diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 176c7e576a..4a6e51b8aa 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -1,4 +1,5 @@ ## common setting +include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_BINARY_DIR}) link_directories(${CMAKE_SOURCE_DIR}/build/mindspore/graphengine) @@ -35,20 +36,20 @@ if(ENABLE_GPU) include_directories(${CUDNN_PATH} ${CUDA_PATH} ${CUDA_INCLUDE_DIRS}) file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "device/gpu/*.cc" - "device/gpu/*.cu" - "kernel/gpu/*.cu" - "kernel/akg/gpu/*.cc" - "kernel/akg/akg_kernel_build.cc" - "kernel/akg/akg_kernel_attrs_process.cc" + "runtime/device/gpu/*.cc" + "runtime/device/gpu/*.cu" + "backend/kernel_compiler/gpu/*.cu" + "backend/kernel_compiler/akg/gpu/*.cc" + "backend/kernel_compiler/akg/akg_kernel_build.cc" + "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" ) list(APPEND CUDA_NVCC_FLAGS -arch=sm_53) - list(REMOVE_ITEM GPU_SRC_LIST "device/gpu/blocking_queue.cc" "device/gpu/gpu_buffer_mgr.cc") - list(REMOVE_ITEM GPU_SRC_LIST "device/gpu/mpi/mpi_initializer.cc" - "device/gpu/distribution/collective_wrapper.cc" - "device/gpu/distribution/mpi_wrapper.cc" - "device/gpu/distribution/nccl_wrapper.cc" + list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc") + list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc" + "runtime/device/gpu/distribution/collective_wrapper.cc" + "runtime/device/gpu/distribution/mpi_wrapper.cc" + "runtime/device/gpu/distribution/nccl_wrapper.cc" ) set(NVCC_TMP_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) @@ -101,15 +102,15 @@ if (ENABLE_DUMP_PROTO) endif () if (ENABLE_D) - include_directories("${CMAKE_BINARY_DIR}/kernel/aicpu") + include_directories("${CMAKE_BINARY_DIR}/backend/kernel_compiler/aicpu") include_directories("${CMAKE_BINARY_DIR}/predict/generator/ir") - file(GLOB_RECURSE PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "kernel/aicpu/proto/*.proto") + file(GLOB_RECURSE PROTO_IN RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "backend/kernel_compiler/aicpu/proto/*.proto") ms_protobuf_generate(PROTOSRCS PROTOHDRS ${PROTO_IN}) file(GLOB_RECURSE PROTO_INNER RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "predict/proto/*.proto") ms_protobuf_generate(PREDICT_PROTOSRCS PREDICT_PROTOHDRS ${PROTO_INNER}) - file(GLOB_RECURSE PROTO_DUMP RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "device/ascend/dump/proto/*.proto") + file(GLOB_RECURSE PROTO_DUMP RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "runtime/device/ascend/dump/proto/*.proto") ms_protobuf_generate(DUMP_PROTOSRCS PROTOHDRS ${PROTO_DUMP}) list(APPEND MINDSPORE_PROTO_LIST ${PROTOSRCS}) @@ -125,18 +126,32 @@ if (MINDSPORE_PROTO_LIST) endif() ## make sub objects -set(SUB_COMP - transform pre_activate parallel pipeline device kernel common debug gvar ir onnx operator optimizer predict - pybind_api pynative session utils vm base abstract +set(SUB_COMP + transform/graph_ir + transform/onnx + backend/optimizer + backend/kernel_compiler + backend/session + runtime/device + frontend/optimizer + frontend/parallel + frontend/operator + pipeline/jit + pipeline/pynative + common debug gvar predict pybind_api utils vm base abstract ) foreach (_comp ${SUB_COMP}) add_subdirectory(${_comp}) - if (TARGET _mindspore_${_comp}_obj) - list(APPEND SUB_OBJECTS_SRC $) - add_dependencies(_mindspore_${_comp}_obj proto_input flat_input) + string(REPLACE "/" "_" sub ${_comp}) + if (TARGET _mindspore_${sub}_obj) + list(APPEND SUB_OBJECTS_SRC $) + add_dependencies(_mindspore_${sub}_obj proto_input flat_input) endif () endforeach () +add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir) +list(APPEND SUB_OBJECTS_SRC $) +add_dependencies(_mindspore_ir_obj proto_input flat_input) set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) @@ -207,8 +222,8 @@ endif() # set c_expression building set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) -set_property(SOURCE "pipeline/init.cc" PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) -pybind11_add_module(_c_expression "pipeline/init.cc") +set_property(SOURCE "pipeline/jit/init.cc" PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) +pybind11_add_module(_c_expression "pipeline/jit/init.cc") MESSAGE(STATUS "operation system is ${CMAKE_SYSTEM}") if (CMAKE_SYSTEM_NAME MATCHES "Linux") @@ -265,8 +280,8 @@ if (ENABLE_CPU) endif () if (ENABLE_MINDDATA) - add_subdirectory(mindrecord) - add_subdirectory(dataset) + add_subdirectory(minddata/mindrecord) + add_subdirectory(minddata/dataset) endif () # build inference @@ -275,7 +290,7 @@ set(LOAD_ONNX_SRC ${CMAKE_CURRENT_SOURCE_DIR}/utils/load_onnx/anf_model_parser.cc ) add_library(inference SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/session/session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/session.cc ${LOAD_ONNX_SRC} ) target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt new file mode 100644 index 0000000000..b412d83d11 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -0,0 +1,66 @@ +file(GLOB_RECURSE KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_build_info.cc" + "kash/*.cc" + "common_utils.cc" + "oplib/*.cc" +) + +if (ENABLE_D) + file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_query.cc" + "kernel_fusion.cc" + "akg/ascend/*.cc" + "akg/akg_kernel_build.cc" + "akg/akg_kernel_attrs_process.cc" + "akg/akg_kernel_metadata.cc" + "tbe/*.cc" + "aicpu/*.cc" + "rts/*.cc" + "hccl/*.cc" + ) + add_compile_definitions(ENABLE_D) +endif () + +if (ENABLE_CPU) + file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "cpu/*.cc" + ) + + list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc" + "cpu/ps/pull_kernel.cc" + "cpu/ps/embedding_look_up_ps_kernel.cc" + "cpu/ps/embedding_look_up_proxy_kernel.cc" + "cpu/ps/apply_momentum_ps_kernel.cc" + "cpu/ps/sparse_apply_adam_ps_kernel.cc" + "cpu/ps/sparse_apply_ftrl_ps_kernel.cc") + + if (NOT ENABLE_MPI) + list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc") + endif () +endif () + +if (ENABLE_GPU) + file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "gpu/*.cu" + "akg/gpu/*.cc" + "akg/akg_kernel_build.cc" + "akg/akg_kernel_attrs_process.cc" + ) + + file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc") + list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_gpu_kernel.cc") + + if (ENABLE_MPI) + include(ExternalProject) + file(GLOB_RECURSE GPU_NCCL_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/nccl/*.cc") + list(APPEND GPU_SRC_LIST ${GPU_NCCL_LIST}) + endif () + + # add_library(_mindspore_kernel_cuda_obj OBJECT ${CUDA_SRC_LIST}) +endif() + +set_property(SOURCE ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_KERNEL) +add_library(_mindspore_backend_kernel_compiler_obj OBJECT ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc new file mode 100644 index 0000000000..7e7fd20f39 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "proto/tensor.pb.h" +#include "proto/tensor_shape.pb.h" +#include "proto/attr.pb.h" +#include "proto/node_def.pb.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +using FNodeAttrHandle = std::function &anf_node, mindspore::NodeDef *proto)>; + +bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input_num, + std::vector *input_size_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_size_list); + for (size_t i = 0; i < input_num; i++) { + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); + if (AnfAlgo::GetInputDeviceDataType(anf_node, i) == kObjectTypeString) { + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << "anf_node is not CNode."; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < (i + 1)) { + MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1; + return false; + } + auto input_node = cnode->inputs()[i + 1]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + auto value_ptr = GetValueNode(input_node); + auto value = GetValue(value_ptr); + input_size_list->push_back(value.size()); + } + } else { + auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); + MS_EXCEPTION_IF_NULL(type_ptr); + int64_t size_i = 1; + for (size_t j = 0; j < shape_i.size(); j++) { + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); + } + size_t type_byte = GetTypeByte(type_ptr); + if (type_byte == 0) { + return false; + } + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); + input_size_list->push_back(LongToSize(size_i)); + } + } + return true; +} + +bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + std::vector input_size_list; + std::vector output_size_list; + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + + if (!SetIOIputSize(anf_node, input_num, &input_size_list)) { + return false; + } + kernel_mod_ptr->SetInputSizeList(input_size_list); + + for (size_t i = 0; i < output_num; i++) { + std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); + TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); + MS_EXCEPTION_IF_NULL(type_ptr); + int64_t size_i = 1; + for (size_t j = 0; j < shape_i.size(); j++) { + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); + } + size_t type_byte = GetTypeByte(type_ptr); + if (type_byte == 0) { + return false; + } + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); + output_size_list.push_back(LongToSize(size_i)); + } + kernel_mod_ptr->SetOutputSizeList(output_size_list); + return true; +} + +void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, + ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { + MS_EXCEPTION_IF_NULL(node_attr); + MS_EXCEPTION_IF_NULL(value); + if (type == "int") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_i(attr_value); + } else if (type == "str") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_s(attr_value); + } else if (type == "bool") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_b(attr_value); + } else if (type == "float") { + auto attr_value = GetValue(value); + (*node_attr)[attr_name].set_f(attr_value); + } else if (type == "listInt") { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == "Int32") { + int data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + mindspore::AttrValue input_shape_attr; + mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); + MS_EXCEPTION_IF_NULL(input_shape_attr_list); + for (const auto shape : attr_value) { + input_shape_attr_list->add_i(shape); + } + (*node_attr)[attr_name] = input_shape_attr; + } else { + MS_LOG(EXCEPTION) << "type: " << type << "not support"; + } +} + +void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(proto); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + if (op_name == kPrint) { + return; + } + + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); + MS_EXCEPTION_IF_NULL(op_info_ptr); + auto attrs_ptr = op_info_ptr->attrs_ptr(); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); + for (const auto &attr_ptr : attrs_ptr) { + MS_EXCEPTION_IF_NULL(attr_ptr); + std::string attr_name = attr_ptr->name(); + auto value = primitive->GetAttr(attr_name); + if (value != nullptr) { + if (attr_name == kQueueName || attr_name == kSharedName) { + attr_name = kChannelName; + } else if (attr_name == kSeed0) { + attr_name = kSeed; + } else if (attr_name == kSeed1) { + attr_name = kSeed2; + } + std::string type = attr_ptr->type(); + ParseAttrValue(type, attr_name, value, node_attr); + } + } + MS_LOG(INFO) << "Set node attr end!"; +} + +void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + if (input_num == 0) { + MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; + return; + } + + for (size_t input_index = 0; input_index < input_num; input_index++) { + ::mindspore::Tensor *node_inputs = proto->add_inputs(); + MS_EXCEPTION_IF_NULL(node_inputs); + TypeId input_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index); + std::vector input_shape; + int32_t input_data_type; + if (input_type == kObjectTypeString) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = cnode->inputs()[input_index + 1]; + auto value_ptr = GetValueNode(input_node); + auto value = GetValue(value_ptr); + input_shape.push_back(1); + input_shape.push_back(value.size()); + input_data_type = AicpuOpUtil::MsTypeToProtoType(kTypeUnknown); + } else { + input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); + input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type); + } + + mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape(); + for (auto item : input_shape) { + mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + dim->set_size((::google::protobuf::int64)item); + } + node_inputs->set_tensor_type((mindspore::DataType)input_data_type); + node_inputs->set_mem_device("HBM"); + } +} + +void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(proto); + MS_EXCEPTION_IF_NULL(anf_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + if (output_num == 0) { + MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; + return; + } + + for (size_t output_index = 0; output_index < output_num; output_index++) { + ::mindspore::Tensor *node_outputs = proto->add_outputs(); + MS_EXCEPTION_IF_NULL(node_outputs); + std::vector output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); + mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape(); + MS_EXCEPTION_IF_NULL(tensorShape); + for (auto item : output_shape) { + mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); + MS_EXCEPTION_IF_NULL(dim); + dim->set_size((::google::protobuf::int64)item); + } + TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index); + int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type); + node_outputs->set_tensor_type((mindspore::DataType)output_data_type); + node_outputs->set_mem_device("HBM"); + } +} + +void SetNodedefProto(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(proto); + MS_LOG(INFO) << "SetNodedefProto entry"; + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + // set op name + proto->set_op(op_name); + // set inputs tensor + SetNodeInputs(anf_node, proto); + // set outputs tensor + SetNodeOutputs(anf_node, proto); + // set node attr + SetNodeAttr(anf_node, proto); + MS_LOG(INFO) << "SetNodedefProto end!"; +} + +bool CreateNodeDefBytes(const std::shared_ptr &anf_node, + const std::shared_ptr &kernel_mod_ptr) { + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "CreateNodeDefBytes entry"; + + mindspore::NodeDef proto; + SetNodedefProto(anf_node, &proto); + std::string nodeDefStr; + if (!proto.SerializeToString(&nodeDefStr)) { + MS_LOG(ERROR) << "Serialize nodeDef to string failed."; + return false; + } + kernel_mod_ptr->SetNodeDef(nodeDefStr); + MS_LOG(INFO) << "CreateNodeDefBytes end!"; + return true; +} + +KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + auto kernel_mod_ptr = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + kernel_mod_ptr->SetAnfNode(anf_node); + kernel_mod_ptr->SetNodeName(op_name); + if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { + MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; + } + if (!SetIOSize(anf_node, kernel_mod_ptr)) { + MS_LOG(EXCEPTION) << "Set input output size list failed."; + } + return kernel_mod_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h new file mode 100644 index 0000000000..6e2ee3959b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc new file mode 100644 index 0000000000..76c29b9f5c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + MS_LOG(INFO) << "AicpuMetadataInfo."; + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + if (op_name == kInitDataSetQueue) { + op_name = kInitData; + } + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); + if (op_info_ptr == nullptr) { + MS_LOG(DEBUG) << "Aicpu does not have op [" << op_name << "]"; + return; + } + // For compatibility with the current framework + if (op_name == kPrint || op_name == kGetNext || op_name == kPack) { + std::vector inputs_format{}; + std::vector inputs_type{}; + if (op_name == kPrint || op_name == kPack) { + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetProcessor(AICPU); + builder.SetKernelType(AICPU_KERNEL); + builder.SetFusionType(OPAQUE); + kernel_info_list->push_back(builder.Build()); + return; + } + if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { + MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed"; + return; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h new file mode 100644 index 0000000000..e21f4eace4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ + +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc new file mode 100644 index 0000000000..e18b3169f3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/aicpu/aicpu_kernel_mod.h" + +#include +#include +#include +#include + +#include "runtime/mem.h" +#include "runtime/rt.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include "utils/convert_utils.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "utils/context/ms_context.h" + +using AicpuTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +constexpr auto AICPU_OPS_SO_NAME = "libaicpu_kernels.so"; + +AicpuOpKernelMod::AicpuOpKernelMod() : anf_node_(nullptr) {} + +AicpuOpKernelMod::~AicpuOpKernelMod() { + args_.clear(); + inputList_.clear(); + outputList_.clear(); + anf_node_ = nullptr; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +void AicpuOpKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } +const std::vector &AicpuOpKernelMod::GetInputSizeList() const { return input_size_list_; } +void AicpuOpKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } +const std::vector &AicpuOpKernelMod::GetOutputSizeList() const { return output_size_list_; } +void AicpuOpKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } +const std::vector &AicpuOpKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } +void AicpuOpKernelMod::SetInputList(const std::vector &inputList) { inputList_ = inputList; } +void AicpuOpKernelMod::SetOutputList(const std::vector &outputList) { outputList_ = outputList; } +void AicpuOpKernelMod::SetNodeDef(const std::string &nodeDef) { (void)node_def_str_.assign(nodeDef); } +void AicpuOpKernelMod::SetNodeName(const std::string &node_name) { node_name_ = node_name; } +void AicpuOpKernelMod::SetAnfNode(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + anf_node_ = anf_node; +} + +void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector &inputs, + const std::vector &outputs) { + MS_LOG(INFO) << "CreateCpuKernelInfoOffline start"; + + node_so_ = AICPU_OPS_SO_NAME; + + // InputOutputAddr + vector io_addrs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(io_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(io_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + + auto io_addrs_num = io_addrs.size(); + // calculate paramLen: AicpuParamHead.len + ioAddrsSize + notifyId.len + customizedAttr.len + auto param_len = sizeof(AicpuParamHead); + + // get input and output addrs size, no need to check overflow + auto io_addrs_size = io_addrs_num * sizeof(uint64_t); + // refresh paramLen, no need to check overflow + param_len += io_addrs_size; + + auto node_def_len = node_def_str_.length(); + param_len += node_def_len; + + // Create taskArgs: AicpuParamHead + ioAddrs + notifyId + customizedAttr + AicpuParamHead paramHead = {static_cast(param_len), static_cast(io_addrs_num)}; + args_.clear(); + (void)args_.append(reinterpret_cast(¶mHead), sizeof(AicpuParamHead)); + // TaskArgs append ioAddrs + if (io_addrs_size != 0) { + (void)args_.append(reinterpret_cast(io_addrs.data()), io_addrs_size); + } + + // When it's aicpu customized ops, taskArgs should append customized attr + if (node_def_len != 0) { + (void)args_.append(reinterpret_cast(node_def_str_.data()), node_def_len); + } + + MS_LOG(INFO) << "CreateCpuKernelInfoOffline end"; +} + +bool AicpuOpKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == nullptr) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + + CreateCpuKernelInfo(inputs, outputs); + if (node_name_ == kTopK) { + node_name_ = kTopKV2; + } + MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ + << ", args_size:" << args_.length(); + if (rtCpuKernelLaunch(reinterpret_cast(node_so_.c_str()), + reinterpret_cast(node_name_.c_str()), 1, + reinterpret_cast(args_.data()), static_cast(args_.length()), nullptr, + stream_ptr) != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Aicpu op launch failed!"; + + return false; + } + return true; +} + +std::vector AicpuOpKernelMod::GenTask(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "AicpuOpKernelMod GenTask start"; + + stream_id_ = stream_id; + node_so_ = AICPU_OPS_SO_NAME; + std::vector input_data_addrs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + + std::vector output_data_addrs; + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + + if (node_name_ == kTopK) { + node_name_ = kTopKV2; + } + + AicpuTaskInfoPtr task_info_ptr = make_shared( + kernel_name_, stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs, NeedDump()); + + MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h new file mode 100644 index 0000000000..82260010ea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +namespace mindspore { +namespace kernel { +class AicpuOpKernelMod : public AscendKernelMod { + public: + AicpuOpKernelMod(); + ~AicpuOpKernelMod() override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + void SetInputList(const std::vector &inputList); + void SetOutputList(const std::vector &outputList); + void SetAnfNode(const AnfNodePtr &anf_node); + void SetNodeDef(const std::string &nodeDef); + void SetNodeName(const std::string &node_name); + + /** + * @brief Build AICPU Engine kernel structure, and allocate device memory for offline task generate + * @return SUCCESS + * @return FAIL + * + */ + void CreateCpuKernelInfo(const std::vector &inputs, const std::vector &outputs); + + void SetInputSizeList(const std::vector &size_list); + void SetOutputSizeList(const std::vector &size_list); + void SetWorkspaceSizeList(const std::vector &size_list); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + + private: + std::string args_; + std::string node_def_str_; + std::string node_name_; + std::string node_so_; + std::vector inputList_; + std::vector outputList_; + AnfNodePtr anf_node_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using AicpuOpKernelModPtr = std::shared_ptr; +using AicputOpKernelModPtrList = std::vector; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc new file mode 100644 index 0000000000..790319daa6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include +#include +#include "proto/types.pb.h" +#include "runtime/mem.h" +#include "runtime/rt.h" +#include "utils/convert_utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +static std::map MS_PROTO_DATA_TYPE_MAP = { + {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, + {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, + {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, + {mindspore::TypeId::kNumberTypeInt8, mindspore::DataType::MS_INT8}, + {mindspore::TypeId::kNumberTypeInt16, mindspore::DataType::MS_INT16}, + {mindspore::TypeId::kNumberTypeInt32, mindspore::DataType::MS_INT32}, + {mindspore::TypeId::kNumberTypeInt64, mindspore::DataType::MS_INT64}, + {mindspore::TypeId::kNumberTypeUInt, mindspore::DataType::MS_UINT32}, + {mindspore::TypeId::kNumberTypeUInt8, mindspore::DataType::MS_UINT8}, + {mindspore::TypeId::kNumberTypeUInt16, mindspore::DataType::MS_UINT16}, + {mindspore::TypeId::kNumberTypeUInt32, mindspore::DataType::MS_UINT32}, + {mindspore::TypeId::kNumberTypeUInt64, mindspore::DataType::MS_UINT64}, + {mindspore::TypeId::kNumberTypeFloat16, mindspore::DataType::MS_FLOAT16}, + {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, + {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, + {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, +}; + +int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { + auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type); + if (iter != MS_PROTO_DATA_TYPE_MAP.end()) { + return MS_PROTO_DATA_TYPE_MAP[ms_type]; + } else { + MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast(ms_type); + return -1; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h new file mode 100644 index 0000000000..fd4495afeb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +constexpr auto kInitDataSetQueue = "InitDataSetQueue"; +constexpr auto kInitData = "InitData"; +constexpr auto kGetNext = "GetNext"; +constexpr auto kPrint = "Print"; +constexpr auto kPack = "Pack"; +constexpr auto kOutputTypes = "output_types"; +constexpr auto kOutputShapes = "output_shapes"; +constexpr auto kChannelName = "channel_name"; +constexpr auto kSharedName = "shared_name"; +constexpr auto kShapes = "shapes"; +constexpr auto kTypes = "types"; +constexpr auto kQueueName = "queue_name"; +constexpr auto kSeed = "seed"; +constexpr auto kSeed0 = "Seed0"; +constexpr auto kSeed1 = "Seed1"; +constexpr auto kSeed2 = "seed2"; +constexpr auto kTopK = "TopK"; +constexpr auto kTopKV2 = "TopKV2"; + +struct AicpuParamHead { + uint32_t length; // Total length: include cunstom message + uint32_t ioAddrNum; // Input and output address number + uint32_t extInfoLength; // extInfo struct Length + uint64_t extInfoAddr; // extInfo address +} __attribute__((packed)); + +class AicpuOpUtil { + public: + static int MsTypeToProtoType(TypeId ms_type); + + private: + // kernel id + static uint64_t KernelId_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/proto/attr.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/attr.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/attr.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/attr.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/node_def.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/node_def.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/node_def.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/node_def.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/tensor.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/tensor.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/tensor_shape.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor_shape.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/tensor_shape.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/tensor_shape.proto diff --git a/mindspore/ccsrc/kernel/aicpu/proto/types.proto b/mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/types.proto similarity index 100% rename from mindspore/ccsrc/kernel/aicpu/proto/types.proto rename to mindspore/ccsrc/backend/kernel_compiler/aicpu/proto/types.proto diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc new file mode 100644 index 0000000000..73fdb5c11b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.cc @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" + +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace kernel { +void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // The x and output are akg op input and output param. + std::vector input_names = {"x"}; + std::vector output_names = {"output"}; + AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); + + TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + std::string dst_type; + if (dst_type_id == kFloat32->type_id()) { + dst_type = "float32"; + } else if (dst_type_id == kFloat16->type_id()) { + dst_type = "float16"; + } + AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); +} + +void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names = {"x"}; + std::vector output_names = {"output"}; + AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); + if (origin_shape.size() != kShape4dDims) { + MS_LOG(EXCEPTION) << "The dim of origin_shape is not equal to 4, but it's dim is " << origin_shape.size() << "."; + } + std::vector shape_transform; + (void)std::transform(origin_shape.begin(), origin_shape.end(), std::back_inserter(shape_transform), + [](const int &origin_shape) { return static_cast(origin_shape); }); + AnfAlgo::SetNodeAttr("shape4d", MakeValue(shape_transform), anf_node); + AnfAlgo::SetNodeAttr("output_format", MakeValue(kOpFormat_NCHW), anf_node); + + TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + std::string dst_type; + if (dst_type_id == kFloat32->type_id()) { + dst_type = "float32"; + } else if (dst_type_id == kFloat16->type_id()) { + dst_type = "float16"; + } + AnfAlgo::SetNodeAttr("dstType", MakeValue(dst_type), anf_node); +} + +void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // The x and output are akg op input and output param. + std::vector input_names = {"x", "dst_type"}; + std::vector output_names = {"output"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); + + std::string dst_type; + TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); + if (output_type == kFloat32->type_id()) { + dst_type = "float32"; + } else if (output_type == kFloat16->type_id()) { + dst_type = "float16"; + } else if (output_type == kInt32->type_id()) { + dst_type = "int32"; + } else { + MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); + } + AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); +} + +void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names{"dy", "data", "mean"}; + std::vector output_names{"dgamma_red_hw", "dbeta_red_hw", "data_minus_mean"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); +} + +void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node) { + const size_t kBNGrad2InputSize = 5; + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names{"dgamma_red_hw", "dbeta_red_hw", "variance", "gamma"}; + std::vector output_names{"bn_scale", "bn_bias", "rs", "dgamma_dx", "dbeta_dx"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < kBNGrad2InputSize) { + MS_LOG(EXCEPTION) << "The inputs size of BNGrad2 is less then " << kBNGrad2InputSize; + } + auto input1 = cnode->input(1); + MS_EXCEPTION_IF_NULL(input1); + auto tuple_getitem = input1->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (tuple_getitem->inputs().size() < kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The inputs size of tuple_getitem is less then " << kTupleGetItemInputSize; + } + auto bn_grad1 = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + std::vector data_shape = AnfAlgo::GetInputDeviceShape(bn_grad1, 0); + AnfAlgo::SetNodeAttr(kAttrDataShape, MakeValue(opt::Convert2Int(data_shape)), anf_node); +} + +void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector input_names{"dy", "rs", "dgamma_dx", "dbeta_dx", "data_minus_mean"}; + std::vector output_names{"dx"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); +} + +void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // Set attr for fused_bn1 + std::vector fused_bn1_input_names{"data"}; + std::vector fused_bn1_output_names{"mean", "var_part"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn1_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn1_output_names), anf_node); +} + +void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // Set attr for fused_bn2 + std::vector fused_bn2_input_names{"mean", "var_part", "running_mean", "running_var"}; + std::vector fused_bn2_output_names{"variance", "running_mean", "running_variance"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn2_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn2_output_names), anf_node); +} + +void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + // Set attr for fused_bn3 + std::vector fused_bn3_input_names{"data", "mean", "variance", "gamma", "beta"}; + std::vector fused_bn3_output_names{"y"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn3_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn3_output_names), anf_node); +} + +void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector conv_bn1_output_names{"data", "var_part", "mean"}; + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(conv_bn1_output_names), anf_node); +} + +void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector bn2_add_relu_input_names{"data", "var_part", "mean", "other_branch_data", + "gamma", "beta", "running_mean", "running_var"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_add_relu_input_names), anf_node); + std::vector bn2_add_relu_output_names{"output", "running_mean", "running_variance", "save_inv_variance"}; + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_add_relu_output_names), anf_node); +} + +void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector bn2_input_names{"data", "var_part", "mean", "gamma", "beta", "running_mean", "running_var"}; + std::vector bn2_output_names{"y", "running_mean", "running_variance", "save_inv_variance"}; + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node); + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h new file mode 100644 index 0000000000..9ba724db42 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_attrs_process.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H +#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H + +#include +#include +#include +#include +#include "ir/anf.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace kernel { +void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node); +void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node); +void SetAkgAttrsForCast(const AnfNodePtr &anf_node); +void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node); +void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node); +void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node); +void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node); +void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node); +void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node); +void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node); +void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node); +void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node); + +const std::unordered_map> kAkgKernelAttrsProcessMap = { + {kFour2FiveOpName, SetAkgAttrsForFour2Five}, + {kFive2FourOpName, SetAkgAttrsForFive2Four}, + {"Cast", SetAkgAttrsForCast}, + {kBNGrad1OpName, SetAkgAttrsForBNGrad1}, + {kBNGrad2OpName, SetAkgAttrsForBNGrad2}, + {kBNGrad3OpName, SetAkgAttrsForBNGrad3}, + {kFusedBN1OpName, SetAkgAttrsForFusedBN1}, + {kFusedBN2OpName, SetAkgAttrsForFusedBN2}, + {kFusedBN3OpName, SetAkgAttrsForFusedBN3}, + {kConvBN1OpName, SetAkgAttrsForConvBN1}, + {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, + {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc new file mode 100644 index 0000000000..9c13629b1b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.cc @@ -0,0 +1,623 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "utils/any.h" +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" + +namespace mindspore { +namespace kernel { +constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; +constexpr int32_t ARGS_SIZE = 1; +constexpr auto kCompileWithJsonFunc = "compilewithjson"; + +// json key +constexpr auto kOpDesc = "op_desc"; +constexpr auto kInputDesc = "input_desc"; +constexpr auto kShape = "shape"; +constexpr auto kDataType = "data_type"; +constexpr auto kOutputDesc = "output_desc"; +constexpr auto kName = "name"; +constexpr auto kTensorName = "tensor_name"; +constexpr auto kValue = "value"; +constexpr auto KDynInputSizes = "dyn_input_sizes"; +constexpr auto KInputNames = "input_names"; +constexpr auto KInput = "input"; +constexpr auto KDtype = "dtype"; +namespace { +template +std::string Vector2Str(const std::vector &inputs) { + if (!inputs.empty()) { + std::ostringstream oss; + (void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator(oss, ", ")); + oss << inputs.back(); + return oss.str(); + } + return ""; +} +} // namespace + +std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { + char *pChar = nullptr; + std::string str_res; + if (PyObj == nullptr) { + MS_LOG(ERROR) << "Input parameter is nullptr."; + return str_res; + } + PyObject *strArgs = PyObject_Str(PyObj); + if (strArgs != nullptr) { + (void)PyArg_Parse(strArgs, "s", &pChar); + } + if (pChar == nullptr) { + MS_LOG(ERROR) << "pChar is nullptr."; + return str_res; + } + str_res = pChar; + return str_res; +} + +std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, + const std::pair &position) { + if (node_json.count(tag) == 0) { + MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "]."; + return ""; + } + + auto const &tag_desc = node_json[tag]; + nlohmann::json first_index; + if (tag == kOutputDesc) { + first_index = tag_desc; + } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) { + MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "]."; + return ""; + } else { + first_index = tag_desc[position.first]; + } + + if (!first_index.is_array() || first_index.size() <= position.second) { + MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "]."; + return ""; + } + auto const &second_index = first_index[position.second]; + if (second_index.count(kTensorName) == 0) { + MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "]."; + return ""; + } + + return second_index[kTensorName]; +} + +void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, + nlohmann::json *const node_json) { + MS_EXCEPTION_IF_NULL(node_json); + if (node_json->count(tag) == 0) { + MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "]."; + return; + } + + nlohmann::json *tag_desc = &((*node_json)[tag]); + nlohmann::json *first_index; + if (tag == kOutputDesc) { + first_index = tag_desc; + } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) { + MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "]."; + return; + } else { + first_index = &((*tag_desc)[position.first]); + } + + if (!first_index->is_array() || first_index->size() <= position.second) { + MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "]."; + return; + } + nlohmann::json *second_index = &((*first_index)[position.second]); + if (second_index->count(kTensorName) == 0) { + MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "]."; + return; + } + (*second_index)[kTensorName] = new_name; + return; +} + +int AkgKernelBuild::op_cnt_ = 0; +std::mutex AkgKernelBuild::op_cnt_mtx_; + +std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string device; + switch (AnfAlgo::GetProcessor(anf_node)) { + case Processor::AICORE: + device = kProcessorAiCore; + break; + + case Processor::AICPU: + device = kProcessorAiCpu; + break; + + case Processor::CUDA: + device = kProcessorCuda; + break; + + default: + MS_LOG(ERROR) << "Unknown processor type."; + break; + } + + return device; +} + +bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, + std::vector *const output_size) { + if (input_size == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "input size or output size is nullptr"; + return false; + } + input_size->clear(); + output_size->clear(); + + for (size_t i = 0; i < node_json[kInputDesc].size(); i++) { + for (size_t m = 0; m < node_json[kInputDesc][i].size(); m++) { + std::string dtype = node_json[kInputDesc][i][m][kDataType]; + size_t nbyte = GetDtypeNbyte(dtype); + size_t size_i = std::accumulate(node_json[kInputDesc][i][m][kShape].begin(), + node_json[kInputDesc][i][m][kShape].end(), nbyte, std::multiplies()); + input_size->push_back(size_i); + } + } + + for (size_t i = 0; i < node_json[kOutputDesc].size(); i++) { + std::string dtype = node_json[kOutputDesc][i][kDataType]; + size_t nbyte = GetDtypeNbyte(dtype); + size_t size_i = std::accumulate(node_json[kOutputDesc][i][kShape].begin(), node_json[kOutputDesc][i][kShape].end(), + nbyte, std::multiplies()); + output_size->push_back(size_i); + } + + return true; +} + +int AkgKernelBuild::GetOpCntInc() { + op_cnt_mtx_.lock(); + int cnt = op_cnt_++; + op_cnt_mtx_.unlock(); + return cnt; +} + +bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(inputs_json); + + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + if (op_info == nullptr) { + MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op_info is nullptr"; + return false; + } + + std::vector> inputs_ptr = op_info->inputs_ptr(); + if (inputs_ptr.empty()) { + MS_LOG(INFO) << "Apply kernel [" << op_name << "] regist info has no input info"; + return true; + } + auto op_info_input_num = inputs_ptr.size(); + + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + size_t real_input_index = 0; + std::vector input_list; + for (size_t i = 0; i < op_info_input_num; i++) { + size_t input_tensor_num; + std::shared_ptr input_ptr = inputs_ptr[i]; + std::string op_input_name; + if (input_ptr == nullptr) { + MS_LOG(ERROR) << "Apply kernel [" << op_name << "] regist input[" << i << "] is nullptr"; + return false; + } + + op_input_name = input_ptr->name(); + if (dyn_input_sizes.empty()) { + input_tensor_num = 1; + } else { + input_tensor_num = IntToSize(dyn_input_sizes[i]); + } + + input_list.clear(); + for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { + // dtype : float16 + auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); + std::string dtype = TypeId2String(type_id); + if (dtype.empty()) { + MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; + return false; + } + nlohmann::json input_desc_json; + input_desc_json[kDataType] = dtype; + input_desc_json[kName] = op_input_name; + input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); + auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); + if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && + GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { + MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) + << "] as const tensor, shape: [" << Vector2Str(input_shape) + << "], value: " << input_desc_json[kValue]; + + input_shape.clear(); + } + if (input_shape.empty()) { + input_shape.push_back(1); + } + input_desc_json[kShape] = input_shape; + input_list.emplace_back(input_desc_json); + real_input_index++; + } + inputs_json->emplace_back(input_list); + } + return true; +} + +bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(outputs_json); + size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + auto outputs = op_info_ptr->outputs_ptr(); + for (size_t i = 0; i < output_tensor_num; i++) { + nlohmann::json output_json; + auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); + std::string dtype = TypeId2String(type_id); + if (dtype.empty()) { + MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; + return false; + } + + std::string output_name = outputs[i]->name(); + output_json[kDataType] = dtype; + output_json[kName] = output_name; + output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc()); + output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); + outputs_json->push_back(output_json); + } + return true; +} + +void GetJson(const AnfNodePtr &anf_node, const std::vector &dyn_input_sizes, + const std::shared_ptr &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_attr); + MS_EXCEPTION_IF_NULL(attr_json); + std::string type = op_attr->type(); + if (type == "int") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "str") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "bool") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "float") { + (*attr_json)[kValue] = GetValue(attr_value); + } else if (type == "listInt") { + (*attr_json)[kValue] = GetValue>(attr_value); + } else if (type == "listStr") { + std::vector data_format; + if (op_attr->name() == kArgDataformat) { + size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); + for (size_t format_i = 0; format_i < tensor_args_num; format_i++) { + auto input_format = AnfAlgo::GetInputFormat(anf_node, format_i); + data_format.push_back(input_format); + } + } else { + data_format = GetValue>(attr_value); + } + (*attr_json)[kValue] = data_format; + } else { + MS_LOG(WARNING) << "attr type:" << type; + } +} + +bool AkgKernelBuild::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, + const std::shared_ptr &op_info, nlohmann::json *const attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + MS_EXCEPTION_IF_NULL(op_info); + std::vector> attrs = op_info->attrs_ptr(); + if (attrs.empty()) { + MS_LOG(INFO) << "Apply kernel [" << op_name << "] op info attrs is empty"; + return true; + } + std::vector> inputs = op_info->inputs_ptr(); + + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + if (inputs.empty()) { + MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op info inputs is empty"; + return false; + } + + // create input name list for atch "x_shape" in att with "x" in primitive. + std::map op_info_shape_name; + for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) { + std::string input_name = inputs[op_info_input_i]->name(); + std::string x_shape_name = input_name + "_shape"; + (void)op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)); + } + + for (const auto &op_attr : attrs) { + nlohmann::json attr_json; + ValuePtr attr_value = primitive->GetAttr(op_attr->name()); + if (attr_value == nullptr && op_attr->name() != kArgDataformat) { + if (op_attr->param_type() == "required") { + // match "x_shape" in att with "x" in primitive. + std::string attr_name = op_attr->name(); + auto find_item = std::find_if( + op_info_shape_name.begin(), op_info_shape_name.end(), + [attr_name](const std::map::value_type item) { return item.second == attr_name; }); + if (find_item != op_info_shape_name.end()) { + if (!dyn_input_sizes.empty()) { + if (find_item->first >= dyn_input_sizes.size() - 1) { + MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first + << " is out of range:" << dyn_input_sizes.size() - 1 << "."; + return false; + } + size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0)); + for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) { + attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); + attr_json[kName] = op_attr->name(); + attrs_json->push_back(attr_json); + tensor_idx++; + } + } else { + attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first); + attr_json[kName] = op_attr->name(); + attrs_json->push_back(attr_json); + } + } else { + MS_LOG(ERROR) << "op [" << op_name << "] should have attr :" << op_attr->name(); + return false; + } + } + continue; + } + + GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); + + attr_json[kName] = op_attr->name(); + attrs_json->push_back(attr_json); + } + return true; +} + +bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, + nlohmann::json *const node_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(node_json); + int op_cnt = GetOpCntInc(); + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + MS_EXCEPTION_IF_NULL(op_info_ptr); + + // get basic params from currentNodeOpDesc + (*node_json)[kName] = op_name; + (*node_json)["impl_path"] = op_info_ptr->impl_path(); + (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); + (*node_json)["composite"] = false; + + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr input_names_v = primitive->GetAttr(KInputNames); + if (input_names_v == nullptr) { + MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; + return false; + } + std::vector prim_input_names = GetValue>(input_names_v); + std::string inputs_name; + for (const auto &prim_input_name : prim_input_names) { + (void)inputs_name.append("_input_").append(prim_input_name).append("_"); + } + + // input desc + nlohmann::json inputs_json; + if (!CreateInputDescJson(anf_node, &inputs_json)) { + MS_LOG(ERROR) << "Create input desc json failed, op[" << op_name << "]."; + return false; + } + (*node_json)[kInputDesc] = inputs_json; + MS_LOG(INFO) << "Akg create input desc json success."; + std::string inputs_shape = "inputs_shape_"; + for (auto &i : inputs_json) { + for (auto &m : i) { + std::string data_type = m[kDataType]; + (void)inputs_shape.append("_").append(data_type).append("_"); + for (auto &j : m[kShape]) { + size_t n = j; + (void)inputs_shape.append(std::to_string(n)).append("_"); + } + } + } + + // output desc + nlohmann::json outputs_json; + if (!CreateOutputDescJson(anf_node, &outputs_json)) { + MS_LOG(ERROR) << "Create output desc json failed, op[" << op_name << "]."; + return false; + } + + (*node_json)[kOutputDesc] = outputs_json; + MS_LOG(INFO) << "Akg create output desc json success."; + std::string outputs_shape = "outputs_shape_"; + for (auto &i : outputs_json) { + std::string data_type = i[kDataType]; + (void)outputs_shape.append("_").append(data_type).append("_"); + for (auto &j : i[kShape]) { + size_t m = j; + (void)outputs_shape.append(std::to_string(m)).append("_"); + } + } + + // attribute desc + nlohmann::json attrs_json; + if (!CreateAttrDescJson(anf_node, op_name, op_info_ptr, &attrs_json)) { + MS_LOG(ERROR) << "Create attr desc json failed, op[" << op_name << "]."; + return false; + } + (*node_json)["attr"] = attrs_json; + std::string json_str = node_json->dump(); + size_t hash_id = std::hash()(json_str); + json_name_ = op_name + "_"; + (void)json_name_.append(std::to_string(hash_id)); + MS_LOG(INFO) << "full scope name is : " << anf_node->fullname_with_scope() << ", json info name is : " << json_name_; + json_info_ = json_str; + (*node_json)["id"] = op_cnt; + (*node_json)["op"] = json_name_; + MS_LOG(INFO) << "Akg create node desc json success."; + return true; +} + +KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto processor = AkgKernelBuild::GetProcessor(anf_node); + auto cached_kernel_pack = SearchCache(json_name_, processor); + if (cached_kernel_pack != nullptr) { + MS_LOG(INFO) << "Use cached kernel, json_name_[" << json_name_ << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return cached_kernel_pack; + } + + PyObject *pModule = nullptr; + PyObject *pFunc = nullptr; + PyObject *pArg = nullptr; + PyObject *pRes = nullptr; + + pModule = PyImport_ImportModule(kAkgModule); + if (pModule == nullptr) { + MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "]."; + return nullptr; + } + + pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc); + pArg = PyTuple_New(ARGS_SIZE); + (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str())); + + (void)alarm(AUTODIFF_COMPILE_OVERTIME); + pRes = PyEval_CallObject(pFunc, pArg); + (void)alarm(0); + if (pRes == nullptr) { + MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(pArg) << ")."; + return nullptr; + } + if (PyObject_IsTrue(pRes) != 1) { + MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(pArg) << ")."; + return nullptr; + } + + auto new_kernel_pack = InsertCache(json_name_, processor); + kernel::SaveJsonInfo(json_name_, json_info_); + if (new_kernel_pack == nullptr) { + MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name_ << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return nullptr; + } + return new_kernel_pack; +} + +KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, + std::vector *const output_size) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto it = kAkgKernelAttrsProcessMap.find(op_name); + if (it != kAkgKernelAttrsProcessMap.end()) { + it->second(anf_node); + } + MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; + nlohmann::json node_json; + if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { + MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; + } + + std::string json_str = node_json.dump(); + auto kernel_pack = OpBuild(json_str, anf_node); + if (kernel_pack == nullptr) { + MS_LOG(ERROR) << "Akg build failed op[" << op_name << "], json:" << json_str; + return nullptr; + } + + if (!GetIOSize(node_json, input_size, output_size)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return nullptr; + } + MS_LOG(INFO) << "Akg compile success, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) + << "]"; + return kernel_pack; +} + +size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->inputs().size()) { + MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" + << cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]"; + } + + auto input_node = cnode->input(input_idx + 1); + if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) { + size_t index = input_tensor_idx_.size(); + input_tensor_idx_[input_node] = index; + } + + return input_tensor_idx_[input_node]; +} + +size_t AkgKernelBuild::GetOutputTensorIdxInc() { + size_t idx = output_tensor_idx_++; + return idx; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h new file mode 100644 index 0000000000..7b6a2f0b86 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_build.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "ir/dtype.h" +#include +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace kernel { +class AkgKernelBuild { + public: + AkgKernelBuild() { + input_tensor_idx_ = {}; + output_tensor_idx_ = 0; + } + ~AkgKernelBuild() = default; + + KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, + std::vector *const output_size); + static std::string GetProcessor(const AnfNodePtr &anf_node); + static std::string PyObjectToStr(PyObject *const PyObj); + + protected: + bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); + bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); + bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, + const std::shared_ptr &op_info, nlohmann::json *const attrs_json); + KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); + int GetOpCntInc(); + size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); + size_t GetOutputTensorIdxInc(); + bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, + nlohmann::json *const node_json); + + static int op_cnt_; + // lock for variable fusionOpCnt in singleton mode + static std::mutex op_cnt_mtx_; + std::string json_name_; + std::string json_info_; + std::unordered_map input_tensor_idx_; + size_t output_tensor_idx_; +}; + +bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, + std::vector *const output_size); +void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, + nlohmann::json *const node_json); +std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, + const std::pair &position); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.cc new file mode 100644 index 0000000000..f3567428d3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/akg_kernel_metadata.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +void AkgMetadataInfo(const CNodePtr &kernel_node, + std::vector> *const kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + for (size_t i = 0; i < support_devices.size(); i++) { + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); + if (op_info_ptr == nullptr) { + continue; + } + + if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) { + MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed."; + } else { + MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "]."; + break; + } + } + + if (kernel_info_list->empty()) { + MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "]."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h new file mode 100644 index 0000000000..02785c6cdb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/akg_kernel_metadata.h @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc new file mode 100644 index 0000000000..d698c89bc9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc @@ -0,0 +1,422 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/dtype.h" +#include "ir/func_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" +#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +constexpr int32_t PARALLEL_ARGS_SIZE = 3; +constexpr int32_t PROCESS_NUM = 16; +constexpr int32_t TIME_OUT = 300; + +constexpr auto kOpDesc = "op_desc"; +constexpr auto kShape = "shape"; +constexpr auto kDataType = "data_type"; +constexpr auto kInputDesc = "input_desc"; +constexpr auto kOutputDesc = "output_desc"; +constexpr auto kTensorName = "tensor_name"; +constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; +constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; +namespace { +void UpdateTensorNameInJson(const std::vector &anf_nodes, + std::map *node_json_map) { + for (auto const &anf_node : anf_nodes) { + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + bool is_dynamic_input = !dyn_input_sizes.empty(); + size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); + size_t real_input_index = 0; + for (size_t i = 0; i < input_num; ++i) { + size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; + for (size_t j = 0; j < input_tensor_num; ++j) { + auto tmp_input = GetKernelInput(anf_node, real_input_index); + std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kInputDesc, std::make_pair(i, j)); + if (node_json_map->find(tmp_input.first) != node_json_map->end()) { + std::string new_tensor_name = + GetTensorName((*node_json_map)[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second)); + SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node])); + MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" + << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" + << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; + } else { + MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of [" + << anf_node->fullname_with_scope() << "] is out input."; + } + real_input_index++; + } + } + } +} + +nlohmann::json GetInputsJson(const std::vector &anf_nodes, const std::vector &input_list, + std::map *node_json_map) { + nlohmann::json inputs_json; + auto input_index = GetInputIndex(anf_nodes, input_list); + for (size_t i = 0; i < input_index.size(); ++i) { + auto tmp_input = input_index[i]; + auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first); + std::string dtype = TypeId2String(type_id); + nlohmann::json input_desc_json; + input_desc_json[kTensorName] = GetTensorName((*node_json_map)[tmp_input.first], kInputDesc, tmp_input.second); + input_desc_json[kDataType] = dtype; + input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first); + inputs_json.emplace_back(std::vector{input_desc_json}); + } + + return inputs_json; +} + +nlohmann::json GetOutputsJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::vector &output_list, const nlohmann::json &inputs_json, + std::map *node_json_map) { + nlohmann::json outputs_json; + auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); + for (size_t i = 0; i < output_index.size(); ++i) { + auto tmp_output = output_index[i]; + bool found = false; + nlohmann::json output_desc_json; + for (size_t input_i = 0; input_i < input_list.size(); ++input_i) { + if (tmp_output.first == input_list[input_i]) { + output_desc_json = inputs_json[input_i][0]; + found = true; + break; + } + } + if (!found) { + auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second); + std::string dtype = TypeId2String(type_id); + output_desc_json[kTensorName] = + GetTensorName((*node_json_map)[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second)); + output_desc_json[kDataType] = dtype; + auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second); + if (output_shape.empty()) { + output_shape.push_back(1); + } + output_desc_json[kShape] = output_shape; + } + outputs_json.emplace_back(output_desc_json); + } + + return outputs_json; +} + +std::pair, std::vector>> PreProcessJsonForBuild( + const std::vector> &build_args) { + // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. + std::vector jsons; + std::vector> repeat_nodes; + std::unordered_set json_name_set; + for (const auto &[builder, anf_node] : build_args) { + MS_EXCEPTION_IF_NULL(anf_node); + auto json_name = builder.json_name(); + MS_LOG(DEBUG) << "Akg start compile op: " << json_name; + auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); + if (cached_kernel_pack != nullptr) { + MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); + kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + continue; + } + + if (json_name_set.count(json_name) != 0) { + repeat_nodes.push_back({builder, anf_node}); + continue; + } + json_name_set.insert(json_name); + auto node_json = builder.kernel_json(); + kernel::SaveJsonInfo(json_name, node_json); + jsons.push_back(node_json); + } + + return std::make_pair(jsons, repeat_nodes); +} + +bool PostProcessAfterCompile(const std::vector> &build_args, + const std::vector> &repeat_nodes) { + for (const auto &[builder, anf_node] : build_args) { + auto json_name = builder.json_name(); + auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); + if (new_kernel_pack == nullptr) { + MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + return false; + } + auto kernel_mod_ptr = std::make_shared(new_kernel_pack); + kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!"; + } + + for (const auto &[builder, anf_node] : repeat_nodes) { + auto node_json = builder.kernel_json(); + auto json_name = builder.json_name(); + auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); + if (cached_kernel_pack == nullptr) { + return false; + } + MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope[" + << anf_node->fullname_with_scope() << "]."; + auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); + kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); + kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + } + + return true; +} +} // namespace + +bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; + auto it = kAkgKernelAttrsProcessMap.find(op_name); + if (it != kAkgKernelAttrsProcessMap.end()) { + it->second(anf_node); + } + MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; + nlohmann::json node_json; + if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { + MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; + } + + kernel_json_ = node_json.dump(); + + if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return false; + } + + return true; +} + +bool AkgAscendKernelBuilder::GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, + std::map *node_json_map) { + for (auto const &anf_node : anf_nodes) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (!AnfAlgo::IsRealKernel(anf_node)) { + MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "]."; + return false; + } + auto it = kAkgKernelAttrsProcessMap.find(op_name); + if (it != kAkgKernelAttrsProcessMap.end()) { + it->second(anf_node); + } + + nlohmann::json node_json; + if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { + MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed."; + return false; + } + // No need for composite op. + node_json.erase("id"); + node_json.erase("op"); + node_json.erase("composite"); + + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + + if (primitive->GetAttr("fusion") != nullptr) { + node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); + } + + (*node_json_map)[anf_node] = node_json; + } + return true; +} + +bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector &anf_nodes, + const std::vector &input_list, + const std::vector &output_list) { + if (anf_nodes.empty() || input_list.empty()) { + MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size() + << "]."; + return false; + } + MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list [" + << input_list.size() << "]."; + + std::map node_json_map; + if (!GenJsonAndPreprocess4Fused(anf_nodes, &node_json_map)) { + return false; + } + + UpdateTensorNameInJson(anf_nodes, &node_json_map); + + nlohmann::json fused_node_json; + std::vector node_json_desc; + std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), + [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); + fused_node_json[kOpDesc] = node_json_desc; + fused_node_json[kInputDesc] = GetInputsJson(anf_nodes, input_list, &node_json_map); + fused_node_json[kOutputDesc] = + GetOutputsJson(anf_nodes, input_list, output_list, fused_node_json[kInputDesc], &node_json_map); + + size_t hash_id = std::hash()(fused_node_json.dump()); + json_name_ = "Fused_"; + auto fg = anf_nodes[0]->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (attr_val != nullptr) { + auto fg_attr = GetValue(attr_val); + (void)json_name_.append(fg_attr).append("_"); + } + (void)json_name_.append(std::to_string(hash_id)); + fused_node_json["composite_graph"] = fg->ToString(); + fused_node_json["op"] = json_name_; + fused_node_json["platform"] = "AKG"; + fused_node_json["process"] = "aicore"; + fused_node_json["composite"] = true; + + kernel_json_ = fused_node_json.dump(); + + if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) { + MS_LOG(ERROR) << "Cal mem size failed."; + return false; + } + + return true; +} + +void GenParallelCompileFuncArgs(const std::vector &kernel_jsons, PyObject **p_args) { + MS_EXCEPTION_IF_NULL(p_args); + *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); + + PyObject *arg1 = PyList_New(kernel_jsons.size()); + for (int i = 0; i < PyList_Size(arg1); ++i) { + PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); + } + PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); + PyObject *arg3 = Py_BuildValue("i", TIME_OUT); + + (void)PyTuple_SetItem(*p_args, 0, arg1); + (void)PyTuple_SetItem(*p_args, 1, arg2); + (void)PyTuple_SetItem(*p_args, 2, arg3); +} + +bool AkgOpParallelBuild(const std::vector> &build_args) { + auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); + if (jsons.empty()) { + return true; + } + + // Try to call python method to compile nodes parallely. + PyObject *p_module = nullptr; + PyObject *p_func = nullptr; + PyObject *p_arg = nullptr; + PyObject *p_res = nullptr; + + p_module = PyImport_ImportModule(kMultiProcModule); + if (p_module == nullptr) { + MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; + return false; + } + + p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); + GenParallelCompileFuncArgs(jsons, &p_arg); + MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() + << " Akg kernels parallelly."; + p_res = PyEval_CallObject(p_func, p_arg); + if (p_res == nullptr) { + PyErr_Print(); + MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + return false; + } + if (PyObject_IsTrue(p_res) != 1) { + PyErr_Print(); + MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" + << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + return false; + } + + if (!PostProcessAfterCompile(build_args, repeat_nodes)) { + return false; + } + + return true; +} + +bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes) { + std::vector> json_and_node; + for (const auto &anf_node : anf_nodes) { + MS_EXCEPTION_IF_NULL(anf_node); + AkgAscendKernelBuilder akg_cce_kernel_builder; + KernelPackPtr kernel_pack = nullptr; + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::IsGraphKernel(cnode)) { + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + MS_EXCEPTION_IF_NULL(func_graph); + std::vector node_list; + std::vector input_list; + std::vector output_list; + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]"; + GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { + MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; + } + } else { + if (!akg_cce_kernel_builder.CollectJson(anf_node)) { + MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; + } + } + json_and_node.push_back({akg_cce_kernel_builder, anf_node}); + } + + if (json_and_node.empty()) { + MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; + return true; + } + + return AkgOpParallelBuild(json_and_node); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h new file mode 100644 index 0000000000..713b65a451 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" + +namespace mindspore { +namespace kernel { +class AkgAscendKernelBuilder : public AkgKernelBuild { + public: + AkgAscendKernelBuilder() = default; + ~AkgAscendKernelBuilder() = default; + + bool CollectJson(const AnfNodePtr &anf_node); + bool CollectFusedJson(const std::vector &anf_nodes, const std::vector &input_list, + const std::vector &output_list); + std::string json_name() const { return json_name_; } + std::string kernel_json() const { return kernel_json_; } + const std::vector &input_size_list() const { return input_size_list_; } + const std::vector &output_size_list() const { return output_size_list_; } + + private: + bool GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, + std::map *node_json_map); + + std::string kernel_json_; + std::vector input_size_list_; + std::vector output_size_list_; +}; + +bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc new file mode 100644 index 0000000000..8bb4940778 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" +#include +#include +#include +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "runtime/rt.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +using std::fstream; +using std::map; +using std::mutex; +using std::string; +using TbeTaskInfoPtr = std::shared_ptr; +using tbe::KernelManager; +constexpr uint32_t DEFAULT_BLOCK_DIM = 1; +/** + * @brief infotable contain func_stub\blockdim\kernel file buffer + */ +AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} + +void AkgKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } + +void AkgKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } + +void AkgKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } + +const std::vector &AkgKernelMod::GetInputSizeList() const { return input_size_list_; } + +const std::vector &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool AkgKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == nullptr) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + + if (kernel_pack_ == nullptr) { + MS_LOG(ERROR) << "kernel pack should not be nullptr."; + return false; + } + + uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); + if (func_stub == 0) { + MS_LOG(ERROR) << "GenFuncStub failed."; + return false; + } + + // pack all addresses into a vector. + std::vector runtime_args; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args), + [](const AddressPtr &output) -> void * { return output->addr; }); + + rtL2Ctrl_t *l2ctrl = nullptr; + auto stream = reinterpret_cast(stream_ptr); + if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast(func_stub), block_dim, runtime_args.data(), + SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) { + MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; + return false; + } + + return true; +} + +std::vector AkgKernelMod::GenTask(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + if (kernel_pack_ == nullptr) { + MS_LOG(EXCEPTION) << "kernel pack should not be nullptr."; + } + + std::vector args; + const uint32_t args_size = 0; + std::vector sm_desc; + void *binary = nullptr; + const uint32_t binary_size = 0; + std::vector meta_data; + std::vector input_data_addrs; + std::vector output_data_addrs; + std::vector workspace_addrs; + + // pack all addresses into a vector. + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + + uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); + if (func_stub == 0) { + MS_LOG(EXCEPTION) << "GenFuncStub failed."; + } + + std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); + + MS_LOG(DEBUG) << "The block_dim is:" << block_dim; + + TbeTaskInfoPtr task_info_ptr = make_shared( + kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, + input_data_addrs, output_data_addrs, workspace_addrs, NeedDump()); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h new file mode 100644 index 0000000000..3ea36f1a23 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +class AkgKernelMod : public AscendKernelMod { + public: + explicit AkgKernelMod(const KernelPackPtr &kernel_pack); + ~AkgKernelMod() final {} + + void SetInputSizeList(const std::vector &size_list); + void SetOutputSizeList(const std::vector &size_list); + void SetWorkspaceSizeList(const std::vector &size_list); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + KernelPackPtr kernel_pack_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using AkgKernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc new file mode 100644 index 0000000000..96fcd1869e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h" +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + AkgKernelBuild akg_kernel_build; + + std::vector input_size_list; + std::vector output_size_list; + KernelPackPtr kernel_pack = akg_kernel_build.BuildByJson(anf_node, &input_size_list, &output_size_list); + MS_EXCEPTION_IF_NULL(kernel_pack); + + auto kernel_mod_ptr = std::make_shared(kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + kernel_mod_ptr->SetInputSizeList(input_size_list); + kernel_mod_ptr->SetOutputSizeList(output_size_list); + return kernel_mod_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h new file mode 100644 index 0000000000..abb6d1f030 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ +#include "backend/kernel_compiler/kernel.h" +#include "base/base.h" + +namespace mindspore { +namespace kernel { +KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc new file mode 100644 index 0000000000..d527f8ec76 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h" +#include +#include +#include "nlohmann/json.hpp" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +using std::fstream; +using std::string; +using std::vector; + +GpuKernelManagerPtr GpuKernelMod::kernelmanager_ = std::make_shared(); +GpuKernelManager::GpuKernelManager() {} + +CUresult GpuKernelManager::GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, + vector *thread_info, CUfunction *func) { + if (kernel_pack->GetJson() == nullptr || kernel_pack->GetJson()->contents == nullptr || + kernel_pack->GetKernel() == nullptr || kernel_pack->GetKernel()->contents == nullptr) { + MS_LOG(ERROR) << "GPU:Invalid kernel pack, json or kernel is nullptr."; + return CUDA_ERROR_INVALID_IMAGE; + } + auto js = nlohmann::json::parse(kernel_pack->GetJson()->contents, + kernel_pack->GetJson()->contents + kernel_pack->GetJson()->len); + string fn = js["kernelName"]; + if (!force_reload) { + auto iter = infotable_.find(fn); + if (iter != infotable_.end()) { + auto kernelmeta = iter->second; + *thread_info = kernelmeta->thread_info_; + *func = kernelmeta->func_addr_; + return CUDA_SUCCESS; + } + } + thread_info->emplace_back(js["blockIdx.x"]); + thread_info->emplace_back(js["blockIdx.y"]); + thread_info->emplace_back(js["blockIdx.z"]); + thread_info->emplace_back(js["threadIdx.x"]); + thread_info->emplace_back(js["threadIdx.y"]); + thread_info->emplace_back(js["threadIdx.z"]); + CUmodule module; + CUresult result = cuModuleLoadData(&module, kernel_pack->GetKernel()->contents); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "cuModuleLoadData failed."; + return result; + } + result = cuModuleGetFunction(func, module, fn.c_str()); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "cuModuleGetFunction failed."; + return result; + } + infotable_[fn] = std::make_shared(*func, module, *thread_info); + return result; +} + +GpuKernelMod::GpuKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} + +void GpuKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } + +void GpuKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } + +const std::vector &GpuKernelMod::GetInputSizeList() const { return input_size_list_; } + +const std::vector &GpuKernelMod::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &GpuKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool GpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == 0) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + if (kernel_pack_ == nullptr) { + MS_LOG(ERROR) << "kernel pack should not be nullptr."; + return false; + } + vector thread_info; + CUfunction kernel_addr; + CUresult result = kernelmanager_->GetFunction(kernel_pack_, false, &thread_info, &kernel_addr); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "GetFunction failed."; + return false; + } + std::vector runtimeargs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), + [](const AddressPtr &input) -> void * { return reinterpret_cast(&(input->addr)); }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), + [](const AddressPtr &output) -> void * { return reinterpret_cast(&(output->addr)); }); + result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4], + thread_info[5], 0, reinterpret_cast(stream_ptr), + reinterpret_cast(&runtimeargs[0]), 0); + if (result != CUDA_SUCCESS) { + MS_LOG(ERROR) << "Launch Kernel failed."; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h new file mode 100644 index 0000000000..a6a17d033f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/gpu/akg_gpu_kernel_mod.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +struct GpuKernelMeta { + CUfunction func_addr_; + CUmodule module_; + std::vector thread_info_; + GpuKernelMeta(CUfunction funcAddr, CUmodule module, const std::vector &thread_info) + : func_addr_(funcAddr), module_(module), thread_info_(thread_info) {} +}; +using GpuKernelMetaPtr = std::shared_ptr; + +class GpuKernelManager { + public: + GpuKernelManager(); + virtual ~GpuKernelManager() { + for (auto iter = infotable_.begin(); iter != infotable_.end(); ++iter) { + CUresult ret = cuModuleUnload(iter->second->module_); + if (ret != CUDA_SUCCESS && ret != CUDA_ERROR_DEINITIALIZED) { + MS_LOG(ERROR) << "Unload GPU Module failed."; + } + } + } + CUresult GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, std::vector *thread_info, + CUfunction *func); + + private: + std::unordered_map infotable_; +}; +using GpuKernelManagerPtr = std::shared_ptr; + +class GpuKernelMod : public KernelMod { + public: + explicit GpuKernelMod(const KernelPackPtr &kernel_pack); + virtual ~GpuKernelMod() {} + + void SetInputSizeList(const std::vector &size_list); + void SetOutputSizeList(const std::vector &size_list); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + static GpuKernelManagerPtr kernelmanager_; + + private: + KernelPackPtr kernel_pack_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using GpuKernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h new file mode 100644 index 0000000000..c6398eda9e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/ascend_kernel_mod.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ + +#include +#include +#include "framework/ge_runtime/task_info.h" +#include "backend/kernel_compiler/kernel.h" +#ifdef ENABLE_DATA_DUMP +#include "debug/data_dump_parser.h" +#endif + +using TaskInfoPtr = std::shared_ptr; +namespace mindspore { +namespace kernel { +class AscendKernelMod : public KernelMod { + public: + virtual std::vector GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t) = 0; + uint32_t block_dim() { return block_dim_; } + uint32_t stream_id() { return stream_id_; } + virtual bool NeedDump() { +#ifdef ENABLE_DATA_DUMP + return DataDumpParser::GetInstance().NeedDump(kernel_name_); +#else + return false; +#endif + } + + protected: + uint32_t block_dim_{1}; + uint32_t stream_id_{0}; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc new file mode 100644 index 0000000000..f4495cdb9d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -0,0 +1,1029 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/common_utils.h" +#include +#include +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "ir/manager.h" +#include "ir/meta_tensor.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" +#include "utils/graph_utils.h" + +namespace mindspore { +namespace kernel { +constexpr char kAxis[] = "axis"; +constexpr char kTypeInt32[] = "Int32"; +const std::unordered_map type_id_maps = { + {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, + {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, + {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, + {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, + {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, + {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, + {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, + {"bool", TypeId::kNumberTypeBool}, +}; + +const std::map type_id_str_map = { + {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, + {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, + {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, + {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, + {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, + {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, + {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, + {TypeId::kNumberTypeBool, "bool"}, +}; + +const std::unordered_map dtype_shortdtype_map_ = { + {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, + {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, +}; + +const std::unordered_map dtype_nbyte_map = { + {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, + {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, + {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, + {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, +}; + +const std::unordered_map fusion_type_maps = { + {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, + {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, +}; + +void KernelMeta::Initialize() { + kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; + // remove old kernel cache + RemoveKernelCache(); + +#if defined(_WIN32) || defined(_WIN64) + auto ret = mkdir(kernel_meta_path_.c_str()); +#else + auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU); +#endif + if (ret != 0) { + MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later"; + } + initialized_ = true; +} + +void KernelMeta::RemoveKernelCache() { + DIR *dir = opendir(kernel_meta_path_.c_str()); + if (dir == nullptr) { + return; + } + struct dirent *entry; + while ((entry = readdir(dir)) != nullptr) { + std::string kernel_file = entry->d_name; + std::string kernel_file_realpath = kernel_meta_path_ + kernel_file; + (void)remove(kernel_file_realpath.c_str()); + } + (void)closedir(dir); + (void)rmdir(kernel_meta_path_.c_str()); +} + +std::string KernelMeta::Search(const std::string &kernel_name) const { + if (!initialized_) { + return ""; + } + + auto iter = kernel_meta_map_.find(kernel_name); + if (iter == kernel_meta_map_.end()) { + return ""; + } else { + return iter->second; + } +} + +bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) { + if (!initialized_) { + return false; + } + kernel_meta_map_[kernel_name] = kernel_json; + return true; +} + +bool CheckCache(const std::string &kernel_name) { + // check cache. + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map == nullptr) { + MS_LOG(DEBUG) << "kernel cache is invalid."; + return false; + } + std::string kernel_json = bin_map->Search(kernel_name); + bool ret = (!kernel_json.empty()); + if (ret) { + MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed."; + } else { + MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed."; + } + return ret; +} + +KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) { + // search cache. + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map == nullptr) { + MS_LOG(DEBUG) << "kernel cache is invalid."; + return nullptr; + } + + std::string kernel_json = bin_map->Search(kernel_name); + if (!kernel_json.empty()) { + KernelPackPtr kernel_pack = std::make_shared(); + // just a tmp solution. + if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { + MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "]."; + return nullptr; + } else { + return kernel_pack; + } + } else { + MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "]."; + return nullptr; + } +} + +KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) { + MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor; + KernelMeta *bin_map = KernelMeta::GetInstance(); + std::string kernel_json; + if (processor == kProcessorAiCore || processor == kProcessorAiCpu) { + kernel_json = kCceKernelMeta; + } else { + kernel_json = bin_map->GetKernelMetaPath(); + } + (void)kernel_json.append(kernel_name).append(kJsonSuffix); + KernelPackPtr kernel_pack = std::make_shared(); + if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { + MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "]."; + return nullptr; + } + + if (bin_map == nullptr) { + MS_LOG(DEBUG) << "kernel cache is invalid."; + return nullptr; + } + if (bin_map->Insert(kernel_name, kernel_json)) { + MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "]."; + } + return kernel_pack; +} + +TypeId DtypeToTypeId(const std::string &dtypes) { + auto iter = type_id_maps.find(dtypes); + if (iter != type_id_maps.end()) { + return iter->second; + } else { + MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes; + } +} + +std::string TypeId2String(TypeId type_id) { + auto iter = type_id_str_map.find(type_id); + if (iter == type_id_str_map.end()) { + return std::string(TypeIdLabel(type_id)); + } + return iter->second; +} + +std::string Dtype2ShortType(const std::string &dtypes) { + auto iter = dtype_shortdtype_map_.find(dtypes); + if (iter != dtype_shortdtype_map_.end()) { + return iter->second; + } else { + MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; + } +} + +size_t GetDtypeNbyte(const std::string &dtypes) { + auto iter = dtype_nbyte_map.find(dtypes); + if (iter != dtype_nbyte_map.end()) { + return iter->second; + } else { + MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; + } +} + +bool SetInputKernelBuilderInfo(const std::vector> &inputs, size_t real_input_num, + size_t builder_idex, const std::vector &dyn_input_sizes, + const std::shared_ptr &builder) { + MS_EXCEPTION_IF_NULL(builder); + + std::vector inputs_device_type; + std::vector inputs_format; + size_t dyn_input_idx = 0; + size_t kernel_info_index = 0; + MS_EXCEPTION_IF_NULL(inputs[0]); + size_t kernel_info_cnt = inputs[0]->dtypes().size(); + + for (const auto &input : inputs) { + MS_EXCEPTION_IF_NULL(input); + std::string param_type = input->param_type(); + std::vector dtypes = input->dtypes(); + std::vector formats = input->formats(); + if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { + MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size."; + return false; + } + + if (param_type == "dynamic") { + if (dyn_input_sizes.empty()) { + MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; + return false; + } + + for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { + kernel_info_index++; + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + inputs_device_type.push_back(type_id); + inputs_format.push_back(formats[builder_idex]); + } + dyn_input_idx++; + } else if (param_type == "required") { + kernel_info_index++; + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + inputs_device_type.push_back(type_id); + inputs_format.push_back(formats[builder_idex]); + } else { + if (kernel_info_index < real_input_num) { + MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index; + kernel_info_index++; + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + inputs_device_type.push_back(type_id); + inputs_format.push_back(formats[builder_idex]); + } + } + } + + builder->SetInputsDeviceType(inputs_device_type); + builder->SetInputsFormat(inputs_format); + return true; +} + +bool SetOutputKernelBuilderInfo(const std::vector> &outputs, size_t builder_idex, + const size_t &real_output_num, + const std::shared_ptr &builder) { + // not now but in the next we need to support dynamic output case + MS_EXCEPTION_IF_NULL(builder); + + size_t output_idx = 0; + std::vector outputs_device_type; + std::vector outputs_format; + MS_EXCEPTION_IF_NULL(outputs[0]); + size_t kernel_info_cnt = outputs[0]->dtypes().size(); + + for (const auto &output : outputs) { + MS_EXCEPTION_IF_NULL(output); + if (output_idx >= real_output_num) { + MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!"; + continue; + } + size_t output_num = 0; + if (output->param_type() == "dynamic") { + if (outputs.size() > 1) { + MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; + } + output_num = real_output_num; + } else if (output->param_type() == "required") { + output_num = 1; + } else { + if (output_idx < real_output_num) { + MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; + output_num = 1; + } + } + + for (size_t i = 0; i < output_num; i++) { + std::vector dtypes = output->dtypes(); + std::vector formats = output->formats(); + if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { + MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size."; + return false; + } + auto type_id = DtypeToTypeId(dtypes[builder_idex]); + outputs_device_type.push_back(type_id); + outputs_format.push_back(formats[builder_idex]); + output_idx++; + } + } + + builder->SetOutputsFormat(outputs_format); + builder->SetOutputsDeviceType(outputs_device_type); + return true; +} + +void SetKernelBuildInfo(const std::shared_ptr &builder, Processor processor, + const std::shared_ptr &op_info_ptr) { + MS_EXCEPTION_IF_NULL(builder); + MS_EXCEPTION_IF_NULL(op_info_ptr); + + auto imply_type = op_info_ptr->imply_type(); + builder->SetProcessor(processor); + std::string fusion_type = op_info_ptr->fusion_type(); + auto iter = fusion_type_maps.find(fusion_type); + if (iter != fusion_type_maps.end()) { + builder->SetFusionType(iter->second); + } else { + if (imply_type == kAKG) { + MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type; + } + } + + if (imply_type == kAKG) { + builder->SetKernelType(AKG_KERNEL); + } else if (imply_type == kAICPU) { + builder->SetKernelType(AICPU_KERNEL); + } else { + builder->SetKernelType(TBE_KERNEL); + } +} + +bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, + std::vector> *const kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + std::vector> inputs = op_info_ptr->inputs_ptr(); + std::vector> outputs = op_info_ptr->outputs_ptr(); + std::vector dyn_input_sizes; + auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("dyn_input_sizes") != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); + } + if (inputs.size() > 0) { + MS_EXCEPTION_IF_NULL(inputs[0]); + size_t kernel_info_cnt = inputs[0]->dtypes().size(); + for (size_t j = 0; j < kernel_info_cnt; j++) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + SetKernelBuildInfo(builder, processor, op_info_ptr); + + if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { + MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed."; + return false; + } + + if (outputs.size() > 0) { + if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { + MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; + return false; + } + } + + kernel_info_list->push_back(builder->Build()); + } + } else if (outputs.size() > 0) { + MS_EXCEPTION_IF_NULL(outputs[0]); + size_t kernel_info_cnt = outputs[0]->dtypes().size(); + for (size_t j = 0; j < kernel_info_cnt; j++) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + SetKernelBuildInfo(builder, processor, op_info_ptr); + + if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { + MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; + return false; + } + + kernel_info_list->push_back(builder->Build()); + } + } else { + if (processor == AICPU) { + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + SetKernelBuildInfo(builder, processor, op_info_ptr); + kernel_info_list->push_back(builder->Build()); + } + } + return true; +} + +void SaveJsonInfo(const std::string &json_name, const std::string &info) { + char real_path[PATH_MAX] = {0}; + std::string path = kCceKernelMeta + json_name + kInfoSuffix; + if (path.size() > PATH_MAX) { + MS_LOG(DEBUG) << "file path " << path << " is too long."; + return; + } + std::ofstream filewrite; + filewrite.open(path); + if (!filewrite.is_open()) { + return; + } + filewrite << info << std::endl; + filewrite.close(); +#if defined(_WIN32) || defined(_WIN64) + if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) { + MS_LOG(DEBUG) << "dir " << path << " does not exit."; + return; + } +#else + if (nullptr == realpath(path.c_str(), real_path)) { + MS_LOG(DEBUG) << "dir " << path << " does not exit."; + return; + } +#endif + MS_LOG(INFO) << "real path is :" << real_path; + if (chmod(real_path, S_IRUSR) == -1) { + MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail."; + } +} + +std::string GetProcessor(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string device; + switch (AnfAlgo::GetProcessor(anf_node)) { + case Processor::AICORE: + device = kProcessorAiCore; + break; + + case Processor::AICPU: + device = kProcessorAiCpu; + break; + + case Processor::CUDA: + device = kProcessorCuda; + break; + + default: + MS_LOG(DEBUG) << "Unknown processor type."; + break; + } + return device; +} + +bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b) { + if (shape_a.size() != shape_b.size()) { + return false; + } + for (size_t i = 0; i < shape_a.size(); ++i) { + if (shape_a[i] != shape_b[i]) { + return false; + } + } + return true; +} + +int Sign(float x) { + if (x > 0) { + return 1; + } + if (x < 0) { + return -1; + } + return 0; +} + +void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim) { + MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); + MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); + MS_EXCEPTION_IF_NULL(unique_grad); + MS_EXCEPTION_IF_NULL(unique_grad->value_); + MS_EXCEPTION_IF_NULL(unique_grad->indices_); + std::unordered_map index_map; + size_t unique_indices_size = 0; + for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { + int index = origin_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= first_dim) { + continue; + } + auto iter = index_map.find(index); + if (iter == index_map.end()) { + index_map[index] = unique_indices_size; + unique_grad->indices_[unique_indices_size] = index; + size_t start_index = unique_indices_size * outer_dim; + size_t end_index = start_index + outer_dim; + for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { + unique_grad->value_[j] = origin_sparse_grad.value_[k]; + } + unique_indices_size++; + } else { + size_t first_index = iter->second; + size_t start_index = first_index * outer_dim; + size_t end_index = start_index + outer_dim; + for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { + unique_grad->value_[j] += origin_sparse_grad.value_[k]; + } + } + } + unique_grad->indices_size_ = unique_indices_size; +} + +struct WorkerParamsForReduceSparseGradient { + size_t slice_start_{0}; + size_t slice_end_{0}; + size_t max_length_{0}; + size_t outer_dim_{0}; + std::vector> *sorted_indices_{nullptr}; + std::vector *slice_positions_{nullptr}; + float *src_value_{nullptr}; + SparseGradient *unique_grad_{nullptr}; +}; + +void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) { + MS_EXCEPTION_IF_NULL(param.sorted_indices_); + MS_EXCEPTION_IF_NULL(param.slice_positions_); + MS_EXCEPTION_IF_NULL(param.src_value_); + MS_EXCEPTION_IF_NULL(param.unique_grad_); + auto outer_dim = param.outer_dim_; + auto &sorted_indices = *(param.sorted_indices_); + auto &slice_positions = *(param.slice_positions_); + auto unique_grad = param.unique_grad_; + for (size_t slice_id = param.slice_start_; slice_id < param.slice_end_; ++slice_id) { + size_t cur_pos = slice_positions[slice_id]; + int index = sorted_indices[cur_pos].first; + unique_grad->indices_[slice_id] = index; + size_t start_index = slice_id * outer_dim; + auto ret_code = memcpy_s(unique_grad->value_ + start_index, (param.max_length_ - start_index) * sizeof(float), + param.src_value_ + sorted_indices[cur_pos].second, outer_dim * sizeof(float)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + cur_pos++; + size_t end_pos; + if (slice_id + 1 < slice_positions.size()) { + end_pos = slice_positions[slice_id + 1]; + } else { + end_pos = sorted_indices.size(); + } + while (cur_pos < end_pos) { + for (size_t i = 0; i < outer_dim; ++i) { + unique_grad->value_[start_index + i] += param.src_value_[sorted_indices[cur_pos].second + i]; + } + cur_pos++; + } + } +} + +void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, + size_t outer_dim, std::vector> *sorted_indices, + std::vector *slice_positions) { + MS_LOG(DEBUG) << "Start"; + size_t thread_num = 24; + if (slice_positions->size() < thread_num) { + thread_num = slice_positions->size(); + } + size_t stride = (slice_positions->size() + thread_num - 1) / thread_num; + thread_num = (slice_positions->size() + stride - 1) / stride; + std::vector threads; + size_t max_length = sorted_indices->size() * outer_dim; + for (size_t i = 0; i < thread_num; ++i) { + size_t slice_start = i * stride; + size_t slice_end = 0; + if (i == thread_num - 1) { + slice_end = slice_positions->size(); + } else { + slice_end = slice_start + stride; + } + WorkerParamsForReduceSparseGradient params{ + slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_, + unique_grad}; + threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params)); + } + for (size_t i = 0; i < thread_num; ++i) { + threads[i].join(); + } + MS_LOG(DEBUG) << "End"; +} + +void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim, bool use_multi_threads) { + MS_LOG(DEBUG) << "Start"; + MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); + MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); + MS_EXCEPTION_IF_NULL(unique_grad); + MS_EXCEPTION_IF_NULL(unique_grad->value_); + MS_EXCEPTION_IF_NULL(unique_grad->indices_); + std::vector> sorted_indices; + sorted_indices.reserve(origin_sparse_grad.indices_size_); + for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { + int index = origin_sparse_grad.indices_[i]; + if (index >= 0 && IntToSize(index) < first_dim) { + sorted_indices.emplace_back(std::pair(index, i * outer_dim)); + } + } + std::sort( + sorted_indices.begin(), sorted_indices.end(), + [](const std::pair &left, const std::pair &right) { return left.first < right.first; }); + int last_index = 0; + std::vector slice_positions; + slice_positions.reserve(sorted_indices.size()); + for (size_t i = 0; i < sorted_indices.size(); ++i) { + if (i == 0 || last_index != sorted_indices[i].first) { + slice_positions.emplace_back(i); + } + last_index = sorted_indices[i].first; + } + if (use_multi_threads) { + RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions); + } else { + size_t max_length = sorted_indices.size() * outer_dim; + WorkerParamsForReduceSparseGradient params{0, + slice_positions.size(), + max_length, + outer_dim, + &sorted_indices, + &slice_positions, + origin_sparse_grad.value_, + unique_grad}; + WorkerForReduceSparseGradient(params); + } + unique_grad->indices_size_ = slice_positions.size(); + MS_LOG(DEBUG) << "End"; +} + +void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, + SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim) { + MS_LOG(DEBUG) << "Start"; + if (unique_slice_grads.empty()) { + return; + } + size_t index_data_size = outer_dim * sizeof(float); + size_t unique_indices_size = 0; + for (size_t i = 0; i < unique_slice_grads.size(); ++i) { + auto &slice_grad = unique_slice_grads[i]; + auto ret_code = memcpy_s(tmp_grad->value_ + unique_indices_size * outer_dim, + (tmp_grad->indices_size_ - unique_indices_size) * index_data_size, slice_grad->value_, + slice_grad->indices_size_ * index_data_size); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + ret_code = + memcpy_s(tmp_grad->indices_ + unique_indices_size, (tmp_grad->indices_size_ - unique_indices_size) * sizeof(int), + slice_grad->indices_, slice_grad->indices_size_ * sizeof(int)); + if (ret_code != EOK) { + MS_LOG(EXCEPTION) << "Failed to copy data!"; + } + unique_indices_size += slice_grad->indices_size_; + } + tmp_grad->indices_size_ = unique_indices_size; + ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim); + MS_LOG(DEBUG) << "End"; +} + +void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, + SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) { + MS_LOG(DEBUG) << "Start"; + MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); + MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); + MS_EXCEPTION_IF_NULL(unique_grad); + MS_EXCEPTION_IF_NULL(unique_grad->value_); + MS_EXCEPTION_IF_NULL(unique_grad->indices_); + MS_EXCEPTION_IF_NULL(tmp_grad); + MS_EXCEPTION_IF_NULL(tmp_grad->value_); + MS_EXCEPTION_IF_NULL(tmp_grad->indices_); + size_t thread_num = 24; + if (origin_sparse_grad.indices_size_ < thread_num) { + thread_num = origin_sparse_grad.indices_size_; + } + size_t thread_indices_size = origin_sparse_grad.indices_size_ / thread_num; + size_t left_indices_size = origin_sparse_grad.indices_size_ % thread_num; + std::vector threads; + threads.reserve(thread_num); + std::vector> unique_slice_grads; + for (size_t i = 0; i < thread_num; ++i) { + size_t indices_size = thread_indices_size; + if (i == thread_num - 1) { + indices_size = thread_indices_size + left_indices_size; + } + size_t value_offset = i * thread_indices_size * outer_dim; + size_t indices_offset = i * thread_indices_size; + auto slice_grad = SparseGradient( + {origin_sparse_grad.value_ + value_offset, origin_sparse_grad.indices_ + indices_offset, indices_size}); + unique_slice_grads.emplace_back(std::make_shared()); + unique_slice_grads[i]->value_ = unique_grad->value_ + value_offset; + unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset; + unique_slice_grads[i]->indices_size_ = indices_size; + threads.emplace_back( + std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false)); + } + for (size_t i = 0; i < thread_num; ++i) { + threads[i].join(); + } + ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim); + MS_LOG(DEBUG) << "End"; +} + +std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + + if (index >= AnfAlgo::GetInputTensorNum(anf_node)) { + MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs."; + } + + auto cnode = anf_node->cast(); + if (cnode == nullptr) { + return AnfAlgo::VisitKernel(anf_node, 0); + } else { + return AnfAlgo::VisitKernel(anf_node->cast()->input(index + 1), 0); + } +} + +std::vector>> GetInputIndex(const std::vector &node_list, + const std::vector &input_list) { + std::vector>> input_index; + for (size_t i = 0; i < input_list.size(); ++i) { + auto const &input = input_list[i]; + MS_EXCEPTION_IF_NULL(input); + bool found = false; + // using NodeUsersMap = std::unordered_map>>; + auto mng = input->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); + const NodeUsersMap &users = mng->node_users(); + auto input_users = users.find(input); + if (input_users == users.end() || input_users->second.empty()) { + MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" + << input->func_graph()->ToString() << "] has no users."; + } + + for (auto const &input_user : input_users->second) { + for (auto const &anf_node : node_list) { + if (anf_node != input_user.first) { + continue; + } + + std::vector dyn_input_sizes; + auto prim = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(prim); + if (prim->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(prim->GetAttr(kAttrDynInputSizes)); + } + + if (dyn_input_sizes.empty()) { + input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); + found = true; + break; + } else { + int used_as_idx = input_user.second - 1; + int accum_idx = 0; + size_t dyn_i = 0; + for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) { + accum_idx += dyn_input_sizes[dyn_i]; + if (used_as_idx < accum_idx) { + input_index.push_back(std::make_pair( + anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); + break; + } + } + if (dyn_i != dyn_input_sizes.size()) { + found = true; + break; + } + } + } + if (found) { + break; + } + } + + if (!found) { + MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" + << input->func_graph()->ToString() << "] found no related kernel info."; + } + } + return input_index; +} + +std::vector> GetOutputIndex(const std::vector &node_list, + const std::vector &input_list, + const std::vector &output_list) { + std::vector> output_index; + for (size_t i = 0; i < output_list.size(); ++i) { + auto const &output = output_list[i]; + MS_EXCEPTION_IF_NULL(output); + bool found = false; + auto pree_node = AnfAlgo::VisitKernel(output, 0); + auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); + if (pos != std::end(node_list)) { + output_index.push_back(pree_node); + continue; + } + auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); + if (ret != std::end(input_list)) { + output_index.push_back(std::make_pair(pree_node.first, 0)); + found = true; + } + if (!found) { + MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" + << output->func_graph()->ToString() << "] found no related kernel info."; + } + } + return output_index; +} + +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list) { + MS_EXCEPTION_IF_NULL(node_list); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector node_lists = TopoSort(func_graph->get_return()); + for (auto const &node : node_lists) { + if (!AnfAlgo::IsRealKernel(node) || !node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsValueNode(cnode->input(kAnfPrimitiveIndex))) { + node_list->push_back(node); + } + } +} + +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, + std::vector *input_list, std::vector *output_list) { + MS_EXCEPTION_IF_NULL(node_list); + MS_EXCEPTION_IF_NULL(input_list); + MS_EXCEPTION_IF_NULL(output_list); + MS_EXCEPTION_IF_NULL(func_graph); + + GetValidKernelNodes(func_graph, node_list); + + auto parameters = func_graph->parameters(); + input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); + + auto func_output = func_graph->output(); + MS_EXCEPTION_IF_NULL(func_output); + if (func_output->isa()) { + // multi output. + auto cnode = func_output->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input0 = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) { + auto input_node = cnode->input(input_idx); + MS_EXCEPTION_IF_NULL(input_node); + output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); + } + } else { + // single output. + output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); + } + } else { + // single output. + output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); + } +} + +bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(node_json); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->size()) { + MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" + << cnode->inputs().size() << "][" << cnode->DebugString() << "]"; + } + + auto input_node = cnode->input(input_idx + 1); + if (!IsValueNode(input_node)) { + return false; + } + + auto tensor = GetValueNode(input_node); + if (tensor == nullptr) { + return false; + } + + auto type_id = tensor->data_type(); + auto *data = tensor->data_c(); + MS_EXCEPTION_IF_NULL(data); + if (tensor->DataDim() > 1 || tensor->DataSize() != 1) { + // not const tensor. + MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]"; + } + + if (type_id == kFloat32->type_id()) { + float *val = static_cast(data); + MS_EXCEPTION_IF_NULL(val); + (*node_json)["value"] = val[0]; + MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "]."; + return true; + } else if (type_id == kFloat16->type_id()) { + float16 *val = static_cast(data); + MS_EXCEPTION_IF_NULL(val); + (*node_json)["value"] = static_cast(val[0]); + MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "]."; + return true; + } else if (type_id == kInt32->type_id()) { + int *val = static_cast(data); + MS_EXCEPTION_IF_NULL(val); + (*node_json)["value"] = val[0]; + MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "]."; + return true; + } + MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]"; + return false; +} + +void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node_list); + auto output = func_graph->output(); + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::IsRealKernel(output)) { + // single output. + node_list->push_back(std::make_pair(output, 0)); + return; + } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + // multi output. + auto &inputs = output_cnode->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); + node_list->push_back(in_with_idx); + } + return; + } + MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) + << " of graph: " << func_graph->ToString(); +} + +bool IsWeightBoundary(const AnfNodePtr &node) { + if (node->isa()) { + return true; + } + if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { + return true; + } + return false; +} + +void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, + size_t total_compute_size) { + const size_t kThreadNum = 24; + std::vector threads; + threads.reserve(kThreadNum); + size_t start = 0; + size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum; + while (start < total_compute_size) { + size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size); + threads.emplace_back(std::thread(func, params, start, end)); + start += once_compute_size; + } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); + } +} + +std::vector GetReduceAttrAxis(const CNodePtr &cnode) { + if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) && + AnfAlgo::GetInputTensorNum(cnode) != 1) { + MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString() + << "] is not single input or single output "; + } + std::vector axis; + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto axis_attr = primitive->GetAttr(kAxis); + if (axis_attr == nullptr) { + MS_LOG(ERROR) << "This node does't have axie attr."; + return std::vector(); + } + auto type = axis_attr->type(); + MS_EXCEPTION_IF_NULL(type); + std::vector axis_list; + if (type->ToString() == kTypeInt32) { + axis_list.emplace_back(GetValue(axis_attr)); + } else { + axis_list = GetValue>(axis_attr); + } + for (const auto &elem : axis_list) { + if (elem < 0) { + axis.emplace_back(input_shape.size() + elem); + } else { + axis.emplace_back(elem); + } + } + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode); + return axis; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h new file mode 100644 index 0000000000..8c9ea84b34 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -0,0 +1,145 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/oplib/opinfo.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +constexpr auto kCceKernelMeta = "./kernel_meta/"; +constexpr auto kGpuKernelMeta = "./cuda_meta"; +constexpr auto kProcessorAiCore = "aicore"; +constexpr auto kProcessorAiCpu = "aicpu"; +constexpr auto kProcessorCuda = "cuda"; +constexpr auto kJsonSuffix = ".json"; +constexpr auto kInfoSuffix = ".info"; +constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; +constexpr auto kAkgModule = "_akg"; +constexpr auto kArgDataformat = "data_format"; + +const std::vector support_devices = {"aicore", "aicpu", "cuda"}; + +struct KernelMetaInfo { + uintptr_t func_stub_; + uint32_t block_dim_; +}; +using KernelMetaPtr = std::shared_ptr; + +class KernelMeta { + public: + KernelMeta() = default; + void Initialize(); + void RemoveKernelCache(); + std::string Search(const std::string &kernel_name) const; + bool Insert(const std::string &kernel_name, const std::string &kernel_json); + std::string GetKernelMetaPath() { return kernel_meta_path_; } + + static KernelMeta *GetInstance() { + static KernelMeta kernel_meta; + return &kernel_meta; + } + ~KernelMeta() = default; + + private: + bool initialized_ = false; + std::string kernel_meta_path_; + std::unordered_map kernel_meta_map_; +}; + +struct SparseGradient { + float *value_; + int *indices_; + size_t indices_size_; +}; + +struct MultiThreadComputeParams { + float *var_; + float *accum_; + float *linear_; + float *m_; + float *m_t_; + float *v_; + float lr_; + float l1_; + float l2_; + float lr_power_; + float beta1_; + float beta2_; + float epsilon_; + SparseGradient sparse_grad_; + size_t var_first_dim_size_; + size_t var_outer_dim_size_; + bool use_nesterov_; +}; +using MultiThreadComputeFunc = std::function; + +bool CheckCache(const std::string &kernel_name); +KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); +KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); +TypeId DtypeToTypeId(const std::string &dtypes); +std::string Dtype2ShortType(const std::string &dtypes); +std::string TypeId2String(TypeId type_id); +size_t GetDtypeNbyte(const std::string &dtypes); +bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, + std::vector> *const kernel_info_list); +void SaveJsonInfo(const std::string &json_name, const std::string &info); +std::string GetProcessor(const AnfNodePtr &anf_node); +bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); +int Sign(float x); +void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim); +void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim, bool use_multi_threads = true); +std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index); +std::vector>> GetInputIndex(const std::vector &node_list, + const std::vector &input_list); +std::vector> GetOutputIndex(const std::vector &node_list, + const std::vector &input_list, + const std::vector &output_list); +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, + std::vector *input_list, std::vector *output_list); +void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list); +bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); +void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list); +bool IsWeightBoundary(const AnfNodePtr &node); +void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, + size_t total_compute_size); +void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, + size_t outer_dim, std::vector> *sorted_indices, + std::vector *slice_positions); +void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, + SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, + size_t outer_dim); +void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, + SparseGradient *unique_grad, size_t first_dim, size_t outer_dim); +std::vector GetReduceAttrAxis(const CNodePtr &cnode); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.cc new file mode 100644 index 0000000000..1300847d40 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/addn_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool AddNCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + + size_t offset = 0; + for (size_t i = 0; i < output_shape_[0]; ++i) { + for (size_t j = 0; j < output_shape_[1]; ++j) { + for (size_t k = 0; k < output_shape_[2]; ++k) { + for (size_t m = 0; m < output_shape_[3]; ++m) { + float sum = 0; + for (size_t index = 0; index < input_num_; ++index) { + auto input_addr = reinterpret_cast(inputs[index]->addr); + sum += input_addr[offset]; + } + output_addr[offset++] = sum; + } + } + } + } + + return true; +} + +void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but AddNCPUKernel olny support 4d or lower."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h new file mode 100644 index 0000000000..925f0fab50 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/addn_cpu_kernel.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class AddNCPUKernel : public CPUKernel { + public: + AddNCPUKernel() : input_num_(0) {} + ~AddNCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + size_t input_num_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL(AddN, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AddNCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.cc new file mode 100644 index 0000000000..55afecb8fa --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/allgather_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto kRanksGroup = "group"; +constexpr auto kAllGatherInputNum = 1; +} // namespace + +void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != kAllGatherInputNum) { + MS_LOG(EXCEPTION) << "allgather input num:" << input_num; + } + + auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); + if (ranks_group != nullptr) { + ranks_group_ = GetValue>(ranks_group); + } else { + MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; + } +} + +bool AllGatherCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto input_data_num = inputs[0]->size / sizeof(float); + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h new file mode 100644 index 0000000000..42c83ccf0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/allgather_cpu_kernel.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class AllGatherCPUKernel : public CPUKernel { + public: + AllGatherCPUKernel() = default; + ~AllGatherCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector ranks_group_; +}; + +MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AllGatherCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc new file mode 100644 index 0000000000..c1ff8d54bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void ApplyMomentumCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} + +bool ApplyMomentumCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/) { + if (inputs.size() < 5) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[3]->size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + auto weight = reinterpret_cast(inputs[0]->addr); + auto accumulate = reinterpret_cast(inputs[1]->addr); + float learning_rate = reinterpret_cast(inputs[2]->addr)[0]; + auto gradient = reinterpret_cast(inputs[3]->addr); + float moment = reinterpret_cast(inputs[4]->addr)[0]; + size_t elem_num = inputs[0]->size / sizeof(float); + for (size_t i = 0; i < elem_num; ++i) { + accumulate[i] = accumulate[i] * moment + gradient[i]; + weight[i] -= accumulate[i] * learning_rate; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h new file mode 100644 index 0000000000..23e8488890 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ApplyMomentumCPUKernel : public MKLCPUKernel { + public: + ApplyMomentumCPUKernel() = default; + ~ApplyMomentumCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ApplyMomentumCPUKernel); +MS_REG_CPU_KERNEL(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + ApplyMomentumCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc new file mode 100644 index 0000000000..d67c4d47ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/argmax_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (shape.size() != 2) { + MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis != -1 && axis != 1) { + MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis; + } +} + +bool ArgmaxCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + size_t row_start = 0; + for (size_t i = 0; i < batch_size_; ++i) { + size_t max_index = 0; + float max_value = input[row_start]; + for (size_t j = 1; j < class_num_; ++j) { + size_t index = row_start + j; + if (input[index] > max_value) { + max_value = input[index]; + max_index = j; + } + } + output[i] = SizeToInt(max_index); + row_start += class_num_; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h new file mode 100644 index 0000000000..3883344f96 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmax_cpu_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ArgmaxCPUKernel : public CPUKernel { + public: + ArgmaxCPUKernel() = default; + ~ArgmaxCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t class_num_{0}; + size_t batch_size_{0}; +}; + +MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + ArgmaxCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.cc new file mode 100644 index 0000000000..f42bb6807d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/bias_add_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + if (input_shape_.size() == 4) { + data_shape_ = 4; + } else if (input_shape_.size() == 2) { + data_shape_ = 2; + } else { + MS_LOG(EXCEPTION) << "bias add input data format should be NCHW or NC"; + } + if (input_shape_.size() != 2 && input_shape_.size() != 4) { + MS_LOG(EXCEPTION) << "bias add input shape nchw or nc"; + } + if (bias_shape_.size() != 1) { + MS_LOG(EXCEPTION) << "bias shape invalid"; + } + if (input_shape_[1] != bias_shape_[0]) { + MS_LOG(EXCEPTION) << "bias shape not match"; + } +} + +bool BiasAddCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() != 2 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "inputs outputs size not supoort"; + } + + auto src_addr = reinterpret_cast(inputs[0]->addr); + auto bias_addr = reinterpret_cast(inputs[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + if (data_shape_ == 4) { + size_t h_size = input_shape_[3]; + size_t c_size = input_shape_[2] * h_size; + size_t n_size = input_shape_[1] * c_size; + size_t hw_size = input_shape_[2] * input_shape_[3]; + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + size_t c_offset = 0; + for (size_t c = 0; c < input_shape_[1]; ++c) { + for (size_t hw = 0; hw < hw_size; ++hw) { + size_t offset = n_offset + c_offset + hw; + output_addr[offset] = src_addr[offset] + bias_addr[c]; + } + c_offset += c_size; + } + n_offset += n_size; + } + } else { + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + for (size_t c = 0; c < input_shape_[1]; ++c) { + output_addr[n_offset + c] = src_addr[n_offset + c] + bias_addr[c]; + } + n_offset += input_shape_[1]; + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h new file mode 100644 index 0000000000..c572f68230 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_cpu_kernel.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BiasAddCPUKernel : public CPUKernel { + public: + BiasAddCPUKernel() = default; + ~BiasAddCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + uint8_t data_shape_{0}; + std::vector input_shape_; + std::vector bias_shape_; +}; +MS_REG_CPU_KERNEL( + BiasAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.cc new file mode 100644 index 0000000000..8b6e2d0188 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void BiasAddGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (input_shape_.size() != 4 && input_shape_.size() != 2) { + MS_LOG(EXCEPTION) << "input data format not support"; + } +} + +bool BiasAddGradCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 1) { + MS_LOG(EXCEPTION) << "input output size not support"; + } + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto input_addr = reinterpret_cast(inputs[0]->addr); + + if (input_shape_.size() == 4) { + size_t h_size = input_shape_[3]; + size_t c_size = h_size * input_shape_[2]; + size_t n_size = c_size * input_shape_[1]; + size_t hw_size = input_shape_[2] * input_shape_[3]; + size_t c_offset = 0; + for (size_t c = 0; c < input_shape_[1]; ++c) { + output_addr[c] = 0; + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + for (size_t hw = 0; hw < hw_size; ++hw) { + size_t offset = c_offset + n_offset + hw; + output_addr[c] += input_addr[offset]; + } + n_offset += n_size; + } + c_offset += c_size; + } + } else if (input_shape_.size() == 2) { + for (size_t c = 0; c < input_shape_[1]; ++c) { + output_addr[c] = 0; + size_t n_offset = 0; + for (size_t n = 0; n < input_shape_[0]; ++n) { + output_addr[c] += input_addr[c + n_offset]; + n_offset += input_shape_[1]; + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h new file mode 100644 index 0000000000..a5743879a7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/bias_add_grad_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class BiasAddGradCPUKernel : public CPUKernel { + public: + BiasAddGradCPUKernel() = default; + ~BiasAddGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector input_shape_; +}; +MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc new file mode 100644 index 0000000000..6776c0f154 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/concat_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_1_shape.size()); + } + axis_ += 4 - input_1_shape.size(); + + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + CPUKernelUtils::ExpandDimsTo4(&input_shape); + input_shape_list_.push_back(input_shape); + } + + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool ConcatCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto buff_size = outputs[0]->size; + size_t dim0 = output_shape_[0]; + size_t dim1 = output_shape_[1]; + size_t dim2 = output_shape_[2]; + + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); + } + } else if (axis_ == 0) { + CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); + } + return true; +} + +void ConcatCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr, size_t *buff_size) { + for (size_t i = 0; i < input_shape_list_.size(); ++i) { + auto input_i_shape = input_shape_list_[i]; + auto input_i_addr = reinterpret_cast(inputs[i]->addr); + + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_); + num *= input_i_shape[axis_]; + auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0); + auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; + } + *output_addr += num; + *buff_size -= num * sizeof(float); + } +} + +void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h new file mode 100644 index 0000000000..94e4ad40f3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ConcatCPUKernel : public CPUKernel { + public: + ConcatCPUKernel() : axis_(0) {} + ~ConcatCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr, size_t *buff_size); + int axis_; + std::vector> input_shape_list_; + std::vector output_shape_; +}; + +MS_REG_CPU_KERNEL(Concat, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc new file mode 100644 index 0000000000..fb9398e7c4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/cpu_kernel.h" + +namespace mindspore { +namespace kernel { +void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t type_size = sizeof(float); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + input_size_list_.emplace_back(tensor_size); + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + std::vector shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + output_size_list_.emplace_back(tensor_size); + } +} + +void CPUKernel::Init(const CNodePtr &kernel_node) { + InitKernel(kernel_node); + InitInputOutputSize(kernel_node); +} + +void CPUKernelUtils::ExpandDimsTo4(std::vector *shape) { + auto len = shape->size(); + if (len < 4) { + for (size_t i = 0; i < 4 - len; ++i) { + shape->insert(shape->begin(), 1); + } + } +} + +size_t CPUKernelUtils::CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, + size_t dim3) { + size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3; + return offset; +} + +size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector &shape, int axis) { + if (axis < 0) { + axis = axis + SizeToInt(shape.size()); + } + size_t result = 1; + for (int j = 3; j > axis; --j) { + result *= shape[j]; + } + return result; +} + +void CPUKernelUtils::GetElementNumEveryDim(const std::vector &shape, std::vector *element_num) { + size_t accumulation = 1; + element_num->emplace_back(1); + for (size_t i = shape.size() - 1; i > 0; --i) { + accumulation *= shape[i]; + element_num->emplace_back(accumulation); + } + std::reverse(element_num->begin(), element_num->end()); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h new file mode 100644 index 0000000000..f2aa292c6e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "ir/anf.h" +#include "backend/session/anf_runtime_algorithm.h" + +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; +namespace mindspore { +namespace kernel { +const char KSIZE[] = "ksize"; +const char STRIDE[] = "stride"; +const char STRIDES[] = "strides"; +const char DILATION[] = "dilation"; +const char PAD[] = "pad"; +const char PAD_MODE[] = "pad_mode"; +const char PADDING[] = "padding"; +const char PAD_MODE_LOWER_SAME[] = "same"; +const char PAD_MODE_LOWER_VALID[] = "valid"; +const char PAD_MODE_UPPER_SAME[] = "SAME"; +const char PAD_MODE_UPPER_VALID[] = "VALID"; +const char TRANSPOSE_A[] = "transpose_a"; +const char TRANSPOSE_B[] = "transpose_b"; +const char IS_GRAD[] = "is_grad"; +const char TRANSPOSE_NO = 'N'; +const char TRANSPOSE_YES = 'T'; +const char AXIS[] = "axis"; +const char BEGIN[] = "begin"; +const char END[] = "end"; +const char SIZE[] = "size"; +const char USE_NESTEROV[] = "use_nesterov"; + +class CPUKernel : public kernel::KernelMod { + public: + CPUKernel() = default; + ~CPUKernel() override = default; + virtual void Init(const CNodePtr &kernel_node); + virtual void InitKernel(const CNodePtr &kernel_node) = 0; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void * /*stream_ptr*/) override { + return Launch(inputs, workspace, outputs); + }; + virtual bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) = 0; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + protected: + virtual void InitInputOutputSize(const CNodePtr &kernel_node); + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +class CPUKernelUtils { + public: + static void ExpandDimsTo4(std::vector *shape); + static size_t CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); + static size_t GetElementNumOnAxis(const std::vector &shape, int axis); + static void GetElementNumEveryDim(const std::vector &shape, std::vector *element_num); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc new file mode 100644 index 0000000000..249450c193 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +#include +#include +#include + +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace kernel { +CPUKernelFactory &CPUKernelFactory::GetInstance() { + static CPUKernelFactory instance; + return instance; +} + +void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, + CPUKernelCreator &&kernel_creator) { + (void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator); +#if !defined(_WIN32) && !defined(_WIN64) + MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; +#endif +} + +std::shared_ptr CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { + auto kernel_info = apply_kernel->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(kernel_build_Info); + std::pair ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info); + if (ret_pair.first) { + return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second(); + } + return nullptr; +} + +std::pair CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name, + const KernelBuildInfo &kernel_info) { + auto iter = name_to_attr_creator_.find(kernel_name); + if (iter == name_to_attr_creator_.end()) { + MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!"; + return std::make_pair(false, 0); + } + auto creators = iter->second; + for (size_t index = 0; index < creators.size(); ++index) { + auto attr_creator = creators[index]; + if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) { + return std::make_pair(true, index); + } + } + return std::make_pair(false, 0); +} + +bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) { + for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) { + auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first; + if (kernel_info.GetInputDeviceType(i) != dtype) { + MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i) + << ", register type:" << dtype; + return false; + } + } + for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) { + auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first; + if (kernel_info.GetOutputDeviceType(i) != dtype) { + MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i) + << ", register type:" << dtype; + return false; + } + } + return true; +} + +std::vector CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) { + std::vector result; + auto iter = name_to_attr_creator_.find(kernel_name); + if (iter == name_to_attr_creator_.end()) { + MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; + return result; + } + auto creators = iter->second; + for (size_t index = 0; index < creators.size(); ++index) { + auto attr_creator = creators[index]; + result.push_back(attr_creator.first); + } + return result; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h new file mode 100644 index 0000000000..80f9a342ac --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ + +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "runtime/device/cpu/kernel_select_cpu.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::cpu::KernelAttr; +using CPUKernelCreator = std::function()>; +class CPUKernelFactory { + public: + static CPUKernelFactory &GetInstance(); + void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); + std::shared_ptr Create(const std::string &kernel_name, const CNodePtr &apply_kernel); + std::vector GetSupportedKernelAttrList(const std::string &kernel_name); + + private: + CPUKernelFactory() = default; + ~CPUKernelFactory() = default; + DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) + std::pair CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); + bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info); + std::map>> name_to_attr_creator_; +}; + +class CPUKernelRegistrar { + public: + CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) { + CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator)); + } + ~CPUKernelRegistrar() = default; +}; + +#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS) +#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) +#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \ + static_assert(std::is_base_of::value, " must be base of CPUKernel"); \ + static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ + []() { return std::make_shared(); }); + +#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) +#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) +#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ + static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ + #OPNAME, ATTR, []() { return std::make_shared>(); }); + +#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ + static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ + static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \ + #OPNAME, ATTR, []() { return std::make_shared>(); }); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc new file mode 100644 index 0000000000..344f03cc53 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/debug_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +namespace kernel { +void DebugCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } + +bool DebugCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 1 || outputs.empty()) { + MS_LOG(EXCEPTION) << " input or output empty!"; + } + auto val = reinterpret_cast(inputs[0]->addr); + MS_LOG(DEBUG) << " launch DebugCountCPUKernel val " << *val; + + auto output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + for (size_t i = 0; i < elem_num; i++) { + output[i] = val[i]; + } + +#ifdef ENABLE_DEBUGGER + // debugger will suspend execution is neccessary + Debugger::GetInstance()->PostDebugOp(); +#endif + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h new file mode 100644 index 0000000000..18302e8992 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/debug_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DebugCPUKernel : public CPUKernel { + public: + DebugCPUKernel() = default; + ~DebugCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), DebugCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.cc new file mode 100644 index 0000000000..1bcc36faa4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" + +namespace mindspore { +namespace kernel { +void EmbeddingLookUpCommGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); + MS_LOG(INFO) << "split_num: " << split_num_; + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape[0] % split_num_ != 0) { + MS_LOG(EXCEPTION) << "Input shape[0] is " << input_shape[0] << ", but it must be multiple of split_num."; + } +} + +bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t input_size = inputs[0]->size; + size_t output_size = outputs[0]->size; + MS_LOG(DEBUG) << "input addr: " << input_addr << "input size: " << input_size; + MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << output_size; + memset_s(output_addr, output_size, 0, output_size); + const std::vector &rank_group = {0, 1, 2, 3, 4, 5, 6, 7}; + size_t input_split_lens = input_size / split_num_ / sizeof(float_t); + size_t output_split_lens = output_size / split_num_ / sizeof(float_t); + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + for (int i = 0; i < split_num_; i++) { + mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, + input_split_lens); + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << time << " us"; +#endif + return true; +} + +void EmbeddingLookUpCommGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCommGradCPUKernel needs 1."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h new file mode 100644 index 0000000000..3e3807f58e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_comm_grad_cpu_kernel.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EmbeddingLookUpCommGradCPUKernel : public CPUKernel { + public: + EmbeddingLookUpCommGradCPUKernel() : split_num_(1) {} + ~EmbeddingLookUpCommGradCPUKernel() override{}; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + int split_num_; +}; + +MS_REG_CPU_KERNEL(EmbeddingLookupCommGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EmbeddingLookUpCommGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc new file mode 100644 index 0000000000..b2feb9204f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_lens_ = 1; + for (auto shape : input_shape_) { + input_lens_ = input_lens_ * shape; + } + indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + indices_lens_ = 1; + for (auto shape : indices_shape_) { + indices_lens_ = indices_lens_ * shape; + } + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + axis_ = 4 - input_shape_.size(); + if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) { + reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrReduceScatterFlag); + } +#ifdef ENABLE_MPI + if (reduce_scatter_flag_) { + size_t gatherv2_out_lens = 1; + for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { + if (i == 0) { + for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { + gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; + } + } else { + gatherv2_out_lens = gatherv2_out_lens * input_shape_[i]; + } + } + gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); + gather_v2_out_ = malloc(gatherv2_out_lens_); + if (gather_v2_out_ == nullptr) { + MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; + } + auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); + if (ret != 0) { + MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; + } + split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); + } +#else + if (reduce_scatter_flag_) { + MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; + } +#endif + if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { + offset_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrOffset); + } + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast(gather_v2_out_) : output_addr; + size_t dim0 = input_shape_[0]; + size_t dim1 = input_shape_[1]; + size_t dim2 = input_shape_[2]; + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + LookUpTable(inputs, i, j, k, &gather_out_addr); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + LookUpTable(inputs, i, j, 0, &gather_out_addr); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + LookUpTable(inputs, i, 0, 0, &gather_out_addr); + } + } else if (axis_ == 0) { + LookUpTable(inputs, 0, 0, 0, &gather_out_addr); + } +#ifdef ENABLE_MPI + if (reduce_scatter_flag_) { + size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); + size_t reduce_scatter_out_lens = one_split_lens / 8; + const std::vector &group = {0, 1, 2, 3, 4, 5, 6, 7}; + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + for (int i = 0; i < split_num_; i++) { + mpi_instance->ReduceScatter(reinterpret_cast(gather_v2_out_) + i * one_split_lens, + output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum"); + } + } +#endif + return true; +} + +void LookUpTable_task(const float *input_addr, float *output_addr, const int *indices_addr, size_t indices_lens, + size_t num, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, + std::vector input_shape, size_t input_lens) { + size_t lens = num * sizeof(float); + for (size_t i = 0; i < indices_lens; ++i) { + int indices = indices_addr[i] - offset; + if (indices >= 0) { + size_t index = IntToSize(indices); + if (index < input_shape[axis]) { + size_t pos = 0; + if (axis == 3) { + pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index); + } else if (axis == 2) { + pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0); + } else if (axis == 1) { + pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0); + } else if (axis == 0) { + pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0); + } + if (pos + num <= input_lens) { + auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + } else { + auto ret = memset_s(output_addr, lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + output_addr += num; + } +} + +void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto indices_addr = reinterpret_cast(inputs[1]->addr); + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); + float *task_out_addr = *output_addr; + const size_t thread_num = 8; + std::thread threads[8]; + size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; + size_t i; + size_t task_offset = 0; + MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens; + for (i = 0; i < thread_num; i++) { + if (task_offset >= indices_lens_) { + break; + } + MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; + threads[i] = + std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset, + task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_); + task_offset += task_proc_lens; + if (task_offset + task_proc_lens > indices_lens_) { + task_proc_lens = indices_lens_ - task_offset; + } + } + for (size_t j = 0; j < i; j++) { + threads[j].join(); + } + *output_addr += num * indices_lens_; +} + +void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() + << ", but EmbeddingLookUpCPUKernel olny support 4d or lower."; + } + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCPUKernel needs 2."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h new file mode 100644 index 0000000000..6c61ee346c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EmbeddingLookUpCPUKernel : public CPUKernel { + public: + EmbeddingLookUpCPUKernel() { + axis_ = 0; + offset_ = 0; + split_num_ = 0; + input_lens_ = 0; + indices_lens_ = 0; + gatherv2_out_lens_ = 0; + reduce_scatter_flag_ = false; + gather_v2_out_ = nullptr; + } + ~EmbeddingLookUpCPUKernel() override { + if (gather_v2_out_ != nullptr) { + free(gather_v2_out_); + gather_v2_out_ = nullptr; + } + } + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr); + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector indices_shape_; + std::vector output_shape_; + int axis_; + int offset_; + int split_num_; + size_t input_lens_; + size_t indices_lens_; + size_t gatherv2_out_lens_; + bool reduce_scatter_flag_; + + void *gather_v2_out_; +}; + +MS_REG_CPU_KERNEL( + EmbeddingLookup, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + EmbeddingLookUpCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.cc new file mode 100644 index 0000000000..a61cd185c6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/equal_count_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void EqualCountCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} + +bool EqualCountCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + if (inputs[0]->size != inputs[1]->size) { + MS_LOG(EXCEPTION) << "input or output size!"; + } + int count = 0; + auto left = reinterpret_cast(inputs[0]->addr); + auto right = reinterpret_cast(inputs[1]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + for (size_t i = 0; i < elem_num; i++) { + if (left[i] == right[i]) { + count++; + } + } + auto output = reinterpret_cast(outputs[0]->addr); + output[0] = count; + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h new file mode 100644 index 0000000000..6e4ed6d5f1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/equal_count_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EqualCountCPUKernel : public CPUKernel { + public: + EqualCountCPUKernel() = default; + ~EqualCountCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + EqualCountCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc new file mode 100644 index 0000000000..73b11f1c01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/gather_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shape_.size()); + } + axis_ += 4 - input_shape_.size(); + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +bool GatherV2CPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto buff_size = outputs[0]->size; + size_t dim0 = input_shape_[0]; + size_t dim1 = input_shape_[1]; + size_t dim2 = input_shape_[2]; + if (axis_ == 3) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); + } + } + } + } else if (axis_ == 2) { + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); + } + } + } else if (axis_ == 1) { + for (size_t i = 0; i < dim0; ++i) { + CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); + } + } else if (axis_ == 0) { + CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); + } + return true; +} + +void GatherV2CPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, + size_t dim2, float **output_addr, size_t *buff_size) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto indices_addr = reinterpret_cast(inputs[1]->addr); + size_t elem_num = inputs[1]->size / 4; + size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); + for (size_t i = 0; i < elem_num; ++i) { + if (indices_addr[i] < 0) { + MS_LOG(EXCEPTION) << "The indices value is less than 0."; + } + size_t index = IntToSize(indices_addr[i]); + if (index >= input_shape_[IntToSize(axis_)]) { + auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memset failed."; + } + } else { + size_t pos = 0; + if (axis_ == 3) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); + } else if (axis_ == 2) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); + } else if (axis_ == 1) { + pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); + } else if (axis_ == 0) { + pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); + } + auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; + } + } + *output_addr += num; + *buff_size -= num * sizeof(float); + } +} // namespace kernel + +void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h new file mode 100644 index 0000000000..8fdac0dfde --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class GatherV2CPUKernel : public CPUKernel { + public: + GatherV2CPUKernel() : axis_(0) {} + ~GatherV2CPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, + float **output_addr, size_t *buff_size); + void CheckParam(const CNodePtr &kernel_node); + std::vector input_shape_; + std::vector indices_shape_; + std::vector output_shape_; + int axis_; +}; + +MS_REG_CPU_KERNEL( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherV2CPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc new file mode 100644 index 0000000000..e58b1d319c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h" +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 || weight_shape.size() != 4) { + MS_LOG(EXCEPTION) << "conv2d only support nchw input!"; + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + + int kernel_size = SizeToInt(weight_shape[3]); + auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); + auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); + if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) { + MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!"; + } + if (stride_ori[0] != 1 || stride_ori[1] != 1) { + MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!"; + } + if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { + MS_LOG(EXCEPTION) << "conv2d dilation only support 1, and dilation must be 4d!"; + } + if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { + MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!"; + } + int stride = stride_ori[2]; + int dilation = dilation_ori[2]; + + dnnl::memory::dims strides{stride, stride}; + dnnl::memory::dims dilates{dilation - 1, dilation - 1}; + std::vector int_padding_l; + std::vector int_padding_r; + + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::convolution_forward::desc desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, + weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_WEIGHTS, weights_desc); + AddArgument(DNNL_ARG_DST, dst_desc); +} + +bool Conv2dCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h new file mode 100644 index 0000000000..c0c64ba4da --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class Conv2dCPUKernel : public MKLCPUKernel { + public: + Conv2dCPUKernel() = default; + ~Conv2dCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Conv2D, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc new file mode 100644 index 0000000000..3fa6a91405 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h" +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 || weight_shape.size() != 4) { + MS_LOG(EXCEPTION) << ("conv2d grad filter only support nchw input!"); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + + int kernel_size = SizeToInt(weight_shape[3]); + auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); + auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); + if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { + MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel only support equal stride, and stride must be 2d!"; + } + if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1, and dilation must be 4d!"; + } + if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!"; + } + int stride = stride_ori[0]; + int dilation = dilation_ori[2]; + + dnnl::memory::dims strides{stride, stride}; + dnnl::memory::dims dilates{dilation - 1, dilation - 1}; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + std::vector int_padding_l; + std::vector int_padding_r; + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::convolution_forward::desc forward_desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, + weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + + dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc( + dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto backward_prim_desc = dnnl::convolution_backward_weights::primitive_desc( + backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + primitive_ = std::make_shared(backward_prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_DST, dst_desc); + AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc); +} + +bool Conv2dGradFilterCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h new file mode 100644 index 0000000000..ae8269c142 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class Conv2dGradFilterCPUKernel : public MKLCPUKernel { + public: + Conv2dGradFilterCPUKernel() = default; + ~Conv2dGradFilterCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Conv2DBackpropFilter, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dGradFilterCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc new file mode 100644 index 0000000000..1f02d70f86 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 || weight_shape.size() != 4) { + MS_LOG(EXCEPTION) << "conv2d grad filter only support nchw input!"; + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + + int kernel_size = SizeToInt(weight_shape[3]); + auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); + auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); + if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { + MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel only support equal stride, and stride must be 2d!"; + } + if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1, and dilation must be 4d!"; + } + if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!"; + } + int stride = stride_ori[0]; + int dilation = dilation_ori[2]; + dnnl::memory::dims strides{stride, stride}; + dnnl::memory::dims dilates{dilation - 1, dilation - 1}; + std::vector int_padding_l; + std::vector int_padding_r; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); + GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "conv2d grad get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::convolution_forward::desc forward_desc = + dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, + weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + + dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc( + dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); + + auto backward_prim_desc = + dnnl::convolution_backward_data::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + primitive_ = std::make_shared(backward_prim_desc); + + AddArgument(DNNL_ARG_DIFF_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_DST, dst_desc); + AddArgument(DNNL_ARG_WEIGHTS, weights_desc); +} + +bool Conv2dGradInputCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h new file mode 100644 index 0000000000..6f699130a8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class Conv2dGradInputCPUKernel : public MKLCPUKernel { + public: + Conv2dGradInputCPUKernel() = default; + ~Conv2dGradInputCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Conv2DBackpropInput, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dGradInputCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc new file mode 100644 index 0000000000..626fd1934e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.cc @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h" +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { +#ifdef PLATFORM_86 + _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); + _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); +#endif + MS_EXCEPTION_IF_NULL(kernel_node); + using tag = dnnl::memory::format_tag; + using dim = dnnl::memory::dims; + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); + input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); + hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); + num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); + batch_size_ = SizeToInt(src_shape[1]); + seq_len_ = SizeToInt(src_shape[0]); + num_directions_ = 1; + if (bidirectional_) { + num_directions_ = 2; + } + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } + if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { + MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; + } + const int gate_size = 4 * hidden_size_; + for (int i = 0; i < num_layers_; ++i) { + weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); + weight_h_size_ += gate_size * hidden_size_; + } + weight_size_ = weight_size_ * num_directions_; + weight_h_size_ = weight_h_size_ * num_directions_; + auto eng = MKLKernelEngine::Get().engine(); + dnnl::stream s(eng); + dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; + if (bidirectional_) { + direction = dnnl::rnn_direction::bidirectional_concat; + } + dim src_dims = {seq_len_, batch_size_, input_size_}; + dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; + weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; + bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; + dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; + dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); + dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); + dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); + dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); + dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); + dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); + dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); + auto desc = std::make_shared(dnnl::prop_kind::forward_training, direction, src_desc, + src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), + formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, + dst_h_desc, dst_c_desc); + prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng); + primitive_ = std::make_shared(prim_desc_); + AddArgument(DNNL_ARG_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); + AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); + AddArgument(DNNL_ARG_BIAS, bias_desc); + AddArgument(DNNL_ARG_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); +} + +bool LstmCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + auto eng = MKLKernelEngine::Get().engine(); + auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); + user_weights_memory.set_data_handle(inputs[3]->addr); + user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); + Reorder(&user_weights_memory, &weights_memory); + Reorder(&user_weights_h_memory, &weights_h_memory); + auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); + if (has_bias_) { + bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); + } else { + auto ret = + memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "bias memset error"; + } + } + // set handle + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h new file mode 100644 index 0000000000..761494a931 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ +#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) +#define PLATFORM_86 +#endif +#ifdef PLATFORM_86 +#include +#endif +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" +namespace mindspore { +namespace kernel { +class LstmCPUKernel : public MKLCPUKernel { + public: + LstmCPUKernel() = default; + ~LstmCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + int weight_size_ = 0; + int weight_h_size_ = 0; + int input_size_; + int hidden_size_; + int num_layers_; + int batch_size_; + int seq_len_; + int num_directions_; + bool bidirectional_; + bool has_bias_; + dnnl::memory::dims weights_dims_; + dnnl::memory::dims weights_h_dims_; + dnnl::memory::dims bias_dims_; + dnnl::lstm_forward::primitive_desc prim_desc_; +}; + +MS_REG_CPU_KERNEL(LSTM, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc new file mode 100644 index 0000000000..56da8ec808 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h" +#include +#include +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + using tag = dnnl::memory::format_tag; + using dim = dnnl::memory::dims; + auto eng = MKLKernelEngine::Get().engine(); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); + input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); + hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); + num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); + has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); + batch_size_ = SizeToInt(src_shape[1]); + seq_len_ = SizeToInt(src_shape[0]); + num_directions_ = 1; + if (bidirectional_) { + num_directions_ = 2; + } + if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { + MS_LOG(EXCEPTION) << "error iteration shape!"; + } + if (num_layers_ <= 0) { + MS_LOG(EXCEPTION) << "layers must be greater than zero!"; + } + if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { + MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; + } + const int gate_size = 4 * hidden_size_; + for (int i = 0; i < num_layers_; ++i) { + weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); + weight_h_size_ += gate_size * hidden_size_; + } + weight_size_ = weight_size_ * num_directions_; + weight_h_size_ = weight_h_size_ * num_directions_; + dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; + if (bidirectional_) { + direction = dnnl::rnn_direction::bidirectional_concat; + } + dim src_dims = {seq_len_, batch_size_, input_size_}; + dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; + weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; + bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; + dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; + dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; + dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); + dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); + dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); + dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); + dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); + dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); + dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); + auto forward_desc = std::make_shared( + dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, + formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, + dst_c_desc); + auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng); + auto backward_desc = std::make_shared( + dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), + formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, + src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, + dst_h_desc, dst_c_desc); + prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); + primitive_ = std::make_shared(prim_backward_desc_); + + AddArgument(DNNL_ARG_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); + AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); + AddArgument(DNNL_ARG_BIAS, bias_desc); + AddArgument(DNNL_ARG_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); + AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); + AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); + AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); + AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); + AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); + AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); + AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); + AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); + AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); +} + +bool LSTMGradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace /*workspace*/, + const std::vector &outputs) { + using dt = dnnl::memory::data_type; + using tag = dnnl::memory::format_tag; + auto eng = MKLKernelEngine::Get().engine(); + // construct fw memory + auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); + auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); + auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); + user_weights_memory.set_data_handle(inputs[3]->addr); + user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); + Reorder(&user_weights_memory, &weights_memory); + Reorder(&user_weights_h_memory, &weights_h_memory); + if (has_bias_) { + bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); + } else { + if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0, + prim_backward_desc_.bias_desc().get_size())) { + MS_LOG(EXCEPTION) << "bias memset error"; + } + } + // construct bw memory + auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); + auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); + auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); + auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); + auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); + user_diff_weights_memory.set_data_handle(outputs[3]->addr); + user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); + if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0, + user_diff_weights_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "user weights grad memset error"; + } + if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0, + user_diff_weights_h_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "user weights iter grad memset error"; + } + if (has_bias_) { + diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); + } + if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0, + prim_backward_desc_.diff_bias_desc().get_size())) { + MS_LOG(EXCEPTION) << "bias grad memset error"; + } + if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0, + diff_weights_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "weights grad memset error"; + } + if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0, + diff_weights_h_memory.get_desc().get_size())) { + MS_LOG(EXCEPTION) << "weights iter grad memset error"; + } + SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); + SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); + SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); + SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); + SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); + ExecutePrimitive(); + Reorder(&diff_weights_memory, &user_diff_weights_memory); + Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h new file mode 100644 index 0000000000..b95b5ba792 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class LSTMGradCPUKernel : public MKLCPUKernel { + public: + LSTMGradCPUKernel() = default; + ~LSTMGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + int weight_size_ = 0; + int weight_h_size_ = 0; + int input_size_; + int hidden_size_; + int num_layers_; + int batch_size_; + int seq_len_; + int num_directions_; + bool bidirectional_; + bool has_bias_; + dnnl::memory::dims weights_dims_; + dnnl::memory::dims weights_h_dims_; + dnnl::memory::dims bias_dims_; + dnnl::lstm_backward::primitive_desc prim_backward_desc_; +}; + +MS_REG_CPU_KERNEL(LSTMGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LSTMGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc new file mode 100644 index 0000000000..4bbaa6459f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h" +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "common/utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + + if (src_shape.size() != 2 || weight_shape.size() != 2 || dst_shape.size() != 2) { + MS_LOG(EXCEPTION) << "matmul invalid input size"; + } + bool trans_a = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_A); + bool trans_b = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_B); + if (trans_a) { + trans_a_ = TRANSPOSE_YES; + dim_m_ = static_cast(src_shape[1]); + dim_k_ = static_cast(src_shape[0]); + } else { + dim_m_ = static_cast(src_shape[0]); + dim_k_ = static_cast(src_shape[1]); + } + if (trans_b) { + trans_b_ = TRANSPOSE_YES; + } + dim_n_ = static_cast(dst_shape[1]); +} + +bool MatMulCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "matmul error input output size!"; + } + dnnl_dim_t lda = dim_m_; + if (trans_a_ == TRANSPOSE_NO) { + lda = dim_k_; + } + dnnl_dim_t ldb = dim_k_; + if (trans_b_ == TRANSPOSE_NO) { + ldb = dim_n_; + } + auto input_a = reinterpret_cast(inputs[0]->addr); + auto input_b = reinterpret_cast(inputs[1]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + (void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, 1.f, input_a, lda, input_b, ldb, 0.f, output, dim_n_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h new file mode 100644 index 0000000000..ef52f652d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class MatMulCPUKernel : public MKLCPUKernel { + public: + MatMulCPUKernel() = default; + ~MatMulCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + char trans_a_{TRANSPOSE_NO}; + char trans_b_{TRANSPOSE_NO}; + dnnl_dim_t dim_m_{0}; + dnnl_dim_t dim_n_{0}; + dnnl_dim_t dim_k_{0}; +}; + +MS_REG_CPU_KERNEL( + MatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc new file mode 100644 index 0000000000..c71abe809d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" +#include +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" + +namespace mindspore { +namespace kernel { +void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, + const std::vector &src_shape, int kernel_size, int stride, + std::vector *padding_l, std::vector *padding_r) { + MS_EXCEPTION_IF_NULL(kernel_node); + if (src_shape.size() < 2) { + MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!"; + } + std::vector weight_height; + weight_height.emplace_back(src_shape[src_shape.size() - 2]); + weight_height.emplace_back(src_shape[src_shape.size() - 1]); + int rad = kernel_size / 2; + int need_pad = kernel_size - 1; + MS_LOG(INFO) << "pad mode " << pad_mode; + if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { + for (auto wh : weight_height) { + int re = (wh - 1) % stride; + int pad = std::max(rad - (re / 2), 0); + padding_r->emplace_back(pad); + pad = std::max(need_pad - pad - re, 0); + padding_l->emplace_back(pad); + } + } else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) { + MS_LOG(INFO) << "pad valid"; + padding_l->emplace_back(0); + padding_l->emplace_back(0); + padding_r->emplace_back(0); + padding_r->emplace_back(0); + } else { + std::vector pad = AnfAlgo::GetNodeAttr>(kernel_node, PAD); + if (pad.size() != 4) { + MS_LOG(EXCEPTION) << "wrong pad size in max pooling " << pad.size(); + } + padding_l->emplace_back(pad[0]); + padding_l->emplace_back(pad[1]); + padding_r->emplace_back(pad[2]); + padding_r->emplace_back(pad[3]); + } +} + +dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { + dnnl::memory::format_tag mem_tag; + auto dim_size = dims.size(); + if (dim_size == 4) { + mem_tag = dnnl::memory::format_tag::abcd; + } else if (dim_size == 3) { + mem_tag = dnnl::memory::format_tag::abc; + } else if (dim_size == 2) { + mem_tag = dnnl::memory::format_tag::ab; + } else if (dim_size == 1) { + mem_tag = dnnl::memory::format_tag::a; + } else { + MS_LOG(EXCEPTION) << "kernel dims invalid " << dim_size; + } + return mem_tag; +} + +dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector &shape) { + dnnl::memory::dims dims; + dims.insert(dims.end(), shape.begin(), shape.end()); + dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); + dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); + return mem_desc; +} + +void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc) { + arguments_[arg_key] = MKLKernelEngine::Get().CreateMemory(mem_desc, alloc); +} + +void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { + auto arg_iter = arguments_.find(arg_key); + if (arg_iter != arguments_.end()) { + arg_iter->second.set_data_handle(ptr); + } +} + +void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } + +void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { + MKLKernelEngine::Get().Reorder(src_mem, dst_mem); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h new file mode 100644 index 0000000000..fc7128b10e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "dnnl.hpp" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MKLCPUKernel : public CPUKernel { + public: + MKLCPUKernel() = default; + ~MKLCPUKernel() override = default; + + protected: + void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector &src_shape, + int kernel_size, int stride, std::vector *padding_l, std::vector *padding_r); + void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); + void SetArgumentHandle(int arg_key, void *ptr); + dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; + dnnl::memory::desc GetDefaultMemDesc(const std::vector &shape); + void ExecutePrimitive(); + std::unordered_map arguments_; + std::shared_ptr primitive_{nullptr}; + inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { + return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; + } + void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc new file mode 100644 index 0000000000..777668f960 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "utils/log_adapter.h" +#include "dnnl.hpp" + +namespace mindspore { +namespace kernel { +void MKLKernelEngine::Execute(const std::shared_ptr &primitive, + const std::unordered_map &arguments) { + MS_EXCEPTION_IF_NULL(primitive); + primitive->execute(stream_, arguments); + (void)stream_.wait(); +} + +dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc) { + if (alloc) { + return dnnl::memory(mem_desc, engine_); + } else { + return dnnl::memory(mem_desc, engine_, nullptr); + } +} +void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { + dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h similarity index 100% rename from mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.h rename to mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc new file mode 100644 index 0000000000..fddd769047 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { + MS_LOG(EXCEPTION) << "mul only support same dim input or tensor * scalar " << src0_shape.size() << " vs " + << src1_shape.size(); + } + if (src1_shape.size() < src0_shape.size()) { + for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { + src1_shape.emplace_back(1); + } + } + dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape); + dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape); + dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape); + dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_mul, src0_mem_desc, src1_mem_desc, dst_mem_desc); + auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + AddArgument(DNNL_ARG_SRC_0, src0_mem_desc); + AddArgument(DNNL_ARG_SRC_1, src1_mem_desc); + AddArgument(DNNL_ARG_DST, dst_mem_desc); +} + +bool MulCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "mul error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h new file mode 100644 index 0000000000..182679f59d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mul_cpu_kernel.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class MulCPUKernel : public MKLCPUKernel { + public: + MulCPUKernel() = default; + ~MulCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MulCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc new file mode 100644 index 0000000000..e4bedf23b9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h" +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); + std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); + std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + if (kernel_sizes.size() != 4 || strides.size() != 4) { + MS_LOG(EXCEPTION) << "invalid kernel size " << kernel_sizes.size() << " or stride size " << strides.size(); + } + dnnl::memory::dims strides_dims{strides[2], strides[3]}; + dnnl::memory::dims kernels_dims{kernel_sizes[2], kernel_sizes[3]}; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); + std::vector int_padding_l; + std::vector int_padding_r; + GetPadding(kernel_node, pad_mode, src_shape, kernel_sizes[3], strides[3], &int_padding_l, &int_padding_r); + if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { + MS_LOG(EXCEPTION) << "pooling get padding failed"; + } + dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; + dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; + dnnl::pooling_forward::desc desc = + dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc, + strides_dims, kernels_dims, padding_l, padding_r); + auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, dst_desc); + AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc()); +} + +bool PoolingCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h new file mode 100644 index 0000000000..8187eaffda --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class PoolingCPUKernel : public MKLCPUKernel { + public: + PoolingCPUKernel() = default; + ~PoolingCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PoolingCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc new file mode 100644 index 0000000000..8189df07ff --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h" +#include +#include +#include +#include "common/utils.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + src_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dst_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); + std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + if (kernel_sizes.size() != 4 || strides.size() != 4 || src_shape_.size() != 4 || dst_shape_.size() != 4) { + MS_LOG(EXCEPTION) << "pooling grad invalid input size"; + } + std::vector padding_r; + const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); + kernel_size_ = kernel_sizes[3]; + stride_ = strides[3]; + GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); +} + +void PoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, float diff, + const std::vector> &box, + std::vector> *row_max_pair) { + float max_value = 0; + size_t max_index = box[1].second; + size_t src_width = src_shape_[3]; + size_t index_start; + size_t index; + for (size_t i = box[1].first; i < box[1].second; ++i) { + if ((*row_max_pair)[i].first == 0) { + index_start = box[0].first * src_width; + for (size_t j = box[0].first; j < box[0].second; ++j) { + index = index_start + i; + if (input[index] > (*row_max_pair)[i].second || j == box[0].first) { + (*row_max_pair)[i].second = input[index]; + (*row_max_pair)[i].first = index; + } + index_start += src_width; + } + } + if ((*row_max_pair)[i].second > max_value || max_index == box[1].second) { + max_value = (*row_max_pair)[i].second; + max_index = i; + } + } + + output[(*row_max_pair)[max_index].first] += diff; +} + +void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *diff, float *output) { + int src_width = SizeToInt(src_shape_[3]); + int src_height = SizeToInt(src_shape_[2]); + std::vector> row_max_pair(src_shape_[3]); + std::vector> box(2); + int h_start = -padding_l_[0]; + size_t diff_index = 0; + for (size_t h = 0; h < dst_shape_[2]; ++h) { + box[0].first = IntToSize(std::max(h_start, 0)); + box[0].second = IntToSize(std::min(h_start + kernel_size_, src_height)); + for (size_t w = 0; w < src_shape_[3]; ++w) { + row_max_pair[w].first = 0; + row_max_pair[w].second = 0; + } + int w_start = -padding_l_[1]; + for (size_t w = 0; w < dst_shape_[3]; ++w) { + box[1].first = IntToSize(std::max(w_start, 0)); + box[1].second = IntToSize(std::min(w_start + kernel_size_, src_width)); + RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); + diff_index += 1; + w_start += stride_; + } + h_start += stride_; + } +} + +bool PoolingGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 3 || outputs.empty()) { + MS_LOG(EXCEPTION) << "pooling grad error input output size!"; + } + + auto input = reinterpret_cast(inputs[0]->addr); + auto diff = reinterpret_cast(inputs[2]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + auto ret = memset_s(output, outputs[0]->size, 0, outputs[0]->size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "pooling grad memset error"; + } + size_t src_wh = src_shape_[2] * src_shape_[3]; + size_t dst_wh = dst_shape_[2] * dst_shape_[3]; + for (size_t n = 0; n < src_shape_[0]; ++n) { + for (size_t c = 0; c < src_shape_[1]; ++c) { + ChannelPoolingGrad(input, diff, output); + input = input + src_wh; + output = output + src_wh; + diff = diff + dst_wh; + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h new file mode 100644 index 0000000000..95a7bb3f66 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_grad_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class PoolingGradCPUKernel : public MKLCPUKernel { + public: + PoolingGradCPUKernel() = default; + ~PoolingGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void RowPoolingGrad(const float *input, float *output, float diff, const std::vector> &box, + std::vector> *row_max_pair); + void ChannelPoolingGrad(const float *input, const float *diff, float *output); + int stride_{0}, kernel_size_{0}; + std::vector padding_l_; + std::vector src_shape_; + std::vector dst_shape_; +}; + +MS_REG_CPU_KERNEL(MaxPoolGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PoolingGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc new file mode 100644 index 0000000000..29ac9a1062 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void ReluCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 && src_shape.size() != 2) { + MS_LOG(EXCEPTION) << "relu kernel dims invalid " << src_shape.size(); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + + dnnl::eltwise_forward::desc desc = + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); + auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, src_desc); +} + +bool ReluCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h new file mode 100644 index 0000000000..a2da2480e2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ReluCPUKernel : public MKLCPUKernel { + public: + ReluCPUKernel() = default; + ~ReluCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc new file mode 100644 index 0000000000..9139aa7862 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void ReluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + if (src_shape.size() != 4 && src_shape.size() != 2) { + MS_LOG(EXCEPTION) << "relu grad kernel dims invalid " << src_shape.size(); + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + + dnnl::eltwise_forward::desc forward_desc = + dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); + auto forward_prim_desc = dnnl::eltwise_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); + + dnnl::eltwise_backward::desc backward_desc = + dnnl::eltwise_backward::desc(dnnl::algorithm::eltwise_relu, src_desc, src_desc, 0.0, 0.0); + auto backward_prim_desc = + dnnl::eltwise_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); + primitive_ = std::make_shared(backward_prim_desc); + + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_SRC, src_desc); + AddArgument(DNNL_ARG_DIFF_DST, src_desc); +} + +bool ReluGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 2 || outputs.empty()) { + MS_LOG(EXCEPTION) << "relu grad error input output size!"; + } + if (inputs[0]->size != outputs[0]->size) { + MS_LOG(EXCEPTION) << "relu grad error input output data size!"; + } + + SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); + ExecutePrimitive(); + size_t mem_bits = outputs[0]->size; + auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h new file mode 100644 index 0000000000..c895ab2756 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/relu_grad_cpu_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class ReluGradCPUKernel : public MKLCPUKernel { + public: + ReluGradCPUKernel() = default; + ~ReluGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReluGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc new file mode 100644 index 0000000000..94271b8a69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + std::vector axis_list = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); + if (axis_list.size() != 1) { + MS_LOG(EXCEPTION) << "cpu softmax only support input axis size 1"; + } + int axis = axis_list[0]; + if (axis == -1 || axis >= SizeToInt(src_shape.size())) { + axis = SizeToInt(src_shape.size()) - 1; + } + dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + AddArgument(DNNL_ARG_SRC, src_desc); + AddArgument(DNNL_ARG_DST, src_desc); +} + +bool SoftmaxCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "softmax error input output size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); + ExecutePrimitive(); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h new file mode 100644 index 0000000000..2812dd31af --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cpu_kernel.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SoftmaxCPUKernel : public MKLCPUKernel { + public: + SoftmaxCPUKernel() = default; + ~SoftmaxCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc new file mode 100644 index 0000000000..889e2abdec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h" +#include +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + size_t type_size = sizeof(float); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + workspace_size_list_.emplace_back(tensor_size); +} + +void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dnnl::memory::dims mem_dims; + mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); + if (mem_dims.size() != 2) { + MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + if (batch_size_ == 0 || class_num_ == 0) { + MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; + } + dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); + + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, mem_desc); + AddArgument(DNNL_ARG_DST, mem_desc); +} + +void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels, + float *output1, float *output2) const { + float epsilon = 1e-6; + for (size_t i = 0; i < batch_size_; ++i) { + output1[i] = 0; + float loss = 0.0; + for (size_t j = 0; j < class_num_; ++j) { + float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]); + output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j]; + loss += labels[i * class_num_ + j] * logit; + } + output1[i] = -loss; + } +} + +bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (inputs.empty() || workspace.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || + inputs[1]->size != batch_class_float_size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "error output data size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); + ExecutePrimitive(); + auto labels = reinterpret_cast(inputs[1]->addr); + auto logits = reinterpret_cast(workspace[0]->addr); + auto output1 = reinterpret_cast(outputs[0]->addr); + auto output2 = reinterpret_cast(outputs[1]->addr); + ForwardPostExecute(logits, labels, output1, output2); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h new file mode 100644 index 0000000000..d05cb49b7b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { + public: + SoftmaxCrossEntropyWithLogitsCPUKernel() = default; + ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + private: + void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const; + size_t class_num_{0}; + size_t batch_size_{0}; +}; +MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SoftmaxCrossEntropyWithLogitsCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc new file mode 100644 index 0000000000..b8bf7b318a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h" +#include +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + size_t type_size = sizeof(float); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + workspace_size_list_.emplace_back(tensor_size); +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + dnnl::memory::dims mem_dims; + mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); + if (mem_dims.size() != 2) { + MS_LOG(EXCEPTION) << "SparseSoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); + } + batch_size_ = shape[0]; + class_num_ = shape[1]; + if (batch_size_ == 0 || class_num_ == 0) { + MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; + } + is_grad_ = AnfAlgo::GetNodeAttr(kernel_node, IS_GRAD); + dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); + + dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); + auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, mem_desc); + AddArgument(DNNL_ARG_DST, mem_desc); +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, + float *output) const { + float total_loss = 0; + for (size_t i = 0; i < batch_size_; ++i) { + if (labels[i] < 0) { + MS_LOG(EXCEPTION) << "label value must >= 0"; + } + size_t label = IntToSize(labels[i]); + if (label > class_num_) { + MS_LOG(EXCEPTION) << "error label input!"; + } + total_loss -= logf(losses[i * class_num_ + label]); + } + output[0] = total_loss / batch_size_; +} + +void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, + float *output) const { + size_t row_start = 0; + for (size_t i = 0; i < batch_size_; ++i) { + if (labels[i] < 0) { + MS_LOG(EXCEPTION) << "label value must >= 0"; + } + size_t label = IntToSize(labels[i]); + if (label > class_num_) { + MS_LOG(EXCEPTION) << "error label input!"; + } + for (size_t j = 0; j < class_num_; ++j) { + size_t index = row_start + j; + if (j == label) { + output[index] = (losses[index] - 1) / batch_size_; + } else { + output[index] = losses[index] / batch_size_; + } + } + row_start += class_num_; + } +} + +bool SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + if (inputs.empty() || workspace.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + size_t batch_float_size = batch_size_ * sizeof(float); + size_t batch_class_float_size = class_num_ * batch_float_size; + if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || + inputs[1]->size != batch_float_size) { + MS_LOG(EXCEPTION) << "error input data size!"; + } + if (is_grad_ && outputs[0]->size != batch_class_float_size) { + MS_LOG(EXCEPTION) << "error output data size!"; + } else if (!is_grad_ && outputs[0]->size != sizeof(float)) { + MS_LOG(EXCEPTION) << "error output data size!"; + } + SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); + SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); + ExecutePrimitive(); + auto labels = reinterpret_cast(inputs[1]->addr); + auto losses = reinterpret_cast(workspace[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + if (is_grad_) { + GradPostExecute(labels, losses, output); + } else { + ForwardPostExecute(labels, losses, output); + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h new file mode 100644 index 0000000000..0d79b0514b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { + public: + SparseSoftmaxCrossEntropyWithLogitsCPUKernel() = default; + ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + void InitInputOutputSize(const CNodePtr &kernel_node) override; + + private: + void ForwardPostExecute(const int *labels, const float *losses, float *output) const; + void GradPostExecute(const int *labels, const float *losses, float *output) const; + bool is_grad_{false}; + size_t class_num_{0}; + size_t batch_size_{0}; +}; + +MS_REG_CPU_KERNEL( + SparseSoftmaxCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + SparseSoftmaxCrossEntropyWithLogitsCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.cc new file mode 100644 index 0000000000..5bbc9f49a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/one_hot_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void OneHotCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (output_shape.size() < 2) { + MS_LOG(EXCEPTION) << "invalid output shape size: " << output_shape.size(); + } + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis != -1 && IntToSize(axis) >= output_shape.size()) { + MS_LOG(EXCEPTION) << "invalid axis: " << axis; + } + if (axis == -1) { + axis_ = output_shape.size() - 1; + } else { + axis_ = IntToSize(axis); + } + depth_ = output_shape[axis_]; + stride_ = 1; + for (size_t i = axis_ + 1; i < output_shape.size(); ++i) { + stride_ *= output_shape[i]; + } +} + +bool OneHotCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.size() < 3 || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output invalid!"; + } + auto indices = reinterpret_cast(inputs[0]->addr); + auto on_value = reinterpret_cast(inputs[1]->addr)[0]; + auto off_value = reinterpret_cast(inputs[2]->addr)[0]; + auto output = reinterpret_cast(outputs[0]->addr); + size_t elem_num = inputs[0]->size / sizeof(int); + + for (size_t i = 0; i < elem_num; i++) { + size_t stride_num = i / stride_; + size_t output_index = stride_num * depth_ * stride_ + i % stride_; + size_t index = IntToSize(indices[i]); + for (size_t j = 0; j < depth_; j++) { + if (index == j) { + output[output_index] = on_value; + } else { + output[output_index] = off_value; + } + output_index += stride_; + } + } + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h new file mode 100644 index 0000000000..393b0e8c41 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/one_hot_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class OneHotCPUKernel : public CPUKernel { + public: + OneHotCPUKernel() = default; + ~OneHotCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t depth_; + size_t stride_; + size_t axis_; +}; + +MS_REG_CPU_KERNEL(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + OneHotCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.cc new file mode 100644 index 0000000000..6537c88840 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +bool ApplyMomentumPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return Launch(inputs, workspace, outputs); +} + +const std::vector &ApplyMomentumPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &ApplyMomentumPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &ApplyMomentumPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h new file mode 100644 index 0000000000..a78f40d04b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/apply_momentum_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { + public: + ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~ApplyMomentumPSKernel() override = default; + + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc new file mode 100644 index 0000000000..59ab65014b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h" +#include +#include "frontend/parallel/ps/worker.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { + EmbeddingLookUpCPUKernel::InitKernel(kernel_node); + + for (auto dim : input_shape_) { + input_dims_ *= dim; + } + + if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { + key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); + } + std::vector keys{key_, key_, key_}; + std::vector values; + values.insert(values.end(), input_shape_.begin(), input_shape_.end()); + values.insert(values.end(), indices_shape_.begin(), indices_shape_.end()); + values.insert(values.end(), output_shape_.begin(), output_shape_.end()); + std::vector lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()), + SizeToInt(output_shape_.size())}; + const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); + if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { + parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]); + parallel::ps::Worker::GetInstance().InitPSEmbeddingTable(keys, values, lens); + } +} + +bool EmbeddingLookUpProxyKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto indices_addr = reinterpret_cast(inputs[1]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t input_size = inputs[1]->size; + size_t output_size = outputs[0]->size; + + size_t size = input_size / sizeof(float); + ::ps::SArray lookup_ids(size, 0); + ::ps::SArray lengths{size}; + ::ps::SArray lookup_result; + + auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } + parallel::ps::Worker::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result, + parallel::ps::kEmbeddingLookupCmd); + + auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size); + if (ret2 != EOK) { + MS_LOG(EXCEPTION) << "Lookup result memcpy failed."; + } + return true; +} +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h new file mode 100644 index 0000000000..45e0a23fcb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ + +#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" +#include +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +namespace ps { +class EmbeddingLookUpProxyKernel : public EmbeddingLookUpCPUKernel { + public: + EmbeddingLookUpProxyKernel() = default; + ~EmbeddingLookUpProxyKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t key_{0}; + size_t input_dims_{1}; +}; + +MS_REG_CPU_KERNEL( + EmbeddingLookupProxy, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + EmbeddingLookUpProxyKernel); +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc new file mode 100644 index 0000000000..bcb3ca8ae8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" +#include +#include +#include +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::parallel::ps::Util; +void EmbeddingLookUpPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + input_shape_ = *(shape_vec[0]); + input_lens_ = 1; + for (auto shape : input_shape_) { + input_lens_ = input_lens_ * shape; + } + indices_shape_ = *(shape_vec[1]); + indices_lens_ = 1; + for (auto shape : indices_shape_) { + indices_lens_ = indices_lens_ * shape; + } + output_shape_ = *(shape_vec[2]); + axis_ = 2; + reduce_scatter_flag_ = false; + + size_t offset = 0; + for (size_t i = 0; i < rank_id_; i++) { + offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_); + } + offset_ = offset; + split_num_ = pserver_num_; + + // input shape should be sharded after computing offset_; + Shard(input_shape_, axis_); + + size_t output_size = + std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies()); + output_size_list_.emplace_back(output_size); + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + CPUKernelUtils::ExpandDimsTo4(&output_shape_); +} + +void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + const auto &indices_shape_ = *(shape_vec[0]); + indices_lens_ = indices_shape_[0]; + + size_t output_size = sizeof(float) * indices_lens_; + for (size_t i = axis_ + 1; i < input_shape_.size(); i++) { + output_size *= input_shape_[i]; + } + output_size_list_.clear(); + output_size_list_.emplace_back(output_size); +} + +bool EmbeddingLookUpPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + return Launch(inputs, workspace, outputs); +} + +const std::vector &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } + +const std::vector &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &EmbeddingLookUpPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h new file mode 100644 index 0000000000..e23a90a11c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { + public: + EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~EmbeddingLookUpPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc new file mode 100644 index 0000000000..3aa421881a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps {} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h new file mode 100644 index 0000000000..a2b6c4fa61 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::parallel::ps::Util; +class PServerKernel { + public: + PServerKernel(size_t rank_id, size_t pserver_num) : rank_id_(rank_id), pserver_num_(pserver_num) {} + ~PServerKernel() = default; + PServerKernel(const PServerKernel &) = delete; + PServerKernel &operator=(const PServerKernel &) = delete; + + virtual void InitKernel(const std::shared_ptr>>> &) {} + virtual void ReInit(const std::shared_ptr>>> &) {} + virtual bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) = 0; + + virtual const std::vector &input_sizes() const = 0; + virtual const std::vector &output_sizes() const = 0; + virtual const std::vector &workspace_sizes() const = 0; + + protected: + virtual void ReInit(const std::vector &) {} + void Shard(std::vector *shape, int axis) { + (*shape)[axis] = Util::LocalShard((*shape)[axis], rank_id_, pserver_num_); + } + + size_t rank_id_; + size_t pserver_num_; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.cc new file mode 100644 index 0000000000..92c901d4c8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/pull_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL_T( + Pull, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PullKernel, float); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h new file mode 100644 index 0000000000..84dd9b819e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pull_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ + +#include +#include +#include "frontend/parallel/ps/worker.h" +#include "frontend/parallel/ps/util.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class PullKernel : public CPUKernel { + public: + PullKernel() : keys_size_(sizeof(size_t)), var_size_(sizeof(size_t)) {} + ~PullKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { + // If the paramter is embedding table, don't Pull from PServer. + if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) { + parallel::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); + } + return true; + } + void Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but pull needs 2 inputs."; + return; + } + + auto key_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < key_shape.size(); i++) { + keys_size_ *= key_shape[i]; + } + auto var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < var_shape.size(); i++) { + var_size_ *= var_shape[i]; + } + auto param_node = AnfAlgo::GetInputNode(kernel_node, 1); + MS_EXCEPTION_IF_NULL(param_node); + param_name_ = param_node->fullname_with_scope(); + + if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { + key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); + } + InitSizeLists(); + return; + } + void InitKernel(const CNodePtr &kernel_node) { return; } + + protected: + void InitSizeLists() { + input_size_list_.push_back(keys_size_); + input_size_list_.push_back(var_size_); + output_size_list_.push_back(0); + } + + private: + size_t key_; + size_t keys_size_; + size_t var_size_; + std::string param_name_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc new file mode 100644 index 0000000000..96c1f15bda --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/ps/push_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_CPU_KERNEL_T(Push, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeUInt64), + PushKernel, float); + +MS_REG_CPU_KERNEL_T( + Push, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + PushKernel, float); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h new file mode 100644 index 0000000000..938792f3bf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/push_kernel.h @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ + +#include +#include +#include "frontend/parallel/ps/worker.h" +#include "frontend/parallel/ps/util.h" +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class PushKernel : public CPUKernel { + public: + PushKernel() : key_(UINT64_MAX) {} + ~PushKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + std::vector keys; + std::vector addrs; + std::vector sizes; + for (auto input : inputs) { + keys.push_back(key_); + addrs.push_back(reinterpret_cast(input->addr)); + sizes.push_back(SizeToInt(input->size) / sizeof(T)); + } + parallel::ps::Worker::GetInstance().Push(keys, addrs, sizes); + memcpy(outputs[0]->addr, &key_, sizeof(size_t)); + return true; + } + + void Init(const CNodePtr &kernel_node) { + key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); + auto optim_input_shapes = AnfAlgo::GetNodeAttr>>(kernel_node, "optim_input_shapes"); + std::vector only_shape_indices = AnfAlgo::GetNodeAttr>(kernel_node, "only_shape_indices"); + MS_LOG(INFO) << "Key " << key_ << " optimizer input shapes are:" << optim_input_shapes; + MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices; + for (size_t i = 0; i < optim_input_shapes.size(); i++) { + auto shape = optim_input_shapes[i]; + mindspore::parallel::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape); + if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) { + size_t size = sizeof(T); + for (size_t j = 0; j < shape.size(); j++) { + size *= shape[j]; + } + input_size_list_.push_back(size); + } + } + + output_size_list_.push_back(sizeof(size_t)); + return; + } + + void InitKernel(const CNodePtr &kernel_node) { return; } + + private: + size_t key_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc new file mode 100644 index 0000000000..c7283954f8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" +#include +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void SparseApplyAdamPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector &var_shape = *(shape_vec[0]); + std::vector &m_shape = *(shape_vec[1]); + std::vector &v_shape = *(shape_vec[2]); + const std::vector &grad_shape = *(shape_vec[9]); + const std::vector &indices_shape = *(shape_vec[10]); + + Shard(&var_shape, 0); + Shard(&m_shape, 0); + Shard(&v_shape, 0); + + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; + } + /* + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } + */ + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); +} + +void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + const std::vector &indices_shape = *(shape_vec[0]); + indices_size_ = indices_shape[0]; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +void SparseApplyAdamPSKernel::ReInit(const std::vector &inputs) { + const auto &indices_addr = inputs[10]; + indices_size_ = indices_addr->size; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +bool SparseApplyAdamPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + ReInit(inputs); + int *indices = reinterpret_cast(inputs[10]->addr); + for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) { + indices[i] -= rank_id_ * var_first_dim_size_; + } + return Launch(inputs, workspace, outputs); +} + +const std::vector &SparseApplyAdamPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &SparseApplyAdamPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &SparseApplyAdamPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h new file mode 100644 index 0000000000..337fcb3bf0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::kernel::SparseApplyAdamCPUKernel; +class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { + public: + SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~SparseApplyAdamPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; + + protected: + void ReInit(const std::vector &) override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc new file mode 100644 index 0000000000..0392bd5a69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace ps { +void SparseApplyFtrlPSKernel::InitKernel( + const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector var_shape = *(shape_vec[0]); + std::vector accum_shape = *(shape_vec[1]); + std::vector linear_shape = *(shape_vec[2]); + std::vector grad_shape = *(shape_vec[3]); + std::vector indices_shape = *(shape_vec[4]); + + Shard(&var_shape, 0); + Shard(&accum_shape, 0); + Shard(&linear_shape, 0); + + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + lr_ = 0.01; + l1_ = 1e-8; + l2_ = 1e-8; + lr_power_ = -0.5; + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>>> &shapes) { + const std::vector>> &shape_vec = *shapes; + std::vector indices_shape = *(shape_vec[0]); + indices_size_ = indices_shape[0]; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +void SparseApplyFtrlPSKernel::ReInit(const std::vector &inputs) { + const auto &indices_addr = inputs[4]; + indices_size_ = indices_addr->size; + workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); + workspace_size_list_[1] = indices_size_ * sizeof(int); +} + +bool SparseApplyFtrlPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + ReInit(inputs); + int *indices = reinterpret_cast(inputs[4]->addr); + for (size_t i = 0; i < inputs[4]->size / sizeof(int); i++) { + indices[i] -= rank_id_ * var_first_dim_size_; + } + return Launch(inputs, workspace, outputs); +} + +const std::vector &SparseApplyFtrlPSKernel::input_sizes() const { return GetInputSizeList(); } + +const std::vector &SparseApplyFtrlPSKernel::output_sizes() const { return GetOutputSizeList(); } + +const std::vector &SparseApplyFtrlPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } +} // namespace ps +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h new file mode 100644 index 0000000000..d97f19d349 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h" + +namespace mindspore { +namespace kernel { +namespace ps { +using mindspore::kernel::SparseApplyFtrlCPUKernel; +class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { + public: + SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ~SparseApplyFtrlPSKernel() override = default; + + void InitKernel(const std::shared_ptr>>> &) override; + void ReInit(const std::shared_ptr>>> &) override; + + bool Execute(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + const std::vector &input_sizes() const override; + const std::vector &output_sizes() const override; + const std::vector &workspace_sizes() const override; + + protected: + void ReInit(const std::vector &) override; +}; +} // namespace ps +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc new file mode 100644 index 0000000000..0dddf1d3c4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "backend/kernel_compiler/cpu/reduce_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +const size_t kReduceTypeMax = 0; +const size_t kReduceTypeMean = 1; +const size_t kReduceTypeSum = 2; +const size_t kMaxDim = 100; +void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "ReduceMax") { + reduce_type_ = kReduceTypeMax; + } else if (kernel_name == "ReduceMean") { + reduce_type_ = kReduceTypeMean; + } else if (kernel_name == "ReduceSum") { + reduce_type_ = kReduceTypeSum; + } else { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; + } + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); + if (axis_addr->isa()) { + auto attr_axis = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); + if (attr_axis.size() > shape_.size()) { + MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size(); + } else if (attr_axis.empty()) { + axis_.push_back(shape_.size() - 1); + } else { + for (auto axis : attr_axis) { + if (IntToSize(axis) >= (shape_.size())) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + } + } + } else if (axis_addr->isa()) { + int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + if (axis >= 0 && IntToSize(axis) >= shape_.size()) { + MS_LOG(EXCEPTION) << "axis value is oversize."; + } + axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; + } + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] <= 0) { + MS_LOG(EXCEPTION) << "shape value is invalid."; + } + left_dims_ *= shape_[i]; + } + for (size_t i = 0; i < axis_.size(); ++i) { + stride_ *= shape_[axis_[i]]; + } + if (stride_ <= 0) { + MS_LOG(EXCEPTION) << "stride_ must greater than zero."; + } + left_dims_ = left_dims_ / stride_; +} +bool ReduceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + size_t out_float_size = left_dims_ * sizeof(float); + size_t in_float_size = stride_ * out_float_size; + if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + } + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + int size = inputs[0]->size / sizeof(float); + std::vector new_input(IntToSize(size), 0.0); + std::vector transpose_axis; + for (size_t i = 0; i < shape_.size(); ++i) { + bool insert = true; + for (size_t j = 0; j < axis_.size(); ++j) { + if (axis_[j] == i) { + insert = false; + break; + } + } + if (insert) { + transpose_axis.push_back(i); + } + } + (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); + Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]); + if (reduce_type_ == kReduceTypeMax) { + for (size_t i = 0; i < left_dims_; ++i) { + float value = new_input[i * stride_]; + for (size_t k = 0; k < stride_; ++k) { + if (value < new_input[i * stride_ + k]) { + value = new_input[i * stride_ + k]; + } + } + output[i] = value; + } + } else { + for (size_t i = 0; i < left_dims_; ++i) { + float value = 0.0; + for (size_t k = 0; k < stride_; ++k) { + value += new_input[i * stride_ + k]; + } + if (reduce_type_ == kReduceTypeMean) { + output[i] = value / stride_; + } else { + output[i] = value; + } + } + } + return true; +} +void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, float *output) { + int pos_array[kMaxDim]; + int size_offset[kMaxDim]; + size_offset[0] = size / SizeToInt(input_shape[0]); + for (int i = 1; i < shape_size; i++) { + size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]); + } + for (int position = 0; position < size; position += 1) { + int temp_position = position; + pos_array[0] = temp_position / size_offset[0]; + for (int i = 1; i < shape_size; i++) { + temp_position -= pos_array[i - 1] * size_offset[i - 1]; + pos_array[i] = temp_position / size_offset[i]; + } + int new_position = pos_array[SizeToInt(input_axis[shape_size - 1])]; + int new_position_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]); + new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size; + } + output[new_position] = input[position]; + } + return; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h new file mode 100644 index 0000000000..a9696bad49 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_cpu_kernel.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReduceCPUKernel : public CPUKernel { + public: + ReduceCPUKernel() = default; + ~ReduceCPUKernel() override = default; + void InitKernel(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void Transpose(const int size, const float *input, const std::vector &input_shape, + const std::vector &input_axis, const int shape_size, float *output); + size_t reduce_type_; + std::vector axis_; + std::vector shape_; + size_t left_dims_ = 1; + size_t stride_ = 1; +}; +MS_REG_CPU_KERNEL(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +MS_REG_CPU_KERNEL(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.cc new file mode 100644 index 0000000000..f44c109ace --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr auto kRanksGroup = "group"; +} // namespace + +ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} + +void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); + if (op != nullptr) { + op_type_ = GetValue(op); + } + + auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); + if (ranks_group != nullptr) { + ranks_group_ = GetValue>(ranks_group); + } else { + MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; + } +} + +bool ReduceScatterCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto output_data_num = outputs[0]->size / sizeof(float); + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h new file mode 100644 index 0000000000..317d7df443 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reduce_scatter_cpu_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReduceScatterCPUKernel : public CPUKernel { + public: + ReduceScatterCPUKernel(); + ~ReduceScatterCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::string op_type_; + std::vector ranks_group_; +}; + +MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReduceScatterCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc new file mode 100644 index 0000000000..6370fdc78a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/reshape_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } + +bool ReshapeCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "input or output empty!"; + } + if (inputs[0]->size != outputs[0]->size) { + return false; + } + + if (inputs[0]->addr == outputs[0]->addr) { + return true; + } + + size_t mem_bits = outputs[0]->size; + auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h new file mode 100644 index 0000000000..04f1db3304 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ReshapeCPUKernel : public CPUKernel { + public: + ReshapeCPUKernel() = default; + ~ReshapeCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; +}; + +MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); + +MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); + +MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ReshapeCPUKernel); +MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ReshapeCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc new file mode 100644 index 0000000000..c6657a845a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.cc @@ -0,0 +1,179 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/slice_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + auto strides = prim->GetAttr(STRIDES); + if (strides != nullptr) { + strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); + if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) { + MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; + } + for (size_t i = 0; i < strides_.size(); ++i) { + if (strides_[i] < 0) { + strides_[i] = (strides_[i] + input_shape_[i]) > 0 ? (strides_[i] + input_shape_[i]) : 0; + } + if (end_[i] < 0) { + end_[i] = (end_[i] + input_shape_[i]) > 0 ? (end_[i] + input_shape_[i]) : 0; + } + } + } else { + auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { + MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; + } + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] < 0) { + sizes[i] = (sizes[i] + input_shape_[i]) > 0 ? (sizes[i] + input_shape_[i]) : 0; + } + strides_.emplace_back(1); + end_.emplace_back(begin_[i] + sizes[i]); + } + } + + ExpandAllMemberDims(); + CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); + CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); +} + +void SliceCPUKernel::ExpandAllMemberDims() { + CPUKernelUtils::ExpandDimsTo4(&output_shape_); + + auto input_len = input_shape_.size(); + if (input_len < 4) { + for (size_t i = 0; i < 4 - input_len; ++i) { + input_shape_.insert(input_shape_.begin(), 1); + begin_.insert(begin_.begin(), 0); + strides_.insert(strides_.begin(), 1); + end_.insert(end_.begin(), 1); + } + } +} + +bool SliceCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; + size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1], + begin_[2] * input_element_num_[2]}; + size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1], + strides_[2] * input_element_num_[2]}; + + auto in_n_offset = in_start_offset[0]; + auto out_n_offset = 0; + for (int i = begin_[0]; i < end_[0]; + i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) { + if (can_copy_memory[0]) { + CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); + continue; + } + auto in_c_offset = in_start_offset[1]; + auto out_c_offset = 0; + for (int j = begin_[1]; j < end_[1]; + j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) { + if (can_copy_memory[1]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, + input_element_num_[1]); + continue; + } + auto in_h_offset = in_start_offset[2]; + auto out_h_offset = 0; + for (int k = begin_[2]; k < end_[2]; + k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) { + if (can_copy_memory[2]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, + out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); + continue; + } + for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { + *output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m]; + } + } + } + } + + return true; +} + +bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { + for (size_t i = dim + 1; i < 4; ++i) { + if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) { + return false; + } + } + return true; +} + +void SliceCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, + size_t copy_num) const { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto in_buff_size = inputs[0]->size; + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto out_buff_size = outputs[0]->size; + + if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { + MS_LOG(EXCEPTION) << "input memory out of bounds."; + } + if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { + MS_LOG(EXCEPTION) << "output memory out of bounds."; + } + + auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, + copy_num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; + } +} + +void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; + } + if (input_shape.size() == 0) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h new file mode 100644 index 0000000000..03b7ecdc17 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_cpu_kernel.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SliceCPUKernel : public CPUKernel { + public: + SliceCPUKernel() = default; + ~SliceCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void ExpandAllMemberDims(); + bool CanCopyMemoryOnAxis(size_t dim) const; + void CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, size_t copy_num) const; + void CheckParam(const CNodePtr &kernel_node) const; + std::vector begin_; + std::vector end_; + std::vector strides_; + std::vector input_shape_; + std::vector input_element_num_; + std::vector output_shape_; + std::vector output_element_num_; +}; + +MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceCPUKernel); +MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.cc new file mode 100644 index 0000000000..20904e0504 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace kernel { +void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + + begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + output_shape_[i]; + } + } + + auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + auto strides = prim->GetAttr(STRIDES); + if (strides != nullptr) { + strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); + end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); + if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) { + MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; + } + for (size_t i = 0; i < strides_.size(); ++i) { + if (strides_[i] < 0) { + strides_[i] = (strides_[i] + output_shape_[i]) > 0 ? (strides_[i] + output_shape_[i]) : 0; + } + if (end_[i] < 0) { + end_[i] = (end_[i] + output_shape_[i]) > 0 ? (end_[i] + output_shape_[i]) : 0; + } + } + } else { + auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + if (sizes.size() != output_shape_.size() || begin_.size() != output_shape_.size()) { + MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; + } + for (size_t i = 0; i < sizes.size(); ++i) { + if (sizes[i] < 0) { + sizes[i] = (sizes[i] + output_shape_[i]) > 0 ? (sizes[i] + output_shape_[i]) : 0; + } + strides_.emplace_back(1); + end_.emplace_back(begin_[i] + sizes[i]); + } + } + + ExpandAllMemberDims(); + CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); + CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); +} + +void SliceGradCPUKernel::ExpandAllMemberDims() { + CPUKernelUtils::ExpandDimsTo4(&input_shape_); + + auto output_len = output_shape_.size(); + if (output_len < 4) { + for (size_t i = 0; i < 4 - output_len; ++i) { + output_shape_.insert(output_shape_.begin(), 1); + begin_.insert(begin_.begin(), 0); + strides_.insert(strides_.begin(), 1); + end_.insert(end_.begin(), 1); + } + } +} + +bool SliceGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); + if (ret != EOK) { + MS_LOG(ERROR) << "output buff memset fail. ret:" << ret; + return false; + } + + bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; + size_t out_start_offset[3] = {begin_[0] * output_element_num_[0], begin_[1] * output_element_num_[1], + begin_[2] * output_element_num_[2]}; + size_t out_step_size[3] = {strides_[0] * output_element_num_[0], strides_[1] * output_element_num_[1], + strides_[2] * output_element_num_[2]}; + + auto in_n_offset = 0; + auto out_n_offset = out_start_offset[0]; + for (int i = begin_[0]; i < end_[0]; + i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) { + if (can_copy_memory[0]) { + CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); + continue; + } + auto in_c_offset = 0; + auto out_c_offset = out_start_offset[1]; + for (int j = begin_[1]; j < end_[1]; + j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) { + if (can_copy_memory[1]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, + input_element_num_[1]); + continue; + } + auto in_h_offset = 0; + auto out_h_offset = out_start_offset[2]; + for (int k = begin_[2]; k < end_[2]; + k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) { + if (can_copy_memory[2]) { + CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, + out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); + continue; + } + for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { + output_addr[out_n_offset + out_c_offset + out_h_offset + m] = *input_addr++; + } + } + } + } + return true; +} + +bool SliceGradCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { + for (size_t i = dim + 1; i < 4; ++i) { + if (begin_[i] != 0 || end_[i] != SizeToInt(output_shape_[i]) || strides_[i] != 1) { + return false; + } + } + return true; +} + +void SliceGradCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, + size_t copy_num) const { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto in_buff_size = inputs[0]->size; + auto output_addr = reinterpret_cast(outputs[0]->addr); + auto out_buff_size = outputs[0]->size; + + if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { + MS_LOG(EXCEPTION) << "input memory out of bounds."; + } + if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { + MS_LOG(EXCEPTION) << "output memory out of bounds."; + } + + auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, + copy_num * sizeof(float)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; + } +} + +void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) const { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; + } + if (input_shape.size() == 0) { + MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h new file mode 100644 index 0000000000..ec480d7e80 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/slice_grad_cpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SliceGradCPUKernel : public CPUKernel { + public: + SliceGradCPUKernel() = default; + ~SliceGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void ExpandAllMemberDims(); + bool CanCopyMemoryOnAxis(size_t dim) const; + void CopyDataToOutput(const std::vector &inputs, size_t in_offset, + const std::vector &outputs, size_t out_offset, size_t copy_num) const; + void CheckParam(const CNodePtr &kernel_node) const; + std::vector begin_; + std::vector end_; + std::vector strides_; + std::vector input_shape_; + std::vector input_element_num_; + std::vector output_shape_; + std::vector output_element_num_; +}; + +MS_REG_CPU_KERNEL( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradCPUKernel); +MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc new file mode 100644 index 0000000000..2ff8e77fcd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc @@ -0,0 +1,177 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyAdamInputSize = 11; + +void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto m = input_params->m_; + auto m_t = input_params->m_t_; + auto v = input_params->v_; + auto beta1 = input_params->beta1_; + auto beta2 = input_params->beta2_; + auto use_nesterov = input_params->use_nesterov_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + m[j] += (1 - beta1) * summed_grad; + v[j] += (1 - beta2) * summed_grad * summed_grad; + if (use_nesterov) { + m_t[j] = m[j] * beta1 + (1 - beta1) * summed_grad; + } + } + } +} + +void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto m = input_params->m_; + auto v = input_params->v_; + auto beta1 = input_params->beta1_; + auto beta2 = input_params->beta2_; + for (size_t i = start; i < end; ++i) { + m[i] *= beta1; + v[i] *= beta2; + } +} + +void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto m = input_params->m_; + auto v = input_params->v_; + auto lr = input_params->lr_; + auto epsilon = input_params->epsilon_; + for (size_t i = start; i < end; ++i) { + var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); + } +} +} // namespace + +void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); +} + +void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } +} + +bool SparseApplyAdamCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyAdamInputSize) { + MS_LOG(EXCEPTION) << "Error input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto m = reinterpret_cast(inputs[1]->addr); + auto v = reinterpret_cast(inputs[2]->addr); + auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; + if (beta1_power == 1) { + MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; + } + auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; + auto lr = reinterpret_cast(inputs[5]->addr)[0]; + auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; + auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; + auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; + auto grad = reinterpret_cast(inputs[9]->addr); + auto indices = reinterpret_cast(inputs[10]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + auto m_t = reinterpret_cast(workspace[2]->addr); + + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); + size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; + lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); + + MultiThreadComputeParams input_params; + input_params.m_ = m; + input_params.v_ = v; + input_params.beta1_ = beta1; + input_params.beta2_ = beta2; + MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size); + + input_params.m_t_ = m_t; + input_params.use_nesterov_ = use_nesterov_; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_); + + if (use_nesterov_) { + input_params.m_ = input_params.m_t_; + } + input_params.var_ = var; + input_params.lr_ = lr; + input_params.epsilon_ = epsilon; + MultiThreadCompute(ComputeWeight, &input_params, total_dim_size); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h new file mode 100644 index 0000000000..5d3d4193f7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyAdamCPUKernel : public CPUKernel { + public: + SparseApplyAdamCPUKernel() = default; + ~SparseApplyAdamCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; + bool use_nesterov_{false}; +}; + +MS_REG_CPU_KERNEL(SparseApplyAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyAdamCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc new file mode 100644 index 0000000000..2662604e19 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyFtrlInputSize = 5; + +void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto accum = input_params->accum_; + auto linear = input_params->linear_; + auto lr = input_params->lr_; + auto l1 = input_params->l1_; + auto l2_plus = 2 * input_params->l2_; + auto lr_power = input_params->lr_power_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + auto accum_new = accum[j] + summed_grad * summed_grad; + float y; + if (lr_power == -0.5) { + y = std::sqrt(accum_new); + linear[j] += summed_grad - (y - std::sqrt(accum[j])) / lr * var[j]; + } else { + y = std::pow(accum_new, -lr_power); + linear[j] += summed_grad - (y - std::pow(accum[j], -lr_power)) / lr * var[j]; + } + accum[j] = accum_new; + auto x = Sign(linear[j]) * l1 - linear[j]; + y = y / lr + l2_plus; + var[j] = std::fabs(linear[j]) > l1 ? x / y : 0; + } + } +} +} // namespace + +void SparseApplyFtrlCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + if (!IsSameShape(var_shape, accum_shape)) { + MS_LOG(EXCEPTION) << "var and accum should have the same shape"; + } + if (!IsSameShape(var_shape, linear_shape)) { + MS_LOG(EXCEPTION) << "var and linear should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + lr_ = AnfAlgo::GetNodeAttr(kernel_node, "lr"); + if (lr_ <= 0) { + MS_LOG(EXCEPTION) << "lr should be a positive scalar"; + } + l1_ = AnfAlgo::GetNodeAttr(kernel_node, "l1"); + if (l1_ < 0) { + MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; + } + l2_ = AnfAlgo::GetNodeAttr(kernel_node, "l2"); + if (l2_ < 0) { + MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; + } + lr_power_ = AnfAlgo::GetNodeAttr(kernel_node, "lr_power"); + if (lr_power_ > 0) { + MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; + } +} + +bool SparseApplyFtrlCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyFtrlInputSize) { + MS_LOG(EXCEPTION) << "error input output size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto accum = reinterpret_cast(inputs[1]->addr); + auto linear = reinterpret_cast(inputs[2]->addr); + auto grad = reinterpret_cast(inputs[3]->addr); + auto indices = reinterpret_cast(inputs[4]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + auto tmp_grad = reinterpret_cast(workspace[2]->addr); + auto tmp_indices = reinterpret_cast(workspace[3]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); + TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, + var_first_dim_size_, var_outer_dim_size_); + + MultiThreadComputeParams input_params; + input_params.var_ = var; + input_params.accum_ = accum; + input_params.linear_ = linear; + input_params.lr_ = lr_; + input_params.l1_ = l1_; + input_params.l2_ = l2_; + input_params.lr_power_ = lr_power_; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeFtrl, &input_params, unique_sparse_grad.indices_size_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h new file mode 100644 index 0000000000..af8796d8a5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyFtrlCPUKernel : public CPUKernel { + public: + SparseApplyFtrlCPUKernel() = default; + ~SparseApplyFtrlCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + protected: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; + float lr_{0}; + float l1_{0}; + float l2_{0}; + float lr_power_{0}; +}; + +MS_REG_CPU_KERNEL(SparseApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyFtrlCPUKernel); + +MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyFtrlCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc new file mode 100644 index 0000000000..636d92dcbb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyLazyAdamInputSize = 11; + +void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto m = input_params->m_; + auto v = input_params->v_; + auto lr = input_params->lr_; + auto beta1 = input_params->beta1_; + auto beta2 = input_params->beta2_; + auto epsilon = input_params->epsilon_; + auto use_nesterov = input_params->use_nesterov_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + m[j] = beta1 * m[j] + (1 - beta1) * summed_grad; + v[j] = beta2 * v[j] + (1 - beta2) * summed_grad * summed_grad; + if (use_nesterov) { + var[j] -= lr * (m[j] * beta1 + (1 - beta1) * summed_grad) / (std::sqrt(v[j]) + epsilon); + } else { + var[j] -= lr * m[j] / (std::sqrt(v[j]) + epsilon); + } + } + } +} +} // namespace + +void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); + if (!IsSameShape(var_shape, m_shape)) { + MS_LOG(EXCEPTION) << "var and m should have the same shape"; + } + if (!IsSameShape(var_shape, v_shape)) { + MS_LOG(EXCEPTION) << "var and v should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be 1D"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { + use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); + } +} + +bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyLazyAdamInputSize) { + MS_LOG(EXCEPTION) << "Error input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto m = reinterpret_cast(inputs[1]->addr); + auto v = reinterpret_cast(inputs[2]->addr); + auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; + if (beta1_power == 1) { + MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; + } + auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; + auto lr = reinterpret_cast(inputs[5]->addr)[0]; + auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; + auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; + auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; + auto grad = reinterpret_cast(inputs[9]->addr); + auto indices = reinterpret_cast(inputs[10]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + auto tmp_grad = reinterpret_cast(workspace[2]->addr); + auto tmp_indices = reinterpret_cast(workspace[3]->addr); + + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); + TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, + var_first_dim_size_, var_outer_dim_size_); + + lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); + MultiThreadComputeParams input_params; + input_params.var_ = var; + input_params.m_ = m; + input_params.v_ = v; + input_params.lr_ = lr; + input_params.beta1_ = beta1; + input_params.beta2_ = beta2; + input_params.epsilon_ = epsilon; + input_params.use_nesterov_ = use_nesterov_; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeLazyAdam, &input_params, unique_sparse_grad.indices_size_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h new file mode 100644 index 0000000000..ee95db8f33 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyLazyAdamCPUKernel : public CPUKernel { + public: + SparseApplyLazyAdamCPUKernel() = default; + ~SparseApplyLazyAdamCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; + bool use_nesterov_{false}; +}; + +MS_REG_CPU_KERNEL(SparseApplyLazyAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyLazyAdamCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc new file mode 100644 index 0000000000..efba35ad8c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kSparseApplyProximalAdagradInputSize = 7; + +void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start, size_t end) { + MS_EXCEPTION_IF_NULL(input_params); + auto var = input_params->var_; + auto accum = input_params->accum_; + auto lr = input_params->lr_; + auto l1 = input_params->l1_; + auto l2 = input_params->l2_; + auto unique_sparse_grad = input_params->sparse_grad_; + auto var_first_dim_size = input_params->var_first_dim_size_; + auto var_outer_dim_size = input_params->var_outer_dim_size_; + for (size_t i = start; i < end; ++i) { + int index = unique_sparse_grad.indices_[i]; + if (index < 0 || IntToSize(index) >= var_first_dim_size) { + MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; + } + size_t start_index = var_outer_dim_size * index; + size_t end_index = start_index + var_outer_dim_size; + for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { + auto summed_grad = unique_sparse_grad.value_[k]; + accum[j] += summed_grad * summed_grad; + auto learning_rate = lr * (1 / std::sqrt(accum[j])); + auto prox_v = var[j]; + prox_v -= summed_grad * learning_rate; + if (l1 > 0) { + var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast(0.0)) / + (1 + l2 * learning_rate); + } else { + var[j] = prox_v / (1 + l2 * learning_rate); + } + } + } +} +} // namespace + +void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { + CPUKernel::InitInputOutputSize(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_node); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); +} + +void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + std::vector lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + std::vector l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + std::vector l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); + std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); + if (!IsSameShape(var_shape, accum_shape)) { + MS_LOG(EXCEPTION) << "var and accum should have the same shape"; + } + if (var_shape.empty()) { + MS_LOG(EXCEPTION) << "var must be at least 1D"; + } + var_first_dim_size_ = var_shape[0]; + for (size_t i = 1; i < var_shape.size(); ++i) { + if (var_shape[i] != grad_shape[i]) { + MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; + } + var_outer_dim_size_ *= var_shape[i]; + } + if (indices_shape.size() != 1) { + MS_LOG(EXCEPTION) << "indices must be a 1D vector"; + } + indices_size_ = indices_shape[0]; + if (grad_shape[0] != indices_size_) { + MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; + } + if (!lr_shape.empty()) { + MS_LOG(EXCEPTION) << "lr is not a scalar"; + } + if (!l1_shape.empty()) { + MS_LOG(EXCEPTION) << "l1 is not a scalar"; + } + if (!l2_shape.empty()) { + MS_LOG(EXCEPTION) << "l2 is not a scalar"; + } +} + +bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector & /*outputs*/) { + if (inputs.size() < kSparseApplyProximalAdagradInputSize) { + MS_LOG(EXCEPTION) << "Wrong input size!"; + } + + auto var = reinterpret_cast(inputs[0]->addr); + auto accum = reinterpret_cast(inputs[1]->addr); + auto lr = reinterpret_cast(inputs[2]->addr)[0]; + auto l1 = reinterpret_cast(inputs[3]->addr)[0]; + auto l2 = reinterpret_cast(inputs[4]->addr)[0]; + auto grad = reinterpret_cast(inputs[5]->addr); + auto indices = reinterpret_cast(inputs[6]->addr); + auto new_grad = reinterpret_cast(workspace[0]->addr); + auto new_indices = reinterpret_cast(workspace[1]->addr); + SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); + ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, + var_outer_dim_size_); + + MultiThreadComputeParams input_params; + input_params.var_ = var; + input_params.accum_ = accum; + input_params.lr_ = lr; + input_params.l1_ = l1; + input_params.l2_ = l2; + input_params.sparse_grad_ = unique_sparse_grad; + input_params.var_first_dim_size_ = var_first_dim_size_; + input_params.var_outer_dim_size_ = var_outer_dim_size_; + MultiThreadCompute(ComputeProximalAdagrad, &input_params, unique_sparse_grad.indices_size_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h new file mode 100644 index 0000000000..56b180ec0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SparseApplyProximalAdagradCPUKernel : public CPUKernel { + public: + SparseApplyProximalAdagradCPUKernel() = default; + ~SparseApplyProximalAdagradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + void InitInputOutputSize(const CNodePtr &kernel_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + size_t indices_size_{0}; + size_t var_first_dim_size_{0}; + size_t var_outer_dim_size_{1}; +}; + +MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyProximalAdagradCPUKernel); + +MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SparseApplyProximalAdagradCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc new file mode 100644 index 0000000000..1e759390a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "backend/kernel_compiler/cpu/sub_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (shape.size() == 1) { + if (shape[0] != 1) { + MS_LOG(EXCEPTION) << "input 1 only support scalar"; + } + } else { + MS_LOG(EXCEPTION) << "input 1 only support scalar"; + } +} + +void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) { + for (size_t i = 0; i < lens; i++) { + out_addr[i] = in_addr[i] - offset; + } +} + +bool SubCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + offset_ = *reinterpret_cast(inputs[1]->addr); + MS_LOG(INFO) << "offset: " << offset_; + auto lens = inputs[0]->size / sizeof(int); + if (lens < 10000) { + for (size_t i = 0; i < lens; i++) { + output_addr[i] = input_addr[i] - offset_; + } + } else { + const size_t thread_num = 4; + std::thread threads[4]; + size_t process_lens = (lens + thread_num - 1) / thread_num; + size_t process_offset = 0; + for (size_t i = 0; i < thread_num; i++) { + threads[i] = + std::thread(sub_task, input_addr + process_offset, output_addr + process_offset, process_lens, offset_); + if (process_offset + process_lens > lens) { + process_lens = lens - process_offset; + process_offset = lens; + } else { + process_offset += process_lens; + } + } + for (size_t i = 0; i < thread_num; i++) { + threads[i].join(); + } + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "SubscaleCPUKernel, used time: " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us"; +#endif + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h new file mode 100644 index 0000000000..d1b55ded90 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_cpu_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SubCPUKernel : public CPUKernel { + public: + SubCPUKernel() : offset_(0) {} + ~SubCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + int offset_; +}; + +MS_REG_CPU_KERNEL( + Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SubCPUKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc new file mode 100644 index 0000000000..8ec3698cf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/transpose_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +namespace mindspore { +namespace kernel { +const size_t kMaxDim = 100; +void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + axis_ = AnfAlgo::GetNodeAttr>(kernel_node, "perm"); + if (shape_.size() != axis_.size()) { + MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal."; + } +} +bool TransposeCPUFwdKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto input = reinterpret_cast(inputs[0]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + size_t size = IntToSize(inputs[0]->size / sizeof(float)); + size_t shape_size = IntToSize(shape_.size()); + if (shape_size > kMaxDim) { + MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs."; + } + size_t pos_array[kMaxDim]; + size_t size_offset[kMaxDim]; + size_offset[0] = size / shape_[0]; + for (size_t i = 1; i < shape_size; i++) { + size_offset[i] = size_offset[SizeToInt(i) - 1] / shape_[i]; + } + for (size_t position = 0; position < size; position += 1) { + size_t temp_position = position; + pos_array[0] = temp_position / size_offset[0]; + for (size_t i = 1; i < shape_size; i++) { + temp_position -= pos_array[SizeToInt(i) - 1] * size_offset[i - 1]; + pos_array[i] = temp_position / size_offset[i]; + } + size_t new_position = pos_array[axis_[SizeToInt(shape_size) - 1]]; + size_t new_position_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + new_position_size *= shape_[axis_[j + 1]]; + new_position += pos_array[axis_[j]] * new_position_size; + } + output[new_position] = input[position]; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h new file mode 100644 index 0000000000..15796f9f3c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.h @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +namespace mindspore { +namespace kernel { +class TransposeCPUFwdKernel : public CPUKernel { + public: + TransposeCPUFwdKernel() = default; + ~TransposeCPUFwdKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector shape_; + std::vector axis_; +}; + +MS_REG_CPU_KERNEL(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TransposeCPUFwdKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc new file mode 100644 index 0000000000..39f535a2af --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + ArgmaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + ArgmaxGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h new file mode 100644 index 0000000000..61a53c5b40 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmax_gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh" +namespace mindspore { +namespace kernel { +#define ARGMAX_MAX_DIMENSION 2 +template +class ArgmaxGpuKernel : public GpuKernel { + public: + ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {} + ~ArgmaxGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + int *output = GetDeviceAddress(outputs, 0); + CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output."; + return false; + } + auto output_type = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type")); + if (output_type->type_id() != TypeId::kNumberTypeInt32) { + MS_LOG(EXCEPTION) << "Argmax only supports int32 output type."; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > ARGMAX_MAX_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but argmax supports max " << ARGMAX_MAX_DIMENSION + << "-D inputs."; + } + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ += SizeToInt(input_shape.size()); + } + if (input_shape.size() == 1) { + batch_size_ = 0; + channel_size_ = input_shape[0]; + input_size_ = sizeof(T) * channel_size_; + output_size_ = sizeof(int); + } else { + batch_size_ = input_shape[0]; + channel_size_ = input_shape[1]; + input_size_ = sizeof(T) * batch_size_ * channel_size_; + output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t batch_size_; + size_t channel_size_; + int axis_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc new file mode 100644 index 0000000000..5ead387ccc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + ArgmaxWithValueGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + ArgmaxWithValueGpuKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h new file mode 100644 index 0000000000..d2369023fb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxwithvalue_gpu_kernel.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh" +namespace mindspore { +namespace kernel { +template +class ArgmaxWithValueGpuKernel : public GpuKernel { + public: + ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} + ~ArgmaxWithValueGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 1); + S *index = GetDeviceAddress(outputs, 0); + CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); + int dims = shape.size(); + int axis = GetAttr(kernel_node, "axis"); + if (axis < 0) { + axis += dims; + } + input_size_ = sizeof(T); + for (auto x : shape) { + input_size_ *= x; + } + output_size_ = sizeof(S); + for (auto x : output_shape) { + output_size_ *= x; + } + bound_ = shape[axis]; + outerSize_ = 1; + for (int i = axis - 1; i >= 0; i--) { + outerSize_ *= shape[i]; + } + + innerSize_ = 1; + for (int i = axis + 1; i < dims; i++) { + innerSize_ *= shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T)); + } + + private: + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + int bound_; + int outerSize_; + int innerSize_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc new file mode 100644 index 0000000000..5d34a1c9c2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArrayReduceGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ArrayReduceGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h new file mode 100644 index 0000000000..b96f63670d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/array_reduce_gpu_kernel.h @@ -0,0 +1,237 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +const std::map kReduceTypeMap = { + {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, + {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, + {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, +}; +template +class ArrayReduceGpuKernel : public GpuKernel { + public: + ArrayReduceGpuKernel() + : cudnn_handle_(nullptr), + reduce_tensor_op_(CUDNN_REDUCE_TENSOR_ADD), + data_type_(CUDNN_DATA_FLOAT), + nan_prop_(CUDNN_NOT_PROPAGATE_NAN), + reduce_indices_(CUDNN_REDUCE_TENSOR_NO_INDICES), + reduce_tensor_descriptor_(nullptr), + inputA_descriptor_(nullptr), + outputC_descriptor_(nullptr), + keep_dims_(false), + all_match_(false), + is_null_input_(false), + input_size_(0), + output_size_(0), + workspace_size_(0) {} + ~ArrayReduceGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + T *workspace_addr = GetDeviceAddress(workspace, 0); + + const float alpha = 1; + const float beta = 0; + if (all_match_) { + MS_LOG(WARNING) + << "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, + inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), + "cudnnReduceTensor failed."); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but reduce op needs 1 output."; + return false; + } + int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); + + if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa() || + AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { + auto attr_axis = GetAttr>(kernel_node, "axis"); + if (attr_axis.empty()) { + axis_.push_back(-1); + } else { + for (auto axis : attr_axis) { + axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); + } + } + } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { + int axis = GetAttr(kernel_node, "axis"); + axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; + } + keep_dims_ = GetAttr(kernel_node, "keep_dims"); + + auto inputA_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto outputC_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(inputA_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ArrayReduceGpuKernel input is null"; + InitSizeLists(); + return true; + } + InferInAndOutDesc(inputA_shape, outputC_shape); + InferArrayReduceType(kernel_node); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), + "cudnnCreateReduceTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&outputC_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + void InitSizeLists() override { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed."); + input_size_list_.push_back(input_size_); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_), + "cudnnGetTensorSizeInBytes failed."); + output_size_list_.push_back(output_size_); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, inputA_descriptor_, outputC_descriptor_, + &workspace_size_), + "cudnnGetReductionWorkspaceSize failed."); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), + "cudnnDestroyReduceTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(outputC_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + void InferArrayReduceType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kReduceTypeMap.find(kernel_name); + if (iter == kReduceTypeMap.end()) { + MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; + } else { + reduce_tensor_op_ = iter->second; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, reduce_tensor_op_, CUDNN_DATA_FLOAT, nan_prop_, + reduce_indices_, CUDNN_32BIT_INDICES), + "cudnnSetReduceTensorDescriptor failed"); + return; + } + void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { + std::vector inputA; + std::vector outputC_shape = output_shape; + ShapeNdTo4d(input_shape, &inputA); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], + inputA[1], inputA[2], inputA[3]), + "cudnnSetTensor4dDescriptor failed"); + + if (axis_[0] == -1) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), + "cudnnSetTensor4dDescriptor failed"); + if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { + all_match_ = true; + } + return; + } + if (!keep_dims_) { + for (auto i : axis_) { + (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); + } + } + std::vector outputC; + ShapeNdTo4d(outputC_shape, &outputC); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, + outputC[0], outputC[1], outputC[2], outputC[3]), + "cudnnSetTensor4dDescriptor failed"); + if (inputA == outputC) { + all_match_ = true; + } + return; + } + + cudnnHandle_t cudnn_handle_; + cudnnReduceTensorOp_t reduce_tensor_op_; + cudnnDataType_t data_type_; + cudnnNanPropagation_t nan_prop_; + cudnnReduceTensorIndices_t reduce_indices_; + cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_; + cudnnTensorDescriptor_t inputA_descriptor_; + cudnnTensorDescriptor_t outputC_descriptor_; + + std::vector axis_; + bool keep_dims_; + bool all_match_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc new file mode 100644 index 0000000000..f5979dc62d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatV2GpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Concat, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ConcatV2GpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ConcatV2GpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h new file mode 100644 index 0000000000..15ccedcaec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h @@ -0,0 +1,128 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ConcatV2GpuFwdKernel : public GpuKernel { + public: + ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} + ~ConcatV2GpuFwdKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (inputs.size() == 2) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, + reinterpret_cast(stream_ptr)); + } + + if (inputs.size() == 3) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *input_2 = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, + reinterpret_cast(stream_ptr)); + } + + if (inputs.size() == 4) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *input_2 = GetDeviceAddress(inputs, 2); + T *input_3 = GetDeviceAddress(inputs, 3); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, + reinterpret_cast(stream_ptr)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + axis_ += SizeToInt(input_shape.size()); + } + + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_size = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + for (size_t j = 0; j < input_shape.size(); j++) { + input_size *= SizeToInt(input_shape[j]); + if (j >= IntToSize(axis_)) { + w_[i] *= SizeToInt(input_shape[j]); + } + input_size_list_.push_back(input_size); + } + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + output_size_list_.push_back(output_size_); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override {} + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num < 2 || input_num > 4) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; + return false; + } + return true; + } + int w_[4] = {1, 1, 1, 1}; + int axis_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc new file mode 100644 index 0000000000..8d3c06e805 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + GatherGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + GatherV2, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + GatherGpuFwdKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h new file mode 100644 index 0000000000..2211361cee --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GATHER_GPU_KERNEL_H +#define MINDSPORE_GATHER_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" + +namespace mindspore { +namespace kernel { +template +class GatherGpuFwdKernel : public GpuKernel { + public: + GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} + ~GatherGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + S *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + auto input_dim1 = input_shapes_[IntToSize(axis_)]; + Gather(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuFwdKernel needs 2."; + } + input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + axis_ = GetAttr(kernel_node, "axis"); + if (axis_ < 0) { + axis_ = axis_ + SizeToInt(input_shapes_.size()); + } + + Reshape(); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitSizeLists() override { + size_t size = GetSize(input_shapes_); + input_size_list_.push_back(size); + + size = GetSize(indices_shapes_); + input_size_list_.push_back(size); + + size = GetSize(output_shapes_); + output_size_list_.push_back(size); + } + + private: + void Reshape() { + size_t dim_before_axis = 1; + for (size_t i = 0; i < IntToSize(axis_); i++) { + dim_before_axis *= output_shapes_[i]; + } + + size_t dim_of_indices = 1; + for (size_t i = 0; i < indices_shapes_.size(); i++) { + dim_of_indices *= indices_shapes_[i]; + } + + size_t dim_after_indices = 1; + for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { + dim_after_indices *= output_shapes_[i]; + } + + dims_[0] = dim_before_axis; + dims_[1] = dim_of_indices; + dims_[2] = dim_after_indices; + return; + } + size_t GetSize(const std::vector &shape) const { + if (shape.size() == 0) { + return 0; + } + size_t result = sizeof(T); + for (size_t i = 0; i < shape.size(); i++) { + result *= shape[i]; + } + return result; + } + + std::vector input_shapes_; + std::vector indices_shapes_; + std::vector output_shapes_; + + size_t dims_[3] = {}; + int axis_; + cudnnHandle_t handle_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_GATHER_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc new file mode 100644 index 0000000000..e764a08dc8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + OneHotGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO(OneHot, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + OneHotGpuFwdKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h new file mode 100644 index 0000000000..6c46a63e69 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/one_hot_gpu_kernel.h @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class OneHotGpuFwdKernel : public GpuKernel { + public: + OneHotGpuFwdKernel() : input_size_(1), output_size_(1), depth_(0), left_dim_size_(1), right_dim_size_(1) {} + ~OneHotGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + const S *indices = GetDeviceAddress(inputs, 0); + const T *on_value = GetDeviceAddress(inputs, 1); + const T *off_value = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + OneHot(indices, depth_, on_value, off_value, left_dim_size_, right_dim_size_, output, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int axis = GetAttr(kernel_node, "axis"); + auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0); + int input_size = SizeToInt(input.size()); + const int default_axis = -1; + + // Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). + for (int i = 0; i < input_size; i++) { + auto dim_size = input[IntToSize(i)]; + if (axis == default_axis || i < axis) { + left_dim_size_ *= dim_size; + } + if (axis != default_axis && i >= axis) { + right_dim_size_ *= dim_size; + } + } + for (auto size : input) { + input_size_ *= size; + } + for (auto size : output) { + output_size_ *= size; + } + if (axis >= input_size) { + MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size(); + return false; + } + if (axis == default_axis) { + depth_ = output[output.size() - 1]; + } else { + depth_ = output[IntToSize(axis)]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // inputs: indices, depth + input_size_list_.push_back((input_size_ + 1) * sizeof(S)); + output_size_list_.push_back(output_size_ * sizeof(T)); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + + size_t depth_; + size_t left_dim_size_; + size_t right_dim_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc new file mode 100644 index 0000000000..3c1323de07 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SelectGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + SelectGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Select, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + SelectGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h new file mode 100644 index 0000000000..73e60c44bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/select_gpu_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SelectGpuKernel : public GpuKernel { + public: + SelectGpuKernel() : input_size_(0), output_size_(0) {} + ~SelectGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + bool *input_cond = GetDeviceAddress(inputs, 0); + T *input_x = GetDeviceAddress(inputs, 1); + T *input_y = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + CalSelect(output_size_ / sizeof(T), input_cond, input_x, input_y, output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(bool); + output_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + output_size_ = output_size_ * x; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(output_size_); + input_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SelectGpuKernel needs 3 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SelectGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc new file mode 100644 index 0000000000..4c9ff2b7f4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGpuFwdKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h new file mode 100644 index 0000000000..f8ecb9ccf0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_gpu_kernel.h @@ -0,0 +1,162 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SliceGpuFwdKernel : public GpuKernel { + public: + SliceGpuFwdKernel() + : is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} + ~SliceGpuFwdKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + if (is_strided_slice_) { + CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, + reinterpret_cast(stream_ptr)); + } else { + Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], + input_shape_[1], input_shape_[2], input_shape_[3], input, output, + reinterpret_cast(stream_ptr)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + ShapeNdTo4d(input_shape, &input_shape_); + auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"); + if (strides) { + strides_ = GetAttr>(kernel_node, "strides"); + for (auto i = strides_.size(); i < 4; i++) { + (void)strides_.insert(strides_.begin(), 1); + } + size_ = GetAttr>(kernel_node, "end"); + is_strided_slice_ = true; + } else { + size_ = GetAttr>(kernel_node, "size"); + } + for (auto i = begin_.size(); i < 4; i++) { + (void)begin_.insert(begin_.begin(), 0); + } + for (size_t i = size_.size(); i < 4; i++) { + (void)size_.insert(size_.begin(), 1); + } + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + for (size_t i = 0; i < size_.size(); i++) { + if (size_[i] < 0) { + size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; + } + if (begin_[i] == size_[i] && is_strided_slice_) { + MS_LOG(WARNING) << "Output is null."; + is_null_input_ = true; + } + if (size_[i] == 0 && strides_[i] > 0) { + size_[i] = begin_[i] + 1; + } + } + + input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); + auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + output_size_ = sizeof(T); + for (size_t x : out_shape) { + output_size_ = output_size_ * x; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SliceGpuFwdKernel needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGpuFwdKernel needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 4d or lower."; + return false; + } + if (input_shape.size() == 0) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + return false; + } + begin_ = GetAttr>(kernel_node, "begin"); + for (size_t i = 0; i < input_shape.size(); i++) { + if ((begin_[i] > 0 && (begin_[i] > SizeToInt(input_shape[i]))) || + (begin_[i] < 0 && (std::abs(begin_[i]) > SizeToInt(input_shape[i])))) { + MS_LOG(INFO) << "Input out of bounds " << input_shape[i] << " in axis " << i << "."; + begin_[i] = 0; + } + } + return true; + } + std::vector begin_; + std::vector size_; + std::vector strides_; + std::vector input_shape_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + bool is_strided_slice_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc new file mode 100644 index 0000000000..2eeb3acf73 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + SliceGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SliceGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + SliceGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SliceGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h new file mode 100644 index 0000000000..006cbf0266 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/slice_grad_gpu_kernel.h @@ -0,0 +1,147 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SliceGradGpuKernel : public GpuKernel { + public: + SliceGradGpuKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} + ~SliceGradGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *dy = GetDeviceAddress(inputs, 0); + T *dx = GetDeviceAddress(outputs, 0); + FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast(stream_ptr)); + if (is_strided_slice_) { + CalStridedSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, strides_, dx, + reinterpret_cast(stream_ptr)); + } else { + CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "StridedSliceGrad") { + is_strided_slice_ = true; + input_shape_ = GetAttr>(kernel_node, "shapex"); + for (auto i = input_shape_.size(); i < 4; i++) { + (void)input_shape_.insert(input_shape_.begin(), 1); + } + strides_ = GetAttr>(kernel_node, "strides"); + for (auto i = strides_.size(); i < 4; i++) { + (void)strides_.insert(strides_.begin(), 1); + } + size_ = GetAttr>(kernel_node, "end"); + } else { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + ShapeNdTo4d(input_shape, &input_shape_); + size_ = GetAttr>(kernel_node, "size"); + } + + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + ShapeNdTo4d(dy_shape, &dy_shape_); + begin_ = GetAttr>(kernel_node, "begin"); + DealParam(); + input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); + + output_size_ = sizeof(T); + for (auto x : dy_shape_) { + output_size_ = output_size_ * IntToSize(x); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(output_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() > 4) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; + return false; + } + if (input_shape.size() == 0) { + MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; + return false; + } + return true; + } + void DealParam() { + for (auto i = begin_.size(); i < 4; i++) { + (void)begin_.insert(begin_.begin(), 0); + } + for (auto i = size_.size(); i < 4; i++) { + (void)size_.insert(size_.begin(), 1); + } + for (size_t i = 0; i < begin_.size(); i++) { + if (begin_[i] < 0) { + begin_[i] = begin_[i] + input_shape_[i]; + } + } + for (size_t i = 0; i < size_.size(); i++) { + if (size_[i] < 0) { + size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; + } + } + } + std::vector begin_; + std::vector size_; + std::vector strides_; + std::vector input_shape_; + std::vector dy_shape_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + bool is_strided_slice_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc new file mode 100644 index 0000000000..77e7de6fef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + TransposeGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + TransposeGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h new file mode 100644 index 0000000000..0f9c710e3e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" +namespace mindspore { +namespace kernel { +template +class TransposeGpuFwdKernel : public GpuKernel { + public: + TransposeGpuFwdKernel() : shape_size_(0), input_size_(0), output_size_(0), workspace_size_(0) {} + ~TransposeGpuFwdKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int *input_shape = GetDeviceAddress(workspace, 0); + int *input_axis = GetDeviceAddress(workspace, 1); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis failed"); + int size = SizeToInt(input_size_ / sizeof(T)); + CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION + << "-D inputs."; + } + + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + auto perm = GetAttr>(kernel_node, "perm"); + for (size_t j = 0; j < perm.size(); j++) { + input_axis_.push_back(perm[j]); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size_ = shape_size_ * sizeof(int); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + std::vector input_shape_; + std::vector input_axis_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t shape_size_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc new file mode 100644 index 0000000000..4be887ec79 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentSumGpuKernel, float, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + UnsortedSegmentSumGpuKernel, float, int64_t) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentSumGpuKernel, int, int) + +MS_REG_GPU_KERNEL_TWO( + UnsortedSegmentSum, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + UnsortedSegmentSumGpuKernel, int, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h new file mode 100644 index 0000000000..1f7884c650 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_sum_gpu_kernel.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh" + +namespace mindspore { +namespace kernel { +template +class UnsortedSegmentSumGpuKernel : public GpuKernel { + public: + UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {} + ~UnsortedSegmentSumGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + S *indices_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemsetAsync(output_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + UnsortedSegmentSum(input_dim0_, input_dim1_, output_dim0_, output_dim1_, input_addr, indices_addr, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + + auto axis = ids_shapes.size(); + for (size_t i = 0; i < input_shapes.size(); i++) { + if (i < axis) { + input_dim0_ *= input_shapes[i]; + } else { + input_dim1_ *= input_shapes[i]; + } + } + + output_dim0_ = output_shapes[0]; + for (size_t j = 1; j < output_shapes.size(); j++) { + output_dim1_ *= output_shapes[j]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T)); + input_size_list_.push_back(input_dim0_ * sizeof(S)); + output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(T)); + } + + private: + size_t input_dim0_; + size_t input_dim1_; + size_t output_dim0_; + size_t output_dim1_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc new file mode 100644 index 0000000000..a89d4e9baf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/control/recv_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h new file mode 100644 index 0000000000..7de32ade4f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/recv_gpu_kernel.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class RecvGpuKernel : public GpuKernel { + public: + RecvGpuKernel() {} + ~RecvGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &, const std::vector &, const std::vector &, + void *) override { + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamWaitEvent(wait_stream_, wait_event_, 0), "Waiting cuda event failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + wait_stream_ = reinterpret_cast(GetAttr(kernel_node, "wait_event_stream")); + wait_event_ = reinterpret_cast(GetAttr(kernel_node, "wait_event")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + return; + } + + private: + cudaStream_t wait_stream_{nullptr}; + cudaEvent_t wait_event_{nullptr}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc new file mode 100644 index 0000000000..946038bb18 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/control/send_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h new file mode 100644 index 0000000000..beea19a435 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/control/send_gpu_kernel.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class SendGpuKernel : public GpuKernel { + public: + SendGpuKernel() {} + ~SendGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &, const std::vector &, const std::vector &, + void *) override { + CHECK_CUDA_RET_WITH_EXCEPT(cudaEventRecord(record_event_, record_stream_), "Recording cuda event failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + record_stream_ = reinterpret_cast(GetAttr(kernel_node, "record_event_stream")); + record_event_ = reinterpret_cast(GetAttr(kernel_node, "record_event")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + return; + } + + private: + cudaStream_t record_stream_{nullptr}; + cudaEvent_t record_event_{nullptr}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu new file mode 100644 index 0000000000..615b94723d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh" + +template +__device__ __forceinline__ T SqrtFunc(T input) { + return sqrt(input); +} + +template <> +__device__ __forceinline__ half SqrtFunc(half input) { + return hsqrt(input); +} + +template +__global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, + const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, + T *m, T *v) { + const T one = static_cast(1.0); + const T new_learning_rate = learning_rate[0] * SqrtFunc(one - beta2_power[0]) / (one - beta1_power[0]); + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + m[i] += (gradient[i] - m[i]) * (one - beta1[0]); + v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2[0]); + variable[i] -= new_learning_rate * m[i] / (SqrtFunc(v[i]) + epsilon[0]); + } +} + +template +void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, + const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { + ApplyAdamKernel<<>>( + size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); +} + +template void ApplyAdam(const size_t size, const float *gradient, const float *beta1_power, + const float *beta2_power, const float *learning_rate, const float *beta1, + const float *beta2, const float *epsilon, float *variable, float *m, float *v, + cudaStream_t cuda_stream); +template void ApplyAdam(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, + const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, + half *variable, half *m, half *v, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh new file mode 100644 index 0000000000..7fc4a3e949 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, + const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cu new file mode 100644 index 0000000000..3bad9a61e1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cu @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "adam_weight_decay_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1, + const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, + const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v, + T *param, T *gradient) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) { + float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i]; + float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i]; + float update = next_m / (sqrt(next_v) + epsilon[0]); + if (need_decay && weight_decay != nullptr) { + update += weight_decay[0] * param[i]; + } + param[i] -= lr[0] * update; + m[i] = next_m; + v[i] = next_v; + } +} + +template +void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, + const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, + const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) { + AdamWeightDecayKernel<<>>( + element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param, + gradient); +} + +template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, + const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, + const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v, + float *param, float *gradient, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu new file mode 100755 index 0000000000..a4f1f6680b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cu @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "argmax_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void Argmax1D(const T* input, const int channel_size, int* output) { + int max_index = 0; + T max = input[0]; + for (int pos = 1; pos < channel_size; pos++) { + if (max < input[pos]) { + max = input[pos]; + max_index = pos; + } + } + output[0] = max_index; + return; +} +template +__global__ void ArgmaxDefault2D(const T* input, const int batch_size, const int channel_size, int* output) { + int pos; + int max_index; + T max; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { + max = input[i * channel_size]; + max_index = 0; + for (int j = 1; j < channel_size; j++) { + pos = i * channel_size + j; + if (max < input[pos]) { + max = input[pos]; + max_index = j; + } + } + + output[i] = max_index; + } + return; +} +template +__global__ void ArgmaxAxis2D(const T* input, const int batch_size, const int channel_size, int* output) { + int pos; + int max_index; + T max; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { + max = input[i]; + max_index = 0; + for (int j = 1; j < batch_size; j++) { + pos = j * channel_size + i; + if (max < input[pos]) { + max = input[pos]; + max_index = j; + } + } + output[i] = max_index; + } + return; +} +template +void CalArgmax(const T* input, const int batch_size, const int channel_size, const int axis, int* output, + cudaStream_t cuda_stream) { + if (batch_size == 0) { + Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output); + } else if (axis == 1) { + ArgmaxDefault2D<<>>(input, batch_size, channel_size, output); + } else { + ArgmaxAxis2D<<>>(input, batch_size, channel_size, output); + } + return; +} + +template void CalArgmax(const float* input, const int batch_size, const int channel_size, const int axis, + int* output, cudaStream_t cuda_stream); +template void CalArgmax(const half* input, const int batch_size, const int channel_size, const int axis, + int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmax_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu new file mode 100644 index 0000000000..46a8a75af9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "argmaxwithvalue_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index, + T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) { + int inputOutterOffset = pos * innerSize * bound; + int outputOutterOffset = pos * innerSize; + for (int j = 0; j < innerSize; j++) { + auto outputInnerOffset = outputOutterOffset + j; + S idx = 0; + T maxData = input[j + inputOutterOffset]; + for (S c = 0; c < bound; c++) { + int offset = j + c * innerSize; + auto inputData = input[inputOutterOffset + offset]; + idx = inputData > maxData ? c : idx; + maxData = inputData > maxData ? inputData : maxData; + } + output[outputInnerOffset] = maxData; + index[outputInnerOffset] = idx; + } + } + return; +} + +template +void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_, + S* index, T* output, cudaStream_t cuda_stream) { + ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, + index, output); + return; +} + +template void CalArgmaxWithValue(const float* input, const int bound_, const int outerSize_, + const int innerSize_, int* index, float* output, + cudaStream_t cuda_stream); +template void CalArgmaxWithValue(const half* input, const int bound_, const int outerSize_, + const int innerSize_, int* index, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/argmaxwithvalue_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu new file mode 100644 index 0000000000..604391ccf3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cu @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "assign_add_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void AssignAdd(const size_t size, T* ref, const T* value, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + output[pos] = ref[pos] + value[pos]; + ref[pos] = output[pos]; + } + return; +} + +template +void CalAssignAdd(const size_t size, T* ref, const T* value, T* output, cudaStream_t cuda_stream) { + AssignAdd<<>>(size, ref, value, output); + + return; +} + +template void CalAssignAdd(const size_t size, float* ref, const float* value, float* output, + cudaStream_t cuda_stream); +template void CalAssignAdd(const size_t size, half* ref, const half* value, half* output, + cudaStream_t cuda_stream); +template void CalAssignAdd(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh new file mode 100644 index 0000000000..3a895405b1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, + const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn, + size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream); +template +void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, + const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, + T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); +template +void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, + const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, + T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); +template +void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N, + size_t C, size_t H, size_t W, cudaStream_t cuda_stream); + +template +void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H, + size_t W, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cu new file mode 100755 index 0000000000..dae9a7d629 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cu @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "batchnorm_fold_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { + running_std[i] = sqrtf(running_std[i] + epsilon); + } + return; +} + +template +__global__ void UpdateBatchStd(int channel_size, T* batch_std) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { + batch_std[i] = 1 / batch_std[i]; + } + return; +} + +template +__global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std, + int batch_size, int channel_size, int height, int width, T* dx) { + int n = batch_size * channel_size * height * width; + int normal_size = batch_size * height * width; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + int channel_index = i / (height * width) % channel_size; + dx[i] = d_batch_mean[channel_index] / normal_size + + d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size; + } + return; +} + +template +void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) { + UpdateRunningStd<<>>(channel_size, epsilon, running_std); + return; +} + +template void CalUpdateRunningStd(int channel_size, double epsilon, float* running_std, + cudaStream_t cuda_stream); + +template +void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) { + UpdateBatchStd<<>>(channel_size, batch_std); + return; +} + +template void CalUpdateBatchStd(int channel_size, float* batch_std, cudaStream_t cuda_stream); + +template +void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, + const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx, + cudaStream_t cuda_stream) { + CalDx<<>>( + d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx); +} + +template void CalBatchNormFoldGrad(const float* d_batch_mean, const float* d_batch_std, const float* x, + const float* batch_mean, const float* batch_std, int batch_size, + int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream); + +template +void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) { + thrust::device_ptr dev_ptr(array); + thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill); +} + +template void ThrustFillWith(float* array, int size, float tofill, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu new file mode 100644 index 0000000000..262d4c438d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +struct MinimumGradFunc { + __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { + if (x1 < x2) { + atomicAdd(dx1, dy); + } else { + atomicAdd(dx2, dy); + } + } +}; + +template +struct MaximumGradFunc { + __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { + if (x1 > x2) { + atomicAdd(dx1, dy); + } else { + atomicAdd(dx2, dy); + } + } +}; + +__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + +template +__device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3, + const int &r0, const int &r1, const int &r2, const int &r3, + const int &d0, const int &d1, const int &d2, const int &d3, + const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3) % d0; + int j = pos / (d2 * d3) % d1; + int k = pos / d3 % d2; + int l = pos % d3; + + int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); + int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); + Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index); + } +} + +template +__global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, + const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, + enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2) { + switch (op) { + case BROADCAST_GRAD_TYPE_MINIMUM: + return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, + dx1, dx2); + case BROADCAST_GRAD_TYPE_MAXIMUM: + return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, + dx1, dx2); + } +} + +template +void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, + cudaStream_t stream) { + int size = d0 * d1 * d2 * d3; + BroadcastGradKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, + x1, x2, dy, dx1, dx2); +} + +template +__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos); + } +} + +template +__global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2, + const T *dy, T *dx1, T *dx2) { + switch (op) { + case BROADCAST_GRAD_TYPE_MINIMUM: + return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); + case BROADCAST_GRAD_TYPE_MAXIMUM: + return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); + } +} + +template +void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2, cudaStream_t stream) { + NoBroadcastGradKernel<<>>(nums, op, x1, x2, dy, dx1, dx2); +} + +template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2, + const float *dy, float *dx1, float *dx2, cudaStream_t stream); +template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2, + const int *dy, int *dx1, int *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, + float *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, + int *dx2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh new file mode 100644 index 0000000000..7742043592 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" + +enum BroadcastGradOpType { + BROADCAST_GRAD_TYPE_MAXIMUM = 0, + BROADCAST_GRAD_TYPE_MINIMUM = 1, + BROADCAST_GRAD_TYPE_INVALID = 0xffffffff, +}; + +template +void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, + cudaStream_t stream); + +template +void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, + T *dx2, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu new file mode 100644 index 0000000000..a72daa4234 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +struct GreaterFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } +}; + +template +struct LessFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } +}; + +template +struct MinimumFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } +}; + +template +struct MaximumFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } +}; + +template +struct PowerFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } +}; + +template <> +struct PowerFunc { + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + return __float2half(pow(__half2float(lhs), __half2float(rhs))); + } +}; + +template +struct RealDivFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } +}; + +template +struct MulFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); } +}; + +template +struct SubFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } +}; + +template +struct AddFunc { + __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } +}; + +template <> +struct PowerFunc { + // invalid branch + __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +}; + +__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } + +template +__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, + const int &r0, const int &r1, const int &r2, const int &r3, + const int &d0, const int &d1, const int &d2, const int &d3, + const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3) % d0; + int j = pos / (d2 * d3) % d1; + int k = pos / d3 % d2; + int l = pos % d3; + + int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); + int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); + output[pos] = Func()(input0[l_index], input1[r_index]); + } +} + +template +__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, + const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, + enum BroadcastOpType op, const T *input0, const T *input1, S *output) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_LESS: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MINIMUM: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MAXIMUM: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_POWER: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_REALDIV: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_MUL: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_SUB: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + case BROADCAST_TYPE_ADD: + return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, + output); + } +} + +template +void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, + const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, + const T *input0, const T *input1, S *output, cudaStream_t stream) { + int size = d0 * d1 * d2 * d3; + BroadcastKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, + input0, input1, output); +} + +template +__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + output[pos] = Func()(input0[pos], input1[pos]); + } +} + +template +__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, + S *output) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_LESS: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MINIMUM: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MAXIMUM: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_POWER: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_REALDIV: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_MUL: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_SUB: + return NoBroadcastOperator>(nums, input0, input1, output); + case BROADCAST_TYPE_ADD: + return NoBroadcastOperator>(nums, input0, input1, output); + } +} + +template +void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, + cudaStream_t stream) { + NoBroadcastKernel<<>>(nums, op, input0, input1, output); +} + +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const float *input0, const float *input1, bool *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const float *input0, const float *input1, float *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const half *input0, const half *input1, bool *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const half *input0, const half *input1, half *output, + cudaStream_t stream); +template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + enum BroadcastOpType op, const int *input0, const int *input1, int *output, + cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, + bool *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, + float *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, + bool *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, + half *output, cudaStream_t stream); +template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, + int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh new file mode 100644 index 0000000000..dfc4c75c93 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ + +#include "runtime/device/gpu/cuda_common.h" + +enum BroadcastOpType { + BROADCAST_TYPE_GREATER = 0, + BROADCAST_TYPE_LESS = 1, + BROADCAST_TYPE_MAXIMUM = 2, + BROADCAST_TYPE_MINIMUM = 3, + BROADCAST_TYPE_POWER = 4, + BROADCAST_TYPE_REALDIV = 5, + BROADCAST_TYPE_MUL = 6, + BROADCAST_TYPE_SUB = 7, + BROADCAST_TYPE_ADD = 8, + BROADCAST_TYPE_INVALID = 0xffffffff, +}; + +template +void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, + const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, + const T *input0, const T *input1, S *output, cudaStream_t stream); + +template +void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu new file mode 100755 index 0000000000..147782591a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" +template +__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2); + int m = pos % (w1 + w2); + output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; + } + return; +} + +template +__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2 + w3); + int m = pos % (w1 + w2 + w3); + output[pos] = m < w1 ? input_1[n * w1 + m] : + m < w1 + w2 ? input_2[n * w2 + m - w1] : + input_3[n * w3 + m - w1 - w2]; + } + return; +} + +template +__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2 + w3 + w4); + int m = pos % (w1 + w2 + w3 + w4); + output[pos] = m < w1 ? input_1[n * w1 + m] : + m < w1 + w2 ? input_2[n * w2 + m - w1]: + m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: + input_4[n * w4 + m - w1 - w2 - w3]; + } + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, input_1, input_2, output); + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, w3, input_1, input_2, input_3, output); + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, w3, w4, input_1, + input_2, input_3, input_4, output); + return; +} + +template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, + half* output, cudaStream_t cuda_stream); + +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const float* input_1, const float* input_2, const float* input_3, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const int* input_1, const int* input_2, const int* input_3, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const half* input_1, const half* input_2, const half* input_3, + half* output, cudaStream_t cuda_stream); + +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const float* input_1, const float* input_2, const float* input_3, const float* input_4, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const int* input_1, const int* input_2, const int* input_3, const int* input_4, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const half* input_1, const half* input_2, const half* input_3, const half* input_4, + half* output, cudaStream_t cuda_stream); + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh new file mode 100755 index 0000000000..7bd32c140f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, + cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cu new file mode 100755 index 0000000000..87aaf1351c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cu @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "correction_mul_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw, + T* output) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) { + int n = i / chw; + output[i] = weight[i] * gamma[n] / running_std[n]; + } + return; +} + +template +__global__ void Mul(int N, const T* a, const T* b, T* c) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + c[i] = a[i] * b[i]; + } + return; +} + +template +__global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { + d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus()); + d_gamma[i] = d_gamma[i] / running_std[i]; + } + return; +} + +template +void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output, + cudaStream_t cuda_stream) { + CorrectionMul<<>>(weight, gamma, running_std, N, C * H * W, + output); +} + +template void CalCorrectionMul(const float* weight, const float* gamma, const float* running_std, int N, int C, + int H, int W, float* output, cudaStream_t cuda_stream); + +template +void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma, + T* tmp, cudaStream_t cuda_stream) { + Mul<<>>(N * C * H * W, d_out, weight, tmp); + Reduce<<>>(N, C * H * W, tmp, running_std, d_gamma); +} + +template void CalCorrectionMulGrad(const float* d_out, const float* weight, const float* running_std, int N, + int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh new file mode 100644 index 0000000000..cb4ccc2c44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss, + cudaStream_t cuda_stream); + +template +void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, + T *grad, cudaStream_t cuda_stream); + +template +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh new file mode 100644 index 0000000000..3ba27eeeea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float keep_prob, + cudaStream_t cuda_stream); +template +void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float keep_prob, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cu new file mode 100755 index 0000000000..e6f424c661 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cu @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "equalcount_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void EqualCount(const int size, const T* input1, const T* input2, T* output) { + T equal_count = 0; + + for (int i = 0; i < size; i++) { + if (input1[i] == input2[i]) { + equal_count++; + } + } + + output[0] = equal_count; + return; +} +template +void CalEqualCount(const int size, const T* input1, const T* input2, T* output, cudaStream_t cuda_stream) { + EqualCount<<<1, 1, 0, cuda_stream>>>(size, input1, input2, output); + return; +} + +template void CalEqualCount(const int size, const int* input1, const int* input2, int* output, + cudaStream_t cuda_stream); +template void CalEqualCount(const int size, const float* input1, const float* input2, float* output, + cudaStream_t cuda_stream); +template void CalEqualCount(const int size, const half* input1, const half* input2, half* output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh new file mode 100644 index 0000000000..e17615db67 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, + cudaStream_t cuda_stream); + +void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num, + const float *nudge_min, const float *nudge_max, const float *scale, + cudaStream_t cuda_stream); + +void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, + const int channel_num, const float *nudge_min, const float *nudge_max, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh new file mode 100644 index 0000000000..5f6675b2d7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream); + +void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale, cudaStream_t cuda_stream); + +void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu new file mode 100644 index 0000000000..bc400eb704 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/cuda_runtime.h" +#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" + +template +__global__ void IsNan(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsNan(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsInf(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsInf(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsFinite(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) == 0 && !isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsFinite(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void FloatStatus(const size_t size, const T* input, T* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0 || isnan(input[pos])) { + out[0] = 1; + } + } + return; +} +template <> +__global__ void FloatStatus(const size_t size, const half* input, half* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { + out[0] = 1; + } + } + return; +} + +template +void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { + FloatStatus<<>>(size, input, output); + return; +} +template +void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsNan<<>>(size, input, output); + return; +} +template +void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsInf<<>>(size, input, output); + return; +} +template +void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsFinite<<>>(size, input, output); + return; +} + +template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh new file mode 100644 index 0000000000..fbe063e72a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#include "runtime/device/gpu/cuda_common.h" +template +void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); +template +void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cu new file mode 100644 index 0000000000..be4415d509 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh" + +template +__device__ __forceinline__ T PowFunc(T x, T y) { + return pow(x, y); +} + +template <> +__device__ __forceinline__ half PowFunc(half x, half y) { + return __float2half(pow(__half2float(x), __half2float(y))); +} + +template +__device__ __forceinline__ bool CompareFunc(T x, T y) { + return abs(x) > y; +} + +template <> +__device__ __forceinline__ bool CompareFunc(half x, half y) { + return abs(__half2float(x)) > __half2float(y); +} + +template +__device__ __forceinline__ T Sgn(T x) { + return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); +} + +template <> +__device__ __forceinline__ half Sgn(half x) { + return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); +} + +template +__global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate, + const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power, + T *variable, T *accumulation, T *linear) { + const T two = static_cast(2.0); + const T learning_rate_power_val = -learning_rate_power[0]; + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i]; + const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); + const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); + const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0]; + + linear[i] += gradient[i] - sigma * variable[i]; + variable[i] = CompareFunc(linear[i], l1_regularization[0]) + ? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) / + (cur_accumulation_power / learning_rate[0] + two * l2_regularization[0])) + : static_cast(0); + accumulation[i] = cur_accumulation; + } +} + +template +void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, + const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, + cudaStream_t cuda_stream) { + ApplyFtrlKernel<<>>(size, gradient, learning_rate, l1_regularization, + l2_regularization, learning_rate_power, variable, + accumulation, linear); +} + +template void ApplyFtrl(const size_t size, const float *gradient, const float *learning_rate, + const float *l1_regularization, const float *l2_regularization, + const float *learning_rate_power, float *variable, float *accumulation, float *linear, + cudaStream_t cuda_stream); +template void ApplyFtrl(const size_t size, const half *gradient, const half *learning_rate, + const half *l1_regularization, const half *l2_regularization, + const half *learning_rate_power, half *variable, half *accumulation, half *linear, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh new file mode 100644 index 0000000000..b5f0f82afe --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, + const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu new file mode 100755 index 0000000000..03b58b81a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cu @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/gather.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void GatherKernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1) { + int num = output_dim0 * output_dim1 * output_dim2; + int i, j, k; + for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; + write_index += blockDim.x * gridDim.x) { + i = write_index / (output_dim1 * output_dim2) % output_dim0; + j = write_index / output_dim2 % output_dim1; + k = write_index % output_dim2; + + if ((indices[j] >= 0) && (indices[j] < input_dim1)) { + int read_index = i * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; + output[write_index] = input[read_index]; + } else { + output[write_index] = 0; + } + } + + return; +} +template +void Gather(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, + size_t input_dim1, cudaStream_t stream) { + int size = output_dim0 * output_dim1 * output_dim2; + GatherKernel<<>>(input, indices, output, output_dim0, output_dim1, + output_dim2, input_dim1); + return; +} + +template void Gather(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); + +template void Gather(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gather.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cu new file mode 100644 index 0000000000..a4dc6648cc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cu @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void GeluKernel(size_t size, T *input_addr, T *output_addr) { + // formula: + // gelu(x) = 0.5 * x * (1.0 + tanh(y)) + // tanh(y) = 2 / (1 + exp(-2y)) - 1) + // y = sqrt(2/pi) * (x + 0.044715 * x^3) + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + float x = input_addr[pos]; + float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + output_addr[pos] = 0.5 * x * (1.0 + tanh_res); + } +} + +template <> +__global__ void GeluKernel(size_t size, half *input_addr, half *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half x = input_addr[pos]; + float tanh_res = tanh(__half2float(half(0.7978845608) * (x + half(0.044715) * x * x * x))); + output_addr[pos] = half(0.5) * x * (half(1.0) + __float2half(tanh_res)); + } +} + +template <> +__global__ void GeluKernel(size_t size, half2 *input_addr, half2 *output_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + half2 x = input_addr[pos]; + float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); + float2 tanh_res; + tanh_res.x = tanh(tanh_param.x); + tanh_res.y = tanh(tanh_param.y); + output_addr[pos] = half2(0.5, 0.5) * x * (half2(1.0, 1.0) + __float22half2_rn(tanh_res)); + } +} + +template +void Gelu(size_t size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { + GeluKernel<<>>(size, input_addr, output_addr); + return; +} + +template <> +void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream) { + if (size % 2 == 0) { + GeluKernel<<>>( + size / 2, reinterpret_cast(input_addr), reinterpret_cast(output_addr)); + } else { + GeluKernel<<>>(size, input_addr, output_addr); + } + return; +} + +template +__global__ void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr) { + // formula: + // dx = dy * y' + // y' = 0.5 * (1 + tanh(tanh_para)) + + // 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right + // tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3) + // mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)) + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + T x = x_addr[pos]; + T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); + T mul_right = 0.7978845608 + 0.1070322244 * x * x; + T y_res = 0.5 * (1.0 + tanh_res) + 0.5 * x * (1.0 - tanh_res * tanh_res) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +__global__ void GeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, half2 *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + half2 x = x_addr[pos]; + float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); + float2 tanh_res; + tanh_res.x = tanh(tanh_param.x); + tanh_res.y = tanh(tanh_param.y); + half2 tanh_res_half = __float22half2_rn(tanh_res); + half2 mul_right = half2(0.7978845608, 0.7978845608) + half2(0.1070322244, 0.1070322244) * x * x; + half2 y_res = half2(0.5, 0.5) * (half2(1.0, 1.0) + tanh_res_half) + + half2(0.5, 0.5) * x * (half2(1.0, 1.0) - tanh_res_half * tanh_res_half) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +__global__ void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + half x = x_addr[pos]; + half tanh_param = half(0.7978845608) * (x + half(0.044715) * x * x * x); + half tanh_res = __float2half_rn(tanh(__half2float(tanh_param))); + half mul_right = half(0.7978845608) + half(0.1070322244) * x * x; + half y_res = half(0.5) * (half(1.0) + tanh_res) + half(0.5) * x * (half(1.0) - tanh_res * tanh_res) * mul_right; + dx_addr[pos] = dy_addr[pos] * y_res; + } +} + +template +void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { + GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); +} + +template <> +void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { + if (size % 2 == 0) { + GeluGradKernel<<>>( + size / 2, reinterpret_cast(dy_addr), reinterpret_cast(x_addr), + reinterpret_cast(dx_addr)); + } else { + GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); + } + return; +} + +template void Gelu(size_t size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); +template void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, float *dy_addr, float *x_addr, float *dx_addr, cudaStream_t cuda_stream); +template void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh new file mode 100644 index 0000000000..1e69f26d57 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Gelu(size_t input_size, T* input_addr, T* output_addr, cudaStream_t cuda_stream); + +template +void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu new file mode 100644 index 0000000000..fcb7418952 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; + +template +inline __device__ T my_pow(T a, double b) { + return pow(a, static_cast(b)); +} + +template <> +inline __device__ half my_pow(half a, double b) { + return __float2half(pow(__half2float(a), static_cast(b))); +} + +template +inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, + const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, + T* dg, T* db) { + int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int row = NUM_PER_THREAD_REDUCE * i + j; + if (row >= row_dim) { + return; + } + + int pos = row * col_dim + col; + dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); + db[0] += dy[pos]; + } + } +} + +template +inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); + db[0] += __shfl_down_sync(0xffffffff, db[0], delta); + } +} + +template +inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, + T* db_addr) { + if (threadIdx.x >= row_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + DynamicSharedMem share_mem; + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 2; + share_mem.addr()[offset] = dg[0]; + share_mem.addr()[offset + 1] = db[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 2; + share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset]; + share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1]; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + dg_addr[col] = share_mem.addr()[0]; + db_addr[col] = share_mem.addr()[1]; + } +} + +template +__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, + const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { + // row: [0:param_axis] + // col: [param_axis:] + // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) + // dg[j] = \Sigma_{j}dg[i][j] + for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { + T dg = 0; + T db = 0; + GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); + GammaAndBetaWarpReduce(&dg, &db); + GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); + } +} + +template +inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, + T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, + const T* var, const T* gamma) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + T v1 = dy[pos] * gamma[gamma_offset]; + T v2 = x[pos] - mean[row]; + + sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5); + sum2[0] += v1; + sum3[0] += -2.0 * v2; + } + } +} + +template <> +inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, + half* sum1, half* sum2, half* sum3, const half* dy, const half* x, + const half* mean, const half* var, const half* gamma) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + half v1 = dy[pos] * gamma[gamma_offset]; + half v2 = x[pos] - mean[row]; + + sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5); + sum2[0] += v1; + sum3[0] += __float2half(-2.0) * v2; + } + } +} + +template +inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); + sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); + sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); + } +} + +template +inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { + if (threadIdx.x >= col_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem[offset] = sum1[0]; + share_mem[offset + 1] = sum2[0]; + share_mem[offset + 2] = sum3[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + share_mem[threadIdx.x * 3] += share_mem[offset]; + share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; + share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; + } + } + __syncthreads(); +} + +template +inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, + const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, + const T* share_mem) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + T v1 = dy[pos] * gamma[gamma_offset]; + T v2 = x[pos] - mean[row]; + T v3 = my_pow(var[row] + epsilon, -0.5); + dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + + (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); + } +} + +template <> +inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, + const half* dy, const half* x, const half* mean, const half* var, const half* gamma, + half* dx, const half* share_mem) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + half v1 = dy[pos] * gamma[gamma_offset]; + half v2 = x[pos] - mean[row]; + half v3 = my_pow(var[row] + epsilon, -0.5); + dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + + (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ + * __float2half(1.0 / col_dim); + } +} + +template +__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx) { + for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { + T sum1 = 0; + T sum2 = 0; + T sum3 = 0; + DynamicSharedMem share_mem; + InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); + InputWarpReduce(&sum1, &sum2, &sum3); + InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr()); + InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr()); + } +} + +template +void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { + int share_mem_size = + ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, + gamma, dx); + + share_mem_size = + ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); + GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); +} + +template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, + const float* dy, const float* x, const float* mean, const float* var, const float* gamma, + float* dx, float* dg, float* db, cudaStream_t stream); +template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, + const half* dy, const half* x, const half* mean, const half* var, const half* gamma, + half* dx, half* dg, half* db, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh new file mode 100644 index 0000000000..13d7a58614 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, + const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu new file mode 100644 index 0000000000..138300b303 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cu @@ -0,0 +1,163 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +constexpr int NUM_PER_THREAD_REDUCE = 4; +constexpr int WARP_SIZE = 32; + +template +inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) { + // Welford Algorithm: + // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k + // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) + num[0]++; + T mean_new = mean[0] + (val - mean[0]) / num[0]; + var[0] = var[0] + (val - mean[0]) * (val - mean_new); + mean[0] = mean_new; +} + +template +inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { + T zero = 0; + if (n2 == zero) { + return; + } + + T count = n1[0] + n2; + v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count; + m1[0] = (n1[0] * m1[0] + n2 * m2) / count; + n1[0] = count; +} + +template +inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int pos = NUM_PER_THREAD_REDUCE * i + j; + if (pos >= col_dim) { + return; + } + MeanAndVarAccumulation(mean, var, num, block_addr[pos]); + } + } +} + +template +inline __device__ void WarpReduce(T *mean, T *var, T *num) { + for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { + T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); + T var_other = __shfl_down_sync(0xffffffff, var[0], delta); + T num_other = __shfl_down_sync(0xffffffff, num[0], delta); + MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other); + } +} + +template +inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, + T *share_mem) { + if (threadIdx.x >= col_dim) { + return; + } + + // load data to share memory + // thread(0, 32, 64, 96, ...) keep the data + if (threadIdx.x % WARP_SIZE == 0) { + int offset = threadIdx.x / WARP_SIZE * 3; + share_mem[offset] = mean[0]; + share_mem[offset + 1] = var[0]; + share_mem[offset + 2] = num[0]; + } + __syncthreads(); + + for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) { + int offset = (threadIdx.x + stride) * 3; + MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2], + share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]); + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + mean_addr[blockIdx.x] = share_mem[0]; + share_mem[1] /= col_dim; + var_addr[blockIdx.x] = share_mem[1]; + } +} + +template +inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x, + const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = row * col_dim + col; + int i = pos % param_dim; + y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; + } +} + +template <> +inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const half *x, + const half *share_mem, const half *gamma, const half *beta, const half epsilon, + half *y) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = row * col_dim + col; + int i = pos % param_dim; + y[pos] = (x[pos] - share_mem[0]) / hsqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; + } +} + +template +__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, + const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { + for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { + T mean = 0; + T var = 0; + T num = 0; + const T *block_addr = x + row * col_dim; + DynamicSharedMem share_mem; + + ThreadReduce(col_dim, block_addr, &mean, &var, &num); + WarpReduce(&mean, &var, &num); + BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr()); + + __syncthreads(); + LayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y); + } +} + +template +void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, + const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { + const dim3 block(row_dim); + const dim3 thread(256); + // keep the mean/var/num after warp reduce + int share_mem_size = + ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); + LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, + mean, var); +} + +template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, + const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, + cudaStream_t stream); +template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *x, const half *gamma, const half *beta, half *y, half *mean, half *var, + cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh new file mode 100644 index 0000000000..9548b30d44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +struct DynamicSharedMem; +template<> +struct DynamicSharedMem { + __device__ float *addr() { + extern __shared__ float addr_float[]; + return addr_float; + } +}; +template<> +struct DynamicSharedMem { + __device__ half *addr() { + extern __shared__ half addr_half[]; + return addr_half; + } +}; + +template +void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, + const T* beta, T* y, T* mean, T* var, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cu new file mode 100644 index 0000000000..3915dba172 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "minmax_update_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, + float *output_max, const float min, const float max, + const float decay) { + output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); + output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; + output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); + output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; + return; +} + +__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) { + output_min[0] = min > 0 ? 0 : min; + output_max[0] = max < 0 ? 0 : max; + return; +} + +__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, + float *output_max, int channels, int per_channel_nums, bool ema, + float ema_decay) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { + thrust::pair sum = + thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); + if (ema) { + output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; + output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; + } else { + output_min[i] = sum.first[0]; + output_max[i] = sum.second[0]; + } + output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; + output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; + } + return; +} + +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const int channel_num, const float ema_decay, const bool ema, + cudaStream_t cuda_stream) { + int per_channel_num = total_num / channel_num; + UpdateInputMinMaxPerChannel<<>>( + input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay); + return; +} + +void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { + float minel = 0.f; + float maxel = 0.f; + auto policy = thrust::cuda::par.on(cuda_stream); + thrust::pair, thrust::device_ptr> tuple; + tuple = + thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num); + minel = tuple.first[0]; + maxel = tuple.second[0]; + + if (ema) { + UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, + maxel, ema_decay); + } else { + UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel); + } + return; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh new file mode 100644 index 0000000000..b4b4d582ee --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int total_num, const int channel_num, const float ema_decay, const bool ema, + cudaStream_t cuda_stream); + +void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, + const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh new file mode 100755 index 0000000000..62708663ad --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, + const S *momentum, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu new file mode 100644 index 0000000000..6dc4d676f2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cu @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "one_hot_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void OneHotKernel(size_t size, const S *indices, size_t depth, const T *on_value, const T *off_value, + size_t left_dim_size, size_t right_dim_size, T *output) { + T on_v = *on_value; + T off_v = *off_value; + for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; + thread_idx += blockDim.x * gridDim.x) { + if (thread_idx < size) { + int left_idx = (thread_idx / (depth * right_dim_size)) % left_dim_size; + int d_idx = thread_idx / right_dim_size % depth; + int right_idx = thread_idx % right_dim_size; + int input_idx = left_idx * right_dim_size + right_idx; + int output_idx = left_idx * depth * right_dim_size + d_idx * right_dim_size + right_idx; + if (indices[input_idx] == d_idx) { + output[output_idx] = on_v; + } else { + output[output_idx] = off_v; + } + } + } +} +template +void OneHot(const S *indices, size_t depth, const T *on_value, const T *off_value, size_t left_dim_size, + size_t right_dim_size, T *output, cudaStream_t cuda_stream) { + size_t size = left_dim_size * depth * right_dim_size; + OneHotKernel<<>>(size, indices, depth, on_value, off_value, + left_dim_size, right_dim_size, output); + return; +} +template void OneHot(const int *indices, size_t depth, const float *on_value, const float *off_value, + size_t left_dim_size, size_t right_dim_size, float *output, cudaStream_t cuda_stream); +template void OneHot(const int *indices, size_t depth, const half *on_value, const half *off_value, + size_t left_dim_size, size_t right_dim_size, half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/one_hot_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu new file mode 100755 index 0000000000..3bb4d04a01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" + +template +__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, float pad_value, T* output) { + T pad_value_ = static_cast(pad_value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int block_num = pos / padded_width / padded_height; + const int padded_w = pos % padded_width; + const int padded_h = pos / padded_width % padded_height; + if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || + padded_w - pad_left >= old_width) { + output[pos] = pad_value_; + } else { + output[pos] = input[(block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; + } + } + return; +} + +template +__global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int block_num = pos / old_width / old_height; + const int padded_w = pos % old_width + pad_left; + const int padded_h = pos / old_width % old_height + pad_top; + dx[pos] = dy[(block_num * padded_height + padded_h) * padded_width + padded_w]; + } + return; +} + +template +void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, + const float pad_value, T* output, cudaStream_t cuda_stream) { + Pad<<>>(size, input, num, channels, old_height, old_width, + padded_height, padded_width, pad_top, pad_left, pad_value, + output); + return; +} + +template +void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx, cudaStream_t cuda_stream) { + PadGrad<<>>(size, dy, num, channels, old_height, old_width, + padded_height, padded_width, pad_top, pad_left, dx); + return; +} + +template void CalPad(const size_t size, const float* input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, const int padded_width, + const int pad_top, const int pad_left, float pad_value, float* output, + cudaStream_t cuda_stream); +template void CalPadGrad(const size_t size, const float* dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, float* dx, + cudaStream_t cuda_stream); +template void CalPad(const size_t size, const half* input, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, const int padded_width, + const int pad_top, const int pad_left, float pad_value, half* output, + cudaStream_t cuda_stream); +template void CalPadGrad(const size_t size, const half* dy, const int num, const int channels, + const int old_height, const int old_width, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, half* dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh new file mode 100755 index 0000000000..b10804fdab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, + float pad_value, T* output, cudaStream_t cuda_stream); +template +void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh new file mode 100644 index 0000000000..b099ead9bf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void StandardNormal(int seed, int seed2, curandState *globalState, + T *output, size_t count, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cu new file mode 100644 index 0000000000..80806b552f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cu @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, + T* mean_square, T*moment, T* gradients, const size_t size) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i]; + moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i]; + variable[i] -= moment[i]; + } +} + +template +void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, + T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { + RmsPropKernel<<>>(learning_rate, decay, momentum, epsilon, + variable, mean_square, moment, gradients, size); +} + +template +__global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, + T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients, + const size_t size) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i]; + mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; + moment[i] = momentum[0] * moment[i] + learning_rate[0] * + rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i]; + variable[i] -= moment[i]; + } +} + +template +void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size, + cudaStream_t cuda_stream) { + RmsPropCenterKernel<<>>(learning_rate, decay, momentum, epsilon, + variable, mean_gradients, mean_square, + moment, gradients, size); +} + +template +void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon, + float* variable, float* mean_square, float* moment, float* gradients, const size_t size, + cudaStream_t cuda_stream); + +template +void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, + float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients, + const size_t size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh new file mode 100644 index 0000000000..16ad611381 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ +#include "runtime/device/gpu/cuda_common.h" + +template +void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, + T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); + +template +void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu new file mode 100644 index 0000000000..f7086f8093 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cu @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh" + +template +__global__ void Select(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + output[pos] = cond[pos] ? input_x[pos] : input_y[pos]; + } + return; +} + +template +void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, + cudaStream_t cuda_stream) { + Select<<>>(size, cond, input_x, input_y, output); + return; +} + +template void CalSelect(const size_t size, const bool* cond, const float* input_X, const float* input_y, + float* output, cudaStream_t cuda_stream); +template void CalSelect(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output, + cudaStream_t cuda_stream); +template void CalSelect(const size_t size, const bool* cond, const half* input_X, const half* input_y, + half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh new file mode 100644 index 0000000000..e201ab352c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/select_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, + cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu new file mode 100644 index 0000000000..f0c64bfb01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" + +template +__global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, + T *outputs) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (logits[i] >= 0) { + outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; + } else { + const T exp_val = exp(logits[i]); + outputs[i] = exp_val / (1. + exp_val) - labels[i]; + } + } +} + +template +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream) { + SigmoidCrossEntropyWithLogitsGradKernel<<>>(size, logits, labels, + outputs); +} + +template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, + const float *labels, float *outputs, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh new file mode 100644 index 0000000000..6b444d6c02 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu new file mode 100644 index 0000000000..7425ac3809 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" + +template +__global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const T reverse_factor = static_cast(logits[i] >= 0); + outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); + } +} + +template +void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream) { + SigmoidCrossEntropyWithLogitsKernel<<>>(size, logits, labels, outputs); +} + +template void SigmoidCrossEntropyWithLogits(const size_t size, const float *logits, const float *labels, + float *outputs, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh new file mode 100644 index 0000000000..7e9130857f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu new file mode 100755 index 0000000000..dd4effc174 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cu @@ -0,0 +1,191 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" + +template +__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { + int i = pos / (l2 * l3 * l4) % l1; + int j = pos / (l3 * l4) % l2; + int k = pos / l4 % l3; + int o = pos % l4; + + int offset = (i + s1) * (d2 * d3 * d4) + + (j + s2) * (d3 * d4) + + (k + s3) * d4 + + (o + s4); + output[pos] = input[offset]; + } +} +template +__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { + output[start + pos] = dy[p + pos]; + } + return; +} +template +__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); + pos += blockDim.x * gridDim.x) { + output[p + pos] = input[start + pos * stride]; + } + return; +} +template +__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); + pos += blockDim.x * gridDim.x) { + dx[start + pos * stride] = dy[p + pos]; + } + return; +} +template +__global__ void FillArray(T* addr, const size_t len, const float value) { + T value_ = static_cast(value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < len; pos += blockDim.x * gridDim.x) { + addr[pos] = value_; + } + return; +} +template +void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream) { + FillArray<<>>(addr, input_size, value); + return; +} +template +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output, cudaStream_t stream) { + Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, + d1, d2, d3, d4, input, output); +} +template +void CalSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, const std::vector begin, + const std::vector size, T* output, cudaStream_t cuda_stream) { + int block = in_shape[1] * in_shape[2] * in_shape[3]; + int map = in_shape[2] * in_shape[3]; + int w = in_shape[3]; + int length = size[3]; + int p = 0; + for (int i = begin[0]; i < size[0] + begin[0]; i++) { + for (int j = begin[1]; j < size[1] + begin[1]; j++) { + for (int k = begin[2]; k < size[2] + begin[2]; k++) { + SliceGrad<<>>( + dy, p, i * block + j * map + k * w + begin[3], length, output); + p = p + size[3]; + } + } + } +} +template +void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* output, cudaStream_t cuda_stream) { + int block = in_shape[1] * in_shape[2] * in_shape[3]; + int map = in_shape[2] * in_shape[3]; + int w = in_shape[3]; + int ended = end[3]; + int p = 0; + int start = 0; + for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0])); i += std::abs(strides[0])) { + for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1])); j += std::abs(strides[1])) { + for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2])); k += std::abs(strides[2])) { + start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + + (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; + StridedSlice<<>>(input, p, start, begin[3], strides[3], + ended, output); + p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); + } + } + } +} +template +void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* dx, cudaStream_t cuda_stream) { + int block = in_shape[1] * in_shape[2] * in_shape[3]; + int map = in_shape[2] * in_shape[3]; + int w = in_shape[3]; + int ended = end[3]; + int p = 0; + int start = 0; + for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0] + 1)); i += std::abs(strides[0])) { + for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1] + 1)); + j += std::abs(strides[1])) { + for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2] + 1)); + k += std::abs(strides[2])) { + start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + + (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; + StridedSliceGrad<<>>(dy, p, start, begin[3], strides[3], + ended, dx); + p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); + } + } + } +} + +template void FillDeviceArray(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const float *input, float *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, + const std::vector begin, const std::vector size, float* output, + cudaStream_t cuda_stream); +template void CalStridedSlice(const size_t input_size, const float* input, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, float* output, cudaStream_t cuda_stream); +template void CalStridedSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, float* dx, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const half *input, half *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, + const std::vector begin, const std::vector size, half* output, + cudaStream_t cuda_stream); +template void CalStridedSlice(const size_t input_size, const half* input, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, half* output, cudaStream_t cuda_stream); +template void CalStridedSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, half* dx, cudaStream_t cuda_stream); +template void FillDeviceArray(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const int *input, int *output, cudaStream_t stream); +template void CalSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, + const std::vector begin, const std::vector size, int* output, + cudaStream_t cuda_stream); +template void CalStridedSlice(const size_t input_size, const int* input, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, int* output, cudaStream_t cuda_stream); +template void CalStridedSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, + const std::vector strides, int* dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh new file mode 100755 index 0000000000..e04f277c3d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ + +#include +#include +#include "runtime/device/gpu/cuda_common.h" + + +template +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output, cudaStream_t stream); +template +void CalSliceGrad(const size_t input_size, const T* input, const std::vector in_shape, + const std::vector begin, const std::vector size, T* output, cudaStream_t cuda_stream); +template +void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* output, cudaStream_t cuda_stream); +template +void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, + const std::vector begin, const std::vector end, const std::vector strides, + T* dx, cudaStream_t cuda_stream); +template +void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu new file mode 100644 index 0000000000..9050044b7f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "smooth_l1_loss_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target, + T *loss) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]); + if (value < sigma) { + loss[i] = static_cast(0.5) * value * value; + } else { + loss[i] = value - static_cast(0.5); + } + } +} + +template +void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, + cudaStream_t stream) { + SmoothL1LossKernel<<>>(input_size, sigma, prediction, target, loss); +} + +template +__global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target, + const T *dloss, T *dx) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { + T value = prediction[i] - target[i]; + if (value > static_cast(sigma)) { + dx[i] = dloss[i]; + } else if (value < static_cast(-sigma)) { + dx[i] = -dloss[i]; + } else { + dx[i] = value * dloss[i]; + } + } +} + +template +void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, + T *dx, cudaStream_t stream) { + SmoothL1LossGradKernel<<>>(input_size, sigma, prediction, target, + dloss, dx); +} + +template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target, + float *loss, cudaStream_t stream); +template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target, + const float *dloss, float *dx, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh new file mode 100755 index 0000000000..fa32260381 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss, + cudaStream_t cuda_stream); + +template +void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu new file mode 100755 index 0000000000..ffcb2c8052 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "transpose_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template +__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, + const int shape_size, T* output) { + int pos_size; + int temp_pos; + int newpos; + int newpos_size; + int pos_array[TRANSPOSE_MAX_DIMENSION]; + + // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + + // posArray[1] * input_shape[2] * input_shape[3] + + // posArray[2] * input_shape[3] + + // posArray[3] + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + temp_pos = pos; + pos_size = size / input_shape[0]; + pos_array[0] = temp_pos / pos_size; + for (int i = 1; i < shape_size; i++) { + temp_pos -= pos_array[i - 1] * pos_size; + pos_size = pos_size / input_shape[i]; + pos_array[i] = temp_pos / pos_size; + } + + newpos = pos_array[input_axis[shape_size - 1]]; + newpos_size = 1; + for (int j = shape_size - 2; j >= 0; j--) { + newpos_size *= input_shape[input_axis[j + 1]]; + newpos += pos_array[input_axis[j]] * newpos_size; + } + + output[newpos] = input[pos]; + } + return; +} +template +void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, + T* output, cudaStream_t cuda_stream) { + Transpose<<>>(size, input, input_shape, input_axis, shape_size, + output); + return; +} + +template void CalTranspose(const int size, const float* input, const int* input_shape, const int* input_axis, + const int shape_size, float* output, cudaStream_t cuda_stream); +template void CalTranspose(const int size, const half* input, const int* input_shape, const int* input_axis, + const int shape_size, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cuh rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu similarity index 100% rename from mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu rename to mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh new file mode 100755 index 0000000000..cf8b30866e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Logarithm(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cu new file mode 100644 index 0000000000..3d299c2352 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh" + +template +__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + T* input_addr, S* ids_addr, T* output_addr) { + for (int input_index = blockIdx.x * blockDim.x + threadIdx.x; input_index < input_dim0 * input_dim1; + input_index += blockDim.x * gridDim.x) { + size_t j = input_index / input_dim1; + size_t k = input_index % input_dim1; + + S i = ids_addr[j]; + if (i < 0 || i >= output_dim0) { + continue; + } + size_t output_index = i * output_dim1 + k; + atomicAdd(output_addr + output_index, input_addr[input_index]); + } +} + +template +void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + T* input_addr, S* ids_addr, T* output_addr, cudaStream_t stream) { + int size = input_dim0 * input_dim1; + UnsortedSegmentSum<<>>(input_dim0, input_dim1, + output_dim0, output_dim1, input_addr, ids_addr, output_addr); + return; +} + +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream); +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + float* input_addr, int64_t* ids_addr, float* output_addr, cudaStream_t stream); + +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + int* input_addr, int* ids_addr, int* output_addr, cudaStream_t stream); +template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + int* input_addr, int64_t* ids_addr, int* output_addr, cudaStream_t stream); + + + diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh new file mode 100644 index 0000000000..315677fde4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unsorted_segment_sum.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, + T* input_addr, S* ids, T* output_addr, cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc new file mode 100644 index 0000000000..3c88b88c74 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/data/dataset_init_kernel.h" +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::GpuBufferMgr; + +DatasetInitKernel::DatasetInitKernel() : total_bytes_(0) {} + +const std::vector &DatasetInitKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &DatasetInitKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool DatasetInitKernel::Init(const CNodePtr &kernel_node) { + queue_name_ = GetAttr(kernel_node, "queue_name"); + auto shapes = GetAttr>>(kernel_node, "shapes"); + auto types = GetAttr>(kernel_node, "types"); + if (shapes.size() != types.size()) { + MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; + } + + for (size_t i = 0; i < shapes.size(); i++) { + int unit = UnitSizeInBytes(types[i]->type_id()); + int nums = ElementNums(shapes[i]); + int bytes = unit * nums; + shapes_.push_back(bytes); + total_bytes_ += bytes; + } + return true; +} + +void DatasetInitKernel::InitSizeLists() { return; } + +bool DatasetInitKernel::Launch(const std::vector &, const std::vector &, + const std::vector &, void *) { + void *addr = nullptr; + size_t len = total_bytes_ * buffer_q_capacity_; + + if (!device::gpu::GPUMemoryAllocator::GetInstance().AllocBufferQueueMem(len, &addr)) { + MS_LOG(EXCEPTION) << "Memory not enough: failed to allocate GPU buffer queue memory[" << len << "]."; + } + + auto status = GpuBufferMgr::GetInstance().Create(0, queue_name_, addr, shapes_, buffer_q_capacity_); + if (status) { + MS_LOG(EXCEPTION) << "Init Dataset Failed. len: " << len << ", status:" << status; + } + + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.h new file mode 100644 index 0000000000..f8cc9b19ea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_init_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DATASET_INIT_KERNEL_H +#define MINDSPORE_DATASET_INIT_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DatasetInitKernel : public GpuKernel { + public: + DatasetInitKernel(); + ~DatasetInitKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + std::string queue_name_; + std::vector shapes_; + size_t total_bytes_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + // The capacity of buffer Q. + size_t buffer_q_capacity_{2}; +}; + +MS_REG_GPU_KERNEL(InitDataSetQueue, DatasetInitKernel) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc new file mode 100644 index 0000000000..67a487ce28 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h" +#include +#include +#include +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "runtime/device/gpu/gpu_common.h" +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::GpuBufferMgr; +using mindspore::device::HandleMgr; + +DatasetIteratorKernel::DatasetIteratorKernel() : handle_(HandleMgr::INVALID_HANDLE), total_bytes_(0) {} + +DatasetIteratorKernel::~DatasetIteratorKernel() { GpuBufferMgr::GetInstance().Close(handle_); } + +const std::vector &DatasetIteratorKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &DatasetIteratorKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { + queue_name_ = GetAttr(kernel_node, "shared_name"); + auto shapes = GetAttr>>(kernel_node, "shapes"); + auto types = GetAttr>(kernel_node, "types"); + if (shapes.size() != types.size()) { + MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; + } + + for (size_t i = 0; i < shapes.size(); i++) { + int unit = UnitSizeInBytes(types[i]->type_id()); + int nums = ElementNums(shapes[i]); + int bytes = unit * nums; + output_size_list_.push_back(bytes); + total_bytes_ += bytes; + } + + handle_ = GpuBufferMgr::GetInstance().Open(0, queue_name_, output_size_list_); + if (handle_ == HandleMgr::INVALID_HANDLE) { + MS_LOG(EXCEPTION) << "Gpu Queue(" << queue_name_ << ") Open Failed"; + } + + return true; +} + +void DatasetIteratorKernel::InitSizeLists() { return; } + +bool DatasetIteratorKernel::Launch(const std::vector &, const std::vector &, + const std::vector &outputs, void *stream) { + void *addr = nullptr; + size_t len = 0; + + int repeat = 0; + while (true) { + auto ret = GpuBufferMgr::GetInstance().Front(handle_, &addr, &len); + if (ret == device::SUCCESS) { + break; + } + + if (ret == device::TIMEOUT) { + repeat++; + if (repeat < 10) { + MS_LOG(INFO) << "Waiting for data...(" << repeat << " / 10)"; + continue; + } else { + MS_LOG(ERROR) << "Get data timeout"; + return false; + } + } + + MS_LOG(ERROR) << "Get data failed, errcode " << ret; + return false; + } + + if (total_bytes_ != len) { + MS_LOG(ERROR) << "Dataset front error. read: " << len << ", expect: " << total_bytes_ << ", "; + return false; + } + + for (size_t i = 0; i < output_size_list_.size(); i++) { + void *output_addr = GetDeviceAddress(outputs, i); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice, + reinterpret_cast(stream)), + "Cuda Memcpy Failed"); + addr = reinterpret_cast(addr) + output_size_list_[i]; + } + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream)), + "cudaStreamSynchronize failed"); + (void)GpuBufferMgr::GetInstance().Pop(handle_); + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h new file mode 100644 index 0000000000..746aed3294 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_iterator_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_GET_NEXT_KERNEL_H +#define MINDSPORE_GET_NEXT_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class DatasetIteratorKernel : public GpuKernel { + public: + DatasetIteratorKernel(); + ~DatasetIteratorKernel(); + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + std::string queue_name_; + unsigned int handle_; + size_t total_bytes_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +MS_REG_GPU_KERNEL(GetNext, DatasetIteratorKernel) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc new file mode 100644 index 0000000000..cb014a3d2b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/data/dataset_utils.h" + +namespace mindspore { +namespace kernel { +size_t UnitSizeInBytes(const mindspore::TypeId &t) { + size_t bytes = 0; + switch (t) { + case kNumberTypeBool: + case kNumberTypeInt8: + case kNumberTypeUInt8: + bytes = 1; + break; + case kNumberTypeInt16: + case kNumberTypeUInt16: + case kNumberTypeFloat16: + bytes = 2; + break; + case kNumberTypeInt: + case kNumberTypeUInt: + case kNumberTypeInt32: + case kNumberTypeUInt32: + case kNumberTypeFloat: + case kNumberTypeFloat32: + bytes = 4; + break; + case kNumberTypeUInt64: + case kNumberTypeInt64: + case kNumberTypeFloat64: + bytes = 8; + break; + default: + MS_LOG(EXCEPTION) << "Invalid types " << t; + break; + } + + return bytes; +} + +int ElementNums(const std::vector &shape) { + if (shape.size() == 0) { + return 0; + } + + int nums = 1; + for (size_t i = 0; i < shape.size(); i++) { + nums *= shape[i]; + } + + return nums; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_utils.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h similarity index 100% rename from mindspore/ccsrc/kernel/gpu/data/dataset_utils.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/data/dataset_utils.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h new file mode 100644 index 0000000000..4c179f2173 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "runtime/device/gpu/gpu_device_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "backend/session/anf_runtime_algorithm.h" +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; + +namespace mindspore { +namespace kernel { +class GpuKernel : public KernelMod { + public: + virtual ~GpuKernel() = default; + virtual bool Init(const CNodePtr &kernel_node) = 0; + + protected: + virtual void InitResource() {} + virtual void InitSizeLists() = 0; + + template + inline T *GetDeviceAddress(const std::vector &addr_list, size_t index) { + if (index >= addr_list.size()) { + MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; + } + // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. + if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(addr_list[index]->addr); + return reinterpret_cast(addr_list[index]->addr); + } + + template + inline T GetAttr(const CNodePtr &kernel_node, const std::string &key) const { + const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); + const ValuePtr &attr = prim->GetAttr(key); + if (attr == nullptr) { + const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node); + MS_LOG(EXCEPTION) << "The attr(" << key << ") of kernel(" << prim_name << ") not exist"; + } + return GetValue(attr); + } + // expand Nd Shape to 4d (N in [0,4]) + void ShapeNdTo4d(const std::vector &src, std::vector *dst) { + if (src.size() > 4) { + MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; + } + dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); + dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); + dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); + dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); + } + + inline void CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, + const std::vector &Out) { + if (A != Out && B != Out) { + MS_EXCEPTION(ValueError) + << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" + "InputA must match the corresponding dimension of the destination tensor outC, and each " + "dimension of the inputB " + "must match the corresponding dimension of outC or must be equal to 1."; + } + } + + // choose the suitable datatype for cudnn/cublas + inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { + auto type = kCudnnDtypeMap.find(Type); + if (type == kCudnnDtypeMap.end()) { + MS_EXCEPTION(TypeError) << Type << " is not supported."; + } + return type->second; + } + inline cudaDataType_t GetCudaDataType(const std::string &Type) { + auto type = kCudaDtypeMap.find(Type); + if (type == kCudaDtypeMap.end()) { + MS_EXCEPTION(TypeError) << Type << " is not supported."; + } + return type->second; + } +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc new file mode 100644 index 0000000000..3820089e35 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +#include +#include + +#include "common/utils.h" +#include "runtime/device/kernel_info.h" +#include "runtime/device/gpu/cuda_common.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +GpuKernelFactory &GpuKernelFactory::GetInstance() { + static GpuKernelFactory instance; + return instance; +} + +void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, + GpuKernelCreater &&creater) { + map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater); +} + +void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, + std::vector> *iter_second, + size_t attr_index) { + if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { + if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; + for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { + (void)iter_second->at(attr_index).first.AddInputAttr(dtype); + } + } else { + MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; + } + } + if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { + if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; + for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { + (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); + } + } else { + MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; + } + } +} + +std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) { + std::string type_lists = ""; + auto iter = map_kernel_name_to_creater_.find(kernel_name); + if (map_kernel_name_to_creater_.end() == iter) { + return type_lists; + } + for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { + std::string type_list = "in["; + auto attr = (iter->second)[attr_index].first; + for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) { + type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) + + ((input_index == (attr.GetInputSize() - 1)) ? "" : " "); + } + type_list = type_list + "], out["; + for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) { + type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) + + ((input_index == (attr.GetOutputSize() - 1)) ? "" : " "); + } + type_lists = type_lists + type_list + "]; "; + } + return type_lists; +} + +std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string &kernel_name, + const KernelBuildInfo *kernel_info) { + auto iter = map_kernel_name_to_creater_.find(kernel_name); + const int marjor_sm = GET_MAJOR_SM; + if (map_kernel_name_to_creater_.end() == iter) { + MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!"; + return std::make_pair(false, 0); + } + if ((iter->second).size() == 1 && (iter->second)[0].first.GetInputSize() == 0) { + return std::make_pair(true, 0); + } + + for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { + CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); + bool flag = true; + // data type matching check of all input parameters of kernel + for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { + if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { + if (marjor_sm < MINIUM_SM) { + MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + if (kernel_info->GetInputDeviceType(input_index) != + (iter->second)[attr_index].first.GetInputAttr(input_index).first) { + flag = false; + break; + } + } + if (!flag) { + continue; + } + // data type matching check of all output parameters of kernel + for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { + if (kernel_info->GetOutputDeviceType(output_index) != + (iter->second)[attr_index].first.GetOutputAttr(output_index).first) { + flag = false; + break; + } + } + // finish data type matching check and return a pair maintain the whether matching is success, + // if first is true, second is index of matching KernelAttr and creater pair in vector; + if (flag) { + size_t match_index = attr_index; + return std::make_pair(true, match_index); + } + } + return std::make_pair(false, 0); +} + +GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { + auto kernel_info = apply_kernel->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(kernel_build_Info); + std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_Info); + if (ret_pair.first) { + return (map_kernel_name_to_creater_.find(kernel_name)->second)[ret_pair.second].second(); + } + return nullptr; +} + +bool GpuKernelFactory::SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_build_info) { + std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_info.get()); + return ret_pair.first; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h new file mode 100644 index 0000000000..8834fa0f1a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.h @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +using mindspore::device::gpu::KernelAttr; +using GpuKernelCreater = std::function; +class GpuKernelFactory { + public: + ~GpuKernelFactory() = default; + + static GpuKernelFactory &GetInstance(); + + void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater); + + GpuKernel *Create(const std::string &kernel_name, const CNodePtr &apply_kernel); + + bool SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_info); + + std::string SupportedTypeList(const std::string &kernel_name); + + private: + GpuKernelFactory() = default; + + GpuKernelFactory(GpuKernelFactory const &); + + GpuKernelFactory &operator=(const GpuKernelFactory &); + + std::pair GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); + void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, + std::vector> *iter_second, size_t attr_index); + // map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair. + std::map>> map_kernel_name_to_creater_; +}; + +class GpuKernelRegister { + public: + GpuKernelRegister(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater) { + GpuKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(creater)); + } +}; + +#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \ + static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, KernelAttr(), []() { return new OPCLASS(); }); + +// regular register of fixed accuracy kernels +#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \ + static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain one typename, ignore input num +#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain one typename +#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); + +// register of mixed accuracy kernels which use template and maintain two typename +#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \ + static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ + static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ + []() { return new OPCLASS(); }); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/kernel_constants.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h similarity index 100% rename from mindspore/ccsrc/kernel/gpu/kernel_constants.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/kernel_constants.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc new file mode 100644 index 0000000000..86c7d8c108 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/addn_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AddNGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AddNGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(AddN, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AddNGpuFwdKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h new file mode 100644 index 0000000000..b69bd20216 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h @@ -0,0 +1,143 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/slice_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class AddNGpuFwdKernel : public GpuKernel { + public: + AddNGpuFwdKernel() + : cudnn_handle_(nullptr), + input_descriptor_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + input_size_(0), + output_size_(0), + workspace_size_(0), + is_null_input_(false), + num_input_(0) {} + ~AddNGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *output_addr = GetDeviceAddress(outputs, 0); + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast(stream_ptr)); + } + const float alpha = 1; + const float beta = 0; + for (size_t i = 0; i < IntToSize(num_input_); i++) { + T *input_addr = GetDeviceAddress(inputs, i); + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, + &(i > 0 ? alpha : beta), input_descriptor_, output_addr), + "cudnnAddTensor failed"); + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + num_input_ = GetAttr(kernel_node, "n"); + if (IntToSize(num_input_) != input_num) { + MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "AddNGpuFwdKernel input is null"; + InitSizeLists(); + return true; + } + for (size_t i = input_shape.size(); i < 4; i++) { + (void)input_shape.insert(input_shape.begin(), 1); + } + int dimA[4]; + for (size_t i = 0; i < input_shape.size(); i++) { + dimA[i] = SizeToInt(input_shape[i]); + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + SizeToInt(input_shape.size()), dimA), + "cudnnSetTensorNdDescriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + for (int i = 0; i < num_input_; i++) { + input_size_list_.push_back(input_size_); + } + output_size_list_.push_back(input_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnDataType_t cudnn_data_type_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + bool is_null_input_; + int num_input_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc new file mode 100644 index 0000000000..bffcca158b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AssignAddGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE( + AssignAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AssignAddGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + AssignAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AssignAddGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h new file mode 100644 index 0000000000..04a74b3412 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/assign_add_gpu_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/assign_add_impl.cuh" +namespace mindspore { +namespace kernel { +template +class AssignAddGpuFwdKernel : public GpuKernel { + public: + AssignAddGpuFwdKernel() : is_null_input_(false), input_size_(0) {} + ~AssignAddGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *input_addr2 = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + CalAssignAdd(input_size_ / sizeof(T), input_addr, input_addr2, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "AssignAddGpuFwdKernel input is null"; + InitSizeLists(); + return true; + } + input_size_ = sizeof(T); + for (size_t i : input_shape) { + input_size_ = i * input_size_; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc new file mode 100644 index 0000000000..a07fb6ddf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + BiasAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + BiasAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BiasAddGpuKernel, float16) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h new file mode 100644 index 0000000000..fd344be28a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h @@ -0,0 +1,149 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H +#define MINDSPORE_BIAS_ADD_GPU_KERNEL_H +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class BiasAddGpuKernel : public GpuKernel { + public: + BiasAddGpuKernel() + : cudnn_handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + x_desc_(nullptr), + b_desc_(nullptr), + op_desc_(nullptr), + is_null_input_(false) {} + ~BiasAddGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + + T *x_addr = GetDeviceAddress(inputs, 0); + T *b_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + + try { + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_, + b_addr, &beta, x_desc_, output_addr), + "cudnnOpTensor failed"); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor"; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto num_dims = x_shape.size(); + is_null_input_ = CHECK_NULL_INPUT(x_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "input is null"; + InitSizeLists(); + return true; + } + + if (num_dims < 2) { + MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; + } + + std::string format = GetAttr(kernel_node, "data_format"); + string::size_type pos = format.find("C"); + if (pos == std::string::npos || pos >= num_dims) { + MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; + } + + // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. + auto cudnn_dims = std::max(num_dims, 4UL); + std::unique_ptr x_dims = std::make_unique(cudnn_dims); + std::unique_ptr b_dims = std::make_unique(cudnn_dims); + for (size_t i = 0; i < cudnn_dims; i++) { + x_dims[i] = (i < num_dims) ? SizeToInt(x_shape[i]) : 1; + b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), + "cudnnSetOpTensorDescriptor failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&b_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); + } + void InitSizeLists() override { + size_t x_size, b_size; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &x_size), "cudnnGetTensorSizeInBytes failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(b_desc_, &b_size), "cudnnGetTensorSizeInBytes failed."); + input_size_list_.push_back(x_size); + input_size_list_.push_back(b_size); + output_size_list_.push_back(x_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(b_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyOpTensorDescriptor failed"); + } + + cudnnHandle_t cudnn_handle_; + cudnnDataType_t cudnn_data_type_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t b_desc_; + cudnnOpTensorDescriptor_t op_desc_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc new file mode 100644 index 0000000000..41e7147328 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +// fp32 +MS_REG_GPU_KERNEL_TWO( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, float, bool) +MS_REG_GPU_KERNEL_TWO( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + RealDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO( + TensorAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGpuKernel, float, float) + +// fp16 +MS_REG_GPU_KERNEL_TWO( + Greater, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + BroadcastOpGpuKernel, half, bool) +MS_REG_GPU_KERNEL_TWO( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + RealDiv, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO( + TensorAdd, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGpuKernel, half, half) + +// int32 +MS_REG_GPU_KERNEL_TWO( + TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +MS_REG_GPU_KERNEL_TWO( + Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + BroadcastOpGpuKernel, int, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h new file mode 100644 index 0000000000..aaf827723a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +template +class BroadcastOpGpuKernel : public GpuKernel { + public: + BroadcastOpGpuKernel() + : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} + ~BroadcastOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *lhs = GetDeviceAddress(inputs, 0); + T *rhs = GetDeviceAddress(inputs, 1); + S *output = GetDeviceAddress(outputs, 0); + + if (need_broadcast_) { + Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], + rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, + rhs, output, reinterpret_cast(stream_ptr)); + } else { + NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + GetOpType(kernel_node); + auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); + need_broadcast_ = IsBroadcast(shape1, shape2); + if (need_broadcast_ && shape1.size() > 4) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; + } + + for (size_t i = 0; i < shape3.size(); i++) { + output_shape_[i] = shape3[i]; + output_num_ *= shape3[i]; + } + int lhs_offset = shape3.size() - shape1.size(); + for (size_t j = 0; j < shape1.size(); j++) { + lhs_shape_[j + lhs_offset] = shape1[j]; + input1_num_ *= shape1[j]; + } + int rhs_offset = shape3.size() - shape2.size(); + for (size_t k = 0; k < shape2.size(); k++) { + rhs_shape_[k + rhs_offset] = shape2[k]; + input2_num_ *= shape2[k]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { return; } + void InitSizeLists() override { + input_size_list_.push_back(input1_num_ * sizeof(T)); + input_size_list_.push_back(input2_num_ * sizeof(T)); + output_size_list_.push_back(output_num_ * sizeof(S)); + } + + private: + void GetOpType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + + static std::map kBroadcastTypeMap = { + {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, + {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, + {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, + }; + + auto iter = kBroadcastTypeMap.find(kernel_name); + if (iter == kBroadcastTypeMap.end()) { + MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; + } else { + op_type_ = iter->second; + } + } + + bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] != rhs[i]) { + return true; + } + } + return false; + } + + BroadcastOpType op_type_; + bool need_broadcast_; + int input1_num_; + int input2_num_; + int output_num_; + int lhs_shape_[4] = {1, 1, 1, 1}; + int rhs_shape_[4] = {1, 1, 1, 1}; + int output_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc new file mode 100644 index 0000000000..49be2fd9a6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MaximumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BroadcastOpGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + BroadcastOpGradGpuKernel, int) +MS_REG_GPU_KERNEL_ONE(MaximumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + BroadcastOpGradGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h new file mode 100644 index 0000000000..6258c5c4e2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.h @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +template +class BroadcastOpGradGpuKernel : public GpuKernel { + public: + BroadcastOpGradGpuKernel() + : op_type_(BROADCAST_GRAD_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} + ~BroadcastOpGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *x1 = GetDeviceAddress(inputs, 0); + T *x2 = GetDeviceAddress(inputs, 1); + T *dy = GetDeviceAddress(inputs, 2); + T *dx1 = GetDeviceAddress(outputs, 0); + T *dx2 = GetDeviceAddress(outputs, 1); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx1, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx2, 0, outputs[1]->size, reinterpret_cast(stream_ptr)), + "cudaMemSet Failed"); + if (need_broadcast_) { + BroadcastGrad(x1_shape_[0], x1_shape_[1], x1_shape_[2], x1_shape_[3], x2_shape_[0], x2_shape_[1], x2_shape_[2], + x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], op_type_, x1, x2, dy, dx1, + dx2, reinterpret_cast(stream_ptr)); + } else { + NoBroadcastGrad(output_num_, op_type_, x1, x2, dy, dx1, dx2, reinterpret_cast(stream_ptr)); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + GetOpType(kernel_node); + auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + need_broadcast_ = IsBroadcast(shape1, shape2); + if (need_broadcast_ && shape1.size() > 4) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; + } + + for (size_t i = 0; i < shape3.size(); i++) { + dy_shape_[i] = shape3[i]; + output_num_ *= shape3[i]; + } + int x1_offset = shape3.size() - shape1.size(); + for (size_t i = 0; i < shape1.size(); i++) { + x1_shape_[i + x1_offset] = shape1[i]; + input1_num_ *= shape1[i]; + } + int x2_offset = shape3.size() - shape2.size(); + for (size_t i = 0; i < shape2.size(); i++) { + x2_shape_[i + x2_offset] = shape2[i]; + input2_num_ *= shape2[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { return; } + void InitSizeLists() override { + input_size_list_.push_back(input1_num_ * sizeof(T)); + input_size_list_.push_back(input2_num_ * sizeof(T)); + input_size_list_.push_back(output_num_ * sizeof(T)); + output_size_list_.push_back(input1_num_ * sizeof(T)); + output_size_list_.push_back(input2_num_ * sizeof(T)); + } + + private: + void GetOpType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + + static std::map kBroadcastTypeMap = { + {"MaximumGrad", BROADCAST_GRAD_TYPE_MAXIMUM}, + {"MinimumGrad", BROADCAST_GRAD_TYPE_MINIMUM}, + }; + + auto iter = kBroadcastTypeMap.find(kernel_name); + if (iter == kBroadcastTypeMap.end()) { + MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; + } else { + op_type_ = iter->second; + } + } + + bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { + if (lhs.size() != rhs.size()) { + return true; + } + for (size_t i = 0; i < lhs.size(); i++) { + if (lhs[i] != rhs[i]) { + return true; + } + } + return false; + } + + BroadcastGradOpType op_type_; + bool need_broadcast_; + int input1_num_; + int input2_num_; + int output_num_; + int x1_shape_[4] = {1, 1, 1, 1}; + int x2_shape_[4] = {1, 1, 1, 1}; + int dy_shape_[4] = {1, 1, 1, 1}; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.cc new file mode 100644 index 0000000000..3103f30f52 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + EqualCountGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EqualCountGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + EqualCount, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + EqualCountGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h new file mode 100644 index 0000000000..eae7a893b7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/equalcount_gpu_kernel.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_EQUALCOUNT_GPU_KERNEL_H +#define MINDSPORE_EQUALCOUNT_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/equalcount_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class EqualCountGpuKernel : public GpuKernel { + public: + EqualCountGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} + ~EqualCountGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input1 = GetDeviceAddress(inputs, 0); + T *input2 = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + int size = SizeToInt(input_size_ / sizeof(T)); + CalEqualCount(size, input1, input2, output, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but equalcount needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but equalcount needs 1 output."; + return false; + } + + output_size_ = sizeof(T); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + private: + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc new file mode 100644 index 0000000000..313669a647 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h new file mode 100644 index 0000000000..be74f2e9dc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/float_status_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/float_status_impl.cuh" + +namespace mindspore { +namespace kernel { +enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; +static const std::map kOpTypeMap = { + {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; +template +class FloatStatusGpuKernel : public GpuKernel { + public: + FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} + ~FloatStatusGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + + switch (kernel_name_) { + case OP_STATUS: { + T *output = GetDeviceAddress(outputs, 0); + CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_INF: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_NAN: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_FINITE: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; + } + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kOpTypeMap.find(kernel_name); + if (iter == kOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; + } else { + kernel_name_ = iter->second; + } + if (kernel_name_ == OP_STATUS) { + output_size_ = sizeof(T); + } else { + output_size_ = input_size_ / sizeof(T) * sizeof(bool); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + Optype kernel_name_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.cc new file mode 100644 index 0000000000..471c394598 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + MatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + MatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + MatMulGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + MatMulGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h new file mode 100644 index 0000000000..7888d442c9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h @@ -0,0 +1,155 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H +#define MINDSPORE_MATMUL_GPU_KERNEL_H + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace kernel { +template +class MatMulGpuKernel : public GpuKernel { + public: + MatMulGpuKernel() + : batch_(0), + m_(0), + n_(0), + k_(0), + is_null_input_(false), + transpose_x1_(CUBLAS_OP_N), + transpose_x2_(CUBLAS_OP_N), + handle_(nullptr), + dtype_a_(CUDA_R_32F), + dtype_b_(CUDA_R_32F), + dtype_c_(CUDA_R_32F), + algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} + ~MatMulGpuKernel() = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto input1_addr = GetDeviceAddress(inputs, 0); + auto input2_addr = GetDeviceAddress(inputs, 1); + auto output_addr = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_); + const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_); + const int ldc = n_; + + auto stride_a = SizeToInt(m_ * k_); + auto stride_b = SizeToInt(k_ * n_); + auto stride_c = SizeToInt(m_ * n_); + + try { + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), + &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, + &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), + "cublasSgemm Call Fail"); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); + dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); + dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(output_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "input is null"; + InitSizeLists(); + return true; + } + auto dims = output_shape.size(); + if (dims < 2) { + MS_LOG(EXCEPTION) << "Output dims " << dims << " not support."; + } + + m_ = output_shape[dims - 2]; + n_ = output_shape[dims - 1]; + batch_ = 1; + for (size_t i = 0; i < dims - 2; i++) { + batch_ *= output_shape[i]; + } + + bool transpose = GetAttr(kernel_node, "transpose_x1"); + transpose_x1_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + k_ = transpose ? input1_shape[dims - 2] : input1_shape[dims - 1]; + + transpose = GetAttr(kernel_node, "transpose_x2"); + transpose_x2_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t unit_size = sizeof(T); + + size_t input_size = batch_ * m_ * k_ * unit_size; + input_size_list_.push_back(input_size); + + input_size = batch_ * n_ * k_ * unit_size; + input_size_list_.push_back(input_size); + + size_t output_size = batch_ * m_ * n_ * unit_size; + output_size_list_.push_back(output_size); + } + + private: + size_t batch_; + size_t m_; + size_t n_; + size_t k_; + bool is_null_input_; + + cublasOperation_t transpose_x1_; + cublasOperation_t transpose_x2_; + cublasHandle_t handle_; + cudaDataType_t dtype_a_; + cudaDataType_t dtype_b_; + cudaDataType_t dtype_c_; + cublasGemmAlgo_t algo_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc new file mode 100644 index 0000000000..c72c271c52 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + RandomOpGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h new file mode 100644 index 0000000000..785ac02ee5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh" + +namespace mindspore { +namespace kernel { +enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; + +const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; +template +class RandomOpGpuKernel : public GpuKernel { + public: + RandomOpGpuKernel() + : random_op_type_(RANDOM_OP_INVALID_TYPE), + input_size_0_(0), + output_size_(sizeof(T)), + workspace_size_(sizeof(curandState)) {} + ~RandomOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + void *workspace_addr = GetDeviceAddress(workspace, 0); + curandState *devStates = reinterpret_cast(workspace_addr); + T *output_addr = GetDeviceAddress(outputs, 0); + + switch (random_op_type_) { + case RANDOM_OP_NORMAL: { + StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kRandomOpTypeMap.find(kernel_name); + if (iter == kRandomOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Random operation " << kernel_name << " is not supported."; + } else { + random_op_type_ = iter->second; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; + return false; + } + auto input_shape_0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape_0.size(); i++) { + input_size_0_ += input_shape_0[i]; + } + input_size_0_ *= sizeof(int); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + workspace_size_ *= output_shape[i]; + } + seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_0_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + } + + private: + RandomOptype random_op_type_; + size_t input_size_0_; + size_t output_size_; + size_t workspace_size_; + int seed_; + int seed2_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc new file mode 100644 index 0000000000..ae8e7bbd0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h new file mode 100644 index 0000000000..26993bc3bd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh" + +namespace mindspore { +namespace kernel { +enum UnaryOptype { + UNARY_OP_EXP = 0, + UNARY_OP_LOG, + UNARY_OP_NEG, + UNARY_OP_RECIPROCAL, + UNARY_OP_ZEROSLIKE, + UNARY_OP_SQUARE, + UNARY_OP_SQRT, + UNARY_OP_RSQRT, + UNARY_OP_INVALID_TYPE = 255 +}; +static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, + {"Log", UNARY_OP_LOG}, + {"Neg", UNARY_OP_NEG}, + {"Reciprocal", UNARY_OP_RECIPROCAL}, + {"ZerosLike", UNARY_OP_ZEROSLIKE}, + {"Square", UNARY_OP_SQUARE}, + {"Sqrt", UNARY_OP_SQRT}, + {"Rsqrt", UNARY_OP_RSQRT}}; +template +class UnaryOpGpuKernel : public GpuKernel { + public: + UnaryOpGpuKernel() + : unary_op_type_(UNARY_OP_INVALID_TYPE), + input_size_(sizeof(T)), + output_size_(sizeof(T)), + workspace_size_(0), + is_null_input_(false) {} + ~UnaryOpGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + switch (unary_op_type_) { + case UNARY_OP_EXP: { + Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_LOG: { + Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_NEG: { + Negative(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_RECIPROCAL: { + Reciprocal(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_SQUARE: { + Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_SQRT: { + Sqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_RSQRT: { + Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ZEROSLIKE: { + Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); + return true; + } + default: { + MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kUnaryOpTypeMap.find(kernel_name); + if (iter == kUnaryOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "Unary operation " << kernel_name << " is not supported."; + } else { + unary_op_type_ = iter->second; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "UnaryOpGpuKernel input is null"; + InitSizeLists(); + return true; + } + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + output_size_ = input_size_; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + UnaryOptype unary_op_type_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc new file mode 100644 index 0000000000..c6e3c4c043 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h new file mode 100644 index 0000000000..4c3c3189fb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h @@ -0,0 +1,181 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "runtime/device/gpu/distribution/collective_init.h" + +namespace mindspore { +namespace kernel { +enum NcclKernelType { NCCL_ALL_REDUCE = 0, NCCL_ALL_GATHER, NCCL_REDUCE_SCATTER, NCCL_INVALID_TYPE = 255 }; +const std::map kNcclTypeMap = { + {"AllReduce", NCCL_ALL_REDUCE}, + {"AllGather", NCCL_ALL_GATHER}, + {"ReduceScatter", NCCL_REDUCE_SCATTER}, +}; + +static std::map kNcclDtypeMap = { + {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; + +typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); +typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t); +typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); + +template +class NcclGpuKernel : public GpuKernel { + public: + NcclGpuKernel() + : nccl_kernel_type_(NCCL_INVALID_TYPE), + nccl_reduce_type_(ncclSum), + input_size_(0), + output_size_(0), + collective_handle_(nullptr), + comm_stream_(nullptr) {} + ~NcclGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + switch (nccl_kernel_type_) { + case NCCL_ALL_REDUCE: { + auto all_reduce_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); + MS_EXCEPTION_IF_NULL(all_reduce_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream), + "ncclAllReduce failed"); + break; + } + case NCCL_ALL_GATHER: { + auto all_gather_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); + MS_EXCEPTION_IF_NULL(all_gather_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT( + (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream), + "ncclAllGather failed"); + break; + } + case NCCL_REDUCE_SCATTER: { + auto reduce_scatter_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); + MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); + CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), + nccl_data_type_, nccl_reduce_type_, stream), + "ncclReduceScatter failed"); + break; + } + default: { + MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; + } + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + nccl_data_type_ = kNcclDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; ++i) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + size_t size = sizeof(T); + for (size_t j = 0; j < shape.size(); j++) { + size *= IntToSize(shape[j]); + } + input_size_list_.push_back(size); + input_size_ += size; + } + for (size_t i = 0; i < output_num; ++i) { + auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); + size_t size = sizeof(T); + for (size_t j = 0; j < shape.size(); j++) { + size *= IntToSize(shape[j]); + } + output_size_list_.push_back(size); + output_size_ += size; + } + InferCommType(kernel_node); + collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + + auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); + if (comm_stream_attr) { + comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); + MS_EXCEPTION_IF_NULL(comm_stream_); + } + return true; + } + + protected: + void InitSizeLists() override { return; } + + private: + void InferCommType(const CNodePtr &kernel_node) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kNcclTypeMap.find(kernel_name); + if (iter == kNcclTypeMap.end()) { + MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported."; + } else { + nccl_kernel_type_ = iter->second; + } + + auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); + if (reduce_op) { + std::string type = GetValue(reduce_op); + if (type == "sum") { + nccl_reduce_type_ = ncclSum; + } else if (type == "max") { + nccl_reduce_type_ = ncclMax; + } else if (type == "min") { + nccl_reduce_type_ = ncclMin; + } else if (type == "prod") { + nccl_reduce_type_ = ncclProd; + } else { + MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; + } + } + return; + } + + NcclKernelType nccl_kernel_type_; + ncclRedOp_t nccl_reduce_type_; + ncclDataType_t nccl_data_type_; + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const void *collective_handle_; + cudaStream_t comm_stream_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc new file mode 100644 index 0000000000..334550b213 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) + +MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) + +MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h new file mode 100644 index 0000000000..d651da75e0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ActivationGpuFwdKernel : public GpuKernel { + public: + ActivationGpuFwdKernel() + : cudnn_handle_(nullptr), + activation_desc_(nullptr), + mode_(CUDNN_ACTIVATION_RELU), + data_descriptor_(nullptr), + is_null_input_(false), + cudnn_data_type_(CUDNN_DATA_FLOAT), + input_size_(0), + output_size_(0), + workspace_size_(0) {} + ~ActivationGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *) override { + if (is_null_input_) { + return true; + } + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, + &beta, data_descriptor_, output), + "cudnnActivationForward failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kernel_map.find(node_name); + if (iter == kernel_map.end()) { + MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; + } + mode_ = iter->second; + + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; + InitSizeLists(); + return true; + } + std::vector shape; + ShapeNdTo4d(input_shape, &shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), + "cudnnSetActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "cudnnSetTensor4dDescriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), + "cudnnCreateActivationDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + output_size_ = input_size_; + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), + "cudnnDestroyActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, + {"Tanh", CUDNN_ACTIVATION_TANH}, + {"ELU", CUDNN_ACTIVATION_ELU}, + {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; + + cudnnHandle_t cudnn_handle_; + cudnnActivationDescriptor_t activation_desc_; + cudnnActivationMode_t mode_; + cudnnTensorDescriptor_t data_descriptor_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc new file mode 100644 index 0000000000..8fd486c08c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/activation_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + TanhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + TanhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) + +MS_REG_GPU_KERNEL_ONE( + SigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ActivationGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + SigmoidGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ActivationGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h new file mode 100644 index 0000000000..ffdb618098 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -0,0 +1,146 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ActivationGradGpuKernel : public GpuKernel { + public: + ActivationGradGpuKernel() + : cudnn_handle_(nullptr), + activation_desc_(nullptr), + mode_(CUDNN_ACTIVATION_RELU), + data_descriptor_(nullptr), + is_null_input_(false), + cudnn_data_type_(CUDNN_DATA_FLOAT), + input_size_(0) {} + ~ActivationGradGpuKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *) override { + if (is_null_input_) { + return true; + } + T *dy = nullptr; + T *y = nullptr; + if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) { + dy = GetDeviceAddress(inputs, 0); + y = GetDeviceAddress(inputs, 1); + } else { + y = GetDeviceAddress(inputs, 0); + dy = GetDeviceAddress(inputs, 1); + } + T *dx = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, + data_descriptor_, y, &beta, data_descriptor_, dx), + "cudnnActivationBackward failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kernel_map.find(node_name); + if (iter == kernel_map.end()) { + MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; + } + mode_ = iter->second; + + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2."; + return false; + } + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ActivationGradGpuKernel input is null."; + InitSizeLists(); + return true; + } + std::vector shape; + ShapeNdTo4d(input_shape, &shape); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), + "SetActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + shape[0], shape[1], shape[2], shape[3]), + "SetTensor4dDescriptor failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), + "cudnnCreateActivationDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), + "cudnnDestroyActivationDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + std::map kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, + {"TanhGrad", CUDNN_ACTIVATION_TANH}, + {"ELUGrad", CUDNN_ACTIVATION_ELU}, + {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; + cudnnHandle_t cudnn_handle_; + cudnnActivationDescriptor_t activation_desc_; + cudnnActivationMode_t mode_; + cudnnTensorDescriptor_t data_descriptor_; + bool is_null_input_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.cc new file mode 100644 index 0000000000..0f89eb4419 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Adam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdamGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Adam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdamGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h new file mode 100644 index 0000000000..e2fc87ed51 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adam_impl.cuh" +namespace mindspore { +namespace kernel { +template +class AdamGpuKernel : public GpuKernel { + public: + AdamGpuKernel() + : variable_size_(0), + m_size_(0), + v_size_(0), + beta1_power_size_(0), + beta2_power_size_(0), + learning_rate_size_(0), + beta1_size_(0), + beta2_size_(0), + epsilon_size_(0), + gradient_size_(0) {} + + ~AdamGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *m = GetDeviceAddress(inputs, 1); + T *v = GetDeviceAddress(inputs, 2); + T *beta1_power = GetDeviceAddress(inputs, 3); + T *beta2_power = GetDeviceAddress(inputs, 4); + T *learning_rate = GetDeviceAddress(inputs, 5); + T *beta1 = GetDeviceAddress(inputs, 6); + T *beta2 = GetDeviceAddress(inputs, 7); + T *epsilon = GetDeviceAddress(inputs, 8); + T *gradient = GetDeviceAddress(inputs, 9); + ApplyAdam(inputs[0]->size / sizeof(T), gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, + variable, m, v, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 10) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs."; + return false; + } + + variable_size_ = sizeof(T); + m_size_ = sizeof(T); + v_size_ = sizeof(T); + beta1_power_size_ = sizeof(T); + beta2_power_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + beta1_size_ = sizeof(T); + beta2_size_ = sizeof(T); + epsilon_size_ = sizeof(T); + gradient_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < m_shape.size(); i++) { + m_size_ *= m_shape[i]; + } + + auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < v_shape.size(); i++) { + v_size_ *= v_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(m_size_); + input_size_list_.push_back(v_size_); + input_size_list_.push_back(beta1_power_size_); + input_size_list_.push_back(beta2_power_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(beta1_size_); + input_size_list_.push_back(beta2_size_); + input_size_list_.push_back(epsilon_size_); + input_size_list_.push_back(gradient_size_); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t m_size_; + size_t v_size_; + size_t beta1_power_size_; + size_t beta2_power_size_; + size_t learning_rate_size_; + size_t beta1_size_; + size_t beta2_size_; + size_t epsilon_size_; + size_t gradient_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.cc new file mode 100644 index 0000000000..6131aa8568 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BiasAddGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BiasAddGradGpuKernel, float16) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h new file mode 100644 index 0000000000..3e15b818be --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_grad_gpu_kenel.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class BiasAddGradGpuKernel : public GpuKernel { + public: + BiasAddGradGpuKernel() + : same_dims_(true), + cudnn_handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + dy_desc_(nullptr), + db_desc_(nullptr), + op_desc_(nullptr) {} + ~BiasAddGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *db_addr = GetDeviceAddress(outputs, 0); + T *indices_addr = GetDeviceAddress(workspace, 0); + T *workspace_addr = GetDeviceAddress(workspace, 1); + + const float alpha = 1; + const float beta = 0; + if (same_dims_) { + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed."); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr, + workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr), + "cudnnReduceTensor failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto num_dims = dy_shape.size(); + if (num_dims < 2) { + MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; + } + + std::string format = GetAttr(kernel_node, "data_format"); + string::size_type pos = format.find("C"); + if (pos == std::string::npos || pos >= num_dims) { + MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; + } + + // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. + auto cudnn_dims = std::max(num_dims, 4UL); + std::unique_ptr dy_dims = std::make_unique(cudnn_dims); + std::unique_ptr db_dims = std::make_unique(cudnn_dims); + for (size_t i = 0; i < cudnn_dims; i++) { + dy_dims[i] = (i < num_dims) ? SizeToInt(dy_shape[i]) : 1; + db_dims[i] = (i == pos) ? SizeToInt(dy_shape[i]) : 1; + + if (dy_dims[i] != db_dims[i]) { + same_dims_ = false; + } + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), + "cudnnSetTensorNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, + CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), + "cudnnSetReduceTensorDescriptor failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&db_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); + } + void InitSizeLists() override { + size_t dy_size, db_size; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size), "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(db_desc_, &db_size), "cudnnGetTensorSizeInBytes failed"); + input_size_list_.push_back(dy_size); + output_size_list_.push_back(db_size); + + size_t indices_size, workspace_size; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size), + "cudnnGetReductionIndicesSize failed") + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size), + "cudnnGetReductionWorkspaceSize failed") + workspace_size_list_.push_back(indices_size); + workspace_size_list_.push_back(workspace_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDestroyReduceTensorDescriptor(op_desc_), + "cudnnDestroyReduceTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(db_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyOpTensorDescriptor failed"); + } + + bool same_dims_; + cudnnHandle_t cudnn_handle_; + cudnnDataType_t cudnn_data_type_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t db_desc_; + cudnnReduceTensorDescriptor_t op_desc_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.cc new file mode 100644 index 0000000000..f9bb710b94 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv2D, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Conv2dGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv2D, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Conv2dGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h new file mode 100644 index 0000000000..6072614e22 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class Conv2dGpuFwdKernel : public GpuKernel { + public: + Conv2dGpuFwdKernel() + : cudnn_handle_(nullptr), + input_desc_(nullptr), + output_desc_(nullptr), + filter_desc_(nullptr), + conv_desc_(nullptr), + padded_desc_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + group_(1), + is_null_input_(false), + input_size_(0), + filter_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~Conv2dGpuFwdKernel() override { DestroyResource(); } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *filter_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + T *workspace_addr = nullptr; + if (workspace_size_ != 0) { + workspace_addr = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded_addr = GetDeviceAddress(workspace, 1); + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, + conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), + "cudnnConvolutionForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionForward(cudnn_handle_, &alpha, input_desc_, input_addr, filter_desc_, filter_addr, conv_desc_, + conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), + "cudnnConvolutionForward failed"); + } + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null."; + InitSizeLists(); + return true; + } + Set4DDesc(in_shape, filter_shape, output_shape); + group_ = GetAttr(kernel_node, "group"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t input_descriptor_real = nullptr; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(in_shape, kernel_node); + input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + input_descriptor_real = input_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(input_descriptor_real); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(filter_desc_, reinterpret_cast(&filter_size_)), + "cudnnGetFilterSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + input_size_list_.push_back(filter_size_); + output_size_list_.push_back(output_size_); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, padded_desc_, filter_desc_, conv_desc_, output_desc_, + conv_algorithm_, &workspace_size_), + "cudnnGetConvolutionForwardWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, input_desc_, filter_desc_, conv_desc_, output_desc_, + conv_algorithm_, &workspace_size_), + "cudnnGetConvolutionForwardWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_height_ = SizeToInt(in_shape[2]); + old_width_ = SizeToInt(in_shape[3]); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + + // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, + old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + + void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, + const std::vector &output_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), + SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]), + SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), + "cudnnSetFilter4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), + SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + } + void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { + if (group_ > 1 || CUDNN_MAJOR < 7) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionForwardAlgorithm( + cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, output_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 0, &conv_algorithm_), + "cudnnGetConvolutionForwardAlgorithm failed"); + } else { + constexpr int requested_algo_count = 1; + int returned_algo_count; + cudnnConvolutionFwdAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, + output_desc_, requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionForwardAlgorithm_v7 failed"); + conv_algorithm_ = perf_results.algo; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + } + } + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 4) { + MS_LOG(EXCEPTION) << "Conv2d's' stride must be 4d!"; + } + if (stride_[0] != 1 || stride_[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2d stride only support 1 in N axis and C axis!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "Conv2d's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "Conv2d dilation only support 1 in N axis and C axis!"; + } + } + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_desc_; + cudnnTensorDescriptor_t output_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnConvolutionFwdAlgo_t conv_algorithm_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t padded_desc_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t input_size_; + size_t filter_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.cc new file mode 100644 index 0000000000..ca16e1a18c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropFilter, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConvGradFilterGpuBkwKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropFilter, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ConvGradFilterGpuBkwKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h new file mode 100644 index 0000000000..638da4a99f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ConvGradFilterGpuBkwKernel : public GpuKernel { + public: + ConvGradFilterGpuBkwKernel() + : cudnn_handle_(nullptr), + dw_desc_(nullptr), + conv_desc_(nullptr), + dy_desc_(nullptr), + x_desc_(nullptr), + padded_descriptor_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + group_(1), + is_null_input_(false), + input_size_(0), + dy_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~ConvGradFilterGpuBkwKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *dy = GetDeviceAddress(inputs, 0); + T *x = GetDeviceAddress(inputs, 1); + T *dw = GetDeviceAddress(outputs, 0); + T *work_space = nullptr; + if (workspace_size_ != 0) { + work_space = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded = GetDeviceAddress(workspace, 1); + CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, + reinterpret_cast(stream_ptr)); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, + algo_, work_space, workspace_size_, &beta, dw_desc_, dw), + "ConvolutionBackwardFilter failed"); + return true; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space, + workspace_size_, &beta, dw_desc_, dw), + "ConvolutionBackwardFilter failed"); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null."; + InitSizeLists(); + return true; + } + std::vector filter_shape; + GetFilterShape(kernel_node, &filter_shape); + Set4DDesc(dy_shape, filter_shape, in_shape); + group_ = GetAttr(kernel_node, "group"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); + + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t x_desc_real = nullptr; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(in_shape, kernel_node); + x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "GetConvolution2dDescriptor failed"); + x_desc_real = x_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(x_desc_real); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "cudnnCreateFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast(&dy_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast(&output_size_)), + "cudnnGetFilterSizeInBytes failed"); + } + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_, + dw_desc_, algo_, reinterpret_cast(&workspace_size_)), + "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_, + reinterpret_cast(&workspace_size_)), + "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "cudnnDestroyFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradFilter needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradFilter needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + n_ = SizeToInt(in_shape[0]); + c_ = SizeToInt(in_shape[1]); + old_height_ = SizeToInt(in_shape[2]); + old_width_ = SizeToInt(in_shape[3]); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { + if (group_ > 1 || CUDNN_MAJOR < 7) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), + "GetConvolutionBackwardFilterAlgorithm failed"); + } else { + constexpr int requested_algo_count = 1; + int returned_algo_count; + cudnnConvolutionBwdFilterAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, + requested_algo_count, &returned_algo_count, &perf_results), + "GetConvolutionBackwardFilterAlgorithm failed"); + algo_ = perf_results.algo; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + } + } + void GetFilterShape(const CNodePtr &kernel_node, std::vector *filter_shape) { + auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast()->value(); + (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), + [](const ValuePtr &e) -> int { return e->cast()->value(); }); + } + void Set4DDesc(const std::vector &dy_shape, const std::vector &filter_shape, + const std::vector &in_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), + SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), + "SetTensor4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1], + filter_shape[2], filter_shape[3]), + "SetFilter4dDescriptor failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), + SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), + "SetTensor4dDescriptor failed"); + } + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 2) { + MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's stride must be 2d!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!"; + } + } + cudnnHandle_t cudnn_handle_; + cudnnFilterDescriptor_t dw_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnConvolutionBwdFilterAlgo_t algo_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + const float pad_value_ = 0.0; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t input_size_; + size_t dy_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.cc new file mode 100644 index 0000000000..d8441fb67c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropInput, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConvGradInputGpuBkwKernel, float) +MS_REG_GPU_KERNEL_ONE( + Conv2DBackpropInput, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ConvGradInputGpuBkwKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h new file mode 100644 index 0000000000..a9a1e5c0cc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -0,0 +1,315 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class ConvGradInputGpuBkwKernel : public GpuKernel { + public: + ConvGradInputGpuBkwKernel() + : cudnn_handle_(nullptr), + w_desc_(nullptr), + conv_desc_(nullptr), + dy_desc_(nullptr), + dx_desc_(nullptr), + padded_descriptor_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + group_(1), + is_null_input_(false), + dy_size_(0), + w_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *dy = GetDeviceAddress(inputs, 0); + T *w = GetDeviceAddress(inputs, 1); + T *dx = GetDeviceAddress(outputs, 0); + T *work_space = nullptr; + if (workspace_size_ != 0) { + work_space = GetDeviceAddress(workspace, 0); + } + + const float alpha = 1; + const float beta = 0; + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded = GetDeviceAddress(workspace, 1); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, + workspace_size_, &beta, padded_descriptor_, padded), + "ConvolutionBackwardData failed"); + CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, + workspace_size_, &beta, dx_desc_, dx), + "ConvolutionBackwardData failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(dy_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null."; + InitSizeLists(); + return true; + } + std::vector input_shape; + GetInputShape(kernel_node, &input_shape); + Set4DDesc(dy_shape, input_shape, filter_shape); + + group_ = GetAttr(kernel_node, "group"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); + + pad_height_ = GetAttr(kernel_node, "pad"); + pad_width_ = pad_height_; + pad_mode_ = GetAttr(kernel_node, "pad_mode"); + SetStrideAndDilation(kernel_node); + cudnnTensorDescriptor_t dx_desc_real = nullptr; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(input_shape, kernel_node); + dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], + dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + dx_desc_real = dx_desc_; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), + "cudnnSetConvolutionMathType failed.") + } + SelectAlgorithm(dx_desc_real); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "cudnnCreateFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), + "cudnnCreateConvolutionDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_), "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(w_desc_, &w_size_), "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_desc_, &output_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(w_size_); + output_size_list_.push_back(output_size_); + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), + "cudnnGetTensorSizeInBytes failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_, + algo_, &workspace_size_), + "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); + workspace_size_list_.push_back(padded_size_); + } else { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_), + "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); + } + } + (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), + "cudnnDestroyConvolutionDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "cudnnDestroyFilterDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "cudnnDestroyTensorDescriptor failed"); + } + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradInput needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradInput needs 1 output."; + return false; + } + return true; + } + void SetPad(const std::vector &input_shape, const CNodePtr &kernel_node) { + auto pad_list = GetAttr>(kernel_node, "pad_list"); + n_ = input_shape[0]; + c_ = input_shape[1]; + old_height_ = input_shape[2]; + old_width_ = input_shape[3]; + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( + conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], + dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); + } + void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { + if (group_ > 1 || CUDNN_MAJOR < 7) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), + "cudnnGetConvolutionBackwardDataAlgorithm failed"); + } else { + constexpr int requested_algo_count = 1; + int returned_algo_count; + cudnnConvolutionBwdDataAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, + requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); + algo_ = perf_results.algo; + } + if (cudnn_data_type_ == CUDNN_DATA_HALF) { + algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + } + } + void GetInputShape(const CNodePtr &kernel_node, std::vector *input_shape) { + auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast()->value(); + (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), + [](const ValuePtr &e) -> int { return e->cast()->value(); }); + } + void Set4DDesc(const std::vector &dy_shape, const std::vector &input_shape, + const std::vector &filter_shape) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), + SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), + "SetFilter4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), + SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), + "SetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1], + input_shape[2], input_shape[3]), + "SetTensor4dDescriptor failed"); + } + void SetStrideAndDilation(const CNodePtr &kernel_node) { + stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); + dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); + if (stride_.size() != 2) { + MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's stride must be 2d!"; + } + if (dilation_.size() != 4) { + MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's dilation must be 4d!"; + } + if (dilation_[0] != 1 || dilation_[1] != 1) { + MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!"; + } + } + cudnnHandle_t cudnn_handle_; + cudnnFilterDescriptor_t w_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t dx_desc_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnConvolutionBwdDataAlgo_t algo_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + std::vector stride_; + std::vector dilation_; + int group_; + bool is_null_input_; + size_t dy_size_; + size_t w_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.cc new file mode 100644 index 0000000000..155451875c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CTCLossV2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CtcLossGpuKernel, float) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h new file mode 100644 index 0000000000..8b02354516 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctcloss_gpu_kernel.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace kernel { +template +class CtcLossGpuKernel : public GpuKernel { + public: + CtcLossGpuKernel() + : cudnn_handle_(nullptr), + probs_desc_(nullptr), + ctcloss_desc_(nullptr), + label_size_(0), + input_lengths_size_(0), + label_lengths_size_(0) {} + ~CtcLossGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + float *probs = GetDeviceAddress(inputs, 0); + int *labels = GetDeviceAddress(inputs, 1); + int *input_lengths = GetDeviceAddress(inputs, 2); + int *label_lengths = GetDeviceAddress(inputs, 3); + float *costs = GetDeviceAddress(outputs, 0); + float *grads = GetDeviceAddress(outputs, 1); + + // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires + void *labels_host = nullptr; + void *input_lengths_host = nullptr; + void *label_lengths_host = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); + cudaStream_t stream = reinterpret_cast(stream_ptr); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(labels_host, labels, inputs[1]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(input_lengths_host, input_lengths, inputs[2]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(label_lengths_host, label_lengths, inputs[3]->size, cudaMemcpyDeviceToHost, stream), + "cudaMemcpyAsync failed."); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(labels_host), + reinterpret_cast(label_lengths_host), + reinterpret_cast(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, + ctcloss_desc_, &workspace_size), + "cudnnGetCTCLossWorkspaceSize failed."); + void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); + if (workspace == nullptr) { + MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(labels_host), + reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), costs, + probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), + "cudnnCtcLoss failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); + + device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (probs_shape.size() != 3) { + MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support."; + } + probs_dims_[0] = probs_shape[0]; + probs_dims_[1] = probs_shape[1]; + probs_dims_[2] = probs_shape[2]; + + auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + if (labels_dims.size() != 1 && labels_dims.size() != 2) { + MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support."; + } + label_size_ = sizeof(int); + for (auto i : labels_dims) { + label_size_ *= i; + } + + auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + input_lengths_size_ = input_length_dims[0] * sizeof(int); + auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + label_lengths_size_ = label_length_dims[0] * sizeof(int); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_), + "cudnnSetTensorNdDescriptorEx failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT, + CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN), + "cudnnSetCTCLossDescriptorEx failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed."); + } + + void InitSizeLists() override { + input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); + input_size_list_.push_back(label_size_); + input_size_list_.push_back(input_lengths_size_); + input_size_list_.push_back(label_lengths_size_); + + output_size_list_.push_back(probs_dims_[1] * sizeof(float)); + output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed."); + } + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t probs_desc_; + cudnnCTCLossDescriptor_t ctcloss_desc_; + int probs_dims_[3] = {0}; + int label_size_; + int input_lengths_size_; + int label_lengths_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.cc new file mode 100644 index 0000000000..423a230b6e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Dropout, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DropoutGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + Dropout, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DropoutGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h new file mode 100644 index 0000000000..2104d7af35 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_gpu_kernel.h @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh" +#include "include/curand.h" + +namespace mindspore { +namespace kernel { +template +class DropoutGpuFwdKernel : public GpuKernel { + public: + DropoutGpuFwdKernel() + : cudnn_handle_(nullptr), + is_null_input_(false), + num_count_(0), + keep_prob_(0.0), + states_init_(false), + mask_generator_(nullptr) {} + + ~DropoutGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + T *mask = GetDeviceAddress(outputs, 1); + float *mask_f = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL)); + states_init_ = true; + } + // curandGen only support float or double for mask. + curandGenerateUniform(mask_generator_, mask_f, num_count_); + DropoutForward(input, mask, output, mask_f, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1."; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + InitSizeLists(); + return true; + } + + num_count_ = 1; + for (size_t x : input_shape) { + num_count_ *= x; + } + keep_prob_ = GetAttr(kernel_node, "keep_prob"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + + void InitSizeLists() override { + size_t input_size = num_count_ * sizeof(T); + input_size_list_.push_back(input_size); + output_size_list_.push_back(input_size); // output size: the same with input size + output_size_list_.push_back(input_size); // mask size: the same with input size + workspace_size_list_.push_back(num_count_ * sizeof(float)); // temp mask_f for curandGen + } + + private: + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t num_count_; + float keep_prob_; + bool states_init_; + curandGenerator_t mask_generator_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.cc new file mode 100644 index 0000000000..faf884c2eb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + DropoutGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + DropoutGradGpuBwdKernel, float) +MS_REG_GPU_KERNEL_ONE( + DropoutGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + DropoutGradGpuBwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h new file mode 100644 index 0000000000..a3a7250c9b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/dropout_grad_kernel.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/dropout_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class DropoutGradGpuBwdKernel : public GpuKernel { + public: + DropoutGradGpuBwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {} + ~DropoutGradGpuBwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + T *dy = GetDeviceAddress(inputs, 0); + T *mask = GetDeviceAddress(inputs, 1); + T *dx = GetDeviceAddress(outputs, 0); + + DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuBwdKernel needs 2."; + return false; + } + + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + InitSizeLists(); + return true; + } + + num_count_ = 1; + for (size_t x : input_shape) { + num_count_ *= x; + } + keep_prob_ = GetAttr(kernel_node, "keep_prob"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + void InitSizeLists() override { + size_t dy_size = num_count_ * sizeof(T); + size_t mask_size = dy_size; + size_t dx_size = dy_size; + + input_size_list_.push_back(dy_size); + input_size_list_.push_back(mask_size); + output_size_list_.push_back(dx_size); + } + + private: + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t num_count_; + float keep_prob_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.cc new file mode 100644 index 0000000000..d8206aedcd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGpuFwdKernel, int) +MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGpuFwdKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h new file mode 100644 index 0000000000..a140579a3c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_gpu_kernel.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class FlattenGpuFwdKernel : public GpuKernel { + public: + FlattenGpuFwdKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} + ~FlattenGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + cudaError_t ret = + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); + if (ret) { + MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGpuFwdKernel::Launch, error code is " << ret; + return false; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t i = 0; i < shape.size(); ++i) { + input_size_ *= shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_ = input_size_; + output_size_list_.push_back(output_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.cc new file mode 100644 index 0000000000..c07126a2ed --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FlattenGardGpuBkwKernel, float) +MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FlattenGardGpuBkwKernel, half) +MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + FlattenGardGpuBkwKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h new file mode 100644 index 0000000000..b21327bc3b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/flatten_grad_gpu_kernel.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class FlattenGardGpuBkwKernel : public GpuKernel { + public: + FlattenGardGpuBkwKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} + ~FlattenGardGpuBkwKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + cudaError_t ret = + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); + if (ret) { + MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGardGpuFwdKernel::Launch, error code is " << ret; + return false; + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but FlattenGardGpuFwdKernel needs 1."; + return false; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < shape.size(); ++i) { + if (input_size_ == 0) { + input_size_ = 1; + } + input_size_ *= shape[i]; + } + input_size_ = input_size_ * sizeof(T); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_ = input_size_; + output_size_list_.push_back(output_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; + size_t output_size_; + size_t workspace_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.cc new file mode 100644 index 0000000000..0186153745 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FtrlGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FtrlGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h new file mode 100644 index 0000000000..ea08741dba --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ftrl_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/ftrl_impl.cuh" +namespace mindspore { +namespace kernel { +template +class FtrlGpuKernel : public GpuKernel { + public: + FtrlGpuKernel() + : variable_size_(0), + accumulation_size_(0), + linear_size_(0), + gradient_size_(0), + learning_rate_size_(0), + l1_regularization_size_(0), + l2_regularization_size_(0), + learning_rate_power_size_(0) {} + + ~FtrlGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + T *linear = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + T *learning_rate = GetDeviceAddress(inputs, 4); + T *l1_regularization = GetDeviceAddress(inputs, 5); + T *l2_regularization = GetDeviceAddress(inputs, 6); + T *learning_rate_power = GetDeviceAddress(inputs, 7); + ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization, + learning_rate_power, variable, accumulation, linear, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs."; + return false; + } + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + linear_size_ = sizeof(T); + gradient_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + l1_regularization_size_ = sizeof(T); + l2_regularization_size_ = sizeof(T); + learning_rate_power_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + + auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < linear_shape.size(); i++) { + linear_size_ *= linear_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(linear_size_); + input_size_list_.push_back(gradient_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(l1_regularization_size_); + input_size_list_.push_back(l2_regularization_size_); + input_size_list_.push_back(learning_rate_power_size_); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t linear_size_; + size_t gradient_size_; + size_t learning_rate_size_; + size_t l1_regularization_size_; + size_t l2_regularization_size_; + size_t learning_rate_power_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.cc new file mode 100644 index 0000000000..5ef2fd8786 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAdamWeightDecayGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedAdam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedAdamWeightDecayGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h new file mode 100644 index 0000000000..c4fd31a737 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_adam_weight_decay.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adam_weight_decay_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class FusedAdamWeightDecayGpuKernel : public GpuKernel { + public: + FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {} + ~FusedAdamWeightDecayGpuKernel() override = default; + + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "AdamWeighDecay") { + weight_decay_ = true; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7); + element_nums_ = 1; + for (auto i : shape) { + element_nums_ *= i; + } + + InitSizeLists(); + return true; + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + float *beta1 = GetDeviceAddress(inputs, 0); + float *one_sub_beta1 = GetDeviceAddress(inputs, 1); + float *beta2 = GetDeviceAddress(inputs, 2); + float *one_sub_beta2 = GetDeviceAddress(inputs, 3); + float *epsilon = GetDeviceAddress(inputs, 4); + float *lr = GetDeviceAddress(inputs, 5); + T *param = GetDeviceAddress(inputs, 6); + T *m = GetDeviceAddress(inputs, 7); + T *v = GetDeviceAddress(inputs, 8); + T *gradient = GetDeviceAddress(inputs, 9); + float *weight_decay = nullptr; + if (weight_decay_) { + weight_decay = GetDeviceAddress(inputs, 10); + } + AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, + param, gradient, reinterpret_cast(stream_ptr)); + return true; + } + + protected: + void InitResource() override{}; + void InitSizeLists() override { + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(element_nums_ * sizeof(T)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(sizeof(float)); + input_size_list_.push_back(element_nums_ * sizeof(T)); + if (weight_decay_) { + input_size_list_.push_back(sizeof(float)); + } + output_size_list_.push_back(element_nums_ * sizeof(T)); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int element_nums_; + bool weight_decay_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc new file mode 100644 index 0000000000..2ce39b63a0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedBatchNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FusedBatchNormGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(BatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedBatchNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(BatchNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FusedBatchNormGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h new file mode 100644 index 0000000000..774428dc40 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -0,0 +1,190 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class FusedBatchNormGpuKernel : public GpuKernel { + public: + FusedBatchNormGpuKernel() + : batch_(0), + channel_(0), + height_(0), + width_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + epsilon_(10e-5), + exp_avg_factor_(0.1), + is_train_(false), + is_null_input_(false), + x_desc_(nullptr), + y_desc_(nullptr), + scale_bias_mean_var_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~FusedBatchNormGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto x = GetDeviceAddress(inputs, 0); + auto scale = GetDeviceAddress(inputs, 1); + auto bias = GetDeviceAddress(inputs, 2); + auto runing_mean = GetDeviceAddress(inputs, 3); + auto runnig_variance = GetDeviceAddress(inputs, 4); + auto y = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + if (is_train_) { + auto save_mean = GetDeviceAddress(outputs, 3); + auto save_variance = GetDeviceAddress(outputs, 4); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, + scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, + runnig_variance, epsilon_, save_mean, save_variance), + "Kernel launch failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x, + y_desc_, y, scale_bias_mean_var_desc_, scale, + bias, runing_mean, runnig_variance, epsilon_), + "Kernel launch failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (shape.size() != 4) { + MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4"; + } + is_null_input_ = CHECK_NULL_INPUT(shape); + if (is_null_input_) { + MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null"; + InitSizeLists(); + return true; + } + batch_ = SizeToInt(shape[0]); + channel_ = SizeToInt(shape[1]); + height_ = SizeToInt(shape[2]); + width_ = SizeToInt(shape[3]); + + mode_ = CUDNN_BATCHNORM_SPATIAL; + epsilon_ = GetAttr(kernel_node, "epsilon"); + // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "FusedBatchNorm") { + is_train_ = true; + exp_avg_factor_ = GetAttr(kernel_node, "momentum"); + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set x desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set y desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), + "Set para desc failed"); + + InitSizeLists(); + + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); + } + void InitSizeLists() override { + size_t input_size = 0; + size_t para_size = 0; + size_t output_size = 0; + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size), + "Get para size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_desc_, &output_size), "Get para size failed"); + } + input_size_list_.push_back(input_size); + input_size_list_.push_back(para_size); // scale + input_size_list_.push_back(para_size); // bias + input_size_list_.push_back(para_size); // mean + input_size_list_.push_back(para_size); // variance + + output_size_list_.push_back(output_size); + output_size_list_.push_back(para_size); // running mean + output_size_list_.push_back(para_size); // running variance + output_size_list_.push_back(para_size); // save mean + output_size_list_.push_back(para_size); // save variance + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); + } + + int batch_; + int channel_; + int height_; + int width_; + cudnnBatchNormMode_t mode_; + double epsilon_; + double exp_avg_factor_; + bool is_train_; + bool is_null_input_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t y_desc_; + cudnnTensorDescriptor_t scale_bias_mean_var_desc_; + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc new file mode 100644 index 0000000000..546e034f6b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FusedBatchNormGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FusedBatchNormGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h new file mode 100644 index 0000000000..a2d0d741b1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batchnorm_grad_gpu_kernel.h @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class FusedBatchNormGradGpuKernel : public GpuKernel { + public: + FusedBatchNormGradGpuKernel() + : batch_(0), + channel_(0), + height_(0), + width_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + epsilon_(10e-5), + is_null_input_(false), + x_desc_(nullptr), + dy_desc_(nullptr), + dx_desc_(nullptr), + scale_bias_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~FusedBatchNormGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + VARIABLE_NOT_USED(stream_ptr); + if (is_null_input_) { + return true; + } + auto dy = GetDeviceAddress(inputs, 0); + auto x = GetDeviceAddress(inputs, 1); + auto scale = GetDeviceAddress(inputs, 2); + auto save_mean = GetDeviceAddress(inputs, 3); + auto save_variance = GetDeviceAddress(inputs, 4); + auto dx = GetDeviceAddress(outputs, 0); + auto bn_scale = GetDeviceAddress(outputs, 1); + auto bn_bias = GetDeviceAddress(outputs, 2); + + const float alpha_data_diff = 1; + const float beta_data_diff = 0; + const float alpha_param_diff = 1; + const float beta_param_diff = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, + &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, + bn_scale, bn_bias, epsilon_, save_mean, save_variance), + "Kernel Launch Failed."); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; + } + + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (shape.size() != 4) { + MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4"; + return false; + } + is_null_input_ = CHECK_NULL_INPUT(shape); + if (is_null_input_) { + MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null"; + InitSizeLists(); + return true; + } + batch_ = SizeToInt(shape[0]); + channel_ = SizeToInt(shape[1]); + height_ = SizeToInt(shape[2]); + width_ = SizeToInt(shape[3]); + + mode_ = CUDNN_BATCHNORM_SPATIAL; + epsilon_ = GetAttr(kernel_node, "epsilon"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set dy desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + "Set dx desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), + "Set para desc failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_desc_), "Create para desc failed"); + } + + void InitSizeLists() override { + size_t input_size = 0; + size_t para_size = 0; + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size), "Get input size failed"); + } + + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(para_size); + input_size_list_.push_back(para_size); + input_size_list_.push_back(para_size); + + output_size_list_.push_back(input_size); + output_size_list_.push_back(para_size); + output_size_list_.push_back(para_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_desc_), "Destroy para desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + } + + int batch_; + int channel_; + int height_; + int width_; + + cudnnBatchNormMode_t mode_; + double epsilon_; + bool is_null_input_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t dy_desc_; + cudnnTensorDescriptor_t dx_desc_; + cudnnTensorDescriptor_t scale_bias_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.cc new file mode 100644 index 0000000000..274e4896c9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(GeluGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + GeLUGpuGradKernel, float) +MS_REG_GPU_KERNEL_ONE(GeluGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + GeLUGpuGradKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h new file mode 100644 index 0000000000..823da1fe9f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_grad_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class GeLUGpuGradKernel : public GpuKernel { + public: + GeLUGpuGradKernel() : input_size_(0) {} + ~GeLUGpuGradKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *dy_addr = GetDeviceAddress(inputs, 0); + T *x_addr = GetDeviceAddress(inputs, 1); + T *dx_addr = GetDeviceAddress(outputs, 0); + + GeluGradKernel(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.cc new file mode 100644 index 0000000000..03cd9a155b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/gelu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + GeluGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + GeluGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h new file mode 100644 index 0000000000..76d3861d55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/gelu_kernel.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/gelu_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class GeluGpuKernel : public GpuKernel { + public: + GeluGpuKernel() : input_size_(0) {} + ~GeluGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + input_size_ = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (auto dim : input_shape) { + input_size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.cc new file mode 100644 index 0000000000..49f556ae64 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h new file mode 100644 index 0000000000..74669e03de --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_gpu_kernel.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGpuKernel : public GpuKernel { + public: + LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} + ~LayerNormGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto gamma = GetDeviceAddress(inputs, 1); + auto beta = GetDeviceAddress(inputs, 2); + auto y = GetDeviceAddress(outputs, 0); + auto mean = GetDeviceAddress(outputs, 1); + auto variance = GetDeviceAddress(outputs, 2); + + const T epsilon = 10e-12; + LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); + int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + output_size_list_.push_back(input_row_ * sizeof(T)); + output_size_list_.push_back(input_row_ * sizeof(T)); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.cc new file mode 100644 index 0000000000..b59f95b8a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LayerNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LayerNormGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h new file mode 100644 index 0000000000..93967adad3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/layer_norm_grad_gpu_kernel.h @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/layer_norm_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class LayerNormGradGpuKernel : public GpuKernel { + public: + LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} + ~LayerNormGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto x = GetDeviceAddress(inputs, 0); + auto dy = GetDeviceAddress(inputs, 1); + auto var = GetDeviceAddress(inputs, 2); + auto mean = GetDeviceAddress(inputs, 3); + auto gamma = GetDeviceAddress(inputs, 4); + auto dx = GetDeviceAddress(outputs, 0); + auto dg = GetDeviceAddress(outputs, 1); + auto db = GetDeviceAddress(outputs, 2); + + const T epsilon = 10e-12; + LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); + int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (begin_norm_axis < 0) { + begin_norm_axis += input_shape.size(); + } + + if (begin_params_axis < 0) { + begin_params_axis += input_shape.size(); + } + + for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { + input_row_ *= input_shape[i]; + } + + for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { + input_col_ *= input_shape[i]; + } + + for (size_t i = begin_params_axis; i < input_shape.size(); i++) { + param_dim_ *= input_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(input_row_ * sizeof(T)); + input_size_list_.push_back(param_dim_ * sizeof(T)); + + output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); + output_size_list_.push_back(param_dim_ * sizeof(T)); + output_size_list_.push_back(param_dim_ * sizeof(T)); + return; + } + + private: + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int input_row_; + int input_col_; + int param_dim_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.cc new file mode 100644 index 0000000000..a24aaeeb96 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LSTM, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LSTM, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LstmGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h new file mode 100644 index 0000000000..ad3e588f00 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h @@ -0,0 +1,247 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class LstmGpuKernel : public GpuKernel { + public: + LstmGpuKernel() + : batch_size_(0), + seq_len_(0), + input_size_(0), + hidden_size_(0), + num_layers_(0), + has_bias_(false), + bidirectional_(false), + states_init_(false), + dropout_(0), + weight_size_(0), + reserved_size_(0), + x_desc_(nullptr), + hx_desc_(nullptr), + cx_desc_(nullptr), + w_desc_(nullptr), + dropout_desc_(nullptr), + y_desc_(nullptr), + hy_desc_(nullptr), + cy_desc_(nullptr), + rnn_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~LstmGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(stream_ptr); + auto x_addr = GetDeviceAddress(inputs, 0); + auto hx_addr = GetDeviceAddress(inputs, 1); + auto cx_addr = GetDeviceAddress(inputs, 2); + auto w_addr = GetDeviceAddress(inputs, 3); + auto y_addr = GetDeviceAddress(outputs, 0); + auto hy_addr = GetDeviceAddress(outputs, 1); + auto cy_addr = GetDeviceAddress(outputs, 2); + auto reserved_addr = GetDeviceAddress(outputs, 3); + auto states_addr = GetDeviceAddress(outputs, 4); + void *workspace_addr = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, output_size_list_[4], 0), + "set dropout_desc failed"); + states_init_ = true; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRNNForwardTraining(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, + w_desc_, w_addr, y_desc_.get(), y_addr, hy_desc_, hy_addr, cy_desc_, cy_addr, + workspace_addr, workspace_size_list_[0], reserved_addr, reserved_size_), + "launch lstm kernel failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + seq_len_ = SizeToInt(input_shape[0]); + batch_size_ = SizeToInt(input_shape[1]); + input_size_ = SizeToInt(input_shape[2]); + + input_size_ = GetAttr(kernel_node, "input_size"); + hidden_size_ = GetAttr(kernel_node, "hidden_size"); + num_layers_ = GetAttr(kernel_node, "num_layers"); + has_bias_ = GetAttr(kernel_node, "has_bias"); + bidirectional_ = GetAttr(kernel_node, "bidirectional"); + dropout_ = GetAttr(kernel_node, "dropout"); + + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + CreateTensorDescGrp(); + int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set cx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set cy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), + "set dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); + cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); + auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), + "get weight_size_ failed"); + if (weight_size != weight_size_) { + MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; + } + int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), + "set w_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), + "get reserve size failed"); + InitSizeLists(); + return true; + } + void CreateTensorDescGrp() { + int x_dims[3]{batch_size_, input_size_, 1}; + int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; + + x_desc_ = std::make_unique(seq_len_); + y_desc_ = std::make_unique(seq_len_); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); + } + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hy_desc_), "create hy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cy_desc_), "create cy_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); + } + void InitSizeLists() override { + size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); + + size_t h_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); + + input_size_list_.push_back(x_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(weight_size_); + + size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); + output_size_list_.push_back(y_size); + output_size_list_.push_back(h_size); + output_size_list_.push_back(h_size); + output_size_list_.push_back(reserved_size_); + size_t state_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); + output_size_list_.push_back(state_size); + + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), + "get workspace size failed"); + workspace_size_list_.push_back(workspace_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cy_desc_), "destroy cy_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hy_desc_), "destroy hy_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc failed"); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); + } + } + + int batch_size_; + int seq_len_; + int input_size_; + int hidden_size_; + int num_layers_; + + bool has_bias_; + bool bidirectional_; + bool states_init_; + float dropout_; + + size_t weight_size_; + size_t reserved_size_; + + // input desc + std::unique_ptr x_desc_; + cudnnTensorDescriptor_t hx_desc_; + cudnnTensorDescriptor_t cx_desc_; + cudnnFilterDescriptor_t w_desc_; + cudnnDropoutDescriptor_t dropout_desc_; + std::unique_ptr y_desc_; + cudnnTensorDescriptor_t hy_desc_; + cudnnTensorDescriptor_t cy_desc_; + cudnnRNNDescriptor_t rnn_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.cc new file mode 100644 index 0000000000..1fa47690b3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LSTMGradData, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmGradDataGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LSTMGradData, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LstmGradDataGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h new file mode 100644 index 0000000000..6d6bed5555 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h @@ -0,0 +1,284 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class LstmGradDataGpuKernel : public GpuKernel { + public: + LstmGradDataGpuKernel() + : batch_size_(0), + seq_len_(0), + input_size_(0), + hidden_size_(0), + num_layers_(0), + has_bias_(false), + bidirectional_(false), + states_init_(false), + dropout_(0), + weight_size_(0), + reserved_size_(0), + rnn_desc_(nullptr), + y_desc_(nullptr), + dy_desc_(nullptr), + dhy_desc_(nullptr), + dcy_desc_(nullptr), + w_desc_(nullptr), + hx_desc_(nullptr), + cx_desc_(nullptr), + dropout_desc_(nullptr), + dx_desc_(nullptr), + dhx_desc_(nullptr), + dcx_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~LstmGradDataGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(stream_ptr); + auto y_addr = GetDeviceAddress(inputs, 0); + auto dy_addr = GetDeviceAddress(inputs, 1); + auto dhy_addr = GetDeviceAddress(inputs, 2); + auto dcy_addr = GetDeviceAddress(inputs, 3); + auto w_addr = GetDeviceAddress(inputs, 4); + auto hx_addr = GetDeviceAddress(inputs, 5); + auto cx_addr = GetDeviceAddress(inputs, 6); + auto reserved_addr = GetDeviceAddress(inputs, 7); + auto states_addr = GetDeviceAddress(inputs, 8); + auto dx_addr = GetDeviceAddress(outputs, 0); + auto dhx_addr = GetDeviceAddress(outputs, 1); + auto dcx_addr = GetDeviceAddress(outputs, 2); + void *workspace_addr = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[8], 0), + "restore dropout state failed"); + states_init_ = true; + } + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRNNBackwardData(handle_, rnn_desc_, seq_len_, y_desc_.get(), y_addr, dy_desc_.get(), dy_addr, dhy_desc_, + dhy_addr, dcy_desc_, dcy_addr, w_desc_, w_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, + dx_desc_.get(), dx_addr, dhx_desc_, dhx_addr, dcx_desc_, dcx_addr, workspace_addr, + workspace_size_list_[0], reserved_addr, reserved_size_), + "launch lstm back data kernel failed"); + + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "stream synchronize failed."); + return true; + } + void GetAttrs(const CNodePtr &kernel_node) { + input_size_ = GetAttr(kernel_node, "input_size"); + hidden_size_ = GetAttr(kernel_node, "hidden_size"); + num_layers_ = GetAttr(kernel_node, "num_layers"); + has_bias_ = GetAttr(kernel_node, "has_bias"); + bidirectional_ = GetAttr(kernel_node, "bidirectional"); + dropout_ = GetAttr(kernel_node, "dropout"); + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + seq_len_ = SizeToInt(input_shape[0]); + batch_size_ = SizeToInt(input_shape[1]); + GetAttrs(kernel_node); + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + CreateTensorDescGrp(); + int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dhy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dcy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set cx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dhx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dcx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), + "set dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); + cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); + auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); + size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, dx_desc_[0], &weight_size_, cudnn_data_type_), + "get weight_size_ failed"); + if (weight_size != weight_size_) { + MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; + } + int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), + "set w_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &reserved_size_), "get size failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhy_desc_), "create dhy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcy_desc_), "create dcy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhx_desc_), "create dhx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcx_desc_), "create dcx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); + } + + void InitSizeLists() override { + size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); + + size_t h_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); + + input_size_list_.push_back(y_size); + input_size_list_.push_back(y_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(weight_size_); + input_size_list_.push_back(h_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(reserved_size_); + size_t state_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); + input_size_list_.push_back(state_size); + + size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); + output_size_list_.push_back(x_size); + output_size_list_.push_back(h_size); + output_size_list_.push_back(h_size); + + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &workspace_size), + "get workspace size failed"); + workspace_size_list_.push_back(workspace_size); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcx_desc_), "destroy dcx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhx_desc_), "destroy dhx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcy_desc_), "destroy dcy_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhy_desc_), "destroy dhy_desc_ failed"); + DestroyTensorDescGrp(); + } + void CreateTensorDescGrp() { + int x_dims[3]{batch_size_, input_size_, 1}; + int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; + + dx_desc_ = std::make_unique(seq_len_); + y_desc_ = std::make_unique(seq_len_); + dy_desc_ = std::make_unique(seq_len_); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_[i]), "create x_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dx_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), + "set dx_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_[i]), "create dy_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(dy_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), + "set dy_desc_ failed"); + } + } + + void DestroyTensorDescGrp() { + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_[i]), "destroy dy_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_[i]), "destroy x_desc failed"); + } + } + + int batch_size_; + int seq_len_; + int input_size_; + int hidden_size_; + int num_layers_; + + bool has_bias_; + bool bidirectional_; + bool states_init_; + float dropout_; + + size_t weight_size_; + size_t reserved_size_; + + cudnnRNNDescriptor_t rnn_desc_; + + // input desc + std::unique_ptr y_desc_; + std::unique_ptr dy_desc_; + cudnnTensorDescriptor_t dhy_desc_; + cudnnTensorDescriptor_t dcy_desc_; + cudnnFilterDescriptor_t w_desc_; + cudnnTensorDescriptor_t hx_desc_; + cudnnTensorDescriptor_t cx_desc_; + + cudnnDropoutDescriptor_t dropout_desc_; + + // output desc + std::unique_ptr dx_desc_; + cudnnTensorDescriptor_t dhx_desc_; + cudnnTensorDescriptor_t dcx_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.cc new file mode 100644 index 0000000000..9ec239491f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + LstmGradWeightGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LstmGradWeightGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h new file mode 100644 index 0000000000..445d2ce199 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h @@ -0,0 +1,231 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +namespace mindspore { +namespace kernel { +template +class LstmGradWeightGpuKernel : public GpuKernel { + public: + LstmGradWeightGpuKernel() + : batch_size_(0), + seq_len_(0), + input_size_(0), + hidden_size_(0), + num_layers_(0), + has_bias_(false), + bidirectional_(false), + states_init_(false), + dropout_(0), + weight_size_(0), + reserved_size_(0), + rnn_desc_(nullptr), + dropout_desc_(nullptr), + x_desc_(nullptr), + hx_desc_(nullptr), + y_desc_(nullptr), + dw_desc_(nullptr), + handle_(nullptr), + cudnn_data_type_(CUDNN_DATA_FLOAT) {} + ~LstmGradWeightGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(stream_ptr); + auto x_addr = GetDeviceAddress(inputs, 0); + auto hx_addr = GetDeviceAddress(inputs, 1); + auto y_addr = GetDeviceAddress(inputs, 2); + auto reserved_addr = GetDeviceAddress(inputs, 3); + auto states_addr = GetDeviceAddress(inputs, 4); + auto dw_addr = GetDeviceAddress(outputs, 0); + void *workspace_addr = GetDeviceAddress(workspace, 0); + + if (!states_init_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[4], 0), + "restore dropout state failed"); + states_init_ = true; + } + + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemsetAsync(dw_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), "cudaMemSet Failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnRNNBackwardWeights(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, y_desc_.get(), + y_addr, workspace_addr, workspace_size_list_[0], dw_desc_, dw_addr, reserved_addr, + reserved_size_), + "launch lstm back weight kernel failed"); + + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + seq_len_ = SizeToInt(input_shape[0]); + batch_size_ = SizeToInt(input_shape[1]); + + input_size_ = GetAttr(kernel_node, "input_size"); + hidden_size_ = GetAttr(kernel_node, "hidden_size"); + num_layers_ = GetAttr(kernel_node, "num_layers"); + has_bias_ = GetAttr(kernel_node, "has_bias"); + bidirectional_ = GetAttr(kernel_node, "bidirectional"); + dropout_ = GetAttr(kernel_node, "dropout"); + + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + + CreateTensorDescGrp(); + int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), + "set hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), + "set dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); + cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); + + auto weight_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), + "get weight_size_ failed"); + if (weight_size != weight_size_) { + MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; + } + int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), + "set dw_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), + "get reserve size failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "create dw_desc_ failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); + } + void InitSizeLists() override { + size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); + + size_t h_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); + + size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); + input_size_list_.push_back(x_size); + input_size_list_.push_back(h_size); + input_size_list_.push_back(y_size); + input_size_list_.push_back(reserved_size_); + size_t state_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); + input_size_list_.push_back(state_size); + + output_size_list_.push_back(weight_size_); + + size_t workspace_size = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), + "get workspace size failed"); + workspace_size_list_.push_back(workspace_size); + } + + private: + void CreateTensorDescGrp() { + int x_dims[3]{batch_size_, input_size_, 1}; + int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; + + x_desc_ = std::make_unique(seq_len_); + y_desc_ = std::make_unique(seq_len_); + + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); + } + } + void DestroyTensorDescGrp() { + for (size_t i = 0; i < IntToSize(seq_len_); ++i) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); + } + } + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "destroy dw_desc_ failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); + DestroyTensorDescGrp(); + } + + int batch_size_; + int seq_len_; + int input_size_; + int hidden_size_; + int num_layers_; + + bool has_bias_; + bool bidirectional_; + bool states_init_; + float dropout_; + + size_t weight_size_; + size_t reserved_size_; + + cudnnRNNDescriptor_t rnn_desc_; + cudnnDropoutDescriptor_t dropout_desc_; + + // input desc + std::unique_ptr x_desc_; + cudnnTensorDescriptor_t hx_desc_; + std::unique_ptr y_desc_; + + // output desc + cudnnFilterDescriptor_t dw_desc_; + + cudnnHandle_t handle_; + cudnnDataType_t cudnn_data_type_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc new file mode 100644 index 0000000000..99ae2affe8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + MomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h new file mode 100644 index 0000000000..32d3fbb079 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/momentum_gpu_kernel.h @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/momentum_impl.cuh" +namespace mindspore { +namespace kernel { +template +class MomentumGpuKernel : public GpuKernel { + public: + MomentumGpuKernel() + : variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), momentum_size_(0) {} + ~MomentumGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + S *learning_rate = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + S *momentum = GetDeviceAddress(inputs, 4); + MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 5) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs 5 inputs."; + return false; + } + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + learning_rate_size_ = sizeof(S); + gradient_size_ = sizeof(T); + momentum_size_ = sizeof(S); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(gradient_size_); + input_size_list_.push_back(momentum_size_); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t learning_rate_size_; + size_t gradient_size_; + size_t momentum_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc new file mode 100644 index 0000000000..902b0d9faf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PoolingGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + PoolingGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + PoolingGpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + PoolingGpuFwdKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h new file mode 100644 index 0000000000..908a4e9b99 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -0,0 +1,252 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class PoolingGpuFwdKernel : public GpuKernel { + public: + PoolingGpuFwdKernel() + : cudnn_handle_(nullptr), + input_descriptor_(nullptr), + output_descriptor_(nullptr), + pooling_descriptor_(nullptr), + padded_descriptor_(nullptr), + pooling_mode_(CUDNN_POOLING_MAX), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + pad_value_(0), + is_null_input_(false), + input_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~PoolingGpuFwdKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + T *input_addr = reinterpret_cast(inputs[0]->addr); + T *output_addr = reinterpret_cast(outputs[0]->addr); + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded_addr = reinterpret_cast(workspace[0]->addr); + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_, + padded_addr, &beta, output_descriptor_, output_addr), + "cudnnPoolingForward failed"); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, + input_addr, &beta, output_descriptor_, output_addr), + "cudnnPoolingForward failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null."; + InitSizeLists(); + return true; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), + SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), + SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + int window_height = window[2]; + int window_width = window[3]; + stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + SetPoolingMode(kernel_node); + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(input_shape, window_height, window_width); + } else { + pad_height_ = 0; + pad_width_ = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, + window_width, pad_height_, pad_width_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + } + + InitSizeLists(); + return true; + } + + protected: + void InitResource() { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), + "cudnnCreatePoolingDescriptor failed"); + } + void InitSizeLists() { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast(&input_size_)), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(output_descriptor_, reinterpret_cast(&output_size_)), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), + "cudnnGetTensorSizeInBytes failed"); + workspace_size_list_.push_back(padded_size_); + if (padded_size_ == 0) { + MS_LOG(EXCEPTION) << "Padded size is 0."; + } + } + return; + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but pooling needs 1 inputs."; + return false; + } + return true; + } + void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { + n_ = SizeToInt(input_shape[0]); + c_ = SizeToInt(input_shape[1]); + old_height_ = SizeToInt(input_shape[2]); + old_width_ = SizeToInt(input_shape[3]); + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + window_height, window_width, use_pad_ ? 0 : pad_top_, + use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + } + void SetPoolingMode(const CNodePtr &kernel_node) { + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + mode_ = AnfAlgo::GetCNodeName(kernel_node); + if (mode_ == "AvgPool") { + pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + pad_value_ = 0.0; + } else { + pooling_mode_ = CUDNN_POOLING_MAX; + pad_value_ = kSignedMinFloat; + } + } + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), + "cudnnDestroyPoolingDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnTensorDescriptor_t output_descriptor_; + cudnnPoolingDescriptor_t pooling_descriptor_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; + std::vector stride_; + std::string mode_; + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + cudnnDataType_t cudnn_data_type_; + + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + float pad_value_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc new file mode 100644 index 0000000000..2948c900d2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PoolingGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + PoolingGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + PoolingGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + PoolingGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h new file mode 100644 index 0000000000..a066eacfa0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -0,0 +1,296 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class PoolingGradGpuKernel : public GpuKernel { + public: + PoolingGradGpuKernel() + : cudnn_handle_(nullptr), + pooling_descriptor_(nullptr), + y_descriptor_(nullptr), + dy_descriptor_(nullptr), + x_descriptor_(nullptr), + dx_descriptor_(nullptr), + padded_descriptor_(nullptr), + pooling_mode_(CUDNN_POOLING_MAX), + cudnn_data_type_(CUDNN_DATA_FLOAT), + old_height_(0), + old_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + n_(0), + c_(0), + pad_value_(0), + is_null_input_(false), + input_size_(0), + output_size_(0), + padded_size_(0), + workspace_size_(0), + use_pad_(true) {} + ~PoolingGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *x_data = GetDeviceAddress(inputs, 0); + T *y = GetDeviceAddress(inputs, 1); + T *dy = GetDeviceAddress(inputs, 2); + T *dx = GetDeviceAddress(outputs, 0); + + const float alpha = 1; + const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { + T *padded = GetDeviceAddress(workspace, 0); + T *padded_dx = GetDeviceAddress(workspace, 1); + + CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, + reinterpret_cast(stream_ptr)); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, + padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx), + "cudnnPoolingBackward failed"); + + CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, + x_descriptor_, x_data, &beta, dx_descriptor_, dx), + "cudnnPoolingBackward failed"); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + if (!CheckParam(kernel_node)) { + return false; + } + auto window = GetAttr>(kernel_node, "ksize"); + int window_height = window[2]; + int window_width = window[3]; + SetPoolingMode(kernel_node); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); + if (is_null_input_) { + MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; + InitSizeLists(); + return true; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]), + SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])), + "cudnnSetTensor4dDescriptor"); + + auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]), + SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])), + "cudnnSetTensor4dDescriptor"); + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), + SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + "cudnnSetTensor4dDescriptor failed"); + if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { + SetPad(input_shape, window_height, window_width); + } else { + if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { + pad_height_ = 0; + pad_width_ = 0; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, + window_width, pad_height_, pad_width_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), + SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), + "cudnnSetTensor4dDescriptor"); + } + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), + "cudnnCreatePoolingDescriptor failed"); + } + void InitSizeLists() override { + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_descriptor_, &output_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_descriptor_, &input_size_), + "cudnnGetTensorSizeInBytes failed"); + } + input_size_list_.push_back(input_size_); + + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), + "cudnnGetTensorSizeInBytes failed"); + if (padded_size_ == 0) { + MS_LOG(EXCEPTION) << "Padded size is 0."; + } + workspace_size_list_.push_back(padded_size_); + workspace_size_list_.push_back(padded_size_); + } + return; + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs 3 inputs."; + return false; + } + return true; + } + void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { + n_ = SizeToInt(input_shape[0]); + c_ = SizeToInt(input_shape[1]); + old_height_ = SizeToInt(input_shape[2]); + old_width_ = SizeToInt(input_shape[3]); + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, + c_, old_height_ + pad_height_, old_width_ + pad_width_), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), + SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0), + SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)), + "cudnnSetTensor4dDescriptor"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + window_height, window_width, use_pad_ ? 0 : pad_top_, + use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), + "cudnnSetPooling2dDescriptor failed"); + } + void SetPoolingMode(const CNodePtr &kernel_node) { + pad_mode_ = GetAttr(kernel_node, "padding"); + stride_ = GetAttr>(kernel_node, "strides"); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + mode_ = AnfAlgo::GetCNodeName(kernel_node); + if (mode_ == "AvgPoolGradGpu") { + pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + pad_value_ = 0.0; + } else { + pooling_mode_ = CUDNN_POOLING_MAX; + pad_value_ = kSignedMinFloat; + } + } + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), + "cudnnDestroyPoolingDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_descriptor_), "cudnnDestroyTensorDescriptor failed"); + } + + cudnnHandle_t cudnn_handle_; + cudnnPoolingDescriptor_t pooling_descriptor_; + cudnnTensorDescriptor_t y_descriptor_; + cudnnTensorDescriptor_t dy_descriptor_; + cudnnTensorDescriptor_t x_descriptor_; + cudnnTensorDescriptor_t dx_descriptor_; + cudnnTensorDescriptor_t padded_descriptor_; + cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; + std::vector stride_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + std::string mode_; + std::string pad_mode_; + cudnnDataType_t cudnn_data_type_; + int old_height_; + int old_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int n_; + int c_; + float pad_value_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t padded_size_; + size_t workspace_size_; + bool use_pad_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.cc new file mode 100644 index 0000000000..c33909a82b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h new file mode 100644 index 0000000000..9811c71094 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/rmsprop_gpu_kernel.h @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/rmsprop_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class RMSPropGpuKernel : public GpuKernel { + public: + RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} + ~RMSPropGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream) override { + if (!use_center_) { + T *variable = GetDeviceAddress(inputs, 0); + T *mean_square = GetDeviceAddress(inputs, 1); + T *moment = GetDeviceAddress(inputs, 2); + T *learning_rate = GetDeviceAddress(inputs, 3); + T *gradients = GetDeviceAddress(inputs, 4); + + RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_, + reinterpret_cast(stream)); + } else { + T *variable = GetDeviceAddress(inputs, 0); + T *mean_gradients = GetDeviceAddress(inputs, 1); + T *mean_square = GetDeviceAddress(inputs, 2); + T *moment = GetDeviceAddress(inputs, 3); + T *gradients = GetDeviceAddress(inputs, 4); + T *learning_rate = GetDeviceAddress(inputs, 5); + T *decay = GetDeviceAddress(inputs, 6); + T *momentum = GetDeviceAddress(inputs, 7); + T *epsilon = GetDeviceAddress(inputs, 8); + + RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients, + size_, reinterpret_cast(stream)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "ApplyCenteredRMSProp") { + use_center_ = true; + } + + if (node_name == "ApplyRMSProp") { + decay_ = GetAttr(kernel_node, "rho"); + momentum_ = GetAttr(kernel_node, "momentum"); + epsilon_ = GetAttr(kernel_node, "epsilon"); + } + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto &dim : input_shape) { + size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = size_ * sizeof(T); + if (!use_center_) { + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(input_size); + output_size_list_.push_back(input_size); + } else { + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(input_size); + } + } + + private: + size_t size_; + bool use_center_; + float decay_; + float momentum_; + float epsilon_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc new file mode 100644 index 0000000000..96d2d29549 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + SigmoidCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SigmoidCrossEntropyWithLogitsGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h new file mode 100644 index 0000000000..a2d3aabb68 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel { + public: + SigmoidCrossEntropyWithLogitsGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} + + ~SigmoidCrossEntropyWithLogitsGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *outputs_addr = GetDeviceAddress(outputs, 0); + + SigmoidCrossEntropyWithLogits(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs 2 inputs."; + return false; + } + logits_size_ = sizeof(T); + labels_size_ = sizeof(S); + outputs_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + logits_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + labels_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + outputs_size_ *= output_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(outputs_size_); + } + + private: + size_t logits_size_; + size_t labels_size_; + size_t outputs_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc new file mode 100644 index 0000000000..05c9a4234b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h new file mode 100644 index 0000000000..88ab46a6ba --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { + public: + SigmoidCrossEntropyWithLogitsGradGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} + ~SigmoidCrossEntropyWithLogitsGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *outputs_addr = GetDeviceAddress(outputs, 0); + + SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogitsGrad needs 3 inputs."; + return false; + } + logits_size_ = sizeof(T); + labels_size_ = sizeof(S); + outputs_size_ = sizeof(T); + + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < logits_shape.size(); i++) { + logits_size_ *= logits_shape[i]; + } + + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < labels_shape.size(); i++) { + labels_size_ *= labels_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + outputs_size_ *= output_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(outputs_size_); + } + + private: + size_t logits_size_; + size_t labels_size_; + size_t outputs_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.cc new file mode 100644 index 0000000000..ea40bea6a4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + SmoothL1Loss, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h new file mode 100644 index 0000000000..dc20f75077 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh" +namespace mindspore { +namespace kernel { +template +class SmoothL1LossGpuKernel : public GpuKernel { + public: + SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {} + ~SmoothL1LossGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *prediction = GetDeviceAddress(inputs, 0); + T *target = GetDeviceAddress(inputs, 1); + T *loss = GetDeviceAddress(outputs, 0); + + SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + sigma_ = GetAttr(kernel_node, "sigma"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + float sigma_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc new file mode 100644 index 0000000000..8a4fb38460 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(SmoothL1LossGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SmoothL1LossGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h new file mode 100644 index 0000000000..02be336932 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh" +namespace mindspore { +namespace kernel { +template +class SmoothL1LossGradGpuKernel : public GpuKernel { + public: + SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {} + ~SmoothL1LossGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *prediction = GetDeviceAddress(inputs, 0); + T *target = GetDeviceAddress(inputs, 1); + T *dloss = GetDeviceAddress(inputs, 2); + T *dx = GetDeviceAddress(outputs, 0); + + SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + + sigma_ = GetAttr(kernel_node, "sigma"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_ * sizeof(T)); + input_size_list_.push_back(input_size_ * sizeof(T)); + output_size_list_.push_back(input_size_ * sizeof(T)); + } + + private: + size_t input_size_; + float sigma_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc new file mode 100644 index 0000000000..8a64762c0a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(SoftmaxCrossEntropyWithLogits, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + SoftmaxCrossEntropyWithLogitsGpuKernel, float, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h new file mode 100644 index 0000000000..e56cb96fd7 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -0,0 +1,205 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { + public: + SoftmaxCrossEntropyWithLogitsGpuKernel() + : cudnn_handle_(nullptr), + logits_descriptor_(nullptr), + softmax_output_descriptor_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_null_input_(false), + logits_size_(0), + labels_size_(0), + output1_size_(0), + output2_size_(0), + softmax_output_logits_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *loss_addr = GetDeviceAddress(outputs, 0); + T *dlogits_addr = GetDeviceAddress(outputs, 1); + T *softmax_output_logits = GetDeviceAddress(workspace, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, + softmax_output_descriptor_, softmax_output_logits), + "cudnnSoftmaxForward failed."); + + CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num + << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "Output number is " << output_num + << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 output."; + return false; + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + + InferInputOutputSize(kernel_node); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + batch_size_, channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, + channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(output1_size_); + output_size_list_.push_back(output2_size_); + workspace_size_list_.push_back(softmax_output_logits_size_); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + void InferInputOutputSize(const CNodePtr &kernel_node) { + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; + InitSizeLists(); + return; + } + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; + InitSizeLists(); + return; + } + CheckShapeValidation(logits_shape, labels_shape); + + size_t logits_dims = logits_shape.size(); + batch_size_ = 1; + for (size_t i = 0; i < logits_dims - 1; i++) { + batch_size_ *= logits_shape[i]; + } + channel_size_ = logits_shape[logits_dims - 1]; + height_ = 1; + width_ = 1; + logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + + labels_size_ = 1; + size_t labels_dims = labels_shape.size(); + for (size_t i = 0; i < labels_dims; i++) { + labels_size_ *= labels_shape[i]; + } + labels_size_ *= sizeof(S); + + output1_size_ = logits_size_ / logits_shape[logits_dims - 1]; + output2_size_ = logits_size_; + softmax_output_logits_size_ = logits_size_; + return; + } + void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { + size_t logits_dim_length = logits_shape.size(); + size_t labels_dim_length = labels_shape.size(); + if (labels_dim_length != logits_dim_length) { + MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length for " + "SoftmaxCrossEntropyWithLogits, but got Labels " + "shape length:" + << labels_dim_length << ", Logits shape length:" << logits_dim_length; + } + if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { + MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; + } + return; + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t logits_descriptor_; + cudnnTensorDescriptor_t softmax_output_descriptor_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_null_input_; + + size_t logits_size_; + size_t labels_size_; + size_t output1_size_; + size_t output2_size_; + size_t softmax_output_logits_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.cc new file mode 100644 index 0000000000..24c2c12601 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftmaxGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftmaxGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h new file mode 100644 index 0000000000..279bac3aa9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_gpu_kernel.h @@ -0,0 +1,252 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SoftmaxGpuKernel : public GpuKernel { + public: + SoftmaxGpuKernel() + : cudnn_handle_(nullptr), + input_descriptor_(nullptr), + output_descriptor_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_null_input_(false), + input_size_(0), + output_size_(0), + workspace_size_(0), + axis_(0), + shape_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SoftmaxGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + const float alpha = 1; + const float beta = 0; + + if (axis_ == 1) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, + input_addr, &beta, output_descriptor_, output_addr), + "cudnnSoftmaxForward failed"); + } else { + T *transpose_input_addr = GetDeviceAddress(workspace, 0); + T *transpose_output_addr = GetDeviceAddress(workspace, 1); + int *input_shape = GetDeviceAddress(workspace, 2); + int *transpose_shape = GetDeviceAddress(workspace, 3); + int *transpose_axis = GetDeviceAddress(workspace, 4); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis failed"); + int size = SizeToInt(input_size_ / sizeof(T)); + CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, transpose_input_addr, &beta, + output_descriptor_, transpose_output_addr), + "cudnnSoftmaxForward failed"); + CalTranspose(size, transpose_output_addr, transpose_shape, transpose_axis, shape_size_, output_addr, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxGpuKernel input is null"; + InitSizeLists(); + return true; + } + shape_size_ = SizeToInt(input_shape.size()); + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "LogSoftmax") { + algo_ = CUDNN_SOFTMAX_LOG; + auto axis = GetAttr(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis); + } else { + algo_ = CUDNN_SOFTMAX_ACCURATE; + auto axis = GetAttr>(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis[0]); + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), + SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), + "set input_descriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), + SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), + "set output_descriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "create input_descriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "create output_descriptor failed"); + } + + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "destroy output_descriptor failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "destroy input_descriptor failed"); + } + + void InitSizeByAxis(const std::vector &input_shape, const int &axis) { + if (input_shape.size() == 2) { + InitSizeByAxis2D(input_shape, axis); + } else { + InitSizeByAxisLastDim(input_shape, axis); + } + } + + void InitSizeByAxis2D(const std::vector &input_shape, const int &axis) { + axis_ = axis; + if (axis_ < 0) { + axis_ += shape_size_; + } + if (axis_ == 1) { + batch_size_ = input_shape[0]; + channel_size_ = input_shape[1]; + } else if (axis_ == 0) { + batch_size_ = input_shape[1]; + channel_size_ = input_shape[0]; + input_shape_.push_back(input_shape[0]); + input_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[0]); + transpose_axis_.push_back(1); + transpose_axis_.push_back(0); + } else { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; + } + + height_ = 1; + width_ = 1; + input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + output_size_ = input_size_; + workspace_size_ = IntToSize(shape_size_) * sizeof(int); + } + + void InitSizeByAxisLastDim(const std::vector &input_shape, const int &axis) { + int axis_pos = axis; + if (axis_pos < 0) { + axis_pos += input_shape.size(); + } + // axis should be -1 with ND + if (axis_pos != SizeToInt(input_shape.size() - 1)) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; + } + // squeeze to 2d, then invoke cudnn + size_t n = 1; + for (size_t i = 0; i < input_shape.size() - 1; i++) { + n *= input_shape[i]; + } + axis_ = 1; + batch_size_ = n; + channel_size_ = input_shape[axis_pos]; + height_ = 1; + width_ = 1; + input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + output_size_ = input_size_; + input_shape_.push_back(batch_size_); + input_shape_.push_back(channel_size_); + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t input_descriptor_; + cudnnTensorDescriptor_t output_descriptor_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector input_shape_; + std::vector transpose_shape_; + std::vector transpose_axis_; + int axis_; + int shape_size_; + + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.cc new file mode 100644 index 0000000000..bd20413d08 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + LogSoftmaxGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + SoftmaxGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + LogSoftmaxGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + SoftmaxGradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h new file mode 100644 index 0000000000..b814be9969 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/softmax_grad_gpu_kernel.h @@ -0,0 +1,219 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class SoftmaxGradGpuKernel : public GpuKernel { + public: + SoftmaxGradGpuKernel() + : cudnn_handle_(nullptr), + y_desc_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_null_input_(false), + input_size_(0), + output_size_(0), + workspace_size_(0), + axis_(0), + shape_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SoftmaxGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *y_addr = GetDeviceAddress(inputs, 0); + T *dy_addr = GetDeviceAddress(inputs, 1); + T *dx_addr = GetDeviceAddress(outputs, 0); + + T *transpose_y_addr = GetDeviceAddress(workspace, 0); + T *transpose_dy_addr = GetDeviceAddress(workspace, 1); + T *transpose_dx_addr = GetDeviceAddress(workspace, 2); + int *input_shape = GetDeviceAddress(workspace, 3); + int *transpose_shape = GetDeviceAddress(workspace, 4); + int *transpose_axis = GetDeviceAddress(workspace, 5); + const float alpha = 1; + const float beta = 0; + + if (axis_ == 1) { + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_, + dy_addr, &beta, y_desc_, dx_addr), + "cudnnSoftmaxBackward failed"); + } else { + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_axis failed"); + int size = SizeToInt(input_size_ / sizeof(T)); + CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, + reinterpret_cast(stream_ptr)); + CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, + reinterpret_cast(stream_ptr)); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr, + y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr), + "cudnnSoftmaxBackward failed"); + CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax grad needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxGradGpuKernel input is null"; + InitSizeLists(); + return true; + } + shape_size_ = SizeToInt(input_shape.size()); + if (shape_size_ != 2) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == "LogSoftmaxGrad") { + algo_ = CUDNN_SOFTMAX_LOG; + auto axis = GetAttr(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis); + } else { + algo_ = CUDNN_SOFTMAX_ACCURATE; + auto axis = GetAttr>(kernel_node, "axis"); + InitSizeByAxis(input_shape, axis[0]); + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), + SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), + "set input_descriptor failed"); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed"); + } + + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(output_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed"); + } + + void InitSizeByAxis(const std::vector input_shape, const int axis) { + axis_ = axis; + if (axis_ < 0) { + axis_ += shape_size_; + } + if (axis_ == 1) { + batch_size_ = input_shape[0]; + channel_size_ = input_shape[1]; + } else if (axis_ == 0) { + batch_size_ = input_shape[1]; + channel_size_ = input_shape[0]; + input_shape_.push_back(input_shape[0]); + input_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[1]); + transpose_shape_.push_back(input_shape[0]); + transpose_axis_.push_back(1); + transpose_axis_.push_back(0); + } else { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; + } + + height_ = 1; + width_ = 1; + input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + output_size_ = input_size_; + workspace_size_ = IntToSize(shape_size_) * sizeof(int); + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t y_desc_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_null_input_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + std::vector input_shape_; + std::vector transpose_shape_; + std::vector transpose_axis_; + int axis_; + int shape_size_; + + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc new file mode 100644 index 0000000000..81b46f520c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + SparseSoftmaxCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + SparseSoftmaxCrossEntropyWithLogits, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int64_t) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h new file mode 100644 index 0000000000..bcb8a6b333 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h @@ -0,0 +1,206 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ + +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class SparseSoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { + public: + SparseSoftmaxCrossEntropyWithLogitsGpuKernel() + : cudnn_handle_(nullptr), + logits_descriptor_(nullptr), + softmax_output_descriptor_(nullptr), + algo_(CUDNN_SOFTMAX_ACCURATE), + mode_(CUDNN_SOFTMAX_MODE_INSTANCE), + cudnn_data_type_(CUDNN_DATA_FLOAT), + is_grad_(false), + is_null_input_(false), + logits_size_(0), + labels_size_(0), + output_size_(0), + softmax_output_logits_size_(0), + batch_size_(0), + channel_size_(0), + height_(0), + width_(0) {} + ~SparseSoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + T *logits_addr = GetDeviceAddress(inputs, 0); + S *labels_addr = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + T *softmax_output_logits = GetDeviceAddress(workspace, 0); + + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, + softmax_output_descriptor_, softmax_output_logits), + "cudnnSoftmaxForward failed."); + + is_grad_ ? CrossEntropyGradWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, + reinterpret_cast(stream_ptr)) + : CrossEntropyWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num + << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num + << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 1 output."; + return false; + } + is_grad_ = GetAttr(kernel_node, "is_grad"); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + + InferInputOutputSize(kernel_node); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, + batch_size_, channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, + channel_size_, height_, width_), + "cudnnSetTensor4dDescriptor failed."); + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), + "cudnnCreateTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), + "cudnnCreateTensorDescriptor failed."); + } + void InitSizeLists() override { + input_size_list_.push_back(logits_size_); + input_size_list_.push_back(labels_size_); + output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(softmax_output_logits_size_); + return; + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), + "cudnnDestroyTensorDescriptor failed."); + } + void InferInputOutputSize(const CNodePtr &kernel_node) { + auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; + InitSizeLists(); + return; + } + auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + is_null_input_ = CHECK_NULL_INPUT(logits_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; + InitSizeLists(); + return; + } + CheckShapeValidation(logits_shape, labels_shape); + + size_t logits_dims = logits_shape.size(); + batch_size_ = 1; + for (size_t i = 0; i < logits_dims - 1; i++) { + batch_size_ *= logits_shape[i]; + } + channel_size_ = logits_shape[logits_dims - 1]; + height_ = 1; + width_ = 1; + logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; + + labels_size_ = 1; + size_t labels_dims = labels_shape.size(); + for (size_t i = 0; i < labels_dims; i++) { + labels_size_ *= labels_shape[i]; + } + labels_size_ *= sizeof(S); + + output_size_ = is_grad_ ? logits_size_ : sizeof(T); + softmax_output_logits_size_ = logits_size_; + return; + } + void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { + size_t logits_dim_length = logits_shape.size(); + size_t labels_dim_length = labels_shape.size(); + if (labels_dim_length != logits_dim_length - 1) { + MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1 for " + "SparseSoftmaxCrossEntropyWithLogits, " + "but got Labels shape length:" + << labels_dim_length << ", Logits shape length:" << logits_dim_length; + } + if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { + MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; + } + return; + } + + cudnnHandle_t cudnn_handle_; + cudnnTensorDescriptor_t logits_descriptor_; + cudnnTensorDescriptor_t softmax_output_descriptor_; + cudnnSoftmaxAlgorithm_t algo_; + cudnnSoftmaxMode_t mode_; + cudnnDataType_t cudnn_data_type_; + bool is_grad_; + bool is_null_input_; + + size_t logits_size_; + size_t labels_size_; + size_t output_size_; + size_t softmax_output_logits_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + size_t batch_size_; + size_t channel_size_; + size_t height_; + size_t width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc new file mode 100644 index 0000000000..4e07463a6c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/assign_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Assign, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AssignGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Assign, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AssignGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AssignGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h new file mode 100644 index 0000000000..76e863393c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.h @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class AssignGpuKernel : public GpuKernel { + public: + AssignGpuKernel() : input_size_(0) {} + ~AssignGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *var = GetDeviceAddress(inputs, 0); + T *value = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(var, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "cudaMemxcpyAsync failed."); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(output, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "cudaMemxcpyAsync failed."); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + input_size_list_.push_back(input_size_); + output_size_list_.push_back(input_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but AssignGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + size_t input_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.cc new file mode 100644 index 0000000000..92652f67f9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFold2, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFold2GpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h new file mode 100644 index 0000000000..83600e20df --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_gpu_kernel.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFold2GpuKernel : public GpuKernel { + public: + BatchNormFold2GpuKernel() + : cudnn_handle_(nullptr), + is_null_input_(false), + batch_size_(0), + channel_(0), + height_(0), + width_(0), + freeze_bn_(0) {} + + ~BatchNormFold2GpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + auto *input = GetDeviceAddress(inputs, 0); + auto *beta = GetDeviceAddress(inputs, 1); + auto *gamma = GetDeviceAddress(inputs, 2); + auto *batch_std = GetDeviceAddress(inputs, 3); + auto *batch_mean = GetDeviceAddress(inputs, 4); + auto *running_std = GetDeviceAddress(inputs, 5); + auto *running_mean = GetDeviceAddress(inputs, 6); + auto *global_step = GetDeviceAddress(inputs, 7); + auto *output = GetDeviceAddress(outputs, 0); + + BatchNormFold2Forward(input, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, output, + freeze_bn_, batch_size_, channel_, height_, width_, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "BatchNormFold2GpuKernel input is null"; + InitSizeLists(); + return true; + } + + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "BatchNormFold2GpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = channel_ * sizeof(T); + input_size_list_.push_back(input_size); + input_size_list_.push_back(weight_size); // beta + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // batch_std + input_size_list_.push_back(weight_size); // batch_mean + input_size_list_.push_back(weight_size); // running_std + input_size_list_.push_back(weight_size); // running_mean + input_size_list_.push_back(sizeof(int32_t)); // global_step + output_size_list_.push_back(input_size); + } + + private: + void DestroyResource() noexcept {} + + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + size_t freeze_bn_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc new file mode 100644 index 0000000000..6fc080713a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFold2GradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h new file mode 100644 index 0000000000..3335210925 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold2_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFold2GradGpuKernel : public GpuKernel { + public: + BatchNormFold2GradGpuKernel() + : cudnn_handle_(nullptr), + is_null_input_(false), + batch_size_(0), + channel_(0), + height_(0), + width_(0), + freeze_bn_(0) {} + + ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + if (is_null_input_) { + return true; + } + + auto *dout = GetDeviceAddress(inputs, 0); + auto *x = GetDeviceAddress(inputs, 1); + auto *gamma = GetDeviceAddress(inputs, 2); + auto *batch_std = GetDeviceAddress(inputs, 3); + auto *batch_mean = GetDeviceAddress(inputs, 4); + auto *running_std = GetDeviceAddress(inputs, 5); + auto *running_mean = GetDeviceAddress(inputs, 6); + auto *global_step = GetDeviceAddress(inputs, 7); + auto *d_batch_std = GetDeviceAddress(outputs, 0); + auto *d_batch_mean = GetDeviceAddress(outputs, 1); + auto *d_beta = GetDeviceAddress(outputs, 2); + auto *d_gamma = GetDeviceAddress(outputs, 3); + auto *d_x = GetDeviceAddress(outputs, 4); + auto *tmp = GetDeviceAddress(workspace, 0); + auto *tmp2 = GetDeviceAddress(workspace, 1); + auto *reduce_x = GetDeviceAddress(workspace, 2); + auto *tmp_x = GetDeviceAddress(workspace, 3); + + int32_t current_step_host[1]; + size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + + BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, + reinterpret_cast(stream_ptr)); + if (current_step_host[0] < freeze_bn_) { + CalBatchNormFold2GradNotFreezeDxMul(batch_std, running_std, d_x, batch_size_, channel_, height_, width_, + reinterpret_cast(stream_ptr)); + CalBatchNormFold2GradNotFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, + d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); + } else { + CalBatchNormFold2GradFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, + d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + is_null_input_ = CHECK_NULL_INPUT(input_shape); + if (is_null_input_) { + MS_LOG(WARNING) << "BatchNormFold2GradGpuKernel input is null"; + InitSizeLists(); + return true; + } + + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "BatchNormFold2GradGpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + InitSizeLists(); + return true; + } + + protected: + void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } + + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = channel_ * sizeof(T); + size_t workspace_size = batch_size_ * channel_ * sizeof(T); + input_size_list_.push_back(input_size); // dout + input_size_list_.push_back(input_size); // x + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // batch_std + input_size_list_.push_back(weight_size); // batch_mean + input_size_list_.push_back(weight_size); // running_std + input_size_list_.push_back(weight_size); // running_mean + input_size_list_.push_back(sizeof(int32_t)); // global_step + + output_size_list_.push_back(weight_size); // d_batch_std + output_size_list_.push_back(weight_size); // d_batch_mean + output_size_list_.push_back(weight_size); // d_beta + output_size_list_.push_back(weight_size); // d_gamma + output_size_list_.push_back(input_size); // d_x + + workspace_size_list_.push_back(workspace_size); // tmp + workspace_size_list_.push_back(workspace_size); // tmp2 + workspace_size_list_.push_back(weight_size); // reduce_x + workspace_size_list_.push_back(input_size); // tmp_x + } + + private: + void DestroyResource() noexcept {} + + cudnnHandle_t cudnn_handle_; + bool is_null_input_; + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + int32_t freeze_bn_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.cc new file mode 100644 index 0000000000..95349c84aa --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFold, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFoldGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h new file mode 100644 index 0000000000..11b150686c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_gpu_kernel.h @@ -0,0 +1,209 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFoldGpuKernel : public GpuKernel { + public: + BatchNormFoldGpuKernel() + : input_size_(0), + output_size_(0), + exp_avg_factor_(0.9), + epsilon_(1e-12), + is_training_(true), + freeze_bn_(0), + batch_(0), + channel_(0), + height_(0), + width_(0), + mode_(CUDNN_BATCHNORM_SPATIAL), + x_desc_(nullptr), + scale_bias_mean_var_desc_(nullptr), + handle_(nullptr) {} + + ~BatchNormFoldGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + (void)workspace; + auto x = GetDeviceAddress(inputs, 0); + auto mean = GetDeviceAddress(inputs, 1); + auto variance = GetDeviceAddress(inputs, 2); + int *current_step = GetDeviceAddress(inputs, 3); + int current_step_host[1]; + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "Copy gpu memoy failed."); + if (x == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; + return false; + } + if (mean == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel mean is null."; + return false; + } + if (variance == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel variance is null."; + return false; + } + if (current_step == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; + return false; + } + auto batch_mean = GetDeviceAddress(outputs, 0); + auto batch_std = GetDeviceAddress(outputs, 1); + auto running_mean = GetDeviceAddress(outputs, 2); + auto running_std = GetDeviceAddress(outputs, 3); + auto y = GetDeviceAddress(workspace, 0); + + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); + CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast(stream_ptr)); + if (!is_training_ || current_step_host[0] >= freeze_bn_) { + CHECK_CUDA_RET_WITH_ERROR(cudaMemset(batch_mean, 0, output_size_), "Failed to set gpu memory."); + ThrustFillWith(batch_std, channel_, 1.f, reinterpret_cast(stream_ptr)); + return true; + } + const T alpha = 1; + const T beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardTraining( + handle_, mode_, &alpha, &beta, x_desc_, x, x_desc_, y, scale_bias_mean_var_desc_, + mean, mean, exp_avg_factor_, mean, variance, epsilon_, batch_mean, batch_std), + "Failed to launch kernel.") + CalUpdateBatchStd(channel_, batch_std, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(ERROR) << "Input number is " << input_num << " but BatchNormFold GpuKernel OP needs 4 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 4) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFold GpuKernel OP needs 4 output."; + return false; + } + + T momentum = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("momentum")); + exp_avg_factor_ = 1.0 - momentum; + epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); + is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "Input shape is " << input_shape.size() + << ", but BatchNormFold GpuKernel OP needs 4DTensor input."; + return false; + } + batch_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; + output_size_ = sizeof(T) * channel_; + + cudnnDataType_t cudnnDataType = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), + "Set x desc failed"); + + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, 1, channel_, 1, 1), + "Set para desc failed"); + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // x, mean, variance, current_step + input_size_list_.push_back(input_size_); + input_size_list_.push_back(output_size_); + input_size_list_.push_back(output_size_); + input_size_list_.push_back(sizeof(int)); + + // batch_mean, batch_std, running_mean, running_std + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_); + + // store y + workspace_size_list_.push_back(input_size_); + } + + void InitResource() override { + handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); + } + + private: + void DestroyResource() noexcept { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); + } + + size_t input_size_; + size_t output_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + double exp_avg_factor_; + double epsilon_; + bool is_training_; + int freeze_bn_; + int batch_; + int channel_; + int height_; + int width_; + + cudnnBatchNormMode_t mode_; + cudnnTensorDescriptor_t x_desc_; + cudnnTensorDescriptor_t scale_bias_mean_var_desc_; + + cudnnHandle_t handle_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc new file mode 100644 index 0000000000..b727c6c7df --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(BatchNormFoldGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + BatchNormFoldGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h new file mode 100644 index 0000000000..93a3cbf46e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/batchnorm_fold_grad_gpu_kernel.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/batchnorm_fold_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class BatchNormFoldGradGpuKernel : public GpuKernel { + public: + BatchNormFoldGradGpuKernel() + : input_size_(0), + channel_size_(0), + workspace_size_(0), + momentum_(0.1), + epsilon_(1e-12), + is_training_(true), + freeze_bn_(0), + current_step_(0), + batch_(0), + channel_(0), + height_(0), + width_(0) {} + ~BatchNormFoldGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' + T *d_batch_mean = GetDeviceAddress(inputs, 0); + T *d_batch_std = GetDeviceAddress(inputs, 1); + T *x = GetDeviceAddress(inputs, 2); + T *batch_mean = GetDeviceAddress(inputs, 3); + T *batch_std = GetDeviceAddress(inputs, 4); + int *current_step = GetDeviceAddress(inputs, 5); + int current_step_host[1]; + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), + "Copy gpu memoy failed."); + if (d_batch_mean == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; + return false; + } + if (d_batch_std == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_std is null."; + return false; + } + if (x == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel x is null."; + return false; + } + if (batch_mean == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_mean is null."; + return false; + } + if (batch_std == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_std is null."; + return false; + } + if (current_step == nullptr) { + MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; + return false; + } + T *dx = GetDeviceAddress(outputs, 0); + + if (!is_training_ || current_step_host[0] >= freeze_bn_) { + ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast(stream_ptr)); + return true; + } + CalBatchNormFoldGrad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_, channel_, height_, width_, dx, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 6) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFoldGrad GpuKernel OP needs 4 output."; + return false; + } + + epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); + is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); + freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "Input shape is " << input_shape.size() + << ", but BatchNormFoldGrad GpuKernel OP needs 4DTensor input."; + return false; + } + batch_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; + channel_size_ = sizeof(T) * channel_; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(input_size_); + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(channel_size_); + input_size_list_.push_back(sizeof(int)); + // 'dx' + output_size_list_.push_back(input_size_); + } + + private: + size_t input_size_; + size_t channel_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + T momentum_; + T epsilon_; + bool is_training_; + int freeze_bn_; + int current_step_; + int batch_; + int channel_; + int height_; + int width_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.cc new file mode 100644 index 0000000000..9af5451c53 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020、 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CorrectionMul, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CorrectionMulGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h new file mode 100644 index 0000000000..4ba6285e4b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_gpu_kernel.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class CorrectionMulGpuKernel : public GpuKernel { + public: + CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} + ~CorrectionMulGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + auto *weight = GetDeviceAddress(inputs, 0); + auto *gamma = GetDeviceAddress(inputs, 1); + auto *running_std = GetDeviceAddress(inputs, 2); + auto *output = GetDeviceAddress(outputs, 0); + + CalCorrectionMul(weight, gamma, running_std, batch_size_, channel_, height_, width_, output, + reinterpret_cast(stream_ptr)); + return true; + } + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGpuKernel needs 3."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = batch_size_ * sizeof(T); + input_size_list_.push_back(input_size); // weight + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // running_std + output_size_list_.push_back(input_size); + } + + void InitResource() override {} + + private: + void DestroyResource() noexcept {} + + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.cc new file mode 100644 index 0000000000..63a47bc452 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + CorrectionMulGradGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h new file mode 100644 index 0000000000..b9fcbf0787 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/correction_mul_grad_gpu_kernel.h @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/correction_mul_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class CorrectionMulGradGpuKernel : public GpuKernel { + public: + CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} + ~CorrectionMulGradGpuKernel() override { DestroyResource(); } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto *d_out = GetDeviceAddress(inputs, 0); + auto *weight = GetDeviceAddress(inputs, 1); + auto *gamma = GetDeviceAddress(inputs, 2); + auto *running_std = GetDeviceAddress(inputs, 3); + auto *d_weight = GetDeviceAddress(outputs, 0); + auto *d_gamma = GetDeviceAddress(outputs, 1); + auto *tmp = GetDeviceAddress(workspace, 0); + + CalCorrectionMul(d_out, gamma, running_std, batch_size_, channel_, height_, width_, d_weight, + reinterpret_cast(stream_ptr)); + CalCorrectionMulGrad(d_out, weight, running_std, batch_size_, channel_, height_, width_, d_gamma, tmp, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + InitResource(); + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGradGpuKernel needs 4."; + return false; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + if (input_shape.size() != 4) { + MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."; + return false; + } + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + height_ = input_shape[2]; + width_ = input_shape[3]; + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); + size_t weight_size = batch_size_ * sizeof(T); + input_size_list_.push_back(input_size); // d_out + input_size_list_.push_back(input_size); // weight + input_size_list_.push_back(weight_size); // gamma + input_size_list_.push_back(weight_size); // running_std + output_size_list_.push_back(input_size); // d_weight + output_size_list_.push_back(weight_size); // d_gamma + workspace_size_list_.push_back(input_size); // tmp d_out * weight + } + void InitResource() override {} + + private: + void DestroyResource() noexcept {} + + size_t batch_size_; + size_t channel_; + size_t height_; + size_t width_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc new file mode 100644 index 0000000000..8a43ce0941 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.cc @@ -0,0 +1,147 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() + : input_size_(0), + num_channels_(0), + num_bits_(0), + training_(false), + symmetric_(false), + narrow_range_(false), + quant_delay_(0), + quant_min_(0), + quant_max_(0), + global_step_(0) {} + +const std::vector &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input."; + return false; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output."; + return false; + } + + // get attribute + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; + return false; + } + + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; + return false; + } + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // shape info for gpu + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerChannelGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input in tensor + input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar + input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar + output_size_list_.push_back(input_size_); // output in tensor + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel +} + +void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, + float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, + symmetric_, reinterpret_cast(stream_ptr)); + CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); +} + +bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + (void)workspace; + float *output = GetDeviceAddress(outputs, 0); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; + } + + if (training_) { + if (global_step_ >= quant_delay_) { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; + } else { + CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h new file mode 100755 index 0000000000..8e2c9524b2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_gpu_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerChannelGpuKernel : public GpuKernel { + public: + FakeQuantPerChannelGpuKernel(); + ~FakeQuantPerChannelGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, + float *nudge_max, float *scale, void *stream_ptr); + + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_channels_; + int num_bits_; + bool training_; + bool symmetric_; + bool narrow_range_; + int quant_delay_; + float quant_min_; + float quant_max_; + int global_step_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc new file mode 100644 index 0000000000..598a6a960d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" + +namespace mindspore { +namespace kernel { +FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() + : input_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + num_channels_(0), + quant_delay_(0), + global_step_(0), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &FakeQuantPerChannelGradGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerChannelGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; + } + + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // min + input_size_list_.push_back(sizeof(float) * num_channels_); // max + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel + workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel +} + +bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + (void)workspace; + float *output = GetDeviceAddress(outputs, 0); + float *gradient = GetDeviceAddress(inputs, 0); + float *input = GetDeviceAddress(inputs, 1); + float *input_min = GetDeviceAddress(inputs, 2); + float *input_max = GetDeviceAddress(inputs, 3); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (gradient == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; + } + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input is null"; + } + if (input_min == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input min is null"; + } + if (input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input max is null"; + } + + int total_size = input_size_ / sizeof(float); + if (global_step_ >= quant_delay_) { + CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, + symmetric_, reinterpret_cast(stream_ptr)); + CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h new file mode 100644 index 0000000000..c2611ab8a2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerChannelGradGpuKernel : public GpuKernel { + public: + FakeQuantPerChannelGradGpuKernel(); + ~FakeQuantPerChannelGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_bits_; + float quant_min_; + float quant_max_; + int num_channels_; + int quant_delay_; + int global_step_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc new file mode 100644 index 0000000000..24edec97a9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel() + : input_size_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + global_step_(0), + num_bits_(0), + quant_delay_(0), + training_(false), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; + } + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // x + input_size_list_.push_back(sizeof(float)); // min + input_size_list_.push_back(sizeof(float)); // max + output_size_list_.push_back(input_size_); // y + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max +} + +bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *output = GetDeviceAddress(outputs, 0); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null."; + } + + if (training_) { + // control flow for quant_delay + if (global_step_ >= quant_delay_) { + // real launch + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + } else { + // real launch + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); + } + + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h new file mode 100755 index 0000000000..6df4da3104 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_gpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerLayerGpuKernel : public GpuKernel { + public: + FakeQuantPerLayerGpuKernel(); + ~FakeQuantPerLayerGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + float quant_min_; + float quant_max_; + int quant_num_; + int global_step_; + int num_bits_; + int quant_delay_; + bool training_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc new file mode 100644 index 0000000000..f96b6a48d2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" + +namespace mindspore { +namespace kernel { +FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel() + : input_size_(0), + workspace_size_(0), + num_bits_(0), + quant_min_(0), + quant_max_(0), + quant_num_(1), + quant_delay_(0), + global_step_(0), + narrow_range_(false), + symmetric_(false) {} + +const std::vector &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; + } + + num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); + if (num_bits_ <= 2 || num_bits_ >= 16) { + MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; + } + + quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); + if (quant_delay_ < 0) { + MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; + } + + symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); + narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); + + // quant min and max value + quant_min_ = 0; + quant_max_ = (1 << num_bits_) - 1; + if (narrow_range_) { + quant_min_++; + } + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void FakeQuantPerLayerGradGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // gradient + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // min + input_size_list_.push_back(sizeof(float)); // max + output_size_list_.push_back(input_size_); // output + workspace_size_list_.push_back(sizeof(float)); // scale + workspace_size_list_.push_back(sizeof(float)); // nudge_min + workspace_size_list_.push_back(sizeof(float)); // nudge_max +} + +bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + float *output = GetDeviceAddress(outputs, 0); + float *gradient = GetDeviceAddress(inputs, 0); + float *input = GetDeviceAddress(inputs, 1); + float *input_min = GetDeviceAddress(inputs, 2); + float *input_max = GetDeviceAddress(inputs, 3); + float *scale = GetDeviceAddress(workspace, 0); + float *nudge_min = GetDeviceAddress(workspace, 1); + float *nudge_max = GetDeviceAddress(workspace, 2); + + if (gradient == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null"; + } + if (input == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null."; + } + + if (global_step_ >= quant_delay_) { + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); + } else { + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); + } + global_step_++; + return true; +} + +MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h new file mode 100644 index 0000000000..475723f684 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class FakeQuantPerLayerGradGpuKernel : public GpuKernel { + public: + FakeQuantPerLayerGradGpuKernel(); + ~FakeQuantPerLayerGradGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel_node) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int num_bits_; + float quant_min_; + float quant_max_; + int quant_num_; + int quant_delay_; + int global_step_; + bool narrow_range_; + bool symmetric_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc new file mode 100644 index 0000000000..742a9b8c55 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() + : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {} + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const { + return workspace_size_list_; +} + +bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + num_channels_ = SizeToInt(input_shape[0]); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float) * num_channels_); // min + input_size_list_.push_back(sizeof(float) * num_channels_); // max + output_size_list_.push_back(sizeof(float) * num_channels_); // output min + output_size_list_.push_back(sizeof(float) * num_channels_); // output max +} + +bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + float *output_min = GetDeviceAddress(outputs, 0); + float *output_max = GetDeviceAddress(outputs, 1); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null."; + } + + // calculate the input min and max according by the parameter ema and ema_decay. + CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, + ema_decay_, ema_, reinterpret_cast(stream_ptr)); + return true; +} + +MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h new file mode 100644 index 0000000000..9a0fe23e6a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perchannel_gpu_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { + public: + MinMaxUpdatePerChannelGpuKernel(); + ~MinMaxUpdatePerChannelGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + bool ema_; + float ema_decay_; + int num_channels_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc new file mode 100644 index 0000000000..8f11e907e1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/cuda_impl/minmax_update_impl.cuh" +#include +#include +#include +#include + +namespace mindspore { +namespace kernel { +MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() + : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {} + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; + } + + ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + + // init size + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shape.size(); ++i) { + quant_num_ *= SizeToInt(input_shape[i]); + } + input_size_ = sizeof(float); + for (size_t i = 0; i < input_shape.size(); i++) { + input_size_ *= input_shape[i]; + } + InitSizeLists(); + return true; +} + +void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() { + input_size_list_.push_back(input_size_); // input + input_size_list_.push_back(sizeof(float)); // input min + input_size_list_.push_back(sizeof(float)); // input max + output_size_list_.push_back(sizeof(float)); // output min + output_size_list_.push_back(sizeof(float)); // output max +} + +bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) { + float *output_min = GetDeviceAddress(outputs, 0); + float *output_max = GetDeviceAddress(outputs, 1); + float *input = GetDeviceAddress(inputs, 0); + float *input_min = GetDeviceAddress(inputs, 1); + float *input_max = GetDeviceAddress(inputs, 2); + + if (input == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null."; + } + if (input_min == nullptr || input_max == nullptr) { + MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; + } + + CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, + reinterpret_cast(stream_ptr)); + + return true; +} + +MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h new file mode 100644 index 0000000000..80ce6185c0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/quant/minmax_update_perlayer_gpu_kernel.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { + public: + MinMaxUpdatePerLayerGpuKernel(); + ~MinMaxUpdatePerLayerGpuKernel() = default; + + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + bool Init(const CNodePtr &kernel) override; + + protected: + void InitSizeLists() override; + + private: + size_t input_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int quant_num_; + bool ema_; + float ema_decay_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc new file mode 100644 index 0000000000..5ec4f52574 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_kernel.h" +#include "runtime/device/ascend/tasksink/runtime_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" + +using HcclTaskInfoPtr = std::shared_ptr; +using ge::model_runner::HcclTaskInfo; +using mindspore::device::ascend::tasksink::RuntimeUtils; + +namespace mindspore { +namespace kernel { +void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { + hcclKernelMap_.emplace(name, std::move(fun)); +} + +std::shared_ptr HcclKernelFactory::Get(const std::string &name) { + const auto &map = Get().hcclKernelMap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +HcclKernelFactory &HcclKernelFactory::Get() { + static HcclKernelFactory _this; + return _this; +} + +HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} + +HcclKernel::~HcclKernel() { + hccl_kernel_input_shape_list_.clear(); + hccl_kernel_output_shape_list_.clear(); + hccl_data_type_list_.clear(); + hccl_count_ = 0; + op_type_ = HCCL_REP_OP_SUM; + root_id_ = 0; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + anf_node_ = nullptr; +} + +bool HcclKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + op_name_ = AnfAlgo::GetCNodeName(anf_node); + + if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { + MS_LOG(ERROR) << "GetKernelInputShape fail!"; + return false; + } + if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) { + MS_LOG(ERROR) << "GetKernelOutputShape fail!"; + return false; + } + if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { + MS_LOG(ERROR) << "GetHcomDataType fail!"; + return false; + } + if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { + MS_LOG(ERROR) << "GetHcomCount fail!"; + return false; + } + if (op_name_ == kAllReduce || op_name_ == kReduceScatter) { + if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) { + MS_LOG(ERROR) << "GetHcomOperationType fail!"; + return false; + } + } + if (op_name_ == kBroadcast) { + if (!HcomUtil::GetHcomRootId(anf_node, &root_id_)) { + MS_LOG(ERROR) << "GetHcomRootId fail!"; + return false; + } + } + HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_)); + anf_node_ = anf_node; + return true; +} + +const std::vector &HcclKernel::GetInputSizeList() const { + size_t size = 0; + if (!input_size_list_.empty()) { + return input_size_list_; + } + for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { + if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_input_shape_list_[i], &size)) { + MS_LOG(ERROR) << "GetHcclOpInputSize failed"; + } + input_size_list_.push_back(size); + } + return input_size_list_; +} + +const std::vector &HcclKernel::GetOutputSizeList() const { + size_t size = 0; + if (!output_size_list_.empty()) { + return output_size_list_; + } + for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { + if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { + MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; + } + output_size_list_.push_back(size); + } + return output_size_list_; +} + +const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } + +std::vector HcclKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.empty() || outputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; + } + stream_id_ = stream_id; + std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); + MS_EXCEPTION_IF_NULL(inputs.at(0)); + auto input_data_addr = inputs.at(0)->addr; + MS_EXCEPTION_IF_NULL(outputs.at(0)); + auto output_data_addr = outputs.at(0)->addr; + void *workspace_address = nullptr; + const int64_t workspace_num = 0; + std::vector private_def; + hcclDataType_t data_type = hccl_data_type_list_[0]; + + MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ + << ", root_id=" << root_id_ << ", op_type=" << static_cast(op_type_) + << ", data_type=" << static_cast(data_type); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + HcclTaskInfoPtr task_info_ptr = std::make_shared( + kernel_name_, stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, + private_def, nullptr, hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, + RuntimeUtils::HcomUnbindModel, RuntimeUtils::HcomDistribute, NeedDump()); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h new file mode 100644 index 0000000000..db7a0fbf7c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -0,0 +1,95 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/hccl/hcom_util.h" +#include "hccl/hcom.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +class HcclKernel : public AscendKernelMod { + public: + HcclKernel(); + ~HcclKernel() override; + virtual bool Init(const AnfNodePtr &anf_node); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + protected: + std::vector> hccl_kernel_input_shape_list_; + std::vector> hccl_kernel_output_shape_list_; + std::vector hccl_data_type_list_; + std::vector hccl_format_list_; + uint64_t hccl_count_; + hcclRedOp_t op_type_; + uint32_t root_id_; + mutable std::vector input_size_list_; + mutable std::vector output_size_list_; + mutable std::vector workspace_size_list_; + AnfNodePtr anf_node_; + std::string op_name_; + std::string group_; +}; + +using HcclKernelCreater = std::function()>; + +class HcclKernelFactory { + HcclKernelFactory() = default; + ~HcclKernelFactory() = default; + + public: + static HcclKernelFactory &Get(); + void Registe(const string &name, HcclKernelCreater &&fun); + static std::shared_ptr Get(const string &name); + + private: + std::map hcclKernelMap_; +}; + +class _HcclKernelRegister { + public: + _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { + HcclKernelFactory::Get().Registe(name, std::move(fun)); + } + ~_HcclKernelRegister() = default; +}; + +#define _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of HcclKernel"); \ + static const _HcclKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \ + std::shared_ptr ptr = nullptr; \ + ptr = std::make_shared(); \ + MS_EXCEPTION_IF_NULL(ptr); \ + return ptr; \ + }); + +#define MS_HCCL_REG_KERNEL(KNAME, clazz) _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) +} // namespace kernel +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.cc new file mode 100644 index 0000000000..8297be0b6d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_kernel_build.h" + +#include +#include +#include + +#include "backend/kernel_compiler/hccl/hccl_kernel.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string opname = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "Hccl op [" << opname << "]"; + auto kerPtr = HcclKernelFactory::Get(opname); + if (kerPtr == nullptr) { + MS_LOG(ERROR) << "Hccl can't find Kernel[" << opname << "]"; + return nullptr; + } + if (!kerPtr->Init(anf_node)) { + MS_LOG(ERROR) << "Kernel initialize failed!"; + return nullptr; + } + return kerPtr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h new file mode 100644 index 0000000000..21b34d6522 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_build.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ + +#include +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc new file mode 100755 index 0000000000..55742d383c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" +#include +#include +#include "utils/utils.h" +#include "backend/kernel_compiler/hccl/hcom_util.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +namespace { +std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { + const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; + auto op_name = AnfAlgo::GetCNodeName(kernel_node); + auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); + if (op_name != kReduceScatter && op_name != kAllGatherOpName) { + return format; + } + if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) { + return kOpFormat_DEFAULT; + } + if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) { + return kOpFormat_DEFAULT; + } + return format; +} +} // namespace +void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + const std::vector kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, + kNumberTypeFloat32, kNumberTypeInt16}; + MS_EXCEPTION_IF_NULL(kernel_info_list); + MS_EXCEPTION_IF_NULL(kernel_node); + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) { + MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]"; + return; + } + for (const auto &type : kHcclSupportTypes) { + std::vector inputs_format{}; + std::vector inputs_type{}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); + inputs_type.push_back(type); + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index)); + outputs_type.push_back(type); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetKernelType(HCCL_KERNEL); + kernel_info_list->push_back(builder.Build()); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h new file mode 100755 index 0000000000..25891fdaf6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc new file mode 100644 index 0000000000..e9fb4c9314 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_broadcast.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "BroadCast param is empty"; + return false; + } + const char *tag = "Hccl-BroadCast"; + MS_EXCEPTION_IF_NULL(inputs[0]); + hcclResult_t ret = + hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h new file mode 100644 index 0000000000..6434b5fb9c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ + +#include +#include +#include "hccl/hcom.h" +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllBroadCastKernel : public HcclKernel { + public: + HcomAllBroadCastKernel() = default; + ~HcomAllBroadCastKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; +MS_HCCL_REG_KERNEL(Broadcast, HcomAllBroadCastKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc new file mode 100644 index 0000000000..201071dcb5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_gather.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllGatherKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "AllGather param is empty"; + return false; + } + const char *tag = "Hccl-AllGather"; + hcclResult_t ret = + hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h new file mode 100644 index 0000000000..21d8ffa484 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ + +#include +#include +#include "hccl/hcom.h" +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllGatherKernel : public HcclKernel { + public: + HcomAllGatherKernel() = default; + ~HcomAllGatherKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; +MS_HCCL_REG_KERNEL(AllGather, HcomAllGatherKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc new file mode 100644 index 0000000000..533ce1b087 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_reduce.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllReduceKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "AllReduce param is empty"; + return false; + } + const char *tag = "Hccl-AllReduce"; + hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], + op_type_, nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h new file mode 100644 index 0000000000..39641f7448 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ + +#include +#include +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllReduceKernel : public HcclKernel { + public: + HcomAllReduceKernel() = default; + ~HcomAllReduceKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; + +MS_HCCL_REG_KERNEL(AllReduce, HcomAllReduceKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc new file mode 100644 index 0000000000..32c6dacb01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h" + +#include +#include +#include + +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +bool HcomAllReduceScatterKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink()) { + return true; + } + if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { + MS_LOG(ERROR) << "ReduceScatter param is empty"; + return false; + } + const char *tag = "Hccl-ReduceScatter"; + hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], + op_type_, nullptr, stream_ptr); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast(ret); + return false; + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h new file mode 100644 index 0000000000..2f4ace5aea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ +#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ + +#include +#include +#include "hccl/hcom.h" +#include "backend/kernel_compiler/hccl/hccl_kernel.h" + +namespace mindspore { +namespace kernel { +class HcomAllReduceScatterKernel : public HcclKernel { + public: + HcomAllReduceScatterKernel() = default; + ~HcomAllReduceScatterKernel() override = default; + + /* Inherit from kernelmod */ + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + + private: +}; + +MS_HCCL_REG_KERNEL(ReduceScatter, HcomAllReduceScatterKernel); +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc new file mode 100644 index 0000000000..721c1b6ba0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/hccl/hcom_util.h" + +#include + +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" + +namespace mindspore { +bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_intput_shape_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); + hccl_kernel_intput_shape_list->emplace_back(shape_i); + } + + return true; +} + +bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_output_shape_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list); + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) { + std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); + hccl_kernel_output_shape_list->emplace_back(shape_i); + } + + return true; +} + +bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(data_type_list); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i); + auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr); + if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { + MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr; + } + data_type_list->emplace_back(iter->second); + } + auto type_base = *(std::begin(*data_type_list)); + if (std::any_of(data_type_list->begin(), data_type_list->end(), + [&type_base](hcclDataType_t type) { return type != type_base; })) { + MS_LOG(ERROR) << "hccl have different data type"; + return false; + } + return true; +} + +bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector &shape, size_t *size) { + MS_EXCEPTION_IF_NULL(size); + size_t tmp_size = 1; + uint32_t type_size = 4; + for (size_t i = 0; i < shape.size(); i++) { + tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]); + } + + if (!GetHcomTypeSize(data_type, &type_size)) { + return false; + } + + *size = SizetMulWithOverflowCheck(tmp_size, type_size); + + MS_LOG(INFO) << "size[" << *size << "]"; + return true; +} + +bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) { + MS_EXCEPTION_IF_NULL(size); + auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); + if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { + MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!"; + return false; + } + *size = iter->second; + return true; +} + +bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, + const vector> &shape_list, uint64_t *total_count) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(total_count); + const uint32_t align_size = 512; + const uint32_t filled_size = 32; + uint64_t total_size = 0; + uint64_t block_size; + size_t input_size; + uint32_t type_size = 4; + + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + if (!GetHcomTypeSize(data_type_list[i], &type_size)) { + return false; + } + + if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) { + MS_LOG(ERROR) << "Get GetHcclOpSize failed"; + return false; + } + + if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) { + int32_t rank_size; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("rank_size") != nullptr) { + rank_size = GetValue(primitive->GetAttr("rank_size")); + } else { + MS_LOG(ERROR) << "Get rank size failed"; + return false; + } + block_size = input_size / IntToSize(rank_size); + total_size = total_size + block_size; + } else { + if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { + block_size = input_size; + } else { + block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; + } + total_size = total_size + block_size; + } + } + + if (type_size == 0 || total_size % type_size != 0) { + MS_LOG(ERROR) << "Total_size[" << total_size << "],Type_size[" << type_size << "] != 0, fail!"; + return false; + } + *total_count = total_size / type_size; + return true; +} + +bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_type); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("op") == nullptr) { + MS_LOG(ERROR) << "Get HCOM_ATTR_REDUCE_TYPE fail, not support!"; + return false; + } + auto hcom_op_type_get = GetValue(primitive->GetAttr("op")); + string hcom_op_type(hcom_op_type_get); + if (hcom_op_type == "min") { + *op_type = HCCL_REP_OP_MIN; + } else if (hcom_op_type == "max") { + *op_type = HCCL_REP_OP_MAX; + } else if (hcom_op_type == "prod") { + *op_type = HCCL_REP_OP_PROD; + } else if (hcom_op_type == "sum") { + *op_type = HCCL_REP_OP_SUM; + } else { + MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; + return false; + } + return true; +} + +bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(root_id); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (primitive->GetAttr("root_rank") != nullptr) { + *root_id = (uint32_t)GetValue(primitive->GetAttr("root_rank")); + } else { + MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; + return false; + } + return true; +} + +void HcomUtil::GetHcomGroup(NotNull anf_node, NotNull group) { + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + auto attr = primitive->GetAttr("group"); + if (attr != nullptr) { + *group = GetValue(attr); + } else { + MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed"; + } +} +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h similarity index 100% rename from mindspore/ccsrc/kernel/hccl/hcom_util.h rename to mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc b/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc new file mode 100644 index 0000000000..9933826f2b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "nlohmann/json.hpp" +#include "securec/include/securec.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +namespace mindspore { +namespace kernel { +constexpr auto kUtilsModule = "mindspore._extends.utils"; +constexpr auto kCalSha256Func = "cal_sha256"; + +namespace { +bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) { + if (js.find("sha256") == js.end()) { + MS_LOG(ERROR) << "No sha256 found in " << json_file; + return false; + } + std::string sha256_str = js["sha256"]; + py::object ret = parse::python_adapter::CallPyFn(kUtilsModule, kCalSha256Func, bin_file); + std::string sha256_cal = py::cast(ret); + if (sha256_cal.empty()) { + MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; + return false; + } + if (sha256_cal != sha256_str) { + MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; + return false; + } + return true; +} +} // namespace + +const std::string KernelPack::Serialize() const { + MS_EXCEPTION_IF_NULL(json_); + MS_EXCEPTION_IF_NULL(kernel_); + std::string buffer; + (void)buffer.append((const char *)json_, json_->len + sizeof(json_->len)); + (void)buffer.append((const char *)kernel_, kernel_->len + sizeof(kernel_->len)); + return buffer; +} + +bool KernelPack::ReadFromJsonFileHelper(std::ifstream &kernelbin) { + size_t binsize = LongToSize(kernelbin.seekg(0, std::ios::end).tellg()); + // free old data + if (kernel_ != nullptr) { + delete[] kernel_; + kernel_ = nullptr; + } + + void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); + if (ptr != nullptr) { + kernel_ = static_cast(ptr); + } + if (kernel_ == nullptr) { + MS_LOG(ERROR) << "memory malloc failed."; + kernelbin.close(); + return false; + } + if (memset_s(kernel_, sizeof(KernelPack) + binsize, 0, sizeof(KernelPack) + binsize) != EOK) { + MS_LOG(ERROR) << "memset kernel_ failed."; + delete[] kernel_; + kernel_ = nullptr; + kernelbin.close(); + return false; + } + kernel_->len = binsize; + MS_LOG(INFO) << "kernel len:" << kernel_->len; + (void)kernelbin.seekg(0, std::ios::beg); + (void)kernelbin.read(kernel_->contents, SizeToLong(kernel_->len)); + return true; +} + +bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &processor) { + if (json_f.length() <= strlen(kJsonSuffix)) { + MS_LOG(ERROR) << "please check json path."; + return false; + } + + std::ifstream kerneljson(json_f); + if (!kerneljson.is_open()) { + MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; + return false; + } + nlohmann::json js; + kerneljson >> js; + + size_t binsize = LongToSize(kerneljson.seekg(0, std::ios::end).tellg()); + void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); + if (ptr != nullptr) { + json_ = static_cast(ptr); + } + if (json_ == nullptr) { + MS_LOG(ERROR) << "memory malloc failed."; + kerneljson.close(); + return false; + } + json_->len = binsize; + (void)kerneljson.seekg(0, std::ios::beg); + (void)kerneljson.read(json_->contents, SizeToLong(json_->len)); + + if (processor == kProcessorCuda) { + std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx"; + std::ifstream kernelbin(bin_f); + if (!kernelbin.is_open()) { + MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta."; + kerneljson.close(); + return false; + } + + if (ReadFromJsonFileHelper(kernelbin) == false) { + delete[] json_; + json_ = nullptr; + kerneljson.close(); + return false; + } + kerneljson.close(); + if (!CheckHash(json_f, bin_f, js)) { + return false; + } + return true; + } + + std::string binfilesuffix = js["binFileSuffix"]; + std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfilesuffix; + if (binfilesuffix.compare(".so") == 0) { + // change "xx/xx.so" -> "xx/libxx.so" + auto sp = bin_f.rfind('/'); + if (sp == std::string::npos) { + MS_LOG(ERROR) << "illegal bin file path " << bin_f; + kerneljson.close(); + return false; + } + bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); + } + + std::ifstream kernelbin(bin_f, std::ios::binary); + if (!kernelbin.is_open()) { + MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; + kerneljson.close(); + delete[] json_; + json_ = nullptr; + return false; + } + + MS_LOG(INFO) << "kernelbin_name:" << bin_f; + if (ReadFromJsonFileHelper(kernelbin) == false) { + delete[] json_; + json_ = nullptr; + kerneljson.close(); + return false; + } + kerneljson.close(); + + if (!CheckHash(json_f, bin_f, js)) { + return false; + } + + return true; +} + +void KernelPack::ParseKernelJson(const nlohmann::json &js) { + kernel_json_info_.bin_file_name = js["binFileName"]; + kernel_json_info_.bin_file_suffix = js["binFileSuffix"]; + kernel_json_info_.block_dim = js["blockDim"]; + kernel_json_info_.kernel_name = js["kernelName"]; + kernel_json_info_.magic = js["magic"]; + if (js.find("parameters") != js.end()) { + if (!js.at("parameters").is_array()) { + MS_LOG(DEBUG) << "Format error!,parameters should be array."; + } + std::vector sizes = js.at("parameters"); + for (auto size : sizes) { + MS_LOG(INFO) << "parameter " << size; + kernel_json_info_.parameters.push_back(size); + } + } + if (js.find("workspace") != js.end()) { + auto workspace = js.at("workspace"); + std::vector sizes = workspace.at("size"); + for (auto size : sizes) { + MS_LOG(INFO) << "workspace_size_list " << size; + kernel_json_info_.workspaces.push_back(size); + } + } + kernel_json_info_.sha256 = js["sha256"]; +} + +bool KernelPack::LoadKernelMeta(const std::string &json_f, const std::string &processor) { + if (json_f.length() <= strlen(kJsonSuffix)) { + MS_LOG(ERROR) << "please check json path."; + return false; + } + std::ifstream kernel_json(json_f); + if (!kernel_json.is_open()) { + MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; + return false; + } + nlohmann::json js; + kernel_json >> js; + ParseKernelJson(js); + kernel_json.close(); + + std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix; + if (kernel_json_info_.bin_file_suffix == ".so") { + // change "xx/xx.so" -> "xx/libxx.so" + auto sp = bin_f.rfind('/'); + if (sp == std::string::npos) { + MS_LOG(ERROR) << "illegal bin file path " << bin_f; + return false; + } + bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); + } + + std::ifstream kernelbin(bin_f, std::ios::binary); + if (!kernelbin.is_open()) { + MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; + return false; + } + + MS_LOG(INFO) << "kernelbin_name:" << bin_f; + if (!ReadFromJsonFileHelper(kernelbin)) { + return false; + } + + return CheckHash(json_f, bin_f, js); +} + +KernelJsonInfo KernelPack::kernel_json_info() const { return kernel_json_info_; } +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h similarity index 100% rename from mindspore/ccsrc/kernel/kernel.h rename to mindspore/ccsrc/backend/kernel_compiler/kernel.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc new file mode 100644 index 0000000000..68392d1871 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/kernel_build_info.h" +#include +#include "utils/log_adapter.h" +#include "debug/anf_ir_dump.h" +namespace mindspore { +namespace kernel { +std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { + if (input_index >= inputs_format_.size()) { + MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node"; + return kInvalidFormat; + } + return inputs_format_[input_index]; +} + +std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { + if (output_index >= outputs_format_.size()) { + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; + return kInvalidFormat; + } + return outputs_format_[output_index]; +} + +TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const { + if (input_index >= inputs_device_type_.size()) { + MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input"; + return TypeId::kNumberTypeEnd; + } + return inputs_device_type_[input_index]; +} + +TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { + if (output_index >= outputs_device_type_.size()) { + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; + return TypeId::kNumberTypeEnd; + } + return outputs_device_type_[output_index]; +} + +std::vector KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } + +std::vector KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } + +std::vector KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } + +std::vector KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } + +size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } + +size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } + +std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { + if (input_index >= input_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " + << input_reshape_type_.size(); + } + return input_reshape_type_[input_index]; +} + +std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { + if (output_index >= output_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " + << output_reshape_type_.size(); + } + return output_reshape_type_[output_index]; +} + +std::string KernelBuildInfo::ToString() const { + std::ostringstream output_buffer; + output_buffer << "("; + for (size_t index = 0; index < GetInputNum(); ++index) { + if (index != 0) { + output_buffer << ", "; + } + output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; + } + output_buffer << ") -> ("; + for (size_t index = 0; index < GetOutputNum(); ++index) { + if (index != 0) { + output_buffer << ", "; + } + output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; + } + output_buffer << ")"; + return output_buffer.str(); +} + +bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { + if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) { + return false; + } + if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) { + if (op_pattern_ != kFormatAgnosticPattern) { + return false; + } else { + MS_LOG(INFO) << "this kernel build info:" << this->ToString() + << ", other kernel build info: " << other.ToString(); + } + } + return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); +} + +bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); } + +bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } + +bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); } + +void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->kernel_type_ = kernel_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector &inputs_format) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->inputs_format_ = inputs_format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector &outputs_format) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->outputs_format_ = outputs_format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector &inputs_device_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->inputs_device_type_ = inputs_device_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector &outputs_device_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->outputs_device_type_ = outputs_device_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->fusion_type_ = fusion_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->processor_ = processor; +} + +std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( + const std::vector> &input_reshape_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->input_reshape_type_ = input_reshape_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( + const std::vector> &output_reshape_type) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->output_reshape_type_ = output_reshape_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + kernel_build_info_->op_pattern_ = pattern; +} +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + if (index >= kernel_build_info_->inputs_format_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->inputs_format_[index] = format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + if (index >= kernel_build_info_->outputs_format_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->outputs_format_[index] = format; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h new file mode 100644 index 0000000000..be243c9ae0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -0,0 +1,147 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ +#include +#include +#include +#include +#include +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +class KernelBuildInfo { + public: + class KernelBuildInfoBuilder; + + KernelBuildInfo() { + kernel_type_ = TBE_KERNEL; + fusion_type_ = OPAQUE; + processor_ = AICORE; + op_pattern_ = kCommonPattern; + input_reshape_type_ = {}; + output_reshape_type_ = {}; + inputs_format_ = {}; + outputs_format_ = {}; + inputs_device_type_ = {}; + outputs_device_type_ = {}; + } + + ~KernelBuildInfo() = default; + + KernelType kernel_type() const { return kernel_type_; } + + std::string GetInputFormat(size_t input_index) const; + + std::string GetOutputFormat(size_t output_index) const; + + TypeId GetInputDeviceType(size_t input_index) const; + + TypeId GetOutputDeviceType(size_t output_index) const; + + std::vector GetInputReshapeType(size_t input_index) const; + + bool IsInputDefaultPadding() const; + + bool IsOutputDefaultPadding() const; + + std::vector GetOutputReshapeType(size_t input_index) const; + + std::vector GetAllInputFormats() const; + + std::vector GetAllOutputFormats() const; + + std::vector GetAllInputDeviceTypes() const; + + std::vector GetAllOutputDeviceTypes() const; + + OpPattern op_pattern() const { return op_pattern_; } + + FusionType fusion_type() const { return fusion_type_; } + + Processor processor() const { return processor_; } + + size_t GetInputNum() const; + + size_t GetOutputNum() const; + + std::string ToString() const; + + bool operator==(const KernelBuildInfo &other) const; + + bool operator!=(const KernelBuildInfo &other) const; + + public: + static auto constexpr kInvalidFormat = "InvalidFormat"; + + private: + KernelType kernel_type_; + std::vector inputs_format_; + OpPattern op_pattern_; + std::vector outputs_format_; + std::vector> input_reshape_type_; + std::vector> output_reshape_type_; + std::vector inputs_device_type_; + std::vector outputs_device_type_; + FusionType fusion_type_; + Processor processor_; +}; +using KernelBuildInfoPtr = std::shared_ptr; + +class KernelBuildInfo::KernelBuildInfoBuilder { + public: + KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared(); } + + explicit KernelBuildInfoBuilder(std::shared_ptr kernel_build_info) + : kernel_build_info_(std::move(kernel_build_info)) {} + + ~KernelBuildInfoBuilder() = default; + + void SetKernelType(const KernelType &kernel_type); + + void SetInputsFormat(const std::vector &inputs_format); + + void SetOutputsFormat(const std::vector &outputs_format); + + void SetInputsDeviceType(const std::vector &inputs_device_type); + + void SetOutputsDeviceType(const std::vector &outputs_device_type); + + void SetInputReshapeType(const std::vector> &input_reshape_type); + + void SetOutputReshapeType(const std::vector> &output_reshape_type); + + void SetFusionType(FusionType fusion_type); + + void SetProcessor(Processor processor); + + void SetOpPattern(OpPattern pattern); + + void SetInputFormat(const std::string &format, size_t index); + + void SetOutputFormat(const std::string &format, size_t index); + + std::shared_ptr Build(); + + private: + std::shared_ptr kernel_build_info_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc new file mode 100644 index 0000000000..0045e49bef --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/kernel_fusion.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeUtils; +static bool GenPreBuildKernelJson(const std::vector &compute_nodes, + std::vector *prebuild_op_list) { + MS_EXCEPTION_IF_NULL(prebuild_op_list); + TbeKernelJsonCreator creator(PREBUILD); + for (const auto &anf_node : compute_nodes) { + nlohmann::json prebuild; + if (!creator.GenTbeSingleKernelJson(anf_node, &prebuild)) { + MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; + return false; + } + (*prebuild_op_list).push_back(prebuild); + } + return true; +} + +std::map KernelFusion(const std::vector &fusion_scopes) { + MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); + std::map kernel_mod_ret; + auto build_manger = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manger); + for (const auto &fusion_scope_iter : fusion_scopes) { + auto scope_id = fusion_scope_iter.scope_id; + nlohmann::json fusion_op; + string fusion_kernel = "te_fusion"; + if (!TbeKernelBuild::GenFusionScopeJson(fusion_scope_iter.input_nodes, fusion_scope_iter.compute_nodes, &fusion_op, + &fusion_kernel)) { + continue; + } + // gen kernel_name & check cache + std::string json_str = fusion_op.dump(); + size_t hash_id = std::hash()(json_str); + auto json_name = fusion_kernel.append("_").append(std::to_string(hash_id)); + fusion_op["fusion_op_name"] = json_name; + // gen json for prebuild + std::vector prebuild_op_list; + if (!GenPreBuildKernelJson(fusion_scope_iter.compute_nodes, &prebuild_op_list)) { + continue; + } + // get io size + std::vector input_size_list; + std::vector output_size_list; + if (!TbeKernelBuild::GetIOSize(fusion_op["op_list"], fusion_scope_iter.output_nodes, &input_size_list, + &output_size_list)) { + continue; + } + // search cache + auto kernel_pack = TbeUtils::SearchCache(json_name, tbe::kProcessorAiCore); + if (kernel_pack != nullptr) { + MS_LOG(INFO) << "Use cached kernel, kernel json name: " << json_name; + auto kernel_mod = + build_manger->GenKernelMod(json_name, tbe::kProcessorAiCore, input_size_list, output_size_list, kernel_pack); + if (kernel_mod != nullptr) { + kernel_mod_ret[scope_id] = kernel_mod; + continue; + } + } + // fusion build + nlohmann::json fusion_json; + fusion_json["fusion_op"] = fusion_op; + fusion_json["prebuild_ops"] = prebuild_op_list; + auto task_id = build_manger->StartCompileOp(fusion_json); + TbeUtils::SaveJsonInfo(json_name, fusion_json.dump()); + if (task_id < 0) { + MS_EXCEPTION(ArgumentError) << "start compile failed."; + } + build_manger->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list, scope_id); + } + + int build_failed_num = 0; + while (!build_manger->IsAllTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_LOG(INFO) << "Fusion warning: Fuison op build failed, err log: " << task_result + << " change to single op build."; + build_failed_num++; + } + auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); + if (kernel_mod_item.second != nullptr) { + (void)kernel_mod_ret.emplace(kernel_mod_item); + } + } + MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; + return kernel_mod_ret; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h new file mode 100644 index 0000000000..2fb3a05b4b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ +#include +#include +#include "backend/kernel_compiler/kernel.h" +namespace mindspore { +namespace kernel { +/* + * @brief fuse op and return a callable mod + */ +struct FusionScopeInfo { + int32_t scope_id; + std::vector input_nodes; + std::vector compute_nodes; + std::vector output_nodes; +}; + +std::map KernelFusion(const std::vector &fusion_scopes); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc new file mode 100755 index 0000000000..81b5d0f996 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/kernel_query.h" +#include +#include +#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" +#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" +#include "backend/kernel_compiler/akg/akg_kernel_metadata.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +namespace { +void FilterInvalidKernelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_info_list); + std::vector> filtered_list; + (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), + [&kernel_node](const std::shared_ptr &kernel_build_info) { + return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && + AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); + }); + if (!filtered_list.empty()) { + kernel_info_list->clear(); + (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); + } else { + MS_LOG(INFO) << "All kernel Info list does not match any kernel info "; + for (size_t index = 0; index < kernel_info_list->size(); ++index) { + std::ostringstream buffer; + auto kernel_info = kernel_info_list->at(index); + MS_EXCEPTION_IF_NULL(kernel_info); + if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) { + buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]"; + } else { + buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" + << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]"; + } + MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str(); + } + kernel_info_list->clear(); + MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" + << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; + } +} +} // namespace + +void KernelQueryAll(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + + TbeMetadataInfo(kernel_node, kernel_info_list); + + if (kernel_info_list->empty()) { + AicpuMetadataInfo(kernel_node, kernel_info_list); + if (!kernel_info_list->empty()) { + MS_LOG(INFO) << "The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); + } + } + + if (kernel_info_list->empty()) { + GetRtKelInfo(kernel_node, kernel_info_list); + } + + if (kernel_info_list->empty()) { + HcclMetadataInfo(kernel_node, kernel_info_list); + } + if (kernel_info_list->empty()) { + MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; + } +} + +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, + KernelType kernel_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { + kernel_type = KernelType::AKG_KERNEL; + } + + switch (kernel_type) { + case KernelType::AKG_KERNEL: + AkgMetadataInfo(kernel_node, kernel_info_list); + break; + default: + KernelQueryAll(kernel_node, kernel_info_list); + break; + } + + if (kernel_info_list->empty()) { + MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; + } + // check output + FilterInvalidKernelInfo(kernel_node, kernel_info_list); +} + +void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + kernel_info_list->clear(); + AicpuMetadataInfo(kernel_node, kernel_info_list); + FilterInvalidKernelInfo(kernel_node, kernel_info_list); +} +bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(select_kernel_build_info); + std::vector> kernel_info_list; + auto cnode = kernel_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + AICPUQuery(cnode, &kernel_info_list); + return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { + MS_EXCEPTION_IF_NULL(item); + return *item == *select_kernel_build_info; + }); +} + +bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(select_kernel_build_info); + std::vector> kernel_info_list; + auto cnode = kernel_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + TbeMetadataInfo(cnode, &kernel_info_list); + return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { + MS_EXCEPTION_IF_NULL(item); + return *item == *select_kernel_build_info; + }); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h new file mode 100644 index 0000000000..20458f48d0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ +#define MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ + +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, + KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); +void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h new file mode 100644 index 0000000000..64ae1009d1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ +#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ +#include +#include +#include +#include +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; +enum OpIOType { kInput = 0, kOutput }; + +class OpAttr { + public: + OpAttr() = default; + ~OpAttr() = default; + + std::string name() const { return name_; } + std::string param_type() const { return param_type_; } + std::string type() const { return type_; } + std::string value() const { return value_; } + std::string default_value() const { return default_value_; } + + void set_name(const std::string &name) { name_ = name; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_type(const std::string &type) { type_ = type; } + void set_value(const std::string &value) { value_ = value; } + void set_default_value(const std::string &default_value) { default_value_ = default_value; } + + private: + std::string name_; + std::string param_type_; + std::string type_; + std::string value_; + std::string default_value_; +}; + +class OpIOInfo { + public: + OpIOInfo() = default; + ~OpIOInfo() = default; + + int index() const { return index_; } + std::string name() const { return name_; } + bool need_compile() const { return need_compile_; } + std::string param_type() const { return param_type_; } + std::string reshape_type() const { return reshape_type_; } + std::string shape() const { return shape_; } + std::vector dtypes() const { return dtypes_; } + std::vector formats() const { return formats_; } + + void set_index(const int index) { index_ = index; } + void set_name(const std::string &name) { name_ = name; } + void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } + void set_shape(const std::string &shape) { shape_ = shape; } + void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } + void set_formats(const std::vector &formats) { formats_ = formats; } + + private: + int index_ = 0; + std::string name_; + bool need_compile_ = false; + std::string param_type_; + std::string reshape_type_; + std::string shape_; + std::vector dtypes_; + std::vector formats_; +}; + +class OpInfo { + public: + OpInfo() = default; + OpInfo(const OpInfo &opinfo) { + op_name_ = opinfo.op_name(); + imply_type_ = opinfo.imply_type(); + + impl_path_ = opinfo.impl_path(); + fusion_type_ = opinfo.fusion_type(); + async_flag_ = opinfo.async_flag_; + binfile_name_ = opinfo.binfile_name_; + compute_cost_ = opinfo.compute_cost_; + kernel_name_ = opinfo.kernel_name(); + partial_flag_ = opinfo.partial_flag_; + dynamic_format_ = opinfo.dynamic_format_; + op_pattern_ = opinfo.op_pattern(); + processor_ = opinfo.processor_; + for (const auto &attr : opinfo.attrs_ptr()) { + attrs_ptr_.push_back(std::make_shared(*attr)); + } + for (const auto &input : opinfo.inputs_ptr()) { + inputs_ptr_.push_back(std::make_shared(*input)); + } + for (const auto &output : opinfo.outputs_ptr()) { + outputs_ptr_.push_back(std::make_shared(*output)); + } + ref_infos_ = opinfo.ref_infos(); + } + ~OpInfo() = default; + std::string op_name() const { return op_name_; } + OpImplyType imply_type() const { return imply_type_; } + std::string impl_path() const { return impl_path_; } + std::string fusion_type() const { return fusion_type_; } + std::string kernel_name() const { return kernel_name_; } + OpPattern op_pattern() const { return op_pattern_; } + std::string processor() const { return processor_; } + std::vector> attrs_ptr() const { return attrs_ptr_; } + std::vector> inputs_ptr() const { return inputs_ptr_; } + std::vector> outputs_ptr() const { return outputs_ptr_; } + const std::unordered_map &ref_infos() const { return ref_infos_; } + + void set_op_name(const std::string &op_name) { op_name_ = op_name; } + void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } + void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } + void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } + void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } + void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } + void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } + void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } + void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } + void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } + void set_processor(const std::string &processor) { processor_ = processor; } + void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } + void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } + void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } + bool is_ref() const { return !ref_infos_.empty(); } + bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } + void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } + void ClearInputs() { (void)inputs_ptr_.clear(); } + void ClearOutputs() { (void)outputs_ptr_.clear(); } + bool equals_to(const std::shared_ptr &other_info) const { + return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && + this->processor_ == other_info->processor_; + } + + private: + std::string op_name_; + OpImplyType imply_type_ = kTBE; + std::string impl_path_; + std::string fusion_type_; + bool async_flag_ = false; + std::string binfile_name_; + int compute_cost_ = 0; + std::string kernel_name_; + bool partial_flag_ = false; + bool dynamic_format_ = false; + OpPattern op_pattern_ = kCommonPattern; + std::string processor_; + std::vector> attrs_ptr_; + std::vector> inputs_ptr_; + std::vector> outputs_ptr_; + std::unordered_map ref_infos_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc new file mode 100644 index 0000000000..69c4ca7db1 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc @@ -0,0 +1,390 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/oplib/oplib.h" +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" +#include "utils/overload.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +constexpr auto kImplyType = "imply_type"; +constexpr auto kOpName = "op_name"; +constexpr auto kFusionType = "fusion_type"; +constexpr auto kAsyncFlag = "async_flag"; +constexpr auto kBinfileName = "binfile_name"; +constexpr auto kComputeCost = "compute_cost"; +constexpr auto kKernelName = "kernel_name"; +constexpr auto kPartialFlag = "partial_flag"; +constexpr auto kReshapeType = "reshape_type"; +constexpr auto kOpPattern = "op_pattern"; +constexpr auto kDynamicFormat = "dynamicFormat"; +constexpr auto kFormatAgnostic = "formatAgnostic"; +constexpr auto kBroadcast = "broadcast"; +constexpr auto kReduce = "reduce"; +constexpr auto kDtypeFormat = "dtype_format"; +constexpr auto kAttr = "attr"; +constexpr auto kIputs = "inputs"; +constexpr auto kOutputs = "outputs"; +constexpr auto kAiCPU = "AiCPU"; +constexpr auto kAiCore = "AiCore"; +constexpr auto kCUDA = "CUDA"; +constexpr auto kTbe = "TBE"; +constexpr auto kAkg = "AKG"; +constexpr auto kName = "name"; +constexpr auto kParamType = "param_type"; +constexpr auto kDtype = "dtype"; +constexpr auto kType = "type"; +constexpr auto kValue = "value"; +constexpr auto kDefaultValue = "default_value"; +constexpr auto kIndex = "index"; +constexpr auto kFormat = "format"; +constexpr auto kNeedCompile = "need_compile"; +constexpr auto kShape = "shape"; +constexpr auto kProcessor = "processor"; +std::vector> OpLib::op_info_; + +static std::string ImplTypeToStr(OpImplyType impl_type) { + switch (impl_type) { + case kTBE: + return kTbe; + case kAKG: + return kAkg; + case kAICPU: + return kAiCPU; + default: + return "unknow"; + } +} +bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { + bool ret = false; + try { + auto op_json = nlohmann::json::parse(json_string); + std::string imply_type_string = op_json.at(kImplyType); + std::string op_name = op_json.at(kOpName); + if (imply_type_string == kTbe) { + OpImplyType imply_type = kTBE; + ret = DecodeOpInfo(op_json, imply_type, impl_path); + } else if (imply_type_string == kAkg) { + OpImplyType imply_type = kAKG; + ret = DecodeOpInfo(op_json, imply_type, impl_path); + } else if (imply_type_string == kAiCPU) { + OpImplyType imply_type = kAICPU; + ret = DecodeOpInfo(op_json, imply_type, impl_path); + } else { + MS_LOG(ERROR) << "Not support imply_type"; + } + if (!ret) { + MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string; + } + } catch (const std::exception &e) { + MS_LOG(ERROR) << "get op json elements failed: " << e.what(); + } + return ret; +} + +void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { + const std::map kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, + {kBroadcast, kBroadcastPattern}, + {kReduce, kReducePattern}, + {kDynamicFormat, kDynamicFormatPattern}}; + MS_EXCEPTION_IF_NULL(op_info); + op_info->set_async_flag(obj.at(kAsyncFlag)); + op_info->set_binfile_name(obj.at(kBinfileName)); + op_info->set_compute_cost(obj.at(kComputeCost)); + op_info->set_kernel_name(obj.at(kKernelName)); + op_info->set_partial_flag(obj.at(kPartialFlag)); + + if (obj.find(kOpPattern) != obj.end()) { + std::string op_pattern = obj.at(kOpPattern); + auto find_iter = kOpPatternMap.find(op_pattern); + if (find_iter == kOpPatternMap.end()) { + if (!op_pattern.empty()) { + MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; + } + op_info->set_op_pattern(kCommonPattern); + } else { + op_info->set_op_pattern(find_iter->second); + } + } +} + +void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + op_info->set_processor(obj.at(kProcessor)); +} + +bool OpLib::RegOpFromLocalInfo() { + MS_LOG(INFO) << "Start"; + static bool has_load = false; + if (has_load) { + return true; + } + has_load = true; + std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH"); + if (dir.empty()) { + MS_LOG(INFO) << "MindSpore op info path does not been setted. use op info from python pass."; + return true; + } + char real_path[PATH_MAX] = {0}; + if (dir.size() >= PATH_MAX) { + MS_LOG(ERROR) << "Op info path is invalid: " << dir; + return false; + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Op info path is invalid: " << dir; + return false; + } +#else + if (realpath(common::SafeCStr(dir), real_path) == nullptr) { + MS_LOG(ERROR) << "Op info path is invalid: " << dir; + return false; + } +#endif + MS_LOG(INFO) << "Start to read op info from local file."; + std::ifstream file(real_path); + if (!file.is_open()) { + MS_LOG(ERROR) << "Find op info file failed."; + return false; + } + std::string line; + while (getline(file, line)) { + if (!line.empty()) { + (void)OpLib::RegOp(line, ""); + } + } + MS_LOG(INFO) << "End"; + return true; +} + +bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, + const std::string &impl_path) { + std::shared_ptr op_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_info); + op_info->set_op_name(obj.at(kOpName)); + op_info->set_impl_path(impl_path); + op_info->set_imply_type(imply_type); + op_info->set_fusion_type(obj.at(kFusionType)); + if (imply_type == kTBE) { + DecodeTBESpecificInfo(obj, op_info); + } else if (imply_type == kAKG) { + DecodeAKGSpecificInfo(obj, op_info); + } + auto attrs = obj.at(kAttr); + for (const auto &attr : attrs) { + if (!DecodeAttr(attr, imply_type, op_info)) { + MS_LOG(ERROR) << "DecodeAttr Failed"; + return false; + } + } + nlohmann::json dtype_format; + if (obj.find(kDtypeFormat) != obj.end()) { + dtype_format = obj.at(kDtypeFormat); + } + auto inputs = obj.at(kIputs); + for (const auto &input : inputs) { + if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { + MS_LOG(ERROR) << "DecodeInputOutput Failed"; + return false; + } + } + auto outputs = obj.at(kOutputs); + for (const auto &output : outputs) { + if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { + MS_LOG(ERROR) << "DecodeInputOutput Failed"; + return false; + } + } + if (CheckRepetition(op_info)) { + MS_LOG(WARNING) << "This op info has been already registed. op name: " << op_info->op_name() + << ", impl type: " << ImplTypeToStr(op_info->imply_type()) + << ", impl path: " << op_info->impl_path(); + return true; + } + if (!GetRefInfo(op_info)) { + MS_LOG(ERROR) << "GetRefInfo Failed"; + return false; + } + op_info_.push_back(op_info); + return true; +} + +bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + bool ret = true; + try { + std::shared_ptr op_attr = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_attr); + op_attr->set_name(obj.at(kName)); + if (imply_type != kAICPU) { + op_attr->set_param_type(obj.at(kParamType)); + } + op_attr->set_type(obj.at(kType)); + if (imply_type == kTBE) { + op_attr->set_value(obj.at(kValue)); + } + if (obj.find(kDefaultValue) != obj.end()) { + op_attr->set_default_value(obj.at(kDefaultValue)); + } + op_info->add_attrs_ptr(op_attr); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeAttr failed:" << e.what(); + ret = false; + } + return ret; +} + +bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, + size_t index) { + MS_EXCEPTION_IF_NULL(op_io); + bool ret = true; + try { + std::vector dtype; + std::vector format; + for (const auto &it : dtype_format) { + dtype.emplace_back(it[index][0]); + format.emplace_back(it[index][1]); + } + op_io->set_dtypes(dtype); + op_io->set_formats(format); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); + ret = false; + } + return ret; +} + +bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { + MS_EXCEPTION_IF_NULL(op_info); + bool ret = true; + try { + std::shared_ptr op_io = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_io); + op_io->set_index(obj.at(kIndex)); + op_io->set_name(obj.at(kName)); + if (!dtype_format.empty()) { + if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) { + MS_LOG(ERROR) << "Decode dtype format failed"; + return false; + } + } else { + op_io->set_dtypes(obj.at(kDtype)); + op_io->set_formats(obj.at(kFormat)); + } + if (op_io->dtypes().size() != op_io->formats().size()) { + MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes() + << " is not equal to format size: " << op_io->formats(); + return false; + } + if (obj.find(kParamType) != obj.end()) { + op_io->set_param_type(obj.at(kParamType)); + } + if (imply_type == kTBE) { + if (obj.find(kNeedCompile) != obj.end()) { + op_io->set_need_compile(obj.at(kNeedCompile)); + } + if (obj.find(kShape) != obj.end()) { + op_io->set_shape(obj.at(kShape)); + } + if (obj.find(kReshapeType) != obj.end()) { + op_io->set_reshape_type(obj.at(kReshapeType)); + } + } + + if (io_type == kInput) { + op_info->add_inputs_ptr(op_io); + } else if (io_type == kOutput) { + op_info->add_outputs_ptr(op_io); + } + } catch (const std::exception &e) { + MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what(); + ret = false; + } + return ret; +} + +std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { + if (!OpLib::RegOpFromLocalInfo()) { + MS_LOG(INFO) << "Warning reg local op info failed."; + } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool is_gpu = (context->device_target() == kGPUDevice); + if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { + MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); + return nullptr; + } + for (const auto &op_info : op_info_) { + MS_EXCEPTION_IF_NULL(op_info); + if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { + auto akg_processor_match = [&]() { + return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore; + }; + if (imply_type != kAKG || akg_processor_match()) { + return op_info; + } + } + } + MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); + return nullptr; +} + +bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + const auto &output_infos = op_info->outputs_ptr(); + const auto &input_infos = op_info->inputs_ptr(); + for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { + MS_EXCEPTION_IF_NULL(output_infos[out_index]); + const auto &out_name = output_infos[out_index]->name(); + for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { + MS_EXCEPTION_IF_NULL(input_infos[in_index]); + const auto &in_name = input_infos[in_index]->name(); + if (out_name == in_name) { + if (op_info->has_ref_index(out_index)) { + MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info"; + return false; + } + op_info->add_ref_pair(out_index, in_index); + MS_LOG(INFO) << "add ref info, op name is " << op_info->op_name() << ", outindex is " << out_index + << ", in_index is " << in_index; + } + } + } + return true; +} + +bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + for (const auto &exist_op_info : op_info_) { + MS_EXCEPTION_IF_NULL(exist_op_info); + if (exist_op_info->equals_to(op_info)) { + return true; + } + } + return false; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h new file mode 100644 index 0000000000..845edbfc2a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ +#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/oplib/opinfo.h" + +namespace mindspore { +namespace kernel { +class OpLib { + public: + OpLib() = default; + virtual ~OpLib() = default; + static bool RegOp(const std::string &json_string, const std::string &impl_path); + static void RegOpInfo(const std::shared_ptr &opinfo) { op_info_.emplace_back(opinfo); } + static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); + static const std::vector> &GetAllOpsInfo() { return op_info_; } + + protected: + static std::vector> op_info_; + + private: + static bool RegOpFromLocalInfo(); + static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); + static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info); + static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, + size_t index); + static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static void DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format); + static bool GetRefInfo(const std::shared_ptr &op_info); + static bool CheckRepetition(const std::shared_ptr &op_info); +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h new file mode 100644 index 0000000000..6b2981e5b3 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oploader.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_OPLOADER_H +#define MINDSPORE_OPLOADER_H + +#include +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace kernel { +class OpInfoLoaderPy { + public: + OpInfoLoaderPy() = default; + + ~OpInfoLoaderPy() = default; + + size_t GetAllOpsInfo() { + auto ops = OpLib::GetAllOpsInfo(); + auto op_infos = new std::vector(); + for (auto op_info : ops) { + auto new_op_info = new OpInfo(*op_info); + op_infos->emplace_back(new_op_info); + } + return (size_t)op_infos; + } +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_OPLOADER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc new file mode 100644 index 0000000000..552468bb71 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/assign.h" + +#include + +#include "runtime/mem.h" +#include "common/utils.h" + +using ge::model_runner::MemcpyAsyncTaskInfo; +using MemcpyAsyncTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +AssignKernel::AssignKernel() {} + +AssignKernel::~AssignKernel() {} + +bool AssignKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void *stream_ptr) { + if (inputs.size() != 2) { + MS_LOG(ERROR) << "inputs size is not two"; + return false; + } + + if (inputs[0]->addr == inputs[1]->addr) { + MS_LOG(INFO) << "first addr is same with second addr , no need assign"; + return true; + } + rtError_t status = rtMemcpyAsync(inputs[0]->addr, inputs[0]->size, inputs[1]->addr, inputs[1]->size, + RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Assign op rtMemcpyAsync failed!"; + return false; + } + return true; +} + +std::vector AssignKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "inputs size is not two"; + } + stream_id_ = stream_id; + + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, inputs[0]->addr, inputs[0]->size, inputs[1]->addr, + inputs[1]->size, RT_MEMCPY_DEVICE_TO_DEVICE, false); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h new file mode 100644 index 0000000000..cff946cc36 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/assign.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H +#define MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H + +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class AssignKernel : public RtKernel { + public: + AssignKernel(); + ~AssignKernel() override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; +}; + +MS_REG_RTKERNEL(assign, AssignKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc new file mode 100644 index 0000000000..8ec460fe0b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/label_goto.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelGotoTaskInfo; +using LabelGotoTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelGotoKernel::LabelGotoKernel() { label_ = 0; } + +LabelGotoKernel::~LabelGotoKernel() {} + +bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelGotoKernel init"; + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); + MS_LOG(INFO) << "LabelGotoKernel get attr label:" << label_; + return true; +} + +bool LabelGotoKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "LabelGotoKernel launch"; + return true; +} + +std::vector LabelGotoKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "LabelGotoKernel GenTask label:" << label_ << ", stream id:" << stream_id; + std::vector task_info_list; + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, label_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h new file mode 100644 index 0000000000..2680d916a5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_goto.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H +#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelGotoKernel : public RtKernel { + public: + LabelGotoKernel(); + ~LabelGotoKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t label_; +}; + +MS_REG_RTKERNEL(labelgoto, LabelGotoKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc new file mode 100644 index 0000000000..909885ff17 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/label_set.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelSetTaskInfo; +using LabelSetTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelSetKernel::LabelSetKernel() { label_ = 0; } + +LabelSetKernel::~LabelSetKernel() {} + +bool LabelSetKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelSetKernel init"; + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); + MS_LOG(INFO) << "LabelSetKernel get attr label:" << label_; + return true; +} + +bool LabelSetKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "LabelSetKernel launch"; + return true; +} + +std::vector LabelSetKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "LabelSetKernel GenTask label:" << label_ << ", stream id:" << stream_id; + std::vector task_info_list; + std::shared_ptr task_info_ptr = std::make_shared(kernel_name_, stream_id, label_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h new file mode 100644 index 0000000000..8d0cfdfb20 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_set.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H +#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelSetKernel : public RtKernel { + public: + LabelSetKernel(); + ~LabelSetKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t label_; +}; + +MS_REG_RTKERNEL(labelset, LabelSetKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc new file mode 100644 index 0000000000..ccb49d9497 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/label_switch.h" +#include +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::LabelSwitchTaskInfo; +using LabelSwitchTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +LabelSwitchKernel::LabelSwitchKernel() { + label_list_ = {}; + cond_ = nullptr; + label_size_ = 0; +} + +LabelSwitchKernel::~LabelSwitchKernel() {} + +bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "LabelSwitchKernel init"; + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) { + MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + label_list_ = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); + label_size_ = label_list_.size(); + MS_LOG(INFO) << "LabelSwitchKernel get attr label size:" << label_size_; + for (auto label : label_list_) { + MS_LOG(INFO) << "label: " << label; + } + return true; +} + +bool LabelSwitchKernel::Launch(const std::vector & /*inputs*/, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + MS_LOG(INFO) << "LabelSwitchKernel launch"; + return true; +} + +std::vector LabelSwitchKernel::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; + std::vector task_info_list; + cond_ = inputs[0]->addr; + auto task_info_ptr = std::make_shared(kernel_name_, stream_id, label_size_, label_list_, cond_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + return task_info_list; +} + +std::vector> LabelSwitchDesc::GetKernelInfo() { + std::vector> label_switch_build_info{}; + vector input_format{kOpFormat_DEFAULT}; + vector input_type{kNumberTypeInt32}; + if (input_format.size() != input_type.size()) { + MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " + << input_type.size(); + } + for (size_t i = 0; i < input_format.size(); ++i) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat({input_format[i]}); + builder.SetInputsDeviceType({input_type[i]}); + builder.SetProcessor(AICORE); + builder.SetKernelType(RT_KERNEL); + builder.SetFusionType(OPAQUE); + label_switch_build_info.emplace_back(builder.Build()); + } + return label_switch_build_info; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h new file mode 100644 index 0000000000..1860d38d74 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H +#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class LabelSwitchKernel : public RtKernel { + public: + LabelSwitchKernel(); + ~LabelSwitchKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + std::vector label_list_; + uint32_t label_size_; + void *cond_; +}; + +class LabelSwitchDesc : public RtKerDesc { + public: + LabelSwitchDesc() = default; + ~LabelSwitchDesc() override = default; + std::vector> GetKernelInfo() override; +}; + +MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); +MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc new file mode 100644 index 0000000000..ca1114a83f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/memcpy_async.h" + +#include +#include + +#include "runtime/mem.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/trans.h" +#include "utils/context/ms_context.h" + +using ge::model_runner::MemcpyAsyncTaskInfo; +using MemcpyAsyncTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +MemCpyAsyncKernel::MemCpyAsyncKernel() {} + +MemCpyAsyncKernel::~MemCpyAsyncKernel() {} + +bool MemCpyAsyncKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) { + if (inputs.size() != 1) { + MS_LOG(ERROR) << "inputs size is not one"; + return false; + } + if (outputs.size() != 1) { + MS_LOG(ERROR) << "outputs size is not one"; + return false; + } + + if (inputs[0]->addr == outputs[0]->addr) { + MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async"; + return true; + } + if (outputs[0]->size < inputs[0]->size) { + MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (outputs[0]->size > inputs[0]->size) { + MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; + } + rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, + RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!"; + return false; + } + return true; +} + +bool MemCpyAsyncKernel::Init(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + GetInputOutputDataType(anf_node); + GetInputOutputTotalCount(anf_node); + return true; +} + +void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); + if (input_size != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; + } + input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); +} + +void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); + if (input_size != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; + } + size_t type_size = trans::TypeIdSize(input_type_id_); + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); + size_t total_size = 1; + for (size_t i = 0; i < shape_i.size(); i++) { + total_size = total_size * shape_i[i]; + } + total_size *= type_size; + MS_LOG(INFO) << "MemCpyAsync size[" << total_size << "]"; + input_size_list_.emplace_back(total_size); + output_size_list_.emplace_back(total_size); +} + +std::vector MemCpyAsyncKernel::GenTask(const std::vector &inputs, + const std::vector &, + const std::vector &outputs, uint32_t stream_id) { + if (inputs.size() != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one"; + } + + if (outputs.size() != 1) { + MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one"; + } + + if (outputs[0]->size < inputs[0]->size) { + MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (outputs[0]->size > inputs[0]->size) { + MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; + } + + stream_id_ = stream_id; + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, + inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump()); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} + +const std::vector data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, + kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, + kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16, + kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; +const std::vector format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, + kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, + kOpFormat_C1HWNCoC0}; + +MemCpyAsyncDesc::MemCpyAsyncDesc() {} + +MemCpyAsyncDesc::~MemCpyAsyncDesc() {} + +std::vector> MemCpyAsyncDesc::GetKernelInfo() { + std::vector> memcpy_build_info{}; + for (const auto &format : format_list) { + for (const auto &type : data_type_list) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + vector input_format{format}; + vector input_type{type}; + vector output_format{format}; + vector output_type{type}; + builder.SetInputsFormat(input_format); + builder.SetInputsDeviceType(input_type); + builder.SetOutputsFormat(output_format); + builder.SetOutputsDeviceType(output_type); + builder.SetProcessor(AICORE); + builder.SetKernelType(RT_KERNEL); + builder.SetFusionType(OPAQUE); + memcpy_build_info.emplace_back(builder.Build()); + } + } + return memcpy_build_info; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h new file mode 100644 index 0000000000..07a782be50 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H +#define MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class MemCpyAsyncKernel : public RtKernel { + public: + MemCpyAsyncKernel(); + ~MemCpyAsyncKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + void GetInputOutputDataType(const AnfNodePtr &anf_node); + void GetInputOutputTotalCount(const AnfNodePtr &anf_node); + TypeId input_type_id_{}; +}; + +class MemCpyAsyncDesc : public RtKerDesc { + public: + MemCpyAsyncDesc(); + ~MemCpyAsyncDesc() override; + std::vector> GetKernelInfo() override; +}; + +MS_REG_RTKERNEL_DESC(memcpy_async, MemCpyAsyncDesc); +MS_REG_RTKERNEL(memcpy_async, MemCpyAsyncKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc new file mode 100644 index 0000000000..8213468b48 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/profiling_kernel_mod.h" + +#include +#include +#include + +#include "framework/ge_runtime/task_info.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; +using mindspore::device::ascend::ProfilingUtils; + +namespace mindspore { +namespace kernel { +bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) { + MS_LOG(INFO) << "[profiling] init profiling kernel mod"; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + + ValuePtr notify_ptr = primitive->GetAttr(ProfilingUtils::kNotify); + MS_EXCEPTION_IF_NULL(notify_ptr); + + ValuePtr log_id_ptr = primitive->GetAttr(ProfilingUtils::kProfilerTraceId); + MS_EXCEPTION_IF_NULL(log_id_ptr); + + ValuePtr flags_ptr = primitive->GetAttr(ProfilingUtils::kFlags); + MS_EXCEPTION_IF_NULL(flags_ptr); + + notify_ = GetValue(notify_ptr); + log_id_ = GetValue(log_id_ptr); + flags_ = GetValue(flags_ptr); + MS_LOG(INFO) << "[profiling] profiling kernel notify_:" << notify_ << ", log_id_:" << log_id_ + << ", flags_:" << flags_; + return true; +} + +bool ProfilingKernelMod::Launch(const std::vector & /*inputs*/, + const std::vector & /*workspace*/, + const std::vector & /*outputs*/, void * /*stream_ptr*/) { + return true; +} + +std::vector ProfilingKernelMod::GenTask(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) { + MS_LOG(INFO) << "gen task inputs size:" << inputs.size() << ", workspace size:" << workspace.size() + << ", outputs size:" << outputs.size(); + stream_id_ = stream_id; + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, log_id_, notify_, flags_); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h new file mode 100644 index 0000000000..cdb43afb3e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ +#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +namespace mindspore { +namespace kernel { +class ProfilingKernelMod : public RtKernel { + public: + ProfilingKernelMod() = default; + ~ProfilingKernelMod() override = default; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + bool Init(const AnfNodePtr &anf_node) override; + + private: + uint64_t log_id_{0}; + bool notify_{true}; + uint32_t flags_{0}; +}; +MS_REG_RTKERNEL(profiling, ProfilingKernelMod); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc new file mode 100644 index 0000000000..cee0ef2fdc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/recv.h" +#include +#include "runtime/stream.h" +#include "utils/context/ms_context.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +using ge::model_runner::EventWaitTaskInfo; +using mindspore::device::ascend::AscendStreamAssign; +using EventWaitTaskInfoPtr = std::shared_ptr; + +RecvKernel::RecvKernel() { event_id_ = 0; } + +RecvKernel::~RecvKernel() {} + +bool RecvKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { + MS_LOG(EXCEPTION) << "RecvKernel has no attr kAttrEventId"; + } + event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); + MS_LOG(INFO) << "recv op event_id_:" << event_id_; + return true; +} + +bool RecvKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + rtEvent_t stream_event{}; + auto status = rtStreamWaitEvent(stream_ptr, stream_event); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Recv rtStreamWaitEvent failed!"; + return false; + } + return true; +} + +std::vector RecvKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "RecvKernel GenTask event_id_:" << event_id_ << ", stream_id_:" << stream_id; + stream_id_ = stream_id; + EventWaitTaskInfoPtr task_info_ptr = std::make_shared(kernel_name_, stream_id, event_id_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h new file mode 100644 index 0000000000..73e0214eae --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/recv.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RECV_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RECV_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class RecvKernel : public RtKernel { + public: + RecvKernel(); + ~RecvKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t event_id_; +}; + +MS_REG_RTKERNEL(recv, RecvKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RECV_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.cc new file mode 100644 index 0000000000..9279a84cf0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/rt_kernel.h" + +namespace mindspore { +namespace kernel { +void RtKernelFactory::Registe(const std::string &name, RtKernelCreater &&fun) { + (void)fmap_.emplace(name, std::move(fun)); +} + +std::shared_ptr RtKernelFactory::Create(const std::string &name) { + const auto &map = Get().fmap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +RtKernelFactory &RtKernelFactory::Get() { + static RtKernelFactory _this; + return _this; +} + +RtKernel::RtKernel() {} + +RtKernel::~RtKernel() {} + +bool RtKernel::Init(const mindspore::AnfNodePtr & /*anf_node*/) { return true; } + +const std::vector &RtKernel::GetInputSizeList() const { return input_size_list_; } + +const std::vector &RtKernel::GetOutputSizeList() const { return output_size_list_; } + +const std::vector &RtKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h new file mode 100644 index 0000000000..dc0aa3e283 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H + +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/task_stream.h" + +namespace mindspore { +namespace kernel { +class RtKernel : public AscendKernelMod { + public: + RtKernel(); + ~RtKernel() override; + virtual bool Init(const AnfNodePtr &anf_node); + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + + protected: + mutable std::vector input_size_list_; + mutable std::vector output_size_list_; + mutable std::vector workspace_size_list_; +}; + +using RTKernelPtr = std::shared_ptr; + +using RtKernelCreater = std::function()>; +class RtKernelFactory { + RtKernelFactory() = default; + ~RtKernelFactory() = default; + + public: + static RtKernelFactory &Get(); + void Registe(const std::string &name, RtKernelCreater &&fun); + static std::shared_ptr Create(const std::string &name); + + private: + std::map fmap_; +}; + +class _RtKernelRegister { + public: + _RtKernelRegister(const std::string &name, RtKernelCreater &&fun) { + RtKernelFactory::Get().Registe(name, std::move(fun)); + } + ~_RtKernelRegister() = default; +}; + +#define _MS_REG_RTKERNEL_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of RtKernel"); \ + static const _RtKernelRegister g_##KNAME##_##_RtKernel_reg(#KNAME, []() { return std::make_shared(); }); + +#define MS_REG_RTKERNEL(KNAME, clazz) _MS_REG_RTKERNEL_REG(KNAME, clazz) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.cc new file mode 100644 index 0000000000..9704a9b97f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/rt_kernel_build.h" + +#include +#include +#include +#include + +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +KernelModPtr RtOpBuild(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + (void)std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); + MS_LOG(INFO) << "Op Name(tolower)[" << op_name << "]"; + auto ker_ptr = RtKernelFactory::Create(op_name); + MS_EXCEPTION_IF_NULL(ker_ptr); + if (!ker_ptr->Init(anf_node)) { + MS_LOG(ERROR) << "Rt Op initialize failed!"; + return nullptr; + } + + return ker_ptr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h new file mode 100644 index 0000000000..ccfb8d923b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_build.h @@ -0,0 +1,29 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H + +#include +#include +#include "backend/kernel_compiler/kernel.h" +namespace mindspore { +namespace kernel { +KernelModPtr RtOpBuild(const AnfNodePtr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc new file mode 100755 index 0000000000..9501aed5f2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/rt_kernel_info.h" +#include +#include +#include "utils/convert_utils.h" +#include "utils/utils.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +void RtKerDescFactory::Register(const std::string &name, RtKerDescCreater &&fun) { + if (fmap_.find(name) == fmap_.end()) { + (void)fmap_.emplace(name, std::move(fun)); + } +} + +std::shared_ptr RtKerDescFactory::Create(const std::string &name) { + const auto &map = Get().fmap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +RtKerDescFactory &RtKerDescFactory::Get() { + static RtKerDescFactory _this; + return _this; +} + +static bool IsDefaultKernelInfo(const std::string &name) { + static const std::set white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, + kLabelGotoOpName}; + return white_list.find(name) != white_list.end(); +} + +void GetRtKelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_info_list); + MS_EXCEPTION_IF_NULL(kernel_node); + std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node); + (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower); + + auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower); + if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) { + *kernel_info_list = ker_desc_ptr->GetKernelInfo(); + return; + } + // if can't find kernel info in kernel info database, use the default kernel info + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (IsDefaultKernelInfo(node_name)) { + auto kernel_build_info_builder = std::make_shared(); + // set input infos + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + kernel_build_info_builder->SetInputsFormat(std::vector(input_num, kOpFormat_DEFAULT)); + std::vector input_types = {}; + for (size_t i = 0; i < input_num; i++) { + input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); + } + kernel_build_info_builder->SetInputsDeviceType(input_types); + // set output info + auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + kernel_build_info_builder->SetOutputsFormat(std::vector(output_num, kOpFormat_DEFAULT)); + kernel_build_info_builder->SetOutputsDeviceType(std::vector(output_num, TypeId::kTypeUnknown)); + // set ohter info + kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); + kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); + kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); + kernel_info_list->push_back(kernel_build_info_builder->Build()); + return; + } + MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "]."; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h new file mode 100644 index 0000000000..6048fb3779 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H +#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "utils/utils.h" + +namespace mindspore { +namespace kernel { +class RtKerDesc { + public: + virtual ~RtKerDesc() {} + virtual std::vector> GetKernelInfo() { + return std::vector>{}; + } +}; + +using RtKerDescCreater = std::function()>; +class RtKerDescFactory { + RtKerDescFactory() = default; + ~RtKerDescFactory() = default; + + public: + static RtKerDescFactory &Get(); + void Register(const std::string &name, RtKerDescCreater &&fun); + static std::shared_ptr Create(const std::string &name); + + private: + std::map fmap_; +}; + +class _RtKerDescRegister { + public: + _RtKerDescRegister(const std::string &name, RtKerDescCreater &&fun) { + RtKerDescFactory::Get().Register(name, std::move(fun)); + } + ~_RtKerDescRegister() = default; +}; + +#define _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of RtKerDesc"); \ + static const _RtKerDescRegister g_##KNAME##_##_rtkernel_desc_reg(#KNAME, []() { return std::make_shared(); }); + +#define MS_REG_RTKERNEL_DESC(KNAME, clazz) _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) + +void GetRtKelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc new file mode 100644 index 0000000000..11c0a7d668 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/send.h" +#include +#include "runtime/event.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::EventRecordTaskInfo; +using EventRecordTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +SendKernel::SendKernel() { event_id_ = 0; } + +SendKernel::~SendKernel() {} + +bool SendKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { + MS_LOG(EXCEPTION) << "SendKernel has no attr kAttrEventId"; + } + event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); + MS_LOG(INFO) << "send op event id:" << event_id_; + return true; +} + +bool SendKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + rtEvent_t event{}; + rtError_t status = rtEventRecord(event, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Send op rtEventRecord failed!"; + return false; + } + return true; +} + +std::vector SendKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "SendKernel GenTask event id:" << event_id_ << ", stream id:" << stream_id; + stream_id_ = stream_id; + EventRecordTaskInfoPtr task_info_ptr = std::make_shared(kernel_name_, stream_id, event_id_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/send.h b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h new file mode 100644 index 0000000000..dbadb1ef44 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/send.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_SEND_H +#define MINDSPORE_CCSRC_KERNEL_RTS_SEND_H +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class SendKernel : public RtKernel { + public: + SendKernel(); + ~SendKernel() override; + bool Init(const AnfNodePtr &anf_node) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + uint32_t event_id_; +}; + +MS_REG_RTKERNEL(send, SendKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_SEND_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc new file mode 100644 index 0000000000..e33549973d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/stream_active.h" +#include +#include +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::StreamActiveTaskInfo; +using StreamActiveTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +StreamActiveKernel::StreamActiveKernel() { active_streams_index_ = {}; } + +StreamActiveKernel::~StreamActiveKernel() {} + +bool StreamActiveKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "stream active op init start"; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrActiveStreamList, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamActiveKernel has no attr kAttrActiveStreamList"; + } + active_streams_index_ = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); + return true; +} + +bool StreamActiveKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + MS_LOG(INFO) << "Stream active op launch start"; + + if (active_streams_index_.empty()) { + MS_LOG(ERROR) << "activeStreamList_ is empty!"; + return false; + } + + rtStream_t act_stream; + rtError_t status; + for (auto index : active_streams_index_) { + act_stream = kernel::TaskStream::GetInstance()->gen_stream_list()[index]; + status = rtStreamActive(act_stream, stream_ptr); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Stream active failed!"; + return false; + } + } + return true; +} + +std::vector StreamActiveKernel::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t stream_id) { + MS_LOG(INFO) << "StreamActiveKernel GenTask active stream size:" << active_streams_index_.size() + << ", stream id:" << stream_id; + stream_id_ = stream_id; + std::vector task_info_list; + for (auto &index : active_streams_index_) { + std::shared_ptr task_info_ptr = + std::make_shared(kernel_name_, stream_id, index); + MS_EXCEPTION_IF_NULL(task_info_ptr); + task_info_list.emplace_back(task_info_ptr); + MS_LOG(INFO) << "StreamActiveKernel GenTask: streamId:" << stream_id << ", Active streamId:" << index; + } + return task_info_list; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h new file mode 100644 index 0000000000..409c3437dc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_active.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H +#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class StreamActiveKernel : public RtKernel { + public: + StreamActiveKernel(); + ~StreamActiveKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + std::vector active_streams_index_; +}; + +MS_REG_RTKERNEL(streamactive, StreamActiveKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc new file mode 100644 index 0000000000..5fe03b1960 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/rts/stream_switch.h" + +#include +#include + +#include "runtime/stream.h" +#include "framework/ge_runtime/task_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +using ge::model_runner::StreamSwitchTaskInfo; +using StreamSwitchTaskInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace kernel { +StreamSwitchKernel::StreamSwitchKernel() { + cond_ = RT_EQUAL; + true_stream_index_ = 0; + data_type_ = RT_SWITCH_INT32; +} + +StreamSwitchKernel::~StreamSwitchKernel() {} + +bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "stream switch op init start"; + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + if (!AnfAlgo::HasNodeAttr(kAttrSwitchCondition, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrSwitchCondition"; + } + cond_ = tagRtCondition(GetValue(primitive->GetAttr(kAttrSwitchCondition))); + if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrTrueBranchStream"; + } + true_stream_index_ = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); + if (!AnfAlgo::HasNodeAttr(kAttrDataType, anf_node->cast())) { + MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrDataType"; + } + data_type_ = tagRtSwitchDataType(GetValue(primitive->GetAttr(kAttrDataType))); + MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ + << ", data_type_:" << static_cast(data_type_); + return true; +} + +bool StreamSwitchKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + MS_LOG(INFO) << "stream switch op launch start"; + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "Stream switch inputs size is " << inputs.size() << ", only support 2"; + } + + void *loop_cnt = inputs[0]->addr; + void *ites_per_loop = inputs[1]->addr; + rtStream_t true_stream_ = kernel::TaskStream::GetInstance()->gen_stream_list()[true_stream_index_]; + rtError_t status = rtStreamSwitchEx(loop_cnt, cond_, ites_per_loop, true_stream_, stream_ptr, data_type_); + if (status != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Stream switch failed!"; + return false; + } + return true; +} + +std::vector StreamSwitchKernel::GenTask(const std::vector &inputs, + const std::vector &, const std::vector &, + uint32_t stream_id) { + MS_LOG(INFO) << "StreamSwitchKernel GenTask start"; + if (inputs.size() != 2) { + MS_LOG(EXCEPTION) << "stream switch inputs size is " << inputs.size() << ", is not two"; + } + stream_id_ = stream_id; + MS_EXCEPTION_IF_NULL(inputs[0]); + MS_EXCEPTION_IF_NULL(inputs[1]); + auto loop_cnt = inputs[0]->addr; + auto ites_per_loop = inputs[1]->addr; + MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ + << ", stream_id:" << stream_id; + std::shared_ptr task_info_ptr = std::make_shared( + kernel_name_, stream_id, true_stream_index_, loop_cnt, ites_per_loop, cond_, data_type_); + MS_EXCEPTION_IF_NULL(task_info_ptr); + return {task_info_ptr}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h new file mode 100644 index 0000000000..64a51f68bf --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/stream_switch.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H +#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H + +#include +#include +#include "backend/kernel_compiler/rts/rt_kernel.h" +#include "backend/kernel_compiler/rts/rt_kernel_info.h" + +namespace mindspore { +namespace kernel { +class StreamSwitchKernel : public RtKernel { + public: + StreamSwitchKernel(); + ~StreamSwitchKernel() override; + + bool Init(const AnfNodePtr &anf_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, uint32_t stream_id) override; + + private: + rtCondition_t cond_; + uint32_t true_stream_index_; + rtSwitchDataType_t data_type_; +}; + +MS_REG_RTKERNEL(streamswitch, StreamSwitchKernel); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H diff --git a/mindspore/ccsrc/kernel/task_stream.h b/mindspore/ccsrc/backend/kernel_compiler/task_stream.h similarity index 100% rename from mindspore/ccsrc/kernel/task_stream.h rename to mindspore/ccsrc/backend/kernel_compiler/task_stream.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc new file mode 100644 index 0000000000..449a9f4556 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc @@ -0,0 +1,424 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_adapter.h" + +#include +#include +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/opinfo.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +static std::map tbe_func_adapter_map = { + {"softmax", "softmax_v2"}, + {"log_softmax", "log_softmax_v2"}, + {"apply_momentum", "apply_momentum_d"}, + {"apply_ftrl", "apply_ftrl_d"}, + {"re_lu6", "relu6"}, + {"re_lu6_grad", "relu6_grad"}, + {"re_lu", "relu"}, + {"re_luv2", "relu_v2"}, + {"p_re_lu", "prelu"}, + {"p_re_lu_grad", "prelu_grad"}, + {"tensor_add", "add"}, + {"reduce_mean", "reduce_mean_d"}, + {"reduce_max", "reduce_max_d"}, + {"reduce_min", "reduce_min_d"}, + {"avg_pool_grad", "avg_pool_grad_d"}, + {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, + {"conv2d_backprop_input", "conv2d_backprop_input_d"}, + {"depthwise_conv2d_native", "depthwise_conv2d"}, + {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"}, + {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"}, + {"scatter_nd", "scatter_nd_d"}, + {"tile", "tile_d"}, + {"gather_v2", "gather_v2_d"}, + {"sparse_gather_v2", "gather_v2_d"}, + {"batch_mat_mul", "batch_matmul"}, + {"b_n_training_reduce", "bn_training_reduce"}, + {"b_n_training_update", "bn_training_update"}, + {"b_n_training_update_v2", "bn_training_update_v2"}, + {"b_n_training_update_v3", "bn_training_update_v3"}, + {"b_n_training_reduce_grad", "bn_training_reduce_grad"}, + {"b_n_training_update_grad", "bn_training_update_grad"}, + {"b_n_infer", "bn_infer"}, + {"b_n_infer_grad", "bn_infer_grad"}, + {"n_pu_clear_float_status", "n_p_u_clear_float_status"}, + {"n_pu_get_float_status", "n_p_u_get_float_status"}, + {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"}, + {"dropout_do_mask", "drop_out_do_mask"}, + {"strided_slice", "strided_slice_d"}, + {"strided_slice_grad", "strided_slice_grad_d"}, + {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, + {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, + {"apply_ada_max", "apply_ada_max_d"}, + {"apply_adadelta", "apply_adadelta_d"}, + {"apply_adagrad", "apply_adagrad_d"}, + {"apply_adagrad_v2", "apply_adagradv2_d"}, + {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, + {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, + {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, + {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, + {"apply_add_sign", "apply_add_sign_d"}, + {"apply_power_sign", "apply_power_sign_d"}, + {"transpose", "transpose_d"}, + {"fill", "fill_d"}, + {"unsorted_segment_sum", "unsorted_segment_sum_d"}, + {"unsorted_segment_prod", "unsorted_segment_prod_d"}, + {"concat", "concat_d"}, + {"slice", "slice_d"}, + {"reduce_sum", "reduce_sum_d"}, + {"inplace_add", "inplace_add_d"}, + {"inplace_sub", "inplace_sub_d"}, + {"one_hot", "one_hot_d"}, + {"sum", "reduce_sum_d"}, + {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, + {"lamb_next_mv", "lamb_next_m_v"}, + {"split", "split_d"}, + {"split_v", "split_v_d"}, + {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, + {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, + {"pad", "pad_d"}, + {"argmax", "arg_max_d"}, + {"argmin", "arg_min_d"}, + {"space_to_batch", "space_to_batch_d"}, + {"batch_to_space", "batch_to_space_d"}, + {"space_to_batch_nd", "space_to_batch_nd_d"}, + {"batch_to_space_nd", "batch_to_space_nd_d"}, + {"resize_bilinear", "resize_bilinear_v2_d"}, + {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, + {"adam", "apply_adam_d"}, + {"r_oi_align", "roi_align"}, + {"r_oi_align_grad", "roi_align_grad"}, + {"i_ou", "iou"}, + {"s_gd", "sgd"}, + {"l_rn", "lrn"}, + {"l_rn_grad", "lrn_grad"}, + {"l_ars_update", "lars_v2_update"}, + {"n_ms_with_mask", "nms_with_mask"}, + {"square_sum_all", "square_sum_all"}, + {"cum_sum", "cumsum_d"}, + {"range", "range_d"}, + {"lin_space", "lin_space_d"}, + {"inv_grad", "inv_grad"}, + {"apply_rms_prop", "apply_rms_prop_d"}, + {"cum_prod", "cumprod_d"}, + {"reduce_all", "reduce_all_d"}, + {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, + {"unsorted_segment_min", "unsorted_segment_min_d"}, + {"reduce_prod", "reduce_prod_d"}, + {"a_cos", "acos"}, + {"a_cos_grad", "acos_grad"}, + {"histogram_fixed_width", "histogram_fixed_width_d"}, + {"broadcast_to", "broadcast_to_d"}, + {"inplace_update", "inplace_update_d"}, + {"matrix_diag", "matrix_diag_d"}, + {"matrix_diag_part", "matrix_diag_part_d"}, + {"matrix_set_diag", "matrix_set_diag_d"}}; + +void TbeAdapter::NormalizeFuncName(std::string *func_name) { + if (func_name == nullptr) { + MS_LOG(EXCEPTION) << "func_name is null"; + } + std::string name_tmp; + bool sub_head = false; + for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) { + if (islower(*iter)) { + sub_head = false; + } + if (isdigit(*iter)) { + sub_head = true; + } + if (isupper(*iter) && iter != func_name->begin()) { + if (!sub_head) { + (void)name_tmp.insert(name_tmp.end(), '_'); + sub_head = true; + } else { + string::iterator iter_next = iter + 1; + if (iter_next != func_name->end()) { + if (islower(*iter_next)) { + (void)name_tmp.insert(name_tmp.end(), '_'); + } + } + } + } + (void)name_tmp.insert(name_tmp.end(), *iter); + } + (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower); + *func_name = name_tmp; + auto iter = tbe_func_adapter_map.find(*func_name); + if (iter != tbe_func_adapter_map.end()) { + MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second; + *func_name = iter->second; + } +} + +void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) { + std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0); + std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0); + if (input_format == kOpFormat_DEFAULT) { + input_format = kOpFormat_NCHW; + } + if (output_format == kOpFormat_DEFAULT) { + output_format = kOpFormat_NCHW; + } + AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node); + AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node); + } +} + +std::unordered_set input_order_adjusted_ops = { + "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", + "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; + +void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, + nlohmann::json *inputs_json) { + MS_EXCEPTION_IF_NULL(inputs_json); + if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { + (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); + } else { + if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { + inputs_json->push_back(inputs_list[2]); + inputs_json->push_back(inputs_list[0]); + inputs_json->push_back(inputs_list[1]); + for (size_t i = 3; i < inputs_list.size(); ++i) { + inputs_json->push_back(inputs_list[i]); + } + } else if (op_name == "ApplyCenteredRMSProp") { + // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map + // TBE parameter to correspond python API parameter by latter's index using hardcode + inputs_json->push_back(inputs_list[0]); + inputs_json->push_back(inputs_list[1]); + inputs_json->push_back(inputs_list[2]); + inputs_json->push_back(inputs_list[3]); + inputs_json->push_back(inputs_list[5]); + inputs_json->push_back(inputs_list[6]); + inputs_json->push_back(inputs_list[7]); + inputs_json->push_back(inputs_list[8]); + inputs_json->push_back(inputs_list[4]); + } else { + inputs_json->push_back(inputs_list[1]); + inputs_json->push_back(inputs_list[0]); + for (size_t i = 2; i < inputs_list.size(); ++i) { + inputs_json->push_back(inputs_list[i]); + } + } + } +} + +void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, + std::vector *inputs_json) { + MS_EXCEPTION_IF_NULL(inputs_json); + if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { + (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); + } else { + if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { + inputs_json->emplace_back(inputs_list[2]); + inputs_json->emplace_back(inputs_list[0]); + inputs_json->emplace_back(inputs_list[1]); + for (size_t i = 3; i < inputs_list.size(); ++i) { + inputs_json->emplace_back(inputs_list[i]); + } + } else { + inputs_json->emplace_back(inputs_list[1]); + inputs_json->emplace_back(inputs_list[0]); + for (size_t i = 2; i < inputs_list.size(); ++i) { + inputs_json->emplace_back(inputs_list[i]); + } + } + } +} + +void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, + std::vector *reorder_data_layer) { + MS_EXCEPTION_IF_NULL(reorder_data_layer); + if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { + (void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer))); + } else { + if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { + reorder_data_layer->emplace_back(data_layer[2]); + reorder_data_layer->emplace_back(data_layer[0]); + reorder_data_layer->emplace_back(data_layer[1]); + for (size_t i = 3; i < data_layer.size(); ++i) { + reorder_data_layer->emplace_back(data_layer[i]); + } + } else { + reorder_data_layer->emplace_back(data_layer[1]); + reorder_data_layer->emplace_back(data_layer[0]); + for (size_t i = 2; i < data_layer.size(); ++i) { + reorder_data_layer->emplace_back(data_layer[i]); + } + } + } +} + +std::map TbeAdapter::build_json_attr_pass_map_ = { + {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass}, + {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass}, + {"Cast", TbeAdapter::CastAttrJsonPass}}; + +bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(attrs_json); + auto cnode_name = AnfAlgo::GetCNodeName(anf_node); + auto FPass = build_json_attr_pass_map_.find(cnode_name); + if (FPass != build_json_attr_pass_map_.end()) { + FPass->second(anf_node, op_info_attrs, attrs_json); + return true; + } + return false; +} + +void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + auto attr_num = op_info_attrs.size(); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + for (size_t i = 0; i < attr_num; i++) { + nlohmann::json attr_obj; + MS_EXCEPTION_IF_NULL(op_info_attrs[i]); + std::string attr_name = op_info_attrs[i]->name(); + auto value = primitive->GetAttr(attr_name); + if (value != nullptr) { + bool attr_value = GetValue(value); + attr_obj["value"] = attr_value; + attr_obj["valid"] = true; + } else { + attr_obj["valid"] = false; + } + attr_obj["name"] = attr_name; + attrs_json->push_back(attr_obj); + } + MS_LOG(INFO) << "MaximumGradAttrJsonPass done."; +} + +void TbeAdapter::MinimumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + auto attr_num = op_info_attrs.size(); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + for (size_t i = 0; i < attr_num; i++) { + nlohmann::json attr_obj; + MS_EXCEPTION_IF_NULL(op_info_attrs[i]); + std::string attr_name = op_info_attrs[i]->name(); + auto value = primitive->GetAttr(attr_name); + if (value != nullptr) { + bool attr_value = GetValue(value); + attr_obj["value"] = attr_value; + attr_obj["valid"] = true; + } else { + attr_obj["valid"] = false; + } + attr_obj["name"] = attr_name; + attrs_json->push_back(attr_obj); + } + MS_LOG(INFO) << "MinimumGradAttrJsonPass done."; +} + +static int TypeStrToDstType(const std::string &type_str) { + int ret = -1; + if (type_str == "Float" || type_str == "Float32") { + ret = 0; + } else if (type_str == "Float16") { + ret = 1; + } else if (type_str == "Int8") { + ret = 2; + } else if (type_str == "Int32") { + ret = 3; + } else if (type_str == "UInt8") { + ret = 4; + } else if (type_str == "UInt64") { + ret = 10; + } else if (type_str == "Bool") { + ret = 12; + } else { + MS_LOG(INFO) << "Error type str is invailed: " << type_str; + } + return ret; +} + +void TbeAdapter::CastAttrJsonPass(const mindspore::AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(attrs_json); + if (op_info_attrs.size() != 1) { + MS_LOG(INFO) << "cast node should has dst_type attr"; + return; + } + auto attr_name = op_info_attrs[0]->name(); + auto type_ptr = std::make_shared(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, 0))); + MS_EXCEPTION_IF_NULL(type_ptr); + auto type_element = type_ptr->element(); + MS_EXCEPTION_IF_NULL(type_element); + auto dtype = type_element->ToString(); + auto dst_type_value = TypeStrToDstType(dtype); + nlohmann::json attr_obj; + attr_obj["value"] = dst_type_value; + attr_obj["valid"] = true; + attr_obj["name"] = attr_name; + attrs_json->push_back(attr_obj); + MS_LOG(INFO) << "CastAttrJsonPass done."; +} + +void TbeAdapter::GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, + size_t real_input_index, std::vector *input_list, + mindspore::kernel::kCreaterType creater_type) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_list); + auto input_x_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); + size_t last_dim = input_x_shape[input_x_shape.size() - 1]; + std::vector tensor_shape = {last_dim}; + std::vector tensor_origin_shape = {last_dim}; + std::string tensor_format = AnfAlgo::GetInputFormat(anf_node, static_cast(real_input_index)); + if (tensor_format == kOpFormat_DEFAULT) { + tensor_format = kOpFormat_NCHW; + } + std::string tensor_origin_format = kOpFormat_NCHW; + std::string tensor_dtype = "float16"; + nlohmann::json input_desc_json; + input_desc_json["dtype"] = tensor_dtype; + input_desc_json["name"] = AnfAlgo::GetCNodeName(anf_node); + input_desc_json["ori_shape"] = tensor_origin_shape; + input_desc_json["ori_format"] = tensor_origin_format; + input_desc_json["shape"] = tensor_shape; + if (creater_type == OP_SELECT_FORMAT) { + input_desc_json["format"] = tensor_origin_format; + } else { + input_desc_json["format"] = tensor_format; + } + input_desc_json["valid"] = true; + input_list->emplace_back(input_desc_json); +} +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h new file mode 100644 index 0000000000..aa09efc11f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H + +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "base/base.h" +#include "backend/kernel_compiler/oplib/opinfo.h" +// Note: This file is mainly used to adapt the ME front-end operator description and +// the TBE back-end operator implementation difference +namespace mindspore { +namespace kernel { +enum kCreaterType : int { SINGLE_BUILD = 0, PREBUILD, OP_SELECT_FORMAT, CHECK_SUPPORTED, OP_PRE_COMPILE }; +namespace tbe { +using FAttrsPass = void (*)(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); +class TbeAdapter { + public: + TbeAdapter() = default; + ~TbeAdapter() = default; + static void NormalizeFuncName(std::string *func_name); + static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node); + static void InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, + nlohmann::json *inputs_json); + static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + static void GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, size_t real_input_index, + std::vector *input_list, kCreaterType creater_type); + + static void FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, + std::vector *inputs_json); + static void FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, + std::vector *reorder_data_layer); + + private: + static void MaximumGradAttrJsonPass(const AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + static void MinimumGradAttrJsonPass(const AnfNodePtr &anf_node, + const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + + static void CastAttrJsonPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, + nlohmann::json *attrs_json); + + static std::map build_json_attr_pass_map_; +}; +} // namespace tbe +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc new file mode 100644 index 0000000000..e7fd94ef84 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +const std::unordered_map type_str_id_maps = { + {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, + {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, + {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, + {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, + {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, + {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, + {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, + {"bool", TypeId::kNumberTypeBool}, +}; + +const std::map type_id_str_maps = { + {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, + {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, + {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, + {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, + {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, + {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, + {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, + {TypeId::kNumberTypeBool, "int8"}, +}; + +const std::map type_str_maps = { + {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, + {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, + {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"}, +}; + +const std::unordered_map type_nbyte_maps = { + {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, + {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, + {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, + {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, +}; + +const std::unordered_map fusion_type_maps = { + {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, + {"SEGMENT", FusionType::SEGMENT}, {"DYNAMIC", FusionType::DYNAMIC}, {"OPAQUE", FusionType::OPAQUE}, +}; + +TypeId DtypeToTypeId(const std::string &dtypes) { + auto iter = type_str_id_maps.find(dtypes); + if (iter == type_str_id_maps.end()) { + MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes; + } + return iter->second; +} + +std::string TypeIdToString(TypeId type_id) { + auto iter = type_id_str_maps.find(type_id); + if (iter == type_id_str_maps.end()) { + MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); + } + return iter->second; +} + +size_t GetDtypeNbyte(const std::string &dtypes) { + auto iter = type_nbyte_maps.find(dtypes); + if (iter == type_nbyte_maps.end()) { + MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes; + } + return iter->second; +} + +FusionType GetFusionType(const std::string &pattern) { + auto iter = fusion_type_maps.find(pattern); + if (iter == fusion_type_maps.end()) { + MS_LOG(INFO) << "Illegal fusion pattern: " << pattern; + return UNKNOWN_FUSION_TYPE; + } + return iter->second; +} + +std::string GetProcessor(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string device; + switch (AnfAlgo::GetProcessor(anf_node)) { + case Processor::AICORE: + device = kProcessorAiCore; + break; + default: + MS_LOG(INFO) << "Unknown processor type." << anf_node->fullname_with_scope(); + break; + } + return device; +} +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h new file mode 100644 index 0000000000..dea058cd56 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_convert_utils.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ + +#include +#include "backend/kernel_compiler/kernel.h" +#include "base/base.h" +#include "ir/dtype/type.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +constexpr auto kProcessorAiCore = "aicore"; +TypeId DtypeToTypeId(const std::string &dtypes); + +std::string TypeIdToString(TypeId type_id); + +size_t GetDtypeNbyte(const std::string &dtypes); + +FusionType GetFusionType(const std::string &pattern); + +std::string GetProcessor(const AnfNodePtr &anf_node); +} // namespace tbe +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc new file mode 100644 index 0000000000..73642b291a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -0,0 +1,1019 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include +#include +#include +#include "frontend/operator/ops.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_adapter.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeAdapter; +using mindspore::kernel::tbe::TbeUtils; +constexpr auto kFusionOpList = "op_list"; +constexpr auto kFusionKernelNamePrfix = "te_fusion"; +constexpr auto kOptional = "optional_"; +constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z"; +constexpr auto kPlatform = "platform"; +constexpr auto kPlatTBE = "TBE"; +constexpr auto kGenModel = "gen_model"; +constexpr auto kSingle = "single"; +constexpr auto kImplPath = "impl_path"; +constexpr auto kJInputs = "inputs"; +constexpr auto kJOutputs = "outputs"; +constexpr auto kJAttrs = "attrs"; +constexpr auto kJKernelName = "kernel_name"; +constexpr auto kJOpInfo = "op_info"; +constexpr auto kJDtype = "dtype"; +constexpr auto kJtype = "type"; +constexpr auto kJName = "name"; +constexpr auto kJOriShape = "ori_shape"; +constexpr auto kJOriFormat = "ori_format"; +constexpr auto kJShape = "shape"; +constexpr auto kJFormat = "format"; +constexpr auto kJValid = "valid"; +constexpr auto kJParamType = "param_type"; +constexpr auto kParamDynamic = "dynamic"; +constexpr auto kParamRequred = "required"; +constexpr auto kJDataType = "data_type"; +constexpr auto kJOutputIndex = "output_index"; +constexpr auto kJOutputDesc = "output_desc"; +constexpr auto kJInputDesc = "input_desc"; +constexpr auto kVTypeInt = "int"; +constexpr auto kVTypeStr = "str"; +constexpr auto kVTypeBool = "bool"; +constexpr auto kVTypeFloat = "float"; +constexpr auto kVTypeListInt = "listInt"; +constexpr auto kVTypeInt32 = "Int32"; +constexpr auto kVTypeListUInt64 = "listUInt64"; +constexpr auto kVTypeListFloat = "listFloat"; +constexpr auto kVTypeListListInt = "listListInt"; +constexpr auto kJValue = "value"; +constexpr auto kJDynIndex = "dyn_index"; +constexpr auto kJFuncName = "func_name"; + +std::string NormalizeFullScopeName(const string &full_scope_name) { + // exp:Default/ReLU-op0 -->Default_ReLU_op0 + string normal_ret = full_scope_name; + std::replace(normal_ret.begin(), normal_ret.end(), '/', '_'); + std::replace(normal_ret.begin(), normal_ret.end(), '-', '_'); + return normal_ret; +} + +bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, + nlohmann::json *kernel_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(kernel_json); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); + MS_EXCEPTION_IF_NULL(op_info_ptr); + (*kernel_json)[kPlatform] = kPlatTBE; + (*kernel_json)[kGenModel] = kSingle; + (*kernel_json)[kImplPath] = op_info_ptr->impl_path(); + nlohmann::json op_info_json; + if (op_info_ptr->impl_path().empty()) { + tbe::TbeAdapter::NormalizeFuncName(&op_name); + } else { + op_name = op_info_ptr->kernel_name(); + } + op_info_json[kJName] = op_name; + // generate inputs json + nlohmann::json inputs_json; + if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { + MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed"; + return false; + } + op_info_json[kJInputs] = inputs_json; + // generate outputs json + nlohmann::json outputs_json; + if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) { + MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed"; + return false; + } + op_info_json[kJOutputs] = outputs_json; + // generate attrs json + nlohmann::json attrs_json; + (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json); + op_info_json[kJAttrs] = attrs_json; + std::string json_str = op_info_json.dump(); + size_t hash_id = std::hash()(json_str); + json_name_ = op_name + "_" + std::to_string(hash_id); + json_info_ = json_str; + if (creater_type_ == PREBUILD) { + op_info_json[kJKernelName] = NormalizeFullScopeName(anf_node->fullname_with_scope()); + } else { + op_info_json[kJKernelName] = json_name_; + } + (*kernel_json)[kJOpInfo] = op_info_json; + if (creater_type_ == SINGLE_BUILD) { + TbeUtils::SaveJsonInfo(json_name_, json_info_); + } + + MS_LOG(INFO) << "Operate type:" << creater_type_ << ", full scope name is :" << anf_node->fullname_with_scope() + << ", json info name is : " << json_name_ << ", kernel json:" << kernel_json->dump(); + + return true; +} + +bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, + bool value, const std::shared_ptr &input_ptr, + const string &op_input_name, size_t input_i, + std::vector *input_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(input_list); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { + TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); + } else { + auto dtype = GetDeviceInputType(anf_node, real_input_index); + auto format = GetDeviceInputFormat(anf_node, real_input_index); + auto shape = GetDeviceInputShape(anf_node, real_input_index); + auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); + if (ori_shape.empty()) { + ori_shape.emplace_back(1); + } + nlohmann::json input_desc_json; + input_desc_json[kJDtype] = dtype; + input_desc_json[kJName] = op_input_name + std::to_string(input_i); + input_desc_json[kJOriShape] = ori_shape; + input_desc_json[kJOriFormat] = kOpFormat_NCHW; + input_desc_json[kJShape] = shape; + input_desc_json[kJFormat] = format; + input_desc_json[kJValid] = value; + input_desc_json[kJParamType] = input_ptr->param_type(); + input_list->emplace_back(input_desc_json); + } + return true; +} + +bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, + const std::shared_ptr &input_ptr, size_t *real_input_index, + string *op_input_name, std::vector *input_list) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(real_input_index); + MS_EXCEPTION_IF_NULL(op_input_name); + MS_EXCEPTION_IF_NULL(input_list); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); + bool value = true; + for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { + if (*real_input_index >= real_input_num) { + if (input_ptr->param_type() == "optional") { + *op_input_name = input_ptr->name() + "_optional_"; + nlohmann::json input_desc_json; + input_desc_json[kJValid] = false; + input_desc_json[kJName] = *op_input_name + std::to_string(*real_input_index); + input_list->emplace_back(input_desc_json); + continue; + } + MS_LOG(ERROR) << "Input num: " << *real_input_index << " is not match op inputs"; + return false; + } + if (op_name == "BatchNorm") { + if (input_ptr->name() == "mean" || input_ptr->name() == "variance") { + auto attr = primitive->GetAttr("is_training"); + MS_EXCEPTION_IF_NULL(attr); + bool is_training = GetValue(attr); + MS_LOG(INFO) << "Op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training " + << is_training; + if (is_training) { + (*real_input_index)++; + break; + } + } + } + bool ret = GenInputDescJson(anf_node, *real_input_index, value, input_ptr, *op_input_name, input_i, input_list); + (*real_input_index)++; + if (!ret) { + return false; + } + } + return true; +} + +bool GetInputNameAndRealNum(const std::shared_ptr &anf_node, const std::shared_ptr &input_ptr, + size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(input_ptr); + MS_EXCEPTION_IF_NULL(dyn_input_index); + MS_EXCEPTION_IF_NULL(input_num); + MS_EXCEPTION_IF_NULL(op_input_name); + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + std::vector dyn_input_sizes; + if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + + if (input_ptr->param_type() == kParamDynamic) { + if (*dyn_input_index >= dyn_input_sizes.size()) { + MS_LOG(ERROR) << "Dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size(); + return false; + } + *input_num = IntToSize(dyn_input_sizes[*dyn_input_index]); + *op_input_name = input_ptr->name() + "_dynamic_"; + (*dyn_input_index)++; + // if optional input is exist + } else { + *input_num = 1; + *op_input_name = input_ptr->name() + "_"; + } + return true; +} + +bool TbeKernelJsonCreator::GenTbeInputsJson(const std::shared_ptr &anf_node, + const std::shared_ptr &op_info, nlohmann::json *inputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(inputs_json); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kAtomicAddrCleanOpName) { + return true; + } + std::vector> inputs_ptr = op_info->inputs_ptr(); + if (inputs_ptr.empty()) { + MS_LOG(INFO) << "Apply kernel " << op_name << "registration info has no input info"; + return true; + } + auto op_info_input_num = inputs_ptr.size(); + size_t dyn_input_index = 0; + size_t real_input_index = 0; + std::vector> inputs_list; + for (size_t i = 0; i < op_info_input_num; i++) { + size_t input_tensor_num; + std::shared_ptr input_ptr = inputs_ptr[i]; + std::string op_input_name; + MS_EXCEPTION_IF_NULL(input_ptr); + if (!GetInputNameAndRealNum(anf_node, input_ptr, &dyn_input_index, &input_tensor_num, &op_input_name)) { + return false; + } + std::vector input_list; + if (!GenInputList(anf_node, input_tensor_num, input_ptr, &real_input_index, &op_input_name, &input_list)) { + return false; + } + inputs_list.emplace_back(input_list); + } + + TbeAdapter::InputOrderPass(op_name, inputs_list, inputs_json); + return true; +} + +bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr &anf_node, + const std::shared_ptr &op_info, nlohmann::json *outputs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(outputs_json); + auto op_name = AnfAlgo::GetCNodeName(anf_node); + if (op_name == kAtomicAddrCleanOpName) { + return true; + } + auto outputs_ptr = op_info->outputs_ptr(); + return GenOutputDescJson(anf_node, outputs_ptr, outputs_json); +} + +bool TbeKernelJsonCreator::GenOutputDescJson( + const std::shared_ptr &anf_node, + const std::vector> &outputs_ptr, nlohmann::json *outputs_json) { + MS_EXCEPTION_IF_NULL(outputs_json); + size_t output_idx = 0; + auto op_name = AnfAlgo::GetCNodeName(anf_node); + size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); + + for (const auto &output_ptr : outputs_ptr) { + size_t output_obj_num = 0; + if (output_ptr->param_type() == kParamRequred) { + output_obj_num = 1; + } else if (output_ptr->param_type() == kParamDynamic) { + if (outputs_ptr.size() > 1) { + MS_LOG(ERROR) << "Dynamic output is unsupported multi output!"; + return false; + } + output_obj_num = real_output_num; + } else { + if (output_idx >= real_output_num) { + MS_LOG(INFO) << "Op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none."; + std::vector output_list; + nlohmann::json output_obj; + output_obj[kJName] = output_ptr->name(); + output_obj[kJValid] = false; + output_list.emplace_back(output_obj); + (*outputs_json).push_back(output_list); + continue; + } else { + output_obj_num = 1; + } + } + std::vector output_list; + GenOutputList(anf_node, output_obj_num, output_ptr, &output_idx, &output_list); + (*outputs_json).push_back(output_list); + } + return true; +} + +void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, + const std::shared_ptr &output_ptr, size_t *output_idx, + std::vector *output_list) { + MS_EXCEPTION_IF_NULL(output_idx); + MS_EXCEPTION_IF_NULL(output_list); + for (size_t i = 0; i < output_obj_num; i++) { + auto dtype = GetDeviceOutputType(anf_node, *output_idx); + auto format = GetDeviceOutputFormat(anf_node, *output_idx); + auto shape = GetDeviceOutputShape(anf_node, *output_idx); + std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); + if (ori_shape.empty()) { + ori_shape.emplace_back(1); + } + nlohmann::json output_obj; + output_obj[kJDtype] = dtype; + output_obj[kJShape] = shape; + output_obj[kJFormat] = format; + output_obj[kJOriShape] = ori_shape; + output_obj[kJOriFormat] = kOpFormat_NCHW; + output_obj[kJName] = output_ptr->name(); + output_obj[kJValid] = true; + output_obj[kJParamType] = output_ptr->param_type(); + output_list->emplace_back(output_obj); + (*output_idx)++; + } +} + +bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_node, + const std::shared_ptr &op_info, nlohmann::json *attrs_json) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(op_info); + MS_EXCEPTION_IF_NULL(attrs_json); + auto attrs_ptr = op_info->attrs_ptr(); + std::string op_name = AnfAlgo::GetCNodeName(anf_node); + if (TbeAdapter::RunAttrPass(anf_node, attrs_ptr, attrs_json)) { + return true; + } + auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); + MS_EXCEPTION_IF_NULL(primitive); + for (const auto &attr_ptr : attrs_ptr) { + std::string attr_name = attr_ptr->name(); + nlohmann::json attr_obj; + attr_obj[kJName] = attr_name; + if (op_name == parallel::LAYER_NORM && attr_obj[kJName] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) { + continue; + } + if (primitive->GetAttr(attr_name) != nullptr) { + auto value = primitive->GetAttr(attr_name); + std::string type = attr_ptr->type(); + ParseAttrValue(type, value, &attr_obj); + attr_obj[kJValid] = true; + } else { + if (op_info->impl_path().empty()) { + attr_obj[kJValid] = false; + } else { + if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) { + MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name + << " is required, but not set."; + } else { + attr_obj[kJValid] = false; + } + } + } + (*attrs_json).push_back(attr_obj); + } + return true; +} + +void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, + nlohmann::json *attr_obj) { + MS_EXCEPTION_IF_NULL(value); + MS_EXCEPTION_IF_NULL(attr_obj); + if (type == kVTypeInt) { + auto attr_value = GetValue(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeStr) { + auto attr_value = GetValue(value); + if (attr_value == kOpFormat_FRAC_Z) { + attr_value = kOpFormat_FRACTAL_Z; + } + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeBool) { + auto attr_value = GetValue(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeFloat) { + auto attr_value = GetValue(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListInt) { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == kVTypeInt32) { + int data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListFloat) { + std::vector attr_value; + auto value_type = value->type(); + MS_EXCEPTION_IF_NULL(value_type); + auto value_type_str = value_type->ToString(); + if (value_type_str == kVTypeFloat) { + auto data = GetValue(value); + attr_value.push_back(data); + } else { + attr_value = GetValue>(value); + } + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListUInt64) { + auto attr_value = GetValue>(value); + (*attr_obj)[kJValue] = attr_value; + } else if (type == kVTypeListListInt) { + auto attr_value = GetValue>>(value); + (*attr_obj)[kJValue] = attr_value; + } else { + MS_LOG(EXCEPTION) << "Type: " << type << "not support"; + } +} + +std::vector TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector shape; + if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { + shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); + } else { + shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); + } + if (shape.empty()) { + shape.emplace_back(1); + } + return shape; +} + +std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + TypeId type_id; + if (creater_type_ == OP_SELECT_FORMAT) { + type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); + } else { + type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); + } + return tbe::TypeIdToString(type_id); +} + +std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::string format = kOpFormat_NCHW; + if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { + format = AnfAlgo::GetInputFormat(anf_node, real_index); + if (format == kOpFormat_FRAC_Z) { + format = kOpFormat_FRACTAL_Z; + } else if (format == kOpFormat_DEFAULT) { + format = kOpFormat_NCHW; + } + } + return format; +} + +std::vector TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::vector shape; + if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { + shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); + } else { + shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); + } + if (shape.empty()) { + shape.emplace_back(1); + } + return shape; +} + +std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + TypeId type_id; + if (creater_type_ == OP_SELECT_FORMAT) { + type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); + } else { + type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); + } + return tbe::TypeIdToString(type_id); +} + +std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { + MS_EXCEPTION_IF_NULL(anf_node); + std::string format = kOpFormat_NCHW; + if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { + format = AnfAlgo::GetOutputFormat(anf_node, real_index); + if (format == kOpFormat_FRAC_Z) { + format = kOpFormat_FRACTAL_Z; + } else if (format == kOpFormat_DEFAULT) { + format = kOpFormat_NCHW; + } + } + return format; +} + +bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, + std::vector *output_size_list) { + if (input_size_list == nullptr || output_size_list == nullptr) { + MS_LOG(ERROR) << "Input size or output size is nullptr"; + return false; + } + input_size_list->clear(); + output_size_list->clear(); + for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) { + for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) { + size_t size_i = 1; + if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) { + std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName]; + MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false."; + continue; + } + for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) { + size_i *= static_cast(j); + } + std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype]; + size_t nbyte = tbe::GetDtypeNbyte(dtype); + size_i *= nbyte; + input_size_list->push_back(size_i); + } + } + for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) { + for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) { + size_t size_i = 1; + if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) { + std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName]; + MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false."; + continue; + } + for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) { + size_i *= static_cast(j); + } + std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype]; + size_t nbyte = tbe::GetDtypeNbyte(dtype); + size_i *= nbyte; + output_size_list->push_back(size_i); + } + } + return true; +} + +bool TbeKernelBuild::GenFusionScopeJson(const std::vector &input_nodes, + const std::vector &compute_nodes, + nlohmann::json *fusion_str, std::string *fusion_kernel) { + MS_EXCEPTION_IF_NULL(fusion_str); + MS_EXCEPTION_IF_NULL(fusion_kernel); + // get input layer info + std::vector> input_layers; + std::map spec_data_input; + if (!GetInputLayers(input_nodes, compute_nodes, &input_layers, &spec_data_input)) { + return false; + } + // gen fusion scopre_op jsom + std::vector compute_list; + (*fusion_kernel) = kFusionKernelNamePrfix; + // index: fusion build option input record, next one from 0 + static size_t index = 0; + auto layer_iter = input_layers.begin(); + auto compute_op_iter = compute_nodes.begin(); + for (; compute_op_iter != compute_nodes.end(); ++compute_op_iter, ++layer_iter) { + nlohmann::json compute_op_str; + (void)GenFusionComputeJson(*compute_op_iter, &layer_iter, &compute_op_str, fusion_kernel, &index); + compute_list.push_back(compute_op_str); + } + index = 0; + // gen data input json + std::vector data_list; + for (const auto &layer : input_layers) { + for (const auto &data_input : layer) { + nlohmann::json data_str; + if (!GenFusionDataInputJson(data_input, spec_data_input, &data_str, &index)) { + MS_LOG(INFO) << "Fusion error: gen fusion datainput json faild."; + return false; + } + data_list.push_back(data_str); + } + } + index = 0; + data_list.insert(data_list.end(), compute_list.begin(), compute_list.end()); + (*fusion_str)[kFusionOpList] = data_list; + return true; +} + +void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { + std::string output_desc_name = anf_node->fullname_with_scope(); + if (node_out_idx > 0) { + output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); + } + (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); + auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); + (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); + auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); + if (ori_shape.empty()) { + ori_shape.emplace_back(1); + } + (*output_desc)[kJOriShape] = ori_shape; + auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); + if (shape.empty()) { + shape.emplace_back(1); + } + (*output_desc)[kJShape] = shape; + auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); + if (format == kOpFormat_DEFAULT) { + format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND; + } + (*output_desc)[kJFormat] = format; + (*output_desc)[kJOriFormat] = kOpFormat_NCHW; + (*output_desc)[kJOutputIndex] = desc_output_idx; + if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { + std::vector spec_shape = {}; + spec_shape.emplace_back(shape[0]); + spec_shape.emplace_back(shape[1]); + spec_shape.emplace_back(shape[2] * shape[3]); + spec_shape.emplace_back(shape[4]); + (*output_desc)[kJShape] = spec_shape; + } else if (fusion_data_type == kFusionReLUGradV2) { + std::vector spec_shape = {}; + spec_shape.emplace_back(shape[0]); + spec_shape.emplace_back(shape[1]); + spec_shape.emplace_back(shape[2] * shape[3]); + spec_shape.emplace_back(16); + (*output_desc)[kJShape] = spec_shape; + (*output_desc)[kJDataType] = kVTypeBool; + } +} + +void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, + size_t output_index, nlohmann::json *output_desc) { + std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); + (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); + (*output_desc)[kJOutputIndex] = output_index; + std::vector shape; + (*output_desc)[kJShape] = shape; +} + +bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, + const std::vector &reorder_layer, + std::map *spec_data_input) { + if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) { + MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. "; + return false; + } + MS_LOG(INFO) << "Fusion info: op_name: " << op_name << "input layer size: " << reorder_layer.size(); + if (op_name == kReluGradV2OpName) { + (*spec_data_input)[reorder_layer[0]] = kFusionReLUGradV2; + } else if (op_name == kAddNOpName) { + for (const auto &it : reorder_layer) { + (*spec_data_input)[it] = kFusionAddN; + } + } + return true; +} + +bool TbeKernelBuild::GetInputLayers(const std::vector &input_nodes, + const std::vector &compute_nodes, + std::vector> *input_layers, + std::map *spec_data_input) { + MS_EXCEPTION_IF_NULL(input_layers); + MS_EXCEPTION_IF_NULL(spec_data_input); + auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { + auto op_name = AnfAlgo::GetCNodeName(it); + return op_name == kConv2DBackpropInputOpName; + }); + bool need_spec = (result != compute_nodes.end()); + size_t input_size = 0; + for (const auto &compute_node : compute_nodes) { + std::vector layer = {}; + std::vector reorder_layer = {}; + MS_EXCEPTION_IF_NULL(compute_node); + auto op_name = AnfAlgo::GetCNodeName(compute_node); + auto ccompute_node = compute_node->cast(); + if (ccompute_node == nullptr) { + MS_LOG(INFO) << "Fusion error: fusion compute node must be cnode"; + return false; + } + MS_LOG(INFO) << "Fusion info: compute name: " << compute_node->fullname_with_scope(); + for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) { + auto input = ccompute_node->input(i); + auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input); + if (find_iter != input_nodes.end()) { + MS_LOG(INFO) << "Fusion info: add compute node's [" << i << "] input: " << input->fullname_with_scope(); + layer.emplace_back((*find_iter)); + } else { + MS_LOG(INFO) << "Fusion warnig: this input [" << i << "] may be pre compute(" << input->fullname_with_scope() + << ") node's output."; + } + } + TbeAdapter::FusionDataOrderPass(op_name, layer, &reorder_layer); + if (need_spec) { + MS_LOG(INFO) << "Fusion info: match conv2d backprop input + ... patten."; + if (!GetSpecInputLayers(op_name, reorder_layer, spec_data_input)) { + return false; + } + } + input_size += reorder_layer.size(); + input_layers->emplace_back(reorder_layer); + } + if (input_nodes.size() != input_size) { + MS_LOG(INFO) << "Fusion error: fusion scope error, layer input:" << input_size + << ", input_node:" << input_nodes.size(); + return false; + } + return true; +} + +bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr &data_input, + const std::map &spec_data_input, + nlohmann::json *data_str, size_t *index) { + MS_EXCEPTION_IF_NULL(data_str); + MS_EXCEPTION_IF_NULL(index); + std::vector output_desc_list; + if (!data_input) { + MS_LOG(INFO) << "Data input is optional node"; + auto name = std::string(kOptional) + std::to_string(*index); + (*data_str)[kJName] = name; + nlohmann::json output_desc; + output_desc[kJName] = name; + output_desc[kJShape] = "NULL"; + output_desc_list.push_back(output_desc); + (*index)++; + } else { + FusionDataType fusion_data_type = kFusionNormal; + if (spec_data_input.find(data_input) != spec_data_input.end()) { + fusion_data_type = spec_data_input.at(data_input); + } + auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0); + auto real_node = kernel_idx.first; + size_t real_idx = kernel_idx.second; + MS_LOG(INFO) << "Real name " << real_node->fullname_with_scope() << " index:" << real_idx; + // kJOutputDesc + nlohmann::json output_desc; + GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type); + output_desc_list.push_back(output_desc); + (*data_str)[kJName] = NormalizeFullScopeName(real_node->fullname_with_scope()); + } + (*data_str)[kJOutputDesc] = output_desc_list; + (*data_str)[kJtype] = "Data"; + return true; +} + +bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. + bool ret = false; + std::vector dyn_input_sizes; + auto dynamic_input_attr = primitive->GetAttr(kAttrDynInputSizes); + if (dynamic_input_attr != nullptr) { + dyn_input_sizes = GetValue>(dynamic_input_attr); + auto real_input_size = cnode->inputs().size() - 1; + auto dyn_input_size = dyn_input_sizes.size(); + if (dyn_input_size != 1) { + MS_LOG(INFO) << "Fusion error: fusion build not support dyn_input_sizes > 1"; + return ret; + } + if (IntToSize(dyn_input_sizes[0]) != real_input_size) { + MS_LOG(INFO) << "Fusion error: dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size" + << real_input_size; + return ret; + } + ret = true; + } + return ret; +} + +size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) { + MS_EXCEPTION_IF_NULL(cnode); + if (is_dynamic_input) { + return 0; + } + MS_EXCEPTION_IF_NULL(cnode); + auto node_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = OpLib::FindOp(node_name, kTBE); + MS_EXCEPTION_IF_NULL(cnode); + if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { + MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); + } + return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); +} + +std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { + static std::map buffer_fussion_op_map = { + {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}}; + string result = origin_type; + auto iter = buffer_fussion_op_map.find(origin_type); + if (iter != buffer_fussion_op_map.end()) { + result = iter->second; + } + return result; +} + +bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(input_desc_list); + std::vector input_desc_list_tmp = {}; + bool is_dynamic_input = IsDynamicInput(cnode); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto input = cnode->input(i); + auto kernel_idx = AnfAlgo::VisitKernel(input, 0); + auto real_node = kernel_idx.first; + size_t real_idx = kernel_idx.second; + MS_LOG(INFO) << "Real name" << real_node->fullname_with_scope() << "index:" << real_idx; + nlohmann::json input_desc; + GenDescJson(real_node, real_idx, real_idx, &input_desc); + if (is_dynamic_input) { + MS_LOG(INFO) << "Node has dynamic input."; + input_desc[kJDynIndex] = (i - 1); + } + input_desc_list_tmp.emplace_back(input_desc); + } + size_t optional_num = GetOptionalInput(cnode, is_dynamic_input); + if (optional_num > 0) { + MS_LOG(INFO) << "Node has optional input."; + for (size_t i = 0; i < optional_num; ++i) { + nlohmann::json optional_input_desc; + optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index); + (*index)++; + (*layer_iter)->emplace_back(nullptr); + input_desc_list_tmp.emplace_back(optional_input_desc); + } + } + auto op_name = AnfAlgo::GetCNodeName(cnode); + TbeAdapter::FusionInputOrderPass(op_name, input_desc_list_tmp, input_desc_list); + return true; +} + +std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &output_used_nums) { + std::vector desc_output_index = {}; + for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { + auto output_use_num_item = output_used_nums[idx]; + MS_LOG(INFO) << "Output used num[" << idx << "] = " << output_use_num_item; + desc_output_index.emplace_back(idx); + if (output_use_num_item > 1) { + desc_output_index.emplace_back(idx); + } + } + return desc_output_index; +} + +bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list) { + MS_EXCEPTION_IF_NULL(output_desc_list); + auto output_size = AnfAlgo::GetOutputTensorNum(cnode); + if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { + auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); + MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); + if (output_used_nums.size() != output_size) { + MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" + << " is not match output used num(" << output_used_nums.size() << ")"; + return false; + } + auto desc_output_index = GetDescOutputIndex(output_used_nums); + for (size_t i = 0; i < output_size; ++i) { + MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; + nlohmann::json output_desc; + GenDescJson(cnode, i, desc_output_index[i], &output_desc); + output_desc_list->emplace_back(output_desc); + } + for (size_t j = output_size; j < desc_output_index.size(); ++j) { + MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j]; + nlohmann::json output_desc; + GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc); + output_desc_list->emplace_back(output_desc); + } + } else { + for (size_t i = 0; i < output_size; ++i) { + nlohmann::json output_desc; + GenDescJson(cnode, i, i, &output_desc); + output_desc_list->push_back(output_desc); + } + } + return true; +} + +bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, + std::vector>::iterator *layer_iter, + nlohmann::json *compute_op_str, std::string *fusion_kernel_name, + size_t *index) { + MS_EXCEPTION_IF_NULL(compute_node); + auto cnode = compute_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // gen input desc + std::vector input_desc_list; + (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); + (*compute_op_str)[kJInputDesc] = input_desc_list; + // gen output desc + std::vector output_desc_list; + if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { + MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); + return false; + } + (*compute_op_str)[kJOutputDesc] = output_desc_list; + // gen others + auto origin_type = AnfAlgo::GetCNodeName(cnode); + // replace special op type for buffer fusion op + auto type = GetRealOpType(origin_type); + (*compute_op_str)[kJtype] = type; + tbe::TbeAdapter::NormalizeFuncName(&type); + (*compute_op_str)[kJFuncName] = type; + (*compute_op_str)[kJName] = NormalizeFullScopeName(cnode->fullname_with_scope()); + (void)(*fusion_kernel_name).append("_"); + (void)(*fusion_kernel_name).append(type); + return true; +} + +size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) { + size_t ret = 1; + for (const auto &shape_item : desc[kJShape]) { + ret *= static_cast(shape_item); + } + std::string data_type = desc[kJDataType]; + size_t nbyte = tbe::GetDtypeNbyte(data_type); + ret *= nbyte; + return ret; +} + +bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, + const std::vector &output_nodes, + std::vector *input_size_list, std::vector *output_size_list) { + MS_EXCEPTION_IF_NULL(input_size_list); + MS_EXCEPTION_IF_NULL(output_size_list); + input_size_list->clear(); + output_size_list->clear(); + + for (const auto &op : fusion_op_list) { + if (op[kJtype] == "Data") { + const auto &data_output_desc = op[kJOutputDesc]; + for (const auto &data_output : data_output_desc) { + if (data_output[kJShape] == "NULL") { + break; + } + auto ret = GetIOSizeImpl(data_output); + input_size_list->push_back(ret); + MS_LOG(INFO) << "Fusion info: scope input name: " << op[kJName] << ", size: " << ret; + } + } + } + + for (const auto &output_node : output_nodes) { + auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0); + auto real_node = kernel_idx.first; + size_t real_idx = kernel_idx.second; + auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope()); + MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx; + for (const auto &op : fusion_op_list) { + if (op[kJName] == normal_name) { + auto op_output_desces = op[kJOutputDesc]; + if (output_node != real_node) { + // tuple_get item + MS_LOG(INFO) << "Output is a tuple getitem node"; + auto output_desc = op_output_desces[real_idx]; + if (output_desc[kJShape].empty()) { + MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx; + return false; + } + auto ret = GetIOSizeImpl(output_desc); + output_size_list->push_back(ret); + MS_LOG(INFO) << "Fusion info: scope output index: " << real_idx << ", size: " << ret; + } else { + for (const auto &output_desc : op_output_desces) { + if (output_desc[kJShape].empty()) { + MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output"; + continue; + } + auto ret = GetIOSizeImpl(output_desc); + output_size_list->push_back(ret); + MS_LOG(INFO) << "Fusion info: scope output size: " << ret; + } + } + } + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h new file mode 100644 index 0000000000..768f811055 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -0,0 +1,122 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "ir/dtype.h" +#include "backend/kernel_compiler/kernel.h" +#include "pybind11/stl.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/tbe/tbe_adapter.h" + +namespace mindspore { +namespace kernel { +// kernel operate type used for generate json + +class TbeKernelBuild { + enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 }; + + public: + static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, + std::vector *output_size_list); + // Ub Fuison + static bool GenFusionScopeJson(const std::vector &input_nodes, + const std::vector &compute_nodes, nlohmann::json *fusion_str, + std::string *fusion_kernel); + static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector &output_nodes, + std::vector *input_size_list, std::vector *output_size_list); + + private: + TbeKernelBuild() = default; + ~TbeKernelBuild() = default; + static bool GenFusionDataInputJson(const std::shared_ptr &data_input, + const std::map &spec_data_input, + nlohmann::json *data_str, size_t *index); + static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, + std::vector>::iterator *layer_iter, + nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); + static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index); + static std::vector GetDescOutputIndex(const std::vector &output_used_nums); + static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list); + static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc, + FusionDataType fusion_data_type = kFusionNormal); + static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, + size_t output_index, nlohmann::json *output_desc); + static size_t GetIOSizeImpl(const nlohmann::json &desc); + static bool GetSpecInputLayers(const std::string &op_name, const std::vector &reorder_layer, + std::map *spec_data_input); + static bool GetInputLayers(const std::vector &input_nodes, + const std::vector &compute_nodes, + std::vector> *input_layers, + std::map *spec_data_input); + static bool IsDynamicInput(const CNodePtr &cnode); + static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); + static std::string GetRealOpType(const std::string &origin_type); +}; + +class TbeKernelJsonCreator { + public: + explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {} + ~TbeKernelJsonCreator() = default; + bool GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json); + std::string json_name() { return json_name_; } + + private: + bool GenTbeInputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, + nlohmann::json *inputs_json); + bool GenTbeOutputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, + nlohmann::json *outputs_json); + bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, + nlohmann::json *attrs_json); + static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); + bool GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, bool value, + const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, + std::vector *input_list); + bool GenOutputDescJson(const std::shared_ptr &anf_node, + const std::vector> &outputs_ptr, nlohmann::json *outputs_json); + bool GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, + const std::shared_ptr &input_ptr, size_t *real_input_index, string *op_input_name, + std::vector *input_list); + void GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, + const std::shared_ptr &output_ptr, size_t *output_idx, + std::vector *output_list); + std::vector GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; + std::vector GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; + std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; + + kCreaterType creater_type_; + std::string json_name_; + std::string json_info_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc new file mode 100644 index 0000000000..e6cb4cf30d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" +#include +#include "runtime/rt.h" +#include "utils/context/ms_context.h" +#include "graphengine/inc/framework/ge_runtime/task_info.h" + +namespace mindspore { +namespace kernel { +using TbeTaskInfoPtr = std::shared_ptr; +using tbe::KernelManager; +bool TbeKernelMod::Launch(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (stream_ptr == nullptr) { + MS_LOG(ERROR) << "stream_ptr should not be nullptr."; + return false; + } + + if (kernel_pack_ == nullptr) { + MS_LOG(ERROR) << "kernel pack should not be nullptr."; + return false; + } + + uint32_t blockdim = 1; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim); + if (func_stub == 0) { + MS_LOG(ERROR) << "GenFuncStub failed."; + return false; + } + + // pack all addresses into a vector. + std::vector runtimeargs; + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), + [](const AddressPtr &output) -> void * { return output->addr; }); + if (!workspace.empty()) { + (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs), + [](const AddressPtr &addr) -> void * { return addr->addr; }); + } + rtL2Ctrl_t *l2ctrl = nullptr; + const void *stubFunc = reinterpret_cast(func_stub); + auto argsSize = static_cast(UlongToUint(sizeof(void *)) * runtimeargs.size()); + if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) { + MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; + return false; + } + + return true; +} + +std::vector TbeKernelMod::GenTask(const std::vector &inputs, + const std::vector &workspaces, + const std::vector &outputs, uint32_t stream_id) { + if (kernel_pack_ == nullptr) { + MS_EXCEPTION(ArgumentError) << "kernel pack should not be nullptr."; + } + + std::vector args; + std::vector sm_desc; + std::vector meta_data; + std::vector input_data_addrs; + std::vector output_data_addrs; + std::vector workspace_addrs; + + // pack all addresses into a vector. + (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), + [](const AddressPtr &output) -> void * { return output->addr; }); + if (!workspaces.empty()) { + (void)std::transform(std::begin(workspaces), std::end(workspaces), std::back_inserter(workspace_addrs), + [](const AddressPtr &workspace) -> void * { return workspace->addr; }); + } + + stream_id_ = stream_id; + auto funcstub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim_); + if (funcstub == 0) { + MS_EXCEPTION(ArgumentError) << "GenFuncStub failed."; + } + + std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); + + MS_LOG(INFO) << "block_dim is:" << block_dim_; + + TbeTaskInfoPtr task_info_ptr = make_shared( + kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs, + output_data_addrs, workspace_addrs, NeedDump()); + return {task_info_ptr}; +} + +vector TbeKernelMod::GenParameters() { + auto kernel_json_info = kernel_pack_->kernel_json_info(); + return kernel_json_info.parameters; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h new file mode 100644 index 0000000000..de48c83d9b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +class TbeKernelMod : public AscendKernelMod { + public: + explicit TbeKernelMod(KernelPackPtr kernel_pack) : kernel_pack_(std::move(kernel_pack)) {} + ~TbeKernelMod() override = default; + + void SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } + void SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } + void SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, uint32_t stream_id) override; + std::vector GenParameters() override; + + private: + KernelPackPtr kernel_pack_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using TbeKernelModPtr = std::shared_ptr; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc new file mode 100644 index 0000000000..48223f40c6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -0,0 +1,326 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" + +#include +#include +#include +#include +#include +#include + +#include "utils/context/ms_context.h" +#include "backend/kernel_compiler/tbe/tbe_adapter.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "./common.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeUtils; +constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; +constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler"; +constexpr auto kStartCompileOp = "start_compile_op"; +constexpr auto kWaitOne = "wait_one"; +constexpr auto kResetTaskInfo = "reset_task_info"; + +bool TbeOpParallelPreBuild(const std::vector &anf_nodes) { + auto build_manger = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manger); + for (const auto &anf_node : anf_nodes) { + // gen kernel json + MS_EXCEPTION_IF_NULL(anf_node); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(OP_PRE_COMPILE); + if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { + MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; + return false; + } + kernel_json["compile_type"] = "pre_build"; + // op build + auto task_id = build_manger->StartCompileOp(kernel_json); + build_manger->SavePreTaskInfo(task_id, anf_node); + } + while (!build_manger->IsAllPreTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Pre Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_EXCEPTION(ArgumentError) << "task pre compile Failed, task id:" << task_id << ", cause:" << task_result; + } + + build_manger->PreTaskFinishProcess(task_id, pre_build_result); + } + return true; +} + +bool TbeOpParallelBuild(const std::vector &anf_nodes) { + auto build_manger = std::make_shared(); + MS_EXCEPTION_IF_NULL(build_manger); + set processed_kernel; + for (const auto &anf_node : anf_nodes) { + // gen kernel json + tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node); + if (AnfAlgo::GetKernelMod(anf_node) != nullptr) { + continue; + } + const std::string &processor = tbe::GetProcessor(anf_node); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(SINGLE_BUILD); + if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { + MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; + return false; + } + // get size + std::vector input_size_list; + std::vector output_size_list; + (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); + // search cache + const std::string &json_name = creator.json_name(); + if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) { + MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name; + continue; + } + // same op not need build, but need wait build finish to set kernel mode + if (processed_kernel.find(json_name) != processed_kernel.end()) { + build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list); + continue; + } + (void)processed_kernel.insert(json_name); + // op build + auto task_id = build_manger->StartCompileOp(kernel_json); + build_manger->SaveTaskInfo(task_id, anf_node, json_name, input_size_list, output_size_list); + } + while (!build_manger->IsAllTaskFinish()) { + int task_id = -1; + char *task_result = nullptr; + char *pre_build_result = nullptr; + auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + if (!ret) { + MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; + } + + if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { + MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; + } + (void)build_manger->TaskFinishProcess(task_id); + } + return build_manger->GenSameOpKernelMod(); +} + +ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); } + +ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); } + +int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const { + PyObject *pRes = nullptr; + PyObject *pArgs = PyTuple_New(1); + std::string json_str = kernel_json.dump(); + PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); + (void)PyTuple_SetItem(pArgs, 0, arg1); + pRes = PyObject_CallMethod(tbe_parallel_compiler_, kStartCompileOp, "O", pArgs); + if (pRes == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function start_compile_op"; + } + int task_id; + (void)PyArg_Parse(pRes, "i", &task_id); + MS_LOG(INFO) << "start compile , task id:" << task_id; + return task_id; +} + +bool ParallelBuildManager::WaitOne(int *task_id, char **task_result, char **pre_build_result) const { + MS_LOG(INFO) << "wait task start."; + MS_EXCEPTION_IF_NULL(task_id); + MS_EXCEPTION_IF_NULL(task_result); + PyObject *pRes = nullptr; + PyObject *pArg = Py_BuildValue("()"); + pRes = PyObject_CallMethod(tbe_parallel_compiler_, kWaitOne, "O", pArg); + if (pRes == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function wait_one"; + return false; + } + (void)PyArg_ParseTuple(pRes, "iss", task_id, task_result, pre_build_result); + return true; +} + +void ParallelBuildManager::SavePreTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node) { + MS_LOG(INFO) << "SavePreTaskInfo, task id: " << task_id; + pre_task_map_[task_id] = anf_node; +} + +void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node, + const std::string &json_name, const std::vector &input_size_list, + const std::vector &output_size_list, int32_t scope_id) { + MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id; + struct KernelBuildTaskInfo task_info; + task_info.node = anf_node.get(); + task_info.json_name = json_name; + if (anf_node == nullptr) { + task_info.processor = tbe::kProcessorAiCore; + } else { + task_info.processor = tbe::GetProcessor(anf_node); + } + task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); + task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); + task_info.scope_id = scope_id; + task_map_[task_id] = task_info; +} + +bool ParallelBuildManager::IsAllPreTaskFinish() const { + MS_LOG(INFO) << "wait pre build process task_num: " << pre_task_map_.size(); + return pre_task_map_.empty(); +} + +bool ParallelBuildManager::IsAllTaskFinish() const { + MS_LOG(INFO) << "wait process task_num: " << task_map_.size(); + return task_map_.empty(); +} + +void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) { + auto task_iter = pre_task_map_.find(task_id); + if (task_iter == pre_task_map_.end()) { + MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; + } + auto node = task_iter->second; + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + std::string start_flag = "fusion_pattern_start"; + std::string end_flag = "fusion_pattern_end"; + int start = pre_build_result.find(start_flag); + int end = pre_build_result.find(end_flag); + if (start != -1 && end != -1 && end >= start) { + std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size()); + if (result == "") { + (void)pre_task_map_.erase(task_iter); + return; + } + transform(result.begin(), result.end(), result.begin(), ::toupper); + FusionType fusion_type = tbe::GetFusionType(result); + builder->SetFusionType(fusion_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + (void)pre_task_map_.erase(task_iter); +} + +std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) { + auto task_iter = task_map_.find(task_id); + if (task_iter == task_map_.end()) { + MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id; + } + auto json_name = task_iter->second.json_name; + auto processor = task_iter->second.processor; + auto kernel_pack = TbeUtils::InsertCache(json_name, processor); + if (kernel_pack == nullptr) { + if (set_kernel_mod) { + MS_EXCEPTION(ArgumentError) << "build kernel name:" << task_iter->second.json_name << " failed."; + } else { + MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed."; + auto ret = std::make_pair(task_iter->second.scope_id, nullptr); + (void)task_map_.erase(task_iter); + return ret; + } + } + auto kernel_mod = GenKernelMod(json_name, processor, task_iter->second.input_size_list, + task_iter->second.output_size_list, kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod); + if (set_kernel_mod) { + AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); + } + auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); + (void)task_map_.erase(task_iter); + MS_LOG(INFO) << "wait process remain task_num:" << task_map_.size(); + return ret; +} + +void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node, const std::string &json_name, + const std::vector &input_size_list, + const std::vector &output_size_list) { + struct KernelBuildTaskInfo task_info; + task_info.node = anf_node.get(); + task_info.json_name = json_name; + task_info.processor = tbe::GetProcessor(anf_node); + task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); + task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); + same_op_list_.push_back(task_info); +} + +bool ParallelBuildManager::GenSameOpKernelMod() const { + for (const auto &task_info : same_op_list_) { + bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list, + task_info.output_size_list, task_info.node); + if (!ret) { + MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache."; + return false; + } + } + return true; +} + +bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std::string &processor, + const std::vector &input_size_list, + const std::vector &output_size_list, mindspore::AnfNode *node) const { + auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); + if (cached_kernel_pack != nullptr) { + MS_LOG(INFO) << "Find cached kernel, kernel json name" << json_name; + auto kernel_mod_ptr = GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + AnfAlgo::SetKernelMod(kernel_mod_ptr, node); + return true; + } else { + return false; + } +} + +KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor, + const vector &input_size_list, + const vector &output_size_list, + const mindspore::kernel::KernelPackPtr &kernel_pack) const { + MS_EXCEPTION_IF_NULL(kernel_pack); + auto kernel_json_info = kernel_pack->kernel_json_info(); + auto kernel_mod_ptr = std::make_shared(kernel_pack); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + kernel_mod_ptr->SetInputSizeList(input_size_list); + kernel_mod_ptr->SetOutputSizeList(output_size_list); + kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces); + return kernel_mod_ptr; +} + +void ParallelBuildManager::ResetTaskInfo() { + if (task_map_.empty()) { + MS_LOG(INFO) << "All tasks are compiled success."; + return; + } + task_map_.clear(); + same_op_list_.clear(); + if (tbe_parallel_compiler_ != nullptr) { + PyObject *pArg = Py_BuildValue("()"); + (void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h new file mode 100644 index 0000000000..a29469b47c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "pybind11/stl.h" +#include +namespace mindspore { +namespace kernel { +bool TbeOpParallelPreBuild(const std::vector &anf_nodes); +bool TbeOpParallelBuild(const std::vector &anf_nodes); + +struct KernelBuildTaskInfo { + AnfNode *node; + std::string processor; + std::string json_name; + std::vector input_size_list; + std::vector output_size_list; + int32_t scope_id; +}; + +class ParallelBuildManager { + public: + ParallelBuildManager(); + ~ParallelBuildManager(); + int32_t StartCompileOp(const nlohmann::json &kernel_json) const; + void SavePreTaskInfo(int32_t task_id, const AnfNodePtr &anf_node); + void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name, + const std::vector &input_size_list, const std::vector &output_size_list, + int32_t scope_id = 0); + void SaveSameOpInfo(const AnfNodePtr &anf_node, const std::string &json_name, + const std::vector &input_size_list, const std::vector &output_size_list); + bool GenSameOpKernelMod() const; + bool SearchInCache(const std::string &json_name, const std::string &processor, + const std::vector &input_size_list, const std::vector &output_size_list, + AnfNode *node) const; + + bool WaitOne(int *task_id, char **task_result, char **pre_build_result) const; + bool IsAllPreTaskFinish() const; + bool IsAllTaskFinish() const; + void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result); + std::pair TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true); + KernelModPtr GenKernelMod(const string &json_name, const string &processor, + const std::vector &input_size_list, const std::vector &output_size_list, + const KernelPackPtr &kernel_pack) const; + void ResetTaskInfo(); + + private: + PyObject *tbe_parallel_compiler_; + std::map pre_task_map_; + std::map task_map_; + std::vector same_op_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h similarity index 100% rename from mindspore/ccsrc/kernel/tbe/tbe_kernel_select/common_utils.h rename to mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc new file mode 100644 index 0000000000..c5e882949b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc @@ -0,0 +1,318 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kInputIndex_0 = 0; +constexpr size_t kChannelN = 0; +constexpr size_t kChannelC = 1; +constexpr size_t kAlignmented16 = 16; +// 1. all shape no scalar and same +// 2. part scalar : no_scalar (shape size > xxx && alig xxx) +// 3. all no_scalar and not same (broad cast xxx dim) +bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { + MS_EXCEPTION_IF_NULL(support_format); + input_num_ = 0; + output_num_ = 0; + input_shapes_.clear(); + output_shapes_.clear(); + if (AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode_ptr_)) { + MS_LOG(INFO) << "This broadcast node has dynamic input."; + auto dynamic_size_vec = AnfAlgo::GetNodeAttr>(cnode_ptr_, kAttrDynInputSizes); + if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { + MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; + } + auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); + PadScalarShape(&dynamic_input_shape0_); + input_shapes_.emplace_back(dynamic_input_shape0_); + input_num_ = 1; + } else { + input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); + for (size_t i = 0; i < input_num_; ++i) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); + PadScalarShape(&input_shape); + input_shapes_.emplace_back(input_shape); + } + } + + output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + for (size_t i = 0; i < output_num_; ++i) { + auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); + PadScalarShape(&output); + output_shapes_.emplace_back(output); + } + AssignSupportFormat(kOpFormat_DEFAULT, support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_NC1HWC0, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelC] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_NC1HWC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is4DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_c_axis = std::any_of( + input_shapes_.begin(), input_shapes_.end(), + [&shape_tmp](const std::vector &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); + if (broadcast_c_axis) { + MS_LOG(INFO) << "This node broadcast c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_NC1HWC0); + } + GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_FRAC_Z, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_FRAC_Z); + } + } + } else { + return false; + } + GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} +bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (!Is4DShape(shape)) { + return false; + } + if (shape[kChannelN] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_C1HWNCoC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is4DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_nc_axis = + std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { + return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); + }); + if (broadcast_nc_axis) { + MS_LOG(INFO) << "This node broadcast n || c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); + } + GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); + return true; + } else { + return false; + } + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_DEFAULT); + } else { + if (shape.size() < kShape2dDims) { + return false; + } + if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { + return false; + } + input_support_format.emplace_back(kOpFormat_FRAC_NZ); + } + } + } else { + auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), + [](const std::vector &elem) { return elem.size() < kShape2dDims; }); + if (less_2dims) { + MS_LOG(INFO) << "This node dim less 2."; + return false; + } + + auto shape_tmp = input_shapes_[0]; + auto broadcast_last_dim = + std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { + return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || + (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); + }); + if (broadcast_last_dim) { + MS_LOG(INFO) << "This node broadcast last channel."; + return false; + } + + input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); + } + GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; +} + +bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + return false; +} + +bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector &shape) const { + return shape.size() == kShape4dDims; +} + +bool TbeKernelBroadCastSelecter::IsSameShape() const { + auto shape = input_shapes_.begin(); + for (const auto &item : input_shapes_) { + if (shape->size() != item.size()) { + return false; + } + for (size_t i = 0; i < shape->size(); ++i) { + if (shape->at(i) != item.at(i)) { + return false; + } + } + } + return true; +} + +void TbeKernelBroadCastSelecter::PadScalarShape(std::vector *shape) const { + MS_EXCEPTION_IF_NULL(shape); + if (shape->empty()) { + shape->emplace_back(1); + } +} + +bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector &shape) const { + return (shape.size() == 1 && shape[0] == 1); +} + +bool TbeKernelBroadCastSelecter::HasScalarInput() const { + bool ret = false; + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + ret = true; + break; + } + } + return ret; +} + +void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, + SupportFormatItem *output_support_item) const { + MS_EXCEPTION_IF_NULL(output_support_item); + for (const auto &shape : output_shapes_) { + if (IsScalarShape(shape)) { + output_support_item->emplace_back(kOpFormat_DEFAULT); + } else { + output_support_item->emplace_back(support_format); + } + } +} + +void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + input_support_format.assign(input_num_, support_format_str); + output_support_format.assign(output_num_, support_format_str); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h new file mode 100644 index 0000000000..4685df6724 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +class TbeKernelBroadCastSelecter { + public: + explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} + ~TbeKernelBroadCastSelecter() = default; + bool GetShapeInfo(SupportFormat *support_format); + bool IsBroadCastSupport5HD(SupportFormat *support_format) const; + bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; + bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; + bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; + bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; + + private: + bool IsSameShape() const; + void PadScalarShape(std::vector *shape) const; + bool Is4DShape(const std::vector &shape) const; + bool IsScalarShape(const std::vector &shape) const; + bool HasScalarInput() const; + void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; + void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; + // broadcast + CNodePtr cnode_ptr_; + size_t input_num_{}; + size_t output_num_{}; + std::vector> input_shapes_; + std::vector> output_shapes_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc new file mode 100644 index 0000000000..61aa9dfb91 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" +#include +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kInputIndex_0 = 0; +constexpr size_t kOutputIndex_0 = 0; +constexpr size_t kChannelN = 0; +constexpr size_t kChannelC = 1; +constexpr size_t kReduceNZMinDim = 3; + +bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { + MS_EXCEPTION_IF_NULL(support_format); + input_shape_.clear(); + output_shape_.clear(); + axis_.clear(); + auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + if (input_num != 1 || output_num != 1) { + MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num + << ", output num: " << output_num; + } + // get input/output shape + input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); + PadScalarShape(&input_shape_); + output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); + PadScalarShape(&output_shape_); + // get keep dim attr + GetReduceAttrKeepDim(); + // get axis attr + axis_ = GetReduceAttrAxis(cnode_ptr_); + AssignSupportFormat(kOpFormat_DEFAULT, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (!Is4DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); + if (reduce_c_axis) { + return false; + } + AssignSupportFormat(kOpFormat_NC1HWC0, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + // like to 5HD + return false; +} + +bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { + return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); +} + +bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { + return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); +} + +bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (input_shape_.size() < kReduceNZMinDim) { + return false; + } + if (axis_.empty()) { + return false; + } + auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { + return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); + }); + if (reduce_last_axis) { + return false; + } + AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); + return true; +} + +bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + if (!Is4DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), + [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); + if (reduce_n_c_axis) { + return false; + } + AssignSupportFormat(format, support_format); + return true; +} + +void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { + if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode_ptr_)) { + MS_LOG(INFO) << "This node does't have keep_attr."; + keep_dims_ = false; + return; + } + keep_dims_ = AnfAlgo::GetNodeAttr(cnode_ptr_, kAttrKeepDims); +} + +void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, + mindspore::kernel::SupportFormat *support_format) const { + MS_EXCEPTION_IF_NULL(support_format); + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + input_support_format.emplace_back(support_format_str); + output_support_format.emplace_back(support_format_str); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); +} + +bool TbeKernelReduceSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } + +void TbeKernelReduceSelecter::PadScalarShape(std::vector *shape) const { + MS_EXCEPTION_IF_NULL(shape); + if (shape->empty()) { + shape->emplace_back(1); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h new file mode 100644 index 0000000000..196bb7b06a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ +#include +#include +#include +#include "ir/anf.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" +namespace mindspore { +namespace kernel { +class TbeKernelReduceSelecter { + public: + explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} + ~TbeKernelReduceSelecter() = default; + bool GetShapeInfo(SupportFormat *support_format); + bool IsReduceSupport5HD(SupportFormat *support_format) const; + bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; + bool IsReduceSupportFracZ(SupportFormat *support_format) const; + bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; + bool IsReduceSupportFracNZ(SupportFormat *support_format) const; + + private: + bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; + void GetReduceAttrKeepDim(); + void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; + bool Is4DShape(const std::vector &shape) const; + void PadScalarShape(std::vector *shape) const; + CNodePtr cnode_ptr_; + std::vector input_shape_{}; + std::vector output_shape_{}; + std::vector axis_{}; + bool keep_dims_ = false; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc new file mode 100644 index 0000000000..d0563e0ffa --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -0,0 +1,623 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "nlohmann/json.hpp" +#include "utils/context/ms_context.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +constexpr auto kName = "name"; +constexpr auto kDtype = "dtype"; +constexpr auto kFormat = "format"; +constexpr auto kPrefixInput = "input"; +constexpr auto kPrefixOutput = "output"; +constexpr char kParamTypeDynamic[] = "dynamic"; +constexpr char kParamTypeRequre[] = "required"; +constexpr char kParamTypeOptional[] = "optional"; +void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); + tbe_selecter.TbeMetadataInfoEx(); +} + +TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list) + : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} + +void TbeKernelSelect::TbeMetadataInfoEx() { + MS_EXCEPTION_IF_NULL(cnode_ptr_); + MS_EXCEPTION_IF_NULL(kernel_info_list_); + node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); + auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); + if (!op_info_ptr) { + MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; + return; + } + MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ + << ", node name: " << cnode_ptr_->fullname_with_scope(); + OpPattern pattern = op_info_ptr->op_pattern(); + if (pattern == kCommonPattern) { + GetCommonPatternKernelInfo(*op_info_ptr); + } else if (pattern == kDynamicFormatPattern) { + GetDynamicFormatPatternKernelInfo(*op_info_ptr); + } else if (pattern == kFormatAgnosticPattern) { + GetAgnosticPatternKernelInfo(*op_info_ptr); + } else if (pattern == kBroadcastPattern) { + GetBroadcastPatternKernelInfo(*op_info_ptr); + } else if (pattern == kReducePattern) { + GetReducePatternKernelInfo(*op_info_ptr); + } else { + MS_LOG(INFO) << "Warning: op pattern is invailed."; + } + // check support + FilterInVaildKernelInfo(); + MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; +} + +void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + // get dynamic inputs + auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); + MS_EXCEPTION_IF_NULL(primitive); + std::vector dyn_input_sizes; + if (primitive->HasAttr(kAttrDynInputSizes)) { + dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); + } + // get real input/output num + size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + const auto inputs_info = op_info.inputs_ptr(); + size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + const auto outputs_info = op_info.outputs_ptr(); + if (inputs_info.empty() && outputs_info.empty()) { + MS_LOG(EXCEPTION) << "op info input & output is null, please check."; + } + // create kernel build info from opinfo + size_t kernel_build_info_num = + inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); + for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + SetTbeBuildCommonInfo(op_info, &builder); + std::vector inputs_format; + std::vector inputs_device_type; + std::vector> inputs_reshape_type; + // input + if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, + &inputs_format, &inputs_device_type, &inputs_reshape_type)) { + break; + } + builder.SetInputsDeviceType(inputs_device_type); + builder.SetInputsFormat(inputs_format); + builder.SetInputReshapeType(inputs_reshape_type); + // output + std::vector outputs_format; + std::vector outputs_device_type; + std::vector> outputs_reshape_type; + if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, + &outputs_format, &outputs_device_type, &outputs_reshape_type)) { + break; + } + builder.SetOutputsDeviceType(outputs_device_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputReshapeType(outputs_reshape_type); + kernel_info_list_->emplace_back(builder.Build()); + } + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + // + OpInfo op_info_new; + CreateNewOpInfo(op_info, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + if (op_info.inputs_ptr().size() != 1) { + MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; + } + auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(INFO) << "Got the unknown format " << format; + format = kOpFormat_DEFAULT; + } + SupportFormat support_format; + SupportFormatItem input_item; + SupportFormatItem output_item; + input_item.assign(op_info.inputs_ptr().size(), format); + output_item.assign(op_info.outputs_ptr().size(), format); + support_format.input_format.emplace_back(input_item); + support_format.output_format.emplace_back(output_item); + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); + SupportFormat support_format; + broadcast_selecter.GetShapeInfo(&support_format); + if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; + } + if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; + } + if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; + } + if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; + } + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { + MS_LOG(INFO) << "start."; + auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); + SupportFormat support_format; + reduce_selecter.GetShapeInfo(&support_format); + if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; + } + if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; + } + if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; + } + if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { + MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; + } + PrintSupportedFormat(support_format); + OpInfo op_info_new; + CreateNewOpInfo(op_info, support_format, &op_info_new); + GetCommonPatternKernelInfo(op_info_new); + MS_LOG(INFO) << "end."; +} + +void TbeKernelSelect::FilterInVaildKernelInfo() { + if (kernel_info_list_->empty()) { + MS_LOG(INFO) << "Warning: get kernel build info failed."; + return; + } + auto kernel_build_info_iter = kernel_info_list_->begin(); + while (kernel_build_info_iter != kernel_info_list_->end()) { + if (!FilterInVaildShape(kernel_build_info_iter)) { + MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); + kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + continue; + } + if (!TbeCheckSupported(kernel_build_info_iter)) { + MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); + kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); + continue; + } + kernel_build_info_iter++; + } +} + +bool TbeKernelSelect::FilterInVaildShape( + const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); + auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); + for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); + auto format = kernel_build_info_inputs_format.at(i); + if (!IsShapeMatchFormat(shape, format)) { + MS_LOG(INFO) << "The " << i << "th input check failed."; + return false; + } + } + auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); + for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { + auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); + auto format = kernel_build_info_outputs_format.at(j); + if (!IsShapeMatchFormat(shape, format)) { + MS_LOG(INFO) << "The " << j << "th input check failed."; + return false; + } + } + return true; +} + +bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const std::string &format) { + if (format == kOpFormat_DEFAULT) { + return true; + } + static std::set kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; + // if format is default, it remarkes support all format + if (kOpFormatList.find(format) == kOpFormatList.end()) { + MS_LOG(EXCEPTION) << "Got the unknown format " << format; + } + // server not support format with C04 suffix + if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != + kServerNotSupportFormat.end()) { + MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; + return false; + } + // not support format: + // 1 NDHWC with shape size != 5 + // 2 FRAC_NZ with shape size < 2 + // 3 !NDHWC with shape size > 4 + if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || + (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || + (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { + MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); + return false; + } + return true; +} + +bool TbeKernelSelect::TbeCheckSupported( + const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { + MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); + static const std::set kCheckSupportedOpType = {parallel::MATMUL, + parallel::BATCHMATMUL, + parallel::TOPK, + parallel::IN_TOPK, + parallel::PACK, + parallel::GATHER_ND, + parallel::UNSORTEF_SEGMENT_MIND, + parallel::UNSORTEF_SEGMENT_PRODD, + parallel::CAST}; + auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); + if (iter == kCheckSupportedOpType.end()) { + return true; + } + MS_LOG(INFO) << "Check support start."; + // replace kernel_info with current kernel info + auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); + AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); + nlohmann::json kernel_json; + TbeKernelJsonCreator creator(CHECK_SUPPORTED); + bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; + } + ret = TbePythonFuncs::CheckSupported(kernel_json); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); + return ret; +} + +void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, + mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { + MS_EXCEPTION_IF_NULL(builder); + builder->SetProcessor(AICORE); + std::string fusion_type = op_info.fusion_type(); + if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { + builder->SetFusionType(tbe::GetFusionType(fusion_type)); + } + builder->SetOpPattern(op_info.op_pattern()); + builder->SetKernelType(TBE_KERNEL); +} + +bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, + const std::vector> &ios_info, + const std::vector &dyn_input_sizes, std::vector *formats, + std::vector *device_types, std::vector> *reshape_types) { + MS_EXCEPTION_IF_NULL(formats); + MS_EXCEPTION_IF_NULL(device_types); + MS_EXCEPTION_IF_NULL(reshape_types); + size_t dynamic_input_index = 0; + size_t real_io_tensor_index = 0; + size_t io_info_index = 0; + size_t io_info_num = ios_info.size(); + for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { + std::shared_ptr io_info_item = ios_info[io_info_index]; + auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); + std::string kernel_build_info_format; + if (!io_info_item->formats().empty()) { + kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); + } + std::string io_param_type = io_info_item->param_type(); + std::vector reshape_type; + StringToAxisVector(io_info_item->reshape_type(), &reshape_type); + if (io_param_type == kParamTypeDynamic) { + // dynamic io + if (is_input) { + if (dynamic_input_index >= dyn_input_sizes.size()) { + MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index + << ", dyn_input_sizes size: " << dyn_input_sizes.size(); + } + int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; + for (int i = 0; i < dynamic_input_size; ++i) { + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + } + dynamic_input_index++; + real_io_tensor_index += dynamic_input_size; + } else { + if (ios_info.size() != 1) { + MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; + } + for (size_t i = 0; i < real_io_tensor_num; ++i) { + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + } + real_io_tensor_index += real_io_tensor_num; + } + } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { + // requre or optional io + device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); + formats->emplace_back(kernel_build_info_format); + reshape_types->emplace_back(reshape_type); + real_io_tensor_index++; + } else { + MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; + } + } + + if (io_info_index != io_info_num) { + MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num + << "), this node may has optional input/output."; + } + if (real_io_tensor_index != real_io_tensor_num) { + std::string io_type = is_input ? "inputs " : "outputs"; + MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num + << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index + << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; + return false; + } + return true; +} + +void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { + MS_EXCEPTION_IF_NULL(reshape_type_vec); + for (const auto &c : reshape_type_str) { + switch (c) { + case 'N': + reshape_type_vec->push_back(kernel::N); + break; + case 'C': + reshape_type_vec->push_back(kernel::C); + break; + case 'H': + reshape_type_vec->push_back(kernel::H); + break; + case 'W': + reshape_type_vec->push_back(kernel::W); + break; + default: + MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; + } + } +} + +void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, + const std::vector> &support_format_item, size_t index, + mindspore::kernel::OpIOInfo *op_io_info_new) { + MS_EXCEPTION_IF_NULL(op_io_info_new); + op_io_info_new->set_index(op_io_info.index()); + op_io_info_new->set_name(op_io_info.name()); + op_io_info_new->set_param_type(op_io_info.param_type()); + op_io_info_new->set_need_compile(op_io_info.need_compile()); + op_io_info_new->set_reshape_type(op_io_info.reshape_type()); + op_io_info_new->set_shape(op_io_info.shape()); + // dtype + std::vector dtype_new; + auto dtype = op_io_info.dtypes(); + for (size_t i = 0; i < support_format_item.size(); ++i) { + dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); + } + op_io_info_new->set_dtypes(dtype_new); + // format + std::vector format_new; + for (const auto &formats : support_format_item) { + auto format = formats.at(index); + for (size_t j = 0; j < dtype.size(); ++j) { + format_new.emplace_back(format); + } + } + op_io_info_new->set_formats(format_new); +} + +std::vector TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { + const std::map kDynamicFormatMap = { + {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; + if (op_select_json_item.empty()) { + MS_LOG(EXCEPTION) << "Op select ret item is null."; + } + const char space = ' '; + const char sep = ','; + std::string op_select_tmp = op_select_json_item + ","; + std::vector ret; + auto begin = op_select_tmp.find_first_not_of(space, 0); + auto sep_pos = op_select_tmp.find(sep); + if (begin >= sep_pos) { + MS_LOG(EXCEPTION) << "Select ret json is error."; + } + while (sep_pos != std::string::npos) { + auto obj = op_select_tmp.substr(begin, sep_pos - begin); + if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { + obj = kDynamicFormatMap.at(obj); + } + ret.emplace_back(obj); + begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); + sep_pos = op_select_tmp.find(sep, begin); + } + return ret; +} + +std::string TbeKernelSelect::OpSelectFormat() { + nlohmann::json kernel_json; + std::string res_json_str; + TbeKernelJsonCreator creator(OP_SELECT_FORMAT); + bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); + if (!ret) { + MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; + } + res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); + if (res_json_str.empty()) { + MS_LOG(EXCEPTION) << "op select format error."; + } + MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; + return res_json_str; +} + +void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, + mindspore::kernel::OpInfo *op_info_new) { + MS_EXCEPTION_IF_NULL(op_info_new); + if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || + op_info.outputs_ptr().size() != support_format.output_format[0].size()) { + MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() + << ", input support size: " << support_format.input_format[0].size() + << ", op info output size: " << op_info.outputs_ptr().size() + << ", output support size: " << support_format.output_format[0].size(); + } + *op_info_new = op_info; + op_info_new->ClearInputs(); + op_info_new->ClearOutputs(); + for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { + auto input = op_info.inputs_ptr().at(i); + auto input_new = std::make_shared(); + CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); + op_info_new->add_inputs_ptr(input_new); + } + for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { + auto output = op_info.outputs_ptr().at(j); + auto output_new = std::make_shared(); + CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); + op_info_new->add_outputs_ptr(output_new); + } +} + +struct SelectOpIOInfo { + std::string name; + std::vector dtypes; + std::vector formats; +}; + +void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, + mindspore::kernel::OpInfo *op_info_new) { + MS_EXCEPTION_IF_NULL(op_info_new); + auto op_seclect_json = OpSelectFormat(); + if (!op_seclect_json.empty()) { + nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); + if (!json_obj.is_object()) { + MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; + } + std::vector inputs; + std::vector outputs; + for (const auto &item : json_obj.items()) { + const std::string &item_name = item.key(); + bool is_input = (item_name.find(kPrefixInput) != std::string::npos); + bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); + if (!is_input && !is_output) { + MS_LOG(EXCEPTION) << "op select ret json is error."; + } + if (is_input) { + SelectOpIOInfo select_input; + select_input.name = item.value().at(kName); + std::string input_dtype_item = item.value().at(kDtype); + select_input.dtypes = SplitStrToVec(input_dtype_item); + std::string input_format_item = item.value().at(kFormat); + select_input.formats = SplitStrToVec(input_format_item); + inputs.emplace_back(select_input); + } else if (is_output) { + SelectOpIOInfo select_output; + select_output.name = item.value().at(kName); + std::string input_dtype_item = item.value().at(kDtype); + select_output.dtypes = SplitStrToVec(input_dtype_item); + std::string input_format_item = item.value().at(kFormat); + select_output.formats = SplitStrToVec(input_format_item); + outputs.emplace_back(select_output); + } + } + + if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { + MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; + } + + *op_info_new = op_info; + op_info_new->ClearInputs(); + op_info_new->ClearOutputs(); + for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { + auto input_new = std::make_shared(); + CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); + op_info_new->add_inputs_ptr(input_new); + } + for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { + auto output_new = std::make_shared(); + CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); + op_info_new->add_outputs_ptr(output_new); + } + } +} + +void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, + const std::vector &support_dtype, + const std::vector &support_format, + mindspore::kernel::OpIOInfo *op_io_info_new) { + MS_EXCEPTION_IF_NULL(op_io_info_new); + op_io_info_new->set_index(op_io_info.index()); + op_io_info_new->set_name(op_io_info.name()); + op_io_info_new->set_param_type(op_io_info.param_type()); + op_io_info_new->set_need_compile(op_io_info.need_compile()); + op_io_info_new->set_reshape_type(op_io_info.reshape_type()); + op_io_info_new->set_shape(op_io_info.shape()); + // dtype && format + op_io_info_new->set_dtypes(support_dtype); + op_io_info_new->set_formats(support_format); +} + +void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { + if (support_format.input_format.size() != support_format.output_format.size()) { + MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" + << support_format.output_format.size() << ") size not match."; + } + for (size_t i = 0; i < support_format.input_format.size(); ++i) { + auto input_items = support_format.input_format.at(i); + auto output_items = support_format.output_format.at(i); + std::string print_str = "["; + for (const auto &input : input_items) { + print_str.append(input); + print_str.append(", "); + } + print_str.append("] -->"); + for (const auto &output : output_items) { + print_str.append(output); + print_str.append(", "); + } + MS_LOG(INFO) << "Support format: " << print_str; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h new file mode 100644 index 0000000000..679c56379f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_TBE_KERNEL_SELECT_H +#define MINDSPORE_TBE_KERNEL_SELECT_H + +#include +#include +#include +#include "backend/kernel_compiler/oplib/opinfo.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" + +namespace mindspore { +namespace kernel { +void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); + +class TbeKernelSelect { + using OpInfoPtr = std::shared_ptr; + using KernelBuildInfoIter = std::vector>::iterator; + + public: + TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list); + ~TbeKernelSelect() = default; + void TbeMetadataInfoEx(); + + private: + void GetCommonPatternKernelInfo(const OpInfo &op_info); + void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); + void GetAgnosticPatternKernelInfo(const OpInfo &op_info); + void GetBroadcastPatternKernelInfo(const OpInfo &op_info); + void GetReducePatternKernelInfo(const OpInfo &op_info); + void FilterInVaildKernelInfo(); + bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); + static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); + bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); + static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); + bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, + const std::vector> &ios_info, const std::vector &dyn_input_sizes, + std::vector *formats, std::vector *device_types, + std::vector> *reshape_types); + static void StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec); + static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); + static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, + const std::vector> &support_format_item, size_t index, + OpIOInfo *op_io_info_new); + // op select(dynamic) + void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); + static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector &support_dtype, + const std::vector &support_format, OpIOInfo *op_io_info_new); + static std::vector SplitStrToVec(const std::string &op_select_json_item); + std::string OpSelectFormat(); + + static void PrintSupportedFormat(const SupportFormat &support_format); + + private: + CNodePtr cnode_ptr_; + std::vector> *kernel_info_list_; + std::string node_name_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_TBE_KERNEL_SELECT_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc new file mode 100644 index 0000000000..facb07991a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "common/utils.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace kernel { +using mindspore::kernel::tbe::TbeUtils; +constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; +constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; +constexpr auto kOpSelectFormatFunc = "op_select_format"; +constexpr auto kCheckSupportedFunc = "check_supported"; +constexpr auto kTBEException = "TBEException"; + +PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; +PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; +PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr; +PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr; +bool TbePythonFuncs::Init() { + static bool initialized = false; + if (initialized) { + return true; + } + // Initialize cache + TbeUtils::LoadCache(); + + // tbe_process + PyObject *pTbeProcessModule = nullptr; + pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule); + if (pTbeProcessModule == nullptr) { + MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module."; + return false; + } + + pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc); + if (pCreateTbeParallelCompilerFunc_ == nullptr) { + MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule + << "], FuncName:[" << kCreateTbeParallelCompilerFunc << "]."; + return false; + } + + pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr); + if (pTbeCompiler_ == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler."; + return false; + } + + pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc); + if (pOpSelectFormatFunc_ == nullptr) { + MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule + << "], FuncName:[" << kOpSelectFormatFunc << "]."; + return false; + } + + pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc); + if (pCheckSupportedFunc_ == nullptr) { + MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule + << "], FuncName:[" << kCheckSupportedFunc << "]."; + return false; + } + initialized = true; + MS_LOG(INFO) << "TbePythonFuncs initialized Success."; + return true; +} + +std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) { + char *pChar = nullptr; + std::string str_res; + if (PyObj == nullptr) { + MS_LOG(ERROR) << "Input parameter is nullptr."; + return str_res; + } + PyObject *strArgs = PyObject_Str(PyObj); + if (strArgs != nullptr) { + (void)PyArg_Parse(strArgs, "s", &pChar); + } + if (pChar == nullptr) { + MS_LOG(ERROR) << "pChar is nullptr."; + return str_res; + } + str_res = pChar; + return str_res; +} + +std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { + PyObject *pArg = nullptr; + PyObject *pRet = nullptr; + std::string res_json_str; + + if (!Init()) { + MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; + return res_json_str; + } + + // assembly Args + pArg = PyTuple_New(1); + std::string json_str = kernel_json.dump(); + (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str())); + if (pArg == nullptr) { + MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; + return res_json_str; + } + + // call functions + if (pOpSelectFormatFunc_ == nullptr) { + MS_LOG(ERROR) << "function is nullptr."; + return res_json_str; + } + + pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg); + if (pRet == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc + << "], function args:" << PyObjectToStr(pArg); + } + + char *pstr = nullptr; + (void)PyArg_Parse(pRet, "s", &pstr); + res_json_str = pstr; + if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str + << " ,function args:" << PyObjectToStr(pArg); + } + return res_json_str; +} + +bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { + PyObject *pArg = nullptr; + PyObject *pRes = nullptr; + bool ret = false; + + if (!Init()) { + MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; + return ret; + } + // assembly Args + pArg = PyTuple_New(1); + std::string json_str = kernel_json.dump(); + PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); + (void)PyTuple_SetItem(pArg, 0, arg1); + if (pArg == nullptr) { + MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; + return ret; + } + + // call functions + if (pCheckSupportedFunc_ == nullptr) { + MS_LOG(ERROR) << "function is nullptr."; + return ret; + } + + pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg); + if (pRes == nullptr) { + PyErr_Print(); + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc + << "], function args: " << PyObjectToStr(pArg); + } + if (PyBool_Check(pRes)) { + ret = PyObject_IsTrue(pRes) != 0; + } else { + char *pstr = nullptr; + (void)PyArg_Parse(pRes, "s", &pstr); + std::string res_str = pstr; + if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str + << ", function args: " << PyObjectToStr(pArg); + } + } + + return ret; +} + +PyObject *TbePythonFuncs::TbeParallelCompiler() { + if (!Init()) { + MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; + return nullptr; + } + return pTbeCompiler_; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h similarity index 100% rename from mindspore/ccsrc/kernel/tbe/tbe_python_funcs.h rename to mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_python_funcs.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc new file mode 100644 index 0000000000..76ef7b08d5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.cc @@ -0,0 +1,254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/tbe/tbe_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "runtime/kernel.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "runtime/device/kernel_info.h" +#include "ir/dtype/type.h" +#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "securec/include/securec.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +constexpr auto kCceKernelMeta = "./kernel_meta/"; +constexpr auto kJsonSuffix = ".json"; +constexpr auto kInfoSuffix = ".info"; + +uintptr_t KernelManager::kernel_stub_gen_ = 0; +std::unordered_map KernelManager::info_table_ = {}; + +void TbeUtils::SaveJsonInfo(const std::string &json_name, const std::string &info) { + char real_path[PATH_MAX] = {0}; + std::string path = kCceKernelMeta + json_name + kInfoSuffix; + if (path.size() > PATH_MAX) { + MS_LOG(ERROR) << "file path: " << path << "is too long."; + return; + } + std::ifstream fin(path); + if (fin) { + MS_LOG(INFO) << "json file exist, no need to create."; + return; + } + std::ofstream file_write; + file_write.open(path); + if (!file_write.is_open()) { + return; + } + file_write << info << std::endl; + file_write.close(); + if (realpath(path.c_str(), real_path) == nullptr) { + MS_LOG(INFO) << "dir: " << path << "does not exit."; + return; + } + MS_LOG(INFO) << "real path is: " << real_path; + if (chmod(real_path, S_IRUSR) == -1) { + MS_LOG(INFO) << "modify file: " << real_path << "to read only fail."; + } +} + +void TbeUtils::LoadCache() { + static bool has_load = false; + if (!has_load) { + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map != nullptr && !bin_map->ReadIndex(kCceKernelMeta)) { + MS_LOG(INFO) << "Cache initialize failed[" << kCceKernelMeta << "]"; + } else { + MS_LOG(INFO) << "Cache initialize to " << kCceKernelMeta; + } + has_load = true; + } +} + +KernelPackPtr TbeUtils::SearchCache(const std::string &kernel_name, const std::string &processor) { + // search cache. + KernelMeta *bin_map = KernelMeta::GetInstance(); + if (bin_map == nullptr) { + MS_LOG(INFO) << "kernel cache is invalid."; + return nullptr; + } + return bin_map->GetKernelPack(kernel_name, processor); +} + +KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::string &processor) { + MS_LOG(INFO) << "kernel name: " << kernel_name << ", processr:" << processor; + if (processor != kProcessorAiCore) { + MS_LOG(EXCEPTION) << "process type should be aicore, actually is: " << processor; + } + return SearchCache(kernel_name, processor); +} + +int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, + const string &magic) { + static std::map magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, + {"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN}, + {"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU}, + {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}}; + // object for device register. + rtDevBinary_t dev_bin; + dev_bin.data = kernel_buffer.contents; + auto iter = magic_maps.find(magic); + if (iter == magic_maps.end()) { + MS_LOG(INFO) << "Invalid magic number: " << magic; + return -1; + } + dev_bin.magic = iter->second; + dev_bin.length = kernel_buffer.len; + dev_bin.version = 2; + if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) { + MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error."; + return -1; + } + return 0; +} + +uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload, + uint32_t *block_dim) { + auto kernel = kernel_pack.GetKernel(); + if (kernel == nullptr) { + MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr."; + } + auto kernel_contents = kernel->contents; + if (kernel_contents == nullptr) { + MS_LOG(EXCEPTION) << "Invalid kernel context, json or kernel is nullptr."; + } + auto kernel_json_info = kernel_pack.kernel_json_info(); + + *block_dim = kernel_json_info.block_dim; + string func_name = kernel_json_info.kernel_name; + string magic = kernel_json_info.magic; + + if (!force_reload) { + // use the cached object. + auto iter = info_table_.find(func_name); + if (iter != info_table_.end()) { + auto kernelmeta = iter->second; + *block_dim = kernelmeta->block_dim_; + return kernelmeta->func_stub_; + } + } + void *module = nullptr; + if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) { + MS_LOG(INFO) << "Call runtime BinaryRegister error."; + return 0; + } + // to diff different funcs. + uintptr_t func_stub = ++kernel_stub_gen_; + if (RT_ERROR_NONE != + rtFunctionRegister(module, reinterpret_cast(func_stub), func_name.c_str(), func_name.c_str(), 0)) { + MS_LOG(INFO) << "Call runtime rtFunctionRegister error."; + return 0; + } + // cache the registered kernelmeta. + info_table_[func_name] = std::make_shared(KernelMetaInfo{func_stub, *block_dim}); + return func_stub; +} + +std::string KernelManager::GetStubFuncName(const KernelPackPtr &kernel_pack) { + MS_EXCEPTION_IF_NULL(kernel_pack); + auto kernel_json_info = kernel_pack->kernel_json_info(); + return kernel_json_info.kernel_name; +} + +KernelMeta *KernelMeta::GetInstance() { + static KernelMeta inst; + return &inst; +} + +bool KernelMeta::ReadIndex(const std::string &bin_dir) { + DIR *dir = opendir(bin_dir.c_str()); + if (dir == nullptr) { + auto ret = mkdir(bin_dir.c_str(), S_IRWXG | S_IRWXU); + if (ret != 0) { + MS_LOG(INFO) << "kernel dir: " << bin_dir << "not exist"; + return false; + } + dir = opendir(bin_dir.c_str()); + } + struct dirent *entry; + while ((entry = readdir(dir)) != nullptr) { + string bin_dir_tmp = bin_dir; + std::string cce_json = entry->d_name; + if (cce_json.length() <= 5) { + continue; + } + std::string suffix = cce_json.substr(cce_json.length() - 5); + if (suffix != kJsonSuffix) { + continue; + } + auto sp = cce_json.rfind('/'); + if (sp != std::string::npos) { + continue; + } + sp = cce_json.rfind('.'); + if (sp == std::string::npos) { + continue; + } + auto kernel_name = cce_json.substr(0, sp); + (void)bin_dir_tmp.append("/"); + (void)bin_dir_tmp.append(cce_json); + kernel_index_map_[kernel_name] = bin_dir_tmp; + } + (void)closedir(dir); + + MS_LOG(INFO) << "Cache kernel initialized, kernel size: " << kernel_index_map_.size(); + return true; +} + +KernelPackPtr KernelMeta::GetKernelPack(const std::string &kernel_name, const std::string &processor) { + KernelPackPtr ret = nullptr; + // 1. pack has been created + auto kernel_pack_iter = kernel_pack_map_.find(kernel_name); + if (kernel_pack_iter != kernel_pack_map_.end()) { + MS_LOG(INFO) << "kernel pack [" << kernel_name << "]has been created."; + ret = kernel_pack_iter->second; + } else { + // 2. kernel file has been create, but pack does not been created. + std::string cce_json = kCceKernelMeta; + (void)cce_json.append(kernel_name).append(kJsonSuffix); + ret = std::make_shared(); + if (!ret->LoadKernelMeta(cce_json, processor)) { + MS_LOG(INFO) << "Read cache json and bin file failed[" << cce_json << "]"; + return nullptr; + } + kernel_pack_map_[kernel_name] = ret; + auto iter = kernel_index_map_.find(kernel_name); + if (iter == kernel_index_map_.end()) { + MS_LOG(INFO) << "kernel name [" << kernel_name << "] has been ceated first."; + kernel_index_map_[kernel_name] = cce_json; + } + } + return ret; +} +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h new file mode 100644 index 0000000000..39ddaaa73d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_utils.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ +#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ +#include +#include +#include +#include +#include +#include + +#include "backend/session/kernel_graph.h" +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +namespace tbe { +using std::string; +using std::vector; + +class TbeUtils { + public: + TbeUtils() = default; + + ~TbeUtils() = default; + + static void SaveJsonInfo(const std::string &json_name, const std::string &info); + + static void LoadCache(); + + static KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); + + static KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); +}; + +struct KernelMetaInfo { + uintptr_t func_stub_; + uint32_t block_dim_; +}; +using KernelMetaPtr = std::shared_ptr; + +class KernelManager { + public: + static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim); + static std::string GetStubFuncName(const KernelPackPtr &kernel_pack); + + private: + KernelManager() = default; + ~KernelManager() = default; + static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic); + static std::unordered_map info_table_; + static uintptr_t kernel_stub_gen_; +}; + +class KernelMeta { + public: + static KernelMeta *GetInstance(); + bool ReadIndex(const std::string &bin_dir); + KernelPackPtr GetKernelPack(const std::string &kernel_name, const std::string &processor); + + private: + KernelMeta() = default; + ~KernelMeta() = default; + std::unordered_map kernel_index_map_{}; + std::unordered_map kernel_pack_map_{}; +}; +} // namespace tbe +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/CMakeLists.txt b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt new file mode 100644 index 0000000000..ee1532a416 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "common/*.cc" + "mem_reuse/*.cc" + "pass/*.cc" + "gpu/*.cc" +) + +if (ENABLE_D) + file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc") + list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST}) +endif () + +set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) +add_library(_mindspore_backend_optimizer_obj OBJECT ${_PREACTIVATE_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc new file mode 100644 index 0000000000..40e7a29c92 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -0,0 +1,495 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ascend_backend_optimization.h" +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fission/bn_split.h" +#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" +#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" +#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" +#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" +#include "backend/optimizer/pass/communication_op_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h" +#include "backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" +#include "backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" +#include "backend/optimizer/ascend/ir_fission/topk_split.h" +#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h" +#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" +#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" +#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" +#include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/optimizer/pass/optimize_dependence.h" +#include "backend/optimizer/pass/erase_visit_attr.h" +#include "backend/optimizer/ascend/format_type/insert_cast.h" +#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" +#include "backend/optimizer/pass/eliminate_redundant_op.h" +#include "backend/optimizer/pass/common_subexpression_elimination.h" +#include "backend/optimizer/pass/fuse_graph_kernel.h" +#include "backend/optimizer/pass/fuse_basic.h" +#include "backend/optimizer/pass/add_atomic_clean.h" +#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" +#include "backend/optimizer/ascend/format_type/check_consistency.h" +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include "backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" +#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" +#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" +#include "backend/optimizer/ascend/ir_fission/addn_fission.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include "backend/optimizer/ascend/ir_fission/split_fission.h" +#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" +#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" +#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" +#include "utils/context/ms_context.h" +#include "utils/config_manager.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" + +namespace mindspore { +namespace opt { +namespace { +void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} +} // namespace + +void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto data_layout_pm = std::make_shared("pynative_transop_pm"); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(data_layout_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + MS_EXCEPTION_IF_NULL(optimizer); + auto common_process = std::make_shared("graph_kernel_common_process"); + MS_EXCEPTION_IF_NULL(common_process); + common_process->AddPass(std::make_shared()); + common_process->AddPass(std::make_shared()); + optimizer->AddPassManager(common_process); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendDataLayout(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto data_layout_pm = std::make_shared("transop_pm"); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(data_layout_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendMixPrecision(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto mixed_precision_pm = std::make_shared("cast_pm"); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(mixed_precision_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" + + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + DumpIRProto(kernel_graph, "before_hwopt_" + std::to_string(kernel_graph->graph_id())); + } + auto optimizer = std::make_shared(); + auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); + if (context_ptr->execution_mode() == kPynativeMode) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } else { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } + ir_fusion_pm->AddPass(std::make_shared()); + if (context_ptr->ir_fusion_flag()) { + AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); + } + + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(ir_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } +} + +void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->ir_fusion_flag()) { + MS_LOG(INFO) << "IRFusion is not enable, skip"; + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir"; + DumpIR(file_path, kernel_graph); + } + auto optimizer = std::make_shared(); + auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + + optimizer->AddPassManager(ir_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir"; + DumpIR(file_path, kernel_graph); + } +} + +void AscendBackendOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + // data layout optimization + AscendDataLayout(kernel_graph); + // mixed precision optimization + AscendMixPrecision(kernel_graph); + // other optimization + auto optimizer = std::make_shared(); + auto other_pm = std::make_shared("other_pm"); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(other_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + // buffer fusion + AscendBackendUBFusionOptimization(kernel_graph); + + // other2 optimization + auto optimizer2 = std::make_shared(); + auto other2_pm = std::make_shared("other2_pm"); + other2_pm->AddPass(std::make_shared()); + other2_pm->AddPass(std::make_shared()); + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + other2_pm->AddPass(std::make_shared()); + } + other2_pm->AddPass(std::make_shared()); + optimizer2->AddPassManager(other2_pm); + (void)optimizer2->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph, true); + DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); + kernel_graph->DumpFuncGraph("hwopt_d_end"); + } +} + +void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph); + } + + // Fuse graph kernels with basic ops + FuseGraphKernel(kernel_graph, is_before_kernel_select); + + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph, true); + } +} + +void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph, true); + } + + // Fuse basic ops with basic ops + FuseBasic(kernel_graph, is_before_kernel_select); + + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + + std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + + ".ir"; + DumpIR(file_path, kernel_graph, true); + } +} + +void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + + AddAtomicClean(kernel_graph); + + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph, true); + } +} + +void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->ir_fusion_flag()) { + MS_LOG(INFO) << "UBFusion is not enable, skip"; + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); + auto optimizer = std::make_shared(); + auto ub_fusion_pm = std::make_shared("ub_fusion_pm"); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(ub_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h new file mode 100644 index 0000000000..8194ab467b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ +#include +#include "backend/session/kernel_graph.h" +namespace mindspore { +namespace opt { +void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph); +void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); +void AscendDataLayout(const std::shared_ptr &kernel_graph); +void AscendMixPrecision(const std::shared_ptr &kernel_graph); +void AscendBackendOptimization(const std::shared_ptr &kernel_graph); +void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph); +void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select = false); +void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, + bool is_before_kernel_select = false); +void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph); +void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); +void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc new file mode 100644 index 0000000000..fd4c0e5952 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -0,0 +1,345 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ascend_helper.h" +#include +#include "common/trans.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; +namespace { +const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; +AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, + const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { + std::vector trans_inputs; + auto prim = std::make_shared(prim::kPrimReshape->name()); + trans_inputs.emplace_back(NewValueNode(prim)); + trans_inputs.emplace_back(input_node); + auto reshape = func_graph->NewCNode(trans_inputs); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape); + reshape->set_scope(input_node->scope()); + kernel_select->SelectKernel(reshape); + return reshape; +} + +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { + AnfNodePtr trans_node = nullptr; + AnfNodePtr input_node = node; + CNodePtr trans_data = nullptr; + std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); + std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; + std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); + MS_EXCEPTION_IF_NULL(node); + // if insert transdata for input we need to change the input + if (is_insert_input) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; + } + auto cnode = node->cast(); + dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); + input_node = AnfAlgo::GetInputNode(cnode, insert_index); + padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); + } + bool need_padding = false; + if (is_insert_input) { + need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + } else { + need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); + } + if (!need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else if (is_insert_input) { + // if need padding & is input need insert a transdata + // reshape[padding shape] -> transdata[padding shape] -> node + auto padding_shape = + trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); + auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); + trans_node = trans_data; + } else { + // if need padding & is output need insert a transdata + // node -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); + auto reshape_node = + CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); + trans_node = reshape_node; + } + // refresh the transdata's format to ori format & dst format + RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); + return trans_node; +} + +AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + auto input_node = AnfAlgo::GetInputNode(node, index); + auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); + MS_EXCEPTION_IF_NULL(node_with_index.first); + auto real_input = node_with_index.first; + if (real_input->isa() || real_input->isa()) { + input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); + MS_EXCEPTION_IF_NULL(input_node); + AnfAlgo::SetNodeInput(node, input_node, index); + } + std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); + std::string dest_format = AnfAlgo::GetInputFormat(node, index); + if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) + << " To DefaultFormat , index: " << index; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); + } + return input_node; +} + +AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + std::string output_format = AnfAlgo::GetOutputFormat(node, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); + if (output_format == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " + << node->DebugString(); + } + if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); + } + return node; +} + +AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { + std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); + if (output_format == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " + << node->DebugString(); + } + auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { + make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); + } else { + // No need insert trans op. + make_tuple_inputs.push_back(tuple_getitem); + } + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &trans_data, const std::vector &reshape_type) { + MS_EXCEPTION_IF_NULL(trans_data); + auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); + MS_EXCEPTION_IF_NULL(ori_build_info); + auto builder = std::make_shared(ori_build_info); + builder->SetInputsFormat({input_format}); + builder->SetInputReshapeType({reshape_type}); + builder->SetOutputReshapeType({reshape_type}); + builder->SetOutputsFormat({output_format}); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); +} + +CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, + const bool need_padding, const std::string &op_name) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input); + std::vector trans_inputs; + auto prim = std::make_shared(op_name); + trans_inputs.push_back(NewValueNode(prim)); + trans_inputs.push_back(input); + CNodePtr trans_node = func_graph->NewCNode(trans_inputs); + MS_EXCEPTION_IF_NULL(trans_node); + auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); + if (need_padding) { + // if need padding we should set the transdata node's shape to the padding shape + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, + trans_node.get()); + } else { + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); + } + // special handle for ut + if (trans_node->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + trans_node->set_kernel_info(kernel_info); + } + MS_EXCEPTION_IF_NULL(kernel_select); + kernel_select->SelectKernel(trans_node); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); + MS_EXCEPTION_IF_NULL(trans_node); + trans_node->set_scope(input->scope()); + return trans_node; +} + +AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type) { + MS_EXCEPTION_IF_NULL(func_graph); + std::string input_format = format; + std::string output_format = format; + std::vector new_cast_inputs; + auto prim = std::make_shared(prim::kPrimCast->name()); + new_cast_inputs.push_back(NewValueNode(prim)); + new_cast_inputs.push_back(input); + CNodePtr cast = func_graph->NewCNode(new_cast_inputs); + MS_EXCEPTION_IF_NULL(cast); + // set kernel build info + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({input_format}); + builder.SetOutputsFormat({output_format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetFusionType(kernel::FusionType::OPAQUE); + builder.SetProcessor(kernel::Processor::AICORE); + if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { + builder.SetKernelType(KernelType::TBE_KERNEL); + } else { + builder.SetKernelType(KernelType::AKG_KERNEL); + } + // if kernel info is null , it remarks this function is running ut + if (cast->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + cast->set_kernel_info(kernel_info); + } + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); + return cast; +} + +AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); + if (outputs_num == 0) { + return node; + } + // Single output + if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { + return InsertTransOpForSingleOutput(func_graph, node, kernel_select); + } + // Multiple output + return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); +} + +AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); + MS_EXCEPTION_IF_NULL(input_node); + new_inputs.push_back(input_node); + } + CNodePtr new_cnode = nullptr; + // cnode changed so make a new cnode to differ from original one. + auto kernel_graph = func_graph->cast>(); + if (kernel_graph == nullptr) { + new_cnode = std::make_shared(*cnode); + } else { + new_cnode = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_inputs(new_inputs); + return new_cnode; +} + +CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + TypeId origin_type(kTypeUnknown); + auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); + auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); + auto real_input_node = kernel_with_index.first; + if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + // weight + origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); + if (origin_type == kTypeUnknown) { + origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); + } + } else { + // feature map + origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + } + const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); + const std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); + const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); + // In graph kernel, we check parameter, + // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode(cur_input)) { + new_inputs.push_back(cur_input); + } else if (origin_type != device_type) { + auto cast = + AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(cast); + cast->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); + new_inputs.push_back(cast); + } else { + new_inputs.push_back(cur_input); + } + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_node = nullptr; + if (kernel_graph == nullptr) { + new_node = std::make_shared(*cnode); + } else { + new_node = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_inputs(new_inputs); + return new_node; +} + +AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto prim = std::make_shared(kMemCpyAsyncOpName); + std::vector new_node_inputs = {NewValueNode(prim), node}; + auto new_node = graph->NewCNode(new_node_inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_abstract(node->abstract()); + new_node->set_scope(node->scope()); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h new file mode 100644 index 0000000000..cb308a09a0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ + +#include +#include +#include +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +class KernelSelect { + public: + KernelSelect() = default; + virtual ~KernelSelect() = default; + virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } +}; +using KernelSelectPtr = std::shared_ptr; + +class SupportedChecker { + public: + SupportedChecker() = default; + virtual ~SupportedChecker() = default; + virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); + } + virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, + const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); + } +}; +using SupportedCheckerPtr = std::shared_ptr; + +class KernelQuery { + public: + KernelQuery() = default; + virtual ~KernelQuery() = default; + virtual void Query(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { + kernel::KernelQuery(kernel_node, kernel_info_list); + } + virtual bool IsTbeRef(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); + if (op_info != nullptr) { + return op_info->is_ref(); + } + return false; + } +}; +using KernelQueryPtr = std::shared_ptr; + +class OpFinder { + public: + OpFinder() = default; + virtual ~OpFinder() = default; + virtual int GetOpRegisteredOutputNum(const std::string &op_name) { + auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); + if (op_info == nullptr) { + return -1; + } + return op_info->outputs_ptr().size(); + } +}; +using OpFinderPtr = std::shared_ptr; + +void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); + +CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, + const bool need_padding, const std::string &op_name); + +AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, + const std::vector &origin_shape, const TypeId &origin_type); + +AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select); + +AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select); + +CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + +AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..22183c9050 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(relu_input); + auto add = relu_input->cast(); + MS_EXCEPTION_IF_NULL(add); + auto tuple_getitem = add->input(1); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (tuple_getitem->isa() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { + auto getitem = tuple_getitem->cast(); + MS_EXCEPTION_IF_NULL(getitem); + auto bnupdate = getitem->input(1); + MS_EXCEPTION_IF_NULL(bnupdate); + if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { + std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); + for (auto out_getitem : manager->node_users()[bnupdate]) { + MS_EXCEPTION_IF_NULL(out_getitem.first); + auto out_getitem_ptr = out_getitem.first->cast(); + MS_EXCEPTION_IF_NULL(out_getitem_ptr); + auto input2 = out_getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); + } + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); + std::unordered_set record{cnode, relu_input, bnupdate}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { + MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h new file mode 100644 index 0000000000..dfc45b4688 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { + public: + explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} + ~BnupdateEltwiseEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..59915d43d4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(relu_input); + auto getitem = relu_input->cast(); + MS_EXCEPTION_IF_NULL(getitem); + auto bnupdate = getitem->input(1); + MS_EXCEPTION_IF_NULL(bnupdate); + if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { + std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); + for (auto out_getitem : manager->node_users()[bnupdate]) { + MS_EXCEPTION_IF_NULL(out_getitem.first); + auto out_getitem_ptr = out_getitem.first->cast(); + MS_EXCEPTION_IF_NULL(out_getitem_ptr); + auto input2 = out_getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); + } + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); + std::unordered_set record{cnode, bnupdate}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { + MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h new file mode 100644 index 0000000000..abaf264d2e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class BnupdateEltwiseFusionPass : public FusionBasePass { + public: + explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} + ~BnupdateEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..1bfff1b50e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( + const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + } else { + return; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto double_in_eltwise_input = input_cnode->input(1); + MS_EXCEPTION_IF_NULL(double_in_eltwise_input); + if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { + return; + } + if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { + (void)record.insert(double_in_eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { + MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h new file mode 100644 index 0000000000..6bf74d5268 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { + public: + explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} + ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..144ab4b53f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { + (void)record.insert(eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { + MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h new file mode 100644 index 0000000000..93aa324566 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { + public: + explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} + ~Conv2DBackpropEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc new file mode 100644 index 0000000000..a2ebfbe79e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto conv = cnode->input(1); + MS_EXCEPTION_IF_NULL(conv); + if (conv->isa() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { + std::vector output_used_num{SizeToInt(manager->node_users()[conv].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); + std::unordered_set record{cnode, conv}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { + MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h new file mode 100644 index 0000000000..224422530b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ConvBnReduceFusionPass : public FusionBasePass { + public: + explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} + ~ConvBnReduceFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_CONV_BNREDUCE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc new file mode 100644 index 0000000000..1a67e3c39b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + } else { + return; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto double_in_eltwise_input = input_cnode->input(1); + MS_EXCEPTION_IF_NULL(double_in_eltwise_input); + if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) { + (void)record.insert(double_in_eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h new file mode 100644 index 0000000000..911cf744de --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ConvDoubleInFusionPass : public FusionBasePass { + public: + explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} + ~ConvDoubleInFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc new file mode 100644 index 0000000000..1eb26b12bc --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + if (record.size() == MAX_ELTWISE_NUM) { + break; + } + } + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) { + (void)record.insert(eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchConvSingleInEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h new file mode 100644 index 0000000000..6dddd600c2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ConvSingleInFusionPass : public FusionBasePass { + public: + explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} + ~ConvSingleInFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..285b8f6c07 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion, bool is_order) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + if (is_order) { + // DepthwiseConvolution--->Elemwise + auto depthwise_conv = cnode->input(1); + MS_EXCEPTION_IF_NULL(depthwise_conv); + if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { + std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); + std::unordered_set record{cnode, depthwise_conv}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } else { + // Elemwise-->DepthwiseConvolution + auto relu = cnode->input(1); + MS_EXCEPTION_IF_NULL(relu); + if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { + std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); + std::unordered_set record{cnode, relu}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); + } + } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h new file mode 100644 index 0000000000..6746dad984 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class DepthwiseConvEltwiseFusionPass : public FusionBasePass { + public: + explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} + ~DepthwiseConvEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion, bool is_order); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc new file mode 100644 index 0000000000..1e24cce0e4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + if (record.size() == MAX_ELTWISE_SIZE) { + break; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + } + if (record.size() < MIN_ELTWISE_SIZE) { + return; + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + +void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h new file mode 100644 index 0000000000..ae63687631 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class EltwiseFusionPass : public FusionBasePass { + public: + explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} + ~EltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc new file mode 100644 index 0000000000..27a7a786d1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include +#include +#include "debug/anf_ir_dump.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && + cnode->inputs().size() == ELTWISE_INPUT_SIZE; +} + +bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && + cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; +} + +bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_MULTI_USE && + cnode->inputs().size() == ELTWISE_INPUT_SIZE; +} + +void FusionBasePass::SetRecordFusionId(const std::unordered_set &record) { + auto id = fusion_id_allocator->AllocateFusionId(); + for (auto node : record) { + fusion_id_allocator->SetFusionId(node, id); + } +} + +bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto return_node = kernel_graph.get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() <= 1) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; + FusedNodeRecord candidate_fusion; + MatchSingleFusionPattern(kernel_graph, &candidate_fusion); + if (candidate_fusion.empty()) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; + return true; +} + +bool FusionBasePass::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return MatchUBFusionPattern(*kernel_graph); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h new file mode 100644 index 0000000000..dced2c2fa2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#include +#include +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +const int8_t MAX_ELTWISE_NUM = 3; +const int8_t MIN_ELTWISE_SIZE = 2; +const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; +const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3; +const int8_t CONV_QUART_IN_INPUT_SIZE = 5; +const int8_t ELTWISE_USE = 1; +const int8_t ELTWISE_MULTI_USE = 2; +const int8_t MAX_ELTWISE_SIZE = 6; +const int8_t MULTI_ELTWISE_SIZE = 4; +using FusedNodeRecord = std::vector>; + +struct BufferFusionInfo_t { + std::vector anf_nodes; + std::vector inputs_list; + std::vector outputs_list; + kernel::KernelBuildInfoPtr kernel_build_info; +}; + +class FusionBasePass : public Pass { + public: + FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) + : Pass(name), fusion_id_allocator(idAllocator) {} + ~FusionBasePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph); + + protected: + virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) = 0; + void SetRecordFusionId(const std::unordered_set &record); + bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + bool CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + FusionIdAllocatorPtr fusion_id_allocator; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..7fcc6e45e0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); + std::unordered_set record{cnode, relu_input}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + +void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { + MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h new file mode 100644 index 0000000000..e0d08bb58d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class MatmulEltwiseFusionPass : public FusionBasePass { + public: + explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} + ~MatmulEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc new file mode 100644 index 0000000000..58a219aec7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); + if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { + std::vector output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + } else { + return; + } + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + if (record.size() == MULTI_ELTWISE_SIZE) { + break; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + } + if (record.size() != MULTI_ELTWISE_SIZE) { + return; + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); +} + +void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchMultiOutputEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h new file mode 100644 index 0000000000..40a45360a1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/multi_output_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class MultiOutputFusionPass : public FusionBasePass { + public: + explicit MultiOutputFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("MultiOutputFusionPass", idAllocator) {} + ~MultiOutputFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..95955818eb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + if (record.size() == MAX_ELTWISE_NUM) { + break; + } + } + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE) { + (void)record.insert(eltwise_input); + auto previous_input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_input_cnode); + auto previous_eltwise_input = previous_input_cnode->input(1); + auto previous_size = record.size(); + while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { + (void)record.insert(previous_eltwise_input); + auto previous_node = previous_eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_node); + previous_eltwise_input = previous_node->input(1); + if (record.size() - previous_size == MAX_ELTWISE_NUM) { + break; + } + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchReduceEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h new file mode 100644 index 0000000000..4d56eee7b3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class ReduceEltwiseFusionPass : public FusionBasePass { + public: + explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} + ~ReduceEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..f2117f9374 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + while (CheckEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + eltwise_input = input_cnode->input(1); + if (record.size() == MAX_ELTWISE_NUM) { + break; + } + } + MS_EXCEPTION_IF_NULL(eltwise_input); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) { + (void)record.insert(eltwise_input); + auto previous_input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_input_cnode); + auto previous_eltwise_input = previous_input_cnode->input(1); + auto previous_size = record.size(); + while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { + (void)record.insert(previous_eltwise_input); + auto previous_node = previous_eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(previous_node); + previous_eltwise_input = previous_node->input(1); + if (record.size() - previous_size == MAX_ELTWISE_NUM) { + break; + } + } + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + std::reverse(node_list.begin(), node_list.end()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { + MatchSegmentEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h new file mode 100644 index 0000000000..f3b97f8357 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class SegmentEltwiseFusionPass : public FusionBasePass { + public: + explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} + ~SegmentEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc new file mode 100644 index 0000000000..d93b47b66c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h" + +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto write_input = cnode->input(1); + if (CheckEltWiseNode(manager.get(), write_input)) { + (void)record.insert(write_input); + auto input_cnode = write_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + write_input = input_cnode->input(1); + } + MS_EXCEPTION_IF_NULL(write_input); + if (!write_input->isa() || !AnfAlgo::IsRealCNodeKernel(write_input) || + fusion_id_allocator->HasFusionIdAttr(write_input)) { + return; + } + auto conv_cnode = write_input->cast(); + MS_EXCEPTION_IF_NULL(conv_cnode); + if (AnfAlgo::GetKernelType(conv_cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(conv_cnode) == kernel::FusionType::CONVLUTION && + conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE && + conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { + (void)record.insert(write_input); + auto conv_input = conv_cnode->input(1); + MS_EXCEPTION_IF_NULL(conv_input); + if (!conv_input->isa() || !AnfAlgo::IsRealCNodeKernel(conv_input) || + fusion_id_allocator->HasFusionIdAttr(conv_input)) { + return; + } + if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { + (void)record.insert(conv_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void StridedReadConvStridedWriteFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kStridedWriteOpName) { + MatchStridedReadConvStridedWrite(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h new file mode 100644 index 0000000000..371c206399 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ + +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class StridedReadConvStridedWriteFusionPass : public FusionBasePass { + public: + explicit StridedReadConvStridedWriteFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("StridedReadConvStridedWriteFusionPass", idAllocator) {} + ~StridedReadConvStridedWriteFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchStridedReadConvStridedWrite(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc new file mode 100644 index 0000000000..9685530705 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -0,0 +1,448 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +namespace { +const int8_t MAX_PATTERN_SIZE = 7; +const int8_t MIN_PATTERN_SIZE = 2; +const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_USE = 1; +const int8_t MULTI_ELTWISE_USE = 2; +const int8_t MAX_MULTI_ELTWISE_SIZE = 4; +const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; +constexpr auto kOpAttrFusionId = "fusion_id"; + +#ifdef DEBUG +std::string GetFusionTypeName(const kernel::FusionType &type) { + switch (type) { + case kernel::FusionType::COMMREDUCE: + return "COMMREDUCE"; + case kernel::FusionType::SEGMENT: + return "SEGMENT"; + case kernel::FusionType::ELEMWISE: + return "ELEMWISE"; + case kernel::FusionType::CONVLUTION: + return "CONVLUTION"; + case kernel::FusionType::OPAQUE: + return "OPAQUE"; + default: + return "OPAQUE"; + } +} + +void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { + MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id; + for (auto &node : info.input_nodes) { + MS_LOG(INFO) << "=== Input: " << node->DebugString(); + } + for (auto &node : info.output_nodes) { + MS_LOG(INFO) << "=== Output: " << node->DebugString(); + } + for (auto &node : info.compute_nodes) { + MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-(" << GetFusionTypeName(AnfAlgo::GetFusionType(node)) + << ")"; + } + MS_LOG(INFO) << "=== Dump FusionScopeInfo end"; +} +#endif +CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::vector &outputs_list, + const std::vector &anf_nodes, session::KernelGraph *kernel_graph) { + MS_LOG(DEBUG) << "Start Create FusionOp Kernel"; + MS_EXCEPTION_IF_NULL(kernel_graph); + std::string fusion_op_name = "FusionOp"; + for (auto node : anf_nodes) { + fusion_op_name += '_' + AnfAlgo::GetCNodeName(node); + } + auto fusion_op = std::make_shared(fusion_op_name); + MS_EXCEPTION_IF_NULL(fusion_op); + + std::vector input_names; + for (uint8_t i = 0; i < inputs_list.size(); i++) { + input_names.emplace_back("input" + std::to_string(i)); + } + std::vector output_names; + for (uint8_t i = 0; i < outputs_list.size(); i++) { + output_names.emplace_back("output" + std::to_string(i)); + } + + ValuePtr input_names_v = MakeValue(input_names); + ValuePtr output_names_v = MakeValue(output_names); + fusion_op->set_attr("input_names", input_names_v); + fusion_op->set_attr("output_names", output_names_v); + std::vector fusion_inputs_list = inputs_list; + auto value_node = std::make_shared(fusion_op); + (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); + auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list); + if (buffer_fusion_kernel == nullptr) { + MS_LOG(EXCEPTION) << "New FusionOp kernel failed!"; + } + buffer_fusion_kernel->set_scope((anf_nodes.back())->scope()); + + return buffer_fusion_kernel; +} + +kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, + const std::vector &outputs_list) { + MS_LOG(DEBUG) << "Start Create Kernel Info"; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + // inputs format and data type + std::vector inputs_format; + std::vector inputs_data_type; + for (const auto &input : inputs_list) { + auto real_input = AnfAlgo::VisitKernel(input, 0); + inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); + inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); + } + // outputs format and data type + std::vector outputs_format; + std::vector outputs_data_type; + for (const auto &output : outputs_list) { + if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto tuple_getitem = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + outputs_format.push_back(AnfAlgo::GetOutputFormat( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + } else { + outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); + outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); + } + } + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_data_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_data_type); + builder.SetKernelType(KernelType::TBE_KERNEL); + return builder.Build(); +} + +AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph, + size_t output_index) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector tuple_getitem_inputs_list; + auto value = std::make_shared(prim::kPrimTupleGetItem); + MS_EXCEPTION_IF_NULL(value); + auto idx = NewValueNode(SizeToInt(output_index)); + MS_EXCEPTION_IF_NULL(idx); + int temp = SizeToInt(output_index); + auto imm = std::make_shared(temp); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + tuple_getitem_inputs_list.push_back(value); + tuple_getitem_inputs_list.push_back(buffer_fusion_kernel); + tuple_getitem_inputs_list.push_back(idx); + auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list); + MS_EXCEPTION_IF_NULL(tuple_item); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, + {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)}, + tuple_item.get()); + return tuple_item; +} + +void ReplaceInputNodeInOtherFusionScope(std::unordered_map *buffer_fusion_infos, + int32_t fusion_id, const AnfNodePtr &output_item, + const AnfNodePtr &replace_item) { + for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { + auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), + output_item); + if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { + MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; + *itr = replace_item; + } + } +} + +void ReplaceOldNode(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, + const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; + if (buffer_fusion_info.outputs_list.size() == 1) { // single output + (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); + ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], + buffer_fusion_kernel); + } else { // multiple output + for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { + auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); + (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); + ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], + tuple_item); + } + } +} + +void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto nodes = TopoSort(kernel_graph->get_return()); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { + auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); + (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); + } + } +} + +void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { + auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == + fusion_info.anf_nodes.end()) { + if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), + (*buffer_fusion_infos)[fusion_id].inputs_list.end(), + cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { + (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); + } + } + } + } + } +} + +bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + auto getitem1 = node1->cast(); + auto getitem2 = node2->cast(); + MS_EXCEPTION_IF_NULL(getitem1); + MS_EXCEPTION_IF_NULL(getitem2); + if (getitem1->size() < kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" + << getitem1->DebugString() << "]"; + } + if (getitem2->size() < kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" + << getitem2->DebugString() << "]"; + } + auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); + auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); + return output_idx1 < output_idx2; +} + +void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + if (AnfAlgo::GetOutputTensorNum(node) == 1) { + for (auto use_node : manager->node_users()[node]) { + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == + fusion_info.anf_nodes.end()) { + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); + break; + } + } + } else { + int prev_idx = 0; + std::vector tuple_getitem_nodes; + std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), + std::back_inserter(tuple_getitem_nodes), + [](const std::pair &use_node) { return use_node.first; }); + std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); + for (auto getitem : tuple_getitem_nodes) { + MS_EXCEPTION_IF_NULL(getitem); + auto getitem_ptr = getitem->cast(); + auto input2 = getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { + auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); + } + prev_idx = output_idx + 1; + for (auto item_use_node : manager->node_users()[getitem]) { + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == + fusion_info.anf_nodes.end()) { + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); + break; + } + } + } + } + } + } +} + +void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, + const AnfNodePtr &fusion_kernel) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (size_t idx = 0; idx < outputs_list.size(); ++idx) { + auto output = outputs_list[idx]; + MS_EXCEPTION_IF_NULL(output); + if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto real_output = AnfAlgo::VisitKernel(output, 0); + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto input2 = output_cnode->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + session::AnfWithOutIndex out_pair(real_output.first, output_idx); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } else { + session::AnfWithOutIndex out_pair(output, 0); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } + } +} +} // namespace + +void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) const { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); + GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); + GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + buffer_fusion_info.second.kernel_build_info = + CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); + } +} + +bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + bool change = false; + std::unordered_map buffer_fusion_infos; + buffer_fusion_infos.clear(); + GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); + + std::vector fusion_scope_infos; + for (auto &buffer_fusion_info : buffer_fusion_infos) { + mindspore::kernel::FusionScopeInfo fusion_scope_info; + fusion_scope_info.scope_id = buffer_fusion_info.first; + fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; + fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; + fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; + fusion_scope_infos.push_back(fusion_scope_info); +#ifdef DEBUG + DumpFusionScopeInfo(fusion_scope_info); +#endif + } + auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); + std::vector fusion_ids; + for (auto &buffer_fusion_info : buffer_fusion_infos) { + MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() + << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() + << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); + fusion_ids.push_back(buffer_fusion_info.first); + } + // Replace fusion op from return to head + std::sort(fusion_ids.begin(), fusion_ids.end()); + for (auto &fusion_id : fusion_ids) { + // Get kernel mod when supporting tbe + if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { + MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; + continue; + } + change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); + } + MS_LOG(DEBUG) << "End Buffer Fusion"; + return change; +} + +bool UbPatternFusion::ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, + int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, + session::KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; + auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, + buffer_fusion_info.anf_nodes, kernel_graph); + AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); + // Set abstract of fusion_op node + std::vector types; + std::vector> shapes; + for (const auto &out_node : buffer_fusion_info.outputs_list) { + for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { + types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); + shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); + } + } + if (types.empty() || shapes.empty()) { + MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty"; + return false; + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); + AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); + SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); + ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); + return true; +} + +bool UbPatternFusion::Run(const FuncGraphPtr &graph) { + bool changed = false; + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + changed = FuseBufferFusionPattern(kernel_graph.get()); + // clear fusion_id attr + for (auto &node : graph->nodes()) { + if (node != nullptr && node->isa()) { + AnfAlgo::EraseNodeAttr(kAttrFusionId, node); + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h new file mode 100644 index 0000000000..69eb0f43d4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ +#include +#include +#include + +#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class UbPatternFusion : public Pass { + public: + UbPatternFusion() : Pass("TbeBufferFusion") {} + ~UbPatternFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + void GetBufferFusionInfo(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) const; + bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, + const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; + bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc new file mode 100644 index 0000000000..a729cdd0f9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore::opt { + +const BaseRef GetnextMemcpyElimination::DefinePattern() const { + auto prim_memcpy = std::make_shared(kMemCpyAsyncOpName); + VarPtr x = std::make_shared(); + VectorRef memcpy_async({prim_memcpy, x}); + return memcpy_async; +} + +const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + auto memcpy_cnode = node->cast(); + if (memcpy_cnode == nullptr) { + return nullptr; + } + + // 1. memcpy has attr kAttrLabelForInsertStreamActive + if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, memcpy_cnode)) { + MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr"; + return nullptr; + } + + // 2. memcpy's output has only one user next_node + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(memcpy_cnode) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "memcpy has no output in manager"; + } + auto next_nodes = manager->node_users()[memcpy_cnode]; + if (next_nodes.size() > 1) { + MS_LOG(DEBUG) << "node's output has more than one users"; + return nullptr; + } + + // 3. next_node is not nop node and it has only one input which is memcpy's output + for (auto &item : next_nodes) { + auto next_node = item.first->cast(); + if (opt::IsNopNode(next_node)) { + return nullptr; + } + if (next_node->inputs().size() != 2) { + MS_LOG(DEBUG) << "next node has more than one input"; + return nullptr; + } + // add attr label_for_insert_stream_active for next_node + AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), next_node); + } + + return memcpy_cnode->input(1); +} +} // namespace mindspore::opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h new file mode 100644 index 0000000000..365088b34a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class GetnextMemcpyElimination : public PatternProcessPass { + public: + explicit GetnextMemcpyElimination(bool multigraph = true) + : PatternProcessPass("getnext_memcpy_elimination", multigraph) {} + ~GetnextMemcpyElimination() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc new file mode 100644 index 0000000000..bac9f54ace --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" +#include +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (func_graph == nullptr || node == nullptr) { + return nullptr; + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + if (output_num == 0) { + MS_LOG(DEBUG) << "Output number is zero, no need to insert memcpy_async!"; + return node; + } + + // getnext output is tuple and dynamic + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (size_t output_index = 0; output_index < output_num; ++output_index) { + auto tuple_get_item = CreatTupleGetItemNode(func_graph, node, output_index); + auto new_node = CreateMemcpyAsyncOp(func_graph, tuple_get_item); + if (new_node == nullptr) { + MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; + } + AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); + make_tuple_inputs.push_back(new_node); + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} + +const BaseRef InsertMemcpyAsyncForGetNext::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + auto prim = std::make_shared(kGetNextOpName); + + return VectorRef({prim, Xs}); +} + +const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + + auto cnode = node->cast(); + if (AnfAlgo::HasNodeAttr(kAttrVisited, cnode)) { + MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited."; + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cnode); + + return InsertMemcpyAsyncForGetNextOutputs(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h new file mode 100644 index 0000000000..6fefc32230 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class InsertMemcpyAsyncForGetNext : public PatternProcessPass { + public: + explicit InsertMemcpyAsyncForGetNext(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_getnext", multigraph) {} + ~InsertMemcpyAsyncForGetNext() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc new file mode 100644 index 0000000000..2585006be6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include +#include +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +// insert memcpy for some cnode even if not a Ref cnode +const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, + kLambUpdateWithLROpName}; + +bool IsParameterOrValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + return kernel_with_index.first->isa() || kernel_with_index.first->isa(); +} + +void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(hccl_node); + MS_EXCEPTION_IF_NULL(memcpy_async); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(hccl_node); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + // find hccl_node's output which is a control depend + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + int output_index = node_index.second; + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + CNodePtr control_depend = output->cast(); + MS_EXCEPTION_IF_NULL(control_depend); + std::vector new_inputs; + for (size_t i = 0; i < control_depend->size(); ++i) { + if (i == IntToSize(output_index)) { + new_inputs.push_back(memcpy_async); + } else { + new_inputs.push_back(control_depend->input(i)); + } + } + control_depend->set_inputs(new_inputs); + } + } +} +} // namespace + +bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input); + // when input is a parameter or is a value node + if (IsParameterOrValueNode(input)) { + return true; + } + + // when input is a Ref or some special cnodes + if (kernel_query_->IsTbeRef(input) || + kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { + return true; + } + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(input); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + // when input is used by others + if (iter->second.size() > 1) { + return true; + } + return false; +} + +void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + bool has_insert_memcpy = false; + AnfNodePtr memcpy_async = nullptr; + std::vector new_inputs = {hccl_node->input(0)}; + for (size_t i = 1; i < hccl_node->size(); ++i) { + auto input = hccl_node->input(i); + if (NeedInsertMemcpy(graph, input)) { + memcpy_async = CreateMemcpyAsyncOp(graph, input); + has_insert_memcpy = true; + new_inputs.push_back(memcpy_async); + } else { + new_inputs.push_back(input); + } + } + + if (has_insert_memcpy) { + CNodePtr new_hccl_node = std::make_shared(*hccl_node); + new_hccl_node->set_inputs(new_inputs); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; + (void)manager->Replace(hccl_node, new_hccl_node); + MS_LOG(DEBUG) << "end replace"; + + // transer hccl op's control to the memcpy_async + if (hccl_node->size() == 2) { + TransferControl(new_hccl_node, memcpy_async, graph); + } + } +} + +const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + if (!AnfAlgo::IsCommunicationOp(node)) { + return nullptr; + } + InsertMemcpyAsync(func_graph, cnode); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h new file mode 100644 index 0000000000..7bd730a84d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { + public: + explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), + kernel_query_(std::make_shared()) {} + ~InsertMemcpyAsyncForHcclOp() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; + bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; + KernelQueryPtr kernel_query_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc new file mode 100644 index 0000000000..be61833fe4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" +#include +#include +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler//oplib/oplib.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertPadForNMSWithMask::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimNMSWithMask, Xs}); +} + +AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type, + const std::vector &origin_shape) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector new_pad_inputs; + auto prim = std::make_shared(prim::kPrimPad->name()); + new_pad_inputs.push_back(NewValueNode(prim)); + new_pad_inputs.push_back(input); + CNodePtr pad = func_graph->NewCNode(new_pad_inputs); + MS_EXCEPTION_IF_NULL(pad); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get()); + return pad; +} + +const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + size_t input_num = AnfAlgo::GetInputTensorNum(node); + if (input_num == 0) { + return nullptr; + } + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { + auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); + auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); + auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); + if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) { + return nullptr; + } + origin_shape[1] = 8; + auto pad = InsertPadToGraph(func_graph, cur_input, origin_type, origin_shape); + MS_EXCEPTION_IF_NULL(pad); + pad->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector>{{0, 0}, {0, 3}}), pad); + new_inputs.push_back(pad); + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_node = nullptr; + if (kernel_graph == nullptr) { + new_node = std::make_shared(*cnode); + } else { + new_node = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_inputs(new_inputs); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h new file mode 100644 index 0000000000..6aed678ff2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class InsertPadForNMSWithMask : public PatternProcessPass { + public: + explicit InsertPadForNMSWithMask(bool multigraph = true) + : PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {} + ~InsertPadForNMSWithMask() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc new file mode 100644 index 0000000000..f508bb2868 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" + +#include +#include +#include +#include + +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +using ConvertFunction = std::function; + +void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode); +const size_t kAxis_H = 2; +const size_t kAxis_W = 3; +const size_t kAxis_6HD_H = 1; +const size_t kAxis_6HD_W = 2; +const std::map kReduceConvertMap = {{kOpFormat_FRAC_Z, ConvertReduceAttrFraczAnd6HD}, + {kOpFormat_C1HWNCoC0, ConvertReduceAttrFraczAnd6HD}}; +void SafeCheckFunction(const CNodePtr &cnode, const std::vector &reduce_axis) { + if (reduce_axis.empty()) { + MS_LOG(EXCEPTION) << "The node " << cnode->DebugString() << "'s reduce axis got a empty vector"; + } + if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) && + AnfAlgo::GetInputTensorNum(cnode) != 1) { + MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString() + << "] is not single input or single output "; + } + for (auto elem : reduce_axis) { + if (elem > 4) { + MS_LOG(INFO) << "reduce axis is larger than 4 dims reduce axis : [" << elem << "]"; + } + } +} + +void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) { + auto axis = kernel::GetReduceAttrAxis(cnode); + std::vector convert_axis; + SafeCheckFunction(cnode, axis); + auto format = AnfAlgo::GetInputFormat(cnode, 0); + if (format != kOpFormat_FRAC_Z || format != kOpFormat_C1HWNCoC0) { + MS_LOG(EXCEPTION) << "The node [" << cnode->DebugString() << "] format " << format << " is not 5hd"; + } + for (auto elem : axis) { + switch (elem) { + case kAxis_H: + convert_axis.emplace_back(kAxis_6HD_H); + break; + case kAxis_W: + convert_axis.emplace_back(kAxis_6HD_W); + break; + default: + MS_LOG(INFO) << "reduce axis is axis : [" << elem << "]" + << " but the format is not supported this reduce axis"; + } + } + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(convert_axis), cnode); +} +} // namespace + +const BaseRef ChangeAxisOfReduceKernel::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) { + return nullptr; + } + auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0)); + if (convert_map == kReduceConvertMap.end()) { + return nullptr; + } + convert_map->second(node->cast()); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h new file mode 100644 index 0000000000..6bf1287ae7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ChangeAxisOfReduceKernel : public PatternProcessPass { + public: + explicit ChangeAxisOfReduceKernel(bool multigraph = true) + : PatternProcessPass("change_axis_of_reduce_kernel", multigraph) {} + ~ChangeAxisOfReduceKernel() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc new file mode 100644 index 0000000000..7da0027310 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/check_consistency.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { + MS_EXCEPTION_IF_NULL(node); + // get prior node's device output format + string pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(node, input_index); + string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); + if (pre_output_format == selected_input_format) { + return true; + } + auto input_origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, input_index); + if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { + string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; + // when input shape size is 1D, default format and NC1HWC0 are compatible + if (input_origin_shape.size() == 1 && checking_format == kOpFormat_NC1HWC0) { + return true; + } + if (kDefaultCompatibleFormat.find(checking_format) != kDefaultCompatibleFormat.end()) { + return true; + } + } + if (input_origin_shape.size() == 0) { + return true; + } + MS_LOG(ERROR) << "Found inconsistent format! input format " << input_index << ": " << pre_output_format + << ", selected input format: " << selected_input_format; + return false; +} + +bool CheckDataTypeForConsistency(const CNodePtr &node, const size_t input_index) { + MS_EXCEPTION_IF_NULL(node); + TypeId input_data_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(node, input_index); + TypeId selected_data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); + if (input_data_type == selected_data_type) { + return true; + } + MS_LOG(ERROR) << "Found inconsistent dtype! input dtype " << input_index << ": " << TypeIdLabel(input_data_type) + << ", selected dtype: " << TypeIdLabel(selected_data_type); + return false; +} +} // namespace + +const BaseRef CheckConsistency::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + + std::vector todos = {node}; + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + kernel::GetValidKernelNodes(sub_graph, &todos); + } + + for (auto &t : todos) { + CNodePtr cnode = t->cast(); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { + if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { + MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" + << cnode->DebugString() << "]"; + } + } + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h new file mode 100644 index 0000000000..bf956895de --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/check_consistency.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class CheckConsistency : public PatternProcessPass { + public: + explicit CheckConsistency(bool multigraph = true) : PatternProcessPass("check_consistency", multigraph) {} + ~CheckConsistency() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc new file mode 100644 index 0000000000..48948dca06 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/kernel_query.h" +namespace mindspore { +namespace opt { +const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, + const mindspore::AnfNodePtr &node, + const mindspore::EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto node_name = AnfAlgo::GetCNodeName(node); + if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { + return nullptr; + } + auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); + if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { + return nullptr; + } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { + auto builder = std::make_shared(kernel_builder_info); + builder->SetKernelType(AICPU_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); + } else { + MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" + << node->DebugString() << "]"; + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h new file mode 100644 index 0000000000..e534a851ad --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_unsupported_transnode_to_aicpu.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H +#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H +namespace mindspore { +namespace opt { +class ConvertUnSupportNodeToAICPU : public PatternProcessPass { + public: + explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) + : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), + supported_checker_(std::make_shared()) {} + ~ConvertUnSupportNodeToAICPU() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc new file mode 100644 index 0000000000..3dbe2d9f8a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" +#include +#include +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { + session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); + AnfNodePtr cur_node = kernel_with_index.first; + size_t cur_out_index = kernel_with_index.second; + MS_EXCEPTION_IF_NULL(cur_node); + if (cur_node->isa()) { + auto cnode = cur_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string op_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + // deal ref op + if (op_info != nullptr && op_info->is_ref()) { + auto ref_infos = op_info->ref_infos(); + if (ref_infos.count(cur_out_index) != 0) { + auto in_index = ref_infos.at(cur_out_index); + if (in_index > cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() + << ", ref info is " << cur_out_index; + } + AnfNodePtr next_node = cnode->input(in_index + 1); + return FindRefOriginNode(next_node); + } + } + + // deal special (trans,cast,reshape) op + if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || + op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { + AnfNodePtr next_node = cnode->input(1); + return FindRefOriginNode(next_node); + } + } + + return kernel_with_index; +} + +void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, + const AnfNodePtr &final_node, size_t final_index, + const session::KernelWithIndex &origin_pair) { + // record the ref_pair + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + // if the final node is get item, means no trans or cast op is added, the final node is itself + // so add the pair for itself, because the get item will removed later + auto final_ref = (final_node == get_item ? cnode : final_node); + session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index); + if (kernel_graph->IsInRefOutputMap(final_pair)) { + MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is " + << final_index; + } + MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is " + << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr " + << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index " + << origin_pair.second << "}"; + kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair); +} + +// if get_item is nullptr, the additional node will link to the cnode +// else the additional node will link to the get_item node (the get_item node link to cnode) +AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, + size_t input_index, const AnfNodePtr &get_item) { + AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); + size_t final_index = output_index; + AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); + session::KernelWithIndex origin_pair; + origin_pair = FindRefOriginNode(input_node); + MS_EXCEPTION_IF_NULL(origin_pair.first); + if (!origin_pair.first->isa()) { + MS_LOG(EXCEPTION) << "ref op origin node is not parameter"; + } + MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is " + << origin_pair.first->DebugString() << ", index is " << origin_pair.second; + auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second); + auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); + auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); + auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); + auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); + // insert trans + if (origin_format != cur_format && cur_shape.size() > 1) { + auto kernel_select = std::make_shared(); + final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(cur_format, origin_format, final_node); + final_index = 0; + MS_EXCEPTION_IF_NULL(final_node); + MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); + } + // insert cast + if (origin_type != cur_type) { + final_node = + AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, cur_shape, cur_type); + MS_EXCEPTION_IF_NULL(final_node); + final_node->set_scope(cnode->scope()); + final_index = 0; + MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); + } + // add ref pair + AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); + // insert depend + if (origin_format != cur_format || origin_type != cur_type) { + std::vector depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; + final_node = func_graph->NewCNode(depend_nodes); + MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); + } + + return final_node; +} +AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(op_info); + auto ref_infos = op_info->ref_infos(); + std::vector make_tuple_inputs; + AbstractBasePtrList abstract_list; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); + // deal with ref output + if (ref_infos.count(output_index) != 0) { + auto input_index = ref_infos.at(output_index); + final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); + } + MS_EXCEPTION_IF_NULL(final_node); + abstract_list.push_back(final_node->abstract()); + make_tuple_inputs.push_back(final_node); + } + MS_EXCEPTION_IF_NULL(func_graph); + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + make_tuple->set_abstract(std::make_shared(abstract_list)); + return make_tuple; +} + +AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::shared_ptr &op_info) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(op_info); + auto ref_infos = op_info->ref_infos(); + for (const auto &ref_info : ref_infos) { + if (ref_info.second > cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is " + << ref_info.second; + } + return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr); + } + return nullptr; +} +} // namespace + +const BaseRef DealRefTransAndCast::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { + auto input_size = AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_size; ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); + auto input_node = input_node_with_index.first; + MS_EXCEPTION_IF_NULL(input_node); + MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); + AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); + } + } +} + +const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::IsRealCNodeKernel(cnode)) { + return nullptr; + } + + DealBroadCastAsRef(graph, cnode); + + auto op_name = AnfAlgo::GetCNodeName(cnode); + auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + if (op_info == nullptr || !op_info->is_ref()) { + return nullptr; + } + if (op_info->is_ref()) { + auto type = cnode->Type(); + MS_EXCEPTION_IF_NULL(type); + if (!type->isa()) { + return DealRefSigleOutput(graph, cnode, op_info); + } else { + return DealRefForMultipleOutput(graph, cnode, op_info); + } + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h new file mode 100644 index 0000000000..cb3b13dc49 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class DealRefTransAndCast : public PatternProcessPass { + public: + explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} + ~DealRefTransAndCast() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc new file mode 100644 index 0000000000..c3f7900645 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/format_type/insert_cast.h" + +#include +#include +#include +#include + +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &need_insert_cast) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + std::vector make_tuple_inputs; + AbstractBasePtrList abstract_list; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { + AnfNodePtr replace_node = nullptr; + const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); + const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(output_idx); + idx->set_abstract(std::make_shared(imm)); + auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); + AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); + if (need_insert_cast[output_idx]) { + const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); + TypeId origin_type(kTypeUnknown); + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); + } + origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; + const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); + if (origin_type != device_type) { + replace_node = + AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + } else { + replace_node = getitem; + } + } else { + replace_node = getitem; + } + abstract_list.push_back(replace_node->abstract()); + make_tuple_inputs.push_back(replace_node); + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + make_tuple->set_abstract(std::make_shared(abstract_list)); + return make_tuple; +} // namespace + +AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + const std::vector &need_insert_cast) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { + return cnode; + } + MS_EXCEPTION_IF_NULL(cnode->Type()); + // Single output + if (!cnode->Type()->isa()) { + if (!need_insert_cast[0]) { + return cnode; + } + + const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); + TypeId origin_type(kTypeUnknown); + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); + } + origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; + const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); + AnfNodePtr replace_node = cnode; + if (origin_type != device_type) { + replace_node = + AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); + } + return replace_node; + } + // Multiple output + return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); +} + +AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + // insert cast for ops in graph kernel. + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + auto mng = sub_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + std::vector todo; + std::vector> graph_rets; + kernel::GetValidKernelNodes(sub_graph, &todo); + kernel::GetGraphRealOutput(sub_graph, &graph_rets); + for (auto &t : todo) { + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); + // process input + CNodePtr t_cnode = t->cast(); + MS_EXCEPTION_IF_NULL(t_cnode); + auto t_new_node = InsertCastForInput(sub_graph, t_cnode); + AnfNodePtr t_new_node_1 = nullptr; + std::vector need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true); + // process output + auto iter = std::find_if(graph_rets.begin(), graph_rets.end(), + [&t](const std::pair &ret) { return ret.first == t; }); + if (iter != graph_rets.end()) { + auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t); + auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second); + auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin()); + if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) { + need_insert_cast[iter->second] = false; + } else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) { + need_insert_cast[iter->second] = false; + } + t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); + } else { + t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); + } + + if (t_new_node_1 != nullptr && t_new_node_1 != t) { + (void)mng->Replace(t, t_new_node_1); + } + } + + // insert cast for graph kernel. + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + // process input + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto new_node = InsertCastForInput(func_graph, cnode); + // process output + return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); +} +} // namespace + +const BaseRef InsertCast::DefinePattern() const { + VarPtr V = std::make_shared(UnVisited); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { + return nullptr; + } + + if (AnfAlgo::IsGraphKernel(node)) { + return ProcessGraphKernelOp(func_graph, node); + } + // insert cast for single op. + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + // process input + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto new_node = InsertCastForInput(func_graph, cnode); + // process output + return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h new file mode 100644 index 0000000000..19c282aac9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ +#include + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "ir/anf.h" + +namespace mindspore { +namespace opt { +class InsertCast : public PatternProcessPass { + public: + explicit InsertCast(bool multigraph = true) : PatternProcessPass("insert_cast", multigraph) {} + ~InsertCast() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc new file mode 100644 index 0000000000..a22a1faa5f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include +#include +#include "utils/utils.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertTransOp::DefinePattern() const { + std::shared_ptr V = std::make_shared(UnVisited); + std::shared_ptr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +bool IsGraphOutput(const AnfNodePtr &node, const std::vector &outputs) { + auto iter = std::find(outputs.begin(), outputs.end(), node); + if (iter != outputs.end()) { + return true; + } + + return false; +} + +const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + AnfNodePtr front_node; + auto kernel_graph = func_graph->cast>(); + if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { + front_node = kernel_graph->GetFrontNodeByInternalOutput(node); + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + MS_LOG(DEBUG) << "====process op: " << node->DebugString(); + AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { + if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { + return new_node; + } + } + auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); + if (kernel_graph != nullptr && front_node != nullptr) { + auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); + kernel_graph->ReplaceInternalOutput(old_node, final_node); + } + return final_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h new file mode 100644 index 0000000000..0b21375327 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertTransOp : public PatternProcessPass { + public: + explicit InsertTransOp(bool multigraph = true) + : PatternProcessPass("insert_trans_op", multigraph), kernel_select_(std::make_shared()) {} + ~InsertTransOp() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc new file mode 100644 index 0000000000..d0b92b250d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" +#include +#include "utils/utils.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace opt { +const BaseRef RunOpInsertTransData::DefinePattern() const { + std::shared_ptr V = std::make_shared(UnVisited); + MS_EXCEPTION_IF_NULL(V); + std::shared_ptr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + return VectorRef({V, Xs}); +} + +const AnfNodePtr RunOpInsertTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + MS_LOG(DEBUG) << "====process op: " << node->DebugString(); + return InsertTransOpForInput(func_graph, node, kernel_select_); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h new file mode 100644 index 0000000000..82ff5f2b9a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class RunOpInsertTransData : public PatternProcessPass { + public: + explicit RunOpInsertTransData(bool multigraph = true) + : PatternProcessPass("insert_transdata_for_runop", multigraph), + kernel_select_(std::make_shared()) {} + ~RunOpInsertTransData() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc new file mode 100644 index 0000000000..88e9fa77b8 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc @@ -0,0 +1,282 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kCastInputNum = 2; +const size_t kTupleGetitemInputNum = 3; +bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx, + const std::shared_ptr &candidate_kernel_info) { + if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) { + return false; + } + + // checkout inputs' fmt and dtype except index equal change_idx + for (size_t i = 0; i < candidate_kernel_info->GetInputNum(); i++) { + if (i == change_idx) { + if (candidate_kernel_info->GetInputDeviceType(i) != dst_type || + candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { + return false; + } + } else if (candidate_kernel_info->GetInputDeviceType(i) != AnfAlgo::GetInputDeviceDataType(node, i) || + candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { + return false; + } + } + + // check outputs's fmt and dtype + for (size_t i = 0; i < candidate_kernel_info->GetOutputNum(); i++) { + if (candidate_kernel_info->GetOutputDeviceType(i) != AnfAlgo::GetOutputDeviceDataType(node, i) || + candidate_kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(node, i)) { + return false; + } + } + return true; +} + +bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, + size_t *cast_index) { + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + if (output_node_list->size() != 1) { + return false; + } + auto node_pair = output_node_list->at(0); + *next_node = node_pair.first; + *cast_index = node_pair.second - 1; + return true; +} + +bool CheckInputs(const CNodePtr &node, const std::shared_ptr &kernel_info) { + MS_EXCEPTION_IF_NULL(kernel_info); + if (AnfAlgo::GetInputTensorNum(node) != kernel_info->GetInputNum()) { + return false; + } + + for (size_t index = 0; index < kernel_info->GetInputNum(); ++index) { + if (AnfAlgo::GetInputFormat(node, index) != kernel_info->GetInputFormat(index) || + AnfAlgo::GetInputDeviceDataType(node, index) != kernel_info->GetInputDeviceType(index)) { + return false; + } + } + return true; +} + +bool CheckOtherOutputs(const CNodePtr &node, const std::shared_ptr &kernel_info, + const size_t idx) { + MS_EXCEPTION_IF_NULL(kernel_info); + if (AnfAlgo::GetOutputTensorNum(node) != kernel_info->GetOutputNum()) { + return false; + } + for (size_t index = 0; index < kernel_info->GetOutputNum(); ++index) { + if (idx == index) { + continue; + } + if (AnfAlgo::GetOutputFormat(node, index) != kernel_info->GetOutputFormat(index) || + AnfAlgo::GetOutputDeviceDataType(node, index) != kernel_info->GetOutputDeviceType(index)) { + return false; + } + } + return true; +} + +bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr &kernel_info, size_t index) { + if (kernel_info == nullptr) { + return false; + } + + if (AnfAlgo::GetOutputDeviceDataType(node, 0) != kernel_info->GetOutputDeviceType(index)) { + return false; + } + if (AnfAlgo::GetOutputInferShape(node, 0).size() == 4 && AnfAlgo::GetOutputFormat(node, 0) == kOpFormat_NCHW && + kernel_info->GetOutputFormat(index) == kOpFormat_DEFAULT) { + return true; + } + return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index); +} + +void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { + using Shape = std::vector; + auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); + auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); + std::vector shapes; + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + if (cast_index == index) { + shapes.emplace_back(cast_shape); + types.emplace_back(cast_dtype); + continue; + } + shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); + types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); +} + +AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(kernel_query); + AnfNodePtr next_node = nullptr; + size_t cast_index = 0; + if (!GetNextNodeAndCastIndex(graph, node, &next_node, &cast_index)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(next_node); + if (!next_node->isa() || !AnfAlgo::IsRealKernel(next_node)) { + return nullptr; + } + auto next_cnode = next_node->cast(); + if (AnfAlgo::IsGraphKernel(next_node)) { + return nullptr; + } + auto next_op_name = AnfAlgo::GetCNodeName(next_node); + std::vector> kernel_info_list; + kernel_query->Query(next_cnode, &kernel_info_list); + + auto dst_type_id = AnfAlgo::GetInputDeviceDataType(node, 0); + auto alternative_kernel_info = std::find_if( + kernel_info_list.begin(), kernel_info_list.end(), + [&next_cnode, &dst_type_id, &cast_index](const std::shared_ptr &candidate_kernel_info) { + return AlternativeKernelInfoForInput(next_cnode, dst_type_id, cast_index, candidate_kernel_info); + }); + if (alternative_kernel_info == kernel_info_list.end()) { + return nullptr; + } + auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); + MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() + << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" + << (*alternative_kernel_info)->ToString(); + AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); + ChangeNodeInferInfo(next_cnode, node, cast_index); + if (node->inputs().size() < kCastInputNum) { + MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; + } + return node->input(1); +} + +bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_output, size_t *output_idx) { + MS_EXCEPTION_IF_NULL(x_node); + if (x_node->isa()) { + auto x_cnode = x_node->cast(); + *prior_op = x_cnode; + // when x_node is tuple_getitem + if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) { + if (x_cnode->inputs().size() < kTupleGetitemInputNum) { + MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size(); + } + MS_EXCEPTION_IF_NULL(output_idx); + AnfNodePtr input1 = x_cnode->input(1); + MS_EXCEPTION_IF_NULL(input1); + if (!input1->isa()) { + return false; + } + *prior_op = input1->cast(); + MS_EXCEPTION_IF_NULL(*prior_op); + AnfNodePtr input2 = x_cnode->input(2); + MS_EXCEPTION_IF_NULL(input2); + auto value_ptr = input2->cast(); + MS_EXCEPTION_IF_NULL(value_ptr); + *output_idx = IntToSize(GetValue(value_ptr->value())); + *single_output = false; + } + return AnfAlgo::IsRealKernel(*prior_op); + } + return false; +} + +AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) { + MS_EXCEPTION_IF_NULL(cur_node); + MS_EXCEPTION_IF_NULL(kernel_query); + if (cur_node->inputs().size() < kCastInputNum) { + MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:"; + } + AnfNodePtr x_node = cur_node->input(1); + if (IsUsedByOthers(graph, x_node)) { + return nullptr; + } + + CNodePtr prior_op = nullptr; + bool single_output = true; + size_t output_idx = 0; + if (!GetPriorOp(x_node, &prior_op, &single_output, &output_idx)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(prior_op); + if (AnfAlgo::IsGraphKernel(prior_op)) { + return nullptr; + } + + std::vector> kernel_info_list; + kernel_query->Query(prior_op, &kernel_info_list); + auto kernel_info_it = std::find_if( + kernel_info_list.begin(), kernel_info_list.end(), + [&prior_op, &cur_node, &output_idx](const std::shared_ptr &item_kernel_info) { + return CheckInputs(prior_op, item_kernel_info) && CheckOtherOutputs(prior_op, item_kernel_info, output_idx) && + CheckIndexOutput(cur_node, item_kernel_info, output_idx); + }); + if (kernel_info_it == kernel_info_list.end()) { + return nullptr; + } + auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); + MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() + << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" + << (*kernel_info_it)->ToString(); + AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); + ChangeNodeInferInfo(prior_op, cur_node, output_idx); + if (!single_output) { + MS_EXCEPTION_IF_NULL(x_node); + ChangeNodeInferInfo(x_node->cast(), cur_node, 0); + } + auto prior_name = AnfAlgo::GetCNodeName(prior_op); + if (prior_name == kFive2FourOpName) { + AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op); + } else if (prior_name == kFour2FiveOpName) { + AnfAlgo::CopyNodeAttr("dst_type", cur_node, prior_op); + } + return single_output ? prior_op : x_node; +} +} // namespace + +const BaseRef MergeCastToOp::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::kPrimCast, X}); +} + +const AnfNodePtr MergeCastToOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + auto new_node = MergeCastToNextOp(graph, cnode, kernel_query_); + if (new_node == nullptr) { + new_node = MergeCastToPriorOp(graph, cnode, kernel_query_); + } + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h new file mode 100644 index 0000000000..d0e467b7a3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class MergeCastToOp : public PatternProcessPass { + public: + explicit MergeCastToOp(bool multigraph = true) + : PatternProcessPass("merge_cast_to_op", multigraph), kernel_query_(std::make_shared()) {} + ~MergeCastToOp() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelQueryPtr kernel_query_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.cc new file mode 100644 index 0000000000..adca536f04 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h" +#include +#include +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto input_format = AnfAlgo::GetInputFormat(cnode, 0); + if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) { + return nullptr; + } + if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { + return nullptr; + } + + AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode); + return cnode; +} + +AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) { + auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0); + if (input_shape.size() != 5) { + return nullptr; + } + if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) { + return nullptr; + } + + auto multiples = AnfAlgo::GetNodeAttr>(cnode, kAttrMultiples); + if (multiples.size() == 4 && multiples[1] == 1) { + multiples.push_back(1); + AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode); + } + + return cnode; +} + +AnfNodePtr ModifyAttrs(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto op_name = AnfAlgo::GetCNodeName(cnode); + if (op_name == prim::kPrimTile->name()) { + return ModifyTileOpAttrs(cnode); + } else if (op_name == prim::kPrimReduceSum->name()) { + // kPrimReduceMean + // kPrimReduceSum + // kPrimReduceAll + // kPrimReduceMax + // kPrimReduceMin + return ModifyReduceOpsAttrs(cnode); + } + return nullptr; +} +} // namespace + +const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { + return nullptr; + } + MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node); + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + auto manager = fg->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector todos; + kernel::GetValidKernelNodes(fg, &todos); + for (auto &t : todos) { + auto new_node = ModifyAttrs(t->cast()); + if (new_node != nullptr && new_node != t) { + (void)manager->Replace(t, new_node); + } + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h new file mode 100644 index 0000000000..f5608db05a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/modify_ops_attrs.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ModifyOpAttrs : public PatternProcessPass { + public: + explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {} + ~ModifyOpAttrs() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc new file mode 100644 index 0000000000..91b9326cc1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/common_utils.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode) { + return RectifyKernelInfoInPynativeProcess(node); + } + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { + return nullptr; + } + std::vector do_mask_node_list; + auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode); + MS_EXCEPTION_IF_NULL(gen_mask_output_nodes); + for (const auto &output_node : *gen_mask_output_nodes) { + if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { + MS_EXCEPTION_IF_NULL(output_node.first); + auto output_cnode = output_node.first->cast(); + do_mask_node_list.push_back(output_cnode); + } + } + std::vector input_shape; + for (const auto &output_node : do_mask_node_list) { + if (input_shape.empty()) { + input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); + continue; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); + if (!kernel::IsSameShape(shape, input_shape)) { + MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!" + << " GenMask " << node->DebugString(); + } + } + RectifyKernelInfo(do_mask_node_list, graph); + return nullptr; +} + +void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list, + const FuncGraphPtr &graph) const { + std::map format_counter; + std::string special_format; + std::string convert_format; + for (const auto &do_mask : do_mask_node_list) { + auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); + if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) { + special_format = do_mask_data_format; + } + if (format_counter.find(do_mask_data_format) == format_counter.end()) { + format_counter[do_mask_data_format] = 1; + } else { + format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; + } + } + if (format_counter.size() == 1) { + return; + } + if (convert_format.empty()) { + convert_format = GetConvertFormat(format_counter); + } + RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph); +} + +std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { + std::string convert_format = kOpFormat_DEFAULT; + size_t counter = 0; + if (format_counter.size() > 2) { + return kOpFormat_DEFAULT; + } + if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { + return kOpFormat_DEFAULT; + } + for (const auto &iter : format_counter) { + if (counter < iter.second) { + convert_format = iter.first; + counter = iter.second; + } else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) { + convert_format = iter.first; + } + } + return convert_format; +} + +void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, + const std::string &format, + const FuncGraphPtr &graph) const { + for (const auto &do_mask : do_mask_node_list) { + if (AnfAlgo::GetInputFormat(do_mask, 0) != format) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); + builder->SetInputFormat(format, 0); + builder->SetOutputFormat(format, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); + ReSelecChildNodeKernelInfo(do_mask, graph); + } + } +} + +AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { + return nullptr; + } + auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); + if (do_mask_input_format != kOpFormat_DEFAULT) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + builder->SetInputFormat(kOpFormat_DEFAULT, 0); + builder->SetOutputFormat(kOpFormat_DEFAULT, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + return nullptr; +} + +void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const { + MS_EXCEPTION_IF_NULL(cnode); + auto output_node_list = GetRealNodeUsedList(graph, cnode); + MS_EXCEPTION_IF_NULL(output_node_list); + for (const auto &out_node_info : *output_node_list) { + MS_EXCEPTION_IF_NULL(out_node_info.first); + auto out_node = out_node_info.first->cast(); + if (AnfAlgo::IsRealKernel(out_node_info.first)) { + auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); + kernel_selecter->SelectKernel(out_node); + auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); + MS_EXCEPTION_IF_NULL(new_build_info); + MS_EXCEPTION_IF_NULL(ori_build_info); + if ((*new_build_info) != (*ori_build_info)) { + ReSelecChildNodeKernelInfo(out_node, graph); + } + } else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() || + AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) { + ReSelecChildNodeKernelInfo(out_node, graph); + } else { + MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed"; + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h new file mode 100644 index 0000000000..cc9333a013 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#include +#include +#include +#include + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" +namespace mindspore { +namespace opt { +class RectifyDoMaskKernelInfo : public PatternProcessPass { + public: + explicit RectifyDoMaskKernelInfo(bool multigraph = true) + : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared()) {} + ~RectifyDoMaskKernelInfo() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void RectifyKernelInfo(const std::vector &do_mask_node_list, const FuncGraphPtr &graph) const; + AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; + std::string GetConvertFormat(const std::map &format_counter) const; + void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format, + const FuncGraphPtr &graph) const; + void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const; + KernelSelectPtr kernel_selecter; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.cc new file mode 100644 index 0000000000..09992005a4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h" +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto op_name = AnfAlgo::GetCNodeName(cnode); + if (op_name != prim::kPrimReshape->name()) { + return nullptr; + } + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); + if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) { + return nullptr; + } + + return cnode->input(1); +} +} // namespace + +const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { + return nullptr; + } + MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node); + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + auto manager = fg->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector todos; + kernel::GetValidKernelNodes(fg, &todos); + for (auto &t : todos) { + auto new_node = RemoveReshapeOp(t->cast()); + if (new_node != nullptr && new_node != t) { + (void)manager->Replace(t, new_node); + } + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h new file mode 100644 index 0000000000..135f11f52c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveNoUseReshapeOp : public PatternProcessPass { + public: + explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {} + ~RemoveNoUseReshapeOp() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc new file mode 100644 index 0000000000..a3fd704bc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/addn_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index, + size_t offset) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(origin_addn_cnode); + std::vector new_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; + for (size_t i = begin_index; i < begin_index + offset; ++i) { + new_addn_inputs.push_back(origin_addn_cnode->input(i)); + } + CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); + MS_EXCEPTION_IF_NULL(new_addn); + new_addn->set_scope(origin_addn_cnode->scope()); + new_addn->set_abstract(origin_addn_cnode->abstract()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); + return new_addn; +} +} // namespace + +const BaseRef AddnFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimAddN, Xs}); +} + +const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // The real input begins with index 1. + size_t origin_input_size = cnode->inputs().size() - 1; + if (origin_input_size <= inputs_divisor_) { + return nullptr; + } + CNodePtr new_cnode = cnode; + while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); + std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; + size_t cur_input_index = 1; + // Divide the inputs of addn by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { + base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); + cur_input_index += inputs_divisor_; + } + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_addn_inputs.push_back(new_cnode->input(i)); + } + CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); + MS_EXCEPTION_IF_NULL(base_addn); + base_addn->set_scope(new_cnode->scope()); + base_addn->set_abstract(new_cnode->abstract()); + AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); + std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); + new_cnode = base_addn; + origin_input_size = base_addn->inputs().size() - 1; + } + + return new_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h new file mode 100644 index 0000000000..e04cdfdf7b --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/addn_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr size_t kAddnInputsDivisor = 63; +class AddnFission : public PatternProcessPass { + public: + explicit AddnFission(bool multigraph = true) + : PatternProcessPass("addn_fission", multigraph), inputs_divisor_(kAddnInputsDivisor) {} + ~AddnFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + size_t inputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc new file mode 100644 index 0000000000..f0edefd5f5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const std::vector kOutputIndex{0, 3, 4, 5}; +constexpr size_t kBatchNormRealOutputNum = 3; +constexpr size_t kBatchNormRealInputNum = 3; + +bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(bn) == manager->node_users().end()) { + return false; + } + size_t output_num = 0; + for (const auto &node_index : manager->node_users()[bn]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getiterm_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); + auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { + return false; + } + bn_outputs->push_back(output); + output_num++; + } + return output_num == kBatchNormRealOutputNum; +} + +AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; + auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + auto bn_input1 = bn_cnode->input(2); + MS_EXCEPTION_IF_NULL(bn_input1); + auto bn_input2 = bn_cnode->input(3); + MS_EXCEPTION_IF_NULL(bn_input2); + AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input2->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_reduce->set_abstract(abstract_tuple); + bn_training_reduce->set_scope(bn->scope()); + AnfAlgo::CopyNodeAttrs(bn, bn_training_reduce); + return bn_training_reduce; +} + +AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, + const std::vector &bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum + << ", but it is " << bn_training_reduce_outputs.size(); + } + std::vector bn_training_update_v2_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateV2OpName)), + bn_cnode->input(1), + bn_training_reduce_outputs[0], + bn_training_reduce_outputs[1], + bn_cnode->input(2), + bn_cnode->input(3)}; + auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_v2); + + auto bn_abstract_tuple = dyn_cast(bn->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " + << bn_abstract_tuple->elements().size(); + } + std::vector abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3], + bn_abstract_tuple->elements()[4]}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update_v2->set_abstract(abstract_tuple); + bn_training_update_v2->set_scope(bn->scope()); + AnfAlgo::CopyNodeAttrs(bn, bn_training_update_v2); + return bn_training_update_v2; +} +} // namespace + +const BaseRef BatchNormBertFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNorm, Xs}); +} + +const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + std::vector bn_outputs; + if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { + MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != kBatchNormRealInputNum + 1) { + MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum + << ". The node should not be changed"; + return nullptr; + } + AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); + std::vector bn_training_reduce_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, + &bn_training_reduce_outputs); + + AnfNodePtr bn_training_update_v2 = CreateBNTrainingUpdateV2(func_graph, node, bn_training_reduce_outputs); + std::vector bn_training_update_v2_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v2, kBNTrainingUpdateV2OutputNum, + &bn_training_update_v2_outputs); + if (bn_training_update_v2_outputs.size() != kBNTrainingUpdateV2OutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum + << ", but it is " << bn_training_update_v2_outputs.size(); + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); + size_t output_index = 0; + for (const auto &output : bn_outputs) { + (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); + output_index++; + } + // Return the new node for control depends. + return bn_training_update_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h new file mode 100644 index 0000000000..23f0e56035 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormBertFission : public PatternProcessPass { + public: + explicit BatchNormBertFission(bool multigraph = true) : PatternProcessPass("batch_norm_bert_fission", multigraph) {} + ~BatchNormBertFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc new file mode 100644 index 0000000000..97c67e4441 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.cc @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBatchNormGradInferOutputNum = 3; +bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(node) == manager->node_users().end()) { + MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs"; + return false; + } + for (const auto &node_index : manager->node_users()[node]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getiterm_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); + auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) { + MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change"; + return false; + } + } + return true; +} +} // namespace + +AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_grad); + MS_EXCEPTION_IF_NULL(equiv); + // Set inputs + auto iter_input0 = (*equiv).find(input0_var_); + if (iter_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; + } + auto iter_input2 = (*equiv).find(input2_var_); + if (iter_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched."; + } + auto iter_input4 = (*equiv).find(input4_var_); + if (iter_input4 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; + } + std::vector bn_infer_grad_inputs = { + NewValueNode(std::make_shared(kBNInferGradOpName)), utils::cast(iter_input0->second), + utils::cast(iter_input2->second), utils::cast(iter_input4->second)}; + auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_infer_grad); + // Set abstract, the output of new node is taking the place of the 0th output of bn_grad. + auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); + MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); + if (bn_grad_abstract_tuple->elements().empty()) { + MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty"; + } + bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad); + bn_infer_grad->set_scope(bn_grad->scope()); + return bn_infer_grad; +} + +AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, + const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn_grad); + MS_EXCEPTION_IF_NULL(equiv); + // Set inputs + auto iter_input0 = (*equiv).find(input0_var_); + if (iter_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; + } + auto iter_input1 = (*equiv).find(input1_var_); + if (iter_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched."; + } + auto iter_input3 = (*equiv).find(input3_var_); + if (iter_input3 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched."; + } + auto iter_input4 = (*equiv).find(input4_var_); + if (iter_input4 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; + } + std::vector bn_training_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), + utils::cast(iter_input0->second), utils::cast(iter_input1->second), + utils::cast(iter_input3->second), utils::cast(iter_input4->second)}; + auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_grad); + // Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad. + auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); + MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); + if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) { + MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"; + } + std::vector abstract_list{bn_grad_abstract_tuple->elements()[1], + bn_grad_abstract_tuple->elements()[2]}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update_grad->set_abstract(abstract_tuple); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad); + bn_training_update_grad->set_scope(bn_grad->scope()); + return bn_training_update_grad; +} + +const BaseRef BatchNormGradInferFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs}); +} + +const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast())) { + MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed"; + return nullptr; + } + if (AnfAlgo::GetNodeAttr(node, kAttrIsTraining)) { + MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change"; + return nullptr; + } + if (!CheckOutputsIndex(func_graph, node)) { + MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change"; + return nullptr; + } + AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv); + AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv); + std::vector bn_training_update_grad_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum, + &bn_training_update_grad_outputs); + if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be " + << kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size(); + } + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad, + bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h new file mode 100644 index 0000000000..97100de284 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormGradInferFission : public PatternProcessPass { + public: + explicit BatchNormGradInferFission(bool multigraph = true) + : PatternProcessPass("batch_norm_grad_infer_fission", multigraph), + input0_var_(std::make_shared()), + input1_var_(std::make_shared()), + input2_var_(std::make_shared()), + input3_var_(std::make_shared()), + input4_var_(std::make_shared()) {} + ~BatchNormGradInferFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const; + AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, + const EquivPtr &equiv) const; + + VarPtr input0_var_; + VarPtr input1_var_; + VarPtr input2_var_; + VarPtr input3_var_; + VarPtr input4_var_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc new file mode 100644 index 0000000000..97122386c6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + std::vector *bn_update_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() < kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + std::vector bn_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], + bn_grad_inputs[4], bn_grad_inputs[5]}; + auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_update_grad); + bn_update_grad->set_kernel_info(std::make_shared()); + bn_update_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); + CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); +} + +void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + const std::vector &bn_update_grad_outputs, + std::vector *bn_reduce_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() < kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size"; + } + std::vector bn_reduce_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), + bn_grad_inputs[1], + bn_grad_inputs[2], + bn_update_grad_outputs[0], + bn_update_grad_outputs[1], + bn_grad_inputs[3], + bn_grad_inputs[4], + bn_grad_inputs[5]}; + auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_reduce_grad); + bn_reduce_grad->set_kernel_info(std::make_shared()); + bn_reduce_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); + (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); +} +} // namespace +const BaseRef BatchNormGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kBatchNormGradOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr BatchNormGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->HasAttr(kAttrIsTraining)) { + MS_LOG(INFO) << "Op BatchNormGrad must have attrs of is_training"; + return nullptr; + } + if (!AnfAlgo::GetNodeAttr(cnode, kAttrIsTraining)) { + MS_LOG(INFO) << "is_training must be true"; + return nullptr; + } + + std::vector bn_update_grad_outputs; + CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; + } + + std::vector bn_reduce_grad_outputs; + CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); + if (bn_reduce_grad_outputs.size() != kSingleOutputNum) { + MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], + bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h new file mode 100644 index 0000000000..e5378d8332 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BatchNormGradSplit : public PatternProcessPass { + public: + explicit BatchNormGradSplit(bool multigraph = true) : PatternProcessPass("batch_norm_grad_split", multigraph) {} + ~BatchNormGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc new file mode 100644 index 0000000000..6c4e226120 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + std::vector *bn_update_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() != kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + std::vector bn_update_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], + bn_grad_inputs[4], bn_grad_inputs[5]}; + auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_update_grad); + bn_update_grad->set_kernel_info(std::make_shared()); + bn_update_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); + CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); +} + +void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, + const std::vector &bn_update_grad_outputs, + std::vector *bn_reduce_grad_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_grad_node); + auto bn_grad_inputs = bn_grad_node->inputs(); + if (bn_grad_inputs.size() != kBNGradInputNum) { + MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; + } + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; + } + std::vector bn_reduce_grad_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), + bn_grad_inputs[1], + bn_grad_inputs[2], + bn_update_grad_outputs[0], + bn_update_grad_outputs[1], + bn_grad_inputs[3], + bn_grad_inputs[4], + bn_grad_inputs[5]}; + auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); + MS_EXCEPTION_IF_NULL(bn_reduce_grad); + bn_reduce_grad->set_kernel_info(std::make_shared()); + bn_reduce_grad->set_scope(bn_grad_node->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); + + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); + (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); +} + +CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector bn_update_grad_outputs; + CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); + if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { + MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; + } + + std::vector bn_reduce_grad_outputs; + CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); + if (bn_reduce_grad_outputs.size() != 1) { + MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], + bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace + +const BaseRef BnGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimFusedBatchNormGrad, Xs}); +} + +const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + return BNGradSplitForTBE(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h new file mode 100644 index 0000000000..6fe78d4724 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BnGradSplit : public PatternProcessPass { + public: + explicit BnGradSplit(bool multigraph = true) : PatternProcessPass("bn_grad_split", multigraph) {} + ~BnGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc new file mode 100644 index 0000000000..33670e5703 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/bn_split.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, + std::vector *bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() != kBnInputNum) { + MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); + return false; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName))}; + bn_training_reduce_inputs.push_back(bn_cnode->input(1)); + auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + bn_training_reduce->set_kernel_info(kernel_info); + std::vector bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); + if (bn_shape_i0.size() < kShape2dDims) { + MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; + return false; + } + std::vector bn_training_reduce_shape = {bn_shape_i0[1]}; + auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; + auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get()); + bn_training_reduce->set_scope(bn_cnode->scope()); + AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce); + + CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs); + return true; +} + +AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, + const std::vector &bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() != kBnInputNum) { + MS_LOG(EXCEPTION) << "BN node has wrong input size"; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; + } + // the inputs of BNTrainingUpdate are from the outputs of BNTrainingReduce and the inputs of BN + std::vector bn_training_update_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateOpName))}; + bn_training_update_inputs.push_back(bn_cnode->input(1)); + bn_training_update_inputs.push_back(bn_training_reduce_outputs[0]); + bn_training_update_inputs.push_back(bn_training_reduce_outputs[1]); + bn_training_update_inputs.push_back(bn_cnode->input(2)); + bn_training_update_inputs.push_back(bn_cnode->input(3)); + bn_training_update_inputs.push_back(bn_cnode->input(4)); + bn_training_update_inputs.push_back(bn_cnode->input(5)); + auto bn_training_update = graph->NewCNode(bn_training_update_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + bn_training_update->set_kernel_info(kernel_info); + bn_training_update->set_abstract(bn_cnode->abstract()); + bn_training_update->set_scope(bn_cnode->scope()); + auto factor = AnfAlgo::GetNodeAttr(bn_cnode, kAttrMomentum); + AnfAlgo::SetNodeAttr(kAttrFactor, MakeValue(factor), bn_training_update); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update); + AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); + return bn_training_update; +} + +AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() < kBnInputNum) { + MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; + return nullptr; + } + // Create BNTrainingReduce node and get outputs of BNTrainingReduce + std::vector bn_training_reduce_outputs; + if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { + MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; + return nullptr; + } + if (bn_training_reduce_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"; + } + + // Create BNTrainingUpdate node + return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs); +} +} // namespace + +const BaseRef BnSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + return VectorRef({prim::kPrimFusedBatchNorm, Xs}); +} + +const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + return SplitFusedBatchNormForTBE(func_graph, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h new file mode 100644 index 0000000000..4340ba0af6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class BnSplit : public PatternProcessPass { + public: + explicit BnSplit(bool multigraph = true) : PatternProcessPass("bn_split", multigraph) {} + ~BnSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc new file mode 100644 index 0000000000..e8a778b36f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/lars_v2_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2, + std::vector *square_sum_all_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lars_v2); + if (lars_v2->size() != kLarsV2InputNum) { + MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; + } + + std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName))}; + inputs.push_back(lars_v2->input(1)); + inputs.push_back(lars_v2->input(2)); + auto square_sum_all = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(square_sum_all); + square_sum_all->set_scope(lars_v2->scope()); + + auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; + std::vector shape; + auto shapes = {shape, shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sum_all.get()); + + CreateMultipleOutputsOfAnfNode(graph, square_sum_all, 2, square_sum_all_outputs); +} + +CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, + const std::vector &square_sum_all_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(lars_v2); + if (square_sum_all_outputs.size() != 2) { + MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"; + } + if (lars_v2->size() != kLarsV2InputNum) { + MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; + } + std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName))}; + inputs.push_back(lars_v2->input(1)); + inputs.push_back(lars_v2->input(2)); + inputs.push_back(square_sum_all_outputs[0]); + inputs.push_back(square_sum_all_outputs[1]); + inputs.push_back(lars_v2->input(3)); + inputs.push_back(lars_v2->input(4)); + auto lars_v2_update = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(lars_v2_update); + lars_v2_update->set_scope(lars_v2->scope()); + lars_v2_update->set_abstract(lars_v2->abstract()); + return lars_v2_update; +} +} // namespace + +const BaseRef LarsV2Fission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto lars_v2_prim = std::make_shared(kLarsV2OpName); + return VectorRef({lars_v2_prim, Xs}); +} + +const AnfNodePtr LarsV2Fission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto lars_v2 = node->cast(); + MS_EXCEPTION_IF_NULL(lars_v2); + + std::vector square_sum_all_outputs; + CreateOutputsOfSquareSumAll(graph, lars_v2, &square_sum_all_outputs); + return CreateLarsV2Update(graph, lars_v2, square_sum_all_outputs); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h new file mode 100644 index 0000000000..3a165f2b29 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LarsV2Fission : public PatternProcessPass { + public: + explicit LarsV2Fission(bool multigraph = true) : PatternProcessPass("lars_v2_fission", multigraph) {} + ~LarsV2Fission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc new file mode 100644 index 0000000000..1d19def787 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "ir/primitive.h" +#include "common/utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( + const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_x_backprop_outputs) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(layer_norm_grad); + auto prim = std::make_shared(kLayerNormXBackpropOpName); + std::vector layer_norm_x_backprop_inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { + layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i)); + } + auto layer_norm_x_backprop = graph->NewCNode(layer_norm_x_backprop_inputs); + MS_EXCEPTION_IF_NULL(layer_norm_x_backprop); + layer_norm_x_backprop->set_scope(layer_norm_grad->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get()); + + (*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop); +} + +void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop( + const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_beta_gamma_backprop_outputs) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(layer_norm_grad); + auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); + std::vector layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) { + layer_norm_beta_gamma_backprop_inputs.push_back(layer_norm_grad->input(i)); + } + auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs); + MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop); + auto kernel_info = std::make_shared(); + layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info); + layer_norm_beta_gamma_backprop->set_scope(layer_norm_grad->scope()); + + auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 1), + AnfAlgo::GetOutputInferDataType(layer_norm_grad, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 1), AnfAlgo::GetOutputInferShape(layer_norm_grad, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get()); + + // get device shape of LayerNormGrad's 5th Input, and convert it to attr + std::vector shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4); + AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop); + + CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum, + layer_norm_beta_gamma_backprop_outputs); +} + +const BaseRef LayerNormGradSplit::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VectorRef pattern({prim::kPrimLayerNormGrad, Xs}); + return pattern; +} + +const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode->inputs().size() != kLayerNormGradInputNum) { + return nullptr; + } + + // create layer_norm_x_backprop + std::vector layer_norm_x_backprop_outputs; + CreateOutputsOfLayerNormXBackprop(graph, cnode, &layer_norm_x_backprop_outputs); + if (layer_norm_x_backprop_outputs.size() != kSingleOutputNum) { + MS_LOG(EXCEPTION) << "layer_norm_grad_outputs has wrong size"; + } + + // create layer_norm_beta_gamma_backprop + std::vector layer_norm_beta_gamma_backprop_outputs; + CreateOutputsOfLayerNormBetaGammaBackprop(graph, cnode, &layer_norm_beta_gamma_backprop_outputs); + if (layer_norm_beta_gamma_backprop_outputs.size() != kLayerNormBetaGammaBackpropOutputNum) { + MS_LOG(EXCEPTION) << "layer_norm_beta_gamma_outputs has wrong size"; + } + + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), layer_norm_x_backprop_outputs[0], + layer_norm_beta_gamma_backprop_outputs[0], + layer_norm_beta_gamma_backprop_outputs[1]}; + auto make_tuple = graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + return make_tuple; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h new file mode 100644 index 0000000000..c1501b1593 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class LayerNormGradSplit : public PatternProcessPass { + public: + explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {} + ~LayerNormGradSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void CreateOutputsOfLayerNormXBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_grad_outputs) const; + void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, + std::vector *layer_norm_beta_gamma_outputs) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc new file mode 100644 index 0000000000..133d51734f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kBatchNormRealInputNum = 3; + +AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; + auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + + // set abstract + auto bn_input1 = bn_cnode->input(2); + MS_EXCEPTION_IF_NULL(bn_input1); + AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_reduce->set_abstract(abstract_tuple); + bn_training_reduce->set_scope(bn->scope()); + return bn_training_reduce; +} + +AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, + const std::vector &bn_training_reduce_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + auto bn_cnode = bn->cast(); + MS_EXCEPTION_IF_NULL(bn_cnode); + if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { + MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " + << kBatchNormRealInputNum + 1; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum + << ", but it is " << bn_training_reduce_outputs.size(); + } + std::vector bn_training_update_v3_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateV3OpName)), + bn_cnode->input(1), + bn_training_reduce_outputs[0], + bn_training_reduce_outputs[1], + bn_cnode->input(2), + bn_cnode->input(3)}; + auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update_v3); + + auto bn_abstract_tuple = dyn_cast(bn->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " + << bn_abstract_tuple->elements().size(); + } + bn_training_update_v3->set_abstract(bn->abstract()); + bn_training_update_v3->set_scope(bn->scope()); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3); + return bn_training_update_v3; +} +} // namespace + +const BaseRef SingleBatchNormFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimBatchNorm, Xs}); +} + +const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() < kBatchNormRealInputNum + 1) { + MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum + << ". The node should not be changed"; + return nullptr; + } + if (!GetBoolAttr(cnode, kAttrIsTraining)) { + MS_LOG(INFO) << "is training should be true if do fusion"; + return nullptr; + } + AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); + std::vector bn_training_reduce_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, + &bn_training_reduce_outputs); + + return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h new file mode 100644 index 0000000000..fb641c12d6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SingleBatchNormFission : public PatternProcessPass { + public: + explicit SingleBatchNormFission(bool multigraph = true) + : PatternProcessPass("single_batch_norm_fission", multigraph) {} + ~SingleBatchNormFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc new file mode 100644 index 0000000000..063f81a1ca --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/split_fission.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(input_node); + std::vector splitv_inputs{NewValueNode(std::make_shared(kSplitVOpName)), input_node}; + CNodePtr splitv = func_graph->NewCNode(splitv_inputs); + MS_EXCEPTION_IF_NULL(splitv); + splitv->set_scope(input_node->scope()); + return splitv; +} + +CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { + MS_EXCEPTION_IF_NULL(origin_cnode); + if (origin_cnode->inputs().size() < kSplitInputNum) { + MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " + << kSplitInputNum - 1; + } + return CreateSplitVNode(func_graph, origin_cnode->input(1)); +} + +void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector &size_splits, int split_dim, int num_split) { + AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); + AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); + AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); +} + +size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); + if (split_dim < 0) { + split_dim += input_shape.size(); + } + if (IntToSize(split_dim) >= input_shape.size()) { + MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; + } + return input_shape[split_dim] / num_split; +} + +void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num, + std::vector *inputs) { + MS_EXCEPTION_IF_NULL(inputs); + std::vector new_splitv_output; + CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output); + inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); +} + +AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { + MS_EXCEPTION_IF_NULL(func_graph); + auto idx = NewValueNode(SizeToInt(index)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(index)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); + return tuple_getitem; +} + +void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split, + std::vector *new_type_ids, + std::vector> *new_output_shapes) { + MS_EXCEPTION_IF_NULL(new_type_ids); + MS_EXCEPTION_IF_NULL(new_output_shapes); + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } + output_shape[split_dim] = split_size; + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + for (int i = 0; i < num_split; ++i) { + new_type_ids->emplace_back(type_id); + new_output_shapes->emplace_back(output_shape); + } +} + +void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, + const std::vector &size_splits_base, int split_dim, int num_split) { + SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); + std::vector base_type_ids; + std::vector> base_output_shapes_base; + auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); + TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); + if (split_dim < 0) { + split_dim += output_shape.size(); + } + for (int i = 0; i < num_split; ++i) { + output_shape[split_dim] = size_splits_base[i]; + base_output_shapes_base.emplace_back(output_shape); + base_type_ids.emplace_back(type_id); + } + AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); +} + +AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) { + MS_EXCEPTION_IF_NULL(func_graph); + auto split_dim = AnfAlgo::GetNodeAttr(cnode, kAttrAxis); + CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); + + // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. + auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); + std::vector size_splits_new; + for (int i = 0; i < divisor; ++i) { + size_splits_new.emplace_back(small_split_size); + } + // Create new output shape and new output type id for each new Splitv node which has full inputs. + std::vector new_type_ids; + std::vector> new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); + + // Create make_tuple input to create a make_tuple for replacing the old Split node. + std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; + // Start to divide the outputs of Split. + std::vector size_splits_base; + const auto base_split_size = divisor * small_split_size; + int nodes_num = 0; + int cur_output_index = 0; + while (num_split - cur_output_index > divisor) { + CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); + AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); + cur_output_index += divisor; + size_splits_base.emplace_back(base_split_size); + nodes_num++; + } + if (cur_output_index < num_split) { + auto last_node_num_split = num_split - cur_output_index; + if (last_node_num_split > 1) { + CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + std::vector size_splits_new_last; + for (int i = 0; i < last_node_num_split; ++i) { + size_splits_new_last.emplace_back(small_split_size); + } + SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); + // Create new output shape and new output type id for the last Splitv node + std::vector last_new_type_ids; + std::vector> last_new_output_shapes; + CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, + &last_new_output_shapes); + AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); + AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); + size_splits_base.emplace_back(last_node_num_split * small_split_size); + } else { + make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); + size_splits_base.emplace_back(small_split_size); + } + nodes_num++; + } + // Set Attr and abstract for the base splitv + SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace + +const BaseRef SplitFission::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto split_prim = std::make_shared(kSplitOpName); + return VectorRef({split_prim, Xs}); +} + +const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // Check output num + if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) { + return nullptr; + } + auto num_split = AnfAlgo::GetNodeAttr(cnode, kAttrOutputNum); + if (num_split <= outputs_divisor_) { + return nullptr; + } + return DoFission(func_graph, cnode, num_split, outputs_divisor_); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h new file mode 100644 index 0000000000..6428a21e73 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr int kSplitOutputsDivisor = 63; +class SplitFission : public PatternProcessPass { + public: + explicit SplitFission(bool multigraph = true) + : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} + ~SplitFission() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + int outputs_divisor_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc new file mode 100644 index 0000000000..6eeb7a61f7 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/topk_split.h" +#include +#include +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/utils.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +constexpr size_t kFloat16Len = 2; // size of float16; +constexpr size_t kTopkIndexK = 1; +namespace { +tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { + // 1 create tensor + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + auto last_dim = shape[shape.size() - 1]; + std::vector indices_shape = {SizeToInt(last_dim * 2)}; + TensorTypePtr tensor_type = std::make_shared(kFloat16); + MS_EXCEPTION_IF_NULL(tensor_type); + tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type}; + tensor::TensorPtr indices_tensor = std::make_shared(kFloat16->type_id(), indices_shape); + MS_EXCEPTION_IF_NULL(indices_tensor); + indices_tensor->set_device_info(device_info); + + // 2 set value of tensor + auto data_ptr = indices_tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + std::vector half_data; + for (size_t i = 0; i < last_dim; ++i) { + half_data.emplace_back(Eigen::half(static_cast(i))); + } + for (size_t i = 0; i < last_dim; ++i) { + auto gap = static_cast(i) - static_cast(Eigen::half(static_cast(i))); + half_data.emplace_back(Eigen::half(static_cast(gap))); + } + auto elem_num = last_dim * kFloat16Len * 2; + auto ret_code = memcpy_s(data_ptr, static_cast(indices_tensor->data().nbytes()), half_data.data(), elem_num); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data into Tensor."; + return nullptr; + } + return indices_tensor; +} + +ValueNodePtr CreateValueNode(const AnfNodePtr &node) { + tensor::TensorPtr indices_tensor = CreateTensor(node); + MS_EXCEPTION_IF_NULL(indices_tensor); + auto indices_const = std::make_shared(indices_tensor); + MS_EXCEPTION_IF_NULL(indices_const); + auto indices_abstract = indices_tensor->ToAbstract(); + indices_const->set_abstract(indices_abstract); + auto indices_kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(indices_kernel_info); + indices_const->set_kernel_info(indices_kernel_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; + builder1.SetOutputsFormat({kOpFormat_DEFAULT}); + builder1.SetOutputsDeviceType({kNumberTypeFloat16}); + AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); + return indices_const; +} + +kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetKernelType(TBE_KERNEL); + builder.SetFusionType(kernel::OPAQUE); + builder.SetProcessor(kernel::AICORE); + builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); + builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); + return builder.Build(); +} + +bool CheckInputNamesSize(const CNodePtr &cnode) { + auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); + if (input_names_vec.size() < kTopkIndexK + 1) { + MS_LOG(INFO) << "The input k of topk has been converted to attr"; + return false; + } + return true; +} + +bool CheckOutputShape(const AnfNodePtr &node) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); + if (shape.empty()) { + MS_LOG(INFO) << "The output shape of topk to split must not be empty"; + return false; + } + auto last_dim = shape[shape.size() - 1]; + const size_t kMaxFloat16 = 65500; + if (last_dim > kMaxFloat16) { + MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; + return false; + } + return true; +} +} // namespace + +const BaseRef TopKSplit::DefinePattern() const { + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + auto prim = std::make_shared(kTopKOpName); + return VectorRef({prim, X1, X2}); +} + +const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto kernel_graph = func_graph->cast(); + // set value node as topk's input + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!CheckInputNamesSize(cnode)) { + return nullptr; + } + if (!CheckOutputShape(cnode)) { + return nullptr; + } + // Copy a new node to check supported. + std::vector new_inputs{NewValueNode(std::make_shared(kTopKOpName))}; + new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + CNodePtr new_cnode = func_graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + AnfAlgo::CopyNodeAttrs(cnode, new_cnode); + CheckCNodeInputSize(new_cnode, kTopkInputNum); + // Convert the tensor input to scalar and convert it to attr + auto input_k = new_cnode->input(kTopkIndexK + 1); + MS_EXCEPTION_IF_NULL(input_k); + if (!IsValueNode(input_k)) { + return nullptr; + } + ValuePtr value = GetValueNode(input_k); + MS_EXCEPTION_IF_NULL(value); + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + int32_t *data = reinterpret_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(data); + auto new_value_node = std::make_shared(MakeValue(*data)); + new_cnode->set_input(kTopkIndexK + 1, new_value_node); + + std::unordered_set attr_index{kTopkIndexK}; + ConstInputToAttr(new_cnode, attr_index); + auto indices_const = CreateValueNode(new_cnode); + new_cnode->add_input(indices_const); + MS_EXCEPTION_IF_NULL(supported_checker_); + if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { + MS_LOG(INFO) << "split topk failed, check to aicpu."; + return nullptr; + } + + if (kernel_graph != nullptr) { + MS_LOG(INFO) << "split topk success. use tbe aicore."; + kernel_graph->AddValueNodeToGraph(indices_const); + } + + return new_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h new file mode 100644 index 0000000000..e005a83a2f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TopKSplit : public PatternProcessPass { + public: + explicit TopKSplit(bool multigraph = true) + : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared()) {} + ~TopKSplit() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc new file mode 100644 index 0000000000..057cf8deed --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" +#include +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, + {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, + {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, + {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; + +bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + bool changed = false; + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { + CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); + if (IsFormatInvaild(node)) { + changed = DoSplit(func_graph, node); + } + } + } + return changed; +} +bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + auto format_pair = std::make_pair(input_format, output_format); + + return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); +} +// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) +bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = node->cast()->input(1); + MS_EXCEPTION_IF_NULL(input_node); + + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + AnfNodePtr new_transdata_node = nullptr; + AnfNodePtr new_transpose_node = nullptr; + AnfNodePtr new_replace_node = nullptr; + // if output_format=default transdata need split transdata->transpose else transpose->transdata + if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { + // trans input_format to hwcn + new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, + false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); + // trans hwcn to default_format + new_transpose_node = + NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); + new_replace_node = new_transpose_node; + } else { + // trans default to hwcn + new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, + false, prim::kPrimTranspose->name()); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); + RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); + + // trans hwcn to output_format + new_transdata_node = + NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); + RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); + new_replace_node = new_transdata_node; + } + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + if (!manager->Replace(node, new_replace_node)) { + MS_LOG(EXCEPTION) << "Manager replace node failed"; + } + MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h new file mode 100644 index 0000000000..bc681944c3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ +#include +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TransDataSplit : public Pass { + public: + TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared()) {} + ~TransDataSplit() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); + bool IsFormatInvaild(const AnfNodePtr &node); + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc new file mode 100644 index 0000000000..189ac94546 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" +#include "backend/optimizer/common/helper.h" +namespace mindspore { +namespace opt { +AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto prim = std::make_shared(kAdamApplyOneOpName); + std::vector new_node_inputs = {NewValueNode(prim)}; + for (const auto &input_var : input_vars_) { + auto input_node = utils::cast((*equiv)[input_var]); + MS_EXCEPTION_IF_NULL(input_node); + new_node_inputs.push_back(input_node); + } + for (const auto &mul_x_input_var : mul_x_input_vars_) { + auto mul_x_input_node = utils::cast((*equiv)[mul_x_input_var]); + MS_EXCEPTION_IF_NULL(mul_x_input_node); + new_node_inputs.push_back(mul_x_input_node); + } + auto add2_y_node = utils::cast((*equiv)[add2_y_]); + MS_EXCEPTION_IF_NULL(add2_y_node); + new_node_inputs.push_back(add2_y_node); + auto new_node = func_graph->NewCNode(new_node_inputs); + return new_node; +} + +const BaseRef AdamApplyOneFusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); +} + +const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); +} + +const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); + VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); + VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); + return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); +} + +const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto new_node = CreateAdamApplyOneNode(func_graph, equiv); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(node->scope()); + // Set abstract of new node + AbstractBasePtrList new_node_abstract_list; + auto iter_add0 = (*equiv).find(add0_var_); + if (iter_add0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; + } + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add0 = utils::cast(iter_add0->second); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); + new_node_abstract_list.push_back(add1->abstract()); + new_node_abstract_list.push_back(add0->abstract()); + new_node_abstract_list.push_back(node->abstract()); + auto abstract_tuple = std::make_shared(new_node_abstract_list); + new_node->set_abstract(abstract_tuple); + // Create tuple_getitem node for outputs + std::vector new_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kAdamApplyOneOutputNum, &new_node_outputs); + if (new_node_outputs.size() != kAdamApplyOneOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node " << new_node->DebugString() << " should be " + << kAdamApplyOneOutputNum; + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add1, new_node_outputs[0]); + (void)manager->Replace(add0, new_node_outputs[1]); + return new_node_outputs[2]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h new file mode 100644 index 0000000000..683a345cdb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +constexpr size_t kAdamApplyOneInputVarNum = 5; +constexpr size_t kAdamApplyOneMulInputVarNum = 4; + +class AdamApplyOneFusion : public PatternProcessPass { + public: + explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { + input_vars_.push_back(std::make_shared()); + } + for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { + mul_x_input_vars_.push_back(std::make_shared()); + } + add2_y_ = std::make_shared(); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + + ~AdamApplyOneFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; + std::vector input_vars_; + std::vector mul_x_input_vars_; + VarPtr add2_y_; + VarPtr add0_var_; + VarPtr add1_var_; +}; + +class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond1Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} + + ~AdamApplyOneCond1Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond2Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} + + ~AdamApplyOneCond2Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond3Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} + + ~AdamApplyOneCond3Fusion() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { + public: + explicit AdamApplyOneCond4Fusion(bool multigraph = true) + : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} + + ~AdamApplyOneCond4Fusion() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc new file mode 100644 index 0000000000..b1afa338d4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -0,0 +1,189 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + auto input0 = utils::cast((*equiv)[input0_]); + auto input1 = utils::cast((*equiv)[input1_]); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto input4 = utils::cast((*equiv)[input4_]); + auto mul0_x = utils::cast((*equiv)[mul0_x_]); + auto mul1_x = utils::cast((*equiv)[mul1_x_]); + auto mul2_x = utils::cast((*equiv)[mul2_x_]); + auto mul3_x = utils::cast((*equiv)[mul3_x_]); + auto mul4_x = utils::cast((*equiv)[mul4_x_]); + auto add2_y = utils::cast((*equiv)[add2_y_]); + auto prim = std::make_shared(kAdamApplyOneWithDecayOpName); + return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; +} + +const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, input4_, add3}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); + VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + std::vector inputs = GetFusionNodeInputs(equiv); + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(node->scope()); + + auto iter_add0 = (*equiv).find(add0_var_); + if (iter_add0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; + } + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add0 = utils::cast(iter_add0->second); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); + auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), + AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), + AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, fusion_node, kAdamApplyOneWithDecayOutputNum, &fusion_node_outputs); + if (fusion_node_outputs.size() != kAdamApplyOneWithDecayOutputNum) { + MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; + return nullptr; + } + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add1, fusion_node_outputs[0]); + (void)manager->Replace(add0, fusion_node_outputs[1]); + return fusion_node_outputs[2]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h new file mode 100644 index 0000000000..2d599a8cc9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" +namespace mindspore { +namespace opt { +class AdamApplyOneWithDecayRule : public PatternProcessPass { + public: + explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_x_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_x_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + ~AdamApplyOneWithDecayRule() override = default; + const BaseRef DefinePattern() const override = 0; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr mul0_x_; + VarPtr mul1_x_; + VarPtr mul2_x_; + VarPtr mul3_x_; + VarPtr mul4_x_; + VarPtr add2_y_; + VarPtr add0_var_; + VarPtr add1_var_; +}; + +class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond1() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond2() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond3() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond4() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond5() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc new file mode 100644 index 0000000000..cc58d2b057 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" +#include +#include +#include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" + +namespace mindspore { +namespace opt { +namespace { +void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector *names_vec) { + MS_EXCEPTION_IF_NULL(names_vec); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr names_value = primitive->GetAttr(attr_name); + if (names_value == nullptr) { + return; + } + *names_vec = GetValue>(names_value); +} + +void AddOutputs(const CNodePtr &cnode, const std::vector &input_indices) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector input_names_vec; + GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec); + std::vector output_names_vec; + GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec); + AbstractBasePtrList abstract_list; + auto origin_abstract = cnode->abstract(); + MS_EXCEPTION_IF_NULL(origin_abstract); + if (origin_abstract->isa()) { + auto origin_abstract_tuple = dyn_cast(origin_abstract); + MS_EXCEPTION_IF_NULL(origin_abstract_tuple); + AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements(); + (void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list)); + } else { + abstract_list.emplace_back(origin_abstract); + } + + for (size_t i = 0; i < input_indices.size(); ++i) { + size_t index = input_indices[i]; + if (index + 1 >= cnode->inputs().size()) { + MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, " + << "node: " << cnode->DebugString(); + continue; + } + auto node_to_output = cnode->input(index + 1); + MS_EXCEPTION_IF_NULL(node_to_output); + abstract_list.emplace_back(node_to_output->abstract()); + if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) { + output_names_vec.emplace_back(input_names_vec[index]); + } + } + if (!output_names_vec.empty()) { + AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode); + } + auto abstract_tuple = std::make_shared(abstract_list); + cnode->set_abstract(abstract_tuple); +} +} // namespace + +const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string op_name = AnfAlgo::GetCNodeName(cnode); + InputToOutputRegister reg; + if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) { + return nullptr; + } + int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); + // No need add output when it is not a tbe op. + if (output_num == -1) { + return nullptr; + } + // No need add output if the output num matches the registered output num for tbe. + if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) { + return nullptr; + } + bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode); + AddOutputs(cnode, reg.input_indices()); + // No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems + // pointed to the outputs. + if (is_origin_tuple_output) { + return nullptr; + } + std::vector new_outputs; + auto new_abstract_tuple = dyn_cast(cnode->abstract()); + MS_EXCEPTION_IF_NULL(new_abstract_tuple); + CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs); + if (new_outputs.size() != new_abstract_tuple->size()) { + MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString(); + } + return new_outputs[0]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h new file mode 100644 index 0000000000..6e5560bfb0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class AddInputToOutput : public PatternProcessPass { + public: + explicit AddInputToOutput(bool multigraph = true) + : PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared()) {} + ~AddInputToOutput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + OpFinderPtr op_finder_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc new file mode 100644 index 0000000000..51bcd880cd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnorm); + MS_EXCEPTION_IF_NULL(node); + auto prim = std::make_shared(kBNInferOpName); + std::vector inputs = {NewValueNode(prim)}; + for (size_t i = 1; i < batchnorm->size(); ++i) { + inputs.push_back(batchnorm->input(i)); + } + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(batchnorm->scope()); + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node); + return new_node; +} + +bool CheckIndex(const AnfNodePtr &index_node) { + MS_EXCEPTION_IF_NULL(index_node); + if (!IsValueNode(index_node)) { + return false; + } + ValueNodePtr value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index != 0) { + MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNorm"; + return false; + } + return true; +} + +bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnorm); + if (batchnorm->size() < kBatchNormInputNum + 1) { + MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum; + return false; + } + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { + return false; + } + auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); + if (is_training) { + MS_LOG(DEBUG) << "is_training is true, no need do fusion"; + return false; + } + + if (IsUsedByOthers(graph, batchnorm)) { + MS_LOG(DEBUG) << "Only the 0th output of BatchNorm is used, then do fusion"; + return false; + } + return true; +} + +bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); + AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + if (!CheckIndex(index_node)) { + return false; + } + + AnfNodePtr batchnorm_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(batchnorm_anf); + MS_EXCEPTION_IF_NULL(batchnorm); + *batchnorm = batchnorm_anf->cast(); + MS_EXCEPTION_IF_NULL(*batchnorm); + return CheckBatchNorm(graph, *batchnorm); +} +} // namespace + +const BaseRef BatchNorm2BNInfer::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr Y = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Y); + VectorRef batchnorm({prim::kPrimBatchNorm, Xs}); + VectorRef pattern({prim::kPrimTupleGetItem, batchnorm, Y}); + return pattern; +} + +const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + CNodePtr batchnorm = nullptr; + if (!NeedFusion(graph, node, &batchnorm)) { + return nullptr; + } + return CreateBNInfer(graph, batchnorm, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h new file mode 100644 index 0000000000..46872aa959 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNorm2BNInfer : public PatternProcessPass { + public: + explicit BatchNorm2BNInfer(bool multigraph = true) : PatternProcessPass("batchnorm_to_bninfer", multigraph) {} + ~BatchNorm2BNInfer() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc new file mode 100644 index 0000000000..defb011396 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc @@ -0,0 +1,127 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateBNInferGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnormgrad); + auto prim = std::make_shared(kBNInferGradOpName); + std::vector inputs = {NewValueNode(prim)}; + inputs.push_back(batchnormgrad->input(1)); + inputs.push_back(batchnormgrad->input(3)); + inputs.push_back(batchnormgrad->input(5)); + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(batchnormgrad->scope()); + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnormgrad, new_node); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnormgrad, new_node); + return new_node; +} + +bool CheckIndex(const AnfNodePtr &index_node) { + MS_EXCEPTION_IF_NULL(index_node); + if (!IsValueNode(index_node)) { + return false; + } + ValueNodePtr value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index != 0) { + MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNormGrad"; + return false; + } + return true; +} + +bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(batchnormgrad); + if (batchnormgrad->size() < kBatchNormInputNum + 1) { + MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; + return false; + } + if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { + return false; + } + auto is_training = AnfAlgo::GetNodeAttr(batchnormgrad, kAttrIsTraining); + if (is_training) { + MS_LOG(DEBUG) << "is_training is true, no need do fusion"; + return false; + } + + if (IsUsedByOthers(graph, batchnormgrad)) { + MS_LOG(DEBUG) << "Only the 0th output of BatchNormGrad is used, then do fusion"; + return false; + } + return true; +} + +bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); + AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + if (!CheckIndex(index_node)) { + return false; + } + + AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(batchnormgrad_anf); + MS_EXCEPTION_IF_NULL(batchnormgrad); + *batchnormgrad = batchnormgrad_anf->cast(); + MS_EXCEPTION_IF_NULL(*batchnormgrad); + return CheckBatchNormGrad(graph, *batchnormgrad); +} +} // namespace + +const BaseRef BatchNormGrad2BNInferGrad::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr Y = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Y); + VectorRef batchnormgrad({prim::kPrimBatchNormGrad, Xs}); + VectorRef pattern({prim::kPrimTupleGetItem, batchnormgrad, Y}); + return pattern; +} + +const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + CNodePtr batchnormgrad = nullptr; + if (!NeedFusion(graph, node, &batchnormgrad)) { + return nullptr; + } + return CreateBNInferGrad(graph, batchnormgrad, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h new file mode 100644 index 0000000000..0676f8a040 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class BatchNormGrad2BNInferGrad : public PatternProcessPass { + public: + explicit BatchNormGrad2BNInferGrad(bool multigraph = true) + : PatternProcessPass("batchnormgrad_to_bninfergrad", multigraph) {} + ~BatchNormGrad2BNInferGrad() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc new file mode 100644 index 0000000000..1d89bfd388 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "common/utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +const BaseRef ClipByNormNoDivSquareSumFusion::DefinePattern() const { + auto greater = std::make_shared(kGreaterOpName); + MS_EXCEPTION_IF_NULL(greater); + auto sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(sqrt); + + VectorRef greater_pattern({greater, input_, constant_greater_}); + VectorRef pattern( + {prim::kPrimMaximum, + VectorRef({prim::kPrimSelect, greater_pattern, + VectorRef({sqrt, VectorRef({prim::kPrimSelect, greater_pattern, input_, constant_select_})}), input_}), + constant_maximum_}); + return pattern; +} + +const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + BaseRef &input_gnode = (*equiv)[input_]; + BaseRef &constant_select_gnode = (*equiv)[constant_select_]; + BaseRef &constant_greater_gnode = (*equiv)[constant_greater_]; + BaseRef &constant_maximum_gnode = (*equiv)[constant_maximum_]; + auto input = utils::cast(input_gnode); + auto constant_select = utils::cast(constant_select_gnode); + auto constant_greater = utils::cast(constant_greater_gnode); + auto constant_maximum = utils::cast(constant_maximum_gnode); + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(constant_select); + MS_EXCEPTION_IF_NULL(constant_greater); + MS_EXCEPTION_IF_NULL(constant_maximum); + + auto prim = std::make_shared(kClipByNormNoDivSumOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + fusion_node->set_scope(node->scope()); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h new file mode 100644 index 0000000000..9282b75527 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +constexpr auto kInputVarName = "input"; +constexpr auto kConstantSelectVarName = "constant_select"; +constexpr auto kConstantGreaterVarName = "constant_greater"; +constexpr auto kConstantMaximumVarName = "constant_maximum"; + +class ClipByNormNoDivSquareSumFusion : public PatternProcessPass { + public: + explicit ClipByNormNoDivSquareSumFusion(bool multigraph = true) + : PatternProcessPass("clip_by_norm_no_div_square_sum_fusion", multigraph) { + input_ = std::make_shared(kInputVarName); + constant_select_ = std::make_shared(kConstantSelectVarName); + constant_greater_ = std::make_shared(kConstantGreaterVarName); + constant_maximum_ = std::make_shared(kConstantMaximumVarName); + } + ~ClipByNormNoDivSquareSumFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_; + VarPtr constant_select_; + VarPtr constant_greater_; + VarPtr constant_maximum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc new file mode 100644 index 0000000000..e1b0cb81e3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool GetMinimumOp(const AnfNodePtr &input0, const AnfNodePtr &input1, CNodePtr *minimum, bool *is_first_input) { + MS_EXCEPTION_IF_NULL(input0); + MS_EXCEPTION_IF_NULL(input1); + + CNodePtr cnode = nullptr; + if (input0->isa() && !input1->isa()) { + cnode = input0->cast(); + *is_first_input = true; + } else if (!input0->isa() && input1->isa()) { + cnode = input1->cast(); + *is_first_input = false; + } else if (input0->isa() && input1->isa()) { + if (AnfAlgo::GetCNodeName(input0) == prim::kPrimMinimum->name()) { + cnode = input0->cast(); + *is_first_input = true; + } else { + cnode = input1->cast(); + *is_first_input = false; + } + } else { + return false; + } + + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMinimum->name()) { + return false; + } + *minimum = cnode; + return true; +} +} // namespace + +const BaseRef ClipByValueFusion::DefinePattern() const { + VectorRef pattern({prim::kPrimMaximum, maximum_input0_, maximum_input1_}); + return pattern; +} + +const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto maximum_input0 = utils::cast((*equiv)[maximum_input0_]); + auto maximum_input1 = utils::cast((*equiv)[maximum_input1_]); + MS_EXCEPTION_IF_NULL(maximum_input0); + MS_EXCEPTION_IF_NULL(maximum_input1); + + CNodePtr minimum = nullptr; + bool is_first_input = true; + if (!GetMinimumOp(maximum_input0, maximum_input1, &minimum, &is_first_input)) { + return nullptr; + } + MS_EXCEPTION_IF_NULL(minimum); + if (minimum->inputs().size() != kMinimumInputNum) { + return nullptr; + } + + auto prim = std::make_shared(kClipByValueOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), minimum->input(1), + is_first_input ? maximum_input1 : maximum_input0, minimum->input(2)}; + auto clip_by_value = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clip_by_value); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, clip_by_value.get()); + clip_by_value->set_scope(node->scope()); + return clip_by_value; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h new file mode 100644 index 0000000000..05bf713bdd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ClipByValueFusion : public PatternProcessPass { + public: + explicit ClipByValueFusion(bool multigraph = true) : PatternProcessPass("clip_by_value_fusion", multigraph) { + maximum_input0_ = std::make_shared(); + maximum_input1_ = std::make_shared(); + } + ~ClipByValueFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr maximum_input0_; + VarPtr maximum_input1_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc new file mode 100644 index 0000000000..6ccf3e29bd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kConfusionMulGradOutputNum = 2; + +CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf, + const AnfNodePtr &input3) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(reduce_sum); + MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(input3); + auto mul0 = mul0_anf->cast(); + MS_EXCEPTION_IF_NULL(mul0); + + auto prim = std::make_shared(kConfusionMulGradOpName); + std::vector inputs = {NewValueNode(prim), mul0->input(1), mul0->input(2), input3}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(reduce_sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); + auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + return fusion_node; +} + +AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input2); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(input2) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + + AnfNodePtr mul0 = nullptr; + const AnfNodeIndexSet &outputs_set = manager->node_users()[input2]; + // input2 must be the 2rd input of mul0 + auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul1](const std::pair &node_index) { + return node_index.first != mul1 && node_index.second == 2; + }); + if (it != outputs_set.end() && AnfAlgo::GetCNodeName(it->first) == prim::kPrimMul->name()) { + mul0 = it->first; + } + return mul0; +} + +bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, + const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { + MS_EXCEPTION_IF_NULL(mul0_anf); + MS_EXCEPTION_IF_NULL(mul1_anf); + MS_EXCEPTION_IF_NULL(reduce_sum); + MS_EXCEPTION_IF_NULL(input2); + auto addn = input2->cast(); + if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { + MS_LOG(INFO) << "mul's second input is not addn"; + return true; + } + std::vector shape = AnfAlgo::GetOutputInferShape(addn, 0); + if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { + MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; + return true; + } + if (!mul0_anf->isa() || !mul1_anf->isa()) { + return true; + } + auto mul1 = mul1_anf->cast(); + MS_EXCEPTION_IF_NULL(mul1); + auto mul0 = mul0_anf->cast(); + MS_EXCEPTION_IF_NULL(mul0); + + if (IsDepend(graph, mul0->input(1), reduce_sum)) { + MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; + return true; + } + if (IsDepend(graph, mul1->input(1), mul0)) { + MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; + return true; + } + return false; +} +} // namespace + +const BaseRef ConfusionMulGradFusion::DefinePattern() const { + VectorRef mul1({prim::kPrimMul, input3_, input2_}); + VectorRef reduce_sum({prim::kPrimReduceSum, mul1}); + return reduce_sum; +} + +const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto reduce_sum = node->cast(); + MS_EXCEPTION_IF_NULL(reduce_sum); + auto mul1 = reduce_sum->input(1); + if (IsUsedByOthers(graph, mul1)) { + MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; + return nullptr; + } + auto mul0 = GetMul0(graph, input2, mul1); + if (mul0 == nullptr) { + MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; + return nullptr; + } + if (QuitFusion(graph, mul0, mul1, node, input2)) { + return nullptr; + } + + auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, fusion_node, kConfusionMulGradOutputNum, &fusion_node_outputs); + + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(mul0, fusion_node_outputs[0]); + return fusion_node_outputs[1]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h new file mode 100644 index 0000000000..932f0d2890 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConfusionMulGradFusion : public PatternProcessPass { + public: + explicit ConfusionMulGradFusion(bool multigraph = true) + : PatternProcessPass("confusion_mul_grad_fusion", multigraph) { + input2_ = std::make_shared(); + input3_ = std::make_shared(); + } + ~ConfusionMulGradFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input2_; + VarPtr input3_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.cc new file mode 100644 index 0000000000..a8cf0af465 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { + return VectorRef({prim::kPrimSub, input0_, VectorRef({reduce_sum_, VectorRef({prim::kPrimMul, input1_, input0_})})}); +} + +const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + AnfNodePtr input0 = GetAnfNodeByVar(equiv, input0_); + AnfNodePtr input1 = GetAnfNodeByVar(equiv, input1_); + AnfNodePtr sum_anf = GetAnfNodeByVar(equiv, reduce_sum_); + if (sum_anf == nullptr || !sum_anf->isa()) { + MS_LOG(WARNING) << "Matched ReduceSum is not a CNode!"; + return nullptr; + } + if (!GetBoolAttr(sum_anf, kAttrKeepDims)) { + MS_LOG(INFO) << "ReduceSum's attr keep_dims should be true if do fusion. Otherwise the calculation will be wrong"; + return nullptr; + } + + auto prim = std::make_shared(kConfusionSoftmaxGradOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), input0, input1}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_abstract(node->abstract()); + fusion_node->set_scope(node->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum_anf, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum_anf, fusion_node); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h new file mode 100644 index 0000000000..e3a86e22c9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConfusionSoftmaxGradRule : public PatternProcessPass { + public: + explicit ConfusionSoftmaxGradRule(bool multigraph = true) + : PatternProcessPass("confusion_softmax_grad_rule", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + reduce_sum_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); + } + ~ConfusionSoftmaxGradRule() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input0_; + VarPtr input1_; + VarPtr reduce_sum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc new file mode 100644 index 0000000000..0fe042dc4e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t kReluV2OutputNum = 2; + +CNodePtr GetRelu(const CNodePtr &relu_grad) { + MS_EXCEPTION_IF_NULL(relu_grad); + if (relu_grad->size() != kReluGradInputNum) { + MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); + } + auto relu_anf = relu_grad->input(2); + MS_EXCEPTION_IF_NULL(relu_anf); + return relu_anf->cast(); +} + +CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu); + if (relu->size() != kReluInputNum) { + MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); + } + + auto prim = std::make_shared(kReluV2OpName); + std::vector inputs = {NewValueNode(prim), relu->input(1)}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu->scope()); + + // ReluV2's 2rd output is mask whose data type is uint8 + TypeId mask_dtype = kNumberTypeUInt8; + std::vector mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); + if (mask_shape.size() != 4) { + MS_LOG(DEBUG) << "relu's infer shape size not equal 4"; + return nullptr; + } + auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); + if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { + mask_shape[1] = (mask_shape[1] + 31) / 32; + mask_shape.push_back(4); + } else { + mask_shape[1] = (mask_shape[1] + 15) / 16; + mask_shape.push_back(2); + } + + auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; + auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); + return new_node; +} + +CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(relu_grad); + MS_EXCEPTION_IF_NULL(second_input); + + auto prim = std::make_shared(kReluGradV2OpName); + std::vector inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; + auto new_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_scope(relu_grad->scope()); + new_node->set_abstract(relu_grad->abstract()); + return new_node; +} +} // namespace + +const BaseRef DereluFusion::DefinePattern() const { + VarPtr i0 = std::make_shared(); + VarPtr i1 = std::make_shared(); + VectorRef relu({prim::kPrimRelu, i1}); + VectorRef relu_grad({prim::kPrimReluGrad, i0, relu}); + return relu_grad; +} + +const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto relu_grad = node->cast(); + MS_EXCEPTION_IF_NULL(relu_grad); + auto relu = GetRelu(relu_grad); + MS_EXCEPTION_IF_NULL(relu); + + auto relu_v2 = CreateReluV2(graph, relu); + if (relu_v2 == nullptr) { + return nullptr; + } + std::vector relu_v2_node_outputs; + CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); + + auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); + + auto manage = graph->manager(); + MS_EXCEPTION_IF_NULL(manage); + manage->Replace(relu, relu_v2_node_outputs[0]); + return relu_grad_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h new file mode 100644 index 0000000000..7506960ecb --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class DereluFusion : public PatternProcessPass { + public: + explicit DereluFusion(bool multigraph = true) : PatternProcessPass("derelu_fusion", multigraph) {} + ~DereluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc new file mode 100644 index 0000000000..dbff0374f3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -0,0 +1,340 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kReplaceOutputIndex0 = 3; +constexpr size_t kReplaceOutputIndex1 = 4; +bool IsC(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + return in->isa(); + } + return false; +} + +void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(bn); + MS_EXCEPTION_IF_NULL(bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(bn) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; + } + for (const auto &node_index : manager->node_users()[bn]) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + bn_outputs->push_back(output); + } +} +} // namespace + +const BaseRef FusedBatchNormFusion::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); + VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} + +ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + auto iter_constant_input0 = (*equiv).find(constant_input0_var_); + if (iter_constant_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; + } + auto constant_input = utils::cast(iter_constant_input0->second); + MS_EXCEPTION_IF_NULL(constant_input); + if (!constant_input->isa()) { + return nullptr; + } + auto value_node = constant_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + return nullptr; + } + auto tensor_ptr = value->cast(); + MS_EXCEPTION_IF_NULL(tensor_ptr); + if (tensor_ptr->data_type() == kNumberTypeFloat16) { + auto *half_data = static_cast(tensor_ptr->data_c()); + MS_EXCEPTION_IF_NULL(half_data); + float float_data = Eigen::half_impl::half_to_float(half_data[0]); + return MakeValue(float_data); + } else if (tensor_ptr->data_type() == kNumberTypeFloat32) { + auto *tensor_data = static_cast(tensor_ptr->data_c()); + MS_EXCEPTION_IF_NULL(tensor_data); + return MakeValue(tensor_data[0]); + } else { + MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32"; + return nullptr; + } +} + +AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + // Set input to create node + auto iter_data_input0 = (*equiv).find(data_input0_var_); + if (iter_data_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; + } + std::vector bn_training_reduce_inputs = { + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), + utils::cast(iter_data_input0->second)}; + auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); + MS_EXCEPTION_IF_NULL(bn_training_reduce); + bn_training_reduce->set_scope(node->scope()); + // Set abstract + auto iter_data_input1 = (*equiv).find(data_input1_var_); + if (iter_data_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; + } + auto data_input1 = utils::cast(iter_data_input1->second); + MS_EXCEPTION_IF_NULL(data_input1); + auto iter_data_input2 = (*equiv).find(data_input2_var_); + if (iter_data_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; + } + auto data_input2 = utils::cast(iter_data_input2->second); + MS_EXCEPTION_IF_NULL(data_input2); + AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_reduce->set_abstract(abstract_tuple); + return bn_training_reduce; +} + +void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, + const std::vector &bn_training_reduce_outputs, + std::vector *bn_training_update_inputs) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(bn_training_update_inputs); + auto iter_data_input0 = (*equiv).find(data_input0_var_); + if (iter_data_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; + } + auto iter_data_input1 = (*equiv).find(data_input1_var_); + if (iter_data_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; + } + auto iter_data_input2 = (*equiv).find(data_input2_var_); + if (iter_data_input2 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; + } + auto iter_variable_input0 = (*equiv).find(variable_input0_var_); + if (iter_variable_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; + } + auto iter_variable_input1 = (*equiv).find(variable_input1_var_); + if (iter_variable_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; + } + if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum + << ", but it is " << bn_training_reduce_outputs.size(); + } + *bn_training_update_inputs = { + NewValueNode(std::make_shared(kBNTrainingUpdateOpName)), + utils::cast(iter_data_input0->second), + bn_training_reduce_outputs[0], + bn_training_reduce_outputs[1], + utils::cast(iter_data_input1->second), + utils::cast(iter_data_input2->second), + utils::cast(iter_variable_input0->second), + utils::cast(iter_variable_input1->second), + }; +} + +void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, + std::vector *abstract_list) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(bn); + MS_EXCEPTION_IF_NULL(abstract_list); + auto bn_abstract_tuple = dyn_cast(bn->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + if (bn_abstract_tuple->elements().size() < kBnOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " + << bn_abstract_tuple->elements().size(); + } + auto iter_variable_input0 = (*equiv).find(variable_input0_var_); + if (iter_variable_input0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; + } + auto variable_input0 = utils::cast(iter_variable_input0->second); + MS_EXCEPTION_IF_NULL(variable_input0); + auto iter_variable_input1 = (*equiv).find(variable_input1_var_); + if (iter_variable_input1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; + } + auto variable_input1 = utils::cast(iter_variable_input1->second); + MS_EXCEPTION_IF_NULL(variable_input1); + *abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), + bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; +} + +AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( + const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + const std::vector &bn_training_reduce_outputs) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + // Set input + std::vector bn_training_update_inputs; + GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs); + auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); + MS_EXCEPTION_IF_NULL(bn_training_update); + // Set abstract + auto iter_batch_norm = (*equiv).find(batch_norm_var_); + if (iter_batch_norm == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; + } + AnfNodePtr bn = utils::cast(iter_batch_norm->second); + MS_EXCEPTION_IF_NULL(bn); + AbstractBasePtrList abstract_list; + GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); + auto abstract_tuple = std::make_shared(abstract_list); + bn_training_update->set_abstract(abstract_tuple); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update); + ValuePtr factor = GetFactor(equiv); + if (factor == nullptr) { + return nullptr; + } + AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update); + AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); + bn_training_update->set_scope(node->scope()); + return bn_training_update; +} + +const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(node); + AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv); + std::vector bn_training_reduce_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, + &bn_training_reduce_outputs); + AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs); + if (bn_training_update == nullptr) { + MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString(); + return nullptr; + } + std::vector bn_training_update_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum, + &bn_training_update_outputs); + if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) { + MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is " + << bn_training_update_outputs.size(); + } + // Replace old bn outputs with new outputs + auto iter_batch_norm = (*equiv).find(batch_norm_var_); + if (iter_batch_norm == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; + } + AnfNodePtr bn = utils::cast(iter_batch_norm->second); + std::vector bn_outputs; + GetBNOutput(func_graph, bn, &bn_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (const auto &output : bn_outputs) { + MS_EXCEPTION_IF_NULL(output); + if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + continue; + } + auto tuple_getitem_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); + AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + auto value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) { + (void)manager->Replace(output, bn_training_update_outputs[index]); + } + } + return bn_training_update_outputs[0]; +} + +const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); + VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); + VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); + VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); + VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); + VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} + +const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); + VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); + VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); + VectorRef cast0 = VectorRef({prim::kPrimCast, sub0}); + VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); + VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); + VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); + VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h new file mode 100644 index 0000000000..b3bbedc36e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +class FusedBatchNormFusion : public PatternProcessPass { + public: + explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph), + data_input0_var_(std::make_shared()), + data_input1_var_(std::make_shared()), + data_input2_var_(std::make_shared()), + variable_input0_var_(std::make_shared()), + variable_input1_var_(std::make_shared()), + constant_input0_var_(std::make_shared()), + constant_input1_var_(std::make_shared()), + batch_norm_var_(std::make_shared(std::make_shared(prim::kPrimBatchNorm->name()))) {} + ~FusedBatchNormFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const; + void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector &bn_training_reduce_outputs, + std::vector *bn_training_update_inputs) const; + void GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, + std::vector *abstract_list) const; + AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + const std::vector &bn_training_reduce_outputs) const; + ValuePtr GetFactor(const EquivPtr &equiv) const; + + VarPtr data_input0_var_; + VarPtr data_input1_var_; + VarPtr data_input2_var_; + VarPtr variable_input0_var_; + VarPtr variable_input1_var_; + VarPtr constant_input0_var_; + VarPtr constant_input1_var_; + VarPtr batch_norm_var_; +}; + +class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { + public: + explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true) + : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} + + ~FusedBatchNormMixPrecisionFusion0() override = default; + const BaseRef DefinePattern() const override; +}; + +class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { + public: + explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true) + : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} + + ~FusedBatchNormMixPrecisionFusion1() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.cc new file mode 100644 index 0000000000..2fb42f9bd6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/input_to_output_registry.h" +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +bool ApplyRMSPropPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool SparseApplyRMSPropPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool ApplyAdagradV2PreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool ApplyKerasMomentumPreCheck(const CNodePtr &node) { + TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); + return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); +} + +bool SparseApplyFtrlPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} + +bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) { + return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); +} +} // namespace +InputToOutputRegistry::InputToOutputRegistry() { + Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck); + Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck); + Register(kApplyAdagradOpName, {1}); + Register(kApplyAdagradDAName, {1, 2}); + Register(kApplyAdadeltaOpName, {1, 2}); + Register(kApplyPowerSignOpName, {1}); + Register(kApplyProximalAdagradOpName, {1}); + Register(kApplyAdaMaxOpName, {1, 2}); + Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck); + Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck); + Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck); + Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck); + Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck); + Register(kSparseApplyProximalAdagradOpName, {1}); + Register(kSparseApplyAdagradOpName, {1}); + Register(kApplyFtrlV2OpName, {1, 2}); + Register(kApplyMomentumOpName, {1}); + Register(kApplyFtrlOpName, {1, 2}); + Register(kApplyAdamOpName, {1, 2}); + Register(kApplyCenteredRMSPropOpName, {1, 2, 3}); + Register(kApplyAddSignOpName, {1}); + Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck); + Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck); + Register(kApplyAdamWithAmsgradOpName, {1, 2}); +} + +InputToOutputRegistry &InputToOutputRegistry::Instance() { + static InputToOutputRegistry instance; + return instance; +} + +void InputToOutputRegistry::Register(const InputToOutputRegister ®) { + auto op_name = reg.op_name(); + if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { + (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " input2output register successfully!"; + } +} + +void InputToOutputRegistry::Register(const std::string &op_name, const std::vector &input_indices, + const PreCheckFunc &pre_check_func) { + if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { + InputToOutputRegister reg(op_name, pre_check_func); + reg.set_input_indices(input_indices); + (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " input2output register successfully!"; + } +} + +bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const { + if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) { + *reg = op_input_to_output_map_.at(op_name); + MS_LOG(DEBUG) << op_name << " input2output find in registry."; + return true; + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h similarity index 100% rename from mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.h rename to mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/input_to_output_registry.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.cc new file mode 100644 index 0000000000..fd9fd31f12 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.cc @@ -0,0 +1,266 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + std::vector *old_pattern_outputs) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_); + auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &users = manager->node_users(); + if (users.find(real_div0) == users.end() || users[real_div0].size() < 2) { + return false; + } + AnfNodeIndexSet real_div0_outputs = users[real_div0]; + auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(), + [&real_div2, &equiv, this](const std::pair &node_index) { + return node_index.first != real_div2 && node_index.second == 1 && + MatchAnotherPattern(node_index.first, equiv); + }); + if (iter == real_div0_outputs.end()) { + return false; + } + + (*old_pattern_outputs).push_back(node); + (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_)); + (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_)); + (*old_pattern_outputs).push_back(iter->first); + + return true; +} + +AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph, + const std::vector &old_pattern_outputs, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + auto prim = std::make_shared(kLambNextMVOpName); + std::vector lamb_next_mv_rule_inputs = {NewValueNode(prim)}; + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input0_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input1_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input2_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input3_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input4_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input5_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input6_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul0_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul1_sub_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul2_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul3_sub1_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul4_x_])); + lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[add2_y_])); + auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs); + MS_EXCEPTION_IF_NULL(lamb_next_mv_rule); + + // Set abstract of new node + AbstractBasePtrList new_abstracts; + (void)std::transform(old_pattern_outputs.begin(), old_pattern_outputs.end(), std::back_inserter(new_abstracts), + [](const AnfNodePtr &out) { return out->abstract(); }); + auto abstract_tuple = std::make_shared(new_abstracts); + MS_EXCEPTION_IF_NULL(abstract_tuple); + lamb_next_mv_rule->set_abstract(abstract_tuple); + + // Create tuple_getitem node for outputs + std::vector lamb_next_mv_rule_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, lamb_next_mv_rule, kLambNextMVRuleOutputNum, &lamb_next_mv_rule_outputs); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(old_pattern_outputs[1], lamb_next_mv_rule_outputs[1]); + (void)manager->Replace(old_pattern_outputs[2], lamb_next_mv_rule_outputs[2]); + (void)manager->Replace(old_pattern_outputs[3], lamb_next_mv_rule_outputs[3]); + + return lamb_next_mv_rule_outputs[0]; +} + +bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { + return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) && + IsSameNode(equiv1, equiv2, add2_y_); +} + +const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + std::vector old_pattern_outputs; + if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { + return nullptr; + } + + return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); +} + +const BaseRef LambNextMVRuleCond1::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); + auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); + auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond2::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); + auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); + auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond3::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); + auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); + auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); + auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_}); + auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); + + return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); +} + +BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} + +const BaseRef LambNextMVRuleCond4::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + + auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); + auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); + auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); + auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); + auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); + auto add0 = VectorRef({add0_var_, mul0, mul1}); + auto add1 = VectorRef({add1_var_, mul2, mul3}); + + auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); + auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); + + auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); + auto sqrt0 = VectorRef({prim_rsqrt, add2}); + auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); + + return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); +} + +BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + // Two patterns share: real_div0, real_div1, add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); + VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); + return real_div4; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h new file mode 100644 index 0000000000..d14ce6e3fe --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LambNextMVRule : public MultipleOutputPatternProcessPass { + public: + explicit LambNextMVRule(const std::string &name = "", bool multigraph = true) + : MultipleOutputPatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + input6_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_sub_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_sub1_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div2_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + ~LambNextMVRule() override = default; + const BaseRef DefinePattern() const override = 0; + BaseRef DefineAnotherPattern() const override = 0; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; + + protected: + bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, + std::vector *old_pattern_outputs) const; + AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector &old_pattern_outputs, + const EquivPtr &equiv) const; + + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr input6_; + VarPtr mul0_x_; + VarPtr mul1_sub_; + VarPtr mul2_x_; + VarPtr mul3_sub1_; + VarPtr mul4_x_; + VarPtr add2_y_; + // nodes which two patterns share, and add2_y_ also. + VarPtr real_div0_var_; + VarPtr real_div1_var_; + // part of output nodes + VarPtr add0_var_; + VarPtr add1_var_; + // other node + VarPtr real_div2_var_; +}; + +class LambNextMVRuleCond1 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} + + ~LambNextMVRuleCond1() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond2 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} + + ~LambNextMVRuleCond2() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond3 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} + + ~LambNextMVRuleCond3() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVRuleCond4 : public LambNextMVRule { + public: + explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} + + ~LambNextMVRuleCond4() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc new file mode 100644 index 0000000000..4ef3fa269f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc @@ -0,0 +1,278 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore { +namespace opt { +AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, + const AnfNodePtr &new_node, const AnfNodePtr &add3, + const AnfNodePtr &add5, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(new_node); + MS_EXCEPTION_IF_NULL(add3); + MS_EXCEPTION_IF_NULL(add5); + MS_EXCEPTION_IF_NULL(equiv); + auto add0 = GetAnfNodeByVar(equiv, add0_var_); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = GetAnfNodeByVar(equiv, add1_var_); + MS_EXCEPTION_IF_NULL(add1); + + // Set abstract of new node + AbstractBasePtrList new_node_list; + new_node_list.push_back(add3->abstract()); + new_node_list.push_back(add0->abstract()); + new_node_list.push_back(add1->abstract()); + new_node_list.push_back(add5->abstract()); + auto abstract_tuple = std::make_shared(new_node_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + new_node->set_abstract(abstract_tuple); + // Create tuple_getitem node for outputs + std::vector new_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextMVWithDecayOutputNum, &new_node_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add3, new_node_outputs[0]); + (void)manager->Replace(add0, new_node_outputs[1]); + (void)manager->Replace(add1, new_node_outputs[2]); + return new_node_outputs[3]; +} + +AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, + const AnfNodePtr &add3, const AnfNodePtr &add5, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(add3); + MS_EXCEPTION_IF_NULL(equiv); + // Create new node with all the inputs + auto prim = std::make_shared(kLambNextMVWithDecayOpName); + std::vector new_node_inputs = {NewValueNode(prim)}; + for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { + auto input_node = utils::cast((*equiv)[input_vars_[i]]); + MS_EXCEPTION_IF_NULL(input_node); + new_node_inputs.push_back(input_node); + } + for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { + auto constant_mul_input_node = utils::cast((*equiv)[constant_mul_input_vars_[i]]); + MS_EXCEPTION_IF_NULL(constant_mul_input_node); + new_node_inputs.push_back(constant_mul_input_node); + } + auto constant_add2_y_node = utils::cast((*equiv)[constant_add2_y_]); + MS_EXCEPTION_IF_NULL(constant_add2_y_node); + new_node_inputs.push_back(constant_add2_y_node); + auto new_node = func_graph->NewCNode(new_node_inputs); + return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); +} + +bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { + return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && + IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); +} + +const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); + MS_EXCEPTION_IF_NULL(mul4); + // Get add3 and match the add3 pattern + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(mul4) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; + } + AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; + auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), + [&node, &equiv, this](const std::pair &node_index) { + return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); + }); + if (iter != mul4_outputs.end()) { + return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); + } + return nullptr; +} + +BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); + return add5; +} + +BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + MS_EXCEPTION_IF_NULL(prim_rsqrt); + VarPtr Xs = std::make_shared(); + VarPtr Ys = std::make_shared(); + VarPtr Zs = std::make_shared(); + MS_EXCEPTION_IF_NULL(Xs); + MS_EXCEPTION_IF_NULL(Ys); + MS_EXCEPTION_IF_NULL(Zs); + // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ + VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); + VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); + VectorRef mul4 = VectorRef({mul4_var_, Zs}); + + VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); + VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); + VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); + VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); + return add3; +} + +const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(prim_deal_div); + VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); + VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); + VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); + VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); + VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); + VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); + VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); + VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); + VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); + VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); + VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); + VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); + return add5; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h new file mode 100644 index 0000000000..23114c37ee --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass { + public: + explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true) + : MultipleOutputPatternProcessPass(name, multigraph) { + for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { + input_vars_.push_back(std::make_shared()); + } + for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { + constant_mul_input_vars_.push_back(std::make_shared()); + } + constant_add2_y_ = std::make_shared(); + mul4_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); + real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + } + + ~LambNextMVWithDecayRule() override = default; + const BaseRef DefinePattern() const override = 0; + BaseRef DefineAnotherPattern() const override = 0; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; + + protected: + AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, + const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; + AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, + const AnfNodePtr &add5, const EquivPtr &equiv) const; + std::vector input_vars_; + std::vector constant_mul_input_vars_; + // nodes which two patterns share + VarPtr constant_add2_y_; + VarPtr mul4_var_; + VarPtr real_div0_var_; + VarPtr real_div1_var_; + // part of output nodes + VarPtr add0_var_; + VarPtr add1_var_; +}; + +class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} + + ~LambNextMVWithDecayRuleCond1() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} + + ~LambNextMVWithDecayRuleCond2() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} + + ~LambNextMVWithDecayRuleCond3() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; + +class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule { + public: + explicit LambNextMVWithDecayRuleCond4(bool multigraph = true) + : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {} + + ~LambNextMVWithDecayRuleCond4() override = default; + const BaseRef DefinePattern() const override; + BaseRef DefineAnotherPattern() const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc new file mode 100644 index 0000000000..f21433b3c6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore { +namespace opt { +namespace { +std::tuple GetSharedNodes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto add3 = node->cast(); + MS_EXCEPTION_IF_NULL(add3); + if (add3->inputs().size() < kAddInputNum) { + MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum; + } + auto real_div2_anf = add3->input(1); + MS_EXCEPTION_IF_NULL(real_div2_anf); + auto real_div2 = real_div2_anf->cast(); + MS_EXCEPTION_IF_NULL(real_div2); + if (real_div2->inputs().size() < kRealDivInputNum) { + MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum; + } + auto sqrt0_anf = real_div2->input(2); + MS_EXCEPTION_IF_NULL(sqrt0_anf); + auto sqrt0 = sqrt0_anf->cast(); + MS_EXCEPTION_IF_NULL(sqrt0); + if (sqrt0->inputs().size() < kRsqrtInputNum) { + MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum; + } + auto add2_anf = sqrt0->input(1); + MS_EXCEPTION_IF_NULL(add2_anf); + auto add2 = add2_anf->cast(); + if (add2->inputs().size() < kAddInputNum) { + MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum; + } + return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2)); +} + +bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, + const AnfNodePtr &real_div1, const AnfNodePtr &add2_y) { + if (node == nullptr || !node->isa()) { + return false; + } + auto add5 = node->cast(); + if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) { + return false; + } + auto real_div4_anf = add5->input(1); + if (real_div4_anf == nullptr || !real_div4_anf->isa()) { + return false; + } + auto real_div4 = real_div4_anf->cast(); + if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) { + return false; + } + auto add4_anf = real_div4->input(2); + if (add4_anf == nullptr || !add4_anf->isa()) { + return false; + } + auto add4 = add4_anf->cast(); + if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) { + return false; + } + auto sqrt1_anf = add4->input(1); + if (sqrt1_anf == nullptr || !sqrt1_anf->isa()) { + return false; + } + auto sqrt1 = sqrt1_anf->cast(); + if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) { + return false; + } + return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 && + *add4->input(2) == *add2_y; +} + +std::tuple GetAdd0Add1Nodes(const AnfNodePtr &real_div0_anf, const AnfNodePtr &real_div1_anf) { + MS_EXCEPTION_IF_NULL(real_div0_anf); + MS_EXCEPTION_IF_NULL(real_div1_anf); + auto real_div0 = real_div0_anf->cast(); + auto real_div1 = real_div1_anf->cast(); + MS_EXCEPTION_IF_NULL(real_div0); + MS_EXCEPTION_IF_NULL(real_div1); + if (real_div0->inputs().size() != kRealDivInputNum) { + MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size"; + } + if (real_div1->inputs().size() != kRealDivInputNum) { + MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size"; + } + return std::make_tuple(real_div0->input(1), real_div1->input(1)); +} +} // namespace + +std::vector LambNextMVWithDecayV1Rule::GetFusionNodeInputs(const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + auto i0 = utils::cast((*equiv)[input0_]); + auto i1 = utils::cast((*equiv)[input1_]); + auto i2 = utils::cast((*equiv)[input2_]); + auto i3 = utils::cast((*equiv)[input3_]); + auto i4 = utils::cast((*equiv)[input4_]); + auto i5 = utils::cast((*equiv)[input5_]); + auto i6 = utils::cast((*equiv)[input6_]); + auto i7 = utils::cast((*equiv)[mul0_x_]); + auto i8 = utils::cast((*equiv)[mul1_sub_]); + auto i9 = utils::cast((*equiv)[mul2_x_]); + auto i10 = utils::cast((*equiv)[mul3_sub1_]); + auto i11 = utils::cast((*equiv)[mul4_x_]); + auto i12 = utils::cast((*equiv)[add2_y_]); + auto prim = std::make_shared(kLambNextMVWithDecayV1OpName); + return {NewValueNode(prim), i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12}; +} + +const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const { + const auto prim_rsqrt = std::make_shared(kRsqrtOpName); + const auto prim_real_div = std::make_shared(kRealDivOpName); + VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef add1({prim::kPrimTensorAdd, mul2, mul3}); + VectorRef real_div1({prim_real_div, add1, input2_}); + VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_}); + VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); + VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); + VectorRef sqrt0({prim_rsqrt, add2}); + VectorRef add0({prim::kPrimTensorAdd, mul0, mul1}); + VectorRef real_div0({prim_real_div, add0, input5_}); + VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); + VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4}); + return add3; +} + +const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (func_graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + AnfNodePtr mul4 = nullptr; + AnfNodePtr real_div0 = nullptr; + AnfNodePtr real_div1 = nullptr; + AnfNodePtr add2_y = nullptr; + std::tie(mul4, real_div0, real_div1, add2_y) = GetSharedNodes(node); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(mul4) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; + } + AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; + auto iter = std::find_if( + mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), + [&node, &mul4, &real_div0, &real_div1, &add2_y](const std::pair &node_index) { + return node_index.first != node && MatchAdd5Pattern(node_index.first, mul4, real_div0, real_div1, add2_y); + }); + if (iter == mul4_output_node_index_set.end()) { + return nullptr; + } + + std::vector inputs = GetFusionNodeInputs(equiv); + auto fusion_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(node->scope()); + + AnfNodePtr add0 = nullptr; + AnfNodePtr add1 = nullptr; + AnfNodePtr add5 = iter->first; + std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(add0, 0), + AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add5, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0), AnfAlgo::GetOutputInferShape(add0, 0), + AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add5, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); + + std::vector fusion_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, fusion_node, kLambNextMVWithDecayV1OutputNum, &fusion_node_outputs); + if (fusion_node_outputs.size() != kLambNextMVWithDecayV1OutputNum) { + MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; + return nullptr; + } + + (void)manager->Replace(add0, fusion_node_outputs[1]); + (void)manager->Replace(add1, fusion_node_outputs[2]); + (void)manager->Replace(add5, fusion_node_outputs[3]); + return fusion_node_outputs[0]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h new file mode 100644 index 0000000000..58f05c37ba --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class LambNextMVWithDecayV1Rule : public PatternProcessPass { + public: + explicit LambNextMVWithDecayV1Rule(bool multigraph = true) + : PatternProcessPass("lamb_next_mv_with_decay_v1_rule", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + input6_ = std::make_shared(); + mul0_x_ = std::make_shared(); + mul1_sub_ = std::make_shared(); + mul2_x_ = std::make_shared(); + mul3_sub1_ = std::make_shared(); + mul4_x_ = std::make_shared(); + add2_y_ = std::make_shared(); + } + + ~LambNextMVWithDecayV1Rule() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr input6_; + VarPtr mul0_x_; + VarPtr mul1_sub_; + VarPtr mul2_x_; + VarPtr mul3_sub1_; + VarPtr mul4_x_; + VarPtr add2_y_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.cc new file mode 100644 index 0000000000..03bc1e0484 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h" +#include +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + std::vector new_node_inputs; + auto prim = std::make_shared(kLambNextRightOpName); + MS_EXCEPTION_IF_NULL(prim); + new_node_inputs.push_back(NewValueNode(prim)); + auto input0 = utils::cast((*equiv)[input0_]); + MS_EXCEPTION_IF_NULL(input0); + new_node_inputs.push_back(input0); + auto input1 = utils::cast((*equiv)[input1_]); + MS_EXCEPTION_IF_NULL(input1); + new_node_inputs.push_back(input1); + auto mul2_x = utils::cast((*equiv)[mul2_x_]); + MS_EXCEPTION_IF_NULL(mul2_x); + new_node_inputs.push_back(mul2_x); + auto mul3_x = utils::cast((*equiv)[mul3_x_]); + MS_EXCEPTION_IF_NULL(mul3_x); + new_node_inputs.push_back(mul3_x); + auto true_div1_recip = utils::cast((*equiv)[true_div1_recip_]); + MS_EXCEPTION_IF_NULL(true_div1_recip); + new_node_inputs.push_back(true_div1_recip); + auto add2_y = utils::cast((*equiv)[add2_y_]); + MS_EXCEPTION_IF_NULL(add2_y); + new_node_inputs.push_back(add2_y); + auto new_node = func_graph->NewCNode(new_node_inputs); + return new_node; +} + +const BaseRef LambNextRightRule::DefinePattern() const { + const auto prim_sqrt = std::make_shared(kSqrtOpName); + MS_EXCEPTION_IF_NULL(prim_sqrt); + VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); + VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); + return VectorRef( + {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); +} + +const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto new_node = CreateLambNextRightNode(func_graph, equiv); + MS_EXCEPTION_IF_NULL(new_node); + // Set abstract of new node + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); + AbstractBasePtrList new_node_abstract_list; + new_node_abstract_list.push_back(add1->abstract()); + new_node_abstract_list.push_back(node->abstract()); + auto abstract_tuple = std::make_shared(new_node_abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + new_node->set_abstract(abstract_tuple); + // Create tuple_getitem node for outputs + std::vector new_node_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextRightOutputNum, &new_node_outputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(add1, new_node_outputs[0]); + return new_node_outputs[1]; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h new file mode 100644 index 0000000000..67687cc037 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +class LambNextRightRule : public PatternProcessPass { + public: + explicit LambNextRightRule(bool multigraph = true) + : PatternProcessPass("lamb_next_right_rule", multigraph), + input0_(std::make_shared()), + input1_(std::make_shared()), + mul2_x_(std::make_shared()), + mul3_x_(std::make_shared()), + true_div1_recip_(std::make_shared()), + add2_y_(std::make_shared()), + add1_var_(std::make_shared(std::make_shared(prim::kPrimTensorAdd->name()))) {} + + ~LambNextRightRule() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; + + VarPtr input0_; + VarPtr input1_; + VarPtr mul2_x_; + VarPtr mul3_x_; + VarPtr true_div1_recip_; + VarPtr add2_y_; + VarPtr add1_var_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc new file mode 100644 index 0000000000..8e38c3cc2e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "common/utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +const BaseRef LambUpdateWithLRRuleFusion::DefinePattern() const { + auto real_div = std::make_shared(kRealDivOpName); + MS_EXCEPTION_IF_NULL(real_div); + auto greater = std::make_shared(kGreaterOpName); + MS_EXCEPTION_IF_NULL(greater); + + VectorRef pattern_real_div0({real_div, input1_, input2_}); + VectorRef pattern_greater0({greater, input0_, constant_greater_max_}); + VectorRef pattern_greater1({greater, input1_, constant_greater_max_}); + VectorRef pattern_select0({prim::kPrimSelect, pattern_greater0, pattern_real_div0, constant_select_}); + VectorRef pattern_select1({prim::kPrimSelect, pattern_greater1, pattern_select0, constant_select_}); + VectorRef pattern_minimum0({prim::kPrimMinimum, pattern_select1, constant_minimum_}); + VectorRef pattern_maximum0({prim::kPrimMaximum, pattern_minimum0, constant_greater_max_}); + VectorRef pattern_mul0({prim::kPrimMul, pattern_maximum0, input3_}); + VectorRef pattern_mul1({prim::kPrimMul, pattern_mul0, input4_}); + VectorRef pattern({prim::kPrimSub, input5_, pattern_mul1}); + return pattern; +} + +const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto input0 = utils::cast((*equiv)[input0_]); + auto input1 = utils::cast((*equiv)[input1_]); + auto input2 = utils::cast((*equiv)[input2_]); + auto input3 = utils::cast((*equiv)[input3_]); + auto input4 = utils::cast((*equiv)[input4_]); + auto input5 = utils::cast((*equiv)[input5_]); + auto input6 = utils::cast((*equiv)[constant_greater_max_]); + auto input7 = utils::cast((*equiv)[constant_select_]); + auto input8 = utils::cast((*equiv)[constant_minimum_]); + + auto prim = std::make_shared(kLambUpdateWithLROpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), input0, input1, input2, input3, input4, input5, input6, input7, input8}; + auto lamb_update_with_lr = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(lamb_update_with_lr); + + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lamb_update_with_lr.get()); + lamb_update_with_lr->set_scope(node->scope()); + return lamb_update_with_lr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h new file mode 100644 index 0000000000..5ea01ccf65 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LambUpdateWithLRRuleFusion : public PatternProcessPass { + public: + explicit LambUpdateWithLRRuleFusion(bool multigraph = true) + : PatternProcessPass("lamb_update_with_lr_rule_fusion", multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + input3_ = std::make_shared(); + input4_ = std::make_shared(); + input5_ = std::make_shared(); + constant_greater_max_ = std::make_shared(); + constant_select_ = std::make_shared(); + constant_minimum_ = std::make_shared(); + } + ~LambUpdateWithLRRuleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr input3_; + VarPtr input4_; + VarPtr input5_; + VarPtr constant_greater_max_; + VarPtr constant_select_; + VarPtr constant_minimum_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.cc new file mode 100644 index 0000000000..59511a611a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h" +#include +#include +#include +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef LambUpdateWithLrV2::DefinePattern() const { + const auto prim_greater = std::make_shared(kGreaterOpName); + const auto prim_deal_div = std::make_shared(kRealDivOpName); + + VectorRef greater0({prim_greater, input_varptr_[0], input_varptr_[5]}); + VectorRef greater1({prim_greater, input_varptr_[1], input_varptr_[5]}); + VectorRef real_div0({prim_deal_div, input_varptr_[0], input_varptr_[1]}); + VectorRef select0({prim::kPrimSelect, greater1, real_div0, input_varptr_[6]}); + VectorRef select1({prim::kPrimSelect, greater0, select0, input_varptr_[6]}); + VectorRef mul0({prim::kPrimMul, select1, input_varptr_[2]}); + VectorRef mul1({prim::kPrimMul, mul0, input_varptr_[3]}); + + return VectorRef({prim::kPrimSub, input_varptr_[4], mul1}); +} + +const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + if (!CheckSupportDataType(node, kFloatDataTypeSet)) { + return nullptr; + } + auto prim = std::make_shared(kLambUpdateWithLrV2OpName); + std::vector inputs = {NewValueNode(prim)}; + (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs), + [&equiv](const VarPtr &in) { return utils::cast((*equiv)[in]); }); + auto lamb_update_with_lr_v2 = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(lamb_update_with_lr_v2); + lamb_update_with_lr_v2->set_abstract(node->abstract()); + + return lamb_update_with_lr_v2; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h new file mode 100644 index 0000000000..c5396178a5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class LambUpdateWithLrV2 : public PatternProcessPass { + public: + explicit LambUpdateWithLrV2(bool multigraph = true) : PatternProcessPass("lamb_update_with_lr_v2", multigraph) { + for (size_t i = 0; i < kLambUpdateWithLrV2InputNum - 1; ++i) { + input_varptr_.push_back(std::make_shared()); + } + } + ~LambUpdateWithLrV2() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector input_varptr_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc new file mode 100644 index 0000000000..fa1e92120d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +using common::SafeCStr; +namespace { +void GetOutputCastNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &node, std::vector *cast_nodes) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(node) == manager->node_users().end()) { + return; + } + for (const auto &node_index : manager->node_users()[node]) { + AnfNodePtr output = node_index.first; + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + if (AnfAlgo::GetCNodeName(output_cnode) != prim::kPrimTupleGetItem->name()) { + MS_LOG(EXCEPTION) << "The output of node " << node->DebugString() << " should be " + << prim::kPrimTupleGetItem->name(); + } + if (manager->node_users().find(output) == manager->node_users().end() || + manager->node_users()[output].size() != 1) { + continue; + } + AnfNodePtr transitive_output = manager->node_users()[output].begin()->first; + MS_EXCEPTION_IF_NULL(transitive_output); + auto transitive_output_cnode = transitive_output->cast(); + MS_EXCEPTION_IF_NULL(transitive_output_cnode); + if (AnfAlgo::GetCNodeName(transitive_output_cnode) == prim::kPrimCast->name()) { + cast_nodes->push_back(transitive_output_cnode); + } + } +} + +bool CheckKernelBuildInfo(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_info) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(kernel_info); + for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) { + if (kernel_info->GetInputDeviceType(i) != kNumberTypeFloat16 || + kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(cnode, i)) { + return false; + } + } + for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) { + if (kernel_info->GetOutputDeviceType(i) != kNumberTypeFloat32 || + kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(cnode, i)) { + return false; + } + } + return true; +} + +bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + std::vector *cast_nodes) { + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrShapeGamma, cnode)) { + MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr"; + return false; + } + if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) { + MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to " + << kLayerNormBetaGammaBackpropInputNum; + return false; + } + if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) { + MS_LOG(INFO) << "The node " << cnode->DebugString() << " outputs num is not equal to " + << kLayerNormBetaGammaBackpropOutputNum; + return false; + } + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { + if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) { + MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16"; + return false; + } + } + GetOutputCastNodes(func_graph, cnode, cast_nodes); + if (cast_nodes->size() != kLayerNormBetaGammaBackpropOutputNum) { + MS_LOG(INFO) << "The num of cast node in node " << cnode->DebugString() << " outputs is not equal to " + << kLayerNormBetaGammaBackpropOutputNum; + return false; + } + for (const auto &cast : *cast_nodes) { + if (AnfAlgo::GetInputDeviceDataType(cast, 0) != kNumberTypeFloat16 || + AnfAlgo::GetOutputDeviceDataType(cast, 0) != kNumberTypeFloat32) { + MS_LOG(INFO) << "The cast " << cast->DebugString() << " should be fp16->fp32"; + return false; + } + } + return true; +} +} // namespace + +const BaseRef LayerNormBetaGammaBackpropFusion::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + const auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + if (AnfAlgo::IsGraphKernel(node)) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::vector cast_nodes; + if (!CheckLayernormBetaGammaBackprop(func_graph, cnode, &cast_nodes)) { + return nullptr; + } + std::vector> kernel_info_list; + MS_EXCEPTION_IF_NULL(kernel_query_); + kernel_query_->Query(cnode, &kernel_info_list); + auto alternative_kernel_build_info = + std::find_if(kernel_info_list.begin(), kernel_info_list.end(), + [&cnode](const kernel::KernelBuildInfoPtr &candidate_kernel_build_info) { + return CheckKernelBuildInfo(cnode, candidate_kernel_build_info); + }); + if (alternative_kernel_build_info == kernel_info_list.end()) { + MS_LOG(INFO) << "Can not find alternative kernel build info for node " << node->DebugString(); + return nullptr; + } + AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_build_info, cnode.get()); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + // The cast_nodes size has been checked above. + MS_EXCEPTION_IF_NULL(cast_nodes[0]); + MS_EXCEPTION_IF_NULL(cast_nodes[1]); + if (cast_nodes[0]->inputs().size() != kCastInputNum) { + MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum; + } + (void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1)); + if (cast_nodes[1]->inputs().size() != kCastInputNum) { + MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum; + } + (void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1)); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h new file mode 100644 index 0000000000..5bf1608143 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class LayerNormBetaGammaBackpropFusion : public PatternProcessPass { + public: + explicit LayerNormBetaGammaBackpropFusion(bool multigraph = true) + : PatternProcessPass("layer_norm_beta_gamma_backprop_fusion", multigraph), + kernel_query_(std::make_shared()) {} + + ~LayerNormBetaGammaBackpropFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + KernelQueryPtr kernel_query_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc new file mode 100644 index 0000000000..fdd390677a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kMatMulInputIndex = 1; +constexpr size_t kBiasInputIndex = 2; +} // namespace + +const BaseRef MatmulBiasaddFusion::DefinePattern() const { + VarPtr X0 = std::make_shared(); + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + const auto prim_bias_add = std::make_shared(kBiasAddOpName); + return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2}); +} + +const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kBiasAddInputNum); + AnfNodePtr matmul = cnode->input(kMatMulInputIndex); + MS_EXCEPTION_IF_NULL(matmul); + auto matmul_cnode = matmul->cast(); + MS_EXCEPTION_IF_NULL(matmul_cnode); + matmul_cnode->add_input(cnode->input(kBiasInputIndex)); + AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul); + return matmul; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h new file mode 100644 index 0000000000..8c762435a9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MatmulBiasaddFusion : public PatternProcessPass { + public: + explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {} + + ~MatmulBiasaddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc new file mode 100644 index 0000000000..90c5ac19a9 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" +#include +#include +#include +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kAccumIndex = 1; +bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + std::vector mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); + return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); +} +} // namespace + +const BaseRef MomentumLossscaleFusion::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr X0 = std::make_shared(); + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); + VarPtr X4 = std::make_shared(); + return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4}); +} + +const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kApplyMomentumInputNum); + AnfNodePtr mul = cnode->input(4); + MS_EXCEPTION_IF_NULL(mul); + auto mul_cnode = mul->cast(); + MS_EXCEPTION_IF_NULL(mul_cnode); + CheckCNodeInputSize(mul_cnode, kMulInputNum); + size_t value_node_index = 0; + for (size_t i = 1; i < kMulInputNum; ++i) { + if (CheckValueNodeInputOfMul(mul_cnode->input(i))) { + value_node_index = i; + break; + } + } + if (value_node_index == 0) { + MS_LOG(DEBUG) << "The Mul " << mul->DebugString() << " to be fused must has a scalar constant input"; + return nullptr; + } + auto new_prim = std::make_shared(kFusedMulApplyMomentumOpName); + std::vector new_node_inputs{NewValueNode(new_prim), + cnode->input(1), + cnode->input(2), + cnode->input(3), + mul_cnode->input(kMulInputNum - value_node_index), + cnode->input(5), + mul_cnode->input(value_node_index)}; + auto new_node = func_graph->NewCNode(new_node_inputs); + MS_EXCEPTION_IF_NULL(new_node); + AnfAlgo::CopyNodeAttrs(node, new_node); + auto input_names_value = AnfAlgo::GetNodeAttr>(new_node, kAttrInputNames); + input_names_value[3] = "x1"; + input_names_value.emplace_back("x2"); + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); + new_node->set_abstract(node->abstract()); + new_node->set_scope(node->scope()); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h new file mode 100644 index 0000000000..8d36684a11 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MomentumLossscaleFusion : public PatternProcessPass { + public: + explicit MomentumLossscaleFusion(bool multigraph = true) + : PatternProcessPass("momentum_lossscale_fusion", multigraph) {} + + ~MomentumLossscaleFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc new file mode 100644 index 0000000000..2d766891a0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(add); + + for (size_t index = 1; index < add->size(); ++index) { + auto input = add->input(index); + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + auto cnode = input->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { + if (!opt::IsUsedByOthers(graph, cnode)) { + auto full_name = cnode->fullname_with_scope(); + // exclude lamb and adam, and only work in bert + if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || + std::string::npos == full_name.find("bert")) { + MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; + return false; + } + + *mul = cnode; + *mul_index = index; + return true; + } + } + } + } + return false; +} +} // namespace +const BaseRef MulAddFusion::DefinePattern() const { + VarPtr x = std::make_shared(); + VarPtr y = std::make_shared(); + VectorRef pattern({prim::kPrimTensorAdd, x, y}); + return pattern; +} + +const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + if (graph == nullptr || node == nullptr) { + return nullptr; + } + auto add = node->cast(); + if (add == nullptr || add->inputs().size() != kAddInputNum) { + return nullptr; + } + CNodePtr mul = nullptr; + size_t mul_index = 0; + if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) { + MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs"; + return nullptr; + } + + auto prim = std::make_shared(kFusedMulAddOpName); + std::vector inputs = {NewValueNode(prim)}; + for (size_t index = 1; index < mul->size(); ++index) { + inputs.push_back(mul->input(index)); + } + auto another_input_node = add->input(add->size() - mul_index); + if (another_input_node->isa() && + AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) { + MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse"; + return nullptr; + } + inputs.push_back(another_input_node); + auto fusion_node = graph->NewCNode(inputs); + fusion_node->set_scope(add->scope()); + fusion_node->set_abstract(add->abstract()); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h new file mode 100644 index 0000000000..0ad13e10e6 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MulAddFusion : public PatternProcessPass { + public: + explicit MulAddFusion(bool multigraph = true) : PatternProcessPass("mul_add_fusion", multigraph) {} + ~MulAddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc new file mode 100644 index 0000000000..3567864e2f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const CNodePtr &addn, + const size_t &lossscale_input_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mul); + MS_EXCEPTION_IF_NULL(addn); + auto prim = std::make_shared(kFusedMulAddNOpName); + std::vector inputs = {NewValueNode(prim)}; + inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); + inputs.push_back(addn->input(2)); + // scalar input should be 3rd input + inputs.push_back(mul->input(lossscale_input_index)); + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(addn->scope()); + fusion_node->set_abstract(addn->abstract()); + return fusion_node; +} +} // namespace + +const BaseRef MulAddNFusion::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Y = std::make_shared(); + VarPtr Z = std::make_shared(); + + VectorRef mul({prim::kPrimMul, X, Z}); + VectorRef addn({prim::kPrimAddN, mul, Y}); + return addn; +} + +const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + if (graph == nullptr || node == nullptr || equiv == nullptr) { + return nullptr; + } + + auto addn = node->cast(); + if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { + return nullptr; + } + auto mul_anf = addn->input(1); + if (mul_anf == nullptr) { + return nullptr; + } + auto mul = mul_anf->cast(); + if (mul == nullptr || mul->inputs().size() != kMulInputNum) { + return nullptr; + } + if (IsUsedByOthers(graph, mul)) { + MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse"; + return nullptr; + } + + size_t lossscale_input_index = 1; + for (size_t index = 1; index < mul->inputs().size(); ++index) { + auto input_node = mul->input(index); + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + lossscale_input_index = index; + break; + } + } + auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0); + if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) { + MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is " + << constant_shape.size() << " and shape[0] is " << constant_shape[0]; + return nullptr; + } + + return CreateFusionNode(graph, mul, addn, lossscale_input_index); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h new file mode 100644 index 0000000000..484cb75237 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class MulAddNFusion : public PatternProcessPass { + public: + explicit MulAddNFusion(bool multigraph = true) : PatternProcessPass("mul_addn_fusion", multigraph) {} + ~MulAddNFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc new file mode 100644 index 0000000000..9f44eb9d89 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, + std::vector *trans_road) { + if (node == nullptr) { + MS_LOG(ERROR) << "nullptr"; + return nullptr; + } + if (node->isa()) { + auto cnode = node->cast(); + auto op_name = AnfAlgo::GetCNodeName(cnode); + auto manager = func_graph->manager(); + if (manager == nullptr) { + return nullptr; + } + if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || + op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { + auto users = manager->node_users()[node]; + if (users.size() > 1 && !first_flag) { + return nullptr; + } + trans_road->push_back(cnode); + first_flag = false; + auto next_node = AnfAlgo::GetInputNode(cnode, 0); + if (next_node->isa() || next_node->isa()) { + return next_node; + } + return ParamTransRoad(func_graph, next_node, first_flag, trans_road); + } + } else if (node->isa() || node->isa()) { + return node; + } + return nullptr; +} + +kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, + TypeId output_type) { + MS_EXCEPTION_IF_NULL(cast); + auto kernel_info = cast->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto cast_build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(cast_build_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({format}); + builder.SetInputsFormat({format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetKernelType(cast_build_info->kernel_type()); + builder.SetFusionType(cast_build_info->fusion_type()); + builder.SetProcessor(cast_build_info->processor()); + return builder.Build(); +} +} // namespace +bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Func graph is nullptr"; + return false; + } + auto manager = func_graph->manager(); + if (manager == nullptr) { + return false; + } + std::vector node_list = TopoSort(func_graph->get_return()); + bool changed = false; + for (auto node : node_list) { + if (node == nullptr || !node->isa()) { + continue; + } + auto cnode = node->cast(); + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || + node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { + MS_LOG(DEBUG) << "Skip trans op"; + continue; + } + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { + std::vector trans_road; + bool first_flag = true; + auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); + if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && + AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && + AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { + auto cur_transop = trans_road[0]; + auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); + auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); + auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); + auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); + + auto cast = trans_road[1]; + if (param_format == format && param_dtype != dtype) { + AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); + manager->Replace(trans_road[2], final_node); + manager->Replace(cur_transop, cast); + } + changed = true; + } + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h new file mode 100644 index 0000000000..0479fd3d63 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class ParameterTransOpFusion : public Pass { + public: + explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} + ~ParameterTransOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + size_t groups_ = 1; +}; +} // namespace opt +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc new file mode 100644 index 0000000000..ebaa429ebf --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +void DoRefresh(const CNodePtr &cnode) { + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "node is nullptr"; + } + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { + auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index); + if (input_kernel_node->isa()) { + std::shared_ptr builder = + std::make_shared(); + auto cnode_input_format = AnfAlgo::GetInputFormat(cnode, input_index); + auto kernel_node_format = AnfAlgo::GetOutputFormat(input_kernel_node, 0); + auto dtype = AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0); + if (kernel_node_format != cnode_input_format) { + builder->SetOutputsFormat({cnode_input_format}); + builder->SetOutputsDeviceType({dtype}); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + } + } + } +} + +bool RefreshParameterFormat::Run(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "func_graph is nullptr."; + return false; + } + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto node : node_list) { + if (node == nullptr || !node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + auto node_name = AnfAlgo::GetCNodeName(cnode); + if (node_name == kBNTrainingUpdateOpName) { + DoRefresh(cnode); + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h new file mode 100644 index 0000000000..122bdf55ca --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +class RefreshParameterFormat : public Pass { + public: + explicit RefreshParameterFormat(size_t groups = 1) : Pass("refresh_parameter_format"), groups_(groups) {} + ~RefreshParameterFormat() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + size_t groups_ = 1; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc new file mode 100644 index 0000000000..6f48eabbc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef RemoveReshapePair::DefinePattern() const { + VarPtr X = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); +} + +const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_1); + // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly + if (IsUsedByOthers(func_graph, reshape_op_1)) { + return nullptr; + } + auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_2); + if (IsUsedByOthers(func_graph, reshape_op_2)) { + return nullptr; + } + auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); + if (input_shape == output_shape) { + auto input_node = reshape_op_2->input(1); + return input_node; + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h new file mode 100644 index 0000000000..848713201a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ + +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveReshapePair : public PatternProcessPass { + public: + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {} + ~RemoveReshapePair() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc new file mode 100644 index 0000000000..02a866930c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + +const BaseRef ReshapeTransposeFusion::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef reshape({prim_reshape, input_varptr_}); + + return VectorRef({prim::kPrimTranspose, reshape}); +} + +const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(transpose_cnode); + auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_cnode); + std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); + std::vector transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) { + return nullptr; + } + auto prim = std::make_shared(kConfusionTransposeDOpName); + std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; + auto new_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_abstract(node->abstract()); + + AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); + AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); + AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); + auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); + + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h new file mode 100644 index 0000000000..a76538019e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ReshapeTransposeFusion : public PatternProcessPass { + public: + explicit ReshapeTransposeFusion(bool multigraph = true) : PatternProcessPass("reshape_transpose_fusion", multigraph) { + input_varptr_ = std::make_shared(); + } + ~ReshapeTransposeFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.cc new file mode 100644 index 0000000000..a3706bfb68 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef SoftmaxGradExtFusion::DefinePattern() const { + VectorRef mul({prim::kPrimMul, input1_, input0_}); + VectorRef sum({sum_var_, mul}); + VectorRef sub({prim::kPrimSub, input0_, sum}); + VectorRef mul1({prim::kPrimMul, input2_, input1_}); + VectorRef mul_grad({prim::kPrimMul, mul1, sub}); + return mul_grad; +} + +const BaseRef SoftmaxGradExtFusionV2::DefinePattern() const { + VectorRef mul({prim::kPrimMul, input1_, input0_}); + VectorRef sum({sum_var_, mul}); + VectorRef sub({prim::kPrimSub, input0_, sum}); + VectorRef mul1({prim::kPrimMul, input1_, sub}); + VectorRef mul_grad({prim::kPrimMul, input2_, mul1}); + return mul_grad; +} + +const BaseRef SoftmaxGradExtFusionV3::DefinePattern() const { + VectorRef mul({prim::kPrimMul, input1_, input0_}); + VectorRef sum({sum_var_, mul}); + VectorRef sub({prim::kPrimSub, input0_, sum}); + VectorRef mul1({prim::kPrimMul, input1_, sub}); + VectorRef mul_grad({prim::kPrimMul, mul1, input2_}); + return mul_grad; +} + +const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(node); + auto input0 = GetAnfNodeByVar(equiv, input0_); + auto input1 = GetAnfNodeByVar(equiv, input1_); + auto input2 = GetAnfNodeByVar(equiv, input2_); + auto sum = GetAnfNodeByVar(equiv, sum_var_); + if (!GetBoolAttr(sum, kAttrKeepDims)) { + MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion"; + return nullptr; + } + + auto prim = std::make_shared(kSoftmaxGradExtOpName); + auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2}); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_scope(node->scope()); + fusion_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, "keepdims", sum, fusion_node); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum, fusion_node); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h new file mode 100644 index 0000000000..1b884b2726 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SoftmaxGradExtFusion : public PatternProcessPass { + public: + explicit SoftmaxGradExtFusion(const std::string &name = "softmax_grad_ext_fusion", bool multigraph = true) + : PatternProcessPass(name, multigraph) { + input0_ = std::make_shared(); + input1_ = std::make_shared(); + input2_ = std::make_shared(); + sum_var_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); + } + ~SoftmaxGradExtFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + protected: + VarPtr input0_; + VarPtr input1_; + VarPtr input2_; + VarPtr sum_var_; +}; + +class SoftmaxGradExtFusionV2 : public SoftmaxGradExtFusion { + public: + explicit SoftmaxGradExtFusionV2(bool multigraph = true) + : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v2", multigraph) {} + ~SoftmaxGradExtFusionV2() override = default; + const BaseRef DefinePattern() const override; +}; + +class SoftmaxGradExtFusionV3 : public SoftmaxGradExtFusion { + public: + explicit SoftmaxGradExtFusionV3(bool multigraph = true) + : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v3", multigraph) {} + ~SoftmaxGradExtFusionV3() override = default; + const BaseRef DefinePattern() const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc new file mode 100644 index 0000000000..67c881759a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" + +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(square); + MS_EXCEPTION_IF_NULL(sum); + if (square->inputs().size() != kSquareNodeInputNum) { + MS_LOG(EXCEPTION) << "Square node has wrong input size"; + } + auto prim = std::make_shared(kSquareSumV1OpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector square_sumv1_inputs = {NewValueNode(prim), square->input(1)}; + auto square_sumv1 = graph->NewCNode(square_sumv1_inputs); + MS_EXCEPTION_IF_NULL(square_sumv1); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + square_sumv1->set_kernel_info(kernel_info); + auto types = {AnfAlgo::GetOutputInferDataType(sum, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv1.get()); + square_sumv1->set_scope(sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1); + auto names = MakeValue>({square->fullname_with_scope(), sum->fullname_with_scope()}); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv1); + return square_sumv1; +} + +CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(square); + MS_EXCEPTION_IF_NULL(sum); + if (square->inputs().size() != kSquareNodeInputNum) { + MS_LOG(EXCEPTION) << "Square node has wrong input size"; + } + auto prim = std::make_shared(kSquareSumV2OpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector square_sumv2_inputs = {NewValueNode(prim), square->input(1)}; + auto square_sumv2 = graph->NewCNode(square_sumv2_inputs); + MS_EXCEPTION_IF_NULL(square_sumv2); + auto types = {AnfAlgo::GetOutputInferDataType(sum, 0), AnfAlgo::GetOutputInferDataType(square, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0), AnfAlgo::GetOutputInferShape(square, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv2.get()); + square_sumv2->set_scope(sum->scope()); + AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); + AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2); + auto names = MakeValue>({square->fullname_with_scope(), sum->fullname_with_scope()}); + AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv2); + return square_sumv2; +} + +std::tuple GetPrevNodes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto sum = node->cast(); + MS_EXCEPTION_IF_NULL(sum); + if (sum->inputs().size() != kSumNodeInputNum) { + MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size"; + } + auto square_anf = sum->input(1); + MS_EXCEPTION_IF_NULL(square_anf); + auto square = square_anf->cast(); + MS_EXCEPTION_IF_NULL(square); + + return std::make_tuple(sum, square_anf, square); +} +} // namespace + +const BaseRef SquareSumFusion::DefinePattern() const { + VarPtr X = std::make_shared(); + MS_EXCEPTION_IF_NULL(X); + return VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimSquare, X})}); +} + +const AnfNodePtr SquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + CNodePtr sum = nullptr; + AnfNodePtr square_anf = nullptr; + CNodePtr square = nullptr; + std::tie(sum, square_anf, square) = GetPrevNodes(node); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (manager->node_users().find(square_anf) == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "Square node has no output in NodeUsersMap"; + } + AnfNodePtr ret_node = nullptr; + if (manager->node_users()[square_anf].size() == 1) { + ret_node = GenerateSquareSumV1(graph, square, sum); + } else if (manager->node_users()[square_anf].size() == 2) { + auto square_sumv2 = GenerateSquareSumV2(graph, square, sum); + + std::vector square_sumv2_outputs; + CreateMultipleOutputsOfAnfNode(graph, square_sumv2, kSquareSumv2OutputNum, &square_sumv2_outputs); + if (square_sumv2_outputs.size() != kSquareSumv2OutputNum) { + MS_LOG(EXCEPTION) << "make SquareSumV2 outputs fail"; + } + (void)manager->Replace(square, square_sumv2_outputs[1]); + ret_node = square_sumv2_outputs[0]; + } + return ret_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h new file mode 100644 index 0000000000..54189606ba --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SquareSumFusion : public PatternProcessPass { + public: + explicit SquareSumFusion(bool multigraph = true) : PatternProcessPass("square_sum_fusion", multigraph) {} + ~SquareSumFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc new file mode 100644 index 0000000000..46bf2a8604 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckShapeDimInfo(const std::vector &shape) { + if (shape.empty()) { + return false; + } + if (shape.size() == 1 && shape[0] % kCubeSize != 0) { + return false; + } + return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); +} +} // namespace + +const BaseRef TransposeReshapeFusion::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef transpose({prim::kPrimTranspose, input_varptr_}); + + return VectorRef({prim_reshape, transpose}); +} + +const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_cnode); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(transpose_cnode); + std::vector reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); + if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { + return nullptr; + } + auto prim = std::make_shared(kConfusionTransposeDOpName); + std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; + auto new_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); + AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); + AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); + auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); + + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h new file mode 100644 index 0000000000..39b8fe4687 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class TransposeReshapeFusion : public PatternProcessPass { + public: + explicit TransposeReshapeFusion(bool multigraph = true) : PatternProcessPass("transpose_reshape_fusion", multigraph) { + input_varptr_ = std::make_shared(); + } + ~TransposeReshapeFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc new file mode 100644 index 0000000000..b6da588e89 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef TransposeTransDataFusion::DefinePattern() const { + const auto prim_transdata = std::make_shared(prim::KPrimTransData->name()); + VectorRef transpose({prim::kPrimTranspose, input_varptr_}); + + return VectorRef({prim_transdata, transpose}); +} + +const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum); + MS_EXCEPTION_IF_NULL(transdata_cnode); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum); + MS_EXCEPTION_IF_NULL(transpose_cnode); + auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode); + auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode); + MS_EXCEPTION_IF_NULL(transpose_kernel_build_info); + MS_EXCEPTION_IF_NULL(transdata_kernel_build_info); + + auto new_transdata_builder = std::make_shared(); + auto transpose_input_formats = transpose_kernel_build_info->GetAllInputFormats(); + new_transdata_builder->SetInputsFormat(transpose_input_formats); + new_transdata_builder->SetOutputsFormat(transdata_kernel_build_info->GetAllOutputFormats()); + new_transdata_builder->SetInputsDeviceType(transdata_kernel_build_info->GetAllInputDeviceTypes()); + new_transdata_builder->SetOutputsDeviceType(transdata_kernel_build_info->GetAllOutputDeviceTypes()); + new_transdata_builder->SetKernelType(transdata_kernel_build_info->kernel_type()); + new_transdata_builder->SetFusionType(transdata_kernel_build_info->fusion_type()); + new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); + + auto new_fusion_transdata = std::make_shared(kTransDataOpName); + if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { + std::vector inputs = {NewValueNode(new_fusion_transdata), + utils::cast((*equiv)[input_varptr_])}; + auto new_node = func_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_abstract(node->abstract()); + AnfAlgo::CopyNodeAttrs(transdata_cnode, new_node); + AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(transpose_input_formats[0]), new_node); + AnfAlgo::SetSelectKernelBuildInfo(new_transdata_builder->Build(), new_node.get()); + MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " success"; + return new_node; + } else { + MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " failed"; + return node; + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h new file mode 100644 index 0000000000..852d5194ec --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class TransposeTransDataFusion : public PatternProcessPass { + public: + explicit TransposeTransDataFusion(bool multigraph = true) + : PatternProcessPass("transpose_transdata_fusion", multigraph) { + input_varptr_ = std::make_shared(); + supported_checker_ = std::make_shared(); + } + ~TransposeTransDataFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; + + private: + SupportedCheckerPtr supported_checker_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc new file mode 100644 index 0000000000..887b9a76a1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/common_backend_optimization.h" +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/pass/convert_const_input_to_attr.h" +#include "backend/optimizer/pass/convert_tuple_output_to_maketuple.h" +#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" +#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" +#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" +#include "utils/context/ms_context.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +void BackendCommonOptimization(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } + auto optimizer = std::make_shared(); + auto common_pm = std::make_shared("common_pm"); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + common_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(common_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; + DumpIR(file_path, kernel_graph); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h new file mode 100644 index 0000000000..4127fc05de --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ +#include +#include "backend/session/kernel_graph.h" +namespace mindspore { +namespace opt { +void BackendCommonOptimization(const std::shared_ptr &kernel_graph); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.cc b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.cc new file mode 100644 index 0000000000..d21cabe54a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/fusion_id_allocator.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; } + +FusionIdAllocator::~FusionIdAllocator() {} + +void FusionIdAllocator::Init() { fusion_id = 0; } + +int32_t FusionIdAllocator::AllocateFusionId() { + fusion_id++; + return fusion_id; +} + +bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode); +} + +int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) { + if (HasFusionIdAttr(node)) { + return AnfAlgo::GetNodeAttr(node, kAttrFusionId); + } + return -1; +} + +void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) { + ValuePtr fusion_id_v = MakeValue(id); + AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h b/mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h similarity index 100% rename from mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h rename to mindspore/ccsrc/backend/optimizer/common/fusion_id_allocator.h diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc new file mode 100644 index 0000000000..266130c6b1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -0,0 +1,785 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/common/helper.h" +#include +#include +#include +#include +#include +#include +#include +#include "utils/utils.h" +#include "utils/base_ref.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "common/utils.h" +#include "runtime/device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +constexpr size_t kType32Len = 4; +std::vector Convert2Int(const std::vector &v) { + std::vector result; + (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); + return result; +} + +bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + std::vector node_list = TopoSort(graph->get_return()); + std::map> control_depend_map; + for (auto &nd : node_list) { + MS_EXCEPTION_IF_NULL(nd); + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + auto control_depend = nd->cast(); + auto prior_node = control_depend->input(kControlDependPriorIndex); + auto behind_node = control_depend->input(kControlDependBehindIndex); + auto it = control_depend_map.find(behind_node); + if (it == control_depend_map.end()) { + control_depend_map[behind_node] = std::set{prior_node}; + } else { + it->second.insert(prior_node); + } + } + } + + FuncGraphManagerPtr manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + std::unordered_set seen_node; + std::deque todo{node1}; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + + if (node == node2) { + return true; + } + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + auto it = control_depend_map.find(node); + if (it != control_depend_map.end()) { + (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); + } + } + return false; +} + +bool UnVisited(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + if (IsValueNode(in)) { + auto value_node = in->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto prim_py = value->cast(); + MS_EXCEPTION_IF_NULL(prim_py); + return !prim_py->HasAttr(kAttrVisited); + } else if (IsValueNode(in)) { + auto func_graph = GetValueNode(in); + MS_EXCEPTION_IF_NULL(func_graph); + return !func_graph->has_flag(kAttrVisited); + } + return false; + } + return false; +} + +bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(ERROR) << "The node is expected to be a cnode"; + return false; + } + *cnode = node->cast(); + if (*cnode == nullptr) { + return false; + } + if ((*cnode)->inputs().size() < IntToSize(input_size)) { + auto op_name = AnfAlgo::GetCNodeName(*cnode); + MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs."; + return false; + } + return true; +} + +CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != IntToSize(input_size)) { + auto op_name = AnfAlgo::GetCNodeName(cnode); + MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs."; + } + return cnode; +} + +void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != input_size) { + MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size; + } +} + +bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) { + MS_EXCEPTION_IF_NULL(node_x); + MS_EXCEPTION_IF_NULL(node_y); + return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) && + AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0)); +} + +const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + + auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); + MS_EXCEPTION_IF_NULL(transop_cnode); + auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); + auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); + MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); + MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1)); + auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1); + MS_EXCEPTION_IF_NULL(transed_node); + + std::vector replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, + depend_cnode->input(kDependInputNum - 1)}; + AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); + MS_EXCEPTION_IF_NULL(replace_depend); + auto transed_abstract = transed_node->abstract(); + replace_depend->set_abstract(transed_abstract); + return replace_depend; +} + +bool Visited(const BaseRef &n) { + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + if (IsValueNode(in)) { + auto value_node = in->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto prim_py = value->cast(); + MS_EXCEPTION_IF_NULL(prim_py); + return prim_py->HasAttr(kAttrVisited); + } else if (IsValueNode(in)) { + auto func_graph = GetValueNode(in); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->has_flag(kAttrVisited); + } + return false; + } + return false; +} + +void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, + std::vector *conv_bn1_outputs) { + auto prim = std::make_shared(kConvBN1OpName); + std::vector conv_bn1_inputs = {NewValueNode(prim)}; + MS_EXCEPTION_IF_NULL(conv_cnode); + // All the inputs of conv_bn1 are from the inputs of conv + for (size_t i = 1; i < conv_cnode->inputs().size(); i++) { + conv_bn1_inputs.push_back(conv_cnode->input(i)); + } + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs); + MS_EXCEPTION_IF_NULL(conv_bn1_cnode); + auto kernel_info = std::make_shared(); + conv_bn1_cnode->set_kernel_info(kernel_info); + // Set attr for conv_bn1 + AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode); + // Set abstract of conv_bn1 + MS_EXCEPTION_IF_NULL(bn_cnode); + auto bn_abstract_tuple = dyn_cast(bn_cnode->abstract()); + MS_EXCEPTION_IF_NULL(bn_abstract_tuple); + AbstractBasePtrList conv_bn1_abstract_list; + conv_bn1_abstract_list.push_back(conv_cnode->abstract()); + auto abstract_tensor = std::make_shared( + kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1))); + conv_bn1_abstract_list.push_back(abstract_tensor); + conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]); + auto abstract_tuple = std::make_shared(conv_bn1_abstract_list); + conv_bn1_cnode->set_abstract(abstract_tuple); + + CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs); +} + +void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, + const CNodePtr &bn_node, std::vector *fused_bn2_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(bn_node); + MS_EXCEPTION_IF_NULL(fused_bn2_outputs); + if (bn_node->inputs().size() != kBnInputNum) { + MS_LOG(EXCEPTION) << "BN node has wrong input size"; + } + if (fused_bn1_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; + } + + // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn + std::vector fused_bn2_inputs = {NewValueNode(std::make_shared(kFusedBN2OpName))}; + fused_bn2_inputs.push_back(fused_bn1_outputs[0]); + fused_bn2_inputs.push_back(fused_bn1_outputs[1]); + fused_bn2_inputs.push_back(bn_node->input(4)); + fused_bn2_inputs.push_back(bn_node->input(5)); + auto fused_bn2 = graph->NewCNode(fused_bn2_inputs); + MS_EXCEPTION_IF_NULL(fused_bn2); + auto kernel_info = std::make_shared(); + fused_bn2->set_kernel_info(kernel_info); + auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1), + AnfAlgo::GetOutputInferDataType(bn_node, 2)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1), + AnfAlgo::GetOutputInferShape(bn_node, 2)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get()); + fused_bn2->set_scope(bn_node->scope()); + AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2); + + CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs); +} + +void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, + const std::vector &fused_bn1_outputs, + const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, + std::vector *fused_bn3_outputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(data_input); + MS_EXCEPTION_IF_NULL(bn_node); + MS_EXCEPTION_IF_NULL(fused_bn3_outputs); + if (bn_node->inputs().size() != kBnInputNum) { + MS_LOG(EXCEPTION) << "BN node has wrong input size"; + } + + if (fused_bn1_outputs.size() != kBN1OutputNum) { + MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; + } + + if (fused_bn2_outputs.size() != kBN2OutputNum) { + MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size"; + } + + // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn + std::vector fused_bn3_inputs = {NewValueNode(std::make_shared(kFusedBN3OpName))}; + fused_bn3_inputs.push_back(data_input); + fused_bn3_inputs.push_back(fused_bn1_outputs[0]); + fused_bn3_inputs.push_back(fused_bn2_outputs[0]); + fused_bn3_inputs.push_back(bn_node->input(2)); + fused_bn3_inputs.push_back(bn_node->input(3)); + auto fused_bn3 = graph->NewCNode(fused_bn3_inputs); + MS_EXCEPTION_IF_NULL(fused_bn3); + auto kernel_info = std::make_shared(); + fused_bn3->set_kernel_info(kernel_info); + auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get()); + + fused_bn3->set_scope(bn_node->scope()); + AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3); + + (*fused_bn3_outputs).push_back(fused_bn3); +} + +void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, + std::vector *outputs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(outputs); + for (size_t i = 0; i < output_num; i++) { + auto idx = NewValueNode(SizeToInt(i)); + MS_EXCEPTION_IF_NULL(idx); + int temp = SizeToInt(i); + auto imm = std::make_shared(temp); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)}, + {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get()); + (*outputs).push_back(tuple_getitem); + } +} + +template +tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, + size_t data_length) { + MS_EXCEPTION_IF_NULL(value_tuple_ptr); + MS_EXCEPTION_IF_NULL(type_ptr); + std::vector values; + for (const auto &v : value_tuple_ptr->value()) { + MS_EXCEPTION_IF_NULL(v); + if (v->isa()) { + ScalarPtr scalar = v->cast(); + values.push_back(GetValue(scalar)); + } else { + MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; + return nullptr; + } + } + std::vector tensor_shape = {SizeToInt(values.size())}; + tensor::TensorPtr tensor = std::make_shared(type_ptr->type_id(), tensor_shape); + MS_EXCEPTION_IF_NULL(tensor); + tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; + tensor->set_device_info(device_info); + auto data_ptr = tensor->data_c(); + MS_EXCEPTION_IF_NULL(data_ptr); + auto elem_num = values.size() * data_length; + auto ret_code = memcpy_s(data_ptr, static_cast(tensor->data().nbytes()), values.data(), elem_num); + if (ret_code != 0) { + MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; + } + return tensor; +} + +tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { + MS_EXCEPTION_IF_NULL(value_tuple); + tensor::TensorPtr tensor = nullptr; + if (value_tuple->value().empty()) { + MS_LOG(WARNING) << "The value tuple is empty."; + return nullptr; + } + ValuePtr v = *(value_tuple->value().begin()); + MS_EXCEPTION_IF_NULL(v); + // Currently we only deal with the scalar tuple + if (!v->isa()) { + MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; + return nullptr; + } + ScalarPtr scalar = v->cast(); + MS_EXCEPTION_IF_NULL(scalar); + if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kInt32, kType32Len); + } else if (scalar->isa()) { + tensor = CreateTensorWithValueTuple(value_tuple, kFloat32, kType32Len); + } else { + auto type = scalar->type(); + auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); + MS_LOG(ERROR) << "Invalid scalar type: " << type_str; + return nullptr; + } + return tensor; +} + +bool IsNopNode(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { + return false; + } + static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, + prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), + kFlattenGradOpName}; + if (node == nullptr || !node->isa()) { + return false; + } + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) { + return false; + } + return true; +} + +bool IsAllNopNode(const session::KernelGraph *const graph) { + MS_EXCEPTION_IF_NULL(graph); + auto execution_order = graph->execution_order(); + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsNopNode(cnode)) { + return false; + } + } + return true; +} + +void HideNopNode(session::KernelGraph *const graph) { + MS_EXCEPTION_IF_NULL(graph); + if (IsAllNopNode(graph) == true) { + return; + } + auto execution_order = graph->execution_order(); + MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); + std::vector new_nodes; + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsNopNode(cnode)) { + new_nodes.push_back(cnode); + } + } + graph->set_execution_order(new_nodes); + MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size(); +} + +void RemoveNopNode(session::KernelGraph *const graph) { + MS_EXCEPTION_IF_NULL(graph); + if (IsAllNopNode(graph) == true) { + return; + } + bool changed = true; + while (changed) { + changed = false; + std::vector new_nodes; + for (auto &cnode : graph->execution_order()) { + MS_EXCEPTION_IF_NULL(cnode); + // ignore nop node itself + if (IsNopNode(cnode)) { + continue; + } + // Replace the input which is nop node + std::vector new_inputs; + new_inputs.push_back(cnode->input(0)); + bool need_update = false; + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto input = cnode->input(i); + MS_EXCEPTION_IF_NULL(input); + auto cinput = input->cast(); + if (cinput == nullptr || !IsNopNode(cinput)) { + new_inputs.push_back(input); + continue; + } + if (cinput->inputs().size() == 2) { + new_inputs.push_back(cinput->input(1)); + need_update = true; + changed = true; + } else { + new_inputs.push_back(input); + } + } + if (need_update) { + cnode->set_inputs(new_inputs); + } + // push into new execution list + new_nodes.push_back(cnode); + } + graph->set_execution_order(new_nodes); + } +} + +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node) { + auto output_node_list = std::make_shared>>(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto iter = manager->node_users().find(node); + if (iter == manager->node_users().end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + auto output_info_list = iter->second; + for (const auto &output_info : output_info_list) { + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { + continue; + } + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && + output_info.second == kDependAttachNodeIndex) { + continue; + } + output_node_list->push_back(output_info); + } + return output_node_list; +} + +bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto output_node_list = GetRealNodeUsedList(graph, node); + MS_EXCEPTION_IF_NULL(output_node_list); + return output_node_list->size() > 1; +} + +AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs; + std::vector new_input_names; + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto input_names = primitive->GetAttr(kAttrInputNames); + if (input_names == nullptr) { + MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; + return; + } + auto input_names_vec = GetValue>(input_names); + auto inputs = cnode->inputs(); + new_inputs.push_back(inputs[0]); + bool need_update = false; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + auto input_node = inputs[i + 1]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; + if (i >= input_names_vec.size()) { + MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; + } + primitive->set_attr(input_names_vec[i], value_node->value()); + need_update = true; + } else { + new_inputs.push_back(input_node); + if (i < input_names_vec.size()) { + new_input_names.push_back(input_names_vec[i]); + } + } + } + if (need_update) { + // Update cnode's inputs + cnode->set_inputs(new_inputs); + // Update cnode's input_names attr + primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); + } +} + +bool AnfEqual(const BaseRef &a, const BaseRef &b) { + if (utils::isa(a) && utils::isa(b)) { + auto a_node = utils::cast(a); + auto b_node = utils::cast(b); + MS_EXCEPTION_IF_NULL(a_node); + MS_EXCEPTION_IF_NULL(b_node); + if (IsValueNode(a_node) && IsValueNode(b_node)) { + auto a_value_node = a_node->cast(); + MS_EXCEPTION_IF_NULL(a_value_node); + auto a_value = a_value_node->value(); + MS_EXCEPTION_IF_NULL(a_value); + auto a_prim = a_value->cast(); + MS_EXCEPTION_IF_NULL(a_prim); + + auto b_value_node = b_node->cast(); + MS_EXCEPTION_IF_NULL(b_value_node); + auto b_value = b_value_node->value(); + MS_EXCEPTION_IF_NULL(b_value); + auto b_prim = b_value->cast(); + MS_EXCEPTION_IF_NULL(b_prim); + + return a_prim->name() == b_prim->name(); + } else if (a_node->isa() && b_node->isa()) { + auto a_value_node_ptr = a_node->cast(); + if (a_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto a_value_ptr = a_value_node_ptr->value(); + if (a_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + auto b_value_node_ptr = b_node->cast(); + if (b_value_node_ptr == nullptr) { + MS_LOG(EXCEPTION) << "cast value node ptr fail"; + } + auto b_value_ptr = b_value_node_ptr->value(); + if (b_value_ptr == nullptr) { + MS_LOG(EXCEPTION) << "value ptr is nullptr"; + } + + return (*a_value_ptr) == (*b_value_ptr); + } + MS_LOG(DEBUG) << "check AnfNodePtr equal"; + } + if (utils::isa(a) && utils::isa(b)) { + MS_LOG(DEBUG) << "check GraphPtr equal"; + } + return a == b; +} + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { + // To matchCNode and Kernel's type + if (utils::isa(a) && utils::isa(b)) { + return true; + } + return a.type() == b.type(); +} + +namespace { +ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + if (utils::isa(sexp)) { + return NewValueNode(utils::cast(sexp)); + } + return nullptr; +} + +CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + if (utils::isa(graph)) { + return std::make_shared(input_nodes, utils::cast(graph)); + } + return nullptr; +} + +VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); + return std::make_shared(utils::cast(sexp), nullptr); + } + if (utils::isa(graph)) { + MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); + return std::make_shared(utils::cast(sexp), utils::cast(graph)); + } + MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); + return nullptr; +} + +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { + MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + std::vector input_nodes; + const auto &tuple = utils::cast(sexp); + if (multigraph && utils::isa(graph)) { + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); + input_nodes.push_back(node); + } + VarPtr var_ptr = utils::cast(graph); + return std::make_shared(input_nodes, var_ptr); + } + + for (auto &x : tuple) { + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); + input_nodes.push_back(node); + } + return CreateCNodeWithGraph(input_nodes, graph); +} +} // namespace + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { + MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_EXCEPTION_IF_NULL(primitive_vars); + if (utils::isa(sexp)) { + return HandleSexpVector(sexp, graph, primitive_vars, multigraph); + } + if (utils::isa(sexp)) { + auto var_ptr = utils::cast(sexp); + MS_EXCEPTION_IF_NULL(var_ptr); + if (var_ptr->primitive()) { + (*primitive_vars)[var_ptr->primitive()] = var_ptr; + return NewValueNode(var_ptr->primitive()); + } + return CreateVarNodeWithSexp(sexp, graph); + } + if (utils::isa(sexp)) { + return utils::cast(sexp); + } + auto value_node = CreateValueNodeWithSexp(sexp); + if (value_node == nullptr) { + MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); + } + return value_node; +} + +bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { + MS_EXCEPTION_IF_NULL(equiv1); + MS_EXCEPTION_IF_NULL(equiv2); + MS_EXCEPTION_IF_NULL(var_node); + auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); + MS_EXCEPTION_IF_NULL(equiv1_node); + auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); + MS_EXCEPTION_IF_NULL(equiv2_node); + return *equiv1_node == *equiv2_node; +} + +AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { + MS_EXCEPTION_IF_NULL(equiv); + MS_EXCEPTION_IF_NULL(var_node); + auto iter = (*equiv).find(var_node); + if (iter == (*equiv).end()) { + MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; + return nullptr; + } + auto res = utils::cast(iter->second); + if (res == nullptr) { + MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; + } + return res; +} + +bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { + MS_EXCEPTION_IF_NULL(n1); + MS_EXCEPTION_IF_NULL(n2); + auto n1_cnode = n1->cast(); + auto n2_cnode = n2->cast(); + MS_EXCEPTION_IF_NULL(n1_cnode); + MS_EXCEPTION_IF_NULL(n2_cnode); + auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_input1); + auto value_node1 = index_input1->cast(); + MS_EXCEPTION_IF_NULL(value_node1); + auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_input2); + auto value_node2 = index_input2->cast(); + MS_EXCEPTION_IF_NULL(value_node2); + return GetValue(value_node1->value()) < GetValue(value_node2->value()); +} + +bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(INFO) << "node is not a cnode"; + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); +} + +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { + MS_EXCEPTION_IF_NULL(node); + TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); + if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { + return true; + } + MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h new file mode 100644 index 0000000000..a267e65b53 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -0,0 +1,199 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ + +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "backend/session/kernel_graph.h" +#include "common/utils.h" +#include "backend/optimizer/common/pattern_engine.h" + +namespace mindspore { +namespace opt { +constexpr size_t kTransOpInputNum = 2; +constexpr size_t kCastInputNum = 2; +constexpr size_t kDependInputNum = 3; +constexpr size_t kReluInputNum = 2; +constexpr size_t kReluGradInputNum = 3; +constexpr size_t kAddInputNum = 3; +constexpr size_t kAddNInputNum = 3; +constexpr size_t kTupleGetitemInputNum = 3; +constexpr size_t kConvInputNum = 3; +constexpr size_t kRealDivInputNum = 3; +constexpr size_t kSqrtInputNum = 2; +constexpr size_t kMulInputNum = 3; +constexpr size_t kRsqrtInputNum = 2; +constexpr size_t kSubInputNum = 3; +constexpr size_t kAssignSubInputNum = 3; + +constexpr size_t kConvBn1OutputNum = 3; +constexpr size_t kBn2ReluOutputNum = 4; + +constexpr size_t kBnInputNum = 6; +constexpr size_t kBnOutputNum = 5; +constexpr size_t kBatchNormInputNum = 5; +constexpr size_t kBatchNormOutputNum = 5; + +constexpr size_t kBN1OutputNum = 2; +constexpr size_t kBN2OutputNum = 3; +constexpr size_t kBN3OutputNum = 1; + +constexpr size_t kBNGradInputNum = 6; +constexpr size_t kBNGradOutputNum = 3; + +constexpr size_t kBNGrad1OutputNum = 3; +constexpr size_t kBNGrad2OutputNum = 5; +constexpr size_t kBNGrad3OutputNum = 1; + +constexpr size_t kBNTrainingReduceOutputNum = 2; +constexpr size_t kBNTrainingUpdateOutputNum = 5; +constexpr size_t kBNTrainingUpdateV2OutputNum = 3; +constexpr size_t kBNTrainingUpdateV3OutputNum = 5; +constexpr size_t kBNTrainingUpdateGradOutputNum = 2; + +constexpr size_t kSingleOutputNum = 1; +constexpr size_t kSumNodeInputNum = 2; +constexpr size_t kSquareNodeInputNum = 2; +constexpr size_t kSquareSumv2OutputNum = 2; +constexpr size_t kMinimumInputNum = 3; + +constexpr size_t kLambNextMVWithDecayInputNum = 7; +constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5; +constexpr size_t kLambNextMVWithDecayOutputNum = 4; +constexpr size_t kLambNextMVWithDecayV1OutputNum = 4; +constexpr size_t kLambNextRightOutputNum = 2; +constexpr size_t kLambUpdateWithLrV2InputNum = 8; +constexpr size_t kLambNextMVRuleInputNum = 14; +constexpr size_t kLambNextMVRuleOutputNum = 4; +constexpr size_t kBackendReshapeInputNum = 2; +constexpr size_t kBackendTransposeInputNum = 2; +constexpr size_t kAdamApplyOneWithDecayOutputNum = 3; +constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5; +constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2; +constexpr size_t kLayerNormGradInputNum = 6; +constexpr size_t kAdamApplyOneOutputNum = 3; +constexpr size_t kBackendTransDataInputNum = 2; +constexpr size_t kApplyMomentumInputNum = 6; +constexpr size_t kBiasAddInputNum = 3; +constexpr size_t kTopkInputNum = 3; +constexpr size_t kLarsV2InputNum = 5; +constexpr size_t kFusedMulApplyMomentumOutputNum = 2; +constexpr size_t kSplitInputNum = 2; + +enum FusedBatchNormInput { + kX = 1, + kVariance = 5, +}; +enum FusedBatchNormOutput { + kY = 0, + kRunningMean, + kRunningVariance, + kSaveMean, + kSaveInvVariance, +}; +enum ConvBn1Output { + kData = 0, + kVarPart, + kMean, +}; + +std::vector Convert2Int(const std::vector &v); + +// check whether node1 depends on node2 or not +bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); + +bool UnVisited(const BaseRef &n); + +bool Visited(const BaseRef &n); + +// check if the input node is CNode, then check it's input_size, if meet condition above, return true, otherwise return +// false. cnode can only be used when return true. +bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode); + +// check if the input node is CNode, then check it's input_size, return CNodePtr if check success. +CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size); + +void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size); + +bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y); + +const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node); + +void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, + std::vector *conv_bn1_outputs); + +void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, + const CNodePtr &bn_node, std::vector *fused_bn2_outputs); +void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, + const std::vector &fused_bn1_outputs, + const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, + std::vector *fused_bn3_outputs); + +void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num, + std::vector *outputs); + +tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, + size_t data_length); + +tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); + +bool IsAllNopNode(const session::KernelGraph *const graph); + +bool IsNopNode(const AnfNodePtr &node); + +void HideNopNode(session::KernelGraph *const graph); + +void RemoveNopNode(session::KernelGraph *const graph); + +AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); + +bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); + +std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, + const AnfNodePtr &node); + +void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); + +bool AnfEqual(const BaseRef &a, const BaseRef &b); + +bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); + +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph = false); + +// Check var_node in two equivs is the same node +bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); + +// Get anf_node from equiv by var_node +AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); + +// Compare tuple getitem's index, return bool[n1's index < n2's index] +bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); + +// Get attr which is bool from cnode +bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); + +// Check node's data type is in supported data type set +bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/node_pass.cc b/mindspore/ccsrc/backend/optimizer/common/node_pass.cc new file mode 100644 index 0000000000..16f5284a57 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/node_pass.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/node_pass.h" + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +bool NodePass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + std::unordered_set seen_node; + std::deque todo{func_graph->output()}; + bool changes = false; + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + (void)seen_node.erase(node); + } else if (new_node == nullptr) { + new_node = node; + } + if (new_node && IsValueNode(new_node)) { + auto const_func_graph = GetValueNode(new_node); + MS_EXCEPTION_IF_NULL(const_func_graph); + if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + todo.push_back(const_func_graph->output()); + } + } else if (new_node && new_node->isa()) { + if (AnfAlgo::IsGraphKernel(new_node)) { + todo.push_back(new_node); + } + auto cnode = new_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); + } + changes = changes || change; + } + return changes; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/node_pass.h b/mindspore/ccsrc/backend/optimizer/common/node_pass.h new file mode 100644 index 0000000000..780ae1a056 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/node_pass.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ +#include +#include + +#include "backend/optimizer/common/pass.h" + +namespace mindspore { +namespace opt { +// @brief ANF Node level optimization base pass +class NodePass : public Pass { + public: + explicit NodePass(const std::string &name) : Pass(name) {} + ~NodePass() override = default; + bool Run(const FuncGraphPtr &func_graph) final; + virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; +}; +using NodePassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/optimizer.cc b/mindspore/ccsrc/backend/optimizer/common/optimizer.cc new file mode 100644 index 0000000000..01e9111e86 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/optimizer.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/optimizer.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "backend/optimizer/common/pass_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/manager.h" + +namespace mindspore { +namespace opt { +PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) + : NodePass(name), + multigraph_(multigraph), + pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + primitive_vars_(std::make_shared()) {} + +const BaseRef PatternProcessPass::DefinePattern() const { + VarPtr X = std::make_shared(); + return BaseRef({X}); +} + +void PatternProcessPass::Build() { + VarPtr fg = std::make_shared("RootG"); + BaseRef pattern = std::move(DefinePattern()); + pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); +} + +AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (pattern_ == nullptr) { + Build(); + } + + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(primitive_vars_); + EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); + if (equiv != nullptr && !equiv->empty()) { + return Process(func_graph, node, equiv); + } + return nullptr; +} + +bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + VarPtr fg = std::make_shared("RootG"); + auto empty_equiv = std::make_shared(); + MS_EXCEPTION_IF_NULL(child_primitive_vars_); + EquivPtr another_equiv = + child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, + *child_primitive_vars_, empty_equiv); + if (another_equiv != nullptr && !another_equiv->empty()) { + return IsShareNodes(equiv, another_equiv); + } + return false; +} + +void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { + if (pass_manager != nullptr) { + pass_managers_.push_back(pass_manager); + } +} + +FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { + MS_EXCEPTION_IF_NULL(func_graph); + run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; + // Performance risk by creating new manager each time + auto manager = Manage(func_graph, true); + + bool changed = true; + while (changed) { + changed = false; + for (size_t i = 0; i < pass_managers_.size(); ++i) { + const PassManagerPtr &pm = pass_managers_[i]; + if (pm != nullptr && pm->Run(func_graph)) { + changed = true; + } + } + if (run_only_once_) { + break; + } + } + + std::vector func_graphs; + func_graphs.push_back(func_graph); + manager->KeepRoots(func_graphs); + (void)TopoSort(func_graph->get_return()); + return func_graph; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/optimizer.h b/mindspore/ccsrc/backend/optimizer/common/optimizer.h new file mode 100644 index 0000000000..0b03c9c0ee --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/optimizer.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "utils/graph_utils.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +using PatternListType = std::initializer_list; + +class PatternProcessPass : public NodePass { + public: + explicit PatternProcessPass(const std::string &name = "", bool multigraph = true); + ~PatternProcessPass() override = default; + virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; + virtual const BaseRef DefinePattern() const; + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + void Build(); + + AnfNodePtr pattern_ = nullptr; + bool multigraph_ = true; + PatternEngine pattern_engine_; + PrimitiveVarMapPtr primitive_vars_; +}; + +class MultipleOutputPatternProcessPass : public PatternProcessPass { + public: + explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) + : PatternProcessPass(name, multigraph), + child_pattern_engine_(PatternEngine(std::make_shared(), + std::function(AnfEqual), + std::function(CNodeTypeEqual))), + child_primitive_vars_(std::make_shared()) {} + ~MultipleOutputPatternProcessPass() override = default; + virtual BaseRef DefineAnotherPattern() const = 0; + // check two patterns whether share the same nodes or not + virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; + + protected: + bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; + PatternEngine child_pattern_engine_; + PrimitiveVarMapPtr child_primitive_vars_; +}; + +class GraphOptimizer { + public: + explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} + virtual ~GraphOptimizer() = default; + + void AddPassManager(const PassManagerPtr &pass_manager); + FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true); + + private: + const std::string name_ = "graph_optimizer"; + std::vector pass_managers_{}; + bool run_only_once_ = true; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pass.h b/mindspore/ccsrc/backend/optimizer/common/pass.h new file mode 100644 index 0000000000..6e35fb1dc4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ +#include +#include + +#include "ir/anf.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +// @brief ANF Graph level optimization base pass +class Pass { + public: + explicit Pass(const std::string &name = "pass") : name_(name) {} + virtual ~Pass() = default; + virtual bool Run(const FuncGraphPtr &func_graph) = 0; + virtual std::string name() const { return name_; } + + private: + const std::string name_; +}; +using PassPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc new file mode 100644 index 0000000000..f9f41237e0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/common/pass_manager.h" + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +const std::vector &PassManager::Passes() const { return passes_; } + +void PassManager::AddPass(const PassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + bool changed = false; + size_t num = 0; + for (const auto &pass : passes) { + if (pass != nullptr) { +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time {}; + struct timeval end_time {}; + (void)gettimeofday(&start_time, nullptr); +#endif + if (pass->Run(func_graph)) { + changed = true; + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; +#endif + if (save_graphs) { + auto dump_file_path = + save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir"; + DumpIR(dump_file_path, func_graph); + } + num++; + } + } + return changed; +} + +bool PassManager::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/pass_manager.h b/mindspore/ccsrc/backend/optimizer/common/pass_manager.h new file mode 100644 index 0000000000..51db27d250 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pass_manager.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "backend/optimizer/common/node_pass.h" + +namespace mindspore { +namespace opt { +// @brief For optimization passes management +class PassManager { + public: + explicit PassManager(const std::string &name = "pm", bool run_only_once = true) + : name_(name), passes_{}, run_only_once_(run_only_once) {} + virtual ~PassManager() = default; + // Get all the passes added by AddPass + const std::vector &Passes() const; + // Add graph pass, the pass object will be freed when pass manager freed. + void AddPass(const PassPtr &pass); + // Run passes added in pass manager on the input graph + // @param [inout] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [inout] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + private: + const std::string name_; + std::vector passes_; + bool run_only_once_; +}; +using PassManagerPtr = std::shared_ptr; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc new file mode 100644 index 0000000000..bd4efd82ef --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.cc @@ -0,0 +1,360 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/common/pattern_engine.h" + +#include +#include +#include +#include + +#include "frontend/optimizer/opt.h" + +#include "ir/anf.h" +#include "utils/convert_utils_base.h" +#include "utils/overload.h" + +namespace mindspore { +static int GetNextTag() { + static int kID = 0; + return kID++; +} + +void Var::EnsureTag() { + if (tag_.length() == 0) { + std::ostringstream buffer; + buffer << "_" << GetNextTag(); + tag_ = buffer.str(); + } +} + +bool operator==(const VarPtr &lhs, const VarPtr &rhs) { + if (lhs->isa() && rhs->isa()) { + CondVarPtr v1 = dyn_cast(lhs); + CondVarPtr v2 = dyn_cast(rhs); + return *v1 == *v2; + } + + if (lhs->isa() && rhs->isa()) { + SVarPtr v1 = dyn_cast(lhs); + SVarPtr v2 = dyn_cast(rhs); + return *v1 == *v2; + } + return (*lhs == *rhs); +} + +std::string SeqVar::ToString() const { + std::ostringstream buffer; + buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; + return buffer.str(); +} + +std::ostream &operator<<(std::ostream &os, const VarPtr &var) { + if (var == nullptr) { + os << ""; + } else { + os << var->ToString(); + } + return os; +} + +template <> +std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { + os << "[Equiv]" + << "\n"; + for (auto &equiv_item : equiv) { + auto k = equiv_item.first; + os << k << ":"; + BaseRef x = equiv_item.second; + if (utils::isa(x)) { + auto node = utils::cast(x); + os << "TypeString[" << node->type_name() << "]"; + if (IsValueNode(node)) { + os << "IsValueNodeGraph "; + } + os << "type " << node->type_name(); + if (node->isa()) { + os << " value " << GetValueNode(node); + } + os << " addr: " << node; + } else if (utils::isa(x)) { + os << "Named " << x.ToString().c_str(); + } else if (utils::isa(x)) { + os << "TypeString[Var]"; + os << utils::cast(x); + } else if (utils::isa(x)) { + os << "TypeString[Graph]"; + } + os << "\n"; + } + return os; +} + +static BaseRef GetVar(const BaseRef &x) { + MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); + if (utils::isa(x)) { + auto node = utils::cast(x); + MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; + if (node->isa()) { + MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); + return node->cast()->var_; + } + if (node->isa()) { + MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); + } else { + MS_LOG(DEBUG) << "type " + node->type_name(); + } + } else if (utils::isa(x)) { + MS_LOG(DEBUG) << "Named " + x.ToString(); + } else if (utils::isa(x)) { + MS_LOG(DEBUG) << "VectorRef"; + } else if (utils::isa(x)) { + MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); + } + MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); + return x; +} + +EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { + MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); + MS_EXCEPTION_IF_NULL(equiv); + if (utils::isa(pattern)) { + VarPtr var = utils::cast(pattern); + if (var->matches(expr)) { + (*equiv)[var] = expr; + MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); + return equiv; + } + } + + return nullptr; +} + +bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { + MS_EXCEPTION_IF_NULL(values_expr); + if (utils::isa(pattern_ref)) { + *values_pattern = pattern_ref; + *values_expr = expr_ref; + return true; + } + return false; +} + +bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { + MS_EXCEPTION_IF_NULL(values_expr); + // visitor to visite the list + auto appender_pattern = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { + values.push_back(GetVar(u)); + return u; + }; + return fn; + }; + + visitor_->SetFn(appender_pattern(*values_pattern)); + MS_LOG(DEBUG) << "visit pattern_ref"; + bool success = visitor_->Visit(pattern_ref, nullptr); + if (!success) { + return false; + } + + auto appender_expr = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { + values.push_back(u); + return u; + }; + return fn; + }; + + visitor_->SetFn(appender_expr(*values_expr)); + MS_LOG(DEBUG) << "visit expr_ref"; + return visitor_->Visit(expr_ref, nullptr); +} + +static int GetSVarStartIndex(const VectorRef &values) { + int index = -1; + int count = 0; + for (auto &value : values) { + if (utils::isa(value) && utils::cast(value)->isa()) { + if (index != -1) { + MS_LOG(DEBUG) << "Multiple SVars in sequence"; + return kInvalidVarIndex; + } + index = count; + } + count++; + } + return index; +} + +void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) { + if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || + !utils::isa(expr_ref)) { + return; + } + auto real_node = utils::cast(expr_ref); + MS_EXCEPTION_IF_NULL(real_node); + if (!real_node->isa()) { + return; + } + auto prim_node = utils::cast(values_pattern[0]); + MS_EXCEPTION_IF_NULL(prim_node); + if (!IsValueNode(prim_node)) { + return; + } + ValuePtr value = GetValueNode(prim_node); + MS_EXCEPTION_IF_NULL(value); + auto prim = value->cast(); + MS_EXCEPTION_IF_NULL(prim); + auto iter = primitive_vars.find(prim); + if (iter == primitive_vars.end()) { + return; + } + (*equiv)[iter->second] = real_node; +} + +EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { + int svar_index = GetSVarStartIndex(values_pattern); + if (svar_index == kInvalidVarIndex) { + return nullptr; + } + + size_t values_pattern_len = values_pattern.size(); + size_t values_expr_len = values_expr.size(); + + if (svar_index == -1) { + if (values_pattern_len != values_expr_len) { + MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len " + << values_expr_len; + return nullptr; + } + } + if (values_expr_len < values_pattern_len - 1) { + MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; + return nullptr; + } + size_t diff = values_expr_len - values_pattern_len + 1; + for (size_t i = 0; i < values_pattern_len; i++) { + size_t expr_i = i; + if (svar_index != -1 && i == IntToSize(svar_index)) { + auto seq = + std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); + equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); + } else { + if (svar_index != -1 && i > IntToSize(svar_index)) { + expr_i = i + diff - 1; + } + equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); + } + if (equiv == nullptr) { + return nullptr; + } + } + return equiv; +} + +EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const { + MS_LOG(DEBUG) << "-----[in Match]"; + MS_LOG(DEBUG) << "GetVar w"; + BaseRef pattern_ref = GetVar(pattern); + MS_LOG(DEBUG) << "GetVar v"; + BaseRef expr_ref = expr; + + if (equiv == nullptr) { + MS_LOG(EXCEPTION) << "Equiv pointer is null"; + } + + MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); + // 1. if pattern_ref is var and already in equiv, replace it. + if (utils::isa(pattern_ref)) { + VarPtr var = utils::cast(pattern_ref); + auto iter = equiv->find(var); + if (iter != equiv->end()) { + pattern_ref = iter->second; + } + } + + // 2. check equal + if (eq_(pattern_ref, expr_ref)) { + return equiv; + } + + // 3. match var + EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); + if (ret_equiv) { + return ret_equiv; + } + + // 4. here the type can be std:vector, std:list, + // or cnode. + if (!type_eq_(pattern_ref, expr_ref)) { + MS_LOG(DEBUG) << "Type mismatch"; + return nullptr; + } + + // 5. transfer the Containers by visitor to std::vector + VectorRef values_pattern; + VectorRef values_expr; + if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { + return nullptr; + } + + // 6. if any svar in both side, find the SeqVar index, + // try to pack the Var s in std::vector to a Seq and match elements one by one. + // check svar + equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); + UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); + return equiv; +} + +BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(equiv); + MS_LOG(DEBUG) << "-----[in Replace]"; + BaseRef ref = GetVar(pattern); + BaseRef out; + bool is_match = false; + + // w is var + if (utils::isa(ref)) { + const VarPtr &var = utils::cast(ref); + auto iter = equiv->find(var); + if (iter != equiv->end()) { + out = iter->second; + is_match = true; + } + } + if (is_match) { + return out; + } + + // visitor to visit the list + std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; + + visitor_->SetFn(fn); + BaseRef visit_out; + if (!visitor_->Visit(pattern, &visit_out)) { + return pattern; + } + return visit_out; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h new file mode 100644 index 0000000000..51fa8801b2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/pattern_engine.h @@ -0,0 +1,204 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "backend/optimizer/common/visit.h" +#include "base/base.h" +#include "utils/log_adapter.h" +#include "utils/base_ref.h" + +namespace mindspore { +class CondVar; +class SeqVar; +using CondVarPtr = std::shared_ptr; +using SVarPtr = std::shared_ptr; +const int kInvalidVarIndex = -2; + +using ConditionFunc = std::function; + +// Base wildcard variable which could match any anf node. +class Var : public Base { + friend class VarHasher; + + public: + explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } + explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { + EnsureTag(); + } + Var(const Var &other) : Base(other), tag_(other.tag_) {} + virtual Var &operator=(const Var &other) { + if (&other == this) { + return *this; + } + this->tag_ = other.tag_; + return *this; + } + ~Var() override = default; + MS_DECLARE_PARENT(Var, Base); + + virtual bool matches(const BaseRef &) { return true; } + + virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } + bool operator!=(const Var &other) const { return !(&other == this); } + + std::string tag() const { return tag_; } + PrimitivePtr primitive() const { return primitive_; } + std::string ToString() const override { + std::ostringstream buffer; + buffer << "Var(" << tag_ << ")"; + return buffer.str(); + } + std::size_t hash() const override { return std::hash()(tag_); } + + protected: + void EnsureTag(); + + std::string tag_; + PrimitivePtr primitive_; +}; + +// VarNode means variable node, a subclass of AnfNode +class VarNode : public AnfNode { + public: + VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} + ~VarNode() override = default; + MS_DECLARE_PARENT(VarNode, AnfNode); + + const VarPtr var_; +}; +using VarNodePtr = std::shared_ptr; + +class VarHasher { + public: + std::size_t operator()(const Var &var) const { return var.hash(); } +}; + +// Condition Var, match an anf node when condition function return true. +class CondVar : public Var { + public: + explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} + ~CondVar() override = default; + MS_DECLARE_PARENT(CondVar, Var); + bool matches(const BaseRef &value) override { + MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); + if (utils::isa(value)) { + return false; + } + return cond_fn_(value); + } + ConditionFunc cond_fn_; +}; + +using Seq = VectorRef; +using SeqPtr = std::shared_ptr; + +// Sequence Var which could match multiple consecutive input nodes of a CNode. +class SeqVar : public Var { + public: + SeqVar() { subvar_ = std::make_shared(); } + ~SeqVar() override = default; + MS_DECLARE_PARENT(SeqVar, Var); + explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } + bool matches(const BaseRef &value) override { + // match Seq. + if (utils::isa(value)) { + const Seq &seq = utils::cast(value); + return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { + auto eq = subvar_->matches(v); + return eq; + }); + } + return false; + } + bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } + std::string ToString() const override; + + private: + VarPtr subvar_; +}; + +bool operator==(const VarPtr &lhs, const VarPtr &rhs); + +inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } + +std::ostream &operator<<(std::ostream &os, const VarPtr &var); + +using Equiv = std::map; +using EquivPtr = std::shared_ptr; +using PrimitiveVarMap = std::unordered_map; +using PrimitiveVarMapPtr = std::shared_ptr; + +inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } + +class PatternEngine { + public: + PatternEngine(const std::shared_ptr &visitor, + const std::function &eq, + const std::function &type_eq = DefaultTypeEq) + : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} + ~PatternEngine() = default; + + EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const; + // Replace pattern with equivalent + BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; + + private: + EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; + bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + std::shared_ptr visitor_; + std::function eq_; + std::function type_eq_; +}; +} // namespace mindspore +namespace std { +using mindspore::ERROR; +using mindspore::LogStream; +using mindspore::NoExceptionType; +template <> +struct hash { + std::size_t operator()(const mindspore::VarPtr var) const { + if (var == nullptr) { + MS_LOG(ERROR) << "Invalid var ptr"; + return 0; + } + return std::hash{}(var->tag()); + } +}; +} // namespace std +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/common/visit.cc b/mindspore/ccsrc/backend/optimizer/common/visit.cc new file mode 100644 index 0000000000..d0b52609f8 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/common/visit.cc @@ -0,0 +1,166 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/common/visit.h" + +#include +#include +#include +#include + +#include "backend/optimizer/common/pattern_engine.h" +#include "utils/any.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "utils/log_adapter.h" + +/* namespace to support utils definition */ +namespace mindspore { +bool CheckIfNeedExpand(const std::vector &list) { + return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); +} + +std::shared_ptr ExpandList(const std::vector &list) { + std::shared_ptr new_list = std::make_shared(); + for (auto &item : list) { + if (utils::isa(item)) { + const Seq &seq = utils::cast(item); + new_list->insert(new_list->end(), seq.begin(), seq.end()); + } else { + new_list->push_back(item); + } + } + return new_list; +} + +bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { + std::vector out; + (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), + [this](const BaseRef &item) { return fn_(item); }); + if (visit_out != nullptr) { + *visit_out = ExpandList(out); + } + return true; +} + +bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { + if (utils::isa(any)) { + return Visit(utils::cast(any), visit_out); + } else if (utils::isa(any)) { + auto nodeptr = utils::cast(any); + AnfNodePtr output; + AnfNodePtr *p_output = &output; + if (visit_out == nullptr) { + p_output = nullptr; + } + Visit(nodeptr, fn_, p_output); + if (visit_out != nullptr) { + *visit_out = output; + } + return true; + } + MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); + return false; +} + +void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { + if (node->isa()) { + Visit(node->cast(), fn, output); + return; + } + + if (node->isa()) { + Visit(node->cast(), fn, output); + return; + } + + if (output != nullptr) { + *output = node; + } +} + +void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { + // if output is nullptr, it's not required to make the new CNode node. + if (output == nullptr) { + for (auto &inp : cnode->inputs()) { + (void)fn(inp); + } + + if (cnode->func_graph() != nullptr) { + (void)fn(cnode->func_graph()); + } else { + (void)fn(cnode->func_graph_as_var()); + } + return; + } + + std::vector new_inputs; + std::vector after_cnode_fn; + std::shared_ptr out; + (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); + if (CheckIfNeedExpand(after_cnode_fn)) { + out = ExpandList(after_cnode_fn); + } + + std::vector &outs = after_cnode_fn; + if (out != nullptr) { + outs = out->elements(); + } + + for (auto &any_item : outs) { + if (!utils::isa(any_item)) { + MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; + } + new_inputs.push_back(utils::cast(any_item)); + } + + BaseRef any_fg; + AnfNodePtr new_cnode = nullptr; + if (cnode->func_graph() != nullptr) { + any_fg = fn(cnode->func_graph()); + if (!utils::isa(any_fg)) { + MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; + } + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else { + any_fg = fn(cnode->func_graph_as_var()); + if (utils::isa(any_fg)) { + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else if (utils::isa(any_fg)) { + new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); + } else { + MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; + } + } + new_cnode->set_abstract(cnode->abstract()); + *output = new_cnode; +} + +void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { + const BaseRef &value = utils::cast(fn(vnode->value())); + if (utils::isa(value)) { + if (output != nullptr) { + auto ct = NewValueNode(utils::cast(value)); + ct->set_abstract(vnode->abstract()); + *output = ct; + } + return; + } + MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/visit.h b/mindspore/ccsrc/backend/optimizer/common/visit.h similarity index 100% rename from mindspore/ccsrc/pre_activate/common/visit.h rename to mindspore/ccsrc/backend/optimizer/common/visit.h diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc new file mode 100644 index 0000000000..41e4abee27 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/adam_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AdamFusion::DefinePattern() const { + VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef next_v = + VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef( + {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); + VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); + VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); + VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); + VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); + return depend3; +} + +const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto beta1_input = utils::cast((*equiv)[beta1_]); + auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); + auto beta2_input = utils::cast((*equiv)[beta2_]); + auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); + auto eps_input = utils::cast((*equiv)[eps_]); + auto lr_input = utils::cast((*equiv)[lr_]); + auto param_input = utils::cast((*equiv)[param_]); + auto m_input = utils::cast((*equiv)[m_]); + auto v_input = utils::cast((*equiv)[v_]); + auto gradient_input = utils::cast((*equiv)[gradient_]); + MS_EXCEPTION_IF_NULL(beta1_input); + MS_EXCEPTION_IF_NULL(one_sub_beta1_input); + MS_EXCEPTION_IF_NULL(beta2_input); + MS_EXCEPTION_IF_NULL(one_sub_beta2_input); + MS_EXCEPTION_IF_NULL(eps_input); + MS_EXCEPTION_IF_NULL(lr_input); + MS_EXCEPTION_IF_NULL(param_input); + MS_EXCEPTION_IF_NULL(m_input); + MS_EXCEPTION_IF_NULL(v_input); + MS_EXCEPTION_IF_NULL(gradient_input); + + auto prim = std::make_shared(kFusedAdamName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, + eps_input, lr_input, param_input, m_input, v_input, + gradient_input}; + auto adam = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(adam); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); + adam->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(adam); + AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); + return adam; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h new file mode 100644 index 0000000000..f87defc04c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamFusion : public PatternProcessPass { + public: + explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) { + beta1_ = std::make_shared(); + one_sub_beta1_ = std::make_shared(); + beta2_ = std::make_shared(); + one_sub_beta2_ = std::make_shared(); + eps_ = std::make_shared(); + lr_ = std::make_shared(); + param_ = std::make_shared(); + m_ = std::make_shared(); + v_ = std::make_shared(); + gradient_ = std::make_shared(); + } + ~AdamFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr beta1_; + VarPtr one_sub_beta1_; + VarPtr beta2_; + VarPtr one_sub_beta2_; + VarPtr eps_; + VarPtr lr_; + VarPtr param_; + VarPtr m_; + VarPtr v_; + VarPtr gradient_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc new file mode 100644 index 0000000000..c95945c980 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/adam_weight_decay_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { + std::vector inputs_format; + std::vector outputs_format; + std::vector inputs_type; + std::vector outputs_type; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); + inputs_format.push_back(kOpFormat_DEFAULT); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); + outputs_format.push_back(kOpFormat_DEFAULT); + } + builder.SetInputsDeviceType(inputs_type); + builder.SetInputsFormat(inputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetOutputsFormat(outputs_format); + return builder.Build(); +} +} // namespace + +const BaseRef AdamWeightDecayFusion::DefinePattern() const { + VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef next_v = + VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef( + {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); + VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); + + VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); + VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); + VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); + VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); + VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); + return depend3; +} + +const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + auto beta1_input = utils::cast((*equiv)[beta1_]); + auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); + auto beta2_input = utils::cast((*equiv)[beta2_]); + auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); + auto eps_input = utils::cast((*equiv)[eps_]); + auto lr_input = utils::cast((*equiv)[lr_]); + auto weight_decay_input = utils::cast((*equiv)[weight_decay_]); + auto param_input = utils::cast((*equiv)[param_]); + auto m_input = utils::cast((*equiv)[m_]); + auto v_input = utils::cast((*equiv)[v_]); + auto gradient_input = utils::cast((*equiv)[gradient_]); + MS_EXCEPTION_IF_NULL(beta1_input); + MS_EXCEPTION_IF_NULL(one_sub_beta1_input); + MS_EXCEPTION_IF_NULL(beta2_input); + MS_EXCEPTION_IF_NULL(one_sub_beta2_input); + MS_EXCEPTION_IF_NULL(eps_input); + MS_EXCEPTION_IF_NULL(lr_input); + MS_EXCEPTION_IF_NULL(weight_decay_input); + MS_EXCEPTION_IF_NULL(param_input); + MS_EXCEPTION_IF_NULL(m_input); + MS_EXCEPTION_IF_NULL(v_input); + MS_EXCEPTION_IF_NULL(gradient_input); + + auto prim = std::make_shared(kFusedAdamWeightDecayName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = { + NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, + eps_input, lr_input, param_input, m_input, v_input, + gradient_input, weight_decay_input}; + auto adam_weight_decay = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(adam_weight_decay); + auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; + auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get()); + adam_weight_decay->set_scope(node->scope()); + + auto build_info = GenerateKernelBuildInfo(adam_weight_decay); + AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); + return adam_weight_decay; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h new file mode 100644 index 0000000000..53477ec898 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class AdamWeightDecayFusion : public PatternProcessPass { + public: + explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) { + beta1_ = std::make_shared(); + one_sub_beta1_ = std::make_shared(); + beta2_ = std::make_shared(); + one_sub_beta2_ = std::make_shared(); + eps_ = std::make_shared(); + lr_ = std::make_shared(); + weight_decay_ = std::make_shared(); + param_ = std::make_shared(); + m_ = std::make_shared(); + v_ = std::make_shared(); + gradient_ = std::make_shared(); + } + ~AdamWeightDecayFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr beta1_; + VarPtr one_sub_beta1_; + VarPtr beta2_; + VarPtr one_sub_beta2_; + VarPtr eps_; + VarPtr lr_; + VarPtr weight_decay_; + VarPtr param_; + VarPtr m_; + VarPtr v_; + VarPtr gradient_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.cc new file mode 100644 index 0000000000..b531b0caa5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include +#include "utils/log_adapter.h" +namespace mindspore { +namespace memreuse { +/** + * Add some set && get function + */ +void KernelRefCount::SetKernelRefCountInfo(int index, size_t size, RefCountType reftype) { + index_ = index; + size_ = size; + reftype_ = reftype; +} + +std::vector KernelDef::GetInputRefIndexs() const { + std::vector input_ref_indexs; + if (input_refs_.empty()) { + return input_ref_indexs; + } + (void)std::transform(input_refs_.begin(), input_refs_.end(), std::back_inserter(input_ref_indexs), + [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); + return input_ref_indexs; +} + +std::vector KernelDef::GetOutputRefIndexs() const { + std::vector output_ref_indexs; + if (output_refs_.empty()) { + return output_ref_indexs; + } + (void)std::transform(output_refs_.begin(), output_refs_.end(), std::back_inserter(output_ref_indexs), + [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); + return output_ref_indexs; +} + +std::vector KernelDef::GetWorkspaceRefIndexs() const { + std::vector wk_ref_indexs; + if (wk_space_.empty()) { + return wk_ref_indexs; + } + // only one key + auto wk_refs_iter = wk_space_.begin(); + auto wk_refs = wk_refs_iter->second; + (void)std::transform(wk_refs.begin(), wk_refs.end(), std::back_inserter(wk_ref_indexs), + [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); + return wk_ref_indexs; +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h similarity index 100% rename from mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.h rename to mindspore/ccsrc/backend/optimizer/mem_reuse/kernel_refcount.h diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h new file mode 100644 index 0000000000..1952415515 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" + +using HostAddress = mindspore::kernel::Address; +namespace mindspore { +namespace device { +namespace memswap { +enum class SwapKind { kDeviceToHost = 0, kHostToDevice = 1 }; + +struct TensorInfo { + size_t tensor_size_{0}; + AnfNodePtr kernel_{nullptr}; + size_t output_idx_{0}; +}; + +struct KernelExecutionInfo { + size_t topo_order_{0}; + float execution_perform_{0.0}; + bool trigger_swap_{false}; + bool need_swap_{false}; + // output index to topo orders of node users + std::map> node_users_map_; + // kernel output idx to host addr + std::map host_addrs_; + + KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} + explicit KernelExecutionInfo(size_t topo_order) + : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} + KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) + : topo_order_(topo_order), + execution_perform_(execution_perform), + trigger_swap_(trigger_swap), + need_swap_(need_swap) {} +}; + +// trigger swap +struct MemSwapInfo { + SwapKind swap_kind_; + // kernel need to be swapped + AnfNodePtr kernel_{nullptr}; + size_t output_idx_{0}; +}; + +class MemCopyManager { + public: + MemCopyManager() = default; + + virtual ~MemCopyManager() = default; + + virtual void Init() {} + + virtual void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} + + virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} + + virtual bool SyncMemCopyStream(SwapKind swap_kind) { return true; } + + virtual DeviceAddressPtr UpdateSwapOutQueue() { return nullptr; } + + virtual DeviceAddressPtr UpdateSwapInQueue() { return nullptr; } + + virtual bool AllocHostPinnedMem(size_t size, void **addr) const { return true; } + + virtual void FreeHostPinnedMem(void *addr) const {} + + virtual void ClearSwapQueue() {} +}; +using MemCopyManagerPtr = std::shared_ptr; +} // namespace memswap +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc new file mode 100644 index 0000000000..8f705be556 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.cc @@ -0,0 +1,326 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +DynamicMemPoolBestFit::~DynamicMemPoolBestFit() { + global_mem_block_list_.clear(); + global_idle_mem_buf_map_.clear(); +} + +DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { + size_t align_size = AlignMemorySize(size); + // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf. + DeviceMemPtr device_addr = FindIdleMemBuf(align_size); + if (!device_addr) { + device_addr = AddMemBlockAndMemBuf(align_size); + } + return device_addr; +} + +std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, + std::vector size_list) { + std::vector device_addr_list; + // Pre-alloc the one whole piece memory. + auto device_addr = AllocTensorMem(total_size); + if (!device_addr) { + return device_addr_list; + } + // Remove the pre-alloc memory. + auto mem_block = FindMemBlock(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + auto rest_size = mem_buf->size_ - total_size; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + // Split the pre-alloc memory into continuous memory by the size list. + DynamicMemBufPtr continuous_mem_buf; + auto buf_addr = device_addr; + for (size_t i = 0; i < size_list.size(); i++) { + continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); + (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); + device_addr_list.emplace_back(buf_addr); + buf_addr = AddressOffset(buf_addr, size_list[i]); + } + // Update the size of the last memory buf. + continuous_mem_buf->size_ += rest_size; + return device_addr_list; +} + +size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { + if (size == 0) { + return DYNAMIC_MEM_ALIGN_SIZE; + } + return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; +} + +DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) { + auto iter = global_idle_mem_buf_map_.lower_bound(size); + if (iter != global_idle_mem_buf_map_.end()) { + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + if (mem_buf->status_ != kMemBufIdle) { + MS_LOG(EXCEPTION) << "Find the mem_buf is not idle, alloc_size[" << size << "] mem_buf_size[" << mem_buf->size_ + << "] mem_buf_address[" << mem_buf->device_addr_ << "]."; + } + mem_buf->status_ = kMemBufUsed; + // Remove map of old idle memory buf + (void)global_idle_mem_buf_map_.erase(iter); + // Divide memory buf + if (IsDivide(size, mem_buf->size_)) { + DivideMemBuf(size, mem_buf); + } + // Memory statistics + total_used_mem_statistics_ += mem_buf->size_; + if (total_used_mem_statistics_ > used_mem_peak_statistics_) { + used_mem_peak_statistics_ = total_used_mem_statistics_; + } + return mem_buf->device_addr_; + } + return nullptr; +} + +DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) { + size_t alloc_mem_size = CalMemBlockAllocSize(size); + if (alloc_mem_size == 0) { + return nullptr; + } + // Add new memory block + DeviceMemPtr device_addr = nullptr; + auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr); + if (real_alloc_size < size) { + MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size + << "]."; + return nullptr; + } + auto mem_block = std::make_shared(device_addr, real_alloc_size); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); + (void)global_mem_block_list_.insert(iter, mem_block); + // Add new memory buf + auto mem_buf = std::make_shared(device_addr, kMemBufUsed, real_alloc_size); + MS_EXCEPTION_IF_NULL(mem_buf); + // Add map of new memory buf in the block + (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf); + // Divide memory buf + if (IsDivide(size, mem_buf->size_)) { + DivideMemBuf(size, mem_buf); + } + // Memory statistics + total_mem_statistics_ += real_alloc_size; + total_used_mem_statistics_ += mem_buf->size_; + if (total_used_mem_statistics_ > used_mem_peak_statistics_) { + used_mem_peak_statistics_ = total_used_mem_statistics_; + } + return mem_buf->device_addr_; +} + +size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) { + auto device_free_mem_size = free_mem_size(); + if (device_free_mem_size < size) { + MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size + << "] is smaller than required size[" << size << "]."; + return 0; + } + auto alloc_mem_size = mem_alloc_unit_size(); + // Growing at twice of alloc size + while (alloc_mem_size < size) { + alloc_mem_size = alloc_mem_size * 2; + } + alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size); + return alloc_mem_size; +} + +bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const { + return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; +} + +void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { + MS_EXCEPTION_IF_NULL(mem_buf); + auto mem_block = FindMemBlock(mem_buf->device_addr_); + MS_EXCEPTION_IF_NULL(mem_block); + // Divide new memory buf + size_t newbuf_size = mem_buf->size_ - size; + mem_buf->size_ = size; + DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size); + auto new_mem_buf = std::make_shared(newbuf_addr, kMemBufIdle, newbuf_size); + // Add map of new memory buf in the block + (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf); + // Add map of new idle memory buf + (void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf); +} + +bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block) { + MS_EXCEPTION_IF_NULL(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + return device_addr < mem_block->device_addr(); +} + +DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(device_addr); + auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); + if (iter != global_mem_block_list_.begin()) { + return *(--iter); + } + return nullptr; +} + +void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(device_addr); + auto mem_block = FindMemBlock(device_addr); + if (mem_block == nullptr) { + MS_LOG(WARNING) << "Can't find the mem_block of the device address[" << device_addr << "]."; + return; + } + CombineMemBuf(mem_block, device_addr); +} + +void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(mem_block); + MS_EXCEPTION_IF_NULL(device_addr); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + if (mem_buf->status_ != kMemBufUsed) { + MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "]."; + } + mem_buf->status_ = kMemBufIdle; + total_used_mem_statistics_ -= mem_buf->size_; + // Combine backward(combine the next_mem_buf to mem_buf) + auto next_iter = iter; + (void)next_iter++; + if (next_iter != mem_block->block_all_mem_buf_map_.end()) { + auto next_mem_buf = next_iter->second; + MS_EXCEPTION_IF_NULL(next_mem_buf); + if (next_mem_buf->status_ == kMemBufIdle) { + mem_buf->size_ += next_mem_buf->size_; + EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_); + (void)mem_block->block_all_mem_buf_map_.erase(next_iter); + } + } + // Combine forward(combine the mem_buf to prev_mem_buf) + bool forward_combine = false; + DynamicMemBufPtr prev_mem_buf; + if (iter != mem_block->block_all_mem_buf_map_.begin()) { + auto prev_iter = iter; + (void)prev_iter--; + prev_mem_buf = prev_iter->second; + MS_EXCEPTION_IF_NULL(prev_mem_buf); + if (prev_mem_buf->status_ == kMemBufIdle) { + EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_); + prev_mem_buf->size_ += mem_buf->size_; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + forward_combine = true; + } + } + // Add map of new idle memory + if (forward_combine) { + (void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf); + } else { + (void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf); + } +} + +void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr) { + MS_EXCEPTION_IF_NULL(device_addr); + auto iter = global_idle_mem_buf_map_.equal_range(size); + while (iter.first != iter.second) { + MS_EXCEPTION_IF_NULL(iter.first->second); + // Remove map of the idle memory buf by size and device address + if (iter.first->second->device_addr_ == device_addr) { + (void)global_idle_mem_buf_map_.erase(iter.first); + return; + } + (void)iter.first++; + } + MS_LOG(ERROR) << "Can't find the size[" << size << "] and device address[" << device_addr << "] in the idle mem_buf."; +} + +void DynamicMemPoolBestFit::ReleaseDeviceRes() { + MS_LOG(INFO) << "The dynamic memmory pool total size is " << total_mem_statistics_ << ", total used size is " + << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << "."; + for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { + auto device_addr = (*iter)->device_addr(); + if (device_addr != nullptr) { + if (!FreeDeviceMem(device_addr)) { + MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error."; + } + } + } +} + +void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() { + MS_LOG(INFO) << "Start dump dynamic memory pool info."; + DeviceAddrMapMemBuf mem_block_map; + DynamicMemBufPtr mem_buf; + size_t total_mem = 0; + size_t total_used_mem = 0; + size_t total_idle_mem1 = 0; + size_t total_idle_mem2 = 0; + // Dump the memory block info and memory buf info + MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "]."; + for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { + total_mem += (*iter)->size(); + mem_block_map = (*iter)->block_all_mem_buf_map_; + MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts[" + << mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size[" + << (*iter)->size() << "]."; + for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) { + mem_buf = iter_mem_buf->second; + MS_EXCEPTION_IF_NULL(mem_buf); + if (mem_buf->status_ == kMemBufIdle) { + total_idle_mem1 += mem_buf->size_; + } else { + total_used_mem += mem_buf->size_; + } + MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status[" + << mem_buf->status_ << "]."; + } + } + // Dump all the idle memory buf info + MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "]."; + for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) { + mem_buf = iter_idle->second; + MS_EXCEPTION_IF_NULL(mem_buf); + total_idle_mem2 += mem_buf->size_; + MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status[" + << mem_buf->status_ << "]."; + } + // Dump the memory statistical info + MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory[" + << total_idle_mem1 << "]."; + if (total_idle_mem1 != total_idle_mem2) { + MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory."; + } + if (total_mem != total_used_mem + total_idle_mem1) { + MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory."; + } + MS_LOG(INFO) << "Finish dump dynamic memory pool info."; +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h similarity index 100% rename from mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h rename to mindspore/ccsrc/backend/optimizer/mem_reuse/mem_dynamic_allocator.h diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc new file mode 100644 index 0000000000..263ceaec63 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -0,0 +1,436 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include +#include +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace memreuse { +bool MemReuseUtil::InitDynamicOutputKernelRef() { + int index = util_index_; + auto kernel_cnodes = graph_->execution_order(); + if (kernel_cnodes.empty()) { + return true; + } + int kernel_out_ref_num = 0; + for (auto &kernel_cnode : kernel_cnodes) { +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().CheckSignalOps(kernel_cnode); +#endif + if (kernel_cnode == nullptr) { + return false; + } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); + if (kernel_mod == nullptr) { + return false; + } + auto key = kernel_cnode.get(); + // for every apply_kernel to set new output + auto iter = kernel_output_refs_.find(key); + if (iter == kernel_output_refs_.end()) { + auto output_sizes = kernel_mod->GetOutputSizeList(); + KernelRefCountPtrList kernel_refs; + for (auto size : output_sizes) { + total_dy_size_ += size; + // do not MallocDynamicMem just record this + KernelRefCountPtr kernel_ref = std::make_shared(); + index++; + auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); + kernel_ref->stream_id_ = curr_stream_id; + kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); + kernel_refs.push_back(kernel_ref); + kernel_out_ref_num++; + total_refs_list_.push_back(kernel_ref); + } + if (!kernel_refs.empty()) { + kernel_output_refs_[key] = kernel_refs; + } + } + } + return true; +} + +bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { + int WkIndex = util_index_; + auto kernel_cnodes = graph_->execution_order(); + if (kernel_cnodes.empty()) { + return true; + } + for (auto &kernel_cnode : kernel_cnodes) { + if (kernel_cnode == nullptr) { + return false; + } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); + if (kernel_mod == nullptr) { + return false; + } + auto key = kernel_cnode.get(); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + KernelRefCountPtrList workspace_kernel_refs; + for (auto size : workspace_sizes) { + total_workspace_size_ += size; + ++WkIndex; + KernelRefCountPtr workspace_ref = std::make_shared(); + workspace_ref->SetKernelRefCountInfo(WkIndex, size, kDynamicRefCount); + workspace_kernel_refs.push_back(workspace_ref); + // total wk ref + total_wk_ref_list_.push_back(workspace_ref); + } + if (!workspace_kernel_refs.empty()) { + // every key index wk_refs + kernel_workspace_refs_[key] = workspace_kernel_refs; + } + } + return true; +} + +bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + graph_ = graph; + is_all_nop_node_ = opt::IsAllNopNode(graph); + if (!InitDynamicOutputKernelRef()) { + MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; + return false; + } + if (!InitDynamicWorkspaceKernelRef()) { + MS_LOG(INFO) << "InitDynamicWorkspaceKernelRef fail"; + return false; + } + return true; +} + +// set longest worspace list && largest workspace sizes +void MemReuseUtil::SetWorkSpaceList() { + int max_list_size = 0; + std::vector total_sizes; + std::vector max_list; + auto kernel_cnodes = graph_->execution_order(); + for (auto &kernel_cnode : kernel_cnodes) { + MS_EXCEPTION_IF_NULL(kernel_cnode); + auto cnode_key = kernel_cnode.get(); + auto cnode_iter = kernel_workspace_refs_.find(cnode_key); + if (cnode_iter != kernel_workspace_refs_.end()) { + auto kernel_refs = cnode_iter->second; + std::vector current_list; + for (size_t i = 0; i < kernel_refs.size(); ++i) { + auto size = kernel_refs[i]->size_; + current_list.push_back(size); + } + if (max_list_size < SizeToInt(current_list.size())) { + max_list_size = SizeToInt(current_list.size()); + } + (void)std::copy(current_list.begin(), current_list.end(), std::back_inserter(total_sizes)); + } + } + sort(total_sizes.rbegin(), total_sizes.rend()); + max_list.resize(IntToSize(max_list_size)); + if (SizeToInt(total_sizes.size()) < max_list_size) { + MS_LOG(EXCEPTION) << "total workspace size is less than required max list size"; + } + max_list.assign(total_sizes.begin(), total_sizes.begin() + max_list_size); + for (auto &ma : max_list) { + total_reuseworkspace_size_ += ma; + } + max_workspace_size_ = max_list_size; + max_workspace_list_ = max_list; +} + +void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_def_ptr); + auto key = kernel.get(); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto ref_ptr = GetKernelInputRef(kernel, i); + if (ref_ptr != nullptr) { + if (ref_ptr->reftype() == kStaticRefCount) { + continue; + } else if (ref_ptr->reftype() == kDynamicRefCount) { + auto iter = kernel_def_ptr->inputs_.find(key); + if (iter == kernel_def_ptr->inputs_.end()) { + kernel_def_ptr->inputs_[key].push_back(ref_ptr); + } else { + iter->second.push_back(ref_ptr); + } + } + } + } +} + +void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_def_ptr); + auto key = kernel.get(); + auto iter = kernel_def_ptr->outputs_.find(key); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t k = 0; k < kernel_mod->GetOutputSizeList().size(); ++k) { + KernelRefCountPtr kernel_ref = kernel_output_refs_[key][k]; + if (iter == kernel_def_ptr->outputs_.end()) { + kernel_def_ptr->outputs_[key].push_back(kernel_ref); + } else { + iter->second.push_back(kernel_ref); + } + } +} + +void MemReuseUtil::SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_def_ptr); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto key = kernel.get(); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + if (kernel_workspace_refs_.find(key) != kernel_workspace_refs_.end()) { + auto wk_refs = kernel_workspace_refs_[key]; + if (i < wk_refs.size()) { + auto wk_ref = wk_refs[i]; + kernel_def_ptr->wk_space_[key].push_back(wk_ref); + } else { + MS_LOG(EXCEPTION) << "current index: " << i << " larger than wk_refs size " << wk_refs.size(); + } + } else { + MS_LOG(EXCEPTION) << "kernel_workspace_refs_ init error"; + } + } +} + +KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { + if (node == nullptr) { + MS_LOG(EXCEPTION) << "The node pointer is a nullptr."; + } + if (node->isa()) { + auto ak_node = node->cast(); + auto key = ak_node.get(); + MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx)); + return kernel_output_refs_[key][IntToSize(output_idx)]; + } + return nullptr; +} + +KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { + if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { + MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " + << AnfAlgo::GetInputTensorNum(kernel); + } + auto input_node = kernel->input(input_idx + 1); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + session::KernelWithIndex kernel_input; + if (is_all_nop_node_) { + // The graph does not remove the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + } else { + // The graph removes the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + } + if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; + } + auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second)); + return result; +} + +void MemReuseUtil::SetKernelDefMap() { + auto kernel_cnodes = graph_->execution_order(); + for (auto &kernel : kernel_cnodes) { + KernelDefPtr kernel_def_ptr = std::make_shared(); + kernel_def_ptr->set_kernel_name(AnfAlgo::GetCNodeName(kernel)); + kernel_def_ptr->set_scope_full_name(kernel->fullname_with_scope()); + kernel_def_ptr->set_stream_id(AnfAlgo::GetStreamId(kernel)); + SetInputMap(kernel, kernel_def_ptr.get()); + SetOutputMap(kernel, kernel_def_ptr.get()); + SetWkMap(kernel, kernel_def_ptr.get()); + auto key = kernel.get(); + kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); + kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); + kernel_def_ptr_list_.push_back(kernel_def_ptr); + kernel_map_[key] = kernel_def_ptr; + } + SetKernelDefInputs(); +} + +void MemReuseUtil::SetKernelDefInputs() { + for (const auto &kernel : graph_->execution_order()) { + MS_EXCEPTION_IF_NULL(kernel); + auto key = kernel.get(); + // find kernel_def according to cnode addr + auto iter = kernel_map_.find(key); + if (iter == kernel_map_.end()) { + MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init."; + } + auto kernel_def = iter->second; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto ref_ptr = GetKernelInputRef(kernel, i); + if (ref_ptr != nullptr) { + // set the inputs of this kernel_def + auto input_node = AnfAlgo::GetInputNode(kernel, i); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + session::KernelWithIndex input; + if (is_all_nop_node_) { + // The graph does not remove the nop node. + input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); + } else { + // The graph removes the nop node. + input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + } + if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; + } + auto input_key = (input.first).get(); + auto input_iter = kernel_map_.find(input_key); + if (input_iter == kernel_map_.end()) { + MS_LOG(EXCEPTION) << "kernel [" << (input.first)->fullname_with_scope() << "] is not init."; + } + kernel_def->InsertInputKernel(input_iter->second); + } + } + } +} + +void MemReuseUtil::SetReuseRefCount() { + auto kernels = graph_->execution_order(); + for (auto &kernel : kernels) { + auto key = kernel.get(); + for (auto &def : kernel_def_ptr_list_) { + auto iter = def->inputs_.find(key); + if (iter != def->inputs_.end()) { + for (auto &input : iter->second) { + input->ref_count_++; + input->ref_count_dynamic_use_++; + } + } + } + } +} + +void MemReuseUtil::SetSummaryNodesRefCount() { + bool summary_exist = graph_->summary_node_exist(); + if (!summary_exist) { + return; + } + + auto summary_nodes = graph_->summary_nodes(); + if (summary_nodes.empty()) { + return; + } + + size_t total_summary_size = 0; + for (auto &node_item : summary_nodes) { + auto node = node_item.second.first; + size_t index = IntToSize(node_item.second.second); + if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { + KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; + kernel_ref->ref_count_ = kMaxRefCount; + kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; + total_summary_size += kernel_ref->size_; + MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; + } else { + MS_LOG(WARNING) << "Can't find summary node's kernel_def " << node->fullname_with_scope() << " index: " << index; + } + } +#ifdef MEM_REUSE_DEBUG + auto graph = *graph_; + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); +#endif + MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; +} + +void MemReuseUtil::SetGraphOutputRefCount() { + auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); + for (const auto &node : nodes) { + session::KernelWithIndex kernel_input; + if (is_all_nop_node_) { + // The graph does not remove the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); + } else { + // The graph removes the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + } + MS_EXCEPTION_IF_NULL(kernel_input.first); + if (!kernel_input.first->isa() || !AnfAlgo::IsRealKernel(kernel_input.first)) { + continue; + } + auto ak_node = kernel_input.first->cast(); + auto key = ak_node.get(); + auto iter = kernel_output_refs_.find(key); + if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { + auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; + MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); + kernel_ref_count_ptr->ref_count_ = kMaxRefCount; + kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; + } + } +#ifdef MEM_REUSE_DEBUG + auto graph = *graph_; + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); +#endif +} + +void MemReuseUtil::ResetDynamicUsedRefCount() { + for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { + for (auto &ref_count : iter->second) { + MS_EXCEPTION_IF_NULL(ref_count); + ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; + } + } +} + +void MemReuseUtil::SetAllInfo(KernelGraph *graph) { + if (!InitDynamicKernelRef(graph)) { + MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; + } + SetKernelDefMap(); + SetReuseRefCount(); + SetSummaryNodesRefCount(); + SetWorkSpaceList(); +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); +#endif +} + +uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { + auto key = node.get(); + auto iter = kernel_output_refs_.find(key); + uint8_t *ptr = nullptr; + if (iter != kernel_output_refs_.end()) { + if (index >= iter->second.size()) { + MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; + } + auto output_ref = iter->second[index]; + ptr = mem_base_ + output_ref->offset_; + } else { + MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs"; + } + return ptr; +} + +uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const { + auto key = node.get(); + auto iter = kernel_workspace_refs_.find(key); + uint8_t *ptr = nullptr; + if (iter != kernel_workspace_refs_.end()) { + if (index >= iter->second.size()) { + MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; + } + auto wk_ref = iter->second[index]; + ptr = mem_base_ + wk_ref->offset_; + } + return ptr; +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h new file mode 100644 index 0000000000..b286bcbc2c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ +#include +#include +#include +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +using mindspore::kernel::tbe::TbeUtils; +namespace mindspore { +namespace memreuse { +static constexpr int kMaxRefCount = 9999; +static constexpr size_t kDefaultMemAlignSize = 512; +static constexpr size_t kAttAlignSize = 31; +static constexpr int kInvalidIndex = -2; + +using KernelDefPtrMaps = std::vector; +using KernelRefs = std::map; + +using KernelGraph = mindspore::session::KernelGraph; + +class MemReuseUtil { + public: + KernelRefs kernel_output_refs_; + KernelRefCountPtrList total_refs_list_; + KernelRefCountPtrList total_wk_ref_list_; + KernelRefs kernel_workspace_refs_; + MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {} + ~MemReuseUtil() { + if (graph_ != nullptr) { + graph_ = nullptr; + } + MS_LOG(INFO) << "Total Dynamic Memory Size: " << total_dy_size_; + MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_; + MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; + } + + void SetAllInfo(KernelGraph *graph); + bool InitDynamicOutputKernelRef(); + bool InitDynamicWorkspaceKernelRef(); + bool InitDynamicKernelRef(const KernelGraph *graph); + void SetWorkSpaceList(); + void SetKernelDefMap(); + void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); + void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); + void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); + void SetKernelDefInputs(); + void SetReuseRefCount(); + void SetSummaryNodesRefCount(); + // Set the reference count of graph output specially. + void SetGraphOutputRefCount(); + // Reset the dynamic used reference count by ref_count_. + void ResetDynamicUsedRefCount(); + + KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); + KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); + KernelRefCountPtrList total_refs_list() const { return total_refs_list_; } + KernelRefCountPtrList total_wk_ref_list() const { return total_wk_ref_list_; } + KernelDefPtrMaps kernel_def_ptr_list() const { return kernel_def_ptr_list_; } + int max_workspace_size() const { return max_workspace_size_; } + std::vector max_workspace_list() const { return max_workspace_list_; } + void set_total_refs_list(const KernelRefCountPtrList &total_refs_list) { total_refs_list_ = total_refs_list; } + void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { + kernel_def_ptr_list_ = kernel_def_ptr_list; + } + void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } + uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; + uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; + + private: + int util_index_; + const KernelGraph *graph_; + bool is_all_nop_node_; + KernelRefCountPtrList ref_list_; + KernelDefPtrMaps kernel_def_ptr_list_; + KernelRefCountPtrList last_ref_list_; + int max_workspace_size_ = 0; + std::vector max_workspace_list_; + size_t total_dy_size_ = 0; + size_t total_workspace_size_ = 0; + size_t total_reuseworkspace_size_ = 0; + uint8_t *mem_base_{nullptr}; + // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef + std::map kernel_map_; +}; +using MemReuseUtilPtr = std::shared_ptr; +} // namespace memreuse +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc new file mode 100644 index 0000000000..787d334a1a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.cc @@ -0,0 +1,411 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#ifdef ENABLE_D +#include "runtime/device/ascend/ascend_stream_assign.h" +#endif + +namespace mindspore { +namespace memreuse { +void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + set_tensor_ptr_list(mem_reuse_util_ptr->total_refs_list()); + set_workspace_ptr_list(mem_reuse_util_ptr->total_wk_ref_list()); + set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); + // check info Correctness + for (auto &tensor : tensor_ptr_list_) { + tensor->size_ = AlignMemorySize(tensor->size_); + } + // align wk size to 512 && refcount == 1 + for (auto &wk : wk_tensor_list_) { + wk->size_ = AlignMemorySize(wk->size_); + wk->ref_count_ = 1; + } +#ifdef ENABLE_D + stream_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group(); +#endif +} + +void BestFitMemReuse::InitKernelDependence() { + for (const auto &kernel : op_ptr_list_) { + std::set front; + std::queue to_visit; + to_visit.push(kernel); + // find all kernels before current kernel + while (!to_visit.empty()) { + auto curr = to_visit.front(); + to_visit.pop(); + if (front.count(curr)) { + continue; + } + front.insert(curr); + auto iter = kernel_front_map_.find(curr); + if (iter != kernel_front_map_.end()) { + auto visited_front = iter->second; + front.insert(visited_front.begin(), visited_front.end()); + continue; + } + for (const auto &input : curr->input_kernels()) { + to_visit.push(input); + } + } + kernel_front_map_[kernel] = front; + } +} + +bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf) { + // determine whether the kernel_curr can reuse kernel_prev's output tensor membuf + MS_EXCEPTION_IF_NULL(kernel_curr); + MS_EXCEPTION_IF_NULL(mem_buf); + auto kernel_prev = mem_buf->used_kernel_; + MS_EXCEPTION_IF_NULL(kernel_prev); + auto curr_stream_id = kernel_curr->stream_id(); + auto prev_stream_id = kernel_prev->stream_id(); + if (curr_stream_id == prev_stream_id) { + mem_buf->type_ = IN_STREAM_REUSE; + return true; + } + + bool reuse_between_streams = true; + for (auto &stream_group : stream_groups_) { + size_t cur_index = UINT32_MAX; + size_t prev_index = UINT32_MAX; + for (size_t index = 0; index < stream_group.size(); index++) { + if (curr_stream_id == stream_group[index]) { + cur_index = index; + continue; + } + if (prev_stream_id == stream_group[index]) { + prev_index = index; + continue; + } + } + if ((prev_index != UINT32_MAX) && (cur_index == UINT32_MAX || (prev_index > cur_index))) { + // previous stream and current stream are not in the same group can't be reused + // previous stream is behind current stream can't be reused + reuse_between_streams = false; + break; + } + } + + if (reuse_between_streams) { + mem_buf->type_ = BETWEEN_STREAMS_REUSE; + return true; + } + + auto iter = kernel_front_map_.find(kernel_curr); + if (iter == kernel_front_map_.end()) { + MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init."; + } + auto kernel_curr_front = iter->second; + auto depend_count = kernel_curr_front.count(kernel_prev); + if (depend_count) { + mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; + return true; + } + + return false; +} + +void BestFitMemReuse::AssignNodeOutputOffset() { + for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); + if (!reusable_membuf_map.empty()) { + auto membuf_index = reusable_membuf_map.begin()->second; + // find the best suitable membuf in membuf list, and reuse it + ReuseExistMembuf(tensor_desc.get(), membuf_index, kDynamicMem); + } else { + // no membuf can reuse, add new membuf after the membuf_ptr_list + AddNewMembufPtr(tensor_desc.get(), kDynamicMem); +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; +#endif + } + } +} + +void BestFitMemReuse::AssignNodeWorkspaceOffset() { + for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { + size_t index = GetWorkspaceIndex(wk_idx); + auto wk_ref = wk_tensor_list_[index]; + MS_EXCEPTION_IF_NULL(wk_ref); + auto re_wk_membuf_map = GetReusableMembufMap(wk_ref->size_); + if (!re_wk_membuf_map.empty()) { + auto membuf_index = re_wk_membuf_map.begin()->second; + ReuseExistMembuf(wk_ref.get(), membuf_index, kWorkspaceMem); + } else { + AddNewMembufPtr(wk_ref.get(), kWorkspaceMem); + } + } +} + +void BestFitMemReuse::ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag) { + MS_EXCEPTION_IF_NULL(tensor_desc); + CheckMembufIndx(membuf_index); + auto membuf = membuf_ptr_list_[membuf_index]; + MS_EXCEPTION_IF_NULL(membuf); + // first to split && then update membuf_info + if (IsSplit(tensor_desc->size_, membuf->size_)) { + // split the membuf, and insert a new membuf after this membuf + SplitMembuf(tensor_desc, membuf_index); + } + // update membuf status, and set tensor offset + UpdateMembufInfo(tensor_desc, membuf.get(), flag); +} + +std::map BestFitMemReuse::GetReusableMembufMap(size_t tensor_size) { + std::map size_map; + for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) { + auto membuf = membuf_ptr_list_[i]; + auto index = i; + bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size; + if (is_membuf_ok && IsUsable(current_kernel_, membuf)) { + (void)size_map.insert(std::make_pair(membuf->size_, index)); + break; + } + } + return size_map; +} + +void BestFitMemReuse::UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag) { + MS_EXCEPTION_IF_NULL(tensor_desc); + MS_EXCEPTION_IF_NULL(membuf); + auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); + membuf->status_ = kReused; + membuf->index_ = real_index; + membuf->used_kernel_ = current_kernel_; + tensor_desc->offset_ = membuf->offset_; +} + +bool BestFitMemReuse::IsSplit(size_t tensor_size, size_t membuf_size) const { return tensor_size < membuf_size; } + +void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index) { + MS_EXCEPTION_IF_NULL(tensor_desc); + CheckMembufIndx(membuf_index); + auto membuf = membuf_ptr_list_[membuf_index]; + MS_EXCEPTION_IF_NULL(membuf); + auto bias = membuf->size_ - tensor_desc->size_; + membuf->size_ = tensor_desc->size_; + // to check if spilt membuf can be merge + auto new_membuf = std::make_shared(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, + membuf->type_, current_kernel_); + (void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf); +} + +void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { + MS_EXCEPTION_IF_NULL(tensor_desc); + size_t membuf_offset = 0; + if (!membuf_ptr_list_.empty()) { + membuf_offset = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; + } + auto membuf_size = tensor_desc->size_; + auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); + auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); + membuf_ptr_list_.push_back(membuf); + tensor_desc->offset_ = membuf_offset; +} + +void BestFitMemReuse::UpdateNodeInputAndMembuf() { + // process node input tensor + for (const auto &tensor_idx : current_kernel_->GetInputRefIndexs()) { + size_t tensor_index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[tensor_index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + tensor_desc->ref_count_--; + if (tensor_desc->ref_count_ == 0) { + ReleaseMembuf(tensor_index, kDynamicMem); + } else if (tensor_desc->ref_count_ < 0) { + MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ + << " check error"; + } + } +} + +void BestFitMemReuse::ReleaseNodeUnusedOutput() { + for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { + size_t tensor_index = GetTensorIndex(tensor_idx); + auto tensor_desc = tensor_ptr_list_[tensor_index]; + MS_EXCEPTION_IF_NULL(tensor_desc); + if (tensor_desc->ref_count_ == 0) { + ReleaseMembuf(tensor_index, kDynamicMem); + } else if (tensor_desc->ref_count_ < 0) { + MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ + << " check error"; + } + } +} + +void BestFitMemReuse::ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr) { + for (auto &workspace_index : kernel_def_ptr->GetWorkspaceRefIndexs()) { + size_t index = GetWorkspaceIndex(workspace_index); + auto wk_tensor = wk_tensor_list_[index]; + wk_tensor->ref_count_--; + if (wk_tensor->ref_count_ == 0) { + ReleaseMembuf(index, kWorkspaceMem); + } else if (wk_tensor->ref_count_ < 0) { + MS_LOG(EXCEPTION) << "tensor: " << wk_tensor->index_ << " refcount: " << wk_tensor->ref_count_ << " check error"; + } + } +} + +void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { + if (membuf_ptr_list_.empty()) { + return; + } + auto real_index = GetRealIndex(tensor_index, flag); + auto membuf_iter = std::find_if(membuf_ptr_list_.begin(), membuf_ptr_list_.end(), + [real_index](const MembufPtr &membuf) { return membuf->index_ == real_index; }); + if (membuf_iter == membuf_ptr_list_.end()) { + return; + } + auto membuf = (*membuf_iter); + MS_EXCEPTION_IF_NULL(membuf); + membuf->status_ = kUnused; + if (membuf_iter != membuf_ptr_list_.end() - 1) { + auto next_iter = membuf_iter + 1; + auto membuf_next = (*next_iter); + MS_EXCEPTION_IF_NULL(membuf_next); + if (membuf_next->status_ == kUnused) { + bool is_merge = IsUsable(current_kernel_, membuf_next); + if (is_merge) { + membuf->size_ += membuf_next->size_; + (void)membuf_ptr_list_.erase(next_iter); + } + } + } + if (membuf_iter != membuf_ptr_list_.begin()) { + auto prev_iter = membuf_iter - 1; + auto membuf_prev = (*prev_iter); + MS_EXCEPTION_IF_NULL(membuf_prev); + if (membuf_prev->status_ == kUnused) { + bool is_merge = IsUsable(current_kernel_, membuf_prev); + if (is_merge) { + membuf->size_ += membuf_prev->size_; + membuf->offset_ = membuf_prev->offset_; + (void)membuf_ptr_list_.erase(prev_iter); + } + } + } +} + +size_t BestFitMemReuse::AlignMemorySize(size_t size) const { + // memory size 512 align + return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; +} + +size_t BestFitMemReuse::GetAllocatedSize() { + size_t AllocatedSize = kTotalSize; + if (membuf_ptr_list_.empty()) { + return AllocatedSize; + } + AllocatedSize = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; + MS_LOG(INFO) << "MemReuse Allocated Dynamic Size: " << AllocatedSize; + return AllocatedSize; +} + +bool BestFitMemReuse::IsRelease() { + // unable_used_node include the node type that output tensor cannot be released, + // even if its refcount is equal to zero. + std::unordered_set unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), + prim::kPrimFusedBatchNorm->name(), + prim::kPrimFusedBatchNormGrad->name()}; + return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end(); +} + +size_t BestFitMemReuse::GetTensorIndex(int index) const { + if (index < 0 || IntToSize(index) >= tensor_ptr_list_.size()) { + MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); + MS_LOG(EXCEPTION) << "invalid tensor index"; + } + return IntToSize(index); +} + +size_t BestFitMemReuse::GetWorkspaceIndex(int index) const { + if (index < 0 || IntToSize(index) >= wk_tensor_list_.size()) { + MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); + MS_LOG(EXCEPTION) << "invalid tensor index"; + } + return IntToSize(index); +} + +int BestFitMemReuse::GetRealIndex(size_t index, int flag) const { + if (flag == kDynamicMem) { + return SizeToInt(index); + } else if (flag == kWorkspaceMem) { + return kWorkspaceIndexFactor * SizeToInt(index + 1); + } else { + MS_LOG(EXCEPTION) << "flag " << flag << " is invalid"; + } +} + +void BestFitMemReuse::CheckMembufIndx(size_t membuf_index) const { + if (membuf_index >= membuf_ptr_list_.size()) { + MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); + MS_LOG(EXCEPTION) << "invalid membuf index: " << membuf_index << ", real size: " << membuf_ptr_list_.size(); + } +} + +void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + InitMemReuseInfo(mem_reuse_util_ptr); + InitKernelDependence(); + KernelDefPtr pre_op = nullptr; +#ifdef MEM_REUSE_DEBUG + size_t op_num = 0; +#endif + for (const auto &op_def_ptr : op_ptr_list_) { + current_kernel_ = op_def_ptr; + // releas pre_op_def + if (pre_op != nullptr) { + ReleasePreNodeWorkspace(pre_op.get()); + } + MemReuseChecker::GetInstance().IsAddNewMembuf_ = false; + // process node output tensor + AssignNodeOutputOffset(); +#ifdef MEM_REUSE_DEBUG + if (MemReuseChecker::GetInstance().IsAddNewMembuf_) { + MemReuseChecker::GetInstance().SetAddNewMembuInfos(op_def_ptr.get(), membuf_ptr_list_, op_num); + } +#endif + // deal with current op'workspace + AssignNodeWorkspaceOffset(); + pre_op = op_def_ptr; + // update node input tensor refcount, and membuf list status + UpdateNodeInputAndMembuf(); + // check node output tensor which refcount is equal to zero + if (IsRelease()) { + ReleaseNodeUnusedOutput(); + } +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_); + ++op_num; +#endif + } +#ifdef MEM_REUSE_DEBUG + MemReuseChecker::GetInstance().ExportMembufInfoIR(); + MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); + MemReuseChecker::GetInstance().set_kernel_front_map(kernel_front_map_); + MemReuseChecker::GetInstance().ExportKernelDependence(); +#endif +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h new file mode 100644 index 0000000000..ef1cfd3e11 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_allocator.h @@ -0,0 +1,159 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" + +namespace mindspore { +namespace memreuse { +static constexpr int kWorkspaceIndexFactor = -1000; +static constexpr int kDynamicMem = -1; +static constexpr int kWorkspaceMem = 1; +static constexpr size_t kTotalSize = 0; +enum Status { kUnused, kReused }; +enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; +class Membuf { + public: + Membuf() = default; + Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) + : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} + ~Membuf() = default; + // Memory block status flags + Status status_ = kUnused; + size_t size_{0}; + size_t offset_{0}; + // Store the tensor index stored in this memory block at a certain moment + int index_{0}; + MEMTYPE type_{NEW}; + KernelDefPtr used_kernel_; +}; +using MembufPtr = std::shared_ptr; + +class BestFitMemReuse { + public: + BestFitMemReuse() = default; + ~BestFitMemReuse() { membuf_ptr_list_.clear(); } + /** + * Init all information need by memory reuse + * @param mem_reuse_util_ptr, initialize in the memreuse.cc + */ + void InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr); + void CheckMembufIndx(size_t check_idx) const; + void AssignNodeWorkspaceOffset(); + void ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr); + /** + * Assign output tensor memory offset of current kernel + */ + void AssignNodeOutputOffset(); + /** + * Update input tensor's status of current kernel, and the status of membuf used by current kernel + */ + void UpdateNodeInputAndMembuf(); + /** + * Check whether to release the kernel output tensor which refcount is equal to zero + */ + void ReleaseNodeUnusedOutput(); + /** + * Reuse the exist membuf if possible + * @param tensor_desc, the output tensor of current kernel + * @param membuf_index, the index of membuf to be reused + * @param flag + */ + void ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag); + /** + * Get the membuf that can be reused + * @param tensor_size, the size of the tensor ready to assign memory offset + * @return membuf map, key: the membuf size, value: the membuf index + */ + std::map GetReusableMembufMap(size_t tensor_size); + /** + * Update the status of the reused memory block + * @param tensor_desc, the tensor ready to assign memory + * @param membuf, the membuf to be reused + * @param flag, distinguish dynamic memory and workspace + */ + void UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag); + // If the size of the memory block is greater than the size of the tensor, split the extra memory + void SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index); + // Determine if the memory block needs to be split + bool IsSplit(size_t tensor_size, size_t membuf_size) const; + // If there is no memory block that can be reused, add a new memory block at the end + void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); + // Merge unused membuf + void ReleaseMembuf(size_t tensor_index, int flag); + // Memory address alignment 512 + size_t AlignMemorySize(size_t size) const; + int GetRealIndex(size_t index, int flag = kDynamicMem) const; + size_t GetTensorIndex(int index) const; + size_t GetWorkspaceIndex(int index) const; + // Memory reuse main program entry + void Reuse(const MemReuseUtil *mem_reuse_util_ptr); + // Get the total memory that needs to be applied eventually + size_t GetAllocatedSize(); + // return false, when the node output cannot be released + bool IsRelease(); + /** + * determine if the kernel_curr can reuse the output tensor add of kernel_prev + * @param kernel_curr, current kernel + * @param mem_buf, the membuf + * @return bool + */ + bool IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf); + /** + * init the dependence of all kernels in the graph + */ + void InitKernelDependence(); + // set tensor_def and op_def + void set_tensor_ptr_list(const std::vector &tensor_ptr_list) { + tensor_ptr_list_ = tensor_ptr_list; + } + void set_workspace_ptr_list(const std::vector &workspace_ptr_list) { + wk_tensor_list_ = workspace_ptr_list; + } + void set_op_ptr_list(const std::vector &op_ptr_list) { op_ptr_list_ = op_ptr_list; } + + private: + KernelDefPtr current_kernel_; + // Save all tensor information + std::vector tensor_ptr_list_; + std::vector wk_tensor_list_; + // Save all op information, including input and output tensor index + std::vector op_ptr_list_; + // Memory block information sequence, temporary variables + std::vector membuf_ptr_list_; + // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def + std::map> kernel_front_map_; + std::vector> stream_groups_; +}; +} // namespace memreuse +} // namespace mindspore +#endif // #define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc new file mode 100644 index 0000000000..b93bf42f9f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc @@ -0,0 +1,572 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#include +#include +#include +#include + +namespace mindspore { +namespace memreuse { +MemReuseChecker &MemReuseChecker::GetInstance() { + static MemReuseChecker instance; + return instance; +} + +void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) { + std::string node_name = AnfAlgo::GetCNodeName(c_node); + if (node_name == kSend || node_name == kRecv) { + MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send"; + // get op's info && check + MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node) + << " out_num: " << AnfAlgo::GetOutputTensorNum(c_node); + } +} + +void MemReuseChecker::CheckWorkSpace(const std::vector &max_list) { + for (auto &ma : max_list) { + total_re_wkspe_size_checker_ += ma; + } +} + +void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx) { + auto key = c_node.get(); + auto iter = kernel_refs.find(key); + auto node_name = AnfAlgo::GetCNodeName(c_node); + if (iter == kernel_refs.end()) { + MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString() + << " output index: " << output_idx; + } + if (output_idx >= iter->second.size()) { + MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str(); + MS_LOG(EXCEPTION) << "The index: " << output_idx + << " is out of the size of kernel_output_refs_:" << iter->second.size(); + } +} + +int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t static_input_size = 0; + for (auto &item : graph->inputs()) { + if (!item->isa()) { + continue; + } + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + TypeId ou_type = AnfAlgo::GetOutputDeviceDataType(item, index); + // parameter has not init by a cnode + if (ou_type == kTypeUnknown) { + ou_type = AnfAlgo::GetOutputInferDataType(item, index); + } + size_t type_size = GetTypeByte(TypeIdToType(ou_type)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(item, index); + size_t tensor_size = + shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + auto checker_size = SizeToLong(tensor_size); + static_input_size += checker_size; + } + } + return static_input_size; +} + +int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t static_value_size = 0; + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + auto tensor = node_value->cast(); + if (tensor == nullptr) { + continue; + } + size_t tensor_size = tensor->data().nbytes(); + auto checker_size = SizeToLong(tensor_size); + static_value_size += checker_size; + } + return static_value_size; +} + +int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { + // cal static inputs + auto static_input_size = CalculOriInput(graph); + // do not calcul outpput size + auto statica_value_size = CalculOriValue(graph); + auto total_ori_static_size = static_input_size + statica_value_size; + return total_ori_static_size; +} + +int64_t MemReuseChecker::CalculOriDy(const KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t ori_dy_size = 0; + auto kerenls = graph->execution_order(); + for (auto &kernel : kerenls) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (auto &dy_size : kernel_mod->GetOutputSizeList()) { + auto checker_size = SizeToLong(dy_size); + ori_dy_size += checker_size; + } + } + return ori_dy_size; +} + +int64_t MemReuseChecker::CalculOriWk(const KernelGraph *graph) const { + MS_EXCEPTION_IF_NULL(graph); + int64_t ori_wk_size = 0; + auto kerenls = graph->execution_order(); + for (auto &kernel : kerenls) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (auto &wk_size : kernel_mod->GetWorkspaceSizeList()) { + auto checker_size = SizeToLong(wk_size); + ori_wk_size += checker_size; + } + } + return ori_wk_size; +} + +std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { + auto indx = scope_name.rfind(kSplitC); + if (indx == std::string::npos) { + return scope_name; + } else { + if (indx < scope_name.size() - 1) { + auto split_name = scope_name.substr(indx + 1); + return split_name; + } + return scope_name; + } +} + +void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, + const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { + total_ori_static_size_ = CalculOriStatic(graph); + total_ori_input_size_ = CalculOriInput(graph); + total_ori_value_size_ = CalculOriValue(graph); + total_ori_dy_size_ = CalculOriDy(graph); + total_ori_wkspace_size_ = CalculOriWk(graph); + std::string graph_id = std::to_string(graph->graph_id()); + std::string filename = "./memreuse_" + graph_id + ".ir"; + std::ofstream ofs(filename); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; + return; + } + ofs << "all_tensor_refs:\n"; + ofs << "index:" + << "\tsize:" + << "\trefcount:\n"; + for (auto &ref : total_refs_list) { + ofs << "%" << ref->index_ << "T" + << "\t" + << "#" << ref->size_ << "S" + << "\t" << ref->ref_count_ << "C" + << "\n"; + } + ofs << "kernel_def exc_order:\n"; + int def_idx = 0; + for (auto &def : kernel_def_ptr_list) { + ExportMemOpIr(def.get(), ofs, def_idx); + def_idx++; + } + ofs.close(); +} + +void MemReuseChecker::ExportKernelDependence() { + std::string filename = "./memreuse_dependence.ir"; + std::ofstream ofs(filename); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; + return; + } + size_t i = 0; + for (const auto &kernel_front : kernel_front_map_) { + auto kernel = kernel_front.first; + auto front = kernel_front.second; + ofs << "[" << i++ << "] " << kernel->scope_full_name() << "\n"; + for (const auto &node : front) { + ofs << node->scope_full_name() << "\n"; + } + ofs << "\n\n"; + } + + ofs.close(); +} + +bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph) { + // set real graph output node to be special who's refcount equal kMaxRefCount + for (const auto &output : graph->outputs()) { + MS_EXCEPTION_IF_NULL(output); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { + if (output->isa()) { + auto cnode = output->cast(); + auto input_node = cnode->input(i + 1); + auto kernel_input_with_idx = AnfAlgo::VisitKernel(input_node, 0); + auto kernel_input = kernel_input_with_idx.first; + MS_EXCEPTION_IF_NULL(kernel_input); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_input); + if (kernel_mod == nullptr) { + continue; + } + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + continue; + } + for (size_t j = 0; j < output_sizes.size(); ++j) { + if (!AnfAlgo::OutputAddrExist(kernel_input, j)) { + return false; + } + } + } + } + } + return true; +} + +void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { + auto scope_name = def->scope_full_name(); + std::string split_name = GetSplitName(scope_name); + ofs << "$" << def_idx << "\t" << split_name << "\t"; + ofs << "inputs["; + for (auto &in : def->inputs_) { + for (auto &in_ref : in.second) { + ofs << "%" << in_ref->index_ << "T" + << ","; + } + } + ofs << "]"; + ofs << "\toutpus["; + for (auto &ou : def->outputs_) { + for (auto &ou_ref : ou.second) { + ofs << "%" << ou_ref->index_ << "T" + << ","; + } + } + ofs << "]"; + ofs << "\tstreamID[" + << "@" << def->stream_id() << "]\n"; +} + +void MemReuseChecker::ExportNormalTensorIR(std::ofstream &ofs) { + ofs << "all_tensor_refs:\n"; + ofs << "index:" + << "\tsize:" + << "\trefcount:\n"; + size_t ou_idx = 0; + for (auto &ou : nor_output_tensors_) { + ofs << "%" << ou_idx << "T" + << "\t" + << "#" << nor_tensor_sizes_[ou_idx] << "S" + << "\t"; + auto iter_ref = ptr_refs_.find(ou); + if (iter_ref != ptr_refs_.end()) { + ofs << iter_ref->second << "C" + << "\n"; + } else { + MS_LOG(EXCEPTION) << "can not find refs for output"; + } + ou_idx++; + } + ofs << "kernel_def exc_order:\n"; +} + +int MemReuseChecker::GetTensorIdx(const void *in) const { + auto iter = ptr_idx_.find(in); + if (iter == ptr_idx_.end()) { + return kInvalidIndex; + } else { + return SizeToInt(iter->second); + } +} + +void MemReuseChecker::ExportNormalOpIr(const std::vector &cnodes) { + std::ofstream ofs("./normal_mem.ir"); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file failed!"; + return; + } + ExportNormalTensorIR(ofs); + size_t node_idx = 0; + for (const auto &node : cnodes) { + MS_EXCEPTION_IF_NULL(node); + ofs << "$" << node_idx << "\t" << GetSplitName(node->fullname_with_scope()) << "\t"; + std::vector in_idx; + auto iter = node_ins_.find(node.get()); + if (iter != node_ins_.end()) { + for (auto &in : iter->second) { + if (GetTensorIdx(in) != kInvalidIndex) { + in_idx.push_back(GetTensorIdx(in)); + } + } + } + std::vector ou_idx; + iter = node_ous_.find(node.get()); + if (iter != node_ous_.end()) { + for (auto &ou : iter->second) { + if (GetTensorIdx(ou) != kInvalidIndex) { + ou_idx.push_back(GetTensorIdx(ou)); + } + } + } + ofs << "inputs["; + for (auto idx : in_idx) { + bool has_in_ou = std::any_of(ou_idx.begin(), ou_idx.end(), [idx](int odx) { return idx == odx; }); + if (!has_in_ou) { + ofs << "%" << idx << "T,"; + } + } + ofs << "]\toutpus["; + for (auto odx : ou_idx) { + ofs << "%" << odx << "T,"; + } + ofs << "]\tstreamID[@" << AnfAlgo::GetStreamId(node) << "]\n"; + node_idx++; + } + ofs.close(); +} + +void MemReuseChecker::SetTesnorFromAndToInfo(const KernelDef *op_def) { + auto split_name = GetSplitName(op_def->scope_full_name()); + for (auto &in : op_def->inputs_) { + auto in_tensors = in.second; + for (auto &tensor : in_tensors) { + auto indx = tensor->index_; + tensor_to_[indx].push_back(split_name); + } + } + for (auto &ou : op_def->outputs_) { + auto ou_tensors = ou.second; + for (auto &tensor : ou_tensors) { + auto indx = tensor->index_; + tensor_from_[indx].push_back(split_name); + } + } +} + +void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { + const auto &cnodes = graph->execution_order(); + for (const auto &node : cnodes) { + std::vector curr_ous; + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { + auto it = AnfAlgo::GetOutputAddr(node, i); + MS_EXCEPTION_IF_NULL(it); + auto ptr = it->GetPtr(); + nor_output_tensors_.push_back(ptr); + nor_tensor_sizes_.push_back(it->GetSize()); + curr_ous.push_back(it->GetPtr()); + } + (void)node_ous_.insert(std::make_pair(node.get(), curr_ous)); + std::vector curr_ins; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + if (i + 1 >= node->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index: " << i + << " is larger than input number: " << AnfAlgo::GetInputTensorNum(node); + } + auto real_input_index = AnfAlgo::GetRealInputIndex(node, i); + auto input = node->input(real_input_index + 1); + MS_EXCEPTION_IF_NULL(input); + auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); + if (kernel_with_index.first->isa()) { + continue; + } + auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, real_input_index); + MS_EXCEPTION_IF_NULL(device_address); + nor_input_tensors_.push_back(device_address->GetPtr()); + curr_ins.push_back(device_address->GetPtr()); + } + (void)node_ins_.insert(std::make_pair(node.get(), curr_ins)); + } + size_t ou_idx = 0; + for (const auto &ou : nor_output_tensors_) { + (void)ptr_idx_.insert(std::make_pair(ou, ou_idx)); + (void)ptr_refs_.insert(std::make_pair(ou, 0)); + ou_idx++; + } + for (const auto &in : nor_input_tensors_) { + if (ptr_idx_.find(in) != ptr_idx_.end()) { + if (ptr_refs_.find(in) != ptr_refs_.end()) { + auto iter = ptr_refs_.find(in); + (iter->second)++; + } else { + MS_LOG(EXCEPTION) << "ptr_refs is not equal to ptr_idx"; + } + } + } + ExportNormalOpIr(cnodes); +} + +void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list) { + std::vector curr_mem_infos; + for (const auto &mem : membuf_ptr_list) { + auto mem_checker = + std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); + curr_mem_infos.push_back(mem_checker); + } + membuf_all_infos_.push_back(curr_mem_infos); + auto split_name = GetSplitName(op_def->scope_full_name()); + all_split_names_.push_back(split_name); + SetTesnorFromAndToInfo(op_def); +} + +void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, + size_t op_idx) { + std::vector add_new_curr_mem; + + for (const auto &mem : membuf_ptr_list) { + auto mem_checker = + std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); + add_new_curr_mem.push_back(mem_checker); + } + add_new_mem_infos_.push_back(add_new_curr_mem); + auto split_name = GetSplitName(op_def->scope_full_name()); + add_new_names_.push_back(split_name); + add_new_op_indxs_.push_back(op_idx); + add_new_stream_ids_.push_back(op_def->stream_id()); +} + +void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) { + size_t i = 0; + std::vector each_node_used_size; + std::vector each_node_allocated_size; + for (const auto &curr_membuf_list : membuf_all_infos_) { + ofs << all_split_names_.at(i) << "\n"; + ++i; + ofs << "mem_num\t" + << "stream_id\t" + << "status\t" + << "tensor_idex\t" + << "mem_size\t" + << "mem_head\t" + << "mem_tail\t" + << "mem_type\t" + << "used_kernel\n"; + size_t curr_used = 0; + size_t curr_allocated = 0; + for (size_t j = 0; j < curr_membuf_list.size(); ++j) { + auto membuf = curr_membuf_list.at(j); + auto used_kernel = membuf->used_kernel_->scope_full_name(); + ofs << "&" << j << "\t" + << "streamID[@" << membuf->used_kernel_->stream_id() << "]" + << "\t" + << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" + << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t\t" << membuf->offset_ + membuf->size_ << "\t" + << "\t" << static_cast(membuf->type_) << "\t" << GetSplitName(used_kernel) << "\n"; + if (membuf->status_ == kReused) { + curr_used += membuf->size_; + } + } + if (!curr_membuf_list.empty()) { + curr_allocated = curr_membuf_list.back()->offset_ + curr_membuf_list.back()->size_; + } + each_node_used_size.push_back(curr_used); + each_node_allocated_size.push_back(curr_allocated); + ofs << "curr real used size: \t" << curr_used << "\n"; + ofs << "curr allocated size: \t" << curr_allocated << "\n"; + ofs << "\n\n"; + } + auto optimal_iter = std::max_element(each_node_used_size.begin(), each_node_used_size.end()); + ofs << "theoretical optimal size: " << *optimal_iter << "\n"; + ofs << "each node used size: \n"; + for (auto size : each_node_used_size) { + ofs << size << "\t"; + } + ofs << "\n\n"; + ofs << "each node allocated size: \n"; + for (auto size : each_node_allocated_size) { + ofs << size << "\t"; + } + ofs << "\n\n"; +} + +void MemReuseChecker::ExportMembufInfoIR() { + std::string ir_file_name = "./mem_buf_info.ir"; + std::ofstream ofs(ir_file_name); + int64_t total_reuse_size = 0; + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; + } + ofs << "Total static size:\t" << total_ori_static_size_ << "\n"; + ofs << "Graph inputs size:\t" << total_ori_input_size_ << "\n"; + ofs << "Value nodes size:\t" << total_ori_value_size_ << "\n"; + ofs << "Total dynamic size:\t" << total_ori_dy_size_ << "\n"; + ofs << "Total workspace size:\t" << total_ori_wkspace_size_ << "\n"; + // get last membuf_list + if (membuf_all_infos_.empty()) { + return; + } + auto last_membuf_list = membuf_all_infos_.back(); + for (const auto &membuf : last_membuf_list) { + auto checker_size = SizeToLong(membuf->size_); + total_reuse_size += checker_size; + } + ofs << "After reuse size:\t" << total_reuse_size << "\n\n"; + ExportEachMembufInfo(ofs); + ofs.close(); +} + +void MemReuseChecker::ExportAddNewMmebufIR() { + std::string ir_file_name = "./AddNewMembuf.ir"; + std::ofstream ofs(ir_file_name); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; + } + auto check_idx = add_new_mem_infos_.size(); + if (check_idx == add_new_op_indxs_.size() && check_idx == add_new_names_.size() && + check_idx == add_new_stream_ids_.size()) { + size_t i = 0; + for (const auto &curr_membuf_list : add_new_mem_infos_) { + ofs << "op_idx:$" << add_new_op_indxs_.at(i) << "\t" << add_new_names_.at(i) << "\t"; + ofs << "streamID[@" << add_new_stream_ids_.at(i) << "]" + << "\n"; + i++; + ofs << "mem_num\t" + << "status\t" + << "tensor_idex\t" + << "mem_size\t" + << "mem_head\t" + << "mem_tail\t" + << "FromOp\t" + << "ToOp\n"; + for (size_t j = 0; j < curr_membuf_list.size(); ++j) { + auto membuf = curr_membuf_list.at(j); + ofs << "&" << j << "\t" + << "\t" + << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" + << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t"; + auto in_idx_iter = tensor_from_.find(membuf->index_); + if (in_idx_iter != tensor_from_.end()) { + for (auto &in_name : in_idx_iter->second) { + ofs << in_name << ","; + } + ofs << "\t"; + } + auto ou_idx_iter = tensor_to_.find(membuf->index_); + if (ou_idx_iter != tensor_to_.end()) { + for (auto &ou_name : ou_idx_iter->second) { + ofs << ou_name << ","; + } + ofs << "\n"; + } + } + ofs << "\n"; + } + } + ofs.close(); +} +} // namespace memreuse +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h new file mode 100644 index 0000000000..3c4a00a3ca --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ +#include +#include +#include +#include +#include +#include +#include "mindspore/core/ir/anf.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +namespace mindspore { +namespace memreuse { +constexpr auto kSend = "Send"; +constexpr auto kRecv = "Recv"; +constexpr auto kSplitC = '/'; +class MemReuseChecker { + public: + bool IsAddNewMembuf_ = false; + static MemReuseChecker &GetInstance(); + MemReuseChecker(const MemReuseChecker &) = delete; + MemReuseChecker &operator=(const MemReuseChecker &) = delete; + void CheckSignalOps(const CNodePtr &c_node); + void CheckWorkSpace(const std::vector &max_list); + void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); + bool CheckGraphOutputAssigned(const session::KernelGraph *graph); + void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, + KernelGraph *graph); + int64_t CalculOriStatic(KernelGraph *graph) const; + int64_t CalculOriInput(const KernelGraph *graph) const; + int64_t CalculOriValue(KernelGraph *graph) const; + int64_t CalculOriDy(const KernelGraph *graph) const; + int64_t CalculOriWk(const KernelGraph *graph) const; + std::string GetSplitName(const std::string &scope_name) const; + int GetTensorIdx(const void *in) const; + void SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list); + void SetTesnorFromAndToInfo(const KernelDef *op_def); + void ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx); + void ExportNormalOpIr(const std::vector &cnodes); + void ExportNormalTensorIR(std::ofstream &ofs); + void CheckNormalIR(const session::KernelGraph *graph); + void ExportMembufInfoIR(); + void ExportEachMembufInfo(std::ofstream &ofs); + void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, size_t op_idx); + void ExportAddNewMmebufIR(); + void set_kernel_front_map(const std::map> &kernel_front_map) { + kernel_front_map_ = kernel_front_map; + } + void ExportKernelDependence(); + + private: + MemReuseChecker() = default; + ~MemReuseChecker() {} + size_t total_re_wkspe_size_checker_{0}; + std::vector> membuf_all_infos_; + std::vector nor_output_tensors_; + std::vector nor_tensor_sizes_; + std::vector nor_input_tensors_; + std::map ptr_idx_; + std::map ptr_refs_; + std::map> node_ins_; + std::map> node_ous_; + std::vector> add_new_mem_infos_; + std::vector add_new_names_; + std::vector add_new_op_indxs_; + std::vector add_new_stream_ids_; + std::vector all_split_names_; + std::map> tensor_from_; + std::map> tensor_to_; + std::map> kernel_front_map_; + int64_t total_ori_static_size_ = 0; + int64_t total_ori_input_size_ = 0; + int64_t total_ori_value_size_ = 0; + int64_t total_ori_dy_size_ = 0; + int64_t total_ori_wkspace_size_ = 0; +}; +} // namespace memreuse +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc new file mode 100644 index 0000000000..41bf5460c3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -0,0 +1,344 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/mem_reuse/mem_swap_manager.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace device { +namespace memswap { +void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + graph_manager_ = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(graph_manager_); + auto &kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { + if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { + execution_order_.push_back(kernel); + } + } + + size_t kernel_index = 0; + for (const auto &kernel : execution_order_) { + // parse topo order of kernel + (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); + // parse tensor info + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) { + TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx}; + ordered_tensors_.push_back(tensor_info); + } + } + + // parse topo order of user kernel + SaveUserKernelTopoOrder(); + + sort(ordered_tensors_.begin(), ordered_tensors_.end(), + [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); + + auto cur_tensor_size = ordered_tensors_.front().tensor_size_; + for (auto &tensor_info : ordered_tensors_) { + if (cur_tensor_size != tensor_info.tensor_size_) { + cur_tensor_size = tensor_info.tensor_size_; + tensor_size_num_++; + } + } + tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; + tensor_size_threshold_idx_ = 0; + + distance_threshold_ = kernel_index / kDistanceInitFactor; + mem_swap_initialized_ = true; + MS_EXCEPTION_IF_NULL(mem_copy_manager_); + mem_copy_manager_->Init(); +} + +bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + NodeUsersMap &user_map = graph_manager_->node_users(); + auto iter = user_map.find(kernel); + bool adjacent_with_communication_op = false; + if (iter != user_map.end()) { + AnfNodeIndexSet node_set = iter->second; + adjacent_with_communication_op = std::any_of( + node_set.begin(), node_set.end(), + [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); + } + return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; +} + +void MemSwapManager::SaveUserKernelTopoOrder() { + NodeUsersMap &user_map = graph_manager_->node_users(); + for (const auto &kernel : execution_order_) { + auto iter = user_map.find(kernel); + if (iter == user_map.end()) { + continue; + } + AnfNodeIndexSet node_set = iter->second; + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + for (auto &node_pair : node_set) { + auto user_kernel = node_pair.first; + if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { + continue; + } + + size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); + auto &output_idx = kernel_with_index.second; + if (kernel_with_index.first.get() != kernel.get()) { + MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); + } + for (auto &node_user_pair : kernel_exec_info.node_users_map_) { + sort(node_user_pair.second.begin(), node_user_pair.second.end()); + } + } +} + +void MemSwapManager::AddSwapInfo() { + for (const auto &tensor : ordered_tensors_) { + size_t tensor_size = tensor.tensor_size_; + if (tensor_size < tensor_size_threshold_) { + break; + } + + size_t output_idx = tensor.output_idx_; + const AnfNodePtr &kernel = tensor.kernel_; + if (IsCommunicationRelevantOp(kernel)) { + continue; + } + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &node_users_map = kernel_exec_info.node_users_map_; + + auto iter = node_users_map.find(output_idx); + if (iter == node_users_map.end()) { + continue; + } + auto &node_users = iter->second; + bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) || + (node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_); + if (!need_swap) { + continue; + } + AddKernelNeedSwap(kernel, true); + HostAddress host_addr; + host_addr.size = tensor_size; + auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast(&host_addr.addr)); + if (!ret) { + MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; + } + kernel_exec_info.host_addrs_[output_idx] = host_addr; + MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; + if (node_users.size() > 1) { + AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); + AddKernelTriggerSwap(execution_order_[node_users[0]], true); + } else { + AddKernelMemSwapInfo(kernel, mem_swap_out_info); + AddKernelTriggerSwap(kernel, true); + } + + size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; + if (swap_in_order <= kernel_exec_info.topo_order_) { + MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + auto swap_in_kernel = execution_order_[swap_in_order]; + MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; + AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); + AddKernelTriggerSwap(swap_in_kernel, true); + + host_addrs_list_.push_back(host_addr); + } +} + +void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, + const HostAddress &host_address) const { + if (swap_kind == SwapKind::kDeviceToHost) { + mem_copy_manager_->AddMemSwapOutTask(device_address, host_address); + } else if (swap_kind == SwapKind::kHostToDevice) { + mem_copy_manager_->AddMemSwapInTask(device_address, host_address); + } +} + +bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) const { + return mem_copy_manager_->SyncMemCopyStream(swap_kind); +} + +DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { + if (swap_kind == SwapKind::kDeviceToHost) { + return mem_copy_manager_->UpdateSwapOutQueue(); + } else { + return mem_copy_manager_->UpdateSwapInQueue(); + } +} + +// retreat to find a workable swap scheme +bool MemSwapManager::RetreatSwapInfo() { + if (!trigger_swap_) { + trigger_swap_ = true; + } + if (swap_info_already_set_) { + ResetSwapInfo(); + if (distance_threshold_ >= kDistanceLowerBound) { + auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; + distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1); + } + + while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { + ++tensor_size_threshold_idx_; + if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { + tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; + break; + } + } + + if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { + MS_LOG(ERROR) << "Retreat swap info failed"; + return false; + } + } else { + swap_info_already_set_ = true; + } + AddSwapInfo(); + return true; +} + +KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + auto iter = kernel_execution_info_.find(kernel.get()); + if (iter == kernel_execution_info_.end()) { + MS_LOG(EXCEPTION) << "Can not find execution info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return const_cast(iter->second); +} + +void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.execution_perform_ = perform; +} + +void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.trigger_swap_ = trigger_swap; +} + +void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + kernel_exec_info.need_swap_ = need_swap; +} + +void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, + const std::pair &perform) { + MS_EXCEPTION_IF_NULL(kernel); + kernel_swap_perform_[kernel.get()][output_idx] = perform; +} + +void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { + MS_EXCEPTION_IF_NULL(kernel); + mem_swap_info_[kernel.get()].push_back(mem_swap_info); +} + +float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.execution_perform_; +} + +bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.trigger_swap_; +} + +bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.need_swap_; +} + +const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { + MS_EXCEPTION_IF_NULL(kernel); + auto iter_kernel = kernel_swap_perform_.find(kernel.get()); + if (iter_kernel == kernel_swap_perform_.end()) { + MS_LOG(EXCEPTION) << "Can not find swap performance data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + + auto &perform_map = iter_kernel->second; + auto iter_output = perform_map.find(output_idx); + if (iter_output == perform_map.end()) { + MS_LOG(EXCEPTION) << "Can not find swap performance data of output[" << output_idx << "] of op[" + << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return iter_output->second; +} + +const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { + MS_EXCEPTION_IF_NULL(kernel); + auto iter = mem_swap_info_.find(kernel.get()); + if (iter == mem_swap_info_.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return iter->second; +} + +void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } + +bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { + auto iter = swap_in_blacklist_.find(device_ptr); + return iter != swap_in_blacklist_.end(); +} + +const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return iter->second; +} + +bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { + return mem_copy_manager_->AllocHostPinnedMem(size, addr); +} + +void MemSwapManager::ReleaseHostPinnedMem() { + for (const auto &host_addr : host_addrs_list_) { + if (host_addr.addr) { + mem_copy_manager_->FreeHostPinnedMem(host_addr.addr); + } + } + host_addrs_list_.clear(); +} + +void MemSwapManager::ClearSwapQueue() const { mem_copy_manager_->ClearSwapQueue(); } + +void MemSwapManager::ResetSwapInfo() { + ClearSwapQueue(); + for (auto &kernel_exec_info_pair : kernel_execution_info_) { + auto &kernel_exec_info = kernel_exec_info_pair.second; + kernel_exec_info.trigger_swap_ = false; + kernel_exec_info.need_swap_ = false; + kernel_exec_info.host_addrs_.clear(); + } + ReleaseHostPinnedMem(); + swap_in_blacklist_.clear(); + mem_swap_info_.clear(); +} +} // namespace memswap +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h new file mode 100644 index 0000000000..d8620c8516 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/optimizer/mem_reuse/mem_copy_manager.h" + +using PerformPair = std::pair; +namespace mindspore { +namespace device { +namespace memswap { +class MemSwapManager { + public: + explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) + : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { + mem_copy_manager_ = mem_copy_manager; + } + + MemSwapManager(const MemSwapManager &) = delete; + + MemSwapManager &operator=(const MemSwapManager &) = delete; + + ~MemSwapManager() = default; + + void Init(const mindspore::session::KernelGraph *kernel_graph); + + void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, + const HostAddress &host_address) const; + + bool SyncMemCopyStream(SwapKind swap_kind) const; + + DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; + + // retreat to find a workable swap scheme + bool RetreatSwapInfo(); + + bool trigger_swap() const { return trigger_swap_; } + + bool mem_swap_init() const { return mem_swap_initialized_; } + + KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const; + + void AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform); + + float QueryKernelExecutionPerform(const AnfNodePtr &kernel) const; + + void AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, const PerformPair &perform); + + const PerformPair &QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const; + + bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; + + bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; + + const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; + + void InsertSwapInBlackList(const void *device_ptr); + + bool FindInSwapInBlackList(const void *device_ptr) const; + + const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; + + bool AllocHostPinnedMem(size_t size, void **addr) const; + + void ReleaseHostPinnedMem(); + + void ClearSwapQueue() const; + + private: + void AddSwapInfo(); + + void ResetSwapInfo(); + + void SaveUserKernelTopoOrder(); + + void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); + + void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); + + void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + + bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; + + std::vector execution_order_; + std::vector ordered_tensors_; + std::unordered_map kernel_execution_info_; + std::unordered_map> kernel_swap_perform_; + // trigger swap kernel key : MemSwapInfo of kernel need to be swapped + std::unordered_map> mem_swap_info_; + std::vector host_addrs_list_; + std::unordered_set swap_in_blacklist_; + + size_t tensor_size_threshold_; + size_t tensor_size_threshold_idx_; + size_t tensor_size_num_; + size_t distance_threshold_; + + MemCopyManagerPtr mem_copy_manager_{nullptr}; + FuncGraphManagerPtr graph_manager_{nullptr}; + bool mem_swap_initialized_{false}; + bool swap_info_already_set_{false}; + bool trigger_swap_{false}; + + static constexpr size_t kDistanceInitFactor = 3; + static constexpr size_t kDistanceLowerBound = 3; +}; +using MemSwapManagerPtr = std::shared_ptr; +} // namespace memswap +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc new file mode 100644 index 0000000000..900dd0d563 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/add_atomic_clean.h" +#include +#include +#include +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "utils/graph_utils.h" +#include "utils/log_adapter.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +namespace { + +static std::vector g_output_idx; + +bool HasAtomic(const AnfNodePtr &input) { + if (IsPrimitiveCNode(input)) { + const auto &cnode = input->cast(); + const auto &prim = GetValueNode(cnode->input(0)); + return prim->HasAttr("atomic_add"); + } + return false; +} + +std::vector CalCleanSize(const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(pre_node); + std::vector clean_size_list; + // clean output + for (auto &index : g_output_idx) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(pre_node, index); + size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(pre_node, index); + auto size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + clean_size_list.push_back((size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize); + } + MS_LOG(DEBUG) << "Clear output size: " << clean_size_list.size() << ", pre_node: " << pre_node->fullname_with_scope(); + return clean_size_list; +} + +CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr &kernel_graph, + const mindspore::CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(pre_node); + auto clean_zero_prim = std::make_shared(kAtomicAddrCleanOpName); + auto new_value_node = NewValueNode(clean_zero_prim); + std::vector inputs = {new_value_node}; + CNodePtr clean_zero = kernel_graph->NewCNode(inputs); + AbstractBasePtr abstract = std::make_shared(); + clean_zero->set_abstract(abstract); + auto builder = std::make_shared(); + builder->SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clean_zero.get()); + auto clean_size = CalCleanSize(pre_node); + AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clean_zero); + AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(g_output_idx), clean_zero); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clean_zero.get()); + return clean_zero; +} +} // namespace + +void AddAtomicClean(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + auto &todos = kernel_graph->execution_order(); + for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) { + auto node = *iter; + if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) { + auto fg = GetValueNode(node->input(kAnfPrimitiveIndex)); + MS_EXCEPTION_IF_NULL(fg); + auto input = fg->get_return()->input(1); + if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + const auto &cnode = input->cast(); + for (size_t i = 0; i < cnode->inputs().size(); ++i) { + if (HasAtomic(cnode->input(i))) { + g_output_idx.push_back(i - 1); + } + } + } else if (HasAtomic(input)) { + g_output_idx.push_back(0); + } + + if (!g_output_idx.empty()) { + auto zero_node = CreateTbeAtomicCleanNode(kernel_graph, node); + auto depend = kernel_graph->NewCNode({NewValueNode(prim::kPrimDepend), node->input(1), zero_node}); + std::vector new_input = node->inputs(); + new_input[1] = depend; + auto new_cnode = std::make_shared(new_input, kernel_graph); + // Set abstract + new_cnode->set_abstract(node->abstract()); + // Set kernel info + new_cnode->set_kernel_info(node->kernel_info_ptr()); + mng->Replace(node, new_cnode); + g_output_idx.clear(); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h new file mode 100644 index 0000000000..7e3fbdb472 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/add_atomic_clean.h @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ + +#include +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +void AddAtomicClean(const std::shared_ptr &kernel_graph); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc new file mode 100644 index 0000000000..a485b196af --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/common_subexpression_elimination.h" +#include +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(main); + MS_EXCEPTION_IF_NULL(node); + auto main_kernel_info = main->kernel_info(); + auto node_kernel_info = node->kernel_info(); + if (main_kernel_info == nullptr && node_kernel_info == nullptr) { + return true; + } + if (main_kernel_info != nullptr && node_kernel_info != nullptr) { + return *main_kernel_info == *node_kernel_info; + } + return false; +} +} // namespace + +bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { + MS_EXCEPTION_IF_NULL(main); + MS_EXCEPTION_IF_NULL(node); + + bool replace = false; + if (main->isa() && node->isa()) { + auto main_value = GetValueNode(main); + auto node_value = GetValueNode(node); + if (main_value->isa() && node_value->isa()) { + replace = false; + } else if (main_value->isa() && node_value->isa()) { + replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); + } else { + replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + } + } else if (main->isa() && node->isa()) { + if (!CheckEqualKernelBuildInfo(main, node)) { + replace = false; + } else { + auto c_main = main->cast(); + MS_EXCEPTION_IF_NULL(c_main); + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() == inp2.size()) { + bool appsame = true; + for (size_t j = 0; j < inp1.size(); j++) { + MS_EXCEPTION_IF_NULL(inp1[j]); + MS_EXCEPTION_IF_NULL(inp2[j]); + if (!(*inp1[j] == *inp2[j])) { + appsame = false; + break; + } + } + replace = appsame; + } + } + } + return replace; +} + +bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto backend_cse = std::make_shared(); + return backend_cse->Cse(func_graph, func_graph->manager()); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h new file mode 100644 index 0000000000..bac870e59f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ +#include "backend/optimizer/common/pass.h" +#include "frontend/optimizer/cse.h" + +namespace mindspore { +namespace opt { +class CommonSubexpressionElimination : public Pass { + public: + CommonSubexpressionElimination() : Pass("cse") {} + ~CommonSubexpressionElimination() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; + +class BackendCSE : public CSE { + public: + BackendCSE() = default; + ~BackendCSE() override = default; + bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc new file mode 100644 index 0000000000..3ba055880c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -0,0 +1,274 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/communication_op_fusion.h" + +#include +#include +#include + +#include "utils/graph_utils.h" +#include "frontend/operator/ops.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "frontend/parallel/context.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr auto kAttrDefaultGroup = "default_group"; +constexpr auto kAttrDefaultOp = "default_op"; + +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index, + size_t end_index) { + if (end_index >= communication_op_info.communication_op_nodes.size()) { + MS_LOG(EXCEPTION) << "end index out of vector size"; + } + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); + inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); + builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); + builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); + } + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + +std::string GetFusionGroupKey(const AnfNodePtr &node) { + auto primitive = AnfAlgo::GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion); + if (attr_fusion == nullptr) { + return ""; + } + int fusion = GetValue(attr_fusion); + if (fusion == 0) { + return ""; + } + std::string group = kAttrDefaultGroup; + ValuePtr attr_group = primitive->GetAttr(kAttrGroup); + if (attr_group != nullptr) { + group = GetValue(attr_group); + } + std::string op = kAttrDefaultOp; + ValuePtr attr_op = primitive->GetAttr(kAttrOp); + if (attr_op != nullptr) { + op = GetValue(attr_op); + } + return group + op + std::to_string(fusion); +} +} // namespace + +bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, + std::vector *segment_index, const std::string &group) const { + MS_EXCEPTION_IF_NULL(segment_num); + MS_EXCEPTION_IF_NULL(segment_index); + size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); + MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size; + + auto parallel_context = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context); + const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); + + size_t segments = 0; + if (split_indices.size() != 0) { + uint32_t last_index = 0; + for (size_t i = 0; i < split_indices.size(); ++i) { + uint32_t index = split_indices[i]; + if (index <= last_index || index >= communication_op_node_size) { + MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index; + } + segment_index->push_back(index); + last_index = index; + segments++; + } + if (last_index != communication_op_node_size - 1) { + segment_index->push_back(communication_op_node_size - 1); + segments++; + } + } else { + segments = groups_; + for (size_t i = 0; i < segments - 1; ++i) { + segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1); + } + segment_index->push_back(communication_op_node_size - 1); + } + + if (segments >= communication_op_node_size) { + MS_LOG(INFO) << "fusion not changed: segment_num=" << segments + << ", communication_op_node_size=" << communication_op_node_size; + return false; + } + if (segment_index->at(segments - 1) != communication_op_node_size - 1) { + MS_LOG(EXCEPTION) << "the last segment index is invalid."; + } + for (size_t i = 0; i < segments - 1; ++i) { + if (segment_index->at(i) > segment_index->at(i + 1)) { + MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ " + << i + 1 << "]=" << segment_index->at(i + 1); + } + } + *segment_num = segments; + return true; +} + +AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, + size_t start_index, size_t end_index) const { + MS_EXCEPTION_IF_NULL(func_graph); + auto prim = std::make_shared(op_name_); + MS_EXCEPTION_IF_NULL(prim); + std::vector fusion_inputs = {NewValueNode(prim)}; + // get all inputs of current segment + if (end_index >= communication_op_info.communication_op_nodes.size()) { + MS_LOG(EXCEPTION) << "end index out of vector size"; + } + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + } + AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); + MS_EXCEPTION_IF_NULL(fused_node); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + fused_node->set_kernel_info(kernel_info); + AbstractBasePtrList abstract_list; + for (size_t idx = start_index; idx <= end_index; ++idx) { + auto cnode = communication_op_info.communication_op_nodes[idx]; + MS_EXCEPTION_IF_NULL(cnode); + AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); + AnfAlgo::CopyNodeAttr("op", cnode, fused_node); + AnfAlgo::CopyNodeAttr("group", cnode, fused_node); + abstract_list.push_back(cnode->abstract()); + } + auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + fused_node->set_abstract(abstract_tuple); + return fused_node; +} + +bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, + size_t segment_num, const std::vector &segment_index) const { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + bool changed = false; + size_t start_index = 0; + for (size_t segment_idx = 0; segment_idx < segment_num; ++segment_idx) { + size_t end_index = segment_index.at(segment_idx); + if (end_index - start_index < 1) { + start_index = end_index + 1; + continue; + } + AnfNodePtr new_communication_op = + CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index); + // replace old communication op with new communication op + for (auto idx = start_index; idx <= end_index; ++idx) { + std::vector tuple_getitem_input; + tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem)); + tuple_getitem_input.push_back(new_communication_op); + auto index = NewValueNode(SizeToInt(idx - start_index)); + MS_EXCEPTION_IF_NULL(index); + auto imm = std::make_shared(idx - start_index); + MS_EXCEPTION_IF_NULL(imm); + auto abstract_scalar = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_scalar); + index->set_abstract(abstract_scalar); + tuple_getitem_input.push_back(index); + AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input); + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx); + MS_EXCEPTION_IF_NULL(communication_op_node_item); + tuple_getitem->set_abstract(communication_op_node_item->abstract()); + if (!manager->Replace(communication_op_node_item, tuple_getitem)) { + MS_LOG(EXCEPTION) << "manager replace node failed"; + } + } + start_index = end_index + 1; + changed = true; + } + return changed; +} + +bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + const float input_grad_size_num = 0.0; + const float input_grad_time_num = 0.0; + // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion + std::unordered_map candidate_groups; + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == op_name_) { + std::string key = GetFusionGroupKey(node); + if (key.empty()) { + continue; + } + if (candidate_groups.find(key) == candidate_groups.end()) { + CommunicationOpInfo communication_op_info; + candidate_groups[key] = communication_op_info; + } + candidate_groups[key].communication_op_nodes.push_back(node->cast()); + candidate_groups[key].input_grad_size.push_back(input_grad_size_num); + candidate_groups[key].input_grad_time.push_back(input_grad_time_num); + } + } + // split candidate group to segments according to _group class member + bool changed = false; + for (auto &it : candidate_groups) { + if (it.second.communication_op_nodes.size() <= 1) { + continue; + } + auto first_node = it.second.communication_op_nodes[0]; + if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr(first_node, kAttrIndex) > 0) { + std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), + [](const CNodePtr &a, const CNodePtr &b) { + return AnfAlgo::GetNodeAttr(a, kAttrIndex) < AnfAlgo::GetNodeAttr(b, kAttrIndex); + }); + } + size_t segment_num = 0; + std::vector segment_index; + if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { + if (DoFusion(func_graph, it.second, segment_num, segment_index)) { + changed = true; + } + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h new file mode 100644 index 0000000000..0e7cf9762d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +struct CommunicationOpInfo { + std::vector communication_op_nodes; + std::vector input_grad_size; + std::vector input_grad_time; +}; + +class CommunicationOpFusion : public Pass { + public: + explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) + : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} + ~CommunicationOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num, + const std::vector &segment_index) const; + AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, size_t start_index, + size_t end_index) const; + bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, + std::vector *segment_index, const std::string &group) const; + std::string op_name_; + size_t groups_ = 1; +}; + +class AllReduceFusion : public CommunicationOpFusion { + public: + explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} + ~AllReduceFusion() override = default; +}; + +class AllGatherFusion : public CommunicationOpFusion { + public: + explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} + ~AllGatherFusion() override = default; +}; + +class BroadcastFusion : public CommunicationOpFusion { + public: + explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} + ~BroadcastFusion() override = default; +}; + +class ReduceScatterFusion : public CommunicationOpFusion { + public: + explicit ReduceScatterFusion(size_t groups = 1) + : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} + ~ReduceScatterFusion() override = default; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc new file mode 100644 index 0000000000..814ad9567c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/const_input_to_attr_registry.h" + +#include + +#include "utils/utils.h" +#include "utils/log_adapter.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { + Register(prim::kPrimCast->name(), {1}); + Register(prim::kPrimAvgPoolGrad->name(), {0}); + Register(prim::kPrimConv2DBackpropInput->name(), {2}); + Register(prim::kPrimConv2DBackpropFilter->name(), {2}); + Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); + Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0}); + Register(prim::kPrimReshape->name(), {1}); + Register(prim::kPrimReduceMax->name(), {1}); + Register(prim::kPrimReduceMin->name(), {1}); + Register(prim::kPrimReduceSum->name(), {1}); + Register(prim::kPrimReduceMean->name(), {1}); + Register(prim::kPrimGatherV2->name(), {2}); + Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); + Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1}); + Register(prim::kPrimSubscalar->name(), {1}); + Register(prim::kPrimTranspose->name(), {1}); + Register(prim::kPrimUnsortedSegmentSum->name(), {2}); + Register(prim::kPrimOneHot->name(), {1}); + Register(prim::kPrimConcat->name(), {0}); + Register(prim::kPrimCumSum->name(), {1}); + Register(prim::kPrimCumProd->name(), {1}); + Register(prim::kPrimReduceAll->name(), {1}); + Register(prim::kPrimUnsortedSegmentMin->name(), {2}); + Register(kSparseGatherV2, {2}); + Register(kUnsortedSegmentProdOpName, {2}); + Register(kSimpleMeanGradOpName, {1}); + Register(kMeanGradOpName, {1}); + Register(kSliceOpName, {1, 2}); + Register(kSliceGradOpName, {2, 3}); + Register(kTileOpName, {1}); + Register(kScatterNdOpName, {2}); + Register(kStridedSliceAssignOpName, {1, 2, 3}); + Register(kStridedSliceOpName, {1, 2, 3}); + Register(kFlattenGradOpName, {1}); + Register(kExpandDimsOpName, {1}); + Register(kSplitOpName, {0}); + Register(kErfOpName, {1}); + Register(kSparseApplyAdagradOpName, {2}); + Register(kResizeNearestNeighborGradOpName, {1}); + Register(kResizeNearestNeighborV2OpName, {1}); + Register(kResizeNearestNeighborV2GradOpName, {1}); + Register(kApplyRMSPropOpname, {5, 6, 7}); + Register(kResizeBilinearV2OpName, {1}); + Register(kReduceProdOpName, {1}); + Register(kCumprodOpName, {1}); + Register(kSpaceToBatchOpName, {1}); + Register(kBatchToSpaceOpName, {1}); + Register(kPadOpName, {1}); + Register(kPushOpName, {1}); +} + +ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { + static ConstInputToAttrInfoRegistry instance; + return instance; +} + +void ConstInputToAttrInfoRegistry::Register(const ConstInputToAttrInfoRegister ®) { + auto op_name = reg.GetOpName(); + if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { + (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; + } +} + +void ConstInputToAttrInfoRegistry::Register(const std::string &op_name, + const std::unordered_set &input_attr_set) { + if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { + ConstInputToAttrInfoRegister reg(op_name); + (void)reg.SetConstInputToAttr(input_attr_set); + (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); + MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; + } +} + +bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_name, + ConstInputToAttrInfoRegister *reg) const { + if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) { + *reg = op_input_to_attr_map_.at(op_name); + MS_LOG(DEBUG) << op_name << " const2attr find in registery."; + return true; + } + return false; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.h b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h similarity index 100% rename from mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.h rename to mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.h diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc new file mode 100644 index 0000000000..51d399bbcd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/context/ms_context.h" +#include "utils/utils.h" +#include "abstract/abstract_value.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +const size_t strides_index = 5; + +bool GetStridesValues(const CNodePtr &strided_slice_grad, ValuePtrList *strides_values) { + MS_EXCEPTION_IF_NULL(strided_slice_grad); + if (strided_slice_grad->size() < 6) { + MS_LOG(DEBUG) << "Op strided_slice_grad's inputs size less than 6, graph not changed"; + return false; + } + auto strides_input = strided_slice_grad->input(strides_index); + MS_EXCEPTION_IF_NULL(strides_input); + auto strides_value_node = strides_input->cast(); + if (strides_value_node == nullptr) { + MS_LOG(DEBUG) << "strides is not a value node."; + return false; + } + auto value = strides_value_node->value(); + if (value == nullptr) { + MS_LOG(DEBUG) << "strides has no value."; + return false; + } + auto value_tuple = value->cast(); + if (value_tuple == nullptr) { + MS_LOG(DEBUG) << "strides is not a value tuple."; + return false; + } + *strides_values = value_tuple->value(); + return true; +} + +bool CheckValues(const ValuePtrList &strides_values) { + if (strides_values.empty()) { + MS_LOG(DEBUG) << "strides_values is empty"; + return false; + } + for (auto &value : strides_values) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto scalar = value->cast(); + MS_EXCEPTION_IF_NULL(scalar); + if (!scalar->isa()) { + MS_LOG(DEBUG) << "strides value is not a Integer"; + return false; + } + if (GetValue(scalar) != 1) { + MS_LOG(DEBUG) << "StridedSliceGrad has no 1 value"; + return false; + } + } else { + MS_LOG(DEBUG) << "The value " << value << "of tuple is not a scalar"; + return false; + } + } + return true; +} + +bool CheckAttrs(const CNodePtr &strided_slice_grad) { + MS_EXCEPTION_IF_NULL(strided_slice_grad); + if (!AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) || + !AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) { + MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not exist in cnode[" + strided_slice_grad->DebugString() + "]"; + return false; + } + auto new_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrNewAxisMask); + auto shrink_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrShrinkAxisMask); + if (new_axis_mask != 0 || shrink_axis_mask != 0) { + MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not equal 0"; + return false; + } + return true; +} +} // namespace + +const BaseRef ConstToAttrStridedSliceGradPass::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto strided_slice_grad_prim = std::make_shared(kStridedSliceGradOpName); + return VectorRef({strided_slice_grad_prim, Xs}); +} + +const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto strided_slice_grad = node->cast(); + MS_EXCEPTION_IF_NULL(strided_slice_grad); + + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + + if (ms_context->device_target() == kAscendDevice) { + if (!CheckAttrs(strided_slice_grad)) { + MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; + return nullptr; + } + + ValuePtrList strides_values; + if (!GetStridesValues(strided_slice_grad, &strides_values)) { + return nullptr; + } + + if (!CheckValues(strides_values)) { + MS_LOG(INFO) << "Check strides' values failed, graph not changed"; + return nullptr; + } + } + + ConstInputToAttr(strided_slice_grad, {1, 2, 3, 4}); + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h new file mode 100644 index 0000000000..83b44d5f51 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConstToAttrStridedSliceGradPass : public PatternProcessPass { + public: + explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) + : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} + ~ConstToAttrStridedSliceGradPass() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc new file mode 100644 index 0000000000..f2e35351b4 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_const_input_to_attr.h" + +#include +#include +#include +#include + +#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + std::vector todos; + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + kernel::GetValidKernelNodes(sub_graph, &todos); + } else { + todos.push_back(node); + } + + for (auto &t : todos) { + CNodePtr cnode = t->cast(); + ConstInputToAttrInfoRegister reg; + if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) { + continue; + } + ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h new file mode 100644 index 0000000000..e6def42fa1 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ +#include +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertConstInputToAttr : public PatternProcessPass { + public: + explicit ConvertConstInputToAttr(bool multigraph = true) + : PatternProcessPass("convert_const_input_to_attr", multigraph) {} + ~ConvertConstInputToAttr() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::unordered_map> op_input_attr_map_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc new file mode 100644 index 0000000000..f204841f3c --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" + +#include +#include +#include + +#include "utils/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { + types.push_back(kTypeUnknown); + } + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + return new_value_node; +} + +AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { + MS_EXCEPTION_IF_NULL(input_node); + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + tensor::TensorPtr tensor_ptr = nullptr; + if (value->isa()) { + tensor_ptr = ScalarToTensor(value->cast()); + } else if (value->isa()) { + tensor_ptr = CreateTupleTensor(value->cast()); + } else { + MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple"; + } + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "Create tensor failed"; + return nullptr; + } + auto tensor_input = std::make_shared(tensor_ptr); + MS_EXCEPTION_IF_NULL(tensor_input); + tensor_input->set_abstract(tensor_ptr->ToAbstract()); + if (kernel_graph != nullptr) { + tensor_input = kernel_graph->NewValueNode(tensor_input); + kernel_graph->AddValueNodeToGraph(tensor_input); + } else { + tensor_input = MakeValueNode(tensor_input); + } + tensor_input->set_scope(input_node->scope()); + return tensor_input; +} + +AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs; + auto kernel_graph = func_graph->cast>(); + auto inputs = cnode->inputs(); + new_inputs.push_back(inputs[0]); + bool need_update = false; + // the first input is primitive node which is not the real input + for (size_t i = 0; i < inputs.size() - 1; ++i) { + auto input_node = inputs[i + 1]; + if (IsValueNode(input_node) || IsValueNode(input_node)) { + auto tensor_input = CreateTensorInput(kernel_graph, input_node); + if (tensor_input == nullptr) { + new_inputs.push_back(input_node); + continue; + } + new_inputs.push_back(tensor_input); + need_update = true; + } else { + new_inputs.push_back(input_node); + } + } + if (need_update) { + MS_EXCEPTION_IF_NULL(func_graph); + auto new_cnode = func_graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + AnfAlgo::CopyNodeAttrs(cnode, new_cnode); + return new_cnode; + } + return nullptr; +} + +AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + auto mng = sub_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + std::vector todo; + std::vector> graph_rets; + kernel::GetValidKernelNodes(sub_graph, &todo); + kernel::GetGraphRealOutput(sub_graph, &graph_rets); + + for (auto &t : todo) { + auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast()); + if (t_new_node != nullptr && t_new_node != t) { + (void)mng->Replace(t, t_new_node); + } + } + + return node; +} +} // namespace + +const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { + return nullptr; + } + if (AnfAlgo::IsGraphKernel(node)) { + return ProcessGraphKernelOp(node); + } else { + return ConstInputToTensorInput(func_graph, node->cast()); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h new file mode 100644 index 0000000000..072652497a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertConstInputToTensorInput : public PatternProcessPass { + public: + explicit ConvertConstInputToTensorInput(bool multigraph = true) + : PatternProcessPass("convert_const_input_to_tensor_input", multigraph) {} + ~ConvertConstInputToTensorInput() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc new file mode 100644 index 0000000000..b96a7af8f3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace opt { +namespace { +bool MakeValueNode(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return false; + } + + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + TypeId infer_data_type; + if (AnfAlgo::GetOutputTensorNum(value_node) == 0) { + infer_data_type = kTypeUnknown; + } else { + infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0); + } + kernel_build_info_builder->SetOutputsDeviceType(std::vector{infer_data_type}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get()); + return true; +} + +void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node, + std::vector *plant_inputs, std::vector *dyn_input_sizes) { + MS_EXCEPTION_IF_NULL(plant_inputs); + MS_EXCEPTION_IF_NULL(dyn_input_sizes); + MS_EXCEPTION_IF_NULL(graph); + auto output_size = AnfAlgo::GetOutputTensorNum(input_node); + dyn_input_sizes->push_back(output_size); + std::vector convert_inputs; + auto kernel_graph = graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + if (input_node->isa()) { + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node); + } else { + for (size_t index = 0; index < output_size; ++index) { + auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)}, + {AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get()); + convert_inputs.emplace_back(tuple_get_item); + } + } + (void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs)); +} + +void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(cnode_ptr); + MS_EXCEPTION_IF_NULL(graph); + auto &ori_args = cnode_ptr->inputs(); + if (ori_args.size() < 1) { + return; + } + std::vector plant_inputs; + std::vector dyn_input_sizes; + plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); + for (size_t i = 1; i < ori_args.size(); ++i) { + auto input_node = ori_args[i]; + if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { + auto input_size = AnfAlgo::GetOutputTensorNum(input_node); + dyn_input_sizes.push_back(input_size); + auto cnode = input_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + for (size_t j = 1; j < inputs.size(); ++j) { + MS_EXCEPTION_IF_NULL(inputs[j]); + if (IsValueNode(inputs[j])) { + auto success = MakeValueNode(inputs[j]); + if (!success) { + MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); + } + } + plant_inputs.push_back(inputs[j]); + } + } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { + ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); + } else { + dyn_input_sizes.push_back(-1); + plant_inputs.push_back(input_node); + } + } + // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs. + if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int s) { return s >= 0; })) { + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode_ptr); + cnode_ptr->set_inputs(plant_inputs); + } +} +} // namespace + +const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const { + VarPtr V = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + std::vector todos; + kernel::GetValidKernelNodes(sub_graph, &todos); + for (auto &t : todos) { + ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast()); + } + } else { + ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); + } + return node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h new file mode 100644 index 0000000000..63d2415dc5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ + +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertTupleInputToDynamicInput : public PatternProcessPass { + public: + explicit ConvertTupleInputToDynamicInput(bool multigraph = true) + : PatternProcessPass("convert_tuple_input_to_dynamic_input", multigraph) {} + + ~ConvertTupleInputToDynamicInput() override = default; + + const BaseRef DefinePattern() const override; + + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc new file mode 100644 index 0000000000..34ba83ef17 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/convert_tuple_output_to_maketuple.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +namespace { +CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(cnode_ptr); + MS_EXCEPTION_IF_NULL(graph); + std::vector convert_inputs = {cnode_ptr->input(0)}; + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { + auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); + if (AnfAlgo::IsTupleOutput(input_node)) { + std::vector types; + std::vector> shapes; + std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; + for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { + make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); + types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); + } + auto make_tuple = graph->NewCNode(make_tuple_inputs_list); + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); + convert_inputs.emplace_back(make_tuple); + } else { + convert_inputs.push_back(input_node); + } + } + return graph->NewCNode(convert_inputs); +} +} // namespace + +const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { + VarPtr V = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + return nullptr; + } + if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { + return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); + })) { + return ConvertTupleInputToMakeTuple(func_graph, cnode); + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.h b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.h new file mode 100644 index 0000000000..9ff5ca91ed --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H +#define MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H +#include +#include + +#include "ir/anf.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class ConvertTupleOutputToMaketuple : public PatternProcessPass { + public: + explicit ConvertTupleOutputToMaketuple(bool multigraph = true) + : PatternProcessPass("convert_tuple_output_to_maketuple", multigraph) {} + + ~ConvertTupleOutputToMaketuple() override = default; + + const BaseRef DefinePattern() const override; + + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc new file mode 100644 index 0000000000..3ef912bcec --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc @@ -0,0 +1,190 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/eliminate_redundant_op.h" +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace opt { +using KernelWithIndex = std::pair; +namespace { +CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector *pass_vector) { + MS_EXCEPTION_IF_NULL(pass_vector); + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::IsRealCNodeKernel(cnode)) { + pass_vector->push_back(make_pair(cnode, IntToSize(1))); + return cnode; + } + + auto input0 = cnode->input(0); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + auto temp_node = cnode->input(index + IntToSize(1)); + MS_EXCEPTION_IF_NULL(temp_node); + pass_vector->push_back(make_pair(cnode, index + IntToSize(1))); + return GetRealPrevCNode(temp_node, 0, pass_vector); + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + auto input2 = cnode->input(2); + MS_EXCEPTION_IF_NULL(input2); + auto value_node = input2->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + pass_vector->push_back(make_pair(cnode, IntToSize(1))); + return GetRealPrevCNode(cnode->input(1), IntToSize(item_idx), pass_vector); + } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + pass_vector->push_back(make_pair(cnode, IntToSize(1))); + return GetRealPrevCNode(cnode->input(1), 0, pass_vector); + } else { + return nullptr; + } +} + +bool TransOpEliminateCondition(const CNodePtr &, const CNodePtr &) { return true; } + +bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { + return HasSymmetricalKernelInfo(node1, node2); +} + +bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { + return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) && + AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0); +} + +const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode, + std::vector *pass_vector) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(pass_vector); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + bool has_depend_node = false; + bool has_node_used_more_than_once = false; + auto &users = manager->node_users(); + + auto pass_size = pass_vector->size(); + for (size_t idx = 1; idx <= pass_size - 1; ++idx) { + auto nd = (*pass_vector)[idx].first; + if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || + AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { + has_depend_node = true; + } + if (users[nd].size() >= 2) { + has_node_used_more_than_once = true; + } + } + + // when no depend node and no node used more than once, no need to rebuild the pass nodes + if (!has_depend_node) { + return prev_cnode->input(1); + } else if (!has_node_used_more_than_once) { + (void)manager->Replace(prev_cnode, prev_cnode->input(1)); + return cnode->input(1); + } else { // rebuild the pass nodes + for (size_t idx = pass_size - 2; idx > 0; --idx) { + auto new_node = func_graph->NewCNode((*pass_vector)[idx].first->inputs()); + new_node->set_input((*pass_vector)[idx].second, + (*pass_vector)[idx + 1].first->input((*pass_vector)[idx + 1].second)); + (*pass_vector)[idx].first = new_node; + } + return (*pass_vector)[1].first; + } +} +} // namespace + +void EliminateRedundantOp::Init() { + (void)redundant_process_map_.emplace(std::pair( + kFour2FiveOpName, std::pair(kFive2FourOpName, TransOpEliminateCondition))); + (void)redundant_process_map_.emplace(std::pair( + kFive2FourOpName, std::pair(kFour2FiveOpName, TransOpEliminateCondition))); + (void)redundant_process_map_.emplace(std::pair( + prim::kPrimCast->name(), std::pair(prim::kPrimCast->name(), CastEliminateCondition))); + (void)redundant_process_map_.emplace(std::pair( + kTransDataOpName, std::pair(kTransDataOpName, TransDataOpEliminateCondition))); +} + +const AnfNodePtr EliminateRedundantOp::DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { + // match the first name + auto name1 = AnfAlgo::GetCNodeName(cnode); + auto it = redundant_process_map_.find(name1); + if (it == redundant_process_map_.end()) { + return nullptr; + } + std::vector pass_vector; + pass_vector.push_back(make_pair(cnode, 1)); + auto prev_cnode = GetRealPrevCNode(cnode->input(1), 0, &pass_vector); + if (prev_cnode == nullptr) { + return nullptr; + } + // match the second name + auto name2 = AnfAlgo::GetCNodeName(prev_cnode); + if (name2 != it->second.first) { + return nullptr; + } + // match condition + auto condition_func = it->second.second; + if (condition_func == nullptr) { + return nullptr; + } + if (!condition_func(cnode, prev_cnode)) { + return nullptr; + } + + return ProcessMatchedNodes(func_graph, cnode, prev_cnode, &pass_vector); +} + +const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr || func_graph == nullptr) { + return nullptr; + } + + if (AnfAlgo::IsGraphKernel(node)) { + // do eliminate for ops in graph kernel. + auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(sub_graph); + auto mng = sub_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + std::vector todo; + kernel::GetValidKernelNodes(sub_graph, &todo); + for (auto &t : todo) { + CNodePtr t_cnode = t->cast(); + MS_EXCEPTION_IF_NULL(t_cnode); + auto t_new_node = DoEliminate(sub_graph, t_cnode); + if (t_new_node != nullptr && t_new_node != t) { + (void)mng->Replace(t, t_new_node); + } + } + return node; + } + // do eliminate for single op. + return DoEliminate(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h new file mode 100644 index 0000000000..2fb4715cff --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +using ConditionFunc = std::function; +using RedundantOpPair = std::pair; + +class EliminateRedundantOp : public PatternProcessPass { + public: + explicit EliminateRedundantOp(bool multigraph = true) : PatternProcessPass("eliminate_redundant_op", multigraph) { + Init(); + } + ~EliminateRedundantOp() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void Init(); + const AnfNodePtr DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; + std::unordered_map redundant_process_map_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.cc new file mode 100644 index 0000000000..8c6cb4beb5 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/erase_visit_attr.h" +#include +#include +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef EraseVisitAttr::DefinePattern() const { + std::shared_ptr V = std::make_shared(Visited); + std::shared_ptr Xs = std::make_shared(); + return VectorRef({V, Xs}); +} + +const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) { + if (AnfAlgo::IsGraphKernel(node)) { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + std::vector todos; + kernel::GetValidKernelNodes(fg, &todos); + for (auto &t : todos) { + AnfAlgo::EraseNodeAttr(kAttrVisited, t); + } + } + AnfAlgo::EraseNodeAttr(kAttrVisited, node); + } else { + AnfAlgo::EraseNodeAttr(kAttrVisited, node); + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h new file mode 100644 index 0000000000..37b88a4e39 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/erase_visit_attr.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class EraseVisitAttr : public PatternProcessPass { + public: + explicit EraseVisitAttr(bool multigraph = true) : PatternProcessPass("erase_visit_attr", multigraph) {} + ~EraseVisitAttr() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc new file mode 100644 index 0000000000..32655f1ec2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.cc @@ -0,0 +1,222 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/fuse_basic.h" +#include "backend/optimizer/pass/fuse_graph_kernel.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "utils/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "vm/segment_runner.h" +#include "debug/draw.h" +#include "debug/anf_ir_dump.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace { +std::vector get_fusable_basic_ops(bool is_before_kernel_select) { + std::vector fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, + prim::kPrimExpandDims}; + if (!is_before_kernel_select) { + fusable_basic_ops.push_back(prim::kPrimCast); + } + return fusable_basic_ops; +} + +IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, + const AnfNodePtr &node) { + if (cur_node == node) { + return FOLLOW; + } + if (!IsPrimitiveCNode(node)) { + return EXCLUDE; + } + + auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); + bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + + return is_fusable ? FOLLOW : EXCLUDE; +} + +std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { + GraphKernelInfo info; + info.is_before_kernel_select = is_before_kernel_select; + // Search fusable nodes according input direction. + auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); + auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); + if (used_nodes.size() > 1) { + used_nodes = RemoveCircle(used_nodes, false); + } + TopoSortForNodeList(&used_nodes); + return used_nodes; +} + +void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { + AnfNodeSet outputs_set; + for (auto out : *outputs) { + outputs_set.insert(out); + } + + AnfNodePtrList vir_outputs; + std::unordered_map eqv; + auto fg_outputs = fg->output(); + if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { + auto cnode = fg_outputs->cast(); + for (size_t i = 1; i < cnode->size(); ++i) { + vir_outputs.push_back(cnode->input(i)); + } + } else { + vir_outputs.push_back(fg_outputs); + } + + if (vir_outputs.size() != outputs->size()) { + MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; + } + bool has_erase_outs = false; + size_t index = -1; + for (auto it = outputs->begin(); it != outputs->end();) { + index++; + auto out = *it; + eqv[out] = vir_outputs[index]; + auto users = mng->node_users()[out]; + bool is_only_control_depend_use = true; + std::vector control_depend_use_index; + std::vector control_depend_nodes; + AnfNodePtr use_out = nullptr; + for (auto &user : users) { + auto use_node = user.first; + if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { + is_only_control_depend_use = false; + continue; + } + if (outputs_set.count(use_node) != 0) { + use_out = use_node; + } + + if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { + control_depend_nodes.push_back(use_node->cast()); + control_depend_use_index.push_back(user.second); + } + } + + if (is_only_control_depend_use && !control_depend_nodes.empty()) { + MS_EXCEPTION_IF_NULL(use_out); + it = outputs->erase(it); + for (size_t i = 0; i < control_depend_nodes.size(); ++i) { + auto control_depend_node = control_depend_nodes[i]; + std::vector new_control_depend_inputs; + for (size_t j = 0; j < control_depend_node->size(); ++j) { + if (j == control_depend_use_index[i]) { + new_control_depend_inputs.push_back(use_out); + } else { + new_control_depend_inputs.push_back(control_depend_node->input(j)); + } + } + auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); + mng->Replace(control_depend_node, new_control_depend); + has_erase_outs = true; + } + } else { + it++; + } + } + + if (!has_erase_outs) { + return; + } + + AnfNodePtr fg_new_output; + if (outputs->size() > 1) { + std::vector output_args; + output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); + (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args), + [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); + // Set output for AnfGraph + fg_new_output = fg->NewCNode(output_args); + } else { + fg_new_output = eqv[(*outputs)[0]]; + } + fg->set_output(fg_new_output, true); +} + +void FuseBasic(const std::shared_ptr &kernel_graph, const std::vector &todos, + std::unordered_set *fused_ops, bool is_before_kernel_select) { + auto mng = kernel_graph->manager(); + for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { + auto node = (*iter)->cast(); + if (node == nullptr) { + continue; + } + if (fused_ops->count(node)) { + continue; + } + auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select); + bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + if (!is_basic_op || !kernel_graph->nodes().contains(node)) { + continue; + } + + auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); + if (fuse_nodes.size() <= 1) { + continue; + } + + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); + RemoveControlDependOut(fg, &outputs, mng); + auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); + + ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); + + // Set graph kernel attr + std::string fuse_op_name = ""; + for (auto &fuse_node : fuse_nodes) { + fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; + } + fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); + fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); + } +} +} // namespace + +void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + std::unordered_set fused_ops; + auto todos = TopoSort(kernel_graph->get_return()); + std::reverse(todos.begin(), todos.end()); + FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h new file mode 100644 index 0000000000..9b3916fe28 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_basic.h @@ -0,0 +1,29 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc new file mode 100644 index 0000000000..e04110d8a0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.cc @@ -0,0 +1,562 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/fuse_graph_kernel.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "utils/graph_utils.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "vm/segment_runner.h" +#include "debug/draw.h" +#include "debug/anf_ir_dump.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +std::vector get_fusable_basic_ops(bool is_before_kernel_select) { + std::vector fusable_basic_ops = { + prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum, + prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt, + prim::kPrimReciprocal, prim::kPrimExpandDims, prim::kPrimLessEqual}; + if (!is_before_kernel_select) { + fusable_basic_ops.push_back(prim::kPrimCast); + } + return fusable_basic_ops; +} + +std::vector get_fusable_basic_ops_with_reduce(bool is_before_kernel_select) { + std::vector fusable_basic_ops_with_reduce; + if (!is_before_kernel_select) { + fusable_basic_ops_with_reduce.push_back(prim::kPrimCast); + } + return fusable_basic_ops_with_reduce; +} + +std::vector get_reduce_ops() { + std::vector reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin, + prim::kPrimReduceMax, prim::kPrimReduceAll}; + return reduce_ops; +} + +void GetGraphKernelInfo(const FuncGraphPtr fg, GraphKernelInfo *info) { + MS_EXCEPTION_IF_NULL(fg); + auto reduce_ops = get_reduce_ops(); + const auto &nodes = fg->nodes(); + info->op_type = ELEWISE; + info->cal_step = -1; + info->reduce_op_num = 0; + for (auto node : nodes) { + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + info->cal_step++; + auto prim = GetValueNode(cnode->input(0)); + if (prim != nullptr) { + bool is_reudce = std::any_of(reduce_ops.begin(), reduce_ops.end(), [&prim](const PrimitivePtr &op) { + return op->hash() == prim->hash() && op->name() == prim->name(); + }); + if (is_reudce) { + info->op_type = REDUCE; + info->reduce_op_num++; + } + } + } +} + +bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) { + auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); + auto fusable_basic_ops_with_reduce = get_fusable_basic_ops_with_reduce(info.is_before_kernel_select); + bool is_fusable = false; + if (info.op_type == REDUCE && + (info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) { + is_fusable = std::any_of(fusable_basic_ops_with_reduce.begin(), fusable_basic_ops_with_reduce.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + } else { + is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); + } + + return is_fusable; +} + +IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, + const AnfNodePtr &node) { + if (cur_node == node) { + return FOLLOW; + } + if (!IsPrimitiveCNode(node)) { + return EXCLUDE; + } + + bool is_fusable = IsFuse(info, node); + return is_fusable ? FOLLOW : EXCLUDE; +} + +IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, + const AnfNodePtr &node) { + if (cur_node == node) { + return FOLLOW; + } + if (AnfAlgo::IsGraphKernel(node)) { + auto cnode = node->cast(); + auto fg = GetValueNode(cnode->input(kAnfPrimitiveIndex)); + auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + MS_EXCEPTION_IF_NULL(fg_attr_val); + auto fg_attr = GetValue(fg_attr_val); + if (fg_attr == kApplyMomentumOpName) { + return FOLLOW; + } + return EXCLUDE; + } + if (!IsPrimitiveCNode(node)) { + return EXCLUDE; + } + + bool is_fusable = IsFuse(info, node); + return is_fusable ? FOLLOW : EXCLUDE; +} + +bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, + std::set *cached_unconnected_set) { + if (!check_node->isa() || AnfAlgo::IsGraphKernel(check_node)) { + return false; + } + + auto cnode = check_node->cast(); + const auto &inputs = cnode->inputs(); + // there is a input not in fused_op_set, but the input depends on the fused_op_set + bool has_circle = false; + for (auto input : inputs) { + if (input->isa() && !fused_op_set.count(input)) { + std::set done; + std::vector todos = {input}; + while (!todos.empty()) { + auto node = todos.back(); + todos.pop_back(); + if (done.count(node) || cached_unconnected_set->count(node)) { + continue; + } + + done.insert(node); + if (fused_op_set.count(node)) { + has_circle = true; + break; + } + + if (node->isa()) { + auto cnode_ptr = node->cast(); + for (auto it : cnode_ptr->inputs()) { + if (it->isa()) { + todos.push_back(it); + } + } + } + } + + if (has_circle) { + return true; + } + cached_unconnected_set->insert(done.begin(), done.end()); + } + } + + return false; +} + +bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { + if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { + auto &inputs = out->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + real_outs->push_back(inputs[i]); + } + return true; + } + + if (AnfAlgo::GetCNodeFuncGraphPtr(out) != nullptr) { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); + auto fg_out = fg->output(); + if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) { + auto inputs = fg_out->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + real_outs->push_back(inputs[i]); + } + return true; + } + } + return false; +} + +std::vector RemoveCircle(const std::vector &fused_op, bool is_backward) { + std::set cached_unconnected_set; + std::set fused_op_set(fused_op.begin(), fused_op.end()); + auto include = [&fused_op_set](const AnfNodePtr &node) { + if (fused_op_set.count(node)) { + return FOLLOW; + } + return EXCLUDE; + }; + for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { + bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set); + // delete the circle node and the node which depend on the circle node in fused op + if (has_circle) { + auto mng = (*iter)->func_graph()->manager(); + std::vector erase_nodes; + if (is_backward) { + erase_nodes = DeepUsersSearch(*iter, include, mng); + } else { + erase_nodes = DeepLinkedGraphSearch(*iter, include); + } + for (auto erase_node : erase_nodes) { + fused_op_set.erase(erase_node); + } + } + } + + std::vector res; + for (auto node : fused_op) { + if (fused_op_set.count(node)) { + res.push_back(node); + } + } + return res; +} + +void TopoSortForNodeList(std::vector *lst) { + if (lst->size() < 2) { + return; + } + + std::vector res; + std::set node_sets(lst->begin(), lst->end()); + std::map> ins; + std::map> outs; + std::queue q; + for (auto node : *lst) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (auto input : cnode->inputs()) { + if (!node_sets.count(input)) { + continue; + } + // out_degree + outs[input].insert(node); + // in_degree + ins[node].insert(input); + } + if (!ins.count(node)) { + ins[node] = {}; + } + } + + for (auto p : ins) { + if (p.second.size() == 0) { + q.push(p.first); + } + } + + while (!q.empty()) { + auto node = q.front(); + q.pop(); + res.push_back(node); + if (!outs.count(node)) { + continue; + } + for (auto out : outs[node]) { + if (!ins.count(out)) { + continue; + } + ins[out].erase(node); + if (ins[out].size() == 0) { + q.push(out); + } + } + } + + lst->assign(res.begin(), res.end()); +} + +std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { + auto func_graph = cnode->func_graph(); + auto graph_kernel_g = GetValueNode(cnode->input(0)); + GraphKernelInfo info; + info.is_before_kernel_select = is_before_kernel_select; + GetGraphKernelInfo(graph_kernel_g, &info); + auto mng = func_graph->manager(); + // Search fusable nodes according input direction. + auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); + auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); + std::reverse(used_nodes.begin(), used_nodes.end()); + // Search fusable nodes according output direction. + auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, info, std::placeholders::_1); + auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); + + used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); + if (used_nodes.size() > 1) { + used_nodes = RemoveCircle(used_nodes); + } + TopoSortForNodeList(&used_nodes); + return used_nodes; +} + +AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { + auto out_spec = node->abstract(); + if (out_spec->isa()) { + return out_spec->cast()->elements()[output_idx]; + } + return out_spec; +} + +AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, + const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, + bool is_before_kernel_select) { + auto func_node = NewValueNode(fg); + std::vector fn_inputs; + fn_inputs.push_back(func_node); + fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); + auto fuse_cnode = kernel_graph->NewCNode(fn_inputs); + // Set output abstract + if (outputs.size() > 1) { + std::vector out_specs; + for (size_t i = 0; i < outputs.size(); ++i) { + out_specs.push_back(outputs[i]->abstract()); + } + auto out_spec = std::make_shared(out_specs); + fuse_cnode->set_abstract(out_spec); + } else { + fuse_cnode->set_abstract(outputs[0]->abstract()); + } + // Set parameter abstract. + for (size_t i = 0; i < inputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); + auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); + fg->parameters()[i]->set_abstract(input_abs); + if (is_before_kernel_select) { + fg->parameters()[i]->set_kernel_info(std::make_shared()); + } + } + // Set kernel info. + if (!is_before_kernel_select) { + std::vector graph_input_format; + std::vector graph_input_type; + std::vector graph_output_format; + std::vector graph_output_type; + for (size_t i = 0; i < inputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); + auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); + graph_input_format.push_back(input_format); + auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); + graph_input_type.push_back(input_type); + auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); + fg->parameters()[i]->set_abstract(input_abs); + } + auto new_outputs = outputs; + if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) { + std::vector real_outs; + if (IsMakeTupleOut(outputs[0], &real_outs)) { + new_outputs = real_outs; + } + } + for (size_t i = 0; i < new_outputs.size(); ++i) { + auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0); + auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); + graph_output_format.push_back(output_format); + graph_output_type.push_back(output_type); + } + kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; + graph_info_builder.SetInputsFormat(graph_input_format); + graph_info_builder.SetInputsDeviceType(graph_input_type); + graph_info_builder.SetOutputsFormat(graph_output_format); + graph_info_builder.SetOutputsDeviceType(graph_output_type); + graph_info_builder.SetProcessor(kernel::Processor::AICORE); + graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); + graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + auto graph_selected_info = graph_info_builder.Build(); + AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, fuse_cnode.get()); + } + return fuse_cnode; +} + +void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, + const AnfNodePtrList &outputs) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + // single out + if (outputs.size() == 1) { + mng->Replace(outputs[0], new_fuse_cnode); + return; + } + + std::vector fn_inputs; + for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { + AnfNodePtrList real_outs; + // not make tuple out, replace + if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) { + fn_inputs.clear(); + fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); + fn_inputs.push_back(new_fuse_cnode); + fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx)))); + auto new_out = kernel_graph->NewCNode(fn_inputs); + new_out->set_abstract(outputs[out_idx]->abstract()); + mng->Replace(outputs[out_idx], new_out); + continue; + } + + // the out is make tuple , modify the get_item node's value + auto users = mng->node_users()[outputs[out_idx]]; + for (auto &user : users) { + auto use_node = user.first; + if (use_node->isa() && (IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem))) { + auto get_item_cnode = use_node->cast(); + auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(value_input); + auto value_node = value_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + int new_item_idx = SizeToInt(out_idx) + item_idx; + fn_inputs.clear(); + fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); + fn_inputs.push_back(new_fuse_cnode); + fn_inputs.push_back(NewValueNode(new_item_idx)); + auto new_out = kernel_graph->NewCNode(fn_inputs); + new_out->set_abstract(get_item_cnode->abstract()); + mng->Replace(get_item_cnode, new_out); + } + } + } +} + +AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) { + AnfNodePtrList outs; + auto out_node = (*fg)->output(); + if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { + std::vector output_args; + auto out_cnode = out_node->cast(); + for (auto out : out_cnode->inputs()) { + if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { + auto inputs = out->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + output_args.push_back(inputs[i]); + } + } else { + output_args.push_back(out); + } + } + if (output_args.size() != out_cnode->inputs().size()) { + auto new_out = (*fg)->NewCNode(output_args); + (*mng)->Replace(out_node, new_out); + } + + for (size_t i = 1; i < output_args.size(); ++i) { + outs.push_back(output_args[i]); + } + return outs; + } + + outs.push_back(out_node); + return outs; +} + +AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { + AnfNodePtrList res; + if (outs.size() <= 1) { + return outs; + } + + for (auto out : outs) { + AnfNodePtrList real_outs; + if (IsMakeTupleOut(out, &real_outs)) { + res.insert(res.end(), real_outs.begin(), real_outs.end()); + continue; + } + res.push_back(out); + } + return res; +} + +void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto mng = kernel_graph->manager(); + if (mng == nullptr) { + mng = Manage(kernel_graph, true); + kernel_graph->set_manager(mng); + } + auto &todos = kernel_graph->execution_order(); + for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { + auto node = *iter; + if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { + continue; + } + + auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (fg_attr != nullptr) { + auto fg_name = GetValue(fg_attr); + if (graph_kernel_black_list.count(fg_name) != 0) { + continue; + } + } + + auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); + if (fuse_nodes.size() <= 1) { + continue; + } + + FuncGraphPtr fg; + AnfNodePtrList inputs; + AnfNodePtrList outputs; + std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); + + // Remove nest make tuple in outs + auto expand_out = GetExpandOuts(outputs); + auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select); + + ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); + + // Inline origin graphkernel + auto cnodes = fg->GetOrderedCnodes(); + for (const auto &n : cnodes) { + if (!AnfAlgo::IsGraphKernel(n)) { + continue; + } + auto graph_kernel_g = GetValueNode(n->input(0)); + AnfNodePtrList ins; + ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); + auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); + mng->Replace(n, out); + } + + EliminateMakeTuple(&fg, &mng); + // Set graphkernel flag + auto ori_fg = GetValueNode(node->input(kAnfPrimitiveIndex)); + fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, ori_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h new file mode 100644 index 0000000000..e14661dfdf --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/fuse_graph_kernel.h @@ -0,0 +1,63 @@ + +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ + +#include +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace opt { +enum GraphKernelType { + ELEWISE = 0, // only contain elewise basic ops + REDUCE, // contain reduce ops + CUBE, // contain cube ops +}; +struct GraphKernelInfo { + GraphKernelType op_type = ELEWISE; + bool is_before_kernel_select = false; + int reduce_op_num = 0; + int cal_step = 0; +}; + +// when reduce graph kernel's cal step is greater than this number, not fuse +const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5; +// when reduce graph kernel contain reduce op num is greater than this number, not fuse +const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2; + +const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", + "LambNextMV", "LambUpdateWithLR"}; + +std::vector RemoveCircle(const std::vector &fused_op, bool is_backward = true); + +void TopoSortForNodeList(std::vector *lst); + +AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, + const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, + bool is_before_kernel_select); + +void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, + const AnfNodePtrList &outputs); + +void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select = false); +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc new file mode 100644 index 0000000000..a51a6bab42 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/getitem_tuple.h" + +#include +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool IsC(const BaseRef &n) { + MS_EXCEPTION_IF_NULL(n); + if (utils::isa(n)) { + AnfNodePtr in = utils::cast(n); + MS_EXCEPTION_IF_NULL(in); + return in->isa(); + } else { + return false; + } +} +} // namespace + +const BaseRef GetitemTuple::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr C = std::make_shared(IsC); + return VectorRef({prim::kPrimTupleGetItem, VectorRef({prim::kPrimMakeTuple, Xs}), C}); +} + +const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + CNodePtr tuple_getitem = node->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) { + MS_LOG(EXCEPTION) << "tuple getitem's input num is wrong"; + } + AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(make_tuple_anf); + AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(index_node); + if (IsValueNode(index_node)) { + ValueNodePtr value_node = index_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int index = GetValue(value_node->value()); + CNodePtr make_tuple = make_tuple_anf->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + if (make_tuple->inputs().size() > IntToSize(index + 1)) { + auto ret = make_tuple->input(IntToSize(index + 1)); + MS_EXCEPTION_IF_NULL(ret); + return ret; + } + } + return nullptr; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h new file mode 100644 index 0000000000..9a25b924bd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class GetitemTuple : public PatternProcessPass { + public: + explicit GetitemTuple(bool multigraph = true) : PatternProcessPass("getitem_tuple", multigraph) {} + ~GetitemTuple() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc new file mode 100644 index 0000000000..710e130a85 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/pass/optimize_dependence.h" +#include +#include +#include +#include "backend/optimizer/common/helper.h" +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +constexpr auto kSingleInputIndex = 1; +namespace { +AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + string op_name = AnfAlgo::GetCNodeName(cnode); + // Currently we only eliminate transdata or cast nodes. + if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { + return nullptr; + } + CheckCNodeInputSize(cnode, kSingleInputIndex + 1); + return cnode->input(kSingleInputIndex); +} + +AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { + return nullptr; + } + std::vector new_make_tuple_inputs; + bool need_update = false; + for (const auto &input : cnode->inputs()) { + AnfNodePtr replace_input = GetReplaceNode(input); + // If replace input is not null, it will be the input of the TransData or Cast. + if (replace_input == nullptr) { + new_make_tuple_inputs.push_back(input); + continue; + } + new_make_tuple_inputs.push_back(replace_input); + need_update = true; + } + if (need_update) { + auto kernel_graph = func_graph->cast>(); + CNodePtr new_make_tuple = nullptr; + if (kernel_graph == nullptr) { + new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs); + } else { + new_make_tuple = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_make_tuple); + new_make_tuple->set_inputs(new_make_tuple_inputs); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(cnode, new_make_tuple); + return new_make_tuple; + } + return nullptr; +} +} // namespace + +const BaseRef OptimizeDependence::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return nullptr; + } + auto node_name = AnfAlgo::GetCNodeName(node); + if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { + return nullptr; + } + size_t index = 0; + auto depend_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(depend_cnode); + std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; + if (node_name == prim::kPrimDepend->name()) { + index = 1; + new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); + } + if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { + MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " + << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); + } + auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); + while (index < input_num) { + auto replace_node = GetConvertNode(func_graph, node, index); + MS_EXCEPTION_IF_NULL(replace_node); + new_depend_inputs.push_back(replace_node); + ++index; + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_depend = nullptr; + if (kernel_graph == nullptr) { + new_depend = func_graph->NewCNode(new_depend_inputs); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_abstract(node->abstract()); + new_depend->set_scope(node->scope()); + } else { + new_depend = kernel_graph->NewCNode(depend_cnode); + MS_EXCEPTION_IF_NULL(new_depend); + new_depend->set_inputs(new_depend_inputs); + } + return new_depend; +} + +const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, + const size_t index) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto depend_cnode = node->cast(); + auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); + MS_EXCEPTION_IF_NULL(replacing_node); + if (!replacing_node->isa()) { + return replacing_node; + } + auto replacing_cnode = replacing_node->cast(); + MS_EXCEPTION_IF_NULL(replacing_cnode); + // Deal with the make_tuple with TransData or Cast inputs. + auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode); + if (make_tuple_replace_node != nullptr) { + return make_tuple_replace_node; + } + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); + if (replace_node == nullptr) { + MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); + return replacing_node; + } + return replace_node; +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h new file mode 100644 index 0000000000..8ddd4d662e --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class OptimizeDependence : public PatternProcessPass { + public: + explicit OptimizeDependence(bool multigraph = true) : PatternProcessPass("optimize_dependence", multigraph) {} + ~OptimizeDependence() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc new file mode 100644 index 0000000000..cd34464cda --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/pass/replace_node_by_proxy.h" +#include +#include +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace opt { +kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); + inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); + builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); + builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); + + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + +bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { + CNodePtr cnode = node->cast(); + auto prim = std::make_shared(kEmbeddingLookupProxyOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector proxy_inputs = {NewValueNode(prim)}; + proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs); + MS_EXCEPTION_IF_NULL(proxy_node); + + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + proxy_node->set_kernel_info(kernel_info); + + AbstractBasePtrList abstract_list; + AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); + AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); + AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); + abstract_list.push_back(cnode->abstract()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + proxy_node->set_abstract(abstract_tuple); + + auto kernel_build_info = GenerateKernelBuildInfo(cnode); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get()); + + if (!manager->Replace(cnode, proxy_node)) { + MS_LOG(EXCEPTION) << "Replace node by proxy node failed."; + } + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h new file mode 100644 index 0000000000..382b08304f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace opt { +class ReplaceNodeByProxy : public Pass { + public: + explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {} + ~ReplaceNodeByProxy() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ diff --git a/mindspore/ccsrc/backend/session/CMakeLists.txt b/mindspore/ccsrc/backend/session/CMakeLists.txt new file mode 100644 index 0000000000..b7b791ada9 --- /dev/null +++ b/mindspore/ccsrc/backend/session/CMakeLists.txt @@ -0,0 +1,32 @@ +file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "kernel_graph.cc" + "session_basic.cc" + "session_factory.cc" + "anf_runtime_algorithm.cc" +) + +if (ENABLE_GPU) + file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "gpu_session.cc" + ) + list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) +endif () + +if (ENABLE_CPU) + file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "cpu_session.cc" + ) + list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) +endif () + +if (ENABLE_D) + file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "ascend_session.cc" + "ascend_control_parser.cc" + "ascend_inference_session.cc" + ) + list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) +endif () + +set_property(SOURCE ${_SESSION_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_SESSION) +add_library(_mindspore_backend_session_obj OBJECT ${_SESSION_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc new file mode 100644 index 0000000000..0e5af203bc --- /dev/null +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -0,0 +1,1121 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/anf_runtime_algorithm.h" +#include +#include +#include +#include +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "runtime/device/device_address.h" +#include "backend/optimizer/common/helper.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "common/utils.h" +#include "common/trans.h" + +namespace mindspore { +namespace session { +using abstract::AbstractTensor; +using abstract::AbstractTuple; +using device::KernelInfo; +using device::ascend::AscendDeviceAddress; +using kernel::KernelBuildInfoPtr; +using kernel::KernelMod; +using kernel::KernelModPtr; +namespace { +std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { + MS_EXCEPTION_IF_NULL(shape); + std::vector shape_size_t; + std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); + return shape_size_t; +} +} // namespace + +KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + if (anf_node->isa()) { + return std::make_pair(anf_node, 0); + } else if (anf_node->isa()) { + return std::make_pair(anf_node, 0); + } else if (anf_node->isa()) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input0 = cnode->input(0); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + auto node = cnode->input(index + IntToSize(1)); + MS_EXCEPTION_IF_NULL(node); + return VisitKernel(node, 0); + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(input2); + auto value_node = input2->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); + } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + return VisitKernel(cnode->input(kRealInputIndexInDepend), 0); + } else { + return std::make_pair(anf_node, index); + } + } else { + MS_LOG(EXCEPTION) << "The input is invalid"; + } +} + +KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, + bool visit_nop_node, + const std::vector &return_types) { + MS_EXCEPTION_IF_NULL(anf_node); + for (const auto &prim_type : return_types) { + if (CheckPrimitiveType(anf_node, prim_type)) { + return std::make_pair(anf_node, index); + } + } + if (anf_node->isa()) { + return std::make_pair(anf_node, 0); + } else if (anf_node->isa()) { + return std::make_pair(anf_node, 0); + } else if (anf_node->isa()) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input0 = cnode->input(0); + MS_EXCEPTION_IF_NULL(input0); + if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(input2); + auto value_node = input2->cast(); + MS_EXCEPTION_IF_NULL(value_node); + int item_idx = GetValue(value_node->value()); + return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), + visit_nop_node, return_types); + } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); + } else if (opt::IsNopNode(cnode) && visit_nop_node) { + if (cnode->inputs().size() == 2) { + return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); + } else { + MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; + } + } else { + return std::make_pair(anf_node, index); + } + } else { + MS_LOG(EXCEPTION) << "The input is invalid"; + } +} + +std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, + const std::vector &return_types) { + std::vector ret; + auto return_prim_type = return_types; + // if visited make_tuple should return back + return_prim_type.push_back(prim::kPrimMakeTuple); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type); + if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { + MS_EXCEPTION_IF_NULL(item_with_index.first); + auto make_tuple = item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t i = 1; i < make_tuple->inputs().size(); i++) { + auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types); + (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret)); + } + return ret; + } + ret.push_back(item_with_index.first); + return ret; +} + +AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->input(kAnfPrimitiveIndex); +} + +PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = GetCNodePrimitiveNode(cnode); + MS_EXCEPTION_IF_NULL(attr_input); + auto value_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + auto primitive = value->cast(); + return primitive; +} + +bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); +} + +FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto value_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + return value->cast(); +} + +std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + return primitive->name(); + } + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->ToString(); + } + MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString(); +} + +std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->DebugString(); +} + +void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); + } + // single op cnode. + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + primitive->set_attr(key, value); + return; + } + // graph kernel cnode. + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + fg->set_attr(key, value); +} + +void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) { + CopyNodeAttr(key, key, from, to); +} + +void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, + const AnfNodePtr &to) { + MS_EXCEPTION_IF_NULL(from); + MS_EXCEPTION_IF_NULL(to); + if (!from->isa() || !to->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is " + << to->DebugString(); + } + auto from_primitive = AnfAlgo::GetCNodePrimitive(from); + MS_EXCEPTION_IF_NULL(from_primitive); + auto to_primitive = AnfAlgo::GetCNodePrimitive(to); + MS_EXCEPTION_IF_NULL(to_primitive); + to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key)); +} + +void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) { + MS_EXCEPTION_IF_NULL(from); + MS_EXCEPTION_IF_NULL(to); + if (!from->isa() || !to->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is " + << from->DebugString(); + } + auto from_primitive = AnfAlgo::GetCNodePrimitive(from); + MS_EXCEPTION_IF_NULL(from_primitive); + auto to_primitive = AnfAlgo::GetCNodePrimitive(to); + MS_EXCEPTION_IF_NULL(to_primitive); + (void)to_primitive->SetAttrs(from_primitive->attrs()); +} + +void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); + } + // single op cnode. + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + primitive->EraseAttr(key); + return; + } + // graph kernel cnode. + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + fg->erase_flag(key); +} + +bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString(); + return false; + } + // single op cnode. + auto primitive = AnfAlgo::GetCNodePrimitive(node); + if (primitive != nullptr) { + return primitive->HasAttr(key); + } + // graph kernel cnode. + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + return fg->has_attr(key); +} + +size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString(); + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + size_t input_num = cnode->inputs().size(); + if (input_num == 0) { + MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"; + } + // exclude intputs[0],which is value_node storing attr,inputs left are real input + return input_num - 1; +} + +size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + TypePtr type = node->Type(); + if (type == nullptr) { + return 0; + } + if (type->isa()) { + auto tuple_type = type->cast(); + MS_EXCEPTION_IF_NULL(tuple_type); + return tuple_type->size(); + } else if (type->isa() || type->isa()) { + return 1; + } else if (type->isa()) { + return 0; + } else { + return 1; + } +} + +std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "Output index:" << output_idx + << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" + << node->DebugString() << "]"; + } + if (!AnfAlgo::IsRealKernel(node)) { + return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto format = build_info->GetOutputFormat(output_idx); + if (format == kernel::KernelBuildInfo::kInvalidFormat) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid output format"; + } + return format; +} + +std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "Input index :" << input_idx + << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" + << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + GetPrevNodeOutputFormat(node, input_idx); + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto format = build_info->GetInputFormat(input_idx); + if (format == kernel::KernelBuildInfo::kInvalidFormat) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid input format"; + } + return format; +} + +KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); + } + auto node = cnode->input(input_idx + 1); + MS_EXCEPTION_IF_NULL(node); + return VisitKernel(node, 0); +} + +std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + abstract::BaseShapePtr base_shape = node->Shape(); + MS_EXCEPTION_IF_NULL(base_shape); + if (base_shape->isa() && output_idx == 0) { + return TransShapeToSizet(base_shape->cast()); + } else if (base_shape->isa()) { + auto tuple_shape = base_shape->cast(); + MS_EXCEPTION_IF_NULL(tuple_shape); + if (output_idx >= tuple_shape->size()) { + MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size() + << "."; + } + auto b_shp = (*tuple_shape)[output_idx]; + if (b_shp->isa()) { + return TransShapeToSizet(b_shp->cast()); + } else if (b_shp->isa()) { + return std::vector(); + } else { + MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx + << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); + } + } else if (base_shape->isa()) { + return std::vector(); + } + MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " + << base_shape->ToString(); +} + +std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { + auto format = GetOutputFormat(node, output_idx); + auto infer_shape = GetOutputInferShape(node, output_idx); + if (infer_shape.empty()) { + return infer_shape; + } + // if format is default_format or NC1KHKWHWC0,device shape = original shape + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); + } + return trans::TransShapeToDevice(infer_shape, format); +} + +std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { + auto format = GetInputFormat(node, input_idx); + auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); + if (infer_shape.empty()) { + return infer_shape; + } + // if format is default_format or NC1KHKWHWC0,device shape = original shape + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); + } + return trans::TransShapeToDevice(infer_shape, format); +} + +std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index:" << input_idx + << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" + << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputReshapeType(node, input_idx); + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + if (build_info->IsInputDefaultPadding()) { + return {}; + } + return build_info->GetInputReshapeType(input_idx); +} + +std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputReshapeType(node, output_idx); + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + if (build_info->IsOutputDefaultPadding()) { + return {}; + } + return build_info->GetOutputReshapeType(output_idx); +} + +TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + TypePtr type_ptr = node->Type(); + MS_EXCEPTION_IF_NULL(type_ptr); + if (type_ptr->isa() && output_idx == 0) { + auto tensor_ptr = type_ptr->cast(); + MS_EXCEPTION_IF_NULL(tensor_ptr); + TypePtr elem = tensor_ptr->element(); + MS_EXCEPTION_IF_NULL(elem); + return elem->type_id(); + } else if (type_ptr->isa()) { + auto tuple_ptr = type_ptr->cast(); + MS_EXCEPTION_IF_NULL(tuple_ptr); + if (output_idx >= tuple_ptr->size()) { + MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); + } + auto tuple_i = (*tuple_ptr)[output_idx]; + MS_EXCEPTION_IF_NULL(tuple_i); + if (tuple_i->isa()) { + auto tensor_ptr = tuple_i->cast(); + MS_EXCEPTION_IF_NULL(tensor_ptr); + TypePtr elem = tensor_ptr->element(); + MS_EXCEPTION_IF_NULL(elem); + return elem->type_id(); + } else if (tuple_i->isa()) { + return tuple_i->type_id(); + } else { + MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); + return tuple_i->type_id(); + } + } else if (type_ptr->isa()) { + return type_ptr->type_id(); + } + return type_ptr->type_id(); +} + +TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); +} + +TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputDeviceDataType(node, output_idx); + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto dtype = build_info->GetOutputDeviceType(output_idx); + if (dtype == TypeId::kNumberTypeEnd) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid dtype"; + } + return dtype; +} + +TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { + MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " + << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; + } + if (!IsRealKernel(node)) { + return GetPrevNodeOutputDeviceDataType(node, 0); + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + auto dtype = build_info->GetInputDeviceType(input_idx); + if (dtype == TypeId::kNumberTypeEnd) { + MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" + << " has a invalid dtype"; + } + return dtype; +} + +TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); +} + +// get output device addr of anf_node +const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, + bool visit_nop_node) { + MS_EXCEPTION_IF_NULL(node); + if (opt::IsNopNode(node) && visit_nop_node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() == 2) { + return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); + } else { + MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; + } + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetOutputAddr(output_idx); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() + << " output addr is not exist"; + } + return addr; +} + +DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, + bool visit_nop_node) { + MS_EXCEPTION_IF_NULL(node); + if (opt::IsNopNode(node) && visit_nop_node) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() == 2) { + return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); + } else { + MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; + } + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetMutableOutputAddr(output_idx); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() + << " output addr is not exist"; + } + return addr; +} + +// get output device addr of anf_node +bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->OutputAddrExist(output_idx); +} + +const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); +} + +DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); + return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); +} + +// set output device addr of anf_node +void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + if (!kernel_info->SetOutputAddr(addr, output_idx)) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; + } +} + +// set workspace device addr of anf_node +void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; + } +} + +// get workspace device addr of anf_node +DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetWorkspaceAddr(output_idx); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() + << "] workspace addr is not exist"; + } + return addr; +} + +// set infer shapes and types of anf node +void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, + const std::vector> &shapes, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + if (types.size() != shapes.size()) { + MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); + } + if (shapes.empty()) { + node->set_abstract(std::make_shared()); + } else if (shapes.size() == 1) { + // single output handle + std::vector shape_int; + std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt); + auto abstract = std::make_shared(TypeIdToType(types[0]), shape_int); + node->set_abstract(abstract); + } else { + // multiple output handle + std::vector abstract_list; + for (size_t i = 0; i < types.size(); ++i) { + std::vector shape_int; + std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); + abstract_list.push_back(std::make_shared(TypeIdToType(types[i]), shape_int)); + } + auto abstract_tuple = std::make_shared(abstract_list); + node->set_abstract(abstract_tuple); + } +} +// copy an abstract of a node to another node +void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) { + to_node->set_abstract(from_node->abstract()); +} + +kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + // select_kernel_build_info() has checked whether return pointer is null + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->op_pattern(); +} + +// get KernelBuildType of node, such as ATT,RT,FWK and so on +KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + // select_kernel_build_info() has checked whether return pointer is null + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->kernel_type(); +} + +kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->processor(); +} + +kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(build_info); + return build_info->fusion_type(); +} + +// set select kernel_build_info +void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->set_select_kernel_build_info(select_kernel_build_info); +} + +// get select kernel_build_info +KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->GetMutableSelectKernelBuildInfo(); +} + +// get kernelMode +KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->MutableKernelMod(); +} + +// set kernel mod +void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_kernel_mod(kernel_mod); +} + +bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real kernel too + if (!node->isa()) { + return true; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); + } + auto input = cnode->inputs()[0]; + bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || + IsPrimitive(input, prim::kPrimTensorSummary) || + IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || + IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || + IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || + IsPrimitive(input, prim::kPrimReturn); + return !is_virtual_node; +} + +bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // parameter and value node is not a real cnode kernel + if (!node->isa()) { + return false; + } + // return considered as a real node + if (CheckPrimitiveType(node, prim::kPrimReturn)) { + return true; + } + return IsRealKernel(node); +} + +bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // graph kernel should be a real cnode kernel. + if (!IsRealCNodeKernel(node)) { + return false; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input = cnode->input(kAnfPrimitiveIndex); + // graph kernel should has func_graph as first input. + if (!IsValueNode(input)) { + return false; + } + + auto func_graph = GetValueNode(input); + MS_EXCEPTION_IF_NULL(func_graph); + return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); +} + +bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { + MS_EXCEPTION_IF_NULL(node); + return node->has_default(); +} + +void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_stream_id(stream_id); +} + +uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->stream_id(); +} + +void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_stream_distinction_label(stream_label); +} + +uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->stream_distinction_label(); +} + +void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + kernel_info->set_graph_id(graph_id); +} + +uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->graph_id(); +} + +bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { + MS_EXCEPTION_IF_NULL(anf); + TypePtr type = anf->Type(); + MS_EXCEPTION_IF_NULL(type); + return type->isa(); +} + +AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + auto get_input_index = index + 1; + if (index + 1 > node->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" + << node->inputs().size(); + } + // input 0 is primitive node + return node->input(get_input_index); +} + +bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return false; + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->is_feature_map(); +} + +bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = cnode->input(input_index + 1); + return IsFeatureMapOutput(input_node); +} + +size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { + MS_EXCEPTION_IF_NULL(anf_node); + static std::map> spec_node_list = { + {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, + {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}}, + {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, + {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, + {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, + {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, + {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, + {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, + {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, + {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, + {prim::kPrimApplyCenteredRMSProp->name(), + {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}}; + size_t ret = cur_index; + auto node_name = AnfAlgo::GetCNodeName(anf_node); + if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) { + auto find = spec_node_list.find(node_name); + if (find != spec_node_list.end()) { + ret = find->second[cur_index]; + MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name; + } + } + return ret; +} + +void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(input_node); + node->set_input(index + 1, input_node); +} + +bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto kernel_name = AnfAlgo::GetCNodeName(node); + if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || + kernel_name == kReduceScatterOpName) { + return true; + } + return false; +} + +bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { + auto kernel_name = AnfAlgo::GetCNodeName(node); + return kernel_name == kGetNextOpName; +} + +FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + auto value = value_node->value(); + if (value == nullptr) { + return nullptr; + } + auto func_graph = value->cast(); + return func_graph; +} + +std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { + MS_EXCEPTION_IF_NULL(call_node); + if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { + MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; + } + auto input1 = call_node->input(1); + MS_EXCEPTION_IF_NULL(input1); + if (input1->isa()) { + auto value_node = input1->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto kernel_graph = value_node->value(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return {kernel_graph->cast()}; + } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { + auto switch_node = input1->cast(); + MS_EXCEPTION_IF_NULL(switch_node); + auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { + auto partial = switch_node->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return GetValueNode(partial); + } + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto graph_node = partial_cnode->input(1); + MS_EXCEPTION_IF_NULL(graph_node); + auto graph_value_node = graph_node->cast(); + MS_EXCEPTION_IF_NULL(graph_value_node); + auto graph_value = graph_value_node->value(); + MS_EXCEPTION_IF_NULL(graph_value); + auto child_graph = graph_value->cast(); + return child_graph; + }; + return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; + } + return {}; +} + +bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { + MS_EXCEPTION_IF_NULL(call_node); + if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { + MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString(); + } + auto input1 = call_node->input(1); + if (input1->isa()) { + return false; + } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { + return true; + } + MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); +} + +bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { + std::vector all_opt_list; + std::vector non_opt_list; + + for (const auto &node : *node_list) { + MS_EXCEPTION_IF_NULL(node); + if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { + all_opt_list.emplace_back(node); + } else { + non_opt_list.emplace_back(node); + } + } + node_list->clear(); + std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); + std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); +} + +TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto prim = AnfAlgo::GetCNodePrimitive(node); + if (prim == nullptr) { + return kTypeUnknown; + } + + TypeId except_type = kTypeUnknown; + if (prim->GetAttr(kAttrOutputPrecision) != nullptr) { + auto output_type_str = GetValue(prim->GetAttr(kAttrOutputPrecision)); + if (output_type_str == "float16") { + except_type = kNumberTypeFloat16; + } else if (output_type_str == "float32") { + except_type = kNumberTypeFloat32; + } else { + MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str; + } + } + + return except_type; +} + +TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode."; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (input_idx + 1 >= cnode->inputs().size()) { + MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); + } + auto input_node = cnode->input(input_idx + 1); + MS_EXCEPTION_IF_NULL(input_node); + auto kernel_with_index = VisitKernel(input_node, 0); + if (!kernel_with_index.first->isa()) { + return kTypeUnknown; + } + return GetCNodeOutputPrecision(kernel_with_index.first); +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h new file mode 100644 index 0000000000..6bfc714d66 --- /dev/null +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -0,0 +1,210 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H +#define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H +#include +#include +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "ir/dtype.h" +#include "base/base.h" +#include "ir/primitive.h" +#include "runtime/device/device_address.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "frontend/operator/ops.h" +#include "utils/contract.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace session { +using AnfVisitFuncion = std::function; +using KernelWithIndex = std::pair; +class AnfRuntimeAlgorithm { + public: + // get input_anf_node's real kernel by recurse + static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); + static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, + bool visit_nop_node = false, + const std::vector &return_types = { + prim::kPrimMakeTuple}); + static std::vector GetAllOutput(const AnfNodePtr &node, + const std::vector &return_types = {}); + // get cnode primitive + static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); + static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); + static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); + // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple + static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); + // get cnode primitive + static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node); + // get kernel_name of anf node + static std::string GetCNodeName(const AnfNodePtr &node); + // get detail info of anf node + static std::string GetNodeDebugString(const AnfNodePtr &node); + // get attr of anf node + template + static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + std::string node_debug_log = node->DebugString(); + MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); + } + // single op cnode. + if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { + return GetValue(primitive->GetAttr(key)); + } + // graph kernel cnode. + auto fg = GetCNodeFuncGraphPtr(node); + MS_EXCEPTION_IF_NULL(fg); + return GetValue(fg->get_attr(key)); + } + static bool IsTupleOutput(const AnfNodePtr &anf); + // set attr of anf node + static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); + // set attr of key from 'from' node to 'to' node + static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to); + // set a new key for attr from 'from' node to 'to' node + static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, + const AnfNodePtr &to); + // set all attrs from 'from' node to 'to' node + static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); + // check whether a cnode has the specified attr. + static bool HasNodeAttr(const std::string &key, const CNodePtr &node); + // delete attr of anf node + static void EraseNodeAttr(const std::string &key, AnfNodePtr node); + // get the num of input real_kernel(which can be build and run in device) + static size_t GetInputTensorNum(const AnfNodePtr &node); + // get the num of output real_kernel(which can be build and run in device) + static size_t GetOutputTensorNum(const AnfNodePtr &node); + // get output format select of anf node + static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); + // get input format select of anf node + static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); + // get prev node output width output index + static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); + // get output format from prev node,input_index is the input index of current node related to prev node + static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); + // get reshape_type of from the output of input node. + static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); + // get output shapes inferred by ME from input nodes. + static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); + // get input shapes inferred by ME from input nodes. + static std::vector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx); + // get output shapes which will built and run in device + static std::vector GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); + // get input shapes which will built and run in device + static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); + // Get Input Padding Axis + static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); + // Get Output Padding Axis + static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); + // get output data type inferred by ME of anf node + static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); + // get output original data type from prev node,input_index is the input index of current node related to prev node + static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx); + // get output select data type of anf node + static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx); + // get input select data type of anf node + static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx); + // get output select data type from prev node,input_index is the input index of current node related to prev node + static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); + // get output device addr of anf_node + static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); + // get mutable output device addr of anf_node + static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); + // check whether output addr is exist or not + static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); + // get address from prev node,input_index is the input index of current node related to prev node + static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, + bool visit_nop_node = true); + static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node = true); + // set output device addr of anf_node + static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); + // set workspace device addr of anf_node + static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); + // get workspace device addr of anf_node + static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); + // set infer shapes and types of anf node + static void SetOutputInferTypeAndShape(const std::vector &types, + const std::vector> &shapes, AnfNode *node); + static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); + // get op pattern of the node + static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); + // get KernelBuildType of node ,such as ATT,RT,FWK and so on + static KernelType GetKernelType(const AnfNodePtr &node); + // get processor type:AICORE,AICPU... + static kernel::Processor GetProcessor(const AnfNodePtr &node); + // get fusion type:AICORE,AICPU... + static kernel::FusionType GetFusionType(const AnfNodePtr &node); + // set select kernel_build_info + static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node); + // get select kernel_build_info + static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node); + // get kernelMode + static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node); + // set kernel mod + static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node); + // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too + static bool IsRealKernel(const AnfNodePtr &node); + // checkout whether the anf node is a real kernel that is a cnode and can run on device + static bool IsRealCNodeKernel(const AnfNodePtr &node); + // checkout whether the anf node is a graph kernel. + static bool IsGraphKernel(const AnfNodePtr &node); + // check parameter is weight or data + static bool IsParameterWeight(const ParameterPtr &node); + // set stream id of kernel,which will be set in stream assign and be used in stream generate + static void SetStreamId(uint32_t stream_id, AnfNode *node); + // get stream id + static uint32_t GetStreamId(const AnfNodePtr &node); + // set stream distinction label to distinguish different ops in different streams + static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node); + // get stream distinction label + static uint32_t GetStreamDistinctionLabel(const AnfNode *node); + // set graph id + static void SetGraphId(uint32_t graph_id, AnfNode *node); + // get graph id + static uint32_t GetGraphId(const AnfNode *node); + static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); + // charge if the node's output is a feature map output + static bool IsFeatureMapOutput(const AnfNodePtr &node); + // charge if the node's input is from a feature map output + static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); + // get real input index for some tbe ops which input order is different between me and tbe impl + static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); + static bool IsCommunicationOp(const AnfNodePtr &node); + static bool IsGetNext(const NotNull &node); + static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); + static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); + static bool IsSwitchCall(const CNodePtr &call_node); + static bool IsScalarInput(const CNodePtr &cnode, size_t index); + static bool IsScalarOutput(const CNodePtr &cnode, size_t index); + static void ReorderExecList(NotNull *> node_list); + // get fix output precision of cnode. + static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); + // get fix output precision from prev node, input_idx is the input index of current node related to prev node. + static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); +}; +} // namespace session +using AnfAlgo = session::AnfRuntimeAlgorithm; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc new file mode 100644 index 0000000000..656a6b40ed --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -0,0 +1,643 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/ascend_control_parser.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/union_find_set.h" +#include "runtime/device/ascend/ascend_label_assign.h" + +static constexpr size_t kCNodePrim = 0; +static constexpr size_t kCNodeCallArg = 1; +static constexpr size_t kCNodeSwitchCond = 1; +static constexpr size_t kCNodeSwitchTrue = 2; +static constexpr size_t kCNodeSwitchFalse = 3; +static constexpr size_t kCNodeSwitchLength = 4; +static constexpr size_t kCNodePartialLength = 2; +static constexpr size_t kCNodePartialFunc = 1; +static constexpr size_t kCNodeSwitchLayerBranch = 2; +static constexpr size_t kCNodeSwitchLayerLength = 3; + +namespace mindspore { +namespace session { +static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { + auto &nodes = parent_graph->execution_order(); + CNodePtr last_jump_node = nullptr; + for (auto &node : nodes) { + if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { + if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { + return node; + } + last_jump_node = node; + } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { + if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || + child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { + return node; + } + last_jump_node = node; + } + } + if (last_jump_node == nullptr) { + MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); + } + return last_jump_node; +} + +static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, + const NotNull *> memo) { + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + + const std::vector>> &real_inputs = kg->real_inputs(); + for (auto &iter : real_inputs) { + auto ¶ = iter.first; + MS_EXCEPTION_IF_NULL(para); + if (para->isa()) { + union_find_set->Add(para); + } + for (auto &arg : iter.second) { + MS_EXCEPTION_IF_NULL(arg); + if (!arg->isa()) { + continue; + } + union_find_set->Add(arg); + } + } + for (auto &child : kg->child_graph_order()) { + InitUnionFindSet(NOT_NULL(child), union_find_set, memo); + } +} + +static void UnionParentParameter(NotNull kg, const NotNull *> union_find_set, + const NotNull *> memo) { + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + + const std::vector>> &real_inputs = kg->real_inputs(); + for (auto &iter : real_inputs) { + auto ¶ = iter.first; + for (auto &arg : iter.second) { + MS_EXCEPTION_IF_NULL(arg); + if (!arg->isa()) { + continue; + } + if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { + continue; + } + union_find_set->Union(arg, para); + } + } + for (auto &child : kg->child_graph_order()) { + UnionParentParameter(NOT_NULL(child), union_find_set, memo); + } +} + +static UnionFindSet MakeUnionFindSet(NotNull root_kg) { + UnionFindSet result; + std::set memo; + InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); + memo.clear(); + UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); + return result; +} + +static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, + const std::set ¶meter_reuse_set, + const NotNull *> memo) { + if (parameter_reuse_set.empty()) { + MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty."; + } + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + + for (auto ¶ : parameter_reuse_set) { + if (para == main_parameter.get()) { + continue; + } + MS_EXCEPTION_IF_NULL(para); + MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " + << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); + kg->ReplaceNode(NOT_NULL(para), main_parameter); + } + + for (auto &child : kg->child_graph_order()) { + RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); + } +} + +static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr key, + const std::set ¶meter_reuse_set) { + AnfNodePtr main_parameter = key; + std::set root_inputs_set; + const auto &root_inputs_vector = root_kg->inputs(); + root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); + for (auto &node : parameter_reuse_set) { + if (root_inputs_set.find(node) != root_inputs_set.end()) { + main_parameter = node; + break; + } + } + return main_parameter; +} + +static void ReuseParameter(NotNull root_kg, NotNull *> parameter_set) { + auto parameter_reuse_sets = parameter_set->GetSets(); + for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { + if (parameter_reuse_set.size() <= 1) { + continue; + } + auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); + std::set memo; + RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); + } +} + +CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { + for (size_t i = start; i < list.size() - 1; ++i) { + if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { + return list[i]; + } + } + return nullptr; +} + +void AscendControlParser::LinkGraph(NotNull kg) { + std::set memo; + (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); + std::map graph_id_map; + for (auto &g : memo) { + MS_EXCEPTION_IF_NULL(g); + if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { + MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() + << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); + } + graph_id_map[g->graph_id()] = g; + } + + // Insert Assign + ChildGraphDataAssign(graph_id_map); + // Make UnionFindSet + UnionFindSet parameter_set = MakeUnionFindSet(kg); + // Reuse Parameter + ReuseParameter(kg, NOT_NULL(¶meter_set)); +} + +void AscendControlParser::ExecutorValidate(NotNull root_graph) { + std::set memo; + (void)RecurseGraph(root_graph, NOT_NULL(&memo)); +} + +void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { + for (auto &iter : graph_id_map) { + auto &kg = iter.second; + MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); + MS_EXCEPTION_IF_NULL(kg); + std::set> memo; + const std::vector>> &real_inputs = kg->real_inputs(); + for (auto &it : real_inputs) { + auto ¶meter = it.first; + auto &args = it.second; + for (auto &arg : args) { + MS_EXCEPTION_IF_NULL(arg); + if (memo.find({parameter, arg}) != memo.end()) { + continue; + } else { + memo.emplace(parameter, arg); + } + auto unreuse_args_map = kg->unreuse_args(); + auto unreuse_arg_iter = unreuse_args_map.find(arg); + if (unreuse_arg_iter == unreuse_args_map.end()) { + MS_EXCEPTION_IF_NULL(arg); + MS_EXCEPTION_IF_NULL(parameter); + if (!arg->isa()) { + MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; + } + MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() + << ", arg:" << arg->DebugString(); + continue; + } + auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); + if (target_graph_iter == graph_id_map.end()) { + MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; + } + InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), + NOT_NULL(parameter)); + } + } + kg->SetExecOrderByDefault(); + } +} + +NotNull AscendControlParser::GetStartLabel(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label) { + CNodePtr start_label; + if (last_node != nullptr && last_label != nullptr) { + start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); + kg->set_start_label(start_label); + } else { + // no goto node will jump to start label of root graph, so return a fake label + start_label = std::make_shared(std::vector(), FuncGraphPtr(nullptr)); + } + return NOT_NULL(start_label); +} + +NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label, + const NotNull *> memo) { + MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); + + // 1. recursive condition + if (memo->find(kg) != memo->end()) { + MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); + return NOT_NULL(kg->get_start_label()); + } + memo->insert(kg.get()); + + // 2. args replace placeholder + LinkParentGraph(kg, last_node, last_label); + + // 3. topological sort + kg->SetExecOrderByDefault(); + const std::vector &nodes = kg->execution_order(); + // 4. insert first_label + CNodePtr start_label = GetStartLabel(kg, last_node, last_label); + + // 5. traverse + for (size_t i = 0; i < nodes.size(); ++i) { + auto &cnode = nodes[i]; + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() < kCNodePrim + 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); + if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { + MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); + continue; + } + AnfNodePtr arg = cnode->input(kFirstDataInputIndex); + MS_EXCEPTION_IF_NULL(arg); + if (IsValueNode(arg)) { + RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } else if (!arg->isa()) { + MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); + } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { + auto arg_cnode = arg->cast(); + MS_EXCEPTION_IF_NULL(arg_cnode); + cnode->set_inputs(arg_cnode->inputs()); + RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { + auto arg_cnode = arg->cast(); + MS_EXCEPTION_IF_NULL(arg_cnode); + cnode->set_inputs(arg_cnode->inputs()); + RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); + } + } + kg->SetExecOrderByDefault(); + MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); + return NOT_NULL(start_label); +} + +void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { + auto return_node = kg->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + return_node->input(kFirstDataInputIndex), attch_node.get()}; + auto depend_node = kg->NewCNode(inputs); + return_node->set_input(1, depend_node); +} + +void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, + NotNull second_node) { + MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() + << ", the second node is " << second_node->DebugString(); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), + first_node, second_node}; + auto control_depend = kg->NewCNode(inputs); + InsertDependToGraph(kg, NOT_NULL(control_depend)); +} + +void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, + const CNodePtr &last_label) { + // if not entry graph, replace return with label_goto + if (from_graph_call_node != nullptr && last_label != nullptr) { + auto label_goto = + kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); + MS_EXCEPTION_IF_NULL(label_goto); + MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); + kg->set_end_goto(label_goto); + } +} + +void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo) { + MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); + + // 1 get kernel graph + const std::vector &origin_inputs = cur_node->inputs(); + if (kCNodeCallArg >= origin_inputs.size()) { + MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); + } + std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; + if (!IsValueNode(origin_inputs[kCNodeCallArg])) { + MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; + return; + } + // 2 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " + << cur_node->DebugString(); + // 3 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); + // 4 modify call op to goto op + cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); + // 5 recurse sub graph + CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); + new_inputs.push_back(sub_label); + cur_node->set_inputs(new_inputs); + cur_node->set_abstract(nullptr); + MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); +} + +void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, + const CNodePtr &next_node, const NotNull *> memo) { + MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); + + if (cur_node->size() < kCNodeSwitchLength) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; + } + // 1 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_EXCEPTION_IF_NULL(back_label); + MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " + << cur_node->DebugString(); + // 2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + // 3 recurse sub graph + const std::vector &origin_switch_inputs = cur_node->inputs(); + if (kCNodeSwitchCond >= origin_switch_inputs.size()) { + MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond; + } + std::vector new_switch_inputs = { + std::make_shared(std::make_shared(kLabelSwitchOpName)), + origin_switch_inputs[kCNodeSwitchCond]}; + for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { + // 3.1 branch kernel graph and args + KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + // 3.2 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); + new_switch_inputs.push_back(branch_label); + } + std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); + + cur_node->set_inputs(new_switch_inputs); + cur_node->set_abstract(nullptr); + MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); +} + +void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, + const CNodePtr &next_node, + const NotNull *> memo) { + MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); + + if (cur_node->size() < kCNodeSwitchLayerLength) { + MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; + } + + auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); + MS_EXCEPTION_IF_NULL(branch_tuple); + if (!branch_tuple->isa()) { + MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; + } + const std::vector &branch_partial = utils::cast(branch_tuple)->inputs(); + // 1 return label + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + // 2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + // 3 recurse sub graph + const std::vector &origin_switch_inputs = cur_node->inputs(); + if (kCNodeSwitchCond >= origin_switch_inputs.size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << "."; + } + std::vector new_switch_inputs = { + std::make_shared(std::make_shared(kLabelSwitchOpName)), + origin_switch_inputs[kCNodeSwitchCond]}; + for (size_t i = 0; i < branch_partial.size(); ++i) { + // 3.1 branch kernel graph and args + KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + // 3.2 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); + new_switch_inputs.push_back(branch_label); + } + new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); + cur_node->set_inputs(new_switch_inputs); + cur_node->set_abstract(nullptr); + MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); +} + +KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { + if (!node.get()->isa()) { + if (IsValueNode(node)) { + return GetValueNode(node); + } + MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); + } + // 2.1 branch kernel graph and args + auto partial_cnode = utils::cast(node.get()); + MS_EXCEPTION_IF_NULL(partial_cnode); + if (partial_cnode->size() < kCNodePartialLength) { + MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; + } + + const auto &partial_inputs = partial_cnode->inputs(); + if (kCNodePartialFunc >= partial_inputs.size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; + } + auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); + return branch_kg; +} + +void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, + NotNull to_graph, NotNull from, + NotNull to) { + std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); + std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); + MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; + if (from_outputs.size() != to_outputs.size()) { + MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" + << to_outputs.size() << "]"; + } + for (size_t i = 0; i < from_outputs.size(); i++) { + auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); + if (assign_node != nullptr) { + auto jump_node = GetJumpNode(from_graph, to_graph); + const auto &from_graph_exe_order = from_graph->execution_order(); + auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); + if (jump_node_iter == from_graph_exe_order.end()) { + MS_EXCEPTION_IF_NULL(jump_node); + MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); + } + // insert assign between jump_node -1 and jump_node + if (jump_node_iter != from_graph_exe_order.begin()) { + InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); + } + if (jump_node != nullptr) { + InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); + } + } + } +} + +AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, + NotNull to) { + if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && + AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { + return nullptr; + } + if (from.get() == to.get()) { + return nullptr; + } + MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " + << to->DebugString(); + // config inputs of assign node + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimAssign->name())), to, from}; + // generate a new cnode + auto assign_node = kg->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_node); + assign_node->set_abstract(to->abstract()); + return assign_node; +} + +std::vector AscendControlParser::RecurseGraph(NotNull graph, + const NotNull *> memo) { + MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start"; + if (memo->find(graph) != memo->end()) { + return {}; + } + memo->insert(graph.get()); + graph->SetExecOrderByDefault(); + std::vector cnodes = graph->execution_order(); + + auto end_label_goto = graph->get_end_goto(); + if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { + cnodes.pop_back(); + } + AnfAlgo::ReorderExecList(NOT_NULL(&cnodes)); + if (end_label_goto != nullptr) { + cnodes.push_back(end_label_goto); + } + + std::vector execution_order; + uint32_t child_order_index = 0; + for (auto &node : cnodes) { + execution_order.push_back(node); + if (node == graph->get_end_goto()) { + continue; + } + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { + std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); + for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { + if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + if (child_order_index >= graph->child_graph_order().size()) { + MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); + } + auto child_graph = graph->child_graph_order()[child_order_index++]; + auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); + } + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { + uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + auto child_graph = graph->child_graph_order()[child_order_index++]; + auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); + } + } + graph->set_execution_order(execution_order); + graph->PrintGraphExecuteOrder(); + return execution_order; +} + +bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, + NotNull graph) { + const std::vector> &child_graph_order = graph->child_graph_order(); + // check index and child order size + if (child_graph_order.size() <= IntToSize(order_index)) { + MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " + << child_graph_order.size() << " goto index " << order_index; + } + auto child_graph = child_graph_order[order_index]; + MS_EXCEPTION_IF_NULL(child_graph); + + // get start_label_set_index of child graph + auto start_label_set = child_graph->get_start_label(); + uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); + if (label_index != start_label_set_index) { + MS_EXCEPTION_IF_NULL(cur_label); + MS_EXCEPTION_IF_NULL(start_label_set); + MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() + << " index " << start_label_set_index << " current child graph order : " << order_index; + return false; + } else { + return true; + } +} + +void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { + MS_LOG(INFO) << "Graph id:" << kg->graph_id(); + kg->SetExecOrderByDefault(); + auto call_nodes = kg->FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + std::vector child_graph_order; + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + for (const auto &child_graph : call_child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph != kg->parent_graph()) { + child_graph->set_parent_graph(kg.get()); + } + child_graph_order.push_back(child_graph); + } + } + for (size_t i = 0; i < child_graph_order.size(); i++) { + MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; + } + kg->set_child_graph_order(child_graph_order); +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h new file mode 100644 index 0000000000..bd35d68b36 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H + +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "utils/base_ref.h" +#include "utils/contract.h" +#include "utils/union_find_set.h" + +namespace mindspore { +namespace session { +class AscendControlParser { + public: + static void ChildGraphDataAssign(const std::map &graph_id_map); + static void LinkGraph(NotNull kg); + + static void InsertDependToGraph(NotNull kg, NotNull attch_node); + static void InsertControlDependToGraph(NotNull kg, NotNull first_node, + NotNull second_node); + static void ExecutorValidate(NotNull root_graph); + static void UpdateChildGraphOrder(NotNull kg); + + private: + static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label); + static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, + const CNodePtr &last_label, + const NotNull *> memo); + static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo); + static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo); + static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, + const NotNull *> memo); + + static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, + const CNodePtr &last_label); + static KernelGraphPtr ParsePartial(NotNull node); + + static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, + NotNull from, NotNull to); + static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); + + // root graph order + static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, + NotNull graph); + static std::vector RecurseGraph(NotNull graph, + const NotNull *> memo); +}; +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.cc b/mindspore/ccsrc/backend/session/ascend_inference_session.cc new file mode 100644 index 0000000000..d251eb2039 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/ascend_inference_session.h" +#include "frontend/operator/ops.h" +#include "ir/tensor.h" +#include "ir/anf.h" +#include "ir/param_value.h" +#include "runtime/device/kernel_runtime.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "common/trans.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "utils/config_manager.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +namespace session { +void AscendInferenceSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector inputs(inputs_const); + auto input_nodes = kernel_graph->inputs(); + + size_t no_weight_input = 0; + for (size_t i = 0; i < input_nodes.size(); ++i) { + tensor::TensorPtr tensor = nullptr; + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (!AnfAlgo::IsParameterWeight(pk_node)) { + tensor = inputs[no_weight_input++]; + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } +} + +GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { + auto graph_id = AscendSession::CompileGraph(func_graph); + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + // load weight data to device + auto input_nodes = kernel_graph->inputs(); + for (size_t i = 0; i < input_nodes.size(); ++i) { + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + if (AnfAlgo::IsParameterWeight(pk_node)) { + const auto ¶m_value = pk_node->default_param(); + MS_EXCEPTION_IF_NULL(param_value); + auto tensor = std::dynamic_pointer_cast(param_value->value()); + MS_EXCEPTION_IF_NULL(tensor); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + return graph_id; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_inference_session.h b/mindspore/ccsrc/backend/session/ascend_inference_session.h new file mode 100644 index 0000000000..5364ae8d4e --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_inference_session.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/session/ascend_session.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/session_factory.h" +#include "backend/session/ascend_control_parser.h" + +namespace mindspore { +namespace session { +class AscendInferenceSession : public AscendSession { + public: + AscendInferenceSession() = default; + ~AscendInferenceSession() = default; + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const; + GraphId CompileGraph(NotNull func_graph) override; +}; +MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc new file mode 100644 index 0000000000..9995518c00 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -0,0 +1,1752 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/ascend_session.h" +#include +#include +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "ir/tensor.h" +#include "ir/anf.h" +#include "common/trans.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "runtime/device/ascend/kernel_build_ascend.h" +#include "runtime/device/ascend/ascend_kernel_runtime.h" +#include "runtime/device/ascend/ascend_device_address.h" +#include "backend/optimizer/ascend/ascend_backend_optimization.h" +#include "backend/optimizer/common/common_backend_optimization.h" +#include "runtime/device/kernel_adjust.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_label_assign.h" +#include "predict/predict.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/scalar.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "debug/draw.h" +#include "common/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "utils/config_manager.h" +#include "utils/base_ref_extends.h" +#include "debug/tensor_load.h" + +namespace mindspore { +namespace session { +const size_t kInvalidIndex = SIZE_MAX; +constexpr size_t kReturnDataIndex = 1; +namespace { +void DumpGraphExeOrder(const std::vector &execution_order, const std::string &tag = "") { + MS_LOG(INFO) << "Dump execution_order size " << execution_order.size(); + MS_LOG(INFO) << "[index][stream_label][graph_id][node string]"; + int i = 0; + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + MS_LOG(INFO) << "[ " << i << "]" + << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]" + << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]" + << "[" << cnode->DebugString() << "]"; + i++; + } + + std::stringstream buf; + buf << "================== execution order ==================\n"; + if (!tag.empty()) { + buf << tag << "\n"; + } + buf << "execution_order size: " << execution_order.size() << "\n"; + i = 0; + for (auto &cnode : execution_order) { + MS_EXCEPTION_IF_NULL(cnode); + buf << i << ":\n"; + buf << "\t" << cnode->DebugString() << "\n"; + buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n"; + buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n"; + i++; + } + buf << "================== execution order ==================\n"; + // std::cout << buf.str() << std::endl; +} + +void DumpGraphInputArgs(const VectorRef &args) { + MS_LOG(INFO) << "Args size[%lu]" << args.size(); + for (size_t i = 0; i < args.size(); i++) { + if (utils::isa(args[i])) { + auto anf = utils::cast(args[i]); + MS_EXCEPTION_IF_NULL(anf); + MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString(); + } else if (utils::isa(args[i])) { + auto value = utils::cast(args[i]); + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString(); + } else { + MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString(); + } + } +} + +void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { + MS_EXCEPTION_IF_NULL(graph); + if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { + graph->set_stream_distinction_label(label); + } +} + +std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { + MS_EXCEPTION_IF_NULL(graph); + std::vector graph_inputs = graph->inputs(); + auto valid_inputs = graph->valid_inputs(); + size_t real_args_size = 0; + std::vector real_args = {}; + for (size_t i = 0; i < args.size(); i++) { + if (utils::isa(args[i])) { + auto tmp_args = AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}); + for (auto &real_arg : tmp_args) { + auto anf_node = utils::cast(real_arg); + MS_EXCEPTION_IF_NULL(anf_node); + auto abstract = anf_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa() && + !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + real_args_size += tuple_abstract->size(); + continue; + } + real_args_size += 1; + real_args.push_back(real_arg); + } + } else { + real_args_size += 1; + real_args.push_back(args[i]); + } + } + if (graph_inputs.size() != valid_inputs.size()) { + MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size() + << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; + } + if (real_args_size != graph_inputs.size()) { + for (size_t j = 0; j < valid_inputs.size(); j++) { + if (valid_inputs[j]) { + MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); + } + } + MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() + << " not equal"; + } + return real_args; +} + +std::vector GetCNodes(const std::vector &anf_nodes) { + std::vector cnodes = {}; + size_t i = 0; + for (const auto &anf : anf_nodes) { + MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString(); + MS_EXCEPTION_IF_NULL(anf); + if (anf->isa()) { + cnodes.push_back(anf->cast()); + } + } + return cnodes; +} + +static std::vector> GetChildList(const std::vector &cnodes, + const std::set &cut_prims) { + size_t after_cut_index = 0; + std::vector> ret; + for (size_t i = 0; i < cnodes.size(); ++i) { + bool is_cut_node = false; + for (auto &prim : cut_prims) { + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { + is_cut_node = true; + break; + } + } + if (is_cut_node) { + // is call and not switch call,cut to 3 lists + if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { + // if is not a call,cut to 2 lists + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); + after_cut_index = i; + } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); + ret.emplace_back(1, cnodes[i]); + after_cut_index = i + 1; + continue; + } + } + // get last child graph list + if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { + ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); + continue; + } + } + return ret; +} + +static void BindCallArgsWithParameter(const std::vector ¶meters, const std::vector &args, + const KernelGraphPtr &graph, KernelGraphPtr child_graph, + const NotNull *> memo) { + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); + if (args.empty()) { + return; + } + if (parameters.size() != args.size()) { + MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() + << " and args size:" << args.size() << " not equal!"; + } + child_graph->SetExecOrderByDefault(); + for (size_t i = 0; i < parameters.size(); i++) { + MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" + << args[i]->DebugString(); + if (args[i] == parameters[i]) { + MS_LOG(INFO) << "Parameter and arg are same."; + continue; + } + child_graph->SetRealInput(parameters[i], args[i]); + if (memo->find(child_graph) != memo->end() || !args[i]->isa()) { + MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); + child_graph->AddUnreuseArgs(args[i], graph); + } + } +} + +// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of +// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] +static void UpdateRealInput(NotNull graph, bool split_flag, + const NotNull *> memo) { + MS_EXCEPTION_IF_NULL(memo.get()); + auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); + if (child_graphs.size() == 1) { + MS_EXCEPTION_IF_NULL(child_graphs[0]); + std::vector real_args = + std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); + std::vector child_inputs = child_graphs[0]->inputs(); + BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); + if (split_flag) { + call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); + } + } else if (child_graphs.size() == 2) { + auto get_partial_args = [&](size_t input_index) -> std::vector { + auto switch_node = call_node->input(1); + MS_EXCEPTION_IF_NULL(switch_node); + auto switch_cnode = switch_node->cast(); + MS_EXCEPTION_IF_NULL(switch_cnode); + auto partial = switch_cnode->input(input_index); + MS_EXCEPTION_IF_NULL(partial); + if (IsValueNode(partial)) { + return {}; + } + auto partial_cnode = partial->cast(); + MS_EXCEPTION_IF_NULL(partial_cnode); + auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); + if (split_flag) { + partial_cnode->set_inputs( + std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); + } + return ret; + }; + BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); + BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); + } + } +} + +static void RecurseToUpdateCallRealInput(NotNull graph, + const NotNull *> memo) { + memo->insert(graph.get()); + MS_LOG(INFO) << "Start graph id:" << graph->graph_id(); + for (auto &child_graph : graph->child_graph_order()) { + if (memo->find(child_graph) != memo->end()) { + MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() + << ",parent graph:" << graph->parent_graph()->graph_id(); + continue; + } + RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); + } + // this action should from bottom to top + graph->UpdateCallRealInput(); +} +} // namespace + +GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + MS_LOG(INFO) << "Start"; + // construct graph, if successfully, graph_sum_ + 1 + auto graph = ConstructKernelGraph(lst, outputs); + auto graph_id = graph->graph_id(); + MS_LOG(INFO) << "Compile graph " << graph_id << " success"; + return graph_id; +} + +GraphId AscendSession::CompileGraph(NotNull func_graph) { + MS_LOG(INFO) << "Start"; + std::vector all_graphs; + auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); + BackendOptimization(all_graphs); + // split switch + SplitGraphs(NOT_NULL(root_graph)); + // empty graph dont entry to backend + if (root_graph->execution_order().empty()) { + MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; + root_graph->set_executable(false); + InitRuntimeResource(); + return root_graph->graph_id(); + } + // insert goto labels and label_sets + LinkChildGraphs(NOT_NULL(root_graph)); + // resource initialize + InitRuntimeResource(); + // recurse compile child root_graph + std::set memo; + RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); + // root root_graph valiate,include genearte execute order and so on + RootGraphExecutorValidate(NOT_NULL(root_graph)); + // adjust kernel + AdjustKernel(root_graph); + // assign stream + AssignStream(NOT_NULL(root_graph)); + // insert profiling point + device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); + // build kernel + BuildKernel(root_graph); + // alloc mem + MemoryAlloc(root_graph.get()); + // task generate + GenerateTaskInfo(root_graph); + // load task into device + LoadTask(root_graph); + DumpAllGraphs(all_graphs); + // return the root_graph id to backend + auto graph_id = root_graph->graph_id(); + return graph_id; +} + +void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto graph_order = GetGraphOrder(kernel_graph->graph_id()); + for (auto graph_id : graph_order) { + auto child_graph = GetGraph(graph_id); + if (child_graph == nullptr) { + continue; + } + if (child_graph->summary_node_exist()) { + kernel_graph->set_summary_node_exist(true); + return; + } + } + kernel_graph->set_summary_node_exist(false); +} + +void AscendSession::BuildGraph(GraphId graph_id) { + MS_LOG(INFO) << "Start"; + auto graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(graph); + // resource initialize + InitRuntimeResource(); + // multiple graph handle + if (graph_id == final_graph_id_) { + if (!graph->executable()) { + return; + } + // insert assigns to child graph + InsertAllAssigns(); + // insert switch and active to child graph + MergeSwitchCompile(); + SetFinalGraphSummaryFlag(graph); + // OptChildGraphs + auto graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + for (size_t i = 0; i < graph_order.size(); i++) { + if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { + continue; + } + MS_LOG(INFO) << "Start build child graph " << graph_order[i]; + auto child_graph = GetGraph(graph_order[i]); + CompileChildGraph(child_graph); + } + GetSummaryNodes(graph.get()); + // merge child graph + MergeGraphExecOrder(); + } else { + auto single_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(single_graph); + CompileChildGraph(single_graph); + // set the distinction label of single graph + single_graph->set_stream_distinction_label(graph_id); + single_graph->UpdateExecuteKernelStreamLabel(); + } + // adjust execution order because merge child graph and other special operations + AdjustKernel(graph); + // Assign streams for control sink and hccl and so on + AssignStream(NOT_NULL(graph)); + + device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); + // build kernel if node is cnode + BuildKernel(graph); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->precompile_only()) { + MS_LOG(INFO) << "Precompile only, stop in build kernel step"; + } else { + // alloc memory, including static memory and dynamic memory + MemoryAlloc(graph.get()); + // generate task info for task sink mode + GenerateTaskInfo(graph); + // load task info to device if it is sink mode + LoadTask(graph); + } + // sync the inital const tensor to device + SyncInitialTenosrToDevice(); + DumpAllGraphs({graph}); + MS_LOG(INFO) << "End"; +} + +void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { + MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); + opt::AscendBackendIRFusionOptimization(child_graph); + opt::AscendBackendFuseBasicOpt(child_graph, true); + opt::AscendBackendGraphKernelOpt(child_graph, true); + child_graph->SetExecOrderByDefault(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; + DumpIR(file_path, child_graph); + } + // select kernel build info + SelectKernel(*child_graph); + if (save_graphs) { + std::string file_path = + save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; + DumpIR(file_path, child_graph); + } + // convert kernel Graph to model + predictmodel::StepConvertGraph(child_graph); + // optimize graph + HardwareOptimize(child_graph); + // assign static memory of parameters + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignStaticMemoryInput(child_graph.get()); + runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); +} + +void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, + VectorRef *const outputs) { + MS_LOG(INFO) << "Start"; + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + // if none of child graph and no anf output exists + if (!kernel_graph->executable()) { + MS_LOG(INFO) << "No child graph has anf output"; + UpdateOutputs(kernel_graph, outputs, inputs); + return; + } + // load input data from user input + LoadInputData(kernel_graph, inputs); + // convert inputs to model + predictmodel::StepConvertWeight(inputs); +#ifdef ENABLE_DEBUGGER + // debugger pre-execution processing + if (debugger_) { + debugger_->PreExecute(kernel_graph); + } +#endif + { + py::gil_scoped_release release; + // run task on device + ExecTask(kernel_graph); + } + // get result from device + UpdateOutputs(kernel_graph, outputs, inputs); + // summary + Summary(kernel_graph.get()); +#ifdef ENABLE_DEBUGGER + // load tensor from device for debugger + if (debugger_ && debugger_->debugger_enabled()) { + LoadTensor(kernel_graph); + } +#endif + // dump used for debug + Dump(kernel_graph); +#ifdef ENABLE_DEBUGGER + // debugger post-execution processing + if (debugger_) { + debugger_->PostExecute(); + } +#endif + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start"; + // data layout optimization + opt::RunOpAscendDataLayout(kernel_graph); + // mixed precision optimization + opt::AscendMixPrecision(kernel_graph); + MS_LOG(INFO) << "Finish"; +} + +void AscendSession::RunOpExecTask(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "Run task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { + if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { + return true; + } + + return false; +} + +void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) { + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; + if (GraphCacheExist(graph_info)) { + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !"; + return; + } + + // construct graph include one op + auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(graph); + opt::RunOpAscendBackendIRFusionOptimization(graph); + // kernel select + SelectKernel(*graph); + // optimize + RunOpHardwareOptimize(graph); + // init runtime resource + InitRuntimeResource(); + // build kernel + RunOpAdjustKernel(graph); + BuildKernel(graph); + run_op_graphs_[graph_info] = graph; + MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; +} + +py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) { + auto graph = run_op_graphs_[graph_info]; + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; + // malloc mem + RunOpMemoryAlloc(input_tensors, graph.get()); + // load input data to device + LoadInputData(graph, input_tensors); + // run op + RunOpExecTask(graph); + // get output + VectorRef outputs; + UpdateOutputs(graph, &outputs, input_tensors); + // trans output to tuple + auto output_tensors = TransformBaseRefListToTuple(outputs); + if (!utils::isa(output_tensors) || + !py::isinstance(utils::cast(output_tensors).object_)) { + MS_LOG(EXCEPTION) << "The output tensors should be a tuple !"; + } + py::object tuple_obj = utils::cast(output_tensors).object_; + py::tuple tuple_tensors = py::cast(tuple_obj); + RunOpMemoryClear(graph.get()); + MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; + return tuple_tensors; +} + +// compile graph steps +void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + size_t raise_precision_count = 0; + size_t reduce_precision_count = 0; + for (const auto &cnode : kernel_graph.execution_order()) { + auto status = device::ascend::SelectKernelInfo(cnode); + if (status == device::ascend::kStatusRaisePrecision) { + raise_precision_count++; + } else if (status == device::ascend::kStatusReducePrecision) { + reduce_precision_count++; + } + MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kGraphMode) { + if (raise_precision_count > 0) { + MS_LOG(WARNING) << "There has " << raise_precision_count + << " node/nodes used raise precision to selected the kernel!"; + } + if (reduce_precision_count > 0) { + MS_LOG(WARNING) << "There has " << reduce_precision_count + << " node/nodes used reduce precision to selected the kernel!"; + } + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::InitRuntimeResource() { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Init()) { + MS_LOG(EXCEPTION) << "Kernel runtime init error."; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_graph) const { + device::ascend::KernelPreBuild(kernel_graph.get()); + MS_LOG(INFO) << "HardwareOptimize start!"; + opt::AscendBackendOptimization(kernel_graph); + opt::AscendGraphKernelCommonProcess(kernel_graph); + opt::AscendBackendFuseBasicOpt(kernel_graph, false); + opt::AscendBackendAddAtomicClean(kernel_graph); + MS_EXCEPTION_IF_NULL(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + MS_LOG(INFO) << "HardwareOptimize Finish!"; +} + +void AscendSession::AdjustKernel(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + opt::HideNopNode(kernel_graph.get()); + // Insert CLearZero op + // prepare for next step from json get atomic info + BuildKernel(kernel_graph); + device::ascend::KernelBuildPreprocess(kernel_graph.get()); + device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir"; + DumpIR(file_path, kernel_graph); + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + opt::HideNopNode(kernel_graph.get()); + // Insert CLearZero op + // prepare for next step from json get atomic info + BuildKernel(kernel_graph); + device::ascend::KernelBuildPreprocess(kernel_graph.get()); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::AssignStream(NotNull kernel_graph) const { + MS_LOG(INFO) << "Start!"; + device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + auto ret = device::ascend::KernelBuild(kernel_graph.get()); + if (!ret) { + MS_LOG(EXCEPTION) << "Kernel build error."; + } + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost; + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); + opt::RemoveNopNode(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignMemory(kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpMemoryAlloc(const std::vector &input_tensors, + KernelGraph *kernel_graph) const { + MS_LOG(INFO) << "Start memory alloc!"; + MS_EXCEPTION_IF_NULL(kernel_graph); + opt::RemoveNopNode(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpClearMemory(kernel_graph); +} + +void AscendSession::GenerateTaskInfo(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->GenTask(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "Generate task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::LoadTask(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->LoadTask(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "Load task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::ExecTask(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + bool ret_ok = runtime_instance->Run(kernel_graph.get()); + if (!ret_ok) { + MS_LOG(EXCEPTION) << "run task error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::Dump(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + (void)runtime_instance->DumpData(kernel_graph.get()); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::DumpAllGraphs(const std::vector &all_graphs) { +#ifdef ENABLE_DUMP_IR + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool save_graphs = context_ptr->save_graphs_flag(); + if (!save_graphs) { + return; + } + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + for (auto &graph : all_graphs) { + MS_EXCEPTION_IF_NULL(graph); + std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(graph->graph_id()) + ".ir"; + DumpIR(file_path, graph, true); + DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id())); + } +#endif +} + +void AscendSession::LoadTensor(const std::shared_ptr &kernel_graph) const { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(kernel_graph); +#ifdef ENABLE_DEBUGGER + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + DebugServices *debug_services = debugger_->debug_services(); + TensorLoader *tensor_loader = debug_services->get_tensor_loader(); + tensor_loader->EmptyTensor(); + uint32_t iter_num = tensor_loader->GetIterNum(); + tensor_loader->set_iter_num(++iter_num); + (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get()); + tensor_loader->EmptyPrevTensor(); +#endif + MS_LOG(INFO) << "Finish!"; +} + +GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { + MS_LOG(INFO) << "Start! Args size " << args.size(); + auto final_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(final_graph); + final_graph_id_ = final_graph->graph_id(); + MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; + // init private variables and bind them with final_graph_id + graph_execute_orders_[final_graph_id_] = std::vector(); + graph_order_types_[final_graph_id_] = std::vector(); + for (const auto ¶meter : args) { + MS_EXCEPTION_IF_NULL(parameter); + if (!parameter->isa()) { + MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!"; + } + AnfNodePtr parameter_backend = nullptr; + // if function return UINT_MAX,the parameter is not exist in child graph + auto parameter_belong_graph_id = GetGraphIdByNode(parameter); + if (parameter_belong_graph_id == kInvalidGraphId) { + parameter_backend = CreateNewParameterFromParameter(parameter, true, final_graph.get()); + final_graph->FrontBackendlMapAdd(parameter, parameter_backend); + MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph"; + } else { + // parametr is a parameter of child graph + auto graph = GetGraph(parameter_belong_graph_id); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph [" + << parameter_belong_graph_id << "]"; + parameter_backend = graph->GetBackendAnfByFrontAnf(parameter); + // add parameter in backend to final graph inputs + auto final_graph_inputs = final_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(final_graph_inputs); + final_graph_inputs->push_back(parameter_backend); + } + MS_EXCEPTION_IF_NULL(parameter_backend); + MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id " + << AnfAlgo::GetGraphId(parameter_backend.get()); + } + MS_LOG(INFO) << "End final_graph_id " << final_graph_id_; + return final_graph_id_; +} + +void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, + std::map> *summary) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(summary); + // if final graph have no child graph + auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); + if (graph_order_iter == graph_execute_orders_.end()) { + SessionBasic::GetSummaryNodes(graph); + auto summary_nodes = graph->summary_nodes(); + summary->insert(summary_nodes.begin(), summary_nodes.end()); + return; + } + // for every child graph, find summary nodes + auto graph_order = GetGraphOrder(graph->graph_id()); + for (size_t i = 0; i < graph_order.size(); i++) { + auto child_graph = GetGraph(graph_order[i]); + if (child_graph == nullptr) { + continue; + } + SessionBasic::GetSummaryNodes(child_graph.get()); + auto child_graph_summary = child_graph->summary_nodes(); + summary->insert(child_graph_summary.begin(), child_graph_summary.end()); + RecurseGetSummaryNodes(child_graph.get(), summary); + } + graph->set_summary_nodes(*summary); +} + +void AscendSession::GetSummaryNodes(KernelGraph *graph) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + auto summary_nodes = graph->summary_nodes(); + std::map> summary; + summary.insert(summary_nodes.begin(), summary_nodes.end()); + RecurseGetSummaryNodes(graph, &summary); + graph->set_summary_nodes(summary); + MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); +} + +AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { + auto fake_graph = GetGraph(fake_graph_id); + MS_EXCEPTION_IF_NULL(fake_graph); + auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); + auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { + auto parameter = fake_graph->NewParameter(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->set_abstract(abstract); + auto new_parameter = fake_graph->NewParameter(parameter); + // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. + auto graph_inputs = fake_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->push_back(new_parameter); + return new_parameter; + }; + auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { + MS_EXCEPTION_IF_NULL(cnode); + auto abstract = cnode->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa()) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]"; + return create_parameter((*tuple_abstract)[output_idx]); + } + return create_parameter(cnode->abstract()); + }; + if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + auto make_tuple = output_item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + for (size_t i = 1; i < make_tuple->inputs().size(); i++) { + auto input = make_tuple->inputs()[i]; + make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); + } + return fake_graph->NewCNode(make_tuple_inputs); + } + return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); +} + +void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { + // get the backend anf node related to the output node of front + auto output_from_graph_id = GetGraphIdByNode(node); + auto output_from_graph = GetGraph(output_from_graph_id); + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id + << "] to final graph"; + MS_EXCEPTION_IF_NULL(output_from_graph); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + // if output is from final graph,it remarks no child graph exist + if (final_graph_id_ == output_from_graph_id) { + MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); + final_graph->set_output(ConstructOutput({node}, final_graph)); + final_graph->set_executable(false); + return; + } + final_graph->set_output(output_from_graph->output()); +} + +void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { + auto value_node = NewValueNode(value); + auto kernel_info = std::make_shared(); + value_node->set_kernel_info(kernel_info); + value_node->set_abstract(abstract::FromValue(value)); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); + final_graph->set_executable(false); + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; +} + +void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { + for (auto &output : vec_output) { + if (utils::isa(output)) { + auto output_anf_node = utils::cast(output); + SetFinalGraphOutput(output_anf_node); + } else if (utils::isa(output)) { + auto value = utils::cast(output); + SetFinalGraphOutput(value); + } else { + MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); + } + } +} + +void AscendSession::SetFinalGraphOutput(const BaseRef &output) { + if (utils::isa(output)) { + auto output_anf_node = utils::cast(output); + SetFinalGraphOutput(output_anf_node); + } else if (utils::isa(output)) { + auto value = utils::cast(output); + SetFinalGraphOutput(value); + } else if (utils::isa(output)) { + auto vec_output = utils::cast(output); + SetFinalGraphOutput(vec_output); + } else { + MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); + } +} + +void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { + MS_LOG(INFO) << "Start!"; + MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; + auto condition_graph = GetGraph(condition_graph_id); + MS_EXCEPTION_IF_NULL(condition_graph); + tensor::TensorPtr tensor = std::make_shared(kNumberTypeInt32, std::vector{1}); + int32_t *val = nullptr; + val = static_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + auto value_node = std::make_shared(tensor); + value_node->set_abstract(abstract::FromValue(tensor, false)); + auto counter_const = condition_graph->NewValueNode(value_node); + condition_graph->AddValueNodeToGraph(counter_const); + // create a new switch op + auto switch_primitive = std::make_shared("StreamSwitch"); + auto cond_output_it = condition_output_.find(condition_graph_id); + if (cond_output_it == condition_output_.end()) { + MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; + } + auto cond_output_kernel = + AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; + MS_EXCEPTION_IF_NULL(cond_output_kernel); + std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; + CNodePtr switch_node = condition_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(switch_node); + switch_node->set_abstract(std::make_shared()); + AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); + // set attr: cond_ RT_GREATER + AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue(static_cast(RT_GREATER)), switch_node); + // set attr:data_type + AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue(static_cast(RT_SWITCH_INT64)), switch_node); + // set attr:true branch graph id ,which is same to stream distinction label + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(true_graph_id), switch_node); + // append switch at the end of condition graph + auto return_node = condition_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + InsertControlDependToGraph(condition_graph_id, return_node->input(kReturnDataIndex), switch_node); + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { + auto &graph_execute_order = GetGraphOrder(final_graph_id_); + auto &graph_order_type = GetGraphOrderType(final_graph_id_); + auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id); + if (false_index == kInvalidIndex || false_index == 0) { + return; + } + for (int i = SizeToInt(false_index) - 1; i >= 0; i--) { + size_t graph_index = IntToSize(i); + if (graph_index >= graph_execute_order.size()) { + MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]"; + } + if (graph_order_type[graph_index] == COMMON_GRAPH) { + auto true_last_id = graph_execute_order[graph_index]; + MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id; + auto true_last = GetGraph(true_last_id); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + auto false_last = GetGraph(false_graph_id); + MS_EXCEPTION_IF_NULL(true_last); + MS_EXCEPTION_IF_NULL(false_last); + MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id; + // create fake output + auto fake_output_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(fake_output_graph); + graph_execute_order.push_back(fake_output_graph->graph_id()); + graph_order_type.push_back(COMMON_GRAPH); + fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); + final_graph->set_output(fake_output_graph->output()); + InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); + InsertMultipleAssignToGraph(false_graph_id, false_last->output(), final_graph->output()); + // insert stream active for loop sink + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && + ConfigManager::GetInstance().iter_num() > 1) { + // insert active in true graph, another active will be inserted in kernel adjust + InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); + } + break; + } + } +} + +void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, + const AnfNodePtr &output) { + if (switches_.find(cond_graph_id) != switches_.end()) { + MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; + return; + } + switches_[cond_graph_id] = std::pair(true_graph_id, false_graph_id); + condition_output_[cond_graph_id] = output; + MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; + // set the type of condition graph + auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); + auto &graph_order_type = GetGraphOrderType(final_graph_id_); + if (cond_graph_index >= graph_order_type.size()) { + MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size(); + } + graph_order_type[cond_graph_index] = CONDITION_GRAPH; + // update distinction label of false graph,update before merge to sure the distinction + if (false_graph_id != kInvalidGraphId) { + // false graph and condition in graph same stream + auto condition_graph = GetGraph(cond_graph_id); + MS_EXCEPTION_IF_NULL(condition_graph); + SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); + // if false graph is a condition graph and has been switch compiled before,it's false should be updated again + auto cond_it = switches_.find(false_graph_id); + while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { + cond_graph_id = cond_it->first; + false_graph_id = cond_it->second.second; + condition_graph = GetGraph(cond_graph_id); + if (condition_graph == nullptr) { + continue; + } + SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); + cond_it = switches_.find(false_graph_id); + } + } +} // namespace session + +void AscendSession::MergeSwitchCompile() { + auto graph_execute_order = GetGraphOrder(final_graph_id_); + auto &graph_order_type = GetGraphOrderType(final_graph_id_); + for (auto switch_compile : switches_) { + auto cond_graph_id = switch_compile.first; + auto true_graph_id = switch_compile.second.first; + auto false_graph_id = switch_compile.second.second; + MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; + auto condition_graph = GetGraph(cond_graph_id); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(condition_graph); + MS_EXCEPTION_IF_NULL(final_graph); + // insert switch to condition graph + InsertSwitchToGraph(cond_graph_id, true_graph_id); + auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); + auto prev_graph_id = kInvalidGraphId; + // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph + if (cond_graph_index == 0 && !final_graph->execution_order().empty()) { + prev_graph_id = final_graph_id_; + // set the distinction label of final graph + SetStreamDistinctionLabel(final_graph, final_graph_id_, true); + // if condition graph is not the first graph + } else if ((cond_graph_index - 1 < graph_execute_order.size()) && + (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) { + prev_graph_id = graph_execute_order[cond_graph_index - 1]; + } + // insert stream active to common graph + if (prev_graph_id != kInvalidGraphId) { + InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); + } + // if this is a 'if' condition + auto it = while_condition_graphs_.find(cond_graph_id); + if (it == while_condition_graphs_.end()) { + CopyOutputOfIf(false_graph_id); + } else { + // if it is a while,insert a stream active to true graph + GraphId from_graph = it->second; + InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); + } + } + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::InsertAllAssigns() { + std::vector> assigns; + for (auto assign : assigns_) { + auto front_anf = std::get<0>(assign); + auto to_graph_id = std::get<1>(assign); + auto input_idx = std::get<2>(assign); + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + assigns.emplace_back(std::pair(front_anf, backend_parameter)); + } + // erase the repeat assign + std::set> inserted_nodes; + for (auto &assign : assigns) { + auto front_anf = assign.first; + auto backend_parameter = assign.second; + auto from_graph_id = GetGraphIdByNode(front_anf); + auto from_graph = GetGraph(from_graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); + if (inserted_nodes.find(assign) == inserted_nodes.end()) { + InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); + (void)inserted_nodes.insert(assign); + } + } +} + +// insert active to graph +void AscendSession::SetActive(GraphId from, GraphId to) { + if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { + MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from " + << while_condition_graphs_[to]; + return; + } + MS_LOG(INFO) << "From " << from << " to " << to; + auto &graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + std::vector graph_order_new; + std::vector graph_type_new; + for (size_t i = 0; i < graph_order.size(); i++) { + auto graph_id = graph_order[i]; + graph_order_new.push_back(graph_id); + graph_type_new.push_back(graph_type[i]); + if (from == graph_id) { + graph_order_new.push_back(kInvalidGraphId); + graph_type_new.push_back(BRANCH_END); + } + } + graph_order = graph_order_new; + graph_type = graph_type_new; + // set the graph type of condition graph + graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH; + // record the condition graph into while condition set + while_condition_graphs_[to] = from; +} + +void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { + MS_LOG(INFO) << "Start!"; + MS_EXCEPTION_IF_NULL(front_anf); + auto from_graph_id = GetGraphIdByNode(front_anf); + auto from_graph = GetGraph(from_graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + MS_EXCEPTION_IF_NULL(backend_parameter); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); + MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" + << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) + << "]"; + // a node should not assign to itself + if (backend_arg.get() == backend_parameter.get()) { + return; + } + // if arg is the the parameter of child graph,it is parameter of final graph too + if (front_anf->isa()) { + MS_EXCEPTION_IF_NULL(backend_arg); + MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() + << "] will be replaced."; + to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); + return; + } + MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" + << backend_parameter->DebugString() << "of graph " << to_graph_id; + assigns_.emplace_back(std::tuple(front_anf, to_graph_id, input_idx)); +} + +void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, + size_t input_idx) { + MS_LOG(INFO) << "Start!"; + std::pair graph_input_pair(to_graph_id, input_idx); + initial_tenosrs_[graph_input_pair] = front_tensor; + MS_LOG(INFO) << "Finish!"; +} + +void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { + MS_LOG(INFO) << "To_graph_id " << to_graph_id; + auto &graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + for (size_t i = 0; i < graph_order.size(); i++) { + if (graph_order[i] == to_graph_id) { + return; + } + } + // if graph is not in graph order,add it to graph order + SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false); + graph_order.push_back(to_graph_id); + graph_type.push_back(COMMON_GRAPH); + for (size_t i = 0; i < graph_order.size(); i++) { + MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i]; + } +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto output_num = AnfAlgo::GetOutputTensorNum(node); + if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return input_index + output_num; + } + auto valid_inputs = graph->valid_inputs(); + if (valid_inputs[input_index]) { + SetChildGraphParameter(node, graph->graph_id(), input_index); + } else { + MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); + } + return ++input_index; +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); + } + SetChildGraphParameter(value->cast(), graph->graph_id(), input_index); + return ++input_index; +} + +size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { + auto index = input_index; + for (auto &arg : vec_args) { + if (utils::isa(arg)) { + // arg is a anf node + auto node = utils::cast(arg); + index = SetChildGraphInput(graph, node, input_index); + } else if (utils::isa(arg)) { + // arg is a tensor + auto value = utils::cast(arg); + index = SetChildGraphInput(graph, value, input_index); + } else { + MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); + } + } + return index; +} + +void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { + MS_LOG(INFO) << "Set input of graph " << g; + auto to_graph = GetGraph(g); + MS_EXCEPTION_IF_NULL(to_graph); + DumpGraphInputArgs(args); + UpdateGraphOrder(g); + auto &graph_inputs = to_graph->inputs(); + auto real_args = GetRealArgs(to_graph, args); + size_t input_index = 0; + for (size_t i = 0; i < real_args.size(); i++) { + if (input_index >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size(); + } + auto &real_arg = real_args[i]; + if (utils::isa(real_arg)) { + // arg is a anf node + auto node = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, node, input_index); + } else if (utils::isa(real_arg)) { + // arg is a tensor + auto value = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, value, input_index); + } else if (utils::isa(real_arg)) { + // arg is a VectorRef + auto vec_args = utils::cast(real_arg); + input_index = SetChildGraphInput(to_graph, vec_args, input_index); + } else { + MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); + } + } + MS_LOG(INFO) << "Finish!"; +} + +GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { + for (const auto &graph_item : graphs_) { + auto graph = graph_item.second; + MS_EXCEPTION_IF_NULL(graph); + // if front_anf is a parameter,the backend parameter may have two + if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { + return graph_item.first; + } + } + MS_EXCEPTION_IF_NULL(front_anf); + MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; + return kInvalidGraphId; +} + +void AscendSession::MergeGraphExecOrder() { + MS_LOG(INFO) << "Start!"; + // merge graph order + auto &graph_order = GetGraphOrder(final_graph_id_); + auto &graph_type = GetGraphOrderType(final_graph_id_); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + if (graph_order.empty()) { + MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; + return; + } + if (graph_order.size() > 1) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->enable_task_sink()) { + MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; + } + } + // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph + SetStreamDistinctionLabel(final_graph, graph_order[0], false); + std::vector final_exec_order = final_graph->execution_order(); + KernelGraphPtr last_graph = nullptr; + for (size_t i = 0; i < graph_order.size(); i++) { + auto graph_id = graph_order[i]; + if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { + continue; + } + auto child_graph = GetGraph(graph_id); + last_graph = child_graph; + MS_EXCEPTION_IF_NULL(child_graph); + auto exec_order = child_graph->execution_order(); + MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; + (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order), + [&](CNodePtr node) -> CNodePtr { + AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get()); + return node; + }); + // add all value nodes of child graphs to final graph + for (auto &value_node : child_graph->graph_value_nodes()) { + final_graph->AddValueNodeToGraph(value_node); + } + // copy ref map to final graph + auto child_ref_map = child_graph->GetRefMap(); + for (auto &item : child_ref_map) { + if (final_graph->IsInRefOutputMap(item.first)) { + MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; + } + final_graph->AddRefCorrespondPairs(item.first, item.second); + } + } + // set final_exec_order into final graph + MS_EXCEPTION_IF_NULL(final_graph); + DumpGraphExeOrder(final_exec_order); + final_graph->set_execution_order(final_exec_order); +} + +void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { + MS_EXCEPTION_IF_NULL(from); + MS_EXCEPTION_IF_NULL(to); + if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && + AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { + return; + } + if (from.get() == to.get()) { + return; + } + MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to " + << to->DebugString(); + auto graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(graph); + // config inputs of assign node + std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; + // generate a new cnode + auto assign_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_node); + assign_node->set_abstract(to->abstract()); + // append the assign at the end of from graph + InsertDependToGraph(graph_id, assign_node); +} + +void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { + std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); + std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); + MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to [" + << AnfAlgo::GetGraphId(to.get()) << "]"; + if (from_outputs.size() != to_outputs.size()) { + MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]"; + MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" + << to_outputs.size() << "]"; + } + for (size_t i = 0; i < from_outputs.size(); i++) { + InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]); + } +} + +void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { + MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; + auto from_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(from_graph); + std::vector inputs = {NewValueNode(std::make_shared("StreamActive"))}; + auto active_node = from_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(active_node); + active_node->set_abstract(std::make_shared()); + // set the active stream id into the attr of active node + std::vector active_index_value = {}; + active_index_value.push_back(actived_stream); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_value), active_node); + // append the active node at the end of from graph + auto return_node = from_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + InsertControlDependToGraph(graph_id, return_node->input(kReturnDataIndex), active_node); +} + +void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { + AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); +} + +void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, + const AnfNodePtr &second_node) { + AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), + NOT_NULL(second_node)); +} + +size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { + auto &graph_order = GetGraphOrder(final_graph); + for (size_t i = 0; i < graph_order.size(); i++) { + if (child_graph == graph_order[i]) { + return i; + } + } + return kInvalidIndex; +} + +std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) { + auto graph_order_iter = graph_execute_orders_.find(final_graph_id); + if (graph_order_iter == graph_execute_orders_.end()) { + MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph"; + } + return graph_order_iter->second; +} + +// get graph order type vector by graph id +std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) { + auto graph_type_iter = graph_order_types_.find(final_graph_id); + if (graph_type_iter == graph_order_types_.end()) { + MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_"; + } + return graph_type_iter->second; +} + +void AscendSession::SyncInitialTenosrToDevice() { + for (auto &item : initial_tenosrs_) { + auto to_graph_id = item.first.first; + auto input_idx = item.first.second; + auto front_tensor = item.second; + auto to_graph = GetGraph(to_graph_id); + MS_EXCEPTION_IF_NULL(to_graph); + std::vector graph_inputs = to_graph->inputs(); + if (input_idx >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); + } + auto backend_parameter = graph_inputs[input_idx]; + // sync data from host to device + MS_EXCEPTION_IF_NULL(front_tensor); + size_t tensor_size = front_tensor->data().nbytes(); + auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); + MS_EXCEPTION_IF_NULL(addr); + if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, + front_tensor->data_type(), front_tensor->data_c())) { + MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; + } + } +} + +static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector &list) { + // count the output of every anf node + std::set has_output_nodes; + for (auto &anf_node : list) { + MS_EXCEPTION_IF_NULL(anf_node); + for (auto &input : anf_node->inputs()) { + (void)has_output_nodes.insert(input); + } + } + + auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); + std::vector make_tuple_inputs = {make_tuple_primitve}; + int output_idx = 0; + MS_EXCEPTION_IF_NULL(new_kernel_graph); + for (auto &anf_node : list) { + if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { + new_kernel_graph->set_return(anf_node); + } + if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString(); + make_tuple_inputs.push_back(anf_node); + } + } + if (new_kernel_graph->get_return() == nullptr) { + new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); + } +} + +std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list) { + MS_EXCEPTION_IF_NULL(new_kernel_graph); + MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id(); + MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); + std::vector call_node_inputs; + std::vector new_graph_inputs; + // create new parameter from cnode + for (auto &anf_node : list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { + auto input = cnode->inputs()[input_idx]; + MS_EXCEPTION_IF_NULL(input); + AnfNodePtr new_parameter = nullptr; + // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one + // parameter in graph inputs and one arg in call node + auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); + if (call_input_it != call_node_inputs.end()) { + cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); + continue; + } + // value node consider move to new graph + if (input->isa()) { + cnode->set_input(input_idx, input); + continue; + } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { + // if is cnode and not in current child graph + new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); + cnode->set_input(input_idx, new_parameter); + } else { + // if is a cnode and in current graph + continue; + } + new_graph_inputs.push_back(new_parameter); + call_node_inputs.push_back(input); + } + } + // set graph inputs of new graph + auto graph_inputs = new_kernel_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->clear(); + std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); + + MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); + ConstructSplitedGraphOutput(new_kernel_graph, list); + MS_LOG(INFO) << "End"; + return call_node_inputs; +} + +void AscendSession::BackendOptimization(const std::vector &all_graphs) { + MS_LOG(INFO) << "Start BackendCommonOptimization"; + for (auto &graph : all_graphs) { + opt::BackendCommonOptimization(graph); + } + MS_LOG(INFO) << "End."; +} + +void AscendSession::SplitGraphs(NotNull root_graph) { + std::set memo; + // if output of graph is nullptr,no need insert maketuple at the end of graph + if (root_graph->output() == nullptr) { + return; + } + // if root graph output is a call node ,the root graph is condition graph of 'if' sentence + auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; + if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { + SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); + for (auto &child_graph : root_graph->child_graph_order()) { + RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); + } + } else { + RecurseSplitGraph(root_graph, NOT_NULL(&memo)); + } + memo.clear(); + // add maketuple to the end of the last child graph to suit old process + auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back(); + auto make_tuple = output_graph->NewCNode( + {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), output_graph->output()}); + output_graph->set_output(make_tuple); + // replace the real input if the real input is a call + RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); +} + +AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, + const std::vector &child_graph_list) { + // if child graph list only has a call ,then return the exist call + if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { + return child_graph_list[0]; + } + // create new child graph + auto child_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(child_graph); + // create new value node to bind child graph + auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); + std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), + graph_value_node}; + // set the graph id of all node of child graph + for (auto &child_graph_node : child_graph_list) { + AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); + } + auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); + std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); + auto new_call = graph->NewCNode(new_call_input); + AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call); + return new_call; +} + +void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims, + const NotNull *> memo) { + MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); + bool split_flag = false; + auto apply_list = GetCNodes(TopoSort(graph->get_return())); + // update the root graph child graph order + AscendControlParser::UpdateChildGraphOrder(graph); + // get child list from current graph + std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); + if (child_graph_lists.size() > 1) { + std::list depend_input = {}; + for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { + auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]); + MS_EXCEPTION_IF_NULL(call_node); + // if call node is the last call of true graph,no need create child graph after that + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + depend_input.push_front(call_node); + if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { + break; + } + } + depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); + auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); + auto new_return_primitive = + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); + graph->set_return(graph->NewCNode({new_return_primitive, depend})); + AnfNodePtr pre_call_node = nullptr; + AnfNodePtr cur_call_node = nullptr; + auto iter = depend_input.begin(); + for (++iter; iter != depend_input.end(); ++iter) { + pre_call_node = cur_call_node; + cur_call_node = *iter; + if (pre_call_node != nullptr && cur_call_node != nullptr) { + AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); + } + } + split_flag = true; + } + AscendControlParser::UpdateChildGraphOrder(graph); + UpdateRealInput(graph, split_flag, memo); + MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; +} + +void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { + memo->insert(graph.get()); + SplitGraph(graph, {prim::kPrimCall}, memo); + for (auto &child_graph : graph->child_graph_order()) { + if (memo->find(child_graph) == memo->end()) { + RecurseSplitGraph(NOT_NULL(child_graph), memo); + } + } +} + +void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } + +void AscendSession::RootGraphExecutorValidate(NotNull graph) { + AscendControlParser::ExecutorValidate(graph); +} + +void AscendSession::RecurseCompileGraph(NotNull graph, const NotNull *> memo) { + memo->insert(graph.get()); + CompileChildGraph(graph); + for (auto child_graph : graph->child_graph_order()) { + if (memo->find(child_graph) != memo->end()) { + continue; + } + RecurseCompileGraph(NOT_NULL(child_graph), memo); + // copy ref map to final graph + auto child_ref_map = child_graph->GetRefMap(); + for (auto &item : child_ref_map) { + if (graph->IsInRefOutputMap(item.first)) { + MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; + } + graph->AddRefCorrespondPairs(item.first, item.second); + } + } +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h new file mode 100755 index 0000000000..f8ec7e8545 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/session/session_factory.h" +#include "backend/session/ascend_control_parser.h" + +namespace mindspore { +namespace session { +enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; + +class AscendSession : public SessionBasic { + public: + AscendSession() { final_graph_id_ = kInvalidGraphId; } + ~AscendSession() override = default; + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kAscendDevice, device_id); + } + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + GraphId CompileGraph(NotNull func_graph) override; + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + void BuildGraph(GraphId) override; + void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) override; + py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) override; + + // set parameters of final graph + GraphId SetFinalGraphInput(const std::vector &args) override; + // set output of final graph + void SetFinalGraphOutput(const BaseRef &output) override; + // insert switch and set the relative active ops + void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; + // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter + void SetChildGraphInput(GraphId g, const VectorRef &args) override; + // get graph id in child graphs by ME front anf node pointer + GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; + // get graph id of final graph + GraphId GetFinalRunGraph() const override { return final_graph_id_; } + // insert active to graph + void SetActive(GraphId, GraphId) override; + // compile child graph when session have multiple child graphs + void CompileChildGraph(const KernelGraphPtr &child_graph); + void RecurseGetSummaryNodes(KernelGraph *graph, std::map> *summary); + void GetSummaryNodes(KernelGraph *graph); + + private: + void InitRuntimeResource(); + void SelectKernel(const KernelGraph &kernel_graph) const; + void HardwareOptimize(const std::shared_ptr &kernel_graph) const; + void AdjustKernel(const std::shared_ptr &kernel_graph) const; + void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; + void AssignStream(NotNull kernel_graph) const; + void BuildKernel(const std::shared_ptr &kernel_graph) const; + void MemoryAlloc(KernelGraph *kernel_graph) const; + void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; + void RunOpMemoryClear(const KernelGraph *kernel_graph) const; + void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; + void LoadTask(const std::shared_ptr &kernel_graph) const; + void ExecTask(const std::shared_ptr &kernel_graph) const; + void Dump(const std::shared_ptr &kernel_graph) const; + void DumpAllGraphs(const std::vector &all_graphs); + void LoadTensor(const std::shared_ptr &kernel_graph) const; + // below functions are used for run op + void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; + void RunOpExecTask(const std::shared_ptr &kernel_graph) const; + + size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); + size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); + size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); + + void SetFinalGraphOutput(const AnfNodePtr &node); + void SetFinalGraphOutput(const ValuePtr &value); + void SetFinalGraphOutput(const VectorRef &vec_output); + + void SplitGraph(NotNull graph, const std::set &cut_prims, + const NotNull *> memo); + // split graphs with recurse from root graph + void SplitGraphs(NotNull root_graph); + void BackendOptimization(const std::vector &all_graphs); + void LinkChildGraphs(NotNull graph); + void RootGraphExecutorValidate(NotNull graph); + std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list); + void RecurseCompileGraph(NotNull graph, const NotNull *> memo); + void RecurseSplitGraph(NotNull graph, const NotNull *> memo); + AnfNodePtr BindNewCallToNewGraph(NotNull graph, const std::vector &child_graph_list); + + // merge execution order list of child graphs + void MergeGraphExecOrder(); + // insert assion op to sync data bettween different graphs + void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); + // insert mutiple assigns to graph + void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); + // insert active op to graph + void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); + // get execute index of graph + size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); + // handle condition graph from vm + void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); + // insert depend to graph, used to attch control nodes to graph + void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); + // insert depend to graph, used to attch control nodes to graph + void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); + // set child graph parameter if front arg is a anf + void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); + // set child graph parameter if front arg is a tensor + void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); + // update the execution order of all child graphs + void UpdateGraphOrder(GraphId to_graph); + // handle switch when merge + void MergeSwitchCompile(); + // get graph order vector by graph id + std::vector &GetGraphOrder(GraphId final_graph_id); + // get graph order type vector by graph id + std::vector &GetGraphOrderType(GraphId final_graph_id); + // copy output of if and else + void CopyOutputOfIf(GraphId false_graph_id); + // check if graph cache exist + bool GraphCacheExist(const GraphInfo &graph_info) const; + // insert all assign to child graph + void InsertAllAssigns(); + // create fake output of final graph + AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); + // sync intial tensors' data to device + void SyncInitialTenosrToDevice(); + void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); + + // member variables + // key is final_graph_id,value is child graph execute order of final graph + std::unordered_map> graph_execute_orders_; + // key is final_graph_id,value is the graph types of child graphs + std::unordered_map> graph_order_types_; + // record condition graph of while + std::unordered_map while_condition_graphs_; + // record all conditions + std::unordered_map> switches_; + std::unordered_map condition_output_; + // share parameters + std::vector> assigns_; + // initial tensors, these tensor will sync data to device before run graph + std::map, tensor::TensorPtr> initial_tenosrs_; + // final_graph_id is used in every root graph has it's own session situation + GraphId final_graph_id_; +}; +MS_REG_SESSION(kAscendDevice, AscendSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc new file mode 100644 index 0000000000..ca1c78d206 --- /dev/null +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/cpu_session.h" +#include +#include "ir/tensor.h" +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_runtime.h" +#include "predict/predict.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "runtime/device/cpu/kernel_select_cpu.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +namespace session { +ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; + } + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + ParameterPtr new_parameter = graph->NewParameter(anf->cast()); + TraceManager::EndTrace(); + graph_inputs->push_back(new_parameter); + valid_inputs->push_back(valid_input); + return new_parameter; +} + +GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + auto graph_id = graph_sum_; + auto graph = ConstructKernelGraph(lst, outputs); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Set kernel info"; + SetKernelInfo(graph.get()); + predictmodel::StepConvertGraph(graph); + MS_LOG(INFO) << "Build kernel"; + BuildKernel(graph.get()); + MS_LOG(INFO) << "Assign kernel address"; + runtime_.AssignKernelAddress(graph.get()); + return graph_id; +} + +void CPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { + auto &kernel_graph = graphs_[graph_id]; + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Bind input output address"; + std::vector need_sync_outputs; + runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs); + MS_LOG(INFO) << "Run graph start"; + predictmodel::StepConvertWeight(inputs); + auto execution_order = kernel_graph->execution_order(); + Reorder(&execution_order); + + bool enable_summary = summary_callback_ != nullptr; + kernel_graph->set_execution_order(execution_order); + NamedSummaryOutputs summary_outputs; + if (enable_summary) { + GetSummaryNodes(kernel_graph.get()); + summary_outputs = kernel_graph->summary_nodes(); + runtime_.IncreaseSummaryRefCount(summary_outputs); + } +#ifdef ENABLE_DEBUGGER + // debugger pre-execution processing + if (debugger_) { + debugger_->PreExecute(kernel_graph); + } +#endif + bool ret = runtime_.Run(kernel_graph.get()); + if (!ret) { + MS_LOG(EXCEPTION) << "Run graph failed"; + } + for (auto output : need_sync_outputs) { + (void)output->data_sync(); + } + + if (enable_summary) { + Summary(kernel_graph.get()); + runtime_.DecreaseSummaryRefCount(summary_outputs); + } + +#ifdef ENABLE_DEBUGGER + // debugger post-execution processing + if (debugger_) { + debugger_->PostExecute(); + } +#endif + MS_LOG(INFO) << "Run graph end"; +} + +void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + device::cpu::SetKernelInfo(kernel_node); + } +} + +void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &kernel_nodes = kernel_graph->execution_order(); + for (const auto &kernel_node : kernel_nodes) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "]."; + std::shared_ptr cpu_kernel = + kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node); + if (cpu_kernel == nullptr) { + MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support."; + } + cpu_kernel->Init(kernel_node); + AnfAlgo::SetKernelMod(cpu_kernel, kernel_node.get()); + MS_LOG(INFO) << "Cpu build success operator[" << kernel_name << "]."; + } +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h new file mode 100644 index 0000000000..b0dbd1cc2b --- /dev/null +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_CPU_SESSION_H +#define MINDSPORE_CCSRC_SESSION_CPU_SESSION_H +#include +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "runtime/device/cpu/cpu_kernel_runtime.h" +#include "backend/session/session_factory.h" +namespace mindspore { +namespace session { +class CPUSession : public SessionBasic { + public: + CPUSession() = default; + ~CPUSession() override = default; + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kCPUDevice, device_id); + } + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + + protected: + ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; + + private: + void SetKernelInfo(const KernelGraph *kernel_graph); + void BuildKernel(const KernelGraph *kernel_graph); + device::cpu::CPUKernelRuntime runtime_; +}; +MS_REG_SESSION(kCPUDevice, CPUSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_CPU_SESSION_H diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc new file mode 100644 index 0000000000..1f109e0a6a --- /dev/null +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/gpu_session.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "runtime/device/gpu/gpu_kernel_build.h" +#include "runtime/device/gpu/gpu_kernel_runtime.h" +#include "runtime/device/gpu/gpu_stream_assign.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/pass/communication_op_fusion.h" +#include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/optimizer/gpu/adam_weight_decay_fusion.h" +#include "backend/optimizer/gpu/adam_fusion.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "predict/predict.h" +#include "common/utils.h" +#include "common/trans.h" +#include "utils/context/ms_context.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +namespace session { +namespace gpu { +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; + +void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + for (const auto &kernel_node : kernel_graph->execution_order()) { + MS_EXCEPTION_IF_NULL(kernel_node); + device::gpu::SetKernelInfo(kernel_node); + } +} + +void GPUSession::StartKernelRT() const { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Init()) { + MS_LOG(EXCEPTION) << "GPU start kernel runtime failed"; + } +} + +void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_graph) { + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); +} + +void GPUSession::AssignStream(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + device::gpu::AssignGpuStream(kernel_graph); +} + +void GPUSession::BuildKernel(const std::shared_ptr &kernel_graph) const { + device::gpu::GpuBuild(kernel_graph); +} + +void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->AssignMemory(kernel_graph); +} + +void GPUSession::RunOpAllocateMemory(const std::vector &input_tensors, + KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); +} + +void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->RunOpClearMemory(kernel_graph); +} + +void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + std::vector inputs(inputs_const); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto input_nodes = kernel_graph->inputs(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + MS_EXCEPTION_IF_NULL(tensor); + auto input_node = input_nodes[i]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + auto pk_node = input_node->cast(); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + auto tensor_address = tensor->device_address(); + bool need_sync = false; + if (ms_context->enable_pynative_infer()) { + if (tensor_address == nullptr || tensor_address != device_address) { + need_sync = true; + } + } else if (tensor->is_dirty() || tensor_address == nullptr) { + need_sync = true; + } else if (tensor_address != device_address) { + if (tensor_address->DeviceType() == device_address->DeviceType()) { + AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get()); + } else { + need_sync = true; + } + } + if (need_sync) { + tensor->set_device_address(device_address); + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + tensor->set_dirty(false); + } +} + +void GPUSession::Execute(const std::shared_ptr &kernel_graph) const { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Run(kernel_graph.get())) { + MS_LOG(EXCEPTION) << "GPU execute graph failed!"; + } +} + +GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + // Construct graph, if successfully, graph_sum_ + 1 + auto graph_id = graph_sum_; + auto graph = ConstructKernelGraph(lst, outputs); + MS_EXCEPTION_IF_NULL(graph); + // Optimize + Optimize(graph); + // Select kernel build info + SelectKernel(graph); + // Convert kernel Graph to model + predictmodel::StepConvertGraph(graph); + // Start gpu kernel runtime + StartKernelRT(); + // HardwareOptimize + HardwareOptimize(graph); + // Assign CUDA streams + AssignStream(graph); + // Hide NoOp from execution graph + opt::HideNopNode(graph.get()); + // Build kernel if node is cnode + BuildKernel(graph); + // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph + auto execution_order = graph->execution_order(); + Reorder(&execution_order); + graph->set_execution_order(execution_order); + // Get summary nodes. + GetSummaryNodes(graph.get()); + // Remove NoOp from execution graph + opt::RemoveNopNode(graph.get()); + // Set graph manager. + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = MakeManager({graph}); + context_->AddManager(manager); + if (manager) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + // Alloc memory, including static memory and dynamic memory + AllocateMemory(graph.get()); + return graph_id; +} + +void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { + auto &kernel_graph = graphs_[graph_id]; + // Load input data from user input + LoadInputData(kernel_graph, inputs); + MS_EXCEPTION_IF_NULL(kernel_graph); + // Convert inputs to model + predictmodel::StepConvertWeight(inputs); + { + py::gil_scoped_release gil_release; + // Run graph on GPU + Execute(kernel_graph); + } + // Get result from GPU + UpdateOutputs(kernel_graph, outputs, inputs); + // Summary + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_gpu_summary()) { + Summary(kernel_graph.get()); + } +} + +void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) { + // Check if the graph cache exists. + if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { + return; + } + // Prepare the graph + auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(kernel_graph); + SelectKernel(kernel_graph); + StartKernelRT(); + // Hide NoOp from execution graph + opt::HideNopNode(kernel_graph.get()); + BuildKernel(kernel_graph); + run_op_graphs_[graph_info] = kernel_graph; +} + +py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) { + auto kernel_graph = run_op_graphs_[graph_info]; + MS_EXCEPTION_IF_NULL(kernel_graph); + // Remove NoOp from execution graph + opt::RemoveNopNode(kernel_graph.get()); + RunOpAllocateMemory(input_tensors, kernel_graph.get()); + // Execute the computation + LoadInputData(kernel_graph, input_tensors); + Execute(kernel_graph); + // Fetch outputs + VectorRef outputs; + UpdateOutputs(kernel_graph, &outputs, input_tensors); + // Trans output to tuple + auto output_tensors = TransformBaseRefListToTuple(outputs); + if (!utils::isa(output_tensors) || + !py::isinstance(utils::cast(output_tensors).object_)) { + MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; + } + py::object tuple_obj = utils::cast(output_tensors).object_; + py::tuple tuple_tensors = py::cast(tuple_obj); + RunOpClearMemory(kernel_graph.get()); + return tuple_tensors; +} +} // namespace gpu +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h new file mode 100644 index 0000000000..7e07dfbcbd --- /dev/null +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_GPU_SESSION_H +#define MINDSPORE_CCSRC_SESSION_GPU_SESSION_H + +#include +#include +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/session_factory.h" +using KernelGraph = mindspore::session::KernelGraph; + +namespace mindspore { +namespace session { +namespace gpu { +class GPUSession : public SessionBasic { + public: + GPUSession() = default; + ~GPUSession() override = default; + + void Init(uint32_t device_id) override { + SessionBasic::Init(device_id); + context_ = std::make_shared(kGPUDevice, device_id); + } + + GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; + + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; + void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, const std::vector &tensors_mask) override; + py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors) override; + + private: + void SelectKernel(const std::shared_ptr &kernel_graph) const; + + void StartKernelRT() const; + + void Optimize(const std::shared_ptr &kernel_graph); + + void HardwareOptimize(const std::shared_ptr &kernel_graph); + + void AssignStream(const std::shared_ptr &kernel_graph); + + void BuildKernel(const std::shared_ptr &kernel_graph) const; + + void AllocateMemory(KernelGraph *kernel_graph) const; + + void RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const; + + void RunOpClearMemory(KernelGraph *kernel_graph) const; + + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const override; + + void Execute(const std::shared_ptr &kernel_graph) const; +}; +using GPUSessionPtr = std::shared_ptr; +MS_REG_SESSION(kGPUDevice, GPUSession); +} // namespace gpu +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_GPU_SESSION_H diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc new file mode 100644 index 0000000000..0bf447751b --- /dev/null +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -0,0 +1,998 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/kernel_graph.h" +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "ir/param_value.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace session { +namespace { +constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; +constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; +void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(que); + MS_EXCEPTION_IF_NULL(visited_nodes); + if (visited_nodes->find(node) == visited_nodes->end()) { + que->push(node); + (void)visited_nodes->insert(node); + MS_LOG(DEBUG) << "Push que:" << node->DebugString(); + } +} + +std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { + auto item_with_index = + AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); + AnfNodePtr node = item_with_index.first; + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { + auto outputs = AnfAlgo::GetAllOutput(node); + std::set memo; + std::vector new_output; + for (auto &output : outputs) { + if (memo.find(output) != memo.end()) { + continue; + } + memo.insert(output); + new_output.push_back(output); + } + if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) { + node = new_output[0]; + } + } + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { + return {node}; + } + std::vector real_inputs; + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); + for (const auto &child_graph : child_graphs) { + if (child_graph->get_output_null()) { + continue; + } + auto real_input = child_graph->output(); + auto child_real_inputs = GetCallRealOutputs(real_input); + std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); + } + return real_inputs; +} + +AnfNodePtr MakeValueNode(const AnfNodePtr &node) { + auto value_node = node->cast(); + if (value_node == nullptr) { + return nullptr; + } + + ValueNodePtr new_value_node = std::make_shared(value_node->value()); + new_value_node->set_abstract(value_node->abstract()); + // create kernel_info fo new value node + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + std::vector types; + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { + types.push_back(kTypeUnknown); + } + kernel_build_info_builder->SetOutputsDeviceType(types); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + return new_value_node; +} + +bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { + if (left == right) { + return true; + } + if (left == nullptr || right == nullptr) { + return false; + } + if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { + return false; + } + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { + return AnfAlgo::GetNodeAttr(left, kAttrLabelIndex) == + AnfAlgo::GetNodeAttr(right, kAttrLabelIndex); + } + return false; +} +} // namespace +std::vector KernelGraph::outputs() const { + auto graph_output = output(); + if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { + auto make_tuple = output()->cast(); + MS_EXCEPTION_IF_NULL(make_tuple); + auto &inputs = make_tuple->inputs(); + return std::vector(inputs.begin() + 1, inputs.end()); + } + return std::vector(1, graph_output); +} + +void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(visit_queue); + MS_EXCEPTION_IF_NULL(visited_nodes); + auto it = node_output_edges_.find(node); + if (it == node_output_edges_.end()) { + // value node and parameter has no input,no need to print log + if (node->isa()) { + MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; + } + return; + } + + // visit all reduce node first, then other nodes + std::vector active_nodes; + for (const auto &output_edge : it->second) { + auto next_node = output_edge.first; + MS_EXCEPTION_IF_NULL(next_node); + if (node_input_num_.find(next_node) == node_input_num_.end()) { + MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; + } + MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() + << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; + if (node_input_num_[next_node] < output_edge.second) { + MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] + << ",depend edge:" << output_edge.second; + } + node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; + // allreduce first + if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { + (void)visited_nodes->insert(next_node); + if (AnfAlgo::IsCommunicationOp(next_node)) { + MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); + visit_queue->push(next_node); + } else { + active_nodes.emplace_back(next_node); + } + } + } + + for (auto &node : active_nodes) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Visit node:" << node->DebugString(); + visit_queue->push(node); + } +} + +void KernelGraph::SetExecOrderByDefault() { + std::queue seed_nodes; + UpdateNodeEdgeList(&seed_nodes); + execution_order_.clear(); + std::unordered_set visited_nodes; + std::queue zero_input_nodes; + AnfNodePtr last_communication_node = nullptr; + std::queue communication_descendants; + while (!seed_nodes.empty() || last_communication_node != nullptr) { + // seed nodes first, then visit last all reduce node descendant + if (seed_nodes.empty()) { + VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); + last_communication_node = nullptr; + } else { + zero_input_nodes.push(seed_nodes.front()); + seed_nodes.pop(); + } + // all reduce node descendant first, then common queue + while (!zero_input_nodes.empty() || !communication_descendants.empty()) { + AnfNodePtr node = nullptr; + bool is_communication_descendant = false; + if (communication_descendants.empty()) { + node = zero_input_nodes.front(); + zero_input_nodes.pop(); + } else { + node = communication_descendants.front(); + communication_descendants.pop(); + is_communication_descendant = true; + } + // add execute node + MS_EXCEPTION_IF_NULL(node); + if (node->isa() && AnfAlgo::IsRealKernel(node)) { + execution_order_.push_back(node->cast()); + } + // for all reduce node, visit last all reduce node descendant + if (AnfAlgo::IsCommunicationOp(node)) { + if (last_communication_node != nullptr) { + VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); + } + last_communication_node = node; + } else if (is_communication_descendant) { + VisitNodeDescendants(node, &communication_descendants, &visited_nodes); + } else { + VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); + } + } + } + CheckLoop(); + // resort start label / end goto + std::vector re_order; + if (start_label_ != nullptr) { + re_order.push_back(start_label_); + } + for (auto &node : execution_order_) { + if (node == start_label_ || node == end_goto_) { + continue; + } + + if (IsSameLabel(node, end_goto_)) { + end_goto_ = node; + MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); + continue; + } + + if (IsSameLabel(node, start_label_)) { + start_label_ = node; + MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); + continue; + } + + re_order.push_back(node); + } + if (end_goto_ != nullptr) { + re_order.push_back(end_goto_); + } + execution_order_ = re_order; +} + +void KernelGraph::CheckLoop() { + std::map none_zero_nodes; + if (node_input_edges_.size() != node_input_num_.size()) { + MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() + << "not equal to node_input_num_ size:" << node_input_num_.size(); + } + for (auto &it : node_input_num_) { + MS_EXCEPTION_IF_NULL(it.first); + string str; + auto node_input_it = node_input_edges_.find(it.first); + if (node_input_it == node_input_edges_.end()) { + MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; + } + for (const auto &input_edge : node_input_edges_[it.first]) { + MS_EXCEPTION_IF_NULL(input_edge.first); + str = str.append(input_edge.first->DebugString()).append("|"); + } + if (it.second != 0) { + MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; + none_zero_nodes[it.first] = it.second; + } + } + // if don't consider control depend and loop exit,a exception will be throw + if (!none_zero_nodes.empty()) { + MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); + } +} + +CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { + auto cnode = FuncGraph::NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_abstract(std::make_shared()); + CreateKernelInfoFromNewParameter(cnode); + + auto kernel_info = std::make_shared(); + std::vector feature_map_input_indexs; + // if the node only has the primitive(such as getNext) or the node's input has a feature map input + // then the node's output is a feature map output + for (size_t index = 1; index < inputs.size(); ++index) { + auto node = inputs[index]; + if (AnfAlgo::IsFeatureMapOutput(node)) { + feature_map_input_indexs.push_back(index); + } + } + if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { + AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); + } + if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { + kernel_info->SetFeatureMapFlag(true); + } + if (AnfAlgo::IsRealCNodeKernel(cnode)) { + AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); + AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); + } + cnode->set_kernel_info(kernel_info); + AnfAlgo::SetGraphId(graph_id_, cnode.get()); + return cnode; +} + +void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { + if (!AnfAlgo::IsGraphKernel(cnode)) { + return; + } + auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector node_list; + std::vector input_list; + std::vector output_list; + kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + for (auto &anf_node : node_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_info = std::make_shared(); + anf_node->set_kernel_info(kernel_info); + auto anf_cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(anf_cnode); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) { + auto input_node = anf_cnode->input(i + 1); + MS_EXCEPTION_IF_NULL(input_node); + if (IsValueNode(input_node)) { + auto new_input_node = MakeValueNode(input_node); + if (new_input_node != nullptr) { + anf_cnode->set_input(i + 1, new_input_node); + } + } + } + } + for (auto &anf_node : input_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_info = std::make_shared(); + anf_node->set_kernel_info(kernel_info); + } +} + +CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto new_cnode = std::make_shared(*cnode); + // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map + if (BackendNodeExistInFrontBackendMap(cnode)) { + FrontBackendlMapUpdate(cnode, new_cnode); + } + AnfAlgo::SetGraphId(graph_id_, cnode.get()); + if (IsInternalOutput(cnode)) { + ReplaceInternalOutput(cnode, new_cnode); + } + return new_cnode; +} + +ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { + ParameterPtr new_parameter = add_parameter(); + MS_EXCEPTION_IF_NULL(new_parameter); + // create kernel_info form new parameter + auto kernel_info = std::make_shared(); + size_t output_tensor_num = 1; + // if use default parameter = nullptr,it remarks create a new parameter from no parameter + if (parameter == nullptr) { + new_parameter->set_abstract(std::make_shared()); + kernel_info->SetFeatureMapFlag(true); + } else { + // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter + new_parameter->set_abstract(parameter->abstract()); + new_parameter->set_name(parameter->name()); + if (AnfAlgo::IsParameterWeight(parameter)) { + new_parameter->set_default_param(parameter->default_param()); + kernel_info->SetFeatureMapFlag(false); + } else { + kernel_info->SetFeatureMapFlag(true); + } + } + new_parameter->set_kernel_info(kernel_info); + // create kernel_build_info for new parameter + auto kernel_build_info_builder = std::make_shared(); + // create init data type, + std::vector init_data_type = {}; + + TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); + init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); + + // set the format of parameter to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); + // set parameter initaial device data type + kernel_build_info_builder->SetOutputsDeviceType(init_data_type); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get()); + AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); + return new_parameter; +} + +std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + auto node_value = value_node->value(); + auto output_size = AnfAlgo::GetOutputTensorNum(value_node); + std::vector convert_inputs; + if (!node_value->isa()) { + MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); + } + auto value_tuple = node_value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + if (value_tuple->size() != output_size) { + MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() + << " is not mathced with the value node's output size" << output_size; + } + for (size_t index = 0; index < value_tuple->value().size(); ++index) { + auto new_value_node = std::make_shared(value_tuple->value()[index]); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, + {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get()); + AddValueNodeToGraph(new_value_node); + auto kernel_info = std::make_shared(); + new_value_node->set_kernel_info(kernel_info); + kernel_info->SetFeatureMapFlag(false); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + // set the format of value_node to DEFAULT_FORMAT + kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); + // set value node initial device data type = infer data type + kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); + AddValueNodeToGraph(new_value_node); + convert_inputs.emplace_back(new_value_node); + } + if (!RemoveValueNodeFromGraph(value_node)) { + MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); + } + return convert_inputs; +} + +ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + auto new_value_node = MakeValueNode(value_node)->cast(); + AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); + return new_value_node; +} + +const std::vector &KernelGraph::inputs() const { + MS_EXCEPTION_IF_NULL(inputs_); + return *inputs_; +} + +void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) { + MS_EXCEPTION_IF_NULL(front_anf); + MS_EXCEPTION_IF_NULL(backend_anf); + if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) { + MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; + } + if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { + MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; + } + front_backend_anf_map_[front_anf] = backend_anf; + backend_front_anf_map_[backend_anf] = front_anf; +} + +void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { + MS_EXCEPTION_IF_NULL(old_backend_anf); + MS_EXCEPTION_IF_NULL(new_backend_anf); + if (old_backend_anf == new_backend_anf) { + MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString(); + return; + } + if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { + MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; + return; + } + if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { + MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString(); + } + front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; + backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; + // delete old kernel + (void)backend_front_anf_map_.erase(old_backend_anf); +} +// get kernel by anf +AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { + if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { + return nullptr; + } + return front_backend_anf_map_[front_anf]; +} + +bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { + return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); +} + +ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) { + if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) { + return nullptr; + } + return tensor_to_value_node_map_[tensor]; +} + +void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(tensor); + MS_EXCEPTION_IF_NULL(value_node); + tensor_to_value_node_map_[tensor] = value_node; +} + +void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(input); + MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num; + auto output_depend_edge = std::pair(node, depend_edge_num); + // add output depend edge of input + auto output_it = node_output_edges_.find(input); + if (output_it == node_output_edges_.end()) { + node_output_edges_[input] = std::vector>{output_depend_edge}; + } else { + output_it->second.push_back(output_depend_edge); + } + // add input depend edge of output + auto input_depend_edge = std::pair(input, depend_edge_num); + auto input_it = node_input_edges_.find(node); + if (input_it == node_input_edges_.end()) { + node_input_edges_[node] = std::vector>{input_depend_edge}; + } else { + input_it->second.push_back(input_depend_edge); + } + // add node input depend num + auto depend_it = node_input_num_.find(node); + if (depend_it == node_input_num_.end()) { + node_input_num_[node] = depend_edge_num; + } else { + depend_it->second += depend_edge_num; + } +} + +std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto it = node_output_edges_.find(node); + if (it == node_output_edges_.end()) { + MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]"; + } + std::vector output_nodes; + auto trans = [](const std::pair &pair) -> AnfNodePtr { return pair.first; }; + (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans); + return output_nodes; +} + +// Find control_depend real input nodes. +void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(result); + MS_EXCEPTION_IF_NULL(visited); + if (visited->find(anf_node) != visited->end()) { + MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; + return; + } + visited->insert(anf_node); + if (AnfAlgo::IsRealKernel(anf_node)) { + result->emplace_back(anf_node); + return; + } + if (!anf_node->isa()) { + return; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); + } + auto input0 = cnode->input(0); + if (IsPrimitive(input0, prim::kPrimMakeTuple)) { + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + GetAllFatherRealNode(cnode->input(i), result, visited); + } + } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { + if (cnode->inputs().size() != kTupleGetItemInputSize) { + MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); + } else if (IsPrimitive(input0, prim::kPrimDepend)) { + if (cnode->inputs().size() != kDependInputSize) { + MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; + } + GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); + GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); + } +} + +// update the depend relations of control depend +void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { + for (const auto &node : depends) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { + MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; + } + auto prior_node = cnode->input(kControlDependPriorIndex); + auto depend_node = cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(depend_node); + std::vector prior_nodes = {prior_node}; + std::vector depend_nodes = {depend_node}; + int depend_mode = 0; + if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { + depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); + } + MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() + << "], depend_mode :" << depend_mode << "."; + if (prior_node->isa() && depend_mode == 1) { + prior_nodes = GetOutputNodes(prior_node); + } + if (depend_node->isa()) { + depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; + } + + std::vector real_prior_nodes; + std::set prior_visited; + for (const auto &tmp : prior_nodes) { + GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); + } + + std::vector real_depend_nodes; + std::set depend_visited; + for (const auto &tmp : depend_nodes) { + GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); + } + + for (auto &first_node : real_prior_nodes) { + if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { + continue; + } + for (auto &second_node : real_depend_nodes) { + if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { + continue; + } + MS_EXCEPTION_IF_NULL(first_node); + MS_EXCEPTION_IF_NULL(second_node); + MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); + AddDependEdge(second_node, first_node, 1); + } + } + } +} + +bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *que, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(que); + MS_EXCEPTION_IF_NULL(visited_nodes); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { + return false; + } + // set the control depend visited but don't push it into the que + if (visited_nodes->find(node) != visited_nodes->end()) { + return true; + } + (void)visited_nodes->insert(cnode); + // add a 0 depend num to keep the link relations to prepare for finding zero output nodes + auto prior_node = cnode->input(kControlDependPriorIndex); + auto depend_node = cnode->input(kControlDependBehindIndex); + for (const auto &input : cnode->inputs()) { + AddDependEdge(node, input, 0); + } + PushNoVisitedNode(depend_node, que, visited_nodes); + PushNoVisitedNode(prior_node, que, visited_nodes); + return true; +} + +void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { + MS_EXCEPTION_IF_NULL(seed_nodes); + node_output_edges_.clear(); + node_input_num_.clear(); + node_input_edges_.clear(); + std::vector control_depends; + std::unordered_set visited_nodes; + std::queue que; + que.push(get_return()); + while (!que.empty()) { + auto node = que.front(); + que.pop(); + MS_EXCEPTION_IF_NULL(node); + if (node->isa() || node->isa()) { + seed_nodes->push(node); + continue; + } + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // handle data links + for (const auto &input : cnode->inputs()) { + size_t depend_edge_num = 1; + // handle control depend,all inputs of control depend has no depend edge + if (HandleControlDependNode(input, &que, &visited_nodes)) { + control_depends.push_back(input); + depend_edge_num = 0; + } + PushNoVisitedNode(input, &que, &visited_nodes); + AddDependEdge(node, input, depend_edge_num); + } + } + UpdateControlDependRelations(control_depends); +} + +void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); } + +bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; } + +AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const { + if (!IsInRefOutputMap(out_pair)) { + MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap"; + } + return ref_out_in_map_.at(out_pair); +} + +void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) { + if (IsInRefOutputMap(final_pair)) { + MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap"; + } + (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair)); +} + +bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { + if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) { + (void)graph_value_nodes_.erase(value_node); + return true; + } + return false; +} + +void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { + MS_EXCEPTION_IF_NULL(inputs_); + { + std::queue seed_nodes; + UpdateNodeEdgeList(&seed_nodes); + } + auto it = node_output_edges_.find(old_anf_node); + if (it != node_output_edges_.end()) { + const auto &outputs = it->second; + for (auto &output_node : outputs) { + MS_EXCEPTION_IF_NULL(output_node.first); + auto output_cnode = output_node.first->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto &output_node_inputs = output_cnode->inputs(); + // don't replace node if it is a control edge => output_node.second == 0 + if (output_node.second == 0) { + continue; + } + for (size_t i = 1; i < output_node_inputs.size(); i++) { + if (output_node_inputs[i] == old_anf_node.get()) { + output_cnode->set_input(i, new_anf_node); + } + } + // update graph inputs + for (size_t i = 0; i < inputs_->size(); i++) { + if ((*inputs_)[i] == old_anf_node.get()) { + MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() + << ",new graph input:" << new_anf_node->DebugString(); + (*inputs_)[i] = new_anf_node.get(); + break; + } + } + } + // update front to backend map + FrontBackendlMapUpdate(old_anf_node, new_anf_node); + } + { + std::queue seed_nodes; + UpdateNodeEdgeList(&seed_nodes); + } + // update graph inputs in child graph + auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), + [&old_anf_node](const std::pair> &n) -> bool { + return n.first == old_anf_node.get(); + }); + if (it_real_inputs != real_inputs_.end()) { + // erase old parameter in map + auto old_args = it_real_inputs->second; + real_inputs_.erase(it_real_inputs); + // insert new parameter to map + auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), + [&new_anf_node](const std::pair> &n) -> bool { + return n.first == new_anf_node.get(); + }); + if (iter != real_inputs_.end()) { + MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; + iter->second = old_args; + } else { + real_inputs_.emplace_back(new_anf_node, old_args); + } + } +} + +void KernelGraph::UpdateExecuteKernelStreamLabel() { + for (auto &kernel : execution_order_) { + AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); + } +} + +std::vector> KernelGraph::GetLeafGraphOrder() { + std::vector> leaf_graph_order; + if (IsLeafGraph()) { + leaf_graph_order.push_back(shared_from_this()->cast()); + } else { + for (const auto &child_graph : child_graph_order_) { + MS_EXCEPTION_IF_NULL(child_graph); + auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); + std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); + } + } + return leaf_graph_order; +} + +bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } + +std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { + std::vector result; + for (const auto &anf : execution_order_) { + if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { + result.push_back(anf->cast()); + } + } + return result; +} + +void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { + MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(arg); + MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); + MS_EXCEPTION_IF_NULL(parameter); + MS_EXCEPTION_IF_NULL(arg); + auto iter = std::find_if( + real_inputs_.begin(), real_inputs_.end(), + [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); + if (iter != real_inputs_.end()) { + auto &args = iter->second; + args.push_back(arg); + } else { + real_inputs_.emplace_back(parameter, std::vector(1, arg)); + } +} + +void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { + unreuse_args_[arg] = from_graph; +} + +void KernelGraph::UpdateCallRealInput() { + MS_LOG(INFO) << "Update graph id: " << graph_id_; + std::vector>> real_inputs_map; + for (auto &it : real_inputs_) { + auto parameter = it.first; + MS_EXCEPTION_IF_NULL(parameter); + auto real_inputs = it.second; + std::vector new_real_inputs; + for (auto &real_input : real_inputs) { + // if real input is a call node ,find the child graph output act as the new real input + auto tmp_real_input = GetCallRealOutputs(real_input); + std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); + // replace the call in unreuse_args_ + auto unreuse_arg_it = unreuse_args_.find(real_input); + if (unreuse_arg_it != unreuse_args_.end()) { + auto old_graph = unreuse_arg_it->second; + for (auto new_real_input : new_real_inputs) { + // if call reference graph output is parameter, it will be allowed to reuse + if (!new_real_input->isa()) { + unreuse_args_[new_real_input] = old_graph; + } + } + } + } + real_inputs_map.emplace_back(parameter, new_real_inputs); + } + real_inputs_ = real_inputs_map; +} + +void KernelGraph::PrintGraphExecuteOrder() const { + MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; + for (size_t i = 0; i < execution_order_.size(); i++) { + CNodePtr cur_cnode_ptr = execution_order_[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + std::string event_str; + std::string label_str; + if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) { + event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId)) + "]"; + } + + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) { + label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrLabelIndex)) + "]"; + } + + if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) { + auto label_list = AnfAlgo::GetNodeAttr>(cur_cnode_ptr, kAttrLabelSwitchList); + label_str = ", label_id["; + for (size_t j = 0; j < label_list.size(); ++j) { + label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]"); + } + } + + MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" + << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" + << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" + << event_str << label_str; + } +} + +void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { + if (front_node == nullptr || node == nullptr) { + MS_LOG(INFO) << "Front node or node is nullptr"; + return; + } + MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); + front_to_internal_outputs_map_[front_node] = node; + internal_outputs_to_front_map_[node] = front_node; +} + +void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { + if (new_node == nullptr || node == nullptr) { + MS_LOG(INFO) << "New node or node is nullptr"; + return; + } + if (node == new_node) { + MS_LOG(INFO) << "New node and node is the same"; + return; + } + auto iter = internal_outputs_to_front_map_.find(node); + if (iter == internal_outputs_to_front_map_.end()) { + MS_LOG(INFO) << "Node is not internal output"; + return; + } + MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); + internal_outputs_to_front_map_[new_node] = iter->second; + front_to_internal_outputs_map_[iter->second] = new_node; + internal_outputs_to_front_map_.erase(iter); +} + +AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { + auto iter = front_to_internal_outputs_map_.find(front_node); + if (iter != front_to_internal_outputs_map_.end()) { + return iter->second; + } + return nullptr; +} + +bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { + if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { + return true; + } + return false; +} + +AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { + auto iter = internal_outputs_to_front_map_.find(node); + if (iter != internal_outputs_to_front_map_.end()) { + return iter->second; + } + return nullptr; +} + +void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { + if (node == nullptr) { + return; + } + (void)final_output_kernels_.insert(node); +} + +bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { + if (node == nullptr) { + return false; + } + if (final_output_kernels_.find(node) != final_output_kernels_.end()) { + return true; + } + return false; +} + +std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } + +KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h new file mode 100644 index 0000000000..f353ed1dda --- /dev/null +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -0,0 +1,226 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H +#define MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/graph_utils.h" +#include "utils/contract.h" +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace session { +using AnfWithOutIndex = std::pair; +class KernelGraph : public FuncGraph { + public: + KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) { + inputs_ = std::make_shared>(); + execution_order_ = {}; + executable_ = true; + summary_node_exist_ = false; + stream_distinction_label_ = kInvalidDistincLabel; + } + ~KernelGraph() override; + + MS_DECLARE_PARENT(KernelGraph, FuncGraph); + + const std::vector &inputs() const; + std::vector *MutableInputs() const { return inputs_.get(); } + std::vector outputs() const; + CNodePtr NewCNode(const std::vector &inputs) override; + void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); + CNodePtr NewCNode(const CNodePtr &cnode); + ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); + ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); + std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); + void set_execution_order(const std::vector &order) { execution_order_ = order; } + const std::vector &execution_order() const { return execution_order_; } + void SetExecOrderByDefault(); + uint32_t graph_id() const { return graph_id_; } + void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } + + // and a new front to backend anf relation to maop + void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf); + // replace old backend anf with new backend anf + void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); + // get backend anf by front anf + AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); + // check backend node whether exist in map + bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); + // get value node by tensor + ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor); + // add value node tensor relation map + void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node); + // get all value nodes of graph + const std::unordered_set graph_value_nodes() const { return graph_value_nodes_; } + // add value node to graph + void AddValueNodeToGraph(const ValueNodePtr &value_node); + // ref output is in map + bool IsInRefOutputMap(const AnfWithOutIndex &pair) const; + // get ref correspond pairs + AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const; + // add ref correspond pairs + void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); + // get map + std::map GetRefMap() const { return ref_out_in_map_; } + // checkout whether loop exist in graph + void CheckLoop(); + // check whether graph is executable + bool executable() const { return executable_; } + // set executable of graph + void set_executable(bool executable) { executable_ = executable; } + // set summary_node of graph + void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; } + // check whether exist summary node in graph + bool summary_node_exist() const { return summary_node_exist_; } + // set invalid inputs for control sink + std::vector *MutableValidInputs() { return &valid_inputs_; } + std::vector valid_inputs() const { return valid_inputs_; } + // replace node in graph + void ReplaceNode(NotNull old_anf_node, NotNull new_anf_node); + // set stream label of graph + void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } + // get stream label of graph + uint32_t stream_distinction_label() { return stream_distinction_label_; } + // refresh execute kernel stream label + void UpdateExecuteKernelStreamLabel(); + // calculate the leaf graph order of root graph + std::vector> GetLeafGraphOrder(); + // the child graph of current graph + const std::vector> &child_graph_order() const { return child_graph_order_; } + void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } + // checkout whether current graph is leaf graph + bool IsLeafGraph() const; + + // set input_tensors pointer of control parameter + void set_input_ctrl_tensors(const std::shared_ptr> &input_tensors_ptr) { + input_ctrl_tensors_ = input_tensors_ptr; + } + // get input_tensors pointer of control parameter + std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } + // get parent kernel graph + std::shared_ptr parent_graph() const { return parent_graph_; } + // set parent kernel graph + void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } + // find anf node in graph + std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; + // get real inputs + const std::vector>> &real_inputs() const { return real_inputs_; } + void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); + // mark unreused args + void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph); + const std::map> &unreuse_args() const { return unreuse_args_; } + // used to dump ir + std::string ToString() const override; + // update the real input if the node is a call + void UpdateCallRealInput(); + + void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } + CNodePtr get_start_label() { return start_label_; } + void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } + CNodePtr get_end_goto() { return end_goto_; } + bool get_output_null() { return null_output_; } + void set_output_null(bool is_output_null) { null_output_ = is_output_null; } + void PrintGraphExecuteOrder() const; + const std::map> &summary_nodes() const { return summary_nodes_; } + void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } + void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); + void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); + AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; + bool IsInternalOutput(const AnfNodePtr &node) const; + AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; + void AddFinalOutputKernel(const AnfNodePtr &node); + bool IsFinalOutputKernel(const AnfNodePtr &node) const; + uint32_t current_epoch() const { return current_epoch_; } + void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } + + private: + // remove value node form graph + bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); + void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes); + // update node edge list + void UpdateNodeEdgeList(std::queue *seed_nodes); + // add node depend edge by data edge or control depend + void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); + // handle control depend + std::vector GetOutputNodes(const AnfNodePtr &node); + bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, + std::unordered_set *visited_nodes); + void UpdateControlDependRelations(const std::vector &depends); + + std::shared_ptr> inputs_; + std::vector execution_order_; + uint32_t graph_id_; + uint32_t stream_distinction_label_; + + // record map bettween front anf and backend anf,use two map implement bidirectional map + std::unordered_map front_backend_anf_map_; + std::unordered_map backend_front_anf_map_; + // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record + std::unordered_map tensor_to_value_node_map_; + // include all value nodes + std::unordered_set graph_value_nodes_; + std::unordered_map node_input_num_; + std::unordered_map>> node_input_edges_; + // record map between ref final output anf with index and ref origin input with index + std::map ref_out_in_map_; + std::unordered_map>> node_output_edges_; + std::map> summary_nodes_; + // graph needn't execute + bool executable_; + // exist summary node in graph + bool summary_node_exist_; + // valid inputs + std::vector valid_inputs_; + + // new members for control sink process + // all child grahs refers to partial node + std::map> node_to_child_graphs_; + // child graph execute order in root graph + std::vector> child_graph_order_; + + // input_tensors of control parameter + std::shared_ptr> input_ctrl_tensors_; + + // parameter graph + std::shared_ptr parent_graph_; + // record real parameters,inputs_ is the formal parameters + std::vector>> real_inputs_; + std::map> unreuse_args_; + + CNodePtr start_label_; + CNodePtr end_goto_; + bool null_output_; + std::unordered_map front_to_internal_outputs_map_; + std::unordered_map internal_outputs_to_front_map_; + std::set final_output_kernels_; + uint32_t current_epoch_; +}; +} // namespace session +using KernelGraphPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H diff --git a/mindspore/ccsrc/backend/session/session.cc b/mindspore/ccsrc/backend/session/session.cc new file mode 100644 index 0000000000..95484a1113 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "include/inference.h" +#include "backend/session/session.h" +#include "utils/load_onnx/anf_converter.h" +#include "backend/session/session_basic.h" +#include "backend/session/session_factory.h" +#include "utils/base_ref_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#ifdef ENABLE_D +#include "utils/context/ms_context.h" +#include "backend/session/ascend_session.h" +#else +#include "backend/session/cpu_session.h" +#endif + +namespace py = pybind11; +namespace mindspore::inference { +std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { + try { + inference::Session::RegAllOp(); + auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); + return anf_graph; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference LoadModel failed"; + return nullptr; + } +} + +void ExitInference() { + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return; + } + if (!ms_context->CloseTsd()) { + MS_LOG(ERROR) << "Inference CloseTsd failed!"; + return; + } +} + +std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { + try { + auto session = std::make_shared(); + auto ret = session->Init(device, device_id); + if (ret != 0) { + return nullptr; + } + return session; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CreatSession failed"; + return nullptr; + } +} + +void Session::RegAllOp() { + static std::mutex init_mutex; + static bool Initialized = false; + + std::lock_guard lock(init_mutex); + if (Initialized) { + return; + } + Initialized = true; + MsContext::GetInstance()->set_execution_mode(kGraphMode); + Py_Initialize(); + auto c_expression = PyImport_ImportModule("mindspore._c_expression"); + if (c_expression == nullptr) { + MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; + return; + } + PyObject *c_expression_dict = PyModule_GetDict(c_expression); + + PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); + if (op_info_loader_class == nullptr) { + MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; + return; + } + PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); + if (op_info_loader == nullptr) { + MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; + return; + } + PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); + if (op_info_loader_ins == nullptr) { + MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; + return; + } + auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); + if (all_ops_info_vector_addr_ul == nullptr) { + MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; + return; + } + auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); + auto all_ops_info = static_cast *>(all_ops_info_vector_addr); + for (auto op_info : *all_ops_info) { + kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); + } + all_ops_info->clear(); + delete all_ops_info; + Py_DECREF(op_info_loader); + Py_DECREF(op_info_loader_class); + Py_DECREF(c_expression_dict); + Py_DECREF(c_expression); + return; +} + +uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { + MS_ASSERT(session_impl_ != nullptr); + try { + auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); + py::gil_scoped_release gil_release; + return graph_id; + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference CompileGraph failed"; + return static_cast(-1); + } +} + +MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { + try { + std::vector inTensors; + inTensors.resize(inputs.size()); + bool has_error = false; + std::transform(inputs.begin(), inputs.end(), inTensors.begin(), + [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { + if (tensor_ptr == nullptr) { + MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; + has_error = true; + return nullptr; + } + auto tensor = static_cast(tensor_ptr.get()); + if (tensor == nullptr) { + MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; + has_error = true; + return nullptr; + } + return tensor->tensor(); + }); + if (has_error) { + MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; + std::vector> multiTensor; + return multiTensor; + } + VectorRef outputs; + session_impl_->RunGraph(graph_id, inTensors, &outputs); + + return TransformVectorRefToMultiTensor(outputs); + } catch (std::exception &e) { + MS_LOG(ERROR) << "Inference Rungraph failed"; + return MultiTensor(); + } +} +namespace { +string AjustTargetName(const std::string &device) { + if (device == kAscendDevice) { + return std::string(kAscendDevice) + "Inference"; + } else { + MS_LOG(ERROR) << "Only support device Ascend right now"; + return ""; + } +} +} // namespace +int Session::Init(const std::string &device, uint32_t device_id) { + RegAllOp(); + auto ms_context = MsContext::GetInstance(); + ms_context->set_execution_mode(kGraphMode); + ms_context->set_device_id(device_id); + auto ajust_device = AjustTargetName(device); + if (ajust_device == "") { + return -1; + } + ms_context->set_device_target(device); + session_impl_ = session::SessionFactory::Get().Create(ajust_device); + if (session_impl_ == nullptr) { + MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; + return -1; + } + session_impl_->Init(device_id); + if (ms_context == nullptr) { + MS_LOG(ERROR) << "Get Context failed!"; + return -1; + } + if (!ms_context->OpenTsd()) { + MS_LOG(ERROR) << "Session init OpenTsd failed!"; + return -1; + } + return 0; +} + +Session::Session() = default; +} // namespace mindspore::inference diff --git a/mindspore/ccsrc/backend/session/session.h b/mindspore/ccsrc/backend/session/session.h new file mode 100644 index 0000000000..6ea9cfaa47 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H +#define MINDSPORE_CCSRC_SESSION_SESSION_H + +#include +#include +#include +#include +#include +#include + +#include "backend/session/session_basic.h" +#include "ir/anf.h" +#include "include/inference.h" + +namespace mindspore { +namespace inference { +class Session : public MSSession { + public: + Session(); + + uint32_t CompileGraph(std::shared_ptr funcGraphPtr) override; + + MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) override; + + int Init(const std::string &device, uint32_t device_id); + + static void RegAllOp(); + + private: + std::shared_ptr session_impl_ = nullptr; + std::vector graph_id_; +}; +} // namespace inference +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc new file mode 100644 index 0000000000..a7960c4695 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -0,0 +1,1128 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/session_basic.h" +#include +#include +#include +#include +#include "pipeline/jit/parse/data_converter.h" +#include "ir/manager.h" +#include "ir/param_value.h" +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/operator/ops.h" +#include "common/trans.h" +#include "utils/context/ms_context.h" +#include "utils/config_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/optimizer/common/common_backend_optimization.h" +#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "common/utils.h" +#include "ir/dtype.h" +#include "ir/anf.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace session { +static std::shared_ptr> python_paras; +void ClearPythonParasMap() { python_paras = nullptr; } +namespace { +const int kSummaryGetItem = 2; + +ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) { + if (node == nullptr) { + return nullptr; + } + auto parameter = node->cast(); + if (parameter == nullptr || !parameter->has_default()) { + return nullptr; + } + return parameter->default_param(); +} + +BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; + // if node is a value node, no need sync addr from device to host + if (!AnfAlgo::OutputAddrExist(node, output_index)) { + if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->value(); + } + if (node->isa()) { + for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { + if (input_idx >= input_tensors.size()) { + MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); + } + if (graph.inputs()[input_idx] == node) { + return input_tensors[input_idx]; + } + } + MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr"; + } + } + // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) + auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); + MS_EXCEPTION_IF_NULL(address); + auto shape = AnfAlgo::GetOutputInferShape(node, output_index); + TypeId type_id = kNumberTypeFloat32; + type_id = AnfAlgo::GetOutputInferDataType(node, output_index); + std::vector temp_shape; + if (graph.IsInternalOutput(node)) { + temp_shape.emplace_back(1); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + tensor->set_device_address(address); + tensor->set_dirty(false); + return tensor; + } + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + // if in paynative mode,data only copyed to host when user want to print data + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { + tensor->set_device_address(address); + tensor->set_dirty(false); + } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), + LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { + MS_LOG(INFO) << "Output sync device to host error!!!"; + tensor->set_dirty(false); + } + return tensor; +} + +BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(anf); + MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); + MS_EXCEPTION_IF_NULL(item_with_index.first); + MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString(); + // special handle for maketuple + if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { + auto cnode = item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + VectorRef ret; + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors); + ret.push_back(out); + } + return ret; + } + // if is graph return nothing ,the function should return a null anylist + size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first); + if (size == 0) { + return VectorRef(); + } + return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); +} + +BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(anf); + if (!AnfAlgo::IsRealKernel(anf)) { + MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel"; + } + if (anf->isa()) { + return CreateOneTensor(anf, 0, graph, input_tensors); + } + VectorRef ret; + if (anf->isa() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) { + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) { + auto out = CreateOneTensor(anf, i, graph, input_tensors); + ret.emplace_back(out); + } + } + return ret; +} + +ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + auto value_node = anf->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + return nullptr; + } + auto new_value_node = graph->NewValueNode(value_node); + graph->FrontBackendlMapAdd(anf, new_value_node); + graph->AddValueNodeToGraph(new_value_node); + return new_value_node; +} + +size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Load kInputCtrlTensors"; + auto inputs_params = graph->input_ctrl_tensors(); + if (inputs_params == nullptr) { + return 0; + } + if (inputs_params->size() < 2) { + MS_LOG(EXCEPTION) << "Illegal inputs_params size"; + } + auto tensor = (*inputs_params)[0]; + MS_EXCEPTION_IF_NULL(tensor); + auto *val = static_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + tensor->set_dirty(true); + // set loop_count to zero + MS_EXCEPTION_IF_NULL(inputs); + inputs->push_back(tensor); + + auto epoch_tensor = (*inputs_params)[1]; + MS_EXCEPTION_IF_NULL(epoch_tensor); + auto *epoch_val = static_cast(epoch_tensor->data_c()); + MS_EXCEPTION_IF_NULL(epoch_val); + *epoch_val = graph->current_epoch(); + epoch_tensor->set_dirty(true); + inputs->push_back(epoch_tensor); + MS_LOG(INFO) << "Load epoch_val:" << *epoch_val; + + graph->set_current_epoch(graph->current_epoch() + 1); + + return inputs_params->size(); +} + +ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input_tensor); + auto value_node = std::make_shared(input_tensor); + MS_EXCEPTION_IF_NULL(value_node); + // construct abstract of value node + auto type_of_tensor = input_tensor->Dtype(); + auto shape_of_tensor = input_tensor->shape(); + auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); + value_node->set_abstract(abstract); + // add value node to graph + auto input_value_node = graph->NewValueNode(value_node); + graph->AddValueNodeToGraph(input_value_node); + return input_value_node; +} + +ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor, + int tensor_mask) { + MS_EXCEPTION_IF_NULL(graph); + auto param = graph->NewParameter(); + MS_EXCEPTION_IF_NULL(param); + if (tensor_mask == kParameterWeightTensorMask) { + auto param_value_new = std::make_shared(); + param->set_default_param(param_value_new); + } + // set the kernel info of parameter + auto kernel_build_info_builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(input_tensor); + if (input_tensor->device_address().get() == nullptr) { + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type(); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{param_init_data_type}); + } else { + kernel_build_info_builder->SetOutputsFormat(std::vector{input_tensor->device_address()->format()}); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{input_tensor->device_address()->type_id()}); + } + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); + // construct abstract of parameter + auto type_of_tensor = input_tensor->Dtype(); + auto shape_of_tensor = input_tensor->shape(); + auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); + param->set_abstract(abstract); + return param; +} + +void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { + MS_LOG(INFO) << "Graph outputs:"; + const size_t max_deep = 10; + if (recurse_level > max_deep) { + MS_LOG(INFO) << "Recurse too deep"; + return; + } + std::string tab_str; + for (size_t i = 0; i < recurse_level; i++) { + tab_str = tab_str.append(" "); + } + if (any.is()) { + (void)tab_str.append("{"); + MS_LOG(INFO) << tab_str; + auto any_list = any.cast(); + for (auto &it : any_list) { + DumpGraphOutput(it, recurse_level + 1); + } + (void)tab_str.append("}"); + MS_LOG(INFO) << tab_str; + } + (void)tab_str.append(any.ToString()); + MS_LOG(INFO) << tab_str; +} + +bool ExistSummaryNode(const KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto all_nodes = DeepLinkedGraphSearch(ret); + for (auto &n : all_nodes) { + if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || + IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { + return true; + } + } + return false; +} +} // namespace + +GraphId SessionBasic::graph_sum_ = 0; + +KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { + auto it = graphs_.find(graph_id); + if (it == graphs_.end()) { + MS_LOG(WARNING) << "Can't find graph " << graph_id; + return nullptr; + } + return it->second; +} + +void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { + auto graph_id = GetGraphIdByNode(out_node); + if (graph_id == kInvalidGraphId) { + return; + } + auto node_graph = GetGraph(graph_id); + if (node_graph == nullptr) { + return; + } + MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); + auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); + if (ref_node == nullptr) { + MS_LOG(INFO) << "No corresponding internal output for output node"; + return; + } + auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); + auto ref_real_node = real_kernel.first; + auto ref_real_node_index = real_kernel.second; + if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && + node_graph->IsFinalOutputKernel(ref_real_node)) { + auto kernel_info = ref_real_node->kernel_info(); + if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { + MS_LOG(INFO) << "No kernel info"; + return; + } + auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); + if (address == nullptr) { + MS_LOG(INFO) << "No kernel address"; + return; + } + auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); + auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); + parameter->set_kernel_info(std::make_shared()); + auto d_kernel_info = parameter->kernel_info(); + MS_EXCEPTION_IF_NULL(d_kernel_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsDeviceType({type}); + builder.SetOutputsFormat({format}); + d_kernel_info->set_select_kernel_build_info(builder.Build()); + AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + } +} + +std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, + KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + std::vector parameters; + std::vector pre_graph_out = {node}; + // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive + if (!AnfAlgo::IsRealKernel(node)) { + pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); + } + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { + auto parameter = graph->NewParameter(); + MS_EXCEPTION_IF_NULL(parameter); + parameter->set_abstract(abstract); + auto new_parameter = graph->NewParameter(parameter); + parameters.push_back(new_parameter); + valid_inputs->push_back(valid_input); + graph_inputs->push_back(new_parameter); + }; + for (const auto &out_node : pre_graph_out) { + MS_EXCEPTION_IF_NULL(out_node); + auto abstract = out_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + // create multiple parameters if is a tuple output real kernel + if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { + auto tuple_abstract = abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; + for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { + create_parameter((*tuple_abstract)[output_idx]); + } + continue; + } + // create single parameter if is a abstract real kernel + create_parameter(out_node->abstract()); + InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); + } + return parameters; +} + +ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, + KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; + } + MS_EXCEPTION_IF_NULL(graph); + auto param_value = GetParamDefaultValue(anf); + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + ParameterPtr new_parameter = nullptr; + // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter + if (python_paras == nullptr) { + python_paras = std::make_shared>(); + } + auto iter = python_paras->find(param_value); + if (iter != python_paras->end()) { + new_parameter = iter->second; + } else { + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + new_parameter = graph->NewParameter(anf->cast()); + if (param_value != nullptr) { + (*python_paras)[param_value] = new_parameter; + } + TraceManager::EndTrace(); + } + graph_inputs->push_back(new_parameter); + valid_inputs->push_back(valid_input); + return new_parameter; +} + +AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; + auto parameters = CreateParameterFromTuple(anf, valid_input, graph); + if (parameters.empty()) { + MS_LOG(EXCEPTION) << "No parameter exist!!"; + } + if (parameters.size() == 1) { + return parameters[0]; + } + std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; + (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); + auto make_tuple = graph->NewCNode(make_tuple_input); + MS_EXCEPTION_IF_NULL(make_tuple); + MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; + return make_tuple; +} + +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, + bool *from_other_graph, + std::unordered_map *other_graph_cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(from_other_graph); + MS_EXCEPTION_IF_NULL(other_graph_cnode); + *from_other_graph = false; + // get primitive of old node + std::vector cnode_inputs; + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + if (prim != nullptr) { + // push attr to inputs[0] of new cnode + cnode_inputs.push_back(std::make_shared(std::make_shared(*prim))); + } else { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(fg); + auto new_fg = BasicClone(fg); + cnode_inputs.push_back(std::make_shared(new_fg)); + } + auto origin_inputs = cnode->inputs(); + bool optimize_depend = false; + if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && + origin_inputs[kRealInputIndexInDepend]->isa()) { + optimize_depend = true; + } + // if has multiple depends,only select first depend as parameter + for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { + auto anf = origin_inputs[input_idx]; + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { + cnode_inputs.push_back((*other_graph_cnode)[anf]); + continue; + } else if (anf->isa() && !IsValueNode(anf)) { + // if input is a value node, + auto new_value_node = CreateNewValueNode(anf, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + continue; + } else if (anf->isa()) { + auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); + cnode_inputs.push_back(new_parameter); + if (GetGraphIdByNode(anf) == kInvalidGraphId) { + graph->FrontBackendlMapAdd(anf, new_parameter); + } else { + (*other_graph_cnode)[anf] = new_parameter; + } + continue; + } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { + cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); + continue; + } else { + *from_other_graph = true; + // the input node is a cnode from other graph + auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); + cnode_inputs.push_back(parameter_from_cnode); + (*other_graph_cnode)[anf] = parameter_from_cnode; + } + } + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); + auto new_cnode = graph->NewCNode(cnode_inputs); + TraceManager::EndTrace(); + return new_cnode; +} + +CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(node_input); + MS_EXCEPTION_IF_NULL(graph); + // switch input generalizes partial + if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || + AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { + return node_input->cast(); + } + if (node_input->isa()) { + MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; + } + std::vector partial_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; + if (node_input->isa() && IsValueNode(node_input)) { + partial_inputs.emplace_back(node_input); + auto partial_node = graph->NewCNode(partial_inputs); + return partial_node; + } + KernelGraphPtr kernel_graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(kernel_graph); + kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); + partial_inputs.emplace_back(std::make_shared(kernel_graph)); + auto partial_node = graph->NewCNode(partial_inputs); + return partial_node; +} + +CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(graph); + auto node = anf_node->cast(); + MS_EXCEPTION_IF_NULL(node); + if (node->inputs().size() < kSwitchInputSize) { + MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; + } + auto primitive = NewValueNode(std::make_shared(prim::kPrimSwitch->name())); + std::vector switch_inputs = {primitive, node->input(1)}; + for (size_t index = 2; index < node->inputs().size(); index++) { + auto input = CreateSwitchInput(node->input(index), graph); + switch_inputs.emplace_back(input); + } + auto switch_node = graph->NewCNode(switch_inputs); + return switch_node; +} + +std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + // create primitive of cnode:call(partial or switch) + std::vector cnode_inputs = { + graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); + if (cnode_input == nullptr) { + MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() + << ", but input[0] has not been created."; + } + // if the node is partial, insert the inputs of partial to the call + if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) { + auto partial_node = attr_input->cast(); + MS_EXCEPTION_IF_NULL(partial_node); + auto partial_inputs = partial_node->inputs(); + std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(), + std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node)); + return graph->GetBackendAnfByFrontAnf(node); + }); + return cnode_inputs; + } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { + auto switch_node = HandleSwitchInputs(cnode_input, graph); + cnode_inputs.emplace_back(switch_node); + return cnode_inputs; + } + MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; +} + +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(graph); + std::vector cnode_inputs; + auto attr_input = cnode->input(kAnfPrimitiveIndex); + MS_EXCEPTION_IF_NULL(attr_input); + if (AnfAlgo::IsGraphKernel(cnode)) { + auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); + MS_EXCEPTION_IF_NULL(fg); + auto new_fg = BasicClone(fg); + cnode_inputs.push_back(std::make_shared(new_fg)); + } else if (IsValueNode(attr_input)) { + // create primitive of cnode:call + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; + // create a ValueNode as input of cnode:call + if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); + } else { + auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph); + if (new_value_node != nullptr) { + cnode_inputs.emplace_back(new_value_node); + } + } + } else if (attr_input->isa()) { + cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); + } else { + // get primitive of old node + auto prim = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + // push attr to inputs[0] of new cnode + cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; + } + + for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { + auto anf = cnode->input(input_idx); + MS_EXCEPTION_IF_NULL(anf); + // anf has been created before + if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { + cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); + continue; + } else if (IsValueNode(anf)) { + continue; + } + MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; + } + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); + auto new_cnode = graph->NewCNode(cnode_inputs); + TraceManager::EndTrace(); + return new_cnode; +} + +ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + auto value_node = anf->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); + MS_EXCEPTION_IF_NULL(sub_func_graph); + if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { + MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; + } + auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; + + ValueNodePtr new_value_node = std::make_shared(sub_kernel_graph); + new_value_node->set_abstract(value_node->abstract()); + // create new kernel_info of new value_node + auto kernel_info = std::make_shared(); + kernel_info->SetFeatureMapFlag(false); + new_value_node->set_kernel_info(kernel_info); + // create kernel_build_info for new value node + auto kernel_build_info_builder = std::make_shared(); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); + AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); + + graph->FrontBackendlMapAdd(anf, new_value_node); + + return new_value_node; +} + +ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(graph); + if (!anf->isa()) { + MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; + } + + auto param_value = GetParamDefaultValue(anf); + ParameterPtr new_parameter = nullptr; + if (python_paras == nullptr) { + python_paras = std::make_shared>(); + } + auto iter = python_paras->find(param_value); + if (iter != python_paras->end()) { + new_parameter = iter->second; + } else { + TraceManager::DebugTrace(std::make_shared(anf->debug_info())); + new_parameter = graph->NewParameter(anf->cast()); + if (param_value != nullptr) { + (*python_paras)[param_value] = new_parameter; + } + TraceManager::EndTrace(); + } + + return new_parameter; +} + +KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + std::unordered_map other_graph_cnode; + auto graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Create graph: " << graph->graph_id(); + size_t from_other_graph_depend_num = 0; + for (const auto &node : lst) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); + if (!node->isa()) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // create a new cnode object + bool from_other_graph = false; + // only first depend from other graph can create + bool valid_input = true; + if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { + valid_input = false; + } + auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { + from_other_graph_depend_num++; + } + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + // record map relations between anf from ME and new anf node used in backend + graph->FrontBackendlMapAdd(node, new_cnode); + } + // add a make_tuple at the end of graph as output + graph->set_output(ConstructOutput(outputs, graph)); + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = MakeManager({graph}); + if (manager) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + graph->SetExecOrderByDefault(); + if (ExistSummaryNode(graph.get())) { + graph->set_summary_node_exist(true); + } + opt::BackendCommonOptimization(graph); + return graph; +} + +void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // create a new cnode object + auto new_cnode = CreateNewCNode(cnode, graph.get()); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); + new_cnode->set_scope(cnode->scope()); + graph->FrontBackendlMapAdd(node, new_cnode); + if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { + graph->set_return(new_cnode); + } +} +std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph, + std::vector *all_out_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(all_out_graph); + auto node_list = TopoSort(func_graph->get_return()); + auto graph = NewKernelGraph(); + MS_EXCEPTION_IF_NULL(graph); + front_backend_graph_map_[func_graph] = graph; + MS_LOG(INFO) << "Create graph: " << graph->graph_id(); + + bool is_trace_back = false; + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); + if (node->isa()) { + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + auto new_parameter = CreateNewParameter(node, graph.get()); + graph_inputs->push_back(new_parameter); + graph->FrontBackendlMapAdd(node, new_parameter); + continue; + } else if (node->isa()) { + if (!IsValueNode(node)) { + // if input is a common value node, + (void)CreateNewValueNode(node, graph.get()); + } else { + // if input is a ValueNode + FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); + if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { + is_trace_back = true; + } else { + (void)ConstructKernelGraph(child_graph, all_out_graph); + } + (void)CreateValueNodeKernelGraph(node, graph.get()); + } + continue; + } else { + CreateCNodeKernelGraph(node, graph); + } + } + // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. + graph->set_output_null(is_trace_back); + AddParameterToGraphInputs(func_graph->parameters(), graph.get()); + graph->SetExecOrderByDefault(); + if (ExistSummaryNode(graph.get())) { + graph->set_summary_node_exist(true); + } + all_out_graph->push_back(graph); + return graph; +} + +void SessionBasic::AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->clear(); + for (auto ¶meter : parameters) { + MS_EXCEPTION_IF_NULL(parameter); + auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); + if (backend_parameter == nullptr) { + // for example "def f(x,y,z) {return x + y}", parameter z in unused + auto new_parameter = CreateNewParameter(parameter, graph); + graph_inputs->push_back(new_parameter); + MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); + continue; + } + MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); + graph_inputs->push_back(backend_parameter); + } +} + +// run graph steps +void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + std::vector inputs(inputs_const); + size_t input_ctrl_size = 2; + MS_EXCEPTION_IF_NULL(kernel_graph); + if (kernel_graph->input_ctrl_tensors()) { + input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); + } + auto input_nodes = kernel_graph->inputs(); + if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { + MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() + << ", input_ctrl_size:" << input_ctrl_size; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + MS_EXCEPTION_IF_NULL(tensor); + auto input_node = input_nodes[i]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + auto pk_node = input_node->cast(); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + bool need_sync = false; + if (ms_context->enable_pynative_infer()) { + if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { + need_sync = true; + } + } else { + if (tensor->is_dirty()) { + need_sync = true; + } else if (tensor->device_address() != device_address) { + (void)tensor->data_sync(); + need_sync = true; + } + } + if (need_sync) { + if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { + tensor->set_device_address(device_address); + } + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + tensor->set_dirty(false); + } +} + +void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, + const std::vector &input_tensors) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + if (!kernel_graph->child_graph_order().empty()) { + // use the last child graph output as the root graph output + UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); + return; + } + auto anf_outputs = kernel_graph->outputs(); + for (auto &item : anf_outputs) { + MS_EXCEPTION_IF_NULL(item); + MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; + if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { + outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); + continue; + } + outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors)); + } +} + +void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { + MS_EXCEPTION_IF_NULL(callback); + summary_callback_ = callback; +} + +void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } + +void SessionBasic::GetSummaryNodes(KernelGraph *graph) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + if (!graph->summary_node_exist()) { + return; + } + auto summary = graph->summary_nodes(); + auto apply_list = TopoSort(graph->get_return()); + for (auto &n : apply_list) { + MS_EXCEPTION_IF_NULL(n); + if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || + IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { + auto cnode = n->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() <= kSummaryGetItem) { + MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!"; + } + auto node = cnode->input(kSummaryGetItem); + MS_EXCEPTION_IF_NULL(node); + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (!AnfAlgo::IsRealKernel(item_with_index.first)) { + MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); + } + summary[n->fullname_with_scope()] = item_with_index; + } + } + graph->set_summary_nodes(summary); + MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); +} + +void SessionBasic::Summary(KernelGraph *graph) { + if (summary_callback_ == nullptr) { + return; + } + MS_EXCEPTION_IF_NULL(graph); + bool exist_summary = graph->summary_node_exist(); + if (!exist_summary) { + return; + } + GetSummaryNodes(graph); + auto summary_outputs = graph->summary_nodes(); + std::map params_list; + // fetch outputs apply kernel in session & run callback functions + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetOutputAddr(node, index); + auto shape = AnfAlgo::GetOutputInferShape(node, index); + TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); + std::vector temp_shape; + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + MS_EXCEPTION_IF_NULL(address); + if (!address->GetPtr()) { + continue; + } + if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), + tensor->data_type(), tensor->data_c())) { + MS_LOG(ERROR) << "Failed to sync output from device to host."; + } + tensor->set_dirty(false); + params_list[output_item.first] = tensor; + } + // call callback function here + summary_callback_(0, params_list); +} + +CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::vector output_args; + for (const auto &output : outputs) { + MS_EXCEPTION_IF_NULL(output); + MS_LOG(INFO) << "Output:" << output->DebugString(); + } + auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { + auto backend_anf = graph->GetBackendAnfByFrontAnf(out); + if (backend_anf != nullptr) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->execution_mode() == kPynativeMode) { + return backend_anf; + } + auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); + auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); + MS_EXCEPTION_IF_NULL(out); + auto out_func_graph = out->func_graph(); + MS_EXCEPTION_IF_NULL(out_func_graph); + auto out_func_graph_manager = out_func_graph->manager(); + if (out_func_graph_manager == nullptr) { + return backend_anf; + } + auto node_users = out_func_graph_manager->node_users(); + auto users = node_users[out]; + bool internal_output = true; + std::string kernel_target = GetCNodeTarget(front_real_kernel.first); + for (auto user : users) { + if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { + internal_output = false; + break; + } + } + if (internal_output) { + MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); + graph->AddInternalOutput(out, backend_real_kernel.first); + } + return backend_anf; + } + MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; + }; + output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); + (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), + [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); }); + return graph->NewCNode(output_args); +} + +void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph) { + MS_LOG(INFO) << "Start!"; + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + MS_EXCEPTION_IF_NULL(graph); + if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) { + for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) { + auto idx = NewValueNode(SizeToInt(output_index)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(output_index); + idx->set_abstract(std::make_shared(imm)); + auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); + std::vector types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)}; + std::vector> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)}; + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get()); + make_tuple_inputs.push_back(getitem); + } + } else { + make_tuple_inputs.push_back(cnode); + } + // create output + auto g_output = graph->NewCNode(make_tuple_inputs); + graph->set_output(g_output); + // set graph manager,which now is only used to get valuenodes and hardware optimizing + MS_EXCEPTION_IF_NULL(context_); + FuncGraphManagerPtr manager = context_->manager(); + if (manager != nullptr) { + manager->AddFuncGraph(graph); + graph->set_manager(manager); + } + MS_LOG(INFO) << "Finish!"; +} + +std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, + const std::vector &input_tensors, + const std::vector &tensors_mask) { + auto graph = std::make_shared(); + std::vector inputs; + // set input[0] + PrimitivePtr op_prim = op_run_info.py_primitive; + MS_EXCEPTION_IF_NULL(op_prim); + inputs.push_back(std::make_shared(op_prim)); + // set input parameter + MS_LOG(INFO) << "Input tensor size: " << input_tensors.size(); + if (input_tensors.size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " + << tensors_mask.size(); + } + for (size_t i = 0; i < input_tensors.size(); ++i) { + if (tensors_mask[i] == kValueNodeTensorMask) { + auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]); + inputs.push_back(value_node); + continue; + } + auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); + inputs.push_back(parameter); + auto mutable_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(mutable_inputs); + mutable_inputs->push_back(parameter); + } + // set execution order + auto cnode = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode); + // set abstract,which include inferred shapes and types + cnode->set_abstract(op_run_info.abstract); + // set execution order + std::vector exe_order = {cnode}; + graph->set_execution_order(exe_order); + // set output + CreateOutputNode(cnode, graph); + return graph; +} + +BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) { + if (utils::isa(base_ref)) { + auto ref_list = utils::cast(base_ref); + py::tuple output_tensors(ref_list.size()); + for (size_t i = 0; i < ref_list.size(); ++i) { + auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef + if (utils::isa(output)) { + auto tensor_ptr = utils::cast(output); + MS_EXCEPTION_IF_NULL(tensor_ptr); + output_tensors[i] = tensor_ptr; + } else if (utils::isa(output)) { + py::object obj = utils::cast(output).object_; + py::tuple tensor_tuple = py::cast(obj); + output_tensors[i] = tensor_tuple; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } + } + return output_tensors; // turn tuple to py::object and store in PyObjectRef + } else if (utils::isa(base_ref)) { + return base_ref; + } else { + MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; + } +} + +KernelGraphPtr SessionBasic::NewKernelGraph() { + auto graph = std::make_shared(); + graph->set_graph_id(graph_sum_); + graphs_[graph_sum_++] = graph; + return graph; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h new file mode 100755 index 0000000000..c662e3978b --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -0,0 +1,160 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H +#define MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H + +#include +#include +#include +#include +#include +#include + +#include "utils/base_ref_extends.h" +#include "backend/session/session_context.h" +#include "backend/session/kernel_graph.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "utils/any.h" +#include "utils/contract.h" +#include "pipeline/pynative/pynative_execute.h" +#include "runtime/device/kernel_info.h" +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif + +namespace mindspore { +using GraphId = uint32_t; +using GraphInfo = std::string; +namespace session { +void ClearPythonParasMap(); +using CallBackFunc = uint32_t (*)(uint32_t graph_id, + const std::map ¶ms_list); +using AnyList = std::vector; +using AnyListPtr = std::shared_ptr; + +using OpRunInfo = pynative::OpExecInfo; +using OpRunInfoPtr = std::shared_ptr; + +class SessionBasic { + public: + SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { +#ifdef ENABLE_DEBUGGER + debugger_ = nullptr; +#endif + } + + virtual void Init(uint32_t device_id) { device_id_ = device_id; } + + virtual ~SessionBasic() { summary_callback_ = nullptr; } + + virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; + virtual GraphId CompileGraph(NotNull func_graph) { return kInvalidGraphId; } + // build graph, used to handle multiple child graphs + virtual void BuildGraph(GraphId) {} + + virtual void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) = 0; + + virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors, + const std::vector &tensors_mask) {} + + virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors) { + return py::tuple(); + } + + virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); + + void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph); + + std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, + std::vector *all_out_graph); + + CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, + std::unordered_map *other_graph_cnode); + CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); + + CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); + CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); + std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); + + // set parameters of final graph + virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } + // set output of final graph + virtual void SetFinalGraphOutput(const BaseRef &) {} + // insert switch and set the relative active ops + virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} + // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter + virtual void SetChildGraphInput(GraphId, const VectorRef &) {} + // get graph id in child graphs by ME front anf node pointer + virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } + virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } + virtual void SetActive(GraphId, GraphId) {} + virtual void GetSummaryNodes(KernelGraph *graph); + +#ifdef ENABLE_DEBUGGER + // set debugger + void SetDebugger() { + debugger_ = Debugger::GetInstance(); + debugger_->Init(device_id_); + } +#endif + + protected: + // Get graph by graph id ,if not exist return null ptr + KernelGraphPtr GetGraph(GraphId graph_id); + virtual void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const; + void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, + const std::vector &input_tensors) const; + void Reorder(std::vector *node_list); + void Summary(KernelGraph *graph); + // create graph output for RunOp + void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph); + CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph); + // create a single run op graph + std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, + const std::vector &input_tensors, + const std::vector &tensors_mask); + // trans BaseRef list to py::tuple + BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); + // create a new kernel graph and update the graph sum + KernelGraphPtr NewKernelGraph(); + std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); + virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); + ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); + AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); + void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); + void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); + + std::unordered_map> graphs_; + std::unordered_map> run_op_graphs_; + std::unordered_map front_backend_graph_map_; + std::shared_ptr context_; + CallBackFunc summary_callback_; + static GraphId graph_sum_; + uint32_t device_id_; +#ifdef ENABLE_DEBUGGER + std::shared_ptr debugger_; +#endif +}; + +using SessionPtr = std::shared_ptr; +using NamedSummaryOutputs = std::map>; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/backend/session/session_context.cc b/mindspore/ccsrc/backend/session/session_context.cc new file mode 100644 index 0000000000..f5ec49c090 --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_context.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/session_context.h" +namespace mindspore { +namespace session { +std::shared_ptr Context::GetInstance() { + static std::shared_ptr context_singleton = std::make_shared(); + return context_singleton; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_context.h b/mindspore/ccsrc/backend/session/session_context.h new file mode 100644 index 0000000000..22cc0c813a --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_context.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H +#define MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H +#include +#include +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "pipeline/jit/resource.h" +#include "utils/context/ms_context.h" +namespace mindspore { +namespace session { +const char kInputCtrlTensors[] = "input_ctrl_tensors"; + +class Context : public pipeline::ResourceBase { + public: + explicit Context(std::string target = kAscendDevice, uint32_t device_id = 0) + : target_(std::move(target)), device_id_(device_id) {} + ~Context() override = default; + + uint32_t device_id() const { return device_id_; } + static std::shared_ptr GetInstance(); + void AddManager(const FuncGraphManagerPtr &m) { manager_list_.push_back(m); } + + private: + std::vector manager_list_; + std::string target_; + uint32_t device_id_; +}; +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H diff --git a/mindspore/ccsrc/backend/session/session_factory.cc b/mindspore/ccsrc/backend/session/session_factory.cc new file mode 100644 index 0000000000..8a8f9a9cea --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_factory.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/session/session_factory.h" +#include +#include +#include +namespace mindspore { +namespace session { +SessionFactory &SessionFactory::Get() { + static SessionFactory instance; + return instance; +} + +void SessionFactory::Register(const std::string &device_name, SessionCreator &&session_creator) { + if (session_creators_.end() == session_creators_.find(device_name)) { + (void)session_creators_.emplace(device_name, session_creator); + } +} + +std::shared_ptr SessionFactory::Create(const std::string &device_name) { + auto iter = session_creators_.find(device_name); + if (session_creators_.end() != iter) { + MS_EXCEPTION_IF_NULL(iter->second); + return (iter->second)(); + } + return nullptr; +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/session_factory.h b/mindspore/ccsrc/backend/session/session_factory.h new file mode 100644 index 0000000000..054f03cf4b --- /dev/null +++ b/mindspore/ccsrc/backend/session/session_factory.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ +#define MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ + +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "backend/session/session_basic.h" +namespace mindspore { +namespace session { +using SessionCreator = std::function()>; +class SessionFactory { + public: + static SessionFactory &Get(); + void Register(const std::string &device_name, SessionCreator &&session_creator); + std::shared_ptr Create(const std::string &device_name); + + private: + SessionFactory() = default; + ~SessionFactory() = default; + DISABLE_COPY_AND_ASSIGN(SessionFactory) + std::map session_creators_; +}; + +class SessionRegistrar { + public: + SessionRegistrar(const std::string &device_name, SessionCreator &&session_creator) { + SessionFactory::Get().Register(device_name, std::move(session_creator)); + } + ~SessionRegistrar() = default; +}; + +#define MS_REG_SESSION(DEVICE_NAME, SESSION_CLASS) \ + static const SessionRegistrar g_session_registrar__##DEVICE_NAME##_##_reg( \ + DEVICE_NAME, []() { return std::make_shared(); }); +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ diff --git a/mindspore/ccsrc/common.h b/mindspore/ccsrc/common.h index a545be32c7..6b882a15d4 100644 --- a/mindspore/ccsrc/common.h +++ b/mindspore/ccsrc/common.h @@ -25,11 +25,11 @@ #include "abstract/dshape.h" #include "abstract/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/resolve.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/resolve.h" namespace py = pybind11; #endif // MINDSPORE_CCSRC_COMMON_H_ diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 9cf6eb3a5a..1841826ca9 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -18,9 +18,9 @@ #include #include #include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel.h" -#include "device/convert_tensor_utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/convert_tensor_utils.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" #include "utils/utils.h" diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index a8fc7c8a00..286c76afd0 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -24,7 +24,7 @@ #include #include #include "ir/dtype.h" -#include "kernel/kernel.h" +#include "backend/kernel_compiler/kernel.h" #include "ir/dtype/type.h" namespace mindspore { diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt deleted file mode 100644 index 4b84c4d797..0000000000 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ /dev/null @@ -1,159 +0,0 @@ -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-reorder") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-switch") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sequence-point") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable") - -if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-uninitialized") -else() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-maybe-uninitialized") -endif() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") - -############################# Options ################################ -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - add_definitions(-D _CRT_RAND_S) -endif () -if (ENABLE_GPUQUE) - add_definitions(-D ENABLE_GPUQUE) - message(STATUS "GPU queue is enabled") -endif () -if (ENABLE_TDTQUE) - add_definitions(-D ENABLE_TDTQUE) - message(STATUS "TDT queue is enabled") -endif () - -# conde coverage -# option(ENABLE_COVERAGE "Enable code coverage report" OFF) -# if (ENABLE_COVERAGE) -# include(${CMAKE_SOURCE_DIR}/cmake/CodeCoverage.cmake) -# append_coverage_compiler_flags() -# endif () - -########### Set up the include directories ########################### -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/device/ascend/platform) - -include_directories(${CMAKE_BINARY_DIR}) # for protobuf generated .h - -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/mindrecord/include) -include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include) -###################################################################### - -####################### Flags ######################################## -# compile flags -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") - -ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR}) - -################## Include sub-modules ############################### -add_subdirectory(util) -add_subdirectory(core) -add_subdirectory(kernels) -add_subdirectory(engine) -add_subdirectory(api) -add_subdirectory(text) -###################################################################### -add_dependencies(utils core) -add_dependencies(kernels-image core) -add_dependencies(kernels-data core) -add_dependencies(kernels core) -add_dependencies(engine-datasetops-source core) -add_dependencies(engine-datasetops-source-sampler core) -add_dependencies(engine-datasetops core) -add_dependencies(engine-opt core) -add_dependencies(engine-perf core) -add_dependencies(engine-gnn core) -add_dependencies(engine core) -add_dependencies(text core) -add_dependencies(text-kernels core) -add_dependencies(cpp-API core) -if (ENABLE_PYTHON) - add_dependencies(APItoPython core) -endif() -if (ENABLE_TDTQUE) - add_dependencies(engine-tdt core) -endif () -################### Create _c_dataengine Library ###################### -set(submodules - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - $ - ) - -if (ENABLE_PYTHON) - set(submodules - ${submodules} - $) -endif() - -if (ENABLE_TDTQUE) - add_library(_c_dataengine SHARED ${submodules} $) -else () - add_library(_c_dataengine SHARED ${submodules}) -endif () - -add_dependencies(_c_dataengine generated_engine_files) - -set_target_properties(_c_dataengine PROPERTIES - PREFIX "${PYTHON_MODULE_PREFIX}" - SUFFIX "${PYTHON_MODULE_EXTENSION}" - ) - -###################################################################### - -################# Link with external libraries ######################## -target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) -else() - set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) - target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) -endif() -target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs - mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB}) -if (ENABLE_GPUQUE) - target_link_libraries(_c_dataengine PRIVATE gpu_queue - ${CUDNN_PATH}/lib64/libcudnn.so - ${CUDA_PATH}/lib64/libcudart.so - ${CUDA_PATH}/lib64/stubs/libcuda.so) -endif () - -if (ENABLE_TDTQUE) - target_link_libraries(_c_dataengine PRIVATE ${TSDCLIENT}) -endif () - -add_dependencies(_c_dataengine _c_mindrecord) -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - set(MINDRECORD_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/mindrecord/CMakeFiles/_c_mindrecord.dir/objects.a) - target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) -else() - target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) -endif() - -if (USE_GLOG) - target_link_libraries(_c_dataengine PRIVATE mindspore::glog) -else() - if (CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_options(_c_dataengine PRIVATE -Wl,-init,mindspore_log_init) - elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") - set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) - endif () -endif() diff --git a/mindspore/ccsrc/dataset/api/datasets.cc b/mindspore/ccsrc/dataset/api/datasets.cc deleted file mode 100644 index 5684e6770a..0000000000 --- a/mindspore/ccsrc/dataset/api/datasets.cc +++ /dev/null @@ -1,446 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "dataset/include/datasets.h" -#include "dataset/include/transforms.h" -#include "dataset/include/samplers.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/project_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" - -#include "dataset/core/config_manager.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -namespace api { - -#define RETURN_NULL_IF_ERROR(_s) \ - do { \ - Status __rc = (_s); \ - if (__rc.IsError()) { \ - return nullptr; \ - } \ - } while (false) - -// Function to create the iterator, which will build and launch the execution tree. -std::shared_ptr Dataset::CreateIterator() { - std::shared_ptr iter; - try { - iter = std::make_shared(); - Status rc = iter->BuildAndLaunchTree(shared_from_this()); - if (rc.IsError()) { - MS_LOG(ERROR) << "CreateIterator failed."; - return nullptr; - } - - return iter; - } catch (const std::exception &err) { - MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what(); - return nullptr; - } - - return iter; -} - -// Constructor -Dataset::Dataset() { - // Fetch some default value from config manager - std::shared_ptr cfg = GlobalContext::config_manager(); - num_workers_ = cfg->num_parallel_workers(); - rows_per_buffer_ = cfg->rows_per_buffer(); - connector_que_size_ = cfg->op_connector_size(); -} - -// Function to create a ImageFolderDataset. -std::shared_ptr ImageFolder(std::string dataset_dir, bool decode, - std::shared_ptr sampler, std::set extensions, - std::map class_indexing) { - // This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false. - bool recursive = false; - - // Create logical representation of ImageFolderDataset. - auto ds = std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing); - - // Call derived class validation method. - return ds->ValidateParams() ? ds : nullptr; -} - -// Function to create a MnistDataset. -std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler) { - auto ds = std::make_shared(dataset_dir, sampler); - - // Call derived class validation method. - return ds->ValidateParams() ? ds : nullptr; -} - -// Function to create a Cifar10Dataset. -std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, - std::shared_ptr sampler) { - auto ds = std::make_shared(dataset_dir, num_samples, sampler); - - // Call derived class validation method. - return ds->ValidateParams() ? ds : nullptr; -} - -// Function to create a Batch dataset -std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { - // Default values - std::vector cols_to_map = {}; - std::map>> pad_map; - bool pad = false; - auto ds = std::make_shared(batch_size, drop_remainder, pad, cols_to_map, pad_map); - - if (!ds->ValidateParams()) { - return nullptr; - } - - ds->children.push_back(shared_from_this()); - - return ds; -} - -// Function to create Repeat dataset. -std::shared_ptr Dataset::Repeat(int32_t count) { - // Workaround for repeat == 1, do not inject repeat. - if (count == 1) { - return shared_from_this(); - } - - auto ds = std::make_shared(count); - - if (!ds->ValidateParams()) { - return nullptr; - } - - ds->children.push_back(shared_from_this()); - - return ds; -} - -// Function to create a Map dataset. -std::shared_ptr Dataset::Map(std::vector> operations, - std::vector input_columns, - std::vector output_columns, - const std::vector &project_columns) { - auto ds = std::make_shared(operations, input_columns, output_columns, project_columns); - - if (!ds->ValidateParams()) { - return nullptr; - } - - ds->children.push_back(shared_from_this()); - - return ds; -} - -// Function to create a ShuffleOp -std::shared_ptr Dataset::Shuffle(int32_t shuffle_size) { - // Pass in reshuffle_each_epoch with true - auto ds = std::make_shared(shuffle_size, true); - - if (!ds->ValidateParams()) { - return nullptr; - } - - ds->children.push_back(shared_from_this()); - - return ds; -} - -// Function to create a ProjectDataset. -std::shared_ptr Dataset::Project(const std::vector &columns) { - auto ds = std::make_shared(columns); - // Call derived class validation method. - if (!ds->ValidateParams()) { - return nullptr; - } - - ds->children.push_back(shared_from_this()); - - return ds; -} - -// Helper function to create default RandomSampler. -std::shared_ptr CreateDefaultSampler() { - int32_t num_samples = 0; // 0 means to sample all ids. - bool replacement = false; - return std::make_shared(replacement, num_samples); -} - -/* ####################################### Derived Dataset classes ################################# */ - -ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, - bool recursive, std::set extensions, - std::map class_indexing) - : dataset_dir_(dataset_dir), - decode_(decode), - sampler_(sampler), - recursive_(recursive), - class_indexing_(class_indexing), - exts_(extensions) {} - -bool ImageFolderDataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - - return true; -} - -std::shared_ptr>> ImageFolderDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - // If user does not specify Sampler, create a default sampler, i.e., RandomSampler. - if (sampler_ == nullptr) { - sampler_ = CreateDefaultSampler(); - } - - // Do internal Schema generation. - // This arg is exist in ImageFolderOp, but not externalized (in Python API). - std::unique_ptr schema = std::make_unique(); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_NULL_IF_ERROR( - schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_NULL_IF_ERROR( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); - node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, - recursive_, decode_, exts_, class_indexing_, std::move(schema), - std::move(sampler_->Build()))); - return std::make_shared>>(node_ops); -} - -MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), sampler_(sampler) {} - -bool MnistDataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - - return true; -} - -std::shared_ptr>> MnistDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - // If user does not specify Sampler, create a default sampler, i.e., RandomSampler. - if (sampler_ == nullptr) { - sampler_ = CreateDefaultSampler(); - } - - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_NULL_IF_ERROR( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - - node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, - std::move(schema), std::move(sampler_->Build()))); - return std::make_shared>>(node_ops); -} - -BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, - std::map>> pad_map) - : batch_size_(batch_size), - drop_remainder_(drop_remainder), - pad_(pad), - cols_to_map_(cols_to_map), - pad_map_(pad_map) {} - -std::shared_ptr>> BatchDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - -#ifdef ENABLE_PYTHON - py::function noop; - node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, - cols_to_map_, noop, noop, pad_map_)); -#else - node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, - cols_to_map_, pad_map_)); -#endif - return std::make_shared>>(node_ops); -} - -bool BatchDataset::ValidateParams() { - if (batch_size_ <= 0) { - return false; - } - - return true; -} - -RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} - -std::shared_ptr>> RepeatDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - node_ops.push_back(std::make_shared(repeat_count_)); - return std::make_shared>>(node_ops); -} - -bool RepeatDataset::ValidateParams() { - if (repeat_count_ <= 0) { - return false; - } - - return true; -} -MapDataset::MapDataset(std::vector> operations, std::vector input_columns, - std::vector output_columns, const std::vector &project_columns) - : operations_(operations), - input_columns_(input_columns), - output_columns_(output_columns), - project_columns_(project_columns) {} - -std::shared_ptr>> MapDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - // Currently default is true, and this is not exposed to user. - bool perf_mode = true; - - std::vector> tensor_ops; - - // Build tensorOp from tensorOperation vector - // This is to ensure each iterator hold its own copy of the tensorOp objects. - (void)std::transform( - operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), - [](std::shared_ptr operation) -> std::shared_ptr { return operation->Build(); }); - - // This parameter will be removed with next rebase - std::vector col_orders; - auto map_op = - std::make_shared(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_, perf_mode); - if (!project_columns_.empty()) { - auto project_op = std::make_shared(project_columns_); - node_ops.push_back(project_op); - } - - node_ops.push_back(map_op); - return std::make_shared>>(node_ops); -} - -bool MapDataset::ValidateParams() { - if (operations_.empty()) { - return false; - } - - return true; -} - -// Constructor for ShuffleDataset -ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) - : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} - -// Function to build the ShuffleOp -std::shared_ptr>> ShuffleDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - node_ops.push_back(std::make_shared(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, - rows_per_buffer_)); - return std::make_shared>>(node_ops); -} - -// Function to validate the parameters for ShuffleDataset -bool ShuffleDataset::ValidateParams() { - if (shuffle_size_ <= 1) { - MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_; - return false; - } - - return true; -} - -// Constructor for Cifar10Dataset -Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler) - : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} - -bool Cifar10Dataset::ValidateParams() { - if (dataset_dir_.empty()) { - MS_LOG(ERROR) << "No dataset path is specified."; - return false; - } - if (num_samples_ < 0) { - MS_LOG(ERROR) << "Number of samples cannot be negative"; - return false; - } - return true; -} - -// Function to build CifarOp -std::shared_ptr>> Cifar10Dataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - // If user does not specify Sampler, create a default sampler based on the shuffle variable. - if (sampler_ == nullptr) { - sampler_ = CreateDefaultSampler(); - } - - // Do internal Schema generation. - auto schema = std::make_unique(); - RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_NULL_IF_ERROR( - schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - - node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, - dataset_dir_, connector_que_size_, std::move(schema), - std::move(sampler_->Build()))); - return std::make_shared>>(node_ops); -} - -// Function to build ProjectOp -ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} - -bool ProjectDataset::ValidateParams() { - if (columns_.empty()) { - MS_LOG(ERROR) << "No columns are specified."; - return false; - } - return true; -} - -std::shared_ptr>> ProjectDataset::Build() { - // A vector containing shared pointer to the Dataset Ops that this object will create - std::vector> node_ops; - - node_ops.push_back(std::make_shared(columns_)); - return std::make_shared>>(node_ops); -} - -} // namespace api -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc deleted file mode 100644 index 6d4a60cdc5..0000000000 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ /dev/null @@ -1,1605 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/api/de_pipeline.h" - -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" -#include "dataset/engine/datasetops/cache_op.h" -#include "dataset/engine/datasetops/filter_op.h" -#include "dataset/engine/datasetops/source/celeba_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/kernels/py_func_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); - -static std::unordered_map g_parse_op_func_ = { - {kShuffle, &DEPipeline::ParseShuffleOp}, - {kMindrecord, &DEPipeline::ParseMindRecordOp}, - {kMap, &DEPipeline::ParseMapOp}, - {kFilter, &DEPipeline::ParseFilterOp}, - {kBatch, &DEPipeline::ParseBatchOp}, - {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, - {kBarrier, &DEPipeline::ParseBarrierOp}, - {kRepeat, &DEPipeline::ParseRepeatOp}, - {kSkip, &DEPipeline::ParseSkipOp}, - {kZip, &DEPipeline::ParseZipOp}, - {kConcat, &DEPipeline::ParseConcatOp}, - {kRename, &DEPipeline::ParseRenameOp}, - {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, - {kGenerator, &DEPipeline::ParseGeneratorOp}, - {kTfReader, &DEPipeline::ParseTFReaderOp}, - {kProject, &DEPipeline::ParseProjectOp}, - {kTake, &DEPipeline::ParseTakeOp}, - {kImageFolder, &DEPipeline::ParseImageFolderOp}, - {kMnist, &DEPipeline::ParseMnistOp}, - {kManifest, &DEPipeline::ParseManifestOp}, - {kVoc, &DEPipeline::ParseVOCOp}, - {kCoco, &DEPipeline::ParseCocoOp}, - {kCifar10, &DEPipeline::ParseCifar10Op}, - {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}, - {kRandomData, &DEPipeline::ParseRandomDataOp}, - {kTextFile, &DEPipeline::ParseTextFileOp}, - {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, - {kClue, &DEPipeline::ParseClueOp}}; - -DEPipeline::DEPipeline() : iterator_(nullptr) { - try { - // One time init - (void)GlobalInit(); - - // Instantiate the execution tree - tree_ = std::make_shared(); - repeat_num_ = 1; - batch_size_ = 1; - num_rows_ = 0; - num_classes_ = 0; - temp_batch_size_ = 1; - temp_drop_remainder_ = false; - } catch (const std::exception &err) { - MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << "."; - return; - } -} - -DEPipeline::~DEPipeline() { - { - // Release GIL before joining all threads - py::gil_scoped_release gil_release; - // Release tree - tree_.reset(); - } -} - -// Function to add a Node to the Execution Tree. -Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { - // For each operator, Parse through the list of arguments, then call the respective builder/constructor. - // Note that each call to the parse function may result in building more than one dataset operator. - // For example, one call to ParseNNNOp may result in multiple internal C nodes: - // nodeA - // | - // nodeB - // | - // nodeC - // However, the python side dataset is more abstract, and it does not know about the potential subtree that - // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the - // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child - // to this subtee. - // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse - // function. - DsOpPtr top = nullptr; - DsOpPtr bottom = nullptr; - auto iter = g_parse_op_func_.find(op_name); - if (iter != g_parse_op_func_.end()) { - pFunction func = iter->second; - RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); - - if (top == nullptr) { - RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); - } - - // It is not required that the parse function always produces the bottom pointer. If it's still null, - // then set top and bottom to be the same operator - if (bottom == nullptr) bottom = top; - - // Pack these pointers into a py dict so that we can return both back to python. - (*output)["top"] = top; - (*output)["bottom"] = bottom; - } else { - RETURN_STATUS_UNEXPECTED("No such Op"); - } - // Associate current dataset op node with the tree. - RETURN_IF_NOT_OK(tree_->AssociateNode(top)); - return Status::OK(); -} -// Function to add a child and parent relationship. -Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) { - // Link this relationship. - // Note parent node takes ownership of the child - return (parent_op->AddChild(child_op)); -} - -// Function to assign the node as root. -Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } - -// Function to launch the tree execution. -Status DEPipeline::LaunchTreeExec() { - RETURN_IF_NOT_OK(tree_->Prepare()); - RETURN_IF_NOT_OK(tree_->Launch()); - iterator_ = std::make_unique(tree_); - if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); - return Status::OK(); -} - -void DEPipeline::PrintTree() { - for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { - std::stringstream ss; - ss << *itr; - MS_LOG(DEBUG) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << "."; - } -} - -Status DEPipeline::GetNextAsMap(py::dict *output) { - TensorMap row; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetNextAsMap(&row); - } - RETURN_IF_NOT_OK(s); - // Generate Python dict as return - for (auto el : row) { - (*output)[common::SafeCStr(el.first)] = el.second; - } - return Status::OK(); -} - -Status DEPipeline::GetNextAsList(py::list *output) { - TensorRow row; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->FetchNextTensorRow(&row); - } - RETURN_IF_NOT_OK(s); - // Generate Python list as return - for (auto el : row) { - output->append(el); - } - return Status::OK(); -} - -Status DEPipeline::GetOutputShapes(py::list *output) { - std::vector shapes; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetOutputShapes(&shapes); - } - RETURN_IF_NOT_OK(s); - for (auto el : shapes) { - py::list shape; - for (auto dim : el.AsVector()) { - shape.append(dim); - } - output->append(shape); - } - return Status::OK(); -} - -Status DEPipeline::GetOutputTypes(py::list *output) { - std::vector types; - Status s; - { - py::gil_scoped_release gil_release; - s = iterator_->GetOutputTypes(&types); - } - RETURN_IF_NOT_OK(s); - for (auto el : types) { - output->append(el.AsNumpyType()); - } - return Status::OK(); -} - -int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; } - -int DEPipeline::GetBatchSize() const { return batch_size_; } - -int DEPipeline::GetRepeatCount() const { return repeat_num_; } - -float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -std::string ToString(const py::handle &handle) { return py::reinterpret_borrow(handle); } - -std::vector ToStringVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (!l.is_none()) - vector.push_back(py::str(l)); - else - vector.emplace_back(""); - } - return vector; -} - -std::set ToStringSet(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::set set; - for (auto l : list) { - if (!l.is_none()) { - (void)set.insert(py::str(l)); - } - } - return set; -} - -std::map ToStringMap(const py::handle handle) { - py::dict dict = py::reinterpret_borrow(handle); - std::map map; - for (auto p : dict) { - (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second))); - } - return map; -} - -std::vector ToIntVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (!l.is_none()) { - vector.push_back(ToInt(l)); - } - } - return vector; -} - -std::vector ToTypeVector(const py::handle handle) { - py::list list = py::reinterpret_borrow(handle); - std::vector vector; - for (auto l : list) { - if (l.is_none()) { - vector.emplace_back(DataType()); - } else { - vector.push_back(l.cast()); - } - } - return vector; -} - -Status DEPipeline::SetBatchParameters(const py::dict &args) { - if (args["batch_size"].is_none()) { - std::string err_msg = "Error: batchSize is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - temp_batch_size_ = ToInt(args["batch_size"]); - CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid."); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "drop_remainder") { - temp_drop_remainder_ = ToBool(value); - } - } - } - - return Status::OK(); -} - -Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - if (!args["buffer_size"].is_none()) { - (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); - } else { - std::string err_msg = "Error: Shuffle buffer size is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "reshuffle_each_epoch") { - (void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, - std::vector> *operators, - int num_padded) { - auto sampler = py::reinterpret_borrow(handle); - auto create = sampler.attr("create_for_minddataset"); - auto op = create().cast>(); - std::stack> stack_ops; - while (op != nullptr) { - auto sampler_op = std::dynamic_pointer_cast(op); - if (sampler_op && num_padded > 0) { - sampler_op->SetNumPaddedSamples(num_padded); - stack_ops.push(sampler_op); - } else { - stack_ops.push(op); - } - op = op->GetChildOp(); - } - while (!stack_ops.empty()) { - operators->push_back(stack_ops.top()); - stack_ops.pop(); - } - return Status::OK(); -} - -Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_file"].is_none()) { - std::string err_msg = "Error: at least one of dataset_files is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - bool load_dataset = ToBool(args["load_dataset"]); - if (load_dataset == true) { - (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); - } else { - (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); - } - (void)builder->SetLoadDataset(load_dataset); - std::vector in_col_names; - if (!args["columns_list"].is_none()) { - in_col_names = ToStringVector(args["columns_list"]); - if (in_col_names.empty() || in_col_names[0].empty()) { - std::string err_msg = "Error: columns_list is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetColumnsToLoad(in_col_names); - } - - if (!args["padded_sample"].is_none()) { - (void)builder->SetPaddedSample(args["padded_sample"]); - (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); - } - std::vector> operators; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumMindRecordWorkers(ToInt(value)); - } else if (key == "block_reader" && ToBool(value) == true) { - (void)builder->SetBlockReader(); - } else if (key == "sampler") { - int num_padded = 0; - if (!args["num_padded"].is_none()) { - num_padded = ToInt(args["num_padded"]); - } - RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); - } - } - } - - if (!operators.empty()) { - (void)builder->SetOperators(operators); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - num_rows_ = op->num_rows(); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - MapOp::Builder map_builder; - std::vector> tensor_op_list; - std::vector project_columns; - std::shared_ptr cache_client = nullptr; - int num_workers = 0; - - if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "input_columns") { - std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)map_builder.SetInColNames(in_col_names); - } else if (key == "output_columns") { - (void)map_builder.SetOutColNames(ToStringVector(value)); - } else if (key == "columns_order") { - project_columns = ToStringVector(value); - } else if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)map_builder.SetNumWorkers(num_workers); - } else if (key == "prefetch_size") { - (void)map_builder.SetOpConnectorSize(ToInt(value)); - } else if (key == "operations") { - py::handle tensor_ops = args["operations"]; - // operation can be a list of TensorOps or a single TensorOp. - if (py::isinstance(tensor_ops)) { - for (auto op : tensor_ops) { - std::shared_ptr tensor_op; - if (py::isinstance(op)) { - tensor_op = op.cast>(); - } else if (py::isinstance(op)) { - tensor_op = std::make_shared(op.cast()); - } else { - RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); - } - tensor_op_list.push_back(tensor_op); - } - } - if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); - (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); - } else if (key == "cache") { - cache_client = value.cast>(); - } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); - } - } - } - - std::shared_ptr map_op; - RETURN_IF_NOT_OK(map_builder.Build(&map_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); - *top = map_op; - - // Add a project op over top of the map if the user wanted to reposition the columns - if (!project_columns.empty()) { - ProjectOp::Builder proj_builder(project_columns); - std::shared_ptr proj_op; - RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); - RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); - *top = proj_op; - *bottom = map_op; - } - - // Additionally, add a cache if required. This will go over top of the project op if one - // was created, otherwise it goes over top of the map op - if (cache_client) { - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); - *top = cache_op; - *bottom = map_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - - if (args["predicate"].is_none()) { - RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); - } - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "predicate") { - py::handle op = args["predicate"]; - if (!py::isinstance(op)) { - RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); - } - py::function predicate_func = op.cast(); - (void)builder->SetPredicateFunc(std::move(predicate_func)); - } else if (key == "input_columns") { - std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)builder->SetInColNames(in_col_names); - } else { - RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - repeat_num_ = ToInt(args["count"]); - std::shared_ptr op; - RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "source") { - py::object obj = py::cast(&value); - if (!py::isinstance(obj)) { - std::string err_msg = "Error: generator is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetGeneratorFunction(obj.cast()); - } else if (key == "column_names") { - (void)builder->SetColumnNames(ToStringVector(value)); - } else if (key == "column_types") { - (void)builder->SetColumnTypes(ToTypeVector(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder; - if (py::isinstance(args["batch_size"])) { - batch_size_ = ToInt(args["batch_size"]); - CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); - builder = std::make_shared(ToInt(args["batch_size"])); - } else if (py::isinstance(args["batch_size"])) { - builder = std::make_shared(1); - (void)builder->SetBatchSizeFunc(args["batch_size"].cast()); - } else { - std::string err_msg = "Error: batch_size is neither an Integer nor a python function"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "drop_remainder") { - (void)builder->SetDrop(ToBool(value)); - } - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } - if (key == "per_batch_map") { - (void)builder->SetBatchMapFunc(value.cast()); - } - if (key == "input_columns") { - (void)builder->SetColumnsToMap(ToStringVector(value)); - } - if (key == "pad_info") { - PadInfo pad_info; - RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); - (void)builder->SetPaddingMap(pad_info, true); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", - "bucket_batch_sizes"}; - for (auto name : mandatory_arguments) { - if (args[name.c_str()].is_none()) { - std::string err_msg = "Error: " + name + " is not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - - std::shared_ptr builder = std::make_shared( - ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), - ToIntVector(args[mandatory_arguments[2].c_str()])); - - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "length_dependent_columns") { - (void)builder->SetLengthDependentColumns(ToStringVector(value)); - } - if (key == "bucket_boundaries") { - (void)builder->SetBucketBoundaries(ToIntVector(value)); - } - if (key == "bucket_batch_sizes") { - (void)builder->SetBucketBatchSizes(ToIntVector(value)); - } - if (key == "element_length_function") { - (void)builder->SetElementLengthFunction(value.cast()); - } - if (key == "pad_info") { - PadInfo pad_info; - RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); - (void)builder->SetPadInfo(pad_info); - } - if (key == "pad_to_bucket_boundary") { - (void)builder->SetPadToBucketBoundary(ToBool(value)); - } - if (key == "drop_remainder") { - (void)builder->SetDropRemainder(ToBool(value)); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - // Right now barrier should only take num_rows_per_buffer = 1 - // The reason for this is because having it otherwise can lead to blocking issues - // See barrier_op.h for more details - (void)builder->SetRowsPerBuffer(1); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "condition_name") { - (void)builder->SetConditionName(ToString(value)); - } else if (key == "condition_func") { - (void)builder->SetConditionFunc(value.cast()); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - int32_t prefetch_size = 0; - if (args.contains("prefetch_size")) { - if (args["prefetch_size"].is_none()) { - prefetch_size = 16; - } else { - prefetch_size = ToInt(args["prefetch_size"]); - } - } - std::shared_ptr builder = std::make_shared(prefetch_size); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "queue_name") { - (void)builder->SetChannelName(ToString(value)); - } else if (key == "device_type") { - (void)builder->SetDeviceType(ToString(value)); - } else if (key == "device_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "num_batch") { - (void)builder->SetNumBatch(ToInt(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector in_col_names; - std::vector out_col_names; - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "input_columns") { - in_col_names = ToStringVector(value); - } else if (key == "output_columns") { - out_col_names = ToStringVector(value); - } - } - } - if (in_col_names.empty() || in_col_names[0].empty()) { - std::string err_msg = "Error: input_column_names is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (out_col_names.empty() || out_col_names[0].empty()) { - std::string err_msg = "Error: output_column_names is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)builder->SetInColNames(in_col_names); - (void)builder->SetOutColNames(out_col_names); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["count"].is_none()) { - std::string err_msg = "Error: count is invalid or not set."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr op; - RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - std::vector files_list; - std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; - int num_workers = 0; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetDatasetFilesList(files_list); - } else { - std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_load; - bool schema_exists = false; - bool shuffle_required = false; - int64_t num_devices = 0; - int64_t total_rows = 0; - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)builder->SetNumWorkers(num_workers); - } else if (key == "columns_list") { - columns_to_load = ToStringVector(value); - (void)builder->SetColumnsToLoad(columns_to_load); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "schema_file_path" || key == "schema_json_string") { - schema_exists = true; - } else if (key == "num_samples") { - total_rows = ToInt(value); - (void)builder->setTotalRows(total_rows); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "shard_equal_rows") { - (void)builder->SetShardEqualRows(ToBool(value)); - } else if (key == "cache") { - cache_client = value.cast>(); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); - } - } - } - if (schema_exists) { - std::unique_ptr schema = std::make_unique(); - if (args.contains("schema_file_path")) { - RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); - } else { - RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); - } - (void)builder->SetDataSchema(std::move(schema)); - } - - // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed - // because TFReaderOp is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - if (sampler) { - (void)builder->SetSampler(std::move(sampler)); - } else if (cache_client) { - int64_t num_samples = 0; - int64_t start_index = 0; - sampler = std::make_shared(num_samples, start_index); - (void)builder->SetSampler(std::move(sampler)); - } - - std::shared_ptr tf_op; - RETURN_IF_NOT_OK(builder->Build(&tf_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); - *top = tf_op; - - if (!cache_client && shuffle_required) { - const boolean estimate = true; - const int64_t workers = 8; - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset via estimate and then compute the shuffle size - RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); - *top = shuffle_op; - *bottom = tf_op; - } - - // Add a cache op over this op if required and update the output subtree (top/bottom) - if (cache_client) { - // Note, it is not allowed to have both shuffle and cache - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); - *top = cache_op; - *bottom = tf_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["columns"].is_none()) { - std::string err_msg = "Error: columns is missing"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_project = ToStringVector(args["columns"]); - std::shared_ptr builder = std::make_shared(columns_to_project); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - int num_workers = 0; - std::shared_ptr cache_client = nullptr; - std::shared_ptr builder = std::make_shared(); - (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)builder->SetNumWorkers(num_workers); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "extensions") { - (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "cache") { - cache_client = value.cast>(); - } - } - } - std::shared_ptr if_op; - RETURN_IF_NOT_OK(builder->Build(&if_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); - *top = if_op; - - // Additionally, add a cache if required. - // Note that this cache op is only acting as a place holder for the caching position - // within the tree. Later, a pre-pass will execute a tree transform to set up the actual - // caching logic in the tree. - if (cache_client) { - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); - *top = cache_op; - *bottom = if_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_file"].is_none()) { - std::string err_msg = "Error: No dataset files specified for manifest"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::shared_ptr builder = std::make_shared(); - (void)builder->SetManifestFile(ToString(args["dataset_file"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "usage") { - (void)builder->SetUsage(ToString(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["task"].is_none()) { - std::string err_msg = "Error: No task specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["mode"].is_none()) { - std::string err_msg = "Error: No mode specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - (void)builder->SetTask(ToString(args["task"])); - (void)builder->SetMode(ToString(args["mode"])); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "class_indexing") { - (void)builder->SetClassIndex(ToStringMap(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - - return Status::OK(); -} - -Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["annotation_file"].is_none()) { - std::string err_msg = "Error: No annotation_file specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - if (args["task"].is_none()) { - std::string err_msg = "Error: No task specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - (void)builder->SetFile(ToString(args["annotation_file"])); - (void)builder->SetTask(ToString(args["task"])); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetCifarDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - - (void)builder->SetCifarType(true); - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetCifarDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - - (void)builder->SetCifarType(false); - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - RandomDataOp::Builder builder; - std::shared_ptr cache_client = nullptr; - std::shared_ptr sampler = nullptr; - int num_workers = 0; - - if (args["total_rows"].is_none()) { - std::string err_msg = "Error: total_rows is a required argument"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::vector columns_to_load; - bool schema_exists = false; - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - num_workers = ToInt(value); - (void)builder.SetNumWorkers(num_workers); - } else if (key == "schema_file_path" || key == "schema_json_string") { - schema_exists = true; - } else if (key == "columns_list") { - columns_to_load = ToStringVector(value); - } else if (key == "total_rows") { - // This is not sampling here. The random data op needs to know how much data to generate. - (void)builder.SetTotalRows(ToInt(value)); - } else if (key == "cache") { - cache_client = value.cast>(); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - sampler = create().cast>(); - } - } - } - if (schema_exists) { - std::unique_ptr schema = std::make_unique(); - if (args.contains("schema_file_path")) { - RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); - } else { - RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); - } - (void)builder.SetDataSchema(std::move(schema)); - } - - // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed - // because RandomDataOp is a non-mappable dataset that does not support sampling. - // However, if a cache operator is injected at some other place higher in the tree, that cache can - // inherit this sampler from the leaf, providing sampling support from the caching layer. - // That is why we save the sampler here in a leaf node that does not use sampling. - if (sampler) { - (void)builder.SetSampler(std::move(sampler)); - } else if (cache_client) { - int64_t num_samples = 0; - int64_t start_index = 0; - sampler = std::make_shared(num_samples, start_index); - (void)builder.SetSampler(std::move(sampler)); - } - - std::shared_ptr random_op = nullptr; - RETURN_IF_NOT_OK(builder.Build(&random_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); - *top = random_op; - - // Add a cache op over this op if required and update the output subtree (top/bottom) - if (cache_client) { - std::shared_ptr cache_op = nullptr; - RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); - *top = cache_op; - *bottom = random_op; - } - - return Status::OK(); -} - -int32_t DEPipeline::GetNumClasses() const { return num_classes_; } - -Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - std::shared_ptr builder = std::make_shared(); - (void)builder->SetDir(ToString(args["dataset_dir"])); - - // Optional arguments - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - if (args["dataset_dir"].is_none()) { - std::string err_msg = "Error: No dataset path specified"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - - std::shared_ptr builder = std::make_shared(); - if (builder == nullptr) { - std::string err_msg = "Create celebaop builder failed"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - (void)builder->SetCelebADir(ToString(args["dataset_dir"])); - for (const auto &arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "sampler") { - auto create = py::reinterpret_borrow(value).attr("create"); - std::shared_ptr sampler = create().cast>(); - (void)builder->SetSampler(std::move(sampler)); - } else if (key == "decode") { - (void)builder->SetDecode(ToBool(value)); - } else if (key == "extensions") { - (void)builder->SetExtensions(ToStringSet(value)); - } else if (key == "dataset_type") { - (void)builder->SetDatasetType(ToString(value)); - } - } - } - - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - // Required arguments - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetTextFilesList(files_list); - } else { - RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); - } - // Optional arguments - bool shuffle_required = false; - int64_t num_devices = 0; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "num_samples") { - (void)builder->SetTotalRows(ToInt(value)); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } - } - } - - std::shared_ptr txt_op; - RETURN_IF_NOT_OK(builder->Build(&txt_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); - *top = txt_op; - - if (shuffle_required) { - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset and then compute the shuffle size - RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); - *top = shuffle_op; - *bottom = txt_op; - } - - return Status::OK(); -} - -Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { - for (auto p : py::reinterpret_borrow(value)) { - if (!p.second.is_none()) { - auto tp = py::reinterpret_borrow(p.second); - CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); - TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); - std::shared_ptr pad_val = nullptr; - if (py::isinstance(tp[1])) { - std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED( - Tensor::CreateTensor(&pad_val, std::vector{pad_val_string}, TensorShape::CreateScalar()), - "Cannot create pad_value Tensor"); - } else { - float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); - CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(), - DataType(DataType::DE_FLOAT32)), - "Cannot create pad_value Tensor"); - pad_val->SetItemAt({}, pad_val_float); - } - (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); - } else { // tuple is None - (void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}}); - } - } - return Status::OK(); -} - -Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::shared_ptr builder = std::make_shared(); - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "freq_range") { - py::tuple tp = py::reinterpret_borrow(value); - if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); - if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); - } else if (key == "top_k") { - builder->SetTopK(py::reinterpret_borrow(value)); - } else if (key == "columns") { - (void)builder->SetColumnNames(ToStringVector(value)); - } else if (key == "vocab") { - (void)builder->SetVocab(value.cast>()); - } else if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "special_first") { - (void)builder->SetSpecialFirst(ToBool(value)); - } else if (key == "special_tokens") { - (void)builder->SetSpecialTokens(ToStringVector(value)); - } - } - } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *top = op; - return Status::OK(); -} - -Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom) { - std::vector files_list; - std::shared_ptr builder = std::make_shared(); - if (!args["dataset_files"].is_none()) { - files_list = ToStringVector(args["dataset_files"]); - (void)builder->SetClueFilesList(files_list); - } else { - RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); - } - // Optional arguments - bool shuffle_required = false; - int64_t num_devices = 0; - for (auto arg : args) { - std::string key = py::str(arg.first); - py::handle value = arg.second; - if (!value.is_none()) { - if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); - } else if (key == "shuffle_files") { - (void)builder->SetShuffleFiles(ToBool(value)); - } else if (key == "shuffle_global") { - shuffle_required = ToBool(value); - } else if (key == "num_samples") { - (void)builder->SetNumSamples(ToInt(value)); - } else if (key == "num_shards") { - num_devices = ToInt(value); - (void)builder->SetNumDevices(num_devices); - } else if (key == "shard_id") { - (void)builder->SetDeviceId(ToInt(value)); - } else if (key == "cols_to_keyword") { - std::map map_dict; - for (auto p : py::reinterpret_borrow(value)) { - if (!p.second.is_none()) { - map_dict.insert({ToString(p.first), ToString(p.second)}); - } else { - map_dict.insert({ToString(p.first), ToString(p.first)}); - } - } - (void)builder->SetColsKeyMap(map_dict); - } - } - } - - std::shared_ptr clue_op; - RETURN_IF_NOT_OK(builder->Build(&clue_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); - *top = clue_op; - - if (shuffle_required) { - std::shared_ptr shuffle_op = nullptr; - int64_t shuffle_size = 0; - int64_t num_rows = 0; - - // First, get the number of rows in the dataset and then compute the shuffle size - RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); - RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); - - // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller - RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); - *top = shuffle_op; - *bottom = clue_op; - } - - return Status::OK(); -} - -// Helper function to inject the cache operator over top of the current operation being built. -Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num_workers, - std::shared_ptr input_op, std::shared_ptr *cache_op) { - std::shared_ptr new_cache_op = nullptr; - CacheOp::Builder cache_builder; - // use the same number of workers as the leaf. We need some optimization here, the user does not - // give the cache op number of workers directly. - if (num_workers != 0) { - (void)cache_builder.SetNumWorkers(num_workers); - } - (void)cache_builder.SetClient(cache_client); - RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); - RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); - // We have now created: - // - // CacheOp - // | - // input_op - // - *cache_op = new_cache_op; - - return Status::OK(); -} - -// Helper function to inject a shuffle operator over top of the current operation being built. -Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, - std::shared_ptr *shuffle_op) { - std::shared_ptr new_shuffle_op = nullptr; - ShuffleOp::Builder shuffle_builder; - - (void)shuffle_builder.SetShuffleSize(shuffle_size); - RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); - RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); - RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); - // We have now created: - // - // ShuffleOp - // | - // input_op - // - *shuffle_op = new_shuffle_op; - - return Status::OK(); -} - -// Common code for computing a default shuffle size -Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size) { - const int64_t average_files_multiplier = 4; - const int64_t shuffle_max = 10000; - int64_t avg_rows_per_file = 0; - - // Adjust the num rows per shard if sharding was given - if (num_devices > 0) { - if (num_rows % num_devices == 0) { - num_rows = num_rows / num_devices; - } else { - num_rows = (num_rows / num_devices) + 1; - } - } - - // Cap based on total rows directive. Some ops do not have this and give value of 0. - if (total_rows > 0) { - num_rows = std::min(num_rows, total_rows); - } - - // get the average per file - avg_rows_per_file = num_rows / num_files; - - *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h deleted file mode 100644 index aac2d686af..0000000000 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ /dev/null @@ -1,225 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_API_DE_PIPELINE_H_ -#define DATASET_API_DE_PIPELINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/client.h" // DE client -#include "dataset/engine/dataset_iterator.h" -#include "dataset/util/status.h" -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -using DsOpPtr = std::shared_ptr; - -class CacheClient; - -// enum for the dataset operator names -enum OpName { - kShuffle, - kMindrecord, - kBatch, - kBucketBatch, - kBarrier, - kCache, - kRepeat, - kSkip, - kTake, - kZip, - kConcat, - kMap, - kFilter, - kDeviceQueue, - kGenerator, - kRename, - kTfReader, - kProject, - kImageFolder, - kMnist, - kManifest, - kVoc, - kCoco, - kCifar10, - kCifar100, - kCelebA, - kRandomData, - kTextFile, - kBuildVocab, - kClue -}; - -// The C++ binder class that we expose to the python script. -class DEPipeline { - public: - DEPipeline(); - - ~DEPipeline(); - - // Function to add a Node to the Execution Tree. - Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); - - // Function to add a child and parent relationship. - static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); - - // Function to assign the node as root. - Status AssignRootNode(const DsOpPtr &dataset_op); - - // Function to launch the tree execution. - Status LaunchTreeExec(); - - // Get a row of data as dictionary of column name to the value. - Status GetNextAsMap(py::dict *output); - - // Get a row of data as list. - Status GetNextAsList(py::list *output); - - Status GetOutputShapes(py::list *output); - - Status GetOutputTypes(py::list *output); - - int GetDatasetSize() const; - - int GetBatchSize() const; - - int GetRepeatCount() const; - - Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status BuildMindrecordSamplerChain(const py::handle &handle, - std::vector> *operators, - int num_padded); - - Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, - std::shared_ptr *bottom); - - Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - void PrintTree(); - - int32_t GetNumClasses() const; - - Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status SetBatchParameters(const py::dict &args); - - Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - - private: - // Execution tree that links the dataset operators. - std::shared_ptr tree_; - - std::unique_ptr iterator_; - - static Status ParsePadInfo(py::handle value, PadInfo *pad_info); - - /// \brief Helper function to inject a cache operator over top of the current operation being built. - /// \param[in] cache_client The client to use for caching - /// \param[in] num_workers The number of workers to use in the cache op - /// \param[in] input_op The operator to build the cache on top of - /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be - /// the cache operator - /// \return Status return code - Status AddCacheOp(std::shared_ptr cache_client, int num_workers, std::shared_ptr input_op, - std::shared_ptr *cache_op); - - /// \brief Helper function to inject a shuffle operator over top of the current operation being built. - /// \param[in] shuffle_size The size to use in the shuffle buffer - /// \param[in] input_op The operator to build shuffle on top of - /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be - /// the shuffle operator - /// \return Status return code - Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, - std::shared_ptr *shuffle_op); - - /// \brief Helper function to compute the shuffle size - /// \param[in] num_files The number of files in the dataset - /// \param[in] num_devices The number of devices in the dataset - /// \param[in] num_rows The number of rows in the dataset - /// \param[in] total_rows An upper bound on the total rows in the dataset - /// \param[out] shuffle_size The resultant computed shuffle size - /// \return Status return code - Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, - int64_t *shuffle_size); - - int batch_size_; - int repeat_num_; - int num_rows_; - int num_classes_; - - int temp_batch_size_; - bool temp_drop_remainder_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_API_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/dataset/api/iterator.cc b/mindspore/ccsrc/dataset/api/iterator.cc deleted file mode 100644 index 3875dcf8aa..0000000000 --- a/mindspore/ccsrc/dataset/api/iterator.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/include/iterator.h" -#include "dataset/core/client.h" -#include "dataset/include/datasets.h" - -namespace mindspore { -namespace dataset { -namespace api { - -// Get the next row from the data pipeline. -void Iterator::GetNextRow(TensorMap *row) { - Status rc = iterator_->GetNextAsMap(row); - if (rc.IsError()) { - MS_LOG(ERROR) << "GetNextRow: Failed to get next row."; - row->clear(); - } -} - -// Shut down the data pipeline. -void Iterator::Stop() { - // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_. - iterator_.reset(); - - // Release ownership of tree_ shared pointer. This will decrement the ref count. - tree_.reset(); -} - -// Function to build and launch the execution tree. -Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { - // One time init - Status rc; - rc = GlobalInit(); - RETURN_IF_NOT_OK(rc); - - // Instantiate the execution tree - tree_ = std::make_shared(); - - // Iterative BFS converting Dataset tree into runtime Execution tree. - std::queue, std::shared_ptr>> q; - - if (ds != nullptr) { - // Convert the current root node. - auto root_op = ds->Build()->front(); - RETURN_UNEXPECTED_IF_NULL(root_op); - - RETURN_IF_NOT_OK(tree_->AssociateNode(root_op)); - - q.push(std::make_pair(ds, root_op)); - - // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes) - while (!q.empty()) { - auto node_pair = q.front(); - q.pop(); - // Iterate through all the direct children of the first element in our BFS queue - for (auto child : node_pair.first->children) { - auto child_ops = child->Build(); - RETURN_UNEXPECTED_IF_NULL(child_ops); - auto node_op = node_pair.second; - // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them - // with the execution tree and add the child and parent relationship between the nodes - // Note that some Dataset objects might return more than one DatasetOps - // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset - for (auto child_op : *child_ops) { - RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); - RETURN_IF_NOT_OK(node_op->AddChild(child_op)); - node_op = child_op; - } - // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current - // execution tree) to the BFS queue - q.push(std::make_pair(child, child_ops->back())); - } - } - RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); - } - - // Launch the execution tree. - RETURN_IF_NOT_OK(tree_->Prepare()); - RETURN_IF_NOT_OK(tree_->Launch()); - iterator_ = std::make_unique(tree_); - RETURN_UNEXPECTED_IF_NULL(iterator_); - - return rc; -} - -} // namespace api -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc deleted file mode 100644 index 63bd5eccdc..0000000000 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ /dev/null @@ -1,954 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "dataset/api/de_pipeline.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/python_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/gnn/graph.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/kernels/data/concatenate_op.h" -#include "dataset/kernels/data/duplicate_op.h" -#include "dataset/kernels/data/fill_op.h" -#include "dataset/kernels/data/mask_op.h" -#include "dataset/kernels/data/one_hot_op.h" -#include "dataset/kernels/data/pad_end_op.h" -#include "dataset/kernels/data/slice_op.h" -#include "dataset/kernels/data/to_float16_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/kernels/image/bounding_box_augment_op.h" -#include "dataset/kernels/image/center_crop_op.h" -#include "dataset/kernels/image/cut_out_op.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/hwc_to_chw_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/normalize_op.h" -#include "dataset/kernels/image/pad_op.h" -#include "dataset/kernels/image/random_color_adjust_op.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" -#include "dataset/kernels/image/random_crop_decode_resize_op.h" -#include "dataset/kernels/image/random_crop_op.h" -#include "dataset/kernels/image/random_crop_with_bbox_op.h" -#include "dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" -#include "dataset/kernels/image/random_horizontal_flip_op.h" -#include "dataset/kernels/image/random_resize_op.h" -#include "dataset/kernels/image/random_resize_with_bbox_op.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/kernels/image/random_vertical_flip_op.h" -#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" -#include "dataset/kernels/image/rescale_op.h" -#include "dataset/kernels/image/resize_bilinear_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/kernels/no_op.h" -#include "dataset/text/kernels/jieba_tokenizer_op.h" -#include "dataset/text/kernels/lookup_op.h" -#include "dataset/text/kernels/ngram_op.h" -#include "dataset/text/kernels/to_number_op.h" -#include "dataset/text/kernels/unicode_char_tokenizer_op.h" -#include "dataset/text/kernels/wordpiece_tokenizer_op.h" -#include "dataset/text/vocab.h" -#include "dataset/util/random.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_pk_sample.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_sequential_sample.h" -#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "pybind11/stl_bind.h" - -#ifdef ENABLE_ICU4C -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include "dataset/text/kernels/bert_tokenizer_op.h" -#include "dataset/text/kernels/case_fold_op.h" -#include "dataset/text/kernels/normalize_utf8_op.h" -#include "dataset/text/kernels/regex_replace_op.h" -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include "dataset/text/kernels/unicode_script_tokenizer_op.h" -#include "dataset/text/kernels/whitespace_tokenizer_op.h" -#endif - -namespace py = pybind11; - -namespace mindspore { -namespace dataset { -#define THROW_IF_ERROR(s) \ - do { \ - Status rc = std::move(s); \ - if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ - } while (false) - -void bindDEPipeline(py::module *m) { - (void)py::class_(*m, "DEPipeline") - .def(py::init<>()) - .def( - "AddNodeToTree", - [](DEPipeline &de, const OpName &op_name, const py::dict &args) { - py::dict out; - THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); - return out; - }, - py::return_value_policy::reference) - .def_static("AddChildToParentNode", - [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { - THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); - }) - .def("AssignRootNode", - [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) - .def("SetBatchParameters", - [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) - .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) - .def("GetNextAsMap", - [](DEPipeline &de) { - py::dict out; - THROW_IF_ERROR(de.GetNextAsMap(&out)); - return out; - }) - .def("GetNextAsList", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetNextAsList(&out)); - return out; - }) - .def("GetOutputShapes", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetOutputShapes(&out)); - return out; - }) - .def("GetOutputTypes", - [](DEPipeline &de) { - py::list out; - THROW_IF_ERROR(de.GetOutputTypes(&out)); - return out; - }) - .def("GetDatasetSize", &DEPipeline::GetDatasetSize) - .def("GetBatchSize", &DEPipeline::GetBatchSize) - .def("GetNumClasses", &DEPipeline::GetNumClasses) - .def("GetRepeatCount", &DEPipeline::GetRepeatCount); -} -void bindDatasetOps(py::module *m) { - (void)py::class_>(*m, "TFReaderOp") - .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { - int64_t count = 0; - std::vector filenames; - for (auto l : files) { - !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); - } - THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); - return count; - }); - - (void)py::class_>(*m, "CifarOp") - .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { - int64_t count = 0; - THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); - return count; - }); - - (void)py::class_>(*m, "ImageFolderOp") - .def_static("get_num_rows_and_classes", [](const std::string &path) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }); - - (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, - const int64_t num_padded) { - int64_t count = 0; - std::shared_ptr op; - if (py::hasattr(sampler, "create_for_minddataset")) { - auto create = sampler.attr("create_for_minddataset"); - op = create().cast>(); - } - THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); - return count; - }); - - (void)py::class_>(*m, "ManifestOp") - .def_static("get_num_rows_and_classes", - [](const std::string &file, const py::dict &dict, const std::string &usage) { - int64_t count = 0, num_classes = 0; - THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); - return py::make_tuple(count, num_classes); - }) - .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { - std::map output_class_indexing; - THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); - return output_class_indexing; - }); - - (void)py::class_>(*m, "MnistOp") - .def_static("get_num_rows", [](const std::string &dir) { - int64_t count = 0; - THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); - return count; - }); - - (void)py::class_>(*m, "TextFileOp") - .def_static("get_num_rows", [](const py::list &files) { - int64_t count = 0; - std::vector filenames; - for (auto file : files) { - !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); - } - THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); - return count; - }); - - (void)py::class_>(*m, "ClueOp") - .def_static("get_num_rows", [](const py::list &files) { - int64_t count = 0; - std::vector filenames; - for (auto file : files) { - file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); - } - THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); - return count; - }); - - (void)py::class_>(*m, "VOCOp") - .def_static("get_num_rows", - [](const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t numSamples) { - int64_t count = 0; - THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); - return count; - }) - .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, - const std::string &task_mode, const py::dict &dict) { - std::map output_class_indexing; - THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); - return output_class_indexing; - }); - (void)py::class_>(*m, "CocoOp") - .def_static("get_class_indexing", - [](const std::string &dir, const std::string &file, const std::string &task) { - std::vector>> output_class_indexing; - THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); - return output_class_indexing; - }) - .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) { - int64_t count = 0; - THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); - return count; - }); -} -void bindTensor(py::module *m) { - (void)py::class_(*m, "GlobalContext") - .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference); - - (void)py::class_>(*m, "ConfigManager") - .def("__str__", &ConfigManager::ToString) - .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) - .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) - .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) - .def("set_op_connector_size", &ConfigManager::set_op_connector_size) - .def("set_seed", &ConfigManager::set_seed) - .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) - .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) - .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) - .def("get_worker_connector_size", &ConfigManager::worker_connector_size) - .def("get_op_connector_size", &ConfigManager::op_connector_size) - .def("get_seed", &ConfigManager::seed) - .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) - .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); - - (void)py::class_>(*m, "Tensor", py::buffer_protocol()) - .def(py::init([](py::array arr) { - std::shared_ptr out; - THROW_IF_ERROR(Tensor::CreateTensor(&out, arr)); - return out; - })) - .def_buffer([](Tensor &tensor) { - py::buffer_info info; - THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); - return info; - }) - .def("__str__", &Tensor::ToString) - .def("shape", &Tensor::shape) - .def("type", &Tensor::type) - .def("as_array", [](py::object &t) { - auto &tensor = py::cast(t); - if (tensor.type() == DataType::DE_STRING) { - py::array res; - tensor.GetDataAsNumpyStrings(&res); - return res; - } - py::buffer_info info; - THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); - return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t); - }); - - (void)py::class_(*m, "TensorShape") - .def(py::init()) - .def("__str__", &TensorShape::ToString) - .def("as_list", &TensorShape::AsPyList) - .def("is_known", &TensorShape::known); - - (void)py::class_(*m, "DataType") - .def(py::init()) - .def(py::self == py::self) - .def("__str__", &DataType::ToString) - .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); -} - -void bindTensorOps1(py::module *m) { - (void)py::class_>(*m, "TensorOp") - .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); - - (void)py::class_>( - *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") - .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), - py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); - - (void)py::class_>( - *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") - .def(py::init(), py::arg("rescale"), py::arg("shift")); - - (void)py::class_>( - *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)") - .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); - - (void)py::class_>( - *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); - - (void)py::class_>( - *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, - py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); - - (void)py::class_>( - *m, "RandomResizeWithBBoxOp", - "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); - - (void)py::class_>( - *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") - .def(py::init>, int32_t>(), py::arg("operations"), - py::arg("NumOps") = UniformAugOp::kDefNumOps); - - (void)py::class_>( - *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") - .def(py::init, float>(), py::arg("transform"), - py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); - - (void)py::class_>( - *m, "ResizeBilinearOp", - "Tensor operation to resize an image using " - "Bilinear mode. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth); - - (void)py::class_>(*m, "DecodeOp", - "Tensor operation to decode a jpg image") - .def(py::init<>()) - .def(py::init(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat); - - (void)py::class_>( - *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.") - .def(py::init(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability); - - (void)py::class_>( - *m, "RandomHorizontalFlipWithBBoxOp", - "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.") - .def(py::init(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability); -} - -void bindTensorOps2(py::module *m) { - (void)py::class_>( - *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.") - .def(py::init(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability); - - (void)py::class_>( - *m, "RandomVerticalFlipWithBBoxOp", - "Tensor operation to randomly flip an image vertically" - " and adjust bounding boxes.") - .def(py::init(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability); - - (void)py::class_>(*m, "RandomCropOp", - "Gives random crop of specified size " - "Takes crop size") - .def(py::init(), - py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop, - py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft, - py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType, - py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR, - py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB); - (void)py::class_>(*m, "ChannelSwapOp").def(py::init<>()); - - (void)py::class_>(*m, "RandomCropWithBBoxOp", - "Gives random crop of given " - "size + adjusts bboxes " - "Takes crop size") - .def(py::init(), - py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop, - py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom, - py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft, - py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight, - py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType, - py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded, - py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG, - py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB); - - (void)py::class_>( - *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") - .def(py::init()); - - (void)py::class_>( - *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") - .def(py::init>()); - - (void)py::class_>(*m, "SliceOp", "Tensor slice operation.") - .def(py::init()) - .def(py::init([](const py::list &py_list) { - std::vector c_list; - for (auto l : py_list) { - if (!l.is_none()) { - c_list.push_back(py::reinterpret_borrow(l)); - } - } - return std::make_shared(c_list); - })) - .def(py::init([](const py::tuple &py_slice) { - if (py_slice.size() != 3) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); - } - Slice c_slice; - if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1]), - py::reinterpret_borrow(py_slice[2])); - } else if (py_slice[0].is_none() && py_slice[2].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[1])); - } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { - c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1])); - } - - if (!c_slice.valid()) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); - } - return std::make_shared(c_slice); - })); - - (void)py::enum_(*m, "RelationalOp", py::arithmetic()) - .value("EQ", RelationalOp::kEqual) - .value("NE", RelationalOp::kNotEqual) - .value("LT", RelationalOp::kLess) - .value("LE", RelationalOp::kLessEqual) - .value("GT", RelationalOp::kGreater) - .value("GE", RelationalOp::kGreaterEqual) - .export_values(); - - (void)py::class_>(*m, "MaskOp", - "Tensor mask operation using relational comparator") - .def(py::init, DataType>()); - - (void)py::class_>(*m, "DuplicateOp", "Duplicate tensor.") - .def(py::init<>()); - - (void)py::class_>( - *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") - .def(py::init()); - - (void)py::class_>(*m, "ConcatenateOp", - "Tensor operation concatenate tensors.") - .def(py::init, std::shared_ptr>(), py::arg("axis"), - py::arg("prepend").none(true), py::arg("append").none(true)); - - (void)py::class_>( - *m, "RandomRotationOp", - "Tensor operation to apply RandomRotation." - "Takes a range for degrees and " - "optional parameters for rotation center and image expand") - .def(py::init(), - py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX, - py::arg("centerY") = RandomRotationOp::kDefCenterY, - py::arg("interpolation") = RandomRotationOp::kDefInterpolation, - py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR, - py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); - - (void)py::class_>( - *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.") - .def(py::init>()); -} - -void bindTensorOps3(py::module *m) { - (void)py::class_>( - *m, "RandomCropAndResizeOp", - "Tensor operation to randomly crop an image and resize to a given size." - "Takes output height and width and" - "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," - "interpolation mode, and max attempts to crop") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb, - py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation, - py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter); - - (void)py::class_>( - *m, "RandomCropAndResizeWithBBoxOp", - "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size." - "Takes output height and width and" - "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," - "interpolation mode, and max attempts to crop") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb, - py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation, - py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter); - - (void)py::class_>( - *m, "RandomColorAdjustOp", - "Tensor operation to adjust an image's color randomly." - "Takes range for brightness, contrast, saturation, hue and") - .def(py::init(), py::arg("bright_factor_start"), - py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"), - py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"), - py::arg("hue_factor_end")); - - (void)py::class_>( - *m, "RandomResizeOp", - "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); - - (void)py::class_>( - *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.") - .def(py::init(), py::arg("boxHeight"), - py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor, - py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG, - py::arg("fillB") = CutOutOp::kDefFillB); -} - -void bindTensorOps4(py::module *m) { - (void)py::class_>( - *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") - .def(py::init(), py::arg("data_type")) - .def(py::init(), py::arg("data_type")); - - (void)py::class_>(*m, "NoOp", - "TensorOp that does nothing, for testing purposes only.") - .def(py::init<>()); - - (void)py::class_>( - *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.") - .def(py::init<>()); - - (void)py::class_>( - *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding") - .def(py::init(), py::arg("targetHeight"), - py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb, - py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb, - py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb, - py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb, - py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation, - py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter); - - (void)py::class_>( - *m, "PadOp", - "Pads image with specified color, default black, " - "Takes amount to pad for top, bottom, left, right of image, boarder type and color") - .def(py::init(), py::arg("padTop"), - py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType, - py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); - (void)py::class_>(*m, "ToNumberOp", - "TensorOp to convert strings to numbers.") - .def(py::init(), py::arg("data_type")) - .def(py::init(), py::arg("data_type")); -} - -void bindTokenizerOps(py::module *m) { - (void)py::class_>(*m, "JiebaTokenizerOp", "") - .def(py::init(), py::arg("hmm_path"), - py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix, - py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets) - .def("add_word", - [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); - (void)py::class_>( - *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") - .def(py::init(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets); - (void)py::class_>(*m, "LookupOp", - "Tensor operation to LookUp each word.") - .def(py::init([](std::shared_ptr vocab, const py::object &py_word) { - if (vocab == nullptr) { - THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); - } - if (py_word.is_none()) { - return std::make_shared(vocab, Vocab::kNoTokenExists); - } - std::string word = py::reinterpret_borrow(py_word); - WordIdType default_id = vocab->Lookup(word); - if (default_id == Vocab::kNoTokenExists) { - THROW_IF_ERROR( - Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab.")); - } - return std::make_shared(vocab, default_id); - })); - (void)py::class_>(*m, "NgramOp", "TensorOp performs ngram mapping.") - .def(py::init &, int32_t, int32_t, const std::string &, const std::string &, - const std::string &>(), - py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"), - py::arg("separator")); - (void)py::class_>( - *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.") - .def( - py::init &, const std::string &, const int &, const std::string &, const bool &>(), - py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), - py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, - py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), - py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); -} - -void bindDependIcuTokenizerOps(py::module *m) { -#ifdef ENABLE_ICU4C - (void)py::class_>( - *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.") - .def(py::init(), py::arg("with_offsets") = WhitespaceTokenizerOp::kDefWithOffsets); - (void)py::class_>( - *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.") - .def(py::init<>()) - .def(py::init(), - py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace, - py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets); - (void)py::class_>( - *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor") - .def(py::init<>()); - (void)py::class_>( - *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.") - .def(py::init<>()) - .def(py::init(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm); - (void)py::class_>( - *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.") - .def(py::init(), py::arg("pattern"), py::arg("replace"), - py::arg("replace_all")); - (void)py::class_>( - *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.") - .def(py::init(), py::arg("delim_pattern"), - py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets); - (void)py::class_>( - *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.") - .def(py::init(), - py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, - py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, - py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, - py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, - py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets); - (void)py::class_>(*m, "BertTokenizerOp", - "Tokenizer used for Bert text process.") - .def(py::init &, const std::string &, const int &, const std::string &, const bool &, - const bool &, const NormalizeForm &, const bool &, const bool &>(), - py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), - py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, - py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), - py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, - py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, - py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, - py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, - py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); -#endif -} - -void bindSamplerOps(py::module *m) { - (void)py::class_>(*m, "Sampler") - .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) - .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) - .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) - .def("get_indices", - [](Sampler &self) { - py::array ret; - THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); - return ret; - }) - .def("add_child", - [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); - - (void)py::class_>(*m, "ShardOperator") - .def("add_child", [](std::shared_ptr self, - std::shared_ptr child) { self->SetChildOp(child); }); - - (void)py::class_>(*m, "DistributedSampler") - .def(py::init()); - - (void)py::class_>(*m, "PKSampler") - .def(py::init()); - - (void)py::class_>(*m, "RandomSampler") - .def(py::init()); - - (void)py::class_>(*m, "SequentialSampler") - .def(py::init()); - - (void)py::class_>(*m, "SubsetRandomSampler") - .def(py::init>()); - - (void)py::class_>( - *m, "MindrecordSubsetRandomSampler") - .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); - - (void)py::class_>( - *m, "MindrecordPkSampler") - .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { - if (shuffle == true) { - return std::make_shared(kColumn, kVal, std::numeric_limits::max(), - GetSeed()); - } else { - return std::make_shared(kColumn, kVal); - } - })); - - (void)py::class_>(*m, "MindrecordDistributedSampler") - .def(py::init()); - - (void)py::class_>( - *m, "MindrecordRandomSampler") - .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { - return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); - })); - - (void)py::class_>(*m, "MindrecordSequentialSampler") - .def(py::init([](int num_samples, int start_index) { - return std::make_shared(num_samples, start_index); - })); - - (void)py::class_>(*m, "WeightedRandomSampler") - .def(py::init, bool>()); - - (void)py::class_>(*m, "PythonSampler") - .def(py::init()); -} - -void bindInfoObjects(py::module *m) { - (void)py::class_(*m, "CBatchInfo") - .def(py::init()) - .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) - .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); -} - -void bindCacheClient(py::module *m) { - (void)py::class_>(*m, "CacheClient") - .def(py::init()); -} - -void bindVocabObjects(py::module *m) { - (void)py::class_>(*m, "Vocab") - .def(py::init<>()) - .def_static("from_list", - [](const py::list &words, const py::list &special_tokens, bool special_first) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); - return v; - }) - .def_static("from_file", - [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens, - bool special_first) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); - return v; - }) - .def_static("from_dict", [](const py::dict &words) { - std::shared_ptr v; - THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); - return v; - }); -} - -void bindGraphData(py::module *m) { - (void)py::class_>(*m, "Graph") - .def(py::init([](std::string dataset_file, int32_t num_workers) { - std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); - THROW_IF_ERROR(g_out->Init()); - return g_out; - })) - .def("get_all_nodes", - [](gnn::Graph &g, gnn::NodeType node_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); - return out; - }) - .def("get_all_edges", - [](gnn::Graph &g, gnn::EdgeType edge_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); - return out; - }) - .def("get_nodes_from_edges", - [](gnn::Graph &g, std::vector edge_list) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); - return out; - }) - .def("get_all_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); - return out; - }) - .def("get_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, - std::vector neighbor_types) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); - return out; - }) - .def("get_neg_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, - gnn::NodeType neg_neighbor_type) { - std::shared_ptr out; - THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); - return out; - }) - .def("get_node_feature", - [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { - TensorRow out; - THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); - return out.getRow(); - }) - .def("get_edge_feature", - [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { - TensorRow out; - THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); - return out.getRow(); - }) - .def("graph_info", - [](gnn::Graph &g) { - py::dict out; - THROW_IF_ERROR(g.GraphInfo(&out)); - return out; - }) - .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, - float step_home_param, float step_away_param, gnn::NodeIdType default_node) { - std::shared_ptr out; - THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); - return out; - }); -} - -// This is where we externalize the C logic as python modules -PYBIND11_MODULE(_c_dataengine, m) { - m.doc() = "pybind11 for _c_dataengine"; - (void)py::class_>(m, "DatasetOp"); - - (void)py::enum_(m, "OpName", py::arithmetic()) - .value("SHUFFLE", OpName::kShuffle) - .value("BATCH", OpName::kBatch) - .value("BUCKETBATCH", OpName::kBucketBatch) - .value("BARRIER", OpName::kBarrier) - .value("MINDRECORD", OpName::kMindrecord) - .value("CACHE", OpName::kCache) - .value("REPEAT", OpName::kRepeat) - .value("SKIP", OpName::kSkip) - .value("TAKE", OpName::kTake) - .value("ZIP", OpName::kZip) - .value("CONCAT", OpName::kConcat) - .value("MAP", OpName::kMap) - .value("FILTER", OpName::kFilter) - .value("DEVICEQUEUE", OpName::kDeviceQueue) - .value("GENERATOR", OpName::kGenerator) - .export_values() - .value("RENAME", OpName::kRename) - .value("TFREADER", OpName::kTfReader) - .value("PROJECT", OpName::kProject) - .value("IMAGEFOLDER", OpName::kImageFolder) - .value("MNIST", OpName::kMnist) - .value("MANIFEST", OpName::kManifest) - .value("VOC", OpName::kVoc) - .value("COCO", OpName::kCoco) - .value("CIFAR10", OpName::kCifar10) - .value("CIFAR100", OpName::kCifar100) - .value("RANDOMDATA", OpName::kRandomData) - .value("BUILDVOCAB", OpName::kBuildVocab) - .value("CELEBA", OpName::kCelebA) - .value("TEXTFILE", OpName::kTextFile) - .value("CLUE", OpName::kClue); - - (void)py::enum_(m, "JiebaMode", py::arithmetic()) - .value("DE_JIEBA_MIX", JiebaMode::kMix) - .value("DE_JIEBA_MP", JiebaMode::kMp) - .value("DE_JIEBA_HMM", JiebaMode::kHmm) - .export_values(); - -#ifdef ENABLE_ICU4C - (void)py::enum_(m, "NormalizeForm", py::arithmetic()) - .value("DE_NORMALIZE_NONE", NormalizeForm::kNone) - .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc) - .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc) - .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd) - .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd) - .export_values(); -#endif - - (void)py::enum_(m, "InterpolationMode", py::arithmetic()) - .value("DE_INTER_LINEAR", InterpolationMode::kLinear) - .value("DE_INTER_CUBIC", InterpolationMode::kCubic) - .value("DE_INTER_AREA", InterpolationMode::kArea) - .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) - .export_values(); - - (void)py::enum_(m, "BorderType", py::arithmetic()) - .value("DE_BORDER_CONSTANT", BorderType::kConstant) - .value("DE_BORDER_EDGE", BorderType::kEdge) - .value("DE_BORDER_REFLECT", BorderType::kReflect) - .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric) - .export_values(); - bindDEPipeline(&m); - bindTensor(&m); - bindTensorOps1(&m); - bindTensorOps2(&m); - bindTensorOps3(&m); - bindTensorOps4(&m); - bindTokenizerOps(&m); - bindSamplerOps(&m); - bindDatasetOps(&m); - bindInfoObjects(&m); - bindCacheClient(&m); - bindVocabObjects(&m); - bindGraphData(&m); - bindDependIcuTokenizerOps(&m); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/samplers.cc b/mindspore/ccsrc/dataset/api/samplers.cc deleted file mode 100644 index 44d01c2f0c..0000000000 --- a/mindspore/ccsrc/dataset/api/samplers.cc +++ /dev/null @@ -1,224 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/include/samplers.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" - -namespace mindspore { -namespace dataset { -namespace api { - -SamplerObj::SamplerObj() {} - -/// Function to create a Distributed Sampler. -std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, - int64_t num_samples, uint32_t seed) { - auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed); - // Input validation - if (!sampler->ValidateParams()) { - return nullptr; - } - return sampler; -} - -/// Function to create a PK Sampler. -std::shared_ptr PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { - auto sampler = std::make_shared(num_val, shuffle, num_samples); - // Input validation - if (!sampler->ValidateParams()) { - return nullptr; - } - return sampler; -} - -/// Function to create a Random Sampler. -std::shared_ptr RandomSampler(bool replacement, int64_t num_samples) { - auto sampler = std::make_shared(replacement, num_samples); - // Input validation - if (!sampler->ValidateParams()) { - return nullptr; - } - return sampler; -} - -/// Function to create a Sequential Sampler. -std::shared_ptr SequentialSampler(int64_t start_index, int64_t num_samples) { - auto sampler = std::make_shared(start_index, num_samples); - // Input validation - if (!sampler->ValidateParams()) { - return nullptr; - } - return sampler; -} - -/// Function to create a Subset Random Sampler. -std::shared_ptr SubsetRandomSampler(const std::vector &indices, int64_t num_samples) { - auto sampler = std::make_shared(indices, num_samples); - // Input validation - if (!sampler->ValidateParams()) { - return nullptr; - } - return sampler; -} - -/// Function to create a Weighted Random Sampler. -std::shared_ptr WeightedRandomSampler(const std::vector &weights, int64_t num_samples, - bool replacement) { - auto sampler = std::make_shared(weights, num_samples, replacement); - // Input validation - if (!sampler->ValidateParams()) { - return nullptr; - } - return sampler; -} - -/* ####################################### Derived Sampler classes ################################# */ - -// DistributedSampler -DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, - uint32_t seed) - : num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {} - -bool DistributedSamplerObj::ValidateParams() { - if (num_shards_ <= 0) { - MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_; - return false; - } - - if (shard_id_ < 0 || shard_id_ >= num_shards_) { - MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; - return false; - } - - if (num_samples_ < 0) { - MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_; - return false; - } - - return true; -} - -std::shared_ptr DistributedSamplerObj::Build() { - return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_); -} - -// PKSampler -PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) - : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} - -bool PKSamplerObj::ValidateParams() { - if (num_val_ <= 0) { - MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_; - return false; - } - - if (num_samples_ < 0) { - MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_; - return false; - } - return true; -} - -std::shared_ptr PKSamplerObj::Build() { - return std::make_shared(num_samples_, num_val_, shuffle_); -} - -// RandomSampler -RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) - : replacement_(replacement), num_samples_(num_samples) {} - -bool RandomSamplerObj::ValidateParams() { - if (num_samples_ < 0) { - MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_; - return false; - } - return true; -} - -std::shared_ptr RandomSamplerObj::Build() { - bool reshuffle_each_epoch = true; - auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); - return sampler; -} - -// SequentialSampler -SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) - : start_index_(start_index), num_samples_(num_samples) {} - -bool SequentialSamplerObj::ValidateParams() { - if (num_samples_ < 0) { - MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_; - return false; - } - - if (start_index_ < 0) { - MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_; - return false; - } - - return true; -} - -std::shared_ptr SequentialSamplerObj::Build() { - auto sampler = std::make_shared(num_samples_, start_index_); - return sampler; -} - -// SubsetRandomSampler -SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples) - : indices_(indices), num_samples_(num_samples) {} - -bool SubsetRandomSamplerObj::ValidateParams() { - if (num_samples_ < 0) { - MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_; - return false; - } - - return true; -} - -std::shared_ptr SubsetRandomSamplerObj::Build() { - auto sampler = std::make_shared(num_samples_, indices_); - return sampler; -} - -// WeightedRandomSampler -WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples, - bool replacement) - : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} - -bool WeightedRandomSamplerObj::ValidateParams() { - if (num_samples_ < 0) { - MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; - return false; - } - return true; -} - -std::shared_ptr WeightedRandomSamplerObj::Build() { - auto sampler = std::make_shared(num_samples_, weights_, replacement_); - return sampler; -} - -} // namespace api -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/transforms.cc b/mindspore/ccsrc/dataset/api/transforms.cc deleted file mode 100644 index e086837447..0000000000 --- a/mindspore/ccsrc/dataset/api/transforms.cc +++ /dev/null @@ -1,491 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/include/transforms.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/normalize_op.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/random_crop_op.h" -#include "dataset/kernels/image/center_crop_op.h" -#include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/kernels/image/random_horizontal_flip_op.h" -#include "dataset/kernels/image/random_vertical_flip_op.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/kernels/image/cut_out_op.h" -#include "dataset/kernels/image/random_color_adjust_op.h" -#include "dataset/kernels/image/pad_op.h" - -namespace mindspore { -namespace dataset { -namespace api { - -TensorOperation::TensorOperation() {} - -// Transform operations for computer vision. -namespace vision { - -// Function to create NormalizeOperation. -std::shared_ptr Normalize(std::vector mean, std::vector std) { - auto op = std::make_shared(mean, std); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create DecodeOperation. -std::shared_ptr Decode(bool rgb) { - auto op = std::make_shared(rgb); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create ResizeOperation. -std::shared_ptr Resize(std::vector size, InterpolationMode interpolation) { - auto op = std::make_shared(size, interpolation); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create RandomCropOperation. -std::shared_ptr RandomCrop(std::vector size, std::vector padding, - bool pad_if_needed, std::vector fill_value) { - auto op = std::make_shared(size, padding, pad_if_needed, fill_value); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create CenterCropOperation. -std::shared_ptr CenterCrop(std::vector size) { - auto op = std::make_shared(size); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create UniformAugOperation. -std::shared_ptr UniformAugment(std::vector> operations, - int32_t num_ops) { - auto op = std::make_shared(operations, num_ops); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create RandomHorizontalFlipOperation. -std::shared_ptr RandomHorizontalFlip(float prob) { - auto op = std::make_shared(prob); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create RandomVerticalFlipOperation. -std::shared_ptr RandomVerticalFlip(float prob) { - auto op = std::make_shared(prob); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create RandomRotationOperation. -std::shared_ptr RandomRotation(std::vector degrees, InterpolationMode resample, - bool expand, std::vector center, - std::vector fill_value) { - auto op = std::make_shared(degrees, resample, expand, center, fill_value); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create PadOperation. -std::shared_ptr Pad(std::vector padding, std::vector fill_value, - BorderType padding_mode) { - auto op = std::make_shared(padding, fill_value, padding_mode); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create CutOutOp. -std::shared_ptr CutOut(int32_t length, int32_t num_patches) { - auto op = std::make_shared(length, num_patches); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -// Function to create RandomColorAdjustOperation. -std::shared_ptr RandomColorAdjust(std::vector brightness, - std::vector contrast, - std::vector saturation, std::vector hue) { - auto op = std::make_shared(brightness, contrast, saturation, hue); - // Input validation - if (!op->ValidateParams()) { - return nullptr; - } - return op; -} - -/* ####################################### Derived TensorOperation classes ################################# */ - -// NormalizeOperation -NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} - -bool NormalizeOperation::ValidateParams() { - if (mean_.size() != 3) { - MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); - return false; - } - - if (std_.size() != 3) { - MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); - return false; - } - - return true; -} - -std::shared_ptr NormalizeOperation::Build() { - return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); -} - -// DecodeOperation -DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} - -bool DecodeOperation::ValidateParams() { return true; } - -std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } - -// ResizeOperation -ResizeOperation::ResizeOperation(std::vector size, InterpolationMode interpolation) - : size_(size), interpolation_(interpolation) {} - -bool ResizeOperation::ValidateParams() { - if (size_.empty() || size_.size() > 2) { - MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); - return false; - } - return true; -} - -std::shared_ptr ResizeOperation::Build() { - int32_t height = size_[0]; - int32_t width = 0; - - // User specified the width value. - if (size_.size() == 2) { - width = size_[1]; - } - - return std::make_shared(height, width, interpolation_); -} - -// RandomCropOperation -RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, - std::vector fill_value) - : size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {} - -bool RandomCropOperation::ValidateParams() { - if (size_.empty() || size_.size() > 2) { - MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size(); - return false; - } - - if (padding_.empty() || padding_.size() != 4) { - MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()"; - return false; - } - - if (fill_value_.empty() || fill_value_.size() != 3) { - MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()"; - return false; - } - return true; -} - -std::shared_ptr RandomCropOperation::Build() { - int32_t crop_height = size_[0]; - int32_t crop_width = 0; - - int32_t pad_top = padding_[0]; - int32_t pad_bottom = padding_[1]; - int32_t pad_left = padding_[2]; - int32_t pad_right = padding_[3]; - - uint8_t fill_r = fill_value_[0]; - uint8_t fill_g = fill_value_[1]; - uint8_t fill_b = fill_value_[2]; - - // User has specified the crop_width value. - if (size_.size() == 2) { - crop_width = size_[1]; - } - - auto tensor_op = std::make_shared(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, - BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b); - return tensor_op; -} - -// CenterCropOperation -CenterCropOperation::CenterCropOperation(std::vector size) : size_(size) {} - -bool CenterCropOperation::ValidateParams() { - if (size_.empty() || size_.size() > 2) { - MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; - return false; - } - return true; -} - -std::shared_ptr CenterCropOperation::Build() { - int32_t crop_height = size_[0]; - int32_t crop_width = 0; - - // User has specified crop_width. - if (size_.size() == 2) { - crop_width = size_[1]; - } - - std::shared_ptr tensor_op = std::make_shared(crop_height, crop_width); - return tensor_op; -} - -// UniformAugOperation -UniformAugOperation::UniformAugOperation(std::vector> operations, int32_t num_ops) - : operations_(operations), num_ops_(num_ops) {} - -bool UniformAugOperation::ValidateParams() { return true; } - -std::shared_ptr UniformAugOperation::Build() { - std::vector> tensor_ops; - (void)std::transform(operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), - [](std::shared_ptr op) -> std::shared_ptr { return op->Build(); }); - std::shared_ptr tensor_op = std::make_shared(tensor_ops, num_ops_); - return tensor_op; -} - -// RandomHorizontalFlipOperation -RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} - -bool RandomHorizontalFlipOperation::ValidateParams() { return true; } - -std::shared_ptr RandomHorizontalFlipOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(probability_); - return tensor_op; -} - -// RandomVerticalFlipOperation -RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} - -bool RandomVerticalFlipOperation::ValidateParams() { return true; } - -std::shared_ptr RandomVerticalFlipOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(probability_); - return tensor_op; -} - -// Function to create RandomRotationOperation. -RandomRotationOperation::RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, - bool expand, std::vector center, - std::vector fill_value) - : degrees_(degrees), - interpolation_mode_(interpolation_mode), - expand_(expand), - center_(center), - fill_value_(fill_value) {} - -bool RandomRotationOperation::ValidateParams() { - if (degrees_.empty() || degrees_.size() != 2) { - MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()"; - return false; - } - if (center_.empty() || center_.size() != 2) { - MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()"; - return false; - } - if (fill_value_.empty() || fill_value_.size() != 3) { - MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()"; - return false; - } - return true; -} - -std::shared_ptr RandomRotationOperation::Build() { - std::shared_ptr tensor_op = - std::make_shared(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_, - fill_value_[0], fill_value_[1], fill_value_[2]); - return tensor_op; -} - -// PadOperation -PadOperation::PadOperation(std::vector padding, std::vector fill_value, BorderType padding_mode) - : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} - -bool PadOperation::ValidateParams() { - if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) { - MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()"; - return false; - } - - if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) { - MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()"; - return false; - } - return true; -} - -std::shared_ptr PadOperation::Build() { - int32_t pad_top, pad_bottom, pad_left, pad_right; - switch (padding_.size()) { - case 1: - pad_left = padding_[0]; - pad_top = padding_[0]; - pad_right = padding_[0]; - pad_bottom = padding_[0]; - break; - case 2: - pad_left = padding_[0]; - pad_top = padding_[1]; - pad_right = padding_[0]; - pad_bottom = padding_[1]; - break; - default: - pad_left = padding_[0]; - pad_top = padding_[1]; - pad_right = padding_[2]; - pad_bottom = padding_[3]; - } - uint8_t fill_r, fill_g, fill_b; - - fill_r = fill_value_[0]; - fill_g = fill_value_[0]; - fill_b = fill_value_[0]; - - if (fill_value_.size() == 3) { - fill_r = fill_value_[0]; - fill_g = fill_value_[1]; - fill_b = fill_value_[2]; - } - - std::shared_ptr tensor_op = - std::make_shared(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b); - return tensor_op; -} - -// CutOutOperation -CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} - -bool CutOutOperation::ValidateParams() { - if (length_ < 0) { - MS_LOG(ERROR) << "CutOut: length cannot be negative"; - return false; - } - if (num_patches_ < 0) { - MS_LOG(ERROR) << "CutOut: number of patches cannot be negative"; - return false; - } - return true; -} - -std::shared_ptr CutOutOperation::Build() { - std::shared_ptr tensor_op = std::make_shared(length_, length_, num_patches_, false, 0, 0, 0); - return tensor_op; -} - -// RandomColorAdjustOperation. -RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector brightness, std::vector contrast, - std::vector saturation, std::vector hue) - : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {} - -bool RandomColorAdjustOperation::ValidateParams() { - // Do some input validation. - if (brightness_.empty() || brightness_.size() > 2) { - MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values"; - return false; - } - if (contrast_.empty() || contrast_.size() > 2) { - MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values"; - return false; - } - if (saturation_.empty() || saturation_.size() > 2) { - MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values"; - return false; - } - if (hue_.empty() || hue_.size() > 2) { - MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values"; - return false; - } - return true; -} - -std::shared_ptr RandomColorAdjustOperation::Build() { - float brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub; - - brightness_lb = brightness_[0]; - brightness_ub = brightness_[0]; - - if (brightness_.size() == 2) brightness_ub = brightness_[1]; - - contrast_lb = contrast_[0]; - contrast_ub = contrast_[0]; - - if (contrast_.size() == 2) contrast_ub = contrast_[1]; - - saturation_lb = saturation_[0]; - saturation_ub = saturation_[0]; - - if (saturation_.size() == 2) saturation_ub = saturation_[1]; - - hue_lb = hue_[0]; - hue_ub = hue_[0]; - - if (hue_.size() == 2) hue_ub = hue_[1]; - - std::shared_ptr tensor_op = std::make_shared( - brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub); - return tensor_op; -} - -} // namespace vision -} // namespace api -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/client.cc b/mindspore/ccsrc/dataset/core/client.cc deleted file mode 100644 index 6247ddae7d..0000000000 --- a/mindspore/ccsrc/dataset/core/client.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/client.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/util/services.h" -#include "dataset/util/sig_handler.h" - -namespace mindspore { -namespace dataset { -// This is a one-time global initializer which includes the call to instantiate singletons. -// It is external api call and not a member of the GlobalContext directly. -Status GlobalInit() { - // Bring up all the services (logger, task, bufferpool) - return (Services::CreateInstance()); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h deleted file mode 100644 index 96553c9169..0000000000 --- a/mindspore/ccsrc/dataset/core/client.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CLIENT_H_ -#define DATASET_CORE_CLIENT_H_ - -// client.h -// Include file for DE client functions - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" - -#ifdef ENABLE_PYTHON -#include "dataset/engine/datasetops/barrier_op.h" -#include "dataset/engine/datasetops/filter_op.h" -#include "dataset/engine/datasetops/source/generator_op.h" -#include "dataset/engine/datasetops/build_vocab_op.h" -#endif - -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/project_op.h" -#include "dataset/engine/datasetops/rename_op.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/skip_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/take_op.h" -#include "dataset/engine/datasetops/zip_op.h" -#include "dataset/engine/datasetops/concat_op.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// This is a one-time global initializer that needs to be called at the -// start of any minddata applications. -extern Status GlobalInit(); -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_CLIENT_H_ diff --git a/mindspore/ccsrc/dataset/core/config_manager.cc b/mindspore/ccsrc/dataset/core/config_manager.cc deleted file mode 100644 index 9291a8f832..0000000000 --- a/mindspore/ccsrc/dataset/core/config_manager.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/config_manager.h" - -#include -#include -#include - -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -// A print method typically used for debugging -void ConfigManager::Print(std::ostream &out) const { - // Don't show the test/internal ones. Only display the main ones here. - // fyi, boolalpha tells the output stream to write "true" and "false" for bools - out << "\nClient config settings :" - << "\nDataCache Rows per buffer : " << rows_per_buffer_ - << "\nParallelOp workers : " << num_parallel_workers_ - << "\nParallelOp worker connector size : " << worker_connector_size_ - << "\nSize of each Connector : " << op_connector_size_ << std::endl; -} - -// Private helper function that taks a nlohmann json format and populates the settings -Status ConfigManager::FromJson(const nlohmann::json &j) { - set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_)); - set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); - set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_)); - set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); - set_seed(j.value("seed", seed_)); - set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); - return Status::OK(); -} - -// Loads a json file with the default settings and populates all the settings -Status ConfigManager::LoadFile(const std::string &settingsFile) { - Status rc; - if (!Path(settingsFile).Exists()) { - RETURN_STATUS_UNEXPECTED("File is not found."); - } - // Some settings are mandatory, others are not (with default). If a setting - // is optional it will set a default value if the config is missing from the file. - try { - std::ifstream in(settingsFile); - nlohmann::json js; - in >> js; - rc = FromJson(js); - } catch (const nlohmann::json::type_error &e) { - std::ostringstream ss; - ss << "Client file failed to load:\n" << e.what(); - std::string err_msg = ss.str(); - RETURN_STATUS_UNEXPECTED(err_msg); - } catch (const std::exception &err) { - RETURN_STATUS_UNEXPECTED("Client file failed to load."); - } - return rc; -} - -// Setter function -void ConfigManager::set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; } - -// Setter function -void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { - num_parallel_workers_ = num_parallel_workers; -} - -// Setter function -void ConfigManager::set_worker_connector_size(int32_t connector_size) { worker_connector_size_ = connector_size; } - -// Setter function -void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector_size_ = connector_size; } - -uint32_t ConfigManager::seed() const { return seed_; } - -void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } - -void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/config_manager.h b/mindspore/ccsrc/dataset/core/config_manager.h deleted file mode 100644 index 807591daa1..0000000000 --- a/mindspore/ccsrc/dataset/core/config_manager.h +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CONFIG_MANAGER_H_ -#define DATASET_CORE_CONFIG_MANAGER_H_ - -#include -#include -#include - -#include - -#include "dataset/core/constants.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" - -// Config settings for the client-side -// example config file: -// { -// "rowsPerBuffer": 3 -// } -// - -namespace mindspore { -namespace dataset { -// The ConfigManager is a class for managing default values. When a user is constructing any objects -// in the framework, often they may choose to omit some settings instead of overriding them. -// This class manages some of the default values, for cases when the user does not manually specify -// those values. -class ConfigManager { - public: - ConfigManager() = default; - - // destructor - ~ConfigManager() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - void Print(std::ostream &out) const; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param cS - reference to the ConfigManager to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ConfigManager &cS) { - cS.Print(out); - return out; - } - - // Another debug print helper. Converts the print info to a string for you. - // @return The string version of the debug print - std::string ToString() { - std::stringstream ss; - ss << *this; - return ss.str(); - } - - // Loads a json file with the default settings and populates all the settings - // @param settingsFile - A json file with a set of default settings - // @return Status error code - Status LoadFile(const std::string &settingsFile); - - // getter function - // @return The rows per buffer setting - int32_t rows_per_buffer() const { return rows_per_buffer_; } - - // getter function - // @return The number of workers setting - int32_t num_parallel_workers() const { return num_parallel_workers_; } - - // getter function - // @return The queue size of the operator's output connector - int32_t op_connector_size() const { return op_connector_size_; } - - // getter function - // @return The internal worker-to-master connector queue size - int32_t worker_connector_size() const { return worker_connector_size_; } - - // setter function - // @param rows_per_buffer - The setting to apply to the config - void set_rows_per_buffer(int32_t rows_per_buffer); - - // setter function - // @param num_parallel_workers - The setting to apply to the config - void set_num_parallel_workers(int32_t num_parallel_workers); - - // setter function - // @param connector_size - The setting to apply to the config - void set_worker_connector_size(int32_t connector_size); - - // setter function - // @param connector_size - The setting to apply to the config - void set_op_connector_size(int32_t connector_size); - - uint32_t seed() const; - - // setter function - // @param seed - The default seed to use - void set_seed(uint32_t seed); - - // setter function - // @param interval - The setting to apply to the config - void set_monitor_sampling_interval(uint32_t interval); - - // getter function - // @return The iterval of monitor sampling - int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } - - private: - int32_t rows_per_buffer_{kCfgRowsPerBuffer}; - int32_t num_parallel_workers_{kCfgParallelWorkers}; - int32_t worker_connector_size_{kCfgWorkerConnectorSize}; - int32_t op_connector_size_{kCfgOpConnectorSize}; - uint32_t seed_{kCfgDefaultSeed}; - uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; - - // Private helper function that taks a nlohmann json format and populates the settings - // @param j - The json nlohmann json info - Status FromJson(const nlohmann::json &j); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_CONFIG_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/core/cv_tensor.cc b/mindspore/ccsrc/dataset/core/cv_tensor.cc deleted file mode 100644 index 16921e8b2d..0000000000 --- a/mindspore/ccsrc/dataset/core/cv_tensor.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/cv_tensor.h" - -#include -#include - -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { -CVTensor::CVTensor(const TensorShape &shape, const DataType &type) : Tensor(shape, type) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -CVTensor::CVTensor(const TensorShape &shape, const DataType &type, const uchar *data) : Tensor(shape, type, data) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} - -std::pair, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) { - std::array size = {1, 1}; - if (shape.Rank() <= 2 || (shape.Rank() == 3 && shape[2] <= CV_CN_MAX)) { - uint8_t ch = 1; - if (shape.Rank() == 3) { - ch = static_cast(shape[2]); - } - if (shape.Rank() > 0) size[0] = static_cast(shape[0]); - if (shape.Rank() > 1) size[1] = static_cast(shape[1]); - if (type.AsCVType() == kCVInvalidType) return std::make_pair(size, -1); - - int cv_type = CV_MAKETYPE(type.AsCVType(), ch); - return std::make_pair(size, cv_type); - } - return std::make_pair(size, -1); -} - -std::shared_ptr CVTensor::AsCVTensor(std::shared_ptr t) { - std::shared_ptr cv_t = std::dynamic_pointer_cast(t); - if (cv_t != nullptr) { - return cv_t; - } else { - return std::make_shared(t); - } -} - -Status CVTensor::MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat) { - std::pair, int> cv_shape_type = IsValidImage(shape, type); - if (cv_shape_type.second == -1) { - std::vector sizes = shape.AsVector(); - std::vector sizes32(sizes.begin(), sizes.end()); // convert long to int for usage with OpenCV - if (static_cast(shape.Rank()) != shape.Rank()) { - RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Wrong shape."); - } - - uint8_t cv_type = type.AsCVType(); - if (cv_type == kCVInvalidType) { - RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Invalid type."); - } - *mat = cv::Mat(static_cast(shape.Rank()), &sizes32[0], cv_type, data); - } else { - *mat = cv::Mat(2, &(cv_shape_type.first[0]), cv_shape_type.second, data); - } - return Status::OK(); -} - -Status CVTensor::Reshape(const TensorShape &shape) { - RETURN_IF_NOT_OK(Tensor::Reshape(shape)); - RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); - return Status::OK(); -} - -Status CVTensor::ExpandDim(const dsize_t &axis) { - RETURN_IF_NOT_OK(Tensor::ExpandDim(axis)); - RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); - return Status::OK(); -} - -void CVTensor::Squeeze() { - Tensor::Squeeze(); - (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/cv_tensor.h b/mindspore/ccsrc/dataset/core/cv_tensor.h deleted file mode 100644 index 8c136f5f3c..0000000000 --- a/mindspore/ccsrc/dataset/core/cv_tensor.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_CV_TENSOR_H_ -#define DATASET_CORE_CV_TENSOR_H_ - -#include -#include -#include - -#include - -#include "./securec.h" - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { -class CVTensor : public Tensor { - public: - // Create an empty CVTensor of shape `shape` and type `type`. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - CVTensor(const TensorShape &shape, const DataType &type); - - // Create a CVTensor from a given buffer, shape and type. - // @note This constructor allocates a new space in the memory and copies the buffer into it. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - CVTensor(const TensorShape &shape, const DataType &type, const uchar *data); - - // Create a CVTensor from a given CV::Mat. - // @note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. - // @param mat CV::Mat - explicit CVTensor(const cv::Mat &mat) - : CVTensor(TensorShape(mat.size, mat.type()), DataType::FromCVType(mat.type()), mat.data) {} - - ~CVTensor() = default; - - // Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, - // this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data - // provided. The Passed Tensor will be invalidated. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor - // @return CVTensor - static std::shared_ptr AsCVTensor(std::shared_ptr tensor); - - // Create a CVTensor from a given tensor. The input tensor will be invalidated (i.e., the shape and type will be - // set to unknown and the data buffer will point to null. - // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. - // @param tensor - explicit CVTensor(std::shared_ptr tensor); - - // Getter function for the CV::Mat - // @return - cv::Mat mat() const { return mat_; } - - // Static function to check if the passed information (shape and type) can be treated as a valid description - // of an image in OpenCV. Moreover, it returns OpenCV shape and type - // For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. - // In case of invalid shape or type, the function will return pair - // @param shape TensorShape - // @param type DataType - // @return std::pair of OpenCV shape and type - std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); - - Status Reshape(const TensorShape &shape) override; - - Status ExpandDim(const dsize_t &axis) override; - - void Squeeze() override; - - Status Mat(const std::vector &index, cv::Mat *mat) { - uchar *start = nullptr; - TensorShape remaining({-1}); - RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); - RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); - return Status::OK(); - } - - private: - cv::Mat mat_; - - // Initialize CV::Mat with the data_, shape_ and type_ - Status MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_CV_TENSOR_H_ diff --git a/mindspore/ccsrc/dataset/core/data_type.cc b/mindspore/ccsrc/dataset/core/data_type.cc deleted file mode 100644 index dd97c10bae..0000000000 --- a/mindspore/ccsrc/dataset/core/data_type.cc +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/data_type.h" -#ifdef ENABLE_PYTHON -#include "dataset/core/pybind_support.h" -#endif - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { - -uint8_t DataType::SizeInBytes() const { - if (type_ < DataType::NUM_OF_TYPES) - return kTypeInfo[type_].sizeInBytes_; - else - return 0; -} - -#ifdef ENABLE_PYTHON -py::dtype DataType::AsNumpyType() const { - if (type_ < DataType::NUM_OF_TYPES) - return py::dtype(kTypeInfo[type_].pybindType_); - else - return py::dtype("unknown"); -} -#endif - -uint8_t DataType::AsCVType() const { - uint8_t res = kCVInvalidType; - if (type_ < DataType::NUM_OF_TYPES) { - res = kTypeInfo[type_].cvType_; - } - - if (res == kCVInvalidType) { - MS_LOG(ERROR) << "Cannot convert to OpenCV type. Return invalid type!"; - } - - return res; -} // namespace dataset - -DataType DataType::FromCVType(int cv_type) { - auto depth = static_cast(cv_type) & static_cast(CV_MAT_DEPTH_MASK); - switch (depth) { - case CV_8S: - return DataType(DataType::DE_INT8); - case CV_8U: - return DataType(DataType::DE_UINT8); - case CV_16S: - return DataType(DataType::DE_INT16); - case CV_16U: - return DataType(DataType::DE_UINT16); - case CV_32S: - return DataType(DataType::DE_INT32); - case CV_16F: - return DataType(DataType::DE_FLOAT16); - case CV_32F: - return DataType(DataType::DE_FLOAT32); - case CV_64F: - return DataType(DataType::DE_FLOAT64); - default: - MS_LOG(ERROR) << "Cannot convert from OpenCV type, unknown CV type. Unknown data type is returned!"; - return DataType(DataType::DE_UNKNOWN); - } -} - -DataType::DataType(const std::string &type_str) { - if (type_str == "bool") - type_ = DE_BOOL; - else if (type_str == "int8") - type_ = DE_INT8; - else if (type_str == "uint8") - type_ = DE_UINT8; - else if (type_str == "int16") - type_ = DE_INT16; - else if (type_str == "uint16") - type_ = DE_UINT16; - else if (type_str == "int32") - type_ = DE_INT32; - else if (type_str == "uint32") - type_ = DE_UINT32; - else if (type_str == "int64") - type_ = DE_INT64; - else if (type_str == "uint64") - type_ = DE_UINT64; - else if (type_str == "float16") - type_ = DE_FLOAT16; - else if (type_str == "float32") - type_ = DE_FLOAT32; - else if (type_str == "float64") - type_ = DE_FLOAT64; - else if (type_str == "string") - type_ = DE_STRING; - else - type_ = DE_UNKNOWN; -} - -std::string DataType::ToString() const { - if (type_ < DataType::NUM_OF_TYPES) - return kTypeInfo[type_].name_; - else - return "unknown"; -} - -#ifdef ENABLE_PYTHON -DataType DataType::FromNpArray(const py::array &arr) { - if (py::isinstance>(arr)) { - return DataType(DataType::DE_BOOL); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT8); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT8); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT16); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT16); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT32); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT32); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_INT64); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_UINT64); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_FLOAT16); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_FLOAT32); - } else if (py::isinstance>(arr)) { - return DataType(DataType::DE_FLOAT64); - } else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') { - return DataType(DataType::DE_STRING); - } else { - MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!"; - return DataType(DataType::DE_UNKNOWN); - } -} - -std::string DataType::GetPybindFormat() const { - std::string res; - if (type_ < DataType::NUM_OF_TYPES) { - res = kTypeInfo[type_].pybindFormatDescriptor_; - } - - if (res.empty()) { - MS_LOG(ERROR) << "Cannot convert from data type to pybind format descriptor!"; - } - return res; -} -#endif - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/data_type.h b/mindspore/ccsrc/dataset/core/data_type.h deleted file mode 100644 index e15b6ed272..0000000000 --- a/mindspore/ccsrc/dataset/core/data_type.h +++ /dev/null @@ -1,350 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_DATA_TYPE_H_ -#define DATASET_CORE_DATA_TYPE_H_ - -#include - -#include -#ifdef ENABLE_PYTHON -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "dataset/core/pybind_support.h" -namespace py = pybind11; -#else -#include "Eigen/Core" -using float16 = Eigen::half; -#endif -#include "dataset/core/constants.h" -namespace mindspore { -namespace dataset { - -// Class that represents basic data types in DataEngine. -class DataType { - public: - enum Type : uint8_t { - DE_UNKNOWN = 0, - DE_BOOL, - DE_INT8, - DE_UINT8, - DE_INT16, - DE_UINT16, - DE_INT32, - DE_UINT32, - DE_INT64, - DE_UINT64, - DE_FLOAT16, - DE_FLOAT32, - DE_FLOAT64, - DE_STRING, - NUM_OF_TYPES - }; - - struct TypeInfo { - const char *name_; // name to be represent the type while printing - const uint8_t sizeInBytes_; // number of bytes needed for this type - const char *pybindType_; // Python matching type, used in get_output_types - const std::string pybindFormatDescriptor_; // pybind format used for numpy types - const uint8_t cvType_; // OpenCv matching type - }; - -#ifdef ENABLE_PYTHON - static inline const TypeInfo kTypeInfo[] = { - // name, sizeInBytes, pybindTypem formatDescriptor, openCV - {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN - {"bool", 1, "bool", py::format_descriptor::format(), CV_8U}, // DE_BOOL - {"int8", 1, "int8", py::format_descriptor::format(), CV_8S}, // DE_INT8 - {"uint8", 1, "uint8", py::format_descriptor::format(), CV_8U}, // DE_UINT8 - {"int16", 2, "int16", py::format_descriptor::format(), CV_16S}, // DE_INT16 - {"uint16", 2, "uint16", py::format_descriptor::format(), CV_16U}, // DE_UINT16 - {"int32", 4, "int32", py::format_descriptor::format(), CV_32S}, // DE_INT32 - {"uint32", 4, "uint32", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT32 - {"int64", 8, "int64", py::format_descriptor::format(), kCVInvalidType}, // DE_INT64 - {"uint64", 8, "uint64", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT64 - {"float16", 2, "float16", "e", CV_16F}, // DE_FLOAT16 - {"float32", 4, "float32", py::format_descriptor::format(), CV_32F}, // DE_FLOAT32 - {"float64", 8, "double", py::format_descriptor::format(), CV_64F}, // DE_FLOAT64 - {"string", 0, "bytes", "S", kCVInvalidType} // DE_STRING - }; -#else - static inline const TypeInfo kTypeInfo[] = { - // name, sizeInBytes, pybindTypem formatDescriptor, openCV - {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN - {"bool", 1, "bool", "", CV_8U}, // DE_BOOL - {"int8", 1, "int8", "", CV_8S}, // DE_INT8 - {"uint8", 1, "uint8", "", CV_8U}, // DE_UINT8 - {"int16", 2, "int16", "", CV_16S}, // DE_INT16 - {"uint16", 2, "uint16", "", CV_16U}, // DE_UINT16 - {"int32", 4, "int32", "", CV_32S}, // DE_INT32 - {"uint32", 4, "uint32", "", kCVInvalidType}, // DE_UINT32 - {"int64", 8, "int64", "", kCVInvalidType}, // DE_INT64 - {"uint64", 8, "uint64", "", kCVInvalidType}, // DE_UINT64 - {"float16", 2, "float16", "", CV_16F}, // DE_FLOAT16 - {"float32", 4, "float32", "", CV_32F}, // DE_FLOAT32 - {"float64", 8, "double", "", CV_64F}, // DE_FLOAT64 - {"string", 0, "bytes", "", kCVInvalidType} // DE_STRING - }; -#endif - - // No arg constructor to create an unknown shape - DataType() : type_(DE_UNKNOWN) {} - - // Create a type from a given string - /// \param type_str - explicit DataType(const std::string &type_str); - - // Default destructor - ~DataType() = default; - - // Create a type from a given enum - /// \param d - constexpr explicit DataType(Type d) : type_(d) {} - - constexpr bool operator==(const DataType a) const { return type_ == a.type_; } - - constexpr bool operator==(const Type a) const { return type_ == a; } - - constexpr bool operator!=(const DataType a) const { return type_ != a.type_; } - - constexpr bool operator!=(const Type a) const { return type_ != a; } - - // Disable this usage `if(d)` where d is of type DataType - /// \return - operator bool() = delete; - - // To be used in Switch/case - /// \return - operator Type() const { return type_; } - - // The number of bytes needed to store one value of this type - /// \return - uint8_t SizeInBytes() const; - - // Convert from DataType to OpenCV type - /// \return - uint8_t AsCVType() const; - - // Convert from OpenCV type to DataType - /// \param cv_type - /// \return - static DataType FromCVType(int cv_type); - - // Returns a string representation of the type - /// \return - std::string ToString() const; - - // returns true if the template type is the same as the Tensor type_ - /// \tparam T - /// \return true or false - template - bool IsCompatible() const { - return type_ == FromCType(); - } - - // returns true if the template type is the same as the Tensor type_ - /// \tparam T - /// \return true or false - template - bool IsLooselyCompatible() const; - - // << Stream output operator overload - /// \notes This allows you to print the info using stream operators - /// \param out - reference to the output stream being overloaded - /// \param rO - reference to the DataType to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const DataType &so) { - out << so.ToString(); - return out; - } - - template - static DataType FromCType(); - -#ifdef ENABLE_PYTHON - // Convert from DataType to Pybind type - /// \return - py::dtype AsNumpyType() const; - - // Convert from NP type to DataType - /// \param type - /// \return - static DataType FromNpType(const py::dtype &type); - - // Convert from NP array to DataType - /// \param py array - /// \return - static DataType FromNpArray(const py::array &arr); -#endif - - // Get the buffer string format of the current type. Used in pybind buffer protocol. - /// \return - std::string GetPybindFormat() const; - - bool IsSignedInt() const { - return type_ == DataType::DE_INT8 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT32 || - type_ == DataType::DE_INT64; - } - - bool IsUnsignedInt() const { - return type_ == DataType::DE_UINT8 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT32 || - type_ == DataType::DE_UINT64; - } - - bool IsInt() const { return IsSignedInt() || IsUnsignedInt(); } - - bool IsFloat() const { - return type_ == DataType::DE_FLOAT16 || type_ == DataType::DE_FLOAT32 || type_ == DataType::DE_FLOAT64; - } - - bool IsBool() const { return type_ == DataType::DE_BOOL; } - - bool IsNumeric() const { return type_ != DataType::DE_STRING; } - - Type value() const { return type_; } - - private: - Type type_; -}; - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_BOOL); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_FLOAT64); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_FLOAT32); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_FLOAT16); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT64); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT64); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT32); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT32); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT16); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT16); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_INT8); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_UINT8); -} - -template <> -inline DataType DataType::FromCType() { - return DataType(DataType::DE_STRING); -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_BOOL; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_FLOAT64 || type_ == DataType::DE_FLOAT32; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_FLOAT32; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_FLOAT16; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT64 || type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || - type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT64 || type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || - type_ == DataType::DE_UINT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_INT8; -} - -template <> -inline bool DataType::IsLooselyCompatible() const { - return type_ == DataType::DE_UINT8; -} -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_DATA_TYPE_H_ diff --git a/mindspore/ccsrc/dataset/core/global_context.cc b/mindspore/ccsrc/dataset/core/global_context.cc deleted file mode 100644 index 3de8e0fcd8..0000000000 --- a/mindspore/ccsrc/dataset/core/global_context.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/global_context.h" - -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/tensor.h" -#include "dataset/util/allocator.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -// Global static pointer for the singleton GlobalContext -std::unique_ptr GlobalContext::global_context_ = nullptr; -std::once_flag GlobalContext::init_instance_flag_; - -constexpr int GlobalContext::kArenaSize; -constexpr int GlobalContext::kMaxSize; -constexpr bool GlobalContext::kInitArena; - -// Singleton initializer -GlobalContext *GlobalContext::Instance() { - // If the single global context is not created yet, then create it. Otherwise the - // existing one is returned. - std::call_once(init_instance_flag_, []() { - global_context_.reset(new GlobalContext()); - Status rc = global_context_->Init(); - if (rc.IsError()) { - std::terminate(); - } - }); - return global_context_.get(); -} - -Status GlobalContext::Init() { - config_manager_ = std::make_shared(); - mem_pool_ = std::make_shared(); - // For testing we can use Dummy pool instead - - // Create some tensor allocators for the different types and hook them into the pool. - tensor_allocator_ = std::make_unique>(mem_pool_); - cv_tensor_allocator_ = std::make_unique>(mem_pool_); - int_allocator_ = std::make_unique(mem_pool_); - return Status::OK(); -} - -// A print method typically used for debugging -void GlobalContext::Print(std::ostream &out) const { - out << "GlobalContext contains the following default config: " << *config_manager_ << "\n"; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/global_context.h b/mindspore/ccsrc/dataset/core/global_context.h deleted file mode 100644 index ee0cbfbbe0..0000000000 --- a/mindspore/ccsrc/dataset/core/global_context.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_GLOBAL_CONTEXT_H_ -#define DATASET_CORE_GLOBAL_CONTEXT_H_ - -#include -#include - -#include "dataset/core/constants.h" -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// forward declare -class MemoryPool; -class ConfigManager; -class Tensor; -class CVTensor; - -using TensorAlloc = Allocator; // An allocator for Tensors -using CVTensorAlloc = Allocator; // An allocator CVTensors -using IntAlloc = Allocator; - -class GlobalContext { - // some consts for pool config - static constexpr int kArenaSize = 128; - static constexpr int kMaxSize = -1; - static constexpr bool kInitArena = true; - - public: - // Singleton pattern. This method either: - // - creates the single version of the GlobalContext for the first time and returns it - // OR - // - returns the already existing single instance of the GlobalContext - // @return the single global context - static GlobalContext *Instance(); - - // Destructor - ~GlobalContext() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - void Print(std::ostream &out) const; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param g_c - reference to the GlobalContext to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const GlobalContext &g_c) { - g_c.Print(out); - return out; - } - - // Getter method - // @return the client config as raw const pointer - static std::shared_ptr config_manager() { return Instance()->config_manager_; } - - // Getter method - // @return the mem pool - std::shared_ptr mem_pool() const { return mem_pool_; } - - // Getter method - // @return the tensor allocator as raw pointer - const TensorAlloc *tensor_allocator() const { return tensor_allocator_.get(); } - - // Getter method - // @return the CVTensor allocator as raw pointer - const CVTensorAlloc *cv_tensor_allocator() const { return cv_tensor_allocator_.get(); } - - // Getter method - // @return the integer allocator as raw pointer - const IntAlloc *int_allocator() const { return int_allocator_.get(); } - - private: - // Constructor. - // @note Singleton. Instantiation flows through instance() - // @return This is a constructor. - GlobalContext() = default; - - Status Init(); - - static std::once_flag init_instance_flag_; - static std::unique_ptr global_context_; // The instance of the singleton (global) - std::shared_ptr mem_pool_; // A global memory pool - std::shared_ptr config_manager_; // The configs - std::unique_ptr tensor_allocator_; // An allocator for Tensors - std::unique_ptr cv_tensor_allocator_; // An allocator for CV Tensors - std::unique_ptr int_allocator_; // An allocator for ints -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CORE_GLOBAL_CONTEXT_H_ diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc deleted file mode 100644 index eda5239852..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ /dev/null @@ -1,1034 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/core/tensor.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/constants.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/global_context.h" -#ifdef ENABLE_PYTHON -#include "dataset/core/pybind_support.h" -namespace py = pybind11; -#endif -#include "dataset/core/tensor_shape.h" - -namespace mindspore { -namespace dataset { -// Helper macros for printing tensor elements -#define CASE_PRINT(de_type, native_type) \ - case de_type: { \ - native_type o; \ - rc = GetItemAt(&o, index); \ - out << o; \ - break; \ - } - -#define CASE_PRINT_HEX(de_type, native_type) \ - case de_type: { \ - native_type o; \ - rc = GetItemAt(&o, index); \ - out << std::hex << std::setw(2) << std::setfill('0') << o << std::dec << std::setfill(' '); \ - break; \ - } - -Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), type_(type), data_(nullptr) { - // grab the mem pool from global context and create the allocator for char data area - std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); - data_allocator_ = std::make_unique>(global_pool); -} - -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data) : Tensor(shape, type) { - if (type.IsNumeric()) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Given the shape/type of this tensor, compute the data size and copy in the input bytes. - int64_t byte_size = this->SizeInBytes(); - Status s = this->AllocateBuffer(byte_size); // Allocates data_ inside itself - if (s.IsOk() && data_ != nullptr) { - int ret_code = memcpy_s(data_, byte_size, data, byte_size); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } else { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - } - } else { - MS_LOG(ERROR) << "Type should be numeric to use this constructor."; - } -} - -Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length) - : Tensor(shape, type) { - // If the data pointer was given, then we can also populate the tensor with data - if (data != nullptr) { - // Allocates data_ inside itself - Status s = AllocateBuffer(length); - if (s.IsError()) { - MS_LOG(ERROR) << "Failed to create memory for Tensor!"; - } - if (data_ != nullptr) { - int ret_code = memcpy_s(data_, length, data, length); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor!"; - } - } - } -} - -Tensor::Tensor(Tensor &&other) noexcept - : shape_(other.shape()), - type_(other.type()), - data_(other.GetMutableBuffer()), - data_allocator_(std::move(other.data_allocator_)) { - other.Invalidate(); -} - -Tensor &Tensor::operator=(Tensor &&other) noexcept { - if (&other != this) { - shape_ = other.shape(); - type_ = other.type(); - data_ = other.GetMutableBuffer(); - data_end_ = other.data_end_; - data_allocator_ = std::move(other.data_allocator_); - other.Invalidate(); - } - return *this; -} - -Tensor::Tensor(const std::vector &strings, const TensorShape &shape) - : Tensor(TensorShape({static_cast(strings.size())}), DataType(DataType::DE_STRING)) { - auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; - dsize_t total_length = std::accumulate(strings.begin(), strings.end(), 0, length_sum); - - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize + 1) * shape_.NumOfElements() + kOffsetSize + total_length; - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (const auto &str : strings) { - // insert the start index of the string. - offset_arr[i++] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; - } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - this->data_end_ = data_ + offset_arr[i]; - - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); -} - -Tensor::Tensor(const dataengine::BytesList &bytes_list, const TensorShape &shape) - : Tensor(TensorShape({static_cast(bytes_list.value_size())}), DataType(DataType::DE_STRING)) { - // total bytes needed = offset array + strings - // offset array needs to store one offset var per element + 1 extra to get the length of the last string. - // strings will be null-terminated --> need 1 extra byte per element - dsize_t num_bytes = (kOffsetSize)*shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); - - data_ = data_allocator_->allocate(num_bytes); - - auto offset_arr = reinterpret_cast(data_); - uchar *buf = GetStringsBuffer(); - - offset_t offset = buf - data_; // the first string will start here - uint32_t i = 0; - for (; i < bytes_list.value_size(); i++) { - const std::string &str = bytes_list.value(i); - // insert the start index of the string. - offset_arr[i] = offset; - // total bytes are reduced by kOffsetSize - num_bytes -= kOffsetSize; - // insert actual string - int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); - if (ret_code != 0) { - MS_LOG(ERROR) << "Cannot copy string into Tensor"; - } - // next string will be stored right after the current one. - offset = offset + str.length() + 1; - // total bytes are reduced by the length of the string - num_bytes -= str.length() + 1; - } - // store one more offset value so we can get the length of the last string - // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] - offset_arr[i] = offset; - - data_end_ = data_ + offset_arr[i]; - - MS_ASSERT(num_bytes == 0); - if (shape.known()) Tensor::Reshape(shape); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, TensorImpl tensor_impl, const TensorShape &shape, - DataType type, const unsigned char *data) { - if (!shape.known()) { - RETURN_STATUS_UNEXPECTED("Invalid shape."); - } - if (type == DataType::DE_UNKNOWN) { - RETURN_STATUS_UNEXPECTED("Invalid data type."); - } - - switch (tensor_impl) { - case TensorImpl::kFlexible: { - // The flex tensor is really just the base class tensor implementation - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - case TensorImpl::kCv: { - const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); - *ptr = std::allocate_shared(*alloc, shape, type, data); - break; - } - default: { - std::string err_msg("Invalid tensor implementation type."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); // returns base-class shared_ptr -} - -#ifdef ENABLE_PYTHON -Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr) { - std::vector shape; - for (dsize_t i = 0; i < arr.ndim(); i++) { - shape.push_back(static_cast(arr.shape()[i])); - } - arr.resize({arr.size()}); // flatten the py::array so we can iterate once - std::vector strings; - - if (arr.dtype().kind() == 'U') { - std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); - } else { - std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); - } - - arr.resize(shape); // resize arr back to the original shape - - return CreateTensor(ptr, strings, TensorShape{shape}); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { - if (DataType::FromNpArray(arr) == DataType::DE_STRING) { - return CreateTensorFromNumpyString(ptr, arr); - } - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, TensorShape({}), DataType(DataType::DE_UNKNOWN)); - - std::vector shape; - for (dsize_t i = 0; i < arr.ndim(); i++) { - shape.push_back(static_cast(arr.shape()[i])); - } - - (*ptr)->shape_ = TensorShape(shape); - (*ptr)->type_ = DataType::FromNpArray(arr); - if (!(*ptr)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); - - if ((*ptr)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); - - std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); - (*ptr)->data_allocator_ = std::make_unique>(global_pool); - int64_t byte_size = (*ptr)->SizeInBytes(); - RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size)); - - unsigned char *data = static_cast(arr.request().ptr); - if ((*ptr)->data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); - } - - std::vector strides; - for (dsize_t i = 0; i < arr.ndim(); i++) { - strides.push_back(static_cast(arr.strides()[i])); - } - - // check if strides are contiguous - bool is_strided = false; - dsize_t count = (*ptr)->shape_.NumOfElements(); - for (size_t i = 0; i < shape.size(); i++) { - count /= shape[i]; - if (strides[i] != (*ptr)->type_.SizeInBytes() * count) { - is_strided = true; - break; - } - } - - if (is_strided) { - RETURN_IF_NOT_OK(CopyStridedArray((*ptr)->data_, data, shape, strides, (*ptr)->type_.SizeInBytes())); - } else { - int ret_code = memcpy_s((*ptr)->data_, byte_size, data, byte_size); - if (ret_code != 0) { - RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); - } - } - - return Status::OK(); // returns base-class shared_ptr -} -#endif - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, strings, shape); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *ptr = std::allocate_shared(*alloc, bytes_list, shape); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::string &file_path) { - std::ifstream fs; - fs.open(file_path, std::ios::binary | std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path); - int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); - CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8))); - int64_t written_bytes = fs.read(reinterpret_cast((*ptr)->GetMutableBuffer()), num_bytes).gcount(); - CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor"); - fs.close(); - return Status::OK(); -} - -Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type)); - - unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer(); - int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size; - - for (int i = 0; i < bytes_list.value_size(); i++) { - // read string data into tensor - const std::string ¤t_element = bytes_list.value(i); - int return_code = - memcpy_s(current_tensor_addr, tensor_bytes_remaining, common::SafeCStr(current_element), current_element.size()); - - CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when reading bytesList element into Tensor"); - - current_tensor_addr += current_element.size(); - tensor_bytes_remaining -= current_element.size(); - - // pad - int64_t chars_to_pad = pad_size - current_element.size(); - return_code = memset_s(current_tensor_addr, tensor_bytes_remaining, static_cast(' '), chars_to_pad); - CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when padding Tensor"); - - current_tensor_addr += chars_to_pad; - tensor_bytes_remaining -= chars_to_pad; - } - - return Status::OK(); -} - -// Memcpy the given strided array's used part to consecutive memory -// Consider a 3-d array -// A[(i * shape[1] + j) * shape[2] + k] = B[i][j][k] = C[i * strides[0] + j * strides[1] + k * strides[2]] -// Here we convert array C to array A, by memcpy index by index (Note that not all elements in C is copied) -Status Tensor::CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size) { - dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - for (dsize_t i = 0; i < size; ++i) { - dsize_t offset = 0; - dsize_t count = i; - for (size_t j = 0; j < shape.size(); ++j) { - // convert 1d array's index to 3d array's index (A -> B) - dsize_t idx = count % shape[shape.size() - 1 - j]; - count /= shape[shape.size() - 1 - j]; - // calculate the raw data offset based on strides (B -> C) - offset += idx * strides[shape.size() - 1 - j]; - // once count = 0, the following idxes are all zero, skip them - if (count == 0) break; - } - // strides already consider byte size of the data type, but dst doesn't. - // dst[i] = dst + i * type_size = src + offset - int ret_code = memcpy_s(dst + i * type_size, type_size, src + offset, type_size); - if (ret_code != 0) { - RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); - } - } - return Status::OK(); -} - -// Name: Destructor -// Description: Destructor -Tensor::~Tensor() { - if (data_ != nullptr) { - if (data_allocator_ != nullptr) { - data_allocator_->deallocate(data_); - data_ = nullptr; - data_end_ = nullptr; - } else { - // If we didn't have an allocator, but data_ is not null then it must - // be a stand-alone tensor that used malloc directly. - free(data_); - data_ = nullptr; - data_end_ = nullptr; - } - } -} - -bool Tensor::operator==(const Tensor &rhs) const { - // 1. different shape 2. different type 3. one data_ is nullptr and the other is not - if (shape_ != rhs.shape() || type_ != rhs.type_ || (data_ == nullptr && rhs.data_ != nullptr) || - (data_ != nullptr && rhs.data_ == nullptr)) { - return false; - } - if (data_ == nullptr && rhs.data_ == nullptr) { - return true; - } - // use mem compare to compare the two data, size are already verified - return memcmp(data_, rhs.data_, SizeInBytes()) == 0; -} - -// Name: PrintItemAt() -// Description: A function that print the value as specified by its index -void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) const { - Status rc; - MS_ASSERT(data_); - - switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, bool); - - CASE_PRINT_HEX(DataType::DE_INT8, int8_t); - - CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t); - - CASE_PRINT(DataType::DE_INT16, int16_t); - - CASE_PRINT(DataType::DE_UINT16, uint16_t); - - CASE_PRINT(DataType::DE_INT32, int32_t); - - CASE_PRINT(DataType::DE_UINT32, uint32_t); - - CASE_PRINT(DataType::DE_INT64, int64_t); - - CASE_PRINT(DataType::DE_UINT64, uint64_t); - - CASE_PRINT(DataType::DE_FLOAT16, float16); - - CASE_PRINT(DataType::DE_FLOAT32, float); - - CASE_PRINT(DataType::DE_FLOAT64, double); - - case DataType::DE_STRING: { - std::string_view o{""}; - GetItemAt(&o, index); - out << "\"" << o << "\""; - break; - } - default: { - out << "?"; - break; - } - } - if (rc.IsError()) { - out << rc.ToString(); - } -} - -// Name: PrintRecursive() -// Description: A function that prints Tensor recursively, first called by print -void Tensor::PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const { - if (cur_index.size() == shape_.Rank()) { - PrintItemAt(cur_index, out); - } else { - out << "["; - for (dsize_t i = 0; i < shape_[cur_dim]; i++) { - std::vector new_index = cur_index; - new_index.push_back(i); - PrintRecursive(out, cur_dim + 1, new_index); - if (i < shape_[cur_dim] - 1) { - out << ","; - } - } - out << "]"; - } -} - -// Name: Print() -// Description: A function that prints info about the tensor -void Tensor::Print(std::ostream &out) const { - out << "Tensor (shape: "; - out << shape_; - out << ", Type: " << type_ << ")\n"; - if (data_) { - PrintRecursive(out, 0, std::vector{}); - } else { - out << "[Data area is null]"; - } -} -Status Tensor::AllocateBuffer(const dsize_t &length) { - if (data_ == nullptr) { - if (data_allocator_ != nullptr) { - data_ = data_allocator_->allocate(length); - RETURN_UNEXPECTED_IF_NULL(data_); - data_end_ = data_ + length; - } else { - data_ = static_cast(malloc(length)); - data_end_ = data_ + length; - RETURN_UNEXPECTED_IF_NULL(data_); - } - } - return Status::OK(); -} -const unsigned char *Tensor::GetBuffer() const { - // This version cannot modify anything. data_ could possibly be null. - return data_; -} - -// check for empty -bool Tensor::HasData() const { - if (data_ == nullptr) { - return true; - } else { - return false; - } -} - -unsigned char *Tensor::GetMutableBuffer() { - if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { - return nullptr; - } - // If the data area is already created, return the pointer to it - if (data_ != nullptr) { - return data_; - } else { - // If the data area is not created, then identify the memory size based - // on the shape and type and allocate it. - if (this->AllocateBuffer(this->SizeInBytes()).IsOk()) { - return data_; - } else { - return nullptr; - } - } -} - -Status Tensor::Reshape(const TensorShape &shape) { - if (shape.NumOfElements() == shape_.NumOfElements()) { - shape_ = shape; - return Status::OK(); - } else { - std::string err = "Cannot reshape, Number of elements do not match"; - RETURN_STATUS_UNEXPECTED(err); - } -} - -void Tensor::Invalidate() { - shape_ = TensorShape::CreateUnknownRankShape(); - type_ = DataType(DataType::DE_UNKNOWN); - data_ = nullptr; - data_end_ = nullptr; - data_allocator_ = nullptr; -} - -template -Status Tensor::GetItemPtr(T **ptr, const std::vector &index) const { - if (type_.IsCompatible()) { - if (data_ == nullptr) { - std::string err = "Data is not allocated yet"; - RETURN_STATUS_UNEXPECTED(err); - } - dsize_t flat_idx; - RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); - *ptr = reinterpret_cast(data_ + flat_idx * type_.SizeInBytes()); - - return Status::OK(); - } else { - std::string err = "data type not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } -} - -Status Tensor::GetItemPtr(uchar **ptr, const std::vector &index, offset_t *length) const { - if (type_ == DataType::DE_STRING) { - if (data_ == nullptr) { - std::string err = "Data is not allocated yet"; - RETURN_STATUS_UNEXPECTED(err); - } - dsize_t flat_idx; - RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); - offset_t length_temp = 0; - RETURN_IF_NOT_OK(GetStringAt(flat_idx, ptr, &length_temp)); - if (length != nullptr) *length = length_temp; - return Status::OK(); - } else { - std::string err = "data type not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } -} - -Status Tensor::StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining) { - if (type() == DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); - } - - dsize_t flat_ind; - std::vector t_shape = shape().AsVector(); - std::vector r(t_shape.begin() + ind.size(), t_shape.end()); - *remaining = TensorShape(r); - ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); - - RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); - // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only - // be true is the tensor failed to allocate memory. - if (GetMutableBuffer() == nullptr) { - RETURN_STATUS_UNEXPECTED("Invalid GetBuffer in Tensor, got nullptr"); - } - *start_addr_of_index = GetMutableBuffer() + flat_ind * this->type().SizeInBytes(); - return Status::OK(); -} - -Status Tensor::InsertTensor(const std::vector &ind, const std::shared_ptr &tensor) { - std::string err_msg; - err_msg += (this->type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; - err_msg += (!this->shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; - err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : ""; - err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; - uchar *start_addr_of_ind = nullptr; - TensorShape remaining_shape({-1}); - err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; - err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; - if (!err_msg.empty()) { - MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - if (start_addr_of_ind != nullptr) { - int ret_code = - memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); - if (ret_code == 0) { - return Status::OK(); - } else { - err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; - MS_LOG(DEBUG) << "Tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } else { - RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); - } - } -} - -Status Tensor::Concatenate(const std::vector &index, const std::shared_ptr &tensor) { - std::string err_msg; - err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; - err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; - err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; - - err_msg += - (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; - err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; - uchar *start_addr_of_ind = nullptr; - - TensorShape remaining_shape = tensor->shape(); - StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); - err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; - - if (!err_msg.empty()) { - MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; - - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - int ret_code = - memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); - - if (ret_code == 0) { - return Status::OK(); - } else { - err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; - MS_LOG(DEBUG) << "Tensor message: " << err_msg; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } -} - -Status Tensor::ExpandDim(const dsize_t &axis) { - if (axis > Rank()) { - std::string err = "Axis is out of bound"; - RETURN_STATUS_UNEXPECTED(err); - } - if (axis == Rank()) { - shape_ = shape_.AppendDim(1); - } else { - shape_ = shape_.InsertDim(axis, 1); - } - return Status::OK(); -} - -std::vector Tensor::Strides() { - std::vector strides = shape_.Strides(); - uint8_t size = type_.SizeInBytes(); - std::transform(strides.begin(), strides.end(), strides.begin(), [&size](const auto &c) { return c * size; }); - return strides; -} - -#ifdef ENABLE_PYTHON -Status Tensor::GetBufferInfo(Tensor *t, py::buffer_info *out) { - RETURN_UNEXPECTED_IF_NULL(t); - CHECK_FAIL_RETURN_UNEXPECTED(t->type().IsNumeric(), "Cannot use GetBufferInfo on tensor of strings."); - - std::string format_desc = t->type().GetPybindFormat(); - if (format_desc.empty()) { - RETURN_STATUS_UNEXPECTED("Cannot convert DE type tp pybind format"); - } - *out = py::buffer_info(t->GetMutableBuffer(), /* Pointer to buffer */ - t->type().SizeInBytes(), /* Size of one scalar */ - format_desc, /* Python struct-style format descriptor */ - t->Rank(), /* Number of dimensions */ - t->shape().AsVector(), /* Buffer dimensions */ - t->Strides()); - return Status::OK(); -} -#endif - -template -Status Tensor::GetItemAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - if (type_.IsUnsignedInt()) { - RETURN_IF_NOT_OK(GetUnsignedIntAt(o, index)); - } else if (type_.IsSignedInt()) { - RETURN_IF_NOT_OK(GetSignedIntAt(o, index)); - } else if (type_.IsFloat()) { - RETURN_IF_NOT_OK(GetFloatAt(o, index)); - } else if (type_.IsBool()) { - bool *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - } else { - std::string err = "Tensor Type is unknown"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} - -Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) const { - RETURN_UNEXPECTED_IF_NULL(data_); - RETURN_UNEXPECTED_IF_NULL(o); - CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string"); - - uchar *start = nullptr; - offset_t length = 0; - RETURN_IF_NOT_OK(GetItemPtr(&start, index, &length)); - std::string_view sv{reinterpret_cast(start)}; - o->swap(sv); - return Status::OK(); -} - -#ifdef ENABLE_PYTHON -// return data as numpy, should return status -Status Tensor::GetDataAsNumpy(py::array *data) { - RETURN_UNEXPECTED_IF_NULL(data_); - RETURN_UNEXPECTED_IF_NULL(data); - if (type_ == DataType::DE_BOOL) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT8) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT16) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT32) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_INT64) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT8) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT16) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT32) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_UINT64) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_FLOAT16) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_FLOAT32) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_FLOAT64) { - *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); - } else if (type_ == DataType::DE_STRING) { - GetDataAsNumpyStrings(data); - } else { - RETURN_STATUS_UNEXPECTED("Got unexpected type when returning numpy"); - } - return Status::OK(); -} -Status Tensor::GetDataAsNumpyStrings(py::array *data) { - auto itr = begin(); - uint64_t max = 0; - for (; itr != end(); itr++) { - max = std::max((*itr).length(), max); - } - // if all strings are empty, numpy stores a byte for each string |S1 - max = (max == 0 ? 1 : max); - uint64_t total_size = shape_.NumOfElements() * max; - char *tmp_data = reinterpret_cast(data_allocator_->allocate(total_size)); - if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array."); - int ret_code = memset_s(tmp_data, total_size, 0, total_size); - CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to initialize temp memory"); - - itr = begin(); - uint64_t i = 0; - for (; itr != end(); itr++, i++) { - if (!(*itr).empty()) { - ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length()); - CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data."); - } - } - auto strides = shape_.Strides(); - std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; }); - *data = py::array(py::dtype("S" + std::to_string(max)), shape_.AsVector(), strides, tmp_data); - data_allocator_->deallocate(reinterpret_cast(tmp_data)); - return Status::OK(); -} -#endif - -void Tensor::Squeeze() { shape_ = shape_.Squeeze(); } - -template -Status Tensor::GetUnsignedIntAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - switch (type_.value()) { - case DataType::DE_UINT8: { - uint8_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_UINT16: { - uint16_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_UINT32: { - uint32_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_UINT64: { - uint64_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - default: - std::string err = "Tensor Type is not an unsigned Integer"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} - -template -Status Tensor::GetSignedIntAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - switch (type_.value()) { - case DataType::DE_INT8: { - int8_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_INT16: { - int16_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_INT32: { - int32_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_INT64: { - int64_t *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - default: - std::string err = "Tensor Type is not a signed Integer"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} - -template -Status Tensor::GetFloatAt(T *o, const std::vector &index) const { - if (data_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); - } - if (!type_.IsLooselyCompatible()) { - std::string err = "Template type and Tensor type are not compatible"; - RETURN_STATUS_UNEXPECTED(err); - } - switch (type_.value()) { - case DataType::DE_FLOAT16: { - float16 *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_FLOAT32: { - float *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - case DataType::DE_FLOAT64: { - double *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *o = static_cast(*ptr); - break; - } - default: - std::string err = "Tensor Type is not a float/double"; - RETURN_STATUS_UNEXPECTED(err); - } - return Status::OK(); -} -Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const { - CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not string"); - RETURN_UNEXPECTED_IF_NULL(data_); - RETURN_UNEXPECTED_IF_NULL(string_start); - RETURN_UNEXPECTED_IF_NULL(length); - auto *offset_ptr = reinterpret_cast(data_); // offsets starts here - offset_t start = offset_ptr[index]; - *string_start = data_ + start; - *length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length - return Status::OK(); -} -Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vector &index) { - CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type"); - CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0"); - - uint8_t type_size = type_.SizeInBytes(); - size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size; - dsize_t src_flat_ind = 0, dst_flat_ind = 0; - RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind)); - RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind)); - - const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size; - unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size; - CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); - return Status::OK(); -} -Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { - CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); - CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); - if (type_.IsNumeric()) { - return SliceNumeric(out, indices); - } else { - return SliceString(out, indices); - } -} -Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { - RETURN_IF_NOT_OK( - CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); - (*out)->GetMutableBuffer(); - dsize_t out_index = 0; - dsize_t dim_length = shape_[0]; - dsize_t type_size = type_.SizeInBytes(); - dsize_t src_start = HandleNeg(indices[0], dim_length); - uchar *dst_addr = (*out)->data_; - dsize_t count = 1; - - for (dsize_t i = 0; i < indices.size(); i++) { - dsize_t cur_index = HandleNeg(indices[i], dim_length); - CHECK_FAIL_RETURN_UNEXPECTED( - cur_index >= 0 && cur_index < dim_length, - "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); - if (i < indices.size() - 1) { - dsize_t next_index = HandleNeg(indices[i + 1], dim_length); - if (next_index == cur_index + 1) { - count++; - continue; - } - } - int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, - count * type_size); - CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); - out_index += count; - if (i < indices.size() - 1) { - src_start = HandleNeg(indices[i + 1], dim_length); // next index - } - count = 1; - } - return Status::OK(); -} -Status Tensor::SliceString(std::shared_ptr *out, const std::vector &indices) { - dsize_t dim_length = shape_[0]; - std::vector strings; - for (dsize_t index : indices) { - dsize_t cur_index = HandleNeg(index, dim_length); - CHECK_FAIL_RETURN_UNEXPECTED( - cur_index >= 0 && cur_index < dim_length, - "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); - std::string_view sv; - GetItemAt(&sv, {cur_index}); - strings.emplace_back(sv); - } - return CreateTensor(out, strings); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor.h b/mindspore/ccsrc/dataset/core/tensor.h deleted file mode 100644 index 337535a2c3..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor.h +++ /dev/null @@ -1,668 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_TENSOR_H_ -#define DATASET_CORE_TENSOR_H_ - -#include -#include -#include -#include -#include "./securec.h" -#include "utils/log_adapter.h" -#if defined(_WIN32) || defined(_WIN64) -#undef HAVE_STDDEF_H -#undef HAVE_STDLIB_H -#endif - -#ifdef ENABLE_PYTHON -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#endif - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/util/status.h" -#include "proto/example.pb.h" - -#ifdef ENABLE_PYTHON -namespace py = pybind11; -#endif -namespace mindspore { -namespace dataset { -class Tensor; -template -class Allocator; - -using CharAllocPtr = std::unique_ptr>; -using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors - -class Tensor { - public: - Tensor() = delete; - - // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. - // @note The shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - Tensor(const TensorShape &shape, const DataType &type); - - // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. - // @note The buffer should be valid and the shape and type information should be known and valid. - // @param shape TensorShape - // @param type DataType - // @param data unsigned char*, pointer to the data. - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); - - Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); - - Tensor(const Tensor &other) = delete; - - Tensor &operator=(const Tensor &other) = delete; - - Tensor(Tensor &&other) noexcept; - - Tensor &operator=(Tensor &&other) noexcept; - - Status AllocateBuffer(const dsize_t &length); - - // type of offest values to store strings information - using offset_t = uint32_t; - // const of the size of the offset variable - static constexpr uint8_t kOffsetSize = sizeof(offset_t); - // Tensor base class which holds the data in an unsigned char* buffer. - - // Construct a scalar string Tensor - explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} - - // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is - // the size of the vector `strings`. - // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. - // Thr offset array will store one extra value to find the length of the last string. - // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn - // The value of each offset is the start index of the corresponding string - // Offsets is of type offest_t - // strings will ne null-terminated - // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) - // |----------------------------------------------------------------| - // | OFFSET ARRAY | STRINGS | - // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | - // | 11 | 15 | 18 | abc\0 | de\0 | - // |----------------------------------------------------------------| - explicit Tensor(const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // Same as Tensor(vector) but the input is protobuf bytelist - explicit Tensor(const dataengine::BytesList &bytes_list, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // A static factory method to create the given flavour of derived Tensor - // Returns the base class reference for the Tensor. - // @param ptr output argument to hold the created Tensor of given tensor_impl - // @param tensor_impl - which implementation of Tensor - // @param shape - shape of the tensor - // @param type - datatype of the tensor - // @param data - data to be copied to Tensor new allocation - // @return Status Code - static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, - const unsigned char *data = nullptr); - - // Create a copy of the input tensor - // @param out [out] output tensor to be generated - // @param in [in] orginal tensor to be copied - // @return Status - static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { - const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); - *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); - return Status::OK(); - } - -#ifdef ENABLE_PYTHON - // A static factory method to create a Tensor from a given py::array. - // @param ptr output argument to hold the created Tensor - // @param arr py::array - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, py::array arr); - - // Helper function to create a tensor from Numpy of strings - static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); -#endif - - // A static factory method to create a Tensor from a given list of strings. - // @param ptr output argument to hold the created Tensor - // @param strings elements of the tensor - // @param shape shape of the tensor - // @return Status Code - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, - const TensorShape &shape = TensorShape::CreateUnknownRankShape()); - - // create tensor from protobuf bytelist with strings - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape); - - // A static factory method to create a Tensor from a given list of numbers. - // @param ptr output argument to hold the created Tensor - // @param items elements of the tensor - // @param shape shape of the tensor - // @return Status Code - template - static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, - const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { - DataType type = DataType::FromCType(); - auto items_ptr = reinterpret_cast(&items[0]); - TensorShape shape = shape_req; - if (!shape.known()) { - shape = TensorShape({static_cast(items.size())}); - } - return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); - } - - // A static factory method to create a Tensor from a given number. - // @param ptr output argument to hold the created Tensor - // @param item value - // @return Status Code - template - static Status CreateTensor(std::shared_ptr *ptr, const T &item) { - return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); - } - - // Create tensor from protobuf bytelist with uint8 or int8 types - static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, - const TensorShape &shape, const DataType &type, dsize_t pad_size); - - static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); - - // Copy raw data of a array based on shape and strides to the destination pointer - // @param dst Pointer to the destination array where the content is to be copied - // @param src Pointer to the source of strided array to be copied - // @param shape - shape of the source array - // @param strides - strides of the source array - // @param type_size - number of bytes needed to store one array element's type - // @return Status Code - static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, - std::vector strides, uint8_t type_size); - - // Release the memory using the allocator - virtual ~Tensor(); - - // compare the tensor shape and data - bool operator==(const Tensor &rhs) const; - - bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } - - // Get item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return the item specified at index - template - Status GetItemAt(T *o, const std::vector &index) const; - - // Get string located at `index`. - // @param index vector - // @return return std::string_view specified at index - Status GetItemAt(std::string_view *o, const std::vector &index) const; - - template - Status GetUnsignedIntAt(T *o, const std::vector &index) const; - - template - Status GetSignedIntAt(T *o, const std::vector &index) const; - - template - Status GetFloatAt(T *o, const std::vector &index) const; - - // set item at location specified by index - // @tparam `T` - // @param index - // @param value of type `T` - template - Status SetItemAt(const std::vector &index, const T &value) { - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); - T *ptr = nullptr; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); - *ptr = value; - return Status::OK(); - } - - // set string item at location specified by index - // @param index - // @param value of type std::string - Status SetItemAt(const std::vector &index, const std::string &value) { - RETURN_UNEXPECTED_IF_NULL(data_); - uchar *ptr = nullptr; - offset_t length = 0; - RETURN_IF_NOT_OK(GetItemPtr(&ptr, index, &length)); - if (value.length() != length) { - RETURN_STATUS_UNEXPECTED("Length of the new string does not match the item."); - } - memcpy_s(reinterpret_cast(ptr), length, value.c_str(), length); - - return Status::OK(); - } - // fill tensor with Zeros. Does not support strings. - Status Zero() { - CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); - dsize_t size = SizeInBytes(); - CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(GetMutableBuffer(), size, 0, size) == 0, - "Failed to fill tensor with zeroes."); - return Status::OK(); - } - - // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. - // @tparam T - // @param value - template - Status Fill(const T &value) { - CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); - RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); - int64_t cellSize = type_.SizeInBytes(); - if ((data_ != nullptr) && type_.IsCompatible()) { - for (dsize_t i = 0; i < Size(); i++) { - CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err"); - } - return Status::OK(); - } else { - std::string err; - err += (data_ == nullptr) ? "data_ is nullptr \t" : ""; - err += type_.IsCompatible() ? "data type not compatible\t" : ""; - return Status(StatusCode::kUnexpectedError, err); - } - } - - // Getter function for shape - // @return - const TensorShape &shape() const { return shape_; } - - /// Check if tensor has data - /// \return bool - true if tensor is empty - bool HasData() const; - - // Reshape the tensor. The given shape should have the same number of elements in the Tensor - // @param shape - virtual Status Reshape(const TensorShape &shape); - - // @return number of elements in this tensor - dsize_t Size() const { return shape().NumOfElements(); } - - // @return the number of bytes this tensor is needs - dsize_t SizeInBytes() const { - if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); - return data_end_ - data_; - } - - // @return the rank of the tensor - dsize_t Rank() const { return shape().Rank(); } - - // Get the starting memory address as a constant for the data of the tensor. This potentially - // drives an allocation if the data area. - // @return const unsigned char* - const unsigned char *GetBuffer() const; - - // Getter of the type - // @return - DataType type() const { return type_; } - - // Provide stream operator for displaying it - // @param output stream - // @param so the Tensor object to be printed - // @return output stream - friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { - so.Print(out); - return out; - } - - // Invalidate this Tensor by setting the type and shape to unknown and MData to null. - // Calling this method will make the Tensor and its data inaccessible, use it with caution. - void Invalidate(); - - // Copy input tensor into self at the location index. - // Index is a vector of axises which can be incomplete: - // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. - // @param index - // @param input - // @return Status code - Status InsertTensor(const std::vector &index, const std::shared_ptr &input); - - // Find the address of the given index. Used in InsertTensor. - // Example: - // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 - // @param index incomplete index - // @param output: startAddrofIndex - // @param output: remaining - // @return Status code - Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); - - // Expand the shape of the Tensor with one extra dimension. - // For example, if the shape is <512,512,3>: - // *- ExpandDim(0) gives: <1,512,512,3> - // *- ExpandDim(1) gives: <512,1,512,3> - // *- ExpandDim(3) gives: <512,512,3,1> - // @param axis location of the dim - virtual Status ExpandDim(const dsize_t &axis); - - virtual void Squeeze(); - - // Calculates the strides of the Tensor - // Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) - // The strides will be {6,2,1}. - // Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) - // The strides will be {24,8,4}. - // @return vector of integers - std::vector Strides(); - - std::string ToString() { - std::stringstream ss; - this->Print(ss); - return ss.str(); - } - - // Handle negative indices. - static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } - - // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. - // Based on the type of tensor, SliceNumeric or SliceString will be called - // @param out Tensor - // @param indices vector of indices - // @return Status error code - Status Slice(std::shared_ptr *out, const std::vector &indices); - - // Slice numeric tensors. - Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); - - // Slice string tensors - Status SliceString(std::shared_ptr *out, const std::vector &indices); - -#ifdef ENABLE_PYTHON - // Constructs numpy array from input tensor - // @param data this data is the location of python data - // @return Status code - Status GetDataAsNumpy(py::array *data); - - Status GetDataAsNumpyStrings(py::array *data); - - static Status GetBufferInfo(Tensor *t, py::buffer_info *out); -#endif - - // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor - Status Concatenate(const std::vector &index, const std::shared_ptr &input); - - // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor - // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 - // @tparam T type of values in the Tensor Iterator - template - class TensorIterator { - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = T; - using difference_type = ptrdiff_t; - using pointer = T *; - using reference = T &; - - explicit TensorIterator(uchar *ptr = nullptr) { ptr_ = reinterpret_cast(ptr); } - - TensorIterator(const TensorIterator &raw_iterator) { ptr_ = raw_iterator.ptr_; } - - ~TensorIterator() = default; - - TensorIterator &operator=(const TensorIterator &rhs) { - ptr_ = rhs.ptr_; - return *this; - } - - TensorIterator &operator=(T *rhs) { - ptr_ = rhs; - return *this; - } - - bool operator==(const TensorIterator &rhs) { return ptr_ == rhs.ptr_; } - - bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } - - operator bool() const { return ptr_ != nullptr; } - - T &operator*() { return *ptr_; } - - const T &operator*() const { return *ptr_; } - - T *operator->() { return ptr_; } - - TensorIterator &operator+=(const ptrdiff_t &inc) { - ptr_ += inc; - return *this; - } - - TensorIterator &operator-=(const ptrdiff_t &inc) { - ptr_ -= inc; - return *this; - } - - TensorIterator &operator++() { - ++ptr_; - return *this; - } - - TensorIterator &operator--() { - --ptr_; - return *this; - } - - TensorIterator operator++(int) { - auto temp(*this); - ++ptr_; - return temp; - } - - TensorIterator operator--(int) { - auto temp(*this); - --ptr_; - return temp; - } - - TensorIterator operator+(const ptrdiff_t &inc) { - auto oldPtr = ptr_; - ptr_ += inc; - auto temp(*this); - ptr_ = oldPtr; - return temp; - } - - TensorIterator operator-(const ptrdiff_t &inc) { - auto oldPtr = ptr_; - ptr_ -= inc; - auto temp(*this); - ptr_ = oldPtr; - return temp; - } - - protected: - T *ptr_; - }; - - // Specialization of TensorIterator for strings. It returns std::string_view for every item. - // @tparam DUMMY, used to mbe able to specialize the inner class - template - class TensorIterator { - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = std::string_view; - using difference_type = ptrdiff_t; - using pointer = std::string_view *; - using reference = std::string_view &; - - explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) { - data_ = reinterpret_cast(data); - index_ = index; - } - - TensorIterator(const TensorIterator &raw_iterator) { - data_ = raw_iterator.data_; - index_ = raw_iterator.index_; - } - - ~TensorIterator() = default; - - bool operator==(const TensorIterator &rhs) { return data_ == rhs.data_ && index_ == rhs.index_; } - - bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } - - operator bool() const { return data_ != nullptr; } - - std::string_view operator*() const { - auto offset_ = reinterpret_cast(data_); - offset_t start = offset_[index_]; - return std::string_view{data_ + start}; - } - - TensorIterator &operator+=(const dsize_t &inc) { - index_ += inc; - return *this; - } - - TensorIterator &operator-=(const dsize_t &inc) { - index_ -= inc; - return *this; - } - - TensorIterator &operator++() { - ++index_; - return *this; - } - - TensorIterator &operator--() { - --index_; - return *this; - } - - TensorIterator operator++(int) { - auto temp(*this); - ++index_; - return temp; - } - - TensorIterator operator--(int) { - auto temp(*this); - --index_; - return temp; - } - - TensorIterator operator+(const dsize_t &inc) { - auto oldPtr = index_; - index_ += inc; - auto temp(*this); - index_ = oldPtr; - return temp; - } - - TensorIterator operator-(const dsize_t &inc) { - auto oldPtr = index_; - index_ -= inc; - auto temp(*this); - index_ = oldPtr; - return temp; - } - - protected: - dsize_t index_; - const char *data_; - }; - - // Return a TensorIterator that points to the start of the Tensor. - // It's the user responsibility to use the correct type that matches the Tensor type - // @param T The type of values in the Tensor - // @return TensorIterator - template - TensorIterator begin() { - AllocateBuffer(SizeInBytes()); - return TensorIterator(data_); - } - - // Return a linear iterator that points to the place after the last element of the Tensor. - // @tparam T The type of values in the Tensor - // @return TensorIterator - template - TensorIterator end() { - return TensorIterator(data_end_); - } - - // Copies the last dimension at `index` from Tensor `src` to this Tensor. - // @param src Tensor - // @param index vector to the start of the dimension. The last dim should be 0 - // @return Status - Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); - - protected: - // Get the starting memory address for the data of the tensor. This potentially - // drives an allocation if the data is null. - // @return unsigned char* - unsigned char *GetMutableBuffer(); - - // A function that prints Tensor recursively, first called by print - // @param out - // @param cur_dim - // @param cur_index - void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; - - // A function that prints info about the tensor - // @param out output stream - void Print(std::ostream &out) const; - - // A function that print the value as specified by its index - // @param index vector representing the index - // @param out - void PrintItemAt(const std::vector &index, std::ostream &out) const; - - // Get pointer to item located at `index`, caller needs to provide the type. - // @tparam T - // @param index vector - // @return return a pointer to the item specified at index of type `T` - template - Status GetItemPtr(T **, const std::vector &index) const; - - // Get pointer to string located at `index` and the length of string - // @param index vector - // @return return a pointer to the string specified at index and the length of the string - Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; - - // Given a flat index of an item string, return the start and length of the item - // @param index flat index of the item - // @return start address of the ths string - // @return length of the string - Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; - - // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the - // tensor's type is a string, otherwise undefined address would be returned. - // @return address of the first string of the tensor. - uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } - - // all access to shape_ should be via shape - TensorShape shape_; - // data type of tensor - DataType type_; - // pointer to the start of the physical data - unsigned char *data_; - // An allocator for data_ - CharAllocPtr data_allocator_; - // pointer to the end of the physical data - unsigned char *data_end_ = nullptr; -}; -template <> -inline Tensor::TensorIterator Tensor::end() { - return TensorIterator(data_, shape_.NumOfElements()); -} -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/dataset/core/tensor_row.cc b/mindspore/ccsrc/dataset/core/tensor_row.cc deleted file mode 100644 index 930608d108..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_row.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { - -TensorRow::TensorRow() noexcept : id_(kDefaultRowId) {} - -TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept : id_(kDefaultRowId), row_(n, t) {} - -TensorRow::TensorRow(const TensorRow::vector_type &v) : id_(kDefaultRowId), row_(v) {} - -TensorRow::TensorRow(row_id_type id, const std::initializer_list &lst) : id_(id), row_(lst) {} - -TensorRow::TensorRow(const TensorRow &tr) : id_(tr.id_), row_(tr.row_) {} - -TensorRow &TensorRow::operator=(const TensorRow &tr) { - if (this == &tr) { - return *this; - } - row_ = tr.row_; - id_ = tr.id_; - return *this; -} - -TensorRow &TensorRow::operator=(const std::initializer_list &lst) { - row_ = lst; - return *this; -} - -TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept : id_(kDefaultRowId), row_(std::move(v)) {} - -TensorRow::TensorRow(row_id_type id, std::initializer_list &&lst) noexcept - : id_(id), row_(std::move(lst)) {} - -TensorRow::TensorRow(TensorRow &&tr) noexcept { - id_ = tr.id_; - row_ = std::move(tr.row_); -} - -TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept { - if (this == &tr) { - return *this; - } - row_ = std::move(tr.row_); - id_ = tr.id_; - tr.id_ = kDefaultRowId; - return *this; -} - -TensorRow &TensorRow::operator=(std::initializer_list &&lst) noexcept { - row_ = std::move(lst); - return *this; -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor_row.h b/mindspore/ccsrc/dataset/core/tensor_row.h deleted file mode 100644 index 49bc61657c..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_row.h +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_CORE_TENSOR_ROW_H_ -#define DATASET_CORE_TENSOR_ROW_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { - -class TensorRow; // A set of Tensor pointers with an id -using TensorTable = std::vector; // The table of tensors is a vector of rows -using TensorQTable = std::deque; // A different flavour of tensor table, this one has queue functionality - -class TensorRow { - public: - static constexpr row_id_type kDefaultRowId = -1; // Default row id - - // Type definitions - using size_type = dsize_t; - using value_type = std::shared_ptr; - using reference = std::shared_ptr &; - using const_reference = const std::shared_ptr &; - using vector_type = std::vector>; - using iterator = std::vector>::iterator; - using const_iterator = std::vector>::const_iterator; - - TensorRow() noexcept; - - TensorRow(size_type n, value_type t) noexcept; - - // Copy Constructors - explicit TensorRow(const vector_type &v); - - TensorRow(row_id_type id, const std::initializer_list &lst); - - TensorRow(const TensorRow &tr); - - TensorRow &operator=(const TensorRow &tr); - - TensorRow &operator=(const std::initializer_list &lst); - - // Move Constructors - explicit TensorRow(vector_type &&v) noexcept; - - TensorRow(row_id_type id, std::initializer_list &&lst) noexcept; - - TensorRow(TensorRow &&tr) noexcept; - - TensorRow &operator=(TensorRow &&tr) noexcept; - - TensorRow &operator=(std::initializer_list &&lst) noexcept; - - // Destructor - ~TensorRow() = default; - - // Functions to fetch/set id/vector - row_id_type getId() const { return id_; } - - void setId(row_id_type id) { id_ = id; } - - const vector_type &getRow() const { return row_; } - - // Wrapper functions to support vector operations - void emplace_back(value_type t) { row_.emplace_back(t); } - - void push_back(value_type t) { row_.push_back(t); } - - void clear() noexcept { row_.clear(); } - - size_type size() const noexcept { return row_.size(); } - - void reserve(size_type size) { row_.reserve(size); } - - void resize(size_type size) { row_.resize(size); } - - bool empty() { return row_.empty(); } - - void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); } - - // Wrapper functions to support vector element access - reference at(size_type index) { return row_.at(index); } - - const_reference at(size_type index) const { return row_.at(index); } - - reference front() { return row_.front(); } - - const_reference front() const { return row_.front(); } - - reference back() { return row_.back(); } - - const_reference back() const { return row_.back(); } - - reference operator[](size_type index) { return row_[index]; } - - const_reference operator[](size_type index) const { return row_[index]; } - - // Wrapper functions to support vector iteration - iterator begin() { return row_.begin(); } - - const_iterator begin() const { return row_.begin(); } - - iterator end() { return row_.end(); } - - const_iterator end() const { return row_.end(); } - - protected: - row_id_type id_; - std::vector> row_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_TENSOR_ROW_H_ diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.cc b/mindspore/ccsrc/dataset/core/tensor_shape.cc deleted file mode 100644 index 953b9dfc9f..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_shape.cc +++ /dev/null @@ -1,235 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#define MAX_INTEGER_DTYPE 9223372036854775807 - -#include "dataset/core/tensor_shape.h" - -#include - -#include "common/utils.h" -#include "utils/log_adapter.h" -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { -constexpr dsize_t TensorShape::kDimUnknown; - -bool multi_ok(dsize_t x, dsize_t y) { - dsize_t p = x * y; - if (x == 0) { - return true; - } - return p / x == y; -} - -dsize_t TensorShape::NumOfElements() const { - if (!known()) { - return 0; - } - return strides_[0]; -} - -void TensorShape::Print(std::ostream &out) const { - if (!known() && raw_shape_.empty()) { - out << ""; - } else { - out << "<"; - for (auto i = 0; i < this->Rank(); i++) { - if (raw_shape_[i] == kDimUnknown) { - out << "*"; - } else { - out << raw_shape_[i]; - } - if (i != this->Rank() - 1) { - out << ","; - } - } - out << ">"; - } -} - -TensorShape::TensorShape(const std::initializer_list &list) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - AddListToShape(list); -} - -TensorShape::TensorShape(const std::vector &list) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - AddListToShape(list); -} - -TensorShape::TensorShape(const TensorShape &shape) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - AddListToShape(shape.AsVector()); - known_ = shape.known_; // override with the input shape in case of unknown-rank tensor shape. -} - -#ifdef ENABLE_PYTHON -TensorShape::TensorShape(py::list l) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - std::vector list_c; - for (auto &i : l) { - if (!i.is_none()) { - list_c.push_back(i.cast()); - } else { - list_c.push_back(TensorShape::kDimUnknown); - } - } - AddListToShape(list_c); -} -#endif - -TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) - : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { - for (int i = 0; i < cv_size.dims(); i++) { - raw_shape_.push_back(cv_size[i]); - } - auto channels = static_cast(1 + (type >> static_cast(CV_CN_SHIFT))); - if (channels != 1) { - raw_shape_.push_back(channels); - } - known_ = true; -} - -TensorShape TensorShape::CreateUnknownRankShape() { - TensorShape s({}); - s.known_ = false; - return s; -} - -TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const { - std::vector tmp = AsVector(); - (void)tmp.insert(tmp.begin() + axis, dim); - return TensorShape(tmp); -} - -std::vector TensorShape::AsVector() const { - return std::vector(raw_shape_.begin(), raw_shape_.end()); -} - -bool TensorShape::IsValidIndex(const std::vector &index) const { - dsize_t s_rank = Rank(); - if (index.size() != s_rank) { - return false; - } - for (dsize_t i = 0; i < s_rank; i++) { - if (index[i] < 0 || raw_shape_[i] <= index[i]) { - return false; - } - } - return true; -} - -template -void TensorShape::AddListToShape(const T &list) { - raw_shape_.resize(list.size()); - strides_.resize(list.size() + 1); - strides_[list.size()] = 1; - known_ = true; - dsize_t size = 0; - auto itr = std::rbegin(list); // iterate over the list in reverse order - auto s = list.size() - 1; // to compute strides while adding dims - for (; itr != std::rend(list); itr++, s--) { - dsize_t dim = *itr; - if (dim > 0) { - if (strides_[s + 1] > std::numeric_limits::max() / dim) { - MS_LOG(ERROR) << "Invalid shape data, overflow occurred!"; - known_ = false; - raw_shape_.clear(); - return; - } - strides_[s] = dim * strides_[s + 1]; - } - if (dim < 0) { - known_ = false; - } - if (dim > kDeMaxDim) { - std::stringstream ss; - ss << "Invalid shape data, dim (" << size << ") is larger than the maximum dim size(" << kDeMaxDim << ")!"; - MS_LOG(ERROR) << ss.str().c_str(); - known_ = false; - raw_shape_.clear(); - return; - } - raw_shape_[s] = dim; - size++; - } - if (size > kDeMaxRank) { - std::stringstream ss; - ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ")."; - MS_LOG(ERROR) << ss.str().c_str(); - known_ = false; - raw_shape_.clear(); - return; - } -} - -TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) { - TensorShape s({}); - for (dsize_t i = 0; i < rank; i++) { - s.raw_shape_.push_back(kDimUnknown); - } - s.known_ = false; - return s; -} - -TensorShape TensorShape::PrependDim(dsize_t dim) const { - if (Size() == 0) { - return TensorShape({dim}); - } - return InsertDim(0, dim); -} - -TensorShape TensorShape::AppendDim(dsize_t dim) const { - auto vec = AsVector(); - vec.push_back(dim); - return TensorShape(vec); -} - -#ifdef ENABLE_PYTHON -py::list TensorShape::AsPyList() { - py::list list; - for (auto i : raw_shape_) { - list.append(i); - } - return list; -} -#endif - -TensorShape TensorShape::Squeeze() const { - std::vector new_shape; - for (auto s : AsVector()) { - if (s != 1) { - new_shape.push_back(s); - } - } - return TensorShape(new_shape); -} - -std::vector TensorShape::Strides() const { return std::vector{strides_.begin() + 1, strides_.end()}; } - -// Name: ToFlatIndex() -// Description: convert a vector style index to number, used to access memory internal use only -Status TensorShape::ToFlatIndex(const std::vector &index, dsize_t *flat_index) const { - *flat_index = 0; - for (size_t k = 0; k < index.size(); k++) { - *flat_index += index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements - } - CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index"); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor_shape.h b/mindspore/ccsrc/dataset/core/tensor_shape.h deleted file mode 100644 index 3d2681271a..0000000000 --- a/mindspore/ccsrc/dataset/core/tensor_shape.h +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CORE_TENSOR_SHAPE_H_ -#define DATASET_CORE_TENSOR_SHAPE_H_ - -#include -#include -#include -#include -#include - -#include - -#ifdef ENABLE_PYTHON -#include "pybind11/pybind11.h" -namespace py = pybind11; -#endif - -#include "dataset/core/constants.h" -#include "dataset/util/status.h" -#include "dataset/core/global_context.h" -#include "dataset/util/allocator.h" - -namespace mindspore { -namespace dataset { -// Class that represents a shape of a Tensor. A shape can be: -// -# Known shape (mKnown = true) -// -# Scalar --> empty vector --> <> -// -# n-Dim --> not empty vector --> where di is >= 0\n -// Example: <1,2>, <1>, <1,13,10,11,1> -// -# Unknown shape (mKnown = false) -// -# Rank is unknown --> empty vector --> <> -// -# one or more dim is unknown --> not empty vector --> where di is unknown\n -// Example: <3,?> (the 1st dim is unknown)\n -// <2,?,?,?> (all dims but the 0th dim are unknown) - -/// \brief TensorShape supports any dim > 0 and < 2^31-1 -class TensorShape { - public: - static constexpr dsize_t kDimUnknown = -1; // constant for an unknown dimension - - // Force the compiler to not create a no-arg constructor - TensorShape() = delete; - - /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}). - /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown - /// \param[in] list - explicit TensorShape(const std::initializer_list &list); - - /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector({2,2}) ). - /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown - /// \param[in] list - explicit TensorShape(const std::vector &list); - - /// \brief Copy constructor - /// \param[in] shape - TensorShape(const TensorShape &shape); - -#ifdef ENABLE_PYTHON - /// \brief construct a TensorShape via a python list - /// \param[in] py::list l - a list object from python - explicit TensorShape(py::list l); -#endif - - ~TensorShape() = default; - - /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true) - /// \return TensorShape - static TensorShape CreateScalar() { return TensorShape({}); } - - /// \brief Create a shape with an unknown rank. - /// \return TensorShape - static TensorShape CreateUnknownRankShape(); - - /// \brief Create a shape with a known rank . - /// \return TensorShape - static TensorShape CreateUnknownShapeWithRank(dsize_t rank); - - /// \brief Insert a new dim into a copy of the current shape. - /// \param[in] dim to be added - /// \param[in] axis the index where dim should be added - /// \return New modified shape - TensorShape InsertDim(dsize_t axis, dsize_t dim) const; - - /// \brief Insert new dim at index 0. For example, <2,4> --> PrependDim(4) --> <4,2,4> - /// \param[in] dim - /// \return - TensorShape PrependDim(dsize_t dim) const; - - /// \brief Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> - /// \param[in] dim - /// \return - TensorShape AppendDim(dsize_t dim) const; - - /// \brief Create a shape based on OpenCV shape and type - /// \param[in] cv_size - /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S - TensorShape(cv::MatSize cv_size, uint32_t type); - - dsize_t Size() const { return raw_shape_.size(); } - - dsize_t Rank() const { return raw_shape_.size(); } - - bool known() const { return known_; } - - bool empty() const { return raw_shape_.empty(); } - - dsize_t NumOfElements() const; - - bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; } - - bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); } - - dsize_t operator[](const dsize_t index) const { - if (index < 0) return raw_shape_[raw_shape_.size() + index]; - return raw_shape_[index]; - } - - /// \brief Return the Shape as a vector - /// \return - std::vector AsVector() const; - - /// \brief Returns the class info as a string - /// \return - std::string ToString() const { - std::stringstream ss; - ss << *this; - return ss.str(); - } - - /// \brief Actual print function used by operator<< - /// \param out output string stream - void Print(std::ostream &out) const; - - /// \brief << Stream output operator overload - /// This allows you to print the info using stream operators - /// \param[in] out - reference to the output stream being overloaded - /// \param[in] rO - reference to the TensorShape to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) { - so.Print(out); - return out; - } - -#ifdef ENABLE_PYTHON - py::list AsPyList(); -#endif - - /// \brief Checks if the given index is a valid index for this tensor. - /// For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not. - /// \param[in] index - /// \return bool - bool IsValidIndex(const std::vector &index) const; - - TensorShape Squeeze() const; - - std::vector Strides() const; - - /// \brief Returns the location of the item assuming row major memory layout. - /// \param[in] index - /// \param[out] flat_index - /// \return - Status ToFlatIndex(const std::vector &index, dsize_t *flat_index) const; - - private: - // True if known and valid shape, false otherwise - bool known_; - // Vector to keep the dims of the shape. - std::vector raw_shape_; - // Vector to keep the strides of the shape. The size is rank+1 - std::vector strides_; - - /// \brief Internal utility function to iterate over a list, - /// check if the dim is valid and then insert it into the shape. - /// \param[in] list Iterable list - /// \return true if the shape is valid and no overflow would be generated when counting the number of elements. - /// False otherwise. - template - void AddListToShape(const T &list); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_TENSOR_SHAPE_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/dataset/engine/cache/cache_client.cc deleted file mode 100644 index 1dc97ac43a..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_client.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/cache/cache_request.h" -#include "dataset/util/bit.h" - -namespace mindspore { -namespace dataset { - -// Constructor -CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) - : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} - -// print method for display cache details -void CacheClient::Print(std::ostream &out) const { - out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ - << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ - << "\n Spilling: " << std::boolalpha << spill_; -} - -Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { - CacheRowRequest rq(server_connection_id_, cookie()); - RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - if (row_id_from_server != nullptr) { - *row_id_from_server = rq.GetRowIdAfterCache(); - } - return Status::OK(); -} - -Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { - std::unique_ptr db_ptr = std::move(in); - auto num_rows = db_ptr->NumRows(); - std::vector all_rows; - if (num_rows > 0) { - all_rows.reserve(num_rows); - // Break down the DataBuffer into TensorRow. We will send the requests async - // and then do a final wait. - MemGuard rq_arr; - RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); - CacheServer &cs = CacheServer::GetInstance(); - for (auto i = 0; i < num_rows; ++i) { - TensorRow row; - auto rq = rq_arr[i]; - RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); - RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); - RETURN_IF_NOT_OK(cs.PushRequest(rq)); - // We can't let row go out of scope. Otherwise it will free all the tensor memory. - // So park it in the vector. When this function go out of scope, its memory - // will be freed. - all_rows.push_back(std::move(row)); - } - // Now we wait for the requests to be done. - for (auto i = 0; i < num_rows; ++i) { - auto rq = rq_arr[i]; - RETURN_IF_NOT_OK(rq->Wait()); - } - } - return Status::OK(); -} - -Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { - RETURN_UNEXPECTED_IF_NULL(out); - BatchFetchRequest rq(server_connection_id_, row_id); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - RETURN_IF_NOT_OK(rq.RestoreRows(out)); - return Status::OK(); -} - -Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { - UniqueLock lck(&mux_); - // To create a cache, we identify ourself at the client by: - // - the shared session id - // - a crc for the tree nodes from the cache downward - // Pack these 2 into a single 64 bit request id - // - // Consider this example: - // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch - // tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch - // These are different trees in a single session, but the user wants to share the cache. - // This is not allowed because the data of these caches are different. - // - // Consider this example: - // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch - // tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch - // These are different trees in the same session, but the cached data is the same, so it is okay - // to allow the sharing of this cache between these pipelines. - - // The CRC is computed by the tree prepare phase and passed to this function when creating the cache. - // If we already have a server_connection_id_, then it means this same cache client has already been used - // to create a cache and some other tree is trying to use the same cache. - // That is allowed, however the crc better match! - if (server_connection_id_) { - if (cache_crc_ != tree_crc) { - RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); - } - // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should - // skip the build phase. - lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. - CacheClient::ServiceStat stat{}; - RETURN_IF_NOT_OK(GetStat(&stat)); - if (stat.cache_service_state == static_cast(CacheService::State::kFetchPhase)) { - return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); - } - } else { - cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client - // Combine the session and crc. This will form our client cache identifier. - connection_id_type connection_identification = (static_cast(session_id_) << 32) | cache_crc_; - // Now execute the cache create request using this identifier and other configs - BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; - if (spill_) { - createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; - } - if (generate_id) { - createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; - } - CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - Status rc = rq.Wait(); - if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { - server_connection_id_ = rq.GetServerConnectionId(); - if (rc.IsOk()) { - // The 1st guy creating the cache will get a cookie back. - // But this object may be shared among pipelines and we don't want - // overwrite it. - cookie_ = rq.cookie(); - } - } - // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the - // CacheOp to bypass the build phase. - return rc; - } - return Status::OK(); -} - -Status CacheClient::PurgeCache() { - UniqueLock lck(&mux_); - PurgeCacheRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - return rq.Wait(); -} - -Status CacheClient::DestroyCache() { - UniqueLock lck(&mux_); - DestroyCacheRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - return rq.Wait(); -} - -Status CacheClient::GetStat(ServiceStat *stat) { - SharedLock lck(&mux_); - RETURN_UNEXPECTED_IF_NULL(stat); - GetStatRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - stat->num_disk_cached = rq.GetNumDiskCached(); - stat->num_mem_cached = rq.GetNumMemCached(); - stat->min_row_id = rq.GetMinRowId(); - stat->max_row_id = rq.GetMaxRowId(); - stat->cache_service_state = rq.GetState(); - return Status::OK(); -} - -Status CacheClient::CacheSchema(const std::unordered_map &map) { - SharedLock lck(&mux_); - CacheSchemaRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - return Status::OK(); -} - -Status CacheClient::FetchSchema(std::unordered_map *map) { - SharedLock lck(&mux_); - RETURN_UNEXPECTED_IF_NULL(map); - FetchSchemaRequest rq(server_connection_id_); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - *map = rq.GetColumnMap(); - return Status::OK(); -} - -Status CacheClient::BuildPhaseDone() const { - SharedLock lck(&mux_); - BuildPhaseDoneRequest rq(server_connection_id_, cookie()); - RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); - RETURN_IF_NOT_OK(rq.Wait()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/dataset/engine/cache/cache_client.h deleted file mode 100644 index ffdb9e9fdd..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_client.h +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_CACHE_CLIENT_H_ -#define DATASET_ENGINE_CACHE_CLIENT_H_ - -#include -#include -#include -#include -#include -#include - -#include "./de_tensor_generated.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/cache/cache_server.h" -#include "dataset/util/lock.h" - -namespace mindspore { -namespace dataset { -/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through -/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously -/// rows, etc. -class CacheClient { - public: - /// \brief Constructor - /// \param session_id A user assigned session id for the current pipeline - /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited - /// \param spill Spill to disk if out of memory - CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); - - /// \brief Destructor - ~CacheClient() = default; - - /// \brief Getter function for returning the current session id - /// \return session id - uint64_t session_id() const { return session_id_; } - - /// \brief Send a TensorRow to the cache server - /// \param[in] row - /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset - /// \return return code - Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const; - - /// \brief Send a DataBuffer to the cache server - /// \param in Unique pointer of the DataBuffer to be cached - /// \return return code - Status WriteBuffer(std::unique_ptr &&in) const; - - /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is - /// any cache miss - /// \param row_id A vector of row id's - /// \param out A TensorTable of TensorRows. - /// \return return code - Status GetRows(const std::vector &row_id, TensorTable *out) const; - - /// \brief Create a cache. - /// \param tree_crc A crc that was generated during tree prepare phase - /// \param generate_id Let the cache service generate row id - /// \return Status object - Status CreateCache(uint32_t tree_crc, bool generate_id); - - /// \brief Purge a cache. Cache can be reused after reset. - /// \return Status object - Status PurgeCache(); - - /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. - /// \return Status object - Status DestroyCache(); - - /// \brief Get the statistics from a cache. - /// \param[in/out] Pointer to a pre-allocated ServiceStat object - /// \return Status object - struct ServiceStat { - int64_t num_mem_cached; - int64_t num_disk_cached; - row_id_type min_row_id; - row_id_type max_row_id; - int8_t cache_service_state; - }; - Status GetStat(ServiceStat *); - - /// \brief Cache the schema at the cache server - /// \param map The unordered map of the schema - /// \return Status object - Status CacheSchema(const std::unordered_map &map); - - /// \brief Fetch the schema from the cache server - /// \param map Pointer to pre-allocated map object - /// \return Status object. - Status FetchSchema(std::unordered_map *map); - - /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache - /// client that holds cookie can be allowed to make this request - /// \return Status object - Status BuildPhaseDone() const; - - /// \brief A print method typically used for debugging - /// \param out The output stream to write output to - void Print(std::ostream &out) const; - - /// \brief Stream output operator overload - /// \return the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) { - cc.Print(out); - return out; - } - - /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it. - /// \return Cookie - std::string cookie() const { return cookie_; } - - private: - mutable RWLock mux_; - uint64_t cache_mem_sz_; - bool spill_; - // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow - // sharing of the cache. - uint32_t session_id_; - uint32_t cache_crc_; - // The server_connection_id_ is the actual id we use for operations after the cache is built - connection_id_type server_connection_id_; - // Some magic cookie returned from the cache server. - std::string cookie_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_CACHE_CLIENT_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/dataset/engine/cache/cache_request.cc deleted file mode 100644 index 5485c22b6a..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_request.cc +++ /dev/null @@ -1,223 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/engine/cache/cache_request.h" - -namespace mindspore { -namespace dataset { - -Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { - buffers_.reserve(row.size() + 1); - RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); - buffers_.push_back(fbb_->GetBufferPointer()); - for (const auto &ts : row) { - buffers_.push_back(ts->GetBuffer()); - } - return Status::OK(); -} - -Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { - try { - fbb_ = std::make_shared(); - std::vector> v; - std::vector tensor_sz; - v.reserve(row.size()); - tensor_sz.reserve(row.size()); - // We will go through each column in the row. - for (const std::shared_ptr &ts_ptr : row) { - flatbuffers::Offset ts_off; - RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); - v.push_back(ts_off); - tensor_sz.push_back(ts_ptr->SizeInBytes()); - } - auto column_off = fbb_->CreateVector(v); - auto data_sz_off = fbb_->CreateVector(tensor_sz); - TensorRowHeaderMsgBuilder row_builder(*fbb_); - row_builder.add_column(column_off); - row_builder.add_data_sz(data_sz_off); - // Pass the row_id even if it may not be known. - row_builder.add_row_id(row.getId()); - row_builder.add_size_of_this(-1); // fill in later after we call Finish. - auto out = row_builder.Finish(); - fbb_->Finish(out); - // Now go back to fill in size_of_this in the flat buffer. - auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); - auto success = msg->mutate_size_of_this(fbb_->GetSize()); - if (!success) { - RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); - } - return Status::OK(); - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } -} - -Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, - flatbuffers::Offset *out_off) { - RETURN_UNEXPECTED_IF_NULL(out_off); - const Tensor *ts = ts_ptr.get(); - auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); - const auto ptr = ts->GetBuffer(); - if (ptr == nullptr) { - RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); - } - auto src = ts->type().value(); - TensorType dest; -#define CASE(t) \ - case DataType::t: \ - dest = TensorType::TensorType_##t; \ - break - // Map the type to fill in the flat buffer. - switch (src) { - CASE(DE_BOOL); - CASE(DE_INT8); - CASE(DE_UINT8); - CASE(DE_INT16); - CASE(DE_UINT16); - CASE(DE_INT32); - CASE(DE_UINT32); - CASE(DE_INT64); - CASE(DE_UINT64); - CASE(DE_FLOAT16); - CASE(DE_FLOAT32); - CASE(DE_FLOAT64); - CASE(DE_STRING); - default: - MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; - RETURN_STATUS_UNEXPECTED("Unknown type"); - } -#undef CASE - - TensorMetaMsgBuilder ts_builder(*fbb_); - ts_builder.add_dims(shape_off); - ts_builder.add_type(dest); - auto ts_off = ts_builder.Finish(); - *out_off = ts_off; - return Status::OK(); -} - -Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, - std::shared_ptr *out) { - RETURN_UNEXPECTED_IF_NULL(col_ts); - auto shape_in = col_ts->dims(); - auto type_in = col_ts->type(); - std::vector v; - v.reserve(shape_in->size()); - v.assign(shape_in->begin(), shape_in->end()); - TensorShape shape(v); - DataType::Type dest = DataType::DE_UNKNOWN; -#define CASE(t) \ - case TensorType_##t: \ - dest = DataType::Type::t; \ - break - - switch (type_in) { - CASE(DE_BOOL); - CASE(DE_INT8); - CASE(DE_UINT8); - CASE(DE_INT16); - CASE(DE_UINT16); - CASE(DE_INT32); - CASE(DE_UINT32); - CASE(DE_INT64); - CASE(DE_UINT64); - CASE(DE_FLOAT16); - CASE(DE_FLOAT32); - CASE(DE_FLOAT64); - CASE(DE_STRING); - } -#undef CASE - - DataType type(dest); - std::shared_ptr ts = - std::make_shared(shape, type, static_cast(data.GetPointer()), data.GetSize()); - // Next we restore the real data which can be embedded or stored separately. - if (ts->SizeInBytes() != data.GetSize()) { - MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" - << "Dumping tensor\n" - << *ts << "\n"; - RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); - } - *out = std::move(ts); - return Status::OK(); -} - -Status BatchFetchRequest::RestoreRows(TensorTable *out) { - RETURN_UNEXPECTED_IF_NULL(out); - auto num_elements = row_id_.size(); - auto *offset_array = reinterpret_cast(mem_.GetPointer()); - TensorTable tbl; - tbl.reserve(num_elements); - ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); - for (auto i = 0; i < num_elements; ++i) { - auto len = offset_array[i + 1] - offset_array[i]; - TensorRow row; - row.setId(row_id_.at(i)); - if (len > 0) { - ReadableSlice row_data(all, offset_array[i], len); - // Next we de-serialize flat buffer to get back each column - auto msg = GetTensorRowHeaderMsg(row_data.GetPointer()); - auto msg_sz = msg->size_of_this(); - // Start of the tensor data - auto ts_offset = msg_sz; - row.reserve(msg->column()->size()); - for (auto k = 0; k < msg->column()->size(); ++k) { - auto col_ts = msg->column()->Get(k); - std::shared_ptr ts; - ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); - RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); - row.push_back(ts); - ts_offset += data.GetSize(); - } - } - tbl.push_back(std::move(row)); - } - *out = std::move(tbl); - return Status::OK(); -} - -Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map &map) { - try { - fbb_ = std::make_shared(); - std::vector> v; - v.reserve(map.size()); - for (auto &column : map) { - auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); - v.push_back(c); - } - auto v_off = fbb_->CreateVector(v); - auto final_off = CreateSchemaMsg(*fbb_, v_off); - fbb_->Finish(final_off); - buf_ = fbb_->GetBufferPointer(); - len_of_buf_ = fbb_->GetSize(); - return Status::OK(); - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } -} - -std::unordered_map FetchSchemaRequest::GetColumnMap() { - if (column_name_id_map_.empty()) { - auto *map_msg = flatbuffers::GetRoot(mem_.GetPointer()); - auto v = map_msg->column(); - for (auto i = 0; i < v->size(); ++i) { - auto col = map_msg->column()->Get(i); - column_name_id_map_.emplace(col->name()->str(), col->id()); - } - } - return column_name_id_map_; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/dataset/engine/cache/cache_request.h deleted file mode 100644 index 3182816e54..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_request.h +++ /dev/null @@ -1,225 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#ifndef DATASET_ENGINE_CACHE_REQ_H_ -#define DATASET_ENGINE_CACHE_REQ_H_ - -#include -#include -#include -#include -#include -#include - -#include "./de_tensor_generated.h" -#include "dataset/core/tensor_row.h" -#include "dataset/util/slice.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -/// \brief CacheClient communicates with CacheServer using Requests. -class BaseRequest { - public: - // Request types - enum class RequestType : int16_t { - kCacheRow = 0, - kBatchFetchRows = 1, - kCreateCache = 2, - kPurgeCache = 3, - kDestroyCache = 4, - kGetStat = 5, - kCacheSchema = 6, - kFetchSchema = 7, - kBuildPhaseDone = 8, - // Add new request before it. - kRequestUnknown = 32767 - }; - // For kCreateCache - enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; - friend class CacheServer; - /// \brief Base class of a cache server request - /// \param connection_id A combination of session id and crc that uniquely identifies a connection. - /// \param type Type of the request - explicit BaseRequest(connection_id_type connection_id, RequestType type) - : type_(type), connection_id_(connection_id) {} - virtual ~BaseRequest() = default; - /// \brief Wait for the completion of a request - /// \return Status returned from the cache server - Status Wait() { - RETURN_IF_NOT_OK(wp_.Wait()); - return rc_; - } - - /// \brief Getter function of the current connection id - /// \return Connection id - connection_id_type GetServerConnectionId() const { return connection_id_; } - - private: - RequestType type_; - connection_id_type connection_id_; - Status rc_; - WaitPost wp_; -}; -/// \brief Request to cache a single TensorRow -class CacheRowRequest : public BaseRequest { - public: - friend class CacheServer; - explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) - : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} - ~CacheRowRequest() = default; - - /// \brief Serialize a TensorRow for streaming to the cache server - /// \param row TensorRow - /// \return Status object - Status SerializeCacheRowRequest(const TensorRow &row); - /// \brief Return the row id assigned to this row for non-mappable dataset - /// \return row id of the cached row - row_id_type GetRowIdAfterCache() { return row_id_from_server_; } - - private: - std::shared_ptr fbb_; - row_id_type row_id_from_server_; - std::vector buffers_; - std::string cookie_; - - /// \brief Private function to serialize one TensorRow - /// \param row TensorRow - /// \return Status object - Status SerializeTensorRowHeader(const TensorRow &row); - /// \brief Private function to serialize one Tensor - /// \param ts_ptr Tensor - /// \return Status object - Status SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, flatbuffers::Offset *out_off); -}; -/// \brief Request to fetch rows in batch -class BatchFetchRequest : public BaseRequest { - public: - friend class CacheServer; - friend class CacheService; - BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id) - : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} - Status RestoreRows(TensorTable *out); - - private: - std::vector row_id_; - MemGuard mem_; - Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out); -}; -/// \brief Request to create a cache for the current connection -class CreationCacheRequest : public BaseRequest { - public: - friend class CacheServer; - /// \brief Constructor - /// \param connection_id - /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited - /// \param flag Attributes of the cache. - explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, - CreateCacheFlag flag = CreateCacheFlag::kNone) - : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} - - std::string cookie() const { return cookie_; } - - private: - uint64_t cache_mem_sz; - CreateCacheFlag flag_; - std::string cookie_; -}; -/// \brief Request to purge a cache. -class PurgeCacheRequest : public BaseRequest { - public: - friend class CacheServer; - explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} -}; -/// \brief Request to destroy a cache -class DestroyCacheRequest : public BaseRequest { - public: - friend class CacheServer; - explicit DestroyCacheRequest(connection_id_type connection_id) - : BaseRequest(connection_id, RequestType::kDestroyCache) {} -}; -/// \brief Obtain the statistics of the current connection -class GetStatRequest : public BaseRequest { - public: - friend class CacheServer; - friend class CacheService; - explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} - row_id_type GetMinRowId() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->min_row_id(); - } - row_id_type GetMaxRowId() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->max_row_id(); - } - int64_t GetNumMemCached() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->num_mem_cached(); - } - int64_t GetNumDiskCached() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->num_disk_cached(); - } - uint8_t GetState() const { - auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); - return msg->state(); - } - - private: - MemGuard mem_; -}; -/// \brief Request to cache a schema -class CacheSchemaRequest : public BaseRequest { - public: - friend class CacheServer; - explicit CacheSchemaRequest(connection_id_type connection_id) - : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} - ~CacheSchemaRequest() = default; - - Status SerializeCacheSchemaRequest(const std::unordered_map &map); - const void *GetBuffer() const { return buf_; } - - private: - std::shared_ptr fbb_; - const void *buf_; - int64_t len_of_buf_; -}; -/// \brief Request to fetch a schema -class FetchSchemaRequest : public BaseRequest { - public: - friend class CacheServer; - explicit FetchSchemaRequest(connection_id_type connection_id) - : BaseRequest(connection_id, RequestType::kFetchSchema) {} - ~FetchSchemaRequest() = default; - - std::unordered_map GetColumnMap(); - - private: - MemGuard mem_; - std::unordered_map column_name_id_map_; -}; -/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. -class BuildPhaseDoneRequest : public BaseRequest { - public: - friend class CacheServer; - BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) - : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} - - private: - std::string cookie_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/dataset/engine/cache/cache_server.cc deleted file mode 100644 index 88d617b598..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_server.cc +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/engine/cache/cache_server.h" -#include "dataset/engine/cache/cache_service.h" -#include "dataset/engine/cache/cache_request.h" -#include "dataset/util/bit.h" - -namespace mindspore { -namespace dataset { -Status CacheServer::DoServiceStart() { - if (!top_.empty()) { - Path spill(top_); - RETURN_IF_NOT_OK(spill.CreateDirectories()); - MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; - } - RETURN_IF_NOT_OK(vg_.ServiceStart()); - cache_q_ = std::make_shared>(1024); - RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); - auto f = std::bind(&CacheServer::ServerRequest, this); - // Spawn a a few threads to serve the request. - for (auto i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); - } - return Status::OK(); -} - -Status CacheServer::DoServiceStop() { - Status rc; - Status rc2; - // First stop all the threads. - RETURN_IF_NOT_OK(vg_.ServiceStop()); - // Clean up all the caches if any. - UniqueLock lck(&rwLock_); - auto it = all_caches_.begin(); - while (it != all_caches_.end()) { - auto cs = std::move(it->second); - rc2 = cs->ServiceStop(); - if (rc2.IsError()) { - rc = rc2; - } - ++it; - } - return rc; -} - -CacheService *CacheServer::GetService(connection_id_type id) const { - SharedLock lck(&rwLock_); - auto it = all_caches_.find(id); - if (it != all_caches_.end()) { - return it->second.get(); - } - return nullptr; -} - -Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, - BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { - // We can't do spilling unless this server is setup with a spill path in the first place - bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; - bool generate_id = - (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; - if (spill && top_.empty()) { - RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); - } - RETURN_UNEXPECTED_IF_NULL(out_cookie); - *out_cookie = ""; - // Before creating the cache, first check if this is a request for a shared usage of an existing cache - // If two CreateService come in with identical connection_id, we need to serialize the create. - // The first create will be successful and be given a special cookie. - UniqueLock lck(&rwLock_); - auto end = all_caches_.end(); - auto it = all_caches_.find(connection_id); - if (it == end) { - std::unique_ptr cs; - try { - cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); - RETURN_IF_NOT_OK(cs->ServiceStart()); - *out_cookie = cs->cookie(); - all_caches_.emplace(connection_id, std::move(cs)); - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory); - } - } else { - MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; - // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it - // treat it as OK. - return Status(StatusCode::kDuplicateKey); - } - return Status::OK(); -} - -/// This is the main loop the cache server thread(s) are running. -/// Each thread will pop a request and save the result in the same request. -/// The sender will wait on the wait post in the request. Once the request -/// is fulfilled, the server thread will do a post signalling the request is -/// is processed. -/// \return -Status CacheServer::ServerRequest() { - TaskManager::FindMe()->Post(); - // Loop forever until we are interrupted. - while (true) { - BaseRequest *base_rq = nullptr; - RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); - auto cs = GetService(base_rq->connection_id_); - // Except for creating a new session, we expect cs is not null. - switch (base_rq->type_) { - case BaseRequest::RequestType::kCacheRow: { - if (cs == nullptr) { - std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - // Only if the cookie matches, we can accept insert into this cache that has a build phase - if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { - rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); - } - } - break; - } - case BaseRequest::RequestType::kBatchFetchRows: { - if (cs == nullptr) { - std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); - } - break; - } - case BaseRequest::RequestType::kCreateCache: { - // If the cache is already created we still need to run the creation so that we do sanity checks on the - // client id and return the cache id back to the user. - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); - break; - } - case BaseRequest::RequestType::kPurgeCache: { - if (cs != nullptr) { - base_rq->rc_ = cs->Purge(); - } else { - // it is already purged. Ignore it. - base_rq->rc_ = Status::OK(); - } - break; - } - case BaseRequest::RequestType::kDestroyCache: { - if (cs != nullptr) { - // We need a strong lock to protect the map. - connection_id_type id = base_rq->connection_id_; - UniqueLock lck(&rwLock_); - // std::map will invoke the constructor of CacheService. So we don't need to do anything here. - auto n = all_caches_.erase(id); - if (n == 0) { - // It has been destroyed by another duplicate request. - MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; - } - base_rq->rc_ = Status::OK(); - } else { - // it is already destroyed. Ignore it. - base_rq->rc_ = Status::OK(); - } - break; - } - case BaseRequest::RequestType::kGetStat: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - CacheService::ServiceStat svc_stat; - rq->rc_ = cs->GetStat(&svc_stat); - if (rq->rc_.IsOk()) { - flatbuffers::FlatBufferBuilder fbb; - ServiceStatMsgBuilder bld(fbb); - bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); - bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); - bld.add_max_row_id(svc_stat.max_); - bld.add_min_row_id(svc_stat.min_); - bld.add_state(svc_stat.state_); - auto offset = bld.Finish(); - fbb.Finish(offset); - rq->rc_ = rq->mem_.allocate(fbb.GetSize()); - if (rq->rc_.IsOk()) { - WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); - ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); - RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); - } - } - } - break; - } - case BaseRequest::RequestType::kCacheSchema: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); - } - break; - } - case BaseRequest::RequestType::kFetchSchema: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - rq->rc_ = cs->FetchSchema(&rq->mem_); - } - break; - } - case BaseRequest::RequestType::kBuildPhaseDone: { - if (cs == nullptr) { - std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); - } else { - auto *rq = reinterpret_cast(base_rq); - // We can only allow to switch phase is the cookie match. - if (rq->cookie_ == cs->cookie()) { - rq->rc_ = cs->BuildPhaseDone(); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); - } - } - break; - } - default: - base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); - } - // Notify it is done, and move on to the next request. - base_rq->wp_.Set(); - } - return Status::OK(); -} -CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) - : top_(spill_path), num_workers_(num_workers) {} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/dataset/engine/cache/cache_server.h deleted file mode 100644 index f83fa1cb6d..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_server.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -#ifndef DATASET_ENGINE_CACHE_SERVER_H_ -#define DATASET_ENGINE_CACHE_SERVER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/engine/cache/cache_service.h" -#include "dataset/core/tensor.h" -#include "dataset/util/arena.h" -#include "dataset/util/cache_pool.h" -#include "dataset/util/lock.h" -#include "dataset/util/service.h" -#include "dataset/util/services.h" -#include "dataset/util/system_pool.h" -#include "dataset/util/queue.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -class BaseRequest; -/// \brief A server which provides CacheService services. -class CacheServer : public Service { - public: - friend class Services; - using cache_index = std::map>; - - CacheServer(const CacheServer &) = delete; - CacheServer &operator=(const CacheServer &) = delete; - CacheServer(CacheServer &&) = delete; - CacheServer &operator=(CacheServer &) = delete; - static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } - Status DoServiceStart() override; - Status DoServiceStop() override; - ~CacheServer() { (void)ServiceStop(); } - - /// \brief For the current demonstration, a cache client contacts cache server using a Queue. - /// \param rq - /// \return Status object - Status PushRequest(BaseRequest *rq) { - RETURN_UNEXPECTED_IF_NULL(rq); - RETURN_IF_NOT_OK(cache_q_->Add(rq)); - return Status::OK(); - } - - private: - mutable RWLock rwLock_; - std::string top_; - cache_index all_caches_; - std::shared_ptr> cache_q_; - TaskGroup vg_; - int32_t num_workers_; - - /// \brief Constructor - /// \param spill_path Top directory for spilling buffers to. - /// \param num_workers Number of threads for handling requests. - explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); - - /// \brief Locate a cache service from connection id. - /// \return Pointer to cache service. Null if not found - CacheService *GetService(connection_id_type id) const; - - /// \brief Create a cache service. We allow multiple clients to create the same cache service. - /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given - /// a special unique cookie. - /// \param[in] connection_id This is from a Cache client. - /// \param[in] cache_mem_sz - /// \param[in] flag - /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator - /// \return Status object - Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, - std::string *out_cookie); - - /// \brief Entry point for all server threads. - Status ServerRequest(); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CORE_CACHE_TENSOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/dataset/engine/cache/cache_service.cc deleted file mode 100644 index 555413a566..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_service.cc +++ /dev/null @@ -1,265 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/engine/cache/cache_service.h" -#include "dataset/util/slice.h" - -namespace mindspore { -namespace dataset { -CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) - : root_(root), - cache_mem_sz_(mem_sz), - cp_(nullptr), - map_(nullptr), - next_id_(0), - generate_id_(generate_id), - schema_key_(-1), - st_(generate_id ? State::kBuildPhase : State::kNone) {} -CacheService::~CacheService() { (void)ServiceStop(); } -bool CacheService::UseArena() { - // If fixed size, use Arena instead of the pool from global context. - return (cache_mem_sz_ > 0); -} -Status CacheService::DoServiceStart() { - std::shared_ptr mp_; - if (UseArena()) { - // Create a fixed size arena based on the parameter. - std::shared_ptr arena; - RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); - mp_ = std::move(arena); - } else { - // Unlimited size. Simply use a system pool. Another choice is CircularPool. - mp_ = std::make_shared(); - } - // Put together a CachePool for backing up the Tensor - cp_ = std::make_shared(CachePool::value_allocator(mp_), root_); - RETURN_IF_NOT_OK(cp_->ServiceStart()); - // Set up the B+ tree as well. But use the system pool instead. - map_ = std::make_shared(); - // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. - cookie_ = cp_->MyName(); - return Status::OK(); -} -Status CacheService::DoServiceStop() { - if (cp_ != nullptr) { - RETURN_IF_NOT_OK(cp_->ServiceStop()); - } - return Status::OK(); -} -Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { - SharedLock rw(&rw_lock_); - RETURN_UNEXPECTED_IF_NULL(row_id_generated); - if (st_ == State::kFetchPhase) { - // For this kind of cache service, once we are done with the build phase into fetch phase, we can't - // allow other to cache more rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } - try { - // The first buffer is a flatbuffer which describes the rest of the buffers follow - auto fb = buf.front(); - RETURN_UNEXPECTED_IF_NULL(fb); - auto msg = GetTensorRowHeaderMsg(fb); - // If the server side is designed to ignore incoming row id, we generate row id. - if (generate_id_) { - *row_id_generated = GetNextRowId(); - // Some debug information on how many rows we have generated so far. - if ((*row_id_generated) % 1000 == 0) { - MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; - } - } else { - if (msg->row_id() < 0) { - std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); - RETURN_STATUS_UNEXPECTED(errMsg); - } - *row_id_generated = msg->row_id(); - } - auto size_of_this = msg->size_of_this(); - auto column_hdr = msg->column(); - // Number of tensor buffer should match the number of columns plus one. - if (buf.size() != column_hdr->size() + 1) { - std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) + - " but get " + std::to_string(buf.size()); - RETURN_STATUS_UNEXPECTED(errMsg); - } - // Next we store in either memory or on disk. Low level code will consolidate everything in one piece. - std::vector all_data; - all_data.reserve(column_hdr->size() + 1); - all_data.emplace_back(fb, size_of_this); - for (auto i = 0; i < column_hdr->size(); ++i) { - all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); - } - // Now we cache the flat buffer. - CachePool::key_type key; - RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); - Status rc = map_->DoInsert(*row_id_generated, key); - if (rc == Status(StatusCode::kDuplicateKey)) { - MS_LOG(DEBUG) << "Ignoring duplicate key."; - } else { - RETURN_IF_NOT_OK(rc); - } - return Status::OK(); - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } -} -std::ostream &operator<<(std::ostream &out, const CacheService &cs) { - // Then show any custom derived-internal stuff - out << "\nCache memory size: " << cs.cache_mem_sz_; - out << "\nSpill path: "; - if (cs.root_.empty()) { - out << "None"; - } else { - out << cs.GetSpillPath(); - } - return out; -} -Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } -Status CacheService::Purge() { - // First we must lock exclusively. No one else can cache/restore anything. - UniqueLock rw(&rw_lock_); - RETURN_IF_NOT_OK(cp_->ServiceStop()); - auto new_map = std::make_shared(); - map_.reset(); - map_ = std::move(new_map); - next_id_ = 0; - RETURN_IF_NOT_OK(cp_->ServiceStart()); - return Status::OK(); -} -Status CacheService::GetStat(CacheService::ServiceStat *out) { - SharedLock rw(&rw_lock_); - RETURN_UNEXPECTED_IF_NULL(out); - if (st_ == State::kNone || st_ == State::kFetchPhase) { - out->stat_ = cp_->GetStat(); - out->state_ = static_cast(st_); - auto it = map_->begin(); - if (it != map_->end()) { - out->min_ = it.key(); - auto end_it = map_->end(); - --end_it; - out->max_ = end_it.key(); - } - } else { - out->state_ = static_cast(st_); - } - return Status::OK(); -} -Status CacheService::BatchFetch(const std::vector &v, MemGuard *out) const { - RETURN_UNEXPECTED_IF_NULL(out); - SharedLock rw(&rw_lock_); - if (st_ == State::kBuildPhase) { - // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } - const auto num_elements = v.size(); - int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); - int64_t data_offset = mem_sz; - std::vector sz_v; - std::vector keys; - sz_v.reserve(num_elements); - keys.reserve(num_elements); - for (auto row_id : v) { - auto r = map_->Search(row_id); - if (r.second) { - auto &it = r.first; - CachePool::key_type key = it.value(); - auto sz = cp_->GetSize(key); - if (sz == 0) { - std::string errMsg = "Key not found: "; - errMsg += std::to_string(key); - RETURN_STATUS_UNEXPECTED(errMsg); - } - keys.push_back(key); - sz_v.push_back(sz); - mem_sz += sz; - } else { - keys.push_back(-1); - sz_v.push_back(0); - } - } - MemGuard mem; - RETURN_IF_NOT_OK(mem.allocate(mem_sz)); - auto *offset_array = reinterpret_cast(mem.GetMutablePointer()); - offset_array[0] = data_offset; - WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); - for (auto i = 0; i < num_elements; ++i) { - auto sz = sz_v.at(i); - offset_array[i + 1] = offset_array[i] + sz; - if (sz > 0) { - WritableSlice row_data(all, offset_array[i], sz); - auto key = keys.at(i); - size_t bytesRead = 0; - RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); - if (bytesRead != sz) { - MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." - << " Internal key: " << key << "\n"; - RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); - } - } - } - *out = std::move(mem); - return Status::OK(); -} -Status CacheService::CacheSchema(const void *buf, int64_t len) { - SharedLock rw(&rw_lock_); - if (st_ == State::kFetchPhase) { - // For this kind of cache service, once we are done with the build phase into fetch phase, we can't - // allow other to cache more rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } - // This is a special request and we need to remember where we store it. - // In case we are calling the same function from multiple threads, only - // the first one is considered. Rest is ignored. - CachePool::key_type cur_key = schema_key_; - CachePool::key_type key; - if (cur_key < 0) { - RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); - auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); - MS_LOG(DEBUG) << "Caching Schema. Result = " << result; - } else { - MS_LOG(DEBUG) << "Caching Schema already done"; - } - return Status::OK(); -} -Status CacheService::FetchSchema(MemGuard *out) const { - SharedLock rw(&rw_lock_); - if (st_ == State::kBuildPhase) { - // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. - RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); - } - RETURN_UNEXPECTED_IF_NULL(out); - MemGuard mem; - if (schema_key_ >= 0) { - auto len = cp_->GetSize(schema_key_); - RETURN_IF_NOT_OK(mem.allocate(len)); - auto slice = WritableSlice(mem.GetMutablePointer(), len); - RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); - *out = std::move(mem); - } else { - return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); - } - return Status::OK(); -} -Status CacheService::BuildPhaseDone() { - if (HasBuildPhase()) { - // Exclusive lock to switch phase - UniqueLock rw(&rw_lock_); - st_ = State::kFetchPhase; - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/dataset/engine/cache/cache_service.h deleted file mode 100644 index 60cfa40a50..0000000000 --- a/mindspore/ccsrc/dataset/engine/cache/cache_service.h +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -#ifndef DATASET_ENGINE_CACHE_SERVICE_H_ -#define DATASET_ENGINE_CACHE_SERVICE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "./de_tensor_generated.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/cache/cache_request.h" -#include "dataset/util/arena.h" -#include "dataset/util/btree.h" -#include "dataset/util/cache_pool.h" -#include "dataset/util/service.h" -#include "dataset/util/services.h" -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -struct CacheStat; -/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is -/// created to support spilling -class CacheService : public Service { - public: - friend class CacheServer; - using row_map = BPlusTree; - - enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; - - /// \brief Constructor - /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited - /// \param root Spill path. Empty string means no spilling - /// \param generate_id If the cache service should generate row id for buffer that is cached. - /// For non-mappable dataset, this should be set to true. - CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); - ~CacheService(); - - /// \brief For fixed size memory, we will create an Arena. - /// \return false if unlimited memory. - bool UseArena(); - - Status DoServiceStart() override; - Status DoServiceStop() override; - - /// \brief Main function to cache a row which is in form a series of buffers. - /// The first buffer is a Google flatbuffer which describes the rest of the buffers followed. - /// \param[in] buf Vector of buffer - /// \param[out] row_id_generated The row id assigned to this row if any - /// \return Status object - Status CacheRow(const std::vector &buf, row_id_type *row_id_generated); - /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded - /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. - /// \param[in] v A vector of row id. - /// \param[out] out A contiguous memory buffer that holds the requested rows. - /// \return Status object - Status BatchFetch(const std::vector &v, MemGuard *out) const; - - /// \brief Getter function - /// \return Spilling path - Path GetSpillPath() const; - /// \brief A structure returned from the cache server for statistics request. - class ServiceStat { - public: - using state_type = std::underlying_type::type; - ServiceStat() : min_(0), max_(0), state_(0) {} - CachePool::CacheStat stat_{}; - row_id_type min_; - row_id_type max_; - state_type state_; - }; - /// \brief Statistics for the current service - /// \param[in/out] A pointer to a pre-allocated ServiceStat structure - /// \return Status Object - Status GetStat(ServiceStat *); - /// \brief Cache schema - /// \param buf A Google Flatbuffer that contains the schema - /// \param len size of the buffer - /// \return Status object - Status CacheSchema(const void *buf, int64_t len); - /// \brief Fetch schema - /// \param out A contiguous memory that contains the serialized form of schema. - /// \return Status object - Status FetchSchema(MemGuard *out) const; - /// \brief Purge the content of a cache - /// \return Status object - Status Purge(); - /// \brief Overload the << operator to print a cache service - /// \param out std::ostream - /// \param cs A cache service - /// \return std::ostream - friend std::ostream &operator<<(std::ostream &out, const CacheService &cs); - /// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient - /// is the creator - /// \return Cookie - std::string cookie() const { return cookie_; } - /// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and - /// a read phase. - /// \return True if has two phases. - bool HasBuildPhase() const { return generate_id_; } - /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. - /// \return Status object - Status BuildPhaseDone(); - - private: - mutable RWLock rw_lock_; - std::string root_; - uint64_t cache_mem_sz_; - std::shared_ptr cp_; - std::shared_ptr map_; - std::atomic next_id_; - bool generate_id_; - std::atomic schema_key_; - std::string cookie_; - State st_; - - /// \brief Private function to generate a row id - /// \return Row id assigned. - row_id_type GetNextRowId() { return next_id_.fetch_add(1); } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/engine/connector.h b/mindspore/ccsrc/dataset/engine/connector.h deleted file mode 100644 index bd66172be5..0000000000 --- a/mindspore/ccsrc/dataset/engine/connector.h +++ /dev/null @@ -1,211 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_CONNECTOR_H_ -#define DATASET_ENGINE_CONNECTOR_H_ - -#include -#include -#include -#include -#include "dataset/util/task_manager.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/cond_var.h" - -namespace mindspore { -namespace dataset { -// Connector is a communication data structure between two group of threads that -// preserve the order. -// -// Example use case: -// An initial tasks-list of [1,2,3,4,5,6,7,8,9] with 5 threads getting/processing elements from that list, -// and pushing the processed elements to a Connector in any order whoever finishes processing first. -// If the consumer of the Connector is single threaded, when the consumer pop() the -// element from the Connector one by one, it will get [1,2,3,4,5,6,7,8,9]. -// -// Requirements: -// 1. Each thread in the group of consumer or producer threads must be assigned ids starting from 0. -// 2. If your multi-threads program is not reading from a Connector class but -// want to push to a Connector class, you must follow roundrobin element distribution, -// i.e., the thread-id0 must have the first element, thread-id1 has the second element, -// and so on; then each of this worker can push to the Connector class async in parallel. -// -// Blocking conditions: -// 1. Connector.push(int, T) can block when the internal queue it's trying to push is full. -// 2. Connector.pop(int) can block when -// - The internal queue it's trying to pop is empty. -// - The caller thread of pop() is not equal to the _expectConsumer. This is to enforce -// the ordering. -// -// Future improvement: -// 1. Fault tolerant: Right now, if one of the worker dies, the Connector will not work -// properly. -template -class Connector { - public: - // Name: Constructor - // Description: Initializing private members with the given input arguments. - // expect_consumer_ and pop_from_ is initialized to 0 as part of - // our requirements. We instantiate nProducers number of internal - // queues so that each producer thread can push to its queue without - // any sync overhead. - // Constructor of Connector - // Initializing private members with the given input arguments. - // _expectConsumer and _popFrom is initialized to 0 as part of - // our requirements. We instantiate nProducers number of internal - // queues so that each producer thread can push to its queue without - // any sync overhead. - // @param n_producers The number of threads producing data into this DbConnector. - // @param n_consumers The number of thread consuming data from this DbConnector. - // @param queue_capacity The number of element (DataBuffer) for each queue. - Connector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) - : num_producers_(n_producers), num_consumers_(n_consumers) { - MS_LOG(DEBUG) << "A connector is created with " << n_producers << " producers and " << n_consumers << " consumers."; - my_name_ = Services::GetUniqueID(); - // We require the consumers to have ids sequentially from 0 to the num_consumers_-1, - // Otherwise a ordered list of consumer ids have to be passed here. (not implemented yet) - expect_consumer_ = 0; - - // Roundrobin pop starts from index 0 of the queues_. - pop_from_ = 0; - - // Initialize the queues_ to have num_producers_ number of queues. - // Each queue is a blocking queue and has the same queue_capacity. - queues_.Init(num_producers_, queue_capacity); - } - - // Destructor of Connector - virtual ~Connector() = default; - - // Get an element from the Connector. - // @not Call to pop() can block the caller thread, see the blocking condition at the top of this file. - // @param worker_id The id of a worker thread calling this method. - // @param result The address of an object where the popped element will be placed. - virtual Status Pop(int32_t worker_id, // The worker-id of the caller. See the requirement at the top of this file. - T *result) noexcept { - { - MS_ASSERT(worker_id < num_consumers_); - std::unique_lock lk(m_); - RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; })); - RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); - pop_from_ = (pop_from_ + 1) % num_producers_; - out_buffers_count_++; - expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; - } - - cv_.NotifyAll(); - return Status::OK(); - } - - // Add an element into the DbConnector without the overhead of synchronization. - // It may block when the internal queue is full. - // The element passed to this function will be copied into the internal queue. - // @param worker_id The id of a worker thread calling this method. - // @param el A const lvalue element to be passed/added/pushed. - Status Push(int32_t worker_id, const T &el) noexcept { - MS_ASSERT(worker_id < static_cast(queues_.size())); - MS_ASSERT(queues_[worker_id] != nullptr); - return (queues_[worker_id]->Add(el)); - } - - auto out_buffers_count() const { return out_buffers_count_.load(); } - - // Add an element into the DbConnector without the overhead of synchronization. - // It may block when the internal queue is full. - // The element passed to this function will be forwarded into the internal queue. - // @param worker_id The id of a worker thread calling this method. - // @param el An element to be passed/added/pushed. - virtual Status Push(int32_t worker_id, T &&el) noexcept { - MS_ASSERT(worker_id < static_cast(queues_.size())); - MS_ASSERT(queues_[worker_id] != nullptr); - return (queues_[worker_id]->Add(std::forward(el))); - } - - // Resets the internal index tracking of the queue so that it can be used again with new inputs, - // starting from the beginning. - void Reset() { - for (int i = 0; i < queues_.size(); ++i) { - queues_[i]->ResetQue(); - } - expect_consumer_ = 0; - pop_from_ = 0; - out_buffers_count_ = 0; - MS_LOG(DEBUG) << "Connector counters reset."; - } - - void Print(std::ostream &out, bool showAll) const { - out << "\n--------- Connector ------------" - << "\nConnector Name : " << my_name_ << "\nNumber of consumers : " << num_consumers_ - << "\nNumber of producers : " << num_producers_ << "\n"; - } - - friend std::ostream &operator<<(std::ostream &out, const Connector &con) { - con.print(out, false); - return out; - } - - // Get current size of connector. - int32_t size() const { - int32_t size = 0; - for (int32_t i = 0; i < queues_.size(); ++i) { - size += queues_[i]->size(); - } - return size; - } - - int32_t capacity() const { - int32_t capacity = 0; - for (int32_t i = 0; i < queues_.size(); ++i) { - capacity += queues_[i]->capacity(); - } - return capacity; - } - - // Register the internal resources with Task group for interruption service. - // @param vg - // @return - Status Register(TaskGroup *vg) { - Status rc = queues_.Register(vg); - if (rc.IsOk()) { - rc = cv_.Register(vg->GetIntrpService()); - } - return rc; - } - - protected: - std::string my_name_; - - // A list of Queues that are thread safe. - QueueList queues_; - - // The consumer that we allow to get the next data from pop() - int32_t expect_consumer_; - - // The index to the queues_ where the next data should be popped. - int32_t pop_from_; - - int32_t num_producers_; - int32_t num_consumers_; - - // Used in the Pop(), when a thread call pop() but it is not the expect_consumer_. - std::mutex m_; - CondVar cv_; - std::atomic out_buffers_count_ = 0; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/data_buffer.cc b/mindspore/ccsrc/dataset/engine/data_buffer.cc deleted file mode 100644 index 718721b906..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_buffer.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/data_buffer.h" -#include "dataset/util/allocator.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" - -namespace mindspore { -namespace dataset { -// Name: Constructor #1 -// Description: This is the main constructor that is used for making a buffer -DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} - -// A method for debug printing of the buffer -void DataBuffer::Print(std::ostream &out, bool show_all) const { - out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; - - // If the column counts are set then it means that data has been set into - // the tensor table. Display the tensor table here. - if (this->NumCols() > 0) { - out << "Tensor table:\n"; - for (int32_t row = 0; row < DataBuffer::NumRows(); ++row) { - out << "Row # : " << row << "\n"; - TensorRow currRow = (*tensor_table_)[row]; - for (int32_t col = 0; col < this->NumCols(); ++col) { - out << "Column #: " << col << "\n"; // Should add the column name here as well? - // Call the tensor display - out << *(currRow[col]) << "\n"; - } - } - } -} - -// Remove me!! Callers should fetch rows via pop -Status DataBuffer::GetTensor(std::shared_ptr *ptr, int32_t row_id, int32_t col_id) const { - if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { - *ptr = (tensor_table_->at(row_id)).at(col_id); - } else { - std::string err_msg = - "indices for mTensorTable out of range: (" + std::to_string(row_id) + "," + std::to_string(col_id) + ")."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// Remove me!! Callers should fetch rows via pop -Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const { - if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) { - *ptr = tensor_table_->at(row_id); - } else { - std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - return Status::OK(); -} - -Status DataBuffer::PopRow(TensorRow *ptr) { - if (tensor_table_ && !tensor_table_->empty()) { - *ptr = std::move(tensor_table_->front()); - tensor_table_->pop_front(); - } - - return Status::OK(); -} - -Status DataBuffer::SliceOff(int64_t number_of_rows) { - while (number_of_rows > 0) { - tensor_table_->pop_back(); - number_of_rows--; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/data_buffer.h b/mindspore/ccsrc/dataset/engine/data_buffer.h deleted file mode 100644 index b539bdaf7b..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_buffer.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATA_BUFFER_H_ -#define DATASET_ENGINE_DATA_BUFFER_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { -/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between -/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format -/// where n TensorRows may consist of m tensors (columns). -class DataBuffer { - public: - // Buffer flags - enum BufferFlags : uint32_t { - kDeBFlagNone = 0, - kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg - kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg - }; - - // Name: Constructor #1 - // Description: This is the main constructor that is used for making a buffer - DataBuffer(int32_t id, BufferFlags flags); - - /// \brief default destructor - ~DataBuffer() = default; - - /// \brief A method for debug printing of the buffer - /// \param[inout] out The stream to write to - /// \param[in] show_all A boolean to toggle between details and summary printing - void Print(std::ostream &out, bool show_all) const; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { - cb.Print(out, false); - return out; - } - - // Convenience getter functions for flag checking - bool eof() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOF)); } - - bool eoe() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOE)); } - - // Simple getter funcs - int32_t id() const { return buffer_id_; } - - void set_id(int32_t id) { buffer_id_ = id; } - - int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); } - - int32_t NumCols() const { - return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size(); - } - - BufferFlags buffer_flags() const { return buffer_flags_; } - - // Remove me!! Callers should fetch rows via pop - Status GetTensor(std::shared_ptr *, int32_t row_id, int32_t col_id) const; - - // Remove me!! Callers should drain rows via pop. - Status GetRow(int32_t row_id, TensorRow *) const; - - // Get a row from the TensorTable - Status PopRow(TensorRow *); - - Status SliceOff(int64_t number_of_rows); - - // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable. - void set_tensor_table(std::unique_ptr new_table) { tensor_table_ = std::move(new_table); } - - void set_flag(BufferFlags in_flag) { - buffer_flags_ = static_cast(static_cast(buffer_flags_) | static_cast(in_flag)); - } - - void Shuffle() {} // does nothing right now. possibly remove later - - protected: - int32_t buffer_id_; // An id for the buffer. - std::unique_ptr tensor_table_; // A table (row major) of Tensors - BufferFlags buffer_flags_; // bit mask for various buffer properties -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATA_BUFFER_H_ diff --git a/mindspore/ccsrc/dataset/engine/data_schema.cc b/mindspore/ccsrc/dataset/engine/data_schema.cc deleted file mode 100644 index 6c5f882bed..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_schema.cc +++ /dev/null @@ -1,451 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/data_schema.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/util/status.h" -#include "dataset/core/tensor_shape.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// A macro for converting an input string representing the column type to it's actual -// numeric column type. -#define STR_TO_TENSORIMPL(in_col_str, out_type) \ - do { \ - if (in_col_str == "cvmat") { \ - out_type = TensorImpl::kCv; \ - } else if (in_col_str == "flex") { \ - out_type = TensorImpl::kFlexible; \ - } else if (in_col_str == "np") { \ - out_type = TensorImpl::kNP; \ - } else { \ - out_type = TensorImpl::kNone; \ - } \ - } while (false) - -// Constructor 1: Simple constructor that leaves things uninitialized. -ColDescriptor::ColDescriptor() - : type_(DataType::DE_UNKNOWN), rank_(0), tensor_impl_(TensorImpl::kNone), tensor_shape_(nullptr) {} - -// Constructor 2: Main constructor -ColDescriptor::ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, - const TensorShape *in_shape) - : type_(col_type), rank_(rank), tensor_impl_(tensor_impl), col_name_(col_name) { - // If a shape was provided, create unique pointer for it and copy construct it into - // our shape. Otherwise, set our shape to be empty. - if (in_shape != nullptr) { - // Create a shape and copy construct it into our column's shape. - tensor_shape_ = std::make_unique(*in_shape); - } else { - tensor_shape_ = nullptr; - } - // If the user input a shape, then the rank of the input shape needs to match - // the input rank - if (in_shape != nullptr && in_shape->known() && in_shape->Size() != rank_) { - rank_ = in_shape->Size(); - MS_LOG(WARNING) << "Rank does not match the number of dimensions in the provided shape." - << " Overriding rank with the number of dimensions in the provided shape."; - } -} - -// Explicit copy constructor is required -ColDescriptor::ColDescriptor(const ColDescriptor &in_cd) - : type_(in_cd.type_), rank_(in_cd.rank_), tensor_impl_(in_cd.tensor_impl_), col_name_(in_cd.col_name_) { - // If it has a tensor shape, make a copy of it with our own unique_ptr. - tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; -} - -// Assignment overload -ColDescriptor &ColDescriptor::operator=(const ColDescriptor &in_cd) { - if (&in_cd != this) { - type_ = in_cd.type_; - rank_ = in_cd.rank_; - tensor_impl_ = in_cd.tensor_impl_; - col_name_ = in_cd.col_name_; - // If it has a tensor shape, make a copy of it with our own unique_ptr. - tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; - } - return *this; -} - -// Destructor -ColDescriptor::~ColDescriptor() = default; - -// A print method typically used for debugging -void ColDescriptor::Print(std::ostream &out) const { - out << " Name : " << col_name_ << "\n Type : " << type_ << "\n Rank : " << rank_ - << "\n Shape : ("; - if (tensor_shape_) { - out << *tensor_shape_ << ")\n"; - } else { - out << "no shape provided)\n"; - } -} - -// Given a number of elements, this function will compute what the actual Tensor shape would be. -// If there is no starting TensorShape in this column, or if there is a shape but it contains -// an unknown dimension, then the output shape returned shall resolve dimensions as needed. -Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const { - if (out_shape == nullptr) { - RETURN_STATUS_UNEXPECTED("Unexpected null output shape argument."); - } - - // If the shape is not given in this column, then we assume the shape will be: {numElements} - if (tensor_shape_ == nullptr) { - if (this->rank() == 0 && num_elements == 1) { - *out_shape = TensorShape::CreateScalar(); - return Status::OK(); - } - *out_shape = TensorShape({num_elements}); - return Status::OK(); - } - - // Build the real TensorShape based on the requested shape and the number of elements in the data. - // If there are unknown dimensions, then the unknown dimension needs to be filled in. - // Example: requestedShape: {?,4,3}. - // If numElements is 24, then the output shape can be computed to: {2,4,3} - std::vector requested_shape = tensor_shape_->AsVector(); - int64_t num_elements_of_shape = 1; // init to 1 as a starting multiplier. - - // unknownDimPosition variable is overloaded to provide 2 meanings: - // 1) If it's set to DIM_UNKNOWN, then it provides a boolean knowledge to tell us if there are - // any unknown dimensions. i.e. if it's set to unknown, then there are no unknown dimensions. - // 2) If it's set to a numeric value, then this is the vector index position within the shape - // where the single unknown dimension can be found. - int64_t unknown_dim_position = TensorShape::kDimUnknown; // Assume there are no unknown dims to start - - for (int i = 0; i < requested_shape.size(); ++i) { - // If we already had an unknown dimension, then we cannot have a second unknown dimension. - // We only support the compute of a single unknown dim. - if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Requested shape has more than one unknown dimension!"); - } - - // If the current dimension in the requested shape is a known value, then compute the number of - // elements so far. - if (requested_shape[i] != TensorShape::kDimUnknown) { - num_elements_of_shape *= requested_shape[i]; - } else { - // This dimension is unknown so track which dimension position has it. - unknown_dim_position = i; - } - } - - // Sanity check the the computed element counts divide evenly into the input element count - if (num_elements < num_elements_of_shape || num_elements_of_shape == 0 || num_elements % num_elements_of_shape != 0) { - RETURN_STATUS_UNEXPECTED("Requested shape has an invalid element count!"); - } - - // If there was any unknown dimensions, then update the requested shape to fill in the unknown - // dimension with the correct value. If there were no unknown dim's then the output shape will - // remain to be the same as the requested shape. - if (unknown_dim_position != TensorShape::kDimUnknown) { - requested_shape[unknown_dim_position] = (num_elements / num_elements_of_shape); - } - - // Any unknown dimension is filled in now. Set the output shape - *out_shape = TensorShape(requested_shape); - return Status::OK(); -} - -// getter function for the shape -TensorShape ColDescriptor::shape() const { - if (tensor_shape_ != nullptr) { - return *tensor_shape_; // copy construct a shape to return - } else { - return TensorShape::CreateUnknownRankShape(); // empty shape to return - } -} - -const char DataSchema::DEFAULT_DATA_SCHEMA_FILENAME[] = "datasetSchema.json"; - -// Constructor 1: Simple constructor that leaves things uninitialized. -DataSchema::DataSchema() : num_rows_(0) {} - -// Internal helper function. Parses the json schema file in any order and produces a schema that -// does not follow any particular order (json standard does not enforce any ordering protocol). -// This one produces a schema that contains all of the columns from the schema file. -Status DataSchema::AnyOrderLoad(nlohmann::json column_tree) { - // Iterate over the json file. Each parent json node is the column name, - // followed by the column properties in the child tree under the column. - // Outer loop here iterates over the parents (i.e. the column name) - if (!column_tree.is_array()) { - for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { - std::string col_name = it.key(); - nlohmann::json column_child_tree = it.value(); - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); - } - } else { - // Case where the schema is a list of columns not a dict - for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { - nlohmann::json column_child_tree = it.value(); - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, "")); - } - } - return Status::OK(); -} - -// Internal helper function. For each input column name, perform a lookup to the json document to -// find the matching column. When the match is found, process that column to build the column -// descriptor and add to the schema in the order in which the input column names are given.id -Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load) { - if (!column_tree.is_array()) { - // the json file is dict (e.g., {image: ...}) - // Loop over the column name list - for (const auto &curr_col_name : columns_to_load) { - // Find the column in the json document - auto column_info = column_tree.find(common::SafeCStr(curr_col_name)); - if (column_info == column_tree.end()) { - RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); - } - // At this point, columnInfo.value() is the subtree in the json document that contains - // all of the data for a given column. This data will formulate our schema column. - const std::string &col_name = column_info.key(); - nlohmann::json column_child_tree = column_info.value(); - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); - } - } else { - // the json file is array (e.g., [name: image...]) - // Loop over the column name list - for (const auto &curr_col_name : columns_to_load) { - // Find the column in the json document - int32_t index = -1; - int32_t i = 0; - for (const auto &it_child : column_tree.items()) { - auto name = it_child.value().find("name"); - if (name == it_child.value().end()) { - RETURN_STATUS_UNEXPECTED("Name field is missing for this column."); - } - if (name.value() == curr_col_name) { - index = i; - break; - } - i++; - } - if (index == -1) { - RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); - } - nlohmann::json column_child_tree = column_tree[index]; - RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, curr_col_name)); - } - } - return Status::OK(); -} - -// Internal helper function for parsing shape info and building a vector for the shape construction. -static Status buildShape(const nlohmann::json &shapeVal, std::vector *outShape) { - if (outShape == nullptr) { - RETURN_STATUS_UNEXPECTED("null output shape"); - } - if (shapeVal.empty()) return Status::OK(); - - // Iterate over the integer list and add those values to the output shape tensor - auto items = shapeVal.items(); - using it_type = decltype(items.begin()); - (void)std::transform(items.begin(), items.end(), std::back_inserter(*outShape), [](it_type j) { return j.value(); }); - return Status::OK(); -} - -// Internal helper function. Given the json tree for a given column, load it into our schema. -Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name) { - int32_t rank_value = -1; - TensorImpl t_impl_value = TensorImpl::kFlexible; - std::string name, type_str; - std::vector tmp_shape = {}; - bool shape_field_exists = false; - // Iterate over this column's attributes. - // Manually iterating each of the child nodes/trees here so that we can provide our own error handling. - for (const auto &it_child : column_child_tree.items()) { - // Save the data for each of the attributes into variables. We'll use these to construct later. - if (it_child.key() == "name") { - name = it_child.value(); - } else if (it_child.key() == "type") { - type_str = it_child.value(); - } else if (it_child.key() == "rank") { - rank_value = it_child.value(); - } else if (it_child.key() == "t_impl") { - STR_TO_TENSORIMPL(it_child.value(), t_impl_value); - } else if (it_child.key() == "shape") { - shape_field_exists = true; - RETURN_IF_NOT_OK(buildShape(it_child.value(), &tmp_shape)); - } else { - std::string err_msg = "Unexpected column attribute " + it_child.key() + " for column " + col_name; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - if (!name.empty()) { - if (!col_name.empty() && col_name != name) { - std::string err_msg = - "json schema file for column " + col_name + " has column name that does not match columnsToLoad"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } else { - if (col_name.empty()) { - std::string err_msg = "json schema file for column " + col_name + " has invalid or missing column name."; - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - name = col_name; - } - } - // data type is mandatory field - if (type_str.empty()) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "json schema file for column " + col_name + " has invalid or missing column type."); - - // rank number is mandatory field - if (rank_value <= -1) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "json schema file for column " + col_name + " must define a positive rank value."); - - // Create the column descriptor for this column from the data we pulled from the json file - TensorShape col_shape = TensorShape(tmp_shape); - if (shape_field_exists) - (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value, &col_shape)); - else - // Create a column descriptor that doesn't have a shape - (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value)); - return Status::OK(); -} - -// Parses a schema json file and populates the columns and meta info. -Status DataSchema::LoadSchemaFile(const std::string &schema_file_path, - const std::vector &columns_to_load) { - try { - std::ifstream in(schema_file_path); - - nlohmann::json js; - in >> js; - RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); - try { - num_rows_ = js.at("numRows").get(); - } catch (nlohmann::json::out_of_range &e) { - num_rows_ = 0; - } catch (nlohmann::json::exception &e) { - RETURN_STATUS_UNEXPECTED("Unable to parse \"numRows\" from schema"); - } - nlohmann::json column_tree = js.at("columns"); - if (column_tree.empty()) { - RETURN_STATUS_UNEXPECTED("columns is null"); - } - if (columns_to_load.empty()) { - // Parse the json tree and load the schema's columns in whatever order that the json - // layout decides - RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); - } else { - RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); - } - } catch (const std::exception &err) { - // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Schema file failed to load"); - } - return Status::OK(); -} - -// Parses a schema json string and populates the columns and meta info. -Status DataSchema::LoadSchemaString(const std::string &schema_json_string, - const std::vector &columns_to_load) { - try { - nlohmann::json js = nlohmann::json::parse(schema_json_string); - RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); - num_rows_ = js.value("numRows", 0); - nlohmann::json column_tree = js.at("columns"); - if (column_tree.empty()) { - RETURN_STATUS_UNEXPECTED("columns is null"); - } - if (columns_to_load.empty()) { - // Parse the json tree and load the schema's columns in whatever order that the json - // layout decides - RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); - } else { - RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); - } - } catch (const std::exception &err) { - // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Schema file failed to load"); - } - return Status::OK(); -} - -// Destructor -DataSchema::~DataSchema() = default; - -// Getter for the ColDescriptor by index -const ColDescriptor &DataSchema::column(int32_t idx) const { - MS_ASSERT(idx < static_cast(col_descs_.size())); - return col_descs_[idx]; -} - -// A print method typically used for debugging -void DataSchema::Print(std::ostream &out) const { - out << "Dataset schema: ("; - for (const auto &col_desc : col_descs_) { - out << col_desc << "\n"; - } -} - -// Adds a column descriptor to the schema -Status DataSchema::AddColumn(const ColDescriptor &cd) { - // Sanity check there's not a duplicate name before adding the column - for (int32_t i = 0; i < col_descs_.size(); ++i) { - if (col_descs_[i].name() == cd.name()) { - std::ostringstream ss; - ss << "column name '" << cd.name() << "' already exists in schema."; - std::string err_msg = ss.str(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - col_descs_.push_back(cd); - return Status::OK(); -} - -// Internal helper function. Performs sanity checks on the json file setup. -Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) { - // Check if columns node exists. It is required for building schema from file. - if (js.find("columns") == js.end()) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "\"columns\" node is required in the schema json file."); - return Status::OK(); -} - -// Loops through all columns in the schema and returns a map with the column -// name to column index number. -Status DataSchema::GetColumnNameMap(std::unordered_map *out_column_name_map) { - if (out_column_name_map == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map."); - } - - for (int32_t i = 0; i < col_descs_.size(); ++i) { - if (col_descs_[i].name().empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Constructing column name map from schema, but found empty column name."); - } - (*out_column_name_map)[col_descs_[i].name()] = i; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/data_schema.h b/mindspore/ccsrc/dataset/engine/data_schema.h deleted file mode 100644 index ce61b8952d..0000000000 --- a/mindspore/ccsrc/dataset/engine/data_schema.h +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATA_SCHEMA_H_ -#define DATASET_ENGINE_DATA_SCHEMA_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -/// \class ColDescriptor data_schema.h -/// \brief A simple class to provide meta info about a column. -class ColDescriptor { - public: - /// \brief Constructor 1: Simple constructor that leaves things uninitialized. - ColDescriptor(); - - /// \brief Constructor 2: Main constructor - /// \param[in] col_name - The name of the column - /// \param[in] col_type - The DE Datatype of the column - /// \param[in] tensor_impl - The (initial) type of tensor implementation for the column - /// \param[in] rank - The number of dimension of the data - /// \param[in] in_shape - option argument for input shape - ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, - const TensorShape *in_shape = nullptr); - - /// \brief Explicit copy constructor is required - /// \param[in] in_cd - the source ColDescriptor - ColDescriptor(const ColDescriptor &in_cd); - - /// \brief Assignment overload - /// \param in_cd - the source ColDescriptor - ColDescriptor &operator=(const ColDescriptor &in_cd); - - /// \brief Destructor - ~ColDescriptor(); - - /// \brief A print method typically used for debugging - /// \param out - The output stream to write output to - void Print(std::ostream &out) const; - - /// \brief Given a number of elements, this function will compute what the actual Tensor shape would be. - /// If there is no starting TensorShape in this column, or if there is a shape but it contains - /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. - /// \param[in] num_elements - The number of elements in the data for a Tensor - /// \param[inout] out_shape - The materialized output Tensor shape - /// \return Status - The error code return - Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; - - /// \brief << Stream output operator overload - /// This allows you to write the debug print info using stream operators - /// \param[in] out - reference to the output stream being overloaded - /// \param[in] cd - reference to the ColDescriptor to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) { - cd.Print(out); - return out; - } - - /// \brief getter function - /// \return The column's DataType - DataType type() const { return type_; } - - /// \brief getter function - /// \return The column's rank - int32_t rank() const { return rank_; } - - /// \brief getter function - /// \return The column's name - std::string name() const { return col_name_; } - - /// \brief getter function - /// \return The column's shape - TensorShape shape() const; - - /// \brief getter function - /// \return TF if the column has an assigned fixed shape. - bool hasShape() const { return tensor_shape_ != nullptr; } - - /// \brief getter function - /// \return The column's tensor implementation type - TensorImpl tensorImpl() const { return tensor_impl_; } - - private: - DataType type_; // The columns type - int32_t rank_; // The rank for this column (number of dimensions) - TensorImpl tensor_impl_; // The initial flavour of the tensor for this column - std::unique_ptr tensor_shape_; // The fixed shape (if given by user) - std::string col_name_; // The name of the column -}; - -/// \class DataSchema data_schema.h -/// \brief A list of the columns. -class DataSchema { - public: - /// \brief Constructor - DataSchema(); - - /// \brief Destructor - ~DataSchema(); - - /// \brief Parses a schema json file and populates the columns and meta info. - /// \param[in] schema_file_path - the schema file that has the column's info to load - /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. - /// \return Status - The error code return - Status LoadSchemaFile(const std::string &schema_file_path, const std::vector &columns_to_load); - - /// \brief Parses a schema JSON string and populates the columns and meta info. - /// \param[in] schema_json_string - the schema file that has the column's info to load - /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. - /// \return Status - The error code return - Status LoadSchemaString(const std::string &schema_json_string, const std::vector &columns_to_load); - - /// \brief A print method typically used for debugging - /// \param[in] out - The output stream to write output to - void Print(std::ostream &out) const; - - /// \brief << Stream output operator overload. This allows you to write the debug print info using stream operators - /// \param[in] out - reference to the output stream being overloaded - /// \param[in] ds - reference to the DataSchema to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) { - ds.Print(out); - return out; - } - - /// \brief Adds a column descriptor to the schema - /// \param[in] cd - The ColDescriptor to add - /// \return Status - The error code return - Status AddColumn(const ColDescriptor &cd); - - /// \brief getter - /// \return The reference to a ColDescriptor to get (const version) - const ColDescriptor &column(int32_t idx) const; - - /// \brief getter - /// \return The number of columns in the schema - int32_t NumColumns() const { return col_descs_.size(); } - - bool Empty() const { return NumColumns() == 0; } - - /// \brief getter - /// \return The number of rows read from schema - int64_t num_rows() const { return num_rows_; } - - static const char DEFAULT_DATA_SCHEMA_FILENAME[]; - - /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. - /// \param[inout] out_column_name_map - The output map of columns names to column index - /// \return Status - The error code return - Status GetColumnNameMap(std::unordered_map *out_column_name_map); - - private: - /// \brief Internal helper function. Parses the json schema file in any order and produces a schema that - /// does not follow any particular order (json standard does not enforce any ordering protocol). - /// This one produces a schema that contains all of the columns from the schema file. - /// \param[in] column_tree - The nlohmann tree from the json file to parse - /// \return Status - The error code return - Status AnyOrderLoad(nlohmann::json column_tree); - - /// \brief Internal helper function. For each input column name, perform a lookup to the json document to - /// find the matching column. When the match is found, process that column to build the column - /// descriptor and add to the schema in the order in which the input column names are given. - /// \param[in] column_tree - The nlohmann tree from the json file to parse - /// \param[in] columns_to_load - list of strings for the columns to add to the schema - /// \return Status - The error code return - Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load); - - /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. - /// \param[in] columnTree - The nlohmann child tree for a given column to load. - /// \param[in] col_name - The string name of the column for that subtree. - /// \return Status - The error code return - Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); - - /// \brief Internal helper function. Performs sanity checks on the json file setup. - /// \param[in] js - The nlohmann tree for the schema file - /// \return Status - The error code return - Status PreLoadExceptionCheck(const nlohmann::json &js); - - std::vector col_descs_; // Vector of column descriptors - int64_t num_rows_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATA_SCHEMA_H_ diff --git a/mindspore/ccsrc/dataset/engine/dataset_iterator.cc b/mindspore/ccsrc/dataset/engine/dataset_iterator.cc deleted file mode 100644 index be333741b1..0000000000 --- a/mindspore/ccsrc/dataset/engine/dataset_iterator.cc +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/dataset_iterator.h" -#include -#include -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/status.h" -#include "dataset/engine/datasetops/dataset_op.h" - -namespace mindspore { -namespace dataset { -// Constructor of the IteratorBase -IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false) {} - -IteratorBase::~IteratorBase() = default; - -// Fetches one row of data from the iterator as a column map. -Status IteratorBase::GetNextAsMap(TensorMap *out_map) { - if (out_map == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output map in iterator!"); - } - - out_map->clear(); - - TensorRow curr_row; - RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); - - // Return empty map if there's no data - if (curr_row.empty()) { - return Status::OK(); - } - - // The column name mapping is needed to be able to produce the tensor map output. - // The column name mapping comes from the source operator that is producing the data into the iterator. - // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping - // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping. - if (col_name_id_map_.empty()) { - // Determine the column name map by calling the derived class method to retrieve the column - // name map - col_name_id_map_ = this->GetColumnNameMap(); - } - - // Populate the out map from the row and return it - for (auto colMap : col_name_id_map_) { - (*out_map)[colMap.first] = std::move(curr_row[colMap.second]); - } - - return Status::OK(); -} - -// Fetches one row of data from the iterator. -// The base class version simply performs error handling and returns empty row. Actual -// functionality exists in the derived versions of this function. -Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) { - if (out_row == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output row in iterator!"); - } - - // clear the old tensor row - out_row->clear(); - - return Status::OK(); -} - -// Constructor of the DatasetIterator -DatasetIterator::DatasetIterator(std::shared_ptr exe_tree) - : IteratorBase(), - root_(exe_tree->root()), - tracing_(nullptr), - cur_batch_num_(0), - cur_connector_size_(0), - cur_connector_capacity_(0) { - std::shared_ptr node; - Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); - if (s.IsOk()) { - tracing_ = std::dynamic_pointer_cast(node); - } -} - -DatasetIterator::~DatasetIterator() = default; - -// Fetches one row of data from the iterator. Overrides the base class. This one fetches -// from the tree root node directly. -Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { - // Common code init and error checking in the base class. - RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); - - // Once eof is handled, always return empty row. Class must be destroyed and recreated if you - // want to iterate again. - if (eof_handled_) { - return Status::OK(); - } - - // Check if we need to get a new DataBuffer to iterate. - if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { - if (tracing_ != nullptr) { - cur_connector_size_ = root_->ConnectorSize(); - cur_connector_capacity_ = root_->ConnectorCapacity(); - } - RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); - - // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually - // handle eoe and eof messages here. - // - // An eoe buffer means we have iterated fully to the end of the tree. - // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of - // all operators. - if (curr_buffer_->eoe()) { - MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; - - // Before returning the last empty vector, fetch the eof buffer which should be the last - // buffer, and then free it. - RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); - - if (!curr_buffer_->eof()) { - RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!"); - } - eof_handled_ = true; - curr_buffer_.reset(); // explicitly free the eof buffer - // Set tree to Finished state - root_->Tree()->SetFinished(); - - return Status::OK(); - } - - if (curr_buffer_->eof()) { - // An eof by itself, without being preceded by an eoe, is possible if a repeat operator - // exists below us in the stack. Repeat operator eats eoe's but eventually allows the - // flow of an eof up the pipeline by itself. - eof_handled_ = true; - curr_buffer_.reset(); // explicitly free the eof buffer - // Set tree to Finished state - root_->Tree()->SetFinished(); - return Status::OK(); - } - } - - // If we got this far, now it's time to pop that next row for return to caller - RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); - if (tracing_ != nullptr) { - cur_batch_num_++; - tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); - } - return Status::OK(); -} - -Status DatasetIterator::GetOutputShapes(std::vector *out_shapes) { - if (out_shapes == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output shape argument"); - } - if (device_queue_row_.empty()) { - RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); - } - for (auto ts : device_queue_row_) { - out_shapes->push_back(ts->shape()); - } - - return Status::OK(); -} - -Status DatasetIterator::GetOutputTypes(std::vector *out_types) { - if (out_types == nullptr) { - RETURN_STATUS_UNEXPECTED("Null output type argument"); - } - if (device_queue_row_.empty()) { - RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); - } - for (auto ts : device_queue_row_) { - out_types->push_back(ts->type()); - } - return Status::OK(); -} - -// Getter -std::unordered_map DatasetIterator::GetColumnNameMap() const { - return root_->column_name_id_map(); -} - -// Constructor of the ChildIterator -ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx) - : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {} - -ChildIterator::~ChildIterator() { current_op_ = nullptr; } - -// Fetches one row of data from the iterator. Overrides the base class. This one fetches -// only from the child/worker id as given from the constructor. -Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) { - // Common code init and error checking in the base class. - RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); - - // Once eof is handled, always return empty row. Class must be destroyed and recreated if you - // want to iterate again. - if (eof_handled_) { - return Status::OK(); - } - - // Check if we need to get a new DataBuffer to iterate. - if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { - RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); - - // Unlike the DatasetIterator, this child iterator does not quit after eoe. - // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the - // caller to decide what it wants to do next. - if (curr_buffer_->eoe()) { - MS_LOG(DEBUG) << "Child iterator picked up EOE."; - end_epoch_ = true; - return Status::OK(); - } - - if (curr_buffer_->eof()) { - MS_LOG(DEBUG) << "Child iterator picked up EOF."; - eof_handled_ = true; - return Status::OK(); - } - } - - // If we got this far, now it's time to pop that next row for return to caller - RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); - - return Status::OK(); -} - -// drain till the next eoe -Status ChildIterator::Drain() { - if (end_epoch_ == true) { - // Calling drain against a child that is already at it's eoe state will not result in any action. - // This allows you to do: - // - fetch until empty row - // - drain (will not actually drain because you are already at the end of the iteration) - // However, the next time after that, it will perform it's normal draining activities. - end_epoch_ = false; - MS_LOG(DEBUG) << "No operation drain, already at end of epoch."; - return Status::OK(); - } - MS_LOG(DEBUG) << "Child draining buffers until eoe."; - // else we drain until eoe or eof, eof here is for sanity check - while (!curr_buffer_->eoe() && !curr_buffer_->eof()) { - RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); - } - if (curr_buffer_->eof()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain."); - } - return Status::OK(); -} - -// Getter -std::unordered_map ChildIterator::GetColumnNameMap() const { - return current_op_->child(child_idx_)->column_name_id_map(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/dataset_iterator.h b/mindspore/ccsrc/dataset/engine/dataset_iterator.h deleted file mode 100644 index 4e40e77c74..0000000000 --- a/mindspore/ccsrc/dataset/engine/dataset_iterator.h +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASET_ITERATOR_H_ -#define DATASET_ENGINE_DATASET_ITERATOR_H_ - -#include -#include -#include -#include -#include "dataset/util/status.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/perf/dataset_iterator_tracing.h" - -namespace mindspore { -namespace dataset { -using TensorMap = std::unordered_map>; - -// forward declare -class ExecutionTree; - -class DataBuffer; - -// IteratorBase class is used to iterate data from an executionTree one row at a time. -// The base class provides the general interface, whereas derived classes provide slightly -// different implementations. -class IteratorBase { - public: - // Constructor of IteratorBase - IteratorBase(); - - // Destructor - virtual ~IteratorBase(); - - // Fetches one row of data from the iterator. - // the base class version simply performs error handling and returns empty row. Actual - // functionality exists in the derived versions of this function. - // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data - // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. - // @return Status - The error code return - // @note The position of a Tensor/column might be different from the initial column order - // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change - // the column ordering. - virtual Status FetchNextTensorRow(TensorRow *out_row); - - // Fetches one row of data from the iterator as a column map. - // @return A unordered map from column name to shared pointer to Tensor. - Status GetNextAsMap(TensorMap *out_map); - - // Getter - // @return T/F if this iterator is completely done after getting an eof - bool eof_handled() const { return eof_handled_; } - - // Getter - // @return The string to column id mapping. - virtual std::unordered_map GetColumnNameMap() const = 0; - - protected: - std::unique_ptr curr_buffer_; // holds the current buffer - bool eof_handled_; // T/F if this op got an eof - std::unordered_map col_name_id_map_; -}; - -// The DatasetIterator derived class is for fetching rows off the end/root of the execution tree. -class DatasetIterator : public IteratorBase { - public: - // Constructor of the DatasetIterator - // @param exe_tree The execution tree we want to pull/iterate the data from using it's root node. - explicit DatasetIterator(std::shared_ptr exe_tree); - - // Destructor - ~DatasetIterator(); - - // Fetches one row of data from the iterator. Overrides the base class. This one fetches - // from the tree root node directly. - // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data - // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. - // @return Status - The error code return - Status FetchNextTensorRow(TensorRow *out_row) override; - - // Fetches the next tensor row into device row, and returns it's shape. - // @param out_shapes - A vector of tensor shapes (one shape per column) - // @return Status - The error code return - Status GetOutputShapes(std::vector *out_shapes); - - // Fetches the next tensor row into device row, and returns it's shape. - // @param outShapes - A vector of tensor shapes (one shape per column) - // @return Status - The error code return - Status GetOutputTypes(std::vector *out_types); - - // Getter - // @return The string to column id mapping. - std::unordered_map GetColumnNameMap() const override; - - private: - std::shared_ptr root_; // saves the root of the executionTree - TensorRow device_queue_row_; - std::shared_ptr tracing_; // trace profiling data - int32_t cur_batch_num_; // current batch number,used for profiling - int32_t cur_connector_size_; // current connector size of root op,used for profiling - int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling -}; - -// The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree. -// This one should only be used by internal Dataset operators, rather than an end-user. -class ChildIterator : public IteratorBase { - public: - // Constructor of the DatasetIterator - // @param current_op - The parent op from which we'll fetch from it's children. - // @param worker_id - The worker id to use when fetching from the children. - // @param child_idx - The index to the child to fetch from. - ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx); - - // Destructor - ~ChildIterator(); - - // Fetches one row of data from the iterator. Overrides the base class. This one fetches - // only from the child/worker id as given from the constructor. - // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data - // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. - // @return Status - The error code return - Status FetchNextTensorRow(TensorRow *out_row) override; - - // This function drains buffer until next eoe has been received. - // It will be a no-op if the previous row returned is empty. - // @return Status - The error code return - Status Drain(); - - // Getter - // @return The string to column id mapping. - std::unordered_map GetColumnNameMap() const override; - - private: - DatasetOp *current_op_; // The parent operator. We consume from it's children. - int32_t child_idx_; // The specific child this iterator will fetch from. - int32_t worker_id_; // The worker id uses for fetching the child data. - bool end_epoch_; // the flag used when an empty row has been returned. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASET_ITERATOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc deleted file mode 100644 index 6fc276a75e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc +++ /dev/null @@ -1,242 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/barrier_op.h" -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -BarrierOp::Builder::Builder() { - // Some arguments to the BarrierOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the BarrierOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } - -Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, - builder_condition_func_); - return Status::OK(); -} - -// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions -BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, - py::function condition_func) - : PipelineOp(op_connector_size), - rows_per_buffer_(rows_per_buffer), - buffer_id_(0), - clean_up_(false), - eof_(false), - condition_name_(condition_name), - condition_function_(condition_func) {} - -// destructor -BarrierOp::~BarrierOp() {} - -// Entry point for Barrier, called by launch() -Status BarrierOp::operator()() { - // The children_num_ parameter needs to be put here - // Synchronize with TaskManager once the thread is created. - TaskManager::FindMe()->Post(); - - // create child iterator, right now this barrier is a pipeline operator - const int32_t worker_id = 0; - const int32_t child_idx = 0; - child_iterator_ = std::make_unique(this, worker_id, child_idx); - - // Loop until eof is true - while (!eof_) { - // Create new table to put the new tensor rows - std::unique_ptr curr_table = std::make_unique(); - RETURN_IF_NOT_OK(prepare(curr_table.get())); - - // If an eof got picked up during the above prepare, then we're done - if (eof_) { - break; - } - - // we have to output new buffer with possibly different buffer size, possibly one row - while (!clean_up_) { - // 1. If a previous loop iteration sent the current table out, then create a new one. - - if (curr_table == nullptr) { - curr_table = std::make_unique(); - } - - // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished - RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); - - // 3 create and update buffer and send it to the out connector - if (!curr_table->empty()) { - std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); - curr_buffer->set_tensor_table(std::move(curr_table)); - MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " - << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - buffer_id_++; - } - } - - // 4 handle drain state. - if (clean_up_) { - MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; - // Send the eoe up. - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - } - } - // 5 handle eof - // propagate eof here. - MS_LOG(INFO) << "Barrier operator got EOF, propagating."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Handles preprocessing of the main loop, used when starting new epoch -Status BarrierOp::prepare(TensorQTable *const table) { - MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; - clean_up_ = false; - buffer_id_ = 0; - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); - } - // fill initial row - TensorRow new_row = {}; - // use iterator to get next row and invoke pyfunc wait - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - - // If the first row fetching resulted in eof, then we are done. - if (eof_) { - return Status::OK(); - } - if (new_row.empty()) { - // This epoch is empty - return Status::OK(); - } - // Pack this first row into our tensor table - // first row we also have to check if we should block - RETURN_IF_NOT_OK(blockCond()); - - table->push_back(std::move(new_row)); - - // the update code below shouldn't do anything bad if the column name already exists. - return Status::OK(); -} - -// fillBuffer always expects a new table to fill -Status BarrierOp::fillBuffer(TensorQTable *const table) { - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); - } - TensorRow new_row = {}; - while (table->size() < static_cast(rows_per_buffer_)) { - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - // Early exit the loop if we got empty row from any of our child iterations - if (new_row.empty()) { - return Status::OK(); - } - // else we got a row so pack it into the tensor table. - RETURN_IF_NOT_OK(blockCond()); - - table->push_back(std::move(new_row)); - } - return Status::OK(); -} - -// function executes a py_func and blocks until condition becomes true. -Status BarrierOp::blockCond() { - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - // we have condition name, however the flexibility is in python today - try { - // Invoke python function - py::object ret_py_obj = condition_function_(); - // Process the return value - if (!py::isinstance(ret_py_obj)) { - return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - } - return Status::OK(); -} - -// fetches next Barrier buffer row -Status BarrierOp::getNextTensorRow(TensorRow *new_row) { - // iterate over all iterators and generate a row - RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); - // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row - if (new_row->empty()) { - // If we did not get a row from any of the children, then it's the end of an epoch and we can move - // to drain state. - MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; - clean_up_ = true; - // If we picked up an eof here, then we are completely done. - if ((child_iterator_)->eof_handled()) { - MS_LOG(INFO) << "Barrier operator iterator got EOF."; - eof_ = true; - } - return Status::OK(); - } - return Status::OK(); -} - -// A function that prints info about the Operator -void BarrierOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nCondition: " << condition_name_ << "\n\n"; - } -} - -// overwrite function and handle eof -Status BarrierOp::EofReceived(int32_t) { - MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; - return Status::OK(); -} - -// overwrite function and handle eoe -Status BarrierOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h deleted file mode 100644 index 379b8f146b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h +++ /dev/null @@ -1,169 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class DataBuffer; -class ExecutionTree; - -// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has -// been received. This signal is given from python layer. The current barrier design respects the -// rows per buffer design and will only output a buffer with rows once it has received rows per buffer -// signals from python. - -class BarrierOp : public PipelineOp { - public: - // The nested builder class inside of the BarrierOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @param const std::string & condition_name - // @return Builder setter method returns reference to the builder. - Builder &SetConditionName(const std::string &condition_name) { - builder_condition_name_ = condition_name; - return *this; - } - - // Setter method. - // @param py::function condition_func - blocking condition function - // @return Builder setter method returns reference to the builder. - Builder &SetConditionFunc(py::function condition_func) { - builder_condition_func_ = condition_func; - return *this; - } - - // The builder "build" method creates the BarrierOp dataset Operator. - // @return shared_ptr to the new BarrierOp object - Status Build(std::shared_ptr *); - - private: - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::string builder_condition_name_; - py::function builder_condition_func_; - - Status SanityCheck() const; - }; - - // Constructor for BarrierOp - // @param rows_per_buffer - number of rows in output buffer - // @param op_connector_size - connector size - // @param condition_name - the condition name associated with this operator - // @param condition_func - the blocking condition check per row - // @note - currently rows_per_buffer should = 1 for barrier. - // The reason for this is having other values would complicate how the pipeline behaves with other operators - // One example of such case is having batch after barrier. Batch would be waiting for data and having - // rows per buffer in this case can result in hanging - BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, - py::function condition_func); - - // Destructor - ~BarrierOp(); - - Status EofReceived(int32_t) override; - - Status EoeReceived(int32_t) override; - - // Print function for Barrier - // @param out - output stream to print to - // @param show_all - if it should print everything - void Print(std::ostream &out, bool show_all) const override; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { - bo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Handles preprocessing of the main loop, used when starting new epoch - // @param table - a table of tensors to be moved into a buffer - Status prepare(TensorQTable *const table); - - // This function calls takes a table repeatedly adds rows to it. - // @param table - a table of tensors to be moved into a buffer - Status fillBuffer(TensorQTable *const table); - - // Gets next tensor row and sets control signals - Status getNextTensorRow(TensorRow *new_row); - - // This function runs the wait function on condition - Status blockCond(); - - private: - // clean up variable to return imcomplete buffer - bool clean_up_; - // end of file state, we stop reading data and shut down - bool eof_; - // rows per buffer - int32_t rows_per_buffer_; - // buffer_id - int32_t buffer_id_; - // iterator to pull new rows, we only have one child - std::unique_ptr child_iterator_; - // condition name, to support multiple barriers - std::string condition_name_; - // Function pointer of blocking function - py::function condition_function_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc deleted file mode 100644 index 93b4864040..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.cc +++ /dev/null @@ -1,446 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/batch_op.h" - -#include -#include - -#include "common/utils.h" -#ifdef ENABLE_PYTHON -#include "dataset/core/pybind_support.h" -#endif -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/data/data_utils.h" - -using float16 = Eigen::half; - -namespace mindspore { -namespace dataset { -BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) { - builder_batch_size_ = batch_size; - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status BatchOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); -#ifdef ENABLE_PYTHON - *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, - builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_, - builder_batch_map_func_, builder_pad_map_); -#else - *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, - builder_num_workers_, builder_cols_to_map_, builder_pad_map_); -#endif - return Status::OK(); -} - -Status BatchOp::Builder::SanityCheck() { - std::string err; - err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; - err += builder_batch_size_ <= 0 ? "batch size <= 0\n" : ""; - err += builder_num_workers_ <= 0 ? "batch num_parallel_workers <= 0\n" : ""; - return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); -} - -#ifdef ENABLE_PYTHON -BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, - const std::vector &cols_to_map, py::function batch_size_func, py::function batch_map_func, - PadInfo pad_map) - : ParallelOp(num_workers, op_queue_size), - start_batch_size_(batch_size), - drop_(drop), - pad_(pad), - pyfunc_column_names_(cols_to_map), - batch_size_func_(batch_size_func), - batch_map_func_(batch_map_func), - pad_info_(pad_map) { - worker_queues_.Init(num_workers, op_queue_size); -} -#else -BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, - const std::vector &cols_to_map, PadInfo pad_map) - : ParallelOp(num_workers, op_queue_size), - start_batch_size_(batch_size), - drop_(drop), - pad_(pad), - pyfunc_column_names_(cols_to_map), - pad_info_(pad_map) { - worker_queues_.Init(num_workers, op_queue_size); -} -#endif - -Status BatchOp::operator()() { - Status rc = LaunchThreadsAndInitOp(); - // Synchronize with TaskManager - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - int64_t epoch_num = 0, batch_num = 0, cnt = 0; - TensorRow new_row; - std::unique_ptr table = std::make_unique(); - child_iterator_ = std::make_unique(this, 0, 0); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - int32_t cur_batch_size = 0; - RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0))); - while (child_iterator_->eof_handled() == false) { - while (new_row.empty() == false) { - table->emplace_back(new_row); - // if # of rows is enough to make 1 batch (1 batch is buffer), send it to worker_queue - if (table->size() == static_cast(cur_batch_size)) { - RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( - std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); - table = std::make_unique(); - RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); - } - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } - // Reminder logic, execute only when there is a remainder (table is non empty) and don't drop - if (drop_ == false && table->empty() == false) { - RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( - std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); - } - table = std::make_unique(); // this drops when drop == true - // end of the current epoch, batch_num should start from 0 again - batch_num = 0; - epoch_num++; - RETURN_IF_NOT_OK( - worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE)))); - RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } // end of eof_handled() == false - RETURN_IF_NOT_OK( - worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF)))); - // EOF received, send quit signal (an empty buffer) to all workers - for (int32_t ind = 0; ind < num_workers_; ind++) { - RETURN_IF_NOT_OK( - worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit)))); - } - return Status::OK(); -} - -void BatchOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [batch size: " << start_batch_size_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nStart batch size: " << start_batch_size_ << "\nDrop remainder: " << (drop_ ? "yes" : "no") << "\n\n"; - } -} - -Status BatchOp::BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, - dsize_t batch_size) { - if ((*src)->size() != batch_size) { - RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Source table size does not match the batch_size"); - } - - if (batch_size == 1) { - TensorRow row = std::move((*src)->front()); - (*src)->pop_front(); - (*dest)->push_back(row); - for (const auto &tensor : (*dest)->front()) { - RETURN_IF_NOT_OK(tensor->ExpandDim(0)); - } - return Status::OK(); - } - - TensorRow batched_row; - auto num_columns = (*src)->front().size(); - for (size_t i = 0; i < num_columns; i++) { - std::shared_ptr first_tensor = (*src)->at(0).at(i); // first row, column i - TensorShape first_shape = first_tensor->shape(); - DataType first_type = first_tensor->type(); - TensorShape new_shape = first_shape.PrependDim(static_cast(batch_size)); - - std::shared_ptr new_tensor; - if (first_type.IsNumeric()) { // numeric tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type)); - dsize_t j = 0; - for (auto row : **src) { - std::shared_ptr old_tensor = row.at(i); // row j, column i - if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first - RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); - } else { - RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i)); - } - } - } else { // handle string column differently - std::vector strings; - for (dsize_t j = 0; j < batch_size; j++) { - std::shared_ptr old_tensor = (*src)->at(j).at(i); - for (auto itr = old_tensor->begin(); itr != old_tensor->end(); itr++) { - strings.emplace_back(*itr); - } - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape)); - } - batched_row.emplace_back(new_tensor); - } - - (*dest)->emplace_back(batched_row); - - return Status::OK(); -} - -Status BatchOp::WorkerEntry(int32_t workerId) { - TaskManager::FindMe()->Post(); - std::pair, CBatchInfo> table_pair; - RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); - while (table_pair.second.ctrl_ != batchCtrl::kQuit) { - if (table_pair.second.ctrl_ == batchCtrl::kEOE) { - RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - } else if (table_pair.second.ctrl_ == batchCtrl::kEOF) { - RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) { - std::unique_ptr db = nullptr; - RETURN_IF_NOT_OK(MakeBatchedBuffer(std::move(table_pair), &db)); - RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::move(db))); - } - RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); - } - return Status::OK(); -} - -Status BatchOp::MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, - std::unique_ptr *db) { - RETURN_UNEXPECTED_IF_NULL(table_pair.first); -#ifdef ENABLE_PYTHON - if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc -#endif - if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed - (*db) = std::make_unique(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone); - std::unique_ptr dest_table = std::make_unique(); - RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size())); - (*db)->set_tensor_table(std::move(dest_table)); - return Status::OK(); -} - -Status BatchOp::LaunchThreadsAndInitOp() { - RETURN_UNEXPECTED_IF_NULL(tree_); - RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1))); - return Status::OK(); -} - -Status BatchOp::EofReceived(int32_t) { return Status::OK(); } - -Status BatchOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -#ifdef ENABLE_PYTHON -Status BatchOp::MapColumns(std::pair, CBatchInfo> *table_pair) { - TensorBatchTable input_table; - input_table.reserve(pyfunc_column_names_.size()); - for (std::string col_name : pyfunc_column_names_) { - if (column_name_id_map_.find(col_name) == column_name_id_map_.end()) { - RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n"); - } - TensorBatch tensor_batch; - tensor_batch.reserve(table_pair->first->size()); - size_t col_idx = static_cast(column_name_id_map_[col_name]); - for (size_t row_idx = 0; row_idx < table_pair->first->size(); row_idx++) { - tensor_batch.push_back(std::move(table_pair->first->at(row_idx)[col_idx])); - } - input_table.push_back(std::move(tensor_batch)); - } - - // Perform batch map - TensorBatchTable output_table; - RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second)); - - // Write back to TensorQTable - for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) { - size_t col_idx = static_cast(column_name_id_map_[pyfunc_column_names_[input_idx]]); - size_t row_id = 0; - for (TensorRow &row : *(table_pair->first)) { - row[col_idx] = std::move(output_table[input_idx][row_id++]); - } - } - return Status::OK(); -} -#endif - -Status BatchOp::GetBatchSize(int32_t *batch_size, CBatchInfo info) { -#ifdef ENABLE_PYTHON - if (batch_size_func_ != nullptr) { - RETURN_IF_NOT_OK(InvokeBatchSizeFunc(batch_size, info)); - } else { - (*batch_size) = start_batch_size_; - } -#else - (*batch_size) = start_batch_size_; -#endif - return Status::OK(); -} - -#ifdef ENABLE_PYTHON -Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) { - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py::object size = batch_size_func_(info); - *batch_size = size.cast(); - if (*batch_size <= 0) { - return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); - } - } - return Status(StatusCode::kOK, "Batch size func call succeed"); -} - -Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *output, CBatchInfo info) { - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - // Prepare batch map call back parameters - py::tuple input_args(input->size() + 1); - for (size_t i = 0; i < input->size(); i++) { - std::vector np_batch; - for (std::shared_ptr t : input->at(i)) { - py::array np_array; - RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array)); - np_batch.push_back(std::move(np_array)); - } - input_args[i] = np_batch; - } - input_args[input->size()] = info; - // Invoke batch map func - py::object ret_py_obj = batch_map_func_(*input_args); - // Parse batch map return value - py::tuple ret_tuple = py::cast(ret_py_obj); - if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance(ret_tuple)) { - return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple"); - } - for (size_t i = 0; i < ret_tuple.size(); i++) { - TensorBatch output_batch; - py::list output_list = py::cast(ret_tuple[i]); - for (size_t j = 0; j < output_list.size(); j++) { - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast(output_list[j]))); - output_batch.push_back(std::move(out)); - } - output->push_back(std::move(output_batch)); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple of list of numpy array"); - } - } - return Status(StatusCode::kOK); -} -#endif - -Status BatchOp::PadColumns(std::unique_ptr *table, const PadInfo &pad_info, - const std::unordered_map &column_name_id_map) { - RETURN_UNEXPECTED_IF_NULL(table); // placeholder for now, might need this in the future - CHECK_FAIL_RETURN_UNEXPECTED((*table)->front().size() == column_name_id_map.size(), "col_name_map mismatch"); - std::vector> pad_vals(column_name_id_map.size(), - 0); // value to pad each column's tensor with, default 0 - std::set pad_cols; - // padded_shape provided by user, maximum shapes of current batch of tensors - std::vector> pad_shapes(column_name_id_map.size()), max_shapes(column_name_id_map.size()); - RETURN_IF_NOT_OK(UnpackPadInfo(pad_info, column_name_id_map, &pad_cols, &pad_vals, &pad_shapes)); - - // init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well - for (size_t col_id : pad_cols) { - max_shapes[col_id] = std::vector((*table)->front()[col_id]->Rank(), -1); - if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1 - CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape"); - } - - // calculate maximum shape for each column that needs to be padded - for (const TensorRow &row : **table) { // iterator each row in a batch - for (size_t col_id : pad_cols) { // iterator each tensor in a row - CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(), - "Tensor to be padded together need to have the same rank"); - for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension - max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]); - } - } - } - - // if user sets a dimension to -1 (None in python), use the max value for current dimension - for (size_t col_id : pad_cols) { - for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) { - if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim]; - } - } - - // call pad on each tensor that needs to be padded - for (TensorRow &row : **table) { - for (size_t col_id : pad_cols) { - std::shared_ptr pad_tensor; - RETURN_IF_NOT_OK(PadEnd(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id])); - row[col_id] = pad_tensor; - } - } - return Status::OK(); -} - -Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, - const std::unordered_map &column_name_id_map, - std::set *pad_cols, std::vector> *pad_vals, - std::vector> *pad_shapes) { - if (pad_info.empty()) { // if pad_info empty, pad every columns automatically - for (dsize_t col_id = 0; col_id < column_name_id_map.size(); col_id++) { - pad_cols->insert(col_id); - } - } else { - for (const auto &p : pad_info) { - auto location = column_name_id_map.find(p.first); - CHECK_FAIL_RETURN_UNEXPECTED(location != column_name_id_map.end(), "no column exists with name:" + p.first); - auto col_id = static_cast(location->second); - CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound"); - pad_cols->insert(col_id); - (*pad_vals)[col_id] = p.second.second; // set pad values - (*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown - } - } - return Status::OK(); -} - -// Visitor accept method for NodePass -Status BatchOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h deleted file mode 100644 index acf2e5a0c0..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/batch_op.h +++ /dev/null @@ -1,287 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DataBuffer; - -using TensorBatch = TensorRow; -using TensorBatchTable = std::vector; -using PadInfo = std::map>>; - -class BatchOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor for Batch, batch size needs to be specified - // @param int32_t batch_size - explicit Builder(int32_t batch_size); - - // Default destructor - ~Builder() = default; - - // set number of parallel Workers on batch - // @param int32_t num_workers - // @return Builder & reference to builder class object - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // set drop for batch op,default false - // @param bool drop - // @return Builder & reference to builder class object - Builder &SetDrop(bool drop) { - builder_drop_ = drop; - return *this; - } - - Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) { - builder_pad_ = pad; - builder_pad_map_ = pad_map; - return *this; - } - - // set connector size for batch - // @param int32_t op_conn_size - // @return Builder & reference to builder class object - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size); - return *this; - } - - // set columns to perform map on - // @param const std::vector & cols_to_map - name of columns to perform map on - // @return Builder & reference to builder class object - Builder &SetColumnsToMap(const std::vector &cols_to_map) { - builder_cols_to_map_ = cols_to_map; - return *this; - } - -#ifdef ENABLE_PYTHON - // set columns to perform map on - // @param const std::vector & cols_to_map - name of columns to perform map on - // @return Builder & reference to builder class object - Builder &SetBatchMapFunc(py::function batch_map_func) { - builder_batch_map_func_ = batch_map_func; - return *this; - } - - // SetBatchSizeFunc, a function that calls to python after every batch is made - // @param py::function batch_size_func - python function to call, GIL required before calling - // @return Builder & reference to builder class object - Builder &SetBatchSizeFunc(py::function batch_size_func) { - builder_batch_size_func_ = batch_size_func; - return *this; - } -#endif - - // @param std::shared_ptr *ptr pointer to shared_ptr, actual return arg - // @return Status - The error code return - Status Build(std::shared_ptr *); - - private: - // Sanity check for builder class args - // @return Status - The error code return - Status SanityCheck(); - - bool builder_drop_; - bool builder_pad_; - int32_t builder_batch_size_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - std::vector builder_cols_to_map_; - PadInfo builder_pad_map_; -#ifdef ENABLE_PYTHON - py::function builder_batch_size_func_; - py::function builder_batch_map_func_; -#endif - }; - - enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 }; - - // Parameters associate with one batch. - // This struct is used for both internal control and python callback. - // This struct is bound to python with read-only access. - struct CBatchInfo { - CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl) - : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {} - CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {} - CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {} - explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {} - int64_t epoch_num_; // i-th epoch. i starts from 0 - int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0 - int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0 - batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3 - const int64_t get_batch_num() const { return batch_num_; } - const int64_t get_epoch_num() const { return epoch_num_; } - }; - -#ifdef ENABLE_PYTHON - // BatchOp constructor - // @param int32_t batch_size - // @param bool drop - // @param int32_t op_queue_size - // @param int32_t rows_per_buf - // @param int32_t num_workers - BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, - const std::vector &, py::function batch_size_func, py::function batch_map_func, PadInfo pad_map); -#else - BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, - const std::vector &, PadInfo pad_map); -#endif - - // BatchOp destructor - ~BatchOp() {} - - // @param int32_t workerId - // @return Status - The error code return - Status EofReceived(int32_t) override; - - // @param int32_t workerId - // @return Status - The error code return - Status EoeReceived(int32_t) override; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param sO - reference to the BatchOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) { - bo.Print(out, false); - return out; - } - - // Main loop of batch - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "BatchOp"; } - - // batch the rows in src table then put it to dest table - // @param const std::unique_ptr *src - table that has the rows for batching - // @param const std::unique_ptr *dest - dest_table to hold batched rows - // @param int32_t size - batch_size - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, - dsize_t batch_size); - - // @param table - // @param const PadInfo &pad_info pad info - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @return Status - The error code return - static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, - const std::unordered_map &column_name_id_map); - - private: - // Worker thread for doing the memcpy of batch - // @param int32_t param workerId - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Generate buffer with batched tensors - // @return Status - The error code return - Status MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, - std::unique_ptr *db); - -#ifdef ENABLE_PYTHON - // Function that calls pyfunc to perform map on batch - // @param (std::pair, batch_stats> *table_pair - contains un-batched tensor - // @return Status - The error code return - Status MapColumns(std::pair, CBatchInfo> *table_pair); -#endif - - // @param const PadInfo &pad_info pad info to unpack - // @param const std::unordered_map& column_name_id_map - column names to index mapping - // @param std::set *cols, col ids to perform pad on - // @param std::vector *vals, default padding value for each column - // @param std::vector> *shapes, padding shape specified by user - // @return Status - The error code return - static Status UnpackPadInfo(const PadInfo &pad_info, - const std::unordered_map &column_name_id_map, - std::set *pad_cols, std::vector> *pad_vals, - std::vector> *pad_shapes); - - // the number of thread pulling from the mOutConnector of the Op below - // @return int32_t, 1 - int32_t num_consumers() const override { return 1; } - - // get the batch size for next batch - // @return Status - The error code return - Status GetBatchSize(int32_t *batch_size, CBatchInfo info); - - // Do the initialization of all queues then start all worker threads - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - -#ifdef ENABLE_PYTHON - // Invoke batch size function with current BatchInfo to generate batch size. - // @return Status - The error code return - Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); - - // Invoke batch map function with current BatchInfo to generate tensors to batch. - // @return Status - The error code return - Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); -#endif - - int32_t start_batch_size_; - bool drop_; // bool for whether to drop remainder or not - bool pad_; // bool for whether to perform padding on tensor - std::vector pyfunc_column_names_; // Name of the columns to perform map op on - PadInfo pad_info_; // column names to perform padding on - std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 - QueueList, CBatchInfo>> worker_queues_; // internal queue for syncing worker -#ifdef ENABLE_PYTHON - py::function batch_size_func_; // Function pointer of batch size function - py::function batch_map_func_; // Function pointer of per batch map function -#endif -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc deleted file mode 100644 index 5e143b700f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/bucket_batch_by_length_op.h" - -#include -#include -#include -#include -#include - -#include "pybind11/numpy.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "dataset/core/pybind_support.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/status.h" - -namespace py = pybind11; -namespace mindspore { -namespace dataset { -BucketBatchByLengthOp::Builder::Builder(std::vector length_dependent_columns, - std::vector bucket_boundaries, std::vector bucket_batch_sizes) - : builder_length_dependent_columns_(length_dependent_columns), - builder_bucket_boundaries_(bucket_boundaries), - builder_bucket_batch_sizes_(bucket_batch_sizes), - builder_pad_info_({}), - builder_pad_to_bucket_boundary_(false), - builder_drop_remainder_(false) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_op_connector_size_ = config_manager->op_connector_size(); -} - -Status BucketBatchByLengthOp::Builder::SanityCheck() { - std::string error_message; - - if (builder_length_dependent_columns_.empty()) { - error_message += "At least 1 column must be specified for element length calculation.\n"; - } - - if (builder_bucket_boundaries_.empty()) { - error_message += "At least 1 bucket boundary must be specified.\n"; - } - - if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { - error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; - } - - CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); - - return Status::OK(); -} - -Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr *new_bucket_batch_by_length_op) { - RETURN_IF_NOT_OK(SanityCheck()); - - // insert 0 for the first bucket - builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); - - *new_bucket_batch_by_length_op = std::make_shared( - builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, - builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, - builder_op_connector_size_); - - return Status::OK(); -} - -BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector length_dependent_columns, - std::vector bucket_boundaries, - std::vector bucket_batch_sizes, - py::function element_length_function, PadInfo pad_info, - bool pad_to_bucket_boundary, bool drop_remainder, - int32_t op_connector_size) - : PipelineOp(op_connector_size), - length_dependent_columns_(length_dependent_columns), - bucket_boundaries_(bucket_boundaries), - bucket_batch_sizes_(bucket_batch_sizes), - element_length_function_(element_length_function), - pad_info_(pad_info), - pad_to_bucket_boundary_(pad_to_bucket_boundary), - drop_remainder_(drop_remainder), - batch_count_(0) { - for (int i = 0; i < bucket_batch_sizes_.size(); i++) { - buckets_.push_back(std::make_unique()); - } -} - -Status BucketBatchByLengthOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } - -Status BucketBatchByLengthOp::operator()() { - TaskManager::FindMe()->Post(); - - TensorRow current_row; - child_iterator_ = std::make_unique(this, 0, 0); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); - while (!child_iterator_->eof_handled()) { - while (!current_row.empty()) { - int32_t element_length; - RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); - - int bucket_index = bucket_boundaries_.size() - 1; - while (element_length < bucket_boundaries_[bucket_index]) { - bucket_index--; - } - - buckets_[bucket_index]->push_back(current_row); - - if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { - RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); - } - - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); - } - - // got EOE, do what we need to do with remainders in each bucket - if (!drop_remainder_) { - for (int i = 0; i < bucket_boundaries_.size(); i++) { - if (!buckets_[i]->empty()) { - RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); - } - } - } - - // need to send EOE manually since we set state to idle in EoeRecieved() - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); - } - - return Status::OK(); -} - -Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { - // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of - // the single column specified in length_dependent_columns_ - if (element_length_function_) { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - size_t number_of_arguments = length_dependent_columns_.size(); - py::tuple input_arguments(number_of_arguments); - for (size_t i = 0; i < number_of_arguments; i++) { - py::array argument_value; - int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; - RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); - input_arguments[i] = argument_value; - } - - py::object length = element_length_function_(*input_arguments); - *out_element_length = length.cast(); - if (*out_element_length < 0) { - return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); - } - } else { - *out_element_length = element[0]->shape()[0]; - } - - return Status::OK(); -} - -Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { - std::unique_ptr *bucket = &buckets_[bucket_index]; - - PadInfo pad_info_copy = pad_info_; - if (pad_to_bucket_boundary_) { - for (auto &pair : pad_info_copy) { - std::vector pad_shape = pair.second.first.AsVector(); - - for (size_t i = 0; i < pad_shape.size(); i++) { - if (pad_shape[i] == TensorShape::kDimUnknown) { - if (bucket_index + 1 >= bucket_boundaries_.size()) { - std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); - } - - pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; - } - } - - pair.second.first = TensorShape(pad_shape); - } - } - - // PadColumns will change the data in bucket - RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); - - std::unique_ptr batched_bucket = std::make_unique(); - RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); - (*bucket)->clear(); - - std::unique_ptr batched_buffer = std::make_unique(batch_count_, DataBuffer::kDeBFlagNone); - batched_buffer->set_tensor_table(std::move(batched_bucket)); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); - - batch_count_++; - - return Status::OK(); -} - -Status BucketBatchByLengthOp::Reset() { - batch_count_ = 0; - - for (int i = 0; i < buckets_.size(); i++) { - buckets_[i] = std::make_unique(); - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h deleted file mode 100644 index bf0bcb0e78..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DataBuffer; - -class BucketBatchByLengthOp : public PipelineOp { - public: - class Builder { - public: - Builder(std::vector length_dependent_columns, std::vector bucket_boundaries, - std::vector bucket_batch_sizes); - - ~Builder() = default; - - Builder &SetLengthDependentColumns(std::vector length_dependent_columns) { - builder_length_dependent_columns_ = length_dependent_columns; - return *this; - } - - Builder &SetBucketBoundaries(std::vector bucket_boundaries) { - builder_bucket_boundaries_ = bucket_boundaries; - return *this; - } - - Builder &SetBucketBatchSizes(std::vector bucket_batch_sizes) { - builder_bucket_batch_sizes_ = bucket_batch_sizes; - return *this; - } - - Builder &SetElementLengthFunction(py::function element_length_function) { - builder_element_length_function_ = element_length_function; - return *this; - } - - Builder &SetPadInfo(PadInfo pad_info) { - builder_pad_info_ = pad_info; - return *this; - } - - Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { - builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; - return *this; - } - - Builder &SetDropRemainder(bool drop_remainder) { - builder_drop_remainder_ = drop_remainder; - return *this; - } - - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - Status Build(std::shared_ptr *new_bucket_batch_by_length_op); - - private: - Status SanityCheck(); - - std::vector builder_length_dependent_columns_; - std::vector builder_bucket_boundaries_; - std::vector builder_bucket_batch_sizes_; - py::function builder_element_length_function_; - PadInfo builder_pad_info_; - bool builder_pad_to_bucket_boundary_; - bool builder_drop_remainder_; - int32_t builder_op_connector_size_; - }; - - BucketBatchByLengthOp(std::vector length_dependent_columns, std::vector bucket_boundaries, - std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, - bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); - - // Destructor - ~BucketBatchByLengthOp() = default; - - // Might need to batch remaining buckets after receiving eoe, so override this method. - // @param int32_t workerId - // @return Status - The error code returned - Status EoeReceived(int32_t) override; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param sO - reference to the BucketBatchByLengthOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { - bo.Print(out, false); - return out; - } - - // Main loop of batch - // @return Status - The error code returned - Status operator()() override; - - // Function that is called by ResetOp at the end of every epoch - // @return Status - The error code returned - Status Reset() override; - - private: - Status ObtainElementLength(int32_t *out_element_length, TensorRow element); - - Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); - - std::vector length_dependent_columns_; - std::vector bucket_boundaries_; - std::vector bucket_batch_sizes_; - py::function element_length_function_; - PadInfo pad_info_; - bool pad_to_bucket_boundary_; - bool drop_remainder_; - - int32_t batch_count_; - std::unique_ptr child_iterator_; - std::vector> buckets_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc deleted file mode 100644 index ceb5058593..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.cc +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/datasetops/build_vocab_op.h" - -#include -#include -#include -#include -#include -#include "dataset/core/config_manager.h" - -namespace mindspore { -namespace dataset { - -BuildVocabOp::BuildVocabOp(std::shared_ptr vocab, std::vector col_names, - std::pair freq_r, int64_t top_k, const std::vector &tokens, - bool prepend, int32_t num_workers, int32_t op_conn_size) - : ParallelOp(num_workers, op_conn_size), - interval_(op_conn_size * num_workers), - vocab_(vocab), - col_names_(col_names), - freq_range_(freq_r), - top_k_(top_k), - special_tokens_(tokens), - special_first_(prepend) { - // init two queues for thread sync - distributor_queue_ = std::make_unique>(num_workers * op_conn_size); - collector_queue_ = - std::make_unique>>>(num_workers * op_conn_size); -} - -Status BuildVocabOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - TensorRow new_row; - RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); - std::unique_ptr> wrkr_map = - std::make_unique>(); - int32_t row_cnt = 0; - while (!new_row.empty()) { - for (int32_t col : col_ids_) { - CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns"); - for (auto itr = new_row[col]->begin(); itr != new_row[col]->end(); itr++) { - (*wrkr_map)[std::string(*itr)] += 1; - } - } - row_cnt++; // row is processed by this point - if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) { - RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); - wrkr_map = std::make_unique>(); - } - RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); - } - // clean up - if (!wrkr_map->empty()) { - RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); - } - // empty map as quit signal - RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique>())); - return Status::OK(); -} - -Status BuildVocabOp::operator()() { - // launch the collector thread - RETURN_UNEXPECTED_IF_NULL(tree_); - RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); - // launch worker threads and collector thread - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1))); - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this))); - TaskManager::FindMe()->Post(); - child_iterator_ = std::make_unique(this, 0, 0); - TensorRow new_row; - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - if (!col_names_.empty()) { - col_ids_.reserve(col_names_.size()); - for (std::string col : col_names_) { - auto itr = column_name_id_map_.find(col); - CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist"); - col_ids_.push_back(itr->second); - } - } else { - col_ids_.reserve(column_name_id_map_.size()); - for (const auto &p : column_name_id_map_) { - col_ids_.push_back(p.second); - } - } - bool eoe_warning = false; // give out warning if receive more than 1 eoe - while (child_iterator_->eof_handled() == false) { - while (new_row.empty() == false) { - RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row)); - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } - CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); - eoe_warning = true; - } - - // tell all workers to quit - for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) { - RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow())); - } - return Status::OK(); -} - -Status BuildVocabOp::CollectorThread() { - TaskManager::FindMe()->Post(); - int32_t num_quited_worker = 0; - std::unique_ptr> wrkr_map; - while (num_quited_worker != num_workers_) { - RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map)); - RETURN_UNEXPECTED_IF_NULL(wrkr_map); - if (!wrkr_map->empty()) { - for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second; - } else { - ++num_quited_worker; - } - } // all frequencies are obtained - CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); - std::vector words; - // make sure enough is reserved, this will become a partially sorted list eventually - words.reserve(wrkr_map->size()); - - for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { - if (it->second >= freq_range_.first && it->second <= freq_range_.second) { - words.push_back(it->first); - it++; - } else { - it = word_cnt_.erase(it); - } - } - std::string err_msg; - - for (const std::string &sp_tk : special_tokens_) { - // if a special word exists in dataset, warn user about this - err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : ""); - } - - CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + "."); - - int64_t num_words = std::min(static_cast(words.size()), top_k_); - if (num_words == 0) { - MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second - << ") vocab would be empty (except for special tokens)."; - } - - // this would take the top-k most frequent words - std::partial_sort(words.begin(), words.begin() + num_words, words.end(), - [this](const std::string &w1, const std::string &w2) { - int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; - return f1 == f2 ? w1 < w2 : f1 > f2; - }); - - if (special_first_) { - for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); - } - - for (int64_t i = 0; i < num_words; i++) { - vocab_->append_word(words[i]); - } - - if (!special_first_) { - for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); - } - - RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - // then use std::nth_element to partial sort - return Status::OK(); -} - -Status BuildVocabOp::Builder::Build(std::shared_ptr *op) { - CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0"); - CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); - CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, - "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); - (*op) = std::make_shared( - builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, - builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_); - return Status::OK(); -} - -BuildVocabOp::Builder::Builder() - : builder_top_k_(std::numeric_limits::max()), - builder_min_freq_(0), - builder_max_freq_(std::numeric_limits::max()), - builder_special_first_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_connector_size_ = cfg->op_connector_size(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h b/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h deleted file mode 100644 index bf358c48c6..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ -#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/text/vocab.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BuildVocabOp : public ParallelOp { - public: - class Builder { - public: - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_connector_size_ = size; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param int64_t top_k - // @return Builder setter method returns reference to the builder. - Builder &SetTopK(int64_t top_k) { - builder_top_k_ = top_k; - return *this; - } - - // Setter method - // @param int64_t min_freq - // @return Builder setter method returns reference to the builder. - Builder &SetMinFreq(int64_t min_freq) { - builder_min_freq_ = min_freq; - return *this; - } - - // Setter method - // @param int64_t max_freq - // @return Builder setter method returns reference to the builder. - Builder &SetMaxFreq(int64_t max_freq) { - builder_max_freq_ = max_freq; - return *this; - } - - // set columns names - // @param const std::vector & col_names - name of columns to get words - // @return Builder & reference to builder class object - Builder &SetColumnNames(const std::vector &col_names) { - builder_col_names_ = col_names; - return *this; - } - - // set special tokens - // @param const std::vector & col_names - name of columns to get words - // @return Builder & reference to builder class object - Builder &SetSpecialTokens(const std::vector &tokens) { - builder_speical_tokens_ = tokens; - return *this; - } - - // set vocab object - Builder &SetVocab(std::shared_ptr vocab) { - builder_vocab_ = vocab; - return *this; - } - - // set special tokens first (or last) - Builder &SetSpecialFirst(bool prepend) { - builder_special_first_ = prepend; - return *this; - } - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - int32_t builder_num_workers_; - int32_t builder_connector_size_; - int64_t builder_min_freq_; - int64_t builder_max_freq_; - bool builder_special_first_; - std::vector builder_col_names_; - std::vector builder_speical_tokens_; - std::shared_ptr builder_vocab_; - int64_t builder_top_k_; - }; - - BuildVocabOp(std::shared_ptr vocab, std::vector col_names, std::pair freq_range, - int64_t top_k, const std::vector &tokens, bool prepend, int32_t num_workers, - int32_t op_connector_size); - - ~BuildVocabOp() = default; - - Status WorkerEntry(int32_t worker_id) override; - - // collect the work product from each worker - Status CollectorThread(); - - Status EofReceived(int32_t) override { return Status::OK(); } - - Status EoeReceived(int32_t) override { return Status::OK(); } - - Status operator()() override; - - // Getter - // @return the number of workers - int32_t num_producers() const override { return 1; } - - // Getter - // @return the number of threads consuming from the previous Connector - int32_t num_consumers() const override { return 1; } - - Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } - - private: - const int32_t interval_; - bool special_first_; - std::shared_ptr vocab_; - std::vector col_names_; - std::vector col_ids_; - std::vector special_tokens_; - // pair = {min_f, max_f} - // make sure that 0<= min_f < max_f <= int32_max in the builder - std::pair freq_range_; - - int64_t top_k_; // every thing means top_k_ == int32_max - std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 - std::unique_ptr> distributor_queue_; // master thread assigns each worker TensorRow via this - std::unique_ptr>>> collector_queue_; - std::unordered_map word_cnt_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc deleted file mode 100644 index c943f8bd7a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc +++ /dev/null @@ -1,185 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/cache_base_op.h" -#include -#include -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -// A print method typically used for debugging -void CacheBase::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nCache client:\n" << *cache_client_ << "\n\n"; - } -} -// Overrides base class reset method. When an operator does a reset, it cleans up any state -// info from it's previous execution and then initializes itself so that it can be executed -// again. -Status CacheBase::Reset() { - if (sampler_ != nullptr) { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - } - // Wake up the workers to get them going again in a new epoch - MS_LOG(DEBUG) << Name() << " resetting."; - epoch_sync_.Set(); - return Status::OK(); -} -CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, sampler), - cache_client_(cache_client), - rows_per_buffer_(rows_per_buf), - // We can cause deadlock if this internal Connector size is too small. - keys_miss_(num_workers_, 1, connector_capacity_) { - io_block_queues_.Init(num_workers, op_connector_size); -} -// Common function to fetch samples from the sampler and send them using the io_block_queues to -// the parallel workers -Status CacheBase::FetchSamplesToWorkers() { - int64_t buf_cnt = 0; - int64_t wait_cnt = 0; - do { - epoch_sync_.Clear(); - std::vector keys; - int64_t row_cnt = 0; - keys.reserve(rows_per_buffer_); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (!sampler_buffer->eoe()) { - TensorRow sample_row; - RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { - keys.push_back(*itr); - ++row_cnt; - if (row_cnt % rows_per_buffer_ == 0) { - auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (!keys.empty()) { - auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); - } - // send the eoe - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - // If repeat but the not last repeat, wait for reset. - if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; - RETURN_IF_NOT_OK(epoch_sync_.Wait()); - } else { - // We can break out from the loop. - break; - } - } while (true); - // Flow the eof before exit - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - // Ask all the workers to quit. - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); -} -Status CacheBase::FetchFromCache(int32_t worker_id) { - int64_t buffer_id = worker_id; - std::unique_ptr blk; - do { - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); - if (blk->eof()) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else if (blk->eoe()) { - if (AllowCacheMiss()) { - // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from - // a sampler, send a eoe to physical leaf op as well. - std::vector eoe; - eoe.push_back(eoe_row_id); - RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); - } - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(blk->GetKeys(&keys)); - if (keys.empty()) { - // empty key is a quit signal for workers - break; - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - std::unique_ptr que = std::make_unique(); - TensorTable ttbl; - RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); - auto row_it = ttbl.begin(); - std::vector cache_miss; - cache_miss.reserve(keys.size()); - for (auto row_id : keys) { - auto &row = *row_it; - if (row.empty()) { - if (AllowCacheMiss()) { - cache_miss.push_back(row_id); - } else { - std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; - RETURN_STATUS_UNEXPECTED(errMsg); - } - } - que->push_back(std::move(row)); - ++row_it; - } - db->set_tensor_table(std::move(que)); - if (AllowCacheMiss()) { - // Because of the way connector works, we push unconditionally even cache_miss can be empty. - RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); - } - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - } while (true); - return Status::OK(); -} -Status CacheBase::RegisterResources() { - RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - return Status::OK(); -} -CacheBase::~CacheBase() {} -Status CacheBase::UpdateColumnMapFromCache() { - Status rc; - // Get the schema from the server. It may not be there yet. So tolerate the error. - if (column_name_id_map_.empty()) { - rc = cache_client_->FetchSchema(&column_name_id_map_); - if (rc == Status(StatusCode::kFileNotExist)) { - MS_LOG(DEBUG) << "Schema not in the server yet."; - rc = Status::OK(); - } - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h deleted file mode 100644 index 9f90b7cd9d..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ - -#include -#include -#include -#include -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/cache/cache_service.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/util/queue.h" -#include "dataset/util/wait_post.h" -#include "dataset/engine/datasetops/cache_base_op.h" -namespace mindspore { -namespace dataset { -/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. -/// \see CacheOp -/// \see CacheLookupOp -class CacheBase : public ParallelOp { - public: - /// \brief Base class constructor - /// \param num_workers Number of parallel workers - /// \param op_connector_size Connector size - /// \param rows_per_buf Number of rows per buffer - /// \param cache_client CacheClient for communication to the CacheServer - /// \param sampler Sampler which is mandatory - CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler); - /// \brief Destructor - ~CacheBase(); - - /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state - /// info from it's previous execution and then initializes itself so that it can be executed - /// again. - /// \return Status - The error code return - Status Reset() override; - - /// \brief A print method typically used for debugging - /// \param out The output stream to write output to - /// \param show_all A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - /// \brief << Stream output operator overload - /// \notes This allows you to write the debug print info using stream operators - /// \param out reference to the output stream being overloaded - /// \param mo reference to the CacheOp to display - /// \return the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) { - mo.Print(out, false); - return out; - } - - /// \brief Getter for the cache client - /// \return shared ptr to the cache client - std::shared_ptr cache_client() { return cache_client_; } - /// \brief Setter for the cache client - void SetCacheClient(std::shared_ptr cache_client) { cache_client_ = std::move(cache_client); } - /// \brief Derived class must implement this method if a cache miss is treated as error - virtual bool AllowCacheMiss() = 0; - - protected: - constexpr static int32_t eoe_row_id = -1; - std::shared_ptr cache_client_; - WaitPost epoch_sync_; - int32_t rows_per_buffer_; - Connector> keys_miss_; - - /// \brief Common function to register resources for interrupt - /// \note Derived should override this function for extra resources to be registered - virtual Status RegisterResources(); - /// \brief This function is called by main thread to send samples to the worker thread. - /// \note It is a non-virtual function - /// \return Status object - Status FetchSamplesToWorkers(); - /// \brief This function is called by each worker to fetch rows from the cache server for a given set of - /// sample row id's - /// \return Status object - Status FetchFromCache(int32_t worker_id); - /// \brief Get the column map from cache server - Status UpdateColumnMapFromCache(); - - private: - constexpr static int32_t connector_capacity_ = 1024; - QueueList> io_block_queues_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc deleted file mode 100644 index 196a8790df..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/cache_lookup_op.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/execution_tree.h" -#include "utils/log_adapter.h" -#include "utils/system/crc32c.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - rows_per_buffer_ = cfg->rows_per_buffer(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status CacheLookupOp::Builder::SanityCheck() const { - if (build_cache_client_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient"); - } - // Make sure the cache client has a valid session - if (!build_cache_client_->session_id()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Cache client for CacheLookupOp is missing session id"); - } - return Status::OK(); -} - -// The builder "build" method creates the final object and does some init on it -Status CacheLookupOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, - build_cache_client_, build_sampler_); - return Status::OK(); -} -Status CacheLookupOp::operator()() { - if (!sampler_) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "CacheLookupOp requires a sampler before it can be executed!"); - } - RETURN_IF_NOT_OK(RegisterResources()); - // Kick off the workers - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1))); - // required task group sync after launching workers - TaskManager::FindMe()->Post(); - // We have to wait until the leaf op has handshake with us. - RETURN_IF_NOT_OK(leaf_op_wp_.Wait()); - RETURN_IF_NOT_OK(FetchSamplesToWorkers()); - return Status::OK(); -} -Status CacheLookupOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(FetchFromCache(worker_id)); - return Status::OK(); -} -Status CacheLookupOp::ResetSampler() { return Status::OK(); } -Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { - // We act like a sampler and as a dataset op. During handshake with leaf op, - // We must wait until the leaf op has indexed everything. - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op)); - // Now we notify the main thread handshake has finished. - leaf_op_wp_.Set(); - return Status::OK(); -} -Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } -void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } -Status CacheLookupOp::GetNextSample(std::unique_ptr *out_buffer) { - std::vector cache_miss; - RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); - // Ignore the case we have no cache miss, we can't return empty samples. - while (cache_miss.empty()) { - RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); - } - // Special code for eoe - if (cache_miss.at(0) == eoe_row_id) { - *out_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - std::shared_ptr sample_ts; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); - auto idPtr = sample_ts->begin(); - for (auto i = 0; i < cache_miss.size(); ++i) { - *idPtr = cache_miss.at(i); - ++idPtr; - } - TensorRow row; - row.push_back(sample_ts); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} -Status CacheLookupOp::RegisterResources() { - RETURN_IF_NOT_OK(CacheBase::RegisterResources()); - RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks())); - return Status::OK(); -} -Status CacheLookupOp::ComputeColMap() { - // We don't know the column map at this point unless we contact the cache server - // to fetch the schema but the cache server may not have it at this point either. - // So we will just return OK and let MergeOp (our parent) to handle it. - return Status::OK(); -} - -// Visitor accept method for NodePass -Status CacheLookupOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h deleted file mode 100644 index 526fb7c3a7..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/cache_base_op.h" - -namespace mindspore { -namespace dataset { -/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. -/// \note For non-mappable dataset, please see CacheOp -/// \see CacheOp -class CacheLookupOp : public CacheBase, public Sampler { - public: - class Builder { - public: - /// \brief Builder constructor. Creates the builder object. - /// \note No default args - Builder(); - - /// Default destructor - ~Builder() = default; - - /// Setter method. - /// \treturn Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetClient(std::shared_ptr cache_client) { - build_cache_client_ = cache_client; - return *this; - } - - /// \brief Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - build_sampler_ = std::move(sampler); - return *this; - } - - /// \brief The builder "build" method creates the final object and does some init on it. - /// \param ptr The shared_ptr to the new CacheLookupOp object - /// \return Status - Status Build(std::shared_ptr *ptr); - - private: - int32_t build_num_workers_; - int32_t rows_per_buffer_; - int32_t build_op_connector_size_; - std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; - - // Check if the required parameters are set by the builder. - // \return Status The error code return - Status SanityCheck() const; - }; - /// \brief Constructor - /// \note It takes the same argument as the base class. - /// \see CacheBase - CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler) - : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} - ~CacheLookupOp() = default; - // As a parallel op, we override these two functions - Status operator()() override; - Status WorkerEntry(int32_t worker_id) override; - // As a sampler, we override the following functions - Status ResetSampler() override; - Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; - Status InitSampler() override; - Status GetNextSample(std::unique_ptr *out_buffer) override; - void Print(std::ostream &out, bool show_all) const override; - bool AllowCacheMiss() override { return true; } - std::string Name() const override { return "CacheLookupOp"; } - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - protected: - Status ComputeColMap() override; - - private: - WaitPost leaf_op_wp_; - - Status RegisterResources() override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc deleted file mode 100644 index f2d5173348..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc +++ /dev/null @@ -1,302 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/cache_merge_op.h" - -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -CacheMergeOp::~CacheMergeOp() = default; -void CacheMergeOp::Print(std::ostream &out, bool show_all) - const { // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\n\n"; - } -} -CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, - std::shared_ptr cache_client, const std::shared_ptr &sampler) - : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} -Status CacheMergeOp::operator()() { - // A queue of row id to let cleaner send cache miss rows to the cache server - // We don't want a small queue as this will block the parallel op workers. - // A row id is 8 byte integer. So bigger size doesn't consume a lot of memory. - static const int32_t queue_sz = 512; - io_que_ = std::make_unique>(queue_sz); - RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1))); - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1))); - // One dedicated thread to move TensorRow from the pool to the cache server - for (auto i = 0; i < num_cleaners_; ++i) { - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this))); - } - TaskManager::FindMe()->Post(); - return Status::OK(); -} -// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait -// until it shows up in the pool. -Status CacheMergeOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::shared_ptr cache_hit_stream = child_[kCacheHitChildIdx]; - std::unique_ptr db_ptr; - RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); - while (!db_ptr->eof()) { - if (db_ptr->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - db_ptr.reset(); - RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); - } else { - // See if there is any missing row - auto tbl = std::make_unique(); - while (db_ptr->NumRows() > 0) { - TensorRow row; - RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); - if (row.empty()) { - auto row_id = row.getId(); - TensorRowRequest *rq = nullptr; - RETURN_IF_NOT_OK(GetRq(row_id, &rq)); - // Block until the row shows up in the pool. - RETURN_IF_NOT_OK(rq->Wait(&row)); - } - tbl->push_back(std::move(row)); - } - db_ptr->set_tensor_table(std::move(tbl)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); - RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); - } - } - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); - return Status::OK(); -} -Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { - TaskManager::FindMe()->Post(); - // We will simply pop TensorRow from the stream and insert them into the pool and - // wake up any worker that is awaiting on the missing TensorRow. - // If we see an eoe, ignore it. For eof, we exit. - std::shared_ptr cache_missing_stream = child_[kCacheMissChildIdx]; - // Before we start, cache the schema at the server. Pick one of the workers - // do it. The schema should have been done at prepare time. - if (workerId == 0) { - RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); - } - std::unique_ptr db_ptr; - RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); - while (!db_ptr->eof()) { - if (db_ptr->eoe()) { - // Ignore it. - MS_LOG(DEBUG) << "Ignore eoe"; - } else { - while (db_ptr->NumRows() > 0) { - TensorRow row; - RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); - row_id_type row_id = row.getId(); - if (row_id < 0) { - std::string errMsg = "Expect positive row id: " + std::to_string(row_id); - RETURN_STATUS_UNEXPECTED(errMsg); - } - TensorRowRequest *rq = nullptr; - RETURN_IF_NOT_OK(GetRq(row_id, &rq)); - rq->WakeUpAny(std::move(row)); - // Let the cleaner to flush out this row (async) to the cache server. - RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); - } - } - RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); - } - return Status::OK(); -} -Status CacheMergeOp::Cleaner() { - TaskManager::FindMe()->Post(); - while (true) { - row_id_type row_id; - RETURN_IF_NOT_OK(io_que_->PopFront(&row_id)); - if (row_id < 0) { - break; - } - TensorRowRequest *rq = nullptr; - RETURN_IF_NOT_OK(GetRq(row_id, &rq)); - if (rq->GetState() == TensorRowRequest::State::kClean) { - // If already flushed, move on to the next one. - continue; - } - TensorRow row; - RETURN_IF_NOT_OK(rq->Release(&row)); - CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error."); - Status rc = cache_client_->WriteRow(row); - // Bad rc should not bring down the pipeline - if (rc.IsError()) { - MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); - } - rq->SetState(TensorRowRequest::State::kClean); - } - return Status::OK(); -} - -Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { - RETURN_UNEXPECTED_IF_NULL(out); - std::unique_lock lck(mux_); - auto it = cache_miss_map_.find(row_id); - if (it != cache_miss_map_.end()) { - *out = it->second.GetMutablePointer(); - } else { - // We will create a new one. - auto alloc = Services::GetAllocator(); - auto r = cache_miss_map_.emplace(row_id, MemGuard>(alloc)); - if (r.second) { - auto &mem = r.first->second; - RETURN_IF_NOT_OK(mem.allocate(1, row_id)); - *out = mem.GetMutablePointer(); - } else { - RETURN_STATUS_UNEXPECTED("Map insert fail."); - } - } - return Status::OK(); -} -Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own - // specific logic - CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); - RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); - // Get the computed check sum from all ops in the cache miss class - uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]); - // This is a mappable cache op so the id's need to be generated. - // Construct the cache - const bool generate_ids = false; - Status rc = cache_client_->CreateCache(cache_crc, generate_ids); - if (rc.get_code() == StatusCode::kDuplicateKey) { - // We are told the cache has been created already. - MS_LOG(INFO) << "Cache created already"; - rc = Status::OK(); - } - RETURN_IF_NOT_OK(rc); - return Status::OK(); -} -Status CacheMergeOp::ComputeColMap() { - CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); - if (column_name_id_map().empty()) { - column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map(); - } - CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); - return Status::OK(); -} -Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { - RETURN_UNEXPECTED_IF_NULL(out); - // Block until the missing row is in the pool. - RETURN_IF_NOT_OK(use_count_.P()); - std::unique_lock lck(dq_mux_); - CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); - *out = std::move(row_.front()); - row_.pop_front(); - return Status::OK(); -} -void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { - std::unique_lock lck(dq_mux_); - // Technically number of this row shows up in the cache miss stream is equal to the number - // of P() call. However the cleaner wants it too. So we need an extra copy. - if (GetState() == State::kEmpty) { - // We will do a deep copy - for (auto &ts : row) { - auto out_ts = std::make_shared(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); - cleaner_copy_.push_back(out_ts); - } - cleaner_copy_.setId(row.getId()); - // Change the state to dirty - SetState(State::kDirty); - } - row_.push_back(std::move(row)); - // Bump up the use count by 1. This wake up any parallel worker which is waiting - // for this row. - use_count_.V(); -} -Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { - RETURN_UNEXPECTED_IF_NULL(out); - // We are not holding any mutex here because the cleaner isn't really touching the deque row_. - // In case we have multiple cleaners and they all see the copy, only one of them will - // get it. - auto expected = State::kDirty; - if (st_.compare_exchange_strong(expected, State::kClean)) { - *out = std::move(cleaner_copy_); - } - return Status::OK(); -} -// Builder constructor. Creates the builder object. -CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); - build_num_cleaners_ = 1; -} - -// Check if the required parameters are set by the builder. -Status CacheMergeOp::Builder::SanityCheck() const { - if (build_cache_client_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient"); - } - // Make sure the cache client has a valid session - if (!build_cache_client_->session_id()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Cache client for CacheMergeOp is missing session id"); - } - return Status::OK(); -} - -// The builder "build" method creates the final object and does some init on it -Status CacheMergeOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, build_num_cleaners_, - build_cache_client_, build_sampler_); - return Status::OK(); -} - -// Pre-Visitor accept method for NodePass -Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) { - // Downcast shared pointer then call the pre-visitation - return p->PreRunOnNode(shared_from_base(), modified); -} - -// Visitor accept method for NodePass -Status CacheMergeOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status CacheMergeOp::EoeReceived(int32_t worker_id) { - // If we are in a repeat path, send the eoe up. - // Otherwise ignore it. - if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { - return DatasetOp::EoeReceived(worker_id); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h deleted file mode 100644 index 60e2ebd0be..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/tensor_row.h" -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/util/queue.h" -#include "dataset/util/semaphore.h" - -namespace mindspore { -namespace dataset { -/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single -/// stream -class CacheMergeOp : public ParallelOp { - public: - // Some handshake structures among the main thread, cleaner threads and parallel op threads. - class TensorRowRequest { - public: - enum class State : uint8_t { - kEmpty = 0, // No row in the deque - kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. - kClean = 2 // The row has been flushed already. - }; - explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} - ~TensorRowRequest() = default; - State GetState() const { return st_; } - void SetState(State newState) { st_ = newState; } - Status Wait(TensorRow *out); - void WakeUpAny(TensorRow &&row); - Status Release(TensorRow *out); - - private: - std::mutex dq_mux_; - std::atomic st_; - Semaphore use_count_; - std::deque row_; - TensorRow cleaner_copy_; - }; - - constexpr static int kCacheHitChildIdx = 0; // Cache hit stream - constexpr static int kCacheMissChildIdx = 1; // Cache miss stream - - /// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of - /// the arguments for constructing it. Use the builder by setting each argument - /// with the provided set methods, and then finally call the build method to execute - /// the actual construction. - class Builder { - public: - /// Builder constructor. Creates the builder object. - /// \note No default args - Builder(); - - /// Default destructor - ~Builder() = default; - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetClient(std::shared_ptr cache_client) { - build_cache_client_ = cache_client; - return *this; - } - - /// \brief Setter method - /// \param sampler - /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - build_sampler_ = std::move(sampler); - return *this; - } - - /// \brief Setter method - /// \param num_cleaners - /// \return Builder setter method returns reference to the builder. - Builder &SetNumCleaner(int32_t num_cleaners) { - build_num_cleaners_ = num_cleaners; - return *this; - } - - /// The builder "build" method creates the final object and does some init on it. - /// \param ptr The shared_ptr to the new CacheMergeOp object - /// \return Status - Status Build(std::shared_ptr *ptr); - - private: - int32_t build_num_workers_; - int32_t build_op_connector_size_; - int32_t build_num_cleaners_; - std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; - - /// Check if the required parameters are set by the builder. - /// \return Status The error code return - Status SanityCheck() const; - }; - - /// \brief Constructor - /// \param numWorkers Number of parallel workers as a derived class of ParallelOp - /// \param opConnector Size Connector size as a derived class of ParallelOp - /// \param numCleaners Number of cleaners to move cache miss rows into the cache server - /// \param cache_client CacheClient to commmunicate with the Cache server - /// \param sampler as a derived class of ParallelOp - CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, - std::shared_ptr cache_client, const std::shared_ptr &sampler); - ~CacheMergeOp(); - void Print(std::ostream &out, bool show_all) const override; - friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { - mo.Print(out, false); - return out; - } - /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and - /// the threads for the cleaners. - /// \return - Status operator()() override; - /// \brief Entry function for worker thread that fetch rows from CacheLookupOp - /// \param workerId - /// \return Status object - Status WorkerEntry(int32_t workerId) override; - Status PrepareNodePostAction() override; - /// \brief Entry function for worker thread that fetch rows from the cache miss stream - /// \param workerId - /// \return Status object - Status CacheMissWorkerEntry(int32_t workerId); - Status GetRq(row_id_type row_id, TensorRowRequest **); - - /// \brief Base-class override for NodePass pre-visit acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status PreAccept(NodePass *p, bool *modified) override; - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - /// \brief Base-class override for eoe handling - /// \param worker_id - /// \return Status object - Status EoeReceived(int32_t worker_id) override; - - protected: - Status ComputeColMap() override; - - private: - std::mutex mux_; - std::map>> cache_miss_map_; - std::unique_ptr> io_que_; - std::shared_ptr cache_client_; - int32_t num_cleaners_; - - /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for - /// moving cache miss TensorRow into the CacheServer. - /// \return Status object - Status Cleaner(); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc deleted file mode 100644 index 149f2b0bbb..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/cache_op.h" - -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/task_manager.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - rows_per_buffer_ = cfg->rows_per_buffer(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status CacheOp::Builder::SanityCheck() const { - if (build_cache_client_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient"); - } - // Make sure the cache client has a valid session - if (!build_cache_client_->session_id()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id"); - } - return Status::OK(); -} - -// The builder "build" method creates the final object and does some init on it -Status CacheOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_, - build_sampler_); - RETURN_IF_NOT_OK((*ptr)->InitCache()); - - return Status::OK(); -} - -// Constructor of CacheOp -CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler) - : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), - num_guys_in_(0), - phase_(Phase::kBuildPhase) {} - -// Destructor -CacheOp::~CacheOp() = default; - -// Private function for cache setup/init work just after construction -Status CacheOp::InitCache() { return Status::OK(); } - -// This class functor will provide the master loop that drives the logic for performing the work -Status CacheOp::operator()() { - if (!sampler_) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "CacheOp requires a sampler before it can be executed!"); - } - RETURN_IF_NOT_OK(RegisterResources()); - // Kick off the workers - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1))); - // required task group sync after launching workers - TaskManager::FindMe()->Post(); - // Wait for the workers to finish caching the rows. - RETURN_IF_NOT_OK(WaitForCachingAllRows()); - RETURN_IF_NOT_OK(FetchSamplesToWorkers()); - return Status::OK(); -} -Status CacheOp::CacheAllRows(int32_t worker_id) { - // If the current phase is to fill the cache, do it then. - if (phase_ == Phase::kBuildPhase) { - // We will take the chance to cache the schema at the server. - // Just do it once and pick one worker to do it. - if (worker_id == 0) { - RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); - } - MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id; - // SAVE mode loop - std::unique_ptr db_ptr; - RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); - while (!db_ptr->eof()) { - if (!db_ptr->eoe()) { - RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); - } else { - // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up - // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the - // the eoe to indicate the end of the epoch, we should next expect to get the eof. - // Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch - // from again. - RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); - if (!db_ptr->eof()) { - RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child."); - } - } - RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); - } - } - // Let the main guy know we are done. - auto last_guy_in = num_guys_in_.fetch_add(1); - if ((last_guy_in + 1) == num_workers_) { - rows_cache_done_.Set(); - } else { - // Let's do a sync up here. - RETURN_IF_NOT_OK(rows_cache_done_.Wait()); - } - return Status::OK(); -} -Status CacheOp::WaitForCachingAllRows() { - // Wait for the workers to finish caching the rows. - RETURN_IF_NOT_OK(rows_cache_done_.Wait()); - // Move from build phase to fetch phase if we are the one to fill the cache - if (phase_ == Phase::kBuildPhase) { - RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); - // Move to the next phase - phase_ = Phase::kFetchPhase; - } - // Get statistics from the server, and if we are not the one to create the cache, - // wait until the state changed from build phase to fetch base. - CacheClient::ServiceStat stat{}; - bool BuildPhaseDone = true; - do { - RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); - BuildPhaseDone = stat.cache_service_state == static_cast(CacheService::State::kFetchPhase); - if (!BuildPhaseDone) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - } while (!BuildPhaseDone); - const row_id_type min_key = stat.min_row_id; - const row_id_type max_key = stat.max_row_id; - num_rows_ = max_key - min_key + 1; - MS_LOG(INFO) << "Number of rows cached: " << num_rows_; - MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; - MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; - // Now all rows are cached and we have done a sync point check up. Next phase is - // is pick up fetch input from sampler and pass up to the caller. - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} -Status CacheOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(CacheAllRows(worker_id)); - RETURN_IF_NOT_OK(FetchFromCache(worker_id)); - return Status::OK(); -} -Status CacheOp::RegisterResources() { - RETURN_IF_NOT_OK(CacheBase::RegisterResources()); - RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); - return Status::OK(); -} - -// Base-class override for setting specific CacheOp configurations. This code will be called -// during the execution tree prepare phase BEFORE traversing down to child operators. -uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; } -// Base-class override for special eoe handler. -// CacheOp must override this because it shall not perform default handling of eoe. Instead -// the CacheOp manages actions related to the end of the epoch. -Status CacheOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} -// Base-class override for handling cases when an eof is received. -Status CacheOp::EofReceived(int32_t worker_id) { - // eofReceived is overloaded because we want to manually handle this eof. - // Specifically, the default behaviour is to pack it and flow it up to the next connection. - // In this case, we want a no-op behaviour so that we can perform correct action. - return Status::OK(); -} - -// Pre-Visitor accept method for NodePass -Status CacheOp::PreAccept(NodePass *p, bool *modified) { - // Downcast shared pointer then call the pre-visitation - return p->PreRunOnNode(shared_from_base(), modified); -} - -// Visitor accept method for NodePass -Status CacheOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -// A public wrapper for creating the cache through the client -Status CacheOp::CreateCache(uint32_t cache_crc) { - // This is a non-mappable cache op so the id's need to be generated. - // Construct the cache - const bool generate_ids = true; - Status rc = cache_client_->CreateCache(cache_crc, generate_ids); - if (rc.get_code() == StatusCode::kDuplicateKey) { - // We are told the cache has been created already. So we skip the build phase. - phase_ = Phase::kFetchPhase; - rc = Status::OK(); - } - RETURN_IF_NOT_OK(rc); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/dataset/engine/datasetops/cache_op.h deleted file mode 100644 index 6ec7e95ecf..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/cache_op.h +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ - -#include -#include -#include -#include -#include "dataset/engine/datasetops/cache_base_op.h" - -namespace mindspore { -namespace dataset { -/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset. -/// \note For mappable dataset, please see CacheLookupOp. -/// \see CacheLookupOp -class CacheOp : public CacheBase, public RandomAccessOp { - public: - // This CacheOp is for non-mappable case where it is divided into two phases. - // The first phase is we cache all the rows from the child (and let the cache server - // assigns row id). No read access in the first phase. Once the cache is fully built, - // we switch to second phase and fetch requests from the sampler. - enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; - - /// \brief The nested builder class inside of the CacheOp is used to help manage all of - /// the arguments for constructing it. Use the builder by setting each argument - /// with the provided set methods, and then finally call the build method to execute - /// the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - /// \brief Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - /// \brief Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - /// Setter method. - /// \return Builder setter method returns reference to the builder. - Builder &SetClient(std::shared_ptr cache_client) { - build_cache_client_ = cache_client; - return *this; - } - - /// \brief Setter method - /// \param rows_per_buffer - /// \return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - rows_per_buffer_ = rows_per_buffer; - return *this; - } - - /// \brief Setter method - /// \param sampler - /// \return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - build_sampler_ = std::move(sampler); - return *this; - } - - /// \brief The builder "build" method creates the final object and does some init on it. - /// \param ptr The shared_ptr to the new CacheOp object - /// \return Status - Status Build(std::shared_ptr *ptr); - - private: - int32_t build_num_workers_; - int32_t rows_per_buffer_; - int32_t build_op_connector_size_; - std::shared_ptr build_cache_client_; - std::shared_ptr build_sampler_; - - /// \brief Check if the required parameters are set by the builder. - /// \return Status The error code return - Status SanityCheck() const; - }; - - /// \brief Constructor of CacheOp - /// \note The builder class should be used to call it. - /// \param num_workers The number of worker threads. - /// \param op_connector_size The size of each queue in the connector. - CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, - std::shared_ptr cache_client, std::shared_ptr sampler); - - // Destructor - ~CacheOp(); - - /// \brief Base-class override for setting specific CacheOp configurations. This code will be called - /// during the execution tree prepare phase BEFORE traversing down to child operators. - uint32_t PrepareFlags() const override; - /// \brief Base-class override for special eoe handler. - /// CacheOp must override this because it shall not perform default handling of eoe. Instead - /// the CacheOp manages actions related to the end of the epoch. - /// \return Status - The error code return - Status EoeReceived(int32_t worker_id) override; - /// \brief Base-class override for NodePass pre-visit acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status PreAccept(NodePass *p, bool *modified) override; - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - /// \brief Base-class override for handling cases when an eof is received. - /// \param worker_id - The worker id - /// \return Status - The error code return - Status EofReceived(int32_t worker_id) override; - Status operator()() override; - Status WorkerEntry(int32_t worker_id) override; - /// \brief Base-class override for handling cases if we allow cache miss - bool AllowCacheMiss() override { return false; } - /// \brief Base-class override for the name of this operator - std::string Name() const override { return "CacheOp"; } - /// \brief A public wrapper for creating the cache through the client - /// \param[in] cache_crc The crc that identifies the cache - /// \see cache_pass.cc - /// \return Status return code - Status CreateCache(uint32_t cache_crc); - - private: - WaitPost rows_cache_done_; - std::atomic num_guys_in_; - Phase phase_; - /// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler. - /// \return Status object - Status WaitForCachingAllRows(); - /// \brief For non-mappable dataset, there is a build phase where we cache all the rows. - /// \return Status object - Status CacheAllRows(int32_t worker_id); - Status RegisterResources() override; - /// \brief Private function for cache setup/init work just after construction - /// \return Status The error code return - Status InitCache(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc deleted file mode 100644 index 2cf2e8045f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/concat_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -ConcatOp::Builder::Builder() { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -// The builder "build" method creates the final object. -Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { - *ptr = std::make_shared(builder_op_connector_size_); - return Status::OK(); -} - -// Constructor of the ConcatOp. -ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} - -// A function that prints info about the Operator -void ConcatOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nDatasets: " << children_num_ << "\n\n"; - } -} - -// Main entry point for Concat -Status ConcatOp::operator()() { - // The children_num_ parameter needs to be put here - children_num_ = static_cast(child_.size()); - TaskManager::FindMe()->Post(); - std::unique_ptr buf; - int eof_count = 0; - while (eof_count == 0) { - for (int i = 0; i < children_num_; i++) { - // 1. Read the first buffer - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); - if (buf->eof()) { - eof_count++; - continue; - } - // 2. Do verification as for column name, column data type and rank of column data - if (!buf->eoe()) { - RETURN_IF_NOT_OK(Verify(i, buf)); - } - // 3. Put the data into output_connector - while (!buf->eoe() && !buf->eof()) { - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); - RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); - } - } - // 4. Add eoe buffer after get buffer from all child - if (eof_count == 0) { - auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - } - } - CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, - "Something went wrong, eof count does not match the number of children."); - // 5. Add eof buffer in the end manually - MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - return Status::OK(); -} - -Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { - TensorRow new_row; - buf->GetRow(0, &new_row); - - if (id == 0) { - // Obtain the data type and data rank in child[0] - for (auto item : new_row) { - data_type_.push_back(item->type()); - data_rank_.push_back(item->Rank()); - } - } else { - // Compare the data type and data rank with these in child[0] - int32_t index = 0; - for (auto item : new_row) { - if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { - RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); - } - } - } - return Status::OK(); -} - -// We need to overwrite the super class ComputeColMap here because the number of children is more than 1. -Status ConcatOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - // Obtain columns_name_id_map from child_[0] - column_name_id_map_ = child_[0]->column_name_id_map(); - if (column_name_id_map_.empty()) { - RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); - } - // Verify all children have the same column name map - for (int32_t i = 0; i < child_.size(); ++i) { - if (child_[i]->column_name_id_map() != column_name_id_map_) { - RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); - } - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h deleted file mode 100644 index e3dd890d07..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/concat_op.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ - -#include -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class ConcatOp : public PipelineOp { - public: - // The nested builder class inside of the ConcatOp is used to help manage all of the arguments - // for constructing it. This Concat op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new ConcatOp object - Status Build(std::shared_ptr *); - - private: - int32_t builder_op_connector_size_; - }; - - // Constructor of the ConcatOp. - // @note The builder class should be used to call it - // @param op_connector_size - connector size - explicit ConcatOp(int32_t op_connector_size); - - // Destructor - ~ConcatOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param ro - reference to the ConcatOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { - ro.Print(out, false); - return out; - } - - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ConcatOp"; } - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - private: - Status Verify(int32_t id, const std::unique_ptr &buf); - - int32_t children_num_; // The num of child of parent node. - std::unordered_map column_name_id_; // Mapping between col index and col name - std::vector data_type_; - std::vector data_rank_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc deleted file mode 100644 index a963033833..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc +++ /dev/null @@ -1,391 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/dataset_op.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "utils/system/crc32c.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Constructor -DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) - : oc_queue_size_(op_connector_size), - sampler_(sampler), - operator_id_(kInvalidOperatorId), - tree_(nullptr), - state_(OpState::kDeOpIdle), - op_ctrl_flags_(kDeOpNone), - out_connector_(nullptr) { - // The operator starts out with an invalid operator id. The only way to - // get it out of invalid state is to assign the operator to an execution tree. -} - -// Adds a operator to become our child. -Status DatasetOp::AddChild(std::shared_ptr child) { - if (std::dynamic_pointer_cast(child) != nullptr) { - std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node"); - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (operator_id_ == kInvalidOperatorId) { - std::string err_msg( - "Cannot add child node. Tree node connections can only" - "be made if the node belongs to a tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // disallow relationships with other trees - if (tree_ != child->tree_) { - std::string err_msg( - "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - child_.push_back(child); - child->AddParent(this); - return Status::OK(); -} - -Status DatasetOp::RemoveChild(std::shared_ptr child) { - if (operator_id_ == kInvalidOperatorId) { - std::string err_msg( - "Cannot remove child node. Tree node connections can only" - "be made if the node belongs to a tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // disallow relationships with other trees - if (tree_ != child->tree_) { - std::string err_msg( - "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end()); - child->RemoveParent(this); - return Status::OK(); -} - -Status DatasetOp::InsertAsParent(std::shared_ptr to_add) { - for (auto &prev_parent : this->parent_) { - RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this())); - RETURN_IF_NOT_OK(prev_parent->AddChild(to_add)); - } - RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this())); - if (tree_->root()->id() == this->id()) { - tree_->AssignRoot(to_add); - } - return Status::OK(); -} - -// Adds a parent operator to this operator -void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } - -// Removes a parent operator from this operator -void DatasetOp::RemoveParent(const DatasetOp *parent) { - parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); -} - -// Removes this node from the tree and connects it's parent/child together -Status DatasetOp::Remove() { - if (parent_.size() > 1) { - std::string err_msg("No support for op removal if the operator has more than one parent"); - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (child_.size() > 1) { - std::string err_msg("No support for op removal if the operator has more than one child"); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Scenario's when removing node B: - // A -> B -> C - // A -> B - // B -> C - // - // If we remove B, then first take our child A and update it's parent to be C - // It's possible the parent is null if we are the root node being removed. - if (!child_.empty()) { - // If we have a parent, then assign chlid's parent to point to our parent. - if (!parent_.empty()) { - child_[0]->parent_[0] = parent_[0]; - } else { - // We don't have a parent, so we are the root node being removed. - // clear the parent list of our child so that it becomes the new root. - child_[0]->parent_.clear(); - tree_->AssignRoot(child_[0]); - } - } - - // Next, if we had a parent, then set it's child to be our child. - if (!parent_.empty()) { - // if we have a child, then set our parent to point to it - if (!child_.empty()) { - parent_[0]->child_[0] = child_[0]; - } else { - // We don't have a child, so clear the child list of the current - // parent because it will be empty once we are removed. - parent_[0]->child_.clear(); - } - } - - // Finally, clear "this" op's parent and child pointers since we have just - // disconnected it from the tree and invalidate it's fields. - child_.clear(); - parent_.clear(); - operator_id_ = kInvalidOperatorId; - tree_ = nullptr; - - return Status::OK(); -} - -// Getter function to get a shared pointer to our child -std::shared_ptr DatasetOp::child(int32_t child_index) const { - std::shared_ptr return_op = nullptr; - if (child_.empty()) { - return return_op; - } - MS_ASSERT(child_index < static_cast(child_.size())); - // Return a shared pointer - return child_[child_index]; -} - -// Getter function to get the parent pointer -void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { - if (parent_.empty()) { - // common case if this is a root node - *parent = nullptr; - } else { - MS_ASSERT(parent_index < static_cast(parent_.size())); - *parent = parent_[parent_index]; - } -} - -// Creates the connector within this operator -void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { - MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers - << ". Consumer: " << num_consumers << "."; - if (oc_queue_size_ > 0) { - out_connector_ = std::make_unique(num_producers, // The number of producers - num_consumers, // Only one consumer (the training App) - oc_queue_size_); - } else { - // Some op's may choose not to have an output connector - MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << "."; - out_connector_ = nullptr; - } -} - -// A print method typically used for debugging. showAll of true will recursively descend to child prints -void DatasetOp::Print(std::ostream &out, bool show_all) const { - // When show_all is false, we display a 1 liner piece of text for the op. - // When show_all is true, we display more detailed output for the op. - // Derived printers should show their own header info, then call base class printer, followed by - // derived-specific items. - // For now, the base class doesn't have any summary info to show so it's a no-op in that case. - if (show_all) { - // The detailed display will show common base class info of the op. Allow the derived class to print - // it's own id and name though as the first line. - out << "\nNumber of children : " << child_.size(); - for (size_t i = 0; i < child_.size(); i++) { - out << "\n Child[" << i << "] id: " << child_[i]->id(); - } - out << "\nNumber of parents : " << parent_.size(); - for (size_t i = 0; i < parent_.size(); i++) { - out << "\n Parent[" << i << "] id: " << parent_[i]->id(); - } - out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex - << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); - if (sampler_) { - sampler_->Print(out, show_all); - } - } -} - -// Gets the next buffer from the given child -Status DatasetOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { -#if defined(_WIN32) || defined(_WIN64) - RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), p_buffer, retry_if_eoe)); -#else - std::unique_ptr next_buff; - // pop is a blocked call and will throw an interruption if the whole group shuts down. - RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), &next_buff, retry_if_eoe)); - - *p_buffer = std::move(next_buff); -#endif - return Status::OK(); -} - -// Gets the next buffer from the given child . This function also has built-in eoe and eof -// message handling so that child classes don't have to manually code pass-through logic when -// those messages are received. -Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id, int32_t child_index) { - if (child_.size() == 0) { - return this->GetNextBuffer(p_buffer, worker_id); - } - CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index)); - std::shared_ptr child = child_[child_index]; - std::unique_ptr buf; - RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); - // Loop until non EOE is received - while (buf->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - if (state_ == OpState::kDeOpIdle) { - *p_buffer = std::move(buf); - return Status::OK(); - } - RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); - } - // Check if the last buf is next eof - if (buf->eof()) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - *p_buffer = std::move(buf); - return Status::OK(); -} - -// Performs handling for when an eoe message is received. -// The base class implementation simply flows the eoe message to output. Derived classes -// may override if they need to perform special eoe handling. -Status DatasetOp::EoeReceived(int32_t worker_id) { - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - return (out_connector_->Add(static_cast(worker_id), std::move(eoe_buffer))); -} - -// Performs handling for when an eof message is received. -// The base class implementation simply flows the eof message to output. Derived classes -// may override if they need to perform special eof handling. -Status DatasetOp::EofReceived(int32_t worker_id) { - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - return (out_connector_->Add(static_cast(worker_id), std::move(eof_buffer))); -} - -// During tree prepare phase, operators may have specific pre-operations to perform depending on -// their role. -Status DatasetOp::PrepareNodePreAction() { return Status::OK(); } - -// During tree prepare phase, operators may have specific post-operations to perform depending on -// their role. -Status DatasetOp::PrepareNodePostAction() { - // Creating Connector object for each op. - // The consumer of the root node is assumed to be one thread. - // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. - if (parent_.empty()) { - this->CreateConnector(num_producers(), 1); - } else { - this->CreateConnector(num_producers(), parent_[0]->num_consumers()); - } - if (out_connector_) { - RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks())); - } - RETURN_IF_NOT_OK(this->RegisterWorkerConnectors()); - - // Generate the column name map for the current op. - RETURN_IF_NOT_OK(this->ComputeColMap()); - - return Status::OK(); -} - -// Getter function. Base class does not have any special flags setting. -uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; } - -// Derived classes may implement the reset function if the operator is stateful and needs -// specific reset handling that is not contained in this common code version of the reset. -Status DatasetOp::Reset() { - state_ = OpState::kDeOpRunning; - return Status::OK(); -} - -// gives a string output for the column map for handy debug printing -std::string DatasetOp::ColumnNameMapAsString() const { - std::string outStr = "Column name id map: "; - for (auto &it : column_name_id_map_) { - outStr += (" " + it.first + ":" + std::to_string(it.second)); - } - return outStr; -} - -// Computing the assignment of the column name map. -// This just inherits the column map from its first child, can only be used if the number of children is 1. -// Operations changing the column map must overwrite this function. -Status DatasetOp::ComputeColMap() { - if (child_.size() > 1) { - RETURN_STATUS_UNEXPECTED("Assigning column name map from child only works for single-child operators."); - } - if (column_name_id_map_.empty()) { - column_name_id_map_ = child_[0]->column_name_id_map(); - if (column_name_id_map_.empty()) { - RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); - } - MS_LOG(DEBUG) << "Setting column map:\n" << DatasetOp::ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -Status DatasetOp::PreAccept(NodePass *p, bool *modified) { - // DatasetOp is the base class of visitor target pre-visit. - // This method will only be called if its derived class does not implement one. - return p->PreRunOnNode(shared_from_this(), modified); -} - -Status DatasetOp::Accept(NodePass *p, bool *modified) { - // DatasetOp is the base class of visitor target. - // This method will only be called if its derived class does not implement one. - return p->RunOnNode(shared_from_this(), modified); -} - -// Getter for the sampler, and it also removes the sampler from the op -Status DatasetOp::FetchRemoveSampler(std::shared_ptr *sampler) { - *sampler = sampler_; // It's okay if it sampler_ points to nullptr - sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler - return Status::OK(); -} - -uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { - std::stringstream ss; - op->tree_->Print(ss, op); - std::string ss_str = ss.str(); - - // Filter out the Operator control flags field when generating the check sum - ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); - - // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline - ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), ""); - ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); - - // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same - // cache_client later. So we filter out these two fields to allow cache sharing. - ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); - ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); - - uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); - return cache_crc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h deleted file mode 100644 index b5bcb17b4b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h +++ /dev/null @@ -1,363 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ -#define DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/db_connector.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -// Forward declare -class ExecutionTree; - -class DataBuffer; - -class NodePass; - -class Sampler; - -/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so -/// the actual implementation of the operators will be derived from here. -class DatasetOp : public std::enable_shared_from_this { - // Allow execution tree to access internal members - friend class ExecutionTree; - - public: - static constexpr int32_t kInvalidOperatorId = -1; - - // Operator control flags - enum OpControlFlags { - kDeOpNone = 0, - kDeOpRepeated = 1, // Operator is a node in a repeat path - kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop - }; - - // Flags that control operator runtime behaviours - enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; - - /// Constructor - /// \param op_connector_size - The size for the output connector of this operator. - /// \param sampler - The sampler for the op - explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); - - /// Destructor - virtual ~DatasetOp() { tree_ = nullptr; } - - /// Adds a operator to become our child. - /// \param child - shared pointer to the child to add. - Status AddChild(std::shared_ptr child); - - /// Remove a operator from our children. - /// \param child - shared pointer to the child to remove. - Status RemoveChild(std::shared_ptr child); - - /// \brief Removes this node from the tree and connects it's parent/child together - /// \return Status eerror code returned - Status Remove(); - - /// \brief Getter function to get a shared pointer to our child - /// \param[in] child_index An operator can have n children. Indicates which child to return. - /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index - std::shared_ptr child(int32_t child_index) const; - - /// \brief Getter function to get the pointer to our parent - /// If there are no parents, it returns null regardless of the given index - /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. - void Parent(DatasetOp **parent, int32_t parent_index) const; - - // Inserts a operator as the parent current op. - // Inserted op will become the sole parent of the current op. - // The existing parent of the current op will be transferred to the inserted op. - Status InsertAsParent(std::shared_ptr to_add); - - /// \brief Creates the connector within this operator - /// \param num_producers - number of threads that write into this connector - /// \param num_consumers - number of threads that read from this connector - void CreateConnector(int32_t num_producers, int32_t num_consumers); - - /// \brief A print method typically used for debugging - /// \param out - The output stream to write output to - /// \param show_all - A bool to control if you want to show all info or just a summary - virtual void Print(std::ostream &out, bool show_all) const; - - /// \brief << Stream output operator overload - /// \notes This allows you to write the debug print info using stream operators - /// \param out - reference to the output stream being overloaded - /// \param dO - reference to the DatasetOp to display - /// \return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) { - dO.Print(out, false); - return out; - } - - /// \brief Class functor operator (). - /// DatasetOps operate by launching a thread (see ExecutionTree). - /// This pure virtual version makes the requirement that derived classes must provide a functor - /// that will execute their main runtime loop code. - /// \return Status - The error code return - virtual Status operator()() = 0; - - /// \brief Gets the next buffer from the given child - /// \notes See GetNextInput for similar function that has built-in message handling - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \param worker_id - The worker id - /// \return Status - The error code return - virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { - return GetNextBuffer(p_buffer, worker_id, false); - } - - /// \brief Gets the next buffer from the given child - /// \notes See GetNextInput for similar function that has built-in message handling - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \return Status - The error code return - virtual Status GetNextBuffer(std::unique_ptr *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } - - /// \brief Gets the next buffer from the given child - /// \notes See GetNextInput for similar function that has built-in message handling - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \param worker_id - The worker id - /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - /// \return Status - The error code return - virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe); - - /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof - /// message handling so that child classes don't have to manually code pass-through logic when - /// those messages are received. - /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) - /// \param worker_id - The worker id - /// \return Status - The error code return - Status GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); - - /// \brief Performs handling for when an eoe message is received. - /// The base class implementation simply flows the eoe message to output. Derived classes - /// may override if they need to perform special eoe handling. - /// \param worker_id - The worker id - /// \return Status - The error code return - virtual Status EoeReceived(int32_t worker_id); - - /// \brief Performs handling for when an eof message is received. - /// The base class implementation simply flows the eof message to output. Derived classes - /// may override if they need to perform special eof handling. - /// \param worker_id - The worker id - /// \return Status - The error code return - virtual Status EofReceived(int32_t worker_id); - - /// \brief Derived classes may implement the reset function if the operator is stateful and needs - /// specific reset handling that is not contained in this common code version of the reset - /// \return Status - The error code return - virtual Status Reset(); - - /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on - /// their role. - /// \notes Derived versions of this function should always call it's superclass version first - /// before providing their own implementations. - virtual Status PrepareNodePreAction(); - - /// \brief During tree prepare phase, operators may have specific post-operations to perform depending on - /// their role. - /// \notes Derived versions of this function should always call it's superclass version first - /// before providing their own implementations. - virtual Status PrepareNodePostAction(); - - /// \brief Getter function - /// \return The operator id - int32_t id() const { return operator_id_; } - - /// \brief Getter function - /// \return The prepare flags - virtual uint32_t PrepareFlags() const; - - /// \brief Getter function - /// \return The number of workers in this op - virtual int32_t num_workers() const = 0; - - /// \brief Getter function - /// \return The number of threads consuming from previous op. - virtual int32_t num_consumers() const = 0; - - /// \brief Getter function - /// \return The number of threads producing to the output connector. - virtual int32_t num_producers() const = 0; - - /// \brief Getter function - /// \return T/F if this is an inlined operator - bool inlined() const { return (oc_queue_size_ == 0); } - - /// \brief Setter function - /// \return Sets the control flags - void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } - - /// \brief Setter function - /// \return Sets the control flags - void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } - - /// \brief Register the internal worker connectors. No op unless it is a parallel op - /// \return Status - virtual Status RegisterWorkerConnectors() { return Status::OK(); } - - /// \brief Getter for the column name mapping - /// \return The returned map - std::unordered_map column_name_id_map() const { return column_name_id_map_; } - - /// \brief Checks if the column name map has been set up yet for this op - /// \return - T/F if the operator has the map set up - bool HasColumnNameMap() const { return (column_name_id_map_.empty()); } - - /// \brief gives a string output for the column map for handy debug printing - /// \return - the column name map as a string - std::string ColumnNameMapAsString() const; - - /// \brief Getter function - /// \return connector size of current op - int32_t ConnectorSize() const { - if (!inlined()) { - return out_connector_->size(); - } - // Return child connector size for inlined op - return ChildOpConnectorSize(); - } - - /// \brief Counting number of buffer sent out by a connector - int64_t ConnectorOutBufferCount() const { - return out_connector_ == nullptr ? int64_t(-1) : static_cast(out_connector_->out_buffers_count()); - } - - /// \brief Getter function - /// \return connector size of current op - int32_t ConnectorCapacity() const { - if (!inlined()) { - return out_connector_->capacity(); - } - // Return child connector capacity for inlined op - return ChildOpConnectorCapacity(); - } - - /// \brief Getter function - /// \return connector size of child op - int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); } - - /// \brief Getter function - /// \return connector capacity of child op - int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } - - /// \brief Children Getter - /// \return Vector of Children - std::vector> Children() const { return child_; } - - /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up - /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main - /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it - /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - virtual Status PreAccept(NodePass *p, bool *modified); - - /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. - /// Check "dataset/engine/opt/pass.h" for more details. - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - virtual Status Accept(NodePass *p, bool *modified); - - /// Op name getter - /// \return Name of the current Op - virtual std::string Name() const { return "DatasetOp"; } - - /// Execution Tree getter - /// \return Pointer to the ExecutionTree the current op belongs to, no ownership - ExecutionTree *Tree() { return tree_; } - - /// Getter for the sampler - /// \return Shared pointer to the sampler (may return nullptr) - std::shared_ptr sampler() { return sampler_; } - - /// \brief Getter for the sampler, and it also removes the sampler from the op - /// \param[out] sampler A pointer to the output sampler that was removed - /// \return Status error code - Status FetchRemoveSampler(std::shared_ptr *sampler); - - // Computes a CRC value for the operator - static uint32_t GenerateCRC(const std::shared_ptr &op); - - /// \brief A helper templated function for casting "this" pointer to shared_ptr - /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr - /// \return A shared_ptr casted to the derived class - template - std::shared_ptr shared_from_base() { - return std::static_pointer_cast(shared_from_this()); - } - - /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. - void SetSampler(std::shared_ptr sampler) { sampler_ = sampler; } - - /// \brief Checks if this is a leaf node (0 children) - /// \return boolean returns true if it's a leaf - bool IsLeaf() { return (child_.empty()); } - - protected: - /// \brief Removes a parent operator from this operator - /// \notes External callers do not have access to this function - /// \param[in] parent The parent node to remove - void RemoveParent(const DatasetOp *parent); - - /// \brief Adds a parent operator to this operator - /// \notes External callers do not have access to this function - /// \param[in] parent The parent node to add - void AddParent(DatasetOp *parent); - - /// Compute the current op's column map using its child's column map. - /// Get called during the tree post-prepare phase in PrepareNodePostAction. - /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. - /// Operations changing the column map it inherits from the child must overwrite this function. - /// \return - Status - virtual Status ComputeColMap(); - - std::vector> child_; // Child nodes - std::vector parent_; // Parent nodes. No ownership - std::shared_ptr sampler_; // Some leaf ops might have a sampler - int32_t oc_queue_size_; // Capacity for each out_connector_ - int32_t operator_id_; // Generated id for the node - ExecutionTree *tree_; // Back pointer to our tree. - OpState state_; // The state of the operator, Running, Idle, Terminated - uint32_t op_ctrl_flags_; // Flags for the operator - std::unique_ptr out_connector_; // Output Connector - std::unordered_map column_name_id_map_; // Mapping between col index and col name - std::mutex column_name_map_mutex_; // For protecting shared access to the column map - - private: - /// Sets the operator id. - /// \notes No public interface. Only the class itself, or it's friend the execution tree can set - /// this - /// \param op_id - the Id value to set into the operator - void set_id(int32_t op_id) { operator_id_ = op_id; } - - /// Sets the tree into the op so that the operator has a back pointer to the tree. - /// \param tree - the tree to assign to the op. - void set_tree(ExecutionTree *tree) { tree_ = tree; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc deleted file mode 100644 index 0f1fefc0f0..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/perf/device_queue_tracing.h" -#include "dataset/util/status.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - int32_t op_connector_size, int64_t num_batch) - : PipelineOp(op_connector_size), - channel_name_(channel_name), - device_type_(device_type), - device_id_(device_id), - prefetch_size_(prefetch_size), - num_batch_(num_batch) {} - -DeviceQueueOp::~DeviceQueueOp() {} - -#ifdef ENABLE_GPUQUE -void ReleaseData(void *addr) { - if (addr != nullptr) { - free(addr); - } -} -#endif - -DeviceQueueOp::Builder::Builder(int32_t prefetch_size) - : builder_prefetch_size_(prefetch_size), - builder_device_id_(0), - builder_device_type_(DeviceType::CPU), - builder_channel_name_(""), - builder_num_batch_(0) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status DeviceQueueOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -Status DeviceQueueOp::operator()() { - TaskManager::FindMe()->Post(); - - if (device_type_ == DeviceType::Ascend) { -#ifdef ENABLE_TDTQUE - RETURN_IF_NOT_OK(SendDataToAscend()); -#endif - } else if (device_type_ == DeviceType::GPU) { -#ifdef ENABLE_GPUQUE - RETURN_IF_NOT_OK(SendDataToGPU()); -#endif - } else if (device_type_ == DeviceType::CPU) { - RETURN_IF_NOT_OK(SendDataToCPU()); - } - - return Status::OK(); -} - -Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) const { - // this method checks if the buffer meets the conditions to be sent to TDT - if (buffer->NumRows() != 0) { - TensorRow row; - buffer->GetRow(0, &row); - for (const auto &item : row) { - CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); - } - } - return Status::OK(); -} - -#ifdef ENABLE_TDTQUE -Status DeviceQueueOp::SendDataToAscend() { - MS_LOG(INFO) << "Device queue, sending data to Ascend."; - int64_t total_batch = 0; - bool is_break_loop = false; - double batch_start_time, end_time; - int32_t batch_cost, tdt_cost; - int32_t connector_size = 0; - int32_t connector_capacity; - std::shared_ptr profiling_node; - bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); - if (isProfilingEnable) { - std::shared_ptr node; - RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node)); - profiling_node = std::dynamic_pointer_cast(node); - batch_start_time = ProfilingTime::GetCurMilliSecond(); - connector_capacity = ChildOpConnectorCapacity(); - } - std::unique_ptr current_buffer; - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - - while (!current_buffer->eof() && !is_break_loop) { - while (!current_buffer->eoe() && !is_break_loop) { - RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); - TensorRow currRow; - for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { - RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); - auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); - if (status == TdtStatus::FAILED) { - return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); - } - - if (isProfilingEnable) { - end_time = ProfilingTime::GetCurMilliSecond(); - // record push tdt time - profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost); - batch_cost = (int32_t)(end_time - batch_start_time); - // record batch time - profiling_node->Record(TIME, BATCH_TIME, total_batch + 1, batch_cost); - // record pipeline time - profiling_node->Record(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost); - batch_start_time = end_time; - // record connector depth - profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); - } - total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - is_break_loop = true; - } - } - if (isProfilingEnable) { - connector_size = ChildOpConnectorSize(); - connector_capacity = ChildOpConnectorCapacity(); - } - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - } - if (isProfilingEnable) { - connector_size = ChildOpConnectorSize(); - connector_capacity = ChildOpConnectorCapacity(); - } - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - } - - tree_->SetFinished(); - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; - - return Status::OK(); -} -#endif - -#ifdef ENABLE_GPUQUE -Status DeviceQueueOp::SendDataToGPU() { - MS_LOG(INFO) << "Device queue, sending data to GPU."; - int64_t total_batch = 0; - bool is_break_loop = false; - bool is_open = false; - uint32_t handle = INVALID_HANDLE; - - std::unique_ptr current_buffer; - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - - while (!current_buffer->eof() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { - while (!current_buffer->eoe() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { - RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); - TensorRow curr_row; // batch data - for (int row_id = 0; - row_id < current_buffer->NumRows() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed(); row_id++) { - RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &curr_row)); - - std::vector data_size; - for (int i = 0; i < curr_row.size(); i++) { - data_size.push_back(static_cast(curr_row[i]->SizeInBytes())); - } - if (!is_open) { - handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, ReleaseData); - if (handle == INVALID_HANDLE) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "open failed"); - } - is_open = true; - } - RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); - total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - is_break_loop = true; - } - } - if (!TaskManager::FindMe()->Interrupted()) - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - else - is_break_loop = true; - } - if (!TaskManager::FindMe()->Interrupted()) - RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); - else - is_break_loop = true; - } - - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; - - GpuBufferMgr::GetInstance().Close(handle); - - GpuBufferMgr::GetInstance().CloseConfirm(); - - return Status::OK(); -} - -Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, - uint32_t handle) { - std::vector items; - for (int i = 0; i < data_size.size(); i++) { - device::DataItemGpu data_item; - data_item.data_len_ = data_size[i]; - data_item.data_ptr_ = nullptr; - items.push_back(data_item); - } - - while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { - RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row)); - BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); - if (ret) { - for (int i = 0; i < items.size(); i++) { - free(items[i].data_ptr_); - } - if (ret == BlockQueueStatus_T::ERROR_INPUT) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); - } else { - MS_LOG(WARNING) << "Retry pushing data..."; - continue; - } - } else { - break; - } - } - return Status::OK(); -} - -Status DeviceQueueOp::MallocForGPUData(std::vector *items, const TensorRow &curr_row) { - int i = 0; - for (auto &sub_item : *items) { - sub_item.data_ptr_ = (unsigned char *)malloc(sub_item.data_len_); - if (sub_item.data_ptr_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memory malloc failed."); - } - (void)memset_s(sub_item.data_ptr_, sub_item.data_len_, 0, sub_item.data_len_); - const unsigned char *column_data = curr_row[i]->GetBuffer(); - if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, - static_cast(curr_row[i++]->SizeInBytes())) != 0) { - MS_LOG(ERROR) << "memcpy_s failed!"; - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed."); - } - } - - return Status::OK(); -} -#endif - -Status DeviceQueueOp::SendDataToCPU() { - MS_LOG(INFO) << "Device queue, sending data to CPU."; - int64_t total_batch = 0; - - std::unique_ptr child_iterator = std::make_unique(this, 0, 0); - while (!(child_iterator->eof_handled())) { - TensorRow curr_row; - RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&curr_row)); - - if (!curr_row.empty()) { - MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; - MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; - total_batch++; - if (num_batch_ > 0 && total_batch == num_batch_) { - break; - } - } - } - - MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; - - return Status::OK(); -} - -void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; - } -} - -// Visitor accept method for NodePass -Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h deleted file mode 100644 index a854004593..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -#ifdef ENABLE_TDTQUE -#include "dataset/engine/tdt/tdt_plugin.h" -#endif - -#ifdef ENABLE_GPUQUE -#include "device/gpu/gpu_buffer_mgr.h" -using mindspore::device::BlockQueueStatus_T; -using mindspore::device::GpuBufferMgr; -#endif - -namespace mindspore { -namespace dataset { -class DeviceQueueOp : public PipelineOp { - public: - static const uint32_t INVALID_HANDLE = 0xffffffffUL; - static const uint32_t WAIT_TIME = 5; - - enum class DeviceType { Ascend = 0, GPU = 1, CPU = 2 }; - - // The nested builder class inside of the DeviceQueueOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - explicit Builder(int32_t prefetch_size); - - // Default destructor - ~Builder() = default; - - Builder &SetPrefetchSize(int32_t prefetch_size) { - builder_prefetch_size_ = prefetch_size; - return *this; - } - - Builder &SetChannelName(const std::string &channel_name) { - builder_channel_name_ = channel_name; - return *this; - } - - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - Builder &SetDeviceType(const std::string &device_type) { - if (device_type == "Ascend") { - builder_device_type_ = DeviceType::Ascend; - } else if (device_type == "GPU") { - builder_device_type_ = DeviceType::GPU; - } else if (device_type == "CPU") { - builder_device_type_ = DeviceType::CPU; - } - return *this; - } - - Builder &SetDeviceId(int32_t device_id) { - builder_device_id_ = device_id; - return *this; - } - - Builder &SetNumBatch(int64_t num_batch) { - builder_num_batch_ = num_batch; - return *this; - } - - // Name: Build() - // Description: The final step for building a DeviceQueueOp via the Builder is - // to call this Build() method. It will instantiate the DeviceQueueOp - // and return it to caller as a shared pointer. - Status Build(std::shared_ptr *ptr) { - *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, - builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); - return Status::OK(); - } - - private: - int32_t builder_prefetch_size_; - int32_t builder_device_id_; - DeviceType builder_device_type_; - std::string builder_channel_name_; - int64_t builder_num_batch_; - int32_t builder_op_connector_size_; - }; - - // Name: constructor - // Description - DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, - int32_t op_connector_size, int64_t num_batch); - - // Name: destructor - // Description - ~DeviceQueueOp(); - - Status EoeReceived(int32_t worker_id) override; - - const int32_t get_prefetch_size() { return prefetch_size_; } - - // Name: Print() - // Description: A function that prints info about the node - void Print(std::ostream &out, // In: The output stream to print to - bool show_all) const override; // In: T/F if it should print everything - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const DeviceQueueOp &to) { - to.Print(out, false); - return out; - } - - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "DeviceQueueOp"; } - - private: - // Name: checkExceptions(DataBuffer); - // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp - Status CheckExceptions(const std::unique_ptr &buffer) const; - -#ifdef ENABLE_TDTQUE - Status SendDataToAscend(); -#endif - -#ifdef ENABLE_GPUQUE - Status SendDataToGPU(); - Status RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, uint32_t handle); - Status MallocForGPUData(std::vector *items, const TensorRow &curr_row); -#endif - - Status SendDataToCPU(); - std::string channel_name_; - DeviceType device_type_; - const int32_t device_id_; - const int32_t prefetch_size_; - const int64_t num_batch_; - -#ifdef ENABLE_TDTQUE - std::shared_ptr tdtInstancePtr; -#endif -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc deleted file mode 100644 index 81c93c6e1c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/filter_op.h" -#include -#include -#include -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/tensor_op.h" -#include "utils/log_adapter.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { - -Status FilterOp::Builder::SanityCheck() { - std::string err; - err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; - err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; - return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); -} - -FilterOp::Builder::Builder() { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status FilterOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, - builder_predicate_func_); - return Status::OK(); -} - -FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, - py::function predicate_func) - : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} - -Status FilterOp::operator()() { - // The operator class just starts off threads by calling the tree_ function. - RETURN_UNEXPECTED_IF_NULL(tree_); - filter_queues_.Init(num_workers_, oc_queue_size_); - RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); - Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)); - // Synchronize with TaskManager. - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - RETURN_IF_NOT_OK(Collector()); - return Status::OK(); -} - -Status FilterOp::EofReceived(int32_t) { return Status::OK(); } - -Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } - -// Validating if each of the input_columns exists in the DataBuffer. -Status FilterOp::ValidateInColumns(const std::vector *input_columns) { - for (const auto &inCol : *input_columns) { - bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false; - if (!found) { - std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); -} - -// A print method typically used for debugging. -void FilterOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nInput column names:"; - for (size_t i = 0; i < in_columns_.size(); i++) { - out << " " << in_columns_[i]; - } - out << "\n\n"; - } -} - -Status FilterOp::WorkerEntry(int32_t worker_id) { - // Handshake with TaskManager that thread creation is successful. - TaskManager::FindMe()->Post(); - std::unique_ptr in_buffer; - bool worker_stop = false; - while (worker_stop == false) { - // Getting a databuffer to work on. - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); - if (in_buffer->eoe()) { - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); - continue; - } else if (in_buffer->eof()) { - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); - worker_stop = true; - continue; - } - - RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); - - // if the databuffer was all filtered, it is marked as kFilterEmpty. - // if the databuffer was partially filtered, it is marked as kFilterPartial. - // if the databuffer was not filtered, it is marked as kFilterFull. - int32_t num_rows = in_buffer->NumRows(); - std::unique_ptr new_tensor_table; - RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); - - if (new_tensor_table->empty()) { - RETURN_IF_NOT_OK( - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); - } else if (new_tensor_table->size() == num_rows) { - in_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK( - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); - } else { // kFilterPartial - in_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK( - filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); - } - } - return Status::OK(); -} - -Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out) { - *out = std::make_unique(); - int32_t num_rows = in_buffer->NumRows(); - for (int32_t i = 0; i < num_rows; i++) { - TensorRow to_process; - TensorRow cur_row; - RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); - if (in_columns_.empty() == true) { - MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; - to_process = cur_row; - } else { - (void)std::transform( - in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), - [&cur_row, this](const auto &it) -> std::shared_ptr { return cur_row[column_name_id_map_[it]]; }); - } - bool predicate = true; - RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); - if (predicate) { - (*out)->push_back(std::move(cur_row)); - } - } - return Status::OK(); -} - -// if the filtered DataBuffer is written directly to out_connector_, -// the thread fetching data will block in a queue. -// Collector function will reorder the DataBuffer in order. -// for example in two work queues: -// int filter_queues_: -// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) -// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) -// after reorder in out_connector_: -// queue1: DB(data2) DB(data4) DB(eof) -// queue2: DB(eoe) DB(eoe) -Status FilterOp::Collector() { - bool collector_stop = false; - uint64_t task_id_cnt = 0; - uint64_t out_id_cnt = 0; - std::pair, filterCtrl> in_pair; - while (collector_stop == false) { - uint32_t w_id = task_id_cnt % num_workers_; - RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); - if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || - in_pair.second == filterCtrl::kFilterEoe) { - uint32_t out_task_id = out_id_cnt % num_workers_; - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); - out_id_cnt++; - task_id_cnt++; - } else if (in_pair.second == filterCtrl::kFilterEof) { - uint32_t out_task_id = out_id_cnt % num_workers_; - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); - collector_stop = true; - } else { // kFilterEmpty - task_id_cnt++; - } - } - return Status::OK(); -} - -// Private function for checking the column legality. -Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns) { - int32_t num_rows = in_buf->NumRows(); - int32_t num_cols = in_buf->NumCols(); - if (num_rows == 0 || num_cols == 0) { - RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); - } - // Check if there is invalid column name in the inColumns. - RETURN_IF_NOT_OK(ValidateInColumns(input_columns)); - return Status::OK(); -} - -Status FilterOp::CheckInput(const TensorRow &input) const { - for (auto &item : input) { - if (item == nullptr) { - RETURN_STATUS_UNEXPECTED("input is null."); - } - } - return Status::OK(); -} - -Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { - RETURN_IF_NOT_OK(CheckInput(input)); - // Acquire Python GIL. - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - // Transform input tensor vector into numpy array vector. - py::tuple input_args(input.size()); - for (size_t i = 0; i < input.size(); i++) { - py::array new_data; - RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); - input_args[i] = new_data; - } - // Invoke python function. - py::object ret_py_obj = predicate_func_(*input_args); - *out_predicate = ret_py_obj.cast(); - } catch (const py::error_already_set &e) { - std::stringstream ss; - ss << e.what() << std::endl; - ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; - return Status(StatusCode::kPyFuncException, ss.str()); - } - return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); -} - -// Visitor accept method for NodePass -Status FilterOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h deleted file mode 100644 index 36f70cb82f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h +++ /dev/null @@ -1,188 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/queue.h" - -namespace mindspore { -namespace dataset { - -class FilterOp : public ParallelOp { - public: - // The nested builder class inside of the FilterOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args. - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPredicateFunc(py::function func) { - builder_predicate_func_ = std::move(func); - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - build_in_col_names_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - builder_op_connector_size_ = connector_size; - return *this; - } - - // The builder "build" method creates the final object. - // @param ptr The shared_ptr to the new FilterOp object. - // @return Status. - Status Build(std::shared_ptr *ptr); - - private: - // Sanity check for builder class args. - // @return Status - The error code return. - Status SanityCheck(); - std::vector build_in_col_names_; - py::function builder_predicate_func_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - }; - - enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; - - // Constructor of FilterOp - // @note The builder class should be used to call it. - // @param in_col_names A list of input column names,when it is empty the predicate will be - // applied all columns in the dataset. - // @param num_workers The number of worker threads. - // @param op_connector_size The size of each queue in the connector. - // @param predicate_func python callable which returns a boolean value. - FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, - py::function predicate_func); - - // Destructor - ~FilterOp() = default; - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status The error code return - Status operator()() override; - - // @param int32_t workerId. - // @return Status - The error code return. - Status EofReceived(int32_t) override; - - // @param int32_t workerId. - // @return Status - The error code return. - Status EoeReceived(int32_t) override; - - // A print method typically used for debugging. - // @param out The output stream to write output to. - // @param show_all A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "FilterOp"; } - - private: - // predicate_func python callable which returns a boolean value. - py::function predicate_func_; - - // Variable to store the column name that will feed to predicate function. - std::vector in_columns_; - - // Internal queue for filter. - QueueList, filterCtrl>> filter_queues_; - - // Private function for worker/thread to loop continuously. It comprises the main - // logic of FilterOp, getting the data from previous Op, validating user specified column names, - // applying predicate to each of the data, filter the data when predicate result is false. - // @param worker_id The id assigned to this thread/worker upon creation. - // @return Status The error code return. - Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ - - // Filter the data by predicate function . - // @param in_buffer input data buffer. - // @param to_proess_indices Indices of columns to be processed. - // @param out data buffer that are filtered by predicate. - // @return Status The error code return. - Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out); - - // Collector databuffer. - // @return Status The error code return. - Status Collector(); - - // @param input tensor vector. - // @return Status - The error code return. - Status CheckInput(const TensorRow &input) const; - - // Invoke python func. - // @param input tensor vector. - // @param the result of predicate. - // @return Status - The error code return. - Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); - - // Private function for validating if each of the user specified input column names - // exist in the DataBuffer. - // @param input_columns The vector of input column names used in the current thread. - // @return Status The error code return. - Status ValidateInColumns(const std::vector *input_columns); - - // Private function for checking the column legality - // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory - // and is not shared with other threads. - // @param[out] to_process_indices Indices of columns that will feed to predicate. - // @param input_columns The vector of input column names used in the current thread. - Status CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns); -}; - -} // namespace dataset -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc deleted file mode 100644 index 05a1ac7925..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ /dev/null @@ -1,373 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/map_op.h" -#include -#include -#include -#include -#include -#include "dataset/core/config_manager.h" - -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/tensor_op.h" -#include "utils/log_adapter.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -MapOp::Builder::Builder() : build_perf_mode_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_workers_ = cfg->num_parallel_workers(); - build_op_connector_size_ = cfg->op_connector_size(); -} - -// Check if the required parameters are set by the builder. -Status MapOp::Builder::sanityCheck() const { - if (build_tensor_funcs_.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Building a MapOp that has not provided any function/operation to apply"); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status MapOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(sanityCheck()); - *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), - std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, - build_perf_mode_); - return Status::OK(); -} - -// Constructor of MapOp -MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode) - : ParallelOp(num_workers, op_connector_size), - tfuncs_(std::move(tensor_funcs)), - in_columns_(in_col_names), - out_columns_(out_col_names), - perf_mode_(perf_mode) { - // If caller didn't specify the out_col_names, assume they are same as the in_columns. - if (out_columns_.empty() || out_columns_[0].empty()) { - out_columns_ = in_columns_; - } - MS_LOG(DEBUG) << "Performance Mode in map operator is " << perf_mode_ << "."; -} - -// The number of threads consuming data from previous op's output Connector. -int32_t MapOp::num_consumers() const { - // When Performance Mode is on, there is only one thread consuming from the previous Connector. - return perf_mode_ == true ? 1 : num_workers_; -} - -// A print method typically used for debugging -void MapOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nInput column names:"; - for (size_t i = 0; i < in_columns_.size(); i++) { - out << " " << in_columns_[i]; - } - out << "\n TensorOps:"; - for (size_t i = 0; i < tfuncs_.size(); i++) { - out << " " << *(tfuncs_[i].get()); - } - out << "\n\n"; - } -} - -// This class functor will provide the master loop that drives the logic for performing the work -Status MapOp::operator()() { - if (perf_mode_) { - // Create and register the local queues. - local_queues_.Init(num_workers_, oc_queue_size_); - Status rc = local_queues_.Register(tree_->AllTasks()); - if (rc.IsError()) { - TaskManager::FindMe()->Post(); - return rc; - } - } - - // The operator class just starts off threads by calling the tree_ function - Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); - // Synchronize with TaskManager - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(rc); - - if (perf_mode_) { - int64_t que_id = 0; - std::unique_ptr buff; - bool is_eof = false; - // Draining output connector of the previous op and distribute it to local queues. - // Stop when all worker threads are finished (received EOF). - while (!is_eof) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); - is_eof = buff->eof(); - RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); - que_id = (que_id + 1) % num_workers_; - } - } - - return Status::OK(); -} - -// Private function for worker/thread to loop continuously. It comprises the main -// logic of MapOp: getting the data from previous Op, validating user specified column names, -// applying a list of TensorOps to each of the data, process the results and then -// pushing them back to MapOp's output Connector to be fetched by the next Op. -Status MapOp::WorkerEntry(int32_t worker_id) { - // Handshake with TaskManager that thread creation is successful. - TaskManager::FindMe()->Post(); - std::unique_ptr in_buffer; - - // Getting a databuffer to work on. - // Perform the first fetch here outside of the loop. This allows us to execute one-time only - // initializations that happen after the first fetch. - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - - // Sanity check the databuffer. - // Special case: if there's more threads than buffers, some threads simply get the final control - // messages (eoe/eof), and so they will not perform the check. - if (!in_buffer->eoe() && !in_buffer->eof()) { - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - if (num_rows == 0 || num_cols == 0) { - RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); - } - } - - // Now that init work is done, drop into the main fetching loop. - // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself - // rather than use the base-class defaults. - while (true) { - // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work - // with Performance Mode design. - if (in_buffer->eoe()) { - // Calling base class EoeReceived to forward eoe buffer. - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - continue; - } else if (in_buffer->eof()) { - // Calling base class EofReceived to forward eof buffer. - RETURN_IF_NOT_OK(EofReceived(worker_id)); - break; - } - - std::unique_ptr new_tensor_table(std::make_unique()); - // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. - RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get())); - - // Replace the TensorTable in DataBuffer with the new one. - in_buffer->set_tensor_table(std::move(new_tensor_table)); - - // Push the buffer onto the connector for next operator to consume. - RETURN_IF_NOT_OK(out_connector_->Add(static_cast(worker_id), std::move(in_buffer))); - - // Fetch the next buffer and loop back to the top. - RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); - } - - return Status::OK(); -} - -Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table) { - // Getting number of rows and cols in this buffer. - int32_t num_rows = in_buffer->NumRows(); - int32_t num_cols = in_buffer->NumCols(); - - for (int32_t r = 0; r < num_rows; r++) { - // to_process : A vector of Tensors only holding cols in input_columns. - // result_row; : A vector of Tensors to hold the result after Compute(). - // cur_row : A vector of Tensors holding all the columns from DataBuffer. - TensorRow to_process, result_row, cur_row; - RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); - - // Populate the Tensor from the current row to be processed by TensorOp - for (const auto &idx : to_process_indices_) { - to_process.push_back(std::move(cur_row[idx])); - } - - // Looping over multiple TensorOps supplied in to MapOp. - // The assumption is that the result of one TensorOp matches the required input to the next TensorOp. - for (size_t i = 0; i < tfuncs_.size(); i++) { - // TensorOp can operate on single col or multiple cols. MapOp always call compute for multiple cols. - // TensorOp base class will call the single column Compute() depending on the ops. - // Note: The columns of the result_row is not preallocated, the compute function of each tensor op are - // required to resize/push back the result_row - RETURN_IF_NOT_OK(tfuncs_[i]->Compute(to_process, &result_row)); - - // Assign result_row to to_process for the next TensorOp processing, except for the last TensorOp in the list. - if (i + 1 < tfuncs_.size()) { - to_process = std::move(result_row); - } - } - - if (out_columns_.size() != result_row.size()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Result of a tensorOp doesn't match output column names"); - } - - if (in_columns_.size() == out_columns_.size()) { - for (size_t i = 0; i < result_row.size(); i++) { - cur_row[to_process_indices_[i]] = std::move(result_row[i]); - } - new_tensor_table->push_back(std::move(cur_row)); - } else { - // Add the columns we did not touch to the result_row. - for (int32_t i = 0; i < num_cols; i++) { - if (keep_input_columns_[i]) { - result_row.push_back(std::move(cur_row[i])); - } - } - - // Add this final result_row to our new TensorTable. - new_tensor_table->push_back(std::move(result_row)); - } - } - - return Status::OK(); -} - -Status MapOp::ComputeColMap() { - // If the map has not been set up yet in the base class, then set it up - if (column_name_id_map_.empty()) { - std::unordered_map current_name_id_map = child_[0]->column_name_id_map(); - // Initialize private variables - RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); - // Create the final column name to index mapping in the base class field - CreateFinalColMap(¤t_name_id_map); - MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// Validating if each of the input_columns exists in the DataBuffer. -Status MapOp::ValidateInColumns(const std::unordered_map &col_name_id_map) { - for (const auto &inCol : in_columns_) { - bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; - if (!found) { - std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - return Status::OK(); -} - -Status MapOp::InitPrivateVariable(std::unordered_map *col_name_id_map) { - // If input_columns is empty(), The col at index-0 will be picked. - if (in_columns_.empty()) { - for (const auto &pair : *col_name_id_map) { - if (pair.second == 0) { - MS_LOG(INFO) << "Input columns empty for map op, will apply to the first column in the current table."; - in_columns_.push_back(pair.first); - break; - } - } - - // If caller didn't specify the out_col_names, assume they are same as the input_columns. - // This was done in the constructor, but if input columns was empty to start we have to redo it here. - if (out_columns_.empty() || out_columns_[0].empty()) { - out_columns_ = in_columns_; - } - } - - // Before we continue, issue a sanity check to make sure the input columns from user and the incoming - // columns from child are correct - RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); - - // initialize keep_input_columns, true means to keep the column. - keep_input_columns_.resize(col_name_id_map->size(), true); - for (const auto &col_name : in_columns_) { - int32_t missed = (*col_name_id_map)[col_name]; - keep_input_columns_[missed] = false; - } - - // initialize to_process_indices. - for (const auto &col_name : in_columns_) { - to_process_indices_.push_back((*col_name_id_map)[col_name]); - } - return Status::OK(); -} - -// Create the final column name to index mapping and get indices of the columns this mapop does not use. -void MapOp::CreateFinalColMap(std::unordered_map *col_name_id_map) { - std::unordered_map final_col_name_id_map; - size_t num_cols = col_name_id_map->size(); - std::vector new_ids(num_cols); - if (in_columns_.size() == out_columns_.size()) { - for (size_t i = 0; i < in_columns_.size(); i++) { - int32_t loc = (*col_name_id_map)[in_columns_[i]]; - (void)col_name_id_map->erase(in_columns_[i]); - (*col_name_id_map)[out_columns_[i]] = loc; - } - - // Set the base class final column id map result - column_name_id_map_ = *col_name_id_map; - } else { - int32_t fill_idx = 0; - // First columns of the tables are occupied by the output columns from tensorOp. - for (const auto &col_name : out_columns_) { - final_col_name_id_map[col_name] = fill_idx++; - } - - // Creating new_ids mapping for the columns we keep. - for (size_t i = 0; i < num_cols; i++) { - if (keep_input_columns_[i]) { - new_ids[i] = fill_idx++; - } - } - - // Iterating through the old mapping to update the final mapping for the columns we kept. - std::string name; - for (const auto &pair : *col_name_id_map) { - name = pair.first; - int32_t old_id = pair.second; - if (keep_input_columns_[old_id]) { - final_col_name_id_map[name] = new_ids[old_id]; - } - } - - // Set the base class final column id map result - column_name_id_map_ = final_col_name_id_map; - } -} - -// Visitor accept method for NodePass -Status MapOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h deleted file mode 100644 index db7ad7e504..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_MAP_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/queue.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class DataBuffer; -class ExecutionTree; - -// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. -// The column order behavior after MapOp is as follows. -// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp -// is the same as the original column order where the Remainder Columns stay in the same position, -// and the Output Columns are placed the same position of the Input Columns. -// For example, initially if the dataset has column order |A, B, C, D, E|, -// and we apply MapOp() with Input Columns {B, C} and Output Columns {X, Y}. -// The column order after applying MapOp will be |A, X, Y, D, E|. -// Note that in this case, |X, Y| is the Output Columns and |A, D, E| which is the Remainder Columns stay in -// their original position, and column B is replaced by column X and column C is replace by column Y. -// [Case 2] If the number of Input Columns != the number of Output Column, column ordering after MapOp -// is Output Columns followed by Remainder Columns. -// For example, initially if the dataset has column order |A, B, C, D, E|, -// and we apply MapOp() with Input Columns {B, C, A} and Output Columns {X, Y}. -// The column order after applying MapOp will be |X, Y, D, E|. -// Note that in this case, |X, Y| is the Output Columns and |D, E| is the Remainder Columns, -// and the Input Columns are gone and replaced by the Output Columns. - -// Keywords: -// Input Columns : a vector of column names (string) passed to MapOp specifying the column names from which -// Tensors are taken and passed to the TensorOp Compute(). -// Output Columns : a vector of column names (string) passed to MapOp specifying what are the column names -// for the Tensors produced by TensorOp Compute(). -// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns. -// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns. -class MapOp : public ParallelOp { - public: - // The nested builder class inside of the MapOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - build_in_col_names_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOutColNames(const std::vector &out_col_names) { - build_out_col_names_ = out_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetTensorFuncs(std::vector> funcs) { - build_tensor_funcs_ = std::move(funcs); - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - build_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t connector_size) { - build_op_connector_size_ = connector_size; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPerformanceMode(bool perf_mode) { - build_perf_mode_ = perf_mode; - return *this; - } - - // The builder "build" method creates the final object. - // @param ptr The shared_ptr to the new MapOp object - // @return Status - Status Build(std::shared_ptr *ptr); - - private: - std::vector build_in_col_names_; - std::vector build_out_col_names_; - std::vector> build_tensor_funcs_; - int32_t build_num_workers_; - int32_t build_op_connector_size_; - bool build_perf_mode_; // Default true. - - // Check if the required parameters are set by the builder. - // @return Status The error code return - Status sanityCheck() const; - }; - - // Constructor of MapOp - // @note The builder class should be used to call it. - // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). - // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). - // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. - // @param num_workers The number of worker threads. - // @param op_connector_size The size of each queue in the connector. - MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, - bool perf_mode); - - // Destructor - ~MapOp() = default; - - // A print method typically used for debugging - // @param out The output stream to write output to - // @param show_all A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out reference to the output stream being overloaded - // @param mo reference to the MapOp to display - // @return the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const MapOp &mo) { - mo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status The error code return - Status operator()() override; - - // Getter - // @return the number of threads consuming data from previous op's output Connector. - int32_t num_consumers() const override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MapOp"; } - - // List of tensor ops getter/setter - // @Return the vector of tensor ops by non-const reference - - auto &TFuncs() { return tfuncs_; } - - const auto &TFuncs() const { return tfuncs_; } - - private: - // Local queues where worker threads can pop from. - // Popping directly from the Connector can block if the previous designated threads haven't pop. - // Setting the size of these queues to 0 is essentially the same as pulling directly from Connector. - QueueList> local_queues_; - - // Static variables to be ready by worker threads, no modification and readonly - std::vector> tfuncs_; - - // Variable to store the column name that the tensorOps are consuming - std::vector in_columns_; - - // Variable to store the column name that the tensorOps are producing - std::vector out_columns_; - - // Boolean mapping, true means to keep the column. - std::vector keep_input_columns_; - - // Indices of the columns to process. - std::vector to_process_indices_; - - // Performance mode is when the main thread creates local queues, pulls databuffers from the previous - // op's Connector and distributes them to the local queues. Workers pull from the local queues. - // If this flag is false, each worker pulls directly from the Connector. This use less resources - // (thread and memory), but when the computation cost is heavy (e.g. DecodeOp) and fluctuating, it can - // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. - bool perf_mode_; - - // Private function for worker/thread to loop continuously. It comprises the main - // logic of MapOp: getting the data from previous Op, validating user specified column names, - // applying a list of TensorOps to each of the data, process the results and then - // pushing them back to MapOp's output Connector to be fetched by the next Op. - // @param worker_id The id assigned to this thread/worker upon creation. - // @return Status The error code return - Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ - - // Private helper function for getting the next buffer - // When PerformanceMode is enabled, workers pop from the local queue. - // Otherwise, workers pop from the first child output Connector. - // @param p_buffer - the buffer to return - // @return Status return code - Status FetchNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { - if (perf_mode_) { - RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(p_buffer)); - } else { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id)); - } - return Status::OK(); - } - - // Private function for worker thread to perform TensorOp's compute function and get the result. - // @param in_buffer A raw pointer to the DataBuffer. A raw pointer is fine because this function doesn't manage memory - // and is not shared with other threads. - // @param[out] new_tensor_table A new Tensor Table to be populated in this function. - Status WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table); - - // Private function that create the final column name to index mapping and - // get indices of the columns this mapop does not use. - // @param col_name_id_map The column name to index mapping obtained from child operator - void CreateFinalColMap(std::unordered_map *col_name_id_map); - - // Validating if each of the input_columns exists in the DataBuffer. - // @param - the column map to check - // @return - status return code - Status ValidateInColumns(const std::unordered_map &col_name_id_map); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - // Private function for initializing private variables such as in_columns_, out_columns_. - // @return - Status - Status InitPrivateVariable(std::unordered_map *col_name_id_map); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc deleted file mode 100644 index 244861a6c8..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/parallel_op.h" - -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/core/config_manager.h" -#include "dataset/engine/db_connector.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Constructor -ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) - : DatasetOp(op_connector_size, sampler), - num_workers_(num_workers), - num_producers_(num_workers), - worker_connector_size_(1), - worker_connector_(nullptr) {} - -// Creates the internal worker connector for the parallel op if the derived class wants to use it -Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { - if (worker_connector_size == 0) { - RETURN_STATUS_UNEXPECTED("Worker connector size 0 is invalid."); - } - num_producers_ = 1; - worker_connector_size_ = worker_connector_size; - // Instantiate the worker connector. This is the internal connector, not the operators - // output connector. It has single master consuming from it (num producers is 1), and the number - // of workers is the defined count from the op. - worker_connector_ = std::make_unique(num_workers_, num_producers_, worker_connector_size); - - return Status::OK(); -} - -// A print method typically used for debugging -void ParallelOp::Print(std::ostream &out, bool show_all) const { - // Summary 1-liner print - if (!show_all) { - out << " [workers: " << num_workers_ << "]"; - // Call super class printer - DatasetOp::Print(out, show_all); - } else { - // Detailed print - DatasetOp::Print(out, show_all); - out << "\nNum workers: " << num_workers_; - } -} - -// Override base class reset to provide reset actions specific to the ParallelOp class. -Status ParallelOp::Reset() { - RETURN_IF_NOT_OK(DatasetOp::Reset()); // Perform any super class reset work - - // ParallelOp is abstract, but we do own the connector between workers and master - // (if the parallel op is configured for this). Reset that connector here. - if (worker_connector_) { - worker_connector_->Reset(); - } - - return Status::OK(); -} - -// Register the internal worker connectors -Status ParallelOp::RegisterWorkerConnectors() { - if (worker_connector_) { - return (worker_connector_->Register(tree_->AllTasks())); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h deleted file mode 100644 index f59d4bfc53..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/parallel_op.h +++ /dev/null @@ -1,126 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ - -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// global const in our namespace -constexpr int32_t kEndOfActions = -1; - -// Forward declares -class DataBuffer; - -class DbConnector; - -// A ParallelOp provides a multi-threaded DatasetOp -class ParallelOp : public DatasetOp { - public: - // Constructor - // @param num_workers - // @param op_connector_size - size of the output connector for this operator - // @param sampler - The sampler for the op - ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); - - // Destructor - ~ParallelOp() = default; - - // Creates the internal worker connector for the parallel op if the derived class wants to use it. - // @notes This changes the number of producers of this op to 1, since it establishes a master/worker - // relationship within the op, making all production flow through a single master. - // @return Status - The error return code - Status CreateWorkerConnector(int32_t worker_connector_size); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param pO - reference to the ParallelOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) { - po.Print(out, false); - return out; - } - - // During tree prepare phase, operators may have specific pre-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - // @return Status - The error return code - Status PrepareNodePreAction() override { - // Run common code from super class before adding ParallelOp specific logic - return (DatasetOp::PrepareNodePreAction()); - } - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - // @return Status - The error return code - Status PrepareNodePostAction() override { - // Run common code from super class before adding ParallelOp specific logic - return (DatasetOp::PrepareNodePostAction()); - } - - // Override base class reset to provide reset actions specific to the ParallelOp class. - // @return Status - The error code return - Status Reset() override; - - // Getter - // @return the number of workers - int32_t num_workers() const override { return num_workers_; } - - // Getter - // @return the number of threads consuming from the previous Connector - int32_t num_consumers() const override { return num_workers_; } - - // Getter - // @return the number of producers pushing to the output Connector - // @notes The number of producers is commonly the same as number of workers, except in the case - // when a worker connector is set up. In that case, there are n workers, and a single master - // such that only 1 thread is a producer rather than the n workers. - // @return the number of producers - int32_t num_producers() const override { return num_producers_; } - - // Register the internal worker connectors. - // @return Status - Status RegisterWorkerConnectors() override; - - protected: - // Interface for derived classes to implement. All derived classes must provide the entry - // function with the main execution loop for worker threads. - // @return Status - The error code return - virtual Status WorkerEntry(int32_t workerId) = 0; - - int32_t num_workers_; // The number of worker threads - int32_t num_producers_; // The number of threads pushing to the out_connector_ - int32_t worker_connector_size_; - std::unique_ptr worker_connector_; // The internal connector for worker threads -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc deleted file mode 100644 index 1d017a4d3e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/pipeline_op.h" -#include -#include - -namespace mindspore { -namespace dataset { -// Constructor -PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) - : DatasetOp(op_connector_size, sampler) {} - -// A print method typically used for debugging -void PipelineOp::Print(std::ostream &out, bool show_all) const { - // Summary 1-liner print - if (!show_all) { - out << " [workers: "; - if (this->inlined()) { - out << "0 (inlined)]"; - } else { - out << "1]"; // Pipeline ops only have 1 worker - } - // Call super class printer - DatasetOp::Print(out, show_all); - } else { - // Detailed print - DatasetOp::Print(out, show_all); - out << "\nNum workers: "; - if (this->inlined()) { - out << "0 (inlined)"; - } else { - out << "1"; // Pipeline ops only have 1 worker - } - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h deleted file mode 100644 index cb3c76813b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/pipeline_op.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ - -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" - -namespace mindspore { -namespace dataset { -// forward declare -class ExecutionTree; - -class DataBuffer; - -class PipelineOp : public DatasetOp { - public: - // Constructor - // @param op_connector_size - size of the output connector - // @return Builder setter method returns reference to the builder. - // @param sampler - The sampler for the op - explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); - - // Destructor - ~PipelineOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param po - reference to the PipelineOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const PipelineOp &po) { - po.Print(out, false); - return out; - } - - // Getter - // @return The number of workers inside this op. Pipeline ops only have a single worker. - int32_t num_workers() const override { return 1; } - - // Getter - // @return the number of threads consuming from the previous Connector - int32_t num_consumers() const override { return 1; } - - // Getter - // @return The number of threads that push data to the output connector - int32_t num_producers() const override { return 1; } - - // During tree prepare phase, operators may have specific pre-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePreAction() override { - // Run common code from super class before adding PipelineOp specific logic - return (DatasetOp::PrepareNodePreAction()); - } - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override { - // Run common code from super class before adding PipelineOp specific logic - return (DatasetOp::PrepareNodePostAction()); - } - - protected: - // ******************************************************************************* - // I'm predicting there will be common arguments or functionality for pipeline ops, - // just not sure yet what those are. perhaps this intermediate class between - // DatasetOp and the actual ops is not needed at all? - // For example, if there's no common code for all of the non-parallel ops, then - // they can just inherit from DatasetOp directly and we can put this class into the - // trash. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc deleted file mode 100644 index 5ce4056024..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.cc +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/datasetops/project_op.h" -#include -#include -#include -#include -#include -#include -#include -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -ProjectOp::Builder::Builder(const std::vector &columns_to_project) - : builder_columns_to_project_(columns_to_project) {} - -Status ProjectOp::Builder::SanityCheck() const { - if (builder_columns_to_project_.empty()) { - std::string err_msg("Columns to project is empty."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -Status ProjectOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_columns_to_project_); - return Status::OK(); -} - -ProjectOp::ProjectOp(const std::vector &columns_to_project) - : PipelineOp(0), columns_to_project_(columns_to_project) {} - -void ProjectOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nColumns that are projected:"; - for (size_t i = 0; i < columns_to_project_.size(); i++) { - out << "\n" << columns_to_project_[i]; - } - out << "\n\n"; - } -} - -// Gets a buffer from the child operator and projects the buffer. -Status ProjectOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id, retry_if_eoe)); - if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { - RETURN_IF_NOT_OK(Project(p_buffer)); - } - return Status::OK(); -} - -Status ProjectOp::Project(std::unique_ptr *data_buffer) { - std::unique_ptr new_tensor_table = std::make_unique(); - while ((*data_buffer)->NumRows() > 0) { - TensorRow current_row; - RETURN_IF_NOT_OK((*data_buffer)->PopRow(¤t_row)); - TensorRow new_row; - (void)std::transform(projected_column_indices_.begin(), projected_column_indices_.end(), - std::back_inserter(new_row), [¤t_row](uint32_t x) { return current_row[x]; }); - new_tensor_table->push_back(new_row); - } - (*data_buffer)->set_tensor_table(std::move(new_tensor_table)); - return Status::OK(); -} - -// Class functor operator () override. -// Most dataset ops operate by launching a thread (see ExecutionTree). -// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the -// functor since this op runs inlined inside another operator. The function is overloaded to -// ensure that it is not called by mistake (it will generate an error). -Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); } - -int32_t ProjectOp::num_consumers() const { - if (parent_.empty()) { - MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1."; - return 1; - } else if (parent_[0] == nullptr) { - MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0."; - return 0; - } else { - return parent_[0]->num_consumers(); - } -} - -int32_t ProjectOp::num_producers() const { - if (child_.empty() || child_[0] == nullptr) { - MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0."; - return 0; - } else { - return child_[0]->num_producers(); - } -} - -Status ProjectOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } - -// Visitor accept method for NodePass -Status ProjectOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -// Compute the column map and save it into our own column name map -// We cannot use the super class ComputeColMap here because we're making a modification of the -// map from the child map. -Status ProjectOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - std::unordered_map child_column_name_mapping = child_[0]->column_name_id_map(); - for (size_t i = 0; i < columns_to_project_.size(); i++) { - std::string ¤t_column = columns_to_project_[i]; - if (child_column_name_mapping.find(current_column) == child_column_name_mapping.end()) { - std::string err_msg = "ProjectOp: column " + current_column + " does not exist in child operator."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - // Setup the new column name mapping for ourself (base class field) - column_name_id_map_[current_column] = i; - projected_column_indices_.push_back(child_column_name_mapping[current_column]); - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/project_op.h b/mindspore/ccsrc/dataset/engine/datasetops/project_op.h deleted file mode 100644 index 628c1342ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/project_op.h +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class ProjectOp : public PipelineOp { - public: - // The nested builder class inside of the ProjectOp is used to help manage all of the arguments - // for constructing it. This repeat op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @param columns_to_project - - // @return This is a constructor. - explicit Builder(const std::vector &columns_to_project); - - // Builder destructor. - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new ProjectOp object. - Status Build(std::shared_ptr *); - - private: - std::vector builder_columns_to_project_; - Status SanityCheck() const; - }; - - // Constructor of the ProjectOp. - // @param columnsToProject - - explicit ProjectOp(const std::vector &columns_to_project); - - // Destructor. - ~ProjectOp() = default; - - // A print method typically used for debugging. - // @param out - The output stream to write output to. - // @param show_all - A bool to control if you want to show all info or just a summary. - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload. - // @notes This allows you to write the debug print info using stream operators. - // @param out - reference to the output stream being overloaded. - // @param project_op - reference to the ProjectOp to display. - // @return - the output stream must be returned. - friend std::ostream &operator<<(std::ostream &out, const ProjectOp &project_op) { - project_op.Print(out, false); - return out; - } - - // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). - // @return Status - The error code returned. - Status operator()() override; - - // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. - // @param p_buffer - output pointer to the projected buffer. - // @param worker_id - The worker id - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - - // Base-class override. Return the number of workers in the first parent. - // @param workerId - The worker id - int32_t num_consumers() const override; - - // Base-class override. Return the number of producers in the first child. - // @param workerId - The worker id - int32_t num_producers() const override; - - // Base-class override for special eoe handler. - // Inline operators must override this because there is no connector to push eoe onto. - // @return Status - The error code returned. - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for special eof handler. - // Inline operators must override this because there is no connector to push eof onto. - // @return Status - The error code returned. - Status EofReceived(int32_t worker_id) override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ProjectOp"; } - - private: - std::vector columns_to_project_; - std::vector projected_column_indices_; - - Status Project(std::unique_ptr *data_buffer); - - // Computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc deleted file mode 100644 index 23cd29d295..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.cc +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/rename_op.h" -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// builds -RenameOp::Builder::Builder() { - // Some arguments to the RenameOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the RenameOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status RenameOp::Builder::SanityCheck() const { return Status::OK(); } - -// build method for RenameOp -Status RenameOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_in_columns_, builder_out_columns_, builder_op_connector_size_); - return Status::OK(); -} - -// constructor -RenameOp::RenameOp(const std::vector &in_col_names, const std::vector &out_col_names, - int32_t op_connector_size) - : PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {} - -// destructor -RenameOp::~RenameOp() {} - -// main entry point for rename -Status RenameOp::operator()() { - TaskManager::FindMe()->Post(); - std::unique_ptr curr_buffer; - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - if (curr_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) { - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - std::string err_msg = "Rename first buffer got was control signal"; - // if 1st eoe or eof, pass it on then return - RETURN_STATUS_UNEXPECTED(err_msg); - } - - while (curr_buffer->eof() == false) { - while (curr_buffer->eoe() == false) { - // push the renamed input buffer - MS_LOG(DEBUG) << "Rename operator pushing next buffer."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } // end of while eoe loop - - // we got eoe, now try again until we get eof - MS_LOG(DEBUG) << "Rename operator EOE Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - MS_LOG(DEBUG) << "Rename operator fetching buffer after EOE."; - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } // end of while eof loop - - MS_LOG(DEBUG) << "Rename opeerator EOF Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Rename core functionality to compute the new column name id map. -// We need to overwrite the super class ComputeColMap here because we're making a modification of the -// map from the child map. -Status RenameOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - column_name_id_map_ = child_[0]->column_name_id_map(); - // iterate over my index in input vector, find the corresponding position - std::unordered_map new_col_name_id_map = {}; - // parameter for input check - size_t found = 0; - - // iterate over all the pairs and if there is a name match with rename, rename the column and add it to new map - // by doing it this way we recreate a new ColNameIdMap and allow for switching - for (const auto &pair : column_name_id_map_) { - std::string name = pair.first; - int32_t id = pair.second; - // find name - std::vector::iterator it; - it = std::find(in_columns_.begin(), in_columns_.end(), name); - // for c input checks here we have to count the number of times we find the stuff in in_columns_ - // because we iterate over the mInputList n times - if (it != in_columns_.end()) { - // found - found += 1; - int index = std::distance(in_columns_.begin(), it); - MS_LOG(DEBUG) << "Rename operator index found " << index << " value " << id << "."; - - new_col_name_id_map[out_columns_[index]] = id; - } else { - // not found - MS_LOG(DEBUG) << "Rename operator index not found: " << id << " is the column id."; - new_col_name_id_map[name] = id; - } - } - // only checks number of renamed columns have been found, this input check doesn't check everything - if (found != in_columns_.size()) { - MS_LOG(DEBUG) << "Rename operator column names found: " << found << " out of " << in_columns_.size() << "."; - std::string err_msg = "Renamed column doesn't exist in dataset"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Now, overwrite our column map with the new renamed columns/id's - column_name_id_map_ = new_col_name_id_map; - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// prints rename -void RenameOp::Print(std::ostream &out, // In: The output stream to print to - bool show_all) const { // In: T/F if it should print everything - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nIn columns:"; - for (size_t i = 0; i < in_columns_.size(); ++i) { - out << "\n " << in_columns_[i]; - } - for (size_t i = 0; i < out_columns_.size(); ++i) { - out << "\n " << out_columns_[i]; - } - out << "\n\n"; - } -} - -Status RenameOp::EofReceived(int32_t) { - MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now."; - return Status::OK(); -} - -Status RenameOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status RenameOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h b/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h deleted file mode 100644 index e209c075d6..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/rename_op.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ -#define DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// forward declare -class DataBuffer; - -class RenameOp : public PipelineOp { - public: - // The nested builder class inside of the RenameOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetInColNames(const std::vector &in_col_names) { - builder_in_columns_ = in_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOutColNames(const std::vector &out_col_names) { - builder_out_columns_ = out_col_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // The builder "build" method creates the ZipOp dataset Operator. - // @return shared_ptr to the new RenameOp object - Status Build(std::shared_ptr *); - - private: - std::vector builder_in_columns_; - std::vector builder_out_columns_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor for RenameOp - // @param in_col_names names of columns to rename - // @param out_col_names names of columns after rename - // @param op_connector_size connector size - RenameOp(const std::vector &in_col_names, // In: Col names to consume - const std::vector &out_col_names, // In: Col names to produce - int32_t op_connector_size); - - // Destructor - ~RenameOp(); - - Status EofReceived(int32_t) override; - - Status EoeReceived(int32_t) override; - - // Print function for Rename - // @param out output stream to print to - // @param show_all if it should print everything - void Print(std::ostream &out, bool show_all) const override; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const RenameOp &ro) { - ro.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "RenameOp"; } - - protected: - // Rename core functionality - // Computing the assignment of the new column name map. - // @return - Status - Status ComputeColMap() override; - - // Variable to store the input column names - std::vector in_columns_; - - // Variable to store the output column names - std::vector out_columns_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc deleted file mode 100644 index a0de649284..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include - -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} - -Status RepeatOp::Builder::SanityCheck() const { - if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { - std::string err_msg("Repeat count must be > 0 or -1."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status RepeatOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_repeats_); - return Status::OK(); -} - -// Constructor of the RepeatOp. -RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} - -// Destructor -RepeatOp::~RepeatOp() {} - -// A print method typically used for debugging -void RepeatOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [repeats: " << max_repeats_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ - << "\nLeaf Nodes in execution path:"; - if (!eoe_ops_.empty()) { - for (size_t i = 0; i < eoe_ops_.size(); i++) { - out << "\n Operator: " << eoe_ops_[i]->id(); - } - } else { - out << " None."; - } - out << "\n\n"; - } -} - -// This function returns the buffer that is at the top of our output connector. The caller is -// typically our parent node, when the parent is asking us to provide the next buffer of data. -// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get -// a buffer from our child. -// This function sets the `retryIfEoe` flag when popping from the child connector. This way, -// this function will retry to pop the connector again and will get the non-EOE buffer if any. -Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { - if (child_.empty()) { - RETURN_STATUS_UNEXPECTED("RepeatOp can't be the leaf node."); - } - - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - // Loop until non EOE is received - while (buf->eoe()) { - RETURN_IF_NOT_OK(EoeReceived(worker_id)); - if (state_ == OpState::kDeOpIdle) { - *p_buffer = std::move(buf); - return Status::OK(); - } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); - } - // Check if the last buf is next eof - if (buf->eof()) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - *p_buffer = std::move(buf); - return Status::OK(); -} - -// Base-class override for handling cases when an eoe is received. -Status RepeatOp::EoeReceived(int32_t worker_id) { - repeat_count_++; - MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ - << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; - bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); - bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); - // If we've reached the requested repeat count, then flag the eoe nodes - // to tell them they've got one more epoch to perform. When they reach the end - // of the last epoch, they quit rather than loop again. This happens in two cases: - // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, - // 2- We are not repeated - if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { - for (auto &eoe_op : eoe_ops_) { - eoe_op->set_control_flag(kDeOpLastRepeat); - } - } - if (repeat_count_ == max_repeats_) { - repeat_count_ = 0; - state_ = OpState::kDeOpIdle; - return Status::OK(); - } - - // Invoke a reset against the eoe nodes only. - for (auto &eoe_op : eoe_ops_) { - RETURN_IF_NOT_OK(eoe_op->Reset()); - } - - return Status::OK(); -} - -// Class functor operator () override. -// Most dataset ops operate by launching a thread (see ExecutionTree). -// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the -// functor since this op runs inlined inside another operator. The function is overloaded to -// ensure that it is not called by mistake (it will generate an error). -Status RepeatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RepeatOp is an inlined operator."); } - -// Base-class override for handling cases when an eof is received. -Status RepeatOp::EofReceived(int32_t worker_id) { - MS_LOG(DEBUG) << "Repeat operator EOF received, do nothing now."; - return Status::OK(); -} - -int32_t RepeatOp::num_consumers() const { - if (parent_.empty()) { - MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1."; - return 1; - } else if (parent_[0] == nullptr) { - MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0."; - return 0; - } else { - return parent_[0]->num_consumers(); - } -} - -// Drive reset actions if needed -Status RepeatOp::Reset() { - // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. - // In that case, we now have to bounce the reset down to our own eoe ops. - MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; - for (auto &eoe_op : eoe_ops_) { - RETURN_IF_NOT_OK(eoe_op->Reset()); - } - state_ = OpState::kDeOpRunning; - return Status::OK(); -} - -int32_t RepeatOp::num_producers() const { - if (child_.empty() || child_[0] == nullptr) { - MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; - return 0; - } else { - return child_[0]->num_producers(); - } -} - -// Pre-Visitor accept method for NodePass -Status RepeatOp::PreAccept(NodePass *p, bool *modified) { - // Downcast shared pointer then call the pre-visitation - return p->PreRunOnNode(shared_from_base(), modified); -} - -// Visitor accept method for NodePass -Status RepeatOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h deleted file mode 100644 index 7993737aeb..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ -#define DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ - -#include -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class RepeatOp : public PipelineOp { - public: - static constexpr int32_t kInfiniteRepeat = -1; - - // The nested builder class inside of the RepeatOp is used to help manage all of the arguments - // for constructing it. This repeat op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @param count - The number of repeats to do - // @return This is a constructor. - explicit Builder(int32_t count); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new RepeatOp object - Status Build(std::shared_ptr *); - - private: - int32_t build_max_repeats_; - - Status SanityCheck() const; - }; - - // Constructor of the RepeatOp. - // @note The builder class should be used to call it - // @param count - The number of repeats to do - explicit RepeatOp(int32_t count); - - // Destructor - ~RepeatOp(); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param ro - reference to the RepeatOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const RepeatOp &ro) { - ro.Print(out, false); - return out; - } - - // Class functor operator () override. - // Most dataset ops operate by launching a thread (see ExecutionTree). - // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the - // functor since this op runs inlined inside another operator. The function is overloaded to - // ensure that it is not called by mistake (it will generate an error). - // @return Status - The error code return - Status operator()() override; - - // This function returns the buffer that is at the top of our output connector. The caller is - // typically our parent node, when the parent is asking us to provide the next buffer of data. - // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get - // a buffer from our child. - // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, - // this function will retry to pop the connector again and will get the non-EOE buffer if any. - // @param p_buffer - output pointer to the buffer that it will fetch. - // @param worker_id - The worker id - // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. - // @return Status - The error code return - Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; - - // Base-class override for handling cases when an eoe is received. - // @param worker_id - The worker id - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for handling cases when an eof is received. - // @param worker_id - The worker id - Status EofReceived(int32_t worker_id) override; - - /// \brief reset Op - /// \@return Status - The error code return - Status Reset() override; - - // Base-class override. Return the number of workers in the first parent. - // @param workerId - The worker id - int32_t num_consumers() const override; - - // Base-class override. Return the number of producers in the first child. - // @param workerId - The worker id - int32_t num_producers() const override; - - /// \brief Base-class override for NodePass pre-visit acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status PreAccept(NodePass *p, bool *modified) override; - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p The node to visit - /// \param[out] modified Indicator if the node was modified - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "RepeatOp"; } - - /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes - /// \param[in] eoe_op The input leaf/eoe operator to add to the list - void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } - - private: - int32_t max_repeats_; // The number of repeats that the user requested - int32_t repeat_count_; // A counter for the current number of executed repeats - std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc deleted file mode 100644 index f86fcc602b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ /dev/null @@ -1,304 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#if defined(_WIN32) || defined(_WIN64) -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -constexpr int32_t ShuffleOp::kShuffleStateInit; -constexpr int32_t ShuffleOp::kShuffleStateActive; -constexpr int32_t ShuffleOp::kShuffleStateDrain; - -// Builder constructor. Creates the builder object. -ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) { - std::shared_ptr cfg = GlobalContext::config_manager(); - build_op_connector_size_ = cfg->op_connector_size(); - build_rows_per_buffer_ = cfg->rows_per_buffer(); - build_shuffle_seed_ = GetSeed(); -} - -Status ShuffleOp::Builder::SanityCheck() const { - if (build_shuffle_size_ < 2) { - RETURN_STATUS_UNEXPECTED("Shuffle buffer size must be greater than 1."); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status ShuffleOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_, - build_reshuffle_each_epoch_, build_rows_per_buffer_); - return Status::OK(); -} - -// Constructor of the ShuffleOp -ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, - int32_t rows_per_buffer) - : PipelineOp(op_connector_size), - shuffle_size_(shuffle_size), - shuffle_seed_(shuffle_seed), - reshuffle_each_epoch_(reset_every_epoch), - rng_(shuffle_seed), - buffer_counter_(0), - rows_per_buffer_(rows_per_buffer), - shuffle_buffer_(std::make_unique()), - shuffle_last_row_idx_(0), - shuffle_buffer_state_(kShuffleStateInit) {} - -// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by -// itself rather than waiting for the reset driven from operators above it in the pipeline. -Status ShuffleOp::SelfReset() { - MS_LOG(DEBUG) << "Shuffle operator performing a self-reset."; - // If reshuffle_each_epoch is false, then we always use the same seed for every - // epoch. - // If reshuffle_each_epoch is true, then the first epoch uses the given seed, - // and all subsequent epochs will then keep on using the rng_ without resetting it - if (!reshuffle_each_epoch_) { - rng_ = std::mt19937_64(shuffle_seed_); - } - - shuffle_buffer_ = std::make_unique(); - buffer_counter_ = 0; - shuffle_last_row_idx_ = 0; - shuffle_buffer_state_ = kShuffleStateInit; - return Status::OK(); -} - -// A print method typically used for debugging -void ShuffleOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [shuffle size: " << shuffle_size_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nShuffle size: " << shuffle_size_ << "\nRows per buffer: " << rows_per_buffer_ - << "\nShuffle buffer state: " << shuffle_buffer_state_ << "\nShuffle seed: " << shuffle_seed_ << "\n\n"; - } -} - -// Private function to add a new row to the shuffle buffer. -Status ShuffleOp::AddRowToShuffleBuffer(TensorRow new_shuffle_row) { - // If the last slot of our shuffle buffer was not the full size of the shuffle buffer then we are - // filling it during the initial fill codepath and thus growing it's size. In that case, we push - // back the new row to grow our shuffle buffer size by 1. - // If we are already at the full size, then we overwrite the last slot with our row (and the last - // slot better be empty because it should already have been swapped out during the random row - // selection that was done previously!) - if (shuffle_last_row_idx_ < (shuffle_size_ - 1)) { - shuffle_buffer_->push_back(std::move(new_shuffle_row)); - shuffle_last_row_idx_ = (shuffle_buffer_->size()) - 1; - } else { - if (!(*shuffle_buffer_)[shuffle_last_row_idx_].empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Last row of shuffle buffer should not be occupied!"); - } - (*shuffle_buffer_)[shuffle_last_row_idx_] = std::move(new_shuffle_row); - } - return Status::OK(); -} - -// Class functor operator () override. -// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work -Status ShuffleOp::operator()() { - std::unique_ptr new_buffer_table; // A tensor table to be used for output. - - // Synchronize with TaskManager once the thread is launched. - TaskManager::FindMe()->Post(); - - // Shuffle op does not have workers, and only consumes from child 0. - // Create the child iterator to fetch our data from. - int32_t worker_id = 0; - int32_t child_idx = 0; - child_iterator_ = std::make_unique(this, worker_id, child_idx); - - // Main operator loop - while (true) { - // Do an initial populate of the shuffle buffer - RETURN_IF_NOT_OK(InitShuffleBuffer()); - - // This is our main loop exit condition, when the iterator has no more data completely. - if (child_iterator_->eof_handled()) { - break; - } - - // Next, enter into the main execution loop of the shuffle op. - // When the tail index position of our shuffle buffer goes negative it means that we've - // fully drained the data from the shuffle buffer and we're done. - while (shuffle_last_row_idx_ >= 0) { - // Step 1) - // Create an output tensor table if one is not created yet. - if (!new_buffer_table) { - new_buffer_table = std::make_unique(); - } - - // Step 2) - // Randomly select a slot from our shuffle buffer and copy that row into the output - // tensor table. We remove the data from the shuffle buffer, leaving that slot - // in the table as an empty vector - int64_t random_slot = rng_() % (shuffle_last_row_idx_ + 1); - new_buffer_table->push_back(std::move((*shuffle_buffer_)[random_slot])); - - // Step 3) - // If the output tensor table is at the requested size, then create a buffer for it - // and send this buffer on it's way up the pipeline. Special case is if this is the - // last row then we also send it. - if (new_buffer_table->size() == rows_per_buffer_ || shuffle_last_row_idx_ == 0) { - auto new_buffer = std::make_unique(buffer_counter_, DataBuffer::kDeBFlagNone); - new_buffer->set_tensor_table(std::move(new_buffer_table)); - buffer_counter_++; - MS_LOG(DEBUG) << "Shuffle operator sending a buffer to output."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(new_buffer))); - } - - // Step 4) - // Take the last row from shuffle buffer, and swap it into the row position that was - // just vacated. This makes the shuffle buffer contiguous, with an empty slot at the - // tail of the shuffle buffer. - if (random_slot != shuffle_last_row_idx_) { - (*shuffle_buffer_)[random_slot] = std::move((*shuffle_buffer_)[shuffle_last_row_idx_]); - } - - // Step 5) - // Refill the last slot of the shuffle buffer with the next row from input if we are in the - // active state. - // If we are in the draining state, we do not need to fetch another row to replace the one we - // just drained. - if (shuffle_buffer_state_ == kShuffleStateActive) { - TensorRow new_row; - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - - if (!new_row.empty()) { - RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); - } else { - shuffle_buffer_state_ = kShuffleStateDrain; - } - } - - // If we are draining, reposition (decrement) our tail index in the shuffle buffer since we - // just drained a row from it. - if (shuffle_buffer_state_ == kShuffleStateDrain) { - shuffle_last_row_idx_--; - } - } - - // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the - // pipepline manually now that we are done draining the shuffle buffer - MS_LOG(DEBUG) << "Shuffle operator sending EOE."; - auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - // Do not wait for any reset to be flown down from operators above us. - // Instead, manually update ourselves and then go reloop to start fetching from child operator - // right away. Any Reset() from the parent will still perform common reset actions. - RETURN_IF_NOT_OK(this->SelfReset()); - } - return Status::OK(); -} - -// Private function populate the shuffle buffer initially by fetching from the child output -// connector until the shuffle buffer is full (or there is no more data coming). -Status ShuffleOp::InitShuffleBuffer() { - MS_LOG(DEBUG) << "Shuffle operator initializing the shuffle buffer."; - - // The first phase of this operator is to read incoming buffers and then drain those - // rows from the buffers, putting them into our own local table of tensors (the shuffle - // buffer). - // This shuffle buffer initialization phase stops when we've either filled up the - // shuffle buffer to it's max size, or the dataset below us is not providing any more - // rows. - if (shuffle_buffer_state_ != kShuffleStateInit) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Invalid shuffle buffer state (SHUFFLE_STATE_INIT expected)"); - } - - // Before we drop into the fetching loop, call the fetch once for the first time - // to fill the first row and grab the first buffer. - TensorRow new_row; - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - - if (child_iterator_->eof_handled()) { - MS_LOG(DEBUG) << "Shuffle operator init picked up EOF. No more epochs."; - return Status::OK(); - } - - if (new_row.empty()) { - RETURN_STATUS_UNEXPECTED("Unable to fetch a single row for shuffle buffer."); - } - - // Now fill the rest of the shuffle buffer until we are unable to get the next row or we reached - // the desired shuffle buffer size. - while (!new_row.empty() && shuffle_buffer_->size() < static_cast(shuffle_size_ - 1)) { - // Add the previously fetched row - RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); - - // Fetch the next row - RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); - } - - // If we quit the loop due to being at the shuffle size, still need to add the last row here. - if (!new_row.empty()) { - RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); - shuffle_buffer_state_ = kShuffleStateActive; // Transition to the active state - } else { - // If init phase doesn't have more rows, then skip the active state and jump straight to the - // shuffle buffer draining state - shuffle_buffer_state_ = kShuffleStateDrain; - } - - MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer."; - return Status::OK(); -} - -Status ShuffleOp::EoeReceived(int32_t worker_id) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ShuffleOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h deleted file mode 100644 index 14b1e4511e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Forward declare -class ExecutionTree; - -class DbConnector; - -class DataBuffer; - -class ShuffleOp : public PipelineOp { - // Shuffle buffer state flags - // - // Shuffle buffer is in a state of being initialized - static constexpr int32_t kShuffleStateInit = 0; - - // Shuffle buffer is in a state of being actively drained from, but refilling as well - static constexpr int32_t kShuffleStateActive = 1; - - // Shuffle buffer is in a state of being drained - static constexpr int32_t kShuffleStateDrain = 2; - - public: - // The nested builder class inside of the ShuffleOp is used to help manage all of the arguments - // for constructing it. The shuffle op is fairly simple though, but the builder provides a - // consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetShuffleSize(int32_t shuffle_size) { - build_shuffle_size_ = shuffle_size; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetShuffleSeed(uint32_t shuffle_seed) { - build_shuffle_seed_ = shuffle_seed; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - build_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetReshuffleEachEpoch(bool reshuffle_each_epoch) { - build_reshuffle_each_epoch_ = reshuffle_each_epoch; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - build_op_connector_size_ = op_connector_size; - return *this; - } - - // The builder "build" method creates the final object. - // @return shared_ptr to the new ShuffleOp object - Status Build(std::shared_ptr *); - - private: - // The builder saves all ShuffleOp construction arguments internally. - // The following are the arguments. - int32_t build_shuffle_size_; - uint32_t build_shuffle_seed_; - int32_t build_rows_per_buffer_; - bool build_reshuffle_each_epoch_; - int32_t build_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor of the ShuffleOp - // @note The builder class should be used to call it - // @param shuffle_size - The size for the shuffle buffer - // @param shuffle_seed - The seed to use for random number generation - // @param op_connector_size - The output connector queue size - // @param rows_per_buffer - The requested number of rows per buffer - ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, - int32_t rows_per_buffer); - - // Destructor - ~ShuffleOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param so - reference to the ShuffleOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) { - so.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for special eoe handler. - // ShuffleOp must override this because it shall not perform default handling of eoe. Instead - // the ShuffleOp needs to manage actions related to the end of the epoch itself. - // @return Status - The error code return - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ShuffleOp"; } - - private: - // Private function to add a new row to the shuffle buffer. - // @return Status - The error code return - Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); - - // Private function to populate the shuffle buffer initially by fetching from the child output - // connector until the shuffle buffer is full (or there is no more data coming). - // @return Status - The error code return - Status InitShuffleBuffer(); - - // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by - // itself rather than waiting for the reset driven from operators above it in the pipeline. - // @return Status - The error code return - Status SelfReset(); - - int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) - uint32_t shuffle_seed_; - bool reshuffle_each_epoch_; - // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period. - // specifically mt19937_64 is used to generate larger random numbers to reduce bias when - // modding to fit within our desired range. we dont use a distribution - // (ie uniform_int_distribution) because we will need to create up to |dataset| instances - // of the distribution object in the common case of a perfect shuffle - std::mt19937_64 rng_; - int32_t buffer_counter_; // For creating new buffer id's - int32_t rows_per_buffer_; // Number of rows to pack into output buffer - // A single (potentially large) buffer of tensor rows for performing shuffling. - std::unique_ptr shuffle_buffer_; - int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer - int32_t shuffle_buffer_state_; // State tracking for the shuffle buffer phases of work - - std::unique_ptr child_iterator_; // An iterator for fetching. -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc deleted file mode 100644 index f6b0fe689c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/skip_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status SkipOp::Builder::SanityCheck() const { - if (build_max_skips_ < 0) { - std::string err_msg("Skip count must be positive integer or 0."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status SkipOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); - return Status::OK(); -} - -// Constructor of the SkipOp. -SkipOp::SkipOp(int32_t count, int32_t op_connector_size) - : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} - -// Destructor -SkipOp::~SkipOp() {} - -// A print method typically used for debugging -void SkipOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [skips: " << max_skips_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nSkip count: " << skip_count_ << "\nMax skips: " << max_skips_ << "\n\n"; - } -} - -// Base-class override for handling cases when an eoe is received. -Status SkipOp::EoeReceived(int32_t worker_id) { - skip_count_ = 0; - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// main entry point for skip -Status SkipOp::operator()() { - TaskManager::FindMe()->Post(); - std::unique_ptr curr_buffer; - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - - while (curr_buffer->eof() == false) { - // Reset count - skip_count_ = 0; - while (curr_buffer->eoe() == false) { - // Drop first count rows - while (skip_count_ < max_skips_) { - if (curr_buffer->eoe() || curr_buffer->eof()) { - break; - } - // Consider the rows of buffer more than one - TensorRow drop_row; - int row_num = curr_buffer->NumRows(); - int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; - skip_count_ += drop_num; - for (int i = 0; i < drop_num; i++) { - RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); - } - if (curr_buffer->NumRows() == 0) { - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } - } - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } - // we got eoe, now try again until we got eof - MS_LOG(DEBUG) << "Skip operator EOE Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); - } - - MS_LOG(DEBUG) << "Skip operator EOF Received."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Base-class override for handling cases when an eof is received. -Status SkipOp::EofReceived(int32_t worker_id) { - MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status SkipOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h deleted file mode 100644 index 4cb658b2a7..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ - -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class SkipOp : public PipelineOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @param count - The number of skip to do - // @return This is a constructor. - explicit Builder(int32_t count); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new SkipOp object - Status Build(std::shared_ptr *); - - private: - int32_t build_max_skips_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor of the SkipOp. - // @note The builder class should be used to call it - // @param count - The number of skips to do - explicit SkipOp(int32_t count, int32_t op_connector_size); - - // Destructor - ~SkipOp(); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for handling cases when an eoe is received. - // @param worker_id - The worker id - Status EoeReceived(int32_t worker_id) override; - - // Base-class override for handling cases when an eof is received. - // @param worker_id - The worker id - Status EofReceived(int32_t worker_id) override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "SkipOp"; } - - private: - int32_t max_skips_; // The number of skips that the user requested - int32_t skip_count_; // A counter for the current number of executed skips -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc deleted file mode 100644 index db357f42ec..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc +++ /dev/null @@ -1,430 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/engine/datasetops/source/celeba_op.h" - -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/util/path.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status CelebAOp::Builder::Build(std::shared_ptr *op) { - MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << "."; - MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - // label is like this:0 1 0 0 1...... - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - *op = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, builder_decode_, builder_dataset_type_, - builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); - if (*op == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); - } - - return Status::OK(); -} - -Status CelebAOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() ? "" : "CelebA path is invalid or not set\n"; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is smaller than 1\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, - bool decode, const std::string &dataset_type, const std::set &exts, - std::unique_ptr schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size, std::move(sampler)), - rows_per_buffer_(rows_per_buffer), - folder_path_(dir), - decode_(decode), - extensions_(exts), - data_schema_(std::move(schema)), - num_rows_in_attr_file_(0), - dataset_type_(dataset_type) { - attr_info_queue_ = std::make_unique>>(queue_size); - io_block_queues_.Init(num_workers_, queue_size); -} - -Status CelebAOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "tree_ not set"); - } - - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(ParseImageAttrInfo()); - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - - return Status::OK(); -} - -Status CelebAOp::ParseAttrFile() { - TaskManager::FindMe()->Post(); - Path folder_path(folder_path_); - std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); - if (!attr_file.is_open()) { - return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "Celeba attr file does not exist"); - } - - const auto PushBackToQueue = [this](std::vector &vec, std::ifstream &attr_file, - std::ifstream &partition_file) { - Status s = attr_info_queue_->EmplaceBack(vec); - if (s.IsError()) { - CLOSE_FILE(attr_file, partition_file); - return s; - } - return Status::OK(); - }; - - std::string rows_num; - std::string attr_name; - (void)getline(attr_file, rows_num); - try { - num_rows_in_attr_file_ = static_cast(std::stoul(rows_num)); // First line is rows number in attr file - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, invalid argument."); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, out of range."); - } - - (void)getline(attr_file, attr_name); // Second line is attribute name,ignore it - std::string image_info; - std::vector image_infos; - image_infos.reserve(oc_queue_size_); - while (getline(attr_file, image_info)) { - if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) { - continue; - } - image_infos.push_back(image_info); - if (image_info.size() % oc_queue_size_ == 0) { - RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); - image_infos.clear(); - } - } - if (!image_infos.empty()) { - RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); - } - std::vector end_indicator = std::vector(0); - RETURN_IF_NOT_OK(PushBackToQueue(end_indicator, attr_file, partition_file_)); // end indicator - CLOSE_FILE(attr_file, partition_file_); - return Status::OK(); -} - -bool CelebAOp::CheckDatasetTypeValid() { - if (!partition_file_.is_open()) { - Path folder_path(folder_path_); - partition_file_.open((folder_path / "list_eval_partition.txt").toString()); - if (!partition_file_.is_open()) { - MS_LOG(ERROR) << "Celeba partition file does not exist!"; - return false; - } - } - std::string line; - (void)getline(partition_file_, line); - std::vector vec = Split(line); - if (vec.size() != 2) { - return false; - } - int32_t type; - try { - type = std::stoi(vec[1]); - } catch (std::invalid_argument &e) { - MS_LOG(WARNING) << "Conversion to unsigned long failed, invalid argument, " << vec[0] << "."; - return false; - } catch (std::out_of_range &e) { - MS_LOG(WARNING) << "Conversion to unsigned long failed, out of range, " << vec[0] << "."; - return false; - } - // train:0, valid=1, test=2 - if (dataset_type_ == "train" && (type == 0)) { - return true; - } else if (dataset_type_ == "valid" && (type == 1)) { - return true; - } else if (dataset_type_ == "test" && (type == 2)) { - return true; - } - - return false; -} - -Status CelebAOp::ParseImageAttrInfo() { - std::vector image_infos; - bool needMoreData = true; - RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); - while (!image_infos.empty() && needMoreData) { - for (uint32_t index = 0; index < image_infos.size(); index++) { - std::string image_info = image_infos[index]; - std::vector split = Split(image_info); - std::pair> image_labels; - - Path path(folder_path_); - Path file_path = path / split[0]; - if (!extensions_.empty() && extensions_.find(file_path.Extension()) == extensions_.end()) { - MS_LOG(WARNING) << "Unsupported file found at " << file_path.toString().c_str() << ", its extension is " - << file_path.Extension().c_str() << "."; - continue; - } - image_labels.first = split[0]; - for (uint32_t label_index = 1; label_index < split.size(); label_index++) { - int32_t value; - try { - value = std::stoi(split[label_index]); - } catch (std::invalid_argument &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); - } catch (std::out_of_range &e) { - RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); - } - image_labels.second.push_back(value); - } - - image_labels_vec_.push_back(image_labels); - } - - RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); - } - - num_rows_ = image_labels_vec_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " - "validation first."); - } - MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; - return Status::OK(); -} - -std::vector CelebAOp::Split(const std::string &line) { - std::string str = line; - std::string::size_type pos; - std::vector split; - str += " "; - int size = str.size(); - for (uint32_t index = 0; index < size;) { - pos = str.find(" ", index); - if (pos != index) { // skip space - std::string s = str.substr(index, pos - index); - split.push_back(s); - } - index = pos + 1; - } - - return split; -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status CelebAOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr data_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); - RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); - return Status::OK(); -} - -Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { - int64_t buff_count = 0; - while (true) { - std::vector keys; - keys.reserve(rows_per_buffer_); - int64_t row_count = 0; - while (!(*data_buffer)->eoe()) { - TensorRow sample_row; - RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) { - MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << "."; - continue; - } - keys.push_back(*itr); - row_count++; - if (row_count % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buff_count++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); - } - - if (!keys.empty()) { - RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); - } - } -} - -Status CelebAOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty()) { - return Status::OK(); // empty key is a quit signal for workers - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unexpected nullptr received in worker"); -} - -Status CelebAOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - for (const auto &key : keys) { - TensorRow row; - RETURN_IF_NOT_OK(LoadTensorRow(key, image_labels_vec_[key], &row)); - deq->push_back(std::move(row)); - } - - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::pair> &image_label, - TensorRow *row) { - std::shared_ptr image; - std::shared_ptr label; - - Path path(folder_path_); - Path image_path = path / image_label.first; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, image_path.toString())); - if (decode_ == true) { - Status rc = Decode(image, &image); - if (rc.IsError()) { - image = nullptr; - std::string err_msg = "Fail to decode image: " + image_path.toString(); - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); - } - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), - TensorShape({1, (uint32_t)image_label.second.size()}), - data_schema_->column(1).type())); - RETURN_IF_NOT_OK(label->Zero()); - for (uint32_t index = 0; index < image_label.second.size(); index++) { - if (image_label.second[index] == 1) { - label->SetItemAt({0, static_cast(index)}, 1); - } else { - label->SetItemAt({0, static_cast(index)}, 0); - } - } - label->Squeeze(); - - (*row) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -void CelebAOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status CelebAOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// Visitor accept method for NodePass -Status CelebAOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status CelebAOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t index = 0; index < data_schema_->NumColumns(); index++) { - column_name_id_map_[data_schema_->column(index).name()] = index; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h deleted file mode 100644 index fa81babe4c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ - -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H -#define DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H - -#include -#include -#include -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/util/queue.h" -#include "dataset/engine/datasetops/source/io_block.h" - -#define CLOSE_FILE(attr_file, pairition_file) \ - do { \ - attr_file.close(); \ - if (pairition_file.is_open()) { \ - pairition_file.close(); \ - } \ - } while (false) - -namespace mindspore { -namespace dataset { -class CelebAOp : public ParallelOp, RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of CelebAOp - // @return Builder setter method returns reference to the builder. - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_op_connector_size_ = size; - return *this; - } - - // Setter method - // @param std::set & exts, file extensions to be read - // @return Builder setter method returns reference to the builder. - Builder &SetExtensions(const std::set &exts) { - builder_extensions_ = exts; - return *this; - } - - // Setter method - // @param bool decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool decode) { - builder_decode_ = decode; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string &dir - // @return Builder setter method returns reference to the builder. - Builder &SetCelebADir(const std::string &dir) { - builder_dir_ = dir; - return *this; - } - - // Setter method - // @param const std::string dataset_type: type to be read - // @return Builder setter method returns reference to the builder. - Builder &SetDatasetType(const std::string &dataset_type) { - builder_dataset_type_ = dataset_type; - return *this; - } - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - std::string builder_dir_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::set builder_extensions_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - std::string builder_dataset_type_; - }; - - // Constructor - // @param int32_t - num_workers - Num of workers reading images in parallel - // @param int32_t - rows_per_buffer Number of images (rows) in each buffer - // @param std::string - dir directory of celeba dataset - // @param int32_t queueSize - connector queue size - // @param std::unique_ptr sampler - sampler tells CelebAOp what to read - CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, - const std::string &dataset_type, const std::set &exts, std::unique_ptr schema, - std::shared_ptr sampler); - - ~CelebAOp() override = default; - - // Main Loop of CelebaOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t worker_id - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // Method in operator(), to fill IOBlockQueue - // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue - // @return Status - The error code return - Status AddIOBlock(std::unique_ptr *data_buffer); - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p Pointer to the NodePass to be accepted - /// \param[out] modified Indicator if the node was changed at all - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const { return "CelebAOp"; } - - private: - // Called first when function is called - // @return - Status LaunchThreadsAndInitOp(); - - // Parse attribute file - // @return - Status ParseAttrFile(); - - // Parse each image line in attribute file - // @return - Status ParseImageAttrInfo(); - - // Split attribute info with space - // @param std::string - line - Line from att or partition file - // @return std::vector - string after split - std::vector Split(const std::string &line); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param std::pair - > - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::pair> &image_label, - TensorRow *row); - - // Check if need read according to dataset type - // @return bool - if need read - bool CheckDatasetTypeValid(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; - std::string folder_path_; // directory of celeba folder - bool decode_; - std::set extensions_; // extensions allowed - std::unique_ptr data_schema_; - std::unique_ptr>> attr_info_queue_; - int64_t num_rows_in_attr_file_; // rows number specified in attr file - QueueList> io_block_queues_; - WaitPost wp_; - std::vector>> image_labels_vec_; - std::string dataset_type_; - std::ifstream partition_file_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc deleted file mode 100644 index d378933c04..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ /dev/null @@ -1,472 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/cifar_op.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -constexpr uint32_t kCifarImageHeight = 32; -constexpr uint32_t kCifarImageWidth = 32; -constexpr uint32_t kCifarImageChannel = 3; -constexpr uint32_t kCifarBlockImageNum = 5; -constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; - -CifarOp::Builder::Builder() : sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - num_workers_ = cfg->num_parallel_workers(); - rows_per_buffer_ = cfg->rows_per_buffer(); - op_connect_size_ = cfg->op_connector_size(); - cifar_type_ = kCifar10; -} - -Status CifarOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - sampler_ = std::make_shared(start_index, num_samples); - } - schema_ = std::make_unique(); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - if (cifar_type_ == kCifar10) { - RETURN_IF_NOT_OK( - schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - } else { - RETURN_IF_NOT_OK(schema_->AddColumn( - ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - TensorShape another_scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(schema_->AddColumn( - ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); - } - - *ptr = std::make_shared(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, - std::move(schema_), std::move(sampler_)); - return Status::OK(); -} - -Status CifarOp::Builder::SanityCheck() { - Path dir(dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : ""; - err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, - int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_works, queue_size, std::move(sampler)), - cifar_type_(type), - rows_per_buffer_(rows_per_buf), - folder_path_(file_dir), - data_schema_(std::move(data_schema)), - row_cnt_(0), - buf_cnt_(0) { - constexpr uint64_t kUtilQueueSize = 512; - cifar_raw_data_block_ = std::make_unique>>(kUtilQueueSize); - io_block_queues_.Init(num_workers_, queue_size); -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status CifarOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - TensorRow sample_row; - RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { - keys.push_back(*itr); - row_cnt_++; - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -Status CifarOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK( - tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - // The order of the following 2 functions must not be changed! - RETURN_IF_NOT_OK(ParseCifarData()); // Parse cifar data and get num rows, blocking - RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler - return Status::OK(); -} - -// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -// IMPORTANT: 1 IOBlock produces 1 DataBuffer -Status CifarOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) { - return Status::OK(); // empty key is a quit signal for workers - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label). 1 function call produces 1 TensorTow in a DataBuffer -Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) { - std::shared_ptr label; - std::shared_ptr fine_label; - std::shared_ptr ori_image = cifar_image_label_pairs_[index].first; - std::shared_ptr copy_image = - std::make_shared(ori_image->shape(), ori_image->type(), ori_image->GetBuffer()); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&cifar_image_label_pairs_[index].second[0]))); - if (cifar_image_label_pairs_[index].second.size() > 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &fine_label, data_schema_->column(2).tensorImpl(), data_schema_->column(2).shape(), - data_schema_->column(2).type(), reinterpret_cast(&cifar_image_label_pairs_[index].second[1]))); - (*trow) = TensorRow(index, {copy_image, std::move(label), std::move(fine_label)}); - } else { - (*trow) = TensorRow(index, {copy_image, std::move(label)}); - } - - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status CifarOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - for (const int64_t &key : keys) { - TensorRow trow; - RETURN_IF_NOT_OK(LoadTensorRow(key, &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void CifarOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nCifar directory: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status CifarOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status CifarOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -Status CifarOp::ReadCifarBlockDataAsync() { - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(GetCifarFiles()); - if (cifar_type_ == kCifar10) { - RETURN_IF_NOT_OK(ReadCifar10BlockData()); - } else { - RETURN_IF_NOT_OK(ReadCifar100BlockData()); - } - - return Status::OK(); -} - -Status CifarOp::ReadCifar10BlockData() { - constexpr uint32_t num_cifar10_records = 10000; - uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M - std::vector image_data(block_size * sizeof(unsigned char), 0); - for (auto &file : cifar_files_) { - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - std::string err_msg = file + " can not be opened."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - - for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) { - (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); - if (in.fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); - } - (void)cifar_raw_data_block_->EmplaceBack(image_data); - } - in.close(); - } - (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // end block - - return Status::OK(); -} - -Status CifarOp::ReadCifar100BlockData() { - uint32_t num_cifar100_records = 0; // test:10000, train:50000 - uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M - std::vector image_data(block_size * sizeof(unsigned char), 0); - for (auto &file : cifar_files_) { - int pos = file.find_last_of('/'); - if (pos == std::string::npos) { - RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path"); - } - std::string file_name(file.substr(pos + 1)); - if (file_name.find("test") != std::string::npos) { - num_cifar100_records = 10000; - } else if (file_name.find("train") != std::string::npos) { - num_cifar100_records = 50000; - } else { - RETURN_STATUS_UNEXPECTED("Cifar 100 file not found!"); - } - - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - RETURN_STATUS_UNEXPECTED(file + " can not be opened."); - } - - for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) { - (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); - if (in.fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); - } - (void)cifar_raw_data_block_->EmplaceBack(image_data); - } - in.close(); - } - (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // block end - return Status::OK(); -} - -Status CifarOp::GetCifarFiles() { - // Initialize queue to hold the file names - const std::string kExtension = ".bin"; - Path dataset_directory(folder_path_); - auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory); - if (dirIt) { - while (dirIt->hasNext()) { - Path file = dirIt->next(); - std::string filename = file.toString(); - if (filename.find(kExtension) != std::string::npos) { - cifar_files_.push_back(filename); - MS_LOG(INFO) << "Cifar operator found file at " << filename << "."; - } - } - } else { - std::string err_msg = "Unable to open directory " + dataset_directory.toString(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::sort(cifar_files_.begin(), cifar_files_.end()); - return Status::OK(); -} - -Status CifarOp::ParseCifarData() { - std::vector block; - RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); - uint32_t cur_block_index = 0; - while (!block.empty()) { - for (uint32_t index = 0; index < kCifarBlockImageNum; ++index) { - std::vector labels; - uint32_t label = block[cur_block_index++]; - labels.push_back(label); - if (cifar_type_ == kCifar100) { - uint32_t fine_label = block[cur_block_index++]; - labels.push_back(fine_label); - } - - std::shared_ptr image_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image_tensor, data_schema_->column(0).tensorImpl(), - TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), - data_schema_->column(0).type())); - auto itr = image_tensor->begin(); - uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; - for (int pix = 0; pix < total_pix; ++pix) { - for (int ch = 0; ch < kCifarImageChannel; ++ch) { - *itr = block[cur_block_index + ch * total_pix + pix]; - itr++; - } - } - cur_block_index += total_pix * kCifarImageChannel; - cifar_image_label_pairs_.emplace_back(std::make_pair(image_tensor, labels)); - } - RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); - cur_block_index = 0; - } - cifar_image_label_pairs_.shrink_to_fit(); - num_rows_ = cifar_image_label_pairs_.size(); - if (num_rows_ == 0) { - std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; - std::string err_msg = "There is no valid data matching the dataset API " + api + - ".Please check file path or dataset API validation first."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - cifar_raw_data_block_->Reset(); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status CifarOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty()) { - RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); - } - - for (uint64_t index = 0; index < cifar_image_label_pairs_.size(); ++index) { - uint32_t label = (cifar_image_label_pairs_[index].second)[0]; - (*cls_ids)[label].push_back(index); - } - - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { - // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); - RETURN_IF_NOT_OK(op->GetCifarFiles()); - if (op->cifar_type_ == kCifar10) { - constexpr int64_t num_cifar10_records = 10000; - for (auto &file : op->cifar_files_) { - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - std::string err_msg = file + " can not be opened."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - *count = *count + num_cifar10_records; - } - return Status::OK(); - } else { - int64_t num_cifar100_records = 0; - for (auto &file : op->cifar_files_) { - size_t pos = file.find_last_of('/'); - if (pos == std::string::npos) { - std::string err_msg = "Invalid cifar100 file path"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::string file_name; - if (file.size() > 0) - file_name = file.substr(pos + 1); - else - RETURN_STATUS_UNEXPECTED("Invalid string length!"); - if (file_name.find("test") != std::string::npos) { - num_cifar100_records = 10000; - } else if (file_name.find("train") != std::string::npos) { - num_cifar100_records = 50000; - } - std::ifstream in(file, std::ios::binary); - if (!in.is_open()) { - std::string err_msg = file + " can not be opened."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - *count = num_cifar100_records; - return Status::OK(); - } -} - -// Visitor accept method for NodePass -Status CifarOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status CifarOp::ComputeColMap() { - // set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (uint32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h deleted file mode 100644 index 24324bbebb..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h +++ /dev/null @@ -1,236 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -class CifarOp : public ParallelOp, public RandomAccessOp { - public: - enum CifarType { kCifar10, kCifar100 }; - - class Builder { - public: - // Constructor for Builder class of CifarOp - // @return Builder setter method returns reference to the builder. - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param uint32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param uint32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - op_connect_size_ = size; - return *this; - } - - // Setter method - // @param uint32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetCifarDir(const std::string &dir) { - dir_ = dir; - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetCifarType(const bool cifar10) { - if (cifar10) { - cifar_type_ = kCifar10; - } else { - cifar_type_ = kCifar100; - } - return *this; - } - - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - std::string dir_; - int32_t num_workers_; - int32_t rows_per_buffer_; - int32_t op_connect_size_; - std::shared_ptr sampler_; - std::unique_ptr schema_; - CifarType cifar_type_; - }; - - // Constructor - // @param CifarType type - Cifar10 or Cifar100 - // @param uint32_t numWorks - Num of workers reading images in parallel - // @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer - // @param std::string - dir directory of cifar dataset - // @param uint32_t - queueSize - connector queue size - // @param std::unique_ptr sampler - sampler tells ImageFolderOp what to read - CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - // Destructor. - ~CifarOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param uint32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of CifarOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // Function to count the number of samples in the CIFAR dataset - // @param dir path to the CIFAR directory - // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 - // @param count output arg that will hold the actual dataset size - // @return - static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p Pointer to the NodePass to be accepted - /// \param[out] modified Indicator if the node was changed at all - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "CifarOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to a pair - // @param uint64_t index - index need to load - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(uint64_t index, TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Read block data from cifar file - // @return - Status ReadCifarBlockDataAsync(); - - // Called first when function is called - // @return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Get cifar files in dir - // @return - Status GetCifarFiles(); - - // Read cifar10 data as block - // @return - Status ReadCifar10BlockData(); - - // Read cifar100 data as block - // @return - Status ReadCifar100BlockData(); - - // Parse cifar data - // @return - Status ParseCifarData(); - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - CifarType cifar_type_; - int32_t rows_per_buffer_; - std::string folder_path_; - std::unique_ptr data_schema_; - int64_t row_cnt_; - int64_t buf_cnt_; - - WaitPost wp_; - QueueList> io_block_queues_; - std::unique_ptr>> cifar_raw_data_block_; - std::vector cifar_files_; - std::vector, std::vector>> cifar_image_label_pairs_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc deleted file mode 100644 index 9fceb6f333..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc +++ /dev/null @@ -1,555 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/clue_op.h" - -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/util/task_manager.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -ClueOp::Builder::Builder() - : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_num_workers_ = config_manager->num_parallel_workers(); - builder_op_connector_size_ = config_manager->op_connector_size(); - builder_rows_per_buffer_ = config_manager->rows_per_buffer(); - builder_worker_connector_size_ = config_manager->worker_connector_size(); -} - -Status ClueOp::Builder::ValidateInputs() const { - std::string err; - err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; - err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; - return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); -} - -Status ClueOp::Builder::Build(std::shared_ptr *op) { - RETURN_IF_NOT_OK(ValidateInputs()); - - // Throttle the number of workers if we have more workers than files! - if (static_cast(builder_num_workers_) > builder_clue_files_list_.size()) { - builder_num_workers_ = builder_clue_files_list_.size(); - MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers."; - } - - ColKeyMap ck_map; - for (auto &p : builder_cols_to_keyword_) { - ck_map.insert({p.first, split(p.second, '/')}); - } - - std::shared_ptr clue_op = std::make_shared( - builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, - builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, - builder_device_id_); - RETURN_IF_NOT_OK(clue_op->Init()); - *op = std::move(clue_op); - - return Status::OK(); -} - -std::vector ClueOp::Builder::split(const std::string &s, char delim) { - std::vector res; - std::stringstream ss(s); - std::string item; - - while (getline(ss, item, delim)) { - res.push_back(item); - } - return res; -} - -ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, - ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_device, int32_t device_id) - : ParallelOp(num_workers, op_connector_size), - rows_per_buffer_(rows_per_buffer), - num_rows_per_shard_(0), - all_num_rows_(0), - num_samples_(num_samples), - filename_index_(std::make_unique()), - clue_files_list_(std::move(clue_files_list)), - load_jagged_connector_(true), - cols_to_keyword_(cols_to_keyword), - shuffle_files_(shuffle_files), - finished_reading_dataset_(false), - num_devices_(num_device), - device_id_(device_id), - load_io_block_queue_(true) { - worker_connector_size_ = worker_connector_size; -} - -Status ClueOp::Init() { - RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); - - int32_t safe_queue_size = static_cast(std::ceil(clue_files_list_.size() / num_workers_) + 1); - io_block_queues_.Init(num_workers_, safe_queue_size); - - RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); - jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); - - return Status::OK(); -} - -Status ClueOp::Reset() { - load_jagged_connector_ = true; - load_io_block_queue_ = true; - - RETURN_IF_NOT_OK(ParallelOp::Reset()); - NotifyToFillIOBlockQueue(); - return Status::OK(); -} - -Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { - TensorRow tRow(1, nullptr); - (*tensor_table)->push_back(std::move(tRow)); - - std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); - (**tensor_table)[row][0] = std::move(tensor); - return Status::OK(); -} - -Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t) { - nlohmann::json cursor = js; - for (int i = 0; i < key_chain.size(); i++) { - if (cursor.find(key_chain[i]) != cursor.end()) { - cursor = cursor[key_chain[i]]; - } else { - RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]); - } - } - std::string final_str = key_chain.back(); - switch (cursor.type()) { - case nlohmann::detail::value_t::string: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); - break; - - case nlohmann::detail::value_t::number_integer: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); - break; - case nlohmann::detail::value_t::number_unsigned: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); - (*t)->SetItemAt({0}, cursor.get()); - break; - case nlohmann::detail::value_t::number_float: - RETURN_IF_NOT_OK( - Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); - (*t)->SetItemAt({0}, cursor.get()); - break; - case nlohmann::detail::value_t::array: - RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); - break; - default: - break; - } - return Status::OK(); -} - -Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id) { - std::ifstream handle(file); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Failed to open file " + file); - } - - int64_t rows_each_buffer = 0; - int64_t rows_total = 0; - std::string line; - std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - std::unique_ptr tensor_table = std::make_unique(); - - while (getline(handle, line)) { - if (line.empty()) { - continue; - } - // If read to the end offset of this file, break. - if (rows_total >= end_offset) { - break; - } - // Skip line before start offset. - if (rows_total < start_offset) { - rows_total++; - continue; - } - - try { - nlohmann::json js = nlohmann::json::parse(line); - int cols_count = cols_to_keyword_.size(); - TensorRow tRow(cols_count, nullptr); - tensor_table->push_back(std::move(tRow)); - - int cout = 0; - for (auto &p : cols_to_keyword_) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); - (*tensor_table)[rows_each_buffer][cout] = std::move(tensor); - cout++; - } - } catch (const std::exception &err) { - // Catch any exception and convert to Status return code - RETURN_STATUS_UNEXPECTED("Failed to load json file"); - } - - // RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); - rows_each_buffer++; - rows_total++; - if (rows_each_buffer == rows_per_buffer_) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - - cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - tensor_table = std::make_unique(); - rows_each_buffer = 0; - } - } - - if (rows_each_buffer > 0) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - } - return Status::OK(); -} - -Status ClueOp::operator()() { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - - // launch one thread, responsible for filling IoBlockQueue - RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this))); - - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1))); - - // must be called after launching workers. - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); - NotifyToFillIOBlockQueue(); - - while (!finished_reading_dataset_) { - int64_t buffer_id = 0; - int32_t workers_done = 0; - int64_t rows_read = 0; - load_io_block_queue_ = true; - - while (workers_done < num_workers_) { - std::unique_ptr buffer; - RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); - if (buffer->eoe()) { - workers_done++; - } else if (num_samples_ == 0 || rows_read < num_samples_) { - if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { - int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); - RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); - } - rows_read += buffer->NumRows(); - buffer->set_id(buffer_id++); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); - } else { - // end of epoch - load_jagged_connector_ = false; - load_io_block_queue_ = false; - } - } - - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - finished_reading_dataset_ = true; - NotifyToFillIOBlockQueue(); - } else { - jagged_buffer_connector_->DoReset(); - buffer_id = 0; - } - } - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - RETURN_IF_NOT_OK(PostEndOfData()); - return Status::OK(); -} - -Status ClueOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::unique_ptr io_block; - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - while (!io_block->eof()) { - if (!io_block->eoe()) { - if (load_jagged_connector_) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - } - } else { - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); - } - - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - } - return Status::OK(); -} - -// A print method typically used for debugging -void ClueOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ - << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ - << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; - for (int i = 0; i < clue_files_list_.size(); ++i) { - out << " " << clue_files_list_[i]; - } - out << "\n\n"; - } -} - -// Pops an element from a queue in io_block_queues -Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); - - return Status::OK(); -} - -// Pushes an element to a queue in io_block_queues -Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); - - return Status::OK(); -} - -static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); -} - -Status ClueOp::WaitToFillIOBlockQueue() { - // must be called first if called by worker spanwed by taskgroup - TaskManager::FindMe()->Post(); - - std::vector i_keys; - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; - while (true) { - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); - io_block_queue_wait_post_.Clear(); - - if (finished_reading_dataset_) { - break; - } - - if (shuffle_files_) { - ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); - } - RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); - } - return Status::OK(); -} - -Status ClueOp::FillIOBlockQueue(const std::vector &i_keys) { - int32_t queue_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - while (!finish) { - std::vector> file_index; - if (!i_keys.empty()) { - for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); - } - } else { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair(it.value(), it.key())); - } - } - for (auto file_info : file_index) { - if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { - auto ioBlock = - std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_info.first]; - } - - if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { - finish = false; - } else { - finish = true; - } - } - - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } - -bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count) { - *start_offset = 0; - *end_offset = 0; - bool push = false; - int64_t start_index = device_id_ * num_rows_per_shard_; - if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; - return false; - } - - int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; - if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { - *start_offset = start_index - pre_count; - push = true; - if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - if (pre_count >= start_index && pre_count < end_index) { - *start_offset = 0; - push = true; - if (pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - return push; -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker -// pops this control indicator, it will wait until the next epoch starts and then resume execution. -Status ClueOp::PostEndOfEpoch(int32_t queue_index) { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); - } - - return Status::OK(); -} - -Status ClueOp::CalculateNumRowsPerShard() { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - int64_t count = CountTotalRows(it.value()); - filename_numrows_[it.value()] = count; - all_num_rows_ += count; - } - if (all_num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " - "validation first."); - } - - num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); - MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; - return Status::OK(); -} - -int64_t ClueOp::CountTotalRows(const std::string &file) { - std::ifstream handle(file); - if (!handle.is_open()) { - MS_LOG(ERROR) << "Failed to open file: " << file; - return 0; - } - - std::string line; - int64_t count = 0; - while (getline(handle, line)) { - if (!line.empty()) { - count++; - } - } - - return count; -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. -// When the worker pops this control indicator, it will shut itself down gracefully. -Status ClueOp::PostEndOfData() { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); - } - - return Status::OK(); -} - -Status ClueOp::CountAllFileRows(const std::vector &files, int64_t *count) { - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op)); - for (auto file : files) { - *count += op->CountTotalRows(file); - } - return Status::OK(); -} - -Status ClueOp::ComputeColMap() { - // Set the column name mapping (base class field) - if (column_name_id_map_.empty()) { - int count = 0; - for (auto &p : cols_to_keyword_) { - column_name_id_map_[p.first] = count; - count++; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h deleted file mode 100644 index 487ed0d47f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h +++ /dev/null @@ -1,277 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/util/auto_index.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" - -namespace mindspore { -namespace dataset { -using StringIndex = AutoIndexObj; -using ColKeyMap = std::map>; - -class JaggedConnector; - -class ClueOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Checks if the inputs of the builder is valid. - // @return Status - the error code returned. - Status ValidateInputs() const; - - // Create the final object. - // @param op - dataset op. - // @return - the error code return. - Status Build(std::shared_ptr *op); - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumDevices(int64_t num_dev) { - builder_num_devices_ = num_dev; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDeviceId(int64_t dev_id) { - builder_device_id_ = dev_id; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetClueFilesList(const std::vector &files_list) { - builder_clue_files_list_ = files_list; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleFiles(bool shuffle_files) { - builder_shuffle_files_ = shuffle_files; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumSamples(int64_t num_samples) { - builder_num_samples_ = num_samples; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetColsKeyMap(const std::map &cols_to_key) { - builder_cols_to_keyword_ = cols_to_key; - return *this; - } - - // Split string based on a character delimiter - // @return - the a string vector - std::vector split(const std::string &s, char delim); - - private: - int32_t builder_device_id_; - int32_t builder_num_devices_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_num_samples_; - int32_t builder_worker_connector_size_; - std::vector builder_clue_files_list_; - bool builder_shuffle_files_; - std::map builder_cols_to_keyword_; - }; - - // Constructor of ClueOp - ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, - ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id); - - // Default destructor - ~ClueOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors - // @return Status - the error code returned - Status Init(); - - // Class functor operator () override. - // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - the error code returned. - Status operator()() override; - - // Overrides base class reset method. Cleans up any state info from it's previous execution - // reinitializes itself so that it can be executed again, as if it was just created. - // @return Status - the error code returned. - Status Reset() override; - - // Get total rows in files. - // @param files - all clue files. - // @param count - number of rows. - // @return Status - the error coed returned. - static Status CountAllFileRows(const std::vector &files, int64_t *count); - - // File names getter - // @return Vector of the input file names - std::vector FileNames() { return clue_files_list_; } - - private: - // The entry point for when workers are launched. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status WorkerEntry(int32_t worker_id) override; - - // Parses a single row and puts the data into a tensor table. - // @param line - the content of the row. - // @param tensor_table - the tensor table to put the parsed data in. - // @param row - the id of the row filled in the tensor table. - // @return Status - the error code returned. - Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); - - // Reads a clue file and loads the data into multiple buffers. - // @param file - the file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id); - - // Pops an element from a queue in IOBlockQueue. - // @param index - the index of the queue to pop from. - // @param out_block - the popped element. - // @return Status - the error code returned. - Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); - - // Pushes an element to a queue in IOBlockQueue. - // @param index - the index of the queue to push to. - // @param io_block - the element to push onto the queue. - // @return Status - the error code returned. - Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); - - // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. - // @return Status - the error code returned. - Status WaitToFillIOBlockQueue(); - - // Fill the IOBlockQueue. - // @para i_keys - keys of file to fill to the IOBlockQueue - // @return Status - the error code returned. - Status FillIOBlockQueue(const std::vector &i_keys); - - // Notifies the thread which called FillIoBlockQueue to resume execution - void NotifyToFillIOBlockQueue(); - - // Select file and push it to the block queue. - // @param file_name - File name. - // @param start_file - If file contains the first sample of data. - // @param end_file - If file contains the end sample of data. - // @param pre_count - Total rows of previous files. - // @return Status - the error code returned. - bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker - // pops this control indicator, it will wait until the next epoch starts and then resume execution. - // @return Status - the error code returned. - Status PostEndOfEpoch(int32_t queue_index); - - // Calculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard(); - - // Count number of rows in each file. - // @param filename - clue file name. - // @return int64_t - the total number of rows in file. - int64_t CountTotalRows(const std::string &file); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. - // When the worker pops this control indicator, it will shut itself down gracefully. - // @return Status - the error code returned. - Status PostEndOfData(); - - // @return Status - the error code returned. - Status GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t device_id_; - bool shuffle_files_; - bool finished_reading_dataset_; - int32_t num_devices_; - int64_t rows_per_buffer_; - bool load_io_block_queue_; - int64_t num_rows_per_shard_; - int64_t all_num_rows_; - int64_t num_samples_; - std::map filename_numrows_; - std::unique_ptr filename_index_; - std::vector clue_files_list_; - WaitPost io_block_queue_wait_post_; - std::unique_ptr jagged_buffer_connector_; - QueueList> io_block_queues_; - bool load_jagged_connector_; - ColKeyMap cols_to_keyword_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc deleted file mode 100644 index 7d14163544..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc +++ /dev/null @@ -1,646 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/coco_op.h" - -#include -#include -#include -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -const char kColumnImage[] = "image"; -const char kJsonImages[] = "images"; -const char kJsonImagesFileName[] = "file_name"; -const char kJsonId[] = "id"; -const char kJsonAnnotations[] = "annotations"; -const char kJsonAnnoSegmentation[] = "segmentation"; -const char kJsonAnnoCounts[] = "counts"; -const char kJsonAnnoSegmentsInfo[] = "segments_info"; -const char kJsonAnnoIscrowd[] = "iscrowd"; -const char kJsonAnnoBbox[] = "bbox"; -const char kJsonAnnoArea[] = "area"; -const char kJsonAnnoImageId[] = "image_id"; -const char kJsonAnnoNumKeypoints[] = "num_keypoints"; -const char kJsonAnnoKeypoints[] = "keypoints"; -const char kJsonAnnoCategoryId[] = "category_id"; -const char kJsonCategories[] = "categories"; -const char kJsonCategoriesIsthing[] = "isthing"; -const char kJsonCategoriesName[] = "name"; -const float kDefaultPadValue = -1.0; -const unsigned int kPadValueZero = 0; - -CocoOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); - builder_task_type_ = TaskType::Detection; -} - -Status CocoOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - switch (builder_task_type_) { - case TaskType::Detection: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - case TaskType::Stuff: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoSegmentation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - case TaskType::Keypoint: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoKeypoints), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoNumKeypoints), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - case TaskType::Panoptic: - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kJsonAnnoArea), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - break; - default: - RETURN_STATUS_UNEXPECTED("Invalid task type"); - } - *ptr = std::make_shared(builder_task_type_, builder_dir_, builder_file_, builder_num_workers_, - builder_rows_per_buffer_, builder_op_connector_size_, builder_decode_, - std::move(builder_schema_), std::move(builder_sampler_)); - return Status::OK(); -} - -Status CocoOp::Builder::SanityCheck() { - Path dir(builder_dir_); - Path file(builder_file_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "Coco image folder path is invalid or not set\n" : ""; - err_msg += file.Exists() == false ? "Coco annotation json path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, - int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size), - decode_(decode), - row_cnt_(0), - buf_cnt_(0), - task_type_(task_type), - image_folder_path_(image_folder_path), - annotation_path_(annotation_path), - rows_per_buffer_(rows_per_buffer), - sampler_(std::move(sampler)), - data_schema_(std::move(data_schema)) { - io_block_queues_.Init(num_workers_, queue_size); -} - -Status CocoOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) > num_rows_) continue; - keys->push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); - keys->clear(); - } - } - return Status::OK(); -} - -Status CocoOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); - if (sample_ids->type() != DataType(DataType::DE_INT64)) { - RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); - } - RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -void CocoOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows: " << num_rows_ << "\nCOCO Directory: " << image_folder_path_ << "\n\n"; - } -} - -Status CocoOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); - return Status::OK(); -} - -Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { - std::shared_ptr image, coordinate; - auto itr = coordinate_map_.find(image_id); - if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - std::string kImageFile = image_folder_path_ + image_id; - RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - - auto bboxRow = itr->second; - std::vector bbox_row; - dsize_t bbox_row_num = static_cast(bboxRow.size()); - dsize_t bbox_column_num = 0; - for (auto bbox : bboxRow) { - if (static_cast(bbox.size()) > bbox_column_num) { - bbox_column_num = static_cast(bbox.size()); - } - } - - for (auto bbox : bboxRow) { - bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); - dsize_t pad_len = bbox_column_num - static_cast(bbox.size()); - if (pad_len > 0) { - for (dsize_t i = 0; i < pad_len; i++) { - bbox_row.push_back(kDefaultPadValue); - } - } - } - - std::vector bbox_dim = {bbox_row_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&coordinate, data_schema_->column(1).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(1).type(), - reinterpret_cast(&bbox_row[0]))); - if (task_type_ == TaskType::Detection) { - RETURN_IF_NOT_OK(LoadDetectionTensorRow(row_id, image_id, image, coordinate, trow)); - } else if (task_type_ == TaskType::Stuff || task_type_ == TaskType::Keypoint) { - RETURN_IF_NOT_OK(LoadSimpleTensorRow(row_id, image_id, image, coordinate, trow)); - } else if (task_type_ == TaskType::Panoptic) { - RETURN_IF_NOT_OK(LoadMixTensorRow(row_id, image_id, image, coordinate, trow)); - } else { - RETURN_STATUS_UNEXPECTED("Invalid task type."); - } - - return Status::OK(); -} - -// When task is Detection, user can get data with four columns: -// column ["image"] with datatype=uint8 -// column ["bbox"] with datatype=float32 -// column ["category_id"] with datatype=uint32 -// column ["iscrowd"] with datatype=uint32 -// By the way, column ["iscrowd"] is used for some testcases, like fasterRcnn. -// If "iscrowd" is not existed, user will get default value 0. -Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow) { - std::shared_ptr category_id, iscrowd; - std::vector category_id_row; - std::vector iscrowd_row; - auto itr_item = simple_item_map_.find(image_id); - if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - std::vector annotation = itr_item->second; - for (int64_t i = 0; i < annotation.size(); i++) { - if (i % 2 == 0) { - category_id_row.push_back(annotation[i]); - } else if (i % 2 == 1) { - iscrowd_row.push_back(annotation[i]); - } - } - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); - (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)}); - return Status::OK(); -} - -// When task is "Stuff"/"Keypoint", user can get data with three columns: -// column ["image"] with datatype=uint8 -// column ["segmentation"]/["keypoints"] with datatype=float32 -// column ["iscrowd"]/["num_keypoints"] with datatype=uint32 -Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow) { - std::shared_ptr item; - std::vector item_queue; - auto itr_item = simple_item_map_.find(image_id); - if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - item_queue = itr_item->second; - std::vector bbox_dim = {static_cast(item_queue.size()), 1}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&item, data_schema_->column(2).tensorImpl(), TensorShape(bbox_dim), - data_schema_->column(2).type(), - reinterpret_cast(&item_queue[0]))); - (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)}); - return Status::OK(); -} - -// When task is "Panoptic", user can get data with five columns: -// column ["image"] with datatype=uint8 -// column ["bbox"] with datatype=float32 -// column ["category_id"] with datatype=uint32 -// column ["iscrowd"] with datatype=uint32 -// column ["area"] with datattype=uint32 -Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow) { - std::shared_ptr category_id, iscrowd, area; - std::vector category_id_row; - std::vector iscrowd_row; - std::vector area_row; - auto itr_item = simple_item_map_.find(image_id); - if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); - - std::vector annotation = itr_item->second; - for (int64_t i = 0; i < annotation.size(); i++) { - if (i % 3 == 0) { - category_id_row.push_back(annotation[i]); - } else if (i % 3 == 1) { - iscrowd_row.push_back(annotation[i]); - } else if (i % 3 == 2) { - area_row.push_back(annotation[i]); - } - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), - data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), - data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); - - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &area, data_schema_->column(4).tensorImpl(), TensorShape({static_cast(area_row.size()), 1}), - data_schema_->column(4).type(), reinterpret_cast(&area_row[0]))); - (*trow) = TensorRow( - row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)}); - return Status::OK(); -} - -Status CocoOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const int64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -Status CocoOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -template -Status CocoOp::SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node) { - auto node = input_tree.find(node_name); - if (node == input_tree.end()) RETURN_STATUS_UNEXPECTED("Invalid node found in json : " + node_name); - (*output_node) = *node; - return Status::OK(); -} - -Status CocoOp::ParseAnnotationIds() { - std::ifstream in(annotation_path_); - nlohmann::json js; - in >> js; - - std::vector image_que; - nlohmann::json image_list; - RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonImages), &image_list)); - RETURN_IF_NOT_OK(ImageColumnLoad(image_list, &image_que)); - if (task_type_ == TaskType::Detection || task_type_ == TaskType::Panoptic) { - nlohmann::json node_categories; - RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonCategories), &node_categories)); - RETURN_IF_NOT_OK(CategoriesColumnLoad(node_categories)); - } - nlohmann::json annotations_list; - RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonAnnotations), &annotations_list)); - for (auto annotation : annotations_list) { - int32_t image_id = 0, id = 0; - std::string file_name; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonAnnoImageId), &image_id)); - auto itr_file = image_index_.find(image_id); - if (itr_file == image_index_.end()) - RETURN_STATUS_UNEXPECTED("Invalid image id of annotations : " + std::to_string(image_id)); - file_name = itr_file->second; - switch (task_type_) { - case TaskType::Detection: - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); - RETURN_IF_NOT_OK(DetectionColumnLoad(annotation, file_name, id)); - break; - case TaskType::Stuff: - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); - RETURN_IF_NOT_OK(StuffColumnLoad(annotation, file_name, id)); - break; - case TaskType::Keypoint: - RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); - RETURN_IF_NOT_OK(KeypointColumnLoad(annotation, file_name, id)); - break; - case TaskType::Panoptic: - RETURN_IF_NOT_OK(PanopticColumnLoad(annotation, file_name, image_id)); - break; - default: - RETURN_STATUS_UNEXPECTED("Invalid task type"); - } - } - for (auto img : image_que) { - if (coordinate_map_.find(img) != coordinate_map_.end()) image_ids_.push_back(img); - } - num_rows_ = image_ids_.size(); - return Status::OK(); -} - -Status CocoOp::ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec) { - if (image_tree.size() == 0) { - RETURN_STATUS_UNEXPECTED("No images found in " + annotation_path_); - } - for (auto img : image_tree) { - std::string file_name; - int32_t id = 0; - RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonImagesFileName), &file_name)); - RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonId), &id)); - - image_index_[id] = file_name; - image_vec->push_back(file_name); - } - return Status::OK(); -} - -Status CocoOp::DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &unique_id) { - std::vector bbox; - nlohmann::json node_bbox; - uint32_t category_id = 0, iscrowd = 0; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoBbox), &node_bbox)); - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoCategoryId), &category_id)); - auto search_category = category_set_.find(category_id); - if (search_category == category_set_.end()) - RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + std::to_string(category_id)); - auto node_iscrowd = annotation_tree.find(kJsonAnnoIscrowd); - if (node_iscrowd != annotation_tree.end()) iscrowd = *node_iscrowd; - bbox.insert(bbox.end(), node_bbox.begin(), node_bbox.end()); - coordinate_map_[image_file].push_back(bbox); - simple_item_map_[image_file].push_back(category_id); - simple_item_map_[image_file].push_back(iscrowd); - return Status::OK(); -} - -Status CocoOp::StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &unique_id) { - uint32_t iscrowd = 0; - std::vector bbox; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoIscrowd), &iscrowd)); - simple_item_map_[image_file].push_back(iscrowd); - nlohmann::json segmentation; - RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoSegmentation), &segmentation)); - if (iscrowd == 0) { - for (auto item : segmentation) { - if (bbox.size() > 0) bbox.clear(); - bbox.insert(bbox.end(), item.begin(), item.end()); - coordinate_map_[image_file].push_back(bbox); - } - } else if (iscrowd == 1) { - nlohmann::json segmentation_count; - RETURN_IF_NOT_OK(SearchNodeInJson(segmentation, std::string(kJsonAnnoCounts), &segmentation_count)); - bbox.insert(bbox.end(), segmentation_count.begin(), segmentation_count.end()); - coordinate_map_[image_file].push_back(bbox); - } - return Status::OK(); -} - -Status CocoOp::KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &unique_id) { - auto itr_num_keypoint = annotation_tree.find(kJsonAnnoNumKeypoints); - if (itr_num_keypoint == annotation_tree.end()) - RETURN_STATUS_UNEXPECTED("No num_keypoint found in annotations where id: " + std::to_string(unique_id)); - simple_item_map_[image_file].push_back(*itr_num_keypoint); - auto itr_keypoint = annotation_tree.find(kJsonAnnoKeypoints); - if (itr_keypoint == annotation_tree.end()) - RETURN_STATUS_UNEXPECTED("No keypoint found in annotations where id: " + std::to_string(unique_id)); - coordinate_map_[image_file].push_back(*itr_keypoint); - return Status::OK(); -} - -Status CocoOp::PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, - const int32_t &image_id) { - auto itr_segments = annotation_tree.find(kJsonAnnoSegmentsInfo); - if (itr_segments == annotation_tree.end()) - RETURN_STATUS_UNEXPECTED("No segments_info found in annotations where image_id: " + std::to_string(image_id)); - for (auto info : *itr_segments) { - std::vector bbox; - uint32_t category_id = 0; - auto itr_bbox = info.find(kJsonAnnoBbox); - if (itr_bbox == info.end()) - RETURN_STATUS_UNEXPECTED("No bbox found in segments_info where image_id: " + std::to_string(image_id)); - bbox.insert(bbox.end(), itr_bbox->begin(), itr_bbox->end()); - coordinate_map_[image_file].push_back(bbox); - - RETURN_IF_NOT_OK(SearchNodeInJson(info, std::string(kJsonAnnoCategoryId), &category_id)); - auto search_category = category_set_.find(category_id); - if (search_category == category_set_.end()) - RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + - std::to_string(category_id)); - auto itr_iscrowd = info.find(kJsonAnnoIscrowd); - if (itr_iscrowd == info.end()) - RETURN_STATUS_UNEXPECTED("No iscrowd found in segments_info where image_id: " + std::to_string(image_id)); - auto itr_area = info.find(kJsonAnnoArea); - if (itr_area == info.end()) - RETURN_STATUS_UNEXPECTED("No area found in segments_info where image_id: " + std::to_string(image_id)); - simple_item_map_[image_file].push_back(category_id); - simple_item_map_[image_file].push_back(*itr_iscrowd); - simple_item_map_[image_file].push_back(*itr_area); - } - return Status::OK(); -} - -Status CocoOp::CategoriesColumnLoad(nlohmann::json categories_tree) { - if (categories_tree.size() == 0) RETURN_STATUS_UNEXPECTED("No categories found in " + annotation_path_); - for (auto category : categories_tree) { - int32_t id = 0; - std::string name; - std::vector label_info; - auto itr_id = category.find(kJsonId); - if (itr_id == category.end()) RETURN_STATUS_UNEXPECTED("No id found in categories of " + annotation_path_); - id = *itr_id; - label_info.push_back(id); - category_set_.insert(id); - - auto itr_name = category.find(kJsonCategoriesName); - if (itr_name == category.end()) - RETURN_STATUS_UNEXPECTED("No name found in categories where id: " + std::to_string(id)); - name = *itr_name; - - if (task_type_ == TaskType::Panoptic) { - auto itr_isthing = category.find(kJsonCategoriesIsthing); - if (itr_isthing == category.end()) - RETURN_STATUS_UNEXPECTED("No isthing found in categories of " + annotation_path_); - label_info.push_back(*itr_isthing); - } - label_index_.emplace_back(std::make_pair(name, label_info)); - } - return Status::OK(); -} - -Status CocoOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -Status CocoOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(this->ParseAnnotationIds()); - RETURN_IF_NOT_OK(this->InitSampler()); - return Status::OK(); -} - -Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); - - if (decode_ == true) { - Status rc = Decode(*tensor, tensor); - if (rc.IsError()) { - RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); - } - } - return Status::OK(); -} - -Status CocoOp::CountTotalRows(const std::string &dir, const std::string &file, const std::string &task, - int64_t *count) { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - *count = static_cast(op->image_ids_.size()); - return Status::OK(); -} - -Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file, const std::string &task, - std::vector>> *output_class_indexing) { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - *output_class_indexing = op->label_index_; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status CocoOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status CocoOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h deleted file mode 100644 index 2a93d26195..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h +++ /dev/null @@ -1,340 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_COC0_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using CoordinateRow = std::vector>; - -class CocoOp : public ParallelOp, public RandomAccessOp { - public: - enum class TaskType { Detection = 0, Stuff = 1, Panoptic = 2, Keypoint = 3 }; - - class Builder { - public: - // Constructor for Builder class of ImageFolderOp - // @param uint32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method. - // @param const std::string & build_dir - // @return Builder setter method returns reference to the builder. - Builder &SetDir(const std::string &build_dir) { - builder_dir_ = build_dir; - return *this; - } - - // Setter method. - // @param const std::string & build_file - // @return Builder setter method returns reference to the builder. - Builder &SetFile(const std::string &build_file) { - builder_file_ = build_file; - return *this; - } - - // Setter method. - // @param const std::string & task_type - // @return Builder setter method returns reference to the builder. - Builder &SetTask(const std::string &task_type) { - if (task_type == "Detection") { - builder_task_type_ = TaskType::Detection; - } else if (task_type == "Stuff") { - builder_task_type_ = TaskType::Stuff; - } else if (task_type == "Panoptic") { - builder_task_type_ = TaskType::Panoptic; - } else if (task_type == "Keypoint") { - builder_task_type_ = TaskType::Keypoint; - } - return *this; - } - - // Setter method. - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method. - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Check validity of input args - // @return = The error code return - Status SanityCheck(); - - // The builder "Build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - std::string builder_dir_; - std::string builder_file_; - TaskType builder_task_type_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int32_t builder_rows_per_buffer_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - }; - - // Constructor - // @param TaskType task_type - task type of Coco - // @param std::string image_folder_path - image folder path of Coco - // @param std::string annotation_path - annotation json path of Coco - // @param int32_t num_workers - number of workers reading images in parallel - // @param int32_t rows_per_buffer - number of images (rows) in each buffer - // @param int32_t queue_size - connector queue size - // @param int64_t num_samples - number of samples to read - // @param bool decode - whether to decode images - // @param std::unique_ptr data_schema - the schema of the Coco dataset - // @param std::shared_ptr sampler - sampler tells CocoOp what to read - CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, - int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, - std::unique_ptr data_schema, std::shared_ptr sampler); - - // Destructor - ~CocoOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of CocoOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // @param const std::string &dir - Coco image dir path - // @param const std::string &file - Coco json file path - // @param const std::string &task - task mode of Coco task - // @param int64_t numSamples - samples number of CocoDataset - // @param int64_t *count - output rows number of CocoDataset - static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - int64_t *count); - - // @param const std::string &dir - Coco image dir path - // @param const std::string &file - Coco json file path - // @param const std::string &task - task mode of Coco task - // @param int64_t numSamples - samples number of CocoDataset - // @param std::map *output_class_indexing - output class index of CocoDataset - static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - std::vector>> *output_class_indexing); - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p Pointer to the NodePass to be accepted - /// \param[out] modified Indicator if the node was changed at all - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to image id - // @param row_id_type row_id - id for this tensor row - // @param std::string image_id - image id - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); - - // Load a tensor row with vector which a vector to a tensor - // @param row_id_type row_id - id for this tensor row - // @param const std::string &image_id - image is - // @param std::shared_ptr image - image tensor - // @param std::shared_ptr coordinate - coordinate tensor - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow); - - // Load a tensor row with vector which a vector to a tensor - // @param row_id_type row_id - id for this tensor row - // @param const std::string &image_id - image is - // @param std::shared_ptr image - image tensor - // @param std::shared_ptr coordinate - coordinate tensor - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow); - - // Load a tensor row with vector which a vector to multi-tensor - // @param row_id_type row_id - id for this tensor row - // @param const std::string &image_id - image is - // @param std::shared_ptr image - image tensor - // @param std::shared_ptr coordinate - coordinate tensor - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, - std::shared_ptr coordinate, TensorRow *trow); - - // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return - // @return Status - The error code return - Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Read annotation from Annotation folder - // @return Status - The error code return - Status ParseAnnotationIds(); - - // @param const std::shared_ptr &sample_ids - sample ids of tensor - // @param std::vector *keys - image id - // @return Status - The error code return - Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // Reset dataset state - // @return Status - The error code return - Status Reset() override; - - // @param nlohmann::json image_tree - image tree of json - // @param std::vector *image_vec - image id list of json - // @return Status - The error code return - Status ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec); - - // @param nlohmann::json categories_tree - categories tree of json - // return Status - The error code return - Status CategoriesColumnLoad(nlohmann::json categories_tree); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &id - current unique id of annotation - // @return Status - The error code return - Status DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &id - current unique id of annotation - // @return Status - The error code return - Status StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &id - current unique id of annotation - // @return Status - The error code return - Status KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); - - // @param nlohmann::json categories_tree - categories tree of json - // @param const std::string &image_file - current image name in annotation - // @param const int32_t &image_id - current unique id of annotation - // @return Status - The error code return - Status PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &image_id); - - template - Status SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - bool decode_; - int64_t row_cnt_; - int64_t buf_cnt_; - std::string image_folder_path_; - std::string annotation_path_; - TaskType task_type_; - int32_t rows_per_buffer_; - std::shared_ptr sampler_; - std::unique_ptr data_schema_; - - WaitPost wp_; - std::vector image_ids_; - std::map image_index_; - QueueList> io_block_queues_; - std::vector>> label_index_; - std::map coordinate_map_; - std::map> simple_item_map_; - std::set category_set_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_Coco_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc deleted file mode 100644 index 36c221fc16..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.cc +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/generator_op.h" -#include -#include "dataset/core/global_context.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/task_manager.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -GeneratorOp::Builder::Builder() { - // Some arguments to the GeneratorOp constructor have a default argument that is taken - // from the client config. - build_buffer_size_ = kCfgRowsPerBuffer; - build_op_connector_size_ = kCfgOpConnectorSize; -} - -Status GeneratorOp::Builder::SanityCheck() { - // Update queue size to fit the prefetch requirement - MS_LOG(DEBUG) << "Generator operator sanity check, prefetch size is " << build_prefetch_size_ << "."; - if (build_prefetch_size_ > 0) { - build_op_connector_size_ = (build_prefetch_size_ + build_buffer_size_ - 1) / build_buffer_size_; - } - return Status::OK(); -} - -Status GeneratorOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_generator_function_, build_column_names_, build_column_types_, - build_prefetch_size_, build_buffer_size_, build_op_connector_size_); - return (*ptr)->Init(); -} - -GeneratorOp::GeneratorOp(py::function generator_function, std::vector column_names, - std::vector column_types, int32_t prefetch_size, int32_t buffer_size, - int32_t connector_size) - : PipelineOp(connector_size), - generator_function_(generator_function), - column_names_(column_names), - column_types_(column_types), - prefetch_size_(prefetch_size), - buffer_size_(buffer_size), - buffer_id_(0) {} - -GeneratorOp::~GeneratorOp() { this->Dealloc(); } - -void GeneratorOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nColumn names:\n"; - for (int i = 0; i < column_names_.size(); ++i) { - out << "\n " << column_names_[i]; - } - out << "\n\n"; - } -} - -void GeneratorOp::Dealloc() noexcept { - // Setup GIL state - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); - // GC the generator object within GIL - (void)generator_.dec_ref(); - // Release GIL - PyGILState_Release(gstate); -} - -// Reentrant init method. -Status GeneratorOp::Init() { - // Reset BufferID - buffer_id_ = 0; - Status ret; - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - // Invoke the generatorFunction to get generator object - try { - generator_ = generator_function_(); - } catch (const py::error_already_set &e) { - ret = Status(StatusCode::kPyFuncException, e.what()); - } - } - return ret; -} - -Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { - if (!py::isinstance(py_data)) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator should return a tuple of numpy arrays."); - } - py::tuple py_row = py_data.cast(); - // Check if returned number of columns matches with column names - if (py_row.size() != column_names_.size()) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, - "Generator should return same number of numpy arrays as specified in column names."); - } - // Iterate over two containers simultaneously for memory copy - for (int i = 0; i < py_row.size(); ++i) { - py::object ret_py_ele = py_row[i]; - if (!py::isinstance(ret_py_ele)) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, - "Generator should return a tuple of numpy arrays."); - } - std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, ret_py_ele.cast())); - if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) && - (column_types_[i] != tensor->type())) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator type check failed."); - } - tensor_row->push_back(tensor); - } - return Status(StatusCode::kOK, ""); -} - -Status GeneratorOp::FillBuffer(TensorQTable *tt) { - for (int i = 0; i < buffer_size_; i++) { - TensorRow row; - RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row)); - tt->push_back(std::move(row)); - } - return Status::OK(); -} - -// Entry point for Generator, called by launch() -// Note that this function is very easy to break because of the Python GIL mechanism -// The master thread has the following workflow -// -// while !eof: -// Try: -// Prepare one data buffer GIL, Can throw -// Catch: -// Fetch Python Exception GIL -// Check if Exception is StopIteration (EOE) GIL -// Restore Python Exception GIL -// If not StopIteration: -// Return Status PyFuncException -// -// Push data buffer to connector Block -// -// if EOE -// Push EOE Block -// if more epoch: -// Block until next epoch Block -// else: -// Push EOF Block -// eof = true -// Return Status OK -// -// Note that any modification of this function need to guarantee: -// 1. All "Require GIL" operations are protected by GIL -// SegFault / Deadlock will occur if this condition is not fulfilled. -// 2. All "Block" operations are free from GIL, all block target are registered with tree. -// Deadlock will occur if this condition is not fulfilled -// 3. No Python GC should be triggered outside of GIL. -// SegFault will occur is this condition is not fulfilled -// -Status GeneratorOp::operator()() { - // Handshake with TaskManager to synchronize thread creation - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - std::unique_ptr fetched_buffer; - bool eof = false; - while (!eof) { - // Create new buffer each iteration - fetched_buffer = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); - std::unique_ptr fetched_table = std::make_unique(); - bool eoe = false; - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - RETURN_IF_NOT_OK(FillBuffer(fetched_table.get())); - } catch (py::error_already_set &e) { - eoe = e.matches(PyExc_StopIteration); - // Restore exception to python - e.restore(); - // Pop up non StopIteration Python Exception - if (!eoe) { - return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); - } - } - } - if (fetched_table->size() > 0) { - fetched_buffer->set_tensor_table(std::move(fetched_table)); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); - } - if (eoe) { - // Push out EOE upon StopIteration exception from generator - MS_LOG(DEBUG) << "Generator operator sends out EOE."; - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - // If last repeat or not repeated, push out EOF and exit master loop - MS_LOG(DEBUG) << "Generator operator sends out EOF."; - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - MS_LOG(DEBUG) << "Generator operator main execution loop complete."; - eof = true; - } else { - // Waiting for repeatOp to start new epoch - // If Reset() is called first by repeat op, this wait() will return right away. - // If Reset() is not called yet, this wait() will block until reset. - RETURN_IF_NOT_OK(wp_.Wait()); - // Clear the status of the wait post - wp_.Clear(); - } - } - } - return Status::OK(); -} - -Status GeneratorOp::Reset() { - // Reset Op state - RETURN_IF_NOT_OK(this->Init()); - // Wake up master thread - wp_.Set(); - return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); -} - -// Visitor accept method for NodePass -Status GeneratorOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status GeneratorOp::ComputeColMap() { - // Setup column names map (base class field) - if (column_name_id_map_.empty()) { - for (int i = 0; i < column_names_.size(); ++i) { - column_name_id_map_[column_names_[i]] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h deleted file mode 100644 index 98dd2d70a1..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/generator_op.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -#pragma GCC visibility push(hidden) - -class GeneratorOp : public PipelineOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetGeneratorFunction(py::function generator_function) { - build_generator_function_ = generator_function; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetColumnNames(const std::vector &column_names) { - build_column_names_ = column_names; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetColumnTypes(const std::vector &column_types) { - build_column_types_ = column_types; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetPrefetchSize(int32_t prefetch_size) { - build_prefetch_size_ = prefetch_size; - return *this; - } - - // The builder "build" method creates the final object. - // @return shared_ptr to the new GeneratorOp object - Status Build(std::shared_ptr *); - - private: - // The builder saves all GeneratorOp construction arguments internally. - // The following are the arguments. - py::function build_generator_function_; - std::vector build_column_names_; - std::vector build_column_types_; - - int32_t build_prefetch_size_ = 0; - int32_t build_buffer_size_; - int32_t build_op_connector_size_; - - Status SanityCheck(); - }; - - GeneratorOp(py::function generator_function, std::vector column_names, - std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size); - - ~GeneratorOp(); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param generator_op - reference to the GeneratorOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { - generator_op.Print(out, false); - return out; - } - - // Class functor operator () override. - // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status - The error code return - Status operator()() override; - - // Overrides base class reset method. When an operator does a reset, it cleans up any state - // info from it's previous execution and then initializes itself so that it can be executed - // again. - // @return Status - The error code return - Status Reset() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "GeneratorOp"; } - - private: - py::function generator_function_; - std::vector column_names_; - std::vector column_types_; - int32_t prefetch_size_; - int32_t buffer_size_; - - py::object generator_; - int32_t buffer_id_; - - WaitPost wp_; - - Status Init(); - - void Dealloc() noexcept; - - Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); - - Status FillBuffer(TensorQTable *tt); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; -}; - -#pragma GCC visibility pop -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc deleted file mode 100644 index 837eae1e3c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.cc +++ /dev/null @@ -1,429 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include -#include -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); - *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, builder_recursive_, builder_decode_, - builder_extensions_, builder_labels_to_read_, std::move(builder_schema_), - std::move(builder_sampler_)); - return Status::OK(); -} - -Status ImageFolderOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "ImageFolder path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, - bool recursive, bool do_decode, const std::set &exts, - const std::map &map, std::unique_ptr data_schema, - std::shared_ptr sampler) - : ParallelOp(num_wkrs, queue_size, std::move(sampler)), - rows_per_buffer_(rows_per_buffer), - folder_path_(file_dir), - recursive_(recursive), - decode_(do_decode), - extensions_(exts), - class_index_(map), - data_schema_(std::move(data_schema)), - row_cnt_(0), - buf_cnt_(0), - sampler_ind_(0), - dirname_offset_(0) { - folder_name_queue_ = std::make_unique>(num_wkrs * queue_size); - image_name_queue_ = std::make_unique>(num_wkrs * queue_size); - io_block_queues_.Init(num_workers_, queue_size); -} - -// Master thread that pulls the prescan worker's results. -// Keep collecting results until all prescan workers quit -// Then consolidate 2 level shuffles together into 1 giant vector -// calculate numRows then return -Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { - std::vector v; - int64_t cnt = 0; - while (cnt != num_workers_) { // count number of end signals - FolderImagesPair p; - RETURN_IF_NOT_OK(image_name_queue_->PopFront(&p)); - if (p == nullptr) { - cnt++; - } else { - v.push_back(p); - } - } - std::sort(v.begin(), v.end(), - [](const FolderImagesPair &lhs, const FolderImagesPair &rhs) { return lhs->first < rhs->first; }); - // following loop puts the 2 level of shuffles together into 1 vector - for (size_t ind = 0; ind < v.size(); ++ind) { - while (v[ind]->second.empty() == false) { - MS_ASSERT(!(v[ind]->first.empty())); // make sure that v[ind]->first.substr(1) is not out of bound - v[ind]->second.front()->second = class_index_.empty() ? ind : class_index_[v[ind]->first.substr(1)]; - image_label_pairs_.push_back(v[ind]->second.front()); - v[ind]->second.pop(); - } - } - image_label_pairs_.shrink_to_fit(); - num_rows_ = image_label_pairs_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset " - "API validation first."); - } - // free memory of two queues used for pre-scan - folder_name_queue_->Reset(); - image_name_queue_->Reset(); - return Status::OK(); -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status ImageFolderOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - TensorRow sample_row; - RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - keys.push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK( - io_block_queues_[buf_cnt_++ % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); - for (int32_t i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -// IMPORTANT: 1 IOBlock produces 1 DataBuffer -Status ImageFolderOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer -Status ImageFolderOp::LoadTensorRow(row_id_type row_id, ImageLabelPair pairPtr, TensorRow *trow) { - std::shared_ptr image, label; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), - reinterpret_cast(&pairPtr->second))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, folder_path_ + (pairPtr->first))); - - if (decode_ == true) { - Status rc = Decode(image, &image); - if (rc.IsError()) { - std::string err = "Fail to decode image:" + folder_path_ + (pairPtr->first); - RETURN_STATUS_UNEXPECTED(err); - } - } - (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status ImageFolderOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const int64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void ImageFolderOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nImageFolder directory: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status ImageFolderOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status ImageFolderOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status ImageFolderOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { - RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); - } - for (size_t i = 0; i < image_label_pairs_.size(); ++i) { - (*cls_ids)[image_label_pairs_[i]->second].push_back(i); - } - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -// Worker Entry for pre-scanning all the folders and do the 1st level shuffle -// Worker pull a file name from mFoldernameQueue (which is a Queue), walks all the images under that foldername -// After walking is complete, sort all the file names (relative path to all jpeg files under the same directory ) -// (Sort is automatically conducted using a set which is implemented using a Red-Black Tree) -// Add the sorted filenames in to a queue. The make a pair (foldername, queue*), -// foldername is used for 2nd level sorting. -// FYI: 1st level sorting: sort all images under the same directory. -// FYI: 2nd level sorting: sort all folder names -// push this pair to mImagenameQueue (which is again a Queue) -Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::string folder_name; - RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); - while (folder_name.empty() == false) { - Path folder(folder_path_ + folder_name); - std::shared_ptr dirItr = Path::DirIterator::OpenDirectory(&folder); - if (folder.Exists() == false || dirItr == nullptr) { - RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_name); - } - std::set imgs; // use this for ordering - while (dirItr->hasNext()) { - Path file = dirItr->next(); - if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) { - (void)imgs.insert(file.toString().substr(dirname_offset_)); - } else { - MS_LOG(WARNING) << "Image folder operator unsupported file found: " << file.toString() - << ", extension: " << file.Extension() << "."; - } - } - FolderImagesPair p = std::make_shared>>(); - p->first = folder_name; - for (const std::string &img : imgs) { - p->second.push(std::make_shared>(img, 0)); - } - RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(p)); - RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); - } - RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(nullptr)); // end signal - return Status::OK(); -} - -// This helper function recursively walks all foldernames, and send each foldername to mFoldernameQueue -// if mRecursive == false, don't go into folder of folders -Status ImageFolderOp::RecursiveWalkFolder(Path *dir) { - std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(dir); - RETURN_UNEXPECTED_IF_NULL(dir_itr); - while (dir_itr->hasNext()) { - Path subdir = dir_itr->next(); - if (subdir.IsDirectory()) { - if (class_index_.empty() || - class_index_.find(subdir.toString().substr(dirname_offset_ + 1)) != class_index_.end()) { - RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack(subdir.toString().substr(dirname_offset_))); - } - if (recursive_ == true) { - RETURN_IF_NOT_OK(RecursiveWalkFolder(&subdir)); - } - } - } - return Status::OK(); -} - -// A thread that calls RecursiveWalkFolder -Status ImageFolderOp::startAsyncWalk() { - TaskManager::FindMe()->Post(); - Path dir(folder_path_); - if (dir.Exists() == false || dir.IsDirectory() == false) { - RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_path_); - } - dirname_offset_ = folder_path_.length(); - RETURN_IF_NOT_OK(RecursiveWalkFolder(&dir)); - // send out num_workers_ end signal to mFoldernameQueue, 1 for each worker. - // Upon receiving end Signal, worker quits and set another end Signal to mImagenameQueue. - for (int32_t ind = 0; ind < num_workers_; ++ind) { - RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack("")); // end signal - } - return Status::OK(); -} - -Status ImageFolderOp::LaunchThreadsAndInitOp() { - RETURN_UNEXPECTED_IF_NULL(tree_); - // Registers QueueList and individual Queues for interrupt services - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - // The following code launch 3 threads group - // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. - // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue - // 3) Launch main workers that load DataBuffers by reading all images - RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::startAsyncWalk, this))); - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1))); - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - // The order of the following 2 functions must not be changed! - RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking - RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler - return Status::OK(); -} - -Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, - int64_t *num_classes, int64_t dev_id, int64_t num_dev) { - Path dir(path); - std::string err_msg = ""; - int64_t row_cnt = 0; - err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : ""; - err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : ""; - err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : ""; - if (err_msg.empty() == false) { - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::queue foldernames; - std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(&dir); - while (dir_itr->hasNext()) { - Path subdir = dir_itr->next(); - if (subdir.IsDirectory()) { - foldernames.push(subdir.toString()); - } - } - (*num_classes) = foldernames.size(); - while (foldernames.empty() == false) { - Path subdir(foldernames.front()); - dir_itr = Path::DirIterator::OpenDirectory(&subdir); - while (dir_itr->hasNext()) { - if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) { - ++row_cnt; - } - } - foldernames.pop(); - } - (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ImageFolderOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status ImageFolderOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h deleted file mode 100644 index 6629fd6092..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/image_folder_op.h +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using ImageLabelPair = std::shared_ptr>; -using FolderImagesPair = std::shared_ptr>>; - -class ImageFolderOp : public ParallelOp, public RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of ImageFolderOp - // @param int32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_op_connector_size_ = size; - return *this; - } - - // Setter method - // @param std::set & exts, file extensions to be read - // @return Builder setter method returns reference to the builder. - Builder &SetExtensions(const std::set &exts) { - builder_extensions_ = exts; - return *this; - } - - // Setter method - // @paramconst std::map& map - a class name to label map - // @return - Builder &SetClassIndex(const std::map &map) { - builder_labels_to_read_ = map; - return *this; - } - - // Setter method - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetImageFolderDir(const std::string &dir) { - builder_dir_ = dir; - return *this; - } - - // Whether dir are walked recursively - // @param bool recursive - if set to false, only get dirs in top level dir - // @return - Builder &SetRecursive(bool recursive) { - builder_recursive_ = recursive; - return *this; - } - - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - bool builder_recursive_; - std::string builder_dir_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::set builder_extensions_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - std::map builder_labels_to_read_; - }; - - // Constructor - // @param int32_t num_wkrs - Num of workers reading images in parallel - // @param int32_t - rows_per_buffer Number of images (rows) in each buffer - // @param std::string - dir directory of ImageNetFolder - // @param int32_t queue_size - connector queue size - // @param std::set exts - set of file extensions to read, if empty, read everything under the dir - // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read - ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, - bool do_decode, const std::set &exts, const std::map &map, - std::unique_ptr, std::shared_ptr sampler); - - // Destructor. - ~ImageFolderOp() = default; - - // Initialize ImageFOlderOp related var, calls the function to walk all files - // @param - std::string dir file directory to ImageNetFolder - // @return - The error code return - Status PrescanMasterEntry(const std::string &dir); - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status PrescanWorkerEntry(int32_t worker_id); - - // Main Loop of ImageFolderOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // This function is a hack! It is to return the num_class and num_rows. The result - // returned by this function may not be consistent with what image_folder_op is going to return - // user this at your own risk! - static Status CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, - int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ImageFolderOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param ImageLabelPair pair - - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // @param std::string & dir - dir to walk all images - // @param int64_t * cnt - number of non folder files under the current dir - // @return - Status RecursiveWalkFolder(Path *dir); - - // start walking of all dirs - // @return - Status startAsyncWalk(); - - // Called first when function is called - // @return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; - std::string folder_path_; // directory of image folder - bool recursive_; - bool decode_; - std::set extensions_; // extensions allowed - std::map class_index_; - std::unique_ptr data_schema_; - int64_t row_cnt_; - int64_t buf_cnt_; - int64_t sampler_ind_; - int64_t dirname_offset_; - WaitPost wp_; - std::vector image_label_pairs_; - QueueList> io_block_queues_; // queues of IOBlocks - std::unique_ptr> folder_name_queue_; - std::unique_ptr> image_name_queue_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.cc deleted file mode 100644 index 0963f1a67a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/io_block.h" - -#include -#include - -namespace mindspore { -namespace dataset { -// IOBlock Class // - -// Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. -IOBlock::IOBlock(int64_t inKey, IOBlockFlags io_block_flags) : index_keys_(1, inKey), io_block_flags_(io_block_flags) {} - -// Constructor of the IOBlock (2) -IOBlock::IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) { - index_keys_.insert(index_keys_.end(), in_keys.begin(), in_keys.end()); -} - -// Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. -IOBlock::IOBlock(IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) {} - -// Fetches the first key from this block -Status IOBlock::GetKey(int64_t *out_key) const { - if (out_key == nullptr || index_keys_.empty()) { - RETURN_STATUS_UNEXPECTED("Failed to get the key from IOBlock"); - } - *out_key = index_keys_[0]; - return Status::OK(); -} - -// Fetches the list of keys from this block. -Status IOBlock::GetKeys(std::vector *out_keys) const { - if (out_keys == nullptr) { - RETURN_STATUS_UNEXPECTED("Output arg for GetKeys is null"); - } - *out_keys = index_keys_; // vector copy assign - return Status::OK(); -} - -// FilenameBlock derived class // - -// Constructor of the FilenameBlock (1) -FilenameBlock::FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags) - : IOBlock(key, io_block_flags), start_offset_(start_offset), end_offset_(end_offset) {} - -// Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. -FilenameBlock::FilenameBlock(IOBlockFlags io_block_flags) - : IOBlock(io_block_flags), start_offset_(kInvalidOffset), end_offset_(kInvalidOffset) {} - -// Gets the filename from the block using the provided index container -Status FilenameBlock::GetFilename(std::string *out_filename, const AutoIndexObj &index) const { - if (out_filename == nullptr) { - RETURN_STATUS_UNEXPECTED("Failed to get filename from FilenameBlock"); - } - - // a FilenameBlock only has one key. Call base class method to fetch that key - int64_t fetched_key; - RETURN_IF_NOT_OK(IOBlock::GetKey(&fetched_key)); - - // Do an index lookup using that key to get the filename. - auto r = index.Search(fetched_key); - if (r.second) { - auto &it = r.first; - *out_filename = it.value(); - } else { - RETURN_STATUS_UNEXPECTED("Could not find filename from index"); - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.h b/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.h deleted file mode 100644 index 87b417f027..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/io_block.h +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ - -#include -#include - -#include "dataset/util/auto_index.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// The IOBlock class is used to describe a "unit of work" that a storage leaf operator worker thread -// is responsible for acting on. -// The IOBlocks and it's derived classes abstracts a key-store and key-lookup interface where each -// block contains 1 to n keys, and the keys are used in conjunction with an index to provide the meta -// information for satisfying an IO request. -class IOBlock { - public: - enum IOBlockFlags : uint32_t { - kDeIoBlockNone = 0, - kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch - kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program - }; - - // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. - // @param inKey - A single key to add into the block - // @param io_block_flags - The flag setting for the block - IOBlock(int64_t inKey, IOBlockFlags io_block_flags); - - // Constructor of the IOBlock (2). - // @param in_keys - A vector of keys to add into the block - // @param io_block_flags - The flag setting for the block - IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags); - - // Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. - // @param io_block_flags - The flag setting for the block - explicit IOBlock(IOBlockFlags io_block_flags); - - // Destructor - virtual ~IOBlock() = default; - - // Fetches the first key from the block. - // @note Only useful if you know the block only has 1 key. - // @return A copy of the first key from the block - // @return Status - The error code return - Status GetKey(int64_t *out_key) const; - - // Fetches the list of keys from this block. - // @param out_keys - A copy of the vector of keys from the block. - // @return Status - The error code return - Status GetKeys(std::vector *out_keys) const; - - // Does this block have the eoe flag turned on? - // @return T/F if the IOBlock is eoe - bool eoe() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEoe); } - - // Does this block have the eof flag turned on? - // @return T/F if the IOBlock is eof - bool eof() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEof); } - - // Adds a key to this block - // @param key - The key to add to this block - void AddKey(int64_t key) { index_keys_.push_back(key); } - - protected: - std::vector index_keys_; // keys used for lookups to the meta info for the data - IOBlockFlags io_block_flags_; -}; // class IOBlock - -const int64_t kInvalidOffset = -1; - -// The Filename block derived class implements a style of IO block where each block contains only a -// single key that maps to a filename. -class FilenameBlock : public IOBlock { - public: - // Constructor of the FilenameBlock (1) - // @param key - The key identifier that can be used to find the data for this block - // @param start_offset - Start offset - // @param end_offset - End offset - // @param io_block_flags - The flag setting for the block - FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags); - - // Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. - // @param io_block_flags - The flag setting for the block - explicit FilenameBlock(IOBlockFlags io_block_flags); - - // Destructor - ~FilenameBlock() = default; - - // Gets the filename from the block using the provided index container - // @param out_filename - The filename to add to the block - // @param index - The index to perform lookup against - // @return Status - The error code return - Status GetFilename(std::string *out_filename, const AutoIndexObj &index) const; - - // Get the start offset of file - // @return int64_t - Start offset - int64_t GetStartOffset() const { return start_offset_; } - - // Get the end offset of the file - // @return int64_t - Start offset - int64_t GetEndOffset() const { return end_offset_; } - - private: - int64_t start_offset_; - int64_t end_offset_; -}; // class TFBlock -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc deleted file mode 100644 index 4f9a12bd65..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc +++ /dev/null @@ -1,438 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/manifest_op.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status ManifestOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); - *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_file_, - builder_op_connector_size_, builder_decode_, builder_labels_to_read_, - std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); - return Status::OK(); -} - -Status ManifestOp::Builder::SanityCheck() { - std::string err_msg; - err_msg += builder_file_.empty() ? "Manifest file is not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers smaller than 1\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, - const std::map &class_index, std::unique_ptr data_schema, - std::shared_ptr sampler, std::string usage) - : ParallelOp(num_works, queue_size, std::move(sampler)), - rows_per_buffer_(rows_per_buffer), - io_block_pushed_(0), - row_cnt_(0), - sampler_ind_(0), - data_schema_(std::move(data_schema)), - file_(file), - class_index_(class_index), - decode_(decode), - usage_(usage), - buf_cnt_(0) { - io_block_queues_.Init(num_workers_, queue_size); - (void)std::transform(usage_.begin(), usage_.end(), usage_.begin(), ::tolower); -} - -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status ManifestOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - return AddIoBlock(&sampler_buffer); -} - -Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (!(*sampler_buffer)->eoe()) { - TensorRow sample_row; - RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - keys.push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - keys.clear(); - } - } - RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); - } - } -} - -Status ManifestOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(ParseManifestFile()); - RETURN_IF_NOT_OK(CountDatasetInfo()); - RETURN_IF_NOT_OK(InitSampler()); - return Status::OK(); -} - -// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -// IMPORTANT: 1 IOBlock produces 1 DataBuffer -Status ManifestOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty()) { - return Status::OK(); // empty key is a quit signal for workers - } - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer -Status ManifestOp::LoadTensorRow(row_id_type row_id, const std::pair> &data, - TensorRow *trow) { - std::shared_ptr image; - std::shared_ptr label; - std::vector label_index(data.second.size()); - (void)std::transform(data.second.begin(), data.second.end(), label_index.begin(), - [this](const std::string &label_name) { return label_index_[label_name]; }); - if (label_index.size() == 1) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), TensorShape({}), - data_schema_->column(1).type(), - reinterpret_cast(&label_index[0]))); - } else { - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &label, data_schema_->column(1).tensorImpl(), TensorShape(std::vector(1, label_index.size())), - data_schema_->column(1).type(), reinterpret_cast(&label_index[0]))); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data.first)); - if (decode_ == true) { - Status rc = Decode(image, &image); - if (rc.IsError()) { - std::string err = "Fail to decode image:" + data.first; - RETURN_STATUS_UNEXPECTED(err); - } - } - (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status ManifestOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - for (const auto &key : keys) { - TensorRow trow; - RETURN_IF_NOT_OK(LoadTensorRow(key, image_labelname_[static_cast(key)], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void ManifestOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nManifest file: " << file_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status ManifestOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status ManifestOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status ManifestOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) { - RETURN_STATUS_UNEXPECTED("Class indexing is invalid."); - } - - for (size_t i = 0; i < image_labelname_.size(); i++) { - size_t image_index = i; - for (size_t j = 0; j < image_labelname_[image_index].second.size(); j++) { - std::string label_name = (image_labelname_[image_index].second)[j]; - int32_t label_index = label_index_.at(label_name); - (*cls_ids)[label_index].emplace_back(image_index); - } - } - - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -// Manifest file content -// {"source": "/path/to/image1.jpg", "usage":"train", annotation": ...} -// {"source": "/path/to/image2.jpg", "usage":"eval", "annotation": ...} -Status ManifestOp::ParseManifestFile() { - std::ifstream file_handle(file_); - if (!file_handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Manifest file " + file_ + " can not open."); - } - std::string line; - while (getline(file_handle, line)) { - try { - nlohmann::json js = nlohmann::json::parse(line); - std::string image_file_path = js.value("source", ""); - // If image is not JPEG/PNG/GIF/BMP, drop it - bool valid = false; - RETURN_IF_NOT_OK(CheckImageType(image_file_path, &valid)); - if (!valid) { - continue; - } - std::string usage = js.value("usage", ""); - (void)std::transform(usage.begin(), usage.end(), usage.begin(), ::tolower); - if (usage != usage_) { - continue; - } - std::vector labels; - nlohmann::json annotations = js.at("annotation"); - for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) { - nlohmann::json annotation = it.value(); - std::string label_name = annotation.value("name", ""); - if (label_name == "") { - file_handle.close(); - RETURN_STATUS_UNEXPECTED("Label name is not found in manifest file for " + image_file_path); - } - if (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) { - if (label_index_.find(label_name) == label_index_.end()) { - label_index_[label_name] = 0; - } - labels.emplace_back(label_name); - } - } - if (!labels.empty()) { - image_labelname_.emplace_back(std::make_pair(image_file_path, labels)); - } - } catch (const std::exception &err) { - file_handle.close(); - RETURN_STATUS_UNEXPECTED("Parse manifest file failed"); - } - } - file_handle.close(); - - return Status::OK(); -} - -// Only support JPEG/PNG/GIF/BMP -Status ManifestOp::CheckImageType(const std::string &file_name, bool *valid) { - std::ifstream file_handle; - constexpr int read_num = 3; - *valid = false; - file_handle.open(file_name, std::ios::binary | std::ios::in); - if (!file_handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Can not open image file " + file_name); - } - unsigned char file_type[read_num]; - (void)file_handle.read(reinterpret_cast(file_type), read_num); - - if (file_handle.fail()) { - file_handle.close(); - RETURN_STATUS_UNEXPECTED("Read image file failed " + file_name); - } - file_handle.close(); - if (file_type[0] == 0xff && file_type[1] == 0xd8 && file_type[2] == 0xff) { - // Normal JPEGs start with \xff\xd8\xff\xe0 - // JPEG with EXIF stats with \xff\xd8\xff\xe1 - // Use \xff\xd8\xff to cover both. - *valid = true; - } else if (file_type[0] == 0x89 && file_type[1] == 0x50 && file_type[2] == 0x4e) { - // It's a PNG - *valid = true; - } else if (file_type[0] == 0x47 && file_type[1] == 0x49 && file_type[2] == 0x46) { - // It's a GIF - *valid = true; - } else if (file_type[0] == 0x42 && file_type[1] == 0x4d) { - // It's a BMP - *valid = true; - } - return Status::OK(); -} - -Status ManifestOp::CountDatasetInfo() { - int32_t index = 0; - for (auto &label : label_index_) { - label.second = class_index_.empty() ? index : class_index_[label.first]; - index++; - } - - num_rows_ = static_cast(image_labelname_.size()); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " - "validation first."); - } - return Status::OK(); -} - -Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, - int64_t *count, int64_t *numClasses) { - // the logic of counting the number of samples is copied from ParseManifestFile() - std::map map; - for (auto p : dict) { - (void)map.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); - RETURN_IF_NOT_OK(op->ParseManifestFile()); - *numClasses = static_cast(op->label_index_.size()); - *count = static_cast(op->image_labelname_.size()); - return Status::OK(); -} - -Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, - std::map *output_class_indexing) { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - if (!input_class_indexing.empty()) { - *output_class_indexing = input_class_indexing; - } else { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op)); - RETURN_IF_NOT_OK(op->ParseManifestFile()); - RETURN_IF_NOT_OK(op->CountDatasetInfo()); - uint32_t count = 0; - for (const auto label : op->label_index_) { - (*output_class_indexing).insert(std::make_pair(label.first, count)); - count++; - } - } - - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ManifestOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status ManifestOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h deleted file mode 100644 index 864abf676c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h +++ /dev/null @@ -1,250 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/queue.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -class ManifestOp : public ParallelOp, public RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of ManifestOp - Builder(); - - // Destructor - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t size) { - builder_op_connector_size_ = size; - return *this; - } - - // Setter method - // @param const std::map& map - a class name to label map - // @return - Builder &SetClassIndex(const std::map &map) { - builder_labels_to_read_ = map; - return *this; - } - - // Setter method - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return Builder setter method returns reference to the builder. - Builder &SetManifestFile(const std::string &file) { - builder_file_ = file; - return *this; - } - - // Setter method - // @param const std::string & dir - // @return Builder setter method returns reference to the builder. - Builder &SetUsage(const std::string &usage) { - builder_usage_ = usage; - return *this; - } - - // Check validity of input args - // @return Status - The error code return - Status SanityCheck(); - - // The builder "build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - std::shared_ptr builder_sampler_; - bool builder_decode_; - - std::string builder_file_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::unique_ptr builder_schema_; - std::string builder_usage_; - std::map builder_labels_to_read_; - }; - - // Constructor - // @param int32_t num_works - Num of workers reading images in parallel - // @param int32_t - rows_per_buffer Number of images (rows) in each buffer - // @param std::string - file list of Manifest - // @param int32_t queue_size - connector queue size - // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read - ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, - const std::map &class_index, std::unique_ptr data_schema, - std::shared_ptr sampler, std::string usage); - // Destructor. - ~ManifestOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t worker_id - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of ManifestOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, - int64_t *numClasses); - - // Get str-to-int mapping from label name to index - static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, - std::map *output_class_indexing); - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p Pointer to the NodePass to be accepted - /// \param[out] modified Indicator if the node was changed at all - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ManifestOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Method in operator(), to fill IOBlockQueue - // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue - // @return Status - The error code return - Status AddIoBlock(std::unique_ptr *sampler_buffer); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param std::pair> - > - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::pair> &data, - TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Parse manifest file to get image path and label and so on. - // @return Status - The error code return - Status ParseManifestFile(); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Check if image ia valid.Only support JPEG/PNG/GIF/BMP - // @return - Status CheckImageType(const std::string &file_name, bool *valid); - - // Count label index,num rows and num samples - // @return Status - The error code return - Status CountDatasetInfo(); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; - int64_t io_block_pushed_; - int64_t row_cnt_; - int64_t sampler_ind_; - std::unique_ptr data_schema_; - std::string file_; // file that store the information of images - std::map class_index_; - bool decode_; - std::string usage_; - int64_t buf_cnt_; - - WaitPost wp_; - QueueList> io_block_queues_; - std::map label_index_; - std::vector>> image_labelname_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc deleted file mode 100644 index 2b9d010ebb..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ /dev/null @@ -1,513 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/mindrecord_op.h" - -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/constants.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using mindrecord::kInt64Len; -using mindrecord::MSRStatus; -using mindrecord::Schema; -using mindrecord::ShardOperator; -using mindrecord::ShardReader; - -// Builder constructor. Creates the builder object. -MindRecordOp::Builder::Builder() : build_dataset_file_({}) { - // Some arguments to the MindRecordOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the MindRecordOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - build_num_mind_record_workers_ = kDefaultMindRecordWorkers; - build_rows_per_buffer_ = cfg->rows_per_buffer(); - build_op_connector_queue_size_ = cfg->op_connector_size(); - build_block_reader_ = false; - builder_num_workers_ = 0; - build_num_padded_ = 0; - build_sample_ = nullptr; -} - -// The builder "build" method creates the final object. -Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { - std::shared_ptr new_mind_record_op; - - if (build_dataset_file_.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Building a MindRecordOp that has not provided a file."); - } - mindrecord::json sample_json; - if (build_num_padded_ > 0) { - sample_json = ToJson(build_sample_); - } - new_mind_record_op = std::make_shared( - build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, - build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_, build_num_padded_, - sample_json, build_sample_bytes_); - - RETURN_IF_NOT_OK(new_mind_record_op->Init()); - *ptr = std::move(new_mind_record_op); - return Status::OK(); -} - -Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } - -mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { - if (obj.is_none()) { - return nullptr; - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { // also catch py::bytes - return obj.cast(); - } - if (py::isinstance(obj)) { - auto out = mindrecord::json::object(); - for (const py::handle &key : obj) { - if (py::isinstance(obj[key])) { - build_sample_bytes_[py::str(key).cast()] = obj[key].cast(); - } else { - out[py::str(key).cast()] = ToJson(obj[key]); - } - } - return out; - } - MS_LOG(ERROR) << "Python object convert to json failed, object is: " << py::cast(obj); - return mindrecord::json(); -} - -// Constructor of the MindRecordOp. -MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, - std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, - const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader, - int64_t num_padded, const mindrecord::json &sample_json, - const std::map &sample_bytes) - : ParallelOp(num_mind_record_workers, op_connector_queue_size), - rows_per_buffer_(rows_per_buffer), - dataset_file_(dataset_file), - load_dataset_(load_dataset), - columns_to_load_(columns_to_load), - operators_(operators), - num_mind_record_workers_(num_mind_record_workers), - block_reader_(block_reader), - num_rows_(0), - buffers_needed_(0), - buf_cnt_(0), - ended_worker_(0), - buffer_water_mark_(0), - num_padded_(num_padded), - sample_json_(sample_json), - sample_bytes_(sample_bytes) { - io_blk_queues_.Init(num_workers_, op_connector_queue_size); - if (!block_reader_) return; - for (int32_t i = 0; i < num_workers_; ++i) { - block_buffer_.emplace_back(std::make_unique>(std::vector{})); - } -} - -// Private helper method to encapsulate some common construction/reset tasks -Status MindRecordOp::Init() { - shard_reader_ = std::make_unique(); - auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, - block_reader_, num_padded_); - - CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, - "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); - - data_schema_ = std::make_unique(); - - std::vector col_names = shard_reader_->GetShardColumn()->GetColumnName(); - CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); - std::vector col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType(); - std::vector> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape(); - - bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything - std::map colname_to_ind; - for (uint32_t i = 0; i < col_names.size(); i++) { - std::string colname = col_names[i]; - ColDescriptor col_desc; - - TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown - std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; - DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} - - if (col_data_types[i] == mindrecord::ColumnBytes) { // rank = 1 - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); - } else if (col_data_types[i] == mindrecord::ColumnString) { // rank = 0 - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 0); - } else if (col_shapes[i].size() > 0) { - std::vector vec(col_shapes[i].size()); // temporary vector to hold shape - (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); - t_shape = TensorShape(vec); - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); - } else { // unknown shape - // create colDesc and add it to schema - col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); - } - - colname_to_ind[colname] = data_schema_->NumColumns(); - RETURN_IF_NOT_OK(data_schema_->AddColumn(col_desc)); - - if (load_all_cols) { - columns_to_load_.emplace_back(colname); - } - } - - if (!load_all_cols) { - std::unique_ptr tmp_schema = std::make_unique(); - for (std::string colname : columns_to_load_) { - CHECK_FAIL_RETURN_UNEXPECTED(colname_to_ind.find(colname) != colname_to_ind.end(), colname + ": doesn't exist"); - RETURN_IF_NOT_OK(tmp_schema->AddColumn(data_schema_->column(colname_to_ind[colname]))); - } - data_schema_ = std::move(tmp_schema); - } - - return Status::OK(); -} - -// Destructor -MindRecordOp::~MindRecordOp() {} - -// A print method typically used for debugging -void MindRecordOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\n Dataset file : "; - for (auto &file : dataset_file_) { - out << file << " "; - } - out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_ - << "\nNumber of buffers : " << buffers_needed_ - << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; - } -} - -Status MindRecordOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe()) { - RETURN_IF_NOT_OK( - out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - if (io_block->eof()) { - RETURN_IF_NOT_OK( - out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - - // load data buffer - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) { - { - std::unique_lock lock(ended_worker_mutex_); - ended_worker_++; - if (ended_worker_ == num_workers_) shard_reader_->Close(); - } - return Status::OK(); // empty key is a quit signal for workers - } - - const uint64_t buffer_id = keys[0]; - std::unique_ptr fetched_buffer; - - // Get the next buffer. Push it up to the output connector. - if (buffer_id % LOG_INTERVAL == 0) { - MS_LOG(DEBUG) << "MindRecord operator consumed buffer " << buffer_id << " by worker " << worker_id << "."; - } - RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); - if (!block_reader_) { - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - continue; - } - - // update block-reader buffer - block_buffer_[buffer_id % num_workers_]->clear(); - { - std::unique_lock lck(mtx_block_reader_); - if (buffer_id == buffer_water_mark_) { - buffer_water_mark_++; - while (block_set_.count(buffer_water_mark_) > 0) (void)block_set_.erase(buffer_water_mark_++); - } else { - (void)block_set_.insert(buffer_id); - } - } - cv_reader_.notify_one(); - RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, - int32_t worker_id) { - *fetched_buffer = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - std::unique_ptr tensor_table = std::make_unique(); - for (int32_t i = 0; i < rows_per_buffer_; ++i) { - ShardTuple tupled_buffer; - mindrecord::TaskType task_type = mindrecord::TaskType::kCommonTask; - if (block_reader_) { - if (i >= block_buffer_[buffer_id % num_workers_]->size()) break; - tupled_buffer = block_buffer_[buffer_id % num_workers_]->at(i); - } else { - int32_t row_id = buffer_id * rows_per_buffer_ + i; - auto rc = shard_reader_->GetNextById(row_id, worker_id); - task_type = rc.first; - tupled_buffer = rc.second; - if (task_type == mindrecord::TaskType::kPaddedTask) { - TensorRow tensor_row; - RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); - tensor_table->push_back(std::move(tensor_row)); - } - if (tupled_buffer.empty()) break; - } - if (task_type == mindrecord::TaskType::kCommonTask) { - for (const auto &tupled_row : tupled_buffer) { - std::vector columns_blob = std::get<0>(tupled_row); - mindrecord::json columns_json = std::get<1>(tupled_row); - TensorRow tensor_row; - RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json, task_type)); - tensor_table->push_back(std::move(tensor_row)); - } - } - } - - // Replace the TensorTable in DataBuffer with the new one. - (*fetched_buffer)->set_tensor_table(std::move(tensor_table)); - return Status::OK(); -} - -Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, - const mindrecord::json &columns_json, const mindrecord::TaskType task_type) { - for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { - auto column_name = columns_to_load_[i_col]; - - // Initialize column parameters - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0; - mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; - uint64_t column_data_type_size = 1; - std::vector column_shape; - - // Get column data - auto shard_column = shard_reader_->GetShardColumn(); - if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { - auto rc = - shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); - if (rc.first != MSRStatus::SUCCESS) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve data type."); - } - if (rc.second == mindrecord::ColumnInRaw) { - auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); - if (has_column == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve raw data from padding sample."); - } - } else if (rc.second == mindrecord::ColumnInBlob) { - if (sample_bytes_.find(column_name) == sample_bytes_.end()) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve blob data from padding sample."); - } - std::string ss(sample_bytes_[column_name]); - n_bytes = ss.size(); - data_ptr = std::make_unique(n_bytes); - std::copy(ss.begin(), ss.end(), data_ptr.get()); - } else { - RETURN_STATUS_UNEXPECTED("Retrieved data type is unknown."); - } - if (data == nullptr) { - data = reinterpret_cast(data_ptr.get()); - } - } else { - auto has_column = - shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, - &column_data_type, &column_data_type_size, &column_shape); - if (has_column == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); - } - } - - std::shared_ptr tensor; - const ColDescriptor &column = data_schema_->column(i_col); - DataType type = column.type(); - - // Set shape - auto num_elements = n_bytes / column_data_type_size; - if (type == DataType::DE_STRING) { - std::string s{data, data + n_bytes}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {s}, TensorShape::CreateScalar())); - } else if (column.hasShape()) { - auto new_shape = TensorShape(column.shape()); - RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast(num_elements), &new_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); - } else { - std::vector shapeDetails = {static_cast(num_elements)}; - auto new_shape = TensorShape(shapeDetails); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); - } - tensor_row->push_back(std::move(tensor)); - } - return Status::OK(); -} - -Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { - { - std::unique_lock lck(mtx_block_reader_); - cv_reader_.wait(lck, [buffer_id, this] { return buffer_id < buffer_water_mark_ + num_workers_; }); - } - for (int32_t i = 0; i < rows_per_buffer_; i++) { - // Block reader does NOT care about argument - auto rc = shard_reader_->GetNextById(i, i); - ShardTuple tuple_buffer = rc.second; - if (tuple_buffer.empty()) break; - block_buffer_[buffer_id % num_workers_]->push_back(std::move(tuple_buffer)); - } - return Status::OK(); -} - -// Class functor operator () override. -// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work -// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work -Status MindRecordOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); - num_rows_ = shard_reader_->GetNumRows(); - // Compute how many buffers we would need to accomplish rowsPerBuffer - buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; - - while (true) { // each iterator is 1 epoch - for (int32_t i = 0; i < buffers_needed_; ++i) { - if (block_reader_) RETURN_IF_NOT_OK(FetchBlockBuffer(i)); - std::vector keys(1, i); - RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( - std::move(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone)))); - } - return Status::OK(); - } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset - RETURN_IF_NOT_OK( - io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - - // reset our buffer count and go to loop again. - RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); - shard_reader_wait_post_.Clear(); - } - } -} - -// Overrides base class reset method. When an operator does a reset, it cleans up any state -// info from it's previous execution and then initializes itself so that it can be executed -// again. -Status MindRecordOp::Reset() { - RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. - - if (block_reader_) { - shard_reader_->Reset(); - buffer_water_mark_ = 0; - } else { - shard_reader_->ShuffleTask(); - } - shard_reader_wait_post_.Set(); - - return Status::OK(); -} - -Status MindRecordOp::LaunchThreadAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - - RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); - if (shard_reader_->Launch(!block_reader_) == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); - } - // Launch main workers that load DataBuffers by reading all images - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - return Status::OK(); -} - -Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, - const std::shared_ptr &op, int64_t *count, int64_t num_padded) { - std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); - if (rc == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); - } - return Status::OK(); -} - -// Visitor accept method for NodePass -Status MindRecordOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status MindRecordOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - for (int i = 0; i < static_cast(columns_to_load_.size()); i++) { - column_name_id_map_[columns_to_load_[i]] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h deleted file mode 100644 index af405a8f5b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ /dev/null @@ -1,276 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "mindrecord/include/shard_column.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/common/shard_utils.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; -class DataBuffer; - -using mindrecord::ShardOperator; -using mindrecord::ShardReader; -using ShardTuple = std::vector, mindrecord::json>>; // Row of data from ShardReader - -const int32_t LOG_INTERVAL = 19; - -class MindRecordOp : public ParallelOp { - public: - // The nested builder class inside of the MindRecordOp is used to help manage all of the arguments - // for constructing it. Use the builder by setting each argument with the provided set methods, - // and then finally call the build method to execute the actual construction. - class Builder { - public: - Builder(); - - ~Builder() = default; - - Status Build(std::shared_ptr *); - - Builder &SetRowsPerBuffer(int rows_per_buffer) { - build_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - Builder &SetNumMindRecordWorkers(int32_t num_mind_record_workers) { - build_num_mind_record_workers_ = num_mind_record_workers; - return *this; - } - - Builder &SetOpConnectorQueueSize(int32_t queue_size) { - build_op_connector_queue_size_ = queue_size; - return *this; - } - - Builder &SetDatasetFile(const std::vector &files) { - build_dataset_file_ = files; - return *this; - } - - Builder &SetColumnsToLoad(const std::vector &columns) { - build_columns_to_load_ = columns; - return *this; - } - - Builder &SetOperators(const std::vector> &operators) { - build_operators_ = operators; - return *this; - } - - Builder &SetBlockReader() { - build_block_reader_ = true; - return *this; - } - - Builder &SetLoadDataset(bool load_dataset) { - build_load_dataset_ = load_dataset; - return *this; - } - - Builder &SetNumToPadSamples(int64_t num_padded) { - build_num_padded_ = num_padded; - return *this; - } - - Builder &SetPaddedSample(const py::handle &sample) { - build_sample_ = sample; - return *this; - } - - Status SanityCheck() const; - - static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } - - mindrecord::json ToJson(const py::handle &obj); - - private: - static constexpr int32_t kDefaultMindRecordWorkers = 4; - // The builder saves all MindRecordOp construction arguments internally. - // The following are the arguments. - int32_t build_num_mind_record_workers_; - int32_t builder_num_workers_; - int32_t build_rows_per_buffer_; - int32_t build_op_connector_queue_size_; - std::vector build_dataset_file_; - bool build_load_dataset_; - std::vector build_columns_to_load_; - std::vector> build_operators_; - bool build_block_reader_; - int64_t build_num_padded_; - py::handle build_sample_; - std::map build_sample_bytes_; - }; - - // Constructor of the MindRecordOp. - // @note The builder class should be used to call it - // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) - // @param rows_per_buffer - The requested number of rows per buffer - // @param dataset_file - dataset files - // @param op_connector_queue_size - The output connector queue size - // @param columns_to_load - The list of columns to use (column name) - // @param operators - ShardOperators for Shuffle, Category, Sample - MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, - bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, - const std::vector> &operators, const bool &block_reader, - int64_t num_padded_, const mindrecord::json &sample_json, - const std::map &sample_bytes_); - - // Destructor - ~MindRecordOp() override; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param op - reference to the MindRecordOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const MindRecordOp &op) { - op.Print(out, false); - return out; - } - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Class functor operator () override. - // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work. - // @return Status - The error code return - Status operator()() override; - - // Called first when function is called - // @return - Status LaunchThreadAndInitOp(); - - // Overrides base class reset method. When an operator does a reset, it cleans up any state - // info from it's previous execution and then initializes itself so that it can be executed - // again. - // @return Status - The error code return - Status Reset() override; - - // Getter method - int32_t num_rows() const { return num_rows_; } - - static Status CountTotalRows(const std::vector dataset_path, bool load_dataset, - const std::shared_ptr &op, int64_t *count, int64_t num_padded); - - // Getter method - int32_t rows_per_buffer() const { return rows_per_buffer_; } - - // Getter method - std::vector dataset_file() const { return dataset_file_; } - - // Getter method - std::vector columns_to_load() const { return columns_to_load_; } - - bool block_reader() const { return block_reader_; } - - bool load_dataset() const { return load_dataset_; } - - Status Init(); - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MindRecordOp"; } - - private: - Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); - - // Parses a single cell and puts the data into a tensor - // @param tensor_row - the tensor row to put the parsed data in - // @param columns_blob - the blob data received from the reader - // @param columns_json - the data for fields received from the reader - Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, - const mindrecord::json &columns_json, const mindrecord::TaskType task_type); - - Status FetchBlockBuffer(const int32_t &buffer_id); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t rows_per_buffer_; // The number of requested rows per buffer. - std::vector dataset_file_; // dataset files - bool load_dataset_; // load dataset from single file or not - std::vector columns_to_load_; // Columns to load from dataset - std::vector> operators_; // ShardOperators to use - int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader - bool block_reader_; // block reader switch - int32_t buffers_needed_; // Counter for the buffers that were fetched - int64_t buf_cnt_; // Buffer counter - int32_t num_rows_; // One more than the last row id in the range for this cache - std::atomic ended_worker_; - std::atomic buffer_water_mark_; - - int64_t num_padded_; - mindrecord::json sample_json_; - std::map sample_bytes_; - - std::unique_ptr data_schema_; // Data schema for column typing - std::vector columns_blob_; // Blob Columns to load from dataset - std::vector columns_blob_index_; // Blob Columns to load from dataset - - std::unique_ptr shard_reader_; - WaitPost shard_reader_wait_post_; - QueueList> io_blk_queues_; - - // For block reader - std::mutex mtx_block_reader_; - std::condition_variable cv_reader_; - std::vector>> block_buffer_; - std::unordered_set block_set_; - - std::mutex ended_worker_mutex_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc deleted file mode 100644 index 8a75cdc579..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc +++ /dev/null @@ -1,450 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/mnist_op.h" - -#include -#include -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -const int32_t kMnistImageFileMagicNumber = 2051; -const int32_t kMnistLabelFileMagicNumber = 2049; -const int32_t kMnistImageRows = 28; -const int32_t kMnistImageCols = 28; - -MnistOp::Builder::Builder() : builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status MnistOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); - TensorShape scalar = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); - *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, - builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); - return Status::OK(); -} - -Status MnistOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size, std::move(sampler)), - buf_cnt_(0), - row_cnt_(0), - folder_path_(folder_path), - rows_per_buffer_(rows_per_buffer), - data_schema_(std::move(data_schema)) { - io_block_queues_.Init(num_workers, queue_size); -} - -Status MnistOp::TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) >= num_rows_) continue; // index out of bound, skipping - keys->push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); - keys->clear(); - } - } - return Status::OK(); -} - -// functor that contains the main logic of MNIST op -Status MnistOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { // each iterator is 1 epoch - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); - if (sample_ids->type() != DataType(DataType::DE_INT64)) { - RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); - } - RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); - for (int32_t i = 0; i < num_workers_; ++i) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -// contains the logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ -Status MnistOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr iOBlock; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); - while (iOBlock != nullptr) { - if (iOBlock->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (iOBlock->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(iOBlock->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -// Load 1 TensorRow (image,label) using 1 MnistLabelPair. -Status MnistOp::LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *trow) { - std::shared_ptr image, label; - int32_t l = mnist_pair.second; - // make a copy of cached tensor - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), mnist_pair.first->shape(), - mnist_pair.first->type(), mnist_pair.first->GetBuffer())); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), - data_schema_->column(1).type(), reinterpret_cast(&l))); - (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); - return Status::OK(); -} - -// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer -Status MnistOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const int64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -void MnistOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows:" << num_rows_ << "\nMNIST Directory: " << folder_path_ << "\n\n"; - } -} - -// Reset Sampler and wakeup Master thread (functor) -Status MnistOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); // wake up master thread after reset is done - return Status::OK(); -} - -// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows -Status MnistOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -// Derived from RandomAccessOp -Status MnistOp::GetClassIds(std::map> *cls_ids) const { - if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { - RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); - } - for (size_t i = 0; i < image_label_pairs_.size(); ++i) { - (*cls_ids)[image_label_pairs_[i].second].push_back(i); - } - for (auto &pair : (*cls_ids)) { - pair.second.shrink_to_fit(); - } - return Status::OK(); -} - -Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) { - uint32_t res = 0; - reader->read(reinterpret_cast(&res), 4); - if (reader->fail()) { - RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file"); - } - *result = SwapEndian(res); - return Status::OK(); -} - -uint32_t MnistOp::SwapEndian(uint32_t val) const { - val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); - return (val << 16) | (val >> 16); -} - -Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) { - if (image_reader->is_open() == false) { - RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name); - } - int64_t image_len = image_reader->seekg(0, std::ios::end).tellg(); - (void)image_reader->seekg(0, std::ios::beg); - // The first 16 bytes of the image file are type, number, row and column - if (image_len < 16) { - RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); - } - uint32_t magic_number; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number)); - CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber, - "This is not the mnist image file: " + file_name); - - uint32_t num_items; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &num_items)); - uint32_t rows; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &rows)); - uint32_t cols; - RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols)); - // The image size of the Mnist dataset is fixed at [28,28] - if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) { - RETURN_STATUS_UNEXPECTED("Wrong shape of image."); - } - if ((image_len - 16) != num_items * rows * cols) { - RETURN_STATUS_UNEXPECTED("Wrong number of image."); - } - *num_images = num_items; - return Status::OK(); -} - -Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { - if (label_reader->is_open() == false) { - RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name); - } - int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); - (void)label_reader->seekg(0, std::ios::beg); - // The first 8 bytes of the image file are type and number - if (label_len < 8) { - RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); - } - uint32_t magic_number; - RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); - CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber, - "This is not the mnist label file: " + file_name); - uint32_t num_items; - RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); - if ((label_len - 8) != num_items) { - RETURN_STATUS_UNEXPECTED("Wrong number of labels!"); - } - *num_labels = num_items; - return Status::OK(); -} - -Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) { - uint32_t num_images, num_labels; - RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images)); - RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels)); - CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num_images != num_labels"); - // The image size of the Mnist dataset is fixed at [28,28] - int64_t size = kMnistImageRows * kMnistImageCols; - auto images_buf = std::make_unique(size * num_images); - auto labels_buf = std::make_unique(num_images); - if (images_buf == nullptr || labels_buf == nullptr) { - std::string err_msg = "Fail to allocate memory for MNIST Buffer."; - MS_LOG(ERROR) << err_msg.c_str(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - (void)image_reader->read(images_buf.get(), size * num_images); - if (image_reader->fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read:" + image_names_[index] + " size:" + std::to_string(size * num_images)); - } - (void)label_reader->read(labels_buf.get(), num_images); - if (label_reader->fail()) { - RETURN_STATUS_UNEXPECTED("Fail to read:" + label_names_[index] + " size: " + std::to_string(num_images)); - } - TensorShape img_tensor_shape = TensorShape({kMnistImageRows, kMnistImageCols, 1}); - for (int64_t j = 0; j != num_images; ++j) { - auto pixels = &images_buf[j * size]; - for (int64_t m = 0; m < size; ++m) { - pixels[m] = (pixels[m] == 0) ? 0 : 255; - } - std::shared_ptr image; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), img_tensor_shape, - data_schema_->column(0).type(), reinterpret_cast(pixels))); - image_label_pairs_.emplace_back(std::make_pair(image, labels_buf[j])); - } - return Status::OK(); -} - -Status MnistOp::ParseMnistData() { - for (size_t i = 0; i < image_names_.size(); ++i) { - std::ifstream image_reader, label_reader; - image_reader.open(image_names_[i], std::ios::binary); - label_reader.open(label_names_[i], std::ios::binary); - - Status s = ReadImageAndLabel(&image_reader, &label_reader, i); - // Close the readers - image_reader.close(); - label_reader.close(); - RETURN_IF_NOT_OK(s); - } - image_label_pairs_.shrink_to_fit(); - num_rows_ = image_label_pairs_.size(); - if (num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " - "validation first."); - } - return Status::OK(); -} - -Status MnistOp::WalkAllFiles() { - const std::string kImageExtension = "idx3-ubyte"; - const std::string kLabelExtension = "idx1-ubyte"; - - Path dir(folder_path_); - auto dir_it = Path::DirIterator::OpenDirectory(&dir); - if (dir_it != nullptr) { - while (dir_it->hasNext()) { - Path file = dir_it->next(); - std::string filename = file.toString(); - if (filename.find(kImageExtension) != std::string::npos) { - image_names_.push_back(filename); - MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; - } else if (filename.find(kLabelExtension) != std::string::npos) { - label_names_.push_back(filename); - MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; - } - } - } else { - MS_LOG(WARNING) << "Mnist operator unable to open directory " << dir.toString() << "."; - } - - std::sort(image_names_.begin(), image_names_.end()); - std::sort(label_names_.begin(), label_names_.end()); - - if (image_names_.size() != label_names_.size()) { - RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels"); - } - - return Status::OK(); -} - -Status MnistOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(this->WalkAllFiles()); - RETURN_IF_NOT_OK(this->ParseMnistData()); - RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler - return Status::OK(); -} - -Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { - // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); - - RETURN_IF_NOT_OK(op->WalkAllFiles()); - - for (size_t i = 0; i < op->image_names_.size(); ++i) { - std::ifstream image_reader; - image_reader.open(op->image_names_[i], std::ios::binary); - std::ifstream label_reader; - label_reader.open(op->label_names_[i], std::ios::binary); - - uint32_t num_images; - RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images)); - uint32_t num_labels; - RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels)); - CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num of images does not equal to num of labels"); - *count = *count + num_images; - - // Close the readers - image_reader.close(); - label_reader.close(); - } - - return Status::OK(); -} - -// Visitor accept method for NodePass -Status MnistOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status MnistOp::ComputeColMap() { - // set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h deleted file mode 100644 index e57dc21d60..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using MnistLabelPair = std::pair, int32_t>; - -class MnistOp : public ParallelOp, public RandomAccessOp { - public: - class Builder { - public: - // Constructor for Builder class of MnistOp - // @param uint32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method - // @param const std::string & dir - // @return - Builder &SetDir(const std::string &dir) { - builder_dir_ = dir; - return *this; - } - - // Check validity of input args - // @return - The error code return - Status SanityCheck(); - - // The builder "Build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - std::string builder_dir_; - int32_t builder_num_workers_; - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - }; - - // Constructor - // @param int32_t num_workers - number of workers reading images in parallel - // @param int32_t rows_per_buffer - number of images (rows) in each buffer - // @param std::string folder_path - dir directory of mnist - // @param int32_t queue_size - connector queue size - // @param std::unique_ptr data_schema - the schema of the mnist dataset - // @param td::unique_ptr sampler - sampler tells MnistOp what to read - MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, - std::unique_ptr data_schema, std::shared_ptr sampler); - - // Destructor. - ~MnistOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t worker_id - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of MnistOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // Method derived from RandomAccess Op, enable Sampler to get all ids for each class - // @param (std::map> * map - key label, val all ids for this class - // @return Status - The error code return - Status GetClassIds(std::map> *cls_ids) const override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // Function to count the number of samples in the MNIST dataset - // @param dir path to the MNIST directory - // @param count output arg that will hold the minimum of the actual dataset size and numSamples - // @return - static Status CountTotalRows(const std::string &dir, int64_t *count); - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p Pointer to the NodePass to be accepted - /// \param[out] modified Indicator if the node was changed at all - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "MnistOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to a pair - // @param row_id_type row_id - id for this tensor row - // @param ImageLabelPair pair - - // @param TensorRow row - image & label read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Iterate through all members in sampleIds and fill them into IOBlock. - // @param std::shared_ptr sample_ids - - // @param std::vector *keys - keys in ioblock - // @return Status - The error code return - Status TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); - - // Check image file stream. - // @param const std::string *file_name - image file name - // @param std::ifstream *image_reader - image file stream - // @param uint32_t num_images - returns the number of images - // @return Status - The error code return - Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); - - // Check label stream. - // @param const std::string &file_name - label file name - // @param std::ifstream *label_reader - label file stream - // @param uint32_t num_labels - returns the number of labels - // @return Status - The error code return - Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); - - // Read 4 bytes of data from a file stream. - // @param std::ifstream *reader - file stream to read - // @return uint32_t - read out data - Status ReadFromReader(std::ifstream *reader, uint32_t *result); - - // Swap endian - // @param uint32_t val - - // @return uint32_t - swap endian data - uint32_t SwapEndian(uint32_t val) const; - - // Read the specified number of images and labels from the file stream - // @param std::ifstream *image_reader - image file stream - // @param std::ifstream *label_reader - label file stream - // @param int64_t read_num - number of image to read - // @return Status - The error code return - Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); - - // Parse all mnist dataset files - // @return Status - The error code return - Status ParseMnistData(); - - // Read all files in the directory - // @return Status - The error code return - Status WalkAllFiles(); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // reset Op - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int64_t buf_cnt_; - int64_t row_cnt_; - WaitPost wp_; - std::string folder_path_; // directory of image folder - int32_t rows_per_buffer_; - std::unique_ptr data_schema_; - std::vector image_label_pairs_; - std::vector image_names_; - std::vector label_names_; - QueueList> io_block_queues_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc deleted file mode 100644 index f13de2e5c9..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc +++ /dev/null @@ -1,426 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/datasetops/source/random_data_op.h" -#include -#include -#include "dataset/engine/execution_tree.h" -#include "dataset/core/config_manager.h" -#include "dataset/util/random.h" -#include "dataset/util/wait_post.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -RandomDataOp::Builder::Builder() - : builder_data_schema_(nullptr), - builder_num_workers_(0), - builder_op_connector_size_(0), - builder_rows_per_buffer_(0), - builder_total_rows_(0), - builder_sampler_(nullptr) { - // Some arguments to the RandomDataOp have a default argument that is taken from the config. - // The user may override these defaults by using the builder set methods. - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -// The build method that produces the instantiated RandomDataOp as a shared pointer -Status RandomDataOp::Builder::Build(std::shared_ptr *out_op) { - RETURN_IF_NOT_OK(SanityCheck()); - - *out_op = - std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, - builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); - - // If the user did not provide a schema, then we will ask the op to generate a pseudo-random - // schema. - // See details of generateSchema function to learn what type of schema it will create. - if ((*out_op)->data_schema_ == nullptr) { - RETURN_IF_NOT_OK((*out_op)->GenerateSchema()); - } - - return Status::OK(); -} - -// Check if the required parameters are set by the builder. -Status RandomDataOp::Builder::SanityCheck() const { - // There actually is no required arguments for the random data op at all. - // Some arguments are preset with global values from config, and if they are not given by the user - // then we create them randomly. Leaving this function here for consistency with other operators. - return Status::OK(); -} - -// Constructor for RandomDataOp -RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), - buffer_id_(0), - rows_per_buffer_(rows_per_buffer), - total_rows_(total_rows), - epoch_buffers_sent_(0), - guys_in_(0), - guys_out_(num_workers_), - eoe_worker_id_(0), - data_schema_(std::move(data_schema)) { - rand_gen_.seed(GetSeed()); // seed the random generator - // If total rows was not given, then randomly pick a number - if (total_rows_ == 0) { - total_rows_ = GenRandomInt(1, kMaxTotalRows); - } - // Everyone is already out from the sync area. - all_out_.Set(); -} - -// A print method typically used for debugging -void RandomDataOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [total rows: " << total_rows_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nTotal_rows: " << total_rows_ << "\nRows per buffer: " << rows_per_buffer_ << "\nSchema:\n" - << *data_schema_ << "\n\n"; - } -} - -// Helper function to produce a default/random schema if one didn't exist -Status RandomDataOp::GenerateSchema() { - if (data_schema_ != nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Generating a schema but one already exists!"); - } - - // To randomly create a schema, we need to choose: - // a) how many columns - // b) the type of each column - // c) the shape of each column (number of dimensions i.e. rank) - // d) the shape of each column (dimension values) - data_schema_ = std::make_unique(); - std::unique_ptr newShape; - std::unique_ptr newCol; - - // Loop over the number of chosen columns - int32_t numColumns = GenRandomInt(1, kMaxNumColumns); - for (int32_t i = 0; i < numColumns; i++) { - // For each column: - // - choose a datatype - // - generate a shape that randomly chooses the number of dimensions and the dimension values. - DataType::Type newType = static_cast(GenRandomInt(1, DataType::NUM_OF_TYPES - 2)); - int32_t rank = GenRandomInt(1, kMaxRank); - std::vector dims; - for (int32_t d = 0; d < rank; d++) { - // 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random - // 0 value to the unknown attribute if 0 is chosen - dsize_t dim_value = static_cast(GenRandomInt(0, kMaxDimValue)); - if (dim_value == 0) dim_value = TensorShape::kDimUnknown; - dims.push_back(dim_value); - } - newShape = std::make_unique(dims); - - // Create the column descriptor - std::string colName = "c" + std::to_string(i); - newCol = std::make_unique(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get()); - - data_schema_->AddColumn(*newCol); - } - - return Status::OK(); -} - -// Class functor operator () override. -// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work. -Status RandomDataOp::operator()() { - // First, compute how many buffers we'll need to satisfy the total row count. - // The only reason we do this is for the purpose of throttling worker count if needed. - int64_t buffers_needed = total_rows_ / rows_per_buffer_; - if (total_rows_ % rows_per_buffer_ != 0) { - buffers_needed++; - } - - // If the amount of workers we have exceeds the number of buffers to produce, then we'll have - // idle workers doing nothing. In that case, let's throttle the worker count. - if (num_workers_ > buffers_needed) { - MS_LOG(INFO) << "RandomDataOp throttling worker count from " << num_workers_ << "to " << buffers_needed; - num_workers_ = buffers_needed; - num_producers_ = num_workers_; - guys_out_ = num_workers_; - // The output connector was already created with a different worker count. We have to drop and recreate - // that connector. - DatasetOp::CreateConnector(num_producers_, num_workers_); - } - - // Assign the number of rows to each worker in a round robin fashion. - worker_max_rows_.reserve(num_workers_); - worker_rows_packed_.reserve(num_workers_); - // init the counts to zero to start. - for (int32_t w = 0; w < num_workers_; w++) { - worker_max_rows_.push_back(0); - worker_rows_packed_.push_back(0); - } - // then assign round robin row counts - int32_t currentWorker = 0; - for (int64_t r = 0; r < total_rows_; r++) { - worker_max_rows_[currentWorker]++; - currentWorker = (currentWorker + 1) % num_workers_; - } - - // Next, compute the total buffer count. This stat is needed during reset logic - for (int32_t w = 0; w < num_workers_; w++) { - int64_t worker_buffers = 0; - worker_buffers = worker_max_rows_[w] / rows_per_buffer_; - if (worker_max_rows_[w] % rows_per_buffer_ != 0) worker_buffers++; - epoch_buffers_sent_ += worker_buffers; - } - - // For the connector to work, we need to target the correct worker channel for the eoe. - // This will initialize it for the first one. reset() handles for the rest of the epochs. - eoe_worker_id_ = epoch_buffers_sent_ % num_workers_; - epoch_buffers_sent_++; // Add the eoe buffer to the count for subsequent epochs - - // RandomDataOp doesn't need the master thread to stay around. Kick off the workers and then master exits. - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&RandomDataOp::WorkerEntry, this, std::placeholders::_1))); - - // required task group setup after launching workers - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(epoch_sync_wait_post_.Register(tree_->AllTasks())); - - return Status::OK(); -} - -// Performs a synchronization between workers at the end of an epoch -Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " syncing at end of epoch"; - - // Sync on the guys_in counter - // We have to wait the last guy is out. - all_out_.Wait(); - // If we are not in a repeat loop, or that was the last repeat already, then setup our exit - // condition from the master loop. - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - *quitting = true; - } - - auto prev = guys_in_.fetch_add(1); - bool last_guy_in = (prev + 1) == num_workers_; - // If we are the last worker to hit this sync point, we have some extra tasks - if (last_guy_in) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " - << eoe_worker_id_; - // Prepare for sync - all_out_.Clear(); - // Always flow eoe at the end - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(eoe_worker_id_, std::move(eoe_buffer))); - // If we're done then also flow the eof - if (*quitting) { - // The eof needs to be sent from the next sender in the round robin, so +1 - int32_t eof_worker_id = (eoe_worker_id_ + 1) % num_workers_; - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " has no more epochs. sending eof as worker " - << eof_worker_id; - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(eof_worker_id, std::move(eof_buffer))); - } - } - - // Wait for the reset to wake us up if we're not quitting - if (!(*quitting)) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; - RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); - prev = guys_out_.fetch_add(1); - bool last_guy_out = (prev + 1) == num_workers_; - // Last guy out will clear the wait post and set the row counts - if (last_guy_out) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " last guy out clearing wait post."; - epoch_sync_wait_post_.Clear(); - guys_in_ = 0; - all_out_.Set(); - } - } - - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " epoch sync complete."; - return Status::OK(); -} - -// The entry point code for when workers are launched -Status RandomDataOp::WorkerEntry(int32_t worker_id) { - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entry"; - - // handshake with the master first to tell it we're alive - TaskManager::FindMe()->Post(); - - bool quitting = false; - std::unique_ptr new_tensor_table = nullptr; - - // Loop until the quitting variable gets set to true - do { - // If we have not yet reached the row count for this worker then produce another record - if (worker_rows_packed_[worker_id] < worker_max_rows_[worker_id]) { - TensorRow new_row; - - // Start a new tensor table if needed - if (new_tensor_table == nullptr) { - new_tensor_table = std::make_unique(); - } - - // Create the data for the row - RETURN_IF_NOT_OK(CreateRandomRow(worker_id, &new_row)); - - // Add the row to our table - new_tensor_table->push_back(std::move(new_row)); - worker_rows_packed_[worker_id]++; - - // If the tensor table is at capacity then it's time to send it to output - if (new_tensor_table->size() == rows_per_buffer_) { - RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); - } - } else { - // We've reached the total row count for this worker, so it's time for epoch sync. - // There is likely some records built but not sent yet, so take care of those first - // (this buffer will be smaller than rows_per_buffer) - if (new_tensor_table != nullptr && new_tensor_table->size() > 0) { - RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); - } - - // Now, let's enter the epoch sync - RETURN_IF_NOT_OK(EpochSync(worker_id, &quitting)); - } - } while (!quitting); - - MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is now quitting."; - - return Status::OK(); -} - -// A helper function to stuff the tensor table into a buffer and send it to output connector -Status RandomDataOp::PackAndSend(int32_t worker_id, std::unique_ptr in_table) { - auto new_buffer = std::make_unique(GetNextBufferId(), DataBuffer::kDeBFlagNone); - new_buffer->set_tensor_table(std::move(in_table)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(new_buffer))); - return Status::OK(); -} - -// A helper function to create random data for the row -Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { - if (new_row == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Missing tensor row output"); - } - - // Create a tensor for each column, then add the tensor to the row - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - const ColDescriptor current_col = data_schema_->column(i); - std::vector current_shape = current_col.shape().AsVector(); - std::unique_ptr new_shape = nullptr; - std::unique_ptr buf = nullptr; - std::shared_ptr new_tensor = nullptr; - - // We need to resolve the shape to fill in any unknown dimensions with random - // values, then use that as our shape for this tensor. - for (int j = 0; j < current_shape.size(); ++j) { - if (current_shape[j] == TensorShape::kDimUnknown) { - current_shape[j] = static_cast(GenRandomInt(1, kMaxDimValue)); - } - } - - new_shape = std::make_unique(current_shape); - int64_t size_in_bytes = new_shape->NumOfElements() * current_col.type().SizeInBytes(); - - // Generate a random byte of data. This may cause some funny data for things like doubles,floats, bools - // however the random data op is not too concerned about the physical data itself. - std::uniform_int_distribution uniDist(0, 255); - uint8_t random_byte = uniDist(rand_gen_); - - // Now, create a chunk of memory for the entire tensor and copy this byte in repeatedly. - buf = std::make_unique(size_in_bytes); - int ret_code = memset_s(buf.get(), size_in_bytes, random_byte, size_in_bytes); - if (ret_code != 0) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor."); - } - - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get())); - - // Add this tensor to the tensor row for output - (*new_row).push_back(std::move(new_tensor)); - } - return Status::OK(); -} - -// Overrides base class reset method. When an operator does a reset, it cleans up any state -// info from it's previous execution and then initializes itself so that it can be executed -// again. -Status RandomDataOp::Reset() { - MS_LOG(INFO) << "RandomDataOp resetting."; - - // Ensure all guys are in the waitpost - if (guys_in_ != num_workers_) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "Issuing a reset, but some workers are missing from epochSync!"); - } - - // reset the row counters for all workers - for (int32_t w = 0; w < num_workers_; w++) { - worker_rows_packed_[w] = 0; - worker_max_rows_[w] = 0; - } - buffer_id_ = 0; - - // Re-assign round robin row counts, starting from the worker after the one that gave - // the eoe last time - int32_t currentWorker = (eoe_worker_id_ + 1) % num_workers_; - for (int64_t r = 0; r < total_rows_; r++) { - worker_max_rows_[currentWorker]++; - currentWorker = (currentWorker + 1) % num_workers_; - } - - // Compute which worker should get the eoe for the next epoch - eoe_worker_id_ = ((epoch_buffers_sent_ % num_workers_) + eoe_worker_id_) % num_workers_; - - // Wake up the workers to get them going again in a new epoch - guys_out_ = 0; - epoch_sync_wait_post_.Set(); - - return Status::OK(); -} - -// Visitor accept method for NodePass -Status RandomDataOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status RandomDataOp::ComputeColMap() { - // Extract the column name mapping from the schema and save it in the class. - if (column_name_id_map_.empty()) { - RETURN_IF_NOT_OK(data_schema_->GetColumnNameMap(&(column_name_id_map_))); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h deleted file mode 100644 index 76d781ee1c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h +++ /dev/null @@ -1,291 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ - -#include -#include -#include -#include -#include -#include -#include -#include "dataset/util/status.h" -#include "dataset/core/tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// The RandomDataOp is a leaf node storage operator that generates random data based -// on the schema specifications. Typically, it's used for testing and demonstrating -// various dataset operator pipelines. It is not "real" data to train with. -// The data that is random created is just random and repeated bytes, there is no -// "meaning" behind what these bytes are. -class RandomDataOp : public ParallelOp { - public: - // Some constants to provide limits to random generation. - static constexpr int32_t kMaxNumColumns = 4; - static constexpr int32_t kMaxRank = 4; - static constexpr int32_t kMaxDimValue = 32; - static constexpr int32_t kMaxTotalRows = 1024; - - // A nested builder class to aid in the construction of a RandomDataOp - class Builder { - public: - /** - * Builder constructor. Creates the builder object. - * @note No default args. - * @return This is a constructor. - */ - Builder(); - - /** - * Default destructor - */ - ~Builder() = default; - - /** - * The build method that produces the instantiated RandomDataOp as a shared pointer - * @param out_op - The output RandomDataOperator that was constructed - * @return Status - The error code return - */ - Status Build(std::shared_ptr *out_op); - - /** - * Builder set method - * @param data_schema - A user-provided schema - * @return Builder - The modified builder by reference - */ - Builder &SetDataSchema(std::unique_ptr data_schema) { - builder_data_schema_ = std::move(data_schema); - return *this; - } - - /** - * Builder set method - * @param num_workers - The number of workers - * @return Builder - The modified builder by reference - */ - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - /** - * Builder set method - * @param op_connector_size - The size of the output connector - * @return Builder - The modified builder by reference - */ - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - /** - * Builder set method - * @param rows_per_buffer - The number of rows in each DataBuffer - * @return Builder - The modified builder by reference - */ - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - /** - * Builder set method - * @param total_rows - The total number of rows in the dataset - * @return Builder - The modified builder by reference - */ - Builder &SetTotalRows(int64_t total_rows) { - builder_total_rows_ = total_rows; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - private: - /** - * Check if the required parameters are set by the builder. - * @return Status - The error code return - */ - Status SanityCheck() const; - - std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_total_rows_; - }; // class Builder - - /** - * Constructor for RandomDataOp - * @note Private constructor. Must use builder to construct. - * @param num_workers - The number of workers - * @param op_connector_size - The size of the output connector - * @param rows_per_buffer - The number of rows in each DataBuffer - * @param data_schema - A user-provided schema - * @param total_rows - The total number of rows in the dataset - * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes - * @return Builder - The modified builder by reference - */ - RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, - std::unique_ptr data_schema, std::shared_ptr sampler); - - /** - * Destructor - */ - ~RandomDataOp() = default; - - /** - * A print method typically used for debugging - * @param out - The output stream to write output to - * @param show_all - A bool to control if you want to show all info or just a summary - */ - void Print(std::ostream &out, bool show_all) const override; - - /** - * << Stream output operator overload - * @notes This allows you to write the debug print info using stream operators - * @param out - reference to the output stream being overloaded - * @param so - reference to the ShuffleOp to display - * @return - the output stream must be returned - */ - friend std::ostream &operator<<(std::ostream &out, const RandomDataOp &op) { - op.Print(out, false); - return out; - } - - /** - * Class functor operator () override. - * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will - * provide the master loop that drives the logic for performing the work. - * @return Status - The error code return - */ - Status operator()() override; - - /** - * Overrides base class reset method. When an operator does a reset, it cleans up any state - * info from it's previous execution and then initializes itself so that it can be executed - * again. - * @return Status - The error code return - */ - Status Reset() override; - - /** - * Quick getter for total rows. - */ - int64_t GetTotalRows() const { return total_rows_; } - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "RandomDataOp"; } - - private: - /** - * The entry point code for when workers are launched - * @param worker_id - The worker id - * @return Status - The error code return - */ - Status WorkerEntry(int32_t worker_id) override; - - /** - * Helper function to produce a default/random schema if one didn't exist - @return Status - The error code return - */ - Status GenerateSchema(); - - /** - * Performs a synchronization between workers at the end of an epoch - * @param worker_id - The worker id - * @return Status - The error code return - */ - Status EpochSync(int32_t worker_id, bool *quitting); - - /** - * A helper function to stuff the tensor table into a buffer and send it to output connector - * @param worker_id - The worker id - * @param in_table - The tensor table to pack and send - * @return Status - The error code return - */ - Status PackAndSend(int32_t worker_id, std::unique_ptr in_table); - - /** - * A helper function to create random data for the row - * @param worker_id - The worker id - * @param new_row - The output row to produce - * @return Status - The error code return - */ - Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); - - /** - * A quick inline for producing a random number between (and including) min/max - * @param min - minimum number that can be generated - * @param max - maximum number that can be generated - * @return - The generated random number - */ - inline int32_t GenRandomInt(int32_t min, int32_t max) { - std::uniform_int_distribution uniDist(min, max); - return uniDist(rand_gen_); - } - - /** - * A quick inline for producing the next buffer id in sequence, threadsafe - * @return - The next buffer id. - */ - inline int32_t GetNextBufferId() { - std::unique_lock lock(buffer_id_mutex_); - return ++buffer_id_; - } - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t buffer_id_; - int64_t rows_per_buffer_; - int64_t total_rows_; - int64_t epoch_buffers_sent_; - std::atomic guys_in_; - std::atomic guys_out_; - int32_t eoe_worker_id_; - std::unique_ptr data_schema_; - std::vector worker_max_rows_; - std::vector worker_rows_packed_; - std::mt19937 rand_gen_; - WaitPost epoch_sync_wait_post_; - WaitPost all_out_; - std::mutex buffer_id_mutex_; -}; // class RandomDataOp -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc deleted file mode 100644 index 9f4a9cf55c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.cc +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" - -#include -#include - -#include "dataset/engine/data_buffer.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed) - : Sampler(num_samples, std::numeric_limits::max()), - cnt_(0), - seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), - device_id_(dev_id), - num_devices_(num_dev), - shuffle_(shuffle) {} - -Status DistributedSampler::InitSampler() { - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, - "fail to init DistributedSampler"); - rnd_.seed(seed_++); - samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) - samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; - if (shuffle_ == true) { - shuffle_vec_.reserve(num_rows_); - for (int64_t i = 0; i < num_rows_; i++) { - shuffle_vec_.push_back(i); - } - std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); - } - return Status::OK(); -} - -Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (cnt_ > samples_per_buffer_) { - RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); - } else if (cnt_ == samples_per_buffer_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(cnt_, DataBuffer::kDeBFlagNone); - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); - auto id_ptr = sample_ids->begin(); - while (cnt_ < samples_per_buffer_ && id_ptr != sample_ids->end()) { - int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_; - if (shuffle_) { - sampled_id = shuffle_vec_[static_cast(sampled_id)]; - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *id_ptr = sampled_id; - id_ptr++; - cnt_++; - } - TensorRow row(1, sample_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status DistributedSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); - cnt_ = 0; - - if (shuffle_ == true) { - rnd_.seed(seed_); - seed_++; - std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void DistributedSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: DistributedSampler"; - if (show_all) { - Sampler::Print(out, show_all); - out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ - << "\nshuffle: " << shuffle_; - } -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h deleted file mode 100644 index 7083580c6c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/distributed_sampler.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ - -#include -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class DistributedSampler : public Sampler { - public: - // @param num_samples - // @param int64_t num_dev - // @param int64_t dev_id - // @param bool shuffle - DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, - uint32_t seed = std::numeric_limits::max()); - - // default destructor - ~DistributedSampler() = default; - - // @param std::unique_ptr * pBuffer - // @param int32_t workerId - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Init sampler, called by base class or python - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - void Print(std::ostream &out, bool show_all) const override; - - private: - int64_t cnt_; // number of samples that have already been filled in to buffer - uint32_t seed_; - int64_t device_id_; - int64_t num_devices_; - bool shuffle_; - std::mt19937 rnd_; - std::vector shuffle_vec_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc deleted file mode 100644 index cd2cadb9ff..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.cc +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include -#include -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), - shuffle_(shuffle), - seed_(GetSeed()), - next_id_(0), - samples_per_class_(val) {} - -Status PKSampler::InitSampler() { - labels_.reserve(label_to_ids_.size()); - for (const auto &pair : label_to_ids_) { - if (pair.second.empty() == false) { - labels_.push_back(pair.first); - } - } - rnd_.seed(seed_++); - - // The special handshake gives the list of classes and id's, but it did not set the num_rows_ to - // capture the total number of possible sample ids. - // Compute that here for this case to find the total number of samples that are available to return. - // (in this case, samples per class * total classes). - num_rows_ = samples_per_class_ * static_cast(labels_.size()); - - // The user may have chosen to sample less than the total amount. - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - - samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; - if (shuffle_ == true) { - std::shuffle(labels_.begin(), labels_.end(), rnd_); - } else { - std::sort(labels_.begin(), labels_.end()); - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive"); - return Status::OK(); -} - -Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (next_id_ > num_samples_ || num_samples_ == 0) { - RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); - } else if (next_id_ == num_samples_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); - std::shared_ptr sample_ids; - int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; - RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); - auto id_ptr = sample_ids->begin(); - while (next_id_ < last_id && id_ptr != sample_ids->end()) { - int64_t cls_id = next_id_++ / samples_per_class_; - const std::vector &samples = label_to_ids_[labels_[cls_id]]; - int64_t rnd_ind = std::uniform_int_distribution(0, samples.size() - 1)(rnd_); - int64_t sampled_id = samples[rnd_ind]; - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *id_ptr = sampled_id; - id_ptr++; - } - - TensorRow row(1, sample_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status PKSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); - next_id_ = 0; - rnd_.seed(seed_++); - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { - RETURN_UNEXPECTED_IF_NULL(op); - RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); - RETURN_IF_NOT_OK(InitSampler()); - return Status::OK(); -} - -void PKSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: PKSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h deleted file mode 100644 index cde8a75b5b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/pk_sampler.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ - -#include -#include -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class PKSampler : public Sampler { // NOT YET FINISHED - public: - // @param num_samples - the number of samples to draw. value of 0 means to take the full amount - // @param int64_t val - // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 - // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // default destructor - ~PKSampler() = default; - - // @param std::unique_ptr *out_buffer) override; - - // first handshake between leaf source op and Sampler. This func will determine the amount of data - // in the dataset that we can sample from. - // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is - // @return - Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; - - // init sampler, to be called by python or Handshake - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - bool shuffle_; - uint32_t seed_; - int64_t next_id_; - int64_t samples_per_class_; - std::mt19937 rnd_; - std::vector labels_; - std::map> label_to_ids_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc deleted file mode 100644 index d204c55ce9..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/python_sampler.h" - -#include - -namespace mindspore { -namespace dataset { - -PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} - -Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (need_to_reset_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - std::shared_ptr sample_ids; - { - py::gil_scoped_acquire gil_acquire; - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py::object py_ret = py_sampler_instance.attr("_get_indices")(); - py::array np_sample_ids = py_ret.cast(); - Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor - - if (HasChildSampler()) { - for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { - int64_t associated_child_id = 0; - RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id)); - *it = associated_child_id; - } - } - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } catch (const py::cast_error &e) { - return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); - } - } - TensorRow row(1, sample_ids); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - need_to_reset_ = true; - } - return Status::OK(); -} - -Status PythonSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - } - return Status::OK(); -} - -Status PythonSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); - need_to_reset_ = false; - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - py_sampler_instance.attr("reset")(); - } catch (const py::error_already_set &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void PythonSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: PythonSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h deleted file mode 100644 index 7d653b2087..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ - -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class PythonSampler : public Sampler { - public: - // Constructor - // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the - // data from the dataset. - // @param py_sampler_instance - the python instance of the sampler - // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~PythonSampler() = default; - - // Initialize the sampler. - // @return Status - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - // Op calls this to get next Buffer that contains all the sampleIds - // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op - // @param int32_t workerId - not meant to be used - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() - - py::object py_sampler_instance; // The handle to the py_sampler python object -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc deleted file mode 100644 index db0a96ea3a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" - -#include -#include -#include -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), - seed_(GetSeed()), - replacement_(replacement), - next_id_(0), - reshuffle_each_epoch_(reshuffle_each_epoch), - dist(nullptr) {} - -Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (next_id_ > num_samples_) { - RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); - } else if (next_id_ == num_samples_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); - - std::shared_ptr sampleIds; - int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_); - RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); - auto id_ptr = sampleIds->begin(); - - for (int64_t i = 0; i < (last_id - next_id_); i++) { - int64_t sampled_id = 0; - if (replacement_) { - sampled_id = (*dist)(rnd_); - } else { - sampled_id = shuffled_ids_[static_cast(i + next_id_)]; - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *(id_ptr + i) = sampled_id; - } - next_id_ = last_id; - TensorRow row(1, sampleIds); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status RandomSampler::InitSampler() { - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; - rnd_.seed(seed_); - - if (replacement_ == false) { - shuffled_ids_.reserve(num_rows_); - for (int64_t i = 0; i < num_rows_; i++) { - shuffled_ids_.push_back(i); - } - std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); - } else { - dist = std::make_unique>(0, num_rows_ - 1); - } - - return Status::OK(); -} - -Status RandomSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); - next_id_ = 0; - - if (reshuffle_each_epoch_) { - seed_++; - } - - rnd_.seed(seed_); - - if (replacement_ == false && reshuffle_each_epoch_) { - std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void RandomSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: RandomSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h deleted file mode 100644 index b1c54eb98c..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/random_sampler.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class RandomSampler : public Sampler { - public: - // Constructor - // @param int64_t num_samples - number samples to draw - // @param bool replacement - put he id back / or not after a sample - // @param reshuffle_each_epoch - T/F to reshuffle after epoch - // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~RandomSampler() = default; - - // Op calls this to get next Buffer that contains all the sampleIds - // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp - // @param int32_t workerId - not meant to be used - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // meant to be called by base class or python - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - virtual void Print(std::ostream &out, bool show_all) const; - - private: - uint32_t seed_; - bool replacement_; - std::vector shuffled_ids_; // only used for NO REPLACEMENT - int64_t next_id_; - std::mt19937 rnd_; - std::unique_ptr> dist; - bool reshuffle_each_epoch_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc deleted file mode 100644 index 5f0ffd8855..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -#include - -namespace mindspore { -namespace dataset { -Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { - // The sampler base class itself does not compute it's own num_rows_ value. - // Instead, this value is computed by the derived leaf op during it's own initialization - // after it has interacted with it's storage layers. - // Here, it is just a getter method to return the value. However, it is invalid if there is - // not a value set for this count, so generate a failure if that is the case. - if (num == nullptr || num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); - } - (*num) = num_rows_; - return Status::OK(); -} - -Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) - : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} - -Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { - std::shared_ptr child_sampler; - if (HasChildSampler()) { - child_sampler = std::dynamic_pointer_cast(child_[0]); - if (!child_sampler) { - std::string err_msg("Cannot handshake, child is not a sampler object."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Handshake and init child first. - RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); - } - - CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); - - // If there's a child sampler, set the row count to be it's sample count - if (HasChildSampler()) { - num_rows_ = child_sampler->num_samples_; - } else { - RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); - } - - // It's up to the derived class to check the validity of the two args - // Because some sampler only needs one of the arg (weighted_random_sampler) - RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback - - return Status::OK(); -} - -Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements) { - if (num_elements == 0) { - RETURN_STATUS_UNEXPECTED("num of Elements is 0"); - } - if (col_desc_ == nullptr) { - // a ColDescriptor for Tensor that holds SampleIds - col_desc_ = std::make_unique("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); - } - TensorShape shape(std::vector(1, num_elements)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type())); - RETURN_IF_NOT_OK( - (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets! - return Status::OK(); -} - -void Sampler::Print(std::ostream &out, bool show_all) const { - // Sampler printing is usually only called in the show_all mode. - // Derived classes will display the name, then call back to this base - // for common info. - // No-op in the summary mode. - if (show_all) { - out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_; - } -} - -#ifdef ENABLE_PYTHON -Status Sampler::GetAllIdsThenReset(py::array *data) { - std::unique_ptr db; - std::shared_ptr sample_ids; - TensorRow sample_row; - - // A call to derived class to get sample ids wrapped inside a buffer - RETURN_IF_NOT_OK(GetNextSample(&db)); - // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch - RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); - sample_ids = sample_row[0]; - - // check this buffer is not a ctrl buffer - CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); - { - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - } - try { - RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); - } catch (const std::runtime_error &e) { - return Status(StatusCode::kPyFuncException, e.what()); - } - } - // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch - RETURN_IF_NOT_OK(GetNextSample(&db)); - CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); - // Reset Sampler since this is the end of the epoch - RETURN_IF_NOT_OK(ResetSampler()); - return Status::OK(); -} -#endif - -Status Sampler::SetNumSamples(int64_t num_samples) { - CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative"); - num_samples_ = num_samples; - return Status::OK(); -} - -Status Sampler::SetNumRowsInDataset(int64_t num_rows) { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); - num_rows_ = num_rows; - return Status::OK(); -} - -Status Sampler::AddChild(std::shared_ptr child) { - if (child == nullptr) { - return Status::OK(); - } - - // Only samplers can be added, not any other DatasetOp. - std::shared_ptr sampler = std::dynamic_pointer_cast(child); - if (!sampler) { - std::string err_msg("Cannot add child, child is not a sampler object."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Samplers can have at most 1 child. - if (!child_.empty()) { - std::string err_msg("Cannot add child sampler, this sampler already has a child."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - child_.push_back(child); - - // doesn't work, protected? - // child->AddParent(this); - return Status::OK(); -} - -bool Sampler::HasChildSampler() { return !child_.empty(); } - -Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { - if (child_ids_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); - } - - TensorRow sample_row; - RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row)); - std::shared_ptr sample_ids = sample_row[0]; - RETURN_IF_NOT_OK(sample_ids->GetItemAt(out_associated_id, {id})); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h deleted file mode 100644 index d9da777a48..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.h +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/dataset_op.h" - -namespace mindspore { -namespace dataset { -// RandomAccessOp is a base class that all data-producing leaf operators -// must inherit from if those leaf operator wish to support sampling. -class RandomAccessOp { - public: - // Sampler get number of rows in the dataset - // @param int64_t num - return number of rows for this dataset - // @return - The error code return - Status GetNumRowsInDataset(int64_t *num_rows) const; - - // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK - // @param std::map> * map - // @return - The error code return - virtual Status GetClassIds(std::map> *map) const { - RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); - } - - // default destructor - virtual ~RandomAccessOp() = default; - - protected: - // The amount of rows in the dataset itself. This is the before-sampling value, the - // total count of rows. A sampler may choose to sample less than this amount. - int64_t num_rows_; -}; - -class Sampler { - public: - // Constructor - // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 - // indicates that the sampler should produce the complete set of ids. - // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); - - Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} - - // default destructor - ~Sampler() = default; - - // Get a list of sample ids. - // @note It is Sampler responsibility to make sure that the id is not out of bound. - // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp - // @param int32_t workerId - not meant to be used - // @return - The error code return - virtual Status GetNextSample(std::unique_ptr *out_buffer) = 0; - -// This function only called by python layer. Not needed by Android. -#ifdef ENABLE_PYTHON - // return all ids in one epoch as a numpy array, then call reset - Status GetAllIdsThenReset(py::array *data); -#endif - - // for next epoch of sampleIds - // @return - The error code return - virtual Status ResetSampler() = 0; - - // first handshake between leaf source op and Sampler. This func will determine the amount of data - // in the dataset that we can sample from. - // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is - // @return - virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); - - // initialize sampler and perform checks on certain vars - virtual Status InitSampler() { return Status::OK(); } - - // setter for num samples - // @param num_samples - the number of samples to assign. - // @return status error code - Status SetNumSamples(int64_t num_samples); - - // setter for num or records in the dataset - // @param num_rows - the number of records - // @return status error code - Status SetNumRowsInDataset(int64_t num_rows); - - // Adds a sampler to become our child. - // @param std::shared_ptr - The sampler to add as a child. - // @return - The error code returned. - Status AddChild(std::shared_ptr child); - - // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler - // @param std::shared_ptr* sampleIds - // @param int64_t numElements - must be a non 0 number - // @return - The error code returned. - Status CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements); - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - virtual void Print(std::ostream &out, bool show_all) const; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param sampler - reference to teh sampler to print - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { - sampler.Print(out, false); - return out; - } - - // Checks if this sampler has a child sampler. - // @return - tre if there is a child sampler, false otherwise. - bool HasChildSampler(); - - // Uses id as an index for the list of ids generated by the child sampler, and gets the - // associated id. - // @param int64_t* out_associated_id - Out parameter, contains the associated id. - // @param int64_t id - The id used as an index to get the associated child id. - // @return - The error code returned. - Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); - - protected: - // Number of rows of data from the place this sampler is sampling from. If this sampler - // has a child sampler, num_rows_ is the number of ids the child sampler will - // output. Otherwise, num_rows_ is the number of rows in the dataset. - int64_t num_rows_; - - // The user may want to sample less than the full amount of data. num_samples_ reduces the number - // of id's returned as request by the user. Derived classes will choose how to sample the smaller - // amount. - int64_t num_samples_; - - int64_t samples_per_buffer_; - std::unique_ptr col_desc_; - std::vector> child_; // Child nodes - std::unique_ptr child_ids_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc deleted file mode 100644 index 28598da55f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" - -#include -#include - -namespace mindspore { -namespace dataset { -SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} - -Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (id_count_ > num_samples_) { - RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); - } else if (id_count_ == num_samples_) { - (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(current_id_, DataBuffer::kDeBFlagNone); - std::shared_ptr sampleIds; - - // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for - // samples per buffer though. - int64_t remaining_ids = num_samples_ - id_count_; - int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); - - RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); - auto idPtr = sampleIds->begin(); - for (int64_t i = 0; i < num_elements; i++) { - int64_t sampled_id = current_id_; - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *idPtr = sampled_id; - current_id_++; // Move the current id to the next one in the sequence - idPtr++; - } - - id_count_ += num_elements; // Count the packed ids towards our overall sample count - - TensorRow row(1, sampleIds); - (*out_buffer)->set_tensor_table(std::make_unique(1, row)); - } - return Status::OK(); -} - -Status SequentialSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); - CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n"); - // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample - // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data. - int64_t available_row_count = num_rows_ - start_index_; - if (num_samples_ == 0 || num_samples_ > available_row_count) { - num_samples_ = available_row_count; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); - samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; - return Status::OK(); -} - -Status SequentialSampler::ResetSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); - current_id_ = start_index_; - id_count_ = 0; - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -void SequentialSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: SequentialSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info - out << "\nStart index: " << start_index_; - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h deleted file mode 100644 index 06f084fb7a..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ - -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -class SequentialSampler : public Sampler { - public: - // Constructor - // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the - // full amount of ids from the dataset - // @param start_index - The starting index value - // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call - explicit SequentialSampler(int64_t num_samples, int64_t start_index, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~SequentialSampler() = default; - - // init sampler, called by python - Status InitSampler() override; - - // for next epoch of sampleIds - // @return - The error code return - Status ResetSampler() override; - - // Op calls this to get next Buffer that contains all the sampleIds - // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op - // @param int32_t workerId - not meant to be used - // @return - The error code return - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - int64_t current_id_; // The id sequencer. Each new id increments from this - int64_t start_index_; // The starting id. current_id_ begins from here. - int64_t id_count_; // An internal counter that tracks how many ids have been produced -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc deleted file mode 100644 index 08a623ed1b..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" - -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -// Constructor. -SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector &indices, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} - -// Initialized this Sampler. -Status SubsetRandomSampler::InitSampler() { - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); - - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // In this case, the id's are provided by the user. Cap the num_samples on the number of id's given. - if (num_samples_ == 0 || num_samples_ > static_cast(indices_.size())) { - num_samples_ = static_cast(indices_.size()); - } - // Initialize random generator with seed from config manager - rand_gen_.seed(GetSeed()); - - if (samples_per_buffer_ > num_samples_) { - samples_per_buffer_ = num_samples_; - } - - // num_samples_ could be smaller than the total number of input id's. - // We will shuffle the full set of id's, but only select the first num_samples_ of them later. - std::shuffle(indices_.begin(), indices_.end(), rand_gen_); - - return Status::OK(); -} - -// Reset the internal variable to the initial state. -Status SubsetRandomSampler::ResetSampler() { - // Reset the internal counters. - sample_id_ = 0; - buffer_id_ = 0; - - // Randomized the indices again. - rand_gen_.seed(GetSeed()); - std::shuffle(indices_.begin(), indices_.end(), rand_gen_); - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -// Get the sample ids. -Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { - // All samples have been drawn - if (sample_id_ == num_samples_) { - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); - std::shared_ptr outputIds; - - int64_t last_id = sample_id_ + samples_per_buffer_; - // Handling the return all samples at once, and when last draw is not a full batch. - if (last_id > num_samples_) { - last_id = num_samples_; - } - - // Allocate tensor - RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); - - // Initialize tensor - auto id_ptr = outputIds->begin(); - while (sample_id_ < last_id) { - if (indices_[sample_id_] >= num_rows_) { - std::string err_msg = - "Generated id is bigger than numRows (out of bound). indices_: " + std::to_string(indices_[sample_id_]) + - " num_rows_: " + std::to_string(num_rows_); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - int64_t sampled_id = indices_[sample_id_]; - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); - } - - *id_ptr = sampled_id; - id_ptr++; - sample_id_++; - } - - // Create a TensorTable from that single tensor and push into DataBuffer - (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); - } - - return Status::OK(); -} - -void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: SubsetRandomSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h deleted file mode 100644 index ffc7cb17bc..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/subset_random_sampler.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ - -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -// Randomly samples elements from a given list of indices, without replacement. -class SubsetRandomSampler : public Sampler { - public: - // Constructor. - // @param num_samples The number of samples to draw. 0 for the full amount. - // @param indices List of indices from where we will randomly draw samples. - // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). - // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - explicit SubsetRandomSampler(int64_t num_samples, const std::vector &indices, - std::int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~SubsetRandomSampler() = default; - - // Initialize the sampler. - // @return Status - Status InitSampler() override; - - // Reset the internal variable to the initial state and reshuffle the indices. - // @return Status - Status ResetSampler() override; - - // Get the sample ids. - // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. - // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - // A list of indices (already randomized in constructor). - std::vector indices_; - - // Current sample id. - int64_t sample_id_; - - // Current buffer id. - int64_t buffer_id_; - - // A random number generator. - std::mt19937 rand_gen_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc deleted file mode 100644 index 6bf3d2d85e..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc +++ /dev/null @@ -1,169 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" - -#include -#include -#include -#include -#include - -#include "dataset/core/global_context.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -// Constructor. -WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, - int64_t samples_per_buffer) - : Sampler(num_samples, samples_per_buffer), - weights_(weights), - replacement_(replacement), - sample_id_(0), - buffer_id_(0) {} - -// Initialized this Sampler. -Status WeightedRandomSampler::InitSampler() { - // Special value of 0 for num_samples means that the user wants to sample the entire set of data. - // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. - if (num_samples_ == 0 || num_samples_ > num_rows_) { - num_samples_ = num_rows_; - } - CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); - CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); - - // Initialize random generator with seed from config manager - rand_gen_.seed(GetSeed()); - - samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; - - if (!replacement_) { - exp_dist_ = std::make_unique>(1); - InitOnePassSampling(); - } else { - discrete_dist_ = std::make_unique>(weights_.begin(), weights_.end()); - } - - return Status::OK(); -} - -// Initialized the computation for generating weighted random numbers without replacement using onepass method. -void WeightedRandomSampler::InitOnePassSampling() { - exp_dist_->reset(); - onepass_ids_.clear(); - std::vector> val_idx; - for (size_t i = 0; i < weights_.size(); i++) { - val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i)); - } - - // Partial sort the first `numSamples` elements. - std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); - for (int64_t i = 0; i < num_samples_; i++) { - onepass_ids_.push_back(val_idx[i].second); - } -} - -// Reset the internal variable to the initial state and reshuffle the indices. -Status WeightedRandomSampler::ResetSampler() { - sample_id_ = 0; - buffer_id_ = 0; - rand_gen_.seed(GetSeed()); - if (!replacement_) { - InitOnePassSampling(); - } else { - discrete_dist_->reset(); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->ResetSampler()); - } - - return Status::OK(); -} - -// Get the sample ids. -Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { - if (weights_.size() > static_cast(num_rows_)) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); - } - - if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { - RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); - } - - if (sample_id_ == num_samples_) { - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); - } else { - if (HasChildSampler()) { - RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); - } - - (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); - std::shared_ptr outputIds; - - int64_t last_id = sample_id_ + samples_per_buffer_; - // Handling the return all samples at once, and when last draw is not a full batch. - if (last_id > num_samples_) { - last_id = num_samples_; - } - - // Allocate tensor. - RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); - - // Initialize tensor. - auto id_ptr = outputIds->begin(); - // Assign the data to tensor element. - while (sample_id_ < last_id) { - int64_t genId; - if (replacement_) { - genId = (*discrete_dist_)(rand_gen_); - } else { - // Draw sample without replacement. - genId = onepass_ids_.front(); - onepass_ids_.pop_front(); - } - - if (genId >= num_rows_) { - RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); - } - - if (HasChildSampler()) { - RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId)); - } - - *id_ptr = genId; - id_ptr++; - sample_id_++; - } - - // Create a TensorTable from that single tensor and push into DataBuffer - (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); - } - - return Status::OK(); -} - -void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { - out << "\nSampler: WeightedRandomSampler"; - if (show_all) { - // Call the super class for displaying any common detailed info - Sampler::Print(out, show_all); - // Then add our own info if any - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h deleted file mode 100644 index 1fbe29ed80..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ - -#include -#include -#include -#include - -#include "dataset/engine/datasetops/source/sampler/sampler.h" - -namespace mindspore { -namespace dataset { -// Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). -class WeightedRandomSampler : public Sampler { - public: - // Constructor. - // @param num_samples Number of samples to be drawn. - // @param weights A lift of sample weights. - // @param replacement Determine if samples are drawn with/without replacement. - // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). - // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. - WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, - int64_t samples_per_buffer = std::numeric_limits::max()); - - // Destructor. - ~WeightedRandomSampler() = default; - - // Initialize the sampler. - // @param op (Not used in this sampler) - // @return Status - Status InitSampler() override; - - // Reset the internal variable to the initial state and reshuffle the indices. - Status ResetSampler() override; - - // Get the sample ids. - // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. - // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. - Status GetNextSample(std::unique_ptr *out_buffer) override; - - // Printer for debugging purposes. - // @param out - output stream to write to - // @param show_all - bool to show detailed vs summary - void Print(std::ostream &out, bool show_all) const override; - - private: - // A list of weights for each sample. - std::vector weights_; - - // A flag indicating if samples are drawn with/without replacement. - bool replacement_; - - // Current sample id. - int64_t sample_id_; - - // Current buffer id. - int64_t buffer_id_; - - // Random engine and device - std::mt19937 rand_gen_; - - // Discrete distribution for generating weighted random numbers with replacement. - std::unique_ptr> discrete_dist_; - - // Exponential distribution for generating weighted random numbers without replacement. - // based on "Accelerating weighted random sampling without replacement" by Kirill Muller. - std::unique_ptr> exp_dist_; - - // Initialized the computation for generating weighted random numbers without replacement - // using onepass method. - void InitOnePassSampling(); - - // Store the random weighted ids generated by onepass method in `InitOnePassSampling` - std::deque onepass_ids_; -}; -} // namespace dataset -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc deleted file mode 100644 index 818b5ab3f4..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ /dev/null @@ -1,498 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/core/config_manager.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/wait_post.h" -#include "dataset/util/random.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { -TextFileOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_shuffle_files_(false), - builder_sampler_(nullptr) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_num_workers_ = config_manager->num_parallel_workers(); - builder_op_connector_size_ = config_manager->op_connector_size(); - builder_rows_per_buffer_ = config_manager->rows_per_buffer(); - builder_worker_connector_size_ = config_manager->worker_connector_size(); -} - -Status TextFileOp::Builder::ValidateInputs() const { - std::string err_msg; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; - err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -Status TextFileOp::Builder::Build(std::shared_ptr *op) { - RETURN_IF_NOT_OK(ValidateInputs()); - - // Throttle the number of workers if we have more workers than files! - if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { - builder_num_workers_ = builder_text_files_list_.size(); - MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; - } - - builder_schema_ = std::make_unique(); - RETURN_IF_NOT_OK( - builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - - std::shared_ptr text_file_op = std::make_shared( - builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, - std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, - builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); - RETURN_IF_NOT_OK(text_file_op->Init()); - *op = std::move(text_file_op); - - return Status::OK(); -} - -TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, - std::unique_ptr schema, std::vector text_files_list, - int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, - std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), - device_id_(device_id), - num_devices_(num_device), - rows_per_buffer_(rows_per_buffer), - total_rows_(total_rows), - text_files_list_(std::move(text_files_list)), - shuffle_files_(shuffle_files), - data_schema_(std::move(schema)), - all_num_rows_(0), - num_rows_per_shard_(0), - filename_index_(std::make_unique()), - finished_reading_dataset_(false), - load_io_block_queue_(true), - load_jagged_connector_(true) { - worker_connector_size_ = worker_connector_size; -} - -// A print method typically used for debugging -void TextFileOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ - << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") - << "\nText files list:\n"; - for (int i = 0; i < text_files_list_.size(); ++i) { - out << " " << text_files_list_[i]; - } - out << "\nData Schema:\n"; - out << *data_schema_ << "\n\n"; - } -} - -Status TextFileOp::Init() { - RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); - - int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); - io_block_queues_.Init(num_workers_, safe_queue_size); - - RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); - - jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); - return Status::OK(); -} - -Status TextFileOp::Reset() { - load_jagged_connector_ = true; - load_io_block_queue_ = true; - - RETURN_IF_NOT_OK(ParallelOp::Reset()); - NotifyToFillIOBlockQueue(); - return Status::OK(); -} - -Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { - TensorRow tRow(1, nullptr); - (*tensor_table)->push_back(std::move(tRow)); - - std::shared_ptr tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); - (**tensor_table)[row][0] = std::move(tensor); - return Status::OK(); -} - -Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id) { - std::ifstream handle(file); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Failed to open file " + file); - } - - int64_t rows_each_buffer = 0; - int64_t rows_total = 0; - std::string line; - std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - std::unique_ptr tensor_table = std::make_unique(); - - while (getline(handle, line)) { - if (line.empty()) { - continue; - } - // If read to the end offset of this file, break. - if (rows_total >= end_offset) { - break; - } - // Skip line before start offset. - if (rows_total < start_offset) { - rows_total++; - continue; - } - - RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); - rows_each_buffer++; - rows_total++; - if (rows_each_buffer == rows_per_buffer_) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - - cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - tensor_table = std::make_unique(); - rows_each_buffer = 0; - } - } - - if (rows_each_buffer > 0) { - cur_buffer->set_tensor_table(std::move(tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); - } - - return Status::OK(); -} - -Status TextFileOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - - std::unique_ptr io_block; - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - while (!io_block->eof()) { - if (!io_block->eoe()) { - if (load_jagged_connector_) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - } - } else { - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); - } - - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - } - return Status::OK(); -} - -// Pops an element from a queue in io_block_queues -Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); - - return Status::OK(); -} - -// Pushes an element to a queue in io_block_queues -Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. -// When the worker pops this control indicator, it will shut itself down gracefully. -Status TextFileOp::PostEndOfData() { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); - } - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker -// pops this control indicator, it will wait until the next epoch starts and then resume execution. -Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); - } - - return Status::OK(); -} - -static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); -} - -bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count) { - *start_offset = 0; - *end_offset = 0; - bool push = false; - int64_t start_index = device_id_ * num_rows_per_shard_; - if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; - return false; - } - - int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; - if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { - *start_offset = start_index - pre_count; - push = true; - if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - if (pre_count >= start_index && pre_count < end_index) { - *start_offset = 0; - push = true; - if (pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - return push; -} - -Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { - int32_t queue_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - while (!finish) { - std::vector> file_index; - if (!i_keys.empty()) { - for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); - } - } else { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - { - if (!load_io_block_queue_) { - break; - } - } - file_index.emplace_back(std::pair(it.value(), it.key())); - } - } - for (auto file_info : file_index) { - if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { - auto ioBlock = - std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_info.first]; - } - - if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { - finish = false; - } else { - finish = true; - } - } - - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -Status TextFileOp::WaitToFillIOBlockQueue() { - // must be called first if called by worker spanwed by taskgroup - TaskManager::FindMe()->Post(); - - std::vector i_keys; - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; - while (true) { - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); - io_block_queue_wait_post_.Clear(); - - if (finished_reading_dataset_) { - break; - } - - if (shuffle_files_) { - ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); - } - RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); - } - return Status::OK(); -} - -void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } - -Status TextFileOp::operator()() { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - - // launch one thread, responsible for filling IoBlockQueue - RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); - - // Read data from disk into buffers - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); - - // must be called after launching workers. - TaskManager::FindMe()->Post(); - - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); - NotifyToFillIOBlockQueue(); - while (!finished_reading_dataset_) { - int64_t buffer_id = 0; - int32_t workers_done = 0; - int64_t rows_read = 0; - load_io_block_queue_ = true; - - while (workers_done < num_workers_) { - std::unique_ptr buffer; - RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); - if (buffer->eoe()) { - workers_done++; - } else if (total_rows_ == 0 || rows_read < total_rows_) { - if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { - int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); - RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); - } - rows_read += buffer->NumRows(); - buffer->set_id(buffer_id++); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); - } else { - // end of epoch - load_jagged_connector_ = false; - load_io_block_queue_ = false; - } - } - - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - finished_reading_dataset_ = true; - NotifyToFillIOBlockQueue(); - } else { - jagged_buffer_connector_->DoReset(); - buffer_id = 0; - } - } - - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - RETURN_IF_NOT_OK(PostEndOfData()); - - return Status::OK(); -} - -int64_t TextFileOp::CountTotalRows(const std::string &file) { - std::ifstream handle(file); - if (!handle.is_open()) { - MS_LOG(ERROR) << "Failed to open file: " << file; - return 0; - } - - std::string line; - int64_t count = 0; - while (getline(handle, line)) { - if (!line.empty()) { - count++; - } - } - - return count; -} - -Status TextFileOp::CalculateNumRowsPerShard() { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - int64_t count = CountTotalRows(it.value()); - filename_numrows_[it.value()] = count; - all_num_rows_ += count; - } - if (all_num_rows_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API TextFileDataset.Please check file path or dataset API " - "validation first."); - } - - num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); - MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; - return Status::OK(); -} - -Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { - std::shared_ptr op; - *count = 0; - RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); - for (auto file : files) { - *count += op->CountTotalRows(file); - } - return Status::OK(); -} - -Status TextFileOp::ComputeColMap() { - // Set the column name mapping (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h deleted file mode 100644 index 5b787d4dad..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ - -#include -#include -#include -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/util/auto_index.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/util/queue.h" -#include "dataset/util/wait_post.h" -#include "dataset/engine/jagged_connector.h" - -namespace mindspore { -namespace dataset { -using StringIndex = AutoIndexObj; - -class TextFileOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Checks if the inputs of the builder is valid. - // @return Status - the error code returned. - Status ValidateInputs() const; - - // Create the final object. - // @param op - dataset op. - // @return - the error code return. - Status Build(std::shared_ptr *op); - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumDevices(int64_t num_dev) { - builder_num_devices_ = num_dev; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDeviceId(int64_t dev_id) { - builder_device_id_ = dev_id; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetTextFilesList(const std::vector &files_list) { - builder_text_files_list_ = files_list; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleFiles(bool shuffle_files) { - builder_shuffle_files_ = shuffle_files; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetTotalRows(int64_t total_rows) { - builder_total_rows_ = total_rows; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - private: - int32_t builder_device_id_; - int32_t builder_num_devices_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_total_rows_; - int32_t builder_worker_connector_size_; - std::vector builder_text_files_list_; - bool builder_shuffle_files_; - std::unique_ptr builder_schema_; - std::shared_ptr builder_sampler_; - }; - - // Constructor of TextFileOp - // @note The builder class should be used to call this constructor. - // @param num_workers - number of worker threads reading data from tf_file files. - // @param rows_per_buffer - number of rows that a full buffer will contain. - // @param total_num_rows - number of rows to read - // @param dataset_files_list - list of filepaths for the dataset files. - // @param data_schema - the data schema object. - // @param op_connector_size - size of each queue in the connector that the child operator pulls from. - // @param columns_to_load - the names of the columns to load data from. - // @param shuffle_files - whether or not to shuffle the files before reading data. - // @param equal_rows_per_shard - whether or not to get equal rows for each process. - // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes - TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, - std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); - - // Default destructor - ~TextFileOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors - // @return Status - the error code returned - Status Init(); - - // Class functor operator () override. - // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - the error code returned. - Status operator()() override; - - // Overrides base class reset method. Cleans up any state info from it's previous execution - // reinitializes itself so that it can be executed again, as if it was just created. - // @return Status - the error code returned. - Status Reset() override; - - // Get total rows in files. - // @param files - all text files. - // @param count - number of rows. - // @return Status - the error coed returned. - static Status CountAllFileRows(const std::vector &files, int64_t *count); - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "TextFileOp"; } - - // File names getter - // @return Vector of the input file names - std::vector FileNames() { return text_files_list_; } - - private: - // The entry point for when workers are launched. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status WorkerEntry(int32_t worker_id) override; - - // Parses a single row and puts the data into a tensor table. - // @param line - the content of the row. - // @param tensor_table - the tensor table to put the parsed data in. - // @param row - the id of the row filled in the tensor table. - // @return Status - the error code returned. - Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); - - // Reads a text file and loads the data into multiple buffers. - // @param file - the file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, - const int32_t worker_id); - - // Calculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard(); - - // Count number of rows in each file. - // @param filename - text file name. - // @return int64_t - the total number of rows in file. - int64_t CountTotalRows(const std::string &file); - - // Notifies the thread which called FillIoBlockQueue to resume execution - void NotifyToFillIOBlockQueue(); - - // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. - // @return Status - the error code returned. - Status WaitToFillIOBlockQueue(); - - // Fill the IOBlockQueue. - // @para i_keys - keys of file to fill to the IOBlockQueue - // @return Status - the error code returned. - Status FillIOBlockQueue(const std::vector &i_keys); - - // Select file and push it to the block queue. - // @param file_name - File name. - // @param start_file - If file contains the first sample of data. - // @param end_file - If file contains the end sample of data. - // @param pre_count - Total rows of previous files. - // @return Status - the error code returned. - bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count); - - // Pops an element from a queue in IOBlockQueue. - // @param index - the index of the queue to pop from. - // @param out_block - the popped element. - // @return Status - the error code returned. - Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); - - // Pushes an element to a queue in IOBlockQueue. - // @param index - the index of the queue to push to. - // @param io_block - the element to push onto the queue. - // @return Status - the error code returned. - Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. - // When the worker pops this control indicator, it will shut itself down gracefully. - // @return Status - the error code returned. - Status PostEndOfData(); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker - // pops this control indicator, it will wait until the next epoch starts and then resume execution. - // @return Status - the error code returned. - Status PostEndOfEpoch(int32_t queue_index); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t device_id_; - int32_t num_devices_; - int64_t rows_per_buffer_; - int64_t total_rows_; - std::vector text_files_list_; - bool shuffle_files_; - std::unique_ptr data_schema_; - int64_t all_num_rows_; - int64_t num_rows_per_shard_; - std::map filename_numrows_; - std::unique_ptr filename_index_; - QueueList> io_block_queues_; - WaitPost io_block_queue_wait_post_; - bool finished_reading_dataset_; - bool load_io_block_queue_; - bool load_jagged_connector_; - std::unique_ptr jagged_buffer_connector_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc deleted file mode 100644 index 6e6d885cb1..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ /dev/null @@ -1,1054 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/tf_reader_op.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "proto/example.pb.h" -#include "./securec.h" -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/connector.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/jagged_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/wait_post.h" -#include "utils/system/crc32c.h" - -namespace mindspore { -namespace dataset { -TFReaderOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_equal_rows_per_shard_(false), - builder_sampler_(nullptr) { - std::shared_ptr config_manager = GlobalContext::config_manager(); - builder_num_workers_ = config_manager->num_parallel_workers(); - builder_worker_connector_size_ = config_manager->worker_connector_size(); - builder_op_connector_size_ = config_manager->op_connector_size(); - builder_rows_per_buffer_ = config_manager->rows_per_buffer(); - builder_shuffle_files_ = false; - builder_data_schema_ = std::make_unique(); -} - -bool ValidateFirstRowCrc(const std::string &filename) { - std::ifstream reader; - reader.open(filename); - if (!reader) { - return false; - } - - // read data - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // read crc from file - uint32_t masked_crc = 0; - (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); - - // generate crc from data - uint32_t generated_crc = - system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); - - return masked_crc == generated_crc; -} - -Status TFReaderOp::Builder::ValidateInputs() const { - std::string err_msg; - - if (builder_num_workers_ <= 0) { - err_msg += "Number of parallel workers is smaller or equal to 0\n"; - } - - if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { - err_msg += "Wrong sharding configs\n"; - } - - std::vector invalid_files(builder_dataset_files_list_.size()); - auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), - [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); - invalid_files.resize(std::distance(invalid_files.begin(), it)); - - if (!invalid_files.empty()) { - err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; - - std::string accumulated_filenames = std::accumulate( - invalid_files.begin(), invalid_files.end(), std::string(""), - [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); - err_msg += accumulated_filenames; - } - - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) { - RETURN_IF_NOT_OK(ValidateInputs()); - - // Throttle the number of workers if we have more workers than files! - if (static_cast(builder_num_workers_) > builder_dataset_files_list_.size()) { - builder_num_workers_ = builder_dataset_files_list_.size(); - MS_LOG(WARNING) << "TFReader operator parallelism reduced to " << builder_num_workers_ << " workers."; - } - - std::shared_ptr new_tf_reader_op = std::make_shared( - builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, - builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, - builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, - std::move(builder_sampler_)); - - RETURN_IF_NOT_OK(new_tf_reader_op->Init()); - *out_tf_reader_op = std::move(new_tf_reader_op); - return Status::OK(); -} - -TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, - int64_t total_num_rows, std::vector dataset_files_list, - std::unique_ptr data_schema, int32_t op_connector_size, - std::vector columns_to_load, bool shuffle_files, int32_t num_device, - int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) - : ParallelOp(num_workers, op_connector_size, std::move(sampler)), - device_id_(device_id), - num_devices_(num_device), - rows_per_buffer_(rows_per_buffer), - total_rows_(total_num_rows), - dataset_files_list_(std::move(dataset_files_list)), - columns_to_load_(std::move(columns_to_load)), - finished_reading_dataset_(false), - shuffle_files_(shuffle_files), - data_schema_(std::move(data_schema)), - filename_index_(std::make_unique()), - load_io_block_queue_(true), - load_jagged_connector_(true), - num_rows_(0), - num_rows_per_shard_(0), - equal_rows_per_shard_(equal_rows_per_shard) { - worker_connector_size_ = worker_connector_size; -} - -// A print method typically used for debugging -void TFReaderOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ - << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") - << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; - for (int i = 0; i < dataset_files_list_.size(); ++i) { - out << " " << dataset_files_list_[i]; - } - if (!columns_to_load_.empty()) { - out << "\nColumns to load:\n"; - for (int i = 0; i < columns_to_load_.size(); ++i) { - out << " " << columns_to_load_[i]; - } - } - out << "\nData Schema:\n"; - out << *data_schema_ << "\n\n"; - } -} - -Status TFReaderOp::Init() { - if (data_schema_->Empty()) { - RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_)); - } - - if (total_rows_ == 0) { - total_rows_ = data_schema_->num_rows(); - } - if (total_rows_ < 0) { - RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); - } - - // Build the index with our files such that each file corresponds to a key id. - RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); - - // The creation of the internal connector has been delayed until now, since we may have adjusted the - // number of workers. Now that the worker count is established, create the connector now in the - // parallel op base. - RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); - - jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); - - // temporary: make size large enough to hold all files + EOE to avoid hangs - int32_t safe_queue_size = static_cast(std::ceil(dataset_files_list_.size() / num_workers_)) + 1; - io_block_queues_.Init(num_workers_, safe_queue_size); - - return Status::OK(); -} - -Status TFReaderOp::CalculateNumRowsPerShard() { - if (!equal_rows_per_shard_) { - return Status::OK(); - } - - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - std::vector file(1, it.value()); - int64_t num = CountTotalRowsSectioned(file, 0, 1); - filename_numrows_[it.value()] = num; - num_rows_ += num; - } - num_rows_per_shard_ = static_cast(std::ceil(num_rows_ * 1.0 / num_devices_)); - if (num_rows_per_shard_ == 0) { - RETURN_STATUS_UNEXPECTED( - "There is no valid data matching the dataset API TFRecordDataset.Please check file path or dataset API " - "validation first."); - } - return Status::OK(); -} -// Class functor operator () override. -// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will -// provide the master loop that drives the logic for performing the work -Status TFReaderOp::operator()() { - RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); - - // launch one thread, responsible for filling mIOBlockQueue - RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this))); - - // launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading - // data from disk into buffers - RETURN_IF_NOT_OK( - tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1))); - - // must be called after launching workers. workers can't be spawned after this post, - // so workers have to be kept alive until the end of the program - TaskManager::FindMe()->Post(); - - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); - - NotifyToFillIOBlockQueue(); - while (!finished_reading_dataset_) { - int64_t buffer_id = 0; - int32_t workers_done = 0; - int64_t rows_read = 0; - { - std::unique_lock lock(load_io_block_queue_mutex_); - load_io_block_queue_ = true; - } - - while (workers_done < num_workers_) { - std::unique_ptr fetched_buffer; - RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer)); - if (fetched_buffer->eoe()) { - workers_done++; - } else if (total_rows_ == 0 || rows_read < total_rows_) { - // we need to push a buffer - if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) { - // this is last buffer we need, and we only need a part of it - int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read); - RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove)); - } - - rows_read += fetched_buffer->NumRows(); - fetched_buffer->set_id(buffer_id); - buffer_id++; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); - } else { - // user specified number of rows they want, and we read enough rows - // - // IOBlockQueue thread needs to: - // -stop pushing stuff to IOBlockQueue - // -call PostEndOfEpoch (will send EOE) - // -wait for reset - // - // Worker threads need to: - // -stop reading the file they are currently reading and throw it away - // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) - // - // Master thread needs to: - // -tell IOBlockQueue thread to stop pushing - // -tell worker threads to stop reading the file tey are currently reading - // -keep pulling until EOE - - // don't think we need a lock for now - load_jagged_connector_ = false; - - std::unique_lock lock(load_io_block_queue_mutex_); - load_io_block_queue_ = false; - } - } - - // all workers finished reading for this epoch, and we have read all the data from all workers - std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); - - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - finished_reading_dataset_ = true; - NotifyToFillIOBlockQueue(); - } else { - jagged_buffer_connector_->DoReset(); - buffer_id = 0; - } - } - - std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - - RETURN_IF_NOT_OK(PostEndOfData()); - - return Status::OK(); -} - -// static local-only helper function -static void shuffleKeys(std::vector *i_keys, uint32_t seed) { - std::mt19937 rng(seed); - std::shuffle(i_keys->begin(), i_keys->end(), rng); -} - -// The entry point for when workers are launched. -Status TFReaderOp::WorkerEntry(int32_t worker_id) { - // must be called first if called by worker spawned by taskgroup - TaskManager::FindMe()->Post(); - - std::unique_ptr io_block; - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - - while (!io_block->eof()) { - if (!io_block->eoe()) { - if (load_jagged_connector_) { - std::string filename; - RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); - int64_t start_offset = io_block->GetStartOffset(); - int64_t end_offset = io_block->GetEndOffset(); - RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); - MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; - } - } else { - std::unique_ptr eoe_buffer = std::make_unique(1, DataBuffer::kDeBFlagEOE); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); - } - - RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); - } - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. -// When the worker pops this control indicator, it will shut itself down gracefully. -Status TFReaderOp::PostEndOfData() { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); - } - - return Status::OK(); -} - -// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker -// pops this control indicator, it will wait until the next epoch starts and then resume execution. -Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { - for (int i = 0; i < num_workers_; ++i) { - std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); - } - - return Status::OK(); -} - -bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count) { - *start_offset = 0; - *end_offset = 0; - bool push = false; - int64_t start_index = device_id_ * num_rows_per_shard_; - if (device_id_ + 1 < 0) { - MS_LOG(ERROR) << "Device id is invalid"; - return false; - } - int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; - - if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { - *start_offset = start_index - pre_count; - push = true; - if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - if (pre_count >= start_index && pre_count < end_index) { - *start_offset = 0; - push = true; - if (pre_count + filename_numrows_[file_name] >= end_index) { - *end_offset = end_index - pre_count; - } else { - *end_offset = filename_numrows_[file_name]; - } - } - - return push; -} - -Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { - int32_t queue_index = 0; - int32_t key_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - bool end_of_epoch = false; - while (!finish) { - for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { - { - std::unique_lock lock(load_io_block_queue_mutex_); - if (load_io_block_queue_ == false) { - end_of_epoch = true; - break; - } - } - if (!equal_rows_per_shard_) { - if (key_index++ % num_devices_ == device_id_) { - auto ioBlock = std::make_unique(*it, kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - } else { - // Do an index lookup using that key to get the filename. - std::string file_name = (*filename_index_)[*it]; - if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { - auto ioBlock = std::make_unique(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset; - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_name]; - } - } - if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && - !end_of_epoch) { - finish = false; - } else { - finish = true; - } - } - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -Status TFReaderOp::FillIOBlockNoShuffle() { - int32_t queue_index = 0; - int32_t key_index = 0; - int64_t pre_count = 0; - int64_t start_offset = 0; - int64_t end_offset = 0; - bool finish = false; - bool end_of_epoch = false; - while (!finish) { - // Iterate over all the keys and add one key to each block. - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - { - std::unique_lock lock(load_io_block_queue_mutex_); - if (load_io_block_queue_ == false) { - end_of_epoch = true; - break; - } - } - if (!equal_rows_per_shard_) { - if (key_index++ % num_devices_ == device_id_) { - auto ioBlock = - std::make_unique(it.key(), kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - } else { - std::string file_name = it.value(); - if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { - auto ioBlock = std::make_unique(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone); - RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); - queue_index = (queue_index + 1) % num_workers_; - } - - pre_count += filename_numrows_[file_name]; - } - } - if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && - !end_of_epoch) { - finish = false; - } else { - finish = true; - } - } - - RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); - return Status::OK(); -} - -// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. -Status TFReaderOp::WaitToFillIOBlockQueue() { - // must be called first if called by worker spawned by taskgroup - TaskManager::FindMe()->Post(); - - std::vector i_keys; - // Generate a vector of keys that we can shuffle - if (shuffle_files_) { - for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { - i_keys.push_back(it.key()); - } - } - uint32_t seed = 0; - while (true) { - RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); - io_block_queue_wait_post_.Clear(); - - if (finished_reading_dataset_) { - break; - } - - if (shuffle_files_) { - shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); - RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys)); - } else { // shuffle_files_ == false - RETURN_IF_NOT_OK(FillIOBlockNoShuffle()); - } - } - - return Status::OK(); -} - -// Notifies the thread which called WaitToFillIOBlockQueue to resume execution. -void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } - -// Pops an element from a queue in io_block_queues -Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); - - return Status::OK(); -} - -// Pushes an element to a queue in io_block_queues -Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { - RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); - - return Status::OK(); -} - -// Reads a tf_file file and loads the data into multiple buffers. -Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, - const int32_t &worker_id) { - std::ifstream reader; - reader.open(filename); - if (!reader) { - RETURN_STATUS_UNEXPECTED("failed to open file: " + filename); - } - - int64_t rows_read = 0; - int64_t rows_total = 0; - std::unique_ptr current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - std::unique_ptr new_tensor_table = std::make_unique(); - - while (reader.peek() != EOF) { - if (!load_jagged_connector_) { - break; - } - - // read length - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // ignore crc header - (void)reader.ignore(static_cast(sizeof(int32_t))); - - // read serialized Example - std::string serialized_example; - serialized_example.resize(record_length); - (void)reader.read(&serialized_example[0], static_cast(record_length)); - if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) { - dataengine::Example tf_file; - if (!tf_file.ParseFromString(serialized_example)) { - std::string errMsg = "parse tfrecord failed"; - RETURN_STATUS_UNEXPECTED(errMsg); - } - RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); - rows_read++; - } - - // ignore crc footer - (void)reader.ignore(static_cast(sizeof(int32_t))); - rows_total++; - - if (rows_read == rows_per_buffer_) { - current_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); - - current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); - new_tensor_table = std::make_unique(); - rows_read = 0; - } - } - - if (rows_read > 0) { - current_buffer->set_tensor_table(std::move(new_tensor_table)); - RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); - } - - return Status::OK(); -} - -// Parses a single row and puts the data into a tensor table. -Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, - int64_t row) { - int32_t num_columns = data_schema_->NumColumns(); - TensorRow newRow(num_columns, nullptr); - (*tensor_table)->push_back(std::move(newRow)); - - for (int32_t col = 0; col < num_columns; ++col) { - const ColDescriptor current_col = data_schema_->column(col); - const dataengine::Features &example_features = tf_file->features(); - const google::protobuf::Map &feature_map = example_features.feature(); - const dataengine::Feature &column_values_list = feature_map.at(current_col.name()); - RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col)); - } - - return Status::OK(); -} - -// Parses a single cell and puts the data into a tensor table. -Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table, - const dataengine::Feature &column_values_list, const ColDescriptor ¤t_col, - int64_t row, int32_t col) { - const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case(); - std::unique_ptr float_array; // For staging data from protobuf deserialization - const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor - - // This variable will point into the above staging variables. - // Also used for creating shape attributes. - int32_t num_elements = 0; - - // we build a tensor first a read directly into it if we need to cast - std::shared_ptr ts; - - // Depending on the type of data from the tf_file, we want to extract 2 things: - // 1) A pointer to the data as a const unsigned char * - // 2) The number of elements of the data - // After those are determined, we can then build the tensor to represent this data. - switch (column_list_type) { - case dataengine::Feature::KindCase::kBytesList: { - RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts)); - - break; - } - case dataengine::Feature::KindCase::kFloatList: { - RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array)); - - data_ptr = reinterpret_cast(float_array.get()); - - // only floatList needs to create the tensor here, other two lists read directly - // into the tensor - TensorShape current_shape = TensorShape::CreateUnknownRankShape(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape)); - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&ts, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr)); - break; - } - case dataengine::Feature::KindCase::kInt64List: { - RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts)); - break; - } - case dataengine::Feature::KindCase::KIND_NOT_SET: { - std::string err_msg = "tf_file column list type enum is KIND_NOT_SET"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - default: { - std::string err_msg = "tf_file column list type enum does not match any known DE type"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - } - - (**tensor_table)[row][col] = std::move(ts); - - return Status::OK(); -} - -// Overrides base class reset method. Cleans up any state info from it's previous execution and -// reinitializes itself so that it can be executed again, as if it was just created. -Status TFReaderOp::Reset() { - // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true - load_jagged_connector_ = true; - - { - std::unique_lock lock(load_io_block_queue_mutex_); - load_io_block_queue_ = true; - } - - RETURN_IF_NOT_OK(ParallelOp::Reset()); - NotifyToFillIOBlockQueue(); - - return Status::OK(); -} - -Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor) { - // kBytesList can map to the following DE types ONLY! - // DE_UINT8, DE_INT8 - // Must be single byte type for each element! - if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8 && - current_col.type() != DataType::DE_STRING) { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - const dataengine::BytesList &bytes_list = column_values_list.bytes_list(); - - *num_elements = bytes_list.value_size(); - - if (current_col.type() == DataType::DE_STRING) { - TensorShape shape = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape)); - return Status::OK(); - } - - uint64_t max_size = 0; - for (uint32_t i = 0; i < bytes_list.value_size(); ++i) max_size = std::max(max_size, bytes_list.value(i).size()); - - int64_t pad_size = max_size; - - // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn - if (current_col.hasShape()) { - TensorShape cur_shape = current_col.shape(); - if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) { - int64_t new_pad_size = 1; - for (int i = 1; i < cur_shape.Size(); ++i) { - if (cur_shape[i] == TensorShape::kDimUnknown) { - std::string err_msg = "More than one unknown dimension in the shape of column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - new_pad_size *= cur_shape[i]; - } - pad_size = new_pad_size; - } - } - - // know how many elements there are and the total bytes, create tensor here: - TensorShape current_shape = TensorShape::CreateScalar(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, current_shape, current_col.type(), pad_size)); - - return Status::OK(); -} - -Status TFReaderOp::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::unique_ptr *float_array) { - // KFloatList can only map to DE types: - // DE_FLOAT32 - if (current_col.type() != DataType::DE_FLOAT32) { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - const dataengine::FloatList &float_list = column_values_list.float_list(); - - // Identify how many values we have and then create a local array of these - // to deserialize into - *num_elements = float_list.value_size(); - *float_array = std::make_unique(*num_elements); - for (int i = 0; i < float_list.value_size(); ++i) { - (*float_array)[i] = float_list.value(i); - } - - return Status::OK(); -} - -// Determines which template type to use and calls LoadIntList -Status TFReaderOp::LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor) { - if (current_col.type() == DataType::DE_UINT64) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT64) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_UINT32) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT32) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_UINT16) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT16) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_UINT8) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else if (current_col.type() == DataType::DE_INT8) { - RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); - } else { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - return Status::OK(); -} - -// Reads values from a bytes list and casts the value to type T, must be an integral type -// compatible with int64_t -template -Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor) { - if (!(current_col.type().IsInt())) { - std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - const dataengine::Int64List &int64_list = column_values_list.int64_list(); - - // Identify how many values we have and then create a local array of these - // to deserialize into - *num_elements = int64_list.value_size(); - - // know how many elements there are, create tensor here: - TensorShape current_shape = TensorShape::CreateUnknownRankShape(); - RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type())); - - // Tensors are lazily allocated, this eagerly allocates memory for the tensor. - RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes())); - - int64_t i = 0; - auto it = (*tensor)->begin(); - for (; it != (*tensor)->end(); i++, ++it) { - T element = static_cast(int64_list.value(i)); - *it = element; - } - - return Status::OK(); -} - -Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector columns_to_load) { - std::ifstream reader; - reader.open(tf_file); - - // read length - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // ignore crc header - (void)reader.ignore(static_cast(sizeof(int32_t))); - - // read serialized Example - std::string serialized_example; - serialized_example.resize(record_length); - (void)reader.read(&serialized_example[0], static_cast(record_length)); - - dataengine::Example example; - if (!example.ParseFromString(serialized_example)) RETURN_STATUS_UNEXPECTED("parse tf_file failed"); - - const dataengine::Features &example_features = example.features(); - const google::protobuf::Map &feature_map = example_features.feature(); - - if (columns_to_load.empty()) { - (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load), - [](const auto &it) -> std::string { return it.first; }); - std::sort(columns_to_load.begin(), columns_to_load.end()); - } - - for (const auto &curr_col_name : columns_to_load) { - auto it = feature_map.find(curr_col_name); - if (it == feature_map.end()) { - RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); - } - std::string column_name = it->first; - - std::string column_type; - - const dataengine::Feature &feature = it->second; - const dataengine::Feature::KindCase kind_case = feature.kind_case(); - switch (kind_case) { - case dataengine::Feature::KindCase::kBytesList: - column_type = "uint8"; - break; - - case dataengine::Feature::KindCase::kFloatList: - column_type = "float32"; - break; - - case dataengine::Feature::KindCase::kInt64List: - column_type = "int64"; - break; - - case dataengine::Feature::KindCase::KIND_NOT_SET: - RETURN_STATUS_UNEXPECTED("trying to make schema, tf_file column list type enum is KIND_NOT_SET"); - - default: - RETURN_STATUS_UNEXPECTED( - "trying to make schema, tf_file column list type enum does not match any known DE type"); - } - - RETURN_IF_NOT_OK( - data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1))); - } - - return Status::OK(); -} - -Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads, - bool estimate) { - try { - if (threads > filenames.size()) { - threads = filenames.size(); - } - - std::vector> async_results; - - int64_t chunk_size = filenames.size() / threads; - int64_t remainder = filenames.size() % threads; - - int64_t begin = 0; - int64_t end = begin; - for (int i = 0; i < threads; i++) { - end += chunk_size; - if (remainder > 0) { - end++; - remainder--; - } - - if (estimate) { - // Parse a single file for each chunk with estimate mode on - async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1)); - } else { - // Parse the whole chunk with estimate mode off - async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end)); - } - - begin = end; - } - - int64_t total_rows = 0; - for (int i = 0; i < async_results.size(); i++) { - total_rows += async_results[i].get(); - } - - if (estimate) { - // Each thread only scans 1 file - // Estimated total rows = Average rows * total number of files - total_rows = total_rows / threads * filenames.size(); - } - - *out_total_rows = total_rows; - } catch (const std::exception &e) { - std::string err_msg = "Unexpected error occurred: "; - err_msg += e.what(); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - return Status::OK(); -} - -int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector &filenames, int64_t begin, int64_t end) { - int64_t rows_read = 0; - for (int i = begin; i < end; i++) { - std::ifstream reader; - reader.open(filenames[i]); - if (!reader) { - MS_LOG(DEBUG) << "TFReader operator failed to open file " << filenames[i] << "."; - } - - while (reader.peek() != EOF) { - // read length - int64_t record_length = 0; - (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); - - // ignore crc header - (void)reader.ignore(static_cast(sizeof(int32_t))); - - // ignore tf_file contents - (void)reader.ignore(static_cast(record_length)); - - // ignore crc footer - (void)reader.ignore(static_cast(sizeof(int32_t))); - - rows_read++; - } - } - - return rows_read; -} - -// Visitor accept method for NodePass -Status TFReaderOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status TFReaderOp::ComputeColMap() { - // Construct the column name map for this operator (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} - -// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing -// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so -// that this tf reader will produce the full set of data into the cache. -void TFReaderOp::MakeSimpleProducer() { - device_id_ = 0; - num_devices_ = 1; - total_rows_ = 0; - shuffle_files_ = false; - equal_rows_per_shard_ = false; -} - -// During tree prepare phase, operators may have specific post-operations to perform depending on -// their role. -Status TFReaderOp::PrepareNodePostAction() { - // Run common code from super class before adding TFReaderOp specific handling - RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); - - // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into - // a simpler producer of all data (no shuffling or sharding or anything) - if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { - // This sanity check had been delayed until now in the prepare loop. - // If we are not in a cache path, then we can validate the file-based sharding config. - // If we are in a cache path, there is no file-based sharding so the check is not correct in that - // situation. - if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast(num_devices_)) { - RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n"); - } - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h deleted file mode 100644 index 2613bc5e46..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ /dev/null @@ -1,420 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/util/wait_post.h" -#include "dataset/util/auto_index.h" -#include "dataset/util/status.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" - -namespace dataengine { -class Example; -class Feature; -class BytesList; -} // namespace dataengine - -namespace mindspore { -namespace dataset { -template -class Queue; - -template -class Connector; - -class JaggedConnector; -class FilenameBlock; - -using StringIndex = AutoIndexObj; - -class TFReaderOp : public ParallelOp { - public: - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Checks if the inputs of the builder is valid. - // @return Status - the error code returned. - Status ValidateInputs() const; - - Status Build(std::shared_ptr *out_tf_reader_op); - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDataSchema(std::unique_ptr data_schema) { - builder_data_schema_ = std::move(data_schema); - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetWorkerConnectorSize(int32_t size) { - builder_worker_connector_size_ = size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetNumDevices(int64_t num_dev) { - builder_num_devices_ = num_dev; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDeviceId(int64_t dev_id) { - builder_device_id_ = dev_id; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &setTotalRows(int64_t total_rows) { - builder_total_rows_ = total_rows; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetDatasetFilesList(const std::vector &dataset_files_list) { - builder_dataset_files_list_ = dataset_files_list; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetColumnsToLoad(const std::vector &columns_to_load) { - builder_columns_to_load_ = columns_to_load; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleFiles(bool shuffle_files) { - builder_shuffle_files_ = shuffle_files; - return *this; - } - - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShardEqualRows(bool shard_equal_rows) { - builder_equal_rows_per_shard_ = shard_equal_rows; - return *this; - } - - // Setter method - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - private: - std::unique_ptr builder_data_schema_; - std::shared_ptr builder_sampler_; - int32_t builder_device_id_; - int32_t builder_num_devices_; - int32_t builder_num_workers_; - int32_t builder_worker_connector_size_; - int32_t builder_op_connector_size_; - int64_t builder_rows_per_buffer_; - int64_t builder_total_rows_; - std::vector builder_dataset_files_list_; - std::vector builder_columns_to_load_; - bool builder_shuffle_files_; - bool builder_equal_rows_per_shard_; - }; - - // Constructor of TFReaderOp (2) - // @note The builder class should be used to call this constructor. - // @param num_workers - number of worker threads reading data from tf_file files. - // @param worker_connector_size - size of each internal queue. - // @param rows_per_buffer - number of rows that a full buffer will contain. - // @param total_num_rows - Number of rows to read - // @param dataset_files_list - list of filepaths for the dataset files. - // @param data_schema - the data schema object. - // @param op_connector_size - size of each queue in the connector that the child operator pulls from. - // @param columns_to_load - the names of the columns to load data from. - // @param shuffle_files - whether or not to shuffle the files before reading data. - // @param equal_rows_per_shard - whether or not to get equal rows for each process. - // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes - TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, - std::vector dataset_files_list, std::unique_ptr data_schema, - int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); - - // Default destructor - ~TFReaderOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // Instantiates the internal queues and connectors. - // @return Status - the error code returned. - Status Init(); - - // Class functor operator () override. - // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - the error code returned. - Status operator()() override; - - // Overrides base class reset method. Cleans up any state info from it's previous execution and - // reinitializes itself so that it can be executed again, as if it was just created. - // @return Status - the error code returned. - Status Reset() override; - - // Getter method - int64_t rows_per_buffer() const { return rows_per_buffer_; } - - // Reads all the provided tf_file files and counts the total number of rows. filenames will - // first be sectioned into equal parts, then sections are read in parallel. If threads is - // greater than the number of files, threads will be clamped to the number of files. - // @param out_total_tows - output parameter which contains the total number of rows - // @param filenames - a list of tf_file filenames. - // @param threads - number of threads to use to read the tf_file files. - // @param estimate - estimate mode, under this mode each threads will sample a single file from each chunk - // @return Status - the error code returned. - static Status CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads = 1, - bool estimate = false); - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "TFReaderOp"; } - - // File names getter - // @return Vector of the input file names - std::vector FileNames() { return dataset_files_list_; } - - /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing - /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so - /// that this tf reader will produce the full set of data into the cache. - void MakeSimpleProducer(); - - // During tree prepare phase, operators may have specific post-operations to perform depending on - // their role. - // @notes Derived versions of this function should always call it's superclass version first - // before providing their own implementations. - Status PrepareNodePostAction() override; - - private: - // The entry point for when workers are launched. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status WorkerEntry(int32_t worker_id) override; - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. - // When the worker pops this control indicator, it will shut itself down gracefully. - // @return Status - the error code returned. - Status PostEndOfData(); - - // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker - // pops this control indicator, it will wait until the next epoch starts and then resume execution. - // @return Status - the error code returned. - Status PostEndOfEpoch(int32_t queue_index); - - // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. - // @return Status - the error code returned. - Status WaitToFillIOBlockQueue(); - - // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. - void NotifyToFillIOBlockQueue(); - - // Pops an element from a queue in IOBlockQueue. - // @param index - the index of the queue to pop from. - // @param out_block - the popped element. - // @return Status - the error code returned. - Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); - - // Pushes an element to a queue in IOBlockQueue. - // @param index - the index of the queue to push to. - // @param io_block - the element to push onto the queue. - // @return Status - the error code returned. - Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); - - // Reads a tf_file file and loads the data into multiple buffers. - // @param filename - the tf_file file to read. - // @param start_offset - the start offset of file. - // @param end_offset - the end offset of file. - // @param worker_id - the id of the worker that is executing this function. - // @return Status - the error code returned. - Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, - const int32_t &worker_id); - - // Parses a single row and puts the data into a tensor table. - // @param tf_file - the row to be parsed. - // @param tensor_table - the tensor table to put the parsed data in. - // @param row - the id of the row filled in the tensor table. - // @return Status - the error code returned. - Status LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, int64_t row); - - // Parses a single cell and puts the data into a tensor table. - // @param tensor_table - the tensor table to put the parsed data in. - // @param column_values_list - the cell to parse. - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @return Status - the error code returned. - Status LoadFeature(const std::unique_ptr *tensor_table, const dataengine::Feature &column_values_list, - const ColDescriptor ¤t_col, int64_t row, int32_t col); - - // Reads values from a bytes list - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the bytes list to read from. - // @param elementStr - the string we read the value into. - // @return Status - the error code returned. - static Status LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor); - - // Reads values from a float list - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the float list to read from. - // @Param numElements - number of values in the float list. - // @param float_array - the array we read the values into. - // @return Status - the error code returned. - Status LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::unique_ptr *float_array); - - // Reads values from a bytes list and casts the value to type T, must be an integral - // type compatible with int64_t - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the int list to read from. - // @Param num_elements - number of values in the int list. - // @param tensor - the tensor we read the values into. - // @return Status - the error code returned. - template - Status LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor); - - // Determines which template type to use and calls LoadIntList - // @param current_col - the column descriptor containing the expected shape and type of the data. - // @param column_values_list - the cell that contains the int list to read from. - // @Param numElements - number of values in the int list. - // @param tensor - the tensor we read the values into. - // @return Status - the error code returned. - Status LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, - int32_t *num_elements, std::shared_ptr *tensor); - - // Reads one row of data from a tf file and creates a schema based on that row - // @return Status - the error code returned. - Status CreateSchema(const std::string tf_file, std::vector columns_to_load); - - // Meant to be called async. Will read files in the range [begin, end) and return the total rows - // @param filenames - a list of tf data filenames. - // @param begin - index of first file to read. - // @param end - one greater than the index of the last file to read. - // @return int63_t - the total number of rows of files read. - static int64_t CountTotalRowsSectioned(const std::vector &filenames, const int64_t begin, - const int64_t end); - // Fill IO block queue if shuffle is true - // @param i_keys - shuffle keys. - // @return Status - the error code returned. - Status FillIOBlockShuffle(const std::vector &i_keys); - - /** - * Fill IO block queue if shuffle is false - * @param i_keys - shuffle keys. - * @return Status - the error code returned. - */ - Status FillIOBlockNoShuffle(); - - // Select file and push it to the block queue. - // @param file_name - File name. - // @param start_file - If file contains the first sample of data. - // @param end_file - If file contains the end sample of data. - // @param pre_count - Total rows of previous files. - // @return Status - the error code returned. - bool NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, - const int64_t &pre_count); - - // Caculate number of rows in each shard. - // @return Status - the error code returned. - Status CalculateNumRowsPerShard(); - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t device_id_; - int32_t num_devices_; - int64_t rows_per_buffer_; - int64_t total_rows_; - std::vector dataset_files_list_; - std::vector columns_to_load_; - bool finished_reading_dataset_; - bool shuffle_files_; - std::unique_ptr data_schema_; - std::unique_ptr filename_index_; - bool load_io_block_queue_; - bool load_jagged_connector_; - - std::unique_ptr jagged_buffer_connector_; - QueueList> io_block_queues_; - WaitPost io_block_queue_wait_post_; - std::mutex load_io_block_queue_mutex_; - std::map filename_numrows_; - int64_t num_rows_; - int64_t num_rows_per_shard_; - bool equal_rows_per_shard_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc deleted file mode 100644 index 27a343c973..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc +++ /dev/null @@ -1,471 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/source/voc_op.h" - -#include -#include -#include -#include "./tinyxml2.h" -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -using tinyxml2::XMLDocument; -using tinyxml2::XMLElement; -using tinyxml2::XMLError; -namespace mindspore { -namespace dataset { -const char kColumnImage[] = "image"; -const char kColumnTarget[] = "target"; -const char kColumnAnnotation[] = "annotation"; -const char kJPEGImagesFolder[] = "/JPEGImages/"; -const char kSegmentationClassFolder[] = "/SegmentationClass/"; -const char kAnnotationsFolder[] = "/Annotations/"; -const char kImageSetsSegmentation[] = "/ImageSets/Segmentation/"; -const char kImageSetsMain[] = "/ImageSets/Main/"; -const char kImageExtension[] = ".jpg"; -const char kSegmentationExtension[] = ".png"; -const char kAnnotationExtension[] = ".xml"; -const char kImageSetsExtension[] = ".txt"; - -VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_num_workers_ = cfg->num_parallel_workers(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); - builder_task_type_ = TaskType::Segmentation; -} - -Status VOCOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - if (builder_sampler_ == nullptr) { - const int64_t num_samples = 0; - const int64_t start_index = 0; - builder_sampler_ = std::make_shared(start_index, num_samples); - } - builder_schema_ = std::make_unique(); - if (builder_task_type_ == TaskType::Segmentation) { - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - } else if (builder_task_type_ == TaskType::Detection) { - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); - RETURN_IF_NOT_OK(builder_schema_->AddColumn( - ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); - } - *ptr = std::make_shared(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, - builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, - builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); - return Status::OK(); -} - -Status VOCOp::Builder::SanityCheck() { - Path dir(builder_dir_); - std::string err_msg; - err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : ""; - err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; - return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); -} - -VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, - const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) - : ParallelOp(num_workers, queue_size, std::move(sampler)), - decode_(decode), - row_cnt_(0), - buf_cnt_(0), - task_type_(task_type), - task_mode_(task_mode), - folder_path_(folder_path), - class_index_(class_index), - rows_per_buffer_(rows_per_buffer), - data_schema_(std::move(data_schema)) { - io_block_queues_.Init(num_workers_, queue_size); -} - -Status VOCOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { - for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { - if ((*itr) > num_rows_) continue; - keys->push_back(*itr); - row_cnt_++; - if (row_cnt_ % rows_per_buffer_ == 0) { - RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( - std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); - keys->clear(); - } - } - return Status::OK(); -} - -Status VOCOp::operator()() { - RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); - std::unique_ptr sampler_buffer; - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - while (true) { - std::vector keys; - keys.reserve(rows_per_buffer_); - while (sampler_buffer->eoe() == false) { - std::shared_ptr sample_ids; - RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); - if (sample_ids->type() != DataType(DataType::DE_INT64)) { - RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); - } - RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - if (keys.empty() == false) { - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( - std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); - } - if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { - std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); - std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); - RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); - for (int32_t i = 0; i < num_workers_; i++) { - RETURN_IF_NOT_OK( - io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); - } - return Status::OK(); - } else { - RETURN_IF_NOT_OK( - io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); - RETURN_IF_NOT_OK(wp_.Wait()); - wp_.Clear(); - RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); - } - } -} - -void VOCOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - ParallelOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nNumber of rows: " << num_rows_ << "\nVOC Directory: " << folder_path_ << "\n\n"; - } -} - -Status VOCOp::Reset() { - RETURN_IF_NOT_OK(sampler_->ResetSampler()); - row_cnt_ = 0; - wp_.Set(); - return Status::OK(); -} - -Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { - if (task_type_ == TaskType::Segmentation) { - std::shared_ptr image, target; - const std::string kImageFile = - folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); - const std::string kTargetFile = - folder_path_ + std::string(kSegmentationClassFolder) + image_id + std::string(kSegmentationExtension); - RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target)); - (*trow) = TensorRow(row_id, {std::move(image), std::move(target)}); - } else if (task_type_ == TaskType::Detection) { - std::shared_ptr image, annotation; - const std::string kImageFile = - folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); - const std::string kAnnotationFile = - folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); - RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); - RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation)); - (*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)}); - } - return Status::OK(); -} - -Status VOCOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { - std::unique_ptr deq = std::make_unique(); - TensorRow trow; - for (const uint64_t &key : keys) { - RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); - deq->push_back(std::move(trow)); - } - (*db)->set_tensor_table(std::move(deq)); - return Status::OK(); -} - -Status VOCOp::WorkerEntry(int32_t worker_id) { - TaskManager::FindMe()->Post(); - int64_t buffer_id = worker_id; - std::unique_ptr io_block; - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - while (io_block != nullptr) { - if (io_block->eoe() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); - buffer_id = worker_id; - } else if (io_block->eof() == true) { - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - } else { - std::vector keys; - RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); - if (keys.empty() == true) return Status::OK(); - std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); - RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); - RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); - buffer_id += num_workers_; - } - RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); - } - RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); -} - -Status VOCOp::ParseImageIds() { - std::string image_sets_file; - if (task_type_ == TaskType::Segmentation) { - image_sets_file = - folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension); - } else if (task_type_ == TaskType::Detection) { - image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension); - } - std::ifstream in_file; - in_file.open(image_sets_file); - if (in_file.fail()) { - RETURN_STATUS_UNEXPECTED("Fail to open file: " + image_sets_file); - } - std::string id; - while (getline(in_file, id)) { - if (id.size() > 0 && id[id.size() - 1] == '\r') { - image_ids_.push_back(id.substr(0, id.size() - 1)); - } else { - image_ids_.push_back(id); - } - } - in_file.close(); - image_ids_.shrink_to_fit(); - num_rows_ = image_ids_.size(); - return Status::OK(); -} - -Status VOCOp::ParseAnnotationIds() { - std::vector new_image_ids; - for (auto id : image_ids_) { - const std::string kAnnotationName = - folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); - RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); - if (label_map_.find(kAnnotationName) != label_map_.end()) { - new_image_ids.push_back(id); - } - } - - if (image_ids_.size() != new_image_ids.size()) { - image_ids_.clear(); - image_ids_.insert(image_ids_.end(), new_image_ids.begin(), new_image_ids.end()); - } - uint32_t count = 0; - for (auto &label : label_index_) { - label.second = count++; - } - - num_rows_ = image_ids_.size(); - return Status::OK(); -} - -Status VOCOp::ParseAnnotationBbox(const std::string &path) { - if (!Path(path).Exists()) { - RETURN_STATUS_UNEXPECTED("File is not found : " + path); - } - Bbox bbox; - XMLDocument doc; - XMLError e = doc.LoadFile(common::SafeCStr(path)); - if (e != XMLError::XML_SUCCESS) { - RETURN_STATUS_UNEXPECTED("Xml load failed"); - } - XMLElement *root = doc.RootElement(); - if (root == nullptr) { - RETURN_STATUS_UNEXPECTED("Xml load root element error"); - } - XMLElement *object = root->FirstChildElement("object"); - if (object == nullptr) { - RETURN_STATUS_UNEXPECTED("No object find in " + path); - } - while (object != nullptr) { - std::string label_name; - float xmin = 0.0, ymin = 0.0, xmax = 0.0, ymax = 0.0, truncated = 0.0, difficult = 0.0; - XMLElement *name_node = object->FirstChildElement("name"); - if (name_node != nullptr && name_node->GetText() != 0) label_name = name_node->GetText(); - XMLElement *truncated_node = object->FirstChildElement("truncated"); - if (truncated_node != nullptr) truncated = truncated_node->FloatText(); - XMLElement *difficult_node = object->FirstChildElement("difficult"); - if (difficult_node != nullptr) difficult = difficult_node->FloatText(); - - XMLElement *bbox_node = object->FirstChildElement("bndbox"); - if (bbox_node != nullptr) { - XMLElement *xmin_node = bbox_node->FirstChildElement("xmin"); - if (xmin_node != nullptr) xmin = xmin_node->FloatText(); - XMLElement *ymin_node = bbox_node->FirstChildElement("ymin"); - if (ymin_node != nullptr) ymin = ymin_node->FloatText(); - XMLElement *xmax_node = bbox_node->FirstChildElement("xmax"); - if (xmax_node != nullptr) xmax = xmax_node->FloatText(); - XMLElement *ymax_node = bbox_node->FirstChildElement("ymax"); - if (ymax_node != nullptr) ymax = ymax_node->FloatText(); - } else { - RETURN_STATUS_UNEXPECTED("bndbox dismatch in " + path); - } - if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && - ymin > 0 && xmax > xmin && ymax > ymin) { - std::vector bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; - bbox.emplace_back(std::make_pair(label_name, bbox_list)); - label_index_[label_name] = 0; - } - object = object->NextSiblingElement("object"); - } - if (bbox.size() > 0) label_map_[path] = bbox; - return Status::OK(); -} - -Status VOCOp::InitSampler() { - RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); - return Status::OK(); -} - -Status VOCOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("tree_ not set"); - } - RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); - RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); - TaskManager::FindMe()->Post(); - RETURN_IF_NOT_OK(this->ParseImageIds()); - if (task_type_ == TaskType::Detection) { - RETURN_IF_NOT_OK(this->ParseAnnotationIds()); - } - RETURN_IF_NOT_OK(this->InitSampler()); - return Status::OK(); -} - -Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); - if (decode_ == true) { - Status rc = Decode(*tensor, tensor); - if (rc.IsError()) { - RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); - } - } - return Status::OK(); -} - -Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, - std::shared_ptr *tensor) { - Bbox bbox_info = label_map_[path]; - std::vector bbox_row; - dsize_t bbox_column_num = 0, bbox_num = 0; - for (auto box : bbox_info) { - if (label_index_.find(box.first) != label_index_.end()) { - std::vector bbox; - bbox.insert(bbox.end(), box.second.begin(), box.second.end()); - if (class_index_.find(box.first) != class_index_.end()) { - bbox.push_back(static_cast(class_index_[box.first])); - } else { - bbox.push_back(static_cast(label_index_[box.first])); - } - bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); - if (bbox_column_num == 0) { - bbox_column_num = static_cast(bbox.size()); - } - bbox_num++; - } - } - - std::vector bbox_dim = {bbox_num, bbox_column_num}; - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(), - reinterpret_cast(&bbox_row[0]))); - return Status::OK(); -} - -Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t *count) { - if (task_type == "Detection") { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - std::shared_ptr op; - RETURN_IF_NOT_OK( - Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - *count = static_cast(op->image_ids_.size()); - } else if (task_type == "Segmentation") { - std::shared_ptr op; - RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - *count = static_cast(op->image_ids_.size()); - } - - return Status::OK(); -} - -Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, std::map *output_class_indexing) { - std::map input_class_indexing; - for (auto p : dict) { - (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), - py::reinterpret_borrow(p.second))); - } - - if (!input_class_indexing.empty()) { - *output_class_indexing = input_class_indexing; - } else { - std::shared_ptr op; - RETURN_IF_NOT_OK( - Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); - RETURN_IF_NOT_OK(op->ParseImageIds()); - RETURN_IF_NOT_OK(op->ParseAnnotationIds()); - for (const auto label : op->label_index_) { - (*output_class_indexing).insert(std::make_pair(label.first, label.second)); - } - } - - return Status::OK(); -} -// Visitor accept method for NodePass -Status VOCOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status VOCOp::ComputeColMap() { - // Set the column name map (base class field) - if (column_name_id_map_.empty()) { - for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { - column_name_id_map_[data_schema_->column(i).name()] = i; - } - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h deleted file mode 100644 index ec46a3c7b1..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h +++ /dev/null @@ -1,294 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ -#define DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/data_schema.h" -#include "dataset/engine/datasetops/parallel_op.h" -#include "dataset/engine/datasetops/source/io_block.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/path.h" -#include "dataset/util/queue.h" -#include "dataset/util/status.h" -#include "dataset/util/wait_post.h" - -namespace mindspore { -namespace dataset { -// Forward declares -template -class Queue; - -using Bbox = std::vector>>; - -class VOCOp : public ParallelOp, public RandomAccessOp { - public: - enum class TaskType { Segmentation = 0, Detection = 1 }; - - class Builder { - public: - // Constructor for Builder class of ImageFolderOp - // @param uint32_t numWrks - number of parallel workers - // @param dir - directory folder got ImageNetFolder - Builder(); - - // Destructor. - ~Builder() = default; - - // Setter method. - // @param const std::string & build_dir - // @return Builder setter method returns reference to the builder. - Builder &SetDir(const std::string &build_dir) { - builder_dir_ = build_dir; - return *this; - } - - // Setter method. - // @param const std::map &map - a class name to label map - // @return Builder setter method returns reference to the builder. - Builder &SetClassIndex(const std::map &map) { - builder_labels_to_read_ = map; - return *this; - } - - // Setter method. - // @param const std::string & task_type - // @return Builder setter method returns reference to the builder. - Builder &SetTask(const std::string &task_type) { - if (task_type == "Segmentation") { - builder_task_type_ = TaskType::Segmentation; - } else if (task_type == "Detection") { - builder_task_type_ = TaskType::Detection; - } - return *this; - } - - // Setter method. - // @param const std::string & task_mode - // @return Builder setter method returns reference to the builder. - Builder &SetMode(const std::string &task_mode) { - builder_task_mode_ = task_mode; - return *this; - } - - // Setter method. - // @param int32_t num_workers - // @return Builder setter method returns reference to the builder. - Builder &SetNumWorkers(int32_t num_workers) { - builder_num_workers_ = num_workers; - return *this; - } - - // Setter method. - // @param int32_t op_connector_size - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // Setter method. - // @param int32_t rows_per_buffer - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @param std::shared_ptr sampler - // @return Builder setter method returns reference to the builder. - Builder &SetSampler(std::shared_ptr sampler) { - builder_sampler_ = std::move(sampler); - return *this; - } - - // Setter method. - // @param bool do_decode - // @return Builder setter method returns reference to the builder. - Builder &SetDecode(bool do_decode) { - builder_decode_ = do_decode; - return *this; - } - - // Check validity of input args - // @return = The error code return - Status SanityCheck(); - - // The builder "Build" method creates the final object. - // @param std::shared_ptr *op - DatasetOp - // @return - The error code return - Status Build(std::shared_ptr *op); - - private: - bool builder_decode_; - std::string builder_dir_; - TaskType builder_task_type_; - std::string builder_task_mode_; - int32_t builder_num_workers_; - int32_t builder_op_connector_size_; - int32_t builder_rows_per_buffer_; - std::shared_ptr builder_sampler_; - std::unique_ptr builder_schema_; - std::map builder_labels_to_read_; - }; - - // Constructor - // @param TaskType task_type - task type of VOC - // @param std::string task_mode - task mode of VOC - // @param std::string folder_path - dir directory of VOC - // @param std::map class_index - input class-to-index of annotation - // @param int32_t num_workers - number of workers reading images in parallel - // @param int32_t rows_per_buffer - number of images (rows) in each buffer - // @param int32_t queue_size - connector queue size - // @param bool decode - whether to decode images - // @param std::unique_ptr data_schema - the schema of the VOC dataset - // @param std::shared_ptr sampler - sampler tells VOCOp what to read - VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, - const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, - int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); - - // Destructor - ~VOCOp() = default; - - // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector - // @param int32_t workerId - id of each worker - // @return Status - The error code return - Status WorkerEntry(int32_t worker_id) override; - - // Main Loop of VOCOp - // Master thread: Fill IOBlockQueue, then goes to sleep - // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector - // @return Status - The error code return - Status operator()() override; - - // A print method typically used for debugging - // @param out - // @param show_all - void Print(std::ostream &out, bool show_all) const override; - - // @param const std::string &dir - VOC dir path - // @param const std::string &task_type - task type of reading voc job - // @param const std::string &task_mode - task mode of reading voc job - // @param const py::dict &dict - input dict of class index - // @param int64_t *count - output rows number of VOCDataset - static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, int64_t *count); - - // @param const std::string &dir - VOC dir path - // @param const std::string &task_type - task type of reading voc job - // @param const std::string &task_mode - task mode of reading voc job - // @param const py::dict &dict - input dict of class index - // @param int64_t numSamples - samples number of VOCDataset - // @param std::map *output_class_indexing - output class index of VOCDataset - static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, - const py::dict &dict, std::map *output_class_indexing); - - /// \brief Base-class override for NodePass visitor acceptor - /// \param[in] p Pointer to the NodePass to be accepted - /// \param[out] modified Indicator if the node was changed at all - /// \return Status of the node visit - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "VOCOp"; } - - private: - // Initialize Sampler, calls sampler->Init() within - // @return Status - The error code return - Status InitSampler(); - - // Load a tensor row according to image id - // @param row_id_type row_id - id for this tensor row - // @param std::string image_id - image id - // @param TensorRow row - image & target read into this tensor row - // @return Status - The error code return - Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); - - // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return - // @return Status - The error code return - Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); - - // @param const std::string &path - path to the image file - // @param const ColDescriptor &col - contains tensor implementation and datatype - // @param std::shared_ptr tensor - return - // @return Status - The error code return - Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); - - // @param const std::vector &keys - keys in ioblock - // @param std::unique_ptr db - // @return Status - The error code return - Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); - - // Read image list from ImageSets - // @return Status - The error code return - Status ParseImageIds(); - - // Read annotation from Annotation folder - // @return Status - The error code return - Status ParseAnnotationIds(); - - // @param const std::string &path - path to annotation xml - // @return Status - The error code return - Status ParseAnnotationBbox(const std::string &path); - - // @param const std::shared_ptr &sample_ids - sample ids of tensor - // @param std::vector *keys - image id - // @return Status - The error code return - Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); - - // Called first when function is called - // @return Status - The error code return - Status LaunchThreadsAndInitOp(); - - // Reset dataset state - // @return Status - The error code return - Status Reset() override; - - // Private function for computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - bool decode_; - int64_t row_cnt_; - int64_t buf_cnt_; - std::string folder_path_; - TaskType task_type_; - std::string task_mode_; - int32_t rows_per_buffer_; - std::unique_ptr data_schema_; - - WaitPost wp_; - std::vector image_ids_; - QueueList> io_block_queues_; - std::map class_index_; - std::map label_index_; - std::map label_map_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc deleted file mode 100644 index b9fd8a0663..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include "common/utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/take_op.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -// Builder constructor. Creates the builder object. -TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) { - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status TakeOp::Builder::SanityCheck() const { - if (build_max_takes_ <= 0) { - std::string err_msg("Take count must be greater than 0."); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -// The builder "build" method creates the final object. -Status TakeOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(build_max_takes_, builder_op_connector_size_); - return Status::OK(); -} - -// Constructor of the TakeOp. -TakeOp::TakeOp(int32_t count, int32_t op_connector_size) - : PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} - -// A print method typically used for debugging -void TakeOp::Print(std::ostream &out, bool show_all) const { - // Always show the id and name as first line regardless if this summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << " [takes: " << max_takes_ << "]\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nTake count: " << take_count_ << "\nMax takes: " << max_takes_ << "\n\n"; - } -} - -// Main entry point for Take -Status TakeOp::operator()() { - TaskManager::FindMe()->Post(); - std::unique_ptr buf; - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - - while (buf->eof() == false) { - if (take_count_ == max_takes_) { - // Do drain Operation - while (!buf->eoe() && !buf->eof()) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - } - } - - // Loop until non EOE is received - if (buf->eoe()) { - take_count_ = 0; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - continue; - } - - // Get buffer and push back when take_count is still small - if (take_count_ < max_takes_) { - std::unique_ptr p_buffer; - RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer)); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer))); - } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); - } - - take_count_ = 0; - MS_LOG(DEBUG) << "Meet the end and push-back eof buffer."; - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); - return Status::OK(); -} - -// Function FillBuffer mainly prepare the buffer for returning -Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer) { - int32_t buffer_size = (*buffer)->NumRows(); - if (take_count_ + buffer_size < max_takes_) { - *data_buffer = std::move(*buffer); - take_count_ = take_count_ + buffer_size; - } else { - MS_LOG(DEBUG) << "In last buffer: Push one buffer."; - std::unique_ptr new_tensor_table = std::make_unique(); - while (take_count_ < max_takes_) { - TensorRow new_row; - RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); - take_count_++; - new_tensor_table->push_back(new_row); - } - (*buffer)->set_tensor_table(std::move(new_tensor_table)); - *data_buffer = std::move(*buffer); - } - return Status::OK(); -} - -// Visitor accept method for NodePass -Status TakeOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/dataset/engine/datasetops/take_op.h deleted file mode 100644 index 07626d5f1f..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ -#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ - -#include -#include -#include -#include "dataset/engine/datasetops/pipeline_op.h" - -namespace mindspore { -namespace dataset { -class TakeOp : public PipelineOp { - public: - // The nested builder class inside of the TakeOp is used to help manage all of the arguments - // for constructing it. This take op is very simple though, so this builder is really just - // provided for a consistent look and feel for creators of Dataset operators overall. - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @param count - The number of takes to do - // @return This is a constructor. - explicit Builder(int32_t count); - - // Default destructor - ~Builder() = default; - - // The builder "build" method creates the final object. - // @return shared_ptr to the new TakeOp object - Status Build(std::shared_ptr *); - - private: - int32_t build_max_takes_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor of the TakeOp. - // @note The builder class should be used to call it - // @param count - The number of takes to do - explicit TakeOp(int32_t count, int32_t op_connector_size); - - // Destructor - ~TakeOp() = default; - - // A print method typically used for debugging - // @param out - The output stream to write output to - // @param show_all - A bool to control if you want to show all info or just a summary - void Print(std::ostream &out, bool show_all) const override; - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param ro - reference to the TakeOp to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { - ro.Print(out, false); - return out; - } - - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "TakeOp"; } - - private: - int32_t max_takes_; // The number of takes that the user requested - int32_t take_count_; // A counter for the current number of executed takes - - Status FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc deleted file mode 100644 index 70bce16a89..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.cc +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/datasetops/zip_op.h" -#include -#include -#include "dataset/core/constants.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/db_connector.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -ZipOp::Builder::Builder() { - // Some arguments to the ZipOp constructor have a default argument that is taken - // from the client config. - // The user may choose to change these values for the construction of the ZipOp by - // using the various builder set methods. - - std::shared_ptr cfg = GlobalContext::config_manager(); - builder_rows_per_buffer_ = cfg->rows_per_buffer(); - builder_op_connector_size_ = cfg->op_connector_size(); -} - -Status ZipOp::Builder::SanityCheck() const { return Status::OK(); } - -Status ZipOp::Builder::Build(std::shared_ptr *ptr) { - RETURN_IF_NOT_OK(SanityCheck()); - *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_); - return Status::OK(); -} - -// Construct ZipOp here, local variables initialized in operator due to tree construction restrictions -ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size) - : PipelineOp(op_connector_size), - children_num_(0), - rows_per_buffer_(rows_per_buffer), - buffer_id_(0), - draining_(false), - eof_(false) {} - -// destructor -ZipOp::~ZipOp() {} - -// Entry point for Zip, called by launch() -Status ZipOp::operator()() { - // The children_num_ parameter needs to be put here - children_num_ = child_.size(); - // Synchronize with TaskManager once the thread is created. - TaskManager::FindMe()->Post(); - - // initialize the iterators - for (int32_t i = 0; i < children_num_; ++i) { - // magic number 0 since Zip is not a parallel Op - child_iterators_.push_back(std::make_unique(this, 0, i)); - } - - // Loop until eof is true - while (!eof_) { - // Create tensor table and prepare it by fetching and packing the first zipped row into it. - std::unique_ptr curr_table = std::make_unique(); - RETURN_IF_NOT_OK(prepare(curr_table.get())); - - // If an eof got picked up during the above prepare, then we're done - if (eof_) { - break; - } - while (!draining_) { - // 1. If a previous loop iteration sent the current table out, then create a new one. - if (curr_table == nullptr) { - curr_table = std::make_unique(); - } - - // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done - RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); - - // 3 create and update buffer and send it to the out connector - if (!curr_table->empty()) { - std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); - curr_buffer->set_tensor_table(std::move(curr_table)); - MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " - << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); - buffer_id_++; - } - } - - // 4 handle drain state. - if (draining_) { - MS_LOG(DEBUG) << "Zip operator is now draining child inputs."; - RETURN_IF_NOT_OK(drainPipeline()); - // Now that we have drained child inputs, send the eoe up. - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); - } - } - - // 5 handle eof - // propagate eof here. - MS_LOG(DEBUG) << "Zip operator got EOF, propagating."; - RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); - return Status::OK(); -} - -// Handles preprocessing of the main loop, used when starting new epoch -Status ZipOp::prepare(TensorQTable *const table) { - MS_LOG(DEBUG) << "Zip operator prepares for new epoch."; - draining_ = false; - buffer_id_ = 0; - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table."); - } - // fill initial row - TensorRow new_row; - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - - // If the first row fetching resulted in eof, then we are done. - if (eof_) { - return Status::OK(); - } - if (new_row.empty()) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); - } - - // Pack this first row into our tensor table - table->push_back(std::move(new_row)); - - return Status::OK(); -} - -// fillBuffer always expects a new table to fill -Status ZipOp::fillBuffer(TensorQTable *const table) { - if (table == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer."); - } - TensorRow new_row; - while (table->size() < static_cast(rows_per_buffer_)) { - RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); - // Early exit the loop if we got empty row from any of our child iterations - if (new_row.empty()) { - return Status::OK(); - } - // else we got a row so pack it into the tensor table. - table->push_back(std::move(new_row)); - } - return Status::OK(); -} - -// fetches next zip buffer row (merged row) -Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) { - // iterate over all iterators and generate a row - for (int32_t i = 0; i < children_num_; ++i) { - TensorRow new_row = {}; - RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row)); - // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row - if (new_row.empty()) { - // If we did not get a row from any of the children, then it's the end of an epoch and we can move - // to drain state. - MS_LOG(DEBUG) << "Zip operator child iterator produced empty row."; - draining_ = true; - new_zip_row->clear(); - // If we picked up an eof here, then we are completely done. - if ((child_iterators_[i])->eof_handled()) { - MS_LOG(DEBUG) << "Zip operator iterator got EOF."; - eof_ = true; - } - return Status::OK(); - } else { - MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << "."; - // if row isn't empty then we can append the fetched row with new_zip_row - new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end()); - } - } - MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << "."; - return Status::OK(); -} - -// drain end of epoch messages from iterator for this epoch -Status ZipOp::drainPipeline() { - // we don't need to drain if we reached eof - if (eof_) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "ZipOp draining should not be done if already at eof!"); - } - for (int32_t con = 0; con < children_num_; ++con) { - MS_LOG(DEBUG) << "Zip operator draining child at " << con << "."; - RETURN_IF_NOT_OK(child_iterators_[con]->Drain()); - } - // at this point all connectors don't contain end of epoch messages. next iteration should be clean - return Status::OK(); -} - -// A function that prints info about the Operator -void ZipOp::Print(std::ostream &out, // In: The output stream to print to - bool show_all) const { // In: T/F if it should print everything - // Always show the id and name as first line regardless if this is summary or detailed print - out << "(" << std::setw(2) << operator_id_ << ") :"; - if (!show_all) { - // Call the super class for displaying any common 1-liner info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal 1-liner info for this op - out << "\n"; - } else { - // Call the super class for displaying any common detailed info - PipelineOp::Print(out, show_all); - // Then show any custom derived-internal stuff - out << "\nDatasets: " << children_num_ << "\n\n"; - } -} - -// overwrite function and handle eof -Status ZipOp::EofReceived(int32_t) { - MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now."; - return Status::OK(); -} - -// overwrite function and handle eoe -Status ZipOp::EoeReceived(int32_t) { - state_ = OpState::kDeOpIdle; - return Status::OK(); -} - -// Visitor accept method for NodePass -Status ZipOp::Accept(NodePass *p, bool *modified) { - // Downcast shared pointer then call visitor - return p->RunOnNode(shared_from_base(), modified); -} - -Status ZipOp::ComputeColMap() { - if (column_name_id_map_.empty()) { - column_name_id_map_ = {}; - for (int32_t i = 0; i < child_.size(); ++i) { - // Initializing col_name_id_map from the child. - const std::unordered_map col_name_id_map = child_[i]->column_name_id_map(); - int32_t colsCurrent = column_name_id_map_.size(); - // the update code below shouldn't do anything bad if the column name already exists. - for (const auto &pair : col_name_id_map) { - std::string name = pair.first; - int32_t old_id = pair.second; - // check if name already exists in column name descriptor - if (column_name_id_map_.count(name) == 1) { - RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets"); - } - column_name_id_map_[name] = old_id + colsCurrent; - } - } - MS_LOG(DEBUG) << "Setting column map:\n" << this->ColumnNameMapAsString(); - } else { - MS_LOG(WARNING) << "Column name map is already set!"; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h deleted file mode 100644 index fad3c22eaa..0000000000 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ -#define DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/engine/dataset_iterator.h" -#include "dataset/engine/datasetops/pipeline_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// forward declare -class DataBuffer; - -class ZipOp : public PipelineOp { - public: - // The nested builder class inside of the ZipOp is used to help manage all of - // the arguments for constructing it. Use the builder by setting each argument - // with the provided set methods, and then finally call the build method to execute - // the actual construction. - // NOTE: the rows per buffer with initial value 0 means to default to the number of rows from the first child - - class Builder { - public: - // Builder constructor. Creates the builder object. - // @note No default args - // @return This is a constructor. - Builder(); - - // Default destructor - ~Builder() = default; - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { - builder_rows_per_buffer_ = rows_per_buffer; - return *this; - } - - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetOpConnectorSize(int32_t op_connector_size) { - builder_op_connector_size_ = op_connector_size; - return *this; - } - - // The builder "build" method creates the ZipOp dataset Operator. - // @return shared_ptr to the new ZipOp object - Status Build(std::shared_ptr *); - - private: - int32_t builder_rows_per_buffer_; - int32_t builder_op_connector_size_; - - Status SanityCheck() const; - }; - - // Constructor for ZipOp - // @param rows_per_buffer - number of rows in output buffer - // @param op_connector_size - connector size - ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); - - // Destructor - ~ZipOp(); - - Status EofReceived(int32_t) override; - - Status EoeReceived(int32_t) override; - - // Print function for Zip - // @param out - output stream to print to - // @param show_all - if it should print everything - void Print(std::ostream &out, bool show_all) const override; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const ZipOp &zo) { - zo.Print(out, false); - return out; - } - - // Class functor operator () override. - // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will - // provide the master loop that drives the logic for performing the work - // @return Status - The error code return - Status operator()() override; - - // Base-class override for NodePass visitor acceptor. - // @param p - Pointer to the NodePass to be accepted. - // @param modified - Whether this node visit modified the pipeline. - // @return - Status of the node visit. - Status Accept(NodePass *p, bool *modified) override; - - // Op name getter - // @return Name of the current Op - std::string Name() const override { return "ZipOp"; } - - private: - // Handles preprocessing of the main loop, used when starting new epoch - Status prepare(TensorQTable *const table); - - // This function calls takes a table repeatedly adds rows to it. - // @param table a table of tensors to be moved into a buffer - Status fillBuffer(TensorQTable *const table); - - // Special handle case where an empty row has been received from child iterator - // @note - we need to drain eoe signals from all children connectors. - // @details - when this function is called, then we encountered eoe at child iterator - // we have to drain rows from other child iterators until we hit eoe from all other child iterators - Status drainPipeline(); - - // Merges 1 row from each childIterator together - // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty - // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true - // @details merge rows from iterator together. This is the main functionality for ZipOp - // this function takes one row and fills it with tensors from rows fetched - // from childIterators. - // @example: - // Zips multiple rows at a time, the output is store in newZipRow - // 1 a T - // \ | / - // 1, a, T - Status getNextTensorRow(TensorRow *const new_zip_row); - - // Computing the assignment of the column name map. - // @return - Status - Status ComputeColMap() override; - - int32_t children_num_; - int32_t rows_per_buffer_; - int32_t buffer_id_; - bool draining_; - bool eof_; - std::vector> child_iterators_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/db_connector.h b/mindspore/ccsrc/dataset/engine/db_connector.h deleted file mode 100644 index 54909f51ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/db_connector.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_DB_CONNECTOR_H_ -#define DATASET_ENGINE_DB_CONNECTOR_H_ - -#include -#include -#include "dataset/engine/connector.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { -// DbConnector is a derived class from Connector with added logic to handle EOE and EOF. -// The Connector class itself is responsible to ensure deterministic order on every run. -class DbConnector : public Connector> { - public: - // Constructor of DbConnector - // @note DbConnector will create internal N number of blocking queues, where N = nProducers. - // See Connector.h for more details. - // @param n_producers The number of threads producing data into this DbConnector. - // @param n_consumers The number of thread consuming data from this DbConnector. - // @param queue_capacity The number of element (DataBuffer) for each internal queue. - DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) - : Connector>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {} - - // Destructor of DbConnector - ~DbConnector() = default; - - // Add a unique_ptr into the DbConnector. - // @note The caller of this add method should use std::move to pass the ownership to DbConnector. - // @param worker_id The id of a worker thread calling this method. - // @param el A rvalue reference to an element to be passed/added/pushed. - Status Add(int32_t worker_id, std::unique_ptr &&el) noexcept { - return (Connector>::Push(worker_id, std::move(el))); - } - - // Get a unique_ptr from the DbConnector. - // @note After the first EOF Buffer is encountered, subsequent pop()s will return EOF Buffer. - // This will provide/propagate the EOF to all consumer threads of this Connector. - // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues - // and reset() must be called before reusing DbConnector. - // @param worker_id The id of a worker thread calling this method. - // @param result The address of a unique_ptr where the popped element will be placed. - // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer. - Status PopWithRetry(int32_t worker_id, std::unique_ptr *result, bool retry_if_eoe = false) noexcept { - if (result == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "[ERROR] nullptr detected when getting data from db connector"); - } else { - std::unique_lock lk(m_); - RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return (expect_consumer_ == worker_id) || end_of_file_; })); - // Once an EOF message is encountered this flag will be set and we can return early. - if (end_of_file_) { - *result = std::make_unique(0, DataBuffer::kDeBFlagEOF); - } else { - RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); - if (*result == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "[ERROR] nullptr detected when getting data from db connector"); - } - // Setting the internal flag once the first EOF is encountered. - if ((*result)->eof()) { - end_of_file_ = true; - } - pop_from_ = (pop_from_ + 1) % num_producers_; - } - // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set. - if (!((*result)->eoe() && retry_if_eoe)) { - expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; - } - } - out_buffers_count_++; - cv_.NotifyAll(); - return Status::OK(); - } - - private: - // A flag to indicate the end of stream has been encountered. - bool end_of_file_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_DB_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc deleted file mode 100644 index b816cb3487..0000000000 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ /dev/null @@ -1,312 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/execution_tree.h" -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/util/task_manager.h" -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/opt/pre/removal_pass.h" -#include "dataset/engine/opt/pre/cache_transform_pass.h" -#include "dataset/engine/opt/post/repeat_pass.h" -#include "mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.h" -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/perf/monitor.h" - -namespace mindspore { -namespace dataset { -// Constructor -ExecutionTree::ExecutionTree() : id_count_(0) { - tg_ = std::make_unique(); - tree_state_ = kDeTStateInit; - prepare_flags_ = kDePrepNone; - perf_monitor_ = std::make_unique(this); - profiling_manager_ = std::make_unique(this); - optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false; -} - -// Destructor -ExecutionTree::~ExecutionTree() { (void)tg_->ServiceStop(); } - -// Associates a DatasetOp with this tree. This assigns a valid node id to the operator and -// provides it with a link to the tree. A node cannot form any relationships (parent/child) with -// other nodes unless they are associated with the same tree. -Status ExecutionTree::AssociateNode(const std::shared_ptr &op) { - // If we are already a part of the tree, no-op - if (op->tree_ == this) { - return Status::OK(); - } - if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { - std::string err_msg = - "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected states: " + std::to_string(static_cast(kDeTStateInit)) + " or " + - std::to_string(static_cast(kDeTStateBuilding)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // Enter the building state if we were not already there - tree_state_ = kDeTStateBuilding; - - // Assign an id to the operator - op->set_id(id_count_); - id_count_++; - - // Assign our tree into the op so that each op has a link back to the tree - op->set_tree(this); - return Status::OK(); -} - -// Sets the root node of the tree -Status ExecutionTree::AssignRoot(const std::shared_ptr &op) { - // Tree must be in building state before we can assign root to it - if (tree_state_ != kDeTStateBuilding) { - std::string err_msg = - "Invalid tree state for assigning a root node. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStateBuilding)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - - // If they didn't already call AssociateNode for this node before calling AssignRoot, - // then do so now. - if (op->operator_id_ == DatasetOp::kInvalidOperatorId) { - RETURN_IF_NOT_OK(this->AssociateNode(op)); - } - - // Then add it as the root. - root_ = op; - - return Status::OK(); -} - -// A print method typically used for debugging -void ExecutionTree::Print(std::ostream &out, const std::shared_ptr &op) const { - out << "Execution tree summary:\n" - << "-----------------------\n"; - this->PrintNode(out, op == nullptr ? root_ : op, "", true, false); - out << "\nExecution tree operator details:\n" - << "--------------------------------\n"; - this->PrintNode(out, op == nullptr ? root_ : op, "", true, true); -} - -// A helper functions for doing the recursive printing -void ExecutionTree::PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, - bool last, bool detailed) const { - // Decide which printer to use based on detailed arg. - if (!detailed) { - out << indent << "+- " << *dataset_op; - indent += (last ? " " : "| "); - } else { - dataset_op->Print(out, detailed); - } - - // Descend to children - for (int32_t i = 0; i < dataset_op->child_.size(); ++i) { - this->PrintNode(out, dataset_op->child_[i], indent, (i == (dataset_op->child_.size() - 1)), detailed); - } -} - -// Start the execution of the tree -Status ExecutionTree::Launch() { - // Tree must be built and prepared before it can be launched! - if (tree_state_ != kDeTStateReady) { - std::string err_msg = - "Invalid tree state for launching tree. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStateReady)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - std::ostringstream ss; - ss << *this; - - // Profiling infrastructures need to be initialized before Op launching - if (profiling_manager_->IsProfilingEnable()) { - // Setup profiling manager - RETURN_IF_NOT_OK(profiling_manager_->Initialize()); - // Launch Monitor Thread - RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Monitor Thread launched", std::ref(*perf_monitor_))); - } - - MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); - for (auto itr = this->begin(); itr != this->end(); ++itr) { - // An inlined operator is one that has an output connector size of 0, and it does not - // require a thread to execute. Instead, the work of this operator is executed inlined - // from the tree node directly above it (or in the case of a root node, it runs from within - // the launching tree/user thread. Do not exec any thread for an inlined op. - itr->state_ = DatasetOp::OpState::kDeOpRunning; - if (!itr->inlined()) { - RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); - // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp - } - } - - tree_state_ = kDeTStateExecuting; - - return Status::OK(); -} - -// A function that traverse the tree in postorder then save the results in nodes -void ExecutionTree::Iterator::PostOrderTraverse(const std::shared_ptr &node) { - if (node == nullptr) { - return; - } - for (int32_t i = 0; i < node->child_.size(); ++i) { - PostOrderTraverse(node->child_[i]); - } - nodes_.push_back(node); -} - -ExecutionTree::Iterator::Iterator(const std::shared_ptr &root) : ind_(0) { - // post-order traverse the tree, if root is null, it return - PostOrderTraverse(root); - nodes_.emplace_back(nullptr); -} - -// Given the number of workers, launches the worker entry function for each. Essentially a -// wrapper for the TaskGroup handling that is stored inside the execution tree. -Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func) { - // Launch the workers - for (int32_t i = 0; i < num_workers; ++i) { - RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); - } - return Status::OK(); -} - -// The driver of the prepare phase of the execution tree. -// Prepare phase consists of three sub phases -// -// 1. PrepareTreePreAction() -// Compulsory transformation/action pre optimization. -// For example, CacheOp Insertion -// -// 2. Optimize() -// Optimization transformation/action, optional -// For example, MapOp Fusion -// -// 3. PrepareTreePostAction() -// Compulsory transformation/action post optimization. -// For example, repeatOp inlining -// -// @return Status - The error code return -Status ExecutionTree::Prepare() { - // Pre optimization compulsory transformation - RETURN_IF_NOT_OK(this->PrepareTreePreAction()); - - // If optional optimizations are enabled - if (optimize_) { - RETURN_IF_NOT_OK(this->Optimize()); - } - - // Post optimization compulsory transformation - RETURN_IF_NOT_OK(this->PrepareTreePostAction()); - - // Existing transformation implementation, will be removed later - RETURN_IF_NOT_OK(this->PrepareDeprecated()); - return Status::OK(); -} - -Status ExecutionTree::PrepareTreePreAction() { - bool modified = false; - std::vector> pre_actions; - // Construct pre actions - MS_LOG(INFO) << "Running pre pass loops."; - pre_actions.push_back(std::make_unique()); - pre_actions.push_back(std::make_unique()); - // Apply pre action passes - for (auto &pass : pre_actions) { - RETURN_IF_NOT_OK(pass->Run(this, &modified)); - } - MS_LOG(INFO) << "Pre passes complete."; - return Status::OK(); -} - -Status ExecutionTree::PrepareTreePostAction() { - // The tree is ready to be prepared. - tree_state_ = kDeTStatePrepare; - - bool modified = false; - std::vector> post_actions; - // Construct pre actions - MS_LOG(INFO) << "Running post pass loops."; - post_actions.push_back(std::make_unique()); - - // Apply post action passes - for (auto &pass : post_actions) { - RETURN_IF_NOT_OK(pass->Run(this, &modified)); - } - MS_LOG(INFO) << "Post passes complete."; - - return Status::OK(); -} - -Status ExecutionTree::Optimize() { - // Vector of optimizations, currently only 1, add more as necessary - std::vector> optimizations; - optimizations.push_back(std::make_unique()); - // vector of flags for each optimization - std::vector modified(optimizations.size(), false); - for (auto i = 0; i < optimizations.size(); i++) { - auto m = false; - optimizations[i]->Run(this, &m); - modified[i] = m; - } - return Status::OK(); -} - -// The driver of the prepare phase of the execution tree. The prepare phase will recursively -// walk the tree to perform modifications to the tree or specific nodes within the tree to get -// it ready for execution. -// -// This driver is deprecated. -Status ExecutionTree::PrepareDeprecated() { - // Tree must be in pending prepare state before we can assign root to it - if (tree_state_ != kDeTStatePrepare) { - std::string err_msg = - "Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast(tree_state_)) + - " Expected state: " + std::to_string(static_cast(kDeTStatePrepare)); - RETURN_STATUS_UNEXPECTED(err_msg); - } - // Start the recursive prepare - RETURN_IF_NOT_OK(this->PrepareNode(root_)); - tree_state_ = kDeTStateReady; - return Status::OK(); -} - -// Recursive function used during prepare phase to visit a node and drive any pre- and post- -// node actions during a tree walk. -Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) { - // execute PreAction - RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); - - // Before going down into children, make any prepare flags updates based on this operator. - uint32_t op_prep_flags = dataset_op->PrepareFlags(); - BitSet(&prepare_flags_, op_prep_flags); - - // Now, descend to children - for (const auto &i : dataset_op->child_) { - RETURN_IF_NOT_OK(this->PrepareNode(i)); - } - - // No more children, now we execute any prepare actions before going back up the - // the tree on recursive function - RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); - - // Then clear the flags from this op now that we have prepared it. - BitClear(&prepare_flags_, op_prep_flags); - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.h b/mindspore/ccsrc/dataset/engine/execution_tree.h deleted file mode 100644 index 465d200856..0000000000 --- a/mindspore/ccsrc/dataset/engine/execution_tree.h +++ /dev/null @@ -1,257 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_EXECUTION_TREE_H_ -#define DATASET_ENGINE_EXECUTION_TREE_H_ - -#include -#include -#include -#include -#include -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/util/status.h" -#include "mindspore/ccsrc/dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -// Forward declares -class TaskGroup; -class DatasetOp; -class Monitor; - -class ExecutionTree { - public: - // Prepare flags used during tree prepare phase - enum PrepareFlags { - kDePrepNone = 0, - kDePrepRepeat = 1, // Processing a repeat operation - kDePrepCache = 2 // Processing a cache operation - }; - - // State flags for the lifecycle of the tree - enum TreeState { - kDeTStateInit = 0, // The freshly initialized state after construction - kDeTStateBuilding, // The tree is being built, nodes are being added - kDeTStatePrepare, // The tree has been assigned a root node and is pending prepare - kDeTStateReady, // The tree has been prepared and is ready to be launched - kDeTStateExecuting, // The tree has been launched and is executing - kDeTStateFinished // The tree has been drained, dataset iterator received EOF - }; - - class Iterator { - public: - // Constructor - // @param root The root node to start iterating from - explicit Iterator(const std::shared_ptr &root = nullptr); - - // Destructor - ~Iterator() {} - - Iterator &operator++() { - ++ind_; - return *this; - } // prefix ++ overload - Iterator operator++(int) { - Iterator it = *this; - it.ind_ = ind_; - ind_++; - return it; - } // post-fix ++ overload - Iterator &operator--() { - --ind_; - return *this; - } // prefix -- overload - Iterator operator--(int) { - Iterator it = *this; - it.ind_ = ind_; - ind_--; - return it; - } // post-fix -- overload - DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator - std::shared_ptr operator->() { return nodes_[ind_]; } - - // getter function - // @return Shared pointer to the current operator - std::shared_ptr get() { return nodes_[ind_]; } - - bool operator==(const Iterator &rhs) { return nodes_[ind_] == rhs.nodes_[rhs.ind_]; } - - bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; } - - int32_t NumNodes() { return nodes_.size(); } - - private: - int32_t ind_; // the cur node our Iterator points to - std::vector> nodes_; // store the nodes in post order - void PostOrderTraverse(const std::shared_ptr &); - }; - - // Constructor - ExecutionTree(); - - // Destructor - ~ExecutionTree(); - - // Associates a DatasetOp with this tree. This assigns a valid node id to the operator and - // provides it with a link to the tree. A node cannot form any relationships (parent/child) with - // other nodes unless they are associated with the same tree. - // @param op - The operator to associate - // @return Status - The error code return - Status AssociateNode(const std::shared_ptr &op); - - // Sets the root node of the tree - // @param op - The operator to assign as root - // @return Status - The error code return - Status AssignRoot(const std::shared_ptr &op); - - // Start the execution of the tree - // @return Status - The error code return - Status Launch(); - - /// A print method typically used for debugging - /// \param out - The output stream to write output to - void Print(std::ostream &out, const std::shared_ptr &op = nullptr) const; - - // Returns an iterator positioned at the start - // @return Iterator - The iterator - ExecutionTree::Iterator begin(const std::shared_ptr &root = nullptr) const { - return Iterator(root == nullptr ? root_ : root); - } - - // Returns an iterator positioned at the end - // @return Iterator - The iterator - ExecutionTree::Iterator end() const { return Iterator(nullptr); } - - // << Stream output operator overload - // @notes This allows you to write the debug print info using stream operators - // @param out - reference to the output stream being overloaded - // @param exe_tree - reference to the execution tree to display - // @return - the output stream must be returned - friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) { - exe_tree.Print(out); - return out; - } - - // Given the number of workers, launches the worker entry function for each. Essentially a - // wrapper for the TaskGroup handling that is stored inside the execution tree. - // @param num_workers - The number of workers to launch - // @param func - The function entry point that workers will execute - // @return Status - The error code return - Status LaunchWorkers(int32_t num_workers, std::function func); - - // Getter method - // @return shared_ptr to the root operator - std::shared_ptr root() const { return root_; } - - // Getter method - // @return the prepare flags - uint32_t PrepareFlags() const { return prepare_flags_; } - - // The driver of the prepare phase of the execution tree. - // Prepare phase consists of three sub phases - // - // 1. PrepareTreePreAction() - // Compulsory transformation/action pre optimization. - // For example, CacheOp Insertion - // - // 2. Optimize() - // Optimization transformation/action, optional - // For example, MapOp Fusion - // - // 3. PrepareTreePostAction() - // Compulsory transformation/action post optimization. - // For example, repeatOp inlining - // - // @return Status - The error code return - Status Prepare(); - - // Compulsory transformation/action pre optimization. - // @return Status - The error code return - Status PrepareTreePreAction(); - - // Compulsory transformation/action post optimization. - // @return Status - The error code return - Status PrepareTreePostAction(); - - // Optimization transformation/action, optional. - // @return Status - The error code return - Status Optimize(); - - // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively - // walk the tree to perform modifications to the tree or specific nodes within the tree to get - // it ready for execution. - // @return Status - The error code return - Status PrepareDeprecated(); - - // Recursive function used during prepare phase to visit a node and drive any pre- and post- - // node actions during a tree walk. - // @param op - The dataset op to work on - // @return Status - The error code return - Status PrepareNode(const std::shared_ptr &dataset_op); - - // Return the pointer to the TaskGroup - // @return raw pointer to the TaskGroup - TaskGroup *AllTasks() const { return tg_.get(); } - - // Return if the ExecutionTree is finished (iterator receives EOF). - // @return Bool - true is ExecutionTree is finished - bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } - - // Set the ExecutionTree to Finished state. - void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } - - // Getter for profiling manager, no ownership - ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } - - // Set optional optimization if tree has not been prepared yet - Status SetOptimize(bool value) { - if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { - std::string optimize = (optimize_ == true) ? "true" : "false"; - std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize; - RETURN_STATUS_UNEXPECTED(msg); - } else { - optimize_ = value; - return Status::OK(); - } - } - - // Optional optimizations status - bool OptimizationEnabled() const { return optimize_; } - - private: - // A helper functions for doing the recursive printing - // @param dataset_op - The dataset op to print - // @param indent - an indent string for aligning child levels in output - // @param last - an indicator if it's the last child or not - // @param detailed - should it display the detailed node output or the summary line - void PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, bool last, - bool detailed) const; - - std::unique_ptr tg_; // Class for worker management - std::shared_ptr root_; // The root node of the tree - int32_t id_count_; // Counter for generating operator id's - uint32_t prepare_flags_; // Flags used during tree prepare - TreeState tree_state_; // Tracking the current tree state - std::unique_ptr perf_monitor_; // Performance Monitor - std::unique_ptr profiling_manager_; // Profiling manager - bool optimize_; // Flag to enable optional optimizations -}; - -inline bool operator==(const ExecutionTree::Iterator &lhs, const ExecutionTree::Iterator &rhs) { return lhs == rhs; } -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_EXECUTION_TREE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/edge.h b/mindspore/ccsrc/dataset/engine/gnn/edge.h deleted file mode 100644 index 47314d97c2..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/edge.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_EDGE_H_ -#define DATASET_ENGINE_GNN_EDGE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/node.h" - -namespace mindspore { -namespace dataset { -namespace gnn { -using EdgeType = int8_t; -using EdgeIdType = int32_t; - -class Edge { - public: - // Constructor - // @param EdgeIdType id - edge id - // @param EdgeType type - edge type - // @param std::shared_ptr src_node - source node - // @param std::shared_ptr dst_node - destination node - Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} - - virtual ~Edge() = default; - - // @return NodeIdType - Returned edge id - EdgeIdType id() const { return id_; } - - // @return NodeIdType - Returned edge type - EdgeType type() const { return type_; } - - // Get the feature of a edge - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; - - // Get nodes on the edge - // @param std::pair, std::shared_ptr> *out_node - Source and destination nodes returned - Status GetNode(std::pair, std::shared_ptr> *out_node) { - *out_node = std::make_pair(src_node_, dst_node_); - return Status::OK(); - } - - // Set node to edge - // @param const std::pair, std::shared_ptr> &in_node - - Status SetNode(const std::pair, std::shared_ptr> &in_node) { - src_node_ = in_node.first; - dst_node_ = in_node.second; - return Status::OK(); - } - - // Update feature of edge - // @param std::shared_ptr feature - - // @return Status - The error code return - virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; - - protected: - EdgeIdType id_; - EdgeType type_; - std::shared_ptr src_node_; - std::shared_ptr dst_node_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_EDGE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/feature.cc b/mindspore/ccsrc/dataset/engine/gnn/feature.cc deleted file mode 100644 index e457947821..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/feature.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/feature.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -Feature::Feature(FeatureType type_name, std::shared_ptr value) : type_name_(type_name), value_(value) {} - -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/feature.h b/mindspore/ccsrc/dataset/engine/gnn/feature.h deleted file mode 100644 index 7ce5967fbd..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/feature.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_FEATURE_H_ -#define DATASET_ENGINE_GNN_FEATURE_H_ - -#include - -#include "dataset/core/tensor.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -namespace gnn { -using FeatureType = int16_t; - -class Feature { - public: - // Constructor - // @param FeatureType type_name - feature type - // @param std::shared_ptr value - feature value - Feature(FeatureType type_name, std::shared_ptr value); - - ~Feature() = default; - - // Get feature value - // @return std::shared_ptr *out_value - feature value - const std::shared_ptr Value() const { return value_; } - - // @return NodeIdType - Returned feature type - FeatureType type() const { return type_name_; } - - private: - FeatureType type_name_; - std::shared_ptr value_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_FEATURE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc deleted file mode 100644 index bf67772fe5..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.cc +++ /dev/null @@ -1,681 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/graph.h" - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor_shape.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -Graph::Graph(std::string dataset_file, int32_t num_workers) - : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { - rnd_.seed(GetSeed()); - MS_LOG(INFO) << "num_workers:" << num_workers; -} - -Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { - auto itr = node_type_map_.find(node_type); - if (itr == node_type_map_.end()) { - std::string err_msg = "Invalid node type:" + std::to_string(node_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); - } - return Status::OK(); -} - -template -Status Graph::CreateTensorByVector(const std::vector> &data, DataType type, - std::shared_ptr *out) { - if (!type.IsCompatible()) { - RETURN_STATUS_UNEXPECTED("Data type not compatible"); - } - if (data.empty()) { - RETURN_STATUS_UNEXPECTED("Input data is empty"); - } - std::shared_ptr tensor; - size_t m = data.size(); - size_t n = data[0].size(); - RETURN_IF_NOT_OK(Tensor::CreateTensor( - &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); - auto ptr = tensor->begin(); - for (const auto &id_m : data) { - CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); - for (const auto &id_n : id_m) { - *ptr = id_n; - ptr++; - } - } - tensor->Squeeze(); - *out = std::move(tensor); - return Status::OK(); -} - -template -Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { - if (!data || data->empty()) { - RETURN_STATUS_UNEXPECTED("Input data is empty"); - } - for (std::vector &vec : *data) { - size_t size = vec.size(); - if (size > max_size) { - RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal"); - } else { - for (size_t i = 0; i < (max_size - size); ++i) { - vec.push_back(default_value); - } - } - } - return Status::OK(); -} - -Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { - auto itr = edge_type_map_.find(edge_type); - if (itr == edge_type_map_.end()) { - std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); - } - return Status::OK(); -} - -Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { - if (edge_list.empty()) { - RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); - } - - std::vector> node_list; - node_list.reserve(edge_list.size()); - for (const auto &edge_id : edge_list) { - auto itr = edge_id_map_.find(edge_id); - if (itr == edge_id_map_.end()) { - std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - std::pair, std::shared_ptr> nodes; - RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); - node_list.push_back({nodes.first->id(), nodes.second->id()}); - } - } - RETURN_IF_NOT_OK(CreateTensorByVector(node_list, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out) { - CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); - RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); - - std::vector> neighbors; - size_t max_neighbor_num = 0; - neighbors.resize(node_list.size()); - for (size_t i = 0; i < node_list.size(); ++i) { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); - RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); - max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); - } - - RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); - - return Status::OK(); -} - -Status Graph::CheckSamplesNum(NodeIdType samples_num) { - NodeIdType all_nodes_number = - std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, - [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); - if ((samples_num < 1) || (samples_num > all_nodes_number)) { - std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) + - ", got " + std::to_string(samples_num); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -Status Graph::CheckNeighborType(NodeType neighbor_type) { - if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { - std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } - return Status::OK(); -} - -Status Graph::GetSampledNeighbors(const std::vector &node_list, - const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { - CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); - CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), - "The sizes of neighbor_nums and neighbor_types are inconsistent."); - for (const auto &num : neighbor_nums) { - RETURN_IF_NOT_OK(CheckSamplesNum(num)); - } - for (const auto &type : neighbor_types) { - RETURN_IF_NOT_OK(CheckNeighborType(type)); - } - std::vector> neighbors_vec(node_list.size()); - for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { - std::shared_ptr input_node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node)); - neighbors_vec[node_idx].emplace_back(node_list[node_idx]); - std::vector input_list = {node_list[node_idx]}; - for (size_t i = 0; i < neighbor_nums.size(); ++i) { - std::vector neighbors; - neighbors.reserve(input_list.size() * neighbor_nums[i]); - for (const auto &node_id : input_list) { - if (node_id == kDefaultNodeId) { - for (int32_t j = 0; j < neighbor_nums[i]; ++j) { - neighbors.emplace_back(kDefaultNodeId); - } - } else { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); - std::vector out; - RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); - neighbors.insert(neighbors.end(), out.begin(), out.end()); - } - } - neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); - input_list = std::move(neighbors); - } - } - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::NegativeSample(const std::vector &data, const std::unordered_set &exclude_data, - int32_t samples_num, std::vector *out_samples) { - CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); - std::vector shuffled_id(data.size()); - std::iota(shuffled_id.begin(), shuffled_id.end(), 0); - std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); - for (const auto &index : shuffled_id) { - if (exclude_data.find(data[index]) != exclude_data.end()) { - continue; - } - out_samples->emplace_back(data[index]); - if (out_samples->size() >= samples_num) { - break; - } - } - return Status::OK(); -} - -Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out) { - CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); - RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); - RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); - - std::vector> neg_neighbors_vec; - neg_neighbors_vec.resize(node_list.size()); - for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); - std::vector neighbors; - RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); - std::unordered_set exclude_nodes; - std::transform(neighbors.begin(), neighbors.end(), - std::insert_iterator>(exclude_nodes, exclude_nodes.begin()), - [](const NodeIdType node) { return node; }); - const std::vector &all_nodes = node_type_map_[neg_neighbor_type]; - neg_neighbors_vec[node_idx].emplace_back(node->id()); - if (all_nodes.size() > exclude_nodes.size()) { - while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { - RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), - &neg_neighbors_vec[node_idx])); - } - } else { - MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() - << " neg_neighbor_type:" << neg_neighbor_type; - // If there are no negative neighbors, they are filled with kDefaultNodeId - for (int32_t i = 0; i < samples_num; ++i) { - neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); - } - } - } - RETURN_IF_NOT_OK(CreateTensorByVector(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, NodeIdType default_node, - std::shared_ptr *out) { - RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); - std::vector> walks; - RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); - RETURN_IF_NOT_OK(CreateTensorByVector({walks}, DataType(DataType::DE_INT32), out)); - return Status::OK(); -} - -Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = default_node_feature_map_.find(feature_type); - if (itr == default_node_feature_map_.end()) { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - *out_feature = itr->second; - } - return Status::OK(); -} - -Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = default_edge_feature_map_.find(feature_type); - if (itr == default_edge_feature_map_.end()) { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - *out_feature = itr->second; - } - return Status::OK(); -} - -Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, - TensorRow *out) { - if (!nodes || nodes->Size() == 0) { - RETURN_STATUS_UNEXPECTED("Input nodes is empty"); - } - CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); - TensorRow tensors; - for (const auto &f_type : feature_types) { - std::shared_ptr default_feature; - // If no feature can be obtained, fill in the default value - RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); - - TensorShape shape(default_feature->Value()->shape()); - auto shape_vec = nodes->shape().AsVector(); - dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); - shape = shape.PrependDim(size); - std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); - - dsize_t index = 0; - for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { - std::shared_ptr feature; - if (*node_itr == kDefaultNodeId) { - feature = default_feature; - } else { - std::shared_ptr node; - RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); - if (!node->GetFeatures(f_type, &feature).IsOk()) { - feature = default_feature; - } - } - RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); - index++; - } - - TensorShape reshape(nodes->shape()); - for (auto s : default_feature->Value()->shape().AsVector()) { - reshape = reshape.AppendDim(s); - } - RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); - fea_tensor->Squeeze(); - tensors.push_back(fea_tensor); - } - *out = std::move(tensors); - return Status::OK(); -} - -Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, - TensorRow *out) { - if (!edges || edges->Size() == 0) { - RETURN_STATUS_UNEXPECTED("Input edges is empty"); - } - CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); - TensorRow tensors; - for (const auto &f_type : feature_types) { - std::shared_ptr default_feature; - // If no feature can be obtained, fill in the default value - RETURN_IF_NOT_OK(GetEdgeDefaultFeature(f_type, &default_feature)); - - TensorShape shape(default_feature->Value()->shape()); - auto shape_vec = edges->shape().AsVector(); - dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); - shape = shape.PrependDim(size); - std::shared_ptr fea_tensor; - RETURN_IF_NOT_OK( - Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); - - dsize_t index = 0; - for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { - std::shared_ptr edge; - RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); - std::shared_ptr feature; - if (!edge->GetFeatures(f_type, &feature).IsOk()) { - feature = default_feature; - } - RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); - index++; - } - - TensorShape reshape(edges->shape()); - for (auto s : default_feature->Value()->shape().AsVector()) { - reshape = reshape.AppendDim(s); - } - RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); - fea_tensor->Squeeze(); - tensors.push_back(fea_tensor); - } - *out = std::move(tensors); - return Status::OK(); -} - -Status Graph::Init() { - RETURN_IF_NOT_OK(LoadNodeAndEdge()); - return Status::OK(); -} - -Status Graph::GetMetaInfo(MetaInfo *meta_info) { - meta_info->node_type.resize(node_type_map_.size()); - std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), - [](auto itr) { return itr.first; }); - std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); - - meta_info->edge_type.resize(edge_type_map_.size()); - std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), - [](auto itr) { return itr.first; }); - std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); - - for (const auto &node : node_type_map_) { - meta_info->node_num[node.first] = node.second.size(); - } - - for (const auto &edge : edge_type_map_) { - meta_info->edge_num[edge.first] = edge.second.size(); - } - - for (const auto &node_feature : node_feature_map_) { - for (auto type : node_feature.second) { - meta_info->node_feature_type.emplace_back(type); - } - } - std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); - auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); - meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); - - for (const auto &edge_feature : edge_feature_map_) { - for (const auto &type : edge_feature.second) { - meta_info->edge_feature_type.emplace_back(type); - } - } - std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); - auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); - meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); - return Status::OK(); -} - -#ifdef ENABLE_PYTHON -Status Graph::GraphInfo(py::dict *out) { - MetaInfo meta_info; - RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); - (*out)["node_type"] = py::cast(meta_info.node_type); - (*out)["edge_type"] = py::cast(meta_info.edge_type); - (*out)["node_num"] = py::cast(meta_info.node_num); - (*out)["edge_num"] = py::cast(meta_info.edge_num); - (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); - (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); - return Status::OK(); -} -#endif - -Status Graph::LoadNodeAndEdge() { - GraphLoader gl(dataset_file_, num_workers_); - // ask graph_loader to load everything into memory - RETURN_IF_NOT_OK(gl.InitAndLoad()); - // get all maps - RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, - &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, - &default_edge_feature_map_)); - return Status::OK(); -} - -Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { - auto itr = node_id_map_.find(id); - if (itr == node_id_map_.end()) { - std::string err_msg = "Invalid node id:" + std::to_string(id); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - *node = itr->second; - } - return Status::OK(); -} - -Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { - auto itr = edge_id_map_.find(id); - if (itr == edge_id_map_.end()) { - std::string err_msg = "Invalid edge id:" + std::to_string(id); - RETURN_STATUS_UNEXPECTED(err_msg); - } else { - *edge = itr->second; - } - return Status::OK(); -} - -Graph::RandomWalkBase::RandomWalkBase(Graph *graph) - : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} - -Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, const NodeIdType default_node, - int32_t num_walks, int32_t num_workers) { - CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); - node_list_ = node_list; - if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { - std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + - ". The size of input path is " + std::to_string(meta_path.size()); - RETURN_STATUS_UNEXPECTED(err_msg); - } - for (const auto &type : meta_path) { - RETURN_IF_NOT_OK(graph_->CheckNeighborType(type)); - } - meta_path_ = meta_path; - if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { - std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + - std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + - ", step_away_param: " + std::to_string(step_away_param); - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (default_node < -1) { - std::string err_msg = "Failed, default_node required to be greater or equal to -1."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (num_walks <= 0) { - std::string err_msg = "Failed, num_walks parameter required to be greater than 0"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - if (num_workers <= 0) { - std::string err_msg = "Failed, num_workers parameter required to be greater than 0"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - step_home_param_ = step_home_param; - step_away_param_ = step_away_param; - default_node_ = default_node; - num_walks_ = num_walks; - num_workers_ = num_workers; - return Status::OK(); -} - -Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { - // Simulate a random walk starting from start node. - auto walk = std::vector(1, start_node); // walk is an vector - // walk simulate - while (walk.size() - 1 < meta_path_.size()) { - // current nodE - auto cur_node_id = walk.back(); - std::shared_ptr cur_node; - RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); - - // current neighbors - std::vector cur_neighbors; - RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); - std::sort(cur_neighbors.begin(), cur_neighbors.end()); - - // break if no neighbors - if (cur_neighbors.empty()) { - break; - } - - // walk by the fist node, then by the previous 2 nodes - std::shared_ptr stochastic_index; - if (walk.size() == 1) { - RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); - } else { - NodeIdType prev_node_id = walk[walk.size() - 2]; - RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); - } - NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; - walk.push_back(next_node_id); - } - - while (walk.size() - 1 < meta_path_.size()) { - walk.push_back(default_node_); - } - - *walk_path = std::move(walk); - return Status::OK(); -} - -Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { - for (int32_t i = 0; i < num_walks_; i++) { - for (const auto &node : node_list_) { - std::vector walk; - RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); - walks->push_back(walk); - } - } - return Status::OK(); -} - -Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, - std::shared_ptr *node_probability) { - // Generate alias nodes - std::shared_ptr node; - graph_->GetNodeByNodeId(node_id, &node); - std::vector neighbors; - RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); - std::sort(neighbors.begin(), neighbors.end()); - auto non_normalized_probability = std::vector(neighbors.size(), 1.0); - *node_probability = - std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); - return Status::OK(); -} - -Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, - std::shared_ptr *edge_probability) { - // Get the alias edge setup lists for a given edge. - std::shared_ptr src_node; - graph_->GetNodeByNodeId(src, &src_node); - std::vector src_neighbors; - RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); - - std::shared_ptr dst_node; - graph_->GetNodeByNodeId(dst, &dst_node); - std::vector dst_neighbors; - RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); - - std::sort(dst_neighbors.begin(), dst_neighbors.end()); - std::vector non_normalized_probability; - for (const auto &dst_nbr : dst_neighbors) { - if (dst_nbr == src) { - non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] - continue; - } - auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); - if (it != src_neighbors.end()) { - // stay close, this node connect both src and dst - non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] - } else { - // step far away - non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] - } - } - - *edge_probability = - std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); - return Status::OK(); -} - -StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { - uint32_t K = probability.size(); - std::vector switch_to_large_index(K, 0); - std::vector weight(K, .0); - std::vector smaller; - std::vector larger; - auto random_device = GetRandomDevice(); - std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); - float accumulate_threshold = 0.0; - for (uint32_t i = 0; i < K; i++) { - float threshold_one = distribution(random_device); - accumulate_threshold += threshold_one; - weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; - weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); - } - - while ((!smaller.empty()) && (!larger.empty())) { - uint32_t small = smaller.back(); - smaller.pop_back(); - uint32_t large = larger.back(); - larger.pop_back(); - switch_to_large_index[small] = large; - weight[large] = weight[large] + weight[small] - 1.0; - weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); - } - return StochasticIndex(switch_to_large_index, weight); -} - -uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { - auto switch_to_large_index = stochastic_index.first; - auto weight = stochastic_index.second; - const uint32_t size_of_index = switch_to_large_index.size(); - - auto random_device = GetRandomDevice(); - std::uniform_real_distribution<> distribution(0.0, 1.0); - - // Generate random integer between [0, K) - uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); - - if (distribution(random_device) < weight[random_idx]) { - return random_idx; - } - return switch_to_large_index[random_idx]; -} - -template -std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { - float sum_probability = - 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); - if (sum_probability < kGnnEpsilon) { - sum_probability = 1.0; - } - std::vector normalized_probability; - std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), - std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); - return normalized_probability; -} -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h deleted file mode 100644 index 7a50440b27..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph.h +++ /dev/null @@ -1,267 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_GRAPH_H_ -#define DATASET_ENGINE_GNN_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" -#include "dataset/engine/gnn/graph_loader.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/edge.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -const float kGnnEpsilon = 0.0001; -const uint32_t kMaxNumWalks = 80; -using StochasticIndex = std::pair, std::vector>; - -struct MetaInfo { - std::vector node_type; - std::vector edge_type; - std::map node_num; - std::map edge_num; - std::vector node_feature_type; - std::vector edge_feature_type; -}; - -class Graph { - public: - // Constructor - // @param std::string dataset_file - - // @param int32_t num_workers - number of parallel threads - Graph(std::string dataset_file, int32_t num_workers); - - ~Graph() = default; - - // Get all nodes from the graph. - // @param NodeType node_type - type of node - // @param std::shared_ptr *out - Returned nodes id - // @return Status - The error code return - Status GetAllNodes(NodeType node_type, std::shared_ptr *out); - - // Get all edges from the graph. - // @param NodeType edge_type - type of edge - // @param std::shared_ptr *out - Returned edge ids - // @return Status - The error code return - Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); - - // Get the node id from the edge. - // @param std::vector edge_list - List of edges - // @param std::shared_ptr *out - Returned node ids - // @return Status - The error code return - Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); - - // All neighbors of the acquisition node. - // @param std::vector node_list - List of nodes - // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported - // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is - // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors - // is not enough, fill in tensor as -1. - // @return Status - The error code return - Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out); - - // Get sampled neighbors. - // @param std::vector node_list - List of nodes - // @param std::vector neighbor_nums - Number of neighbors sampled per hop - // @param std::vector neighbor_types - Neighbor type sampled per hop - // @param std::shared_ptr *out - Returned neighbor's id. - // @return Status - The error code return - Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out); - - // Get negative sampled neighbors. - // @param std::vector node_list - List of nodes - // @param NodeIdType samples_num - Number of neighbors sampled - // @param NodeType neg_neighbor_type - The type of negative neighbor. - // @param std::shared_ptr *out - Returned negative neighbor's id. - // @return Status - The error code return - Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out); - - // Node2vec random walk. - // @param std::vector node_list - List of nodes - // @param std::vector meta_path - node type of each step - // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm - // @param NodeIdType default_node - default node id - // @param std::shared_ptr *out - Returned nodes id in walk path - // @return Status - The error code return - Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, NodeIdType default_node, - std::shared_ptr *out); - - // Get the feature of a node - // @param std::shared_ptr nodes - List of nodes - // @param std::vector feature_types - Types of features, An error will be reported if the feature type - // does not exist. - // @param TensorRow *out - Returned features - // @return Status - The error code return - Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, - TensorRow *out); - - // Get the feature of a edge - // @param std::shared_ptr edget - List of edges - // @param std::vector feature_types - Types of features, An error will be reported if the feature type - // does not exist. - // @param Tensor *out - Returned features - // @return Status - The error code return - Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, - TensorRow *out); - - // Get meta information of graph - // @param MetaInfo *meta_info - Returned meta information - // @return Status - The error code return - Status GetMetaInfo(MetaInfo *meta_info); - -#ifdef ENABLE_PYTHON - // Return meta information to python layer - Status GraphInfo(py::dict *out); -#endif - - Status Init(); - - private: - class RandomWalkBase { - public: - explicit RandomWalkBase(Graph *graph); - - Status Build(const std::vector &node_list, const std::vector &meta_path, - float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, - int32_t num_walks = 1, int32_t num_workers = 1); - - ~RandomWalkBase() = default; - - Status SimulateWalk(std::vector> *walks); - - private: - Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); - - Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, - std::shared_ptr *node_probability); - - Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, - std::shared_ptr *edge_probability); - - static StochasticIndex GenerateProbability(const std::vector &probability); - - static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); - - template - std::vector Normalize(const std::vector &non_normalized_probability); - - Graph *graph_; - std::vector node_list_; - std::vector meta_path_; - float step_home_param_; // Return hyper parameter. Default is 1.0 - float step_away_param_; // Inout hyper parameter. Default is 1.0 - NodeIdType default_node_; - - int32_t num_walks_; // Number of walks per source. Default is 1 - int32_t num_workers_; // The number of worker threads. Default is 1 - }; - - // Load graph data from mindrecord file - // @return Status - The error code return - Status LoadNodeAndEdge(); - - // Create Tensor By Vector - // @param std::vector> &data - - // @param DataType type - - // @param std::shared_ptr *out - - // @return Status - The error code return - template - Status CreateTensorByVector(const std::vector> &data, DataType type, std::shared_ptr *out); - - // Complete vector - // @param std::vector> *data - To be completed vector - // @param size_t max_size - The size of the completed vector - // @param T default_value - Filled default - // @return Status - The error code return - template - Status ComplementVector(std::vector> *data, size_t max_size, T default_value); - - // Get the default feature of a node - // @param FeatureType feature_type - - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); - - // Get the default feature of a edge - // @param FeatureType feature_type - - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); - - // Find node object using node id - // @param NodeIdType id - - // @param std::shared_ptr *node - Returned node object - // @return Status - The error code return - Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); - - // Find edge object using edge id - // @param EdgeIdType id - - // @param std::shared_ptr *edge - Returned edge object - // @return Status - The error code return - Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge); - - // Negative sampling - // @param std::vector &input_data - The data set to be sampled - // @param std::unordered_set &exclude_data - Data to be excluded - // @param int32_t samples_num - - // @param std::vector *out_samples - Sampling results returned - // @return Status - The error code return - Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, - int32_t samples_num, std::vector *out_samples); - - Status CheckSamplesNum(NodeIdType samples_num); - - Status CheckNeighborType(NodeType neighbor_type); - - std::string dataset_file_; - int32_t num_workers_; // The number of worker threads - std::mt19937 rnd_; - RandomWalkBase random_walk_; - - std::unordered_map> node_type_map_; - std::unordered_map> node_id_map_; - - std::unordered_map> edge_type_map_; - std::unordered_map> edge_id_map_; - - std::unordered_map> node_feature_map_; - std::unordered_map> edge_feature_map_; - - std::unordered_map> default_node_feature_map_; - std::unordered_map> default_edge_feature_map_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_GRAPH_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc deleted file mode 100644 index f3374954b6..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc +++ /dev/null @@ -1,260 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "dataset/engine/gnn/graph_loader.h" -#include "mindspore/ccsrc/mindrecord/include/shard_error.h" -#include "dataset/engine/gnn/local_edge.h" -#include "dataset/engine/gnn/local_node.h" -#include "dataset/util/task_manager.h" - -using ShardTuple = std::vector, mindspore::mindrecord::json>>; - -namespace mindspore { -namespace dataset { -namespace gnn { - -using mindrecord::MSRStatus; - -GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) - : mr_path_(mr_filepath), - num_workers_(num_workers), - row_id_(0), - shard_reader_(nullptr), - keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} - -Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, - EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, - EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, - DefaultEdgeFeatureMap *default_edge_feature_map) { - for (std::deque> &dq : n_deques_) { - while (dq.empty() == false) { - std::shared_ptr node_ptr = dq.front(); - n_id_map->insert({node_ptr->id(), node_ptr}); - (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); - dq.pop_front(); - } - } - - for (std::deque> &dq : e_deques_) { - while (dq.empty() == false) { - std::shared_ptr edge_ptr = dq.front(); - std::pair, std::shared_ptr> p; - RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); - auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); - CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); - CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); - RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); - RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); - e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ - (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); - dq.pop_front(); - } - } - - for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); - for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); - - MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); - return Status::OK(); -} - -Status GraphLoader::InitAndLoad() { - CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n"); - CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n"); - n_deques_.resize(num_workers_); - e_deques_.resize(num_workers_); - n_feature_maps_.resize(num_workers_); - e_feature_maps_.resize(num_workers_); - default_node_feature_maps_.resize(num_workers_); - default_edge_feature_maps_.resize(num_workers_); - TaskGroup vg; - - shard_reader_ = std::make_unique(); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, - "Fail to open" + mr_path_); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); - - mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; - for (const std::string &key : keys_) { - if (schema.find(key) == schema.end()) { - RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); - } - } - - // launching worker threads - for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { - RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); - } - // wait for threads to finish and check its return code - vg.join_all(Task::WaitFlag::kBlocking); - RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny()); - return Status::OK(); -} - -Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrecord::json &col_jsn, - std::shared_ptr *node, NodeFeatureMap *feature_map, - DefaultNodeFeatureMap *default_feature) { - NodeIdType node_id = col_jsn["first_id"]; - NodeType node_type = static_cast(col_jsn["type"]); - (*node) = std::make_shared(node_id, node_type); - std::vector indices; - RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); - - for (int32_t ind : indices) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); - RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); - (*feature_map)[node_type].insert(ind); - if ((*default_feature)[ind] == nullptr) { - std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); - RETURN_IF_NOT_OK(zero_tensor->Zero()); - (*default_feature)[ind] = std::make_shared(ind, zero_tensor); - } - } - return Status::OK(); -} - -Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrecord::json &col_jsn, - std::shared_ptr *edge, EdgeFeatureMap *feature_map, - DefaultEdgeFeatureMap *default_feature) { - EdgeIdType edge_id = col_jsn["first_id"]; - EdgeType edge_type = static_cast(col_jsn["type"]); - NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; - std::shared_ptr src = std::make_shared(src_id, -1); - std::shared_ptr dst = std::make_shared(dst_id, -1); - (*edge) = std::make_shared(edge_id, edge_type, src, dst); - std::vector indices; - RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); - for (int32_t ind : indices) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); - RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); - (*feature_map)[edge_type].insert(ind); - if ((*default_feature)[ind] == nullptr) { - std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); - RETURN_IF_NOT_OK(zero_tensor->Zero()); - (*default_feature)[ind] = std::make_shared(ind, zero_tensor); - } - } - return Status::OK(); -} - -Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, - const mindrecord::json &col_jsn, std::shared_ptr *tensor) { - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0, col_type_size = 1; - mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; - std::vector column_shape; - MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( - key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); - if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, - std::move(TensorShape({static_cast(n_bytes / col_type_size)})), - std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); - return Status::OK(); -} - -Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, - const mindrecord::json &col_jsn, std::vector *indices) { - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0, col_type_size = 1; - mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; - std::vector column_shape; - MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( - key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); - - if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - - for (int i = 0; i < n_bytes; i += col_type_size) { - int32_t feature_ind = -1; - if (col_type == mindrecord::ColumnInt32) { - feature_ind = *(reinterpret_cast(data + i)); - } else if (col_type == mindrecord::ColumnInt64) { - feature_ind = *(reinterpret_cast(data + i)); - } else { - RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); - } - if (feature_ind >= 0) indices->push_back(feature_ind); - } - return Status::OK(); -} - -Status GraphLoader::WorkerEntry(int32_t worker_id) { - // Handshake - TaskManager::FindMe()->Post(); - auto ret = shard_reader_->GetNextById(row_id_++, worker_id); - ShardTuple rows = ret.second; - while (rows.empty() == false) { - RETURN_IF_INTERRUPTED(); - for (const auto &tupled_row : rows) { - std::vector col_blob = std::get<0>(tupled_row); - mindrecord::json col_jsn = std::get<1>(tupled_row); - std::string attr = col_jsn["attribute"]; - if (attr == "n") { - std::shared_ptr node_ptr; - RETURN_IF_NOT_OK(LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), - &default_node_feature_maps_[worker_id])); - n_deques_[worker_id].emplace_back(node_ptr); - } else if (attr == "e") { - std::shared_ptr edge_ptr; - RETURN_IF_NOT_OK(LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), - &default_edge_feature_maps_[worker_id])); - e_deques_[worker_id].emplace_back(edge_ptr); - } else { - MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; - } - } - auto rc = shard_reader_->GetNextById(row_id_++, worker_id); - rows = rc.second; - } - return Status::OK(); -} - -void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, - DefaultNodeFeatureMap *default_node_feature_map, - DefaultEdgeFeatureMap *default_edge_feature_map) { - for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { - for (auto &m : n_feature_maps_[wkr_id]) { - for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); - } - for (auto &m : e_feature_maps_[wkr_id]) { - for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); - } - for (auto &m : default_node_feature_maps_[wkr_id]) { - (*default_node_feature_map)[m.first] = m.second; - } - for (auto &m : default_edge_feature_maps_[wkr_id]) { - (*default_edge_feature_map)[m.first] = m.second; - } - } - n_feature_maps_.clear(); - e_feature_maps_.clear(); -} - -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h deleted file mode 100644 index 141816d633..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_ -#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/graph.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/edge.h" -#include "dataset/util/status.h" -#include "mindrecord/include/shard_reader.h" -namespace mindspore { -namespace dataset { -namespace gnn { - -using mindrecord::ShardReader; -using NodeIdMap = std::unordered_map>; -using EdgeIdMap = std::unordered_map>; -using NodeTypeMap = std::unordered_map>; -using EdgeTypeMap = std::unordered_map>; -using NodeFeatureMap = std::unordered_map>; -using EdgeFeatureMap = std::unordered_map>; -using DefaultNodeFeatureMap = std::unordered_map>; -using DefaultEdgeFeatureMap = std::unordered_map>; - -// this class interfaces with the underlying storage format (mindrecord) -// it returns raw nodes and edges via GetNodesAndEdges -// it is then the responsibility of graph to construct itself based on the nodes and edges -// if needed, this class could become a base where each derived class handles a specific storage format -class GraphLoader { - public: - explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); - - ~GraphLoader() = default; - // Init mindrecord and load everything into memory multi-threaded - // @return Status - the status code - Status InitAndLoad(); - - // this function will query mindrecord and construct all nodes and edges - // nodes and edges are added to map without any connection. That's because there nodes and edges are read in - // random order. src_node and dst_node in Edge are node_id only with -1 as type. - // features attached to each node and edge are expected to be filled correctly - Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, - DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); - - private: - // - // worker thread that reads mindrecord file - // @param int32_t worker_id - id of each worker - // @return Status - the status code - Status WorkerEntry(int32_t worker_id); - - // Load a node based on 1 row of mindrecord, returns a shared_ptr - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *node - return value - // @param NodeFeatureMap *feature_map - - // @param DefaultNodeFeatureMap *default_feature - - // @return Status - the status code - Status LoadNode(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *node, - NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature); - - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *edge - return value, the edge ptr, edge is not yet connected - // @param FeatureMap *feature_map - // @param DefaultEdgeFeatureMap *default_feature - - // @return Status - the status code - Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, - EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); - - // @param std::string key - column name - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::vector *ind - return value, list of feature index in int32_t - // @return Status - the status code - Status LoadFeatureIndex(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, - std::vector *ind); - - // @param std::string &key - column name - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *tensor - return value feature tensor - // @return Status - the status code - Status LoadFeatureTensor(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, - std::shared_ptr *tensor); - - // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 - void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); - - const int32_t num_workers_; - std::atomic_int row_id_; - std::string mr_path_; - std::unique_ptr shard_reader_; - std::vector>> n_deques_; - std::vector>> e_deques_; - std::vector n_feature_maps_; - std::vector e_feature_maps_; - std::vector default_node_feature_maps_; - std::vector default_edge_feature_maps_; - const std::vector keys_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc deleted file mode 100644 index 7465b689d5..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/local_edge.h" - -#include - -namespace mindspore { -namespace dataset { -namespace gnn { - -LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : Edge(id, type, src_node, dst_node) {} - -Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = features_.find(feature_type); - if (itr != features_.end()) { - *out_feature = itr->second; - return Status::OK(); - } else { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } -} - -Status LocalEdge::UpdateFeature(const std::shared_ptr &feature) { - auto itr = features_.find(feature->type()); - if (itr != features_.end()) { - RETURN_STATUS_UNEXPECTED("Feature already exists"); - } else { - features_[feature->type()] = feature; - return Status::OK(); - } -} -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/dataset/engine/gnn/local_edge.h deleted file mode 100644 index a34fc00373..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_edge.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_ -#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/edge.h" -#include "dataset/engine/gnn/feature.h" -#include "dataset/engine/gnn/node.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -class LocalEdge : public Edge { - public: - // Constructor - // @param EdgeIdType id - edge id - // @param EdgeType type - edge type - // @param std::shared_ptr src_node - source node - // @param std::shared_ptr dst_node - destination node - LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); - - ~LocalEdge() = default; - - // Get the feature of a edge - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; - - // Update feature of edge - // @param std::shared_ptr feature - - // @return Status - The error code return - Status UpdateFeature(const std::shared_ptr &feature) override; - - private: - std::unordered_map> features_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc deleted file mode 100644 index c829f8e8ca..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/gnn/local_node.h" - -#include -#include -#include - -#include "dataset/engine/gnn/edge.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } - -Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { - auto itr = features_.find(feature_type); - if (itr != features_.end()) { - *out_feature = itr->second; - return Status::OK(); - } else { - std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); - RETURN_STATUS_UNEXPECTED(err_msg); - } -} - -Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, bool exclude_itself) { - std::vector neighbors; - auto itr = neighbor_nodes_.find(neighbor_type); - if (itr != neighbor_nodes_.end()) { - if (exclude_itself) { - neighbors.resize(itr->second.size()); - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), - [](const std::shared_ptr node) { return node->id(); }); - } else { - neighbors.resize(itr->second.size() + 1); - neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, - [](const std::shared_ptr node) { return node->id(); }); - } - } else { - MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; - if (!exclude_itself) { - neighbors.emplace_back(id_); - } - } - *out_neighbors = std::move(neighbors); - return Status::OK(); -} - -Status LocalNode::GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, - std::vector *out) { - std::vector shuffled_id(neighbors.size()); - std::iota(shuffled_id.begin(), shuffled_id.end(), 0); - std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); - int32_t num = std::min(samples_num, static_cast(neighbors.size())); - for (int32_t i = 0; i < num; ++i) { - out->emplace_back(neighbors[shuffled_id[i]]->id()); - } - return Status::OK(); -} - -Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, - std::vector *out_neighbors) { - std::vector neighbors; - neighbors.reserve(samples_num); - auto itr = neighbor_nodes_.find(neighbor_type); - if (itr != neighbor_nodes_.end()) { - while (neighbors.size() < samples_num) { - RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); - } - } else { - MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; - // If there are no neighbors, they are filled with kDefaultNodeId - for (int32_t i = 0; i < samples_num; ++i) { - neighbors.emplace_back(kDefaultNodeId); - } - } - *out_neighbors = std::move(neighbors); - return Status::OK(); -} - -Status LocalNode::AddNeighbor(const std::shared_ptr &node) { - auto itr = neighbor_nodes_.find(node->type()); - if (itr != neighbor_nodes_.end()) { - itr->second.push_back(node); - } else { - neighbor_nodes_[node->type()] = {node}; - } - return Status::OK(); -} - -Status LocalNode::UpdateFeature(const std::shared_ptr &feature) { - auto itr = features_.find(feature->type()); - if (itr != features_.end()) { - RETURN_STATUS_UNEXPECTED("Feature already exists"); - } else { - features_[feature->type()] = feature; - return Status::OK(); - } -} - -} // namespace gnn -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/dataset/engine/gnn/local_node.h deleted file mode 100644 index bc069d073f..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/local_node.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_ -#define DATASET_ENGINE_GNN_LOCAL_NODE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/feature.h" - -namespace mindspore { -namespace dataset { -namespace gnn { - -class LocalNode : public Node { - public: - // Constructor - // @param NodeIdType id - node id - // @param NodeType type - node type - LocalNode(NodeIdType id, NodeType type); - - ~LocalNode() = default; - - // Get the feature of a node - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; - - // Get the all neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, - bool exclude_itself = false) override; - - // Get the sampled neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param int32_t samples_num - Number of neighbors to be acquired - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, - std::vector *out_neighbors) override; - - // Add neighbor of node - // @param std::shared_ptr node - - // @return Status - The error code return - Status AddNeighbor(const std::shared_ptr &node) override; - - // Update feature of node - // @param std::shared_ptr feature - - // @return Status - The error code return - Status UpdateFeature(const std::shared_ptr &feature) override; - - private: - Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, - std::vector *out); - - std::mt19937 rnd_; - std::unordered_map> features_; - std::unordered_map>> neighbor_nodes_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/node.h b/mindspore/ccsrc/dataset/engine/gnn/node.h deleted file mode 100644 index 282f856797..0000000000 --- a/mindspore/ccsrc/dataset/engine/gnn/node.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_GNN_NODE_H_ -#define DATASET_ENGINE_GNN_NODE_H_ - -#include -#include -#include - -#include "dataset/util/status.h" -#include "dataset/engine/gnn/feature.h" - -namespace mindspore { -namespace dataset { -namespace gnn { -using NodeType = int8_t; -using NodeIdType = int32_t; - -constexpr NodeIdType kDefaultNodeId = -1; - -class Node { - public: - // Constructor - // @param NodeIdType id - node id - // @param NodeType type - node type - Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} - - virtual ~Node() = default; - - // @return NodeIdType - Returned node id - NodeIdType id() const { return id_; } - - // @return NodeIdType - Returned node type - NodeType type() const { return type_; } - - // Get the feature of a node - // @param FeatureType feature_type - type of feature - // @param std::shared_ptr *out_feature - Returned feature - // @return Status - The error code return - virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; - - // Get the all neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, - bool exclude_itself = false) = 0; - - // Get the sampled neighbors of a node - // @param NodeType neighbor_type - type of neighbor - // @param int32_t samples_num - Number of neighbors to be acquired - // @param std::vector *out_neighbors - Returned neighbors id - // @return Status - The error code return - virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, - std::vector *out_neighbors) = 0; - - // Add neighbor of node - // @param std::shared_ptr node - - // @return Status - The error code return - virtual Status AddNeighbor(const std::shared_ptr &node) = 0; - - // Update feature of node - // @param std::shared_ptr feature - - // @return Status - The error code return - virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; - - protected: - NodeIdType id_; - NodeType type_; -}; -} // namespace gnn -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_GNN_NODE_H_ diff --git a/mindspore/ccsrc/dataset/engine/jagged_connector.h b/mindspore/ccsrc/dataset/engine/jagged_connector.h deleted file mode 100644 index 2058c542a8..0000000000 --- a/mindspore/ccsrc/dataset/engine/jagged_connector.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_JAGGED_CONNECTOR_H_ -#define DATASET_ENGINE_JAGGED_CONNECTOR_H_ - -#include -#include -#include -#include -#include "dataset/engine/connector.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/util/status.h" -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { -class JaggedConnector : public Connector> { - public: - JaggedConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity) - : Connector>(num_producers, num_consumers, queue_capacity) { - for (int i = 0; i < num_producers; i++) { - is_queue_finished_.push_back(false); - } - } - - ~JaggedConnector() = default; - - Status Add(int32_t worker_d, std::unique_ptr &&element) noexcept { - return Connector>::Push(worker_d, std::move(element)); - } - - Status Pop(int32_t worker_id, std::unique_ptr *result) noexcept override { - { - MS_ASSERT(worker_id < num_consumers_); - std::unique_lock lock(m_); - RETURN_IF_NOT_OK(cv_.Wait(&lock, [this, worker_id]() { return expect_consumer_ == worker_id; })); - if (is_queue_finished_[pop_from_]) { - std::string errMsg = "ERROR: popping from a finished queue in JaggedConnector"; - RETURN_STATUS_UNEXPECTED(errMsg); - } - - RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); - if ((*result)->eoe()) { - is_queue_finished_[pop_from_] = true; - } - - for (int offset = 1; offset <= num_producers_; offset++) { - int32_t nextQueueIndex = (pop_from_ + offset) % num_producers_; - if (is_queue_finished_[nextQueueIndex] == false) { - pop_from_ = nextQueueIndex; - break; - } - } - - expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; - } - - cv_.NotifyAll(); - return Status::OK(); - } - - void DoReset() { - for (int i = 0; i < is_queue_finished_.size(); i++) { - is_queue_finished_[i] = false; - } - - Connector>::Reset(); - } - - private: - std::vector is_queue_finished_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_JAGGED_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.cc b/mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.cc deleted file mode 100644 index 67b742cf6e..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/optional/tensor_op_fusion_pass.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/kernels/image/random_crop_decode_resize_op.h" - -namespace mindspore { -namespace dataset { - -Status TensorOpFusionPass::RunOnNode(std::shared_ptr node, bool *modified) { - // Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp - // Abstract into a more general member function that can find any pattern, expressed - // by regular expressions, for instance. - // Add a list of optimisation policies. For now, just this lambda - auto FindPattern = [](auto &tfuncs) { - auto it = - std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; }); - auto next = it + 1; - if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) { - return it; - } else { - return tfuncs.end(); - } - }; - - auto &tfuncs = node->TFuncs(); - auto it = FindPattern(tfuncs); - if (it != tfuncs.end()) { - auto next = it + 1; - auto op = static_cast(next->get()); - *it = std::static_pointer_cast(std::make_shared(*op)); - tfuncs.erase(next); - } - if (modified != nullptr) { - *modified = true; - } else { - RETURN_STATUS_UNEXPECTED("modified is nullptr"); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.h b/mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.h deleted file mode 100644 index e7fa4f076b..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/optional/tensor_op_fusion_pass.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TENSOR_OP_FUSION_PASS_H_ -#define DATASET_TENSOR_OP_FUSION_PASS_H_ - -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -/// \class TensorOpFusionPass tensor_op_fusion_pass.h -/// \brief And optional optimization pass identifying and fusing -/// tensor ops within MapOp -class TensorOpFusionPass : public NodePass { - /// \brief Identifies and fuses tensor ops within MapOp - /// \param[in] node The node being visited - /// \param[inout] *modified indicates whether the node has been visited - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TENSOR_OP_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.cc b/mindspore/ccsrc/dataset/engine/opt/pass.cc deleted file mode 100644 index 17689224ea..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pass.cc +++ /dev/null @@ -1,248 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/datasetops/batch_op.h" -#include "dataset/engine/datasetops/cache_op.h" -#include "dataset/engine/datasetops/cache_merge_op.h" -#include "dataset/engine/datasetops/cache_lookup_op.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/datasetops/device_queue_op.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/project_op.h" -#include "dataset/engine/datasetops/rename_op.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/skip_op.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/source/celeba_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#ifdef ENABLE_PYTHON -#include "dataset/engine/datasetops/filter_op.h" -#include "dataset/engine/datasetops/source/generator_op.h" -#endif -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/take_op.h" -#include "dataset/engine/datasetops/zip_op.h" - -namespace mindspore { -namespace dataset { - -// Driver method for TreePass -Status TreePass::Run(ExecutionTree *tree, bool *modified) { - if (tree == nullptr || modified == nullptr) { - return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); - } - return this->RunOnTree(tree, modified); -} - -// Driver method for NodePass -Status NodePass::Run(ExecutionTree *tree, bool *modified) { - if (tree == nullptr || modified == nullptr) { - return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); - } - std::shared_ptr root = tree->root(); - if (traversalOrder_ == Order::DFS) { - // DFS - return DFSNodeVisit(root, modified); - } else if (traversalOrder_ == Order::BFS) { - // BFS - return BFSNodeVisit(root, modified); - } - return Status::OK(); -} - -// Helper function to perform DFS visit -Status NodePass::DFSNodeVisit(std::shared_ptr node, bool *modified) { - RETURN_IF_NOT_OK(node->PreAccept(this, modified)); - for (const auto &c : node->Children()) { - RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); - } - return node->Accept(this, modified); -} - -// Helper function to perform BFS visit -Status NodePass::BFSNodeVisit(std::shared_ptr root, bool *modified) { - // Initialize bfs queue with root - std::queue> bfsQueue; - bfsQueue.push(root); - - // BFS loop - while (!bfsQueue.empty()) { - // Pop the front of the bfs queue - auto curNode = bfsQueue.front(); - bfsQueue.pop(); - - // Run node pass - RETURN_IF_NOT_OK(curNode->Accept(this, modified)); - - // Push children into bfs queue - for (const auto &c : curNode->Children()) { - bfsQueue.push(c); - } - } - return Status::OK(); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -#ifdef ENABLE_PYTHON -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} -#endif - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return RunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return PreRunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return PreRunOnNode(std::static_pointer_cast(node), modified); -} - -Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Fallback to base class visitor by default - return PreRunOnNode(std::static_pointer_cast(node), modified); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pass.h b/mindspore/ccsrc/dataset/engine/opt/pass.h deleted file mode 100644 index 8489faa23a..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pass.h +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_H_ - -#include -#include - -#include "dataset/engine/execution_tree.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BatchOp; - -class MapOp; - -class ProjectOp; - -class RenameOp; - -class SkipOp; - -class ShuffleOp; - -class MindRecordOp; - -class TFReaderOp; - -#ifdef ENABLE_PYTHON -class FilterOp; - -class GeneratorOp; -#endif - -class RandomDataOp; - -class RepeatOp; - -class TakeOp; - -class ZipOp; - -class DeviceQueueOp; - -class ImageFolderOp; - -class CacheOp; - -class MnistOp; - -class ManifestOp; - -class CifarOp; - -class VOCOp; - -class CocoOp; - -class CelebAOp; - -class CacheMergeOp; - -class CacheLookupOp; - -// The base class Pass is the basic unit of tree transformation. -// The actual implementation of the passes will be derived from here. -class Pass : public std::enable_shared_from_this { - public: - // Run the transformation pass against the execution tree. - // @param tree - Pointer to the execution tree to be transformed. - // @param modified - Pointer to the modified flag, - virtual Status Run(ExecutionTree *tree, bool *modified) = 0; -}; - -// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. -class TreePass : public Pass { - public: - /// \brief Run the transformation pass against the execution tree. - /// \param[inout] tree Pointer to the execution tree to be transformed. - /// \param[inout] modified Indicate if the tree was modified - Status Run(ExecutionTree *tree, bool *modified) final; - - /// \brief Derived classes may implement the runOnTree function to implement tree transformation. - /// "modified" flag needs to be set to true if tree is modified during the pass execution. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The error code return - virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } -}; - -// NodePass is a basic Pass class which performs transformation on Node visiting. -// NodePass implements Visitor design pattern. -class NodePass : public Pass { - public: - // Tree traversal order - enum Order { DFS, BFS }; - - // Constructor - // Default DFS traversal - explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } - - ~NodePass() = default; - - /// \brief Run the transformation pass against the execution tree - /// \param[inout] tree Pointer to the execution tree to be transformed - /// \param[inout] modified Indicator if the tree was changed - Status Run(ExecutionTree *tree, bool *modified) final; - - /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down - /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution - /// \param[in] node The node being visited - /// \param[out] modified Indicator if the node was changed at all - /// \return Status The error code return - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } - - /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation - /// "modified" flag needs to be set to true if tree is modified during the pass execution - /// \param[in] node The node being visited - /// \param[out] modified Indicator if the node was changed at all. - /// \return Status The error code return - virtual Status RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } - - // Visit methods to be overridden. - // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode - // of its own type and override "Accept" from DatasetOp. - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - -#ifdef ENABLE_PYTHON - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); -#endif - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status RunOnNode(std::shared_ptr node, bool *modified); - - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - - virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); - - private: - // Helper function to perform DFS visit - Status DFSNodeVisit(std::shared_ptr node, bool *modified); - - // Helper function to perform BFS visit - Status BFSNodeVisit(std::shared_ptr root, bool *modified); - - // Tree traversal order of the NodePass - Order traversalOrder_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc deleted file mode 100644 index 9f7a561aa6..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/post/repeat_pass.h" -#include "dataset/engine/datasetops/repeat_op.h" -#include "dataset/engine/datasetops/cache_op.h" -#include "dataset/engine/datasetops/cache_lookup_op.h" -#include "dataset/engine/datasetops/cache_merge_op.h" - -namespace mindspore { -namespace dataset { - -RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} - -// Identifies the subtree below this node as being in a repeated path of the tree. -Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // If we are already repeated, then this is a nested repeat. - if (is_repeated_) { - nested_repeats_++; - } - is_repeated_ = true; - return Status::OK(); -} - -// Identifies the subtree below this node as being in a cache merge path -Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { - // Turn on the flag that we're under a merge op - is_merge_ = true; - return Status::OK(); -} - -// Hooks up any identified eoe nodes under this repeat. -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking - std::shared_ptr leaf_op = PopFromEOEOpStack(); - while (leaf_op != nullptr) { - node->AddToEoeList(leaf_op); - leaf_op = PopFromEOEOpStack(); - } - - // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up - // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. - if (is_merge_ && cache_lookup_) { - cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); - node->AddToEoeList(std::move(cache_lookup_)); - } - - // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. - // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. - if (nested_repeats_ > 0) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - AddToEOEOpStack(node); - nested_repeats_--; - } - - // If we are not nested, or we were the top-most repeat, now we clear the flag - if (nested_repeats_ == 0) { - is_repeated_ = false; - } - - return Status::OK(); -} - -// CacheOp removes previous leaf ops and replaces them with itself -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - // if we are a cache within a repeat path of the tree, then there will be - // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the - // repeat or epoch ctrl operators can work with them for repeat activity during runtime. - // However, since a cache is present: - // - unflag those ops as being repeated ops - // - remove them from the eoe op stack so that repeat op above in the tree won't know about them - // - add ourself (the cache op), as an eoe op - // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead - // the repeating behaviours shall be invoked against the cache op. - std::shared_ptr leaf_op = PopFromEOEOpStack(); - while (leaf_op != nullptr) { - leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); - leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); - leaf_op = PopFromEOEOpStack(); - } - AddToEOEOpStack(std::static_pointer_cast(node)); - } - - return Status::OK(); -} - -// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up -// for use with a controlling repeat above it. -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // If we are in a repeat path, then set our repeated flag - if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - - // if we are a leaf node then save ourself in a stack for the repeat operator above us - if (node->IsLeaf()) { - AddToEOEOpStack(node); - } - } - return Status::OK(); -} - -// Turns off the tracking for operations under merge op -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - // Setting the flag is needed since we didn't call the base class DatasetOp version - if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); - is_merge_ = false; - cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed - return Status::OK(); -} - -// Saves the lookup up in case it needs to be referenced by a repeat -Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { - if (!node->IsLeaf()) { - // By definition, the CacheLookup must be a leaf op. Make that clear here. - RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); - } - - // If we are in a repeat path already, then there must be a repeat above the merge op - // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. - if (is_repeated_) { - node->set_control_flag(DatasetOp::kDeOpRepeated); - AddToEOEOpStack(node); - } else { - // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we - // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself - // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. - cache_lookup_ = std::static_pointer_cast(node); - } - return Status::OK(); -} - -// Adds an operator to the eoe operator stack save area -void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } - -// Pops an operator from the eoe operator stack save area -std::shared_ptr RepeatPass::PopFromEOEOpStack() { - std::shared_ptr top_op = nullptr; - if (!eoe_stack_.empty()) { - top_op = eoe_stack_.top(); - eoe_stack_.pop(); - } - return top_op; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h deleted file mode 100644 index 3f5f347a30..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ -#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ - -#include -#include -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -/// \class RepeatPass repeat_pass.h -/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references -/// to the eoe-producing (typically leaf) nodes underneath it. -class RepeatPass : public NodePass { - public: - /// \brief Constructor - RepeatPass(); - - /// \brief Identifies the subtree below this node as being in a repeated path of the tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Identifies the subtree below this node as being in a cache merge path - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Hooks up any identified eoe nodes under this repeat. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief CacheOp removes previous leaf ops and replaces them with itself - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Turns of the tracking for operations under merge op - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Saves the lookup up in case it needs to be referenced by a repeat - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up - /// for use with a controlling repeat above it. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - /// \brief Adds an operator to the eoe operator stack save area - /// \param op - The dataset op to work add to eoe stack - /// \return Status - The error code return - void AddToEOEOpStack(std::shared_ptr dataset_op); - - /// \brief Pops an operator from the eoe operator stack save area - /// \return shared_ptr to the popped operator - std::shared_ptr PopFromEOEOpStack(); - - bool is_repeated_; // T/F if we are processing under a repeat - bool is_merge_; // T/F if we are processing under a cache merge op - int32_t nested_repeats_; // A counter for nested repeats - std::stack> eoe_stack_; // A save area for leaf/eoe ops - std::shared_ptr cache_lookup_; // A save area for a cache lookup op -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc deleted file mode 100644 index ae0f4d3a04..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/pre/cache_pass.h" -#include "dataset/engine/opt/pre/cache_transform_pass.h" -#include "dataset/engine/datasetops/cache_op.h" -#include "dataset/engine/datasetops/source/celeba_op.h" -#include "dataset/engine/datasetops/source/generator_op.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/source/mindrecord_op.h" - -namespace mindspore { -namespace dataset { - -// Constructor -CachePass::CachePass(CacheTransformPass *transform_pass) - : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; - if (is_caching_) { - RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); - } - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache -// transformation -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - is_caching_ = false; // We a no longer in a cache subtree. clear the flag. - if (leaf_op_) { - MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; - // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, - // using base class pointers. - transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); - } else { - // If there was no leaf_op set, then this is a non-mappable scenario. - - if (sampler_) { - // Grab the sampler that was saved from the leaf and plug it into the cache op - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; - } else { - // We're a cache op but no sampler was saved from leaf, so create a default sampler - int64_t num_samples = 0; - int64_t start_index = 0; - sampler_ = std::make_shared(num_samples, start_index); - node->SetSampler(std::move(sampler_)); - MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; - } - - // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache - uint32_t cache_crc = DatasetOp::GenerateCRC(node); - RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); - } - - return Status::OK(); -} - -// Common code for mappable leaf setup. -Status CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); - } - - // If we are a leaf in the caching path, then save this leaf. - if (is_caching_) { - MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; - leaf_op_ = std::move(leaf_op); - } - return Status::OK(); -} - -// Common code for non mappable leaf setup. -Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { - // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. - if (is_caching_ && leaf_op_) { - RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); - } - - // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf - // as save it for use by cache op in ascendant tree. - if (is_caching_) { - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); - MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; - } else { - // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can - // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) - std::shared_ptr sampler_from_leaf; - RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); - } - return Status::OK(); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - if (is_caching_) { - // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic - // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. - node->MakeSimpleProducer(); - } - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} - -// Perform leaf node cache tranform identifications -Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { - return MappableCacheLeafSetup(std::static_pointer_cast(node)); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h deleted file mode 100644 index c842e54bbf..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ - -#include -#include -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class CacheTransformPass; - -/// \class CachePass cache_pass.h -/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache -/// transformation. It works in conjunction with the CacheTransformPass -class CachePass : public NodePass { - public: - /// \brief Constructor - /// \param[in] transform_pass Raw pointer back to controlling tree pass - explicit CachePass(CacheTransformPass *transform_pass); - - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache - /// transformation - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Perform leaf node cache tranform identifications - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - /// \brief Common code for mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The error code return - Status MappableCacheLeafSetup(std::shared_ptr leaf_op); - - /// \brief Common code for non-mappable leaf setup. - /// \param[in] node The leaf node performing setup work. - /// \return Status The error code return - Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); - - bool is_caching_; - std::shared_ptr leaf_op_; - std::shared_ptr sampler_; - CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc deleted file mode 100644 index df4933fa1c..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/pre/cache_pass.h" -#include "dataset/engine/opt/pre/cache_transform_pass.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/datasetops/cache_lookup_op.h" -#include "dataset/engine/datasetops/cache_merge_op.h" -#include "dataset/engine/datasetops/cache_op.h" - -namespace mindspore { -namespace dataset { - -// constructor -CacheTransformPass::CacheTransformPass() {} - -// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations -Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { - MS_LOG(INFO) << "Pre pass: Cache transform pass started."; - // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will - // use to execute a transform. - std::unique_ptr cache_pass = std::make_unique(this); - RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); - - // Then, execute the transform for each pair - for (auto cache_pair : cache_pairs_) { - MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; - ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); - } - MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; - return Status::OK(); -} - -// Helper function to execute the cache transformation. -Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, - std::shared_ptr cache_op, - std::shared_ptr cache_client) { - // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was - // the root node. It is also possible that cache_child == leaf_op - std::shared_ptr cache_child = cache_op->child(0); - DatasetOp *cache_parent = nullptr; - cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent - - // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. - std::shared_ptr leaf_sampler = leaf_op->sampler(); - - // Construct the merge op with defaults - std::shared_ptr merge_op; - CacheMergeOp::Builder merge_builder; - RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); - RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); - - // Construct the cache lookup op with defaults - std::shared_ptr cache_lookup_op; - CacheLookupOp::Builder lookup_builder; - RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); - RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); - - // Overwrite the old sampler in this leaf op to become the lookup op - leaf_op->SetSampler(cache_lookup_op); - - // If the cache had a parent, then go into that parent to remove the cache from it's child list and then - // replace it with the merge op. - if (cache_parent != nullptr) { - RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); - RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); - } else { - // If we didn't have a parent, then the merge op is the root node - RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); - } - - // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. - // We maintain a local pointer to the old child though. - RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); - - // Connect the merge op - RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); - RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); - - // At this point, the cache op has already had it's children and parents taken away. Calling remove - // on it at this point will not do any node hookups, and instead set internal fields to invalid. - RETURN_IF_NOT_OK(cache_op->Remove()); - - return Status::OK(); -} - -// Assigns the leaf and cache operators that are involved in a cache transformation -void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr leaf_op, - std::shared_ptr cache_op) { - cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h deleted file mode 100644 index dc31d76d80..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ - -#include -#include -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class DatasetOp; - -class CacheClient; - -/// \class CacheTransformPass cache_transform_pass.h -/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching -/// operations -class CacheTransformPass : public TreePass { - public: - /// \brief Constructor - CacheTransformPass(); - - /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The error code return - Status RunOnTree(ExecutionTree *tree, bool *modified) override; - - /// \brief Assigns the leaf and cache operators that are involved in a cache transformation - /// \param[in] leaf_op The leaf operator involved in the cache transform - /// \param[in] cache_op The cache operator involved in the cache transform - void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); - - private: - /// \brief Helper function to execute the cache transformation. - /// - /// Input: - /// Sampler - /// | - /// LeafOp --> OtherOps --> CacheOp - /// - /// Transformed: - /// Sampler --> CacheLookupOp ----------------> - /// | | - /// | MergeOp - /// | | - /// LeafOp --> OtherOps --> - /// - /// \param[in] leaf_op The leaf node in the transform - /// \param[in] cache_op The cache op in the transform (will get removed) - /// \param[in] cache_client The cache client - /// \return Status The error code return - Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, - std::shared_ptr cache_op, std::shared_ptr cache_client); - - // The two operators that work together to establish the cache transform - std::vector, std::shared_ptr>> cache_pairs_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc deleted file mode 100644 index e361015e48..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/pre/removal_nodes.h" -#include "dataset/engine/opt/pre/removal_pass.h" -#include "dataset/engine/datasetops/shuffle_op.h" - -namespace mindspore { -namespace dataset { - -RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} - -// Identifies the subtree below this node as a cached descendant tree. -Status RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; - is_caching_ = true; - return Status::OK(); -} - -// Resets the tracking of the cache within the tree -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; - is_caching_ = false; - return Status::OK(); -} - -// Perform ShuffleOp removal check. -Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - // If we are in a cache descendant tree, then this shuffle op needs to be removed - if (is_caching_) { - MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; - if (removal_pass_) { - removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); - } else { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); - } - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h deleted file mode 100644 index be1aaea645..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ - -#include -#include "dataset/engine/opt/pass.h" -#include "dataset/engine/opt/pre/removal_pass.h" - -namespace mindspore { -namespace dataset { -/// \class RemovalNodes removal_nodes.h -/// \brief This is a NodePass who's job is to identify which nodes should be removed. -/// It works in conjunction with the removal_pass. -class RemovalNodes : public NodePass { - public: - /// \brief Constructor - /// \param[in] removal_pass Raw pointer back to controlling tree pass - explicit RemovalNodes(RemovalPass *removal_pass); - - /// \brief Identifies the subtree below this node as a cached descendant tree. - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status PreRunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Resets the tracking of the cache within the tree - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - /// \brief Destructor - ~RemovalNodes() = default; - - /// \brief Perform ShuffleOp removal check - /// \param[in] node The node being visited - /// \param[inout] modified Indicator if the node was changed at all - /// \return Status The error code return - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - private: - bool is_caching_; - RemovalPass *removal_pass_; // Back pointer to the owning removal pass -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc deleted file mode 100644 index db5e37a085..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/engine/opt/pre/removal_nodes.h" -#include "dataset/engine/opt/pre/removal_pass.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { - -// constructor -RemovalPass::RemovalPass() {} - -// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. -Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { - MS_LOG(INFO) << "Pre pass: removal pass started."; - // Create the removal node pass which can identify which nodes need to be removed. - std::unique_ptr removal_nodes = std::make_unique(this); - RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); - - // Then, execute the removal of any nodes that were set up for removal - for (auto node : removal_nodes_) { - node->Remove(); - } - MS_LOG(INFO) << "Pre pass: removal pass complete."; - return Status::OK(); -} - -// Adds an operator to the list of operators to be removed -void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h deleted file mode 100644 index 6c1963b826..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ -#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ - -#include -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class DatasetOp; - -/// \class RemovalPass removal_pass.h -/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which -/// nodes should be removed, and then removes them. -class RemovalPass : public TreePass { - public: - /// \brief Constructor - RemovalPass(); - - /// \brief Destructor - ~RemovalPass() = default; - - /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. - /// \param[inout] tree The tree to operate on. - /// \param[inout] Indicate of the tree was modified. - /// \return Status The error code return - Status RunOnTree(ExecutionTree *tree, bool *modified) override; - - /// \brief Adds an operator to the list of operators to be removed - /// \param[in] dataset_op The operator to add to the removal list - void AddToRemovalList(std::shared_ptr dataset_op); - - private: - std::vector> removal_nodes_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc deleted file mode 100644 index 305c3ce121..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/engine/opt/util/printer_pass.h" - -namespace mindspore { -namespace dataset { - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting DatasetOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting BatchOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting MapOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ProjectOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting RenameOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting SkipOp" << '\n'; - return Status::OK(); -} -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ShuffleOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting MindRecordOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting TFReaderOp" << '\n'; - return Status::OK(); -} - -#ifdef ENABLE_PYTHON -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting FilterOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting GeneratorOp" << '\n'; - return Status::OK(); -} -#endif - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting TakeOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ZipOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting DeviceQueueOp" << '\n'; - return Status::OK(); -} - -Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { - *modified = false; - std::cout << "Visiting ImageFolderOp" << '\n'; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h deleted file mode 100644 index 2552476ebd..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/util/printer_pass.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H -#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H - -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { - -class PrinterPass : public NodePass { - public: - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - -#ifdef ENABLE_PYTHON - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; -#endif - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; - - Status RunOnNode(std::shared_ptr node, bool *modified) override; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_size.cc b/mindspore/ccsrc/dataset/engine/perf/connector_size.cc deleted file mode 100644 index 0bd2754075..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_size.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/perf/connector_size.h" -#include -#include -#include -#include -#include "dataset/core/config_manager.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/path.h" - -using json = nlohmann::json; -namespace mindspore { -namespace dataset { -using Qrow = std::vector; - -// Sample action -Status ConnectorSize::Sample() { - Qrow cur_row; - std::transform(tree_->begin(), tree_->end(), std::back_inserter(cur_row), - [](DatasetOp &op) { return op.ConnectorSize(); }); - // Push new row of sample - sample_table_.push_back(cur_row); - return Status::OK(); -} - -// JSON serializer helper function -json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector &size) { - auto children = node.Children(); - std::vector children_id; - std::transform(children.begin(), children.end(), std::back_inserter(children_id), - [](std::shared_ptr op) -> int32_t { return op->id(); }); - json json_node; - json_node["op_id"] = node.id(); - json_node["op_type"] = node.Name(); - json_node["num_workers"] = node.num_workers(); - json metrics; - // DeviceQueueOp is a special op,it is not inlined but its output queue is invalid. - // So we should not output its queue size. - if (!node.inlined() && node.Name() != "DeviceQueueOp") { - metrics["output_queue"] = {{"size", size}, {"length", node.ConnectorCapacity()}}; - } - json_node["metrics"] = metrics; - if (!children_id.empty()) { - json_node["children"] = children_id; - } - - return json_node; -} - -// Save profiling data to file -Status ConnectorSize::SaveToFile() { - std::ofstream os(file_path_, std::ios::trunc); - uint32_t idx = 0; - json output; - std::shared_ptr cfg = GlobalContext::config_manager(); - output["sampling_interval"] = cfg->monitor_sampling_interval(); - // Traverse the ExecutionTree for JSON node generation - for (auto &node : *tree_) { - std::vector cur_queue_size; - std::transform(sample_table_.begin(), sample_table_.end(), std::back_inserter(cur_queue_size), - [&](const ConnectorSizeSample &sample) { return sample[idx]; }); - json json_node = ParseOpInfo(node, cur_queue_size); - output["op_info"].push_back(json_node); - idx++; - } - os << output; - return Status::OK(); -} -Status ConnectorSize::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + device_id + ".json")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_size.h b/mindspore/ccsrc/dataset/engine/perf/connector_size.h deleted file mode 100644 index 2584289fb4..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_size.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_CONNECTOR_SIZE_H -#define DATASET_CONNECTOR_SIZE_H - -#include -#include -#include -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/datasetops/dataset_op.h" - -using json = nlohmann::json; - -namespace mindspore { -namespace dataset { -class ExecutionTree; - -// Connector size sampling samples the output connector size of each op in the pipeline. -// It support JSON serialization for external usage. -class ConnectorSize : public Sampling { - // Connecto size sampling data is stored as a 2D vector - // op_0 ... op_m - // sample_0 size_0_0 ... size_m_0 - // ... ... ... ... - // sample_n size_0_m ... size_m_n - // - // A circular buffer will be implemented in the future to make this table more flexible. - using ConnectorSizeSample = std::vector; - using ConnectorSizeSampleTable = std::vector; - - public: - explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} - - ~ConnectorSize() override = default; - - // Driver function for connector size sampling. - // This function samples the connector size of every nodes within the ExecutionTree - Status Sample() override; - - std::string Name() const override { return kConnectorSizeSamplingName; } - - // Save sampling data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id) override; - - // Parse op infomation and transform to json format - json ParseOpInfo(const DatasetOp &node, const std::vector &size); - - private: - ExecutionTree *tree_ = nullptr; // ExecutionTree pointer - ConnectorSizeSampleTable sample_table_; // Dataset structure to store all samples of connector size sampling -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CONNECTOR_SIZE_H diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.cc b/mindspore/ccsrc/dataset/engine/perf/connector_throughput.cc deleted file mode 100644 index 4fd59de390..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.cc +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include "dataset/engine/perf/connector_throughput.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/util/path.h" - -namespace mindspore { -namespace dataset { - -// temporary helper -int ConnectorThroughput::InitNodes() { - auto it = (*tree_).begin(); - return it.NumNodes(); -} -// Sample action -Status ConnectorThroughput::Sample() { - std::vector out_buffer_count_row(n_nodes_); - std::vector throughput_row(n_nodes_); - TimePoint cur_time; // initialised inside the loop, used outside the loop to update prev sample time. - auto col = 0; - for (const auto &node : *tree_) { - auto cur_out_buffer_count = node.ConnectorOutBufferCount(); - out_buffer_count_row[col] = cur_out_buffer_count; - auto sz = timestamps_.size(); - cur_time = std::chrono::steady_clock::now(); - auto _dt = std::chrono::duration_cast(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]); - auto dt = std::chrono::duration(_dt).count(); - auto prev_out_buffer_count = out_buffer_count_table_[col][out_buffer_count_table_.size() - 1]; - if (dt != 0) { - auto thr = (cur_out_buffer_count - prev_out_buffer_count) / (1000 * dt); - throughput_row[col] = thr; - } else { - throughput_row[col] = -1; - } - col++; - } - std::vector v = {cur_time}; // temporary fix - timestamps_.AddSample(v); - // Push new row of sample - out_buffer_count_table_.AddSample(out_buffer_count_row); - throughput_.AddSample(throughput_row); - return Status::OK(); -} - -json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector &thr) { - auto children = node.Children(); - std::vector children_id; - std::transform(children.begin(), children.end(), std::back_inserter(children_id), - [](std::shared_ptr op) -> int32_t { return op->id(); }); - json json_node; - json_node["op_id"] = node.id(); - json_node["op_type"] = node.Name(); - json_node["num_workers"] = node.num_workers(); - json metrics; - metrics["output_queue"] = {{"throughput", thr}}; - - json_node["metrics"] = metrics; - if (!children_id.empty()) { - json_node["children"] = children_id; - } - - return json_node; -} - -// Save profiling data to file -Status ConnectorThroughput::SaveToFile() { - std::ofstream os(file_path_); - json output; - output["sampling_interval"] = 10; - // Traverse the ExecutionTree for JSON node generation - int col = 0; - for (auto &node : *tree_) { - std::vector throughput; - for (auto i = 0; i < throughput_.size(); i++) { - throughput.push_back(throughput_[col][i]); - } - json json_node = ParseOpInfo(node, throughput); - output["op_info"].push_back(json_node); - col++; - } - os << output; - return Status::OK(); -} -Status ConnectorThroughput::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + Name() + "_" + device_id + ".json")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.h b/mindspore/ccsrc/dataset/engine/perf/connector_throughput.h deleted file mode 100644 index 4dbb4cdad7..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/connector_throughput.h +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_CONNECTOR_THROUGHPUT_H -#define DATASET_CONNECTOR_THROUGHPUT_H - -#include -#include -#include -#include -#include -#include "dataset/engine/perf/profiling.h" -#include "dataset/engine/perf/perf_data.h" -#include "dataset/engine/perf/cyclic_array.h" -#include "dataset/engine/datasetops/dataset_op.h" -#include "dataset/engine/execution_tree.h" - -using json = nlohmann::json; -namespace mindspore { -namespace dataset { -// Connector throughput samples the output connector size of each op in the pipeline. -// For the description of the data structure see perf_buffer.h -// It support JSON serialization for external usage. -class ConnectorThroughput : public Sampling { - using OutBufferCount = PerfData>; - using Throughput = PerfData>; - using TimePoint = std::chrono::time_point; - using TimeStamps = PerfData>; - - public: - explicit ConnectorThroughput(ExecutionTree *tree, int64_t max_rows = 1000000) - : tree_(tree), - max_rows_(max_rows), - n_nodes_(InitNodes()), - out_buffer_count_table_(OutBufferCount(max_rows_, n_nodes_)), - throughput_(Throughput(max_rows_, n_nodes_)), - timestamps_(TimeStamps(max_rows_, 1)) { - timestamps_.AddSample(std::vector(1)); - out_buffer_count_table_.AddSample(std::vector(n_nodes_)); - } - - /// \brief Destructor - ~ConnectorThroughput() = default; - - // Driver function for connector size sampling. - // This function samples the connector size of every nodes within the ExecutionTree - Status Sample() override; - - /* Status TestPrint() override { - std::ofstream os("performance_monitor.txt"); - if (throughput_.size() == 0) { - os << "data is empty" << std::endl; - return Status::OK(); - } - for (int i = 0; i < throughput_.size(); i++) { - for (int j = 0; j < n_nodes_; j++) { - os << throughput_[j][i] << " "; - } - os << std::endl; - } - return Status::OK(); - };*/ - - // Traverse the tree nodes and count them - int InitNodes(); - - std::string Name() const override { return name_; }; - - // Save sampling data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id); - - json ParseOpInfo(const DatasetOp &node, const std::vector &thr); - - private: - ExecutionTree *tree_ = nullptr; // ExecutionTree pointer - int64_t max_rows_; - int32_t n_nodes_; - OutBufferCount out_buffer_count_table_; - Throughput throughput_; - TimeStamps timestamps_; - std::string name_ = kConnectorThroughputSamplingName; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_CONNECTOR_THROUGHPUT_H diff --git a/mindspore/ccsrc/dataset/engine/perf/cyclic_array.h b/mindspore/ccsrc/dataset/engine/perf/cyclic_array.h deleted file mode 100644 index fa60b401c5..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/cyclic_array.h +++ /dev/null @@ -1,197 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_CYCLIC_ARRAY_H -#define DATASET_CYCLIC_ARRAY_H - -#include -#include -#include -#include -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { - -/// \class CyclicArray "include/cyclic_array.h -/// \brief This is a container with a contiguous memory layout that pnly keeps N last entries, -/// when the number of entries exceeds the capacity -/// Must be preallocated -template -class CyclicArray { - public: - using value_type = T; - class Iterator { - // Add operator[] and make fully compliant with random access iterator - // and add a const iterator - // add resize(), empty() - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = CyclicArray::value_type; - using difference_type = std::ptrdiff_t; - using pointer = CyclicArray::value_type *; - using reference = CyclicArray::value_type &; - - Iterator() = default; - - Iterator(dsize_t idx, pointer ptr, dsize_t capacity, dsize_t head) - : cur_idx_(idx), ptr_(ptr), capacity_(capacity), head_(head) {} - - Iterator(const Iterator &rhs) = default; - - ~Iterator() = default; - - Iterator &operator++() { - cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); - return *this; - } - - Iterator operator++(int) { - Iterator tmp(*this); - cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); - return tmp; - } - - Iterator &operator--() { - cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); - return *this; - } - - Iterator operator--(int) { - Iterator tmp(*this); - cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); - return tmp; - } - - Iterator operator+(dsize_t x) { return Iterator((cur_idx_ + x) % (capacity_ + 1), ptr_, capacity_, head_); } - - Iterator operator-(dsize_t x) { - return Iterator((cur_idx_ + (capacity_ + 1 - x)) % (capacity_ + 1), ptr_, capacity_, head_); - } - - bool operator<(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) < (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - bool operator>(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) > (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - bool operator>=(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) >= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - bool operator<=(const Iterator &rhs) { - return (head_ + cur_idx_) % (capacity_ + 1) <= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); - } - - difference_type operator-(const Iterator &rhs) { - return (cur_idx_ - rhs.cur_idx_ + capacity_ + 1) % (capacity_ + 1); - } - - reference operator*() { return ptr_[cur_idx_]; } - - pointer operator->() { return &(ptr_[cur_idx_]); } - - bool operator==(const Iterator &rhs) { return cur_idx_ == rhs.cur_idx_; } - - bool operator!=(const Iterator &rhs) { return cur_idx_ != rhs.cur_idx_; } - - private: - dsize_t cur_idx_; - pointer ptr_; - dsize_t capacity_; - dsize_t head_; - }; - - /// \brief Default constructor - CyclicArray() : buf_(nullptr), head_(0), tail_(0), size_(0), capacity_(0) {} - - /// \brief Constructor - /// \param[in] capacity - explicit CyclicArray(dsize_t capacity) - : buf_(std::make_unique(capacity + 1)), head_(0), tail_(0), size_(0), capacity_(capacity) {} - - CyclicArray(const CyclicArray &rhs) - : buf_(std::make_unique(rhs.capacity_ + 1)), - head_(rhs.head_), - tail_(rhs.tail_), - size_(rhs.size_), - capacity_(rhs.capacity_) { - std::copy(rhs.begin(), rhs.end(), begin()); - } - - CyclicArray(CyclicArray &&rhs) = default; - - ~CyclicArray() = default; - - /// \brief Iterator begin() - Iterator begin() { return Iterator(head_, buf_.get(), capacity_, head_); } - - /// \brief Iterator end() - Iterator end() { return Iterator(tail_, buf_.get(), capacity_, head_); } - - // not really const. - Iterator begin() const { return Iterator(head_, buf_.get(), capacity_, head_); } - - Iterator end() const { return Iterator(tail_, buf_.get(), capacity_, head_); } - - /// \brief clear the array. Does not deallocate memory, capacity remains the same - void clear() { - head_ = 0; - tail_ = 0; - size_ = 0; - } - - /// \brief returns current size - dsize_t size() { return size_; } - - /// \brief returns capacity - dsize_t capacity() { return capacity_; } - - /// \brief pushes a value - /// \param[in] val value - void push_back(T val) { - buf_[tail_] = val; - if (size_ >= capacity_) { - (tail_ != capacity_) ? tail_++ : tail_ = 0; - (head_ != capacity_) ? head_++ : head_ = 0; - } else { - tail_++; - size_++; - } - } - - /// \brief returns const reference to an element of the array - /// \param[in] idx index of the element - /// \param[out] const T& reference to an element of the array - const T &operator[](dsize_t idx) const { return buf_[(head_ + idx) % (capacity_ + 1)]; } - - /// \brief returns non-const reference to an element of the array - /// \param[in] idx index of the element - /// \param[out] T& reference to an element of the array - T &operator[](dsize_t idx) { return buf_[(head_ + idx) % (capacity_ + 1)]; } - - private: - std::unique_ptr buf_; - dsize_t head_; - dsize_t tail_; - dsize_t size_; - dsize_t capacity_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_CYCLIC_ARRAY_H diff --git a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.cc b/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.cc deleted file mode 100644 index 99b0c2d7e0..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "dataset/engine/perf/dataset_iterator_tracing.h" -#include "dataset/util/path.h" - -namespace mindspore { -namespace dataset { - -Status DatasetIteratorTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, - const int32_t value) { - // Format: "type extra-info batch-num value" - // type: 0: time, 1: connector size - // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time - // if type is 1 - connector capacity - // batch-num: batch number - // value: if type is 0 - value is time(ms) - // if type is 1 - value is connector size - // Examples: - // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. - // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. - std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + - std::to_string(value); - value_.emplace_back(data); - return Status::OK(); -} - -Status DatasetIteratorTracing::SaveToFile() { - if (value_.empty()) { - return Status::OK(); - } - - std::ofstream handle(file_path_, std::ios::trunc); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); - } - for (auto value : value_) { - handle << value << "\n"; - } - handle.close(); - - return Status::OK(); -} - -Status DatasetIteratorTracing::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("dataset_iterator_profiling_" + device_id + ".txt")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h b/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h deleted file mode 100644 index 129863c6d1..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DATASET_ITERATOR_TRACING_H -#define MINDSPORE_DATASET_ITERATOR_TRACING_H - -#include -#include -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -class DatasetIteratorTracing : public Tracing { - public: - // Constructor - DatasetIteratorTracing() = default; - - // Destructor - ~DatasetIteratorTracing() override = default; - - // Record tracing data - // @return Status - The error code return - Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); - - std::string Name() const override { return kDatasetIteratorTracingName; }; - - // Save tracing data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id) override; - - private: - std::vector value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_DATASET_ITERATOR_TRACING_H diff --git a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.cc b/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.cc deleted file mode 100644 index 204a83e3fb..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/engine/perf/device_queue_tracing.h" -#include "dataset/util/path.h" -namespace mindspore { -namespace dataset { - -Status DeviceQueueTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, - const int32_t value) { - // Format: "type extra-info batch-num value" - // type: 0: time, 1: connector size - // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time - // if type is 1 - connector capacity - // batch-num: batch number - // value: if type is 0 - value is time(ms) - // if type is 1 - value is connector size - // Examples: - // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. - // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. - std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + - std::to_string(value); - value_.emplace_back(data); - return Status::OK(); -} - -Status DeviceQueueTracing::SaveToFile() { - if (value_.empty()) { - return Status::OK(); - } - - std::ofstream handle(file_path_, std::ios::trunc); - if (!handle.is_open()) { - RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); - } - for (auto value : value_) { - handle << value << "\n"; - } - handle.close(); - - return Status::OK(); -} - -Status DeviceQueueTracing::Init(const std::string &dir_path, const std::string &device_id) { - file_path_ = (Path(dir_path) / Path("device_queue_profiling_" + device_id + ".txt")).toString(); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h b/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h deleted file mode 100644 index 13ef7121c1..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DEVICE_QUEUE_TRACING_H -#define MINDSPORE_DEVICE_QUEUE_TRACING_H - -#include -#include -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -class DeviceQueueTracing : public Tracing { - public: - // Constructor - DeviceQueueTracing() = default; - - // Destructor - ~DeviceQueueTracing() override = default; - - // Record tracing data - // @return Status - The error code return - Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); - - std::string Name() const override { return kDeviceQueueTracingName; }; - - // Save tracing data to file - // @return Status - The error code return - Status SaveToFile() override; - - Status Init(const std::string &dir_path, const std::string &device_id) override; - - private: - std::vector value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_DEVICE_QUEUE_TRACING_H diff --git a/mindspore/ccsrc/dataset/engine/perf/monitor.cc b/mindspore/ccsrc/dataset/engine/perf/monitor.cc deleted file mode 100644 index 8a0d682b81..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/monitor.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "dataset/core/config_manager.h" -#include "dataset/engine/perf/monitor.h" -#include "dataset/engine/execution_tree.h" - -namespace mindspore { -namespace dataset { - -Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { - std::shared_ptr cfg = GlobalContext::config_manager(); - sampling_interval_ = cfg->monitor_sampling_interval(); - max_samples_ = 0; - cur_row_ = 0; -} -Status Monitor::operator()() { - // Register this thread with TaskManager to receive proper interrupt signal. - TaskManager::FindMe()->Post(); - - // Keep sampling if - // 1) Monitor Task is not interrupted by TaskManager AND - // 2) Iterator has not received EOF - while (!this_thread::is_interrupted() && !(tree_->isFinished())) { - for (auto &node : tree_->GetProfilingManager()->GetSamplingNodes()) { - RETURN_IF_NOT_OK(node.second->Sample()); - std::this_thread::sleep_for(std::chrono::milliseconds(sampling_interval_)); - } - } - - // Output all profiling data upon request. - tree_->GetProfilingManager()->SaveProfilingData(); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/monitor.h b/mindspore/ccsrc/dataset/engine/perf/monitor.h deleted file mode 100644 index 8b4245db8e..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/monitor.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MONITOR_H -#define MINDSPORE_MONITOR_H - -#include -#include -#include -#include "dataset/util/status.h" -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -class ExecutionTree; -class Monitor { - public: - // Monitor object constructor - - explicit Monitor(ExecutionTree *tree); - - Monitor() = default; - - ~Monitor() = default; - - // Functor for Perf Monitor main loop. - // This function will be the entry point of mindspore::Dataset::Task - Status operator()(); - - int64_t GetSamplingInterval() { return sampling_interval_; } - - private: - int64_t cur_row_; - int64_t max_samples_; - int64_t sampling_interval_; - ExecutionTree *tree_; - std::vector> sampling_list_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_MONITOR_H diff --git a/mindspore/ccsrc/dataset/engine/perf/perf_data.h b/mindspore/ccsrc/dataset/engine/perf/perf_data.h deleted file mode 100644 index a201d705ea..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/perf_data.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_PERF_DATA_H -#define DATASET_PERF_DATA_H - -#include -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { - -// PerfData is a convenience class to record and store the data produced by Monitor -// and represents a 2D column major table with every column storing samples -// for an operator. The number of rows equals to the number of samples, -// the number of columns equals to the number of operators. -// The capacity is determined on construction and cannot be changed. -// ColumnType can be std::vector or CyclicArray. In case of the latter data can be added -// indefinitely without the risk of overflowing otherwise the capacity must not be exceeded. -// Given PerfData pd(n_rows, n_cols) an element in the column i and row j can be accessed as -// pd[i][j] - -template -class PerfData { - public: - PerfData() = default; - ~PerfData() = default; - PerfData(dsize_t max_rows, dsize_t n_cols) : counter_(0), max_rows_(max_rows), n_cols_(n_cols) { - for (auto i = 0; i < n_cols_; i++) { - data_.push_back(ColumnType(max_rows_)); - } - } - PerfData(const PerfData &rhs) = default; - PerfData(PerfData &&rhs) = default; - - // Adds a row of data - // T must be any container working with range based loops - template - void AddSample(const T &row) { - auto i = 0; - for (const auto &e : row) { - data_[i++].push_back(e); - } - counter_++; - } - - // Fetches a row of data by copy - template - auto Row(dsize_t idx) { - std::vector row(n_cols_); - for (auto i = 0; i < n_cols_; i++) { - row[i] = data_[i][idx]; - } - return row; - } - - // returns a column of data - ColumnType &operator[](size_t idx) { return data_[idx]; } - - const ColumnType &operator[](size_t idx) const { return data_[idx]; } - - dsize_t size() { return counter_ < max_rows_ ? counter_ : max_rows_; } - - dsize_t capacity() { return max_rows_; } - - private: - std::vector data_; - dsize_t counter_; - dsize_t max_rows_; - int n_cols_; -}; - -} // namespace dataset -} // namespace mindspore -#endif // DATASET_PERF_DATA_H diff --git a/mindspore/ccsrc/dataset/engine/perf/profiling.cc b/mindspore/ccsrc/dataset/engine/perf/profiling.cc deleted file mode 100644 index 66f27c46ba..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/profiling.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/perf/profiling.h" -#include -#include -#include -#include "common/utils.h" -#include "dataset/util/path.h" -#include "dataset/engine/perf/monitor.h" -#include "dataset/engine/perf/device_queue_tracing.h" -#include "dataset/engine/perf/connector_size.h" -#include "dataset/engine/perf/connector_throughput.h" -#include "dataset/engine/perf/dataset_iterator_tracing.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { - -bool ProfilingManager::IsProfilingEnable() const { - auto profiling = common::GetEnv("PROFILING_MODE"); - if (profiling.empty() || profiling != "true") { - return false; - } - return true; -} - -Status ProfilingManager::Initialize() { - // Register nodes based on config - std::string dir = common::GetEnv("MINDDATA_PROFILING_DIR"); - if (dir.empty()) { - RETURN_STATUS_UNEXPECTED("Profiling dir is not set."); - } - char real_path[PATH_MAX] = {0}; - if (dir.size() >= PATH_MAX) { - RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); - } -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { - RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); - } -#else - if (realpath(common::SafeCStr(dir), real_path) == nullptr) { - RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); - } -#endif - dir_path_ = real_path; - - // If DEVICE_ID is not set,defult value is 0 - device_id_ = common::GetEnv("DEVICE_ID"); - if (device_id_.empty()) { - device_id_ = "0"; - } - - // Register all profiling node. - // device_queue node is used for graph mode - std::shared_ptr device_queue_tracing = std::make_shared(); - RETURN_IF_NOT_OK(RegisterTracingNode(device_queue_tracing)); - // dataset_iterator node is used for graph mode - std::shared_ptr dataset_iterator_tracing = std::make_shared(); - RETURN_IF_NOT_OK(RegisterTracingNode(dataset_iterator_tracing)); - - std::shared_ptr connector_size_sampling = std::make_shared(tree_); - RETURN_IF_NOT_OK(RegisterSamplingNode(connector_size_sampling)); - - std::shared_ptr connector_thr_sampling = std::make_shared(tree_); - RETURN_IF_NOT_OK(RegisterSamplingNode(connector_thr_sampling)); - return Status::OK(); -} - -// Profiling node registration -Status ProfilingManager::RegisterTracingNode(std::shared_ptr node) { - // Check if node with the same name has already been registered. - auto exist = tracing_nodes_.find(node->Name()); - if (exist != tracing_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); - } - // Register the node with its name as key. - RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); - tracing_nodes_[node->Name()] = node; - return Status::OK(); -} - -// Profiling node getter -Status ProfilingManager::GetTracingNode(const std::string &name, std::shared_ptr *node) { - // Check if node with the same name has already been registered. - auto exist = tracing_nodes_.find(name); - if (exist == tracing_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); - } - // Fetch node. - *node = tracing_nodes_[name]; - return Status::OK(); -} - -// Profiling node registration -Status ProfilingManager::RegisterSamplingNode(std::shared_ptr node) { - // Check if node with the same name has already been registered. - auto exist = sampling_nodes_.find(node->Name()); - if (exist != sampling_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); - } - // Register the node with its name as key. - RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); - sampling_nodes_[node->Name()] = node; - return Status::OK(); -} - -// Profiling node getter -Status ProfilingManager::GetSamplingNode(const std::string &name, std::shared_ptr *node) { - // Check if node with the same name has already been registered. - auto exist = sampling_nodes_.find(name); - if (exist == sampling_nodes_.end()) { - return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); - } - // Fetch node. - *node = sampling_nodes_[name]; - return Status::OK(); -} - -Status ProfilingManager::SaveProfilingData() { - if (!IsProfilingEnable()) { - return Status::OK(); - } - MS_LOG(INFO) << "Start to save profiling data."; - for (auto node : tracing_nodes_) { - RETURN_IF_NOT_OK(node.second->SaveToFile()); - } - for (auto node : sampling_nodes_) { - RETURN_IF_NOT_OK(node.second->SaveToFile()); - } - MS_LOG(INFO) << "Save profiling data end."; - return Status::OK(); -} - -int64_t ProfilingTime::GetCurMilliSecond() { - // because cpplint does not allow using namespace - using std::chrono::duration_cast; - using std::chrono::milliseconds; - using std::chrono::steady_clock; - return duration_cast(steady_clock::now().time_since_epoch()).count(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/perf/profiling.h b/mindspore/ccsrc/dataset/engine/perf/profiling.h deleted file mode 100644 index e38c2d5e54..0000000000 --- a/mindspore/ccsrc/dataset/engine/perf/profiling.h +++ /dev/null @@ -1,144 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_PROFILE_H_ -#define DATASET_UTIL_PROFILE_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class Monitor; -class ExecutionTree; - -const char kDeviceQueueTracingName[] = "Device_Queue_Tracing"; -const char kDatasetIteratorTracingName[] = "Dataset_Iterator_Tracing"; -const char kConnectorSizeSamplingName[] = "Connector_Size_Sampling"; -const char kConnectorThroughputSamplingName[] = "Connector_Throughput_Sampling"; - -// Profiling is a class of basic unit of profiling action -// This base class encapsulate the serialization output logic -class Profiling : std::enable_shared_from_this { - public: - // Constructor - Profiling() = default; - - // Destructor - virtual ~Profiling() = default; - - virtual Status Init(const std::string &dir_path, const std::string &device_id) = 0; - - // Default serialization file generator - virtual Status SaveToFile() = 0; - - // Profiling name - virtual std::string Name() const = 0; - - protected: - std::string file_path_; -}; - -// Sampling is a class of profiling which generate samples periodically. -class Sampling : public Profiling { - public: - // Sampling action function. This function will be invoked by performance monitor thread. - virtual Status Sample() = 0; - // virtual Status TestPrint() = 0; - virtual ~Sampling() = default; -}; - -// Tracing is class of profiling which record samples upon request. -class Tracing : public Profiling { - // Tracing does not define a fixed interface to provide flexible on data recording. -}; - -// ProfilingManager is a class manages all profiling infrastructure -// It serves the following purposes: -// 1) Fetch profiling configs from global contexts -// 2) Setup all profiling node based on config -// 3) Provide access of profiling nodes for profiling actions -// 4) Manage profiling data serialization process -class ProfilingManager { - public: - explicit ProfilingManager(ExecutionTree *tree) : tree_(tree) {} - - ~ProfilingManager() = default; - - Status Initialize(); - - // Save profile data to file - // @return Status - The error code return - Status SaveProfilingData(); - - // Sampling node getter - // @param name - The name of the requested node - // @param node - Pointer to the shared pointer for the Sampling node - // @return Status - The error code return - Status GetSamplingNode(const std::string &name, std::shared_ptr *node); - - // Tracing node getter - // @param name - The name of the requested node - // @param node - Pointer to the shared pointer for the Tracing node - // @return Status - The error code return - Status GetTracingNode(const std::string &name, std::shared_ptr *node); - - // If profiling is enabled. - bool IsProfilingEnable() const; - - const std::unordered_map> &GetSamplingNodes() { return sampling_nodes_; } - - private: - std::unordered_map> tracing_nodes_; - - std::unordered_map> sampling_nodes_; - - // Register profile node to tree - // @param node - Profiling node - // @return Status - The error code return - Status RegisterTracingNode(std::shared_ptr node); - - // Register profile node to tree - // @param node - Profiling node - // @return Status - The error code return - Status RegisterSamplingNode(std::shared_ptr node); - - ExecutionTree *tree_ = nullptr; // ExecutionTree pointer - std::string dir_path_; // where to create profiling file - std::string device_id_; // used when create profiling file,filename_deviceid.suffix -}; - -enum ProfilingType { TIME, CONNECTOR_DEPTH }; - -enum ProfilingTimeSubType { - PIPELINE_TIME, - TDT_PUSH_TIME, - BATCH_TIME, - INVALID_TIME, -}; - -class ProfilingTime { - public: - static int64_t GetCurMilliSecond(); -}; - -} // namespace dataset -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc deleted file mode 100644 index ca9f2176f5..0000000000 --- a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/engine/tdt/tdt_plugin.h" -#include "common/utils.h" -#include "utils/log_adapter.h" -#include "dataset/engine/perf/profiling.h" - -namespace mindspore { -namespace dataset { -static std::shared_ptr instance_ptr_ = nullptr; - -std::shared_ptr TdtPlugin::GetInstance() { - if (instance_ptr_ == nullptr) { - instance_ptr_ = std::shared_ptr(new TdtPlugin); - } - return instance_ptr_; -} - -TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { - MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; - std::vector items; - double start_time; - auto ret = translate(ts_row, items); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "TDT converting tensor failed!"; - return FAILED; - } - if (profiling) { - start_time = ProfilingTime::GetCurMilliSecond(); - } - if (tdt::TdtHostPushData(channel_name, items) != 0) { - MS_LOG(ERROR) << "TDT pushing data failed!"; - return FAILED; - } - if (profiling) { - double end_time = ProfilingTime::GetCurMilliSecond(); - time = (int32_t)(end_time - start_time); - } - return SUCCESS; -} - -TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) { - switch (d_type.value()) { - case DataType::DE_BOOL: - datatype = "bool"; - break; - case DataType::DE_INT8: - datatype = "int8"; - break; - case DataType::DE_UINT8: - datatype = "uint8"; - break; - case DataType::DE_INT16: - datatype = "int16"; - break; - case DataType::DE_UINT16: - datatype = "uint16"; - break; - case DataType::DE_INT32: - datatype = "int32"; - break; - case DataType::DE_UINT32: - datatype = "uint32"; - break; - case DataType::DE_FLOAT16: - datatype = "float16"; - break; - case DataType::DE_FLOAT32: - datatype = "float32"; - break; - case DataType::DE_FLOAT64: - datatype = "float64"; - break; - case DataType::DE_INT64: - datatype = "int64"; - break; - case DataType::DE_UINT64: - datatype = "uint64"; - break; - default: - return FAILED; - } - return SUCCESS; -} - -TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector &items) { - if (ts_row.size() == 0) { - MS_LOG(ERROR) << "TDT the size of row is zero."; - return SUCCESS; - } - for (auto ts : ts_row) { - std::string datatype; - TdtStatus status = getTdtType(ts->type(), datatype); - if (status != SUCCESS) { - return status; - } - TensorShape tsShape = ts->shape(); - std::string dataShapes = "["; - for (auto dim : tsShape.AsVector()) { - (void)dataShapes.append(std::to_string(dim)).append(","); - } - dataShapes.pop_back(); - (void)dataShapes.append("]"); - DataItem data_item; - data_item.dataType_ = tdt::TDT_TENSOR; - data_item.tensorShape_ = dataShapes; - data_item.tensorType_ = datatype; - data_item.dataLen_ = ts->SizeInBytes(); - data_item.dataPtr_ = - std::shared_ptr(reinterpret_cast(&(*ts->begin())), [](const void *elem) {}); - items.emplace_back(data_item); - MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " - << ts->Size() << "."; - } - return SUCCESS; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h b/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h deleted file mode 100644 index 304b205b81..0000000000 --- a/mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_TDT_TDT_PLUGIN_H_ -#define DATASET_ENGINE_TDT_TDT_PLUGIN_H_ - -#include -#include -#include -#include -#include -#include -#include "tdt/tdt_host_interface.h" - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { -enum TdtStatus { SUCCESS, FAILED }; - -using tdt::DataItem; - -class TdtPlugin { - public: - static std::shared_ptr GetInstance(); - - TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); - - private: - TdtPlugin() {} - - TdtStatus getTdtType(DataType d_type, std::string &datatype); - - TdtStatus translate(const TensorRow &ts_row, std::vector &items); - - void *tdt_handle_ = nullptr; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_TDT_TDT_PLUGIN_H_ diff --git a/mindspore/ccsrc/dataset/include/datasets.h b/mindspore/ccsrc/dataset/include/datasets.h deleted file mode 100644 index 586fff2107..0000000000 --- a/mindspore/ccsrc/dataset/include/datasets.h +++ /dev/null @@ -1,357 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_INCLUDE_DATASETS_H_ -#define DATASET_INCLUDE_DATASETS_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/include/tensor.h" -#include "dataset/include/iterator.h" -#include "dataset/include/samplers.h" - -namespace mindspore { -namespace dataset { - -// Forward declare -class DatasetOp; -class DataSchema; -class Tensor; -class TensorShape; - -namespace api { - -class TensorOperation; -class SamplerObj; -class ImageFolderDataset; -class MnistDataset; -class BatchDataset; -class RepeatDataset; -class MapDataset; -class ShuffleDataset; -class Cifar10Dataset; -class ProjectDataset; - -/// \brief Function to create an ImageFolderDataset -/// \notes A source dataset that reads images from a tree of directories -/// All images within one folder have the same label -/// The generated dataset has two columns ['image', 'label'] -/// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] decode A flag to decode in ImageFolder -/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, -/// A `RandomSampler` will be used to randomly iterate the entire dataset -/// \param[in] extensions File extensions to be read -/// \param[in] class_indexing a class name to label map -/// \return Shared pointer to the current ImageFolderDataset -std::shared_ptr ImageFolder(std::string dataset_dir, bool decode = false, - std::shared_ptr sampler = nullptr, - std::set extensions = {}, - std::map class_indexing = {}); - -/// \brief Function to create a MnistDataset -/// \notes The generated dataset has two columns ['image', 'label'] -/// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, -/// A `RandomSampler` will be used to randomly iterate the entire dataset -/// \return Shared pointer to the current MnistDataset -std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler = nullptr); - -/// \brief Function to create a Cifar10 Dataset -/// \notes The generated dataset has two columns ['image', 'label'] -/// \param[in] dataset_dir Path to the root directory that contains the dataset -/// \param[in] num_samples The number of images to be included in the dataset -/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` -/// will be used to randomly iterate the entire dataset -/// \return Shared pointer to the current Dataset -std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, - std::shared_ptr sampler); - -/// \class Dataset datasets.h -/// \brief A base class to represent a dataset in the data pipeline. -class Dataset : public std::enable_shared_from_this { - public: - friend class Iterator; - - /// \brief Constructor - Dataset(); - - /// \brief Destructor - ~Dataset() = default; - - /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object - /// \return shared pointer to the list of newly created DatasetOps - virtual std::shared_ptr>> Build() = 0; - - /// \brief Pure virtual function for derived class to implement parameters validation - /// \return bool True if all the params are valid - virtual bool ValidateParams() = 0; - - /// \brief Setter function for runtime number of workers - /// \param[in] num_workers The number of threads in this operator - /// \return Shared pointer to the original object - std::shared_ptr SetNumWorkers(int32_t num_workers) { - num_workers_ = num_workers; - return shared_from_this(); - } - - /// \brief Function to create an Iterator over the Dataset pipeline - /// \return Shared pointer to the Iterator - std::shared_ptr CreateIterator(); - - /// \brief Function to create a BatchDataset - /// \notes Combines batch_size number of consecutive rows into batches - /// \param[in] batch_size Path to the root directory that contains the dataset - /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete - /// batch. If true, and if there are less than batch_size rows - /// available to make the last batch, then those rows will - /// be dropped and not propagated to the next node - /// \return Shared pointer to the current BatchDataset - std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); - - /// \brief Function to create a RepeatDataset - /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 - /// \param[in] count Number of times the dataset should be repeated - /// \return Shared pointer to the current Dataset - /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` - /// due to a limitation in the current implementation - std::shared_ptr Repeat(int32_t count = -1); - - /// \brief Function to create a MapDataset - /// \notes Applies each operation in operations to this dataset - /// \param[in] operations Vector of operations to be applied on the dataset. Operations are - /// applied in the order they appear in this list - /// \param[in] input_columns Vector of the names of the columns that will be passed to the first - /// operation as input. The size of this list must match the number of - /// input columns expected by the first operator. The default input_columns - /// is the first column - /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation - /// This parameter is mandatory if len(input_columns) != len(output_columns) - /// The size of this list must match the number of output columns of the - /// last operation. The default output_columns will have the same - /// name as the input columns, i.e., the columns will be replaced - /// \param[in] project_columns A list of column names to project - /// \return Shared pointer to the current MapDataset - std::shared_ptr Map(std::vector> operations, - std::vector input_columns = {}, - std::vector output_columns = {}, - const std::vector &project_columns = {}); - - /// \brief Function to create a Shuffle Dataset - /// \notes Randomly shuffles the rows of this dataset - /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling - /// \return Shared pointer to the current ShuffleDataset - std::shared_ptr Shuffle(int32_t shuffle_size); - - /// \brief Function to create a Project Dataset - /// \notes Applies project to the dataset - /// \param[in] columns The name of columns to project - /// \return Shared pointer to the current Dataset - std::shared_ptr Project(const std::vector &columns); - - protected: - std::vector> children; - std::shared_ptr parent; - - int32_t num_workers_; - int32_t rows_per_buffer_; - int32_t connector_que_size_; -}; - -/* ####################################### Derived Dataset classes ################################# */ - -/// \class ImageFolderDataset -/// \brief A Dataset derived class to represent ImageFolder dataset -class ImageFolderDataset : public Dataset { - public: - /// \brief Constructor - ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, - std::set extensions, std::map class_indexing); - - /// \brief Destructor - ~ImageFolderDataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - std::string dataset_dir_; - bool decode_; - bool recursive_; - std::shared_ptr sampler_; - std::map class_indexing_; - std::set exts_; -}; - -class MnistDataset : public Dataset { - public: - /// \brief Constructor - MnistDataset(std::string dataset_dir, std::shared_ptr sampler); - - /// \brief Destructor - ~MnistDataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - std::string dataset_dir_; - std::shared_ptr sampler_; -}; - -class BatchDataset : public Dataset { - public: - /// \brief Constructor - BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, - std::map>> pad_map); - - /// \brief Destructor - ~BatchDataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - int32_t batch_size_; - bool drop_remainder_; - bool pad_; - std::vector cols_to_map_; - std::map>> pad_map_; -}; - -class RepeatDataset : public Dataset { - public: - /// \brief Constructor - explicit RepeatDataset(uint32_t count); - - /// \brief Destructor - ~RepeatDataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - uint32_t repeat_count_; -}; - -class ShuffleDataset : public Dataset { - public: - ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch); - - ~ShuffleDataset() = default; - - std::shared_ptr>> Build() override; - - bool ValidateParams() override; - - private: - int32_t shuffle_size_; - uint32_t shuffle_seed_; - bool reset_every_epoch_; -}; - -class MapDataset : public Dataset { - public: - /// \brief Constructor - MapDataset(std::vector> operations, std::vector input_columns = {}, - std::vector output_columns = {}, const std::vector &columns = {}); - - /// \brief Destructor - ~MapDataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - std::vector> operations_; - std::vector input_columns_; - std::vector output_columns_; - std::vector project_columns_; -}; - -class Cifar10Dataset : public Dataset { - public: - /// \brief Constructor - Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler); - - /// \brief Destructor - ~Cifar10Dataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - std::string dataset_dir_; - int32_t num_samples_; - std::shared_ptr sampler_; -}; - -class ProjectDataset : public Dataset { - public: - /// \brief Constructor - explicit ProjectDataset(const std::vector &columns); - - /// \brief Destructor - ~ProjectDataset() = default; - - /// \brief a base class override function to create the required runtime dataset op objects for this class - /// \return shared pointer to the list of newly created DatasetOps - std::shared_ptr>> Build() override; - - /// \brief Parameters validation - /// \return bool true if all the params are valid - bool ValidateParams() override; - - private: - std::vector columns_; -}; -} // namespace api -} // namespace dataset -} // namespace mindspore -#endif // DATASET_INCLUDE_DATASETS_H_ diff --git a/mindspore/ccsrc/dataset/include/iterator.h b/mindspore/ccsrc/dataset/include/iterator.h deleted file mode 100644 index 1c78031771..0000000000 --- a/mindspore/ccsrc/dataset/include/iterator.h +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_INCLUDE_ITERATOR_H_ -#define DATASET_INCLUDE_ITERATOR_H_ - -#include -#include -#include -#include -#include "dataset/include/status.h" - -namespace mindspore { -namespace dataset { - -// Forward declare -class ExecutionTree; -class DatasetIterator; -class DatasetOp; -class Tensor; - -namespace api { - -class Dataset; - -using TensorMap = std::unordered_map>; - -// Abstract class for iterating over the dataset. -class Iterator { - public: - /// \brief Constructor - Iterator() = default; - - /// \brief Destructor - ~Iterator() = default; - - /// \brief Method for building and launching the pipeline. - /// \param[in] ops - a vector of DatasetOp in the data pipeline. - /// \return - a Status error code, returns OK if no error encountered. - Status BuildAndLaunchTree(std::shared_ptr ds); - - /// \brief Function to get the next row from the data pipeline. - /// \param[out] row - the output tensor row. - void GetNextRow(TensorMap *row); - - /// \brief Function to shut down the data pipeline. - void Stop(); - - class _Iterator { - public: - explicit _Iterator(Iterator *lt) : lt_{lt}, cur_row_{nullptr} { - if (lt_) { - cur_row_ = new TensorMap(); - lt_->GetNextRow(cur_row_); - } - } - - // Destructor - ~_Iterator() { - if (cur_row_) { - delete cur_row_; - } - } - - _Iterator &operator++() { - if (lt_) { - ++ind_; - lt_->GetNextRow(cur_row_); - } - if (cur_row_ && cur_row_->size() == 0) { - delete cur_row_; - cur_row_ = nullptr; - } - return *this; - } // prefix ++ overload - TensorMap &operator*() { return *cur_row_; } // dereference operator - TensorMap *operator->() { return cur_row_; } - - bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; } - - private: - int ind_; // the cur node our Iterator points to - Iterator *lt_; - TensorMap *cur_row_; - }; - - _Iterator begin() { return _Iterator(this); } - - _Iterator end() { return _Iterator(nullptr); } - - private: - // Runtime tree. - // Use shared_ptr instead of unique_ptr because the DatasetIterator constructor takes in a shared_ptr type. - std::shared_ptr tree_; - - // Runtime iterator - std::unique_ptr iterator_; -}; -} // namespace api -} // namespace dataset -} // namespace mindspore -#endif // DATASET_INCLUDE_ITERATOR_H_ diff --git a/mindspore/ccsrc/dataset/include/transforms.h b/mindspore/ccsrc/dataset/include/transforms.h deleted file mode 100644 index c3a1540ae8..0000000000 --- a/mindspore/ccsrc/dataset/include/transforms.h +++ /dev/null @@ -1,380 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_API_TRANSFORMS_H_ -#define DATASET_API_TRANSFORMS_H_ - -#include -#include -#include "dataset/core/constants.h" - -namespace mindspore { -namespace dataset { - -class TensorOp; - -namespace api { -// Abstract class to represent a dataset in the data pipeline. -class TensorOperation : public std::enable_shared_from_this { - public: - /// \brief Constructor - TensorOperation(); - - /// \brief Destructor - ~TensorOperation() = default; - - /// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object. - /// \return shared pointer to the newly created TensorOp. - virtual std::shared_ptr Build() = 0; - - virtual bool ValidateParams() = 0; -}; - -// Transform operations for performing computer vision. -namespace vision { - -class NormalizeOperation; -class DecodeOperation; -class ResizeOperation; -class RandomCropOperation; -class CenterCropOperation; -class UniformAugOperation; -class RandomHorizontalFlipOperation; -class RandomVerticalFlipOperation; -class RandomRotationOperation; -class PadOperation; -class CutOutOperation; -class RandomColorAdjustOperation; - -/// \brief Function to create a Normalize TensorOperation. -/// \notes Normalize the input image with respect to mean and standard deviation. -/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. -/// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr Normalize(std::vector mean, std::vector std); - -/// \brief Function to create a Decode TensorOperation. -/// \notes Decode the input image in RGB mode. -/// \param[in] rgb - a boolean of whether to decode in RGB mode or not. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr Decode(bool rgb = true); - -/// \brief Function to create a Resize TensorOperation. -/// \notes Resize the input image to the given size.. -/// \param[in] size - a vector representing the output size of the resized image. -/// If size is a single value, the image will be resized to this value with -/// the same image aspect ratio. If size has 2 values, it should be (height, width). -/// \param[in] interpolation An enum for the mode of interpolation -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr Resize(std::vector size, - InterpolationMode interpolation = InterpolationMode::kLinear); - -/// \brief Function to create a RandomCrop TensorOperation. -/// \notes Crop the input image at a random location. -/// \param[in] size - a vector representing the output size of the cropped image. -/// If size is a single value, a square crop of size (size, size) is returned. -/// If size has 2 values, it should be (height, width). -/// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided, -/// it pads the left, top, right and bottom respectively. -/// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than -/// the given output size. -/// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to -/// fill R, G, B channels respectively. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomCrop(std::vector size, std::vector padding = {0, 0, 0, 0}, - bool pad_if_needed = false, - std::vector fill_value = {0, 0, 0}); - -/// \brief Function to create a CenterCrop TensorOperation. -/// \notes Crops the input image at the center to the given size. -/// \param[in] size - a vector representing the output size of the cropped image. -/// If size is a single value, a square crop of size (size, size) is returned. -/// If size has 2 values, it should be (height, width). -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr CenterCrop(std::vector size); - -/// \brief Function to create a UniformAugment TensorOperation. -/// \notes Tensor operation to perform randomly selected augmentation. -/// \param[in] operations - a vector of TensorOperation operations. -/// \param[in] num_ops - integer representing the number of OPs to be selected and applied. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr UniformAugment(std::vector> operations, - int32_t num_ops = 2); - -/// \brief Function to create a RandomHorizontalFlip TensorOperation. -/// \notes Tensor operation to perform random horizontal flip. -/// \param[in] prob - float representing the probability of flip. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomHorizontalFlip(float prob = 0.5); - -/// \brief Function to create a RandomVerticalFlip TensorOperation. -/// \notes Tensor operation to perform random vertical flip. -/// \param[in] prob - float representing the probability of flip. -/// \return Shared pointer to the current TensorOperation. -std::shared_ptr RandomVerticalFlip(float prob = 0.5); - -/// \brief Function to create a RandomRotation TensorOp -/// \notes Rotates the image according to parameters -/// \param[in] degrees A float vector size 2, representing the starting and ending degree -/// \param[in] resample An enum for the mode of interpolation -/// \param[in] expand A boolean representing whether the image is expanded after rotation -/// \param[in] center A float vector size 2, representing the x and y center of rotation. -/// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color -/// \return Shared pointer to the current TensorOp -std::shared_ptr RandomRotation( - std::vector degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, - std::vector center = {-1, -1}, std::vector fill_value = {0, 0, 0}); - -/// \brief Function to create a Pad TensorOp -/// \notes Pads the image according to padding parameters -/// \param[in] padding A vector representing the number of pixels to pad the image -/// If vector has one value, it pads all sides of the image with that value -/// If vector has two values, it pads left and right with the first and -/// top and bottom with the second value -/// If vector has four values, it pads left, top, right, and bottom with -/// those values respectively -/// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is -/// BorderType.kConstant. If 3 values are provided, -/// it is used to fill R, G, B channels respectively -/// \param[in] padding_mode The method of padding (default=BorderType.kConstant) -/// Can be any of -/// [BorderType.kConstant, BorderType.kEdge, BorderType.kReflect, BorderType.kSymmetric] -/// - BorderType.kConstant, means it fills the border with constant values -/// - BorderType.kEdge, means it pads with the last value on the edge -/// - BorderType.kReflect, means it reflects the values on the edge omitting the last value of edge -/// - BorderType.kSymmetric, means it reflects the values on the edge repeating the last value of edge -/// \return Shared pointer to the current TensorOp -std::shared_ptr Pad(std::vector padding, std::vector fill_value = {0}, - BorderType padding_mode = BorderType::kConstant); - -/// \brief Function to create a CutOut TensorOp -/// \notes Randomly cut (mask) out a given number of square patches from the input image -/// \param[in] length Integer representing the side length of each square patch -/// \param[in] num_patches Integer representing the number of patches to be cut out of an image -/// \return Shared pointer to the current TensorOp -std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1); - -/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image -/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values -/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} -/// \param[in] contrast Contrast adjustment factor. Must be a vector of one or two values -/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} -/// \param[in] saturation Saturation adjustment factor. Must be a vector of one or two values -/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} -/// \param[in] hue Brightness adjustment factor. Must be a vector of one or two values -/// if it's a vector of two values it must be in the form of [min, max] where -0.5 <= min <= max <= 0.5 -/// Default value is {0, 0} -/// \return Shared pointer to the current TensorOp -std::shared_ptr RandomColorAdjust(std::vector brightness = {1.0, 1.0}, - std::vector contrast = {1.0, 1.0}, - std::vector saturation = {1.0, 1.0}, - std::vector hue = {0.0, 0.0}); - -/* ####################################### Derived TensorOperation classes ################################# */ - -class NormalizeOperation : public TensorOperation { - public: - NormalizeOperation(std::vector mean, std::vector std); - - ~NormalizeOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector mean_; - std::vector std_; -}; - -class DecodeOperation : public TensorOperation { - public: - explicit DecodeOperation(bool rgb = true); - - ~DecodeOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - bool rgb_; -}; - -class ResizeOperation : public TensorOperation { - public: - explicit ResizeOperation(std::vector size, - InterpolationMode interpolation_mode = InterpolationMode::kLinear); - - ~ResizeOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector size_; - InterpolationMode interpolation_; -}; - -class RandomCropOperation : public TensorOperation { - public: - RandomCropOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, - bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}); - - ~RandomCropOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector size_; - std::vector padding_; - bool pad_if_needed_; - std::vector fill_value_; -}; - -class CenterCropOperation : public TensorOperation { - public: - explicit CenterCropOperation(std::vector size); - - ~CenterCropOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector size_; -}; - -class UniformAugOperation : public TensorOperation { - public: - explicit UniformAugOperation(std::vector> operations, int32_t num_ops = 2); - - ~UniformAugOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector> operations_; - int32_t num_ops_; -}; - -class RandomHorizontalFlipOperation : public TensorOperation { - public: - explicit RandomHorizontalFlipOperation(float probability = 0.5); - - ~RandomHorizontalFlipOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - float probability_; -}; - -class RandomVerticalFlipOperation : public TensorOperation { - public: - explicit RandomVerticalFlipOperation(float probability = 0.5); - - ~RandomVerticalFlipOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - float probability_; -}; - -class RandomRotationOperation : public TensorOperation { - public: - RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, bool expand, - std::vector center, std::vector fill_value); - - ~RandomRotationOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector degrees_; - InterpolationMode interpolation_mode_; - std::vector center_; - bool expand_; - std::vector fill_value_; -}; - -class PadOperation : public TensorOperation { - public: - PadOperation(std::vector padding, std::vector fill_value = {0}, - BorderType padding_mode = BorderType::kConstant); - - ~PadOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector padding_; - std::vector fill_value_; - BorderType padding_mode_; -}; - -class CutOutOperation : public TensorOperation { - public: - explicit CutOutOperation(int32_t length, int32_t num_patches = 1); - - ~CutOutOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - int32_t length_; - int32_t num_patches_; -}; - -class RandomColorAdjustOperation : public TensorOperation { - public: - RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, - std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); - - ~RandomColorAdjustOperation() = default; - - std::shared_ptr Build() override; - - bool ValidateParams() override; - - private: - std::vector brightness_; - std::vector contrast_; - std::vector saturation_; - std::vector hue_; -}; -} // namespace vision -} // namespace api -} // namespace dataset -} // namespace mindspore -#endif // DATASET_API_TRANSFORMS_H_ diff --git a/mindspore/ccsrc/dataset/include/utils/log_adapter.h b/mindspore/ccsrc/dataset/include/utils/log_adapter.h deleted file mode 120000 index 5cecc45938..0000000000 --- a/mindspore/ccsrc/dataset/include/utils/log_adapter.h +++ /dev/null @@ -1 +0,0 @@ -../../../utils/log_adapter.h \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/include/utils/overload.h b/mindspore/ccsrc/dataset/include/utils/overload.h deleted file mode 120000 index d163e52748..0000000000 --- a/mindspore/ccsrc/dataset/include/utils/overload.h +++ /dev/null @@ -1 +0,0 @@ -../../../utils/overload.h \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc deleted file mode 100644 index 87115fd3ce..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/concatenate_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); - return Status::OK(); -} - -Status ConcatenateOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - - std::vector inputs_copy; - inputs_copy.push_back(inputs[0].Squeeze()); - - CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); - - outputs.clear(); - dsize_t output_shape = 0; - output_shape = output_shape + inputs.at(0).NumOfElements(); - if (prepend_ != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); - output_shape = output_shape + prepend_->shape().NumOfElements(); - } - if (append_ != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); - output_shape = output_shape + append_->shape().NumOfElements(); - } - - outputs.emplace_back(std::vector{output_shape}); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h b/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h deleted file mode 100644 index b85d75a68e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/concatenate_op.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ -#define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -class ConcatenateOp : public TensorOp { - public: - /// Constructor to ConcatenateOp. - /// @param int8_t axis - axis to concatenate tensors along. - /// @param std::shared_ptr prepend - prepend tensor. - /// @param std::shared_ptr append -append tensor. - explicit ConcatenateOp(int8_t axis, std::shared_ptr prepend, std::shared_ptr append) - : axis_(axis), prepend_(prepend), append_(append) {} - - ~ConcatenateOp() override = default; - - /// Print method to see which tensor Op this is. - /// @param std::ostream &out - output stream object. - void Print(std::ostream &out) const override { out << "ConcatenateOp"; } - - /// Compute method allowing multiple tensors as inputs - /// @param TensorRow &input - input tensor rows - /// @param TensorRow *output - output tensor rows - Status Compute(const TensorRow &input, TensorRow *output) override; - - /// Compute tensor output shape - /// @param std::vector &inputs - vector of input tensor shapes - /// @param std::vector &inputs, std::vector &outputs) override; - - /// Number of inputs the tensor operation accepts - uint32_t NumInput() override { return 0; } - - std::string Name() const override { return kConcatenateOp; } - - private: - int8_t axis_; - std::shared_ptr prepend_; - std::shared_ptr append_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_CONCATENATE_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/dataset/kernels/data/data_utils.cc deleted file mode 100644 index 0d437675f8..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.cc +++ /dev/null @@ -1,656 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/data/data_utils.h" - -#include -#include -#include -#include - -#include "dataset/core/constants.h" -#include "dataset/core/data_type.h" -#ifdef ENABLE_PYTHON -#include "dataset/core/pybind_support.h" -#endif -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, - dsize_t num_classes, int64_t index) { - uint64_t class_idx; - if (input->Rank() == 0) { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); - } else { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); - } - if (class_idx >= static_cast(num_classes)) { - RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); - } - if (input->type() == DataType::DE_UINT64) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_UINT32) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_UINT16) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_UINT8) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else { - RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input."); - } - return Status::OK(); -} - -Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, - int64_t index) { - int64_t class_idx; - if (input->Rank() == 0) { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); - } else { - RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); - } - if (class_idx >= static_cast(num_classes)) { - RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); - } - if (input->type() == DataType::DE_INT64) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_INT32) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_INT16) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else if (input->type() == DataType::DE_INT8) { - RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); - } else { - RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input."); - } - return Status::OK(); -} - -Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes) { - input->Squeeze(); - - if (input->Rank() > 1) { // We expect the input to be int he first dimension - RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); - } - if (!input->type().IsInt()) { - RETURN_STATUS_UNEXPECTED("One hot does not support input of this type."); - } - try { - dsize_t num_elements = 1; - if (input->Rank() == 1) num_elements = input->shape()[0]; - TensorShape out_shape({num_elements, num_classes}); - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); - RETURN_IF_NOT_OK(out->Zero()); - for (dsize_t i = 0; i < num_elements; ++i) { - if (input->type().IsUnsignedInt()) { - RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i)); - } else { - RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i)); - } - } - out->Squeeze(); - *output = out; - return Status::OK(); - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); - } -} - -Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value) { - const DataType &fill_type = fill_value->type(); - const DataType &input_type = input->type(); - const TensorShape &input_shape = input->shape(); - - CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)), - "Types do not match"); - - CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); - - std::shared_ptr out, fill_output; - - if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) { - auto op = std::make_unique(input_type); - RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); - } else { - fill_output = fill_value; - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); - - switch (input_type.value()) { - case DataType::DE_BOOL: { - bool value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT8: { - int8_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT8: { - uint8_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT16: { - uint16_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT16: { - int16_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT32: { - uint32_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT32: { - int32_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_UINT64: { - uint64_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_INT64: { - int64_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_FLOAT16: { - int64_t value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_FLOAT32: { - float value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_FLOAT64: { - double value = 0; - RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); - out->Fill(value); - break; - } - case DataType::DE_STRING: { - std::vector strings; - std::string_view fill_string_view; - RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); - std::string fill_string = std::string(fill_string_view); - for (int i = 0; i < input_shape.NumOfElements(); i++) { - strings.emplace_back(fill_string); - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); - break; - } - case DataType::DE_UNKNOWN: { - RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); - break; - } - } - - *output = out; - return Status::OK(); -} -template -void Cast(const std::shared_ptr &input, std::shared_ptr *output) { - auto in_itr = input->begin(); - auto out_itr = (*output)->begin(); - auto out_end = (*output)->end(); - - for (; out_itr != out_end; static_cast(in_itr++), static_cast(out_itr++)) - *out_itr = static_cast(*in_itr); -} - -template -void CastFrom(const std::shared_ptr &input, std::shared_ptr *output) { - switch ((*output)->type().value()) { - case DataType::DE_BOOL: - Cast(input, output); - break; - case DataType::DE_INT8: - Cast(input, output); - break; - case DataType::DE_UINT8: - Cast(input, output); - break; - case DataType::DE_INT16: - Cast(input, output); - break; - case DataType::DE_UINT16: - Cast(input, output); - break; - case DataType::DE_INT32: - Cast(input, output); - break; - case DataType::DE_UINT32: - Cast(input, output); - break; - case DataType::DE_INT64: - Cast(input, output); - break; - case DataType::DE_UINT64: - Cast(input, output); - break; - case DataType::DE_FLOAT16: - Cast(input, output); - break; - case DataType::DE_FLOAT32: - Cast(input, output); - break; - case DataType::DE_FLOAT64: - Cast(input, output); - break; - case DataType::DE_UNKNOWN: - MS_LOG(ERROR) << "Unknown data type."; - break; - } -} - -// Type cast operator -Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); - - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); - switch (input->type().value()) { - case DataType::DE_BOOL: - CastFrom(input, output); - break; - case DataType::DE_INT8: - CastFrom(input, output); - break; - case DataType::DE_UINT8: - CastFrom(input, output); - break; - case DataType::DE_INT16: - CastFrom(input, output); - break; - case DataType::DE_UINT16: - CastFrom(input, output); - break; - case DataType::DE_INT32: - CastFrom(input, output); - break; - case DataType::DE_UINT32: - CastFrom(input, output); - break; - case DataType::DE_INT64: - CastFrom(input, output); - break; - case DataType::DE_UINT64: - CastFrom(input, output); - break; - case DataType::DE_FLOAT16: - CastFrom(input, output); - break; - case DataType::DE_FLOAT32: - CastFrom(input, output); - break; - case DataType::DE_FLOAT64: - CastFrom(input, output); - break; - case DataType::DE_UNKNOWN: - // sanity check, unreachable code. - RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type."); - } - return Status::OK(); -} - -Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { - // initiate new tensor for type cast - DataType new_type = DataType("float16"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); - RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); - - auto in_itr = input->begin(); - auto out_itr = (*output)->begin(); - auto out_end = (*output)->end(); - - for (; out_itr != out_end; in_itr++, out_itr++) { - float element = *in_itr; - float float16_max = static_cast(std::numeric_limits::max()); - float float16_min = static_cast(std::numeric_limits::lowest()); - if (element > float16_max || element < float16_min) { - RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" + - std::to_string(float16_max) + ", " + std::to_string(float16_min) + "]."); - } - - *out_itr = Eigen::half(*in_itr); - } - - return Status::OK(); -} - -Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, - const std::shared_ptr &pad_val) { - if (pad_val == nullptr) { - if (src->type().IsNumeric()) { - return PadEndNumeric(src, dst, pad_shape, 0); - } else { - return PadEndString(src, dst, pad_shape, ""); - } - } - CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(), - "Source and pad_value tensors are not of the same type."); - if (pad_val->type().IsNumeric()) { - std::shared_ptr float_pad_value; - RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32))); - float val = 0; - RETURN_IF_NOT_OK(float_pad_value->GetItemAt(&val, {})); - return PadEndNumeric(src, dst, pad_shape, val); - } - std::string_view val; - RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {})); - return PadEndString(src, dst, pad_shape, std::string(val)); -} - -Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, float pad_val) { - CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); - if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { - (*dst) = src; // if no padding, copy the pointer - } else { - CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); - auto tensor_type = src->type().value(); - if (pad_val == 0) { // if pad with zero, don't care what type it is - RETURN_IF_NOT_OK((*dst)->Zero()); - } else if (tensor_type == DataType::DE_INT8) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_BOOL) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_UINT8) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_INT16) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_FLOAT16) { - RETURN_IF_NOT_OK((*dst)->Fill(static_cast(pad_val))); - } else if (tensor_type == DataType::DE_UINT16) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_INT32) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_UINT32) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_INT64) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_UINT64) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_FLOAT32) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else if (tensor_type == DataType::DE_FLOAT64) { - RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); - } else { - RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type"); - } - std::vector cur_ind(src->Rank(), 0); - RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0)); - } - return Status::OK(); -} -Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, - std::vector cur_ind, size_t cur_dim) { - if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data - dst->CopyLastDimAt(src, cur_ind); - } else { // not the last dimension, keep doing recursion - dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); - for (dsize_t i = 0; i < min_ind; i++) { - cur_ind[cur_dim] = i; - RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1)); - } - } - return Status::OK(); -} - -Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, const std::string &pad_val) { - CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); - if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { - (*dst) = src; // if no padding, copy the pointer - } else { - CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); - std::vector cur_ind(src->Rank(), 0); - std::vector strings; - RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val)); - RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape))); - } - return Status::OK(); -} - -Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, - const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, - const std::string &pad_value) { - if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data - dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); - for (dsize_t i = 0; i < min_ind; i++) { - cur_ind[cur_dim] = i; - std::string_view item; - RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind)); - dst->emplace_back(item); - } - for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) { - dst->emplace_back(pad_value); - } - - } else { // not the last dimension, keep doing recursion - dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); - for (dsize_t i = 0; i < min_ind; i++) { - cur_ind[cur_dim] = i; - RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value)); - } - dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim]; - for (dsize_t i = 0; i < count; i++) { - dst->emplace_back(pad_value); - } - } - return Status::OK(); -} - -template -Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, - const std::shared_ptr &value_tensor, RelationalOp op) { - T value; - RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {})); - auto in_itr = input->begin(); - auto out_itr = output->begin(); - for (; in_itr != input->end(); in_itr++, out_itr++) { - switch (op) { - case RelationalOp::kEqual: - *out_itr = (*in_itr == value); - break; - case RelationalOp::kNotEqual: - *out_itr = (*in_itr != value); - break; - case RelationalOp::kGreater: - *out_itr = (*in_itr > value); - break; - case RelationalOp::kGreaterEqual: - *out_itr = (*in_itr >= value); - break; - case RelationalOp::kLess: - *out_itr = (*in_itr < value); - break; - case RelationalOp::kLessEqual: - *out_itr = (*in_itr <= value); - break; - default: - RETURN_STATUS_UNEXPECTED("Unknown relational operator."); - } - } - return Status::OK(); -} - -Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, - RelationalOp op) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(), - "Cannot convert constant value to the type of the input tensor."); - CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); - - std::unique_ptr value_cast_op(new TypeCastOp(input->type())); - std::shared_ptr casted_value; - if (input->type().IsNumeric()) { - RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value)); - } else { - casted_value = value; - } - - switch (input->type().value()) { - case DataType::DE_BOOL: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT8: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT8: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT16: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT16: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT32: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT32: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UINT64: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_INT64: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_FLOAT16: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_FLOAT32: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_FLOAT64: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_STRING: - RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); - break; - case DataType::DE_UNKNOWN: - RETURN_STATUS_UNEXPECTED("Unsupported input type."); - break; - } - return Status::OK(); -} - -Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, - std::shared_ptr append) { - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); - CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); - - axis = Tensor::HandleNeg(axis, input[0]->shape().Rank()); - CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); - - std::shared_ptr out; - if (prepend != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); - } else { - out = input[0]; - } - for (dsize_t i = 1; i < input.size(); i++) { - std::shared_ptr out_t; - CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); - out = out_t; - } - std::shared_ptr out_t; - if (append != nullptr) { - CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); - RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); - } else { - out_t = out; - } - output->push_back(out_t); - - return Status::OK(); -} - -Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, - std::shared_ptr append) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); - - TensorShape t({}); - - for (dsize_t i = 0; i < input->shape().Rank(); i++) { - if (i != axis) { - t = t.AppendDim(input->shape()[i]); - } else { - dsize_t new_shape = input->shape()[i] + append->shape()[i]; - - t = t.AppendDim(new_shape); - } - } - std::shared_ptr out; - - if (input->type().IsNumeric()) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); - - RETURN_IF_NOT_OK(out->Concatenate({0}, input)); - RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); - *output = out; - } else { - std::vector strings; - - auto itr = input->begin(); - for (; itr != input->end(); itr++) { - strings.emplace_back(*itr); - } - itr = append->begin(); - for (; itr != append->end(); itr++) { - strings.emplace_back(*itr); - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); - - *output = out; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/dataset/kernels/data/data_utils.h deleted file mode 100644 index 6034e2a0eb..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/data_utils.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_DATA_UTILS_H_ -#define DATASET_KERNELS_DATA_DATA_UTILS_H_ - -#include -#include -#include -#include "dataset/core/constants.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" - -namespace mindspore { -namespace dataset { -// Returns Onehot encoding of the input tensor. -// Example: if input=2 and numClasses=3, the output is [0 0 1]. -// @param input: Tensor has type DE_UINT64, the non-one hot values are stored -// along the first dimensions or rows.. -// If the rank of input is not 1 or the type is not DE_UINT64, -// then it will fail. -// @param output: Tensor. The shape of the output tensor is -// and the type is same as input. -// @param num_classes: Number of classes to. -Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes); - -Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, - dsize_t num_classes, int64_t index); - -Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, - int64_t index); - -// Returns a tensor of shape input filled with the passed fill_value -// @param input Tensor -// @param output Tensor. The shape and type of the output tensor is same as input -// @param fill_value Tensor. A scalar tensor used to fill the output tensor - -Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value); - -// Returns a type changed input tensor. -// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp -// @param input Tensor -// @param output Tensor. The shape of the output tensor is same as input with the type changed. -// @param data_type: type of data to cast data to -// @note: this operation will do a memcpy and if the value is truncated then precision will be lost - -template -void CastFrom(const std::shared_ptr &input, std::shared_ptr *output); - -template -void Cast(const std::shared_ptr &input, std::shared_ptr *output); - -Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); - -Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type); - -// Pad input tensor according pad_shape, need to have same rank. -// Based on the type of the input tensor, PadEndNumeric/String will be called. -// @param std::shared_ptr src - tensor to pad from -// @param std::shared_ptr *dst - return tensor padded -// @param std::vector pad_shape - shape to pad to -// @param std::shared_ptr pad_val - value to pad with in Tensor format, -// @return - The error code return -Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, - const std::shared_ptr &pad_val); - -// Pad input numeric tensor according pad_shape, need to have same rank. -// @param std::shared_ptr src - tensor to pad from -// @param std::shared_ptr *dst - return tensor padded -// @param std::vector pad_shape - shape to pad to -// @param float pad_val - value to pad with -// @return - The error code return -Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, float pad_val); - -// recursive helper function for padding numric tensors. This function could be very expensive if called on a -// multi-dimensional tensor it is only meant to be called by PadEndNumeric. -// @tparam T - type of tensor and fill value -// @param std::shared_ptr src - Tensor to pad from -// @param std::shared_ptr* dst - Tensor to pad to, return value -// @param std::vector cur_ind - recursion helper -// @param T pad_val - value to pad tensor with -// @param size_t cur_dim - recursion helper -// @return Status - The error code return -Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, - std::vector cur_ind, size_t cur_dim = 0); - -// Pad input string tensor according pad_shape, need to have same rank. -// @param std::shared_ptr src - tensor to pad from -// @param std::shared_ptr *dst - return tensor padded -// @param std::vector pad_shape - shape to pad to -// @param std::string pad_val - value to pad with -// @return - The error code return -Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, - const std::vector &pad_shape, const std::string &pad_val); - -// recursive helper function for padding string tensors. This function could be very expensive if called on a -// multi-dimensional tensor it is only meant to be called by PadEndString. -// @tparam T - type of tensor and fill value -// @param std::shared_ptr src - Tensor to pad from -// @param std::shared_ptr* dst - Tensor to pad to, return value -// @param std::vector cur_ind - recursion helperas text -// @param std::string pad_val - value to pad tensor with -// @param size_t cur_dim - recursion helper -// @return Status - The error code return -Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, - const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, - const std::string &pad_value); - -enum class RelationalOp { - kEqual = 0, // == - kNotEqual, // != - kLess, // < - kLessEqual, // <= - kGreater, // > - kGreaterEqual, // >= -}; - -/// Helper method that masks the input tensor -/// @tparam T type of the tensor -/// @param input[in] input tensor -/// @param output[out] output tensor -/// @param value_tensor[in] scalar tensor value to compared with -/// @param op[in] RelationalOp enum -/// @return Status ok/error -template -Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, - const std::shared_ptr &value_tensor, RelationalOp op); - -/// Mask the input tensor -/// @param input[in] input tensor -/// @param output[out] output tensor -/// @param value[in] scalar tensor value to compared with -/// @param op[in] RelationalOp enum -/// @return Status ok/error -Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, - RelationalOp op); - -Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, - std::shared_ptr append); - -// helper for concat, always append to the input, and pass that to the output -Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, - std::shared_ptr append); - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_DATA_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.cc b/mindspore/ccsrc/dataset/kernels/data/duplicate_op.cc deleted file mode 100644 index 959516a4aa..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/data/duplicate_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, input[0])); - output->push_back(input[0]); - output->push_back(out); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.h b/mindspore/ccsrc/dataset/kernels/data/duplicate_op.h deleted file mode 100644 index 598aa3407d..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/duplicate_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_DUPLICATE_OP_H_ -#define DATASET_KERNELS_DATA_DUPLICATE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -class DuplicateOp : public TensorOp { - public: - DuplicateOp() = default; - - ~DuplicateOp() override = default; - - void Print(std::ostream &out) const override { out << "DuplicateOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - uint32_t NumOutput() override { return 2; } - - std::string Name() const override { return kDuplicateOp; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DUPLICATE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/fill_op.cc b/mindspore/ccsrc/dataset/kernels/data/fill_op.cc deleted file mode 100644 index 63895d3a95..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/fill_op.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/fill_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status FillOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - Status s = Fill(input, output, fill_value_); - return s; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/fill_op.h b/mindspore/ccsrc/dataset/kernels/data/fill_op.h deleted file mode 100644 index 5338dbd2b3..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/fill_op.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_DATA_FILL_OP_H_ -#define DATASET_KERNELS_DATA_FILL_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class FillOp : public TensorOp { - public: - explicit FillOp(std::shared_ptr value) : fill_value_(value) {} - - ~FillOp() override = default; - void Print(std::ostream &out) const override { out << "FillOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kFillOp; } - - private: - std::shared_ptr fill_value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDSPORE_FILL_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc b/mindspore/ccsrc/dataset/kernels/data/mask_op.cc deleted file mode 100644 index 2cfeb7e36f..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/mask_op.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/data/mask_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { - -Status MaskOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::shared_ptr temp_output; - CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric."); - - RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_)); - - // cast the output to the the required type. Skip casting if type_ is bool. - if (type_ != DataType::DE_BOOL) { - RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); - } else { - *output = std::move(temp_output); - } - - return Status::OK(); -} - -Status MaskOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = type_; - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/mask_op.h b/mindspore/ccsrc/dataset/kernels/data/mask_op.h deleted file mode 100644 index c610c43715..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/mask_op.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_MASK_OP_H_ -#define DATASET_KERNELS_DATA_MASK_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/kernels/data/data_utils.h" - -namespace mindspore { -namespace dataset { - -class MaskOp : public TensorOp { - public: - MaskOp(RelationalOp op, std::shared_ptr value, DataType type = DataType(DataType::DE_BOOL)) - : op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {} - - ~MaskOp() override = default; - - void Print(std::ostream &out) const override { out << "MaskOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kMaskOp; } - - private: - RelationalOp op_; - std::shared_ptr value_; - DataType type_; - std::unique_ptr cast_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_MASK_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.cc b/mindspore/ccsrc/dataset/kernels/data/one_hot_op.cc deleted file mode 100644 index 65d1a183b3..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/one_hot_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status OneHotOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - Status s = OneHotEncoding(input, output, num_classes_); - return s; -} - -Status OneHotOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - std::vector inputs_copy; - inputs_copy.push_back(inputs[0].Squeeze()); - if (inputs_copy[0].Rank() == 0) outputs.emplace_back(std::vector{num_classes_}); - if (inputs_copy[0].Rank() == 1) outputs.emplace_back(std::vector{inputs_copy[0][0], num_classes_}); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.h b/mindspore/ccsrc/dataset/kernels/data/one_hot_op.h deleted file mode 100644 index 6c789aa10e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/one_hot_op.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_ONE_HOT_OP_H_ -#define DATASET_KERNELS_DATA_ONE_HOT_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class OneHotOp : public TensorOp { - public: - explicit OneHotOp(int num_classes) : num_classes_(num_classes) {} - - ~OneHotOp() override = default; - - void Print(std::ostream &out) const override { out << "OneHotOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kOneHotOp; } - - private: - int num_classes_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.cc b/mindspore/ccsrc/dataset/kernels/data/pad_end_op.cc deleted file mode 100644 index 5b3b4cbe16..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/pad_end_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status PadEndOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_); - return s; -} - -Status PadEndOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - for (auto s : inputs) { - outputs.emplace_back(TensorShape(output_shape_.AsVector())); - } - CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape"); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.h b/mindspore/ccsrc/dataset/kernels/data/pad_end_op.h deleted file mode 100644 index eeb4ce4695..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/pad_end_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_ -#define DATASET_KERNELS_DATA_PAD_END_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class PadEndOp : public TensorOp { - public: - explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr &pad_value) - : output_shape_(pad_shape), pad_val_(pad_value) {} - - ~PadEndOp() override = default; - - void Print(std::ostream &out) const override { out << "PadEndOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kPadEndOp; } - - private: - TensorShape output_shape_; - std::shared_ptr pad_val_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc deleted file mode 100644 index 2eebf26e84..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/slice_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status SliceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now."); - - // if `all` flag is true, output is just the input. - if (all_) { - *output = input; - return Status::OK(); - } - - // if slice object was provided, indices should be empty. Generate indices from the slice object. - if (slice_.valid() && indices_.empty()) { - dsize_t len = input->shape()[0]; - std::vector indices = slice_.Indices(len); - return input->Slice(output, indices); - } - - // if indices are not empty, slices should be invalid, use indices_ to slice - if (!indices_.empty() && !slice_.valid()) { - return input->Slice(output, indices_); - } - RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/dataset/kernels/data/slice_op.h deleted file mode 100644 index b180c9d0a9..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ -#define DATASET_KERNELS_DATA_SLICE_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class Slice { - public: - Slice() : start_(0), stop_(0), step_(0) {} - Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {} - Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} - explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} - - ~Slice() = default; - - std::vector Indices(dsize_t length) { - std::vector indices; - dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); - dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length); - if (step_ > 0) { - for (; index < end_index; index += step_) { - indices.push_back(index); - } - } else { - for (; index > end_index; index += step_) { - indices.push_back(index); - } - } - return indices; - } - - bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); } - - dsize_t start_; - dsize_t stop_; - dsize_t step_; -}; - -class SliceOp : public TensorOp { - public: - explicit SliceOp(std::vector indices) : indices_(std::move(indices)) {} - explicit SliceOp(Slice slice) : slice_(slice) {} - explicit SliceOp(bool all) : all_(all) {} - - ~SliceOp() override = default; - - void Print(std::ostream &out) const override { out << "SliceOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kSliceOp; } - - private: - // only on of the following will be valid - // given indices to slice the Tensor. Empty vector if invalid. - std::vector indices_; - // Slice object. All start, stop and step are 0 if invalid. - Slice slice_; - // Flag to read all indcies in the dim. - bool all_ = false; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_SLICE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.cc b/mindspore/ccsrc/dataset/kernels/data/to_float16_op.cc deleted file mode 100644 index 1cd79456e0..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/to_float16_op.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -Status ToFloat16Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - return ToFloat16(input, output); -} -Status ToFloat16Op::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = DataType(DataType::DE_FLOAT16); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.h b/mindspore/ccsrc/dataset/kernels/data/to_float16_op.h deleted file mode 100644 index b4aa84d10e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/to_float16_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDDATA_TOFLOAT16OP_H -#define MINDDATA_TOFLOAT16OP_H - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class ToFloat16Op : public TensorOp { - public: - ToFloat16Op() = default; - - ~ToFloat16Op() override = default; - - // Overrides the base class compute function - // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor - // and transforms its data to float16, the output memory is manipulated to contain the result - // @return Status - The error code return - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - void Print(std::ostream &out) const override { out << "ToFloat16Op"; } - - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kToFloat16Op; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // MINDDATA_TOFLOAT16OP_H diff --git a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.cc b/mindspore/ccsrc/dataset/kernels/data/type_cast_op.cc deleted file mode 100644 index 74c84a668a..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/data/type_cast_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -TypeCastOp::TypeCastOp(const DataType &new_type) : type_(new_type) {} - -TypeCastOp::TypeCastOp(const std::string &data_type) { type_ = DataType(data_type); } - -Status TypeCastOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - return TypeCast(input, output, type_); -} -Status TypeCastOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = type_; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.h b/mindspore/ccsrc/dataset/kernels/data/type_cast_op.h deleted file mode 100644 index 82fc4bea35..0000000000 --- a/mindspore/ccsrc/dataset/kernels/data/type_cast_op.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ -#define DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class TypeCastOp : public TensorOp { - public: - // Constructor for TypecastOp - // @param data_type datatype to cast to - explicit TypeCastOp(const DataType &data_type); - - // Constructor for TypecastOp - // @param data_type datatype to cast to - explicit TypeCastOp(const std::string &data_type); - - ~TypeCastOp() override = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - void Print(std::ostream &out) const override { out << "TypeCastOp"; } - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kTypeCastOp; } - - private: - DataType type_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc b/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc deleted file mode 100644 index 8f738b6e78..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/kernels/image/bounding_box_augment_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/cv_tensor.h" - -namespace mindspore { -namespace dataset { -const float BoundingBoxAugmentOp::kDefRatio = 0.3; - -BoundingBoxAugmentOp::BoundingBoxAugmentOp(std::shared_ptr transform, float ratio) - : ratio_(ratio), uniform_(0, 1), transform_(std::move(transform)) { - rnd_.seed(GetSeed()); -} - -Status BoundingBoxAugmentOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid - uint32_t num_of_boxes = input[1]->shape()[0]; - std::shared_ptr crop_out; - std::shared_ptr res_out; - std::shared_ptr input_restore = CVTensor::AsCVTensor(input[0]); - for (uint32_t i = 0; i < num_of_boxes; i++) { - // using a uniform distribution to ensure op happens with probability ratio_ - if (uniform_(rnd_) < ratio_) { - float min_x = 0; - float min_y = 0; - float b_w = 0; - float b_h = 0; - // get the required items - RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_x, {i, 0})); - RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_y, {i, 1})); - RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_w, {i, 2})); - RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_h, {i, 3})); - RETURN_IF_NOT_OK(Crop(input_restore, &crop_out, static_cast(min_x), static_cast(min_y), - static_cast(b_w), static_cast(b_h))); - // transform the cropped bbox region - RETURN_IF_NOT_OK(transform_->Compute(crop_out, &res_out)); - // place the transformed region back in the restored input - std::shared_ptr res_img = CVTensor::AsCVTensor(res_out); - // check if transformed crop is out of bounds of the box - if (res_img->mat().cols > b_w || res_img->mat().rows > b_h || res_img->mat().cols < b_w || - res_img->mat().rows < b_h) { - // if so, resize to fit in the box - std::shared_ptr resize_op = - std::make_shared(static_cast(b_h), static_cast(b_w)); - RETURN_IF_NOT_OK(resize_op->Compute(std::static_pointer_cast(res_img), &res_out)); - res_img = CVTensor::AsCVTensor(res_out); - } - res_img->mat().copyTo(input_restore->mat()(cv::Rect(min_x, min_y, res_img->mat().cols, res_img->mat().rows))); - } - } - (*output).push_back(std::move(std::static_pointer_cast(input_restore))); - (*output).push_back(input[1]); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h b/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h deleted file mode 100644 index 9b1d2d18dd..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/bounding_box_augment_op.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ -#define DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -class BoundingBoxAugmentOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefRatio; - - // Constructor for BoundingBoxAugmentOp - // @param std::shared_ptr transform transform: C++ opration to apply on select bounding boxes - // @param float ratio: ratio of bounding boxes to have the transform applied on - BoundingBoxAugmentOp(std::shared_ptr transform, float ratio); - - ~BoundingBoxAugmentOp() override = default; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugmentOp &so) { - so.Print(out); - return out; - } - - void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kBoundingBoxAugmentOp; } - - private: - float ratio_; - std::mt19937 rnd_; - std::uniform_real_distribution uniform_; - std::shared_ptr transform_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/center_crop_op.cc deleted file mode 100644 index a5129e9c71..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/center_crop_op.h" -#include -#include "common/utils.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t CenterCropOp::kDefWidth = 0; - -Status CenterCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::string err_msg; - dsize_t rank = input->shape().Rank(); - err_msg += (rank < 2 || rank > 3) ? "Rank received::" + std::to_string(rank) + " Expected: 2 or 3 \t" : ""; - err_msg += (crop_het_ <= 0 || crop_wid_ <= 0) ? "crop size needs to be positive integers\t" : ""; - - if (err_msg.length() != 0) RETURN_STATUS_UNEXPECTED(common::SafeCStr(err_msg)); - - int32_t top = crop_het_ - input->shape()[0]; // number of pixels to pad (top and bottom) - int32_t left = crop_wid_ - input->shape()[1]; - std::shared_ptr pad_image; - if (top > 0 && left > 0) { // padding only - return Pad(input, output, top / 2 + top % 2, top / 2, left / 2 + left % 2, left / 2, BorderType::kConstant); - } else if (top > 0) { - RETURN_IF_NOT_OK(Pad(input, &pad_image, top / 2 + top % 2, top / 2, 0, 0, BorderType::kConstant)); - return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, - (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); - } else if (left > 0) { - RETURN_IF_NOT_OK(Pad(input, &pad_image, 0, 0, left / 2 + left % 2, left / 2, BorderType::kConstant)); - return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, - (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); - } - return Crop(input, output, (input->shape()[1] - crop_wid_) / 2, (input->shape()[0] - crop_het_) / 2, crop_wid_, - crop_het_); -} - -void CenterCropOp::Print(std::ostream &out) const { - out << "CenterCropOp: " - << "cropWidth: " << crop_wid_ << "cropHeight: " << crop_het_ << "\n"; -} -Status CenterCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{crop_het_, crop_wid_}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/center_crop_op.h deleted file mode 100644 index 87164fe816..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/center_crop_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class CenterCropOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefWidth; - - explicit CenterCropOp(int32_t het, int32_t wid = kDefWidth) : crop_het_(het), crop_wid_(wid == 0 ? het : wid) {} - - ~CenterCropOp() override = default; - - void Print(std::ostream &out) const override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kCenterCropOp; } - - private: - int32_t crop_het_; - int32_t crop_wid_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc deleted file mode 100644 index 74d9df5d6b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/kernels/image/cut_out_op.h" - -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const bool CutOutOp::kDefRandomColor = false; -const uint8_t CutOutOp::kDefFillR = 0; -const uint8_t CutOutOp::kDefFillG = 0; -const uint8_t CutOutOp::kDefFillB = 0; - -// constructor -CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, - uint8_t fill_g, uint8_t fill_b) - : rnd_(GetSeed()), - box_height_(box_height), - box_width_(box_width), - num_patches_(num_patches), - random_color_(random_color), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) {} - -// main function call for cut out -Status CutOutOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - std::shared_ptr inputCV = CVTensor::AsCVTensor(input); - // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black - RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, - fill_g_, fill_b_)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h deleted file mode 100644 index 5c46e5f013..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#ifndef DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ -#define DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class CutOutOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const bool kDefRandomColor; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - // Constructor for CutOutOp - // @param box_height box height - // @param box_width box_width - // @param num_patches how many patches to erase from image - // @param random_color boolean value to indicate fill patch with random color - // @param fill_r R value for the color to fill patch with - // @param fill_g G value for the color to fill patch with - // @param fill_b B value for the color to fill patch with - // @note maybe using unsigned long int isn't the best here according to our coding rules - CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color = kDefRandomColor, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~CutOutOp() override = default; - - void Print(std::ostream &out) const override { - out << "CutOut:: box_height: " << box_height_ << " box_width: " << box_width_ << " num_patches: " << num_patches_; - } - - // Overrides the base class compute function - // Calls the erase function in ImageUtils, this function takes an input tensor - // and overwrites some of its data using openCV, the output memory is manipulated to contain the result - // @return Status - The error code return - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kCutOutOp; } - - private: - std::mt19937 rnd_; - int32_t box_height_; - int32_t box_width_; - int32_t num_patches_; - bool random_color_; - uint8_t fill_r_; - uint8_t fill_g_; - uint8_t fill_b_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.cc b/mindspore/ccsrc/dataset/kernels/image/decode_op.cc deleted file mode 100644 index ef6cf88b3b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/decode_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const bool DecodeOp::kDefRgbFormat = true; - -DecodeOp::DecodeOp(bool is_rgb_format) : is_rgb_format_(is_rgb_format) { - if (is_rgb_format_) { // RGB colour mode - MS_LOG(DEBUG) << "Decode colour mode is RGB."; - } else { - MS_LOG(DEBUG) << "Decode colour mode is BGR."; - } -} - -Status DecodeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (is_rgb_format_) { // RGB colour mode - return Decode(input, output); - } else { // BGR colour mode - RETURN_STATUS_UNEXPECTED("Decode BGR is deprecated"); - } -} -Status DecodeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels - if (inputs[0].Rank() == 1) outputs.emplace_back(out); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} - -Status DecodeOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = DataType(DataType::DE_UINT8); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/dataset/kernels/image/decode_op.h deleted file mode 100644 index f55baf62b4..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_DECODE_OP_H_ -#define DATASET_KERNELS_IMAGE_DECODE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DecodeOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const bool kDefRgbFormat; - - explicit DecodeOp(bool is_rgb_format = true); - - ~DecodeOp() = default; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - void Print(std::ostream &out) const override { out << "DecodeOp"; } - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kDecodeOp; } - - private: - bool is_rgb_format_ = true; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_DECODE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.cc b/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.cc deleted file mode 100644 index 8ed2229cd1..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/hwc_to_chw_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status HwcToChwOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // input.shape == HWC - // output.shape == CHW - return HwcToChw(input, output); -} -Status HwcToChwOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape in = inputs[0]; - TensorShape out = TensorShape{in[2], in[0], in[1]}; - if (inputs[0].Rank() == 3) outputs.emplace_back(out); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.h b/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.h deleted file mode 100644 index 5e1d442148..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/hwc_to_chw_op.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ -#define DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class HwcToChwOp : public TensorOp { - public: - void Print(std::ostream &out) const override { out << "HwcToChw"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kHwcToChwOp; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/dataset/kernels/image/image_utils.cc deleted file mode 100644 index 5bf7b6ba8e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc +++ /dev/null @@ -1,836 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/image_utils.h" -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "dataset/core/constants.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/util/random.h" - -#define MAX_INT_PRECISION 16777216 // float int precision is 16777216 -namespace mindspore { -namespace dataset { -int GetCVInterpolationMode(InterpolationMode mode) { - switch (mode) { - case InterpolationMode::kLinear: - return static_cast(cv::InterpolationFlags::INTER_LINEAR); - case InterpolationMode::kCubic: - return static_cast(cv::InterpolationFlags::INTER_CUBIC); - case InterpolationMode::kArea: - return static_cast(cv::InterpolationFlags::INTER_AREA); - case InterpolationMode::kNearestNeighbour: - return static_cast(cv::InterpolationFlags::INTER_NEAREST); - default: - return static_cast(cv::InterpolationFlags::INTER_LINEAR); - } -} - -int GetCVBorderType(BorderType type) { - switch (type) { - case BorderType::kConstant: - return static_cast(cv::BorderTypes::BORDER_CONSTANT); - case BorderType::kEdge: - return static_cast(cv::BorderTypes::BORDER_REPLICATE); - case BorderType::kReflect: - return static_cast(cv::BorderTypes::BORDER_REFLECT101); - case BorderType::kSymmetric: - return static_cast(cv::BorderTypes::BORDER_REFLECT); - default: - return static_cast(cv::BorderTypes::BORDER_CONSTANT); - } -} - -Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); - - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes())); - - if (input_cv->mat().data) { - try { - cv::flip(input_cv->mat(), output_cv->mat(), flip_code); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in flip op."); - } - } else { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor, the input data is null"); - } -} - -Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output) { - return Flip(std::move(input), output, 1); -} - -Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output) { - return Flip(std::move(input), output, 0); -} - -Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, - int32_t output_width, double fx, double fy, InterpolationMode mode) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { - RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of or "); - } - cv::Mat in_image = input_cv->mat(); - // resize image too large or too small - if (output_height == 0 || output_height > in_image.rows * 1000 || output_width == 0 || - output_width > in_image.cols * 1000) { - std::string err_msg = - "The resizing width or height 1) is too big, it's up to " - "1000 times the original image; 2) can not be 0."; - return Status(StatusCode::kShapeMisMatch, err_msg); - } - try { - TensorShape shape{output_height, output_width}; - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - auto cv_mode = GetCVInterpolationMode(mode); - cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image resize."); - } -} - -bool IsNonEmptyJPEG(const std::shared_ptr &input) { - const unsigned char *kJpegMagic = (unsigned char *)"\xFF\xD8\xFF"; - constexpr size_t kJpegMagicLen = 3; - return input->SizeInBytes() > kJpegMagicLen && memcmp(input->GetBuffer(), kJpegMagic, kJpegMagicLen) == 0; -} - -Status Decode(const std::shared_ptr &input, std::shared_ptr *output) { - if (IsNonEmptyJPEG(input)) { - return JpegCropAndDecode(input, output); - } else { - return DecodeCv(input, output); - } -} - -Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - try { - cv::Mat img_mat = cv::imdecode(input_cv->mat(), cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION); - if (img_mat.data == nullptr) { - std::string err = "Error in decoding\t"; - RETURN_STATUS_UNEXPECTED(err); - } - cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); - std::shared_ptr output_cv = std::make_shared(img_mat); - RETURN_UNEXPECTED_IF_NULL(output_cv); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image Decode"); - } -} - -static void JpegInitSource(j_decompress_ptr cinfo) {} - -static boolean JpegFillInputBuffer(j_decompress_ptr cinfo) { - if (cinfo->src->bytes_in_buffer == 0) { - ERREXIT(cinfo, JERR_INPUT_EMPTY); - return FALSE; - } - return TRUE; -} - -static void JpegTermSource(j_decompress_ptr cinfo) {} - -static void JpegSkipInputData(j_decompress_ptr cinfo, int64_t jump) { - if (jump < 0) { - return; - } - if (static_cast(jump) > cinfo->src->bytes_in_buffer) { - cinfo->src->bytes_in_buffer = 0; - return; - } else { - cinfo->src->bytes_in_buffer -= jump; - cinfo->src->next_input_byte += jump; - } -} - -void JpegSetSource(j_decompress_ptr cinfo, const void *data, int64_t datasize) { - cinfo->src = static_cast( - (*cinfo->mem->alloc_small)(reinterpret_cast(cinfo), JPOOL_PERMANENT, sizeof(struct jpeg_source_mgr))); - cinfo->src->init_source = JpegInitSource; - cinfo->src->fill_input_buffer = JpegFillInputBuffer; -#if defined(_WIN32) || defined(_WIN64) - cinfo->src->skip_input_data = reinterpret_cast(JpegSkipInputData); -#else - cinfo->src->skip_input_data = JpegSkipInputData; -#endif - cinfo->src->resync_to_restart = jpeg_resync_to_restart; - cinfo->src->term_source = JpegTermSource; - cinfo->src->bytes_in_buffer = datasize; - cinfo->src->next_input_byte = static_cast(data); -} - -static Status JpegReadScanlines(jpeg_decompress_struct *const cinfo, int max_scanlines_to_read, JSAMPLE *buffer, - int buffer_size, int crop_w, int crop_w_aligned, int offset, int stride) { - // scanlines will be read to this buffer first, must have the number - // of components equal to the number of components in the image - int64_t scanline_size = crop_w_aligned * cinfo->output_components; - std::vector scanline(scanline_size); - JSAMPLE *scanline_ptr = &scanline[0]; - while (cinfo->output_scanline < static_cast(max_scanlines_to_read)) { - int num_lines_read = jpeg_read_scanlines(cinfo, &scanline_ptr, 1); - if (cinfo->out_color_space == JCS_CMYK && num_lines_read > 0) { - for (int i = 0; i < crop_w; ++i) { - int cmyk_pixel = 4 * i + offset; - const int c = scanline_ptr[cmyk_pixel]; - const int m = scanline_ptr[cmyk_pixel + 1]; - const int y = scanline_ptr[cmyk_pixel + 2]; - const int k = scanline_ptr[cmyk_pixel + 3]; - int r, g, b; - if (cinfo->saw_Adobe_marker) { - r = (k * c) / 255; - g = (k * m) / 255; - b = (k * y) / 255; - } else { - r = (255 - c) * (255 - k) / 255; - g = (255 - m) * (255 - k) / 255; - b = (255 - y) * (255 - k) / 255; - } - buffer[3 * i + 0] = r; - buffer[3 * i + 1] = g; - buffer[3 * i + 2] = b; - } - } else if (num_lines_read > 0) { - int copy_status = memcpy_s(buffer, buffer_size, scanline_ptr + offset, stride); - if (copy_status != 0) { - jpeg_destroy_decompress(cinfo); - RETURN_STATUS_UNEXPECTED("memcpy failed"); - } - } else { - jpeg_destroy_decompress(cinfo); - std::string err_msg = "failed to read scanline"; - RETURN_STATUS_UNEXPECTED(err_msg); - } - buffer += stride; - buffer_size = buffer_size - stride; - } - return Status::OK(); -} - -static Status JpegSetColorSpace(jpeg_decompress_struct *cinfo) { - switch (cinfo->num_components) { - case 1: - // we want to output 3 components if it's grayscale - cinfo->out_color_space = JCS_RGB; - return Status::OK(); - case 3: - cinfo->out_color_space = JCS_RGB; - return Status::OK(); - case 4: - // Need to manually convert to RGB - cinfo->out_color_space = JCS_CMYK; - return Status::OK(); - default: - jpeg_destroy_decompress(cinfo); - std::string err_msg = "wrong number of components"; - RETURN_STATUS_UNEXPECTED(err_msg); - } -} - -void JpegErrorExitCustom(j_common_ptr cinfo) { - char jpeg_last_error_msg[JMSG_LENGTH_MAX]; - (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); - throw std::runtime_error(jpeg_last_error_msg); -} - -Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int crop_x, int crop_y, - int crop_w, int crop_h) { - struct jpeg_decompress_struct cinfo; - auto DestroyDecompressAndReturnError = [&cinfo](const std::string &err) { - jpeg_destroy_decompress(&cinfo); - RETURN_STATUS_UNEXPECTED(err); - }; - struct JpegErrorManagerCustom jerr; - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = JpegErrorExitCustom; - try { - jpeg_create_decompress(&cinfo); - JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); - (void)jpeg_read_header(&cinfo, TRUE); - RETURN_IF_NOT_OK(JpegSetColorSpace(&cinfo)); - jpeg_calc_output_dimensions(&cinfo); - } catch (std::runtime_error &e) { - return DestroyDecompressAndReturnError(e.what()); - } - if (crop_x == 0 && crop_y == 0 && crop_w == 0 && crop_h == 0) { - crop_w = cinfo.output_width; - crop_h = cinfo.output_height; - } else if (crop_w == 0 || static_cast(crop_w + crop_x) > cinfo.output_width || crop_h == 0 || - static_cast(crop_h + crop_y) > cinfo.output_height) { - return DestroyDecompressAndReturnError("Crop window is not valid"); - } - const int mcu_size = cinfo.min_DCT_scaled_size; - unsigned int crop_x_aligned = (crop_x / mcu_size) * mcu_size; - unsigned int crop_w_aligned = crop_w + crop_x - crop_x_aligned; - try { - (void)jpeg_start_decompress(&cinfo); - jpeg_crop_scanline(&cinfo, &crop_x_aligned, &crop_w_aligned); - } catch (std::runtime_error &e) { - return DestroyDecompressAndReturnError(e.what()); - } - JDIMENSION skipped_scanlines = jpeg_skip_scanlines(&cinfo, crop_y); - // three number of output components, always convert to RGB and output - constexpr int kOutNumComponents = 3; - TensorShape ts = TensorShape({crop_h, crop_w, kOutNumComponents}); - auto output_tensor = std::make_shared(ts, DataType(DataType::DE_UINT8)); - const int buffer_size = output_tensor->SizeInBytes(); - JSAMPLE *buffer = reinterpret_cast(&(*output_tensor->begin())); - const int max_scanlines_to_read = skipped_scanlines + crop_h; - // stride refers to output tensor, which has 3 components at most - const int stride = crop_w * kOutNumComponents; - // offset is calculated for scanlines read from the image, therefore - // has the same number of components as the image - const int offset = (crop_x - crop_x_aligned) * cinfo.output_components; - RETURN_IF_NOT_OK( - JpegReadScanlines(&cinfo, max_scanlines_to_read, buffer, buffer_size, crop_w, crop_w_aligned, offset, stride)); - *output = output_tensor; - jpeg_destroy_decompress(&cinfo); - return Status::OK(); -} - -Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - cv::Mat input_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); - try { - input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image rescale"); - } - return Status::OK(); -} - -Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { - RETURN_STATUS_UNEXPECTED("Shape not or "); - } - try { - TensorShape shape{h, w}; - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::Rect roi(x, y, w, h); - (input_cv->mat())(roi).copyTo(output_cv->mat()); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in crop."); - } -} - -Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() == 2) { - // If input tensor is 2D, we assume we have hw dimensions - *output = input; - return Status::OK(); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->shape().Size() < 2 || input_cv->shape().Size() > 3 || - (input_cv->shape().Size() == 3 && num_channels != 3 && num_channels != 1)) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3 nor 1"); - } - cv::Mat output_img; - - int height = input_cv->shape()[0]; - int width = input_cv->shape()[1]; - - auto output_cv = std::make_unique(TensorShape{num_channels, height, width}, input_cv->type()); - for (int i = 0; i < num_channels; ++i) { - cv::Mat mat; - RETURN_IF_NOT_OK(output_cv->Mat({i}, &mat)); - cv::extractChannel(input_cv->mat(), mat, i); - } - *output = std::move(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in ChannelSwap."); - } -} - -Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); - int num_channels = input_cv->shape()[2]; - if (input_cv->shape().Size() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in ChangeMode."); - } -} - -Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, - int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { - RETURN_STATUS_UNEXPECTED("Shape not or "); - } - // image too large or too small - if (crop_height == 0 || crop_width == 0 || target_height == 0 || target_height > crop_height * 1000 || - target_width == 0 || target_height > crop_width * 1000) { - std::string err_msg = - "The resizing width or height 1) is too big, it's up to " - "1000 times the original image; 2) can not be 0."; - RETURN_STATUS_UNEXPECTED(err_msg); - } - cv::Rect roi(x, y, crop_width, crop_height); - auto cv_mode = GetCVInterpolationMode(mode); - cv::Mat cv_in = input_cv->mat(); - TensorShape shape{target_height, target_width}; - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); - std::shared_ptr cvt_out = std::make_shared(shape, input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(cvt_out); - cv::resize(cv_in(roi), cvt_out->mat(), cv::Size(target_width, target_height), 0, 0, cv_mode); - *output = std::static_pointer_cast(cvt_out); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in CropAndResize."); - } -} - -Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, - InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - cv::Mat input_img = input_cv->mat(); - if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) { - RETURN_STATUS_UNEXPECTED("Image too large center not precise"); - } - // default to center of image - if (fx == -1 && fy == -1) { - fx = (input_img.cols - 1) / 2.0; - fy = (input_img.rows - 1) / 2.0; - } - cv::Mat output_img; - cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); - // maybe don't use uint32 for image dimension here - cv::Point2f pc(fx, fy); - cv::Mat rot = cv::getRotationMatrix2D(pc, degree, 1.0); - std::shared_ptr output_cv; - if (!expand) { - // this case means that the shape doesn't change, size stays the same - // We may not need this memcpy if it is in place. - output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - // using inter_nearest to comply with python default - cv::warpAffine(input_img, output_cv->mat(), rot, input_img.size(), GetCVInterpolationMode(interpolation), - cv::BORDER_CONSTANT, fill_color); - } else { - // we resize here since the shape changes - // create a new bounding box with the rotate - cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f(); - rot.at(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0; - rot.at(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0; - // use memcpy and don't compute the new shape since openCV has a rounding problem - cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation), - cv::BORDER_CONSTANT, fill_color); - output_cv = std::make_shared(output_img); - RETURN_UNEXPECTED_IF_NULL(output_cv); - } - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in image rotation"); - } - return Status::OK(); -} - -Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, - const std::shared_ptr &mean, const std::shared_ptr &std) { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - if (!(input_cv->mat().data && input_cv->Rank() == 3)) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - cv::Mat in_image = input_cv->mat(); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); - RETURN_UNEXPECTED_IF_NULL(output_cv); - mean->Squeeze(); - if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { - std::string err_msg = "Mean tensor should be of size 3 and type float."; - return Status(StatusCode::kShapeMisMatch, err_msg); - } - std->Squeeze(); - if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) { - std::string err_msg = "Std tensor should be of size 3 and type float."; - return Status(StatusCode::kShapeMisMatch, err_msg); - } - try { - // NOTE: We are assuming the input image is in RGB and the mean - // and std are in RGB - cv::Mat rgb[3]; - cv::split(in_image, rgb); - for (uint8_t i = 0; i < 3; i++) { - float mean_c, std_c; - RETURN_IF_NOT_OK(mean->GetItemAt(&mean_c, {i})); - RETURN_IF_NOT_OK(std->GetItemAt(&std_c, {i})); - rgb[i].convertTo(rgb[i], CV_32F, 1.0 / std_c, (-mean_c / std_c)); - } - cv::merge(rgb, 3, output_cv->mat()); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in Normalize"); - } -} - -Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - output_cv->mat() = input_img * alpha; - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust brightness"); - } - return Status::OK(); -} - -Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - cv::Mat gray, output_img; - cv::cvtColor(input_img, gray, CV_RGB2GRAY); - int mean_img = static_cast(cv::mean(gray).val[0] + 0.5); - std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - output_img = cv::Mat::zeros(input_img.rows, input_img.cols, CV_8UC1); - output_img = output_img + mean_img; - cv::cvtColor(output_img, output_img, CV_GRAY2RGB); - output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust contrast"); - } - return Status::OK(); -} - -Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::Mat output_img = output_cv->mat(); - cv::Mat gray; - cv::cvtColor(input_img, gray, CV_RGB2GRAY); - cv::cvtColor(gray, output_img, CV_GRAY2RGB); - output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust saturation"); - } - return Status::OK(); -} - -Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue) { - if (hue > 0.5 || hue < -0.5) { - MS_LOG(ERROR) << "Hue factor is not in [-0.5, 0.5]."; - RETURN_STATUS_UNEXPECTED("hue_factor is not in [-0.5, 0.5]."); - } - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - cv::Mat input_img = input_cv->mat(); - if (!input_cv->mat().data) { - RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); - } - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); - } - auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); - RETURN_UNEXPECTED_IF_NULL(output_cv); - cv::Mat output_img; - cv::cvtColor(input_img, output_img, CV_RGB2HSV_FULL); - for (int y = 0; y < output_img.cols; y++) { - for (int x = 0; x < output_img.rows; x++) { - uint8_t cur1 = output_img.at(cv::Point(y, x))[0]; - uint8_t h_hue = 0; - h_hue = static_cast(hue * 255); - cur1 += h_hue; - output_img.at(cv::Point(y, x))[0] = cur1; - } - } - cv::cvtColor(output_img, output_cv->mat(), CV_HSV2RGB_FULL); - *output = std::static_pointer_cast(output_cv); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in adjust hue"); - } - return Status::OK(); -} - -Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, - uint8_t fill_g, uint8_t fill_b) { - try { - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - int num_channels = input_cv->shape()[2]; - if (input_cv->mat().data == nullptr || input_cv->Rank() != 3 || num_channels != 3) { - RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); - } - cv::Mat input_img = input_cv->mat(); - int32_t image_h = input_cv->shape()[0]; - int32_t image_w = input_cv->shape()[1]; - // check if erase size is bigger than image itself - if (box_height > image_h || box_width > image_w) { - RETURN_STATUS_UNEXPECTED("input box size too large for image erase"); - } - - // for random color - std::normal_distribution normal_distribution(0, 1); - std::uniform_int_distribution height_distribution_bound(0, image_h - box_height); - std::uniform_int_distribution width_distribution_bound(0, image_w - box_width); - std::uniform_int_distribution height_distribution_unbound(0, image_h + box_height); - std::uniform_int_distribution width_distribution_unbound(0, image_w + box_width); - // core logic - // update values based on random erasing or cutout - - for (int32_t i = 0; i < num_patches; i++) { - // rows in cv mat refers to the height of the cropped box - // we determine h_start and w_start using two different distributions as erasing is used by two different - // image augmentations. The bounds are also different in each case. - int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); - int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); - - int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; - int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; - // check for starting range >= 0, here the start range is checked after for cut out, for random erasing - // w_start and h_start will never be less than 0. - h_start = (h_start < 0) ? 0 : h_start; - w_start = (w_start < 0) ? 0 : w_start; - for (int y = w_start; y < max_width; y++) { - for (int x = h_start; x < max_height; x++) { - if (random_color) { - // fill each box with a random value - input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(*rnd)); - input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(*rnd)); - input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(*rnd)); - } else { - input_img.at(cv::Point(y, x))[0] = fill_r; - input_img.at(cv::Point(y, x))[1] = fill_g; - input_img.at(cv::Point(y, x))[2] = fill_b; - } - } - } - } - *output = std::static_pointer_cast(input); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Error in erasing"); - } -} - -Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, - const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, - uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { - try { - // input image - std::shared_ptr input_cv = CVTensor::AsCVTensor(input); - // get the border type in openCV - auto b_type = GetCVBorderType(border_types); - // output image - cv::Mat out_image; - if (b_type == cv::BORDER_CONSTANT) { - cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); - cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type, fill_color); - } else { - cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type); - } - std::shared_ptr output_cv = std::make_shared(out_image); - RETURN_UNEXPECTED_IF_NULL(output_cv); - // pad the dimension if shape information is only 2 dimensional, this is grayscale - int num_channels = input_cv->shape()[2]; - if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2); - *output = std::static_pointer_cast(output_cv); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("Unexpected error in pad"); - } -} -// -------- BBOX OPERATIONS -------- // -Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, - int CB_Ymax) { - // PASS LIST, COUNT OF BOUNDING BOXES - // Also PAss X/Y Min/Max of image cropped region - normally obtained from 'GetCropBox' functions - float bb_Xmin = 0.0, bb_Ymin = 0.0, bb_Xmax = 0.0, bb_Ymax = 0.0; - std::vector correct_ind; - std::vector copyVals; - dsize_t bboxDim = (*bboxList)->shape()[1]; - bool retFlag = false; // true unless overlap found - for (int i = 0; i < *bboxCount; i++) { - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Xmin, {i, 0})); - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Ymin, {i, 1})); - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Xmax, {i, 2})); - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Ymax, {i, 3})); - bb_Xmax = bb_Xmin + bb_Xmax; - bb_Ymax = bb_Ymin + bb_Ymax; - // check for image / BB overlap - if (((bb_Xmin > CB_Xmax) || (bb_Ymin > CB_Ymax)) || ((bb_Xmax < CB_Xmin) || (bb_Ymax < CB_Ymin))) { - continue; // no overlap found - } - // Update this bbox and select it to move to the final output tensor - correct_ind.push_back(i); - // adjust BBox corners by bringing into new CropBox if beyond - // Also reseting/adjusting for boxes to lie within CropBox instead of Image - subtract CropBox Xmin/YMin - - bb_Xmin = bb_Xmin - std::min(static_cast(0.0), (bb_Xmin - CB_Xmin)) - CB_Xmin; - bb_Xmax = bb_Xmax - std::max(static_cast(0.0), (bb_Xmax - CB_Xmax)) - CB_Xmin; - bb_Ymin = bb_Ymin - std::min(static_cast(0.0), (bb_Ymin - CB_Ymin)) - CB_Ymin; - bb_Ymax = bb_Ymax - std::max(static_cast(0.0), (bb_Ymax - CB_Ymax)) - CB_Ymin; - - // bound check for float values - bb_Xmin = std::max(bb_Xmin, static_cast(0)); - bb_Ymin = std::max(bb_Ymin, static_cast(0)); - bb_Xmax = std::min(bb_Xmax, static_cast(CB_Xmax - CB_Xmin)); // find max value relative to new image - bb_Ymax = std::min(bb_Ymax, static_cast(CB_Ymax - CB_Ymin)); - - // reset min values and calculate width/height from Box corners - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, bb_Xmin)); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, bb_Ymin)); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 2}, bb_Xmax - bb_Xmin)); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 3}, bb_Ymax - bb_Ymin)); - } - // create new tensor and copy over bboxes still valid to the image - // bboxes outside of new cropped region are ignored - empty tensor returned in case of none - *bboxCount = correct_ind.size(); - float temp = 0.0; - for (auto slice : correct_ind) { // for every index in the loop - for (int ix = 0; ix < bboxDim; ix++) { - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&temp, {slice, ix})); - copyVals.push_back(temp); - } - } - std::shared_ptr retV; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast(*bboxCount), bboxDim}))); - (*bboxList) = retV; // reset pointer - return Status::OK(); -} - -Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left) { - for (int i = 0; i < bboxCount; i++) { - float xMin = 0.0, yMin = 0.0; - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&xMin, {i, 0})); - RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&yMin, {i, 1})); - xMin += pad_left; - yMin += pad_top; - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, xMin)); - RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, yMin)); - } - return Status::OK(); -} - -Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, - int32_t target_height_, int orig_width, int orig_height) { - float bb_Xmin = 0, bb_Ymin = 0, bb_Xwidth = 0, bb_Ywidth = 0; - // cast to float to preserve fractional - float W_aspRatio = (target_width_ * 1.0) / (orig_width * 1.0); - float H_aspRatio = (target_height_ * 1.0) / (orig_height * 1.0); - for (int i = 0; i < bboxCount; i++) { - // for each bounding box - RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Xmin, {i, 0})); - RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Ymin, {i, 1})); - RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Xwidth, {i, 2})); - RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Ywidth, {i, 3})); - // update positions and widths - bb_Xmin = bb_Xmin * W_aspRatio; - bb_Ymin = bb_Ymin * H_aspRatio; - bb_Xwidth = bb_Xwidth * W_aspRatio; - bb_Ywidth = bb_Ywidth * H_aspRatio; - // reset bounding box values - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 0}, bb_Xmin)); - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 1}, bb_Ymin)); - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 2}, bb_Xwidth)); - RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 3}, bb_Ywidth)); - } - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/dataset/kernels/image/image_utils.h deleted file mode 100644 index 212d81f7fc..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.h +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ -#define DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ - -#include - -#include -#include -#include -#include -#if defined(_WIN32) || defined(_WIN64) -#undef HAVE_STDDEF_H -#undef HAVE_STDLIB_H -#endif -#include "./jpeglib.h" -#include "./jerror.h" -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -void JpegErrorExitCustom(j_common_ptr cinfo); - -struct JpegErrorManagerCustom { - // "public" fields - struct jpeg_error_mgr pub; - // for return to caller - jmp_buf setjmp_buffer; -}; - -// Returns the interpolation mode in openCV format -// @param mode: interpolation mode in DE format -int GetCVInterpolationMode(InterpolationMode mode); - -// Returns the openCV equivalent of the border type used for padding. -// @param type -// @return -int GetCVBorderType(BorderType type); - -// Returns flipped image -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param flip_code: 1 for Horizontal (around y-axis), 0 for Vertical (around x-axis), -1 for both -// The flipping happens in place. -Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code); - -// Returns Horizontally flipped image -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// The flipping happens in place. -Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output); - -// Returns Vertically flipped image -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// The flipping happens in place. -Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output); - -// Returns Resized image. -// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param output_height: height of output -// @param output_width: width of output -// @param fx: horizontal scale -// @param fy: vertical scale -// @param InterpolationMode: the interpolation mode -// @param output: Resized image of shape or -// and same type as input -Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, - int32_t output_width, double fx = 0.0, double fy = 0.0, - InterpolationMode mode = InterpolationMode::kLinear); - -// Returns Decoded image -// Supported images: -// BMP JPEG JPG PNG TIFF -// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly. -// @param input: CVTensor containing the not decoded image 1D bytes -// @param output: Decoded image Tensor of shape and type DE_UINT8. Pixel order is RGB -Status Decode(const std::shared_ptr &input, std::shared_ptr *output); - -Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output); - -bool IsNonEmptyJPEG(const std::shared_ptr &input); - -void JpegSetSource(j_decompress_ptr c_info, const void *data, int64_t data_size); - -Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int x = 0, int y = 0, - int w = 0, int h = 0); -// Returns Rescaled image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param rescale: rescale parameter -// @param shift: shift parameter -// @param output: Rescaled image Tensor of same input shape and type DE_FLOAT32 -Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift); - -// Returns cropped ROI of an image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param x: starting horizontal position of ROI -// @param y: starting vertical position of ROI -// @param w: width of the ROI -// @param h: height of the ROI -// @param output: Cropped image Tensor of shape or and same input type. -Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h); - -// Swaps the channels in the image, i.e. converts HWC to CHW -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param output: Tensor of shape or and same input type. -Status HwcToChw(std::shared_ptr input, std::shared_ptr *output); - -// Swap the red and blue pixels (RGB <-> BGR) -// @param input: Tensor of shape and any OpenCv compatible type, see CVTensor. -// @param output: Swapped image of same shape and type -Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output); - -// Crops and resizes the image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param x: horizontal start point -// @param y: vertical start point -// @param crop_height: height of the cropped ROI -// @param crop_width: width of the cropped ROI -// @param target_width: width of the final resized image -// @param target_height: height of the final resized image -// @param InterpolationMode: the interpolation used in resize operation -// @param output: Tensor of shape or -// and same type as input -Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, - int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode); - -// Returns rotated image -// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. -// @param fx: rotation center x coordinate -// @param fy: rotation center y coordinate -// @param degree: degree to rotate -// @param expand: if reshape is necessary -// @param output: rotated image of same input type. -Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, - InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, bool expand = false, - uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); - -// Returns Normalized image -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order -// @param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order -// @param output: Normalized image Tensor of same input shape and type DE_FLOAT32 -Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, - const std::shared_ptr &mean, const std::shared_ptr &std); - -// Returns image with adjusted brightness. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param alpha: Alpha value to adjust brightness by. Should be a positive number. -// If user input one value in python, the range is [1 - value, 1 + value]. -// This will output original image multiplied by alpha. 0 gives a black image, 1 gives the -// original image while 2 increases the brightness by a factor of 2. -// @param output: Adjusted image of same shape and type. -Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); - -// Returns image with adjusted contrast. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param alpha: Alpha value to adjust contrast by. Should be a positive number. -// If user input one value in python, the range is [1 - value, 1 + value]. -// 0 gives a solid gray image, 1 gives the original image while 2 increases -// the contrast by a factor of 2. -// @param output: Adjusted image of same shape and type. -Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); - -// Returns image with adjusted saturation. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param alpha: Alpha value to adjust saturation by. Should be a positive number. -// If user input one value in python, the range is [1 - value, 1 + value]. -// 0 will give a black and white image, 1 will give the original image while -// 2 will enhance the saturation by a factor of 2. -// @param output: Adjusted image of same shape and type. -Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); - -// Returns image with adjusted hue. -// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. -// @param hue: Hue value to adjust by, should be within range [-0.5, 0.5]. 0.5 and - 0.5 will reverse the hue channel -// completely. -// If user input one value in python, the range is [-value, value]. -// @param output: Adjusted image of same shape and type. -Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue); - -// Masks out a random section from the image with set dimension -// @param input: input Tensor -// @param output: cutOut Tensor -// @param box_height: height of the cropped box -// @param box_width: width of the cropped box -// @param num_patches: number of boxes to cut out from the image -// @param bounded: boolean flag to toggle between random erasing and cutout -// @param random_color: whether or not random fill value should be used -// @param fill_r: red fill value for erase -// @param fill_g: green fill value for erase -// @param fill_b: blue fill value for erase. -Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, - uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); - -// Pads the input image and puts the padded image in the output -// @param input: input Tensor -// @param output: padded Tensor -// @param pad_top: amount of padding done in top -// @param pad_bottom: amount of padding done in bottom -// @param pad_left: amount of padding done in left -// @param pad_right: amount of padding done in right -// @param border_types: the interpolation to be done in the border -// @param fill_r: red fill value for pad -// @param fill_g: green fill value for pad -// @param fill_b: blue fill value for pad. -Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, - const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, - uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); - -// -------- BBOX OPERATIONS -------- // -// Updates and checks bounding boxes for new cropped region of image -// @param bboxList: A tensor contaning bounding box tensors -// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop -// @param CB_Xmin: Image's CropBox Xmin coordinate -// @param CB_Xmin: Image's CropBox Ymin coordinate -// @param CB_Xmax: Image's CropBox Xmax coordinate - (Xmin + width) -// @param CB_Xmax: Image's CropBox Ymax coordinate - (Ymin + height) -Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, - int CB_Ymax); - -// Updates bounding boxes with required Top and Left padding -// Top and Left padding amounts required to adjust bboxs min X,Y values according to padding 'push' -// Top/Left since images 0,0 coordinate is taken from top left -// @param bboxList: A tensor contaning bounding box tensors -// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop -// @param pad_top: Total amount of padding applied to image top -// @param pad_left: Total amount of padding applied to image left side -Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left); - -// Updates bounding boxes for an Image Resize Operation - Takes in set of valid BBoxes -// For e.g those that remain after a crop -// @param bboxList: A tensor contaning bounding box tensors -// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop -// @param bboxList: A tensor contaning bounding box tensors -// @param target_width_: required width of image post resize -// @param target_width_: required height of image post resize -// @param orig_width: current width of image pre resize -// @param orig_height: current height of image pre resize -Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, - int32_t target_height_, int orig_width, int orig_height); - -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/dataset/kernels/image/normalize_op.cc deleted file mode 100644 index 638eaad264..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/normalize_op.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/normalize_op.h" - -#include - -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { - int size[] = {3}; - cv::Mat mean_cv(1, size, CV_32F); - mean_cv.at(0) = mean_r; - mean_cv.at(1) = mean_g; - mean_cv.at(2) = mean_b; - mean_ = std::make_shared(mean_cv); - mean_->Squeeze(); - - cv::Mat std_cv(1, size, CV_32F); - std_cv.at(0) = std_r; - std_cv.at(1) = std_g; - std_cv.at(2) = std_b; - std_ = std::make_shared(std_cv); - std_->Squeeze(); -} - -Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - // Doing the normalization - return Normalize(input, output, mean_, std_); -} - -void NormalizeOp::Print(std::ostream &out) const { - out << "NormalizeOp, mean: " << mean_->mat().at(0) << ", " << mean_->mat().at(1) << ", " - << mean_->mat().at(2) << "std: " << std_->mat().at(0) << ", " << std_->mat().at(1) << ", " - << std_->mat().at(2) << std::endl; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/normalize_op.h b/mindspore/ccsrc/dataset/kernels/image/normalize_op.h deleted file mode 100644 index a66f95a2b5..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/normalize_op.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ - -#include -#include - -#include "dataset/core/cv_tensor.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class NormalizeOp : public TensorOp { - public: - NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b); - - ~NormalizeOp() override = default; - - void Print(std::ostream &out) const override; - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kNormalizeOp; } - - private: - std::shared_ptr mean_; - std::shared_ptr std_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/pad_op.cc b/mindspore/ccsrc/dataset/kernels/image/pad_op.cc deleted file mode 100644 index baeceeed77..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/pad_op.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/pad_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/constants.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const BorderType PadOp::kDefBorderType = BorderType::kConstant; -const uint8_t PadOp::kDefFillR = 0; -const uint8_t PadOp::kDefFillG = 0; -const uint8_t PadOp::kDefFillB = 0; - -PadOp::PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, - uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) - : pad_top_(pad_top), - pad_bottom_(pad_bottom), - pad_left_(pad_left), - pad_right_(pad_right), - boarder_type_(border_types), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) {} - -Status PadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - return Pad(input, output, pad_top_, pad_bottom_, pad_left_, pad_right_, boarder_type_, fill_r_, fill_g_, fill_b_); -} - -Status PadOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels - if (inputs[0].Rank() == 1) outputs.emplace_back(out); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/pad_op.h b/mindspore/ccsrc/dataset/kernels/image/pad_op.h deleted file mode 100644 index 0457fbc01b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/pad_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_PAD_OP_H_ -#define DATASET_KERNELS_IMAGE_PAD_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/core/constants.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class PadOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const BorderType kDefBorderType; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - // Constructor for PadOp. - // @param pad_top number of pixels to pad the top of image with. - // @param pad_bottom number of pixels to pad the bottom of the image with. - // @param pad_left number of pixels to pad the left of the image with. - // @param pad_right number of pixels to pad the right of the image with. - // @param border_types BorderType enum, the type of boarders that we are using. - // @param fill_r R value for the color to pad with. - // @param fill_g G value for the color to pad with. - // @param fill_b B value for the color to pad with. - PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~PadOp() override = default; - - void Print(std::ostream &out) const override { out << "PadOp: "; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kPadOp; } - - private: - int32_t pad_top_; - int32_t pad_bottom_; - int32_t pad_left_; - int32_t pad_right_; - BorderType boarder_type_; - uint8_t fill_r_; - uint8_t fill_g_; - uint8_t fill_b_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_PAD_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.cc deleted file mode 100644 index e420f86e9a..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_color_adjust_op.h" - -#include - -#include "dataset/core/config_manager.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, - float e_contrast_factor, float s_saturation_factor, float e_saturation_factor, - float s_hue_factor, float e_hue_factor) - : bright_factor_start_(s_bright_factor), - bright_factor_end_(e_bright_factor), - contrast_factor_start_(s_contrast_factor), - contrast_factor_end_(e_contrast_factor), - saturation_factor_start_(s_saturation_factor), - saturation_factor_end_(e_saturation_factor), - hue_factor_start_(s_hue_factor), - hue_factor_end_(e_hue_factor) { - rnd_.seed(GetSeed()); -} - -Status RandomColorAdjustOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - - // randomly select an augmentation to apply to the input image until all the transformations run - std::vector params_vector = {"brightness", "contrast", "saturation", "hue"}; - - std::shuffle(params_vector.begin(), params_vector.end(), rnd_); - - *output = std::static_pointer_cast(input); - // determine if certain augmentation needs to be executed: - for (const auto ¶m : params_vector) { - // case switch - if (param == "brightness") { - if (CmpFloat(bright_factor_start_, bright_factor_end_) && CmpFloat(bright_factor_start_, 1.0f)) { - MS_LOG(DEBUG) << "Not running brightness."; - } else { - // adjust the brightness of an image - float random_factor = std::uniform_real_distribution(bright_factor_start_, bright_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustBrightness(*output, output, random_factor)); - } - } else if (param == "contrast") { - if (CmpFloat(contrast_factor_start_, contrast_factor_end_) && CmpFloat(contrast_factor_start_, 1.0f)) { - MS_LOG(DEBUG) << "Not running contrast."; - } else { - float random_factor = std::uniform_real_distribution(contrast_factor_start_, contrast_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustContrast(*output, output, random_factor)); - } - } else if (param == "saturation") { - // adjust the Saturation of an image - if (CmpFloat(saturation_factor_start_, saturation_factor_end_) && CmpFloat(saturation_factor_start_, 1.0f)) { - MS_LOG(DEBUG) << "Not running saturation."; - } else { - float random_factor = - std::uniform_real_distribution(saturation_factor_start_, saturation_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustSaturation(*output, output, random_factor)); - } - } else if (param == "hue") { - if (CmpFloat(hue_factor_start_, hue_factor_end_) && CmpFloat(hue_factor_start_, 0.0f)) { - MS_LOG(DEBUG) << "Not running hue."; - } else { - // adjust the Hue of an image - float random_factor = std::uniform_real_distribution(hue_factor_start_, hue_factor_end_)(rnd_); - RETURN_IF_NOT_OK(AdjustHue(*output, output, random_factor)); - } - } - } - // now after we do all the transformations, the last one is fine - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.h b/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.h deleted file mode 100644 index 23ccf4aa93..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_color_adjust_op.h +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomColorAdjustOp : public TensorOp { - public: - static const uint32_t kDefSeed; - - // Constructor for RandomColorAdjustOp. - // @param s_bright_factor brightness change range start value. - // @param e_bright_factor brightness change range end value. - // @param s_contrast_factor contrast change range start value. - // @param e_contrast_factor contrast change range start value. - // @param s_saturation_factor saturation change range end value. - // @param e_saturation_factor saturation change range end value. - // @param s_hue_factor hue change factor start value, this should be greater than -0.5. - // @param e_hue_factor hue change factor start value, this should be less than 0.5. - // @param seed optional seed to pass in to the constructor. - // @details the randomly chosen degree is uniformly distributed. - RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, float e_contrast_factor, - float s_saturation_factor, float e_saturation_factor, float s_hue_factor, float e_hue_factor); - - ~RandomColorAdjustOp() override = default; - - // Print function for RandomJitter. - // @param out output stream to print to. - void Print(std::ostream &out) const override { out << "RandomColorAdjustOp: "; } - - // Overrides the base class compute function. - // Calls multiple transform functions in ImageUtils, this function takes an input tensor. - // and transforms its data using openCV, the output memory is manipulated to contain the result. - // @return Status - The error code return. - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kRandomColorAdjustOp; } - - private: - std::mt19937 rnd_; - float bright_factor_start_; - float bright_factor_end_; - float contrast_factor_start_; - float contrast_factor_end_; - float saturation_factor_start_; - float saturation_factor_end_; - float hue_factor_start_; - float hue_factor_end_; - // Compare two floating point variables. Return true if they are same / very close. - inline bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) const { - return (std::fabs(a - b) < epsilon); - } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc deleted file mode 100644 index c5b5f20c63..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomCropAndResizeOp::kDefScaleLb = 0.08; -const float RandomCropAndResizeOp::kDefScaleUb = 1.0; -const float RandomCropAndResizeOp::kDefAspectLb = 0.75; -const float RandomCropAndResizeOp::kDefAspectUb = 1.333333; -const InterpolationMode RandomCropAndResizeOp::kDefInterpolation = InterpolationMode::kLinear; -const int32_t RandomCropAndResizeOp::kDefMaxIter = 10; - -RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb, - float scale_ub, float aspect_lb, float aspect_ub, - InterpolationMode interpolation, int32_t max_iter) - : target_height_(target_height), - target_width_(target_width), - rnd_scale_(scale_lb, scale_ub), - rnd_aspect_(log(aspect_lb), log(aspect_ub)), - interpolation_(interpolation), - aspect_lb_(aspect_lb), - aspect_ub_(aspect_ub), - max_iter_(max_iter) { - rnd_.seed(GetSeed()); -} - -Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); - - int h_in = input->shape()[0]; - int w_in = input->shape()[1]; - int x = 0; - int y = 0; - int crop_height = 0; - int crop_width = 0; - (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); - return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); -} -Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{target_height_, target_width_}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { - *crop_width = w_in; - *crop_height = h_in; - CHECK_FAIL_RETURN_UNEXPECTED(w_in != 0, "Width is 0"); - CHECK_FAIL_RETURN_UNEXPECTED(h_in != 0, "Height is 0"); - CHECK_FAIL_RETURN_UNEXPECTED(aspect_lb_ > 0, "Aspect lower bound must be greater than zero"); - for (int32_t i = 0; i < max_iter_; i++) { - double const sample_scale = rnd_scale_(rnd_); - // In case of non-symmetrical aspect ratios, use uniform distribution on a logarithmic sample_scale. - // Note rnd_aspect_ is already a random distribution of the input aspect ratio in logarithmic sample_scale. - double const sample_aspect = exp(rnd_aspect_(rnd_)); - - *crop_width = static_cast(std::round(std::sqrt(h_in * w_in * sample_scale * sample_aspect))); - *crop_height = static_cast(std::round(*crop_width / sample_aspect)); - if (*crop_width <= w_in && *crop_height <= h_in) { - std::uniform_int_distribution<> rd_x(0, w_in - *crop_width); - std::uniform_int_distribution<> rd_y(0, h_in - *crop_height); - *x = rd_x(rnd_); - *y = rd_y(rnd_); - return Status::OK(); - } - } - double const img_aspect = static_cast(w_in) / h_in; - if (img_aspect < aspect_lb_) { - *crop_width = w_in; - *crop_height = static_cast(std::round(*crop_width / static_cast(aspect_lb_))); - } else { - if (img_aspect > aspect_ub_) { - *crop_height = h_in; - *crop_width = static_cast(std::round(*crop_height * static_cast(aspect_ub_))); - } else { - *crop_width = w_in; - *crop_height = h_in; - } - } - *x = static_cast(std::round((w_in - *crop_width) / 2.0)); - *y = static_cast(std::round((h_in - *crop_height) / 2.0)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.h deleted file mode 100644 index 04e4135e7b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomCropAndResizeOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefScaleLb; - static const float kDefScaleUb; - static const float kDefAspectLb; - static const float kDefAspectUb; - static const InterpolationMode kDefInterpolation; - static const int32_t kDefMaxIter; - - RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, - float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, - InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); - - RandomCropAndResizeOp() = default; - - RandomCropAndResizeOp(const RandomCropAndResizeOp &rhs) = default; - - RandomCropAndResizeOp(RandomCropAndResizeOp &&rhs) = default; - - ~RandomCropAndResizeOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropAndResize: " << target_height_ << " " << target_width_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - Status GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width); - - std::string Name() const override { return kRandomCropAndResizeOp; } - - protected: - int32_t target_height_; - int32_t target_width_; - std::uniform_real_distribution rnd_scale_; - std::uniform_real_distribution rnd_aspect_; - std::mt19937 rnd_; - InterpolationMode interpolation_; - int32_t max_iter_; - double aspect_lb_; - double aspect_ub_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc deleted file mode 100644 index fbaf2c9326..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "dataset/util/random.h" -#include "dataset/util/status.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" - -namespace mindspore { -namespace dataset { - -Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); - - output->resize(2); - (*output)[1] = std::move(input[1]); // move boxes over to output - - size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor - int h_in = input[0]->shape()[0]; - int w_in = input[0]->shape()[1]; - int x = 0; - int y = 0; - int crop_height = 0; - int crop_width = 0; - - RETURN_IF_NOT_OK(RandomCropAndResizeOp::GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width)); - - int maxX = x + crop_width; // max dims of selected CropBox on image - int maxY = y + crop_height; - - RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &bboxCount, x, y, maxX, maxY)); // IMAGE_UTIL - RETURN_IF_NOT_OK(CropAndResize(input[0], &(*output)[0], x, y, crop_height, crop_width, target_height_, target_width_, - interpolation_)); - - RETURN_IF_NOT_OK( - UpdateBBoxesForResize((*output)[1], bboxCount, target_width_, target_height_, crop_width, crop_height)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h deleted file mode 100644 index 2e28495658..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ - -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include - -namespace mindspore { -namespace dataset { - -class RandomCropAndResizeWithBBoxOp : public RandomCropAndResizeOp { - public: - // Constructor for RandomCropAndResizeWithBBoxOp, with default value and passing to base class constructor - RandomCropAndResizeWithBBoxOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, - float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, - float aspect_ub = kDefAspectUb, InterpolationMode interpolation = kDefInterpolation, - int32_t max_iter = kDefMaxIter) - : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, - max_iter) {} - - ~RandomCropAndResizeWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropAndResizeWithBBox: " << RandomCropAndResizeOp::target_height_ << " " - << RandomCropAndResizeOp::target_width_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kRandomCropAndResizeWithBBoxOp; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.cc deleted file mode 100644 index 36d80aea98..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_crop_decode_resize_op.h" -#include -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/config_manager.h" -#include "dataset/kernels/image/decode_op.h" - -namespace mindspore { -namespace dataset { -RandomCropDecodeResizeOp::RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb, - float scale_ub, float aspect_lb, float aspect_ub, - InterpolationMode interpolation, int32_t max_iter) - : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, - max_iter) {} - -Status RandomCropDecodeResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - if (input == nullptr) { - RETURN_STATUS_UNEXPECTED("input tensor is null"); - } - if (!IsNonEmptyJPEG(input)) { - DecodeOp op(true); - std::shared_ptr decoded; - RETURN_IF_NOT_OK(op.Compute(input, &decoded)); - return RandomCropAndResizeOp::Compute(decoded, output); - } else { - struct jpeg_decompress_struct cinfo {}; - struct JpegErrorManagerCustom jerr {}; - cinfo.err = jpeg_std_error(&jerr.pub); - jerr.pub.error_exit = JpegErrorExitCustom; - try { - jpeg_create_decompress(&cinfo); - JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); - (void)jpeg_read_header(&cinfo, TRUE); - jpeg_calc_output_dimensions(&cinfo); - } catch (std::runtime_error &e) { - jpeg_destroy_decompress(&cinfo); - RETURN_STATUS_UNEXPECTED(e.what()); - } - int h_in = cinfo.output_height; - int w_in = cinfo.output_width; - jpeg_destroy_decompress(&cinfo); - - int x = 0; - int y = 0; - int crop_height = 0; - int crop_width = 0; - (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); - - std::shared_ptr decoded; - RETURN_IF_NOT_OK(JpegCropAndDecode(input, &decoded, x, y, crop_width, crop_height)); - return Resize(decoded, output, target_height_, target_width_, 0.0, 0.0, interpolation_); - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.h deleted file mode 100644 index 57d1161961..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_decode_resize_op.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomCropDecodeResizeOp : public RandomCropAndResizeOp { - public: - RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, - float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, - InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); - - explicit RandomCropDecodeResizeOp(const RandomCropAndResizeOp &rhs) : RandomCropAndResizeOp(rhs) {} - - ~RandomCropDecodeResizeOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropDecodeResize: " << RandomCropAndResizeOp::target_height_ << " " - << RandomCropAndResizeOp::target_width_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kRandomCropDecodeResizeOp; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_op.cc deleted file mode 100644 index 110d769f26..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_crop_op.h" -#include -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t RandomCropOp::kDefPadTop = 0; -const int32_t RandomCropOp::kDefPadBottom = 0; -const int32_t RandomCropOp::kDefPadLeft = 0; -const int32_t RandomCropOp::kDefPadRight = 0; -const BorderType RandomCropOp::kDefBorderType = BorderType::kConstant; -const bool RandomCropOp::kDefPadIfNeeded = false; -const uint8_t RandomCropOp::kDefFillR = 0; -const uint8_t RandomCropOp::kDefFillG = 0; -const uint8_t RandomCropOp::kDefFillB = 0; - -RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top, int32_t pad_bottom, - int32_t pad_left, int32_t pad_right, BorderType border_types, bool pad_if_needed, - uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) - : crop_height_(crop_height), - crop_width_(crop_width), - pad_top_(pad_top), - pad_bottom_(pad_bottom), - pad_left_(pad_left), - pad_right_(pad_right), - pad_if_needed_(pad_if_needed), - border_type_(border_types), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) { - rnd_.seed(GetSeed()); -} - -Status RandomCropOp::ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, - int32_t *t_pad_top, int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, - int32_t *padded_image_w, int32_t *padded_image_h, bool *crop_further) { - *t_pad_top = pad_top_; - *t_pad_bottom = pad_bottom_; - *t_pad_left = pad_left_; - *t_pad_right = pad_right_; - - RETURN_IF_NOT_OK( - Pad(input, pad_image, pad_top_, pad_bottom_, pad_left_, pad_right_, border_type_, fill_r_, fill_g_, fill_b_)); - CHECK_FAIL_RETURN_UNEXPECTED((*pad_image)->shape().Size() >= 2, "Abnormal shape"); - - *padded_image_h = (*pad_image)->shape()[0]; - *padded_image_w = (*pad_image)->shape()[1]; - - if (*padded_image_h == crop_height_ && *padded_image_w == crop_width_) { - *crop_further = false; // no need for further crop - return Status::OK(); - } else if (pad_if_needed_) { - // check the dimensions of the image for padding, if we do need padding, then we change the pad values - if (*padded_image_h < crop_height_) { - RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, crop_height_ - *padded_image_h, crop_height_ - *padded_image_h, 0, 0, - border_type_, fill_r_, fill_g_, fill_b_)); - - // update pad total above/below - t_pad_top += (crop_height_ - *padded_image_h); - t_pad_bottom += (crop_height_ - *padded_image_h); - } - if (*padded_image_w < crop_width_) { - RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, 0, 0, crop_width_ - *padded_image_w, crop_width_ - *padded_image_w, - border_type_, fill_r_, fill_g_, fill_b_)); - // update pad total left/right - t_pad_left += (crop_width_ - *padded_image_w); - t_pad_right += (crop_width_ - *padded_image_w); - } - *padded_image_h = (*pad_image)->shape()[0]; - *padded_image_w = (*pad_image)->shape()[1]; - } - - if (*padded_image_h < crop_height_ || *padded_image_w < crop_width_ || crop_height_ == 0 || crop_width_ == 0) { - return Status(StatusCode::kShapeMisMatch, __LINE__, __FILE__, - "Crop size is greater than the image dimensions or is zero."); - } - return Status::OK(); -} - -void RandomCropOp::GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h) { - // GenCropPoints for cropping - *x = std::uniform_int_distribution(0, padded_image_w - crop_width_)(rnd_); - *y = std::uniform_int_distribution(0, padded_image_h - crop_height_)(rnd_); -} - -Status RandomCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - - // Apply padding first then crop - std::shared_ptr pad_image; - int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; - int32_t padded_image_w; - int32_t padded_image_h; - bool crop_further = true; // whether image needs further cropping based on new size & requirements - - RETURN_IF_NOT_OK( // error code sent back directly - ImagePadding(input, &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, &padded_image_w, - &padded_image_h, &crop_further)); - if (!crop_further) { - *output = pad_image; - return Status::OK(); - } - - int x, y; - GenRandomXY(&x, &y, padded_image_w, padded_image_h); - return Crop(pad_image, output, x, y, crop_width_, crop_height_); -} - -Status RandomCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{crop_height_, crop_width_}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_op.h deleted file mode 100644 index f0b1ec828c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_op.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomCropOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefPadTop; - static const int32_t kDefPadBottom; - static const int32_t kDefPadLeft; - static const int32_t kDefPadRight; - static const BorderType kDefBorderType; - static const bool kDefPadIfNeeded; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, - int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, int32_t pad_right = kDefPadRight, - BorderType border_types = kDefBorderType, bool pad_if_needed = kDefPadIfNeeded, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - RandomCropOp(const RandomCropOp &rhs) = default; - - RandomCropOp(RandomCropOp &&rhs) = default; - - ~RandomCropOp() override = default; - - void Print(std::ostream &out) const override { out << "RandomCropOp: " << crop_height_ << " " << crop_width_; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // Function breaks out the compute function's image padding functionality and makes available to other Ops - // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op - // @param input: Input is the original Image - // @param pad_image: Pointer to new Padded image - // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required - // @param t_pad_bottom: Total bottom Padding - Based on input and value calculated in function if required - // @param t_pad_left: Total left Padding - Based on input and value calculated in function if required - // @param t_pad_right: Total right Padding - Based on input and value calculated in function if required - // @param padded_image_w: Final Width of the 'pad_image' - // @param padded_image_h: Final Height of the 'pad_image' - // @param crop_further: Whether image required cropping after padding - False if new padded image matches required - // dimensions - Status ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, int32_t *t_pad_top, - int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, int32_t *padded_image_w, - int32_t *padded_image_h, bool *crop_further); - - // Function breaks X,Y generation functionality out of original compute function and makes available to other Ops - void GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h); - - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kRandomCropOp; } - - protected: - int32_t crop_height_ = 0; - int32_t crop_width_ = 0; - - private: - int32_t pad_top_ = 0; - int32_t pad_bottom_ = 0; - int32_t pad_left_ = 0; - int32_t pad_right_ = 0; - bool pad_if_needed_ = false; - BorderType border_type_; - uint8_t fill_r_ = 0; - uint8_t fill_g_ = 0; - uint8_t fill_b_ = 0; - std::mt19937 rnd_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc deleted file mode 100644 index c873307afd..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "dataset/kernels/image/random_crop_with_bbox_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - - std::shared_ptr pad_image; - int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; - size_t boxCount = input[1]->shape()[0]; // number of rows - - int32_t padded_image_h; - int32_t padded_image_w; - - output->resize(2); - (*output)[1] = std::move(input[1]); // since some boxes may be removed - - bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches - RETURN_IF_NOT_OK( // Error passed back to caller - RandomCropOp::ImagePadding(input[0], &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, - &padded_image_w, &padded_image_h, &crop_further)); - - // update bounding boxes with new values based on relevant image padding - if (t_pad_left || t_pad_bottom) { - RETURN_IF_NOT_OK(PadBBoxes(&(*output)[1], boxCount, t_pad_left, t_pad_top)); - } - if (!crop_further) { - // no further cropping required - (*output)[0] = pad_image; - (*output)[1] = std::move(input[1]); - return Status::OK(); - } - - int x, y; - RandomCropOp::GenRandomXY(&x, &y, padded_image_w, padded_image_h); - int maxX = x + RandomCropOp::crop_width_; // max dims of selected CropBox on image - int maxY = y + RandomCropOp::crop_height_; - RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &boxCount, x, y, maxX, maxY)); - return Crop(pad_image, &(*output)[0], x, y, RandomCropOp::crop_width_, RandomCropOp::crop_height_); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.h deleted file mode 100644 index 37b5ffc38b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_with_bbox_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ - -#include -#include -#include - -#include "dataset/kernels/image/random_crop_op.h" - -namespace mindspore { -namespace dataset { -class RandomCropWithBBoxOp : public RandomCropOp { - public: - // Constructor for RandomCropWithBBoxOp, with default value and passing to base class constructor - RandomCropWithBBoxOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, - int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, - int32_t pad_right = kDefPadRight, BorderType border_types = kDefBorderType, - bool pad_if_needed = kDefPadIfNeeded, uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, - uint8_t fill_b = kDefFillB) - : RandomCropOp(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, border_types, pad_if_needed, - fill_r, fill_g, fill_b) {} - - ~RandomCropWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { - out << "RandomCropWithBBoxOp: " << RandomCropOp::crop_height_ << " " << RandomCropOp::crop_width_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kRandomCropWithBBoxOp; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.cc deleted file mode 100644 index ae76e1bf59..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_horizontal_flip_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomHorizontalFlipOp::kDefProbability = 0.5; - -Status RandomHorizontalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (distribution_(rnd_)) { - return HorizontalFlip(input, output); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.h b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.h deleted file mode 100644 index a0ea3822d3..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_op.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomHorizontalFlipOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - - explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomHorizontalFlipOp() override = default; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipOp &so) { - so.Print(out); - return out; - } - - void Print(std::ostream &out) const override { out << "RandomHorizontalFlipOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kRandomHorizontalFlipOp; } - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc deleted file mode 100644 index cf8a4640ff..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" -#include "dataset/core/cv_tensor.h" - -namespace mindspore { -namespace dataset { -const float RandomHorizontalFlipWithBBoxOp::kDefProbability = 0.5; - -Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - if (distribution_(rnd_)) { - // To test bounding boxes algorithm, create random bboxes from image dims - size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes - float img_center = (input[0]->shape()[1] / 2.); // get the center of the image - for (int i = 0; i < num_of_boxes; i++) { - float b_w = 0; // bounding box width - float min_x = 0; - // get the required items - RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_x, {i, 0})); - RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_w, {i, 2})); - // do the flip - float diff = img_center - min_x; // get distance from min_x to center - float refl_min_x = diff + img_center; // get reflection of min_x - float new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one - RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 0}, new_min_x)); - } - (*output).resize(2); - // move input to output pointer of bounding boxes - (*output)[1] = std::move(input[1]); - // perform HorizontalFlip on the image - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); - return HorizontalFlip(std::static_pointer_cast(input_cv), &(*output)[0]); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h deleted file mode 100644 index 3480e2ac6b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ - -#include -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomHorizontalFlipWithBBoxOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - - explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomHorizontalFlipWithBBoxOp() override = default; - - // Provide stream operator for displaying it - friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipWithBBoxOp &so) { - so.Print(out); - return out; - } - - void Print(std::ostream &out) const override { out << "RandomHorizontalFlipWithBBoxOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kRandomHorizontalFlipWithBBoxOp; } - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_resize_op.cc deleted file mode 100644 index c14224a930..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_resize_op.h" - -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t RandomResizeOp::kDefTargetWidth = 0; - -Status RandomResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - // Randomly selects from the following four interpolation methods - // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area - interpolation_ = static_cast(distribution_(random_generator_)); - return ResizeOp::Compute(input, output); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.h b/mindspore/ccsrc/dataset/kernels/image/random_resize_op.h deleted file mode 100644 index 9e60867353..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomResizeOp : public ResizeOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefTargetWidth; - - explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { - random_generator_.seed(GetSeed()); - } - - ~RandomResizeOp() = default; - - // Description: A function that prints info about the node - void Print(std::ostream &out) const override { - out << "RandomResizeOp: " << ResizeOp::size1_ << " " << ResizeOp::size2_; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kRandomResizeOp; } - - private: - std::mt19937 random_generator_; - std::uniform_int_distribution distribution_{0, 3}; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc deleted file mode 100644 index de69c02e39..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/image/random_resize_with_bbox_op.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0; - -Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - // Randomly selects from the following four interpolation methods - // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area - interpolation_ = static_cast(distribution_(random_generator_)); - RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h deleted file mode 100644 index e5106f9cf5..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H -#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefTargetWidth; - explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { - random_generator_.seed(GetSeed()); - } - - ~RandomResizeWithBBoxOp() = default; - - // Description: A function that prints info about the node - void Print(std::ostream &out) const override { - out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kRandomResizeWithBBoxOp; } - - private: - std::mt19937 random_generator_; - std::uniform_int_distribution distribution_{0, 3}; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.cc deleted file mode 100644 index 65e024865b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.cc +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/random_rotation_op.h" - -#include - -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomRotationOp::kDefCenterX = -1; -const float RandomRotationOp::kDefCenterY = -1; -const InterpolationMode RandomRotationOp::kDefInterpolation = InterpolationMode::kNearestNeighbour; -const bool RandomRotationOp::kDefExpand = false; -const uint8_t RandomRotationOp::kDefFillR = 0; -const uint8_t RandomRotationOp::kDefFillG = 0; -const uint8_t RandomRotationOp::kDefFillB = 0; - -// constructor -RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float center_x, float center_y, - InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, - uint8_t fill_b) - : degree_start_(start_degree), - degree_end_(end_degree), - center_x_(center_x), - center_y_(center_y), - interpolation_(interpolation), - expand_(expand), - fill_r_(fill_r), - fill_g_(fill_g), - fill_b_(fill_b) { - rnd_.seed(GetSeed()); -} - -// main function call for random rotation : Generate the random degrees -Status RandomRotationOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - float random_double = distribution_(rnd_); - // get the degree rotation range, mod by 360 because full rotation doesn't affect - // the way this op works (uniform distribution) - // assumption here is that mDegreesEnd > mDegreeStart so we always get positive number - // Note: the range technically is greater than 360 degrees, but will be halved - float degree_range = (degree_end_ - degree_start_) / 2; - float mid = (degree_end_ + degree_start_) / 2; - float degree = mid + random_double * degree_range; - - return Rotate(input, output, center_x_, center_y_, degree, interpolation_, expand_, fill_r_, fill_g_, fill_b_); -} -Status RandomRotationOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - int32_t outputH = -1, outputW = -1; - // if expand_, then we cannot know the shape. We need the input image to find the output shape --> set it to - // <-1,-1[,3]> - if (!expand_) { - outputH = inputs[0][0]; - outputW = inputs[0][1]; - } - TensorShape out = TensorShape{outputH, outputW}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.h b/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.h deleted file mode 100644 index 7ae65fe02b..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_rotation_op.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/kernels/image/image_utils.h" - -namespace mindspore { -namespace dataset { -class RandomRotationOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefCenterX; - static const float kDefCenterY; - static const InterpolationMode kDefInterpolation; - static const bool kDefExpand; - static const uint8_t kDefFillR; - static const uint8_t kDefFillG; - static const uint8_t kDefFillB; - - // Constructor for RandomRotationOp - // @param startDegree starting range for random degree - // @param endDegree ending range for random degree - // @param centerX x coordinate for center of image rotation - // @param centerY y coordinate for center of image rotation - // @param interpolation DE interpolation mode for rotation - // @param expand option for the output image shape to change - // @param fill_r R value for the color to pad with - // @param fill_g G value for the color to pad with - // @param fill_b B value for the color to pad with - // @details the randomly chosen degree is uniformly distributed - // @details the output shape, if changed, will contain the entire rotated image - // @note maybe using unsigned long int isn't the best here according to our coding rules - RandomRotationOp(float start_degree, float end_degree, float center_x = kDefCenterX, float center_y = kDefCenterY, - InterpolationMode interpolation = kDefInterpolation, bool expand = kDefExpand, - uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); - - ~RandomRotationOp() override = default; - - // Print function for RandomRotation - // @param out output stream to print to - void Print(std::ostream &out) const override { out << "RandomRotationOp: "; } - - // Overrides the base class compute function - // Calls the rotate function in ImageUtils, this function takes an input tensor - // and transforms its data using openCV, the output memory is manipulated to contain the result - // @return Status - The error code return - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kRandomRotationOp; } - - private: - float degree_start_; - float degree_end_; - float center_x_; - float center_y_; - InterpolationMode interpolation_; - bool expand_; - uint8_t fill_r_; - uint8_t fill_g_; - uint8_t fill_b_; - std::uniform_real_distribution distribution_{-1.0, 1.0}; - std::mt19937 rnd_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.cc deleted file mode 100644 index 096923a9ec..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/image/random_vertical_flip_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const float RandomVerticalFlipOp::kDefProbability = 0.5; - -Status RandomVerticalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (distribution_(rnd_)) { - return VerticalFlip(input, output); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.h b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.h deleted file mode 100644 index 3664ed7d3a..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_op.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -class RandomVerticalFlipOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - - explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomVerticalFlipOp() override = default; - - void Print(std::ostream &out) const override { out << "RandomVerticalFlipOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kRandomVerticalFlipOp; } - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc deleted file mode 100644 index 7e897536e8..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "dataset/util/status.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" - -namespace mindspore { -namespace dataset { -const float RandomVerticalFlipWithBBoxOp::kDefProbability = 0.5; -Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - - if (distribution_(rnd_)) { - dsize_t imHeight = input[0]->shape()[0]; - size_t boxCount = input[1]->shape()[0]; // number of rows in tensor - - // one time allocation -> updated in the loop - // type defined based on VOC test dataset - for (int i = 0; i < boxCount; i++) { - float boxCorner_y = 0.0, boxHeight = 0.0; - float newBoxCorner_y = 0.0; - RETURN_IF_NOT_OK(input[1]->GetItemAt(&boxCorner_y, {i, 1})); // get min y of bbox - RETURN_IF_NOT_OK(input[1]->GetItemAt(&boxHeight, {i, 3})); // get height of bbox - - // subtract (curCorner + height) from (max) for new Corner position - newBoxCorner_y = (imHeight - 1.0) - ((boxCorner_y + boxHeight) - 1.0); - RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); - } - - output->resize(2); - (*output)[1] = std::move(input[1]); - - return VerticalFlip(input[0], &(*output)[0]); - } - *output = input; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.h deleted file mode 100644 index 15a96fe749..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/random_vertical_flip_with_bbox_op.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ -#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -class RandomVerticalFlipWithBBoxOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const float kDefProbability; - // Constructor for RandomVerticalFlipWithBBoxOp - // @param probability: Probablity of Image flipping, 0.5 by default - explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { - rnd_.seed(GetSeed()); - } - - ~RandomVerticalFlipWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { out << "RandomVerticalFlipWithBBoxOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kRandomVerticalFlipWithBBoxOp; } - - private: - std::mt19937 rnd_; - std::bernoulli_distribution distribution_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/rescale_op.cc b/mindspore/ccsrc/dataset/kernels/image/rescale_op.cc deleted file mode 100644 index fd1807991c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/rescale_op.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/rescale_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status RescaleOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - return Rescale(input, output, rescale_, shift_); -} -Status RescaleOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = DataType(DataType::DE_FLOAT32); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/rescale_op.h b/mindspore/ccsrc/dataset/kernels/image/rescale_op.h deleted file mode 100644 index b91226a9f8..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/rescale_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESCALE_OP_H_ -#define DATASET_KERNELS_IMAGE_RESCALE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class RescaleOp : public TensorOp { - public: - RescaleOp(float rescale_ratio, float shift_ratio) : rescale_(rescale_ratio), shift_(shift_ratio) {} - - ~RescaleOp() override = default; - - void Print(std::ostream &out) const override { - out << "RescaleOp: shift: " << shift_ << ", Rescale: " << rescale_ << std::endl; - } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kRescaleOp; } - - private: - float rescale_; - float shift_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_IMAGE_RESCALE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.cc deleted file mode 100644 index 658caac6a5..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.cc +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/resize_bilinear_op.h" -#include - -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t ResizeBilinearOp::kDefWidth = 0; - -void ResizeBilinearOp::Print(std::ostream &out) const { out << "ResizeBilinearOp: "; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.h deleted file mode 100644 index c14beda067..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_bilinear_op.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ -#define DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ - -#include -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class ResizeBilinearOp : public ResizeOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefWidth; - - // Name: constructor - // Resizes the image to the output specified size using Bilinear interpolation. - // If only one value is provided, the it will resize the smaller size and maintains - // the aspect ratio. - // @param size1: the first size of output. If only this parameter is provided - // the smaller dimension will be resized to this and then the other dimension changes - // such that the aspect ratio is maintained. - // @param size2: the second size of output. If this is also provided, the output size - // will be (size1, size2) - explicit ResizeBilinearOp(int32_t size1, int32_t size2 = kDefWidth) - : ResizeOp(size1, size2, ResizeOp::kDefInterpolation) {} - - // Name: Destructor - // Description: Destructor - ~ResizeBilinearOp() = default; - - // Name: Print() - // Description: A function that prints info about the node - void Print(std::ostream &out) const override; - - std::string Name() const override { return kResizeBilinearOp; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_op.cc deleted file mode 100644 index 7c0252188e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_op.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/resize_op.h" - -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t ResizeOp::kDefWidth = 0; -const InterpolationMode ResizeOp::kDefInterpolation = InterpolationMode::kLinear; - -Status ResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape size " + std::to_string(input->shape().Size()) + - " of input tensor is invalid"); - int32_t output_h, output_w = 0; - int32_t input_h = static_cast(input->shape()[0]); - int32_t input_w = static_cast(input->shape()[1]); - if (size2_ == 0) { - if (input_h < input_w) { - CHECK_FAIL_RETURN_UNEXPECTED(input_h != 0, "The input height is 0"); - output_h = size1_; - output_w = static_cast(std::lround(static_cast(input_w) / input_h * output_h)); - } else { - CHECK_FAIL_RETURN_UNEXPECTED(input_w != 0, "The input width is 0"); - output_w = size1_; - output_h = static_cast(std::lround(static_cast(input_h) / input_w * output_w)); - } - } else { - output_h = size1_; - output_w = size2_; - } - return Resize(input, output, output_h, output_w, 0, 0, interpolation_); -} - -Status ResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - int32_t outputH = -1, outputW = -1; - // if size2_ == 0, then we cannot know the shape. We need the input image to find the output shape --> set it to - // <-1,-1[,3]> - if (size2_ != 0) { - outputH = size1_; - outputW = size2_; - } - TensorShape out = TensorShape{outputH, outputW}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_op.h deleted file mode 100644 index efbe9dab06..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_op.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_OP_H_ -#define DATASET_KERNELS_IMAGE_RESIZE_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class ResizeOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefWidth; - static const InterpolationMode kDefInterpolation; - - // Resizes the image to the output specified size. If only one value is provided, - // the it will resize the smaller size and maintains the aspect ratio. - // @param size1: the first size of output. If only this parameter is provided - // the smaller dimension will be resized to this and then the other dimension changes - // such that the aspect ratio is maintained. - // @param size2: the second size of output. If this is also provided, the output size - // will be (size1, size2) - // @param InterpolationMode: the interpolation mode being used. - explicit ResizeOp(int32_t size1, int32_t size2 = kDefWidth, InterpolationMode mInterpolation = kDefInterpolation) - : size1_(size1), size2_(size2), interpolation_(mInterpolation) {} - - ResizeOp(const ResizeOp &rhs) = default; - - ResizeOp(ResizeOp &&rhs) = default; - - ~ResizeOp() override = default; - - void Print(std::ostream &out) const override { out << "ResizeOp: " << size1_ << " " << size2_; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kResizeOp; } - - protected: - int32_t size1_; - int32_t size2_; - InterpolationMode interpolation_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc deleted file mode 100644 index 8a633d5678..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/kernels/image/resize_with_bbox_op.h" -#include -#include -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/pybind_support.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - BOUNDING_BOX_CHECK(input); - - int32_t input_h = input[0]->shape()[0]; - int32_t input_w = input[0]->shape()[1]; - - output->resize(2); - (*output)[1] = std::move(input[1]); // move boxes over to output - - std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); - - RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast(input_cv), &(*output)[0])); - - int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox - int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox - - size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor - RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h deleted file mode 100644 index 2fa3e711b8..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H -#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H - -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/kernels/image/resize_op.h" - -namespace mindspore { -namespace dataset { -class ResizeWithBBoxOp : public ResizeOp { - public: - // Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor - explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth, - InterpolationMode mInterpolation = kDefInterpolation) - : ResizeOp(size_1, size_2, mInterpolation) {} - - ~ResizeWithBBoxOp() override = default; - - void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kResizeWithBBoxOp; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc deleted file mode 100644 index 7889b3b157..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include -#include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/util/random.h" - -namespace mindspore { -namespace dataset { -const int UniformAugOp::kDefNumOps = 2; - -UniformAugOp::UniformAugOp(std::vector> op_list, int32_t num_ops) - : tensor_op_list_(op_list), num_ops_(num_ops) { - rnd_.seed(GetSeed()); -} - -// compute method to apply uniformly random selected augmentations from a list -Status UniformAugOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - - // randomly select ops to be applied - std::vector> selected_tensor_ops; - std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); - - bool first = true; - for (const auto &tensor_op : selected_tensor_ops) { - // Do NOT apply the op, if second random generator returned zero - if (std::uniform_int_distribution(0, 1)(rnd_)) { - continue; - } - // apply C++ ops (note: python OPs are not accepted) - if (first) { - RETURN_IF_NOT_OK(tensor_op->Compute(input, output)); - first = false; - } else { - RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output)); - } - } - - // The case where no tensor op is applied. - if (output->empty()) { - *output = input; - } - - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h deleted file mode 100644 index aa96b9f33c..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ -#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class UniformAugOp : public TensorOp { - public: - // Default number of Operations to be applied - static const int kDefNumOps; - - // Constructor for UniformAugOp - // @param std::vector> op_list: list of candidate C++ operations - // @param int32_t num_ops: number of augemtation operations to applied - UniformAugOp(std::vector> op_list, int32_t num_ops); - - // Destructor - ~UniformAugOp() override = default; - - void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } - - // Overrides the base class compute function - // @return Status - The error code return - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kUniformAugOp; } - - private: - int32_t num_ops_; - std::vector> tensor_op_list_; - std::mt19937 rnd_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/no_op.h b/mindspore/ccsrc/dataset/kernels/no_op.h deleted file mode 100644 index 83d0d4baa7..0000000000 --- a/mindspore/ccsrc/dataset/kernels/no_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_NO_OP_H_ -#define DATASET_KERNELS_NO_OP_H_ - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class NoOp : public TensorOp { - public: - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { - *output = input; - return Status::OK(); - } - - void Print(std::ostream &out) const override { out << "NoOp"; }; - - std::string Name() const override { return kNoOp; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_NO_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/dataset/kernels/py_func_op.cc deleted file mode 100644 index 0a6a1452b5..0000000000 --- a/mindspore/ccsrc/dataset/kernels/py_func_op.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/py_func_op.h" - -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed"); - { - // Acquire Python GIL - py::gil_scoped_acquire gil_acquire; - if (Py_IsInitialized() == 0) { - ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); - goto ComputeReturn; - } - try { - // Transform input tensor vector into numpy array vector - py::tuple input_args(input.size()); - for (size_t i = 0; i < input.size(); i++) { - py::array new_data; - RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); - // possible memcpy here - input_args[i] = new_data; - } - // Invoke python function - py::object ret_py_obj = this->py_func_ptr_(*input_args); - // Process the return value - if (py::isinstance(ret_py_obj)) { - // In case of a n-1 mapping, the return value will be a numpy array - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast())); - output->push_back(out); - } else if (py::isinstance(ret_py_obj)) { - // In case of a n-m mapping, the return value will be a tuple of numpy arrays - py::tuple ret_py_tuple = ret_py_obj.cast(); - // Iterate over two containers simultaneously for memory copy - for (size_t i = 0; i < ret_py_tuple.size(); i++) { - py::object ret_py_ele = ret_py_tuple[i]; - if (!py::isinstance(ret_py_ele)) { - goto ShapeMisMatch; - } - std::shared_ptr out; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast())); - output->push_back(out); - } - } else { - goto ShapeMisMatch; - } - } catch (const py::error_already_set &e) { - ret = Status(StatusCode::kPyFuncException, e.what()); - } - } - -ComputeReturn: - return ret; - -ShapeMisMatch: - ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple"); - goto ComputeReturn; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/py_func_op.h b/mindspore/ccsrc/dataset/kernels/py_func_op.h deleted file mode 100644 index 473e75ec97..0000000000 --- a/mindspore/ccsrc/dataset/kernels/py_func_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_KERNELS_PY_FUNC_OP_H_ -#define DATASET_KERNELS_PY_FUNC_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" - -namespace mindspore { -namespace dataset { -class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp { - public: - explicit PyFuncOp(py::function func) : py_func_ptr_(std::move(func)) {} - - ~PyFuncOp() override = default; - - uint32_t NumInput() override { return 0; } - uint32_t NumOutput() override { return 0; } - - // Compute function for n-n mapping. - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kPyFuncOp; } - - private: - py::function py_func_ptr_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_PY_FUNC_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/tensor_op.cc b/mindspore/ccsrc/dataset/kernels/tensor_op.cc deleted file mode 100644 index 92aef8dc9e..0000000000 --- a/mindspore/ccsrc/dataset/kernels/tensor_op.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/tensor_op.h" -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { -// Name: Compute() -// Description: This Compute() take 1 Tensor and produce 1 Tensor. -// The derived class should override this function otherwise error. -Status TensorOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - if (!OneToOne()) { - return Status(StatusCode::kUnexpectedError, "Wrong Compute() function is called. This is not 1-1 TensorOp."); - } else { - return Status(StatusCode::kUnexpectedError, - "Is this TensorOp 1-1? If yes, please implement this Compute() in the derived class."); - } -} - -// Name: Compute() -// Description: This Compute() take multiple Tensors from different columns and produce multiple Tensors too. -// The derived class should override this function otherwise error. -Status TensorOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - if (OneToOne()) { - output->resize(1); - return Compute(input[0], &(*output)[0]); - } - - return Status(StatusCode::kUnexpectedError, - "Is this TensorOp oneToOne? If no, please implement this Compute() in the derived class."); -} - -void TensorOp::Print(std::ostream &out) const { out << "TensorOp" << std::endl; } - -Status TensorOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - if (inputs.size() != NumInput()) - return Status(StatusCode::kUnexpectedError, - "The size of the input argument vector does not match the number of inputs"); - outputs = inputs; - return Status::OK(); -} - -Status TensorOp::OutputType(const std::vector &inputs, std::vector &outputs) { - if (inputs.size() != NumInput()) - return Status(StatusCode::kUnexpectedError, - "The size of the input argument vector does not match the number of inputs"); - outputs = inputs; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/tensor_op.h b/mindspore/ccsrc/dataset/kernels/tensor_op.h deleted file mode 100644 index 444919b78d..0000000000 --- a/mindspore/ccsrc/dataset/kernels/tensor_op.h +++ /dev/null @@ -1,212 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_TENSOR_OP_H_ -#define DATASET_KERNELS_TENSOR_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_row.h" -#include "dataset/util/status.h" - -#define IO_CHECK(input, output) \ - do { \ - if (input == nullptr || output == nullptr) { \ - RETURN_STATUS_UNEXPECTED("input or output is null."); \ - } \ - } while (false) - -#define IO_CHECK_VECTOR(input, output) \ - do { \ - if (output == nullptr) { \ - RETURN_STATUS_UNEXPECTED("output is null."); \ - } \ - for (auto &_i : input) { \ - if (_i == nullptr) { \ - RETURN_STATUS_UNEXPECTED("input is null."); \ - } \ - } \ - } while (false) - -#define BOUNDING_BOX_CHECK(input) \ - do { \ - if (input.size() != 2) { \ - return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ - "Requires Image and Bounding Boxes, likely missed bounding boxes."); \ - } \ - if (input[1]->shape().Size() < 2) { \ - return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ - "Bounding boxes shape should have at least two dimensions."); \ - } \ - uint32_t num_of_features = input[1]->shape()[1]; \ - if (num_of_features < 4) { \ - return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ - "Bounding boxes should be have at least 4 features."); \ - } \ - uint32_t num_of_boxes = input[1]->shape()[0]; \ - uint32_t img_h = input[0]->shape()[0]; \ - uint32_t img_w = input[0]->shape()[1]; \ - for (uint32_t i = 0; i < num_of_boxes; i++) { \ - float min_x = 0.0, min_y = 0.0, b_w = 0.0, b_h = 0.0; \ - bool passing_data_fetch = true; \ - passing_data_fetch &= input[1]->GetItemAt(&min_x, {i, 0}).IsOk(); \ - passing_data_fetch &= input[1]->GetItemAt(&min_y, {i, 1}).IsOk(); \ - passing_data_fetch &= input[1]->GetItemAt(&b_w, {i, 2}).IsOk(); \ - passing_data_fetch &= input[1]->GetItemAt(&b_h, {i, 3}).IsOk(); \ - if (!passing_data_fetch) { \ - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, \ - "Fetching BBox values failed in BOUNDING_BOX_CHECK."); \ - } \ - if ((min_x + b_w > img_w) || (min_y + b_h > img_h)) { \ - return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ - "At least one of the bounding boxes is out of bounds of the image."); \ - } \ - if (static_cast(min_x) < 0 || static_cast(min_y) < 0) { \ - return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ - "At least one of the bounding boxes has negative min_x or min_y."); \ - } \ - } \ - } while (false) - -namespace mindspore { -namespace dataset { - -// image -constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; -constexpr char kDecodeOp[] = "DecodeOp"; -constexpr char kCenterCropOp[] = "CenterCropOp"; -constexpr char kCutOutOp[] = "CutOutOp"; -constexpr char kHwcToChwOp[] = "HwcToChwOp"; -constexpr char kNormalizeOp[] = "NormalizeOp"; -constexpr char kPadOp[] = "PadOp"; -constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; -constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; -constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp"; -constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp"; -constexpr char kRandomCropOp[] = "RandomCropOp"; -constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp"; -constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp"; -constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp"; -constexpr char kRandomResizeOp[] = "RandomResizeOp"; -constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; -constexpr char kRandomRotationOp[] = "RandomRotationOp"; -constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp"; -constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp"; -constexpr char kRescaleOp[] = "RescaleOp"; -constexpr char kResizeBilinearOp[] = "ResizeBilinearOp"; -constexpr char kResizeOp[] = "ResizeOp"; -constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; -constexpr char kUniformAugOp[] = "UniformAugOp"; - -// text -constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; -constexpr char kBertTokenizerOp[] = "BertTokenizerOp"; -constexpr char kCaseFoldOp[] = "CaseFoldOp"; -constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; -constexpr char kLookupOp[] = "LookupOp"; -constexpr char kNgramOp[] = "NgramOp"; -constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; -constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; -constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; -constexpr char kToNumberOp[] = "ToNumberOp"; -constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp"; -constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp"; -constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp"; -constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp"; -constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp"; - -// data -constexpr char kConcatenateOp[] = "kConcatenateOp"; -constexpr char kDuplicateOp[] = "DuplicateOp"; -constexpr char kFillOp[] = "FillOp"; -constexpr char kMaskOp[] = "MaskOp"; -constexpr char kOneHotOp[] = "OneHotOp"; -constexpr char kPadEndOp[] = "PadEndOp"; -constexpr char kSliceOp[] = "SliceOp"; -constexpr char kToFloat16Op[] = "ToFloat16Op"; -constexpr char kTypeCastOp[] = "TypeCastOp"; - -// other -constexpr char kPyFuncOp[] = "PyFuncOp"; -constexpr char kNoOp[] = "NoOp"; - -// A class that does a computation on a Tensor -class TensorOp { - public: - TensorOp() = default; - - virtual ~TensorOp() = default; - - // A function that prints info about the tensor operation - // @param out - virtual void Print(std::ostream &out) const; - - // Provide stream operator for displaying it - // @param output stream - // @param so the TensorOp object to be printed - // @return output stream - friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) { - so.Print(out); - return out; - } - - // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp - // @param input shares the ownership of the Tensor (increase the ref count). - // @param output the address to a shared_ptr where the result will be placed. - // @return Status - virtual Status Compute(const std::shared_ptr &input, std::shared_ptr *output); - - // Perform an operation on Tensors from multiple columns, and produce multiple Tensors. - // This is for m-to-n column MapOp. - // @param input is a vector of shared_ptr to Tensor (pass by const reference). - // @param output is the address to an empty vector of shared_ptr to Tensor. - // @return Status - virtual Status Compute(const TensorRow &input, TensorRow *output); - - // Returns true oif the TensorOp takes one input and returns one output. - // @return true/false - bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } - - // Function to determine the number of inputs the TensorOp can take. 0: means undefined. - // @return uint32_t - virtual uint32_t NumInput() { return 1; } - - // Function to determine the number of output the TensorOp generates. 0: means undefined. - // @return uint32_t - virtual uint32_t NumOutput() { return 1; } - - // Function to determine the shapes of the output tensor given the input tensors' shapes. - // If a subclass did not override this function, it means that the shape does not change. - // @param inputs in: vector of the shapes of the input tensors. - // @param outputs out: vector of the shapes of the output tensors to be filled. - // @return Status - virtual Status OutputShape(const std::vector &inputs, std::vector &outputs); - - // Function to determine the types of the output tensor given the input tensor's types. - // If a subclass did not override this function, it means that the type does not change. - // @param inputs in: vector of the types of the input tensors. - // @param outputs out: vector of the types of the output tensors to be filled. - // @return Status - virtual Status OutputType(const std::vector &inputs, std::vector &outputs); - - virtual std::string Name() const = 0; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_TENSOR_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc deleted file mode 100644 index c0217b2083..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.cc +++ /dev/null @@ -1,173 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include -#include -#include -#include -#include -#include - -#include "unicode/errorcode.h" -#include "unicode/normalizer2.h" -#include "unicode/utypes.h" - -namespace mindspore { -namespace dataset { - -const bool BasicTokenizerOp::kDefLowerCase = false; -const bool BasicTokenizerOp::kDefKeepWhitespace = false; -const NormalizeForm BasicTokenizerOp::kDefNormalizationForm = NormalizeForm::kNone; -const bool BasicTokenizerOp::kDefPreserveUnusedToken = true; -const bool BasicTokenizerOp::kDefWithOffsets = false; -const char BasicTokenizerOp::kCommonPattern[] = - "[!-/]" - "|[:-@]" - "|[\\[-`]" - "|[{-~]" - "|[\\p{P}]" - "|[\\x{4E00}-\\x{9FFF}]" - "|[\\x{3400}-\\x{4DBF}]" - "|[\\x{20000}-\\x{2A6DF}]" - "|[\\x{2A700}-\\x{2B73F}]" - "|[\\x{2B740}-\\x{2B81F}]" - "|[\\x{2B820}-\\x{2CEAF}]" - "|[\\x{F900}-\\x{FAFF}]" - "|[\\x{2F800}-\\x{2FA1F}]"; -const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|"; -const std::unordered_set BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"}; - -BasicTokenizerOp::BasicTokenizerOp(const bool &lower_case, const bool &keep_whitespace, - const NormalizeForm &normalization_form, const bool &preserve_unused_token, - const bool &with_offsets) - : lower_case_(lower_case), - keep_whitespace_(keep_whitespace), - preserve_unused_token_(preserve_unused_token), - with_offsets_(with_offsets), - case_fold_(std::make_unique()), - nfd_normalize_(std::make_unique(NormalizeForm::kNfd)), - normalization_form_(normalization_form), - common_normalize_(std::make_unique(normalization_form)), - replace_accent_chars_(std::make_unique("\\p{Mn}", "")), - replace_control_chars_(std::make_unique("\\p{Cc}|\\p{Cf}", " ")) { - std::string delim_pattern = std::string("\\s+|") + kCommonPattern; - std::string keep_delim_pattern; - if (keep_whitespace_) { - keep_delim_pattern = delim_pattern; - } else { - keep_delim_pattern = kCommonPattern; - } - if (preserve_unused_token_) { - keep_delim_pattern = kUnusedPattern + keep_delim_pattern; - delim_pattern = kUnusedPattern + delim_pattern; - } - regex_tokenizer_ = std::make_unique(delim_pattern, keep_delim_pattern, with_offsets_); -} - -Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text, - const std::unordered_set &unused_words, - std::string *outupt) { - icu::ErrorCode error; - const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); - outupt->clear(); - - // 1. get start and end offsets of not case fold strs - std::queue> offsets; // offsets of not used words - int start = -1; - int len = 0; - for (int i = 0; i < text.length(); i++) { - if (text[i] == '[') { - start = i; - ++len; - } else if (text[i] == ']' && start >= 0) { - ++len; - std::string word(text.substr(start, len)); - if (unused_words.find(word) != unused_words.end()) { - offsets.push(std::make_pair(start, start + len - 1)); - } - start = -1; - len = 0; - } else if (start >= 0) { - ++len; - } - } - - // 2. Do not apply case fold on `unused_words` - start = 0; - for (int i = 0; i < text.length();) { - std::string_view process_text; - std::string preserve_token; - if (offsets.empty()) { - i = text.length(); - process_text = text.substr(start, i - start); - } else { - preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1); - process_text = text.substr(start, offsets.front().first - start); - i = offsets.front().second + 1; - offsets.pop(); - } - std::string temp; - icu::StringByteSink sink(&temp); - nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error); - *outupt += temp + preserve_token; - } - return Status::OK(); -} - -Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr &input, - std::shared_ptr *output) { - IO_CHECK(input, output); - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} - -Status BasicTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::shared_ptr cur_input; - std::shared_ptr processed_tensor; - if (lower_case_) { - if (!preserve_unused_token_) { - // to lower case - RETURN_IF_NOT_OK(case_fold_->Compute(input[0], &processed_tensor)); - } else { - // to lower case except words in kUnusedWords - RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input[0], &processed_tensor)); - } - cur_input = processed_tensor; - // strip accent characters - RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor)); - cur_input = processed_tensor; - RETURN_IF_NOT_OK(replace_accent_chars_->Compute(cur_input, &processed_tensor)); - } else { - RETURN_IF_NOT_OK(common_normalize_->Compute(input[0], &processed_tensor)); - } - // strip control characters - cur_input = processed_tensor; - RETURN_IF_NOT_OK(replace_control_chars_->Compute(cur_input, &processed_tensor)); - return regex_tokenizer_->Compute(TensorRow(0, {std::move(processed_tensor)}), output); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h deleted file mode 100644 index 96bf3e1ae2..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/basic_tokenizer_op.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/text/kernels/case_fold_op.h" -#include "dataset/text/kernels/normalize_utf8_op.h" -#include "dataset/text/kernels/regex_replace_op.h" -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class BasicTokenizerOp : public TensorOp { - public: - static const bool kDefLowerCase; - static const bool kDefKeepWhitespace; - static const NormalizeForm kDefNormalizationForm; - static const bool kDefPreserveUnusedToken; - static const bool kDefWithOffsets; - - explicit BasicTokenizerOp(const bool &lower_case = kDefLowerCase, const bool &keep_whitespace = kDefKeepWhitespace, - const NormalizeForm &normalization_form = kDefNormalizationForm, - const bool &preserve_unused_token = kDefPreserveUnusedToken, - const bool &with_offsets = kDefWithOffsets); - - ~BasicTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "BasicTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - protected: - Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set &unused_words, - std::string *outupt); - Status CaseFoldWithoutUnusedWords(const std::shared_ptr &input, std::shared_ptr *output); - - std::string Name() const override { return kBasicTokenizerOp; } - - private: - static const char kCommonPattern[]; - static const char kUnusedPattern[]; - static const std::unordered_set kUnusedWords; - bool with_offsets_; - bool lower_case_; - bool keep_whitespace_; - NormalizeForm normalization_form_; - bool preserve_unused_token_; - std::unique_ptr case_fold_; - std::unique_ptr nfd_normalize_; - std::unique_ptr common_normalize_; - std::unique_ptr replace_accent_chars_; - std::unique_ptr replace_control_chars_; - std::unique_ptr regex_tokenizer_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.cc deleted file mode 100644 index 3e7f1251ed..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.cc +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/bert_tokenizer_op.h" -namespace mindspore { -namespace dataset { -Status BertTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - TensorRow basic_tensor; - RETURN_IF_NOT_OK(basic_tokenizer_.Compute(input, &basic_tensor)); - RETURN_IF_NOT_OK(wordpiece_tokenizer_.Compute(basic_tensor, output)); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h deleted file mode 100644 index b3ae1d2ab1..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/bert_tokenizer_op.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include "dataset/text/kernels/wordpiece_tokenizer_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class BertTokenizerOp : public TensorOp { - public: - explicit BertTokenizerOp(const std::shared_ptr &vocab, - const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, - const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, - const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, - const bool &lower_case = BasicTokenizerOp::kDefLowerCase, - const bool &keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, - const NormalizeForm &normalization_form = BasicTokenizerOp::kDefNormalizationForm, - const bool &preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken, - const bool &with_offsets = WordpieceTokenizerOp::kDefWithOffsets) - : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets), - basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token, with_offsets) {} - - ~BertTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "BertTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kBertTokenizerOp; } - - private: - WordpieceTokenizerOp wordpiece_tokenizer_; - BasicTokenizerOp basic_tokenizer_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/dataset/text/kernels/case_fold_op.cc deleted file mode 100644 index d935608efd..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/case_fold_op.h" -#include -#include -#include -#include -#include - -#include "unicode/errorcode.h" -#include "unicode/normalizer2.h" -#include "unicode/utypes.h" - -namespace mindspore { -namespace dataset { - -Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - icu::ErrorCode error; - const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - icu::StringByteSink sink(&strs[i++]); - nfkc_case_fold->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.h b/mindspore/ccsrc/dataset/text/kernels/case_fold_op.h deleted file mode 100644 index 87fe05ae8d..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/case_fold_op.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ -#define DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class CaseFoldOp : public TensorOp { - public: - CaseFoldOp() {} - - ~CaseFoldOp() override = default; - - void Print(std::ostream &out) const override { out << "CaseFoldOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kCaseFoldOp; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.cc deleted file mode 100644 index b221e9cafd..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/jieba_tokenizer_op.h" - -#include -#include -#include -#include "dataset/util/path.h" - -namespace mindspore { -namespace dataset { - -const bool JiebaTokenizerOp::kDefWithOffsets = false; - -JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::string &dict_path, const JiebaMode &mode, - const bool &with_offsets) - : jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path), with_offsets_(with_offsets) { - jieba_parser_ = std::make_unique(mp_dict_path_, hmm_model_path_, ""); -} - -Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - RETURN_UNEXPECTED_IF_NULL(jieba_parser_); - - if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); - } - - std::string_view sentence_v; - RETURN_IF_NOT_OK(input[0]->GetItemAt(&sentence_v, {})); - std::string sentence{sentence_v}; - std::vector words; - std::vector offsets_start, offsets_limit; - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - if (sentence == "") { - words.push_back(""); - } else { - std::vector tmp; - if (jieba_mode_ == JiebaMode::kMp) { - std::unique_ptr mp_seg = std::make_unique(jieba_parser_->GetDictTrie()); - mp_seg->Cut(sentence, tmp, MAX_WORD_LENGTH); - } else if (jieba_mode_ == JiebaMode::kHmm) { - std::unique_ptr hmm_seg = - std::make_unique(jieba_parser_->GetHMMModel()); - hmm_seg->Cut(sentence, tmp); - } else { // Mix - std::unique_ptr mix_seg = - std::make_unique(jieba_parser_->GetDictTrie(), jieba_parser_->GetHMMModel()); - mix_seg->Cut(sentence, tmp, true); - } - GetStringsFromWords(tmp, words); - for (auto item : tmp) { - offsets_start.push_back(static_cast(item.offset)); - offsets_limit.push_back(static_cast(item.offset + item.word.length())); - } - } - token_tensor = std::make_shared(words, TensorShape({(dsize_t)words.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} - -Status JiebaTokenizerOp::AddWord(const std::string &word, int freq) { - RETURN_UNEXPECTED_IF_NULL(jieba_parser_); - if (jieba_parser_->InsertUserWord(word, freq, "") == false) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "add word error"); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.h deleted file mode 100644 index 09123d0e34..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/jieba_tokenizer_op.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ -#define DATASET_ENGINE_TEXT_JIEBA_OP_H_ - -#include -#include - -#include "cppjieba/Jieba.hpp" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; - -class JiebaTokenizerOp : public TensorOp { - public: - // default constant for Jieba MPSegment algorithm. - static constexpr size_t MAX_WORD_LENGTH = 512; - // default const for set whether Jieba output offsets tensor. - static const bool kDefWithOffsets; - // Constructor for JiebaTokenizerOp. - // @param hmm_path HMM model file. - // @param mp_path MP model file. - // @mode tokenization mode [Default "MIX"], "MP" model will tokenize with MPSegment algorithm, "HMM" mode will - // tokenize with Hiddel Markov Model Segment algorithm, "MIx" model will tokenize with a mix of MPSegment and - // HMMSegment algorithm. - // @with_offsets user set this value to choose whether output offset tensor. - JiebaTokenizerOp(const std::string &hmm_path, const std::string &mp_path, const JiebaMode &mode = JiebaMode::kMix, - const bool &with_offsets = kDefWithOffsets); - ~JiebaTokenizerOp() override = default; - - void Print(std::ostream &out) const override { - out << "JiebaTokenizerOp: " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" - << mp_dict_path_; - } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - // @word the word to be added to the JiebaTokenizer. - // @freq [Default 0] the frequency fo the word to be added. - // @tag [Default ""] the tag of the word to be added. - Status AddWord(const std::string &word, int freq = 0); - - std::string Name() const override { return kJiebaTokenizerOp; } - - protected: - std::string hmm_model_path_; - std::string mp_dict_path_; - std::unique_ptr jieba_parser_; - JiebaMode jieba_mode_; - bool with_offsets_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/dataset/text/kernels/lookup_op.cc deleted file mode 100644 index 1793301e1d..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/lookup_op.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/lookup_op.h" - -#include - -namespace mindspore { -namespace dataset { - -LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) - : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} - -Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - RETURN_UNEXPECTED_IF_NULL(vocab_); - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor."); - std::vector word_ids; - word_ids.reserve(input->Size()); - for (auto itr = input->begin(); itr != input->end(); itr++) { - WordIdType word_id = vocab_->Lookup(std::string(*itr)); - word_ids.emplace_back(word_id == Vocab::kNoTokenExists ? default_id_ : word_id); - CHECK_FAIL_RETURN_UNEXPECTED( - word_ids.back() != Vocab::kNoTokenExists, - "Lookup Error: token" + std::string(*itr) + "doesn't exist in vocab and no unknown token is specified."); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, - reinterpret_cast(word_ids.data()))); - return Status::OK(); -} -Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { - CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); - CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); - outputs[0] = type_; - return Status::OK(); -} - -void LookupOp::Print(std::ostream &out) const { - out << "LookupOp: " - << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/lookup_op.h b/mindspore/ccsrc/dataset/text/kernels/lookup_op.h deleted file mode 100644 index 7ef259474e..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/lookup_op.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_KERNELS_LOOKUP_OP_H_ -#define DATASET_TEXT_KERNELS_LOOKUP_OP_H_ - -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" -#include "dataset/text/vocab.h" - -namespace mindspore { -namespace dataset { -class LookupOp : public TensorOp { - public: - // constructor for lookup, takes in a vocab object - // @param std::shared_ptr vocab - - // @param WordIdType default_id, id to lookup if a word is not in vocab - explicit LookupOp(std::shared_ptr vocab, WordIdType default_id = 1); - - ~LookupOp() = default; - - // perform actual lookup on each tensor - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // print method - // @param std::ostream out - void Print(std::ostream &out) const override; - - // @param std::vector &inputs - - // @param std::vector &outputs - - // @return error code - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - std::string Name() const override { return kLookupOp; } - - private: - std::shared_ptr vocab_; - WordIdType default_id_; - DataType type_; // type of tensor after lookup -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_KERNELS_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/ngram_op.cc b/mindspore/ccsrc/dataset/text/kernels/ngram_op.cc deleted file mode 100644 index bbe449a89a..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/ngram_op.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/ngram_op.h" - -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { - -NgramOp::NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, - const std::string &r_pad, const std::string &separator) - : ngrams_(ngrams), - l_len_(l_len), - r_len_(r_len), - l_pad_with_sp_(l_pad + separator), - r_pad_with_sp_(r_pad + separator), - separator_(separator) {} - -Status NgramOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor"); - std::vector offsets; // offsets for each str - std::vector res; // holds the result of ngrams - std::string str_buffer; // concat all pad tokens with string interleaved with separators - res.reserve(input->shape().NumOfElements()); // this should be more than enough - offsets.reserve(1 + l_len_ + r_len_ + input->shape().NumOfElements()); - str_buffer.reserve(l_pad_with_sp_.size() * l_len_ + r_pad_with_sp_.size() * r_len_ + input->SizeInBytes()); - offsets.push_back(str_buffer.size()); // insert 0 as the starting pos - for (int i = 0; i < l_len_; i++) offsets.push_back((str_buffer += l_pad_with_sp_).size()); - - for (auto itr = input->begin(); itr != input->end(); itr++) { - str_buffer += (*itr); - str_buffer += separator_; - offsets.push_back(str_buffer.size()); - } - - for (int i = 0; i < r_len_; i++) offsets.push_back((str_buffer += r_pad_with_sp_).size()); - - for (auto n : ngrams_) { - CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n"); - int32_t start_ind = l_len_ - std::min(l_len_, n - 1); - int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1); - if (end_ind - start_ind <= n) { - res.emplace_back(std::string()); // push back empty string - } else { - CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition"); - - for (int i = start_ind; i < end_ind - n; i++) { - res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); - } - } - } - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, res, TensorShape({static_cast(res.size())}))); - return Status::OK(); -} - -void NgramOp::Print(std::ostream &out) const { - out << "NgramOp: " - << "left pad width: " << l_len_ << " left pad token with separator: " << l_pad_with_sp_ << "\n" - << "right pad width: " << r_len_ << " right pad token with separator: " << r_pad_with_sp_ << "\n" - << "separator: " << separator_ << "\n"; -} - -Status NgramOp::OutputShape(const std::vector &inputs, std::vector &outputs) { - CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); - CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1, "ngram only works with 1-dim data\n"); - dsize_t num_elements = ngrams_.size(); - for (int32_t n : ngrams_) { - // here since rank == 1, NumOfElements == shape[0]. add padding length to string - int32_t len_with_padding = inputs[0].NumOfElements() + std::min(n - 1, l_len_) + std::min(n - 1, r_len_); - // if len_with_padding - n < 0, this would return an empty string - num_elements += std::max(len_with_padding - n, 0); - } - outputs.emplace_back(TensorShape({num_elements})); - CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/ngram_op.h b/mindspore/ccsrc/dataset/text/kernels/ngram_op.h deleted file mode 100644 index 33d2587f9b..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/ngram_op.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_KERNELS_NGRAM_OP_H_ -#define DATASET_TEXT_KERNELS_NGRAM_OP_H_ - -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class NgramOp : public TensorOp { - public: - // Constructor of Ngram model - // @param const std::vector &ngrams - // @param int32_tl_len - padding length on the left - // @param int32_t r_len - padding length on the right - // @param const std::string &l_pad - padding token on the left - // @param const std::string &r_pad - padding token on the right - // @param const std::string &separator - use to join strings - NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, - const std::string &r_pad, const std::string &separator); - - // perform ngram model on each tensor - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // destructor - ~NgramOp() override = default; - - // @param std::vector &inputs - shape of input tensors - // @param std::vector &outputs - shape of output tensors - // @return error code - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - - // print arg for debugging - // @param std::ostream &out - void Print(std::ostream &out) const override; - - std::string Name() const override { return kNgramOp; } - - private: - std::vector ngrams_; // list of n grams - int32_t l_len_; // left padding length - int32_t r_len_; // right padding length - std::string l_pad_with_sp_; // left padding appended with separator - std::string r_pad_with_sp_; // right padding appended with separator - std::string separator_; // separator -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_KERNELS_NGRAM_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.cc deleted file mode 100644 index b902286576..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/normalize_utf8_op.h" -#include -#include -#include -#include -#include - -#include "unicode/errorcode.h" -#include "unicode/normalizer2.h" -#include "unicode/utypes.h" - -namespace mindspore { -namespace dataset { -const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; -Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - icu::ErrorCode error; - const icu::Normalizer2 *normalize = nullptr; - switch (normalize_form_) { - case NormalizeForm::kNone: { - *output = input; - return Status::OK(); - } - case NormalizeForm::kNfc: { - normalize = icu::Normalizer2::getNFCInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFCInstance failed"); - break; - } - case NormalizeForm::kNfkc: { - normalize = icu::Normalizer2::getNFKCInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCInstance failed"); - break; - } - case NormalizeForm::kNfd: { - normalize = icu::Normalizer2::getNFDInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFDInstance failed"); - break; - } - case NormalizeForm::kNfkd: { - normalize = icu::Normalizer2::getNFKDInstance(error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKDInstance failed"); - break; - } - default: { - RETURN_STATUS_UNEXPECTED("unexpected normalize form"); - break; - } - } - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - icu::StringByteSink sink(&strs[i++]); - normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); - CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.h b/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.h deleted file mode 100644 index d85f0fdf8f..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/normalize_utf8_op.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ -#define DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -enum class NormalizeForm { - kNone = 0, - kNfc, - kNfkc, - kNfd, - kNfkd, -}; - -class NormalizeUTF8Op : public TensorOp { - public: - static const NormalizeForm kDefNormalizeForm; - explicit NormalizeUTF8Op(NormalizeForm normalize_form = kDefNormalizeForm) : normalize_form_(normalize_form) {} - - ~NormalizeUTF8Op() override = default; - - void Print(std::ostream &out) const override { out << "NormalizeUTF8Op"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kNormalizeUTF8Op; } - - private: - NormalizeForm normalize_form_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.cc deleted file mode 100644 index 1ce2c5ea61..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/regex_replace_op.h" -#include -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { - -Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, - std::string *out) const { - CHECK_FAIL_RETURN_UNEXPECTED((matcher != nullptr && out != nullptr), "Input is null"); - UErrorCode icu_error = U_ZERO_ERROR; - icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text); - matcher->reset(unicode_text); - icu::UnicodeString unicode_out; - if (replace_all_) { - unicode_out = matcher->replaceAll(replace_, icu_error); - } else { - unicode_out = matcher->replaceFirst(replace_, icu_error); - } - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "RegexReplace failed"); - unicode_out.toUTF8String(*out); - return Status::OK(); -} - -Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - IO_CHECK(input, output); - UErrorCode icu_error = U_ZERO_ERROR; - icu::RegexMatcher matcher(pattern_, 0, icu_error); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); - std::vector strs(input->Size()); - int i = 0; - for (auto iter = input->begin(); iter != input->end(); iter++) { - RETURN_IF_NOT_OK(RegexReplace(&matcher, *iter, &strs[i])); - } - *output = std::make_shared(std::move(strs), input->shape()); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.h b/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.h deleted file mode 100644 index 9e4ae243e7..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_replace_op.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ -#define DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ -#include -#include - -#include "unicode/regex.h" -#include "unicode/errorcode.h" -#include "unicode/utypes.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class RegexReplaceOp : public TensorOp { - public: - RegexReplaceOp(const std::string &pattern, const std::string &replace, bool replace_all = true) - : pattern_(icu::UnicodeString::fromUTF8(pattern)), - replace_(icu::UnicodeString::fromUTF8(replace)), - replace_all_(replace_all) {} - - ~RegexReplaceOp() override = default; - - void Print(std::ostream &out) const override { out << "RegexReplaceOp"; } - - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - std::string Name() const override { return kRegexReplaceOp; } - - protected: - Status RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, std::string *out) const; - - private: - const icu::UnicodeString pattern_; - const icu::UnicodeString replace_; - const bool replace_all_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.cc deleted file mode 100644 index b15df9af67..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.cc +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include -#include -#include -#include -#include - -namespace mindspore { -namespace dataset { - -const bool RegexTokenizerOp::kDefWithOffsets = false; - -Status RegexTokenizerOp::GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, - std::string *out_utf8, icu::UnicodeString *out_unicode) const { - CHECK_FAIL_RETURN_UNEXPECTED((out_utf8 != nullptr || out_unicode != nullptr), "Wrong input"); - int total_len = input.length(); - int end = start + len; - CHECK_FAIL_RETURN_UNEXPECTED((start >= 0 && len > 0 && end <= total_len), "Out of range"); - icu::UnicodeString temp; - input.extract(start, len, temp); - if (out_utf8 != nullptr) { - temp.toUTF8String(*out_utf8); - } - if (out_unicode != nullptr) { - *out_unicode = temp; - } - return Status::OK(); -} - -Status RegexTokenizerOp::GetRegexTokens(const std::string &text, std::vector *out_tokens, - std::vector *offsets_start, - std::vector *offsets_limit) const { - UErrorCode status = U_ZERO_ERROR; - out_tokens->clear(); - icu::RegexMatcher token_matcher(delim_pattern_, 0, status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); - icu::RegexMatcher delim_matcher(keep_delim_pattern_, 0, status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); - - icu::UnicodeString utext(icu::UnicodeString::fromUTF8(text)); - token_matcher.reset(utext); - - int text_start_index = 0; - int token_start_index = 0; - status = U_ZERO_ERROR; - while (token_matcher.find(status) && U_SUCCESS(status)) { - int deli_start_index = token_matcher.start(status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); - int deli_end_index = token_matcher.end(status); - CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); - - // Add non-empty token - int token_len = deli_start_index - token_start_index; - if (token_len > 0) { - std::string token; - uint32_t token_offset = 0; - RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, token_len, &token)); - token_offset = token.length(); - out_tokens->emplace_back(std::move(token)); - offsets_start->push_back(static_cast(text_start_index)); - offsets_limit->push_back(static_cast(text_start_index + token_offset)); - text_start_index += token_offset; - } - - int delim_len = deli_end_index - deli_start_index; - if (delim_len > 0) { - icu::UnicodeString delim_str; - std::string delim_utf8_str; - uint32_t delim_str_offset = 0; - RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, deli_start_index, delim_len, &delim_utf8_str, &delim_str)); - delim_matcher.reset(delim_str); - delim_str_offset = delim_utf8_str.length(); - if (keep_delim_ && delim_matcher.matches(status) && U_SUCCESS(status)) { - out_tokens->emplace_back(std::move(delim_utf8_str)); - offsets_start->push_back(static_cast(text_start_index)); - offsets_limit->push_back(static_cast(text_start_index + delim_str_offset)); - } - text_start_index += delim_str_offset; - } - token_start_index = deli_end_index; - } - - if (token_start_index < utext.length()) { - std::string temp; - uint32_t temp_offset = 0; - RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, utext.length() - token_start_index, &temp)); - temp_offset = temp.length(); - out_tokens->emplace_back(std::move(temp)); - offsets_start->push_back(static_cast(text_start_index)); - offsets_limit->push_back(static_cast(text_start_index + temp_offset)); - } - return Status::OK(); -} - -Status RegexTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view text; - std::vector tokens; - std::vector offsets_start; - std::vector offsets_limit; - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - RETURN_IF_NOT_OK(input[0]->GetItemAt(&text, {})); - RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens, &offsets_start, &offsets_limit)); - token_tensor = std::make_shared(std::move(tokens), TensorShape({(dsize_t)tokens.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.h deleted file mode 100644 index 174a8419b0..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/regex_tokenizer_op.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_REGEX_TOKENIZER_OP_H_ -#define DATASET_TEXT_REGEX_TOKENIZER_OP_H_ -#include -#include -#include - -#include "unicode/regex.h" -#include "unicode/errorcode.h" -#include "unicode/utypes.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class RegexTokenizerOp : public TensorOp { - public: - static const bool kDefWithOffsets; - - RegexTokenizerOp(const std::string &delim_pattern, const std::string &keep_delim_pattern, - const bool &with_offsets = kDefWithOffsets) - : delim_pattern_(icu::UnicodeString::fromUTF8(delim_pattern)), - keep_delim_pattern_(icu::UnicodeString::fromUTF8(keep_delim_pattern)), - with_offsets_(with_offsets), - keep_delim_(!keep_delim_pattern.empty()) {} - - ~RegexTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "RegexTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - protected: - Status GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, std::string *out_utf8, - icu::UnicodeString *out_unicode = nullptr) const; - Status GetRegexTokens(const std::string &text, std::vector *out_tokens, - std::vector *offsets_start, std::vector *offsets_limit) const; - - std::string Name() const override { return kRegexTokenizerOp; } - - private: - const icu::UnicodeString delim_pattern_; - const icu::UnicodeString keep_delim_pattern_; - bool with_offsets_; - const bool keep_delim_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_REGEX_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc deleted file mode 100644 index 1368684daf..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/to_number_op.cc +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/to_number_op.h" - -#include -#include -#include -#include -#include -#include - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/kernels/data/data_utils.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {} - -ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {} - -Status ToNumberOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { - CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string."); - - switch (cast_to_type_.value()) { - case DataType::DE_INT8: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_INT16: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_INT32: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_INT64: - RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); - break; - case DataType::DE_UINT8: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_UINT16: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_UINT32: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_UINT64: - RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); - break; - case DataType::DE_FLOAT16: - RETURN_IF_NOT_OK(this->ToFloat16(input, output)); - break; - case DataType::DE_FLOAT32: - RETURN_IF_NOT_OK(ToFloat(input, output)); - break; - case DataType::DE_FLOAT64: - RETURN_IF_NOT_OK(ToDouble(input, output)); - break; - } - - return Status::OK(); -} - -void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; } - -Status ToNumberOp::OutputShape(const std::vector &input_shapes, std::vector &output_shapes) { - (void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes)); - return Status::OK(); -} - -template -Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - int64_t result = 0; - - try { - result = std::stoll(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::min()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - T casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -template -Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - uint64_t result = 0; - - // If there is a - at the start of the string, it is considered by us to - // be out of bounds. If the - is somewhere else in the string, it is - // deemed invalid by std::stoull and will throw std::invalid_argument - for (int i = 0; i < (*it).size(); i++) { - if ((*it)[i] == '-') { - is_cast_out_of_range = true; - break; - } - } - - try { - result = std::stoull(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::min()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - T casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { - // special case, float16 does not exist in c++, no native support for - // casting, so cast to float first then use this method, which use Eigen. - std::shared_ptr temp; - RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); - RETURN_IF_NOT_OK(ToFloat(input, &temp)); - RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); - return Status::OK(); -} - -Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - float result = 0; - - try { - result = std::stof(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || - is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::lowest()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - float casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_ptr *output) { - std::vector casted; - - for (auto it = input->begin(); it != input->end(); ++it) { - bool is_cast_out_of_range = false; - double result = 0; - - try { - result = std::stod(std::string(*it)); - } catch (const std::out_of_range &) { - is_cast_out_of_range = true; - } catch (const std::invalid_argument &) { - RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); - } - - if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || - is_cast_out_of_range) { - std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + - cast_to_type_.ToString() + ". The valid range is: [" + - std::to_string(std::numeric_limits::lowest()) + ", " + - std::to_string(std::numeric_limits::max()) + "]."; - - RETURN_STATUS_UNEXPECTED(error_message); - } - - double casted_result = static_cast(result); - casted.push_back(casted_result); - } - - RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/to_number_op.h b/mindspore/ccsrc/dataset/text/kernels/to_number_op.h deleted file mode 100644 index 765749b778..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/to_number_op.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ -#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ - -#include -#include -#include - -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class ToNumberOp : public TensorOp { - public: - // Constructor of ToNumberOp - // @param const DataType &cast_to_type - the type to convert string inputs to. - explicit ToNumberOp(const DataType &cast_to_type); - - // Constructor of ToNumberOp - // @param const std::string &cast_to_type - the type in string form to convert string inputs to. - explicit ToNumberOp(const std::string &cast_to_type); - - ~ToNumberOp() override = default; - - // Perform numeric conversion on each string in each tensor. - // @param const std::shared_ptr &input - // @param std::shared_ptr *output - // @return error code - Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - - // For each input shape, find the output shape - // @param std::vector &inputs - shape of input tensors - // @param std::vector &outputs - shape of output tensors - // @return error code - Status OutputShape(const std::vector &input_shapes, std::vector &output_shapes) override; - - // print arg for debugging - // @param std::ostream &out - void Print(std::ostream &out) const override; - - std::string Name() const override { return kToNumberOp; } - - private: - template - Status ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); - - template - Status ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); - - Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); - - Status ToFloat(const std::shared_ptr &input, std::shared_ptr *output); - - Status ToDouble(const std::shared_ptr &input, std::shared_ptr *output); - - DataType cast_to_type_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc deleted file mode 100644 index 136d5006df..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/truncate_sequence_pair_op.h" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/data/slice_op.h" - -namespace mindspore { -namespace dataset { - -Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 2, "Number of inputs should be two."); - std::shared_ptr seq1 = input[0]; - std::shared_ptr seq2 = input[1]; - CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, - "Both sequences should be of rank 1"); - dsize_t length1 = seq1->shape()[0]; - dsize_t length2 = seq2->shape()[0]; - dsize_t outLength1 = length1; - dsize_t outLength2 = length2; - - dsize_t total = length1 + length2; - while (total > max_length_) { - if (outLength1 > outLength2) - outLength1--; - else - outLength2--; - total--; - } - std::shared_ptr outSeq1; - if (length1 != outLength1) { - std::unique_ptr slice1(new SliceOp(Slice(outLength1 - length1))); - RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1)); - } else { - outSeq1 = std::move(seq1); - } - - std::shared_ptr outSeq2; - if (length2 != outLength2) { - std::unique_ptr slice2(new SliceOp(Slice(outLength2 - length2))); - RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2)); - } else { - outSeq2 = std::move(seq2); - } - output->push_back(outSeq1); - output->push_back(outSeq2); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h b/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h deleted file mode 100644 index e9bd00f9de..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ -#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ - -#include -#include -#include -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/kernels/data/data_utils.h" - -namespace mindspore { -namespace dataset { - -class TruncateSequencePairOp : public TensorOp { - public: - explicit TruncateSequencePairOp(dsize_t length) : max_length_(length) {} - - ~TruncateSequencePairOp() override = default; - - void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kTruncateSequencePairOp; } - - private: - dsize_t max_length_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.cc deleted file mode 100644 index d2bd22058b..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/unicode_char_tokenizer_op.h" -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; - -namespace mindspore { -namespace dataset { - -const bool UnicodeCharTokenizerOp::kDefWithOffsets = false; - -Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view str; - RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); - - RuneStrArray runes; - if (!DecodeRunesInString(str.data(), str.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - std::vector splits(runes.size()); - std::vector offsets_start, offsets_limit; - for (size_t i = 0; i < runes.size(); i++) { - offsets_start.push_back(runes[i].offset); - offsets_limit.push_back(runes[i].offset + runes[i].len); - splits[i] = str.substr(runes[i].offset, runes[i].len); - } - if (splits.empty()) { - splits.emplace_back(""); - offsets_start.push_back(0); - offsets_limit.push_back(0); - } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.h deleted file mode 100644 index 116b8028da..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_char_tokenizer_op.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class UnicodeCharTokenizerOp : public TensorOp { - public: - static const bool kDefWithOffsets; - - explicit UnicodeCharTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {} - - ~UnicodeCharTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kUnicodeCharTokenizerOp; } - - private: - bool with_offsets_; -}; - -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.cc deleted file mode 100644 index 0760fea90a..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/unicode_script_tokenizer_op.h" -#include -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" -#include "unicode/errorcode.h" -#include "unicode/uchar.h" -#include "unicode/uscript.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; - -namespace mindspore { -namespace dataset { - -const bool UnicodeScriptTokenizerOp::kDefKeepWhitespace = false; -const bool UnicodeScriptTokenizerOp::kDefWithOffsets = false; - -Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view str; - RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); - RuneStrArray runes; - if (!DecodeRunesInString(str.data(), str.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - UScriptCode last_script = USCRIPT_INVALID_CODE; - icu::ErrorCode status; - int start = 0; - int len = 0; - std::vector splits; - std::vector offsets_start, offsets_limit; - - bool was_space = false; - for (size_t i = 0; i < runes.size(); i++) { - bool is_space = u_isUWhiteSpace(runes[i].rune); - UScriptCode script = uscript_getScript(runes[i].rune, status); - if (status.isFailure()) { - status.reset(); - script = USCRIPT_INVALID_CODE; - } - // 1) Seperate UTF-8 strings of different UScriptCode values - // (such as: "Chinese中国" should be splited to ["Chinese", "中国"]) - // 2) Seperate whitespace and non-whitespace UTF-8 strings - // (such as: " ." should be split to [" ", "."]) - if (len > 0 && (script != last_script || is_space != was_space)) { - // 3) If keep_whitespace_ is false, all the whitespace characters will be discard - if (keep_whitespace_ || !was_space) { - offsets_start.push_back(static_cast(start)); - offsets_limit.push_back(static_cast(start + len)); - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - } - start = runes[i].offset; - len = runes[i].len; - } else { - len += runes[i].len; - } - last_script = script; - was_space = is_space; - } - - if (len > 0 && (keep_whitespace_ || !was_space)) { - offsets_start.push_back(static_cast(start)); - offsets_limit.push_back(static_cast(start + len)); - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - } - // 4) If the input is empty scalar string, the output will be 1-D empty string. - if (splits.empty()) { - splits.emplace_back(""); - offsets_start.push_back(0); - offsets_limit.push_back(0); - } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.h deleted file mode 100644 index ec1be52533..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/unicode_script_tokenizer_op.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class UnicodeScriptTokenizerOp : public TensorOp { - public: - static const bool kDefKeepWhitespace; - static const bool kDefWithOffsets; - - explicit UnicodeScriptTokenizerOp(const bool &keep_whitespace = kDefKeepWhitespace, - const bool &with_offsets = kDefWithOffsets) - : keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {} - - ~UnicodeScriptTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "UnicodeScriptTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kUnicodeScriptTokenizerOp; } - - private: - bool keep_whitespace_; // If or not keep whitespace tokens - bool with_offsets_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.cc deleted file mode 100644 index 16bc2c87a3..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/text/kernels/whitespace_tokenizer_op.h" -#include -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" -#include "unicode/errorcode.h" -#include "unicode/uchar.h" -#include "unicode/uscript.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; - -namespace mindspore { -namespace dataset { - -const bool WhitespaceTokenizerOp::kDefWithOffsets = false; - -Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); - if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); - } - std::string_view str; - RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); - - RuneStrArray runes; - if (!DecodeRunesInString(str.data(), str.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - std::vector offsets_start, offsets_limit; - std::vector splits; - int start = 0; - int len = 0; - for (size_t i = 0; i < runes.size(); i++) { - if (u_isUWhiteSpace(runes[i].rune)) { - if (len > 0) { - offsets_start.push_back(static_cast(start)); - offsets_limit.push_back(static_cast(start + len)); - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - len = 0; - } - } else { - if (len == 0) { - start = runes[i].offset; - } - len += runes[i].len; - } - } - if (len > 0) { - offsets_start.push_back(static_cast(start)); - offsets_limit.push_back(static_cast(start + len)); - std::string temp(str.substr(start, len)); - splits.emplace_back(std::move(temp)); - } - if (splits.empty()) { - splits.emplace_back(""); - offsets_start.push_back(0); - offsets_limit.push_back(0); - } - token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.h deleted file mode 100644 index e507e5b393..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/whitespace_tokenizer_op.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ -#include -#include - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { - -class WhitespaceTokenizerOp : public TensorOp { - public: - static const bool kDefWithOffsets; - - explicit WhitespaceTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {} - - ~WhitespaceTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "WhitespaceTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - std::string Name() const override { return kWhitespaceTokenizerOp; } - - private: - bool with_offsets_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.cc b/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.cc deleted file mode 100644 index b97f696da7..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.cc +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "dataset/text/kernels/wordpiece_tokenizer_op.h" -#include -#include - -namespace mindspore { -namespace dataset { - -const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; -const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; -const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; -const bool WordpieceTokenizerOp::kDefWithOffsets = false; - -WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, - const int &max_bytes_per_token, const std::string &unknown_token, - const bool &with_offsets) - : vocab_(vocab), - suffix_indicator_(suffix_indicator), - max_bytes_per_token_(max_bytes_per_token), - unknown_token_(unknown_token), - with_offsets_(with_offsets) {} - -Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, - bool *out_found, int *out_end) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); - *out_found = false; - for (int i = runes.size() - 1; i >= 0; i--) { - *out_end = runes[i].offset + runes[i].len; - int len = *out_end - start; - std::string word = input_token.substr(start, len); - if (start > 0) { - word = suffix_indicator_ + word; - } - if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { - *out_found = true; - break; - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, - std::vector *out_tokens, std::vector *offsets_start, - std::vector *offsets_limit) const { - out_tokens->clear(); - offsets_start->push_back(basic_start); - if (unknown_token_.empty()) { - out_tokens->emplace_back(input_token); - offsets_limit->push_back(basic_start + input_token.length()); - } else { - out_tokens->emplace_back(unknown_token_); - offsets_limit->push_back(basic_start + input_token.length()); - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, - std::vector *out_tokens) const { - CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); - std::string subword = input_token.substr(start, end - start); - if (start > 0) { - subword = suffix_indicator_ + subword; - } - out_tokens->emplace_back(subword); - return Status::OK(); -} - -Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, - std::vector *out_tokens, std::vector *offsets_start, - std::vector *offsets_limit) const { - if (input_token.size() > max_bytes_per_token_) { - offsets_start->push_back(basic_start); - if (!unknown_token_.empty()) { - offsets_limit->push_back(basic_start + unknown_token_.size()); - out_tokens->emplace_back(unknown_token_); - } else { - out_tokens->emplace_back(input_token); - offsets_limit->push_back(basic_start + input_token.size()); - } - return Status::OK(); - } - RuneStrArray runes; - if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { - RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); - } - int end = 0; - for (int start = 0; start < input_token.size();) { - bool found = false; - RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); - if (found) { - RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); - offsets_start->push_back(static_cast(basic_start + start)); - offsets_limit->push_back(static_cast(basic_start + end)); - start = end; - } else { - return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); - } - } - return Status::OK(); -} - -Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { - IO_CHECK_VECTOR(input, output); - if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { - RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); - } - dsize_t count = 0; - std::vector out_tokens; - std::vector offsets_start, offsets_limit; - std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; - for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { - uint32_t basic_start = 0; - std::vector temp_tokens; - if (with_offsets_ && input.size() == 3) { - RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); - } - RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); - out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); - count++; - } - if (out_tokens.empty()) { - out_tokens.emplace_back(""); - offsets_start.push_back(0); - offsets_limit.push_back(0); - } - token_tensor = std::make_shared(out_tokens, TensorShape({(dsize_t)out_tokens.size()})); - output->push_back(token_tensor); - if (with_offsets_) { - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_start[0]))); - RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, - TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), - reinterpret_cast(&offsets_limit[0]))); - output->push_back(offsets_start_tensor); - output->push_back(offsets_limit_tensor); - } - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.h b/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.h deleted file mode 100644 index 502da4cef2..0000000000 --- a/mindspore/ccsrc/dataset/text/kernels/wordpiece_tokenizer_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#define DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ -#include -#include -#include -#include - -#include "cppjieba/Unicode.hpp" - -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/text/vocab.h" -#include "dataset/util/status.h" - -using cppjieba::DecodeRunesInString; -using cppjieba::RuneStrArray; -namespace mindspore { -namespace dataset { - -class WordpieceTokenizerOp : public TensorOp { - public: - static const char kDefSuffixIndicator[]; - static const int kDefMaxBytesPerToken; - static const char kDefUnknownToken[]; - static const bool kDefWithOffsets; - WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator = kDefSuffixIndicator, - const int &max_bytes_per_token = kDefMaxBytesPerToken, - const std::string &unknown_token = kDefUnknownToken, const bool &with_offsets = kDefWithOffsets); - - ~WordpieceTokenizerOp() override = default; - - void Print(std::ostream &out) const override { out << "WordpieceTokenizerOp"; } - - Status Compute(const TensorRow &input, TensorRow *output) override; - - protected: - Status AddSubword(const std::string &input_token, const int &start, const int &end, - std::vector *out_token) const; - Status FoundNoToken(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, - std::vector *offsets_start, std::vector *offsets_limit) const; - Status LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, bool *out_found, - int *out_end) const; - Status GetTokens(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, - std::vector *offsets_start, std::vector *offsets_limit) const; - - std::string Name() const override { return kWordpieceTokenizerOp; } - - private: - const std::shared_ptr vocab_; - const std::string suffix_indicator_; - const bool with_offsets_; - const int max_bytes_per_token_; - const std::string unknown_token_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/vocab.cc b/mindspore/ccsrc/dataset/text/vocab.cc deleted file mode 100644 index 399a9dee37..0000000000 --- a/mindspore/ccsrc/dataset/text/vocab.cc +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include - -#include "dataset/text/vocab.h" - -namespace mindspore { -namespace dataset { -Vocab::Vocab(std::unordered_map word2id) { word2id_ = std::move(word2id); } - -WordIdType Vocab::Lookup(const WordType &word) const { - auto itr = word2id_.find(word); - return itr == word2id_.end() ? kNoTokenExists : itr->second; -} - -Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, - std::shared_ptr *vocab) { - // check of duplication on both words and special_tokens will be performed in python - // special_tokens and words both need to be unique, and shouldn't overlap - std::unordered_map word2id; - // if special is added in front, normal words id will start from number of special tokens - WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; - - for (auto word : words) { - word2id[py::str(word)] = word_id++; - } - - word_id = prepend_special ? 0 : word2id.size(); - - for (auto special_token : special_tokens) { - word2id[py::str(special_token)] = word_id++; - } - - *vocab = std::make_shared(std::move(word2id)); - return Status::OK(); -} - -Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, - const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab) { - // python validator checks special_tokens doesn't contain any duplicate words - std::unordered_set specials; - // used to check that words in file don't contain any special token that already exists - for (auto word : special_tokens) { - specials.insert(py::str(word)); - } - WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; - std::unordered_map word2id; - std::fstream handle(path, std::ios::in); - CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); - std::string word; - while (std::getline(handle, word)) { - if (!delimiter.empty()) { - // if delimiter is not found, find_first_of would return std::string::npos which is -1 - word = word.substr(0, word.find_first_of(delimiter)); - } - CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + "."); - CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens."); - word2id[word] = word_id++; - // break if enough row is read, if vocab_size is smaller than 0 - if (word2id.size() == vocab_size) break; - } - - word_id = prepend_special ? 0 : word2id.size(); - - for (auto special_token : special_tokens) { - word2id[py::str(special_token)] = word_id++; - } - - *vocab = std::make_shared(std::move(word2id)); - return Status::OK(); -} - -Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab) { - std::unordered_map word2id; - for (auto p : words) { - word2id[py::str(p.first)] = py::reinterpret_borrow(p.second); - } - *vocab = std::make_shared(std::move(word2id)); - return Status::OK(); -} - -void Vocab::append_word(const std::string &word) { - if (word2id_.find(word) == word2id_.end()) { - word2id_[word] = word2id_.size(); - } -} - -const WordIdType Vocab::kNoTokenExists = -1; - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/text/vocab.h b/mindspore/ccsrc/dataset/text/vocab.h deleted file mode 100644 index 410b0aeeca..0000000000 --- a/mindspore/ccsrc/dataset/text/vocab.h +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_TEXT_VOCAB_H_ -#define DATASET_TEXT_VOCAB_H_ - -#include -#include -#include -#include - -#include "dataset/util/status.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace mindspore { -namespace dataset { -namespace py = pybind11; - -using WordIdType = int32_t; -using WordType = std::string; - -class Vocab { - public: - // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous - // @param const py::dict &words - a dictionary containing word, word id pair. - // @param std::shared_ptr *vocab - return value, vocab object - // @return error code - static Status BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab); - - // Build a vocab from a python list, id will be assigned automatically, start from 2 - // @param const py::list &words - a list of string, used to build vocab, id starts from 2 - // @param std::shared_ptr *vocab - return value, vocab object - // @return error code - static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, - std::shared_ptr *vocab); - - // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 - // @param std::string &path - path to vocab file , each line is assumed to contain 1 word - // @param std::string &delimiter - delimiter to break each line with - // @param int32_t vocab_size - number of words to read from file - // @param std::shared_ptr *vocab - return value, vocab object - // @return error code - static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, - const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab); - - // Lookup the id of a word, if word doesn't exist in vocab, return default_id - // @param const WordType word - word to look up - // @param WordIdType default_id - word id to return to user when its not in the vocab - // @return WordIdType, word_id - WordIdType Lookup(const WordType &word) const; - - // constructor, shouldn't be called directly, can't be private due to std::make_unique() - // @param std::unordered_map map - sanitized word2id map - explicit Vocab(std::unordered_map map); - - Vocab() = default; - - // add one word to vocab, increment it's index automatically - // @param std::string & word - word to be added will skip if word already exists - void append_word(const std::string &word); - - // destructor - ~Vocab() = default; - - static const WordIdType kNoTokenExists; - - private: - std::unordered_map word2id_; -}; - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_TEXT_VOCAB_H_ diff --git a/mindspore/ccsrc/dataset/util/allocator.h b/mindspore/ccsrc/dataset/util/allocator.h deleted file mode 100644 index 1998716438..0000000000 --- a/mindspore/ccsrc/dataset/util/allocator.h +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_ALLOCATOR_H_ -#define DATASET_UTIL_ALLOCATOR_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/memory_pool.h" - -namespace mindspore { -namespace dataset { -// The following conforms to the requirements of -// std::allocator. Do not rename/change any needed -// requirements, e.g. function names, typedef etc. -template -class Allocator { - public: - template - friend class Allocator; - - using value_type = T; - using pointer = T *; - using const_pointer = const T *; - using reference = T &; - using const_reference = const T &; - using size_type = uint64_t; - - template - struct rebind { - using other = Allocator; - }; - - using propagate_on_container_copy_assignment = std::true_type; - using propagate_on_container_move_assignment = std::true_type; - using propagate_on_container_swap = std::true_type; - - explicit Allocator(const std::shared_ptr &b) : pool_(b) {} - - ~Allocator() = default; - - template - explicit Allocator(Allocator const &rhs) : pool_(rhs.pool_) {} - - template - bool operator==(Allocator const &rhs) const { - return pool_ == rhs.pool_; - } - - template - bool operator!=(Allocator const &rhs) const { - return pool_ != rhs.pool_; - } - - pointer allocate(std::size_t n) { - void *p; - Status rc = pool_->Allocate(n * sizeof(T), &p); - if (rc.IsOk()) { - return reinterpret_cast(p); - } else if (rc.IsOutofMemory()) { - throw std::bad_alloc(); - } else { - throw std::exception(); - } - } - - void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); } - - size_type max_size() { return pool_->get_max_size(); } - - private: - std::shared_ptr pool_; -}; -/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will -/// be released when the object goes out of scope -/// \tparam T The type of object to be allocated -/// \tparam C Allocator. Default to std::allocator -template > -class MemGuard { - public: - using allocator = C; - MemGuard() : n_(0) {} - explicit MemGuard(allocator a) : n_(0), alloc_(a) {} - // There is no copy constructor nor assignment operator because the memory is solely owned by this object. - MemGuard(const MemGuard &) = delete; - MemGuard &operator=(const MemGuard &) = delete; - // On the other hand, We can support move constructor - MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} - MemGuard &operator=(MemGuard &&lhs) noexcept { - if (this != &lhs) { - this->deallocate(); - n_ = lhs.n_; - alloc_ = std::move(lhs.alloc_); - ptr_ = std::move(lhs.ptr_); - } - return *this; - } - /// \brief Explicitly deallocate the memory if allocated - void deallocate() { - if (ptr_) { - auto *p = ptr_.release(); - if (!std::is_arithmetic::value && std::is_destructible::value) { - for (auto i = 0; i < n_; ++i) { - p[i].~T(); - } - } - alloc_.deallocate(p, n_); - n_ = 0; - } - } - /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is - /// allocated. - /// \param n Number of objects of type T to be allocated - /// \tparam Args Extra arguments pass to the constructor of T - template - Status allocate(size_t n, Args &&... args) noexcept { - try { - deallocate(); - if (n > 0) { - T *data = alloc_.allocate(n); - if (!std::is_arithmetic::value) { - for (auto i = 0; i < n; i++) { - std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); - } - } - ptr_ = std::unique_ptr(data); - n_ = n; - } - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory); - } catch (std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - return Status::OK(); - } - ~MemGuard() noexcept { deallocate(); } - /// \brief Getter function - /// \return The pointer to the memory allocated - T *GetPointer() const { return ptr_.get(); } - /// \brief Getter function - /// \return The pointer to the memory allocated - T *GetMutablePointer() { return ptr_.get(); } - /// \brief Overload [] operator to access a particular element - /// \param x index to the element. Must be less than number of element allocated. - /// \return pointer to the x-th element - T *operator[](size_t x) { return GetMutablePointer() + x; } - /// \brief Overload [] operator to access a particular element - /// \param x index to the element. Must be less than number of element allocated. - /// \return pointer to the x-th element - T *operator[](size_t x) const { return GetPointer() + x; } - /// \brief Return how many bytes are allocated in total - /// \return Number of bytes allocated in total - size_t GetSizeInBytes() const { return n_ * sizeof(T); } - - private: - allocator alloc_; - std::unique_ptr ptr_; - size_t n_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/dataset/util/arena.cc b/mindspore/ccsrc/dataset/util/arena.cc deleted file mode 100644 index af4f522678..0000000000 --- a/mindspore/ccsrc/dataset/util/arena.cc +++ /dev/null @@ -1,256 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/arena.h" -#include -#include -#include "dataset/util/system_pool.h" -#include "./securec.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -struct MemHdr { - uint32_t sig; - uint64_t addr; - uint64_t blk_size; - MemHdr(uint64_t a, uint64_t sz) : sig(0xDEADBEEF), addr(a), blk_size(sz) {} - static void setHdr(void *p, uint64_t addr, uint64_t sz) { new (p) MemHdr(addr, sz); } - static void getHdr(void *p, MemHdr *hdr) { - auto *tmp = reinterpret_cast(p); - *hdr = *tmp; - } -}; -Status Arena::Init() { - RETURN_IF_NOT_OK(DeMalloc(size_in_MB_ * 1048576L, &ptr_, false)); - // Divide the memory into blocks. Ignore the last partial block. - uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; - MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; - tr_.Insert(0, num_blks); - return Status::OK(); -} - -Status Arena::Allocate(size_t n, void **p) { - if (n == 0) { - *p = nullptr; - return Status::OK(); - } - std::unique_lock lck(mux_); - // Round up n to 1K block - uint64_t req_size = static_cast(n) + ARENA_WALL_OVERHEAD_SZ; - if (req_size > this->get_max_size()) { - return Status(StatusCode::kOutOfMemory); - } - uint64_t reqBlk = SizeToBlk(req_size); - // Do a first fit search - auto blk = tr_.Top(); - if (blk.second && reqBlk <= blk.first.priority) { - uint64_t addr = blk.first.key; - uint64_t size = blk.first.priority; - // Trim to the required size and return the rest to the tree. - tr_.Pop(); - if (size > reqBlk) { - tr_.Insert(addr + reqBlk, size - reqBlk); - } - lck.unlock(); - char *q = static_cast(ptr_) + addr * ARENA_BLK_SZ; - MemHdr::setHdr(q, addr, reqBlk); - *p = get_user_addr(q); - } else { - return Status(StatusCode::kOutOfMemory); - } - return Status::OK(); -} - -void Arena::Deallocate(void *p) { - auto *q = get_base_addr(p); - MemHdr hdr(0, 0); - MemHdr::getHdr(q, &hdr); - MS_ASSERT(hdr.sig == 0xDEADBEEF); - // We are going to insert a free block back to the treap. But first, check if we can combine - // with the free blocks before and after to form a bigger block. - std::unique_lock lck(mux_); - // Query if we have a free block after us. - auto nextBlk = tr_.Search(hdr.addr + hdr.blk_size); - if (nextBlk.second) { - // Form a bigger block - hdr.blk_size += nextBlk.first.priority; - tr_.DeleteKey(nextBlk.first.key); - } - // Next find a block in front of us. - auto result = FindPrevBlk(hdr.addr); - if (result.second) { - // We can combine with this block - hdr.addr = result.first.first; - hdr.blk_size += result.first.second; - tr_.DeleteKey(result.first.first); - } - // Now we can insert the free node - tr_.Insert(hdr.addr, hdr.blk_size); -} - -Status Arena::Reallocate(void **pp, size_t old_sz, size_t new_sz) { - MS_ASSERT(pp); - MS_ASSERT(*pp); - uint64_t actual_size = static_cast(new_sz) + ARENA_WALL_OVERHEAD_SZ; - if (actual_size > this->get_max_size()) { - RETURN_STATUS_UNEXPECTED("Request size too big : " + std::to_string(new_sz)); - } - uint64_t req_blk = SizeToBlk(actual_size); - char *oldAddr = reinterpret_cast(*pp); - auto *oldHdr = get_base_addr(oldAddr); - MemHdr hdr(0, 0); - MemHdr::getHdr(oldHdr, &hdr); - MS_ASSERT(hdr.sig == 0xDEADBEEF); - std::unique_lock lck(mux_); - if (hdr.blk_size > req_blk) { - // Refresh the header with the new smaller size. - MemHdr::setHdr(oldHdr, hdr.addr, req_blk); - // Return the unused memory back to the tree. Unlike allocate, we we need to merge with the block after us. - auto next_blk = tr_.Search(hdr.addr + hdr.blk_size); - if (next_blk.second) { - hdr.blk_size += next_blk.first.priority; - tr_.DeleteKey(next_blk.first.key); - } - tr_.Insert(hdr.addr + req_blk, hdr.blk_size - req_blk); - } else if (hdr.blk_size < req_blk) { - uint64_t addr = hdr.addr; - // Attempt a block enlarge. No guarantee it is always successful. - bool success = BlockEnlarge(&addr, hdr.blk_size, req_blk); - if (success) { - auto *newHdr = static_cast(ptr_) + addr * ARENA_BLK_SZ; - MemHdr::setHdr(newHdr, addr, req_blk); - if (addr != hdr.addr) { - errno_t err = - memmove_s(get_user_addr(newHdr), (req_blk * ARENA_BLK_SZ) - ARENA_WALL_OVERHEAD_SZ, oldAddr, old_sz); - if (err) { - RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); - } - } - *pp = get_user_addr(newHdr); - return Status::OK(); - } - // If we reach here, allocate a new block and simply move the content from the old to the new place. - // Unlock since allocate will grab the lock again. - lck.unlock(); - return FreeAndAlloc(pp, old_sz, new_sz); - } - return Status::OK(); -} - -std::ostream &operator<<(std::ostream &os, const Arena &s) { - for (auto &it : s.tr_) { - os << "Address : " << it.key << ". Size : " << it.priority << "\n"; - } - return os; -} - -Arena::Arena(size_t val_in_MB) : ptr_(nullptr), size_in_MB_(val_in_MB), size_in_bytes_(val_in_MB * 1048576L) {} - -Status Arena::CreateArena(std::shared_ptr *p_ba, size_t val_in_MB) { - if (p_ba == nullptr) { - RETURN_STATUS_UNEXPECTED("p_ba is null"); - } - Status rc; - auto ba = new (std::nothrow) Arena(val_in_MB); - if (ba == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - rc = ba->Init(); - if (rc.IsOk()) { - (*p_ba).reset(ba); - } else { - delete ba; - } - return rc; -} - -int Arena::PercentFree() const { - uint64_t sz = 0; - for (auto &it : tr_) { - sz += it.priority; - } - double ratio = static_cast(sz * ARENA_BLK_SZ) / static_cast(size_in_bytes_); - return static_cast(ratio * 100.0); -} - -uint64_t Arena::get_max_size() const { return (size_in_bytes_ - ARENA_WALL_OVERHEAD_SZ); } - -std::pair, bool> Arena::FindPrevBlk(uint64_t addr) { - for (auto &it : tr_) { - if (it.key + it.priority == addr) { - return std::make_pair(std::make_pair(it.key, it.priority), true); - } else if (it.key > addr) { - break; - } - } - return std::make_pair(std::make_pair(0, 0), false); -} - -bool Arena::BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz) { - uint64_t size = old_sz; - // The logic is very much identical to Deallocate. We will see if we can combine with the blocks before and after. - auto next_blk = tr_.Search(*addr + old_sz); - if (next_blk.second) { - size += next_blk.first.priority; - if (size >= new_sz) { - // In this case, we can just enlarge the block without doing any moving. - tr_.DeleteKey(next_blk.first.key); - // Return unused back to the tree. - if (size > new_sz) { - tr_.Insert(*addr + new_sz, size - new_sz); - } - } - return true; - } - // If we still get here, we have to look at the block before us. - auto result = FindPrevBlk(*addr); - if (result.second) { - // We can combine with this block together with the next block (if any) - size += result.first.second; - *addr = result.first.first; - if (size >= new_sz) { - // We can combine with this block together with the next block (if any) - tr_.DeleteKey(*addr); - if (next_blk.second) { - tr_.DeleteKey(next_blk.first.key); - } - // Return unused back to the tree. - if (size > new_sz) { - tr_.Insert(*addr + new_sz, size - new_sz); - } - return true; - } - } - return false; -} - -Status Arena::FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz) { - MS_ASSERT(pp); - MS_ASSERT(*pp); - void *p = nullptr; - void *q = *pp; - RETURN_IF_NOT_OK(Allocate(new_sz, &p)); - errno_t err = memmove_s(p, new_sz, q, old_sz); - if (err) { - RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); - } - *pp = p; - // Free the old one. - Deallocate(q); - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/arena.h b/mindspore/ccsrc/dataset/util/arena.h deleted file mode 100644 index 8c5d1e1093..0000000000 --- a/mindspore/ccsrc/dataset/util/arena.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_ARENA_H_ -#define DATASET_UTIL_ARENA_H_ - -#include -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/treap.h" - -#define ARENA_LOG_BLK_SZ (6u) -#define ARENA_BLK_SZ (static_cast(1u << ARENA_LOG_BLK_SZ)) -#define ARENA_WALL_OVERHEAD_SZ 32 -namespace mindspore { -namespace dataset { -// This is a memory arena based on a treap data structure. -// The constructor of the Arena takes the size of the initial memory size (in MB). -// Internally we divide the memory into multiple blocks. Each block is 64 bytes. -// The treap contains all the free blocks with the relative memory address as key -// and the size of the block as priority. -// -// Initially the treap has only one root which is the whole memory piece. -// -// For memory suballocation, we pop the root node of the treap which contains the largest free block. -// We allocate what we need and return the rest back to the treap. We search for the first fit instead -// of the best fit so to give us a constant time in memory allocation. -// -// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to -// form a bigger block. -class Arena : public MemoryPool { - public: - Arena(const Arena &) = delete; - - Arena &operator=(const Arena &) = delete; - - ~Arena() override { - if (ptr_ != nullptr) { - free(ptr_); - ptr_ = nullptr; - } - } - - Status Allocate(size_t n, void **p) override; - - Status Reallocate(void **, size_t old_sz, size_t new_sz) override; - - void Deallocate(void *) override; - - uint64_t get_max_size() const override; - - static uint64_t SizeToBlk(uint64_t sz) { - uint64_t req_blk = sz / ARENA_BLK_SZ; - if (sz % ARENA_BLK_SZ) { - ++req_blk; - } - return req_blk; - } - - int PercentFree() const override; - - const void *get_base_addr() const { return ptr_; } - - friend std::ostream &operator<<(std::ostream &os, const Arena &s); - - static Status CreateArena(std::shared_ptr *p_ba, size_t val_in_MB = 4096); - - private: - std::mutex mux_; - Treap tr_; - void *ptr_; - size_t size_in_MB_; - size_t size_in_bytes_; - - explicit Arena(size_t val_in_MB = 4096); - - std::pair, bool> FindPrevBlk(uint64_t addr); - - Status Init(); - - bool BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz); - - Status FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz); - - void *get_user_addr(void *base_addr) const { return reinterpret_cast(base_addr) + ARENA_WALL_OVERHEAD_SZ; } - - void *get_base_addr(void *user_addr) const { return reinterpret_cast(user_addr) - ARENA_WALL_OVERHEAD_SZ; } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_ARENA_H_ diff --git a/mindspore/ccsrc/dataset/util/auto_index.h b/mindspore/ccsrc/dataset/util/auto_index.h deleted file mode 100644 index 5c43ecfd80..0000000000 --- a/mindspore/ccsrc/dataset/util/auto_index.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_AUTO_INDEX_H_ -#define DATASET_UTIL_AUTO_INDEX_H_ - -#include -#include -#include -#include - -#include "dataset/util/btree.h" -#include "dataset/util/system_pool.h" - -namespace mindspore { -namespace dataset { -/// This is a B+ tree with generated int64_t value as key. -/// Use minKey() function to query the min key. -/// Use maxKey() function to query the max key. -/// @tparam T -template > -class AutoIndexObj : public BPlusTree { - public: - using my_tree = BPlusTree; - using key_type = typename my_tree::key_type; - using value_type = typename my_tree::value_type; - - AutoIndexObj() : my_tree::BPlusTree(), inx_(kMinKey) {} - - explicit AutoIndexObj(const Allocator &alloc) : my_tree::BPlusTree(alloc), inx_(kMinKey) {} - - ~AutoIndexObj() = default; - - // Insert an object into the tree. - // @param val - // @return - Status insert(const value_type &val, key_type *key = nullptr) { - key_type my_inx = inx_.fetch_add(1); - if (key != nullptr) { - *key = my_inx; - } - return my_tree::DoInsert(my_inx, val); - } - - Status insert(std::unique_ptr &&val, key_type *key = nullptr) { - key_type my_inx = inx_.fetch_add(1); - if (key) { - *key = my_inx; - } - return my_tree::DoInsert(my_inx, std::move(val)); - } - - // Insert a vector of objects into the tree. - // @param v - // @return - Status insert(std::vector v) { - uint64_t num_ele = v.size(); - if (num_ele > 0) { - // reserve a range of keys rather than getting it one by one. - key_type my_inx = inx_.fetch_add(num_ele); - for (uint64_t i = 0; i < num_ele; i++) { - RETURN_IF_NOT_OK(my_tree::DoInsert(my_inx + i, v.at(i))); - } - } - return Status::OK(); - } - - // @return the minimum key - key_type min_key() const { - auto it = this->cbegin(); - return it.key(); - } - - // @return the maximum key - key_type max_key() const { - auto it = this->cend(); - --it; - return it.key(); - } - - private: - static constexpr key_type kMinKey = 0; - std::atomic inx_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_AUTO_INDEX_H_ diff --git a/mindspore/ccsrc/dataset/util/btree.h b/mindspore/ccsrc/dataset/util/btree.h deleted file mode 100644 index ccf642e366..0000000000 --- a/mindspore/ccsrc/dataset/util/btree.h +++ /dev/null @@ -1,459 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_INDEX_H_ -#define DATASET_UTIL_INDEX_H_ - -#include -#include -#include -#include -#include -#include -#include "./securec.h" -#include "dataset/util/allocator.h" -#include "dataset/util/list.h" -#include "dataset/util/lock.h" -#include "dataset/util/memory_pool.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Default traits for a B+ tree -struct BPlusTreeTraits { - // This determines the limit of number of keys in a node. - using slot_type = uint16_t; - // Number of slots in each leaf of the tree. - static constexpr slot_type kLeafSlots = 256; - // Number of slots in each inner node of the tree - static constexpr slot_type kInnerSlots = 128; -}; - -/// Implementation of B+ tree -/// @tparam K -- the type of key -/// @tparam V -- the type of value -/// @tparam A -- allocator -/// @tparam C -- comparison class -/// @tparam T -- trait -template , typename C = std::less, - typename T = BPlusTreeTraits> -class BPlusTree { - public: - enum class IndexRc : char { - kOk = 0, - kDuplicateKey = 1, - kSlotFull = 2, - kKeyNotFound = 3, - kNullPointer = 4, - kOutOfMemory = 5, - kRetry = 6, - kUnexpectedError = 127 - }; -#define RETURN_IF_BAD_RC(_s) \ - do { \ - IndexRc __rc = (_s); \ - if (__rc != IndexRc::kOk) { \ - return __rc; \ - } \ - } while (false) - - Status IndexRc2Status(IndexRc rc) { - if (rc == IndexRc::kOk) { - return Status(StatusCode::kOK); - } else if (rc == IndexRc::kOutOfMemory) { - return Status(StatusCode::kOutOfMemory); - } else if (rc == IndexRc::kDuplicateKey) { - return Status(StatusCode::kDuplicateKey); - } else { - RETURN_STATUS_UNEXPECTED(std::to_string(static_cast(rc))); - } - } - - using key_type = K; - using value_type = V; - using key_compare = C; - using slot_type = typename T::slot_type; - using traits = T; - using value_allocator = A; - using key_allocator = typename value_allocator::template rebind::other; - using slot_allocator = typename value_allocator::template rebind::other; - - BPlusTree(); - - explicit BPlusTree(const Allocator &alloc); - - ~BPlusTree() noexcept; - - BPlusTree(const BPlusTree &) = delete; - - BPlusTree(BPlusTree &&) = delete; - - BPlusTree &operator=(const BPlusTree &) = delete; - - BPlusTree &operator=(BPlusTree &&) = delete; - - key_compare key_comp() const { return key_less_; } - - size_t size() const { return stats_.size_; } - - bool empty() const { return (size() == 0); } - - /// @param key - /// @param value - /// @return - Status DoInsert(const key_type &key, const value_type &value); - Status DoInsert(const key_type &key, std::unique_ptr &&value); - - // Update a new value for a given key. - std::unique_ptr DoUpdate(const key_type &key, const value_type &new_value); - std::unique_ptr DoUpdate(const key_type &key, std::unique_ptr &&new_value); - - // Statistics - struct tree_stats { - std::atomic size_; - uint32_t leaves_; - uint32_t inner_nodes_; - uint32_t level_; - - tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} - }; - - private: - // Abstract class of a node (leaf or inner) - class BaseNode { - public: - friend class BPlusTree; - - virtual bool is_leafnode() const = 0; - - virtual bool is_full() const = 0; - - explicit BaseNode(const value_allocator &alloc) : alloc_(alloc) {} - - virtual ~BaseNode() = default; - - protected: - mutable RWLock rw_lock_; - value_allocator alloc_; - - private: - Node lru_; - }; - - // This control block keeps track of all the nodes we traverse on insert. - // To maximize concurrency, internal nodes are latched S. If a node split - // is required, we must releases all the latches and redo it again and change - // the latch mode from S to X. - struct LockPathCB { - enum class LockMode : char { kShared = 0, kExclusive = 1, kNone = 2 }; - - struct path { - BaseNode *node_; - bool locked_; - - path() : node_(nullptr), locked_(false) {} - - path(BaseNode *p, LockMode lockmode) : node_(p), locked_(false) { - if (lockmode == LockMode::kExclusive) { - p->rw_lock_.LockExclusive(); - locked_ = true; - } else if (lockmode == LockMode::kShared) { - p->rw_lock_.LockShared(); - locked_ = true; - } - } - }; - - LockPathCB(BPlusTree *tree, bool retryWithXlock) : self_(tree), latch_shared_(true) { - if (retryWithXlock) { - latch_shared_ = false; - } - if (latch_shared_) { - tree->rw_lock_.LockShared(); - } else { - tree->rw_lock_.LockExclusive(); - } - } - - ~LockPathCB() noexcept { - // Make sure all locks are released. - while (!paths_.empty()) { - path p = paths_.back(); - paths_.pop_back(); - if (p.locked_) { - p.node_->rw_lock_.Unlock(); - } - } - self_->rw_lock_.Unlock(); - self_ = nullptr; - } - - void LockNode(BaseNode *p, LockMode locktype) { paths_.emplace_back(p, locktype); } - - void UnlockMyParents(BaseNode *me) { - path p = paths_.front(); - while (p.node_ != me) { - if (p.locked_) { - p.node_->rw_lock_.Unlock(); - } - paths_.pop_front(); - p = paths_.front(); - } - } - - BPlusTree *self_; - std::deque paths_; - bool latch_shared_; - }; - - // Definition of inner node which fans to either inner node or leaf node. - class InnerNode : public BaseNode { - public: - friend class BPlusTree; - - using alloc_type = typename value_allocator::template rebind::other; - - bool is_leafnode() const override { return false; } - - bool is_full() const override { return (slotuse_ == traits::kInnerSlots); } - - IndexRc Sort(); - - // 50/50 split - IndexRc Split(InnerNode *to, key_type *split_key); - - IndexRc InsertIntoSlot(slot_type slot, const key_type &key, BaseNode *ptr); - - explicit InnerNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} - - ~InnerNode() = default; - - slot_type slot_dir_[traits::kInnerSlots] = {0}; - key_type keys_[traits::kInnerSlots] = {0}; - BaseNode *data_[traits::kInnerSlots + 1] = {nullptr}; - slot_type slotuse_; - }; - - // Definition of a leaf node which contains the key/value pair - class LeafNode : public BaseNode { - public: - friend class BPlusTree; - - using alloc_type = typename value_allocator::template rebind::other; - Node link_; - - bool is_leafnode() const override { return true; } - - bool is_full() const override { return (slotuse_ == traits::kLeafSlots); } - - IndexRc Sort(); - - // 50/50 split - IndexRc Split(LeafNode *to); - - IndexRc InsertIntoSlot(LockPathCB *insCB, slot_type slot, const key_type &key, std::unique_ptr &&value); - - explicit LeafNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} - - ~LeafNode() = default; - - slot_type slot_dir_[traits::kLeafSlots] = {0}; - key_type keys_[traits::kLeafSlots] = {0}; - std::unique_ptr data_[traits::kLeafSlots]; - slot_type slotuse_; - }; - - mutable RWLock rw_lock_; - value_allocator alloc_; - // All the leaf nodes. Used by the iterator to traverse all the key/values. - List leaf_nodes_; - // All the nodes (inner + leaf). Used by the destructor to free the memory of all the nodes. - List all_; - // Pointer to the root of the tree. - BaseNode *root_; - // Key comparison object - key_compare key_less_; - // Stat - tree_stats stats_; - - bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } - - bool EqualOrLessThan(const key_type &a, const key_type &b) const { return !key_less_(b, a); } - - bool Equal(const key_type &a, const key_type &b) const { return !key_less_(a, b) && !key_less_(b, a); } - - IndexRc AllocateInner(InnerNode **p); - - IndexRc AllocateLeaf(LeafNode **p); - - template - slot_type FindSlot(const node_type *node, const key_type &key, bool *duplicate = nullptr) const { - slot_type lo = 0; - while (lo < node->slotuse_ && key_comp()(node->keys_[node->slot_dir_[lo]], key)) { - ++lo; - } - bool keymatch = (lo < node->slotuse_ && Equal(key, node->keys_[node->slot_dir_[lo]])); - if (keymatch && !node->is_leafnode()) { - // For an inner node and we match a key during search, we should look into the next slot. - ++lo; - } - if (duplicate != nullptr) { - *duplicate = keymatch; - } - return lo; - } - - IndexRc LeafInsertKeyValue(LockPathCB *ins_cb, LeafNode *node, const key_type &key, - std::unique_ptr &&value, key_type *split_key, LeafNode **split_node); - - IndexRc InnerInsertKeyChild(InnerNode *node, const key_type &key, BaseNode *ptr, key_type *split_key, - InnerNode **split_node); - - inline BaseNode *FindBranch(InnerNode *inner, slot_type slot) const { - BaseNode *child = nullptr; - if (slot == 0) { - child = inner->data_[0]; - } else { - child = inner->data_[inner->slot_dir_[slot - 1] + 1]; - } - return child; - } - - IndexRc InsertKeyValue(LockPathCB *ins_cb, BaseNode *n, const key_type &key, std::unique_ptr &&value, - key_type *split_key, BaseNode **split_node); - - IndexRc Locate(RWLock *parent_lock, bool forUpdate, BaseNode *top, const key_type &key, LeafNode **ln, - slot_type *s) const; - - public: - class Iterator : public std::iterator { - public: - using reference = BPlusTree::value_type &; - using pointer = BPlusTree::value_type *; - - explicit Iterator(BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} - - Iterator(LeafNode *leaf, slot_type slot, bool locked = false) : cur_(leaf), slot_(slot), locked_(locked) {} - - ~Iterator(); - - explicit Iterator(const Iterator &); - - Iterator &operator=(const Iterator &lhs); - - Iterator(Iterator &&); - - Iterator &operator=(Iterator &&lhs); - - pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } - - reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } - - value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - // Prefix++ - Iterator &operator++(); - - // Postfix++ - Iterator operator++(int); - - // Prefix-- - Iterator &operator--(); - - // Postfix-- - Iterator operator--(int); - - bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } - bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } - - private: - typename BPlusTree::LeafNode *cur_; - slot_type slot_; - bool locked_; - }; - - class ConstIterator : public std::iterator { - public: - using reference = BPlusTree::value_type &; - using pointer = BPlusTree::value_type *; - - explicit ConstIterator(const BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} - - ~ConstIterator(); - - ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) - : cur_(leaf), slot_(slot), locked_(locked) {} - - explicit ConstIterator(const ConstIterator &); - - ConstIterator &operator=(const ConstIterator &lhs); - - ConstIterator(ConstIterator &&); - - ConstIterator &operator=(ConstIterator &&lhs); - - pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } - - reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } - - value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } - - // Prefix++ - ConstIterator &operator++(); - - // Postfix++ - ConstIterator operator++(int); - - // Prefix-- - ConstIterator &operator--(); - - // Postfix-- - ConstIterator operator--(int); - - bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } - bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } - - private: - const typename BPlusTree::LeafNode *cur_; - slot_type slot_; - bool locked_; - }; - - Iterator begin(); - Iterator end(); - - ConstIterator begin() const; - ConstIterator end() const; - - ConstIterator cbegin() const; - ConstIterator cend() const; - - // Locate the entry with key - std::pair Search(const key_type &key) const; - std::pair Search(const key_type &key); - - value_type operator[](key_type key); -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_INDEX_H_ - -#include "btree_impl.tpp" -#include "btree_iterator.tpp" diff --git a/mindspore/ccsrc/dataset/util/buddy.cc b/mindspore/ccsrc/dataset/util/buddy.cc deleted file mode 100644 index 540fa993d6..0000000000 --- a/mindspore/ccsrc/dataset/util/buddy.cc +++ /dev/null @@ -1,388 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/buddy.h" -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/system_pool.h" -#include "utils/log_adapter.h" -#include "./securec.h" - -inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } - -inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } - -inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } - -inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } - -inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } - -namespace mindspore { -namespace dataset { -Status BuddySpace::Init() { - if (log_min_ < 0) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "log_min must be positive : " + std::to_string(log_min_)); - } - if (num_lvl_ < 3 || num_lvl_ > 18) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, - "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); - } - min_ = BitLeftShift(1, log_min_); - max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); - size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; - size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; - size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; - RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); - hint_ = reinterpret_cast(ptr_); - count_ = reinterpret_cast((reinterpret_cast(ptr_) + offset_1)); - map_ = reinterpret_cast(ptr_) + offset_2; - count_[num_lvl_ - 1] = 1; - map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); - return Status::OK(); -} - -Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { - std::lock_guard lock(mutex_); - addr_t addr = AllocNoLock(sz, desc); - if (addr != NOSPACE) { - *p = addr; - return Status::OK(); - } else { - return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); - } -} - -addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { - MS_ASSERT(sz <= max_); - uint32_t reqSize = SizeToBlock(sz); - rel_addr_t rel_addr = AllocBuddySeg(reqSize); - if (rel_addr != static_cast(NOSPACE)) { - (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); - desc->sig = static_cast(0xDEADBEEF); - desc->addr = rel_addr; - desc->req_size = reqSize; - desc->blk_size = NextPowerOf2(reqSize); - return static_cast(rel_addr * min_); - } else { - return NOSPACE; - } -} - -void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { - MS_ASSERT(desc->sig == 0XDEADBEEF); - rel_addr_t rel_addr = desc->addr; - size_t blk_size = desc->blk_size; - size_t req_size = desc->req_size; - FreeBuddySeg(rel_addr, blk_size, req_size); -} - -void BuddySpace::Free(const BSpaceDescriptor *desc) { - std::lock_guard lock(mutex_); - return FreeNoLock(desc); -} - -std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { - os << "1 unit = " << s.GetMinSize() << "\n" - << "Size of buddy space = " << s.GetMaxSize() << "\n" - << "Number of levels = " << s.num_lvl_ << "\n\n" - << "Percent free = " << s.PercentFree() << "\n" - << "Dumping count array : " - << "\n"; - for (int i = 0; i < s.num_lvl_; i++) { - os << "[" << i << "] = " << s.count_[i] << " "; - if (((i + 1) % 4) == 0) { - os << "\n"; - } - } - os << "\n"; - os << "Dumping allocation info:" - << "\n"; - auto max_addr = static_cast(BitLeftShift(1, s.num_lvl_ - 1)); - rel_addr_t addr = 0; - while (addr < max_addr) { - size_t sz = 0; - BuddySpace::STATE st; - s.GetBuddySegState(addr, &sz, &st); - os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " - << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) - << "\n"; - addr += sz; - } - return os; -} - -void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { - char byte; - int pos; - int offset; - uint64_t val = 0; - int shift; - pos = BitRightShift(rel_addr, 2); - offset = rel_addr % 4; - shift = offset * 2; - byte = map_[pos]; - switch (offset) { - case 0: - val = byte; - break; - case 1: - case 3: - if (offset == 1) { - val = BitLeftShift(BitAnd(byte, 0x30), shift); - } else { - val = BitLeftShift(BitAnd(byte, 0x03), shift); - } - break; - case 2: - val = BitLeftShift(BitAnd(byte, 0x0F), shift); - break; - } - if (BitAnd(val, ONE_BIT)) { - *rel_sz = 1; - } else if (BitAnd(val, TWO_BIT)) { - *rel_sz = 2; - } else if (BitAnd(val, MORE_BIT)) { - log_t lg = BitAnd(val, 0x0F); - *rel_sz = BitLeftShift(1, lg + 2); - } else { - *st = STATE::kEmpty; - return; - } - *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; -} - -void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { - int clr; - int mask; - int pos; - int offset; - int val = 0; - int shift; - auto log_sz = static_cast(Log2(rel_sz)); - pos = BitRightShift(rel_addr, 2); - offset = rel_addr % 4; - shift = offset * 2; - if (rel_sz == 1) { - val = ONE_BIT; - mask = 0xC0; - } else if (rel_sz == 2) { - val = TWO_BIT; - mask = 0xF0; - } else { - val = BitOr(log_sz - 2, MORE_BIT); - mask = 0xFF; - } - if (st == STATE::kAlloc) { - val = BitOr(val, ALLOC_BIT); - } else if (st == STATE::kFree) { - val = BitAnd(val, ~(static_cast(ALLOC_BIT))); - } else if (st == STATE::kEmpty) { - val = 0; - } - clr = static_cast(~(BitRightShift(mask, shift))); - map_[pos] = static_cast(BitAnd(map_[pos], clr)); - map_[pos] = static_cast(BitOr(map_[pos], BitRightShift(val, shift))); - if (st == STATE::kAlloc) { - count_[log_sz]--; - } else if (st == STATE::kFree) { - count_[log_sz]++; - if (rel_addr < hint_[log_sz]) { - hint_[log_sz] = rel_addr; - } - } -} - -void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { - while (blk_sz < BitLeftShift(1, num_lvl_)) { - rel_addr_t buddy = BitEx(addr, blk_sz); - size_t sz = 0; - STATE st; - GetBuddySegState(buddy, &sz, &st); - if (st == STATE::kFree && sz == blk_sz) { - auto log_sz = static_cast(Log2(blk_sz)); - rel_addr_t left = (buddy < addr) ? buddy : addr; - rel_addr_t right = left + blk_sz; - MS_ASSERT(count_[log_sz] >= 2); - count_[log_sz] -= 2; - SetBuddySegState(right, blk_sz, STATE::kEmpty); - SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); - for (int i = 0; i < log_sz; i++) { - if (hint_[i] == right) { - hint_[i] = left; - } - } - addr = left; - blk_sz <<= 1u; - } else { - break; - } - } -} - -void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { - MS_ASSERT(ask_sz < blk_sz); - uint32_t inx = Log2(blk_sz); - size_t remaining_sz = ask_sz; - for (int i = inx; i > 0; i--) { - size_t b_size = BitLeftShift(1, i); - size_t half_sz = BitRightShift(b_size, 1); - count_[i]--; - SetBuddySegState(addr, half_sz, STATE::kFree); - SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); - if (remaining_sz >= half_sz) { - SetBuddySegState(addr, half_sz, STATE::kAlloc); - remaining_sz -= half_sz; - if (remaining_sz == 0) { - break; - } - addr += half_sz; - } - } -} - -void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { - MS_ASSERT(ask_sz < blk_sz); - uint32_t inx = Log2(blk_sz); - size_t remaining_sz = ask_sz; - for (int i = inx; i > 0; i--) { - size_t b_size = BitLeftShift(1, i); - size_t half_sz = BitRightShift(b_size, 1); - if (remaining_sz >= half_sz) { -#ifdef DEBUG - { - size_t sz = 0; - STATE st; - GetBuddySegState(addr, &sz, &st); - MS_ASSERT(sz == half_sz && st == STATE::kAlloc); - } -#endif - SetBuddySegState(addr, half_sz, STATE::kFree); - remaining_sz -= half_sz; - if (remaining_sz == 0) { - JoinBuddySeg(addr, half_sz); - break; - } - addr += half_sz; - } - } -} - -rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { - uint32_t blk_size = NextPowerOf2(req_size); - int start_inx = static_cast(Log2(blk_size)); - bool found = false; - rel_addr_t ask_addr = 0; - auto max_addr = static_cast(BitLeftShift(1, num_lvl_ - 1)); - STATE st; - size_t sz = 0; - for (int i = start_inx; !found && i < num_lvl_; i++) { - MS_ASSERT(count_[i] >= 0); - if (count_[i] == 0) { - continue; - } - auto blk_sz = static_cast(BitLeftShift(1, i)); - ask_addr = hint_[i]; - while (ask_addr < max_addr && !found) { - GetBuddySegState(ask_addr, &sz, &st); - if (st == STATE::kFree && sz == blk_sz) { - found = true; - } else { - MS_ASSERT(st != STATE::kEmpty); - ask_addr += ((sz > blk_sz) ? sz : blk_sz); - } - } - } - if (found) { - if (sz > req_size) { - TrimBuddySeg(ask_addr, sz, req_size); - } else { - SetBuddySegState(ask_addr, sz, STATE::kAlloc); - hint_[start_inx] = ask_addr; - } - return ask_addr; - } else { - return static_cast(NOSPACE); - } -} - -void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { - if (req_size == blk_size) { -#ifdef DEBUG - { - size_t sz = 0; - STATE st; - GetBuddySegState(addr, &sz, &st); - } -#endif - SetBuddySegState(addr, blk_size, STATE::kFree); - JoinBuddySeg(addr, blk_size); - } else { - UnTrimBuddySeg(addr, blk_size, req_size); - } -} - -int BuddySpace::PercentFree() const { - uint64_t total_free_sz = 0; - uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); - // Go through the count array without lock - for (int i = 0; i < num_lvl_; i++) { - int cnt = count_[i]; - if (cnt == 0) { - continue; - } - uint64_t blk_sz = BitLeftShift(1, i); - total_free_sz += (blk_sz * cnt); - } - return static_cast(static_cast(total_free_sz) / static_cast(max_sz_in_unit) * 100); -} - -BuddySpace::BuddySpace(int log_min, int num_lvl) - : hint_(nullptr), - count_(nullptr), - map_(nullptr), - log_min_(log_min), - num_lvl_(num_lvl), - min_(0), - max_(0), - ptr_(nullptr) {} - -BuddySpace::~BuddySpace() { - if (ptr_ != nullptr) { - free(ptr_); - } - hint_ = nullptr; - count_ = nullptr; - map_ = nullptr; -} - -Status BuddySpace::CreateBuddySpace(std::unique_ptr *out_bs, int log_min, int num_lvl) { - Status rc; - auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); - if (bs == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - rc = bs->Init(); - if (rc.IsOk()) { - (*out_bs).reset(bs); - } else { - delete bs; - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/buddy.h b/mindspore/ccsrc/dataset/util/buddy.h deleted file mode 100644 index 08c05cbbdb..0000000000 --- a/mindspore/ccsrc/dataset/util/buddy.h +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_BUDDY_H_ -#define DATASET_UTIL_BUDDY_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/util/status.h" - -using addr_t = int64_t; -using rel_addr_t = int32_t; -using log_t = int; -#define ALLOC_BIT 0x80 -#define ONE_BIT 0x40 -#define TWO_BIT 0x20 -#define MORE_BIT 0x10 -#define NOSPACE ((addr_t)(-1)) -namespace mindspore { -namespace dataset { -struct BSpaceDescriptor { - int32_t sig; - rel_addr_t addr; - size_t req_size; - size_t blk_size; -}; - -class BuddySpace { - public: - // C++11 feature. Change STATE into a type safe class with - // the keyword. Don't take out the keyword 'class' - enum class STATE { kFree, kAlloc, kEmpty }; - - BuddySpace(const BuddySpace &) = delete; - - BuddySpace &operator=(const BuddySpace &) = delete; - - virtual ~BuddySpace(); - - Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; - - void Free(const BSpaceDescriptor *desc); - - uint64_t GetMinSize() const { return min_; } - - uint64_t GetMaxSize() const { return max_; } - - int PercentFree() const; - - friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); - - static uint64_t NextPowerOf2(uint64_t n) { - if (n <= 1) { - return 1; - } - n = n - 1; - while (n & (n - 1)) { - n = n & (n - 1); - } - return n << 1; - } - - static uint32_t Log2(uint64_t n) { - uint32_t cnt = 0; - while (n >>= 1) { - cnt++; - } - return cnt; - } - - static Status CreateBuddySpace(std::unique_ptr *out_bs, int log_min = 15, int num_lvl = 18); - - private: - rel_addr_t *hint_; - int *count_; - char *map_; - int log_min_; - int num_lvl_; - uint64_t min_; - uint64_t max_; - void *ptr_; - std::mutex mutex_; - - explicit BuddySpace(int log_min = 15, int num_lvl = 18); - - Status Init(); - - addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; - - void FreeNoLock(const BSpaceDescriptor *desc); - - uint32_t SizeToBlock(const uint64_t sz) const { - uint32_t reqSize = (sz / min_); - if (sz % min_) { - reqSize++; - } - return reqSize; - } - - void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; - - void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); - - void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); - - void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); - - void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); - - rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; - - void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/dataset/util/cache_pool.cc b/mindspore/ccsrc/dataset/util/cache_pool.cc deleted file mode 100644 index 7d7a2a4a94..0000000000 --- a/mindspore/ccsrc/dataset/util/cache_pool.cc +++ /dev/null @@ -1,197 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "common/utils.h" -#include "dataset/util/cache_pool.h" -#include "dataset/util/services.h" - -namespace mindspore { -namespace dataset { -CachePool::CachePool(const value_allocator &alloc, const std::string &root) - : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} - -Status CachePool::DoServiceStart() { - tree_ = std::make_shared(); - // If we are given a disk path, set up the StorageManager - if (!root_.toString().empty()) { - Path spill = GetSpillPath(); - RETURN_IF_NOT_OK(spill.CreateDirectories()); - sm_ = std::make_shared(spill); - RETURN_IF_NOT_OK(sm_->ServiceStart()); - MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); - } - return Status::OK(); -} -Status CachePool::DoServiceStop() { - Status rc; - Status rc2; - if (sm_ != nullptr) { - rc = sm_->ServiceStop(); - if (rc.IsError()) { - rc2 = rc; - } - } - sm_.reset(); - for (auto &bl : *tree_) { - if (bl.ptr != nullptr) { - alloc_.deallocate(bl.ptr, bl.sz); - } - } - tree_.reset(); - if (!root_.toString().empty()) { - Path spill = GetSpillPath(); - auto it = Path::DirIterator::OpenDirectory(&spill); - while (it->hasNext()) { - rc = it->next().Remove(); - if (rc.IsError() && rc2.IsOk()) { - rc2 = rc; - } - } - rc = spill.Remove(); - if (rc.IsError() && rc2.IsOk()) { - rc2 = rc; - } - } - return rc2; -} -CachePool::~CachePool() noexcept { (void)ServiceStop(); } -Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { - DataLocator bl; - Status rc; - size_t sz = 0; - // We will consolidate all the slices into one piece. - for (auto &v : buf) { - sz += v.GetSize(); - } - bl.sz = sz; - try { - bl.ptr = alloc_.allocate(sz); - // We will do a piecewise copy. - WritableSlice dest(bl.ptr, bl.sz); - size_t pos = 0; - for (auto &v : buf) { - WritableSlice out(dest, pos); - rc = WritableSlice::Copy(&out, v); - if (rc.IsError()) { - break; - } - pos += v.GetSize(); - } - if (rc.IsError()) { - alloc_.deallocate(bl.ptr, sz); - bl.ptr = nullptr; - return rc; - } - } catch (std::bad_alloc &e) { - if (sm_ != nullptr) { - RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); - } else { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - rc = tree_->insert(bl, key); - if (rc.IsError() && bl.ptr != nullptr) { - alloc_.deallocate(bl.ptr, sz); - } - return rc; -} -Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { - RETURN_UNEXPECTED_IF_NULL(dest); - auto r = tree_->Search(key); - if (r.second) { - auto &it = r.first; - if (it->ptr != nullptr) { - ReadableSlice src(it->ptr, it->sz); - RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); - } else if (sm_ != nullptr) { - size_t expectedLength = 0; - RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); - if (expectedLength != it->sz) { - MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." - << " Internal key: " << key << "\n"; - RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); - } - } - if (bytesRead != nullptr) { - *bytesRead = it->sz; - } - } else { - RETURN_STATUS_UNEXPECTED("Key not found"); - } - return Status::OK(); -} -const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } -Path CachePool::GetSpillPath() const { - auto spill = Path(root_) / subfolder_; - return spill; -} -CachePool::CacheStat CachePool::GetStat() const { - CacheStat cs{0}; - for (auto &it : *tree_) { - if (it.ptr != nullptr) { - ++cs.num_mem_cached; - } else { - ++cs.num_disk_cached; - } - } - return cs; -} -Status CachePool::Spill(CachePool::DataLocator *dl) { - if (sm_ == nullptr) { - RETURN_STATUS_UNEXPECTED("No disk storage to spill"); - } - RETURN_UNEXPECTED_IF_NULL(dl); - RETURN_UNEXPECTED_IF_NULL(dl->ptr); - if (dl->storage_key == 0) { - ReadableSlice data(dl->ptr, dl->sz); - RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); - } - alloc_.deallocate(dl->ptr, dl->sz); - dl->ptr = nullptr; - return Status::OK(); -} -Status CachePool::Locate(CachePool::DataLocator *dl) { - RETURN_UNEXPECTED_IF_NULL(dl); - if (dl->ptr == nullptr) { - if (sm_ == nullptr) { - RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); - } - try { - dl->ptr = alloc_.allocate(dl->sz); - WritableSlice dest(dl->ptr, dl->sz); - Status rc = Read(dl->storage_key, &dest); - if (rc.IsError()) { - alloc_.deallocate(dl->ptr, dl->sz); - dl->ptr = nullptr; - return rc; - } - } catch (const std::bad_alloc &e) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - return Status::OK(); -} -size_t CachePool::GetSize(CachePool::key_type key) const { - auto r = tree_->Search(key); - if (r.second) { - auto &it = r.first; - return it->sz; - } else { - return 0; - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/cache_pool.h b/mindspore/ccsrc/dataset/util/cache_pool.h deleted file mode 100644 index d35617d0e4..0000000000 --- a/mindspore/ccsrc/dataset/util/cache_pool.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_CACHE_POOL_H_ -#define DATASET_UTIL_CACHE_POOL_H_ - -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/service.h" -#include "dataset/util/slice.h" -#include "dataset/util/storage_manager.h" -#include "dataset/util/auto_index.h" - -namespace mindspore { -namespace dataset { -/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of -/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to -/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to -/// restore the buffer. -/// \see ReadableSlice -class CachePool : public Service { - public: - using base_type = uint8_t; - using pointer = base_type *; - using const_pointer = const base_type *; - using reference = base_type &; - using const_reference = const base_type &; - using value_allocator = Allocator; - - // An internal class to locate the whereabouts of a backed up buffer which can be either in - class DataLocator { - public: - DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} - ~DataLocator() = default; - DataLocator(const DataLocator &other) = default; - DataLocator &operator=(const DataLocator &other) = default; - DataLocator(DataLocator &&other) noexcept { - ptr = other.ptr; - sz = other.sz; - storage_key = other.storage_key; - other.ptr = nullptr; - other.sz = 0; - other.storage_key = 0; - } - DataLocator &operator=(DataLocator &&other) noexcept { - if (&other != this) { - ptr = other.ptr; - sz = other.sz; - storage_key = other.storage_key; - other.ptr = nullptr; - other.sz = 0; - other.storage_key = 0; - } - return *this; - } - pointer ptr; - size_t sz; - StorageManager::key_type storage_key; - }; - - using data_index = AutoIndexObj; - using key_type = data_index::key_type; - using bl_alloc_type = typename value_allocator::template rebind::other; - - /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and - /// how many elements are spilled to disk. - struct CacheStat { - int64_t num_mem_cached; - int64_t num_disk_cached; - }; - - /// \brief Constructor - /// \param alloc Allocator to allocate memory from - /// \param root Optional disk folder to spill - explicit CachePool(const value_allocator &alloc, const std::string &root = ""); - - CachePool(const CachePool &) = delete; - CachePool(CachePool &&) = delete; - CachePool &operator=(const CachePool &) = delete; - CachePool &operator=(CachePool &&) = delete; - ~CachePool() noexcept; - - Status DoServiceStart() override; - Status DoServiceStop() override; - - Path GetSpillPath() const; - - /// \brief Insert a sequence of ReadableSlice objects into the pool. - /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. - /// \param[in] buf A sequence of ReadableSlice objects. - /// \param[out] key Generated key - /// \return Error code - Status Insert(const std::vector &buf, key_type *key); - /// \brief Restore a cached buffer (from memory or disk) - /// \param[in] key A previous key returned from Insert - /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice - /// \param[out] bytesRead Optional. Number of bytes read. - /// \return Error code - Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; - - Status Spill(DataLocator *dl); - - Status Locate(DataLocator *dl); - - size_t GetSize(key_type key) const; - - /// \brief Get statistics. - /// \return CacheStat object - CacheStat GetStat() const; - - const value_allocator &get_allocator() const; - - std::string MyName() const { return subfolder_; } - - private: - value_allocator alloc_; - Path root_; - const std::string subfolder_; - std::shared_ptr sm_; - std::shared_ptr tree_; -}; -} // namespace dataset -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/dataset/util/circular_pool.cc b/mindspore/ccsrc/dataset/util/circular_pool.cc deleted file mode 100644 index 42cccd87ed..0000000000 --- a/mindspore/ccsrc/dataset/util/circular_pool.cc +++ /dev/null @@ -1,225 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/circular_pool.h" - -#include -#include -#include -#include "./securec.h" -#include "dataset/util/system_pool.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -Status CircularPool::AddOneArena() { - Status rc; - std::shared_ptr b; - RETURN_IF_NOT_OK(Arena::CreateArena(&b, arena_size_)); - tail_ = b.get(); - cur_size_in_mb_ += arena_size_; - mem_segments_.push_back(std::move(b)); - return Status::OK(); -} - -ListOfArenas::iterator CircularPool::CircularIterator::Next() { - ListOfArenas::iterator it = dp_->mem_segments_.begin(); - uint32_t size = dp_->mem_segments_.size(); - // This is what we return - it += cur_; - // Prepare for the next round - cur_++; - if (cur_ == size) { - if (start_ == 0) { - has_next_ = false; - } else { - wrap_ = true; - cur_ = 0; - } - } else if (cur_ == start_) { - has_next_ = false; - } - return it; -} - -bool CircularPool::CircularIterator::has_next() const { return has_next_; } - -void CircularPool::CircularIterator::Reset() { - wrap_ = false; - has_next_ = false; - if (!dp_->mem_segments_.empty()) { - // Find the buddy arena that corresponds to the tail. - cur_tail_ = dp_->tail_; - auto list_end = dp_->mem_segments_.end(); - auto it = std::find_if(dp_->mem_segments_.begin(), list_end, - [this](const std::shared_ptr &b) { return b.get() == cur_tail_; }); - MS_ASSERT(it != list_end); - start_ = std::distance(dp_->mem_segments_.begin(), it); - cur_ = start_; - has_next_ = true; - } -} - -CircularPool::CircularIterator::CircularIterator(CircularPool *dp) : dp_(dp) { Reset(); } - -Status CircularPool::Allocate(size_t n, void **p) { - if (p == nullptr) { - RETURN_STATUS_UNEXPECTED("p is null"); - } - Status rc; - void *ptr = nullptr; - do { - SharedLock lock_s(&rw_lock_); - int prevSzInMB = cur_size_in_mb_; - bool move_tail = false; - CircularIterator cirIt(this); - while (cirIt.has_next()) { - auto it = cirIt.Next(); - Arena *ba = it->get(); - if (ba->get_max_size() < n) { - return Status(StatusCode::kOutOfMemory); - } - // If we are asked to move forward the tail - if (move_tail) { - Arena *expected = cirIt.cur_tail_; - (void)atomic_compare_exchange_weak(&tail_, &expected, ba); - move_tail = false; - } - rc = ba->Allocate(n, &ptr); - if (rc.IsOk()) { - *p = ptr; - break; - } else if (rc.IsOutofMemory()) { - // Make the next arena a new tail and continue. - move_tail = true; - } else { - return rc; - } - } - - // Handle the case we have done one round robin search. - if (ptr == nullptr) { - // If we have room to expand. - if (unlimited_ || cur_size_in_mb_ < max_size_in_mb_) { - // lock in exclusively mode. - lock_s.Upgrade(); - // Check again if someone has already expanded. - if (cur_size_in_mb_ == prevSzInMB) { - RETURN_IF_NOT_OK(AddOneArena()); - } - // Re-acquire the shared lock and try again - lock_s.Downgrade(); - } else { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } - } - } while (ptr == nullptr); - return rc; -} - -void CircularPool::Deallocate(void *p) { - // Lock in the chain in shared mode and find out which - // segment it comes from - SharedLock lock(&rw_lock_); - auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { - char *q = reinterpret_cast(p); - char *base = const_cast(reinterpret_cast(b->get_base_addr())); - return (q > base && q < base + b->get_max_size()); - }); - lock.Unlock(); - it->get()->Deallocate(p); -} - -Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { - // Lock in the chain in shared mode and find out which - // segment it comes from - if (pp == nullptr) { - RETURN_STATUS_UNEXPECTED("pp is null"); - } - void *p = *pp; - SharedLock lock(&rw_lock_); - auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { - char *q = reinterpret_cast(p); - char *base = const_cast(reinterpret_cast(b->get_base_addr())); - return (q > base && q < base + b->get_max_size()); - }); - lock.Unlock(); - MS_ASSERT(it != mem_segments_.end()); - Arena *ba = it->get(); - Status rc = ba->Reallocate(pp, old_sz, new_sz); - if (rc.IsOutofMemory()) { - // The current arena has no room for the bigger size. - // Allocate free space from another arena and copy - // the content over. - void *q = nullptr; - rc = this->Allocate(new_sz, &q); - RETURN_IF_NOT_OK(rc); - errno_t err = memcpy_s(q, new_sz, p, old_sz); - if (err) { - this->Deallocate(q); - RETURN_STATUS_UNEXPECTED(std::to_string(err)); - } - *pp = q; - ba->Deallocate(p); - } - return Status::OK(); -} - -uint64_t CircularPool::get_max_size() const { return mem_segments_.front()->get_max_size(); } - -int CircularPool::PercentFree() const { - int percent_free = 0; - int num_arena = 0; - for (auto const &p : mem_segments_) { - percent_free += p->PercentFree(); - num_arena++; - } - if (num_arena) { - return percent_free / num_arena; - } else { - return 100; - } -} - -CircularPool::CircularPool(int max_size_in_gb, int arena_size) - : unlimited_(max_size_in_gb <= 0), - max_size_in_mb_(unlimited_ ? std::numeric_limits::max() : max_size_in_gb * 1024), - arena_size_(arena_size), - cur_size_in_mb_(0) {} - -Status CircularPool::CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb, int arena_size, - bool createOneArena) { - Status rc; - if (out_pool == nullptr) { - RETURN_STATUS_UNEXPECTED("pPool is null"); - } - auto pool = new (std::nothrow) CircularPool(max_size_in_gb, arena_size); - if (pool == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - if (createOneArena) { - rc = pool->AddOneArena(); - } - if (rc.IsOk()) { - (*out_pool).reset(pool); - } else { - delete pool; - } - return rc; -} - -CircularPool::~CircularPool() = default; -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/circular_pool.h b/mindspore/ccsrc/dataset/util/circular_pool.h deleted file mode 100644 index 3c52659799..0000000000 --- a/mindspore/ccsrc/dataset/util/circular_pool.h +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_CIRCULAR_POOL_H_ -#define DATASET_UTIL_CIRCULAR_POOL_H_ - -#include -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/arena.h" -#include "dataset/util/lock.h" - -namespace mindspore { -namespace dataset { -using ListOfArenas = std::vector>; - -// This is a dynamic memory pool built on top of memory -// segment each of which is 4G in size. Initially we start -// with one segment, and gradually add segments (not -// guaranteed contiguous) until we reach 32G in size. There -// is an assumption about this kind of memory pool. Allocated -// memory is not held for the whole duration of the pool and -// will be released soon. Based on this assumption, memory is -// obtained from the tail while allocated memory is returned -// to the head of the pool. -class CircularPool : public MemoryPool { - public: - class CircularIterator { - friend class CircularPool; - - public: - explicit CircularIterator(CircularPool *dp); - - ~CircularIterator() = default; - - bool has_next() const; - - ListOfArenas::iterator Next(); - - void Reset(); - - private: - CircularPool *dp_; - Arena *cur_tail_{}; - uint32_t start_{}; - uint32_t cur_{}; - bool wrap_{}; - bool has_next_{}; - }; - - CircularPool(const CircularPool &) = delete; - - CircularPool &operator=(const CircularPool &) = delete; - - ~CircularPool() override; - - Status Allocate(size_t n, void **) override; - - Status Reallocate(void **, size_t old_size, size_t new_size) override; - - void Deallocate(void *) override; - - uint64_t get_max_size() const override; - - int PercentFree() const override; - - friend std::ostream &operator<<(std::ostream &os, const CircularPool &s) { - int i = 0; - for (auto it = s.mem_segments_.begin(); it != s.mem_segments_.end(); ++it, ++i) { - os << "Dumping segment " << i << "\n" << *(it->get()); - } - return os; - } - - static Status CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb = -1, - int arena_size = 4096, bool create_one_arena = false); - - private: - ListOfArenas mem_segments_; - std::atomic tail_{}; - bool unlimited_; - int max_size_in_mb_; - int arena_size_; - int cur_size_in_mb_; - RWLock rw_lock_; - - // We can take negative or 0 as input which means unlimited. - CircularPool(int max_size_in_gb, int arena_size); - - Status AddOneArena(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_CIRCULAR_POOL_H_ diff --git a/mindspore/ccsrc/dataset/util/cond_var.cc b/mindspore/ccsrc/dataset/util/cond_var.cc deleted file mode 100644 index 8b1099fb71..0000000000 --- a/mindspore/ccsrc/dataset/util/cond_var.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/cond_var.h" -#include -#include -#include "dataset/util/services.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {} - -Status CondVar::Wait(std::unique_lock *lck, const std::function &pred) { - try { - if (svc_ != nullptr) { - // If this cv registers with a global resource tracking, then wait unconditionally. - auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); }; - cv_.wait(*lck, f); - // If we are interrupted, override the return value if this is the master thread. - // Master thread is being interrupted mostly because of some thread is reporting error. - RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus())); - } else { - // Otherwise we wake up once a while to check for interrupt (for this thread). - auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); }; - while (!f()) { - (void)cv_.wait_for(*lck, std::chrono::milliseconds(1)); - } - RETURN_IF_INTERRUPTED(); - } - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - return Status::OK(); -} - -CondVar::~CondVar() noexcept { - if (svc_ != nullptr) { - (void)svc_->Deregister(my_name_); - svc_ = nullptr; - } -} - -void CondVar::NotifyOne() noexcept { cv_.notify_one(); } - -void CondVar::NotifyAll() noexcept { cv_.notify_all(); } - -Status CondVar::Register(std::shared_ptr svc) { - Status rc = svc->Register(my_name_, this); - if (rc.IsOk()) { - svc_ = svc; - } - return rc; -} - -void CondVar::Interrupt() { - IntrpResource::Interrupt(); - cv_.notify_all(); -} - -std::string CondVar::my_name() const { return my_name_; } - -Status CondVar::Deregister() { - if (svc_) { - Status rc = svc_->Deregister(my_name_); - svc_ = nullptr; - return rc; - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/cond_var.h b/mindspore/ccsrc/dataset/util/cond_var.h deleted file mode 100644 index b23dcd566e..0000000000 --- a/mindspore/ccsrc/dataset/util/cond_var.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_COND_VAR_H_ -#define DATASET_UTIL_COND_VAR_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/intrp_resource.h" -#include "dataset/util/intrp_service.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class CondVar : public IntrpResource { - public: - CondVar(); - - ~CondVar() noexcept; - - Status Wait(std::unique_lock *lck, const std::function &pred); - - void Interrupt() override; - - void NotifyOne() noexcept; - - void NotifyAll() noexcept; - - Status Register(std::shared_ptr svc); - - std::string my_name() const; - - Status Deregister(); - - protected: - std::condition_variable cv_; - std::shared_ptr svc_; - - private: - std::string my_name_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_COND_VAR_H_ diff --git a/mindspore/ccsrc/dataset/util/intrp_resource.h b/mindspore/ccsrc/dataset/util/intrp_resource.h deleted file mode 100644 index 52024cb90a..0000000000 --- a/mindspore/ccsrc/dataset/util/intrp_resource.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_INTRP_RESOURCE_H_ -#define DATASET_UTIL_INTRP_RESOURCE_H_ - -#include -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class IntrpResource { - public: - enum class State : int { kRunning, kInterrupted }; - - IntrpResource() : st_(State::kRunning) {} - - virtual ~IntrpResource() = default; - - virtual void Interrupt() { st_ = State::kInterrupted; } - - virtual void ResetIntrpState() { st_ = State::kRunning; } - - State CurState() const { return st_; } - - bool Interrupted() const { return CurState() == State::kInterrupted; } - - virtual Status GetInterruptStatus() const { - if (Interrupted()) { - return Status(StatusCode::kInterrupted); - } - return Status::OK(); - } - - protected: - std::atomic st_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_INTRP_RESOURCE_H_ diff --git a/mindspore/ccsrc/dataset/util/intrp_service.cc b/mindspore/ccsrc/dataset/util/intrp_service.cc deleted file mode 100644 index da8dde992c..0000000000 --- a/mindspore/ccsrc/dataset/util/intrp_service.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/intrp_service.h" -#include -#include "common/utils.h" -#include "dataset/util/services.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -IntrpService::IntrpService() : high_water_mark_(0) { (void)ServiceStart(); } - -IntrpService::~IntrpService() noexcept { - MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; - if (!all_intrp_resources_.empty()) { - try { - InterruptAll(); - } catch (const std::exception &e) { - // Ignore all error as we can't throw in the destructor. - } - } - (void)ServiceStop(); -} - -Status IntrpService::Register(const std::string &name, IntrpResource *res) { - SharedLock stateLck(&state_lock_); - // Now double check the state - if (ServiceState() != STATE::kRunning) { - return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Interrupt service is shutting down"); - } else { - std::lock_guard lck(mutex_); - try { - std::ostringstream ss; - ss << this_thread::get_id(); - MS_LOG(DEBUG) << "Register resource with name " << name << ". Thread ID " << ss.str() << "."; - auto it = all_intrp_resources_.emplace(name, res); - if (it.second == false) { - return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, name); - } - high_water_mark_++; - } catch (std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - } - return Status::OK(); -} - -Status IntrpService::Deregister(const std::string &name) noexcept { - std::lock_guard lck(mutex_); - try { - std::ostringstream ss; - ss << this_thread::get_id(); - MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; - auto n = all_intrp_resources_.erase(name); - if (n == 0) { - MS_LOG(INFO) << "Key " << name << " not found."; - } - } catch (std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - return Status::OK(); -} - -void IntrpService::InterruptAll() noexcept { - std::lock_guard lck(mutex_); - for (auto const &it : all_intrp_resources_) { - std::string kName = it.first; - try { - it.second->Interrupt(); - } catch (const std::exception &e) { - // continue the clean up. - } - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/intrp_service.h b/mindspore/ccsrc/dataset/util/intrp_service.h deleted file mode 100644 index de1d5eb753..0000000000 --- a/mindspore/ccsrc/dataset/util/intrp_service.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_INTRP_SERVICE_H_ -#define DATASET_UTIL_INTRP_SERVICE_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/intrp_resource.h" -#include "dataset/util/service.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -using SvcAllocator = Allocator>; - -class IntrpService : public Service { - public: - IntrpService(); - - ~IntrpService() noexcept override; - - IntrpService(const IntrpService &) = delete; - - IntrpService &operator=(const IntrpService &) = delete; - - Status Register(const std::string &name, IntrpResource *res); - - Status Deregister(const std::string &name) noexcept; - - void InterruptAll() noexcept; - - Status DoServiceStart() override { return Status::OK(); } - - Status DoServiceStop() override { return Status::OK(); } - - private: - int high_water_mark_; - std::mutex mutex_; - std::map all_intrp_resources_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_INTRP_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/util/lock.cc b/mindspore/ccsrc/dataset/util/lock.cc deleted file mode 100644 index bde9d84005..0000000000 --- a/mindspore/ccsrc/dataset/util/lock.cc +++ /dev/null @@ -1,185 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/lock.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -void SpinLock::Lock() { - while (true) { - int expected = kUnlocked; - if (val_.compare_exchange_weak(expected, kLocked)) { - break; - } - } -} - -bool SpinLock::TryLock() { - int expected = kUnlocked; - return val_.compare_exchange_strong(expected, kLocked); -} - -void SpinLock::Unlock() noexcept { val_.store(kUnlocked); } - -void RWLock::LockShared() { - std::unique_lock lck(mtx_); - waiting_readers_ += 1; - read_cv_.wait(lck, [this]() { return (waiting_writers_ == 0 && status_ >= 0); }); - waiting_readers_ -= 1; - status_ += 1; -} - -void RWLock::Unlock() noexcept { - std::unique_lock lck(mtx_); - if (status_ == -1) { - // I am the writer. By definition, no other writer nor reader. - status_ = 0; - } else if (status_ > 0) { - // One less reader - status_ -= 1; - } - // Wake up writer only if there is no reader. - if (waiting_writers_ > 0) { - if (status_ == 0) { - write_cv_.notify_one(); - } - } else { - read_cv_.notify_all(); - } -} - -void RWLock::Upgrade() { - std::unique_lock lck(mtx_); - MS_ASSERT(status_); - if (status_ == -1) { - // I am a writer already. - return; - } else if (status_ == 1) { - // If I am the only reader. Just change the status. - status_ = -1; - return; - } else { - // In all other cases, let of the shared lock and relock in exclusive. - lck.unlock(); - this->Unlock(); - this->LockExclusive(); - } -} - -void RWLock::Downgrade() { - std::unique_lock lck(mtx_); - MS_ASSERT(status_); - if (status_ == -1) { - // If there are no other writers waiting, just change the status - if (waiting_writers_ == 0) { - status_ = 1; - } else { - // Otherwise just unlock and relock in shared - lck.unlock(); - this->Unlock(); - this->LockShared(); - } - } else if (status_ > 0) { - return; - } -} - -SharedLock::SharedLock(RWLock *rw) : rw_(rw), ownlock_(false) { - rw_->LockShared(); - ownlock_ = true; -} - -SharedLock::~SharedLock() { - if (ownlock_) { - rw_->Unlock(); - ownlock_ = false; - } - rw_ = nullptr; -} - -void SharedLock::Unlock() { - MS_ASSERT(ownlock_ == true); - rw_->Unlock(); - ownlock_ = false; -} - -void SharedLock::Lock() { - MS_ASSERT(ownlock_ == false); - rw_->LockShared(); - ownlock_ = true; -} - -void SharedLock::Upgrade() { - MS_ASSERT(ownlock_ == true); - rw_->Upgrade(); -} - -void SharedLock::Downgrade() { - MS_ASSERT(ownlock_ == true); - rw_->Downgrade(); -} - -UniqueLock::UniqueLock(RWLock *rw) : rw_(rw), ownlock_(false) { - rw_->LockExclusive(); - ownlock_ = true; -} - -UniqueLock::~UniqueLock() { - if (ownlock_) { - rw_->Unlock(); - ownlock_ = false; - } - rw_ = nullptr; -} - -void UniqueLock::Unlock() { - MS_ASSERT(ownlock_ == true); - rw_->Unlock(); - ownlock_ = false; -} - -void UniqueLock::Lock() { - MS_ASSERT(ownlock_ == false); - rw_->LockExclusive(); - ownlock_ = true; -} - -LockGuard::LockGuard(SpinLock *lock) : lck_(lock), own_lock_(false) { - lck_->Lock(); - own_lock_ = true; -} - -LockGuard::~LockGuard() { - if (own_lock_) { - lck_->Unlock(); - own_lock_ = false; - } - lck_ = nullptr; -} - -void LockGuard::Unlock() { - MS_ASSERT(own_lock_); - lck_->Unlock(); - own_lock_ = false; -} - -void LockGuard::Lock() { - MS_ASSERT(own_lock_ == false); - lck_->Lock(); - own_lock_ = true; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/memory_pool.cc b/mindspore/ccsrc/dataset/util/memory_pool.cc deleted file mode 100644 index 5d66b4bd6d..0000000000 --- a/mindspore/ccsrc/dataset/util/memory_pool.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/memory_pool.h" -#include "./securec.h" - -namespace mindspore { -namespace dataset { -Status DeMalloc(std::size_t s, void **p, bool init_to_zero = false) { - if (p == nullptr) { - RETURN_STATUS_UNEXPECTED("p is null"); - } - void *q = ::malloc(s); - if (q == nullptr) { - return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); - } else { - *p = q; - if (init_to_zero) { - (void)memset_s(q, s, 0, s); - } - return Status::OK(); - } -} -} // namespace dataset -} // namespace mindspore - -void *operator new(std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { - void *ptr = nullptr; - *rc = b->Allocate(s, &ptr); - return ptr; -} - -void *operator new[](std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { - void *ptr = nullptr; - *rc = b->Allocate(s, &ptr); - return ptr; -} - -void operator delete(void *p, std::shared_ptr b) { - if (p != nullptr) b->Deallocate(p); -} - -void operator delete[](void *p, std::shared_ptr b) { - if (p != nullptr) b->Deallocate(p); -} diff --git a/mindspore/ccsrc/dataset/util/memory_pool.h b/mindspore/ccsrc/dataset/util/memory_pool.h deleted file mode 100644 index ee1da3bda1..0000000000 --- a/mindspore/ccsrc/dataset/util/memory_pool.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_MEMORY_POOL_H_ -#define DATASET_UTIL_MEMORY_POOL_H_ - -#include -#include -#include -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -// Abstract class of a memory pool -class MemoryPool { - public: - // Allocate a block of size n - virtual Status Allocate(size_t, void **) = 0; - - // Enlarge or shrink a block from oldSz to newSz - virtual Status Reallocate(void **, size_t old_sz, size_t new_sz) = 0; - - // Free a pointer - virtual void Deallocate(void *) = 0; - - // What is the maximum size I can allocate ? - virtual uint64_t get_max_size() const = 0; - - virtual int PercentFree() const = 0; - - // Destructor - virtual ~MemoryPool() {} -}; - -Status DeMalloc(std::size_t s, void **p, bool); -} // namespace dataset -} // namespace mindspore - -void *operator new(std::size_t, mindspore::dataset::Status *, std::shared_ptr); - -void *operator new[](std::size_t, mindspore::dataset::Status *, std::shared_ptr); - -void operator delete(void *, std::shared_ptr); - -void operator delete[](void *, std::shared_ptr); - -#endif // DATASET_UTIL_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/dataset/util/path.cc b/mindspore/ccsrc/dataset/util/path.cc deleted file mode 100644 index cdd2343799..0000000000 --- a/mindspore/ccsrc/dataset/util/path.cc +++ /dev/null @@ -1,340 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/path.h" - -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -#if defined(_WIN32) || defined(_WIN64) -char Path::separator_ = '\\'; -#else -char Path::separator_ = '/'; -#endif - -Path::Path(const std::string &s) : path_(s) {} - -Path::Path(const char *p) : path_(p) {} - -Path::Path(const Path &p) : path_(p.path_) {} - -Path &Path::operator=(const Path &p) { - if (&p != this) { - this->path_ = p.path_; - } - return *this; -} - -Path &Path::operator=(Path &&p) noexcept { - if (&p != this) { - this->path_ = std::move(p.path_); - } - return *this; -} - -Path::Path(Path &&p) noexcept { this->path_ = std::move(p.path_); } - -Path Path::operator+(const Path &p) { - std::string q = path_ + p.toString(); - return Path(q); -} - -Path Path::operator+(const std::string &p) { - std::string q = path_ + p; - return Path(q); -} - -Path Path::operator+(const char *p) { - std::string q = path_ + p; - return Path(q); -} - -Path &Path::operator+=(const Path &rhs) { - path_ += rhs.toString(); - return *this; -} - -Path &Path::operator+=(const std::string &p) { - path_ += p; - return *this; -} - -Path &Path::operator+=(const char *p) { - path_ += p; - return *this; -} - -Path Path::operator/(const Path &p) { - std::string q = path_ + separator_ + p.toString(); - return Path(q); -} - -Path Path::operator/(const std::string &p) { - std::string q = path_ + separator_ + p; - return Path(q); -} - -Path Path::operator/(const char *p) { - std::string q = path_ + separator_ + p; - return Path(q); -} - -std::string Path::Extension() const { - std::size_t found = path_.find_last_of('.'); - if (found != std::string::npos) { - return path_.substr(found); - } else { - return std::string(""); - } -} - -bool Path::Exists() { - struct stat sb; - int rc = stat(common::SafeCStr(path_), &sb); - if (rc == -1 && errno != ENOENT) { - MS_LOG(WARNING) << "Unable to query the status of " << path_ << ". Errno = " << errno << "."; - } - return (rc == 0); -} - -bool Path::IsDirectory() { - struct stat sb; - int rc = stat(common::SafeCStr(path_), &sb); - if (rc == 0) { - return S_ISDIR(sb.st_mode); - } else { - return false; - } -} - -Status Path::CreateDirectory() { - if (!Exists()) { -#if defined(_WIN32) || defined(_WIN64) - int rc = mkdir(common::SafeCStr(path_)); -#else - int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); -#endif - if (rc) { - std::ostringstream oss; - oss << "Unable to create directory " << path_ << ". Errno = " << errno; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - return Status::OK(); - } else { - if (IsDirectory()) { - return Status::OK(); - } else { - std::ostringstream oss; - oss << "Unable to create directory " << path_ << ". It exists but is not a directory"; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - } -} - -std::string Path::ParentPath() { - std::string r(""); - std::size_t found = path_.find_last_of(separator_); - if (found != std::string::npos) { - if (found == 0) { - r += separator_; - } else { - r = std::string(path_.substr(0, found)); - } - } - return r; -} - -Status Path::CreateDirectories() { - if (IsDirectory()) { - MS_LOG(DEBUG) << "Directory " << toString() << " already exists."; - return Status::OK(); - } else { - MS_LOG(DEBUG) << "Creating directory " << toString() << "."; - std::string parent = ParentPath(); - if (!parent.empty()) { - if (Path(parent).CreateDirectories()) { - return CreateDirectory(); - } - } else { - return CreateDirectory(); - } - } - return Status::OK(); -} - -Status Path::Remove() { - if (Exists()) { - if (IsDirectory()) { - errno_t err = rmdir(common::SafeCStr(path_)); - if (err == -1) { - std::ostringstream oss; - oss << "Unable to delete directory " << path_ << ". Errno = " << errno; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - } else { - errno_t err = unlink(common::SafeCStr(path_)); - if (err == -1) { - std::ostringstream oss; - oss << "Unable to delete file " << path_ << ". Errno = " << errno; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - } - } - return Status::OK(); -} - -Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } - -Status Path::OpenFile(int *file_descriptor, bool create) { - int fd; - if (file_descriptor == nullptr) { - RETURN_STATUS_UNEXPECTED("null pointer"); - } - if (IsDirectory()) { - std::ostringstream oss; - oss << "Unable to create file " << path_ << " which is a directory."; - RETURN_STATUS_UNEXPECTED(oss.str()); - } - // Convert to canonical form. - if (strlen(common::SafeCStr(path_)) > PATH_MAX) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - char canonical_path[PATH_MAX + 1] = {0x00}; -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { -#else - if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { -#endif - if (errno == ENOENT && create) { - // File doesn't exist and we are to create it. Let's break it down. - auto file_part = Basename(); - auto parent_part = ParentPath(); -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { -#else - if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { -#endif - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - auto cur_inx = strlen(canonical_path); - if ((cur_inx + file_part.length() + 1) > PATH_MAX) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - canonical_path[cur_inx++] = separator_; - if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != - EOK) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - } else { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - } - if (create) { - fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); - } else { - fd = open(canonical_path, O_RDWR); - } - if (fd == -1) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - *file_descriptor = fd; - return Status::OK(); -} - -Status Path::CloseFile(int fd) const { - if (close(fd) < 0) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - return Status::OK(); -} - -Status Path::TruncateFile(int fd) const { - int rc; - rc = ftruncate(fd, 0); - if (rc == 0) { - return Status::OK(); - } else { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } -} - -std::string Path::Basename() { - std::size_t found = path_.find_last_of(separator_); - if (found != std::string::npos) { - return path_.substr(found + 1); - } else { - return path_; - } -} - -std::shared_ptr Path::DirIterator::OpenDirectory(Path *f) { - auto it = new (std::nothrow) DirIterator(f); - - if (it == nullptr) { - return nullptr; - } - - if (it->dp_) { - return std::shared_ptr(it); - } else { - delete it; - return nullptr; - } -} - -Path::DirIterator::~DirIterator() { - if (dp_) { - (void)closedir(dp_); - } - dp_ = nullptr; - dir_ = nullptr; - entry_ = nullptr; -} - -Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { - MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; - dp_ = opendir(f->toString().c_str()); -} - -bool Path::DirIterator::hasNext() { - do { - entry_ = readdir(dp_); - if (entry_) { - if (strcmp(entry_->d_name, ".") == 0 || strcmp(entry_->d_name, "..") == 0) { - continue; - } - } - break; - } while (true); - return (entry_ != nullptr); -} - -Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } - -std::ostream &operator<<(std::ostream &os, const Path &s) { - os << s.path_; - return os; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/path.h b/mindspore/ccsrc/dataset/util/path.h deleted file mode 100644 index fbf65b8c23..0000000000 --- a/mindspore/ccsrc/dataset/util/path.h +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_PATH_H_ -#define DATASET_UTIL_PATH_H_ - -#include -#include -#include - -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class Path { - public: - class DirIterator { - public: - static std::shared_ptr OpenDirectory(Path *f); - - ~DirIterator(); - - bool hasNext(); - - Path next(); - - private: - explicit DirIterator(Path *f); - - Path *dir_; - DIR *dp_; - struct dirent *entry_; - }; - - explicit Path(const std::string &); - - explicit Path(const char *); - - ~Path() = default; - - Path(const Path &); - - Path &operator=(const Path &); - - Path(Path &&) noexcept; - - Path &operator=(Path &&) noexcept; - - std::string toString() const { return path_; } - - Path operator+(const Path &); - - Path operator+(const std::string &); - - Path operator+(const char *); - - Path &operator+=(const Path &rhs); - - Path &operator+=(const std::string &); - - Path &operator+=(const char *); - - Path operator/(const Path &); - - Path operator/(const std::string &); - - Path operator/(const char *); - - bool Exists(); - - bool IsDirectory(); - - Status CreateDirectory(); - - Status CreateDirectories(); - - std::string Extension() const; - - std::string ParentPath(); - - Status Remove(); - - Status CreateFile(int *fd); - - Status OpenFile(int *fd, bool create = false); - - Status CloseFile(int fd) const; - - Status TruncateFile(int fd) const; - - std::string Basename(); - - friend std::ostream &operator<<(std::ostream &os, const Path &s); - - private: - static char separator_; - std::string path_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_PATH_H_ diff --git a/mindspore/ccsrc/dataset/util/queue.h b/mindspore/ccsrc/dataset/util/queue.h deleted file mode 100644 index 52309962d5..0000000000 --- a/mindspore/ccsrc/dataset/util/queue.h +++ /dev/null @@ -1,256 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_QUEUE_H_ -#define DATASET_UTIL_QUEUE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "utils/log_adapter.h" -#include "dataset/util/allocator.h" -#include "dataset/util/services.h" -#include "dataset/util/cond_var.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -template -struct is_shared_ptr : public std::false_type {}; - -template -struct is_shared_ptr> : public std::true_type {}; - -template -struct is_unique_ptr : public std::false_type {}; - -template -struct is_unique_ptr> : public std::true_type {}; - -// A simple thread safe queue using a fixed size array -template -class Queue { - public: - using value_type = T; - using pointer = T *; - using const_pointer = const T *; - using reference = T &; - using const_reference = const T &; - - void Init() { - if (sz_ > 0) { - // We allocate a block of memory and then call the default constructor for each slot. Maybe simpler to call - // new[] but we want to control where the memory is allocated from. - arr_ = alloc_.allocate(sz_); - for (uint64_t i = 0; i < sz_; i++) { - std::allocator_traits>::construct(alloc_, &(arr_[i])); - } - } - } - - explicit Queue(int sz) - : sz_(sz), - arr_(nullptr), - head_(0), - tail_(0), - my_name_(Services::GetUniqueID()), - alloc_(Services::GetInstance().GetServiceMemPool()) { - Init(); - MS_LOG(DEBUG) << "Create Q with uuid " << my_name_ << " of size " << sz_ << "."; - } - - virtual ~Queue() { - ResetQue(); - if (arr_) { - // Simply free the pointer. Since there is nothing in the queue. We don't want to invoke the destructor - // of T in each slot. - alloc_.deallocate(arr_); - arr_ = nullptr; - } - } - - int size() const { - int v = tail_ - head_; - return (v >= 0) ? v : 0; - } - - int capacity() const { return sz_; } - - bool empty() const { return head_ == tail_; } - - void Reset() { ResetQue(); } - - // Producer - Status Add(const_reference ele) noexcept { - std::unique_lock _lock(mux_); - // Block when full - Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); - if (rc.IsOk()) { - uint32_t k = tail_++ % sz_; - arr_[k] = ele; - empty_cv_.NotifyAll(); - _lock.unlock(); - } else { - empty_cv_.Interrupt(); - } - return rc; - } - - Status Add(T &&ele) noexcept { - std::unique_lock _lock(mux_); - // Block when full - Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); - if (rc.IsOk()) { - uint32_t k = tail_++ % sz_; - arr_[k] = std::forward(ele); - empty_cv_.NotifyAll(); - _lock.unlock(); - } else { - empty_cv_.Interrupt(); - } - return rc; - } - - template - Status EmplaceBack(Ts &&... args) noexcept { - std::unique_lock _lock(mux_); - // Block when full - Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); - if (rc.IsOk()) { - uint32_t k = tail_++ % sz_; - new (&(arr_[k])) T(std::forward(args)...); - empty_cv_.NotifyAll(); - _lock.unlock(); - } else { - empty_cv_.Interrupt(); - } - return rc; - } - - // Consumer - Status PopFront(pointer p) { - std::unique_lock _lock(mux_); - // Block when empty - Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !empty(); }); - if (rc.IsOk()) { - uint32_t k = head_++ % sz_; - *p = std::move(arr_[k]); - if (std::is_destructible::value) { - // std::move above only changes arr_[k] from rvalue to lvalue. - // The real implementation of move constructor depends on T. - // It may be compiler generated or user defined. But either case - // the result of arr_[k] is still a valid object of type T, and - // we will not keep any extra copy in the queue. - arr_[k].~T(); - // For gcc 9, an extra fix is needed here to clear the memory content - // of arr_[k] because this slot can be reused by another Add which can - // do another std::move. We have seen SEGV here in this case. - std::allocator_traits>::construct(alloc_, &(arr_[k])); - } - full_cv_.NotifyAll(); - _lock.unlock(); - } else { - full_cv_.Interrupt(); - } - return rc; - } - - void ResetQue() noexcept { - std::unique_lock _lock(mux_); - // If there are elements in the queue, invoke its destructor one by one. - if (!empty() && std::is_destructible::value) { - for (uint64_t i = head_; i < tail_; i++) { - uint32_t k = i % sz_; - arr_[k].~T(); - } - } - for (uint64_t i = 0; i < sz_; i++) { - std::allocator_traits>::construct(alloc_, &(arr_[i])); - } - empty_cv_.ResetIntrpState(); - full_cv_.ResetIntrpState(); - head_ = 0; - tail_ = 0; - } - - Status Register(TaskGroup *vg) { - Status rc1 = empty_cv_.Register(vg->GetIntrpService()); - Status rc2 = full_cv_.Register(vg->GetIntrpService()); - if (rc1.IsOk()) { - return rc2; - } else { - return rc1; - } - } - - private: - uint64_t sz_; - pointer arr_; - uint64_t head_; - uint64_t tail_; - std::string my_name_; - std::mutex mux_; - CondVar empty_cv_; - CondVar full_cv_; - Allocator alloc_; -}; - -// A container of queues with [] operator accessors. Basically this is a wrapper over of a vector of queues -// to help abstract/simplify code that is maintaining multiple queues. -template -class QueueList { - public: - QueueList() {} - - void Init(int num_queues, int capacity) { - queue_list_.reserve(num_queues); - for (int i = 0; i < num_queues; i++) { - queue_list_.emplace_back(std::make_unique>(capacity)); - } - } - - Status Register(TaskGroup *vg) { - if (vg == nullptr) { - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Null task group during QueueList registration."); - } - for (int i = 0; i < queue_list_.size(); ++i) { - RETURN_IF_NOT_OK(queue_list_[i]->Register(vg)); - } - return Status::OK(); - } - - int size() const { return queue_list_.size(); } - - std::unique_ptr> &operator[](const int index) { return queue_list_[index]; } - - const std::unique_ptr> &operator[](const int index) const { return queue_list_[index]; } - - ~QueueList() = default; - - private: - // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector - // requirement that objects must have copy semantics. To resolve this, we use a vector of unique - // pointers. This allows us to provide dynamic creation of queues in a container. - std::vector>> queue_list_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_QUEUE_H_ diff --git a/mindspore/ccsrc/dataset/util/random.h b/mindspore/ccsrc/dataset/util/random.h deleted file mode 100644 index 957a4214a8..0000000000 --- a/mindspore/ccsrc/dataset/util/random.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_RANDOM_H_ -#define DATASET_UTIL_RANDOM_H_ - -#if defined(_WIN32) || defined(_WIN64) -#include -#endif -#include -#include -#include -#include -#include -#include - -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -inline std::mt19937 GetRandomDevice() { -#if defined(_WIN32) || defined(_WIN64) - unsigned int number; - rand_s(&number); - std::mt19937 random_device{static_cast(number)}; -#else - int i = 0; - while (i < 5) { - try { - std::mt19937 random_device{std::random_device("/dev/urandom")()}; - return random_device; - } catch (const std::exception &e) { - MS_LOG(WARNING) << "Get std::random_device failed, retry: " << i << ", error: " << e.what(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - i++; - } - } - std::mt19937 random_device{std::random_device("/dev/urandom")()}; -#endif - return random_device; -} - -inline uint32_t GetNewSeed() { - std::mt19937 random_device = GetRandomDevice(); - std::uniform_int_distribution distribution(0, std::numeric_limits::max()); - return distribution(random_device); -} - -inline uint32_t GetSeed() { - uint32_t seed = GlobalContext::config_manager()->seed(); - if (seed == std::mt19937::default_seed) { - seed = GetNewSeed(); - } - return seed; -} - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_RANDOM_H_ diff --git a/mindspore/ccsrc/dataset/util/semaphore.cc b/mindspore/ccsrc/dataset/util/semaphore.cc deleted file mode 100644 index 36ddf5511d..0000000000 --- a/mindspore/ccsrc/dataset/util/semaphore.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/semaphore.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -Status Semaphore::P() { - std::unique_lock lck(mutex_); - RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); - --value_; - return Status::OK(); -} -void Semaphore::V() { - std::unique_lock lck(mutex_); - ++value_; - wait_cond_.NotifyOne(); -} -int Semaphore::Peek() { - std::unique_lock lck(mutex_); - return value_; -} -Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } -Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } -void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/semaphore.h b/mindspore/ccsrc/dataset/util/semaphore.h deleted file mode 100644 index 07b9e83e7f..0000000000 --- a/mindspore/ccsrc/dataset/util/semaphore.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SEMAPHORE_H_ -#define DATASET_UTIL_SEMAPHORE_H_ - -#include "dataset/util/cond_var.h" - -namespace mindspore { -namespace dataset { -class TaskGroup; - -/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be -/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. -class Semaphore { - public: - /// \brief Constructor - /// \param init Initial value of the internal counter. - explicit Semaphore(int init) : value_(init) {} - - virtual ~Semaphore() {} - /// \brief Decrement the internal counter. Will be blocked if the value is 0. - /// \return Error code. Can get interrupt. - Status P(); - /// \brief Increment the internal counter. Wakeup on of the watiers if any. - void V(); - /// \brief Peek the internal value - /// \return The internal value - int Peek(); - Status Register(TaskGroup *vg); - Status Deregister(); - void ResetIntrpState(); - - private: - int value_; - - std::mutex mutex_; - CondVar wait_cond_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/dataset/util/service.cc b/mindspore/ccsrc/dataset/util/service.cc deleted file mode 100644 index c89f7287f6..0000000000 --- a/mindspore/ccsrc/dataset/util/service.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/service.h" -#include - -namespace mindspore { -namespace dataset { -Status Service::ServiceStart() { - do { - UniqueLock lck(&state_lock_); - // No-op if it is already up or some other thread is - // in the process of bring it up. - if (state_ == STATE::kRunning || state_ == STATE::kStartInProg) { - return Status::OK(); - } - // If a stop is in progress, we line up after it - // is done. - if (state_ == STATE::kStopInProg) { - std::this_thread::yield(); - } else { - state_ = STATE::kStartInProg; - // At this point, we will let go of the lock. This allow others to proceed. - lck.Unlock(); - RETURN_IF_NOT_OK(DoServiceStart()); - // Lock again to change state. - lck.Lock(); - state_ = STATE::kRunning; - return Status::OK(); - } - } while (true); -} - -Status Service::ServiceStop() noexcept { - do { - UniqueLock lck(&state_lock_); - // No-op if it is already stopped or some other thread is - // in the process of shutting it down - if (state_ == STATE::kStopped || state_ == STATE::kStopInProg) { - return Status::OK(); - } - // If a start is in progress, we line up after it - // is done. - if (state_ == STATE::kStartInProg) { - std::this_thread::yield(); - } else { - state_ = STATE::kStopInProg; - // At this point, we will let go of the lock. This allows others to proceed. - lck.Unlock(); - RETURN_IF_NOT_OK(DoServiceStop()); - // Lock again to change state. - lck.Lock(); - state_ = STATE::kStopped; - return Status::OK(); - } - } while (true); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/service.h b/mindspore/ccsrc/dataset/util/service.h deleted file mode 100644 index 1113fc1d14..0000000000 --- a/mindspore/ccsrc/dataset/util/service.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SERVICE_H_ -#define DATASET_UTIL_SERVICE_H_ - -#include -#include "dataset/util/lock.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class Service { - public: - enum class STATE : int { kStartInProg = 1, kRunning, kStopInProg, kStopped }; - - Service() : state_(STATE::kStopped) {} - - Service(const Service &) = delete; - - Service &operator=(const Service &) = delete; - - virtual ~Service() {} - - STATE ServiceState() const { return state_; } - - virtual Status DoServiceStart() = 0; - - virtual Status DoServiceStop() = 0; - - Status ServiceStart(); - - Status ServiceStop() noexcept; - - protected: - STATE state_; - RWLock state_lock_; -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/util/services.cc b/mindspore/ccsrc/dataset/util/services.cc deleted file mode 100644 index 755d217311..0000000000 --- a/mindspore/ccsrc/dataset/util/services.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/services.h" - -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#else -#include -#endif -#include -#include "dataset/engine/cache/cache_server.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/random.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -std::unique_ptr Services::instance_ = nullptr; -std::once_flag Services::init_instance_flag_; - -#if !defined(_WIN32) && !defined(_WIN64) -std::string Services::GetUserName() { - char user[LOGIN_NAME_MAX]; - (void)getlogin_r(user, sizeof(user)); - return std::string(user); -} - -std::string Services::GetHostName() { - char host[LOGIN_NAME_MAX]; - (void)gethostname(host, sizeof(host)); - return std::string(host); -} - -int Services::GetLWP() { return syscall(SYS_gettid); } -#endif - -std::string Services::GetUniqueID() { - const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; - std::mt19937 gen = GetRandomDevice(); - std::uniform_int_distribution dist(0, kStr.size() - 1); - char buffer[UNIQUEID_LEN]; - for (int i = 0; i < UNIQUEID_LEN; i++) { - buffer[i] = kStr[dist(gen)]; - } - return std::string(buffer, UNIQUEID_LEN); -} - -TaskManager &Services::getTaskMgrInstance() { - Services &sm = GetInstance(); - return *(static_cast(sm.sa_[kSlotTaskMgr_])); -} - -CacheServer &Services::getCacheServer() { - Services &sm = GetInstance(); - return *(static_cast(sm.sa_[kSlotCacheMgr_])); -} - -Status Services::CreateAllInstances() { - // In order, TaskMgr, BufferMgr - Status rc; - sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); - RETURN_IF_NOT_OK(rc); - rc = sa_[kSlotTaskMgr_]->ServiceStart(); - RETURN_IF_NOT_OK(rc); - // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers - sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); - RETURN_IF_NOT_OK(rc); - rc = sa_[kSlotCacheMgr_]->ServiceStart(); - return rc; -} - -Services::Services() : pool_(nullptr), sa_{nullptr} { - Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M - if (rc.IsError()) { - std::terminate(); - } -} - -Services::~Services() noexcept { - try { - // In reverse order - CacheServer *cs = static_cast(sa_[kSlotCacheMgr_]); - if (cs != nullptr) { - (void)cs->ServiceStop(); - cs->~CacheServer(); - pool_->Deallocate(cs); - } - TaskManager *tm = static_cast(sa_[kSlotTaskMgr_]); - if (tm != nullptr) { - (void)tm->ServiceStop(); - tm->~TaskManager(); - pool_->Deallocate(tm); - } - } catch (const std::exception &e) { - // Do nothing. - } -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/services.h b/mindspore/ccsrc/dataset/util/services.h deleted file mode 100644 index e82b3e47f1..0000000000 --- a/mindspore/ccsrc/dataset/util/services.h +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SERVICES_H_ -#define DATASET_UTIL_SERVICES_H_ - -#include -#include -#include -#include "dataset/util/memory_pool.h" -#include "dataset/util/allocator.h" -#include "dataset/util/service.h" - -#define UNIQUEID_LEN 36 -namespace mindspore { -namespace dataset { -class TaskManager; -class CacheServer; -class Services { - public: - static Status CreateInstance() { - std::call_once(init_instance_flag_, [&]() -> Status { - instance_.reset(new Services()); - return (instance_->CreateAllInstances()); - }); - - if (instance_ == nullptr) { - instance_.reset(new Services()); - return (instance_->CreateAllInstances()); - } - - return Status::OK(); - } - - static Services &GetInstance() { - if (instance_ == nullptr) { - if (!CreateInstance()) { - std::terminate(); - } - } - return *instance_; - } - - Services(const Services &) = delete; - - Services &operator=(const Services &) = delete; - - ~Services() noexcept; - - static TaskManager &getTaskMgrInstance(); - - static CacheServer &getCacheServer(); - - std::shared_ptr GetServiceMemPool() { return pool_; } - -#if !defined(_WIN32) && !defined(_WIN64) - static std::string GetUserName(); - - static std::string GetHostName(); - - static int GetLWP(); -#endif - - static std::string GetUniqueID(); - - template - static Allocator GetAllocator() { - return Allocator(Services::GetInstance().GetServiceMemPool()); - } - - private: - static std::once_flag init_instance_flag_; - static std::unique_ptr instance_; - // A small pool used for small objects that last until the - // Services Manager shuts down. Used by all sub-services. - std::shared_ptr pool_; - // We use pointers here instead of unique_ptr because we - // want to have ultimate control on the order of - // construction and destruction. - static constexpr int kSlotTaskMgr_ = 0; - static constexpr int kSlotCacheMgr_ = 1; - static constexpr int kNumServices_ = 2; - Service *sa_[kNumServices_]; - - Services(); - - Status CreateAllInstances(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_SERVICES_H_ diff --git a/mindspore/ccsrc/dataset/util/sig_handler.cc b/mindspore/ccsrc/dataset/util/sig_handler.cc deleted file mode 100644 index 644a633066..0000000000 --- a/mindspore/ccsrc/dataset/util/sig_handler.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/sig_handler.h" -#include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#endif -#include -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// Register the custom signal handlers -#if !defined(_WIN32) && !defined(_WIN64) -void RegisterHandlers() { - struct sigaction new_int_action; - - // For the interrupt handler, we do not use SA_RESETHAND so this handler remains in play - // permanently, do not use the OS default handler for it. - new_int_action.sa_sigaction = &IntHandler; - (void)sigemptyset(&new_int_action.sa_mask); - new_int_action.sa_flags = SA_RESTART | SA_SIGINFO; - (void)sigaction(SIGINT, &new_int_action, nullptr); -} - -extern void IntHandler(int sig_num, // The signal that was raised - siginfo_t *sig_info, // The siginfo structure. - void *context) { // context info - // Wake up the watchdog which is designed as async-signal-safe. - TaskManager::WakeUpWatchDog(); -} -#endif -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/slice.cc b/mindspore/ccsrc/dataset/util/slice.cc deleted file mode 100644 index f1798b4f44..0000000000 --- a/mindspore/ccsrc/dataset/util/slice.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - - * http://www.apache.org/licenses/LICENSE-2.0 - - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. -*/ -#include "dataset/util/slice.h" - -namespace mindspore { -namespace dataset { -WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { - mutable_data_ = static_cast(src.mutable_data_) + offset; -} -WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) - : WritableSlice(src, offset, src.GetSize() - offset) {} -Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { - RETURN_UNEXPECTED_IF_NULL(dest); - RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); - if (dest->GetSize() <= 0) { - RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); - } - auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); - if (err) { - RETURN_STATUS_UNEXPECTED(std::to_string(err)); - } - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/slice.h b/mindspore/ccsrc/dataset/util/slice.h deleted file mode 100644 index b44f4d6a39..0000000000 --- a/mindspore/ccsrc/dataset/util/slice.h +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SLICE_H_ -#define DATASET_UTIL_SLICE_H_ - -#include -#include -#include -#include "./securec.h" -#include "dataset/util/allocator.h" -#include "dataset/util/status.h" -namespace mindspore { -namespace dataset { -/// \brief A ReadableSlice wraps a const pointer in memory and its size. -/// \see WritableSlice for a non-const version -/// -class ReadableSlice { - public: - ReadableSlice() : ptr_(nullptr), sz_(0) {} - ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} - - /// \brief Destructor - ~ReadableSlice() = default; - - ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { - ptr_ = static_cast(src.GetPointer()) + offset; - sz_ = len; - } - ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} - ReadableSlice(const ReadableSlice &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - } - ReadableSlice &operator=(const ReadableSlice &lhs) { - if (this != &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - } - return *this; - } - ReadableSlice(ReadableSlice &&lhs) noexcept { - if (this != &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - lhs.ptr_ = nullptr; - lhs.sz_ = 0; - } - } - ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { - if (this != &lhs) { - ptr_ = lhs.ptr_; - sz_ = lhs.sz_; - lhs.ptr_ = nullptr; - lhs.sz_ = 0; - } - return *this; - } - /// \brief Getter function - /// \return Const version of the pointer - const void *GetPointer() const { return ptr_; } - /// \brief Getter function - /// \return Size of the slice - size_t GetSize() const { return sz_; } - bool empty() const { return ptr_ == nullptr; } - - private: - const void *ptr_; - size_t sz_; -}; -/// \brief A WritableSlice inherits from ReadableSlice to allow -/// one to write to the address pointed to by the pointer. -/// -class WritableSlice : public ReadableSlice { - public: - friend class StorageContainer; - /// \brief Default constructor - WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} - /// \brief This form of a constructor takes a pointer and its size. - WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} - WritableSlice(const WritableSlice &src, off64_t offset, size_t len); - WritableSlice(const WritableSlice &src, off64_t offset); - WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } - /// \brief Destructor - ~WritableSlice() = default; - WritableSlice &operator=(const WritableSlice &lhs) { - if (this != &lhs) { - mutable_data_ = lhs.mutable_data_; - ReadableSlice::operator=(lhs); - } - return *this; - } - WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { - if (this != &lhs) { - mutable_data_ = lhs.mutable_data_; - lhs.mutable_data_ = nullptr; - } - } - WritableSlice &operator=(WritableSlice &&lhs) noexcept { - if (this != &lhs) { - mutable_data_ = lhs.mutable_data_; - lhs.mutable_data_ = nullptr; - ReadableSlice::operator=(std::move(lhs)); - } - return *this; - } - /// \brief Copy the content from one slice onto another. - static Status Copy(WritableSlice *dest, const ReadableSlice &src); - - private: - void *mutable_data_; - void *GetMutablePointer() { return mutable_data_; } -}; -} // namespace dataset -} // namespace mindspore -#endif // DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/dataset/util/status.cc b/mindspore/ccsrc/dataset/util/status.cc deleted file mode 100644 index 27e9dfbc83..0000000000 --- a/mindspore/ccsrc/dataset/util/status.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/status.h" -#include -#include "common/utils.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -std::string CodeAsString(const StatusCode c) { - const char *s = nullptr; - if (c == StatusCode::kOK) { - // Optimize the most frequent case - return std::string("OK"); - } else { - switch (c) { - case StatusCode::kOutOfMemory: - s = "Out of memory"; - break; - case StatusCode::kInterrupted: - s = "Interrupted system call"; - break; - case StatusCode::kShapeMisMatch: - s = "Shape is incorrect."; - break; - case StatusCode::kNoSpace: - s = "No space left on device"; - break; - case StatusCode::kPyFuncException: - s = "Exception thrown from PyFunc"; - break; - case StatusCode::kDuplicateKey: - s = "Duplicate key"; - break; - case StatusCode::kProfilingError: - s = "Error encountered while profiling"; - break; - case StatusCode::kUnexpectedError: - default: - s = "Unexpected error"; - break; - } - } - return std::string(s); -} - -Status::Status(StatusCode c) noexcept : code_(c), err_msg_(std::move(CodeAsString(c))) {} - -Status::Status() noexcept : code_(StatusCode::kOK), err_msg_("") {} - -Status::~Status() noexcept {} - -Status::Status(const Status &s) : code_(s.code_), err_msg_(s.err_msg_) {} - -Status &Status::operator=(const Status &s) { - if (this == &s) { - return *this; - } - code_ = s.code_; - err_msg_ = s.err_msg_; - return *this; -} - -Status::Status(Status &&s) noexcept { - code_ = s.code_; - s.code_ = StatusCode::kOK; - err_msg_ = std::move(s.err_msg_); -} - -Status &Status::operator=(Status &&s) noexcept { - if (this == &s) { - return *this; - } - code_ = s.code_; - s.code_ = StatusCode::kOK; - err_msg_ = std::move(s.err_msg_); - return *this; -} - -Status::Status(const StatusCode code, const std::string &msg) : code_(code), err_msg_(msg) {} - -Status::Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra) { - code_ = code; - std::ostringstream ss; - ss << "Thread ID " << this_thread::get_id() << " " << CodeAsString(code) << ". "; - if (!extra.empty()) { - ss << extra; - } - ss << "\n"; - ss << "Line of code : " << line_of_code << "\n"; - if (file_name != nullptr) { - ss << "File : " << file_name << "\n"; - } - err_msg_ = ss.str(); - MS_LOG(INFO) << err_msg_; -} - -std::ostream &operator<<(std::ostream &os, const Status &s) { - os << s.ToString(); - return os; -} - -std::string Status::ToString() const { return err_msg_; } - -StatusCode Status::get_code() const { return code_; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_container.cc b/mindspore/ccsrc/dataset/util/storage_container.cc deleted file mode 100644 index 3a4c13e2d9..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_container.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/storage_container.h" - -#include -#include -#include -#include -#include "common/utils.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -Status StorageContainer::Create() { - RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); - RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); - is_open_ = true; - MS_LOG(INFO) << "Container " << cont_ << " created"; - return Status::OK(); -} - -Status StorageContainer::Open() noexcept { - std::lock_guard lck(mutex_); - // Check again - if (!is_open_) { - RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); - is_open_ = true; - } - return Status::OK(); -} - -Status StorageContainer::Close() noexcept { - if (is_open_) { - std::lock_guard lck(mutex_); - // Check again - if (is_open_) { - RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); - is_open_ = false; - fd_ = -1; - } - } - return Status::OK(); -} - -Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { - MS_ASSERT(is_open_); - RETURN_UNEXPECTED_IF_NULL(dest); - auto sz = dest->GetSize(); -#if defined(_WIN32) || defined(_WIN64) - // Doesn't seem there is any pread64 on mingw. - // So we will do a seek and then a read under - // a protection of mutex. - std::lock_guard lck(mutex_); - auto seek_err = lseek(fd_, offset, SEEK_SET); - if (seek_err < 0) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - auto r_sz = read(fd_, dest->GetMutablePointer(), sz); -#else - auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); -#endif - if (r_sz != sz) { - errno_t err = (r_sz == 0) ? EOF : errno; - RETURN_STATUS_UNEXPECTED(strerror(err)); - } - return Status::OK(); -} - -Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { - MS_ASSERT(is_open_); - auto sz = dest.GetSize(); -#if defined(_WIN32) || defined(_WIN64) - // Doesn't seem there is any pwrite64 on mingw. - // So we will do a seek and then a read under - // a protection of mutex. - std::lock_guard lck(mutex_); - auto seek_err = lseek(fd_, offset, SEEK_SET); - if (seek_err < 0) { - RETURN_STATUS_UNEXPECTED(strerror(errno)); - } - auto r_sz = write(fd_, dest.GetPointer(), sz); -#else - auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); -#endif - if (r_sz != sz) { - errno_t err = (r_sz == 0) ? EOF : errno; - RETURN_STATUS_UNEXPECTED(strerror(err)); - } - return Status::OK(); -} - -Status StorageContainer::Insert(const std::vector &buf, off64_t *offset) noexcept { - size_t sz = 0; - for (auto &v : buf) { - sz += v.GetSize(); - } - if (sz == 0) { - RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); - } - if (sz > bs_->GetMaxSize()) { - RETURN_STATUS_UNEXPECTED("Request size too big"); - } - BSpaceDescriptor bspd{0}; - addr_t addr = 0; - RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); - *offset = static_cast(addr); - // We will do piecewise copy of the data to disk. - for (auto &v : buf) { - RETURN_IF_NOT_OK(Write(v, addr)); - addr += v.GetSize(); - } - return Status::OK(); -} - -Status StorageContainer::Truncate() const noexcept { - if (is_open_) { - RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); - MS_LOG(INFO) << "Container " << cont_ << " truncated"; - } - return Status::OK(); -} - -StorageContainer::~StorageContainer() noexcept { - (void)Truncate(); - (void)Close(); -} - -std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { - os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); - return os; -} - -Status StorageContainer::CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path) { - Status rc; - auto sc = new (std::nothrow) StorageContainer(path); - if (sc == nullptr) { - return Status(StatusCode::kOutOfMemory); - } - rc = sc->Create(); - if (rc.IsOk()) { - (*out_sc).reset(sc); - } else { - delete sc; - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_container.h b/mindspore/ccsrc/dataset/util/storage_container.h deleted file mode 100644 index 07e41bd66a..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_container.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ -#define DATASET_UTIL_STORAGE_CONTAINER_H_ - -#include -#include -#include -#include -#include -#include -#include "dataset/util/system_pool.h" -#include "dataset/util/buddy.h" -#include "dataset/util/path.h" -#include "dataset/util/slice.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class StorageManager; - -class StorageContainer { - public: - friend class StorageManager; - - ~StorageContainer() noexcept; - - StorageContainer(const StorageContainer &) = delete; - - StorageContainer &operator=(const StorageContainer &) = delete; - - friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); - - Status Open() noexcept; - - Status Close() noexcept; - - Status Insert(const std::vector &buf, off64_t *offset) noexcept; - - Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; - - Status Read(WritableSlice *dest, off64_t offset) const noexcept; - - Status Truncate() const noexcept; - - bool IsOpen() const { return is_open_; } - - static Status CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path); - - private: - mutable std::mutex mutex_; - Path cont_; - int fd_; - bool is_open_; - std::unique_ptr bs_; - - // Use the default value of BuddySpace - // which can map upto 4G of space. - explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} - - Status Create(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/dataset/util/storage_manager.cc b/mindspore/ccsrc/dataset/util/storage_manager.cc deleted file mode 100644 index 1d958576ba..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_manager.cc +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/storage_manager.h" - -#include -#include -#include -#include -#include "common/utils.h" -#include "dataset/util/path.h" -#include "dataset/util/services.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { - std::ostringstream oss; - oss << prefix << std::setfill('0') << std::setw(5) << file_id; - return oss.str(); -} - -std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { - std::string base_name = GetBaseName(prefix, file_id); - return (base_name + "." + suffix); -} - -Status StorageManager::AddOneContainer() { - const std::string kPrefix = "IMG"; - const std::string kSuffix = "LB"; - Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); - std::shared_ptr sc; - RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); - containers_.push_back(sc); - file_id_++; - return Status::OK(); -} - -Status StorageManager::DoServiceStart() { - containers_.reserve(1000); - if (root_.IsDirectory()) { - RETURN_IF_NOT_OK(AddOneContainer()); - } else { - RETURN_STATUS_UNEXPECTED("Not a directory"); - } - return Status::OK(); -} - -Status StorageManager::Write(key_type *key, const std::vector &buf) { - RETURN_UNEXPECTED_IF_NULL(key); - size_t sz = 0; - for (auto &v : buf) { - sz += v.GetSize(); - } - if (sz == 0) { - RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); - } - std::shared_ptr cont; - key_type out_key; - value_type out_value; - bool create_new_container = false; - do { - SharedLock lock_s(&rw_lock_); - size_t num_containers = containers_.size(); - if (create_new_container) { - // Upgrade to exclusvie lock. - lock_s.Upgrade(); - create_new_container = false; - // Check again if someone has already added a - // new container after we got the x lock - if (containers_.size() == num_containers) { - RETURN_IF_NOT_OK(AddOneContainer()); - } - // Refresh how many containers there are. - num_containers = containers_.size(); - // Downgrade back to shared lock - lock_s.Downgrade(); - } - if (num_containers == 0) { - RETURN_STATUS_UNEXPECTED("num_containers is zero"); - } - // Go to the last container to insert. - cont = containers_.at(num_containers - 1); - off64_t offset; - Status rc = cont->Insert(buf, &offset); - if (rc.IsNoSpace()) { - create_new_container = true; - } else if (rc.IsOk()) { - out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); - RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); - *key = out_key; - break; - } else { - return rc; - } - } while (true); - return Status::OK(); -} - -Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { - RETURN_UNEXPECTED_IF_NULL(dest); - auto r = index_.Search(key); - if (r.second) { - auto &it = r.first; - value_type v = *it; - int container_inx = v.first; - off_t offset = v.second.first; - size_t sz = v.second.second; - if (dest->GetSize() < sz) { - std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + - " but length = " + std::to_string(dest->GetSize()); - RETURN_STATUS_UNEXPECTED(errMsg); - } - if (bytesRead != nullptr) { - *bytesRead = sz; - } - auto cont = containers_.at(container_inx); - RETURN_IF_NOT_OK(cont->Read(dest, offset)); - } else { - RETURN_STATUS_UNEXPECTED("Key not found"); - } - return Status::OK(); -} - -Status StorageManager::DoServiceStop() noexcept { - Status rc; - Status rc1; - for (auto const &p : containers_) { - // The destructor of StorageContainer is not called automatically until the use - // count drops to 0. But it is not always the case. We will do it ourselves. - rc = p.get()->Truncate(); - if (rc.IsError()) { - rc1 = rc; - } - } - containers_.clear(); - file_id_ = 0; - return rc1; -} - -StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} - -StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } - -std::ostream &operator<<(std::ostream &os, const StorageManager &s) { - os << "Dumping all containers ..." - << "\n"; - for (auto const &p : s.containers_) { - os << *(p.get()); - } - return os; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/storage_manager.h b/mindspore/ccsrc/dataset/util/storage_manager.h deleted file mode 100644 index 075ac713d2..0000000000 --- a/mindspore/ccsrc/dataset/util/storage_manager.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ -#define DATASET_UTIL_STORAGE_MANAGER_H_ - -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/auto_index.h" -#include "dataset/util/lock.h" -#include "dataset/util/memory_pool.h" -#include "dataset/util/path.h" -#include "dataset/util/service.h" -#include "dataset/util/slice.h" -#include "dataset/util/storage_container.h" - -using ListOfContainers = std::vector>; -namespace mindspore { -namespace dataset { -class StorageManager : public Service { - public: - using storage_index = AutoIndexObj>>; - using key_type = storage_index::key_type; - using value_type = storage_index::value_type; - - explicit StorageManager(const Path &); - - ~StorageManager() override; - - StorageManager(const StorageManager &) = delete; - - StorageManager &operator=(const StorageManager &) = delete; - - Status Write(key_type *out_key, const std::vector &buf); - - Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; - - Status DoServiceStart() override; - - Status DoServiceStop() noexcept override; - - friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); - - private: - Path root_; - ListOfContainers containers_; - int file_id_; - RWLock rw_lock_; - storage_index index_; - - std::string GetBaseName(const std::string &prefix, int32_t file_id); - - std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); - - Status AddOneContainer(); -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/system_pool.h b/mindspore/ccsrc/dataset/util/system_pool.h deleted file mode 100644 index 286e30a615..0000000000 --- a/mindspore/ccsrc/dataset/util/system_pool.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_SYSTEM_POOL_H_ -#define DATASET_UTIL_SYSTEM_POOL_H_ - -#include -#include -#include -#include -#include -#include "./securec.h" -#include "dataset/util/allocator.h" -#include "dataset/util/memory_pool.h" - -namespace mindspore { -namespace dataset { -// This class demonstrate how to implement a simple MemoryPool -// for minddata/dataset using malloc/free/realloc. We need to -// implement 4 virtual functions. Other MemoryPool -// implementation, e.g., are BuddyArena and CircularPool. All -// these MemoryPool can be used together with Allocator.h for -// C++ STL containers. -class SystemPool : public MemoryPool { - public: - ~SystemPool() override {} - - Status Allocate(size_t n, void **pp) override { return DeMalloc(n, pp, false); } - - void Deallocate(void *p) override { free(p); } - - Status Reallocate(void **p, size_t old_sz, size_t new_sz) override { - if (old_sz >= new_sz) { - // Do nothing if we shrink. - return Status::OK(); - } else { - void *ptr = *p; - void *q = nullptr; - RETURN_IF_NOT_OK(DeMalloc(new_sz, &q, false)); - errno_t err = memcpy_s(q, new_sz, ptr, old_sz); - if (err) { - free(q); - RETURN_STATUS_UNEXPECTED(std::to_string(err)); - } - free(ptr); - *p = q; - return Status::OK(); - } - } - - uint64_t get_max_size() const override { return std::numeric_limits::max(); } - - int PercentFree() const override { return 100; } - - template - static Allocator GetAllocator() { - return Allocator(std::make_shared()); - } -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_SYSTEM_POOL_H_ diff --git a/mindspore/ccsrc/dataset/util/task.cc b/mindspore/ccsrc/dataset/util/task.cc deleted file mode 100644 index 93db55d5f9..0000000000 --- a/mindspore/ccsrc/dataset/util/task.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/task.h" -#include "common/utils.h" -#include "dataset/util/task_manager.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -thread_local Task *gMyTask = nullptr; - -void Task::operator()() { -#if !defined(_WIN32) && !defined(_WIN64) - gMyTask = this; -#endif - id_ = this_thread::get_id(); - std::stringstream ss; - ss << id_; - MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; - try { - // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set - // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can - // get the thread id. - TaskGroup *vg = MyTaskGroup(); - rc_ = vg->GetIntrpService()->Register(ss.str(), this); - if (rc_.IsOk()) { - // Now we can run the given task. - rc_ = fnc_obj_(); - } - // Some error codes are ignored, e.g. interrupt. Others we just shutdown the group. - if (rc_.IsError() && !rc_.IsInterrupted()) { - ShutdownGroup(); - } - } catch (const std::bad_alloc &e) { - rc_ = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what()); - ShutdownGroup(); - } catch (const std::exception &e) { - rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); - ShutdownGroup(); - } -} - -void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. - { - std::lock_guard lk(mux_); - caught_severe_exception_ = true; - } - TaskGroup *vg = MyTaskGroup(); - // If multiple threads hit severe errors in the same group. Keep the first one and - // discard the rest. - if (vg->rc_.IsOk()) { - std::unique_lock rcLock(vg->rc_mux_); - // Check again after we get the lock - if (vg->rc_.IsOk()) { - vg->rc_ = rc_; - rcLock.unlock(); - TaskManager::InterruptMaster(rc_); - TaskManager::InterruptGroup(*this); - } - } -} - -Status Task::GetTaskErrorIfAny() const { - std::lock_guard lk(mux_); - if (caught_severe_exception_) { - return rc_; - } else { - return Status::OK(); - } -} - -Task::Task(const std::string &myName, const std::function &f) - : my_name_(myName), - rc_(), - fnc_obj_(f), - task_group_(nullptr), - is_master_(false), - running_(false), - caught_severe_exception_(false) { - IntrpResource::ResetIntrpState(); - wp_.ResetIntrpState(); - wp_.Clear(); -} - -Status Task::Run() { - Status rc; - if (running_ == false) { - try { - thrd_ = std::async(std::launch::async, std::ref(*this)); - running_ = true; - caught_severe_exception_ = false; - } catch (const std::exception &e) { - rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); - } - } - return rc; -} - -Status Task::Join(WaitFlag blocking) { - if (running_) { - RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); - auto interrupt_svc = MyTaskGroup()->GetIntrpService(); - try { - if (blocking == WaitFlag::kBlocking) { - // If we are asked to wait, then wait - thrd_.get(); - } else if (blocking == WaitFlag::kNonBlocking) { - // There is a race condition in the global resource tracking such that a thread can miss the - // interrupt and becomes blocked on a conditional variable forever. As a result, calling - // join() will not come back. We need some timeout version of join such that if the thread - // doesn't come back in a reasonable of time, we will send the interrupt again. - while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { - // We can't tell which conditional_variable this thread is waiting on. So we may need - // to interrupt everything one more time. - MS_LOG(INFO) << "Some threads not responding. Interrupt again"; - interrupt_svc->InterruptAll(); - } - } else { - RETURN_STATUS_UNEXPECTED("Unknown WaitFlag"); - } - std::stringstream ss; - ss << get_id(); - MS_LOG(DEBUG) << MyName() << " Thread ID " << ss.str() << " Stopped."; - running_ = false; - RETURN_IF_NOT_OK(wp_.Deregister()); - RETURN_IF_NOT_OK(interrupt_svc->Deregister(ss.str())); - } catch (const std::exception &e) { - RETURN_STATUS_UNEXPECTED(e.what()); - } - } - return Status::OK(); -} - -TaskGroup *Task::MyTaskGroup() { return task_group_; } - -void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } - -Task::~Task() { task_group_ = nullptr; } -Status Task::OverrideInterruptRc(const Status &rc) { - if (rc.IsInterrupted() && this_thread::is_master_thread()) { - // If we are interrupted, override the return value if this is the master thread. - // Master thread is being interrupted mostly because of some thread is reporting error. - return TaskManager::GetMasterThreadRc(); - } - return rc; -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/task.h b/mindspore/ccsrc/dataset/util/task.h deleted file mode 100644 index 49eb16b182..0000000000 --- a/mindspore/ccsrc/dataset/util/task.h +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_TASK_H_ -#define DATASET_UTIL_TASK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "dataset/util/intrp_resource.h" -#include "dataset/util/list.h" -#include "dataset/util/memory_pool.h" -#include "dataset/util/services.h" -#include "dataset/util/wait_post.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace dataset { -class TaskManager; - -class Task : public IntrpResource { - public: - friend class TaskManager; - friend class TaskGroup; - - enum class WaitFlag : int { kBlocking, kNonBlocking }; - - Task(const std::string &myName, const std::function &f); - - // Future objects are not copyable. - Task(const Task &) = delete; - - ~Task() override; - - Task &operator=(const Task &) = delete; - - // Move constructor and Assignment are not supported. - // Too many things in this class. - Task(Task &&) = delete; - - Task &operator=(Task &&) = delete; - - Status GetTaskErrorIfAny() const; - - void ChangeName(const std::string &newName) { my_name_ = newName; } - - // To execute the _fncObj - void operator()(); - - Node node; - Node group; - Node free; - - // Run the task - Status Run(); - - Status Join(WaitFlag wf = WaitFlag::kBlocking); - - bool Running() const { return running_; } - - bool CaughtSevereException() const { return caught_severe_exception_; } - - bool IsMasterThread() const { return is_master_; } - - std::thread::id get_id() { return id_; } - - std::string MyName() { return my_name_; } - - // An operator used by std::find - bool operator==(const Task &other) const { return (this == &other); } - - bool operator!=(const Task &other) const { return !(*this == other); } - - void Post() { wp_.Set(); } - - Status Wait() { return (wp_.Wait()); } - - static Status OverrideInterruptRc(const Status &rc); - - private: - mutable std::mutex mux_; - std::string my_name_; - Status rc_; - WaitPost wp_; - // Task need to provide definition for this function. It - // will be called by thread function. - std::function fnc_obj_; - // Misc fields used by TaskManager. - TaskGroup *task_group_; - std::future thrd_; - std::thread::id id_; - bool is_master_; - volatile bool running_; - volatile bool caught_severe_exception_; - - void ShutdownGroup(); - TaskGroup *MyTaskGroup(); - void set_task_group(TaskGroup *vg); -}; - -extern thread_local Task *gMyTask; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_TASK_H_ diff --git a/mindspore/ccsrc/dataset/util/task_manager.cc b/mindspore/ccsrc/dataset/util/task_manager.cc deleted file mode 100644 index 3965e35564..0000000000 --- a/mindspore/ccsrc/dataset/util/task_manager.cc +++ /dev/null @@ -1,353 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include "./securec.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -// This takes the same parameter as Task constructor. -Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, - Task **task) { - // We need to block destructor coming otherwise we will deadlock. We will grab the - // stateLock in shared allowing CreateAsyncTask to run concurrently. - SharedLock stateLck(&state_lock_); - // Now double check the state - if (ServiceState() == STATE::kStopInProg || ServiceState() == STATE::kStopped) { - return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "TaskManager is shutting down"); - } - RETURN_IF_NOT_OK(GetFreeTask(my_name, f, task)); - if (vg == nullptr) { - RETURN_STATUS_UNEXPECTED("TaskGroup is null"); - } - // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set - // the TaskGroup pointer. We will do the set here before we call run(). The run() will do the registration. - (*task)->set_task_group(vg); - // Link to the master lru list. - { - UniqueLock lck(&lru_lock_); - lru_.Append(*task); - } - // Link to the group list as well before we spawn. - { - UniqueLock lck(&vg->rw_lock_); - vg->grp_list_.Append(*task); - } - // Track all the TaskGroup. Used for control-c - { - LockGuard lck(&tg_lock_); - this->grp_list_.insert(vg); - } - RETURN_IF_NOT_OK((*task)->wp_.Register(vg)); - RETURN_IF_NOT_OK((*task)->Run()); - // Wait for the thread to initialize successfully. - RETURN_IF_NOT_OK((*task)->Wait()); - return Status::OK(); -} - -Status TaskManager::join_all() { - Status rc; - Status rc2; - SharedLock lck(&lru_lock_); - for (Task &tk : lru_) { - rc = tk.Join(); - if (rc.IsError()) { - rc2 = rc; - } - } - return rc2; -} - -void TaskManager::interrupt_all() noexcept { - global_interrupt_ = 1; - LockGuard lck(&tg_lock_); - for (TaskGroup *vg : grp_list_) { - auto svc = vg->GetIntrpService(); - if (svc) { - // Stop the interrupt service. No new request is accepted. - svc->ServiceStop(); - svc->InterruptAll(); - } - } - master_->Interrupt(); -} - -Task *TaskManager::FindMe() { -#if !defined(_WIN32) && !defined(_WIN64) - return gMyTask; -#else - TaskManager &tm = TaskManager::GetInstance(); - SharedLock lock(&tm.lru_lock_); - auto id = this_thread::get_id(); - auto tk = std::find_if(tm.lru_.begin(), tm.lru_.end(), [id](const Task &tk) { return tk.id_ == id; }); - if (tk != tm.lru_.end()) { - return &(*tk); - } - // If we get here, either I am the watchdog or the master thread. - if (tm.master_->id_ == id) { - return tm.master_.get(); - } else if (tm.watchdog_ != nullptr && tm.watchdog_->id_ == id) { - return tm.watchdog_; - } - MS_LOG(ERROR) << "Task not found."; - return nullptr; -#endif -} - -TaskManager::TaskManager() try : global_interrupt_(0), - lru_(&Task::node), - free_lst_(&Task::free), - watchdog_grp_(nullptr), - watchdog_(nullptr) { - auto alloc = Services::GetAllocator(); - // Create a dummy Task for the master thread (this thread) - master_ = std::allocate_shared(alloc, "master", []() -> Status { return Status::OK(); }); - master_->id_ = this_thread::get_id(); - master_->running_ = true; - master_->is_master_ = true; -#if !defined(_WIN32) && !defined(_WIN64) - gMyTask = master_.get(); - // Initialize the semaphore for the watchdog - errno_t rc = sem_init(&sem_, 0, 0); - if (rc == -1) { - MS_LOG(ERROR) << "Unable to initialize a semaphore. Errno = " << rc << "."; - std::terminate(); - } -#endif -} catch (const std::exception &e) { - MS_LOG(ERROR) << "MindData initialization failed: " << e.what() << "."; - std::terminate(); -} - -TaskManager::~TaskManager() { - if (watchdog_) { - WakeUpWatchDog(); - watchdog_->Join(); - // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it - // on shutdown. So no need to free these pointers one by one. - watchdog_grp_ = nullptr; - watchdog_ = nullptr; - } -#if !defined(_WIN32) && !defined(_WIN64) - (void)sem_destroy(&sem_); -#endif -} - -Status TaskManager::DoServiceStart() { - MS_LOG(INFO) << "Starting Task Manager."; -#if !defined(_WIN32) && !defined(_WIN64) - // Create a watchdog for control-c - std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); - // A dummy group just for the watchdog. We aren't really using it. But most code assumes a thread must - // belong to a group. - auto f = std::bind(&TaskManager::WatchDog, this); - Status rc; - watchdog_grp_ = new (&rc, mp) TaskGroup(); - RETURN_IF_NOT_OK(rc); - rc = watchdog_grp_->CreateAsyncTask("Watchdog", f, &watchdog_); - if (rc.IsError()) { - ::operator delete(watchdog_grp_, mp); - watchdog_grp_ = nullptr; - return rc; - } - grp_list_.erase(watchdog_grp_); - lru_.Remove(watchdog_); -#endif - return Status::OK(); -} - -Status TaskManager::DoServiceStop() { - WakeUpWatchDog(); - interrupt_all(); - return Status::OK(); -} - -Status TaskManager::WatchDog() { - TaskManager::FindMe()->Post(); -#if !defined(_WIN32) && !defined(_WIN64) - errno_t err = sem_wait(&sem_); - if (err == -1) { - RETURN_STATUS_UNEXPECTED("Errno = " + std::to_string(errno)); - } - // We are woken up by control-c and we are going to stop all threads that are running. - // In addition, we also want to prevent new thread from creating. This can be done - // easily by calling the parent function. - RETURN_IF_NOT_OK(ServiceStop()); -#endif - return Status::OK(); -} - -// Follow the group link and interrupt other -// Task in the same group. It is used by -// Watchdog only. -void TaskManager::InterruptGroup(Task &curTk) { - TaskGroup *vg = curTk.MyTaskGroup(); - vg->interrupt_all(); -} - -void TaskManager::InterruptMaster(const Status &rc) { - TaskManager &tm = TaskManager::GetInstance(); - std::shared_ptr master = tm.master_; - std::lock_guard lck(master->mux_); - master->Interrupt(); - if (rc.IsError() && master->rc_.IsOk()) { - master->rc_ = rc; - master->caught_severe_exception_ = true; - } -} - -Status TaskManager::GetMasterThreadRc() { - TaskManager &tm = TaskManager::GetInstance(); - std::shared_ptr master = tm.master_; - Status rc = tm.master_->GetTaskErrorIfAny(); - if (rc.IsError()) { - // Reset the state once we retrieve the value. - std::lock_guard lck(master->mux_); - master->rc_ = Status::OK(); - master->caught_severe_exception_ = false; - master->ResetIntrpState(); - } - return rc; -} - -void TaskManager::ReturnFreeTask(Task *p) noexcept { - // Take it out from lru_ if any - { - UniqueLock lck(&lru_lock_); - auto it = std::find(lru_.begin(), lru_.end(), *p); - if (it != lru_.end()) { - lru_.Remove(p); - } - } - // We need to deallocate the string resources associated with the Task class - // before we cache its memory for future use. - p->~Task(); - // Put it back into free list - { - LockGuard lck(&free_lock_); - free_lst_.Append(p); - } -} - -Status TaskManager::GetFreeTask(const std::string &my_name, const std::function &f, Task **p) { - if (p == nullptr) { - RETURN_STATUS_UNEXPECTED("p is null"); - } - Task *q = nullptr; - // First try the free list - { - LockGuard lck(&free_lock_); - if (free_lst_.count > 0) { - q = free_lst_.head; - free_lst_.Remove(q); - } - } - if (q) { - new (q) Task(my_name, f); - } else { - std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); - Status rc; - q = new (&rc, mp) Task(my_name, f); - RETURN_IF_NOT_OK(rc); - } - *p = q; - return Status::OK(); -} - -Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::function &f, Task **ppTask) { - auto pMytask = TaskManager::FindMe(); - // We need to block ~TaskGroup coming otherwise we will deadlock. We will grab the - // stateLock in shared allowing CreateAsyncTask to run concurrently. - SharedLock state_lck(&state_lock_); - // Now double check the state - if (ServiceState() != STATE::kRunning) { - return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Taskgroup is shutting down"); - } - TaskManager &dm = TaskManager::GetInstance(); - Task *pTask = nullptr; - // If the group is already in error, early exit too. - // We can't hold the rc_mux_ throughout because the thread spawned by CreateAsyncTask may hit error which - // will try to shutdown the group and grab the rc_mux_ and we will deadlock. - { - std::unique_lock rcLock(rc_mux_); - if (rc_.IsError()) { - return pMytask->IsMasterThread() ? rc_ : Status(StatusCode::kInterrupted); - } - } - RETURN_IF_NOT_OK(dm.CreateAsyncTask(my_name, f, this, &pTask)); - if (ppTask) { - *ppTask = pTask; - } - return Status::OK(); -} - -void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } - -Status TaskGroup::join_all(Task::WaitFlag wf) { - Status rc; - Status rc2; - SharedLock lck(&rw_lock_); - for (Task &tk : grp_list_) { - rc = tk.Join(wf); - if (rc.IsError()) { - rc2 = rc; - } - } - return rc2; -} - -Status TaskGroup::DoServiceStop() { - intrp_svc_->ServiceStop(); - interrupt_all(); - return (join_all(Task::WaitFlag::kNonBlocking)); -} - -TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { - auto alloc = Services::GetAllocator(); - intrp_svc_ = std::allocate_shared(alloc); - (void)Service::ServiceStart(); -} - -TaskGroup::~TaskGroup() { - (void)Service::ServiceStop(); - // The TaskGroup is going out of scope, and we can return the Task list to the free list. - Task *cur = grp_list_.head; - TaskManager &tm = TaskManager::GetInstance(); - while (cur) { - Task *next = cur->group.next; - grp_list_.Remove(cur); - tm.ReturnFreeTask(cur); - cur = next; - } - { - LockGuard lck(&tm.tg_lock_); - (void)tm.grp_list_.erase(this); - } -} - -Status TaskGroup::GetTaskErrorIfAny() { - SharedLock lck(&rw_lock_); - for (Task &tk : grp_list_) { - RETURN_IF_NOT_OK(tk.GetTaskErrorIfAny()); - } - return Status::OK(); -} - -std::shared_ptr TaskGroup::GetIntrpService() { return intrp_svc_; } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/task_manager.h b/mindspore/ccsrc/dataset/util/task_manager.h deleted file mode 100644 index 5961c9000e..0000000000 --- a/mindspore/ccsrc/dataset/util/task_manager.h +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_TASK_MANAGER_H_ -#define DATASET_UTIL_TASK_MANAGER_H_ - -#if !defined(_WIN32) && !defined(_WIN64) -#include -#include // for sig_atomic_t -#endif -#include -#include -#include -#include -#include -#include "dataset/util/allocator.h" -#include "dataset/util/intrp_service.h" -#include "dataset/util/lock.h" -#include "dataset/util/services.h" -#include "dataset/util/status.h" -#include "dataset/util/task.h" - -namespace mindspore { -namespace dataset { -namespace thread { -using id = std::thread::id; -} // namespace thread - -namespace this_thread { -inline thread::id get_id() { return std::this_thread::get_id(); } -} // namespace this_thread - -class TaskManager : public Service { - public: - friend class Services; - - friend class TaskGroup; - - ~TaskManager() override; - - TaskManager(const TaskManager &) = delete; - - TaskManager &operator=(const TaskManager &) = delete; - - static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } - - Status DoServiceStart() override; - - Status DoServiceStop() override; - - // A public global interrupt flag for signal handlers - volatile sig_atomic_t global_interrupt_; - - // API - // This takes the same parameter as Task constructor. Take a look - // of the test-thread.cc for usage. - Status CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, Task **); - - // Same usage as boot thread group - Status join_all(); - - void interrupt_all() noexcept; - - // Locate a particular Task. - static Task *FindMe(); - - static void InterruptGroup(Task &); - - static Status GetMasterThreadRc(); - - static void InterruptMaster(const Status &rc = Status::OK()); - - static void WakeUpWatchDog() { -#if !defined(_WIN32) && !defined(_WIN64) - TaskManager &tm = TaskManager::GetInstance(); - (void)sem_post(&tm.sem_); -#endif - } - - void ReturnFreeTask(Task *p) noexcept; - - Status GetFreeTask(const std::string &my_name, const std::function &f, Task **p); - - Status WatchDog(); - - private: - RWLock lru_lock_; - SpinLock free_lock_; - SpinLock tg_lock_; - std::shared_ptr master_; - List lru_; - List free_lst_; -#if !defined(_WIN32) && !defined(_WIN64) - sem_t sem_; -#endif - TaskGroup *watchdog_grp_; - std::set grp_list_; - Task *watchdog_; - - TaskManager(); -}; - -// A group of related tasks. -class TaskGroup : public Service { - public: - friend class Task; - friend class TaskManager; - - Status CreateAsyncTask(const std::string &my_name, const std::function &f, Task **pTask = nullptr); - - void interrupt_all() noexcept; - - Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking); - - int size() const noexcept { return grp_list_.count; } - - Status DoServiceStart() override { return Status::OK(); } - - Status DoServiceStop() override; - - TaskGroup(); - - ~TaskGroup() override; - - Status GetTaskErrorIfAny(); - - std::shared_ptr GetIntrpService(); - - private: - Status rc_; - // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_. - std::mutex rc_mux_; - RWLock rw_lock_; - List grp_list_; - std::shared_ptr intrp_svc_; -}; - -namespace this_thread { -inline bool is_interrupted() { - TaskManager &tm = TaskManager::GetInstance(); - if (tm.global_interrupt_ == 1) { - return true; - } - Task *my_task = TaskManager::FindMe(); - return my_task->Interrupted(); -} - -inline bool is_master_thread() { - Task *my_task = TaskManager::FindMe(); - return my_task->IsMasterThread(); -} - -inline Status GetInterruptStatus() { - Task *my_task = TaskManager::FindMe(); - return my_task->GetInterruptStatus(); -} -} // namespace this_thread - -#define RETURN_IF_INTERRUPTED() \ - do { \ - if (mindspore::dataset::this_thread::is_interrupted()) { \ - return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ - } \ - } while (false) - -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_TASK_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/wait_post.cc b/mindspore/ccsrc/dataset/util/wait_post.cc deleted file mode 100644 index 204f203d9a..0000000000 --- a/mindspore/ccsrc/dataset/util/wait_post.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/util/wait_post.h" -#include "dataset/util/task_manager.h" - -namespace mindspore { -namespace dataset { -WaitPost::WaitPost() : value_(0) {} - -Status WaitPost::Wait() { - std::unique_lock lck(mutex_); - return (wait_cond_.Wait(&lck, [this]() { return value_ != 0; })); -} - -void WaitPost::Set() { - std::unique_lock lck(mutex_); - value_ = 1; - wait_cond_.NotifyAll(); -} - -void WaitPost::Clear() { - std::unique_lock lck(mutex_); - value_ = 0; -} - -Status WaitPost::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } - -void WaitPost::ResetIntrpState() { wait_cond_.ResetIntrpState(); } - -Status WaitPost::Deregister() { return wait_cond_.Deregister(); } -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/wait_post.h b/mindspore/ccsrc/dataset/util/wait_post.h deleted file mode 100644 index 4e60995bd9..0000000000 --- a/mindspore/ccsrc/dataset/util/wait_post.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_UTIL_WAIT_POST_H_ -#define DATASET_UTIL_WAIT_POST_H_ - -#include -#include "dataset/util/cond_var.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class TaskGroup; - -class WaitPost { - public: - WaitPost(); - - ~WaitPost() = default; - - Status Wait(); - - void Set(); - - void Clear(); - - Status Register(TaskGroup *vg); - - Status Deregister(); - - void ResetIntrpState(); - - private: - std::mutex mutex_; - CondVar wait_cond_; - int value_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_UTIL_WAIT_POST_H_ diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index fc32e0fb5f..c7f2e2b14d 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -24,9 +24,9 @@ #include "ir/primitive.h" #include "ir/func_graph.h" -#include "device/kernel_info.h" +#include "runtime/device/kernel_info.h" #include "utils/graph_utils.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" namespace mindspore { const std::string ToShortString(const TypeId &typeId) { diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 894e59fe4b..273a6f6458 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -28,17 +28,17 @@ #include "ir/meta_func_graph.h" #include "ir/param_value.h" #include "ir/tensor_py.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/resolve.h" -#include "operator/composite/composite.h" -#include "operator/composite/map.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/resolve.h" +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/composite/map.h" #include "utils/ordered_map.h" #include "utils/ordered_set.h" #include "utils/utils.h" #include "debug/trace.h" #include "debug/label.h" #include "utils/context/ms_context.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" using mindspore::tensor::TensorPy; diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 4503692eb9..ed5e3b8a5d 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -28,9 +28,9 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "ir/meta_func_graph.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/resolve.h" -#include "operator/composite/composite.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/resolve.h" +#include "frontend/operator/composite/composite.h" #include "utils/symbolic.h" #include "utils/ordered_map.h" #include "utils/ordered_set.h" diff --git a/mindspore/ccsrc/debug/debugger/debugger.cc b/mindspore/ccsrc/debug/debugger/debugger.cc index c061fba6e7..369f33d79c 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.cc +++ b/mindspore/ccsrc/debug/debugger/debugger.cc @@ -19,8 +19,8 @@ #include #include #include "debug/debugger/debugger.h" -#include "pipeline/pipeline.h" -#include "session/anf_runtime_algorithm.h" +#include "pipeline/jit/pipeline.h" +#include "backend/session/anf_runtime_algorithm.h" using debugger::EventReply; using debugger::GraphProto; diff --git a/mindspore/ccsrc/debug/debugger/debugger.h b/mindspore/ccsrc/debug/debugger/debugger.h index 9b03d6b0b7..da1f325291 100644 --- a/mindspore/ccsrc/debug/debugger/debugger.h +++ b/mindspore/ccsrc/debug/debugger/debugger.h @@ -19,7 +19,7 @@ #include #include #include -#include "session/kernel_graph.h" +#include "backend/session/kernel_graph.h" #include "debug/debugger/grpc_client.h" #include "debug/debug_services.h" diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 6cbd5b7f5f..ff8132fb28 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -29,7 +29,7 @@ #include "ir/primitive.h" #include "utils/graph_utils.h" #include "utils/utils.h" -#include "operator/composite/composite.h" +#include "frontend/operator/composite/composite.h" #include "ir/tensor.h" namespace py = pybind11; diff --git a/mindspore/ccsrc/debug/draw.h b/mindspore/ccsrc/debug/draw.h index 7804c6e94a..cb670fe0f6 100644 --- a/mindspore/ccsrc/debug/draw.h +++ b/mindspore/ccsrc/debug/draw.h @@ -22,7 +22,7 @@ #include #include "ir/anf.h" #include "utils/any.h" -#include "pipeline/parse/resolve.h" +#include "pipeline/jit/parse/resolve.h" namespace mindspore { namespace draw { diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index e12a7b1209..b8d3f0a7c7 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -29,10 +29,10 @@ #include "ir/meta_func_graph.h" #include "utils/graph_utils.h" -#include "operator/composite/composite.h" +#include "frontend/operator/composite/composite.h" #include "ir/tensor.h" #include "debug/anf_ir_utils.h" -#include "pipeline/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/evaluator.h" namespace mindspore { // namespace to support debug trace infomation diff --git a/mindspore/ccsrc/debug/trace.h b/mindspore/ccsrc/debug/trace.h index 9583997e93..7cf45abe30 100644 --- a/mindspore/ccsrc/debug/trace.h +++ b/mindspore/ccsrc/debug/trace.h @@ -27,7 +27,7 @@ #include "debug/info.h" #include "ir/anf.h" #include "ir/func_graph.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/static_analysis.h" #include "utils/any.h" namespace mindspore { diff --git a/mindspore/ccsrc/device/CMakeLists.txt b/mindspore/ccsrc/device/CMakeLists.txt deleted file mode 100644 index 652c04d4cd..0000000000 --- a/mindspore/ccsrc/device/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" - "kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" -) - -if (ENABLE_GPU) - list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_init.cc") -else () - list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_fake_init.cc") -endif () - -if (ENABLE_D) - file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc") -endif () - -if (ENABLE_CPU) - file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") -endif () - -if (ENABLE_MPI) - # _ms_mpi - file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") - set_property(SOURCE ${MPI_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) - target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) - - set_property(SOURCE "gpu/mpi/mpi_initializer.cc" - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") - target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) -endif () - -# gpu -if (ENABLE_GPU) - file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") - - set(GPU_QUEUE_SRCS "gpu/blocking_queue.cc" "gpu/gpu_buffer_mgr.cc") - set(GPU_COLLECTIVE_SRCS "gpu/distribution/collective_wrapper.cc" - "gpu/distribution/mpi_wrapper.cc" - "gpu/distribution/nccl_wrapper.cc") - - # gpu_queue - list(REMOVE_ITEM CUDA_SRC_LIST ${GPU_QUEUE_SRCS}) - set_property(SOURCE ${GPU_QUEUE_SRCS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(gpu_queue SHARED ${GPU_QUEUE_SRCS}) - target_link_libraries(gpu_queue ${CMAKE_THREAD_LIBS_INIT} ${CUDA_PATH}/lib64/libcudart.so) - - list(REMOVE_ITEM CUDA_SRC_LIST "gpu/mpi/mpi_initializer.cc" ${GPU_COLLECTIVE_SRCS}) - - if (ENABLE_MPI) - include(ExternalProject) - # gpu_collective - set_property(SOURCE ${GPU_COLLECTIVE_SRCS} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) - add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) - target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) - endif () - - # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) -endif () - -set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) -add_library(_mindspore_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST}) diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc deleted file mode 100644 index 1b5645ab30..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ /dev/null @@ -1,415 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/ascend/ascend_device_address.h" -#include -#include -#include -#include -#include "runtime/mem.h" -#include "device/kernel_runtime_manager.h" -#include "device/convert_tensor_utils.h" -#include "ir/dtype/type.h" -#include "ir/tensor.h" -#include "kernel/common_utils.h" -#include "utils/utils.h" -#include "common/utils.h" -#include "common/trans.h" -#ifdef ENABLE_DUMP_E2E -#include "debug/e2e_dump.h" -#endif -#ifdef ENABLE_DEBUGGER -#include "debug/tensor_load.h" -#endif - -namespace mindspore { -namespace device { -namespace ascend { -const int FLOAT_LEN = sizeof(float); -const int FLOAT16_LEN = 2; // sizeof(float16); -const std::set kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, - kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; - -void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { - auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); - if (ret_rt_memcpy != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; - } -} - -bool FloatToHalfAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { - auto elem_num = src_size / FLOAT_LEN; - if (elem_num != (dst_size / FLOAT16_LEN)) { - MS_EXCEPTION(ArgumentError) << "FloatToHalf failed. size not match src_size[" << src_size << "], dst_size[" - << dst_size << "]"; - } - std::vector half_data(elem_num); - FloatToHalf(half_data.data(), src, elem_num); - SyncMemory(dst, half_data.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); - return true; -} - -bool Float64ToFloatAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { - if (src_size / 2 != dst_size) { - MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; - } - size_t elem_num = dst_size / sizeof(float); - auto host_tmp = std::vector(elem_num); - DoubleToFloat(host_tmp.data(), src, elem_num); - SyncMemory(dst, host_tmp.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); - return true; -} - -bool SyncDeviceToHostAndHalfToFloat(void *dst, size_t dst_size, const void *src, size_t src_size) { - auto elem_num = src_size / FLOAT16_LEN; - if (elem_num != (dst_size / FLOAT_LEN)) { - MS_EXCEPTION(ArgumentError) << "HalfToFloat failed. size not match src_size[" << src_size << "], dst_size[" - << dst_size << "]"; - } - std::vector half_data(elem_num); - SyncMemory(half_data.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); - HalfToFloat(dst, half_data.data(), elem_num); - return true; -} - -bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *src, size_t src_size) { - if (src_size != dst_size / 2) { - MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; - } - size_t elem_num = src_size / sizeof(float); - auto host_tmp = std::vector(elem_num); - SyncMemory(host_tmp.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); - FloatToDouble(dst, host_tmp.data(), elem_num); - return true; -} - -void AscendDeviceAddress::SyncStream() const { - MS_LOG(INFO) << "Start!"; - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() != kPynativeMode) { - MS_LOG(INFO) << "Finish!"; - return; - } - auto device_id = ms_context->device_id(); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); - MS_EXCEPTION_IF_NULL(runtime_instance); - auto ret = runtime_instance->SyncStream(); - if (!ret) { - MS_LOG(EXCEPTION) << "Sync stream error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t size, mindspore::TypeId type, - void *host_ptr) const { - MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - SyncStream(); - bool sync_ok = false; - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { - if (type_id_ == type) { - SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); - sync_ok = true; - } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { - sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); - } else { - auto shape_size = trans::ShapeSize(host_shape); - auto host = std::vector(size_); - SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; - sync_ok = trans::TransDataType(type_args, host_ptr); - if (!sync_ok) { - MS_LOG(ERROR) << "trans data type failed."; - return false; - } - } - } else { - auto iter = kOpNeedTransFormat.find(format_); - if (iter != kOpNeedTransFormat.end()) { - sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); - } else { - MS_LOG(INFO) << "Can not find format transfer for :" << format_; - } - } - if (!sync_ok) { - MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) - << ", host_type:" << TypeIdLabel(type); - return false; - } - return sync_ok; -} - -bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, - mindspore::TypeId type, void *host_ptr) const { - MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - bool sync_ok = false; - auto host_tmp = std::vector(size_); - SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - std::vector device_shape; - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { - device_shape = trans::TransShapeToDevice(host_shape, format_); - } else { - if (host_shape_.empty()) { - host_shape = trans::PaddingShapeTo4d(host_shape); - } else { - host_shape.clear(); - (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); - } - - device_shape = trans::TransShapeToDevice(host_shape, format_); - } - if (type_id_ != type) { - const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, - host_shape, device_shape, type_id_}; - auto host = std::vector(size_); - sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; - sync_ok = trans::TransDataType(type_args, host_ptr); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - } else { - const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, - host_shape, device_shape, type_id_}; - sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - } - return sync_ok; -} - -bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t size, mindspore::TypeId type, - const void *host_ptr) const { - MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - SyncStream(); - bool sync_ok = false; - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { - if (type_id_ == type) { - SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE); - sync_ok = true; - } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { - sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); - } else { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; - auto host_tmp = std::vector(size_); - sync_ok = trans::TransDataType(type_args, host_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans data type failed."; - return false; - } - SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); - } - } else { - auto iter = kOpNeedTransFormat.find(format_); - if (iter != kOpNeedTransFormat.end()) { - sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); - } else { - MS_LOG(INFO) << "Can not find format transfer for :" << format_; - } - } - if (!sync_ok) { - MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) - << ", host_type:" << TypeIdLabel(type); - return false; - } - return sync_ok; -} - -bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, - mindspore::TypeId type, const void *host_ptr) const { - bool sync_ok = false; - MS_LOG(INFO) << "ConvertFormatAndSyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) - << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; - std::vector host_shape; - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.emplace_back(1); - } - std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { - device_shape = trans::TransShapeToDevice(host_shape, format_); - } else { - host_shape = trans::PaddingShapeTo4d(host_shape); - device_shape = trans::TransShapeToDevice(host_shape, format_); - } - if (type_id_ != type) { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; - auto host_tmp = std::vector(size_); - sync_ok = trans::TransDataType(type_args, host_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans datatype failed."; - return false; - } - const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, - host_shape, device_shape, type_id_}; - auto dst_tmp = std::vector(size_); - sync_ok = trans::TransFormat(format_args, dst_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); - } else { - const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; - auto host_tmp = std::vector(size_); - sync_ok = trans::TransFormat(format_args, host_tmp.data()); - if (!sync_ok) { - MS_LOG(ERROR) << "Trans format failed."; - return false; - } - SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); - } - return sync_ok; -} - -void AscendDeviceAddress::UpdateCommunicationAddress() { - MS_EXCEPTION_IF_NULL(ptr_); - communication_ptr_ = reinterpret_cast(ptr_) - kMemAlignSize; -} - -AscendDeviceAddress::~AscendDeviceAddress() { - if (ptr_ == nullptr) { - return; - } - if (from_mem_pool_) { - if (communication_ptr_ != nullptr) { - AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_); - communication_ptr_ = nullptr; - } else { - AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); - } - ptr_ = nullptr; - } -} - -#ifdef ENABLE_DUMP_E2E -bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &filepath, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type) const { - bool ret = false; - if (filepath.empty()) { - MS_LOG(ERROR) << "Dump file path is null!"; - return ret; - } - std::string shape = "shape"; - if (host_shape.size()) { - for (auto &value : host_shape) { - shape = shape + '_' + std::to_string(value); - } - } else { - shape = shape + "_0"; - } - std::string file_extension = ".bin"; - if (trans_flag) { - std::string path = filepath + '_' + shape + '_' + TypeIdLabel(host_type) + '_' + host_fmt + file_extension; - MS_LOG(INFO) << "E2E Dump path is " << path; - mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); - size_t host_size = out_tensor->data().nbytes(); - ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); - if (!ret) { - MS_LOG(ERROR) << "Copy device mem to host failed"; - return ret; - } - ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); - } else { - auto host_tmp = std::vector(size_); - auto ret_rt_memcpy = rtMemcpy(host_tmp.data(), size_, ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - if (ret_rt_memcpy != RT_ERROR_NONE) { - MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; - } - std::string path = - filepath + '_' + shape + '_' + TypeIdToType(type_id_)->ToString() + '_' + format_ + file_extension; - MS_LOG(INFO) << "E2E Dump path is " << path; - ret = mindspore::Dump::DumpToFile(path, host_tmp.data(), size_); - } - - return ret; -} -#endif - -#ifdef ENABLE_DEBUGGER -bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, - const std::string &host_fmt, const std::vector &host_shape, - TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const { - bool ret = false; - - DebugServices *debug_services = debugger->debug_services(); - TensorLoader *tensor_loader = debug_services->get_tensor_loader(); - - if (trans_flag) { - MS_LOG(INFO) << "E2E tensor name is " << tensor_name; - mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); - size_t host_size = out_tensor->data().nbytes(); - ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); - if (!ret) { - MS_LOG(ERROR) << "Copy device mem to host failed"; - return ret; - } - auto tensor_data = std::make_shared(); - tensor_data->SetName(tensor_name); - tensor_data->SetExecutionOrder(execution_order); - tensor_data->SetTensor(out_tensor); - tensor_data->SetSlot(slot); - ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); - } else { - mindspore::tensor::TensorPtr out_tensor = std::make_shared(type_id_, host_shape); - size_t host_size = out_tensor->data().nbytes(); - auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); - - auto tensor_data = std::make_shared(); - tensor_data->SetName(tensor_name); - tensor_data->SetExecutionOrder(execution_order); - tensor_data->SetTensor(out_tensor); - tensor_data->SetSlot(slot); - ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); - if (ret_rt_memcpy != RT_ERROR_NONE) { - MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; - } - MS_LOG(INFO) << "E2E tensor name is " << tensor_name; - } - return ret; -} -#endif -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.h b/mindspore/ccsrc/device/ascend/ascend_device_address.h deleted file mode 100644 index 27bcea814c..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ - -#include -#include -#include -#include "device/device_address.h" -#include "device/ascend/ascend_memory_pool.h" -#include "ir/dtype.h" - -namespace mindspore { -#ifdef ENABLE_DEBUGGER -class Debugger; -#endif -namespace device { -namespace ascend { -class AscendDeviceAddress : public DeviceAddress { - public: - explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} - explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - ~AscendDeviceAddress() override; - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; - DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } - void UpdateCommunicationAddress() override; -#ifdef ENABLE_DUMP_E2E - bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type) const; -#endif -#ifdef ENABLE_DEBUGGER - bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, - bool keep_prev) const; -#endif - - private: - bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; - bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, - const void *host_ptr) const; - void SyncStream() const; - uint8_t *communication_ptr_{nullptr}; -}; -using AscendDeviceAddressPtr = std::shared_ptr; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc deleted file mode 100644 index 42b1d93dd5..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ /dev/null @@ -1,713 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#define PATH_MAX 0x3ffff -#include "device/ascend/ascend_kernel_runtime.h" -#include -#include -#include -#include -#include -#include -#include "device/ascend/ascend_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "utils/context/ms_context.h" -#include "utils/mpi/mpi_config.h" -#include "device/ascend/profiling/profiling_manager.h" -#include "hccl/hcom.h" -#include "common/trans.h" -#include "runtime/context.h" -#include "device/ascend/ascend_label_assign.h" -#include "device/ascend/ascend_stream_assign.h" -#include "device/ascend/ascend_memory_pool.h" -#include "framework/ge_runtime/model_runner.h" -#include "device/ascend/tasksink/task_generator.h" -#include "session/anf_runtime_algorithm.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#include "device/ascend/ascend_memory_manager.h" -#include "debug/tensor_load.h" - -using ge::model_runner::ModelRunner; -using mindspore::device::ascend::ProfilingManager; -using mindspore::device::ascend::ProfilingUtils; -using mindspore::device::ascend::tasksink::TaskGenerator; -using mindspore::kernel::tbe::TbeUtils; -using std::vector; - -namespace mindspore { -namespace device { -namespace ascend { -static const size_t PRAMATER_OUTPUT_INDEX = 0; -namespace { -std::string GetRankId() { - std::string rank_id_str; -#ifdef ENABLE_MPI - auto mpi_config_ptr = MpiConfig::GetInstance(); - MS_EXCEPTION_IF_NULL(mpi_config_ptr); - if (mpi_config_ptr->enable_mpi()) { - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - int rank_id = mpi_instance->GetRankId(); - const char *offset = std::getenv("RANK_OFFSET"); - if (offset != nullptr) { - try { - int rank_offset = std::stoi(offset); - rank_id += rank_offset; - } catch (std::invalid_argument) { - MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset; - } catch (std::out_of_range) { - MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset; - } - } - rank_id_str = std::to_string(rank_id); - } else { - rank_id_str = std::getenv("RANK_ID"); - } -#else - rank_id_str = std::getenv("RANK_ID"); -#endif - if (rank_id_str.empty()) { - MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID"; - } - return rank_id_str; -} -} // namespace - -AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } - -void AscendKernelRuntime::ClearGraphModelMap() { -#ifdef ENABLE_DATA_DUMP - for (auto &iter : graph_data_dumper_) { - MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; - iter.second->UnloadDumpInfo(); - } - graph_data_dumper_.clear(); -#endif - for (auto &iter : graph_model_map_) { - MS_LOG(INFO) << "Ge UnloadModel " << iter.first; - auto ret = ModelRunner::Instance().UnloadModel(iter.first); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; - } - } -} - -void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { - MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; - auto iter = graph_model_map_.find(graph_id); - if (iter == graph_model_map_.end()) { - MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; - return; - } - MS_LOG(DEBUG) << "Ge UnloadModel " << iter->first; - auto ret = ModelRunner::Instance().UnloadModel(iter->first); - if (!ret) { - MS_LOG(ERROR) << "UnloadModel failed"; - } - graph_model_map_.erase(iter); -} - -bool AscendKernelRuntime::NeedDestroyHccl() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->enable_hccl()) { - MS_LOG(INFO) << "Hccl is not enabled"; - return false; - } - // Note: make sure hcom_connectivity_detection api never be used. - return true; -} - -void AscendKernelRuntime::ReleaseDeviceRes() { - MS_LOG(INFO) << "Ascend finalize start"; - // release ge runtime - ClearGraphModelMap(); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - auto ret = rtSetDevice(context_ptr->device_id()); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; - } - - if (mem_manager_ != nullptr) { - mem_manager_->FreeDeviceMemory(); - } - - (void)DestroyHccl(); - (void)ResetDevice(); - (void)ProfilingManager::GetInstance().StopProfiling(); - MS_LOG(INFO) << "Ascend finalize end"; -} - -bool AscendKernelRuntime::Init() { - if (initialized_) { - return true; - } - bool ret = false; -#ifdef ENABLE_DUMP_E2E - ret = SetDumpConf(); - if (!ret) { - MS_LOG(INFO) << "No dump conf to set!"; - } -#endif - -#ifdef ENABLE_DATA_DUMP - DataDumpParser::GetInstance().ParseDumpConfig(); -#endif - - // Start up profiling before rtSetDevice - ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); - if (!ret) { - MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed."; - } - - ret = InitDevice(); - if (!ret) { - return ret; - } - mem_manager_ = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->MallocDeviceMemory(); - - initialized_ = true; - return ret; -} - -#ifdef ENABLE_DUMP_E2E -namespace { -void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(dump_conf); - bool trans_flag = dump_conf->trans_flag(); - const auto &apply_kernels = graph->execution_order(); - for (const auto &node : apply_kernels) { - MS_EXCEPTION_IF_NULL(node); - auto node_name = AnfAlgo::GetCNodeName(node); - std::string kernel_name = node->fullname_with_scope(); - if (!dump_conf->IsKernelNeedDump(kernel_name)) { - continue; - } - const std::string strsrc = "/"; - const std::string strdst = "--"; - std::string::size_type pos = 0; - std::string::size_type srclen = strsrc.size(); - std::string::size_type dstlen = strdst.size(); - while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) { - kernel_name.replace(pos, srclen, strdst); - pos += dstlen; - } - auto output_size = AnfAlgo::GetOutputTensorNum(node); - for (size_t j = 0; j < output_size; ++j) { - auto addr = AnfAlgo::GetOutputAddr(node, j); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(node, j); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(node, j); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto type = AnfAlgo::GetOutputInferDataType(node, j); - auto format = kOpFormat_DEFAULT; - string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); - auto ascend_addr = dynamic_cast(addr); - auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); - if (!ret) { - MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath - << ", host_format:" << format << ".!"; - } - } - } -} - -void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(dump_conf); - bool trans_flag = dump_conf->trans_flag(); - const auto ¶meters = graph->inputs(); - for (auto &item : parameters) { - if (!item->isa()) { - continue; - } - std::string parameter_name = item->fullname_with_scope(); - if (!dump_conf->IsKernelNeedDump(parameter_name)) { - continue; - } - auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); - auto format = kOpFormat_DEFAULT; - string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; - auto ascend_addr = dynamic_cast(addr); - auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); - if (!ret) { - MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath - << ", host_format:" << format << ".!"; - } - } -} -} // namespace -#endif - -bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); -#ifdef ENABLE_DUMP_E2E - MS_LOG(INFO) << "Start dump step"; - DumpConfPtr dump_conf = GetDumpConf(); - MS_EXCEPTION_IF_NULL(dump_conf); - dump_conf->UpdataCurIter(); - bool dump_flag = dump_conf->dump_enable(); - if (!dump_flag) { - MS_LOG(INFO) << "Dump flag is disable, pass dump step"; - return true; - } - uint32_t cur_iter = dump_conf->cur_iter(); - if (dump_conf->dump_iter() != 0) { - if (cur_iter != dump_conf->dump_iter()) { - return true; - } - } - MS_LOG(INFO) << "Cur iter is " << cur_iter; - std::string net_name = dump_conf->dump_net_name(); - std::string iterator = to_string(cur_iter); - std::string dump_path = dump_conf->dump_path(); - if (dump_path.back() == '/') { - dump_path = dump_path + net_name + '/' + iterator; - } else { - dump_path = dump_path + '/' + net_name + '/' + iterator; - } - // dump output - DumpOutput(graph, dump_path, dump_conf); - // dump parameters - DumpParameters(graph, dump_path, dump_conf); -#endif - return true; -} - -#ifdef ENABLE_DEBUGGER -namespace { -void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { - MS_EXCEPTION_IF_NULL(graph); - bool trans_flag = false; - const auto &apply_kernels = graph->execution_order(); - // for kernels, execution order starts from 1 - int exec_order = 1; - for (const auto &node : apply_kernels) { - MS_EXCEPTION_IF_NULL(node); - auto node_name = AnfAlgo::GetCNodeName(node); - std::string kernel_name = node->fullname_with_scope(); - auto output_size = AnfAlgo::GetOutputTensorNum(node); - for (size_t j = 0; j < output_size; ++j) { - auto addr = AnfAlgo::GetOutputAddr(node, j); - auto type = AnfAlgo::GetOutputInferDataType(node, j); - auto format = kOpFormat_DEFAULT; - string tensor_name = kernel_name + ':' + std::to_string(j); - auto ascend_addr = dynamic_cast(addr); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(node, j); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(node, j); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto ret = - ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger, false); - if (!ret) { - MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name - << ", host_format:" << format << ".!"; - } - } - exec_order = exec_order + 1; - } -} - -void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { - MS_EXCEPTION_IF_NULL(graph); - bool trans_flag = false; - const auto ¶meters = graph->inputs(); - // for parameters, set its execution order to be 0; - int exec_order = 0; - for (auto &item : parameters) { - if (!item->isa()) { - continue; - } - std::string parameter_name = item->fullname_with_scope(); - auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); - auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); - auto format = kOpFormat_DEFAULT; - string tensor_name = parameter_name + ':' + "0"; - auto ascend_addr = dynamic_cast(addr); - std::vector int_shapes; - if (trans_flag) { - int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); - } else { - auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); - (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), - [](size_t inner_item) { return SizeToInt(inner_item); }); - } - auto ret = - ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger, true); - if (!ret) { - MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name - << ", host_format:" << format << ".!"; - } - } -} -} // namespace -#endif - -bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { - MS_EXCEPTION_IF_NULL(graph); -#ifdef ENABLE_DEBUGGER - MS_LOG(INFO) << "Start load step"; - uint32_t cur_iter = 0; - MS_LOG(INFO) << "Cur iter is " << cur_iter; - // load output - LoadOutput(graph, debugger); - // load parameters - LoadParameters(graph, debugger); -#endif - return true; -} - -bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { - if (AnfAlgo::OutputAddrExist(kernel, index)) { - auto address = AnfAlgo::GetOutputAddr(kernel, index); - MS_EXCEPTION_IF_NULL(address); - return address->DeviceType() == DeviceAddressType::kAscend; - } - return false; -} - -DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { - if (graph == nullptr) { - MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; - } - MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->enable_task_sink(); - if (!is_task_sink) { - return true; - } -#ifdef MEM_REUSE_DEBUG - if (!context_ptr->enable_mem_reuse()) { - // Get normal graph ir for memreuse - mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); - } -#endif - vector> task_info_list; - auto anf_node_list = graph->execution_order(); - TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); - // Store the task_info_list - auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list)); - if (!insert_ret.second) { - MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; - } - // Graph may have no compute node, such TensorAddGrad. - if (task_info_list.empty()) { - MS_LOG(WARNING) << "Graph " << graph->graph_id() << " have no compute node"; - return true; - } - AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); - // the streams' flag not HEAD_STREAM - std::vector wait_active_stream_list; - assign_instance.GetWaitStreams(&wait_active_stream_list); - std::vector force_copy_stream_list; - assign_instance.GetHcomStreams(&force_copy_stream_list); - MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() - << ", total event num:" << resource_manager.get_cur_event_num() - << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) - << ", wait_active_stream_list size:" << wait_active_stream_list.size() - << ", force_copy_stream_list size:" << force_copy_stream_list.size(); - std::vector> empty_list; - auto model = std::make_shared( - task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), - resource_manager.get_cur_event_num(), 0); - auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); - if (!ret.second) { - MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; - } - MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; - return true; -} - -bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { - if (graph == nullptr) { - MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; - } - MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_task_sink = context_ptr->enable_task_sink(); - if (!is_task_sink) { - return true; - } - - if (GraphWithEmptyTaskList(graph)) { - MS_LOG(WARNING) << "LoadTask end, task list is empty"; - return true; - } - - auto model_iter = graph_model_map_.find(graph->graph_id()); - if (model_iter == graph_model_map_.end()) { - MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask."; - return false; - } - - std::shared_ptr listener; - MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; - bool status = - ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); - if (!status) { - MS_LOG(EXCEPTION) << "Load Task Failed"; - } - if (ProfilingManager::GetInstance().IsProfiling()) { - auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first); - auto stream_ids = ModelRunner::Instance().GetStreamIdList(model_iter->first); - ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph)); - } - -#ifdef ENABLE_DATA_DUMP - LaunchDataDump(NOT_NULL(graph)); -#endif - if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) { - MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed"; - return false; - } - return true; -} - -#ifdef ENABLE_DATA_DUMP -void AscendKernelRuntime::LaunchDataDump(NotNull graph) { - if (!DataDumpParser::GetInstance().DumpEnabled()) { - return; - } - auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph->graph_id()); - auto data_dumper = std::make_shared(graph.get(), runtime_info_map); - MS_EXCEPTION_IF_NULL(data_dumper); - data_dumper->LoadDumpInfo(); - auto ret = graph_data_dumper_.try_emplace(graph->graph_id(), data_dumper); - if (!ret.second) { - MS_LOG(WARNING) << "[DataDump] Insert graphId:" << graph->graph_id() << " data dumper failed"; - } -} -#endif - -void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { - auto task_ids = ModelRunner::Instance().GetTaskIdList(graph_id); - auto graph_task_names = ProfilingUtils::graph_kernel_name(); - auto iter = graph_task_names.find(graph_id); - if (iter != graph_task_names.end()) { - const auto &task_names = iter->second; - if (task_ids.size() != task_names.size()) { - MS_LOG(WARNING) << "Task_ids and task_names size not match"; - return; - } - for (size_t i = 0; i < task_ids.size(); ++i) { - MS_LOG(INFO) << "Task_id:" << task_ids[i] << " task_name:" << task_names[i]; - } - } -} - -bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - ge::InputData input_tensors = ge::InputData(); - ge::OutputData *output_tensors = nullptr; - if (GraphWithEmptyTaskList(graph)) { - MS_LOG(WARNING) << "RunTask end, no task info found"; - return true; - } - - if (!CheckGraphIdValid(graph->graph_id())) { - MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask."; - return false; - } - - bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); - if (!status) { - MS_LOG(ERROR) << "Run task failed"; - DebugTaskIdName(graph->graph_id()); - return false; - } - return true; -} - -bool AscendKernelRuntime::SyncStream() { - if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream - MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; - return false; - } - return true; -} - -bool AscendKernelRuntime::InitDevice() { - int device_count = 0; - auto ret = rtGetDeviceCount(&device_count); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast(ret) << "]"; - } - - ret = rtSetDevice(device_id_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; - } - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr == nullptr) { - MS_LOG(ERROR) << "Get MsContext instance failed"; - return false; - } - if (context_ptr->enable_hccl()) { - if (!HcclInit()) { - MS_LOG(ERROR) << "HcclInit init failed"; - return false; - } - } - - ret = rtCtxCreate(&rt_context_, 0, device_id_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; - } - - ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; - } - - ret = rtStreamCreate(&stream_, 0); - if (ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; - } - - return true; -} - -bool AscendKernelRuntime::ResetDevice() { - auto ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Call rtCtxSetCurrent failed"; - return false; - } - - if (stream_ != nullptr) { - ret = rtStreamDestroy(stream_); - if (ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; - } - stream_ = nullptr; - } - - if (rt_context_ != nullptr) { - ret = rtCtxDestroy(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; - } - rt_context_ = nullptr; - } - return true; -} - -bool AscendKernelRuntime::HcclInit() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->IsTsdOpened()) { - MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; - } - MS_LOG(INFO) << "Do hcom init"; - auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); - if (config_path_str == nullptr) { - config_path_str = std::getenv("RANK_TABLE_FILE"); - if (config_path_str == nullptr) { - MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"; - return false; - } - } - if (strlen(config_path_str) > PATH_MAX) { - MS_LOG(ERROR) << "File path oversize"; - return false; - } - std::string rank_id_str = GetRankId(); - auto full_path = realpath(config_path_str, nullptr); - if (full_path == nullptr) { - MS_LOG(ERROR) << "File path " << config_path_str << " does not exist"; - return false; - } - MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; - hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); - free(full_path); - if (res != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast(res); - return false; - } - return true; -} - -bool AscendKernelRuntime::DestroyHccl() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!NeedDestroyHccl()) { - MS_LOG(INFO) << "Hccl is not enable, no need to close."; - return true; - } - hcclResult_t res = hcom_destroy(); - if (res != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Hccl destroy failed"; - return false; - } - MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; - context_ptr->set_enable_hccl(false); - return true; -} - -bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { - auto iter = task_map_.find(graph->graph_id()); - if (iter == task_map_.end()) { - MS_LOG(EXCEPTION) << "Unknown graph ptr"; - } - return iter->second.empty(); -} - -bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const { - return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h deleted file mode 100644 index 771c3f8c4f..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "runtime/context.h" -#include "framework/ge_runtime/davinci_model.h" -#include "device/kernel_runtime_manager.h" -#include "session/session_basic.h" -#ifdef ENABLE_DATA_DUMP -#include "debug/data_dump_parser.h" -#include "device/ascend/dump/data_dumper.h" -#endif - -using ge::model_runner::TaskInfo; -using std::unordered_map; -using std::vector; -namespace mindspore { -namespace device { -namespace ascend { -class AscendKernelRuntime : public KernelRuntime { - public: - AscendKernelRuntime() = default; - ~AscendKernelRuntime() override; - bool Init() override; - bool DumpData(session::KernelGraph *graph) override; - bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; - bool GenTask(const session::KernelGraph *graph) override; - bool RunTask(const session::KernelGraph *graph) override; - bool LoadTask(const session::KernelGraph *graph) override; - void ClearGraphRuntimeResource(uint32_t graph_id) override; - bool SyncStream() override; - - protected: - DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override; - bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; - - private: - bool InitDevice(); - bool ResetDevice(); - bool HcclInit(); - bool NeedDestroyHccl(); - bool DestroyHccl(); - - void ClearGraphModelMap(); - void ReleaseDeviceRes() override; - bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; - bool CheckGraphIdValid(GraphId graph_id) const; - static void DebugTaskIdName(GraphId graph_id); - - rtContext_t rt_context_{nullptr}; - bool initialized_{false}; - unordered_map>> task_map_; - unordered_map> graph_model_map_; -#ifdef ENABLE_DATA_DUMP - void LaunchDataDump(NotNull graph); - unordered_map> graph_data_dumper_; -#endif -}; - -MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc deleted file mode 100644 index 2db81a1725..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "device/ascend/ascend_label_assign.h" -#include "session/anf_runtime_algorithm.h" - -static constexpr uint32_t kLabelGotoLabelId = 1; -static constexpr uint32_t kLabelSwitchLabelId = 2; - -namespace mindspore { -namespace device { -namespace ascend { -static void UpdateLabelGoto(NotNull node) { - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { - return; - } - if (node->size() <= kLabelGotoLabelId) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); - } - - auto input = node->input(kLabelGotoLabelId); - uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); - AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); - MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; - node->set_inputs({node->input(0)}); -} - -static void UpdateLabelSwitch(NotNull node) { - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { - return; - } - if (node->size() <= kLabelGotoLabelId) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); - } - std::vector label_list; - for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { - auto input = node->input(i); - if (!input->isa() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { - break; - } - - uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); - label_list.push_back(goto_label_id); - MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; - } - AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), node.get()); - node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)}); -} - -static void AssignLabelForLabelSet(NotNull> graph, NotNull label_id, - NotNull> *> memo) { - if (memo->find(graph.get()) != memo->end()) { - return; - } - memo->insert(graph.get()); - - MS_LOG(INFO) << "Assign label for " << graph->ToString(); - graph->SetExecOrderByDefault(); - auto nodes = graph->execution_order(); - - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string node_name = AnfAlgo::GetCNodeName(node); - if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { - AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(*label_id), node); - MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; - ++(*label_id); - } - } - - for (auto &cg : graph->child_graph_order()) { - AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); - } -} - -static void AssignLabelForGotoSwitch(NotNull> graph, - NotNull> *> memo) { - if (memo->find(graph.get()) != memo->end()) { - return; - } - memo->insert(graph.get()); - - MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); - - auto nodes = graph->execution_order(); - auto end_goto = graph->get_end_goto(); - if (end_goto != nullptr) { - nodes.push_back(end_goto); - } - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string node_name = AnfAlgo::GetCNodeName(node); - if (node_name == kLabelGotoOpName) { - UpdateLabelGoto(NOT_NULL(cnode)); - cnode->set_abstract(nullptr); - } - - if (node_name == kLabelSwitchOpName) { - UpdateLabelSwitch(NOT_NULL(cnode)); - } - } - for (auto &cg : graph->child_graph_order()) { - AssignLabelForGotoSwitch(NOT_NULL(cg), memo); - } - graph->SetExecOrderByDefault(); -} - -void AscendLabelAssign::AssignLabel(NotNull> graph) { - MS_LOG(INFO) << "Assign label start."; - std::set> memo; - uint32_t label_id = 0; - AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); - memo.clear(); - { - std::lock_guard lock(label_num_mutex_); - label_num_[graph.get().get()] = label_id; - } - AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); - MS_LOG(INFO) << "Assign label end."; -} - -uint32_t AscendLabelAssign::GetLabelNum(NotNull graph) { - std::lock_guard lock(label_num_mutex_); - auto iter = label_num_.find(graph.get()); - if (iter == label_num_.end()) { - MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0."; - return 0; - } - return iter->second; -} - -uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { - return GetLabelNum(NOT_NULL(graph.get().get())); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/device/ascend/ascend_label_assign.h deleted file mode 100644 index 98055576eb..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ - -#include -#include -#include "session/kernel_graph.h" -#include "utils/contract.h" - -namespace mindspore { -namespace device { -namespace ascend { -class AscendLabelAssign { - public: - static AscendLabelAssign &GetInstance() { - static AscendLabelAssign instance; // Guaranteed to be destroyed. - return instance; - } - - AscendLabelAssign(const AscendLabelAssign &) = delete; - AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; - - void AssignLabel(NotNull> graph); - uint32_t GetLabelNum(NotNull graph); - uint32_t GetLabelNum(NotNull> graph); - - private: - AscendLabelAssign() = default; - ~AscendLabelAssign() = default; - - std::map label_num_; - std::mutex label_num_mutex_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc deleted file mode 100644 index a664232a28..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.cc +++ /dev/null @@ -1,137 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "device/ascend/ascend_memory_manager.h" -#include "device/ascend/ascend_memory_pool.h" -#include "utils/context/ms_context.h" -#include "runtime/mem.h" -namespace mindspore { -namespace device { -namespace ascend { -constexpr uint64_t kAscendDeviceMemGB = 30; -constexpr uint64_t kMemSizeGB = 30; -constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); - -void AscendMemoryManager::MallocDeviceMemory() { - auto context_mem = GetDeviceMemSizeFromContext(); - device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; - dynamic_mem_offset_ = device_mem_size_; - auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), dynamic_mem_offset_, RT_MEMORY_HBM); - - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << dynamic_mem_offset_ << "] fail, ret[" << ret << "]"; - } - - AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_); - AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); -} - -uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - auto variable_memory_max_size = context->variable_memory_max_size(); - if (variable_memory_max_size == "0") { - return 0; - } - MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size; - auto pos = variable_memory_max_size.find('*'); - if (pos == std::string::npos) { - MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size"; - } - auto gb_str = variable_memory_max_size.substr(0, pos); - auto gb_var = std::stoull(gb_str); - MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; - if (gb_var > kAscendDeviceMemGB || gb_var == 0) { - MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; - } - return gb_var << kMemSizeGB; -} - -void AscendMemoryManager::FreeDeviceMemory() { - if (device_mem_base_ != nullptr) { - auto ret = rtFree(device_mem_base_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "rtFree mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; - } - device_mem_base_ = nullptr; - } - if (device_mem_pool_base_ != nullptr) { - auto ret = rtFree(device_mem_pool_base_); - if (ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "rtFree mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; - } - device_mem_pool_base_ = nullptr; - } -} - -void AscendMemoryManager::ResetDynamicMemory() { - total_dynamic_size_ = 0; - dynamic_mem_offset_ = device_mem_size_; - AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); -} - -void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { - auto align_size = GetCommonAlignSize(size); - return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); -} - -uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem) { - size_t align_size = 0; - if (communication_mem) { - align_size = GetCommunicationAlignSize(size); - } else { - align_size = GetCommonAlignSize(size); - } - if (communication_mem) { - // create protect area [kMemAlignSize -- data -- kMemAlignSize] - uint8_t *alloc_address = reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); - return alloc_address + kMemAlignSize; - } else { - return reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); - } -} - -uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { - size_t align_size = 0; - if (communication_mem) { - align_size = GetCommunicationAlignSize(size); - } else { - align_size = GetCommonAlignSize(size); - } - if (dynamic_mem_offset_ < align_size) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ - << "]) malloc [" << align_size << "] failed!"; - } - auto new_offset = dynamic_mem_offset_ - align_size; - auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); - if (new_offset <= device_mem_pool_offset) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ - << "] memory pool[" << device_mem_pool_offset << "])" - << " malloc [" << align_size << "] failed!"; - } - total_dynamic_size_ += align_size; - dynamic_mem_offset_ = new_offset; - AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); - if (communication_mem) { - // create protect area [kMemAlignSize -- data -- kMemAlignSize] - return device_mem_base_ + new_offset + kMemAlignSize; - } else { - return device_mem_base_ + new_offset; - } -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/device/ascend/ascend_memory_manager.h deleted file mode 100644 index 5b52412d78..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_manager.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ -#include "device/memory_manager.h" -namespace mindspore { -namespace device { -namespace ascend { -class AscendMemoryManager : public MemoryManager { - public: - AscendMemoryManager() = default; - ~AscendMemoryManager() override = default; - - void MallocDeviceMemory() override; - void FreeDeviceMemory() override; - void ResetDynamicMemory() override; - void *MallocMemFromMemPool(size_t size) override; - - protected: - uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; - uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override; - - private: - uint8_t *device_mem_pool_base_{nullptr}; - uint64_t device_mem_pool_size_{0}; - - uint64_t GetDeviceMemSizeFromContext(); -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc deleted file mode 100644 index f325046486..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/ascend_memory_pool.h" -#include "device/ascend/ascend_kernel_runtime.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace ascend { -size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - if (size == 0) { - MS_LOG(EXCEPTION) << "Can not alloc memory size(0) in memory pool !"; - } - if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) { - MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" - << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ - << "], need memory size [" << size << "]"; - } - *addr = device_mem_pool_base_ + device_mem_pool_offset_; - device_mem_pool_offset_ += size; - if (*addr == nullptr) { - MS_LOG(EXCEPTION) << "Alloc device address is nullptr, failed to alloc memory pool memory!"; - } - return size; -} - -bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { - MS_EXCEPTION_IF_NULL(addr); - return true; -} - -size_t AscendMemoryPool::AlignMemorySize(size_t size) const { - if (size == 0) { - MS_LOG(EXCEPTION) << "The align memory size is a zero !"; - } - return size; -} - -void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { - MS_EXCEPTION_IF_NULL(device_mem_pool_base); - device_mem_pool_base_ = device_mem_pool_base; -} - -void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset) { - graph_dynamic_mem_offset_ = graph_dynamic_mem_offset; -} - -uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; } - -size_t AscendMemoryPool::free_mem_size() { - if (graph_dynamic_mem_offset_ < device_mem_pool_offset_) { - MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_ - << "] less than device mem pool offset [" << device_mem_pool_offset_ << "]!"; - } - return graph_dynamic_mem_offset_ - device_mem_pool_offset_; -} - -size_t AscendMemoryPool::total_mem_size() { return graph_dynamic_mem_offset_ == 0 ? 0 : graph_dynamic_mem_offset_ - 1; } -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h deleted file mode 100644 index ef02f21cde..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ - -#include -#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" - -namespace mindspore { -namespace device { -namespace ascend { -class AscendMemoryPool : public DynamicMemPoolBestFit { - public: - ~AscendMemoryPool() override = default; - AscendMemoryPool(const AscendMemoryPool &) = delete; - AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; - - size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; - bool FreeDeviceMem(const DeviceMemPtr &addr) override; - void set_device_mem_pool_base(uint8_t *device_mem_pool_base); - void set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset); - - uint64_t device_mem_pool_offset() const; - size_t free_mem_size() override; - size_t total_mem_size() override; - - static AscendMemoryPool &GetInstance() { - static AscendMemoryPool instance; - return instance; - } - - protected: - // The real size by memory alloc aligned. - size_t AlignMemorySize(size_t size) const override; - - private: - AscendMemoryPool() = default; - uint8_t *device_mem_pool_base_{nullptr}; - uint64_t device_mem_pool_offset_{0}; - uint64_t graph_dynamic_mem_offset_{0}; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc deleted file mode 100644 index a68c408221..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ /dev/null @@ -1,1268 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/ascend_stream_assign.h" - -#include -#include - -#include "ir/manager.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_adjust.h" -#include "predict/generator/utils/ir_model_util.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" - -namespace mindspore { -namespace device { -namespace ascend { -const uint32_t kHcomMaxTask = 5; -const uint32_t kCommonMaxTask = 350; - -void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { - if (IsTaskSink()) { - Reset(); - ReorderIndependentOrders(graph_ptr); - AssignAllNodesStream(graph_ptr); - UpdateAtomicAddrCleanStreamId(graph_ptr); - InsertStreamActive(graph_ptr); - InsertEventForHcomParallel(graph_ptr); - InsertEventForIndependentParallel(graph_ptr); - GetNeedActiveStreams(graph_ptr); - graph_ptr->PrintGraphExecuteOrder(); - CheckResourceAssign(graph_ptr); - MS_LOG(INFO) << "After finish stream assign"; - - FindStreamRelations(graph_ptr); - PrintStreamRelations(); - GetStreamRelations(); - PrintStreamGroups(); - FindEventRelations(graph_ptr); - - // Get info for D Model - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); - generator::IRModelUtil::GetInstance().set_stream_num(resource_manager.get_cur_stream_num()); - // Init to 1,temporarily - generator::IRModelUtil::GetInstance().set_batch_num(1); - } -} - -// section 1 -void AscendStreamAssign::ReorderIndependentOrders(const NotNull &graph_ptr) { - std::vector exe_orders; - std::vector independents; - std::vector others; - - auto cnode_ptr_list = graph_ptr->execution_order(); - MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - auto cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (IsIndependentNode(cur_cnode_ptr)) { - independents.emplace_back(cur_cnode_ptr); - } else { - others.emplace_back(cur_cnode_ptr); - } - } - - if (others.empty() || independents.empty()) { - MS_LOG(INFO) << "Independent or others is empty, no need reorder"; - return; - } - - std::set processed; - for (size_t i = 0; i < others.size(); i++) { - auto begin = others.begin() + i; - auto end = begin + 1; - bool flag = false; - for (size_t j = 0; j < independents.size(); j++) { - auto cur_independent = independents[j]; - auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); - if (it != processed.end()) { - continue; - } - - auto res = FindTargetOp(begin, end, cur_independent); - if (res != end) { - flag = true; - exe_orders.emplace_back(cur_independent); - exe_orders.emplace_back(*begin); - processed.emplace(cur_independent.get()); - break; - } - } - - if (!flag) { - exe_orders.emplace_back(*begin); - } - } - - MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size(); - if (processed.size() != independents.size()) { - MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size"; - return; - } - - graph_ptr->set_execution_order(exe_orders); -} - -// section 2 -void AscendStreamAssign::AssignAllNodesStream(const NotNull &graph_ptr) { - auto cnode_ptr_list = graph_ptr->execution_order(); - bool exit_independent = false; - bool exit_hcom = false; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - // node has been assigned stream before - if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { - continue; - } - - if (IsHcom(cur_cnode_ptr)) { - exit_hcom = true; - continue; - } - - if (IsIndependentNode(cur_cnode_ptr)) { - exit_independent = true; - continue; - } - - AssignCommonStreamId(cur_cnode_ptr); - } - MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); - - if (exit_hcom) { - uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - // node has been assigned stream before - if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { - continue; - } - - if (IsHcom(cur_cnode_ptr)) { - AssignHcomStreamId(cur_cnode_ptr); - } - } - MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); - } - - if (exit_independent) { - uint32_t first_independ = resource_manager.ApplyNewStream(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { - continue; - } - if (IsIndependentNode(cur_cnode_ptr)) { - AssignIndependentStreamId(cur_cnode_ptr); - } - } - MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size(); - } - - MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); -} - -void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_common_stream_id = 0; - uint32_t cur_stream_num = resource_manager.get_cur_stream_num(); - if (cur_stream_num == 0) { - cur_common_stream_id = resource_manager.ApplyNewStream(); - } else { - cur_common_stream_id = resource_manager.GetCurAllocStreamId(); - } - - auto it = common_stream_map_.find(cur_common_stream_id); - if (it == common_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); - common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); - } else { - if (it->second < kCommonMaxTask) { - AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; - } else { - cur_common_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); - common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); - } - } -} - -void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); - auto it = hcom_stream_map_.find(cur_hcom_stream_id); - if (it == hcom_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); - hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); - } else { - if (it->second < kHcomMaxTask) { - AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; - } else { - cur_hcom_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); - hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); - } - } -} - -void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId(); - auto it = independent_stream_map_.find(cur_independent_id); - if (it == independent_stream_map_.end()) { - AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); - } else { - if (it->second < kCommonMaxTask) { - AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); - it->second++; - } else { - cur_independent_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); - independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); - } - } -} - -bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { - MS_EXCEPTION_IF_NULL(node_ptr); - if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) { - return false; - } - - if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { - MS_LOG(INFO) << "GetNext should not be independent node"; - return false; - } - - uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); - if (input_nums == 0) { - MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; - return true; - } - - auto inputs = node_ptr->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { - if (!inputs[i]->isa()) { - return false; - } - } - MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; - return true; -} - -// section 3: -void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - // update AtomicAddrClean stream same witch the next node - if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { - AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); - } - } - MS_LOG(INFO) << "End"; -} - -// section 4 -void AscendStreamAssign::InsertStreamActive(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - GetProcessedStream(graph_ptr); - std::vector update_cnode_list; - CNodePtr cur_cnode_ptr = nullptr; - CNodePtr pre_cnode_ptr = nullptr; - uint32_t pre_stream_id = UINT32_MAX; - - bool independent_flag = !(independent_stream_map_.empty()); - bool hcom_flag = !(hcom_stream_map_.empty()); - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (IsIndependentNode(cur_cnode_ptr)) { - update_cnode_list.emplace_back(cur_cnode_ptr); - continue; - } - - if (IsHcom(cur_cnode_ptr)) { - update_cnode_list.emplace_back(cur_cnode_ptr); - continue; - } - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - bool processed = IsProcessedStream(cur_stream_id); - // 1)inner stream assign, need insert active op - if (!processed) { - MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - // 1.set stream id - AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); - // 2.set active stream ids - std::vector active_index_list{cur_stream_id}; - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - update_cnode_list.emplace_back(active_ptr); - } - - if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { - MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; - UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); - } else { - update_cnode_list.emplace_back(cur_cnode_ptr); - } - - processed_streams_.emplace(cur_stream_id); - pre_stream_id = cur_stream_id; - pre_cnode_ptr = cur_cnode_ptr; - } - graph_ptr->set_execution_order(update_cnode_list); - MS_LOG(INFO) << "End"; -} - -void AscendStreamAssign::GetProcessedStream(const NotNull &graph_ptr) { - // 0 stream is activated at first - processed_streams_.emplace(0); - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - auto cur_cnode_ptr = cnode_ptr_list[i]; - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { - auto true_stream_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrTrueBranchStream); - processed_streams_.emplace(true_stream_id); - - if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { - continue; - } - auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); - if (need_active) { - processed_streams_.emplace(cur_stream_id); - } - } - } - for (const auto &item : processed_streams_) { - MS_LOG(INFO) << "Before active:" << item << " is been processed"; - } -} - -void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, - vector *orders) { - orders->emplace_back(switch_ptr); - if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { - return; - } - - auto need_active = AnfAlgo::GetNodeAttr(switch_ptr, kStreamNeedActivedFirst); - if (!need_active) { - return; - } - - MS_EXCEPTION_IF_NULL(switch_ptr); - auto true_stream_id = AnfAlgo::GetNodeAttr(switch_ptr, kAttrTrueBranchStream); - MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) - << "; active stream id:" << true_stream_id; - - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); - AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); - vector active_ids; - // active indepdent stream - for (const auto &item : independent_stream_map_) { - active_ids.emplace_back(item.first); - } - // active hcom stream - for (const auto &item : hcom_stream_map_) { - active_ids.emplace_back(item.first); - } - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); - - // update processed stream - independent_stream_activated_ = true; - for (const auto &item : independent_stream_map_) { - processed_streams_.emplace(item.first); - } - - hcom_stream_activated_ = true; - for (const auto &item : hcom_stream_map_) { - processed_streams_.emplace(item.first); - } - - orders->emplace_back(active_ptr); -} - -bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { - auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); - if (it != processed_streams_.end()) { - return true; - } - return false; -} - -// section5 -void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - InsertEventCommonDependHcom(graph_ptr); - InsertEventHcomDependCommon(graph_ptr); - InsertEventHcomDependHcom(graph_ptr); - MS_LOG(INFO) << "End"; -} - -void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes = cnode_ptr_list; - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto it = cnodes.begin(); - while (it != cnodes.end() && (it + 1) != cnodes.end()) { - MS_EXCEPTION_IF_NULL(*it); - MS_EXCEPTION_IF_NULL(*(it + 1)); - if (IsHcom(*it) && !IsHcom(*(it + 1))) { - CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); - it = cnodes.insert(it + 1, send_cnode_ptr); - - auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); - if (target == cnodes.end()) { - MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope() - << ", can't find target for insert recv op, no insert send/recv"; - it = cnodes.erase(it); - continue; - } - - if (IsHcom(*target)) { - it = cnodes.erase(it); - continue; - } - - // deal recv op - uint32_t stream_id = AnfAlgo::GetStreamId(*target); - CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); - (void)cnodes.insert(target, recv_cnode_ptr); - cur_event_id = resource_manager.ApplyNewEvent(); - } - ++it; - } - // one event allocated additional, should delete - resource_manager.DeleteEvent(); - graph_ptr->set_execution_order(cnodes); - MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); -} - -void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes; - CNodePtr cur_cnode_ptr = nullptr; - uint32_t pre_stream_id = UINT32_MAX; - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (i == 0) { - cnodes.emplace_back(cur_cnode_ptr); - pre_stream_id = cur_stream_id; - continue; - } - - if (!IsHcom(cur_cnode_ptr)) { - cnodes.emplace_back(cur_cnode_ptr); - pre_stream_id = cur_stream_id; - continue; - } - - if (cur_stream_id == pre_stream_id) { - cnodes.emplace_back(cur_cnode_ptr); - pre_stream_id = cur_stream_id; - continue; - } - - if (!IsHcom(cnode_ptr_list[i - 1])) { - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); - cnodes.emplace_back(send); - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); - cnodes.emplace_back(recv); - cnodes.emplace_back(cur_cnode_ptr); - } else { - cnodes.emplace_back(cur_cnode_ptr); - } - pre_stream_id = cur_stream_id; - } - - graph_ptr->set_execution_order(cnodes); - MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); -} - -void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - uint32_t first_hcom_stream = kInvalidStreamId; - uint32_t last_hcom_stream = kInvalidStreamId; - // key: stream id, value:hcom index - std::map> hcom_index; - for (size_t i = 0; i < cnode_ptr_list.size(); i++) { - auto cur_cnode = cnode_ptr_list[i]; - if (!IsHcom(cur_cnode)) { - continue; - } - uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - auto it = hcom_index.find(cur_stream_id); - if (it != hcom_index.end()) { - hcom_index[cur_stream_id].emplace_back(i); - } else { - hcom_index[cur_stream_id] = {i}; - } - - // record first hcom stream id - if (first_hcom_stream == kInvalidStreamId) { - first_hcom_stream = cur_stream_id; - } - - // record last hcom stream id - if (cur_stream_id != last_hcom_stream) { - last_hcom_stream = cur_stream_id; - } - } - - if (hcom_index.size() < 2) { - MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; - return; - } - InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); - MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); -} - -void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, - const map> &hcom_index, - uint32_t first_hcom_stream, uint32_t last_hcom_stream) { - vector orders; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); - size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); - std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); - for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { - auto cur_cnode = cnode_ptr_list[i]; - if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) { - orders.emplace_back(cur_cnode); - continue; - } - auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); - if (i == first_stream_last_index) { - // first fusion hcom - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); - } else if (i == last_stream_first_index) { - // last fusion hcom - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - orders.emplace_back(cur_cnode); - } else { - auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); - if (cur_stream_hcom_size == 1) { - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - cur_event_id = resource_manager.ApplyNewEvent(); - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); - } else { - // current stream, first hcom:add recv op - if (i == hcom_index.at(cur_hcom_stream_id).front()) { - auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(recv); - cur_event_id = resource_manager.ApplyNewEvent(); - orders.emplace_back(cur_cnode); - } else if (i == hcom_index.at(cur_hcom_stream_id).back()) { - // current stream, last hcom:add send op - orders.emplace_back(cur_cnode); - auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); - orders.emplace_back(send); - } else { - // current stream, not first and last op - orders.emplace_back(cur_cnode); - } - } - } - } - std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); - graph_ptr->set_execution_order(orders); -} - -bool AscendStreamAssign::IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, - size_t index) { - MS_EXCEPTION_IF_NULL(node_ptr); - auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); - auto it = hcom_index.find(cur_hcom_stream_id); - if (it == hcom_index.end()) { - return false; - } - auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); - if (iter == hcom_index.at(cur_hcom_stream_id).end()) { - return false; - } - return true; -} - -// section6 -void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull &graph_ptr) { - MS_LOG(INFO) << "Start"; - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto cnode_ptr_list = graph_ptr->execution_order(); - vector cnodes = cnode_ptr_list; - uint32_t cur_event_id = resource_manager.ApplyNewEvent(); - auto it = cnodes.begin(); - while (it != cnodes.end()) { - MS_EXCEPTION_IF_NULL(*it); - if (IsIndependentNode(*it)) { - MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]"; - CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); - it = cnodes.insert(it + 1, send_cnode_ptr); - - auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); - if (target == cnodes.end()) { - MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope() - << "] can't find target for insert recv op, no insert send/recv"; - it = cnodes.erase(it); - continue; - } - - // deal recv op - uint32_t stream_id = AnfAlgo::GetStreamId(*target); - CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); - (void)cnodes.insert(target, recv_cnode_ptr); - cur_event_id = resource_manager.ApplyNewEvent(); - } - ++it; - } - // one event allocated additional, should delete - resource_manager.DeleteEvent(); - graph_ptr->set_execution_order(cnodes); - MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num(); - MS_LOG(INFO) << "End"; -} - -// section7 -void AscendStreamAssign::GetNeedActiveStreams(const NotNull &graph_ptr) { - CNodePtr cur_cnode_ptr = nullptr; - auto cnode_ptr_list = graph_ptr->execution_order(); - // 1)first stream 0 should be actived first; - need_first_active_streams_.emplace_back(0); - - // 2)stream witch kStreamNeedActivedFirst attr should be actived; - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { - continue; - } - - auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); - if (need_active) { - auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first"; - need_first_active_streams_.push_back(stream_id); - } - } - - // 3)independent stream:if has not been activate, push to need active vector - if (!independent_stream_activated_) { - for (auto &item : independent_stream_map_) { - need_first_active_streams_.emplace_back(item.first); - } - } - - // 4)hcom stream:if has not been activate, push to need active vector - if (!hcom_stream_activated_) { - for (auto &item : hcom_stream_map_) { - need_first_active_streams_.emplace_back(item.first); - } - } -} - -// section8 -void AscendStreamAssign::CheckResourceAssign(const NotNull &graph_ptr) { - CheckStreamAssign(graph_ptr); - CheckEventAssign(graph_ptr); -} - -void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - std::set streams; - uint32_t max_stream = 0; - uint32_t min_stream = kInvalidStreamId; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); - if (stream_id == kInvalidStreamId) { - MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream"; - } - - (void)streams.emplace(stream_id); - if (stream_id > max_stream) { - max_stream = stream_id; - } - if (stream_id < min_stream) { - min_stream = stream_id; - } - } - - // check stream assign - if (!streams.empty()) { - if (min_stream != 0) { - MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream; - } - uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); - if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { - MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream - << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); - } - } -} - -void AscendStreamAssign::CheckEventAssign(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - std::map> event_map; - uint32_t max_event_id = 0; - uint32_t min_event_id = kInvalidEventId; - auto cnode_ptr_list = graph_ptr->execution_order(); - for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { - CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr); - if (name == kSendOpName || name == kRecvOpName) { - uint32_t event_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId); - if (event_id > max_event_id) { - max_event_id = event_id; - } - - if (event_id < min_event_id) { - min_event_id = event_id; - } - auto it = event_map.find(event_id); - if (it == event_map.end()) { - event_map[event_id] = {cur_cnode_ptr}; - } else { - event_map[event_id].emplace_back(cur_cnode_ptr); - } - } - } - // check event assign - if (!event_map.empty()) { - if (min_event_id != 0) { - MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id; - } - uint32_t assigned_event_num = resource_manager.get_cur_event_num(); - if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { - MS_LOG(EXCEPTION) << "Event should be consecutive"; - } - for (const auto &item : event_map) { - if (item.second.size() != 2) { - MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id"; - } - auto first_name = AnfAlgo::GetCNodeName(item.second[0]); - auto second_name = AnfAlgo::GetCNodeName(item.second[1]); - if (!(first_name == kSendOpName && second_name == kRecvOpName)) { - MS_LOG(EXCEPTION) << "Send should be before recv"; - } - } - } -} - -// section9 -CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, - uint32_t stream_id) { - auto send_op = std::make_shared(kSendOpName); - MS_EXCEPTION_IF_NULL(send_op); - auto send_apply = std::make_shared(send_op); - MS_EXCEPTION_IF_NULL(send_apply); - std::vector send_input_list = {send_apply}; - CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); - MS_EXCEPTION_IF_NULL(send_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - send_node_ptr->set_abstract(abstract_none); - AnfAlgo::SetStreamId(stream_id, send_node_ptr.get()); - return send_node_ptr; -} - -CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, - uint32_t stream_id) { - auto recv_op = std::make_shared(kRecvOpName); - MS_EXCEPTION_IF_NULL(recv_op); - auto recv_apply = std::make_shared(recv_op); - MS_EXCEPTION_IF_NULL(recv_apply); - std::vector recv_input_list = {recv_apply}; - CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); - MS_EXCEPTION_IF_NULL(recv_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); - AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - recv_node_ptr->set_abstract(abstract_none); - return recv_node_ptr; -} - -vector::iterator AscendStreamAssign::FindTargetOp(vector::iterator begin, - vector::iterator end, const CNodePtr &node) { - while (begin != end) { - auto inputs = (*begin)->inputs(); - for (size_t i = 1; i < inputs.size(); i++) { - auto input = inputs[i]; - if (opt::IsNopNode(input)) { - CNodePtr cnode = input->cast(); - auto new_inputs = cnode->inputs(); - for (size_t j = 1; j < new_inputs.size(); j++) { - auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); - if (node == new_real_input.first) { - MS_LOG(INFO) << "Nop node find target op[" << (*begin)->DebugString() << "]"; - return begin; - } - } - } else { - auto real_input = AnfAlgo::VisitKernel(input, 0); - if (node == real_input.first) { - MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]"; - return begin; - } - } - } - ++begin; - } - return end; -} - -bool AscendStreamAssign::IsTaskSink() { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->enable_task_sink()) { - MS_LOG(INFO) << "Task sink mode is not enable"; - return false; - } else { - MS_LOG(INFO) << "Task sink mode is enable"; - return true; - } -} - -void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { - MS_EXCEPTION_IF_NULL(wait_active_stream_list); - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - uint32_t total_stream_num = resource_manager.get_cur_stream_num(); - if (total_stream_num == 0) { - MS_LOG(INFO) << "The total_common_stream_num is zero"; - return; - } - - // common stream:active first common stream - for (uint32_t i = 0; i < total_stream_num; i++) { - auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); - if (it == need_first_active_streams_.end()) { - MS_LOG(INFO) << "Wait common stream id = " << i; - wait_active_stream_list->push_back(i); - } - } -} - -bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { - MS_EXCEPTION_IF_NULL(apply_kernel); - return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; -} - -void AscendStreamAssign::GetHcomStreams(std::vector *streams) { - MS_EXCEPTION_IF_NULL(streams); - for (const auto &item : hcom_stream_map_) { - streams->emplace_back(item.first); - } -} - -void AscendStreamAssign::Reset() { - independent_stream_activated_ = false; - hcom_stream_activated_ = false; - independent_stream_map_.clear(); - hcom_stream_map_.clear(); - common_stream_map_.clear(); - processed_streams_.clear(); - need_first_active_streams_.clear(); - stream_groups_.clear(); - stream_relations_.clear(); - event_map_.clear(); -} - -// section 10 -bool AscendStreamAssign::IsVecExist(std::vector *group) { - auto group_size = group->size(); - if (group_size == 0) { - return false; - } - for (const auto &item : stream_groups_) { - if (item.size() < group->size()) { - continue; - } - - bool flag = true; - for (size_t i = 0; i < group_size; i++) { - if (item[i] != group->at(i)) { - flag = false; - break; - } - } - - if (flag) { - return true; - } else { - continue; - } - } - - return false; -} - -void AscendStreamAssign::DFS(uint32_t start, std::vector *group) { - auto it = stream_relations_.find(start); - if (it == stream_relations_.end()) { - if (!IsVecExist(group)) { - stream_groups_.emplace_back(*group); - } else { - MS_LOG(WARNING) << "DFS should not print this log"; - } - return; - } - - vector active_streams = stream_relations_[start]; - - for (const auto &item : active_streams) { - group->emplace_back(item); - DFS(item, group); - group->pop_back(); - } -} - -void AscendStreamAssign::GetStreamRelations() { - for (const auto &start : need_first_active_streams_) { - vector group{start}; - DFS(start, &group); - } -} - -void AscendStreamAssign::FindStreamRelations(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto stream_num = resource_manager.get_cur_stream_num(); - if (stream_num <= 1) { - return; - } - - auto exe_orders = graph_ptr->execution_order(); - for (size_t i = 0; i < exe_orders.size(); i++) { - auto cur_cnode = exe_orders[i]; - auto name = AnfAlgo::GetCNodeName(cur_cnode); - if (name != kStreamSwitchOpName && name != kStreamActiveOpName) { - continue; - } - - // support:streamswitch is begin of the stream - if (name == kStreamSwitchOpName) { - GetStreamSwitchStreamRelation(cur_cnode); - } - - if (name == kStreamActiveOpName) { - GetStreamActiveStreamRelation(graph_ptr, i); - } - } -} - -void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) { - MS_EXCEPTION_IF_NULL(node_ptr); - auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr); - auto true_stream_id = AnfAlgo::GetNodeAttr(node_ptr, kAttrTrueBranchStream); - if (true_stream_id <= cur_stream_id) { - MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id - << " is greater than true branch stream id:" << true_stream_id; - } - auto it = stream_relations_.find(cur_stream_id); - if (it == stream_relations_.end()) { - stream_relations_[cur_stream_id] = {true_stream_id}; - } else { - auto iter = - std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id); - if (iter == stream_relations_[cur_stream_id].end()) { - stream_relations_[cur_stream_id].emplace_back(true_stream_id); - } - } -} - -void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index) { - StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index); - if (kind == kInvalid) { - MS_LOG(INFO) << "Invalid streamActive kind"; - return; - } - - auto orders = graph_ptr->execution_order(); - auto cur_cnode = orders[index]; - auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - auto active_list = AnfAlgo::GetNodeAttr>(cur_cnode, kAttrActiveStreamList); - if (kind == kHead) { - uint32_t active_current_node = GetStreamByActivedStream(cur_stream_id); - if (active_current_node == kInvalidStreamId) { - MS_LOG(EXCEPTION) << "No stream to active streamactive stream"; - } - - for (const auto &item : active_list) { - if (item <= active_current_node) { - MS_LOG(WARNING) << "Actived stream is less than activing stream"; - continue; - } - auto it = - std::find(stream_relations_[active_current_node].begin(), stream_relations_[active_current_node].end(), item); - if (it == stream_relations_[active_current_node].end()) { - stream_relations_[active_current_node].emplace_back(item); - } - } - } - - if (kind == kMiddle) { - for (const auto &stream : active_list) { - if (stream <= cur_stream_id) { - MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal"; - } else { - MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now"; - } - } - } - - if (kind == kTail) { - auto it = stream_relations_.find(cur_stream_id); - if (it == stream_relations_.end()) { - stream_relations_[cur_stream_id] = active_list; - } else { - for (const auto &stream : active_list) { - if (stream <= cur_stream_id) { - MS_LOG(WARNING) << "Actived stream is less than activing stream"; - continue; - } - auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream); - if (iter == stream_relations_[cur_stream_id].end()) { - stream_relations_[cur_stream_id].emplace_back(stream); - } - } - } - } -} - -StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull &graph_ptr, size_t index) { - auto exe_orders = graph_ptr->execution_order(); - if (index >= exe_orders.size()) { - MS_LOG(EXCEPTION) << "Invalid op index:" << index; - } - - auto cur_cnode = exe_orders[index]; - auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); - if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) { - MS_LOG(EXCEPTION) << "Current node name is not StreamActive"; - } - - if (index == 0) { - return kInvalid; - } - - if (index == exe_orders.size() - 1) { - return kInvalid; - } - - uint32_t pre_stream_id = UINT32_MAX; - uint32_t next_stream_id = UINT32_MAX; - int32_t start = SizeToInt(index) - 1; - for (int32_t i = start; i >= 0; i--) { - auto cnode = exe_orders[IntToSize(i)]; - auto name = AnfAlgo::GetCNodeName(cnode); - if (name == kSendOpName || name == kRecvOpName) { - continue; - } - - pre_stream_id = AnfAlgo::GetStreamId(cnode); - break; - } - - for (size_t i = index + 1; i < exe_orders.size(); i++) { - auto cnode = exe_orders[i]; - auto name = AnfAlgo::GetCNodeName(cnode); - if (name == kSendOpName || name == kRecvOpName) { - continue; - } - - next_stream_id = AnfAlgo::GetStreamId(cnode); - break; - } - - // pre_stream_id = UINT32_MAX:means no node active current StreamActive - // next_stream_id = UINT32_MAX:means current StreamActive active no node - if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) { - return kInvalid; - } - - if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) { - return kMiddle; - } - - if (cur_stream_id == pre_stream_id) { - return kTail; - } - - if (cur_stream_id == next_stream_id) { - return kHead; - } - - return kInvalid; -} - -uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) { - if (stream_relations_.empty()) { - return kInvalidStreamId; - } - - for (const auto &item : stream_relations_) { - auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id); - if (it != item.second.end()) { - return item.first; - } - } - - return kInvalidStreamId; -} - -void AscendStreamAssign::PrintStreamRelations() { - MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size(); - for (const auto &item : stream_relations_) { - MS_LOG(INFO) << "Stream:" << item.first; - for (const auto &stream : item.second) { - MS_LOG(INFO) << "--actived stream id:" << stream; - } - } -} - -void AscendStreamAssign::PrintStreamGroups() { - MS_LOG(INFO) << "Stream group size:" << stream_groups_.size(); - for (const auto &item : stream_groups_) { - MS_LOG(INFO) << "Group:"; - for (const auto &stream : item) { - MS_LOG(INFO) << "Stream id:" << stream; - } - } -} - -// section 11 -bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const { - size_t send_group = 0; - size_t recv_group = 0; - bool send_flag = true; - bool recv_flag = true; - for (size_t i = 0; i < stream_groups_.size(); i++) { - auto group = stream_groups_[i]; - if (send_flag) { - auto it = std::find(group.begin(), group.end(), send_stream_id); - if (it != group.end()) { - send_group = i; - send_flag = false; - } - } - - if (recv_flag) { - auto it = std::find(group.begin(), group.end(), recv_stream_id); - if (it != group.end()) { - recv_group = i; - recv_flag = false; - } - } - } - - if (!(send_flag || recv_flag)) { - return (send_group != recv_group); - } - - return false; -} - -void AscendStreamAssign::FindEventRelations(const NotNull &graph_ptr) { - AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - auto event_nums = resource_manager.get_cur_event_num(); - if (event_nums == 0) { - return; - } - auto exe_orders = graph_ptr->execution_order(); - // find all event info - for (size_t i = 0; i < exe_orders.size(); i++) { - auto cur_cnode = exe_orders[i]; - auto name = AnfAlgo::GetCNodeName(cur_cnode); - if (name == kSendOpName) { - event_map_[cur_cnode] = {}; - } - - if (name == kRecvOpName) { - auto recv_event_id = AnfAlgo::GetNodeAttr(cur_cnode, kAttrEventId); - for (auto &item : event_map_) { - auto send_event_id = AnfAlgo::GetNodeAttr(item.first, kAttrEventId); - if (recv_event_id == send_event_id) { - item.second = cur_cnode; - break; - } - } - } - } - - // delete useless event info - auto begin = event_map_.begin(); - while (begin != event_map_.end()) { - auto send_stream_id = AnfAlgo::GetStreamId(begin->first); - auto recv_stream_id = AnfAlgo::GetStreamId(begin->second); - bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id); - if (!flag) { - begin = event_map_.erase(begin); - } else { - begin++; - } - } - - MS_LOG(INFO) << "Satisfied event info"; - for (const auto &item : event_map_) { - MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr(item.first, kAttrEventId); - } -} - -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h deleted file mode 100644 index d268e0c975..0000000000 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ /dev/null @@ -1,185 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "runtime/base.h" -#include "runtime/rt_model.h" -#include "runtime/stream.h" -#include "session/kernel_graph.h" -#include "utils/contract.h" - -namespace mindspore { -namespace device { -namespace ascend { -using std::map; -using std::shared_ptr; -using std::unordered_map; -using std::unordered_set; -using std::vector; -const uint32_t kInvalidStreamId = UINT32_MAX; -const uint32_t kInvalidEventId = UINT32_MAX; -class AscendResourceMng { - public: - static AscendResourceMng &GetInstance() { - static AscendResourceMng instance; - return instance; - } - - void ResetResource() { - cur_stream_num_ = 0; - cur_event_num_ = 0; - } - uint32_t ApplyNewStream() { - if (!cur_stream_num_) { - uint32_t cur_stream_id = cur_stream_num_; - cur_stream_num_++; - return cur_stream_id; - } - uint32_t cur_stream_id = cur_stream_num_; - cur_stream_num_++; - return cur_stream_id; - } - uint32_t ApplyNewEvent() { - if (!cur_event_num_) { - uint32_t cur_event_id = cur_event_num_; - cur_event_num_++; - return cur_event_id; - } - uint32_t cur_event_id = cur_event_num_; - cur_event_num_++; - return cur_event_id; - } - - void DeleteEvent() { - if (!cur_event_num_) { - MS_LOG(WARNING) << "total event num is 0, no event to delete"; - } else { - --cur_event_num_; - } - } - uint32_t get_cur_stream_num() { return cur_stream_num_; } - uint32_t GetCurAllocStreamId() { - if (!cur_stream_num_) { - MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; - } - return cur_stream_num_ - 1; - } - uint32_t get_cur_event_num() { return cur_event_num_; } - - private: - uint32_t cur_stream_num_{0}; - uint32_t cur_event_num_{0}; -}; - -enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail }; -class AscendStreamAssign { - public: - static AscendStreamAssign &GetInstance() { - static AscendStreamAssign instance; // Guaranteed to be destroyed. - return instance; - } - - AscendStreamAssign(const AscendStreamAssign &) = delete; - AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; - - void AssignStream(const NotNull &graph_ptr); - void GetHcomStreams(std::vector *streams); - void GetWaitStreams(vector *wait_active_stream_list); - CNodePtr CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); - CNodePtr CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); - const std::vector> &get_stream_group() const { return stream_groups_; } - const std::map &get_event_map() const { return event_map_; } - - private: - AscendStreamAssign() = default; - ~AscendStreamAssign() = default; - void Reset(); - void CheckResourceAssign(const NotNull &graph_ptr); - void CheckStreamAssign(const NotNull &graph_ptr); - void CheckEventAssign(const NotNull &graph_ptr); - void AssignAllNodesStream(const NotNull &graph_ptr); - void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); - void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); - void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); - void UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr); - void FindHcomParallelStreams(const NotNull &graph_ptr); - void InsertStreamActive(const NotNull &graph_ptr); - void UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, - vector *orders); - void InsertEventForIndependentParallel(const NotNull &graph_ptr); - void InsertEventForHcomParallel(const NotNull &graph_ptr); - void InsertEventCommonDependHcom(const NotNull &graph_ptr); - void InsertEventHcomDependCommon(const NotNull &graph_ptr); - void InsertEventHcomDependHcom(const NotNull &graph_ptr); - void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, - uint32_t first_hcom_stream, uint32_t last_hcom_stream); - bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); - - void GetProcessedStream(const NotNull &graph_ptr); - void GetNeedActiveStreams(const NotNull &graph_ptr); - void ReorderIndependentOrders(const NotNull &graph_ptr); - - bool IsTaskSink(); - bool IsHcom(const CNodePtr &cur_cnode_ptr); - bool IsIndependentNode(const CNodePtr &node_ptr); - bool IsProcessedStream(uint32_t stream_id); - vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr &node); - void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); - - // function for memory resue - void GetStreamRelations(); - void DFS(uint32_t start, std::vector *group); - bool IsVecExist(std::vector *group); - void FindStreamRelations(const NotNull &graph_ptr); - void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); - void GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index); - StreamActiveKind GetStreamActiveKind(const NotNull &graph_ptr, size_t index); - uint32_t GetStreamByActivedStream(uint32_t actived_stream_id); - void PrintStreamRelations(); - void PrintStreamGroups(); - void FindEventRelations(const NotNull &graph_ptr); - bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const; - - bool independent_stream_activated_{false}; - bool hcom_stream_activated_{false}; - std::map independent_stream_map_{}; - std::map hcom_stream_map_{}; - std::map common_stream_map_{}; - std::set processed_streams_{}; - std::vector need_first_active_streams_{}; - - // attr for memory copy reuse - std::map> stream_relations_{}; - std::vector> stream_groups_{}; - std::map event_map_; - // new policy end -}; -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/device/ascend/dump/data_dumper.cc deleted file mode 100644 index 14f2c2a524..0000000000 --- a/mindspore/ccsrc/device/ascend/dump/data_dumper.cc +++ /dev/null @@ -1,282 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifdef ENABLE_DATA_DUMP -#include "device/ascend/dump/data_dumper.h" - -#include -#include -#include -#include "utility" -#include "session/anf_runtime_algorithm.h" -#include "runtime/mem.h" -#include "runtime/kernel.h" -#include "device/ascend/dump/ge_dump.h" -#include "proto/op_mapping_info.pb.h" -#include "utils/context/ms_context.h" -#include "debug/data_dump_parser.h" - -constexpr uint32_t kAicpuLoadFlag = 1; -constexpr uint32_t kAicpuUnloadFlag = 0; -constexpr uint32_t kTupleTaskId = 0; -constexpr uint32_t kTupleStreamId = 1; -constexpr uint32_t kTupleArgs = 2; -constexpr uint32_t kCurrentStepTensorIndex = 0; -constexpr uint32_t kCurrentEpochTensorIndex = 1; -constexpr uint32_t kStepsPerEpochTensorIndex = 2; - -namespace mindspore { -namespace device { -namespace ascend { -void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task); -void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task); -void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr); - -DataDumper::~DataDumper() { - ReleaseDevMem(&dev_load_mem_); - ReleaseDevMem(&dev_unload_mem_); -} - -void DataDumper::LoadDumpInfo() { - MS_LOG(INFO) << "[DataDump] LoadDumpInfo start"; - MS_EXCEPTION_IF_NULL(kernel_graph_); - aicpu::dump::OpMappingInfo dump_info; - SetOpMappingInfo(NOT_NULL(&dump_info)); - - auto kernels = kernel_graph_->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - if (!KernelNeedDump(kernel)) { - continue; - } - MS_LOG(INFO) << "[DataDump] LoadDumpInfo kernel:" << kernel->fullname_with_scope(); - dump_kernel_names_.emplace_back(kernel->fullname_with_scope()); - - aicpu::dump::Task task; - ConstructDumpTask(NOT_NULL(kernel), NOT_NULL(&task)); - MS_EXCEPTION_IF_NULL(dump_info.mutable_task()); - dump_info.mutable_task()->Add(std::move(task)); - } - RtLoadDumpData(dump_info, &dev_load_mem_); - load_flag_ = true; - MS_LOG(INFO) << "[DataDump] LoadDumpInfo end"; -} - -void DataDumper::SetOpMappingInfo(NotNull dump_info) const { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(kernel_graph_); - auto dump_path = DataDumpParser::GetInstance().GetDumpPath(); - if (!dump_path.has_value()) { - MS_LOG(EXCEPTION) << "Dump path invalid"; - } - auto device_id = context_ptr->device_id(); - dump_info->set_dump_path(dump_path.value() + "_" + std::to_string(device_id) + "/"); - MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); - - dump_info->set_model_name(DataDumpParser::GetInstance().net_name() + "_" + std::to_string(kernel_graph_->graph_id())); - dump_info->set_dump_step(std::to_string(DataDumpParser::GetInstance().dump_step())); - dump_info->set_model_id(kernel_graph_->graph_id()); - dump_info->set_flag(kAicpuLoadFlag); - - const auto &input_ctrl_tensors = kernel_graph_->input_ctrl_tensors(); - if (input_ctrl_tensors == nullptr || input_ctrl_tensors->size() < 3) { - MS_LOG(INFO) << "[DataDump] Not data sink mode, input_ctrl_tensor"; - return; - } - const auto ¤t_step_tensor = input_ctrl_tensors->at(kCurrentStepTensorIndex); - const auto &currnet_epoch_tensor = input_ctrl_tensors->at(kCurrentEpochTensorIndex); - const auto &steps_per_epoch_tensor = input_ctrl_tensors->at(kStepsPerEpochTensorIndex); - - MS_EXCEPTION_IF_NULL(current_step_tensor); - MS_EXCEPTION_IF_NULL(currnet_epoch_tensor); - MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor); - MS_EXCEPTION_IF_NULL(current_step_tensor->device_address()); - MS_EXCEPTION_IF_NULL(currnet_epoch_tensor->device_address()); - MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor->device_address()); - - void *current_step = current_step_tensor->device_address()->ptr_; - void *current_epoch = currnet_epoch_tensor->device_address()->ptr_; - void *steps_per_epoch = steps_per_epoch_tensor->device_address()->ptr_; - - if (current_epoch != nullptr && current_step != nullptr && steps_per_epoch != nullptr) { - dump_info->set_step_id_addr(reinterpret_cast(current_epoch)); - dump_info->set_loop_cond_addr(reinterpret_cast(current_step)); - dump_info->set_iterations_per_loop_addr(reinterpret_cast(steps_per_epoch)); - } else { - MS_LOG(INFO) << "Invalid ctrl tensor device address"; - } -} - -bool DataDumper::KernelNeedDump(const CNodePtr &kernel) const { - if (AnfAlgo::GetKernelType(kernel) != TBE_KERNEL && AnfAlgo::GetKernelType(kernel) != AICPU_KERNEL && - AnfAlgo::GetKernelType(kernel) != AKG_KERNEL) { - return false; - } - MS_EXCEPTION_IF_NULL(kernel); - // dump all kernel if mode is set 0 in data_dump.json - return DataDumpParser::GetInstance().NeedDump(kernel->fullname_with_scope()); -} - -void DataDumper::UnloadDumpInfo() { - if (!load_flag_) { - MS_LOG(WARNING) << "Load not success, no need to unload"; - return; - } - MS_EXCEPTION_IF_NULL(kernel_graph_); - MS_LOG(INFO) << "[DataDump] UnloadDumpInfo start. graphId:" << kernel_graph_->graph_id(); - - aicpu::dump::OpMappingInfo op_mapping_info; - op_mapping_info.set_model_id(kernel_graph_->graph_id()); - op_mapping_info.set_flag(kAicpuUnloadFlag); - - for (const auto &kernel_name : dump_kernel_names_) { - aicpu::dump::Task task; - auto iter = runtime_info_map_.find(kernel_name); - if (iter == runtime_info_map_.end()) { - MS_LOG(EXCEPTION) << "[DataDump] kernel name not found in runtime_info_map"; - } - MS_EXCEPTION_IF_NULL(iter->second); - auto task_id = std::get(*iter->second); - task.set_task_id(task_id); - MS_EXCEPTION_IF_NULL(op_mapping_info.mutable_task()); - op_mapping_info.mutable_task()->Add(std::move(task)); - } - - RtLoadDumpData(op_mapping_info, &dev_unload_mem_); -} - -void DataDumper::ReleaseDevMem(void **ptr) const { - if (ptr == nullptr) { - return; - } - if (*ptr != nullptr) { - rtError_t rt_error = rtFree(*ptr); - if (rt_error != RT_ERROR_NONE) { - MS_LOG(ERROR) << "[DataDump] Call rtFree failed, ret:" << rt_error; - } - *ptr = nullptr; - } -} - -void DataDumper::ConstructDumpTask(NotNull kernel, NotNull dump_task) const { - dump_task->set_end_graph(false); - auto iter = runtime_info_map_.find(kernel->fullname_with_scope()); - if (iter == runtime_info_map_.end()) { - MS_LOG(EXCEPTION) << "[DataDump] kernel name not found in runtime_info_map"; - } - MS_EXCEPTION_IF_NULL(iter->second); - auto task_id = std::get(*iter->second); - auto stream_id = std::get(*iter->second); - auto args = std::get(*iter->second); - MS_LOG(INFO) << "[DataDump] Get runtime info task_id:" << task_id << " stream_id:" << stream_id; - - dump_task->set_task_id(task_id); - dump_task->set_stream_id(stream_id); - MS_EXCEPTION_IF_NULL(dump_task->mutable_op()); - dump_task->mutable_op()->set_op_name(kernel->fullname_with_scope()); - dump_task->mutable_op()->set_op_type(AnfAlgo::GetCNodeName(kernel.get())); - - DumpKernelOutput(kernel, args, dump_task); - DumpKernelInput(kernel, args, dump_task); -} - -void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr) { - std::string proto_str; - size_t proto_size = dump_info.ByteSizeLong(); - bool ret = dump_info.SerializeToString(&proto_str); - if (!ret || proto_size == 0) { - MS_LOG(EXCEPTION) << "[DataDump] Protobuf SerializeToString failed, proto size %zu."; - } - - rtError_t rt_ret = rtMalloc(ptr, proto_size, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "[DataDump] Call rtMalloc failed"; - } - - if (ptr == nullptr) { - MS_LOG(ERROR) << "[DataDump] rtMalloc failed, ptr is nullptr"; - return; - } - rt_ret = rtMemcpy(*ptr, proto_size, proto_str.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "[DataDump] Call rtMemcpy failed"; - } - - MS_LOG(INFO) << "[DataDump] rtDatadumpInfoLoad start"; - rt_ret = rtDatadumpInfoLoad(*ptr, proto_size); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "[DataDump] Call rtDatadumpInfoLoad failed"; - } -} - -void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task) { - MS_LOG(INFO) << "[DataDump] DumpKernelOutput start. Kernel:" << kernel->fullname_with_scope(); - auto input_size = AnfAlgo::GetInputTensorNum(kernel); - auto output_size = AnfAlgo::GetOutputTensorNum(kernel); - uint64_t offset = sizeof(void *) * input_size; - for (size_t i = 0; i < output_size; ++i) { - auto data_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - auto output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i); - - aicpu::dump::Output output; - output.set_data_type(GetGeDataType(data_type)); - output.set_format(GetGeFormat(output_format, output_shape.size())); - MS_EXCEPTION_IF_NULL(output.mutable_shape()); - for (auto dim : output_shape) { - output.mutable_shape()->add_dim(dim); - } - output.set_original_output_format(GetGeFormat(output_format, output_shape.size())); - output.set_address(static_cast(reinterpret_cast(args)) + offset); - MS_EXCEPTION_IF_NULL(task->mutable_output()); - task->mutable_output()->Add(std::move(output)); - offset += sizeof(void *); - } -} - -void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task) { - MS_LOG(INFO) << "[DataDump] DumpKernelInput start. Kernel:" << kernel->fullname_with_scope(); - auto input_size = AnfAlgo::GetInputTensorNum(kernel); - uint64_t offset = 0; - for (size_t i = 0; i < input_size; ++i) { - aicpu::dump::Input input; - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); - auto input_node = input_node_with_index.first; - auto input_index = input_node_with_index.second; - std::string output_format = AnfAlgo::GetOutputFormat(input_node, input_index); - auto output_type = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); - if (output_type == kTypeUnknown) { - MS_LOG(WARNING) << "[DataDump] It is not suggested to use a lonely weight parameter as the output of graph"; - output_type = AnfAlgo::GetOutputInferDataType(input_node, input_index); - } - auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index); - - input.set_data_type(GetGeDataType(output_type)); - input.set_format(GetGeFormat(output_format, output_shape.size())); - MS_EXCEPTION_IF_NULL(input.mutable_shape()); - for (auto dim : output_shape) { - input.mutable_shape()->add_dim(dim); - } - input.set_address(static_cast(reinterpret_cast(args)) + offset); - MS_EXCEPTION_IF_NULL(task->mutable_input()); - task->mutable_input()->Add(std::move(input)); - offset += sizeof(void *); - } -} -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/device/ascend/dump/data_dumper.h b/mindspore/ccsrc/device/ascend/dump/data_dumper.h deleted file mode 100644 index 65b01c61c4..0000000000 --- a/mindspore/ccsrc/device/ascend/dump/data_dumper.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ -#ifdef ENABLE_DATA_DUMP -#include -#include -#include -#include -#include -#include "session/kernel_graph.h" - -namespace aicpu { -namespace dump { -class OpMappingInfo; -class Task; -} // namespace dump -} // namespace aicpu -namespace mindspore { -namespace device { -namespace ascend { -// tuple(op_name, task_id, stream_id, args) -using RuntimeInfo = std::tuple; -class DataDumper { - public: - DataDumper(const session::KernelGraph *kernel_graph, - const std::map> &runtime_info_map) - : load_flag_(false), - dev_load_mem_(nullptr), - dev_unload_mem_(nullptr), - kernel_graph_(kernel_graph), - runtime_info_map_(runtime_info_map) {} - ~DataDumper(); - void LoadDumpInfo(); - - void UnloadDumpInfo(); - - private: - void ReleaseDevMem(void **ptr) const; - bool KernelNeedDump(const CNodePtr &kernel) const; - void SetOpMappingInfo(NotNull dump_info) const; - void ConstructDumpTask(NotNull kernel, NotNull dump_task) const; - - bool load_flag_; - void *dev_load_mem_; - void *dev_unload_mem_; - std::vector dump_kernel_names_; - const session::KernelGraph *kernel_graph_; - std::map> runtime_info_map_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_build_ascend.cc deleted file mode 100644 index bd0b436344..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_build_ascend.cc +++ /dev/null @@ -1,286 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/kernel_build_ascend.h" - -#include -#include -#include -#include - -#include "device/ascend/kernel_select_ascend.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "kernel/tbe/tbe_kernel_parallel_build.h" -#include "kernel/akg/ascend/akg_ascend_kernel_build.h" -#include "kernel/aicpu/aicpu_kernel_build.h" -#include "kernel/hccl/hccl_kernel_build.h" -#include "kernel/rts/rt_kernel_build.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/common_utils.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -#include "./common.h" - -namespace mindspore { -namespace device { -namespace ascend { -using mindspore::kernel::tbe::TbeUtils; -using std::make_shared; -static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { - kernel::KernelModPtr kernel_mod_ptr = nullptr; - KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); - switch (kernel_type) { - case KernelType::AICPU_KERNEL: { - kernel_mod_ptr = kernel::AicpuOpBuild(anf_node); - break; - } - case KernelType::RT_KERNEL: { - kernel_mod_ptr = kernel::RtOpBuild(anf_node); - break; - } - case KernelType::HCCL_KERNEL: { - kernel_mod_ptr = kernel::HcclOpBuild(anf_node); - break; - } - default: { - MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type; - } - } - return kernel_mod_ptr; -} - -static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - std::vector tbe_nodes; - for (const auto &anf_node : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!AnfAlgo::IsRealKernel(anf_node)) { - continue; - } - KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); - switch (kernel_type) { - case KernelType::TBE_KERNEL: { - if (AnfAlgo::GetKernelMod(anf_node) == nullptr && - AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) { - tbe_nodes.push_back(anf_node); - } - break; - } - default: { - break; - } - } - } - bool ret = kernel::TbeOpParallelPreBuild(tbe_nodes); - return ret; -} - -static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - std::vector tbe_nodes; - std::vector akg_nodes; - std::vector other_nodes; - for (const auto &anf_node : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!AnfAlgo::IsRealKernel(anf_node)) { - continue; - } - KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); - switch (kernel_type) { - case KernelType::TBE_KERNEL: { - if (AnfAlgo::GetKernelMod(anf_node) == nullptr) { - tbe_nodes.push_back(anf_node); - } - break; - } - case KernelType::AKG_KERNEL: { - akg_nodes.push_back(anf_node); - break; - } - default: { - other_nodes.push_back(anf_node); - break; - } - } - } - bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes); - bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes); - auto bin_map = kernel::tbe::KernelMeta::GetInstance(); - (void)bin_map->ReadIndex(kernel::kCceKernelMeta); - for (const auto &anf_node : other_nodes) { - kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - } - return tbe_ret && akg_ret; -} - -static std::vector CalCleanZerosSize(const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(pre_node); - auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); - MS_EXCEPTION_IF_NULL(kernel_mod); - std::vector clean_size_list; - // clean output - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - auto output_men_size = kernel_mod->GetOutputSizeList(); - for (auto index : output_indexs) { - auto clean_item = (output_men_size.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; - clean_size_list.emplace_back(clean_item); - } - } - // clean workspace - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); - for (const auto &index : workspace_indexs) { - auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; - clean_size_list.emplace_back(clean_item); - } - } - MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope(); - return clean_size_list; -} - -static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph, - const mindspore::CNodePtr &pre_node, std::vector *new_nodes) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(pre_node); - MS_EXCEPTION_IF_NULL(new_nodes); - auto clear_zero_prim = std::make_shared(kAtomicAddrCleanOpName); - MS_EXCEPTION_IF_NULL(clear_zero_prim); - auto new_value_node = NewValueNode(clear_zero_prim); - MS_EXCEPTION_IF_NULL(new_value_node); - std::vector inputs = {new_value_node}; - inputs.push_back(pre_node); - CNodePtr clear_zero = kernel_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(clear_zero); - AbstractBasePtr abstract = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract); - clear_zero->set_abstract(abstract); - auto builder = std::make_shared(); - builder->SetKernelType(KernelType::TBE_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); - auto clean_size = CalCleanZerosSize(pre_node); - AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero); - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get()); - new_nodes->push_back(clear_zero); -} - -static bool IsAtomicNode(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto parameters_indexs = kernel_mod->GenParameters(); - if (parameters_indexs.empty()) { - return false; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); - size_t param_num = parameters_indexs.size(); - size_t total_num = input_num + workspace_num + output_num; - MS_LOG(INFO) << "parameters size: " << param_num << ", input & workspace & output num: " << total_num; - size_t pad_index = param_num; - for (; pad_index < total_num; ++pad_index) { - parameters_indexs.emplace_back(0); - } - // process input - for (size_t j = 0; j < input_num; ++j) { - if (parameters_indexs.at(j) == 1) { - MS_LOG(EXCEPTION) << "Atomic addr clean does't support clean input address, input index: " << j; - } - } - // process output - std::vector output_indexs = {}; - for (size_t i = 0; i < output_num; ++i) { - auto param_output = parameters_indexs.at(input_num + workspace_num + i); - if (param_output == 1) { - output_indexs.emplace_back(i); - MS_LOG(INFO) << "Atomic clear output index: " << i; - } - } - if (!output_indexs.empty()) { - AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node); - } - // process workspace - std::vector workspace_indexs = {}; - for (size_t k = 0; k < workspace_num; ++k) { - auto param_workspace = parameters_indexs.at(input_num + k); - if (param_workspace == 1) { - workspace_indexs.emplace_back(k); - MS_LOG(INFO) << "Atomic clear workspace index: " << k; - } - } - if (!workspace_indexs.empty()) { - AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node); - } - return !(workspace_indexs.empty() && output_indexs.empty()); -} - -bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - bool ret = device::ascend::KernelPreBuildParallelCompile(kernel_graph_ptr); - return ret; -} - -bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - TbeUtils::LoadCache(); - bool ret; - ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr); - return ret; -} - -void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector new_nodes; - for (const auto &anf_node : kernel_graph->execution_order()) { - std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); - if (apply_function_name == prim::kPrimMaxPoolGrad->name() && - AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { - auto clear_zero_prim = std::make_shared(kClearZeroOpName); - MS_EXCEPTION_IF_NULL(clear_zero_prim); - auto new_value_node = NewValueNode(clear_zero_prim); - MS_EXCEPTION_IF_NULL(new_value_node); - std::vector inputs = {new_value_node}; - inputs.push_back(anf_node); - CNodePtr clear_zero = kernel_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(clear_zero); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - clear_zero->set_kernel_info(kernel_info); - AbstractBasePtr abstract = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract); - AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector({"x"})), clear_zero); - SelectKernelInfo(clear_zero); - // set the distinction label of clear same with anf - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); - new_nodes.push_back(clear_zero); - } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { - if (IsAtomicNode(anf_node)) { - AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); - } - } - new_nodes.push_back(anf_node); - } - kernel_graph->set_execution_order(new_nodes); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/kernel_build_ascend.h b/mindspore/ccsrc/device/ascend/kernel_build_ascend.h deleted file mode 100644 index d987b6ce7a..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_build_ascend.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ - -#include "session/kernel_graph.h" - -namespace mindspore { -namespace device { -namespace ascend { -/** - * @brief kernel pre build for ascend. - */ -bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); -/** - * @brief kernel build for ascend. - */ -bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); -/** - * @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn. - * Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph - */ -void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph); -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc deleted file mode 100644 index cde79a18f7..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ /dev/null @@ -1,584 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/kernel_select_ascend.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "debug/anf_ir_dump.h" -#include "operator/ops.h" -#include "ir/func_graph.h" -#include "utils/context/ms_context.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/common_utils.h" -#include "kernel/kernel_query.h" -#include "kernel/oplib/oplib.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace { -const float kWegihtBaseScore = 1; -const float kFeatureMapBaseScore = 10; -constexpr auto kPriChoosenFormat = "pri_format"; -enum MatchCountPriority : int { - MATCH_COUNT_PRIORITY_BEGIN = 0, - MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, - MATCH_FORMAT_COUNT, - MATCH_SPECIAL_FORMAT_COUNT, - MATCH_DEFAULT_FORMAT_COUNT, - MATCH_OUTPUT_DTYPE_COUNT, - MATCH_COUNT_PRIORITY_END -}; - -const int kUnSupportMixedDataTypeIndex = -1; - -bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { - MS_EXCEPTION_IF_NULL(cnode); - // Check input data type - for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { - TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { - return false; - } - } - // Check output data type - for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { - if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { - return false; - } - } - return true; -} - -string GetPriorityMatchFormat(const CNodePtr &cnode) { - string priority_matched_format = kOpFormat_NC1HWC0; - bool is_init = false; - bool need_change_nd = false; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { - auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); - if (AnfAlgo::IsFeatureMapInput(cnode, index) && - kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) { - priority_matched_format = !is_init ? pre_output_format : priority_matched_format; - is_init = true; - } - // feature map has two or more special format; - if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) { - priority_matched_format = kOpFormat_DEFAULT; - } - auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); - need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); - } - if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { - priority_matched_format = kOpFormat_DEFAULT; - } - AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); - return priority_matched_format; -} -/** - * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, - * if equal then next num location - * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] - */ -bool PriorityChooseItem(const std::vector &cur_item, std::vector *best_item) { - MS_EXCEPTION_IF_NULL(best_item); - if (cur_item.size() != best_item->size()) { - MS_LOG(ERROR) << "Item size should be same!"; - return false; - } - // Update the best_item by comparing the cur_item and best_item - for (size_t i = 0; i < cur_item.size(); i++) { - if (cur_item[i] > best_item->at(i)) { - *best_item = cur_item; - return true; - } else if (cur_item[i] == best_item->at(i)) { - continue; - } else { - return false; - } - } - return false; -} - -void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, - std::vector *const cur_kernelinfo_match_counts) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); - if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { - MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; - } - auto pri_match_format = GetPriorityMatchFormat(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - auto input_anf_node = kernel_node->input(input_index + 1); - // we do not take ValueNode into consideration in graph kernel. - if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) { - if (input_anf_node->isa() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { - continue; - } - } - auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; - if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { - (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; - } - // we match output fix precision first. - auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index); - if (prev_device_type == kTypeUnknown) { - prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); - } - if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) { - (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; - } - if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { - (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; - } - if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { - (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; - } - } - - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - // cal count of same output dtype between abstract and kernel info - if (kernel_build_info.GetOutputDeviceType(output_index) == - AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { - (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1; - } - } -} - -void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *support_index) { - MS_EXCEPTION_IF_NULL(support_index); - int index = kUnSupportMixedDataTypeIndex; - switch (data_type) { - case kNumberTypeFloat16: - index = 0; - break; - case kNumberTypeFloat32: - case kNumberTypeFloat: - index = 1; - break; - default: - break; - } - support_index->push_back(index); -} - -void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, - std::vector *support_datatype_index, std::vector *support_datatype) { - MS_EXCEPTION_IF_NULL(support_datatype); - auto data_type = kernel_build_info.GetInputDeviceType(input_index); - support_datatype->push_back(data_type); - AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); -} - -void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, - std::vector *support_datatype_index, std::vector *support_datatype) { - MS_EXCEPTION_IF_NULL(support_datatype); - auto data_type = kernel_build_info.GetOutputDeviceType(output_index); - support_datatype->push_back(data_type); - AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); -} - -void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, - std::vector *node_mix_precision_datatype_index, - std::vector *node_mix_precision_datatype) { - AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(cur_input); - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); - node_mix_precision_datatype->push_back(input_origin_type); -} - -void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, - std::vector *node_mix_precision_datatype_index, - std::vector *node_mix_precision_datatype) { - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); - AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); - node_mix_precision_datatype->push_back(output_origin_type); -} - -void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { - MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " - << node_mix_precision_datatype.size(); - } - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { - MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " - << kernel_support_datatypes.size(); - } -} - -bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, - kernel_match_datatype_idx); - for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { - if (node_mix_precision_datatype[i] == kTypeUnknown) { - continue; - } - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { - auto find_iter = kernel_support_datatypes.find(iter->first); - if (find_iter == kernel_support_datatypes.end()) { - MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; - } - if (i >= find_iter->second.size()) { - MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); - } - if (node_mix_precision_datatype[i] != find_iter->second[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - continue; - } - auto datatype_indexes = iter->second; - if (i >= datatype_indexes.size()) { - MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); - } - if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - } - } - return !kernel_match_datatype_idx->empty(); -} - -bool CanDataTypeReduce(const std::vector &datatype_indexes, int check_index, - const std::vector &node_mix_precision_datatype_index) { - auto check_index_tmp = IntToSize(check_index); - if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { - return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && - datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; - } - MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; -} - -bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, - kernel_match_datatype_idx); - for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { - if (node_mix_precision_datatype[i] == kTypeUnknown) { - continue; - } - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { - auto find_iter = kernel_support_datatypes.find(iter->first); - if (find_iter == kernel_support_datatypes.end()) { - MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; - } - if (i >= find_iter->second.size()) { - MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); - } - if (node_mix_precision_datatype[i] != find_iter->second[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - continue; - } - auto datatype_indexes = iter->second; - if (i >= datatype_indexes.size()) { - MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); - } - if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - } - } - return !kernel_match_datatype_idx->empty(); -} - -void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, - std::vector *support_indexes, std::vector *node_mix_precision_datatype, - std::vector *support_datatypes, - std::vector *node_mix_precision_datatype_index) { - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - bool add_node_datatype_flag = false; - if (node_mix_precision_datatype->empty()) { - add_node_datatype_flag = true; - } - for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { - AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes); - if (add_node_datatype_flag) { - AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype); - } - } - // Check output data type - for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { - AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes); - if (add_node_datatype_flag) { - AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype); - } - } -} - -void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatype, - std::map> *kernel_match_datatype_idx, bool *precision_reduce) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(precision_reduce); - std::map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; - // raise precision - bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, kernel_match_datatype_idx); - if (selected_ret) { - *precision_reduce = false; - return; - } - if (context_ptr->enable_reduce_precision()) { - selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, &kernel_match_datatype_idx_copy); - } - if (selected_ret) { - *precision_reduce = true; - *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; - } -} - -void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, - const std::shared_ptr &selected_kernel_build_info, - bool precision_reduce) { - MS_EXCEPTION_IF_NULL(selected_kernel_build_info); - MS_EXCEPTION_IF_NULL(cnode); - std::ostringstream buffer; - buffer << cnode->DebugString(); - if (precision_reduce) { - buffer << " Reduce precision, node datatype: \n"; - } else { - buffer << " Raise precision, node datatype: \n"; - } - PrintInputAndOutputInferType(buffer, cnode); - buffer << ", select kernel:" << selected_kernel_build_info->ToString(); - MS_LOG(INFO) << buffer.str(); -} - -std::shared_ptr ChooseMatchedKernelInfo( - const CNodePtr &kernel_node, const std::vector> &kernel_info_list) { - if (kernel_info_list.empty()) { - return nullptr; - } - std::vector most_match_counts = {-1, -1, -1, -1, -1}; - size_t selected_index = 0; - for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; - auto kernel_info_ptr = kernel_info_list[info_index]; - MS_EXCEPTION_IF_NULL(kernel_info_ptr); - UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); - // Currently the selection policy is the match format count first, and then is datatype counts. - if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { - selected_index = SizeToInt(info_index); - } - } - return kernel_info_list[selected_index]; -} - -std::vector> FilteredKernelInfoByDtype( - const CNodePtr &cnode, const std::vector> &kernel_info_list) { - std::vector> result; - for (const auto &kernel_build_info : kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_build_info); - if (!MatchInferOutputDataType(cnode, *kernel_build_info)) { - continue; - } - result.push_back(kernel_build_info); - } - return result; -} - -std::vector> FilterRaisedOrReducePrecisionMatchedKernelInfo( - const CNodePtr &cnode, const std::vector> &kernel_info_list, - bool *precision_reduce) { - std::vector> filtered_kernel_info_list; - std::map> kernel_match_datatype_idx; - std::map> kernel_support_datatype; - std::vector node_mix_precision_datatype_index; - std::vector node_mix_precision_datatype; - for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector support_indexes; - std::vector support_datatypes; - MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); - AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, - &support_datatypes, &node_mix_precision_datatype_index); - kernel_match_datatype_idx[info_index] = support_indexes; - kernel_support_datatype[info_index] = support_datatypes; - } - PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, - &kernel_match_datatype_idx, precision_reduce); - std::transform( - kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), - [&](const std::pair> &matched_idx) -> std::shared_ptr { - return kernel_info_list[matched_idx.first]; - }); - return filtered_kernel_info_list; -} -} // namespace - -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(input_kernel_node); - auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); - MS_EXCEPTION_IF_NULL(input_with_index.first); - auto real_input_node = input_with_index.first; - if (real_input_node->isa()) { - continue; - } - if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { - continue; - } - auto builder = std::make_shared(); - if (IsValueNode(input_kernel_node) && - AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - continue; - } - // we set special device info of a input tensor. - bool is_ref = false; - auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); - if (op_info != nullptr) { - is_ref = op_info->is_ref(); - } - MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode && - AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { - continue; - } - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); - } - } -} - -KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, - const std::vector> &kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - KernelSelectStatus select_status = kNoMatched; - bool precision_reduce = false; - std::shared_ptr selected_kernel_info = nullptr; - // Matched kernel info - // Filter kernel info matched with me infered type - auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); - if (!filtered_kernel_info_list.empty()) { - selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); - select_status = kStatusAllMatched; - } else { - // selected kernel info using raised precision or reduce precision - filtered_kernel_info_list = - FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); - selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); - if (selected_kernel_info == nullptr) { - return select_status; - } else { - PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); - select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; - } - } - // Set kernel info to the anfnode - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); - // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); - return select_status; -} - -KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { - std::vector> kernel_info_list; - std::vector> aicpu_kernel_info_list; - MS_EXCEPTION_IF_NULL(kernel_node); - if (AnfAlgo::IsGraphKernel(kernel_node)) { - auto func_graph = GetValueNode(kernel_node->input(kAnfPrimitiveIndex)); - MS_EXCEPTION_IF_NULL(func_graph); - SelectGraphKernelInfo(kernel_node, func_graph); - return kStatusAllMatched; - } - kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); - auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); - // If aicore not find valid kernel info reloading aicpu kernel info list to find it - if (select_status == kNoMatched) { - MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() - << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); - select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list); - AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); - } - // The kernel info not finded both in the aicpu kernel list & aicore kernel list - if (select_status == kNoMatched) { - std::ostringstream buffer; - PrintInputAndOutputInferType(buffer, kernel_node); - MS_LOG(WARNING) << ">>> Candidates kernel info list:"; - for (size_t index = 0; index < kernel_info_list.size(); ++index) { - MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString(); - } - for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) { - MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index) - << "] :" << aicpu_kernel_info_list[index]->ToString(); - } - if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) { - auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); - // Set format and data type for input tensor. - SetTensorDeviceInfo(*selected_kernel_info, kernel_node); - } else { - MS_LOG(WARNING) << " <<<"; - MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() - << "] cannot find valid kernel info, not supported the type:" << buffer.str() - << ", please refer to the supported dtypes in candidates kernel info list"; - } - } - return select_status; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/device/ascend/kernel_select_ascend.h deleted file mode 100644 index 7b7a7b9fb9..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ -#include "ir/anf.h" -#include "kernel/kernel_build_info.h" -namespace mindspore { -namespace device { -namespace ascend { -enum KernelSelectStatus { - kNoMatched = -1, - kStatusAllMatched = 0, - kStatusReducePrecision = 1, - kStatusRaisePrecision = 2, -}; -KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, - KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); -void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ diff --git a/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc deleted file mode 100644 index db31460d31..0000000000 --- a/mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc +++ /dev/null @@ -1,531 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/kernel_select_ascend.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "ir/func_graph.h" -#include "kernel/common_utils.h" -#include "kernel/kernel_query.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace { -// sort format according the number of occurrences. -bool cmp_format_num(const std::pair &a, const std::pair &b) { - if (a.second != b.second) { - return a.second > b.second; - } else if (a.first == kOpFormat_DEFAULT) { - return a.second + 1 > b.second; - } else if (b.first == kOpFormat_DEFAULT) { - return a.second > b.second + 1; - } - return a.second > b.second; -} - -TypeId GetPrimitivePrecision(const CNodePtr &cnode) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - - TypeId except_type = kTypeUnknown; - if (primitive->GetAttr(kAttrFixPrecision) != nullptr) { - auto strExceptDtype = GetValue(primitive->GetAttr(kAttrFixPrecision)); - if (strExceptDtype == "float16") { - except_type = kNumberTypeFloat16; - } else if (strExceptDtype == "float32") { - except_type = kNumberTypeFloat32; - } else { - MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype; - } - } - - return except_type; -} -} // namespace - -void ResetKernelBuildInfo(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(input_kernel_node); - auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); - if (!kernel::IsWeightBoundary(kernel_with_index.first)) { - continue; - } - // reset format and dtype. - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get()); - } -} - -void UpdateKernelInfo(const std::vector &node_list) { - for (size_t i = 0; i < node_list.size(); ++i) { - // select nodes in subgraph. - auto anf_node = node_list[i]; - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto fix_precision_type = GetPrimitivePrecision(cnode); - if (fix_precision_type != kTypeUnknown) { - std::vector> kernel_info_list; - kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL); - - for (size_t index = 0; index < kernel_info_list.size(); ++index) - // only math the first input - if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type && - kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) && - AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) { - auto selected_kernel_info_ptr = kernel_info_list[index]; - ResetKernelBuildInfo(cnode); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); - SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); - break; - } - } - } -} - -bool CanConvertDefaultShapeToNZ(const std::vector &shape) { - for (size_t i = 1; i <= shape.size(); ++i) { - if (i > 2) { - break; - } - if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) { - return false; - } - } - return true; -} - -std::vector DefaultToFracNZAxis(const std::vector &ori_shape, const std::vector &axis) { - std::vector frac_nz_axis = axis; - auto shape_len = ori_shape.size(); - for (size_t i = 0; i < axis.size(); ++i) { - auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len; - if (axis_idx == shape_len - 1) { - frac_nz_axis[i] = axis_idx - 1; - frac_nz_axis.push_back(axis_idx + 2); - } else if (axis_idx == shape_len - 2) { - frac_nz_axis[i] = axis_idx + 1; - frac_nz_axis.push_back(axis_idx + 2); - } else { - frac_nz_axis[i] = axis_idx; - } - } - return frac_nz_axis; -} - -std::vector GetReducedFracNZShape(const std::vector &ori_shape, const std::vector &axis, - bool keep_dims) { - std::vector result; - std::set positive_idx; - for (const auto &a : axis) { - positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); - } - for (size_t i = 0; i < ori_shape.size(); ++i) { - if (positive_idx.count(i) == 0) { - result.push_back(ori_shape[i]); - } else if (keep_dims) { - result.push_back(1); - } - } - return result; -} - -void UpdateFracNZReduceOp(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); - if (input_format == kOpFormat_FRAC_NZ) { - // Clone primitive to modify it - auto prim = GetCNodePrimitive(cnode); - auto new_prim = std::make_shared(*prim); - auto new_prim_node = NewValueNode(new_prim); - cnode->set_input(0, new_prim_node); - - auto axis_value = new_prim->GetAttr(kAttrAxis); - std::vector default_axis; - if (axis_value->isa()) { - auto value_list = dyn_cast(axis_value); - for (const auto &item : value_list->value()) { - if (item->isa()) { - default_axis.push_back(GetValue(item)); - } - } - } else if (axis_value->isa()) { - auto value_tuple = dyn_cast(axis_value); - for (const auto &item : value_tuple->value()) { - if (item->isa()) { - default_axis.push_back(GetValue(item)); - } - } - } else { - MS_LOG(ERROR) << "Axis attr type is not correct!"; - } - auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - std::vector frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis); - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue>(frac_nz_axis), cnode); - auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - if (output_shape.size() == 1) { - AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue(true), cnode); - } - } -} - -void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(default_format); - MS_EXCEPTION_IF_NULL(use_same_format); - std::unordered_map all_input_formats; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; ++i) { - auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa()) { - ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; - continue; - } - auto para = input_kernel_node->cast(); - if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { - ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; - continue; - } - *use_same_format = false; - } - - if (all_input_formats.empty()) { - // all inputs are parameter. - *default_format = kOpFormat_NC1HWC0; - } else { - std::vector> pairs; - for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { - pairs.push_back(std::make_pair(iter->first, iter->second)); - } - - std::sort(pairs.begin(), pairs.end(), cmp_format_num); - *default_format = pairs.begin()->first; - } - - for (size_t i = 0; i < input_num; ++i) { - auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa() || - AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) { - continue; - } - auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0); - if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) { - *default_format = kOpFormat_DEFAULT; - *use_same_format = true; - break; - } - } -} - -void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector &input_list, - const std::string &default_format, bool use_same_format, - std::vector *graph_input_format, std::vector *graph_input_type) { - MS_EXCEPTION_IF_NULL(graph_input_format); - MS_EXCEPTION_IF_NULL(graph_input_type); - // We set same format to all inputs of graph kernel subgraph, and process this latter. - // We set dtype to inputs of graph kernel subgraph same as infer dtypes. - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; ++i) { - auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (use_same_format) { - bool can_convert = true; - if (default_format == kOpFormat_FRAC_NZ) { - auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - if (!CanConvertDefaultShapeToNZ(infer_shape)) { - MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead"; - can_convert = false; - } - } - if (can_convert) { - graph_input_format->push_back(default_format); - } else { - graph_input_format->push_back(kOpFormat_DEFAULT); - } - graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); - continue; - } - - if (!input_kernel_node->isa()) { - // subgraph parameter from output of other nodes. - graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)); - graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); - continue; - } - - auto para = input_kernel_node->cast(); - MS_EXCEPTION_IF_NULL(para); - if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { - // parameter already selected. - graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0)); - graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0)); - continue; - } - - // weight parameter. - graph_input_format->push_back(default_format); - graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)); - } - - for (size_t i = 0; i < input_num; ++i) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - std::vector outputs_format = {(*graph_input_format)[i]}; - std::vector outputs_device_type = {(*graph_input_type)[i]}; - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_device_type); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); - } -} - -void UpdateEquivFormat(const std::vector> &output_index, - const std::vector &node_list, const FuncGraphPtr &func_graph, - const FuncGraphManagerPtr &mng) { - MS_EXCEPTION_IF_NULL(mng); - for (size_t i = 0; i < node_list.size(); ++i) { - // select nodes in subgraph. - auto anf_node = node_list[i]; - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - cnode->set_kernel_info(std::make_shared()); - SelectKernelInfo(cnode, KernelType::AKG_KERNEL); - // Update ReduceSum - if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) { - continue; - } - UpdateFracNZReduceOp(cnode); - // If ReduceSum's output is 1d and not Default format, convert it to Default format - auto out_format = AnfAlgo::GetOutputFormat(cnode, 0); - if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { - continue; - } - auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - // Insert EquivFormat node, then select kernel info again - std::vector trans_inputs; - trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); - trans_inputs.push_back(cnode); - CNodePtr trans_node = func_graph->NewCNode(trans_inputs); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)}, - {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get()); - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue>({"x"}), trans_node); - - if (trans_node->kernel_info() == nullptr) { - trans_node->set_kernel_info(std::make_shared()); - } - SelectKernelInfo(trans_node, KernelType::AKG_KERNEL); - mng->Replace(cnode, trans_node); - } -} - -void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &input_list, - const FuncGraphManagerPtr &mng, const std::string &default_format, - std::vector *graph_input_format, std::vector *graph_input_type, - std::vector *need_update) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(graph_input_format); - MS_EXCEPTION_IF_NULL(graph_input_type); - MS_EXCEPTION_IF_NULL(need_update); - // check graph input format and dtype use inner ops. - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || - need_update->size() != input_num) { - MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() - << "], [" << graph_input_format->size() << "] != [" << input_num << "]"; - } - auto &node_users = mng->node_users(); - for (size_t i = 0; i < input_num; ++i) { - auto &input = input_list[i]; - auto iter = node_users.find(input); - if (iter == node_users.end() || iter->second.empty()) { - continue; - } - for (auto &node_user : iter->second) { - if (node_user.first->kernel_info() == nullptr || - node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { - // maybe not a real kernel. - continue; - } - auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1)); - if (user_format != (*graph_input_format)[i]) { - MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" - << kernel_node->DebugString() - << "] selected different format. we use defult: " << default_format; - (*graph_input_format)[i] = default_format; - (*need_update)[i] = true; - } - - if (kernel_node->input(i + 1)->isa() || - AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { - continue; - } - - TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); - MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" - << kernel_node->DebugString() - << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); - (*graph_input_type)[i] = default_dtype; - (*need_update)[i] = true; - } - } -} - -void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &node_list, - const std::vector &input_list, const std::vector &need_update, - const std::vector &graph_input_format, - const std::vector &graph_input_type) { - MS_EXCEPTION_IF_NULL(kernel_node); - // update graph input format and dtype use inner ops. - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || - need_update.size() != input_num) { - MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() - << "], [" << graph_input_format.size() << "] != [" << input_num << "]"; - } - for (size_t i = 0; i < input_num; ++i) { - if (!need_update[i]) { - continue; - } - - MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() - << "] to: " << graph_input_format[i]; - MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() - << "] to: " << TypeIdLabel(graph_input_type[i]); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - std::vector outputs_format = {graph_input_format[i]}; - std::vector outputs_device_type = {graph_input_type[i]}; - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_device_type); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); - } - - ResetKernelBuildInfo(kernel_node); - // select nodes in subgraph again. - for (size_t i = 0; i < node_list.size(); ++i) { - auto anf_node = node_list[i]; - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode); - for (size_t j = 0; j < cnode_input_num; ++j) { - auto input_node = cnode->input(j + 1); - MS_EXCEPTION_IF_NULL(input_node); - if (!IsValueNode(input_node)) { - continue; - } - // reset format and dtype of const tensor. - builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get()); - } - SelectKernelInfo(node_list[i]->cast(), KernelType::AKG_KERNEL); - } -} - -void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector> &output_index, - const std::vector &graph_input_format, - const std::vector &graph_input_type) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector graph_output_format; - std::vector graph_output_type; - for (size_t i = 0; i < output_index.size(); ++i) { - auto const &output = output_index[i]; - graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second)); - TypeId output_type(kTypeUnknown); - if (output.first->isa()) { - output_type = AnfAlgo::GetCNodeOutputPrecision(output.first); - } - if (output_type == kTypeUnknown) { - output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second); - } - graph_output_type.push_back(output_type); - } - - kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; - graph_info_builder.SetInputsFormat(graph_input_format); - graph_info_builder.SetInputsDeviceType(graph_input_type); - graph_info_builder.SetOutputsFormat(graph_output_format); - graph_info_builder.SetOutputsDeviceType(graph_output_type); - graph_info_builder.SetProcessor(kernel::Processor::AICORE); - graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); - graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); - auto graph_selected_info = graph_info_builder.Build(); - MS_EXCEPTION_IF_NULL(graph_selected_info); - AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); - SetTensorDeviceInfo(*graph_selected_info, kernel_node); -} - -void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(func_graph); - - // collect input info of funcgraph - std::vector node_list; - std::vector input_list; - std::vector output_list; - kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - if (input_list.size() != kernel_node->inputs().size() - 1) { - MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode[" - << kernel_node->DebugString() << "], [%" << input_list.size() << "] != [" - << kernel_node->inputs().size() << "]"; - } - - std::string default_format; - bool use_same_format = true; - GetDefaultFormat(kernel_node, &default_format, &use_same_format); - MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format - << "] for ParameterWeight."; - - std::vector graph_input_format; - std::vector graph_input_type; - UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, - &graph_input_type); - - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - } - auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); - UpdateEquivFormat(output_index, node_list, func_graph, mng); - node_list.clear(); - input_list.clear(); - output_list.clear(); - kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - - // update graph input format and dtype use inner ops. - std::vector need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); - CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, - &need_update); - UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); - - // set fix_precision for kernel when the me prim has fix_precision attr - UpdateKernelInfo(node_list); - - output_index = kernel::GetOutputIndex(node_list, input_list, output_list); - SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.cc b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.cc deleted file mode 100644 index 7790107aa1..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/ascend/profiling/plugin_impl.h" -#include -#include "utils/log_adapter.h" -using std::string; - -namespace mindspore { -namespace device { -namespace ascend { -Reporter *PluginImpl::reporter_ = nullptr; - -PluginImpl::PluginImpl(const std::string &module) : module_(module) { MS_LOG(INFO) << "Create PluginImpl."; } - -int PluginImpl::Init(const Reporter *reporter) { - MS_LOG(INFO) << "PluginImpl init"; - MS_EXCEPTION_IF_NULL(reporter); - reporter_ = const_cast(reporter); - return 0; -} - -int PluginImpl::UnInit() { - MS_LOG(INFO) << " PluginImpl Uninit "; - reporter_ = nullptr; - return 0; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc deleted file mode 100644 index a393409334..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/ascend/profiling/profiling_engine_impl.h" -#include "utils/log_adapter.h" -#include "device/ascend/profiling/plugin_impl.h" - -namespace mindspore { -namespace device { -namespace ascend { -PluginIntf *ProfilingEngineImpl::CreatePlugin() { - MS_LOG(INFO) << "Create Plugin."; - return new (std::nothrow) PluginImpl("Framework"); -} - -int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { - if (plugin != nullptr) { - delete plugin; - plugin = nullptr; - } - return 0; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc deleted file mode 100644 index a2fe5b852d..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc +++ /dev/null @@ -1,207 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/profiling/profiling_manager.h" -#include -#include -#include "securec/include/securec.h" -#include "./prof_mgr_core.h" -#include "device/ascend/profiling/plugin_impl.h" -#include "device/ascend/profiling/profiling_engine_impl.h" -#include "utils/log_adapter.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "runtime/base.h" - -namespace mindspore { -namespace device { -namespace ascend { -ProfilingManager &ProfilingManager::GetInstance() { - static ProfilingManager inst; - return inst; -} - -ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { - engine_0_ = std::make_shared(); -} - -uint64_t ProfilingManager::GetJobId() const { - const char *job_id = std::getenv("JOB_ID"); - return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); -} - -bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { - if (!IsProfiling()) { - MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; - return false; - } - if (op_taskId_map.empty()) { - MS_LOG(WARNING) << "op_taskId_map is empty."; - return false; - } - auto reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - MS_LOG(ERROR) << "No profiling data report!"; - return false; - } - MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); - - Msprof::Engine::ReporterData reporter_data = {}; - for (const auto &iter : op_taskId_map) { - auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; - reporter_data.deviceId = UintToInt(device_id_); - reporter_data.data = (unsigned char *)(const_cast(data.c_str())); - reporter_data.dataLen = data.size(); - auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); - if (ret != 0) { - MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; - return false; - } - ret = reporter->Report(&reporter_data); - if (ret != 0) { - MS_LOG(ERROR) << "reporter data fail, errorno(" << ret << ")"; - return false; - } - } - return true; -} - -static std::vector Split(const std::string &str, const char delim) { - std::vector elems; - - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream ss(str); - std::string item; - - while (getline(ss, item, delim)) { - elems.push_back(item); - } - auto str_size = str.size(); - if (str_size > 0 && str[str_size - 1] == delim) { - elems.emplace_back(""); - } - - return elems; -} - -bool ProfilingManager::StartupProfiling(uint32_t device_id) { - auto is_profiling = IsProfiling(); - if (!is_profiling) { - MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; - return true; - } - device_id_ = device_id; - // register Framework to profiling - int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); - if (result != 0) { - MS_LOG(ERROR) << "Register profiling Engine failed."; - return false; - } - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - const string prof_options_str = context->profiling_options(); - std::vector opts = Split(prof_options_str, ':'); - if (opts.empty()) { - MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; - return true; - } - // current one docker only use one device` - nlohmann::json p_device; - // JOBID - auto job_id = GetJobId(); - p_device["jobID"] = std::to_string(job_id); - // device_id - p_device["deviceID"] = std::to_string(device_id); - // features:'training_trace', 'task_trace' etc - nlohmann::json features; - for (std::vector::size_type i = 0; i < opts.size(); i++) { - nlohmann::json f; - f["name"] = opts[i]; - features[i] = f; - } - p_device["features"] = features; - // only one device, but sProfMgrStartUp API require for device list - nlohmann::json devices; - devices[0] = p_device; - nlohmann::json startCfg; - startCfg["startCfg"] = devices; - - if (!ProfStartUp(NOT_NULL(&startCfg))) { - MS_LOG(ERROR) << "ProfMgrStartUp failed."; - return false; - } - return true; -} - -bool ProfilingManager::ProfStartUp(NotNull startCfg) { - // convert json to string - std::stringstream ss; - ss << *startCfg; - std::string cfg = ss.str(); - MS_LOG(INFO) << "profiling config " << cfg; - auto ret = rtProfilerStart(); - if (ret != RT_ERROR_NONE) { - MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret; - return false; - } - - // call profiling startup API - ProfMgrCfg prof_cfg = {cfg}; - prof_handle_ = ProfMgrStartUp(&prof_cfg); - if (prof_handle_ == nullptr) { - MS_LOG(ERROR) << "Startup profiling failed."; - return false; - } - return true; -} - -bool ProfilingManager::StopProfiling() { - MS_LOG(INFO) << "StopProfiling"; - if (!IsProfiling()) { - MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; - return true; - } - Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter != nullptr) { - MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); - } - - auto rt_ret = rtProfilerStop(); - if (rt_ret != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Call rtProfilerStop failed"; - return false; - } - - if (prof_handle_ != nullptr) { - int result = ProfMgrStop(prof_handle_); - if (result != 0) { - MS_LOG(ERROR) << "ProfMgr stop return fail:" << result << "."; - prof_handle_ = nullptr; - return false; - } - prof_handle_ = nullptr; - } - - return true; -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc deleted file mode 100644 index 17ac4c4530..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc +++ /dev/null @@ -1,367 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/profiling/reporter/graph_desc_reporter.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "kernel/kernel.h" -#include "device/ascend/profiling/profiling_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "utils/utils.h" -#include "device/ascend/profiling/reporter/task_desc_reporter.h" -#include "utils/context/ms_context.h" -#include "device/ascend/profiling/reporter/point_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -constexpr uint32_t kMaxProfilingNodeNum = 100; -constexpr char kCustomNode[] = "PROFILING_CUSTOM_"; -constexpr char kFpStartNode[] = "PROFILING_FP_START"; -constexpr char kBpEndNode[] = "PROFILING_BP_END"; -constexpr char kIterEndNode[] = "PROFILING_ITER_END"; -// PROFILING_CUSTOM_LOGID_START 3 -constexpr uint64_t kProfilingFpStartLogId = 1; -constexpr uint64_t kProfilingBpEndLogId = 2; -constexpr uint64_t kProfilingIterEndLogId = 255; -std::map> ProfilingUtils::graph_profiling_cnode_; -std::map> ProfilingUtils::graph_kernel_name_; -std::map>> ProfilingUtils::graph_point_; -uint32_t ProfilingUtils::custom_node_index_ = 1; - -ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull graph_ptr) { - MS_LOG(INFO) << "get env start"; - custom_node_index_ = 1; - auto &cnode_exec_order = graph_ptr->execution_order(); - ProfilingTraceInfo profiling_trace; - profiling_trace.trace_begin = GetTraceBegin(cnode_exec_order); - profiling_trace.trace_bp_end = GetTraceBpEnd(cnode_exec_order); - profiling_trace.trace_netoutput = GetTraceNetoutput(cnode_exec_order); - - for (uint32_t i = 1; i <= kMaxProfilingNodeNum; ++i) { - std::string env_str = std::string(kCustomNode) + std::to_string(i); - const char *node_full_name = std::getenv(env_str.c_str()); - if (node_full_name == nullptr) { - break; - } - MS_LOG(INFO) << "Get profiling node:" << node_full_name; - profiling_trace.trace_custom_node.insert(node_full_name); - } - MS_LOG(INFO) << "get env end"; - GetTraceHccl(cnode_exec_order, NOT_NULL(&profiling_trace)); - - MS_LOG(INFO) << "[profiling]trace_begin:" << profiling_trace.trace_begin - << " trace_bp_end:" << profiling_trace.trace_bp_end - << " trace_netoutput:" << profiling_trace.trace_netoutput; - return profiling_trace; -} - -void ProfilingUtils::GetTraceHccl(const std::vector &cnode_exec_order, - NotNull profiling_trace) { - for (const auto &node : cnode_exec_order) { - if (AnfAlgo::IsCommunicationOp(node)) { - MS_EXCEPTION_IF_NULL(node); - profiling_trace->trace_custom_node.insert(node->fullname_with_scope()); - MS_LOG(INFO) << "[profiling]Get hccl node:" << node->fullname_with_scope(); - } - } -} - -std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exec_order) { - const char *trace_begin = std::getenv(kFpStartNode); - if (trace_begin != nullptr) { - return std::string(trace_begin); - } - - std::string fp_start_str; - std::set getnext_outputs; - GetCNodeOutputRealNode(kGetNextOpName, cnode_exec_order, NOT_NULL(&getnext_outputs)); - if (getnext_outputs.empty()) { - auto first_node = cnode_exec_order.front(); - MS_EXCEPTION_IF_NULL(first_node); - fp_start_str = first_node->fullname_with_scope(); - } else { - for (auto &cnode : cnode_exec_order) { - if (getnext_outputs.count(cnode->fullname_with_scope()) != 0) { - fp_start_str = cnode->fullname_with_scope(); - break; - } - } - } - return fp_start_str; -} - -void ProfilingUtils::GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, - NotNull *> getnext_outputs) { - for (const auto &cnode : cnode_exec_order) { - MS_EXCEPTION_IF_NULL(cnode); - for (const auto &input : cnode->inputs()) { - auto prev_cnode = AnfAlgo::VisitKernel(input, 0); - if (!prev_cnode.first->isa()) { - continue; - } - if (AnfAlgo::GetCNodeName(prev_cnode.first) == node_name) { - getnext_outputs->insert(cnode->fullname_with_scope()); - MS_LOG(INFO) << "Find GetNext Output CNode:" << cnode->fullname_with_scope(); - } - } - } - if (getnext_outputs->empty()) { - MS_LOG(WARNING) << "GetNext not found"; - } -} - -std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exec_order) { - const char *trace_bp_end = std::getenv(kBpEndNode); - - if (trace_bp_end != nullptr) { - return std::string(trace_bp_end); - } - std::string bp_end_str; - // Contain hccl kernel - auto iter = cnode_exec_order.rbegin(); - while (iter != cnode_exec_order.rend()) { - if (AnfAlgo::IsCommunicationOp(*iter)) { - // store communication op input nodes' name - std::set ar_input_node_names; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); - auto input_node = input_node_with_index.first; - ar_input_node_names.insert(input_node->fullname_with_scope()); - } - // start from previous node - ++iter; - // find input names in previous node - while (iter != cnode_exec_order.rend()) { - if (ar_input_node_names.find((*iter)->fullname_with_scope()) != ar_input_node_names.end()) { - bp_end_str = (*iter)->fullname_with_scope(); - break; - } - ++iter; - } - break; - } - ++iter; - } - - if (bp_end_str.empty()) { - bp_end_str = GetGraphLastTbeKernelName(cnode_exec_order); - } - return bp_end_str; -} - -std::string ProfilingUtils::GetGraphLastTbeKernelName(const std::vector &cnode_exec_order) { - std::string last_tbe_kernel_name; - // find last tbe_kernel - for (auto iter = cnode_exec_order.rbegin(); iter != cnode_exec_order.rend(); ++iter) { - if (AnfAlgo::GetKernelType(*iter) == TBE_KERNEL) { - last_tbe_kernel_name = (*iter)->fullname_with_scope(); - break; - } - } - if (last_tbe_kernel_name.empty()) { - MS_LOG(WARNING) << "tbe kernel not found in graph"; - } - return last_tbe_kernel_name; -} - -std::string ProfilingUtils::GetTraceNetoutput(const std::vector &cnode_exec_order) { - const char *trace_netoutput = std::getenv(kIterEndNode); - return trace_netoutput == nullptr ? GetGraphLastTbeKernelName(cnode_exec_order) : std::string(trace_netoutput); -} - -NotNull ProfilingUtils::CreateProfilingCNode(const ProfilingContent &profiling_content, - NotNull graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); - selected_kernel_builder.SetInputsDeviceType({TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); - selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - abstract::AbstractBasePtr type_none_abstract = std::make_shared(); - auto primitive = std::make_shared(ProfilingUtils::kProfiling); - std::vector inputs; - inputs.emplace_back(NewValueNode(primitive)); - CNodePtr cnode_ptr = graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode_ptr); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), cnode_ptr.get()); - cnode_ptr->set_abstract(type_none_abstract); - // set attr - ValuePtr notify_value = MakeValue(profiling_content.notify); - ValuePtr trace_id_value = MakeValue(profiling_content.profiler_trace_id); - ValuePtr flags_value = MakeValue(profiling_content.flags); - AnfAlgo::SetNodeAttr(ProfilingUtils::kNotify, notify_value, cnode_ptr); - AnfAlgo::SetNodeAttr(ProfilingUtils::kProfilerTraceId, trace_id_value, cnode_ptr); - AnfAlgo::SetNodeAttr(ProfilingUtils::kFlags, flags_value, cnode_ptr); - return NOT_NULL(cnode_ptr); -} - -void ProfilingUtils::SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id) { - std::shared_ptr prof_desc_ptr = std::make_shared(node_name, point_id); - auto iter = graph_point_.find(graph_id); - if (iter == graph_point_.end()) { - std::vector> tmp_vect = {prof_desc_ptr}; - graph_point_.insert({graph_id, tmp_vect}); - } else { - iter->second.emplace_back(prof_desc_ptr); - } -} - -void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node, - const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - if (profiling_trace_info.trace_begin == anf_node->fullname_with_scope()) { - MS_LOG(INFO) << "Profiling Match FpStart:" << profiling_trace_info.trace_begin; - ProfilingTraceJobId(anf_node, graph_ptr, kernel_list); - ProfilingContent fp_profiling_content = {false, kProfilingFpStartLogId, 0}; - auto fp_profiling_node = CreateProfilingCNodeWithStream(anf_node, fp_profiling_content, graph_ptr); - kernel_list->emplace_back(fp_profiling_node); - // insert ProfDesc - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingFpStartLogId); - } -} - -void ProfilingUtils::ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, - NotNull *> kernel_list) { - MS_LOG(INFO) << "Profiling Match start"; - auto job_id = ProfilingManager::GetInstance().GetJobId(); - ProfilingContent job_profiling_context = {false, job_id, 0}; - auto job_profiling_node = CreateProfilingCNodeWithStream(anf_node, job_profiling_context, graph_ptr); - kernel_list->emplace_back(job_profiling_node); -} - -CNodePtr ProfilingUtils::CreateProfilingCNodeWithStream(const mindspore::AnfNodePtr &anf_node, - const ProfilingContent &profiling_content, - NotNull graph_ptr) { - CNodePtr profiling_node = CreateProfilingCNode(profiling_content, graph_ptr); - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), profiling_node.get()); - AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(anf_node), profiling_node.get()); - return profiling_node; -} - -void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto iter = profiling_trace_info.trace_custom_node.find(anf_node->fullname_with_scope()); - if (iter == profiling_trace_info.trace_custom_node.end()) { - return; - } - MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope(); - // custom op profiling job start from 3. - auto custom_point_id = 2 * custom_node_index_ + 1; - ProfilingContent front_profiling_content = {false, custom_point_id, 0}; - CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr); - kernel_list->insert(kernel_list->end() - 1, front_node); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id); - - ProfilingContent back_profiling_content = {false, custom_point_id + 1, 0}; - CNodePtr back_node = CreateProfilingCNodeWithStream(anf_node, back_profiling_content, graph_ptr); - kernel_list->insert(kernel_list->end(), back_node); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id + 1); - ++custom_node_index_; -} - -void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - MS_EXCEPTION_IF_NULL(anf_node); - if (profiling_trace_info.trace_bp_end == anf_node->fullname_with_scope()) { - MS_LOG(INFO) << "Profiling Match BpEnd:" << profiling_trace_info.trace_bp_end; - ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0}; - CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); - kernel_list->emplace_back(bp_end_node); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingBpEndLogId); - } -} - -void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto full_scope_name = anf_node->fullname_with_scope(); - if (profiling_trace_info.trace_netoutput == full_scope_name) { - MS_LOG(INFO) << "Profiling Match IterEnd:" << profiling_trace_info.trace_netoutput; - ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0}; - CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); - kernel_list->emplace_back(bp_kernel_ptr); - SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingIterEndLogId); - } -} - -void ProfilingUtils::SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names) { - auto ret = graph_kernel_name_.try_emplace(graph_id, kernel_names); - if (!ret.second) { - MS_LOG(ERROR) << "[profiling]graph " << graph_id << " kernel names already exist"; - } -} - -void ProfilingUtils::SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list) { - auto ret = graph_profiling_cnode_.try_emplace(graph_id, profiling_cnode_list); - if (!ret.second) { - MS_LOG(ERROR) << "[profiling]graph " << graph_id << " profiling cnode list already exist"; - } -} - -bool ProfilingUtils::ValidComputeGraph(NotNull graph_ptr) { - for (const auto &node : graph_ptr->execution_order()) { - if (AnfAlgo::GetKernelType(node) == TBE_KERNEL) { - return true; - } - } - return false; -} - -void ProfilingUtils::ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, - NotNull graph) { - if (!ValidComputeGraph(graph)) { - MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id(); - return; - } - - auto ret = graph_profiling_cnode_.find(graph->graph_id()); - if (ret == graph_profiling_cnode_.end()) { - MS_LOG(ERROR) << "Graph id not found"; - return; - } - - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); - task_reporter.set_task_ids(task_ids); - task_reporter.set_stream_ids(stream_ids); - task_reporter.ReportData(); - - GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); - graph_profiling_cnode_.erase(ret); - graph_reporter.ReportData(); - - // Report profiling point - auto point_iter = graph_point_.find(graph->graph_id()); - if (point_iter == graph_point_.end()) { - MS_LOG(ERROR) << "Graph id not found in graph_point"; - return; - } - PointReporter point_reporter(context->device_id(), "vm.point"); - for (const auto &point : point_iter->second) { - point_reporter.AddReportData(point); - } - point_reporter.ReportData(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h deleted file mode 100644 index a3c7739447..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include "session/kernel_graph.h" -#include "utils/contract.h" -#include "device/ascend/profiling/reporter/profiling_desc.h" - -namespace mindspore { -namespace device { -namespace ascend { -struct ProfilingTraceInfo { - // execute order's first execute op(like: Cast or Four2Five ...), except tdt op(GetNext ...) - std::string trace_begin; - // get first net_output(apply kernel) from graph outputs: fp ->net_output<- bp - std::string trace_bp_end; - // execute order's end execute (like: Conv2DBackpropFilter) - std::string trace_netoutput; - - // profiling specific op, such as AllReduce; - std::set trace_custom_node; - - // 1. insert profiling_trace_begin if profiling_trace_bp_end is not empty. - // 2. op lanuch get task info with callback func. - // 3. insert profiling_trace_bp_end. - // 4. insert profiling_trace_net_output if profiling_trace_bp_end is not empty. - - bool IsValid() const { return !(trace_begin.empty() || trace_netoutput.empty()); } -}; - -struct ProfilingContent { - // true -send data from device to host and finish profiling - bool notify; - uint64_t profiler_trace_id; - uint32_t flags; -}; - -class ProfilingUtils { - public: - ProfilingUtils() = default; - ~ProfilingUtils() = default; - - // Insert job_id profiling node and fp_start profiling node. - // Job_id is got from envs, which shound be a number greater than 255 - // Fp_start node should been inserted in the start of a network, and the log_id is hard code to 1. - static void ProfilingTraceFpStart(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - static void ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, - NotNull *> kernel_list); - - // Insert net output profiling node, which tells the device to stop profiling. - // The notify in struct ProfilingContent should be 'true', which tells the device to send data to host. - static void ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - // Insert bp_end profiling node, which should been inserted after the last backpropagation CNode in the network. - static void ProfilingTraceBpEnd(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - // Mapping graph id and the kernels' name in the graph - static void SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list); - - static void SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names); - - // Mapping task_id and kernel name for device to generate the time cost of specific kernel. - // Device calculate the time cost of the task which is marked by task id. - // But we need data of (kernel name , time cost) - static void ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, - NotNull graph); - - // Get profiling trace point from envs. - // export PROFILING_FP_START='full name of the first cnode to execute' - // export PROFILING_BP_END='full name of the last backpropagation cnode to execute' - // export PROFILING_ITER_END='full name of last cnode in graph to execute' - // And other cnode, like AllReduce, export PROFILING_CUSTOM_1='full name of AllReduce cnode' - // GetNext, export PROFIFLING_CUSTOM_2='full name fo GetNext cnode' - // The variable i in PROFILING_CUSTOM_i should start from 1 without interruption. - static ProfilingTraceInfo GetProfilingTraceFromEnv(NotNull graph_ptr); - - // Insert two profiling trace points, one in front and one behind - static void ProfilingCustomOp(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, - NotNull graph_ptr, - NotNull *> kernel_list); - - static std::map> graph_kernel_name() { return graph_kernel_name_; } - - inline static constexpr char kProfiling[] = "Profiling"; - inline static constexpr char kNotify[] = "notify"; - inline static constexpr char kProfilerTraceId[] = "profiler_trace_id"; - inline static constexpr char kFlags[] = "flags"; - - private: - static NotNull CreateProfilingCNode(const ProfilingContent &profiling_content, - NotNull graph_ptr); - static CNodePtr CreateProfilingCNodeWithStream(const AnfNodePtr &anf_node, const ProfilingContent &profiling_content, - NotNull graph_ptr); - static std::string GetTraceBegin(const std::vector &cnode_exec_order); - static std::string GetTraceBpEnd(const std::vector &cnode_exec_order); - static std::string GetTraceNetoutput(const std::vector &cnode_exec_order); - static std::string GetGraphLastTbeKernelName(const std::vector &cnode_exec_order); - static void GetTraceHccl(const std::vector &cnode_exec_order, - NotNull profiling_trace); - static void GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, - NotNull *> getnext_outputs); - - static bool ValidComputeGraph(NotNull graph_ptr); - static void SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id); - - // graph id --> (kernel name list) - static std::map> graph_profiling_cnode_; - static std::map> graph_kernel_name_; - static std::map>> graph_point_; - static uint32_t custom_node_index_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.cc deleted file mode 100644 index cf80c07ca9..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" -#include "device/ascend/profiling/plugin_impl.h" -#include "utils/log_adapter.h" - -constexpr size_t kReportMaxLen = 2048; - -namespace mindspore { -namespace device { -namespace ascend { -DescReporter::~DescReporter() = default; - -void DescReporter::ReportByLine(const std::string &data, const std::string &file_name) const { - auto reporter = PluginImpl::GetPluginReporter(); - MS_EXCEPTION_IF_NULL(reporter); - - auto tot_size = data.size(); - size_t cur_size = 0; - while (cur_size < tot_size) { - size_t remain_size = tot_size - cur_size; - size_t report_size = std::min(remain_size, kReportMaxLen); - - Msprof::Engine::ReporterData report_data{}; - report_data.deviceId = device_id_; - report_data.dataLen = report_size; - report_data.data = (unsigned char *)data.c_str() + cur_size; - auto ret = memcpy_s(report_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, file_name.c_str(), file_name.length()); - if (ret != 0) { - MS_LOG(EXCEPTION) << "Memcpy_s report data tag failed"; - } - auto report_ret = reporter->Report(&report_data); - if (report_ret != 0) { - MS_LOG(EXCEPTION) << "Report data failed"; - } - if (report_size == 0) { - MS_LOG(WARNING) << "Report_size is 0"; - break; - } - cur_size += report_size; - } -} - -void DescReporter::ReportAllLine() { - for (const auto &desc : prof_desc_list_) { - auto data = desc->ToString(); - ReportByLine(data, file_name_); - } -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.h deleted file mode 100644 index c8e1b3ed62..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/desc_reporter.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ - -#include -#include -#include -#include -#include "toolchain/prof_reporter.h" -#include "device/ascend/profiling/reporter/profiling_desc.h" -#include "utils/contract.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace device { -namespace ascend { -class DescReporter { - public: - virtual ~DescReporter() = 0; - DescReporter(int device_id, std::string file_name) : device_id_(device_id), file_name_(std::move(file_name)) {} - - virtual void ReportData() = 0; - - protected: - void ReportByLine(const std::string &data, const std::string &file_name) const; - void ReportAllLine(); - - int device_id_; - std::string file_name_; - std::vector> prof_desc_list_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.cc deleted file mode 100644 index 1f2d1570bb..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "device/ascend/profiling/reporter/graph_desc_reporter.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace ascend { -void GraphDescReporter::ReportData() { - for (const auto &node : cnode_list_) { - if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { - MS_LOG(WARNING) << "Skip non tbe kernel"; - continue; - } - std::vector input_data_list; - std::vector output_data_list; - MS_EXCEPTION_IF_NULL(node); - auto op_name = node->fullname_with_scope(); - auto op_type = AnfAlgo::GetCNodeName(node); - auto input_size = AnfAlgo::GetInputTensorNum(node); - for (size_t i = 0; i < input_size; ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); - auto input_node = input_node_with_index.first; - auto input_index = input_node_with_index.second; - DataElement element{}; - element.index_ = i; - element.data_type_ = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); - element.data_format_ = AnfAlgo::GetOutputFormat(input_node, input_index); - element.data_shape_ = AnfAlgo::GetOutputDeviceShape(input_node, input_index); - input_data_list.emplace_back(element); - } - - auto output_size = AnfAlgo::GetOutputTensorNum(node); - for (size_t i = 0; i < output_size; ++i) { - DataElement element{}; - element.index_ = i; - element.data_type_ = AnfAlgo::GetOutputDeviceDataType(node, i); - element.data_format_ = AnfAlgo::GetOutputFormat(node, i); - element.data_shape_ = AnfAlgo::GetOutputDeviceShape(node, i); - output_data_list.emplace_back(element); - } - - auto graph_desc = std::make_shared(op_name, op_type, input_data_list, output_data_list); - prof_desc_list_.emplace_back(graph_desc); - } - ReportAllLine(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h deleted file mode 100644 index 10f78092f2..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/graph_desc_reporter.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ - -#include -#include -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -class GraphDescReporter : public DescReporter { - public: - GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector cnode_list) - : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} - ~GraphDescReporter() override = default; - void ReportData() override; - - private: - std::vector cnode_list_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.cc deleted file mode 100644 index 0024ab9c22..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/profiling/reporter/point_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -void PointReporter::ReportData() { ReportAllLine(); } - -void PointReporter::AddReportData(const std::shared_ptr &prof_desc) { - prof_desc_list_.emplace_back(prof_desc); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.h deleted file mode 100644 index ae12672df6..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/point_reporter.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ - -#include -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -class PointReporter : public DescReporter { - public: - PointReporter(uint32_t device_id, const std::string &file_name) : DescReporter(device_id, file_name) {} - ~PointReporter() override = default; - void ReportData() override; - void AddReportData(const std::shared_ptr &prof_desc); -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.cc deleted file mode 100644 index 082cb81e42..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include "device/ascend/profiling/reporter/profiling_desc.h" - -namespace mindspore { -namespace device { -namespace ascend { -std::string TaskDesc::ToString() { - std::string out = op_name_; - out.append(" ") - .append(std::to_string(block_dim_)) - .append(" ") - .append(std::to_string(task_id_)) - .append(" ") - .append(std::to_string(stream_id_)) - .append("\n"); - return out; -} - -std::string GraphDesc::ToString() { - std::string desc; - desc.append("op_name:").append(op_name_).append(" op_type:").append(op_type_); - int input_id = 0; - for (const auto &element : input_data_list_) { - desc.append(" input_id:") - .append(std::to_string(input_id++)) - .append(" input_format:") - .append(element.data_format_) - .append(" input_data_type:") - .append(std::to_string(element.data_type_)) - .append(" input_shape:") - .append(DataShapeToString(element.data_shape_)); - } - - input_id = 0; - for (const auto &element : output_data_list_) { - desc.append(" output_id:") - .append(std::to_string(input_id++)) - .append(" output_format:") - .append(element.data_format_) - .append(" output_data_type:") - .append(std::to_string(element.data_type_)) - .append(" output_shape:") - .append((DataShapeToString(element.data_shape_))); - } - - desc.append("\n"); - - return desc; -} - -std::string PointDesc::ToString() { - std::string desc; - desc.append(std::to_string(point_id_)).append(" ").append(op_name_).append("\n"); - return desc; -} - -std::string GraphDesc::DataShapeToString(const std::vector &shape) { - std::ostringstream oss; - oss << "\""; - if (!shape.empty()) { - std::copy(shape.begin(), shape.end() - 1, std::ostream_iterator(oss, ",")); - oss << shape.back(); - } - oss << "\""; - return oss.str(); -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.cc b/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.cc deleted file mode 100644 index 0bd66e31ef..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "device/ascend/profiling/reporter/task_desc_reporter.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/ascend_kernel_mod.h" - -namespace mindspore { -namespace device { -namespace ascend { -void TaskDescReporter::ReportData() { - MS_LOG(INFO) << "cnode_list.size()=" << cnode_list_.size() << " task_ids_.size()=" << task_ids_.size(); - if (cnode_list_.size() != task_ids_.size()) { - MS_LOG(ERROR) << "cnode list size not equal task ids size"; - return; - } - - size_t task_index = 0; - for (const auto &node : cnode_list_) { - if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { - MS_LOG(WARNING) << "Skip non tbe kernel"; - ++task_index; - continue; - } - auto kernel_mod = AnfAlgo::GetKernelMod(node); - auto ascend_kernel_mod = dynamic_cast(kernel_mod); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(ascend_kernel_mod); - // Check task_id and stream_id valid - CheckStreamTaskValid(task_index, task_index); - auto desc_ptr = std::make_shared(node->fullname_with_scope(), task_ids_[task_index], - ascend_kernel_mod->block_dim(), stream_ids_[task_index]); - prof_desc_list_.emplace_back(desc_ptr); - ++task_index; - } - ReportAllLine(); -} - -void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) { - if (task_id >= task_ids_.size() || stream_id >= stream_ids_.size()) { - MS_LOG(EXCEPTION) << "Index invalid. task_id:" << task_id << ", task_ids.size:" << task_ids_.size() - << ", stream_id:" << stream_id << ", stream_ids.size:" << stream_ids_.size(); - } -} -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.h b/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.h deleted file mode 100644 index 087c691a5f..0000000000 --- a/mindspore/ccsrc/device/ascend/profiling/reporter/task_desc_reporter.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ - -#include -#include -#include -#include "device/ascend/profiling/reporter/desc_reporter.h" - -namespace mindspore { -namespace device { -namespace ascend { -class TaskDescReporter : public DescReporter { - public: - TaskDescReporter(int device_id, const std::string &file_name, std::vector cnode_list) - : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} - ~TaskDescReporter() override = default; - void ReportData() override; - void set_task_ids(const std::vector &task_ids) { task_ids_ = task_ids; } - void set_stream_ids(const std::vector &stream_ids) { stream_ids_ = stream_ids; } - - private: - std::vector task_ids_; - std::vector stream_ids_; - void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id); - std::vector cnode_list_; -}; -} // namespace ascend -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc b/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc deleted file mode 100644 index 3faeefb820..0000000000 --- a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.cc +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/tasksink/runtime_utils.h" - -#include - -#include "hccl/hcom.h" -#include "utils/log_adapter.h" -#include "utils/utils.h" - -constexpr auto kHcomBroadcast = "hcom_broadcast_"; -constexpr auto kHcomAllGather = "hcom_all_gather_"; -constexpr auto kHcomAllReduce = "hcom_all_reduce_"; -constexpr auto kHcomReduceScatter = "hcom_reduce_scatter_"; -constexpr auto kUnderline = "_"; -namespace mindspore { -namespace device { -namespace ascend { -namespace tasksink { -bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { - hcclResult_t ret = hcom_bind_model(model, stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast(ret); - return false; - } - return true; -} - -bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { - hcclResult_t ret = hcom_unbind_model(model); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast(ret); - return false; - } - return true; -} - -bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info, rtStream_t stream) { - MS_LOG(INFO) << "hccl distribute start"; - MS_EXCEPTION_IF_NULL(task_info); - hcclResult_t ret; - static uint32_t task_counter = 0; - auto hccl_group = task_info->group(); - if (task_info->hccl_type() == kBroadcastOpName) { - // call hcom broadcast interface to run op - const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast(task_info->count()), - static_cast(task_info->data_type()), static_cast(task_info->root_id()), - hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast(ret); - return false; - } - } else if (task_info->hccl_type() == kAllGatherOpName) { - // call hcom allgather interface to run op - const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; - return false; - } - } else if (task_info->hccl_type() == kAllReduceOpName) { - // call hcom allreduce interface to run op - const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; - return false; - } - } else if (task_info->hccl_type() == kReduceScatterOpName) { - // call hcom reducescatter interface to run op - const string tag_reduce_scatter = - kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); - ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), - static_cast(task_info->count()), static_cast(task_info->data_type()), - static_cast(task_info->op_type()), hccl_group.c_str(), stream); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; - return false; - } - } - return true; -} -} // namespace tasksink -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc deleted file mode 100644 index 00489c7299..0000000000 --- a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc +++ /dev/null @@ -1,200 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/ascend/tasksink/task_generator.h" - -#include -#include "kernel/task_stream.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "device/ascend/profiling/profiling_manager.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace tasksink { -bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, - uint32_t graph_id) { - MS_LOG(INFO) << "GenTasks start..."; - MS_EXCEPTION_IF_NULL(task_info_list); - // Traverse graph applykernel list and run - if (!LaunchAllKernel(anf_node_list, task_info_list, graph_id)) { - MS_LOG(ERROR) << "LaunchAllKernel failed"; - return false; - } - MS_LOG(INFO) << "GenTasks end..."; - return true; -} - -void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { - MS_EXCEPTION_IF_NULL(anf_node_ptr); - MS_EXCEPTION_IF_NULL(kernel_inputs); - // akg process - // set atomic clean addr - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, anf_node_ptr)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicOutputIndexs); - auto graph = anf_node_ptr->func_graph(); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_users = manager->node_users(); - if (node_users[anf_node_ptr].empty()) { - MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; - } - auto depend_node = node_users[anf_node_ptr].pop().first; - if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { - MS_LOG(EXCEPTION) << "Checking Depend node failed"; - } - if (node_users[depend_node].empty()) { - MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty."; - } - auto post_node = node_users[depend_node].pop().first; - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(post_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - input->size = device_address->size_; - kernel_inputs->push_back(input); - } - MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size(); - } -} - -void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { - MS_EXCEPTION_IF_NULL(anf_node_ptr); - MS_EXCEPTION_IF_NULL(kernel_inputs); - if (anf_node_ptr->inputs().size() != 2) { - LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs); - return; - } - MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); - auto pre_node = (anf_node_ptr->inputs()[1])->cast(); - // set clean output addr - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->push_back(input); - } - MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); - } - // set clean workspace address - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - for (const auto &index : clean_workspace_indexs) { - auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_inputs->push_back(workspace); - } - } - auto clear_mems = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); - if (kernel_inputs->size() != clear_mems.size()) { - MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" - << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); - } -} - -bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, - std::vector *task_info_list) { - MS_EXCEPTION_IF_NULL(task_info_list); - MS_EXCEPTION_IF_NULL(anf_node_ptr); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr); - MS_EXCEPTION_IF_NULL(kernel_mod); - kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope()); - if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { - auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i); - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index); - AddressPtr input = std::make_shared
(); - input->addr = device_address->ptr_; - input->size = device_address->size_; - kernel_inputs.push_back(input); - } - - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node_ptr); ++i) { - auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i); - AddressPtr output = std::make_shared
(); - output->addr = it->ptr_; - output->size = it->size_; - kernel_outputs.push_back(output); - } - - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetWorkspaceAddr(anf_node_ptr, i); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - workspace->size = device_address->size_; - kernel_workspaces.push_back(workspace); - } - } else { - LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs); - } - - auto ascend_kernel_mod = dynamic_cast(kernel_mod); - MS_EXCEPTION_IF_NULL(ascend_kernel_mod); - std::vector task_info_ptrs = - ascend_kernel_mod->GenTask(kernel_inputs, kernel_workspaces, kernel_outputs, stream_id); - task_info_list->insert(task_info_list->end(), task_info_ptrs.begin(), task_info_ptrs.end()); - return true; -} - -bool TaskGenerator::LaunchAllKernel(const std::vector &anf_node_list, - std::vector *task_info_list, uint32_t graph_id) { - uint32_t current_op_index = 0; - std::vector profiling_cnode_list; - std::vector kernel_name_list; - for (const auto &anf_node_ptr : anf_node_list) { - size_t old_size = task_info_list->size(); - uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr); - MS_EXCEPTION_IF_NULL(anf_node_ptr); - MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index - << " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id; - if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) { - MS_LOG(ERROR) << "LaunchKernel failed."; - return false; - } - for (size_t i = old_size; i < task_info_list->size(); ++i) { - profiling_cnode_list.emplace_back(anf_node_ptr); - kernel_name_list.emplace_back(anf_node_ptr->fullname_with_scope()); - } - current_op_index++; - } - - ProfilingUtils::SetGraphKernelName(graph_id, kernel_name_list); - if (ProfilingManager::GetInstance().IsProfiling()) { - ProfilingUtils::SetGraphProfilingCNode(graph_id, profiling_cnode_list); - } - return true; -} -} // namespace tasksink -} // namespace ascend -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/tasksink/task_generator.h b/mindspore/ccsrc/device/ascend/tasksink/task_generator.h deleted file mode 100644 index ecd5889b04..0000000000 --- a/mindspore/ccsrc/device/ascend/tasksink/task_generator.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ -#define MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ - -#include -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "ir/anf.h" -#include "kernel/ascend_kernel_mod.h" -#include "framework/ge_runtime/task_info.h" - -namespace mindspore { -namespace device { -namespace ascend { -namespace tasksink { -using mindspore::kernel::Address; -using mindspore::kernel::AddressPtr; -using AddressPtrList = std::vector; -using ge::model_runner::TaskInfo; -using TaskInfoPtr = std::shared_ptr; -class TaskGenerator { - public: - TaskGenerator() = default; - ~TaskGenerator() = default; - TaskGenerator(const TaskGenerator &in) = delete; - TaskGenerator &operator=(const TaskGenerator &in) = delete; - - static bool GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, - uint32_t graph_id); - - private: - static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); - static void LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); - static bool LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, std::vector *task_info_list); - static bool LaunchAllKernel(const std::vector &anf_node_list, std::vector *task_info_list, - uint32_t graph_id); -}; -} // namespace tasksink -} // namespace ascend -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ diff --git a/mindspore/ccsrc/device/convert_tensor_utils.cc b/mindspore/ccsrc/device/convert_tensor_utils.cc deleted file mode 100644 index bac72727c2..0000000000 --- a/mindspore/ccsrc/device/convert_tensor_utils.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/convert_tensor_utils.h" -#include -namespace mindspore { -namespace device { -void HalfToFloat(void *dst, const void *src, size_t elem_num) { - auto half_data = static_cast(src); - auto float_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - float tmp = Eigen::half_impl::half_to_float(half_data[i]); - float_data[i] = tmp; - } -} - -void FloatToHalf(void *dst, const void *src, size_t elem_num) { - auto float_data = static_cast(src); - auto half_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - half_data[i] = Eigen::half(float_data[i]); - } -} - -void DoubleToFloat(void *dst, const void *src, size_t elem_num) { - auto double_data = static_cast(src); - auto float_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - float_data[i] = static_cast(double_data[i]); - } -} - -void FloatToDouble(void *dst, const void *src, size_t elem_num) { - auto float_data = static_cast(src); - auto double_data = static_cast(dst); - for (size_t i = 0; i < elem_num; ++i) { - double_data[i] = static_cast(float_data[i]); - } -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/device/cpu/cpu_device_address.cc deleted file mode 100644 index 09ab0da12b..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_device_address.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_device_address.h" -#include -#include "device/convert_tensor_utils.h" - -namespace mindspore { -namespace device { -namespace cpu { -bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size_t size, TypeId type, - void *host_ptr) const { - if (ptr_ == nullptr) { - MS_LOG(ERROR) << "The pointer ptr_ is null!"; - return false; - } - - if (host_ptr == ptr_) { - MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; - return true; - } - - if (type == type_id_) { - auto ret_code = memcpy_s(host_ptr, size, ptr_, size_); - if (ret_code != EOK) { - MS_LOG(ERROR) << "Failed to copy tensor!"; - return false; - } - } else if (type == kNumberTypeFloat16) { - FloatToHalf(host_ptr, ptr_, size / 2); - } else if (type == kNumberTypeFloat64) { - FloatToDouble(host_ptr, ptr_, size / sizeof(double)); - } else { - MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) - << "!"; - return false; - } - return true; -} - -bool CPUDeviceAddress::SyncHostToDevice(const std::vector & /*shape*/, size_t size, TypeId type, - const void *host_ptr) const { - if (type == kNumberTypeFloat16) { - HalfToFloat(ptr_, host_ptr, size / 2); - } else if (type == kNumberTypeFloat64) { - DoubleToFloat(ptr_, host_ptr, size / sizeof(double)); - } - return true; -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_device_address.h b/mindspore/ccsrc/device/cpu/cpu_device_address.h deleted file mode 100644 index a041567f47..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_device_address.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ - -#include -#include -#include "device/device_address.h" - -namespace mindspore { -namespace device { -namespace cpu { -class CPUDeviceAddress : public DeviceAddress { - public: - CPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} - - CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - - ~CPUDeviceAddress() override = default; - - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; - DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc deleted file mode 100644 index f46d10ed82..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_kernel_runtime.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "utils/context/ms_context.h" -#include "utils/config_manager.h" -#include "utils/profile.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "session/session_basic.h" -#include "operator/ops.h" - -namespace mindspore { -namespace device { -namespace cpu { -const size_t INIT_NODE_REF = 1; -namespace { -TypeId GetCPUSupportOutputTypeId(const TypeId type_id) { - TypeId support_type_id = type_id; - if (type_id == kNumberTypeUInt32) { - support_type_id = kNumberTypeInt32; - } - if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 || - type_id == kNumberTypeFloat64) { - support_type_id = kNumberTypeFloat32; - } - if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) { - MS_LOG(EXCEPTION) << "Check output type failed."; - } - return support_type_id; -} -} // namespace - -void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { - AssignValueNodeAddress(kernel_graph); - AssignInputNodeAddress(kernel_graph); - AssignKernelOutputAddress(kernel_graph); - resource_manager_.MemPlan(kernel_graph); - resource_manager_.MemMalloc(kernel_graph); -} - -void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - size_t type_size = sizeof(float); - for (auto &item_node : kernel_graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(item_node); - if (item_node->isa()) { - auto value_node = item_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - if (!node_value->isa()) { - continue; - } - auto tensor = node_value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - std::vector data_shape = tensor->shape(); - size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); - DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32); - MS_EXCEPTION_IF_NULL(address); - if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { - address->ptr_ = tensor->data_c(); - } else { - address->ptr_ = resource_manager_.MemMalloc(tensor_size); - if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "Value node sync host to device failed!"; - } - } - address->ref_count_ = INIT_NODE_REF; - AnfAlgo::SetOutputAddr(address, 0, item_node.get()); - } - } -} - -void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - size_t type_size = sizeof(float); - for (auto &item : kernel_graph->inputs()) { - MS_EXCEPTION_IF_NULL(item); - if (item->isa()) { - auto output_num = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_num; index++) { - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); - std::vector fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index); - size_t tensor_size = - fmt_shape.empty() ? type_size - : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies()); - auto format = AnfAlgo::GetOutputFormat(item, index); - auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id); - AnfAlgo::SetOutputAddr(address, index, item.get()); - } - } - } -} - -void CPUKernelRuntime::AssignKernelOutputAddress(const session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto kernels = kernel_graph->execution_order(); - for (auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - auto output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i, - kernel.get()); - } - auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); - for (size_t i = 0; i < workspace_sizes.size(); ++i) { - AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32), - i, kernel.get()); - } - } -} - -DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, - std::vector *need_sync_outputs) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(bound_addresses); - MS_EXCEPTION_IF_NULL(need_sync_outputs); - size_t output_size = AnfAlgo::GetOutputTensorNum(node); - if (index >= output_size) { - MS_LOG(EXCEPTION) << "Invalid input index " << index; - } - auto address = AnfAlgo::GetMutableOutputAddr(node, index); - MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, index); - std::vector temp_shape; - (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); - TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); - type_id = GetCPUSupportOutputTypeId(type_id); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - MS_EXCEPTION_IF_NULL(tensor); - if (bound_addresses->find(address) != bound_addresses->end()) { - tensor->set_device_address(address); - need_sync_outputs->emplace_back(tensor); - } else { - address->ptr_ = tensor->data_c(); - address->ref_count_ = INIT_NODE_REF; - (void)bound_addresses->insert(address); - } - tensor->set_dirty(false); - return tensor; -} - -BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, - std::vector *need_sync_outputs) { - auto &input_node = kernel_with_index.first; - auto index = kernel_with_index.second; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - auto node = input_node->cast(); - MS_EXCEPTION_IF_NULL(node); - if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) { - VectorRef ret; - for (size_t i = 1; i < node->inputs().size(); i++) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); - auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); - ret.push_back(out); - } - return ret; - } - return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); - } else if (input_node->isa() || input_node->isa()) { - auto iter = input_map.find(input_node.get()); - if (iter != input_map.end()) { - return iter->second; - } - } - return BaseRef(); -} - -void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, - const std::vector &inputs, VectorRef *outputs, - std::vector *need_sync_outputs) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(outputs); - // bind input ptr - auto &input_nodes = kernel_graph->inputs(); - if (input_nodes.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; - } - std::unordered_map input_map; - size_t input_idx = 0; - for (auto &item : input_nodes) { - MS_EXCEPTION_IF_NULL(item); - input_map[item.get()] = inputs[input_idx]; - if (item->isa()) { - auto address = AnfAlgo::GetMutableOutputAddr(item, 0); - auto tensor = inputs[input_idx]; - auto tensor_address = tensor->device_address(); - MS_EXCEPTION_IF_NULL(address); - MS_EXCEPTION_IF_NULL(tensor); - if (tensor_address != nullptr && tensor_address != address) { - (void)tensor->data_sync(); - } - std::vector data_shape = tensor->shape(); - size_t tensor_size = - std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); - if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { - address->ptr_ = tensor->data_c(); - } else { - address->ptr_ = resource_manager_.MemMalloc(tensor_size); - if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; - } - tensor->set_dirty(true); - } - address->ref_count_ = INIT_NODE_REF; - tensor->set_device_address(address); - } - input_idx++; - } - // new output and bind ptr - std::set bound_addresses; - auto output_nodes = kernel_graph->outputs(); - for (const auto &item : output_nodes) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); - auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); - outputs->push_back(std::move(out)); - } -} - -void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector *input_list) { - MS_EXCEPTION_IF_NULL(address); - MS_EXCEPTION_IF_NULL(input_list); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - if (address->ptr_ == nullptr) { - address->ptr_ = resource_manager_.MemMalloc(address->size_); - } - MS_EXCEPTION_IF_NULL(address->ptr_); - input->addr = address->ptr_; - input->size = address->size_; - input_list->push_back(input); -} - -void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - resource_manager_.IncreaseSummaryRefCount(summary_outputs); -} - -void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - resource_manager_.DecreaseSummaryRefCount(summary_outputs); -} - -bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - resource_manager_.IncreaseAddressRefCount(kernel_graph); - - auto kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { -#ifdef ENABLE_PROFILE - double start_time = GetTime(); -#endif - std::vector kernel_inputs; - std::vector kernel_workspaces; - std::vector kernel_outputs; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i).get(); - MS_EXCEPTION_IF_NULL(device_address); - AddRuntimeAddress(device_address, &kernel_inputs); - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t i = 0; i < output_num; ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i).get(); - MS_EXCEPTION_IF_NULL(device_address); - AddRuntimeAddress(device_address, &kernel_outputs); - } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(device_address); - AddRuntimeAddress(device_address, &kernel_workspaces); - } - auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0); - resource_manager_.DecreaseAddressRefCount(kernel); - if (!ret) { - MS_LOG(EXCEPTION) << "Launch kernel failed."; - } -#ifdef ENABLE_PROFILE - double cost_time = GetTime() - start_time; - MS_LOG(INFO) << "cpu kernel: " << kernel->fullname_with_scope() << " costs " << cost_time * 1e6 << " us"; -#endif - } - return true; -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h deleted file mode 100644 index 354d2922c2..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ - -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "session/kernel_graph.h" -#include "session/session_basic.h" -#include "device/cpu/cpu_resource_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/any.h" -namespace mindspore { -namespace device { -namespace cpu { -class CPUKernelRuntime : public KernelRuntime { - public: - CPUKernelRuntime() = default; - ~CPUKernelRuntime() override = default; - - bool Init() override { return true; } - bool Run(session::KernelGraph *graph) override; - void AssignKernelAddress(session::KernelGraph *kernel_graph); - void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, - VectorRef *outputs, std::vector *need_sync_outputs); - void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - - protected: - bool SyncStream() override { return true; }; - DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override; - - private: - tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, - std::vector *need_sync_outputs); - - BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, - std::vector *need_sync_outputs); - void AssignValueNodeAddress(session::KernelGraph *kernel_graph); - void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); - void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); - void AddRuntimeAddress(DeviceAddress *address, std::vector *input_list); - CPUResourceManager resource_manager_; -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc b/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc deleted file mode 100644 index c69ef35305..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_resource_manager.cc +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_resource_manager.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace cpu { -CPUResourceManager::~CPUResourceManager() { MemFree(); } - -void CPUResourceManager::MemFree() { - if (mem_ptr_ != nullptr) { - free(mem_ptr_); - mem_ptr_ = nullptr; - mem_size_ = 0; - } - - for (auto &&iter : dynamic_mem_) { - free(iter.first); - } - dynamic_mem_.clear(); -} - -void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { - mem_plan_.MemPlan(graph); - size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph); - if (graph_mem_size > mem_size_) { - MemFree(); - mem_ptr_ = reinterpret_cast(malloc(graph_mem_size)); - if (mem_ptr_ != nullptr) { - mem_size_ = graph_mem_size; - dynamic_malloc_ = false; - } else { - MS_LOG(INFO) << "Switch to dynamic malloc"; - dynamic_malloc_ = true; - } - } -} - -void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) { - if (dynamic_malloc_) { - return; - } - mem_plan_.MemAssign(graph, mem_ptr_); -} - -void *CPUResourceManager::MemMalloc(size_t mem_size) { - void *ptr = malloc(mem_size); - if (ptr != nullptr) { - memset_s(ptr, mem_size, 0, mem_size); - dynamic_mem_[ptr] = mem_size; - return ptr; - } else { - MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size; - } -} - -void CPUResourceManager::MemFree(void *ptr) { - auto iter = dynamic_mem_.find(ptr); - if (iter != dynamic_mem_.end()) { - (void)dynamic_mem_.erase(iter); - free(ptr); - } -} - -void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - if (!dynamic_malloc_) { - return; - } - - if (summary_outputs.empty()) { - return; - } - - for (auto &output_item : summary_outputs) { - auto node = output_item.second.first; - size_t index = IntToSize(output_item.second.second); - auto address = AnfAlgo::GetMutableOutputAddr(node, index); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_++; - } -} - -void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { - if (!dynamic_malloc_) { - return; - } - - if (summary_outputs.empty()) { - return; - } - - for (auto &output_item : summary_outputs) { - auto node = output_item.second.first; - size_t index = IntToSize(output_item.second.second); - auto address = AnfAlgo::GetMutableOutputAddr(node, index); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_--; - if (address->ref_count_ == 0 && address->ptr_ != nullptr) { - MemFree(address->ptr_); - address->ptr_ = nullptr; - } - } -} - -void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { - if (!dynamic_malloc_) { - return; - } - MS_EXCEPTION_IF_NULL(graph); - auto kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_++; - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_++; - } - } -} - -void CPUResourceManager::DecreaseAddressRefCount(const AnfNodePtr &kernel) { - if (!dynamic_malloc_) { - return; - } - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_--; - if (address->ref_count_ == 0 && address->ptr_ != nullptr) { - MemFree(address->ptr_); - address->ptr_ = nullptr; - } - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - address->ref_count_--; - if (address->ref_count_ == 0 && address->ptr_ != nullptr) { - MemFree(address->ptr_); - address->ptr_ = nullptr; - } - } -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_resource_manager.h b/mindspore/ccsrc/device/cpu/cpu_resource_manager.h deleted file mode 100644 index d130241464..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_resource_manager.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ - -#include -#include -#include "session/kernel_graph.h" -#include "session/session_basic.h" -#include "device/device_address.h" -#include "device/cpu/cpu_simple_mem_plan.h" -namespace mindspore { -namespace device { -namespace cpu { -class CPUResourceManager { - public: - CPUResourceManager() = default; - ~CPUResourceManager(); - - void MemPlan(const session::KernelGraph *graph); - void MemMalloc(const session::KernelGraph *graph); - void IncreaseAddressRefCount(const session::KernelGraph *graph); - void DecreaseAddressRefCount(const AnfNodePtr &kernel); - void *MemMalloc(size_t mem_size); - void MemFree(void *ptr); - void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); - - private: - void MemFree(); - CPUSimpleMemPlan mem_plan_; - - size_t mem_size_{0}; - uint8_t *mem_ptr_{nullptr}; - bool dynamic_malloc_{false}; - std::unordered_map dynamic_mem_; -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ diff --git a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc b/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc deleted file mode 100644 index e6cb6ee53a..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/cpu_simple_mem_plan.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace cpu { -void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - size_t total_mem_size = 0; - auto kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); - MS_EXCEPTION_IF_NULL(kernel_with_index.first); - if (kernel_with_index.first->isa()) { - continue; - } - auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - total_mem_size += address->size_; - } - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t i = 0; i < output_num; ++i) { - auto address = AnfAlgo::GetOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - total_mem_size += address->size_; - } - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - total_mem_size += address->size_; - } - } - } - graph_mem_size_[graph] = total_mem_size; -} - -size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const { - auto iter = graph_mem_size_.find(graph); - if (iter != graph_mem_size_.end()) { - return iter->second; - } - return 0; -} - -void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(base_ptr); - uint8_t *mem_ptr = base_ptr; - auto kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel); - for (size_t i = 0; i < input_num; ++i) { - auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); - MS_EXCEPTION_IF_NULL(kernel_with_index.first); - if (kernel_with_index.first->isa()) { - continue; - } - auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - address->ptr_ = mem_ptr; - mem_ptr = mem_ptr + address->size_; - } - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t i = 0; i < output_num; ++i) { - auto address = AnfAlgo::GetMutableOutputAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - address->ptr_ = mem_ptr; - mem_ptr = mem_ptr + address->size_; - } - } - - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); - MS_EXCEPTION_IF_NULL(address); - if (address->ptr_ == nullptr) { - address->ptr_ = mem_ptr; - mem_ptr = mem_ptr + address->size_; - } - } - } -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h b/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h deleted file mode 100644 index 7633ef3f45..0000000000 --- a/mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ -#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ - -#include -#include -#include "session/kernel_graph.h" -#include "device/device_address.h" - -namespace mindspore { -namespace device { -namespace cpu { -class CPUSimpleMemPlan { - public: - CPUSimpleMemPlan() = default; - ~CPUSimpleMemPlan() = default; - - void MemPlan(const session::KernelGraph *graph); - void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); - size_t GetGraphMemSize(const session::KernelGraph *graph) const; - - private: - std::unordered_map graph_mem_size_; -}; -} // namespace cpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc deleted file mode 100644 index 9d72bcab89..0000000000 --- a/mindspore/ccsrc/device/cpu/kernel_select_cpu.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/cpu/kernel_select_cpu.h" - -#include -#include -#include - -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace device { -namespace cpu { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; -using mindspore::kernel::KernelBuildInfo; -namespace { -bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { - auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() || input_node->isa()) { - return true; - } - return false; -} - -void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector &input_not_cnode_indexes, - const CNodePtr kernel_node) { - for (auto &input_index : input_not_cnode_indexes) { - auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; - MS_EXCEPTION_IF_NULL(input_node); - std::vector output_types; - output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetOutputsFormat({kOpFormat_DEFAULT}); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); - } -} - -void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector *input_formats, - std::vector *input_types, std::vector *input_no_cnode_indexes) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - TypeId dtype = kTypeUnknown; - if (IsInputNotCNode(kernel_node, input_index)) { - input_no_cnode_indexes->emplace_back(input_index); - dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - } else { - dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); - } - input_formats->emplace_back(kOpFormat_DEFAULT); - input_types->emplace_back(dtype); - } -} - -void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, - std::vector *output_formats, std::vector *output_types) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); - auto dtype = kernel_attr.GetOutputAttr(output_index).first; - output_types->emplace_back(dtype); - } -} - -bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector &input_formats, - const std::vector &input_types, - const std::vector &input_not_cnode_indexes) { - if (kernel_attr.GetInputSize() != input_types.size()) { - MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); - return false; - } - auto input_num = input_types.size(); - for (size_t i = 0; i < input_num; ++i) { - bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), - [i](size_t index) { return index == i; }); - bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); - if (have_cnode_input && is_not_cnode_idx) { - continue; - } - if (kernel_attr.GetInputAttr(i).first != input_types[i]) { - MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first - << ", actual input dtype:" << input_types[i]; - return false; - } - if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { - MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second - << ", actual input format:" << input_formats[i]; - return false; - } - } - return true; -} - -void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { - MS_EXCEPTION_IF_NULL(kernel_attr); - TypeId input_dtype = kernel_attr->GetInputAttr(0).first; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 1; i < input_num; ++i) { - kernel_attr->AddInputAttr(input_dtype); - } - - TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t i = 1; i < output_num; ++i) { - kernel_attr->AddOutputAttr(output_dtype); - } -} -} // namespace - -void SetKernelInfo(const CNodePtr &kernel_node) { - std::vector input_formats; - std::vector input_types; - std::vector input_not_cnode_indexes; - std::vector output_formats; - std::vector output_types; - - MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); - GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); - - auto kernel_attrs = - kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); - - for (size_t index = 0; index < kernel_attrs.size(); ++index) { - auto kernel_attr = kernel_attrs[index]; - if (kernel_attr.GetAllSame()) { - ExpandKernelAttr(kernel_node, &kernel_attr); - } - if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (kernel_attr.GetOutputSize() != output_num) { - MS_LOG(DEBUG) << "Output num is not equal!"; - continue; - } - MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; - GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); - UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); - for (auto &input_index : input_not_cnode_indexes) { - input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; - } - break; - } - } - - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - builder->SetInputsFormat(input_formats); - builder->SetInputsDeviceType(input_types); - builder->SetOutputsFormat(output_formats); - builder->SetOutputsDeviceType(output_types); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); -} -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc b/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc deleted file mode 100644 index 9b06c0a40a..0000000000 --- a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.cc +++ /dev/null @@ -1,277 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/cpu/mpi/mpi_adapter.h" -#ifdef ENABLE_MPI -#include -#include -#include "pybind11/pybind11.h" -#endif // ENABLE_MPI -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace cpu { -std::shared_ptr MPIAdapter::instance_ = nullptr; -std::shared_ptr MPIAdapter::Instance() { - if (instance_ == nullptr) { - MS_LOG(DEBUG) << "Create new mpi adapter instance."; - instance_.reset(new (std::nothrow) MPIAdapter()); - } - return instance_; -} - -#ifdef ENABLE_MPI - -#define RAISE_EXCEPTION(message) \ - { \ - std::ostringstream oss; \ - oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ - pybind11::pybind11_fail(oss.str()); \ - } - -#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ - { \ - std::ostringstream oss; \ - oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ - pybind11::pybind11_fail(oss.str()); \ - } - -namespace { -MPI_Op GetMpiOp(const std::string &op_type) { - if (op_type == "sum") { - return MPI_SUM; - } else if (op_type == "max") { - return MPI_MAX; - } else if (op_type == "min") { - return MPI_MIN; - } else if (op_type == "prod") { - return MPI_PROD; - } - - RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type); - return MPI_SUM; -} - -int GetScatterIndex(int rankid, const std::vector &ranks_group) { - int scatter_index = -1; - for (size_t i = 0; i < ranks_group.size(); ++i) { - if (ranks_group[i] == rankid) { - scatter_index = static_cast(i); - break; - } - } - if (scatter_index == -1) { - RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid); - } - return scatter_index; -} -} // namespace - -MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); } - -MPIAdapter::~MPIAdapter() { - int finalized; - MPI_Finalized(&finalized); - if (finalized != 0) { - return; - } - - for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { - MPI_Group_free(&iter->second); - } - ranks_group_.clear(); - if (comm_group_world_ != MPI_GROUP_NULL) { - MPI_Group_free(&comm_group_world_); - comm_group_world_ = MPI_GROUP_NULL; - } - MPI_Finalize(); -} - -void MPIAdapter::Init() { - static bool init = false; - if (init) { - return; - } - - int init_flag = 0; - if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { - RAISE_EXCEPTION("Check mpi initialized fail!"); - } - if (init_flag == 0) { - auto ret = MPI_Init(nullptr, nullptr); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION("Failed to init mpi!"); - } - } - - MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); - if (comm_group_world_ == MPI_GROUP_NULL) { - RAISE_EXCEPTION("comm_group_world_ init fail!"); - } - auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION("Failed to init mpi rank id!"); - } - - ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) - } - init = true; -} - -MPI_Group MPIAdapter::AddGroup(const std::vector &ranks) { - if (ranks.size() > static_cast(rank_size_) || ranks.empty()) { - RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size()); - } - - if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { - RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_); - } - std::lock_guard lock(group_mutex_); - auto iter = ranks_group_.find(ranks); - if (iter != ranks_group_.end()) { - return iter->second; - } - const auto ranks_size = ranks.size(); - std::vector ranks_input(ranks_size, 0); - for (size_t i = 0; i < ranks_size; ++i) { - ranks_input[i] = ranks[i]; - } - - MPI_Group group = MPI_GROUP_NULL; - MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) - } - - ranks_group_[ranks] = group; - return group; -} - -bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, - const std::string &op_type) { - if (ranks_group.empty()) { - RAISE_EXCEPTION("input rank group is empty!"); - return false; - } - - auto group = AddGroup(ranks_group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) - } - MPI_Comm comm; - MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); - if (comm == MPI_COMM_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); - } - std::vector receive_count(ranks_group.size(), 0); - for (size_t i = 0; i < ranks_group.size(); ++i) { - receive_count[i] = data_num; - } - - auto op = GetMpiOp(op_type); - auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); - bool result = true; - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret); - result = false; - } - - ret = MPI_Comm_free(&comm); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret); - } - return result; -} - -bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t input_data_num, - size_t output_size, const std::string &op_type, float *output) { - int scatter_index = GetScatterIndex(rank_id_, ranks_group); - auto group = AddGroup(ranks_group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); - } - MPI_Comm comm; - MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); - if (comm == MPI_COMM_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); - } - - MPI_Win window; - auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret); - } - MPI_Win_fence(0, window); - for (size_t i = 0; i < ranks_group.size(); ++i) { - int remote_rank = ranks_group[i]; - if (rank_id_ == remote_rank) { - continue; - } - auto op = GetMpiOp(op_type); - ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, - input_data_num, MPI_FLOAT, op, window); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret); - } - } - MPI_Win_fence(0, window); - if (output != nullptr) { - auto data_size = input_data_num * sizeof(float); - if (output_size < data_size) { - std::ostringstream exception_msg; - exception_msg << "output buffer size " << output_size << " < input size " << data_size; - RAISE_EXCEPTION(exception_msg.str()) - } - auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); - if (copy_ret != 0) { - RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret); - } - } - MPI_Win_free(&window); - MPI_Comm_free(&comm); - return true; -} - -bool MPIAdapter::AllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num) { - if (ranks_group.empty()) { - RAISE_EXCEPTION("input rank group is empty!"); - return false; - } - auto group = AddGroup(ranks_group); - if (group == MPI_GROUP_NULL) { - RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_); - } - MPI_Comm comm; - MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); - if (comm == MPI_COMM_NULL) { - RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_); - } - auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret); - } - ret = MPI_Comm_free(&comm); - if (ret != MPI_SUCCESS) { - RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret); - } - return true; -} -#endif // ENABLE_MPI -} // namespace cpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.cc b/mindspore/ccsrc/device/gpu/blocking_queue.cc deleted file mode 100644 index 3b5e75f551..0000000000 --- a/mindspore/ccsrc/device/gpu/blocking_queue.cc +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/blocking_queue.h" -#include -#include "device/gpu/gpu_common.h" -#include "common/utils.h" - -namespace mindspore { -namespace device { -GpuQueue::GpuQueue(void *addr, const std::vector &shape, const size_t &capacity) - : buffer_(addr), head_(0), tail_(0), shape_(shape), len_(0), capacity_(capacity), stream_(0), node_info_(nullptr) { - CHECK_CUDA_RET_WITH_ERROR(cudaStreamCreate(&stream_), "Cuda Create Stream Failed"); - node_info_ = std::make_unique(capacity); - for (auto item : shape) { - len_ += item; - } -} - -GpuQueue::~GpuQueue() { buffer_ = nullptr; } - -BlockQueueStatus_T GpuQueue::Push(const std::vector &data) { - int offset = 0; - for (size_t i = 0; i < data.size(); i++) { - auto item = data[i]; - if (item.data_ptr_ == nullptr || item.data_len_ != shape_[i]) { - MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_; - return ERROR_INPUT; - } - - void *addr = reinterpret_cast(buffer_) + tail_ * len_ + offset; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_), - "Cuda Memcpy Error"); - - offset += item.data_len_; - } - - node_info_[tail_].event_.reset(new cudaEvent_t()); - CHECK_CUDA_RET_WITH_ERROR(cudaEventCreate(&(*(node_info_[tail_].event_))), "Cuda Create Event Failed"); - node_info_[tail_].data_ = data; - tail_ = (tail_ + 1) % (capacity_); - return SUCCESS; -} - -BlockQueueStatus_T GpuQueue::Front(void **addr, size_t *len) const { - CHECK_CUDA_RET_WITH_ERROR(cudaEventSynchronize(*(node_info_[head_].event_)), "Cuda Event Syn Failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaEventDestroy(*(node_info_[head_].event_)), "Cuda Destroy Event Failed"); - *addr = (unsigned char *)buffer_ + head_ * len_; - *len = len_; - - for (auto item : node_info_[head_].data_) { - host_release_(item.data_ptr_); - } - return SUCCESS; -} - -BlockQueueStatus_T GpuQueue::Pop() { - head_ = (head_ + 1) % (capacity_); - return SUCCESS; -} - -bool GpuQueue::Destroy() { - if (stream_ != nullptr) { - auto ret = cudaStreamDestroy(stream_); - if (ret == cudaSuccess) { - return true; - } else { - return false; - } - } else { - return true; - } -} - -BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector &shape, const size_t &capacity) { - if (addr == nullptr) { - MS_LOG(ERROR) << "addr is nullptr"; - return INTERNAL_ERROR; - } - queue_ = std::make_shared(addr, shape, capacity); - return SUCCESS; -} - -void BlockingQueue::RegisterRelease(const std::function &func) { queue_->RegisterRelease(func); } - -BlockQueueStatus_T BlockingQueue::Push(const std::vector &data, unsigned int timeout_in_sec) { - std::unique_lock locker(mutex_); - if (queue_->IsFull()) { - if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) { - return TIMEOUT; - } - } - auto ret = queue_->Push(data); - if (ret) { - return ret; - } - not_empty_cond_.notify_one(); - return SUCCESS; -} - -BlockQueueStatus_T BlockingQueue::Front(void **addr, size_t *len) { - std::unique_lock locker(mutex_); - bool timeout = not_empty_cond_.wait_for(locker, std::chrono::seconds(30), [this] { return !queue_->IsEmpty(); }); - if (!timeout) { - return TIMEOUT; - } - - return queue_->Front(addr, len); -} - -BlockQueueStatus_T BlockingQueue::Pop() { - std::unique_lock locker(mutex_); - not_empty_cond_.wait(locker, [this] { return !queue_->IsEmpty(); }); - auto ret = queue_->Pop(); - if (ret) { - return ret; - } - not_full_cond_.notify_one(); - return SUCCESS; -} - -bool BlockingQueue::Destroy() { - if (queue_ != nullptr) { - return queue_->Destroy(); - } else { - return true; - } -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/cuda_common.h b/mindspore/ccsrc/device/gpu/cuda_common.h deleted file mode 100644 index b79ba8bc28..0000000000 --- a/mindspore/ccsrc/device/gpu/cuda_common.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ - -#include -#include "device/gpu/gpu_device_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -class CudaCommon { - public: - inline int threads_num() const { return threads_per_block_; } - inline int major_sm() const { return major_sm_; } - inline int blocks_num(const int total_threads) const { - return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); - } - - static CudaCommon &GetInstance() { - static CudaCommon instance; - return instance; - } - - private: - CudaCommon() { - uint32_t device_id = GPUDeviceManager::GetInstance().cur_device_id(); - cudaDeviceProp prop; - (void)cudaGetDeviceProperties(&prop, device_id); - threads_per_block_ = prop.maxThreadsPerBlock; - max_blocks_ = prop.multiProcessorCount; - major_sm_ = prop.major; - } - ~CudaCommon() = default; - CudaCommon(const CudaCommon &) = delete; - CudaCommon &operator=(const CudaCommon &) = delete; - - int max_blocks_; - int threads_per_block_; - int major_sm_; -}; -#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) -#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() -#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() -#define MINIUM_SM 6 -#define RECOMMEND_SM 7 -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ diff --git a/mindspore/ccsrc/device/gpu/cuda_driver.cc b/mindspore/ccsrc/device/gpu/cuda_driver.cc deleted file mode 100644 index 0dee53df64..0000000000 --- a/mindspore/ccsrc/device/gpu/cuda_driver.cc +++ /dev/null @@ -1,231 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/cuda_driver.h" -#include -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace device { -namespace gpu { -size_t CudaDriver::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - size_t retreat_count = 0; - auto ret = cudaMalloc(reinterpret_cast(addr), size); - // If free memory is not enough, then retry with mem_malloc_retry_rate_. - while (ret == cudaErrorMemoryAllocation) { - size = FloatToSize(size * mem_malloc_retry_rate_); - size = (size / mem_malloc_align_size_) * mem_malloc_align_size_; - ret = cudaMalloc(reinterpret_cast(addr), size); - retreat_count++; - if (retreat_count > mem_malloc_retry_conut_max_) { - break; - } - } - - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMalloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - return size; -} - -bool CudaDriver::FreeDeviceMem(const DeviceMemPtr &addr) { - auto ret = cudaFree(addr); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaFree failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -size_t CudaDriver::AllocHostPinnedMem(size_t size, void **addr) { - if (size == 0) { - MS_LOG(EXCEPTION) << "The memory allocate size is 0"; - } - auto ret = cudaHostAlloc(addr, size, cudaHostAllocDefault); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaHostAlloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - return size; -} - -void CudaDriver::FreeHostPinnedMem(void *addr) { - if (addr) { - auto ret = cudaFreeHost(addr); - if (ret != cudaSuccess) { - MS_LOG(EXCEPTION) << "cudaFreeHost failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - } - } -} - -bool CudaDriver::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) { - auto ret = cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) { - auto ret = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) { - auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, (cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, - DeviceStream stream) { - auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, (cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -size_t CudaDriver::total_mem_size() { - size_t free; - size_t total; - auto ret = cudaMemGetInfo(&free, &total); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - return total; -} - -size_t CudaDriver::free_mem_size() { - size_t free; - size_t total; - auto ret = cudaMemGetInfo(&free, &total); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return 0; - } - - return free; -} - -bool CudaDriver::CreateStream(DeviceStream *stream) { - auto ret = cudaStreamCreateWithFlags(reinterpret_cast(stream), cudaStreamNonBlocking); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamCreate failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::DestroyStream(const DeviceStream &stream) { - auto ret = cudaStreamDestroy((cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::SyncStream(const DeviceStream &stream) { - auto ret = cudaStreamSynchronize((cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaStreamSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::CreateEvent(DeviceEvent *event, unsigned int flag) { - auto ret = cudaEventCreateWithFlags(reinterpret_cast(event), flag); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventCreateWithFlags failed, ret[" << static_cast(ret) << "], " - << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::DestroyEvent(const DeviceEvent &event) { - auto ret = cudaEventDestroy((cudaEvent_t)event); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::RecordEvent(DeviceEvent event, DeviceStream stream) { - auto ret = cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventRecord failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::SyncEvent(const DeviceEvent &event) { - auto ret = cudaEventSynchronize((cudaEvent_t)event); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaEventSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} - -bool CudaDriver::QueryEvent(const DeviceEvent &event) { - auto ret = cudaEventQuery((cudaEvent_t)event); - if (ret == cudaSuccess) { - return true; - } else if (ret == cudaErrorNotReady) { - return false; - } else { - MS_LOG(ERROR) << "cudaEventQuery failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } -} - -int CudaDriver::device_count() { - int dev_count; - auto ret = cudaGetDeviceCount(&dev_count); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaGetDeviceCount failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - } - return dev_count; -} - -bool CudaDriver::set_current_device(int index) { - auto ret = cudaSetDevice(index); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaSetDevice failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); - return false; - } - return true; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.cc deleted file mode 100644 index 06497a2e82..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.cc +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/collective_fake_init.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace gpu { -void CollectiveFakeInitializer::InitCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } - -void CollectiveFakeInitializer::FinalizeCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc deleted file mode 100644 index d7ab95bbe8..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/collective_init.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -namespace gpu { -CollectiveInitializer &CollectiveInitializer::instance() { - static CollectiveInitializer instance = {}; - return instance; -} - -bool CollectiveInitializer::collective_inited() const { return collective_inited_; } - -const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } - -void CollectiveInitializer::InitCollective() { - void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); - if (handle == nullptr) { - MS_LOG(EXCEPTION) - << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " - "installed.\n2.nccl is not " - "installed or found.\n3.mpi is not installed or found"; - } - auto mpi_init_funcptr = reinterpret_cast(dlsym(handle, "InitMPI")); - MS_EXCEPTION_IF_NULL(mpi_init_funcptr); - (*mpi_init_funcptr)(); - - CollectiveInitializer::instance().collective_inited_ = true; - CollectiveInitializer::instance().collective_handle_ = handle; -} - -void CollectiveInitializer::FinalizeCollective() { - if (CollectiveInitializer::instance().collective_handle_ != nullptr) { - if (dlclose(CollectiveInitializer::instance().collective_handle_) != 0) { - MS_LOG(EXCEPTION) << "Closing libgpu_collective.so handle failed."; - } - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/device/gpu/distribution/collective_wrapper.cc deleted file mode 100644 index 5fb0f74849..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/collective_wrapper.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include "device/gpu/distribution/mpi_wrapper.h" -#include "device/gpu/distribution/nccl_wrapper.h" - -#ifndef EXPORT_WRAPPER -#define EXPORT_WRAPPER __attribute__((visibility("default"))) -#endif - -using MPIWrapper = mindspore::device::gpu::MPIWrapper; -using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; - -extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } - -extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } - -extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } - -extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); -} - -extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, cudaStream_t stream) { - return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); -} - -extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, - cudaStream_t stream) { - return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); -} diff --git a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.cc deleted file mode 100644 index 46b574c575..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/mpi_wrapper.h" - -#include -#include -#include "device/gpu/distribution/nccl_wrapper.h" - -namespace mindspore { -namespace device { -namespace gpu { -MPIWrapper::MPIWrapper() : rank_id_(0), rank_size_(0), local_rank_id_(0) { Init(); } - -MPIWrapper::~MPIWrapper() { - int finalized; - MPI_Finalized(&finalized); - if (finalized == 0) { - MPI_Finalize(); - } -} - -MPIWrapper &MPIWrapper::instance() { - static MPIWrapper instance; - return instance; -} - -int MPIWrapper::local_rank_id() const { return local_rank_id_; } - -void MPIWrapper::Init() { - int initialized; - CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status."); - - if (initialized == 0) { - MPI_Init(nullptr, nullptr); - } - CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); - CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); - NCCLWrapper::instance().set_rank(rank_id_, rank_size_); - AssignLocalRankId(); - - ncclUniqueId unique_id; - if (rank_id_ == 0) { - unique_id = NCCLWrapper::instance().nccl_unique_id(); - } - CHECK_RET(MPI_Bcast(reinterpret_cast(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), - MPI_SUCCESS, "Failed to broadcast nccl unique id."); - NCCLWrapper::instance().set_nccl_unique_id(unique_id); - return; -} - -void MPIWrapper::AssignLocalRankId() { - char host_name[MAX_HOSTNAME_LEN] = {0}; - CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed."); - size_t host_hash = std::hash()(host_name); - - const int kRankSize = rank_size_; - size_t all_host_hashs[kRankSize]; - all_host_hashs[rank_id_] = host_hash; - CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD), - MPI_SUCCESS, "MPI_Allgather host hashs failed."); - for (int global_rank = 0; global_rank < kRankSize; global_rank++) { - if (global_rank == rank_id_) { - break; - } - if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) { - local_rank_id_++; - } - } - return; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.h b/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.h deleted file mode 100644 index 6dfedea922..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/mpi_wrapper.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ - -#include -#include -#include -#include -#include -#include "device/gpu/distribution/collective_common.h" - -namespace mindspore { -namespace device { -namespace gpu { -class MPIWrapper { - public: - MPIWrapper(MPIWrapper const &) = delete; - MPIWrapper &operator=(const MPIWrapper &) = delete; - static MPIWrapper &instance(); - int local_rank_id() const; - - private: - MPIWrapper(); - ~MPIWrapper(); - void Init(); - void AssignLocalRankId(); - - int rank_id_; - int rank_size_; - int local_rank_id_; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ diff --git a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.cc deleted file mode 100644 index aa4756a69f..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/distribution/nccl_wrapper.h" - -namespace mindspore { -namespace device { -namespace gpu { -NCCLWrapper &NCCLWrapper::instance() { - static NCCLWrapper instance; - return instance; -} - -ncclUniqueId NCCLWrapper::nccl_unique_id() const { - ncclUniqueId unique_id; - CHECK_RET(ncclGetUniqueId(&unique_id), ncclSuccess, "Failed to create nccl unique id."); - return unique_id; -} - -void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } - -void NCCLWrapper::set_rank(int rank_id, int rank_size) { - rank_id_ = rank_id; - rank_size_ = rank_size; -} - -void NCCLWrapper::InitNCCLComm() { - CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, - "Failed to init nccl communicator."); -} - -ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, - ncclRedOp_t reduce_type, cudaStream_t stream) { - return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); -} - -ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, - cudaStream_t stream) { - return ncclAllGather(input_addr, output_addr, count, data_type, comm_, stream); -} - -ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, - ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream) { - return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.h deleted file mode 100644 index 5df1e63bb8..0000000000 --- a/mindspore/ccsrc/device/gpu/distribution/nccl_wrapper.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ - -#include -#include -#include -#include "device/gpu/distribution/collective_common.h" - -namespace mindspore { -namespace device { -namespace gpu { -class NCCLWrapper { - public: - NCCLWrapper(NCCLWrapper const &) = delete; - NCCLWrapper &operator=(const NCCLWrapper &) = delete; - static NCCLWrapper &instance(); - ncclUniqueId nccl_unique_id() const; - void set_nccl_unique_id(ncclUniqueId unique_id); - void set_rank(int rank_id, int rank_size); - void InitNCCLComm(); - ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, - ncclRedOp_t op, cudaStream_t stream); - ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, - cudaStream_t stream); - ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, - ncclRedOp_t op, cudaStream_t stream); - - private: - NCCLWrapper() : rank_id_(-1), rank_size_(0) {} - ~NCCLWrapper() = default; - - private: - int rank_id_; - int rank_size_; - ncclUniqueId unique_id_; - ncclComm_t comm_; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc b/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc deleted file mode 100644 index 621ba557e5..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_buffer_mgr.h" -#include -#include -#include "utils/log_adapter.h" -#include "common/utils.h" - -namespace mindspore { -namespace device { -unsigned int HandleMgr::AllocHandle() { - for (size_t i = 0; i < MAX_HANDLE_NUM; ++i) { - if (!handle_list_[i]) { - handle_list_[i] = true; - return (unsigned int)i; - } - } - return INVALID_HANDLE; -} - -void HandleMgr::FreeHandle(unsigned int handle_id) { - if (handle_id >= MAX_HANDLE_NUM) { - return; - } - handle_list_[handle_id] = false; -} - -GpuBufferMgr &GpuBufferMgr::GetInstance() noexcept { - static GpuBufferMgr instance; - return instance; -} - -BlockQueueStatus_T GpuBufferMgr::Create(unsigned int device_id, const std::string &channel_name, void *addr, - const std::vector &shape, const size_t &capacity) { - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (name_queue_map_.count(name)) { - MS_LOG(ERROR) << "Queue not exist " << name; - return QUEUE_NOT_EXIST; - } - std::shared_ptr queue = std::make_shared(); - BlockQueueStatus_T rt = queue->Create(addr, shape, capacity); - if (rt != SUCCESS) { - return rt; - } - (void)name_queue_map_.insert(std::make_pair(name, queue)); - init_ = true; - return SUCCESS; -} - -unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, - const std::vector &shape, const std::function func) { - set_device(); - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (!name_queue_map_.count(name)) { - MS_LOG(ERROR) << "Queue not exist " << name; - return HandleMgr::INVALID_HANDLE; - } - unsigned int handle = handle_mgr_.AllocHandle(); - if (handle == HandleMgr::INVALID_HANDLE) { - MS_LOG(ERROR) << "handle is invalid"; - return HandleMgr::INVALID_HANDLE; - } - (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); - name_queue_map_[name]->RegisterRelease(func); - open_by_dataset_++; - return handle; -} - -unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, - const std::vector &shape) { - set_device(); - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (!name_queue_map_.count(name)) { - MS_LOG(ERROR) << "Queue not exist " << name; - return HandleMgr::INVALID_HANDLE; - } - unsigned int handle = handle_mgr_.AllocHandle(); - if (handle == HandleMgr::INVALID_HANDLE) { - MS_LOG(ERROR) << "handle is invalid"; - return HandleMgr::INVALID_HANDLE; - } - (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); - return handle; -} - -void GpuBufferMgr::set_device_id(int device_id) { cur_dev_id_ = device_id; } - -void GpuBufferMgr::set_device() const { - auto ret = cudaSetDevice(cur_dev_id_); - if (ret != cudaSuccess) { - MS_LOG(ERROR) << "cudaSetDevice, ret[" << static_cast(ret) << "]"; - } -} - -BlockQueueStatus_T GpuBufferMgr::Push(unsigned int handle, const std::vector &data, - unsigned int timeout_in_sec) { - auto iter = handle_queue_map_.find(handle); - if (iter == handle_queue_map_.end()) { - return HANDLE_NOT_EXIST; - } - return iter->second->Push(data, timeout_in_sec); -} - -BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, void **addr, size_t *len) { - auto iter = handle_queue_map_.find(handle); - if (iter == handle_queue_map_.end()) { - return HANDLE_NOT_EXIST; - } - return iter->second->Front(addr, len); -} - -BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) { - auto iter = handle_queue_map_.find(handle); - if (iter == handle_queue_map_.end()) { - return HANDLE_NOT_EXIST; - } - return iter->second->Pop(); -} - -void GpuBufferMgr::Close(unsigned int handle) noexcept { - if (!handle_queue_map_.count(handle)) { - return; - } - (void)handle_queue_map_.erase(handle); - handle_mgr_.FreeHandle(handle); - return; -} - -bool GpuBufferMgr::IsInit() const { return init_; } - -bool GpuBufferMgr::IsClosed() const { return closed_; } - -bool GpuBufferMgr::Destroy() { - for (auto iter = name_queue_map_.begin(); iter != name_queue_map_.end(); ++iter) { - std::shared_ptr queue = iter->second; - if (queue != nullptr) { - if (!queue->Destroy()) { - return false; - } - queue.reset(); - } - } - name_queue_map_.clear(); - return true; -} - -inline bool GpuBufferMgr::isCreated(unsigned int device_id, const std::string &channel_name) { - std::string name = std::to_string(device_id) + std::string("_") + channel_name; - if (name_queue_map_.count(name) != 0) { - return true; - } - return false; -} - -bool GpuBufferMgr::CloseNotify() { - bool result = true; - // lock scope - { - std::lock_guard lk(close_mutex_); - // set closed_ to be true, all the dataset retry can be jumped out of the while - closed_ = true; - } - - // wati for the dataset threads' ack - for (int i = 0; i < open_by_dataset_; i++) { - if (sema.Wait() == false) { - MS_LOG(ERROR) << "time out of receiving signals"; - result = false; - } - MS_LOG(DEBUG) << "receive one signal (" << i + 1 << "/" << open_by_dataset_ << ")"; - } - return result; -} - -void GpuBufferMgr::CloseConfirm() { sema.Signal(); } -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h b/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h deleted file mode 100644 index 5ce4a2cbdc..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "device/gpu/blocking_queue.h" - -#define EXPORT __attribute__((visibility("default"))) - -namespace mindspore { -namespace device { -static const unsigned int MAX_WAIT_TIME_IN_SEC = 60; - -class Semaphore { - public: - explicit Semaphore(int count = 0) : count_(count) {} - - inline void Signal() { - std::unique_lock lock(mutex_); - ++count_; - cv_.notify_one(); - } - - inline bool Wait() { - std::unique_lock lock(mutex_); - while (count_ == 0) { - if (cv_.wait_for(lock, std::chrono::seconds(MAX_WAIT_TIME_IN_SEC)) == std::cv_status::timeout) { - return false; - } - } - --count_; - return true; - } - - private: - std::mutex mutex_; - std::condition_variable cv_; - int count_; -}; - -class HandleMgr { - public: - static const unsigned int MAX_HANDLE_NUM = 32; - static const unsigned int INVALID_HANDLE = 0xffffffffUL; - - unsigned int AllocHandle(); - void FreeHandle(unsigned int); - - private: - bool handle_list_[MAX_HANDLE_NUM]; -}; - -class GpuBufferMgr { - public: - EXPORT GpuBufferMgr() : cur_dev_id_(0), init_(false), closed_(false), open_by_dataset_(0) {} - - EXPORT virtual ~GpuBufferMgr() = default; - - EXPORT static GpuBufferMgr &GetInstance() noexcept; - - EXPORT BlockQueueStatus_T Create(unsigned int device_id, const std::string &channel_name, void *addr, - const std::vector &shape, const size_t &capacity); - - // call for Push thread - EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape, - std::function func); - - // call for Front/Pop thread - EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape); - - EXPORT BlockQueueStatus_T Push(unsigned int handle, const std::vector &data, - unsigned int timeout_in_sec); - EXPORT BlockQueueStatus_T Front(unsigned int handle, void **addr, size_t *len); - EXPORT BlockQueueStatus_T Pop(unsigned int handle); - - EXPORT void set_device_id(int device_id); - - EXPORT void Close(unsigned int handle) noexcept; - - EXPORT bool IsInit() const; - - EXPORT bool IsClosed() const; - - EXPORT bool Destroy(); - - // call for Release GPU Resources - EXPORT bool CloseNotify(); - - // call for dataset send thread - EXPORT void CloseConfirm(); - - private: - void set_device() const; - - int cur_dev_id_; - bool init_; - bool closed_; - std::mutex mutex_; - std::mutex close_mutex_; - // how many queues opened by dataset - int open_by_dataset_; - Semaphore sema; - - HandleMgr handle_mgr_; - - std::map> handle_queue_map_; - std::map> name_queue_map_; - - inline bool isCreated(unsigned int device_id, const std::string &channel_name); - - GpuBufferMgr(const GpuBufferMgr &) = delete; - GpuBufferMgr &operator=(const GpuBufferMgr &) = delete; -}; -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/device/gpu/gpu_device_address.cc deleted file mode 100644 index 401eb9f34e..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_address.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_device_address.h" -#include -#include "device/gpu/gpu_device_manager.h" -#include "utils/log_adapter.h" -#include "device/gpu/gpu_memory_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, TypeId, void *host_ptr) const { - MS_EXCEPTION_IF_NULL(host_ptr); - auto &stream = GPUDeviceManager::GetInstance().default_stream(); - MS_EXCEPTION_IF_NULL(stream); - auto ret = GPUDeviceManager::GetInstance().SyncStream(stream); - if (!ret) { - MS_LOG(ERROR) << "SyncStream failed"; - return ret; - } - if (size != size_) { - MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_; - return true; - } - return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size_); -} - -bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t, TypeId, const void *host_ptr) const { - MS_EXCEPTION_IF_NULL(host_ptr); - auto &stream = GPUDeviceManager::GetInstance().default_stream(); - MS_EXCEPTION_IF_NULL(stream); - if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size_, stream)) { - MS_LOG(ERROR) << "CopyHostMemToDeviceAsync failed"; - return false; - } - return GPUDeviceManager::GetInstance().SyncStream(stream); -} - -GPUDeviceAddress::~GPUDeviceAddress() { - if (ptr_ == nullptr) { - return; - } - if (from_mem_pool_) { - GPUMemoryAllocator::GetInstance().FreeTensorMem(ptr_); - ptr_ = nullptr; - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_device_address.h b/mindspore/ccsrc/device/gpu/gpu_device_address.h deleted file mode 100644 index 4074cb6ce9..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_address.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ - -#include -#include -#include "device/device_address.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUDeviceAddress : public DeviceAddress { - public: - GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} - GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) - : DeviceAddress(ptr, size, format, type_id) {} - ~GPUDeviceAddress() override; - - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; - void set_status(DeviceAddressStatus status) { status_ = status; } - DeviceAddressStatus status() const { return status_; } - DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } - - private: - DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc deleted file mode 100644 index 9f5f37c606..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_common.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -#include "device/gpu/gpu_buffer_mgr.h" - -namespace mindspore { -namespace device { -namespace gpu { -void GPUDeviceManager::InitDevice() { - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id"); - CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cuDNN handle."); - CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); - CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), - "Failed to set stream for cuBLAS handle."); - CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") -} - -void GPUDeviceManager::ReleaseDevice() { - for (DeviceStream stream : gpu_streams_) { - if (stream != nullptr) { - CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); - } - } - if (cudnn_handle_ != nullptr) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); - } - if (cublas_handle_ != nullptr) { - CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); - } - CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); -} - -bool GPUDeviceManager::CreateStream(DeviceStream *stream) { - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); - gpu_streams_.emplace_back(*stream); - return true; -} - -const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } - -int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } - -bool GPUDeviceManager::set_cur_device_id(uint32_t device_id) { - if (!dev_id_init_) { - dev_id_init_ = true; - cur_dev_id_ = device_id; - mindspore::device::GpuBufferMgr::GetInstance().set_device_id(UintToInt(device_id)); - return true; - } else { - MS_LOG(ERROR) << "Device already been set."; - return false; - } -} - -uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } - -bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } - -const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } - -const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } - -bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } - -bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { - return CudaDriver::CopyDeviceMemToHost(dst, src, size); -} - -bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { - return CudaDriver::CopyHostMemToDevice(dst, src, size); -} - -bool GPUDeviceManager::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, - DeviceStream stream) const { - return CudaDriver::CopyDeviceMemToHostAsync(dst, src, size, stream); -} - -bool GPUDeviceManager::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, - DeviceStream stream) const { - return CudaDriver::CopyHostMemToDeviceAsync(dst, src, size, stream); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/device/gpu/gpu_device_manager.h deleted file mode 100644 index b6b630181e..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ - -#include -#include -#include -#include -#include "device/gpu/cuda_driver.h" -#include "device/gpu/gpu_memory_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUDeviceManager { - public: - void InitDevice(); - void ReleaseDevice(); - - int device_count() const; - bool set_cur_device_id(uint32_t device_id); - uint32_t cur_device_id() const; - bool is_device_id_init() const; - - bool CreateStream(DeviceStream *stream); - bool SyncStream(const DeviceStream &stream) const; - const DeviceStream &default_stream() const; - - const cudnnHandle_t &GetCudnnHandle() const; - const cublasHandle_t &GetCublasHandle() const; - - bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; - bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; - - bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, DeviceStream stream) const; - bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) const; - - static GPUDeviceManager &GetInstance() { - static GPUDeviceManager instance; - return instance; - } - - private: - GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} - ~GPUDeviceManager() = default; - GPUDeviceManager(const GPUDeviceManager &) = delete; - GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; - - // default CUDA stream used for all the kernels. - DeviceStream default_stream_{nullptr}; - - // all gpu CUDA streams including default_stream_. - std::vector gpu_streams_; - - // handle used for cuDNN kernels. - cudnnHandle_t cudnn_handle_{nullptr}; - - // handle used for cuBLAS kernels. - cublasHandle_t cublas_handle_{nullptr}; - - bool dev_id_init_; - uint32_t cur_dev_id_; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_build.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_build.cc deleted file mode 100644 index 19d2284510..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_build.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "device/gpu/gpu_kernel_build.h" -#include -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" -#include "kernel/akg/gpu/akg_gpu_kernel_build.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -namespace mindspore { -namespace device { -namespace gpu { -void GpuBuild(const KernelGraphPtr &kernel_graph) { - kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); - MS_EXCEPTION_IF_NULL(bin_map); - bin_map->Initialize(); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { - std::string kernel_name = session::AnfRuntimeAlgorithm::GetCNodeName(kernel); - if (kernel_name == prim::kPrimTupleGetItem->name() || kernel_name == prim::kPrimMakeTuple->name() || - kernel_name == prim::kPrimDepend->name() || kernel_name == prim::kPrimStateSetItem->name()) { - continue; - } - - if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { - auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); - if (!gpu_kernel_ptr) { - MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; - } - session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get()); - } else { - auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel); - if (!gpu_kernel_ptr) { - MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel_name << "] failed"; - } - if (!gpu_kernel_ptr->Init(kernel)) { - MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel_name << "] failed."; - } - session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get()); - } - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_build.h b/mindspore/ccsrc/device/gpu/gpu_kernel_build.h deleted file mode 100644 index 5770e4d3b1..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_build.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ - -#include -#include "session/kernel_graph.h" -namespace mindspore { -namespace device { -namespace gpu { -void GpuBuild(const std::shared_ptr &kernel_graph); -} // namespace gpu -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc deleted file mode 100644 index 839229be36..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ /dev/null @@ -1,646 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_kernel_runtime.h" -#include "device/gpu/gpu_device_address.h" -#include "device/gpu/cuda_driver.h" -#include "device/gpu/gpu_buffer_mgr.h" -#include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_memory_allocator.h" -#include "device/gpu/distribution/collective_init.h" -#include "utils/convert_utils.h" -#include "utils/context/ms_context.h" -#include "device/kernel_runtime_manager.h" -#include "device/gpu/gpu_common.h" -#include "common/utils.h" -#include "device/gpu/gpu_memory_manager.h" -#include "kernel/common_utils.h" -#include "device/gpu/gpu_memory_copy_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -using mindspore::device::memswap::MemSwapManager; -using mindspore::device::memswap::SwapKind; -bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } - -bool GPUKernelRuntime::Init() { - if (device_init_ == true) { - GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); - return true; - } - auto ret = InitDevice(); - if (!ret) { - MS_LOG(ERROR) << "InitDevice error."; - return ret; - } - mem_manager_ = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->MallocDeviceMemory(); - const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); - bool collective_inited = CollectiveInitializer::instance().collective_inited(); - if (collective_inited && collective_handle_ != nullptr) { - auto init_nccl_comm_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "InitNCCLComm")); - MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr); - (*init_nccl_comm_funcptr)(); - } - device_init_ = true; - return ret; -} - -DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) { - return std::make_shared(device_ptr, device_size, format, type_id); -} - -bool GPUKernelRuntime::InitDevice() { - if (GPUDeviceManager::GetInstance().device_count() <= 0) { - MS_LOG(ERROR) << "No GPU device found."; - return false; - } - const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); - bool collective_inited = CollectiveInitializer::instance().collective_inited(); - if (collective_inited && collective_handle_ != nullptr) { - auto get_local_rank_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); - MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); - device_id_ = IntToUint((*get_local_rank_funcptr)()); - } - if (!GPUDeviceManager::GetInstance().is_device_id_init()) { - if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) { - MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_); - return false; - } - } - GPUDeviceManager::GetInstance().InitDevice(); - stream_ = GPUDeviceManager::GetInstance().default_stream(); - if (stream_ == nullptr) { - MS_LOG(ERROR) << "No default CUDA stream found."; - return false; - } - return true; -} - -void GPUKernelRuntime::ReleaseDeviceRes() { - // For dataset mode. - if (GpuBufferMgr::GetInstance().IsInit()) { - if (!GpuBufferMgr::GetInstance().IsClosed()) { - if (!GpuBufferMgr::GetInstance().CloseNotify()) { - MS_LOG(EXCEPTION) << "Could not close gpu data queue."; - } - } - CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue."); - } - - // Destroy remaining memory swap events and free host memory. - for (auto &item : mem_swap_map_) { - auto &mem_swap_manager = item.second; - MS_EXCEPTION_IF_NULL(mem_swap_manager); - if (mem_swap_manager->trigger_swap()) { - mem_swap_manager->ClearSwapQueue(); - mem_swap_manager->ReleaseHostPinnedMem(); - } - } - - GPUDeviceManager::GetInstance().ReleaseDevice(); - if (mem_manager_ != nullptr) { - mem_manager_->FreeDeviceMemory(); - } - - kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); - MS_EXCEPTION_IF_NULL(bin_map); - bin_map->RemoveKernelCache(); -} - -void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->ResetDynamicMemory(); - AssignStaticMemoryInput(graph); - AssignStaticMemoryValueNode(graph); - bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); - if (is_enable_dynamic_mem) { - // Use the dynamic memory pool. - InitKernelRefCount(graph); - InitMemorySwapInfo(graph); - InitKernelOutputAddress(graph); - } else { - AssignDynamicMemory(graph); - } -} - -bool GPUKernelRuntime::Run(session::KernelGraph *graph) { - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - bool ret = true; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); - bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); - if (is_enable_dynamic_mem && !is_enable_pynative_infer) { - auto graph_id = graph->graph_id(); - auto iter = mem_swap_map_.find(graph_id); - if (iter == mem_swap_map_.end()) { - MS_LOG(EXCEPTION) << "Find memory swap map failed."; - } - mem_swap_manager_ = iter->second; - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - while (!LaunchKernelDynamic(graph)) { - MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; - if (!UpdateMemorySwapInfo(graph)) { - return false; - } - } - } else { - ret = LaunchKernel(graph); - } - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(DEBUG) << "GPU kernel runtime run graph in " << cost << " us"; - return ret; -} - -void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - // Init the kernel reference count. - if (!mem_reuse_util_ptr->InitDynamicKernelRef(graph)) { - MS_LOG(EXCEPTION) << "Init kernel reference count failed"; - } - mem_reuse_util_ptr->SetKernelDefMap(); - mem_reuse_util_ptr->SetReuseRefCount(); - // Can't free the device address of graph output, so set the reference count of graph output specially. - mem_reuse_util_ptr->SetGraphOutputRefCount(); - // Can't free the device address of summary nodes, so set the reference count of summary nodes specially. - mem_reuse_util_ptr->SetSummaryNodesRefCount(); - auto graph_id = graph->graph_id(); - mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; -} - -void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); - MS_EXCEPTION_IF_NULL(gpu_mem_copy_manager); - MemSwapManagerPtr mem_swap_manager = std::make_shared(gpu_mem_copy_manager); - MS_EXCEPTION_IF_NULL(mem_swap_manager); - auto graph_id = graph->graph_id(); - mem_swap_map_[graph_id] = mem_swap_manager; -} - -void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - if (AnfAlgo::OutputAddrExist(kernel, i)) { - continue; - } - std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); - AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); - } - } -} - -void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - if (!AnfAlgo::OutputAddrExist(kernel, i)) { - continue; - } - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - if (device_address->ptr_) { - mem_manager_->FreeMemFromMemPool(device_address); - } - device_address->set_status(DeviceAddressStatus::kInDevice); - } - } -} - -bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto graph_id = graph->graph_id(); - auto iter = mem_reuse_util_map_.find(graph_id); - if (iter == mem_reuse_util_map_.end()) { - MS_LOG(EXCEPTION) << "Find memory reuse map failed."; - } - auto mem_reuse_util_ptr = iter->second; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - // Reset the reference count. - mem_reuse_util_ptr->ResetDynamicUsedRefCount(); - // The inputs and outputs memory of communication kernel need be continuous, so separate processing. - AllocCommunicationOpDynamicRes(graph); - - auto &kernels = graph->execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - if (!ret) { - return false; - } - if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { - MS_LOG(EXCEPTION) << "Launch kernel failed."; - } - FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); - UpdateMemorySwapTask(kernel); - } - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - ClearSwapQueue(); - return true; -} - -bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); - for (auto &mem_swap_info : mem_swap_info_list) { - auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); - const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; - auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); - - if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { - mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); - } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { - auto status = device_address->status(); - if (status == DeviceAddressStatus::kInDeviceToHost) { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); - device_address->set_status(DeviceAddressStatus::kInDevice); - } else if (status == DeviceAddressStatus::kInHost) { - if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) { - return false; - } - if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) { - mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address); - } - } - } - } - return true; -} - -bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - ClearKernelOutputAddress(graph); - if (!mem_swap_manager_->mem_swap_init()) { - mem_swap_manager_->Init(graph); - } - return mem_swap_manager_->RetreatSwapInfo(); -} - -bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - if (!mem_swap_manager_->trigger_swap()) { - return true; - } - if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { - CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); - if (!AddMemorySwapTask(kernel)) { - return false; - } - } - CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed."); - return true; -} - -void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - if (!mem_swap_manager_->trigger_swap()) { - return; - } - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { - device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); - } - auto status = device_address->status(); - switch (status) { - case DeviceAddressStatus::kInDevice: - break; - case DeviceAddressStatus::kInDeviceToHost: { - mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); - device_address->set_status(DeviceAddressStatus::kInDevice); - break; - } - case DeviceAddressStatus::kInHostToDevice: { - while (device_address->status() != DeviceAddressStatus::kInDevice) { - while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { - device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); - } - } - break; - } - case DeviceAddressStatus::kInHost: - MS_LOG(ERROR) << "Invaild device address status:" << status; - break; - default: - MS_LOG(EXCEPTION) << "Invaild device address status:" << status; - } -} - -void GPUKernelRuntime::UpdateDeviceSwapQueue() { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - if (!mem_swap_manager_->trigger_swap()) { - return; - } - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } -} - -void GPUKernelRuntime::ClearSwapQueue() { - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - if (!mem_swap_manager_->trigger_swap()) { - return; - } - mem_swap_manager_->ClearSwapQueue(); -} - -bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { - MS_EXCEPTION_IF_NULL(mem_manager_); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); - if (!ret) { - if (!mem_swap_manager_->trigger_swap()) { - return false; - } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - ret = mem_manager_->MallocMemFromMemPool(device_address, size); - if (!ret) { - return false; - } - } - return true; -} - -void *GPUKernelRuntime::AttemptMallocMem(size_t size) { - MS_EXCEPTION_IF_NULL(mem_manager_); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - if (!mem_swap_manager_->trigger_swap()) { - return nullptr; - } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - return nullptr; - } - } - return device_ptr; -} - -bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, - AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { - if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) { - return false; - } - if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) { - return false; - } - if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) { - return false; - } - return true; -} - -bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_inputs); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - UpdateHostSwapQueue(device_address); - MS_EXCEPTION_IF_NULL(device_address->ptr_); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - input->size = device_address->size_; - kernel_inputs->emplace_back(input); - } - return true; -} - -bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_outputs) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_outputs); - UpdateDeviceSwapQueue(); - auto output_sizes = kernel_mod.GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { - return false; - } - kernel::AddressPtr output = std::make_shared(); - MS_EXCEPTION_IF_NULL(output); - output->addr = device_address->ptr_; - output->size = output_sizes[i]; - kernel_outputs->emplace_back(output); - } - return true; -} - -bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_workspaces) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - auto workspace_sizes = kernel_mod.GetWorkspaceSizeList(); - for (size_t i = 0; i < workspace_sizes.size(); ++i) { - if (workspace_sizes[i] == 0) { - kernel_workspaces->emplace_back(nullptr); - continue; - } - auto device_ptr = AttemptMallocMem(workspace_sizes[i]); - if (!device_ptr) { - return false; - } - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_ptr; - workspace->size = workspace_sizes[i]; - kernel_workspaces->emplace_back(workspace); - } - return true; -} - -void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - if (AnfAlgo::IsCommunicationOp(kernel)) { - AllocCommunicationOpInputDynamicRes(kernel); - AllocCommunicationOpOutputDynamicRes(kernel); - } - } -} - -void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - bool is_need_alloc_memory = false; - bool is_need_free_memory = false; - size_t total_size = 0; - std::vector size_list; - DeviceAddressPtrList addr_list; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr) { - is_need_alloc_memory = true; - } else { - is_need_free_memory = true; - } - total_size += device_address->size_; - size_list.emplace_back(device_address->size_); - addr_list.emplace_back(device_address); - } - AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); -} - -void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - bool is_need_alloc_memory = false; - bool is_need_free_memory = false; - size_t total_size = 0; - std::vector size_list; - DeviceAddressPtrList addr_list; - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t i = 0; i < output_sizes.size(); ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - MS_EXCEPTION_IF_NULL(device_address); - if (device_address->ptr_ == nullptr) { - is_need_alloc_memory = true; - } else { - is_need_free_memory = true; - } - total_size += output_sizes[i]; - size_list.emplace_back(output_sizes[i]); - addr_list.emplace_back(device_address); - } - AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); -} - -void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, - const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list) { - MS_EXCEPTION_IF_NULL(mem_manager_); - if (!is_need_alloc_memory) { - return; - } - if (is_need_free_memory) { - for (const auto &iter : addr_list) { - MS_EXCEPTION_IF_NULL(iter); - // Free the inputs/outputs of communication kernel which are not released. - if (iter->ptr_ != nullptr) { - mem_manager_->FreeMemFromMemPool(iter); - } - } - } - auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } -} - -void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, - const AddressPtrList &kernel_workspaces, uint32_t graph_id) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::IsCommunicationOp(kernel)) { - return; - } - // Free the input of kernel by reference count. - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); - if (kernel_ref_count_ptr == nullptr) { - continue; - } - kernel_ref_count_ptr->ref_count_dynamic_use_--; - if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { - MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; - } - if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); - mem_manager_->FreeMemFromMemPool(device_address); - device_address->set_status(DeviceAddressStatus::kInDevice); - } - } - // Free the output of kernel, if output has no reference. - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { - auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); - if (kernel_ref_count_ptr == nullptr) { - continue; - } - if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); - mem_manager_->FreeMemFromMemPool(device_address); - device_address->set_status(DeviceAddressStatus::kInDevice); - } - } - // Free the workspace of kernel. - for (size_t i = 0; i < kernel_workspaces.size(); ++i) { - auto workspace = kernel_workspaces[i]; - if (workspace != nullptr) { - MS_EXCEPTION_IF_NULL(workspace->addr); - mem_manager_->FreeMemFromMemPool(workspace->addr); - workspace->addr = nullptr; - } - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h deleted file mode 100644 index bc7e4ed22c..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ - -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "device/kernel_runtime_manager.h" -#include "pre_activate/mem_reuse/mem_swap_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -using mindspore::device::memswap::MemSwapManagerPtr; -class GPUKernelRuntime : public KernelRuntime { - public: - GPUKernelRuntime() = default; - ~GPUKernelRuntime() override = default; - bool Init() override; - void ReleaseDeviceRes() override; - void AssignMemory(session::KernelGraph *graph) override; - bool Run(session::KernelGraph *graph) override; - - protected: - DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) override; - bool SyncStream() override; - - private: - GPUKernelRuntime(const GPUKernelRuntime &); - GPUKernelRuntime &operator=(const GPUKernelRuntime &); - bool InitDevice(); - bool device_init_{false}; - - // The related functions and members for using dynamic memory pool. - void InitKernelRefCount(const session::KernelGraph *graph); - void InitKernelOutputAddress(const session::KernelGraph *graph); - void InitMemorySwapInfo(const session::KernelGraph *graph); - void ClearKernelOutputAddress(const session::KernelGraph *graph); - bool LaunchKernelDynamic(const session::KernelGraph *graph); - bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); - void *AttemptMallocMem(size_t size); - bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, - AddressPtrList *kernel_outputs); - bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs); - bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_outputs); - bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, - const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces); - void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); - void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); - void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); - void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, - const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list); - void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, - uint32_t graph_id); - bool AddMemorySwapTask(const AnfNodePtr &kernel); - bool UpdateMemorySwapInfo(const session::KernelGraph *graph); - bool UpdateMemorySwapTask(const AnfNodePtr &kernel); - void UpdateHostSwapQueue(const DeviceAddressPtr device_address); - void UpdateDeviceSwapQueue(); - void ClearSwapQueue(); - std::unordered_map mem_reuse_util_map_; - std::unordered_map mem_swap_map_; - MemSwapManagerPtr mem_swap_manager_{nullptr}; -}; -MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); -} // namespace gpu -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc deleted file mode 100644 index 9137945661..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "device/gpu/gpu_memory_allocator.h" -#include "device/gpu/cuda_driver.h" -#include "utils/log_adapter.h" -#include "utils/context/ms_context.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -namespace device { -namespace gpu { -bool GPUMemoryAllocator::Init() { - size_t total_size = total_mem_size(); - size_t free_size = CudaDriver::free_mem_size(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - limited_device_memory_ = context_ptr->max_device_memory(); - available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); - if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { - MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size - << ", set max available memory size " << available_device_memory_ << "."; - } else { - MS_LOG(EXCEPTION) << "GPU device memory error, total memory size " << total_size << ", current free memory size " - << free_size << ", set max available memory size " << available_device_memory_ << "."; - } - return true; -} - -void GPUMemoryAllocator::CheckMaxDeviceMemory() const { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - auto max_device_memory = context_ptr->max_device_memory(); - // Currently not support modifying the max device memory. - if (limited_device_memory_ != max_device_memory) { - MS_LOG(EXCEPTION) - << "Can't change context param max_device_memory in runtime, currently effective max_device_memory(" - << limited_device_memory_ << "GB), set new max_device_memory(" << max_device_memory << "GB) failed."; - } -} - -bool GPUMemoryAllocator::Finalize() { - if (buffer_q_addr_ != nullptr) { - if (!CudaDriver::FreeDeviceMem(buffer_q_addr_)) { - MS_LOG(ERROR) << "Could not free buffer queue memory."; - return false; - } - } - return true; -} - -bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { - auto alloc_size = AllocDeviceMem(size, addr); - buffer_q_addr_ = *addr; - // Buffer queue needs to ensure that the alloc_size and size is equal. - return (alloc_size == size) ? true : false; -} - -size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { - if (size == 0) { - MS_LOG(EXCEPTION) << "The memory alloc size is 0."; - } - auto free_size = free_mem_size(); - if (size > free_size) { - MS_LOG(EXCEPTION) << "Memory not enough: current free memory size[" << free_size - << "] is smaller than required size[" << size << "]."; - } - - auto alloc_size = CudaDriver::AllocDeviceMem(size, addr); - if (alloc_size == 0) { - MS_LOG(EXCEPTION) << "Alloc device memory[" << size << "] failed."; - } - total_used_device_memory_ += alloc_size; - available_device_memory_ -= alloc_size; - MS_LOG(INFO) << "Current free memory size[" << free_size - alloc_size << "], current alloc size[" << alloc_size - << "], total used size[" << total_used_device_memory_ << "]."; - return alloc_size; -} - -bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } - -size_t GPUMemoryAllocator::free_mem_size() { return std::min(CudaDriver::free_mem_size(), available_device_memory_); } - -size_t GPUMemoryAllocator::total_mem_size() { return CudaDriver::total_mem_size(); } -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h deleted file mode 100644 index 90d7791057..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ - -#include -#include "device/gpu/cuda_driver.h" -#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" - -namespace mindspore { -namespace device { -namespace gpu { -class GPUMemoryAllocator : public DynamicMemPoolBestFit { - public: - ~GPUMemoryAllocator() override = default; - bool Init(); - void CheckMaxDeviceMemory() const; - bool Finalize(); - bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); - - size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; - bool FreeDeviceMem(const DeviceMemPtr &addr) override; - size_t free_mem_size() override; - size_t total_mem_size() override; - - static GPUMemoryAllocator &GetInstance() { - static GPUMemoryAllocator instance; - return instance; - } - - private: - GPUMemoryAllocator() = default; - GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; - GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; - - // Used to track address of data buffer queue. - DeviceMemPtr buffer_q_addr_{nullptr}; - - float limited_device_memory_{0.0}; - size_t total_used_device_memory_{0}; - size_t available_device_memory_{0}; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.cc deleted file mode 100644 index 80206f309d..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_memory_copy_manager.h" -#include "device/gpu/gpu_common.h" -#include "device/gpu/gpu_device_manager.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace gpu { -void GPUMemCopyManager::Init() { - CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_out_stream_), - "Failed to create CUDA stream of memory swap out."); - CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_in_stream_), - "Failed to create CUDA stream of memory swap in."); -} - -void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(host_addr.addr); - DeviceEvent event = nullptr; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); - DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->set_status(DeviceAddressStatus::kInDeviceToHost); - - CHECK_OP_RET_WITH_EXCEPT( - CudaDriver::CopyDeviceMemToHostAsync(host_addr.addr, device_ptr, host_addr.size, swap_out_stream_), - "Failed to copy device memory to host."); - - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_out_stream_), - "Failed to record CUDA event to swap out stream."); - swap_out_queue_.emplace(device_address, event); -} - -void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(host_addr.addr); - DeviceEvent event = nullptr; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); - DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->set_status(DeviceAddressStatus::kInHostToDevice); - - CHECK_OP_RET_WITH_EXCEPT( - CudaDriver::CopyHostMemToDeviceAsync(device_ptr, host_addr.addr, host_addr.size, swap_in_stream_), - "Failed to copy host memory to device."); - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_in_stream_), - "Failed to record CUDA event to swap in stream."); - swap_in_queue_.emplace(device_address, event); -} - -bool GPUMemCopyManager::SyncMemCopyStream(SwapKind swap_kind) { - if (swap_kind == SwapKind::kDeviceToHost) { - return GPUDeviceManager::GetInstance().SyncStream(swap_out_stream_); - } else { - return GPUDeviceManager::GetInstance().SyncStream(swap_in_stream_); - } -} - -DeviceAddressPtr GPUMemCopyManager::UpdateSwapOutQueue() { - if (swap_out_queue_.empty()) { - return nullptr; - } - auto &task = swap_out_queue_.front(); - auto device_address = task.first; - auto &event = task.second; - bool finish_swap = CudaDriver::QueryEvent(event); - if (!finish_swap) { - return nullptr; - } - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); - swap_out_queue_.pop(); - return device_address; -} - -DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueue() { - if (swap_in_queue_.empty()) { - return nullptr; - } - auto &task = swap_in_queue_.front(); - auto device_address = task.first; - auto &event = task.second; - bool finish_swap = CudaDriver::QueryEvent(event); - if (!finish_swap) { - return nullptr; - } - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); - swap_in_queue_.pop(); - return device_address; -} - -bool GPUMemCopyManager::AllocHostPinnedMem(size_t size, void **addr) const { - auto alloc_size = CudaDriver::AllocHostPinnedMem(size, addr); - return alloc_size == size; -} - -void GPUMemCopyManager::FreeHostPinnedMem(void *addr) const { CudaDriver::FreeHostPinnedMem(addr); } - -void GPUMemCopyManager::ClearSwapQueue() { - CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kDeviceToHost), "Failed to sync swap out stream"); - CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kHostToDevice), "Failed to sync swap in stream"); - - while (!swap_out_queue_.empty()) { - auto &event = swap_out_queue_.front().second; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); - swap_out_queue_.pop(); - } - while (!swap_in_queue_.empty()) { - auto &event = swap_in_queue_.front().second; - CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); - swap_in_queue_.pop(); - } -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.h deleted file mode 100644 index 36ff273015..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_copy_manager.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ - -#include -#include -#include -#include "pre_activate/mem_reuse/mem_copy_manager.h" -#include "device/device_address.h" -#include "device/gpu/cuda_driver.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace device { -namespace gpu { -using mindspore::device::memswap::MemCopyManager; -using mindspore::device::memswap::SwapKind; -class GPUMemCopyManager : public MemCopyManager { - public: - GPUMemCopyManager() = default; - - ~GPUMemCopyManager() override = default; - - void Init() override; - - void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; - - void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; - - bool SyncMemCopyStream(SwapKind swap_kind) override; - - DeviceAddressPtr UpdateSwapOutQueue() override; - - DeviceAddressPtr UpdateSwapInQueue() override; - - bool AllocHostPinnedMem(size_t size, void **addr) const override; - - void FreeHostPinnedMem(void *addr) const override; - - void ClearSwapQueue() override; - - private: - DeviceStream swap_out_stream_{nullptr}; - DeviceStream swap_in_stream_{nullptr}; - std::queue> swap_out_queue_; - std::queue> swap_in_queue_; -}; -using GPUMemCopyManagerPtr = std::shared_ptr; -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc deleted file mode 100644 index 9a63921add..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_memory_manager.h" -#include "device/gpu/gpu_memory_allocator.h" -#include "utils/context/ms_context.h" -#include "utils/convert_utils.h" -namespace mindspore { -namespace device { -namespace gpu { -void *GPUMemoryManager::MallocMemFromMemPool(size_t size) { - return GPUMemoryAllocator::GetInstance().AllocTensorMem(size); -} - -void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { - GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); -} - -std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { - return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); -} - -void GPUMemoryManager::MallocDeviceMemory() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - // If use the dynamic memory pool, then alloc the first memory block to init. - if (context_ptr->enable_dynamic_mem_pool()) { - auto device_addr = MallocMemFromMemPool(1); - if (!device_addr) { - MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; - } - } else { - // Need to reserve 20% space for dynamic memory - const float init_gpu_mem_ratio = 0.8; - size_t mem_size = FloatToSize(GPUMemoryAllocator::GetInstance().free_mem_size() * init_gpu_mem_ratio); - auto alloc_size = - GPUMemoryAllocator::GetInstance().AllocDeviceMem(mem_size, reinterpret_cast(&device_mem_base_)); - device_mem_size_ = alloc_size; - static_mem_offset_ = device_mem_size_; - } -} - -void GPUMemoryManager::FreeDeviceMemory() { - if (device_mem_base_ != nullptr) { - if (!GPUMemoryAllocator::GetInstance().FreeDeviceMem(device_mem_base_)) { - MS_LOG(EXCEPTION) << "Could not free gpu device memory."; - } - } - GPUMemoryAllocator::GetInstance().ReleaseDeviceRes(); -} - -uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_dynamic_mem_pool()) { - auto device_ptr = MallocMemFromMemPool(size); - MS_EXCEPTION_IF_NULL(device_ptr); - return AddressOffset(device_ptr, 0); - } - - auto align_size = GetCommonAlignSize(size); - if (static_mem_offset_ < align_size) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - auto offset = static_mem_offset_ - align_size; - if (dynamic_mem_offset_ > offset) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - total_static_size_ += align_size; - static_mem_offset_ = offset; - return device_mem_base_ + offset; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h deleted file mode 100644 index c79fb9cc22..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ -#include -#include "device/memory_manager.h" -namespace mindspore { -namespace device { -namespace gpu { -class GPUMemoryManager : public MemoryManager { - public: - GPUMemoryManager() = default; - virtual ~GPUMemoryManager() = default; - - void MallocDeviceMemory() override; - void FreeDeviceMemory() override; - - void *MallocMemFromMemPool(size_t size) override; - void FreeMemFromMemPool(void *device_ptr) override; - std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); - - protected: - uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; -}; -} // namespace gpu -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc deleted file mode 100644 index 42cdcf29ec..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/gpu_stream_assign.h" -#include -#include -#include -#include -#include "device/gpu/gpu_common.h" -#include "device/gpu/kernel_info_setter.h" -#include "device/gpu/gpu_device_manager.h" - -namespace mindspore { -namespace device { -namespace gpu { -void AssignGpuStream(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector allreduce_kernels; - auto execution_kernels = kernel_graph->execution_order(); - for (auto kernel_node : execution_kernels) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == kAllReduceOpName) { - allreduce_kernels.emplace_back(kernel_node); - } else { - DeviceStream compute_stream = GPUDeviceManager::GetInstance().default_stream(); - MS_EXCEPTION_IF_NULL(compute_stream); - AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(compute_stream)), kernel_node); - } - } - if (allreduce_kernels.size() > 1) { - // Assign multiple streams only when there're multiple AllReduce nodes. - std::vector send_recv_pairs; - if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) { - DeviceStream comm_stream = nullptr; - GPUDeviceManager::GetInstance().CreateStream(&comm_stream); - std::transform( - allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) { - AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(comm_stream)), allreduce_kernel); - return allreduce_kernel; - }); - InsertStreamSwitchNode(kernel_graph, send_recv_pairs); - } else { - return; - } - } -} - -bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, - std::vector *send_recv_pairs) { - auto execution_kernels = kernel_graph->execution_order(); - std::vector::iterator iter, iter_begin; - iter = iter_begin = execution_kernels.begin(); - std::vector::iterator iter_end = execution_kernels.end(); - for (; iter != execution_kernels.end(); ++iter) { - std::string kernel_name = AnfAlgo::GetCNodeName(*iter); - if (kernel_name == kAllReduceOpName) { - // Find AllReduce node's last input node. - std::vector::iterator mock_send_node_iter = - FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); - if (mock_send_node_iter == iter + 1) { - MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; - continue; - } - SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, - IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; - send_recv_pairs->push_back(pair1); - // Find node which uses AllReduce as input[0]. - std::vector::iterator mock_recv_node_iter = - FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); - if (mock_recv_node_iter == iter_end) { - MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; - return false; - } - SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), - IntToSize(mock_recv_node_iter - iter_begin)}; - send_recv_pairs->push_back(pair2); - } - } - return true; -} - -std::vector::iterator FindSendNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_recv_node, - StreamSwitchType stream_switch_type) { - MS_EXCEPTION_IF_NULL(mock_recv_node); - if (stream_switch_type == kAllReduceStreamSwitch) { - for (auto iter = begin; iter != end; iter++) { - if (*(iter + 1) == mock_recv_node) { - return iter; - } - } - } - return end; -} - -std::vector::iterator FindRecvNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_send_node, - StreamSwitchType stream_switch_type) { - MS_EXCEPTION_IF_NULL(mock_send_node); - for (auto iter = begin; iter != end; iter++) { - auto node = *iter; - if (stream_switch_type == kAllReduceStreamSwitch) { - for (auto input : node->inputs()) { - if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) { - return iter; - } - } - } - } - return end; -} - -void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, - const std::vector &send_recv_pairs) { - std::set ordered_stream_switch_nodes; - for (SendRecvPair pair : send_recv_pairs) { - StreamSwitchType stream_switch_type = pair.stream_switch_type; - CNodePtr mock_send_node = pair.mock_send_node; - CNodePtr mock_recv_node = pair.mock_recv_node; - size_t send_node_offset = pair.send_node_offset; - size_t recv_node_offset = pair.recv_node_offset; - CNodePtr send_node = nullptr; - CNodePtr recv_node = nullptr; - // Step 1: generate Send and Recv CNodes. - if (stream_switch_type == kAllReduceStreamSwitch) { - if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) { - MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch"; - } - } - // Step 2: sort send and recv CNodes by offset. - ordered_stream_switch_nodes.insert({send_node_offset, send_node}); - ordered_stream_switch_nodes.insert({recv_node_offset, recv_node}); - } - // Step 3: insert stream switch CNodes into execution kernel list. - auto execution_kernels = kernel_graph->execution_order(); - for (auto node = ordered_stream_switch_nodes.rbegin(); node != ordered_stream_switch_nodes.rend(); node++) { - execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); - } - kernel_graph->set_execution_order(execution_kernels); -} - -bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, - const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, - CNodePtr *recv_node) { - *send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName); - MS_EXCEPTION_IF_NULL(*send_node); - *recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName); - MS_EXCEPTION_IF_NULL(*recv_node); - - cudaEvent_t event = nullptr; - CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed."); - AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast(event)), *send_node); - AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast(event)), *recv_node); - - uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, kAttrStreamId); - AnfAlgo::SetNodeAttr(kAttrRecordEventStream, MakeValue(send_stream), *send_node); - uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, kAttrStreamId); - AnfAlgo::SetNodeAttr(kAttrWaitEventStream, MakeValue(recv_stream), *recv_node); - return true; -} - -CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name) { - auto op = std::make_shared(name); - MS_EXCEPTION_IF_NULL(op); - auto apply = std::make_shared(op); - MS_EXCEPTION_IF_NULL(apply); - std::vector input_list = {apply}; - CNodePtr node = kernel_graph->NewCNode(input_list); - MS_EXCEPTION_IF_NULL(node); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get()); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - node->set_abstract(abstract_none); - SetKernelInfo(node); - return node; -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/device/gpu/gpu_stream_assign.h deleted file mode 100644 index f8041878b2..0000000000 --- a/mindspore/ccsrc/device/gpu/gpu_stream_assign.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ -#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ - -#include -#include -#include -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace device { -namespace gpu { -enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 }; -struct SendRecvPair { - StreamSwitchType stream_switch_type; - CNodePtr mock_send_node; - CNodePtr mock_recv_node; - size_t send_node_offset; - size_t recv_node_offset; -}; -struct StreamSwitchNode { - size_t offset; - CNodePtr cnode; - bool operator<(const StreamSwitchNode &n) const { - if (offset < n.offset) { - return true; - } else if (offset == n.offset) { - return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false; - } else { - return false; - } - } -}; -void AssignGpuStream(const std::shared_ptr &kernel_graph); -bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, - std::vector *send_recv_pairs); -// Find Send node position according to "mock" recv node. -// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node. -std::vector::iterator FindSendNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_recv_node, - StreamSwitchType stream_switch_type); -// Find Recv node position according to "mock" send node. -// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node. -std::vector::iterator FindRecvNodePos(std::vector::iterator begin, - std::vector::iterator end, const CNodePtr mock_send_node, - StreamSwitchType stream_switch_type); -void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, - const std::vector &send_recv_pairs); -bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, - const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, - CNodePtr *recv_node); -CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name); -} // namespace gpu -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc deleted file mode 100644 index f4367e4714..0000000000 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ /dev/null @@ -1,212 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/kernel_info_setter.h" -#include -#include -#include "kernel/kernel.h" -#include "utils/utils.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/kernel_build_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/common_utils.h" -#include "common/utils.h" -#include "kernel/oplib/oplib.h" -#include "kernel/oplib/opinfo.h" - -namespace mindspore { -namespace device { -namespace gpu { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; -using mindspore::kernel::KernelBuildInfo; -namespace { -bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, - const std::shared_ptr &selected_kernel_info) { - MS_EXCEPTION_IF_NULL(selected_kernel_info); - MS_EXCEPTION_IF_NULL(alternative_kernel_info); - size_t selected_input_num = selected_kernel_info->GetInputNum(); - size_t alternative_input_num = alternative_kernel_info->GetInputNum(); - if (selected_input_num != alternative_input_num) { - return false; - } - for (size_t i = 0; i < selected_input_num; i++) { - if (selected_kernel_info->GetInputFormat(i) != alternative_kernel_info->GetInputFormat(i)) { - return false; - } - if (selected_kernel_info->GetInputDeviceType(i) != alternative_kernel_info->GetInputDeviceType(i)) { - return false; - } - } - - size_t selected_output_num = selected_kernel_info->GetOutputNum(); - size_t alternative_output_num = alternative_kernel_info->GetOutputNum(); - if (selected_output_num != alternative_output_num) { - return false; - } - for (size_t i = 0; i < selected_output_num; i++) { - if (selected_kernel_info->GetOutputFormat(i) != alternative_kernel_info->GetOutputFormat(i)) { - return false; - } - if (selected_kernel_info->GetOutputDeviceType(i) != alternative_kernel_info->GetOutputDeviceType(i)) { - return false; - } - } - return true; -} - -std::string SupportedTypeList(const CNodePtr &kernel_node) { - std::string supported_type_lists = - kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); - if (!supported_type_lists.empty()) { - return supported_type_lists; - } - std::vector> kernel_info_list; - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); - if (op_info_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Unsupported op [" << op_name << "]"; - } - (void)ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list); - for (size_t i = 0; i < kernel_info_list.size(); i++) { - auto supported_akg_type = kernel_info_list[i]->GetAllInputDeviceTypes(); - auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes(); - std::string supported_akg_type_list = "in["; - for (auto type : supported_akg_type) { - supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); - } - supported_type_lists = supported_type_lists + supported_akg_type_list + "], out["; - supported_akg_type_list.clear(); - for (auto type : supported_akg_type_out) { - supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); - } - supported_type_lists = supported_type_lists + supported_akg_type_list + "]; "; - } - return supported_type_lists; -} - -bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(selected_kernel_info); - std::vector> kernel_info_list; - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); - if (op_info_ptr == nullptr) { - MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; - return false; - } - if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) { - MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed."; - } - if (kernel_info_list.empty()) { - MS_LOG(EXCEPTION) << "Akg dose not has metadata of op[" << op_name << "]."; - } - - bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), - [&](const std::shared_ptr &alternative_kernel_info) { - return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); - }); - if (!match) { - MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; - return false; - } - return true; -} - -void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - auto input_kernel_node = kernel_node->input(input_index + 1); - MS_EXCEPTION_IF_NULL(input_kernel_node); - if (!input_kernel_node->isa()) { - continue; - } - std::shared_ptr builder = - std::make_shared(); - - auto param = input_kernel_node->cast(); - MS_EXCEPTION_IF_NULL(param); - if (!AnfAlgo::IsParameterWeight(param)) { - std::vector output_format = {kOpFormat_DEFAULT}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - continue; - } - if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) || - (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { - std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; - builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; - builder->SetOutputsDeviceType(output_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - } - } -} -} // namespace - -void SetKernelInfo(const CNodePtr &kernel_node) { - std::vector inputs_format; - std::vector inputs_type; - std::shared_ptr builder = - std::make_shared(); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(kOpFormat_DEFAULT); - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); - } - builder->SetInputsFormat(inputs_format); - builder->SetInputsDeviceType(inputs_type); - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(kOpFormat_DEFAULT); - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); - } - builder->SetOutputsFormat(outputs_format); - builder->SetOutputsDeviceType(outputs_type); - - bool result = - kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); - KernelType kernel_type = UNKNOWN_KERNEL_TYPE; - - if (!result) { - result = SelectAkgKernel(kernel_node, builder->Build()); - kernel_type = AKG_KERNEL; - } - - if (!result) { - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - std::string build_type = "in ["; - std::for_each(std::begin(inputs_type), std::end(inputs_type), - [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); - build_type += "] out ["; - std::for_each(std::begin(outputs_type), std::end(outputs_type), - [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); - build_type += "]"; - auto supported_type_lists = SupportedTypeList(kernel_node); - MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name - << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists - << ", but get " << build_type; - } - builder->SetKernelType(kernel_type); - builder->SetProcessor(kernel::Processor::CUDA); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); - SetTensorDeviceInfo(*(builder->Build()), kernel_node); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc b/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc deleted file mode 100644 index bcad74e5b5..0000000000 --- a/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/gpu/mpi/mpi_initializer.h" - -#include -#include -#include - -namespace mindspore { -namespace device { -namespace gpu { -MPIInitializer::MPIInitializer() { - int init_flag = 0; - if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { - return; - } - if (init_flag == 0) { - auto ret = MPI_Init(nullptr, nullptr); - if (ret != MPI_SUCCESS) { - return; - } - } - MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); - MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); -} - -MPIInitializer::~MPIInitializer() { - int finalized_flag = 0; - (void)MPI_Finalized(&finalized_flag); - if (finalized_flag == 0) { - (void)MPI_Finalize(); - } -} - -MPIInitializer &MPIInitializer::GetInstance() { - static MPIInitializer instance; - return instance; -} - -int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } - -int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } - -PYBIND11_MODULE(_ms_mpi, mpi_initializer) { - mpi_initializer.doc() = "mindspore mpi python wrapper"; - mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id"); - mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size"); -} -} // namespace gpu -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc deleted file mode 100644 index 86dcf2b449..0000000000 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ /dev/null @@ -1,591 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_adjust.h" - -#include -#include -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" -#include "common/trans.h" -#include "utils/config_manager.h" -#include "common/utils.h" -#include "kernel/kernel_build_info.h" -#include "utils/utils.h" -#include "device/ascend/profiling/profiling_manager.h" -#include "device/ascend/kernel_select_ascend.h" -#include "runtime/base.h" -#include "device/ascend/ascend_stream_assign.h" - -namespace mindspore { -namespace device { -using device::ascend::ProfilingUtils; -void KernelAdjust::ReorderGetNext(const std::shared_ptr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - const std::vector &origin_cnode_list = kernel_graph_ptr->execution_order(); - std::vector getnext_list; - std::vector other_list; - for (const auto &cnode : origin_cnode_list) { - if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { - getnext_list.emplace_back(cnode); - } else { - other_list.emplace_back(cnode); - } - } - std::vector new_order_list; - new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end()); - new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); - kernel_graph_ptr->set_execution_order(new_order_list); -} - -bool KernelAdjust::NeedInsertSwitch() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1); -} - -CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, - uint32_t event_id) { - MS_EXCEPTION_IF_NULL(graph_ptr); - auto send_op = std::make_shared(kSendOpName); - MS_EXCEPTION_IF_NULL(send_op); - auto send_apply = std::make_shared(send_op); - MS_EXCEPTION_IF_NULL(send_apply); - std::vector send_input_list = {send_apply}; - CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); - MS_EXCEPTION_IF_NULL(send_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - send_node_ptr->set_abstract(abstract_none); - return send_node_ptr; -} - -CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, - uint32_t event_id) { - MS_EXCEPTION_IF_NULL(graph_ptr); - auto recv_op = std::make_shared(kRecvOpName); - MS_EXCEPTION_IF_NULL(recv_op); - auto recv_apply = std::make_shared(recv_op); - MS_EXCEPTION_IF_NULL(recv_apply); - std::vector recv_input_list = {recv_apply}; - CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); - MS_EXCEPTION_IF_NULL(recv_node_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); - AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); - auto abstract_none = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_none); - recv_node_ptr->set_abstract(abstract_none); - return recv_node_ptr; -} - -void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { - device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); - resource_manager.ResetResource(); - if (!NeedInsertSwitch()) { - return; - } - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; - ReorderGetNext(kernel_graph_ptr); - std::map switch_loop_input; - CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); - - std::vector *mute_inputs = kernel_graph_ptr->MutableInputs(); - MS_EXCEPTION_IF_NULL(mute_inputs); - mute_inputs->push_back(switch_loop_input[kLoopCountParamName]); - mute_inputs->push_back(switch_loop_input[kEpochParamName]); - mute_inputs->push_back(switch_loop_input[kIterLoopParamName]); - mute_inputs->push_back(switch_loop_input[kZeroParamName]); - mute_inputs->push_back(switch_loop_input[kOneParamName]); - for (const auto &input : kernel_graph_ptr->inputs()) { - MS_EXCEPTION_IF_NULL(input); - if (input->isa()) { - ParameterPtr param_ptr = input->cast(); - if (param_ptr == nullptr) { - MS_EXCEPTION(NotSupportError) << "Cast to parameter point failed !"; - } - } - } - - const std::vector &orders = kernel_graph_ptr->execution_order(); - if (orders.empty()) { - MS_LOG(EXCEPTION) << "graph execution order is empty"; - } - - std::vector exec_order; - std::vector getnext_active_streams; - std::vector fpbp_active_streams; - CNodePtr getnext_cnode; - uint32_t eos_done_event_id = UINT32_MAX; - - // getnext loop process - // getnext loop stream switch op - CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(getnext_switch_app); - uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); - exec_order.push_back(getnext_switch_app); - - // getnext op - uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); - size_t i = 0; - for (; i < orders.size(); i++) { - auto node = orders[i]; - exec_order.push_back(node); - AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); - if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { - getnext_cnode = node; - break; - } - } - - // update getnext loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); - - // getnext loop fpbp start send - uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); - CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); - exec_order.push_back(fpbp_start_send); - - if (eos_mode) { - // getnext loop eos start send - uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); - CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); - AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); - exec_order.push_back(eos_start_send); - - // End Of Sequence loop process - // eos loop stream switch - CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(eos_switch_app); - uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); - AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), eos_switch_app); - exec_order.push_back(eos_switch_app); - - // eos loop eos start recv - CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); - uint32_t eos_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); - exec_order.push_back(eos_start_recv); - - // update eos loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); - - // EndOfSequence op - CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); - MS_EXCEPTION_IF_NULL(end_of_sequence_op); - AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); - exec_order.push_back(end_of_sequence_op); - - // eos loop eos done send - eos_done_event_id = resource_manager.ApplyNewEvent(); - CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); - AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); - exec_order.push_back(eos_done_send); - - // eos loop stream active - fpbp_active_streams.push_back(eos_switch_stream_id); - } - - // fpbp loop process - // fpbp loop stream switch - CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(fpbp_switch_app); - uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); - AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); - exec_order.push_back(fpbp_switch_app); - - // fpbp loop fpbp start recv - CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); - uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); - AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); - exec_order.push_back(fpbp_start_recv); - - // update fpbp loop stream switch true_branch_stream attr - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); - - // fpbp loop AssignAdd - CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(assign_add_one); - AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); - exec_order.push_back(assign_add_one); - - // fpbp memcpy - std::vector memcpy_list; - std::vector other_list; - CNodePtr cur_cnode = nullptr; - for (size_t idx = i + 1; idx < orders.size(); idx++) { - cur_cnode = orders[idx]; - if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { - memcpy_list.emplace_back(cur_cnode); - } else { - other_list.emplace_back(cur_cnode); - } - } - - (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); - - // fpbp loop eos done recv - if (eos_mode) { - CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); - AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); - exec_order.push_back(eos_done_recv); - } - - // stream active to activate getnext loop - CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(getnext_active_app); - getnext_active_streams.push_back(getnext_switch_stream_id); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), - getnext_active_app); - exec_order.push_back(getnext_active_app); - - // fpbp loop other ops - (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); - - // stream active to activate fpbp loop and eos loop - CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(fpbp_active_app); - fpbp_active_streams.push_back(fpbp_switch_stream_id); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); - exec_order.push_back(fpbp_active_app); - - kernel_graph_ptr->set_execution_order(exec_order); -} - -void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, - std::map *switch_loop_input) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(switch_loop_input); - std::vector shp = {1}; - tensor::TensorPtr tensor_ptr = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(tensor_ptr); - mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); - if (paremeter_abstract_ptr == nullptr) { - MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!"; - } - - ParameterPtr loop_count = std::make_shared(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(loop_count); - loop_count->set_name(kLoopCountParamName); - loop_count->set_abstract(paremeter_abstract_ptr); - ParameterPtr loop_count_new = kernel_graph_ptr->NewParameter(loop_count); - - (*switch_loop_input)[kLoopCountParamName] = loop_count_new; - - ParameterPtr iter_loop = std::make_shared(kernel_graph_ptr); - iter_loop->set_name(kIterLoopParamName); - iter_loop->set_abstract(paremeter_abstract_ptr); - ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop); - (*switch_loop_input)[kIterLoopParamName] = iter_loop_new; - - ParameterPtr zero = std::make_shared(kernel_graph_ptr); - zero->set_name(kZeroParamName); - zero->set_abstract(paremeter_abstract_ptr); - ParameterPtr zero_new = kernel_graph_ptr->NewParameter(zero); - (*switch_loop_input)[kZeroParamName] = zero_new; - - ParameterPtr one = std::make_shared(kernel_graph_ptr); - one->set_name(kOneParamName); - one->set_abstract(paremeter_abstract_ptr); - ParameterPtr one_new = kernel_graph_ptr->NewParameter(one); - (*switch_loop_input)[kOneParamName] = one_new; - - ParameterPtr epoch = std::make_shared(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(epoch); - epoch->set_name(kEpochParamName); - epoch->set_abstract(paremeter_abstract_ptr); - ParameterPtr epoch_new = kernel_graph_ptr->NewParameter(epoch); - (*switch_loop_input)[kEpochParamName] = epoch_new; -} - -kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBuilder( - const std::vector &formats, const std::vector &type_ids) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetInputsFormat(formats); - selected_kernel_builder.SetInputsDeviceType(type_ids); - - selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); - selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); - selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); - return selected_kernel_builder; -} - -CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - auto typeNone_abstract = std::make_shared(); - auto stream_switch = std::make_shared(kStreamSwitchOpName); - std::vector inputs; - inputs.push_back(NewValueNode(stream_switch)); - inputs.push_back(switch_loop_input.at(kLoopCountParamName)); - inputs.push_back(switch_loop_input.at(kIterLoopParamName)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_switch_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_switch_app.get()); - stream_switch_app->set_abstract(typeNone_abstract); - // set attr: cond_ RT_LESS - int condition = static_cast(RT_LESS); - ValuePtr cond = MakeValue(condition); - AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); - // set attr:data_type - int data_type = static_cast(RT_SWITCH_INT64); - ValuePtr dt = MakeValue(data_type); - AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); - // set distinction label and graph id - return stream_switch_app; -} - -CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActiveOpName); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_others)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_others_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); - stream_active_others_app->set_abstract(typeNone_abstract); - return stream_active_others_app; -} - -CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, - const CNodePtr &node, size_t output_idx) { - auto idx = NewValueNode(SizeToInt(output_idx)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(SizeToInt(output_idx)); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); - MS_EXCEPTION_IF_NULL(tuple_getitem); - tuple_getitem->set_scope(node->scope()); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); - return tuple_getitem; -} - -CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, - const CNodePtr &getnext_cnode) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; - selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT}); - selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8}); - - selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); - selected_kernel_builder.SetProcessor(kernel::Processor::AICPU); - selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL); - - selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); - selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8}); - // EndOfSequence - auto end_of_sequence = std::make_shared(kEndOfSequence); - std::vector inputs; - inputs.push_back(NewValueNode(end_of_sequence)); - // GetNext output 0 is EndOfSequence's input - auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0); - inputs.push_back(tuple_get_item); - CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(end_of_sequence_node); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get()); - std::vector input_names = {"x"}; - ValuePtr input_names_v = MakeValue(input_names); - AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node); - std::vector output_names = {"y"}; - ValuePtr output_names_v = MakeValue(output_names); - AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node); - end_of_sequence_node->set_abstract(tuple_get_item->abstract()); - return end_of_sequence_node; -} - -CNodePtr KernelAdjust::CreateStreamAssignAddnOP( - const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); - selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); - // AssignAdd - auto assign_add = std::make_shared(kAssignAddOpName); - std::vector inputs; - inputs.push_back(NewValueNode(assign_add)); - inputs.push_back(switch_loop_input.at(kLoopCountParamName)); - inputs.push_back(switch_loop_input.at(kOneParamName)); - CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_add_one); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_one.get()); - std::vector input_names = {"ref", "value"}; - std::vector output_names = {"output"}; - ValuePtr input_names_v = MakeValue(input_names); - ValuePtr output_names_v = MakeValue(output_names); - AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one); - AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one); - selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); - MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); - assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); - return assign_add_one; -} - -bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { - if (!NeedInsertSwitch()) { - return true; - } - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - auto input_nodes = kernel_graph_ptr->inputs(); - std::vector inputs; - LoadSwitchInputs(&inputs); - std::shared_ptr> inputsPtr = std::make_shared>(inputs); - kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr); - size_t input_ctrl_size = inputs.size(); - // inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one. - // deal four ctrl nodes. - for (size_t i = 0; i < inputs.size(); ++i) { - auto tensor = inputs[i]; - size_t deal_index = input_nodes.size() - input_ctrl_size + i; - if (deal_index >= input_nodes.size()) { - MS_LOG(EXCEPTION) << "deal_index[" << deal_index << "] out of range"; - } - auto input_node = input_nodes[deal_index]; - bool need_sync = false; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - auto pk_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(tensor); - MS_EXCEPTION_IF_NULL(pk_node); - if (tensor->is_dirty() || !pk_node->has_default()) { - need_sync = true; - } - } - if (need_sync) { - auto pk_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(pk_node); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - MS_EXCEPTION_IF_NULL(device_address); - tensor->set_device_address(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(INFO) << "SyncHostToDevice failed."; - return false; - } - } - tensor->set_dirty(false); - } - return true; -} - -void KernelAdjust::LoadSwitchInputs(std::vector *inputs) { - MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; - MS_EXCEPTION_IF_NULL(inputs); - std::vector shp = {1}; - tensor::TensorPtr loop_count_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(loop_count_tensor); - int32_t *val = nullptr; - val = static_cast(loop_count_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - inputs->push_back(loop_count_tensor); - - // Epoch in device - tensor::TensorPtr epoch_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(epoch_tensor); - val = static_cast(epoch_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - inputs->push_back(epoch_tensor); - - tensor::TensorPtr iter_loop_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(iter_loop_tensor); - val = static_cast(iter_loop_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num())); - MS_LOG(INFO) << "iter_loop_tensor = " << *val; - inputs->push_back(iter_loop_tensor); - - tensor::TensorPtr zero_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(zero_tensor); - val = static_cast(zero_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - inputs->push_back(zero_tensor); - - tensor::TensorPtr one_tensor = std::make_shared(kInt32->type_id(), shp); - MS_EXCEPTION_IF_NULL(one_tensor); - val = static_cast(one_tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 1; - inputs->push_back(one_tensor); - - MS_LOG(INFO) << "---------------- LoadSwitchInputs End--"; -} - -void KernelAdjust::Profiling(NotNull kernel_graph_ptr) { - if (!ascend::ProfilingManager::GetInstance().IsProfiling()) { - MS_LOG(INFO) << "No need to profiling"; - return; - } - ProfilingTraceInfo profiling_trace_info = ProfilingUtils::GetProfilingTraceFromEnv(kernel_graph_ptr); - if (!profiling_trace_info.IsValid()) { - MS_LOG(WARNING) << "[profiling] no profiling node found!"; - return; - } - InsertProfilingKernel(profiling_trace_info, kernel_graph_ptr); -} - -void KernelAdjust::InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, - NotNull kernel_graph_ptr) { - MS_LOG(INFO) << "[profiling] Insert profiling kernel start"; - if (!profiling_trace_info.IsValid()) { - MS_LOG(WARNING) << "Profiling trace point not found"; - return; - } - std::vector new_cnode_list; - std::vector cnode_ptr_list = kernel_graph_ptr->execution_order(); - if (cnode_ptr_list.empty()) { - MS_LOG(ERROR) << "No CNode in graph"; - return; - } - for (const auto &cnode_ptr : cnode_ptr_list) { - ProfilingUtils::ProfilingTraceFpStart(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - new_cnode_list.emplace_back(cnode_ptr); - ProfilingUtils::ProfilingCustomOp(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - ProfilingUtils::ProfilingTraceBpEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - ProfilingUtils::ProfilingTraceEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); - } - kernel_graph_ptr->set_execution_order(new_cnode_list); -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h deleted file mode 100644 index 9f59c486bc..0000000000 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ - -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "session/kernel_graph.h" -#include "kernel/kernel_build_info.h" -#include "session/session_context.h" -#include "ir/tensor.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "device/kernel_info.h" - -using mindspore::device::ascend::ProfilingTraceInfo; -using mindspore::device::ascend::ProfilingUtils; -namespace mindspore { -constexpr auto kLoopCountParamName = "loop_count"; -constexpr auto kIterLoopParamName = "iter_loop"; -constexpr auto kZeroParamName = "zero"; -constexpr auto kOneParamName = "one"; -constexpr auto kEpochParamName = "loop_epoch"; -constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; -constexpr uint32_t kSecondStreamSwitchLabel = 2; - -namespace device { -class KernelAdjust { - public: - static KernelAdjust &GetInstance() { - static KernelAdjust instance; - return instance; - } - - void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); - bool StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr); - void Profiling(NotNull kernel_graph_ptr); - static bool NeedInsertSwitch(); - CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); - - private: - KernelAdjust() = default; - ~KernelAdjust() = default; - - void ReorderGetNext(const std::shared_ptr &kernel_graph_ptr); - CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); - CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); - void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, - std::map *switch_loop_input); - CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input); - CNodePtr CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, const CNodePtr &node, - size_t output_idx); - CNodePtr CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, - const CNodePtr &getnext_cnode); - CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, - const std::map &switch_loop_input); - kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, - const std::vector &type_ids); - void LoadSwitchInputs(std::vector *inputs); - void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, - NotNull kernel_graph_ptr); -}; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ diff --git a/mindspore/ccsrc/device/kernel_info.cc b/mindspore/ccsrc/device/kernel_info.cc deleted file mode 100644 index 59c9b0f411..0000000000 --- a/mindspore/ccsrc/device/kernel_info.cc +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_info.h" - -namespace mindspore { -namespace device { -const kernel::KernelBuildInfo *KernelInfo::select_kernel_build_info() const { return select_kernel_build_info_.get(); } - -kernel::KernelBuildInfoPtr KernelInfo::GetMutableSelectKernelBuildInfo() const { return select_kernel_build_info_; } - -const DeviceAddress *KernelInfo::GetOutputAddr(size_t index) const { - if (index >= output_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return nullptr; - } - return output_address_list_[index].get(); -} - -DeviceAddressPtr KernelInfo::GetMutableOutputAddr(size_t index) const { - if (index >= output_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return nullptr; - } - return output_address_list_[index]; -} - -bool KernelInfo::OutputAddrExist(size_t index) const { - if (index >= output_address_list_.size()) { - return false; - } - return output_address_list_[index] != nullptr; -} - -bool KernelInfo::SetOutputAddr(const DeviceAddressPtr &output_address, size_t index) { - // parameter and valuenode - if (kernel_mod_ == nullptr && index >= output_address_list_.size()) { - for (size_t i = output_address_list_.size(); i <= index; i++) { - output_address_list_.emplace_back(nullptr); - } - } else if (output_address_list_.empty()) { - // set cnode - for (size_t i = 0; i < kernel_mod_->GetOutputSizeList().size(); i++) { - output_address_list_.emplace_back(nullptr); - } - } - if (index >= output_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return false; - } - output_address_list_[index] = output_address; - return true; -} - -DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { - if (index >= workspace_address_list_.size()) { - MS_LOG(ERROR) << "Index [" << index << "] out of range"; - return nullptr; - } - return workspace_address_list_[index].get(); -} - -bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { - if (workspace_address_list_.empty()) { - // parameter and valuenode - if (kernel_mod_ == nullptr) { - workspace_address_list_.emplace_back(nullptr); - } else { - // set cnode - for (size_t i = 0; i < kernel_mod_->GetWorkspaceSizeList().size(); i++) { - workspace_address_list_.emplace_back(nullptr); - } - } - } - if (index >= workspace_address_list_.size()) { - MS_LOG(ERROR) << "Index" << index << " out of range"; - return false; - } - workspace_address_list_[index] = output_address; - return true; -} - -void KernelInfo::set_kernel_mod(const kernel::KernelModPtr &kernel_mod) { kernel_mod_ = kernel_mod; } - -kernel::KernelMod *KernelInfo::MutableKernelMod() const { return kernel_mod_.get(); } - -const kernel::KernelMod *KernelInfo::kernel_mod() const { return kernel_mod_.get(); } - -bool KernelInfo::operator==(const KernelInfo &other) const { - if (stream_id_ != other.stream_id_ || stream_distinction_label_ != other.stream_distinction_label_ || - graph_id_ != other.graph_id_) { - return false; - } - if ((select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ == nullptr) || - (select_kernel_build_info_ == nullptr && other.select_kernel_build_info_ != nullptr)) { - return false; - } - if (select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ != nullptr) { - if (!(*select_kernel_build_info_ == *(other.select_kernel_build_info_))) { - return false; - } - } - // Currently we only check whether both the kernel_mod_ are initialized or uninitialized. - if ((kernel_mod_ == nullptr && other.kernel_mod_ != nullptr) || - (kernel_mod_ != nullptr && other.kernel_mod_ == nullptr)) { - return false; - } - // Currently we only check whether both the sizes are equal of output_address_list_ and workspace_address_list_ or - // not. We can complete this check in the future. - if (output_address_list_.size() != other.output_address_list_.size() || - workspace_address_list_.size() != other.workspace_address_list_.size()) { - return false; - } - return true; -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_info.h b/mindspore/ccsrc/device/kernel_info.h deleted file mode 100644 index 84cfaa0fa3..0000000000 --- a/mindspore/ccsrc/device/kernel_info.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DEVICE_KERNEL_INFO_H_ -#define MINDSPORE_DEVICE_KERNEL_INFO_H_ - -#include -#include -#include "kernel/kernel_build_info.h" -#include "device/ascend/ascend_device_address.h" -#include "kernel/kernel.h" - -namespace mindspore { -const uint32_t kInvalidGraphId = UINT32_MAX; -const uint32_t kInvalidDistincLabel = UINT32_MAX; -namespace device { -class KernelInfo { - public: - KernelInfo() { - kernel_mod_ = nullptr; - is_feature_map_ = false; - select_kernel_build_info_ = nullptr; - output_address_list_ = {}; - workspace_address_list_ = {}; - stream_id_ = UINT32_MAX; - stream_distinction_label_ = kInvalidDistincLabel; - graph_id_ = kInvalidGraphId; - } - virtual ~KernelInfo() = default; - - const kernel::KernelBuildInfo *select_kernel_build_info() const; - kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; - void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - select_kernel_build_info_ = select_kernel_build_info; - } - void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } - const DeviceAddress *GetOutputAddr(size_t index) const; - DeviceAddressPtr GetMutableOutputAddr(size_t index) const; - bool OutputAddrExist(size_t index) const; - bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); - DeviceAddress *GetWorkspaceAddr(size_t index) const; - bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); - void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); - kernel::KernelMod *MutableKernelMod() const; - const kernel::KernelMod *kernel_mod() const; - uint32_t stream_id() const { return stream_id_; } - void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } - uint32_t stream_distinction_label() const { return stream_distinction_label_; } - void set_stream_distinction_label(uint32_t stream_distinction_label) { - stream_distinction_label_ = stream_distinction_label; - } - void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } - uint32_t graph_id() const { return graph_id_; } - bool operator==(const KernelInfo &other) const; - bool is_feature_map() const { return is_feature_map_; } - - private: - bool is_feature_map_; - kernel::KernelBuildInfoPtr select_kernel_build_info_; - std::vector> output_address_list_; - std::vector> workspace_address_list_; - kernel::KernelModPtr kernel_mod_; - // stream_id_ is the index of stream object vector - uint32_t stream_id_; - // stream_distinction_label_ is used mark different op in different stream - uint32_t stream_distinction_label_; - // record which graph the node belong to - uint32_t graph_id_; -}; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_DEVICE_KERNEL_INFO_H_ diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc deleted file mode 100644 index 7efb4702e0..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ /dev/null @@ -1,772 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_runtime.h" -#include -#include -#include -#include -#include "common/utils.h" -#include "common/trans.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "pipeline/parse/python_adapter.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/common_utils.h" -#include "kernel/oplib/oplib.h" -#include "ir/value.h" -using mindspore::kernel::Address; -using mindspore::kernel::AddressPtr; - -namespace mindspore { -namespace device { -KernelRuntime::~KernelRuntime() { -#ifdef ENABLE_DUMP_E2E - dump_conf_ptr_ = nullptr; -#endif -} - -bool KernelRuntime::Run(session::KernelGraph *graph) { - bool ret = false; - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - bool is_task_sink = context_ptr->enable_task_sink(); - if (is_task_sink) { - ret = RunTask(graph); - } else { - ret = LaunchKernel(graph); - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; -#endif - return ret; -} - -// for D to impl -bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::GenTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -bool KernelRuntime::LoadTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -// for D to impl -bool KernelRuntime::RunTask(const session::KernelGraph *graph) { - if (graph != nullptr) { - return true; - } - return false; -} - -bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { - MS_EXCEPTION_IF_NULL(kernel); - if (AnfAlgo::OutputAddrExist(kernel, index)) { - return true; - } - return false; -} - -size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { - MS_EXCEPTION_IF_NULL(node); - if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { - MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" - << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; - } - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); - } - size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); - auto format = AnfAlgo::GetOutputFormat(node, output_index); - if (shape.empty() && format != kOpFormat_DEFAULT) { - shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); - shape = trans::TransShapeToDevice(shape, format); - } - // scalar's output shape is a empty vector - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - return tensor_size; -} - -void KernelRuntime::AssignMemory(session::KernelGraph *graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(mem_manager_); - mem_manager_->ResetDynamicMemory(); - AssignStaticMemory(graph); - AssignDynamicMemory(graph); - UpdateRefNodeOutputMem(graph); -} - -void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, - session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - RunOpAssignInputMemory(input_tensors, graph); - AssignStaticMemoryValueNode(graph); - for (const auto &cnode : graph->execution_order()) { - RunOpAssignOutputMemory(cnode); - RunOpAssignWorkSpaceMemory(cnode); - } - UpdateRefNodeOutputMem(graph); -} - -void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - // clear input parameter memory resource - for (const auto &input_node : graph->inputs()) { - MS_EXCEPTION_IF_NULL(input_node); - AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); - } - // clear input value node memory resource - for (const auto &value_node : graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get()); - } - for (const auto &cnode : graph->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode); - // clear output memory resource - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { - AnfAlgo::SetOutputAddr(nullptr, index, cnode.get()); - } - // clear workspace memory resource - auto kernel_mod = AnfAlgo::GetKernelMod(cnode); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); - for (size_t index = 0; index < workspace_lists.size(); ++index) { - AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get()); - } - } -} - -void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) { - AssignStaticMemoryInput(graph); - AssignStaticMemoryValueNode(graph); - AssignStaticMemoryOutput(graph); -} - -void KernelRuntime::RunOpAssignInputMemory(const std::vector &input_tensors, - const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - if (input_tensors.size() != graph->inputs().size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() - << " should be equal to graph input parameter size " << graph->inputs().size(); - } - - for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) { - auto item = graph->inputs()[input_index]; - MS_EXCEPTION_IF_NULL(item); - if (!item->isa()) { - continue; - } - auto output_size = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_size; index++) { - MS_EXCEPTION_IF_NULL(input_tensors[input_index]); - if (input_tensors[input_index]->device_address().get() != nullptr) { - AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get()); - continue; - } - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(item, index); - } - auto tensor_size = CountNodeDeviceMemorySize(item, index); - auto device_address = - CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); - MS_EXCEPTION_IF_NULL(device_address); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } - AnfAlgo::SetOutputAddr(device_address, index, item.get()); - } - } -} - -void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - return; - } - - for (size_t i = 0; i < output_sizes.size(); ++i) { - if (AnfAlgo::OutputAddrExist(kernel, i)) { - continue; - } - if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); - continue; - } - std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); - auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); - device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); - MS_EXCEPTION_IF_NULL(device_address); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } - AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); - } -} - -void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - if (kernel->isa()) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); - for (size_t i = 0; i < workspace_lists.size(); ++i) { - auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown); - MS_EXCEPTION_IF_NULL(device_address); - auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]); - if (!ret) { - MS_LOG(EXCEPTION) << "Malloc device memory failed."; - } - AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); - } - } -} - -void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto graph_inputs = graph->inputs(); - auto graph_valid_input = graph->valid_inputs(); - std::vector need_alloc_nodes; - for (size_t i = 0; i < graph_inputs.size(); ++i) { - auto item = graph_inputs[i]; - MS_EXCEPTION_IF_NULL(item); - if (i < graph_valid_input.size() && !graph_valid_input[i]) { - continue; - } - - if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { - auto outs = AnfAlgo::GetAllOutput(item); - for (auto &out : outs) { - MS_EXCEPTION_IF_NULL(out); - if (!out->isa()) { - continue; - } - if (NodeOutputDeviceAddressExist(out, 0)) { - continue; - } - need_alloc_nodes.push_back(out); - } - } - if (!item->isa()) { - continue; - } - if (NodeOutputDeviceAddressExist(item, 0)) { - continue; - } - need_alloc_nodes.push_back(item); - } - - for (auto &item : need_alloc_nodes) { - auto output_size = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_size; index++) { - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); - // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown - if (output_type_id == kTypeUnknown) { - MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; - output_type_id = AnfAlgo::GetOutputInferDataType(item, index); - } - auto tensor_size = CountNodeDeviceMemorySize(item, index); - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); - AnfAlgo::SetOutputAddr(address, index, item.get()); - } - } -} - -void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); - std::vector non_communication_op; - // Assign Communicate Op Memory firstly. - for (const auto &node : nodes) { - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - MS_EXCEPTION_IF_NULL(item_with_index.first); - if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { - continue; - } - graph->AddFinalOutputKernel(item_with_index.first); - if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { - AssignCommunicationNodeMem(kStaticMem, item_with_index.first); - } else { - non_communication_op.emplace_back(item_with_index); - } - } - - for (const auto &item_with_index : non_communication_op) { - AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); - } -} - -void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto &kernels = graph->execution_order(); - for (auto &kernel : kernels) { - MS_EXCEPTION_IF_NULL(kernel); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - MS_LOG(INFO) << "This kernel has no output size."; - continue; - } - for (size_t i = 0; i < output_sizes.size(); ++i) { - session::AnfWithOutIndex out_pair(kernel, i); - if (graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = graph->GetRefCorrespondOutput(out_pair); - MS_EXCEPTION_IF_NULL(origin_pair.first); - auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second); - MS_EXCEPTION_IF_NULL(origin_node_output_addr); - auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i); - if (origin_node_output_addr.get() != cur_node_output_addr.get()) { - MS_LOG(INFO) << "REF address is not same, ref node output need address update"; - MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is " - << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i; - AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get()); - } - } - } - } -} - -void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { - AssignCommunicationNodeInputMem(node); - AssignCommunicationNodeOutputMem(flag, node); -} - -void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto kernel_mod = AnfAlgo::GetKernelMod(node); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; - return; - } - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - size_t total_size = 0; - size_t output_index = 0; - std::vector align_size_list; - for (uint64_t mem_size : output_sizes) { - if (AnfAlgo::OutputAddrExist(node, output_index++)) { - MS_LOG(INFO) << "communication op addr exist"; - continue; - } - if (context_ptr->enable_hccl()) { - mem_size = mem_manager_->GetCommonAlignSize(mem_size); - } - total_size += mem_size; - align_size_list.emplace_back(mem_size); - } - uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); - for (size_t j = 0; j < align_size_list.size(); ++j) { - std::string output_format = AnfAlgo::GetOutputFormat(node, j); - auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); - auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); - MS_EXCEPTION_IF_NULL(address); - if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { - address->UpdateCommunicationAddress(); - } - AnfAlgo::SetOutputAddr(address, j, node.get()); - output_ptr += align_size_list[j]; - } -} - -DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { - MS_EXCEPTION_IF_NULL(anf_node); - auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.size() <= index) { - MS_LOG(EXCEPTION) << "Previous node output size < node index"; - } - std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); - auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); - auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); - AnfAlgo::SetOutputAddr(address, index, anf_node.get()); - return address; -} - -void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - size_t total_size = 0; - std::vector> addr_size; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); - auto input_node = input_node_with_index.first; - DeviceAddressPtr address = nullptr; - if (input_node->isa()) { - address = PreAssignCNodeMemory(input_node, input_node_with_index.second); - } else { - MS_LOG(EXCEPTION) << "Communication node inputs only support CNode"; - } - MS_EXCEPTION_IF_NULL(address); - auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); - total_size += mem_size; - addr_size.emplace_back(address.get(), mem_size); - } - uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(input_ptr); - input_ptr += iter.second; - } -} - -void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { - MS_LOG(INFO) << "GetNext disable mem_reuse"; - flag = kDynamicMem; - } - auto kernel_mod = AnfAlgo::GetKernelMod(node); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; - return; - } - for (size_t i = 0; i < output_sizes.size(); ++i) { - if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { - continue; - } - if (NodeOutputDeviceAddressExist(node, i)) { - MS_LOG(INFO) << "Already malloc index:" << i; - continue; - } - auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); - if (ptr == nullptr) { - // reused ptr, no need alloc, continue; - continue; - } - std::string output_format = AnfAlgo::GetOutputFormat(node, i); - auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); - auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); - MS_EXCEPTION_IF_NULL(device_address); - device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); - if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { - device_address->UpdateCommunicationAddress(); - } - AnfAlgo::SetOutputAddr(device_address, i, node.get()); - } -} - -void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, - size_t output_idx) { - MS_EXCEPTION_IF_NULL(value_node); - MS_EXCEPTION_IF_NULL(node_value); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - auto tensor = node_value->cast(); - if (tensor == nullptr) { - MS_LOG(WARNING) << "Tensor is null"; - return; - } - size_t tensor_size = tensor->data().nbytes(); - auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); - } - auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); - DeviceAddressPtr address = nullptr; - if (ms_context->enable_pynative_infer()) { - address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); - MS_EXCEPTION_IF_NULL(address); - if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { - MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; - } - } else { - auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); - address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); - MS_EXCEPTION_IF_NULL(address); - } - AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); - if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), - tensor->data_c())) { - MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" - << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " - << AnfAlgo::GetOutputInferDataType(value_node, output_idx); - } -} - -void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - for (auto &value_node : graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - if (NodeOutputDeviceAddressExist(value_node, 0)) { - MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; - continue; - } - auto &node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - if (node_value->isa()) { - AssignValueNodeTensor(value_node, node_value, 0); - } else if (node_value->isa()) { - auto value = GetValue(node_value); - size_t tensor_size = value.size(); - DeviceAddressPtr address = nullptr; - if (ms_context->enable_pynative_infer()) { - address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); - if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { - MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; - } - } else { - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); - } - AnfAlgo::SetOutputAddr(address, 0, value_node.get()); - std::vector shape = {1, SizeToInt(tensor_size)}; - if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) { - MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!"; - } - } - } -} - -void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); - auto mem_flag = kDynamicMem; - if (is_enable_mem_reuse) { - mem_manager_->MallocReusedDynamicMem(graph); - mem_flag = kReuseDynamicMem; - } - auto &execution_nodes = graph->execution_order(); - std::vector compute_nodes; - // communication nodes first - for (auto &node : execution_nodes) { - if (AnfAlgo::IsCommunicationOp(node)) { - // skip if the memory is already alocated - AssignCommunicationNodeMem(mem_flag, node); - } else { - compute_nodes.emplace_back(node); - } - } - - // then compute nodes - for (auto &node : compute_nodes) { - AssignNodeOutputMem(mem_flag, node, kGetAllOuts); - AssignWorkSpaceMem(mem_flag, node); - } -} - -void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(mem_manager_); - auto kernel_mod = AnfAlgo::GetKernelMod(node); - MS_EXCEPTION_IF_NULL(kernel_mod); - size_t index = 0; - for (auto &size : kernel_mod->GetWorkspaceSizeList()) { - auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); - AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); - index++; - } -} - -void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, - AddressPtrList *kernel_outputs) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_inputs); - MS_EXCEPTION_IF_NULL(kernel_workspaces); - MS_EXCEPTION_IF_NULL(kernel_outputs); - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { - return GenAddrCleanLaunchArgs(cnode, kernel_inputs); - } - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); - MS_EXCEPTION_IF_NULL(device_address); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->emplace_back(input); - } - - for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetOutputAddr(kernel, i); - kernel::AddressPtr output = std::make_shared(); - MS_EXCEPTION_IF_NULL(output); - output->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(output->addr); - output->size = device_address->size_; - kernel_outputs->emplace_back(output); - } - - for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { - auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_workspaces->emplace_back(workspace); - } -} - -void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) { - if (cnode->inputs().size() != 2) { - MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2."; - } - MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); - auto pre_node = (cnode->inputs()[1])->cast(); - // set clean output address - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->emplace_back(input); - } - MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); - } - // set clean workspace address - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - for (const auto &index : clean_workspaces_indexs) { - auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_inputs->emplace_back(workspace); - } - } -} - -bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { - auto &kernels = graph.execution_order(); - for (const auto &kernel : kernels) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - - AddressPtrList kernel_inputs; - AddressPtrList kernel_workspaces; - AddressPtrList kernel_outputs; - GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); - auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); - if (!ret) { - MS_LOG(ERROR) << "Launch kernel failed."; - return false; - } - } - return true; -} - -bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - if (!LaunchKernelMod(*graph)) { - MS_LOG(ERROR) << "LaunchKernelMod failed!"; - return false; - } - return true; -} - -void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { - MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; -} - -#ifdef ENABLE_DUMP_E2E -bool KernelRuntime::SetDumpConf() { - dump_conf_ptr_ = std::make_shared(); - MS_EXCEPTION_IF_NULL(dump_conf_ptr_); - bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile(); - return ret; -} - -DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; } -#endif -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h deleted file mode 100644 index 8c6a5eb19b..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ -#define MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ -#include -#include -#include -#include - -#include "device/device_address.h" -#include "ir/tensor.h" -#include "predict/generator/utils/ir_model_util.h" -#ifdef ENABLE_DUMP_E2E -#include "debug/e2e_dump.h" -#endif -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel.h" -#include "utils/context/ms_context.h" -#include "device/memory_manager.h" - -using mindspore::tensor::Tensor; -using std::vector; -using TensorPtr = std::shared_ptr; -using mindspore::kernel::AddressPtr; -using AddressPtrList = std::vector; - -namespace mindspore { -#ifndef ENABLE_DEBUGGER -class Debugger; -#endif -namespace device { -class KernelRuntime { - public: - KernelRuntime() = default; - virtual ~KernelRuntime(); - virtual bool Init() = 0; - virtual void AssignMemory(session::KernelGraph *graph); - void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); - void RunOpClearMemory(const session::KernelGraph *graph); - virtual bool Run(session::KernelGraph *graph); - virtual bool DumpData(session::KernelGraph *graph); - virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); - virtual bool RunTask(const session::KernelGraph *graph); - virtual bool GenTask(const session::KernelGraph *graph); - bool LaunchKernel(const session::KernelGraph *graph); - virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); - virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); - virtual void ClearGraphRuntimeResource(uint32_t graph_id); - virtual bool SyncStream() = 0; - -#ifdef ENABLE_DUMP_E2E - DumpConfPtr GetDumpConf(); -#endif - virtual bool LoadTask(const session::KernelGraph *graph); - // for GPU and D to impl - virtual void ReleaseDeviceRes() {} - void set_device_id(uint32_t device_id) { device_id_ = device_id; } - - protected: - virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, - TypeId type_id) = 0; - virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); - void AssignStaticMemory(session::KernelGraph *graph); - void AssignDynamicMemory(session::KernelGraph *graph); - void ReuseAssignDynamicMemory(session::KernelGraph *graph); - void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); - void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); - void AssignReuseWorkSpaceMem(const AnfNodePtr &node); - - void UpdateRefNodeOutputMem(const session::KernelGraph *graph); - - void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); - void AssignCommunicationNodeInputMem(const AnfNodePtr &node); - void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); -#ifdef ENABLE_DUMP_E2E - bool SetDumpConf(); -#endif - - private: - void AssignStaticMemoryOutput(session::KernelGraph *graph); - void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); - bool LaunchKernelMod(const session::KernelGraph &graph); - void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); - size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); - void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); - void RunOpAssignOutputMemory(const AnfNodePtr &kernel); - void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); - void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); - DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); - - protected: - uint32_t device_id_{0}; -#ifdef ENABLE_DUMP_E2E - DumpConfPtr dump_conf_ptr_; -#endif - void *stream_ = nullptr; - std::shared_ptr mem_manager_{nullptr}; -}; -using KernelRuntimePtr = std::shared_ptr; -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/device/kernel_runtime_manager.cc b/mindspore/ccsrc/device/kernel_runtime_manager.cc deleted file mode 100644 index 29d74762b4..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime_manager.cc +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/kernel_runtime_manager.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -void KernelRuntimeManager::ClearRuntimeResource() { - std::lock_guard guard(lock_); - for (auto &iter : runtime_map_) { - MS_LOG(INFO) << "Release device " << iter.first; - MS_EXCEPTION_IF_NULL(iter.second); - iter.second->ReleaseDeviceRes(); - } - runtime_map_.clear(); -} - -void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { - std::lock_guard guard(lock_); - for (auto &iter : runtime_map_) { - MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource"; - if (!iter.second) { - MS_LOG(ERROR) << "Kernel runtime is nullptr"; - continue; - } - iter.second->ClearGraphRuntimeResource(graph_id); - } -} - -void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { - if (runtime_creators_.find(device_name) == runtime_creators_.end()) { - (void)runtime_creators_.emplace(device_name, runtime_creator); - } -} - -KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) { - std::string runtime_key = device_name + "_" + std::to_string(device_id); - auto runtime_iter = runtime_map_.find(runtime_key); - if (runtime_iter != runtime_map_.end()) { - return runtime_iter->second.get(); - } else if (runtime_map_.size() > 0) { - auto cur_runtime_key = runtime_map_.begin()->first; - auto find_pos = cur_runtime_key.rfind('_'); - if (find_pos != std::string::npos) { - if (cur_runtime_key.size() > find_pos + 1) { - auto cur_device_id = cur_runtime_key.substr(find_pos + 1); - MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id - << ", set device id: " << device_id << " failed"; - } else { - MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: " - << device_id << " failed"; - } - } - } - return GetKernelRuntime(device_name, device_id); -} - -KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) { - std::lock_guard guard(lock_); - std::string runtime_key = device_name + "_" + std::to_string(device_id); - auto runtime_iter = runtime_map_.find(runtime_key); - if (runtime_iter != runtime_map_.end()) { - return runtime_iter->second.get(); - } - std::shared_ptr kernel_runtime; - auto creator_iter = runtime_creators_.find(device_name); - if (creator_iter != runtime_creators_.end()) { - MS_EXCEPTION_IF_NULL(creator_iter->second); - kernel_runtime = (creator_iter->second)(); - kernel_runtime->set_device_id(device_id); - MS_EXCEPTION_IF_NULL(kernel_runtime); - runtime_map_[runtime_key] = kernel_runtime; - } else { - MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id; - } - - return kernel_runtime.get(); -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/kernel_runtime_manager.h b/mindspore/ccsrc/device/kernel_runtime_manager.h deleted file mode 100644 index 89b45ff5f8..0000000000 --- a/mindspore/ccsrc/device/kernel_runtime_manager.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "device/kernel_runtime.h" -namespace mindspore { -namespace device { -using KernelRuntimeCreator = std::function()>; - -class KernelRuntimeManager { - public: - static KernelRuntimeManager &Instance() { - static KernelRuntimeManager instance; - return instance; - } - void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator); - KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); - KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); - void ClearRuntimeResource(); - void ClearGraphResource(uint32_t graph_id); - - private: - KernelRuntimeManager() = default; - ~KernelRuntimeManager() = default; - DISABLE_COPY_AND_ASSIGN(KernelRuntimeManager); - std::map > runtime_map_; - std::map runtime_creators_; - std::mutex lock_; -}; - -class KernelRuntimeRegistrar { - public: - KernelRuntimeRegistrar(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { - KernelRuntimeManager::Instance().Register(device_name, std::move(runtime_creator)); - } - ~KernelRuntimeRegistrar() = default; -}; - -#define MS_REG_KERNEL_RUNTIME(DEVICE_NAME, RUNTIME_CLASS) \ - static const KernelRuntimeRegistrar g_kernel_runtime_##DEVICE_NAME##_reg( \ - DEVICE_NAME, []() { return std::make_shared(); }); -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc deleted file mode 100644 index c6a2329e8f..0000000000 --- a/mindspore/ccsrc/device/memory_manager.cc +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "device/memory_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" -using mindspore::memreuse::BestFitMemReuse; -using mindspore::memreuse::MemReuseUtilPtr; -namespace mindspore { -namespace device { -size_t MemoryManager::GetCommonAlignSize(size_t input_size) const { - return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; -} - -size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { - return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; -} - -void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - // set all infos - mem_reuse_util_ptr->SetAllInfo(graph); - auto bestfit_mem_reuse = std::make_shared(); - MS_EXCEPTION_IF_NULL(bestfit_mem_reuse); - bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get()); - size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize(); - MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]"; - mem_reuse_util_ptr_ = mem_reuse_util_ptr; - auto base_ptr = MallocDynamicMem(total_allocated_size, false); - mem_reuse_util_ptr_->set_mem_base(base_ptr); -} - -uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { - MS_EXCEPTION_IF_NULL(node); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - uint8_t *ptr = nullptr; - if (AnfAlgo::IsCommunicationOp(node)) { - bool communication_mem = false; - if (context_ptr->enable_hccl()) { - communication_mem = true; - } - if (flag == kStaticMem) { - ptr = MallocStaticMem(size, communication_mem); - } else { - ptr = MallocDynamicMem(size, communication_mem); - } - return ptr; - } - - if (flag == kStaticMem) { - ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { - ptr = MallocDynamicMem(size, false); - } else if (flag == kReuseDynamicMem) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); - ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); - } - return ptr; -} - -uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { - if (flag == kReuseDynamicMem) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); - return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); - } - return MallocDynamicMem(size, false); -} - -uint8_t *MemoryManager::MallocMem(int flag, size_t size) { - uint8_t *ptr = nullptr; - if (flag == kStaticMem) { - ptr = MallocStaticMem(size, false); - } else if (flag == kDynamicMem) { - ptr = MallocDynamicMem(size, false); - } - return ptr; -} - -uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { - size_t align_size = 0; - if (communication_mem) { - align_size = GetCommunicationAlignSize(size); - } else { - align_size = GetCommonAlignSize(size); - } - - MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] communication_mem: " << communication_mem; - - if (static_mem_offset_ < align_size) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - total_static_size_ += align_size; - auto offset = static_mem_offset_ - align_size; - if (dynamic_mem_offset_ > offset) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - static_mem_offset_ = offset; - if (communication_mem) { - return device_mem_base_ + offset + kMemAlignSize; - } else { - return device_mem_base_ + offset; - } -} - -uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { - size_t align_size = 0; - if (communication_mem) { - align_size = GetCommunicationAlignSize(size); - } else { - align_size = GetCommonAlignSize(size); - } - - MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] communication_mem: " << communication_mem; - - uint64_t offset = dynamic_mem_offset_; - auto new_offset = dynamic_mem_offset_ + align_size; - if (new_offset > static_mem_offset_) { - MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ - << "] static[" << total_static_size_ << "])" - << " malloc [" << align_size << "] failed!"; - } - total_dynamic_size_ += align_size; - dynamic_mem_offset_ = new_offset; - - if (communication_mem) { - return device_mem_base_ + offset + kMemAlignSize; - } else { - return device_mem_base_ + offset; - } -} - -bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { - auto device_ptr = MallocMemFromMemPool(size); - if (!device_ptr) { - return false; - } - address->ptr_ = device_ptr; - address->from_mem_pool_ = true; - return true; -} - -void *MemoryManager::MallocMemFromMemPool(size_t size) { - if (size == 0) { - MS_LOG(ERROR) << "MallocMemFromMemPool size is 0."; - } - return nullptr; -} - -void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) { - MS_EXCEPTION_IF_NULL(address); - MS_EXCEPTION_IF_NULL(address->ptr_); - FreeMemFromMemPool(address->ptr_); - address->ptr_ = nullptr; -} - -void MemoryManager::FreeMemFromMemPool(void *device_ptr) { - if (device_ptr == nullptr) { - MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; - } -} - -bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list) { - auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); - if (device_ptr_list.size() == 0) { - return false; - } - if (addr_list.size() != device_ptr_list.size()) { - MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; - } - for (size_t i = 0; i < addr_list.size(); i++) { - MS_EXCEPTION_IF_NULL(device_ptr_list[i]); - MS_EXCEPTION_IF_NULL(addr_list[i]); - addr_list[i]->ptr_ = device_ptr_list[i]; - addr_list[i]->from_mem_pool_ = true; - } - return true; -} - -std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { - if (total_size == 0) { - MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; - } - std::vector device_ptr_list; - device_ptr_list.emplace_back(nullptr); - return device_ptr_list; -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/device/memory_manager.h b/mindspore/ccsrc/device/memory_manager.h deleted file mode 100644 index fb9c539adb..0000000000 --- a/mindspore/ccsrc/device/memory_manager.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ -#include -#include -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -namespace mindspore { -namespace device { -const int kStaticMem = 0; -const int kDynamicMem = 1; -const int kReuseDynamicMem = 2; -const int kGetAllOuts = -1; -const uint64_t kMemAlignSize = 512; -using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; - -class MemoryManager { - public: - MemoryManager() = default; - virtual ~MemoryManager() = default; - - virtual void MallocDeviceMemory() = 0; - virtual void FreeDeviceMemory() = 0; - virtual void ResetDynamicMemory() { - total_dynamic_size_ = 0; - dynamic_mem_offset_ = 0; - } - - void MallocReusedDynamicMem(session::KernelGraph *graph); - uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); - virtual uint8_t *MallocMem(int flag, size_t size); - - virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); - virtual void *MallocMemFromMemPool(size_t size); - virtual void FreeMemFromMemPool(const DeviceAddressPtr address); - virtual void FreeMemFromMemPool(void *device_ptr); - virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, - std::vector size_list); - virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); - - size_t GetCommonAlignSize(size_t input_size) const; - size_t GetCommunicationAlignSize(size_t input_size) const; - - protected: - virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem); - virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); - uint8_t *device_mem_base_{nullptr}; - uint64_t device_mem_size_{0}; - uint64_t dynamic_mem_offset_{0}; - uint64_t static_mem_offset_{0}; - size_t total_static_size_ = 0; - size_t total_dynamic_size_ = 0; - MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; -}; -} // namespace device -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/operator/CMakeLists.txt b/mindspore/ccsrc/frontend/operator/CMakeLists.txt new file mode 100644 index 0000000000..0b6dd77c69 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _OPERATOR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_OPERATOR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) +add_library(_mindspore_frontend_operator_obj OBJECT ${_OPERATOR_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/operator/cc_implementations.cc b/mindspore/ccsrc/frontend/operator/cc_implementations.cc new file mode 100644 index 0000000000..3ec3455be7 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/cc_implementations.cc @@ -0,0 +1,432 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/cc_implementations.h" +#include +#include +#include +#include +#include +#include "utils/misc.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support primitive operators definition +namespace prim { +enum class DataType { kInt, kFloat, kDouble, kUnknown }; + +// Whether has a T type data in AnyPtrList. +template +bool HasType(const AnyPtrList &list) { + bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); + return ret; +} + +DataType InferType(const AnyPtrList &list) { + if (HasType(list)) { + return DataType::kDouble; + } else if (HasType(list)) { + return DataType::kFloat; + } else if (HasType(list)) { + return DataType::kInt; + } + return DataType::kUnknown; +} + +enum OpType { ADD, SUB, MUL, DIV, MOD }; + +template +bool IsSignedIntOverflow(T x, T y, OpType opType) { + auto max = std::numeric_limits::max(); + auto min = std::numeric_limits::min(); + + if (opType == OpType::ADD) { + return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x); + } + + if (opType == OpType::SUB) { + return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x); + } + + if (opType == OpType::MUL) { + return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || + (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x); + } + + if (opType == OpType::DIV || opType == OpType::MOD) { + return x == min && static_cast(y) == -1; + } + + MS_LOG(EXCEPTION) << "Unsupported operation type."; +} + +template +T InnerScalarAdd(T x, T y) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::ADD)) { + MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return x + y; +} + +template +T InnerScalarSub(T x, T y) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::SUB)) { + MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return x - y; +} + +template +T InnerScalarMul(T x, T y) { + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MUL)) { + MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return x * y; +} + +template +float InnerScalarDiv(T x, T y) { + if (y == 0) { + MS_LOG(EXCEPTION) << "Divisor could not be zero"; + } + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::DIV)) { + MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + return static_cast(x) / static_cast(y); +} + +template +T InnerScalarFloordiv(T x, T y) { + auto ret = std::floor(InnerScalarDiv(x, y)); + if (std::is_integral::value) { + return static_cast(ret); + } + return ret; +} + +template +T InnerScalarMod(T x, T y) { + if (y == 0) { + MS_LOG(EXCEPTION) << "Could not mod to zero."; + } + if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MOD)) { + MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) + << ", y: " << std::to_string(y) << "."; + } + if (std::is_integral::value) { + return static_cast(x) % static_cast(y); + } + int x_int = std::floor(x); + int y_int = std::ceil(y); + int max = x_int / y_int; + float ret = x - y * max; + return ret; +} + +template +T InnerScalarPow(T x, U y) { + return std::pow(x, y); +} + +template +bool InnerScalarEq(T x, U y) { + double error = static_cast(x) - static_cast(y); + error = fabs(error); + return error < DBL_EPSILON; +} + +template +bool InnerScalarLt(T x, U y) { + return x < y; +} + +template +bool InnerScalarGt(T x, U y) { + return x > y; +} + +template +bool InnerScalarNe(T x, U y) { + return !InnerScalarEq(x, y); +} + +template +bool InnerScalarLe(T x, U y) { + return x <= y; +} + +template +bool InnerScalarGe(T x, U y) { + return x >= y; +} + +#define SCALAR_OP(op_t) \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ + do { \ + if (list.size() < 2) { \ + MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ + } \ + ValuePtr x = list[0]; \ + ValuePtr y = list[1]; \ + MS_EXCEPTION_IF_NULL(x); \ + MS_EXCEPTION_IF_NULL(y); \ + if (x->isa() && y->isa()) { \ + double sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + int sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(IntToFloat(GetValue(x)), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + float sum = InnerScalar##op_t(GetValue(x), IntToFloat(GetValue(y))); \ + return MakeValue(sum); \ + } \ + MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ + << ", y: " << y->ToString(); \ + } while (0); \ + } + +SCALAR_OP(Add) +SCALAR_OP(Sub) +SCALAR_OP(Mul) +SCALAR_OP(Div) +SCALAR_OP(Mod) +SCALAR_OP(Pow) +SCALAR_OP(Floordiv) + +#define LOGIC_OP(op_t) \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ + if (list.size() < 2) { \ + MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ + } \ + ValuePtr x = list[0]; \ + ValuePtr y = list[1]; \ + MS_EXCEPTION_IF_NULL(x); \ + MS_EXCEPTION_IF_NULL(y); \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + if (x->isa() && y->isa()) { \ + bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ + return MakeValue(sum); \ + } \ + MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ + << ", y: " << y->ToString() << "."; \ + } + +LOGIC_OP(Eq) +LOGIC_OP(Lt) +LOGIC_OP(Gt) +LOGIC_OP(Ne) +LOGIC_OP(Le) +LOGIC_OP(Ge) + +ValuePtr ScalarUAdd(const ValuePtrList &list) { + if (list.size() != 1) { + MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + return x; +} + +ValuePtr ScalarUSub(const ValuePtrList &list) { + if (list.size() != 1) { + MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + + if (x->isa()) { + int32_t sum = -1 * GetValue(x); + return MakeValue(sum); + } + if (x->isa()) { + float sum = -1.0f * GetValue(x); + return MakeValue(sum); + } + + MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; +} + +ValuePtr ScalarLog(const ValuePtrList &list) { + if (list.empty()) { + MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + + if (x->isa()) { + double v = log(GetValue(x)); + return MakeValue(v); + } + if (x->isa()) { + auto v = static_cast(log(GetValue(x))); + return MakeValue(v); + } + + MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); +} + +ValuePtr BoolNot(const ValuePtrList &list) { + if (list.empty()) { + MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; + } + ValuePtr x = list[0]; + MS_EXCEPTION_IF_NULL(x); + bool convert = false; + + if (ValueToBool(x, &convert)) { + auto res = !convert; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); +} + +ValuePtr BoolAnd(const ValuePtrList &list) { + if (list.size() < 2) { + MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; + } + ValuePtr x = list[0]; + ValuePtr y = list[1]; + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(y); + bool x_b = false; + bool y_b = false; + + if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { + auto res = x_b && y_b; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; +} + +ValuePtr BoolOr(const ValuePtrList &list) { + if (list.size() < 2) { + MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; + } + ValuePtr x = list[0]; + ValuePtr y = list[1]; + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(y); + bool x_b = false; + bool y_b = false; + + if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { + auto res = x_b || y_b; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; +} + +ValuePtr BoolEq(const ValuePtrList &list) { + if (list.size() < 2) { + MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; + } + ValuePtr x = list[0]; + ValuePtr y = list[1]; + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(y); + bool x_b = false; + bool y_b = false; + + if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { + auto res = x_b == y_b; + return MakeValue(res); + } + + MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << "."; +} + +std::vector BroadcastShape_(std::vector shpx, std::vector shpy) { + int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); + if (dlen < 0) { + for (int i = 0; i < -dlen; ++i) { + (void)shpx.insert(shpx.begin(), 1); + } + } else if (dlen > 0) { + for (int i = 0; i < dlen; i++) { + (void)shpy.insert(shpy.begin(), 1); + } + } + if (shpx.size() != shpy.size()) { + MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; + } + std::vector shp; + for (size_t i = 0; i < shpx.size(); i++) { + auto a = shpx[i]; + auto b = shpy[i]; + if (a == 1) { + shp.push_back(b); + } else if (b == 1) { + shp.push_back(a); + } else if (a == -1) { + shp.push_back(b); + } else if (b == -1) { + shp.push_back(a); + } else if (a == b) { + shp.push_back(a); + } else { + return std::vector(); + } + } + return shp; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/cc_implementations.h b/mindspore/ccsrc/frontend/operator/cc_implementations.h similarity index 100% rename from mindspore/ccsrc/operator/cc_implementations.h rename to mindspore/ccsrc/frontend/operator/cc_implementations.h diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc new file mode 100644 index 0000000000..7d2573e50a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -0,0 +1,971 @@ + +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/composite.h" +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" +#include "./common.h" +#include "ir/signature.h" +#include "debug/trace.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using AbstractTensor = mindspore::abstract::AbstractTensor; +using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; + +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractClass; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractEllipsis; +using mindspore::abstract::AbstractEllipsisPtr; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractFunctionPtr; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractNone; +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractSlice; +using mindspore::abstract::AbstractTuple; + +ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul}, + {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod}, + {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt}, + {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, + {"__ge__", kPrimScalarGe}}; + +const MetaFuncGraphPtr kTail = std::make_shared("tail"); + +// copy from python API: reduce. +// Apply a function of two arguments cumulatively to the items of a sequence, +// from left to right, so as to reduce the sequence to a single value.For example, +// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). +AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { + std::shared_ptr ret; + size_t size = list.size(); + if (size < 2) { + MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; + } + + AnyPtrList input; + input.push_back(list[0]); + input.push_back(list[1]); + ret = std::make_shared(func(input)); + + for (size_t i = 2; i < size; ++i) { + input.clear(); + input.push_back(ret); + input.push_back(list[i]); + ret = std::make_shared(func(input)); + } + + return ret; +} + +AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { + size_t size = list.size(); + if (size < 2) { + MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; + } + + std::vector input; + input.push_back(list[0]); + input.push_back(list[1]); + AnfNodePtr ret = func(input); + + for (size_t i = 2; i < size; ++i) { + input.clear(); + input.push_back(ret); + input.push_back(list[i]); + ret = func(input); + } + + return ret; +} + +ValuePtr kCompositeHyperMap = std::make_shared(); + +void HyperMap::Init() { + if (fn_leaf_) { + name_ = "hyper_map[" + fn_leaf_->name() + "]"; + } + signatures_ = + // def hypermap(func:read, *args:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); +} + +HyperMap::HyperMap(const std::shared_ptr &fn_leaf) + : MetaFuncGraph("hyper_map"), + fn_leaf_(fn_leaf), + broadcast_(false), + nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { + Init(); +} + +HyperMap::HyperMap(const HyperMap &h) + : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + Init(); +} + +AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector inputs; + if (fn_arg != nullptr) { + inputs.push_back(fn_arg); + } else { + inputs.push_back(NewValueNode(fn_leaf_)); + } + + (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), + [](const std::pair &item) { return item.first; }); + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { + auto lhs = std::static_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "List in HyperMap should have same length"; + } + + // cannot use shared_from_base() also known as this, as it will make a reference cycle on + // hypermap and graph generated, it will cause memory leak. + auto fn_rec = NewValueNode(std::make_shared(*this)); + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeList)); + + for (int i = 0; i < SizeToInt(size); ++i) { + std::vector inputs2; + inputs2.push_back(fn_rec); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), + [&func_graph, i](const std::pair &item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { + auto lhs = std::static_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length"; + } + + // cannot use shared_from_base() also known as this, as it will make a reference cycle on + // hypermap and graph generated, it will cause memory leak. + auto fn_rec = NewValueNode(std::make_shared(*this)); + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (int i = 0; i < SizeToInt(size); ++i) { + std::vector inputs2; + inputs2.push_back(fn_rec); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + MS_EXCEPTION_IF_NULL(type); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); + inputs.push_back(NewValueNode(type)); + + // cannot use shared_from_base() also known as this, as it will make a reference cycle on + // hypermap and graph generated, it will cause memory leak. + auto fn_rec = NewValueNode(std::make_shared(*this)); + std::size_t attrSize = type->GetAttributes().size(); + for (std::size_t i = 0; i < attrSize; ++i) { + std::vector inputs2; + inputs2.push_back(fn_rec); + if (fn_arg) { + inputs2.push_back(fn_arg); + } + + int j = 0; + for (auto item : arg_map) { + inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + j++; + } + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { + bool found = false; + TypeId id = kObjectTypeEnd; + std::pair pair; + for (auto &item : arg_map) { + pair = item; + id = item.second->type_id(); + if (nonleaf_.count(id)) { + found = true; + break; + } + } + + if (found) { + // In a nonleaf situation, all arguments must have the same generic. + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { + if (item.first != pair.first) { + return item.second->type_id() != pair.second->type_id(); + } + return false; + }); + if (is_not_same) { + std::ostringstream oss; + oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" + << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + int idx = 0; + for (auto &item : arg_map) { + oss << ++idx << ": " << item.second->ToString() << "\n"; + } + MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); + } + } + + switch (id) { + case kObjectTypeList: { + auto type = std::static_pointer_cast(pair.second); + return FullMake(type, func_graph, fn_arg, arg_map); + } + case kObjectTypeTuple: { + auto type = std::static_pointer_cast(pair.second); + return FullMake(type, func_graph, fn_arg, arg_map); + } + case kObjectTypeClass: { + auto type = std::static_pointer_cast(pair.second); + return FullMake(type, func_graph, fn_arg, arg_map); + } + default: + return FullMake(pair.second, func_graph, fn_arg, arg_map); + } +} + +ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { + TypePtr type_tensor = std::make_shared(); + bool flag = std::any_of( + args_spec_list.begin(), args_spec_list.end(), + [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); + if (flag && broadcast_) { + ArgsPairList ret; + for (auto &item : args_spec_list) { + if (!IsSubType(item.second, type_tensor)) { + TypePtr type_tensor_ele = std::make_shared(item.second); + ret.push_back( + std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele)); + } else { + ret.push_back(std::make_pair(item.first, item.second)); + } + } + return ret; + } + return args_spec_list; +} + +FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("hyper_map"); + + AnfNodePtr ptrFnArg = nullptr; + std::size_t i = 0; + ArgsPairList argmap; + ArgsPairList argmap2; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + i = 1; + } + + std::size_t size = args_spec_list.size(); + for (; i < size; ++i) { + argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); + } + + argmap2 = Harmonize(ptrGraph, argmap); + ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); + return ptrGraph; +} + +abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + if (fn_leaf_ == nullptr) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + // Assert that hypermap's function param does not contain free variables + if (args_spec_list[0]->isa()) { + auto graph_func = dyn_cast(args_spec_list[0]); + auto func_graph = graph_func->func_graph(); + if (func_graph->parent() != nullptr) { + MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet."; + } + } + } + + AbstractBasePtrList broadened; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + return broadened; +} + +REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { + (void)py::class_>(*m, "HyperMap_") + .def(py::init>(), py::arg("leaf")) + .def(py::init<>()); + })); + +FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { + MS_EXCEPTION_IF_NULL(a_tuple); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret->debug_info()->set_name("tail"); + AnfNodePtr ptrTup = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + + int tuple_size = SizeToInt(a_tuple->size()); + for (int i = 1; i < tuple_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { + MS_EXCEPTION_IF_NULL(a_list); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret->debug_info()->set_name("tail"); + AnfNodePtr ptrList = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeList)); + + int list_size = SizeToInt(a_list->size()); + for (int i = 1; i < list_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; + } + + AbstractBasePtr a = args_spec_list[0]; + abstract::AbstractTuplePtr a_tuple = dyn_cast(a); + if (a_tuple != nullptr) { + return GenerateTupleFuncGraph(a_tuple); + } + + abstract::AbstractListPtr a_list = dyn_cast(a); + if (a_list != nullptr) { + return GenerateListFuncGraph(a_list); + } + + MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); +} + +REGISTER_PYBIND_DEFINE( + Tail_, ([](const py::module *m) { + (void)py::class_>(*m, "Tail_").def(py::init()); + })); + +FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + int tuple_size = SizeToInt(args_spec_list.size()); + + std::ostringstream ss; + ss << "▶make_tuple_" << tuple_size; + FuncGraphPtr fg = std::make_shared(); + fg->debug_info()->set_name(ss.str()); + + std::vector params; + params.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (int i = 0; i < tuple_size; ++i) { + params.push_back(fg->add_parameter()); + } + + // make fprob first result, maketuple's forward result. + AnfNodePtr out = fg->NewCNode(params); + + // make fprob second result, maketuple's backward function. + FuncGraphPtr b = std::make_shared(); + + ss.clear(); + ss << "◀make_tuple_" << tuple_size; + b->debug_info()->set_name(ss.str()); + AnfNodePtr dout = b->add_parameter(); + + std::vector grads; + grads.push_back(NewValueNode(prim::kPrimMakeTuple)); + grads.push_back(NewValueNode(newenv)); + for (int i = 0; i < tuple_size; ++i) { + grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); + } + + b->set_flag(FUNC_GRAPH_FLAG_CORE, true); + b->set_output(b->NewCNode(grads)); + + fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); + (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); + return fg; +} + +GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) + : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { + if (get_by_list) { + signatures_ = + // def grad(func:read, weight_list:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}}); + } +} + +FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, + const std::vector ¶ms_list, const std::vector &args, + bool applyJ) { + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + + auto weights_node = weights; + if (weights == nullptr && !args.empty()) { + weights_node = ret->NewCNode(args); + } + + ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); + ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); + + std::vector inputs; + if (applyJ) { + inputs.push_back(opsJ); + inputs.push_back(node); + node = ret->NewCNode(inputs); + } + + std::vector params; + for (size_t i = 0; i < params_list.size(); ++i) { + params.push_back(ret->add_parameter()); + } + + inputs.clear(); + inputs.push_back(node); + (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); + AnfNodePtr cnode = ret->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(opsTupleItem); + inputs.push_back(cnode); + inputs.push_back(NewValueNode(0)); + auto out = ret->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(opsTupleItem); + inputs.push_back(cnode); + inputs.push_back(NewValueNode(1)); + AnfNodePtr ptrBprop = ret->NewCNode(inputs); + + doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); + return ret; +} + +void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, + ValueNodePtr opsTupleItem) { + MS_EXCEPTION_IF_NULL(func_graph); + + AnfNodePtr ptrBPropArg = nullptr; + if (sens_param_) { + ptrBPropArg = func_graph->add_parameter(); + } else { + auto ones_like = prim::GetPythonOps("ones_like"); + ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); + } + + AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); + + CNodePtr fv_bprop = nullptr; + if (get_by_list_) { + // python code: grads = hyper_map(F.partial(env_get, env), weights) + AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); + AnfNodePtr partial_env_get = + func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); + MetaFuncGraphPtr hyper_map = std::make_shared(); + fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); + } + + CNodePtr inputs_bprop = nullptr; + if (get_all_) { + inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); + } + + // Gradients wrt inputs and parameters + if (fv_bprop != nullptr && inputs_bprop != nullptr) { + func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); + return; + } + + // Gradients wrt parameters + if (fv_bprop != nullptr) { + func_graph->set_output(fv_bprop); + return; + } + + // Gradients wrt inputs + if (inputs_bprop != nullptr) { + func_graph->set_output(inputs_bprop); + return; + } + + // Gradients wrt first input. + // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input + func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); +} + +// Generate the graph. +FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() < 1) { + MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " + << args_spec_list.size() << "."; + } + + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + AbstractFunctionPtr fn = dyn_cast(args_spec_list[0]); + if (fn == nullptr) { + MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); + } + + // Waiting for implementation. + auto real_fn = dyn_cast(fn); + MS_EXCEPTION_IF_NULL(real_fn); + + FuncGraphPtr ptrGraph = real_fn->func_graph(); + MS_EXCEPTION_IF_NULL(ptrGraph); + TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); + FuncGraphPtr dfBuilder = std::make_shared(); + TraceManager::EndTrace(); + auto nparam = ptrGraph->parameters().size(); + + std::ostringstream ss; + ss << "grad{" << nparam << "}"; + dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); + dfBuilder->debug_info()->set_name(ss.str()); + ParameterPtr param_graph = dfBuilder->add_parameter(); + + AnfNodePtr weights = nullptr; + if (get_by_list_) { + weights = dfBuilder->add_parameter(); + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimJ)); + inputs.push_back(param_graph); + auto jf = dfBuilder->NewCNode(inputs); + // df is checked in GetGrad + TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); + auto df = GetGrad(jf, weights, ptrGraph->parameters()); + TraceManager::EndTrace(); + dfBuilder->set_output(NewValueNode(df)); + + return dfBuilder; +} + +REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { + (void)py::class_>( + *m, "GradOperation_") + .def(py::init(), py::arg("fn")) + .def(py::init(), py::arg("fn"), py::arg("get_all"), + py::arg("get_by_list"), py::arg("sens_param")); + })); + +// Generate the ListMap func graph. +FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + size_t args_num = args_spec_list.size(); + // args: fn, list1, list2, ... + if (args_num < 2) { + MS_LOG(EXCEPTION) << "list_map takes at least two arguments"; + } + + for (size_t i = 1; i < args_num; ++i) { + if (typeid(args_spec_list[i]) != typeid(AbstractBase)) { + // The function currently not be use + MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'"; + } + } + + FuncGraphPtr fg_ptr = std::make_shared(); + fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fg_ptr->debug_info()->set_name("list_map"); + AnfNodePtr fn = fg_ptr->add_parameter(); + + std::vector lists; + for (size_t i = 1; i < args_num; ++i) { + lists.push_back(fg_ptr->add_parameter()); + } + + std::vector iters; + (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item}); + }); + + std::vector nexts; + (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); + }); + + std::vector values; + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item}); + }); + + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); + }); + + (void)values.insert(values.begin(), fn); + AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); + AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph}); + + FuncGraphPtr fgnext_ptr = std::make_shared(); + fgnext_ptr->debug_info()->set_name("body"); + + FuncGraphPtr fgcond_ptr = std::make_shared(); + fgcond_ptr->debug_info()->set_name("cond"); + + MakeCond(lists, fgnext_ptr, fgcond_ptr); + MakeNext(lists, fgcond_ptr, fgnext_ptr); + + CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); + + auto inputs = output_cnode->inputs(); + (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); + output_cnode->set_inputs(inputs); + + fg_ptr->set_output(output_cnode); + return fg_ptr; +} + +void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, + const FuncGraphPtr &fg_ptr) { + MS_EXCEPTION_IF_NULL(fg_ptr); + + AnfNodePtr fn = fg_ptr->add_parameter(); + AnfNodePtr resl = fg_ptr->add_parameter(); + + std::vector iters; + (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), + [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); + + std::vector hasnexts; + (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item}); + }); + + // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) + FuncGraphPtr fgtrue_ptr = std::make_shared(); + fgtrue_ptr->debug_info()->set_name("ftrue"); + fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); + + CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); + auto inputs = fgtrue_output_cnode->inputs(); + (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); + fgtrue_output_cnode->set_inputs(inputs); + fgtrue_ptr->set_output(fgtrue_output_cnode); + + FuncGraphPtr fgfalse_ptr = std::make_shared(); + fgfalse_ptr->debug_info()->set_name("ffalse"); + fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fgfalse_ptr->set_output(resl); + + AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), + NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)}); + fgtrue_ptr->set_output(output_cnode); +} + +void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, + const FuncGraphPtr &fg_ptr) { + MS_EXCEPTION_IF_NULL(fg_ptr); + AnfNodePtr fn = fg_ptr->add_parameter(); + + std::vector iters; + (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), + [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); + + std::vector nexts; + (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); + }); + + std::vector values; + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr}); + }); + + iters.clear(); + (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { + return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); + }); + + (void)values.insert(values.begin(), fn); + AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); + AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph}); + CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); + + auto inputs = output_cnode->inputs(); + (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); + output_cnode->set_inputs(inputs); + fg_ptr->set_output(output_cnode); +} + +FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // args: tuple1, tuple2 + abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); + AbstractBasePtr abs_a = args_spec_list[0]; + AbstractBasePtr abs_b = args_spec_list[1]; + + abstract::AbstractTuplePtr a_tuple = dyn_cast(abs_a); + abstract::AbstractTuplePtr b_tuple = dyn_cast(abs_b); + if (a_tuple == nullptr || b_tuple == nullptr) { + MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " + << args_spec_list[1]->ToString(); + } + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr p_tup_a = ret->add_parameter(); + AnfNodePtr p_tup_b = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + + int tuple_size = SizeToInt(a_tuple->size()); + for (int i = 0; i < tuple_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)})); + } + + tuple_size = SizeToInt(b_tuple->size()); + for (int i = 0; i < tuple_size; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)})); + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { + MS_EXCEPTION_IF_NULL(scalar); + return GetValue(scalar->BuildValue()); +} + +bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); } + +int GetPositiveIndex(int index, int length) { + if (index < 0) { + index += length; + } + return index; +} + +int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { + MS_EXCEPTION_IF_NULL(member); + + if (member->isa()) { + return GetArgScalarValue(dyn_cast(member), member_name); + } + + if (member->isa()) { + return default_value; + } + + MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString(); +} + +void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, + int *stop_index, int *step_value) { + MS_EXCEPTION_IF_NULL(tuple); + MS_EXCEPTION_IF_NULL(slice); + MS_EXCEPTION_IF_NULL(start_index); + MS_EXCEPTION_IF_NULL(stop_index); + MS_EXCEPTION_IF_NULL(step_value); + + const std::string start_name("Slice start index"); + const std::string stop_name("Slice stop index"); + const std::string step_name("Slice step value"); + + int tuple_size = SizeToInt(tuple->size()); + int start_default = 0; + int stop_default = tuple_size; + int step_default = 1; + + *step_value = CheckSliceMember(slice->step(), step_default, step_name); + if (*step_value == 0) { + MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0."; + } + + if (*step_value < 0) { + start_default = tuple_size - 1; + stop_default = -1; + } + + *start_index = CheckSliceMember(slice->start(), start_default, start_name); + *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name); + if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) || + !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) { + MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index + << " out of range, tuple size " << tuple_size << "."; + } + + *start_index = GetPositiveIndex(*start_index, tuple_size); + if (!slice->stop()->isa()) { + *stop_index = GetPositiveIndex(*stop_index, tuple_size); + } +} + +FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // slice a tuple + // args: tuple, start index, end index, step + const std::string op_name("TupleSlice"); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr tuple = abstract::CheckArg(op_name, args_spec_list, 0); + AbstractSlicePtr slice = abstract::CheckArg(op_name, args_spec_list, 1); + + int start_index; + int stop_index; + int step_value; + GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr p_tuple = ret->add_parameter(); + (void)ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeTuple)); + if (step_value > 0) { + for (int index = start_index; index < stop_index; index = index + step_value) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); + } + } else { + for (int index = start_index; index > stop_index; index = index + step_value) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); + } + } + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // select indexed item + // args: tuple of items, index + const std::string op_name = std::string("TupleGetItemTensor"); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr branches_abs = abstract::CheckArg(op_name, args_spec_list, 0); + AbstractBasePtrList branches = branches_abs->elements(); + if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa()) { + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr functions = ret_graph->add_parameter(); + auto index = ret_graph->add_parameter(); + + ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); + return ret_graph; + } + + MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; +} + +REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleAdd_") + .def(py::init()); + })); + +REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleSlice_") + .def(py::init()); + })); + +REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { + (void)py::class_>( + *m, "TupleGetItemTensor_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h new file mode 100644 index 0000000000..3821192dba --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -0,0 +1,192 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "frontend/operator/composite/zip_operation.h" +#include "frontend/operator/composite/list_append_operation.h" +#include "frontend/operator/composite/do_signature.h" +#include "frontend/operator/composite/unpack_call.h" +#include "frontend/operator/composite/multitype_funcgraph.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using AbstractSlicePtr = abstract::AbstractSlicePtr; +using AbstractScalarPtr = abstract::AbstractScalarPtr; +using AbstractTensorPtr = abstract::AbstractTensorPtr; +using ElemwiseMap = std::unordered_map; +using ArgsPairList = std::vector>; + +class HyperMap : public MetaFuncGraph { + public: + explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); + HyperMap(const HyperMap &h); + void Init(); + HyperMap &operator=(const HyperMap &h) { + if (this != &h) { + fn_leaf_ = h.fn_leaf_; + broadcast_ = h.broadcast_; + nonleaf_ = h.nonleaf_; + if (fn_leaf_) { + name_ = "hyper_map[" + fn_leaf_->name() + "]"; + } + } + return *this; + } + ~HyperMap() override = default; + MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) + + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; + MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } + + private: + AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); + ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); + + MultitypeFuncGraphPtr fn_leaf_; + bool broadcast_; + std::set nonleaf_; +}; +using HyperMapPtr = std::shared_ptr; + +class HyperMapPy : public HyperMap { + public: + explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} + ~HyperMapPy() override = default; + MS_DECLARE_PARENT(HyperMapPy, HyperMap) +}; +using HyperMapPyPtr = std::shared_ptr; + +extern ValuePtr kCompositeHyperMap; + +class Tail : public MetaFuncGraph { + public: + explicit Tail(const std::string &name) : MetaFuncGraph(name) {} + ~Tail() override = default; + MS_DECLARE_PARENT(Tail, MetaFuncGraph) + + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); + FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); + + friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } +}; +using TailPtr = std::shared_ptr; + +class MakeTupleGradient : public MetaFuncGraph { + public: + explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} + ~MakeTupleGradient() override = default; + MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } +}; +using MakeTupleGradientPtr = std::shared_ptr; + +class GradOperation : public MetaFuncGraph { + public: + explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, + bool sens_param = false); + ~GradOperation() override = default; + MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) + + FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, + const std::vector &args = {}, bool applyJ = false); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + bool sens_param() const { return sens_param_; } + bool get_all_; + bool get_by_list_; + bool sens_param_; + + private: + void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, + ValueNodePtr opsTupleItem); +}; +using GradOperationPtr = std::shared_ptr; + +class ListMap { + public: + explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } + ~ListMap() = default; + void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); + void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); + + private: + std::string name_; + std::map, FuncGraphPtr> cache_; +}; + +class TupleAdd : public MetaFuncGraph { + public: + explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} + ~TupleAdd() override = default; + MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } +}; +using TupleAddPtr = std::shared_ptr; + +class TupleSlice : public MetaFuncGraph { + public: + explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} + ~TupleSlice() override = default; + MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } +}; +using TupleSlicePtr = std::shared_ptr; + +class TupleGetItemTensor : public MetaFuncGraph { + public: + explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} + ~TupleGetItemTensor() override = default; + MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { + return lhs.name_ == rhs.name_; + } +}; +using TupleGetItemTensorPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc new file mode 100644 index 0000000000..50be3c5b29 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -0,0 +1,338 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/do_signature.h" +#include +#include + +#include "abstract/abstract_value.h" +#include "ir/anf.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "utils/symbolic.h" +#include "./common.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3}, + {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6}, + {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}}; +namespace { +const std::vector &GetSignature(const ValuePtr &function) { + static const auto empty = std::vector(); + if (function->isa() && function->cast()->has_signature()) { + return function->cast()->signatures(); + } else if (function->isa()) { + return function->cast()->signatures(); + } + return empty; +} + +void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, + const std::vector &signature, bool has_var, std::vector *const op_inputs) { + std::size_t sig_size = signature.size(); + auto positional_size = sig_size; + if (has_var) { + positional_size = sig_size - 1; + } + if (args_spec_list.size() < positional_size) { + for (size_t i = args_spec_list.size(); i < sig_size; ++i) { + auto default_value = signature[i].default_value; + if (default_value == nullptr) { + MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; + } else { + (*op_inputs).push_back(NewValueNode(default_value)); + } + } + } +} + +void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { + *max_type_id = type_id; + *max_type_number = type_number; +} + +bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, + TypeId *arg_type = nullptr) { + if (arg_value->isa()) { + if (is_write) { + arg_value = arg_value->cast()->ref_origin(); + } else { + arg_value = arg_value->cast()->ref(); + } + } + if (arg_value->isa()) { + auto tensor = arg_value->cast(); + auto tensor_type = tensor->element()->BuildType(); + MS_EXCEPTION_IF_NULL(tensor_type); + *arg_type_id = tensor_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeTensorType; + } + return true; + } + if (arg_value->isa()) { + auto scalar = arg_value->cast(); + auto scalar_type = scalar->BuildType(); + MS_EXCEPTION_IF_NULL(scalar_type); + *arg_type_id = scalar_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeNumber; + } + return true; + } + return false; +} + +TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, + const std::set &write_indices) { + TypeId max_type_id = kTypeUnknown; + size_t max_type_number = 0; + bool has_int8 = false; + bool has_scalar_int32 = false; + bool has_scalar_float32 = false; + for (const auto &index : indices) { + TypeId arg_type_id = kTypeUnknown; + TypeId arg_type = kTypeUnknown; + auto is_write = (write_indices.find(index) != write_indices.end()); + if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { + continue; + } + if (arg_type != kObjectTypeTensorType) { + if (arg_type_id == kNumberTypeInt32) { + has_scalar_int32 = true; + } else if (arg_type_id == kNumberTypeFloat32) { + has_scalar_float32 = true; + } + continue; + } + auto it = type_map.find(arg_type_id); + if (it == type_map.end()) { + continue; + } + if (arg_type_id == kNumberTypeInt8) { + has_int8 = true; + } + if (max_type_id == kTypeUnknown) { + SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); + continue; + } + if (it->second > max_type_number) { + SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); + } + } + + if (max_type_id == kNumberTypeUInt8 && has_int8 == true) { + max_type_id = kNumberTypeInt16; + } + // if bool is the max type, see if there is scalar input + // if so, it means that max is bool tensor, use scalar type instead. + // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) + if (max_type_id == kNumberTypeBool) { + if (has_scalar_int32) { + max_type_id = kNumberTypeInt32; + } + if (has_scalar_float32) { + max_type_id = kNumberTypeFloat32; + } + } + return max_type_id; +} + +// Get the largest type of index in the same SignatureEnumDType of arguments. +std::map GetMaxDtype(const std::vector &dtypes, + const abstract::AbstractBasePtrList &args_spec_list, + const std::set &write_indices) { + // record index for signature.dtypes of the same type + // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} + std::map> type_indices; + for (size_t i = 0; i < dtypes.size(); ++i) { + auto it = type_indices.find(dtypes[i]); + if (it == type_indices.end()) { + (void)type_indices.insert(std::make_pair(dtypes[i], std::vector{i})); + } else { + it->second.push_back(i); + } + } + std::map dst_type; + for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { + auto type = it->first; + auto indices = it->second; + // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. + if (indices.size() < 2) { + continue; + } + bool has_tensor = false; + for (const auto &index : indices) { + AbstractBasePtr arg_value = args_spec_list[index]; + if (arg_value->isa()) { + arg_value = arg_value->cast()->ref(); + } + if (arg_value->isa()) { + has_tensor = true; + break; + } + } + if (!has_tensor) { + (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); + continue; + } + (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); + } + return dst_type; +} + +AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGraphPtr &graph) { + auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations"); + MS_EXCEPTION_IF_NULL(prim_cast_class); + auto dtype_node = NewValueNode(TypeIdToType(type_id)); + auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph); + return NewCNode({cast_node, param, dtype_node}, graph); +} + +void DoAutoCast(const std::string &func_name, const std::vector &signature, + const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, + std::vector *const op_inputs, const std::set &write_indices) { + std::vector dtypes; + (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), + [](const Signature &sig) { return sig.dtype; }); + int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); + if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { + return; + } + // Stat the index of the arguments with the largest type in the same SignatureEnumDType. + std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); + // Identify which arg requires auto cast + for (size_t i = 0; i < args_spec_list.size(); ++i) { + auto it = dst_type.find(dtypes[i]); + if (it == dst_type.end() || it->second == kTypeUnknown) { + continue; + } + auto rw_it = write_indices.find(i); + auto is_write = (rw_it != write_indices.end()); + + TypeId arg_type_id = kTypeUnknown; + AbstractBasePtr arg_value = args_spec_list[i]; + (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); + auto it_map = type_name_map.find(arg_type_id); + if (it_map == type_name_map.end()) { + continue; + } + if (is_write) { + if (arg_type_id != it->second) { + auto it_name_map = type_name_map.find(it->second); + if (it_name_map == type_name_map.end()) { + continue; + } + RaiseExceptionForConvertRefDtype(func_name, it_map->second, it_name_map->second); + } + continue; + } + if (arg_value->isa() && arg_type_id == it->second) { + continue; + } + (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); + } +} + +AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { + // args: original inputs + auto &signature = GetSignature(function); + std::size_t sig_size = signature.size(); + auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); + if (sig_size > 0) { + if (has_var) { + if (sig_size - 1 > args_spec_list.size()) { + MS_LOG(EXCEPTION) << "Function " << func_name + << "'s input length less than PositionalKeyword Signature length."; + } + } else if (args_spec_list.size() > sig_size) { + MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; + } + } + std::vector op_inputs; + std::set write_indices; + op_inputs.push_back(NewValueNode(function)); + // Assume, the write input of op is always the first input. We check if any write op, + // and add cast op on other inputs to keep the same type with assigned parameter. + for (size_t i = 0; i < args_spec_list.size(); ++i) { + AnfNodePtr param = params_list[i]; + if (args_spec_list[i] == nullptr) { + op_inputs.push_back(param); + continue; + } + SignatureEnumRW sig = SignatureEnumRW::kRWDefault; + // If sig_size is 0 use defalut. + if (sig_size > 0 && i < sig_size) { + sig = signature[i].rw; + } else if (has_var && i >= sig_size) { + sig = signature[sig_size - 1].rw; + } + + TypePtr type = args_spec_list[i]->GetTypeTrack(); + if (type && type->type_id() == kObjectTypeRef) { + if (sig == SignatureEnumRW::kRWRead) { + param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); + } else if (sig == SignatureEnumRW::kRWWrite) { + param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); + write_indices.insert(i); + } + // If sig is SignatureEnumRW::kRWRef, not do anything. + } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { + MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; + } + op_inputs.push_back(param); + } + // process default + ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); + DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); + return func_graph->NewCNode(op_inputs); +} +} // namespace + +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { + auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); + return new_cnode; +} + +FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + FuncGraphPtr func_graph = std::make_shared(); + + for (size_t i = 0; i < args_spec_list.size(); ++i) { + (void)func_graph->add_parameter(); + } + auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); + func_graph->set_output(new_cnode); + func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + return func_graph; +} + +void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, + const std::string &target_type) { + MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" + << "the type of writable argument is '" << ref_type << "', " + << "but the largest type in the same SignatureEumDtype is '" << target_type + << "'. The writable arg type is not equal to the largest type, " + << "so can not cast automatically."; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.h b/mindspore/ccsrc/frontend/operator/composite/do_signature.h new file mode 100644 index 0000000000..9139be806a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +class DoSignatureMetaFuncGraph : public MetaFuncGraph { + public: + explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) + : MetaFuncGraph("S-" + name), function_(function) {} + + ~DoSignatureMetaFuncGraph() override = default; + + MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) + + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; + const ValuePtr function() const { return function_; } + + friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { + return &lhs == &rhs; + } + + private: + ValuePtr function_; +}; +using RWSignaturePtr = std::shared_ptr; + +extern const std::map type_map; + +void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, + const std::string &target_type); + +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/list_append_operation.cc b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.cc new file mode 100644 index 0000000000..3dfe2e23d0 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/list_append_operation.h" + +#include +#include +#include + +#include "abstract/param_validator.h" +#include "frontend/optimizer/opt.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { + abstract::CheckArgsSize("ListAppend", args_list, 2); + + AbstractBasePtr arg0 = args_list[0]; + abstract::AbstractListPtr arg0_list = dyn_cast(arg0); + MS_EXCEPTION_IF_NULL(arg0_list); + + FuncGraphPtr ret = std::make_shared(); + ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ret->debug_info()->set_name("append"); + AnfNodePtr arg0_node = ret->add_parameter(); + + std::vector elems; + elems.push_back(NewValueNode(prim::kPrimMakeList)); + size_t arg0_length = arg0_list->size(); + for (size_t i = 0; i < arg0_length; ++i) { + elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToInt(i))})); + } + AnfNodePtr arg1_node = ret->add_parameter(); + elems.push_back(arg1_node); + + ret->set_output(ret->NewCNode(elems)); + return ret; +} + +REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { + (void)py::class_>(*m, "ListAppend_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.h b/mindspore/ccsrc/frontend/operator/composite/list_append_operation.h similarity index 100% rename from mindspore/ccsrc/operator/composite/list_append_operation.h rename to mindspore/ccsrc/frontend/operator/composite/list_append_operation.h diff --git a/mindspore/ccsrc/frontend/operator/composite/map.cc b/mindspore/ccsrc/frontend/operator/composite/map.cc new file mode 100644 index 0000000000..a5f674187b --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/map.cc @@ -0,0 +1,292 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/map.h" +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/dshape.h" +#include "pybind_api/api_register.h" +#include "debug/trace.h" +#include "frontend/operator/ops.h" +#include "./common.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; + +AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { + MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; + MS_EXCEPTION_IF_NULL(func_graph); + std::vector inputs; + if (fn_arg != nullptr) { + inputs.emplace_back(fn_arg); + } else { + inputs.emplace_back(NewValueNode(fn_leaf_)); + } + inputs.insert(inputs.end(), args.begin(), args.end()); + return func_graph->NewCNode(inputs); +} + +FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { + // Generate func for leaf nodes + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("map"); + AnfNodePtr ptrFnArg = nullptr; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + } + AnfNodePtrList args; + for (size_t i = 0; i < args_size; ++i) { + args.emplace_back(ptrGraph->add_parameter()); + } + ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); + return ptrGraph; +} + +AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + auto lhs = std::dynamic_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "List in Map should have same length"; + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeList)); + + for (int i = 0; i < SizeToInt(size); ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, i](const std::pair &item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(type); + + std::size_t size = type->elements().size(); + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { + auto lhs = std::dynamic_pointer_cast(item.second); + MS_EXCEPTION_IF_NULL(lhs); + return lhs->elements().size() != size; + }); + if (is_not_same) { + MS_LOG(EXCEPTION) << "tuple in Map should have same length"; + } + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + for (int i = 0; i < SizeToInt(size); ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + (void)std::transform( + arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), + [&func_graph, &i](std::pair item) { + return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + }); + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + MS_EXCEPTION_IF_NULL(type); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector inputs; + inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); + inputs.push_back(NewValueNode(type)); + + std::size_t attrSize = type->GetAttributes().size(); + for (std::size_t i = 0; i < attrSize; ++i) { + MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; + auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); + auto fn = NewValueNode(ptrGraph); + + std::vector inputs2; + inputs2.push_back(fn); + if (fn_arg != nullptr) { + inputs2.push_back(fn_arg); + } + + int j = 0; + for (auto item : arg_pairs) { + inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + j++; + } + + inputs.push_back(func_graph->NewCNode(inputs2)); + } + return func_graph->NewCNode(inputs); +} + +AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { + if (arg_pairs.empty()) { + MS_EXCEPTION(TypeError) << "map() must have at least two arguments"; + } + bool found = false; + TypeId id = kObjectTypeEnd; + std::pair pair; + for (auto &item : arg_pairs) { + pair = item; + MS_LOG(DEBUG) << "Map " << pair.second->ToString(); + id = item.second->type_id(); + if (nonleaf_.count(id)) { + found = true; + break; + } + } + + if (found) { + // In a nonleaf situation, all arguments must have the same generic. + bool is_not_same = + std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair &item) { + if (item.first != pair.first) { + return item.second->type_id() != pair.second->type_id(); + } + return false; + }); + if (is_not_same) { + std::ostringstream oss; + oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" + << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + int idx = 0; + for (auto &item : arg_pairs) { + oss << ++idx << ": " << item.second->ToString() << "\n"; + } + MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" + << oss.str() << pair.second->ToString() << "\n"; + } + } + + switch (id) { + case kObjectTypeList: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeList(type, func_graph, fn_arg, arg_pairs); + } + case kObjectTypeTuple: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); + } + case kObjectTypeClass: { + auto type = std::static_pointer_cast(pair.second); + return FullMakeClass(type, func_graph, fn_arg, arg_pairs); + } + default: + MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " + << ", but got " << pair.second->ToString(); + } +} + +FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { + FuncGraphPtr ptrGraph = std::make_shared(); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); + ptrGraph->debug_info()->set_name("map"); + + AnfNodePtr ptrFnArg = nullptr; + std::size_t i = 0; + if (fn_leaf_ == nullptr) { + ptrFnArg = ptrGraph->add_parameter(); + i = 1; + } + ArgsPairList arg_pairs; + std::size_t size = args_spec_list.size(); + for (; i < size; ++i) { + MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); + arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); + } + + ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); + return ptrGraph; +} + +abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + if (fn_leaf_ == nullptr) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + // Assert that map's function param does not contain free variables + if (args_spec_list[0]->isa()) { + auto graph_func = dyn_cast(args_spec_list[0]); + auto func_graph = graph_func->func_graph(); + if (func_graph->parent() != nullptr) { + MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; + } + } + } + + AbstractBasePtrList broadened; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + return broadened; +} + +REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { + (void)py::class_>(*m, "Map_") + .def(py::init>(), py::arg("leaf")) + .def(py::init<>()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/map.h b/mindspore/ccsrc/frontend/operator/composite/map.h new file mode 100644 index 0000000000..428014f9c4 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/map.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ + +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "frontend/operator/composite/multitype_funcgraph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using ArgsPairList = std::vector>; + +class Map : public MetaFuncGraph { + public: + explicit Map(const std::shared_ptr &fn_leaf = nullptr) + : MetaFuncGraph("map"), + fn_leaf_(fn_leaf), + broadcast_(false), + nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { + Init(); + } + Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { + Init(); + } + Map &operator=(const Map &h) { + if (this != &h) { + fn_leaf_ = h.fn_leaf_; + broadcast_ = h.broadcast_; + nonleaf_ = h.nonleaf_; + if (fn_leaf_) { + name_ = "map[" + fn_leaf_->name() + "]"; + } + } + return *this; + } + ~Map() override = default; + MS_DECLARE_PARENT(Map, MetaFuncGraph) + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; + MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } + + private: + FuncGraphPtr GenerateLeafFunc(const size_t &args_size); + AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); + AnfNodePtr FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_pairs); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); + void Init() { + if (fn_leaf_ != nullptr) { + name_ = "map[" + fn_leaf_->name() + "]"; + } + signatures_ = + // def map(func:read, *args:ref): + std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); + } + + MultitypeFuncGraphPtr fn_leaf_; + bool broadcast_; + std::set nonleaf_; +}; +using MapPtr = std::shared_ptr; +class MapPy : public Map { + public: + explicit MapPy(const std::shared_ptr &fn_leaf = nullptr) : Map(fn_leaf) {} + ~MapPy() override = default; + MS_DECLARE_PARENT(MapPy, Map) +}; +using MapPyPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc new file mode 100644 index 0000000000..ba0d3d9ebb --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -0,0 +1,198 @@ + +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/multitype_funcgraph.h" +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "utils/context/ms_context.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" +#include "./common.h" +#include "ir/signature.h" +#include "debug/trace.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { + fn_cache_.clear(); + signatures_ = std::vector({// def multitype(*args:ref): + {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); +} + +void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { + MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; + auto fn = fn_cache_.find(types); + if (fn != fn_cache_.end()) { + MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; + } + fn_cache_[types] = s_fn; +} + +void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { + MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; + auto fn = fn_cache_.find(types); + if (fn != fn_cache_.end()) { + MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; + } + fn_cache_py_[types] = py_fn; +} + +void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { + TypePtrList types; + for (auto &type_name : types_name) { + auto type_ptr = StringToType(type_name); + if (type_ptr == nullptr) { + MS_LOG(EXCEPTION) << type_name << " convert from string error "; + } + types.push_back(type_ptr); + } + Register(types, py_fn); +} + +void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { + std::vector types_name; + for (size_t it = 0; it < tuple.size(); ++it) { + py::object name_py = tuple[it]; + if (py::isinstance(name_py)) { + types_name.push_back(name_py.cast()); + continue; + } + MS_LOG(EXCEPTION) << "Register must be string"; + } + Register(types_name, py_fn); +} +static TypePtr UnwrapRef(const TypePtr &type) { + if (type->isa()) { + return type->cast()->subtype(); + } + return type; +} + +// Return Exact match if exists, else return non ambiguous sub class match +// Return py::none() if matching is ambiguous +const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { + // Exact match + for (auto &item : fn_cache_py_) { + TypePtrList sign = item.first; + if (sign.size() != types.size()) { + continue; + } + auto match = true; + for (size_t i = 0; i < sign.size(); ++i) { + if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { + match = false; + break; + } + } + if (!match) { + continue; + } + return item.second; + } + return py::none(); +} + +FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + std::vector parameters; + ParameterPtr undetermined_param = nullptr; + auto stub = std::make_shared(); + for (size_t i = 0; i < types.size(); ++i) { + auto param = stub->add_parameter(); + parameters.push_back(param); + if (types[i]->type_id() == kObjectTypeUndeterminedType) { + undetermined_param = param; + } + } + if (undetermined_param != nullptr) { + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i]->type_id() == kObjectTypeFunction) { + std::vector call_prim{parameters[i], undetermined_param}; + inputs.push_back(stub->NewCNode(call_prim)); + } else { + inputs.push_back(parameters[i]); + } + } + auto stub_output = stub->NewCNode(inputs); + stub->set_output(stub_output); + stub->set_stub(true); + return stub; + } + return nullptr; +} + +FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { + auto py_fn = SignMatch(types); + std::ostringstream buffer; + buffer << types; + if (py_fn != py::none()) { + FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); + } + MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); + return func_graph; + } + auto stub = GenerateStubFunc(types); + if (stub != nullptr) { + MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString(); + return stub; + } + std::ostringstream oss; + oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ + << "`, corresponding location info:\n"; + int idx = 0; + for (auto &item : fn_cache_py_) { + FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); + if (func_graph == nullptr) { + MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; + continue; + } + oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; + } + MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n" + << oss.str(); +} + +REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { + (void)py::class_>( + *m, "MultitypeFuncGraph_") + .def(py::init()) + .def("register_fn", &MultitypeFuncGraph::PyRegister); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h new file mode 100644 index 0000000000..2139a0e9d1 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h @@ -0,0 +1,65 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +class MultitypeFuncGraph : public MetaFuncGraph { + public: + explicit MultitypeFuncGraph(const std::string &name); + ~MultitypeFuncGraph() override = default; + MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) + + using specialize_fn = FuncGraph *(*)(TypePtrList); + // Register a method which specialize based on types vectors; + virtual void Register(const TypePtrList &types, specialize_fn s_fn); + virtual void Register(const TypePtrList &types, const py::function &py_fn); + virtual void Register(const std::vector &types_name, const py::function &py_fn); + virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); + + FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; + size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } + const std::unordered_map &GetPyFunctions() const { + return fn_cache_py_; + } + + private: + const py::function SignMatch(const TypePtrList &types); + std::unordered_map fn_cache_; + std::unordered_map fn_cache_py_; +}; +using MultitypeFuncGraphPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc new file mode 100644 index 0000000000..2c9e0b538f --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/unpack_call.h" +#include +#include + +#include "./common.h" +#include "abstract/abstract_value.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" +#include "frontend/operator/cc_implementations.h" +#include "ir/anf.h" +#include "frontend/optimizer/opt.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractKeywordArg; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; + +FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // slice a tensor + // args: tensor, slice or slice tuple + const std::string op_name = std::string("UnpackCall"); + size_t arg_length = args_spec_list.size(); + if (arg_length < 2) { + MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; + } + + (void)abstract::CheckArg(op_name, args_spec_list, 0); + auto ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + + AnfNodePtr fnNode = ret_graph->add_parameter(); + std::vector elems; + elems.push_back(fnNode); + for (size_t index = 1; index < arg_length; index++) { + MS_EXCEPTION_IF_NULL(args_spec_list[index]); + if (args_spec_list[index]->isa()) { + auto arg_tuple = args_spec_list[index]->cast(); + AnfNodePtr para_tuple = ret_graph->add_parameter(); + for (size_t i = 0; i < arg_tuple->size(); ++i) { + elems.push_back( + ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); + } + } else if (args_spec_list[index]->isa()) { + AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast(); + AnfNodePtr para_dict = ret_graph->add_parameter(); + auto dict_elems = arg_dict->elements(); + (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), + [ret_graph, para_dict](const AbstractAttribute &item) { + auto dict_get_item = ret_graph->NewCNode( + {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); + return ret_graph->NewCNode( + {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); + }); + } else { + MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " + << args_spec_list[index]->ToString(); + } + } + ret_graph->set_output(ret_graph->NewCNode(elems)); + return ret_graph; +} + +REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { + (void)py::class_>(*m, "UnpackCall_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.h b/mindspore/ccsrc/frontend/operator/composite/unpack_call.h new file mode 100644 index 0000000000..79c2600f36 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +// Expand the tuple and dict parameters generated when parsing the function call, +// and generate positional parameters and key-value pairs for function. +class UnpackCall : public MetaFuncGraph { + public: + explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} + ~UnpackCall() override = default; + MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } +}; +using UnpackCallPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ diff --git a/mindspore/ccsrc/frontend/operator/composite/zip_operation.cc b/mindspore/ccsrc/frontend/operator/composite/zip_operation.cc new file mode 100644 index 0000000000..9e2b6d28b2 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/zip_operation.cc @@ -0,0 +1,92 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/composite/zip_operation.h" +#include + +#include "abstract/abstract_value.h" +#include "ir/anf.h" +#include "abstract/dshape.h" +#include "frontend/operator/cc_implementations.h" +#include "frontend/optimizer/opt.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractSequeue; +using mindspore::abstract::AbstractSequeuePtr; +using mindspore::abstract::AbstractTuple; + +FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // zip operation: + // input: tuple arguments + // output: tuple of items of input iterated on every input + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "For 'zip', there is at least one input."; + } + + auto is_all_sequeue = + std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { + MS_EXCEPTION_IF_NULL(abs); + return abs->isa(); + }); + if (!is_all_sequeue) { + MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence."; + } + + auto min_abs = std::min_element( + args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) { + return (x->cast()->size() < y->cast()->size()); + }); + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + for (size_t idx = 0; idx < args_spec_list.size(); idx++) { + (void)ret_graph->add_parameter(); + } + + // generate tuple output of ziped arguments input + std::vector make_tuple_nodes; + make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t idx = 0; idx < (*min_abs)->cast()->size(); idx++) { + std::vector make_tuple_zip_nodes; + make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl"; + ValuePtr op = prim::GetPythonOps("getitem", module_name); + for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) { + std::vector tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx], + NewValueNode(SizeToInt(idx))}; + auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes); + make_tuple_zip_nodes.push_back(tuple_get_item_op); + } + auto make_tuple_zip_op = ret_graph->NewCNode(make_tuple_zip_nodes); + make_tuple_nodes.push_back(make_tuple_zip_op); + } + ret_graph->set_output(ret_graph->NewCNode(make_tuple_nodes)); + return ret_graph; +} + +REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { + (void)py::class_>(*m, + "ZipOperation_") + .def(py::init()); + })); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/composite/zip_operation.h b/mindspore/ccsrc/frontend/operator/composite/zip_operation.h new file mode 100644 index 0000000000..96697cb472 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/composite/zip_operation.h @@ -0,0 +1,59 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using AbstractBasePtr = abstract::AbstractBasePtr; +using AbstractBasePtrList = abstract::AbstractBasePtrList; +using AbstractTuplePtr = abstract::AbstractTuplePtr; + +class ZipOperation : public MetaFuncGraph { + public: + explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} + ~ZipOperation() override = default; + MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { + os << op.name_; + return os; + } + friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } +}; +using ZipOperationPtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ diff --git a/mindspore/ccsrc/frontend/operator/ops.cc b/mindspore/ccsrc/frontend/operator/ops.cc new file mode 100755 index 0000000000..5c7672ee3c --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/ops.cc @@ -0,0 +1,288 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/ops.h" +#include +#include + +namespace mindspore { +// namespace to support primitive operators +namespace prim { +// Arithmetic +const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); +const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); +const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); +const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); +const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); +const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); +const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); +const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); +const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); +const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); +const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); +const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); +const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); +const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); +const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); +const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); + +// Comparisons +const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); +const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); +const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); +const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); +const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); +const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); +const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); +const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); +const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); +const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); +const PrimitivePtr kPrimGreater = std::make_shared("Greater"); +const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); +const PrimitivePtr kPrimLess = std::make_shared("Less"); +const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); +const PrimitivePtr kPrimEqual = std::make_shared("Equal"); +const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); + +// Type introspection +const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); +const PrimitivePtr kPrimHasType = std::make_shared("hastype"); + +// Statements +const PrimitivePtr kPrimSwitch = std::make_shared("switch"); +const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); +const PrimitivePtr kPrimReturn = std::make_shared("return"); +const PrimitivePtr kPrimAssign = std::make_shared("Assign"); +const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); +const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); +const PrimitivePtr kPrimSelect = std::make_shared("Select"); +const PrimitivePtr kPrimCall = std::make_shared("call"); + +const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); +const PrimitivePtr kPrimDot = std::make_shared("dot"); +const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); +const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); +const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); +const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); + +const PrimitivePtr kPrimResolve = std::make_shared("resolve"); +const PrimitivePtr kPrimEmbed = std::make_shared("embed"); +const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); +const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); + +const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); +const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); +const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); + +// Structure +const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); +const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); +const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); +const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); +const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); +const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); +const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); +const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); +const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); +const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); +const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); +const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); +const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); +const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); +const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); +const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); +const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); +const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); +const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); +const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); +const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); +const PrimitivePtr kPrimListLen = std::make_shared("list_len"); +const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); +const PrimitivePtr kPrimListMap = std::make_shared("list_map"); +const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); +const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); + +const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); +const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); +const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); +const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); +const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); +const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); +const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); +const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); +const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); +const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); +const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); + +// Arrays +const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); +const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); +const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); +const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); +const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); +const PrimitivePtr kPrimShape = std::make_shared("Shape"); +const PrimitivePtr kPrimCast = std::make_shared("Cast"); +const PrimitivePtr kPrimConcat = std::make_shared("Concat"); +const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); +const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); +const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); +const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); +const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); +const PrimitivePtr kPrimSize = std::make_shared("Size"); +const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); +const PrimitivePtr kPrimPack = std::make_shared("Pack"); +const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); +const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); +const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); +const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); +const PrimitivePtr kPrimTile = std::make_shared("Tile"); +const PrimitivePtr kPrimAddN = std::make_shared("AddN"); +const PrimitivePtr KPrimTransData = std::make_shared("TransData"); +const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); +const PrimitivePtr kPrimPad = std::make_shared("Pad"); +const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); + +// Maths +const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); +const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); +const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); +const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); +const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); +const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); +const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); +const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); +const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); +const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); +const PrimitivePtr kPrimNeg = std::make_shared("Neg"); +const PrimitivePtr kPrimSub = std::make_shared("Sub"); +const PrimitivePtr kPrimMul = std::make_shared("Mul"); +const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); +const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); +const PrimitivePtr kPrimSquare = std::make_shared("Square"); +const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); +const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); +const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); +const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); +const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); +const PrimitivePtr kPrimPow = std::make_shared("Pow"); +const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); +const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); +const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); +const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); + +// NN +const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); +const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); +const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); +const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); +const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); +const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); +const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); +const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); +const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); +const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); +const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); +const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); +const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); +const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); +const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); +const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); +const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); +const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); +const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); +const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); +const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); +const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = + std::make_shared("DepthwiseConv2dNativeBackpropFilter"); +const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = + std::make_shared("DepthwiseConv2dNativeBackpropInput"); +const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); +const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared("SoftmaxCrossEntropyWithLogits"); +const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = + std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); +const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); +const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); +const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); +const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); +const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); +const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); +const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); +const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); +const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); +const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); +const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); +const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); +const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); +const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); +const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); +const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); +const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); +const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); +const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); + +// Other miscellaneous +const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +const PrimitivePtr kPrimPartial = std::make_shared("Partial"); +const PrimitivePtr kPrimJ = std::make_shared("J"); +const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); +const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); +const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); +const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); +const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); +const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); +const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); +const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); +const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); +const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); +const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); +const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); +const PrimitivePtr kPrimPrint = std::make_shared("Print"); + +const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); +const PrimitivePtr kPrimDepend = std::make_shared("Depend"); +const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); + +const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); +const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); +const PrimitivePtr kPrimIs_ = std::make_shared("is_"); +const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); +const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); +const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); +const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); +const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); +const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); + +// Comm ops +const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); +const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); +const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); + +// Debug ops +const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); +const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); +const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); +const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); +const PrimitivePtr kPrimDebug = std::make_shared("Debug"); + +// IndexedSlices +const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); +const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); +const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); +const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); +const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h similarity index 100% rename from mindspore/ccsrc/operator/ops.h rename to mindspore/ccsrc/frontend/operator/ops.h diff --git a/mindspore/ccsrc/frontend/operator/ops_extends.cc b/mindspore/ccsrc/frontend/operator/ops_extends.cc new file mode 100755 index 0000000000..c406682c3e --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/ops_extends.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/ops.h" +#include +#include +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" + +namespace mindspore { +// namespace to support primitive operators +namespace prim { +ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) { + py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); + ValuePtr node = nullptr; + bool succ = parse::ConvertData(obj, &node, use_signature); + if (!succ) { + MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail"; + } + return node; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_arrays.cc b/mindspore/ccsrc/frontend/operator/prim_arrays.cc new file mode 100644 index 0000000000..caaf1d1b2a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_arrays.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "frontend/operator/cc_implementations.h" +#include "abstract/param_validator.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a scalar. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractScalarPtr arg = CheckArg(op_name, args_spec_list, 0); + return std::make_shared(arg, std::make_shared()); +} + +AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor with 0 shape. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + auto a_shp = arg->shape(); + if (!a_shp->shape().empty()) { + MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape."; + } + return arg->element(); +} + +AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto xs = CheckArg(op_name, args_spec_list, 0); + auto ys = CheckArg(op_name, args_spec_list, 1); + + auto value_tuple_x = xs->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(value_tuple_x); + auto shp_tuple_x = value_tuple_x->value(); + std::vector shp_x; + (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x), + [](const ValuePtr &e) -> int { return GetValue(e); }); + + auto value_tuple_y = ys->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(value_tuple_y); + auto shp_tuple_y = value_tuple_y->value(); + std::vector shp_y; + (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), + [](const ValuePtr &e) -> int { return GetValue(e); }); + + std::vector res = prim::BroadcastShape_(shp_x, shp_y); + if (res.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + + AbstractBasePtrList elems; + (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int n) -> AbstractBasePtr { + return std::make_shared(std::make_shared(n), kInt32); + }); + + return std::make_shared(elems); +} + +AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, 0); + MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); + + AbstractBasePtrList values; + auto shp = arg->shape(); + for (int entry : shp->shape()) { + auto entry_v = MakeValue(entry); + values.push_back(std::make_shared(entry_v, entry_v->type())); + } + return std::make_shared(values); +} + +AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto arg = CheckArg(op_name, args_spec_list, 0); + auto multiples = CheckArg(op_name, args_spec_list, 1); + + ShapePtr input_shape = arg->shape(); + (void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s"); + + auto mul_shp_value = multiples->BuildValue(); + if (mul_shp_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString(); + } + + std::vector mul_shp; + auto value_tuple_mul = mul_shp_value->cast(); + auto mul_shp_data = value_tuple_mul->value(); + (void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp), + [](const ValuePtr &e) -> int { return GetValue(e); }); + if (input_shape->shape().size() != mul_shp_data.size()) { + MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is " + << input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << "."; + } + + std::vector result_shp; + for (size_t i = 0; i < mul_shp_data.size(); ++i) { + result_shp.push_back(input_shape->shape()[i] * mul_shp[i]); + } + return std::make_shared(arg->element(), std::make_shared(result_shp)); +} + +AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple of tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + if (arg->elements().empty()) { + MS_LOG(EXCEPTION) << "Arg elements is empty."; + } + + size_t tuple_len = arg->elements().size(); + AbstractTensorPtr tensor_base = CheckArg(op_name, arg->elements(), 0); + int rank_base = SizeToInt(tensor_base->shape()->shape().size()); + + ValuePtr axis = primitive->GetAttr("axis"); + // Axis value should be in [-(rank_base + 1), rank_base). + int axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); + // If axis is negative, add offset(rank_base + 1) to turn it to positive. + axis_value = GetPositiveAxis(axis_value, IntToSize(rank_base + 1)); + + for (size_t i = 1; i < tuple_len; ++i) { + AbstractTensorPtr tensor = CheckArg(op_name, arg->elements(), i); + (void)CheckDtypeSame(op_name, tensor_base, tensor); + (void)CheckShapeSame(op_name, tensor_base, tensor); + } + + primitive->set_attr("N", MakeValue(SizeToInt(tuple_len))); + primitive->set_attr("T", tensor_base->element()->BuildType()); + + AbstractTensorPtr ret = dyn_cast(tensor_base->Broaden()); + MS_EXCEPTION_IF_NULL(ret); + auto shape = ret->shape()->shape(); + (void)shape.insert(shape.begin() + axis_value, tuple_len); + ret->set_shape(std::make_shared(shape)); + return ret; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_debug.cc b/mindspore/ccsrc/frontend/operator/prim_debug.cc new file mode 100644 index 0000000000..718dadf5c1 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_debug.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor(value) + const std::string op_name = primitive->name(); + + CheckArgsSize(op_name, args_spec_list, 1); + auto tensor_value = CheckArg(op_name, args_spec_list, 0); + + int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); + if (tensor_rank == 0) { + MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; + } + + return std::make_shared(AbstractBasePtrList({tensor_value->Broaden()})); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_maths.cc b/mindspore/ccsrc/frontend/operator/prim_maths.cc new file mode 100644 index 0000000000..e4543a3821 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_maths.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" +#include "common/utils.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_y = CheckArg(op_name, args_spec_list, 1); + auto dout = CheckArg(op_name, args_spec_list, 2); + (void)CheckTensorsDTypeSame({input_x, input_y, dout}, {kInt, kUInt, kFloat}, + op_name + "evaluator three inputs should be %s"); + + AbstractBasePtr dx = input_x->Broaden(); + AbstractBasePtr dy = input_y->Broaden(); + + return std::make_shared(AbstractBasePtrList({dx, dy})); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_nn.cc b/mindspore/ccsrc/frontend/operator/prim_nn.cc new file mode 100644 index 0000000000..96c86d815d --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_nn.cc @@ -0,0 +1,432 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input_tensor = CheckArg(op_name, args_spec_list, 0); + (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s"); + + ShapePtr input_shape = dyn_cast(input_tensor->GetShapeTrack()); // NCHW + MS_EXCEPTION_IF_NULL(input_shape); + if (input_shape->shape().size() != 4) { + MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor."; + } + int h_input = input_shape->shape()[2]; + int w_input = input_shape->shape()[3]; + + int window = primitive->GetAttr("window")->cast()->value(); + int stride = primitive->GetAttr("stride")->cast()->value(); + int padding = primitive->GetAttr("pad")->cast()->value(); + int nan_opt = primitive->GetAttr("nan_opt")->cast()->value(); + int data_mode = primitive->GetAttr("data_mode")->cast()->value(); + int ceil_mode = primitive->GetAttr("ceil_mode")->cast()->value(); + + if (stride <= 0) { + MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0"; + } + if (nan_opt != 0) { + MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0"; + } + if (data_mode != 1) { + MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1"; + } + if (ceil_mode != 0) { + MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; + } + + std::set available_pad_mode{"pad", "same", "valid"}; + auto pad_mode_ptr = primitive->GetAttr("pad_mode"); + if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa()) { + auto pad_mode = pad_mode_ptr->cast()->value(); + if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { + MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; + } + if (pad_mode == "valid") { + padding = 0; + } else if (pad_mode == "same") { + padding = (window - 1) / 2; + } + } + + std::set available_mode{"max", "avg"}; + auto mode_ptr = primitive->GetAttr("mode"); + if ((mode_ptr != nullptr) && mode_ptr->isa()) { + auto mode = mode_ptr->cast()->value(); + if (available_mode.find(mode) == available_mode.end()) { + MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; + } + } + + int h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1; + int w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1; + std::vector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; + AbstractBasePtr ret = input_tensor->Broaden(); + ret->set_shape(std::make_shared(shape_out)); + return ret; +} + +AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(y, dy, x). + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto out_y = CheckArg(op_name, args_spec_list, 0); + auto d_out = CheckArg(op_name, args_spec_list, 1); + auto input_x = CheckArg(op_name, args_spec_list, 2); + (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat}, + op_name + "evaluator three inputs should be %s"); + + AbstractBasePtr ret = d_out->Broaden(); + auto x_shape = dyn_cast(args_spec_list[2]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(x_shape); + + ret->set_shape(x_shape); + return ret; +} + +void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { + // check dimension, x > 1, others equal 1 + const std::string op_name = primitive->name(); + for (std::size_t i = 0; i < args_spec_list.size(); ++i) { + AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, i); + ShapePtr arg_shape = dyn_cast(arg->GetShapeTrack()); + if (arg_shape == nullptr) { + MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); + } + + if (i == 0) { + if (arg_shape->shape().size() < 2) { + MS_LOG(EXCEPTION) << op_name << " shape of args[" << i + << "] should be TensorShape with dimension greater than 1, but shape: " + << arg_shape->ToString(); + } + continue; + } + + if (arg_shape->shape().size() != 1) { + MS_LOG(EXCEPTION) << op_name << " shape of args[" << i + << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); + } + } +} + +AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(x, gamma, beta, mean, variance). + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 5); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString() + << ", arg1:" << args_spec_list[1]->ToString(); + FusedBatchNormCheckDim(primitive, args_spec_list); + + auto input = args_spec_list[0]; + auto input_shape = dyn_cast(input->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(input_shape); + const auto &input_shape_list = input_shape->shape(); + if (input_shape_list.size() < 2) { + MS_LOG(EXCEPTION) << "Input shape size should >= 2."; + } + + for (size_t i = 1; i < args_spec_list.size(); ++i) { + auto arg_shape = dyn_cast(args_spec_list[i]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(arg_shape); + const auto &arg_shape_list = arg_shape->shape(); + if (arg_shape_list.size() < 1) { + MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; + } + if (arg_shape_list[0] != input_shape_list[1]) { + MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] + << ") should match the second dimension of tensor" + " param[0](which is " + << input_shape_list[1] << ")."; + } + } + auto input_tensor = CheckArg(op_name, args_spec_list, 0); + (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s"); + + AbstractTensorPtrList tensorPtrList = std::vector(); + for (size_t i = 1; i < args_spec_list.size(); ++i) { + auto param = CheckArg(op_name, args_spec_list, i); + tensorPtrList.push_back(param); + } + (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s"); + + // check validity; + auto epsilon_value = primitive->GetAttr("epsilon"); + auto momentum_value = primitive->GetAttr("momentum"); + MS_EXCEPTION_IF_NULL(epsilon_value); + MS_EXCEPTION_IF_NULL(momentum_value); + if (!epsilon_value->isa() || !momentum_value->isa()) { + MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString() + << ", momentum: " << momentum_value->ToString(); + } + + auto epsilon = epsilon_value->cast()->value(); + auto momentum = momentum_value->cast()->value(); + + if (epsilon > 1.0f || epsilon <= 0.0f) { + MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon; + } + if (momentum > 1.0f || momentum < 0.0f) { + MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum; + } + + // Outputs: y, running_mean, running_variance, save_mean, save_inv_variance. + AbstractBasePtr y = input->Broaden(); + AbstractBasePtr other = args_spec_list[1]->Broaden(); + MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString(); + + AbstractBasePtrList elements = {y, other, other, other, other}; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + MS_EXCEPTION_IF_NULL(args_spec_list[3]); + + CheckArgsSize(primitive->name(), args_spec_list, 5); + auto dx = args_spec_list[1]->Broaden(); + auto dscale = args_spec_list[2]->Broaden(); + auto dbias = args_spec_list[3]->Broaden(); + + AbstractBasePtrList rets = {dx, dscale, dbias}; + return std::make_shared(rets); +} + +AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors(y_backprop, x). + CheckArgsSize(primitive->name(), args_spec_list, 2); + return args_spec_list[1]->Broaden(); +} + +AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(doutput, input, filters). + CheckArgsSize(primitive->name(), args_spec_list, 3); + return args_spec_list[1]->Broaden(); +} + +AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(inputs, filter, doutput). + CheckArgsSize(primitive->name(), args_spec_list, 3); + return args_spec_list[2]->Broaden(); +} + +AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: at least one tensor(y_backprop) + // Outputs: dbias + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is " + << args_spec_list.size() << "."; + } + + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + ShapePtr shape_y = dyn_cast(args_spec_list[0]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(shape_y); + std::vector y_dims = shape_y->shape(); + if (y_dims.size() < 2) { + MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << "."; + } + std::vector bias_dims = {y_dims[1]}; + ShapePtr ret_shape = std::make_shared(bias_dims); + AbstractBasePtr ret = args_spec_list[0]->Broaden(); + ret->set_shape(ret_shape); + return ret; +} + +AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} + +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor. + AbstractBasePtrList args_list; + for (size_t i = 0; i < args_spec_list.size() - 2; i++) { + args_list.push_back(args_spec_list[i]->Broaden()); + } + return std::make_shared(args_list); +} + +AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three tensors(x, gamma, beta). + // outputs: y, mean, variance + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_shape = input_x->shape(); + auto const &input_shape_list = input_shape->shape(); + const size_t input_rank = input_shape_list.size(); + if (input_rank == 0) { + MS_LOG(EXCEPTION) << "input_rank should not be zero"; + } + + // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 + ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); + int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1); + + ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); + int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1); + begin_params_axis = GetPositiveAxis(begin_params_axis, input_rank); + + // the beta and gama shape should be x_shape[begin_params_axis:] + auto tensor = CheckArg(op_name, args_spec_list, 0); + auto gamma = CheckArg(op_name, args_spec_list, 1); + auto beta = CheckArg(op_name, args_spec_list, 2); + (void)CheckTensorDType(tensor, {kFloat16, kFloat32}, "input 0 of LayerNorm should be %s"); + (void)CheckTensorDType(gamma, {kFloat16, kFloat32}, "input 1 of LayerNorm should be %s"); + (void)CheckTensorDType(beta, {kFloat16, kFloat32}, "input 2 of LayerNorm should be %s"); + auto gamma_shape = dyn_cast(gamma->BuildShape()); + auto beta_shape = dyn_cast(beta->BuildShape()); + MS_EXCEPTION_IF_NULL(gamma_shape); + MS_EXCEPTION_IF_NULL(beta_shape); + + auto const &gamma_shape_list = gamma_shape->shape(); + auto const &beta_shape_list = beta_shape->shape(); + if (gamma_shape_list.empty() || beta_shape_list.empty()) { + MS_LOG(EXCEPTION) << "LayerNorm evaluator gamma or beta is a AbstractScalar that is not support."; + } + + size_t begin_params_axis_u = IntToSize(begin_params_axis); + if ((begin_params_axis_u > input_shape_list.size()) || + (gamma_shape_list.size() + begin_params_axis_u < input_shape_list.size()) || + (beta_shape_list.size() + begin_params_axis_u < input_shape_list.size())) { + MS_LOG(EXCEPTION) << "Gamma and beta shape get wrong size."; + } + for (size_t i = begin_params_axis_u; i < input_shape_list.size(); ++i) { + size_t gamma_beta_shape_dim = i - begin_params_axis_u; + if ((gamma_shape_list[gamma_beta_shape_dim] != input_shape_list[i]) || + (beta_shape_list[gamma_beta_shape_dim] != input_shape_list[i])) { + MS_LOG(EXCEPTION) << "Gamma or beta shape not match input shape, input_shape=" << input_shape->ToString() + << ", gamma_shape=" << gamma_shape->ToString() << ", beta_shape=" << beta_shape->ToString(); + } + } + + auto mean_var_shape_value = input_shape->shape(); + if (begin_norm_axis == -1) { + mean_var_shape_value[input_rank - 1] = 1; + } else { + for (size_t i = begin_norm_axis; i < input_rank; ++i) { + mean_var_shape_value[i] = 1; + } + } + + auto mean = input_x->Broaden(); + mean->set_shape(std::make_shared(mean_var_shape_value)); + auto var = input_x->Broaden(); + var->set_shape(std::make_shared(mean_var_shape_value)); + + AbstractBasePtrList args_list({input_x->Broaden(), mean, var}); + return std::make_shared(args_list); +} + +AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(y_backprob, x, variance, mean, gamma). + // Outputs: x_backprob, gamma_backprob, beta_backprob + CheckArgsSize(primitive->name(), args_spec_list, 5); + + auto x_backprob = args_spec_list[0]->Broaden(); + auto gamma_backprob = args_spec_list[4]->Broaden(); + auto beta_backprob = args_spec_list[4]->Broaden(); + + AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob}); + return std::make_shared(args_list); +} + +AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple and a tensor. + // Outputs: mask. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr x_shape = CheckArg(op_name, args_spec_list, 0); + AbstractTensorPtr keep_prob = CheckArg(op_name, args_spec_list, 1); + + TypePtr prob_type = keep_prob->element()->BuildType(); + if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { + MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() + << "."; + } + + auto x_shape_data = x_shape->elements(); + int count = 1; + for (std::size_t i = 0; i < x_shape->size(); ++i) { + auto value_track = x_shape_data[i]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + if (!value_track->isa()) { + MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << "."; + } + + int e_value = GetValue(value_track); + if (e_value <= 0) { + MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0"; + } + if (std::numeric_limits::max() / count / e_value < 1) { + MS_LOG(EXCEPTION) << "integer multiply integer overflow"; + } + count = count * e_value; + } + + // convert to bytes(8 bits) mask, using round up + int n128s = count / 128; + if ((count % 128) != 0) { + n128s++; + } + int bytes_count = n128s * 16; + std::vector shape_y{bytes_count}; + + primitive->set_attr("T", kInt32); + return std::make_shared(std::make_shared(kAnyValue, kUInt8), + std::make_shared(std::vector{shape_y})); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc new file mode 100644 index 0000000000..530ad6a10c --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -0,0 +1,410 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "ir/dtype.h" +#include "common/utils.h" +#include "frontend/operator/ops.h" +#include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "abstract/utils.h" +#include "utils/context/ms_context.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // An object of a subclass of AbstractBase + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]; +} + +AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: An object of AbstractFunction. + CheckArgsSize(primitive->name(), args_spec_list, 1); + MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); + + AbstractFunctionPtr x = dyn_cast(args_spec_list[0]); + if (x == nullptr) { + return std::make_shared(args_spec_list[0]); + } + + AbstractFuncAtomPtrList jv; + auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { + auto j_closure = std::make_shared(func); + jv.push_back(j_closure); + }; + x->Visit(build_jv); + + return AbstractFunction::MakeAbstractFunction(jv); +} + +AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). + CheckArgsSize(primitive->name(), args_spec_list, 3); + auto key = args_spec_list[1]; + auto dflt = args_spec_list[2]; + TypePtr type = key->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kObjectTypeSymbolicKeyType) { + MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); + } + + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (enable_sparse && dflt->isa()) { + auto dflt_tensor = dflt->cast(); + return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); + } + + if (!key->GetValueTrack()->isa()) { + return dflt; + } + ValuePtr key_value_ptr = key->GetValueTrack(); + MS_EXCEPTION_IF_NULL(key_value_ptr); + auto key_value_track = key_value_ptr->cast(); + auto expected = key_value_track->abstract(); + MS_EXCEPTION_IF_NULL(expected); + (void)expected->Join(dflt); + return expected; +} + +AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). + CheckArgsSize(primitive->name(), args_spec_list, 3); + + auto key = args_spec_list[1]; + ValuePtr key_value_ptr = key->GetValueTrack(); + MS_EXCEPTION_IF_NULL(key_value_ptr); + auto key_value_track = key_value_ptr->cast(); + if (key_value_track == nullptr) { + MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: " + << key_value_ptr->ToString(); + } + auto expected = key_value_track->abstract(); + MS_EXCEPTION_IF_NULL(expected); + return std::make_shared(kAnyValue, std::make_shared()); +} + +AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). + CheckArgsSize(primitive->name(), args_spec_list, 2); + return std::make_shared(kAnyValue, std::make_shared()); +} + +AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) { + ValuePtr name_value = prim->GetAttr("tag"); + auto name = name_value->cast(); + if (name == nullptr) { + MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << "."; + } + auto refkey = std::make_shared(name->value()); + if (refkey == nullptr) { + MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared failed"; + } + return refkey->ToAbstract(); +} + +AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: key, value, original value + if (args_spec_list.size() != 3) { + MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() + << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRefKey) { + MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); + } + return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); +} + +AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: value + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRef) { + MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString(); + } + return args_spec_list[0]->cast()->ref(); +} + +AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: value + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() + << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRef) { + MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); + } + return args_spec_list[0]->cast()->ref(); +} + +AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // arguments: value + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() + << "."; + } + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kObjectTypeRef) { + MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); + } + return args_spec_list[0]->cast()->ref_origin(); +} + +AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Two objects of a subclass of AbstractBase, key and value. + CheckArgsSize(primitive->name(), args_spec_list, 2); + + TypePtr type = args_spec_list[0]->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) { + MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString(); + } + return std::make_shared(kAnyValue, kBool); +} + +AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; + } + auto depends = args_spec_list[0]->Broaden(); + return depends; +} + +bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { + if (x_shape.size() != y_shape.size()) { + return false; + } + + for (size_t i = 0; i < x_shape.size(); ++i) { + if (GetValue(x_shape[i]) != GetValue(y_shape[i])) { + return false; + } + } + + return true; +} + +enum State { + SAME, + X_ONE, + Y_ONE, +}; + +void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, + std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { + const size_t n = reverse_x.size(); + for (size_t i = 0; i < n; ++i) { + State curr; + const int32_t x_i = reverse_x[i]; + const int32_t y_i = reverse_y[i]; + const int reduce_idx = SizeToInt(n - 1 - i); + if (x_i == y_i) { + curr = SAME; + } else if (x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + curr = X_ONE; + } else if (y_i == 1) { + grad_y_reduce_idy->push_back(reduce_idx); + curr = Y_ONE; + } else { + MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; + } + if (curr == SAME && x_i == 1) { + grad_x_reduce_idx->push_back(reduce_idx); + grad_y_reduce_idy->push_back(reduce_idx); + continue; + } + } + + std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); + std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); +} + +AbstractBasePtr BroadcastGradientArgsDiff(const std::vector &x_shape, const std::vector &y_shape) { + std::vector reverse_x; + std::vector reverse_y; + + (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), + [](const ValuePtr &v) { return v->cast()->value(); }); + (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), + [](const ValuePtr &v) { return v->cast()->value(); }); + + if (reverse_x.size() > reverse_y.size()) { + reverse_y.resize(reverse_x.size(), 1); + } else { + reverse_x.resize(reverse_y.size(), 1); + } + + std::vector grad_x_reduce_idx; + std::vector grad_y_reduce_idy; + ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); + + AbstractBasePtrList abs_list_x; + AbstractBasePtrList abs_list_y; + (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), + [](int v) { return abstract::FromValue(v); }); + (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), + [](int v) { return abstract::FromValue(v); }); + auto x_reduce_idx = std::make_shared(abs_list_x); + auto y_reduce_idx = std::make_shared(abs_list_y); + AbstractBasePtrList elem_list; + elem_list.push_back(x_reduce_idx); + elem_list.push_back(y_reduce_idx); + + return std::make_shared(elem_list); +} + +AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // this primitive get the index that need to reduce + // input: x's shape and y's shape, inputs should be tuple + // output: tuple of x and y 's reduce index, reduce index should be a tuple + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto arg_x = CheckArg(op_name, args_spec_list, 0); + auto arg_y = CheckArg(op_name, args_spec_list, 1); + + ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(arg_x_value); + + ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(arg_y_value); + + const std::vector x_shape = arg_x_value->value(); + const std::vector y_shape = arg_y_value->value(); + bool is_same_shape = CompareShape(x_shape, y_shape); + // if it is the same shape , do not need reduce , return empty tuple + if (is_same_shape) { + AbstractBasePtrList empty_list; + auto x_reduce_idx = std::make_shared(empty_list); + auto y_reduce_idx = std::make_shared(empty_list); + + AbstractBasePtrList elem_list; + elem_list.push_back(x_reduce_idx); + elem_list.push_back(y_reduce_idx); + + return std::make_shared(elem_list); + } + + return BroadcastGradientArgsDiff(x_shape, y_shape); +} + +AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // args: Two objects of a subclass of AbstractBase + CheckArgsSize(primitive->name(), args_spec_list, 2); + auto arg_src = args_spec_list[0]; + auto arg_dst = args_spec_list[1]; + // control depend can not setup tuple of ops to tuple of ops dependency relation + if (arg_src->isa() && arg_dst->isa()) { + auto src_size = arg_src->cast()->size(); + auto dst_size = arg_src->cast()->size(); + if (src_size > 1 && dst_size > 1) { + MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple"; + } + } + return std::make_shared(kAnyValue, kBool); +} + +AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto indices = CheckArg(op_name, args_spec_list, 0); + auto values = CheckArg(op_name, args_spec_list, 1); + auto dense_shape = CheckArg(op_name, args_spec_list, 2); + + auto dense_shape_value = dense_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(dense_shape_value); + auto shp = dense_shape_value->value(); + std::vector dense_shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), + [](const ValuePtr &e) -> int { + auto elem = GetValue(e); + return elem; + }); + auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); + ret->set_indices(indices); + ret->set_values(values); + ret->set_dense_shape(dense_shape); + return ret; +} + +AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->values()); + return indexed_slices->values(); +} + +AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->indices()); + return indexed_slices->indices(); +} + +AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); + return indexed_slices->dense_shape(); +} + +AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + bool ret = false; + if (args_spec_list[0]->isa()) { + ret = true; + } + MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); + return std::make_shared(ret); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_statement.cc b/mindspore/ccsrc/frontend/operator/prim_statement.cc new file mode 100644 index 0000000000..bb421bdf8a --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_statement.cc @@ -0,0 +1,249 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "abstract/param_validator.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "abstract/utils.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace abstract { +AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object + if (args_spec_list.size() != 1) { + MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? " + "while the input size is " + << args_spec_list.size() << "."; + } + AbstractBasePtr abs_base = args_spec_list[0]; + return abs_base; +} + +AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() + << "."; + } + AbstractBasePtr abs_base = args_spec_list[0]; + MS_EXCEPTION_IF_NULL(abs_base); + TypePtr type = abs_base->BuildType(); + return std::make_shared(type); +} + +AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a pointer to an AbstractBase object and a pointer to a Type + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); + + auto mode_v = abs_type->GetValueTrack(); + MS_EXCEPTION_IF_NULL(mode_v); + if (!mode_v->isa()) { + MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; + } + + TypePtr mode_t = mode_v->cast(); + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + bool v = IsSubtype(args_spec_list[0], mode_t); + return std::make_shared(std::make_shared(v), kBool); +} + +AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); + AbstractTensorPtr input_y = CheckArg(op_name, args_spec_list, 1); + + ShapePtr x_shp = input_x->shape(); + auto x_shp_value = x_shp->shape(); + ShapePtr y_shp = input_y->shape(); + auto y_shp_value = y_shp->shape(); + // Should be matrix which shape size is 2. + if (x_shp_value.size() != 2 || y_shp_value.size() != 2) { + MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are " + << x_shp_value.size() << ", " << y_shp_value.size() << " "; + } + if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) { + MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}"; + } + + auto x_element = input_x->element(); + MS_EXCEPTION_IF_NULL(x_element); + (void)x_element->Join(input_y->element()); + auto param = {x_shp_value[0], y_shp_value[1]}; + + return std::make_shared(input_x->element(), std::make_shared(param)); +} + +AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim, + const AbstractBasePtrList &args_spec_list) { + // Inputs: condition, true branch, false branch + if (args_spec_list.size() != 3) { + MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size() + << "."; + } + + auto cond = args_spec_list[0]; + auto tb = args_spec_list[1]; + auto fb = args_spec_list[2]; + MS_EXCEPTION_IF_NULL(cond); + + auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG); + if (unroll_flag != nullptr && GetValue(unroll_flag) == 0) { + return tb->Join(fb); + } + + ValuePtr v = cond->GetValueTrack(); + MS_EXCEPTION_IF_NULL(v); + // for tensor as condition, keeps both true and false branch. + if (v->isa() || cond->isa()) { + MS_EXCEPTION_IF_NULL(tb); + return tb->Join(fb); + } + + if (v->isa()) { + if (v->cast()->IsOne()) { + return tb; + } else { + return fb; + } + } + + MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); +} + +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: index, branch + const std::string op_name = primitive->name(); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + (void)CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); + AbstractBasePtrList branches = branches_abs->elements(); + const size_t maximum_layer_num = 1000; + if (branches.size() < 0 || branches.size() > maximum_layer_num) { + MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " + << branches.size() << " branches."; + } + + for (size_t i = 0; i < branches.size(); i++) { + MS_EXCEPTION_IF_NULL(branches[i]); + if (!branches[i]->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " + << branches[i]->ToString() << " as the " << i << "th element."; + } + } + + auto b = branches[0]; + for (size_t i = 1; i < branches.size(); i++) { + b = b->Join(branches[i]); + } + return b; +} + +std::vector GetSupportedTargetValue() { + std::vector list = {kNone, MakeValue(false), MakeValue(true)}; + return list; +} + +bool SupportedIsTargetValue(const ValuePtr t) { + auto list = GetSupportedTargetValue(); + auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; }); + return match; +} + +AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x is t + // Inputs: x, t + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + ValuePtr t = args_spec_list[1]->BuildValue(); + if (!SupportedIsTargetValue(t)) { + MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() + << " for statement is, supported list is:None, False, True "; + } + ValuePtr x = args_spec_list[0]->BuildValue(); + + return std::make_shared(*t == *x); +} + +AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x is not t + // Inputs: x, t + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + ValuePtr t = args_spec_list[1]->BuildValue(); + if (!SupportedIsTargetValue(t)) { + MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() + << " for statement is not, supported list is:None, False, True "; + } + ValuePtr x = args_spec_list[0]->BuildValue(); + + return std::make_shared(!(*t == *x)); +} + +bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto key = CheckArg(op_name, args_spec_list, 0); + auto dict = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + auto key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + return it != dict_elems.end(); +} + +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x in t + // Inputs: x, t + return std::make_shared(IsInDict(primitive, args_spec_list)); +} + +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: x not in t + // Inputs: x, t + return std::make_shared(!IsInDict(primitive, args_spec_list)); +} + +AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // statement: isconstant(x) + // Inputs: x + if (args_spec_list.size() != 1) { + MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1"; + } + ValuePtr v = args_spec_list[0]->BuildValue(); + return std::make_shared(!v->isa()); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_structures.cc b/mindspore/ccsrc/frontend/operator/prim_structures.cc new file mode 100644 index 0000000000..b602b07a0c --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_structures.cc @@ -0,0 +1,712 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" +#include "abstract/utils.h" +#include "abstract/param_validator.h" +#include "frontend/operator/ops.h" +#include "utils/convert_utils.h" +#include "ir/tensor_py.h" + +using mindspore::tensor::TensorPy; + +namespace mindspore { +namespace abstract { + +AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + bool ret = (value_x->cast()->value() == value_y->cast()->value()); + return std::make_shared(ret); +} + +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two scalars whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr value_x = scalar_x->BuildValue(); + ValuePtr value_y = scalar_y->BuildValue(); + if (!value_x->isa() || !value_y->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() + << ", param1: " << value_y->ToString(); + } + + std::string ret = (value_x->cast()->value() + value_y->cast()->value()); + return std::make_shared(ret); +} + +AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(args_spec_list); +} + +AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(args_spec_list); +} + +AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr keys = CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr values = CheckArg(op_name, args_spec_list, 1); + + size_t keys_size = keys->size(); + if (values->size() != keys_size) { + MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; + } + + std::vector key_value; + AbstractScalarPtr key; + AbstractBasePtrList key_list = keys->elements(); + AbstractBasePtrList value_list = values->elements(); + for (size_t index = 0; index < keys_size; index++) { + key = CheckArg(op_name + "key", key_list, index); + ValuePtr keyPtr = key->BuildValue(); + MS_EXCEPTION_IF_NULL(keyPtr); + if (!keyPtr->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); + } + std::string key_string = GetValue(keyPtr); + key_value.emplace_back(key_string, value_list[index]); + } + return std::make_shared(key_value); +} + +AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a string and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); + + ValuePtr keyPtr = key->BuildValue(); + if (!keyPtr->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); + } + std::string key_string = GetValue(keyPtr); + return std::make_shared(key_string, args_spec_list[1]); +} + +AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a string and a keyword. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); + AbstractKeywordArgPtr kwarg = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + std::string key_input = GetValue(key_value); + std::string key_actual = kwarg->get_key(); + if (key_actual != key_input) { + MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " + << key_input << ", AbstractKeywordArg' key is " << key_actual; + } + return kwarg->get_arg(); +} + +AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: three scalars whose value is an int32 number. + CheckArgsSize(primitive->name(), args_spec_list, 3); + size_t args_size = args_spec_list.size(); + for (size_t index = 0; index < args_size; index++) { + MS_EXCEPTION_IF_NULL(args_spec_list[index]); + if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { + MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; + } + if (args_spec_list[index]->isa() && + !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { + MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number."; + } + } + // Slice: start, end, step + return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); +} + +// Eval the return type of make_record +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: at lease two objects of a subclass of AbstractBase. + if (args_spec_list.size() < 2) { + MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " + << args_spec_list.size() << "."; + } + + // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + TypePtr type = args_spec_list[0]->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(type); + if (type->type_id() != kMetaTypeTypeType) { + MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; + } + + ValuePtr value_track = args_spec_list[0]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + TypePtr type_ptr = value_track->cast(); + if (type_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); + } + + auto cls = dyn_cast(type_ptr); + MS_EXCEPTION_IF_NULL(cls); + ClassAttrVector attributes = cls->GetAttributes(); + CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); + + std::vector abs_attributes; + for (size_t i = 0; i < attributes.size(); i++) { + AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); + abs_attributes.push_back(elem); + } + + return std::make_shared(cls->tag(), abs_attributes, cls->methods()); +} + +template +AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list and a scalar whose value is an int32 number. + CheckArgsSize(op_name, args_spec_list, 2); + auto queue = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); + + ValuePtr index_value = index->BuildValue(); + if (!index_value->isa()) { + // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element + // and continue + if (dyn_cast(queue->elements()[0]) != nullptr) { + return std::make_shared(queue->elements()[0]->BuildType()); + } + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " + << index_value->ToString(); + } + int idx_v = GetValue(index_value); + std::size_t nelems = queue->elements().size(); + if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " + << SizeToInt(nelems) << "), but got " << idx_v << "."; + } + + std::size_t uidx_v = 0; + if (idx_v >= 0) { + uidx_v = IntToSize(idx_v); + } else { + uidx_v = IntToSize(idx_v + SizeToInt(nelems)); + } + return queue->elements()[uidx_v]; +} + +template +AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. + CheckArgsSize(op_name, args_spec_list, 3); + auto queue = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); + + ValuePtr index_value = index->BuildValue(); + if (!index_value->isa()) { + MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " + << index_value->ToString(); + } + int idx_v = GetValue(index_value); + if (idx_v < 0) { + MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v + << "."; + } + + size_t uidx_v = IntToSize(idx_v); + AbstractBasePtrList elements = queue->elements(); + std::size_t nelems = elements.size(); + if (uidx_v >= nelems) { + MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 + << "."; + } + elements[uidx_v] = args_spec_list[2]; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListGetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListGetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListSetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListSetItem(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict and a scalar whose value is a string. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + auto key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](const AbstractAttribute &item) { return item.first == key_str; }); + + if (it == dict_elems.end()) { + MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); + } + return it->second; +} + +AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); + AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); + + ValuePtr key_value = key->BuildValue(); + if (!key_value->isa()) { + MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); + } + std::string key_str = GetValue(key_value); + std::vector dict_elems = dict->elements(); + auto it = std::find_if(dict_elems.begin(), dict_elems.end(), + [key_str](AbstractAttribute &item) { return item.first == key_str; }); + + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + auto new_ele = std::make_pair(key_str, args_spec_list[2]); + if (it != dict_elems.end()) { + int index = it - dict_elems.begin(); + dict_elems[IntToSize(index)] = new_ele; + } else { + dict_elems.push_back(new_ele); + } + return std::make_shared(dict_elems); +} + +AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a list and an object of a subclass of AbstractBase. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractListPtr list = CheckArg(op_name, args_spec_list, 0); + (void)AbstractJoin(list->elements()); + return list; +} + +template +AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple or list or dict. + CheckArgsSize(op_name, args_spec_list, 1); + auto arg = CheckArg(op_name, args_spec_list, 0); + return std::make_shared(SizeToInt(arg->size())); +} + +AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + return std::make_shared(kAnyValue, kInt32); +} + +AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: fn, list1, list2, ... + MS_EXCEPTION_IF_NULL(engine); + if (args_spec_list.size() <= 1) { + MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << "."; + } + AbstractFunctionPtr fn = CheckArg(primitive->name(), args_spec_list, 0); + // check args from 1. + CheckArgsSpec(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end())); + + AbstractBasePtrList subargs; + for (std::size_t i = 1; i < args_spec_list.size(); i++) { + AbstractListPtr l_ptr = dyn_cast(args_spec_list[i]); + if (l_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list."; + } + subargs.push_back(AbstractJoin(l_ptr->elements())); + } + EvalResultPtr engin_exc = engine->Execute(fn, subargs); + AbstractBasePtrList result; + for (std::size_t i = 1; i < args_spec_list.size(); i++) { + result.push_back(engin_exc->abstract()); + } + return std::make_shared(result); +} + +AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a fn, a list and an object of a subclass of a AbstractBase. + MS_EXCEPTION_IF_NULL(engine); + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + AbstractFunctionPtr fn = CheckArg(op_name, args_spec_list, 0); + AbstractListPtr lst = CheckArg(op_name, args_spec_list, 1); + AbstractBasePtr dflt = args_spec_list[2]; + + AbstractBasePtr list_type = AbstractJoin(lst->elements()); + auto result1 = engine->Execute(fn, lst->elements()); + auto result2 = engine->Execute(fn, {dflt, list_type}); + MS_EXCEPTION_IF_NULL(result1->abstract()); + MS_EXCEPTION_IF_NULL(result2->abstract()); + return result1->abstract()->Join(result2->abstract()); +} + +AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTuplePtr input = CheckArg(op_name, args_spec_list, 0); + + auto tuple_elements = input->elements(); + AbstractBasePtrList elem_list; + (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list), + [](const AbstractBasePtr &elem) { return elem->Clone(); }); + return std::make_shared(elem_list); +} + +AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, + const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { + size_t x_rank = x_shape->size(); + std::set axis_set; + auto axis_data = axis_value_ptr->value(); + if (axis_data.empty()) { + int size = 1; + AbstractBasePtrList values(x_rank, std::make_shared(size)); + return std::make_shared(values); + } + + for (auto &elem : axis_data) { + int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); + (void)axis_set.insert(e_value); + } + + auto x_shp_data = x_shp_value->cast()->value(); + if (x_shp_data.size() < x_rank) { + MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; + } + AbstractBasePtrList values; + for (size_t i = 0; i < x_rank; i++) { + if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { + auto axis_v = MakeValue(1); + values.push_back(std::make_shared(axis_v, axis_v->type())); + } else { + int dim_value = x_shp_data[i]->cast()->value(); + auto dim = MakeValue(dim_value); + values.push_back(std::make_shared(dim, dim->type())); + } + } + + return std::make_shared(values); +} + +AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: x_shape, axis + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + + auto x_shp_value = shape_x->BuildValue(); + if (x_shp_value->isa()) { + MS_LOG(EXCEPTION) << op_name + << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); + } + + // Axis can be scalar, tuple or None + AbstractTuplePtr axis = nullptr; + if (args_spec_list[1]->isa()) { + MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; + AbstractBasePtrList axis_list = {dyn_cast(args_spec_list[1])}; + axis = std::make_shared(axis_list); + } else if (args_spec_list[1]->isa()) { + MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; + axis = args_spec_list[1]->cast(); + } else { + MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " + << args_spec_list[1]->ToString(); + } + + auto axis_value = axis->BuildValue(); + if (axis_value->isa()) { + MS_LOG(EXCEPTION) << op_name + << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); + } + auto axis_value_ptr = axis_value->cast(); + MS_EXCEPTION_IF_NULL(axis_value_ptr); + + return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); +} + +AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr div_shp = CheckArg(op_name, args_spec_list, 1); + MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString(); + + auto div_shp_value = div_shp->BuildValue(); + if (div_shp_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString(); + } + + auto shpx_value = shape_x->BuildValue(); + if (shpx_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString(); + } + + if (div_shp->size() != shape_x->size()) { + MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size() + << ", shapex: " << shape_x->size() << "."; + } + + auto shpx_data = shpx_value->cast()->value(); + auto div_shp_data = div_shp_value->cast()->value(); + AbstractBasePtrList values; + + for (size_t i = 0; i < div_shp_data.size(); i++) { + if (div_shp_data[i]->cast() == nullptr) { + MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString(); + } + int shapex_value = GetValue(shpx_data[i]); + int div_value = GetValue(div_shp_data[i]); + MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value; + if (div_value == 0) { + MS_LOG(EXCEPTION) << "error: division value should not be 0!"; + } + if ((shapex_value % div_value) != 0) { + MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value; + } + + int result = shapex_value / div_value; + auto result_v = MakeValue(result); + values.push_back(std::make_shared(result_v, result_v->type())); + } + + return std::make_shared(values); +} + +AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTuplePtr input = CheckArg(op_name, args_spec_list, 0); + + py::tuple data_tuple = ValuePtrToPyData(input->BuildValue()); + py::array data = py::array(data_tuple); + auto tensor = TensorPy::MakeTensor(data); + auto ret = tensor->ToAbstract(); + ret->set_value(tensor); + MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString(); + return ret; +} + +AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tuple + // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6 + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); + + auto shpx_value = shape_x->BuildValue(); + if (shpx_value->isa()) { + MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString(); + } + + auto shpx_data = shpx_value->cast()->value(); + + int result = 1; + for (size_t i = 0; i < shpx_data.size(); i++) { + int value = GetValue(shpx_data[i]); + result = IntMulWithOverflowCheck(result, value); + } + + auto result_v = MakeValue(result); + MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString(); + return std::make_shared(result_v, result_v->type()); +} + +template +AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { + // Inputs: two tuples or two lists. + CheckArgsSize(op_name, args_spec_list, 2); + auto input_x = CheckArg(op_name, args_spec_list, 0); + auto input_y = CheckArg(op_name, args_spec_list, 1); + + ValuePtr x_value = input_x->BuildValue(); + ValuePtr y_value = input_y->BuildValue(); + return std::make_shared(*x_value == *y_value); +} + +AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplTupleOrListEqual(primitive->name(), args_spec_list); +} + +AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + return InferImplTupleOrListEqual(primitive->name(), args_spec_list); +} + +struct SlideInfo { + int start; + int step; + int stop; +}; + +void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { + int arg1 = 0; + int arg2 = 0; + if (!args_spec_list.empty()) { + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + auto arg_value = args_spec_list[0]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + arg1 = GetValue(arg_value); + } + + if (args_spec_list.size() >= 2) { + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + auto arg_value = args_spec_list[1]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + arg2 = GetValue(arg_value); + } + + if (args_spec_list.size() == 3) { + MS_EXCEPTION_IF_NULL(args_spec_list[2]); + auto arg_value = args_spec_list[2]->BuildValue(); + if (!arg_value->isa()) { + MS_LOG(EXCEPTION) << "Only supported input an int32 number."; + } + slide->step = GetValue(arg_value); + slide->start = arg1; + slide->stop = arg2; + } + + if (args_spec_list.size() == 2) { + slide->start = arg1; + slide->stop = arg2; + } + + if (args_spec_list.size() == 1) { + slide->stop = arg1; + } +} + +AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "Cannot make range from empty input."; + } + + if (args_spec_list.size() > 3) { + MS_LOG(EXCEPTION) << "Error args size of make range operational."; + } + + SlideInfo slide = {0, 1, 0}; + CalcSlidePara(args_spec_list, &slide); + + if (slide.step == 0) { + MS_LOG(EXCEPTION) << "Error, step value is 0."; + } + + AbstractBasePtrList args; + if (slide.start <= slide.stop) { + if (slide.step <= 0) { + MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; + } + for (int i = slide.start; i < slide.stop; i += slide.step) { + args.push_back(abstract::FromValue(i)); + } + } else { + if (slide.step >= 0) { + MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; + } + for (int i = slide.start; i > slide.stop; i += slide.step) { + args.push_back(abstract::FromValue(i)); + } + } + + return std::make_shared(args); +} + +AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: a tensor + CheckArgsSize(primitive->name(), args_spec_list, 1); + return args_spec_list[0]->Clone(); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/prim_to_function.cc b/mindspore/ccsrc/frontend/operator/prim_to_function.cc new file mode 100644 index 0000000000..7b9592e80e --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/prim_to_function.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/operator/prim_to_function.h" +#include +#include +#include + +namespace mindspore { +// namespace to support prim related definition +namespace prim { + +PrimToFunction::PrimToFunction() + : prim_func_type_map_({// ONE_ARG prim + {"bool_not", kPrimTypeOneArg}, + {"scalar_cos", kPrimTypeOneArg}, + {"scalar_exp", kPrimTypeOneArg}, + {"scalar_floor", kPrimTypeOneArg}, + {"scalar_log", kPrimTypeOneArg}, + {"scalar_sin", kPrimTypeOneArg}, + {"scalar_tan", kPrimTypeOneArg}, + {"scalar_trunc", kPrimTypeOneArg}, + {"typeof", kPrimTypeOneArg}, + {"scalar_uadd", kPrimTypeOneArg}, + {"scalar_usub", kPrimTypeOneArg}, + // TWO_ARGS prim + {"scalar_add", kPrimTypeTwoArgs}, + {"bool_and", kPrimTypeTwoArgs}, + {"bool_eq", kPrimTypeTwoArgs}, + {"bool_or", kPrimTypeTwoArgs}, + {"scalar_div", kPrimTypeTwoArgs}, + {"scalar_eq", kPrimTypeTwoArgs}, + {"scalar_ge", kPrimTypeTwoArgs}, + {"scalar_gt", kPrimTypeTwoArgs}, + {"scalar_le", kPrimTypeTwoArgs}, + {"scalar_lt", kPrimTypeTwoArgs}, + {"scalar_ne", kPrimTypeTwoArgs}, + {"scalar_mod", kPrimTypeTwoArgs}, + {"scalar_mul", kPrimTypeTwoArgs}, + {"scalar_pow", kPrimTypeTwoArgs}, + {"scalar_sub", kPrimTypeTwoArgs}, + {"scalar_floordiv", kPrimTypeTwoArgs}}) {} + +bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { + bool result = false; + + if (func != nullptr) { + int args_num = GetPrimType(prim); + std::vector one_arg{std::make_shared()}; + std::vector two_args{std::make_shared(), std::make_shared()}; + TypePtr retval = std::make_shared(); + result = true; + switch (args_num) { + case kPrimTypeOneArg: + *func = Function(one_arg, retval).DeepCopy()->cast(); + break; + case kPrimTypeTwoArgs: + *func = Function(two_args, retval).DeepCopy()->cast(); + break; + default: + result = false; + break; + } + } + + return result; +} + +int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { + MS_EXCEPTION_IF_NULL(prim); + int prim_type = static_cast(kPrimTypeUnknown); + + auto value = prim_func_type_map_.find(prim->name()); + if (value != prim_func_type_map_.end()) { + prim_type = value->second; + } + return prim_type; +} +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_to_function.h b/mindspore/ccsrc/frontend/operator/prim_to_function.h similarity index 100% rename from mindspore/ccsrc/operator/prim_to_function.h rename to mindspore/ccsrc/frontend/operator/prim_to_function.h diff --git a/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt b/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt new file mode 100644 index 0000000000..14fda83052 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _OPTIMIZER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_OPTIMIZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_OPTIMIZER) +add_library(_mindspore_frontend_optimizer_obj OBJECT ${_OPTIMIZER_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc new file mode 100644 index 0000000000..60ccf28df4 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/ad/adjoint.h" + +#include +#include + +#include "ir/anf.h" +#include "frontend/optimizer/ad/dfunctor.h" + +namespace mindspore { +namespace ad { +Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) + : primal_(primal), caller_(caller), dout_(nullptr) { + if (k != nullptr) { + k_ = k; + MS_LOG(DEBUG) << "Add adjoint for " << primal->ToString() << " " << k_->ToString(); + } else { + // Init k hole in a recursive case. + auto k_hole = std::make_shared("k_hole"); + (void)k_hole->AddAttr("info", MakeValue(primal->ToString())); + k_ = NewValueNode(k_hole); + MS_LOG(DEBUG) << "Add hole for " << primal->ToString() << " " << k_->ToString(); + } + + dout_hole_ = caller_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), k_}); + RegisterKUser(dout_hole_->cast(), 1); +} + +AnfNodePtr Adjoint::k() { return k_; } + +void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } + +void Adjoint::UpdateK(const AnfNodePtr &new_k) { + MS_EXCEPTION_IF_NULL(new_k); + MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); + // In recursive case, it needs update. + for (auto &user : k_user_) { + MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" + << new_k->ToString(); + if (user.first->input(user.second) != k_) { + MS_LOG(EXCEPTION) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k " + << new_k->ToString() << ", user relation is set wrongly"; + } + user.first->set_input(user.second, new_k); + } + k_ = new_k; +} + +AnfNodePtr Adjoint::primal() { return primal_; } + +AnfNodePtr Adjoint::dout() { return dout_hole_; } + +void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { + dout_user_.emplace_back(std::make_pair(user, index)); +} + +void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { + if (dout_ != nullptr) { + MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); + auto add = prim::GetPythonOps("hyper_add"); + dout_ = caller_->NewCNode({NewValueNode(add), dout_, dout_factor}); + return; + } + dout_ = dout_factor; +} + +void Adjoint::CallDoutHole() { + if (dout_ != nullptr) { + for (auto &user : dout_user_) { + MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " + << dout_->ToString(); + if (user.first->input(user.second) != dout_hole_) { + MS_LOG(EXCEPTION) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " + << dout_->ToString() << ", user relation is set wrongly"; + } + user.first->set_input(user.second, dout_); + } + } +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h new file mode 100644 index 0000000000..37986e6810 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ + +#include +#include +#include + +#include "ir/anf.h" +#include "frontend/optimizer/opt.h" + +namespace mindspore { +namespace ad { +class Adjoint { + public: + Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); + ~Adjoint() = default; + AnfNodePtr primal(); + AnfNodePtr k(); + void UpdateK(const AnfNodePtr &k); + void RegisterKUser(const CNodePtr &user, size_t index); + AnfNodePtr dout(); + void AccumulateDout(const AnfNodePtr &dout_factor); + void RegisterDoutUser(const CNodePtr &user, size_t index); + void CallDoutHole(); + + private: + AnfNodePtr primal_; + FuncGraphPtr caller_; + // For ```def f(x): return expr```, The representation graph k is ```def kf(kx): return expr, bprop{expr}```. + AnfNodePtr k_; + std::vector> k_user_; + AnfNodePtr dout_; + AnfNodePtr dout_hole_; + std::vector> dout_user_; +}; + +using AdjointPtr = std::shared_ptr; +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc new file mode 100644 index 0000000000..b314b22f81 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -0,0 +1,617 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/ad/dfunctor.h" + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/meta_func_graph.h" +#include "debug/info.h" +#include "ir/func_graph_cloner.h" +#include "ir/manager.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/optimizer/ad/adjoint.h" +#include "frontend/optimizer/opt.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "utils/symbolic.h" +#include "utils/context/ms_context.h" +#include "./common.h" + +namespace mindspore { +namespace ad { +std::unordered_map DFunctor::func_graph_to_functor_; +std::unordered_map DFunctor::anfnode_to_adjoin_definition_; +FuncGraphSet DFunctor::scope_; + +DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) + : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { + TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); + k_graph_ = std::make_shared(); + if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); + } + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); + tape_ = std::make_shared(); + // Add "_Grad" postfix + if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad"; + tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); + } + TraceManager::EndTrace(); + + dout_ = tape_->add_parameter(); +} + +void DFunctor::Init(bool is_top) { + func_graph_to_functor_[primal_graph_] = shared_from_this(); + is_top_ = is_top; + if (is_top) { + scope_ = primal_graph_->scope(); + } +} + +void DFunctor::Clear() { + func_graph_to_functor_.clear(); + anfnode_to_adjoin_definition_.clear(); + scope_.clear(); +} + +void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { + auto fv_adjoint = anfnode_to_adjoin_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() + << " " << fv->ToString() << "."; + fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { + MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + auto parent_adjoint = FindAdjoint(fv); + AdjointPtr adjoint = nullptr; + if (parent_adjoint != nullptr) { + adjoint = std::make_shared(fv, parent_adjoint->k(), tape_); + } else { + MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + adjoint = std::make_shared(fv, nullptr, tape_); + } + anfnode_to_adjoin_indirect_fv_[fv] = adjoint; + fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); + } + } + auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); + fv_adjoint->second->RegisterKUser(node, 1); + auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()}); + fv_adjoint->second->RegisterKUser(default_val, 1); + auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val}); + MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " + << fv->func_graph()->ToString() << " " << fv->ToString() << "."; + MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << "."; + fv_adjoint->second->AccumulateDout(dfv); +} + +void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) { + // Take switch_layer as a set of candidate functions. + auto input = cnode_morph->input(2); + if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; + } + auto tuple_graphs = input->cast(); + for (size_t i = 1; i < tuple_graphs->size(); ++i) { + auto graph = tuple_graphs->input(i); + if (!IsValueNode(graph)) { + MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString() + << " as the " << i << "th element."; + } + auto func_graph = GetValueNode(graph); + auto functor = func_graph_to_functor_.find(func_graph); + if (functor == func_graph_to_functor_.end()) { + MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] " + << func_graph->ToString() << "."; + } + // Consider direct and indirect fvs. + for (auto fv : func_graph->free_variables_nodes()) { + BackPropagateFv(fv, env); + } + for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " + << indirect_fv.first->ToString() << "."; + BackPropagateFv(indirect_fv.first, env); + } + } +} + +void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { + auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)}); + // Call with delimited continuation dout. + auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); + node_adjoint->RegisterDoutUser(bprop_app, 1); + // Special case for switch_layer + if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { + auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)}); + BackPropagateSwitchLayer(cnode_morph, din); + return; + } + for (size_t i = 0; i < cnode_morph->size(); i++) { + auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))}); + auto input = cnode_morph->input(i); + // Backprop sens wrt fvs. + if (IsValueNode(input)) { + auto func_graph = GetValueNode(input); + auto functor = func_graph_to_functor_.find(func_graph); + if (functor == func_graph_to_functor_.end()) { + MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] " + << func_graph->ToString() << "."; + } + // Consider direct and indirect fvs. + for (auto fv : func_graph->free_variables_nodes()) { + BackPropagateFv(fv, din); + } + for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " " + << indirect_fv.first->ToString() << "."; + BackPropagateFv(indirect_fv.first, din); + } + continue; + } + // Backprop sens wrt inputs. + auto input_adjoint = anfnode_to_adjoin_.find(input); + if (input_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << "."; + } + input_adjoint->second->AccumulateDout(din); + } +} + +// Map a morphism. +AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { + // MapMorphism All type except CNode should already be mapped by MapObject. + if (!morph->isa()) { + return nullptr; + } + ScopeGuard scope_guard(morph->scope()); + auto cnode_morph = morph->cast(); + + std::vector inputs; + std::vector param_adjoints; + for (size_t i = 0; i < cnode_morph->size(); i++) { + auto node = cnode_morph->input(i); + auto node_adjoint_iter = anfnode_to_adjoin_.find(node); + AdjointPtr node_adjoint = nullptr; + AnfNodePtr k = nullptr; + if (node_adjoint_iter != anfnode_to_adjoin_.end()) { + node_adjoint = node_adjoint_iter->second; + } else { + // Input might be a CNode that needs to be handled before hand. + node_adjoint = MapMorphism(node); + } + MS_EXCEPTION_IF_NULL(node_adjoint); + k = node_adjoint->k(); + if (k == nullptr) { + MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; + } + inputs.push_back(k); + param_adjoints.push_back(node_adjoint); + } + TraceManager::DebugTrace(std::make_shared(cnode_morph->debug_info())); + auto k_app = k_graph_->NewCNode(inputs); + TraceManager::EndTrace(); + for (size_t i = 0; i < param_adjoints.size(); ++i) { + param_adjoints[i]->RegisterKUser(k_app, i); + } + + // Do forward computation + auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)}); + // K:: cnode -> forward_app + auto node_adjoint = std::make_shared(morph, foward_app, tape_); + UpdateAdjoint(node_adjoint); + anfnode_to_adjoin_[morph] = node_adjoint; + if (cnode_morph->stop_gradient()) { + MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped."; + return node_adjoint; + } + + // Do sens backpropagation + BackPropagate(cnode_morph, k_app, node_adjoint); + MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << "."; + return node_adjoint; +} + +bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { + // Do not care about non-CNode + if (!node->isa()) { + return false; + } + // Do not care about kPrimReturn + if (IsPrimitiveCNode(node, prim::kPrimReturn)) { + return false; + } + auto &users = primal_graph_->manager()->node_users()[node]; + // Do not care about isolated morphisms + if (users.empty()) { + return false; + } + // Not free if it's used by some node in primal_graph + bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) { + auto &user = kv.first; + return user->func_graph() == primal_graph_; + }); + return !nonfree; +} + +void DFunctor::MapFreeMorphism() { + // Handle cnode not attached to output, that might be refered in other functions. + for (auto &node : primal_graph_->nodes()) { + if (!IsFreeMorphism(node)) { + continue; + } + MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; + (void)MapMorphism(node); + } +} + +AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { + AnfNodePtr new_grad_fv = grad_fv; + // Add grads wrt fv. + const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); + for (auto &fv : free_variables_nodes) { + auto fv_adjoint = anfnode_to_adjoin_.find(fv); + if (fv_adjoint == anfnode_to_adjoin_.end()) { + MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << "."; + } + auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); + fv_adjoint->second->RegisterKUser(node, 1); + auto sens = fv_adjoint->second->dout(); + new_grad_fv = tape_->NewCNode({ + NewValueNode(prim::kPrimEnvSetItem), + new_grad_fv, + node, + sens, + }); + fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast(), 3); + MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " " + << fv->ToString() << " " << primal_graph_->ToString() << "."; + } + return new_grad_fv; +} + +AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { + AnfNodePtr new_grad_fv = grad_fv; + // Add indirect fv bprop. + for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) { + MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " " + << primal_graph_->ToString() << "."; + auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()}); + fv_adjoint.second->RegisterKUser(node, 1); + auto sens = fv_adjoint.second->dout(); + new_grad_fv = tape_->NewCNode({ + NewValueNode(prim::kPrimEnvSetItem), + new_grad_fv, + node, + sens, + }); + fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast(), 3); + MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to " + << new_grad_fv->ToString() << "."; + } + return new_grad_fv; +} + +void DFunctor::MapMorphism() { + // Set stop_gradient before MapMorphism. + BroadCastStopFlag(); + + // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent + MapFreeMorphism(); + // Handle morphism from output. + (void)MapMorphism(primal_graph_->output()); + + // Construct K for primal_graph_ + auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); + // Attach dout_ parameter to output_adjoint. + output_adjoint->second->AccumulateDout(dout_); + + // Set output for tape closure. + auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); + + std::vector inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv}; + // Add grads wrt inputs. + std::vector param_adjoints; + for (auto ¶m : primal_graph_->parameters()) { + auto param_adjoint = anfnode_to_adjoin_.find(param); + inputs.push_back(param_adjoint->second->dout()); + param_adjoints.push_back(param_adjoint->second); + } + auto tape_output = tape_->NewCNode(inputs); + for (size_t i = 0; i < param_adjoints.size(); ++i) { + param_adjoints[i]->RegisterDoutUser(tape_output, i + 2); + } + tape_->set_output(tape_output); + // Set output for k_graph_, K:: cnode->forward_app. + auto forward_app = output_adjoint->second->k(); + auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)}); + output_adjoint->second->RegisterKUser(output, 1); + k_graph_->set_output(output); + (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_))); + (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_))); +} + +FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { + // K user defined cell bprop. + auto bprop = primal->transforms().find("bprop"); + if (bprop != primal->transforms().end()) { + FuncGraphPtr bprop_graph = bprop->second.func_graph(); + resources_->manager()->AddFuncGraph(bprop_graph); + + if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { + MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " + << primal->output()->scope()->name() << " does not support Parameter data type."; + } + auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph); + if (fg == nullptr) { + MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope " + << primal->output()->scope()->name() << "."; + } + + // Cache the grad func + (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); + (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); + // Reset defer_inline to enable successive inlining + primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); + + auto functor = std::make_shared(primal, resources_); + functor->Init(); + functor->k_graph_ = fg; + + return fg; + } + return nullptr; +} + +// MapToK(func) +AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { + auto f = func_graph_to_functor_.find(primal); + if (f != func_graph_to_functor_.end()) { + MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; + return NewValueNode(f->second->k_graph_); + } + + auto k_user_defined = KUserDefined(primal); + if (k_user_defined != nullptr) { + MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; + return NewValueNode(k_user_defined); + } + + auto functor = std::make_shared(primal, resources_); + functor->Init(); + functor->MapObject(); + functor->MapMorphism(); + + MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; + return NewValueNode(functor->k_graph_); +} + +// Construct representation graph for given node. +AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { + ScopeGuard scope_guard(primal->scope()); + // MapToK(prim) + if (IsValueNode(primal)) { + auto value_node = primal->cast(); + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { + MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; + need_cut_ = true; + } + auto k_prim = g_k_prims.KPrimitive(value_node, resources_); + if (k_prim != nullptr) { + return NewValueNode(k_prim); + } + // When failed to find k_prim, try k_meta. + auto k_meta = g_k_prims.KMetaFuncGraph(prim); + if (k_meta != nullptr) { + return NewValueNode(k_meta); + } + } + + // MapToK(func) + if (IsValueNode(primal)) { + auto func_graph = GetValueNode(primal); + auto k_func = MapToK(func_graph); + return k_func; + } + + if (primal->isa()) { + TraceManager::DebugTrace(std::make_shared(primal->debug_info())); + auto ret = k_graph_->add_parameter(); + TraceManager::EndTrace(); + return ret; + } + + if (!primal->isa()) { + MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode."; + } + return primal; +} + +bool DFunctor::IsInScope(const AnfNodePtr &node) { + return std::any_of(scope_.begin(), scope_.end(), + [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; }); +} + +void DFunctor::MapFvObject() { + // Map free variable. + const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); + for (auto &node : free_variables_nodes) { + ScopeGuard scope_guard(node->scope()); + MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << "."; + // Find fv's K from parent. + AdjointPtr adjoint = nullptr; + auto parent_adjoint = FindAdjoint(node); + if (parent_adjoint != nullptr) { + adjoint = std::make_shared(node, parent_adjoint->k(), tape_); + } else { + if (is_top_ || node->isa() || !IsInScope(node)) { + // Out of ad scope, add adjoint for free variables. + adjoint = std::make_shared(node, node, tape_); + UpdateAdjoint(adjoint); + } else { + MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << "."; + adjoint = std::make_shared(node, nullptr, tape_); + } + } + if (adjoint == nullptr) { + MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << "."; + } + anfnode_to_adjoin_[node] = adjoint; + } +} + +void DFunctor::MapParamObject() { + // Map parameter. + for (auto &p : primal_graph_->parameters()) { + ScopeGuard scope_guard(p->scope()); + MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; + auto adjoint = std::make_shared(p, MapToK(p), tape_); + UpdateAdjoint(adjoint); + anfnode_to_adjoin_[p] = adjoint; + } +} + +void DFunctor::MapValueObject() { + // Map ValueNode. + auto manager = resources_->manager(); + auto &value_nodes = primal_graph_->value_nodes(); + for (const auto &value_pair : value_nodes) { + auto node = value_pair.first; + auto parent_adjoint = FindAdjoint(node); + if (parent_adjoint != nullptr) { + auto adjoint = std::make_shared(node, parent_adjoint->k(), tape_); + anfnode_to_adjoin_[node] = adjoint; + continue; + } + // Skip Return. + if (IsValueNode(node) && GetValueNode(node) == prim::kPrimReturn) { + continue; + } + MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << "."; + auto adjoint = std::make_shared(node, MapToK(node), tape_); + UpdateAdjoint(adjoint); + anfnode_to_adjoin_[node] = adjoint; + } +} + +// Skip morphism. +void DFunctor::MapObject() { + // The order does not matter + MapFvObject(); + MapParamObject(); + MapValueObject(); +} + +void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) { + auto primal = adjoint_definition->primal(); + if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) { + MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " " + << primal->ToString() << "."; + } + anfnode_to_adjoin_definition_[primal] = adjoint_definition; + // Update k hole for primal. + for (auto &f : func_graph_to_functor_) { + auto adjoint = f.second->anfnode_to_adjoin_.find(primal); + if (adjoint != f.second->anfnode_to_adjoin_.end()) { + adjoint->second->UpdateK(adjoint_definition->k()); + } + adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal); + if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) { + adjoint->second->UpdateK(adjoint_definition->k()); + } + } +} + +AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) { + auto adjoint = anfnode_to_adjoin_definition_.find(primal); + if (adjoint != anfnode_to_adjoin_definition_.end()) { + MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << "."; + return adjoint->second; + } + MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << "."; + return nullptr; +} + +void DFunctor::CallDoutHoleOnTape() { + if (!is_top_) { + return; + } + + // Call dout hole of all adjoint. + for (auto &f : func_graph_to_functor_) { + for (auto &adjoint : f.second->anfnode_to_adjoin_) { + adjoint.second->CallDoutHole(); + } + for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) { + adjoint.second->CallDoutHole(); + } + } +} +FuncGraphPtr DFunctor::k_graph() { + CallDoutHoleOnTape(); + return k_graph_; +} + +void DFunctor::BroadCastStopFlag() { + // As stop set expanding, all directly or indirectly stopped CNode will be cut off + while (need_cut_) { + need_cut_ = false; + for (auto &node : primal_graph_->nodes()) { + if (node->isa()) { + auto cnode = node->cast(); + if (!cnode->stop_gradient()) { + // Cut off the cnode only when it's not referred any more + if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) { + MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; + cnode->set_stop_gradient(true); + // The stop set changed, more cut required + need_cut_ = true; + } + } + } + } + } +} + +bool DFunctor::AllReferencesStopped(const CNodePtr &node) { + auto &users = primal_graph_->manager()->node_users()[node]; + // Only care about stop_gradient caused cutting + if (users.empty()) { + return false; + } + for (auto &kv : users) { + auto &user = kv.first; + if (!user->isa() || !user->cast()->stop_gradient()) { + return false; + } + } + return true; +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h new file mode 100644 index 0000000000..9ee93334e8 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -0,0 +1,210 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/meta_func_graph.h" +#include "ir/func_graph_cloner.h" +#include "pipeline/jit/resource.h" +#include "frontend/optimizer/ad/adjoint.h" +#include "frontend/operator/ops.h" +#include "debug/trace.h" + +namespace mindspore { +namespace ad { +struct PrimitiveTotalEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return *t1 == *t2; + } +}; + +using Registry = std::unordered_map; +class KPrim; +extern KPrim g_k_prims; +class DFunctor; +using DFunctorPtr = std::shared_ptr; + +// D Functor's rules to map closure object and morphisms. +class DFunctor : public std::enable_shared_from_this { + public: + DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); + ~DFunctor() = default; + // Map object in D category to K category. + void MapObject(); + // Map morphism in D category to K category. + void MapMorphism(); + FuncGraphPtr k_graph(); + // Construct user defined k object. + FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); + // Register functor objects to form a global view. + void Init(bool is_top = false); + bool IsInScope(const AnfNodePtr &node); + + // Clear resources. + static void Clear(); + + private: + // Map one morphism. + AdjointPtr MapMorphism(const AnfNodePtr &morph); + bool IsFreeMorphism(const AnfNodePtr &node); + // Map morphism that's not attached to output. + void MapFreeMorphism(); + void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); + void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); + void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); + AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); + AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); + // Map Anfnode object from D category to K category. + AnfNodePtr MapToK(const AnfNodePtr &primal); + // Map FuncGraph object from D category to K category. + AnfNodePtr MapToK(const FuncGraphPtr &primal); + // MapObject impls. + void MapFvObject(); + void MapValueObject(); + void MapParamObject(); + // Find adjoint with its primary k. + AdjointPtr FindAdjoint(const AnfNodePtr &primal); + // Broadcast stop flags. + void BroadCastStopFlag(); + bool AllReferencesStopped(const CNodePtr &node); + // Update k hole with adjoint_definition, only applied in recursive case. + void UpdateAdjoint(const AdjointPtr &adjoint_definition); + void CallDoutHoleOnTape(); + + std::unordered_map anfnode_to_adjoin_; + // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. + std::unordered_map anfnode_to_adjoin_indirect_fv_; + FuncGraphPtr primal_graph_; + // K object for primal_graph_; + FuncGraphPtr k_graph_; + // The Backprop part of k_graph_. + FuncGraphPtr tape_; + // Dout parameter for primal_graph_. + AnfNodePtr dout_; + pipeline::ResourceBasePtr resources_; + // Cut off stopped objects in category D. + bool need_cut_; + bool is_top_; + static std::unordered_map> func_graph_to_functor_; + static std::unordered_map anfnode_to_adjoin_definition_; + static FuncGraphSet scope_; +}; + +// D Functor's rules to map primitive object. +class KPrim { + public: + KPrim() = default; + ~KPrim() = default; + + FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); + FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); + + void clear() { + bprop_registry_meta_.clear(); + bprop_registry_.clear(); + } + + private: + FuncGraphPtr GetBprop(const PrimitivePtr &prim); + FuncGraphPtr GetFprop(const PrimitivePtr &prim); + FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + // Given a bprop rule, do the K mapping. + template + FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); + AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); + void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, + std::vector *const transf_args); + void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); + + Registry bprop_registry_; + std::unordered_map bprop_registry_meta_; +}; + +template +FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { + MS_EXCEPTION_IF_NULL(primal); + MS_EXCEPTION_IF_NULL(bprop_fg); + CheckBprop(bprop_fg, primal->ToString()); + + auto debug_info = std::make_shared(); + debug_info->set_name(primal->ToString()); + + auto cloned_bprop_fg = BasicClone(bprop_fg); + MS_EXCEPTION_IF_NULL(cloned_bprop_fg); + + cloned_bprop_fg->debug_info()->set_name(""); + cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); + + AnfNodePtr bout = BuildOutput(cloned_bprop_fg); + cloned_bprop_fg->set_output(bout); + + TraceManager::DebugTrace(std::make_shared(debug_info)); + auto outer = std::make_shared(); + (void)outer->transforms().emplace("primal", FuncGraphTransform(primal)); + outer->set_output(NewValueNode(kNone)); + TraceManager::EndTrace(); + + auto mng = Manage({cloned_bprop_fg, outer}, false); + + // Make sure (out, dout) provided. + if (cloned_bprop_fg->parameters().size() < 2) { + MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() + << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() + << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); + } + + // In a bprop definition, the last two param should be out and dout. + auto dout = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 1]; + auto out_param = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 2]; + std::vector transf_args; + TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); + + TraceManager::DebugTrace(std::make_shared(dout->debug_info())); + (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); + auto out_value = outer->NewCNode(transf_args); + TraceManager::EndTrace(); + + (void)mng->Replace(out_param, out_value); + + TraceManager::DebugTrace(std::make_shared(out_param->debug_info())); + auto new_dout = cloned_bprop_fg->add_parameter(); + (void)mng->Replace(dout, new_dout); + // We remove all parameters except new_dout. + std::vector newBpropParams = {new_dout}; + cloned_bprop_fg->set_parameters(newBpropParams); + TraceManager::EndTrace(); + + outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); + return BasicClone(outer); +} +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc new file mode 100644 index 0000000000..ef2d7d400a --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/ad/grad.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "ir/func_graph_cloner.h" +#include "utils/context/ms_context.h" +#include "utils/symbolic.h" +#include "utils/graph_utils.h" + +namespace mindspore { +namespace ad { +FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { + MS_EXCEPTION_IF_NULL(func_graph); + auto gradkv = func_graph->transforms().find("grad"); + if (gradkv != func_graph->transforms().end()) { + return gradkv->second.func_graph(); + } + + auto manager_ptr = resources->manager(); + MS_EXCEPTION_IF_NULL(manager_ptr); + manager_ptr->AddFuncGraph(func_graph); + + auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { + if (MsContext::GetInstance()->is_multi_graph_sink()) { + if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + } + } + }; + + auto f = std::make_shared(func_graph, resources); + auto user_defined = f->KUserDefined(func_graph); + if (user_defined != nullptr) { + multi_graph_sink(user_defined); + if (is_top) { + DFunctor::Clear(); + } + return user_defined; + } + f->Init(is_top); + f->MapObject(); + f->MapMorphism(); + auto ret = f->k_graph(); + if (is_top) { + DFunctor::Clear(); + } + + multi_graph_sink(ret); + return ret; +} + +FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto fg = g_k_prims.KPrimitive(value_node, resources); + if (fg == nullptr) { + return nullptr; + } + return BasicClone(fg); +} + +MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &) { + MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim); + return fg; +} + +void CleanRes() { DFunctor::Clear(); } +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.h b/mindspore/ccsrc/frontend/optimizer/ad/grad.h new file mode 100644 index 0000000000..ee9ab79ffb --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ + +#include +#include + +#include "ir/anf.h" +#include "ir/meta_func_graph.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace ad { +using ResourcePtr = std::shared_ptr; + +FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true); +FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); +MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); +void CleanRes(); +} // namespace ad +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc new file mode 100644 index 0000000000..5ca2ca6c43 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -0,0 +1,291 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/primitive_py.h" +#include "ir/meta_func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/manager.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "frontend/optimizer/opt.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "utils/symbolic.h" +#include "utils/primitive_utils.h" +#include "utils/context/ms_context.h" +#include "debug/info.h" +#include "debug/trace.h" + +#include "./common.h" + +namespace mindspore { +namespace ad { +using PatternListType = std::initializer_list; +KPrim g_k_prims; + +FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { + // Set a child scope named "grad'PrimitiveName'" for the bprop function, + // and add "Gradients" to the front. + static const std::string gradients_scope = "Gradients/"; + static const std::string grad_op_child_scope_prefix = "/grad"; + MS_EXCEPTION_IF_NULL(prim); + auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + + grad_op_child_scope_prefix + prim->name()); + ScopeGuard scope_guard(scope); + py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast()->GetBpropFunction(); + if (fn == nullptr || py::isinstance(fn)) { + MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; + return nullptr; + } + FuncGraphPtr func_graph = parse::ParsePythonCode(fn); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; + return nullptr; + } + return func_graph; +} + +FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { + static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; + std::string func_name = "_fprop_" + prim->name(); + py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); + auto func_graph = parse::ParsePythonCode(fn); + MS_EXCEPTION_IF_NULL(func_graph); + return BasicClone(func_graph); +} + +MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + + auto iter = bprop_registry_meta_.find(prim); + if (iter != bprop_registry_meta_.end()) { + return iter->second; + } + + if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { + MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); + bprop_registry_meta_[prim::kPrimMakeTuple] = meta; + return meta; + } + + MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; +} + +FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + if (!IsValueNode(value_node)) { + MS_LOG(EXCEPTION) << "Primitive node is not valid."; + } + + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { + auto fprop = GetFprop(prim); + fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); + return fprop; + } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { + return nullptr; + } + + FuncGraphPtr bprop_fg = nullptr; + if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { + bprop_fg = BpropCut(value_node, resources); + } else { + auto iter = bprop_registry_.find(prim); + if (iter != bprop_registry_.end()) { + bprop_fg = iter->second; + } + + if (bprop_fg == nullptr) { + bprop_fg = GetBprop(prim); + if (bprop_fg != nullptr) { + // Set bprop_g graph cache + bprop_registry_[prim] = bprop_fg; + } else { + bprop_fg = FakeBprop(value_node, resources); + } + } + } + + auto expanded_fg = BpropToK(prim, bprop_fg); + if (expanded_fg == nullptr) { + MS_LOG(EXCEPTION) << "Failed convert " << prim->name() + << " prim bprop function to J expanded func graph. NodeInfo: " + << trace::GetDebugInfo(bprop_fg->debug_info()); + } + + return expanded_fg; +} + +AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) { + // bprop_fg has been checked in caller + if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { + // Set bprop output as (env, dx, dy, dz, ...) + auto cbprop = bprop_fg->output()->cast(); + auto &inputs = cbprop->inputs(); + + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + args.push_back(NewValueNode(newenv)); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + return NewCNode(args, bprop_fg); + } + + // Set bprop output as (env, dx) + std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); + std::string python_ops("_tuple_add"); + auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); + return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg); +} + +void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, + std::vector *const transf_args) { + MS_EXCEPTION_IF_NULL(mng); + // bprop_fg has been checked in caller + // transform except the last 2 parameters: out, dout. + for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) { + auto p = bprop_fg->parameters()[i]; + MS_EXCEPTION_IF_NULL(p); + + TraceManager::DebugTrace(std::make_shared(p->debug_info())); + auto transf_p = outer->add_parameter(); + TraceManager::EndTrace(); + + (void)mng->Replace(p, transf_p); + transf_args->push_back(transf_p); + } +} + +void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool check_bprop_flag = context->check_bprop_flag(); + // Skip checking if check_bprop not set + if (!check_bprop_flag) { + return; + } + + // bprop_fg has been checked in caller + auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops"); + MS_EXCEPTION_IF_NULL(check_bprop_class); + auto check_bprop = + bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared(prim_to_check))}); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); + AnfNodePtr params = bprop_fg->NewCNode(inputs); + + inputs.clear(); + inputs.push_back(check_bprop); + inputs.push_back(bprop_fg->output()); + inputs.push_back(params); + AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); + bprop_fg->set_output(bprop_out); +} + +FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { + MS_EXCEPTION_IF_NULL(bprop_fg); + auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); + auto expanded_fg = BpropToK(fprop_fg, bprop_fg); + if (expanded_fg == nullptr) { + MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() + << " Cell bprop function to K expanded func graph. NodeInfo: " + << trace::GetDebugInfo(fprop_fg->debug_info()); + } + return expanded_fg; +} + +FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto prim = GetValueNode(value_node); + MS_EXCEPTION_IF_NULL(prim); + auto &node_users = resources->manager()->node_users(); + + auto &users = node_users[value_node]; + auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { + return IsPrimitiveCNode(user.first, prim); + }); + if (cnode == users.end()) { + MS_LOG(EXCEPTION) << "Fail to find cnode."; + } + auto inputs_num = cnode->first->cast()->size() - 1; + + auto func_graph = std::make_shared(); + std::vector outputs; + + auto bprop_cut = std::make_shared("bprop_cut", py::object()); + bprop_cut->CopyHookFunction(prim); + + auto cell_id = GetValue(prim->GetAttr("cell_id")); + if (cell_id != "") { + (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); + (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); + } + + outputs.push_back(NewValueNode(bprop_cut)); + for (size_t i = 0; i < inputs_num; ++i) { + auto param = func_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = func_graph->add_parameter(); + auto p2 = func_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} + +FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { + auto prim = value_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + auto &node_users = resources->manager()->node_users(); + + auto &users = node_users[value_node]; + auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { + return IsPrimitiveCNode(user.first, prim); + }); + if (cnode == users.end()) { + MS_LOG(EXCEPTION) << "Fail to find cnode."; + } + auto inputs_num = cnode->first->cast()->inputs().size() - 1; + + auto func_graph = std::make_shared(); + std::vector outputs; + outputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto fake_bprop = std::make_shared("fake_bprop"); + (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); + + for (size_t i = 0; i < inputs_num; ++i) { + // Mock params for inputs + auto param = func_graph->add_parameter(); + // Mock derivatives for each inputs + outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param})); + } + // mock params for out and dout + (void)func_graph->add_parameter(); + (void)func_graph->add_parameter(); + func_graph->set_output(func_graph->NewCNode(outputs)); + return func_graph; +} +} // namespace ad +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc new file mode 100644 index 0000000000..45a271f692 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -0,0 +1,531 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/clean.h" +#include +#include +#include +#include +#include +#include "./common.h" +#include "debug/trace.h" +#include "frontend/operator/composite/composite.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractClass; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractJTagged; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractUndetermined; + +static AbstractBasePtr Reabs(const AbstractBasePtr &t) { + if (t == nullptr) { + return nullptr; + } + + AbstractBasePtr res = t; + if (t->isa()) { + auto abs_class = dyn_cast(t); + AbstractBasePtrList baselist; + auto attributes = abs_class->attributes(); + (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), + [](const AbstractAttribute &item) { return item.second; }); + res = std::make_shared(baselist); + } else if (t->isa()) { + auto abs_dict = dyn_cast(t); + AbstractBasePtrList baselist; + auto elements = abs_dict->elements(); + (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), + [](const AbstractAttribute &item) { return item.second; }); + res = std::make_shared(baselist); + } else if (t->isa()) { + auto abs_dict = dyn_cast(t); + res = std::make_shared(abs_dict->elements()); + } + return res; +} + +AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [getattr, data, attribute] + MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { + return nullptr; + } + + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; + } + + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->attributes(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); +} + +AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + // Inputs should be [dict_getitem, dict, item] + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + MS_EXCEPTION_IF_NULL(dt); + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name(); + } + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->elements(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); +} + +AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + // Inputs should be [dict_setitem, dict, item, value] + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs."); + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + AnfNodePtr item_value = inputs[3]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto dt = data->abstract(); + MS_EXCEPTION_IF_NULL(dt); + if (!dt->isa()) { + MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); + } + auto cons_is_str = IsValueNode(cons); + auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; + + auto ct = dyn_cast(dt); + const auto &cmap = ct->elements(); + int count = 0; + for (auto &item : cmap) { + if (cons_is_str && item.first == cons_str) { + break; + } + count++; + } + if (IntToSize(count) >= cmap.size()) { + // for dictionary set, if the key does not exist, we should create a new item + auto tuple_add_op = std::make_shared("tuple_add"); + auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value}); + return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item}); + } + auto idx_c = NewValueNode(count); + AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); + idx_c->set_abstract(aptr); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); +} + +AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr; + (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end()); + return node->func_graph()->NewCNode(inputs); +} + +AnfNodePtr ErasePartialNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; + MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); + + std::vector args(inputs.begin() + 2, inputs.end()); + auto oper = inputs[1]; + if (IsPrimitive(oper, prim::kPrimMakeRecord)) { + if (args.size() == 1) { + return NewValueNode(prim::kPrimMakeTuple); + } + + if (args.size() > 1) { + std::vector new_inputs; + new_inputs.emplace_back(NewValueNode(prim::kPrimPartial)); + new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end()); + + MS_EXCEPTION_IF_NULL(node->func_graph()); + return node->func_graph()->NewCNode(new_inputs); + } + } + return nullptr; +} + +AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + std::vector inputs; + inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items; + (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); + return node->func_graph()->NewCNode(inputs); +} + +AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [list_getitem, list, item] + if (inputs.size() < 3) { + MS_LOG(EXCEPTION) << "Node's input number < 3."; + } + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + MS_EXCEPTION_IF_NULL(data); + MS_EXCEPTION_IF_NULL(cons); + + auto cons_node = cons->cast(); + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); +} + +AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(node->func_graph()); + + const auto &inputs = node->inputs(); + // Inputs should be [list_setitem, list, index, item] + if (inputs.size() < 4) { + MS_LOG(EXCEPTION) << "Node's input number < 4."; + } + + AnfNodePtr data = inputs[1]; + AnfNodePtr cons = inputs[2]; + AnfNodePtr value = inputs[3]; + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); +} + +AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); + return inputs[2]; +} + +AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + // Inputs should be [make_keyword_arg, key, value] + MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); + return inputs[2]; +} + +AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + const auto &inputs = node->inputs(); + // Inputs should be [extract_keyword_arg, arg, key] + MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); + return inputs[2]; +} + +ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { + const int DEPTH_MAX = 5; + if (depth > DEPTH_MAX) { + MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; + } + std::vector elements; + for (const auto &it : value_list->value()) { + ValuePtr value = nullptr; + if (it->isa()) { + value = ConvertValueListToValueTuple(it->cast(), depth + 1); + } else { + value = it; + } + elements.push_back(value); + } + return std::make_shared(elements); +} + +AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + ValuePtr value = node->value(); + auto value_list = value->cast(); + MS_EXCEPTION_IF_NULL(value_list); + int depth = 0; + return std::make_shared(ConvertValueListToValueTuple(value_list, depth)); +} + +// Convert class to Tuple +// Convert getattr to getitem +// Convert make_record to make_tuple +bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + bool changed = false; + + // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var + AnfNodeSet all_node = manager->all_nodes(); + for (auto &node : all_node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + AnfNodePtr new_node = nullptr; + if (IsValueNode(node)) { + new_node = NewValueNode(prim::kPrimMakeTuple); + } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) { + new_node = ConvertGetAttrToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) { + new_node = ConvertMakeRecordToMakeTuple(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) { + new_node = ErasePartialNode(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { + new_node = ConvertDictGetItemToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { + new_node = ConvertDictSetItemToTupleSetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { + new_node = EraseMakeDictNode(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { + new_node = EraseMakeKeywordArgNode(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { + new_node = EraseExtractKeywordArg(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { + new_node = ConvertMakeListToMakeTuple(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { + new_node = ConvertListGetItemToTupleGetItem(cnode); + } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) { + new_node = ConvertListSetItemToTupleSetItem(cnode); + } else if (IsValueNode(node)) { + new_node = ConvertValueListNodeToValueTupleNode(node->cast()); + } + + if (new_node != nullptr) { + new_node->set_abstract(node->abstract()); + MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); + (void)manager->Replace(node, new_node); + changed = true; + } + } + + for (auto &node : manager->all_nodes()) { + auto ret = Reabs(node->abstract()); + node->set_abstract(ret); + } + return changed; +} + +// expand tuples in graph parameters +static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, + const std::vector ¶ms) { + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector new_params; + for (const auto ¶m : params) { + MS_EXCEPTION_IF_NULL(param); + auto param_abs = param->abstract(); + MS_EXCEPTION_IF_NULL(param_abs); + + if (param_abs->isa()) { + MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info()); + } + + if (!param_abs->isa()) { + new_params.emplace_back(param); + continue; + } + + std::vector new_param; + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + auto abs_tuple = dyn_cast(param_abs); + for (auto &elem : abs_tuple->elements()) { + auto np = std::make_shared(func_graph); + np->set_abstract(elem); + new_param.emplace_back(np); + } + (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end()); + auto new_tuple = func_graph->NewCNode(inputs); + (void)mng->Replace(param, new_tuple); + + auto expand_param = ExpandTuplesP(mng, func_graph, new_param); + (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end()); + } + return new_params; +} + +// expand tuples in graph applies +static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { + MS_EXCEPTION_IF_NULL(graph); + + std::vector new_inputs; + for (const auto &input : inputs) { + MS_EXCEPTION_IF_NULL(input); + + auto input_abs = input->abstract(); + MS_EXCEPTION_IF_NULL(input_abs); + + if (input_abs->isa()) { + auto abstract_tag = dyn_cast(input_abs); + if (abstract_tag->element()->isa()) { + MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info()); + } + } + + if (!input_abs->isa()) { + new_inputs.emplace_back(input); + continue; + } + + int idx = 0; + std::vector new_input; + auto abs_tuple = dyn_cast(input_abs); + for (auto &elem : abs_tuple->elements()) { + auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); + AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); + c_node->input(2)->set_abstract(aptr); + c_node->set_abstract(elem); + new_input.emplace_back(c_node); + idx++; + } + + auto expand_tuple = ExpandTuplesC(graph, new_input); + (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end()); + } + + return new_inputs; +} + +// remove most uses of tuples from the graph parameters & apply inputs +// tuples that are returned will be kept +// tuples in CNode's inputs: AbstractTuple (a, b ,c) --> +// CNode("tuple_getitem", (a,b,c), 0) +// CNode("tuple_getitem", (a,b,c), 1) +// CNode("tuple_getitem", (a,b,c), 2) +// tuples in Graph's parameters: AbstractTuple (a, b, c) --> +// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) +// cppcheck-suppress unusedFunction +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var + AnfNodeSet all_node = manager->all_nodes(); + for (auto &node : all_node) { + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + + const auto &inputs = cnode->inputs(); + + // Bypass the first input in inputs as it's fn. + if (!IsValueNode(inputs[0])) { + std::vector expand_inputs; + (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end()); + + auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); + if (new_inputs != expand_inputs) { + std::vector cnode_inputs{inputs[0]}; + (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); + + MS_EXCEPTION_IF_NULL(node->func_graph()); + auto new_node = node->func_graph()->NewCNode(cnode_inputs); + new_node->set_abstract(node->abstract()); + + (void)manager->Replace(node, new_node); + } + // Bypass the first 2 inputs in inputs as it's [partial, fn]. + } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode(inputs[1])) { + std::vector expand_inputs; + (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end()); + + auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); + if (new_inputs != expand_inputs) { + std::vector cnode_inputs{inputs[0], inputs[1]}; + (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); + + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + auto new_node = cnode->func_graph()->NewCNode(cnode_inputs); + new_node->set_abstract(cnode->abstract()); + + (void)manager->Replace(node, new_node); + } + } + } + + FuncGraphSet all_graph = manager->func_graphs(); + for (auto &func_graph : all_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); + manager->SetParameters(func_graph, expand_p); + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/clean.h b/mindspore/ccsrc/frontend/optimizer/clean.h new file mode 100644 index 0000000000..54faabaa63 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/clean.h @@ -0,0 +1,43 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ + +#include +#include "ir/anf.h" +#include "frontend/operator/ops.h" +#include "utils/any.h" +#include "ir/manager.h" +#include "abstract/dshape.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +// Remove the class type from graphs +bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); + +// Remove most uses of tuples from the graph +// tuples that are returned will be kept +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); + +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/control_depend.cc b/mindspore/ccsrc/frontend/optimizer/control_depend.cc new file mode 100644 index 0000000000..8cc9bdb7f4 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/control_depend.cc @@ -0,0 +1,122 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/control_depend.h" + +#include +#include +#include +#include +#include + +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +std::vector DoControlDepend(const FuncGraphPtr &graph, const CNodePtr &return_node, + const std::vector &effect_index, const std::vector &cnodes) { + std::vector depend_nodes{NewValueNode(prim::kPrimDepend), return_node->input(1)}; + std::vector make_tuple{NewValueNode(prim::kPrimMakeTuple)}; + size_t effect_size = effect_index.size(); + for (size_t i = 0; i < effect_size; i++) { + size_t pre_index = 0; + if (i > 0) { + pre_index = effect_index[i - 1] + 1; + } + size_t this_index = effect_index[i]; + size_t last_index = cnodes.size() - 2; + if (i < effect_size - 1) { + last_index = effect_index[i + 1]; + } + + if (this_index > pre_index) { + std::vector pre_segment; + for (size_t k = pre_index; k < this_index; k++) { + // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. + if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || + IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { + continue; + } + pre_segment.push_back(cnodes[k]); + } + auto roots = FindRoots(pre_segment); + for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { + AnfNodePtr control_depend = + graph->NewCNode({NewValueNode(prim::kPrimControlDepend), *iter, cnodes[this_index]}); + make_tuple.push_back(control_depend); + } + } + if (last_index > this_index) { + std::vector last_segment; + for (size_t k = this_index + 1; k <= last_index; k++) { + // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. + if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || + IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { + continue; + } + last_segment.push_back(cnodes[k]); + } + auto leaves = FindLeaves(last_segment); + for (auto iter = leaves->begin(); iter != leaves->end(); (void)iter++) { + AnfNodePtr control_depend = + graph->NewCNode({NewValueNode(prim::kPrimControlDepend), cnodes[this_index], *iter}); + make_tuple.push_back(control_depend); + } + } + } + depend_nodes.push_back(graph->NewCNode(make_tuple)); + return depend_nodes; +} + +void AddControlDepend(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + std::list orders = graph->GetOrderedCnodes(); + std::vector cnodes(orders.begin(), orders.end()); + size_t cnodes_size = cnodes.size(); + // get effect index of cnodes + std::vector effect_index{}; + for (size_t i = 0; i < cnodes_size; i++) { + if (graph->HasEffect(cnodes[i])) { + effect_index.push_back(i); + } + } + if (effect_index.empty()) { + return; + } + AnfNodePtr last_node = cnodes[cnodes_size - 1]; + CNodePtr return_node; + if (last_node->isa()) { + return_node = last_node->cast(); + } + MS_EXCEPTION_IF_NULL(return_node); + if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { + MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; + } + if (return_node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; + } + + auto depend_node_inputs = DoControlDepend(graph, return_node, effect_index, cnodes); + auto depend_cnode = graph->NewCNode(depend_node_inputs); + depend_cnode->set_abstract(depend_cnode->input(1)->abstract()); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (!manager->Replace(return_node->input(1), depend_cnode)) { + MS_LOG(EXCEPTION) << "Depend replace node failed"; + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/control_depend.h b/mindspore/ccsrc/frontend/optimizer/control_depend.h similarity index 100% rename from mindspore/ccsrc/optimizer/control_depend.h rename to mindspore/ccsrc/frontend/optimizer/control_depend.h diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc new file mode 100644 index 0000000000..4d968d6d74 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -0,0 +1,231 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/cse.h" +#include +#include +#include +#include "./common.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractFunctionPtr; + +BasePtr AbsOf(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto node_abs = node->abstract(); + // in testcase: TestOptOpt.CSE, node->abstract() is null; + if (node_abs == nullptr) { + return kAnyValue; + } + + return node_abs; +} + +bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { + bool changed = false; + for (FuncGraphPtr fg : manager->func_graphs()) { + MS_EXCEPTION_IF_NULL(fg); + std::vector order_group; + std::unordered_map> groups; + std::unordered_map hashes; + + std::vector toposet = TopoSort(fg->get_return()); + for (auto node : toposet) { + MS_EXCEPTION_IF_NULL(node); + if (hashes.find(node) != hashes.end()) { + continue; + } + + std::size_t h = 0; + if (node->isa()) { + ValueNodePtr value_node = node->cast(); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + h = hash_combine(value->hash(), (AbsOf(value_node)->hash())); + } else if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + size_t init = 0; + h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { + return hash_combine(hash, hashes[node_in]); + }); + } else if (node->isa()) { + h = node->hash(); + } else { + MS_LOG(ERROR) << "Unknow node type"; + } + + hashes[node] = h; + if (groups.find(h) == groups.end()) { + std::vector innervec({node}); + groups[h] = innervec; + order_group.emplace_back(h); + } else { + groups[h].push_back(node); + } + } + + changed = DoReplace(manager, order_group, &groups) || changed; + } + + return changed; +} +// The op like print, summary, or the op do not has true output, and always as a depend node input. +static bool HasSideEffect(const AnfNodePtr &node) { + auto prim = GetCNodePrimitive(node); + if (prim == nullptr) { + return false; + } + auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); + if (side_effect_v != nullptr && side_effect_v->isa()) { + return GetValue(side_effect_v); + } + return false; +} +// If true do not merge the node. +bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { + bool has_random_effect = false; + auto prim_main = GetCNodePrimitive(main); + auto prim_node = GetCNodePrimitive(node); + // if has random effect, when generate by different op (not same object), do not merge. + if (prim_main != nullptr) { + if (prim_main == prim_node) { + return false; + } + auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); + if (effect_val != nullptr && effect_val->isa()) { + has_random_effect = GetValue(effect_val); + } + } + return has_random_effect; +} + +bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { + MS_EXCEPTION_IF_NULL(main); + MS_EXCEPTION_IF_NULL(node); + + if (main->isa() && node->isa()) { + auto main_value = GetValueNode(main); + auto node_value = GetValueNode(node); + return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); + } else if (main->isa() && node->isa()) { + auto c_main = main->cast(); + auto c_node = node->cast(); + // When appsame is true, check if has side effect, do not merge. + if (check_side_effect && HasSideEffect(main)) { + return false; + } + const auto &inp1 = c_main->inputs(); + const auto &inp2 = c_node->inputs(); + if (inp1.size() != inp2.size()) { + return false; + } + for (size_t j = 0; j < inp1.size(); j++) { + auto inp1_j = inp1[j]; + auto inp2_j = inp2[j]; + MS_EXCEPTION_IF_NULL(inp1_j); + MS_EXCEPTION_IF_NULL(inp2_j); + if (!(*inp1_j == *inp2_j)) { + // Handle the case of two different Tensor, but with the same value + if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { + auto tensor1 = GetValueNode(inp1_j); + auto tensor2 = GetValueNode(inp2_j); + if (tensor1->ValueEqual(*tensor2)) { + continue; + } + } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { + // When the same side effect node as another two nodes' inputs, we still merge the node. + // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the + // node. + if (CheckReplace(inp1_j, inp2_j, false)) { + continue; + } + } + return false; + } + } + // When appsame is true, check if has random effect do not merge + if (CheckRandomEffect(c_main, c_node)) { + return false; + } + return true; + } + // a parameter node. + return false; +} + +bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, + std::unordered_map> *groups) const { + bool changes = false; + std::set clear_set; + for (auto &h : order_group) { + std::vector &group = (*groups)[h]; + // If there are more than 2 node in that group, they may be same common expression can be eliminated. + if (group.size() > 1) { + for (size_t k = 0; k < group.size() - 1; k++) { + AnfNodePtr main = group[k]; + MS_EXCEPTION_IF_NULL(main); + + // When all node in group has been replaced + // or a valuenode node, skip compare in group + if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa())) { + break; + } + + // skip node has been replaced + if (clear_set.find(k) != clear_set.end()) { + continue; + } + + // Compare with rest elements in this group. + for (size_t i = k + 1; i < group.size(); i++) { + auto node = group[i]; + MS_EXCEPTION_IF_NULL(node); + + if (clear_set.find(i) != clear_set.end()) { + continue; + } + if (main->func_graph() != node->func_graph()) { + continue; + } + if (CheckReplace(node, main)) { + changes = true; + (void)manager->Replace(node, main); + (void)clear_set.insert(i); + } + } + } + clear_set.clear(); + } + } + + return changes; +} + +bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + return BuildOrderGroupAndDoReplace(manager); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/cse.h b/mindspore/ccsrc/frontend/optimizer/cse.h new file mode 100644 index 0000000000..140f592715 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/cse.h @@ -0,0 +1,61 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/manager.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +// Common subexpression elimination. +class CSE { + public: + explicit CSE(bool report_changes = true) : report_changes_(report_changes) {} + virtual ~CSE() = default; + + bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { + bool chg = Cse(root, optimizer->resource()->manager()); + return chg && report_changes_; + } + + virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; + + virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; + + bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; + + private: + bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; + bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, + std::unordered_map> *groups) const; + bool report_changes_; +}; + +BasePtr AbsOf(const AnfNodePtr &node); +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc new file mode 100644 index 0000000000..c157777040 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/graph_kernel_reuse.h" +#include +#include +#include +#include "./common.h" +#include "utils/graph_utils.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { + if (a->abstract() && b->abstract()) { + auto a_type = a->abstract()->GetTypeTrack(); + auto b_type = b->abstract()->GetTypeTrack(); + + if (a_type != b_type) { + return false; + } + + auto a_shape = a->abstract()->GetShapeTrack(); + auto b_shape = b->abstract()->GetShapeTrack(); + if (a_shape != nullptr && a_shape == b_shape) { + return true; + } + + if (a_shape != nullptr && b_shape != nullptr && a_shape->isa() && + b_shape->isa()) { + return a_shape->cast()->shape() == b_shape->cast()->shape(); + } + } + return false; +} + +bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { + bool changed = false; + auto fgs = manager->func_graphs(); + for (FuncGraphPtr &fg : fgs) { + if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + continue; + } + std::string key = GetValue(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { + if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { + FuncGraphPtr new_fg = nullptr; + for (auto &cfg : graph_kernel_ops[key]) { + // If two graphs have different size then continue + auto fg_topos = TopoSort(fg->get_return()); + auto cfg_topos = TopoSort(cfg->get_return()); + if (fg_topos.size() != cfg_topos.size()) { + continue; + } + + // Compare const tensor + bool has_same = true; + for (size_t i = 0; i < fg_topos.size(); ++i) { + if (IsValueNode(fg_topos[i])) { + if (!IsValueNode(cfg_topos[i])) { + has_same = false; + break; + } + + auto tensor1 = GetValueNode(fg_topos[i]); + auto tensor2 = GetValueNode(cfg_topos[i]); + if (!tensor1->ValueEqual(*tensor2)) { + has_same = false; + break; + } + } + } + + if (!has_same) { + continue; + } + + auto fg_input = fg->parameters(); + auto cfg_input = cfg->parameters(); + if (fg_input.size() != cfg_input.size()) { + continue; + } + // Compare input + for (size_t i = 0; i < fg_input.size(); ++i) { + if (!CompareNode(fg_input[i], cfg_input[i])) { + has_same = false; + break; + } + } + if (!has_same) { + continue; + } + + // Compare output + if (!CompareNode(fg->output(), cfg->output())) { + continue; + } + + // Find reusable fg + new_fg = cfg; + break; + } + + if (new_fg != nullptr) { + // Replace current fg with existing fg + auto users = fg->func_graph_cnodes_index(); + for (auto &iter : users) { + auto cnode = iter.first->first->cast(); + auto new_input = cnode->inputs(); + auto main_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(main_graph); + if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { + new_input[1] = NewValueNode(new_fg); + } else { + new_input[0] = NewValueNode(new_fg); + } + auto new_cnode = main_graph->NewCNode(new_input); + manager->Replace(iter.first->first, new_cnode); + changed = true; + } + + } else { + // Add current fg to map + graph_kernel_ops[key].push_back(fg); + } + } + } else { + graph_kernel_ops[key] = {fg}; + } + } + + return changed; +} + +bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + return DoReplace(manager); +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h new file mode 100644 index 0000000000..a79ef3ce6d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/graph_kernel_reuse.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H +#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H + +#include +#include +#include +#include "mindspore/ccsrc/backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { + +// Common subexpression elimination. +class GraphKernelReuse { + public: + GraphKernelReuse() : count(0) {} + virtual ~GraphKernelReuse() = default; + + bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { + bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); + return chg; + } + + bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); + bool DoReplace(const FuncGraphManagerPtr manager); + + bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); + + private: + std::unordered_map> graph_kernel_ops; + int count; +}; + +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc new file mode 100644 index 0000000000..efc3795a4c --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/arithmetic_simplify.h" +#include "frontend/optimizer/irpass/branch_culling.h" +#include "frontend/optimizer/irpass/cast_eliminate.h" +#include "frontend/optimizer/irpass/convert.h" +#include "frontend/optimizer/irpass/env_item_eliminate.h" +#include "frontend/optimizer/irpass/grad_var_prepare.h" +#include "frontend/optimizer/irpass/gradient_eliminate.h" +#include "frontend/optimizer/irpass/inline.h" +#include "frontend/optimizer/irpass/incorporate_call.h" +#include "frontend/optimizer/irpass/incorporate_getitem.h" +#include "frontend/optimizer/irpass/item_tuple_eliminate.h" +#include "frontend/optimizer/irpass/mark_interface_fusion.h" +#include "frontend/optimizer/irpass/merge_addn.h" +#include "frontend/optimizer/irpass/minmax_grad.h" +#include "frontend/optimizer/irpass/param_replace.h" +#include "frontend/optimizer/irpass/partial_eliminate.h" +#include "frontend/optimizer/irpass/reduce_eliminate.h" +#include "frontend/optimizer/irpass/ref_eliminate.h" +#include "frontend/optimizer/irpass/reshape_eliminate.h" +#include "frontend/optimizer/irpass/special_op_eliminate.h" +#include "frontend/optimizer/irpass/specialize_transform.h" +#include "frontend/optimizer/irpass/symbol_resolver.h" +#include "frontend/optimizer/irpass/tile_eliminate.h" +#include "frontend/optimizer/irpass/transpose_eliminate.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" + +namespace mindspore { +namespace opt { +namespace irpass { +OptimizeIRPassLib::OptimizeIRPassLib() { + arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", + {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, + prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); + arithmetic_simplify2_ = + MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); + special_op_eliminate_ = + MakeSubstitution(std::make_shared(), "special_op_eliminate", + {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, + prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv}); + zero_like_fill_zero_ = + MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); + adjust_all_reduce_mul_add_ = + MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); + + // ops eliminate + item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", + {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); + tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); + cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); + reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); + transpose_eliminate_ = + MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); + reduce_eliminate_ = MakeSubstitution( + std::make_shared(), "reduce_eliminate", + {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); + partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); + same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); + check_bprop_eliminate_ = + MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); + reset_defer_inline_ = + MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); + depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); + + // Env Item Eliminate + env_get_item_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); + new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_ = + MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); + incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), + "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); + + // Ref eliminate + make_ref_eliminate_ = + MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); + get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", + {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); + get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", + {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); + + replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", + IsValueNode, opt::FORCE_RENORM); + replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); + // Gradient transforms + expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); + minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); + + // branch culling + switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); + float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), + "float_tuple_getitem_switch", prim::kPrimTupleGetItem); + float_env_getitem_switch_ = + MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); + convert_switch_replacement_ = + MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); + + // Addn + merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); + addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); + + // inline + inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); + replace_applicator_ = + MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); + specialize_transform_ = + MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); + + // Incorporation + incorporate_getitem_set_ = + MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); + incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), + "incorporate_getitem_from_param", IsCNodeGraphKernel); + incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); + incorporate_call_switch_ = + MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); + + // Virtual Dataset + virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), + "virtual_dataset_eliminate", prim::kPrimVirtualDataset); + + // Convert + print_tuple_wrapper_ = + MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); + + // Unused parameter eliminate + unused_parameter_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); + unused_output_eliminate_ = + MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); + + // AddN eliminate + addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); + + // Mark interface fusion + mark_interface_fusion_ = + MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); + + // IndexedSlices Eliminate + indexed_slices_eliminate_ = MakeSubstitution( + std::make_shared(), "indexed_slices_eliminate", + {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); +} + +ResolveIRPassLib::ResolveIRPassLib() { + resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); + resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); +} + +InferenceOptPrepareLib::InferenceOptPrepareLib() { + grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h new file mode 100644 index 0000000000..4af8c0789d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -0,0 +1,192 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ + +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/opt.h" +#include "ir/visitor.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// the collection of irpass for optimie action +class OptimizeIRPassLib { + public: + OptimizeIRPassLib(); + ~OptimizeIRPassLib() = default; + + SubstitutionPtr arithmetic_simplify_; + SubstitutionPtr arithmetic_simplify2_; + SubstitutionPtr special_op_eliminate_; + SubstitutionPtr zero_like_fill_zero_; + SubstitutionPtr adjust_all_reduce_mul_add_; + + // ops eliminate + SubstitutionPtr item_tuple_eliminate_; + SubstitutionPtr tile_eliminate_; + SubstitutionPtr cast_eliminate_; + SubstitutionPtr reshape_eliminate_; + SubstitutionPtr transpose_eliminate_; + SubstitutionPtr reduce_eliminate_; + SubstitutionPtr partial_eliminate_; + SubstitutionPtr same_eliminate_; + SubstitutionPtr check_bprop_eliminate_; + SubstitutionPtr reset_defer_inline_; + SubstitutionPtr depend_value_elim_; + + // Env Item Eliminate + SubstitutionPtr env_get_item_eliminate_; + SubstitutionPtr new_env_get_item_; + SubstitutionPtr incorporate_env_getitem_; + SubstitutionPtr incorporate_env_getitem_switch_; + + // Ref eliminate + SubstitutionPtr make_ref_eliminate_; + SubstitutionPtr get_ref_param_eliminate_; + SubstitutionPtr get_make_ref_eliminate_; + SubstitutionPtr replace_refkey_by_param_; + SubstitutionPtr replace_old_param_; + + // Branch culling + SubstitutionPtr switch_simplify_; + SubstitutionPtr float_tuple_getitem_switch_; + SubstitutionPtr float_env_getitem_switch_; + SubstitutionPtr convert_switch_replacement_; + + // AddN + SubstitutionPtr merge_addn_; + SubstitutionPtr addn_zero_filter_; + + // Gradient irpasses + SubstitutionPtr expand_jprim_; + SubstitutionPtr minmaximum_grad_; + + // inline + SubstitutionPtr inline_; + SubstitutionPtr replace_applicator_; + SubstitutionPtr specialize_transform_; + + // Incorporation + SubstitutionPtr incorporate_getitem_set_; + SubstitutionPtr incorporate_getitem_from_param_; + SubstitutionPtr incorporate_call_; + SubstitutionPtr incorporate_call_switch_; + + // virtual dataset + SubstitutionPtr virtual_dataset_eliminate_; + + // Convert + SubstitutionPtr print_tuple_wrapper_; + + // Unused parameter eliminate + SubstitutionPtr unused_parameter_eliminate_; + SubstitutionPtr unused_output_eliminate_; + + // AddN eliminate + SubstitutionPtr addn_eliminate_; + + // Fusion + SubstitutionPtr mark_interface_fusion_; + + // IndexedSlices Eliminate + SubstitutionPtr indexed_slices_eliminate_; +}; + +// the collection of irpass for resolve action +class ResolveIRPassLib { + public: + ResolveIRPassLib(); + ~ResolveIRPassLib() = default; + + SubstitutionPtr resolver_resolve_; + SubstitutionPtr resolver_getattr_; +}; + +class InferenceOptPrepareLib { + public: + InferenceOptPrepareLib(); + ~InferenceOptPrepareLib() = default; + SubstitutionPtr grad_var_prepare_; +}; + +// predicate functions +inline bool IsNode(const AnfNodePtr &) { return true; } + +inline bool IsCNode(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} + +inline bool IsVNode(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} + +inline bool IsParam(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} + +// Check if CNode Input 0 is Func Graph +inline bool IsCNodeGraph(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + return IsValueNode(inp0); +} + +// Check if CNode Input 0 is Func Graph of graph kernel. +inline bool IsCNodeGraphKernel(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + if (IsValueNode(inp0)) { + auto fg = GetValueNode(inp0); + if (fg == nullptr) { + return false; + } + return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + } + return false; +} + +// Check if CNode Input 0 is CNode +inline bool IsCNodeDup(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + + auto inp0 = node->cast()->input(0); + return (inp0 != nullptr) && inp0->isa(); +} +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc new file mode 100644 index 0000000000..83f7fae582 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -0,0 +1,680 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "frontend/optimizer/irpass/arithmetic_simplify.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarMul)(node); + + if (is_zero_) { + return NewValueNode(zero_); + } + if (is_one_) { + return x_; + } + return nullptr; +} + +void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { + if (is_one_ || node->isa()) { + x_ = node; + return; + } + + AnfVisitor::Visit(node); + if (!is_one_) { + x_ = node; + } +} + +void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (*value == *zero_) { + is_zero_ = true; + } else if (*value == *one_) { + is_one_ = true; + } +} + +void MultiplyByZeroOrOne::Reset() { + x_ = nullptr; + is_one_ = false; + is_zero_ = false; +} + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > FLT_EPSILON) { + return false; + } + } + return true; + } else if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > DBL_EPSILON) { + return false; + } + } + return true; + } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] != check_value_) { + return false; + } + } + return true; + } + // input Data Types is not supported + return false; +} + +bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { + return false; + } + return IsTensorConstant(value); +} + +void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { + if (!node->isa()) { + return nullptr; + } + + auto value = node->cast()->value(); + + if (!value->isa()) { + return nullptr; + } + + tensor::TensorPtr tensor_ptr = dyn_cast(value); + return tensor_ptr->data_c(); +} + +// Make a new tensor (when possible) with the same shape as of `node` +// If x is nullptr then fill new tensor will "0" +// If x is a tensor with empty shape then fill new tensor with the single value of x +// If x is a tensor with same shape as `node` then return x as result +AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (x == nullptr) { + std::memset(data, 0, mem_size); + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + // x is not nullptr + if (x->isa()) { + if ((x->abstract() == nullptr) || !x->abstract()->isa()) { + return nullptr; + } + auto x_abstract = x->abstract()->cast(); + std::vector x_shape = x_abstract->shape()->shape(); + + if (x_shape != tensor_shape) { + return nullptr; + } + return x; + } + + if (!x->isa()) { + return nullptr; + } + auto x_value = x->cast()->value(); + if (!x_value->isa()) { + return nullptr; + } + + auto x_tensor_ptr = dyn_cast(x_value); + + if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { + return nullptr; + } + char *source_data = reinterpret_cast(GetPointerToTensorData(x)); + if (x_tensor_ptr->DataSize() == 1) { + for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { + memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); + } + } else { + memcpy(data, source_data, mem_size); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_zero_) { + if (x_->func_graph() != node->func_graph()) { + return nullptr; + } + return NewTensorFilledWithData(node); + } + return nullptr; +} + +void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { + if (is_zero_) { + x_ = node; + return; + } + + if (IsParam(node)) { + x_ = node; + return; + } + + if (IsCNode(node)) { + CNodePtr cnode = node->cast(); + if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { + is_zero_ = true; + return; + } + x_ = node; + return; + } + auto value = node->cast()->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_one_) { + return NewTensorFilledWithData(node, x_); + } + return nullptr; +} + +void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { + if (is_one_) { + x_ = node; + return; + } + + if (IsParam(node) || IsCNode(node)) { + x_ = node; + return; + } + + auto value = node->cast()->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByOne::Reset() { + x_ = nullptr; + is_one_ = false; +} + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void AddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && + ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void AddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimTensorAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void TensorAddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void TensorAddByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } +} + +void TensorAddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { + return nullptr; + } + + // {PrimMomentum, {...}, Y, Z, Xs} + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { + return nullptr; + } + auto y = inputs[2]; + auto z = inputs[3]; + + // {kPrimZerosLike, X} + if (inputs[1]->cast()->size() != 2) { + return nullptr; + } + + // {prim::kPrimMakeTuple, Z, Y} + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); +} + +// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +// Support function to multiply two constant tensors: partially support broadcasting shapes +template +void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, + void **out_data, int out_data_size) { + T *data_1 = reinterpret_cast(in_data_1); + T *data_2 = reinterpret_cast(in_data_2); + T *data_out = new T[out_data_size]; + + if (in_data_1_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[i]; + } + } + if (in_data_2_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[i]; + } + } + *out_data = reinterpret_cast(data_out); + return; +} + +AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, + const AnfNodePtr &node_3) { + if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || + (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { + return nullptr; + } + + auto value_1 = GetValueNode(vnode_1); + auto value_2 = GetValueNode(vnode_2); + + if (!value_1->isa() || !value_2->isa()) { + return nullptr; + } + + auto tensor_ptr_1 = dyn_cast(value_1); + auto tensor_ptr_2 = dyn_cast(value_2); + + auto tensor_1_abstract = vnode_1->abstract()->cast(); + auto tensor_2_abstract = vnode_1->abstract()->cast(); + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); + TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } + + std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + + int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + + void *data_out; + + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), + &data_out, data_out_size); + } else { + if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + // Un-support data types + return nullptr; + } + } + } + + auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); + size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + memcpy(data, data_out, mem_size); + + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimMul, Tensor1, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + + if (!IsCNode(c_p_node_)) { + return nullptr; + } + + auto tensor1 = vnode_; + auto mul = c_p_node_->cast(); + + Reset(); + // {prim::kPrimMul, Tensor2, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + auto tensor2 = vnode_; + auto c_p_node = c_p_node_; + + auto PrimMul = GetValueNode(mul->input(0)); + auto fg = node->func_graph(); + + auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); + if (new_mul_tensor == nullptr) { + auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); + return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); + } + return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); +} + +void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { + if (IsValueNode(node)) { + vnode_ = node; + } + + if (IsCNode(node) || IsParam(node)) { + c_p_node_ = node; + } +} + +void ConstantDuplicateMul::Reset() { + vnode_ = nullptr; + c_p_node_ = nullptr; +} + +AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (!IsValueNode(inputs[2])) { + return nullptr; + } + auto scalar = GetValueNode(inputs[2]); + if (scalar->isa() && GetValue(scalar) == 1.0) { + return inputs[1]; + } else if (scalar->isa() && GetValue(scalar) == 1) { + return inputs[1]; + } + return nullptr; +} + +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimAddN, Zs} + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + return nullptr; + } + auto addn = node->cast(); + if (addn->size() != 2) { + return nullptr; + } + AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); + if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { + return nullptr; + } + auto addn_maketuple = addn->input(1); + + auto fg = all_reduce_fg_; + // addn inputs cross the graph, make the inputs same as allreduce node. + if (z_->isa() && fg != z_->func_graph()) { + auto cnode_z = z_->cast(); + z_ = NewCNode(cnode_z->inputs(), fg); + } + + auto addn_op_node = addn->input(0); + auto make_tuple_op_node = addn->input(1)->cast()->input(0); + + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); + AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); + ProcessDependEdge(fg, addn_maketuple, all_reduce); + return mul; +} + +void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, + const AnfNodePtr &new_node) { + // If has dynamic loss scale. + auto &users_map = fg->manager()->node_users(); + auto it = users_map.find(mul_cnode_); + if (it != users_map.end()) { + auto users = it->second; + for (auto &user_pair : users) { + auto node = user_pair.first; + if (node != addn_maketuple) { + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + fg->manager()->SetEdge(node, user_pair.second, new_node); + } + } + } + } +} + +void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { + if (level_ == 0) { + level_ = 1; + is_reduce_match_ = false; + // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} + AnfVisitor::Match(prim::kPrimMul)(node); + level_ = 0; + if (is_reduce_match_) { + mul_ = node->cast()->input(0); + mul_cnode_ = node->cast(); + y_ = tmp_; + } else { + z_ = node; + } + } + + if (level_ == 1) { + // {prim::kPrimAllReduce, X} + if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { + auto cnode = node->cast(); + if (cnode->size() > 1) { + all_reduce_ = cnode->input(0); + x_ = cnode->input(1); + is_reduce_match_ = true; + all_reduce_fg_ = cnode->func_graph(); + } + } else { + tmp_ = node; + } + } +} + +void AdjustAllReduceMulAdd::Reset() { + level_ = 0; + is_reduce_match_ = false; + x_ = nullptr; + y_ = nullptr; + z_ = nullptr; + tmp_ = nullptr; + all_reduce_fg_ = nullptr; +} + +AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} + +AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h new file mode 100644 index 0000000000..3088231396 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.h @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ + +#include +#include +#include + +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +class MultiplyByZeroOrOne : public AnfVisitor { + public: + MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} + ~MultiplyByZeroOrOne() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}, is_one_{false}; + ValuePtr zero_, one_; + AnfNodePtr x_{nullptr}; +}; + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +class CheckTensorConstant { + public: + explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} + ~CheckTensorConstant() = default; + + bool IsTensorConstant(const ValuePtr &value); + bool IsTensorScalarConstant(const ValuePtr &value); + + private: + int check_value_; +}; + +class TensorMultiplyBase : public AnfVisitor { + protected: + void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); + + // Make a new tensor (when possible) with the same shape as of `node` + // If x is nullptr then fill new tensor will "0" + // If x is a tensor with empty shape then fill new tensor with the single value of x + // If x is a tensor with same shape as `node` then return x as result + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); + + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +class TensorMultiplyByZero : public TensorMultiplyBase { + public: + TensorMultiplyByZero() : zero_(MakeValue(0)) {} + ~TensorMultiplyByZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}; + ValuePtr zero_; +}; + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +class TensorMultiplyByOne : public TensorMultiplyBase { + public: + TensorMultiplyByOne() {} + ~TensorMultiplyByOne() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_one_{false}; +}; + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +class AddByZero : public AnfVisitor { + public: + AddByZero() : zero_(MakeValue(0)) {} + ~AddByZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + bool is_zero_{false}; + ValuePtr zero_; + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +class TensorAddByZero : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}; + AnfNodePtr x_{nullptr}; +}; + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +class OptUpdateZeroTensor : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +class ConstantDuplicateMul : public AnfVisitor { + public: + // Support function to multiply two constant tensors: partially support broadcasting shapes + template + void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, + int out_data_size); + + AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + AnfNodePtr vnode_; + AnfNodePtr c_p_node_; +}; + +class PowerOneEliminate : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +// grad = AllReduce(grad) / worker_number +// grad = grad + weight * decy +// -> +// grad = grad + weight * decy +// grad = AllReduce(grad) / worker_number + +// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> +// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} +class AdjustAllReduceMulAdd : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + int level_{0}; + bool is_reduce_match_{false}; + AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; + AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; + FuncGraphPtr all_reduce_fg_{nullptr}; +}; + +class ArithmeticSimplify : public OptimizerCaller { + public: + ArithmeticSimplify() + : multiply_by_zero_or_one_(std::make_shared()), + tensor_multiply_by_one_(std::make_shared()), + add_by_zero_(std::make_shared()), + tensor_add_by_zero_(std::make_shared()), + identity_(std::make_shared(prim::kPrimIdentity)), + opt_update_zero_tensor_(std::make_shared()), + constant_duplicate_mul_(std::make_shared()), + power_one_(std::make_shared()) { + eliminaters_.emplace_back(multiply_by_zero_or_one_); + eliminaters_.emplace_back(tensor_multiply_by_one_); + eliminaters_.emplace_back(add_by_zero_); + eliminaters_.emplace_back(tensor_add_by_zero_); + eliminaters_.emplace_back(identity_); + eliminaters_.emplace_back(opt_update_zero_tensor_); + eliminaters_.emplace_back(constant_duplicate_mul_); + eliminaters_.emplace_back(power_one_); + } + ~ArithmeticSimplify() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + + private: + OptimizerCallerPtr multiply_by_zero_or_one_; + OptimizerCallerPtr tensor_multiply_by_one_; + OptimizerCallerPtr add_by_zero_; + OptimizerCallerPtr tensor_add_by_zero_; + OptimizerCallerPtr identity_; + OptimizerCallerPtr opt_update_zero_tensor_; + OptimizerCallerPtr constant_duplicate_mul_; + OptimizerCallerPtr power_one_; + + std::vector eliminaters_{}; +}; + +// Arithmetic Simplifications should be done after step_parallel. +// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor +// with shape(weight), but after step_parallel, shape of weight may be changed, so the +// shape of the constant tensor should also be changed. So this pass is seperated from +// ArithmeticSimplify and deferred until step_parallel. +class ArithmeticSimplify2 : public OptimizerCaller { + public: + ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { + eliminaters_.emplace_back(tensor_multiply_by_zero_); + } + ~ArithmeticSimplify2() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + + private: + OptimizerCallerPtr tensor_multiply_by_zero_; + std::vector eliminaters_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc new file mode 100644 index 0000000000..dc580f6b63 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc @@ -0,0 +1,584 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/branch_culling.h" + +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +AnfNodePtr GenerateSwitchNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data, + int switch_idx) { + auto switch_node = prim::GetPythonOps("geswitch", "mindspore.ops.functional")->cast(); + std::vector switch_nodes{NewValueNode(switch_node), data, cond}; + auto switch_apply = graph->NewCNode(switch_nodes); + std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_apply, + NewValueNode(MakeValue(switch_idx))}; + return graph->NewCNode(tuple_getitem_nodes); +} + +AnfNodePtr GenerateSwitchTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + return GenerateSwitchNode(graph, cond, data, 1); +} + +AnfNodePtr GenerateSwitchFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + return GenerateSwitchNode(graph, cond, data, 0); +} + +bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { + // The CNode inputs of the following Primitive with index in std::vector should not be guarded by geswitch + // node because it is attribute or ge specific reason. + // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be + // converted to switch guarded. + std::vector>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, + {prim::kPrimMomentum, {2, 3}}, + {prim::kPrimStateSetItem, {1}}, + {prim::kPrimTupleGetItem, {2}}, + {prim::kPrimEnvGetItem, {1}}, + {prim::kPrimEnvSetItem, {1}}, + {prim::kPrimReduceSum, {2}}, + {prim::kPrimReduceMean, {2}}, + {prim::kPrimReduceAll, {2}}, + {prim::kPrimCast, {2}}, + {prim::kPrimTranspose, {2}}, + {prim::kPrimOneHot, {2}}, + {prim::kPrimGatherV2, {3}}, + {prim::kPrimReshape, {2}}, + {prim::kPrimAssign, {1}}, + {prim::kPrimAssignAdd, {1}}, + {prim::kPrimAssignSub, {1}}, + {prim::kPrimTensorSummary, {1}}, + {prim::kPrimImageSummary, {1}}, + {prim::kPrimScalarSummary, {1}}, + {prim::kPrimApplyRMSProp, {6, 7, 8}}, + {prim::kPrimCumSum, {2}}, + {prim::kPrimTile, {2}}, + {prim::kPrimExpandDims, {2}}, + {prim::kPrimHistogramSummary, {1}}}); + for (auto &item : white_list) { + auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { + return IsPrimitiveCNode(node, item.first) && idx == index; + }); + if (matched) { + return true; + } + } + + std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend}; + for (auto &item : adapter_convert_ops) { + if (IsPrimitiveCNode(node, item)) { + return true; + } + } + return false; +} + +using NodeInputReplMap = std::unordered_map, AnfNodePtr, PairHasher>; +// replace the nodes which should be changed +void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector> nodes_changed, + std::unordered_map repl_node, NodeInputReplMap repl_node_inputs, + const FuncGraphPtr &func_graph) { + for (auto &node_pair : nodes_changed) { + CNodePtr old_node = node_pair.first; + CNodePtr new_node = node_pair.second; + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + for (size_t i = 0; i < old_node->size(); i++) { + auto input = old_node->input(i); + if (repl_node.count(input) != 0) { + new_node->add_input(repl_node[input]); + } else if (repl_node_inputs.count(std::pair(old_node, i)) != 0) { + new_node->add_input(repl_node_inputs[std::pair(old_node, i)]); + } else { + new_node->add_input(input); + } + } + } + + for (auto &item : repl_node) { + if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { + func_graph->set_output(item.second->cast()->input(1)); + } else if (!manager->Replace(item.first, item.second)) { + MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) + << " to new: " << item.second->DebugString(2); + } + } +} + +// trace the node that should add switch and replace them with new nodes in the graph +FuncGraphPtr TransformGraphCondBranchNodes( + const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::function &generate_func) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + // record the node that has been changed + std::vector> nodes_changed; + // record the node to be replaced + std::unordered_map repl_node; + // record the node input to be replaced + NodeInputReplMap repl_node_inputs; + const AnfNodeSet &nodes = graph->nodes(); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto inputs = node->cast()->inputs(); + bool should_replace = false; + // if the apply input does not belong to graph, insert a switch node + for (size_t index = 0; index < inputs.size(); index++) { + auto input_node = inputs[index]; + MS_EXCEPTION_IF_NULL(input_node); + // for some ops input should not guard it with switch + if (InConvertWhiteList(node, index)) { + continue; + } + + // If the input for node is not the graph belonged, or it is an ValueNode. + // Bypass the Primitive node which is inputs[0]. + if ((index >= 1 && inputs[index]->func_graph() != nullptr && inputs[index]->func_graph() != graph) || + ((index >= 1 && inputs[index]->isa()))) { + input_node = generate_func(graph, cond, inputs[index]); + repl_node_inputs[std::pair(node, index)] = input_node; + should_replace = true; + } + if (input_node == nullptr) { + MS_LOG(EXCEPTION) << "generate switch node failed"; + } + } + if (should_replace) { + auto new_node = graph->NewCNode(); + repl_node[node] = new_node; + nodes_changed.emplace_back(node->cast(), new_node); + } + } + RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); + return graph; +} + +struct SharedOp { + tensor::TensorPtr const_data; + CNodePtr square_ops[2]; + CNodePtr merge_ops[2]; +} MergeNetOutput; + +inline tensor::TensorPtr GetConstData() { return MergeNetOutput.const_data; } +inline void SetConstData(const tensor::TensorPtr &const_value) { MergeNetOutput.const_data = const_value; } + +inline CNodePtr GetSquareOp(int switch_idx) { return MergeNetOutput.square_ops[switch_idx]; } +inline void SetSquareOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.square_ops[switch_idx] = op; } + +inline CNodePtr GetMergeOp(int switch_idx) { return MergeNetOutput.merge_ops[switch_idx]; } +inline void SetMergeOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.merge_ops[switch_idx] = op; } + +inline void ResetSharedOp() { + SetConstData(nullptr); + SetSquareOp(0, nullptr); + SetSquareOp(1, nullptr); + SetMergeOp(0, nullptr); + SetMergeOp(1, nullptr); +} + +tensor::TensorPtr ConstData() { + std::vector shp = {1}; + tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); + auto *val = static_cast(const_data->data_c()); + *val = 0; + return const_data; +} + +CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, + const tensor::TensorPtr &const_data) { + auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); + // for the depended node , add two const data to merge the flow ,one for depended node with same switch, + // the other use the opposite + auto ctrl_data = NewValueNode(const_data); + auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); + + std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; + auto square_op = graph->NewCNode(square_nodes); + + return square_op; +} + +CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, + const tensor::TensorPtr &const_data, const CNodePtr &square_op) { + // for the depended node , add two const data to merge the flow ,one for depended node with same switch, + // the other use the opposite + auto oppsite_ctrl_data = NewValueNode(const_data); + auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); + + std::vector merge_nodes; + auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); + merge_nodes.push_back(NewValueNode(PrimMerge)); + std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; + merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); + auto merge_op = graph->NewCNode(merge_nodes); + + return merge_op; +} + +// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) +// control_depend(output_node, square_op) +AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, + int switch_idx) { + tensor::TensorPtr const_data = GetConstData(); + if (const_data == nullptr) { + const_data = ConstData(); + SetConstData(const_data); + } + + CNodePtr square_op = GetSquareOp(switch_idx); + if (square_op == nullptr) { + square_op = SquareOp(graph, cond, switch_idx, const_data); + SetSquareOp(switch_idx, square_op); + } + + CNodePtr merge_op = GetMergeOp(switch_idx); + if (merge_op == nullptr) { + merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); + SetMergeOp(switch_idx, merge_op); + } + + std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; + auto control_depend_op = graph->NewCNode(control_depend_nodes); + + std::vector depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; + auto depend_op = graph->NewCNode(depend_nodes); + + return depend_op; +} + +// construct a merge output and add dependency with the netoutput node from control_depend +// we need to reserve the control_depend node, besides the generated merge node and control_depend node +CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, + const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, + int switch_idx) { + auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); + auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); + std::vector shp = {1}; + tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); + auto *val = static_cast(const_data->data_c()); + *val = 0; + // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same + // switch the other use the opposite + auto ctrl_data = NewValueNode(const_data); + auto oppsite_ctrl_data = NewValueNode(const_data); + auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); + auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); + + std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; + auto square_op = graph->NewCNode(square_nodes); + + std::vector merge_nodes; + merge_nodes.push_back(NewValueNode(PrimMerge)); + std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; + merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); + auto merge_output = graph->NewCNode(merge_nodes); + + std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; + auto cond_dep_output = graph->NewCNode(control_depend_nodes); + + std::vector depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, + cond_dep_output}; + return graph->NewCNode(depended_make_tuple_nodes); +} + +// generate switch nodes for true graph node inputs +AnfNodePtr GenerateSwitchDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchDependNode(graph, cond, data, 1); +} + +// generate switch nodes for false graph node inputs +AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchDependNode(graph, cond, data, 0); +} + +// generate switch nodes for true graph node inputs +CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, + const AnfNodePtr &con_input, const AnfNodePtr &output) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); +} + +// generate switch nodes for false graph node inputs +CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, + const AnfNodePtr &con_input, const AnfNodePtr &output) { + // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch + return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); +} + +// to judge if the node used in ControlDepend is a net output node +bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { + auto uses = manager->node_users()[node]; + bool is_output_node = true; + for (auto &item : uses) { + if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { + continue; + } + is_output_node = false; + break; + } + return is_output_node; +} + +// generate node for Depended MakeTuple +void GenerateReplNodeForDependMakeTuple( + const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::shared_ptr> &repl_node, + const std::function &generate_func, + const std::function &gen_ctl_depd_func) { + MS_EXCEPTION_IF_NULL(graph->manager()); + + auto make_tuple_inputs = depended_node->cast()->inputs(); + const size_t make_tuple_begin_idx = 1; + std::vector new_make_tuple_nodes; + bool replace_make_tuple = false; + new_make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t idx = make_tuple_begin_idx; idx < make_tuple_inputs.size(); idx++) { + auto depended_tuple_input_node = make_tuple_inputs[idx]; + if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimDepend)) { + new_make_tuple_nodes.push_back(depended_tuple_input_node); + continue; + } + if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimControlDepend)) { + // only when the control depend input is not square op (the op to use as merge output) + auto control_inputs = depended_tuple_input_node->cast()->inputs(); + if (control_inputs.size() != 3) { + MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); + } + // control inputs: primitive, src, dst + auto dst_node = control_inputs[2]; + if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { + auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); + MS_EXCEPTION_IF_NULL(gen_node); + auto tuple_inputs = gen_node->inputs(); + // add depended tuple inputs to new_make_tuple directly + for (size_t i = 1; i < tuple_inputs.size(); i++) { + new_make_tuple_nodes.push_back(tuple_inputs[i]); + } + } + replace_make_tuple = true; + continue; + } + + if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { + auto gen_node = generate_func(graph, cond, depended_tuple_input_node); + new_make_tuple_nodes.push_back(gen_node); + replace_make_tuple = true; + continue; + } + + MS_LOG(WARNING) << "depended node being used by others, "; + } + if (replace_make_tuple) { + auto make_tuple_op = graph->NewCNode(new_make_tuple_nodes); + (*repl_node)[depended_node] = make_tuple_op; + } +} + +// generate a replace depend node for a single network output node +void GenerateRepDepend( + const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::shared_ptr> &repl_node, + const std::function &generate_func, + const std::function &gen_ctl_depd_func) { + auto inputs = node->inputs(); + if (inputs.size() != 3) { + MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; + } + + std::vector new_depened_inputs; + // Inputs should be [depend, actual_value, depended_node] + auto depended_node = inputs[2]; + new_depened_inputs.push_back(inputs[0]); + new_depened_inputs.push_back(inputs[1]); + // depended node should be make_tuple or a single depended node + if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { + GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); + } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { + // only when the control depend input is not square op (the op to use as merge output) + auto control_inputs = depended_node->cast()->inputs(); + // control inputs: primitive, src, dst + if (control_inputs.size() != 3) { + MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); + } + auto dst_node = control_inputs[2]; + if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { + auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); + (*repl_node)[depended_node] = gen_node; + } + } else { + // Check if there is only single user for depend_node. + if (graph->manager()->node_users()[depended_node].size() == 1) { + auto gen_node = generate_func(graph, cond, depended_node); + (*repl_node)[depended_node] = gen_node; + } else { + MS_LOG(WARNING) << "depended node being used by others"; + } + } +} + +// generate depend node for netoutput node, to resolve the stream synchronize problem of ge +// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) +// and add control_depend of graph output node and square node. +FuncGraphPtr TransformGraphDependNode( + const FuncGraphPtr &graph, const AnfNodePtr &cond, + const std::function &gen_depend_func, + const std::function &gen_ctl_depd_func) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + ResetSharedOp(); + std::shared_ptr> repl_node = + std::make_shared>(); // record the node to be replaced + const AnfNodeSet &nodes = graph->nodes(); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + auto cnode = node->cast(); + if (cnode->size() != 3) { + MS_LOG(EXCEPTION) << "Dependnode input size != 3"; + } + auto depended_node = cnode->input(2); + MS_EXCEPTION_IF_NULL(depended_node); + if (!depended_node->isa()) { + continue; + } + if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { + continue; + } + GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); + } + } + ResetSharedOp(); + + for (auto &item : *repl_node) { + if (!manager->Replace(item.first, item.second)) { + MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed"; + } + } + + return graph; +} + +FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { + (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); + return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); +} + +FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { + (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); + return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); +} + +// judge if the true and false graph output is compatible(they shall have same tuple size) +bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const AbstractBasePtr &false_branch_abs) { + MS_EXCEPTION_IF_NULL(true_branch_abs); + MS_EXCEPTION_IF_NULL(false_branch_abs); + + if (true_branch_abs->isa() && false_branch_abs->isa()) { + abstract::AbstractTuplePtr true_branch_tuple = true_branch_abs->cast(); + abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast(); + if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) { + MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size() + << ", not equal to false banch size:" << false_branch_tuple->elements().size() << " "; + return false; + } + bool all_compatible = true; + for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { + all_compatible = + all_compatible && GraphOutputCompatible(true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]); + } + return all_compatible; + } + TypePtr true_branch_type = true_branch_abs->BuildType(); + TypePtr false_branch_type = false_branch_abs->BuildType(); + MS_LOG(DEBUG) << "branch output Type equal?" << (*true_branch_type == *false_branch_type) + << " true:" << true_branch_type->ToString() << " false:" << false_branch_type->ToString(); + return (*true_branch_type == *false_branch_type); +} + +AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, + const AbstractBasePtr &true_graph_output_abs, + const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, + const AnfNodePtr &cond) { + MS_EXCEPTION_IF_NULL(true_graph_output_abs); + MS_EXCEPTION_IF_NULL(false_graph_output_abs); + MS_EXCEPTION_IF_NULL(cond); + MS_EXCEPTION_IF_NULL(switch_graph); + auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); + MS_EXCEPTION_IF_NULL(PrimMerge); + + if (!true_graph_output_abs->isa()) { + std::vector merge_nodes; + merge_nodes.push_back(NewValueNode(PrimMerge)); + std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; + merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); + std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), + switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; + return switch_graph->NewCNode(tuple_getitem_nodes); + } else { + abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast(); + abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast(); + + std::vector make_tuple_nodes; + make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { + std::vector true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, + NewValueNode(MakeValue(SizeToInt(i)))}; + auto true_node = switch_graph->NewCNode(true_getitem_nodes); + std::vector false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, + NewValueNode(MakeValue(SizeToInt(i)))}; + auto false_node = switch_graph->NewCNode(false_getitem_nodes); + + auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], + false_branch_tuple->elements()[i], switch_graph, cond); + make_tuple_nodes.push_back(merge_node); + } + return switch_graph->NewCNode(make_tuple_nodes); + } +} + +AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, + const AbstractBasePtr &true_graph_output_abs, + const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, + const FuncGraphPtr &switch_graph) { + if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { + MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() + << ", false:" << false_graph_output_abs->ToString(); + } + return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, + switch_graph, cond); +} +} // namespace internal +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h new file mode 100644 index 0000000000..b3f3fe4733 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ + +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/pattern_matcher.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimSwitch, true, X, Y} +// {prim::kPrimSwitch, false, X, Y} +class SwitchSimplify : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode cond, true_br, false_br; + auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); + if (cond_value_) { + return true_br.GetNode(node); + } + return false_br.GetNode(node); + }; + + MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, + cond.CheckFunc(IsValueNode, node)); + + return nullptr; + } +}; + +// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => +// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} +class FloatTupleGetItemSwitch : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode cond, true_br, false_br, x; + MATCH_REPLACE_IF(node, + PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), + PPrimitive(prim::kPrimTupleGetItem, false_br, x)), + x.CheckFunc(IsVNode, node)); + return nullptr; + } +}; + +// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => +// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} +class FloatEnvGetItemSwitch : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode cond, true_br, false_br, x, x2; + MATCH_REPLACE(node, + PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), + PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), + PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); + + return nullptr; + } +}; + +namespace internal { +FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); +FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); +AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, + const AbstractBasePtr &true_graph_output_abs, + const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, + const FuncGraphPtr &func_graph); +} // namespace internal + +// {{prim::kPrimSwitch, X, G1, G2}, Xs} +class ConvertSwitchReplacement : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode_ = node->cast(); + if (cnode_->size() < 1) { + return nullptr; + } + + auto node_ = cnode_->input(0); + + PatternNode cond, true_br, false_br; + + auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node_)); + auto g2_ = GetValueNode(false_br.GetNode(node_)); + auto x_ = cond.GetNode(node_); + + // for switch replace method, only graphs without graph inside can be replaced + for (auto &item : g1_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } + } + + for (auto &item : g2_->value_nodes()) { + auto value_node = item.first; + if (IsValueNode(value_node)) { + return nullptr; + } + } + + auto true_output = g1_->output()->abstract(); + auto false_output = g2_->output()->abstract(); + auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); + auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); + + std::vector params; + auto fg = node_->func_graph(); + auto cloned_g1 = InlineClone(trans_g1, fg, params); + auto cloned_g2 = InlineClone(trans_g2, fg, params); + auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); + + return nnode; + }; + + MATCH_REPLACE_LAMBDA_IF( + node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); + + return nullptr; + } +}; + +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc new file mode 100644 index 0000000000..ddb84806e1 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/cast_eliminate.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "ir/func_graph.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimCast, X, T} +AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); + + // check pattern match + if (tgt_ == nullptr) { + return nullptr; + } + + // src type check + auto src_type = src_->Type(); + if (src_type == nullptr || !src_type->isa()) { + return nullptr; + } + + src_type = src_type->cast()->element(); + + // tgt type check + auto tgt_type = GetValueNode(tgt_); + if (tgt_type->isa()) { + tgt_type = tgt_type->cast()->element(); + } + + if (src_type->type_id() == tgt_type->type_id()) { + return src_; + } + + return nullptr; +} + +void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { + if (src_ == nullptr) { + src_ = node; + } else { + tgt_ = node; + } +} + +// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} +AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node); + + if (x_ != nullptr && t_ != nullptr) { + auto cast_op = parse::python_adapter::GetPyFn("mindspore.ops.operations", "Cast")(); + ValuePtr cast = parse::data_converter::PyDataToValue(cast_op); + auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph()); + cnode->set_abstract(node->abstract()); + return cnode; + } + return nullptr; +} + +void TwoCastEliminater::Visit(const AnfNodePtr &node) { + if (IsPrimitiveCNode(node, prim::kPrimCast)) { + auto cnode = node->cast(); + // {prim::kPrimCast, X, Y} + if (cnode->size() != 3) { + return; + } + x_ = cnode->input(1); + } else { + t_ = node; + } +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h new file mode 100644 index 0000000000..d5222d4310 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/cast_eliminate.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ + +#include "ir/visitor.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimCast, X, T} +class CastSameTypeEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + void Visit(const AnfNodePtr &node) override; + void Reset() { + src_ = nullptr; + tgt_ = nullptr; + } + + private: + AnfNodePtr src_{nullptr}, tgt_{nullptr}; +}; + +// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} +class TwoCastEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + void Visit(const AnfNodePtr &node) override; + void Reset() { + x_ = nullptr; + t_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, t_{nullptr}; +}; + +class CastEliminater : public OptimizerCaller { + public: + CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} + ~CastEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + auto new_node = cast_same_type_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + new_node = two_cast_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + return nullptr; + } + + private: + CastSameTypeEliminater cast_same_type_eliminater_; + TwoCastEliminater two_cast_eliminater_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/convert.h b/mindspore/ccsrc/frontend/optimizer/irpass/convert.h new file mode 100644 index 0000000000..d887874203 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/convert.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ + +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimPrint, Xs} -> {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} +class PrintTupleWrapper : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimPrint)) { + return nullptr; + } + + // already be {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} + auto cnode = node->cast(); + if (cnode->size() == 2 && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) { + return nullptr; + } + + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + // {prim::kPrimPrint, Xs} + auto &inputs = cnode->inputs(); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + // {prim::kPrinMakeTuple, Xs} + auto fg = node->func_graph(); + auto tuple = NewCNode(args, fg); + auto print = GetValueNode(cnode->input(0)); + return NewCNode({NewValueNode(print), tuple}, fg); + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h new file mode 100644 index 0000000000..14fd8743ff --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -0,0 +1,364 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ + +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "utils/symbolic.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class EnvGetitemTransform { + public: + EnvGetitemTransform() : cache_() {} + ~EnvGetitemTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + auto hash_key = std::make_pair(key, default_node); + if (cache.find(hash_key) == cache.end()) { + std::ostringstream ss("env", std::ostringstream::app); + if (key->node() != nullptr) { + ss << key->node()->ToString(); + } + + auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); + auto env = new_fg->output(); + while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { + // {prim::kPrimEnvSetItem, env, symbolickey, value} + auto &inputs = env->cast()->inputs(); + if (inputs.size() != 4 || !IsValueNode(inputs[2])) { + MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; + } + + env = inputs[1]; + auto value = inputs[3]; + auto key2 = GetValueNode(inputs[2]); + if (*key2 == *key) { + new_fg->set_output(value); + cache[hash_key] = new_fg; + cache_[fg] = cache; + return new_fg; + } + } + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); + cache[hash_key] = new_fg; + } + + return cache[hash_key]; + } + + private: + std::unordered_map, FuncGraphPtr, PairHasher>> + cache_; +}; +} // namespace internal + +// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y +class NewEnvGetItem : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + auto gety = [this](const AnfNodePtr &node) -> bool { + this->y_ = node; + return true; + }; + + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); + if (env_ != nullptr && env_->Len() == 0) { + return y_; + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (env_ == nullptr) { + env_ = GetValueNode(vnode); + } + } + + void Reset() { + y_ = nullptr; + env_ = nullptr; + } + + private: + AnfNodePtr y_{nullptr}; + EnvInstancePtr env_{nullptr}; +}; + +// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> +// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} +class AddEnvGetItem : public AnfVisitor { + public: + AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} + ~AddEnvGetItem() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsAddCNode = [](const AnfNodePtr &node) -> bool { + return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast()->size() == 3; + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node); + + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C, Z} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto c = cnode->input(2); + auto z = cnode->input(3); + + // {prim::kPrimEnvAdd, X, Y} + auto x = inp1->input(1); + auto y = inp1->input(2); + + auto fg = node->func_graph(); + auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z}); + auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z}); + + return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz}); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + ValuePtr PrimHyperAdd_; +}; + +// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} +class EnvGetSetItem : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsSetCNode = [](const AnfNodePtr &node) -> bool { + if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) { + return false; + } + + // {prim::kPrimEnvSetItem, X, C1, Y} + auto &inputs = node->cast()->inputs(); + if (inputs.size() != 4) { + return false; + } + + return IsValueNode(inputs[2]); + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode, IsNode})(node); + + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C2, Z} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key2 = cnode->input(2); + auto c2 = GetValueNode(key2); + auto default_v = cnode->input(3); + + // {prim::kPrimEnvSetItem, X, C1, Y} + auto env = inp1->input(1); + auto c1 = GetValueNode(inp1->input(2)); + auto last_set = inp1->input(3); + + if (*c1 == *c2) { + return last_set; + } + + while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { + // {prim::kPrimEnvSetItem, env, symbolickey, value} + auto &inputs = env->cast()->inputs(); + if (inputs.size() != 4 || !IsValueNode(inputs[2])) { + MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; + } + + env = inputs[1]; + last_set = inputs[3]; + auto symbolic_c1 = GetValueNode(inputs[2]); + if (*symbolic_c1 == *c2) { + return last_set; + } + } + + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v}); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; +}; + +class EnvGetItemEliminater : public OptimizerCaller { + public: + EnvGetItemEliminater() + : new_env_get_item_(std::make_shared()), + add_env_get_item_(std::make_shared()), + env_get_set_item_(std::make_shared()) { + eliminaters_.emplace_back(new_env_get_item_); + eliminaters_.emplace_back(add_env_get_item_); + eliminaters_.emplace_back(env_get_set_item_); + } + ~EnvGetItemEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; + std::vector eliminaters_{}; +}; + +// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} +class IncorporateEnvGetitem : public AnfVisitor { + public: + IncorporateEnvGetitem() : env_get_item_transform_() {} + ~IncorporateEnvGetitem() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsGCNode = [](const AnfNodePtr &node) -> bool { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() < 1) { + return false; + } + return IsValueNode(cnode->input(0)); + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode, IsNode})(node); + + if (!is_match_) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C, Y} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key = GetValueNode(cnode->input(2)); + auto default_v = cnode->input(3); + + // {G, Xs} + auto inputs = inp1->inputs(); + auto fg = GetValueNode(inputs[0]); + auto new_fg = env_get_item_transform_(fg, key, default_v); + + std::vector args; + args.push_back(NewValueNode(new_fg)); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + return node->func_graph()->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + internal::EnvGetitemTransform env_get_item_transform_; +}; + +// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} +class IncorporateEnvGetitemSwitch : public AnfVisitor { + public: + IncorporateEnvGetitemSwitch() : env_get_item_transform_() {} + ~IncorporateEnvGetitemSwitch() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + is_match_ = false; + auto IsSwNode = [](const AnfNodePtr &node) -> bool { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() < 1) { + return false; + } + + return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); + }; + AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode, IsNode})(node); + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + // {prim::kPrimEnvGetItem, {...}, C, Y} + auto cnode = node->cast(); + auto inp1 = cnode->input(1)->cast(); + auto key = GetValueNode(cnode->input(2)); + auto default_v = cnode->input(3); + + // {{prim::kPrimSwitch, X, G1, G2}, Xs} + auto inputs = inp1->inputs(); + is_match_ = false; + AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs[0]); + if (!is_match_) { + return nullptr; + } + + // {prim::kPrimSwitch, X, G1, G2} + auto sw = inputs[0]->cast(); + auto x = sw->input(1); + auto g1 = GetValueNode(sw->input(2)); + auto g2 = GetValueNode(sw->input(3)); + auto new_g1 = env_get_item_transform_(g1, key, default_v); + auto new_g2 = env_get_item_transform_(g2, key, default_v); + + auto fg = node->func_graph(); + auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); + + std::vector args{new_sw}; + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + return fg->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override { is_match_ = true; } + + private: + bool is_match_{false}; + internal::EnvGetitemTransform env_get_item_transform_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc new file mode 100644 index 0000000000..44c1b62fa5 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/grad_var_prepare.h" +#include +#include +#include +#include + +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace irpass { +static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, FuncGraphPtr func_graph, + AnfNodePtr func_node, bool is_unpack, bool sens_param) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_node); + std::vector nodes; + AnfNodePtr unpack_graph_node = nullptr; + if (is_unpack) { + auto unpack_graph = std::make_shared("unpack_graph", sens_param, true); + nodes.push_back(NewValueNode(unpack_graph)); + nodes.push_back(func_node); + // {unpackcall, {GradOperation, ...}, args...} + std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), + [](const AnfNodePtr &node) { return node; }); + unpack_graph_node = func_graph->NewCNode(nodes); + } else { + auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); + nodes.push_back(NewValueNode(unpack_graph)); + nodes.push_back(func_node); + // {{GradOperation, ...}, args...} + std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), + [](const AnfNodePtr &node) { return node; }); + unpack_graph_node = func_graph->NewCNode(nodes); + } + return unpack_graph_node; +} + +// get metagraph of value node +MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { + ValuePtr value; + if (IsValueNode(node)) { + value = GetValueNode(node)->cast()->function(); + } else { + value = GetValueNode(node); + } + if (value == nullptr) { + return nullptr; + } + return value->cast(); +} + +// check if node is a specific metafuncgraph op +bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { + if (node != nullptr) { + auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); + if (meta_func_graph_ptr == nullptr) { + return false; + } + + if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { + return true; + } + } + return false; +} + +// {{GradOperation, g, w}, Ys} +// {UnPackCall, {GradOperation, g, w}, Ys} +AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + // {{...}, Ys} + auto inputs_y = node->cast()->inputs(); + std::vector inputs_x; + if (IsCNode(inputs_y[0])) { + inputs_x = inputs_y[0]->cast()->inputs(); + } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { + inputs_x = inputs_y[1]->cast()->inputs(); + } else { + return nullptr; + } + + // {{...}, Xs} + if (inputs_x.size() < 2) { + return nullptr; + } + + // {GradOperation, g, w} or {GradOperation, g} + if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { + return nullptr; + } + + auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); + if (meta_func == nullptr) { + return nullptr; + } + auto grad_op_ptr = meta_func->cast(); + auto func_node = inputs_x[1]; + if (!IsValueNode(func_node)) { + return nullptr; + } + + AnfNodePtr unpack_graph_node = + GenerateUnpackGraphNode(inputs_y, node->cast()->func_graph(), func_node, + IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); + // constuct new grad_opration + inputs_x[1] = unpack_graph_node; + auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); + if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { + inputs_y[1] = grad_op_cnode; + } else { + inputs_y[0] = grad_op_cnode; + } + auto cnode = node->func_graph()->NewCNode(inputs_y); + return cnode; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h new file mode 100644 index 0000000000..f6992a87c6 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ + +#include +#include +#include +#include + +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {{GradOperation, g, w}, Ys} +// {UnPackCall, {GradOperation, g, w}, Ys} +class GradVarPrepare : public AnfVisitor { + public: + GradVarPrepare() + : grad_op_(std::make_shared("grad")), + unpack_op_(std::make_shared("unpack_call")) {} + ~GradVarPrepare() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + private: + MetaFuncGraphPtr grad_op_; + MetaFuncGraphPtr unpack_op_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc new file mode 100644 index 0000000000..0d98cffa37 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/gradient_eliminate.h" + +#include + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { + ScopeGuard scope_guard(vnode->scope()); + + auto newg = ad::Kprim(vnode, resource); + if (newg != nullptr) { + return NewValueNode(newg); + } + + // when find in J failed, try in Jmeta + auto prim = GetValueNode(vnode); + MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); + if (meta != nullptr) { + return NewValueNode(meta); + } + + return nullptr; +} + +bool CheckIfEmbedJFuncGraph(const FuncGraphPtr func_graph) { + // if func graph also contain J FuncGraph, then ignore this funcgraph. ExpandJ innermost graph first; + auto func_graph_manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(func_graph_manager); + return func_graph_manager->func_graph_j_total(func_graph); +} + +AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { + if (IsValueNode(vnode)) { + ScopeGuard scope_guard(vnode->scope()); + + auto func_graph = GetValueNode(vnode); + MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); + + // high_order_grad begin; + // if graph also contain J Graph, then ignore this graph. ExpandJ innermost graph first; + if (CheckIfEmbedJFuncGraph(func_graph)) { + MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J(funcgraph), will expandJ later"; + return nullptr; + } + // high_order_grad end; + + MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; + auto newfg = ad::Grad(func_graph, resource); + return NewValueNode(newfg); + } + + if (IsValueNode(vnode)) { + return ExpandJPrimitive(vnode, resource); + } + + return nullptr; +} +} // namespace internal +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h new file mode 100644 index 0000000000..82312d9e37 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/gradient_eliminate.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "common/utils.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/ad/grad.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); +} // namespace internal + +// {prim::kPrimJ, C} +class ExpandJPrim : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); + if (x_ != nullptr) { + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + auto j_node = internal::ExpandJ(x_, optimizer->resource()); + TraceManager::EndTrace(); + return j_node; + } + return nullptr; + } + + void Visit(const ValueNodePtr &node) override { x_ = node; } + + private: + ValueNodePtr x_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h new file mode 100644 index 0000000000..2f6404458f --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_call.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ + +#include +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class CallOutputTransform { + public: + CallOutputTransform() : cache_() {} + ~CallOutputTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + if (cache.find(nargs) == cache.end()) { + FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("call")); + + std::vector new_items; + new_items.push_back(new_fg->output()); + for (size_t i = 0; i < nargs; i++) { + new_items.push_back(new_fg->add_parameter()); + } + new_fg->set_output(new_fg->NewCNode(new_items)); + + cache[nargs] = new_fg; + } + return cache[nargs]; + } + + private: + std::unordered_map> cache_; +}; +} // namespace internal + +// {{G, Xs}, Ys} +class IncorporateCall : public AnfVisitor { + public: + IncorporateCall() : call_output_transform_() {} + ~IncorporateCall() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs[0] == nullptr || !inputs[0]->isa()) { + return nullptr; + } + + AnfVisitor::Visit(inputs[0]); + if (fg_ == nullptr) { + return nullptr; + } + + auto xs_size = Xs_.size(); + auto ys_size = inputs.size() - 1; + auto new_fg = call_output_transform_(fg_, ys_size); + + std::vector args; + args.push_back(NewValueNode(new_fg)); + + if (xs_size > 0) { + (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); + } + + if (ys_size > 0) { + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + } + + return node->func_graph()->NewCNode(args); + } + + void Visit(const CNodePtr &cnode) override { + // {G, Xs} + if (cnode->size() < 1 || !IsValueNode(cnode->input(0))) { + return; + } + + auto &inputs = cnode->inputs(); + fg_ = GetValueNode(inputs[0]); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + } + + void Reset() { + Xs_.clear(); + fg_ = nullptr; + } + + private: + FuncGraphPtr fg_; + std::vector Xs_{}; + internal::CallOutputTransform call_output_transform_; +}; + +// {{{prim::kPrimSwitch, X, G1, G2}, Xs}, Ys} +class IncorporateCallSwitch : public AnfVisitor { + public: + IncorporateCallSwitch() : call_output_transform_() {} + ~IncorporateCallSwitch() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + // {{...}, Ys} + auto &inputs = node->cast()->inputs(); + if (inputs[0] == nullptr || !inputs[0]->isa()) { + return nullptr; + } + + // {{...}, Xs} + auto &inputs_x = inputs[0]->cast()->inputs(); + if (inputs_x[0] == nullptr || !inputs_x[0]->isa()) { + return nullptr; + } + + // {prim::kPrimSwitch, X, G1, G2} + AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs_x[0]); + if (g2_ == nullptr) { + return nullptr; + } + + auto fg = node->func_graph(); + auto xs_size = inputs_x.size() - 1; + auto ys_size = inputs.size() - 1; + auto new_g1 = call_output_transform_(g1_, ys_size); + auto new_g2 = call_output_transform_(g2_, ys_size); + auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); + + std::vector args{sw_node}; + if (xs_size > 0) { + (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); + } + if (ys_size > 0) { + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + } + + return fg->NewCNode(args); + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + return; + } + AnfVisitor::Visit(node); + } + + void Visit(const ValueNodePtr &vnode) override { + auto g = GetValueNode(vnode); + if (g1_ == nullptr) { + g1_ = g; + } else { + g2_ = g; + } + } + + void Reset() { + x_ = nullptr; + g1_ = nullptr; + g2_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}; + FuncGraphPtr g1_{nullptr}, g2_{nullptr}; + internal::CallOutputTransform call_output_transform_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h new file mode 100644 index 0000000000..828e205e4f --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -0,0 +1,416 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ + +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class GetitemTransform { + public: + GetitemTransform() : cache_() {} + ~GetitemTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { + if (cache_.find(fg) == cache_.end()) { + cache_[fg] = {}; + } + + auto &cache = cache_[fg]; + if (cache.find(idx) == cache.end()) { + std::ostringstream ss("tp", std::ostringstream::app); + ss << idx; + + auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); + auto output = new_fg->output(); + if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto cnode = output->cast(); + auto ids = IntToSize(idx + 1); + // Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. + if (ids >= cnode->size()) { + MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); + } + new_fg->set_output(cnode->input(ids)); + } else { + new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); + } + + cache[idx] = new_fg; + } + return cache[idx]; + } + + private: + std::unordered_map> cache_; +}; +} // namespace internal + +// {prim::kPrimTupleGetItem, {G, Xs}, C} +class IncorporateGetitem : public AnfVisitor { + public: + IncorporateGetitem() : getitem_transform_() {} + ~IncorporateGetitem() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) { + return nullptr; + } + + if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + // If graph kernel has muti output, do not split. + // some graph kernel output has EnvInstance node or DeadCode node should split. + auto output = fg_->output(); + if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + auto output_cnode = output->cast(); + auto outputs = output_cnode->inputs(); + int real_output_cnt = 0; + for (size_t i = 1; i < outputs.size(); ++i) { + if (IsCNode(outputs[i]) || IsValueNode(outputs[i]) || IsParam(outputs[i])) { + real_output_cnt++; + if (real_output_cnt > 1) { + return nullptr; + } + } + } + } + } + + auto new_fg = getitem_transform_(fg_, idx_); + (void)args_.insert(args_.begin(), NewValueNode(new_fg)); + return node->func_graph()->NewCNode(args_); + } + + void Visit(const CNodePtr &cnode) override { + if (cnode->size() == 0 || !IsValueNode(cnode->input(0))) { + return; + } + + auto &inputs = cnode->inputs(); + fg_ = GetValueNode(inputs[0]); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); + } + + void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } + + void Reset() { + idx_ = -1; + fg_ = nullptr; + args_.clear(); + } + + private: + int idx_{-1}; + FuncGraphPtr fg_{nullptr}; + std::vector args_{}; + internal::GetitemTransform getitem_transform_; +}; + +class IncorporateGetitemFromParam : public AnfVisitor { + public: + void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto &node_users = mng->node_users(); + if (node_users.find(param) == node_users.end() || node_users[param].empty()) { + args_.push_back(cnode->input(input_idx + 1)); + return; + } + + for (auto &user : node_users[param]) { + if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { + // we do not process this case. + args_.push_back(cnode->input(input_idx + 1)); + return; + } + } + + // update new args. + if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) { + // case 1 + replace_parameters_[input_idx] = true; + need_update_ = true; + auto make_tuple_cnode = cnode->input(input_idx + 1)->cast(); + auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs(); + inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1; + args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end()); + } else { + // case 2 + auto prev_cnode = cnode->input(input_idx + 1)->cast(); + auto prev_fg = GetValueNode(prev_cnode->input(0)); + auto fg_output = prev_fg->output(); + if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) { + MS_LOG(ERROR) << "The return of: " << prev_fg->ToString() + << " should be a make tuple, but got: " << fg_output->DebugString(); + return; + } + replace_parameters_[input_idx] = true; + need_update_ = true; + auto make_tuple_cnode = fg_output->cast(); + inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1; + for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) { + auto new_getitem = + func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))}); + auto aptr = std::make_shared(std::make_shared(SizeToInt(output_i))); + new_getitem->input(2)->set_abstract(aptr); + new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract()); + args_.push_back(new_getitem); + } + } + } + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (node->func_graph() == nullptr) { + return nullptr; + } + + Reset(); + + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + auto &inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0]); + if (fg == nullptr) { + return nullptr; + } + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto parameters = fg->parameters(); + if (parameters.size() != inputs.size() - 1) { + return nullptr; + } + replace_parameters_ = std::vector(parameters.size(), false); + inputs_num_ = std::vector(parameters.size(), 1); + auto node_fg = node->func_graph(); + + for (size_t i = 1; i < inputs.size(); ++i) { + if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) { + Process(node_fg, cnode, parameters[i - 1], i - 1); + } else { + args_.push_back(inputs[i]); + } + } + + if (!need_update_) { + return nullptr; + } + + FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); + mng->AddFuncGraph(new_fg); + + auto node_users = mng->node_users(); + std::vector new_fg_parameters = new_fg->parameters(); + std::vector new_parameters; + size_t curr_input_idx{0}; + for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) { + if (!replace_parameters_[param_i]) { + if (parameters[param_i]->abstract() != nullptr) { + new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract()); + } + new_parameters.push_back(new_fg_parameters[param_i]); + curr_input_idx++; + continue; + } + + // make a new parameter. + for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) { + auto new_param = std::make_shared(new_fg); + new_param->set_abstract(args_.at(curr_input_idx)->abstract()); + + // update users of new parameter. + for (auto &user : node_users[new_fg_parameters[param_i]]) { + idx_ = -1; + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode})(user.first); + if (idx_ == -1) { + MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString() + << " must be tuple getitem here, but got: " << user.first->DebugString(); + return nullptr; + } + + if (input_i == IntToSize(idx_)) { + for (auto &sub_user : node_users[user.first]) { + auto sub_user_cnode = sub_user.first->cast(); + MS_EXCEPTION_IF_NULL(sub_user_cnode); + sub_user_cnode->set_input(sub_user.second, new_param); + (void)mng->Replace(sub_user.first, sub_user_cnode); + } + } + } + + // (void)mng->Replace(new_fg_parameters[param_i], new_param); + new_parameters.push_back(new_param); + curr_input_idx++; + } + } + + mng->SetParameters(new_fg, new_parameters); + (void)args_.insert(args_.begin(), NewValueNode(new_fg)); + auto new_call = node_fg->NewCNode(args_); + new_call->set_abstract(node->abstract()); + return new_call; + } + + void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } + + void Visit(const CNodePtr &cnode) override {} + + void Reset() { + replace_parameters_.clear(); + args_.clear(); + inputs_num_.clear(); + need_update_ = false; + idx_ = -1; + } + + private: + std::vector replace_parameters_{}; + std::vector args_{}; + std::vector inputs_num_{}; + bool need_update_{false}; + int idx_{-1}; +}; + +// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} +class IncorporateGetitemSwitch : public AnfVisitor { + public: + IncorporateGetitemSwitch() : getitem_transform_() {} + ~IncorporateGetitemSwitch() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + is_in_get_ = true; + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + is_in_get_ = false; + + auto fg = node->func_graph(); + if (idx_ == -1 || switch_ == nullptr || fg == nullptr) { + return nullptr; + } + + is_in_switch_ = true; + AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(switch_); + is_in_switch_ = false; + + if (g2_ == nullptr) { + return nullptr; + } + + auto new_g1 = getitem_transform_(g1_, idx_); + auto new_g2 = getitem_transform_(g2_, idx_); + auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); + (void)args_.insert(args_.begin(), sw_node); + + return fg->NewCNode(args_); + } + + void Visit(const AnfNodePtr &node) override { + if (is_in_switch_ && x_ == nullptr) { + x_ = node; + return; + } + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { + if (is_in_get_ && cnode->size() != 0) { + auto &inputs = cnode->inputs(); + switch_ = inputs[0]; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (is_in_get_) { + idx_ = GetValue(vnode->value()); + } + + if (is_in_switch_) { + auto g = GetValueNode(vnode); + if (g1_ == nullptr) { + g1_ = g; + } else { + g2_ = g; + } + } + } + + void Reset() { + x_ = nullptr; + g1_ = nullptr; + g2_ = nullptr; + switch_ = nullptr; + args_.clear(); + is_in_get_ = false; + is_in_switch_ = false; + } + + private: + int idx_{-1}; + AnfNodePtr switch_{nullptr}, x_{nullptr}; + FuncGraphPtr g1_{nullptr}, g2_{nullptr}; + bool is_in_get_{false}, is_in_switch_{false}; + std::vector args_{}; + internal::GetitemTransform getitem_transform_; +}; + +class IncorporateGetitemSet : public OptimizerCaller { + public: + IncorporateGetitemSet() + : incorporate_getitem_(std::make_shared()), + incorporate_getitem_switch_(std::make_shared()) { + eliminaters_.emplace_back(incorporate_getitem_); + eliminaters_.emplace_back(incorporate_getitem_switch_); + } + ~IncorporateGetitemSet() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; + std::vector eliminaters_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h new file mode 100644 index 0000000000..dfe345fe01 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/indexed_slices_eliminate.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} +// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} +// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} +class IndexedSlicesEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(1); + } + AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(2); + } + AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(3); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { + tuple_ = cnode; + is_match_ = true; + } + } + + void Reset() { + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + CNodePtr tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h new file mode 100644 index 0000000000..8cafb268b4 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -0,0 +1,204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class ReplaceApplicator : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsValueNode(node)) { + return nullptr; + } + + auto fg = GetValueNode(node); + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { + return nullptr; + } + + auto out = fg->output(); + MS_EXCEPTION_IF_NULL(out); + if (!out->isa()) { + return nullptr; + } + + auto &inputs = out->cast()->inputs(); + auto params = fg->parameters(); + + // Exclude first elements of inputs which is fn. + auto input_size = inputs.size(); + auto param_size = params.size(); + if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size && + std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) { + auto inner = inputs[0]; + if (IsValueNode(inner) || + (IsValueNode(inner) && GetValueNode(inner)->parent() == nullptr)) { + return inner; + } + } + + return nullptr; + } +}; + +using CriterionFuncType = std::function; + +bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { + auto n_cnode = fg->nodes().size() - fg->parameters().size(); + // There is at least one CNode(return, other_node). + return n_cnode <= 2; +} + +bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { + auto &cnodes = fg->func_graph_cnodes_index(); + int n_use = + std::accumulate(cnodes.begin(), cnodes.end(), 0, + [](int sum, const std::pair &item) { return sum + item.second; }); + return n_use == 1; +} + +bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node->func_graph()); + return node->func_graph()->has_flag("inline_inside"); +} + +bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } + +bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } + +// {G, Xs} +class InlinerBase : public AnfVisitor { + public: + explicit InlinerBase(std::vector> criterions) : criterions_(criterions) {} + ~InlinerBase() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa()) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 1 || !IsValueNode(inputs[0])) { + return nullptr; + } + + // G + auto fg = GetValueNode(inputs[0]); + if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { + return nullptr; + } + // Do not inline GraphKernel to Cell. + if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + // If the GraphKernel only contains a return node, we make it inlined. + if (fg->nodes().size() - fg->parameters().size() > 1) { + return nullptr; + } + } + + Reset(); + bool is_match = false; + for (auto &criterion : criterions_) { + if (!criterion.first(fg, node)) { + continue; + } + + if (criterion.second && IsRecursive(fg)) { + continue; + } + + is_match = true; + break; + } + + if (!is_match) { + return nullptr; + } + + std::vector params; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); + + if (IsUniqueUse(fg, nullptr)) { + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + ReplaceParams(mng, params, fg); + auto out_node = fg->output(); + mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); + return out_node; + } + + return InlineClone(fg, node->func_graph(), params, inputs[0]->scope()); + } + + void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector &new_params, + const FuncGraphPtr &fg) { + auto params = fg->parameters(); + auto old_size = params.size(); + if (old_size != new_params.size()) { + MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size() + << fg->output()->DebugString(10); + } + for (size_t i = 0; i < old_size; i++) { + (void)mng->Replace(params[i], new_params[i]); + } + } + + bool IsRecursive(const FuncGraphPtr &fg) { + if (!is_checked_) { + is_checked_ = true; + is_recursive_ = fg->recursive(); + } + return is_recursive_; + } + + void Reset() { + is_checked_ = false; + is_recursive_ = false; + } + + private: + bool is_checked_{false}, is_recursive_{false}; + std::vector> criterions_; +}; + +class Inliner : public InlinerBase { + public: + Inliner() + : InlinerBase({ + {IsUniqueUse, true}, + {IsTrivial, false}, + {IsInside, false}, + {IsCore, false}, + {NoCriterion, true}, + }) {} + ~Inliner() override = default; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h new file mode 100644 index 0000000000..acd6844ee7 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_eliminate.h @@ -0,0 +1,301 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ + +#include +#include +#include + +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// (a, b, c, ...)[0] => a +// (a, b, c, ...)[1] => b +// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} +class GetitemEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + + if (is_match_) { + return tuple_->input(id_); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + tuple_ = cnode; + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (tuple_ != nullptr && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value()) + 1); + if (tuple_->size() > id_) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + size_t id_{0}; + CNodePtr tuple_{nullptr}; +}; + +// (a, b, c, ...)[0] => a +// (a, b, c, ...)[1] => b +// {prim::kPrimTupleGetItem, C1, C} +class GetitemConstEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); + + if (is_match_) { + return NewValueNode((*tuple_)[id_]); + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + tuple_ = GetValueNode(vnode); + } + if (tuple_ != nullptr && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value())); + if (tuple_->size() > id_) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + size_t id_{0}; + ValueTuplePtr tuple_{nullptr}; +}; + +// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) +// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) +// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} +class SetitemEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); + + auto fg = node->func_graph(); + if (fg != nullptr && z_ != nullptr) { + args_[id_] = z_; + return fg->NewCNode(args_); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (is_match_) { + z_ = node; + return; + } + + AnfVisitor::Visit(node); + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + auto &inputs = cnode->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_)); + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (args_.size() > 0 && IsValueNode(vnode)) { + id_ = IntToSize(GetValue(vnode->value()) + 1); + if (id_ < args_.size()) { + is_match_ = true; + } + } + } + + void Reset() { + id_ = 0; + z_ = nullptr; + is_match_ = false; + args_.clear(); + } + + private: + bool is_match_{false}; + size_t id_{0}; + AnfNodePtr z_{nullptr}; + std::vector args_{}; +}; + +// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} +class GetSetitemEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); + + auto fg = node->func_graph(); + if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { + if (key1_ == key2_) { + return last_; + } + return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_}); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { + if (cnode->size() < 4) { + return; + } + + tuple_ = cnode->input(1); + last_ = cnode->input(3); + + // key of setitem + is_in_set_ = true; + AnfVisitor::Visit(cnode->input(2)); + is_in_set_ = false; + } + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + auto key = GetValue(vnode->value()); + if (is_in_set_) { + key1_ = key; + } else { + c2_ = vnode; + key2_ = key; + } + } + } + + void Reset() { + key1_ = -1; + key2_ = -1; + c2_ = nullptr; + last_ = nullptr; + tuple_ = nullptr; + is_in_set_ = false; + } + + private: + bool is_in_set_{false}; + int key1_{-1}, key2_{-1}; + AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr}; +}; + +// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> +// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} +class GetitemDependReorder : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); + if (x_ == nullptr) { + return nullptr; + } + + auto fg = node->func_graph(); + auto item_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), x_, c_}, fg); + return NewCNode({NewValueNode(prim::kPrimDepend), item_node, y_}, fg); + } + + void Visit(const CNodePtr &cnode) override { + // {prim::kPrimDepend, X, Y} + if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && cnode->size() == 3) { + x_ = cnode->input(1); + y_ = cnode->input(2); + } + } + + void Visit(const ValueNodePtr &vnode) override { c_ = vnode; } + + void Reset() { + x_ = nullptr; + y_ = nullptr; + c_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; +}; + +class ItemTupleEliminater : public OptimizerCaller { + public: + ItemTupleEliminater() + : get_item_eliminater_(std::make_shared()), + get_item_const_eliminater_(std::make_shared()), + set_item_eliminater_(std::make_shared()), + get_set_item_eliminater_(std::make_shared()), + get_item_depend_reorder_(std::make_shared()) { + eliminaters_.emplace_back(get_item_eliminater_); + eliminaters_.emplace_back(get_item_const_eliminater_); + eliminaters_.emplace_back(set_item_eliminater_); + eliminaters_.emplace_back(get_set_item_eliminater_); + eliminaters_.emplace_back(get_item_depend_reorder_); + } + ~ItemTupleEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, + get_item_depend_reorder_; + std::vector eliminaters_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h new file mode 100644 index 0000000000..8d3839bd9e --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/mark_interface_fusion.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "utils/graph_utils.h" +#include "frontend/operator/composite/composite.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +static int count = 0; + +std::string GetFusionNumber() { + std::stringstream ss; + ss << std::setw(4) << std::setfill('0') << count; + std::string num = ss.str(); + ++count; + + return "_" + num; +} + +// Mark CNodes which can be merged in kernel build +class MarkInterfaceFusion : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { + auto cnode = node->cast(); + auto condition = cnode->input(1); + std::string cmp; + std::unordered_map cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, + {"LessEqual", "LE"}, {"Less", "LT"}, + {"Equal", "EQ"}, {"NotEqual", "NE"}}; + if (IsPrimitiveCNode(condition)) { + auto prim_name = GetCNodeFuncName(condition->cast()); + if (cmp_list.count(prim_name) != 0) { + // Mark Select and compare node + cmp = cmp_list[prim_name]; + auto cnt = GetFusionNumber(); + AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); + AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { + AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); + } + } + } + } + } + return nullptr; + } + + void Visit(const AnfNodePtr &) override {} + + private: + AnfNodePtr y_{nullptr}; +}; + +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h new file mode 100644 index 0000000000..a3cf6e2231 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h @@ -0,0 +1,320 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} -> +// {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}} +// {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} -> +// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} +class MergeAddN : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + optimizer_ = optimizer; + is_outer_ = true; + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); + if (!is_match_ || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode = node->cast(); + auto addn = NewValueNode(GetValueNode(cnode->input(0))); + + // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} + (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); + auto fg = node->func_graph(); + auto make_node = fg->NewCNode(args_); + + return fg->NewCNode({addn, make_node}); + } + + void Visit(const CNodePtr &cnode) override { + if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + return; + } + + auto &inputs = cnode->inputs(); + + if (is_outer_) { + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_)); + + is_outer_ = false; + is_inner_ = true; + + // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]); + if (is_match_) { + if (!is_unique(inputs[1])) { + is_match_ = false; + return; + } + (void)Ys_.erase(Ys_.begin()); + (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); + (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); + return; + } + + // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back()); + if (is_match_) { + if (!is_unique(inputs.back())) { + is_match_ = false; + return; + } + Ys_.pop_back(); + (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); + (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); + return; + } + + return; + } + + if (is_inner_) { + is_match_ = true; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + } + } + + bool is_unique(const AnfNodePtr &node) { + auto mng = optimizer_->resource()->manager(); + auto &node_users = mng->node_users(); + if (node_users.find(node) == node_users.end()) { + return false; + } + + size_t n_use = node_users[node].size(); + return n_use == 1; + } + + void Reset() { + Xs_.clear(); + Ys_.clear(); + args_.clear(); + is_inner_ = false; + is_outer_ = false; + is_match_ = false; + } + + private: + OptimizerPtr optimizer_{nullptr}; + std::vector Xs_{}, Ys_{}, args_{}; + bool is_inner_{false}, is_outer_{false}, is_match_{false}; +}; + +// {PrimAddN, {kPrimMakeTuple, Xs}} +class AddNZeroFilter : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); + + if (filtered_Xs_.empty() || node->func_graph() == nullptr) { + return nullptr; + } + + // if only two node in filtered_nodes, {make_tuple, x}. return x. + if (filtered_Xs_.size() == 2) { + return filtered_Xs_[1]; + } + + // if only one node in filtered_nodes, all node is zerolike, return one of the input. + if (filtered_Xs_.size() == 1 && Xs_.size() > 0) { + return Xs_[0]; + } + + if (!has_zero_like_) { + return nullptr; + } + + auto cnode = node->cast(); + auto addn = NewValueNode(GetValueNode(cnode->input(0))); + auto fg = node->func_graph(); + auto make_tuple = fg->NewCNode(filtered_Xs_); + return fg->NewCNode({addn, make_tuple}); + } + + void Visit(const CNodePtr &cnode) override { + if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { + return; + } + + auto &inputs = cnode->inputs(); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + + // {kPrimMakeTuple, X1, X2, ...} + filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (auto &x : Xs_) { + if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { + filtered_Xs_.push_back(x); + } else { + has_zero_like_ = true; + } + } + } + + void Reset() { + Xs_.clear(); + filtered_Xs_.clear(); + has_zero_like_ = false; + } + + private: + std::vector filtered_Xs_{}, Xs_{}; + bool has_zero_like_{false}; +}; + +// {PrimAddN, {kPrimMakeTuple, Xs}} +// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd. +// case0: AddN(inputs)(inputs size < 2) -> error +// case1: AddN(inputs)(all inputs is ValueNode) -> error +// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor) +// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input) +// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) +class AddNEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + auto fg = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(fg); + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + if (fg->recursive()) { + return nullptr; + } + + auto new_fg = TransformableClone(fg, std::make_shared("fg")); + mng->AddFuncGraph(new_fg); + need_update_ = false; + bool changed; + do { + changed = Process(new_fg); + } while (changed); + + if (!need_update_) { + return nullptr; + } else { + auto new_sx = inputs; + new_sx[0] = NewValueNode(new_fg); + return node->func_graph()->NewCNode(new_sx); + } + } + + bool Process(const FuncGraphPtr &func_graph) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto nodes = TopoSort(func_graph->output()); + bool changed = false; + + for (size_t i = 0; i < nodes.size(); ++i) { + auto node = nodes[i]; + if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &tuple_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(tuple_input); + auto tuple_input_cnode = tuple_input->cast(); + MS_EXCEPTION_IF_NULL(tuple_input_cnode); + auto &tuple_inputs = tuple_input_cnode->inputs(); + if (tuple_inputs.size() < 3) { + // case0: inputs size < 2, error + MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2); + } + + int valuenode_num = + std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) { + if (IsValueNode(node)) { + return accumulator + 1; + } else { + return accumulator; + } + }); + if (IntToSize(valuenode_num) == tuple_inputs.size()) { + // case1: all inputs is ValueNode, error + MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2); + } + + if (tuple_inputs.size() == 3) { + // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) + MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); + ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); + std::vector new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], + tuple_inputs[2]}; + mng->Replace(node, func_graph->NewCNode(new_xs)); + changed = true; + continue; + } + + auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(), + [](const AnfNodePtr &node) { return IsValueNode(node); }); + if (first_valuenode == tuple_inputs.end()) { + // no ValueNode input found. + continue; + } else { + // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) + std::vector make_tuple_new_xs{ + NewValueNode(prim::kPrimMakeTuple), + }; + std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(), + [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) { + if (node != *first_valuenode) { + make_tuple_new_xs.push_back(node); + } + }); + ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); + auto new_addn = func_graph->NewCNode( + {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); + ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); + auto new_add = + func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); + (void)mng->Replace(node, new_add); + changed = true; + continue; + } + } + + need_update_ = need_update_ || changed; + return changed; + } + + private: + bool need_update_{false}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h b/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h new file mode 100644 index 0000000000..658a287234 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/minmax_grad.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ + +#include +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +// check if node is MinimumGrad() or MaximumGrad() +bool IsOriginMaxMinGrad(const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) { + return false; + } + + auto cnode = node->cast(); + auto prim = GetValueNode(cnode->input(0)); + auto x_v = prim->GetAttr("grad_x"); + auto y_v = prim->GetAttr("grad_y"); + if (x_v == nullptr || y_v == nullptr || !x_v->isa() || !y_v->isa()) { + return false; + } + + bool x = GetValue(x_v); + bool y = GetValue(y_v); + return x && y; +} +} // namespace internal + +// {prim::kPrimTupleGetItem, {target_grad, Xs}, C} +class MinMaximumGrad : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTupleGetItem, {internal::IsOriginMaxMinGrad, IsValueNode})(node); + if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) { + return nullptr; + } + + // check single use + auto mng = optimizer->resource()->manager(); + auto &users = mng->node_users(); + if (users.find(grad_) == users.end() || users[grad_].size() != 1) { + return nullptr; + } + + // {target_grad, Xs} + auto &inputs = grad_->inputs(); + auto prim = GetValueNode(inputs[0]); + + auto new_prim = std::make_shared(prim->name()); + new_prim->set_attr("grad_x", MakeValue(true)); + new_prim->set_attr("grad_y", MakeValue(true)); + + if (idx_ == 0) { + new_prim->set_attr("grad_y", MakeValue(false)); + } + if (idx_ == 1) { + new_prim->set_attr("grad_x", MakeValue(false)); + } + + std::vector args; + args.push_back(NewValueNode(new_prim)); + (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + + auto fg = node->func_graph(); + auto tuple = fg->NewCNode(args); + + return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple, NewValueNode(MakeValue(idx_))}); + } + + void Visit(const CNodePtr &cnode) override { grad_ = cnode; } + + void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } + + void Reset() { + idx_ = -1; + grad_ = nullptr; + } + + private: + int idx_{-1}; + CNodePtr grad_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h b/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h new file mode 100644 index 0000000000..999376e528 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/param_replace.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ + +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/parse.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class ReplaceOldParam : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!IsParam(node)) { + return nullptr; + } + auto resource = std::dynamic_pointer_cast(optimizer->resource()); + MS_EXCEPTION_IF_NULL(resource); + + auto top_graph = resource->func_graph(); // parse::Parser::GetTopFuncGraph(); + MS_EXCEPTION_IF_NULL(top_graph); + + auto param_node = node->cast(); + if (!param_node->has_default() || node->func_graph() == top_graph) { + return nullptr; + } + auto para_name = param_node->name(); + for (const auto &tnode : top_graph->parameters()) { + auto para = tnode->cast(); + if (para != nullptr && para->name() == para_name) { + return para; + } + } + return nullptr; + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h new file mode 100644 index 0000000000..32fc5abc7d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/partial_eliminate.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} +class PartialEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + Xs_.clear(); + auto &inputs = node->cast()->inputs(); + Visit(inputs[0]); + + if (Xs_.size() == 0) { + return nullptr; + } + + // {X, Xs, Ys} + std::vector args{}; + (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + auto new_node = node->func_graph()->NewCNode(args); + TraceManager::EndTrace(); + return new_node; + } + + void Visit(const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { + return; + } + + auto &inputs = node->cast()->inputs(); + // {prim::kPrimPartial, X, Xs} + if (inputs.size() < 2) { + return; + } + + // fill Xs + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); + } + + private: + std::vector Xs_{}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h new file mode 100644 index 0000000000..d8c96825c9 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/prim_eliminate.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim, X} +class PrimEliminater : public AnfVisitor { + public: + explicit PrimEliminater(const PrimitivePtr &prim) : prim_(prim) {} + ~PrimEliminater() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim_, {IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { x_ = node; } + + private: + AnfNodePtr x_{nullptr}; + PrimitivePtr prim_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h new file mode 100644 index 0000000000..78b7d3f4f1 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/reduce_eliminate.h @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ + +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace opt { +namespace irpass { +using abstract::Shape; +using abstract::ShapePtr; + +// {ReduceLike, X, axis} +class ReduceOneEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + PrimitivePtr prim; + if (IsPrimitiveCNode(node, prim::kPrimReduceMean) || IsPrimitiveCNode(node, prim::kPrimReduceAll) || + IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimReduceMax) || + IsPrimitiveCNode(node, prim::kPrimReduceMin)) { + prim = GetValueNode(node->cast()->input(0)); + AnfVisitor::Match(prim, {IsNode, IsVNode})(node); + if (!is_axis_one_) { + return nullptr; + } + + // consider keep_dims + auto keep_dims = prim->GetAttr("keep_dims"); + auto is_keep_dims = GetValue(keep_dims); + // {_Reduce, X, axis} -> X + if (is_keep_dims) { + return x_; + } + + // {_Reduce, Tensor} + if (is_tensor_) { + return nullptr; + } + + // {_Reduce, X, axis} -> {Reshape, X, new_shape} + std::vector elements; + for (size_t i = 0; i < x_shape_.size(); i++) { + auto iter = find(axis_.begin(), axis_.end(), i); + if (iter == axis_.end()) { + ValuePtr s = MakeValue(x_shape_[i]); + elements.push_back(s); + } + } + auto new_shape = std::make_shared(elements); + auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast(); + return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (!IsVNode(node) && x_ == nullptr) { + if (IsValueNode(node)) { + is_tensor_ = true; + } + // get X's shape + auto x_shape_abs = node->abstract(); + if (x_shape_abs != nullptr) { + auto x_track = x_shape_abs->GetShapeTrack()->cast(); + if (x_track == nullptr) { + return; + } + auto x_shape = x_track->shape(); + (void)std::copy(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_)); + x_ = node; + } + return; + } + + // check axis + AnfVisitor::Visit(node); + } + + void Visit(const ValueNodePtr &vnode) override { + if (x_shape_.empty()) { + return; + } + + // axis : int + if (IsValueNode(vnode)) { + auto idx = GetValue(vnode->value()); + // axis could be negative + if (idx < 0) { + idx += SizeToInt(x_shape_.size()); + } + if (SizeToInt(x_shape_.size()) > idx && x_shape_[IntToSize(idx)] == 1) { + is_axis_one_ = true; + axis_.push_back(idx); + } + return; + } + + // axis : tuple(int), default () + if (IsValueNode(vnode)) { + auto axis = GetValue>(vnode->value()); + if (axis.empty()) { + return; + } + + auto cmp = std::all_of(axis.cbegin(), axis.cend(), [this](int idx) { + // axis could be negative + if (idx < 0) { + idx += SizeToInt(x_shape_.size()); + } + return SizeToInt(this->x_shape_.size()) > idx && this->x_shape_[IntToSize(idx)] == 1; + }); + if (cmp) { + is_axis_one_ = true; + (void)std::copy(axis.begin(), axis.end(), std::back_inserter(axis_)); + } + } + } + + void Reset() { + axis_.clear(); + x_shape_.clear(); + x_ = nullptr; + is_axis_one_ = false; + is_tensor_ = false; + } + + private: + bool is_axis_one_{false}, is_tensor_{false}; + std::vector axis_{}, x_shape_{}; + AnfNodePtr x_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h new file mode 100644 index 0000000000..86eb4e761d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/ref_eliminate.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ + +#include + +#include "ir/pattern_matcher.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimMakeRef, X, Y, Z} -> Y +class MakeRefEliminater : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, y, z; + MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); + return nullptr; + } +}; + +// {prim::kPrimGetRefValue, Parameter} -> Parameter +// {prim::kPrimGetRefOrigin, Parameter} -> Parameter +class GetRefParamEliminater : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); + return nullptr; + } +}; + +// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X +// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y +// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z +class GetMakeRefEliminater : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, y, z; + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); + MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); + + return nullptr; + } +}; + +// IsValueNode +class ReplaceRefkeyByParam : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { + auto refkey = GetValueNode(node); + auto resource = std::dynamic_pointer_cast(optimizer->resource()); + MS_EXCEPTION_IF_NULL(resource); + + auto top_graph = resource->func_graph(); + MS_EXCEPTION_IF_NULL(top_graph); + + for (const auto &tnode : top_graph->parameters()) { + auto para = tnode->cast(); + if (para != nullptr && para->name() == refkey->tag()) { + return para; + } + } + return nullptr; + }; + PatternNode x; + MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode, node)); + return nullptr; + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h new file mode 100644 index 0000000000..27d4bdad3d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/reshape_eliminate.h @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ + +#include + +#include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "abstract/dshape.h" + +namespace mindspore { +namespace opt { +namespace irpass { +using abstract::Shape; +using abstract::ShapePtr; + +// {reshape_op, X, Shape} +class ReshapeSameShapeEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimReshape, {IsNode, IsVNode})(node); + + // check pattern match + if (shape_ == nullptr) { + return nullptr; + } + + auto src_shape_abs = x_->abstract(); + if (src_shape_abs == nullptr) { + return nullptr; + } + + auto src_shape = src_shape_abs->GetShapeTrack(); + auto tgt_shape_abs = node->abstract(); + if (tgt_shape_abs == nullptr) { + return nullptr; + } + auto tgt_shape = tgt_shape_abs->GetShapeTrack(); + if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa() && tgt_shape->isa()) { + auto elements = tgt_shape->cast(); + auto shape = src_shape->cast(); + if (shape->shape() == elements->shape()) { + return x_; + } + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } else { + shape_ = node; + } + } + + void Reset() { + x_ = nullptr; + shape_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, shape_{nullptr}; +}; + +// {PrimReshape, {PrimReshape, X, Y}, Shape} +class TwoReshapeEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimReshape, {IsCNode, IsNode})(node); + + auto fg = node->func_graph(); + if (fg != nullptr && x_ != nullptr && shape_ != nullptr) { + auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_}); + new_node->set_abstract(node->abstract()); + return new_node; + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (IsPrimitiveCNode(node, prim::kPrimReshape)) { + auto &inputs = node->cast()->inputs(); + // {PrimReshape, X, Y} + if (inputs.size() != 3) { + return; + } + prim_ = GetValueNode(inputs[0]); + x_ = inputs[1]; + } else { + shape_ = node; + } + } + + void Reset() { + prim_ = nullptr; + x_ = nullptr; + shape_ = nullptr; + } + + private: + PrimitivePtr prim_{nullptr}; + AnfNodePtr x_{nullptr}, shape_{nullptr}; +}; + +class ReshapeEliminater : public OptimizerCaller { + public: + ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} + ~ReshapeEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + auto new_node = reshape_same_shape_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + new_node = two_reshape_eliminater_(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + + return nullptr; + } + + private: + ReshapeSameShapeEliminater reshape_same_shape_eliminater_; + TwoReshapeEliminater two_reshape_eliminater_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h new file mode 100644 index 0000000000..01efa85e8d --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -0,0 +1,210 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ + +#include +#include +#include +#include + +#include "ir/optimizer_caller.h" +#include "ir/pattern_matcher.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/prim_eliminate.h" +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace opt { +namespace irpass { +class SpecialOpEliminater : public OptimizerCaller { + public: + SpecialOpEliminater() + : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), + stop_gradient_(std::make_shared(prim::kPrimStopGradient)), + hook_backward_(std::make_shared(prim::kPrimHookBackward)), + print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), + get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), + mirror_(std::make_shared(prim::kPrimMirror)), + virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { + eliminaters_.emplace_back(insert_gradient_of_); + eliminaters_.emplace_back(stop_gradient_); + eliminaters_.emplace_back(hook_backward_); + eliminaters_.emplace_back(print_shape_type_); + eliminaters_.emplace_back(get_ref_value_); + eliminaters_.emplace_back(mirror_); + eliminaters_.emplace_back(virtual_div_); + } + ~SpecialOpEliminater() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; + } + + private: + OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, + virtual_div_; + std::vector eliminaters_{}; +}; + +// {PrimVirtualDataset, X} -> X +// {PrimVirtualDataset, Xs} -> {prim::kPrimMakeTuple, Xs} +class VirtualDatasetEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimVirtualDataset) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 1) { + return nullptr; + } + + std::vector args; + (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); + if (args.size() == 1) { + return args.front(); + } + + (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); + + return node->func_graph()->NewCNode(args); + } + + void Visit(const AnfNodePtr &) override {} +}; + +// {prim::kPrimSameTypeShape, X, Y} -> X +class SameEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimSameTypeShape, {IsNode, IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } + } + + private: + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimCheckBprop, X, Y} -> X +class CheckBpropEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + x_ = nullptr; + AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); + return x_; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } + } + + private: + AnfNodePtr x_{nullptr}; +}; + +// Reset defer_inline flag +class ResetDeferInline : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (IsValueNode(node)) { + auto fg = GetValueNode(node); + fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); + } + return nullptr; + } +}; + +// {PrimZerosLike, Y} -> +// {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0} +class ZeroLikeFillZero : public AnfVisitor { + public: + ZeroLikeFillZero() + : PrimFill_(prim::GetPythonOps("fill", "mindspore.ops.functional")->cast()), + PrimShape_(prim::GetPythonOps("shape", "mindspore.ops.functional")->cast()), + PrimDType_(prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast()) {} + ~ZeroLikeFillZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + y_ = nullptr; + AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node); + if (y_ == nullptr || node->func_graph() == nullptr) { + return nullptr; + } + if ((y_->abstract() == nullptr) || !y_->abstract()->isa()) { + auto fg = node->func_graph(); + auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); + auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); + return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); + } + + abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast(); + + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + tensor::TensorPtr new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + (void)memset_s(data, mem_size, 0, mem_size); + + auto new_cnode = NewValueNode(new_tensor_ptr); + new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); + + return new_cnode; + } + + void Visit(const AnfNodePtr &node) override { y_ = node; } + + private: + AnfNodePtr y_{nullptr}; + PrimitivePtr PrimFill_, PrimShape_, PrimDType_; +}; + +// {prim::kPrimDepend, X, ValueCond}->X +class DependValueElim : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, cond; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); + return nullptr; + } +}; + +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h new file mode 100644 index 0000000000..d8a15f6d83 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/specialize_transform.h @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ + +#include +#include +#include +#include +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/manager.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +namespace internal { +class SpecializeTransform { + public: + SpecializeTransform() : cache_() {} + ~SpecializeTransform() = default; + + FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, + std::vector prim_args, std::vector value_args) { + if (cache_.count(func_graph) == 0) { + cache_[func_graph] = {}; + } + + auto &cache = cache_[func_graph]; + auto key = std::make_pair(graph_args, prim_args); + if (cache.count(key) == 0) { + auto mng = func_graph->manager(); + MS_EXCEPTION_IF_NULL(mng); + + FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared("sp")); + mng->AddFuncGraph(new_fg); + + std::vector params = new_fg->parameters(); + std::vector new_params; + size_t n = graph_args.size(); + for (size_t i = 0; i < n; i++) { + if (graph_args[i] != nullptr) { + auto arg = NewValueNode(graph_args[i]); + (void)mng->Replace(params[i], arg); + continue; + } + if (prim_args[i] != nullptr) { + auto arg = NewValueNode(prim_args[i]); + (void)mng->Replace(params[i], arg); + continue; + } + if (value_args[i] != nullptr) { + auto &const_tensor = *value_args[i]; + auto const_tensor_ptr = std::make_shared(const_tensor); + AnfNodePtr arg = NewValueNode(const_tensor_ptr); + (void)mng->Replace(params[i], arg); + continue; + } + new_params.push_back(params[i]); + } + + mng->SetParameters(new_fg, new_params); + cache[key] = new_fg; + } + return cache[key]; + } + + private: + std::unordered_map, std::vector>, FuncGraphPtr>> + cache_; +}; +} // namespace internal + +// {G, Xs} +class SpecializeOnGraphArguments : public AnfVisitor { + public: + SpecializeOnGraphArguments() : specialize_transform_() {} + ~SpecializeOnGraphArguments() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (!IsValueNode(inputs[0])) { + return nullptr; + } + + auto inp0_fg = GetValueNode(inputs[0]); + if (inp0_fg->recursive()) { + return nullptr; + } + + std::vector graph_args; + std::vector prim_args; + std::vector value_node_args; + std::vector new_xs; + bool hasVNode = false; + for (size_t i = 1; i < inputs.size(); i++) { + if (IsValueNode(inputs[i])) { + auto fg_vnode = GetValueNode(inputs[i]); + graph_args.push_back(fg_vnode); + prim_args.emplace_back(nullptr); + value_node_args.emplace_back(nullptr); + hasVNode = true; + } else if (IsValueNode(inputs[i])) { + auto p_vnode = GetValueNode(inputs[i]); + graph_args.emplace_back(nullptr); + prim_args.push_back(p_vnode); + value_node_args.emplace_back(nullptr); + hasVNode = true; + } else if (IsValueNode(inputs[i])) { + tensor::TensorPtr t_vnode = GetValueNode(inputs[i]); + graph_args.emplace_back(nullptr); + prim_args.emplace_back(nullptr); + value_node_args.emplace_back(t_vnode); + hasVNode = true; + } else { + graph_args.emplace_back(nullptr); + prim_args.emplace_back(nullptr); + value_node_args.emplace_back(nullptr); + new_xs.push_back(inputs[i]); + } + } + + if (!hasVNode) { + return nullptr; + } + + auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); + (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); + + return node->func_graph()->NewCNode(new_xs); + } + + private: + internal::SpecializeTransform specialize_transform_; +}; + +// Eliminate unused parameters. +// {G, Xs} +class UnusedParasEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + auto fg = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(fg); + + std::vector parameters = fg->parameters(); + size_t size = parameters.size(); + if (size != inputs.size() - 1) { + return nullptr; + } + + std::vector new_xs; + std::vector keep_parameters; + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + auto &node_users = mng->node_users(); + bool has_unused_para = false; + for (size_t i = 0; i < size; ++i) { + auto iter = node_users.find(parameters[i]); + if (iter != node_users.end() && !iter->second.empty()) { + keep_parameters.push_back(true); + new_xs.push_back(inputs[i + 1]); + continue; + } + keep_parameters.push_back(false); + has_unused_para = true; + } + + if (!has_unused_para) { + return nullptr; + } + FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); + mng->AddFuncGraph(new_fg); + + std::vector new_fg_parameters = new_fg->parameters(); + std::vector new_parameters; + for (size_t i = 0; i < size; i++) { + if (keep_parameters[i]) { + if (parameters[i]->abstract() != nullptr) { + new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); + } + new_parameters.push_back(new_fg_parameters[i]); + } + } + mng->SetParameters(new_fg, new_parameters); + + (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); + return node->func_graph()->NewCNode(new_xs); + } +}; + +// Eliminate unused outputs. +// {G, Xs} +class UnusedOutputEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + auto fg = GetValueNode(inputs[0]); + MS_EXCEPTION_IF_NULL(fg); + auto mng = fg->manager(); + MS_EXCEPTION_IF_NULL(mng); + if (fg->recursive()) { + return nullptr; + } + + auto new_fg = TransformableClone(fg, std::make_shared("fg")); + mng->AddFuncGraph(new_fg); + auto new_fg_output = new_fg->output(); + if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { + return nullptr; + } + + auto output_cnode = new_fg_output->cast(); + auto &node_users = mng->node_users(); + if (node_users.count(node) == 0 || node_users[node].empty()) { + return nullptr; + } + std::unordered_set used_output_idx; + std::vector> all_users; + for (auto &node_user : node_users[node]) { + if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { + return nullptr; + } + auto user_cnode = node_user.first->cast(); + size_t used_idx = GetValue(user_cnode->input(2)->cast()->value()); + used_output_idx.insert(used_idx); + all_users.push_back(std::make_pair(node_user.first, used_idx)); + } + + if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { + // all output has users. + return nullptr; + } + + if (used_output_idx.empty()) { + // we do not process this case. + return nullptr; + } else if (used_output_idx.size() == 1) { + // after eliminate, only one output left. + new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); + // update users. + for (auto &ret_user : all_users) { + (void)mng->Replace(ret_user.first, node); + } + } else { + // after eliminate, create new multi output. + std::vector new_output_inputs{output_cnode->input(0)}; + std::unordered_map new_idx_map; + for (auto idx : used_output_idx) { + new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); + new_output_inputs.push_back(output_cnode->input(idx + 1)); + } + new_fg->set_output(new_fg->NewCNode(new_output_inputs)); + // update users. + for (auto &ret_user : all_users) { + auto ret_user_cnode = ret_user.first->cast(); + ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); + } + } + + auto new_sx = inputs; + new_sx[0] = NewValueNode(new_fg); + return node->func_graph()->NewCNode(new_sx); + } +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h new file mode 100644 index 0000000000..de9e533550 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ + +#include +#include + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimResolve, Ns, Sym} +class ResolverResolve : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node); + if (sym_ != nullptr) { + return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); + } + return nullptr; + } + + void Visit(const ValueNodePtr &vnode) override { + if (IsValueNode(vnode)) { + ns_ = GetValueNode(vnode); + } else if (ns_ != nullptr && IsValueNode(vnode)) { + sym_ = GetValueNode(vnode); + } + } + + void Reset() { + ns_ = nullptr; + sym_ = nullptr; + } + + private: + parse::NameSpacePtr ns_{nullptr}; + parse::SymbolPtr sym_{nullptr}; +}; + +// {prim::kPrimGetAttr, Ns, Str} +class ResolverGetattr : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node); + if (sym_ != nullptr) { + return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); + } + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (IsValueNode(node)) { + ns_ = GetValueNode(node); + } else if (ns_ != nullptr && IsValueNode(node)) { + auto str = GetValue(GetValueNode(node)); + sym_ = std::make_shared(str); + } + } + + void Reset() { + ns_ = nullptr; + sym_ = nullptr; + } + + private: + parse::NameSpacePtr ns_{nullptr}; + parse::SymbolPtr sym_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h new file mode 100644 index 0000000000..f561e04c10 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/tile_eliminate.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// check if node is value tuple and all one. e.g. (1, 1, 1) +// {PrimTile, X, MultiOne} +class TileMultiplyByOne : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTile, {IsNode, IsVNode})(node); + + // check pattern match + if (tuple_ == nullptr) { + return nullptr; + } + + auto value = GetValueNode(tuple_); + auto elements = GetValue>(value); + if (elements.empty()) { + return nullptr; + } + + auto cmp = std::all_of(elements.cbegin(), elements.cend(), [](int i) { return i == 1; }); + if (cmp) { + return x_; + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } else { + tuple_ = node; + } + } + + void Reset() { + x_ = nullptr; + tuple_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h new file mode 100644 index 0000000000..70b8898462 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/transpose_eliminate.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// check if node is value tuple and ascends one by one from zero. e.g., (0, 1, 2, 3) +// {PrimTranspose, X, AscendingNums} +class TransposeSameIOEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimTranspose, {IsNode, IsVNode})(node); + + // check pattern match + if (tuple_ == nullptr) { + return nullptr; + } + + auto value = GetValueNode(tuple_); + auto elements = GetValue>(value); + if (elements.empty()) { + return nullptr; + } + + int j = 0; + bool cmp = std::all_of(elements.cbegin(), elements.cend(), [&j](int i) { return i == j++; }); + // same IO settings, eliminate this transpose + if (cmp) { + return x_; + } + + return nullptr; + } + + void Visit(const AnfNodePtr &node) override { + if (x_ == nullptr) { + x_ = node; + } else { + tuple_ = node; + } + } + + void Reset() { + x_ = nullptr; + tuple_ = nullptr; + } + + private: + AnfNodePtr x_{nullptr}, tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc new file mode 100644 index 0000000000..44917106fa --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/opt.h" + +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/manager.h" +#include "frontend/optimizer/optimizer.h" +#include "utils/log_adapter.h" +#include "utils/ordered_set.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &renorm_action) { + auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; + return std::make_shared(transform, name, fn, renorm_action); +} + +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const std::vector &prims, const RenormAction &renorm_action) { + auto fn = [prims](const AnfNodePtr &node) -> bool { + if (!node->isa()) { + return false; + } + + auto cnode = node->cast(); + auto inp0 = cnode->input(0); + auto prim0 = GetValueNode(inp0); + if (prim0 == nullptr) { + return false; + } + + auto hash = prim0->Hash(); + auto const &name = prim0->name(); + for (auto &prim : prims) { + if (hash == prim->Hash() && name == prim->name()) { + return true; + } + } + return false; + }; + + return std::make_shared(transform, name, fn, renorm_action); +} + +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &renorm_action) { + return std::make_shared(transform, name, predicate, renorm_action); +} + +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { +#ifdef ENABLE_PROFILE + double t = GetTime(); +#endif + AnfNodePtr result = (*transform_)(optimizer, node); +#ifdef ENABLE_PROFILE + if (optimizer != nullptr) { + auto time = GetTime(); + MsProfile::StatTime("substitution." + name_, time - t); + if (result != nullptr) { + MsProfile::StatTime("match." + name_, time - t); + } + } +#endif + if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { + if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) { + optimizer->set_is_untyped_generated(); + } + } + + return result; +} + +static bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->isa() || node->isa()) { + return true; + } + if (IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} + +bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, + const SubstitutionPtr &transform) const { +#ifdef ENABLE_PROFILE + double start = GetTime(); +#endif + FuncGraphManagerPtr manager = optimizer->manager(); + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.clear(); + todo.push_back(root_node); + bool changes = false; + + auto &all_nodes = manager->all_nodes(); + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + + // check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { + continue; + } + node->seen_ = seen; + + // select nodes that this transform can be applied. + bool is_match = transform->predicate_(node); + + // apply transform on this node + bool change = false; + if (is_match) { + auto ret = (*transform)(optimizer, node); + if (ret != nullptr && ret != node) { + change = true; + changes = true; +#ifdef ENABLE_PROFILE + double t = GetTime(); +#endif + (void)manager->Replace(node, ret); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("replace." + transform->name_, GetTime() - t); +#endif + node = ret; + } + } + + // find success, and add them to todo list + if (IsValueNode(node)) { + todo.push_back(GetValueNode(node)->output()); + } + + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); + } + + auto &node_users = manager->node_users(); + if (change && node_users.find(node) != node_users.end()) { + for (auto &use : node_users[node]) { + auto use_node = use.first; + if (use_node == nullptr) { + continue; + } + todo.push_back(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; + } + } + } + } + +#ifdef ENABLE_PROFILE + MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); +#endif + return changes; +} + +bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = optimizer->manager(); + manager->AddFuncGraph(func_graph); + + // for transform status counting + size_t space = 0; + std::unordered_map> status; + if (optimizer->is_on_debug_) { + for (size_t i = 0; i < list_.size(); i++) { + status[list_[i]->name_ + std::to_string(i)] = {}; + } + } + + bool loop = false; + bool changes = false; + + do { + loop = false; + for (size_t i = 0; i < list_.size(); i++) { + auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); + changes = changes || change; + loop = loop || change; + + // record the status of each transform + if (optimizer->is_on_debug_) { + status[list_[i]->name_ + std::to_string(i)].push_back(change); + space = std::max(list_[i]->name_.size(), space); + } + } + + if (is_once_) { + break; + } + } while (loop); + + // display the status of each transform + if (optimizer->is_on_debug_) { + std::stringstream ss; + ss << std::endl + << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name + << std::endl; + for (size_t i = 0; i < list_.size(); i++) { + auto name = list_[i]->name_; + ss << std::left << std::setw(space + 4) << name << "\t"; + for (auto change : status[name + std::to_string(i)]) { + ss << change << " "; + } + ss << std::endl; + } + MS_LOG(DEBUG) << ss.str(); + } + + return changes; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/opt.h b/mindspore/ccsrc/frontend/optimizer/opt.h new file mode 100644 index 0000000000..f440cc71dc --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/opt.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ + +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/optimizer_caller.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +/* namespace to support opt */ +namespace opt { + +// Define the interaction mode between an Optimize pass and Renormalize pass +// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed +// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted +enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; + +class Substitution { + public: + OptimizerCallerPtr transform_; + std::string name_; + PredicateFuncType predicate_{nullptr}; + // an enum to mark this Substitution relation to renormalize pass + RenormAction renorm_action_; + Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, + const RenormAction &renorm_action) + : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} + ~Substitution() = default; + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); +}; + +using SubstitutionPtr = std::shared_ptr; + +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &action_renorm = CHECK_RENORM); +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const std::vector &prims, + const RenormAction &action_renorm = CHECK_RENORM); +SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); + +class SubstitutionList { + public: + explicit SubstitutionList(const std::vector &patterns, bool is_once = false) + : list_(patterns), is_once_(is_once) {} + ~SubstitutionList() = default; + + bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; + + private: + bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; + std::vector list_; + // a flag to mark this list of Substitution can only be executed only once + bool is_once_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/optimizer.h b/mindspore/ccsrc/frontend/optimizer/optimizer.h new file mode 100644 index 0000000000..a1f11e74d0 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/optimizer.h @@ -0,0 +1,242 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "debug/draw.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "debug/trace.h" +#include "frontend/optimizer/opt.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +using OptimizeGraphFunc = std::function; + +class OptPassConfig { + public: + explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} + explicit OptPassConfig(const std::vector &list, bool is_once = false) + : list_(list), is_once_(is_once) {} + OptPassConfig(const std::initializer_list &list, bool is_once = false) + : list_(list), is_once_(is_once) {} + ~OptPassConfig() = default; + + const std::vector &list() const { return list_; } + const OptimizeGraphFunc &func() const { return func_; } + + static OptPassConfig Renormalize() { return OptPassConfig(); } + const bool is_renormalize() const { return is_renormalize_; } + + const bool is_once() const { return is_once_; } + + private: + OptPassConfig() : is_renormalize_(true) {} + + OptimizeGraphFunc func_; + std::vector list_; + bool is_renormalize_{false}; + bool is_once_{false}; +}; + +class OptPass { + public: + explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {} + ~OptPass() = default; + + bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { + return pass_func_(func_graph, optimizer); + } + + static OptPass Renormalize() { return OptPass(); } + const bool is_renormalize() const { return is_renormalize_; } + + private: + OptPass() : is_renormalize_(true) {} + + OptimizeGraphFunc pass_func_; + bool is_renormalize_{false}; +}; +using OptPassGroupMap = std::vector>; + +class Optimizer : public std::enable_shared_from_this { + public: + Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) + : name_(name), + resource_(resource_ptr), + run_only_once_(false), + is_watch_renormalize_(false), + is_enable_(true), + is_untyped_generated_(false) {} + virtual ~Optimizer() = default; + + void Init(const OptPassGroupMap &passes, bool run_only_once) { + run_only_once_ = run_only_once; + is_watch_renormalize_ = false; + is_untyped_generated_ = false; + is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); + + for (auto &iter : passes) { + const std::string &name = iter.first; + pass_names_.push_back(name); + + const OptPassConfig &config = iter.second; + if (config.is_renormalize()) { + passes_.push_back(OptPass::Renormalize()); + continue; + } + + if (config.list().size() > 0) { + OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); + passes_.push_back(OptPass(func)); + continue; + } + + passes_.push_back(OptPass(config.func())); + } + + if (passes_.size() == 1) { + run_only_once_ = true; + } + } + + static std::shared_ptr MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, + const OptPassGroupMap &passes, bool run_only_once = false, + bool watch_renormalize = false) { + OptimizerPtr optimizer = std::make_shared(name, resource_ptr); + optimizer->Init(passes, run_only_once); + if (watch_renormalize) { + optimizer->enable_watch_renormalize(); + } + return optimizer; + } + + FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { + if (!is_enable_) { + return func_graph; + } + // Optimizer step counter; + int counter = 1; + bool changes = true; + + while (changes) { + changes = false; + auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { + for (size_t i = 0; i < passes_.size(); ++i) { + const OptPass &opt = passes_[i]; + CurPass_ = {counter, pass_names_[i]}; + auto opt_func = [&func_graph, &changes, &opt, this]() { + if (opt.is_renormalize()) { + auto resource_ptr = std::dynamic_pointer_cast(resource_); + if (resource_ptr != nullptr) { + // StepParallel may replace the AbstractValue of the parameters of func_graph, + // So generate the args_spec from parameters. + abstract::AbstractBasePtrList maybe_new_args_spec; + if (is_watch_renormalize_) { + if (is_untyped_generated_) { + std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), + std::back_inserter(maybe_new_args_spec), + [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); + func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); + clear_is_untyped_generated(); + } else { + MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False."; + } + } else { + std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), + std::back_inserter(maybe_new_args_spec), + [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); + func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); + } + } + } else if (opt(func_graph, shared_from_this())) { + changes = true; + } + }; + use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); + if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { + MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; + auto fg_name = + "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; + func_graph->DumpFuncGraph(fg_name); + DumpIR(fg_name + ".ir", func_graph); + ExportIR(fg_name + ".dat", "", func_graph); + MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; + } + } + }; + use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc(); + counter++; + + if (run_only_once_) { + break; + } + } + return func_graph; + } + + pipeline::ResourceBasePtr resource() const { return resource_; } + FuncGraphManagerPtr manager() const { + if (resource_ != nullptr) { + return resource_->manager(); + } + MS_LOG(EXCEPTION) << "No ResourceBase exists."; + } + + const std::string name() const { return name_; } + + void set_is_untyped_generated() { is_untyped_generated_ = true; } + void clear_is_untyped_generated() { is_untyped_generated_ = false; } + + void enable_watch_renormalize() { is_watch_renormalize_ = true; } + void disable_watch_renormalize() { is_watch_renormalize_ = false; } + bool is_watch_renormalize() { return is_watch_renormalize_; } + void set_enable(bool enable) { is_enable_ = enable; } + + struct { + int counter; + std::string name; + } CurPass_; + + bool is_on_debug_{false}; + + private: + const std::string name_; + pipeline::ResourceBasePtr resource_; + std::vector passes_; + std::vector pass_names_; + bool run_only_once_; + bool is_watch_renormalize_; + bool is_enable_; + bool is_untyped_generated_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.cc b/mindspore/ccsrc/frontend/optimizer/pass_group.cc new file mode 100644 index 0000000000..3619396215 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +void PassGroup::AddPass(const PythonPassPtr &pass) { + if (pass != nullptr) { + passes_.push_back(pass); + } +} + +bool PassGroup::DeletePass(const std::string &pass_name) { + for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { + if ((*iter)->name() == pass_name) { + *iter = nullptr; + passes_.erase(iter); + return true; + } + } + return false; +} + +bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { + if (func_graph == nullptr) { + return false; + } + bool changed = false; + for (const auto &pass : passes) { + if (pass != nullptr) { + if (pass->Run(func_graph)) { + changed = true; + } + } + } + return changed; +} + +bool PassGroup::Run(const FuncGraphPtr &func_graph) const { + bool changed = false; + // run all passes + bool change = true; + while (change) { + change = Run(func_graph, passes_); + changed = change || changed; + if (run_only_once_) { + break; + } + } + return changed; +} + +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/pass_group.h b/mindspore/ccsrc/frontend/optimizer/pass_group.h new file mode 100644 index 0000000000..08fa8018d6 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/pass_group.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ + +#include +#include +#include +#include + +#include "frontend/optimizer/py_pass.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PassGroup { + public: + explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) + : name_(name), passes_{}, run_only_once_(run_only_once) {} + virtual ~PassGroup() = default; + // Add graph pass, the pass object will be freed when pass manager freed. + void AddPass(const PythonPassPtr &pass); + // Delete graph pass before the pass manager is freed. + bool DeletePass(const std::string &pass_name); + // Run passes added in pass manager on the input graph + // @param [inout] graph The graph to be optimized + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph) const; + // Run the given graph passes on the input graph + // @param [inout] graph The graph to be optimized + // @param [in] passes The given graph passes + // @return true, graph changed + // @return false, graph not changed + bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; + std::string name() const { return name_; } + + private: + const std::string name_; + std::vector passes_; + bool run_only_once_; +}; +using PassGroupPtr = std::shared_ptr; +} // namespace python_pass +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass.cc b/mindspore/ccsrc/frontend/optimizer/py_pass.cc new file mode 100644 index 0000000000..c1bf40fcbb --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/py_pass.cc @@ -0,0 +1,237 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/py_pass.h" +#include +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/manager.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +namespace internal { +std::string GetNodeRepr(AnfNodePtr node) { + if (node != nullptr) { + if (node->isa()) { + std::string repr = "("; + auto const &inputs = node->cast()->inputs(); + for (auto &input : inputs) { + repr += " "; + repr += GetNodeRepr(input); + repr += " "; + } + repr += ")"; + return repr; + } + if (node->isa()) { + return GetValueNode(node)->ToString(); + } + return node->ToString(); + } + return ""; +} + +void ResolveFuncGraph_(const FuncGraphPtr &fg) { + auto manager = Manage(fg, false); + parse::python_adapter::set_use_signature_in_resolve(false); + parse::ResolveAll(manager); + parse::python_adapter::set_use_signature_in_resolve(true); +} + +bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { + if (node == nullptr) { + return false; + } + MS_EXCEPTION_IF_NULL(pattern); + if (pattern->isa()) { + if (!node->isa()) { + return false; + } + if (GetNodeRepr(pattern) == GetNodeRepr(node)) { + // add to equiv_ptr + equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); + return true; + } + return false; + } else if (pattern->isa()) { + MS_LOG(DEBUG) << pattern->ToString() + "\n"; + // add to equiv_ptr + equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); + return true; + } else if (pattern->isa()) { + // match every single sub ANode + if (!node->isa()) { + return false; + } + auto pattern_inputs = pattern->cast()->inputs(); + auto node_inputs = node->cast()->inputs(); + if (pattern_inputs.size() != node_inputs.size()) { + return false; + } + for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); + p_item++, node_item++) { + auto res = Match(*p_item, *node_item, equiv_ptr); + if (!res) { + return false; + } + } + return true; + } + MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; +} + +AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, + const NodeEquivPtr &equiv_ptr) { + if (cur_raw_dst_node_->isa()) { + auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); + if (sub_pair != equiv_ptr->end()) { + return sub_pair->second; + } + MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; + } else if (cur_raw_dst_node_->isa()) { + // check primitive ValueNode + auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); + if (sub_pair != equiv_ptr->end()) { + return sub_pair->second; + } + return cur_raw_dst_node_; + } else if (cur_raw_dst_node_->isa()) { + std::vector new_inputs; + auto inputs = cur_raw_dst_node_->cast()->inputs(); + for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { + auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); + new_inputs.push_back(subed); + } + return func_graph->NewCNode(new_inputs); + } + MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); +} + +bool isTraversable(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->isa() || node->isa()) { + return true; + } + if (IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} +} // namespace internal + +void PythonPass::Build(const py::function &src, const py::function &dst) { + // 1. get FuncGraph from py::function + auto src_fg_ = parse::ParsePythonCode(src); + auto dst_fg_ = parse::ParsePythonCode(dst); + if (src_fg_ == nullptr || dst_fg_ == nullptr) { + MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; + } + // 2. Resolve + internal::ResolveFuncGraph_(src_fg_); + internal::ResolveFuncGraph_(dst_fg_); + // 3. from FuncGraphPtr to ValueNode + src_node_ = src_fg_->output(); + dst_node_ = dst_fg_->output(); +} + +PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, + bool multigraph) + : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { + Build(src, dst); +} + +AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + auto equiv_ptr = std::make_shared(); + bool is_a_match = internal::Match(src_node_, node, equiv_ptr); + if (is_a_match) { + auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); + MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; + return new_node; + } + return nullptr; +} + +bool PythonPass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + auto seen = NewSeenGeneration(); + // 1024 is for the initial capacity of deque + std::deque todo(1024); + todo.push_back(func_graph->output()); + bool changes = false; + + auto &all_nodes = manager->all_nodes(); + while (!todo.empty()) { + AnfNodePtr node = todo.front(); + todo.pop_front(); + + // check whether this node has been matched. + if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { + continue; + } + node->seen_ = seen; + + // select nodes that this transform can be applied. + AnfNodePtr new_node = Run(func_graph, node); + bool change = (new_node != nullptr); + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + } else if (new_node == nullptr) { + new_node = node; + } + if (run_only_once_) { + return change; + } + + // find success, and add them to todo list + if (IsValueNode(node)) { + todo.push_back(GetValueNode(node)->output()); + } + + if (node->isa()) { + auto &inputs = node->cast()->inputs(); + (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); + } + + auto &node_users = manager->node_users(); + if (change && node_users.find(node) != node_users.end()) { + for (auto &use : node_users[node]) { + auto use_node = use.first; + if (use_node == nullptr) { + continue; + } + todo.push_back(use_node); + if (use_node->seen_ == seen) { + use_node->seen_--; + } + } + } + } + return changes; +} +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass.h b/mindspore/ccsrc/frontend/optimizer/py_pass.h similarity index 100% rename from mindspore/ccsrc/optimizer/py_pass.h rename to mindspore/ccsrc/frontend/optimizer/py_pass.h diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc new file mode 100644 index 0000000000..86d7067d1c --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/optimizer/py_pass_manager.h" + +#include +#include +#include +#include + +#include "ir/manager.h" +#include "frontend/optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +PyPassManagerPtr PyPassManager::global_instance = nullptr; +std::unordered_map PyPassManager::phase_to_group_; + +PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { + auto pm = phase_to_group_.find(phase); + if (pm == phase_to_group_.end()) { + return nullptr; + } + return pm->second; +} + +PyPassManagerPtr PyPassManager::GetInstance() { + if (global_instance == nullptr) { + global_instance = std::shared_ptr(new (std::nothrow) PyPassManager()); + } + return global_instance; +} + +PyPassManager::PyPassManager() { + phase_to_group_[Phase::RESOLVE] = std::make_shared(); + phase_to_group_[Phase::OPT] = std::make_shared(); +} + +void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + Phase phase, bool run_only_once, bool multigraph) { + auto cur_pm = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pm); + PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once, multigraph); + cur_pm->AddPass(new_pass); +} + +void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { + auto cur_pm = GetPassGroup(phase); + MS_EXCEPTION_IF_NULL(cur_pm); + if (!cur_pm->DeletePass(pass_name)) { + MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; + } +} + +void PyPassManager::ClearRes() { + MS_LOG(INFO) << "Clear PyPassManager resources!"; + global_instance = nullptr; + phase_to_group_.clear(); +} + +REGISTER_PYBIND_DEFINE( + PyPassManager_, ([](const py::module *m) { + (void)py::enum_(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); + (void)py::class_>(*m, "PyPassManager_") + .def(py::init([]() { return PyPassManager::GetInstance(); })) + .def("registe", &PyPassManager::Registe, "Registe python pass") + .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); + })); +} // namespace python_pass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h new file mode 100644 index 0000000000..84868862a7 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/py_pass_manager.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/primitive_py.h" +#include "utils/graph_utils.h" +#include "common/utils.h" + +#include "pipeline/jit/parse/resolve.h" +#include "frontend/optimizer/py_pass.h" +#include "frontend/optimizer/pass_group.h" + +namespace mindspore { +namespace opt { +namespace python_pass { +class PyPassManager; +using PyPassManagerPtr = std::shared_ptr; + +enum Phase { RESOLVE, OPT }; + +class PyPassManager { + protected: + PyPassManager(); + static PyPassManagerPtr global_instance; + + public: + // Singletons should not be cloneable and assignable + PyPassManager(const PyPassManager &other) = delete; + void operator=(const PyPassManager &) = delete; + // Access the only global instance + static PyPassManagerPtr GetInstance(); + virtual ~PyPassManager() = default; + void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, + Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); + void Unregiste(const std::string &pass_name, Phase phase); + PassGroupPtr GetPassGroup(Phase phase); + void ClearRes(); + + private: + static std::unordered_map phase_to_group_; +}; +} // namespace python_pass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/CMakeLists.txt b/mindspore/ccsrc/frontend/parallel/CMakeLists.txt new file mode 100644 index 0000000000..d2a099cf41 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/CMakeLists.txt @@ -0,0 +1,8 @@ +file(GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc" "ps/scheduler.cc" "ps/optimizer_info.cc" "ps/optimizer_info_builder.cc") +if (ENABLE_DUMP_PROTO) + list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") +endif () + +set_property(SOURCE ${_PARALLEL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARALLEL) +add_library(_mindspore_frontend_parallel_obj OBJECT ${_PARALLEL_SRC_FILES}) diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc new file mode 100644 index 0000000000..70ae5a7d20 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc @@ -0,0 +1,435 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "frontend/parallel/costmodel_context.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/step_parallel.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + MS_EXCEPTION_IF_NULL(para); + MS_EXCEPTION_IF_NULL(para->func_graph()); + FuncGraphManagerPtr manager = para->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_set = manager->node_users()[para]; + std::unordered_set cnode_set; + for (auto &node_pair : node_set) { + auto cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + continue; + } + auto node_prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + (void)cnode_set.emplace(cnode); + } else { + auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); + for (auto &cnode_sub : cnode_set_sub) { + (void)cnode_set.emplace(cnode_sub); + } + } + } + return cnode_set; +} + +Status AllreduceFusion::AddNodeToGraph() { + const auto ¶meters = root_graph_->parameters(); + for (auto ¶meter : parameters) { + if (!ParameterRequireGrad(parameter)) { + continue; + } + auto cnode_set = FindCNodesWithPara(parameter); + if (cnode_set.empty()) { + continue; + } + for (auto &cnode : cnode_set) { + MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); + if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { + MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); + return FAILED; + } + } + } + return SUCCESS; +} + +CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + MS_EXCEPTION_IF_NULL(from); + std::unordered_map cnode_dist; + if (!from->isa()) { + return cnode_dist; + } + auto cnode = from->cast(); + if (!IsValueNode(cnode->input(0))) { + return cnode_dist; + } + + MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) + << " operator_info: " << (cnode->operator_info() != nullptr); + + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); + MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; + + if (allreduce_graph_.NodeInGraph(cnode)) { + cnode_dist[cnode] = cost; + return cnode_dist; + } else { + auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); + for (auto &ele : cnode_dist_next) { + cnode_dist[ele.first] = cost + ele.second; + } + } + } else { + auto cnode_dist_next = FindNextCNodes(cnode); + for (auto &ele : cnode_dist_next) { + cnode_dist[ele.first] = ele.second; + } + } + return cnode_dist; +} + +CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + const auto &from_inputs = from->inputs(); + std::unordered_map dist_map; + MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; + for (auto &input_node : from_inputs) { + auto cnode_dist = FindCNode(input_node, recursive_times + 1); + for (auto &ele : cnode_dist) { + (void)dist_map.emplace(ele); + } + } + return dist_map; +} + +Status AllreduceFusion::AddEdgeToGraph() { + std::unordered_map cnode_state_map; + const auto &cnodes = allreduce_graph_.cnode_set(); + for (auto &cnode : cnodes) { + cnode_state_map[cnode] = 0; + } + const auto &head_cnode = allreduce_graph_.head_cnode(); + std::queue cnode_queue; + cnode_queue.emplace(head_cnode); + cnode_state_map[head_cnode] = 1; + + while (!cnode_queue.empty()) { + const auto cur_cnode = cnode_queue.front(); + cnode_queue.pop(); + cnode_state_map[cur_cnode] = 2; + auto next = FindNextCNodes(cur_cnode); + for (auto &ele : next) { + auto &cnode = ele.first; + auto &dist = ele.second; + if (cnode_state_map[cnode] == 0) { + cnode_queue.emplace(cnode); + cnode_state_map[cnode] = 1; + } + if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) { + MS_LOG(ERROR) << "AddEdge error"; + return FAILED; + } + MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist; + } + } + return SUCCESS; +} + +std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { + if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { + MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " + << MAX_RECURSIVE_CALL_TIMES; + } + MS_EXCEPTION_IF_NULL(para); + MS_EXCEPTION_IF_NULL(para->func_graph()); + FuncGraphManagerPtr manager = para->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[para]; + std::vector cnode_list; + for (auto &node_pair : node_set) { + auto cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + continue; + } + auto node_prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == CAST) { + auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1); + if (mirror_cnodes.empty()) { + MS_LOG(WARNING) << "mirror node after cast not found"; + continue; + } + if (mirror_cnodes.size() > 1) { + MS_LOG(EXCEPTION) << "mirror node after cast number is not 1"; + } + cnode_list.emplace_back(mirror_cnodes[0]); + } + if (node_prim->name() == MIRROR_OPERATOR) { + cnode_list.emplace_back(cnode); + } + } + return cnode_list; +} + +void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { + MS_EXCEPTION_IF_NULL(mirror_cnode); + MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; + auto node_prim = GetValueNode(mirror_cnode->input(0)); + auto old_value_ptr = node_prim->GetAttr(FUSION); + if (old_value_ptr != nullptr) { + if (old_value_ptr->isa()) { + int32_t old_value = old_value_ptr->cast()->value(); + if (old_value < fusion) { + return; + } + } + } + (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared(fusion))); + (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); +} + +Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { + auto mirror_cnodes = FindMirror(para); + if (mirror_cnodes.empty()) { + MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; + return SUCCESS; + } + if (mirror_cnodes.size() > 2) { + for (auto &mirror_cnode : mirror_cnodes) { + MS_EXCEPTION_IF_NULL(mirror_cnode); + MS_LOG(INFO) << mirror_cnode->DebugString(); + } + MS_EXCEPTION_IF_NULL(para); + MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size() + << "Mirror CNode found."; + return FAILED; + } + for (auto &mirror_cnode : mirror_cnodes) { + auto parameter_name = ParameterName(para); + SetMirrorFusion(mirror_cnode, fusion, parameter_name); + } + return SUCCESS; +} + +Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { + for (auto ¶m_node : paras) { + if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { + MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; + return FAILED; + } + } + return SUCCESS; +} + +Status AllreduceFusion::SetFusion(const std::vector &cost_map) { + if (cost_map.size() < 2) { + MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); + return FAILED; + } + int32_t fusion = 1; + for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) { + auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter); + if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { + MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; + return FAILED; + } + fusion++; + } + return SUCCESS; +} + +std::vector AllreduceFusion::GenerateCostMap(int32_t fusion_times, double tail_percent) const { + double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1); + MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset; + std::vector cost_map; + double begin = 0; + for (auto i = 0; i < fusion_times - 1; i++) { + cost_map.push_back(begin); + begin += offset; + } + cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent)); + cost_map.push_back(allreduce_graph_.max()); + MS_LOG(DEBUG) << "cost_map = " << cost_map; + return cost_map; +} + +Status AllreduceFusion::SetFusionByBackwardCompTime() { + auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times(); + if (fusion_times < 2) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion"; + return SUCCESS; + } + auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent(); + if (tail_percent < 0 || tail_percent >= 1) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent + << ". Bypass ProcessAllreduceFusion"; + return SUCCESS; + } + const auto cost_map = GenerateCostMap(fusion_times, tail_percent); + MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed."; + if (SetFusion(cost_map) != SUCCESS) { + MS_LOG(ERROR) << "SetFusion failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed."; + return SUCCESS; +} + +Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() { + tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time(); + if (tail_time_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time(); + if (allreduce_inherent_time_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ + << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + if (tail_time_ <= allreduce_inherent_time_) { + MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ + << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ + << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion"; + return FAILED; + } + allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth(); + if (allreduce_bandwidth_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_ + << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + computation_time_parameter_ = + CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter(); + if (computation_time_parameter_ <= 0) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_ + << ". Bypass ProcessAllreduceFusion"; + return FAILED; + } + return SUCCESS; +} + +Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() { + if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) { + MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!"; + return FAILED; + } + allreduce_graph_.SortArnode(); + if (allreduce_graph_.RemoveExtraParas() != SUCCESS) { + MS_LOG(ERROR) << "RemoveExtraParas failed!"; + return FAILED; + } + double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_; + double to_cost = allreduce_graph_.max(); + int32_t fusion = 1; + while (to_cost != 0) { + MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size; + auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size); + MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second; + auto paras = node_cost_pair.first; + if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { + MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; + return FAILED; + } + fusion++; + para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) / + allreduce_bandwidth_; + to_cost = node_cost_pair.second; + } + MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed."; + return SUCCESS; +} + +Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { + if (algorithm == 1) { + return SetFusionByBackwardCompTime(); + } + return SetFusionByBackwardCompAndAllreduceTime(); +} + +Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { + if (ret == nullptr) { + MS_LOG(ERROR) << "ret is nullptr."; + return FAILED; + } + auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm(); + if (algorithm < 1 || algorithm > 2) { + MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion"; + return SUCCESS; + } + ret_ = ret; + root_graph_ = ret_->func_graph(); + MS_EXCEPTION_IF_NULL(root_graph_); + auto graph_set = ForwardGraph(root_graph_); + if (graph_set.size() > 1) { + MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; + return SUCCESS; + } + auto forward_graph = *(graph_set.begin()); + MS_EXCEPTION_IF_NULL(forward_graph); + forward_ret_ = forward_graph->get_return(); + MS_EXCEPTION_IF_NULL(forward_ret_); + + if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed."; + if (AddNodeToGraph() != SUCCESS) { + MS_LOG(ERROR) << "AddNodeToGraph failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed."; + if (AddEdgeToGraph() != SUCCESS) { + MS_LOG(ERROR) << "AddNodeToGraph failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed."; + if (SetFusionByAlgorithm(algorithm) != SUCCESS) { + MS_LOG(ERROR) << "SetFusionByAlgorithm failed."; + return FAILED; + } + MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h new file mode 100644 index 0000000000..7383c477a6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ + +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/allreduce_fusion/allreduce_graph.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +using CNodeCostMap = std::unordered_map; + +constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0; +constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; +constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; + +constexpr char FUSION[] = "fusion"; +constexpr char PARAMETER[] = "parameter"; +const uint32_t MAX_RECURSIVE_CALL_TIMES = 100; +class AllreduceFusion { + public: + AllreduceFusion() + : allreduce_graph_(), + ret_(nullptr), + forward_ret_(nullptr), + root_graph_(nullptr), + tail_time_(0), + allreduce_inherent_time_(0), + allreduce_bandwidth_(0), + computation_time_parameter_(0) {} + virtual ~AllreduceFusion() = default; + Status ProcessAllreduceFusion(const CNodePtr &ret); + + private: + Status AddNodeToGraph(); + CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; + CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; + Status AddEdgeToGraph(); + std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; + Status SetFusion(const std::vector &cost_map); + Status SetFusionByAlgorithm(int32_t algorithm); + Status SetFusionByBackwardCompTime(); + Status SetFusionByBackwardCompAndAllreduceTime(); + Status GetSetFusionByBackwardCompAndAllreduceTimeParams(); + + AllreduceGraph allreduce_graph_; + CNodePtr ret_; + CNodePtr forward_ret_; + FuncGraphPtr root_graph_; + double tail_time_; + double allreduce_inherent_time_; + double allreduce_bandwidth_; + double computation_time_parameter_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.cc new file mode 100644 index 0000000000..ca47b0fa97 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.cc @@ -0,0 +1,209 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/allreduce_graph.h" +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/allreduce_fusion/allreduce_node.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { + AllreduceNodePtr arnode; + auto cnode_emplace_return = cnode_set_.emplace(node); + if (!cnode_emplace_return.second) { + MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!"; + auto cnode_arnode_pair = cnode_arnode_map_.find(node); + if (cnode_arnode_pair == cnode_arnode_map_.end()) { + MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!"; + } + arnode = cnode_arnode_pair->second; + } else { + arnode = std::make_shared(AllreduceNode()); + } + + if (arnode->Init(node) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceNode Init failed"; + return FAILED; + } + if (arnode->AddPara(para) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceNode AddPara failed"; + return FAILED; + } + cnode_arnode_map_[node] = arnode; + + auto arnode_emplace_return = arnode_set_.insert(arnode); + if (!arnode_emplace_return.second) { + MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!"; + } + cnode_emplace_return = para_cnodeset_map_[para].emplace(node); + if (!cnode_emplace_return.second) { + MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope() + << "'s cnodeset!"; + } + auto para_emplace_return = cnode_paraset_map_[node].emplace(para); + if (!para_emplace_return.second) { + MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString() + << "'s paraset!"; + } + return SUCCESS; +} + +Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { + auto from_arnode_iter = cnode_arnode_map_.find(from); + if (from_arnode_iter == cnode_arnode_map_.end()) { + MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; + PrintCNodeSet(); + return FAILED; + } + auto to_arnode_iter = cnode_arnode_map_.find(to); + if (to_arnode_iter == cnode_arnode_map_.end()) { + MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added"; + PrintCNodeSet(); + return FAILED; + } + auto from_arnode = from_arnode_iter->second; + auto to_arnode = to_arnode_iter->second; + if (from_arnode->AddNext(to_arnode) != SUCCESS) { + MS_LOG(ERROR) << "from_arnode AddNext failed"; + return FAILED; + } + if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) { + MS_LOG(ERROR) << "to_arnode AddPrev failed"; + return FAILED; + } + max_ = std::max(max_, to_arnode->depend_feat_size()); + MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString(); + MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size() + << ", to depend_feat_size: " << to_arnode->depend_feat_size(); + return SUCCESS; +} + +bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { + auto cnode_iter = cnode_set_.find(node); + return !(cnode_iter == cnode_set_.end()); +} + +std::vector AllreduceGraph::GetParaByCost(double from, double to) { + std::vector nodes; + for (auto &cnode_arnode : cnode_arnode_map_) { + MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() + << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() + << " curr_para_size: " << cnode_arnode.second->curr_para_size(); + if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) { + (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(), + cnode_paraset_map_[cnode_arnode.first].end()); + } + } + return nodes; +} + +std::pair, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) { + std::vector nodes; + double cur_para_size = 0; + double from = to; + for (auto &arnode : arnode_vec_) { + if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { + continue; + } + if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) { + return std::make_pair(nodes, from); + } + (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end()); + cur_para_size += arnode.curr_para_size(); + from = arnode.depend_feat_size(); + } + MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size + << " cur_para_size: " << cur_para_size << " from: " << from; + return std::make_pair(nodes, from); +} + +void AllreduceGraph::PrintCNodeSet() const { + MS_LOG(INFO) << "CNodeSet:"; + for (auto &cnode : cnode_set_) { + MS_LOG(INFO) << cnode->DebugString(); + } +} + +void AllreduceGraph::PrintAllredueGraphInfo() const { + MS_LOG(INFO) << "max: " << max_; + for (auto &cnode_arnode : cnode_arnode_map_) { + MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); + MS_LOG(INFO) << "arnode info: "; + cnode_arnode.second->ToString(); + } +} + +void AllreduceGraph::PrintArnodeVec() const { + MS_LOG(INFO) << "ArnodeVec:"; + for (auto &arnode : arnode_vec_) { + arnode.ToString(); + } +} + +void AllreduceGraph::PrintArnodeSet() const { + MS_LOG(INFO) << "ArnodeSet:"; + for (auto &arnode : arnode_set_) { + arnode->ToString(); + } +} + +void AllreduceGraph::SortArnode() { + arnode_vec_.clear(); + for (auto &node : arnode_set_) { + arnode_vec_.emplace_back(*node); + } + std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); +} + +Status AllreduceGraph::RemoveExtraParas() { + std::unordered_set para_map; + for (auto &node : arnode_vec_) { + for (auto ¶ : node.paras()) { + auto emplac_result = para_map.emplace(para); + if (!emplac_result.second) { + MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; + if (node.RemovePara(para) != SUCCESS) { + MS_LOG(ERROR) << "remove para failed"; + return FAILED; + } + } + } + } + return SUCCESS; +} + +Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { + auto arnode = std::make_shared(AllreduceNode()); + if (arnode->Init(node) != SUCCESS) { + MS_LOG(ERROR) << "AllreduceNode Init failed"; + } + head_cnode_ = node; + cnode_arnode_map_[node] = arnode; + auto arnode_emplace_return = arnode_set_.insert(arnode); + if (!arnode_emplace_return.second) { + MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!"; + } + auto cnode_emplace_return = cnode_set_.emplace(node); + if (!cnode_emplace_return.second) { + MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!"; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h new file mode 100644 index 0000000000..a47039f070 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_graph.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/allreduce_fusion/allreduce_node.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class AllreduceGraph { + public: + AllreduceGraph() + : head_cnode_(nullptr), + arnode_set_(), + arnode_vec_(), + cnode_set_(), + para_cnode_map_(), + para_cnodeset_map_(), + cnode_paraset_map_(), + cnode_arnode_map_(), + max_(0) {} + virtual ~AllreduceGraph() = default; + Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); + Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); + bool NodeInGraph(const CNodePtr &node) const; + std::vector GetParaByCost(double from, double to); + // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is + // over para_size. + // Return the parameter AnfNodePtr vector corresponding to these AllreduceNodes and the smallest depend_feat_size. + // If the sum of left AllreduceNode's parameter size is less than para_size, the returned depend_feat_size must be 0. + std::pair, double> GetParaByParaSize(double to, double para_size); + // If one parameter is used by multiple AllreduceNode, parameter belong to the last node for backward computation + // is saved by the corresponding AllreduceNode, parameters belong to other AllreduceNode are removed. + // Called during precise optimization, not implemented temporarily. + void SortArnode(); + Status RemoveExtraParas(); + void PrintCNodeSet() const; + void PrintAllredueGraphInfo() const; + void PrintArnodeVec() const; + void PrintArnodeSet() const; + const std::unordered_set &cnode_set() const { return cnode_set_; } + CNodePtr head_cnode() const { return head_cnode_; } + Status set_head_cnode(const CNodePtr &node); + double max() const { return max_; } + + private: + CNodePtr head_cnode_; + std::set arnode_set_; + std::vector arnode_vec_; + std::unordered_set cnode_set_; + // If One ParameterPtr is used by multiple CNode, the last node for backward computation is saved. + std::unordered_map> para_cnode_map_; + // One ParameterPtr may be used by multiple CNode + std::unordered_map> para_cnodeset_map_; + // Multiple Parameter may be inputs to the same CNode + std::unordered_map> cnode_paraset_map_; + std::unordered_map cnode_arnode_map_; + double max_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc new file mode 100644 index 0000000000..1c478887df --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/allreduce_node.h" +#include +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { + if (next_node == nullptr) { + MS_LOG(ERROR) << "next_node is nullptr!"; + return FAILED; + } + next_.emplace_back(next_node); + return SUCCESS; +} + +Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { + if (prev_node == nullptr) { + MS_LOG(ERROR) << "next_node is nullptr!"; + return FAILED; + } + if (dist <= 0) { + MS_LOG(ERROR) << "dist must be positive! dist: " << dist; + return FAILED; + } + prev_.emplace_back(prev_node); + double add_dist = prev_node->depend_feat_size() + dist; + depend_feat_size_ += add_dist; + if (depend_feat_size_ > *max) { + *max = depend_feat_size_; + } + std::queue next_queue; + for (auto &next : next_) { + next_queue.push(next); + } + while (!next_queue.empty()) { + auto ele = next_queue.front(); + ele->AddDependFeatSize(add_dist); + if (ele->depend_feat_size() > *max) { + *max = ele->depend_feat_size(); + } + for (auto &next : ele->next()) { + next_queue.push(next); + } + next_queue.pop(); + } + return SUCCESS; +} + +Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { + if (cnode_ptr == nullptr) { + MS_LOG(ERROR) << "cnode_ptr is nullptr!"; + return FAILED; + } + cnode_ptr_ = cnode_ptr; + return SUCCESS; +} + +Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { + if (node_ptr == nullptr) { + MS_LOG(ERROR) << "node_ptr is nullptr!"; + return FAILED; + } + if (!node_ptr->isa()) { + MS_LOG(ERROR) << "node_ptr is not a ParameterPtr!"; + return FAILED; + } + auto para_ptr = node_ptr->cast(); + MS_EXCEPTION_IF_NULL(para_ptr); + auto layout_ptr = para_ptr->tensor_layout(); + if (layout_ptr == nullptr) { + MS_LOG(ERROR) << "layout_ptr is nullptr!"; + return FAILED; + } + auto emplace_return = paras_.emplace(node_ptr); + if (emplace_return.second) { + double para_size = static_cast(layout_ptr->slice_shape().size()); + curr_para_size_ += para_size; + para_size_map_[node_ptr] = para_size; + } else { + MS_LOG(INFO) << "node already exist!"; + } + return SUCCESS; +} + +Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { + if (node_ptr == nullptr) { + MS_LOG(ERROR) << "node_ptr is nullptr!"; + return FAILED; + } + auto erase_num = paras_.erase(node_ptr); + if (erase_num == 0) { + MS_LOG(ERROR) << "para not find!"; + return FAILED; + } + curr_para_size_ -= para_size_map_[node_ptr]; + return SUCCESS; +} + +void AllreduceNode::ToString() const { + MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); + for (auto ¶ : paras_) { + MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); + } + MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h new file mode 100644 index 0000000000..6538381f27 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class AllreduceNode; +using AllreduceNodePtr = std::shared_ptr; + +class AllreduceNode { + public: + AllreduceNode() + : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} + Status Init(const CNodePtr &cnode_ptr); + Status AddPara(const AnfNodePtr &node_ptr); + Status RemovePara(const AnfNodePtr &node_ptr); + const std::unordered_set ¶s() const { return paras_; } + double curr_para_size() const { return curr_para_size_; } + virtual ~AllreduceNode() = default; + // Add previous node + // prev_node is the previous to be added + // max is the current max depend_feat_size of the AllreduceGraph + Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); + Status AddNext(const AllreduceNodePtr &next_node); + double depend_feat_size() const { return depend_feat_size_; } + void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } + const std::vector &next() const { return next_; } + void ToString() const; + bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } + bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } + + private: + CNodePtr cnode_ptr_; + std::vector prev_; + std::vector next_; + std::unordered_set paras_; + std::unordered_map para_size_map_; + double curr_para_size_; + double depend_feat_size_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc new file mode 100644 index 0000000000..b669fa7782 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" +#include +#include +#include "frontend/optimizer/optimizer.h" +#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion(); + // assume no change to graph + bool changes = false; + // control whether use model_parallel mode + if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || + (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { + return changes; + } +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + MS_LOG(INFO) << "Now entering allreduce fusion"; + DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN)); + + pipeline::ResourceBasePtr res = optimizer->resource(); + MS_EXCEPTION_IF_NULL(res); + + FuncGraphManagerPtr manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + CNodePtr ret = root->get_return(); + MS_EXCEPTION_IF_NULL(ret); + + AllreduceFusion allreduce_fusion; + if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) { + MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed"; + } + + DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); + + // allreduce fusion only run once + root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true); + res->results()[pipeline::kStepParallelGraph] = root; +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us"; +#endif + return changes; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h new file mode 100644 index 0000000000..2612e71984 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/step_allreduce_fusion.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ +#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ + +#include "frontend/optimizer/optimizer.h" + +namespace mindspore { +namespace parallel { +constexpr char ALLREDUCE_FUSION_RUN_ONCE_ONLY[] = "allreduce_fusion_run_once_only"; +constexpr char ALLREDUCE_FUSION_BEGIN[] = "allreduce_fusion_begin"; +constexpr char ALLREDUCE_FUSION_END[] = "allreduce_fusion_end"; + +bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc new file mode 100644 index 0000000000..531a5cd7f6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.cc @@ -0,0 +1,123 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/costmodel.h" +#include +#include +#include +#include "frontend/parallel/auto_parallel/graph_costmodel.h" + +namespace mindspore { +namespace parallel { +void Simplify(CostPtrList *clist_ptrs) { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); + } else { + // inference phase + SimplifyForDecreasingCommunicationForward(clist_ptrs); + } +} +void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { + // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method + // excludes the cost with greater computation_cost_ and greater communication_forward. + // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} + if (!COST_MODEL_SIMPLIFY_CALCULATION) { + return; + } + MS_EXCEPTION_IF_NULL(clist_ptrs); + std::vector id(clist_ptrs->size()); + std::iota(id.begin(), id.end(), size_t(0)); + std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { + return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; + }); + CostPtrList ret; + for (size_t i = 0; i < clist_ptrs->size(); ++i) { + if ((ret.size() == size_t(0)) || + (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { + ret.emplace_back(std::move(clist_ptrs->at(id[i]))); + } + } + *clist_ptrs = std::move(ret); +} + +void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { + // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing + // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. + if (!COST_MODEL_SIMPLIFY_CALCULATION) { + return; + } + MS_EXCEPTION_IF_NULL(clist_ptrs); + std::vector id(clist_ptrs->size()); + std::iota(id.begin(), id.end(), size_t(0)); + std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { + return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; + }); + CostPtrList ret; + for (size_t i = 0; i < clist_ptrs->size(); ++i) { + if ((ret.size() == size_t(0)) || + (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) { + ret.emplace_back(std::move(clist_ptrs->at(id[i]))); + } + } + *clist_ptrs = std::move(ret); +} + +void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { + MS_EXCEPTION_IF_NULL(origin_cost); + if (is_redistribution) { + // Redistribution cost + if ((origin_cost->communication_redis_forward_ > EPS) && + (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { + origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; + } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { + origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; + } + if ((origin_cost->communication_redis_backward_ > EPS) && + (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { + origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; + } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { + origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; + } + origin_cost->communication_cost_ = + origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; + origin_cost->communication_without_parameter_ = origin_cost->communication_cost_; + origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_; + } else { + // Operator cost + double backward = 0.0; + if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) { + backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_; + } + // forward cost + if ((origin_cost->communication_without_parameter_ > EPS) && + (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { + origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; + } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { + origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; + } + // total + if (origin_cost->communication_cost_ > EPS) { + origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; + } + if (origin_cost->communication_with_partial_para_ > EPS) { + origin_cost->communication_with_partial_para_ = + origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; + } + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h new file mode 100644 index 0000000000..cc4508681b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/costmodel.h @@ -0,0 +1,311 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" + +namespace mindspore { +namespace parallel { +struct Decision; +using OperatorName = std::string; +using Attr = std::pair; +using Param = std::pair, int32_t>; +using OperatorParams = std::vector; +using OperatorAttrs = std::vector; +// OutPutInfo.fist: true if the operator's output is a tuple +// OutPutInfo.second: elements number of the tuple output. Only meaningful if OutPutInfo.fist is true. +using OutPutInfo = std::pair; +using OutPutInfoVector = std::vector; +using OperatorArgs = std::pair; +using Operator = std::pair; +using OperatorVector = std::vector; +using RedistributionOpListPtr = std::shared_ptr>; + +struct Cost { + Cost(); + Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) + : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { + memory_with_reuse_ = 0.0; + communication_without_parameter_ = 0.0; + communication_with_partial_para_ = 0.0; + communication_redis_forward_ = 0.0; + communication_redis_backward_ = 0.0; + communication_forward_ = 0.0; + } + // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase + double memory_with_reuse_; + // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated + // by ONLY forward phase + double computation_cost_; + // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) + double communication_cost_; + // communication_without_parameter_ = communication_cost_ - (backward communication from operators) + double communication_without_parameter_; + // communication_with_partial_para_ = + // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) + double communication_with_partial_para_; + // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. + double communication_forward_; + double communication_redis_forward_; + double communication_redis_backward_; + std::shared_ptr decision_ptr_; +}; + +using CostPtr = std::shared_ptr; +using CostPtrList = std::vector>; + +class StrategyWithCost { + public: + StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) + : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} + + StrategyWithCost(const StrategyWithCost &swc) = delete; + StrategyWithCost(StrategyWithCost &&swc) + : strategy_ptr(swc.strategy_ptr), + inputs_ptr(swc.inputs_ptr), + outputs_ptr(swc.outputs_ptr), + cost_list(swc.cost_list) {} + ~StrategyWithCost() = default; + + StrategyPtr strategy_ptr; + std::vector inputs_ptr; + std::vector outputs_ptr; + CostPtrList cost_list; +}; + +enum DecisionType { + OP_ELIMINATION, + EDGE_ELIMINATION, + MERGE_ELIMINATION, + CONTRACT_ELIMINATION, + TRIANGLE_ELIMINATION, + STAR_ELIMINATION, + FINAL_TYPE, + FINAL_SINGLE +}; + +struct Decision : public Base { + ~Decision() override = default; + DecisionType type_; +}; + +// 'OpEliminationDecision' is for the Operator Elimination in DP algorithm: u --> v --> w ==> u --> w. +// This data structure records the strategy 'op_strategy_' for v, the edge cost 'left_cost_' for 'u --> v', the +// operator cost 'middle_cost_' for v, and the edge cost 'right_cost_' for 'v --> w' +struct OpEliminationDecision : public Decision { + OpEliminationDecision(StrategyPtr op_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) + : op_strategy_(std::move(op_stra)), + left_cost_(std::move(l_cost)), + middle_cost_(std::move(m_cost)), + right_cost_(std::move(r_cost)) { + type_ = DecisionType::OP_ELIMINATION; + } + + StrategyPtr op_strategy_; + CostPtr left_cost_; + CostPtr middle_cost_; + CostPtr right_cost_; + MS_DECLARE_PARENT(OpEliminationDecision, Decision); +}; + +/* 'EdgeEliminationDecision' is for the Edge Elimination in DP algorithm: + ____ + / \ + u v ==> u --> v, which replace the multi-edges by a single edge. + \____/ + This data structure records the cost list for all edges 'edges_cost_list_' + */ +struct EdgeEliminationDecision : public Decision { + explicit EdgeEliminationDecision(CostPtrList cost_list) : edges_cost_list_(std::move(cost_list)) { + type_ = DecisionType::EDGE_ELIMINATION; + } + + CostPtrList edges_cost_list_; + MS_DECLARE_PARENT(EdgeEliminationDecision, Decision); +}; + +// 'MergeEliminationDecision' is for the Merge Elimination in DP algorithm: +// w +// | +// | ==> u --> v +// u --> v In the original graph, v has two alive incoming edges, w has one alive outgoing edge, +// and w has zero alive incoming edges. After the Merge Elimination, the result graph contains only 'u -- >v'. +// This data structure records the strategy 'merged_op_strategy_' for operator 'w', +// the cost 'merged_op_cost_' for operator 'w', and the edge cost 'edge_cost_' for 'w --> v'. +struct MergeEliminationDecision : public Decision { + MergeEliminationDecision(StrategyPtr op_stra, CostPtr op_cost, CostPtr edge_c, StrategyPtr tar_op_stra, + CostPtr target_op_c) + : merged_op_strategy_(std::move(op_stra)), + merged_op_cost_(std::move(op_cost)), + edge_cost_(std::move(edge_c)), + target_op_strategy_(std::move(tar_op_stra)), + target_op_cost_(std::move(target_op_c)) { + type_ = DecisionType::MERGE_ELIMINATION; + } + + StrategyPtr merged_op_strategy_; + CostPtr merged_op_cost_; + CostPtr edge_cost_; + StrategyPtr target_op_strategy_; + CostPtr target_op_cost_; + MS_DECLARE_PARENT(MergeEliminationDecision, Decision); +}; + +// 'ContractEliminationDecision' is for the Contract Elimination in DP algorithm: +// u --> v +// | +// | ==> u --> w +// w In the original graph, u has two alive outgoing edges, v has one alive incoming edge, +// and v has zero outgoing edge. After the Contract Elimination, the result graph contains only 'u --> w'. +// This data structure records the strategy 'contracted_op_strategy_' for operator 'v', the cost for +// operator 'contracted_op_cost_', and the edge cost for 'edge_cost_'. +struct ContractEliminationDecision : public Decision { + ContractEliminationDecision(StrategyPtr contra_stra, CostPtr contra_op_cost, CostPtr edge_cost, + StrategyPtr target_stra, CostPtr tar_cost) + : contracted_op_strategy_(std::move(contra_stra)), + contracted_op_cost_(std::move(contra_op_cost)), + edge_cost_(std::move(edge_cost)), + target_op_strategy_(std::move(target_stra)), + target_cost_(std::move(tar_cost)) { + type_ = DecisionType::CONTRACT_ELIMINATION; + } + + StrategyPtr contracted_op_strategy_; + CostPtr contracted_op_cost_; + CostPtr edge_cost_; + StrategyPtr target_op_strategy_; + CostPtr target_cost_; + MS_DECLARE_PARENT(ContractEliminationDecision, Decision); +}; + +/* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: + * + * u + * / \ + * / \ + * v --- w ==> v --- w In the original graph, u has 2 outgoing edges, v has 1 outgoing edge, + * and w has 2 incoming edges, u can be eliminated into v. + * 'eliminated_op_strategy_' is for u, 'eliminated_op_cost_' is for u, 'eliminated_left_edge_' is for edge u --> v, + * 'eliminated_right_edge_' is for edge u --> w. + */ +struct TriangleEliminationDecision : public Decision { + TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, + StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) + : eliminated_op_strategy_(std::move(elimi_stra)), + eliminated_op_cost_(std::move(elimi_op_cost)), + left_edge_cost_(std::move(l_edge_cost)), + right_edge_cost_(std::move(r_edge_cost)), + left_node_strategy_(std::move(left_stra)), + left_node_cost_(std::move(l_node_cost)), + right_node_strategy_(std::move(right_stra)) { + type_ = DecisionType::TRIANGLE_ELIMINATION; + } + + StrategyPtr eliminated_op_strategy_; + CostPtr eliminated_op_cost_; + CostPtr left_edge_cost_; + CostPtr right_edge_cost_; + StrategyPtr left_node_strategy_; + CostPtr left_node_cost_; + StrategyPtr right_node_strategy_; + MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); +}; + +/* 'StarEliminationDecision' is for the Star Elimination in DP algorithm: + * + * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. + * In addition, v and w have other complicated connections, resulting in v and w can not be performed other + * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple + * connected components. + * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. + */ +struct StarEliminationDecision : public Decision { + StarEliminationDecision(StrategyPtr elimi_op_stra, CostPtr elimi_op_cost, CostPtrList succ_edges_clist, + std::vector succ_ops_stra_list, CostPtrList succ_ops_clist) + : eliminated_op_strategy_(std::move(elimi_op_stra)), + eliminated_op_cost_(std::move(elimi_op_cost)), + succ_edges_cost_list_(std::move(succ_edges_clist)), + succ_ops_stra_list_(std::move(succ_ops_stra_list)), + succ_ops_cost_list_(std::move(succ_ops_clist)) { + type_ = DecisionType::STAR_ELIMINATION; + } + + StrategyPtr eliminated_op_strategy_; + CostPtr eliminated_op_cost_; + CostPtrList succ_edges_cost_list_; + std::vector succ_ops_stra_list_; + CostPtrList succ_ops_cost_list_; + MS_DECLARE_PARENT(StarEliminationDecision, Decision); +}; + +// This data structure records the decision for the graph which contains two nodes: u --> v. This includes +// the strategy 'u_strategy_' for 'u', the strategy 'v_strategy_' for 'v', the cost 'left_cost_' for 'u'. +struct FinalDecision : public Decision { + FinalDecision(StrategyPtr u_stra, StrategyPtr v_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) + : u_strategy_(std::move(u_stra)), + v_strategy_(std::move(v_stra)), + left_cost_(std::move(l_cost)), + middle_cost_(std::move(m_cost)), + right_cost_(std::move(r_cost)) { + type_ = DecisionType::FINAL_TYPE; + } + + StrategyPtr u_strategy_; + StrategyPtr v_strategy_; + CostPtr left_cost_; + CostPtr middle_cost_; + CostPtr right_cost_; + MS_DECLARE_PARENT(FinalDecision, Decision); +}; + +// This data structure records the final decision for the graph containing a single node: u. This includes +// the strategy 'u_strategy_' for 'u', the cost 'u_cost_' for 'u'. +struct FinalSingleDecision : public Decision { + FinalSingleDecision(StrategyPtr u_stra, CostPtr u_cost) : u_strategy_(std::move(u_stra)), u_cost_(std::move(u_cost)) { + type_ = DecisionType::FINAL_SINGLE; + } + + StrategyPtr u_strategy_; + CostPtr u_cost_; + MS_DECLARE_PARENT(FinalSingleDecision, Decision); +}; + +using DecisionPtr = std::shared_ptr; +using OpEliminationDecisionPtr = std::shared_ptr; +using EdgeEliminationDecisionPtr = std::shared_ptr; +using MergeEliminationDecisionPtr = std::shared_ptr; +using ContractEliminationDecisionPtr = std::shared_ptr; +using TriangleEliminationDecisionPtr = std::shared_ptr; +using StarEliminationDecisionPtr = std::shared_ptr; +using FinalDecisionPtr = std::shared_ptr; +using FinalSingleDecisionPtr = std::shared_ptr; + +void Simplify(CostPtrList *clist); +void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); +void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); +void RefineForPracticalCost(const CostPtr &, bool is_redistribution); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc new file mode 100644 index 0000000000..9408596111 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.cc @@ -0,0 +1,226 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" + +#include +#include +#include + +namespace mindspore { +namespace parallel { +Status GetStrategy(const CostGraphPtr &graph) { + MS_LOG(INFO) << "Searching strategies begins."; + MS_EXCEPTION_IF_NULL(graph); + std::vector eliminations; + bool flag = true; + + // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order. + // Note: the checking and applying of the 6 operations MUST in current order. + while (flag) { + flag = false; + auto node = graph->CheckOpElimination(); + if (node != nullptr) { + // Applying the Operator Elimination + flag = true; + auto l_edge = node->GetAlivePrevEdges()[0]; + auto r_edge = node->GetAliveSuccEdges()[0]; + auto n_edge = graph->EliminationOp(node); + auto elimi = std::make_shared(n_edge, l_edge, node, r_edge); + eliminations.emplace_back(std::move(elimi)); + } + auto edges = graph->CheckEdgeElimination(); + if ((!flag) && (!edges.empty())) { + // Applying the Edge Elimination + flag = true; + auto n_edge = graph->EliminationEdges(edges); + auto elimi = std::make_shared(n_edge, edges); + eliminations.emplace_back(std::move(elimi)); + } + auto merge_node = graph->CheckMergeElimination(); + if ((!flag) && (merge_node != nullptr)) { + // Applying the Merge Elimination + flag = true; + auto succ_edge = merge_node->GetAliveSuccEdges()[0]; + auto target_node = graph->EliminationMerge(merge_node); + auto elimi = std::make_shared(merge_node, succ_edge, target_node); + eliminations.emplace_back(std::move(elimi)); + } + auto contracted_node = graph->CheckContractElimination(); + if ((!flag) && (contracted_node != nullptr)) { + // Applying the Contract Elimination + flag = true; + auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; + auto target_node = graph->EliminationContract(contracted_node); + auto elimi = std::make_shared(target_node, prev_edge, contracted_node); + eliminations.emplace_back(std::move(elimi)); + } + auto triangle_pair = graph->CheckTriangleElimination(); + if ((!flag) && (triangle_pair.first != nullptr)) { + // Applying the Triangle Elimination + flag = true; + auto eliminated_node = triangle_pair.first; + auto l_r_edge = triangle_pair.second; + + auto left_node = l_r_edge->prev_operator(); + auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; + auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(left_edge); + if (left_edge->next_operator() != left_node) { + auto tmp = left_edge; + left_edge = right_edge; + right_edge = tmp; + } + auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); + auto right_node = l_r_edge->next_operator(); + auto elimi = + std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); + eliminations.emplace_back(std::move(elimi)); + } + auto star_center = graph->CheckStarElimination(); + if ((!flag) && (star_center != nullptr)) { + // Applying the Star Elimination + flag = true; + auto succ_edges = graph->EliminationStar(star_center); + std::vector succ_nodes; + for (size_t i = 0; i < succ_edges.size(); ++i) { + MS_EXCEPTION_IF_NULL(succ_edges[i]); + succ_nodes.push_back(succ_edges[i]->next_operator()); + } + auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); + eliminations.emplace_back(std::move(elimi)); + } + } + + // Phase 2: Search the cost_list in the final graph, and determine the optimal one + if (graph->SearchStrategy() != SUCCESS) { + MS_LOG(ERROR) << "Searching strategy for the final failed."; + return FAILED; + } + + // Phase 3: Recover the original CostGraph, the determine strategy for each operator + if (RecoverStrategy(eliminations) == SUCCESS) { + MS_LOG(INFO) << "Searching strategies ends."; + return SUCCESS; + } else { + MS_LOG(EXCEPTION) << "Searching strategies failed."; + } +} + +Status RecoverStrategy(std::vector eliminations) { + std::vector::reverse_iterator rit; + + for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { + if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto e = elimination->new_edge_; + auto w = elimination->op_; + MS_EXCEPTION_IF_NULL(e); + MS_EXCEPTION_IF_NULL(w); + auto left_edge = elimination->left_edge_; + auto right_edge = elimination->right_edge_; + MS_EXCEPTION_IF_NULL(left_edge); + MS_EXCEPTION_IF_NULL(right_edge); + auto decision = e->selected_cost()->decision_ptr_->cast(); + w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); + left_edge->set_selected_cost(decision->left_cost_); + right_edge->set_selected_cost(decision->right_cost_); + MS_LOG(INFO) << "Recover opElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto new_edge = elimination->new_edge_; + MS_EXCEPTION_IF_NULL(new_edge); + auto &edges = elimination->edges_; + auto decision = new_edge->selected_cost()->decision_ptr_->cast(); + for (size_t j = 0; j < edges.size(); ++j) { + MS_EXCEPTION_IF_NULL(edges[j]); + edges[j]->set_selected_cost(decision->edges_cost_list_[j]); + } + MS_LOG(INFO) << "Recover edgeElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto target_node = elimination->target_node_; + MS_EXCEPTION_IF_NULL(target_node); + auto merged_node = elimination->merged_node_; + MS_EXCEPTION_IF_NULL(merged_node); + auto merged_edge = elimination->dir_edge_; + MS_EXCEPTION_IF_NULL(merged_edge); + MS_EXCEPTION_IF_NULL(target_node->selected_cost()); + MS_EXCEPTION_IF_NULL(target_node->selected_cost()->decision_ptr_); + auto decision = target_node->selected_cost()->decision_ptr_->cast(); + merged_node->SetSelectedStrategyAndCost(decision->merged_op_strategy_, decision->merged_op_cost_); + merged_edge->set_selected_cost(decision->edge_cost_); + target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_op_cost_); + + MS_LOG(INFO) << "Recover mergeElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto target_node = elimination->target_node_; + auto contracted_node = elimination->contracted_node_; + auto contracted_edge = elimination->dir_edge_; + auto decision = target_node->selected_cost()->decision_ptr_->cast(); + + contracted_node->SetSelectedStrategyAndCost(decision->contracted_op_strategy_, decision->contracted_op_cost_); + contracted_edge->set_selected_cost(decision->edge_cost_); + target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_cost_); + MS_LOG(INFO) << "Recover contractElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto left_node = elimination->left_node_; + auto left_edge = elimination->left_edge_; + auto eliminated_node = elimination->eliminated_node_; + auto right_edge = elimination->right_edge_; + auto right_node = elimination->right_node_; + auto decision = left_node->selected_cost()->decision_ptr_->cast(); + + eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); + left_edge->set_selected_cost(decision->left_edge_cost_); + right_edge->set_selected_cost(decision->right_edge_cost_); + // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. + left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); + right_node->CheckSelectedStrategy(decision->right_node_strategy_); + MS_LOG(INFO) << "Recover triangleElimination succeeded."; + } else if ((*rit)->isa()) { + auto elimination = (*rit)->cast(); + auto merged_node = elimination->eliminated_node_; + auto succ_edges = elimination->succ_edges_; + auto succ_nodes = elimination->succ_ops_; + // decision is hided in succ_nodes[0] + auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast(); + + merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); + for (size_t i = 0; i < succ_edges.size(); ++i) { + succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); + } + MS_EXCEPTION_IF_NULL(succ_nodes[0]); + MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); + MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); + // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. + succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); + for (size_t k = 1; k < succ_nodes.size(); ++k) { + succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); + } + MS_LOG(INFO) << "Recover starElimination succeeded."; + } else { + MS_LOG(ERROR) << "Unknown Elimination type."; + return FAILED; + } + } + + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h new file mode 100644 index 0000000000..812f375f0b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/dp_algo_costmodel.h @@ -0,0 +1,152 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ + +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" + +namespace mindspore { +namespace parallel { +// There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal +// is to compute the strategy for each operator in the CostGraph. +// +// Phase 1: Shrink the CostGraph using 6 operations, and record them in the order +// Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination, +// each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the +// interpretation of 6 operations in costmodel.h. +// Phase 2: Search the cost_list in the final graph, and determine the optimal one +// Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity +// COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost +// Phase 3: Recover the original CostGraph, the determine strategy for each operator +// After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying +// the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, +// the operators' strategies can be all determined. + +struct Elimination : public Base { + enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; + Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} + + EdgePtr new_edge_; + EliminationType type_; +}; + +// Operator Elimination +struct OpElimination : public Elimination { + OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge) + : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA), + left_edge_(std::move(l_edge)), + op_(std::move(op_info)), + right_edge_(std::move(r_edge)) {} + + EdgePtr left_edge_; + OperatorInfoPtr op_; + EdgePtr right_edge_; + MS_DECLARE_PARENT(OpElimination, Elimination); +}; + +// Edge Elimination +struct EdgeElimination : public Elimination { + EdgeElimination(const EdgePtr &n_edge, std::vector eds) + : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} + + std::vector edges_; + MS_DECLARE_PARENT(EdgeElimination, Elimination); +}; + +// Merge Elimination +struct MergeElimination : public Elimination { + MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info) + : Elimination(nullptr, Elimination::EliminationType::MERGE), + merged_node_(std::move(u_info)), + dir_edge_(std::move(merged_target_edge)), + target_node_(std::move(v_info)) {} + + OperatorInfoPtr merged_node_; + EdgePtr dir_edge_; + OperatorInfoPtr target_node_; + MS_DECLARE_PARENT(MergeElimination, Elimination); +}; + +// Contract Elimination +struct ContractElimination : public Elimination { + ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info) + : Elimination(nullptr, Elimination::EliminationType::CONTRACT), + contracted_node_(std::move(con_info)), + dir_edge_(std::move(tar_con_edge)), + target_node_(std::move(tar_info)) {} + + OperatorInfoPtr contracted_node_; + EdgePtr dir_edge_; + OperatorInfoPtr target_node_; + MS_DECLARE_PARENT(ContractElimination, Elimination); +}; + +// Triangle Elimination +struct TriangleElimination : public Elimination { + TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, + OperatorInfoPtr r_node) + : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), + eliminated_node_(std::move(elim_node)), + left_edge_(std::move(l_edge)), + left_node_(std::move(l_node)), + right_edge_(std::move(r_edge)), + right_node_(std::move(r_node)) {} + + OperatorInfoPtr eliminated_node_; + EdgePtr left_edge_; + OperatorInfoPtr left_node_; + EdgePtr right_edge_; + OperatorInfoPtr right_node_; + MS_DECLARE_PARENT(TriangleElimination, Elimination); +}; + +// Star Elimination +struct StarElimination : public Elimination { + StarElimination(OperatorInfoPtr elimi_node, std::vector s_edges, std::vector s_ops) + : Elimination(nullptr, Elimination::EliminationType::STAR), + eliminated_node_(std::move(elimi_node)), + succ_edges_(std::move(s_edges)), + succ_ops_(std::move(s_ops)) {} + + OperatorInfoPtr eliminated_node_; + std::vector succ_edges_; + std::vector succ_ops_; + MS_DECLARE_PARENT(StarElimination, Elimination); +}; + +using EliminationPtr = std::shared_ptr; +using OpEliminationPtr = std::shared_ptr; +using EdgeEliminationPtr = std::shared_ptr; +using MergeEliminationPtr = std::shared_ptr; +using ContractEliminationPtr = std::shared_ptr; +using TriangleEliminationPtr = std::shared_ptr; +using StarEliminationPtr = std::shared_ptr; + +// Phase 1 and Phase 2 +Status GetStrategy(const CostGraphPtr &graph); + +// Phase 3 +Status RecoverStrategy(std::vector eliminations); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc new file mode 100644 index 0000000000..e3f1de7207 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/edge_costmodel.h" + +#include +#include +#include +#include +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status Edge::InitEdgeCost() { + bool has_available_cost = false; + for (auto &swc : prev_op_->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(swc); + pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); + } + for (auto &swc : next_op_->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(swc); + next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); + } + if (is_identity_edge) { + for (auto &target_output : pre_op_output_) { + auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); + auto target_output_str = target_output.first; + for (auto &target_input : next_op_input_) { + auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); + auto target_input_str = target_input.first; + if (target_output_lyt == target_input_lyt) { + CostPtrKey ck = {target_output_str, target_input_str}; + CostPtr cost = std::make_shared(0.0, 0.0); + MS_EXCEPTION_IF_NULL(cost); + cost->communication_without_parameter_ = 0.0; + cost->communication_with_partial_para_ = 0.0; + CostPtrList cl; + cl.push_back(cost); + (void)cost_map_.emplace(std::make_pair(ck, cl)); + has_available_cost = true; + } + } + } + } else { + for (auto &target_output : pre_op_output_) { + auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); + auto target_output_str = target_output.first; + auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; + auto type = prev_op_->outputs_type()[prev_op_output_index_]; + for (auto &target_input : next_op_input_) { + auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); + auto target_input_str = target_input.first; + CostPtr cost; + if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) { + MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; + } + MS_EXCEPTION_IF_NULL(cost); + MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_ + << ", communication_cost: " << cost->communication_cost_ + << ", communication_without_parameter_: " << cost->communication_without_parameter_ + << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; + // refine communication cost calculation for practice + RefineForPracticalCost(cost, true); + cost->communication_forward_ = cost->communication_redis_forward_; + CostPtrKey ck = {target_output_str, target_input_str}; + CostPtrList cl; + cl.push_back(cost); + (void)cost_map_.emplace(std::make_pair(ck, cl)); + has_available_cost = true; + } + } + } + if (!has_available_cost) { + if (FULLY_USE_DEVICES) { + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ + << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " + "'fully_use_devices' false."; + } else if (ELEMENTWISE_OP_STRA_FOLLOW) { + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ + << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " + "Try to set 'elementwise_op_strategy_follow' false."; + } + if (edge_name_.find(RESHAPE) != std::string::npos) { + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ + << " failed, it may be caused by setting different strategies for operators following Reshape. " + "Try to fix that."; + } + MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed."; + } + return Status::SUCCESS; +} + +Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t type_length, TypePtr type, CostPtr *cost) { + MS_EXCEPTION_IF_NULL(prev_op_); + MS_EXCEPTION_IF_NULL(cost); + RankList dev_list = prev_op_->global_device_list(); + TensorRedistribution tensor_redistribution(false); + + // Init TensorRedistribution + if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; + } + + if (tensor_redistribution.ComputeCost() == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; + } + + double comm_cost = tensor_redistribution.comm_cost(); + double forward_comm_cost = tensor_redistribution.forward_comm_cost(); + double backward_comm_cost = tensor_redistribution.backward_comm_cost(); + double computation_cost = tensor_redistribution.computation_cost(); + double mem_cost = tensor_redistribution.memory_cost(); + + // Now AllGather, ReduceScatter, AlltoAll don't support bool type + MS_EXCEPTION_IF_NULL(type); + if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) { + computation_cost = INF; + comm_cost = INF; + MS_LOG(WARNING) << "Communication Operators don't support bool dtype!"; + } + *cost = std::make_shared(type_length * computation_cost, type_length * comm_cost); + (*cost)->communication_without_parameter_ = type_length * comm_cost; + (*cost)->communication_with_partial_para_ = + (*cost)->communication_without_parameter_ + + COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); + (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; + (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; + (*cost)->memory_with_reuse_ = mem_cost; + return Status::SUCCESS; +} + +CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { + CostPtrKey ck = {output_str, input_str}; + CostPtrList result; + if (cost_map_.find(ck) != cost_map_.end()) { + return cost_map_.at(ck); + } + return result; +} + +CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, + const StrategyPtr &input_st_ptr) { + std::function LocalGetCostList = [&](const EdgePtr &edge) { + MS_EXCEPTION_IF_NULL(edge); + return edge->GetCostList(output_st_ptr, input_st_ptr); + }; + CostPtrList result; + std::vector all_cost_list; + all_cost_list.resize(edges.size()); + (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); + + CostPtrList selected_cost_list(all_cost_list.size(), nullptr); + std::function recursive = + [&](size_t k, double computation, double memory, double communication, double communication_without_para, + double communication_forward) { + if (k == edges.size()) { + auto decision = std::make_shared(selected_cost_list); + CostPtr new_cost = std::make_shared(computation, communication); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + new_cost->decision_ptr_ = decision; + result.push_back(new_cost); + return; + } + for (auto &c : all_cost_list[k]) { + MS_EXCEPTION_IF_NULL(c); + selected_cost_list[k] = c; + recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, + communication + c->communication_cost_, + communication_without_para + c->communication_without_parameter_, + communication_forward + c->communication_forward_); + } + }; + recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); + Simplify(&result); + return result; +} + +void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { + bool valid = false; + for (const auto &output_pair : pre_op_output_) { + StrategyPtr output_st_ptr = output_pair.first; + for (const auto &input_pair : next_op_input_) { + StrategyPtr input_st_ptr = input_pair.first; + CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); + CostPtrKey key = {output_st_ptr, input_st_ptr}; + cost_map_[key] = clist; + if ((!valid) && (!clist.empty())) { + valid = true; + } + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; + } +} + +void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list) { + for (auto &left_cost : left_cost_list) { + MS_EXCEPTION_IF_NULL(left_cost); + for (auto &middle_cost : middle_cost_list) { + MS_EXCEPTION_IF_NULL(middle_cost); + for (auto &right_cost : right_cost_list) { + MS_EXCEPTION_IF_NULL(right_cost); + double computation = + left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; + double communication = + left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; + double communication_forward = + left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; + double communication_without_para = left_cost->communication_without_parameter_ + + middle_cost->communication_without_parameter_ + + right_cost->communication_without_parameter_; + double memory_cost = + left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_; + + auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); + auto cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(cost); + cost->communication_without_parameter_ = communication_without_para; + cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + cost->memory_with_reuse_ = memory_cost; + cost->communication_forward_ = communication_forward; + ret_cost_list->emplace_back(std::move(cost)); + } + } + } +} + +CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, + const OperatorInfoPtr &op, const EdgePtr &e2, + const StrategyPtr &input_st_ptr) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(e1); + MS_EXCEPTION_IF_NULL(e2); + CostPtrList result; + for (const auto &op_strategy : op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(op_strategy); + auto middle_strategy = op_strategy->strategy_ptr; + CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), + op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); + } + Simplify(&result); + return result; +} + +void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { + bool valid = false; + for (const auto &output_pair : pre_op_output_) { + StrategyPtr output_st_ptr = output_pair.first; + for (const auto &input_pair : next_op_input_) { + StrategyPtr input_st_ptr = input_pair.first; + + CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); + CostPtrKey key = {output_st_ptr, input_st_ptr}; + cost_map_[key] = clist; + if ((!valid) && (!clist.empty())) { + valid = true; + } + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; + } +} + +Status Edge::CalculateMemoryCost() { + if (is_output_parameter_involve_ == -1) { + MS_LOG(ERROR) << "is_output_parameter_involve_ is unset."; + return FAILED; + } + if (is_output_parameter_involve_ == 0) { + // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is + // unnecessary to keep them in memory. + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; + if (!cost_v.empty()) { + cost_v[0]->memory_with_reuse_ = 0; + } + } + } + + return SUCCESS; +} + +Status Edge::CalculateMemoryCostForInference() { + // Currently, memory cost is NOT calculated for redistribution + if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { + MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; + return FAILED; + } + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; + if (!cost_v.empty()) { + cost_v[0]->memory_with_reuse_ = 0; + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h new file mode 100644 index 0000000000..3fffd1b86d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/edge_costmodel.h @@ -0,0 +1,171 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ +#define PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ + +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +using CostPtrKey = std::pair; +using OperatorInfoPtr = std::shared_ptr; +using EdgePtr = std::shared_ptr; + +class Edge { + // An 'Edge' connects two Operators in the CostGraph. + public: + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com) + : edge_name_(edge_name), + prev_op_(prev_op), + next_op_(next_op), + prev_op_output_index_(output_index_), + next_op_input_index_(input_index_), + is_combined_(is_com) { + is_identity_edge = false; + } + + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com, const bool &is_iden) + : edge_name_(edge_name), + prev_op_(prev_op), + next_op_(next_op), + prev_op_output_index_(output_index_), + next_op_input_index_(input_index_), + is_combined_(is_com), + is_identity_edge(is_iden) {} + + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const std::vector &output_indexs_, + const std::vector &input_indexs_, const bool &is_com) + : edge_name_(edge_name), + prev_op_(prev_op), + next_op_(next_op), + pre_op_output_indexs_(output_indexs_), + next_op_input_indexs_(input_indexs_), + is_combined_(is_com) { + prev_op_output_index_ = 0; + next_op_input_index_ = 0; + is_identity_edge = false; + } + + ~Edge() = default; + std::shared_ptr prev_operator() const { return prev_op_; } + std::shared_ptr next_operator() const { return next_op_; } + std::string edge_name() const { return edge_name_; } + // Init cost_map_: for each output layout and input layout, calculate the cost + Status InitEdgeCost(); + // For two operators u--->v, given the output tensor layout of u, + // and the input tensor layout of v, return the redistribution cost, + // and the op_list to carry out the redistribution. + Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t, TypePtr type, CostPtr *cost); + + void set_pre_op_output(const std::vector, std::vector>> &output_set) { + pre_op_output_ = output_set; + } + void set_next_op_input(const std::vector, std::vector>> &input_set) { + next_op_input_ = input_set; + } + + // Given a pair of output strategy and input strategy, return the corresponding costlist + CostPtrList GetCostList(StrategyPtr output_str, StrategyPtr input_str); + + std::vector, std::vector>> prev_op_output() const { + return pre_op_output_; + } + std::vector, std::vector>> next_op_input() const { + return next_op_input_; + } + + bool is_combined() const { return is_combined_; } + size_t prev_op_output_index() const { return prev_op_output_index_; } + size_t next_op_input_index() const { return next_op_input_index_; } + std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } + std::vector next_op_input_indexs() const { return next_op_input_indexs_; } + + CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, + const std::vector> &edges, + const StrategyPtr &input_st_ptr); + // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to + // set cost for this new edge + void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, + std::shared_ptr v); + void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list); + + CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, + const std::shared_ptr &op, const std::shared_ptr &e2, + const StrategyPtr &input_st_ptr); + // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. + // This method is used to set cost for this new edge + void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, + const std::shared_ptr &e2); + + void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } + const CostPtr &selected_cost() const { return selected_cost_; } + void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } + // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving + // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory + // at the end of forward phase. + Status CalculateMemoryCost(); + // In the inference phase, + Status CalculateMemoryCostForInference(); + void mark_output_critical() { is_output_critical_ = 1; } + + private: + std::string edge_name_; + std::shared_ptr prev_op_, next_op_; + std::map cost_map_; + // pre_op_output_ + std::vector, std::vector>> pre_op_output_; + std::vector, std::vector>> next_op_input_; + // the index of outputs of prev_op, and the index of inputs of next_op + size_t prev_op_output_index_, next_op_input_index_; + + // pre_op_output_indexs_ and next_op_input_indexs_ store the indexs of inputs and outputs if is_combined = true + std::vector pre_op_output_indexs_; + std::vector next_op_input_indexs_; + // is this edge constructed by combining multiple edges? If is is, then is_combined = true, else is_combined = false + bool is_combined_; + // When a Parameter in the ANF graph being used by multiple operators, we include the Parameter in the costgraph by + // replace the Parameter by a TmpIdentity operator, and connecting this TmpIdentity operator with subsequent + // operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them. + // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. + bool is_identity_edge; + CostPtr selected_cost_; + // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator + // is parameter-involved + int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // In the inference phase, this is used to mark whether the output of the previous operator is critical. + int is_output_critical_ = 0; +}; +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc new file mode 100644 index 0000000000..1c1fc3a700 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.cc @@ -0,0 +1,1677 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/step_auto_parallel.h" + +namespace mindspore { +namespace parallel { +CostGraphPtr entire_costgraph = nullptr; +size_t TOTAL_OPS = 0; +double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; +bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; +double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; +double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; +double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; +double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; +bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; +size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; +bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; +bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; +bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; +int32_t RUN_PHASE = DEFAULT_RUN_PHASE; + +void CostGraph::SetDeviceMemoryAndCostParameter() { + MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); + + // DEVICE_MEMORY_CAPACITY + auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); + if (device_memory <= 0) { + MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; + } + dev_memory_ = device_memory; + DEVICE_MEMORY_CAPACITY = device_memory; + MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; + + // COST_MODEL_ALPHA + auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); + if (alpha <= 0) { + MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; + } + costmodel_alpha_ = alpha; + MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; + + // COST_MODEL_BETA + auto beta = CostModelContext::GetInstance()->costmodel_beta(); + if (beta <= 0) { + MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; + } + costmodel_beta_ = beta; + MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; + + // COST_MODEL_GAMMA + auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); + if ((gamma < 0) || (gamma > 1)) { + MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; + } + COST_MODEL_GAMMA = gamma; + MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; + + // COST_MODEL_SIMPLIFY_CALCULATION + auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); + COST_MODEL_SIMPLIFY_CALCULATION = simplify; + if (COST_MODEL_SIMPLIFY_CALCULATION) { + MS_LOG(INFO) << "costmodel_simplify_cal: true."; + } else { + MS_LOG(INFO) << "costmodel_simplify_cal: false."; + } + + // COST_MODEL_COMMUNI_THRESHOLD + auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); + if (communi_threshold < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; + } + COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; + MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; + + // COST_MODEL_COMMUNI_CONST + auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); + if (communi_const < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; + } + COST_MODEL_COMMUNI_CONST = communi_const; + MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; + + // COST_MODEL_COMMUNI_BIAS + auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); + if (communi_bias < 0) { + MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; + } + COST_MODEL_COMMUNI_BIAS = communi_bias; + MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; + + // TENSOR_SLICE_ALIGNMENT_ENABLE + auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); + TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; + if (TENSOR_SLICE_ALIGNMENT_ENABLE) { + MS_LOG(INFO) << "tensor_slice_align_enable: true."; + } else { + MS_LOG(INFO) << "tensor_slice_align_enable: false."; + } + + // TENSOR_SLICE_ALIGNMENT_SIZE + auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); + if (align_size == 0) { + MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; + } + TENSOR_SLICE_ALIGNMENT_SIZE = align_size; + MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; + + // FULLY_USE_DEVICES + auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); + FULLY_USE_DEVICES = fully_devices; + if (FULLY_USE_DEVICES) { + MS_LOG(INFO) << "fully_use_devices: true."; + } else { + MS_LOG(INFO) << "fully_use_devices: false."; + } + + // ELEMENTWISE_OP_STRA_FOLLOW + auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); + ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; + if (ELEMENTWISE_OP_STRA_FOLLOW) { + MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; + } else { + MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; + } + + // MULTI_SUBGRAPHS + auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); + MULTI_SUBGRAPHS = multi_subgraphs; + if (MULTI_SUBGRAPHS) { + MS_LOG(INFO) << "multi_subgraphs: true."; + } else { + MS_LOG(INFO) << "multi_subgraphs: false."; + } + + // RUN_PHASE + auto phase = CostModelContext::GetInstance()->run_phase(); + if (phase != 0 && phase != 1) { + MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; + } + RUN_PHASE = phase; + MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; +} + +void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { + for (auto it = ops_.begin(); it != ops_.end();) { + if ((*it) == op) { + it = ops_.erase(it); + } else { + ++it; + } + } +} + +bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { + struct IsInGraph { + const OperatorInfoPtr test_; + explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} + bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } + }; + return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); +} + +void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { + std::vector curr_edges(edges_[{u_node, v_node}]); + curr_edges.push_back(edge); + edges_[{u_node, v_node}] = curr_edges; + + std::vector curr_out_edges(out_edges_[u_node]); + curr_out_edges.push_back(edge); + out_edges_[u_node] = curr_out_edges; + + std::vector curr_in_edges(in_edges_[v_node]); + curr_in_edges.push_back(edge); + in_edges_[v_node] = curr_in_edges; +} + +bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { + for (auto &edge_pair : edges_) { + auto edges = edge_pair.second; + for (auto &edge : edges) { + MS_EXCEPTION_IF_NULL(edge); + bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && + (edge->next_op_input_index() == input_index); + if (bool_result) { + return true; + } + } + } + return false; +} + +std::vector> CostGraph::ConstructConnectedComponents( + std::vector alive_ops) { + std::map visited; + + for (auto &op : alive_ops) { + visited[op] = false; + } + + MS_LOG(INFO) << "visited: " << visited.size() << "."; + for (auto &op : alive_ops) { + if ((!visited[op]) && op->is_alive()) { + std::shared_ptr new_component = std::make_shared(); + MS_EXCEPTION_IF_NULL(new_component); + new_component->SetDeviceMemoryAndCostParameter(); + DFS(op, &visited, new_component); + connected_compoents_.push_back(new_component); + } + } + return connected_compoents_; +} + +void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component) { + MS_EXCEPTION_IF_NULL(visited); + MS_EXCEPTION_IF_NULL(component); + visited->at(current_op) = true; + component->AddOperator(current_op); + + for (auto &edge : current_op->succ_edges()) { + bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && + (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); + if (bool_test) { + component->AddEdge(current_op, edge->next_operator(), edge); + DFS(edge->next_operator(), visited, component); + } + } + + for (auto &edge : current_op->prev_edges()) { + bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && + (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); + if (bool_test) { + component->AddEdge(edge->prev_operator(), current_op, edge); + DFS(edge->prev_operator(), visited, component); + } + } +} + +// Create final cost list for the graph: u --> v +CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, + const OperatorInfoPtr &v) { + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + MS_EXCEPTION_IF_NULL(e); + CostPtrList ret; + for (const auto &u_strategy : u->GetStrategyCost()) { + for (const auto &v_strategy : v->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(u_strategy); + MS_EXCEPTION_IF_NULL(v_strategy); + auto u_strategy_ptr = u_strategy->strategy_ptr; + auto v_strategy_ptr = v_strategy->strategy_ptr; + CostPtrList clist1 = u_strategy->cost_list; + CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); + CostPtrList clist3 = v_strategy->cost_list; + for (const auto &cost1 : clist1) { + for (const auto &cost2 : clist2) { + for (const auto &cost3 : clist3) { + MS_EXCEPTION_IF_NULL(cost1); + MS_EXCEPTION_IF_NULL(cost2); + MS_EXCEPTION_IF_NULL(cost3); + double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; + double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; + double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; + double communication_forward = + cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; + double communication_without_para = cost1->communication_without_parameter_ + + cost2->communication_without_parameter_ + + cost3->communication_without_parameter_; + auto decision = + std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); + auto cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(cost); + cost->communication_without_parameter_ = communication_without_para; + cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + cost->memory_with_reuse_ = memory; + cost->communication_forward_ = communication_forward; + ret.push_back(cost); + } + } + } + } + } + + Simplify(&ret); + return ret; +} + +// Create final cost list for the graph containing a signle node: u +CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { + MS_EXCEPTION_IF_NULL(u); + CostPtrList ret; + for (const auto &u_strategy : u->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(u_strategy); + auto u_strategy_ptr = u_strategy->strategy_ptr; + CostPtrList clist1 = u_strategy->cost_list; + for (const auto &cost1 : clist1) { + MS_EXCEPTION_IF_NULL(cost1); + auto decision = std::make_shared(u_strategy_ptr, cost1); + auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; + new_cost->communication_with_partial_para_ = + cost1->communication_without_parameter_ + + COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); + new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; + new_cost->communication_forward_ = cost1->communication_forward_; + ret.push_back(new_cost); + } + } + + Simplify(&ret); + return ret; +} + +CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { + // Select the cost with minimum inference time. Currently, the inference time is modeled as = + // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ + if (cost_list.empty()) { + MS_LOG(ERROR) << "Final cost list is null."; + return nullptr; + } + CostPtrList after_mem_filter; + double minimum_memory = DBL_MAX; + // Filter out the valid costs. + for (auto &a_cost : cost_list) { + if (a_cost->memory_with_reuse_ <= memory) { + after_mem_filter.emplace_back(std::move(a_cost)); + } else if (a_cost->memory_with_reuse_ < minimum_memory) { + minimum_memory = a_cost->memory_with_reuse_; + } + } + if (after_mem_filter.empty()) { + MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory + << ", the memory capacity is: " << memory << "."; + return nullptr; + } + // Init the returned value with first cost. + CostPtr ret = after_mem_filter[0]; + + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; + MS_LOG(INFO) << "Cost 0: " + << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ + << ", communication_forward_: " << ret->communication_forward_ + << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ + << ", communication_cost_: " << ret->communication_cost_ + << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; + MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; + for (size_t i = 1; i < after_mem_filter.size(); ++i) { + MS_EXCEPTION_IF_NULL(after_mem_filter[i]); + MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ + << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ + << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ + << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ + << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ + << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ + << "."; + auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + + costmodel_beta_ * after_mem_filter[i]->communication_forward_; + MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; + if (minimum > tmp) { + minimum = tmp; + ret = after_mem_filter[i]; + MS_LOG(INFO) << "Selected: " << i; + } + } + return ret; +} + +CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { + // Select the cost with minimum training time. Currently, the training time is modeled as = + // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ + if (cost_list.empty()) { + MS_LOG(ERROR) << "Final cost list is null."; + return nullptr; + } + CostPtrList after_mem_filter; + double minimum_memory = DBL_MAX; + // Filter out the valid costs. + for (auto &a_cost : cost_list) { + if (a_cost->memory_with_reuse_ <= memory) { + after_mem_filter.emplace_back(std::move(a_cost)); + } else if (a_cost->memory_with_reuse_ < minimum_memory) { + minimum_memory = a_cost->memory_with_reuse_; + } + } + if (after_mem_filter.empty()) { + MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory + << ", the memory capacity is: " << memory << "."; + return nullptr; + } + // Init the returned value with first cost. + CostPtr ret = after_mem_filter[0]; + + double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; + MS_LOG(INFO) << "Cost 0: " + << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ + << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ + << ", communication_cost_: " << ret->communication_cost_ + << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; + MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; + for (size_t i = 1; i < after_mem_filter.size(); ++i) { + MS_EXCEPTION_IF_NULL(after_mem_filter[i]); + MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ + << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ + << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ + << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ + << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ + << "."; + auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + + costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; + MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; + if (minimum > tmp) { + minimum = tmp; + ret = after_mem_filter[i]; + MS_LOG(INFO) << "Selected: " << i; + } + } + return ret; +} + +CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, + double available_memory) { + CostPtrList selected_cost_list(all_cost_list.size(), nullptr); + double minimum = DBL_MAX, total_memory = 0.0; + CostPtrList ret(all_cost_list.size(), nullptr); + // Check whether valid costs exist. + for (size_t i = 0; i < all_cost_list.size(); ++i) { + if (all_cost_list[i][0] == nullptr) { + MS_LOG(ERROR) << "The cost list " << i << " is empty."; + return ret; + } else { + double memory_i_cost = DBL_MAX; + for (size_t j = 0; j < all_cost_list[i].size(); ++j) { + if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) { + memory_i_cost = all_cost_list[i][j]->memory_with_reuse_; + } + } + total_memory += memory_i_cost; + } + } + if (total_memory >= available_memory) { + MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory + << ", minimum strategy cost: " << total_memory << "."; + return selected_cost_list; + } + + std::function recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, + &available_memory, this](size_t k) { + if (k == all_cost_list.size()) { + double tmp_memory = 0.0, tmp_minimum = 0.0; + for (size_t i = 0; i < selected_cost_list.size(); ++i) { + MS_EXCEPTION_IF_NULL(selected_cost_list[i]); + tmp_memory += selected_cost_list[i]->memory_with_reuse_; + tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + + costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; + } + MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum + << "."; + if (tmp_memory < available_memory && tmp_minimum < minimum) { + ret = selected_cost_list; + minimum = tmp_minimum; + MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << "."; + } + return; + } + + MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; + for (auto &c : all_cost_list[k]) { + selected_cost_list[k] = c; + recursive(k + 1); + } + }; + recursive(0); + return ret; +} + +Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { + MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; + auto connected_components = ConstructConnectedComponents(alive_ops); + MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; + std::vector all_list; + for (size_t j = 0; j < connected_components.size(); ++j) { + auto one_component = connected_components[j]; + MS_EXCEPTION_IF_NULL(one_component); + if (one_component->GetOperators().size() == 1) { + MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; + auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); + all_list.push_back(cost_list); + } else if (one_component->GetOperators().size() == 2) { + MS_LOG(INFO) << "There are 2 operators in a component in the final graph."; + OperatorInfoPtr u, v; + auto first_op = one_component->GetOperators()[0]; + auto second_op = one_component->GetOperators()[1]; + MS_EXCEPTION_IF_NULL(first_op); + MS_EXCEPTION_IF_NULL(second_op); + if (!first_op->GetAliveSuccEdges().empty() && + first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { + u = first_op; + v = second_op; + } else if (!second_op->GetAliveSuccEdges().empty() && + second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { + u = second_op; + v = first_op; + } else { + MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size() + << ", " << second_op->GetAliveSuccEdges().size() << "."; + } + MS_EXCEPTION_IF_NULL(u); + auto e = u->GetAliveSuccEdges()[0]; + auto cost_list = one_component->CreateFinalCostList(u, e, v); + all_list.push_back(cost_list); + } else { + MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size() + << " operators in a component in the final graph."; + } + } + // + auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); + for (size_t k = 0; k < selected_cost_list.size(); ++k) { + auto selected_cost = selected_cost_list[k]; + if (selected_cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(connected_components[k]); + if (connected_components[k]->GetOperators().size() == 1) { + auto u = connected_components[k]->GetOperators()[0]; + auto decision = selected_cost->decision_ptr_->cast(); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } else if (connected_components[k]->GetOperators().size() == 2) { + OperatorInfoPtr u = nullptr, v = nullptr; + auto first_op = connected_components[k]->GetOperators()[0]; + auto second_op = connected_components[k]->GetOperators()[1]; + MS_EXCEPTION_IF_NULL(first_op); + MS_EXCEPTION_IF_NULL(second_op); + if (!first_op->GetAliveSuccEdges().empty() && + first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { + u = first_op; + v = second_op; + } else if (!second_op->GetAliveSuccEdges().empty() && + second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { + u = second_op; + v = first_op; + } + MS_EXCEPTION_IF_NULL(u); + auto e = u->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(v); + MS_EXCEPTION_IF_NULL(e); + MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); + auto decision = selected_cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); + v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); + e->set_selected_cost(decision->middle_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } + } + return SUCCESS; +} + +// searching the strategy for the final eliminated graph +Status CostGraph::SearchStrategy() { + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; + std::vector alive_ops; + (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { + MS_EXCEPTION_IF_NULL(op); + if (op->is_alive()) { + alive_ops.push_back(op); + } + }); + + if (alive_ops.size() > 2) { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + return SearchStrategyForMultiNodeFinalGraph(alive_ops); + } else { + // inference phase + MS_LOG(EXCEPTION) + << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; + } + } else if (alive_ops.size() == 1) { + MS_LOG(INFO) << "There are 1 single node in the final graph."; + OperatorInfoPtr u = alive_ops[0]; + auto cost_list = CreateFinalSingleCostList(u); + CostPtr cost = nullptr; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + } else { + // inference phase + cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); + } + if (cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(cost->decision_ptr_); + auto decision = cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; + return SUCCESS; + } else { + // In this case, the final graph should contains exactly 2 nodes. + if (alive_ops.empty()) { + MS_LOG(INFO) << "0 Operator in the final graph."; + return SUCCESS; + } + OperatorInfoPtr u, v; + MS_EXCEPTION_IF_NULL(alive_ops[0]); + MS_EXCEPTION_IF_NULL(alive_ops[1]); + if (!alive_ops[0]->GetAliveSuccEdges().empty() && + alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { + u = alive_ops[0]; + v = alive_ops[1]; + } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && + alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { + u = alive_ops[1]; + v = alive_ops[0]; + } else { + if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { + MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() + << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; + } else { + // In this case, the final graph consists of two single nodes + MS_LOG(INFO) << "There are 2 single nodes in the final graph."; + std::vector all_list; + auto connected_components = ConstructConnectedComponents(alive_ops); + MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; + for (size_t i = 0; i < connected_components.size(); ++i) { + MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; + auto one_component = connected_components[i]; + MS_EXCEPTION_IF_NULL(one_component); + auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); + all_list.push_back(cost_list); + } + CostPtrList selected_cost_list; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); + } else { + // inference phase + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " + "phase is not supported."; + } + for (size_t k = 0; k < selected_cost_list.size(); ++k) { + auto selected_cost = selected_cost_list[k]; + if (selected_cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(connected_components[k]); + auto one_operator = connected_components[k]->GetOperators()[0]; + MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); + auto decision = selected_cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); + MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; + } + + return SUCCESS; + } + } + MS_LOG(INFO) << "There are 2 nodes in the final graph."; + // In this case, the finale graph is exactly of the form: u --> v + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + auto e = u->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(e); + auto cost_list = CreateFinalCostList(u, e, v); + CostPtr cost = nullptr; + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); + } else { + MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " + "phase is not supported."; + } + if (cost == nullptr) { + MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(cost->decision_ptr_); + auto decision = cost->decision_ptr_->cast(); + MS_EXCEPTION_IF_NULL(decision); + u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); + v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); + e->set_selected_cost(decision->middle_cost_); + MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; + return SUCCESS; + } +} + +// Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated +// return the v and the edge u --> v +OperatorInfoPtr CostGraph::CheckOpElimination() const { + for (auto &op : ops_) { + bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; + if (bool_test) { + if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { + return op; + } + } + } + return nullptr; +} + +// Check the graph whether an EdgeElimination can be performed +std::vector> CostGraph::CheckEdgeElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (!op->is_alive()) continue; + std::map count; + for (auto &edge : op->GetAliveSuccEdges()) { + MS_EXCEPTION_IF_NULL(edge); + auto v = edge->next_operator(); + count[v.get()]++; + } + for (auto &pair : count) { + auto *op_ptr = pair.first; + int op_count = pair.second; + if (op_count > 1) { + std::vector> ret; + for (auto &edge : op->GetAliveSuccEdges()) { + MS_EXCEPTION_IF_NULL(edge); + if (edge->next_operator().get() == op_ptr) { + ret.push_back(edge); + } + } + return ret; + } + } + } + return {}; +} + +// Check the graph whether a MergeElimination can be performed +OperatorInfoPtr CostGraph::CheckMergeElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; + if (bool_test) { + auto next_op = op->GetAliveSuccEdges()[0]->next_operator(); + MS_EXCEPTION_IF_NULL(next_op); + if (!next_op->GetAlivePrevEdges().empty()) { + return op; + } + } + } + return nullptr; +} + +// Check the graph whether a ContractElimination can be performed +OperatorInfoPtr CostGraph::CheckContractElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); + if (bool_test) { + auto edge = op->GetAlivePrevEdges()[0]; + MS_EXCEPTION_IF_NULL(edge); + auto prev_op = edge->prev_operator(); + MS_EXCEPTION_IF_NULL(prev_op); + if (!prev_op->GetAliveSuccEdges().empty()) { + return op; + } + } + } + return nullptr; +} + +// Check the graph whether a TriangleElimination can be performed +std::pair> CostGraph::CheckTriangleElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); + if (bool_test) { + auto edge1 = op->GetAliveSuccEdges()[0]; + auto edge2 = op->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(edge1); + MS_EXCEPTION_IF_NULL(edge2); + auto first_op = edge1->next_operator(); + auto second_op = edge2->next_operator(); + MS_EXCEPTION_IF_NULL(first_op); + for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { + if (first_op_succ_edge->next_operator() == second_op) { + return {op, first_op_succ_edge}; + } + } + MS_EXCEPTION_IF_NULL(second_op); + for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { + if (second_op_succ_edge->next_operator() == first_op) { + return {op, second_op_succ_edge}; + } + } + } + } + return {nullptr, nullptr}; +} + +// Check the graph whether a StarElimination can be performed. +// NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. +OperatorInfoPtr CostGraph::CheckStarElimination() const { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); + if (bool_test) { + return op; + } + } + return nullptr; +} + +// This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace +// 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. +std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { + // in this case, the operators are organised in the form of u-->op-->v, and the goal + // is to eliminate 'op'. + MS_EXCEPTION_IF_NULL(op); + MS_LOG(INFO) << "Now eliminating node: " << op->name() << "."; + auto edge_u_op = op->GetAlivePrevEdges()[0]; + auto edge_op_v = op->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(edge_u_op); + MS_EXCEPTION_IF_NULL(edge_op_v); + auto u = edge_u_op->prev_operator(); + auto v = edge_op_v->next_operator(); + std::vector output_indexs, input_indexs; + size_t output_index, input_index; + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); + std::shared_ptr new_edge; + if (edge_u_op->is_combined()) { + output_indexs = edge_u_op->prev_op_output_indexs(); + } else { + output_index = edge_u_op->prev_op_output_index(); + output_indexs.push_back(output_index); + } + if (edge_op_v->is_combined()) { + input_indexs = edge_op_v->next_op_input_indexs(); + } else { + input_index = edge_op_v->next_op_input_index(); + input_indexs.push_back(input_index); + } + + if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) { + new_edge = std::make_shared(new_edge_name, u, v, output_index, input_index, false); + } else { + new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); + } + MS_EXCEPTION_IF_NULL(new_edge); + new_edge->set_pre_op_output(edge_u_op->prev_op_output()); + new_edge->set_next_op_input(edge_op_v->next_op_input()); + new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v); + u->ReplaceSuccEdge(op, new_edge); + v->ReplacePreEdge(op, new_edge); + op->SetNotAlive(); + MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded."; + return new_edge; +} + +// This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', +// and sets new costlist for the new edge. +std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { + MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; + MS_EXCEPTION_IF_NULL(edges[0]); + auto u = edges[0]->prev_operator(); + auto v = edges[0]->next_operator(); + MS_EXCEPTION_IF_NULL(u); + MS_EXCEPTION_IF_NULL(v); + std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); + std::vector output_indexs, input_indexs; + + for (auto &edge : edges) { + MS_EXCEPTION_IF_NULL(edge); + if (edge->is_combined()) { + auto from_output_indexs = edge->prev_op_output_indexs(); + auto from_input_indexs = edge->next_op_input_indexs(); + (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs)); + (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs)); + } else { + output_indexs.push_back(edge->prev_op_output_index()); + input_indexs.push_back(edge->next_op_input_index()); + } + } + + std::shared_ptr new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); + MS_EXCEPTION_IF_NULL(new_edge); + new_edge->set_pre_op_output(edges[0]->prev_op_output()); + new_edge->set_next_op_input(edges[0]->next_op_input()); + + new_edge->EdgeEliminationSetNewCost(u, edges, v); + + u->ReplaceSuccEdges(v, new_edge); + v->ReplacePreEdges(u, new_edge); + MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded."; + return new_edge; +} + +// Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' +// for this contract under the strategy 'op_strategy' +void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, + CostPtrList *const tar_cost_list_new) { + for (size_t i = 0; i < op_cost_list.size(); ++i) { + auto &op_cost = op_cost_list[i]; + MS_EXCEPTION_IF_NULL(op_cost); + for (size_t j = 0; j < edge_cost_list.size(); ++j) { + auto &edge_cost = edge_cost_list[j]; + MS_EXCEPTION_IF_NULL(edge_cost); + for (size_t k = 0; k < tar_cost_list.size(); ++k) { + auto &tar_cost = tar_cost_list[k]; + MS_EXCEPTION_IF_NULL(tar_cost); + double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; + double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; + double communication = + op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; + double communication_forward = + op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; + double communication_without_para = op_cost->communication_without_parameter_ + + edge_cost->communication_without_parameter_ + + tar_cost->communication_without_parameter_; + + auto decision = + std::make_shared(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); + auto new_cost = std::make_shared(computation, communication, decision); + MS_EXCEPTION_IF_NULL(new_cost); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + MS_EXCEPTION_IF_NULL(tar_cost_list_new); + tar_cost_list_new->emplace_back(std::move(new_cost)); + } + } + } +} + +// This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the +// target_op +OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { + MS_EXCEPTION_IF_NULL(op); + auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); + auto edge_ptr = op->GetAliveSuccEdges()[0]; + MS_EXCEPTION_IF_NULL(target_op); + MS_EXCEPTION_IF_NULL(edge_ptr); + MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; + bool valid = false; + + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(tar_stra_cost); + auto tar_stra = tar_stra_cost->strategy_ptr; + auto tar_clist_origin = tar_stra_cost->cost_list; + CostPtrList tar_clist_new; + + for (auto &op_stra_cost : op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(op_stra_cost); + auto op_stra = op_stra_cost->strategy_ptr; + auto op_clist = op_stra_cost->cost_list; + auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra); + + CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); + } + Simplify(&tar_clist_new); + // Set the new costlist w.r.t the strategy + tar_stra_cost->cost_list = tar_clist_new; + if ((!valid) && (!tar_clist_new.empty())) { + valid = true; + } + } + + if (!valid) { + MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed."; + } + op->SetNotAlive(); + MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded."; + return target_op; +} + +// Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' +// for this contract under the strategy 'contract_op_stra' +void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, + const CostPtrList &contract_op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { + for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { + auto &contract_op_cost = contract_op_cost_list[i]; + MS_EXCEPTION_IF_NULL(contract_op_cost); + for (size_t j = 0; j < edge_cost_list.size(); ++j) { + auto &edge_cost = edge_cost_list[j]; + MS_EXCEPTION_IF_NULL(edge_cost); + for (size_t k = 0; k < tar_cost_list.size(); ++k) { + auto &tar_cost = tar_cost_list[k]; + MS_EXCEPTION_IF_NULL(tar_cost); + double computation = + contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; + double memory = + contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; + double communication = + contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; + double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + + tar_cost->communication_forward_; + double communication_without_para = contract_op_cost->communication_without_parameter_ + + edge_cost->communication_without_parameter_ + + tar_cost->communication_without_parameter_; + + auto decision = std::make_shared(contract_op_stra, contract_op_cost, edge_cost, + target_op_stra, tar_cost); + auto new_cost = std::make_shared(computation, communication, decision); + new_cost->communication_without_parameter_ = communication_without_para; + new_cost->communication_with_partial_para_ = + communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); + new_cost->memory_with_reuse_ = memory; + new_cost->communication_forward_ = communication_forward; + tar_cost_list_new->emplace_back(std::move(new_cost)); + } + } + } +} + +// This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the +// target_op +OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { + MS_EXCEPTION_IF_NULL(op); + auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); + auto edge_ptr = op->GetAlivePrevEdges()[0]; + MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; + bool valid = false; + + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(tar_stra_cost); + auto tar_stra = tar_stra_cost->strategy_ptr; + auto tar_clist_origin = tar_stra_cost->cost_list; + CostPtrList tar_clist_new; + + for (auto &op_stra_cost : op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(op_stra_cost); + auto op_stra = op_stra_cost->strategy_ptr; + auto op_clist = op_stra_cost->cost_list; + auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra); + + CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); + } + Simplify(&tar_clist_new); + // Set the new costlist w.r.t the strategy + tar_stra_cost->cost_list = tar_clist_new; + if ((!valid) && (!tar_clist_new.empty())) { + valid = true; + } + } + if (!valid) { + MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed."; + } + op->SetNotAlive(); + MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded."; + return target_op; +} + +void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, + StrategyPtr right_op_stra, const CostPtr &right_op_cost, + const CostPtrList &elimi_op_clist, + const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { + MS_EXCEPTION_IF_NULL(right_edge_cost); + MS_EXCEPTION_IF_NULL(right_op_cost); + MS_EXCEPTION_IF_NULL(left_node_clist_new); + for (auto &elimi_op_cost : elimi_op_clist) { + MS_EXCEPTION_IF_NULL(elimi_op_cost); + for (auto &left_edge_cost : left_edge_clist) { + MS_EXCEPTION_IF_NULL(left_edge_cost); + for (auto &left_node_cost : left_node_clist_origin) { + MS_EXCEPTION_IF_NULL(left_node_cost); + double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; + double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ + + left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; + double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + + left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; + double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + + left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; + double new_commu_without = + elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + + left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; + + auto decision = std::make_shared( + elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); + auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); + new_cost->communication_without_parameter_ = new_commu_without; + new_cost->communication_with_partial_para_ = + new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); + new_cost->memory_with_reuse_ = new_memory; + new_cost->communication_forward_ = new_commu_forward; + left_node_clist_new->emplace_back(std::move(new_cost)); + } + } + } +} + +void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, + const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, + const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, + const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { + MS_EXCEPTION_IF_NULL(elimi_op); + for (auto &right_node_cost : right_node_clist) { + MS_EXCEPTION_IF_NULL(right_node_cost); + for (auto &right_edge_cost : right_edge_clist) { + MS_EXCEPTION_IF_NULL(right_edge_cost); + CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, + elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, + left_node_clist_new); + } + } +} + +OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, + const std::shared_ptr &edge_left_right) { + MS_EXCEPTION_IF_NULL(edge_left_right); + MS_EXCEPTION_IF_NULL(elimi_op); + MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; + auto left_node = edge_left_right->prev_operator(); + auto right_node = edge_left_right->next_operator(); + auto left_edge = elimi_op->GetAliveSuccEdges()[0]; + auto right_edge = elimi_op->GetAliveSuccEdges()[1]; + MS_EXCEPTION_IF_NULL(left_node); + MS_EXCEPTION_IF_NULL(right_node); + MS_EXCEPTION_IF_NULL(left_edge); + MS_EXCEPTION_IF_NULL(right_edge); + MS_LOG(INFO) << "The left operator is: " << left_node->name() << "."; + MS_LOG(INFO) << "The right operator is: " << right_node->name() << "."; + + if (left_edge->next_operator() != left_node) { + auto tmp = left_edge; + left_edge = right_edge; + right_edge = tmp; + } + bool valid = false; + + for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(left_node_stra_cost); + auto left_node_stra = left_node_stra_cost->strategy_ptr; + auto left_node_clist_origin = left_node_stra_cost->cost_list; + CostPtrList left_node_clist_new; + + for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); + auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; + auto elimi_op_clist = elimi_op_stra_cost->cost_list; + auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); + + for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(right_node_stra_cost); + auto right_node_stra = right_node_stra_cost->strategy_ptr; + auto right_node_clist = right_node_stra_cost->cost_list; + auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra); + + CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra, + right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin, + &left_node_clist_new); + } + } + Simplify(&left_node_clist_new); + // Set the new costlist w.r.t the strategy + left_node_stra_cost->cost_list = left_node_clist_new; + if ((!valid) && (!left_node_clist_new.empty())) { + valid = true; + } + } + + if (!valid) { + MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; + } + elimi_op->SetNotAlive(); + MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; + return left_node; +} + +void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + std::vector succ_nodes_stras, + CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, + CostPtrList *first_succ_node_clist_new) { + for (auto &first_succ_node_cost : first_succ_node_clist) { + for (auto &first_succ_edge_cost : first_succ_edge_clist) { + for (auto &merged_node_cost : merged_op_clist) { + MS_EXCEPTION_IF_NULL(merged_node_cost); + succ_nodes_stras[0] = first_succ_node_stra; + succ_edges_costs[0] = first_succ_edge_cost; + succ_nodes_costs[0] = first_succ_node_cost; + + double computation_cost = merged_node_cost->computation_cost_, + memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, + commu_without = merged_node_cost->communication_without_parameter_, + commu_forward = merged_node_cost->communication_forward_; + for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { + MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); + if (i == 0) { + computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; + memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; + commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; + commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; + commu_without += succ_edges_costs[i]->communication_without_parameter_ + + succ_nodes_costs[i]->communication_without_parameter_; + } else { + computation_cost += succ_edges_costs[i]->computation_cost_; + memory_cost += succ_edges_costs[i]->memory_with_reuse_; + commu_cost += succ_edges_costs[i]->communication_cost_; + commu_forward += succ_edges_costs[i]->communication_forward_; + commu_without += succ_edges_costs[i]->communication_without_parameter_; + } + } + + auto decision = std::make_shared(merged_op_stra, merged_node_cost, succ_edges_costs, + succ_nodes_stras, succ_nodes_costs); + auto new_cost = std::make_shared(computation_cost, commu_cost, decision); + new_cost->communication_without_parameter_ = commu_without; + new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); + new_cost->memory_with_reuse_ = memory_cost; + new_cost->communication_forward_ = commu_forward; + first_succ_node_clist_new->emplace_back(std::move(new_cost)); + } + } + } +} + +void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, + const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + CostPtrList *first_succ_node_clist_new) { + std::vector succ_nodes_stras(succ_edges.size(), nullptr); + CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); + std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, + &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs, + &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive, + this](size_t k) { + if (k == succ_edges.size()) { + CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, + merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs, + succ_nodes_costs, first_succ_node_clist_new); + return; + } + MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size() + << ", first_succ_edge_clist: " << first_succ_edge_clist.size() + << ", merged_op_clist: " << merged_op_clist.size() + << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << "."; + auto succ_edge = succ_edges[k]; + MS_EXCEPTION_IF_NULL(succ_edge); + auto succ_node = succ_edge->next_operator(); + MS_EXCEPTION_IF_NULL(succ_node); + for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(succ_node_stra_cost); + auto succ_node_stra = succ_node_stra_cost->strategy_ptr; + auto succ_node_clist = succ_node_stra_cost->cost_list; + auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); + + for (auto &succ_node_cost : succ_node_clist) { + MS_EXCEPTION_IF_NULL(succ_node_cost); + for (auto &succ_edge_cost : succ_edge_clist) { + MS_EXCEPTION_IF_NULL(succ_edge_cost); + succ_nodes_stras[k] = succ_node_stra; + succ_edges_costs[k] = succ_edge_cost; + succ_nodes_costs[k] = succ_node_cost; + recursive(k + 1); + } + } + } + }; + + recursive(1); +} + +std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { + MS_EXCEPTION_IF_NULL(merged_op); + auto succ_edges = merged_op->GetAliveSuccEdges(); + MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; + for (auto &succ_edge : succ_edges) { + MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); + MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; + } + + MS_EXCEPTION_IF_NULL(succ_edges[0]); + auto first_succ_node = succ_edges[0]->next_operator(); + auto first_succ_edge = succ_edges[0]; + bool valid = false; + + // 'merged_op' is merged into first_node + MS_EXCEPTION_IF_NULL(first_succ_node); + for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); + auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; + auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; + CostPtrList first_succ_node_clist_new; + + for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { + MS_EXCEPTION_IF_NULL(merged_op_stra_cost); + auto merged_op_stra = merged_op_stra_cost->strategy_ptr; + auto merged_op_clist = merged_op_stra_cost->cost_list; + auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra); + + CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, + merged_op_stra, merged_op_clist, &first_succ_node_clist_new); + } + Simplify(&first_succ_node_clist_new); + // Set the new costlist w.r.t the strategy + first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; + if ((!valid) && (!first_succ_node_clist_new.empty())) { + valid = true; + } + } + + if (!valid) { + MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; + } + + merged_op->SetNotAlive(); + MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded."; + return succ_edges; +} + +size_t CostGraph::GetNumEdges() const { + size_t sum = 0; + for (const auto &kv : edges_) { + auto &edges = kv.second; + sum += edges.size(); + } + return sum; +} +Status CostGraph::InitSelectedStrategy() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } + auto result = op->InitSelectedStrategy(op->selected_strategy()); + if (result != SUCCESS) { + return result; + } + } + // reshape init should be apply after the init of it's previous node and next node. + for (size_t i = 0; i < ops_.size(); ++i) { + if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { + auto reshape_info = std::dynamic_pointer_cast(ops_[i]); + auto in_edges = GetOriginalPrevEdges(ops_[i]); + auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr edge) { + return edge->prev_operator()->name() == reshape_info->pre_operator_name(); + }); + auto out_edges = GetOriginalNextEdges(ops_[i]); + auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr edge) { + return edge->next_operator()->name() == reshape_info->next_operator_name(); + }); + if (pre_iter != in_edges.end()) { + MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); + int32_t pre_index = reshape_info->pre_operator_index(); + TensorInfo pre_info; + if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { + pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; + } else { + pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; + } + reshape_info->SetInputLayout(pre_info.tensor_layout()); + Dimensions stra = pre_info.InferStrategy(); + if (stra.empty()) { + MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; + } + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = + std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); + reshape_info->set_strategy(reshape_stra); + } + if (next_iter != out_edges.end()) { + MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); + int32_t next_index = reshape_info->next_operator_index(); + reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); + } + if (reshape_info->Init(nullptr) != SUCCESS) { + return FAILED; + } + } + } + return SUCCESS; +} + +Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); + if ((output_parameter != 0) && (output_parameter != 1)) { + MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; + return FAILED; + } + } + return SUCCESS; +} + +void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map *visited, + std::vector *topo_order) { + MS_EXCEPTION_IF_NULL(current_op); + MS_EXCEPTION_IF_NULL(visited); + MS_EXCEPTION_IF_NULL(topo_order); + + visited->at(current_op) = true; + for (const auto &s_edge : current_op->succ_edges()) { + if (!visited->at(s_edge->next_operator())) { + DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); + } + } + topo_order->push_back(current_op); +} + +// Compute a topological order of the costgraph +void CostGraph::TopologyOrder(std::vector *topo_order) { + std::map visited; + for (auto &op : ops_) { + visited[op] = false; + } + + for (auto &op : ops_) { + if (!visited[op]) { + DFSForTopoOrder(op, &visited, topo_order); + } + } +} +void CostGraph::MarkCriticalOpsAndEdges(const std::map &candidate_ops) { + for (auto &op : ops_) { + auto search = candidate_ops.find(op); + if (search != candidate_ops.end()) { + // Mark the critical operators + op->mark_output_critical(); + // Mark the successive edges + for (auto &s_edge : op->succ_edges()) { + s_edge->mark_output_critical(); + } + } else { + op->mark_output_not_critical(); + } + } +} + +Status CostGraph::DetermineCriticalOps(const std::vector &topo_order) { + if (topo_order.size() == 0) { + MS_LOG(ERROR) << "0 operator in costgraph."; + return FAILED; + } + auto &first_op = topo_order[0]; + if (first_op->prev_edges().size() > 0) { + MS_LOG(ERROR) << "The first operator in the first of topological order of " + "costgraph should have 0 incoming edge, but has " + << first_op->prev_edges() << "edges."; + return FAILED; + } + // The 'curr_memory_state' records , where remaining_output_cnt is the number + // of the output of OperatorInfo that currently has not been used + std::map curr_memory_state; + (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); + std::map max_memory_state = curr_memory_state; + // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has + // not been used + double curr_memory_size = first_op->GetOutputsTotalSize(); + double max_memory_size = curr_memory_size; + + for (size_t finished = 1; finished < topo_order.size(); ++finished) { + // Produce + (void)curr_memory_state.emplace( + std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); + curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); + // Consume + for (const auto &prev_edge : topo_order[finished]->prev_edges()) { + const auto &prev_op = prev_edge->prev_operator(); + curr_memory_state[prev_op]--; + } + for (const auto &prev_edge : topo_order[finished]->prev_edges()) { + const auto &prev_op = prev_edge->prev_operator(); + if (curr_memory_state[prev_op] < 0) { + MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; + return FAILED; + } else if (curr_memory_state[prev_op] == 0) { + curr_memory_state.erase(prev_op); + curr_memory_size -= prev_op->GetOutputsTotalSize(); + } + } + + if (curr_memory_size < 0) { + MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; + } + // Modify the max + if (curr_memory_size > max_memory_size) { + max_memory_size = curr_memory_size; + max_memory_state = curr_memory_state; + } + } + // Mark those critical operators + MarkCriticalOpsAndEdges(max_memory_state); + return SUCCESS; +} + +Status CostGraph::ComputeOpsAndEdgesOutputCritical() { + // Two steps to do: + // 1. Compute a topological order of the costgraph + // 2. Determine and mark the operators (and necessary edges) that are critical + std::vector topo_order; + TopologyOrder(&topo_order); + std::reverse(std::begin(topo_order), std::end(topo_order)); + + if (DetermineCriticalOps(topo_order) != SUCCESS) { + MS_LOG(ERROR) << "Determining critical operators failed."; + return FAILED; + } + return SUCCESS; +} + +Status CostGraph::CalculateOpsMemoryCost() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->CalculateMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; + return FAILED; + } + } + return SUCCESS; +} + +Status CostGraph::CalculateOpsMemoryCostForInference() { + for (auto &op : ops_) { + MS_EXCEPTION_IF_NULL(op); + if (op->CalculateMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; + return FAILED; + } + } + return SUCCESS; +} + +Status CostGraph::CalculateEdgesMemoryCost() { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { + if (one_edge->CalculateMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; + return FAILED; + } + } + } + return SUCCESS; +} + +Status CostGraph::CalculateEdgesMemoryCostForInference() { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { + if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; + return FAILED; + } + } + } + return SUCCESS; +} + +OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { + for (auto one_op : ops_) { + if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { + if (one_op->refkey_parameter_name() == p_name) { + return one_op; + } + } + } + return nullptr; +} +Status CostGraph::CorrectOpsMemoryCost() { + for (auto &one_op : ops_) { + if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { + if (one_op->GetAliveSuccEdges().size() > 1) { + // Filter out the case when the TmpIdentity being used by multiple operators + std::map output_count; + for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { + auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); + output_count[output_index]++; + } + for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { + auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); + if (output_count[output_index] <= 1) { + continue; + } + auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator(); + MS_EXCEPTION_IF_NULL(next_op); + auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index(); + if (next_op->CorrectMemoryCost(input_index) != SUCCESS) { + MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name() + << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; + return FAILED; + } + output_count[output_index]--; + } + } + } + } + return SUCCESS; +} + +Status CostGraph::CalculateMemoryCost() { + if (RUN_PHASE == TRAINING_PHASE) { + // training phase + if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { + // Calculate operators' memory usage + if (CalculateOpsMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; + return FAILED; + } + // Calculate edges' memory usage + if (CalculateEdgesMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; + return FAILED; + } + // Correct memory usage caused by TmpIdentity + if (CorrectOpsMemoryCost() != SUCCESS) { + MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; + return FAILED; + } + } else { + // inference phase + if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { + // Calculate operators' memory usage + if (CalculateOpsMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; + return FAILED; + } + // Calculate edges's memory usage + if (CalculateEdgesMemoryCostForInference() != SUCCESS) { + MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; + return FAILED; + } + } else { + MS_LOG(ERROR) << "Computing operators' critical flag failed."; + return FAILED; + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h new file mode 100644 index 0000000000..87f13e3383 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h @@ -0,0 +1,238 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ + +#include +#include +#include +#include +#include +#include "mindspore/ccsrc/common.h" +#include "common/utils.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/costmodel_context.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" + +namespace mindspore { +namespace parallel { +#define OPERATOR_TO_OPERATOR_CONNECTOR "-" +#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) +#define DEFAULT_COST_MODEL_ALPHA 1.0 +#define DEFAULT_COST_MODEL_BETA 400.0 +#define DEFAULT_COST_MODEL_GAMMA 0.001 +#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true +#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 +#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 +#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 +#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false +#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 +#define DEFAULT_FULLY_USE_DEVICES true +#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false +#define DEFAULT_IS_MULTI_SUBGRAPHS false +#define DEFAULT_RUN_PHASE 0 +#define TRAINING_PHASE 0 +#define INFERENCE_PHASE 1 + +class CostGraph; +using CostGraphPtr = std::shared_ptr; +extern CostGraphPtr entire_costgraph; +extern size_t TOTAL_OPS; +extern double COST_MODEL_GAMMA; +extern bool COST_MODEL_SIMPLIFY_CALCULATION; +extern double DEVICE_MEMORY_CAPACITY; +extern double COST_MODEL_COMMUNI_THRESHOLD; +extern double COST_MODEL_COMMUNI_CONST; +extern double COST_MODEL_COMMUNI_BIAS; +extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; +extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; +extern bool FULLY_USE_DEVICES; +extern bool ELEMENTWISE_OP_STRA_FOLLOW; +extern bool MULTI_SUBGRAPHS; +extern int32_t RUN_PHASE; + +class CostGraph { + // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have + // output-input dependency relationship. + public: + CostGraph() { + dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; + costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; + costmodel_beta_ = DEFAULT_COST_MODEL_BETA; + } + ~CostGraph() = default; + void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } + OperatorInfoPtr FindOperatorByIndex(size_t index) { + if (index >= ops_.size()) { + MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; + return nullptr; + } + return ops_[index]; + } + void RemoveOperator(const OperatorInfoPtr &op); + bool IsOperatorInCostGraph(const OperatorInfoPtr &op); + // the edge is in the form: u --> v + void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); + std::vector> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } + std::vector> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } + // An edge is uniquely identified by its name, and its output index and input index. + bool IsEdgeInCostGraph(const std::string &, size_t, size_t); + + void SetDeviceMemoryAndCostParameter(); + + std::vector> ConstructConnectedComponents(std::vector); + void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component); + + CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); + CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); + CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); + CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); + CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); + Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); + std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { + return edges_[{u_node, v_node}]; + } + double GetDeviceMemory() const { return dev_memory_; } + + // Search the cost_list in the final graph, and determine the optimal one + Status SearchStrategy(); + + // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated + OperatorInfoPtr CheckOpElimination() const; + // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges + // can be eliminated into one + std::vector CheckEdgeElimination() const; + // Given a graph which contains the following subgraph: + // u + // | + // w --- v --- x + // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v. + // u is returned. + OperatorInfoPtr CheckMergeElimination() const; + // Given a graph which contains the following subgraph: + // u + // | + // v --- x + // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted + // into v. u is returned. + OperatorInfoPtr CheckContractElimination() const; + /* Given a graph which contains the following subgraph: + * u + * / \ + * / \ + * v --- w + * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v. + * The returned value includes u and the edge >. + */ + std::pair CheckTriangleElimination() const; + /* Given a graph which contains the following subgraph: + * v <--- u ---> w + * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections, + * resulting in v and w can not be performed ContractElimination. u is returned. + * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. + */ + OperatorInfoPtr CheckStarElimination() const; + // Applying Operator Elimination in DP algorithm + EdgePtr EliminationOp(const OperatorInfoPtr &op); + // Applying Edge Elimination in DP algorithm + EdgePtr EliminationEdges(const std::vector &edges); + // Applying Merge Elimination in DP algorithm + OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); + void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); + // Applying Contract Elimination in DP algorithm + OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); + void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, + const CostPtrList &, CostPtrList *); + + // Applying Triangle Elimination in DP algorithm. return the left_node + OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); + void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, + const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); + // Given the relevant costlist, create the TriangleElimination cost + void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, + const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); + + // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op + // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. + std::vector EliminationStar(const OperatorInfoPtr &op); + void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, + const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); + void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const CostPtrList &, std::vector, + CostPtrList &, CostPtrList &, CostPtrList *); + // Calculate memory cost for training phase or inference phase. + Status CalculateMemoryCost(); + // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then + // the memory cost can be resused. This is used to calculate memory in the training phase. + Status CalculateOpsMemoryCost(); + // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then + // the memory cost can be reused. This is used to calculate memory in the training phase. + Status CalculateEdgesMemoryCost(); + // Calculate memory cost of operators in the inference phase. + Status CalculateOpsMemoryCostForInference(); + // Calculate memory cost of edges in the inference phase. + Status CalculateEdgesMemoryCostForInference(); + Status ComputeOpsAndEdgesParameterInvolved(); + // Compute for each operator whether the output is critical. + Status ComputeOpsAndEdgesOutputCritical(); + + std::vector GetOperators() const { return ops_; } + size_t GetNumEdges() const; + Status InitSelectedStrategy(); + OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; + // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only + // once (instead of multiple times), this method is used to correct this. + Status CorrectOpsMemoryCost(); + // Needed by rec_parser + void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { + inputs_tensor_name_list_.push_back(inputs_tensor_name); + } + const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } + void add_tuple_getitem(const std::pair &tuple_getitem) { + auto ret = tuple_getitem_list_.insert(tuple_getitem); + if (ret.second == false) { + MS_LOG(EXCEPTION) << "The insert item is already exist."; + } + } + const std::map get_tuple_getitem_list() const { return tuple_getitem_list_; } + + private: + void TopologyOrder(std::vector *); + void DFSForTopoOrder(const OperatorInfoPtr &, std::map *, std::vector *); + Status DetermineCriticalOps(const std::vector &); + void MarkCriticalOpsAndEdges(const std::map &); + // Needed by rec_parser + std::vector> inputs_tensor_name_list_; + std::map tuple_getitem_list_; + double dev_memory_; + double costmodel_alpha_; + double costmodel_beta_; + std::vector ops_; + std::map, std::vector> edges_; + std::vector> connected_compoents_; + std::map> out_edges_; + std::map> in_edges_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc new file mode 100644 index 0000000000..aaf3fdff3c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.cc @@ -0,0 +1,892 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" + +#include +#include +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } + +void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { + is_parameter_involve_ = is_parameter_inv; +} + +void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } + +void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { + inputs_type_lengths_ = input_lengths; + outputs_type_lengths_ = output_lengths; +} + +void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } + +double OperatorCost::GetMemoryCost(const std::vector &inputs, + const std::vector &outputs) const { + double result = 0.0; + if (output_parameter_involve_ == 1) { + // When this operator has multiple outputs, they all contributes to the memory. + for (size_t i = 0; i < outputs.size(); ++i) { + result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); + } + bool is_any_para_inv = + std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); + if (is_any_para_inv) { + for (size_t i = 0; i < inputs.size(); ++i) { + if (is_parameter_[i]) { + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } else if (inputs_related_ && (!is_parameter_involve_[i])) { + // When the inputs of this operator are related, and they are not parameter-involved, then they are included + // in the memory cost. + result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } + } + } + } + + return result; +} + +double OperatorCost::GetMemoryCostForInference(const std::vector &, + const std::vector &outputs) const { + double result = 0.0; + if (is_outputs_critical_ == -1) { + MS_LOG(EXCEPTION) << "The critical flag is not set."; + } + if (is_outputs_critical_ == 1) { + for (size_t i = 0; i < outputs.size(); ++i) { + result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); + } + } + return result; +} + +// return the per device communication cost in the forward phase. +double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const { + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = input0.slice_shape(); + if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) { + // If the reduced dimension has not been partitioned, then there is no communication cost. + return 0.0; + } else { + // Else, the communication cost is the size (number of bytes) of a slice of output tensor. + return ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } +} + +// return the per device communication cost in the forward phase. +double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not + // fully utilize all devices + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double MatMulCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) + double result = 0.0; + TensorInfo output0 = outputs[0]; + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape input0_shape = inputs[0].shape(); + if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) { + // If the reduced dimension has been partitioned, then there is no communication cost. + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + + return result; +} + +// Return the per device communication cost in the forward phase. +double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // ReLU is the element-wise operator, thus it does not need communication in the forward phase + return 0.0; +} + +// Return the per device communication cost in the backward phase. +double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input1 = inputs[0]; + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + TensorInfo input0_info = inputs[0]; + Shape input0_slice_shape = input0_info.slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const { + return 0.0; +} + +// Return the per device communication cost in the forward phase. +double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // In the forward phase, the communication cost = 0 + return 0.0; +} + +// Return the per device communication cost in the backward phase. +double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input1 = inputs[0]; + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In the forward phase, the computation cost = slice(A) + TensorInfo input0 = inputs[0]; + Shape input0_slice_shape = input0.slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +// return the per device communication cost in the forward phase. +double TmpIdentityCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // Identity is the element-wise operator, thus it does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double TmpIdentityCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // Identity is the element-wise operator, thus it does not need communication in the backward phase + return 0.0; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double TmpIdentityCost::GetForwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { + return 0.0; +} + +// Return the per device PEAK memory cost contributed by this operator in a training iteration. +double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { + return 0.0; +} + +double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, + int32_t) const { + double cost = 0.0; + for (size_t i = 0; i < inputs.size(); ++i) { + cost += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); + } + return cost; +} + +double BatchParallelCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { + return 0.0; +} + +double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + + return result; +} +// return the per device communication cost in the forward phase. +double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { + // prelu does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, + int32_t stage_id) const { + // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) + double result = 0.0; + if (is_parameter_[1]) { + TensorInfo input1 = inputs[1]; // tensor B + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// return the per device communication cost in the forward phase. +double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { + // onehot does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // onehot does not need communication in the backward phase + return 0.0; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In onehot's forward phase, the computation cost = slice(A) + Shape input0_slice_shape = inputs[0].slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const { + return 0.0; +} + +// return the per device communication cost in the forward phase. +double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { + // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase + return 0.0; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +// return the per device communication cost in the forward phase. +double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + TensorRedistribution tensor_redistribution(false, true); + if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; + } + if (tensor_redistribution.ComputeCost() == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; + } + return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost()); +} + +// return the per device communication cost in the backward phase. +double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input1 = inputs[0]; + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + Shape input1_shape = input1.shape(); + Shape input1_slice_shape = input1.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input1_shape.size(); ++i) { + used_device_num *= input1_shape[i] / input1_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + } + return result; +} + +// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes +// this operator uses +double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + TensorRedistribution tensor_redistribution(false, true); + if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; + } + if (tensor_redistribution.ComputeCost() == FAILED) { + MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; + } + return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); +} + +// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes +// this operator uses +double ReshapeCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { + return 0.0; +} + +double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + double result; + result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + + ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); + return result; +} + +double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + if (is_parameter_[0]) { + TensorInfo input_a_tensor_info = inputs[0]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + + if (is_parameter_[1]) { + TensorInfo input_b_tensor_info = inputs[1]; + Shape input_b_shape = input_b_tensor_info.shape(); + Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_b_shape.size(); ++i) { + used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + return result; +} + +double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + if (is_parameter_[0]) { + TensorInfo input_a_tensor_info = inputs[0]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + + if (is_parameter_[1]) { + TensorInfo input_b_tensor_info = inputs[1]; + Shape input_b_shape = input_b_tensor_info.shape(); + Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_b_shape.size(); ++i) { + used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); + } + + return result; +} + +bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + auto strategy0 = shape[0] / slice_shape[0]; + + return (total_device_num == IntToSize(strategy0)); +} + +double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + Shape input0_shape = input0.shape(); + Shape input0_slice_shape = input0.slice_shape(); + if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { + return result; + } + std::vector dim_list = input0.reduce_dim(); + std::vector::iterator pos; + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { + return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; + }); + if (pos != dim_list.end()) { + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } + + return result; +} + +double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_[0]) { + TensorInfo input_tensor_info = inputs[0]; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape input_shape = input_tensor_info.shape(); + Shape input_slice_shape = input_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_shape.size(); ++i) { + used_device_num *= input_shape[i] / input_slice_shape[i]; + } + + if (total_device_num != IntToSize(used_device_num)) + result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + + return result; +} + +double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + std::vector dim_list = input0.reduce_dim(); + Shape input0_slice_shape = input0.slice_shape(); + Shape input0_shape = input0.shape(); + if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { + std::vector::iterator pos; + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { + return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; + }); + if (pos != dim_list.end()) { + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); + } + } + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); + + return result; +} + +double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + TensorInfo input0 = inputs[0]; + TensorInfo output0 = outputs[0]; + std::vector dim_list = input0.reduce_dim(); + Shape input0_slice_shape = input0.slice_shape(); + Shape input0_shape = input0.shape(); + if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { + std::vector::iterator pos; + pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { + return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; + }); + if (pos != dim_list.end()) { + result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]) * 2.0; + } + } + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); + + return result; +} + +double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + if (inputs.empty()) { + return 0.0; + } + TensorInfo input0 = inputs[0]; + Shape input0_slice_shape = input0.slice_shape(); + return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; +} + +// return the per device communication cost in the forward phase. +double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { + // GatherV2Cost does not need communication in the forward phase + return 0.0; +} + +// return the per device communication cost in the backward phase. +double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + + return result; +} + +double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + // In forward phase, the computation cost = slice(A) + slice(B) + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + return result; +} + +double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const { + return 0.0; +} + +double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, + int32_t stage_id) const { + double result = 0.0; + if (is_parameter_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost"; + } + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; + } + + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t index = 0; index < inputs.size(); ++index) { + if (is_parameter_[index]) { + TensorInfo tensor_info = inputs[index]; + Shape shape = tensor_info.shape(); + Shape slice_shape = tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (slice_shape[i] == 0) { + MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape); + } + used_device_num *= shape[i] / slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); + } + } + } + return result; +} + +double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, + int32_t) const { + double result = 0.0; + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; + } + + for (size_t index = 0; index < inputs.size(); ++index) { + TensorInfo tensor_info = inputs[index]; + Shape slice_shape = tensor_info.slice_shape(); + result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); + } + return result; +} + +double GatherV2PCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + if (outputs_type_lengths_.size() != outputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; + } + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + return result; + } + + // split axis + auto param_shape = inputs[0].slice_shape(); + auto index_shape = inputs[1].slice_shape(); + Shape reducescatter_shape = index_shape; + if (param_shape.size() == 2) { + reducescatter_shape.push_back(param_shape.at(1 - axis_)); + } + result += ListProduct(reducescatter_shape) * static_cast(outputs_type_lengths_[0]); + return result; +} + +double GatherV2PCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const { + double result = 0.0; + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + for (size_t j = 0; j < inputs.size(); ++j) { + if (!is_parameter_[j]) { + continue; + } + TensorInfo input_a_tensor_info = inputs[j]; + Shape input_a_shape = input_a_tensor_info.shape(); + Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); + int32_t used_device_num = 1; + for (size_t i = 0; i < input_a_shape.size(); ++i) { + used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; + } + if (total_device_num != IntToSize(used_device_num)) { + result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); + } + } + return result; +} + +double GatherV2PCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { + double result = 0.0; + Shape input0_slice_shape = inputs[0].slice_shape(); + Shape input1_slice_shape = inputs[1].slice_shape(); + if (inputs_type_lengths_.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; + } + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); + } else { + // split axis + result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1; + } + + return result; +} + +double GatherV2PCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { + double result = 0.0; + Shape input1_slice_shape = inputs[1].slice_shape(); + Shape output0_slice_shape = outputs[0].slice_shape(); + // don't split axis + if (strategy_.at(IntToSize(axis_)) == 1) { + result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]); + } else { + // split axis + result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 + + ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3; + } + + return result; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h new file mode 100644 index 0000000000..dda597bd1f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -0,0 +1,656 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ +#define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ + +#include +#include +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" + +namespace mindspore { +namespace parallel { +#define MAXIMUM_INPUT_NUMBER 100 +#define DEFAULT_DATA_TYPE_LENGTH 4 +#define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory +#define GATHERV2_COST_WEIGHT0 3 +#define GATHERV2_COST_WEIGHT1 7 +#define GATHERV2_COST_WEIGHT2 2 +#define GATHERV2_COST_WEIGHT3 6 + +class OperatorCost; +using OperatorCostPtr = std::shared_ptr; + +template +double ListProduct(std::vector vec) { + double result = 1; + for (size_t i = 0; i < vec.size(); ++i) { + result *= vec[i]; + } + return result; +} +// NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of +// entries timing the length of each entry's data type +class OperatorCost { + public: + explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { + // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked + for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { + is_parameter_.push_back(false); + is_parameter_involve_.push_back(false); + inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + } + } + OperatorCost() : inputs_related_(false) { + // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked + for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { + is_parameter_.push_back(false); + is_parameter_involve_.push_back(false); + inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); + } + } + virtual ~OperatorCost() = default; + + void set_is_parameter(const std::vector &is_parameter); + void set_is_parameter_involve(const std::vector &); + void set_output_parameter_involve(int); + void set_output_critical(int); + void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); + std::vector inputs_type_lengths() const { return inputs_type_lengths_; } + std::vector outputs_type_lengths() const { return outputs_type_lengths_; } + + // per device communication cost + virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + // per device computation cost + virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const = 0; + virtual double GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + virtual double GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + // per device PEAK memory cost in a training iteration + // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), + // plus necessary inputs. + virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; + // per device memory cost in a inference phase + double GetMemoryCostForInference(const std::vector &, const std::vector &) const; + + protected: + // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of + // pre-operator that has parameters as input. + std::vector is_parameter_involve_; + int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved + // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while + // Mul's two inputs are dependent (related). + bool inputs_related_; + // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter + std::vector is_parameter_; + // for each input and output, the followings record the number of bytes of each element + std::vector inputs_type_lengths_; + std::vector outputs_type_lengths_; + // Whether the output is critical, which means that this output is included in calculating peak memory cost + // in the inference phase. + int is_outputs_critical_ = -1; +}; + +using OperatorCostPtr = std::shared_ptr; + +class MatMulCost : public OperatorCost { + public: + explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + MatMulCost() : OperatorCost(true) {} + ~MatMulCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using MatMulCostPtr = std::shared_ptr; + +class ActivationCost : public OperatorCost { + public: + explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ActivationCost() : OperatorCost(false) {} + ~ActivationCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ActivationCostPtr = std::shared_ptr; +using TransposeCost = ActivationCost; +using TransposeCostPtr = std::shared_ptr; + +class SoftmaxCost : public OperatorCost { + public: + explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + SoftmaxCost() : OperatorCost(false) {} + ~SoftmaxCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; +}; +using SoftmaxCostPtr = std::shared_ptr; + +class TmpIdentityCost : public OperatorCost { + public: + explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + TmpIdentityCost() : OperatorCost(false) {} + ~TmpIdentityCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + // per device PEAK memory cost in a training iteration + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; +}; +using TmpIdentityCostPtr = std::shared_ptr; + +class BatchParallelCost : public OperatorCost { + public: + explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + BatchParallelCost() : OperatorCost(false) {} + ~BatchParallelCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using BatchParallelCostPtr = std::shared_ptr; + +class VirtualDatasetCost : public OperatorCost { + public: + explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + VirtualDatasetCost() : OperatorCost(false) {} + ~VirtualDatasetCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + // per device PEAK memory cost in a training iteration + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { + return 0.0; + } +}; +using VirtualDatasetCostPtr = std::shared_ptr; + +class GeneratorBaseCost : public OperatorCost { + public: + explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GeneratorBaseCost() : OperatorCost(false) {} + ~GeneratorBaseCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + // Inputs vector is empty for generator ops. + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + // Generator ops don't have backward steps. + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; +using GeneratorBaseCostPtr = std::shared_ptr; + +class PReLUCost : public OperatorCost { + public: + explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + PReLUCost() : OperatorCost(true) {} + ~PReLUCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using PReLUCostPtr = std::shared_ptr; + +class OneHotCost : public OperatorCost { + public: + explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + OneHotCost() : OperatorCost(true) {} + ~OneHotCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using OneHotCostPtr = std::shared_ptr; + +class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { + public: + explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} + ~SoftmaxCrossEntropyWithLogitsCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; + +class ReshapeCost : public OperatorCost { + public: + explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ReshapeCost() : OperatorCost(true) {} + + ~ReshapeCost() override = default; + + // per device communication cost + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + // per device computation cost + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ReshapeCostPtr = std::shared_ptr; + +class ArithmeticCost : public OperatorCost { + public: + explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ArithmeticCost() : OperatorCost(false) {} + ~ArithmeticCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ArithmeticCostPtr = std::shared_ptr; +using BiasAddCost = ArithmeticCost; +using BiasAddCostPtr = std::shared_ptr; + +class ReduceMethodCost : public OperatorCost { + public: + explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + ReduceMethodCost() : OperatorCost(true) {} + ~ReduceMethodCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + void set_cross_batch(bool cb) { cross_batch_ = cb; } + + protected: + bool cross_batch_ = false; +}; +using ReduceMethodCostPtr = std::shared_ptr; + +class ReduceMeanCost : public ReduceMethodCost { + public: + explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} + ReduceMeanCost() : ReduceMethodCost(true) {} + ~ReduceMeanCost() override = default; + + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; +}; +using ReduceMeanCostPtr = std::shared_ptr; + +class GetNextCost : public OperatorCost { + public: + explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GetNextCost() : OperatorCost(false) {} + ~GetNextCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + // Inputs vector is empty for generator ops. + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } + // Generator ops don't have backward steps. + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; +using GetNextCostPtr = std::shared_ptr; + +class DropOutCost : public OperatorCost { + public: + explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + DropOutCost() : OperatorCost(true) {} + ~DropOutCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; + +using DropOutCostPtr = std::shared_ptr; + +class LayerNormCost : public OperatorCost { + public: + explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + LayerNormCost() : OperatorCost(true) {} + ~LayerNormCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { + return 0.0; + } + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override; + double GetBackwardComputationCost(const std::vector &, const std::vector &, + int32_t) const override { + return 0.0; + } +}; + +using DropOutCostPtr = std::shared_ptr; + +class GatherV2Cost : public OperatorCost { + public: + explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} + GatherV2Cost() : OperatorCost(true) {} + ~GatherV2Cost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; +}; + +using GatherV2CostPtr = std::shared_ptr; + +class GatherV2PCost : public OperatorCost { + public: + explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} + GatherV2PCost() : OperatorCost(true), axis_(0) {} + ~GatherV2PCost() override = default; + + double GetCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); + } + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override { + return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); + } + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t stage_id) const override; + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, + int32_t) const override; + void set_axis(int32_t axis) { axis_ = axis; } + void set_strategy(const Shape &strategy) { strategy_ = strategy; } + + protected: + int32_t axis_; + Shape strategy_; +}; + +using GatherV2PCostPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.cc new file mode 100644 index 0000000000..0a7e6c59d4 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.cc @@ -0,0 +1,750 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_cost.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" + +namespace mindspore { +namespace parallel { + +// Compute redistributed cost +double CostRedis(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::vector> &mode, const Graph &graph) { + // Store value of cost redist + double cost_redis = 0; + + // Number of current strategies. + size_t num_strategy = node_name_to_strategy.size(); + + // Number of node-in and node-out + size_t num_node_in = node.node_in.size(); + size_t num_node_out = node.node_out.size(); + + // Set tensor edge value with original tensor shape and cutting times. + double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n * + node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c * + node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h * + node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w; + + double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n * + node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c * + node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h * + node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w; + + // For each strategy candidate. + for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) { + // Find its forward nodes + for (size_t i_node = 0; i_node < num_node_in; i_node++) { + if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first) { + bool is_search_forward = true; + cost_redis += + CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward); + } + } + + // Find its backward nodes + for (size_t i_node = 0; i_node < num_node_out; i_node++) { + if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first) { + bool is_search_forward = false; + cost_redis += + CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward); + } + } + } + + return cost_redis; +} + +double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, + const std::vector> &mode, size_t i_strategy, size_t i_node, + double tensor_size, bool search_forward) { + double new_redis_cost = 0; + int counter = 0; + + if (search_forward) { + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_n) != + static_cast(1 / mode[i_node][0])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_c) != + static_cast(1 / mode[i_node][1])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_h) != + static_cast(1 / mode[i_node][2])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_w) != + static_cast(1 / mode[i_node][3])) { + counter += 1; + } + } else { + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_n) != + static_cast(1 / mode[2][0])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_c) != + static_cast(1 / mode[2][1])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_h) != + static_cast(1 / mode[2][2])) { + counter += 1; + } + if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_w) != + static_cast(1 / mode[2][3])) { + counter += 1; + } + } + + if (counter >= 2) { + new_redis_cost = tensor_size / 4.0; + } else if (counter == 0 || counter == 1) { + new_redis_cost = 0; + } else { + MS_LOG(EXCEPTION) << "Failure: CostRedis failed."; + } + + return new_redis_cost; +} + +// Get optimal strategy for MatMul +StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph) { + int edge_i = + static_cast(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h); + int edge_j = + static_cast(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w); + int edge_k = + static_cast(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w); + + std::vector cost_op; + std::vector> mode; + + if (edge_i < 2 || edge_i % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrConcatDimI(edge_j, edge_k) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, + graph)); + } + + if (edge_j < 2 || edge_j % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrConcatDimJ(edge_i, edge_k) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, + graph)); + } + + if (edge_k < 2 || edge_k % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrReduceDimK(edge_i, edge_j) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}}, + graph)); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Get weight for MatMul +double CostMatMul::GetMinCostIn(const OperatorRec &op) { + int edge_i = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int edge_j = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); + int edge_k = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + + std::vector cost_in; + cost_in.push_back(StrConcatDimI(edge_j, edge_k)); + cost_in.push_back(StrConcatDimJ(edge_i, edge_k)); + cost_in.push_back(StrReduceDimK(edge_i, edge_j)); + + return *min_element(cost_in.begin(), cost_in.end()); +} + +// Chose strategy for MatMul +StrategyRec CostMatMul::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_i_; + break; + + case 1: + str.inputTensor[1].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_j_; + break; + + case 2: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_k_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure:CostMatMul failed."; + } + + return str; +} + +// Get optimal strategy for Conv +StrategyRec CostConvolution::GetOptimalStr( + const Graph::NodeType &node, const std::vector> &node_name_to_strategy, + const Graph &graph, bool channel_partition) { + const OperatorRec &op = node.apply; + + int input_tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int input_tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + int input_tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int input_tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + + int tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c; + + int tensor_filter_h = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h); + int tensor_filter_w = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); + int tensor_filter_n = static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n); + int tensor_filter_c = static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); + + int tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c; + + int output_tensor_h = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h); + int output_tensor_w = static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w); + int output_tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); + int output_tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); + + int tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c; + + std::vector cost_op; + cost_op.reserve(7); + std::vector> mode; + + if (input_tensor_n < 2 || input_tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, + mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + } + + cost_op.push_back(DOUBLE_MAX); + cost_op.push_back(DOUBLE_MAX); + + if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}}, graph)); + } + + cost_op.push_back(DOUBLE_MAX); + cost_op.push_back(DOUBLE_MAX); + + if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy, + mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}}, graph)); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Get weight for Conv +double CostConvolution::GetMinCostIn(const Graph::NodeType &node) { + const OperatorRec &op = node.apply; + + int tensor_in = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) * + static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) * + static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) * + static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_filter = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) * + static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) * + static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) * + static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); + int tensor_out = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) * + static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) * + static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) * + static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); + + std::vector cost_in; + cost_in.push_back(StrDimB(tensor_filter)); + cost_in.push_back(StrDimI(tensor_in, tensor_filter)); + cost_in.push_back(StrDimJ(tensor_in, tensor_filter)); + cost_in.push_back(StrDimK(tensor_in)); + cost_in.push_back(StrDimDI(tensor_in, tensor_out)); + cost_in.push_back(StrDimDJ(tensor_in, tensor_out)); + cost_in.push_back(StrDimQ(tensor_out)); + + return *min_element(cost_in.begin(), cost_in.end()); +} + +// Chose strategy for Conv +StrategyRec CostConvolution::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_b_; + break; + + case 1: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_i_; + break; + + case 2: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_j_; + break; + + case 3: + str.inputTensor[1].str_n /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_k_; + break; + + case 4: + str.inputTensor[1].str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_di_; + break; + + case 5: + str.inputTensor[1].str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_dj_; + break; + + case 6: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_q_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostConvolution failed."; + } + return str; +} + +// Get optimal strategy for Pooling +StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph) { + int tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); + int tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); + + std::vector cost_op; + std::vector> mode; + + if (tensor_n < 2 || tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + } + + if (tensor_c < 2 || tensor_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); + } + + cost_op.push_back(DOUBLE_MAX); + cost_op.push_back(DOUBLE_MAX); + + return ChoseStr(cost_op, node.apply.str); +} + +// Chose strategy for Pooling +StrategyRec CostPooling::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostPooling failed."; + } + return str; +} + +// Chose strategy for Add +StrategyRec CostTensorAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.inputTensor[1].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostAdd failed."; + } + return str; +} + +// Get optimal strategy for Reshape +StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); } + +StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; } + +// Chose strategy for BiasAdd +StrategyRec CostBiasAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed."; + } + return str; +} + +// Get optimal strategy for Common OPs +StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph) { + const OperatorRec &op = node.apply; + int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + + std::vector cost_op; + std::vector> mode; + + if (tensor_n < 2 || tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); + } + + if (tensor_c < 2 || tensor_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); + } + + if (tensor_h < 2 || tensor_h % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}}, graph)); + } + + if (tensor_w < 2 || tensor_w % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, + mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, graph)); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Chose strategy for Common op +StrategyRec CostCommon::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: Common failed."; + } + return str; +} + +// Get optimal strategy for BatchParallel OPs +StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) { + const OperatorRec &op = node.apply; + int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); + int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); + int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); + + std::vector cost_op; + + if (tensor_n < 2 || tensor_n % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + if (tensor_c < 2 || tensor_c % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + if (tensor_h < 2 || tensor_h % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + if (tensor_w < 2 || tensor_w % 2 != 0) { + cost_op.push_back(DOUBLE_MAX); + } else { + cost_op.push_back(cost_in_); + } + + return ChoseStr(cost_op, node.apply.str); +} + +// Chose strategy for BatchParallel op +StrategyRec CostBatchParallel::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.outputTensor.str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.outputTensor.str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.outputTensor.str_h /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed."; + } + return str; +} + +// Chose strategy for CostSoftmaxCrossEntropyWithLogits +StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector &cost_op, StrategyRec str) { + uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); + if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { + return str; + } + + switch (min_position) { + case 0: + str.inputTensor[0].str_n /= 2.0; + str.inputTensor[1].str_n /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 1: + str.inputTensor[0].str_c /= 2.0; + str.inputTensor[1].str_c /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 2: + str.inputTensor[0].str_h /= 2.0; + str.inputTensor[1].str_h /= 2.0; + str.outputTensor.str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + case 3: + str.inputTensor[0].str_w /= 2.0; + str.inputTensor[1].str_w /= 2.0; + str.cut_counter += 1; + str.cost = str.cost + cost_in_; + break; + + default: + MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; + } + return str; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.h new file mode 100644 index 0000000000..563bf4598a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_cost.h @@ -0,0 +1,233 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_COST_H_ +#define PARALLEL_AUTO_PARALLEL_REC_COST_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" + +namespace mindspore { +namespace parallel { +#define DOUBLE_MAX (std::numeric_limits::max)() + +double CostRedis(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::vector> &mode, const Graph &graph); + +double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, + const std::vector> &mode, size_t i_strategy, size_t i_node, + double tensor_size, bool is_search_forward); + +// class CostMatMul is used to compute the cost of MatMul operator. +class CostMatMul { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph); + + double GetMinCostIn(const OperatorRec &op); + + private: + double StrConcatDimI(int32_t a, int32_t b) { + cost_in_i_ = (static_cast(a) * static_cast(b)) / 2.0; + + return cost_in_i_; + } + + double StrConcatDimJ(int32_t a, int32_t b) { + cost_in_j_ = (static_cast(a) * static_cast(b)) / 2.0; + + return cost_in_j_; + } + + double StrReduceDimK(int32_t a, int32_t b) { + cost_in_k_ = (static_cast(a) * static_cast(b)) / 2.0; + + return cost_in_k_; + } + + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_i_ = 0; + + double cost_in_j_ = 0; + + double cost_in_k_ = 0; +}; // class CostMatMul is used to compute the cost of MatMul operator. + +// class CostConvolution is used to compute the cost of Conv operator. +class CostConvolution { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph, bool channel_partition); + + double GetMinCostIn(const Graph::NodeType &node); + + private: + double StrDimB(int32_t TensorFilter) { + cost_in_b_ = static_cast((TensorFilter) / 2.0); + + return cost_in_b_; + } + + double StrDimI(int32_t TensorIn, int32_t TensorFilter) { + cost_in_i_ = static_cast((TensorIn + TensorFilter) / 2.0); + + return cost_in_i_; + } + + double StrDimJ(int32_t TensorIn, int32_t TensorFilter) { + cost_in_j_ = static_cast((TensorIn + TensorFilter) / 2.0); + + return cost_in_j_; + } + + double StrDimK(int32_t TensorIn) { + cost_in_k_ = static_cast((TensorIn) / 2.0); + + return cost_in_k_; + } + + double StrDimDI(int32_t TensorIn, int32_t TensorOut) { + cost_in_di_ = static_cast((TensorIn + TensorOut) / 2.0); + + return cost_in_di_; + } + + double StrDimDJ(int32_t TensorIn, int32_t TensorOut) { + cost_in_dj_ = static_cast((TensorIn + TensorOut) / 2.0); + + return cost_in_dj_; + } + + double StrDimQ(int32_t TensorOut) { + cost_in_q_ = static_cast((TensorOut) / 2.0); + + return cost_in_q_; + } + + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_b_ = 0; + + double cost_in_i_ = 0; + + double cost_in_j_ = 0; + + double cost_in_k_ = 0; + + double cost_in_di_ = 0; + + double cost_in_dj_ = 0; + + double cost_in_q_ = 0; +}; // class CostConvolution is used to compute the cost of Conv operator. + +// class CostPooling is used to compute the cost of Pooling operator. +class CostPooling { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph); + + double GetMinCostIn() const { return cost_in_; } + + private: + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_ = 0; +}; // class CostPooling is used to compute the cost of Pooling operator. + +// class CostReshape is used to compute the cost of Reshape operator. +class CostReshape { + public: + StrategyRec GetOptimalStr(const Graph::NodeType &node) const; + + double GetMinCostIn() const { return cost_in_; } + + private: + StrategyRec ChoseStr(StrategyRec str) const; + + double cost_in_ = 0; +}; // class CostReshape is used to compute the cost of Reshape operator. + +// class CostCommon is used to compute the cost of an element-wise operator +class CostCommon { + public: + virtual StrategyRec GetOptimalStr(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const Graph &graph); + + virtual double GetMinCostIn() const { return cost_in_; } + + protected: + virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_ = 0; +}; // class CostCommon is used to compute the cost of an element-wise operator + +// class CostBiasAdd is used to compute the cost of the addition between a tensor and a bias +class CostBiasAdd : public CostCommon { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; +// class CostAdd is used to compute the cost of Add operator. +class CostTensorAdd : public CostCommon { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; + +// all the following operation are element-wise and have the same cost +class CostReLU : public CostCommon {}; +class CostLog : public CostCommon {}; +class CostExp : public CostCommon {}; +class CostAdd : public CostCommon {}; +class CostSub : public CostCommon {}; +class CostMul : public CostCommon {}; +class CostDiv : public CostCommon {}; +class CostSqueeze : public CostCommon {}; +class CostCast : public CostCommon {}; + +// class BatchParallel is used to compute the cost of BatchParallel operator. +class CostBatchParallel { + public: + virtual StrategyRec GetOptimalStr(const Graph::NodeType &node); + + virtual double GetMaxCostIn() const { return DOUBLE_MAX; } + + protected: + virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); + + double cost_in_ = 0; +}; // class BatchParallel is used to compute the cost of BatchParallel operator. + +class CostBatchNorm : public CostBatchParallel {}; +class CostOneHot : public CostBatchParallel {}; +class CostPRelu : public CostBatchParallel {}; +class CostSoftmax : public CostBatchParallel {}; + +class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { + StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); +}; +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc new file mode 100644 index 0000000000..68b776155a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -0,0 +1,837 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h" + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(eli_list); + MS_EXCEPTION_IF_NULL(index_list); + GeneratePartitionedOperatorStrategy(graph, ops, index_list); + std::shared_ptr> no_stra_op_list(new std::vector); + for (size_t i = 0; i < eli_list->size(); i++) { + no_stra_op_list->push_back(eli_list->at(i)[0]); + } + GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); + GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); + GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); +} + +std::vector> PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies; + auto attrs = ops[iter_ops]->attrs(); + bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); + bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); + + // HCCL does not support multi-dimension partition, and the hardware does not support excessive + // number of EVENT, so we temporarily disable matmul's multi-dimension partition function. + const auto max_cut = 1.0 / g_device_manager->DeviceNum(); + if (graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h != max_cut && + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w != max_cut) { + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = 1.0; + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_w = 1.0; + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_h = 1.0; + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + + auto shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; + if (transpose_a) { + shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[1]; + } + auto shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[1]; + if (transpose_b) { + shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[0]; + } + + bool already_cut = false; + if (shape_1 >= shape_4) { + if (shape_1 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; + already_cut = true; + } + if (!already_cut && shape_4 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; + already_cut = true; + } + } else { + if (shape_4 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; + already_cut = true; + } + if (!already_cut && shape_1 % g_device_manager->DeviceNum() == 0) { + graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; + already_cut = true; + } + } + + if (!already_cut) { + MS_LOG(EXCEPTION) << "Failure: MatMul's shape is invalid."; + } + } + + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + std::vector s; + if (transpose_a && (iter_op_inputs == 0)) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + } else if (transpose_b && (iter_op_inputs == 1)) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + } else { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } + strategies.push_back(s); + } + return strategies; +} + +std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { + std::vector> strategies; + strategies.push_back(*s); + std::vector s_biasadd; + s_biasadd.push_back(s->at(1)); + strategies.push_back(s_biasadd); + return strategies; +} + +std::vector> PrepareOneHot(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + + int32_t axis = -1; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; + } + } + if (axis == -1) { + strategies[0][0] = strategies[0][1]; + strategies[0][1] = 1; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + } + + std::vector s_empty = {}; + strategies.push_back(s_empty); + strategies.push_back(s_empty); + return strategies; +} + +std::vector> PrepareGatherV2(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + std::vector> strategies; + + auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); + if (axis_input < 0) { + axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); + } + int32_t axis = axis_input; + if (axis >= SizeToInt(s.size())) { + MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; + } + s[axis] = 1; + strategies.push_back(s); + + auto pos = ops[iter_ops]->name().find("Info"); + auto name = ops[iter_ops]->name().substr(0, pos); + if (name == "GatherV2") { + return strategies; + } + + std::vector s_indices; + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s_indices.push_back(1); + } + strategies.push_back(s_indices); + + return strategies; +} + +std::vector> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + int32_t axis = 0; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter != ops[iter_ops]->attrs().end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis = iter->second->cast()->value(); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; + } + } + + int32_t axis_index = axis; + if (axis < 0) { + size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + axis_index = static_cast(input_dim) + axis; + } + + s[IntToSize(axis_index)] = 1; + + std::vector> strategies; + strategies.push_back(s); + return strategies; +} + +std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + + StrategyPtr origin_strategy = ops[iter_ops]->strategy(); + std::vector> strategies; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } + + size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + std::vector s; + if (output_size == 4) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 2) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 1) { + s.push_back( + static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); + } else if (output_size == 0) { + s = {}; + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted."; + } + strategies.push_back(s); + } + return strategies; +} + +std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + + StrategyPtr origin_strategy = ops[iter_ops]->strategy(); + std::vector> strategies; + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } + + std::vector s; + size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + if (input_size == 1 || input_size == 2 || input_size == 4) { + if (dim == 0) { + s.push_back(std::min(max_device_num, target_tensor_batch)); + } else { + s.push_back(1); + } + } else if (input_size == 0) { + s = {}; + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; + } + } + strategies.push_back(s); + } + + graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { + graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch); + } + + return strategies; +} + +std::vector> PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + MS_EXCEPTION_IF_NULL(ops[iter_ops]); + + auto type = ops[iter_ops]->type(); + auto idx = DictOpType.find(type); + if (idx == DictOpType.end()) { + return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); + } + + if (type == MATMUL) { + return PrepareMatMul(graph, ops, iter_graph, iter_ops); + } else if (type == ONEHOT) { + return PrepareOneHot(graph, ops, iter_graph, iter_ops); + } else { + return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); + } +} + +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::shared_ptr> &index_list) { + for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { + std::vector> strategies; + size_t iter_graph = index_list->at(iter_ops); + if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { + strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); + } + StrategyPtr sp = std::make_shared(0, strategies); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } +} + +size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, + const size_t iter_ops) { + size_t incoming_op_index = SIZE_MAX; + for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) { + for (size_t j = 0; j < input_tensor_names.size(); j++) { + if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) { + incoming_op_index = j; + break; + } + } + if (incoming_op_index != SIZE_MAX) { + break; + } + } + return incoming_op_index; +} + +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_ops, const size_t iter_graph) { + std::vector s; + for (auto input : ops[iter_ops]->inputs_tensor_info()) { + auto input_stra_dim = input.shape().size(); + if (input_stra_dim == 0) { + continue; + } + if (input_stra_dim == 1) { + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else if (input_stra_dim == 2) { + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else if (input_stra_dim == 4) { + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); + s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); + } else { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; + } + break; + } + return s; +} + +std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t incoming_op_index) { + std::vector s; + if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || + ops[incoming_op_index]->type() == TRANSPOSE) { + return s; + } + auto strategy = ops[incoming_op_index]->selected_strategy(); + if (strategy->GetInputNumber() == 0) { + return s; + } + + for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) { + if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) { + continue; + } + for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) { + s.push_back(strategy->GetInputDim()[i][j]); + } + break; + } + return s; +} + +std::vector GetAxisList(const std::vector> &ops, const int iter_ops) { + std::vector axis_list; + auto axis_param = ops[iter_ops]->attrs().find(AXIS)->second; + std::vector elements; + if (axis_param->isa()) { + elements = axis_param->cast()->value(); + } else if (axis_param->isa()) { + elements = axis_param->cast()->value(); + } else { + MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl; + } + + for (auto &element : elements) { + if (!element->isa()) { + MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl; + } + auto axis = element->cast()->value(); + axis_list.push_back(axis); + } + return axis_list; +} + +std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + std::vector s_Squeeze; + std::vector stra_dim_list; + for (size_t i = 0; i < s.size(); i++) { + stra_dim_list.push_back(i); + } + + auto axis_list = GetAxisList(ops, incoming_op_index); + for (auto axis : axis_list) { + auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis); + if (it == stra_dim_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[axis] != 1) { + MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl; + } + stra_dim_list.erase(it); + } + + for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) { + s_Squeeze.push_back(s[stra_dim_list[i]]); + } + return s_Squeeze; +} + +bool GetKeepDims(const std::vector> &ops, const size_t iter_ops) { + bool keepdims = false; + auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); + if (keep_dims_iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; + } + MS_EXCEPTION_IF_NULL(keep_dims_iter->second); + if (!keep_dims_iter->second->isa()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; + } + keepdims = keep_dims_iter->second->cast()->value(); + return keepdims; +} + +std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { + std::vector dim_list; + bool keep_dims = GetKeepDims(ops, iter_ops); + if (keep_dims != false) { + return dim_list; + } + auto input_value = ops[iter_ops]->input_value(); + auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + if (input_value.back()->isa()) { + auto attr_axis = GetValue>(input_value.back()); + if (attr_axis.empty()) { + MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; + } + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } else if (input_value.back()->isa()) { + int axis = GetValue(input_value.back()); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl; + } + return dim_list; +} + +std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + std::vector s_Reduce; + std::vector axis_list; + for (size_t i = 0; i < s.size(); i++) { + axis_list.push_back(i); + } + + auto dim_list = GetDimList(ops, incoming_op_index); + for (auto axis : dim_list) { + auto it = find(axis_list.begin(), axis_list.end(), axis); + if (it == axis_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + axis_list.erase(it); + } + + for (size_t i = 0; i < (size_t)axis_list.size(); i++) { + s_Reduce.push_back(s[axis_list[i]]); + } + return s_Reduce; +} + +std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { + std::vector dim_list; + auto iter = ops[iter_ops]->attrs().find(AXIS); + if (iter == ops[iter_ops]->attrs().end()) { + MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; + } + auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto attr_axis = GetValue>(iter->second); + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (iter->second->isa()) { + int axis = GetValue(iter->second); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + return dim_list; +} + +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s) { + bool keepdims = GetKeepDims(ops, incoming_op_index); + if (keepdims) { + return s; + } + + std::vector s_Arg; + std::vector axis_list; + for (size_t i = 0; i < s.size(); i++) { + axis_list.push_back(i); + } + + auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); + for (auto axis : dim_list) { + auto it = find(axis_list.begin(), axis_list.end(), axis); + if (it == axis_list.end()) { + MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; + } + axis_list.erase(it); + } + + for (size_t i = 0; i < (size_t)axis_list.size(); i++) { + s_Arg.push_back(s[axis_list[i]]); + } + return s_Arg; +} + +std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t incoming_op_index) { + std::vector s; + s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); + if (s.size() != 0) { + if (ops[incoming_op_index]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); + } + if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX || + ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { + s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); + } + if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { + s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); + } + } + return s; +} + +std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, + const size_t iter_ops, + std::vector basic_stra) { + std::vector s_empty = {}; + std::vector> stra; + MS_EXCEPTION_IF_NULL(ops[iter_ops]); + + if (basic_stra.size() == 0) { + for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); + iter_op_inputs++) { + stra.push_back(basic_stra); + } + return stra; + } + + auto s_ptr = std::make_shared>(basic_stra); + if (ops[iter_ops]->type() == BIAS_ADD) { + return PrepareBiasAdd(s_ptr); + } + if (ops[iter_ops]->type() == GATHERV2) { + return PrepareGatherV2(ops, iter_ops, basic_stra); + } + if (ops[iter_ops]->type() == L2_NORMALIZE) { + return PrepareL2Normalize(ops, iter_ops, basic_stra); + } + + for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); + iter_op_inputs++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { + stra.push_back(s_empty); + continue; + } + + std::vector tmp_stra = basic_stra; + bool modified = false; + for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { + tmp_stra[j] = 1; + modified = true; + } + } + if (modified) { + stra.push_back(tmp_stra); + } else { + stra.push_back(basic_stra); + } + } + return stra; +} + +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + std::vector no_stra_op_list_bis; + + for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { + size_t iter_ops = no_stra_op_list->at(iter_list - 1); + std::vector> stra; + std::vector s; + size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); + if (incoming_op_index != SIZE_MAX) { + auto iter_graph = index_list->at(incoming_op_index); + if (iter_graph != SIZE_MAX) { + s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); + } else { + s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index); + } + } + + if (s.size() == 0) { + no_stra_op_list_bis.push_back(iter_ops); + } else { + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + } + + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } + + no_stra_op_list->clear(); + for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { + no_stra_op_list->push_back(no_stra_op_list_bis[i]); + } +} + +std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, + const size_t iter_ops, std::vector s) { + std::vector s_Squeeze; + auto axis_list = GetAxisList(ops, iter_ops); + size_t s_index = 0; + size_t axis_list_index = 0; + for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) { + if (i == (size_t)axis_list[axis_list_index]) { + s_Squeeze.push_back(1); + axis_list_index++; + } else { + s_Squeeze.push_back(s[s_index]); + s_index++; + } + } + + size_t cut = 1; + for (size_t i = 0; i < s_Squeeze.size(); i++) { + cut *= s_Squeeze[i]; + } + if (cut != g_device_manager->DeviceNum()) { + s_Squeeze.clear(); + } + + return s_Squeeze; +} + +std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops) { + std::vector s; + if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || + ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || + ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || + ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { + return s; + } + + bool found = false; + size_t outgoing_op_index = SIZE_MAX; + size_t iter_op_inputs = SIZE_MAX; + for (size_t i = 0; i < input_tensor_names.size(); i++) { + for (size_t j = 1; j < input_tensor_names[i].size(); j++) { + if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] && + ops[i]->selected_strategy()->GetInputNumber() != 0) { + outgoing_op_index = i; + iter_op_inputs = j - 1; + found = true; + break; + } + } + if (found) { + break; + } + } + + if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) { + for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) { + s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); + } + } + return s; +} + +void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + std::vector no_stra_op_list_bis; + + for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { + auto iter_ops = no_stra_op_list->at(iter_list - 1); + std::vector> stra; + std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); + + if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { + s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); + } + if (s.size() != 0) { + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + } else { + no_stra_op_list_bis.push_back(iter_ops); + } + + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } + + no_stra_op_list->clear(); + for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { + no_stra_op_list->push_back(no_stra_op_list_bis[i]); + } +} + +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list) { + if (no_stra_op_list->size() == 0) { + return; + } + + size_t no_stra_op_list_size = no_stra_op_list->size(); + do { + no_stra_op_list_size = no_stra_op_list->size(); + GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); + GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); + } while (no_stra_op_list_size > no_stra_op_list->size()); + + for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) { + auto iter_ops = no_stra_op_list->at(iter_list); + std::vector> stra; + std::vector s; + + size_t max_dim_num = 0; + for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) { + max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); + } + } + for (size_t i = 0; i < max_dim_num; i++) { + s.push_back(1); + } + + stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); + StrategyPtr sp = std::make_shared(0, stra); + ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h new file mode 100644 index 0000000000..9acd05e0a9 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ +#define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, + const std::shared_ptr>> &eli_list, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list); +std::vector> PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareBiasAdd(const std::shared_ptr> &s); +std::vector> PrepareOneHot(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareGatherV2(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector> PrepareL2Normalize(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +std::vector> PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_graph, const size_t iter_ops); +void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::shared_ptr> &index_list); +size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, + const size_t iter_ops); +std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const size_t iter_ops, const size_t iter_graph); +std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t incoming_op_index); +std::vector GetAxisList(const std::vector> &ops, const int iter_ops); +std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); +bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); +std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); +std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); +std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); +std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, + const size_t incoming_op_index, std::vector s); +std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t incoming_op_index); +std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, + const size_t iter_ops, + std::vector basic_stra); +void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); +std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, + const size_t iter_ops, std::vector s); +std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, + const std::vector> &input_tensor_names, + const size_t iter_ops); +void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &no_stra_op_list); +void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, + const std::vector> &ops, + const std::vector> &input_tensor_names, + const std::shared_ptr> &index_list, + const std::shared_ptr> &no_stra_op_list); +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h new file mode 100644 index 0000000000..15b8220016 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_graph.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ +#define PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ + +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" + +namespace mindspore { +namespace parallel { +enum OperatorType { + kRecUnkownType, + kRecMatMul, + kRecConvolution, + kRecPooling, + kRecElmWiseOp, + kRecReLU, + kRecBatchNorm, + kRecReshape, + kRecBiasAdd, + kRecSoftmax, + kRecSparseSoftmaxCrossEntropyWithLogits, + kRecSoftmaxCrossEntropyWithLogits, + kRecOneHot, + kRecLog, + kRecExp, + kRecAdd, + kRecSub, + kRecMul, + kRecDiv, + kRecSqueeze, + kRecCast, + kRecReduce, + kRecPReLU, + kRecGatherV2, + kRecArgWithValue +}; + +enum InfoType { kApplication, kConstant }; + +struct OperatorRec { + OperatorType op_type; + TensorParam arguments[MAX_INPUT_NUM]; + StrategyRec str; +}; + +// Define simplified dataflow Graph for partitioning +class Graph { + public: + struct NodeType { + std::string name; + // Nodes that point to this node + std::vector node_in; + // Nodes that point from this node + std::vector node_out; + std::vector node_in_aux; + // Node Type Info: Application or Constant. Defined in enum . + InfoType info; + // Operator info. Defined in struct . + OperatorRec apply; + // Tensor info. Defined in tensor.h struct . + TensorParam tensor_parm; + }; + + std::vector nodes; // Nodes of the graph. Pubic. +}; // Define simplified dataflow Graph for partitioning +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc new file mode 100644 index 0000000000..a393c825df --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -0,0 +1,264 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +const TensorParam MakeTensor(int n, int c, int h, int w) { + TensorParam new_tensor; + new_tensor.tensor_type = kFloat32; + new_tensor.tensor_shape.shape_n = n; + new_tensor.tensor_shape.shape_c = c; + new_tensor.tensor_shape.shape_h = h; + new_tensor.tensor_shape.shape_w = w; + const TensorParam &tensor = new_tensor; + return tensor; +} + +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { + Graph::NodeType NewOp; + NewOp.name = ops[iter_ops]->name(); + NewOp.info = InfoType::kApplication; + + auto op_type = ops[iter_ops]->type(); + auto idx = DictOpType.find(op_type); + if (idx == DictOpType.end()) { + NewOp.apply.op_type = OperatorType::kRecUnkownType; + MS_LOG(INFO) << "Unknown operator type."; + } else { + NewOp.apply.op_type = DictOpType.at(op_type); + } + + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { + NewOp.tensor_parm = MakeTensor( + ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], + ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], + ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; + } + + NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); + return NewOp; +} + +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor) { + for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); + iter_input_tensors++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { + NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; + } + } + return NewTensor.apply; +} + +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensors, Graph::NodeType NewTensor) { + if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { + auto attrs = ops[iter_ops]->attrs(); + bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); + bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); + if (transpose_a && (iter_input_tensors == 0)) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else if (transpose_b && (iter_input_tensors == 1)) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + } + } else { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + } + return NewTensor.apply.arguments[iter_input_tensors]; +} + +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names) { + std::shared_ptr graph(new Graph); + if (ops.size() > SIZE_MAX / 2) { + MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; + } + + for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { + Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); + graph->nodes.push_back(NewOp); + } + MakeEdge(input_tensor_names, graph); + + return graph; +} + +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { + for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { + for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { + size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); + if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { + graph->nodes[iter_i].node_in.push_back(head_node_index); + graph->nodes[head_node_index].node_out.push_back(iter_i); + } + } + } +} + +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, + const std::string &input_name) { + for (size_t index = 0; index < input_tensor_name.size(); index++) { + if (input_tensor_name[index][0] == input_name) { + return index; + } + } + MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; + return SIZE_MAX; +} + +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list) { + std::vector eli; + eli.push_back(node_index); + for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { + eli.push_back(graph->nodes[node_index].node_out[i]); + } + eli_list->push_back(eli); + + for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { + auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; + auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); + if (it != incoming_outputs->end()) { + it = incoming_outputs->erase(it); + incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end()); + } + } + + for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { + auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; + auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); + if (it != aux_incoming_outputs->end()) { + it = aux_incoming_outputs->erase(it); + aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), + graph->nodes[node_index].node_out.end()); + } + } + + for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { + auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; + auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); + if (it != outgoing_inputs->end()) { + if (graph->nodes[node_index].node_in.size() > 0) { + outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0]; + for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); + } + for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) { + graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( + graph->nodes[node_index].node_in_aux[j]); + } + } else { + outgoing_inputs->erase(it); + } + } + } +} + +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list) { + MS_EXCEPTION_IF_NULL(graph); + for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { + auto type = graph->nodes[node_index].apply.op_type; + if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) { + Eliminate_Aux(node_index, graph, eli_list); + } + } + index_list->reserve(graph->nodes.size()); + for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { + index_list->push_back(i); + } + for (size_t i = 0; i < (size_t)eli_list->size(); i++) { + if (eli_list->at(i)[0] >= index_list->size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + index_list->at(eli_list->at(i)[0]) = SIZE_MAX; + for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { + index_list->at(j)--; + } + } + std::shared_ptr new_graph(new Graph); + for (size_t i = 0; i < graph->nodes.size(); i++) { + if (index_list->at(i) > SIZE_MAX / 2) { + continue; + } + new_graph->nodes.push_back(graph->nodes[i]); + auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; + for (size_t j = node_in->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_in->erase(node_in->begin() + j - 1); + } else { + node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); + } + } + auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; + for (size_t j = node_out->size(); j > 0; j--) { + bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); + if (IsEliminated) { + node_out->erase(node_out->begin() + j - 1); + } else { + node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); + } + } + } + return new_graph; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h new file mode 100644 index 0000000000..4d0c02f5fe --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -0,0 +1,145 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ +#define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace mindspore { +namespace parallel { +static const std::set ElementWiseOpType = { + OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, + OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, + OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, + OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; + +const std::map DictOpType{ + {MATMUL, OperatorType::kRecMatMul}, + {CONV2D, OperatorType::kRecConvolution}, + {MAXPOOL, OperatorType::kRecPooling}, + {MAXPOOLV2, OperatorType::kRecPooling}, + {SIMPLE_MEAN, OperatorType::kRecPooling}, + {RESHAPE, OperatorType::kRecReshape}, + {BIAS_ADD, OperatorType::kRecBiasAdd}, + {BATCH_NORM, OperatorType::kRecBatchNorm}, + {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, + {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, + {ONEHOT, OperatorType::kRecOneHot}, + {SQUEEZE, OperatorType::kRecSqueeze}, + {CAST, OperatorType::kRecCast}, + {REDUCE_SUM, OperatorType::kRecReduce}, + {REDUCE_MAX, OperatorType::kRecReduce}, + {REDUCE_MIN, OperatorType::kRecReduce}, + {REDUCE_MEAN, OperatorType::kRecReduce}, + {GATHERV2, OperatorType::kRecGatherV2}, + {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, + {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, + + {RELU, OperatorType::kRecReLU}, + {"ReLU6", OperatorType::kRecReLU}, + {"ReLUV2", OperatorType::kRecReLU}, + {SIGMOID, OperatorType::kRecReLU}, + {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU}, + {"HSigmoid", OperatorType::kRecReLU}, + {GELU, OperatorType::kRecReLU}, + {TANH, OperatorType::kRecReLU}, + + {PRELU, OperatorType::kRecPReLU}, + + {TRANSPOSE, OperatorType::kRecElmWiseOp}, + {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, + {TENSOR_ADD, OperatorType::kRecElmWiseOp}, + {SUB, OperatorType::kRecElmWiseOp}, + {MUL, OperatorType::kRecElmWiseOp}, + {DIV, OperatorType::kRecElmWiseOp}, + {REAL_DIV, OperatorType::kRecElmWiseOp}, + {SOFTMAX, OperatorType::kRecSoftmax}, + {LOG_SOFTMAX, OperatorType::kRecSoftmax}, + {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, + {SQRT, OperatorType::kRecElmWiseOp}, + {NEG, OperatorType::kRecElmWiseOp}, + {POW, OperatorType::kRecElmWiseOp}, + {EXP, OperatorType::kRecElmWiseOp}, + {LOG, OperatorType::kRecElmWiseOp}, + {COS, OperatorType::kRecElmWiseOp}, + {ACOS, OperatorType::kRecElmWiseOp}, + {LOGICALNOT, OperatorType::kRecElmWiseOp}, + {"LogicalAnd", OperatorType::kRecElmWiseOp}, + {"LogicalOr", OperatorType::kRecElmWiseOp}, + {SQUARE, OperatorType::kRecElmWiseOp}, + {"Abs", OperatorType::kRecElmWiseOp}, + {"Acosh", OperatorType::kRecElmWiseOp}, + {"AddN", OperatorType::kRecElmWiseOp}, + {"AccumulateNV2", OperatorType::kRecElmWiseOp}, + {"Atan2", OperatorType::kRecElmWiseOp}, + {"Erf", OperatorType::kRecElmWiseOp}, + {"Floor", OperatorType::kRecElmWiseOp}, + {FLOORDIV, OperatorType::kRecElmWiseOp}, + {"FloorMod", OperatorType::kRecElmWiseOp}, + {GREATER, OperatorType::kRecElmWiseOp}, + {"GreaterEqual", OperatorType::kRecElmWiseOp}, + {"HSwish", OperatorType::kRecElmWiseOp}, + {"Less", OperatorType::kRecElmWiseOp}, + {"LessEqual", OperatorType::kRecElmWiseOp}, + {MAXIMUM, OperatorType::kRecElmWiseOp}, + {MINIMUM, OperatorType::kRecElmWiseOp}, + {EQUAL, OperatorType::kRecElmWiseOp}, + {NOT_EQUAL, OperatorType::kRecElmWiseOp}, + {"Reciprocal", OperatorType::kRecElmWiseOp}, + {"Round", OperatorType::kRecElmWiseOp}, + {"Rsqrt", OperatorType::kRecElmWiseOp}, + {"Sign", OperatorType::kRecElmWiseOp}, + {"Sin", OperatorType::kRecElmWiseOp}, + {ASSIGN, OperatorType::kRecElmWiseOp}, + {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, + {"AssignAdd", OperatorType::kRecElmWiseOp}}; + +const TensorParam MakeTensor(int n, int c, int h, int w); + +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops); + +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, + Graph::NodeType NewTensor); + +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, + const size_t iter_input_tensor, Graph::NodeType NewTensor); + +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names); + +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph); + +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, + const std::string &input_name); + +void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list); + +std::shared_ptr EliminateGraph(const std::shared_ptr &graph, + const std::shared_ptr>> &eli_list, + const std::shared_ptr> &index_list); +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc new file mode 100644 index 0000000000..97d230a49f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.cc @@ -0,0 +1,310 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +// Get the target node's weight for sorting. +double GetWeights(const Graph::NodeType &node) { + const OperatorRec &op = node.apply; + + if (op.op_type == OperatorType::kRecMatMul) { + // For MatMul + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(op); + } else if (op.op_type == OperatorType::kRecConvolution) { + // For Convolution + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(node); + } else if (op.op_type == OperatorType::kRecPooling) { + // For Pooling + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecElmWiseOp) { + // For TensorAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecReLU) { + // For Activation + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecReshape) { + // For Reshape + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecBiasAdd) { + // For BiasAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || + op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || + op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv || + op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) { + // For element-wise op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || + op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || + op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || + op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For BatchParallel op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetMaxCostIn(); + } else if (op.op_type == OperatorType::kRecUnkownType) { + // For Unkown type + return 0.0; + } else { + MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; + } +} + +// Sort all the nodes by their weights +std::vector SortByWeight(const std::shared_ptr &graph) { + MS_EXCEPTION_IF_NULL(graph); + + std::vector> weight_to_node_index; + std::vector node_index_by_weights; + + // Get node's weight. + for (size_t i = 0; i < graph->nodes.size(); i++) { + if (graph->nodes[i].info == kApplication) { + const Graph::NodeType &node_ptr = graph->nodes[i]; + double weight = GetWeights(node_ptr); + size_t index = i; + weight_to_node_index.push_back(std::make_pair(weight, index)); + } + } + + // Ordering ops aka nodes of the graph + std::sort(weight_to_node_index.begin(), weight_to_node_index.end()); + + // Store the result in node_index_by_weights. + uint64_t size = weight_to_node_index.size(); + for (uint64_t i = 1; i <= size; i++) { + node_index_by_weights.push_back(weight_to_node_index[size - i].second); + } + + return node_index_by_weights; +} + +// Get optimal strategy to partition the target node +StrategyRec PartitionNode(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::shared_ptr &graph) { + bool enable_conv_chw_partition = false; + MS_EXCEPTION_IF_NULL(graph); + + if (node.apply.op_type == OperatorType::kRecMatMul) { + // For MatMul + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecConvolution) { + // For Convolution + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition); + } else if (node.apply.op_type == OperatorType::kRecPooling) { + // For Pooling + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecElmWiseOp) { + // For TensorAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecReLU) { + // For Activation + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecReshape) { + // For Reshape + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecBiasAdd) { + // For BiasAdd + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp || + node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub || + node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv || + node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) { + // For element-wise op + auto cost_ptr = std::make_shared(); + + return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || + node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || + node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { + // For BatchParallel type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { + // For SoftmaxCrossEntropyWithLogits type + auto cost_ptr = std::make_shared(); + return cost_ptr->GetOptimalStr(node); + } else if (node.apply.op_type == OperatorType::kRecUnkownType) { + // For Unkown type + StrategyRec default_strategy; + return default_strategy; + } else { + MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; + } +} + +// Parttion graph into all devices. +Status PartitionForAllDevices(const size_t num_device, const double device_memory, + const std::shared_ptr &graph) { + if (num_device < 1) { + MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; + } + + if (num_device > 1024) { + MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be larger than 1024."; + } + + MS_EXCEPTION_IF_NULL(graph); + + // Comopute iter times + int iter_times = static_cast(log2(num_device)); + + // N-cuts loop + for (int loop = 0; loop < iter_times; loop++) { + // Sort by weights + std::vector reorder_node_list = SortByWeight(graph); + + // get total node number + size_t iter_nodes = reorder_node_list.size(); + + // temp vector to map nodename to its strategy. + std::vector> node_name_to_strategy; + + // Loop for all the nodes + for (size_t i_node = 0; i_node < iter_nodes; i_node++) { + // get current node's index + size_t index = reorder_node_list[i_node]; + + Graph::NodeType &node_ptr = graph->nodes[index]; + + // Serch optimal strategy to cut this operator. And store the result optimal strategy in graph. + graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph); + + // Apply OP Strategy to Tensor Strategy. + graph->nodes[index] = ApplyStrToTensor(node_ptr); + + // Note down the node name and its strategy in this loop. + auto node_name_to_str = + std::pair(graph->nodes[index].name, graph->nodes[index].apply.str); + node_name_to_strategy.push_back(node_name_to_str); + } + } + + if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) { + return FAILED; + } else { + return SUCCESS; + } +} + +// Apply OP Strategy to Tensor Strategy +Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { + // Set Node's tensor_parm + Node.tensor_parm.tensor_str.str_n = Node.apply.str.outputTensor.str_n; + Node.tensor_parm.tensor_str.str_c = Node.apply.str.outputTensor.str_c; + Node.tensor_parm.tensor_str.str_h = Node.apply.str.outputTensor.str_h; + Node.tensor_parm.tensor_str.str_w = Node.apply.str.outputTensor.str_w; + + // Set input tensors' tersor_parm + for (int i = 0; i < 2; i++) { + Node.apply.arguments[i].tensor_str.str_n = Node.apply.str.inputTensor[i].str_n; + Node.apply.arguments[i].tensor_str.str_c = Node.apply.str.inputTensor[i].str_c; + Node.apply.arguments[i].tensor_str.str_h = Node.apply.str.inputTensor[i].str_h; + Node.apply.arguments[i].tensor_str.str_w = Node.apply.str.inputTensor[i].str_w; + } + return Node; +} + +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph) { + MS_EXCEPTION_IF_NULL(graph); + if (num_device == 0) { + MS_LOG(EXCEPTION) << "Failure: device number is 0."; + } + + uint64_t iter_nodes = graph->nodes.size(); + double used_memory = 0.0; + + for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { + if (graph->nodes[i_node].info == 0) { + Graph::NodeType &Node = graph->nodes[i_node]; + for (int index = 0; index < 2; index++) { + used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * + Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c * + Node.apply.arguments[index].tensor_str.str_h * Node.apply.arguments[index].tensor_shape.shape_h * + Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w * + GetDataTypeSize(Node.apply.arguments[index].tensor_type); + } + } + } + + if (device_memory < (used_memory / num_device)) { + MS_LOG(EXCEPTION) << "Failure: Out of memory!"; + return FAILED; + } else { + return SUCCESS; + } +} + +size_t GetDataTypeSize(const TensorType &type) { + switch (type) { + case kInt8: + return sizeof(int); + case kFloat16: + return sizeof(float) / 2; + case kFloat32: + return sizeof(float); + case kDouble64: + return sizeof(double); + default: + MS_LOG(EXCEPTION) << "GetDataTypeSize Failed. Unexpected type"; + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.h new file mode 100644 index 0000000000..528163e4d3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_partition.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ +#define PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/rec_core/rec_cost.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +std::vector SortByWeight(const std::shared_ptr &graph); + +double GetWeights(const Graph::NodeType &node); + +StrategyRec PartitionNode(const Graph::NodeType &node, + const std::vector> &node_name_to_strategy, + const std::shared_ptr &graph); + +Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr &graph); + +Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); + +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph); + +size_t GetDataTypeSize(const TensorType &type); +} // namespace parallel +} // namespace mindspore + +#endif // PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_strategy.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_strategy.h similarity index 100% rename from mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_strategy.h rename to mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_strategy.h diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_tensor.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_tensor.h new file mode 100644 index 0000000000..315c52c867 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_tensor.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ +#define PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ + +#include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" + +namespace mindspore { +namespace parallel { +enum TensorType { kInt8, kFloat16, kFloat32, kDouble64 }; + +struct Shape4D { + int32_t shape_n = 1; + int32_t shape_c = 1; + int32_t shape_h = 1; + int32_t shape_w = 1; +}; + +struct TensorParam { + TensorType tensor_type = kFloat32; // default as float. + Shape4D tensor_shape; + TensorStr4D tensor_str; +}; +} // namespace parallel +} // namespace mindspore + +#endif // PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc new file mode 100644 index 0000000000..7164660be0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/context.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "frontend/parallel/device_manager.h" + +namespace mindspore { +namespace parallel { +static std::map> param_shapes; + +std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, + AUTO_PARALLEL}; +std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; + +std::shared_ptr ParallelContext::inst_context_ = nullptr; + +std::shared_ptr ParallelContext::GetInstance() { + if (inst_context_ == nullptr) { + inst_context_.reset(new (std::nothrow) ParallelContext()); + } + return inst_context_; +} + +ParallelContext::ParallelContext() { Reset(); } + +void ParallelContext::Reset() { + mirror_mean_ = false; + full_batch_ = false; + cast_before_mirror_ = true; + loss_repeated_mean_ = true; + device_num_ = 1; + global_rank_ = 0; + communication_backend_ = HCCL_BACKEND; + device_num_is_set_ = false; + global_rank_is_set_ = false; + parallel_mode_ = STAND_ALONE; + parameter_broadcast_ = false; + parameter_broadcast_is_set_ = false; + enable_all_reduce_fusion_ = false; + strategy_ckpt_load_file_ = ""; + strategy_ckpt_save_file_ = ""; + enable_parallel_optimizer_ = false; +} + +void ParallelContext::set_device_num(int32_t device_num) { + device_num_ = device_num; + device_num_is_set_ = true; +} + +void ParallelContext::set_global_rank(int32_t global_rank) { + global_rank_ = global_rank; + global_rank_is_set_ = true; +} + +void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } + +void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } + +void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } + +void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } + +void ParallelContext::set_communication_backend(const std::string &communication_backend) { + communication_backend_ = communication_backend; +} + +bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { + auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); + if (iter == PARALLEL_MODE_LIST.end()) { + MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; + return false; + } + parallel_mode_ = parallel_mode; + return true; +} + +bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { + auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); + if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { + MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; + return false; + } + strategy_search_mode_ = strategy_search_mode; + return true; +} + +void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { + parameter_broadcast_ = parameter_broadcast; + parameter_broadcast_is_set_ = true; +} + +void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) { + strategy_ckpt_load_file_ = strategy_ckpt_load_file; +} + +void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) { + strategy_ckpt_save_file_ = strategy_ckpt_save_file; +} + +void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { + all_reduce_fusion_split_indices_[group] = indices; +} + +const std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { + auto iter = all_reduce_fusion_split_indices_.find(group); + if (iter != all_reduce_fusion_split_indices_.end()) { + return iter->second; + } + return {}; +} + +void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group) { + all_reduce_fusion_split_sizes_[group] = sizes; +} + +const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { + auto iter = all_reduce_fusion_split_sizes_.find(group); + if (iter != all_reduce_fusion_split_sizes_.end()) { + return iter->second; + } + return {}; +} + +// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + return; + } + param_shapes.clear(); +} + +// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(param_node); + MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || + func_graph->has_flag(TRAINING)) { + return; + } + + auto iter = param_shapes.find(param_node->name()); + if (iter == param_shapes.end()) { + MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); + return; + } + std::vector shape = iter->second; + std::shared_ptr base_shape = std::make_shared(shape); + ptr->set_shape(base_shape); + MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; +} + +// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode +void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(param_node); + MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { + return; + } + + std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); + auto ret = param_shapes.try_emplace(param_node->name(), shape); + if (!ret.second) { + MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; + return; + } + + MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h new file mode 100644 index 0000000000..1bb40d5c29 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ +#define MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/status.h" +#include "utils/convert_utils.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "debug/info.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace parallel { +constexpr char STAND_ALONE[] = "stand_alone"; +constexpr char DATA_PARALLEL[] = "data_parallel"; +constexpr char HYBRID_PARALLEL[] = "hybrid_parallel"; +constexpr char AUTO_PARALLEL[] = "auto_parallel"; +constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; + +constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; +constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; + +constexpr char TRAINING[] = "training"; + +class ParallelContext { + public: + ~ParallelContext() = default; + ParallelContext(const ParallelContext &) = delete; + ParallelContext &operator=(const ParallelContext &) = delete; + + static std::shared_ptr GetInstance(); + + void set_mirror_mean(bool mirror_mean); + bool mirror_mean() const { return mirror_mean_; } + + void set_full_batch(bool full_batch); + bool full_batch() const { return full_batch_; } + + void set_cast_before_mirror(bool cast_before_mirror); + bool cast_before_mirror() const { return cast_before_mirror_; } + + void set_loss_repeated_mean(bool loss_repeated_mean); + bool loss_repeated_mean() const { return loss_repeated_mean_; } + + void set_device_num(int32_t device_num); + int32_t device_num() const { return device_num_; } + + void set_global_rank(int32_t global_rank); + int32_t global_rank() const { return global_rank_; } + + void set_communication_backend(const std::string &communication_backend); + std::string communication_backend() const { return communication_backend_; } + + bool set_parallel_mode(const std::string ¶llel_mode); + std::string parallel_mode() const { return parallel_mode_; } + + bool set_strategy_search_mode(const std::string &strategy_search_mode); + std::string strategy_search_mode() const { return strategy_search_mode_; } + + void set_parameter_broadcast(bool parameter_broadcast); + bool parameter_broadcast() const { return parameter_broadcast_; } + + bool device_num_is_set() const { return device_num_is_set_; } + bool global_rank_is_set() const { return global_rank_is_set_; } + bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } + + void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); + const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; + void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); + const std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; + void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { + enable_all_reduce_fusion_ = enable_all_reduce_fusion; + } + bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } + + void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); + std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } + void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); + std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } + + void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { + enable_parallel_optimizer_ = enable_parallel_optimizer; + } + bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } + + void Reset(); + + private: + ParallelContext(); + static std::shared_ptr inst_context_; + bool mirror_mean_; + bool full_batch_; + bool cast_before_mirror_; + bool loss_repeated_mean_; + int32_t device_num_; + int32_t global_rank_; + std::string communication_backend_; + std::string parallel_mode_; + std::string strategy_search_mode_; + bool parameter_broadcast_; + bool device_num_is_set_; + bool global_rank_is_set_; + bool parameter_broadcast_is_set_; + bool enable_all_reduce_fusion_; + std::map> all_reduce_fusion_split_indices_; + std::map> all_reduce_fusion_split_sizes_; + std::string strategy_ckpt_load_file_; + std::string strategy_ckpt_save_file_; + bool enable_parallel_optimizer_; +}; + +void ParallelParameterContextInit(const FuncGraphPtr &func_graph); +void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + AbstractBasePtr ptr); +void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, + const AbstractBasePtr &ptr); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/costmodel_context.cc b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc new file mode 100644 index 0000000000..67d087eabd --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/costmodel_context.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/costmodel_context.h" + +#include + +#include "frontend/parallel/allreduce_fusion/allreduce_fusion.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" + +namespace mindspore { +namespace parallel { +std::shared_ptr CostModelContext::cm_context_inst_ = nullptr; + +std::shared_ptr CostModelContext::GetInstance() { + if (cm_context_inst_ == nullptr) { + MS_LOG(INFO) << "Create costmodel_context"; + cm_context_inst_.reset(new (std::nothrow) CostModelContext()); + } + return cm_context_inst_; +} + +CostModelContext::CostModelContext() { + ResetCostModel(); + ResetAlgoParameters(); +} + +void CostModelContext::ResetCostModel() { + device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; + costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; + costmodel_beta_ = DEFAULT_COST_MODEL_BETA; + costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; + costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; + costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; + costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; + is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; + run_phase_ = DEFAULT_RUN_PHASE; + costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; + costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; + costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; + costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME; + costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME; + costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH; + costmodel_allreduce_fusion_computation_time_parameter_ = + DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER; +} + +void CostModelContext::ResetAlgoParameters() { + costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; + tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; + tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; + fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; + elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; +} + +void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } + +void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } + +void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } + +void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } + +void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } + +void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { + costmodel_communi_threshold_ = cm_communi_th; +} + +void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { + costmodel_communi_const_ = cm_communi_const; +} + +void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } + +void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } +void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { + costmodel_allreduce_fusion_algorithm_ = algorithm; +} + +void CostModelContext::set_costmodel_allreduce_fusion_times(int32_t allreduce_fusion_times) { + costmodel_allreduce_fusion_times_ = allreduce_fusion_times; +} + +void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) { + costmodel_allreduce_fusion_tail_percent_ = tail_percent; +} + +void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) { + costmodel_allreduce_fusion_tail_time_ = tail_time; +} + +void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) { + costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time; +} + +void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) { + costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth; +} + +void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) { + costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; +} + +void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } + +void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { + tensor_slice_alignment_size_ = ts_align_size; +} + +void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } + +void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { + elementwise_stra_follow_ = elementwise_follow; +} + +void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/frontend/parallel/costmodel_context.h similarity index 100% rename from mindspore/ccsrc/parallel/costmodel_context.h rename to mindspore/ccsrc/frontend/parallel/costmodel_context.h diff --git a/mindspore/ccsrc/frontend/parallel/device.h b/mindspore/ccsrc/frontend/parallel/device.h new file mode 100644 index 0000000000..c9633623d2 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ +#define MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ + +#include +#include +#include + +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class Device { + // This class abstract the 'device' information, used in Parallel module. + public: + Device() : rank_(0) { name_.clear(); } + explicit Device(int32_t rank) : rank_(rank) { name_.clear(); } + Device(std::string name, int32_t rank) : name_(std::move(name)), rank_(rank) {} + ~Device() = default; + std::string name() const { return name_; } + int32_t rank() const { return rank_; } + + private: + std::string name_; + int32_t rank_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.cc b/mindspore/ccsrc/frontend/parallel/device_manager.cc new file mode 100644 index 0000000000..d3657afdb8 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_manager.cc @@ -0,0 +1,374 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/device_manager.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/step_parallel.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +DeviceManagerPtr g_device_manager = nullptr; + +Stage::Stage(const std::vector &devices, int num, int rank) + : devices_(devices), number_(num), rank_(rank) { + gm_ = GroupManager(); +} + +// NOTE: '-1' indicates ERROR +int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } + +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { + if (device_num <= 0) { + MS_LOG(ERROR) << "'device_num' must be positive."; + return false; + } + if (global_rank < 0) { + MS_LOG(ERROR) << "'global_rank' must be nonnegative."; + return false; + } + if (device_num > MAX_DEVICE_NUM) { + MS_LOG(ERROR) << "'device_num' must be no more than " << MAX_DEVICE_NUM << "."; + return false; + } + // 'device_num_converted' must be the power of 2 + if ((IntToUint(device_num) & IntToUint(device_num - 1)) != 0) { + MS_LOG(ERROR) << "'device_num' must be the power of 2."; + return false; + } + if (global_rank >= device_num) { + MS_LOG(ERROR) << "'global_rank' must be less than 'device_num'."; + return false; + } + if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { + MS_LOG(ERROR) << "Invalid backend: " << backend; + return false; + } + + RankList devices, stage_map; + for (int i = 0; i < device_num; ++i) { + devices.push_back(i); + } + + stage_map.push_back(device_num); + g_device_manager = std::make_shared(); + if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { + MS_LOG(INFO) << "Device initialization succeeds."; + return true; + } else { + MS_LOG(ERROR) << "Device initialization fails."; + return false; + } +} + +void CheckGlobalDeviceManager() { + if (g_device_manager == nullptr) { + MS_LOG(EXCEPTION) << "Device information has not been set!"; + } +} + +int32_t GetListMemberByIndex(size_t index, const RankList &devices) { + size_t i = 0; + int32_t result = 0; + if ((devices.empty()) || (index >= devices.size())) { + MS_LOG(EXCEPTION) << "Index is out of the list scope"; + } + auto it = devices.begin(); + for (; it != devices.end(); ++it) { + if (i == index) { + result = *it; + break; + } + ++i; + } + return result; +} + +std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { + size_t i = 0; + std::shared_ptr result; + if ((device_list.empty()) || (index >= device_list.size())) { + MS_LOG(EXCEPTION) << "Index is out of the list scope"; + } + auto it = device_list.begin(); + for (; it != device_list.end(); ++it) { + if (i == index) { + result = *it; + break; + } + ++i; + } + return result; +} + +// E.g. devices = [4, 5, 2, 1, 7, 8, 10], stage_map = [4, 3], +// therefore the stage_devices_ = [[4, 5, 2, 1], [7, 8, 10]]. +Status DeviceManager::Init(const RankList &devices, int32_t global_device_rank, const RankList &stage_map, + const std::string &backend) { + auto dev_it = devices.begin(); + auto stage_it = stage_map.begin(); + int32_t sum = 0; + + if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { + MS_LOG(ERROR) << "Invalid backend: " << backend; + return Status::FAILED; + } + + for (; stage_it != stage_map.end(); ++stage_it) { + sum += (*stage_it); + } + if (IntToSize(sum) != devices.size()) { + MS_LOG(ERROR) << "The number of 'devices' in the list is not equal to the mentioned " + << "size of 'stage_map'"; + return Status::FAILED; + } + + for (; dev_it != devices.end(); ++dev_it) { + std::shared_ptr one = std::make_shared(*dev_it); + devices_.push_back(one); + } + + size_t global_index = 0; + for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { + int num_device = *stage_it; + if (num_device > MAX_DEVICE_NUM) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; + return Status::FAILED; + } + if (num_device <= 0) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; + return Status::FAILED; + } + RankList curr_dev_list; + for (int i = 0; i < num_device; ++i) { + curr_dev_list.push_back(GetListMemberByIndex(global_index, devices)); + global_index++; + } + stage_devices_.push_back(curr_dev_list); + } + + global_index = 0; + for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { + int num_device = *stage_it; + if (num_device > MAX_DEVICE_NUM) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must be less than " << MAX_DEVICE_NUM; + return Status::FAILED; + } + if (num_device <= 0) { + MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; + return Status::FAILED; + } + std::vector curr_dev_list; + for (int i = 0; i < num_device; ++i) { + curr_dev_list.push_back(*GetListMemberByIndex(global_index, devices_)); + global_index++; + } + std::shared_ptr new_stage = std::make_shared(curr_dev_list); + stages_.push_back(new_stage); + } + + std::shared_ptr dev = std::make_shared(global_device_rank); + device_ = dev; + set_global_rank(global_device_rank); + backend_ = backend; + + if (backend == HCCL_BACKEND) { + gm_.set_world_group(HCCL_WORLD_GROUP); + } else if (backend_ == NCCL_BACKEND) { + gm_.set_world_group(NCCL_WORLD_GROUP); + } else { + gm_.set_world_group(UNDEFINED_WORLD_GROUP); + } + MS_LOG(INFO) << "The device num: " << devices.size() << "rank id: " << global_device_rank + << "the backend: " << backend; + return Status::SUCCESS; +} + +std::shared_ptr DeviceManager::GetStageById(int32_t stage_id) { + std::shared_ptr res; + if (IntToSize(stage_id) >= stages_.size()) { + MS_LOG(ERROR) << "the 'stage_id': " << stage_id << ", is out of the scope of 'stage_devices_': " << stages_.size(); + return res; + } + int32_t index = 0; + for (auto &stage : stages_) { + if (index == stage_id) return stage; + index++; + } + return res; +} + +RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { + if (IntToSize(stage_id) >= stage_devices_.size()) + MS_LOG(ERROR) << "the 'stage_id': " << stage_id + << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); + RankList res; + int32_t index = 0; + for (auto &stage : stage_devices_) { + if (index == stage_id) { + return stage; + } + index++; + } + return res; +} + +RankList DeviceManager::global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const { + RankList res; + if (split_num <= 0) { + return res; + } + if (IntToSize(stage_id) >= stage_devices_.size()) { + MS_LOG(ERROR) << "the 'stage_id': " << stage_id + << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); + return res; + } + + RankList global_list = GetDeviceListByStageId(stage_id); + if (global_list.size() % IntToSize(split_num)) { + MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; + return res; + } + + std::vector dev_list; + (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); + + size_t index = 0; + size_t slice_size = dev_list.size() / IntToSize(split_num); + for (int32_t i = 0; i < split_num; ++i) { + bool found = false; + index = slice_size * IntToSize(i); + for (size_t j = 0; j < slice_size; ++j) { + if (dev_list[index + j] == rank) { + found = true; + break; + } + } + + if (found) { + break; + } + } + + for (size_t k = 0; k < slice_size; ++k) { + res.push_back(dev_list[index + k]); + } + return res; +} + +Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device(rank); } + +std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { + std::vector dev_list; + for (auto &rank : ranks) { + Device one = CreateNewDeviceByRank(rank); + dev_list.push_back(one); + } + return dev_list; +} + +DeviceManager &DeviceManager::GetInstance() { + static DeviceManager instance = DeviceManager(); + return instance; +} + +std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { + std::string tmp = "WORLD_GROUP"; + if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { + return tmp; + } + auto iter = group_to_rank_.find(hash_name); + if (iter == group_to_rank_.end()) { + MS_LOG(WARNING) << "Can not find the rank list name by hash name: " << hash_name; + return tmp; + } + return iter->second; +} + +std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } + +// Group name is generated using the increasing ranks of the devices. +// E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name +// is '0-1-3-5-7'. +std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { + std::string rank_list_name; + std::vector::iterator it; + std::sort(ranks.begin(), ranks.end()); // sorted in increasing order + for (it = ranks.begin(); it != ranks.end(); ++it) { + if (it == ranks.begin()) { + rank_list_name = std::to_string(*it); + } else { + rank_list_name += "-" + std::to_string(*it); + } + } + + // hash rank-list-name and add ranks' size as prefix + std::string group_hash_name = HashName(rank_list_name); + std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name; + + if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) { + if (group_to_rank_.find(group_name) == group_to_rank_.end()) { + rank_to_group_[rank_list_name] = group_name; + group_to_rank_[group_name] = rank_list_name; + MS_LOG(INFO) << "The rank list name is " << rank_list_name << "nd group name is " << group_name; + } else { + MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name + << "the old rank list:" << group_to_rank_.find(group_name)->second + << "the group name: " << group_name; + } + } + return group_name; +} + +// Create the group with the given devices and the given name. The GroupManager +// gm_ will create a new group only if there does not exit a group with the same +// name. Otherwise, let the pointer g point to that group. +Group DeviceManager::CreateGroup(const std::string &group_name, + const std::vector &devices) { + if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { + MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; + } + Group g; + (void)gm_.CreateGroup(group_name, devices, &g); + return g; +} + +// Create the group with only the given devices' ranks. +Group DeviceManager::CreateGroup(const RankList &dev_ranks) { + std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); + if (dev_ranks.size() != rank_set.size()) { + MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; + } + + std::string group_name = GenerateGroupNameByRanks(dev_ranks); + auto dev_list = CreateDeviceListByRankList(dev_ranks); + return CreateGroup(group_name, dev_list); +} + +void DeviceManager::Clear() { + devices_.clear(); + stage_devices_.clear(); + gm_.Clear(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h new file mode 100644 index 0000000000..654acd9dff --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "frontend/parallel/device.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/group_manager.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/strategy.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace parallel { +#define MAX_DEVICE_NUM 1024 + +constexpr char HCCL_BACKEND[] = "hccl"; +constexpr char NCCL_BACKEND[] = "nccl"; +constexpr char UNDEFINED_BACKEND[] = "undefined_backend"; + +class DeviceManager; +using DeviceManagerPtr = std::shared_ptr; +// 'g_device_manager' is the globally unique manager to manage the devices. +extern DeviceManagerPtr g_device_manager; + +class Stage { + // This class is used in pipeline-parallelization. Available devices are partitioned into multiple stages. + // Currently, the function of pipeline-parallelization and this class are NOT implemented. + public: + explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { + gm_ = GroupManager(); + } + Stage(const std::vector &devices, int num, int rank); + ~Stage() = default; + + int GetStageNum() const { return number_; } + size_t GetDevicesNum() const { return devices_.size(); } + std::vector GetDevicesList() { return devices_; } + int global_rank(Group *g) const; + + private: + std::vector devices_; + int number_; + int32_t rank_; + GroupManager gm_; +}; + +// This method is used for initializing the global DeviceManager 'g_device_manager', +// arguments including 'device_num' and 'global_rank' +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); + +void CheckGlobalDeviceManager(); + +std::string HashName(const std::string &rank_list_name); + +class DeviceManager { + // This class is used to manage the abstract devices, including group-related and stage-related management. + public: + DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } + ~DeviceManager() = default; + + Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); + + static DeviceManager &GetInstance(); + RankList GetDeviceListByStageId(int32_t stage_id) const; + RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; + + Device CreateNewDeviceByRank(int32_t rank) const; + std::vector CreateDeviceListByRankList(RankList ranks); + + std::string GenerateGroupNameByRanks(RankList dev_ranks); + Group CreateGroup(const std::string &group_name, const std::vector &devices); + Group CreateGroup(const RankList &dev_ranks); + std::shared_ptr GetStageById(int32_t stage_id); + + size_t DeviceNum() const { return devices_.size(); } + + int32_t GetStageNum() const { return static_cast(stage_devices_.size()); } + + int32_t global_rank() const { return global_rank_; } + std::string backend() const { return backend_; } + void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } + void Clear(); + std::string world_group() const { return gm_.world_group(); } + std::string FindRankListNameByHashName(const std::string &hash_name); + + private: + std::vector> devices_; + // each stage has a list of devices + std::vector> stage_devices_; + std::shared_ptr device_; + std::vector> stages_; + GroupManager gm_; + std::string backend_; + + // bimap: + std::map rank_to_group_; // the key is rank list, value is hash name + std::map group_to_rank_; // the key is hash name, value is rank list + + int32_t local_rank_; + int32_t global_rank_; + int32_t stage_num_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.cc b/mindspore/ccsrc/frontend/parallel/device_matrix.cc new file mode 100644 index 0000000000..9cc85d9701 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/device_matrix.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +DeviceMatrix::DeviceMatrix(int32_t rank, RankList dev_list, Shape dev_shape) + : rank_(rank), dev_list_(std::move(dev_list)), dev_shape_(std::move(dev_shape)) { + if (!std::any_of(dev_list_.begin(), dev_list_.end(), [rank](int32_t a) { return a == rank; })) { + MS_LOG(EXCEPTION) << "Rank " << rank << " is not in the current stage!"; + } + int32_t total = std::accumulate(dev_shape_.begin(), dev_shape_.end(), 1, std::multiplies()); + if (IntToSize(total) != dev_list_.size()) { + MS_LOG(EXCEPTION) << "Device shape does not match the size of the device list!"; + } +} + +Status DeviceMatrix::CreateGroupList() { + size_t size = dev_shape_.size(); + RankList group; + for (size_t i = 0; i < size; i++) { + Status status = GetDevicesAlongDim(SizeToUint(i), &group); + group_list_.push_back(group); + if (status == Status::FAILED) { + return Status::FAILED; + } + } + return Status::SUCCESS; +} + +Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { + if (dim >= dev_shape_.size()) { + MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; + } + if (dev_shape_[dim] == 1) { + *devices = {rank_}; + return Status::SUCCESS; + } + + RankList group; + std::vector local_group_list; + + // lower than dim + int32_t step = 1; + for (uint32_t i = dim + 1; i < dev_shape_.size(); i++) { + step = step * dev_shape_[i]; + } + int32_t num = *dev_list_.begin(); + for (int32_t i = 0; i < dev_shape_[dim]; i++) { + group.push_back(num); + num += step; + } + + for (int32_t i = 0; i < step; i++) { + local_group_list.push_back(group); + (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); + } + + // higher than dim + step = step * dev_shape_[dim]; + int32_t len = SizeToInt(dev_list_.size()) / step; + + // search rank + int32_t target = rank_; + for (int32_t i = 0; i < len; i++) { + for (RankList &temp : local_group_list) { + if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { + *devices = temp; + return Status::SUCCESS; + } + (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); + } + } + MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; + return Status::FAILED; +} + +Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { + Shape dev_coordinate; + for (size_t i = 0; i < dev_shape.size(); ++i) { + int32_t size = dev_shape[dev_shape.size() - i - 1]; + if (size == 0) { + MS_LOG(EXCEPTION) << "Invalid dev shape: " << ShapeToString(dev_shape); + } else { + int32_t index = rank % size; + (void)dev_coordinate.insert(dev_coordinate.begin(), index); + rank = rank / size; + } + } + return dev_coordinate; +} + +Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { + for (auto &element : tensor_map) { + // -1 means the corresponding dimension is not split. + if (element == MAP_NONE) { + continue; + } else if ((element < 0) || (IntToSize(element) >= dev_shape_.size())) { + MS_LOG(ERROR) << "create group by tensor map: the tensor map is invalid"; + return FAILED; + } + } + + Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); + for (auto &tmp_rank : dev_list_) { + Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); + bool matched = true; + for (auto &map : tensor_map) { + if (map == MAP_NONE) { + continue; + } + size_t index = dev_shape_.size() - IntToSize(map) - 1; + if (current_rank_coordinate[index] != tmp_rank_coordinate[index]) { + matched = false; + break; + } + } + if (matched) { + rank_list->push_back(tmp_rank); + } + } + + return SUCCESS; +} + +std::string ShapeToString(const Shape &shape) { + std::string str = "["; + for (size_t i = 0; i < shape.size(); ++i) { + str += std::to_string(shape[i]); + if (i < shape.size() - 1) { + str += ", "; + } + } + return str + "]"; +} + +std::string ListToString(const std::vector &list) { + std::string str = "["; + for (auto &element : list) { + str += std::to_string(element) + ", "; + } + return str + "]"; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/device_matrix.h b/mindspore/ccsrc/frontend/parallel/device_matrix.h new file mode 100644 index 0000000000..f1e7acec39 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/device_matrix.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ +#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ + +#include +#include +#include + +#include "frontend/parallel/status.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace parallel { +using RankList = std::vector; +using Shape = std::vector; + +class DeviceMatrix { + public: + DeviceMatrix(int32_t rank, RankList devices, Shape dev_shape); + DeviceMatrix() = default; + ~DeviceMatrix() = default; + std::vector group_list() const { return group_list_; } + Status CreateGroupList(); + Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); + Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); + + private: + int32_t rank_ = -1; + RankList dev_list_; + // From low dim to high dim. eg: [D0 D1 D2 D3] + Shape dev_shape_; + std::vector group_list_; +}; + +std::string ShapeToString(const Shape &shape); +std::string ListToString(const std::vector &list); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h new file mode 100644 index 0000000000..3ba40fade9 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -0,0 +1,139 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ +#define MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/ops_info_head_files.h" +#include "frontend/parallel/step_parallel.h" + +namespace mindspore { +namespace parallel { +#define REGISTER(className) \ + OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ + return std::make_shared(name, in, out, attrs); \ + } \ + RegisterAction className##Register(#className, (CreatFn)objectCreator##className); + +typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, + const PrimitiveAttrs &attrs); + +class DynCreator { + public: + ~DynCreator() = default; + + // creat static singleton dyn_creator instance + static DynCreator &Instance() { + static DynCreator fac = DynCreator(); + return fac; + } + // register + void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } + // creator + OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, + const PrimitiveAttrs &attrs, size_t count) { + std::string op_name = name + std::to_string(count); + auto iter = Function_map_.find(name); + if (iter == Function_map_.end()) { + MS_LOG(INFO) << name << " is not register yet"; + return nullptr; + } + return iter->second(op_name, shape_in, shape_out, attrs); + } + + private: + DynCreator() = default; + std::map Function_map_; +}; + +class RegisterAction { + public: + RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { + DynCreator::Instance().Regist(name, creatfn); + } + ~RegisterAction() = default; + + private: + std::string name_; +}; + +// operator register +REGISTER(MatMulInfo); +REGISTER(GeluInfo); +REGISTER(VirtualDatasetInfo); +REGISTER(BatchParallelInfo); +REGISTER(TanhInfo); +REGISTER(SoftmaxInfo); +REGISTER(LogSoftmaxInfo); +REGISTER(ActivationInfo); +REGISTER(SoftmaxCrossEntropyWithLogitsInfo); +REGISTER(SubInfo); +REGISTER(TensorAddInfo); +REGISTER(BiasAddInfo); +REGISTER(MulInfo); +REGISTER(DivInfo); +REGISTER(RealDivInfo); +REGISTER(PowInfo); +REGISTER(ExpInfo); +REGISTER(OneHotInfo); +REGISTER(EqualInfo); +REGISTER(NotEqualInfo); +REGISTER(LogInfo); +REGISTER(CosInfo); +REGISTER(ACosInfo); +REGISTER(LogicalNotInfo); +REGISTER(L2NormalizeInfo); +REGISTER(LayerNormInfo); +REGISTER(ReduceMaxInfo); +REGISTER(ArgMaxWithValueInfo); +REGISTER(ArgMinWithValueInfo); +REGISTER(ReduceMeanInfo); +REGISTER(ReduceSumInfo); +REGISTER(ReduceMinInfo); +REGISTER(TransposeInfo); +REGISTER(PReLUInfo); +REGISTER(DropoutDoMaskInfo); +REGISTER(ReshapeInfo); +REGISTER(FloorDivInfo); +REGISTER(MaximumInfo); +REGISTER(MinimumInfo); +REGISTER(CastInfo); +REGISTER(GreaterInfo); +REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); +REGISTER(AssignSubInfo); +REGISTER(ReLUInfo); +REGISTER(GatherV2Info); +REGISTER(SparseGatherV2Info); +REGISTER(SqrtInfo); +REGISTER(SigmoidInfo); +REGISTER(GetNextInfo); +REGISTER(NegInfo); +REGISTER(BatchMatMulInfo); +REGISTER(ExpandDimsInfo); +REGISTER(SqueezeInfo); +REGISTER(SigmoidCrossEntropyWithLogitsInfo); +REGISTER(SquareInfo); +REGISTER(GatherV2PInfo); +REGISTER(EmbeddingLookupInfo); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc new file mode 100644 index 0000000000..30c25e5f26 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/generate_graph.h" + +#include +#include +#include +#include + +using mindspore::tensor::Tensor; + +namespace mindspore { +namespace parallel { +std::string GetOpPythonPath(const OperatorName &op_name) { + // almost all ops are defined in two main paths + const std::string ops_module = OP_PATH; + const std::string inner_ops_module = INNER_OP_PATH; + py::module mod = py::module::import(common::SafeCStr(ops_module)); + py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); + if (!py::hasattr(mod, common::SafeCStr(op_name))) { + if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { + MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; + } + return inner_ops_module; + } + return ops_module; +} + +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { + std::string op_path = GetOpPythonPath(op_name); + py::module mod = py::module::import(common::SafeCStr(op_path)); + if (!py::hasattr(mod, common::SafeCStr(op_name))) { + MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name; + return nullptr; + } + std::vector arg_list; + (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), + [](const Attr &attr) { return ValuePtrToPyData(attr.second); }); + py::object obj = + parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list); + ValuePtr op_instance = nullptr; + bool succ = parse::ConvertData(obj, &op_instance); + if (!succ) { + MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail"; + return nullptr; + } + return op_instance; +} + +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) { + auto value_node = NewValueNode(value_ptr); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->cast(); +} + +static std::unordered_map int_tensor_map = {}; +AnfNodePtr CreateInt32Tensor(int32_t value) { + auto it = int_tensor_map.find(value); + if (it != int_tensor_map.end()) { + return it->second; + } + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(py::int_(value), kInt32); + ValuePtr value_ptr = MakeValue(tensor_ptr); + auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); + int_tensor_map[value] = anf_node_ptr; + return anf_node_ptr; +} + +AnfNodePtr CreatTypeInt(int32_t value) { + ValuePtr value_ptr = MakeValue(std::make_shared(value)); + return ValuePtrToAnfNodePtr(value_ptr); +} + +AnfNodePtr CreatInt32Imm(int32_t value) { + ValuePtr value_ptr = MakeValue(std::make_shared(value)); + return ValuePtrToAnfNodePtr(value_ptr); +} + +std::string GetInstanceNameByCNode(const CNodePtr &cnode) { + PrimitivePtr prim = GetValueNode(cnode->input(0)); + if (!prim) { + MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr."; + } + std::string instance_name = prim->instance_name(); + return HashInstanceName(instance_name); +} + +std::string HashInstanceName(const std::string &name) { + auto using_hash_name = common::GetEnv(USING_HASH_NAME); + std::string instance_name; + if ((using_hash_name.empty()) || (using_hash_name == "on")) { + instance_name = HashName(name); + } else { + instance_name = name; + } + return instance_name; +} + +Status GenerateGraph::Init(const CNodePtr &cnode) { + if (!cnode) { + MS_LOG(ERROR) << "Init:cnode is nullptr"; + return FAILED; + } + cnode_ = cnode; + func_graph_ = cnode->func_graph(); + if (!func_graph_) { + MS_LOG(ERROR) << "Init:func_graph_ is nullptr"; + return FAILED; + } + manager_ = func_graph_->manager(); + if (!manager_) { + MS_LOG(ERROR) << "Init:manager_ is nullptr"; + return FAILED; + } + scope_ = cnode_->scope(); + if (!scope_) { + MS_LOG(ERROR) << "Init:scope_ is nullptr"; + return FAILED; + } + virtual_input_node_ = std::make_shared(nullptr); + virtual_input_node_->set_scope(scope_); + instance_name_base_ = GetInstanceNameByCNode(cnode_); + name_idx_ = 0; + return SUCCESS; +} + +AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { + CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_scope(scope_); + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "inputs.size() must be more than 1"; + } + (void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0] + auto new_anf_node_ptr = cnode->cast(); + MS_EXCEPTION_IF_NULL(new_anf_node_ptr); + return new_anf_node_ptr; +} + +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) { + name_idx_++; + ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_)); + if (pyop_instance == nullptr) { + MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed"; + } + auto value_node = NewValueNode(pyop_instance); + return value_node->cast(); +} + +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) { + name_idx_++; + OperatorAttrs attrs; + ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_)); + if (pyop_instance == nullptr) { + MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed"; + } + auto value_node = NewValueNode(pyop_instance); + return value_node->cast(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h new file mode 100644 index 0000000000..b3ef54a22e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ +#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "./common.h" +#include "frontend/optimizer/opt.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +#define USING_HASH_NAME "USING_HASH_NAME" +// Get the operator's path where the operator has be defined +std::string GetOpPythonPath(const OperatorName &op_name); + +// Init python operator Instance +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); + +AnfNodePtr CreatTypeInt(int32_t value); +AnfNodePtr CreatInt32Imm(int32_t value); +AnfNodePtr CreateInt32Tensor(int32_t value); +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); +std::string HashInstanceName(const std::string &name); + +class GenerateGraph { + public: + GenerateGraph() : name_idx_(0) {} + Status Init(const CNodePtr &cnode); + ~GenerateGraph() = default; + AnfNodePtr virtual_input_node() { return virtual_input_node_; } + AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); + AnfNodePtr NewOpInst(const OperatorName &op_name); + AnfNodePtr PushBack(const std::vector &inputs); + + private: + CNodePtr cnode_; + FuncGraphManagerPtr manager_; + ScopePtr scope_; + FuncGraphPtr func_graph_; + AnfNodePtr virtual_input_node_; + std::string instance_name_base_; + int64_t name_idx_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc new file mode 100644 index 0000000000..21298697f4 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/get_parallel_info.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "ir/func_graph.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +py::dict GetParameterLayout(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + py::dict dict; + std::vector graph_params = graph->parameters(); + + for (auto para : graph_params) { + std::string name = std::static_pointer_cast(para)->name(); + std::shared_ptr tensor_layout = std::static_pointer_cast(para)->tensor_layout(); + if (tensor_layout == nullptr) { + MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; + } else { + auto device_arrangement = tensor_layout->device_arrangement().array(); + auto tensor_map = tensor_layout->tensor_map().array(); + auto slice_shape = tensor_layout->slice_shape().array(); + std::vector> layout = {device_arrangement, tensor_map, slice_shape}; + dict[py::str(name)] = layout; + MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); + } + } + return dict; +} + +py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + py::dict dict; + auto ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + auto nodes = DeepScopedGraphSearch(ret); + + for (auto node : nodes) { + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto distributed_operation_info = cnode->operator_info(); + if (distributed_operation_info != nullptr) { + auto strategyPtr = distributed_operation_info->strategy(); + if (strategyPtr != nullptr) { + auto strategy = strategyPtr->GetInputDim(); + auto name = cnode->fullname_with_scope(); + dict[py::str(name)] = strategy; + } + } + } + } + return dict; +} + +py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + py::dict dict; + auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); + + for (auto prim : allreduce_prim_list) { + auto name_ptr = prim->GetAttr("parameter"); + auto fusion_ptr = prim->GetAttr("fusion"); + if (fusion_ptr == nullptr) { + MS_LOG(EXCEPTION) << "fusion_ptr is nullptr"; + } else if (name_ptr == nullptr) { + continue; + } + if (!name_ptr->isa()) { + MS_LOG(EXCEPTION) << "name is not StringImm"; + } + auto name = name_ptr->cast()->value(); + if (!fusion_ptr->isa()) { + MS_LOG(EXCEPTION) << "fusion is not Int32Imm"; + } + int32_t fusion = fusion_ptr->cast()->value(); + dict[py::str(name)] = fusion; + } + return dict; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h similarity index 100% rename from mindspore/ccsrc/parallel/graph_util/get_parallel_info.h rename to mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.h diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc new file mode 100644 index 0000000000..45a88c3a23 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/graph_info.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "debug/draw.h" +#include "ir/func_graph.h" +#include "utils/context/ms_context.h" +#include "utils/graph_utils.h" + +namespace mindspore { +namespace parallel { +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { + AnfNodePtr ret = graph->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + std::vector prim_list; + for (auto &node : all_nodes) { + if (!IsValueNode(node)) { + continue; + } + ValueNodePtr prim_node_anf = node->cast(); + MS_EXCEPTION_IF_NULL(prim_node_anf); + PrimitivePtr node_prim = prim_node_anf->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == name) { + prim_list.emplace_back(node_prim); + } + } + return prim_list; +} + +void DumpGraph(const FuncGraphPtr &root, const std::string &name) { + if (MsContext::GetInstance()->save_graphs_flag()) { + draw::Draw(name + ".dot", root); + DumpIR(name + ".ir", root); + ExportIR(name + ".dat", "0", root); + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h similarity index 100% rename from mindspore/ccsrc/parallel/graph_util/graph_info.h rename to mindspore/ccsrc/frontend/parallel/graph_util/graph_info.h diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc new file mode 100644 index 0000000000..e50df2818b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/graph_util/node_info.h" + +#include + +#include "ir/anf.h" +#include "ir/param_value.h" +#include "pipeline/jit/parse/python_adapter.h" + +namespace mindspore { +namespace parallel { +std::string ParameterName(const AnfNodePtr &node_ptr) { + auto para_ptr = node_ptr->cast(); + MS_EXCEPTION_IF_NULL(para_ptr); + return para_ptr->name(); +} + +bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { + auto para_ptr = node_ptr->cast(); + if (para_ptr == nullptr) { + return false; + } + if (!para_ptr->has_default()) { + return false; + } + return para_ptr->default_param()->requires_grad(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h similarity index 100% rename from mindspore/ccsrc/parallel/graph_util/node_info.h rename to mindspore/ccsrc/frontend/parallel/graph_util/node_info.h diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc new file mode 100644 index 0000000000..8929af7b0b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/group_manager.h" + +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "utils/comm_manager.h" + +namespace mindspore { +namespace parallel { +Group::Group() { + name_.clear(); + devices_.clear(); +} + +Status Group::Init(const std::string &name, const std::vector &devices) { + this->name_ = name; + this->devices_ = devices; + return Status::SUCCESS; +} + +std::vector Group::GetDevicesList() const { return devices_; } + +bool Group::IsInThisGroup(int32_t device_rank) { + for (auto &device : devices_) { + if (device.rank() == device_rank) { + return true; + } + } + return false; +} + +// Get the position of the device in the group +Status Group::GetIndex(size_t *index) { + size_t pos = 0; + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + for (auto &device : devices_) { + if (device.rank() == rank) { + *index = pos; + return Status::SUCCESS; + } else { + pos++; + } + } + MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!"; + return Status::FAILED; +} + +GroupManager::GroupManager() { groups_.clear(); } + +Status GroupManager::CreateGroup(const std::string &group_name, const std::vector &devices, + mindspore::parallel::Group *const group) { + // it is simple to use size to determine whether it is a world group + uint32_t world_size = 0; + if (world_group_ != NCCL_WORLD_GROUP) { + (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); + } + + if ((world_group_ == NCCL_WORLD_GROUP) || (devices.size() == world_size)) { + auto it = groups_.find(world_group_); + if (it == groups_.end()) { + (void)group->Init(world_group_, devices); + groups_[world_group_] = *group; + } else { + *group = it->second; + } + MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it."; + return Status::SUCCESS; + } + + auto it = groups_.find(group_name); + // If there already exits a group with the desired 'name', + // let the pointer point to the group. + if (it != groups_.end()) { + *group = it->second; + return Status::SUCCESS; + } else { + (void)group->Init(group_name, devices); + groups_[group_name] = *group; + + vector ranks; + (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), + [](const Device dev) { return (uint32_t)dev.rank(); }); + // Create group through the CommManager interface + bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); + if (!ret) { + MS_LOG(ERROR) << "Create group failed, group name is " << group_name; + return Status::FAILED; + } + + MS_LOG(INFO) << "Create group success, group name is " << group_name; + return Status::SUCCESS; + } +} + +Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { + std::string name = (*group).name(); + auto it = groups_.find(name); + if (it == groups_.end()) { + MS_LOG(ERROR) << "Could not find group name :" << name; + return Status::FAILED; + } + (void)groups_.erase(it); + bool ret = CommManager::GetInstance().DestroyGroup(name); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status GroupManager::DestroyAllGroups() { + for (auto &it : groups_) { + std::string name = it.first; + bool ret = CommManager::GetInstance().DestroyGroup(name); + if (!ret) { + return Status::FAILED; + } + } + groups_.clear(); + return Status::SUCCESS; +} + +Status GroupManager::GetRankID(const std::string &name, unsigned int *const rank_id) { + auto it = groups_.find(name); + if (it == groups_.end()) { + MS_LOG(ERROR) << "Could not find group name :" << name; + return Status::FAILED; + } + bool ret = CommManager::GetInstance().GetRankID(name, rank_id); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status GroupManager::GetRankSize(const std::string &name, unsigned int *const rank_size) { + auto it = groups_.find(name); + if (it == groups_.end()) { + MS_LOG(ERROR) << "Could not find group name :" << name; + return Status::FAILED; + } + bool ret = CommManager::GetInstance().GetRankSize(name, rank_size); + if (!ret) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) { + auto it = groups_.find(name); + if (it == groups_.end()) { + return Status::FAILED; + } + *group = &it->second; + return Status::SUCCESS; +} + +void GroupManager::Clear() { (void)DestroyAllGroups(); } +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.h b/mindspore/ccsrc/frontend/parallel/group_manager.h new file mode 100644 index 0000000000..b9cf9663b0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/group_manager.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ +#define MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/device.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +constexpr char HCCL_WORLD_GROUP[] = "hccl_world_group"; +constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; +constexpr char UNDEFINED_WORLD_GROUP[] = "undefined_world_group"; + +// Devices that need communication should in the same group. These classes are used to +// create and destroy group among devices. +class Group { + public: + Group(); + ~Group() = default; + Status Init(const std::string &name, const std::vector &devices); + std::vector GetDevicesList() const; + std::string name() const { return name_; } + bool IsInThisGroup(int32_t device_rank); + Status GetIndex(size_t *index); + size_t GetDevNum() const { return devices_.size(); } + + private: + std::string name_; + std::vector devices_; +}; + +class GroupManager { + public: + GroupManager(); + ~GroupManager() = default; + + Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); + Status DestroyGroup(Group *group); + Status DestroyAllGroups(); + Status GetRankID(const std::string &name, unsigned int *rank_id); + Status GetRankSize(const std::string &name, unsigned int *rank_size); + Status FindGroup(const std::string &name, Group **group); + std::string world_group() const { return world_group_; } + void set_world_group(const std::string &name) { world_group_ = name; } + void Clear(); + + private: + // the key is group name (name_) + std::map groups_; + std::string world_group_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/node_check.cc b/mindspore/ccsrc/frontend/parallel/node_check.cc new file mode 100644 index 0000000000..de29417a4d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/node_check.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/node_check.h" + +#include +#include + +#include "frontend/parallel/ops_info/ops_utils.h" + +namespace mindspore { +namespace parallel { +const std::set BLACK_LIST = {TUPLE_GETITEM, + MAKE_TUPLE, + J, + LIST_GETITEM, + ARRAY_GETITEM, + TUPLE_SETITEM, + DEPEND, + LIST_SETITEM, + ARRAY_SETITEM, + DICT_GETITEM, + LIST_APPEND, + LIST_MAP, + LIST_REDUCE, + TUPLE_REVERSED, + TILE_SHAPE, + TUPLE_DIV, + TUPLE_TO_ARRAY, + MAKE_LIST, + MAKE_DICT, + MAKE_SLICE, + MAKE_RECORD, + STRING_EQUAL, + VIRTUALLOSS, + RETURN, + ENV_GETITEM, + IDENTITY, + PARTIAL, + ENVSETITEM, + ENVGETITEM, + ENVADD, + MAKEREFKEY, + MAKEREF, + GETREFKEY, + GETREFVALUE, + GETREFORIGIN, + DOT, + IM2COL, + COL2IM, + IM2COLV1, + STATESETITEM, + SCALARSUMMARY, + IMAGESUMMARY, + TENSORSUMMARY, + DEBUG, + HISTOGRAMSUMMARY, + COL2IMV1, + RESOLVE, + BROADCASTGRADIENTARGS, + INVERTPERMUTATION, + CONTROLDEPEND, + DROPOUT_GEN_MASK, + EMBED, + CREATINSTANCE, + ZEROSLIKE, + ASSIGN, + REF_TO_EMBED, + STOP_GRADIENT}; + +bool IsInBlackList(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/node_check.h b/mindspore/ccsrc/frontend/parallel/node_check.h similarity index 100% rename from mindspore/ccsrc/parallel/node_check.h rename to mindspore/ccsrc/frontend/parallel/node_check.h diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc new file mode 100644 index 0000000000..35cac1480c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc @@ -0,0 +1,705 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/activation_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status Activation::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ActivationInfo::GetAttrs() { + if (attrs_.size() < ACTIVATION_ATTR_SIZE) { + MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; + return FAILED; + } + + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + auto iter = attrs_.find(ACTIVATION_TYPE); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + std::string val = iter->second->cast()->value(); + if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) { + MS_LOG(ERROR) << name_ << " : Activation type is wrong."; + return FAILED; + } + } else { + MS_LOG(ERROR) << name_ << " : The value of activation_type is not string."; + return FAILED; + } + } + + return SUCCESS; +} + +Status ActivationOther::GetAttrs() { + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + return SUCCESS; +} + +Status Activation::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status Softmax::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + for (auto &element : axis_) { + int32_t axis_index = element; + if (element < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + element; + } + + int32_t axis_strategy = input_strategy.at(IntToSize(axis_index)); + // Dimension corresponding to axis is un-splittable + if (axis_strategy != MIN_SLICE_NUM) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; + } else { + MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; + } + return FAILED; + } + } + + return SUCCESS; +} + +Status Softmax::GetAttrs() { + if (attrs_.size() < SOFTMAX_ATTR_SIZE) { + MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; + return FAILED; + } + + auto iter = attrs_.find(AXIS); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { // the axis is a number + int32_t axis_element = iter->second->cast()->value(); + axis_.push_back(axis_element); + MS_LOG(INFO) << name_ << " : The axis is int, value is " << axis_element; + } else if (iter->second->isa()) { // the axis is a tuple + ValueTuplePtr value_tuple = iter->second->cast(); + if (value_tuple == nullptr) { + MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr."; + return FAILED; + } + std::vector value_vector = value_tuple->value(); + (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); + if (axis_.empty()) { + MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_); + } else { + MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int."; + return FAILED; + } + } + + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + + // for example: tensor dimension is 4, then axis range [-4, 3] + int32_t dim = SizeToInt(inputs_shape_.at(0).size()); + auto it = + std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); }); + if (it != axis_.end()) { + MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "]."; + return FAILED; + } + + return SUCCESS; +} + +Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status Softmax::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split; + (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); + for (auto &element : axis_) { + int32_t axis_index = element; + if (element < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + element; + } + input0_split[IntToSize(axis_index)] = 0; + } + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ActivationBase::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + dev_matrix_shape_ = input_strategy; + + return SUCCESS; +} + +Status ActivationBase::InferMirrorOps() { + mirror_ops_.clear(); + + Shape tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector mirror_op; + if (group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + std::string group_name = group[0].name(); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; + } + + return SUCCESS; +} + +Status ActivationBase::InferForwardCommunication() { + // do nothing + return SUCCESS; +} + +Status ActivationBase::InferTensorMap() { + std::vector tensor_map_index; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - i - 1)); + } + + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + return SUCCESS; +} + +Status ActivationBase::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + + TensorLayout input_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(input_tensor_info); // the same as input + + return SUCCESS; +} + +Status ActivationBase::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status CastInfo::InferMirrorOps() { + mirror_ops_.clear(); + + Shape tensor_map = inputs_tensor_map_[0]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector mirror_op; + OperatorVector op_for_value; + if (group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + mirror_ops_.push_back(op_for_value); + std::string group_name = group[0].name(); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; + } + + return SUCCESS; +} + +Status ExpandDimsInfo::GetAttrs() { + if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size(); + return FAILED; + } + + if (!input_value_.back()->isa()) { + MS_LOG(ERROR) << name_ << ": The type of axis is not int"; + return FAILED; + } + + int32_t axis = GetValue(input_value_.back()); + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + int32_t dim = SizeToInt(inputs_shape_[0].size()); + if ((axis > dim) || (axis < -dim - 1)) { + MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]"; + return FAILED; + } + + if (axis < 0) { + positive_axis_ = dim + axis + 1; + } else { + positive_axis_ = axis; + } + MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_; + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorMap() { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // for example: if the dimension of input is 3, and the axis is 2, + // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] + std::vector input_tensor_map, output_tensor_map; + size_t size = inputs_shape_[0].size(); + for (size_t i = 0; i < size; ++i) { + input_tensor_map.push_back(SizeToInt(size - i - 1)); + } + + inputs_tensor_map_.push_back(input_tensor_map); + + output_tensor_map = input_tensor_map; + if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) { + MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; + return FAILED; + } + (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP); + outputs_tensor_map_.push_back(output_tensor_map); + + MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) + << ", and the tensor map of output is " << ShapeToString(output_tensor_map); + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorStrategy() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + inputs_strategy_ = strategy_->GetInputDim(); + if (inputs_strategy_.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + Shape output_strategy = inputs_strategy_[0]; + if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) { + MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; + return FAILED; + } + (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY); + outputs_strategy_ = {output_strategy}; + return SUCCESS; +} + +Status ExpandDimsInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; + return FAILED; + } + + if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; + return FAILED; + } + + Shape input_shape = inputs_shape_[0]; + Shape output_shape = outputs_shape_[0]; + + // infer slice shape + if (InferTensorStrategy() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed"; + return FAILED; + } + Shapes inputs_slice_shape, outputs_slice_shape; + if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; + return FAILED; + } + + if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { + MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; + return FAILED; + } + + Shape input_slice_shape = inputs_slice_shape[0]; + Shape output_slice_shape = outputs_slice_shape[0]; + + TensorLayout input_tensor_layout, output_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; + return FAILED; + } + + if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status ExpandDimsInfo::InferMirrorOps() { + mirror_ops_.clear(); + + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty"; + return FAILED; + } + + std::vector group; + if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed"; + return FAILED; + } + + if (group.empty()) { + MS_LOG(INFO) << name_ << ": No need to create mirror ops"; + return SUCCESS; + } + + OperatorVector mirror_op, placeholder_op; + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + mirror_ops_.push_back(mirror_op); + mirror_ops_.push_back(placeholder_op); + MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); + return SUCCESS; +} + +Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { + std::vector axis; + auto axis_list = value_tuple->value(); + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + Shape input_shape = inputs_shape_.at(0); + size_t input_size = input_shape.size(); + // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. + if (axis_list.empty()) { + for (size_t i = 0; i < input_size; ++i) { + if (input_shape[i] == 1) { + axis.push_back(i); + } + } + axis_ = MakeValue(axis)->cast(); + return SUCCESS; + } + + // convert negative axis to positive. + for (auto &dim : axis_list) { + if (!dim->isa()) { + MS_LOG(ERROR) << name_ << ": The type of axis is not int"; + return FAILED; + } + int32_t dim_value = GetValue(dim); + int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; + axis.push_back(positive_value); + } + axis_ = MakeValue(axis)->cast(); + return SUCCESS; +} + +Status SqueezeInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; + return FAILED; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto value_tuple = iter->second->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + InferAxis(value_tuple); + attrs_[AXIS] = axis_; + return SUCCESS; +} + +Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { + Attr attr = std::make_pair(AXIS, axis_); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + replace_op_ = {std::make_pair(SQUEEZE, args)}; + return SUCCESS; +} + +Status SqueezeInfo::InferTensorMap() { + // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), + // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] + std::vector input_tensor_map, output_tensor_map; + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + size_t size = inputs_shape_[0].size(); + std::vector axis = GetValue>(axis_); + for (size_t i = 0; i < size; ++i) { + size_t index = size - i - 1; + auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); + if (iter == axis.end()) { + output_tensor_map.push_back(SizeToInt(index)); + } + input_tensor_map.push_back(SizeToInt(index)); + } + inputs_tensor_map_.push_back(input_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) + << ", and the tensor map of output is " << ShapeToString(output_tensor_map); + + return SUCCESS; +} + +Status SqueezeInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; + return FAILED; + } + + if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; + return FAILED; + } + + Shape input_shape = inputs_shape_[0]; + Shape output_shape = outputs_shape_[0]; + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy; + std::vector axis = GetValue>(axis_); + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); + if (iter == axis.end()) { + output_strategy.push_back(inputs_strategy[0].at(i)); + } + } + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; + return FAILED; + } + + if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { + MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; + return FAILED; + } + + Shape input_slice_shape = inputs_slice_shape[0]; + Shape output_slice_shape = outputs_slice_shape[0]; + + // infer tensor layout + TensorLayout input_tensor_layout, output_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; + return FAILED; + } + + if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status SqueezeInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + } + + if (InferReplaceOps(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h new file mode 100644 index 0000000000..a74707efbe --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h @@ -0,0 +1,224 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ActivationBase : public OperatorInfo { + public: + ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} + ~ActivationBase() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + protected: + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; +}; + +class Activation : public ActivationBase { + public: + Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~Activation() override = default; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; +}; + +class ActivationInfo : public Activation { + public: + ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Activation(name, inputs_shape, outputs_shape, attrs) {} + ~ActivationInfo() override = default; + + protected: + Status GetAttrs() override; // activation_type: relu, relu6, sigmoid +}; + +class ActivationOther : public Activation { + public: + ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Activation(name, inputs_shape, outputs_shape, attrs) {} + ~ActivationOther() override = default; + + protected: + Status GetAttrs() override; +}; + +class GeluInfo : public ActivationOther { + public: + GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~GeluInfo() override = default; +}; + +class TanhInfo : public ActivationOther { + public: + TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~TanhInfo() override = default; +}; + +class Softmax : public ActivationBase { + public: + explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~Softmax() override = default; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + + private: + std::vector axis_; +}; + +class SoftmaxInfo : public Softmax { + public: + SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Softmax(name, inputs_shape, outputs_shape, attrs) {} + ~SoftmaxInfo() override = default; +}; + +class LogSoftmaxInfo : public Softmax { + public: + LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Softmax(name, inputs_shape, outputs_shape, attrs) {} + ~LogSoftmaxInfo() override = default; +}; + +class ReLUInfo : public ActivationOther { + public: + ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ReLUInfo() override = default; +}; + +class CastInfo : public ActivationOther { + public: + CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~CastInfo() override = default; + + protected: + Status InferMirrorOps() override; +}; + +class SqrtInfo : public ActivationOther { + public: + SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SqrtInfo() override = default; +}; + +class NegInfo : public ActivationOther { + public: + NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~NegInfo() override = default; +}; + +class ExpandDimsInfo : public ActivationOther { + public: + ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ExpandDimsInfo() override = default; + + protected: + Status GetAttrs() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferMirrorOps() override; + Status InferTensorStrategy(); + + private: + int32_t positive_axis_ = -1; + Strategys inputs_strategy_; + Strategys outputs_strategy_; +}; + +class SqueezeInfo : public ActivationOther { + public: + SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SqueezeInfo() override = default; + + protected: + Status InferAxis(const ValueTuplePtr &value_tuple); + Status GetAttrs() override; + Status InferReplaceOps(const StrategyPtr &strategy); + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status Init(const StrategyPtr &strategy) override; + + private: + ValueTuplePtr axis_; +}; + +class SquareInfo : public ActivationOther { + public: + SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SquareInfo() override = default; +}; + +class SigmoidInfo : public ActivationOther { + public: + SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SigmoidInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc new file mode 100644 index 0000000000..1dd9c899ca --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc @@ -0,0 +1,363 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/arithmetic_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) { + size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size(); + for (size_t num = 0; num < insert_num; ++num) { + (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1); + } + return smaller_size_shape; +} + +Shapes ArithmeticBase::InferExpendShape() { + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + Shapes input_shapes; + size_t input_a_size = input_a_shape.size(); + size_t input_b_size = input_b_shape.size(); + if (input_a_size > input_b_size) { + input_shapes.push_back(input_a_shape); + input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape)); + } else if (input_a_size < input_b_size) { + input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape)); + input_shapes.push_back(input_b_shape); + } else { + input_shapes.push_back(input_a_shape); + input_shapes.push_back(input_b_shape); + } + return input_shapes; +} + +std::vector ExpendStrategy(const StrategyPtr &strategy) { + std::vector expend_strategy; + std::vector stra = strategy->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + size_t input_a_size = sub_a_strategy.size(); + size_t input_b_size = sub_b_strategy.size(); + if (input_a_size > input_b_size) { + expend_strategy.push_back(sub_a_strategy); + expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy)); + } else if (input_a_size < input_b_size) { + expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy)); + expend_strategy.push_back(sub_b_strategy); + } else { + expend_strategy = stra; + } + return expend_strategy; +} + +Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + Shapes input_shapes = InferExpendShape(); + std::vector expend_strategy = ExpendStrategy(strategy); + Dimensions sub_a_strategy = expend_strategy.at(0); + Dimensions sub_b_strategy = expend_strategy.at(1); + Shape input_a_shape = input_shapes.at(0); + Shape input_b_shape = input_shapes.at(1); + + for (size_t i = 0; i < input_a_shape.size(); ++i) { + if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + } + return SUCCESS; +} + +Status ArithmeticBase::InferDevMatrixShape() { + std::vector expend_strategy = ExpendStrategy(strategy_); + Dimensions sub_a_strategy = expend_strategy.at(0); + Dimensions sub_b_strategy = expend_strategy.at(1); + Shape dev_shape; + for (size_t i = 0; i < sub_a_strategy.size(); ++i) { + if (sub_a_strategy[i] != sub_b_strategy[i]) { + dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]); + } else { + dev_shape.push_back(sub_a_strategy[i]); + } + } + dev_matrix_shape_ = dev_shape; + + return SUCCESS; +} + +TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) { + TensorMap tensor_map_index; + for (size_t i = 0; i < strategy.size(); ++i) { + if (strategy[i] == dev_matrix_shape[i]) { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i)); + } else { + tensor_map_index.push_back(-1); + } + } + return tensor_map_index; +} + +TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) { + TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape); + size_t dev_matrix_size = dev_matrix_shape.size(); + size_t strategy_size = strategy.size(); + if (dev_matrix_size != strategy_size) { + (void)expend_map.erase(expend_map.begin(), + expend_map.begin() + static_cast(dev_matrix_size - strategy_size)); + } + return expend_map; +} + +void ArithmeticBase::ReComputeBatchSplitFlagList() { + Shapes expend_shapes = InferExpendShape(); + Shape expend_a_shape = expend_shapes.at(0); + Shape expend_b_shape = expend_shapes.at(1); + if (expend_a_shape.size() != expend_b_shape.size()) { + MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong."; + } + if (expend_a_shape.empty()) { + split_flag_list_[0] = false; + split_flag_list_[1] = false; + return; + } + (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false); + (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false); +} + +Status ArithmeticBase::InferTensorMap() { + std::vector tensor_map_index; + std::vector expend_strategy = ExpendStrategy(strategy_); + Dimensions sub_a_expend_strategy = expend_strategy.at(0); + Dimensions sub_b_expend_strategy = expend_strategy.at(1); + Strategys stra = strategy_->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i)); + } + + Shape dev_shape; + for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { + if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { + dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); + } else { + dev_shape.push_back(sub_a_expend_strategy[i]); + } + } + inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy)); + inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy)); + outputs_tensor_map_.push_back(tensor_map_index); + + return SUCCESS; +} + +Status ArithmeticBase::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + Shape input_b_tensor_map = inputs_tensor_map_.at(1); + std::vector input_a_group, input_b_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input b failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b; + if (input_a_group.empty() && input_b_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } + if (!input_a_group.empty()) { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + if (!input_b_group.empty()) { + op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); + } + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + + return SUCCESS; +} + +Status ArithmeticBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, + const Shape &dev_matrix_array) { + if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); + TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); + TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); + Shape input_a_shape_array = inputs_shape_.at(0); + Shape input_b_shape_array = inputs_shape_.at(1); + Shape out_shape_array = outputs_shape_.at(0); + + TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; + if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; + return FAILED; + } + if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; + return FAILED; + } + if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; + return FAILED; + } + inputs_layout->push_back(input_a_tensor_layout); + inputs_layout->push_back(input_b_tensor_layout); + outputs_layout->push_back(out_tensor_layout); + + return SUCCESS; +} + +Status ArithmeticBase::InferTensorInfo() { + // infer tensor shape + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + std::vector expend_strategy = ExpendStrategy(strategy_); + Dimensions sub_a_expend_strategy = expend_strategy.at(0); + Dimensions sub_b_expend_strategy = expend_strategy.at(1); + Strategys inputs_strategy = strategy_->GetInputDim(); + Shape dev_shape; + for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { + if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { + dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); + } else { + dev_shape.push_back(sub_a_expend_strategy[i]); + } + } + Strategys outputs_strategy = {dev_shape}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_a_slice_shape = inputs_slice_shape.at(0); + Shape input_b_slice_shape = inputs_slice_shape.at(1); + Shape output_slice_shape = outputs_slice_shape.at(0); + + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; + return FAILED; + } + + TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); + TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); + TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a + inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b + outputs_tensor_info_.push_back(out_tensor_info); // output + + return SUCCESS; +} + +Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shape input1_split(inputs_shape_[1].size(), 1); + Shapes splittable_inputs = {input0_split, input1_split}; + + std::vector sp_vector; + is_auto_parallel_ = true; + if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies with broadcast failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ArithmeticBase::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h new file mode 100644 index 0000000000..1d347e4ec1 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h @@ -0,0 +1,135 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ArithmeticBase : public OperatorInfo { + public: + ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} + ~ArithmeticBase() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); + Shapes InferExpendShape(); +}; + +class SubInfo : public ArithmeticBase { + public: + SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~SubInfo() override = default; +}; + +class TensorAddInfo : public ArithmeticBase { + public: + TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TensorAddInfo() override = default; +}; + +class MulInfo : public ArithmeticBase { + public: + MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MulInfo() override = default; +}; + +class DivInfo : public ArithmeticBase { + public: + DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~DivInfo() override = default; +}; + +class RealDivInfo : public ArithmeticBase { + public: + RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~RealDivInfo() override = default; +}; + +class FloorDivInfo : public ArithmeticBase { + public: + FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~FloorDivInfo() override = default; +}; + +class PowInfo : public ArithmeticBase { + public: + PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~PowInfo() override = default; +}; + +class GreaterInfo : public ArithmeticBase { + public: + GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~GreaterInfo() override = default; +}; + +class AssignSubInfo : public ArithmeticBase { + public: + AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~AssignSubInfo() override = default; +}; + +// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. +class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { + public: + SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~SigmoidCrossEntropyWithLogitsInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc new file mode 100644 index 0000000000..64aceb90f6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/batch_parallel_info.h" + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" + +namespace mindspore { +namespace parallel { +Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + int32_t stage = strategy->GetInputStage(); + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + dev_num_ = dev_num; + + size_t strategy_size = strategy->GetInputNumber(); + std::vector stra = strategy->GetInputDim(); + for (size_t i = 0; i < strategy_size; ++i) { + Shape sub_strategy = stra.at(i); + size_t strategy_len = sub_strategy.size(); + bool flag = false; + for (size_t j = 0; j < strategy_len; ++j) { + int32_t strategy_value = sub_strategy.at(j); + if (strategy_value > 1) { + if (flag || strategy_value != dev_num_) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : It is not a valid data parallel strategy."; + } else { + MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; + } + return FAILED; + } + flag = true; + } + } + } + return SUCCESS; +} + +Status BatchParallelInfo::InferDevMatrixShape() { + dev_matrix_shape_.push_back(dev_num_); + return SUCCESS; +} + +Status BatchParallelInfo::InferMirrorOps() { + mirror_ops_.clear(); + if (g_device_manager->DeviceNum() == 1) { + MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops."; + return SUCCESS; + } + + MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber(); + for (size_t i = 0; i < input_value_.size(); i++) { + MS_EXCEPTION_IF_NULL(g_device_manager); + OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum()); + mirror_ops_.push_back(op_vec); + } + return SUCCESS; +} + +Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } + +Status BatchParallelInfo::InferTensorMap() { + if (strategy_->GetInputDim()[0][0] != dev_num_) { + MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; + return FAILED; + } + for (size_t i = 0; i < inputs_shape_.size(); i++) { + std::vector tensor_map_index; + for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { + if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { + tensor_map_index.push_back(0); + } else { + tensor_map_index.push_back(MAP_NONE); + } + } + inputs_tensor_map_.push_back(tensor_map_index); + } + for (size_t i = 0; i < outputs_shape_.size(); i++) { + std::vector tensor_map_index; + for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { + if (i == 0 && j == 0) { + tensor_map_index.push_back(0); + } else { + tensor_map_index.push_back(MAP_NONE); + } + } + outputs_tensor_map_.push_back(tensor_map_index); + } + return SUCCESS; +} + +Strategys BatchParallelInfo::GetOutputsStrategy() { + Strategys outputs_strategy; + + for (size_t i = 0; i < outputs_shape_.size(); ++i) { + std::vector strategy; + for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { + if (i == 0 && j == 0) { + strategy.push_back(dev_num_); + } else { + strategy.push_back(1); + } + } + outputs_strategy.push_back(strategy); + } + + return outputs_strategy; +} + +Status BatchParallelInfo::InferTensorInfo() { + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { + MS_LOG(INFO) << name_ << " : The input size is " << strategy_->GetInputNumber(); + TensorLayout tensor_layout_in; + if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { + return FAILED; + } + TensorInfo tensor_info_in(tensor_layout_in); + inputs_tensor_info_.push_back(tensor_info_in); + } + for (size_t i = 0; i < outputs_shape_.size(); i++) { + TensorLayout tensor_layout_out; + if (tensor_layout_out.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(i), outputs_shape_.at(i)) != + SUCCESS) { + return FAILED; + } + TensorInfo tensor_info_out(tensor_layout_out); + outputs_tensor_info_.push_back(tensor_info_out); + } + return SUCCESS; +} + +Status BatchParallelInfo::GetAttrs() { return SUCCESS; } + +Status BatchParallelInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { + CheckGlobalDeviceManager(); + is_auto_parallel_ = true; + size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + StrategyPtr sp; + std::vector strategy; + for (size_t i = 0; i < inputs_shape_.size(); i++) { + Shape temp(inputs_shape_[i].size(), 1); + if (split_flag_list_[i]) { + temp[0] = SizeToInt(total_dev_num); + } + strategy.push_back(temp); + } + sp = std::make_shared(stage_id, strategy); + + if (SetCostUnderStrategy(sp) == SUCCESS) { + MS_LOG(INFO) << name_ << " : Successfully generated batch-parallel-strategy."; + PrintStrategy(sp); + } else { + MS_LOG(ERROR) << name_ << " : Generating batch-parallel-strategy failed."; + return FAILED; + } + return SUCCESS; +} + +void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = true; + } +} + +Status BatchParallelInfo::InferAsLossDivisor() { + as_loss_divisor_ = 1; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h new file mode 100644 index 0000000000..0ba30c385a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class BatchParallelInfo : public OperatorInfo { + public: + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + dev_num_(1) {} + + ~BatchParallelInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + Strategys GetOutputsStrategy(); + Status InferAsLossDivisor() override; + + private: + int32_t dev_num_; +}; + +class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { + public: + SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, + const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; + void ReComputeBatchSplitFlagList() override; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc new file mode 100644 index 0000000000..e8b3afba16 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc @@ -0,0 +1,261 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/bias_add_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + std::vector stra = strategy->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + Dimensions sub_b_strategy = stra.at(1); + int32_t channel_a_strategy = sub_a_strategy.at(1); + int32_t channel_b_strategy = sub_b_strategy.at(0); + if (channel_a_strategy != channel_b_strategy) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + return SUCCESS; +} + +Status BiasAddInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + dev_matrix_shape_ = sub_a_strategy; + return SUCCESS; +} + +void BiasAddInfo::ReComputeBatchSplitFlagList() { + split_flag_list_[0] = true; + split_flag_list_[1] = false; +} + +Status BiasAddInfo::InferTensorMap() { + TensorMap sub_a_tensor_map; + TensorMap sub_b_tensor_map; + std::vector stra = strategy_->GetInputDim(); + Dimensions sub_a_strategy = stra.at(0); + size_t sub_a_strategy_size = sub_a_strategy.size(); + for (size_t i = 0; i < sub_a_strategy_size; ++i) { + sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - i)); + } + sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - 1)); + + inputs_tensor_map_.push_back(sub_a_tensor_map); + inputs_tensor_map_.push_back(sub_b_tensor_map); + outputs_tensor_map_.push_back(sub_a_tensor_map); + + return SUCCESS; +} + +Status BiasAddInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + Shape input_b_tensor_map = inputs_tensor_map_.at(1); + std::vector input_a_group, input_b_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input b failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b; + if (input_a_group.empty() && input_b_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } + if (!input_a_group.empty()) { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + if (!input_b_group.empty()) { + op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); + } + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + + return SUCCESS; +} + +Status BiasAddInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, + const Shape &dev_matrix_array) { + if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); + TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); + TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); + Shape input_a_shape_array = inputs_shape_.at(0); + Shape input_b_shape_array = inputs_shape_.at(1); + Shape out_shape_array = outputs_shape_.at(0); + + TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; + if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; + return FAILED; + } + if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; + return FAILED; + } + if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; + return FAILED; + } + inputs_layout->push_back(input_a_tensor_layout); + inputs_layout->push_back(input_b_tensor_layout); + outputs_layout->push_back(out_tensor_layout); + + return SUCCESS; +} + +Status BiasAddInfo::InferTensorInfo() { + // infer tensor shape + Shape input_a_shape = inputs_shape_.at(0); + Shape input_b_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_a_slice_shape = inputs_slice_shape.at(0); + Shape input_b_slice_shape = inputs_slice_shape.at(1); + Shape output_slice_shape = outputs_slice_shape.at(0); + + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; + return FAILED; + } + + TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); + TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); + TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a + inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b + outputs_tensor_info_.push_back(out_tensor_info); // output + + return SUCCESS; +} + +Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split, input0_split}; + + std::vector sp_vector; + is_auto_parallel_ = true; + Shapes tmp_inputs_shape = {inputs_shape_[0], inputs_shape_[0]}; + Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &sp_vector) != + SUCCESS) { + return FAILED; + } + MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; + + for (auto &sp : sp_vector) { + std::vector tmp_strategy; + Dimensions input0_strategy = sp->GetInputDim()[0]; + tmp_strategy.push_back(input0_strategy); // input0 + + Dimensions input1_strategy = {input0_strategy.at(1)}; + + // reset the strategy + tmp_strategy.push_back(input1_strategy); // input1 + sp->ResetInputs(tmp_strategy); + } + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status BiasAddInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status BiasAddInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h new file mode 100644 index 0000000000..3ede65a3ba --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ + +#include + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class BiasAddInfo : public OperatorInfo { + public: + BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~BiasAddInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h new file mode 100644 index 0000000000..2829889846 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/comparison_function_info.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class EqualInfo : public ArithmeticBase { + public: + EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~EqualInfo() override = default; +}; + +class NotEqualInfo : public ArithmeticBase { + public: + NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~NotEqualInfo() override = default; +}; + +class MaximumInfo : public ArithmeticBase { + public: + MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MaximumInfo() override = default; +}; + +class MinimumInfo : public ArithmeticBase { + public: + MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MinimumInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc new file mode 100644 index 0000000000..3b411ccb0e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc @@ -0,0 +1,323 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/dropout_do_mask_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "pipeline/jit/resource.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +static int32_t SEED_NUM = 1; + +Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 1) { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; + return FAILED; + } + + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // only check the input[0] + Shapes input_shape = {inputs_shape_[0]}; + if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy"; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + } + return FAILED; + } + return SUCCESS; +} + +Status DropoutDoMaskInfo::InferDevMatrixShape() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + std::vector strategy = strategy_->GetInputDim(); + if (strategy.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = strategy[0]; + return SUCCESS; +} + +Status DropoutDoMaskInfo::InferTensorMap() { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + std::vector tensor_map_index; + size_t size = inputs_shape_[0].size(); + // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back(SizeToInt(size - i - 1)); + } + + // the input[1] do not need tensor map + inputs_tensor_map_.push_back(tensor_map_index); // input_0 + outputs_tensor_map_.push_back(tensor_map_index); // output + return SUCCESS; +} + +Status DropoutDoMaskInfo::InferTensorInfo() { + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << ": Invalid inputs shape size " << inputs_shape_.size(); + return FAILED; + } + + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + + Shape input_0_shape = inputs_shape_[0]; + + if (inputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; + return FAILED; + } + + TensorLayout input_0_tensor_layout; + if (input_0_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_0_shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout failed"; + return FAILED; + } + + TensorInfo input_0_tensor_info(input_0_tensor_layout); + + // input_1 do not need tensor info + inputs_tensor_info_.push_back(input_0_tensor_info); // input_0 + outputs_tensor_info_.push_back(input_0_tensor_info); // output + return SUCCESS; +} + +Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + Shapes used_inputs_shape = {inputs_shape_[0]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, used_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +std::shared_ptr>> DropoutDoMaskInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions strategy(inputs_shape_[0].size() - 1, 1); + (void)strategy.insert(strategy.begin(), SizeToInt(dev_num)); + std::vector strategy_v = {strategy}; + return std::make_shared>>(strategy_v); +} + +Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + if (!dropout_gen_mask->isa()) { + MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode"; + } + + auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); + if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; + } + if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { + MS_LOG(EXCEPTION) << "The input[0] of dropout gen mask cnode is not primitive"; + } + + ValueNodePtr value_node = dropout_gen_mask_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value_node); + PrimitivePtr prim = value_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() != DROPOUT_GEN_MASK) { + MS_LOG(EXCEPTION) << "The primitive name is not DropoutGenMask"; + } + return prim; +} + +void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); + MS_EXCEPTION_IF_NULL(dropout_gen_mask); + if (!dropout_gen_mask->isa()) { + MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode."; + } + + auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); + if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; + } + + if (!IsValueNode(dropout_gen_mask_cnode->input(1))) { + MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple."; + } + + FuncGraphPtr func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; + } + + ValuePtr new_shape = MakeValue(input_slice_shape); + AnfNodePtr val = NewValueNode(new_shape); + (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); +} + +// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is +// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape +// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation +// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. +std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { + std::vector replace_ops; + MS_EXCEPTION_IF_NULL(cnode); + PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); + MS_EXCEPTION_IF_NULL(prim); + + if (inputs_tensor_info_.empty()) { + MS_LOG(EXCEPTION) << "The tensor info of dropout do mask is empty"; + } + + if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + + if (!cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)->isa()) { + MS_LOG(EXCEPTION) << "The keep prob of dropout do mask is not value node"; + } + + ValuePtr keep_prob = GetValueNode(cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)); + MS_EXCEPTION_IF_NULL(keep_prob); + auto attr = prim->attrs(); + if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { + MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; + } + + Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); + int32_t seed_0 = GetValue(attr[SEED0]); + int32_t seed_1 = GetValue(attr[SEED1]); + if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { + seed_0 = SEED_NUM; + seed_1 = SEED_NUM; + SEED_NUM++; + } else { + SetGenMaskShape(cnode, input_slice_shape); + MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); + return replace_ops; + } + + ValuePtr new_shape = MakeValue(input_slice_shape); + Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); + Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); + OperatorAttrs attrs = {attr_0, attr_1}; + Attr param_0 = std::make_pair(SHAPE, new_shape); + Attr param_1 = std::make_pair(KEEP_PROB, keep_prob); + OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; + replace_ops.push_back(replace_op); + return replace_ops; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h new file mode 100644 index 0000000000..ea7d590071 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class DropoutDoMaskInfo : public OperatorInfo { + public: + DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~DropoutDoMaskInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + std::shared_ptr>> GenerateBatchStrategies() override; + std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorMap() override; + Status GetAttrs() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; +}; + +using DropoutDoMaskInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h new file mode 100644 index 0000000000..e25da9e743 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/elementary_function_info.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ + +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ExpInfo : public ActivationOther { + public: + ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ExpInfo() override = default; +}; + +class LogInfo : public ActivationOther { + public: + LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~LogInfo() override = default; +}; + +class CosInfo : public ActivationOther { + public: + CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~CosInfo() override = default; +}; + +class ACosInfo : public ActivationOther { + public: + ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~ACosInfo() override = default; +}; + +class LogicalNotInfo : public ActivationOther { + public: + LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~LogicalNotInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc new file mode 100644 index 0000000000..4e6e947f68 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc @@ -0,0 +1,350 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/gather_v2_info.h" + +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/strategy.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status GatherV2Info::GetAttrs() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { + MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); + return FAILED; + } + // the second input is the index tensor + + // the third input is the axis, is a ValueNode + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + + if (inputs_shape_.at(0).size() == 0) { + MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; + return FAILED; + } + int axis = GetValue(input_value_.at(2)); + if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) { + MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " + << inputs_shape_.at(0).size() << ")."; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; + + index_size_ = inputs_shape_.at(1).size(); + + return SUCCESS; +} + +Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + // Only strategy of the first input should be set. + if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + axis_strategy_ = strategy->GetInputDim().at(0).at(axis_); + if (index_size_ != 1 && axis_strategy_ != 1) { + MS_LOG(ERROR) << name_ + << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " + "corresponding to axis must be 1, but is " + << axis_strategy_; + return FAILED; + } + if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { + MS_LOG(ERROR) << name_ + << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " + "axis. The first dimension of index is " + << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; + return FAILED; + } + return SUCCESS; +} + +Status GatherV2Info::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + dev_matrix_shape_ = stra.at(0); + return SUCCESS; +} + +// If index is a scalar, output dimension is input dimension minus 1; +// If index is a n dimension tensor, output dimension is input dimension plus (n - 1). +// Tensor map dimension is equal to the corresponding input and output dimension. +// If index's dimension is more than 1, we insert -1 for the output tensor map. +Status GatherV2Info::InferTensorMap() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + std::vector tensor_map_in; + std::vector tensor_map_out; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_in.push_back(SizeToInt(size - i - 1)); + tensor_map_out.push_back(SizeToInt(size - i - 1)); + } + + if (index_size_ == 0) { + (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); + } else if (index_size_ > 1) { + (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); + } + if (tensor_map_out.size() != outputs_shape_.at(0).size()) { + MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() + << " output size is " << outputs_shape_.at(0).size(); + return FAILED; + } + + std::vector tensor_map_in_index; + if (index_size_ >= 1) { + tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); + } + for (size_t i = 1; i < index_size_; ++i) { + tensor_map_in_index.push_back(-1); + } + inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status GatherV2Info::InferTensorInfo() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_shape_.size(); + return FAILED; + } + if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_tensor_map_.size(); + return FAILED; + } + if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " + << outputs_tensor_map_.size(); + return FAILED; + } + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +OperatorVector CreateSubOp(int32_t sub_value) { + OperatorVector ops; + OperatorName operator_name = SUB; + OperatorAttrs operator_attrs; + + std::vector tensor_data = {sub_value}; + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, kInt32); + ValuePtr op_param_value = MakeValue(tensor_ptr); + + Attr op1_param = std::make_pair("", op_param_value); + OperatorParams operator_param = {std::make_pair(op1_param, 2)}; + + OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); + Operator op = std::make_pair(operator_name, operator_args); + ops.push_back(op); + return ops; +} + +Status GatherV2Info::InferTensorSubOps() { + sub_ops_.clear(); + if ((index_size_ == 0) || (axis_strategy_ == 1)) { + return SUCCESS; + } + int32_t mod_n = 1; + for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { + mod_n *= dev_matrix_shape_.at(i); + } + if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) { + MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; + } + int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_); + int32_t rank = g_device_manager->global_rank(); + int32_t mod_rank = rank % mod_p; + mod_rank = static_cast(mod_rank / mod_n); + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + return FAILED; + } + if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) { + MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; + } + int32_t sub_value = static_cast(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank; + + OperatorVector sub_op; + sub_ops_.emplace_back(std::move(sub_op)); + sub_op = CreateSubOp(sub_value); + sub_ops_.emplace_back(std::move(sub_op)); + return SUCCESS; +} + +Status GatherV2Info::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + Status status = InferTensorSubOps(); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; + return status; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status GatherV2Info::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" + << outputs_shape_.size() << "is wrong."; + return FAILED; + } + + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != + SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +std::shared_ptr>> GatherV2Info::GenerateBatchStrategies() { + if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { + MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " + << inputs_shape_.size(); + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (GetAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << "GetAttrs failed!"; + } + + Dimensions strategy; + if (index_size_ != 1) { + strategy.push_back(1); + } else { + strategy.push_back(SizeToInt(dev_num)); + } + for (size_t i = 1; i < inputs_shape_[0].size(); i++) { + strategy.push_back(1); + } + std::vector strategy_v = {strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h new file mode 100644 index 0000000000..b3dc0fab87 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +constexpr size_t GATHER_V2_INPUTS_SIZE = 2; +constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; +constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; +// We now supported limited parallel strategies. +// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of +// the input. +// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. +class GatherV2Info : public OperatorInfo { + public: + GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + axis_(-1), + index_size_(0), + axis_strategy_(1) {} + ~GatherV2Info() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status InferTensorSubOps(); + + int32_t axis_; + size_t index_size_; + int32_t axis_strategy_; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc new file mode 100644 index 0000000000..eb3c9900f8 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -0,0 +1,636 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/gather_v2_p_info.h" + +#include +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" + +namespace mindspore { +namespace parallel { +Status GatherV2PInfo::GetAttrs() { + // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. + if (target_ != CPU) { + if (input_value_.at(2) == nullptr) { + MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; + return FAILED; + } + auto axis = GetValue(input_value_.at(2)); + // if axis is negative then convert it to positive + auto params_shape = inputs_shape_.at(0); + if (params_shape.size() == 0) { + MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; + return FAILED; + } + if (axis < 0) { + axis += SizeToInt(inputs_shape_[0].size()); + } + axis_ = axis; + } + + auto target_iter = attrs_.find(TARGET); + if (target_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(target_iter->second); + if (target_iter->second->isa()) { + target_ = target_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of target is not a string."; + } + } + auto manual_split_iter = attrs_.find("manual_split"); + if (manual_split_iter != attrs_.end()) { + param_split_shapes_.clear(); + manual_split_ = true; + auto var = manual_split_iter->second->cast(); + MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); + + if (var->size() > 0) { + std::vector elements = var->value(); + for (auto &ele : elements) { + if (ele->isa()) { + auto value_tuple = ele->cast(); + std::vector value_vector = value_tuple->value(); + if (value_vector.size() != 2) { + MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; + return FAILED; + } + param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); + index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); + } else { + MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; + return FAILED; + } + } + + if (param_split_shapes_.empty()) { + MS_LOG(ERROR) << "Failed to extract param split strategy."; + return FAILED; + } + } + } + + return SUCCESS; +} + +Status GatherV2PInfo::CheckManualSplit() { + auto param_shape = inputs_shape_.at(0); + int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, + [](int32_t s, int32_t shape) { return s + shape; }); + if (split_shape_sum < param_shape.at(0)) { + MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; + return FAILED; + } + + if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { + MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; + return FAILED; + } + + return SUCCESS; +} + +Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + // param slice shape need 32Byte aligned + auto param_shape = inputs_shape_.at(0); + auto param_strategy = strategy->GetInputDim().at(0); + auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); + if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) { + MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; + return FAILED; + } + + // only support 1-dim and 2-dim param + if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { + MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); + return FAILED; + } + + // don't support scalar index + if (inputs_shape_.at(1).size() == 0) { + MS_LOG(DEBUG) << name_ << ": Don't support scalar index."; + return FAILED; + } + + // axis=0, index_shape(0)%param_strategy(0) must be 0 + Shape index_shape = inputs_shape_.at(1); + if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { + MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; + return FAILED; + } + + if (manual_split_) { + if (CheckManualSplit() != SUCCESS) { + return FAILED; + } + // when using manual_split, no need to check belowings. + return SUCCESS; + } + + // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 + if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { + MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; + return FAILED; + } + + // param_strategy(axis) != 1, index can't be splited + auto index_strategy = strategy->GetInputDim().at(1); + auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { + MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; + return FAILED; + } + + // param_strategy(axis) != 1, Don't support repeated calc + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; + return FAILED; + } + + return SUCCESS; +} + +Status GatherV2PInfo::InferMirrorOps() { + // There is no mirror operators for manual split + if (manual_split_) { + return SUCCESS; + } + + mirror_ops_.clear(); + Shape input_a_tensor_map = inputs_tensor_map_.at(0); + std::vector input_a_group; + if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input a failed."; + return FAILED; + } + + OperatorVector op_for_input_a, op_for_input_b, op_for_axis; + if (input_a_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror group is empty."; + return SUCCESS; + } else { + op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); + } + + mirror_ops_.push_back(op_for_input_a); + mirror_ops_.push_back(op_for_input_b); + mirror_ops_.push_back(op_for_axis); + + return SUCCESS; +} + +Status GatherV2PInfo::InferDevMatrixShape() { + dev_matrix_shape_.clear(); + out_dev_matrix_shape_.clear(); + // infer input dev_matrix_shape + auto param_strategy = strategy_->GetInputDim().at(0); + auto index_strategy = strategy_->GetInputDim().at(1); + + if (manual_split_) { + dev_matrix_shape_ = param_strategy; + out_dev_matrix_shape_ = dev_matrix_shape_; + return SUCCESS; + } + + dev_matrix_shape_ = param_strategy; + + // param_strategy(axis)!=1, + if (param_strategy.at(IntToSize(axis_)) != 1) { + std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); + } else { + dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); + } + + // infer out dev_matrix_shape + // axis!=0, split axis + if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) { + out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_))); + for (size_t i = 1; i < param_strategy.size(); ++i) { + if (i == IntToSize(axis_)) { + out_dev_matrix_shape_.push_back(1); + } else { + out_dev_matrix_shape_.push_back(param_strategy.at(i)); + } + } + } else { + out_dev_matrix_shape_ = dev_matrix_shape_; + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if (param_product * index_product < SizeToInt(dev_num)) { + out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); + } + + return SUCCESS; +} + +Status GatherV2PInfo::InferTensorMap() { + if (manual_split_) { + inputs_tensor_map_.push_back({1, 0}); + inputs_tensor_map_.push_back({-1, 1}); + outputs_tensor_map_.push_back({-1, 1, 0}); + return SUCCESS; + } + // infer input tensor map + // param_strategy(axis) != 1 + size_t param_size = inputs_shape_.at(0).size(); + size_t index_size = inputs_shape_.at(1).size(); + size_t total_size = param_size + index_size; + std::vector tensor_map_index; + std::vector tensor_map_params; + auto param_strategy = strategy_->GetInputDim().at(0); + if (param_strategy.at(IntToSize(axis_)) != 1) { + tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(i)); + } + } else { + // param_strategy(axis) == 1 + for (size_t i = 0; i < param_size; ++i) { + tensor_map_params.push_back(SizeToInt(total_size - i - 1)); + } + for (size_t i = 0; i < index_size; ++i) { + tensor_map_index.push_back(SizeToInt(index_size - i - 1)); + } + } + + // infer output tensor map + std::vector tensor_map_out; + if (param_strategy.at(IntToSize(axis_)) == 1) { + // param_strategy(axis) == 1 + for (size_t i = 0; i < param_size; ++i) { + if (i == IntToSize(axis_)) { + for (size_t j = 0; j < index_size; ++j) { + tensor_map_out.push_back(SizeToInt(index_size - j - 1)); + } + } else { + tensor_map_out.push_back(SizeToInt(total_size - i - 1)); + } + } + } else { + // param_strategy(axis) != 1 + if (axis_ == 0) { + tensor_map_out.insert(tensor_map_out.end(), 0); + tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); + for (size_t i = 1; i < param_size; ++i) { + tensor_map_out.push_back(i); + } + } else { + for (size_t i = 0; i < param_size; ++i) { + if (i == IntToSize(axis_)) { + tensor_map_out.insert(tensor_map_out.end(), index_size, -1); + } else { + tensor_map_out.push_back(SizeToInt(param_size - i - 1)); + } + } + } + } + + inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); + inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); + outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); + return SUCCESS; +} + +Status GatherV2PInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape input_index_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + int32_t rank = g_device_manager->global_rank(); + // infer tensor layout + TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; + if (manual_split_) { + input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]]; + input_shape[0] = input_shape[0] * dev_matrix_shape_[0]; + } + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || + (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != + SUCCESS)) { + return FAILED; + } + // infer tensor info + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo input_index_info(input_index_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + Shape slice_shape = input_tensor_info.slice_shape(); + MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(input_index_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status GatherV2PInfo::InferBias() { + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + auto input_shape = inputs_shape_.at(0); + auto params_strategy = strategy_->GetInputDim().at(0); + // axis don't split + if (params_strategy.at(axis_) == 1) { + bias_ = 0; + return SUCCESS; + } + // params_size=1, axis=0 + if ((input_shape.size() == 1) && (axis_ == 0)) { + slice_size_ = input_shape.at(0) / params_strategy.at(0); + bias_ = rank * slice_size_; + return SUCCESS; + } + // params_size=2, axis=0 + if ((input_shape.size() == 2) && (axis_ == 0)) { + slice_size_ = input_shape.at(0) / params_strategy.at(0); + bias_ = rank / params_strategy.at(1) * slice_size_; + return SUCCESS; + } + // params_size=2, axis=1 + if ((input_shape.size() == 2) && (axis_ == 1)) { + slice_size_ = input_shape.at(1) / params_strategy.at(1); + bias_ = rank % params_strategy.at(1) * slice_size_; + return SUCCESS; + } + MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_; + return FAILED; +} + +Status GatherV2PInfo::InferOffset() { + CheckGlobalDeviceManager(); + size_t rank = g_device_manager->global_rank(); + if (rank < index_offsets_.size()) { + index_offset_ = index_offsets_.at(rank); + MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; + return SUCCESS; + } + + MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size(); + return FAILED; +} + +Status GatherV2PInfo::InferGroup() { + auto param_strategy = strategy_->GetInputDim().at(0); + size_t dim = IntToSize(axis_); + if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { + dim = (axis_ + 1) % 2; + } + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + int32_t rank = g_device_manager->global_rank(); + RankList dev_list = g_device_manager->GetDeviceListByStageId(0); + DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Create group failed."; + return FAILED; + } + if (group_devices.size() == 1) { + MS_LOG(INFO) << "the group is empty"; + return SUCCESS; + } + + group_ = g_device_manager->CreateGroup(group_devices); + return SUCCESS; +} + +std::vector GetRankFromGroup(const Group &group) { + std::vector rank_list; + auto device_list = group.GetDevicesList(); + for (auto &device : device_list) { + rank_list.insert(rank_list.end(), device.rank() % 8); + } + return rank_list; +} + +Status GatherV2PInfo::InferForwardCommunication() { + forward_op_.clear(); + auto param_strategy = strategy_->GetInputDim().at(0); + // don't split axis or target is not CPU, no need forward communication + if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) { + return SUCCESS; + } + // split axis + OperatorName operator_name; + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + Attr attr_group; + operator_name = REDUCE_SCATTER; + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + attr_group = std::make_pair(GROUP, MakeValue(group_.name())); + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + OperatorAttrs attrs = {attr_op, attr_group}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(operator_name, args); + + forward_op_.push_back(op); + return SUCCESS; +} + +Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + GenerateGraph gen_g = GenerateGraph(); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + if (manual_split_) { + if (InferOffset() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Bias failed."; + return FAILED; + } + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)}); + auto gather_v2 = + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)}); + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, gather_v2)); + return SUCCESS; + } + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Bias failed."; + return FAILED; + } + auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); + auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); + auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); + auto gather_v2 = + gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); + auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); + auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); + auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); + auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); + // don't need expandim,if param_size = 1, + if (inputs_shape_.at(0).size() == 1) { + mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); + } + if (InferGroup() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer Group failed."; + return FAILED; + } + Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); + Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); + OperatorAttrs attrs = {attr_op, attr_group}; + auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); + std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, reduce_scatter)); + + return SUCCESS; +} + +ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { + if (manual_split_) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; + } + + auto param_strategy = strategy_->GetInputDim().at(0); + // target_ == CPU, no need to raplace graph + if (target_ == CPU) { + return nullptr; + } + if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; +} + +Status GatherV2PInfo::ComputeReplaceOp() { + if (InferBias() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer offset failed."; + return FAILED; + } + OperatorName op_name = EMBEDDING_LOOKUP; + OperatorAttrs attrs; + Attr param_offset = std::make_pair("offset", MakeValue(bias_)); + OperatorParams params = {std::make_pair(param_offset, 3)}; + OperatorArgs args = std::make_pair(attrs, params); + Operator op = std::make_pair(op_name, args); + replace_op_.push_back(op); + + return SUCCESS; +} + +Status GatherV2PInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + // only target_ == CPU, we need to replace op + if (target_ == CPU && ComputeReplaceOp() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + auto param_strategy = strategy_->GetInputDim().at(0); + // cost model set axis and strategy + auto gatherv2_2cost = std::dynamic_pointer_cast(operator_cost()); + gatherv2_2cost->set_axis(axis_); + gatherv2_2cost->set_strategy(param_strategy); + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shape input1_split(inputs_shape_[1].size(), 1); + Shapes splittable_inputs = {input0_split, input1_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions param_strategy(inputs_shape_[0].size(), 1); + Dimensions index_strategy; + index_strategy.push_back(SizeToInt(dev_num)); + for (size_t i = 1; i < inputs_shape_[1].size(); i++) { + index_strategy.push_back(1); + } + std::vector strategy_v = {param_strategy, index_strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h new file mode 100644 index 0000000000..eb26c616d0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GatherV2PInfo : public OperatorInfo { + public: + GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), + axis_(0), + bias_(0), + index_offset_(0), + slice_size_(0) {} + ~GatherV2PInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + + private: + Status ComputeReplaceGraph(const CNodePtr &cnode); + Status CheckManualSplit(); + Status ComputeReplaceOp(); + Status InferBias(); + Status InferOffset(); + Status InferGroup(); + + int32_t axis_; + std::string target_ = DEVICE; + std::string replace_op_name_ = GATHERV2; + int32_t bias_; + int32_t index_offset_; + int32_t slice_size_; + Shape out_dev_matrix_shape_; + Group group_; + bool manual_split_ = false; + std::vector param_split_shapes_; + std::vector index_offsets_; +}; + +class SparseGatherV2Info : public GatherV2PInfo { + public: + SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~SparseGatherV2Info() override = default; + + private: + std::string replace_op_name_ = SPARSE_GATHERV2; +}; + +class EmbeddingLookupInfo : public GatherV2PInfo { + public: + EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} + ~EmbeddingLookupInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc new file mode 100644 index 0000000000..3606732156 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc @@ -0,0 +1,269 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/get_next_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status GetNextInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + for (auto shp : shapes_) { + TensorMap out_tensor_map; + for (size_t i = 0; i < shp.size(); ++i) { + if (full_batch) { + out_tensor_map.push_back(MAP_NONE); + } else { + out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); + } + } + outputs_tensor_map_.push_back(out_tensor_map); + } + return SUCCESS; +} + +Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { + if (outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << " : The layout is null."; + return FAILED; + } + for (size_t i = 0; i < outputs_shape_.size(); ++i) { + TensorLayout output_layout; + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) { + return FAILED; + } + outputs_layout->push_back(output_layout); + } + return SUCCESS; +} + +Strategys GetNextInfo::GetOutputStrategy() { + Strategys outputs_strategy; + for (auto shp : shapes_) { + Dimensions out_strategy; + out_strategy.push_back(dev_num_); + for (size_t i = 1; i < shp.size(); ++i) { + out_strategy.push_back(1); + } + outputs_strategy.push_back(out_strategy); + } + return outputs_strategy; +} + +Status GetNextInfo::InferTensorInfo() { + TensorLayouts outputs_layout; + if (InferTensorLayout(&outputs_layout) != SUCCESS) { + return FAILED; + } + for (size_t i = 0; i < outputs_shape_.size(); ++i) { + TensorInfo output_tensor_info(outputs_layout[i]); + outputs_tensor_info_.push_back(output_tensor_info); + } + return SUCCESS; +} + +Status GetNextInfo::InferDevMatrixShape() { + size_t max_shape_length = 0; + for (auto shp : shapes_) { + if (max_shape_length < shp.size()) { + max_shape_length = shp.size(); + } + } + if (max_shape_length == 0) { + MS_LOG(ERROR) << name_ << " : shape is 0"; + } + dev_matrix_shape_.push_back(dev_num_); + for (size_t i = 1; i < max_shape_length; ++i) { + dev_matrix_shape_.push_back(1); + } + return SUCCESS; +} + +Status GetNextInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed"; + return FAILED; + } + if (InferReplaceOps(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer replace Ops failed"; + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init success"; + return SUCCESS; +} + +Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { + std::vector stras = strategy->GetInputDim(); + for (Dimensions stra : stras) { + if (stra.size() != 0) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + } + int32_t stage = strategy->GetInputStage(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + dev_num_ = dev_num; + return SUCCESS; +} + +Status GetNextInfo::GetAttrTypes() { + auto iter = attrs_.find(TYPES); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto iter_cast = iter->second->cast(); + MS_EXCEPTION_IF_NULL(iter_cast); + auto types = iter_cast->value(); + for (auto &type : types) { + MS_EXCEPTION_IF_NULL(type); + types_.push_back(type->ToString()); + } + } else if (iter->second->isa()) { + auto iter_cast = iter->second->cast(); + MS_EXCEPTION_IF_NULL(iter_cast); + auto types = iter_cast->value(); + for (auto &type : types) { + MS_EXCEPTION_IF_NULL(type); + types_.push_back(type->ToString()); + } + } else { + MS_LOG(ERROR) << name_ << " : The value of types is not list."; + return FAILED; + } + } + return SUCCESS; +} + +Status GetNextInfo::GetAttrShapes() { + shapes_ = outputs_shape_; + if (shapes_.size() == 0) { + MS_LOG(ERROR) << name_ << " : Shape is None."; + return FAILED; + } + return SUCCESS; +} + +Status GetNextInfo::GetAttrOutPutNum() { + auto iter = attrs_.find(GETNEXT_NUM); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + output_num_ = iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of output_num is not int."; + return FAILED; + } + } + return SUCCESS; +} + +Status GetNextInfo::GetAttrs() { + if (GetAttrTypes() == FAILED || GetAttrShapes() == FAILED || GetAttrOutPutNum() == FAILED) { + return FAILED; + } + if (types_.size() != IntToSize(output_num_) || shapes_.size() != IntToSize(output_num_) || output_num_ == 0) { + MS_LOG(ERROR) << name_ << " : The output_num is not equal to shapes size."; + return FAILED; + } + return SUCCESS; +} + +Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + Shapes out_shapes = outputs_shape_; + for (size_t i = 0; i < out_shapes.size(); ++i) { + if (dev_num_ <= 0) { + MS_LOG(ERROR) << name_ << " : The dev num is 0."; + return FAILED; + } + if (out_shapes[i][0] % dev_num_ != 0) { + MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; + return FAILED; + } + if (!full_batch) { + out_shapes[i][0] = out_shapes[i][0] / dev_num_; + } + } + ValuePtr new_shapes = MakeValue(out_shapes); + Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); + Attr attr_shapes = std::make_pair(SHAPES, new_shapes); + Attr attr_num = std::make_pair(GETNEXT_NUM, attrs_[GETNEXT_NUM]); + Attr attr_shared_name = std::make_pair(SHARED_NAME, attrs_[SHARED_NAME]); + OperatorAttrs attrs = {attr_types, attr_shapes, attr_num, attr_shared_name}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + replace_op_ = {std::make_pair(GET_NEXT, args)}; + return SUCCESS; +} + +Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status GetNextInfo::GenerateStrategies(int32_t stage_id) { + is_auto_parallel_ = true; + std::vector stra; + StrategyPtr sp = std::make_shared(stage_id, stra); + if (SetCostUnderStrategy(sp) == SUCCESS) { + MS_LOG(INFO) << name_ << " : Successfully generated strategy."; + PrintStrategy(sp); + } else { + MS_LOG(ERROR) << name_ << " : Generating strategy failed."; + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h new file mode 100644 index 0000000000..36e7a0fcb3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class GetNextInfo : public OperatorInfo { + public: + GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~GetNextInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *outputs_layout); + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferReplaceOps(const StrategyPtr &strategy); + Status GetAttrTypes(); + Status GetAttrShapes(); + Status GetAttrOutPutNum(); + Strategys GetOutputStrategy(); + Status InferAsLossDivisor() override { return SUCCESS; } + + private: + int32_t dev_num_ = 1; + std::vector types_; + Shapes shapes_; + int32_t output_num_ = 0; + std::string shared_name_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc new file mode 100644 index 0000000000..126fdcf84e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/l2_normalize_info.h" + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(INFO) << name_ << " : Init success."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions input_strategy = stra.at(0); + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + axis_; + } + + if (input_strategy[IntToSize(axis_index)] != 1) { + MS_LOG(ERROR) << name_ << " : The dim " << axis_index << " of input strategy must be 1."; + return FAILED; + } + + return SUCCESS; +} + +Status L2NormalizeInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis_ = iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of axis is not int."; + return FAILED; + } + } + + return SUCCESS; +} + +Status L2NormalizeInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_.at(0); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group failed."; + return FAILED; + } + + OperatorVector op_for_weight; + if (input_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_weight); + MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group is " << input_group[0].name(); + } + + return SUCCESS; +} + +Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size() - 1, 1); + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + axis_; + } + (void)input0_split.insert(input0_split.begin() + axis_index, 0); + Shapes splittable_inputs = {input0_split}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h new file mode 100644 index 0000000000..c74dde4b4b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class L2NormalizeInfo : public Activation { + public: + L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : Activation(name, inputs_shape, outputs_shape, attrs) {} + ~L2NormalizeInfo() override = default; + Status GenerateStrategies(int32_t stage_id) override; + + protected: + Status GetAttrs() override; + Status InferMirrorOps() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + + private: + int32_t axis_ = 0; // Default value = 0 +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc new file mode 100644 index 0000000000..62d7c6d61e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/layer_norm_info.h" +#include +#include +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +Status LayerNormInfo::GetAttrs() { + auto iter = attrs_.find(BEGIN_NORM_AXIS); + if (iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis"; + return FAILED; + } + if ((iter->second == nullptr) || !iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": The axis type is not int"; + return FAILED; + } + + int32_t dim = SizeToInt(input_shape_.size()); + auto axis = GetValue(iter->second); + if ((axis >= dim) || (axis < -dim)) { + MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim << ", " << dim - 1 << "]"; + return FAILED; + } + + if (axis < 0) { + axis = axis + dim; + } + begin_norm_axis_ = IntToSize(axis); + return SUCCESS; +} + +Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + std::vector stra = strategy->GetInputDim(); + if (stra.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); + return FAILED; + } + + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy value"; + return FAILED; + } + + Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX]; + Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX]; + Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX]; + if (begin_norm_axis_ >= input_strategy.size()) { + MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; + return FAILED; + } + // check input strategy + for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { + if (input_strategy[i] != NO_SPLIT_STRATEGY) { + MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); + return FAILED; + } + } + + // check gamma and beta strategy + if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) { + MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy"; + return FAILED; + } + + size_t gamma_diff = input_strategy.size() - gamma_strategy.size(); + for (size_t j = 0; j < gamma_strategy.size(); ++j) { + if (gamma_strategy[j] != input_strategy[gamma_diff + j]) { + MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy); + return FAILED; + } + } + + size_t beta_diff = input_strategy.size() - beta_strategy.size(); + for (size_t k = 0; k < beta_strategy.size(); ++k) { + if (beta_strategy[k] != input_strategy[beta_diff + k]) { + MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy); + return FAILED; + } + } + return SUCCESS; +} + +Status LayerNormInfo::InferDevMatrixShape() { + if (strategy_ == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null"; + return FAILED; + } + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is empty"; + return FAILED; + } + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status LayerNormInfo::CreateTensorMap(size_t input_index) { + if (inputs_shape_.size() <= input_index) { + MS_LOG(ERROR) << name_ << ": Invalid index" << input_index; + return FAILED; + } + Shape shape = inputs_shape_[input_index]; + Shape tensor_map; + for (size_t i = 0; i < shape.size(); ++i) { + tensor_map.push_back(SizeToInt(shape.size() - i - 1)); + } + inputs_tensor_map_.push_back(tensor_map); + outputs_tensor_map_.push_back(tensor_map); + return SUCCESS; +} + +Status LayerNormInfo::InferTensorMap() { + if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create tensor map failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::CreateMirrorOp(size_t input_index) { + if (inputs_tensor_map_.size() <= input_index) { + MS_LOG(ERROR) << name_ << ": Invalid index " << input_index; + return FAILED; + } + Shape tensor_map = inputs_tensor_map_[input_index]; + std::vector group; + if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed"; + return FAILED; + } + OperatorVector mirror_op; + if (!group.empty()) { + mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); + MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is " + << group[0].name(); + } + mirror_ops_.push_back(mirror_op); + return SUCCESS; +} + +Status LayerNormInfo::InferMirrorOps() { + if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create mirror op failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::CreateTensorInfo(size_t input_index) { + if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) { + MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index; + return FAILED; + } + Shape tensor_map = inputs_tensor_map_[input_index]; + Shape shape = inputs_shape_[input_index]; + TensorLayout tensor_layout; + if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed"; + return FAILED; + } + + TensorInfo tensor_info(tensor_layout); + inputs_tensor_info_.push_back(tensor_info); + outputs_tensor_info_.push_back(tensor_info); + return SUCCESS; +} + +Status LayerNormInfo::InferTensorInfo() { + if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || + (CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Create tensor info failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error"; + return FAILED; + } + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0]) + << ", as_loss_divisor_ is " << as_loss_divisor_; + return SUCCESS; +} + +Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Set cost failed"; + return FAILED; + } + return SUCCESS; +} + +Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector &sp_vector) { + if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { + MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input"; + return FAILED; + } + + size_t gamma_diff = input_shape_.size() - gamma_shape_.size(); + size_t beta_diff = input_shape_.size() - beta_shape_.size(); + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + std::vector tmp_strategy; + Dimensions input_strategy = sp->GetInputDim()[0]; + Dimensions gamma_strategy = input_strategy; + (void)gamma_strategy.erase(gamma_strategy.begin(), + gamma_strategy.begin() + static_cast(gamma_diff)); + Dimensions beta_strategy = input_strategy; + (void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast(beta_diff)); + + // reset the strategy + tmp_strategy.push_back(input_strategy); + tmp_strategy.push_back(gamma_strategy); + tmp_strategy.push_back(beta_strategy); + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +Status LayerNormInfo::GenerateStrategies(int32_t stage_id) { + if (InitShapes() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init shapes failed"; + return FAILED; + } + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Get attrs failed"; + return FAILED; + } + Shape input_split(input_shape_.size(), SPLIT_FLAG); + if (begin_norm_axis_ >= input_split.size()) { + MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; + return FAILED; + } + + // Can not split the dimensions from begin norm axis + for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) { + input_split[i] = NO_SPLIT_FLAG; + } + + // Generate strategy for input + Shapes splittable_inputs = {input_split}; + Shapes tmp_inputs_shape = {input_shape_}; + std::vector sp_vector; + is_auto_parallel_ = true; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate input strategy failed"; + return FAILED; + } + + // Generate the strategies for gamma and beta + if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate gamma and beta strategies failed"; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(DEBUG) << name_ << ": Successfully generated " << success << " strategy"; + } + } + return SUCCESS; +} + +Status LayerNormInfo::InitShapes() { + if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) { + MS_LOG(ERROR) << name_ << ": Invalid inputs size"; + return FAILED; + } + input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX]; + gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX]; + beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX]; + return SUCCESS; +} + +Status LayerNormInfo::Init(const StrategyPtr &strategy) { + if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed"; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success"; + return SUCCESS; +} + +Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) { + if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed"; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success"; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h new file mode 100644 index 0000000000..9ee11bb215 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ + +#include +#include +#include +#include +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +constexpr size_t LAYER_NORM_INPUT_SIZE = 3; +constexpr size_t LAYER_NORM_INPUT_INDEX = 0; +constexpr size_t LAYER_NORM_GAMMA_INDEX = 1; +constexpr size_t LAYER_NORM_BETA_INDEX = 2; +constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; + +// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split +// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. +class LayerNormInfo : public OperatorInfo { + public: + LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), + begin_norm_axis_(0) {} + ~LayerNormInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + + protected: + Status GetAttrs() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferAsLossDivisor() override; + Status CreateTensorMap(size_t input_index); + Status CreateTensorInfo(size_t input_index); + Status CreateMirrorOp(size_t input_index); + Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); + Status InitShapes(); + + private: + size_t begin_norm_axis_; + Shape input_shape_; + Shape gamma_shape_; + Shape beta_shape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc new file mode 100644 index 0000000000..889f204fb0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc @@ -0,0 +1,232 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/loss_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions input_strategy = stra.at(0); + Dimensions label_strategy = stra.at(1); + if (input_strategy != label_strategy) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_.at(0).size(); + axis_index = static_cast(input_dim) + axis_; + } + + int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); + int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); + // Dimension corresponding to axis is un-splittable + if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ + << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy + << ", label: " << label_axis_strategy; + } else { + MS_LOG(ERROR) << name_ + << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy + << ", label: " << label_axis_strategy; + } + return FAILED; + } + + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() { + if ((inputs_shape_.size() != SoftmaxCrossEntropyWithLogitsInputsSize) || + (outputs_shape_.size() != SoftmaxCrossEntropyWithLogitsOutputsSize)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + dev_matrix_shape_ = input_strategy; + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() { + std::vector tensor_map_index; + size_t size = inputs_shape_[0].size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - i - 1)); + } + + std::vector first_output_tensor_map = {tensor_map_index[0]}; + inputs_tensor_map_.push_back(tensor_map_index); // input + inputs_tensor_map_.push_back(tensor_map_index); // label + outputs_tensor_map_.push_back(first_output_tensor_map); // output-0 + outputs_tensor_map_.push_back(tensor_map_index); // output-1 + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape first_output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {{inputs_strategy[0][0]}, inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape first_output_slice_shape = outputs_slice_shape.at(0); + + TensorMap input_tensor_map = inputs_tensor_map_.at(0); + TensorMap first_output_tensor_map = outputs_tensor_map_.at(0); + + TensorLayout input_tensor_layout, first_output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, input_tensor_map, input_shape) != SUCCESS) || + (first_output_tensor_layout.InitFromVector(dev_matrix_shape_, first_output_tensor_map, first_output_shape) != + SUCCESS)) { + return FAILED; + } + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo first_output_tensor_info(first_output_tensor_layout, first_output_shape, first_output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); // input + inputs_tensor_info_.push_back(input_tensor_info); // label + outputs_tensor_info_.push_back(first_output_tensor_info); // output-0 + outputs_tensor_info_.push_back(input_tensor_info); // output-1 + + return SUCCESS; +} + +// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload the function. +Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.size() != 2) { + MS_LOG(ERROR) << name_ << " : The size of outputs tensor map " << outputs_tensor_map_.size() << " is error."; + return FAILED; + } + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[1]); + MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[1]) << ", as_loss_divisor_ is " + << as_loss_divisor_; + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + split_flag_list_[i] = true; + } +} + +Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + int32_t axis_index = axis_; + if (axis_ < 0) { + size_t input_dim = inputs_shape_[0].size(); + axis_index = static_cast(input_dim) + axis_; + } + is_auto_parallel_ = true; + + Shape input0_split; + (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); + input0_split[IntToSize(axis_index)] = 0; + Shapes splittable_inputs = {input0_split, input0_split}; + std::vector sp_vector; + if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Generate strategies failed."; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + + return SUCCESS; +} + +Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + PrintStrategy(strategy); + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h new file mode 100644 index 0000000000..7e5478bedf --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +// infer shape: +// input_0 : [a, b], input_1 : [a, b] +// output_0 : [a], output_1: [a, b] +class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { + public: + SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, + std::make_shared(false)) {} + ~SoftmaxCrossEntropyWithLogitsInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload + // the InferAsLossDivisor. + Status InferAsLossDivisor() override; + + private: + int32_t axis_ = -1; // default -1 +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc new file mode 100644 index 0000000000..60a3d60b39 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -0,0 +1,647 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/matmul_info.h" + +#include +#include +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +namespace mindspore { +namespace parallel { +void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, + Shape *dev_matrix_shape) { + MS_EXCEPTION_IF_NULL(dev_matrix_shape); + size_t mat_a_size = mat_a_strategy.size(); + size_t mat_b_size = mat_b_strategy.size(); + if (mat_a_size >= mat_b_size) { + // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] + // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) + + // [2],[4] in the example above + for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) { + dev_matrix_shape->push_back(mat_a_strategy.at(i)); + } + } else { + // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32] + // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) + + // [2],[4] in the example above + for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) { + dev_matrix_shape->push_back(mat_b_strategy.at(i)); + } + } + + // [8],[16] in the example above + dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size))); + dev_matrix_shape->push_back(mat_a_strategy.back()); + + // [32] in the example above + if (!transpose_b) { + dev_matrix_shape->push_back(mat_b_strategy.back()); + } else { + dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size))); + } +} + +Status MatMulBase::GetAttrs() { + if (attrs_.size() < MATMUL_ATTRS_SIZE) { + MS_LOG(ERROR) << name_ << " : The size of attrs small than 2."; + return FAILED; + } + + auto transpose_a_iter = attrs_.find(TRANSPOSE_A); + if (transpose_a_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(transpose_a_iter->second); + if (transpose_a_iter->second->isa()) { + transpose_a_ = transpose_a_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; + return FAILED; + } + } + + auto transpose_b_iter = attrs_.find(TRANSPOSE_B); + if (transpose_b_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(transpose_b_iter->second); + if (transpose_b_iter->second->isa()) { + transpose_b_ = transpose_b_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; + return FAILED; + } + } + + auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER); + if (forward_reduce_scatter_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second); + if (forward_reduce_scatter_iter->second->isa()) { + forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool."; + return FAILED; + } + } + + // infer inputs dimension size + if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; + return FAILED; + } + mat_a_dimension_ = inputs_shape_.at(0).size(); + mat_b_dimension_ = inputs_shape_.at(1).size(); + + return SUCCESS; +} + +Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { + size_t long_size = long_strategy.size(); + size_t short_size = short_strategy.size(); + if (long_size < short_size) { + MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is " + << short_size; + return FAILED; + } + + size_t len_diff = long_size - short_size; + for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) { + if (long_strategy.at(len_diff + j) != short_strategy.at(j)) { + MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is " + << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy); + return FAILED; + } + } + + return SUCCESS; +} + +Status MatMul::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << " : Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + Dimensions mat_a_strategy = stra.at(0); + Dimensions mat_b_strategy = stra.at(1); + + size_t mat_a_size = mat_a_strategy.size(); + size_t mat_b_size = mat_b_strategy.size(); + if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; + } else { + MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; + } + return FAILED; + } + + // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] + // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) + // [16] in the example above + if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + + if (mat_a_size >= mat_b_size) { + if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + } else { + if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; + return FAILED; + } + } + + if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) { + MS_LOG(WARNING) << name_ + << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, " + "setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } + + return SUCCESS; +} + +Status MatMulBase::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions mat_a_strategy = stra.at(0); + Dimensions mat_b_strategy = stra.at(1); + + SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_); + return SUCCESS; +} + +// all-reduce weight's grad +Status MatMulBase::InferMirrorOps() { + mirror_ops_.clear(); + + Shape mat_b_tensor_map = inputs_tensor_map_[1]; + std::vector mat_b_group; + if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) { + return FAILED; + } + + OperatorVector op_for_inputs; // op_for_inputs is empty + OperatorVector op_for_weight; + + if (mat_b_group.empty()) { + MS_LOG(INFO) << name_ << " : The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_inputs); + mirror_ops_.push_back(op_for_weight); + MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name(); + } + + return SUCCESS; +} + +Status MatMulBase::InferForwardCommunication() { + forward_op_.clear(); + size_t dimension = dev_matrix_shape_.size(); + size_t relevant_dimension_index = SECOND_FROM_END(dimension); + // Relevant dimension is not split and all reduce is not required + if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) { + MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; + return SUCCESS; + } + + std::vector group_list; + if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed."; + return FAILED; + } else if (group_list.empty()) { + MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; + return SUCCESS; + } + + Operator op; + if (forward_reduce_scatter_) { + op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name()); + } else { + op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); + } + + forward_op_.push_back(op); + MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); + return SUCCESS; +} + +Status MatMulBase::InferTensorMap() { + size_t size = dev_matrix_shape_.size(); + if (repeated_calc_num_ > 1) { + // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation + size = dev_matrix_shape_.size() - 1; + } + + std::vector tensor_map_index; + // such as 5: tensor_map_index [4,3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); + } + + // infer output tensor map: [4,3,2,0], delete the second-from-end element + TensorMap output_tensor_map = tensor_map_index; + (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast(SECOND_FROM_END(size))); + + // infer mat_a tensor map + // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1] + TensorMap mat_a_tensor_map = tensor_map_index; + // delete last one element + mat_a_tensor_map.pop_back(); + // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements + (void)mat_a_tensor_map.erase( + mat_a_tensor_map.begin(), + mat_a_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_a_dimension_)); + + // infer mat_b tensor map + TensorMap mat_b_tensor_map = tensor_map_index; + // delete the third-to-last element + (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast(THIRD_FROM_END(size))); + // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements + (void)mat_b_tensor_map.erase( + mat_b_tensor_map.begin(), + mat_b_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_b_dimension_)); + if (transpose_b_) { + // swap the last two elements + int32_t last_value = mat_b_tensor_map.back(); + mat_b_tensor_map.pop_back(); + (void)mat_b_tensor_map.insert( + mat_b_tensor_map.begin() + static_cast(LAST_INDEX(mat_b_tensor_map.size())), last_value); + } + + if (forward_reduce_scatter_) { + if (dev_matrix_shape_.size() != 3) { + MS_LOG(WARNING) << name_ + << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, " + "setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) { + MS_LOG(WARNING) << name_ + << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in " + "forward reduce scatter mode, setting the forward reduce scatter mode to false here"; + forward_reduce_scatter_ = false; + } else { + // the forward reduce scatter only support that the dimension of output is 2 + output_tensor_map = {1, 0}; + } + } + + inputs_tensor_map_.push_back(mat_a_tensor_map); + inputs_tensor_map_.push_back(mat_b_tensor_map); + outputs_tensor_map_.push_back(output_tensor_map); + return SUCCESS; +} + +Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + Shape output_dev_matrix_shape; + if (forward_reduce_scatter_) { + if (dev_matrix_shape_.size() != 3) { + MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode"; + return FAILED; + } + output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]}; + } else { + output_dev_matrix_shape = dev_matrix_shape_; + } + + TensorLayout mat_a_layout, mat_b_layout, output_layout; + if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || + (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || + (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { + return FAILED; + } + + inputs_layout->push_back(mat_a_layout); + inputs_layout->push_back(mat_b_layout); + outputs_layout->push_back(output_layout); + return SUCCESS; +} + +Status MatMulBase::InferTensorInfo() { + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + + TensorLayout mat_a_layout = inputs_layout.at(0); + TensorLayout mat_b_layout = inputs_layout.at(1); + TensorLayout output_layout = outputs_layout.at(0); + TensorInfo mat_a_tensor_info(mat_a_layout); + TensorInfo mat_b_tensor_info(mat_b_layout); + TensorInfo output_tensor_info(output_layout); + + inputs_tensor_info_.push_back(mat_a_tensor_info); + inputs_tensor_info_.push_back(mat_b_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status MatMulBase::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << " : Init failed."; + return FAILED; + } + + if (forward_reduce_scatter_) { + virtual_div_op_.clear(); + MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op"; + } + + MS_LOG(INFO) << name_ << " : Init success."; + return SUCCESS; +} + +Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << " : Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << " : Init for cost model success."; + return SUCCESS; +} + +Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { + if (input->size() < 2) { + MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; + return FAILED; + } + auto last_1st_value = input->at(input->size() - 1); + auto last_2nd_value = input->at(input->size() - 2); + input->pop_back(); + input->pop_back(); + input->push_back(last_1st_value); + input->push_back(last_2nd_value); + return SUCCESS; +} + +Status MatMulBase::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << " : GetAttrs failed."; + return FAILED; + } + CheckGlobalDeviceManager(); + std::vector dev_list = g_device_manager->GetDeviceListByStageId(stage_id); + size_t dev_num = dev_list.size(); + Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1]; + if (transpose_a_) { + if (SwapLastTwoElements(&input0_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + if (transpose_b_) { + if (SwapLastTwoElements(&input1_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + // The shape of input0 (input1) + // E.g., input0 = [100, 200, 300], input1 = [300, 400] + + // Combining the input0_shape and input1_shape + // E.g., combined_shape = [100, 200, 300, 400] + is_auto_parallel_ = true; + size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size(); + Dimensions combined_partitions; + Shape combined_shape; + // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2 + if (input0_shape.size() >= input1_shape.size()) { + combined_shape = input0_shape; + combined_shape.push_back(input1_shape[input1_shape.size() - 1]); + } else { + combined_shape = input1_shape; + combined_shape.push_back(input0_shape[input0_shape.size() - 2]); + } + std::function recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape, + &input1_shape_size, &recursive, &input0_shape_size, + this](uint32_t current_index, size_t n) { + // Finishing the recursive steps, if the strategy is valid, then calculate the cost + // for this operator under the strategy. + if (current_index == combined_shape.size()) { + StrategyPtr sp; + if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) == + FAILED) { + return; + } + if (this->SetCostUnderStrategy(sp) == FAILED) { + MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed."; + return; + } + } else { + MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size + << ", input1_shape_size: " << input1_shape_size; + for (uint32_t i = 1; i <= n; i *= 2) { + if (n % i == 0 && IntToSize(combined_shape[current_index]) % i == 0) { + combined_partitions.push_back(i); + recursive(current_index + 1, n / i); + combined_partitions.pop_back(); + } + } + } + }; + recursive(0, dev_num); + if (strategy_cost_.empty()) { + MS_LOG(EXCEPTION) << name_ << " : No available strategy."; + } + return Status::SUCCESS; +} + +Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, + mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, + size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { + int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); + if (!FULLY_USE_DEVICES) { + if (IntToSize(product) > dev_num) { + return FAILED; + } + } else { + if (IntToSize(product) != dev_num) { + return FAILED; + } + } + Dimensions input0_partitions, input1_partitions; + if (input0_shape_size >= input1_shape_size) { + for (size_t i = 0; i < input0_shape_size; ++i) { + input0_partitions.push_back(combined_partitions[i]); + } + if (input1_shape_size == 2) { + input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]); + input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); + } else { + // input1_shape.size() > 2 + for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) { + if (j == combined_partitions.size() - 3) { + continue; + } + input1_partitions.push_back(combined_partitions[j]); + } + } + } else { + for (size_t i = 0; i < input1_shape_size; ++i) { + input1_partitions.push_back(combined_partitions[i]); + } + for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) { + input0_partitions.push_back(combined_partitions[j]); + } + input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); + input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]); + } + if (transpose_a_) { + if (SwapLastTwoElements(&input0_partitions) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + if (transpose_b_) { + if (SwapLastTwoElements(&input1_partitions) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + } + std::vector stras; + stras.push_back(input0_partitions); + stras.push_back(input1_partitions); + (*sp) = std::make_shared(stage_id, stras); + + return SUCCESS; +} + +void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { + TensorLayout tly; + if (transpose_a_) { + Shape replica_input0_shape(inputs_tensor_info_[0].shape()); + Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape()); + if (SwapLastTwoElements(&replica_input0_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + + TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape); + relica_inputs_tensor_vector->push_back(replica_input0_info); + } else { + relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]); + } + if (transpose_b_) { + Shape replica_input1_shape(inputs_tensor_info_[1].shape()); + Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape()); + if (SwapLastTwoElements(&replica_input1_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) { + MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; + } + + TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape); + relica_inputs_tensor_vector->push_back(replica_input1_info); + } else { + relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]); + } +} + +Status MatMulBase::CheckForTensorSliceValid() const { + if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { + return SUCCESS; + } + if (inputs_tensor_info_.empty()) { + return FAILED; + } + for (auto &one_input_tensor : inputs_tensor_info_) { + auto slice_shape = one_input_tensor.slice_shape(); + if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || + (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { + return FAILED; + } + } + return SUCCESS; +} + +Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (InitForCostModel(strategy) == FAILED) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; + } else { + MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed."; + } + return FAILED; + } + PrintStrategy(strategy); + // Check whether the tensor slice of input_tensor_info is valid or not + if (CheckForTensorSliceValid() != SUCCESS) { + MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy."; + return FAILED; + } + // Here, a replicated inputs_ is constructed for the transposed TensorInfo. + std::vector relica_inputs_tensor_vector; + InitTensorInfoForCost(&relica_inputs_tensor_vector); + + int32_t stage_id = strategy->GetInputStage(); + // Here, we use the origin outputs_, because we only use the slice size of the output tensor. + // It does not matter whether the output tensor is transposed or not. + double computation_cost = + operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_ + << ", communication_cost: " << result->communication_cost_ + << ", communication_without_parameter_: " << result->communication_without_parameter_ + << ", communication_with_partial_para_: " << result->communication_with_partial_para_; + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + result->communication_forward_ = result->communication_without_parameter_; + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); + + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h new file mode 100644 index 0000000000..d4e144c2b6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ + +#include +#include +#include +#include + +#include "common/utils.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class MatMulBase : public OperatorInfo { + public: + MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~MatMulBase() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + // Generate all strategies and the corresponding cost for this MatMul operator + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, + size_t input1_shape_size, StrategyPtr *sp); + + Status SwapLastTwoElements(Shape *shape); + + protected: + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + void InitTensorInfoForCost(std::vector *); + Status CheckForTensorSliceValid() const; + Status GetAttrs() override; + + bool transpose_a_ = false; + bool transpose_b_ = false; + bool forward_reduce_scatter_ = false; + size_t mat_a_dimension_ = 0; + size_t mat_b_dimension_ = 0; +}; + +class MatMul : public MatMulBase { + public: + MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) + : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} + ~MatMul() override = default; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; +}; + +class MatMulInfo : public MatMul { + public: + MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : MatMul(name, inputs_shape, outputs_shape, attrs) {} + ~MatMulInfo() override = default; +}; + +class BatchMatMulInfo : public MatMul { + public: + BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : MatMul(name, inputs_shape, outputs_shape, attrs) {} + ~BatchMatMulInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc new file mode 100644 index 0000000000..15acb085f5 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -0,0 +1,311 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/onehot_info.h" + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/strategy.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status OneHotInfo::GetAttrs() { + auto iter = attrs_.find(AXIS); + if (iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + axis_value_ptr_ = iter->second; + axis_ = iter->second->cast()->value(); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis is not int."; + return FAILED; + } + } + + if (inputs_shape_[0].size() != 1) { + MS_LOG(ERROR) << name_ << ": Input's shape only support 1-D now."; + return FAILED; + } + + if ((axis_ > 1) || (axis_ < -1)) { + MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1]."; + return FAILED; + } + return SUCCESS; +} + +Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != 1) { + MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, + is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status OneHotInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + // Now input only support 1-D tensor, so the output is a 2-D tensor + // If input is a vector of length features, the output shape will be: + // [features, depth] if axis == -1 (or axis == 1) + // [depth, features] if axis == 0 + if (axis_ == 0) { + dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable + dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable + } else { + dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable + dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable + } + + return SUCCESS; +} + +Status OneHotInfo::InferTensorMap() { + std::vector input_tensor_map_index, output_tensor_map_index; + size_t size = outputs_shape_[0].size(); + // such as 2: tensor_map_index [1,0] + if (axis_ == 0) { + for (size_t i = 0; i < size; ++i) { + output_tensor_map_index.push_back((int32_t)(i)); + } + } else { + for (size_t i = 0; i < size; ++i) { + output_tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); + } + } + outputs_tensor_map_.push_back(output_tensor_map_index); + + // Now input only support 1-D tensor + input_tensor_map_index.push_back(1); + + inputs_tensor_map_.push_back(input_tensor_map_index); + return SUCCESS; +} + +// axis = -1 +// (0,(1,16),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(1,0) +// (0,(16,1),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(1,0) +// (0,(2,8),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between +// machines dev_matrix=(2,8) map_in=(1) map_out=(1,0) (0, (2,4),(),())16 devices dev_matrix=(2,4,2) map_in=(1) +// map_out=(1,0) +// axis = 0 +// (0, (16,1),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(0,1) +// (0, (1,16),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(0,1) +// (0, (8,2),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between +// machines dev_matrix=(2,8) map_in=(1) map_out=(0,1) (0,(4,2),(),())16 devices dev_matrix=(2,4,2) map_in=(1) +// map_out=(0,1) +Status OneHotInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape output_shape = outputs_shape_.at(0); + + TensorLayout input_tensor_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout); + TensorInfo output_tensor_info(output_tensor_layout); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + + return SUCCESS; +} + +Status OneHotInfo::ExtractInputInfo() { + CheckGlobalDeviceManager(); + rank_ = g_device_manager->global_rank(); + mod_rank_ = rank_ % dev_matrix_shape_.back(); + if (!cnode_) { + MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; + return FAILED; + } + if (cnode_->inputs().size() != 5) { + MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, real input size is " + << cnode_->inputs().size(); + return FAILED; + } + if (input_value_.size() != 4) { + MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, and input value size " + "must be 4, real size is " + << input_value_.size(); + return FAILED; + } + auto value_ptr = input_value_.at(1); + if (value_ptr == nullptr) { + MS_LOG(WARNING) << "Input 2 of cnode is not a value node, its type is " << cnode_->input(2)->type_name(); + return FAILED; + } + + if (value_ptr->isa()) { + total_class_number_ = value_ptr->cast()->value(); + } else { + MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; + return FAILED; + } + classes_each_device_ = total_class_number_ / dev_matrix_shape_.back(); + + return SUCCESS; +} + +Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + if (dev_matrix_shape_.back() == 1) { + replace_graph_ = nullptr; + return SUCCESS; + } + if (ExtractInputInfo() != SUCCESS) { + MS_LOG(ERROR) << "ExtractInputInfo failed"; + return FAILED; + } + GenerateGraph gen_g = GenerateGraph(); + Status status = gen_g.Init(cnode); + if (status != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + + auto floor_div = + gen_g.PushBack({gen_g.NewOpInst(FLOORDIV), gen_g.virtual_input_node(), CreateInt32Tensor(classes_each_device_)}); + auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)}); + auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1}); + auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); + auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); + auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); + auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)}); + auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); + auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); + Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); + OperatorAttrs attrs_onehot = {attr_onehot_axis}; + auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_), + cnode->input(3), cnode->input(4)}); + std::vector> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)}; + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, onehot)); + + return SUCCESS; +} + +ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { + if (ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; +} + +Status OneHotInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + Status status = ComputeReplaceGraph(cnode_); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return status; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status OneHotInfo::GenerateStrategies(int32_t stage_id) { + Shapes splittable_inputs = {{1, 1}, {}, {}}; + std::vector sp_vector; + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); + return FAILED; + } + if (outputs_shape_.size() != 1) { + MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + if (GenerateStrategiesForIndependentInputs(stage_id, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, + splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategies failed."; + return FAILED; + } + + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + + return SUCCESS; +} + +Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +std::shared_ptr>> OneHotInfo::GenerateBatchStrategies() { + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + Dimensions strategy = {SizeToInt(dev_num), 1}; + Dimensions empty_strategy; + std::vector strategy_v = {strategy, empty_strategy, empty_strategy}; + return std::make_shared>>(strategy_v); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h new file mode 100644 index 0000000000..dfd7e6cbaf --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class OneHotInfo : public OperatorInfo { + public: + OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~OneHotInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + std::shared_ptr>> GenerateBatchStrategies() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status ExtractInputInfo(); + + private: + Status ComputeReplaceGraph(const CNodePtr &cnode); + + int axis_ = -1; + int32_t rank_ = 0; + int32_t total_class_number_ = 1; + int32_t classes_each_device_ = 1; + ValuePtr axis_value_ptr_; + int32_t mod_rank_ = 0; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc new file mode 100644 index 0000000000..3dd47b1de6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -0,0 +1,1334 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/operator_info.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/dtype.h" +#include "ir/tensor.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/context.h" +#include "utils/context/ms_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { + if (strategy == nullptr) { + MS_LOG(ERROR) << "The strategy is null."; + return FAILED; + } + + size_t strategy_size = strategy->GetInputNumber(); + size_t inputs_shape_size = inputs_shape.size(); + if (strategy_size != inputs_shape_size) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; + } else { + MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + for (size_t i = 0; i < strategy_size; ++i) { + Shape sub_strategy = stra.at(i); + Shape sub_input_shape = inputs_shape.at(i); + size_t strategy_len = sub_strategy.size(); + size_t inputs_len = sub_input_shape.size(); + if (strategy_len != inputs_len) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len + << ", index: " << i; + } else { + MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len + << ", index: " << i; + } + return FAILED; + } + + for (size_t j = 0; j < strategy_len; ++j) { + int32_t strategy_value = sub_strategy.at(j); + if (strategy_value < MIN_SLICE_NUM) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; + } else { + MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value; + } + return FAILED; + } + + if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value; + } else { + MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value; + } + return FAILED; + } + + int32_t shape_value = sub_input_shape.at(j); + if ((shape_value % strategy_value) != 0) { + if (is_auto_parallel) { + MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; + } else { + MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; + } + return FAILED; + } + } + } + + return SUCCESS; +} + +void OperatorInfo::ResetQueueMember() { + inputs_tensor_info_.clear(); + outputs_tensor_info_.clear(); + inputs_tensor_map_.clear(); + outputs_tensor_map_.clear(); + dev_matrix_shape_.clear(); + forward_op_.clear(); + mirror_ops_.clear(); + sub_ops_.clear(); + replace_op_.clear(); + replace_op_info_.clear(); + virtual_div_op_.clear(); + global_device_list_.clear(); +} + +Status OperatorInfo::InferAttrs() { + if (infer_attrs_completed_) { + return SUCCESS; + } + + if (GetAttrs() != SUCCESS) { + return FAILED; + } + infer_attrs_completed_ = true; + return SUCCESS; +} + +void OperatorInfo::SetDeviceListByStrategy() { + int32_t stage = strategy_->GetInputStage(); + CheckGlobalDeviceManager(); + global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); +} + +Status OperatorInfo::InferRepeatedCalcInfo() { + int32_t g_dev_list_size = SizeToInt(global_device_list_.size()); + int32_t dev_matrix_size = + std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); + if (dev_matrix_size == 0) { + MS_LOG(ERROR) << name_ << ": The dev matrix size is 0"; + return FAILED; + } + + if (g_dev_list_size == dev_matrix_size) { + repeated_calc_num_ = 1; + } else if (g_dev_list_size % dev_matrix_size == 0) { + repeated_calc_num_ = g_dev_list_size / dev_matrix_size; + } else { + MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " + << dev_matrix_size; + return FAILED; + } + + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + int32_t stage = strategy_->GetInputStage(); + local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); + + return SUCCESS; +} + +// if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix, +// only use for infer tensor layout +void OperatorInfo::SetRepeatedCalcDevMatrix() { + if (repeated_calc_num_ <= 1) { + return; + } + + (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); +} + +// use for loss repeated calculation +Operator CreateVirtualDivOp(int32_t div_num) { + OperatorName operator_name = VIRTUAL_DIV; + ValuePtr attr0_value = MakeValue(div_num); + Attr attr0 = std::make_pair(DIVISOR, attr0_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + return op; +} + +// use for forward all reduce +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { + OperatorName operator_name = ALL_REDUCE; + ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM + ValuePtr attr1_value = MakeValue(group); // group + Attr attr0 = std::make_pair(OP, attr0_value); + Attr attr1 = std::make_pair(GROUP, attr1_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group; + return op; +} + +Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) { + OperatorName operator_name = REDUCE_SCATTER; + ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM + ValuePtr attr1_value = MakeValue(group); // group + Attr attr0 = std::make_pair(OP, attr0_value); + Attr attr1 = std::make_pair(GROUP, attr1_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group; + return op; +} + +// use for get tensor slice +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { + Shape tensor_map = tensor_layout.tensor_map().array(); + Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); + OperatorName operator_name = GET_TENSOR_SLICE; + + OperatorAttrs attrs; + ValuePtr dev_mat_value = MakeValue(dev_matrix_shape); + Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2); + ValuePtr tensor_map_value = MakeValue(tensor_map); + Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3); + OperatorParams params = {dev_mat_param, tensor_map_param}; + OperatorArgs operator_arg = std::make_pair(attrs, params); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is " + << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map); + return op; +} + +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { + if ((dev_num == 0) || (dev_num == 1)) { + MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; + } + OperatorVector op_for_weight; + bool mean_flag = ParallelContext::GetInstance()->mirror_mean(); + + OperatorName operator_name = MIRROR_OPERATOR; + ValuePtr attr0_value = MakeValue(group_name); + ValuePtr attr1_value = MakeValue(SizeToInt(dev_num)); + ValuePtr attr2_value = MakeValue(mean_flag); + + Attr attr0 = std::make_pair(GROUP, attr0_value); + Attr attr1 = std::make_pair(DEV_NUM, attr1_value); + Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); + + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + operator_attrs.push_back(attr2); + + OperatorParams operator_param; + OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_args); + + op_for_weight.push_back(op); + MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is " + << mean_flag; + return op_for_weight; +} + +Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { + if (group == nullptr) { + MS_LOG(ERROR) << "The group is null."; + return FAILED; + } + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { + return FAILED; + } + + if (group_devices.size() == 1) { + MS_LOG(INFO) << "The dev size is 1, no need to create group."; + return SUCCESS; + } + + Group g = g_device_manager->CreateGroup(group_devices); + group->push_back(g); + return SUCCESS; +} + +Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { + if (group == nullptr) { + MS_LOG(ERROR) << "The group is null."; + return FAILED; + } + CheckGlobalDeviceManager(); + int32_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { + return FAILED; + } + + if (group_devices.size() == 1) { + MS_LOG(INFO) << "The dev size is 1, no need to create group."; + return SUCCESS; + } + + Group g = g_device_manager->CreateGroup(group_devices); + group->push_back(g); + return SUCCESS; +} + +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { + Shape slice_shape; + if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { + MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; + return slice_shape; + } + for (size_t i = 0; i < strategy.size(); ++i) { + slice_shape.push_back(tensor_shape.at(i) / strategy.at(i)); + } + return slice_shape; +} + +Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { + if (slice_shapes == nullptr) { + MS_LOG(ERROR) << "The slice_shapes is null."; + return FAILED; + } + if (strategys.size() != shapes.size()) { + MS_LOG(ERROR) << "Strategy size " << strategys.size() << " not equal to shape size " << shapes.size(); + return FAILED; + } + + for (size_t i = 0; i < strategys.size(); ++i) { + if (strategys.at(i).size() != shapes.at(i).size()) { + MS_LOG(ERROR) << "Strategy dimension " << strategys.at(i).size() << " not equal to shape dimension " + << shapes.at(i).size(); + slice_shapes->clear(); + return FAILED; + } + + for (size_t j = 0; j < shapes.at(i).size(); ++j) { + if (strategys.at(i).at(j) <= 0) { + MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategys[i]) + << " the element is less than or equal to 0."; + slice_shapes->clear(); + return FAILED; + } + if (shapes.at(i).at(j) % strategys.at(i).at(j) != 0) { + MS_LOG(ERROR) << "Shape cannot be divisible by strategy, " << shapes.at(i).at(j) << " : " + << strategys.at(i).at(j); + slice_shapes->clear(); + return FAILED; + } + } + Shape slice_shape = GetSliceShape(shapes.at(i), strategys.at(i)); + slice_shapes->push_back(slice_shape); + } + + return SUCCESS; +} + +Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { + if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { + MS_LOG(ERROR) << "The slice_shape is null."; + return FAILED; + } + + if (InferSliceShapeByStrategy(inputs_strategy, inputs_shape_, inputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << "Infer inputs slice shape error."; + return FAILED; + } + + if (InferSliceShapeByStrategy(outputs_strategy, outputs_shape_, outputs_slice_shape) != SUCCESS) { + MS_LOG(ERROR) << "Infer outputs slice shape error."; + inputs_slice_shape->clear(); + return FAILED; + } + + return SUCCESS; +} + +// method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 +Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferAttrs failed."; + return FAILED; + } + + // must be after InferAttrs() + if (CheckStrategy(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": CheckStrategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; + } + return FAILED; + } + + // need to clear queues before Init(), + // because Init() may be called multiple times by cost model + ResetQueueMember(); + + strategy_ = strategy; + SetDeviceListByStrategy(); + + if (InferDevMatrixShape() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; + return FAILED; + } + + used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); + + // must be after InferDevMatrixShape + if (InferRepeatedCalcInfo() != SUCCESS) { + MS_LOG(ERROR) << ": InferRepeatedCalcInfo failed."; + return FAILED; + } + + // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix for layout + SetRepeatedCalcDevMatrix(); + + if (InferTensorMap() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; + return FAILED; + } + + if (InferTensorInfo() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; + return FAILED; + } + + return SUCCESS; +} + +// method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape +Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferAttrs failed."; + return FAILED; + } + + // must be after InferAttrs() + if (CheckStrategy(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; + return FAILED; + } + + // need to clear queues before Init(), + // because Init() may be called multiple times by cost model + ResetQueueMember(); + + strategy_ = strategy; + SetDeviceListByStrategy(); + + if (InferDevMatrixShape() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; + return FAILED; + } + + // must be after InferDevMatrixShape + if (InferRepeatedCalcInfo() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed."; + return FAILED; + } + + if (InferTensorMap() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; + return FAILED; + } + + if (InferTensorInfo() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; + return FAILED; + } + + return SUCCESS; +} + +Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + return FAILED; + } + + if (InferForwardCommunication() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; + return FAILED; + } + + if (InferMirrorOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; + return FAILED; + } + + if (InferVirtualDivOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; + return FAILED; + } + + return SUCCESS; +} + +Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { + if (strategy == nullptr) { + MS_LOG(ERROR) << name_ << ": The strategy is null."; + return FAILED; + } + + if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { + return FAILED; + } + + if (InferForwardCommunication() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; + return FAILED; + } + + if (InferMirrorOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; + return FAILED; + } + + if (InferVirtualDivOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; + return FAILED; + } + + return SUCCESS; +} + +std::vector> OperatorInfo::GetAliveSuccEdges() { + std::vector> ret; + for (auto &edge : succ_edges_) { + if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { + ret.push_back(edge); + } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) { + // CAST is ordered in front of L2NORMALIZE + ret.push_back(edge); + } + } + for (auto &edge : succ_edges_) { + if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) && + (edge->next_operator()->name().find(CAST) == std::string::npos)) { + ret.push_back(edge); + } + } + return ret; +} + +std::vector> OperatorInfo::GetAlivePrevEdges() { + std::vector> ret; + for (auto &edge : prev_edges_) { + if (edge->prev_operator()->is_alive()) { + ret.push_back(edge); + } + } + return ret; +} + +void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; + return; + } + for (auto &edge : prev_edges_) { + if (edge->prev_operator() == op) { + edge = new_edge; + return; + } + } + MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; +} + +void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; + return; + } + for (auto &edge : succ_edges_) { + if (edge->next_operator() == op) { + edge = new_edge; + return; + } + } + MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; +} + +void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; + return; + } + std::vector> new_pre_edges; + for (auto &edge : prev_edges_) { + if (edge->prev_operator() != op) { + new_pre_edges.push_back(edge); + } + } + new_pre_edges.push_back(new_edge); + prev_edges_ = new_pre_edges; +} + +void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { + if (op == nullptr) { + MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; + return; + } + std::vector> new_succ_edges; + for (auto &edge : succ_edges_) { + if (edge->next_operator() != op) { + new_succ_edges.push_back(edge); + } + } + new_succ_edges.push_back(new_edge); + succ_edges_ = new_succ_edges; +} + +std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( + const Shapes &shapes, const std::vector &split_flag_list) { + if (shapes.size() != split_flag_list.size()) { + MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " + << shapes.size(); + return nullptr; + } + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + std::vector> strategy_v; + for (size_t i = 0; i != shapes.size(); i++) { + if (shapes[i].empty()) { + MS_LOG(INFO) << "Elements of shapes is empty."; + std::vector empty_element; + strategy_v.push_back(empty_element); + } else { + std::vector element(shapes[i].size(), 1); + if (split_flag_list[i]) { + element[0] = dev_num; + } + strategy_v.push_back(element); + } + } + return std::make_shared>>(strategy_v); +} + +void OperatorInfo::ReComputeBatchSplitFlagList() { + if (!inputs_shape_.empty()) { + split_flag_list_[0] = true; + } +} + +void OperatorInfo::ComputeBatchSplitFlagList() { + split_flag_list_.clear(); + for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) { + split_flag_list_.push_back(false); + } + ReComputeBatchSplitFlagList(); +} + +// This is a common method for checking whether the generated stragegy has the correct number of devuces. +Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { + if (sp == nullptr) { + MS_LOG(ERROR) << "The strategy is null."; + return FAILED; + } + int32_t product = 1; + + for (auto &input_partition : inputs_partitions) { + product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); + } + if (!FULLY_USE_DEVICES) { + if (IntToSize(product) > dev_num) { + return FAILED; + } + } else { + if ((product != 1) && (IntToSize(product) != dev_num)) { + return FAILED; + } + } + std::vector stras(inputs_partitions); + (*sp) = std::make_shared(stage_id, stras); + return SUCCESS; +} + +std::shared_ptr>> OperatorInfo::GenerateBatchStrategies() { + ComputeBatchSplitFlagList(); + return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); +} + +void PrintStrategy(const StrategyPtr &strategy) { + if (strategy == nullptr) { + return; + } + std::string all_strategy = ""; + for (size_t i = 0; i < strategy->GetInputNumber(); ++i) { + all_strategy += "["; + for (size_t j = 0; j < strategy->GetInputDim()[i].size(); ++j) { + all_strategy += std::to_string(strategy->GetInputDim()[i][j]); + if (j != strategy->GetInputDim()[i].size() - 1) { + all_strategy += ", "; + } + } + all_strategy += "]"; + if (i != strategy->GetInputNumber() - 1) { + all_strategy += ", "; + } + } + MS_LOG(INFO) << "The strategy is: " << all_strategy; +} + +// generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) +Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { + MS_LOG(ERROR) << "The inputs size is wrong."; + return FAILED; + } + + if ((inputs_shape[0].size() != inputs_shape[1].size()) || + (splittable_inputs[0].size() != splittable_inputs[1].size())) { + MS_LOG(ERROR) << "The size of two inputs are not equal."; + return FAILED; + } + + Shapes input0_shape = {inputs_shape[0]}; + Shapes input0_splittable = {splittable_inputs[0]}; + if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) { + return FAILED; + } + + for (auto &sp : *sp_vector) { + sp->ExpandInputDimFromOneToTwo(); + } + + return SUCCESS; +} + +// generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast +// such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) +Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if (inputs_shape[0].size() >= inputs_shape[1].size()) { + MS_LOG(ERROR) << "Invalid inputs shape."; + return FAILED; + } + + // first, generate strategy for input0 the same as input1 + Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]}; + Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]}; + if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + + // second, get the correct strategy for input0 + for (auto &sp : *sp_vector) { + std::vector tmp_strategy; + Dimensions input0_strategy = sp->GetInputDim()[0]; + size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); + + // erase the unnecessary part + (void)input0_strategy.erase(input0_strategy.begin(), + input0_strategy.begin() + static_cast(size_diff)); + + // handel the case likes ([1, c, d], [a, b, c, d]) + for (size_t i = 0; i < inputs_shape[0].size(); ++i) { + if (inputs_shape[0][i] == 1) { + input0_strategy[i] = 1; + } else { + break; + } + } + + // reset the strategy + tmp_strategy.push_back(input0_strategy); // input0 + tmp_strategy.push_back(sp->GetInputDim()[1]); // input1 + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +// generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast +// such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) +Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if (inputs_shape[0].size() <= inputs_shape[1].size()) { + MS_LOG(ERROR) << "Invalid inputs shape."; + return FAILED; + } + + // first, generate strategy for input1 the same as input0 + Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]}; + Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; + if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + + // second, get the correct strategy for input1 + for (auto &sp : *sp_vector) { + std::vector tmp_strategy; + tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 + + Dimensions input1_strategy = sp->GetInputDim()[1]; + size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size(); + + // erase the unnecessary part + (void)input1_strategy.erase(input1_strategy.begin(), + input1_strategy.begin() + static_cast(size_diff)); + + // handel the case likes ([a, b, c, d], [1, c, d]) + for (size_t i = 0; i < inputs_shape[1].size(); ++i) { + if (inputs_shape[1][i] == 1) { + input1_strategy[i] = 1; + } else { + break; + } + } + + // reset the strategy + tmp_strategy.push_back(input1_strategy); // input1 + sp->ResetInputs(tmp_strategy); + } + return SUCCESS; +} + +// generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast +// such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) +Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if (inputs_shape[0].size() != inputs_shape[1].size()) { + MS_LOG(ERROR) << "Invalid inputs shape."; + return FAILED; + } + + // step1: ([a, 1], [1, b]) -> [a, b] + Shape max_shape, splittable_vector; + for (size_t i = 0; i < inputs_shape[0].size(); ++i) { + if (inputs_shape[0][i] >= inputs_shape[1][i]) { + max_shape.push_back(inputs_shape[0][i]); + splittable_vector.push_back(splittable_inputs[0][i]); + } else { + max_shape.push_back(inputs_shape[1][i]); + splittable_vector.push_back(splittable_inputs[1][i]); + } + } + + // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b]) + Shapes tmp_inputs_shape = {max_shape, max_shape}; + Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector}; + if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + + // step3: reset the strategy if the dimension is 1 + for (auto &sp : *sp_vector) { + Dimensions input0_strategy = sp->GetInputDim()[0]; + Dimensions input1_strategy = sp->GetInputDim()[1]; + for (size_t i = 0; i < inputs_shape[0].size(); ++i) { + if (inputs_shape[0][i] == 1) { + input0_strategy[i] = 1; + } + + if (inputs_shape[1][i] == 1) { + input1_strategy[i] = 1; + } + } + sp->ResetInputs({input0_strategy, input1_strategy}); + } + + return SUCCESS; +} + +// 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that +// the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding +// dimension is splittable. 'inputs_partitions' is the result of partitions. +// NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring +// specific dimensions in inputs have the identical partition should have individual implementation. +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + if (splittable_inputs.size() != inputs_shape.size()) { + MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size() + << " : " << inputs_shape.size(); + return FAILED; + } + CheckGlobalDeviceManager(); + size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + + Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions; + for (size_t j = 0; j < inputs_shape.size(); ++j) { + (void)combined_inputs_shape.insert(combined_inputs_shape.end(), inputs_shape[j].begin(), inputs_shape[j].end()); + (void)combined_splittable_inputs.insert(combined_splittable_inputs.end(), splittable_inputs[j].begin(), + splittable_inputs[j].end()); + } + std::function recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape, + &combined_splittable_inputs, &combined_partitions, &recursive, + &inputs_shape](uint32_t current_index, size_t n) { + if (current_index == combined_inputs_shape.size()) { + MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); + Shapes inputs_partitions; + size_t global_index = 0; + for (auto &shape : inputs_shape) { + Shape tmp_partition; + for (size_t j = 0; j < shape.size(); ++j) { + tmp_partition.push_back(combined_partitions[global_index]); + global_index++; + } + inputs_partitions.push_back(tmp_partition); + } + StrategyPtr sp; + if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) { + sp_vector->push_back(sp); + } + return; + } else { + MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size(); + if (combined_splittable_inputs[current_index] == 0) { + combined_partitions.push_back(MIN_SLICE_NUM); + recursive(current_index + 1, n / MIN_SLICE_NUM); + combined_partitions.pop_back(); + } else if (combined_splittable_inputs[current_index] == 1) { + for (uint32_t i = 1; i <= n; i *= 2) { + if (n % i == 0 && IntToSize(combined_inputs_shape[current_index]) % i == 0) { + combined_partitions.push_back(i); + recursive(current_index + 1, n / i); + combined_partitions.pop_back(); + } + } + } + } + }; + recursive(0, dev_num); + if (sp_vector->empty()) { + MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo."; + } + return SUCCESS; +} + +// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, +// and the corresponding dimensions that are not broadcast are all relevant dimensions +// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) +// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) +// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { + if (sp_vector == nullptr) { + MS_LOG(ERROR) << "The sp_vector is null."; + return FAILED; + } + + if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { + MS_LOG(ERROR) << "The inputs' size is wrong."; + return FAILED; + } + + if (inputs_shape[0] == inputs_shape[1]) { + // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy + if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success."; + } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) { + // ([a, b, c, d], []) or ([], [a, b, c, d]) + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "Generate strategies for scalar case failed."; + return FAILED; + } + MS_LOG(INFO) << "Generate strategies for scalar case success."; + } else if (inputs_shape[0].size() > inputs_shape[1].size()) { + // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) + if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success."; + } else if (inputs_shape[0].size() < inputs_shape[1].size()) { + // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) + if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success."; + } else { // same size, but different value + // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) + if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { + MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed."; + return FAILED; + } + MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success."; + } + return SUCCESS; +} + +Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { + if (InitForCostModel(strategy) == FAILED) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Initialization under the strategy failed."; + } + return FAILED; + } + int32_t stage_id = strategy->GetInputStage(); + double computation_cost = + operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + result->communication_forward_ = result->communication_without_parameter_; + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); + + return SUCCESS; +} + +int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { + if (is_output_parameter_involve_ != -1) { + return is_output_parameter_involve_; + } + is_parameter_involve_ = is_parameter_; + const auto &prev_edges = this->GetAlivePrevEdges(); + for (auto &p_edge : prev_edges) { + auto input_index = p_edge->next_op_input_index(); + auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); + if (input_index >= is_parameter_involve_.size()) { + MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size() + << ", but got wrong input_index: " << input_index; + } + if (prev_op_para == 0) { + is_parameter_involve_[input_index] = false; + } else if (prev_op_para == 1) { + is_parameter_involve_[input_index] = true; + } else { + MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index; + } + p_edge->set_parameter_involve(prev_op_para); + } + if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) { + // If anyone of the input is a parameter_involved, the output is parameter_involved. + is_output_parameter_involve_ = 1; + } else { + is_output_parameter_involve_ = 0; + } + + return is_output_parameter_involve_; +} + +Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { + if (is_parameter.size() != inputs_shape_.size()) { + MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() + << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); + return FAILED; + } + is_parameter_ = is_parameter; + operator_cost()->set_is_parameter(is_parameter); + return SUCCESS; +} + +Status OperatorInfo::CalculateMemoryCost() { + // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to + // calculate memory cost. + if (is_parameter_involve_.size() != is_parameter_.size()) { + MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; + return FAILED; + } + operator_cost()->set_is_parameter_involve(is_parameter_involve_); + operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); + // Set the memory cost in the 'strategy_cost_' + for (auto &swc : strategy_cost_) { + auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); + swc->cost_list[0]->memory_with_reuse_ = mem_cost; + } + return SUCCESS; +} + +Status OperatorInfo::CalculateMemoryCostForInference() { + // First, set the 'is_outputs_critical_' flag into OperatorCost. + if (is_output_critical_ == -1) { + MS_LOG(EXCEPTION) << "The critical flag is not set."; + return FAILED; + } + operator_cost()->set_output_critical(is_output_critical_); + // Set the memory cost in the 'strategy_cost_' + for (auto &swc : strategy_cost_) { + auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); + swc->cost_list[0]->memory_with_reuse_ = mem_cost; + } + return SUCCESS; +} + +Status OperatorInfo::CorrectMemoryCost(size_t input_index) { + for (auto &swc : strategy_cost_) { + double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * + static_cast(operator_cost()->inputs_type_lengths()[input_index]); + swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; + if (swc->cost_list[0]->memory_with_reuse_ < 0) { + MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_ + << ", the parameter memory cost is: " << parameter_mem_cost; + return FAILED; + } + } + return SUCCESS; +} + +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { + int32_t ret = -1; + + // The number of repetitions is equal to the number of all devices divided by the number of devices use for + // tensor map. + int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); + for (auto &element : tensor_map) { + // -1 means the corresponding dimension is not split. + if (element == MAP_NONE) { + continue; + } else if ((element < 0) || (IntToSize(element) >= dev_matrix_shape.size())) { + MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is " + << ShapeToString(dev_matrix_shape); + return ret; + } else { + size_t index = dev_matrix_shape.size() - IntToSize(element) - 1; + if (dev_matrix_shape[index] <= 0) { + MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape); + return ret; + } + device_num /= dev_matrix_shape[index]; + } + } + + return device_num; +} + +Status OperatorInfo::InferAsLossDivisor() { + if (!ParallelContext::GetInstance()->loss_repeated_mean()) { + as_loss_divisor_ = 1; + return SUCCESS; + } + + if (outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; + return FAILED; + } + + if (outputs_tensor_map_.size() > 1) { + MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size() + << ", need to override this function "; + return FAILED; + } + + if (outputs_tensor_map_[0].empty()) { + as_loss_divisor_ = SizeToInt(global_device_list_.size()); + MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; + return SUCCESS; + } + + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_) + << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is " + << as_loss_divisor_; + return SUCCESS; +} + +// If the operator is used as a loss, a div node is inserted for the grad of all its inputs. +Status OperatorInfo::InferVirtualDivOps() { + if (InferAsLossDivisor() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed."; + return FAILED; + } + + if (as_loss_divisor_ <= 0) { + MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_; + return FAILED; + } else if (as_loss_divisor_ == 1) { + MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op."; + return SUCCESS; + } + + virtual_div_op_.clear(); + // if loss is repeated calculation, insert div op + Operator op = CreateVirtualDivOp(as_loss_divisor_); + virtual_div_op_.push_back(op); + return SUCCESS; +} + +Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { + if (input_lengths.size() != inputs_shape_.size()) { + MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() + << " do not have the same number of inputs shape: " << inputs_shape_.size(); + return FAILED; + } + if (output_lengths.size() != outputs_shape_.size()) { + MS_LOG(ERROR) << "Output_lengths: " << output_lengths.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + return FAILED; + } + inputs_type_lengths_ = input_lengths; + outputs_type_lengths_ = output_lengths; + operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); + return SUCCESS; +} + +double OperatorInfo::GetOutputsTotalSize() { + if (is_calculated_outputs_size_) { + return outputs_total_size_; + } + if (outputs_type_lengths_.size() != outputs_shape_.size()) { + MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + } + double sum = 0.0; + for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { + auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast(1.0), + std::multiplies()); + sum += size * static_cast(outputs_type_lengths_[i]); + } + is_calculated_outputs_size_ = true; + outputs_total_size_ = sum; + return outputs_total_size_; +} + +Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { + if (outputs_type.size() != outputs_shape_.size()) { + MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() + << " do not have the same number of outputs shape: " << outputs_shape_.size(); + return FAILED; + } + outputs_type_ = outputs_type; + return SUCCESS; +} + +void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { + if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { + CheckGlobalDeviceManager(); + auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); + if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { + if (cost->computation_cost_ > 1.0) { + cost->computation_cost_ -= 1.0; + } + if (cost->communication_cost_ > 1.0) { + cost->communication_cost_ -= 1.0; + } + if (cost->communication_with_partial_para_ > 1.0) { + cost->communication_with_partial_para_ -= 1.0; + } + if (cost->communication_without_parameter_ > 1.0) { + cost->communication_without_parameter_ -= 1.0; + } + } + } +} + +double OperatorInfo::GetForwardMemoryCostFromCNode() { + return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); +} + +void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { + MS_EXCEPTION_IF_NULL(s_strategy); + if (!s_strategy->IsEqual(selected_strategy_)) { + MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:"; + PrintStrategy(selected_strategy_); + MS_LOG(INFO) << "The minimal strategy:"; + PrintStrategy(s_strategy); + } +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h new file mode 100644 index 0000000000..8641c47491 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -0,0 +1,289 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "base/base.h" +#include "frontend/parallel/auto_parallel/costmodel.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/group_manager.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +using ForwardOp = OperatorVector; +using MirrorOps = std::vector; +using Ops = std::vector; +using VirtualDivOp = OperatorVector; +using TensorMaps = std::vector>; +using TensorLayouts = std::vector; +using different_type = std::vector::difference_type; +using PrimitiveAttrs = std::unordered_map; +using Strategys = std::vector; +using ReplaceGraphPtr = std::shared_ptr>, AnfNodePtr>>; + +class Edge; + +class OperatorInfo { + public: + OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost) + : name_(std::move(name)), + inputs_shape_(std::move(inputs_shape)), + outputs_shape_(std::move(outputs_shape)), + attrs_(std::move(attrs)), + is_alive_(true), + operator_cost_(cost), + outputs_type_() { + std::vector not_parameteter(inputs_shape_.size(), false); + is_parameter_ = not_parameteter; + refkey_parameter_name_ = ""; + } + + virtual ~OperatorInfo() = default; + + Status set_is_parameter(const std::vector &is_parameter); + Status SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths); + double GetOutputsTotalSize(); + // Set outputs dtype. + // If only one output, outputs_type.size() is 1. + // If output is tuple, outputs_type.size() is greater than 1. + Status set_outputs_type(const std::vector &outputs_type); + const std::vector &outputs_type() const { return outputs_type_; } + virtual Status Init(const StrategyPtr &strategy) = 0; + virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts + + // Given the stage_id (which indicates the number of devices), + // generate all strategies for this operator + virtual Status GenerateStrategies(int32_t stage_id) = 0; + const OperatorCostPtr &operator_cost() const { return operator_cost_; } + void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } + virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; + + virtual std::shared_ptr>> GenerateBatchStrategies(); + virtual void ReComputeBatchSplitFlagList(); + void ComputeBatchSplitFlagList(); + + double GetForwardMemoryCostFromCNode(); + // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy + // is checked + Status SetCostUnderStrategyBase(const StrategyPtr &strategy); + std::vector> GetStrategyCost() { return strategy_cost_; } + // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving + // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory + // at the end of forward phase. + Status CalculateMemoryCost(); + // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated + // by the output + Status CalculateMemoryCostForInference(); + int ComputeOpAndPrevEdgeParameterInvolved(); + + ForwardOp forward_op() const { return forward_op_; } + ForwardOp replace_op() const { return replace_op_; } + OutPutInfoVector replace_op_info() const { return replace_op_info_; } + virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } + MirrorOps mirror_ops() const { return mirror_ops_; } + Ops sub_ops() const { return sub_ops_; } + VirtualDivOp virtual_div_op() const { return virtual_div_op_; } + Shape dev_matrix_shape() const { return dev_matrix_shape_; } + std::vector inputs_tensor_info() const { return inputs_tensor_info_; } + std::vector outputs_tensor_info() const { return outputs_tensor_info_; } + std::vector> strategy_cost() const { return strategy_cost_; } + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } + RankList global_device_list() const { return global_device_list_; } + + void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } + void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } + std::vector> succ_edges() const { return succ_edges_; } + std::vector> prev_edges() const { return prev_edges_; } + std::vector> GetAliveSuccEdges(); + std::vector> GetAlivePrevEdges(); + void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } + void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { + selected_strategy_ = s_strategy; + selected_cost_ = cost; + } + StrategyPtr selected_strategy() const { return selected_strategy_; } + CostPtr selected_cost() const { return selected_cost_; } + void CheckSelectedStrategy(const StrategyPtr &); + Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } + void set_input_value(const std::vector &input_value) { input_value_ = input_value; } + const std::vector &input_value() const { return input_value_; } + void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } + void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } + bool is_alive() const { return is_alive_; } + void SetNotAlive() { is_alive_ = false; } + StrategyPtr strategy() const { return strategy_; } + void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } + void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } + const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } + // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated + // multiple times. This method is to correct this, and makes the cost is calulated only once. + Status CorrectMemoryCost(size_t input_index); + int is_output_parameter_involve() const { return is_output_parameter_involve_; } + int is_output_critical() const { return is_output_critical_; } + void mark_output_critical() { is_output_critical_ = 1; } + void mark_output_not_critical() { is_output_critical_ = 0; } + int used_devices() const { return used_devices_; } + // needed by rec_parser + void set_type(const std::string &type) { type_ = type; } + const std::string &type() const { return type_; } + const std::unordered_map &attrs() const { return attrs_; } + + protected: + // needed by rec_parser + std::string type_; + virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; + virtual Status InferTensorMap() = 0; + virtual Status InferForwardCommunication() = 0; + virtual Status InferMirrorOps() = 0; + virtual Status GetAttrs() = 0; + virtual Status InferTensorInfo() = 0; + virtual Status InferDevMatrixShape() = 0; + void SetDeviceListByStrategy(); + void SetRepeatedCalcDevMatrix(); + Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); + Status CreateGroupByDim(size_t axis, std::vector *group); + Status InferAttrs(); + void ResetQueueMember(); + Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitWithManualRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); + Status InferRepeatedCalcInfo(); + Status InferVirtualDivOps(); + + // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map. + // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output + // is used for grad and overload the function. If the output is a scalar, need to override the function too. + virtual Status InferAsLossDivisor(); + Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); + void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); + + std::string name_; + Shapes inputs_shape_; + Shapes outputs_shape_; + std::unordered_map attrs_; + std::vector input_value_; + TypePtr outputs_dtype_; + + StrategyPtr strategy_; + std::vector inputs_tensor_info_; + std::vector outputs_tensor_info_; + Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension + int32_t repeated_calc_num_ = 1; + int32_t as_loss_divisor_ = 1; + TensorMaps inputs_tensor_map_; + TensorMaps outputs_tensor_map_; + ForwardOp forward_op_; + Ops sub_ops_; + ForwardOp replace_op_; + OutPutInfoVector replace_op_info_; + ReplaceGraphPtr replace_graph_; + MirrorOps mirror_ops_; + VirtualDivOp virtual_div_op_; + RankList global_device_list_; // the size of global_device_list equal to the size of stageID + RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ + bool infer_attrs_completed_ = false; + + bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel + // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. + std::vector corrected_input_indices_; + // Given a parallization strategy, there is a cost. + std::vector> strategy_cost_; + // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter + std::vector is_parameter_; + // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of + // pre-operator that has parameters as input. + std::vector is_parameter_involve_; + // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating + // peak memory cost in the training phase. + // -1: unset; 0: not parameter_involved; 1: parameter_involved + int is_output_parameter_involve_ = -1; + // Whether this output is critical, which means that this output is included in calculating peak memory cost + // in the inference phase. + // -1 : unset; 0: not critical; 1: critical + int is_output_critical_ = -1; + double outputs_total_size_ = 0.0; + bool is_calculated_outputs_size_ = false; + // for each input and output, the followings record the number of bytes of each element + std::vector inputs_type_lengths_; + std::vector outputs_type_lengths_; + std::vector> prev_edges_; + std::vector> succ_edges_; + StrategyPtr selected_strategy_; + // Used in DP algorithm + bool is_alive_; + CostPtr selected_cost_; + std::vector split_flag_list_; + std::string refkey_parameter_name_; + CNodePtr cnode_; + int32_t used_devices_ = -1; + + private: + OperatorCostPtr operator_cost_; + std::vector outputs_type_; +}; + +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); +Operator CreateVirtualDivOp(int32_t div_num); +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); +Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); +std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( + const Shapes &shapes, const std::vector &split_flag_list); + +void PrintStrategy(const StrategyPtr &strategy); +// generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *sp_vector); +// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, +// and the corresponding dimensions that are not broadcast are all relevant dimensions +// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) +// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) +// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *sp_vector); + +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h new file mode 100644 index 0000000000..bc732ed234 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ + +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/ops_info/batch_parallel_info.h" +#include "frontend/parallel/ops_info/bias_add_info.h" +#include "frontend/parallel/ops_info/comparison_function_info.h" +#include "frontend/parallel/ops_info/dropout_do_mask_info.h" +#include "frontend/parallel/ops_info/elementary_function_info.h" +#include "frontend/parallel/ops_info/gather_v2_info.h" +#include "frontend/parallel/ops_info/get_next_info.h" +#include "frontend/parallel/ops_info/l2_normalize_info.h" +#include "frontend/parallel/ops_info/layer_norm_info.h" +#include "frontend/parallel/ops_info/loss_info.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/ops_info/onehot_info.h" +#include "frontend/parallel/ops_info/prelu_info.h" +#include "frontend/parallel/ops_info/reduce_method_info.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/ops_info/transpose_info.h" +#include "frontend/parallel/ops_info/virtual_dataset_info.h" +#include "frontend/parallel/ops_info/gather_v2_p_info.h" + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h similarity index 100% rename from mindspore/ccsrc/parallel/ops_info/ops_utils.h rename to mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc new file mode 100644 index 0000000000..57b35b69f7 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc @@ -0,0 +1,253 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/prelu_info.h" + +#include +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +/* + * prelu has 2 input + * A: A float tensor of shape [NCHW] representing the output of the preview layer. + * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. + * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 + */ +Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + std::vector stra = strategy->GetInputDim(); + if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy size."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy size."; + } + return FAILED; + } + if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid channel strategy."; + } + return FAILED; + } + return SUCCESS; +} + +/* + * device matrix is same with the strategy matrix + */ +Status PReLUInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + input_strategy_ = input_strategy; + dev_matrix_shape_ = input_strategy; + return SUCCESS; +} + +Status PReLUInfo::InferMirrorOps() { + Shape param_tensor_map = inputs_tensor_map_[1]; + std::vector param_group; + if (CreateGroupByTensorMap(param_tensor_map, ¶m_group) != SUCCESS) { + return FAILED; + } else if (param_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } + OperatorVector op_for_param; + op_for_param = CreateMirrorOps(param_group[0].name(), param_group[0].GetDevNum()); + // op_for_inputs is empty + OperatorVector op_for_inputs; + mirror_ops_.push_back(op_for_inputs); + mirror_ops_.push_back(op_for_param); + std::string group_name = param_group[0].name(); + MS_LOG(INFO) << name_ << ": The mirror ops group is " << group_name; + return SUCCESS; +} + +Status PReLUInfo::InferForwardCommunication() { return SUCCESS; } + +/* + * the output tensor map is the same as the input tensor map + */ +Status PReLUInfo::InferTensorMap() { + TensorMap input_tensor_map; + // such as 4: input_tensor_map [3,2,1,0] + for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { + input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1)); + } + + TensorMap param_tensor_map; + if (inputs_shape_[1][0] == 1) { + param_tensor_map.push_back(-1); + } else { + param_tensor_map.push_back(input_tensor_map.at(1)); + } + inputs_tensor_map_.push_back(input_tensor_map); + inputs_tensor_map_.push_back(param_tensor_map); + outputs_tensor_map_.push_back(input_tensor_map); + return SUCCESS; +} + +Dimensions PReLUInfo::GetOutputStrategy() { + Dimensions output_strategy = input_strategy_; + return output_strategy; +} + +Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if (inputs_layout == nullptr || outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; + return FAILED; + } + TensorLayout input_layout, param_layout, output_layout; + if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || + (param_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || + (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { + return FAILED; + } + inputs_layout->push_back(input_layout); + inputs_layout->push_back(param_layout); + outputs_layout->push_back(output_layout); + return SUCCESS; +} + +Status PReLUInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape param_shape = inputs_shape_.at(1); + Shape output_shape = outputs_shape_.at(0); + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Dimensions output_strategy = GetOutputStrategy(); + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape param_slice_shape = inputs_slice_shape.at(1); + Shape output_slice_shape = outputs_slice_shape.at(0); + + // infer tensor layout + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + + TensorLayout input_layout = inputs_layout.at(0); + TensorLayout param_layout = inputs_layout.at(1); + TensorLayout output_layout = outputs_layout.at(0); + TensorInfo input_tensor_info(input_layout, input_shape, input_slice_shape); + TensorInfo param_tensor_info(param_layout, param_shape, param_slice_shape); + TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + inputs_tensor_info_.push_back(param_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status PReLUInfo::GetAttrs() { + if ((inputs_shape_.size() != PRELU_INPUTS_SIZE) || (outputs_shape_.size() != PRELU_OUTPUTS_SIZE)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size " + << outputs_shape_.size() << " is wrong."; + return FAILED; + } + return SUCCESS; +} + +Status PReLUInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status PReLUInfo::GenerateStrategies(int32_t stage_id) { + if (inputs_shape_.size() != PRELU_INPUTS_SIZE) { + return FAILED; + } + if (inputs_shape_[1].size() != PRELU_SECOND_INPUT_SIZE) { + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split; + input0_split.emplace_back(1); + input0_split.emplace_back(0); + (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 2, 1); + Shape input1_split(inputs_shape_[1].size(), 0); + Shapes splittable_inputs = {input0_split, input1_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h new file mode 100644 index 0000000000..e6e5e23bac --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * parallel class for PReLU Primitive + */ +class PReLUInfo : public OperatorInfo { + public: + PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~PReLUInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status GetAttrs() override; + Dimensions GetOutputStrategy(); + + private: + Dimensions input_strategy_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc new file mode 100644 index 0000000000..0488dceeca --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -0,0 +1,571 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/reduce_method_info.h" + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ReduceMethod::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + + dev_matrix_shape_ = input_strategy; + + return SUCCESS; +} + +std::vector ReduceMethod::reduce_dim() { + std::vector dim_list; + if (input_value_.size() < 2) { + MS_LOG(EXCEPTION) << name_ << ": Input value size is smaller than 2."; + } + if (input_value_.back() == nullptr) { + MS_LOG(EXCEPTION) << name_ << ": Input value is nullptr."; + } + MS_ASSERT(inputs_shape_.size() == 1); + auto input_dim = inputs_shape_.at(0).size(); + if (input_value_.back()->isa()) { + auto attr_axis = GetValue>(input_value_.back()); + // axis is (), reduce all dim + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (input_value_.back()->isa()) { + int axis = GetValue(input_value_.back()); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + + return dim_list; +} + +Status ReduceMethod::GetAttrs() { + // get attr cross_batch and keep_dims + auto keep_dims_iter = attrs_.find(KEEP_DIMS); + if (keep_dims_iter == attrs_.end()) { + MS_LOG(ERROR) << name_ << ": Don't have attr keep_dims."; + return FAILED; + } + + if (keep_dims_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(keep_dims_iter->second); + if (!keep_dims_iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": Keep_dims is not a bool."; + return FAILED; + } + keepdims_ = keep_dims_iter->second->cast()->value(); + } + + auto cross_batch_iter = attrs_.find(CROSS_BATCH); + if (cross_batch_iter != attrs_.end()) { + MS_EXCEPTION_IF_NULL(cross_batch_iter->second); + if (!cross_batch_iter->second->isa()) { + MS_LOG(ERROR) << name_ << ": cross_batch is not a bool."; + return FAILED; + } + cross_batch_ = cross_batch_iter->second->cast()->value(); + } + auto reducemethodcost = std::dynamic_pointer_cast(operator_cost()); + if (reducemethodcost == nullptr) { + MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; + return FAILED; + } + reducemethodcost->set_cross_batch(cross_batch_); + return SUCCESS; +} + +Status ReduceMethod::InferTensorMap() { + std::vector tensor_map_index, dim_list, output_tensor_map; + size_t size = inputs_shape_.at(0).size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - 1 - i)); + } + dim_list = reduce_dim(); + for (size_t i = 0; i < size; ++i) { + if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { + if (keepdims_) { + output_tensor_map.push_back(-1); + } else { + continue; + } + } else { + output_tensor_map.push_back(tensor_map_index[i]); + } + } + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(output_tensor_map); + + return SUCCESS; +} + +bool IsDataParallelStrategy(const Dimensions &strategy) { + CheckGlobalDeviceManager(); + size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size(); + if (strategy.empty()) { + MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty"; + } + + return (IntToSize(strategy[0]) == total_dev_num); +} + +Status ReduceMethod::InferForwardCommunication() { + Dimensions stra = strategy_->GetInputDim().at(0); + if (cross_batch_ && IsDataParallelStrategy(stra)) { + MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; + return SUCCESS; + } + if (cross_batch_) { + MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; + return SUCCESS; + } + forward_op_.clear(); + std::vector dim_list = reduce_dim(); + size_t size = stra.size(); + // judge if the reduce dim is partitioned. + Shape group_creat_map; + if (dev_matrix_shape_.size() > size) { + group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); + } + for (size_t index = 0; index < size; ++index) { + auto pos = + std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); + if (pos != dim_list.end() && stra[index] != 1) { + continue; + } + group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); + } + std::vector forward_group; + if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; + return FAILED; + } + if (!forward_group.empty()) { + Operator op = CreateAllReduceOp(reduce_method_, forward_group[0].name()); + forward_op_.push_back(op); + std::string group_name = forward_group[0].name(); + MS_LOG(INFO) << name_ << ": Forward communication group is " << group_name; + } + + return SUCCESS; +} + +ForwardOp CreatReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { + // Creat AllReduceSum op + Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); + std::string group_name = forward_group[0].name(); + MS_LOG(INFO) << "The group of forward all reduce is " << group_name; + + // Creat RealDiv op + OperatorName operator1_name = REAL_DIV; + std::vector device_list = forward_group[0].GetDevicesList(); + auto divisor = static_cast(device_list.size()); + std::vector tensor_data = {divisor}; + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, dtype); + ValuePtr op1_param_value = MakeValue(tensor_ptr); + Attr op1_param = std::make_pair("divisor", op1_param_value); + OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; + OperatorAttrs operator1_attrs; + OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params); + Operator op1 = std::make_pair(operator1_name, operator1_args); + ForwardOp forward_op = {op0, op1}; + + std::string dtype_name = dtype->ToString(); + MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name; + return forward_op; +} + +Status ReduceMeanInfo::InferForwardCommunication() { + Dimensions stra = strategy_->GetInputDim().at(0); + if (cross_batch_ && IsDataParallelStrategy(stra)) { + MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; + return SUCCESS; + } + forward_op_.clear(); + std::vector dim_list = reduce_dim(); + size_t size = stra.size(); + // judge if the reduce dim is partitioned. + Shape group_creat_map; + if (dev_matrix_shape_.size() > size) { + group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); + } + for (size_t index = 0; index < size; ++index) { + auto pos = + std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); + if (pos != dim_list.end() && stra[index] != 1) { + continue; + } + group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); + } + std::vector forward_group; + if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; + return FAILED; + } + if (!forward_group.empty()) { + if ((outputs_dtype_ == nullptr) || !outputs_dtype_->isa()) { + MS_LOG(ERROR) << name_ << ": The dtype of output is not Array"; + return FAILED; + } + + auto element_type = outputs_dtype_->cast()->element(); + forward_op_ = CreatReduceMeanForwardOp(forward_group, element_type); + } + + return SUCCESS; +} + +Status ReduceMethod::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_.at(0); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << " Infer MirrorOps failed."; + return FAILED; + } + + OperatorVector op_for_weight; + OperatorVector op_for_reduce_axis; // helper node + if (input_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_weight); + mirror_ops_.push_back(op_for_reduce_axis); + std::string group_name = input_group[0].name(); + MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success, the group is " << group_name; + } + + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = inputs_tensor_map_.at(0); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; + return FAILED; + } + + OperatorVector op_for_weight; + if (input_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } else { + op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + mirror_ops_.push_back(op_for_weight); + MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success."; + } + + return SUCCESS; +} + +Dimensions ReduceMethod::InferOutputStrategy() { + std::vector dim_list = reduce_dim(); + Dimensions output_strategy; + Dimensions stra = strategy_->GetInputDim().at(0); + // if keepdims_ is true,then output strategy is same with input. + for (size_t i = 0; i < stra.size(); ++i) { + if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { + if (keepdims_) { + output_strategy.push_back(1); + } + } else { + output_strategy.push_back(stra[i]); + } + } + return output_strategy; +} + +Status ReduceMethod::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy = InferOutputStrategy(); + + Strategys outputs_strategy = {output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape output_slice_shape = outputs_slice_shape.at(0); + + TensorLayout input_tensor_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { + return FAILED; + } + + std::vector dim_list = reduce_dim(); + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + input_tensor_info.set_reduce_dim(dim_list); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + + return SUCCESS; +} + +Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status ReduceMethod::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + is_auto_parallel_ = true; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ReduceMethod::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + return SUCCESS; +} + +Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed"; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed"; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success"; + return SUCCESS; +} + +std::vector ArgMaxWithValueInfo::reduce_dim() { + std::vector dim_list; + auto iter = attrs_.find(AXIS); + if (iter == attrs_.end()) { + MS_LOG(EXCEPTION) << name_ << ": Don't have attr axis."; + } + + MS_ASSERT(inputs_shape_.size() == 1); + auto input_dim = inputs_shape_.at(0).size(); + MS_EXCEPTION_IF_NULL(iter->second); + if (iter->second->isa()) { + auto attr_axis = GetValue>(iter->second); + if (attr_axis.empty()) { + for (size_t i = 0; i < input_dim; ++i) { + dim_list.push_back(SizeToInt(i)); + } + } else { + for (auto &axis : attr_axis) { + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } + } + } else if (iter->second->isa()) { + int axis = GetValue(iter->second); + axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); + } else { + MS_LOG(EXCEPTION) << "Axis type is invalid."; + } + + return dim_list; +} + +Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) { + if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; + } else { + MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; + } + return FAILED; + } + std::vector dim_list = reduce_dim(); + MS_ASSERT(dim_list.size() == 1); + + std::vector stra = strategy->GetInputDim(); + MS_ASSERT(stra.size() == 1); + Shape input_strategy = stra.at(0); + MS_ASSERT(dim_list.at(0) < input_strategy.size()); + if (input_strategy.at(IntToSize(dim_list.at(0))) != 1) { + MS_LOG(WARNING) + << name_ + << " CheckStrategy for ArgMaxWithValueInfo, the strategy corresponding to axis is not one, real strategy " + "is " + << input_strategy.at(IntToSize(dim_list.at(0))) + << ", the output index may be not compatible with the stand alone Primitive"; + } + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferTensorMap() { + if (ReduceMethod::InferTensorMap() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed"; + return FAILED; + } + MS_ASSERT(outputs_tensor_map_.size() == 1); + outputs_tensor_map_.push_back(outputs_tensor_map_[0]); + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + Shape output_shape = outputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Dimensions output_strategy = InferOutputStrategy(); + + Strategys outputs_strategy = {output_strategy, output_strategy}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + Shape output_slice_shape = outputs_slice_shape.at(0); + + TensorLayout input_tensor_layout, output_tensor_layout; + if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || + (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { + return FAILED; + } + + std::vector dim_list = reduce_dim(); + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); + input_tensor_info.set_reduce_dim(dim_list); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +Status ArgMaxWithValueInfo::InferAsLossDivisor() { + if (outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; + return FAILED; + } + + MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; + if (outputs_tensor_map_[0].empty()) { + as_loss_divisor_ = SizeToInt(global_device_list_.size()); + MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; + return SUCCESS; + } + + as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); + + std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_); + std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]); + MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str + << ", " << output_tensor_map_str << ", " << as_loss_divisor_; + return SUCCESS; +} + +Status ArgMaxWithValueInfo::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 2)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + is_auto_parallel_ = true; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated strategy " << success; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h new file mode 100644 index 0000000000..ed9ab0721d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h @@ -0,0 +1,141 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ + +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ReduceMethod : public OperatorInfo { + public: + ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} + ~ReduceMethod() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + std::string reduce_method_; + bool keepdims_ = false; + bool cross_batch_ = false; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override; + Dimensions InferOutputStrategy(); + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferMirrorOps() override; + virtual std::vector reduce_dim(); + Status InferForwardCommunication() override; + Status InferDevMatrixShape() override; +}; + +class ReduceMaxInfo : public ReduceMethod { + public: + ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MAX; + } + + ~ReduceMaxInfo() override = default; +}; + +class ArgMaxWithValueInfo : public ReduceMethod { + public: + ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MAX; + } + + ~ArgMaxWithValueInfo() override = default; + + Status GenerateStrategies(int32_t stage_id) override; + + protected: + std::vector reduce_dim() override; + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferAsLossDivisor() override; +}; + +class ArgMinWithValueInfo : public ArgMaxWithValueInfo { + public: + ArgMinWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ArgMaxWithValueInfo(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MIN; + } + + ~ArgMinWithValueInfo() override = default; +}; + +class ReduceMeanInfo : public ReduceMethod { + public: + ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + set_cost(std::make_shared()); + } + + ~ReduceMeanInfo() override = default; + + protected: + Status InferForwardCommunication() override; +}; + +class ReduceSumInfo : public ReduceMethod { + public: + ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_SUM; + } + + ~ReduceSumInfo() override = default; +}; + +class ReduceMinInfo : public ReduceMethod { + public: + ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { + reduce_method_ = REDUCE_OP_MIN; + } + + ~ReduceMinInfo() override = default; +}; +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc new file mode 100644 index 0000000000..fb62c1d02c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc @@ -0,0 +1,507 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/reshape_info.h" + +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + size_t strategy_size = strategy->GetInputNumber(); + if (strategy_size != 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy size " << strategy_size; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy size " << strategy_size; + } + return FAILED; + } + return SUCCESS; +} + +/* + * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of + * device matrix + * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) + */ +Status ReshapeInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + input_strategy_ = stra.at(0); + dev_matrix_shape_.push_back(input_strategy_[0]); + return SUCCESS; +} + +/* + * there is no Parameter for Reshape Primitive, so no need to do allreduce + */ +Status ReshapeInfo::InferMirrorOps() { + mirror_ops_.clear(); + Shape input_tensor_map = input_layout_.tensor_map().array(); + std::vector input_group; + if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; + return FAILED; + } + + OperatorVector op_for_input; + if (input_group.empty()) { + MS_LOG(INFO) << name_ << ": The mirror ops is empty."; + return SUCCESS; + } + if (!input_group.empty()) { + op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); + std::string group_name = input_group[0].name(); + MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name; + } + mirror_ops_.push_back(op_for_input); + OperatorVector op_for_input_empty; + mirror_ops_.push_back(op_for_input_empty); + + return SUCCESS; +} + +/* + * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce + */ +Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; } + +/* + * get shape input of Reshape Primitive + * the result is saved in parameter_input_v_ + * not support -1 + */ +Status ReshapeInfo::GetParameterInput() { + if (input_value_[1] == nullptr) { + MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; + return FAILED; + } + std::vector elements; + ValueTuplePtr dim_tuple = input_value_[1]->cast(); + if (dim_tuple == nullptr) { + MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr."; + return FAILED; + } + elements = dim_tuple->value(); + if (elements.size() != outputs_shape_[0].size()) { + MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size."; + return FAILED; + } + + for (auto &element : elements) { + MS_EXCEPTION_IF_NULL(element); + if (element->isa()) { + int32_t axis = element->cast()->value(); + parameter_input_v_.push_back(axis); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; + return FAILED; + } + } + return SUCCESS; +} + +Status ReshapeInfo::ComputeReplaceOp() { + RankList dev_list = global_device_list(); + TensorRedistribution tensor_redistribution(!is_generating_costs_, true); + if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { + if (is_generating_costs_) { + MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed."; + } else { + MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; + } + return FAILED; + } + MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); + MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); + MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); + RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); + if (redistribution_oplist_ptr == nullptr) { + if (is_generating_costs_) { + MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; + } else { + MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; + } + return FAILED; + } + replace_op_ = redistribution_oplist_ptr->first; + replace_op_info_ = redistribution_oplist_ptr->second; + MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); + return SUCCESS; +} + +/* + * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement, + * all other dimension is set to None + * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) + */ +Status ReshapeInfo::InferTensorMap() { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are " + << inputs_shape_.size() << " and " << outputs_shape_.size(); + return FAILED; + } + + std::vector tensor_map_index_input; + tensor_map_index_input.push_back(0); + + for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { + tensor_map_index_input.push_back(MAP_NONE); + } + inputs_tensor_map_.push_back(tensor_map_index_input); + + std::vector tensor_map_index_output; + tensor_map_index_output.push_back(0); + + for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { + tensor_map_index_output.push_back(MAP_NONE); + } + outputs_tensor_map_.push_back(tensor_map_index_output); + return SUCCESS; +} + +/* + * the output tensor strategy is the same as input tensor strategy + * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) + */ +Strategys ReshapeInfo::GetOutputsStrategy() { + Strategys outputs_strategy; + std::vector strategy; + strategy.push_back(input_strategy_[0]); + for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { + strategy.push_back(1); + } + outputs_strategy.push_back(strategy); + return outputs_strategy; +} + +Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if (inputs_layout == nullptr || outputs_layout == nullptr) { + MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; + return FAILED; + } + Arrangement dev_matrix; + Status status = dev_matrix.Init(dev_matrix_shape_); + if (status != Status::SUCCESS) { + return status; + } + // infer input tensor info + Shape shape_array_in = inputs_shape_.at(0); + TensorMap tensor_map_array_in = inputs_tensor_map_.at(0); + TensorLayout tensor_layout_in; + Map tensor_map_in; + status = tensor_map_in.Init(tensor_map_array_in); + if (status != Status::SUCCESS) { + return status; + } + Arrangement shape_in; + status = shape_in.Init(shape_array_in); + if (status != Status::SUCCESS) { + return status; + } + (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in); + inputs_layout->push_back(tensor_layout_in); + // infer output tensor info + Shape shape_array_out = outputs_shape_.at(0); + + TensorMap tensor_map_array_out = outputs_tensor_map_.at(0); + TensorLayout tensor_layout_out; + Map tensor_map_out; + status = tensor_map_out.Init(tensor_map_array_out); + if (status != Status::SUCCESS) { + return status; + } + Arrangement shape_out; + status = shape_out.Init(shape_array_out); + if (status != Status::SUCCESS) { + return status; + } + (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out); + outputs_layout->push_back(tensor_layout_out); + + input_layout_ = tensor_layout_in; + output_layout_ = tensor_layout_out; + return SUCCESS; +} + +Status ReshapeInfo::InferTensorInfo() { + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = GetOutputsStrategy(); + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + TensorLayout tensor_layout_in = inputs_layout.at(0); + TensorLayout tensor_layout_out = outputs_layout.at(0); + Shape shape_array_in = inputs_shape_.at(0); + Shape slice_shape_in = inputs_slice_shape.at(0); + Shape shape_array_out = outputs_shape_.at(0); + Shape slice_shape_out = outputs_slice_shape.at(0); + TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); + TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_out); + return SUCCESS; +} + +void ReshapeInfo::InferTensorInfoByLayout() { + TensorInfo tensor_info_in(input_layout_); + TensorInfo tensor_info_out(output_layout_); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_out); +} + +/* + * compute parameter_input_v_ during this method + */ +Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } + +void ReshapeInfo::device_number(const StrategyPtr &strategy) { + int32_t stage = 0; + if (strategy != nullptr) { + stage = strategy->GetInputStage(); + } + CheckGlobalDeviceManager(); + global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); + dev_num_ = SizeToInt(global_device_list_.size()); + MS_ASSERT(dev_num_ > 0); +} + +Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { + std::vector tensor_map_index; + for (size_t i = 0; i < shape.size(); i++) { + tensor_map_index.push_back(MAP_NONE); + } + Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape); + if (status != Status::SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed."; + return status; + } + return Status::SUCCESS; +} + +Status ReshapeInfo::Init(const StrategyPtr &strategy) { + ResetQueueMember(); + device_number(strategy); + if (strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + } else { + if (!input_layout_set_flag_) { + MS_ASSERT(inputs_shape_.size() == 1); + Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": infer input default layout failed."; + return status; + } + } + if (!output_layout_set_flag_) { + MS_ASSERT(output_layout_.size() == 1); + Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": infer output default layout failed."; + return status; + } + } + inputs_tensor_map_.push_back(input_layout_.tensor_map().array()); + outputs_tensor_map_.push_back(output_layout_.tensor_map().array()); + InferTensorInfoByLayout(); + // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps + dev_matrix_shape_ = input_layout_.device_arrangement().array(); + if (InferMirrorOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; + return FAILED; + } + // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps + dev_matrix_shape_ = output_layout_.device_arrangement().array(); + if (InferVirtualDivOps() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; + return FAILED; + } + } + Status status = ComputeReplaceOp(); + if (status != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; + return status; + } + return SUCCESS; +} + +Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +void ReshapeInfo::SetCostForReshapeWithParameter() { + size_t success = 0; + for (auto &sp : sp_vector_) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } +} + +void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + int32_t stage_id = strategy->GetInputStage(); + double computation_cost = + operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + std::shared_ptr result = std::make_shared(computation_cost, communication_cost); + result->communication_without_parameter_ = + operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); + result->communication_with_partial_para_ = + result->communication_without_parameter_ + + COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); + + // Breaking ties for preferring data parallelization + BreakingTiesForPerferringDataParallel(strategy, result); + // refine communication cost calculation for practice + RefineForPracticalCost(result, false); + + std::shared_ptr swc = + std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); + swc->cost_list.push_back(result); + strategy_cost_.emplace_back(swc); +} + +Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GetAttrs failed."; + return FAILED; + } + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split; + (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + // strategy used only in the input node is parameter, + // in other case, use the input node's output_layout as input_layout. + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + return SUCCESS; +} + +Status ReshapeInfo::GenetateStrategyCosts(const std::vector> &pre_stra_costs, + const std::vector> &next_stra_costs, + int32_t out_index, int32_t in_index, bool is_prev_param) { + is_generating_costs_ = true; + for (auto pre_stra_cost : pre_stra_costs) { + std::vector pre_out_tensor_infos; + if (is_prev_param) { + pre_out_tensor_infos = pre_stra_cost->inputs_ptr; + } else { + pre_out_tensor_infos = pre_stra_cost->outputs_ptr; + } + if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { + MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; + return FAILED; + } + TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; + SetInputLayout(pre_out_tensor_info.tensor_layout()); + // infer pre_node output strategy from output_layout. + Dimensions stra = pre_out_tensor_info.InferStrategy(); + if (stra.empty()) { + MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; + return FAILED; + } + std::vector stra_inputs = {stra}; + StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); + if (next_stra_costs.empty()) { + if (Init(nullptr) == FAILED) { + MS_LOG(ERROR) << "Failure:operator reshape init failed"; + return FAILED; + } + SetCostForReshape(reshape_stra); + continue; + } + for (auto next_stra_cost : next_stra_costs) { + std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; + if (next_in_tensor_infos.size() <= IntToSize(in_index)) { + MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; + return FAILED; + } + TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; + SetOutputLayout(next_in_tensor_info.tensor_layout()); + if (Init(nullptr) == FAILED) { + MS_LOG(DEBUG) << "Failure:operator reshape init failed"; + continue; + } + SetCostForReshape(reshape_stra); + } + } + is_generating_costs_ = false; + if (strategy_cost_.empty()) { + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h new file mode 100644 index 0000000000..2463b440f8 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ + +#include + +#include +#include +#include +#include + +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * parallel class for Reshape Primitive + */ +class ReshapeInfo : public OperatorInfo { + public: + ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), + dev_num_(0), + pre_operator_index_(0), + next_operator_index_(0), + input_layout_set_flag_(false), + output_layout_set_flag_(false) {} + ~ReshapeInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + void SetInputLayout(const TensorLayout &input_layout) { + input_layout_ = input_layout; + input_layout_set_flag_ = true; + } + void SetOutputLayout(const TensorLayout &output_layout) { + output_layout_ = output_layout; + output_layout_set_flag_ = true; + } + void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); + void SetCostForReshapeWithParameter(); + void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } + void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } + void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } + void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } + Status GenetateStrategyCosts(const std::vector> &pre_stra_costs, + const std::vector> &next_stra_costs, int32_t out_index, + int32_t in_index, bool is_prev_param); + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + std::string pre_operator_name() const { return pre_operator_name_; } + std::string next_operator_name() const { return next_operator_name_; } + int32_t pre_operator_index() const { return pre_operator_index_; } + int32_t next_operator_index() const { return next_operator_index_; } + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorMap() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status GetAttrs() override; + Strategys GetOutputsStrategy(); + + private: + Status GetParameterInput(); + Status ComputeReplaceOp(); + void InferTensorInfoByLayout(); + void device_number(const StrategyPtr &strategy); + Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); + + int32_t dev_num_; + int32_t pre_operator_index_; + int32_t next_operator_index_; + std::vector parameter_input_v_; + std::vector sp_vector_; + Dimensions input_strategy_; + TensorLayout input_layout_; + TensorLayout output_layout_; + bool input_layout_set_flag_; + bool output_layout_set_flag_; + bool is_generating_costs_; + std::string pre_operator_name_; + std::string next_operator_name_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc new file mode 100644 index 0000000000..ed6eaa89f1 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc @@ -0,0 +1,147 @@ +/** +#include "utils/log_adapter.h" + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/tmp_identity_info.h" + +#include +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": invalid strategy."; + } + return FAILED; + } + return SUCCESS; +} + +Status TmpIdentityInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions input_strategy = stra.at(0); + dev_matrix_shape_ = input_strategy; + return SUCCESS; +} + +Status TmpIdentityInfo::InferTensorMap() { + std::vector tensor_map_index; + size_t size = inputs_shape_[0].size(); + // such as 4: tensor_map_index [3,2,1,0] + for (size_t i = 0; i < size; ++i) { + tensor_map_index.push_back((int32_t)(size - 1 - i)); + } + + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + return SUCCESS; +} + +Status TmpIdentityInfo::InferTensorInfo() { + // infer tensor shape + Shape input_shape = inputs_shape_.at(0); + + // infer slice shape + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = {inputs_strategy.at(0)}; + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + Shape input_slice_shape = inputs_slice_shape.at(0); + + TensorLayout input_tensor_layout; + if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { + return FAILED; + } + + TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); + + inputs_tensor_info_.push_back(input_tensor_info); + outputs_tensor_info_.push_back(input_tensor_info); // the same as input + + return SUCCESS; +} + +Status TmpIdentityInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h new file mode 100644 index 0000000000..7f73f81180 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ + +#include +#include +#include + +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class TmpIdentityInfo : public OperatorInfo { + // This operator is only used for the case of a parameter tensor being used by multiple operators, where we + // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, + // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. + public: + TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, + const std::string &name = IDENTITY_INFO) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TmpIdentityInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status GetAttrs() override { return SUCCESS; } + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc new file mode 100644 index 0000000000..b6bb875abc --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc @@ -0,0 +1,247 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/transpose_info.h" + +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + return SUCCESS; +} + +Status TransposeInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + input_strategy_ = stra.at(0); + for (auto &iter : input_strategy_) { + dev_matrix_shape_.push_back(iter); + } + return SUCCESS; +} + +// there is no Parameter for Transpose Primitive, so no need to do all reduce +Status TransposeInfo::InferMirrorOps() { return SUCCESS; } + +// there is no reduction dimension for forward computation of Transpose Primitive, so no need to do all reduce +Status TransposeInfo::InferForwardCommunication() { return SUCCESS; } + +/* + * get perm input of Transpose Primitive + * perm is a permutation of the dimensions of input + * the result is saved in axis_v_ + */ +Status TransposeInfo::ComputeAxis() { + if (input_value_[1] == nullptr) { + MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; + return FAILED; + } + std::vector elements; + ValueTuplePtr dim_tuple = input_value_[1]->cast(); + if (dim_tuple == nullptr) { + MS_LOG(ERROR) << name_ << ": input_value_[1] must be ValueTuplePtr."; + return FAILED; + } + elements = dim_tuple->value(); + if (elements.size() != inputs_shape_[0].size()) { + MS_LOG(ERROR) << name_ << ": elements size must equal to inputs shape 0 size."; + return FAILED; + } + axis_v_.clear(); + for (auto &element : elements) { + MS_EXCEPTION_IF_NULL(element); + if (element->isa()) { + int32_t axis = element->cast()->value(); + axis_v_.push_back(axis); + } else { + MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; + return FAILED; + } + } + + for (int32_t i = 0; i < SizeToInt(axis_v_.size()); i++) { + auto iter = std::find(axis_v_.begin(), axis_v_.end(), i); + if (iter == axis_v_.end()) { + MS_LOG(ERROR) << name_ << ": axis_v_ must be a permutation."; + } + } + return SUCCESS; +} + +// the output tensor map is the permutation of input tensor map, the permutation is axis_v +Status TransposeInfo::InferTensorMap() { + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": inputs_shape_ and outputs_shape_ size must be 1, inputs shape and outputs shape is " + << inputs_shape_.size() << ", " << outputs_shape_.size(); + return FAILED; + } + + std::vector tensor_map_index_input; + for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { + tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1)); + } + inputs_tensor_map_.push_back(tensor_map_index_input); + + std::vector tensor_map_index_output = tensor_map_index_input; + for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) { + tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])]; + } + outputs_tensor_map_.push_back(tensor_map_index_output); + return SUCCESS; +} + +// the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v +Strategys TransposeInfo::GetOutputsStrategy() { + Strategys outputs_strategy; + std::vector strategy = input_strategy_; + for (uint32_t i = 0; i < strategy.size(); i++) { + strategy[i] = input_strategy_[IntToUint(axis_v_[i])]; + } + outputs_strategy.push_back(strategy); + return outputs_strategy; +} + +Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { + if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { + MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; + return FAILED; + } + Shape shape_in = inputs_shape_.at(0); + TensorMap tensor_map_in = inputs_tensor_map_.at(0); + Shape shape_out = outputs_shape_.at(0); + TensorMap tensor_map_out = outputs_tensor_map_.at(0); + + TensorLayout tensor_layout_in, tensor_layout_out; + if ((tensor_layout_in.InitFromVector(dev_matrix_shape_, tensor_map_in, shape_in) != SUCCESS) || + (tensor_layout_out.InitFromVector(dev_matrix_shape_, tensor_map_out, shape_out) != SUCCESS)) { + return FAILED; + } + + inputs_layout->push_back(tensor_layout_in); + outputs_layout->push_back(tensor_layout_out); + return SUCCESS; +} + +Status TransposeInfo::InferTensorInfo() { + Shapes inputs_slice_shape, outputs_slice_shape; + Strategys inputs_strategy = strategy_->GetInputDim(); + Strategys outputs_strategy = GetOutputsStrategy(); + if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { + return FAILED; + } + + TensorLayouts inputs_layout, outputs_layout; + if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { + return FAILED; + } + TensorLayout tensor_layout_in = inputs_layout.at(0); + TensorLayout tensor_layout_out = outputs_layout.at(0); + Shape shape_array_in = inputs_shape_.at(0); + Shape slice_shape_in = inputs_slice_shape.at(0); + Shape shape_array_out = outputs_shape_.at(0); + Shape slice_shape_out = outputs_slice_shape.at(0); + TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); + TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_out); + return SUCCESS; +} + +// compute axis_v_ during this method +Status TransposeInfo::GetAttrs() { return ComputeAxis(); } + +Status TransposeInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status TransposeInfo::GenerateStrategies(int32_t stage_id) { + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GetAttrs failed."; + return FAILED; + } + if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { + MS_LOG(ERROR) << name_ << ": inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " + << outputs_shape_.size(); + return FAILED; + } + is_auto_parallel_ = true; + Shape input0_split(inputs_shape_[0].size(), 1); + Shapes splittable_inputs = {input0_split}; + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; + return FAILED; + } + size_t success = 0; + for (auto &sp : sp_vector) { + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h new file mode 100644 index 0000000000..d3b62dc234 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +/* + * parallel class for Transpose Primitive + */ +class TransposeInfo : public OperatorInfo { + public: + TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~TransposeInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + Status GetAttrs() override; + Strategys GetOutputsStrategy(); + + private: + Status ComputeAxis(); + std::vector axis_v_; + Dimensions input_strategy_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc new file mode 100644 index 0000000000..3b89d7c84c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc @@ -0,0 +1,229 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/virtual_dataset_info.h" + +#include +#include +#include + +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { + if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy."; + } else { + MS_LOG(ERROR) << name_ << ": Invalid strategy."; + } + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() < 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1."; + } else { + MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1."; + } + return FAILED; + } + if (stra.size() == 1) { + MS_LOG(WARNING) << name_ << ": Strategy size is 1."; + return SUCCESS; + } + Dimensions strategy_first = stra.at(1); + for (auto iter_strategy = stra.begin() + 1; iter_strategy != stra.end(); ++iter_strategy) { + if (iter_strategy->empty()) { + MS_LOG(ERROR) << name_ << ": iter_strategy size is zero."; + } + if (strategy_first.at(0) != *(iter_strategy->begin())) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": The first dimension of each strategy must be the same."; + } else { + MS_LOG(ERROR) << name_ << ": The first dimension of each strategy must be the same."; + } + return FAILED; + } + + for (auto iter_element = iter_strategy->begin() + 1; iter_element != iter_strategy->end(); ++iter_element) { + if (*iter_element != 1) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": All dimension except the first dimension of each strategy must be 1."; + } else { + MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of each strategy must be 1."; + } + return FAILED; + } + } + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InferDevMatrixShape() { + std::vector stra = strategy_->GetInputDim(); + Dimensions strategy_first = stra.at(0); + int32_t stage = strategy_->GetInputStage(); + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); + int32_t batch_split_num = strategy_first.at(0); + dev_matrix_shape_.push_back(batch_split_num); + if (dev_num > batch_split_num) { + dev_matrix_shape_.push_back(dev_num / batch_split_num); + } + + return SUCCESS; +} + +Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } + +Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } + +Status VirtualDatasetInfo::InferTensorMap() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { + std::vector tensor_map_index; + if (full_batch) { + tensor_map_index.push_back(MAP_NONE); + } else { + tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); + } + for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { + tensor_map_index.push_back(MAP_NONE); + } + inputs_tensor_map_.push_back(tensor_map_index); + outputs_tensor_map_.push_back(tensor_map_index); + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InferTensorInfo() { + for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { + MS_LOG(INFO) << name_ << ": InferTensorInfo " << i << ", size " << strategy_->GetInputNumber(); + TensorLayout tensor_layout_in; + if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { + return FAILED; + } + TensorInfo tensor_info_in(tensor_layout_in); + inputs_tensor_info_.push_back(tensor_info_in); + outputs_tensor_info_.push_back(tensor_info_in); + } + return SUCCESS; +} + +Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } + +Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { + if (InitWithManualRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; + } else { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + } + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} + +void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = true; + } +} + +Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + if (SetCostUnderStrategyBase(strategy) != SUCCESS) { + if (is_auto_parallel_) { + MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; + } + return FAILED; + } + + return SUCCESS; +} + +Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + size_t total_dev_num; + + if (GetAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": GetAttrs failed"; + return FAILED; + } + + CheckGlobalDeviceManager(); + is_auto_parallel_ = true; + if (full_batch) { + total_dev_num = 1; + } else { + total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); + } + StrategyPtr sp; + std::vector strategy; + for (auto &shape : inputs_shape_) { + Shape temp; + temp.emplace_back(SizeToInt(total_dev_num)); + (void)temp.insert(temp.end(), shape.size() - 1, 1); + strategy.push_back(temp); + } + sp = std::make_shared(stage_id, strategy); + + if (SetCostUnderStrategy(sp) == SUCCESS) { + if (full_batch) { + MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy."; + } else { + MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; + } + PrintStrategy(sp); + } else { + if (full_batch) { + MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed."; + } else { + MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; + } + return FAILED; + } + return SUCCESS; +} + +Status VirtualDatasetInfo::InferAsLossDivisor() { + // no need to insert div op + as_loss_divisor_ = 1; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h new file mode 100644 index 0000000000..fe54954be0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_OPS_INFO_DATASET_INFO_H_ +#define PARALLEL_OPS_INFO_DATASET_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class VirtualDatasetInfo : public OperatorInfo { + public: + VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} + ~VirtualDatasetInfo() override = default; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + + Status GenerateStrategies(int32_t stage_id) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override; + Status InferForwardCommunication() override; + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; + Status GetAttrs() override; + Status InferAsLossDivisor() override; +}; +} // namespace parallel +} // namespace mindspore + +#endif // PARALLEL_OPS_INFO_VIRTUAL_DATASET_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ps/common.h b/mindspore/ccsrc/frontend/parallel/ps/common.h similarity index 100% rename from mindspore/ccsrc/parallel/ps/common.h rename to mindspore/ccsrc/frontend/parallel/ps/common.h diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc new file mode 100644 index 0000000000..e16c713e3c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/optimizer_info.h" +#include + +namespace mindspore { +namespace parallel { +namespace ps { +void OptimizerInfo::AddWorkspace(const AddressPtr &workspace) { workspaces_.push_back(workspace); } + +const std::vector &OptimizerInfo::inputs() { return inputs_; } + +const std::vector &OptimizerInfo::workspaces() { return workspaces_; } + +const std::vector &OptimizerInfo::outputs() { return outputs_; } + +bool OptimizerInfo::IsSparse() const { return false; } + +size_t OptimizerInfo::grad_index() { return 0; } + +size_t OptimizerInfo::indices_index() { return 0; } + +void OptimizerInfo::UpdateWeight(const WeightPtr &weight) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + inputs_[0] = weight_addr; +} + +void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { + float *accum_grad_data = reinterpret_cast(gradient()->addr); + size_t size = gradient()->size / sizeof(float); + size_t grad_index = this->grad_index(); + size_t grad_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += lengths[i]; + } + float *grad_data = values.data() + grad_offset; + CHECK_EQ(size, static_cast(lengths[grad_index])); + + for (size_t i = 0; i < size; i++) { + accum_grad_data[i] += grad_data[i]; + } +} + +void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { + // Append grad data to the end + float *accum_grad_data = reinterpret_cast(gradient()->addr); + + size_t grad_index = this->grad_index(); + size_t grad_offset = 0; + for (size_t i = 0; i < grad_index; i++) { + grad_offset += lengths[i]; + } + float *incr_grad_data = values.data() + grad_offset; + size_t incr_grad_size = lengths[grad_index] * sizeof(float); + + auto ret = memcpy_s(accum_grad_data + grads_offset_, incr_grad_size, incr_grad_data, incr_grad_size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + grads_offset_ += incr_grad_size; + gradient()->size += incr_grad_size; + + // Append indice data to the end + int *accum_indices_data = reinterpret_cast(indices()->addr); + + size_t indices_index = this->indices_index(); + size_t indice_offset = 0; + for (size_t i = 0; i < indices_index; i++) { + indice_offset += lengths[i]; + } + int *incr_indice_data = reinterpret_cast(values.data() + indice_offset); + size_t incr_indice_size = lengths[indices_index] * sizeof(float); + + auto ret2 = memcpy_s(accum_indices_data + indices_offset_, incr_indice_size, incr_indice_data, incr_indice_size); + if (ret2 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + } + indices_offset_ += incr_indice_size; + indices()->size += incr_indice_size; +} + +void SparseOptimInfo::Reset() { + auto &gradient = this->gradient(); + gradient->size = 0; + auto &indices = this->indices(); + indices->size = 0; + grads_offset_ = 0; + indices_offset_ = 0; +} + +MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, + const AddressPtr &learning_rate, const AddressPtr &gradient, + const AddressPtr &momentum) { + inputs_.push_back(weight); + inputs_.push_back(accumulate); + inputs_.push_back(learning_rate); + inputs_.push_back(gradient); + inputs_.push_back(momentum); +} + +const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } + +const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } + +SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, + const AddressPtr &beta1_power, const AddressPtr &beta2_power, + const AddressPtr &learning_rate, const AddressPtr &beta1, + const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, + const AddressPtr &indices, size_t grads_offset, size_t indices_offset) { + inputs_.push_back(weight); + inputs_.push_back(m); + inputs_.push_back(v); + inputs_.push_back(beta1_power); + inputs_.push_back(beta2_power); + inputs_.push_back(learning_rate); + inputs_.push_back(beta1); + inputs_.push_back(beta2); + inputs_.push_back(epsilon); + inputs_.push_back(grad); + inputs_.push_back(indices); + grads_offset_ = grads_offset; + indices_offset_ = indices_offset; +} + +void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { + void *data_ptr = values.data(); + AddressPtr beta1_power = inputs_[3]; + size_t size = values.size() * sizeof(float); + auto ret = memcpy_s(beta1_power->addr, size, data_ptr, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } +} + +const AddressPtr &SparseAdamOptimInfo::gradient() { return inputs_[9]; } + +const AddressPtr &SparseAdamOptimInfo::indices() { return inputs_[10]; } + +bool SparseAdamOptimInfo::IsSparse() const { return true; } + +size_t SparseAdamOptimInfo::grad_index() { return 6; } + +size_t SparseAdamOptimInfo::indices_index() { return 7; } + +SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, + const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, + size_t indices_offset) { + inputs_.push_back(weight); + inputs_.push_back(accum); + inputs_.push_back(linear); + inputs_.push_back(grad); + inputs_.push_back(indices); + grads_offset_ = grads_offset; + indices_offset_ = indices_offset; +} + +const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } + +const AddressPtr &SparseFtrlOptimInfo::indices() { return inputs_[4]; } + +bool SparseFtrlOptimInfo::IsSparse() const { return true; } + +size_t SparseFtrlOptimInfo::grad_index() { return 0; } + +size_t SparseFtrlOptimInfo::indices_index() { return 1; } +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h new file mode 100644 index 0000000000..bb9a64acdb --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.h @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ + +#include +#include "backend/kernel_compiler/kernel.h" +#include "frontend/parallel/ps/common.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::AddressPtr; +class OptimizerInfo { + public: + OptimizerInfo() = default; + virtual ~OptimizerInfo() = default; + + virtual void Update(const Values &values, const Lengths &lengths) {} + virtual void UpdateWeight(const WeightPtr &weight); + virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; + virtual void Reset() {} + void AddWorkspace(const AddressPtr &workspace); + + virtual const AddressPtr &gradient() = 0; + virtual const AddressPtr &indices() = 0; + const std::vector &inputs(); + const std::vector &workspaces(); + const std::vector &outputs(); + + virtual bool IsSparse() const; + virtual size_t grad_index(); + virtual size_t indices_index(); + + protected: + std::vector inputs_; + std::vector workspaces_; + std::vector outputs_; +}; + +class DenseOptimInfo : public OptimizerInfo { + public: + DenseOptimInfo() = default; + ~DenseOptimInfo() override = default; + + void Accumulate(const Values &values, const Lengths &lens) override; +}; + +class SparseOptimInfo : public OptimizerInfo { + public: + SparseOptimInfo() = default; + ~SparseOptimInfo() override = default; + + void Accumulate(const Values &values, const Lengths &lens) override; + void Reset() override; + + protected: + size_t grads_offset_{0}; + size_t indices_offset_{0}; +}; + +class MomentumOptimInfo : public DenseOptimInfo { + public: + MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, const AddressPtr &learning_rate, + const AddressPtr &gradient, const AddressPtr &momentum); + ~MomentumOptimInfo() override = default; + + const AddressPtr &gradient(); + const AddressPtr &indices(); +}; + +class SparseAdamOptimInfo : public SparseOptimInfo { + public: + SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, + const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, + const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, + const AddressPtr &indices, size_t grads_offset, size_t indices_offset); + ~SparseAdamOptimInfo() override = default; + + void Update(const Values &values, const Lengths &lens) override; + const AddressPtr &gradient(); + const AddressPtr &indices(); + bool IsSparse() const override; + size_t grad_index() override; + size_t indices_index() override; +}; + +class SparseFtrlOptimInfo : public SparseOptimInfo { + public: + SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, + const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, size_t indices_offset); + ~SparseFtrlOptimInfo() override = default; + + const AddressPtr &gradient(); + const AddressPtr &indices(); + bool IsSparse() const override; + size_t grad_index() override; + size_t indices_index() override; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc new file mode 100644 index 0000000000..159a50793e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/optimizer_info_builder.h" +#include +#include +#include + +namespace mindspore { +namespace parallel { +namespace ps { +OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr &pserver_kernel, + const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) { + OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num); + std::vector ws_sizes = pserver_kernel->workspace_sizes(); + BuildWorkspaces(optim_info, ws_sizes, worker_num); + BuildOutputs(optim_info, worker_num); + return optim_info; +} + +void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, + size_t worker_num) { + for (size_t i = 0; i < ws_sizes.size(); i++) { + size_t size = ws_sizes[i]; + AddressPtr workspace = std::make_shared(); + workspace->addr = new float[size]; + workspace->size = size; + info->AddWorkspace(workspace); + } +} + +OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + void *data_ptr = values.data(); + AddressPtr accumulate = std::make_shared(); + accumulate->addr = new float[weight->size()]; + accumulate->size = weight->size(); + AddressPtr learning_rate = std::make_shared(); + learning_rate->addr = data_ptr; + learning_rate->size = lens[0]; + AddressPtr gradient = std::make_shared(); + gradient->addr = reinterpret_cast(learning_rate->addr) + lens[0]; + gradient->size = lens[1]; + AddressPtr momentum = std::make_shared(); + momentum->addr = reinterpret_cast(gradient->addr) + lens[1]; + momentum->size = lens[2]; + + return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum); +} + +OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + AddressPtr m = std::make_shared(); + m->addr = new float[weight->size()]; + m->size = weight->size() * sizeof(float); + AddressPtr v = std::make_shared(); + v->addr = new float[weight->size()]; + v->size = weight->size() * sizeof(float); + + void *data_ptr = values.data(); + void *copy_data_ptr = new float[values.size()]; + auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + AddressPtr beta1_power = std::make_shared(); + beta1_power->addr = copy_data_ptr; + beta1_power->size = lens[0] * sizeof(float); + AddressPtr beta2_power = std::make_shared(); + beta2_power->addr = reinterpret_cast(beta1_power->addr) + lens[0]; + beta2_power->size = lens[1] * sizeof(float); + + AddressPtr learning_rate = std::make_shared(); + learning_rate->addr = reinterpret_cast(beta2_power->addr) + lens[1]; + learning_rate->size = lens[2] * sizeof(float); + + AddressPtr beta1 = std::make_shared(); + beta1->addr = reinterpret_cast(learning_rate->addr) + lens[2]; + beta1->size = lens[3] * sizeof(float); + + AddressPtr beta2 = std::make_shared(); + beta2->addr = reinterpret_cast(beta1->addr) + lens[3]; + beta2->size = lens[4] * sizeof(float); + + AddressPtr epsilon = std::make_shared(); + epsilon->addr = reinterpret_cast(beta2->addr) + lens[4]; + epsilon->size = lens[5] * sizeof(float); + + const std::shared_ptr> &grad_shape = (*inputs_shape)[9]; + size_t total_grad_size = + std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies()); + AddressPtr grad = std::make_shared(); + grad->addr = new float[total_grad_size * worker_num]; + auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], + lens[6] * sizeof(float)); + if (ret2 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + } + grad->size = lens[6] * sizeof(float); + + const std::shared_ptr> &indices_shape = (*inputs_shape)[10]; + size_t total_indice_size = + std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies()); + AddressPtr indices = std::make_shared(); + indices->addr = new float[total_indice_size * worker_num]; + auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float), + reinterpret_cast(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float)); + if (ret3 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")"; + } + indices->size = lens[7] * sizeof(float); + + return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, + grad, indices, total_grad_size, total_indice_size); +} + +OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num) { + AddressPtr weight_addr = std::make_shared(); + weight_addr->addr = weight->data(); + weight_addr->size = weight->size(); + AddressPtr accum = std::make_shared(); + accum->addr = new float[weight->size()]; + accum->size = weight->size() * sizeof(float); + for (size_t i = 0; i < weight->size(); i++) { + float *tmp = reinterpret_cast(accum->addr); + tmp[i] = 1.0; + } + AddressPtr linear = std::make_shared(); + linear->addr = new float[weight->size()]; + memcpy_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); + linear->size = weight->size() * sizeof(float); + + const std::shared_ptr> &grad_shape = (*inputs_shape)[3]; + size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies()); + AddressPtr grad = std::make_shared(); + grad->addr = new float[total_grad_size * worker_num]; + auto ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + grad->size = lens[0] * sizeof(float); + + const std::shared_ptr> &indices_shape = (*inputs_shape)[4]; + size_t total_indice_size = + std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); + AddressPtr indices = std::make_shared(); + indices->addr = new float[total_indice_size * worker_num]; + auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], + lens[1] * sizeof(float)); + if (ret2 != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; + } + indices->size = lens[1] * sizeof(float); + + return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, total_grad_size, total_indice_size); +} +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h new file mode 100644 index 0000000000..c5aae32921 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ + +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/ps/pserver_kernel.h" +#include "frontend/parallel/ps/optimizer_info.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::KernelMod; +using mindspore::kernel::ps::PServerKernel; +class OptimizerInfoBuilder { + public: + OptimizerInfoBuilder() = default; + virtual ~OptimizerInfoBuilder() = default; + + OptimizerInfo *Build(const std::shared_ptr &pserver_kernel, const WeightPtr &weight, const Keys &keys, + const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, + size_t worker_num); + + virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, + const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) = 0; + + virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, size_t worker_num); + virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {} +}; + +class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { + public: + OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, + const InputsShapePtr &inputs_shape, size_t worker_num) override; +}; + +class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { + public: + OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, + const InputsShapePtr &inputs_shpae, size_t worker_num) override; +}; + +class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { + public: + OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, + const InputsShapePtr &inputs_shpae, size_t worker_num) override; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h new file mode 100755 index 0000000000..1afb4c9fa6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -0,0 +1,559 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ir/func_graph.h" +#include "backend/session/session_basic.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/session_factory.h" +#include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/optimizer_info.h" +#include "frontend/parallel/ps/optimizer_info_builder.h" +#include "frontend/parallel/ps/util.h" +#include "runtime/device/cpu/kernel_select_cpu.h" +#include "utils/context/ms_context.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/ps/pserver_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "backend/kernel_compiler/ps/sparse_apply_adam_ps_kernel.h" +#include "backend/kernel_compiler/ps/sparse_apply_ftrl_ps_kernel.h" +#include "backend/kernel_compiler/ps/apply_momentum_ps_kernel.h" +#include "backend/kernel_compiler/ps/embedding_look_up_ps_kernel.h" + +namespace mindspore { +namespace parallel { +namespace ps { +using mindspore::kernel::ps::PServerKernel; +template +class ParameterServer { + public: + static ParameterServer &GetInstance() { + static ParameterServer instance; + return instance; + } + + void Run(const FuncGraphPtr &func_graph); + + private: + ParameterServer() + : pserver_num_(0), + worker_num_(0), + rank_id_(0), + grad_accum_count_(0), + ps_(new ::ps::KVServer(0)), + handler_(nullptr), + func_graph_(nullptr), + kernel_graph_(nullptr), + sess_(nullptr), + thread_(nullptr) {} + ~ParameterServer() = default; + ParameterServer(const ParameterServer &) = delete; + ParameterServer &operator=(const ParameterServer &) = delete; + + struct ServerHandler { + explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} + void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); + void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); + void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleInitWeights(const ::ps::KVPairs &req_data); + void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); + void HandleInitInputsShape(const ::ps::KVPairs &req_data); + void HandleInitEmbeddings(const ::ps::KVPairs &req_data); + void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + ParameterServer *ps_; + }; + + bool Init(const FuncGraphPtr &func_graph); + void InitOptimInfoBuilders(); + void InitWeightKeyToOptims(const Key &key, const int &optim_id); + void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); + void InitWeight(const Key &key, const WeightPtr &weight); + void InitGrad(const Key &key, const GradPtr &grad); + void InitEmbeddingTable(const Key &key, + const std::shared_ptr>>> &shapes); + void UpdateWeights(); + void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); + WeightPtr weight(const Key &key); + void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); + int SumOfShapes(const std::vector &shapes) const; + size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); + bool ReadyForUpdateWeights(); + bool ReadyForAccumGrads(); + void ResetGradAccumCount(); + + size_t pserver_num_; + size_t worker_num_; + size_t rank_id_; + size_t grad_accum_count_; + std::unique_ptr<::ps::KVServer> ps_; + std::unique_ptr handler_; + FuncGraphPtr func_graph_; + std::shared_ptr kernel_graph_; + std::shared_ptr sess_; + + std::unordered_map> optimizers_; + std::unordered_map optim_inputs_shape_; + std::unordered_map> optim_infos_; + std::unordered_map> optim_info_builders_; + std::unordered_map weight_key_to_optims_; + std::unordered_map weights_; + std::unordered_map grads_; + std::unordered_map grads_accum_counter_; + // std::unordered_map embeddings_; + std::unordered_map> embedding_lookup_ops_; + std::unordered_map embedding_row_lens_; + + T learning_rate_; + T momentum_; + + std::mutex mutex_; + std::condition_variable apply_grads_cv_; + std::condition_variable accum_grads_cv_; + + std::unique_ptr thread_; + + friend struct ServerHandler; +}; + +class FuncGraph; +template +void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVServer *server) { + ::ps::KVPairs res; + if (req_meta.cmd == kInitWeightsCmd) { + MS_LOG(ERROR) << "handle init weights cmd" << std::endl; + HandleInitWeights(req_data); + } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { + MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; + HandleInitWeightToOptimId(req_data); + } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { + MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; + HandleInitInputsShape(req_data); + } else if (req_meta.cmd == kInitEmbeddingsCmd) { + MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; + HandleInitEmbeddings(req_data); + } else if (req_meta.cmd == kEmbeddingLookupCmd) { + MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; + HandleEmbeddingLookup(req_meta, req_data, &res); + } else if (req_meta.push) { + MS_LOG(ERROR) << "handle push req cmd" << std::endl; + HandlePushReq(req_meta, req_data); + } else { + MS_LOG(ERROR) << "handle pull req cmd" << std::endl; + HandlePullReq(req_meta, req_data, &res); + } + server->Response(req_meta, res); +} + +template +void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { + ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); +} + +template +void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + res->keys = req_data.keys; + ::ps::Key key = req_data.keys[0]; + res->vals = *(ps_->weight(key)); +} + +template +void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { + size_t key_num = req_data.keys.size(); + T *data_ptr = req_data.vals.data(); + size_t pos = 0; + for (size_t i = 0; i < key_num; i++) { + Key key = req_data.keys[i]; + size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; + + WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); + weight_ptr->CopyFrom(data_ptr + pos, data_len); + ps_->InitWeight(key, weight_ptr); + + GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); + ps_->InitGrad(key, grad_ptr); + pos += data_len; + } +} + +template +void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { + size_t key_num = req_data.keys.size(); + for (size_t i = 0; i < key_num; i++) { + Key key = req_data.keys[i]; + T val = req_data.vals[i]; + ps_->InitWeightKeyToOptims(key, val); + } +} + +template +void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { + ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); +} + +template +void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> input_shape = std::make_shared>(); + std::shared_ptr> indices_shape = std::make_shared>(); + std::shared_ptr> output_shape = std::make_shared>(); + shapes->push_back(input_shape); + shapes->push_back(indices_shape); + shapes->push_back(output_shape); + + const Key &key = req_data.keys[0]; + const Lengths &lens = req_data.lens; + size_t index = 0; + for (int i = 0; i < lens[0]; i++) { + input_shape->push_back(static_cast(req_data.vals[index++])); + } + for (int j = 0; j < lens[1]; j++) { + indices_shape->push_back(static_cast(req_data.vals[index++])); + } + for (int k = 0; k < lens[2]; k++) { + output_shape->push_back(static_cast(req_data.vals[index++])); + } + ps_->InitEmbeddingTable(key, shapes); +} + +template +void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { + const Key &key = req_data.keys[0]; + ps_->DoEmbeddingLookup(key, req_data.vals, res); + for (size_t i = 0; i < req_data.vals.size(); i++) { + res->keys->push_back(req_data.vals[i]); + } +} + +template +bool ParameterServer::Init(const FuncGraphPtr &func_graph) { + const char *server_num = getenv(kEnvPServerNum); + const char *worker_num = getenv(kEnvWorkerNum); + if (server_num != nullptr) { + pserver_num_ = *server_num - '0'; + } + if (worker_num != nullptr) { + worker_num_ = *worker_num - '0'; + } + func_graph_ = func_graph; + rank_id_ = ::ps::MyRank(); + handler_.reset(new ServerHandler(this)); + + InitOptimInfoBuilders(); + + ps_->set_request_handle(*handler_); + thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); + return true; +} + +template +void ParameterServer::InitOptimInfoBuilders() { + std::shared_ptr momentum_info_builder = std::make_shared(); + std::shared_ptr sparse_adam_info_builder = std::make_shared(); + std::shared_ptr sparse_ftrl_info_builder = std::make_shared(); + optim_info_builders_[kApplyMomentum] = momentum_info_builder; + optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; + optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; +} + +template +void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_id) { + if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { + return; + } + weight_key_to_optims_[key] = Util::optimizer_name(optim_id); +} + +template +void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { + InputsShapePtr inputs_shape = std::make_shared(); + int val_idx = 0; + const Key &key = keys[0]; + + if (optim_inputs_shape_.count(key) == 0) { + optim_inputs_shape_[key] = inputs_shape; + } + for (size_t i = 0; i < keys.size(); i++) { + auto shape = std::make_shared>(); + inputs_shape->push_back(shape); + + int len = lengths[i]; + for (int j = 0; j < len; j++) { + shape->push_back(values[val_idx++]); + } + } + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { + if (optim_name == kSparseAdam) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } else if (optim_name == kApplyMomentum) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } else if (optim_name == kSparseFtrl) { + std::shared_ptr optimizer = + std::make_shared(rank_id_, pserver_num_); + optimizer->InitKernel(optim_inputs_shape_[key]); + optimizers_[optim_name] = optimizer; + } + } + } +} + +template +void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { + if (weights_.count(key) == 0) { + weights_[key] = weight; + } +} + +template +void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { + if (grads_.count(key) == 0) { + grads_[key] = grad; + grads_accum_counter_[key] = 0; + } +} + +template +void ParameterServer::InitEmbeddingTable( + const Key &key, const std::shared_ptr>>> &shapes) { + // Init embedding lookup kernel + std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); + lookup->InitKernel(shapes); + embedding_lookup_ops_[key] = lookup; + + // Init embedding weight + const std::vector &input_shapes = lookup->input_sizes(); + size_t total_dims = 1; + for (auto shape : input_shapes) { + total_dims *= shape; + } + WeightPtr embedding = std::make_shared(total_dims, 0.01); + weights_[key] = embedding; + + grads_accum_counter_[key] = 0; +} + +template +void ParameterServer::UpdateWeights() { + while (true) { + std::unique_lock lock(mutex_); + apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); + + for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { + Key key = iter->first; + WeightPtr weight_ptr = iter->second; + + std::shared_ptr optimizer = nullptr; + if (weight_key_to_optims_.count(key) > 0) { + const std::string &optim_name = weight_key_to_optims_[key]; + optimizer = optimizers_[optim_name]; + } + MS_EXCEPTION_IF_NULL(optimizer); + + std::shared_ptr optim_info = optim_infos_[key]; + if (optim_info == nullptr) { + continue; + } + const WeightPtr &weight = weights_[key]; + optim_info->UpdateWeight(weight); + const std::vector &inputs = optim_info->inputs(); + const std::vector &workspaces = optim_info->workspaces(); + const std::vector &outputs = optim_info->outputs(); + + optimizer->Execute(inputs, workspaces, outputs); + optim_info->Reset(); + } + ResetGradAccumCount(); + accum_grads_cv_.notify_all(); + } +} + +template +void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { + std::unique_lock lock(mutex_); + accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); + + const Key &key = keys[0]; + std::shared_ptr optim_info = optim_infos_[key]; + + // Create or update the optimizer info + if (optim_info == nullptr) { + const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; + std::shared_ptr pserver_kernel = optimizers_[weight_key_to_optims_[key]]; + if (pserver_kernel == nullptr) { + MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; + } + MS_EXCEPTION_IF_NULL(pserver_kernel); + OptimizerInfo *optim = + builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); + optim_info.reset(optim); + optim_infos_[key] = optim_info; + } else { + optim_info->Update(values, lengths); + } + MS_EXCEPTION_IF_NULL(optim_info); + + optim_info->Accumulate(values, lengths); + + grads_accum_counter_[key] += 1; + if (grads_accum_counter_[key] == worker_num_) { + grad_accum_count_++; + } + if (ReadyForUpdateWeights()) { + apply_grads_cv_.notify_one(); + } +} + +template +WeightPtr ParameterServer::weight(const Key &key) { + std::unique_lock lock(mutex_); + + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid weight key " << key; + return nullptr; + } + WeightPtr weight_ptr = weights_[key]; + WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); + copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); + return copy_weight_ptr; +} + +template +void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res) { + std::unique_lock lock(mutex_); + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding table key " << key; + return; + } + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; + return; + } + WeightPtr table_ptr = weights_[key]; + std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; + + // Update shapes of lookup operator + std::shared_ptr>>> shapes = + std::make_shared>>>(); + std::shared_ptr> indices_shape = std::make_shared>(); + indices_shape->emplace_back(lookup_ids.size()); + shapes->push_back(indices_shape); + table_lookup_op->ReInit(shapes); + + const std::vector output_shapes = table_lookup_op->output_sizes(); + std::vector inputs; + AddressPtr embedding_table = std::make_shared(); + AddressPtr indices = std::make_shared(); + inputs.push_back(embedding_table); + inputs.push_back(indices); + embedding_table->addr = table_ptr->data(); + embedding_table->size = table_ptr->size() * sizeof(T); + indices->addr = lookup_ids.data(); + indices->size = lookup_ids.size() * sizeof(T); + + std::vector workspaces; + std::vector outputs; + AddressPtr output = std::make_shared(); + std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(T), 0); + + output->addr = addr->data(); + output->size = output_shapes[0]; + outputs.push_back(output); + + table_lookup_op->Execute(inputs, workspaces, outputs); + res->vals = *addr; + res->lens.push_back(res.vals.size()); +} + +template +int ParameterServer::SumOfShapes(const std::vector &shapes) const { + int sum = 1; + for (auto shape : shapes) { + sum *= shape; + } + return sum; +} + +template +size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { + size_t capacity = 0; + for (size_t i = 0; i < keys.size(); i++) { + Key key = keys[i]; + if (embedding_row_lens_.count(key) > 0) { + capacity += embedding_row_lens_[key] * lens[i]; + } else { + MS_LOG(ERROR) << "Invalid embedding lookup id " << key; + } + } + return capacity; +} + +template +inline bool ParameterServer::ReadyForUpdateWeights() { + return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); +} + +template +inline bool ParameterServer::ReadyForAccumGrads() { + return grad_accum_count_ < weights_.size(); +} + +template +inline void ParameterServer::ResetGradAccumCount() { + grad_accum_count_ = 0; + for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { + grads_accum_counter_[iter->first] = 0; + } +} + +template +void ParameterServer::Run(const FuncGraphPtr &func_graph) { + ::ps::Start(0); + if (!::ps::IsServer()) { + std::cout << "This is not ther Server" << std::endl; + return; + } + Init(func_graph); + thread_->join(); +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc new file mode 100755 index 0000000000..274b7259b0 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/scheduler.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/scheduler.h" +#include +#include "ps/ps.h" + +namespace mindspore { +namespace parallel { +namespace ps { +void Scheduler::Run() { + ::ps::Start(0); + while (true) { + sleep(1); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ps/scheduler.h b/mindspore/ccsrc/frontend/parallel/ps/scheduler.h similarity index 100% rename from mindspore/ccsrc/parallel/ps/scheduler.h rename to mindspore/ccsrc/frontend/parallel/ps/scheduler.h diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.cc b/mindspore/ccsrc/frontend/parallel/ps/util.cc new file mode 100644 index 0000000000..fc63e88901 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/util.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ps/util.h" +#include +#include "frontend/parallel/ps/common.h" +#include "common/utils.h" + +namespace mindspore { +namespace parallel { +namespace ps { +std::unordered_map Util::optimizer_to_ids{ + {kApplyMomentum, 0}, + {kSparseAdam, 1}, + {kSparseFtrl, 2}, +}; + +std::unordered_map Util::id_to_optimizers{ + {0, kApplyMomentum}, + {1, kSparseAdam}, + {2, kSparseFtrl}, +}; +bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } + +bool Util::IsRoleOfWorker() { + auto role = common::GetEnv(kEnvRole); + if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) { + return true; + } else { + return false; + } +} + +bool Util::IsRoleOfPServer() { + auto role = common::GetEnv(kEnvRole); + if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) { + return true; + } else { + return false; + } +} + +bool Util::IsRoleOfScheduler() { + auto role = common::GetEnv(kEnvRole); + if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) { + return true; + } else { + return false; + } +} + +void Util::SetInternalEnvVar() { + if (IsParamServerMode()) { + auto comm_type = common::GetEnv(kEnvCommType); + if (comm_type.size() > 0) { + (void)common::SetEnv(kDmlcCommType, comm_type.c_str()); + } + auto interface = common::GetEnv(kEnvInterface); + if (interface.size() > 0) { + (void)common::SetEnv(kDmlcInterface, interface.c_str()); + } + auto server_num = common::GetEnv(kEnvPServerNum); + if (server_num.size() > 0) { + (void)common::SetEnv(kDmlcPServerNum, server_num.c_str()); + } + auto worker_num = common::GetEnv(kEnvWorkerNum); + if (worker_num.size() > 0) { + (void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str()); + } + if (IsRoleOfScheduler()) { + (void)common::SetEnv(kDmlcRole, kRoleOfScheduler); + } else if (IsRoleOfPServer()) { + (void)common::SetEnv(kDmlcRole, kRoleOfPServer); + } else if (IsRoleOfWorker()) { + (void)common::SetEnv(kDmlcRole, kRoleOfWorker); + } + auto scheduler_host = common::GetEnv(kEnvSchedulerHost); + if (scheduler_host.size() > 0) { + (void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str()); + } + auto scheduler_port = common::GetEnv(kEnvSchedulerPort); + if (scheduler_port.size() > 0) { + (void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str()); + } + } +} + +int Util::optimizer_id(std::string name) { + if (optimizer_to_ids.count(name) > 0) { + return optimizer_to_ids[name]; + } + return -1; +} + +std::string Util::optimizer_name(int id) { + if (id_to_optimizers.count(id) > 0) { + return id_to_optimizers[id]; + } + return ""; +} + +bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } + +int Util::LocalShard(int first_dim, int rank_id, int server_num) { + int shard_size = std::round((static_cast(first_dim)) / server_num); + int remain_size = first_dim % server_num; + if (remain_size == 0 || rank_id < server_num - 1) { + return shard_size; + } else { + return first_dim - (shard_size * (server_num - 1)); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ps/util.h b/mindspore/ccsrc/frontend/parallel/ps/util.h new file mode 100644 index 0000000000..8947ad36de --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/util.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ + +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace parallel { +namespace ps { +class Util { + public: + static bool IsParamServerMode(); + static bool IsRoleOfWorker(); + static bool IsRoleOfPServer(); + static bool IsRoleOfScheduler(); + static void SetInternalEnvVar(); + static int optimizer_id(std::string name); + static std::string optimizer_name(int id); + static bool is_optimizer(std::string name); + static int LocalShard(int first_dim, int rank_id, int server_num); + + private: + static std::unordered_map optimizer_to_ids; + static std::unordered_map id_to_optimizers; +}; +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h new file mode 100644 index 0000000000..9ecbc28fc5 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -0,0 +1,259 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ + +#include +#include +#include +#include +#include +#include "ps/ps.h" +#include "utils/log_adapter.h" +#include "frontend/parallel/ps/util.h" +#include "frontend/parallel/ps/common.h" +#include "frontend/parallel/ps/worker_proxy.h" + +namespace mindspore { +namespace parallel { +namespace ps { +template +class Worker { + public: + static Worker &GetInstance() { + static Worker instance; + return instance; + } + + void Run(); + void Push(const std::vector &keys, std::vector addrs, const std::vector &sizes); + void Pull(const size_t key, void *dev_addr, const size_t size); + size_t SetParamKey(const std::string ¶m_name); + void SetKeyOptimId(size_t key, const std::string &optimizer_name); + void SetOptimInputShapes(size_t key, const std::vector &shape); + void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); + void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); + void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); + void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); + + private: + Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} + ~Worker() { ::ps::Finalize(0, true); } + Worker(const Worker &) = delete; + Worker &operator=(const Worker &) = delete; + + bool IsKeyInit(const size_t key); + size_t GetParamKey(const std::string ¶m_name); + void InitPSOptimId(const size_t param_key); + void InitPSOptimInputShapes(const size_t key); + void InitPSParamData(const std::vector &keys, void *origin_addr, size_t size); + static void EmbeddingLookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &ranges, + std::vector>> *sliced) {} + + std::shared_ptr> kv_worker_; + bool running_; + size_t key_cnt_; + std::map param_to_key_; + std::map init_keys_; + std::map key_to_optimId_; + std::map>> key_to_optim_shapes_; +}; + +template +void Worker::Run() { + if (running_) { + MS_LOG(INFO) << "'Worker is already running."; + return; + } + + ::ps::Start(0); + if (!::ps::IsWorker()) { + MS_LOG(EXCEPTION) << "The role is not worker."; + } + kv_worker_ = std::make_shared>(0, 0, 1); + running_ = true; +} + +template +void Worker::Push(const std::vector &keys, std::vector addrs, const std::vector &sizes) { + size_t total_size = 0; + for (auto size : sizes) { + total_size += size; + } + ::ps::SArray total_buffer(total_size, 0); + size_t offset = 0; + for (size_t i = 0; i < sizes.size(); i++) { + memcpy(total_buffer.data() + offset / sizeof(T), addrs[i], sizes[i] * sizeof(T)); + offset += sizes[i] * sizeof(T); + } + kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); +} + +template +void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { + ::ps::SArray variables(size / sizeof(T), 0); + kv_worker_->Wait(kv_worker_->ZPull({key}, &variables)); + memcpy(dev_addr, variables.data(), size); +} + +template +void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd) { + kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, &lookup_result, cmd); +} + +template +void Worker::InitPSParamData(const std::vector &keys, void *origin_addr, size_t size) { + ::ps::SArray addr(reinterpret_cast(origin_addr), size / sizeof(T)); + ::ps::SArray<::ps::Key> key(keys); + ::ps::SArray lens; + lens.push_back(addr.size()); + kv_worker_->Wait(kv_worker_->ZPush(key, addr, lens, kInitWeightsCmd)); + init_keys_[key[0]] = true; +} + +template +void Worker::SetOptimInputShapes(size_t key, const std::vector &shape) { + if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) { + key_to_optim_shapes_[key] = {shape}; + } else { + key_to_optim_shapes_[key].push_back(shape); + } +} + +template +void Worker::InitPSOptimInputShapes(const size_t key) { + ::ps::SArray<::ps::Key> keys; + ::ps::SArray shape_len; + ::ps::SArray all_shape; + std::vector> shapes = key_to_optim_shapes_[key]; + for (auto shape : shapes) { + keys.push_back(key); + if (shape.size() == 0) { + shape_len.push_back(1); + all_shape.push_back(1); + } else { + shape_len.push_back(SizeToInt(shape.size())); + for (auto dim : shape) { + all_shape.push_back(static_cast(dim)); + } + } + } + MS_LOG(ERROR) << "keys:" << keys; + MS_LOG(ERROR) << "shape_len:" << shape_len; + MS_LOG(ERROR) << "all_shape:" << all_shape; + if (!init_keys_[key]) { + init_keys_[key] = true; + } + kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); +} + +template +bool Worker::IsKeyInit(const size_t key) { + if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) { + return false; + } + return true; +} + +template +size_t Worker::SetParamKey(const std::string ¶m_name) { + size_t key = UINT64_MAX; + if (param_to_key_.count(param_name)) { + key = param_to_key_[param_name]; + MS_LOG(INFO) << param_name << " key is already set: key value is " << key; + } else { + key = key_cnt_++; + param_to_key_[param_name] = key; + MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name; + } + return key; +} + +template +size_t Worker::GetParamKey(const std::string ¶m_name) { + size_t key = kInvalidKey; + if (param_to_key_.find(param_name) != param_to_key_.end()) { + key = param_to_key_[param_name]; + MS_LOG(ERROR) << "Get key of parameter " << param_name << " key is " << key; + } + return key; +} + +template +void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) { + key_to_optimId_[key] = Util::optimizer_id(optimizer_name); +} + +template +void Worker::InitPSOptimId(const size_t param_key) { + if (key_to_optimId_.count(param_key) == 0) { + MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key; + } + int optim_id = key_to_optimId_[param_key]; + + ::ps::SArray<::ps::Key> keys = {param_key}; + ::ps::SArray optim_id_vals = {static_cast(optim_id)}; + ::ps::SArray optim_id_lens = {optim_id_vals.size()}; + kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd); +} + +template +void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, + const std::vector &sizes) { + bool has_init = IsKeyInit(keys[0]); + if (has_init) { + MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; + return; + } + ::ps::SArray shapes_val; + for (auto dim : shapes) { + shapes_val.push_back(static_cast(dim)); + } + kv_worker_->Wait(kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray(sizes))); +} + +template +// Initialize parameters and optimizer kernels of Parameter Server. +void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) { + size_t param_key = GetParamKey(param_name); + if (param_key == kInvalidKey) { + MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; + return; + } + bool init = IsKeyInit(param_key); + if (!init) { + MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; + // No need to push embedding table data to Parameter Server. + if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { + InitPSParamData({param_key}, param_data, param_size); + } + InitPSOptimId(param_key); + InitPSOptimInputShapes(param_key); + } +} + +template +void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + kv_worker_->AddEmbeddingTable(key, row_count); +} + +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h new file mode 100644 index 0000000000..a0f58d39a4 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -0,0 +1,311 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ +#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ + +#include +#include +#include +#include +#include +#include "ps/ps.h" +#include "frontend/parallel/ps/util.h" + +namespace mindspore { +namespace parallel { +namespace ps { +template +class WorkerProxy : public ::ps::KVWorker { + public: + using Worker = ::ps::KVWorker; + using Callback = std::function; + using SlicedKVs = std::vector>>; + using Slicer = + std::function &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; + using ::ps::SimpleApp::obj_; + explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { + using _1 = std::placeholders::_1; + using _2 = std::placeholders::_2; + using _3 = std::placeholders::_3; + lookup_customer_ = std::unique_ptr<::ps::Customer>( + new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); + lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3); + init_embedding_slicer_ = std::bind(&WorkerProxy::EmbeddingTableInitSlicer, this, _1, _2, _3); + push_slicer_ = std::bind(&WorkerProxy::PushSlicer, this, _1, _2, _3); + broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3); + } + ~WorkerProxy() override = default; + + void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); + void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, + int priority = 0); + int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); + void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, + int cmd = 0, int priority = 0); + + private: + template + int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, + const Callback &cb); + void LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced); + void ProcessLookupResult(const ::ps::Message &msg); + void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, + const Slicer &slicer); + + std::unique_ptr<::ps::Customer> lookup_customer_; + std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; + std::unordered_map>> lookup_results_; + std::mutex mutex_; + Slicer lookup_slicer_; + Slicer init_embedding_slicer_; + Slicer push_slicer_; + Slicer broadcast_slicer_; + std::unordered_map lookup_callbacks_; +}; + +template +void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { + uint64_t begin = 0; + uint64_t end = 0; + int server_num = ::ps::NumServers(); + for (int i = 0; i < server_num; i++) { + int local_row_cnt = Util::LocalShard(row_count, i, server_num); + if (i == 0) { + end = local_row_cnt - 1; + } else { + begin = end + 1; + end += local_row_cnt; + } + ::ps::Range range(begin, end); + if (embedding_table_ranges_.count(key) == 0) { + embedding_table_ranges_[key] = std::make_shared>(); + } + embedding_table_ranges_[key]->push_back(range); + } +} + +template +void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, + int priority) { + int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = lookup_ids; + kvs.lens = lens; + kvs.priority = priority; + Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); + lookup_customer_->WaitRequest(ts); +} + +template +int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, const Callback &cb, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); + return ts; +} + +template +void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, + const ::ps::SArray &lens, int cmd, int priority) { + int ts = obj_->NewRequest(::ps::kServerGroup); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.vals = vals; + kvs.lens = lens; + kvs.priority = priority; + Send(obj_, ts, true, false, cmd, kvs, push_slicer_); + obj_->WaitRequest(ts); +} + +template +template +int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + C *lookup_result, int cmd, const Callback &cb) { + int ts = lookup_customer_->NewRequest(::ps::kServerGroup); + const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { + mutex_.lock(); + auto &kvs = lookup_results_[ts]; + mutex_.unlock(); + + size_t total_len = 0; + const auto &s = kvs[0]; + for (size_t i = 0; i < s.lens.size(); i++) { + total_len += s.lens[i]; + } + lookup_result->resize(total_len, 0); + T *result_addr = lookup_result->data(); + + for (const auto &s : kvs) { + size_t offset = 0; + for (size_t i = 0; i < s.vals.size(); i++) { + result_addr[offset++] += s.vals[i]; + } + } + + mutex_.lock(); + lookup_results_.erase(ts); + mutex_.unlock(); + if (cb) cb(); + }; + lookup_callbacks_[ts] = callback; + return ts; +} + +template +void WorkerProxy::LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + int *data = send.lens.data(); + size_t size = send.lens.size(); + std::vector lookup_ids(data, data + size); + std::sort(lookup_ids.begin(), lookup_ids.end()); + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + size_t index = 0; + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + auto &kvs = sliced->at(i).second; + + auto lookup_id = static_cast(lookup_ids[index]); + while (lookup_id >= begin && lookup_id <= end) { + kvs.vals.push_back(lookup_id); + if (++index >= lookup_ids.size()) { + break; + } + lookup_id = static_cast(lookup_ids[index]); + } + kvs.keys.push_back(key); + kvs.lens.push_back(kvs.vals.size()); + + if (kvs.vals.size() == 0) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + } + } +} + +template +void WorkerProxy::EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + for (size_t i = 0; i < ranges.size(); i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + auto server_num = ::ps::Postoffice::Get()->num_servers(); + sliced->resize(server_num); + for (int i = 0; i < server_num; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced) { + auto server_num = ::ps::Postoffice::Get()->num_servers(); + sliced->resize(server_num); + for (int i = 0; i < server_num; i++) { + sliced->at(i).first = true; + sliced->at(i).second = send; + } +} + +template +void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { + int ts = msg.meta.timestamp; + if (msg.meta.pull) { + CHECK_GE(msg.data.size(), (size_t)2); + ::ps::KVPairs kvs; + kvs.keys = msg.data[0]; + kvs.vals = msg.data[1]; + if (msg.data.size() > (size_t)2) { + kvs.lens = msg.data[2]; + } + mutex_.lock(); + lookup_results_[ts].push_back(kvs); + mutex_.unlock(); + } + if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { + const auto &cb = lookup_callbacks_[ts]; + cb(); + lookup_callbacks_.erase(ts); + } +} + +template +void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, + const ::ps::KVPairs &kvs, const Slicer &slicer) { + SlicedKVs sliced; + slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); + + for (size_t i = 0; i < sliced.size(); i++) { + const auto &s = sliced[i]; + if (!s.first) continue; + ::ps::Message msg; + msg.meta.app_id = customer->app_id(); + msg.meta.customer_id = customer->customer_id(); + msg.meta.request = true; + msg.meta.push = push; + msg.meta.pull = pull; + msg.meta.head = cmd; + msg.meta.timestamp = timestamp; + msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); + msg.meta.priority = kvs.priority; + const auto &kvs = s.second; + if (kvs.keys.size()) { + msg.AddData(kvs.keys); + msg.AddData(kvs.vals); + if (kvs.lens.size()) { + msg.AddData(kvs.lens); + } + } + ::ps::Postoffice::Get()->van()->Send(msg); + } +} +} // namespace ps +} // namespace parallel +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ diff --git a/mindspore/ccsrc/parallel/status.h b/mindspore/ccsrc/frontend/parallel/status.h similarity index 100% rename from mindspore/ccsrc/parallel/status.h rename to mindspore/ccsrc/frontend/parallel/status.h diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc new file mode 100644 index 0000000000..8d54eb454a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -0,0 +1,1187 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/step_auto_parallel.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/param_value.h" +#include "ir/tensor.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/pipeline.h" + +namespace mindspore { +namespace parallel { +bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + // assume no change to graph + bool changes = false; + // control whether use model_parallel mode + if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) || + root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { + return changes; + } + // check whether strategy_search_mode is valid + std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); + if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { + // Setting searching mode: dynanic programming as default. + strategy_search_mode = DYNAMIC_PROGRAMMING; + MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; + } + + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + + if (MsContext::GetInstance()->save_graphs_flag()) { + draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); + } + MS_LOG(INFO) << "Now entering step auto parallel"; + TOTAL_OPS = 0; + AnfNodePtr ret = root->get_return(); + std::vector all_nodes = DeepScopedGraphSearch(ret); + + if (ParallelInit() != SUCCESS) { + MS_LOG(EXCEPTION) << "Parallel init failed"; + } + + // mark the forward cnodes, parallel only care these nodes + MarkForwardCNode(root); + + if (FindCommunicationOp(all_nodes)) { + MS_LOG(EXCEPTION) << "The graph contain communication op"; + } + + // search parallelization strategy + if (strategy_search_mode == DYNAMIC_PROGRAMMING) { + if (ParallelStrategySearch(all_nodes, root) != SUCCESS) { + MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode"; + } + } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) { + if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) { + MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode"; + } + } else { + MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected"; + } + + (void)gettimeofday(&end_time, nullptr); + uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; + + root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true); + return changes; +} + +// Given the node, return whether each input is a parameter or a output of a operator. +// The returned boolean vector should be the same order of the inputs, thus its implementation +// is closely consistent with ExtractShape() in step_parallel.cc +std::vector ExtractInputParameterByNode(const CNodePtr &node) { + std::vector is_parameter; + std::vector node_inputs{node->inputs()}; + for (size_t i = 1; i < node_inputs.size(); ++i) { + auto input = node_inputs[i]; + + if (input->isa()) { + auto input_parameter = input->cast(); + if (input_parameter->has_default()) { + bool requires_grad = input_parameter->default_param()->requires_grad(); + is_parameter.push_back(requires_grad); + } else { + is_parameter.push_back(false); + } + } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { + is_parameter.push_back(false); + } + } + return is_parameter; +} + +// Given the type, return the number of bytes to represent this type +size_t GetLengthOfDataType(const TypePtr &type) { + switch (type->type_id()) { + case kNumberTypeBool: + return sizeof(bool); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeInt64: + return sizeof(int64_t); + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeUInt16: + return sizeof(uint16_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + case kNumberTypeUInt64: + return sizeof(uint64_t); + case kNumberTypeFloat16: + return sizeof(float) / 2; + case kNumberTypeFloat32: + return sizeof(float); + case kNumberTypeFloat64: + return sizeof(double); + case kNumberTypeInt: + return sizeof(int); + case kNumberTypeUInt: + return sizeof(unsigned int); + case kNumberTypeFloat: + return sizeof(float); + default: + MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); + } +} + +size_t GetInputsTypeLen(const AnfNodePtr &input) { + MS_EXCEPTION_IF_NULL(input); + if (!input->isa() && !input->isa() && !IsValueNode(input)) { + MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; + } + + size_t input_type_len = 0; + auto type = input->Type(); + MS_EXCEPTION_IF_NULL(type); + if (type->isa()) { + auto input_element_type = type->cast()->element(); + input_type_len = GetLengthOfDataType(input_element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); + } + return input_type_len; +} + +std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector inputs_type_len; + std::vector node_inputs{node->inputs()}; + + // extract input element length + for (auto &input : node_inputs) { + if (IsValueNode(input)) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters = FindParameterByRefKeyNode(input, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); + } else if (input->isa() || input->isa() || IsValueNode(input)) { + // extract input shape from parameter and apply node + inputs_type_len.push_back(GetInputsTypeLen(input)); + } + } + return inputs_type_len; +} + +std::vector ExtractOutputTypeByNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector outputs_type; + // extract output element type + auto primary_output_type = node->Type(); + MS_EXCEPTION_IF_NULL(primary_output_type); + if (primary_output_type->isa()) { + // in this case, the output is a tuple + auto tuple_output_type = primary_output_type->cast(); + auto elements = tuple_output_type->elements(); + for (auto &ele : elements) { + if (ele->isa()) { + auto ele_element_type = ele->cast()->element(); + outputs_type.push_back(ele_element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); + } + } + } else { + // in this case, the output is a single tensor + if (primary_output_type->isa()) { + auto element_type = primary_output_type->cast()->element(); + outputs_type.push_back(element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); + } + } + return outputs_type; +} + +bool IsElementWiseOperator(const std::string &op_name) { + static const std::set elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, + SQRT, CAST, POW, EXP, LOG, COS, + ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; + auto iter = elementwise_op.find(op_name); + return (iter != elementwise_op.end()); +} + +bool IsSplittableOperator(const std::string &op_name) { + // clang-format off + static const std::set splittable_op = + {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, + FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, + REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, + MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, + LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, + STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, + SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; + // clang-format on + + auto iter = splittable_op.find(op_name); + return (iter != splittable_op.end()); +} + +bool IsAutoParallelCareNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + ValueNodePtr prim_node = cnode->input(0)->cast(); + if (prim_node == nullptr) { + return false; + } + PrimitivePtr prim = GetValueNode(prim_node); + if (prim == nullptr) { + return false; + } + bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name()); + if (bool_result) { + MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); + } else if (prim->name() == CAST) { + if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { + // Do not care CASTs from optimizer + return false; + } + return true; + } + return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); +} + +OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(cnode); + auto attrs = prim->attrs(); + std::vector shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape"; + } + // Create an OperatorInfo instance + OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list); + MS_EXCEPTION_IF_NULL(operator_info); + // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not) + std::vector parameter_info = ExtractInputParameterByNode(cnode); + if (operator_info->set_is_parameter(parameter_info) != SUCCESS) { + MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name(); + return nullptr; + } + // Set the data type for inputs and outputs of this OperatorInfo + auto inputs_type_length = ExtractInputTypeLengthByNode(cnode); + auto outputs_type = ExtractOutputTypeByNode(cnode); + std::vector outputs_type_length; + outputs_type_length.reserve(outputs_type.size()); + std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length), + GetLengthOfDataType); + if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) { + MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name(); + return nullptr; + } + if (operator_info->set_outputs_type(outputs_type) != SUCCESS) { + MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name(); + return nullptr; + } + // When the 'inputs' contains numerical values for some operators, these values should be extracted from + // ANF graph + auto &inputs = cnode->inputs(); + std::vector input_value; + for (size_t index = 1; index < inputs.size(); ++index) { + if (inputs[index]->isa()) { + input_value.push_back(GetValueNode(inputs[index])); + } else { + input_value.emplace_back(nullptr); + } + } + operator_info->set_input_value(input_value); + operator_info->set_outputs_dtype(cnode->Type()); + operator_info->set_cnode(cnode); + // key of strategy map + std::string strategy_key_name = NodeParameterName(cnode); + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); + // If no strategy has been configured for this operator, then candidate strategies are generated for + // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. + // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . + if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { + // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for + // BatchParallelInfo operator + operator_info->ComputeBatchSplitFlagList(); + if (operator_info->GenerateStrategies(0) != SUCCESS) { + MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed."; + return nullptr; + } + } else { + // In this case, the configured strategy should be extracted to help setting cost + StrategyPtr strategyPtr; + if (load_strategy_from_ckpt) { + strategyPtr = (*stra_map)[strategy_key_name]; + } else { + strategyPtr = parallel::ExtractStrategy(attrs); + } + if (strategyPtr != nullptr) { + if (prim->name() == RESHAPE) { + MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; + } + // Set cost for this configured strategy + if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { + MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; + } else if (FULLY_USE_DEVICES) { + // If configured to fully use devices, then checking for the user-specified strategy + int32_t used_devices = operator_info->used_devices(); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); + // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel + if (used_devices == 1) { + return operator_info; + } + // 'used_devices == -1' means that 'used_devices_' is not set + if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) { + MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " + << "but the specified strategy uses device: " << used_devices + << ", total devices: " << total_device_num; + } + } + } + } + return operator_info; +} + +// Using CNode's UniqueIds to construct nodes +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { + MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); + // The map from CNode's UniqueId to its operatorInfo + std::map from_cnode_to_info; + // extract strategy from checkpoint for multi-train + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } + // Step 1 + for (auto &node : all_nodes) { + // NOTE: we only care about splittable Primitive operators + auto cnode = node->cast(); + bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); + if (bool_result) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsAutoParallelCareNode(cnode)) { + // Needed by rec_parser + if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { + auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); + if (prev_cnode != nullptr) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); + } + } + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + + auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); + if (search_cnode == from_cnode_to_info.end()) { + auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); + if (operator_info == nullptr) { + return FAILED; + } + // Needed by rec_parser + operator_info->set_type(prim->name()); + std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); + + entire_costgraph->AddOperator(operator_info); + (void)cnode->set_operator_info(operator_info); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); + (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); + // Needed by rec_parser + entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); + } else { + // Two CNODEs' UniqueIds should not be equal + MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); + } + } + + MS_LOG(INFO) << "Constructing nodes for cost graph ends."; + return SUCCESS; +} + +// Using CNode's UniqueIdThroughCopys to construct nodes +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { + MS_LOG(INFO) << "Constructing nodes for cost graph begins."; + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); + // The map from CNode's UniqueIdThroughCopy to its operatorInfo + std::map from_cnode_to_info; + // extract strategy from checkpoint for multi-train + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } + for (auto &node : all_nodes) { + // NOTE: we only care about splittable Primitive operators + auto cnode = node->cast(); + bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); + if (bool_result) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsAutoParallelCareNode(cnode)) { + // Needed by rec_parser + if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { + auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); + if (prev_cnode != nullptr) { + entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); + } + } + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + + // Find the operatorInfo if it exists + auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); + if (search_cnode == from_cnode_to_info.end()) { + // In this case, the corresponding OperatorInfo is not created, create the new one. + auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); + if (operator_info == nullptr) { + return FAILED; + } + // Needed by rec_parser + operator_info->set_type(prim->name()); + std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); + + entire_costgraph->AddOperator(operator_info); + (void)cnode->set_operator_info(operator_info); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); + (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); + // Needed by rec_parser + entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); + } else { + auto current_op_ptr = search_cnode->second; + if (current_op_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; + } else { + bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && + (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && + (current_op_ptr->name().find(prim->name()) == std::string::npos); + if (is_find_wrong) { + MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() + << " does not match the Prim: " << prim->name(); + } + (void)cnode->set_operator_info(current_op_ptr); + MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() + << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() + << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); + } + } + } + + MS_LOG(INFO) << "Constructing nodes for cost graph ends."; + return SUCCESS; +} + +void ConstructCostGraphEdges(const std::vector &all_nodes) { + // Step 2 + MS_LOG(INFO) << "Constructing edges for cost graph begins."; + for (auto &node : all_nodes) { + auto cnode = node->cast(); + bool bool_result_cnode = (cnode == nullptr) || !IsValueNode(cnode->input(0)); + if (bool_result_cnode) { + continue; + } + auto &inputs = cnode->inputs(); + ValueNodePtr prim_anf_node = inputs[0]->cast(); + if (!IsAutoParallelCareNode(cnode)) { + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + size_t edge_count = 0; + + for (size_t i = 1; i < inputs.size(); ++i) { + auto prev_cnode = inputs[i]->cast(); + bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); + if (bool_result_prev_cnode) { + continue; + } + ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast(); + PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast(); + size_t output_index = 0; + + bool bool_result = + (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); + while (bool_result) { + if (IsAutoParallelCareNode(prev_cnode)) { + std::string edge_name = + prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); + // If the edge between these two operators already has been added, then the edge will not be added again. + if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { + break; + } + EdgePtr edge_ptr; + MS_LOG(INFO) << "Creating edge: " << edge_name; + + bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || + (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); + if (follow_strategy) { + // Redistribution in not allowed on the edge. + // Elementwise operators have the same strategy as their previous operators. + edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), + output_index, i - 1, false, true); + } else { + edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), + output_index, i - 1, false); + } + + // Init costs for this edge + if (edge_ptr->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge cost initialization failed"; + } + cnode->operator_info()->AddPrevEdge(edge_ptr); + prev_cnode->operator_info()->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " + << cnode->operator_info()->name(); + edge_count++; + + break; + } else if (prev_prim->name() == TUPLE_GETITEM) { + // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before + // this 'tuple_getitem' + MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; + output_index = IntToSize(GetValue(GetValueNode(prev_cnode->input(2)))); + prev_cnode = prev_cnode->input(1)->cast(); + bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); + if (bool_result_tuple) { + break; + } + prev_prim_anf_node = prev_cnode->input(0)->cast(); + prev_prim = prev_prim_anf_node->value()->cast(); + if (!IsAutoParallelCareNode(prev_cnode)) { + MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name(); + } + MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, " + << "and creating an edge between the Operator before " + << "'tuple_getitem' and the Operator after 'tuple_getitem'."; + } else if (prev_prim->name() == DEPEND) { + // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before + // this 'depend' + MS_LOG(INFO) << "Jumping the 'depend' operator."; + prev_cnode = prev_cnode->input(1)->cast(); + bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); + if (bool_result_depend) { + break; + } + prev_prim_anf_node = prev_cnode->input(0)->cast(); + prev_prim = prev_prim_anf_node->value()->cast(); + MS_LOG(INFO) << "Jumped the 'depend' operator, " + << "and creating an edge between the Operator before " + << "'depend' and the Operator after 'depend'."; + } + bool_result = + (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); + } + } + MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); + } + + MS_LOG(INFO) << "Constructing edges for cost graph ends."; +} + +std::pair> CNodeWithRefKeys(const AnfNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector refkeys; + if (cnode->isa()) { + auto cnode_ptr = cnode->cast(); + auto inputs = cnode_ptr->inputs(); + for (auto &one_input : inputs) { + if (IsValueNode(one_input)) { + refkeys.push_back(one_input); + } + } + if (refkeys.size() >= 1) { + return std::make_pair(cnode, refkeys); + } + } + return {nullptr, refkeys}; +} + +void AugmentCostGraph(const std::vector &all_nodes) { + // Step 3 + for (auto &node : all_nodes) { + auto cnode_with_refkeys = CNodeWithRefKeys(node); + if ((!node->isa()) && (cnode_with_refkeys.first == nullptr)) { + continue; + } + std::string parameter_name; + AnfNodePtr target_parameter = nullptr; + AnfNodeIndexSet target_set; + + if (cnode_with_refkeys.first != nullptr) { + // Dealing with the RefKey case + auto refkeys = cnode_with_refkeys.second; + auto cnode = cnode_with_refkeys.first; + + auto cnode_ptr = cnode->cast(); + if (cnode_ptr == nullptr || !IsValueNode(cnode_ptr->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(cnode_ptr)) { + continue; + } + + if (refkeys.size() > 1) { + MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; + } + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + auto cnode_func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); + + // Find the RefKey being used + auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; + for (auto &candidate : candidate_set_by_refkey) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(c)) { + continue; + } + target_set.add(candidate); + } + + // Find the corresponding Parameter being used + std::vector parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + parameter_name = parameters[0]->cast()->name(); + target_parameter = parameters[0]; + auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; + for (auto &candidate : candidate_set_by_para) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(c)) { + continue; + } + (void)target_set.insert(candidate); + } + } else if (node->isa()) { + // Dealing with the Parameter case + MS_EXCEPTION_IF_NULL(node->func_graph()); + MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); + auto candidate_set = node->func_graph()->manager()->node_users()[node]; + for (auto &candidate : candidate_set) { + auto candidate_node = candidate.first; + auto c = candidate_node->cast(); + if (c == nullptr || !IsValueNode(c->input(0))) { + continue; + } + if (!IsAutoParallelCareNode(c)) { + continue; + } + (void)target_set.insert(candidate); + } + // In this case, node is a Parameter + parameter_name = node->cast()->name(); + target_parameter = node; + } + if (target_set.size() <= 1) { + continue; + } + + // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs + std::set target_without_duplicate; + for (auto &target : target_set) { + auto target_cnode = target.first->cast(); + auto input_index = target.second; + (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); + } + if (target_without_duplicate.size() <= 1) { + continue; + } + + // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators. + OperatorInfoPtr tmp_identity_ptr; + bool new_identity = false; + std::string tmp_identity_name; + auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name); + if (returned_identity != nullptr) { + // In this case, the TmpIdentityInfo instance has already been created + new_identity = false; + tmp_identity_ptr = returned_identity; + tmp_identity_name = tmp_identity_ptr->name(); + } else { + // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created. + new_identity = true; + // 1) extract input shape from this Parameter + MS_EXCEPTION_IF_NULL(target_parameter); + AbstractBasePtr abstract = target_parameter->abstract(); + if (abstract == nullptr) { + MS_LOG(EXCEPTION) << "Failure: abstract is nullptr"; + } + auto input_shape = dyn_cast(abstract->GetShapeTrack()); + if (input_shape == nullptr) { + MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr"; + } + std::vector shape_int = input_shape->shape(); + Shape shape; + (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape), + [](int sub_shape) { return static_cast(sub_shape); }); + Shapes inputs_shape = {shape}; + Shapes outputs_shape = {shape}; + // 2) init the attr + std::unordered_map attr = {}; + + // Create the TmpIdentity instance + tmp_identity_ptr = std::make_shared(inputs_shape, outputs_shape, attr); + tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS)); + TOTAL_OPS++; + tmp_identity_ptr->set_refkey_parameter_name(parameter_name); + // Set the parameter and type lengths for inputs and outputs + std::vector is_parameter; + auto casted_target_parameter = target_parameter->cast(); + MS_EXCEPTION_IF_NULL(casted_target_parameter); + if (casted_target_parameter->has_default()) { + bool requires_grad = casted_target_parameter->default_param()->requires_grad(); + is_parameter.push_back(requires_grad); + } else { + is_parameter.push_back(false); + } + if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) { + MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed"; + } + auto node_type = target_parameter->Type(); + if (node_type->isa()) { + auto input_element_type = node_type->cast()->element(); + std::vector type_length = {GetLengthOfDataType(input_element_type)}; + if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) { + MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed"; + } + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name(); + } + + // Generate strategies for this TmpIdentityInfo instance; + if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) { + MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name(); + } + } + // A flag recording whether new edges have been created or not + bool add_identity_edge = false; + + // Create edges between this TmpIdentityInfo instance and subsequent Operator instances + for (auto &target : target_set) { + auto target_cnode = target.first->cast(); + auto prim = GetValueNode(target_cnode->input(0)); + auto input_index = target.second; + + std::string edge_name = + std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); + // If the edge between these two operators already has been added, then the edge will not be added again. + if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { + continue; + } + std::shared_ptr edge_ptr = std::make_shared( + edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); + + if (edge_ptr->InitEdgeCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Edge cost initialization failed"; + } + target_cnode->operator_info()->AddPrevEdge(edge_ptr); + tmp_identity_ptr->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " + << target_cnode->operator_info()->name(); + add_identity_edge = true; + } + if (new_identity && add_identity_edge) { + // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied + entire_costgraph->AddOperator(tmp_identity_ptr); + } + } +} + +bool FindReshape(const CNodePtr &cnode) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + return false; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + return false; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; + } + if (prim->name() != RESHAPE) { + return false; + } + return true; +} + +// find previous node, then obtain its strategy_cost_ vector to get its layout vector. +bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { + // if previous node is a parameter, handle it in the outsize. + if (node->isa()) { + return false; + } + if (!node->isa()) { + return false; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return false; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + *pre_operator_info = cnode->operator_info(); + *out_index = 0; + return true; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == TUPLE_GETITEM) { + *out_index = GetTupleGetItemIndex(cnode); + // find tuple_get_item's previous node + auto pre_node = cnode->input(1); + if (!pre_node->isa()) { + MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; + } + CNodePtr pre_cnode = pre_node->cast(); + if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { + *pre_operator_info = pre_cnode->operator_info(); + return true; + } + return false; + } + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (prim->name() == DEPEND && index != 1) { + continue; + } + if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { + continue; + } + return true; + } + MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; + return false; +} + +// find next node, then obtain its strategy_cost_ vector to get its layout vector. +// if reshape's output connect to several primitive, return the first layout found +bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + FuncGraphManagerPtr manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[cnode]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); + *next_operator_info = use_apply->operator_info(); + *in_index = node_pair.second - 1; + return true; + } + MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) + << " " << (use_apply->operator_info() != nullptr); + + if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { + return true; + } + } + return false; +} + +void ReshapeCostCompute(const std::vector &all_nodes) { + for (auto node : all_nodes) { + auto cnode = node->cast(); + if (!FindReshape(cnode)) { + continue; + } + MS_ASSERT(cnode->inputs().size() == 3); + // get previous node's strategy_cost_ + auto pre_node = cnode->input(1); + int32_t out_index = 0; + OperatorInfoPtr pre_operator_info; + std::vector> pre_stra_costs; + if (pre_node->isa()) { + OperatorInfoPtr operator_info = cnode->operator_info(); + auto reshape_info = std::dynamic_pointer_cast(operator_info); + reshape_info->SetCostForReshapeWithParameter(); + pre_operator_info = reshape_info; + pre_stra_costs = reshape_info->strategy_cost(); + } else { + if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { + MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; + } + pre_stra_costs = pre_operator_info->strategy_cost(); + } + // get next node's strategy_cost_ + int32_t in_index = 0; + OperatorInfoPtr next_operator_info; + std::vector> next_stra_costs; + bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); + if (!find_next_node) { + MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; + } + // set input_layout and output_layout for reshape. + // init reshape and set cost for each input_layout and output_layout. + OperatorInfoPtr operator_info = cnode->operator_info(); + auto reshape_info = std::dynamic_pointer_cast(operator_info); + reshape_info->set_pre_operator_name(pre_operator_info->name()); + reshape_info->set_pre_operator_index(out_index); + if (find_next_node) { + next_stra_costs = next_operator_info->strategy_cost(); + reshape_info->set_next_operator_name(next_operator_info->name()); + reshape_info->set_next_operator_index(in_index); + } + bool is_prev_param = pre_node->isa(); + if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != + SUCCESS) { + MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; + } + } +} + +Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root) { + // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. + // Step 1: Traverse the ANF graph, and create NODEs for costgraph: + // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies + // for each OperatorInfo; + // Step 1.1: Deal with 'Reshape': + // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's + // layout as its output layout. + // Step 2: Traverse the ANF graph, and create EDGES for costgraph: + // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies + // for each edge, based on the strategies of two OperatorInfos; + // Step 3: Augment the costgraph: + // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity + // operator for this Parameter, and add an edge for the use of this Parameter by each + // subsequent operator; + // Step 3.1: Calculate memory usage: + // note the memory usage calculation is different in training phase and inference phase. + // Step 4: Run the Dynamic Programming algorithm: + // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge + // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input + // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm + // runs on each of them. + // + // OUTPUT: the determined strategy for each operator. + + // Step 1 + if (CostModelContext::GetInstance()->is_multi_subgraphs()) { + if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } else { + if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } + // Step 1.1 + ReshapeCostCompute(all_nodes); + // Step 2 + ConstructCostGraphEdges(all_nodes); + MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() + << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; + + // Step 3: Augment the costgraph. + AugmentCostGraph(all_nodes); + MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() + << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; + + // Step 3.1: Calculate the memory usage + if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { + MS_LOG(EXCEPTION) << "Calculating memory cost failed."; + } + + // Step 4: run DP algorithm on the costgraph. + if (GetStrategy(entire_costgraph) != SUCCESS) { + MS_LOG(ERROR) << "Strategy search for cost-graph fails"; + return FAILED; + } + MS_LOG(INFO) << "Searching strategy succeeded."; + + if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { + MS_LOG(INFO) << "Init selected strategy succeeded."; + } else { + MS_LOG(EXCEPTION) << "Init selected strategy failed."; + } + + // print the selected strategy + for (auto &op : entire_costgraph->GetOperators()) { + StrategyPtr s_strategy = op->selected_strategy(); + MS_LOG(INFO) << op->name() << " : The strategy is:"; + PrintStrategy(s_strategy); + } + + return SUCCESS; +} + +std::vector> RecInputTensorNames(const std::map::iterator &it, + std::vector> input_tensor_names) { + for (size_t j = 0; j < input_tensor_names.size(); j++) { + for (size_t k = 0; k < input_tensor_names[j].size(); k++) { + if (it->first == input_tensor_names[j][k]) { + input_tensor_names[j][k] = it->second; + break; + } + } + } + return input_tensor_names; +} + +CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) { + PrimitivePtr prim = GetValueNode(prim_anf_node); + if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) { + auto prev_cnode = cnode->input(1)->cast(); + if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { + return nullptr; + } + auto prev_prim = prev_cnode->input(0)->cast()->value()->cast(); + while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) { + prev_cnode = prev_cnode->input(1)->cast(); + if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { + return nullptr; + } + prev_prim = prev_cnode->input(0)->cast()->value()->cast(); + } + return prev_cnode; + } + return nullptr; +} + +Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root) { + if (CostModelContext::GetInstance()->is_multi_subgraphs()) { + if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } else { + if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { + MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " + << entire_costgraph->GetOperators().size() << " operators."; + } else { + MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; + } + } + ReshapeCostCompute(all_nodes); + + auto ops = entire_costgraph->GetOperators(); + std::vector> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); + auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list(); + for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) { + input_tensor_names = RecInputTensorNames(it++, input_tensor_names); + } + std::shared_ptr graph = ParseGraph(ops, input_tensor_names); + + std::shared_ptr>> eli_list(new std::vector>); + std::shared_ptr> index_list(new std::vector); + graph = EliminateGraph(graph, eli_list, index_list); + + size_t num_device = g_device_manager->DeviceNum(); + double device_memory = entire_costgraph->GetDeviceMemory(); + if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { + MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; + } else { + MS_LOG(ERROR) << "PartitionForAllDevices failed."; + return FAILED; + } + + GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); + + if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { + MS_LOG(INFO) << "Init selected strategy succeeded."; + } else { + MS_LOG(ERROR) << "Init selected strategy failed."; + return FAILED; + } + + // print the selected strategy + for (auto &op : entire_costgraph->GetOperators()) { + StrategyPtr s_strategy = op->selected_strategy(); + MS_LOG(INFO) << op->name() << " : The strategy is:"; + PrintStrategy(s_strategy); + } + + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h new file mode 100644 index 0000000000..f87d49b736 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARALLEL_STEP_AUTO_PARALLEL_H_ +#define PARALLEL_STEP_AUTO_PARALLEL_H_ + +#include +#include +#include +#include +#include "ir/anf.h" +#include "frontend/optimizer/opt.h" +#include "frontend/parallel/status.h" +#include "pipeline/jit/pipeline.h" + +namespace mindspore { +namespace parallel { +bool IsSplittableOperator(const std::string &); + +bool IsAutoParallelCareNode(const CNodePtr &); + +// main step of Auto-parallel +bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); + +size_t GetLengthOfDataType(const TypePtr &type); + +std::vector ExtractInputParameterByNode(const CNodePtr &node); + +std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); + +std::vector ExtractOutputTypeByNode(const CNodePtr &node); + +Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root); + +Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root); + +void ConstructCostGraphEdges(const std::vector &all_nodes); + +void AugmentCostGraph(const std::vector &all_nodes); + +Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root); + +Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root); + +std::vector> RecInputTensorNames(const std::map::iterator &it, + std::vector> input_tensor_names); + +CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node); +} // namespace parallel +} // namespace mindspore +#endif // PARALLEL_STEP_AUTO_PARALLEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc new file mode 100644 index 0000000000..e9ff347fa3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -0,0 +1,2362 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/step_parallel.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/dynamic_creator.h" +#include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/parallel/graph_util/graph_info.h" +#include "frontend/parallel/graph_util/node_info.h" +#include "frontend/parallel/node_check.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "utils/comm_manager.h" +#include "utils/symbolic.h" +#include "pipeline/jit/static_analysis/prim.h" + +using mindspore::tensor::Tensor; + +namespace mindspore { +namespace parallel { +static const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; +static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; +// g_RefMap, for CNode B input i is a RefKey[Parameter C], +// it will be one item in map with key: C, and value: (B, i) +static std::map> g_RefMap; + +void SetCommunicationOpGroupLabel(std::vector new_node_input) { + if (new_node_input.empty()) { + return; + } + + ValueNodePtr prim_anf_node = new_node_input[0]->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + + auto attrs = prim->attrs(); + auto iter = attrs.find(GROUP); + if (iter != attrs.end()) { + auto value = iter->second; + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + std::string hash_name = value->cast()->value(); + MS_EXCEPTION_IF_NULL(g_device_manager); + std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); + (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); + } + } +} + +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { + MS_EXCEPTION_IF_NULL(node); + OperatorArgs arg_forward = op.second; + ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); + MS_EXCEPTION_IF_NULL(pyop_instance); + OperatorParams params = arg_forward.second; + + std::vector new_node_input = {NewValueNode(pyop_instance), node}; + if (!params.empty()) { + for (auto ¶m : params) { + AnfNodePtr val = NewValueNode(param.first.second); + MS_EXCEPTION_IF_NULL(val); + int32_t position = param.second; + (void)new_node_input.insert(new_node_input.begin() + position, val); + } + } + + // if the op have 'group' attr, set the rank list name for the op + SetCommunicationOpGroupLabel(new_node_input); + return new_node_input; +} + +void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, + const FuncGraphPtr &func_graph, const std::string &instance_name) { + // insert new node before the node + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + std::vector node_input = CreateInput(op, pre_node, instance_name); + CNodePtr new_node = func_graph->NewCNode(node_input); + MS_EXCEPTION_IF_NULL(new_node); + if (instance_name.find(SPLIT_SENS) == std::string::npos) { + new_node->set_in_forward_flag(true); // mark forward flag + } + auto new_node_value = node_input[0]->cast(); + MS_EXCEPTION_IF_NULL(new_node_value); + PrimitivePtr new_node_prim = new_node_value->value()->cast(); + new_node_prim->set_instance_name(instance_name); + new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); + new_node->set_scope(scope); + node_input[0]->set_scope(scope); + manager->SetEdge(node, SizeToInt(index), new_node); +} + +std::string CreateInstanceName(const CNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + if (!IsValueNode(node->input(0))) { + MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; + } + std::string name_base = node->fullname_with_scope(); + std::string name = name_base + "_" + std::to_string(index); + std::string instance_name = HashInstanceName(name); + return instance_name; +} + +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + // step1:get graph manager distribute_operator + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto uses_set = manager->node_users()[node]; + CNodePtr node_to_insert = node; + for (auto &uses_pair : uses_set) { + auto uses_cnode = uses_pair.first->cast(); + MS_EXCEPTION_IF_NULL(uses_cnode); + if (!IsValueNode(uses_cnode->input(0))) { + break; + } + PrimitivePtr value_node_prim = GetValueNode(uses_cnode->input(0)); + MS_EXCEPTION_IF_NULL(value_node_prim); + if (value_node_prim->name() == TUPLE_GETITEM) { + if (uses_set.size() > 1) { + MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size(); + } + node_to_insert = uses_cnode; + } + } + MS_EXCEPTION_IF_NULL(node_to_insert); + std::reverse(forward_op.begin(), forward_op.end()); + + // step2:traverse op_list and insert node + for (size_t index = 0; index < forward_op.size(); ++index) { + std::string instance_name_base = FORWARD_OP; + std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); + std::vector forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); + CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode + MS_EXCEPTION_IF_NULL(forward_node); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + forward_node->set_scope(scope); + forward_node->set_in_forward_flag(true); + forward_input[0]->set_scope(scope); + (void)manager->Replace(node_to_insert, forward_node); // using Replace function to insert node + } +} + +CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(prev); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (uint32_t i = 0; i < num; i++) { + std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev, + CreatInt32Imm(UintToInt(i))}; + auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs); + MS_EXCEPTION_IF_NULL(tuple_get_item); + make_tuple_inputs.push_back(tuple_get_item); + } + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(prev, make_tuple); + return make_tuple; +} + +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(pre_node); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) { + MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!"; + } + for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) { + if (pos >= SizeToInt(node->inputs().size())) { + MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; + } + // Creat new node + AnfNodePtr target_node = node->input(IntToSize(pos)); + MS_EXCEPTION_IF_NULL(target_node); + // Creat instance_name + auto op = (redistribution_oplist_ptr->first)[index]; + std::string op_name = (redistribution_oplist_ptr->first)[index].first; + std::string instance_name_base = REDISTRIBUTION_OP; + std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name; + InsertNode(op, node, IntToSize(pos), target_node, func_graph, instance_name); + if ((redistribution_oplist_ptr->second)[index].first) { + target_node = node->input(IntToSize(pos)); + MS_EXCEPTION_IF_NULL(target_node); + (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph); + } + } +} + +void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, + const std::string &instance_name) { + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; + } + + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (pos >= SizeToInt(node->inputs().size())) { + MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " + << instance_name; + } + // Creat new node + AnfNodePtr pre_node = node->input(IntToSize(pos)); + MS_EXCEPTION_IF_NULL(pre_node); + InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); +} + +TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, + const OperatorInfoPtr &distribute_operator) { + TensorInfo tensorinfo_in; + if (middle_prim->name() == TUPLE_GETITEM) { + auto value_node = middle_node->input(2)->cast(); + MS_EXCEPTION_IF_NULL(value_node); + size_t index_s = IntToSize(GetValue(value_node->value())); + if (index_s >= distribute_operator->outputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index out of range, index: " << index_s + << ", vector size: " << distribute_operator->outputs_tensor_info().size(); + } + tensorinfo_in = distribute_operator->outputs_tensor_info()[index_s]; + } else { + if (distribute_operator->outputs_tensor_info().empty()) { + MS_LOG(EXCEPTION) << "The outputs tensor info is empty"; + } + tensorinfo_in = distribute_operator->outputs_tensor_info()[0]; + } + return tensorinfo_in.tensor_layout(); +} + +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!IsParallelCareNode(node)) { + return nullptr; + } + OperatorInfoPtr distribute_operator = node->operator_info(); + if (distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; + } + return distribute_operator; +} + +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node) { + FuncGraphPtr func_graph = middle_node->func_graph(); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; + } + CNodePtr next_node = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(next_node); + auto middle_value = middle_node->input(0)->cast(); + MS_EXCEPTION_IF_NULL(middle_value); + PrimitivePtr middle_prim = middle_value->value()->cast(); + MS_EXCEPTION_IF_NULL(middle_prim); + OperatorInfoPtr next_distribute_operator = GetDistributeOperator(next_node); + if (next_distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; + } + RankList dev_list = distribute_operator->global_device_list(); + std::string next_prim_name = GetValueNode(next_node->input(0))->name(); + MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; + MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); + // extract tensor layout in and out + if (distribute_operator->outputs_tensor_info().empty()) { + MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; + } + + if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " + << next_distribute_operator->inputs_tensor_info().size(); + } + TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; + TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); + TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); + if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { + MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; + MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " + << next_node->ToString(); + DumpGraph(func_graph, "redistribution_error"); + MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed"; + } + RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); + if (redistribution_oplist_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed"; + } + MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size(); + if (!redistribution_oplist_ptr->first.empty()) { + // insert node before next node + InsertRedistribution(redistribution_oplist_ptr, next_node, func_graph, node_pair.second, pre_node); + } +} + +bool StrategyFound(std::unordered_map attrs) { + auto iter = attrs.find(STRATEGY); + return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); +} + +bool HasStrategy(const FuncGraphPtr &root) { + AnfNodePtr ret = root->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + + for (auto &node : all_nodes) { + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + auto attrs = prim->attrs(); + if (StrategyFound(attrs)) { + return true; + } + } + + return false; +} + +bool IsCommunicationOp(const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(prim); + return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); +} + +bool FindCommunicationOp(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + ValueNodePtr prim_value_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_value_node); + PrimitivePtr prim = GetValueNode(prim_value_node); + MS_EXCEPTION_IF_NULL(prim); + + if (IsCommunicationOp(prim) && cnode->in_forward_flag()) { + MS_EXCEPTION_IF_NULL(prim_value_node->scope()); + MS_LOG(INFO) << "The graph contain communication op: " << prim->name() << ", scope name is " + << prim_value_node->scope()->name(); + return true; + } + } + return false; +} + +bool IsParallelCareNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + ValueNodePtr prim_node = cnode->input(0)->cast(); + if (prim_node == nullptr) { + return false; + } + PrimitivePtr prim = prim_node->value()->cast(); + if (prim == nullptr) { + return false; + } + if (IsInBlackList(prim)) { + MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); + return false; + } + // get_next is not in the forward graph, we need mark the get_next as the forward node + if (prim->name() == GET_NEXT) { + return true; + } + if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { + return false; + } + + return cnode->in_forward_flag(); +} + +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(node->func_graph()); + FuncGraphManagerPtr manager = node->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + CNodePtr insert_node_new; + if (IsValueNode(node->input(0))) { + auto current_value = node->input(0)->cast(); + MS_EXCEPTION_IF_NULL(current_value); + PrimitivePtr current_prim = current_value->value()->cast(); + MS_EXCEPTION_IF_NULL(current_prim); + insert_node_new = ((current_prim->name() == TUPLE_GETITEM) ? node : insert_node); + } else { + insert_node_new = insert_node; + } + MS_EXCEPTION_IF_NULL(insert_node_new); + for (auto &node_pair : node_set) { + CNodePtr use_cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(use_cnode); + if (!IsValueNode(use_cnode->input(0))) { + StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); + } else { + ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { + Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, + pre_node); + } else { + StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); + } + } + } +} + +void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(next_node); + OperatorInfoPtr op_info = next_node->operator_info(); + MS_EXCEPTION_IF_NULL(op_info); + + // If the shape of tensor is [] or [1], no need to split it. + Shapes shapes = GetNodeShape(node); + if (shapes.size() != 1) { + MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name() + << ": GetNodeShape for tensor_node, output size is not 1"; + } + Shape shape = shapes[0]; + std::string shape_str = ShapeToString(shape); + if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) { + MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str + << ", no need to split it."; + return; + } + + MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str; + + // extract tensor layout + if (IntToSize(index - 1) >= op_info->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << index - 1 << ", vector size is " + << op_info->inputs_tensor_info().size(); + } + TensorInfo tensor_info = op_info->inputs_tensor_info()[IntToSize(index - 1)]; + TensorLayout tensor_layout = tensor_info.tensor_layout(); + + // Use _GetTensorSlice operator to split the tensor + FuncGraphPtr func_graph = next_node->func_graph(); // only cnode can get the graph + MS_EXCEPTION_IF_NULL(func_graph); + Operator op = CreateGetTensorSliceOp(tensor_layout); + InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); + if (!op_info->sub_ops().empty()) { + auto sub_ops = op_info->sub_ops(); + for (size_t i = 0; i < sub_ops.size(); i++) { + if (!sub_ops.at(i).empty()) { + InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB); + } + } + } +} + +void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + for (auto &node_pair : node_set) { + CNodePtr use_cnode = node_pair.first->cast(); + if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(use_cnode_prim); + if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_cnode)) { + SplitTensor(node, use_cnode, node_pair.second); + } + } +} + +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node) { + OperatorArgs arg_replace_op = replace_op.second; + ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); + if (pyop_instance == nullptr) { + MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; + } + OperatorParams params = arg_replace_op.second; + if (node->inputs().size() < 2) { + // GetNext operator dose not has input + if (node->inputs().size() == 1) { + return {NewValueNode(pyop_instance)}; + } + MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; + } + std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; + auto prim = GetValueNode(node->input(0)); + if (prim->name() == EMBEDDING_LOOKUP) { + replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; + } + if (!params.empty()) { + Param param_first = *(params.begin()); + int32_t first_position = param_first.second; + if (first_position == 1) { + replace_input.pop_back(); + } + for (auto ¶m : params) { + AnfNodePtr val = NewValueNode(param.first.second); + if (val == nullptr) { + MS_LOG(EXCEPTION) << "Failure:val is nullptr"; + } + int32_t position = param.second; + (void)replace_input.insert(replace_input.begin() + position, val); + } + } + + return replace_input; +} + +void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; + } + std::string instance_name = CreateInstanceName(node, 0); + std::vector replace_input; + replace_input = ReplaceOpInput(replace_op, instance_name, node); + CNodePtr replace_node = func_graph->NewCNode(replace_input); + MS_EXCEPTION_IF_NULL(replace_node); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + replace_node->set_scope(scope); + replace_node->set_in_forward_flag(true); + replace_input[0]->set_scope(scope); + (void)manager->Replace(node, replace_node); +} + +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { + // step1:get graph manager distribute_operator + OperatorInfoPtr distribute_operator = node->operator_info(); + if (distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; + } + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; + } + // step2:traverse op_list and insert node + std::reverse(replace_op.begin(), replace_op.end()); + auto replace_op_info = distribute_operator->replace_op_info(); + std::reverse(replace_op_info.begin(), replace_op_info.end()); + if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) { + MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!"; + } + bool replace_op_info_flag = !replace_op_info.empty(); + for (size_t index = 0; index < replace_op.size(); ++index) { + std::string instance_name = CreateInstanceName(node, index); + std::vector replace_input; + if (index != replace_op.size() - 1) { + replace_input = CreateInput(replace_op[index], node, instance_name); + } else { + replace_input = ReplaceOpInput(replace_op[index], instance_name, node); + } + CNodePtr replace_node = func_graph->NewCNode(replace_input); + MS_EXCEPTION_IF_NULL(replace_node); + ScopePtr scope = node->scope(); + MS_EXCEPTION_IF_NULL(scope); + replace_node->set_scope(scope); + if (index == replace_op.size() - 1) { + (void)replace_node->set_operator_info(node->operator_info()); + } + replace_node->set_in_forward_flag(true); + replace_input[0]->set_scope(scope); + if (replace_op_info_flag && replace_op_info[index].first) { + auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph); + (void)manager->Replace(node, new_cnode); // using Replace function to insert node + } else { + (void)manager->Replace(node, replace_node); // using Replace function to insert node + } + } + MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); +} + +bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { + ValueNodePtr anf_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(anf_node); + PrimitivePtr prim = anf_node->value()->cast(); + return (prim->name() == name); +} + +void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(replace_graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(replace_graph->second); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; + } + for (auto &replace_input : replace_graph->first) { + auto pre_node = node->input(IntToSize(replace_input.second)); + manager->SetEdge(replace_input.first, 1, pre_node); + } + // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called + auto replace_output = replace_graph->second; + MS_EXCEPTION_IF_NULL(replace_output); + (void)manager->Replace(node, replace_output); +} + +int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->inputs().size() != 3) { + MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; + } + + if (!cnode->input(2)->isa()) { + MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node"; + } + + ValuePtr tuple_index_value = GetValueNode(cnode->input(2)); + MS_EXCEPTION_IF_NULL(tuple_index_value); + if (!tuple_index_value->isa()) { + MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32"; + } + return tuple_index_value->cast()->value(); +} + +// Judge whether the node is a loss, and if there are multiple outputs, +// get which output is a grad according to the tuple getitem. +// Currently, it is not supported that the sens is a tuple. +LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { + MS_EXCEPTION_IF_NULL(loss_node); + FuncGraphPtr sub_graph = loss_node->func_graph(); + MS_EXCEPTION_IF_NULL(sub_graph); + CNodePtr return_node = sub_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; + } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); + + LossNodeInfo node_info; + + // return -> cast + auto pre_cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + auto pre_prim = GetValueNode(pre_cnode->input(0)); + if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_node = pre_cnode->input(1); + } + + // return -> loss + if (pre_node == loss_node) { + node_info.has_tuple_getitem = false; + node_info.dout_index = 0; + return node_info; + } + + // return -> tuple_getitem -> loss + auto cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto current_value = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(current_value); + PrimitivePtr current_prim = current_value->value()->cast(); + MS_EXCEPTION_IF_NULL(current_prim); + // size of common cnode is larger than 1 + if (cnode->inputs().size() < 2) { + MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2"; + } + + if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) { + // size of tuple_getitem cnode is 3 + auto tuple_index = GetTupleGetItemIndex(cnode); + node_info.has_tuple_getitem = true; + node_info.dout_index = tuple_index; + return node_info; + } + + MS_LOG(EXCEPTION) << "Invalid loss"; +} + +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + size_t node_size = node->inputs().size(); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (size_t index = 1; index < node_size; ++index) { + AnfNodePtr input = node->input(index); + MS_EXCEPTION_IF_NULL(input); + if (!input->isa() && !input->isa()) { // if it is not a tensor, continue + MS_LOG(INFO) << "insert div op: the index " << index << " is not tensor, skip"; + continue; + } + + for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) { + std::string instance_name = CreateInstanceName(node, pos); + InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name); + } + MS_LOG(INFO) << "insert div op for input index " << index << " of node"; + } +} + +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + if (!node->isa() && !node->isa() && !node->isa()) { + return std::make_pair(nullptr, false); + } else if (node->isa()) { + return std::make_pair(node, false); + } else if (node->isa()) { + if (IsValueNode(node)) { + std::vector param_v = FindParameterByRefKeyNode(node, func_graph); + if (param_v.size() != 1) { + MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " + << param_v.size(); + } + return std::make_pair(node, true); + } + return std::make_pair(nullptr, false); + } else { + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (!FindParameter(cnode->input(index), func_graph).first) { + continue; + } + return FindParameter(cnode->input(index), func_graph); + } + } else { + if (IsParallelCareNode(cnode)) { + return std::make_pair(nullptr, false); + } else { + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + PrimitivePtr prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == DEPEND && index != 1) { + continue; + } + if (!FindParameter(cnode->input(index), func_graph).first) { + continue; + } + return FindParameter(cnode->input(index), func_graph); + } + } + } + } + return std::make_pair(nullptr, false); +} + +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(anode); + MS_EXCEPTION_IF_NULL(anode->func_graph()); + FuncGraphManagerPtr manager = anode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[anode]; + bool result = false; + CNodePtr cnode_return = nullptr; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == name && node_pair.second == 1) { + if (use_apply->func_graph() == func_graph) { + result = true; + cnode_return = use_apply; + MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph"; + continue; + } + MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph"; + } + } + return std::make_pair(result, cnode_return); +} + +bool IsCastBeforMirror(const CNodePtr &node, size_t index) { + // only if cast_before_mirror is true, pre node is cast and type is not float32 return true + if (!ParallelContext::GetInstance()->cast_before_mirror()) { + return false; + } + auto pre_node = node->input(index); + MS_EXCEPTION_IF_NULL(pre_node); + auto cnode = pre_node->cast(); + if (cnode == nullptr || !IsValueNode(cnode->input(0))) { + return false; + } + auto pre_value_node = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(pre_value_node); + auto pre_prim = pre_value_node->value()->cast(); + MS_EXCEPTION_IF_NULL(pre_prim); + if (pre_prim->name() != CAST) { + return false; + } + auto node_type = pre_node->Type(); + MS_EXCEPTION_IF_NULL(node_type); + if (!node_type->isa()) { + MS_LOG(EXCEPTION) << "Unknown type."; + } + auto input_element_type = node_type->cast()->element(); + MS_EXCEPTION_IF_NULL(input_element_type); + auto type_id = input_element_type->type_id(); + + return (type_id != kNumberTypeFloat32); +} + +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + size_t node_size = node->inputs().size(); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + if (mirror_ops.size() != node_size - 1) { + MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() + << ", node_size is " << node_size; + } + for (size_t index = 1; index < node_size; ++index) { + OperatorVector backward_op = mirror_ops[index - 1]; + if (backward_op.empty()) { + continue; + } + std::pair param_node_pair = FindParameter(node->input(index), func_graph); + if (!param_node_pair.first) { + continue; + } + // not a RefKey + if (!param_node_pair.second) { + auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); + // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead + if (next_cnode.first) { + MS_EXCEPTION_IF_NULL(next_cnode.second); + manager->SetEdge(node, SizeToInt(index), next_cnode.second); + continue; + } + } + // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp + // only one MirrorOp in backward_op + if (backward_op.size() != 1) { + MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size(); + } + std::string instance_name = MIRROR_OP; + if (IsCastBeforMirror(node, index)) { + for (auto &op : backward_op) { + // insert new node before the node + CNodePtr cnode = node->input(index)->cast(); + MS_EXCEPTION_IF_NULL(cnode); + AnfNodePtr pre_node = cnode->input(1); + InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); + } + } else { + for (auto &op : backward_op) { + AnfNodePtr pre_node = node->input(index); + InsertNode(op, node, index, pre_node, func_graph, instance_name); + } + } + } +} + +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, + const std::vector> &sens_loss_pairs) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(node); + + bool is_loss_cnode = + std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), + [node](const std::pair &element) { return element.second == node; }); + + MirrorOps mirror_ops = distribute_operator->mirror_ops(); + VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); + // insert mirror op + if (!mirror_ops.empty()) { + MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); + InsertMirrorOps(mirror_ops, node); + } + // insert virtual div op + if (!virtual_div_op.empty() && is_loss_cnode) { + MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name(); + InsertVirtualDivOp(virtual_div_op, node); + } +} + +std::string GetDisOpName(const std::string &prim_name) { + std::string op_name = prim_name; + if (!prim_name.empty() && (prim_name[0] == '_')) { + op_name = prim_name.substr(1); + } + return op_name + "Info"; +} + +OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { + if (shape_list.size() != 2) { + MS_LOG(ERROR) << "The size of shape list is not 2"; + return nullptr; + } + if (name.length() == 0) { + MS_LOG(EXCEPTION) << "Length of name is zero!"; + } + std::string distribute_opname = GetDisOpName(name); + if (name == GATHERV2) { + distribute_opname = name + "PInfo"; + auto data_parallel_iter = attrs.find(DATA_PARALLEL); + if (data_parallel_iter != attrs.end()) { + MS_EXCEPTION_IF_NULL(data_parallel_iter->second); + if (!data_parallel_iter->second->isa()) { + MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; + } + bool data_parallel = data_parallel_iter->second->cast()->value(); + if (data_parallel) { + distribute_opname = name + "Info"; + } + } + } + OperatorInfoPtr operator_ = + (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); + if (operator_ == nullptr) { + MS_LOG(INFO) << "Creat " << name << " failed"; + return nullptr; + } + std::string origin_name = operator_->name(); + operator_->set_name(origin_name + std::to_string(TOTAL_OPS)); + MS_LOG(INFO) << "Successfully created operator " << origin_name; + ++TOTAL_OPS; + return operator_; +} + +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); + if (operator_ == nullptr) { + MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; + operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); + MS_EXCEPTION_IF_NULL(operator_); + } + return operator_; +} + +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + std::vector shape_list) { + OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); + for (size_t i = 0; i < shape_list[0].size(); ++i) { + MS_LOG(INFO) << "No: " << i << " input's shape: " << ShapeToString(shape_list[0][i]); + } + return operator_; +} + +StrategyPtr ExtractStrategy(std::unordered_map attrs) { + ValueTuplePtr var = attrs[STRATEGY]->cast(); + StrategyPtr strategyPtr; + MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); + if (var == nullptr) { + MS_LOG(EXCEPTION) << "Strategy value is nullptr"; + } + if (var->size() > 0) { + std::vector elements = var->value(); + std::vector strategy; + for (uint32_t index = 0; index < elements.size(); ++index) { + Dimensions dim; + if (elements[index]->isa()) { + ValueTuplePtr value_tuple = elements[index]->cast(); + std::vector value_vector = value_tuple->value(); + (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); + strategy.push_back(dim); + } else { + MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; + } + } + if (strategy.empty()) { + MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; + } + strategyPtr = NewStrategy(0, strategy); + } + + return strategyPtr; +} + +Shapes GetNodeShape(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + Shapes shapes; + BaseShapePtr base_shape_ptr = node->Shape(); + if (node->isa()) { + auto cnode = node->cast(); + if (IsValueNode(cnode->input(0))) { + PrimitivePtr prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == MAKEREF) { + AnfNodePtr ref_node = cnode->input(1); + auto func_graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(ref_node); + MS_EXCEPTION_IF_NULL(func_graph); + return GetRefKeyNodeShape(ref_node, func_graph); + } + } + if (cnode->input(0)->isa()) { + if (cnode->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; + } + base_shape_ptr = cnode->input(1)->Shape(); + } + } + if (base_shape_ptr == nullptr) { + MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " + << node->fullname_with_scope(); + } + auto tuple_shape_ptr = dyn_cast(base_shape_ptr); + if (tuple_shape_ptr != nullptr) { + auto tuple_shape = tuple_shape_ptr->shape(); + for (auto &shape : tuple_shape) { + auto each_shape = dyn_cast(shape); + MS_EXCEPTION_IF_NULL(each_shape); + shapes.push_back(each_shape->shape()); + } + } else { + auto shape_ptr = dyn_cast(base_shape_ptr); + MS_EXCEPTION_IF_NULL(shape_ptr); + shapes.push_back(shape_ptr->shape()); + } + return shapes; +} + +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters; + if (!IsValueNode(node)) { + MS_LOG(ERROR) << "The node is not a ref key"; + return parameters; + } + + auto ref_key = GetValueNode(node); + MS_EXCEPTION_IF_NULL(ref_key); + auto name = ref_key->tag(); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto roots = manager->roots(); + if (roots.size() != 1) { + MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; + return parameters; + } + + FuncGraphPtr root_g = roots.back(); + MS_EXCEPTION_IF_NULL(root_g); + for (auto ¶m_node : root_g->parameters()) { + auto param = param_node->cast(); + if (param && (name == param->name())) { + parameters.push_back(param_node); + MS_LOG(INFO) << "The name of ref key is: " << name; + return parameters; + } + } + + MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; + return parameters; +} + +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + + std::vector parameters = FindParameterByRefKeyNode(node, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + + Shapes input_shapes; + input_shapes = GetNodeShape(parameters[0]); + if (input_shapes.size() != 1) { + MS_LOG(EXCEPTION) << "Get input shape failed"; + } + + MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]); + return input_shapes; +} + +std::vector ExtractShape(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + Shapes shape_inputs, shape_outputs; + std::vector shape_all; + std::vector all_inputs = node->inputs(); + std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; + + size_t inputs_size = all_inputs.size(); + for (size_t i = 1; i < inputs_size; ++i) { + Shapes input_shapes; + AnfNodePtr input = all_inputs[i]; + if (IsValueNode(input)) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters = FindParameterByRefKeyNode(input, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + std::pair node_pair = std::make_pair(node, SizeToInt(i)); + g_RefMap[parameters[0]] = node_pair; + input_shapes = GetRefKeyNodeShape(input, func_graph); + } else if (IsValueNode(input) || input->isa() || input->isa()) { + input_shapes = GetNodeShape(input); + } else { + continue; + } + if (input_shapes.size() != 1) { + MS_LOG(EXCEPTION) << "ExtractShape:Get input shape failed"; + } + shape_inputs.push_back(input_shapes[0]); + } + shape_all.push_back(shape_inputs); + // extract out shape + shape_outputs = GetNodeShape(node); + shape_all.push_back(shape_outputs); + return shape_all; +} + +std::pair FindParallelCareNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + FuncGraphPtr func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[node]; + for (auto &node_pair : node_set) { + CNodePtr cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + continue; + } + ValueNodePtr prim_node_anf = cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_node_anf); + PrimitivePtr node_prim = prim_node_anf->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + return node_pair; + } else if (FindParallelCareNode(node_pair.first).first != nullptr) { + return FindParallelCareNode(node_pair.first); + } + } + return std::make_pair(nullptr, 0); +} + +std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(parameter); + FuncGraphManagerPtr manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + std::pair prim_anf_node_pair = FindParallelCareNode(parameter); + if (prim_anf_node_pair.first != nullptr) { + return prim_anf_node_pair; + } else { + AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; + for (auto ¶m_pair : param_sub_set) { + CNodePtr graph_cnode = param_pair.first->cast(); + if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { + continue; + } + CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast(); + if (!IsValueNode(graph_cnode_inp0->input(1))) { + continue; + } + FuncGraphPtr graph_sub = GetValueNode(graph_cnode_inp0->input(1)); + auto parameters = graph_sub->parameters(); + if (IntToSize(param_pair.second - 1) >= parameters.size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " + << parameters.size(); + } + std::pair res = FindSubGraph(graph_sub, parameters[IntToSize(param_pair.second - 1)]); + if (res.first != nullptr) { + return res; + } + } + } + return std::make_pair(nullptr, 0); +} + +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { + MS_EXCEPTION_IF_NULL(parameter); + AbstractBasePtr abstract = parameter->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); + CNodePtr cnode = res.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr distribute_operator = cnode->operator_info(); + if (distribute_operator == nullptr) { + MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; + } + + if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { + MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " + << distribute_operator->inputs_tensor_info().size(); + } + TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; + Shape slice_shape = tensorinfo_in.slice_shape(); + MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " + << MakeValue(slice_shape)->ToString(); + std::shared_ptr parallel_shape = std::make_shared(slice_shape); + MS_EXCEPTION_IF_NULL(parallel_shape); + // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. + auto cloned_abstract = abstract->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(parallel_shape); + parameter->set_abstract(cloned_abstract); + TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); + ParameterPtr parameter_ptr = parameter->cast(); + MS_EXCEPTION_IF_NULL(parameter_ptr); + parameter_ptr->set_tensor_layout(std::make_shared(tensor_layout)); +} + +void CoverSliceShape(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + auto parameters = root->parameters(); + for (auto ¶meter : parameters) { + MS_EXCEPTION_IF_NULL(parameter->Shape()); + auto iter = g_RefMap.find(parameter); + if (iter != g_RefMap.end()) { + SetParallelShape(parameter, g_RefMap[parameter]); + continue; + } + std::pair res = FindSubGraph(root, parameter); + if (res.first == nullptr) { + MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; + } else { + SetParallelShape(parameter, res); + MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); + } + } + g_RefMap.clear(); +} + +bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(parameter_node); + FuncGraphManagerPtr manager = root->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto cloned_parameter = parameter_node->cast(); + MS_EXCEPTION_IF_NULL(cloned_parameter); + + // find the clone parameter + if (!cloned_parameter->has_default()) { + return false; + } + + bool cloned = cloned_parameter->default_param()->cloned(); + if (!cloned) { + return false; + } + + MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; + return true; +} + +void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + for (auto &cloned_parameter_node : root->parameters()) { + MS_EXCEPTION_IF_NULL(cloned_parameter_node); + auto cloned_parameter = cloned_parameter_node->cast(); + MS_EXCEPTION_IF_NULL(cloned_parameter); + + if (!ParameterIsCloned(root, cloned_parameter_node)) { + continue; + } + + // get the cloned index + int32_t cloned_index = cloned_parameter->default_param()->cloned_index(); + + // find the be cloned parameter + bool found_be_cloned_parameter = false; + ParameterPtr cloned_from_parameter = nullptr; + AnfNodePtr cloned_from_node = nullptr; + for (auto &be_cloned_parameter_node : root->parameters()) { + MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); + auto be_cloned_parameter = be_cloned_parameter_node->cast(); + MS_EXCEPTION_IF_NULL(be_cloned_parameter); + if (!be_cloned_parameter->has_default()) { + continue; + } + + const auto ¶m_value_cloned = be_cloned_parameter->default_param(); + if (!param_value_cloned->be_cloned()) { + continue; + } + + // get the be cloned index + auto &be_cloned_index = param_value_cloned->be_cloned_index(); + if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) { + found_be_cloned_parameter = true; + cloned_from_parameter = be_cloned_parameter; + cloned_from_node = be_cloned_parameter_node; + } + } + + if (found_be_cloned_parameter) { + // set the shape and tensor layout for cloned parameter + cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); + MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); + MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); + auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); + cloned_parameter_node->set_abstract(cloned_abstract); + MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() + << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() + << ", clone index is: " << cloned_index; + } else { + MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is " + << cloned_index << ", but not found the be cloned parameter"; + } + } + std::string env = common::GetEnv("SLICE_ENV"); + if (!env.empty()) { + MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; + } +} + +void SetVirtualDatasetStrategy(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool full_batch = ParallelContext::GetInstance()->full_batch(); + + PrimitivePtr prim = GetValueNode(node->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == VIRTUAL_DATA_SET) { + CheckGlobalDeviceManager(); + int32_t dev_num; + if (full_batch) { + dev_num = 1; + } else { + dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + } + auto attrs_temp = prim->attrs(); + std::vector shape_list = ExtractShape(node); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; + } + std::vector elements; + for (size_t i = 0; i < shape_list[0].size(); i++) { + if (shape_list[0][i].empty()) { + MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; + } + std::vector input_strategy = {dev_num}; + for (size_t j = 1; j < shape_list[0][i].size(); j++) { + input_strategy.push_back(1); + } + elements.push_back(MakeValue(input_strategy)); + } + ValueTuplePtr strategy = std::make_shared(elements); + attrs_temp[STRATEGY] = strategy; + (void)prim->SetAttrs(attrs_temp); + } +} + +void ExtractInformation(const std::vector &all_nodes) { + // load strategy map from checkpoint + StrategyMap stra_map; + if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { + if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; + } + } + for (auto &node : all_nodes) { + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + SetVirtualDatasetStrategy(cnode); + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + auto attrs = prim->attrs(); + MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); + if (IsParallelCareNode(cnode)) { + std::vector shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; + } + OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); + if (operator_ == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; + } + auto &inputs = cnode->inputs(); + std::vector input_value; + for (size_t index = 1; index < inputs.size(); ++index) { + if (inputs[index]->isa()) { + input_value.push_back(GetValueNode(inputs[index])); + } else { + input_value.emplace_back(nullptr); + } + } + StrategyPtr strategyPtr = nullptr; + (*operator_).set_input_value(input_value); + (*operator_).set_outputs_dtype(cnode->Type()); + (*operator_).set_cnode(cnode); + if (prim->name() == RESHAPE) { + (void)cnode->set_operator_info(operator_); + continue; + } + // load strategy checkpoint + // key of strategy map + std::string strategy_key_name = NodeParameterName(cnode); + bool load_strategy_from_ckpt = + StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); + if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { + MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() + << " is empty, using batch parallel"; + std::shared_ptr> strategy_v_ptr = operator_->GenerateBatchStrategies(); + if (strategy_v_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; + } + std::vector elements; + for (size_t i = 0; i < strategy_v_ptr->size(); i++) { + elements.push_back(MakeValue((*strategy_v_ptr)[i])); + } + ValueTuplePtr strategy = std::make_shared(elements); + // display the strategy generated by batch parallel + attrs[GEN_STRATEGY] = strategy; + (void)prim->SetAttrs(attrs); + MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " + << attrs[GEN_STRATEGY]->ToString(); + strategyPtr = NewStrategy(0, *strategy_v_ptr); + } else if (load_strategy_from_ckpt) { + strategyPtr = stra_map[strategy_key_name]; + } else { + strategyPtr = ExtractStrategy(attrs); + } + if (strategyPtr != nullptr) { + if (operator_->Init(strategyPtr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; + } + (void)cnode->set_operator_info(operator_); + } else { + MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; + } + } + } +} + +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { + CNodePtr cnode = node_pair.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + MS_EXCEPTION_IF_NULL(distribute_operator); + int index = node_pair.second; + if (index > SizeToInt(distribute_operator->inputs_tensor_info().size())) { + MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << index - 1 << ", the vector size is " + << distribute_operator->inputs_tensor_info().size(); + } + TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; + TensorLayout tensorlayout_in = tensorinfo_in.tensor_layout(); + return tensorlayout_in; +} + +// if reshape's output connect to several primitive, return the first layout found +std::shared_ptr FindNextLayout(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + FuncGraphManagerPtr manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[cnode]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); + auto layout = GetInputLayoutFromCNode(node_pair); + return std::make_shared(layout); + } + MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) + << " " << (use_apply->operator_info() != nullptr); + + auto layout_ptr = FindNextLayout(use_apply); + if (layout_ptr) { + return layout_ptr; + } + } + MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error"; + return nullptr; +} + +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + MS_EXCEPTION_IF_NULL(distribute_operator); + if (distribute_operator->outputs_tensor_info().size() < output_index) { + MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() + << ", must be less than output_index " << output_index; + } + TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; + TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); + return std::make_shared(tensorlayout_out); +} + +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { + if (!node->isa()) { + return nullptr; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return nullptr; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); + if (!layout_ptr) { + MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; + } + return layout_ptr; + } + return nullptr; +} + +std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { + // Create DataParallel tensor layout for parameter(support WideDeep). + CheckGlobalDeviceManager(); + int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); + TensorLayout input_tensor_layout; + // create input_shape + Shapes inputs_shape = GetNodeShape(node); + Shape input_shape_array = inputs_shape[0]; + if (input_shape_array.empty()) { + MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter."; + } + // create tensor_map + size_t shape_size = input_shape_array.size(); + TensorMap input_tensor_map_array(SizeToInt(shape_size) - 1, -1); + input_tensor_map_array.insert(input_tensor_map_array.begin(), 0); + // create dev_matrix + Shape dev_matrix_array = {dev_num}; + if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) { + MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed."; + } + return std::make_shared(input_tensor_layout); +} + +std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { + if (node->isa()) { + return CreateParameterLayout(node); + } + if (!node->isa()) { + return nullptr; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return nullptr; + } + if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); + if (!layout_ptr) { + MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; + } + return layout_ptr; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == TUPLE_GETITEM) { + auto tuple_index = GetTupleGetItemIndex(cnode); + auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), IntToSize(tuple_index)); + if (!layout_ptr) { + MS_LOG(EXCEPTION) + << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node " + "before tuple_getitem!"; + } + return layout_ptr; + } + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (prim->name() == DEPEND && index != 1) { + continue; + } + auto layout_ptr = FindPrevLayout(cnode->inputs()[index]); + if (!layout_ptr) { + continue; + } + return layout_ptr; + } + MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error"; + return nullptr; +} + +void ReshapeInit(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + continue; + } + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info == nullptr) { + MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; + } + if (prim->name() != RESHAPE) { + continue; + } + auto attrs = prim->attrs(); + if (StrategyFound(attrs)) { + MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; + } + MS_ASSERT(cnode->inputs().size() == 3); + auto prev_layout_ptr = FindPrevLayout(cnode->input(1)); + if (prev_layout_ptr) { + auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); + reshape_info_ptr->SetInputLayout(*prev_layout_ptr); + } + auto next_layout_ptr = FindNextLayout(cnode); + if (next_layout_ptr) { + auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); + reshape_info_ptr->SetOutputLayout(*next_layout_ptr); + } + if (operator_info->Init(nullptr) == FAILED) { + MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed"; + } + } +} + +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; + } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); + + auto pre_cnode = pre_node->cast(); + if (pre_cnode == nullptr) { + return nullptr; + } + + auto current_prim = GetValueNode(pre_cnode->input(0)); + // return -> cast + if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_cnode = pre_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + current_prim = GetValueNode(pre_cnode->input(0)); + } + + // notice: the GetNext op has not input + if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(INFO) << "The loss is: " << current_prim->name(); + return pre_cnode; + } + + // size of common cnode is larger than 1 + if (pre_cnode->size() < 2) { + MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; + } + + // return -> tuple_getitem -> loss + if (current_prim->name() == TUPLE_GETITEM) { + AnfNodePtr pre_pre_node = pre_cnode->input(1); + MS_EXCEPTION_IF_NULL(pre_pre_node); + + auto pre_pre_cnode = pre_pre_node->cast(); + auto value = pre_pre_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value); + PrimitivePtr prim = value->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(DEBUG) << "The loss name is " << prim->name(); + return pre_pre_cnode; + } + + // return -> make_tuple + if (current_prim->name() == MAKE_TUPLE) { + MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; + } + + // return -> loss + MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); + return pre_cnode; +} + +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { + TensorLayouts ret; + MS_EXCEPTION_IF_NULL(loss_cnode); + AnfNodePtr node = loss_cnode->cast(); + MS_EXCEPTION_IF_NULL(node); + + LossNodeInfo node_info = GetLossNodeInfo(node); + ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; + return ret; + } + + OperatorInfoPtr operator_info = loss_cnode->operator_info(); + MS_EXCEPTION_IF_NULL(operator_info); + TensorInfo loss_grad_tensor_info; + size_t op_output_size = operator_info->outputs_tensor_info().size(); + MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " + << node_info.has_tuple_getitem << ", the output size is " << op_output_size << ", the dout_index is " + << node_info.dout_index; + + if ((op_output_size == 0) || (op_output_size <= IntToSize(node_info.dout_index))) { + MS_LOG(EXCEPTION) << "The index is " << node_info.dout_index << ", but the size of outputs is " << op_output_size; + } + + if (!node_info.has_tuple_getitem && (op_output_size > 1)) { + MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple."; + } + + loss_grad_tensor_info = operator_info->outputs_tensor_info()[IntToSize(node_info.dout_index)]; + ret.push_back(loss_grad_tensor_info.tensor_layout()); + return ret; +} + +void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { + MS_EXCEPTION_IF_NULL(grad_sens_node); + if (grad_sens_node->size() <= 1) { + MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2"; + } + AnfNodePtr sens_tensor_node = grad_sens_node->input(1); + MS_EXCEPTION_IF_NULL(sens_tensor_node); + Shapes sens_shapes = GetNodeShape(sens_tensor_node); + if (sens_shapes.size() != 1) { + MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1"; + } + // If the shape of sens tensor is [] or [1], no need to split it. + Shape sens_shape = sens_shapes[0]; + if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) { + if (sens_tensor_node->isa()) { + auto sens_tensor_param = sens_tensor_node->cast(); + MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); + sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + } + MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; + return; + } + auto loss_shape = loss_grad_layout.tensor_shape().array(); + if (loss_shape != sens_shape) { + MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is " + << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape); + } + MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it."; + + if (!IsValueNode(sens_tensor_node)) { + if (sens_tensor_node->isa()) { + MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); + AbstractBasePtr abstract = sens_tensor_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + auto slice_shape = loss_grad_layout.slice_shape().array(); + std::shared_ptr parallel_shape = std::make_shared(slice_shape); + MS_EXCEPTION_IF_NULL(parallel_shape); + auto cloned_abstract = abstract->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(parallel_shape); + sens_tensor_node->set_abstract(cloned_abstract); + auto sens_tensor_param = sens_tensor_node->cast(); + sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + return; + } + MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; + } + + // Use _GetTensorSlice operator to split the sens tensor + FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph + MS_EXCEPTION_IF_NULL(func_graph); + Operator op = CreateGetTensorSliceOp(loss_grad_layout); + InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS); +} + +void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(cnode); + OperatorVector forward_op = distribute_operator->forward_op(); + if (!forward_op.empty()) { + MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name(); + ForwardCommunication(forward_op, cnode); + } +} + +void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(cnode); + // StepReplaceOp + OperatorVector replace_op = distribute_operator->replace_op(); + if (!replace_op.empty()) { + MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString(); + StepReplaceOp(replace_op, cnode); + } + + // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore. + ReplaceGraphPtr replace_graph = distribute_operator->replace_graph(cnode); + if (!replace_op.empty() && replace_graph) { + MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used"; + } + if (replace_graph) { + MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString(); + StepReplaceGraph(replace_graph, cnode); + } +} + +void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(distribute_operator); + MS_EXCEPTION_IF_NULL(cnode); + + std::string op_name = distribute_operator->name(); + if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) { + return; + } + + DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast(distribute_operator); + MS_EXCEPTION_IF_NULL(dropout_do_mask); + std::vector replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); + if (replace_op.empty()) { + MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; + return; + } + if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; + } + ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); +} + +void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { + HandleDropoutNode(distribute_operator, cnode); +} + +std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + // J->CNode->Graph + std::set graph_set; + for (auto &node : root_all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { + continue; + } + auto expect_j_prim = GetValueNode(cnode->input(0)); + if (expect_j_prim->name() != J) { + continue; + } + if (IsValueNode(cnode->input(1))) { + auto graph = GetValueNode(cnode->input(1)); + MS_LOG(DEBUG) << "Find the forward graph success"; + graph_set.insert(graph); + } + } + return graph_set; +} + +void StepSplitSens(const std::pair &sens_loss_pair) { + CNodePtr sens_node = sens_loss_pair.first; + CNodePtr loss_node = sens_loss_pair.second; + auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); + if (!loss_grad_layout.empty()) { + SplitSens(sens_node, loss_grad_layout[0]); + } +} + +// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) +std::vector> GetSensLossPairs(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + std::vector> sens_loss_pairs; + for (auto &node : root->nodes()) { + if (!node->isa()) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem) + auto sens_cnode = node->cast(); + AnfNodePtr expect_tuple_getitem = sens_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem); + if (!expect_tuple_getitem->isa()) { + continue; + } + + auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); + if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode + AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(expect_anonymous); + if (!expect_anonymous->isa()) { + continue; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) + auto expect_anonymous_cnode = expect_anonymous->cast(); + AnfNodePtr expect_j = expect_anonymous_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_j); + if (!expect_j->isa()) { + continue; + } + auto expect_j_cnode = expect_j->cast(); + if (!IsSomePrimitive(expect_j_cnode, J)) { + continue; + } + + if (!IsValueNode(expect_j_cnode->input(1))) { + MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; + } + auto func_graph = GetValueNode(expect_j_cnode->input(1)); + auto loss_cnode = FindLossCNode(func_graph); + if (loss_cnode == nullptr) { + MS_LOG(WARNING) << "Can not find the loss cnode"; + continue; + } + std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); + sens_loss_pairs.push_back(sens_loss_pair); + } + return sens_loss_pairs; +} + +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(manager); + TensorRedistribution tensor_redistribution; + + std::vector> sens_loss_pairs = GetSensLossPairs(root); + bool has_backward = !sens_loss_pairs.empty(); + // split sens must before inserting the operators. + for (auto &pair : sens_loss_pairs) { + // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. + // If the type of sens node is not Tensor, it is unsupported now, do nothing default. + StepSplitSens(pair); + } + + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + if (distribute_operator == nullptr) { + continue; + } + + // insert forward ops + InsertForwardOps(distribute_operator, cnode); + + // insert redistribution ops + StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); + + // insert backward ops + if (has_backward) { + BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); + } + + HandleSpecialNode(distribute_operator, cnode); + } else if (IsValueNode(node)) { + StepSplitTensor(node, manager); + } + } + + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); + if (distribute_operator == nullptr) { + continue; + } + // StepReplace + StepReplace(distribute_operator, cnode); + } + } +} + +namespace { +void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(node); + auto symbolic_key = GetValueNode(node); + MS_EXCEPTION_IF_NULL(symbolic_key); + auto all_upstream_node = root->manager()->node_users()[node]; + for (auto &upstream_node : all_upstream_node) { + FuncGraphPtr fg = upstream_node.first->func_graph(); + if (symbolic_key->node()->isa()) { + for (auto ¶m : root->parameters()) { + if (*param == *symbolic_key->node()) { + AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); + MS_EXCEPTION_IF_NULL(reverted_node); + MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString(); + (void)fg->manager()->Replace(node, reverted_node); + MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString(); + } + } + } + } +} +} // namespace + +void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { + MS_EXCEPTION_IF_NULL(root); + for (auto &node : all_nodes) { + // revert back SymbolicKeyInstance to embed() primitive + if (IsValueNode(node)) { + RevertSymbolicKeyInstance(root, node); + continue; + } + } +} + +std::string NodeParameterName(const CNodePtr &node) { + std::vector node_inputs{node->inputs()}; + for (auto input : node_inputs) { + if (input->isa()) { + auto input_parameter = input->cast(); + if (input_parameter->has_default()) { + const auto ¶m_value = input_parameter->default_param(); + if (param_value->requires_grad()) { + return param_value->name(); + } + } + } + } + return ""; +} + +void CheckpointStrategy(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; + StrategyMap stra_map; + auto ret = func_graph->get_return(); + auto all_nodes = DeepScopedGraphSearch(ret); + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + std::string param_name = NodeParameterName(cnode); + if (param_name.empty()) { + continue; + } + PrimitivePtr prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + OperatorInfoPtr operator_info = cnode->operator_info(); + if (operator_info) { + if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } + StrategyPtr strategyPtr = operator_info->strategy(); + MS_EXCEPTION_IF_NULL(node->scope()); + stra_map[param_name] = strategyPtr; + } + } + if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { + MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; + } +} + +void SetForwardFlag(const std::vector &all_nodes) { + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + + // CNode is globally unique. + MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << "."; + cnode->set_in_forward_flag(true); + } +} + +void SetForwardFlag(const AnfNodeSet &all_nodes) { + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + continue; + } + + // CNode is globally unique. + cnode->set_in_forward_flag(true); + } +} + +std::set ForwardGraph(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + return graph_set; +} + +std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { + MS_EXCEPTION_IF_NULL(graph); + std::vector root_forward_nodes; + auto loss_cnode = FindLossCNode(graph); + if (loss_cnode == nullptr) { + MS_LOG(WARNING) << "Can not find the loss cnode"; + return root_forward_nodes; + } + + auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); + for (auto &node : all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + auto root_node_id = node->UniqueIdThroughCopy(); + if (loss_cnode_id == root_node_id) { + root_forward_nodes = DeepLinkedGraphSearch(cnode); + break; + } + } + return root_forward_nodes; +} + +void MarkForwardCNode(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + auto all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + + if (graph_set.empty()) { + MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; + SetForwardFlag(all_nodes); + } else { + for (auto &func_graph : graph_set) { + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); + auto return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); + SetForwardFlag(all_dfs_nodes); + auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); + if (root_forward_nodes.empty()) { + continue; + } + // Mark forward flag for the nodes in root graph. + SetForwardFlag(root_forward_nodes); + } + } +} + +Status ParallelInit() { + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + int32_t device_num = ParallelContext::GetInstance()->device_num(); + int32_t global_rank = ParallelContext::GetInstance()->global_rank(); + std::string backend = ParallelContext::GetInstance()->communication_backend(); + std::string world_group; + + if (backend == HCCL_BACKEND) { + world_group = HCCL_WORLD_GROUP; + } else if (backend == NCCL_BACKEND) { + world_group = NCCL_WORLD_GROUP; + } else { + MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; + } + + uint32_t world_rank_size = 0; + if (!ParallelContext::GetInstance()->device_num_is_set()) { + if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { + MS_LOG(EXCEPTION) << "Get rank size failed"; + } + device_num = UintToInt(world_rank_size); + MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num; + } + + uint32_t rank_id = 0; + if (!ParallelContext::GetInstance()->global_rank_is_set()) { + if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { + MS_LOG(EXCEPTION) << "Get rank id failed"; + } + global_rank = UintToInt(rank_id); + MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; + } + + if (!InitDevice(device_num, global_rank, backend)) { + MS_LOG(ERROR) << "Init device failed"; + return FAILED; + } + + MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank + << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() + << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); + return SUCCESS; +} + +bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(optimizer); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); + // assume no change to graph + bool changes = false; + // control whether use model_parallel mode + if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || + (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { + if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { + if (HasStrategy(root)) { + MS_LOG(INFO) << "Strategies ignored in " << parallel_mode + << ", set_strategy() only valid in [semi_]auto_parallel."; + } + root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); + } + + return changes; + } + + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + + MS_LOG(INFO) << "Now entering step parallel"; + DumpGraph(root, std::string(STEP_PARALLEL_BEGIN)); + + pipeline::ResourceBasePtr res = optimizer->resource(); + MS_EXCEPTION_IF_NULL(res); + + FuncGraphManagerPtr manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodePtr ret = root->get_return(); + MS_EXCEPTION_IF_NULL(ret); + std::vector all_nodes = DeepScopedGraphSearch(ret); + std::reverse(all_nodes.begin(), all_nodes.end()); + if (parallel_mode != AUTO_PARALLEL) { + TOTAL_OPS = 0; + if (ParallelInit() != SUCCESS) { + MS_LOG(EXCEPTION) << "Parallel init failed"; + } + + // mark the forward cnodes, parallel only care these nodes + MarkForwardCNode(root); + + if (FindCommunicationOp(all_nodes)) { + MS_LOG(EXCEPTION) << "The graph contain communication op"; + } + + // extract shape and strategy, set operator_info + ExtractInformation(all_nodes); + ReshapeInit(all_nodes); + } + // save strategy as checkpoint for multi-train + if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { + CheckpointStrategy(root); + } + + HandleSymbolicKeyInstance(root, all_nodes); + + // cover Parallel shape + CoverSliceShape(root); + + // set the shape for optimizer's clone tensor + SetClonedTensorShapeForOptimizer(root); + + // ForwardCommunication BackwardCommunication TensorRedistribution + ParallelCommunication(root, all_nodes, manager); + + DumpGraph(root, std::string(STEP_PARALLEL_END)); + + // step parallel only run once + root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true); + res->results()[pipeline::kStepParallelGraph] = root; + + // in auto parallel mode, no need to check if stategies set + root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); + + (void)gettimeofday(&end_time, nullptr); + uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + time += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us"; + return changes; +} + +// Needed by rec_parser +std::vector ExtractInputsTensorName(const CNodePtr &node) { + std::vector name_inputs; + std::vector all_inputs = node->inputs(); + std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; + + std::string node_id = node->UniqueId(); + name_inputs.push_back(node_id); + for (auto &input : node_inputs) { + std::string name = input->UniqueId(); + name_inputs.push_back(name); + } + + return name_inputs; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h new file mode 100644 index 0000000000..f9fe67ea6b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -0,0 +1,155 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ +#define MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "./common.h" +#include "frontend/optimizer/opt.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" + +using OperatorInfoPtr = std::shared_ptr; + +namespace mindspore { +namespace parallel { +const uint64_t kUSecondInSecond = 1000000; + +struct LossNodeInfo { + bool has_tuple_getitem = false; + int dout_index = 0; // now don't support the sens is a tuple +}; + +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); +std::string CreateInstanceName(const CNodePtr &node, size_t index); +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); + +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); + +TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, + const OperatorInfoPtr &distribute_operator_pre); + +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); + +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node); + +bool StrategyFound(std::unordered_map attrs); + +bool IsParallelCareNode(const CNodePtr &cnode); + +void MarkForwardCNode(const FuncGraphPtr &root); + +bool FindCommunicationOp(const std::vector &all_nodes); + +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); + +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node); + +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); + +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); + +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); + +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); + +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); + +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, + const std::vector> &sens_loss_pairs); + +// Generate and init parallel operator +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list); + +// Generate without initing parallel operator +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + std::vector shape_list); + +// Extract strategy from attr +StrategyPtr ExtractStrategy(std::unordered_map attrs); + +Shapes GetNodeShape(const AnfNodePtr &node); + +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); + +// Extract shape from anfnode +std::vector ExtractShape(const CNodePtr &node); + +std::pair FindParallelCareNode(const AnfNodePtr &node); + +// Find finally sub graph +std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); + +// Set distribute shape for parameters abstract +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); + +// change parameters'shape in resource +void CoverSliceShape(const FuncGraphPtr &root); + +void SetVirtualDatasetStrategy(const CNodePtr &node); + +// Creat parallel operator for primitive node(has strategy) +void ExtractInformation(const std::vector &all_nodes); + +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); + +std::shared_ptr FindNextLayout(const CNodePtr &node); + +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); + +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); + +std::shared_ptr FindPrevLayout(const AnfNodePtr &node); + +void ReshapeInit(const std::vector &all_nodes); + +// Add node for whole graph +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager); + +std::string NodeParameterName(const CNodePtr &node); + +void CheckpointStrategy(const FuncGraphPtr &func_graph); + +// main step of Parallel +bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); + +int32_t GetTupleGetItemIndex(const CNodePtr &cnode); + +Status ParallelInit(); + +std::vector ExtractInputsTensorName(const CNodePtr &node); + +std::set ForwardGraph(const FuncGraphPtr &root); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/strategy.h b/mindspore/ccsrc/frontend/parallel/strategy.h new file mode 100644 index 0000000000..ca01164a6a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/strategy.h @@ -0,0 +1,74 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ +#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +#define MIN_SLICE_NUM 1 + +using Dimensions = std::vector; + +class Strategy; +using StrategyPtr = std::shared_ptr; + +class Strategy { + public: + Strategy(int32_t stage, std::vector inputs) : stage_(stage), inputs_(std::move(inputs)) {} + ~Strategy() = default; + size_t GetInputNumber() const { return inputs_.size(); } + std::vector GetInputDim() const { return inputs_; } + int32_t GetInputStage() const { return stage_; } + void ExpandInputDimFromOneToTwo() { + if (inputs_.size() == 1) { + inputs_.push_back(inputs_[0]); + } + } + void ResetInputs(const std::vector &input) { inputs_ = input; } + + bool IsEqual(const StrategyPtr &another_stra) { + if (another_stra == nullptr) { + return false; + } + if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) { + return false; + } + return true; + } + + private: + const int32_t stage_; + + // The size of Dimensions must equal to inputs_ tensor dimension. + std::vector inputs_; +}; + +inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { + return std::make_shared(stage, inputs); +} +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc new file mode 100644 index 0000000000..bf7c4e29ab --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" + +#include +#include +#include + +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" +#include "proto/node_strategy.pb.h" + +namespace mindspore { +namespace parallel { +StrategyCheckpoint &StrategyCheckpoint::GetInstance() { + static StrategyCheckpoint instance = StrategyCheckpoint(); + if (ParallelContext::GetInstance() != nullptr) { + instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file(); + instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); + instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); + instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); + } + return instance; +} + +bool StrategyCheckpoint::CheckPointExit(const std::string path) const { + std::ifstream fin(path); + if (fin) { + return true; + } + return false; +} + +Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { + if (strategy_map == nullptr) { + MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; + } + if (!CheckPointExit(load_file_)) { + MS_LOG(EXCEPTION) << "CheckPoint file is not found"; + } + straspb::ParallelStrategyMap parallel_strategy_map; + std::fstream input(load_file_, std::ios::in | std::ios::binary); + if (!parallel_strategy_map.ParseFromIstream(&input)) { + MS_LOG(ERROR) << "Load strategy file failed"; + return FAILED; + } + size_t node_num = IntToSize(parallel_strategy_map.parallel_strategy_item_size()); + for (size_t i = 0; i < node_num; i++) { + straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i)); + std::string node_name = parallel_strategy_item.node_name(); + straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); + auto stage = (int32_t)parallel_strategys.stage(); + size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size()); + std::vector> strategy_inputs; + for (size_t j = 0; j < strategys_num; j++) { + straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); + std::vector dimension; + size_t dim_num = IntToSize(parallel_strategy.dim_size()); + for (size_t k = 0; k < dim_num; k++) { + dimension.push_back(parallel_strategy.dim(SizeToInt(k))); + } + strategy_inputs.push_back(dimension); + } + + StrategyPtr strategy = NewStrategy(stage, strategy_inputs); + (*strategy_map)[node_name] = strategy; + current_stage_ = (int32_t)parallel_strategy_map.current_stage(); + } + return SUCCESS; +} + +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { + straspb::ParallelStrategyMap parallel_strategy_map; + parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); + for (auto &node_stra : strategy_map) { + straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); + MS_EXCEPTION_IF_NULL(parallel_strategy_item); + parallel_strategy_item->set_node_name(node_stra.first); + straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); + MS_EXCEPTION_IF_NULL(parallel_strategys); + MS_EXCEPTION_IF_NULL(node_stra.second); + parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); + for (auto &dims : node_stra.second->GetInputDim()) { + straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); + MS_EXCEPTION_IF_NULL(parallel_strategy); + for (auto dim : dims) { + parallel_strategy->add_dim(IntToUint(dim)); + } + } + } + std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); + if (!parallel_strategy_map.SerializeToOstream(&output)) { + MS_LOG(ERROR) << "Save strategy file failed"; + return FAILED; + } + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h new file mode 100644 index 0000000000..67cbb92ee2 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ +#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ + +#include +#include +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/context.h" + +namespace mindspore { +namespace parallel { +using StrategyMap = std::unordered_map; +class StrategyCheckpoint { + public: + StrategyCheckpoint() { + current_stage_ = 0; + load_file_ = ""; + load_checkpoint_on_ = false; + save_file_ = ""; + save_checkpoint_on_ = false; + } + ~StrategyCheckpoint() = default; + + Status Load(StrategyMap *strategy_map); + Status Save(const StrategyMap &strategy_map); + + static StrategyCheckpoint &GetInstance(); + bool LoadCheckPointOn() const { return load_checkpoint_on_; } + bool SaveCheckPointOn() const { return save_checkpoint_on_; } + + private: + std::string load_file_; + std::string save_file_; + bool load_checkpoint_on_; + bool save_checkpoint_on_; + bool CheckPointExit(const std::string path) const; + int32_t current_stage_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc new file mode 100644 index 0000000000..cff3d53a88 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/arrangement.h" +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status Arrangement::Init(const std::vector &array) { + Status status = Array::Init(array); + if (status != Status::SUCCESS) { + return Status::FAILED; + } + if (!IsValidArrangement()) { + MS_LOG(ERROR) << "invalid arrangement " << this->ToString(); + return Status::FAILED; + } + ComputeSize(); + return Status::SUCCESS; +} + +bool Arrangement::IsValidArrangement() { + return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; }); +} + +void Arrangement::ComputeSize() { + size_ = 1; + for (auto &value : array_) { + size_ *= value; + } +} + +/* + * if GetDimSize() = 0, return [] + * if value <= array_[0], return [value] + * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]], + * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1], + * if value > size_, return [] + */ +std::vector Arrangement::GetFrontElementByValue(int32_t value) const { + std::vector out; + if (GetDimSize() == 0) { + return out; + } + if (value <= size_) { + int32_t size = 1; + uint32_t shape_list_idx = 0; + while (size < value) { + size *= array_[shape_list_idx]; + if (size <= value) { + out.push_back(array_[shape_list_idx]); + } else { + if (size == 0) { + MS_LOG(ERROR) << "The size is 0"; + out.clear(); + return out; + } + out.push_back(value * array_[shape_list_idx] / size); + } + shape_list_idx++; + } + } + return out; +} + +std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( + const std::vector &expand_list) const { + if (expand_list.size() != GetDimSize()) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i < expand_list.size(); i++) { + std::vector expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); + if (expand_shape.empty()) { + new_shape.push_back(GetDimByIdx(i)); + } else { + (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end()); + } + } + Arrangement arrangement_new; + (void)arrangement_new.Init(new_shape); + return std::make_shared(arrangement_new); +} + +/* + * example: + * expand_shape = [4, 2, 2, 2] + * array_ = [8, 4], + * arrangement_list = [[4, 2], [2, 2]] + */ +std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { + int32_t size = 1; + uint32_t ind = 0; + std::vector arrangement_list; + std::vector shape; + for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) { + size *= expand_shape.GetDimByIdx(i); + if (size > GetDimByIdx(ind)) { + MS_LOG(ERROR) << "invalid expand_shape"; + return nullptr; + } else if (size < GetDimByIdx(ind)) { + shape.push_back(expand_shape.GetDimByIdx(i)); + continue; + } else { + shape.push_back(expand_shape.GetDimByIdx(i)); + Arrangement arrangement; + (void)arrangement.Init(shape); + arrangement_list.push_back(arrangement); + shape.clear(); + ind++; + size = 1; + } + } + if (ind != GetDimSize()) { + MS_LOG(ERROR) << "invalid expand_shape"; + return nullptr; + } + auto arrangement_new = std::make_shared>(arrangement_list); + return arrangement_new; +} + +std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( + const Arrangement &expand_shape) const { + std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); + if (expand_shape_list_ptr == nullptr) { + return nullptr; + } + std::vector expand_num_list_shape; + (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), + std::back_inserter(expand_num_list_shape), + [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); + Arrangement expand_num_list; + Status status = expand_num_list.Init(expand_num_list_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list); + return std::make_shared, Arrangement>>(out_value); +} + +std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { + std::vector shape_accum; + int32_t size = 0; + for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) { + shape_accum.push_back(size); + size += *iter; + } + return shape_accum; +} + +std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( + const std::vector &expand_list) const { + if (expand_list.size() != GetDimSize()) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i < expand_list.size(); i++) { + if (expand_list[i].GetDimSize() >= 1) { + int32_t size = 1; + for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { + new_shape.push_back(expand_list[i].GetDimByIdx(k)); + size *= expand_list[i].GetDimByIdx(k); + } + new_shape.push_back(GetDimByIdx(i) / size); + } else { + new_shape.push_back(GetDimByIdx(i)); + } + } + Arrangement arrangement_new; + (void)arrangement_new.Init(new_shape); + return std::make_shared(arrangement_new); +} + +std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { + std::vector in1_accum; + Status status = ShapeToAccumulateProduct(array_, &in1_accum); + if (status != Status::SUCCESS) { + return nullptr; + } + std::vector in2_accum; + status = ShapeToAccumulateProduct(in2.array(), &in2_accum); + if (status != Status::SUCCESS) { + return nullptr; + } + std::vector out_accum; + status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); + if (status != Status::SUCCESS) { + return nullptr; + } + std::vector out_shape; + status = AccumulateProductToShape(out_accum, &out_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + Arrangement out; + status = out.Init(out_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +std::vector Arrangement::GetSqueezeIdx() const { + std::vector out; + for (size_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(SizeToUint(i)) == 1) { + out.push_back(i); + } + } + return out; +} + +Arrangement Arrangement::GetSqueezeArrangement() const { + std::vector out_shape(array_.size()); + auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; }); + out_shape.resize(LongToSize(std::distance(out_shape.begin(), it))); + + // if all elements are 1, out_shape = {1} + if (out_shape.empty()) { + MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; + out_shape.push_back(1); + } + Arrangement out; + (void)out.Init(out_shape); + return out; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h new file mode 100644 index 0000000000..ab807fb20a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/arrangement.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ + +#include +#include +#include +#include +#include +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/array.h" + +namespace mindspore { +namespace parallel { +class Arrangement : public Array { + public: + Arrangement() : size_(1) {} + ~Arrangement() override = default; + Status Init(const std::vector &array) override; + int32_t size() const { return size_; } + std::vector GetFrontElementByValue(int32_t value) const; + std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; + std::vector ComputeReverseAccumulateSumInReverseOrder() const; + std::shared_ptr GetExpandedShapeByExpandListReserveLeft( + const std::vector &expand_list) const; + std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( + const std::vector &expand_list) const; + std::shared_ptr, Arrangement>> GetExpandShapeListPair( + const Arrangement &expand_shape) const; + std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; + std::vector GetSqueezeIdx() const; + Arrangement GetSqueezeArrangement() const; + + private: + bool IsValidArrangement(); + void ComputeSize(); + int32_t size_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc new file mode 100644 index 0000000000..4e1f467793 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/array.h" +#include +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +std::string Array::ToString() const { + std::ostringstream buffer; + buffer << "[ "; + for (auto &element : array_) { + buffer << std::to_string(element) + " "; + } + buffer << "]"; + return buffer.str(); +} + +Status Array::Init(const std::vector &array) { + array_ = array; + return IsvalidArray() ? Status::SUCCESS : Status::FAILED; +} + +bool Array::IsvalidArray() const { return true; } + +int32_t Array::GetDimByIdx(uint32_t idx) const { + size_t mod_idx = idx; + if (idx >= GetDimSize()) { + MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize(); + } + return array_[mod_idx]; +} + +int32_t Array::GetDimByReverseIdx(uint32_t idx) const { + size_t mod_idx = idx; + if (idx >= GetDimSize()) { + MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize(); + } + return array_[GetDimSize() - 1 - mod_idx]; +} + +bool Array::operator==(const Array &shape) const { + if (GetDimSize() != shape.GetDimSize()) { + return false; + } + for (uint32_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(i) != shape.GetDimByIdx(i)) { + return false; + } + } + return true; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h new file mode 100644 index 0000000000..13b3982a18 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/array.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +class Array { + public: + Array() = default; + virtual ~Array() = default; + std::string ToString() const; + virtual Status Init(const std::vector &array); + bool IsvalidArray() const; + std::vector array() const { return array_; } + size_t GetDimSize() const { return array_.size(); } + int32_t GetDimByIdx(uint32_t idx) const; + int32_t GetDimByReverseIdx(uint32_t idx) const; + bool operator==(const Array &a1) const; + + protected: + std::vector array_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc new file mode 100644 index 0000000000..9395d3df89 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.cc @@ -0,0 +1,254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/construct_operator.h" + +#include +#include + +namespace mindspore { +namespace parallel { +Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { + dev_size_ = dev_matrix_shape.size(); + dev_matrix_shape_ = dev_matrix_shape; + dev_list_ = dev_list; + return Status::SUCCESS; +} + +Status ConstructOperator::ReshapeOP(Shape shape) { + int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); + if (prod != prod_expect) { + ValuePtr ptr = MakeValue(shape); + MS_EXCEPTION_IF_NULL(ptr); + MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString() << "when construct Reshape operator!"; + return Status::INVALID_ARGUMENT; + } + OperatorAttrs attrs; + ValuePtr param_value = MakeValue(shape); + Attr param = std::make_pair(SHAPE, param_value); + OperatorParams params = {std::make_pair(param, 2)}; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(RESHAPE, args); + return Status::SUCCESS; +} + +Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { + ValuePtr attr_value = MakeValue(value); + Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); + Attr attr_end_mask = std::make_pair(END_MASK, attr_value); + Attr attr_ellipsis_mask = std::make_pair(ELLIPSIS_MASK, attr_value); + Attr attr_new_axis_mask = std::make_pair(NEW_AXIS_MASK, attr_value); + Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value); + OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask}; + + ValuePtr param_begin_value = MakeValue(begin); + Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2); + ValuePtr param_end_value = MakeValue(end); + Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3); + + ValuePtr param_strides_value = MakeValue(strides); + Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4); + OperatorParams params = {param_begin, param_end, param_strides}; + OperatorArgs op_args = std::make_pair(attrs, params); + + return std::make_pair(STRIDED_SLICE, op_args); +} + +Status ConstructOperator::StridedSliceOP(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + int32_t split_count = args[0]; + if (split_count <= 0) { + MS_LOG(ERROR) << "split_count should not be less than 0!"; + return Status::FAILED; + } + int32_t split_dim = args[1]; + int32_t dev_dim = args[2]; + std::vector group_list; + + if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + MS_LOG(ERROR) << "stride slice op: create group failed"; + return FAILED; + } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice + MS_LOG(INFO) << "no need stride slice op"; + return SUCCESS; + } + + Group group = group_list[0]; + size_t rank; + if (group.GetIndex(&rank) == Status::FAILED) { + return Status::FAILED; + } + size_t size = tensor_shape_.size(); + Shape begin(size); + Shape end(size); + Shape strides(size, 1); + size_t index = 0; + for (auto num : tensor_shape_) { + if (index != IntToSize(split_dim)) { + begin[index] = 0; + end[index] = num; + } else { + if (num % split_count != 0) { + MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim + << "! when construct StridedSlice operator"; + return Status::INVALID_ARGUMENT; + } + int32_t count = num / split_count; + begin[index] = SizeToInt(rank) * count; + end[index] = (SizeToInt(rank) + 1) * count; + } + index++; + } + + op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides); + + return Status::SUCCESS; +} + +Status ConstructOperator::AllGatherOP(int32_t dev_dim) { + if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { + MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!"; + return Status::INVALID_ARGUMENT; + } + + std::vector group_list; + if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + MS_LOG(ERROR) << "AllGather op: create group failed"; + return FAILED; + } else if (group_list.empty()) { // this group only has one device, don't need do allgather + MS_LOG(INFO) << "no need all gather op"; + return SUCCESS; + } + + std::string group_name = group_list[0].name(); + ValuePtr attr_value = MakeValue(group_name); + Attr attr = std::make_pair(GROUP, attr_value); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(ALL_GATHER, args); + return Status::SUCCESS; +} + +Status ConstructOperator::ConcatOP(int32_t concat_dim) { + if (IntToSize(concat_dim) >= tensor_shape_.size()) { + MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; + return Status::INVALID_ARGUMENT; + } + ValuePtr attr_value = MakeValue(concat_dim); + Attr attr = std::make_pair(AXIS, attr_value); + OperatorAttrs attrs = {attr}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(CONCAT, args); + return Status::SUCCESS; +} + +Status ConstructOperator::SplitOP(int32_t split_count) { + if (split_count <= 0) { + MS_LOG(ERROR) << "Invalid split count when construct Split operator!"; + return Status::FAILED; + } + OperatorAttrs attrs; + ValuePtr attr_value_axis = MakeValue(DEFAULT); + Attr attr_axis = std::make_pair(AXIS, attr_value_axis); + ValuePtr attr_value_split = MakeValue(split_count); + Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split); + attrs = {attr_axis, attr_split}; + OperatorParams params; + OperatorArgs args = std::make_pair(attrs, params); + op_ = std::make_pair(SPLIT, args); + return Status::SUCCESS; +} + +Status ConstructOperator::AlltoAllOP(Args args) { + if (args.size() < 4) { + MS_LOG(ERROR) << "args size should not be less than 4!"; + return Status::FAILED; + } + int32_t split_count = args[0]; + int32_t split_dim = args[1]; + int32_t concat_dim = args[2]; + int32_t dev_dim = args[3]; + if (split_count <= 0) { + MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!"; + return Status::FAILED; + } + if (tensor_shape_[IntToSize(split_dim)] % split_count != 0) { + MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim + << "when construct AlltoAll operator!"; + return Status::INVALID_ARGUMENT; + } + if (IntToSize(concat_dim) >= tensor_shape_.size()) { + MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!"; + return Status::INVALID_ARGUMENT; + } + if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { + MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!"; + return Status::INVALID_ARGUMENT; + } + + std::vector group_list; + if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { + MS_LOG(ERROR) << "AlltoAll op: create group failed"; + return FAILED; + } else if (group_list.empty()) { // this group only has one device, don't need do alltoall + MS_LOG(INFO) << "no need all to all op"; + return SUCCESS; + } + + std::string group_name = group_list[0].name(); + ValuePtr attr_value_group = MakeValue(group_name); + Attr attr_group = std::make_pair(GROUP, attr_value_group); + ValuePtr attr_value_split_count = MakeValue(split_count); + Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count); + ValuePtr attr_value_split_dim = MakeValue(split_dim); + Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim); + ValuePtr attr_value_concat_dim = MakeValue(concat_dim); + Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim); + OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group}; + OperatorParams params; + OperatorArgs op_args = std::make_pair(attrs, params); + op_ = std::make_pair(ALL_TO_ALL, op_args); + return Status::SUCCESS; +} + +Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { + MS_EXCEPTION_IF_NULL(group); + CheckGlobalDeviceManager(); + MS_EXCEPTION_IF_NULL(g_device_manager); + int32_t rank = g_device_manager->global_rank(); + DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_); + RankList group_devices; + if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { + return FAILED; + } + // this group only has one device, don't need create the group + if (group_devices.size() == 1) { + MS_LOG(INFO) << "the group is empty"; + return SUCCESS; + } + + Group g = g_device_manager->CreateGroup(group_devices); + group->push_back(g); + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h new file mode 100644 index 0000000000..b06d70af36 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/construct_operator.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ + +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +using Args = std::vector; + +class ConstructOperator { + public: + const int32_t DEFAULT = 0; + ConstructOperator() : dev_size_(0) {} + ~ConstructOperator() = default; + Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); + Status ReshapeOP(Shape shape); + Status StridedSliceOP(Args args); + Status AllGatherOP(int32_t dev_dim); + Status SplitOP(int32_t split_count); + Status ConcatOP(int32_t concat_dim); + Status AlltoAllOP(Args args); + Operator GetOperator() const { return op_; } + void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } + + private: + Operator op_; + size_t dev_size_; + Shape tensor_shape_; + RankList dev_list_; + Shape dev_matrix_shape_; + Status CreateGroupByDim(size_t axis, std::vector *group); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc new file mode 100644 index 0000000000..d5d34a484f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/layout_transfer.h" +#include "common/utils.h" +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +std::string LayoutTransfer::ToString() const { + std::ostringstream buffer; + buffer << std::endl << std::string("from_in_ tensor layout:" + from_in_.ToString()); + buffer << std::endl << std::string("to_in_ tensor layout:" + to_in_.ToString()); + return buffer.str(); +} + +LayoutTransfer::~LayoutTransfer() = default; + +Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { + from_in_ = from_in; + to_in_ = to_in; + MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); + Status status = CheckValidTransfer(); + return status; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h new file mode 100644 index 0000000000..01c56fc7cf --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/layout_transfer.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ + +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +class LayoutTransfer { + public: + LayoutTransfer() = default; + virtual ~LayoutTransfer() = 0; + std::string ToString() const; + Status Init(const TensorLayout &from_in, const TensorLayout &to_in); + TensorLayout from_in() const { return from_in_; } + TensorLayout to_in() const { return to_in_; } + + protected: + bool IsSameTensorShape() const { return from_in_.IsSameTensorShape(to_in_); } + bool IsSameDeviceArrangement() const { return from_in_.IsSameDeviceArrangement(to_in_); } + + TensorLayout from_in_; + TensorLayout to_in_; + + private: + virtual Status CheckValidTransfer() = 0; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc new file mode 100644 index 0000000000..184f0c7530 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.cc @@ -0,0 +1,171 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/map.h" +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" +#include "utils/convert_utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +Status Map::Init(const std::vector &array) { + Status status = Array::Init(array); + if (status != Status::SUCCESS) { + return Status::FAILED; + } + if (!IsValidMap()) { + MS_LOG(ERROR) << "invalid map " << this->ToString(); + return Status::FAILED; + } + return Status::SUCCESS; +} + +bool Map::IsValidMap() { + if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { + return false; + } + // check that all none -1 value in array_ is different + std::vector sorted_array = array_; + std::sort(sorted_array.begin(), sorted_array.end()); + int32_t value = MAP_NONE; + for (auto &element : sorted_array) { + if (element == MAP_NONE) { + continue; + } + if (element == value) { + return false; + } + value = element; + } + return true; +} + +int32_t Map::GetMaxItem() const { + if (!array_.empty()) { + return *std::max_element(array_.begin(), array_.end()); + } else { + return MAP_NONE; + } +} + +int32_t Map::GetIndexByValue(int32_t value) const { + auto iter = find(array_.begin(), array_.end(), value); + if (iter != array_.end()) { + return static_cast(std::distance(array_.begin(), iter)); + } else { + return MAP_NONE; + } +} + +/* + * expand.size() should be equal to array_.size() + */ +std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { + if (expand_num_list.GetDimSize() != GetDimSize()) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i != GetDimSize(); i++) { + if (GetDimByIdx(i) == MAP_NONE) { + for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { + new_shape.push_back(MAP_NONE); + } + } else { + new_shape.push_back(GetDimByIdx(i)); + int32_t j = 1; + while (j < expand_num_list.GetDimByIdx(i)) { + new_shape.push_back(MAP_NONE); + j++; + } + } + } + auto map_new = std::make_shared(); + (void)map_new->Init(new_shape); + return map_new; +} + +/* + * expand.size() should be equal to array_.size() + */ +std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { + if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { + return nullptr; + } + std::vector new_shape; + for (uint32_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(i) == MAP_NONE) { + new_shape.push_back(MAP_NONE); + } else { + int32_t start_map = + expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; + for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { + new_shape.push_back(k + start_map); + } + } + } + auto map_new = std::make_shared(); + (void)map_new->Init(new_shape); + return map_new; +} + +std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { + if (GetMaxItem() >= static_cast(input_vector.size())) { + return nullptr; + } + std::vector out; + Arrangement empty_arrangement; + for (uint32_t i = 0; i < GetDimSize(); i++) { + if (GetDimByIdx(i) == MAP_NONE) { + out.push_back(empty_arrangement); + } else { + out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); + } + } + return std::make_shared>(out); +} + +bool Map::CheckNoneByIdxList(std::vector idx_list) const { + for (auto &value : idx_list) { + if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { + return false; + } + } + return true; +} + +Map Map::SqueezeMapByIdxList(std::vector idx_list) const { + std::vector out_shape; + for (size_t i = 0; i < GetDimSize(); i++) { + auto it = std::find(idx_list.begin(), idx_list.end(), i); + if (it == idx_list.end()) { + out_shape.push_back(GetDimByIdx(SizeToUint(i))); + } + } + if (out_shape.empty()) { + MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; + out_shape.push_back(MAP_NONE); + } + Map out; + (void)out.Init(out_shape); + return out; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h new file mode 100644 index 0000000000..3d299d4b90 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/map.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/arrangement.h" +#include "frontend/parallel/tensor_layout/array.h" + +namespace mindspore { +namespace parallel { +constexpr int32_t MAP_NONE = -1; + +class Map : public Array { + public: + Map() = default; + ~Map() override = default; + Status Init(const std::vector &array) override; + int32_t GetMaxItem() const; + int32_t GetIndexByValue(int32_t value) const; + std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; + std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; + std::shared_ptr> ReMapVector(const std::vector &input_vector) const; + bool CheckNoneByIdxList(std::vector idx_list) const; + Map SqueezeMapByIdxList(std::vector idx_list) const; + + private: + bool IsValidMap(); +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc new file mode 100644 index 0000000000..a5a488d807 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +Status RedistributionLayoutTransfer::CheckValidTransfer() { return Status::SUCCESS; } + +/* + * unify device arrangement between in_layout and out_layout + * after this function is called, + * in_step1_layout.device_arrangement and out_step1_layout.device_arrangement will be the same + */ +std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangement() const { + Arrangement in_arrangement; + Arrangement out_arrangement; + in_arrangement = from_in_.device_arrangement(); + out_arrangement = to_in_.device_arrangement(); + std::shared_ptr unify_arrangement_ptr = in_arrangement.GetUnifiedShape(out_arrangement); + if (unify_arrangement_ptr == nullptr) { + return nullptr; + } + std::shared_ptr from_out_ptr = from_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); + if (from_out_ptr == nullptr) { + return nullptr; + } + std::shared_ptr to_out_ptr = to_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); + if (to_out_ptr == nullptr) { + return nullptr; + } + ReshapeLayoutTransfer out; + Status status = out.Init(*from_out_ptr, *to_out_ptr); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +/* + * unify tensor shape between in_step1_layout.tensor_shape and out_step1_layout.tensor_shape + * after this function is called, + * in_step2_layout.tensor_shape and out_step2_layout.tensor_shape will be the same + */ +std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { + std::shared_ptr unified_device_arrangement_ptr = UnifyDeviceArrangement(); + if (unified_device_arrangement_ptr == nullptr) { + return nullptr; + } + return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape(); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h new file mode 100644 index 0000000000..0347b6423a --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_layout_transfer.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ + +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/layout_transfer.h" +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" + +namespace mindspore { +namespace parallel { +class RedistributionLayoutTransfer : public LayoutTransfer { + public: + RedistributionLayoutTransfer() = default; + ~RedistributionLayoutTransfer() override = default; + std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; + + private: + Status CheckValidTransfer() override; + std::shared_ptr UnifyDeviceArrangement() const; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc new file mode 100644 index 0000000000..6ac24418b7 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc @@ -0,0 +1,289 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" + +#include + +#include "frontend/parallel/device_manager.h" + +namespace mindspore { +namespace parallel { +Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, + RankList dev_list, bool is_cost_model) { + in_tensor_map_ = tensor_layout.tensor_map(); + dev_mat_ = tensor_layout.device_arrangement(); + + if (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize()) { + MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!"; + return Status::FAILED; + } + + cur_tensor_layout_ = tensor_layout; + out_tensor_map_ = out_tensor_map; + dev_list_ = std::move(dev_list); + + operator_list_.clear(); + operator_vector_.clear(); + output_info_vector_.clear(); + + if (constructor_.Init(dev_list_, dev_mat_.array()) != Status::SUCCESS) { + MS_LOG(ERROR) << "Init constructor failed"; + return Status::FAILED; + } + constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); + + size_t key = 0; + std::vector map = in_tensor_map_.array(); + for (int32_t item : map) { + map_[key++] = item; + } + + is_cost_model_ = is_cost_model; + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferRedistributionOperator() { + while (!map_.empty()) { + size_t len_global = operator_list_.size(); + + while (!map_.empty()) { + size_t len_split_by_axis = operator_list_.size(); + // split_by_axis operation + if (InferSplitByAxis() == Status::FAILED) { + return Status::FAILED; + } + // permute_by_axis operation + while (!map_.empty()) { + size_t len_permute_by_axis = operator_list_.size(); + if (InferPermuteByAxis() == Status::FAILED) { + return Status::FAILED; + } + if (len_permute_by_axis == operator_list_.size()) break; + } + if (len_split_by_axis == operator_list_.size()) break; + } + // concat_by_axis operation + if (InferConcatByAxis() == Status::FAILED) { + return Status::FAILED; + } + // break loop structure with concat_by_axis + if (len_global == operator_list_.size() && !map_.empty()) { + size_t index = map_.begin()->first; + int32_t in_dim = map_[index]; + map_[index] = NONE; + Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; + if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { + return Status::FAILED; + } + } + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferSplitByAxis() { + for (auto iter = map_.begin(); iter != map_.end();) { + uint32_t index = iter->first; + int32_t in_dim = iter->second; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + if (in_dim == out_dim) { + (void)map_.erase(iter++); + continue; + } + if (in_dim == NONE && + !std::any_of(map_.begin(), map_.end(), + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { + Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; + if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { + MS_LOG(ERROR) << "Insert SplitByAxis Error!"; + return Status::FAILED; + } + (void)map_.erase(iter++); + } else { + (void)++iter; + } + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferPermuteByAxis() { + for (auto iter = map_.begin(); iter != map_.end();) { + uint32_t index = iter->first; + int32_t in_dim = map_[index]; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + if (in_dim == out_dim) { + (void)map_.erase(iter++); + continue; + } + if (in_dim == NONE && + std::any_of(map_.begin(), map_.end(), + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { + int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); + int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); + if (is_cost_model_) { + int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); + Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, + dev_num}; + if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { + MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; + return Status::FAILED; + } + } else { + Args args_allconcat = {cat_dim, out_dim, dev_num}; + Args args_allsplit = {dev_num, UintToInt(index), out_dim}; + if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { + MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; + return Status::FAILED; + } + if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { + MS_LOG(ERROR) << "Insert SplitByAxis Error!"; + return Status::FAILED; + } + } + (void)map_.erase(iter++); + map_[IntToSize(cat_dim)] = NONE; + } else { + (void)++iter; + } + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::InferConcatByAxis() { + for (auto iter = map_.begin(); iter != map_.end();) { + uint32_t index = iter->first; + int32_t in_dim = map_[index]; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) { + Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; + if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { + MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; + return Status::FAILED; + } + if (out_dim == NONE) { + (void)map_.erase(iter++); + } else { + map_[index] = NONE; + (void)++iter; + } + } else { + (void)++iter; + } + } + return Status::SUCCESS; +} + +// Transfer communicative operators into primitives and insert them into vector +Status RedistributionOperatorInfer::InsertOperator(OperatorName name, Args args) { + OperatorR op = std::make_pair(name, args); + OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array()); + operator_list_.push_back(op_cost); + if (construct_op_flag_) { + if (name == SPLIT_BY_AXIS) { + if (TransferSplitByAxis(args) == Status::FAILED) { + return Status::FAILED; + } + } else if (name == PERMUTE_BY_AXIS) { + if (TransferPermuteByAxis(args) == Status::FAILED) { + return Status::FAILED; + } + } else { + if (TransferConcatByAxis(args) == Status::FAILED) { + return Status::FAILED; + } + } + constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + uint32_t index = IntToUint(args[1]); + if (constructor_.StridedSliceOP(args) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + if (cur_tensor_layout_.UpdateTensorMap(index, args[2]) == Status::FAILED) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + if (constructor_.AlltoAllOP(args) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + uint32_t index = IntToUint(args[1]); + int32_t val = args[2]; + int32_t out_dim = out_tensor_map_.GetDimByIdx(index); + + if (cur_tensor_layout_.UpdateTensorMap(IntToUint(val), NONE) == Status::FAILED) { + return Status::FAILED; + } + if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { + if (args.size() < 3) { + MS_LOG(ERROR) << "args size should not be less than 3!"; + return Status::FAILED; + } + int32_t tensor_dim = args[0]; + int32_t dev_dim = args[1]; + int32_t split_count = args[2]; + if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + if (tensor_dim != 0) { + if (constructor_.SplitOP(split_count) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(true, split_count)); + } + if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) { + return Status::FAILED; + } else { + operator_vector_.push_back(constructor_.GetOperator()); + output_info_vector_.push_back(std::make_pair(false, 0)); + } + } + if (cur_tensor_layout_.UpdateTensorMap(IntToUint(tensor_dim), NONE) == Status::FAILED) { + return Status::FAILED; + } + return Status::SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h new file mode 100644 index 0000000000..66cdb3f925 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" +#include "utils/convert_utils.h" +namespace mindspore { +namespace parallel { +using DeviceArrangement = std::vector; +using TensorMap = std::vector; +using TensorShape = std::vector; +using RedistributionOperatorMap = std::unordered_map; +using OperatorR = std::pair; +using OperatorC = std::pair; +using OperatorList = std::vector; + +class RedistributionOperatorInfer { + public: + const int NONE = -1; + explicit RedistributionOperatorInfer(bool construct_op_flag = true) + : construct_op_flag_(construct_op_flag), is_cost_model_(false) {} + Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, + bool is_cost_model = false); + ~RedistributionOperatorInfer() = default; + OperatorList operator_list() const { return operator_list_; } + OperatorVector operator_vector() const { return operator_vector_; } + OutPutInfoVector output_info_vector() const { return output_info_vector_; } + Status InferRedistributionOperator(); + + private: + Status InferSplitByAxis(); + Status InferPermuteByAxis(); + Status InferConcatByAxis(); + Status TransferSplitByAxis(Args args); + Status TransferPermuteByAxis(Args args); + Status TransferConcatByAxis(Args args); + Status InsertOperator(OperatorName name, Args args); + + OperatorList operator_list_; + OperatorVector operator_vector_; + OutPutInfoVector output_info_vector_; + Arrangement dev_mat_; + RedistributionOperatorMap map_; + Map in_tensor_map_; + Map out_tensor_map_; + TensorLayout cur_tensor_layout_; + ConstructOperator constructor_; + RankList dev_list_; + bool construct_op_flag_; + bool is_cost_model_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc new file mode 100644 index 0000000000..98f7cf78fa --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +Status ReshapeLayoutTransfer::CheckValidTransfer() { + if (!IsSameDeviceArrangement()) { + return Status::FAILED; + } + return Status::SUCCESS; +} + +std::shared_ptr ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { + bool is_unified = IsSameTensorShape(); + std::shared_ptr out_layout_ptr = std::make_shared(*this); + if (out_layout_ptr == nullptr) { + return nullptr; + } + while (!is_unified) { + std::shared_ptr temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); + if (temp_layout_ptr == nullptr) { + return nullptr; + } + out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); + if (out_layout_ptr == nullptr) { + return nullptr; + } + is_unified = out_layout_ptr->IsSameTensorShape(); + } + return out_layout_ptr; +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByTo() const { + std::shared_ptr out_ptr = std::make_shared(*this); + bool is_expanded = FromTensorShapeCanBeExpandByTo(); + while (!is_expanded) { + out_ptr = out_ptr->ExtendFromTensorShapeByExpandedTensorShape(); + if (out_ptr == nullptr) { + return nullptr; + } + is_expanded = out_ptr->FromTensorShapeCanBeExpandByTo(); + } + return out_ptr; +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByFrom() const { + std::shared_ptr out_ptr = std::make_shared(*this); + bool is_expanded = ToTensorShapeCanBeExpandByFrom(); + while (!is_expanded) { + out_ptr = out_ptr->ExtendToTensorShapeByExpandedTensorShape(); + if (out_ptr == nullptr) { + return nullptr; + } + is_expanded = out_ptr->ToTensorShapeCanBeExpandByFrom(); + } + return out_ptr; +} + +bool ReshapeLayoutTransfer::FromTensorShapeCanBeExpandByTo() const { + return from_in_.TensorShapeCanBeExpanded(to_in_.tensor_shape()); +} + +bool ReshapeLayoutTransfer::ToTensorShapeCanBeExpandByFrom() const { + return to_in_.TensorShapeCanBeExpanded(from_in_.tensor_shape()); +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByExpandedTensorShape() const { + std::shared_ptr expanded_shape_ptr = ComputeExpandedFromTensorShapeByTo(); + if (expanded_shape_ptr == nullptr) { + return nullptr; + } + return ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); +} + +std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByExpandedTensorShape() const { + std::shared_ptr exchanged_from_and_to_ptr = ExchangeFromAndTo(); + if (exchanged_from_and_to_ptr == nullptr) { + return nullptr; + } + std::shared_ptr expanded_shape_ptr = exchanged_from_and_to_ptr->ComputeExpandedFromTensorShapeByTo(); + if (expanded_shape_ptr == nullptr) { + return nullptr; + } + std::shared_ptr exchanged_out = + exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); + if (exchanged_out == nullptr) { + return nullptr; + } + return exchanged_out->ExchangeFromAndTo(); +} + +std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo() const { + ReshapeLayoutTransfer out; + Status status = out.Init(to_in_, from_in_); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( + const Arrangement &expand_shape) const { + std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); + if (extend_tensor_shape_from_ptr == nullptr) { + return nullptr; + } + Arrangement unified_device_arrangement = extend_tensor_shape_from_ptr->device_arrangement(); + std::shared_ptr extend_device_arrangement_to_ptr = + to_in_.ExpandDeviceArrangement(unified_device_arrangement); + if (extend_device_arrangement_to_ptr == nullptr) { + return nullptr; + } + ReshapeLayoutTransfer out; + Status status = out.Init(*extend_tensor_shape_from_ptr, *extend_device_arrangement_to_ptr); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(out); +} + +std::shared_ptr ReshapeLayoutTransfer::ComputeExpandedFromTensorShapeByTo() const { + return from_in_.ComputeExpandedTensorShape(to_in_.tensor_shape()); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h new file mode 100644 index 0000000000..f9ebe9e32b --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/reshape_layout_transfer.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ + +#include +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/layout_transfer.h" + +namespace mindspore { +namespace parallel { +class ReshapeLayoutTransfer : public LayoutTransfer { + public: + ReshapeLayoutTransfer() = default; + ~ReshapeLayoutTransfer() override = default; + std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; + std::shared_ptr ExtendFromTensorShapeByTo() const; + std::shared_ptr ExtendToTensorShapeByFrom() const; + std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; + std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; + std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( + const Arrangement &expand_shape) const; + std::shared_ptr ExchangeFromAndTo() const; + + private: + Status CheckValidTransfer() override; + std::shared_ptr ComputeExpandedFromTensorShapeByTo() const; + bool FromTensorShapeCanBeExpandByTo() const; + bool ToTensorShapeCanBeExpandByFrom() const; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc new file mode 100644 index 0000000000..83282d16b3 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.cc @@ -0,0 +1,263 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/shape_util.h" +#include +#include "frontend/parallel/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +/* + * example: + * shape = [2, 8, 32] + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + */ +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { + MS_EXCEPTION_IF_NULL(shape_accum); + shape_accum->clear(); + int64_t size = 1; + for (auto iter = shape.begin(); iter < shape.end(); ++iter) { + size *= *iter; + if (size <= 0) { + MS_LOG(ERROR) << "element of shape should not be zero"; + return Status::FAILED; + } + shape_accum->push_back(size); + } + return Status::SUCCESS; +} + +/* + * example: + * shape = [2, 8, 32] + * shape_accum = [2 * 8 * 32, 8 * 32, 32] + * + */ +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { + MS_EXCEPTION_IF_NULL(shape_accum); + shape_accum->clear(); + int64_t size = 1; + for (auto iter = shape.end() - 1; iter >= shape.begin(); --iter) { + size *= *iter; + if (size <= 0) { + MS_LOG(ERROR) << "element of shape should not be zero"; + return Status::FAILED; + } + (void)shape_accum->insert(shape_accum->begin(), size); + } + return Status::SUCCESS; +} + +/* + * example: + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + * shape = [2, 8, 32] + * + */ +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { + MS_EXCEPTION_IF_NULL(shape); + shape->clear(); + int64_t value = 1; + for (auto iter = shape_accum.begin(); iter < shape_accum.end(); ++iter) { + if ((*iter) == 0) { + MS_LOG(ERROR) << "element of shape_accum should not be zero"; + return Status::FAILED; + } + if ((*iter) % value != 0) { + MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; + return Status::FAILED; + } + shape->push_back(static_cast((*iter) / value)); + value = (*iter); + } + return Status::SUCCESS; +} + +/* + * example: + * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * shape = [2, 8, 32] + */ +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { + MS_EXCEPTION_IF_NULL(shape); + shape->clear(); + int64_t value = 1; + for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) { + if (*iter == 0) { + MS_LOG(ERROR) << "element of shape_accum should not be zero"; + return Status::FAILED; + } + if ((*iter) % value != 0) { + MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; + return Status::FAILED; + } + (void)shape->insert(shape->begin(), static_cast((*iter) / value)); + value = *iter; + } + return Status::SUCCESS; +} + +/* + * example1: + * in1 = [2, 8] + * in2 = [4, 8] + * *out = [2, 4, 8] + * + * example2: + * in1 = [2, 4, 16] + * in2 = [8, 16] + * *out = [2, 4, 8, 16] + */ +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum) { + MS_EXCEPTION_IF_NULL(out_accum); + out_accum->clear(); + auto in1_iter = in1_accum.begin(); + auto in2_iter = in2_accum.begin(); + while ((in1_iter < in1_accum.end()) || (in2_iter < in2_accum.end())) { + if ((*in1_iter <= 0) || (*in2_iter <= 0)) { + MS_LOG(ERROR) << "element of in1 and in2 must be larger than zero"; + return Status::FAILED; + } + if (*in1_iter < *in2_iter) { + out_accum->push_back(*in1_iter); + ++in1_iter; + continue; + } else if (*in1_iter == *in2_iter) { + out_accum->push_back(*in1_iter); + ++in1_iter; + ++in2_iter; + } else { + out_accum->push_back(*in2_iter); + ++in2_iter; + } + } + if ((in1_iter != in1_accum.end()) || (in2_iter != in2_accum.end())) { + MS_LOG(ERROR) << "last element of in1 and in2 must be equal"; + return Status::FAILED; + } + return Status::SUCCESS; +} + +/* + * example: + * in1 = [8, 4] + * in2 = [2, 16] + * out = [2, 4, 4] + */ +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { + MS_EXCEPTION_IF_NULL(out); + std::vector in1_accum; + Status status = ShapeToAccumulateProduct(in1, &in1_accum); + if (status != Status::SUCCESS) { + return status; + } + std::vector in2_accum; + status = ShapeToAccumulateProduct(in2, &in2_accum); + if (status != Status::SUCCESS) { + return status; + } + std::vector out_accum; + status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); + if (status != Status::SUCCESS) { + return status; + } + status = AccumulateProductToShape(out_accum, out); + if (status != Status::SUCCESS) { + return status; + } + return status; +} + +/* + * example1: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 8 * 32, 32, 8] + * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] + * + * example2: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] + * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] + */ +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse) { + MS_EXCEPTION_IF_NULL(out_accum_reverse); + out_accum_reverse->clear(); + auto in_riter = in_accum_reverse.rbegin(); + auto expand_riter = expand_accum_reverse.rbegin(); + while (expand_riter != expand_accum_reverse.rend()) { + if (in_riter == in_accum_reverse.rend()) { + MS_LOG(ERROR) << "invalid ExpandAccumProd inputs"; + return Status::FAILED; + } + if (*in_riter > *expand_riter) { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); + ++expand_riter; + } else if (*in_riter == *expand_riter) { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); + ++in_riter; + ++expand_riter; + } else { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); + ++in_riter; + } + } + while (in_riter != in_accum_reverse.rend()) { + (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); + ++in_riter; + } + return Status::SUCCESS; +} + +/* + * example1: + * in = [2, 8, 32] + * expand = [16, 4, 8] + * out = [2, 8, 4, 8] + * + * example2: + * in = [2, 8, 32] + * expand = [2, 4, 8] + * out = [2, 4, 2, 4, 8] + */ +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { + MS_EXCEPTION_IF_NULL(out); + std::vector in_accum_reverse; + Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); + if (status != Status::SUCCESS) { + return status; + } + std::vector expand_accum_reverse; + status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse); + if (status != Status::SUCCESS) { + return status; + } + std::vector out_accum_reverse; + status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse); + if (status != Status::SUCCESS) { + return status; + } + status = AccumulateProductReverseToShape(out_accum_reverse, out); + if (status != Status::SUCCESS) { + return status; + } + return status; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h new file mode 100644 index 0000000000..49dd39ffd6 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/shape_util.h @@ -0,0 +1,172 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "frontend/parallel/status.h" + +namespace mindspore { +namespace parallel { +/* + * compute the accumulating product of all the values in shape from left to right, + * the accumulating results are saved in shape_accum from left to right + * + * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), + * then *shape_accum = [d_n-1, d_n-1 * d_n-2, d_n-1 * d_n-2 * d_n-3, ..., d_n-1 * d_n-2 * ... *d_0] + * + * example: + * shape = [2, 8, 32] + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + * + */ +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); + +/* + * compute the accumulating product of all the values in shape from right to left, + * the accumulating results are saved in shape_accum from right to left + * + * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), + * then *shape_accum = [d_n-1 * d_n-2 * ... *d_0, d_n-2 * d_n-3 * ... *d_0, ..., d_0] + * + * example: + * shape = [2, 8, 32] + * shape_accum = [2 * 8 * 32, 8 * 32, 32] + * + */ +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); + +/* + * compute the original shape from the accumulating product shape_accum, + * elements of shape_accum is saved from left to right, + * given shape_accum = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] + * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), + * (accum_i-1 % accum_i == 0, i=1,...,n-1) + * then *shape = [accum_n-2/accum_n-1, accum_n-3/accum_n-2, ..., accum_0/accum_1] + * + * example: + * shape_accum = [2, 2 * 8, 2 * 8 * 32] + * shape = [2, 8, 32] + * + */ +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); + +/* + * compute the original shape from the accumulating product shape_accum, + * elements of shape_accum is saved from right to left, + * given shape_accum_reverse = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] + * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), + * (accum_i % accum_i-1 == 0, i=1,...,n-1) + * then *shape = [accum_n-1/accum_n-2, accum_n-2/accum_n-1, ..., accum_1/accum_0] + * + * example: + * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * shape = [2, 8, 32] + * + */ +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); + +/* + * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, + * results are saved in out. + * i.e. *out_accum = in1_accum U in2_accum + * elements of out are saved in increasing order + * + * example1: + * in1_accum = [2, 8] + * in2_accum = [4, 8] + * out_accum = [2, 4, 8] + * + * example2: + * in1_accum = [2, 4, 16] + * in2_accum = [8, 16] + * out_accum = [2, 4, 8, 16] + */ +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum); + +/* + * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] + * size = din1_n-1 * din1n-2 * ... * din1_0 = din2_m-1 * din2_m-2 * ... * din2_0 + * find *out = [dout_k-1, dout_k-2, ..., dout_0], s.t. dout_k-1 * dout_k-2 * ... * dout_0 = size and + * suppose in1_accum, in2_accum, and *out_accum is the ShapeToAccumulateProduct result of in1, in2, and *out + * then for each din1_i in in1_accum, din1_i is in *out_accumulate, + * for each din2_i in in2_accum, din2_i is in *out_accumulate + * + * example: + * in1 = [8, 4] + * in2 = [2, 16] + * out = [2, 4, 4] + */ +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); + +/* + * given two accumulate product in reverse order of in and expand, + * in_accum_reverse = [din_n-1, din_n-2, ..., din_0] and expand_pos_reverse = [dexp_n-1, dexp_n-2, ..., dexp_0], + * i.e. in_accum_reverse is the ShapeToAccumulateProductReverse result of a shape in, + * expand_accum_reverse is the ShapeToAccumulateProductReverse result of a shape expand, + * compute the accumulate product in reverse order out_accum_reverse = [dout_k-1, dout_k-2, ..., dout_0], + * s.t. elements in out_accum_reverse are union of elements in in_accum_reverse and expand_accum_reverse + * (out_accum_reverse = in_accum_reverse U expand_accum_reverse), and + * out_accum_reverse is the ShapeToAccumulateProductReverse result of shape expand, + * i.e. dout_i > 0, i=0,1,...,k-1, elements of out_accum_reverse must be larger than zero, + * dout_i-1 % dout_i == 0, i=1,...,k-1 + * + * example1: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 8 * 32, 32, 8] + * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] + * + * example2: + * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] + * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] + * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] + */ +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse); + +/* + * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], + * compute the expended shape out = [dout_k-1, dout_k-2, ..., dout_0], + * s.t. dout_k-1 * dout_k-2 * ...* dout_0 = din_n-1 * din_n-2 * ... * d_0 + * suppose in_accum_reverse is the ShapeToAccumulateProductReverse result of in, + * expand_accum_reverse is the ShapeToAccumulateProductReverse result of expand, + * out_accum_reverse is the ShapeToAccumulateProductReverse result of out, + * then out_accum_reverse is the union of in_accum_reverse and expand_accum_reverse + * (out_accum_reverse = in_accum_reverse U expand_accum_reverse) + * + * example1: + * in = [2, 8, 32] + * expand = [16, 4, 8] + * out = [2, 8, 4, 8] + * + * example2: + * in = [2, 8, 32] + * expand = [2, 4, 8] + * out = [2, 4, 2, 4, 8] + */ +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h new file mode 100644 index 0000000000..fc78b1f59c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_info.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ + +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +using Shapes = std::vector; + +class TensorInfo { + public: + TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) + : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} + explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { + shape_ = tensor_layout.tensor_shape().array(); + slice_shape_ = tensor_layout.slice_shape().array(); + } + // trivial default constructor will not initialize c language types. + TensorInfo() = default; + ~TensorInfo() = default; + TensorLayout tensor_layout() const { return tensor_layout_; } + Shape slice_shape() const { return slice_shape_; } + Shape shape() const { return shape_; } + void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } + std::vector reduce_dim() const { return reduce_dim_; } + Dimensions InferStrategy() const { + Dimensions stra; + for (size_t i = 0; i < shape_.size(); ++i) { + if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) { + return stra; + } + int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]); + stra.push_back(dim); + } + return stra; + } + + private: + TensorLayout tensor_layout_; + Shape shape_; + Shape slice_shape_; + // reduce method's reduce dim + std::vector reduce_dim_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc new file mode 100644 index 0000000000..b9c6cc78de --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.cc @@ -0,0 +1,394 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include +#include +#include "common/utils.h" +#include "ir/value.h" +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/array.h" +#include "frontend/parallel/tensor_layout/shape_util.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parallel { +std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); } + +std::string TensorLayout::StandardToString() const { + std::ostringstream buffer; + buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString()); + buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString()); + buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString()); + return buffer.str(); +} + +std::string TensorLayout::OriginToString() const { + std::ostringstream buffer; + buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString()); + buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString()); + buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString()); + return buffer.str(); +} + +Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, + const Arrangement &tensor_shape) { + device_arrangement_origin_ = device_arrangement; + tensor_map_origin_ = tensor_map; + tensor_shape_origin_ = tensor_shape; + device_arrangement_ = device_arrangement; + tensor_map_ = tensor_map; + tensor_shape_ = tensor_shape; + if (IsValidTensorLayout()) { + MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString(); + RemoveElementEqualToOneInDeviceArrangement(); + MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); + return Status::SUCCESS; + } else { + MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); + return Status::FAILED; + } +} + +Status TensorLayout::InitFromVector(const std::vector &device_arrangement, + const std::vector &tensor_map, const std::vector &tensor_shape) { + if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { + return FAILED; + } + if (tensor_map_origin_.Init(tensor_map) != SUCCESS) { + return FAILED; + } + if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) { + return FAILED; + } + if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) { + return FAILED; + } + return SUCCESS; +} + +bool TensorLayout::IsValidTensorLayout() const { + if (tensor_map_origin_.GetMaxItem() >= static_cast(device_arrangement_origin_.GetDimSize())) { + MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!"; + return false; + } + if (tensor_map_origin_.GetDimSize() != tensor_shape_origin_.GetDimSize()) { + MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size!"; + return false; + } + if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { + MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; + return false; + } + return true; +} + +bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const { + for (uint32_t i = 0; i < tensor_map_.GetDimSize(); i++) { + if (tensor_map_.GetDimByIdx(i) != -1) { + int32_t divisor = GetSliceNumByTensorDimensionIndex(i); + if (divisor == 0) { + MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0"; + return false; + } + if (tensor_shape_.GetDimByIdx(i) % divisor != 0) { + return false; + } + } + } + return true; +} + +void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { + std::vector device_arrangement_shape; + std::vector tensor_map_shape = tensor_map_origin_.array(); + uint32_t dev_num = SizeToUint(device_arrangement_origin_.GetDimSize()); + int32_t dev_num_left = SizeToInt(device_arrangement_origin_.GetDimSize()); + for (uint32_t i = 0; i < dev_num; i++) { + if (device_arrangement_origin_.GetDimByIdx(i) == 1) { + int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast(dev_num - 1 - i)); + if (idx != -1) { + tensor_map_shape[static_cast(idx)] = -1; + } + for (auto &value : tensor_map_shape) { + if (value >= dev_num_left - 1 - static_cast(i)) { + value--; + } + } + continue; + } + device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i)); + } + (void)device_arrangement_.Init(device_arrangement_shape); + (void)tensor_map_.Init(tensor_map_shape); + tensor_shape_ = tensor_shape_origin_; +} + +// if idx is not in tensor_map, return -1 +int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const { + return tensor_map_.GetIndexByValue(idx); +} + +// tensor_map_.GetDimByIdx(idx) should not be -1 +int32_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const { + return static_cast(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx); +} + +// tensor_map_.GetDimByIdx(idx) should not be -1 +int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { + return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); +} + +std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { + std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); + if (expanded_arrangement_ptr == nullptr) { + return nullptr; + } + std::shared_ptr temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr); + if (temp_tensor_layout_ptr == nullptr) { + return nullptr; + } + return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape); +} + +/* + * example1: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, 0], + * in_tensor_shape = [512, 1024], + * out_tensor_shape = [128, 4, 2, 512], + * => + * out_device_arrangement = [8, 2, 2] + */ +std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { + std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); + if (expand_list_ptr == nullptr) { + return nullptr; + } + std::vector re_map_expand_list; + Arrangement empty_arrangement; + for (int32_t i = static_cast(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) { + if (tensor_map_.GetIndexByValue(i) < 0) { + re_map_expand_list.push_back(empty_arrangement); + } else { + re_map_expand_list.push_back((*expand_list_ptr)[IntToUint(tensor_map_.GetIndexByValue(i))]); + } + } + std::shared_ptr new_arrangement_ptr = + device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list); + return new_arrangement_ptr; +} + +/* + * example1: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, 0], + * in_tensor_shape = [512, 1024], + * out_tensor_shape = [8, 64, 4, 256] + * => + * out_device_arrangement = [8, 4], + * out_tensor_map = [1, -1, 0, -1], + */ +std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( + const Arrangement &expanded_shape) const { + std::shared_ptr, Arrangement>> expand_list_pair_ptr = + tensor_shape_.GetExpandShapeListPair(expanded_shape); + if (expand_list_pair_ptr == nullptr) { + return nullptr; + } + std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second); + if (tensor_map_new_ptr == nullptr) { + return nullptr; + } + TensorLayout tensor_layout_new; + Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(tensor_layout_new); +} + +/* + * example1: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, 0], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 2, 2] + * => + * out_tensor_map = [3, 2, 1, 0], + * out_tensor_shape = [4, 128, 2, 512] + * + * example2: + * in_device_arrangement = [8, 4], + * in_tensor_map = [0, 1], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 2, 2] + * => + * out_tensor_map = [1, 0, 3, 2], + * out_tensor_shape = [2, 256, 4, 256] + * + * example3: + * in_device_arrangement = [8, 4], + * in_tensor_map = [1, -1], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 2, 2] + * => + * out_tensor_map = [3, 2, -1], + * out_tensor_shape = [4, 128, 1024] + * + * example4: + * in_device_arrangement = [8, 4], + * in_tensor_map = [0, 1], + * in_tensor_shape = [512, 1024], + * out_device_arrangement = [4, 2, 4] + * => + * out_tensor_map = [0, 2, 1], + * out_tensor_shape = [512, 4, 256] + */ +std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { + std::shared_ptr, Arrangement>> expand_list_pair_ptr = + device_arrangement_.GetExpandShapeListPair(expanded_arrangement); + if (expand_list_pair_ptr == nullptr) { + return nullptr; + } + std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second); + if (tensor_map_new_ptr == nullptr) { + return nullptr; + } + std::shared_ptr> re_map_shape_list_ptr = + tensor_map_.ReMapVector(expand_list_pair_ptr->first); + if (re_map_shape_list_ptr == nullptr) { + return nullptr; + } + std::shared_ptr tensor_shape_new_ptr = + tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr); + if (tensor_shape_new_ptr == nullptr) { + return nullptr; + } + TensorLayout tensor_layout_new; + Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(tensor_layout_new); +} + +bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { + std::vector in_expand_shape_shape; + Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); + if (status != Status::SUCCESS) { + return false; + } + return (in_expand_shape_shape == tensor_shape_.array()); +} + +std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { + std::vector in_expand_shape_shape; + Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + Arrangement expanded_shape; + status = expanded_shape.Init(in_expand_shape_shape); + if (status != Status::SUCCESS) { + return nullptr; + } + return std::make_shared(expanded_shape); +} + +Arrangement TensorLayout::slice_shape() const { + std::vector shape; + for (uint32_t index = 0; index < tensor_map_.GetDimSize(); index++) { + int32_t dim = tensor_map_.GetDimByIdx(index); + int32_t num = tensor_shape_.GetDimByIdx(index); + if (dim == -1) { + shape.push_back(num); + } else { + int32_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); + shape.push_back(num / divisor); + } + } + Arrangement new_tensor_shape; + if (new_tensor_shape.Init(shape) == Status::FAILED) { + ValuePtr ptr = MakeValue(shape); + MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString(); + } else { + return new_tensor_shape; + } +} + +Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { + if (index >= tensor_map_.GetDimSize()) { + MS_LOG(ERROR) << "Index is out of the size of the tensor map!"; + return Status::FAILED; + } + auto shape = tensor_map_.array(); + shape[index] = value; + if (tensor_map_.Init(shape) == Status::FAILED) { + MS_LOG(ERROR) << "Update tensor map failed!"; + return Status::FAILED; + } + return Status::SUCCESS; +} + +bool TensorLayout::operator==(const TensorLayout &t1) const { + return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); +} + +/* + * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ] + * example 1: + * original tensor layout: + * device arrangement = [ 8 ] + * tensor map = [ 0 -1 -1 -1 ] + * tensor shape = [ 128 64 1 1 ] + * return tensor layout: + * device arrangement = [ 8 ] + * tensor map = [ 0 -1 ] + * tensor shape = [ 128 64 ] + * + * example 2: + * device arrangement = [ 8 ] + * tensor map = [ -1 -1 -1 -1 ] + * tensor shape = [ 1 1 1 1 ] + * return tensor layout: + * device arrangement = [ 8 ] + * tensor map = [ -1 ] + * tensor shape = [ 1 ] + */ +TensorLayout TensorLayout::SqueezeShape() const { + TensorLayout out; + Map out_map; + Arrangement out_shape; + if (tensor_shape_.size() == 1) { + (void)out_map.Init({MAP_NONE}); + (void)out_shape.Init({1}); + (void)out.Init(device_arrangement_, out_map, out_shape); + return out; + } + std::vector squeeze_list = tensor_shape_.GetSqueezeIdx(); + if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) { + MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation"; + return *this; + } + out_shape = tensor_shape_.GetSqueezeArrangement(); + out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list); + (void)out.Init(device_arrangement_, out_map, out_shape); + return out; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h new file mode 100644 index 0000000000..a9fdc9610c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -0,0 +1,99 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ + +#include +#include +#include +#include +#include +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/arrangement.h" +#include "frontend/parallel/tensor_layout/map.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace parallel { +class TensorLayout { + public: + TensorLayout() = default; + ~TensorLayout() = default; + std::string ToString() const; + std::string StandardToString() const; + std::string OriginToString() const; + Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); + Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, + const std::vector &tensor_shape); + + Arrangement device_arrangement() const { return device_arrangement_; } + + Map tensor_map() const { return tensor_map_; } + + Arrangement tensor_shape() const { return tensor_shape_; } + + Map origin_tensor_map() const { return tensor_map_origin_; } + + std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; + + std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; + + bool IsSameTensorShape(const TensorLayout &tensor_layout) const { + return (tensor_shape_ == tensor_layout.tensor_shape()); + } + + bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { + return (device_arrangement_ == tensor_layout.device_arrangement()); + } + + bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } + + bool operator==(const TensorLayout &t1) const; + + bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; + + std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; + + Arrangement slice_shape() const; + + Status UpdateTensorMap(uint32_t index, int32_t value); + + TensorLayout SqueezeShape() const; + + private: + std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( + const Arrangement &expanded_shape) const; + std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; + bool IsValidTensorLayout() const; + void RemoveElementEqualToOneInDeviceArrangement(); + int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; + int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const; + bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; + int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const; + + Arrangement device_arrangement_origin_; + Map tensor_map_origin_; + Arrangement tensor_shape_origin_; + Arrangement device_arrangement_; + Map tensor_map_; + Arrangement tensor_shape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc new file mode 100644 index 0000000000..43bb330787 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.cc @@ -0,0 +1,209 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include +#include +#include +#include "common/utils.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/shape_util.h" + +namespace mindspore { +namespace parallel { +Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { + from_origin_ = from; + to_origin_ = to; + if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { + MS_LOG(ERROR) << "from shape size must be equal to to shape size!"; + MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString(); + MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString(); + return Status::FAILED; + } + + dev_list_ = dev_list; + from_ = from_origin_.SqueezeShape(); + to_ = to_origin_.SqueezeShape(); + return Status::SUCCESS; +} + +RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { + // Step 1: Match device arrangement between from_ and to_ + RedistributionLayoutTransfer layout_transfer; + Status status = layout_transfer.Init(from_, to_); + if (status != Status::SUCCESS) { + return nullptr; + } + std::shared_ptr ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape(); + if (ptr == nullptr) { + MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; + return nullptr; + } + TensorLayout from_layout = ptr->from_in(); + TensorLayout to_layout = ptr->to_in(); + MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); + MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString(); + MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); + MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); + MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); + MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); + // Step 2: Infer redistribution and insert operators + RedistributionOperatorInfer operator_infer(construct_op_flag_); + if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { + MS_LOG(ERROR) << "Init operatorInfer failed!"; + return nullptr; + } + OperatorVector operator_vector; + OutPutInfoVector output_info_vector; + if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { + MS_LOG(ERROR) << "Infer redistribution failed!"; + return nullptr; + } else { + operator_vector = operator_infer.operator_vector(); + output_info_vector = operator_infer.output_info_vector(); + operator_list_ = operator_infer.operator_list(); + } + + // Step 3: Infer reshape and insert operators + if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) { + MS_LOG(ERROR) << "Construct Reshape operator failed!"; + return nullptr; + } + + return std::make_shared>( + std::make_pair(operator_vector, output_info_vector)); +} + +Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, + OutPutInfoVector *const output_info_vector) { + MS_EXCEPTION_IF_NULL(operator_vector); + MS_EXCEPTION_IF_NULL(output_info_vector); + ConstructOperator constructor; + if (operator_list_.empty()) { + if (from_origin_.slice_shape().array() != to_origin_.slice_shape().array() || keep_reshape_) { + reshape_flag_ = true; + constructor.UpdateTensorShape(from_origin_.slice_shape().array()); + Arrangement shape = to_origin_.slice_shape(); + MS_LOG(DEBUG) << "reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return Status::FAILED; + } else { + (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); + (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); + } + } + return Status::SUCCESS; + } + + if (from_origin_.slice_shape().array() != from_layout.slice_shape().array()) { + reshape_flag_ = true; + constructor.UpdateTensorShape(from_origin_.slice_shape().array()); + Arrangement shape = from_layout.slice_shape(); + MS_LOG(DEBUG) << "reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return Status::FAILED; + } else { + (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); + (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); + } + } + + if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) { + reshape_flag_ = true; + constructor.UpdateTensorShape(to_layout.slice_shape().array()); + Arrangement shape = to_origin_.slice_shape(); + MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString(); + if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { + return Status::FAILED; + } else { + (void)operator_vector->insert(operator_vector->end(), constructor.GetOperator()); + (void)output_info_vector->insert(output_info_vector->end(), std::make_pair(false, 0)); + } + } + return Status::SUCCESS; +} + +Status TensorRedistribution::ComputeCost() { + RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); + if (redistribution_oplist_ptr == nullptr) { + MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; + return Status::FAILED; + } + // Compute redistribution communication cost and computation cost + for (auto &op_cost : operator_list_) { + OperatorR op = op_cost.first; + Shape slice_shape = op_cost.second; + double prod = + std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast(1.0), std::multiplies()); + std::string str = op.first; + if (str == PERMUTE_BY_AXIS) { + // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. + // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape + forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; + int32_t concat_dim = op.second[2]; + if (concat_dim == 0) { + // memory cost = all_gather + computation_cost_ += prod; + memory_cost_ += prod; + } else { + // memory cost = all_gather + split + concat + int32_t dev_num = op.second[4]; + computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); + } + } else if (str == CONCAT_BY_AXIS) { + // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape + // computation cost = before_slice_shape + if (op.second.size() < 3) { + MS_LOG(ERROR) << "op.second size should not be less than 3!"; + return Status::FAILED; + } + double dev_num = op.second[2]; + // here, communication cost = all_gather + reduce_scatter + forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + int32_t concat_dim = op.second[0]; + if (concat_dim == 0) { + // computation cost = all_gather + computation_cost_ += prod; + memory_cost_ += prod * dev_num; + } else { + // computation cost = all_gather + split + concat + computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); + } + } else { + // There is only computation cost in SplitByAxis. + // computation cost = before_slice_shape + computation_cost_ += prod; + // This addtion may be erroneous + memory_cost_ += prod; + } + } + if (reshape_flag()) { + Shape prev_slice_shape = from_.slice_shape().array(); + double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); + computation_cost_ += 2.0 * prev_prod; + memory_cost_ += 2.0 * prev_prod; + } + return Status::SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h new file mode 100644 index 0000000000..df4bd1570f --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_redistribution.h @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ +#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/status.h" +#include "frontend/parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" + +namespace mindspore { +namespace parallel { +constexpr double ALLTOALL_SCALE_FACTOR = 2.0; +constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; +class TensorRedistribution { + public: + explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) + : reshape_flag_(false), + comm_cost_(0.0), + forward_comm_cost_(0.0), + backward_comm_cost_(0.0), + computation_cost_(0.0), + memory_cost_(0.0), + construct_op_flag_(construct_op_flag), + keep_reshape_(keep_reshape) {} + Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); + ~TensorRedistribution() = default; + RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); + OperatorList operator_list() const { return operator_list_; } + bool reshape_flag() const { return reshape_flag_; } + Status ComputeCost(); + double comm_cost() const { return comm_cost_; } + double computation_cost() const { return computation_cost_; } + double forward_comm_cost() const { return forward_comm_cost_; } + double backward_comm_cost() const { return backward_comm_cost_; } + double memory_cost() const { return memory_cost_; } + + private: + Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); + + TensorLayout from_origin_; + TensorLayout to_origin_; + TensorLayout from_; + TensorLayout to_; + RankList dev_list_; + OperatorList operator_list_; + bool reshape_flag_; + // communication cost, which is the sum of forward communication cost and backward communication cost + double comm_cost_; + // forward communication cost + double forward_comm_cost_; + // backward communication cost + double backward_comm_cost_; + // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the + // inputs. This is calculated ONLY for forward phase. + double computation_cost_; + // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is + // calculated by the outputs. + double memory_cost_; + bool construct_op_flag_; + bool keep_reshape_; +}; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc deleted file mode 100644 index 45cce7b473..0000000000 --- a/mindspore/ccsrc/ir/anf.cc +++ /dev/null @@ -1,221 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/anf.h" - -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/primitive.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" - -namespace mindspore { -// namespace to support intermediate representation definition -CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) - : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} - -// Check if CNode is an apply with the specific Primitive. -bool CNode::IsApply(const PrimitivePtr &value) const { - if (value == nullptr) { - return false; - } - - if (inputs_.size() != 0 && IsValueNode(inputs_[0])) { - PrimitivePtr fn_value = GetValueNode(inputs_[0]); - if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { - return true; - } - } - - return false; -} - -void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } - -std::string CNode::DebugString(int recursive_level) const { - std::ostringstream buffer; - if (recursive_level > 0) { - if (func_graph() != nullptr) { - buffer << func_graph()->ToString() << ":"; - } - buffer << ToString() << "{"; - bool is_first_node = true; - int idx = 0; - for (auto &node : inputs_) { - MS_EXCEPTION_IF_NULL(node); - if (is_first_node) { - is_first_node = false; - } else { - buffer << ", "; - } - buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1); - idx++; - } - buffer << "}"; - } else { - buffer << ToString(); - } - return buffer.str(); -} - -std::string ValueNode::ToString() const { - MS_EXCEPTION_IF_NULL(value_); - if (value_->isa()) { - return value_->cast()->ToString(); - } - std::ostringstream buffer; - buffer << AnfNode::ToString(); - buffer << "(" << value_->ToString() << ")"; - return buffer.str(); -} - -std::string ValueNode::DebugString(int) const { - MS_EXCEPTION_IF_NULL(value_); - std::ostringstream buffer; - buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString(); - return buffer.str(); -} - -std::string ValueNode::fullname_with_scope() { - if (!fullname_with_scope_.empty()) { - return fullname_with_scope_; - } - - MS_EXCEPTION_IF_NULL(scope()); - fullname_with_scope_ = scope()->name() + "/" + "data-" + id_generator::get_id(shared_from_base()); - return fullname_with_scope_; -} - -bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr) { - return false; - } - if (value != nullptr) { - return cnode->IsApply(value); - } - const auto &prim = GetValueNode(cnode->input(0)); - return prim != nullptr; -} - -PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { - if (node == nullptr) { - return nullptr; - } - auto cnode = node->cast(); - if (cnode != nullptr) { - if (cnode->size() > 0) { - auto prim = GetValueNode(cnode->input(0)); - return prim; - } - } - return nullptr; -} - -std::string GetCNodeFuncName(const CNodePtr cnode) { - if (cnode->inputs().empty()) { - return ""; - } - - AnfNodePtr valuenode = cnode->input(0); - if (valuenode->isa()) { - auto value = GetValueNode(valuenode); - // check whether the valuenode is primitive - if (value->isa()) { - return value->cast()->name(); - } - return value->ToString(); - } - return ""; -} - -bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { - if (IsValueNode(node)) { - PrimitivePtr fn_value = GetValueNode(node); - MS_EXCEPTION_IF_NULL(value); - if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { - return true; - } - } - return false; -} - -size_t NewSeenGeneration() { - static size_t seen_generation = 0; - return ++seen_generation; -} - -namespace id_generator { -static std::unordered_map node_ids; -std::string get_id(const AnfNodePtr &node) { - auto type_name = node->type_name(); - if (node_ids.find(type_name) == node_ids.end()) { - node_ids[type_name] = 0; - } else { - node_ids[type_name]++; - } - return std::to_string(node_ids[type_name]); -} - -void reset_id() { node_ids.clear(); } -} // namespace id_generator - -std::string GetCNodeTarget(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->device_target(); - if (!node->isa()) { - return default_target; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = cnode->input(0); - if (attr_input == nullptr) { - return default_target; - } - auto value_node = attr_input->cast(); - if (value_node == nullptr) { - return default_target; - } - auto value = value_node->value(); - if (value == nullptr) { - return default_target; - } - if (!value->isa()) { - return default_target; - } - auto primitive = value->cast(); - auto att_target = primitive->GetAttr("primitive_target"); - if (att_target != nullptr) { - if (!att_target->isa()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - auto target = GetValue(att_target); - if (kTargetSet.find(target) == kTargetSet.end()) { - MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; - } - return target; - } - return default_target; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/anf_extends.cc b/mindspore/ccsrc/ir/anf_extends.cc deleted file mode 100644 index 432ffdb606..0000000000 --- a/mindspore/ccsrc/ir/anf_extends.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/anf.h" - -#include -#include -#include -#include - -#include "ir/visitor.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "operator/ops.h" -#include "parallel/ops_info/ops_utils.h" -#include "debug/label.h" - -namespace mindspore { -// namespace to support intermediate representation definition -// Methods of AnfNode -TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } -BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } - -std::string AnfNode::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); -} - -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { - if (operator_info_ != nullptr) { - MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() - << ", using the new one: " << operator_info->name(); - auto old_ptr = operator_info_; - operator_info_ = operator_info; - return old_ptr; - } - operator_info_ = operator_info; - return nullptr; -} - -std::string CNode::fullname_with_scope() { - // if full name is set, return its name immediately - if (!fullname_with_scope_.empty()) { - return fullname_with_scope_; - } - - if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || - IsApply(prim::kPrimHistogramSummary)) { - std::string tag = GetValue(GetValueNode(input(1))); - std::string name; - if (IsApply(prim::kPrimScalarSummary)) { - name = tag + "[:Scalar]"; - } else if (IsApply(prim::kPrimImageSummary)) { - name = tag + "[:Image]"; - } else if (IsApply(prim::kPrimHistogramSummary)) { - name = tag + "[:Histogram]"; - } else { - name = tag + "[:Tensor]"; - } - fullname_with_scope_ = name; - } else { - // cnode input 0 should be primitive ptr or funcgraph ptr - auto value_ptr = input(0)->cast(); - if (value_ptr == nullptr) { - MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; - fullname_with_scope_ = id_generator::get_id(shared_from_base()); - return fullname_with_scope_; - } - auto input_value = value_ptr->value(); - if (input_value == nullptr) { - MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; - fullname_with_scope_ = id_generator::get_id(shared_from_base()); - return fullname_with_scope_; - } - - auto prim = input_value->cast(); - MS_EXCEPTION_IF_NULL(scope()); - fullname_with_scope_ = scope()->name() + "/"; - if (prim != nullptr) { - fullname_with_scope_ += prim->name(); - } else { - auto func_graph = input_value->cast(); - MS_EXCEPTION_IF_NULL(func_graph); - auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (fg_flag != nullptr) { - auto fg_name = GetValue(fg_flag); - fullname_with_scope_ += "GraphKernel_" + fg_name; - } else { - fullname_with_scope_ += func_graph->ToString(); - } - } - fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base()); - } - - return fullname_with_scope_; -} - -void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc deleted file mode 100644 index b0d0910304..0000000000 --- a/mindspore/ccsrc/ir/func_graph.cc +++ /dev/null @@ -1,628 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/func_graph.h" - -#include -#include -#include - -#include "debug/trace.h" -#include "ir/manager.h" -#include "operator/ops.h" -#include "utils/ordered_set.h" -#include "utils/convert_utils_base.h" - -namespace mindspore { -/* - * Methods of Graph - */ -FuncGraph::FuncGraph() - : attrs_(), - transforms_(), - parameter_default_value_(), - seen_(0), - parameters_(), - has_vararg_(false), - has_kwarg_(false), - kwonlyargs_count_(0), - hyper_param_count_(0), - is_generated_(false), - return_(nullptr), - manager_(std::weak_ptr()), - stub_(false) { - debug_info_ = std::make_shared(); -} - -AnfNodePtr FuncGraph::output() const { - // If return value is set, return should have two inputs. - if (return_ != nullptr && return_->inputs().size() == 2) { - return return_->input(1); - } else { - // If not set yet, return nullptr. - return nullptr; - } -} - -ParameterPtr FuncGraph::add_parameter() { - FuncGraphPtr this_func_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_func_graph); - add_parameter(p); - return p; -} - -void FuncGraph::add_parameter(const ParameterPtr &p) { - if (manager_.lock()) { - manager_.lock()->AddParameter(shared_from_base(), p); - } else { - parameters_.push_back(p); - } -} - -ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { - FuncGraphPtr this_graph = shared_from_base(); - ParameterPtr p = std::make_shared(this_graph); - p->set_name(name); - p->debug_info()->set_name(name); - - if (manager_.lock()) { - manager_.lock()->AddParameter(shared_from_base(), p); - } else { - parameters_.push_back(p); - } - hyper_param_count_++; - return p; -} - -bool FuncGraph::has_flag(const std::string &key) { - auto iter = attrs_.find(key); - if (iter != attrs_.cend()) { - if (iter->second->isa()) { - return GetValue(iter->second); - } - MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function."; - } - return false; -} - -bool FuncGraph::has_attr(const std::string &key) { - auto iter = attrs_.find(key); - return !(iter == attrs_.cend()); -} - -ValuePtr FuncGraph::get_attr(const std::string &key) { - auto iter = attrs_.find(key); - return iter == attrs_.cend() ? nullptr : iter->second; -} - -CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { - CNodePtr cnode = std::make_shared(inputs, shared_from_base()); - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - order_.push_back(cnode); - MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order."; - } - return cnode; -} - -CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { - CNodePtr app = NewCNode(inputs); - app->set_scope(scope); - return app; -} - -void FuncGraph::DumpCNodeList() { - MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; - for (const auto &cnode : order_) { - MS_LOG(INFO) << cnode->DebugString(); - } -} - -std::string FuncGraph::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); -} - -GraphDebugInfoPtr FuncGraph::debug_info() { - MS_EXCEPTION_IF_NULL(this->debug_info_); - if (this->debug_info_->get_graph() == nullptr) { - this->debug_info_->set_graph(shared_from_base()); - } - return this->debug_info_; -} - -const AnfNodeSet &FuncGraph::nodes() { return nodes_; } - -void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } - -void FuncGraph::ClearNodes() { nodes_.clear(); } - -void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } - -void FuncGraph::DropNode(AnfNodePtr node) { - nodes_.erase(node); - auto graph = node->func_graph(); - // Remove the node from order list. - if (graph) { - graph->EraseUnusedNodeInOrder(node); - } -} - -const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } - -void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { - auto &others = source->value_nodes(); - for (auto it = others.begin(); it != others.end(); it++) { - AddValueNode(it->first, it->second); - } -} - -void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } - -void FuncGraph::AddValueNode(AnfNodePtr node, int count) { - if (value_nodes_.count(node) == 0) { - value_nodes_[node] = count; - } else { - value_nodes_[node] += count; - } -} - -void FuncGraph::DropValueNode(AnfNodePtr node) { - if (value_nodes_.count(node) != 0) { - if (value_nodes_[node] == 1) { - (void)value_nodes_.erase(node); - } else { - value_nodes_[node]--; - if (value_nodes_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of ValueNode '" << node - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } -} - -const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } - -void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { - auto &others = source->free_variables(); - for (auto it = others.begin(); it != others.end(); it++) { - if (it->first->func_graph().get() != this) { - (void)AddFreeVariable(it->first, it->second); - } - } -} - -void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } - -bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { - if (free_variables_.count(node) == 0) { - free_variables_[node] = count; - return true; - } else { - free_variables_[node] += count; - return false; - } -} - -bool FuncGraph::DropFreeVariable(AnfNodePtr node) { - if (free_variables_.count(node) != 0) { - if (free_variables_[node] == 1) { - (void)free_variables_.erase(node); - return true; - } else { - free_variables_[node]--; - if (free_variables_[node] < 0) { - MS_LOG(EXCEPTION) << "Count of free variable '" << node - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } - return false; -} - -const BaseRefCounterMap &FuncGraph::free_variables_total() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &fv_total = mng->free_variables_total(); - return fv_total[shared_from_base()]; -} - -std::vector FuncGraph::free_variables_nodes() { - std::vector nodes; - const auto &fv_total = this->free_variables_total(); - for (auto &p : fv_total) { - auto key = p.first; - if (utils::isa(key)) { - nodes.push_back(utils::cast(key)); - } - } - - return nodes; -} - -std::vector FuncGraph::free_variables_func_graphs() { - std::vector func_graphs; - const auto &fv_total = this->free_variables_total(); - for (auto &p : fv_total) { - auto key = p.first; - if (utils::isa(key)) { - func_graphs.push_back(utils::cast(key)); - } - } - - return func_graphs; -} - -const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } - -void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { - auto &others = source->func_graphs_used(); - for (auto it = others.begin(); it != others.end(); it++) { - (void)AddFuncGraphUsed(it->first, it->second); - } - func_graphs_used_.erase(source); -} - -void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } - -bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { - if (func_graphs_used_.count(fg) == 0) { - func_graphs_used_[fg] = count; - return true; - } else { - func_graphs_used_[fg] += count; - return false; - } -} - -bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { - if (func_graphs_used_.count(fg) != 0) { - if (func_graphs_used_[fg] == 1) { - (void)func_graphs_used_.erase(fg); - return true; - } else { - func_graphs_used_[fg]--; - if (func_graphs_used_[fg] < 0) { - MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } - return false; -} - -const FuncGraphSet &FuncGraph::func_graphs_used_total() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - auto &used = mng->func_graphs_used_total(shared_from_base()); - return used; -} - -const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } - -void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { - auto &others = source->func_graph_cnodes_index(); - for (auto it = others.begin(); it != others.end(); it++) { - // Ignore the user graph who may own itself. - auto fg = it->first->first->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - if (fg.get() != this) { - AddFuncGraphCNodeIndex(it->first, it->second); - } - } -} - -void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } - -void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { - if (func_graph_cnodes_index_.count(pair) == 0) { - func_graph_cnodes_index_[pair] = count; - } else { - func_graph_cnodes_index_[pair] += count; - } -} - -void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { - if (func_graph_cnodes_index_.count(pair) != 0) { - if (func_graph_cnodes_index_[pair] == 1) { - (void)func_graph_cnodes_index_.erase(pair); - } else { - func_graph_cnodes_index_[pair]--; - if (func_graph_cnodes_index_[pair] < 0) { - MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } -} - -const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } - -void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { - auto &others = source->j_func_graphs(); - for (auto it = others.begin(); it != others.end(); it++) { - AddJFuncGraph(it->first, it->second); - } -} - -void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } - -void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { - if (j_func_graphs_.count(fg) == 0) { - j_func_graphs_[fg] = count; - } else { - j_func_graphs_[fg] += count; - } -} - -void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { - if (j_func_graphs_.count(fg) != 0) { - if (j_func_graphs_[fg] == 1) { - (void)j_func_graphs_.erase(fg); - } else { - j_func_graphs_[fg]--; - if (j_func_graphs_[fg] < 0) { - MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg - << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - } - } -} - -FuncGraphPtr FuncGraph::parent() { - // report the bug early. - if (manager_.lock() == nullptr) { - MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() - << " NodeInfo: " << trace::GetDebugInfo(debug_info()); - } - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->parent(shared_from_base()); -} - -const FuncGraphSet &FuncGraph::children() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->children(shared_from_base()); -} - -const FuncGraphSet &FuncGraph::scope() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->scopes(shared_from_base()); -} - -bool FuncGraph::recursive() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->recursive(shared_from_base()); -} - -std::shared_ptr> FuncGraph::recursive_graphs() { - auto mng = manager_.lock(); - MS_EXCEPTION_IF_NULL(mng); - return mng->recursive_graphs(shared_from_base()); -} - -AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { - auto itr = this->parameter_default_value_.find(name); - if (itr == parameter_default_value_.end()) { - return nullptr; - } - auto default_value = itr->second; - if (default_value == nullptr) { - MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist"; - } - if (IsValueNode(default_value)) { - return nullptr; - } - return default_value; -} - -// set the default values -void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { - auto all_is_null = - std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode(node); }); - if (value_list.empty()) { - all_is_null = true; - } - for (size_t i = 0; i < name_list.size(); ++i) { - if (!all_is_null) { - this->parameter_default_value_[name_list[i]] = value_list[i]; - } - } -} - -void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } - -size_t FuncGraph::GetDefaultValueCount() { - int null_count = - std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), - [](const std::pair &pair) { return IsValueNode(pair.second); }); - return parameter_default_value_.size() - IntToSize(null_count); -} - -AnfNodePtr FuncGraph::GetVariableArgParameter() { - if (!has_vararg_) { - return nullptr; - } - - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 2) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 2]; - } - - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]; -} - -std::string FuncGraph::GetVariableArgName() { - if (!has_vararg_) { - return ""; - } - - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 2) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast()->name(); - } - - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); -} - -AnfNodePtr FuncGraph::GetVariableKwargParameter() { - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]; - } - return nullptr; -} - -std::string FuncGraph::GetVariableKwargName() { - if (has_kwarg_) { - if (parameters_.size() < hyper_param_count_ + 1) { - MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " - << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; - } - return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); - } - return ""; -} - -int FuncGraph::GetPositionalArgsCount() const { - int count = SizeToInt(parameters_.size()); - if (has_kwarg_) { - count--; - } - if (has_vararg_) { - count--; - } - return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); -} - -AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { - for (size_t i = 0; i < parameters_.size(); ++i) { - MS_EXCEPTION_IF_NULL(parameters_[i]); - auto param_cast = parameters_[i]->cast(); - MS_EXCEPTION_IF_NULL(param_cast); - if (param_cast->name() == name) { - return parameters_[i]; - } - } - return nullptr; -} - -void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } - -std::list FuncGraph::GetOrderedCnodes() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - MS_LOG(DEBUG) << "Return ordered cnodes."; - return order_; - } else { - auto this_ptr = shared_from_base(); - auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); - auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); - - std::list cnodes; - auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); - for (const auto &node : nodes) { - auto cnode = dyn_cast(node); - if (cnode) { - cnodes.push_back(cnode); - } - } - return cnodes; - } -} - -void FuncGraph::EraseUnusedNodeInOrder() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - auto mng = manager_.lock(); - if (mng) { - auto &all_nodes = nodes(); - // Erase unused cnode. - for (auto it = order_.begin(); it != order_.end();) { - if (all_nodes.count(*it)) { - (void)it++; - } else { - MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; - it = order_.erase(it); - } - } - } - } -} - -void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { - if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { - order_.remove(n->cast()); - MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; - } -} - -void FuncGraph::CheckOrder() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - MS_LOG(DEBUG) << "Check graph " << ToString(); - for (auto it = order_.begin(); it != order_.end(); (void)it++) { - for (const auto &input_node : (*it)->inputs()) { - if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { - // Need to reorder the wrong order node. - auto found = std::find(order_.begin(), it, input_node); - if (found == it) { - DumpCNodeList(); - MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString() - << " doesn't obey the input dependency, " - << "as input " << input_node->DebugString() << " is not ahead of itself."; - } - } - } - } - auto mng = manager_.lock(); - if (mng != nullptr) { - const auto &all_nodes = nodes(); - if (all_nodes.size() != (order_.size() + parameters_.size())) { - DumpCNodeList(); - MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " - << all_nodes.size() - parameters_.size() << "."; - } - } - MS_LOG(DEBUG) << "Check order okay."; - } -} - -size_t NewFgSeenGeneration() { - static size_t fg_seen_generation = 0; - return ++fg_seen_generation; -} - -const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); -const char kFuncGraphFlagUndetermined[] = "Undeterminate"; -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc deleted file mode 100644 index f720913b98..0000000000 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ /dev/null @@ -1,650 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/func_graph_cloner.h" - -#include - -#include "ir/manager.h" -#include "ir/param_value.h" -#include "operator/ops.h" -#include "utils/convert_utils_base.h" -#include "utils/log_adapter.h" -#include "utils/profile.h" -#include "utils/context/ms_context.h" - -// namespace to support intermediate representation definition -namespace mindspore { -Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, - bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) - : clone_all_valuenodes_(clone_all_valuenodes), - clone_all_child_graphs_(clone_all_child_graphs), - clone_all_used_graphs_(clone_all_used_graphs), - relation_(relation), - target_relation_(target_relation == nullptr ? relation : target_relation) { - for (auto &func_graph : func_graphs) { - AddClone(func_graph); - } - scope_ = kDefaultScope; - type_ = kBasic; -} - -void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, - const AnfNodePtrList ¶ms, CloneType type) { - if (func_graph != nullptr) { - todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); - type_ = type; - } -} - -void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { - MS_EXCEPTION_IF_NULL(node); - if (repl_node_.find(node) != repl_node_.end() || node->isa()) { - return; - } - if (node->isa()) { - CloneParameter(node, target); - } else if (node->isa()) { - CloneCNode(node, target); - } -} - -void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(target); - TraceManager::DebugTrace(node->debug_info(), relation_); - auto new_param = (is_add) ? target->add_parameter() : std::make_shared(target); - auto old_param = node->cast(); - new_param->set_abstract(old_param->abstract()); - new_param->set_name(old_param->name()); - if (old_param->has_default()) { - // Default parameter can be shared since it is readonly. - new_param->set_default_param(old_param->default_param()); - } - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_param->set_scope(scope); - repl_node_[node] = new_param; - TraceManager::EndTrace(); -} - -void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(target); - TraceManager::DebugTrace(node->debug_info(), relation_); - CNodePtr new_node = std::make_shared(AnfNodePtrList{}, target); - auto old_node = node->cast(); - new_node->set_abstract(old_node->abstract()); - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_node->set_scope(scope); - new_node->set_kernel_info(old_node->kernel_info_ptr()); - repl_node_[old_node] = new_node; - nodes_.emplace_back(old_node, new_node); - TraceManager::EndTrace(); -} - -void Cloner::CloneValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - TraceManager::DebugTrace(node->debug_info(), relation_); - ValueNodePtr new_const = NewValueNode(GetValueNode(node)); - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_const->set_scope(scope); - new_const->set_abstract(node->abstract()); - repl_node_[node] = new_const; - TraceManager::EndTrace(); -} - -void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(target); - TraceManager::DebugTrace(node->debug_info(), relation_); - ValueNodePtr new_const = NewValueNode(target); - ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); - new_const->set_scope(scope); - new_const->set_abstract(node->abstract()); - repl_node_[node] = new_const; - TraceManager::EndTrace(); -} - -void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(manager_); - if (!clone_all_valuenodes_) { - return; - } - auto &value_nodes = func_graph->value_nodes(); - for (auto &value_node : value_nodes) { - auto old_node = value_node.first; - MS_EXCEPTION_IF_NULL(old_node); - if (repl_node_.count(old_node) == 0) { - CloneValueNode(old_node); - } - } -} - -void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(manager_); - if (!clone_all_child_graphs_) { - return; - } - auto &scopes = manager_->scopes(func_graph); - for (auto &graph : scopes) { - if (graph != func_graph) { - todo_.push_back({graph, nullptr, {}}); - } - } -} - -void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(manager_); - if (!clone_all_used_graphs_) { - return; - } - auto &used = func_graph->func_graphs_used(); - for (auto &fg : used) { - todo_.push_back({fg.first, nullptr, {}}); - } -} - -void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - for (auto &item : func_graph->parameter_default_value()) { - auto nodes = DeepLinkedGraphSearch(item.second); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - CloneNode(node, target_func_graph); - } else if (node->isa()) { - CloneValueNode(node); - } - } - } -} - -void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - MS_EXCEPTION_IF_NULL(manager_); - auto return_node = repl_node_[func_graph->get_return()]->cast(); - if (return_node == nullptr) { - MS_LOG(EXCEPTION) << "Can't find replicate node for return."; - } - target_func_graph->set_return(return_node); - - auto &cnodes = func_graph->func_graph_cnodes_index(); - for (auto &cnode : cnodes) { - auto parent = cnode.first->first->cast(); - auto valuenode = parent->input(cnode.first->second); - CloneValueNode(valuenode, target_func_graph); - } -} - -void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { - MS_EXCEPTION_IF_NULL(func_graph); - auto &old_params = func_graph->parameters(); - if (old_params.size() != params.size()) { - MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; - return; - } - for (size_t i = 0; i < old_params.size(); ++i) { - repl_node_[old_params[i]] = params[i]; - } -} - -void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); - *target_func_graph = std::make_shared(); - (*target_func_graph)->set_attrs(func_graph->attrs()); - (*target_func_graph)->set_transforms(func_graph->transforms()); - (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); - (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); - (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); - (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); - (*target_func_graph)->set_is_generate(func_graph->is_generated()); - (*target_func_graph)->set_stub(func_graph->stub()); - TraceManager::EndTrace(); -} - -void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - auto ¶ms = func_graph->parameters(); - for (auto ¶m : params) { - CloneParameter(param, target_func_graph, true); - } - repl_func_graph_[func_graph] = target_func_graph; -} - -void Cloner::GenParameters(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto &free_vars = manager_->free_variables_total(); - auto iter = free_vars.find(func_graph); - if (iter == free_vars.end()) { - return; - } - - for (auto &fv_map : iter->second) { - auto &free_var = fv_map.first; - if (utils::isa(free_var)) { - repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); - } - } -} - -void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { - param->set_abstract(node->abstract()); - if (node->isa()) { - ParameterPtr old_param = dyn_cast(node); - if (old_param->has_default()) { - // Default parameter can be shared since it is readonly. - param->set_default_param(old_param->default_param()); - } - param->set_name(old_param->name()); - } -} - -ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - ParameterPtr param = std::make_shared(func_graph); - TraceManager::EndTrace(); - CloneParameter(param, node); - if (is_add) { - func_graph->add_parameter(param); - } - repl_node_[param] = node; - repl_map_node_[func_graph][node] = param; - return param; -} - -void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, - AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { - AnfNodePtrList parameters; - std::unordered_set old_params; - for (auto ¶m : func_graph->parameters()) { - auto iter = repl_node_.find(param); - if (iter != repl_node_.end()) { - (void)old_params.insert(iter->second); - parameters.push_back(param); - } else { - parameters.push_back(AddParameter(func_graph, param, false)); - (void)old_params.insert(param); - } - } - AnfNodePtr new_param = nullptr; - for (auto ¶m : params) { - auto old_param = repl_node_[param]; - if (old_param->isa() && old_param->func_graph() == func_graph) { - repl_node_[old_param] = old_param; - repl_map_node_[func_graph][old_param] = old_param; - input_params->push_back(old_param); - continue; - } - if (old_params.find(old_param) != old_params.end()) { - new_param = repl_map_node_[func_graph][old_param]; - input_params->push_back(new_param); - continue; - } - new_param = AddParameter(func_graph, old_param, false); - parameters.push_back(new_param); - lift_params->push_back(new_param); - input_params->push_back(new_param); - } - func_graph->set_parameters(parameters); -} - -void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, - const AnfNodePtrList ¶ms) { - AnfNodePtr node = nullptr; - auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; - auto iter = repl_func_graph.find(func_graph); - if (iter == repl_func_graph.end()) { - node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); - repl_func_graph[func_graph] = node; - } else { - node = iter->second; - } - if (node == nullptr || !node->isa()) { - return; - } - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); - cnode->set_inputs(inputs); - OrderParameters(func_graph, inputs); -} - -void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { - std::unordered_set old_params; - for (auto ¶m : func_graph->parameters()) { - (void)old_params.insert(repl_node_[param]); - } - std::unordered_set new_params; - AnfNodePtrList parameters; - // Ignore the 1st and 2nd param of inputs(such as. partial graph) - for (size_t i = 2; i < inputs.size(); ++i) { - auto input = inputs[i]; - auto param = repl_node_[input]; - if (old_params.find(param) != old_params.end()) { - auto new_param = repl_map_node_[func_graph][param]; - parameters.push_back(new_param); - (void)new_params.insert(new_param); - } - } - for (auto ¶m : func_graph->parameters()) { - if (new_params.find(param) == new_params.end()) { - parameters.push_back(param); - } - } - func_graph->set_parameters(parameters); -} - -void Cloner::SetEdges(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - for (auto &node : func_graph->nodes()) { - if (node == nullptr) { - continue; - } - // Only cnode needed to be handled - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - for (size_t i = 0; i < inputs.size(); i++) { - auto &input = inputs[i]; - if (IsValueNode(input)) { - auto graph = GetValueNode(input); - auto &repl_func_graph = repl_map_func_graph_[func_graph]; - if (repl_func_graph.find(graph) != repl_func_graph.end()) { - transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); - } - } else { - auto &repl_node = repl_map_node_[func_graph]; - if (repl_node.find(input) != repl_node.end()) { - transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); - } - } - } - } -} - -void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, - const AnfNodePtrList ¶ms) { - AnfNodePtrList lift_params; - AnfNodePtrList input_params; - AddParameters(func_graph_user, params, &lift_params, &input_params); - AddInputs(func_graph_user, func_graph, input_params); - if (lift_params.empty()) { - return; - } - for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); - } -} - -void Cloner::Lift() { - for (auto &func_graph_params : repl_func_graph_params_) { - auto &func_graph = func_graph_params.first; - auto ¶ms = func_graph_params.second; - for (auto &cnode : func_graph->func_graph_cnodes_index()) { - LiftParameters(cnode.first->first->func_graph(), func_graph, params); - } - } -} - -void Cloner::LiftParameters() { - MS_EXCEPTION_IF_NULL(manager_); - transaction_ = manager_->Transact(); - const FuncGraphSet &func_graphs = manager_->func_graphs(); - for (auto &func_graph : func_graphs) { - GenParameters(func_graph); - } - Lift(); - for (auto &func_graph : func_graphs) { - SetEdges(func_graph); - } - transaction_.Commit(); -} - -bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { - MS_EXCEPTION_IF_NULL(func_graph); - // Make sure only inline once - if (status_.count(func_graph) != 0) { - if (is_inline == status_[func_graph]) { - return false; - } - if (clone_all_used_graphs_) { - MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False."; - return false; - } - } - return true; -} - -void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - MS_EXCEPTION_IF_NULL(manager_); - const AnfNodeSet &nodes = func_graph->nodes(); - for (auto &node : nodes) { - CloneNode(node, target_func_graph); - } -} - -void Cloner::Run() { - if (todo_.empty()) { - return; - } - - if (type_ < kLifting) { - // Basic and Inline Clone - FuncGraphPtrList func_graphs; - (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), - [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); - manager_ = Manage(func_graphs, false); - CloneNodes(); - LinkEdges(); - SetDefaults(); - } else { - // Lifting Clone - CloneInfo item = todo_.back(); - manager_ = Manage(item.origin); - LiftParameters(); - } -} - -void Cloner::CloneNodes() { - while (!todo_.empty()) { - CloneInfo item = todo_.back(); - todo_.pop_back(); - - bool is_inline = (item.target != nullptr); - FuncGraphPtr func_graph = item.origin; - FuncGraphPtr target_func_graph = item.target; - (void)graph_set_.insert(func_graph); - - if (!CheckStatus(func_graph, is_inline)) { - continue; - } - - if (is_inline) { - InlineCloneParameters(func_graph, item.params); - CloneAllNodes(func_graph, target_func_graph); - } else { - SetFuncGraphInfo(func_graph, &target_func_graph); - CloneParameters(func_graph, target_func_graph); - CloneAllNodes(func_graph, target_func_graph); - CloneFuncGraphValueNodes(func_graph, target_func_graph); - CloneFuncGraphDefaultValues(func_graph, target_func_graph); - } - - CloneValueNodes(func_graph); - AddChildGraphs(func_graph); - AddTotalGraphs(func_graph); - status_[func_graph] = is_inline; - } -} - -void Cloner::LinkEdges() { - for (auto &node_pair : nodes_) { - CNodePtr old_node = node_pair.first; - CNodePtr new_node = node_pair.second; - MS_EXCEPTION_IF_NULL(old_node); - MS_EXCEPTION_IF_NULL(new_node); - for (auto &input : old_node->inputs()) { - auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; - new_node->add_input(new_input); - } - } -} - -// For the graphs cloned, update its default value map to the cloned nodes -void Cloner::SetDefaults() { - for (auto &item : graph_set_) { - MS_EXCEPTION_IF_NULL(item); - if (repl_func_graph_.count(item) != 0) { - for (auto ¶m_def : item->parameter_default_value()) { - MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); - if (repl_node_.count(param_def.second) != 0) { - repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); - } else { - repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second); - } - } - } - } -} - -AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { - MS_EXCEPTION_IF_NULL(root); - if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { - MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; - } - CloneNode(root, repl_func_graph_[root->func_graph()]); - auto iter = repl_node_.find(root); - if (iter != repl_node_.end()) { - return iter->second; - } - MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; -} - -AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { -#ifdef ENABLE_PROFILE - double time = GetTime(); -#endif - Run(); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time); -#endif - return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); -} - -FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { -#ifdef ENABLE_PROFILE - double time = GetTime(); -#endif - Run(); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time); -#endif - return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); -} - -FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); - return cloner[func_graph]; -} - -AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, - const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(target_func_graph); - Cloner cloner({}, false); - if (scope != nullptr) { - cloner.set_scope(scope); - } - cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline); - return cloner[func_graph->output()]; -} - -FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - Cloner cloner({}, false); - cloner.AddClone(func_graph, nullptr, {}, kLifting); - return cloner[func_graph]; -} - -ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphPtrList func_graphs = {func_graph}; - ClonerPtr cloner = - std::make_shared(func_graphs, false, false, false, std::make_shared(), relation); -#ifdef ENABLE_PROFILE - double time = GetTime(); -#endif - cloner->Run(); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time); -#endif - return cloner; -} - -FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { - MS_EXCEPTION_IF_NULL(func_graph); - TraceManager::DebugTrace(func_graph->debug_info(), relation); - auto new_func_graph = std::make_shared(); - TraceManager::EndTrace(); - - auto ¶meters = func_graph->parameters(); - (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { - MS_EXCEPTION_IF_NULL(param); - TraceManager::DebugTrace(std::make_shared(param->debug_info())); - (void)new_func_graph->add_parameter(); - TraceManager::EndTrace(); - }); - - Cloner cloner = Cloner(); - cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters()); - AnfNodePtr output = cloner[func_graph->output()]; - new_func_graph->set_output(output); - new_func_graph->set_has_vararg(func_graph->has_vararg()); - new_func_graph->set_has_kwarg(func_graph->has_kwarg()); - new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); - new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); - new_func_graph->set_is_generate(func_graph->is_generated()); - new_func_graph->set_stub(func_graph->stub()); - for (auto &item : func_graph->parameter_default_value()) { - new_func_graph->set_param_default_value(item.first, cloner[item.second]); - } - - if (MsContext::GetInstance()->is_multi_graph_sink()) { - if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - } - } - - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - } - - return new_func_graph; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_extends.cc b/mindspore/ccsrc/ir/func_graph_extends.cc deleted file mode 100644 index 02f37f343d..0000000000 --- a/mindspore/ccsrc/ir/func_graph_extends.cc +++ /dev/null @@ -1,422 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/func_graph.h" - -#include -#include -#include - -#include "ir/manager.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" -#include "utils/ordered_set.h" -#include "abstract/abstract_value.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/abstract_function.h" - -#include "debug/anf_ir_dump.h" -#include "debug/trace.h" -#include "debug/draw.h" -#include "debug/label.h" - -namespace mindspore { -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractFunctionPtr; -using mindspore::abstract::AnalysisContextPtr; -using mindspore::abstract::PrimitiveAbstractClosure; -using mindspore::abstract::VirtualAbstractClosure; - -AbstractFunctionPtr FuncGraph::abstract() { - AbstractBasePtrList args_spec_list; - - for (auto &p : parameters_) { - MS_EXCEPTION_IF_NULL(p); - if (p->abstract() == nullptr) { - MS_LOG(ERROR) << "Error!!"; - return nullptr; - } - args_spec_list.push_back(p->abstract()); - } - - if (nullptr == output()) { - MS_LOG(ERROR) << "Error func graph no output"; - return nullptr; - } - - return std::make_shared(args_spec_list, output()->abstract()); -} - -abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { - AnalysisContextPtr temp_context = context; - if (temp_context == nullptr) { - temp_context = abstract::AnalysisContext::DummyContext(); - } - return std::make_shared(shared_from_base(), temp_context); -} - -void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { - if (force_new_ret || return_ == nullptr) { - std::vector params({NewValueNode(prim::kPrimReturn), value}); - FuncGraphPtr this_graph = shared_from_base(); - return_ = this_graph->NewCNode(params); - } else { - if (manager_.lock()) { - manager_.lock()->SetEdge(return_, 1, value); - } else { - return_->set_input(1, value); - } - } - - return_->set_abstract(value->abstract()); - - AnfNodePtr input0 = return_->input(0); - - PrimitivePtr return_prim = prim::kPrimReturn; - auto f = std::make_shared(return_prim, input0); - input0->set_abstract(f); -} - -void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } - -void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, - std::vector *specialized_parameter_list, - std::unordered_map *repl_nodes, int variable_args_count, - int pos_args_input_count) { - // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple - if (specialized_graph->has_vararg()) { - TraceManager::DebugTrace( - std::make_shared(specialized_graph->GetVariableArgParameter()->debug_info())); - std::vector var_param_tuple_nodes; - var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - - if (variable_args_count < 0) { - MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count - << " were given."; - } - // for python variable argument input , there is no upper limit - for (int i = 0; i < variable_args_count; ++i) { - ParameterPtr p = std::make_shared(specialized_graph); - std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); - p->set_name(param_name); - MS_EXCEPTION_IF_NULL(p->debug_info()); - p->debug_info()->set_name(param_name); - var_param_tuple_nodes.push_back(p); - MS_EXCEPTION_IF_NULL(specialized_parameter_list); - specialized_parameter_list->push_back(p); - } - auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes); - (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param); - TraceManager::EndTrace(); - } else if (variable_args_count > 0) { - MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount() - << " positional arguments, but " << pos_args_input_count << " were given."; - } -} - -void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, - std::vector *specialized_parameter_list, - const std::vector &kwarg_list, - std::unordered_map *repl_nodes) { - std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - - for (const auto &kwarg : kwarg_list) { - MS_EXCEPTION_IF_NULL(kwarg); - std::string kw_param_name = kwarg->get_key(); - MS_EXCEPTION_IF_NULL(specialized_graph); - AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name); - // if not find correspoding parameter node - if (param_node == nullptr) { - if (!has_kwarg()) { - MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name; - } else { - ParameterPtr p = std::make_shared(specialized_graph); - std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; - MS_EXCEPTION_IF_NULL(specialized_parameter_list); - auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), - [param_name](const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto param = node->cast(); - return param != nullptr && param->name() == param_name; - }); - if (find_kw_arg_in_list) { - MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; - } - p->set_name(param_name); - p->debug_info()->set_name(param_name); - kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name)); - auto extract_node = - specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p}); - kwarg_values_tuple_nodes.push_back(extract_node); - specialized_parameter_list->push_back(p); - } - } else { - auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); - // multiply values found given for parameter - if (node_itr != specialized_parameter_list->end()) { - MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; - } else { - specialized_parameter_list->push_back(param_node); - auto extract_node = specialized_graph->NewCNode( - {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); - (void)repl_nodes->emplace(param_node, extract_node); - } - } - } - - GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); -} - -void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, - std::unordered_map *repl_nodes, - const std::vector &kwarg_keys_tuple_nodes, - const std::vector &kwarg_values_tuple_nodes) { - if (has_kwarg()) { - MS_EXCEPTION_IF_NULL(specialized_graph); - TraceManager::DebugTrace( - std::make_shared(specialized_graph->GetVariableKwargParameter()->debug_info())); - auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes); - auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes); - auto make_dict_node = - specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values}); - MS_EXCEPTION_IF_NULL(repl_nodes); - (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node); - TraceManager::EndTrace(); - } -} - -bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { - // if the function does not have any vararg/kwarg/kwonly/default value/kw args input - // return the original graph - if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { - return false; - } - - // if the graph is generated for specific input, do not need to generate again - if (is_generated()) { - return false; - } - return true; -} - -void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, - const std::vector &specialized_parameter_list, - std::unordered_map *repl_nodes) { - MS_EXCEPTION_IF_NULL(specialized_graph); - for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { - auto param_node = specialized_graph->parameters()[i]; - MS_EXCEPTION_IF_NULL(param_node); - auto param_name = param_node->cast()->name(); - auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node); - if (node_itr != specialized_parameter_list.end()) { - continue; - } - if (param_name == specialized_graph->GetVariableArgName() || - param_name == specialized_graph->GetVariableKwargName()) { - continue; - } - auto default_value = specialized_graph->GetDefaultValueByName(param_name); - if (default_value == nullptr) { - MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name; - } - MS_EXCEPTION_IF_NULL(repl_nodes); - (void)repl_nodes->emplace(param_node, default_value); - } -} - -FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { - std::vector kwarg_list; - size_t arguments_count = args_spec_list.size(); - for (const auto &arg : args_spec_list) { - // if it is a keyword argument - MS_EXCEPTION_IF_NULL(arg); - if (arg->isa()) { - kwarg_list.push_back(dyn_cast(arg)); - } - } - if (!NeedGenerate(kwarg_list)) { - return shared_from_base(); - } - FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); - size_t kwarg_count = kwarg_list.size(); - int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); - int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); - int variable_args_count = pos_args_input_count - pos_args_count; - std::vector specialized_parameter_list; - std::unordered_map repl_nodes; - // the parameters that has arg input, copy from original parameters - for (size_t i = 0; i < IntToSize(pos_args_count); ++i) { - specialized_parameter_list.push_back(specialized_graph->parameters()[i]); - } - - GenerateVarParams(specialized_graph, &specialized_parameter_list, &repl_nodes, variable_args_count, - pos_args_input_count); - - GenerateKwParams(specialized_graph, &specialized_parameter_list, kwarg_list, &repl_nodes); - - GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes); - - // append hyper parameter to specialized_parameter_list - MS_EXCEPTION_IF_NULL(specialized_graph); - auto params = specialized_graph->parameters(); - (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), - std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); - - std::shared_ptr manager = mindspore::Manage(specialized_graph, false); - auto tr = manager->Transact(); - for (auto &node_pair : repl_nodes) { - MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" - << node_pair.second->DebugString(); - (void)tr.Replace(node_pair.first, node_pair.second); - } - tr.SetParameters(specialized_graph, specialized_parameter_list); - tr.Commit(); - specialized_graph->set_has_kwarg(false); - specialized_graph->set_has_vararg(false); - specialized_graph->set_kwonlyargs_count(0); - specialized_graph->ClearDefaultValues(); - specialized_graph->set_is_generate(true); - return specialized_graph; -} - -const char kPrimHasEffect[] = "_side_effect_flag"; - -bool FuncGraph::HasEffect(const CNodePtr &cnode) { - auto prim = GetCNodePrimitive(cnode); - if (prim != nullptr && prim->isa()) { - auto do_sig = prim->cast(); - auto prim_val = do_sig->function(); - if (prim_val != nullptr && prim_val->isa()) { - prim = prim_val->cast(); - } else { - prim = nullptr; - } - } - if (prim != nullptr) { - auto effect_val = prim->GetAttr(kPrimHasEffect); - if (effect_val && effect_val->isa()) { - auto effect_bool = GetValue(effect_val); - return effect_bool; - } - } - return false; -} - -std::shared_ptr> FindRoots(const std::vector &segment) { - std::shared_ptr> roots = std::make_shared>(segment); - for (const auto &node : segment) { - if (roots->size() == 1) { - return roots; - } - auto input_size = node->size(); - for (size_t i = 0; i < input_size; i++) { - auto in_node = node->input(i); - auto in_cnode = in_node->cast(); - if (in_cnode != nullptr) { - (void)roots->erase(in_cnode); - } - } - } - return roots; -} - -std::shared_ptr> FindLeaves(const std::vector &segment) { - std::shared_ptr> nodes = std::make_shared>(segment); - for (const auto &node : segment) { - if (nodes->size() == 1) { - return nodes; - } - if (IsPrimitiveCNode(node, prim::kPrimSwitch)) { - (void)nodes->erase(node); - continue; - } - auto input_size = node->size(); - for (size_t i = 0; i < input_size; i++) { - auto in_node = node->input(i); - if (!in_node->isa()) { - continue; - } - auto in_cnode = in_node->cast(); - if (in_cnode != nullptr) { - if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) { - (void)nodes->erase(node); - break; - } - } - } - } - return nodes; -} - -void FuncGraph::ReleaseFullOrderToEffectOrder() { - MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << "."; - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - std::list depends_order; - std::vector segment; - for (const auto &cnode : order_) { - if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { - continue; - } - if (HasEffect(cnode)) { - MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << "."; - if (segment.size() > 0) { - auto roots = FindRoots(segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - depends_order.push_back(*iter); - } - } - segment.clear(); - depends_order.push_back(cnode); - } else { - MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << "."; - segment.push_back(cnode); - } - } - if (segment.size() > 1) { - auto roots = FindRoots(segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - depends_order.push_back(*iter); - } - } - std::vector depend_inputs; - auto old_ret = output(); - for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) { - if (*iter != old_ret) { - depend_inputs.push_back(*iter); - } - } - set_flag(GRAPH_FLAG_HAS_EFFECT, false); - set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); - if (!depend_inputs.empty()) { - SetEffectDepends(depend_inputs); - } - } -} - -void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { - auto old_ret = output(); - std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; - (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); - auto new_ret = NewCNode(inputs); - auto mng = manager(); - if (mng) { - (void)mng->Replace(old_ret, new_ret); - } else { - return_->set_input(1, new_ret); - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc deleted file mode 100644 index cf56500aea..0000000000 --- a/mindspore/ccsrc/ir/manager.cc +++ /dev/null @@ -1,914 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/manager.h" - -#include -#include -#include - -#include "debug/trace_base.h" -#include "ir/func_graph.h" -#include "utils/profile.h" -#include "utils/convert_utils_base.h" -#include "operator/ops.h" - -namespace mindspore { - -FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { - auto m = std::make_shared(func_graphs, manage); - m->Init(); - return m; -} - -FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { - FuncGraphManagerPtr m = nullptr; - bool root = false; - - for (auto &fg : func_graphs) { - if (fg == nullptr) { - continue; - } - if (fg->manager() != nullptr) { - m = fg->manager(); - break; - } - } - - if (m == nullptr) { - std::vector tmp; - m = MakeManager(tmp, manage); - root = true; - } - - for (auto &fg : func_graphs) { - if (fg == nullptr) { - continue; - } - m->AddFuncGraph(fg, root); - } - return m; -} - -FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { - std::vector func_graphs = {func_graph}; - return Manage(func_graphs, manage); -} - -FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) - : roots_(roots), is_manage_(manage) { - Reset(); -} - -void FuncGraphManager::Reset() { - func_graphs_ = FuncGraphSet(); - all_nodes_ = AnfNodeSet(); - node_users_ = NodeUsersMap(); - - signals_ = std::make_shared(); - - func_graph_parents_total_ = std::make_shared(this); - func_graph_parent_ = std::make_shared(this); - children_ = std::make_shared(this); - scopes_ = std::make_shared(this); - free_variables_total_ = std::make_shared(this); - func_graphs_used_total_ = std::make_shared(this); - recursive_ = std::make_shared(this); - j_total_ = std::make_shared(this); - - limit_ = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); -} - -void FuncGraphManager::Init() { - auto roots = roots_; - roots_ = FuncGraphSet(); - - for (auto &fg : roots) { - AddFuncGraph(fg, true); - } -} - -FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); - func_graph_parents_total_->Recompute(fg); - MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString(); - return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; -} - -FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(func_graph_parent_); - MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); - func_graph_parent_->Recompute(fg); - if (func_graph_parent_->parent_analysis().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString(); - return nullptr; - } - MS_LOG(DEBUG) << "End parents func graph " << fg->ToString(); - return func_graph_parent_->parent_analysis()[fg]; -} - -FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(children_); - MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); - children_->Recompute(fg); - return children_->children_analysis()[fg]; -} - -FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(scopes_); - MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); - scopes_->Recompute(fg); - MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString(); - return scopes_->scope_analysis()[fg]; -} - -FVTotalMap &FuncGraphManager::free_variables_total() const { - MS_EXCEPTION_IF_NULL(free_variables_total_); - free_variables_total_->Recompute(); - return free_variables_total_->fv_total_analysis(); -} - -FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(func_graphs_used_total_); - func_graphs_used_total_->Recompute(fg); - return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; -} - -bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - recursive_->Recompute(fg); - if (recursive_->recursive_analysis().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); - return false; - } - return recursive_->recursive_analysis()[fg]; -} - -std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(fg); - if (recursive(fg)) { - if (!recursive_->recursive_map().count(fg)) { - auto trace = std::list(); - recursive_->CheckRecursiveGraphs(fg, &trace); - } - if (recursive_->recursive_map().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); - return nullptr; - } - return recursive_->recursive_map()[fg]; - } else { - return nullptr; - } -} - -bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { - MS_EXCEPTION_IF_NULL(j_total_); - MS_EXCEPTION_IF_NULL(fg); - j_total_->Recompute(fg); - if (j_total_->j_total_analysis().count(fg) == 0) { - MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); - return false; - } - return j_total_->j_total_analysis()[fg]; -} - -// add a func graph to this manager, optionally as a root func graph. -void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { - MS_EXCEPTION_IF_NULL(func_graph); - if (is_root) { - roots_.add(func_graph); - } - if (func_graphs_.contains(func_graph)) { - return; - } - AddIntoManaged(func_graph); - std::vector para = func_graph->parameters(); - AcquireNodes(para); - std::vector return_vec({func_graph->get_return()}); - AcquireNodes(return_vec); -} - -// clear the all information in manager -void FuncGraphManager::Clear() { - func_graphs_.clear(); - all_nodes_.clear(); - node_users_.clear(); - roots_.clear(); - - signals_->InvalidateComputer(); -} - -void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { - MS_LOG(DEBUG) << "Start keep roots"; - bool root_exist = false; - for (auto &item : func_graphs) { - if (roots_.contains(item)) { - root_exist = true; - break; - } - } - - // if the new_root in roots_, we add new_root first, then calculate the func_graphs - // relation to new_root, remove the func_graphs not relation to new_root - // if the new_root not in roots_, we clear the all func_graphs in manager - // then add the new_root - if (root_exist || func_graphs.empty()) { - FuncGraphSet roots(func_graphs); - if (roots.empty()) { - roots = roots_; - } else { - roots_.clear(); - for (auto &item : roots) { - AddFuncGraph(item, true); - } - } - - FuncGraphSet keep; - for (auto &item : roots) { - MS_LOG(DEBUG) << "roots: " << item->ToString(); - keep.update(func_graphs_used_total(item)); -#ifdef DEBUG - for (auto &k : keep) { - MS_LOG(DEBUG) << "keep: " << k->ToString(); - } -#endif - } - MaybeDropFuncGraphs(func_graphs_ - keep, true); - } else { - Clear(); - FuncGraphSet roots(func_graphs); - for (auto &item : roots) { - AddFuncGraph(item, true); - } - } -} - -void FuncGraphManager::RemoveRoots() { - MS_LOG(DEBUG) << "Start remove roots"; - roots_.clear(); - MaybeDropFuncGraphs(func_graphs_, true); -} - -void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { - MS_EXCEPTION_IF_NULL(fg); - if (is_manage_) { - if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { - MS_LOG(WARNING) << "A func graph can only have one manager."; - } - FuncGraphManagerPtr this_manager = shared_from_this(); - fg->set_manager(this_manager); - } - func_graphs_.add(fg); -} - -void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { - FuncGraphSet todo(func_graphs); - std::set dropped; - // int count = 0; - while (!todo.empty()) { - FuncGraphPtr func_graph = todo.pop(); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString(); - if (roots_.contains(func_graph)) { - MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); - continue; - } - auto &users_cnode_index = func_graph->func_graph_cnodes_index(); - if (!users_cnode_index.empty() && !ignore_users) { - MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); - continue; - } - if (dropped.find(func_graph) != dropped.end()) { - MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString(); - continue; - } - (void)dropped.insert(func_graph); - std::vector return_vec = {func_graph->get_return()}; - todo.update(MaybeDropNodes(return_vec)); - } - for (auto &fg : dropped) { - MS_EXCEPTION_IF_NULL(fg); - all_nodes_.difference_update(fg->parameters()); - (void)func_graphs_.erase(fg); - if (fg->manager().get() == this) { - fg->set_manager(nullptr); - } - MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString(); - } -} - -void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(inp); - if (direction == kDecEdge) { - MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - auto &users_node = node_users_[inp]; - if (!users_node.contains(make_pair(node, index))) { - return; - } - (void)users_node.erase(make_pair(node, index)); - DropEdge(node, index, inp); - } else { - MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - if (IsValueNode(inp)) { - MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); - AddFuncGraph(GetValueNode(inp)); - } - auto &users_node = node_users_[inp]; - users_node.add(make_pair(node, index)); - AddEdge(node, index, inp); - } -} - -void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - int index = 0; - for (auto &inp : cnode->inputs()) { - ProcessEdge(cnode, index, inp, direction); - ++index; - } - } -} - -IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { - if (all_nodes_.contains(node)) { - return EXCLUDE; - } else { - return FOLLOW; - } -} - -void FuncGraphManager::AcquireNodes(const std::vector &nodes) { - AnfNodeSet acq; - for (auto &node : nodes) { - AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit_)); - - all_nodes_.update(new_nodes); - acq.update(new_nodes); - } - - for (auto &node : acq) { - MS_EXCEPTION_IF_NULL(node); - auto fg = node->func_graph(); - if (fg != nullptr) { - fg->AddNode(node); - } - ProcessInputs(node, kIncEdge); - } -} - -FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { - AnfNodeSet nodes_ordered(nodes); - FuncGraphSetPtr func_graphs_to_check = std::make_shared(); - while (!nodes_ordered.empty()) { - AnfNodePtr node = nodes_ordered.pop(); - MS_EXCEPTION_IF_NULL(node); - if (!all_nodes_.contains(node)) { - continue; - } - AnfNodeIndexSet &users = node_users_[node]; - - std::vector parameters; - if (!users.empty() || - (node->isa() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) { - continue; - } - if (IsValueNode(node)) { - auto fg = GetValueNode(node); - func_graphs_to_check->add(fg); - MS_LOG(DEBUG) << "Set value of node " << node->DebugString() << " from func graph " << fg->ToString() - << " to null"; - } - ProcessInputs(node, kDecEdge); - (void)all_nodes_.erase(node); - if (node->func_graph() != nullptr) { - node->func_graph()->DropNode(node); - } - - if (node->isa()) { - auto cnode = node->cast(); - nodes_ordered.update(cnode->inputs()); - } - (void)node_users_.erase(node); - } - return func_graphs_to_check; -} - -void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { - auto tr = Transact(); - tr.SetParameters(fg, parameters); - tr.Commit(); -} - -void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { - auto tr = Transact(); - tr.AddParameter(fg, parameter); - tr.Commit(); -} - -bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { - auto tr = Transact(); - bool success = tr.Replace(old_node, new_node); - if (success) { - tr.Commit(); - } - return success; -} - -void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { - auto tr = Transact(); - tr.SetEdge(node, index, value); - tr.Commit(); -} - -void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { - AnfNodePtr source_return = source->get_return(); - AnfNodePtr source_output = source->output(); - AnfNodePtr source_prim = source_return->cast()->input(0); - - int index = 0; - (void)node_users_[source_prim].erase(make_pair(source_return, index)); - DropEdge(source_return, index, source_prim); - index = 1; - (void)node_users_[source_output].erase(make_pair(source_return, index)); - DropEdge(source_return, index, source_output); - (void)all_nodes_.erase(source_return); - (void)node_users_.erase(source_return); - source->DropNode(source_return); - for (auto &node : source->nodes()) { - node->set_func_graph(target); - if (node->scope() == kDefaultScope) { - node->set_scope(scope); - } - } - - MoveAllNodes(source, target); - all_nodes_.difference_update(source->parameters()); - (void)func_graphs_.erase(source); - if (source->manager().get() == this) { - source->set_manager(nullptr); - } -} - -void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { - auto fg = node->func_graph(); - if (input->isa()) { - fg->AddValueNode(input); - if (IsValueNode(input)) { - auto used = GetValueNode(input); - used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); - if (fg->AddFuncGraphUsed(used)) { - signals_->InvalidateComputer(); - } - if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->AddJFuncGraph(used); - } - } - } else if (fg != nullptr && fg != input->func_graph()) { - if (fg->AddFreeVariable(input)) { - signals_->InvalidateComputer(); - } - } -} - -void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { - auto fg = node->func_graph(); - if (input->isa()) { - fg->DropValueNode(input); - if (IsValueNode(input)) { - auto used = GetValueNode(input); - used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); - if (fg->DropFuncGraphUsed(used)) { - signals_->InvalidateComputer(); - } - if (IsPrimitiveCNode(node, prim::kPrimJ)) { - fg->DropJFuncGraph(used); - } - } - } else if (fg != nullptr && fg != input->func_graph()) { - if (fg->DropFreeVariable(input)) { - signals_->InvalidateComputer(); - } - } -} - -void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { - target->CopyNodes(source); - target->CopyValueNodes(source); - target->CopyFuncGraphCNodesIndex(source); - target->CopyFreeVariables(source); - target->CopyFuncGraphsUsed(source); - target->CopyJFuncGraphs(source); - signals_->InvalidateComputer(); - source->ClearNodes(); - source->ClearValueNodes(); - source->ClearFuncGraphCNodesIndex(); - source->ClearFreeVariables(); - source->ClearFuncGraphsUsed(); - source->ClearJFuncGraphs(); -} - -FuncGraphTransaction FuncGraphManager::Transact() { - auto tr = FuncGraphTransaction(this); - return tr; -} - -void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, - EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { - for (auto &iter : changes) { - auto operation = iter.op; - auto args = iter.args; - switch (operation) { - case Change::kTxSetEdge: { - auto edge = args.cast(); - auto old_node = edge.root_node->input(edge.index); - (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; - (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; - (*rms)[old_node] += 1; - (*adds)[edge.new_node] += 1; - edge.root_node->set_input(edge.index, edge.new_node); - } break; - case Change::kTxSetParams: { - auto param = args.cast(); - MS_EXCEPTION_IF_NULL(param.func_graph); - auto old_parameters = param.func_graph->parameters(); - for (auto &p : param.params) { - (*adds)[p] += 1; - } - for (auto &p : old_parameters) { - (*rms)[p] += 1; - } - param.func_graph->set_parameters(param.params); - } break; - case Change::kTxAddParam: { - auto param = args.cast(); - MS_EXCEPTION_IF_NULL(param.func_graph); - (*adds)[param.param] += 1; - auto param_node = param.param->cast(); - param.func_graph->append_parameter(param_node); - } break; - default: - break; - } - } -} - -void FuncGraphManager::CommitChanges(const std::vector &changes) { - EdgeTupleCounter add_edges; - EdgeTupleCounter rm_edges; - Counter adds; - Counter rms; - ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); - - auto sub_edges = add_edges - rm_edges; - for (auto &iter : sub_edges) { - auto root_node = iter.first.first; - int index = iter.first.second.first; - auto new_node = iter.first.second.second; - ProcessEdge(root_node, index, new_node, kIncEdge); - } - - auto sub_nodes = adds - rms; - std::vector nodes; - (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), - [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); - - AcquireNodes(nodes); - - auto sub_edges_reverse = rm_edges - add_edges; - for (auto &iter : sub_edges_reverse) { - auto root_node = iter.first.first; - int index = iter.first.second.first; - auto old_node = iter.first.second.second; - ProcessEdge(root_node, index, old_node, kDecEdge); - } - - auto sub_nodes_reverse = rms - adds; - std::vector nodes_reverse; - - (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), - [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); - - auto drop_func_graphs = MaybeDropNodes(nodes_reverse); - MaybeDropFuncGraphs(*drop_func_graphs); -} - -void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { - changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); -} - -void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { - changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); -} - -bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { - MS_EXCEPTION_IF_NULL(old_node); - MS_EXCEPTION_IF_NULL(new_node); - FuncGraphPtr old_func_graph = old_node->func_graph(); - if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) { - MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString(); - return false; - } - auto users = manager_->node_users()[old_node]; - for (auto &node : users) { - SetEdge(node.first, node.second, new_node); - } - - return true; -} - -void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { - if (k < 0) { - MS_LOG(EXCEPTION) << "Invalid value k = " << k; - } - MS_EXCEPTION_IF_NULL(src_node); - auto cnode = src_node->cast(); - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed."; - } - changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)}); -} - -void FuncGraphTransaction::Commit() { - std::vector changes; - changes_.swap(changes); - manager_->CommitChanges(changes); -} - -DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) { - MS_EXCEPTION_IF_NULL(manager_); - manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); - validate_ = false; -} - -void DepComputer::Recompute() { - if (!validate_) { - RealRecompute(); - validate_ = true; - } -} - -void DepComputer::Recompute(const FuncGraphPtr &fg) { - if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { - RealRecompute(fg); - func_graphs_validate_[fg] = true; - } -} - -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { - if (fg->seen_ == seen_num) { - return std::make_shared(); - } - FuncGraphSetPtr parents = std::make_shared(); - - // Append all the fvs in fg. - auto &fvs = fg->free_variables(); - for (auto fv : fvs) { - parents->add(fv.first->func_graph()); - } - - // Search the fv in fg's child func graph. - auto &fgs = fg->func_graphs_used(); - for (auto &item : fgs) { - fg->seen_ = seen_num; - auto gt = item.first; - parents->update(SeekParents(gt, seen_num)); - } - (void)parents->erase(fg); - return parents; -} - -void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(fg); - func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); -} - -bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { - auto l1 = lhs.second.size(); - auto l2 = rhs.second.size(); - return l1 < l2; -} - -void ParentComputer::RealRecompute(FuncGraphPtr fg) { - this->parent_analysis_[fg] = nullptr; - // Note: must be a copy other than reference as it is modified thereafter. - auto deps = this->manager_->func_graph_parents_total(fg); - - if (deps.empty()) { - this->parent_analysis_[fg] = nullptr; - return; - } else if (deps.size() == 1) { - this->parent_analysis_[fg] = deps.pop(); - return; - } else { - // return nearest parent as parent - FuncGraphSet deps_copy(deps); - for (auto &dep : deps) { - auto parent_deps = this->manager_->func_graph_parents_total(dep); - for (auto &p_d : parent_deps) { - if (deps_copy.count(p_d)) { - (void)deps_copy.erase(p_d); - } - } - if (deps_copy.size() == 1) { - this->parent_analysis_[fg] = deps_copy.pop(); - return; - } - } - } -} - -void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(manager_); - auto used_fg_total = manager_->func_graphs_used_total(fg); - for (auto &used_fg : used_fg_total) { - if (manager_->parent(used_fg) == fg) { - children_analysis_[fg].add(used_fg); - } - } -} - -void ScopeComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(manager_); - auto &children = manager_->children(fg); - - scope_analysis_[fg] = FuncGraphSet(); - scope_analysis_[fg].add(fg); - for (auto &child : children) { - scope_analysis_[fg].add(child); - } -} - -void FVTotalComputer::RealRecompute() { - auto manager = DepComputer::manager_; - MS_EXCEPTION_IF_NULL(manager); - - for (auto &fg : manager->func_graphs()) { - fv_total_analysis_[fg] = OrderedMap(); - } - - for (auto &fg : manager->func_graphs()) { - // add all free variable nodes - AnfNodeCounterMap items = fg->free_variables(); - for (auto &iter : items) { - auto curr = fg; - while (curr != nullptr) { - fv_total_analysis_[curr][iter.first] = iter.second; - curr = manager->parent(curr); - if (curr != nullptr) { - const AnfNodeSet &all_nodes = curr->nodes(); - if (all_nodes.contains(iter.first)) { - break; - } - } - } - } - - // add all FGs of free variables - auto &used = fg->func_graphs_used(); - for (auto &iter : used) { - auto p = manager->parent(iter.first); - if (p == nullptr) { - continue; - } - auto curr = fg; - while (curr != p) { - fv_total_analysis_[curr][iter.first] = iter.second; - curr = manager->parent(curr); - } - } - } -} - -void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { - MS_EXCEPTION_IF_NULL(manager_); - std::vector todo; - std::vector todo_new; - - todo.push_back(fg); - while (!todo.empty()) { - todo_new.clear(); - for (auto > : todo) { - for (auto &item : gt->func_graphs_used()) { - auto used_fg = item.first; - if (used_fg == fg) { - func_graph_used_total_analysis_[fg].add(used_fg); - continue; - } - if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) { - todo_new.push_back(used_fg); - } - MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString(); - func_graph_used_total_analysis_[fg].add(used_fg); - } - } - todo = todo_new; - } -} - -bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { - MS_EXCEPTION_IF_NULL(manager); - std::vector todo; - std::vector todo_new; - todo.push_back(fg); - FuncGraphSet used_total; - while (!todo.empty()) { - todo_new.clear(); - for (auto > : todo) { - for (auto &item : gt->func_graphs_used()) { - auto used_g = item.first; - if (used_g == fg) { - return true; - } - if (used_total.count(used_g) == 0) { - todo_new.push_back(used_g); - } - used_total.add(used_g); - } - } - todo = todo_new; - } - return false; -} - -void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { - this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); -} - -void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { - MS_EXCEPTION_IF_NULL(trace); - auto res = std::find(trace->begin(), trace->end(), fg); - // find recursive - if (res != trace->end()) { - auto recur_ptr = std::make_shared>(res, trace->end()); - for (auto iter = res; iter != trace->end(); (void)iter++) { - MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString(); - recursive_map_[*iter] = recur_ptr; - } - } else { - trace->push_back(fg); - auto &items = fg->func_graphs_used(); - for (auto iter = items.begin(); iter != items.end(); (void)iter++) { - CheckRecursiveGraphs(iter->first, trace); - } - trace->pop_back(); - if (!recursive_map_.count(fg)) { - recursive_map_[fg] = nullptr; - } - } -} - -bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { - if (fg->seen_ == seen_num) { - MS_LOG(DEBUG) << fg->ToString() << " had been checked"; - return false; - } - auto &j_fgs = fg->j_func_graphs(); - if (!j_fgs.empty()) { - // check g1->J(fg)->g2->g cycle; - auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair iter) { - return iter.first->seen_ != seen_num; - }); - if (contains_j != j_fgs.end()) { - MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; - return true; - } - } - fg->seen_ = seen_num; - - // check if func graphs used contains J(func_graph); - for (auto &item : fg->func_graphs_used()) { - auto used_g = item.first; - if (SeekJ(used_g, seen_num)) { - MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; - return true; - } - } - MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)"; - return false; -} - -void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { - this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.cc b/mindspore/ccsrc/ir/meta_func_graph.cc deleted file mode 100644 index 3b2704613a..0000000000 --- a/mindspore/ccsrc/ir/meta_func_graph.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/meta_func_graph.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/abstract_function.h" - -// namespace to support intermediate representation definition -namespace mindspore { -abstract::AbstractBasePtr MetaFuncGraph::MakeAbstractClosure(const AnfNodePtr &anf_node) { - abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn; - if (anf_node == nullptr) { - meta_func_graph_fn = std::make_shared(shared_from_base()); - } else { - meta_func_graph_fn = - std::make_shared(shared_from_base(), anf_node->scope()); - } - return meta_func_graph_fn; -} - -FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { - TypePtrList types; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), - [](const AbstractBasePtr &arg) -> TypePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->BuildType(); - }); - // filter unsafe characters in log print since name_ is from outside - auto iter = cache_.find(types); - if (iter == cache_.end()) { - FuncGraphPtr fg = GenerateFromTypes(types); - MS_EXCEPTION_IF_NULL(fg); - MS_LOG(INFO) << "MetaFuncgraph: cache miss for types: " << mindspore::ToString(args_spec_list) - << ", g: " << fg->ToString(); - cache_[types] = fg; - return fg; - } else { - MS_LOG(DEBUG) << "MetaFuncgraph: cache hit for types: " << mindspore::ToString(args_spec_list) - << ", g: " << iter->second->ToString(); - return iter->second; - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/pattern_matcher.h b/mindspore/ccsrc/ir/pattern_matcher.h deleted file mode 100644 index 6605b9ce4c..0000000000 --- a/mindspore/ccsrc/ir/pattern_matcher.h +++ /dev/null @@ -1,310 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ -#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ - -#include -#include - -#include "ir/anf.h" -#include "operator/ops.h" - -namespace mindspore { - -/// -/// Base class for all recognizable patterns. -/// We implement an Expression Template approach using static polymorphism based on -/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect -/// to the use of virtual functions without the costs..." as described in: -/// https://en.wikipedia.org/wiki/Expression_templates and -/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern -/// The TryCapture function tries to capture the pattern with the given node. -/// The GetNode function builds a new node using the captured values. -/// - -template -class PBase { - public: - bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) { - return func(get_object().GetNode(node)); - } - - const T &get_object() const { return *static_cast(this); } - - template - bool TryCapture(const TN &value) const { - get_object().Reset(); - return get_object().TryCapture_(value); - } - - using Internal = T; -}; - -template -class PIsEqual { - public: - bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } -}; - -template -class PatternNode : public PBase > { - public: - T GetNode(const AnfNodePtr &node) const { - if (!captured_) { - MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; - } - return captured_node_; - } - - bool TryCapture_(const T &node) const { - if (!captured_) { - captured_node_ = node; - captured_ = true; - return true; - } - return PIsEqual()(captured_node_, node); - } - - void Reset() const { captured_ = false; } - using Internal = const PatternNode &; - - protected: - mutable T captured_node_; - mutable bool captured_{false}; -}; - -template -class PBinOperation : public PBase > { - public: - PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - AnfNodePtr lhs = x_.GetNode(node->func_graph()); - AnfNodePtr rhs = y_.GetNode(node->func_graph()); - AnfNodePtrList list = {prim_->cast(), lhs, rhs}; - return NewCNode(list, node->func_graph()); - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (IsPrimitiveCNode(node, prim_)) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - if (inputs.size() == 3) { - // Binary Prim assumes only two inputs - if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { - return false; - } - return true; - } - } - return false; - } - - void Reset() const { - x_.Reset(); - y_.Reset(); - } - - private: - const PrimitivePtr prim_; - typename T::Internal x_; - typename T2::Internal y_; -}; - -/// -/// Helper functions to apply a pattern function on all elements of a tuple -/// -namespace tuple_utils { -template -struct apply_func_tuple_item { - template - static void apply(Func *func, const TTuple &tuple) { - (*func)(Index, std::get(tuple)); - apply_func_tuple_item<(Index + 1) == std::tuple_size::value, (Index + 1), Func>::apply(func, tuple); - } -}; - -template -struct apply_func_tuple_item { - template - static void apply(Func *func, const TTuple &tuple) {} -}; - -template -inline void apply_func_tuple(Func *func, const TTuple &tuple) { - apply_func_tuple_item::value == 0, 0, Func>::apply(func, tuple); -} - -struct PTupleResetCapture { - template - void operator()(size_t i, const T &pattern) const { - pattern.Reset(); - } -}; - -struct PTupleCapture { - explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {} - - template - void operator()(size_t i, const TPattern &pattern) { - // Check if the first node is a Primitive - if (i == 0 && tuple_[i]->isa()) { - auto prim = tuple_[i]->cast(); - if (tuple_[i] != pattern.GetNode(tuple_[i])) { - captured_ = false; - } - } else { - captured_ = captured_ && pattern.TryCapture_(tuple_[i]); - } - } - - const AnfNodePtrList tuple_; - bool captured_{true}; -}; - -struct PTupleGetNode { - explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {} - - template - void operator()(size_t, const TPattern &pattern) { - args_.push_back(pattern.GetNode(node_)); - } - - const AnfNodePtr &node_; - std::vector args_; -}; -} // namespace tuple_utils - -template -class PCNode : public PBase > { - public: - explicit PCNode(const TArgs &... args) : args_(args...) {} - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - tuple_utils::PTupleGetNode get_node(node); - tuple_utils::apply_func_tuple(&get_node, args_); - return NewCNode(get_node.args_, node->func_graph()); - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (node->isa()) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - if (inputs.size() != sizeof...(TArgs)) { - return false; - } - tuple_utils::PTupleCapture capture_func(inputs); - tuple_utils::apply_func_tuple(&capture_func, args_); - return capture_func.captured_; - } - - return false; - } - - void Reset() const { - tuple_utils::PTupleResetCapture reset; - tuple_utils::apply_func_tuple(&reset, args_); - } - - private: - std::tuple args_; -}; - -template -class PPrimitive : public PBase > { - public: - explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - tuple_utils::PTupleGetNode get_node(node); - tuple_utils::apply_func_tuple(&get_node, args_); - auto prim_cnode = get_node.args_; - prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); - return NewCNode(prim_cnode, node->func_graph()); - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (IsPrimitiveCNode(node, prim_)) { - auto cnode = node->cast(); - auto inputs = cnode->inputs(); - if ((inputs.size() - 1) != sizeof...(TArgs)) { - return false; - } - - AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); - tuple_utils::PTupleCapture capture_func(rest); - tuple_utils::apply_func_tuple(&capture_func, args_); - - return capture_func.captured_; - } - - return false; - } - - void Reset() const { - tuple_utils::PTupleResetCapture reset; - tuple_utils::apply_func_tuple(&reset, args_); - } - - private: - const PrimitivePtr prim_; - std::tuple args_; -}; - -// Macro for binary operation functions -#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ - template \ - inline PBinOperation Operator(const PBase &x, const PBase &y) { \ - return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ - } - -// Arithmetic operations -BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); -BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); - -// Macros for match and replace -#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ - if ((CaptureNode).TryCapture(OrigNode)) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } - -#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ - if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } - -#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ - if ((CaptureNode).TryCapture(OrigNode)) { \ - if ((Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } \ - return (ElseNode).GetNode(OrigNode); \ - } - -#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ - if ((CaptureNode).TryCapture(OrigNode)) { \ - return (Lambda)(); \ - } - -#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ - if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (Lambda)(); \ - } - -} // namespace mindspore - -#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h deleted file mode 100644 index 2a4d689ae9..0000000000 --- a/mindspore/ccsrc/ir/primitive.h +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_H_ - -#include -#include -#include -#include -#include - -#include "ir/dtype/type.h" -#include "abstract/abstract_value.h" -#include "parallel/ops_info/operator_info.h" -#include "utils/base_ref_extends.h" - -namespace mindspore { -// Supported meta type -enum PrimType { - kPrimTypeUnknown = 0, - kPrimTypeBegin = kTypeUnknown, - kPrimTypeBuiltIn, // Built-in primitive operator - kPrimTypePyInferShape, // Primitive operator defined by custom - kPrimTypePyInferTensor, // Primitive operator defined by custom - kPrimTypeUserCustom -}; - -class Primitive : public Named { - public: - explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) - : Named(name), - is_base_(is_base), - has_signature_(false), - prim_type_(prim_type), - record_evaluate_add_attr_(false) {} - - Primitive(const Primitive &prim) - : Named(prim), - attrs_(prim.attrs_), - instance_name_(prim.instance_name_), - is_base_(prim.is_base_), - has_signature_(prim.has_signature_), - prim_type_(prim.prim_type_), - record_evaluate_add_attr_(false) {} - - MS_DECLARE_PARENT(Primitive, Named); - - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); - std::string ToString() const override { return name(); } - void BeginRecordAddAttr() { - evaluate_added_attrs_.clear(); - record_evaluate_add_attr_ = true; - } - void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } - Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { - attrs_[name] = attr; - if (record_evaluate_add_attr_) { - evaluate_added_attrs_[name] = attr; - } - return *this; - } - - Primitive &SetAttrs(const std::unordered_map &attrs) { - for (auto &attr : attrs) { - attrs_[attr.first] = attr.second; - } - return *this; - } - - void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - - ValuePtr GetAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return iter == attrs_.cend() ? nullptr : iter->second; - } - - const std::unordered_map &attrs() const { return attrs_; } - const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } - - // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. - bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string &attrName) const { - auto iter = attrs_.find(attrName); - return !(iter == attrs_.cend()); - } - void set_prim_type(const PrimType t) { prim_type_ = t; } - void set_instance_name(const std::string s) { instance_name_ = s; } - bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } - bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } - bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } - - PrimType prim_type() const { return prim_type_; } - std::string instance_name() const { return instance_name_; } - std::string GetAttrsText() const; - bool operator==(const Value &other) const override; - bool operator==(const Primitive &other) const; - ~Primitive() override = default; - - void set_has_signature(bool has_signature) { has_signature_ = has_signature; } - bool has_signature() const { return has_signature_; } - bool is_base() const { return is_base_; } - virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } - virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } - - protected: - std::unordered_map attrs_; - std::unordered_map evaluate_added_attrs_; - - private: - std::string instance_name_; - bool is_base_; - bool has_signature_; - PrimType prim_type_; - bool record_evaluate_add_attr_; -}; - -inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { - os << *p; - return os; -} - -struct PrimitiveEqual { - bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return t1->name() == t2->name(); - } -}; - -struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const &prim) const { - MS_EXCEPTION_IF_NULL(prim); - return prim->Hash(); - } -}; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/ccsrc/ir/primitive_extends.cc b/mindspore/ccsrc/ir/primitive_extends.cc deleted file mode 100644 index 9df46920bf..0000000000 --- a/mindspore/ccsrc/ir/primitive_extends.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive.h" -#include "pipeline/static_analysis/abstract_function.h" - -namespace mindspore { -abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { - auto prim_func = std::make_shared(shared_from_base(), anf_node); - return prim_func; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive_py.cc b/mindspore/ccsrc/ir/primitive_py.cc deleted file mode 100644 index b672f470c9..0000000000 --- a/mindspore/ccsrc/ir/primitive_py.cc +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/primitive_py.h" -#include -#include -#include "ir/signature.h" -#include "operator/ops.h" -#include "./common.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" -#include "pybind11/pytypes.h" -#include "utils/convert_utils_base.h" -#include "utils/primitive_utils.h" -#include "utils/base_ref_py.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" - -namespace mindspore { -namespace { -constexpr auto kBpropAttrName = "bprop"; -constexpr auto kCellHookAttrName = "cell_hook"; -constexpr auto kCellIDAttrName = "cell_id"; -void SyncData(const py::object &arg) { - if (py::isinstance(arg)) { - py::tuple arg_list = py::cast(arg); - for (size_t i = 0; i < arg_list.size(); i++) { - SyncData(arg_list[i]); - } - } - if (py::isinstance(arg)) { - auto tensor = py::cast(arg); - (void)tensor->data_sync(); - } -} -} // namespace -std::map PrimitivePy::hook_grad_; -static ValuePtr PyArgToValue(const py::object &arg) { - if (py::isinstance(arg) && - py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { - return nullptr; - } - return parse::data_converter::PyDataToValue(arg); -} - -void PrimitivePy::set_signatures( - std::vector> signatures) { - signatures_.clear(); - for (auto &signature : signatures) { - auto [name, rw, kind, arg_default, dtype] = signature; - auto default_value = PyArgToValue(arg_default); - signatures_.emplace_back(name, rw, kind, default_value, dtype); - } - set_has_signature(true); -} - -py::function PrimitivePy::GetBpropFunction() { - static const char *const get_bprop_func_name = "get_bprop"; - if (py::hasattr(python_obj_, get_bprop_func_name)) { - py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); - return fn; - } else { - auto fn = GetBpropFunctionByObj(python_obj_); - return fn; - } -} - -BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { - auto py_args = py::tuple(args.size()); - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg:" << i << ":"; - i++; - } - py::object obj; - bool is_bprop = this->HasAttr(kBpropAttrName); - if (is_bprop) { - SyncData(py_args); - obj = hook_(*py_args); - return std::make_shared(obj); - } - SyncData(py_args[2]); - bool is_cell = this->HasAttr(kCellHookAttrName); - if (is_cell) { - auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); - auto iter = hook_grad_.find(cell_id); - if (iter != hook_grad_.end()) { - auto hook_args = py::tuple(3); - hook_args[0] = cell_id; - hook_args[1] = py::make_tuple(iter->second); - hook_args[2] = py::make_tuple(py_args[2]); - obj = hook_(*hook_args); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - hook_grad_.erase(cell_id); - } else { - hook_grad_[cell_id] = py_args[2]; - obj = py_args[2]; - } - } else { - // Hook operator for execute variable hook function - obj = hook_(py::make_tuple(py_args[2])); - if (py::isinstance(obj)) { - obj = py_args[2]; - } - } - obj = py::make_tuple(obj); - return std::make_shared(obj); -} - -py::function PrimitivePy::GetComputeFunction() { - static const char *const compute_func_name = "vm_impl"; - - if (py::hasattr(python_obj_, compute_func_name)) { - MS_LOG(INFO) << name() << " compute_func_name"; - py::function fn = python_obj_.attr(compute_func_name).cast(); - return fn; - } - - static const std::string vm_module = "mindspore.ops.vm_impl_registry"; - static const std::string get_vm_impl_fn = "get_vm_impl_fn"; - MS_LOG(INFO) << name() << ": get_vm_impl_fn"; - py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); - py::function vm_fn = get_fn(python_obj_); - - if (py::isinstance(vm_fn)) { - MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); - vm_fn = mindspore::GetComputeFunction(Primitive::name()); - } - return vm_fn; -} - -void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { - std::string attr_name = name; - ValuePtr converted_ret = nullptr; - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; - } - bool converted = parse::ConvertData(obj, &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); - } - (void)this->AddAttr(attr_name, converted_ret); -} - -py::dict PrimitivePy::GetAttrDict() { - py::dict attr_dict; - for (auto &attr : attrs_) { - attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); - } - return attr_dict; -} - -void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { - MS_EXCEPTION_IF_NULL(primitive); - if (!primitive->isa()) { - MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; - } - auto primitive_py = primitive->cast(); - MS_EXCEPTION_IF_NULL(primitive_py); - this->set_hook(primitive_py->hook()); -} - -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { - (void)py::enum_(*m, "prim_type", py::arithmetic()) - .value("unknown", PrimType::kPrimTypeUnknown) - .value("builtin", PrimType::kPrimTypeBuiltIn) - .value("py_infer_shape", PrimType::kPrimTypePyInferShape) - .value("user_custom", PrimType::kPrimTypeUserCustom); - (void)py::class_>(*m, "Primitive_") - .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) - .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") - .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") - .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") - .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") - .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") - .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); - })); -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/primitive_py.h b/mindspore/ccsrc/ir/primitive_py.h deleted file mode 100644 index 7dc26d1561..0000000000 --- a/mindspore/ccsrc/ir/primitive_py.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ -#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ - -#include -#include -#include -#include -#include -#include - -#include "abstract/abstract_value.h" -#include "utils/misc.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" -#include "ir/primitive.h" -#include "ir/signature.h" -#include "parallel/ops_info/operator_info.h" - -namespace py = pybind11; -namespace mindspore { -class PrimitivePy : public Primitive { - public: - PrimitivePy(const py::str &name, const py::object &python_obj) - : Primitive(name, false), python_obj_(python_obj), signatures_() {} - ~PrimitivePy() override = default; - MS_DECLARE_PARENT(PrimitivePy, Primitive); - py::function GetBpropFunction(); - py::function GetComputeFunction(); - - void set_signatures( - std::vector> - signatures); - - const std::vector &signatures() const { return signatures_; } - - void CopyHookFunction(const PrimitivePtr &primitive) override; - - void AddPyAttr(const py::str &name, const py::object &obj); - - py::dict GetAttrDict(); - void set_hook(const py::function &hook) { hook_ = hook; } - py::function hook() const { return hook_; } - BaseRef RunHookFunction(const VectorRef &args) const override; - const bool parse_info_ = true; - const py::object &GetPyObj() const { return python_obj_; } - bool is_tuple_input_ = false; - - private: - py::object python_obj_; - py::function hook_; - std::vector signatures_; - static std::map hook_grad_; -}; - -using PrimitivePyPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ diff --git a/mindspore/ccsrc/ir/signature_py.cc b/mindspore/ccsrc/ir/signature_py.cc deleted file mode 100644 index 2b01b3e579..0000000000 --- a/mindspore/ccsrc/ir/signature_py.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/signature.h" -#include "pybind11/operators.h" -#include "pybind_api/api_register.h" -#include "pipeline/parse/data_converter.h" - -namespace py = pybind11; - -namespace mindspore { -// Bind SignatureEnumRW as a python class. -REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { - (void)py::enum_(*m, "signature_rw", py::arithmetic()) - .value("RW_READ", SignatureEnumRW::kRWRead) - .value("RW_WRITE", SignatureEnumRW::kRWWrite) - .value("RW_REF", SignatureEnumRW::kRWRef) - .value("RW_EMPTY_DEFAULT_VALUE", SignatureEnumRW::kRWEmptyDefaultValue); - (void)py::enum_(*m, "signature_kind", py::arithmetic()) - .value("KIND_POSITIONAL_KEYWORD", SignatureEnumKind::kKindPositionalKeyword) - .value("KIND_VAR_POSITIONAL", SignatureEnumKind::kKindVarPositional) - .value("KIND_KEYWORD_ONLY", SignatureEnumKind::kKindKeywordOnly) - .value("KIND_VAR_KEYWARD", SignatureEnumKind::kKindVarKeyword) - .value("KIND_EMPTY_DEFAULT_VALUE", SignatureEnumKind::kKindEmptyDefaultValue); - (void)py::enum_(*m, "signature_dtype", py::arithmetic()) - .value("T", SignatureEnumDType::kDType) - .value("T1", SignatureEnumDType::kDType1) - .value("T2", SignatureEnumDType::kDType2) - .value("T3", SignatureEnumDType::kDType3) - .value("T4", SignatureEnumDType::kDType4) - .value("T5", SignatureEnumDType::kDType5) - .value("T6", SignatureEnumDType::kDType6) - .value("T7", SignatureEnumDType::kDType7) - .value("T8", SignatureEnumDType::kDType8) - .value("T9", SignatureEnumDType::kDType9) - .value("T_EMPTY_DEFAULT_VALUE", SignatureEnumDType::kDTypeEmptyDefaultValue); - })); -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc deleted file mode 100644 index b2a2f38915..0000000000 --- a/mindspore/ccsrc/ir/tensor.cc +++ /dev/null @@ -1,506 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/tensor.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "device/device_address.h" -#include "abstract/abstract_value.h" - -namespace mindspore { -namespace tensor { -constexpr auto kEllipsis = "..."; -constexpr auto kThreshold = 6; - -constexpr auto kThreshold1DFloat = kThreshold * 2; -constexpr auto kThreshold1DInt = kThreshold * 4; -constexpr auto kThreshold1DBool = kThreshold * 2; - -static std::string MakeId() { - // Use atomic to make id generator thread safe. - static std::atomic last_id{1}; - return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); -} - -static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { - return data_type ? data_type->type_id() : defaultTypeId; -} - -static size_t SizeOf(const std::vector &shape) { - return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); -} - -template -std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { - const size_t count = SizeOf(shape); - switch (data_type) { - case kNumberTypeBool: - case kNumberTypeUInt8: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt8: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt16: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt32: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeInt64: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeUInt16: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeUInt32: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeUInt64: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeFloat16: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeFloat32: { - const float *buf = static_cast(data); - return std::vector(buf, buf + count); - } - case kNumberTypeFloat64: { - auto buf = static_cast(data); - return std::vector(buf, buf + count); - } - default: - break; - } - MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; -} - -template -std::vector CopyData(const std::vector &shape, void *data, size_t data_len) { - size_t size = SizeOf(shape); - if (size * sizeof(T) != data_len) { - MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) - << " item size " << sizeof(T); - } - auto buf = static_cast(data); - return {buf, buf + size}; -} - -// Tensor data implementation. -template -class TensorDataImpl : public TensorData { - public: - explicit TensorDataImpl(const std::vector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} - ~TensorDataImpl() = default; - - TensorDataImpl(const std::vector &shape, void *data, size_t data_len) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_len)) {} - - TensorDataImpl(const std::vector &shape, void *data, TypeId data_type) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_type)) {} - - template - TensorDataImpl(const std::vector &shape, InputIt first, InputIt last) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {} - - template - TensorDataImpl(const std::vector &shape, Scalar scalar) - : ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast(scalar)}) {} - - ssize_t size() const override { return static_cast(data_size_); } - - ssize_t itemsize() const override { return static_cast(sizeof(T)); } - - ssize_t nbytes() const override { return size() * itemsize(); } - - ssize_t ndim() const override { return static_cast(ndim_); } - - void *data() override { - static std::vector empty_data(1); - if (data_size_ == 0) { - // Prevent null pointer for empty shape. - return empty_data.data(); - } - // Lazy allocation. - if (data_.empty()) { - data_.resize(data_size_); - } - return data_.data(); - } - - bool equals(const TensorData &other) const override { - auto ptr = dynamic_cast *>(&other); - if (ptr) { - return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && (data_ == ptr->data_)); - } - return false; - } - - std::string ToString(const TypeId type, const std::vector &shape) const override { - constexpr auto valid = - std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || std::is_same::value; - static_assert(valid, "Type is invalid"); - if (data_size_ == 0) { - return ""; - } - if (data_.empty()) { - return ""; - } - - std::ostringstream ss; - ssize_t cursor = 0; - SummaryStringRecursive(ss, type, shape, &cursor, 0); - return ss.str(); - } - - private: - void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const { - int linefeedThreshold; - constexpr auto isFloat = - std::is_same::value || std::is_same::value || std::is_same::value; - for (ssize_t i = start; i < end && (cursor + i) < static_cast(data_size_); i++) { - const auto value = data_[cursor + i]; - if constexpr (isFloat) { - ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right) - << value; - linefeedThreshold = kThreshold1DFloat; - } else if (type == kNumberTypeBool) { - ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True"); - linefeedThreshold = kThreshold1DBool; - } else { - constexpr auto isSigned = std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value; - if constexpr (isSigned) { - if (static_cast(value) >= 0) { - ss << ' '; - } - } - if constexpr (std::is_same::value) { - ss << static_cast(value); - } else if constexpr (std::is_same::value) { - ss << static_cast(value); - } else { - ss << value; - } - linefeedThreshold = kThreshold1DInt; - } - if (i != end - 1) { - ss << ' '; - } - if (ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { // Add a line feed every {threshold of type} for 1D tensor. - ss << '\n' << ' '; - } - } - } - - void SummaryStringRecursive(std::ostringstream &ss, const TypeId type, const std::vector &shape, ssize_t *cursor, - ssize_t depth) const { - if (depth >= static_cast(ndim_)) { - return; - } - ss << '['; - if (depth == static_cast(ndim_) - 1) { // Bottom dimension - ssize_t num = shape[depth]; - if (num > kThreshold && ndim_ > 1) { - OutputDataString(ss, type, *cursor, 0, kThreshold / 2); - ss << ' ' << kEllipsis << ' '; - OutputDataString(ss, type, *cursor, num - kThreshold / 2, num); - } else { - OutputDataString(ss, type, *cursor, 0, num); - } - *cursor += num; - } else { // Middle dimension - ssize_t num = shape[depth]; - // Handle the first half. - for (ssize_t i = 0; i < std::min(static_cast(kThreshold / 2), num); i++) { - if (i > 0) { - ss << '\n'; - ss << std::setw(depth + 1) << ' '; // Add the indent. - } - SummaryStringRecursive(ss, type, shape, cursor, depth + 1); - } - // Handle the ignored part. - if (num > kThreshold) { - ss << '\n'; - ss << std::setw(depth + 1) << ' '; // Add the indent. - ss << kEllipsis; - // Ignored at this layer. - ssize_t ignored = shape[depth + 1]; - for (ssize_t i = depth + 2; i < static_cast(ndim_); i++) { - ignored *= shape[i]; - } - // Multiple with ignored layers number. - ignored *= num - kThreshold; - - *cursor += ignored; - } - // Handle the second half. - if (num > kThreshold / 2) { - for (ssize_t i = num - kThreshold / 2; i < num; i++) { - ss << '\n'; - ss << std::setw(depth + 1) << ' '; // Add the indent. - SummaryStringRecursive(ss, type, shape, cursor, depth + 1); - } - } - } - ss << ']'; - } - - size_t ndim_{0}; - size_t data_size_{0}; - std::vector data_; -}; - -template -TensorDataPtr MakeTensorData(TypeId data_type, const std::vector &shape, const Args... args) { - switch (data_type) { - case kNumberTypeBool: - case kNumberTypeUInt8: - return std::make_shared>(shape, args...); - case kNumberTypeInt8: - return std::make_shared>(shape, args...); - case kNumberTypeInt16: - return std::make_shared>(shape, args...); - case kNumberTypeInt32: - return std::make_shared>(shape, args...); - case kNumberTypeInt64: - return std::make_shared>(shape, args...); - case kNumberTypeUInt16: - return std::make_shared>(shape, args...); - case kNumberTypeUInt32: - return std::make_shared>(shape, args...); - case kNumberTypeUInt64: - return std::make_shared>(shape, args...); - case kNumberTypeFloat16: - return std::make_shared>(shape, args...); - case kNumberTypeFloat32: - return std::make_shared>(shape, args...); - case kNumberTypeFloat64: - return std::make_shared>(shape, args...); - default: - break; - } - MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; -} - -Tensor::Tensor(const Tensor &tensor) - : MetaTensor(tensor), - init_flag_(tensor.init_flag_), - data_(tensor.data_), - dirty_(tensor.dirty_), - id_(tensor.id_), - device_address_(tensor.device_address_) {} - -Tensor::Tensor(const Tensor &tensor, TypeId data_type) - : MetaTensor(data_type, tensor.shape_), - init_flag_(tensor.init_flag_), - data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), - dirty_(tensor.dirty_), - id_(tensor.id_), - device_address_(tensor.device_address_) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data) - : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape) - : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len) - : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} - -Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type) - : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} - -Tensor::Tensor(const std::vector &input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast(input.size())}), - data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), - id_(MakeId()) {} - -Tensor::Tensor(const std::vector &input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast(input.size())}), - data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), - id_(MakeId()) {} - -Tensor::Tensor(int64_t input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}), - data_(MakeTensorData(data_type_, {}, input)), - id_(MakeId()) {} - -Tensor::Tensor(double input, const TypePtr &data_type) - : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}), - data_(MakeTensorData(data_type_, {}, input)), - id_(MakeId()) {} - -bool Tensor::operator==(const Tensor &tensor) const { - return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); -} - -bool Tensor::ValueEqual(const Tensor &tensor) const { - return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); -} -// assgin value to this tensor -Tensor &Tensor::AssignValue(const Tensor &tensor) { - if (this != &tensor) { - MetaTensor::operator=(tensor); - dirty_ = tensor.is_dirty(); - device_address_ = tensor.device_address(); - data_ = tensor.data_; - id_ = tensor.id(); - } - return *this; -} -abstract::AbstractBasePtr Tensor::ToAbstract() { - auto tens = shared_from_base(); - auto dtype = tens->Dtype(); - if (!IsSubType(dtype, kNumber)) { - MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; - } - auto tensor_shape = tens->shape(); - auto abs_tensor = std::make_shared(dtype, tensor_shape); - abs_tensor->set_value(shared_from_base()); - return abs_tensor; -} - -std::string Tensor::GetShapeAndDataTypeInfo() const { - std::ostringstream buf; - buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); - return buf.str(); -} - -std::string Tensor::ToString() const { - const int small_tensor_size = 30; - std::ostringstream buf; - buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); - // only print small tensor - if (DataSize() < small_tensor_size) { - buf << ", value:" << data().ToString(data_type_, shape()); - } - return buf.str(); -} - -std::string Tensor::ToStringRepr() const { - std::ostringstream buf; - auto type_ptr = this->Dtype(); - MS_EXCEPTION_IF_NULL(type_ptr); - buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString(); - buf << "\nvalue:" << data().ToString(data_type_, shape()); - return buf.str(); -} - -void Tensor::data_sync() const { - if (device_address_ != nullptr) { - if (!device_address_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { - MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; - } - } -} - -TypeId Tensor::set_data_type(const TypeId data_type) { - if (data_type != data_type_) { - data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_); - return MetaTensor::set_data_type(data_type); - } - return data_type; -} -} // namespace tensor - -namespace inference { -MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { - return new Tensor(data_type, shape); -} - -Tensor::Tensor(TypeId data_type, const std::vector &shape) { - this->tensor_impl_ = std::make_shared(data_type, shape); -} - -Tensor::Tensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } - -TypeId Tensor::data_type() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_type(); -} - -TypeId Tensor::set_data_type(TypeId data_type) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_data_type(data_type); -} - -std::vector Tensor::shape() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->shape(); -} - -size_t Tensor::set_shape(const std::vector &shape) { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->set_shape(shape); -} - -int Tensor::DimensionSize(size_t index) const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->DimensionSize(index); -} - -int Tensor::ElementsNum() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->ElementsNum(); -} - -std::size_t Tensor::hash() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->hash(); -} - -std::shared_ptr Tensor::tensor() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_; -} - -size_t Tensor::Size() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data().nbytes(); -} - -void *Tensor::MutableData() const { - MS_ASSERT(this->tensor_impl_ != nullptr); - return this->tensor_impl_->data_c(); -} - -} // namespace inference -} // namespace mindspore diff --git a/mindspore/ccsrc/ir/tensor.h b/mindspore/ccsrc/ir/tensor.h deleted file mode 100644 index 8230780d02..0000000000 --- a/mindspore/ccsrc/ir/tensor.h +++ /dev/null @@ -1,278 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_IR_TENSOR_H_ -#define MINDSPORE_CCSRC_IR_TENSOR_H_ - -#include -#include -#include -#include - -#include "Eigen/Core" -#include "device/device_address.h" -#include "ir/meta_tensor.h" -#include "include/ms_tensor.h" -#include "utils/log_adapter.h" - -using float16 = Eigen::half; - -using mindspore::device::DeviceAddress; -using DeviceAddressPtr = std::shared_ptr; -// brief mindspore namespace. -// -// mindspore namespace is the top level namespace of MindSpore project. -// Other namespace should be a sub namespace of mindspore namespace in the ME project. -namespace mindspore { -// brief mindspore::tensor namespace -// -// A sub namespace in ME to support tensor related definition. -namespace tensor { -// Tensor data interface. -class TensorData { - public: - /// Total number of elements. - virtual ssize_t size() const = 0; - /// Byte size of a single element. - virtual ssize_t itemsize() const = 0; - /// Total number of bytes. - virtual ssize_t nbytes() const = 0; - /// Number of dimensions. - virtual ssize_t ndim() const = 0; - /// Data pointer. - virtual void *data() = 0; - /// Is data equals. - virtual bool equals(const TensorData &other) const = 0; - /// To string. - virtual std::string ToString(const TypeId type, const std::vector &shape) const = 0; -}; - -using TensorDataPtr = std::shared_ptr; - -// Tensor entity class -class Tensor : public MetaTensor { - public: - abstract::AbstractBasePtr ToAbstract() override; - - // brief Create tensor from another tensor, data is shared. - // - // param tensor [Tensor] The input tensor. - explicit Tensor(const Tensor &tensor); - - // brief Create tensor with given data type from another tensor. - // - // param tensor [Tensor] The input tensor. - // param data_type [TypeId] The new tensor data type. - Tensor(const Tensor &tensor, TypeId data_type); - - // brief Create tensor with the given shared tensor data. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - // param data The shared tensor data. - Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data); - - // brief Create an all zero tensor. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - Tensor(TypeId data_type, const std::vector &shape); - - // brief Create a tensor with input data buffer. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - // param data The input data to be copied into tensor. - // param data_len The length of data in bytes. - Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len); - - // brief Create a tensor with input data buffer and given source data type. - // - // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - // param data The input data to be copied into tensor. - // param src_data_type The source data type. - Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type); - - // brief Create 1 dimension tensor from an int vector. - // - // param input [std::vector] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); - - // brief Create 1 dimension tensor from a float vector. - // - // param input [std::vector] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); - - // brief Create 0 dimension tensor from an int scalar. - // - // param input [int64] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(int64_t input, const TypePtr &data_type = nullptr); - - // brief Create 0 dimension tensor from a float scalar. - // - // param input [double] the data for tensor - // param data_type [TypeId] data type - explicit Tensor(double input, const TypePtr &data_type = nullptr); - - ~Tensor() override = default; - - MS_DECLARE_PARENT(Tensor, MetaTensor); - - // brief Compares two Tensor objects. - // - // Compare two tensor objects to see if they have same data type, shape and data address. - // - // param tensor The Tensor object to be compared. - // return true: If having same type, shape and data address, return true, or return false. - bool operator==(const Tensor &tensor) const; - - // It is different from 'operator==' which just compare shape/type/address, - // it do real value comparison. - bool ValueEqual(const Tensor &tensor) const; - - // assgin value to this tensor - Tensor &AssignValue(const Tensor &tensor); - - bool operator==(const Value &other) const override { - if (other.isa()) { - auto &other_ = static_cast(other); - return *this == other_; - } - return false; - } - - // brief Gets tensor's dimension - // - // return The number of dimensions of the tensor data. - int DataDim() const { return static_cast(data().ndim()); } - - // brief Getting tensor data size - // - // return The total number of elements of the tensor data. - int DataSize() const { return static_cast(data().size()); } - - // brief Get the data type fo the tensor for C++ - // - // return [int] The tensor's data type will be cast to int to return. - int data_type_c() const { return static_cast(data_type_); } - - // brief Get the tensor's shape for C++ - // - // return [std::vector] - std::vector shape_c(void) const { return shape(); } - - // brief Get Tensor data pointer for c++ type - // - // return The pointer to the object - void *data_c() { return data().data(); } - - // brief Get Tensor data byte-size for c++ type - // - // return byte size of Tensor data - size_t Size() const { return data().nbytes(); } - - void *data_c() const { return data_->data(); } - - // brief Sync data with device. - void data_sync() const; - - // brief Get the internal data object. - // - // return The reference to internal data object. - TensorData &data() { return *data_; } - - // brief Get the internal data shared pointer. - // - // return The reference to internal data object. - const TensorDataPtr &data_ptr() const { return data_; } - - // brief Get the internal data object. - // - // return The reference to internal data object. - const TensorData &data() const { return *data_; } - - TypeId set_data_type(const TypeId data_type) override; - - std::string GetShapeAndDataTypeInfo() const; - - std::string ToString() const override; - - std::string ToStringRepr() const; - - bool is_init() const { return init_flag_; } - void set_init_flag(bool flag) { init_flag_ = flag; } - - bool is_dirty() const { return dirty_; } - void set_dirty(const bool dirty) { dirty_ = dirty; } - - DeviceAddressPtr device_address() const { return device_address_; } - void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } - - std::string id() const { return id_; } - - const bool parse_info_ = true; - - private: - bool init_flag_{false}; - TensorDataPtr data_{nullptr}; - bool dirty_{true}; - std::string id_{""}; - DeviceAddressPtr device_address_{nullptr}; -}; -using TensorPtr = std::shared_ptr; -using TensorPtrList = std::vector>; -} // namespace tensor - -namespace inference { -class Tensor : public MSTensor { - public: - Tensor(TypeId data_type, const std::vector &shape); - - explicit Tensor(std::shared_ptr tensor_ptr); - - ~Tensor() = default; - - TypeId data_type() const override; - - TypeId set_data_type(const TypeId data_type) override; - - std::vector shape() const override; - - size_t set_shape(const std::vector &shape) override; - - int DimensionSize(size_t index) const override; - - int ElementsNum() const override; - - std::size_t hash() const override; - - std::shared_ptr tensor() const; - - size_t Size() const override; - - void *MutableData() const override; - - protected: - std::shared_ptr tensor_impl_; -}; -} // namespace inference -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_IR_TENSOR_H_ diff --git a/mindspore/ccsrc/ir/tensor_py.cc b/mindspore/ccsrc/ir/tensor_py.cc deleted file mode 100644 index 25339cff5b..0000000000 --- a/mindspore/ccsrc/ir/tensor_py.cc +++ /dev/null @@ -1,390 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/tensor_py.h" - -#include -#include -#include -#include -#include - -#include "device/device_address.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" -#include "abstract/abstract_value.h" - -namespace mindspore { -namespace tensor { - -static TypeId GetDataType(const py::buffer_info &buf) { - if (buf.format.size() == 1) { - switch (buf.format.front()) { - case 'e': - case 'f': - case 'd': - switch (buf.itemsize) { - case 2: - return TypeId::kNumberTypeFloat16; - case 4: - return TypeId::kNumberTypeFloat32; - case 8: - return TypeId::kNumberTypeFloat64; - } - break; - case 'b': - case 'h': - case 'i': - case 'l': - case 'q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeInt8; - case 2: - return TypeId::kNumberTypeInt16; - case 4: - return TypeId::kNumberTypeInt32; - case 8: - return TypeId::kNumberTypeInt64; - } - break; - case 'B': - case 'H': - case 'I': - case 'L': - case 'Q': - switch (buf.itemsize) { - case 1: - return TypeId::kNumberTypeUInt8; - case 2: - return TypeId::kNumberTypeUInt16; - case 4: - return TypeId::kNumberTypeUInt32; - case 8: - return TypeId::kNumberTypeUInt64; - } - break; - case '?': - return TypeId::kNumberTypeBool; - } - } - MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; - return TypeId::kTypeUnknown; -} - -static std::string GetPyTypeFormat(TypeId data_type) { - switch (data_type) { - case TypeId::kNumberTypeFloat16: - return "e"; - case TypeId::kNumberTypeFloat32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeFloat64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt8: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt16: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeUInt64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt8: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt16: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt32: - return py::format_descriptor::format(); - case TypeId::kNumberTypeInt64: - return py::format_descriptor::format(); - case TypeId::kNumberTypeBool: - return py::format_descriptor::format(); - default: - MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; - return ""; - } -} - -static bool IsCContiguous(const py::array &input) { - auto flags = static_cast(input.flags()); - return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; -} - -TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { - // Get input buffer info. - py::buffer_info buf = input.request(); - // Check data types. - auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; - auto buf_type = GetDataType(buf); - if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { - MS_LOG(EXCEPTION) << "Unsupported tensor type!"; - } - // Use buf type as data type if type_ptr not set. - if (data_type == TypeId::kTypeUnknown) { - data_type = buf_type; - } - // Convert input array to C contiguous if need. - std::unique_ptr tmp_buf; - if (!IsCContiguous(input)) { - Py_buffer pybuf; - if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { - MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; - } - tmp_buf = std::make_unique(pybuf.len); - if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { - MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; - } - PyBuffer_Release(&pybuf); - buf.ptr = tmp_buf.get(); - } - // Get tensor shape. - std::vector shape(buf.shape.begin(), buf.shape.end()); - if (data_type == buf_type) { - // Use memory copy if input data type is same as the required type. - return std::make_shared(data_type, shape, buf.ptr, buf.size * buf.itemsize); - } - // Create tensor with data type converted. - return std::make_shared(data_type, shape, buf.ptr, buf_type); -} - -static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { - std::vector strides; - strides.reserve(shape.size()); - const auto ndim = shape.size(); - for (size_t i = 0; i < ndim; ++i) { - auto stride = item_size; - for (size_t j = i + 1; j < ndim; ++j) { - stride *= shape[j]; - } - strides.push_back(stride); - } - return strides; -} - -static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { - std::vector shape(tensor.shape().begin(), tensor.shape().end()); - std::vector strides = GetStrides(shape, tensor.data().itemsize()); - return py::buffer_info{ - tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; -} - -py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { - auto &shape = tensor.shape(); - py::tuple dims(shape.size()); - for (size_t i = 0; i < dims.size(); ++i) { - dims[i] = py::int_(shape[i]); - } - return dims; -} - -py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { - tensor.data_sync(); - auto info = GetPyBufferInfo(tensor); - py::object self = py::cast(&tensor); - return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); -} - -py::array TensorPy::AsNumpy(const Tensor &tensor) { - auto info = GetPyBufferInfo(tensor); - py::object self = py::cast(&tensor); - return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); -} - -static std::vector GetShapeFromTuple(const py::tuple &tuple) { - std::vector shape; - const size_t size = tuple.size(); - shape.reserve(tuple.size()); - for (size_t i = 0; i < size; ++i) { - shape.push_back(py::int_(tuple[i])); - } - return shape; -} - -REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { - // Define python MetaTensor class. - (void)py::class_>(*m, "MetaTensor") - .def(py::init>(), py::arg("dtype"), py::arg("shape")) - .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) - .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") - .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") - .def(py::pickle( - [](const MetaTensor &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(static_cast(t.data_type()), t.shape()); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 2) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); - return tensor; - })); - // Define python Tensor class. - // dtype should define before Tensor, because Tensor init depend dtype - (void)py::class_>(*m, "Tensor") - .def(py::init([](const Tensor &tensor) { return std::make_shared(tensor); }), - py::arg("input")) - .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { - TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; - if (data_type == kTypeUnknown || tensor.data_type() == data_type) { - return std::make_shared(tensor); - } - return std::make_shared(tensor, data_type); - }), - py::arg("input"), py::arg("dtype")) - .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { - auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; - return std::make_shared(data_type, GetShapeFromTuple(shape)); - }), - py::arg("dtype"), py::arg("shape")) - .def(py::init([](const py::array &input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(input, type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::float_ input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::int_ input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::list input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def(py::init([](py::tuple input, const TypePtr &type_ptr) { - return TensorPy::MakeTensor(py::array(input), type_ptr); - }), - py::arg("input"), py::arg("dtype") = nullptr) - .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) - .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) - .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( - Get the tensor's data type. - - Returns: - type, the data type of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) - >>> data.dtype - Int32 - )mydelimiter") - .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( - Get the tensor's shape. - - Returns: - tuple[int], the shape of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((3, 3))) - >>> data.shape() - (3, 3) - )mydelimiter") - .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( - Convert tensor to numpy.ndarray. - - Returns: - numpy.ndarray. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> array = data.asnumpy() - >>> array - array([[1., 1., 1.], - [1., 1., 1.]]) - )mydelimiter") - .def("size", &Tensor::DataSize, R"mydelimiter( - Get tensor's data size. - - Returns: - int, the size of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.size() - 6 - )mydelimiter") - .def("is_init", &Tensor::is_init, R"mydelimiter( - Get tensor init_flag. - - Returns: - bool, whether the tensor init. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.is_init() - False - )mydelimiter") - .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( - Set tensor init_flag. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.set_init_flag(True) - )mydelimiter") - .def("dim", &Tensor::DataDim, R"mydelimiter( - Get tensor's data dimension. - - Returns: - int, the dimension of tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((2, 3))) - >>> data.dim() - 2 - )mydelimiter") - .def("assign_value", &Tensor::AssignValue, R"mydelimiter( - Assign another tensor value to this. - - Arg: - value (:class:`mindspore.tensor`): The value tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) - >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) - >>> data.assign_value(data2) - >>> data.shape - (2, 2) - )mydelimiter") - .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( - Set the tensor's data type. - - Arg: - dtype (:class:`mindspore.dtype`): The type of output tensor. - - Examples: - >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) - >>> data.set_dtype(mindspore.int32) - mindspore.int32 - )mydelimiter") - .def("__str__", &Tensor::ToString) - .def("__repr__", &Tensor::ToStringRepr) - .def(py::pickle( - [](const Tensor &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(TensorPy::AsNumpy(t)); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - return TensorPy::MakeTensor(t[0].cast()); - })); - })); -} // namespace tensor -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/CMakeLists.txt b/mindspore/ccsrc/kernel/CMakeLists.txt deleted file mode 100644 index 9f460425e1..0000000000 --- a/mindspore/ccsrc/kernel/CMakeLists.txt +++ /dev/null @@ -1,66 +0,0 @@ -file(GLOB_RECURSE KERNEL_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_build_info.cc" - "kash/*.cc" - "common_utils.cc" - "oplib/*.cc" -) - -if (ENABLE_D) - file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_query.cc" - "kernel_fusion.cc" - "akg/ascend/*.cc" - "akg/akg_kernel_build.cc" - "akg/akg_kernel_attrs_process.cc" - "akg/akg_kernel_metadata.cc" - "tbe/*.cc" - "aicpu/*.cc" - "rts/*.cc" - "hccl/*.cc" - ) - add_compile_definitions(ENABLE_D) -endif () - -if (ENABLE_CPU) - file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "cpu/*.cc" - ) - - list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/push_kernel.cc" - "cpu/ps/pull_kernel.cc" - "cpu/ps/embedding_look_up_ps_kernel.cc" - "cpu/ps/embedding_look_up_proxy_kernel.cc" - "cpu/ps/apply_momentum_ps_kernel.cc" - "cpu/ps/sparse_apply_adam_ps_kernel.cc" - "cpu/ps/sparse_apply_ftrl_ps_kernel.cc") - - if (NOT ENABLE_MPI) - list(REMOVE_ITEM CPU_SRC_LIST "cpu/allgather_cpu_kernel.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/reduce_scatter_cpu_kernel.cc") - list(REMOVE_ITEM CPU_SRC_LIST "cpu/embedding_look_up_comm_grad_cpu_kernel.cc") - endif () -endif () - -if (ENABLE_GPU) - file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "gpu/*.cu" - "akg/gpu/*.cc" - "akg/akg_kernel_build.cc" - "akg/akg_kernel_attrs_process.cc" - ) - - file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc") - list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_gpu_kernel.cc") - - if (ENABLE_MPI) - include(ExternalProject) - file(GLOB_RECURSE GPU_NCCL_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/nccl/*.cc") - list(APPEND GPU_SRC_LIST ${GPU_NCCL_LIST}) - endif () - - # add_library(_mindspore_kernel_cuda_obj OBJECT ${CUDA_SRC_LIST}) -endif() - -set_property(SOURCE ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST} - PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_KERNEL) -add_library(_mindspore_kernel_obj OBJECT ${KERNEL_SRC_LIST} ${CPU_SRC_LIST} ${GPU_SRC_LIST} ${D_SRC_LIST}) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc deleted file mode 100644 index 99e792216f..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ /dev/null @@ -1,312 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/aicpu/aicpu_kernel_build.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "device/kernel_runtime.h" -#include "kernel/aicpu/aicpu_kernel_mod.h" -#include "kernel/akg/akg_kernel_build.h" -#include "proto/tensor.pb.h" -#include "proto/tensor_shape.pb.h" -#include "proto/attr.pb.h" -#include "proto/node_def.pb.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "kernel/aicpu/aicpu_util.h" -#include "session/kernel_graph.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -using FNodeAttrHandle = std::function &anf_node, mindspore::NodeDef *proto)>; - -bool SetIOIputSize(const std::shared_ptr &anf_node, const size_t &input_num, - std::vector *input_size_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_size_list); - for (size_t i = 0; i < input_num; i++) { - std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); - if (AnfAlgo::GetInputDeviceDataType(anf_node, i) == kObjectTypeString) { - if (!anf_node->isa()) { - MS_LOG(EXCEPTION) << "anf_node is not CNode."; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < (i + 1)) { - MS_LOG(ERROR) << "cnode inputs size " << cnode->inputs().size() << " is smaller than " << i + 1; - return false; - } - auto input_node = cnode->inputs()[i + 1]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - auto value_ptr = GetValueNode(input_node); - auto value = GetValue(value_ptr); - input_size_list->push_back(value.size()); - } - } else { - auto type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); - MS_EXCEPTION_IF_NULL(type_ptr); - int64_t size_i = 1; - for (size_t j = 0; j < shape_i.size(); j++) { - size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); - } - size_t type_byte = GetTypeByte(type_ptr); - if (type_byte == 0) { - return false; - } - size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); - input_size_list->push_back(LongToSize(size_i)); - } - } - return true; -} - -bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - std::vector input_size_list; - std::vector output_size_list; - size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); - - if (!SetIOIputSize(anf_node, input_num, &input_size_list)) { - return false; - } - kernel_mod_ptr->SetInputSizeList(input_size_list); - - for (size_t i = 0; i < output_num; i++) { - std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); - TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); - MS_EXCEPTION_IF_NULL(type_ptr); - int64_t size_i = 1; - for (size_t j = 0; j < shape_i.size(); j++) { - size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); - } - size_t type_byte = GetTypeByte(type_ptr); - if (type_byte == 0) { - return false; - } - size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); - output_size_list.push_back(LongToSize(size_i)); - } - kernel_mod_ptr->SetOutputSizeList(output_size_list); - return true; -} - -void ParseAttrValue(const std::string &type, const std::string &attr_name, const mindspore::ValuePtr &value, - ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr) { - MS_EXCEPTION_IF_NULL(node_attr); - MS_EXCEPTION_IF_NULL(value); - if (type == "int") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_i(attr_value); - } else if (type == "str") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_s(attr_value); - } else if (type == "bool") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_b(attr_value); - } else if (type == "float") { - auto attr_value = GetValue(value); - (*node_attr)[attr_name].set_f(attr_value); - } else if (type == "listInt") { - std::vector attr_value; - auto value_type = value->type(); - MS_EXCEPTION_IF_NULL(value_type); - auto value_type_str = value_type->ToString(); - if (value_type_str == "Int32") { - int data = GetValue(value); - attr_value.push_back(data); - } else { - attr_value = GetValue>(value); - } - mindspore::AttrValue input_shape_attr; - mindspore::AttrValue_ArrayValue *input_shape_attr_list = input_shape_attr.mutable_array(); - MS_EXCEPTION_IF_NULL(input_shape_attr_list); - for (const auto shape : attr_value) { - input_shape_attr_list->add_i(shape); - } - (*node_attr)[attr_name] = input_shape_attr; - } else { - MS_LOG(EXCEPTION) << "type: " << type << "not support"; - } -} - -void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(proto); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - if (op_name == kPrint) { - return; - } - - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); - MS_EXCEPTION_IF_NULL(op_info_ptr); - auto attrs_ptr = op_info_ptr->attrs_ptr(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); - for (const auto &attr_ptr : attrs_ptr) { - MS_EXCEPTION_IF_NULL(attr_ptr); - std::string attr_name = attr_ptr->name(); - auto value = primitive->GetAttr(attr_name); - if (value != nullptr) { - if (attr_name == kQueueName || attr_name == kSharedName) { - attr_name = kChannelName; - } else if (attr_name == kSeed0) { - attr_name = kSeed; - } else if (attr_name == kSeed1) { - attr_name = kSeed2; - } - std::string type = attr_ptr->type(); - ParseAttrValue(type, attr_name, value, node_attr); - } - } - MS_LOG(INFO) << "Set node attr end!"; -} - -void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(proto); - MS_EXCEPTION_IF_NULL(anf_node); - size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); - if (input_num == 0) { - MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; - return; - } - - for (size_t input_index = 0; input_index < input_num; input_index++) { - ::mindspore::Tensor *node_inputs = proto->add_inputs(); - MS_EXCEPTION_IF_NULL(node_inputs); - TypeId input_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index); - std::vector input_shape; - int32_t input_data_type; - if (input_type == kObjectTypeString) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = cnode->inputs()[input_index + 1]; - auto value_ptr = GetValueNode(input_node); - auto value = GetValue(value_ptr); - input_shape.push_back(1); - input_shape.push_back(value.size()); - input_data_type = AicpuOpUtil::MsTypeToProtoType(kTypeUnknown); - } else { - input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); - input_data_type = AicpuOpUtil::MsTypeToProtoType(input_type); - } - - mindspore::TensorShape *tensorShape = node_inputs->mutable_tensor_shape(); - for (auto item : input_shape) { - mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); - dim->set_size((::google::protobuf::int64)item); - } - node_inputs->set_tensor_type((mindspore::DataType)input_data_type); - node_inputs->set_mem_device("HBM"); - } -} - -void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(proto); - MS_EXCEPTION_IF_NULL(anf_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); - if (output_num == 0) { - MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; - return; - } - - for (size_t output_index = 0; output_index < output_num; output_index++) { - ::mindspore::Tensor *node_outputs = proto->add_outputs(); - MS_EXCEPTION_IF_NULL(node_outputs); - std::vector output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); - mindspore::TensorShape *tensorShape = node_outputs->mutable_tensor_shape(); - MS_EXCEPTION_IF_NULL(tensorShape); - for (auto item : output_shape) { - mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); - MS_EXCEPTION_IF_NULL(dim); - dim->set_size((::google::protobuf::int64)item); - } - TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, output_index); - int32_t output_data_type = AicpuOpUtil::MsTypeToProtoType(output_type); - node_outputs->set_tensor_type((mindspore::DataType)output_data_type); - node_outputs->set_mem_device("HBM"); - } -} - -void SetNodedefProto(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(proto); - MS_LOG(INFO) << "SetNodedefProto entry"; - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - // set op name - proto->set_op(op_name); - // set inputs tensor - SetNodeInputs(anf_node, proto); - // set outputs tensor - SetNodeOutputs(anf_node, proto); - // set node attr - SetNodeAttr(anf_node, proto); - MS_LOG(INFO) << "SetNodedefProto end!"; -} - -bool CreateNodeDefBytes(const std::shared_ptr &anf_node, - const std::shared_ptr &kernel_mod_ptr) { - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "CreateNodeDefBytes entry"; - - mindspore::NodeDef proto; - SetNodedefProto(anf_node, &proto); - std::string nodeDefStr; - if (!proto.SerializeToString(&nodeDefStr)) { - MS_LOG(ERROR) << "Serialize nodeDef to string failed."; - return false; - } - kernel_mod_ptr->SetNodeDef(nodeDefStr); - MS_LOG(INFO) << "CreateNodeDefBytes end!"; - return true; -} - -KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - auto kernel_mod_ptr = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - kernel_mod_ptr->SetAnfNode(anf_node); - kernel_mod_ptr->SetNodeName(op_name); - if (!CreateNodeDefBytes(anf_node, kernel_mod_ptr)) { - MS_LOG(EXCEPTION) << "Create nodeDefBytes faild!"; - } - if (!SetIOSize(anf_node, kernel_mod_ptr)) { - MS_LOG(EXCEPTION) << "Set input output size list failed."; - } - return kernel_mod_ptr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.h b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.h deleted file mode 100644 index a3c24ae49e..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.h +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -KernelModPtr AicpuOpBuild(const std::shared_ptr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc deleted file mode 100644 index 3670a2d76f..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/aicpu/aicpu_kernel_metadata.h" -#include -#include -#include "kernel/oplib/oplib.h" -#include "kernel/common_utils.h" -#include "kernel/aicpu/aicpu_util.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - MS_LOG(INFO) << "AicpuMetadataInfo."; - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - if (op_name == kInitDataSetQueue) { - op_name = kInitData; - } - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAICPU); - if (op_info_ptr == nullptr) { - MS_LOG(DEBUG) << "Aicpu does not have op [" << op_name << "]"; - return; - } - // For compatibility with the current framework - if (op_name == kPrint || op_name == kGetNext || op_name == kPack) { - std::vector inputs_format{}; - std::vector inputs_type{}; - if (op_name == kPrint || op_name == kPack) { - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(kOpFormat_DEFAULT); - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); - } - } - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(kOpFormat_DEFAULT); - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); - } - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetProcessor(AICPU); - builder.SetKernelType(AICPU_KERNEL); - builder.SetFusionType(OPAQUE); - kernel_info_list->push_back(builder.Build()); - return; - } - if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { - MS_LOG(WARNING) << "Aicpu parsed metadata op [" << op_name << "] failed"; - return; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.h b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.h deleted file mode 100644 index 74e667856e..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ - -#include -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc deleted file mode 100644 index c6d8a101cd..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/aicpu/aicpu_kernel_mod.h" - -#include -#include -#include -#include - -#include "runtime/mem.h" -#include "runtime/rt.h" -#include "kernel/aicpu/aicpu_kernel_build.h" -#include "utils/convert_utils.h" -#include "kernel/aicpu/aicpu_util.h" -#include "utils/context/ms_context.h" - -using AicpuTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -constexpr auto AICPU_OPS_SO_NAME = "libaicpu_kernels.so"; - -AicpuOpKernelMod::AicpuOpKernelMod() : anf_node_(nullptr) {} - -AicpuOpKernelMod::~AicpuOpKernelMod() { - args_.clear(); - inputList_.clear(); - outputList_.clear(); - anf_node_ = nullptr; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); -} - -void AicpuOpKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } -const std::vector &AicpuOpKernelMod::GetInputSizeList() const { return input_size_list_; } -void AicpuOpKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } -const std::vector &AicpuOpKernelMod::GetOutputSizeList() const { return output_size_list_; } -void AicpuOpKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } -const std::vector &AicpuOpKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } -void AicpuOpKernelMod::SetInputList(const std::vector &inputList) { inputList_ = inputList; } -void AicpuOpKernelMod::SetOutputList(const std::vector &outputList) { outputList_ = outputList; } -void AicpuOpKernelMod::SetNodeDef(const std::string &nodeDef) { (void)node_def_str_.assign(nodeDef); } -void AicpuOpKernelMod::SetNodeName(const std::string &node_name) { node_name_ = node_name; } -void AicpuOpKernelMod::SetAnfNode(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - anf_node_ = anf_node; -} - -void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector &inputs, - const std::vector &outputs) { - MS_LOG(INFO) << "CreateCpuKernelInfoOffline start"; - - node_so_ = AICPU_OPS_SO_NAME; - - // InputOutputAddr - vector io_addrs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(io_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(io_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - - auto io_addrs_num = io_addrs.size(); - // calculate paramLen: AicpuParamHead.len + ioAddrsSize + notifyId.len + customizedAttr.len - auto param_len = sizeof(AicpuParamHead); - - // get input and output addrs size, no need to check overflow - auto io_addrs_size = io_addrs_num * sizeof(uint64_t); - // refresh paramLen, no need to check overflow - param_len += io_addrs_size; - - auto node_def_len = node_def_str_.length(); - param_len += node_def_len; - - // Create taskArgs: AicpuParamHead + ioAddrs + notifyId + customizedAttr - AicpuParamHead paramHead = {static_cast(param_len), static_cast(io_addrs_num)}; - args_.clear(); - (void)args_.append(reinterpret_cast(¶mHead), sizeof(AicpuParamHead)); - // TaskArgs append ioAddrs - if (io_addrs_size != 0) { - (void)args_.append(reinterpret_cast(io_addrs.data()), io_addrs_size); - } - - // When it's aicpu customized ops, taskArgs should append customized attr - if (node_def_len != 0) { - (void)args_.append(reinterpret_cast(node_def_str_.data()), node_def_len); - } - - MS_LOG(INFO) << "CreateCpuKernelInfoOffline end"; -} - -bool AicpuOpKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == nullptr) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - - CreateCpuKernelInfo(inputs, outputs); - if (node_name_ == kTopK) { - node_name_ = kTopKV2; - } - MS_LOG(INFO) << "Aicpu launch, node_so_:" << node_so_ << ", node name:" << node_name_ - << ", args_size:" << args_.length(); - if (rtCpuKernelLaunch(reinterpret_cast(node_so_.c_str()), - reinterpret_cast(node_name_.c_str()), 1, - reinterpret_cast(args_.data()), static_cast(args_.length()), nullptr, - stream_ptr) != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Aicpu op launch failed!"; - - return false; - } - return true; -} - -std::vector AicpuOpKernelMod::GenTask(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, uint32_t stream_id) { - MS_LOG(INFO) << "AicpuOpKernelMod GenTask start"; - - stream_id_ = stream_id; - node_so_ = AICPU_OPS_SO_NAME; - std::vector input_data_addrs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - - std::vector output_data_addrs; - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - - if (node_name_ == kTopK) { - node_name_ = kTopKV2; - } - - AicpuTaskInfoPtr task_info_ptr = make_shared( - kernel_name_, stream_id, node_so_, node_name_, node_def_str_, input_data_addrs, output_data_addrs, NeedDump()); - - MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.h b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.h deleted file mode 100644 index 3ee9bd2a15..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_mod.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/aicpu/aicpu_util.h" -namespace mindspore { -namespace kernel { -class AicpuOpKernelMod : public AscendKernelMod { - public: - AicpuOpKernelMod(); - ~AicpuOpKernelMod() override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - void SetInputList(const std::vector &inputList); - void SetOutputList(const std::vector &outputList); - void SetAnfNode(const AnfNodePtr &anf_node); - void SetNodeDef(const std::string &nodeDef); - void SetNodeName(const std::string &node_name); - - /** - * @brief Build AICPU Engine kernel structure, and allocate device memory for offline task generate - * @return SUCCESS - * @return FAIL - * - */ - void CreateCpuKernelInfo(const std::vector &inputs, const std::vector &outputs); - - void SetInputSizeList(const std::vector &size_list); - void SetOutputSizeList(const std::vector &size_list); - void SetWorkspaceSizeList(const std::vector &size_list); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - - private: - std::string args_; - std::string node_def_str_; - std::string node_name_; - std::string node_so_; - std::vector inputList_; - std::vector outputList_; - AnfNodePtr anf_node_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using AicpuOpKernelModPtr = std::shared_ptr; -using AicputOpKernelModPtrList = std::vector; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc deleted file mode 100644 index a617f56f8f..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/aicpu/aicpu_util.h" -#include -#include -#include "proto/types.pb.h" -#include "runtime/mem.h" -#include "runtime/rt.h" -#include "utils/convert_utils.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -static std::map MS_PROTO_DATA_TYPE_MAP = { - {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, - {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, - {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, - {mindspore::TypeId::kNumberTypeInt8, mindspore::DataType::MS_INT8}, - {mindspore::TypeId::kNumberTypeInt16, mindspore::DataType::MS_INT16}, - {mindspore::TypeId::kNumberTypeInt32, mindspore::DataType::MS_INT32}, - {mindspore::TypeId::kNumberTypeInt64, mindspore::DataType::MS_INT64}, - {mindspore::TypeId::kNumberTypeUInt, mindspore::DataType::MS_UINT32}, - {mindspore::TypeId::kNumberTypeUInt8, mindspore::DataType::MS_UINT8}, - {mindspore::TypeId::kNumberTypeUInt16, mindspore::DataType::MS_UINT16}, - {mindspore::TypeId::kNumberTypeUInt32, mindspore::DataType::MS_UINT32}, - {mindspore::TypeId::kNumberTypeUInt64, mindspore::DataType::MS_UINT64}, - {mindspore::TypeId::kNumberTypeFloat16, mindspore::DataType::MS_FLOAT16}, - {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, - {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, - {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, -}; - -int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { - auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type); - if (iter != MS_PROTO_DATA_TYPE_MAP.end()) { - return MS_PROTO_DATA_TYPE_MAP[ms_type]; - } else { - MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast(ms_type); - return -1; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/kernel/aicpu/aicpu_util.h deleted file mode 100644 index bf8025de2c..0000000000 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ - -#include -#include -#include -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -constexpr auto kInitDataSetQueue = "InitDataSetQueue"; -constexpr auto kInitData = "InitData"; -constexpr auto kGetNext = "GetNext"; -constexpr auto kPrint = "Print"; -constexpr auto kPack = "Pack"; -constexpr auto kOutputTypes = "output_types"; -constexpr auto kOutputShapes = "output_shapes"; -constexpr auto kChannelName = "channel_name"; -constexpr auto kSharedName = "shared_name"; -constexpr auto kShapes = "shapes"; -constexpr auto kTypes = "types"; -constexpr auto kQueueName = "queue_name"; -constexpr auto kSeed = "seed"; -constexpr auto kSeed0 = "Seed0"; -constexpr auto kSeed1 = "Seed1"; -constexpr auto kSeed2 = "seed2"; -constexpr auto kTopK = "TopK"; -constexpr auto kTopKV2 = "TopKV2"; - -struct AicpuParamHead { - uint32_t length; // Total length: include cunstom message - uint32_t ioAddrNum; // Input and output address number - uint32_t extInfoLength; // extInfo struct Length - uint64_t extInfoAddr; // extInfo address -} __attribute__((packed)); - -class AicpuOpUtil { - public: - static int MsTypeToProtoType(TypeId ms_type); - - private: - // kernel id - static uint64_t KernelId_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_AICPU_AICPU_UTIL_H_ diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.cc deleted file mode 100644 index 018fbe4f2a..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.cc +++ /dev/null @@ -1,180 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/akg/akg_kernel_attrs_process.h" - -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace kernel { -void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // The x and output are akg op input and output param. - std::vector input_names = {"x"}; - std::vector output_names = {"output"}; - AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); - - TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - std::string dst_type; - if (dst_type_id == kFloat32->type_id()) { - dst_type = "float32"; - } else if (dst_type_id == kFloat16->type_id()) { - dst_type = "float16"; - } - AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); -} - -void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names = {"x"}; - std::vector output_names = {"output"}; - AnfAlgo::SetNodeAttr("input_names", MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr("output_names", MakeValue(output_names), anf_node); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); - if (origin_shape.size() != kShape4dDims) { - MS_LOG(EXCEPTION) << "The dim of origin_shape is not equal to 4, but it's dim is " << origin_shape.size() << "."; - } - std::vector shape_transform; - (void)std::transform(origin_shape.begin(), origin_shape.end(), std::back_inserter(shape_transform), - [](const int &origin_shape) { return static_cast(origin_shape); }); - AnfAlgo::SetNodeAttr("shape4d", MakeValue(shape_transform), anf_node); - AnfAlgo::SetNodeAttr("output_format", MakeValue(kOpFormat_NCHW), anf_node); - - TypeId dst_type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - std::string dst_type; - if (dst_type_id == kFloat32->type_id()) { - dst_type = "float32"; - } else if (dst_type_id == kFloat16->type_id()) { - dst_type = "float16"; - } - AnfAlgo::SetNodeAttr("dstType", MakeValue(dst_type), anf_node); -} - -void SetAkgAttrsForCast(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // The x and output are akg op input and output param. - std::vector input_names = {"x", "dst_type"}; - std::vector output_names = {"output"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); - - std::string dst_type; - TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); - if (output_type == kFloat32->type_id()) { - dst_type = "float32"; - } else if (output_type == kFloat16->type_id()) { - dst_type = "float16"; - } else if (output_type == kInt32->type_id()) { - dst_type = "int32"; - } else { - MS_LOG(WARNING) << "Unknown cast_to type: " << TypeIdToType(output_type)->ToString(); - } - AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); -} - -void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names{"dy", "data", "mean"}; - std::vector output_names{"dgamma_red_hw", "dbeta_red_hw", "data_minus_mean"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); -} - -void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node) { - const size_t kBNGrad2InputSize = 5; - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names{"dgamma_red_hw", "dbeta_red_hw", "variance", "gamma"}; - std::vector output_names{"bn_scale", "bn_bias", "rs", "dgamma_dx", "dbeta_dx"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kBNGrad2InputSize) { - MS_LOG(EXCEPTION) << "The inputs size of BNGrad2 is less then " << kBNGrad2InputSize; - } - auto input1 = cnode->input(1); - MS_EXCEPTION_IF_NULL(input1); - auto tuple_getitem = input1->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->inputs().size() < kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The inputs size of tuple_getitem is less then " << kTupleGetItemInputSize; - } - auto bn_grad1 = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - std::vector data_shape = AnfAlgo::GetInputDeviceShape(bn_grad1, 0); - AnfAlgo::SetNodeAttr(kAttrDataShape, MakeValue(opt::Convert2Int(data_shape)), anf_node); -} - -void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector input_names{"dy", "rs", "dgamma_dx", "dbeta_dx", "data_minus_mean"}; - std::vector output_names{"dx"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), anf_node); -} - -void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // Set attr for fused_bn1 - std::vector fused_bn1_input_names{"data"}; - std::vector fused_bn1_output_names{"mean", "var_part"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn1_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn1_output_names), anf_node); -} - -void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // Set attr for fused_bn2 - std::vector fused_bn2_input_names{"mean", "var_part", "running_mean", "running_var"}; - std::vector fused_bn2_output_names{"variance", "running_mean", "running_variance"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn2_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn2_output_names), anf_node); -} - -void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - // Set attr for fused_bn3 - std::vector fused_bn3_input_names{"data", "mean", "variance", "gamma", "beta"}; - std::vector fused_bn3_output_names{"y"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(fused_bn3_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(fused_bn3_output_names), anf_node); -} - -void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector conv_bn1_output_names{"data", "var_part", "mean"}; - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(conv_bn1_output_names), anf_node); -} - -void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector bn2_add_relu_input_names{"data", "var_part", "mean", "other_branch_data", - "gamma", "beta", "running_mean", "running_var"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_add_relu_input_names), anf_node); - std::vector bn2_add_relu_output_names{"output", "running_mean", "running_variance", "save_inv_variance"}; - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_add_relu_output_names), anf_node); -} - -void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector bn2_input_names{"data", "var_part", "mean", "gamma", "beta", "running_mean", "running_var"}; - std::vector bn2_output_names{"y", "running_mean", "running_variance", "save_inv_variance"}; - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(bn2_input_names), anf_node); - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.h b/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.h deleted file mode 100644 index 9d15d4f9e9..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_attrs_process.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H -#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H - -#include -#include -#include -#include -#include "ir/anf.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace kernel { -void SetAkgAttrsForFour2Five(const AnfNodePtr &anf_node); -void SetAkgAttrsForFive2Four(const AnfNodePtr &anf_node); -void SetAkgAttrsForCast(const AnfNodePtr &anf_node); -void SetAkgAttrsForBNGrad1(const AnfNodePtr &anf_node); -void SetAkgAttrsForBNGrad2(const AnfNodePtr &anf_node); -void SetAkgAttrsForBNGrad3(const AnfNodePtr &anf_node); -void SetAkgAttrsForFusedBN1(const AnfNodePtr &anf_node); -void SetAkgAttrsForFusedBN2(const AnfNodePtr &anf_node); -void SetAkgAttrsForFusedBN3(const AnfNodePtr &anf_node); -void SetAkgAttrsForConvBN1(const AnfNodePtr &anf_node); -void SetAkgAttrsForBN2AddRelu(const AnfNodePtr &anf_node); -void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node); - -const std::unordered_map> kAkgKernelAttrsProcessMap = { - {kFour2FiveOpName, SetAkgAttrsForFour2Five}, - {kFive2FourOpName, SetAkgAttrsForFive2Four}, - {"Cast", SetAkgAttrsForCast}, - {kBNGrad1OpName, SetAkgAttrsForBNGrad1}, - {kBNGrad2OpName, SetAkgAttrsForBNGrad2}, - {kBNGrad3OpName, SetAkgAttrsForBNGrad3}, - {kFusedBN1OpName, SetAkgAttrsForFusedBN1}, - {kFusedBN2OpName, SetAkgAttrsForFusedBN2}, - {kFusedBN3OpName, SetAkgAttrsForFusedBN3}, - {kConvBN1OpName, SetAkgAttrsForConvBN1}, - {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, - {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_ATTRS_PROCESS_H diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc deleted file mode 100644 index 0e8d93d47f..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_build.cc +++ /dev/null @@ -1,623 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/akg_kernel_build.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "utils/any.h" -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/akg/akg_kernel_attrs_process.h" - -namespace mindspore { -namespace kernel { -constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200; -constexpr int32_t ARGS_SIZE = 1; -constexpr auto kCompileWithJsonFunc = "compilewithjson"; - -// json key -constexpr auto kOpDesc = "op_desc"; -constexpr auto kInputDesc = "input_desc"; -constexpr auto kShape = "shape"; -constexpr auto kDataType = "data_type"; -constexpr auto kOutputDesc = "output_desc"; -constexpr auto kName = "name"; -constexpr auto kTensorName = "tensor_name"; -constexpr auto kValue = "value"; -constexpr auto KDynInputSizes = "dyn_input_sizes"; -constexpr auto KInputNames = "input_names"; -constexpr auto KInput = "input"; -constexpr auto KDtype = "dtype"; -namespace { -template -std::string Vector2Str(const std::vector &inputs) { - if (!inputs.empty()) { - std::ostringstream oss; - (void)std::copy(inputs.begin(), inputs.end() - 1, std::ostream_iterator(oss, ", ")); - oss << inputs.back(); - return oss.str(); - } - return ""; -} -} // namespace - -std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - -std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, - const std::pair &position) { - if (node_json.count(tag) == 0) { - MS_LOG(ERROR) << "Node [" << node_json.dump() << "] has no key [" << tag << "]."; - return ""; - } - - auto const &tag_desc = node_json[tag]; - nlohmann::json first_index; - if (tag == kOutputDesc) { - first_index = tag_desc; - } else if (!tag_desc.is_array() || tag_desc.size() <= position.first) { - MS_LOG(ERROR) << "Node [" << tag_desc.dump() << "] has no enough value [" << position.first << "]."; - return ""; - } else { - first_index = tag_desc[position.first]; - } - - if (!first_index.is_array() || first_index.size() <= position.second) { - MS_LOG(ERROR) << "Node [" << first_index.dump() << "] has no enough value [" << position.second << "]."; - return ""; - } - auto const &second_index = first_index[position.second]; - if (second_index.count(kTensorName) == 0) { - MS_LOG(ERROR) << "Node [" << second_index.dump() << "] has no key [" << kTensorName << "]."; - return ""; - } - - return second_index[kTensorName]; -} - -void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, - nlohmann::json *const node_json) { - MS_EXCEPTION_IF_NULL(node_json); - if (node_json->count(tag) == 0) { - MS_LOG(ERROR) << "Node [" << node_json->dump() << "] has no key [" << tag << "]."; - return; - } - - nlohmann::json *tag_desc = &((*node_json)[tag]); - nlohmann::json *first_index; - if (tag == kOutputDesc) { - first_index = tag_desc; - } else if (!tag_desc->is_array() || tag_desc->size() <= position.first) { - MS_LOG(ERROR) << "Node [" << tag_desc->dump() << "] has no enough value [" << position.first << "]."; - return; - } else { - first_index = &((*tag_desc)[position.first]); - } - - if (!first_index->is_array() || first_index->size() <= position.second) { - MS_LOG(ERROR) << "Node [" << first_index->dump() << "] has no enough value [" << position.second << "]."; - return; - } - nlohmann::json *second_index = &((*first_index)[position.second]); - if (second_index->count(kTensorName) == 0) { - MS_LOG(ERROR) << "Node [" << second_index->dump() << "] has no key [" << kTensorName << "]."; - return; - } - (*second_index)[kTensorName] = new_name; - return; -} - -int AkgKernelBuild::op_cnt_ = 0; -std::mutex AkgKernelBuild::op_cnt_mtx_; - -std::string AkgKernelBuild::GetProcessor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string device; - switch (AnfAlgo::GetProcessor(anf_node)) { - case Processor::AICORE: - device = kProcessorAiCore; - break; - - case Processor::AICPU: - device = kProcessorAiCpu; - break; - - case Processor::CUDA: - device = kProcessorCuda; - break; - - default: - MS_LOG(ERROR) << "Unknown processor type."; - break; - } - - return device; -} - -bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, - std::vector *const output_size) { - if (input_size == nullptr || output_size == nullptr) { - MS_LOG(ERROR) << "input size or output size is nullptr"; - return false; - } - input_size->clear(); - output_size->clear(); - - for (size_t i = 0; i < node_json[kInputDesc].size(); i++) { - for (size_t m = 0; m < node_json[kInputDesc][i].size(); m++) { - std::string dtype = node_json[kInputDesc][i][m][kDataType]; - size_t nbyte = GetDtypeNbyte(dtype); - size_t size_i = std::accumulate(node_json[kInputDesc][i][m][kShape].begin(), - node_json[kInputDesc][i][m][kShape].end(), nbyte, std::multiplies()); - input_size->push_back(size_i); - } - } - - for (size_t i = 0; i < node_json[kOutputDesc].size(); i++) { - std::string dtype = node_json[kOutputDesc][i][kDataType]; - size_t nbyte = GetDtypeNbyte(dtype); - size_t size_i = std::accumulate(node_json[kOutputDesc][i][kShape].begin(), node_json[kOutputDesc][i][kShape].end(), - nbyte, std::multiplies()); - output_size->push_back(size_i); - } - - return true; -} - -int AkgKernelBuild::GetOpCntInc() { - op_cnt_mtx_.lock(); - int cnt = op_cnt_++; - op_cnt_mtx_.unlock(); - return cnt; -} - -bool AkgKernelBuild::CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(inputs_json); - - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - if (op_info == nullptr) { - MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op_info is nullptr"; - return false; - } - - std::vector> inputs_ptr = op_info->inputs_ptr(); - if (inputs_ptr.empty()) { - MS_LOG(INFO) << "Apply kernel [" << op_name << "] regist info has no input info"; - return true; - } - auto op_info_input_num = inputs_ptr.size(); - - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - size_t real_input_index = 0; - std::vector input_list; - for (size_t i = 0; i < op_info_input_num; i++) { - size_t input_tensor_num; - std::shared_ptr input_ptr = inputs_ptr[i]; - std::string op_input_name; - if (input_ptr == nullptr) { - MS_LOG(ERROR) << "Apply kernel [" << op_name << "] regist input[" << i << "] is nullptr"; - return false; - } - - op_input_name = input_ptr->name(); - if (dyn_input_sizes.empty()) { - input_tensor_num = 1; - } else { - input_tensor_num = IntToSize(dyn_input_sizes[i]); - } - - input_list.clear(); - for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { - // dtype : float16 - auto type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index); - std::string dtype = TypeId2String(type_id); - if (dtype.empty()) { - MS_LOG(ERROR) << "Op [" << op_name << "] input [" << input_i << "] data type is null. "; - return false; - } - nlohmann::json input_desc_json; - input_desc_json[kDataType] = dtype; - input_desc_json[kName] = op_input_name; - input_desc_json[kTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index)); - auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index); - if (anf_node->func_graph() != nullptr && anf_node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && - GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) { - MS_LOG(WARNING) << "we take input[" << real_input_index << "] of [" << anf_node->DebugString(2) - << "] as const tensor, shape: [" << Vector2Str(input_shape) - << "], value: " << input_desc_json[kValue]; - - input_shape.clear(); - } - if (input_shape.empty()) { - input_shape.push_back(1); - } - input_desc_json[kShape] = input_shape; - input_list.emplace_back(input_desc_json); - real_input_index++; - } - inputs_json->emplace_back(input_list); - } - return true; -} - -bool AkgKernelBuild::CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(outputs_json); - size_t output_tensor_num = AnfAlgo::GetOutputTensorNum(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - auto outputs = op_info_ptr->outputs_ptr(); - for (size_t i = 0; i < output_tensor_num; i++) { - nlohmann::json output_json; - auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, i); - std::string dtype = TypeId2String(type_id); - if (dtype.empty()) { - MS_LOG(ERROR) << "Op [" << op_name << "] output [" << i << "] data type is null. "; - return false; - } - - std::string output_name = outputs[i]->name(); - output_json[kDataType] = dtype; - output_json[kName] = output_name; - output_json[kTensorName] = "output_" + std::to_string(i) + "_" + std::to_string(GetOutputTensorIdxInc()); - output_json[kShape] = AnfAlgo::GetOutputDeviceShape(anf_node, i); - outputs_json->push_back(output_json); - } - return true; -} - -void GetJson(const AnfNodePtr &anf_node, const std::vector &dyn_input_sizes, - const std::shared_ptr &op_attr, nlohmann::json *const attr_json, const ValuePtr &attr_value) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_attr); - MS_EXCEPTION_IF_NULL(attr_json); - std::string type = op_attr->type(); - if (type == "int") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "str") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "bool") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "float") { - (*attr_json)[kValue] = GetValue(attr_value); - } else if (type == "listInt") { - (*attr_json)[kValue] = GetValue>(attr_value); - } else if (type == "listStr") { - std::vector data_format; - if (op_attr->name() == kArgDataformat) { - size_t tensor_args_num = !dyn_input_sizes.empty() ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); - for (size_t format_i = 0; format_i < tensor_args_num; format_i++) { - auto input_format = AnfAlgo::GetInputFormat(anf_node, format_i); - data_format.push_back(input_format); - } - } else { - data_format = GetValue>(attr_value); - } - (*attr_json)[kValue] = data_format; - } else { - MS_LOG(WARNING) << "attr type:" << type; - } -} - -bool AkgKernelBuild::CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, - const std::shared_ptr &op_info, nlohmann::json *const attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - MS_EXCEPTION_IF_NULL(op_info); - std::vector> attrs = op_info->attrs_ptr(); - if (attrs.empty()) { - MS_LOG(INFO) << "Apply kernel [" << op_name << "] op info attrs is empty"; - return true; - } - std::vector> inputs = op_info->inputs_ptr(); - - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - if (inputs.empty()) { - MS_LOG(ERROR) << "Apply kernel [" << op_name << "] op info inputs is empty"; - return false; - } - - // create input name list for atch "x_shape" in att with "x" in primitive. - std::map op_info_shape_name; - for (size_t op_info_input_i = 0; op_info_input_i < inputs.size(); op_info_input_i++) { - std::string input_name = inputs[op_info_input_i]->name(); - std::string x_shape_name = input_name + "_shape"; - (void)op_info_shape_name.insert(make_pair(op_info_input_i, x_shape_name)); - } - - for (const auto &op_attr : attrs) { - nlohmann::json attr_json; - ValuePtr attr_value = primitive->GetAttr(op_attr->name()); - if (attr_value == nullptr && op_attr->name() != kArgDataformat) { - if (op_attr->param_type() == "required") { - // match "x_shape" in att with "x" in primitive. - std::string attr_name = op_attr->name(); - auto find_item = std::find_if( - op_info_shape_name.begin(), op_info_shape_name.end(), - [attr_name](const std::map::value_type item) { return item.second == attr_name; }); - if (find_item != op_info_shape_name.end()) { - if (!dyn_input_sizes.empty()) { - if (find_item->first >= dyn_input_sizes.size() - 1) { - MS_LOG(EXCEPTION) << "dyn_input_sizes list index:" << find_item->first - << " is out of range:" << dyn_input_sizes.size() - 1 << "."; - return false; - } - size_t tensor_idx = IntToSize(std::accumulate(&dyn_input_sizes[0], &dyn_input_sizes[find_item->first], 0)); - for (int input_i = 0; input_i < dyn_input_sizes[find_item->first]; input_i++) { - attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, tensor_idx); - attr_json[kName] = op_attr->name(); - attrs_json->push_back(attr_json); - tensor_idx++; - } - } else { - attr_json[kValue] = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, find_item->first); - attr_json[kName] = op_attr->name(); - attrs_json->push_back(attr_json); - } - } else { - MS_LOG(ERROR) << "op [" << op_name << "] should have attr :" << op_attr->name(); - return false; - } - } - continue; - } - - GetJson(anf_node, dyn_input_sizes, op_attr, &attr_json, attr_value); - - attr_json[kName] = op_attr->name(); - attrs_json->push_back(attr_json); - } - return true; -} - -bool AkgKernelBuild::GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, - nlohmann::json *const node_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(node_json); - int op_cnt = GetOpCntInc(); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - MS_EXCEPTION_IF_NULL(op_info_ptr); - - // get basic params from currentNodeOpDesc - (*node_json)[kName] = op_name; - (*node_json)["impl_path"] = op_info_ptr->impl_path(); - (*node_json)["process"] = AkgKernelBuild::GetProcessor(anf_node); - (*node_json)["composite"] = false; - - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - ValuePtr input_names_v = primitive->GetAttr(KInputNames); - if (input_names_v == nullptr) { - MS_LOG(ERROR) << "ApplyKernel has no input_names, op[" << op_name << "]."; - return false; - } - std::vector prim_input_names = GetValue>(input_names_v); - std::string inputs_name; - for (const auto &prim_input_name : prim_input_names) { - (void)inputs_name.append("_input_").append(prim_input_name).append("_"); - } - - // input desc - nlohmann::json inputs_json; - if (!CreateInputDescJson(anf_node, &inputs_json)) { - MS_LOG(ERROR) << "Create input desc json failed, op[" << op_name << "]."; - return false; - } - (*node_json)[kInputDesc] = inputs_json; - MS_LOG(INFO) << "Akg create input desc json success."; - std::string inputs_shape = "inputs_shape_"; - for (auto &i : inputs_json) { - for (auto &m : i) { - std::string data_type = m[kDataType]; - (void)inputs_shape.append("_").append(data_type).append("_"); - for (auto &j : m[kShape]) { - size_t n = j; - (void)inputs_shape.append(std::to_string(n)).append("_"); - } - } - } - - // output desc - nlohmann::json outputs_json; - if (!CreateOutputDescJson(anf_node, &outputs_json)) { - MS_LOG(ERROR) << "Create output desc json failed, op[" << op_name << "]."; - return false; - } - - (*node_json)[kOutputDesc] = outputs_json; - MS_LOG(INFO) << "Akg create output desc json success."; - std::string outputs_shape = "outputs_shape_"; - for (auto &i : outputs_json) { - std::string data_type = i[kDataType]; - (void)outputs_shape.append("_").append(data_type).append("_"); - for (auto &j : i[kShape]) { - size_t m = j; - (void)outputs_shape.append(std::to_string(m)).append("_"); - } - } - - // attribute desc - nlohmann::json attrs_json; - if (!CreateAttrDescJson(anf_node, op_name, op_info_ptr, &attrs_json)) { - MS_LOG(ERROR) << "Create attr desc json failed, op[" << op_name << "]."; - return false; - } - (*node_json)["attr"] = attrs_json; - std::string json_str = node_json->dump(); - size_t hash_id = std::hash()(json_str); - json_name_ = op_name + "_"; - (void)json_name_.append(std::to_string(hash_id)); - MS_LOG(INFO) << "full scope name is : " << anf_node->fullname_with_scope() << ", json info name is : " << json_name_; - json_info_ = json_str; - (*node_json)["id"] = op_cnt; - (*node_json)["op"] = json_name_; - MS_LOG(INFO) << "Akg create node desc json success."; - return true; -} - -KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto processor = AkgKernelBuild::GetProcessor(anf_node); - auto cached_kernel_pack = SearchCache(json_name_, processor); - if (cached_kernel_pack != nullptr) { - MS_LOG(INFO) << "Use cached kernel, json_name_[" << json_name_ << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return cached_kernel_pack; - } - - PyObject *pModule = nullptr; - PyObject *pFunc = nullptr; - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - - pModule = PyImport_ImportModule(kAkgModule); - if (pModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "]."; - return nullptr; - } - - pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc); - pArg = PyTuple_New(ARGS_SIZE); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str())); - - (void)alarm(AUTODIFF_COMPILE_OVERTIME); - pRes = PyEval_CallObject(pFunc, pArg); - (void)alarm(0); - if (pRes == nullptr) { - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; - return nullptr; - } - if (PyObject_IsTrue(pRes) != 1) { - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(pArg) << ")."; - return nullptr; - } - - auto new_kernel_pack = InsertCache(json_name_, processor); - kernel::SaveJsonInfo(json_name_, json_info_); - if (new_kernel_pack == nullptr) { - MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name_ << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return nullptr; - } - return new_kernel_pack; -} - -KernelPackPtr AkgKernelBuild::BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, - std::vector *const output_size) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto it = kAkgKernelAttrsProcessMap.find(op_name); - if (it != kAkgKernelAttrsProcessMap.end()) { - it->second(anf_node); - } - MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; - nlohmann::json node_json; - if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { - MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; - } - - std::string json_str = node_json.dump(); - auto kernel_pack = OpBuild(json_str, anf_node); - if (kernel_pack == nullptr) { - MS_LOG(ERROR) << "Akg build failed op[" << op_name << "], json:" << json_str; - return nullptr; - } - - if (!GetIOSize(node_json, input_size, output_size)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return nullptr; - } - MS_LOG(INFO) << "Akg compile success, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) - << "]"; - return kernel_pack; -} - -size_t AkgKernelBuild::GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" - << cnode->inputs().size() - 1 << "][" << cnode->DebugString() << "]"; - } - - auto input_node = cnode->input(input_idx + 1); - if (input_tensor_idx_.find(input_node) == input_tensor_idx_.end()) { - size_t index = input_tensor_idx_.size(); - input_tensor_idx_[input_node] = index; - } - - return input_tensor_idx_[input_node]; -} - -size_t AkgKernelBuild::GetOutputTensorIdxInc() { - size_t idx = output_tensor_idx_++; - return idx; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_build.h b/mindspore/ccsrc/kernel/akg/akg_kernel_build.h deleted file mode 100644 index 15fa03f45b..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_build.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "ir/dtype.h" -#include -#include "kernel/common_utils.h" -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace kernel { -class AkgKernelBuild { - public: - AkgKernelBuild() { - input_tensor_idx_ = {}; - output_tensor_idx_ = 0; - } - ~AkgKernelBuild() = default; - - KernelPackPtr BuildByJson(const AnfNodePtr &anf_node, std::vector *const input_size, - std::vector *const output_size); - static std::string GetProcessor(const AnfNodePtr &anf_node); - static std::string PyObjectToStr(PyObject *const PyObj); - - protected: - bool CreateInputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const inputs_json); - bool CreateOutputDescJson(const AnfNodePtr &anf_node, nlohmann::json *const outputs_json); - bool CreateAttrDescJson(const AnfNodePtr &anf_node, const std::string &op_name, - const std::shared_ptr &op_info, nlohmann::json *const attrs_json); - KernelPackPtr OpBuild(const std::string &node_json, const AnfNodePtr &anf_node); - int GetOpCntInc(); - size_t GetInputTensorIdxInc(const AnfNodePtr &anf_node, size_t input_idx); - size_t GetOutputTensorIdxInc(); - bool GenerateSingleKernelJson(const AnfNodePtr &anf_node, const std::string &op_name, - nlohmann::json *const node_json); - - static int op_cnt_; - // lock for variable fusionOpCnt in singleton mode - static std::mutex op_cnt_mtx_; - std::string json_name_; - std::string json_info_; - std::unordered_map input_tensor_idx_; - size_t output_tensor_idx_; -}; - -bool GetIOSize(const nlohmann::json &node_json, std::vector *const input_size, - std::vector *const output_size); -void SetTensorName(const std::string &tag, const std::string &new_name, const std::pair &position, - nlohmann::json *const node_json); -std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag, - const std::pair &position); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKGKERNELBUILD_H_ diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.cc b/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.cc deleted file mode 100644 index 3515add1e0..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/akg_kernel_metadata.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -void AkgMetadataInfo(const CNodePtr &kernel_node, - std::vector> *const kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - for (size_t i = 0; i < support_devices.size(); i++) { - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kAKG); - if (op_info_ptr == nullptr) { - continue; - } - - if (!ParseMetadata(kernel_node, op_info_ptr, Processor(i), kernel_info_list)) { - MS_LOG(WARNING) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "] failed."; - } else { - MS_LOG(DEBUG) << "Akg parsed metadata of op[" << op_name << "], device[" << support_devices[i] << "]."; - break; - } - } - - if (kernel_info_list->empty()) { - MS_LOG(WARNING) << "Akg dose not has metadata of op[" << op_name << "]."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.h b/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.h deleted file mode 100644 index 5e329f0080..0000000000 --- a/mindspore/ccsrc/kernel/akg/akg_kernel_metadata.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ - -#include -#include -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void AkgMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_AKG_AKG_KERNEL_METADATA_H_ diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.cc deleted file mode 100644 index 7200a91ac0..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.cc +++ /dev/null @@ -1,422 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/ascend/akg_ascend_kernel_build.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/dtype.h" -#include "ir/func_graph.h" -#include "kernel/kernel.h" -#include "kernel/common_utils.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/akg/ascend/akg_ascend_kernel_mod.h" -#include "kernel/akg/akg_kernel_attrs_process.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -constexpr int32_t PARALLEL_ARGS_SIZE = 3; -constexpr int32_t PROCESS_NUM = 16; -constexpr int32_t TIME_OUT = 300; - -constexpr auto kOpDesc = "op_desc"; -constexpr auto kShape = "shape"; -constexpr auto kDataType = "data_type"; -constexpr auto kInputDesc = "input_desc"; -constexpr auto kOutputDesc = "output_desc"; -constexpr auto kTensorName = "tensor_name"; -constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; -constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; -namespace { -void UpdateTensorNameInJson(const std::vector &anf_nodes, - std::map *node_json_map) { - for (auto const &anf_node : anf_nodes) { - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - bool is_dynamic_input = !dyn_input_sizes.empty(); - size_t input_num = is_dynamic_input ? dyn_input_sizes.size() : AnfAlgo::GetInputTensorNum(anf_node); - size_t real_input_index = 0; - for (size_t i = 0; i < input_num; ++i) { - size_t input_tensor_num = is_dynamic_input ? IntToSize(dyn_input_sizes[i]) : 1; - for (size_t j = 0; j < input_tensor_num; ++j) { - auto tmp_input = GetKernelInput(anf_node, real_input_index); - std::string tensor_name = GetTensorName((*node_json_map)[anf_node], kInputDesc, std::make_pair(i, j)); - if (node_json_map->find(tmp_input.first) != node_json_map->end()) { - std::string new_tensor_name = - GetTensorName((*node_json_map)[tmp_input.first], kOutputDesc, std::make_pair(0, tmp_input.second)); - SetTensorName(kInputDesc, new_tensor_name, std::make_pair(i, j), &((*node_json_map)[anf_node])); - MS_LOG(DEBUG) << "Update [" << real_input_index << "] input [" << tensor_name << "] of [" - << anf_node->fullname_with_scope() << "] to [" << tmp_input.second << "] output [" - << new_tensor_name << "] of [" << tmp_input.first->fullname_with_scope() << "]."; - } else { - MS_LOG(DEBUG) << "[" << real_input_index << "] input " << tensor_name << "] of [" - << anf_node->fullname_with_scope() << "] is out input."; - } - real_input_index++; - } - } - } -} - -nlohmann::json GetInputsJson(const std::vector &anf_nodes, const std::vector &input_list, - std::map *node_json_map) { - nlohmann::json inputs_json; - auto input_index = GetInputIndex(anf_nodes, input_list); - for (size_t i = 0; i < input_index.size(); ++i) { - auto tmp_input = input_index[i]; - auto type_id = AnfAlgo::GetInputDeviceDataType(tmp_input.first, tmp_input.second.first); - std::string dtype = TypeId2String(type_id); - nlohmann::json input_desc_json; - input_desc_json[kTensorName] = GetTensorName((*node_json_map)[tmp_input.first], kInputDesc, tmp_input.second); - input_desc_json[kDataType] = dtype; - input_desc_json[kShape] = AnfAlgo::GetInputDeviceShape(tmp_input.first, tmp_input.second.first); - inputs_json.emplace_back(std::vector{input_desc_json}); - } - - return inputs_json; -} - -nlohmann::json GetOutputsJson(const std::vector &anf_nodes, const std::vector &input_list, - const std::vector &output_list, const nlohmann::json &inputs_json, - std::map *node_json_map) { - nlohmann::json outputs_json; - auto output_index = GetOutputIndex(anf_nodes, input_list, output_list); - for (size_t i = 0; i < output_index.size(); ++i) { - auto tmp_output = output_index[i]; - bool found = false; - nlohmann::json output_desc_json; - for (size_t input_i = 0; input_i < input_list.size(); ++input_i) { - if (tmp_output.first == input_list[input_i]) { - output_desc_json = inputs_json[input_i][0]; - found = true; - break; - } - } - if (!found) { - auto type_id = AnfAlgo::GetOutputDeviceDataType(tmp_output.first, tmp_output.second); - std::string dtype = TypeId2String(type_id); - output_desc_json[kTensorName] = - GetTensorName((*node_json_map)[tmp_output.first], kOutputDesc, std::make_pair(0, tmp_output.second)); - output_desc_json[kDataType] = dtype; - auto output_shape = AnfAlgo::GetOutputDeviceShape(tmp_output.first, tmp_output.second); - if (output_shape.empty()) { - output_shape.push_back(1); - } - output_desc_json[kShape] = output_shape; - } - outputs_json.emplace_back(output_desc_json); - } - - return outputs_json; -} - -std::pair, std::vector>> PreProcessJsonForBuild( - const std::vector> &build_args) { - // Remove cached nodes, gether unique nodes, and collect repeated nodes which need postprecess. - std::vector jsons; - std::vector> repeat_nodes; - std::unordered_set json_name_set; - for (const auto &[builder, anf_node] : build_args) { - MS_EXCEPTION_IF_NULL(anf_node); - auto json_name = builder.json_name(); - MS_LOG(DEBUG) << "Akg start compile op: " << json_name; - auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); - if (cached_kernel_pack != nullptr) { - MS_LOG(DEBUG) << "Use cached kernel, json_name_[" << json_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); - kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - continue; - } - - if (json_name_set.count(json_name) != 0) { - repeat_nodes.push_back({builder, anf_node}); - continue; - } - json_name_set.insert(json_name); - auto node_json = builder.kernel_json(); - kernel::SaveJsonInfo(json_name, node_json); - jsons.push_back(node_json); - } - - return std::make_pair(jsons, repeat_nodes); -} - -bool PostProcessAfterCompile(const std::vector> &build_args, - const std::vector> &repeat_nodes) { - for (const auto &[builder, anf_node] : build_args) { - auto json_name = builder.json_name(); - auto new_kernel_pack = tbe::TbeUtils::InsertCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); - if (new_kernel_pack == nullptr) { - MS_LOG(ERROR) << "Insert to cache failed, json_name_[" << json_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - return false; - } - auto kernel_mod_ptr = std::make_shared(new_kernel_pack); - kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - MS_LOG(DEBUG) << "Akg compile " << json_name << " kernel and insert cache successfully!"; - } - - for (const auto &[builder, anf_node] : repeat_nodes) { - auto node_json = builder.kernel_json(); - auto json_name = builder.json_name(); - auto cached_kernel_pack = tbe::TbeUtils::SearchCache(json_name, AkgKernelBuild::GetProcessor(anf_node)); - if (cached_kernel_pack == nullptr) { - return false; - } - MS_LOG(INFO) << "Use just compiled kernel, json_name_[" << json_name << "], fullname_with_scope[" - << anf_node->fullname_with_scope() << "]."; - auto kernel_mod_ptr = std::make_shared(cached_kernel_pack); - kernel_mod_ptr->SetInputSizeList(builder.input_size_list()); - kernel_mod_ptr->SetOutputSizeList(builder.output_size_list()); - AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); - } - - return true; -} -} // namespace - -bool AkgAscendKernelBuilder::CollectJson(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "AKG start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; - auto it = kAkgKernelAttrsProcessMap.find(op_name); - if (it != kAkgKernelAttrsProcessMap.end()) { - it->second(anf_node); - } - MS_LOG(INFO) << "Akg start compile, op[" << op_name << "], device[" << AkgKernelBuild::GetProcessor(anf_node) << "]"; - nlohmann::json node_json; - if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { - MS_LOG(ERROR) << "Op[" << op_name << "] create single kernel json failed."; - } - - kernel_json_ = node_json.dump(); - - if (!GetIOSize(node_json, &input_size_list_, &output_size_list_)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return false; - } - - return true; -} - -bool AkgAscendKernelBuilder::GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, - std::map *node_json_map) { - for (auto const &anf_node : anf_nodes) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (!AnfAlgo::IsRealKernel(anf_node)) { - MS_LOG(ERROR) << "Invalid anf node to build [" << anf_node->fullname_with_scope() << "]."; - return false; - } - auto it = kAkgKernelAttrsProcessMap.find(op_name); - if (it != kAkgKernelAttrsProcessMap.end()) { - it->second(anf_node); - } - - nlohmann::json node_json; - if (!GenerateSingleKernelJson(anf_node, op_name, &node_json)) { - MS_LOG(ERROR) << "Op [" << op_name << "] create single kernel json failed."; - return false; - } - // No need for composite op. - node_json.erase("id"); - node_json.erase("op"); - node_json.erase("composite"); - - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - - if (primitive->GetAttr("fusion") != nullptr) { - node_json["fusion"] = primitive->GetAttr("fusion")->ToString(); - } - - (*node_json_map)[anf_node] = node_json; - } - return true; -} - -bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector &anf_nodes, - const std::vector &input_list, - const std::vector &output_list) { - if (anf_nodes.empty() || input_list.empty()) { - MS_LOG(ERROR) << "Invalid input size, anf_nodes [" << anf_nodes.size() << "], input_list [" << input_list.size() - << "]."; - return false; - } - MS_LOG(INFO) << "anf_nodes [" << output_list.size() << "], input_list [" << anf_nodes.size() << "], output_list [" - << input_list.size() << "]."; - - std::map node_json_map; - if (!GenJsonAndPreprocess4Fused(anf_nodes, &node_json_map)) { - return false; - } - - UpdateTensorNameInJson(anf_nodes, &node_json_map); - - nlohmann::json fused_node_json; - std::vector node_json_desc; - std::transform(anf_nodes.begin(), anf_nodes.end(), std::back_inserter(node_json_desc), - [&node_json_map](const AnfNodePtr &anf_node) { return node_json_map[anf_node]; }); - fused_node_json[kOpDesc] = node_json_desc; - fused_node_json[kInputDesc] = GetInputsJson(anf_nodes, input_list, &node_json_map); - fused_node_json[kOutputDesc] = - GetOutputsJson(anf_nodes, input_list, output_list, fused_node_json[kInputDesc], &node_json_map); - - size_t hash_id = std::hash()(fused_node_json.dump()); - json_name_ = "Fused_"; - auto fg = anf_nodes[0]->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (attr_val != nullptr) { - auto fg_attr = GetValue(attr_val); - (void)json_name_.append(fg_attr).append("_"); - } - (void)json_name_.append(std::to_string(hash_id)); - fused_node_json["composite_graph"] = fg->ToString(); - fused_node_json["op"] = json_name_; - fused_node_json["platform"] = "AKG"; - fused_node_json["process"] = "aicore"; - fused_node_json["composite"] = true; - - kernel_json_ = fused_node_json.dump(); - - if (!GetIOSize(fused_node_json, &input_size_list_, &output_size_list_)) { - MS_LOG(ERROR) << "Cal mem size failed."; - return false; - } - - return true; -} - -void GenParallelCompileFuncArgs(const std::vector &kernel_jsons, PyObject **p_args) { - MS_EXCEPTION_IF_NULL(p_args); - *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); - - PyObject *arg1 = PyList_New(kernel_jsons.size()); - for (int i = 0; i < PyList_Size(arg1); ++i) { - PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); - } - PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); - PyObject *arg3 = Py_BuildValue("i", TIME_OUT); - - (void)PyTuple_SetItem(*p_args, 0, arg1); - (void)PyTuple_SetItem(*p_args, 1, arg2); - (void)PyTuple_SetItem(*p_args, 2, arg3); -} - -bool AkgOpParallelBuild(const std::vector> &build_args) { - auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); - if (jsons.empty()) { - return true; - } - - // Try to call python method to compile nodes parallely. - PyObject *p_module = nullptr; - PyObject *p_func = nullptr; - PyObject *p_arg = nullptr; - PyObject *p_res = nullptr; - - p_module = PyImport_ImportModule(kMultiProcModule); - if (p_module == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; - return false; - } - - p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); - GenParallelCompileFuncArgs(jsons, &p_arg); - MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() - << " Akg kernels parallelly."; - p_res = PyEval_CallObject(p_func, p_arg); - if (p_res == nullptr) { - PyErr_Print(); - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; - return false; - } - if (PyObject_IsTrue(p_res) != 1) { - PyErr_Print(); - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; - return false; - } - - if (!PostProcessAfterCompile(build_args, repeat_nodes)) { - return false; - } - - return true; -} - -bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes) { - std::vector> json_and_node; - for (const auto &anf_node : anf_nodes) { - MS_EXCEPTION_IF_NULL(anf_node); - AkgAscendKernelBuilder akg_cce_kernel_builder; - KernelPackPtr kernel_pack = nullptr; - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::IsGraphKernel(cnode)) { - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - auto mng = func_graph->manager(); - if (mng == nullptr) { - mng = Manage(func_graph, true); - func_graph->set_manager(mng); - } - MS_EXCEPTION_IF_NULL(func_graph); - std::vector node_list; - std::vector input_list; - std::vector output_list; - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "Akg start compile composite op[" << op_name << "]"; - GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - if (!akg_cce_kernel_builder.CollectFusedJson(node_list, input_list, output_list)) { - MS_EXCEPTION(UnknownError) << "Akg build failed composite op[" << op_name << "]."; - } - } else { - if (!akg_cce_kernel_builder.CollectJson(anf_node)) { - MS_EXCEPTION(UnknownError) << "Akg build failed op[" << AnfAlgo::GetCNodeName(anf_node) << "]."; - } - } - json_and_node.push_back({akg_cce_kernel_builder, anf_node}); - } - - if (json_and_node.empty()) { - MS_LOG(DEBUG) << "There is no kernel needed to be compiled."; - return true; - } - - return AkgOpParallelBuild(json_and_node); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.h b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.h deleted file mode 100644 index 01752911ed..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_build.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" - -namespace mindspore { -namespace kernel { -class AkgAscendKernelBuilder : public AkgKernelBuild { - public: - AkgAscendKernelBuilder() = default; - ~AkgAscendKernelBuilder() = default; - - bool CollectJson(const AnfNodePtr &anf_node); - bool CollectFusedJson(const std::vector &anf_nodes, const std::vector &input_list, - const std::vector &output_list); - std::string json_name() const { return json_name_; } - std::string kernel_json() const { return kernel_json_; } - const std::vector &input_size_list() const { return input_size_list_; } - const std::vector &output_size_list() const { return output_size_list_; } - - private: - bool GenJsonAndPreprocess4Fused(const std::vector &anf_nodes, - std::map *node_json_map); - - std::string kernel_json_; - std::vector input_size_list_; - std::vector output_size_list_; -}; - -bool AkgAscendKernelParallelBuild(const std::vector &anf_nodes); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.cc b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.cc deleted file mode 100644 index 101a9f79b6..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/ascend/akg_ascend_kernel_mod.h" -#include -#include -#include -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "runtime/rt.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -using std::fstream; -using std::map; -using std::mutex; -using std::string; -using TbeTaskInfoPtr = std::shared_ptr; -using tbe::KernelManager; -constexpr uint32_t DEFAULT_BLOCK_DIM = 1; -/** - * @brief infotable contain func_stub\blockdim\kernel file buffer - */ -AkgKernelMod::AkgKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} - -void AkgKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } - -void AkgKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } - -void AkgKernelMod::SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } - -const std::vector &AkgKernelMod::GetInputSizeList() const { return input_size_list_; } - -const std::vector &AkgKernelMod::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &AkgKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool AkgKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == nullptr) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - - if (kernel_pack_ == nullptr) { - MS_LOG(ERROR) << "kernel pack should not be nullptr."; - return false; - } - - uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. - auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); - if (func_stub == 0) { - MS_LOG(ERROR) << "GenFuncStub failed."; - return false; - } - - // pack all addresses into a vector. - std::vector runtime_args; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtime_args), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtime_args), - [](const AddressPtr &output) -> void * { return output->addr; }); - - rtL2Ctrl_t *l2ctrl = nullptr; - auto stream = reinterpret_cast(stream_ptr); - if (RT_ERROR_NONE != rtKernelLaunch(reinterpret_cast(func_stub), block_dim, runtime_args.data(), - SizeToUint(sizeof(void *) * runtime_args.size()), l2ctrl, stream)) { - MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; - return false; - } - - return true; -} - -std::vector AkgKernelMod::GenTask(const std::vector &inputs, const std::vector &, - const std::vector &outputs, uint32_t stream_id) { - if (kernel_pack_ == nullptr) { - MS_LOG(EXCEPTION) << "kernel pack should not be nullptr."; - } - - std::vector args; - const uint32_t args_size = 0; - std::vector sm_desc; - void *binary = nullptr; - const uint32_t binary_size = 0; - std::vector meta_data; - std::vector input_data_addrs; - std::vector output_data_addrs; - std::vector workspace_addrs; - - // pack all addresses into a vector. - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - - uint32_t block_dim = DEFAULT_BLOCK_DIM; // default blockdim equal to 1. - auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); - if (func_stub == 0) { - MS_LOG(EXCEPTION) << "GenFuncStub failed."; - } - - std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); - - MS_LOG(DEBUG) << "The block_dim is:" << block_dim; - - TbeTaskInfoPtr task_info_ptr = make_shared( - kernel_name_, stream_id, stub_func, block_dim, args, args_size, sm_desc, binary, binary_size, meta_data, - input_data_addrs, output_data_addrs, workspace_addrs, NeedDump()); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.h b/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.h deleted file mode 100644 index 18d342f629..0000000000 --- a/mindspore/ccsrc/kernel/akg/ascend/akg_ascend_kernel_mod.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -class AkgKernelMod : public AscendKernelMod { - public: - explicit AkgKernelMod(const KernelPackPtr &kernel_pack); - ~AkgKernelMod() final {} - - void SetInputSizeList(const std::vector &size_list); - void SetOutputSizeList(const std::vector &size_list); - void SetWorkspaceSizeList(const std::vector &size_list); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - KernelPackPtr kernel_pack_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using AkgKernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_ASCEND_AKG_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.cc b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.cc deleted file mode 100644 index 534e355802..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.cc +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/gpu/akg_gpu_kernel_build.h" -#include -#include -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" -#include "kernel/akg/gpu/akg_gpu_kernel_mod.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - AkgKernelBuild akg_kernel_build; - - std::vector input_size_list; - std::vector output_size_list; - KernelPackPtr kernel_pack = akg_kernel_build.BuildByJson(anf_node, &input_size_list, &output_size_list); - MS_EXCEPTION_IF_NULL(kernel_pack); - - auto kernel_mod_ptr = std::make_shared(kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - kernel_mod_ptr->SetInputSizeList(input_size_list); - kernel_mod_ptr->SetOutputSizeList(output_size_list); - return kernel_mod_ptr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.h b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.h deleted file mode 100644 index d615890737..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_build.h +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ -#include "kernel/kernel.h" -#include "base/base.h" - -namespace mindspore { -namespace kernel { -KernelModPtr AkgGpuKernelBuild(const AnfNodePtr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.cc b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.cc deleted file mode 100644 index 64590cd9b8..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.cc +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/akg/gpu/akg_gpu_kernel_mod.h" -#include -#include -#include "nlohmann/json.hpp" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -using std::fstream; -using std::string; -using std::vector; - -GpuKernelManagerPtr GpuKernelMod::kernelmanager_ = std::make_shared(); -GpuKernelManager::GpuKernelManager() {} - -CUresult GpuKernelManager::GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, - vector *thread_info, CUfunction *func) { - if (kernel_pack->GetJson() == nullptr || kernel_pack->GetJson()->contents == nullptr || - kernel_pack->GetKernel() == nullptr || kernel_pack->GetKernel()->contents == nullptr) { - MS_LOG(ERROR) << "GPU:Invalid kernel pack, json or kernel is nullptr."; - return CUDA_ERROR_INVALID_IMAGE; - } - auto js = nlohmann::json::parse(kernel_pack->GetJson()->contents, - kernel_pack->GetJson()->contents + kernel_pack->GetJson()->len); - string fn = js["kernelName"]; - if (!force_reload) { - auto iter = infotable_.find(fn); - if (iter != infotable_.end()) { - auto kernelmeta = iter->second; - *thread_info = kernelmeta->thread_info_; - *func = kernelmeta->func_addr_; - return CUDA_SUCCESS; - } - } - thread_info->emplace_back(js["blockIdx.x"]); - thread_info->emplace_back(js["blockIdx.y"]); - thread_info->emplace_back(js["blockIdx.z"]); - thread_info->emplace_back(js["threadIdx.x"]); - thread_info->emplace_back(js["threadIdx.y"]); - thread_info->emplace_back(js["threadIdx.z"]); - CUmodule module; - CUresult result = cuModuleLoadData(&module, kernel_pack->GetKernel()->contents); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "cuModuleLoadData failed."; - return result; - } - result = cuModuleGetFunction(func, module, fn.c_str()); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "cuModuleGetFunction failed."; - return result; - } - infotable_[fn] = std::make_shared(*func, module, *thread_info); - return result; -} - -GpuKernelMod::GpuKernelMod(const KernelPackPtr &kernel_pack) : kernel_pack_(kernel_pack) {} - -void GpuKernelMod::SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } - -void GpuKernelMod::SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } - -const std::vector &GpuKernelMod::GetInputSizeList() const { return input_size_list_; } - -const std::vector &GpuKernelMod::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &GpuKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool GpuKernelMod::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == 0) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - if (kernel_pack_ == nullptr) { - MS_LOG(ERROR) << "kernel pack should not be nullptr."; - return false; - } - vector thread_info; - CUfunction kernel_addr; - CUresult result = kernelmanager_->GetFunction(kernel_pack_, false, &thread_info, &kernel_addr); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "GetFunction failed."; - return false; - } - std::vector runtimeargs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), - [](const AddressPtr &input) -> void * { return reinterpret_cast(&(input->addr)); }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), - [](const AddressPtr &output) -> void * { return reinterpret_cast(&(output->addr)); }); - result = cuLaunchKernel(kernel_addr, thread_info[0], thread_info[1], thread_info[2], thread_info[3], thread_info[4], - thread_info[5], 0, reinterpret_cast(stream_ptr), - reinterpret_cast(&runtimeargs[0]), 0); - if (result != CUDA_SUCCESS) { - MS_LOG(ERROR) << "Launch Kernel failed."; - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.h b/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.h deleted file mode 100644 index df9cb069f7..0000000000 --- a/mindspore/ccsrc/kernel/akg/gpu/akg_gpu_kernel_mod.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ -#include -#include -#include -#include -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -struct GpuKernelMeta { - CUfunction func_addr_; - CUmodule module_; - std::vector thread_info_; - GpuKernelMeta(CUfunction funcAddr, CUmodule module, const std::vector &thread_info) - : func_addr_(funcAddr), module_(module), thread_info_(thread_info) {} -}; -using GpuKernelMetaPtr = std::shared_ptr; - -class GpuKernelManager { - public: - GpuKernelManager(); - virtual ~GpuKernelManager() { - for (auto iter = infotable_.begin(); iter != infotable_.end(); ++iter) { - CUresult ret = cuModuleUnload(iter->second->module_); - if (ret != CUDA_SUCCESS && ret != CUDA_ERROR_DEINITIALIZED) { - MS_LOG(ERROR) << "Unload GPU Module failed."; - } - } - } - CUresult GetFunction(const KernelPackPtr &kernel_pack, bool force_reload, std::vector *thread_info, - CUfunction *func); - - private: - std::unordered_map infotable_; -}; -using GpuKernelManagerPtr = std::shared_ptr; - -class GpuKernelMod : public KernelMod { - public: - explicit GpuKernelMod(const KernelPackPtr &kernel_pack); - virtual ~GpuKernelMod() {} - - void SetInputSizeList(const std::vector &size_list); - void SetOutputSizeList(const std::vector &size_list); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - static GpuKernelManagerPtr kernelmanager_; - - private: - KernelPackPtr kernel_pack_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using GpuKernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_AKG_GPU_AKG_GPU_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/ascend_kernel_mod.h b/mindspore/ccsrc/kernel/ascend_kernel_mod.h deleted file mode 100644 index 1ca1dbacc8..0000000000 --- a/mindspore/ccsrc/kernel/ascend_kernel_mod.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ - -#include -#include -#include "framework/ge_runtime/task_info.h" -#include "kernel/kernel.h" -#ifdef ENABLE_DATA_DUMP -#include "debug/data_dump_parser.h" -#endif - -using TaskInfoPtr = std::shared_ptr; -namespace mindspore { -namespace kernel { -class AscendKernelMod : public KernelMod { - public: - virtual std::vector GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t) = 0; - uint32_t block_dim() { return block_dim_; } - uint32_t stream_id() { return stream_id_; } - virtual bool NeedDump() { -#ifdef ENABLE_DATA_DUMP - return DataDumpParser::GetInstance().NeedDump(kernel_name_); -#else - return false; -#endif - } - - protected: - uint32_t block_dim_{1}; - uint32_t stream_id_{0}; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_ASCEND_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/common_utils.cc b/mindspore/ccsrc/kernel/common_utils.cc deleted file mode 100644 index d42e887bbc..0000000000 --- a/mindspore/ccsrc/kernel/common_utils.cc +++ /dev/null @@ -1,1029 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/common_utils.h" -#include -#include -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "ir/manager.h" -#include "ir/meta_tensor.h" -#include "ir/func_graph.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" - -namespace mindspore { -namespace kernel { -constexpr char kAxis[] = "axis"; -constexpr char kTypeInt32[] = "Int32"; -const std::unordered_map type_id_maps = { - {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, - {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, - {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, - {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, - {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, - {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, - {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, - {"bool", TypeId::kNumberTypeBool}, -}; - -const std::map type_id_str_map = { - {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, - {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, - {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, - {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, - {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, - {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, - {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, - {TypeId::kNumberTypeBool, "bool"}, -}; - -const std::unordered_map dtype_shortdtype_map_ = { - {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"}, - {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"}, -}; - -const std::unordered_map dtype_nbyte_map = { - {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, - {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, - {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, - {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, -}; - -const std::unordered_map fusion_type_maps = { - {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, - {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE}, -}; - -void KernelMeta::Initialize() { - kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/"; - // remove old kernel cache - RemoveKernelCache(); - -#if defined(_WIN32) || defined(_WIN64) - auto ret = mkdir(kernel_meta_path_.c_str()); -#else - auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU); -#endif - if (ret != 0) { - MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later"; - } - initialized_ = true; -} - -void KernelMeta::RemoveKernelCache() { - DIR *dir = opendir(kernel_meta_path_.c_str()); - if (dir == nullptr) { - return; - } - struct dirent *entry; - while ((entry = readdir(dir)) != nullptr) { - std::string kernel_file = entry->d_name; - std::string kernel_file_realpath = kernel_meta_path_ + kernel_file; - (void)remove(kernel_file_realpath.c_str()); - } - (void)closedir(dir); - (void)rmdir(kernel_meta_path_.c_str()); -} - -std::string KernelMeta::Search(const std::string &kernel_name) const { - if (!initialized_) { - return ""; - } - - auto iter = kernel_meta_map_.find(kernel_name); - if (iter == kernel_meta_map_.end()) { - return ""; - } else { - return iter->second; - } -} - -bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) { - if (!initialized_) { - return false; - } - kernel_meta_map_[kernel_name] = kernel_json; - return true; -} - -bool CheckCache(const std::string &kernel_name) { - // check cache. - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map == nullptr) { - MS_LOG(DEBUG) << "kernel cache is invalid."; - return false; - } - std::string kernel_json = bin_map->Search(kernel_name); - bool ret = (!kernel_json.empty()); - if (ret) { - MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed."; - } else { - MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed."; - } - return ret; -} - -KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) { - // search cache. - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map == nullptr) { - MS_LOG(DEBUG) << "kernel cache is invalid."; - return nullptr; - } - - std::string kernel_json = bin_map->Search(kernel_name); - if (!kernel_json.empty()) { - KernelPackPtr kernel_pack = std::make_shared(); - // just a tmp solution. - if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { - MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "]."; - return nullptr; - } else { - return kernel_pack; - } - } else { - MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "]."; - return nullptr; - } -} - -KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) { - MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor; - KernelMeta *bin_map = KernelMeta::GetInstance(); - std::string kernel_json; - if (processor == kProcessorAiCore || processor == kProcessorAiCpu) { - kernel_json = kCceKernelMeta; - } else { - kernel_json = bin_map->GetKernelMetaPath(); - } - (void)kernel_json.append(kernel_name).append(kJsonSuffix); - KernelPackPtr kernel_pack = std::make_shared(); - if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) { - MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "]."; - return nullptr; - } - - if (bin_map == nullptr) { - MS_LOG(DEBUG) << "kernel cache is invalid."; - return nullptr; - } - if (bin_map->Insert(kernel_name, kernel_json)) { - MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "]."; - } - return kernel_pack; -} - -TypeId DtypeToTypeId(const std::string &dtypes) { - auto iter = type_id_maps.find(dtypes); - if (iter != type_id_maps.end()) { - return iter->second; - } else { - MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes; - } -} - -std::string TypeId2String(TypeId type_id) { - auto iter = type_id_str_map.find(type_id); - if (iter == type_id_str_map.end()) { - return std::string(TypeIdLabel(type_id)); - } - return iter->second; -} - -std::string Dtype2ShortType(const std::string &dtypes) { - auto iter = dtype_shortdtype_map_.find(dtypes); - if (iter != dtype_shortdtype_map_.end()) { - return iter->second; - } else { - MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; - } -} - -size_t GetDtypeNbyte(const std::string &dtypes) { - auto iter = dtype_nbyte_map.find(dtypes); - if (iter != dtype_nbyte_map.end()) { - return iter->second; - } else { - MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes; - } -} - -bool SetInputKernelBuilderInfo(const std::vector> &inputs, size_t real_input_num, - size_t builder_idex, const std::vector &dyn_input_sizes, - const std::shared_ptr &builder) { - MS_EXCEPTION_IF_NULL(builder); - - std::vector inputs_device_type; - std::vector inputs_format; - size_t dyn_input_idx = 0; - size_t kernel_info_index = 0; - MS_EXCEPTION_IF_NULL(inputs[0]); - size_t kernel_info_cnt = inputs[0]->dtypes().size(); - - for (const auto &input : inputs) { - MS_EXCEPTION_IF_NULL(input); - std::string param_type = input->param_type(); - std::vector dtypes = input->dtypes(); - std::vector formats = input->formats(); - if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { - MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size."; - return false; - } - - if (param_type == "dynamic") { - if (dyn_input_sizes.empty()) { - MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; - return false; - } - - for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) { - kernel_info_index++; - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - } - dyn_input_idx++; - } else if (param_type == "required") { - kernel_info_index++; - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - } else { - if (kernel_info_index < real_input_num) { - MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index; - kernel_info_index++; - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - inputs_device_type.push_back(type_id); - inputs_format.push_back(formats[builder_idex]); - } - } - } - - builder->SetInputsDeviceType(inputs_device_type); - builder->SetInputsFormat(inputs_format); - return true; -} - -bool SetOutputKernelBuilderInfo(const std::vector> &outputs, size_t builder_idex, - const size_t &real_output_num, - const std::shared_ptr &builder) { - // not now but in the next we need to support dynamic output case - MS_EXCEPTION_IF_NULL(builder); - - size_t output_idx = 0; - std::vector outputs_device_type; - std::vector outputs_format; - MS_EXCEPTION_IF_NULL(outputs[0]); - size_t kernel_info_cnt = outputs[0]->dtypes().size(); - - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output); - if (output_idx >= real_output_num) { - MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!"; - continue; - } - size_t output_num = 0; - if (output->param_type() == "dynamic") { - if (outputs.size() > 1) { - MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!"; - } - output_num = real_output_num; - } else if (output->param_type() == "required") { - output_num = 1; - } else { - if (output_idx < real_output_num) { - MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx; - output_num = 1; - } - } - - for (size_t i = 0; i < output_num; i++) { - std::vector dtypes = output->dtypes(); - std::vector formats = output->formats(); - if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) { - MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size."; - return false; - } - auto type_id = DtypeToTypeId(dtypes[builder_idex]); - outputs_device_type.push_back(type_id); - outputs_format.push_back(formats[builder_idex]); - output_idx++; - } - } - - builder->SetOutputsFormat(outputs_format); - builder->SetOutputsDeviceType(outputs_device_type); - return true; -} - -void SetKernelBuildInfo(const std::shared_ptr &builder, Processor processor, - const std::shared_ptr &op_info_ptr) { - MS_EXCEPTION_IF_NULL(builder); - MS_EXCEPTION_IF_NULL(op_info_ptr); - - auto imply_type = op_info_ptr->imply_type(); - builder->SetProcessor(processor); - std::string fusion_type = op_info_ptr->fusion_type(); - auto iter = fusion_type_maps.find(fusion_type); - if (iter != fusion_type_maps.end()) { - builder->SetFusionType(iter->second); - } else { - if (imply_type == kAKG) { - MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type; - } - } - - if (imply_type == kAKG) { - builder->SetKernelType(AKG_KERNEL); - } else if (imply_type == kAICPU) { - builder->SetKernelType(AICPU_KERNEL); - } else { - builder->SetKernelType(TBE_KERNEL); - } -} - -bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, - std::vector> *const kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - std::vector> inputs = op_info_ptr->inputs_ptr(); - std::vector> outputs = op_info_ptr->outputs_ptr(); - std::vector dyn_input_sizes; - auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("dyn_input_sizes") != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr("dyn_input_sizes")); - } - if (inputs.size() > 0) { - MS_EXCEPTION_IF_NULL(inputs[0]); - size_t kernel_info_cnt = inputs[0]->dtypes().size(); - for (size_t j = 0; j < kernel_info_cnt; j++) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildInfo(builder, processor, op_info_ptr); - - if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) { - MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed."; - return false; - } - - if (outputs.size() > 0) { - if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { - MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; - return false; - } - } - - kernel_info_list->push_back(builder->Build()); - } - } else if (outputs.size() > 0) { - MS_EXCEPTION_IF_NULL(outputs[0]); - size_t kernel_info_cnt = outputs[0]->dtypes().size(); - for (size_t j = 0; j < kernel_info_cnt; j++) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildInfo(builder, processor, op_info_ptr); - - if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) { - MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed."; - return false; - } - - kernel_info_list->push_back(builder->Build()); - } - } else { - if (processor == AICPU) { - auto builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(builder); - SetKernelBuildInfo(builder, processor, op_info_ptr); - kernel_info_list->push_back(builder->Build()); - } - } - return true; -} - -void SaveJsonInfo(const std::string &json_name, const std::string &info) { - char real_path[PATH_MAX] = {0}; - std::string path = kCceKernelMeta + json_name + kInfoSuffix; - if (path.size() > PATH_MAX) { - MS_LOG(DEBUG) << "file path " << path << " is too long."; - return; - } - std::ofstream filewrite; - filewrite.open(path); - if (!filewrite.is_open()) { - return; - } - filewrite << info << std::endl; - filewrite.close(); -#if defined(_WIN32) || defined(_WIN64) - if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) { - MS_LOG(DEBUG) << "dir " << path << " does not exit."; - return; - } -#else - if (nullptr == realpath(path.c_str(), real_path)) { - MS_LOG(DEBUG) << "dir " << path << " does not exit."; - return; - } -#endif - MS_LOG(INFO) << "real path is :" << real_path; - if (chmod(real_path, S_IRUSR) == -1) { - MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail."; - } -} - -std::string GetProcessor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string device; - switch (AnfAlgo::GetProcessor(anf_node)) { - case Processor::AICORE: - device = kProcessorAiCore; - break; - - case Processor::AICPU: - device = kProcessorAiCpu; - break; - - case Processor::CUDA: - device = kProcessorCuda; - break; - - default: - MS_LOG(DEBUG) << "Unknown processor type."; - break; - } - return device; -} - -bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b) { - if (shape_a.size() != shape_b.size()) { - return false; - } - for (size_t i = 0; i < shape_a.size(); ++i) { - if (shape_a[i] != shape_b[i]) { - return false; - } - } - return true; -} - -int Sign(float x) { - if (x > 0) { - return 1; - } - if (x < 0) { - return -1; - } - return 0; -} - -void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim) { - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - std::unordered_map index_map; - size_t unique_indices_size = 0; - for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { - int index = origin_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= first_dim) { - continue; - } - auto iter = index_map.find(index); - if (iter == index_map.end()) { - index_map[index] = unique_indices_size; - unique_grad->indices_[unique_indices_size] = index; - size_t start_index = unique_indices_size * outer_dim; - size_t end_index = start_index + outer_dim; - for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { - unique_grad->value_[j] = origin_sparse_grad.value_[k]; - } - unique_indices_size++; - } else { - size_t first_index = iter->second; - size_t start_index = first_index * outer_dim; - size_t end_index = start_index + outer_dim; - for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) { - unique_grad->value_[j] += origin_sparse_grad.value_[k]; - } - } - } - unique_grad->indices_size_ = unique_indices_size; -} - -struct WorkerParamsForReduceSparseGradient { - size_t slice_start_{0}; - size_t slice_end_{0}; - size_t max_length_{0}; - size_t outer_dim_{0}; - std::vector> *sorted_indices_{nullptr}; - std::vector *slice_positions_{nullptr}; - float *src_value_{nullptr}; - SparseGradient *unique_grad_{nullptr}; -}; - -void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) { - MS_EXCEPTION_IF_NULL(param.sorted_indices_); - MS_EXCEPTION_IF_NULL(param.slice_positions_); - MS_EXCEPTION_IF_NULL(param.src_value_); - MS_EXCEPTION_IF_NULL(param.unique_grad_); - auto outer_dim = param.outer_dim_; - auto &sorted_indices = *(param.sorted_indices_); - auto &slice_positions = *(param.slice_positions_); - auto unique_grad = param.unique_grad_; - for (size_t slice_id = param.slice_start_; slice_id < param.slice_end_; ++slice_id) { - size_t cur_pos = slice_positions[slice_id]; - int index = sorted_indices[cur_pos].first; - unique_grad->indices_[slice_id] = index; - size_t start_index = slice_id * outer_dim; - auto ret_code = memcpy_s(unique_grad->value_ + start_index, (param.max_length_ - start_index) * sizeof(float), - param.src_value_ + sorted_indices[cur_pos].second, outer_dim * sizeof(float)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; - } - cur_pos++; - size_t end_pos; - if (slice_id + 1 < slice_positions.size()) { - end_pos = slice_positions[slice_id + 1]; - } else { - end_pos = sorted_indices.size(); - } - while (cur_pos < end_pos) { - for (size_t i = 0; i < outer_dim; ++i) { - unique_grad->value_[start_index + i] += param.src_value_[sorted_indices[cur_pos].second + i]; - } - cur_pos++; - } - } -} - -void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, - size_t outer_dim, std::vector> *sorted_indices, - std::vector *slice_positions) { - MS_LOG(DEBUG) << "Start"; - size_t thread_num = 24; - if (slice_positions->size() < thread_num) { - thread_num = slice_positions->size(); - } - size_t stride = (slice_positions->size() + thread_num - 1) / thread_num; - thread_num = (slice_positions->size() + stride - 1) / stride; - std::vector threads; - size_t max_length = sorted_indices->size() * outer_dim; - for (size_t i = 0; i < thread_num; ++i) { - size_t slice_start = i * stride; - size_t slice_end = 0; - if (i == thread_num - 1) { - slice_end = slice_positions->size(); - } else { - slice_end = slice_start + stride; - } - WorkerParamsForReduceSparseGradient params{ - slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_, - unique_grad}; - threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params)); - } - for (size_t i = 0; i < thread_num; ++i) { - threads[i].join(); - } - MS_LOG(DEBUG) << "End"; -} - -void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim, bool use_multi_threads) { - MS_LOG(DEBUG) << "Start"; - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - std::vector> sorted_indices; - sorted_indices.reserve(origin_sparse_grad.indices_size_); - for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) { - int index = origin_sparse_grad.indices_[i]; - if (index >= 0 && IntToSize(index) < first_dim) { - sorted_indices.emplace_back(std::pair(index, i * outer_dim)); - } - } - std::sort( - sorted_indices.begin(), sorted_indices.end(), - [](const std::pair &left, const std::pair &right) { return left.first < right.first; }); - int last_index = 0; - std::vector slice_positions; - slice_positions.reserve(sorted_indices.size()); - for (size_t i = 0; i < sorted_indices.size(); ++i) { - if (i == 0 || last_index != sorted_indices[i].first) { - slice_positions.emplace_back(i); - } - last_index = sorted_indices[i].first; - } - if (use_multi_threads) { - RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions); - } else { - size_t max_length = sorted_indices.size() * outer_dim; - WorkerParamsForReduceSparseGradient params{0, - slice_positions.size(), - max_length, - outer_dim, - &sorted_indices, - &slice_positions, - origin_sparse_grad.value_, - unique_grad}; - WorkerForReduceSparseGradient(params); - } - unique_grad->indices_size_ = slice_positions.size(); - MS_LOG(DEBUG) << "End"; -} - -void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, - SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim) { - MS_LOG(DEBUG) << "Start"; - if (unique_slice_grads.empty()) { - return; - } - size_t index_data_size = outer_dim * sizeof(float); - size_t unique_indices_size = 0; - for (size_t i = 0; i < unique_slice_grads.size(); ++i) { - auto &slice_grad = unique_slice_grads[i]; - auto ret_code = memcpy_s(tmp_grad->value_ + unique_indices_size * outer_dim, - (tmp_grad->indices_size_ - unique_indices_size) * index_data_size, slice_grad->value_, - slice_grad->indices_size_ * index_data_size); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; - } - ret_code = - memcpy_s(tmp_grad->indices_ + unique_indices_size, (tmp_grad->indices_size_ - unique_indices_size) * sizeof(int), - slice_grad->indices_, slice_grad->indices_size_ * sizeof(int)); - if (ret_code != EOK) { - MS_LOG(EXCEPTION) << "Failed to copy data!"; - } - unique_indices_size += slice_grad->indices_size_; - } - tmp_grad->indices_size_ = unique_indices_size; - ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim); - MS_LOG(DEBUG) << "End"; -} - -void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, - SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) { - MS_LOG(DEBUG) << "Start"; - MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); - MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); - MS_EXCEPTION_IF_NULL(unique_grad); - MS_EXCEPTION_IF_NULL(unique_grad->value_); - MS_EXCEPTION_IF_NULL(unique_grad->indices_); - MS_EXCEPTION_IF_NULL(tmp_grad); - MS_EXCEPTION_IF_NULL(tmp_grad->value_); - MS_EXCEPTION_IF_NULL(tmp_grad->indices_); - size_t thread_num = 24; - if (origin_sparse_grad.indices_size_ < thread_num) { - thread_num = origin_sparse_grad.indices_size_; - } - size_t thread_indices_size = origin_sparse_grad.indices_size_ / thread_num; - size_t left_indices_size = origin_sparse_grad.indices_size_ % thread_num; - std::vector threads; - threads.reserve(thread_num); - std::vector> unique_slice_grads; - for (size_t i = 0; i < thread_num; ++i) { - size_t indices_size = thread_indices_size; - if (i == thread_num - 1) { - indices_size = thread_indices_size + left_indices_size; - } - size_t value_offset = i * thread_indices_size * outer_dim; - size_t indices_offset = i * thread_indices_size; - auto slice_grad = SparseGradient( - {origin_sparse_grad.value_ + value_offset, origin_sparse_grad.indices_ + indices_offset, indices_size}); - unique_slice_grads.emplace_back(std::make_shared()); - unique_slice_grads[i]->value_ = unique_grad->value_ + value_offset; - unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset; - unique_slice_grads[i]->indices_size_ = indices_size; - threads.emplace_back( - std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false)); - } - for (size_t i = 0; i < thread_num; ++i) { - threads[i].join(); - } - ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim); - MS_LOG(DEBUG) << "End"; -} - -std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index) { - MS_EXCEPTION_IF_NULL(anf_node); - - if (index >= AnfAlgo::GetInputTensorNum(anf_node)) { - MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs."; - } - - auto cnode = anf_node->cast(); - if (cnode == nullptr) { - return AnfAlgo::VisitKernel(anf_node, 0); - } else { - return AnfAlgo::VisitKernel(anf_node->cast()->input(index + 1), 0); - } -} - -std::vector>> GetInputIndex(const std::vector &node_list, - const std::vector &input_list) { - std::vector>> input_index; - for (size_t i = 0; i < input_list.size(); ++i) { - auto const &input = input_list[i]; - MS_EXCEPTION_IF_NULL(input); - bool found = false; - // using NodeUsersMap = std::unordered_map>>; - auto mng = input->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(mng); - const NodeUsersMap &users = mng->node_users(); - auto input_users = users.find(input); - if (input_users == users.end() || input_users->second.empty()) { - MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" - << input->func_graph()->ToString() << "] has no users."; - } - - for (auto const &input_user : input_users->second) { - for (auto const &anf_node : node_list) { - if (anf_node != input_user.first) { - continue; - } - - std::vector dyn_input_sizes; - auto prim = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(prim); - if (prim->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(prim->GetAttr(kAttrDynInputSizes)); - } - - if (dyn_input_sizes.empty()) { - input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0))); - found = true; - break; - } else { - int used_as_idx = input_user.second - 1; - int accum_idx = 0; - size_t dyn_i = 0; - for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) { - accum_idx += dyn_input_sizes[dyn_i]; - if (used_as_idx < accum_idx) { - input_index.push_back(std::make_pair( - anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i]))))); - break; - } - } - if (dyn_i != dyn_input_sizes.size()) { - found = true; - break; - } - } - } - if (found) { - break; - } - } - - if (!found) { - MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of [" - << input->func_graph()->ToString() << "] found no related kernel info."; - } - } - return input_index; -} - -std::vector> GetOutputIndex(const std::vector &node_list, - const std::vector &input_list, - const std::vector &output_list) { - std::vector> output_index; - for (size_t i = 0; i < output_list.size(); ++i) { - auto const &output = output_list[i]; - MS_EXCEPTION_IF_NULL(output); - bool found = false; - auto pree_node = AnfAlgo::VisitKernel(output, 0); - auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first); - if (pos != std::end(node_list)) { - output_index.push_back(pree_node); - continue; - } - auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first); - if (ret != std::end(input_list)) { - output_index.push_back(std::make_pair(pree_node.first, 0)); - found = true; - } - if (!found) { - MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of [" - << output->func_graph()->ToString() << "] found no related kernel info."; - } - } - return output_index; -} - -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list) { - MS_EXCEPTION_IF_NULL(node_list); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector node_lists = TopoSort(func_graph->get_return()); - for (auto const &node : node_lists) { - if (!AnfAlgo::IsRealKernel(node) || !node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (IsValueNode(cnode->input(kAnfPrimitiveIndex))) { - node_list->push_back(node); - } - } -} - -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, - std::vector *input_list, std::vector *output_list) { - MS_EXCEPTION_IF_NULL(node_list); - MS_EXCEPTION_IF_NULL(input_list); - MS_EXCEPTION_IF_NULL(output_list); - MS_EXCEPTION_IF_NULL(func_graph); - - GetValidKernelNodes(func_graph, node_list); - - auto parameters = func_graph->parameters(); - input_list->insert(input_list->begin(), parameters.begin(), parameters.end()); - - auto func_output = func_graph->output(); - MS_EXCEPTION_IF_NULL(func_output); - if (func_output->isa()) { - // multi output. - auto cnode = func_output->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) { - auto input_node = cnode->input(input_idx); - MS_EXCEPTION_IF_NULL(input_node); - output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first); - } - } else { - // single output. - output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); - } - } else { - // single output. - output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first); - } -} - -bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(node_json); - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->size()) { - MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of [" - << cnode->inputs().size() << "][" << cnode->DebugString() << "]"; - } - - auto input_node = cnode->input(input_idx + 1); - if (!IsValueNode(input_node)) { - return false; - } - - auto tensor = GetValueNode(input_node); - if (tensor == nullptr) { - return false; - } - - auto type_id = tensor->data_type(); - auto *data = tensor->data_c(); - MS_EXCEPTION_IF_NULL(data); - if (tensor->DataDim() > 1 || tensor->DataSize() != 1) { - // not const tensor. - MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]"; - } - - if (type_id == kFloat32->type_id()) { - float *val = static_cast(data); - MS_EXCEPTION_IF_NULL(val); - (*node_json)["value"] = val[0]; - MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "]."; - return true; - } else if (type_id == kFloat16->type_id()) { - float16 *val = static_cast(data); - MS_EXCEPTION_IF_NULL(val); - (*node_json)["value"] = static_cast(val[0]); - MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "]."; - return true; - } else if (type_id == kInt32->type_id()) { - int *val = static_cast(data); - MS_EXCEPTION_IF_NULL(val); - (*node_json)["value"] = val[0]; - MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "]."; - return true; - } - MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]"; - return false; -} - -void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node_list); - auto output = func_graph->output(); - MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::IsRealKernel(output)) { - // single output. - node_list->push_back(std::make_pair(output, 0)); - return; - } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - // multi output. - auto &inputs = output_cnode->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); - node_list->push_back(in_with_idx); - } - return; - } - MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) - << " of graph: " << func_graph->ToString(); -} - -bool IsWeightBoundary(const AnfNodePtr &node) { - if (node->isa()) { - return true; - } - if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { - return true; - } - return false; -} - -void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, - size_t total_compute_size) { - const size_t kThreadNum = 24; - std::vector threads; - threads.reserve(kThreadNum); - size_t start = 0; - size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum; - while (start < total_compute_size) { - size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size); - threads.emplace_back(std::thread(func, params, start, end)); - start += once_compute_size; - } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } -} - -std::vector GetReduceAttrAxis(const CNodePtr &cnode) { - if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) && - AnfAlgo::GetInputTensorNum(cnode) != 1) { - MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString() - << "] is not single input or single output "; - } - std::vector axis; - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - auto axis_attr = primitive->GetAttr(kAxis); - if (axis_attr == nullptr) { - MS_LOG(ERROR) << "This node does't have axie attr."; - return std::vector(); - } - auto type = axis_attr->type(); - MS_EXCEPTION_IF_NULL(type); - std::vector axis_list; - if (type->ToString() == kTypeInt32) { - axis_list.emplace_back(GetValue(axis_attr)); - } else { - axis_list = GetValue>(axis_attr); - } - for (const auto &elem : axis_list) { - if (elem < 0) { - axis.emplace_back(input_shape.size() + elem); - } else { - axis.emplace_back(elem); - } - } - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode); - return axis; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/common_utils.h b/mindspore/ccsrc/kernel/common_utils.h deleted file mode 100644 index b0ffb4ccb8..0000000000 --- a/mindspore/ccsrc/kernel/common_utils.h +++ /dev/null @@ -1,145 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "kernel/oplib/opinfo.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -constexpr auto kCceKernelMeta = "./kernel_meta/"; -constexpr auto kGpuKernelMeta = "./cuda_meta"; -constexpr auto kProcessorAiCore = "aicore"; -constexpr auto kProcessorAiCpu = "aicpu"; -constexpr auto kProcessorCuda = "cuda"; -constexpr auto kJsonSuffix = ".json"; -constexpr auto kInfoSuffix = ".info"; -constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600; -constexpr auto kAkgModule = "_akg"; -constexpr auto kArgDataformat = "data_format"; - -const std::vector support_devices = {"aicore", "aicpu", "cuda"}; - -struct KernelMetaInfo { - uintptr_t func_stub_; - uint32_t block_dim_; -}; -using KernelMetaPtr = std::shared_ptr; - -class KernelMeta { - public: - KernelMeta() = default; - void Initialize(); - void RemoveKernelCache(); - std::string Search(const std::string &kernel_name) const; - bool Insert(const std::string &kernel_name, const std::string &kernel_json); - std::string GetKernelMetaPath() { return kernel_meta_path_; } - - static KernelMeta *GetInstance() { - static KernelMeta kernel_meta; - return &kernel_meta; - } - ~KernelMeta() = default; - - private: - bool initialized_ = false; - std::string kernel_meta_path_; - std::unordered_map kernel_meta_map_; -}; - -struct SparseGradient { - float *value_; - int *indices_; - size_t indices_size_; -}; - -struct MultiThreadComputeParams { - float *var_; - float *accum_; - float *linear_; - float *m_; - float *m_t_; - float *v_; - float lr_; - float l1_; - float l2_; - float lr_power_; - float beta1_; - float beta2_; - float epsilon_; - SparseGradient sparse_grad_; - size_t var_first_dim_size_; - size_t var_outer_dim_size_; - bool use_nesterov_; -}; -using MultiThreadComputeFunc = std::function; - -bool CheckCache(const std::string &kernel_name); -KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); -KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); -TypeId DtypeToTypeId(const std::string &dtypes); -std::string Dtype2ShortType(const std::string &dtypes); -std::string TypeId2String(TypeId type_id); -size_t GetDtypeNbyte(const std::string &dtypes); -bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr &op_info_ptr, Processor processor, - std::vector> *const kernel_info_list); -void SaveJsonInfo(const std::string &json_name, const std::string &info); -std::string GetProcessor(const AnfNodePtr &anf_node); -bool IsSameShape(const std::vector &shape_a, const std::vector &shape_b); -int Sign(float x); -void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim); -void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim, bool use_multi_threads = true); -std::pair GetKernelInput(const AnfNodePtr &anf_node, size_t index); -std::vector>> GetInputIndex(const std::vector &node_list, - const std::vector &input_list); -std::vector> GetOutputIndex(const std::vector &node_list, - const std::vector &input_list, - const std::vector &output_list); -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list, - std::vector *input_list, std::vector *output_list); -void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector *node_list); -bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json); -void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list); -bool IsWeightBoundary(const AnfNodePtr &node); -void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, - size_t total_compute_size); -void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, - size_t outer_dim, std::vector> *sorted_indices, - std::vector *slice_positions); -void ReduceMultiSparseGradient(const std::vector> &unique_slice_grads, - SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, - size_t outer_dim); -void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, - SparseGradient *unique_grad, size_t first_dim, size_t outer_dim); -std::vector GetReduceAttrAxis(const CNodePtr &cnode); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc deleted file mode 100644 index 021b49e20c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/addn_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool AddNCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - - size_t offset = 0; - for (size_t i = 0; i < output_shape_[0]; ++i) { - for (size_t j = 0; j < output_shape_[1]; ++j) { - for (size_t k = 0; k < output_shape_[2]; ++k) { - for (size_t m = 0; m < output_shape_[3]; ++m) { - float sum = 0; - for (size_t index = 0; index < input_num_; ++index) { - auto input_addr = reinterpret_cast(inputs[index]->addr); - sum += input_addr[offset]; - } - output_addr[offset++] = sum; - } - } - } - } - - return true; -} - -void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but AddNCPUKernel olny support 4d or lower."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h deleted file mode 100644 index 1a1a9157d9..0000000000 --- a/mindspore/ccsrc/kernel/cpu/addn_cpu_kernel.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class AddNCPUKernel : public CPUKernel { - public: - AddNCPUKernel() : input_num_(0) {} - ~AddNCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CheckParam(const CNodePtr &kernel_node); - size_t input_num_; - std::vector output_shape_; -}; - -MS_REG_CPU_KERNEL(AddN, - KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AddNCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ADDN_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc deleted file mode 100644 index 811ea3ea16..0000000000 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/allgather_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto kRanksGroup = "group"; -constexpr auto kAllGatherInputNum = 1; -} // namespace - -void AllGatherCPUKernel::InitKernel(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != kAllGatherInputNum) { - MS_LOG(EXCEPTION) << "allgather input num:" << input_num; - } - - auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); - if (ranks_group != nullptr) { - ranks_group_ = GetValue>(ranks_group); - } else { - MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; - } -} - -bool AllGatherCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto input_data_num = inputs[0]->size / sizeof(float); - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - return mpi_instance->AllGather(input_addr, output_addr, ranks_group_, input_data_num); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h deleted file mode 100644 index 1dddf810ef..0000000000 --- a/mindspore/ccsrc/kernel/cpu/allgather_cpu_kernel.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class AllGatherCPUKernel : public CPUKernel { - public: - AllGatherCPUKernel() = default; - ~AllGatherCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::vector ranks_group_; -}; - -MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AllGatherCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.cc deleted file mode 100644 index 3cd6c57413..0000000000 --- a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/apply_momentum_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void ApplyMomentumCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} - -bool ApplyMomentumCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/) { - if (inputs.size() < 5) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - if (inputs[0]->size != inputs[1]->size || inputs[0]->size != inputs[3]->size) { - MS_LOG(EXCEPTION) << "error input data size!"; - } - auto weight = reinterpret_cast(inputs[0]->addr); - auto accumulate = reinterpret_cast(inputs[1]->addr); - float learning_rate = reinterpret_cast(inputs[2]->addr)[0]; - auto gradient = reinterpret_cast(inputs[3]->addr); - float moment = reinterpret_cast(inputs[4]->addr)[0]; - size_t elem_num = inputs[0]->size / sizeof(float); - for (size_t i = 0; i < elem_num; ++i) { - accumulate[i] = accumulate[i] * moment + gradient[i]; - weight[i] -= accumulate[i] * learning_rate; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.h deleted file mode 100644 index c0ca581974..0000000000 --- a/mindspore/ccsrc/kernel/cpu/apply_momentum_cpu_kernel.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class ApplyMomentumCPUKernel : public MKLCPUKernel { - public: - ApplyMomentumCPUKernel() = default; - ~ApplyMomentumCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - ApplyMomentumCPUKernel); -MS_REG_CPU_KERNEL(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - ApplyMomentumCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.cc deleted file mode 100644 index ee328df721..0000000000 --- a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.cc +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/argmax_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void ArgmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (shape.size() != 2) { - MS_LOG(EXCEPTION) << "argmax kernel dims invalid " << shape.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - - int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis != -1 && axis != 1) { - MS_LOG(EXCEPTION) << "argmax kernel not support axis " << axis; - } -} - -bool ArgmaxCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspaces*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "invalid input or output data size!"; - } - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - size_t row_start = 0; - for (size_t i = 0; i < batch_size_; ++i) { - size_t max_index = 0; - float max_value = input[row_start]; - for (size_t j = 1; j < class_num_; ++j) { - size_t index = row_start + j; - if (input[index] > max_value) { - max_value = input[index]; - max_index = j; - } - } - output[i] = SizeToInt(max_index); - row_start += class_num_; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.h deleted file mode 100644 index aae7435c5c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/argmax_cpu_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ArgmaxCPUKernel : public CPUKernel { - public: - ArgmaxCPUKernel() = default; - ~ArgmaxCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t class_num_{0}; - size_t batch_size_{0}; -}; - -MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - ArgmaxCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ARGMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.cc deleted file mode 100644 index 00f3017231..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.cc +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/bias_add_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - if (input_shape_.size() == 4) { - data_shape_ = 4; - } else if (input_shape_.size() == 2) { - data_shape_ = 2; - } else { - MS_LOG(EXCEPTION) << "bias add input data format should be NCHW or NC"; - } - if (input_shape_.size() != 2 && input_shape_.size() != 4) { - MS_LOG(EXCEPTION) << "bias add input shape nchw or nc"; - } - if (bias_shape_.size() != 1) { - MS_LOG(EXCEPTION) << "bias shape invalid"; - } - if (input_shape_[1] != bias_shape_[0]) { - MS_LOG(EXCEPTION) << "bias shape not match"; - } -} - -bool BiasAddCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() != 2 || outputs.size() != 1) { - MS_LOG(EXCEPTION) << "inputs outputs size not supoort"; - } - - auto src_addr = reinterpret_cast(inputs[0]->addr); - auto bias_addr = reinterpret_cast(inputs[1]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - - if (data_shape_ == 4) { - size_t h_size = input_shape_[3]; - size_t c_size = input_shape_[2] * h_size; - size_t n_size = input_shape_[1] * c_size; - size_t hw_size = input_shape_[2] * input_shape_[3]; - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - size_t c_offset = 0; - for (size_t c = 0; c < input_shape_[1]; ++c) { - for (size_t hw = 0; hw < hw_size; ++hw) { - size_t offset = n_offset + c_offset + hw; - output_addr[offset] = src_addr[offset] + bias_addr[c]; - } - c_offset += c_size; - } - n_offset += n_size; - } - } else { - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - for (size_t c = 0; c < input_shape_[1]; ++c) { - output_addr[n_offset + c] = src_addr[n_offset + c] + bias_addr[c]; - } - n_offset += input_shape_[1]; - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.h deleted file mode 100644 index 516a21147b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_cpu_kernel.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class BiasAddCPUKernel : public CPUKernel { - public: - BiasAddCPUKernel() = default; - ~BiasAddCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - uint8_t data_shape_{0}; - std::vector input_shape_; - std::vector bias_shape_; -}; -MS_REG_CPU_KERNEL( - BiasAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.cc deleted file mode 100644 index 1d9c7d076e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/bias_add_grad_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -void BiasAddGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (input_shape_.size() != 4 && input_shape_.size() != 2) { - MS_LOG(EXCEPTION) << "input data format not support"; - } -} - -bool BiasAddGradCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() != 1 || outputs.size() != 1) { - MS_LOG(EXCEPTION) << "input output size not support"; - } - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto input_addr = reinterpret_cast(inputs[0]->addr); - - if (input_shape_.size() == 4) { - size_t h_size = input_shape_[3]; - size_t c_size = h_size * input_shape_[2]; - size_t n_size = c_size * input_shape_[1]; - size_t hw_size = input_shape_[2] * input_shape_[3]; - size_t c_offset = 0; - for (size_t c = 0; c < input_shape_[1]; ++c) { - output_addr[c] = 0; - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - for (size_t hw = 0; hw < hw_size; ++hw) { - size_t offset = c_offset + n_offset + hw; - output_addr[c] += input_addr[offset]; - } - n_offset += n_size; - } - c_offset += c_size; - } - } else if (input_shape_.size() == 2) { - for (size_t c = 0; c < input_shape_[1]; ++c) { - output_addr[c] = 0; - size_t n_offset = 0; - for (size_t n = 0; n < input_shape_[0]; ++n) { - output_addr[c] += input_addr[c + n_offset]; - n_offset += input_shape_[1]; - } - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.h deleted file mode 100644 index e3ac896096..0000000000 --- a/mindspore/ccsrc/kernel/cpu/bias_add_grad_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class BiasAddGradCPUKernel : public CPUKernel { - public: - BiasAddGradCPUKernel() = default; - ~BiasAddGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::vector input_shape_; -}; -MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddGradCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc deleted file mode 100644 index dac382f447..0000000000 --- a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/concat_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - - axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_1_shape.size()); - } - axis_ += 4 - input_1_shape.size(); - - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - CPUKernelUtils::ExpandDimsTo4(&input_shape); - input_shape_list_.push_back(input_shape); - } - - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool ConcatCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto buff_size = outputs[0]->size; - size_t dim0 = output_shape_[0]; - size_t dim1 = output_shape_[1]; - size_t dim2 = output_shape_[2]; - - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); - } - } else if (axis_ == 0) { - CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); - } - return true; -} - -void ConcatCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr, size_t *buff_size) { - for (size_t i = 0; i < input_shape_list_.size(); ++i) { - auto input_i_shape = input_shape_list_[i]; - auto input_i_addr = reinterpret_cast(inputs[i]->addr); - - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_); - num *= input_i_shape[axis_]; - auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0); - auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed."; - } - *output_addr += num; - *buff_size -= num * sizeof(float); - } -} - -void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h deleted file mode 100644 index 46f9078178..0000000000 --- a/mindspore/ccsrc/kernel/cpu/concat_cpu_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ConcatCPUKernel : public CPUKernel { - public: - ConcatCPUKernel() : axis_(0) {} - ~ConcatCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CheckParam(const CNodePtr &kernel_node); - void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr, size_t *buff_size); - int axis_; - std::vector> input_shape_list_; - std::vector output_shape_; -}; - -MS_REG_CPU_KERNEL(Concat, - KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConcatCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONCAT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc deleted file mode 100644 index 2be05038d6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/cpu_kernel.h" - -namespace mindspore { -namespace kernel { -void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t type_size = sizeof(float); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); - size_t tensor_size = - shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - input_size_list_.emplace_back(tensor_size); - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - std::vector shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); - size_t tensor_size = - shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - output_size_list_.emplace_back(tensor_size); - } -} - -void CPUKernel::Init(const CNodePtr &kernel_node) { - InitKernel(kernel_node); - InitInputOutputSize(kernel_node); -} - -void CPUKernelUtils::ExpandDimsTo4(std::vector *shape) { - auto len = shape->size(); - if (len < 4) { - for (size_t i = 0; i < 4 - len; ++i) { - shape->insert(shape->begin(), 1); - } - } -} - -size_t CPUKernelUtils::CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, - size_t dim3) { - size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3; - return offset; -} - -size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector &shape, int axis) { - if (axis < 0) { - axis = axis + SizeToInt(shape.size()); - } - size_t result = 1; - for (int j = 3; j > axis; --j) { - result *= shape[j]; - } - return result; -} - -void CPUKernelUtils::GetElementNumEveryDim(const std::vector &shape, std::vector *element_num) { - size_t accumulation = 1; - element_num->emplace_back(1); - for (size_t i = shape.size() - 1; i > 0; --i) { - accumulation *= shape[i]; - element_num->emplace_back(accumulation); - } - std::reverse(element_num->begin(), element_num->end()); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel.h deleted file mode 100644 index 5837f922b5..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "ir/anf.h" -#include "session/anf_runtime_algorithm.h" - -using mindspore::kernel::Address; -using mindspore::kernel::AddressPtr; -namespace mindspore { -namespace kernel { -const char KSIZE[] = "ksize"; -const char STRIDE[] = "stride"; -const char STRIDES[] = "strides"; -const char DILATION[] = "dilation"; -const char PAD[] = "pad"; -const char PAD_MODE[] = "pad_mode"; -const char PADDING[] = "padding"; -const char PAD_MODE_LOWER_SAME[] = "same"; -const char PAD_MODE_LOWER_VALID[] = "valid"; -const char PAD_MODE_UPPER_SAME[] = "SAME"; -const char PAD_MODE_UPPER_VALID[] = "VALID"; -const char TRANSPOSE_A[] = "transpose_a"; -const char TRANSPOSE_B[] = "transpose_b"; -const char IS_GRAD[] = "is_grad"; -const char TRANSPOSE_NO = 'N'; -const char TRANSPOSE_YES = 'T'; -const char AXIS[] = "axis"; -const char BEGIN[] = "begin"; -const char END[] = "end"; -const char SIZE[] = "size"; -const char USE_NESTEROV[] = "use_nesterov"; - -class CPUKernel : public kernel::KernelMod { - public: - CPUKernel() = default; - ~CPUKernel() override = default; - virtual void Init(const CNodePtr &kernel_node); - virtual void InitKernel(const CNodePtr &kernel_node) = 0; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void * /*stream_ptr*/) override { - return Launch(inputs, workspace, outputs); - }; - virtual bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) = 0; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - protected: - virtual void InitInputOutputSize(const CNodePtr &kernel_node); - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -class CPUKernelUtils { - public: - static void ExpandDimsTo4(std::vector *shape); - static size_t CalcOffset(const std::vector &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); - static size_t GetElementNumOnAxis(const std::vector &shape, int axis); - static void GetElementNumEveryDim(const std::vector &shape, std::vector *element_num); -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc deleted file mode 100644 index bcda7af9fd..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/cpu_kernel_factory.h" - -#include -#include -#include - -#include "device/kernel_info.h" - -namespace mindspore { -namespace kernel { -CPUKernelFactory &CPUKernelFactory::GetInstance() { - static CPUKernelFactory instance; - return instance; -} - -void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, - CPUKernelCreator &&kernel_creator) { - (void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator); -#if !defined(_WIN32) && !defined(_WIN64) - MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; -#endif -} - -std::shared_ptr CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { - auto kernel_info = apply_kernel->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(kernel_build_Info); - std::pair ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info); - if (ret_pair.first) { - return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second(); - } - return nullptr; -} - -std::pair CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name, - const KernelBuildInfo &kernel_info) { - auto iter = name_to_attr_creator_.find(kernel_name); - if (iter == name_to_attr_creator_.end()) { - MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!"; - return std::make_pair(false, 0); - } - auto creators = iter->second; - for (size_t index = 0; index < creators.size(); ++index) { - auto attr_creator = creators[index]; - if (CPUKernelSingleAttrCheck(attr_creator.first, kernel_info)) { - return std::make_pair(true, index); - } - } - return std::make_pair(false, 0); -} - -bool CPUKernelFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) { - for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) { - auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first; - if (kernel_info.GetInputDeviceType(i) != dtype) { - MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i) - << ", register type:" << dtype; - return false; - } - } - for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) { - auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first; - if (kernel_info.GetOutputDeviceType(i) != dtype) { - MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i) - << ", register type:" << dtype; - return false; - } - } - return true; -} - -std::vector CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) { - std::vector result; - auto iter = name_to_attr_creator_.find(kernel_name); - if (iter == name_to_attr_creator_.end()) { - MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; - return result; - } - auto creators = iter->second; - for (size_t index = 0; index < creators.size(); ++index) { - auto attr_creator = creators[index]; - result.push_back(attr_creator.first); - } - return result; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h b/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h deleted file mode 100644 index aebcc15d6a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ - -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "kernel/cpu/cpu_kernel.h" -#include "device/cpu/kernel_select_cpu.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::cpu::KernelAttr; -using CPUKernelCreator = std::function()>; -class CPUKernelFactory { - public: - static CPUKernelFactory &GetInstance(); - void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); - std::shared_ptr Create(const std::string &kernel_name, const CNodePtr &apply_kernel); - std::vector GetSupportedKernelAttrList(const std::string &kernel_name); - - private: - CPUKernelFactory() = default; - ~CPUKernelFactory() = default; - DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) - std::pair CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); - bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info); - std::map>> name_to_attr_creator_; -}; - -class CPUKernelRegistrar { - public: - CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) { - CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator)); - } - ~CPUKernelRegistrar() = default; -}; - -#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS) -#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) -#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \ - static_assert(std::is_base_of::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ - []() { return std::make_shared(); }); - -#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) -#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) -#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ - static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ - #OPNAME, ATTR, []() { return std::make_shared>(); }); - -#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ - static_assert(std::is_base_of>::value, " must be base of CPUKernel"); \ - static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \ - #OPNAME, ATTR, []() { return std::make_shared>(); }); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CPU_KERNEL_FACTORY_H_ diff --git a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc deleted file mode 100644 index a1dcaca3f3..0000000000 --- a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/debug_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif - -namespace mindspore { -namespace kernel { -void DebugCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } - -bool DebugCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 1 || outputs.empty()) { - MS_LOG(EXCEPTION) << " input or output empty!"; - } - auto val = reinterpret_cast(inputs[0]->addr); - MS_LOG(DEBUG) << " launch DebugCountCPUKernel val " << *val; - - auto output = reinterpret_cast(outputs[0]->addr); - size_t elem_num = inputs[0]->size / sizeof(int); - for (size_t i = 0; i < elem_num; i++) { - output[i] = val[i]; - } - -#ifdef ENABLE_DEBUGGER - // debugger will suspend execution is neccessary - Debugger::GetInstance()->PostDebugOp(); -#endif - - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h deleted file mode 100644 index da9f3286b9..0000000000 --- a/mindspore/ccsrc/kernel/cpu/debug_cpu_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class DebugCPUKernel : public CPUKernel { - public: - DebugCPUKernel() = default; - ~DebugCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), DebugCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_DEBUG_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc deleted file mode 100644 index c9e60f0f4c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" - -namespace mindspore { -namespace kernel { -void EmbeddingLookUpCommGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); - MS_LOG(INFO) << "split_num: " << split_num_; - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape[0] % split_num_ != 0) { - MS_LOG(EXCEPTION) << "Input shape[0] is " << input_shape[0] << ", but it must be multiple of split_num."; - } -} - -bool EmbeddingLookUpCommGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - size_t input_size = inputs[0]->size; - size_t output_size = outputs[0]->size; - MS_LOG(DEBUG) << "input addr: " << input_addr << "input size: " << input_size; - MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << output_size; - memset_s(output_addr, output_size, 0, output_size); - const std::vector &rank_group = {0, 1, 2, 3, 4, 5, 6, 7}; - size_t input_split_lens = input_size / split_num_ / sizeof(float_t); - size_t output_split_lens = output_size / split_num_ / sizeof(float_t); - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - for (int i = 0; i < split_num_; i++) { - mpi_instance->AllGather(input_addr + i * input_split_lens, output_addr + i * output_split_lens, rank_group, - input_split_lens); - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "EmbeddingLookUpCommGradCPUKernel, used time: " << time << " us"; -#endif - return true; -} - -void EmbeddingLookUpCommGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCommGradCPUKernel needs 1."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h deleted file mode 100644 index 7222bd9be1..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_comm_grad_cpu_kernel.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class EmbeddingLookUpCommGradCPUKernel : public CPUKernel { - public: - EmbeddingLookUpCommGradCPUKernel() : split_num_(1) {} - ~EmbeddingLookUpCommGradCPUKernel() override{}; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CheckParam(const CNodePtr &kernel_node); - int split_num_; -}; - -MS_REG_CPU_KERNEL(EmbeddingLookupCommGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EmbeddingLookUpCommGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_COMM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc deleted file mode 100644 index f2fd7fc650..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc +++ /dev/null @@ -1,212 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "kernel/cpu/embedding_look_up_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_lens_ = 1; - for (auto shape : input_shape_) { - input_lens_ = input_lens_ * shape; - } - indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - indices_lens_ = 1; - for (auto shape : indices_shape_) { - indices_lens_ = indices_lens_ * shape; - } - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - axis_ = 4 - input_shape_.size(); - if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) { - reduce_scatter_flag_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrReduceScatterFlag); - } -#ifdef ENABLE_MPI - if (reduce_scatter_flag_) { - size_t gatherv2_out_lens = 1; - for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { - if (i == 0) { - for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { - gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; - } - } else { - gatherv2_out_lens = gatherv2_out_lens * input_shape_[i]; - } - } - gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); - gather_v2_out_ = malloc(gatherv2_out_lens_); - if (gather_v2_out_ == nullptr) { - MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; - } - auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_); - if (ret != 0) { - MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; - } - split_num_ = AnfAlgo::GetNodeAttr(kernel_node, "split_num"); - } -#else - if (reduce_scatter_flag_) { - MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; - } -#endif - if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) { - offset_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrOffset); - } - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast(gather_v2_out_) : output_addr; - size_t dim0 = input_shape_[0]; - size_t dim1 = input_shape_[1]; - size_t dim2 = input_shape_[2]; - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - LookUpTable(inputs, i, j, k, &gather_out_addr); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - LookUpTable(inputs, i, j, 0, &gather_out_addr); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - LookUpTable(inputs, i, 0, 0, &gather_out_addr); - } - } else if (axis_ == 0) { - LookUpTable(inputs, 0, 0, 0, &gather_out_addr); - } -#ifdef ENABLE_MPI - if (reduce_scatter_flag_) { - size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); - size_t reduce_scatter_out_lens = one_split_lens / 8; - const std::vector &group = {0, 1, 2, 3, 4, 5, 6, 7}; - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - for (int i = 0; i < split_num_; i++) { - mpi_instance->ReduceScatter(reinterpret_cast(gather_v2_out_) + i * one_split_lens, - output_addr + i * reduce_scatter_out_lens, group, one_split_lens / 8, "sum"); - } - } -#endif - return true; -} - -void LookUpTable_task(const float *input_addr, float *output_addr, const int *indices_addr, size_t indices_lens, - size_t num, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, - std::vector input_shape, size_t input_lens) { - size_t lens = num * sizeof(float); - for (size_t i = 0; i < indices_lens; ++i) { - int indices = indices_addr[i] - offset; - if (indices >= 0) { - size_t index = IntToSize(indices); - if (index < input_shape[axis]) { - size_t pos = 0; - if (axis == 3) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index); - } else if (axis == 2) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0); - } else if (axis == 1) { - pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0); - } else if (axis == 0) { - pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0); - } - if (pos + num <= input_lens) { - auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - } else { - auto ret = memset_s(output_addr, lens, 0, lens); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; - } - } - output_addr += num; - } -} - -void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto indices_addr = reinterpret_cast(inputs[1]->addr); - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); - float *task_out_addr = *output_addr; - const size_t thread_num = 8; - std::thread threads[8]; - size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; - size_t i; - size_t task_offset = 0; - MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens; - for (i = 0; i < thread_num; i++) { - if (task_offset >= indices_lens_) { - break; - } - MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; - threads[i] = - std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset, - task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_); - task_offset += task_proc_lens; - if (task_offset + task_proc_lens > indices_lens_) { - task_proc_lens = indices_lens_ - task_offset; - } - } - for (size_t j = 0; j < i; j++) { - threads[j].join(); - } - *output_addr += num * indices_lens_; -} - -void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() - << ", but EmbeddingLookUpCPUKernel olny support 4d or lower."; - } - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but EmbeddingLookUpCPUKernel needs 2."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.h deleted file mode 100644 index d839571caa..0000000000 --- a/mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class EmbeddingLookUpCPUKernel : public CPUKernel { - public: - EmbeddingLookUpCPUKernel() { - axis_ = 0; - offset_ = 0; - split_num_ = 0; - input_lens_ = 0; - indices_lens_ = 0; - gatherv2_out_lens_ = 0; - reduce_scatter_flag_ = false; - gather_v2_out_ = nullptr; - } - ~EmbeddingLookUpCPUKernel() override { - if (gather_v2_out_ != nullptr) { - free(gather_v2_out_); - gather_v2_out_ = nullptr; - } - } - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void LookUpTable(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr); - void CheckParam(const CNodePtr &kernel_node); - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; - int axis_; - int offset_; - int split_num_; - size_t input_lens_; - size_t indices_lens_; - size_t gatherv2_out_lens_; - bool reduce_scatter_flag_; - - void *gather_v2_out_; -}; - -MS_REG_CPU_KERNEL( - EmbeddingLookup, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - EmbeddingLookUpCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.cc deleted file mode 100644 index 60e7eafa78..0000000000 --- a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/equal_count_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void EqualCountCPUKernel::InitKernel(const CNodePtr & /*kernel_node*/) {} - -bool EqualCountCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - if (inputs[0]->size != inputs[1]->size) { - MS_LOG(EXCEPTION) << "input or output size!"; - } - int count = 0; - auto left = reinterpret_cast(inputs[0]->addr); - auto right = reinterpret_cast(inputs[1]->addr); - size_t elem_num = inputs[0]->size / sizeof(int); - for (size_t i = 0; i < elem_num; i++) { - if (left[i] == right[i]) { - count++; - } - } - auto output = reinterpret_cast(outputs[0]->addr); - output[0] = count; - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.h deleted file mode 100644 index 13083889d0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/equal_count_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class EqualCountCPUKernel : public CPUKernel { - public: - EqualCountCPUKernel() = default; - ~EqualCountCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - EqualCountCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EQUAL_COUNT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc deleted file mode 100644 index 8aad9d19e6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.cc +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/gather_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_shape_.size()); - } - axis_ += 4 - input_shape_.size(); - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -bool GatherV2CPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto buff_size = outputs[0]->size; - size_t dim0 = input_shape_[0]; - size_t dim1 = input_shape_[1]; - size_t dim2 = input_shape_[2]; - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); - } - } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); - } - } else if (axis_ == 0) { - CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); - } - return true; -} - -void GatherV2CPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr, size_t *buff_size) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto indices_addr = reinterpret_cast(inputs[1]->addr); - size_t elem_num = inputs[1]->size / 4; - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); - for (size_t i = 0; i < elem_num; ++i) { - if (indices_addr[i] < 0) { - MS_LOG(EXCEPTION) << "The indices value is less than 0."; - } - size_t index = IntToSize(indices_addr[i]); - if (index >= input_shape_[IntToSize(axis_)]) { - auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memset failed."; - } - } else { - size_t pos = 0; - if (axis_ == 3) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); - } else if (axis_ == 2) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); - } else if (axis_ == 1) { - pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); - } else if (axis_ == 0) { - pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); - } - auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed."; - } - } - *output_addr += num; - *buff_size -= num * sizeof(float); - } -} // namespace kernel - -void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h deleted file mode 100644 index 2ffd7df4d4..0000000000 --- a/mindspore/ccsrc/kernel/cpu/gather_cpu_kernel.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class GatherV2CPUKernel : public CPUKernel { - public: - GatherV2CPUKernel() : axis_(0) {} - ~GatherV2CPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr, size_t *buff_size); - void CheckParam(const CNodePtr &kernel_node); - std::vector input_shape_; - std::vector indices_shape_; - std::vector output_shape_; - int axis_; -}; - -MS_REG_CPU_KERNEL( - GatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - GatherV2CPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_GATHER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.cc deleted file mode 100644 index 657c85dc48..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/conv2d_cpu_kernel.h" -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 || weight_shape.size() != 4) { - MS_LOG(EXCEPTION) << "conv2d only support nchw input!"; - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); - auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); - auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - if (stride_ori.size() != 4 || stride_ori[2] != stride_ori[3]) { - MS_LOG(EXCEPTION) << "conv2d only support equal stride, and stride must be 4d!"; - } - if (stride_ori[0] != 1 || stride_ori[1] != 1) { - MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!"; - } - if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { - MS_LOG(EXCEPTION) << "conv2d dilation only support 1, and dilation must be 4d!"; - } - if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { - MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!"; - } - int stride = stride_ori[2]; - int dilation = dilation_ori[2]; - - dnnl::memory::dims strides{stride, stride}; - dnnl::memory::dims dilates{dilation - 1, dilation - 1}; - std::vector int_padding_l; - std::vector int_padding_r; - - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::convolution_forward::desc desc = - dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, - weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_WEIGHTS, weights_desc); - AddArgument(DNNL_ARG_DST, dst_desc); -} - -bool Conv2dCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.h deleted file mode 100644 index 1cb100299e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class Conv2dCPUKernel : public MKLCPUKernel { - public: - Conv2dCPUKernel() = default; - ~Conv2dCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Conv2D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc deleted file mode 100644 index fbfebaf56e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h" -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector weight_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 || weight_shape.size() != 4) { - MS_LOG(EXCEPTION) << ("conv2d grad filter only support nchw input!"); - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); - auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); - auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { - MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel only support equal stride, and stride must be 2d!"; - } - if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1, and dilation must be 4d!"; - } - if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradFilterCPUKernel dilation only support 1 in N axis and C axis!"; - } - int stride = stride_ori[0]; - int dilation = dilation_ori[2]; - - dnnl::memory::dims strides{stride, stride}; - dnnl::memory::dims dilates{dilation - 1, dilation - 1}; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - std::vector int_padding_l; - std::vector int_padding_r; - GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::convolution_forward::desc forward_desc = - dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, - weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); - - dnnl::convolution_backward_weights::desc backward_desc = dnnl::convolution_backward_weights::desc( - dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto backward_prim_desc = dnnl::convolution_backward_weights::primitive_desc( - backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); - primitive_ = std::make_shared(backward_prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_DST, dst_desc); - AddArgument(DNNL_ARG_DIFF_WEIGHTS, weights_desc); -} - -bool Conv2dGradFilterCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h deleted file mode 100644 index 49559f452b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_filter_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class Conv2dGradFilterCPUKernel : public MKLCPUKernel { - public: - Conv2dGradFilterCPUKernel() = default; - ~Conv2dGradFilterCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Conv2DBackpropFilter, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dGradFilterCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_FILTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc deleted file mode 100644 index ff0b8633d4..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h" -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 || weight_shape.size() != 4) { - MS_LOG(EXCEPTION) << "conv2d grad filter only support nchw input!"; - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - - int kernel_size = SizeToInt(weight_shape[3]); - auto stride_ori = AnfAlgo::GetNodeAttr>(kernel_node, STRIDE); - auto dilation_ori = AnfAlgo::GetNodeAttr>(kernel_node, DILATION); - if (stride_ori.size() != 2 || stride_ori[0] != stride_ori[1]) { - MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel only support equal stride, and stride must be 2d!"; - } - if (dilation_ori.size() != 4 || dilation_ori[2] != 1 || dilation_ori[3] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1, and dilation must be 4d!"; - } - if (dilation_ori[0] != 1 || dilation_ori[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2dGradInputCPUKernel dilation only support 1 in N axis and C axis!"; - } - int stride = stride_ori[0]; - int dilation = dilation_ori[2]; - dnnl::memory::dims strides{stride, stride}; - dnnl::memory::dims dilates{dilation - 1, dilation - 1}; - std::vector int_padding_l; - std::vector int_padding_r; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PAD_MODE); - GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "conv2d grad get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::convolution_forward::desc forward_desc = - dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc, - weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto forward_prim_desc = dnnl::convolution_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); - - dnnl::convolution_backward_data::desc backward_desc = dnnl::convolution_backward_data::desc( - dnnl::algorithm::convolution_auto, src_desc, weights_desc, dst_desc, strides, dilates, padding_l, padding_r); - - auto backward_prim_desc = - dnnl::convolution_backward_data::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); - primitive_ = std::make_shared(backward_prim_desc); - - AddArgument(DNNL_ARG_DIFF_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_DST, dst_desc); - AddArgument(DNNL_ARG_WEIGHTS, weights_desc); -} - -bool Conv2dGradInputCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h deleted file mode 100644 index 9fb024a279..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/conv2d_grad_input_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class Conv2dGradInputCPUKernel : public MKLCPUKernel { - public: - Conv2dGradInputCPUKernel() = default; - ~Conv2dGradInputCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Conv2DBackpropInput, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dGradInputCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_CONV2D_GRAD_INPUT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc deleted file mode 100644 index 0a343785f7..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.cc +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/lstm_cpu_kernel.h" -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { -#ifdef PLATFORM_86 - _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); -#endif - MS_EXCEPTION_IF_NULL(kernel_node); - using tag = dnnl::memory::format_tag; - using dim = dnnl::memory::dims; - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); - input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); - hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); - num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); - has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); - batch_size_ = SizeToInt(src_shape[1]); - seq_len_ = SizeToInt(src_shape[0]); - num_directions_ = 1; - if (bidirectional_) { - num_directions_ = 2; - } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { - MS_LOG(EXCEPTION) << "error iteration shape!"; - } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } - if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { - MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; - } - const int gate_size = 4 * hidden_size_; - for (int i = 0; i < num_layers_; ++i) { - weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); - weight_h_size_ += gate_size * hidden_size_; - } - weight_size_ = weight_size_ * num_directions_; - weight_h_size_ = weight_h_size_ * num_directions_; - auto eng = MKLKernelEngine::Get().engine(); - dnnl::stream s(eng); - dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; - if (bidirectional_) { - direction = dnnl::rnn_direction::bidirectional_concat; - } - dim src_dims = {seq_len_, batch_size_, input_size_}; - dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; - weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; - bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; - dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; - dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); - dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); - dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); - dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); - dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); - dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); - dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - auto desc = std::make_shared(dnnl::prop_kind::forward_training, direction, src_desc, - src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), - formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, - dst_h_desc, dst_c_desc); - prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng); - primitive_ = std::make_shared(prim_desc_); - AddArgument(DNNL_ARG_SRC_LAYER, src_desc); - AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); - AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); - AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); - AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); - AddArgument(DNNL_ARG_BIAS, bias_desc); - AddArgument(DNNL_ARG_DST_LAYER, dst_desc); - AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); - AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); -} - -bool LstmCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - using dt = dnnl::memory::data_type; - using tag = dnnl::memory::format_tag; - auto eng = MKLKernelEngine::Get().engine(); - auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); - auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); - auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); - auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); - user_weights_memory.set_data_handle(inputs[3]->addr); - user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); - Reorder(&user_weights_memory, &weights_memory); - Reorder(&user_weights_h_memory, &weights_h_memory); - auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); - if (has_bias_) { - bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); - } else { - auto ret = - memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, prim_desc_.bias_desc().get_size()); - if (ret != 0) { - MS_LOG(EXCEPTION) << "bias memset error"; - } - } - // set handle - SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h deleted file mode 100644 index d42ff803f0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_cpu_kernel.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H_ -#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) -#define PLATFORM_86 -#endif -#ifdef PLATFORM_86 -#include -#endif -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" -namespace mindspore { -namespace kernel { -class LstmCPUKernel : public MKLCPUKernel { - public: - LstmCPUKernel() = default; - ~LstmCPUKernel() override = default; - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - int weight_size_ = 0; - int weight_h_size_ = 0; - int input_size_; - int hidden_size_; - int num_layers_; - int batch_size_; - int seq_len_; - int num_directions_; - bool bidirectional_; - bool has_bias_; - dnnl::memory::dims weights_dims_; - dnnl::memory::dims weights_h_dims_; - dnnl::memory::dims bias_dims_; - dnnl::lstm_forward::primitive_desc prim_desc_; -}; - -MS_REG_CPU_KERNEL(LSTM, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_CPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc deleted file mode 100644 index d7e7701d85..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h" -#include -#include -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - using tag = dnnl::memory::format_tag; - using dim = dnnl::memory::dims; - auto eng = MKLKernelEngine::Get().engine(); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); - bidirectional_ = AnfAlgo::GetNodeAttr(kernel_node, "bidirectional"); - input_size_ = AnfAlgo::GetNodeAttr(kernel_node, "input_size"); - hidden_size_ = AnfAlgo::GetNodeAttr(kernel_node, "hidden_size"); - num_layers_ = AnfAlgo::GetNodeAttr(kernel_node, "num_layers"); - has_bias_ = AnfAlgo::GetNodeAttr(kernel_node, "has_bias"); - batch_size_ = SizeToInt(src_shape[1]); - seq_len_ = SizeToInt(src_shape[0]); - num_directions_ = 1; - if (bidirectional_) { - num_directions_ = 2; - } - if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { - MS_LOG(EXCEPTION) << "error iteration shape!"; - } - if (num_layers_ <= 0) { - MS_LOG(EXCEPTION) << "layers must be greater than zero!"; - } - if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { - MS_LOG(EXCEPTION) << "conv2d only support 3-D input!"; - } - const int gate_size = 4 * hidden_size_; - for (int i = 0; i < num_layers_; ++i) { - weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); - weight_h_size_ += gate_size * hidden_size_; - } - weight_size_ = weight_size_ * num_directions_; - weight_h_size_ = weight_h_size_ * num_directions_; - dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; - if (bidirectional_) { - direction = dnnl::rnn_direction::bidirectional_concat; - } - dim src_dims = {seq_len_, batch_size_, input_size_}; - dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; - weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; - bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; - dim dst_dims = {seq_len_, batch_size_, hidden_size_ * num_directions_}; - dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; - dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); - dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); - dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); - dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); - dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); - dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); - dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); - auto forward_desc = std::make_shared( - dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, - formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, - dst_c_desc); - auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng); - auto backward_desc = std::make_shared( - dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), - formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, - src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, - dst_h_desc, dst_c_desc); - prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); - primitive_ = std::make_shared(prim_backward_desc_); - - AddArgument(DNNL_ARG_SRC_LAYER, src_desc); - AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); - AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); - AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); - AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); - AddArgument(DNNL_ARG_BIAS, bias_desc); - AddArgument(DNNL_ARG_DST_LAYER, dst_desc); - AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); - AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); - AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); - AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); - AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); - AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); - AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); - AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); - AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); - AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); - AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); -} - -bool LSTMGradCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace /*workspace*/, - const std::vector &outputs) { - using dt = dnnl::memory::data_type; - using tag = dnnl::memory::format_tag; - auto eng = MKLKernelEngine::Get().engine(); - // construct fw memory - auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); - auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); - auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); - auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); - auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); - user_weights_memory.set_data_handle(inputs[3]->addr); - user_weights_h_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_); - Reorder(&user_weights_memory, &weights_memory); - Reorder(&user_weights_h_memory, &weights_h_memory); - if (has_bias_) { - bias_memory.set_data_handle(reinterpret_cast(inputs[3]->addr) + weight_size_ + weight_h_size_); - } else { - if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0, - prim_backward_desc_.bias_desc().get_size())) { - MS_LOG(EXCEPTION) << "bias memset error"; - } - } - // construct bw memory - auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); - auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); - auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); - auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); - auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); - user_diff_weights_memory.set_data_handle(outputs[3]->addr); - user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - if (memset_s(user_diff_weights_memory.get_data_handle(), user_diff_weights_memory.get_desc().get_size(), 0, - user_diff_weights_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "user weights grad memset error"; - } - if (memset_s(user_diff_weights_h_memory.get_data_handle(), user_diff_weights_h_memory.get_desc().get_size(), 0, - user_diff_weights_h_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "user weights iter grad memset error"; - } - if (has_bias_) { - diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); - } - if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0, - prim_backward_desc_.diff_bias_desc().get_size())) { - MS_LOG(EXCEPTION) << "bias grad memset error"; - } - if (memset_s(diff_weights_memory.get_data_handle(), diff_weights_memory.get_desc().get_size(), 0, - diff_weights_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "weights grad memset error"; - } - if (memset_s(diff_weights_h_memory.get_data_handle(), diff_weights_h_memory.get_desc().get_size(), 0, - diff_weights_h_memory.get_desc().get_size())) { - MS_LOG(EXCEPTION) << "weights iter grad memset error"; - } - SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); - SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); - SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); - SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); - SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); - ExecutePrimitive(); - Reorder(&diff_weights_memory, &user_diff_weights_memory); - Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h deleted file mode 100644 index 1f3fb824c0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/lstm_grad_cpu_kernel.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class LSTMGradCPUKernel : public MKLCPUKernel { - public: - LSTMGradCPUKernel() = default; - ~LSTMGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - int weight_size_ = 0; - int weight_h_size_ = 0; - int input_size_; - int hidden_size_; - int num_layers_; - int batch_size_; - int seq_len_; - int num_directions_; - bool bidirectional_; - bool has_bias_; - dnnl::memory::dims weights_dims_; - dnnl::memory::dims weights_h_dims_; - dnnl::memory::dims bias_dims_; - dnnl::lstm_backward::primitive_desc prim_backward_desc_; -}; - -MS_REG_CPU_KERNEL(LSTMGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LSTMGradCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_LSTM_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.cc deleted file mode 100644 index 28266f2aa0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/matmul_cpu_kernel.h" -#include -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "common/utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - - if (src_shape.size() != 2 || weight_shape.size() != 2 || dst_shape.size() != 2) { - MS_LOG(EXCEPTION) << "matmul invalid input size"; - } - bool trans_a = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_A); - bool trans_b = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_B); - if (trans_a) { - trans_a_ = TRANSPOSE_YES; - dim_m_ = static_cast(src_shape[1]); - dim_k_ = static_cast(src_shape[0]); - } else { - dim_m_ = static_cast(src_shape[0]); - dim_k_ = static_cast(src_shape[1]); - } - if (trans_b) { - trans_b_ = TRANSPOSE_YES; - } - dim_n_ = static_cast(dst_shape[1]); -} - -bool MatMulCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "matmul error input output size!"; - } - dnnl_dim_t lda = dim_m_; - if (trans_a_ == TRANSPOSE_NO) { - lda = dim_k_; - } - dnnl_dim_t ldb = dim_k_; - if (trans_b_ == TRANSPOSE_NO) { - ldb = dim_n_; - } - auto input_a = reinterpret_cast(inputs[0]->addr); - auto input_b = reinterpret_cast(inputs[1]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - (void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, 1.f, input_a, lda, input_b, ldb, 0.f, output, dim_n_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.h deleted file mode 100644 index 10276d01fa..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/matmul_cpu_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class MatMulCPUKernel : public MKLCPUKernel { - public: - MatMulCPUKernel() = default; - ~MatMulCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - char trans_a_{TRANSPOSE_NO}; - char trans_b_{TRANSPOSE_NO}; - dnnl_dim_t dim_m_{0}; - dnnl_dim_t dim_n_{0}; - dnnl_dim_t dim_k_{0}; -}; - -MS_REG_CPU_KERNEL( - MatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MatMulCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MATMUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc deleted file mode 100644 index a38470e3a3..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" -#include -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" - -namespace mindspore { -namespace kernel { -void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, - const std::vector &src_shape, int kernel_size, int stride, - std::vector *padding_l, std::vector *padding_r) { - MS_EXCEPTION_IF_NULL(kernel_node); - if (src_shape.size() < 2) { - MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!"; - } - std::vector weight_height; - weight_height.emplace_back(src_shape[src_shape.size() - 2]); - weight_height.emplace_back(src_shape[src_shape.size() - 1]); - int rad = kernel_size / 2; - int need_pad = kernel_size - 1; - MS_LOG(INFO) << "pad mode " << pad_mode; - if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) { - for (auto wh : weight_height) { - int re = (wh - 1) % stride; - int pad = std::max(rad - (re / 2), 0); - padding_r->emplace_back(pad); - pad = std::max(need_pad - pad - re, 0); - padding_l->emplace_back(pad); - } - } else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) { - MS_LOG(INFO) << "pad valid"; - padding_l->emplace_back(0); - padding_l->emplace_back(0); - padding_r->emplace_back(0); - padding_r->emplace_back(0); - } else { - std::vector pad = AnfAlgo::GetNodeAttr>(kernel_node, PAD); - if (pad.size() != 4) { - MS_LOG(EXCEPTION) << "wrong pad size in max pooling " << pad.size(); - } - padding_l->emplace_back(pad[0]); - padding_l->emplace_back(pad[1]); - padding_r->emplace_back(pad[2]); - padding_r->emplace_back(pad[3]); - } -} - -dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { - dnnl::memory::format_tag mem_tag; - auto dim_size = dims.size(); - if (dim_size == 4) { - mem_tag = dnnl::memory::format_tag::abcd; - } else if (dim_size == 3) { - mem_tag = dnnl::memory::format_tag::abc; - } else if (dim_size == 2) { - mem_tag = dnnl::memory::format_tag::ab; - } else if (dim_size == 1) { - mem_tag = dnnl::memory::format_tag::a; - } else { - MS_LOG(EXCEPTION) << "kernel dims invalid " << dim_size; - } - return mem_tag; -} - -dnnl::memory::desc MKLCPUKernel::GetDefaultMemDesc(const std::vector &shape) { - dnnl::memory::dims dims; - dims.insert(dims.end(), shape.begin(), shape.end()); - dnnl::memory::format_tag mem_tag = GetDefaultFormatTag(dims); - dnnl::memory::desc mem_desc(dims, dnnl::memory::data_type::f32, mem_tag); - return mem_desc; -} - -void MKLCPUKernel::AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc) { - arguments_[arg_key] = MKLKernelEngine::Get().CreateMemory(mem_desc, alloc); -} - -void MKLCPUKernel::SetArgumentHandle(int arg_key, void *ptr) { - auto arg_iter = arguments_.find(arg_key); - if (arg_iter != arguments_.end()) { - arg_iter->second.set_data_handle(ptr); - } -} - -void MKLCPUKernel::ExecutePrimitive() { MKLKernelEngine::Get().Execute(primitive_, arguments_); } - -void MKLCPUKernel::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { - MKLKernelEngine::Get().Reorder(src_mem, dst_mem); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h deleted file mode 100644 index 10a860afff..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_cpu_kernel.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ - -#include -#include -#include -#include -#include "dnnl.hpp" -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class MKLCPUKernel : public CPUKernel { - public: - MKLCPUKernel() = default; - ~MKLCPUKernel() override = default; - - protected: - void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector &src_shape, - int kernel_size, int stride, std::vector *padding_l, std::vector *padding_r); - void AddArgument(int arg_key, const dnnl::memory::desc &mem_desc, bool alloc = false); - void SetArgumentHandle(int arg_key, void *ptr); - dnnl::memory::format_tag GetDefaultFormatTag(const dnnl::memory::dims &dims) const; - dnnl::memory::desc GetDefaultMemDesc(const std::vector &shape); - void ExecutePrimitive(); - std::unordered_map arguments_; - std::shared_ptr primitive_{nullptr}; - inline dnnl::memory::desc formatted_md(const dnnl::memory::dims &dimensions, dnnl::memory::format_tag layout) { - return dnnl::memory::desc{{dimensions}, dnnl::memory::data_type::f32, layout}; - } - void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem); -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MKL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc deleted file mode 100644 index 5ae9791b12..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mkl_kernel_engine.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "utils/log_adapter.h" -#include "dnnl.hpp" - -namespace mindspore { -namespace kernel { -void MKLKernelEngine::Execute(const std::shared_ptr &primitive, - const std::unordered_map &arguments) { - MS_EXCEPTION_IF_NULL(primitive); - primitive->execute(stream_, arguments); - (void)stream_.wait(); -} - -dnnl::memory MKLKernelEngine::CreateMemory(const dnnl::memory::desc &mem_desc, bool alloc) { - if (alloc) { - return dnnl::memory(mem_desc, engine_); - } else { - return dnnl::memory(mem_desc, engine_, nullptr); - } -} -void MKLKernelEngine::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) { - dnnl::reorder(*src_mem, *dst_mem).execute(stream_, *src_mem, *dst_mem); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.cc deleted file mode 100644 index 4f77508004..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/mul_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) { - MS_LOG(EXCEPTION) << "mul only support same dim input or tensor * scalar " << src0_shape.size() << " vs " - << src1_shape.size(); - } - if (src1_shape.size() < src0_shape.size()) { - for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { - src1_shape.emplace_back(1); - } - } - dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape); - dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape); - dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape); - dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_mul, src0_mem_desc, src1_mem_desc, dst_mem_desc); - auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - AddArgument(DNNL_ARG_SRC_0, src0_mem_desc); - AddArgument(DNNL_ARG_SRC_1, src1_mem_desc); - AddArgument(DNNL_ARG_DST, dst_mem_desc); -} - -bool MulCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "mul error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.h deleted file mode 100644 index 1131fd594c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/mul_cpu_kernel.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class MulCPUKernel : public MKLCPUKernel { - public: - MulCPUKernel() = default; - ~MulCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MulCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_MUL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.cc deleted file mode 100644 index 5225050dc1..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/pooling_cpu_kernel.h" -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape); - std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); - std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - if (kernel_sizes.size() != 4 || strides.size() != 4) { - MS_LOG(EXCEPTION) << "invalid kernel size " << kernel_sizes.size() << " or stride size " << strides.size(); - } - dnnl::memory::dims strides_dims{strides[2], strides[3]}; - dnnl::memory::dims kernels_dims{kernel_sizes[2], kernel_sizes[3]}; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); - std::vector int_padding_l; - std::vector int_padding_r; - GetPadding(kernel_node, pad_mode, src_shape, kernel_sizes[3], strides[3], &int_padding_l, &int_padding_r); - if (int_padding_l.size() != 2 || int_padding_r.size() != 2) { - MS_LOG(EXCEPTION) << "pooling get padding failed"; - } - dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]}; - dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]}; - dnnl::pooling_forward::desc desc = - dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc, - strides_dims, kernels_dims, padding_l, padding_r); - auto prim_desc = dnnl::pooling_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DST, dst_desc); - AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc()); -} - -bool PoolingCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.h deleted file mode 100644 index 4993d0834d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_cpu_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class PoolingCPUKernel : public MKLCPUKernel { - public: - PoolingCPUKernel() = default; - ~PoolingCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.cc deleted file mode 100644 index c0459de790..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h" -#include -#include -#include -#include "common/utils.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void PoolingGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - src_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - dst_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - std::vector kernel_sizes = AnfAlgo::GetNodeAttr>(kernel_node, KSIZE); - std::vector strides = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - if (kernel_sizes.size() != 4 || strides.size() != 4 || src_shape_.size() != 4 || dst_shape_.size() != 4) { - MS_LOG(EXCEPTION) << "pooling grad invalid input size"; - } - std::vector padding_r; - const std::string pad_mode = AnfAlgo::GetNodeAttr(kernel_node, PADDING); - kernel_size_ = kernel_sizes[3]; - stride_ = strides[3]; - GetPadding(kernel_node, pad_mode, src_shape_, kernel_size_, stride_, &padding_l_, &padding_r); -} - -void PoolingGradCPUKernel::RowPoolingGrad(const float *input, float *output, float diff, - const std::vector> &box, - std::vector> *row_max_pair) { - float max_value = 0; - size_t max_index = box[1].second; - size_t src_width = src_shape_[3]; - size_t index_start; - size_t index; - for (size_t i = box[1].first; i < box[1].second; ++i) { - if ((*row_max_pair)[i].first == 0) { - index_start = box[0].first * src_width; - for (size_t j = box[0].first; j < box[0].second; ++j) { - index = index_start + i; - if (input[index] > (*row_max_pair)[i].second || j == box[0].first) { - (*row_max_pair)[i].second = input[index]; - (*row_max_pair)[i].first = index; - } - index_start += src_width; - } - } - if ((*row_max_pair)[i].second > max_value || max_index == box[1].second) { - max_value = (*row_max_pair)[i].second; - max_index = i; - } - } - - output[(*row_max_pair)[max_index].first] += diff; -} - -void PoolingGradCPUKernel::ChannelPoolingGrad(const float *input, const float *diff, float *output) { - int src_width = SizeToInt(src_shape_[3]); - int src_height = SizeToInt(src_shape_[2]); - std::vector> row_max_pair(src_shape_[3]); - std::vector> box(2); - int h_start = -padding_l_[0]; - size_t diff_index = 0; - for (size_t h = 0; h < dst_shape_[2]; ++h) { - box[0].first = IntToSize(std::max(h_start, 0)); - box[0].second = IntToSize(std::min(h_start + kernel_size_, src_height)); - for (size_t w = 0; w < src_shape_[3]; ++w) { - row_max_pair[w].first = 0; - row_max_pair[w].second = 0; - } - int w_start = -padding_l_[1]; - for (size_t w = 0; w < dst_shape_[3]; ++w) { - box[1].first = IntToSize(std::max(w_start, 0)); - box[1].second = IntToSize(std::min(w_start + kernel_size_, src_width)); - RowPoolingGrad(input, output, diff[diff_index], box, &row_max_pair); - diff_index += 1; - w_start += stride_; - } - h_start += stride_; - } -} - -bool PoolingGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 3 || outputs.empty()) { - MS_LOG(EXCEPTION) << "pooling grad error input output size!"; - } - - auto input = reinterpret_cast(inputs[0]->addr); - auto diff = reinterpret_cast(inputs[2]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - auto ret = memset_s(output, outputs[0]->size, 0, outputs[0]->size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "pooling grad memset error"; - } - size_t src_wh = src_shape_[2] * src_shape_[3]; - size_t dst_wh = dst_shape_[2] * dst_shape_[3]; - for (size_t n = 0; n < src_shape_[0]; ++n) { - for (size_t c = 0; c < src_shape_[1]; ++c) { - ChannelPoolingGrad(input, diff, output); - input = input + src_wh; - output = output + src_wh; - diff = diff + dst_wh; - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h deleted file mode 100644 index cdb2c69ef0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/pooling_grad_cpu_kernel.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class PoolingGradCPUKernel : public MKLCPUKernel { - public: - PoolingGradCPUKernel() = default; - ~PoolingGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void RowPoolingGrad(const float *input, float *output, float diff, const std::vector> &box, - std::vector> *row_max_pair); - void ChannelPoolingGrad(const float *input, const float *diff, float *output); - int stride_{0}, kernel_size_{0}; - std::vector padding_l_; - std::vector src_shape_; - std::vector dst_shape_; -}; - -MS_REG_CPU_KERNEL(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_POOLING_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.cc deleted file mode 100644 index d5ef20a25e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/relu_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void ReluCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 && src_shape.size() != 2) { - MS_LOG(EXCEPTION) << "relu kernel dims invalid " << src_shape.size(); - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - - dnnl::eltwise_forward::desc desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); - auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DST, src_desc); -} - -bool ReluCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.h deleted file mode 100644 index 26905e267d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_cpu_kernel.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class ReluCPUKernel : public MKLCPUKernel { - public: - ReluCPUKernel() = default; - ~ReluCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.cc deleted file mode 100644 index 4a6213ddf2..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/relu_grad_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void ReluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - if (src_shape.size() != 4 && src_shape.size() != 2) { - MS_LOG(EXCEPTION) << "relu grad kernel dims invalid " << src_shape.size(); - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - - dnnl::eltwise_forward::desc forward_desc = - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0); - auto forward_prim_desc = dnnl::eltwise_forward::primitive_desc(forward_desc, MKLKernelEngine::Get().engine()); - - dnnl::eltwise_backward::desc backward_desc = - dnnl::eltwise_backward::desc(dnnl::algorithm::eltwise_relu, src_desc, src_desc, 0.0, 0.0); - auto backward_prim_desc = - dnnl::eltwise_backward::primitive_desc(backward_desc, MKLKernelEngine::Get().engine(), forward_prim_desc); - primitive_ = std::make_shared(backward_prim_desc); - - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_SRC, src_desc); - AddArgument(DNNL_ARG_DIFF_DST, src_desc); -} - -bool ReluGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 2 || outputs.empty()) { - MS_LOG(EXCEPTION) << "relu grad error input output size!"; - } - if (inputs[0]->size != outputs[0]->size) { - MS_LOG(EXCEPTION) << "relu grad error input output data size!"; - } - - SetArgumentHandle(DNNL_ARG_SRC, inputs[1]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DIFF_DST, inputs[0]->addr); - ExecutePrimitive(); - size_t mem_bits = outputs[0]->size; - auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret; - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.h deleted file mode 100644 index f0a77ee282..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/relu_grad_cpu_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class ReluGradCPUKernel : public MKLCPUKernel { - public: - ReluGradCPUKernel() = default; - ~ReluGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL( - ReluGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReluGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RELU_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.cc deleted file mode 100644 index 7fa740cfc0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/softmax_cpu_kernel.h" -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - std::vector axis_list = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); - if (axis_list.size() != 1) { - MS_LOG(EXCEPTION) << "cpu softmax only support input axis size 1"; - } - int axis = axis_list[0]; - if (axis == -1 || axis >= SizeToInt(src_shape.size())) { - axis = SizeToInt(src_shape.size()) - 1; - } - dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape); - dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - AddArgument(DNNL_ARG_SRC, src_desc); - AddArgument(DNNL_ARG_DST, src_desc); -} - -bool SoftmaxCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "softmax error input output size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); - ExecutePrimitive(); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.h deleted file mode 100644 index 6acb9e5b9b..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cpu_kernel.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class SoftmaxCPUKernel : public MKLCPUKernel { - public: - SoftmaxCPUKernel() = default; - ~SoftmaxCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc deleted file mode 100644 index 05b1a79924..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h" -#include -#include -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - size_t type_size = sizeof(float); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - workspace_size_list_.emplace_back(tensor_size); -} - -void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - dnnl::memory::dims mem_dims; - mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); - if (mem_dims.size() != 2) { - MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - if (batch_size_ == 0 || class_num_ == 0) { - MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; - } - dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); - - dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, mem_desc); - AddArgument(DNNL_ARG_DST, mem_desc); -} - -void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels, - float *output1, float *output2) const { - float epsilon = 1e-6; - for (size_t i = 0; i < batch_size_; ++i) { - output1[i] = 0; - float loss = 0.0; - for (size_t j = 0; j < class_num_; ++j) { - float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]); - output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j]; - loss += labels[i * class_num_ + j] * logit; - } - output1[i] = -loss; - } -} - -bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - if (inputs.empty() || workspace.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || - inputs[1]->size != batch_class_float_size) { - MS_LOG(EXCEPTION) << "error input data size!"; - } - if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "error output data size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); - ExecutePrimitive(); - auto labels = reinterpret_cast(inputs[1]->addr); - auto logits = reinterpret_cast(workspace[0]->addr); - auto output1 = reinterpret_cast(outputs[0]->addr); - auto output2 = reinterpret_cast(outputs[1]->addr); - ForwardPostExecute(logits, labels, output1, output2); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h deleted file mode 100644 index f663508059..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { - public: - SoftmaxCrossEntropyWithLogitsCPUKernel() = default; - ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - void InitInputOutputSize(const CNodePtr &kernel_node) override; - - private: - void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const; - size_t class_num_{0}; - size_t batch_size_{0}; -}; -MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SoftmaxCrossEntropyWithLogitsCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc deleted file mode 100644 index c33fcd246f..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h" -#include -#include -#include -#include "kernel/cpu/mkldnn/mkl_kernel_engine.h" -#include "device/cpu/cpu_device_address.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - size_t type_size = sizeof(float); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - workspace_size_list_.emplace_back(tensor_size); -} - -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - dnnl::memory::dims mem_dims; - mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); - if (mem_dims.size() != 2) { - MS_LOG(EXCEPTION) << "SparseSoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); - } - batch_size_ = shape[0]; - class_num_ = shape[1]; - if (batch_size_ == 0 || class_num_ == 0) { - MS_LOG(EXCEPTION) << "invalid batch size or class num input!"; - } - is_grad_ = AnfAlgo::GetNodeAttr(kernel_node, IS_GRAD); - dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); - - dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); - auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); - primitive_ = std::make_shared(prim_desc); - - AddArgument(DNNL_ARG_SRC, mem_desc); - AddArgument(DNNL_ARG_DST, mem_desc); -} - -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses, - float *output) const { - float total_loss = 0; - for (size_t i = 0; i < batch_size_; ++i) { - if (labels[i] < 0) { - MS_LOG(EXCEPTION) << "label value must >= 0"; - } - size_t label = IntToSize(labels[i]); - if (label > class_num_) { - MS_LOG(EXCEPTION) << "error label input!"; - } - total_loss -= logf(losses[i * class_num_ + label]); - } - output[0] = total_loss / batch_size_; -} - -void SparseSoftmaxCrossEntropyWithLogitsCPUKernel::GradPostExecute(const int *labels, const float *losses, - float *output) const { - size_t row_start = 0; - for (size_t i = 0; i < batch_size_; ++i) { - if (labels[i] < 0) { - MS_LOG(EXCEPTION) << "label value must >= 0"; - } - size_t label = IntToSize(labels[i]); - if (label > class_num_) { - MS_LOG(EXCEPTION) << "error label input!"; - } - for (size_t j = 0; j < class_num_; ++j) { - size_t index = row_start + j; - if (j == label) { - output[index] = (losses[index] - 1) / batch_size_; - } else { - output[index] = losses[index] / batch_size_; - } - } - row_start += class_num_; - } -} - -bool SparseSoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs) { - if (inputs.empty() || workspace.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - size_t batch_float_size = batch_size_ * sizeof(float); - size_t batch_class_float_size = class_num_ * batch_float_size; - if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || - inputs[1]->size != batch_float_size) { - MS_LOG(EXCEPTION) << "error input data size!"; - } - if (is_grad_ && outputs[0]->size != batch_class_float_size) { - MS_LOG(EXCEPTION) << "error output data size!"; - } else if (!is_grad_ && outputs[0]->size != sizeof(float)) { - MS_LOG(EXCEPTION) << "error output data size!"; - } - SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); - SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); - ExecutePrimitive(); - auto labels = reinterpret_cast(inputs[1]->addr); - auto losses = reinterpret_cast(workspace[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - if (is_grad_) { - GradPostExecute(labels, losses, output); - } else { - ForwardPostExecute(labels, losses, output); - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h deleted file mode 100644 index 6391b27de6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/mkldnn/sparse_softmax_cross_entropy_with_logits_cpu_kernel.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { - public: - SparseSoftmaxCrossEntropyWithLogitsCPUKernel() = default; - ~SparseSoftmaxCrossEntropyWithLogitsCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - void InitInputOutputSize(const CNodePtr &kernel_node) override; - - private: - void ForwardPostExecute(const int *labels, const float *losses, float *output) const; - void GradPostExecute(const int *labels, const float *losses, float *output) const; - bool is_grad_{false}; - size_t class_num_{0}; - size_t batch_size_{0}; -}; - -MS_REG_CPU_KERNEL( - SparseSoftmaxCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - SparseSoftmaxCrossEntropyWithLogitsCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.cc deleted file mode 100644 index 00dfe73f28..0000000000 --- a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/one_hot_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void OneHotCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - if (output_shape.size() < 2) { - MS_LOG(EXCEPTION) << "invalid output shape size: " << output_shape.size(); - } - int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis != -1 && IntToSize(axis) >= output_shape.size()) { - MS_LOG(EXCEPTION) << "invalid axis: " << axis; - } - if (axis == -1) { - axis_ = output_shape.size() - 1; - } else { - axis_ = IntToSize(axis); - } - depth_ = output_shape[axis_]; - stride_ = 1; - for (size_t i = axis_ + 1; i < output_shape.size(); ++i) { - stride_ *= output_shape[i]; - } -} - -bool OneHotCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.size() < 3 || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output invalid!"; - } - auto indices = reinterpret_cast(inputs[0]->addr); - auto on_value = reinterpret_cast(inputs[1]->addr)[0]; - auto off_value = reinterpret_cast(inputs[2]->addr)[0]; - auto output = reinterpret_cast(outputs[0]->addr); - size_t elem_num = inputs[0]->size / sizeof(int); - - for (size_t i = 0; i < elem_num; i++) { - size_t stride_num = i / stride_; - size_t output_index = stride_num * depth_ * stride_ + i % stride_; - size_t index = IntToSize(indices[i]); - for (size_t j = 0; j < depth_; j++) { - if (index == j) { - output[output_index] = on_value; - } else { - output[output_index] = off_value; - } - output_index += stride_; - } - } - - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.h deleted file mode 100644 index ef13047343..0000000000 --- a/mindspore/ccsrc/kernel/cpu/one_hot_cpu_kernel.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class OneHotCPUKernel : public CPUKernel { - public: - OneHotCPUKernel() = default; - ~OneHotCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t depth_; - size_t stride_; - size_t axis_; -}; - -MS_REG_CPU_KERNEL(OneHot, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - OneHotCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_ONE_HOT_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/apply_momentum_ps_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/apply_momentum_ps_kernel.cc deleted file mode 100644 index ecbf407610..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/apply_momentum_ps_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/ps/apply_momentum_ps_kernel.h" - -namespace mindspore { -namespace kernel { -namespace ps { -bool ApplyMomentumPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - return Launch(inputs, workspace, outputs); -} - -const std::vector &ApplyMomentumPSKernel::input_sizes() const { return GetInputSizeList(); } - -const std::vector &ApplyMomentumPSKernel::output_sizes() const { return GetOutputSizeList(); } - -const std::vector &ApplyMomentumPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } -} // namespace ps -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/apply_momentum_ps_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/apply_momentum_ps_kernel.h deleted file mode 100644 index 43992abc87..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/apply_momentum_ps_kernel.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ - -#include -#include -#include "kernel/cpu/ps/pserver_kernel.h" -#include "kernel/cpu/apply_momentum_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace ps { -class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { - public: - ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} - ~ApplyMomentumPSKernel() override = default; - - bool Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - const std::vector &input_sizes() const override; - const std::vector &output_sizes() const override; - const std::vector &workspace_sizes() const override; -}; -} // namespace ps -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_APPLY_MOMENTUM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_proxy_kernel.cc deleted file mode 100644 index 01dad83f98..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_proxy_kernel.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/ps/embedding_look_up_proxy_kernel.h" -#include -#include "parallel/ps/worker.h" - -namespace mindspore { -namespace kernel { -namespace ps { -void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { - EmbeddingLookUpCPUKernel::InitKernel(kernel_node); - - for (auto dim : input_shape_) { - input_dims_ *= dim; - } - - if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { - key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); - } - std::vector keys{key_, key_, key_}; - std::vector values; - values.insert(values.end(), input_shape_.begin(), input_shape_.end()); - values.insert(values.end(), indices_shape_.begin(), indices_shape_.end()); - values.insert(values.end(), output_shape_.begin(), output_shape_.end()); - std::vector lens{SizeToInt(input_shape_.size()), SizeToInt(indices_shape_.size()), - SizeToInt(output_shape_.size())}; - const char *env_role = getenv(mindspore::parallel::ps::kEnvRole); - if (env_role != nullptr && strcmp(env_role, mindspore::parallel::ps::kEnvRoleOfWorker) == 0) { - parallel::ps::Worker::GetInstance().AddEmbeddingTable(key_, input_shape_[axis_]); - parallel::ps::Worker::GetInstance().InitPSEmbeddingTable(keys, values, lens); - } -} - -bool EmbeddingLookUpProxyKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto indices_addr = reinterpret_cast(inputs[1]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - size_t input_size = inputs[1]->size; - size_t output_size = outputs[0]->size; - - size_t size = input_size / sizeof(float); - ::ps::SArray lookup_ids(size, 0); - ::ps::SArray lengths{size}; - ::ps::SArray lookup_result; - - auto ret = memcpy_s(lookup_ids.data(), input_size, indices_addr, input_size); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; - } - parallel::ps::Worker::GetInstance().DoPSEmbeddingLookup({key_}, lookup_ids, lengths, lookup_result, - parallel::ps::kEmbeddingLookupCmd); - - auto ret2 = memcpy_s(output_addr, output_size, lookup_result.data(), output_size); - if (ret2 != EOK) { - MS_LOG(EXCEPTION) << "Lookup result memcpy failed."; - } - return true; -} -} // namespace ps -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_proxy_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_proxy_kernel.h deleted file mode 100644 index 1ce9154ac0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_proxy_kernel.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ - -#include "kernel/cpu/embedding_look_up_cpu_kernel.h" -#include -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -namespace ps { -class EmbeddingLookUpProxyKernel : public EmbeddingLookUpCPUKernel { - public: - EmbeddingLookUpProxyKernel() = default; - ~EmbeddingLookUpProxyKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t key_{0}; - size_t input_dims_{1}; -}; - -MS_REG_CPU_KERNEL( - EmbeddingLookupProxy, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - EmbeddingLookUpProxyKernel); -} // namespace ps -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PROXY_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_ps_kernel.cc deleted file mode 100644 index efabb49550..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_ps_kernel.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/ps/embedding_look_up_ps_kernel.h" -#include -#include -#include -#include "kernel/common_utils.h" -#include "parallel/ps/util.h" - -namespace mindspore { -namespace kernel { -namespace ps { -using mindspore::parallel::ps::Util; -void EmbeddingLookUpPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - input_shape_ = *(shape_vec[0]); - input_lens_ = 1; - for (auto shape : input_shape_) { - input_lens_ = input_lens_ * shape; - } - indices_shape_ = *(shape_vec[1]); - indices_lens_ = 1; - for (auto shape : indices_shape_) { - indices_lens_ = indices_lens_ * shape; - } - output_shape_ = *(shape_vec[2]); - axis_ = 2; - reduce_scatter_flag_ = false; - - size_t offset = 0; - for (size_t i = 0; i < rank_id_; i++) { - offset += Util::LocalShard(input_shape_[axis_], i, pserver_num_); - } - offset_ = offset; - split_num_ = pserver_num_; - - // input shape should be sharded after computing offset_; - Shard(input_shape_, axis_); - - size_t output_size = - std::accumulate(output_shape_.begin(), output_shape_.end(), sizeof(float), std::multiplies()); - output_size_list_.emplace_back(output_size); - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); -} - -void EmbeddingLookUpPSKernel::ReInit(const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - const auto &indices_shape_ = *(shape_vec[0]); - indices_lens_ = indices_shape_[0]; - - size_t output_size = sizeof(float) * indices_lens_; - for (size_t i = axis_ + 1; i < input_shape_.size(); i++) { - output_size *= input_shape_[i]; - } - output_size_list_.clear(); - output_size_list_.emplace_back(output_size); -} - -bool EmbeddingLookUpPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - return Launch(inputs, workspace, outputs); -} - -const std::vector &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } - -const std::vector &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } - -const std::vector &EmbeddingLookUpPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } -} // namespace ps -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_ps_kernel.h deleted file mode 100644 index 11850b2fa6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/embedding_look_up_ps_kernel.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ - -#include -#include -#include "kernel/cpu/embedding_look_up_cpu_kernel.h" -#include "kernel/cpu/ps/pserver_kernel.h" - -namespace mindspore { -namespace kernel { -namespace ps { -class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { - public: - EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} - ~EmbeddingLookUpPSKernel() override = default; - - void InitKernel(const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; - - bool Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - const std::vector &input_sizes() const override; - const std::vector &output_sizes() const override; - const std::vector &workspace_sizes() const override; -}; -} // namespace ps -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_EMBEDDING_LOOK_UP_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/pserver_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/pserver_kernel.cc deleted file mode 100644 index d6a7725a8d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/pserver_kernel.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/ps/pserver_kernel.h" -#include "parallel/ps/util.h" - -namespace mindspore { -namespace kernel { -namespace ps {} // namespace ps -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/pserver_kernel.h deleted file mode 100644 index 527ee2c7fe..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/pserver_kernel.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ - -#include -#include -#include "kernel/kernel.h" -#include "parallel/ps/util.h" - -namespace mindspore { -namespace kernel { -namespace ps { -using mindspore::parallel::ps::Util; -class PServerKernel { - public: - PServerKernel(size_t rank_id, size_t pserver_num) : rank_id_(rank_id), pserver_num_(pserver_num) {} - ~PServerKernel() = default; - PServerKernel(const PServerKernel &) = delete; - PServerKernel &operator=(const PServerKernel &) = delete; - - virtual void InitKernel(const std::shared_ptr>>> &) {} - virtual void ReInit(const std::shared_ptr>>> &) {} - virtual bool Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) = 0; - - virtual const std::vector &input_sizes() const = 0; - virtual const std::vector &output_sizes() const = 0; - virtual const std::vector &workspace_sizes() const = 0; - - protected: - virtual void ReInit(const std::vector &) {} - void Shard(std::vector *shape, int axis) { - (*shape)[axis] = Util::LocalShard((*shape)[axis], rank_id_, pserver_num_); - } - - size_t rank_id_; - size_t pserver_num_; -}; -} // namespace ps -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_PS_PSERVER_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/pull_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/pull_kernel.cc deleted file mode 100644 index 90b5e2e64d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/pull_kernel.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/ps/pull_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_CPU_KERNEL_T( - Pull, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PullKernel, float); -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/pull_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/pull_kernel.h deleted file mode 100644 index 5cde005617..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/pull_kernel.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ - -#include -#include -#include "parallel/ps/worker.h" -#include "parallel/ps/util.h" -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class PullKernel : public CPUKernel { - public: - PullKernel() : keys_size_(sizeof(size_t)), var_size_(sizeof(size_t)) {} - ~PullKernel() override = default; - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &) { - // If the paramter is embedding table, don't Pull from PServer. - if (param_name_.find("embedding") == std::string::npos && param_name_.find("wide_w") == std::string::npos) { - parallel::ps::Worker::GetInstance().Pull(key_, inputs[1]->addr, inputs[1]->size); - } - return true; - } - void Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but pull needs 2 inputs."; - return; - } - - auto key_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < key_shape.size(); i++) { - keys_size_ *= key_shape[i]; - } - auto var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < var_shape.size(); i++) { - var_size_ *= var_shape[i]; - } - auto param_node = AnfAlgo::GetInputNode(kernel_node, 1); - MS_EXCEPTION_IF_NULL(param_node); - param_name_ = param_node->fullname_with_scope(); - - if (mindspore::parallel::ps::Util::IsRoleOfWorker()) { - key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); - } - InitSizeLists(); - return; - } - void InitKernel(const CNodePtr &kernel_node) { return; } - - protected: - void InitSizeLists() { - input_size_list_.push_back(keys_size_); - input_size_list_.push_back(var_size_); - output_size_list_.push_back(0); - } - - private: - size_t key_; - size_t keys_size_; - size_t var_size_; - std::string param_name_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_PS_PULL_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/push_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/push_kernel.cc deleted file mode 100644 index a49c7e9207..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/push_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/ps/push_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_CPU_KERNEL_T(Push, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeUInt64), - PushKernel, float); - -MS_REG_CPU_KERNEL_T( - Push, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), - PushKernel, float); -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/push_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/push_kernel.h deleted file mode 100644 index 436bebd388..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/push_kernel.h +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ - -#include -#include -#include "parallel/ps/worker.h" -#include "parallel/ps/util.h" -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class PushKernel : public CPUKernel { - public: - PushKernel() : key_(UINT64_MAX) {} - ~PushKernel() override = default; - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs) { - std::vector keys; - std::vector addrs; - std::vector sizes; - for (auto input : inputs) { - keys.push_back(key_); - addrs.push_back(reinterpret_cast(input->addr)); - sizes.push_back(SizeToInt(input->size) / sizeof(T)); - } - parallel::ps::Worker::GetInstance().Push(keys, addrs, sizes); - memcpy(outputs[0]->addr, &key_, sizeof(size_t)); - return true; - } - - void Init(const CNodePtr &kernel_node) { - key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); - auto optim_input_shapes = AnfAlgo::GetNodeAttr>>(kernel_node, "optim_input_shapes"); - std::vector only_shape_indices = AnfAlgo::GetNodeAttr>(kernel_node, "only_shape_indices"); - MS_LOG(INFO) << "Key " << key_ << " optimizer input shapes are:" << optim_input_shapes; - MS_LOG(INFO) << "Only init shape indices are " << only_shape_indices; - for (size_t i = 0; i < optim_input_shapes.size(); i++) { - auto shape = optim_input_shapes[i]; - mindspore::parallel::ps::Worker::GetInstance().SetOptimInputShapes(key_, shape); - if (std::count(only_shape_indices.begin(), only_shape_indices.end(), i) == 0) { - size_t size = sizeof(T); - for (size_t j = 0; j < shape.size(); j++) { - size *= shape[j]; - } - input_size_list_.push_back(size); - } - } - - output_size_list_.push_back(sizeof(size_t)); - return; - } - - void InitKernel(const CNodePtr &kernel_node) { return; } - - private: - size_t key_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_PS_PUSH_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_adam_ps_kernel.cc deleted file mode 100644 index 947f379f5d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/ps/sparse_apply_adam_ps_kernel.h" -#include -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" -#include "parallel/ps/util.h" - -namespace mindspore { -namespace kernel { -namespace ps { -void SparseApplyAdamPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - std::vector &var_shape = *(shape_vec[0]); - std::vector &m_shape = *(shape_vec[1]); - std::vector &v_shape = *(shape_vec[2]); - const std::vector &grad_shape = *(shape_vec[9]); - const std::vector &indices_shape = *(shape_vec[10]); - - Shard(&var_shape, 0); - Shard(&m_shape, 0); - Shard(&v_shape, 0); - - if (!IsSameShape(var_shape, m_shape)) { - MS_LOG(EXCEPTION) << "var and m should have the same shape"; - } - if (!IsSameShape(var_shape, v_shape)) { - MS_LOG(EXCEPTION) << "var and v should have the same shape"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be 1D"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(ERROR) << "The first dimension of grad shape must be equal to indices"; - } - /* - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); - } - */ - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); -} - -void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - const std::vector &indices_shape = *(shape_vec[0]); - indices_size_ = indices_shape[0]; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); -} - -void SparseApplyAdamPSKernel::ReInit(const std::vector &inputs) { - const auto &indices_addr = inputs[10]; - indices_size_ = indices_addr->size; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); -} - -bool SparseApplyAdamPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - ReInit(inputs); - int *indices = reinterpret_cast(inputs[10]->addr); - for (size_t i = 0; i < inputs[10]->size / sizeof(int); i++) { - indices[i] -= rank_id_ * var_first_dim_size_; - } - return Launch(inputs, workspace, outputs); -} - -const std::vector &SparseApplyAdamPSKernel::input_sizes() const { return GetInputSizeList(); } - -const std::vector &SparseApplyAdamPSKernel::output_sizes() const { return GetOutputSizeList(); } - -const std::vector &SparseApplyAdamPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } -} // namespace ps -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_adam_ps_kernel.h deleted file mode 100644 index df49ccc889..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_adam_ps_kernel.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ - -#include -#include -#include "kernel/cpu/ps/pserver_kernel.h" -#include "kernel/cpu/sparse_apply_adam_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace ps { -using mindspore::kernel::SparseApplyAdamCPUKernel; -class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { - public: - SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} - ~SparseApplyAdamPSKernel() override = default; - - void InitKernel(const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; - bool Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - const std::vector &input_sizes() const override; - const std::vector &output_sizes() const override; - const std::vector &workspace_sizes() const override; - - protected: - void ReInit(const std::vector &) override; -}; -} // namespace ps -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc deleted file mode 100644 index 26cc42685f..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace ps { -void SparseApplyFtrlPSKernel::InitKernel( - const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - std::vector var_shape = *(shape_vec[0]); - std::vector accum_shape = *(shape_vec[1]); - std::vector linear_shape = *(shape_vec[2]); - std::vector grad_shape = *(shape_vec[3]); - std::vector indices_shape = *(shape_vec[4]); - - Shard(&var_shape, 0); - Shard(&accum_shape, 0); - Shard(&linear_shape, 0); - - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be a 1D vector"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - lr_ = 0.01; - l1_ = 1e-8; - l2_ = 1e-8; - lr_power_ = -0.5; - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>>> &shapes) { - const std::vector>> &shape_vec = *shapes; - std::vector indices_shape = *(shape_vec[0]); - indices_size_ = indices_shape[0]; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); -} - -void SparseApplyFtrlPSKernel::ReInit(const std::vector &inputs) { - const auto &indices_addr = inputs[4]; - indices_size_ = indices_addr->size; - workspace_size_list_[0] = indices_size_ * var_outer_dim_size_ * sizeof(float); - workspace_size_list_[1] = indices_size_ * sizeof(int); -} - -bool SparseApplyFtrlPSKernel::Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) { - ReInit(inputs); - int *indices = reinterpret_cast(inputs[4]->addr); - for (size_t i = 0; i < inputs[4]->size / sizeof(int); i++) { - indices[i] -= rank_id_ * var_first_dim_size_; - } - return Launch(inputs, workspace, outputs); -} - -const std::vector &SparseApplyFtrlPSKernel::input_sizes() const { return GetInputSizeList(); } - -const std::vector &SparseApplyFtrlPSKernel::output_sizes() const { return GetOutputSizeList(); } - -const std::vector &SparseApplyFtrlPSKernel::workspace_sizes() const { return GetWorkspaceSizeList(); } -} // namespace ps -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.h deleted file mode 100644 index b1afcaf87e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/ps/sparse_apply_ftrl_ps_kernel.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ - -#include -#include -#include "kernel/cpu/ps/pserver_kernel.h" -#include "kernel/cpu/sparse_apply_ftrl_cpu_kernel.h" - -namespace mindspore { -namespace kernel { -namespace ps { -using mindspore::kernel::SparseApplyFtrlCPUKernel; -class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { - public: - SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} - ~SparseApplyFtrlPSKernel() override = default; - - void InitKernel(const std::shared_ptr>>> &) override; - void ReInit(const std::shared_ptr>>> &) override; - - bool Execute(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - const std::vector &input_sizes() const override; - const std::vector &output_sizes() const override; - const std::vector &workspace_sizes() const override; - - protected: - void ReInit(const std::vector &) override; -}; -} // namespace ps -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_PS_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc deleted file mode 100644 index e56f2af8c7..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.cc +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include "kernel/cpu/reduce_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -const size_t kReduceTypeMax = 0; -const size_t kReduceTypeMean = 1; -const size_t kReduceTypeSum = 2; -const size_t kMaxDim = 100; -void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "ReduceMax") { - reduce_type_ = kReduceTypeMax; - } else if (kernel_name == "ReduceMean") { - reduce_type_ = kReduceTypeMean; - } else if (kernel_name == "ReduceSum") { - reduce_type_ = kReduceTypeSum; - } else { - MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; - } - shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); - if (axis_addr->isa()) { - auto attr_axis = AnfAlgo::GetNodeAttr>(kernel_node, AXIS); - if (attr_axis.size() > shape_.size()) { - MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size(); - } else if (attr_axis.empty()) { - axis_.push_back(shape_.size() - 1); - } else { - for (auto axis : attr_axis) { - if (IntToSize(axis) >= (shape_.size())) { - MS_LOG(EXCEPTION) << "axis value is oversize."; - } - axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); - } - } - } else if (axis_addr->isa()) { - int axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); - if (axis >= 0 && IntToSize(axis) >= shape_.size()) { - MS_LOG(EXCEPTION) << "axis value is oversize."; - } - axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; - } - for (size_t i = 0; i < shape_.size(); ++i) { - if (shape_[i] <= 0) { - MS_LOG(EXCEPTION) << "shape value is invalid."; - } - left_dims_ *= shape_[i]; - } - for (size_t i = 0; i < axis_.size(); ++i) { - stride_ *= shape_[axis_[i]]; - } - if (stride_ <= 0) { - MS_LOG(EXCEPTION) << "stride_ must greater than zero."; - } - left_dims_ = left_dims_ / stride_; -} -bool ReduceCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspaces*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - size_t out_float_size = left_dims_ * sizeof(float); - size_t in_float_size = stride_ * out_float_size; - if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) { - MS_LOG(EXCEPTION) << "invalid input or output data size!"; - } - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - int size = inputs[0]->size / sizeof(float); - std::vector new_input(IntToSize(size), 0.0); - std::vector transpose_axis; - for (size_t i = 0; i < shape_.size(); ++i) { - bool insert = true; - for (size_t j = 0; j < axis_.size(); ++j) { - if (axis_[j] == i) { - insert = false; - break; - } - } - if (insert) { - transpose_axis.push_back(i); - } - } - (void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end()); - Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]); - if (reduce_type_ == kReduceTypeMax) { - for (size_t i = 0; i < left_dims_; ++i) { - float value = new_input[i * stride_]; - for (size_t k = 0; k < stride_; ++k) { - if (value < new_input[i * stride_ + k]) { - value = new_input[i * stride_ + k]; - } - } - output[i] = value; - } - } else { - for (size_t i = 0; i < left_dims_; ++i) { - float value = 0.0; - for (size_t k = 0; k < stride_; ++k) { - value += new_input[i * stride_ + k]; - } - if (reduce_type_ == kReduceTypeMean) { - output[i] = value / stride_; - } else { - output[i] = value; - } - } - } - return true; -} -void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector &input_shape, - const std::vector &input_axis, const int shape_size, float *output) { - int pos_array[kMaxDim]; - int size_offset[kMaxDim]; - size_offset[0] = size / SizeToInt(input_shape[0]); - for (int i = 1; i < shape_size; i++) { - size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]); - } - for (int position = 0; position < size; position += 1) { - int temp_position = position; - pos_array[0] = temp_position / size_offset[0]; - for (int i = 1; i < shape_size; i++) { - temp_position -= pos_array[i - 1] * size_offset[i - 1]; - pos_array[i] = temp_position / size_offset[i]; - } - int new_position = pos_array[SizeToInt(input_axis[shape_size - 1])]; - int new_position_size = 1; - for (int j = shape_size - 2; j >= 0; j--) { - new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]); - new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size; - } - output[new_position] = input[position]; - } - return; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h deleted file mode 100644 index 3317ec72ed..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_cpu_kernel.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ -#include -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ReduceCPUKernel : public CPUKernel { - public: - ReduceCPUKernel() = default; - ~ReduceCPUKernel() override = default; - void InitKernel(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void Transpose(const int size, const float *input, const std::vector &input_shape, - const std::vector &input_axis, const int shape_size, float *output); - size_t reduce_type_; - std::vector axis_; - std::vector shape_; - size_t left_dims_ = 1; - size_t stride_ = 1; -}; -MS_REG_CPU_KERNEL(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceCPUKernel); -MS_REG_CPU_KERNEL(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceCPUKernel); -MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceCPUKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc deleted file mode 100644 index 19a4e907a0..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/reduce_scatter_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "device/cpu/mpi/mpi_adapter.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr auto kRanksGroup = "group"; -} // namespace - -ReduceScatterCPUKernel::ReduceScatterCPUKernel() : op_type_(device::cpu::kOpTypeSum) {} - -void ReduceScatterCPUKernel::InitKernel(const CNodePtr &kernel_node) { - auto op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); - if (op != nullptr) { - op_type_ = GetValue(op); - } - - auto ranks_group = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kRanksGroup); - if (ranks_group != nullptr) { - ranks_group_ = GetValue>(ranks_group); - } else { - MS_LOG(EXCEPTION) << "Miss attribute " << kRanksGroup; - } -} - -bool ReduceScatterCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto output_data_num = outputs[0]->size / sizeof(float); - auto mpi_instance = device::cpu::MPIAdapter::Instance(); - MS_EXCEPTION_IF_NULL(mpi_instance); - return mpi_instance->ReduceScatter(input_addr, output_addr, ranks_group_, output_data_num, op_type_); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h deleted file mode 100644 index 5c6907602a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reduce_scatter_cpu_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ReduceScatterCPUKernel : public CPUKernel { - public: - ReduceScatterCPUKernel(); - ~ReduceScatterCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::string op_type_; - std::vector ranks_group_; -}; - -MS_REG_CPU_KERNEL(_HostReduceScatter, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReduceScatterCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_SCATTER_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.cc deleted file mode 100644 index 7342a19e99..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/reshape_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); } - -bool ReshapeCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "input or output empty!"; - } - if (inputs[0]->size != outputs[0]->size) { - return false; - } - - if (inputs[0]->addr == outputs[0]->addr) { - return true; - } - - size_t mem_bits = outputs[0]->size; - auto ret = memcpy_s(outputs[0]->addr, mem_bits, inputs[0]->addr, mem_bits); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h deleted file mode 100644 index 6ca746f4ac..0000000000 --- a/mindspore/ccsrc/kernel/cpu/reshape_cpu_kernel.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class ReshapeCPUKernel : public CPUKernel { - public: - ReshapeCPUKernel() = default; - ~ReshapeCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; -}; - -MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReshapeCPUKernel); -MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ReshapeCPUKernel); - -MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReshapeCPUKernel); -MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ReshapeCPUKernel); - -MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ReshapeCPUKernel); -MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ReshapeCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_RESHAPE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc deleted file mode 100644 index afb3e6a247..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.cc +++ /dev/null @@ -1,179 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/slice_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - } - auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto strides = prim->GetAttr(STRIDES); - if (strides != nullptr) { - strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); - if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) { - MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; - } - for (size_t i = 0; i < strides_.size(); ++i) { - if (strides_[i] < 0) { - strides_[i] = (strides_[i] + input_shape_[i]) > 0 ? (strides_[i] + input_shape_[i]) : 0; - } - if (end_[i] < 0) { - end_[i] = (end_[i] + input_shape_[i]) > 0 ? (end_[i] + input_shape_[i]) : 0; - } - } - } else { - auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); - if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) { - MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; - } - for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] < 0) { - sizes[i] = (sizes[i] + input_shape_[i]) > 0 ? (sizes[i] + input_shape_[i]) : 0; - } - strides_.emplace_back(1); - end_.emplace_back(begin_[i] + sizes[i]); - } - } - - ExpandAllMemberDims(); - CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); - CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); -} - -void SliceCPUKernel::ExpandAllMemberDims() { - CPUKernelUtils::ExpandDimsTo4(&output_shape_); - - auto input_len = input_shape_.size(); - if (input_len < 4) { - for (size_t i = 0; i < 4 - input_len; ++i) { - input_shape_.insert(input_shape_.begin(), 1); - begin_.insert(begin_.begin(), 0); - strides_.insert(strides_.begin(), 1); - end_.insert(end_.begin(), 1); - } - } -} - -bool SliceCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - - bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; - size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1], - begin_[2] * input_element_num_[2]}; - size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1], - strides_[2] * input_element_num_[2]}; - - auto in_n_offset = in_start_offset[0]; - auto out_n_offset = 0; - for (int i = begin_[0]; i < end_[0]; - i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) { - if (can_copy_memory[0]) { - CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); - continue; - } - auto in_c_offset = in_start_offset[1]; - auto out_c_offset = 0; - for (int j = begin_[1]; j < end_[1]; - j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) { - if (can_copy_memory[1]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, - input_element_num_[1]); - continue; - } - auto in_h_offset = in_start_offset[2]; - auto out_h_offset = 0; - for (int k = begin_[2]; k < end_[2]; - k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) { - if (can_copy_memory[2]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, - out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); - continue; - } - for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { - *output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m]; - } - } - } - } - - return true; -} - -bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { - for (size_t i = dim + 1; i < 4; ++i) { - if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) { - return false; - } - } - return true; -} - -void SliceCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, - size_t copy_num) const { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto in_buff_size = inputs[0]->size; - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto out_buff_size = outputs[0]->size; - - if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { - MS_LOG(EXCEPTION) << "input memory out of bounds."; - } - if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { - MS_LOG(EXCEPTION) << "output memory out of bounds."; - } - - auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, - copy_num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; - } -} - -void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs."; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output."; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower."; - } - if (input_shape.size() == 0) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h deleted file mode 100644 index 913c993d7a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_cpu_kernel.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SliceCPUKernel : public CPUKernel { - public: - SliceCPUKernel() = default; - ~SliceCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void ExpandAllMemberDims(); - bool CanCopyMemoryOnAxis(size_t dim) const; - void CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, size_t copy_num) const; - void CheckParam(const CNodePtr &kernel_node) const; - std::vector begin_; - std::vector end_; - std::vector strides_; - std::vector input_shape_; - std::vector input_element_num_; - std::vector output_shape_; - std::vector output_element_num_; -}; - -MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceCPUKernel); -MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc deleted file mode 100644 index 92eaffe8c6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.cc +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/slice_grad_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -#include "ir/primitive.h" - -namespace mindspore { -namespace kernel { -void SliceGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - CheckParam(kernel_node); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - - begin_ = AnfAlgo::GetNodeAttr>(kernel_node, BEGIN); - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + output_shape_[i]; - } - } - - auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); - MS_EXCEPTION_IF_NULL(prim); - auto strides = prim->GetAttr(STRIDES); - if (strides != nullptr) { - strides_ = AnfAlgo::GetNodeAttr>(kernel_node, STRIDES); - end_ = AnfAlgo::GetNodeAttr>(kernel_node, END); - if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) { - MS_LOG(EXCEPTION) << "stride|end|input size must be equal"; - } - for (size_t i = 0; i < strides_.size(); ++i) { - if (strides_[i] < 0) { - strides_[i] = (strides_[i] + output_shape_[i]) > 0 ? (strides_[i] + output_shape_[i]) : 0; - } - if (end_[i] < 0) { - end_[i] = (end_[i] + output_shape_[i]) > 0 ? (end_[i] + output_shape_[i]) : 0; - } - } - } else { - auto sizes = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); - if (sizes.size() != output_shape_.size() || begin_.size() != output_shape_.size()) { - MS_LOG(EXCEPTION) << "begin|size|input size must be equal"; - } - for (size_t i = 0; i < sizes.size(); ++i) { - if (sizes[i] < 0) { - sizes[i] = (sizes[i] + output_shape_[i]) > 0 ? (sizes[i] + output_shape_[i]) : 0; - } - strides_.emplace_back(1); - end_.emplace_back(begin_[i] + sizes[i]); - } - } - - ExpandAllMemberDims(); - CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_); - CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); -} - -void SliceGradCPUKernel::ExpandAllMemberDims() { - CPUKernelUtils::ExpandDimsTo4(&input_shape_); - - auto output_len = output_shape_.size(); - if (output_len < 4) { - for (size_t i = 0; i < 4 - output_len; ++i) { - output_shape_.insert(output_shape_.begin(), 1); - begin_.insert(begin_.begin(), 0); - strides_.insert(strides_.begin(), 1); - end_.insert(end_.begin(), 1); - } - } -} - -bool SliceGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - - auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size); - if (ret != EOK) { - MS_LOG(ERROR) << "output buff memset fail. ret:" << ret; - return false; - } - - bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)}; - size_t out_start_offset[3] = {begin_[0] * output_element_num_[0], begin_[1] * output_element_num_[1], - begin_[2] * output_element_num_[2]}; - size_t out_step_size[3] = {strides_[0] * output_element_num_[0], strides_[1] * output_element_num_[1], - strides_[2] * output_element_num_[2]}; - - auto in_n_offset = 0; - auto out_n_offset = out_start_offset[0]; - for (int i = begin_[0]; i < end_[0]; - i += strides_[0], in_n_offset += input_element_num_[0], out_n_offset += out_step_size[0]) { - if (can_copy_memory[0]) { - CopyDataToOutput(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0]); - continue; - } - auto in_c_offset = 0; - auto out_c_offset = out_start_offset[1]; - for (int j = begin_[1]; j < end_[1]; - j += strides_[1], in_c_offset += input_element_num_[1], out_c_offset += out_step_size[1]) { - if (can_copy_memory[1]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset, - input_element_num_[1]); - continue; - } - auto in_h_offset = 0; - auto out_h_offset = out_start_offset[2]; - for (int k = begin_[2]; k < end_[2]; - k += strides_[2], in_h_offset += input_element_num_[2], out_h_offset += out_step_size[2]) { - if (can_copy_memory[2]) { - CopyDataToOutput(inputs, in_n_offset + in_c_offset + in_h_offset, outputs, - out_n_offset + out_c_offset + out_h_offset, input_element_num_[2]); - continue; - } - for (int m = begin_[3]; m < end_[3]; m += strides_[3]) { - output_addr[out_n_offset + out_c_offset + out_h_offset + m] = *input_addr++; - } - } - } - } - return true; -} - -bool SliceGradCPUKernel::CanCopyMemoryOnAxis(size_t dim) const { - for (size_t i = dim + 1; i < 4; ++i) { - if (begin_[i] != 0 || end_[i] != SizeToInt(output_shape_[i]) || strides_[i] != 1) { - return false; - } - } - return true; -} - -void SliceGradCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, - size_t copy_num) const { - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto in_buff_size = inputs[0]->size; - auto output_addr = reinterpret_cast(outputs[0]->addr); - auto out_buff_size = outputs[0]->size; - - if ((in_offset + copy_num) * sizeof(float) > in_buff_size) { - MS_LOG(EXCEPTION) << "input memory out of bounds."; - } - if ((out_offset + copy_num) * sizeof(float) > out_buff_size) { - MS_LOG(EXCEPTION) << "output memory out of bounds."; - } - - auto ret = memcpy_s(output_addr + out_offset, out_buff_size - out_offset * sizeof(float), input_addr + in_offset, - copy_num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret; - } -} - -void SliceGradCPUKernel::CheckParam(const CNodePtr &kernel_node) const { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; - } - if (input_shape.size() == 0) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h deleted file mode 100644 index 1e42c8ac68..0000000000 --- a/mindspore/ccsrc/kernel/cpu/slice_grad_cpu_kernel.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SliceGradCPUKernel : public CPUKernel { - public: - SliceGradCPUKernel() = default; - ~SliceGradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - void ExpandAllMemberDims(); - bool CanCopyMemoryOnAxis(size_t dim) const; - void CopyDataToOutput(const std::vector &inputs, size_t in_offset, - const std::vector &outputs, size_t out_offset, size_t copy_num) const; - void CheckParam(const CNodePtr &kernel_node) const; - std::vector begin_; - std::vector end_; - std::vector strides_; - std::vector input_shape_; - std::vector input_element_num_; - std::vector output_shape_; - std::vector output_element_num_; -}; - -MS_REG_CPU_KERNEL( - SliceGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradCPUKernel); -MS_REG_CPU_KERNEL(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SLICE_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc deleted file mode 100644 index ef3db78275..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc +++ /dev/null @@ -1,177 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_adam_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyAdamInputSize = 11; - -void ComputeAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto m = input_params->m_; - auto m_t = input_params->m_t_; - auto v = input_params->v_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - auto use_nesterov = input_params->use_nesterov_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - m[j] += (1 - beta1) * summed_grad; - v[j] += (1 - beta2) * summed_grad * summed_grad; - if (use_nesterov) { - m_t[j] = m[j] * beta1 + (1 - beta1) * summed_grad; - } - } - } -} - -void ComputeMomentum(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto m = input_params->m_; - auto v = input_params->v_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - for (size_t i = start; i < end; ++i) { - m[i] *= beta1; - v[i] *= beta2; - } -} - -void ComputeWeight(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto m = input_params->m_; - auto v = input_params->v_; - auto lr = input_params->lr_; - auto epsilon = input_params->epsilon_; - for (size_t i = start; i < end; ++i) { - var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); - } -} -} // namespace - -void SparseApplyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); -} - -void SparseApplyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); - if (!IsSameShape(var_shape, m_shape)) { - MS_LOG(EXCEPTION) << "var and m should have the same shape"; - } - if (!IsSameShape(var_shape, v_shape)) { - MS_LOG(EXCEPTION) << "var and v should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be 1D"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); - } -} - -bool SparseApplyAdamCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyAdamInputSize) { - MS_LOG(EXCEPTION) << "Error input size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto m = reinterpret_cast(inputs[1]->addr); - auto v = reinterpret_cast(inputs[2]->addr); - auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; - if (beta1_power == 1) { - MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; - } - auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; - auto lr = reinterpret_cast(inputs[5]->addr)[0]; - auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; - auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; - auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; - auto grad = reinterpret_cast(inputs[9]->addr); - auto indices = reinterpret_cast(inputs[10]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - auto m_t = reinterpret_cast(workspace[2]->addr); - - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); - size_t total_dim_size = var_first_dim_size_ * var_outer_dim_size_; - lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); - - MultiThreadComputeParams input_params; - input_params.m_ = m; - input_params.v_ = v; - input_params.beta1_ = beta1; - input_params.beta2_ = beta2; - MultiThreadCompute(ComputeMomentum, &input_params, total_dim_size); - - input_params.m_t_ = m_t; - input_params.use_nesterov_ = use_nesterov_; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeAdam, &input_params, unique_sparse_grad.indices_size_); - - if (use_nesterov_) { - input_params.m_ = input_params.m_t_; - } - input_params.var_ = var; - input_params.lr_ = lr; - input_params.epsilon_ = epsilon; - MultiThreadCompute(ComputeWeight, &input_params, total_dim_size); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h deleted file mode 100644 index 05bcad16f6..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyAdamCPUKernel : public CPUKernel { - public: - SparseApplyAdamCPUKernel() = default; - ~SparseApplyAdamCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; - bool use_nesterov_{false}; -}; - -MS_REG_CPU_KERNEL(SparseApplyAdam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyAdamCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc deleted file mode 100644 index 03fb1d303f..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_ftrl_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyFtrlInputSize = 5; - -void ComputeFtrl(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto accum = input_params->accum_; - auto linear = input_params->linear_; - auto lr = input_params->lr_; - auto l1 = input_params->l1_; - auto l2_plus = 2 * input_params->l2_; - auto lr_power = input_params->lr_power_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - auto accum_new = accum[j] + summed_grad * summed_grad; - float y; - if (lr_power == -0.5) { - y = std::sqrt(accum_new); - linear[j] += summed_grad - (y - std::sqrt(accum[j])) / lr * var[j]; - } else { - y = std::pow(accum_new, -lr_power); - linear[j] += summed_grad - (y - std::pow(accum[j], -lr_power)) / lr * var[j]; - } - accum[j] = accum_new; - auto x = Sign(linear[j]) * l1 - linear[j]; - y = y / lr + l2_plus; - var[j] = std::fabs(linear[j]) > l1 ? x / y : 0; - } - } -} -} // namespace - -void SparseApplyFtrlCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyFtrlCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - if (!IsSameShape(var_shape, accum_shape)) { - MS_LOG(EXCEPTION) << "var and accum should have the same shape"; - } - if (!IsSameShape(var_shape, linear_shape)) { - MS_LOG(EXCEPTION) << "var and linear should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be a 1D vector"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - lr_ = AnfAlgo::GetNodeAttr(kernel_node, "lr"); - if (lr_ <= 0) { - MS_LOG(EXCEPTION) << "lr should be a positive scalar"; - } - l1_ = AnfAlgo::GetNodeAttr(kernel_node, "l1"); - if (l1_ < 0) { - MS_LOG(EXCEPTION) << "l1 should be a non-negative scalar"; - } - l2_ = AnfAlgo::GetNodeAttr(kernel_node, "l2"); - if (l2_ < 0) { - MS_LOG(EXCEPTION) << "l2 should be a non-negative scalar"; - } - lr_power_ = AnfAlgo::GetNodeAttr(kernel_node, "lr_power"); - if (lr_power_ > 0) { - MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; - } -} - -bool SparseApplyFtrlCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyFtrlInputSize) { - MS_LOG(EXCEPTION) << "error input output size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto accum = reinterpret_cast(inputs[1]->addr); - auto linear = reinterpret_cast(inputs[2]->addr); - auto grad = reinterpret_cast(inputs[3]->addr); - auto indices = reinterpret_cast(inputs[4]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - auto tmp_grad = reinterpret_cast(workspace[2]->addr); - auto tmp_indices = reinterpret_cast(workspace[3]->addr); - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); - TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, - var_first_dim_size_, var_outer_dim_size_); - - MultiThreadComputeParams input_params; - input_params.var_ = var; - input_params.accum_ = accum; - input_params.linear_ = linear; - input_params.lr_ = lr_; - input_params.l1_ = l1_; - input_params.l2_ = l2_; - input_params.lr_power_ = lr_power_; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeFtrl, &input_params, unique_sparse_grad.indices_size_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h deleted file mode 100644 index dd218294e3..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ - -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyFtrlCPUKernel : public CPUKernel { - public: - SparseApplyFtrlCPUKernel() = default; - ~SparseApplyFtrlCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - protected: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; - float lr_{0}; - float l1_{0}; - float l2_{0}; - float lr_power_{0}; -}; - -MS_REG_CPU_KERNEL(SparseApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyFtrlCPUKernel); - -MS_REG_CPU_KERNEL(SparseApplyFtrlNoReturn, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyFtrlCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_FTRL_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc deleted file mode 100644 index ed5438a318..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyLazyAdamInputSize = 11; - -void ComputeLazyAdam(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto m = input_params->m_; - auto v = input_params->v_; - auto lr = input_params->lr_; - auto beta1 = input_params->beta1_; - auto beta2 = input_params->beta2_; - auto epsilon = input_params->epsilon_; - auto use_nesterov = input_params->use_nesterov_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - m[j] = beta1 * m[j] + (1 - beta1) * summed_grad; - v[j] = beta2 * v[j] + (1 - beta2) * summed_grad * summed_grad; - if (use_nesterov) { - var[j] -= lr * (m[j] * beta1 + (1 - beta1) * summed_grad) / (std::sqrt(v[j]) + epsilon); - } else { - var[j] -= lr * m[j] / (std::sqrt(v[j]) + epsilon); - } - } - } -} -} // namespace - -void SparseApplyLazyAdamCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyLazyAdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 10); - if (!IsSameShape(var_shape, m_shape)) { - MS_LOG(EXCEPTION) << "var and m should have the same shape"; - } - if (!IsSameShape(var_shape, v_shape)) { - MS_LOG(EXCEPTION) << "var and v should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be 1D"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - if (AnfAlgo::HasNodeAttr(USE_NESTEROV, kernel_node)) { - use_nesterov_ = AnfAlgo::GetNodeAttr(kernel_node, "use_nesterov"); - } -} - -bool SparseApplyLazyAdamCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyLazyAdamInputSize) { - MS_LOG(EXCEPTION) << "Error input size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto m = reinterpret_cast(inputs[1]->addr); - auto v = reinterpret_cast(inputs[2]->addr); - auto beta1_power = reinterpret_cast(inputs[3]->addr)[0]; - if (beta1_power == 1) { - MS_LOG(EXCEPTION) << "The beta1_power should not be 1"; - } - auto beta2_power = reinterpret_cast(inputs[4]->addr)[0]; - auto lr = reinterpret_cast(inputs[5]->addr)[0]; - auto beta1 = reinterpret_cast(inputs[6]->addr)[0]; - auto beta2 = reinterpret_cast(inputs[7]->addr)[0]; - auto epsilon = reinterpret_cast(inputs[8]->addr)[0]; - auto grad = reinterpret_cast(inputs[9]->addr); - auto indices = reinterpret_cast(inputs[10]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - auto tmp_grad = reinterpret_cast(workspace[2]->addr); - auto tmp_indices = reinterpret_cast(workspace[3]->addr); - - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - SparseGradient tmp_sparse_grad({tmp_grad, tmp_indices, indices_size_}); - TwoLevelReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &tmp_sparse_grad, &unique_sparse_grad, - var_first_dim_size_, var_outer_dim_size_); - - lr = lr * std::sqrt(1 - beta2_power) / (1 - beta1_power); - MultiThreadComputeParams input_params; - input_params.var_ = var; - input_params.m_ = m; - input_params.v_ = v; - input_params.lr_ = lr; - input_params.beta1_ = beta1; - input_params.beta2_ = beta2; - input_params.epsilon_ = epsilon; - input_params.use_nesterov_ = use_nesterov_; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeLazyAdam, &input_params, unique_sparse_grad.indices_size_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h deleted file mode 100644 index 795568a64d..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyLazyAdamCPUKernel : public CPUKernel { - public: - SparseApplyLazyAdamCPUKernel() = default; - ~SparseApplyLazyAdamCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; - bool use_nesterov_{false}; -}; - -MS_REG_CPU_KERNEL(SparseApplyLazyAdam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyLazyAdamCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_LAZY_ADAM_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc deleted file mode 100644 index 6069fb708e..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" -#include "kernel/common_utils.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -namespace { -constexpr size_t kSparseApplyProximalAdagradInputSize = 7; - -void ComputeProximalAdagrad(MultiThreadComputeParams *input_params, size_t start, size_t end) { - MS_EXCEPTION_IF_NULL(input_params); - auto var = input_params->var_; - auto accum = input_params->accum_; - auto lr = input_params->lr_; - auto l1 = input_params->l1_; - auto l2 = input_params->l2_; - auto unique_sparse_grad = input_params->sparse_grad_; - auto var_first_dim_size = input_params->var_first_dim_size_; - auto var_outer_dim_size = input_params->var_outer_dim_size_; - for (size_t i = start; i < end; ++i) { - int index = unique_sparse_grad.indices_[i]; - if (index < 0 || IntToSize(index) >= var_first_dim_size) { - MS_LOG(EXCEPTION) << "Index " << index << " in indices is out of range after unique process"; - } - size_t start_index = var_outer_dim_size * index; - size_t end_index = start_index + var_outer_dim_size; - for (size_t j = start_index, k = var_outer_dim_size * i; j < end_index; ++j, ++k) { - auto summed_grad = unique_sparse_grad.value_[k]; - accum[j] += summed_grad * summed_grad; - auto learning_rate = lr * (1 / std::sqrt(accum[j])); - auto prox_v = var[j]; - prox_v -= summed_grad * learning_rate; - if (l1 > 0) { - var[j] = Sign(prox_v) * std::fmax(std::fabs(prox_v) - learning_rate * l1, static_cast(0.0)) / - (1 + l2 * learning_rate); - } else { - var[j] = prox_v / (1 + l2 * learning_rate); - } - } - } -} -} // namespace - -void SparseApplyProximalAdagradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { - CPUKernel::InitInputOutputSize(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_node); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); -} - -void SparseApplyProximalAdagradCPUKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::vector var_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - std::vector accum_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - std::vector lr_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - std::vector l1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - std::vector l2_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 5); - std::vector indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 6); - if (!IsSameShape(var_shape, accum_shape)) { - MS_LOG(EXCEPTION) << "var and accum should have the same shape"; - } - if (var_shape.empty()) { - MS_LOG(EXCEPTION) << "var must be at least 1D"; - } - var_first_dim_size_ = var_shape[0]; - for (size_t i = 1; i < var_shape.size(); ++i) { - if (var_shape[i] != grad_shape[i]) { - MS_LOG(EXCEPTION) << "The shape of var and grad must equal in dimension " << i; - } - var_outer_dim_size_ *= var_shape[i]; - } - if (indices_shape.size() != 1) { - MS_LOG(EXCEPTION) << "indices must be a 1D vector"; - } - indices_size_ = indices_shape[0]; - if (grad_shape[0] != indices_size_) { - MS_LOG(EXCEPTION) << "The first dimension of grad shape must be equal to indices"; - } - if (!lr_shape.empty()) { - MS_LOG(EXCEPTION) << "lr is not a scalar"; - } - if (!l1_shape.empty()) { - MS_LOG(EXCEPTION) << "l1 is not a scalar"; - } - if (!l2_shape.empty()) { - MS_LOG(EXCEPTION) << "l2 is not a scalar"; - } -} - -bool SparseApplyProximalAdagradCPUKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector & /*outputs*/) { - if (inputs.size() < kSparseApplyProximalAdagradInputSize) { - MS_LOG(EXCEPTION) << "Wrong input size!"; - } - - auto var = reinterpret_cast(inputs[0]->addr); - auto accum = reinterpret_cast(inputs[1]->addr); - auto lr = reinterpret_cast(inputs[2]->addr)[0]; - auto l1 = reinterpret_cast(inputs[3]->addr)[0]; - auto l2 = reinterpret_cast(inputs[4]->addr)[0]; - auto grad = reinterpret_cast(inputs[5]->addr); - auto indices = reinterpret_cast(inputs[6]->addr); - auto new_grad = reinterpret_cast(workspace[0]->addr); - auto new_indices = reinterpret_cast(workspace[1]->addr); - SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size_}); - ReduceSparseGradient(SparseGradient({grad, indices, indices_size_}), &unique_sparse_grad, var_first_dim_size_, - var_outer_dim_size_); - - MultiThreadComputeParams input_params; - input_params.var_ = var; - input_params.accum_ = accum; - input_params.lr_ = lr; - input_params.l1_ = l1; - input_params.l2_ = l2; - input_params.sparse_grad_ = unique_sparse_grad; - input_params.var_first_dim_size_ = var_first_dim_size_; - input_params.var_outer_dim_size_ = var_outer_dim_size_; - MultiThreadCompute(ComputeProximalAdagrad, &input_params, unique_sparse_grad.indices_size_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h deleted file mode 100644 index ff7da7966c..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ - -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SparseApplyProximalAdagradCPUKernel : public CPUKernel { - public: - SparseApplyProximalAdagradCPUKernel() = default; - ~SparseApplyProximalAdagradCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - void InitInputOutputSize(const CNodePtr &kernel_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - size_t indices_size_{0}; - size_t var_first_dim_size_{0}; - size_t var_outer_dim_size_{1}; -}; - -MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyProximalAdagradCPUKernel); - -MS_REG_CPU_KERNEL(SparseApplyProximalAdagradNoReturn, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SparseApplyProximalAdagradCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SPARSE_APPLY_PROXIMAL_ADAGRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc deleted file mode 100644 index 543f0e5cdd..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "kernel/cpu/sub_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" - -namespace mindspore { -namespace kernel { -void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - if (shape.size() == 1) { - if (shape[0] != 1) { - MS_LOG(EXCEPTION) << "input 1 only support scalar"; - } - } else { - MS_LOG(EXCEPTION) << "input 1 only support scalar"; - } -} - -void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) { - for (size_t i = 0; i < lens; i++) { - out_addr[i] = in_addr[i] - offset; - } -} - -bool SubCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - auto input_addr = reinterpret_cast(inputs[0]->addr); - auto output_addr = reinterpret_cast(outputs[0]->addr); - offset_ = *reinterpret_cast(inputs[1]->addr); - MS_LOG(INFO) << "offset: " << offset_; - auto lens = inputs[0]->size / sizeof(int); - if (lens < 10000) { - for (size_t i = 0; i < lens; i++) { - output_addr[i] = input_addr[i] - offset_; - } - } else { - const size_t thread_num = 4; - std::thread threads[4]; - size_t process_lens = (lens + thread_num - 1) / thread_num; - size_t process_offset = 0; - for (size_t i = 0; i < thread_num; i++) { - threads[i] = - std::thread(sub_task, input_addr + process_offset, output_addr + process_offset, process_lens, offset_); - if (process_offset + process_lens > lens) { - process_lens = lens - process_offset; - process_offset = lens; - } else { - process_offset += process_lens; - } - } - for (size_t i = 0; i < thread_num; i++) { - threads[i].join(); - } - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "SubscaleCPUKernel, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "SubCPUKernel, used time: " << time << " us"; -#endif - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h deleted file mode 100644 index 54b2c8951a..0000000000 --- a/mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SubCPUKernel : public CPUKernel { - public: - SubCPUKernel() : offset_(0) {} - ~SubCPUKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - int offset_; -}; - -MS_REG_CPU_KERNEL( - Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SubCPUKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_CPU_SUB_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.cc b/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.cc deleted file mode 100644 index f2ac9350cb..0000000000 --- a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/cpu/transpose_cpu_kernel.h" -#include "device/cpu/cpu_device_address.h" -namespace mindspore { -namespace kernel { -const size_t kMaxDim = 100; -void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - axis_ = AnfAlgo::GetNodeAttr>(kernel_node, "perm"); - if (shape_.size() != axis_.size()) { - MS_LOG(EXCEPTION) << "The size of input shape and transpose axis shape must be equal."; - } -} -bool TransposeCPUFwdKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto input = reinterpret_cast(inputs[0]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - size_t size = IntToSize(inputs[0]->size / sizeof(float)); - size_t shape_size = IntToSize(shape_.size()); - if (shape_size > kMaxDim) { - MS_LOG(EXCEPTION) << "Input is " << shape_size << "-D, but transpose supports max " << kMaxDim << "-D inputs."; - } - size_t pos_array[kMaxDim]; - size_t size_offset[kMaxDim]; - size_offset[0] = size / shape_[0]; - for (size_t i = 1; i < shape_size; i++) { - size_offset[i] = size_offset[SizeToInt(i) - 1] / shape_[i]; - } - for (size_t position = 0; position < size; position += 1) { - size_t temp_position = position; - pos_array[0] = temp_position / size_offset[0]; - for (size_t i = 1; i < shape_size; i++) { - temp_position -= pos_array[SizeToInt(i) - 1] * size_offset[i - 1]; - pos_array[i] = temp_position / size_offset[i]; - } - size_t new_position = pos_array[axis_[SizeToInt(shape_size) - 1]]; - size_t new_position_size = 1; - for (int j = shape_size - 2; j >= 0; j--) { - new_position_size *= shape_[axis_[j + 1]]; - new_position += pos_array[axis_[j]] * new_position_size; - } - output[new_position] = input[position]; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.h b/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.h deleted file mode 100644 index d882f4fa51..0000000000 --- a/mindspore/ccsrc/kernel/cpu/transpose_cpu_kernel.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ -#include -#include -#include -#include "kernel/cpu/cpu_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" -namespace mindspore { -namespace kernel { -class TransposeCPUFwdKernel : public CPUKernel { - public: - TransposeCPUFwdKernel() = default; - ~TransposeCPUFwdKernel() override = default; - - void InitKernel(const CNodePtr &kernel_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override; - - private: - std::vector shape_; - std::vector axis_; -}; - -MS_REG_CPU_KERNEL(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TransposeCPUFwdKernel); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_CPU_TRANSPOSE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.cc deleted file mode 100644 index 71f612d07c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/argmax_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), - ArgmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), - ArgmaxGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.h deleted file mode 100644 index 3df70d0960..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmax_gpu_kernel.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/argmax_impl.cuh" -namespace mindspore { -namespace kernel { -#define ARGMAX_MAX_DIMENSION 2 -template -class ArgmaxGpuKernel : public GpuKernel { - public: - ArgmaxGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0), batch_size_(0), channel_size_(0), axis_(0) {} - ~ArgmaxGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - int *output = GetDeviceAddress(outputs, 0); - CalArgmax(input, SizeToInt(batch_size_), SizeToInt(channel_size_), axis_, output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but argmax needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but argmax needs 1 output."; - return false; - } - auto output_type = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("output_type")); - if (output_type->type_id() != TypeId::kNumberTypeInt32) { - MS_LOG(EXCEPTION) << "Argmax only supports int32 output type."; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > ARGMAX_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << input_shape.size() << "-D, but argmax supports max " << ARGMAX_MAX_DIMENSION - << "-D inputs."; - } - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - axis_ += SizeToInt(input_shape.size()); - } - if (input_shape.size() == 1) { - batch_size_ = 0; - channel_size_ = input_shape[0]; - input_size_ = sizeof(T) * channel_size_; - output_size_ = sizeof(int); - } else { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - input_size_ = sizeof(T) * batch_size_ * channel_size_; - output_size_ = (axis_ == 1) ? sizeof(int) * batch_size_ : sizeof(int) * channel_size_; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - int axis_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc deleted file mode 100644 index 24c8a9a730..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - ArgMaxWithValue, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - ArgmaxWithValueGpuKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - ArgMaxWithValue, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - ArgmaxWithValueGpuKernel, half, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h deleted file mode 100644 index 304f0ab161..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/argmaxwithvalue_gpu_kernel.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/argmaxwithvalue_impl.cuh" -namespace mindspore { -namespace kernel { -template -class ArgmaxWithValueGpuKernel : public GpuKernel { - public: - ArgmaxWithValueGpuKernel() : input_size_(0), output_size_(0), bound_(0), outerSize_(0), innerSize_(0) {} - ~ArgmaxWithValueGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 1); - S *index = GetDeviceAddress(outputs, 0); - CalArgmaxWithValue(input, bound_, outerSize_, innerSize_, index, output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - std::vector shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1); - int dims = shape.size(); - int axis = GetAttr(kernel_node, "axis"); - if (axis < 0) { - axis += dims; - } - input_size_ = sizeof(T); - for (auto x : shape) { - input_size_ *= x; - } - output_size_ = sizeof(S); - for (auto x : output_shape) { - output_size_ *= x; - } - bound_ = shape[axis]; - outerSize_ = 1; - for (int i = axis - 1; i >= 0; i--) { - outerSize_ *= shape[i]; - } - - innerSize_ = 1; - for (int i = axis + 1; i < dims; i++) { - innerSize_ *= shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_ / sizeof(S) * sizeof(T)); - } - - private: - size_t input_size_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - int bound_; - int outerSize_; - int innerSize_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARGMAXWITHVALUEGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.cc deleted file mode 100644 index f378604624..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/array_reduce_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ArrayReduceGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ArrayReduceGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ArrayReduceGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ArrayReduceGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ArrayReduceGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ArrayReduceGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h deleted file mode 100644 index 4a52439305..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h +++ /dev/null @@ -1,237 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -const std::map kReduceTypeMap = { - {"ReduceMax", CUDNN_REDUCE_TENSOR_MAX}, - {"ReduceMean", CUDNN_REDUCE_TENSOR_AVG}, - {"ReduceSum", CUDNN_REDUCE_TENSOR_ADD}, -}; -template -class ArrayReduceGpuKernel : public GpuKernel { - public: - ArrayReduceGpuKernel() - : cudnn_handle_(nullptr), - reduce_tensor_op_(CUDNN_REDUCE_TENSOR_ADD), - data_type_(CUDNN_DATA_FLOAT), - nan_prop_(CUDNN_NOT_PROPAGATE_NAN), - reduce_indices_(CUDNN_REDUCE_TENSOR_NO_INDICES), - reduce_tensor_descriptor_(nullptr), - inputA_descriptor_(nullptr), - outputC_descriptor_(nullptr), - keep_dims_(false), - all_match_(false), - is_null_input_(false), - input_size_(0), - output_size_(0), - workspace_size_(0) {} - ~ArrayReduceGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - T *workspace_addr = GetDeviceAddress(workspace, 0); - - const float alpha = 1; - const float beta = 0; - if (all_match_) { - MS_LOG(WARNING) - << "The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."; - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, input_addr, inputs[0]->size, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha, - inputA_descriptor_, input_addr, &beta, outputC_descriptor_, output_addr), - "cudnnReduceTensor failed."); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but reduce op needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but reduce op needs 1 output."; - return false; - } - int input_dim_length = SizeToInt(AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0).size()); - - if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa() || - AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { - auto attr_axis = GetAttr>(kernel_node, "axis"); - if (attr_axis.empty()) { - axis_.push_back(-1); - } else { - for (auto axis : attr_axis) { - axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); - } - } - } else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("axis")->isa()) { - int axis = GetAttr(kernel_node, "axis"); - axis < 0 ? axis_.push_back(axis + input_dim_length) : axis_.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Attribute axis type is invalid."; - } - keep_dims_ = GetAttr(kernel_node, "keep_dims"); - - auto inputA_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto outputC_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(inputA_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ArrayReduceGpuKernel input is null"; - InitSizeLists(); - return true; - } - InferInAndOutDesc(inputA_shape, outputC_shape); - InferArrayReduceType(kernel_node); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&reduce_tensor_descriptor_), - "cudnnCreateReduceTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&outputC_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - void InitSizeLists() override { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed."); - input_size_list_.push_back(input_size_); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(outputC_descriptor_, &output_size_), - "cudnnGetTensorSizeInBytes failed."); - output_size_list_.push_back(output_size_); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetReductionWorkspaceSize(cudnn_handle_, reduce_tensor_descriptor_, inputA_descriptor_, outputC_descriptor_, - &workspace_size_), - "cudnnGetReductionWorkspaceSize failed."); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyReduceTensorDescriptor(reduce_tensor_descriptor_), - "cudnnDestroyReduceTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(outputC_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - void InferArrayReduceType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kReduceTypeMap.find(kernel_name); - if (iter == kReduceTypeMap.end()) { - MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported."; - } else { - reduce_tensor_op_ = iter->second; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetReduceTensorDescriptor(reduce_tensor_descriptor_, reduce_tensor_op_, CUDNN_DATA_FLOAT, nan_prop_, - reduce_indices_, CUDNN_32BIT_INDICES), - "cudnnSetReduceTensorDescriptor failed"); - return; - } - void InferInAndOutDesc(const std::vector &input_shape, const std::vector &output_shape) { - std::vector inputA; - std::vector outputC_shape = output_shape; - ShapeNdTo4d(input_shape, &inputA); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0], - inputA[1], inputA[2], inputA[3]), - "cudnnSetTensor4dDescriptor failed"); - - if (axis_[0] == -1) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1), - "cudnnSetTensor4dDescriptor failed"); - if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) { - all_match_ = true; - } - return; - } - if (!keep_dims_) { - for (auto i : axis_) { - (void)(outputC_shape.insert(outputC_shape.begin() + i, 1)); - } - } - std::vector outputC; - ShapeNdTo4d(outputC_shape, &outputC); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, - outputC[0], outputC[1], outputC[2], outputC[3]), - "cudnnSetTensor4dDescriptor failed"); - if (inputA == outputC) { - all_match_ = true; - } - return; - } - - cudnnHandle_t cudnn_handle_; - cudnnReduceTensorOp_t reduce_tensor_op_; - cudnnDataType_t data_type_; - cudnnNanPropagation_t nan_prop_; - cudnnReduceTensorIndices_t reduce_indices_; - cudnnReduceTensorDescriptor_t reduce_tensor_descriptor_; - cudnnTensorDescriptor_t inputA_descriptor_; - cudnnTensorDescriptor_t outputC_descriptor_; - - std::vector axis_; - bool keep_dims_; - bool all_match_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ARRAYREDUCE_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc deleted file mode 100644 index 3bca6a69d3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/concatv2_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConcatV2GpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Concat, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ConcatV2GpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE( - Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ConcatV2GpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h deleted file mode 100644 index a91c50ce69..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/concatv2_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class ConcatV2GpuFwdKernel : public GpuKernel { - public: - ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} - ~ConcatV2GpuFwdKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (inputs.size() == 2) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, - reinterpret_cast(stream_ptr)); - } - - if (inputs.size() == 3) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *input_2 = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, - reinterpret_cast(stream_ptr)); - } - - if (inputs.size() == 4) { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *input_2 = GetDeviceAddress(inputs, 2); - T *input_3 = GetDeviceAddress(inputs, 3); - T *output = GetDeviceAddress(outputs, 0); - ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, - reinterpret_cast(stream_ptr)); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - axis_ += SizeToInt(input_shape.size()); - } - - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_size = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - for (size_t j = 0; j < input_shape.size(); j++) { - input_size *= SizeToInt(input_shape[j]); - if (j >= IntToSize(axis_)) { - w_[i] *= SizeToInt(input_shape[j]); - } - input_size_list_.push_back(input_size); - } - } - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - output_size_ = sizeof(T); - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= output_shape[i]; - } - output_size_list_.push_back(output_size_); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override {} - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num < 2 || input_num > 4) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ConcatV2GpuFwdKernel needs 1 output."; - return false; - } - return true; - } - int w_[4] = {1, 1, 1, 1}; - int axis_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.cc deleted file mode 100644 index dc595e4793..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/gather_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - GatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - GatherGpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - GatherV2, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - GatherGpuFwdKernel, half, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.h deleted file mode 100644 index 72a05b0915..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/gather_gpu_kernel.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_GATHER_GPU_KERNEL_H -#define MINDSPORE_GATHER_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/gather.cuh" - -namespace mindspore { -namespace kernel { -template -class GatherGpuFwdKernel : public GpuKernel { - public: - GatherGpuFwdKernel() : axis_(0), handle_(nullptr) {} - ~GatherGpuFwdKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input_addr = GetDeviceAddress(inputs, 0); - S *indices_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - auto input_dim1 = input_shapes_[IntToSize(axis_)]; - Gather(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuFwdKernel needs 2."; - } - input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - axis_ = GetAttr(kernel_node, "axis"); - if (axis_ < 0) { - axis_ = axis_ + SizeToInt(input_shapes_.size()); - } - - Reshape(); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - void InitSizeLists() override { - size_t size = GetSize(input_shapes_); - input_size_list_.push_back(size); - - size = GetSize(indices_shapes_); - input_size_list_.push_back(size); - - size = GetSize(output_shapes_); - output_size_list_.push_back(size); - } - - private: - void Reshape() { - size_t dim_before_axis = 1; - for (size_t i = 0; i < IntToSize(axis_); i++) { - dim_before_axis *= output_shapes_[i]; - } - - size_t dim_of_indices = 1; - for (size_t i = 0; i < indices_shapes_.size(); i++) { - dim_of_indices *= indices_shapes_[i]; - } - - size_t dim_after_indices = 1; - for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) { - dim_after_indices *= output_shapes_[i]; - } - - dims_[0] = dim_before_axis; - dims_[1] = dim_of_indices; - dims_[2] = dim_after_indices; - return; - } - size_t GetSize(const std::vector &shape) const { - if (shape.size() == 0) { - return 0; - } - size_t result = sizeof(T); - for (size_t i = 0; i < shape.size(); i++) { - result *= shape[i]; - } - return result; - } - - std::vector input_shapes_; - std::vector indices_shapes_; - std::vector output_shapes_; - - size_t dims_[3] = {}; - int axis_; - cudnnHandle_t handle_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_GATHER_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.cc deleted file mode 100644 index 7c160f8f58..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/one_hot_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(OneHot, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - OneHotGpuFwdKernel, float, int) -MS_REG_GPU_KERNEL_TWO(OneHot, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - OneHotGpuFwdKernel, half, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.h deleted file mode 100644 index c8b64e7243..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/one_hot_gpu_kernel.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/one_hot_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class OneHotGpuFwdKernel : public GpuKernel { - public: - OneHotGpuFwdKernel() : input_size_(1), output_size_(1), depth_(0), left_dim_size_(1), right_dim_size_(1) {} - ~OneHotGpuFwdKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - const S *indices = GetDeviceAddress(inputs, 0); - const T *on_value = GetDeviceAddress(inputs, 1); - const T *off_value = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - OneHot(indices, depth_, on_value, off_value, left_dim_size_, right_dim_size_, output, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - int axis = GetAttr(kernel_node, "axis"); - auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0); - int input_size = SizeToInt(input.size()); - const int default_axis = -1; - - // Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims). - for (int i = 0; i < input_size; i++) { - auto dim_size = input[IntToSize(i)]; - if (axis == default_axis || i < axis) { - left_dim_size_ *= dim_size; - } - if (axis != default_axis && i >= axis) { - right_dim_size_ *= dim_size; - } - } - for (auto size : input) { - input_size_ *= size; - } - for (auto size : output) { - output_size_ *= size; - } - if (axis >= input_size) { - MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size(); - return false; - } - if (axis == default_axis) { - depth_ = output[output.size() - 1]; - } else { - depth_ = output[IntToSize(axis)]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - // inputs: indices, depth - input_size_list_.push_back((input_size_ + 1) * sizeof(S)); - output_size_list_.push_back(output_size_ * sizeof(T)); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - - size_t depth_; - size_t left_dim_size_; - size_t right_dim_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONEHOT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.cc deleted file mode 100644 index 41c9c2243f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.cc +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/select_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Select, - KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SelectGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Select, - KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - SelectGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Select, - KernelAttr() - .AddInputAttr(kNumberTypeBool) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - SelectGpuKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.h deleted file mode 100644 index f1b6c5853a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/select_gpu_kernel.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/select_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SelectGpuKernel : public GpuKernel { - public: - SelectGpuKernel() : input_size_(0), output_size_(0) {} - ~SelectGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - bool *input_cond = GetDeviceAddress(inputs, 0); - T *input_x = GetDeviceAddress(inputs, 1); - T *input_y = GetDeviceAddress(inputs, 2); - T *output = GetDeviceAddress(outputs, 0); - CalSelect(output_size_ / sizeof(T), input_cond, input_x, input_y, output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(bool); - output_size_ = sizeof(T); - for (size_t x : shape) { - input_size_ = input_size_ * x; - output_size_ = output_size_ * x; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(output_size_); - input_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SelectGpuKernel needs 3 output."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but SelectGpuKernel needs 1 output."; - return false; - } - return true; - } - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SELECT_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.cc deleted file mode 100644 index 53161c29c2..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/slice_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGpuFwdKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h deleted file mode 100644 index 7f71e548ad..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/slice_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SliceGpuFwdKernel : public GpuKernel { - public: - SliceGpuFwdKernel() - : is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {} - ~SliceGpuFwdKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - if (is_strided_slice_) { - CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, - reinterpret_cast(stream_ptr)); - } else { - Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], - input_shape_[1], input_shape_[2], input_shape_[3], input, output, - reinterpret_cast(stream_ptr)); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - ShapeNdTo4d(input_shape, &input_shape_); - auto strides = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"); - if (strides) { - strides_ = GetAttr>(kernel_node, "strides"); - for (auto i = strides_.size(); i < 4; i++) { - (void)strides_.insert(strides_.begin(), 1); - } - size_ = GetAttr>(kernel_node, "end"); - is_strided_slice_ = true; - } else { - size_ = GetAttr>(kernel_node, "size"); - } - for (auto i = begin_.size(); i < 4; i++) { - (void)begin_.insert(begin_.begin(), 0); - } - for (size_t i = size_.size(); i < 4; i++) { - (void)size_.insert(size_.begin(), 1); - } - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - } - for (size_t i = 0; i < size_.size(); i++) { - if (size_[i] < 0) { - size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; - } - if (begin_[i] == size_[i] && is_strided_slice_) { - MS_LOG(WARNING) << "Output is null."; - is_null_input_ = true; - } - if (size_[i] == 0 && strides_[i] > 0) { - size_[i] = begin_[i] + 1; - } - } - - input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); - auto out_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - output_size_ = sizeof(T); - for (size_t x : out_shape) { - output_size_ = output_size_ * x; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SliceGpuFwdKernel needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGpuFwdKernel needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGpuFwdKernel olny support 4d or lower."; - return false; - } - if (input_shape.size() == 0) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - return false; - } - begin_ = GetAttr>(kernel_node, "begin"); - for (size_t i = 0; i < input_shape.size(); i++) { - if ((begin_[i] > 0 && (begin_[i] > SizeToInt(input_shape[i]))) || - (begin_[i] < 0 && (std::abs(begin_[i]) > SizeToInt(input_shape[i])))) { - MS_LOG(INFO) << "Input out of bounds " << input_shape[i] << " in axis " << i << "."; - begin_[i] = 0; - } - } - return true; - } - std::vector begin_; - std::vector size_; - std::vector strides_; - std::vector input_shape_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - bool is_strided_slice_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.cc deleted file mode 100644 index b91aafb734..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.cc +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/slice_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - SliceGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE( - SliceGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGradGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SliceGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - SliceGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE(StridedSliceGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SliceGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h deleted file mode 100644 index bf24272d93..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_grad_gpu_kernel.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/slice_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SliceGradGpuKernel : public GpuKernel { - public: - SliceGradGpuKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {} - ~SliceGradGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *dy = GetDeviceAddress(inputs, 0); - T *dx = GetDeviceAddress(outputs, 0); - FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast(stream_ptr)); - if (is_strided_slice_) { - CalStridedSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, strides_, dx, - reinterpret_cast(stream_ptr)); - } else { - CalSliceGrad(output_size_ / sizeof(T), dy, input_shape_, begin_, size_, dx, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "StridedSliceGrad") { - is_strided_slice_ = true; - input_shape_ = GetAttr>(kernel_node, "shapex"); - for (auto i = input_shape_.size(); i < 4; i++) { - (void)input_shape_.insert(input_shape_.begin(), 1); - } - strides_ = GetAttr>(kernel_node, "strides"); - for (auto i = strides_.size(); i < 4; i++) { - (void)strides_.insert(strides_.begin(), 1); - } - size_ = GetAttr>(kernel_node, "end"); - } else { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - ShapeNdTo4d(input_shape, &input_shape_); - size_ = GetAttr>(kernel_node, "size"); - } - - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - ShapeNdTo4d(dy_shape, &dy_shape_); - begin_ = GetAttr>(kernel_node, "begin"); - DealParam(); - input_size_ = IntToSize(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3]) * sizeof(T); - - output_size_ = sizeof(T); - for (auto x : dy_shape_) { - output_size_ = output_size_ * IntToSize(x); - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(output_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but SliceGradGpuKernel needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", but SliceGradGpuKernel only support 4d or lower."; - return false; - } - if (input_shape.size() == 0) { - MS_LOG(ERROR) << "Input dims is " << input_shape.size() << ", scalar is not supported."; - return false; - } - return true; - } - void DealParam() { - for (auto i = begin_.size(); i < 4; i++) { - (void)begin_.insert(begin_.begin(), 0); - } - for (auto i = size_.size(); i < 4; i++) { - (void)size_.insert(size_.begin(), 1); - } - for (size_t i = 0; i < begin_.size(); i++) { - if (begin_[i] < 0) { - begin_[i] = begin_[i] + input_shape_[i]; - } - } - for (size_t i = 0; i < size_.size(); i++) { - if (size_[i] < 0) { - size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0; - } - } - } - std::vector begin_; - std::vector size_; - std::vector strides_; - std::vector input_shape_; - std::vector dy_shape_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - bool is_strided_slice_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_SLICE_GRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.cc deleted file mode 100644 index 338e7a4093..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.cc +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/transpose_gpu_kernel.h" -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TransposeGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - TransposeGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h deleted file mode 100644 index 61be9b68fe..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/transpose_impl.cuh" -namespace mindspore { -namespace kernel { -template -class TransposeGpuFwdKernel : public GpuKernel { - public: - TransposeGpuFwdKernel() : shape_size_(0), input_size_(0), output_size_(0), workspace_size_(0) {} - ~TransposeGpuFwdKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - int *input_shape = GetDeviceAddress(workspace, 0); - int *input_axis = GetDeviceAddress(workspace, 1); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - shape_size_ = input_shape.size(); - if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION - << "-D inputs."; - } - - input_size_ = 1; - for (size_t i = 0; i < shape_size_; i++) { - input_size_ *= input_shape[i]; - input_shape_.push_back(input_shape[i]); - } - input_size_ *= sizeof(T); - output_size_ = input_size_; - auto perm = GetAttr>(kernel_node, "perm"); - for (size_t j = 0; j < perm.size(); j++) { - input_axis_.push_back(perm[j]); - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_ = shape_size_ * sizeof(int); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - std::vector input_shape_; - std::vector input_axis_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t shape_size_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_TRANSPOSE_H_ diff --git a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc deleted file mode 100644 index 9962d55988..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentSumGpuKernel, float, int) - -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - UnsortedSegmentSumGpuKernel, float, int64_t) - -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentSumGpuKernel, int, int) - -MS_REG_GPU_KERNEL_TWO( - UnsortedSegmentSum, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - UnsortedSegmentSumGpuKernel, int, int64_t) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h deleted file mode 100644 index a20375ee29..0000000000 --- a/mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh" - -namespace mindspore { -namespace kernel { -template -class UnsortedSegmentSumGpuKernel : public GpuKernel { - public: - UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {} - ~UnsortedSegmentSumGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - S *indices_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemsetAsync(output_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), - "cudaMemSet Failed"); - UnsortedSegmentSum(input_dim0_, input_dim1_, output_dim0_, output_dim1_, input_addr, indices_addr, output_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); - - auto axis = ids_shapes.size(); - for (size_t i = 0; i < input_shapes.size(); i++) { - if (i < axis) { - input_dim0_ *= input_shapes[i]; - } else { - input_dim1_ *= input_shapes[i]; - } - } - - output_dim0_ = output_shapes[0]; - for (size_t j = 1; j < output_shapes.size(); j++) { - output_dim1_ *= output_shapes[j]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_dim0_ * input_dim1_ * sizeof(T)); - input_size_list_.push_back(input_dim0_ * sizeof(S)); - output_size_list_.push_back(output_dim0_ * output_dim1_ * sizeof(T)); - } - - private: - size_t input_dim0_; - size_t input_dim1_; - size_t output_dim0_; - size_t output_dim1_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc deleted file mode 100644 index 5468aa6500..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/control/recv_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h deleted file mode 100644 index 12b4eed132..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class RecvGpuKernel : public GpuKernel { - public: - RecvGpuKernel() {} - ~RecvGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &, const std::vector &, const std::vector &, - void *) override { - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamWaitEvent(wait_stream_, wait_event_, 0), "Waiting cuda event failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - wait_stream_ = reinterpret_cast(GetAttr(kernel_node, "wait_event_stream")); - wait_event_ = reinterpret_cast(GetAttr(kernel_node, "wait_event")); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - return; - } - - private: - cudaStream_t wait_stream_{nullptr}; - cudaEvent_t wait_event_{nullptr}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc deleted file mode 100644 index c417c30bb3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/control/send_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h deleted file mode 100644 index a26e41aa1e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class SendGpuKernel : public GpuKernel { - public: - SendGpuKernel() {} - ~SendGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &, const std::vector &, const std::vector &, - void *) override { - CHECK_CUDA_RET_WITH_EXCEPT(cudaEventRecord(record_event_, record_stream_), "Recording cuda event failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - record_stream_ = reinterpret_cast(GetAttr(kernel_node, "record_event_stream")); - record_event_ = reinterpret_cast(GetAttr(kernel_node, "record_event")); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - return; - } - - private: - cudaStream_t record_stream_{nullptr}; - cudaEvent_t record_event_{nullptr}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu deleted file mode 100644 index 3ec63ee03a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/adam_impl.cuh" - -template -__device__ __forceinline__ T SqrtFunc(T input) { - return sqrt(input); -} - -template <> -__device__ __forceinline__ half SqrtFunc(half input) { - return hsqrt(input); -} - -template -__global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, - const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, - T *m, T *v) { - const T one = static_cast(1.0); - const T new_learning_rate = learning_rate[0] * SqrtFunc(one - beta2_power[0]) / (one - beta1_power[0]); - - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - m[i] += (gradient[i] - m[i]) * (one - beta1[0]); - v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2[0]); - variable[i] -= new_learning_rate * m[i] / (SqrtFunc(v[i]) + epsilon[0]); - } -} - -template -void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, - const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { - ApplyAdamKernel<<>>( - size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); -} - -template void ApplyAdam(const size_t size, const float *gradient, const float *beta1_power, - const float *beta2_power, const float *learning_rate, const float *beta1, - const float *beta2, const float *epsilon, float *variable, float *m, float *v, - cudaStream_t cuda_stream); -template void ApplyAdam(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, - const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, - half *variable, half *m, half *v, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh deleted file mode 100644 index f48a113c26..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, - const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu deleted file mode 100644 index dfadaa09d6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_weight_decay_impl.cu +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "adam_weight_decay_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void AdamWeightDecayKernel(const int element_num_, const bool need_decay, const float *beta1, - const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, - const float *epsilon, const float *lr, const float *weight_decay, T *m, T *v, - T *param, T *gradient) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num_; i += blockDim.x * gridDim.x) { - float next_m = beta1[0] * m[i] + one_sub_beta1[0] * gradient[i]; - float next_v = beta2[0] * v[i] + one_sub_beta2[0] * gradient[i] * gradient[i]; - float update = next_m / (sqrt(next_v) + epsilon[0]); - if (need_decay && weight_decay != nullptr) { - update += weight_decay[0] * param[i]; - } - param[i] -= lr[0] * update; - m[i] = next_m; - v[i] = next_v; - } -} - -template -void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, const float *one_sub_beta1, - const float *beta2, const float *one_sub_beta2, const float *epsilon, const float *lr, - const float *weight_decay, T *m, T *v, T *param, T *gradient, cudaStream_t stream) { - AdamWeightDecayKernel<<>>( - element_num_, need_decay, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, param, - gradient); -} - -template void AdamWeightDecay(const int &element_num_, const bool &need_decay, const float *beta1, - const float *one_sub_beta1, const float *beta2, const float *one_sub_beta2, - const float *epsilon, const float *lr, const float *weight_decay, float *m, float *v, - float *param, float *gradient, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cu deleted file mode 100755 index e8fab27dda..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmax_impl.cu +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "argmax_impl.cuh" -#include "device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void Argmax1D(const T* input, const int channel_size, int* output) { - int max_index = 0; - T max = input[0]; - for (int pos = 1; pos < channel_size; pos++) { - if (max < input[pos]) { - max = input[pos]; - max_index = pos; - } - } - output[0] = max_index; - return; -} -template -__global__ void ArgmaxDefault2D(const T* input, const int batch_size, const int channel_size, int* output) { - int pos; - int max_index; - T max; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { - max = input[i * channel_size]; - max_index = 0; - for (int j = 1; j < channel_size; j++) { - pos = i * channel_size + j; - if (max < input[pos]) { - max = input[pos]; - max_index = j; - } - } - - output[i] = max_index; - } - return; -} -template -__global__ void ArgmaxAxis2D(const T* input, const int batch_size, const int channel_size, int* output) { - int pos; - int max_index; - T max; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - max = input[i]; - max_index = 0; - for (int j = 1; j < batch_size; j++) { - pos = j * channel_size + i; - if (max < input[pos]) { - max = input[pos]; - max_index = j; - } - } - output[i] = max_index; - } - return; -} -template -void CalArgmax(const T* input, const int batch_size, const int channel_size, const int axis, int* output, - cudaStream_t cuda_stream) { - if (batch_size == 0) { - Argmax1D<<<1, 1, 0, cuda_stream>>>(input, channel_size, output); - } else if (axis == 1) { - ArgmaxDefault2D<<>>(input, batch_size, channel_size, output); - } else { - ArgmaxAxis2D<<>>(input, batch_size, channel_size, output); - } - return; -} - -template void CalArgmax(const float* input, const int batch_size, const int channel_size, const int axis, - int* output, cudaStream_t cuda_stream); -template void CalArgmax(const half* input, const int batch_size, const int channel_size, const int axis, - int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu deleted file mode 100644 index 3313fc6853..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/argmaxwithvalue_impl.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "argmaxwithvalue_impl.cuh" -#include "device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void ArgmaxWithValue(const T* input, const int bound, int outerSize, int innerSize, S* index, - T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (outerSize); pos += blockDim.x * gridDim.x) { - int inputOutterOffset = pos * innerSize * bound; - int outputOutterOffset = pos * innerSize; - for (int j = 0; j < innerSize; j++) { - auto outputInnerOffset = outputOutterOffset + j; - S idx = 0; - T maxData = input[j + inputOutterOffset]; - for (S c = 0; c < bound; c++) { - int offset = j + c * innerSize; - auto inputData = input[inputOutterOffset + offset]; - idx = inputData > maxData ? c : idx; - maxData = inputData > maxData ? inputData : maxData; - } - output[outputInnerOffset] = maxData; - index[outputInnerOffset] = idx; - } - } - return; -} - -template -void CalArgmaxWithValue(const T* input, const int bound_, const int outerSize_, const int innerSize_, - S* index, T* output, cudaStream_t cuda_stream) { - ArgmaxWithValue<<>>(input, bound_, outerSize_, innerSize_, - index, output); - return; -} - -template void CalArgmaxWithValue(const float* input, const int bound_, const int outerSize_, - const int innerSize_, int* index, float* output, - cudaStream_t cuda_stream); -template void CalArgmaxWithValue(const half* input, const int bound_, const int outerSize_, - const int innerSize_, int* index, half* output, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cu deleted file mode 100644 index d44ad99202..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/assign_add_impl.cu +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "assign_add_impl.cuh" -#include "device/gpu/cuda_common.h" -#include "include/cuda_fp16.h" -template -__global__ void AssignAdd(const size_t size, T* ref, const T* value, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - output[pos] = ref[pos] + value[pos]; - ref[pos] = output[pos]; - } - return; -} - -template -void CalAssignAdd(const size_t size, T* ref, const T* value, T* output, cudaStream_t cuda_stream) { - AssignAdd<<>>(size, ref, value, output); - - return; -} - -template void CalAssignAdd(const size_t size, float* ref, const float* value, float* output, - cudaStream_t cuda_stream); -template void CalAssignAdd(const size_t size, half* ref, const half* value, half* output, - cudaStream_t cuda_stream); -template void CalAssignAdd(const size_t size, int* ref, const int* value, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh deleted file mode 100644 index c3ce08dfd0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ - -#include "device/gpu/cuda_common.h" -template -void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean, - const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn, - size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream); -template -void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, - const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, - T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); -template -void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std, - const T *running_mean, const T *running_std, const T *gamma, T *d_gamma, - T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream); -template -void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N, - size_t C, size_t H, size_t W, cudaStream_t cuda_stream); - -template -void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H, - size_t W, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cu deleted file mode 100755 index ddc2803f56..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/batchnorm_fold_impl.cu +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "batchnorm_fold_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - running_std[i] = sqrtf(running_std[i] + epsilon); - } - return; -} - -template -__global__ void UpdateBatchStd(int channel_size, T* batch_std) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) { - batch_std[i] = 1 / batch_std[i]; - } - return; -} - -template -__global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std, - int batch_size, int channel_size, int height, int width, T* dx) { - int n = batch_size * channel_size * height * width; - int normal_size = batch_size * height * width; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { - int channel_index = i / (height * width) % channel_size; - dx[i] = d_batch_mean[channel_index] / normal_size + - d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size; - } - return; -} - -template -void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) { - UpdateRunningStd<<>>(channel_size, epsilon, running_std); - return; -} - -template void CalUpdateRunningStd(int channel_size, double epsilon, float* running_std, - cudaStream_t cuda_stream); - -template -void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) { - UpdateBatchStd<<>>(channel_size, batch_std); - return; -} - -template void CalUpdateBatchStd(int channel_size, float* batch_std, cudaStream_t cuda_stream); - -template -void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, - const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx, - cudaStream_t cuda_stream) { - CalDx<<>>( - d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx); -} - -template void CalBatchNormFoldGrad(const float* d_batch_mean, const float* d_batch_std, const float* x, - const float* batch_mean, const float* batch_std, int batch_size, - int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream); - -template -void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) { - thrust::device_ptr dev_ptr(array); - thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill); -} - -template void ThrustFillWith(float* array, int size, float tofill, cudaStream_t cuda_stream); - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu deleted file mode 100644 index 5aa087e7f5..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cu +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -struct MinimumGradFunc { - __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { - if (x1 < x2) { - atomicAdd(dx1, dy); - } else { - atomicAdd(dx2, dy); - } - } -}; - -template -struct MaximumGradFunc { - __device__ __forceinline__ void operator()(const T &x1, const T &x2, const T &dy, T *dx1, T *dx2) { - if (x1 > x2) { - atomicAdd(dx1, dy); - } else { - atomicAdd(dx2, dy); - } - } -}; - -__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } - -template -__device__ __forceinline__ void BroadcastGradOperator(const int &l0, const int &l1, const int &l2, const int &l3, - const int &r0, const int &r1, const int &r2, const int &r3, - const int &d0, const int &d1, const int &d2, const int &d3, - const T *x1, const T *x2, const T *dy, T *dx1, T *dx2) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3) % d0; - int j = pos / (d2 * d3) % d1; - int k = pos / d3 % d2; - int l = pos % d3; - - int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); - int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); - Func()(x1[l_index], x2[r_index], dy[pos], dx1 + l_index, dx2 + r_index); - } -} - -template -__global__ void BroadcastGradKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, - const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, - enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2) { - switch (op) { - case BROADCAST_GRAD_TYPE_MINIMUM: - return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, - dx1, dx2); - case BROADCAST_GRAD_TYPE_MAXIMUM: - return BroadcastGradOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, x1, x2, dy, - dx1, dx2); - } -} - -template -void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, - cudaStream_t stream) { - int size = d0 * d1 * d2 * d3; - BroadcastGradKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, - x1, x2, dy, dx1, dx2); -} - -template -__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { - Func()(x1[pos], x2[pos], dy[pos], dx1 + pos, dx2 + pos); - } -} - -template -__global__ void NoBroadcastGradKernel(const int nums, enum BroadcastGradOpType op, const T *x1, const T *x2, - const T *dy, T *dx1, T *dx2) { - switch (op) { - case BROADCAST_GRAD_TYPE_MINIMUM: - return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); - case BROADCAST_GRAD_TYPE_MAXIMUM: - return NoBroadcastOperator>(nums, x1, x2, dy, dx1, dx2); - } -} - -template -void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2, cudaStream_t stream) { - NoBroadcastGradKernel<<>>(nums, op, x1, x2, dy, dx1, dx2); -} - -template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const float *x1, const float *x2, - const float *dy, float *dx1, float *dx2, cudaStream_t stream); -template void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const int *x1, const int *x2, - const int *dy, int *dx1, int *dx2, cudaStream_t stream); -template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const float *x1, const float *x2, const float *dy, float *dx1, - float *dx2, cudaStream_t stream); -template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, - int *dx2, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh deleted file mode 100644 index d154eddd4c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_grad_impl.cuh +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ - -#include "device/gpu/cuda_common.h" - -enum BroadcastGradOpType { - BROADCAST_GRAD_TYPE_MAXIMUM = 0, - BROADCAST_GRAD_TYPE_MINIMUM = 1, - BROADCAST_GRAD_TYPE_INVALID = 0xffffffff, -}; - -template -void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, T *dx2, - cudaStream_t stream); - -template -void NoBroadcastGrad(const int &nums, enum BroadcastGradOpType op, const T *x1, const T *x2, const T *dy, T *dx1, - T *dx2, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_GRAD_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu deleted file mode 100644 index afa94fc56c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/broadcast_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -struct GreaterFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } -}; - -template -struct LessFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } -}; - -template -struct MinimumFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } -}; - -template -struct MaximumFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } -}; - -template -struct PowerFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } -}; - -template <> -struct PowerFunc { - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { - return __float2half(pow(__half2float(lhs), __half2float(rhs))); - } -}; - -template -struct RealDivFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } -}; - -template -struct MulFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); } -}; - -template -struct SubFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } -}; - -template -struct AddFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } -}; - -template <> -struct PowerFunc { - // invalid branch - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } -}; - -__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } - -template -__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, - const int &r0, const int &r1, const int &r2, const int &r3, - const int &d0, const int &d1, const int &d2, const int &d3, - const T *input0, const T *input1, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) { - int i = pos / (d1 * d2 * d3) % d0; - int j = pos / (d2 * d3) % d1; - int k = pos / d3 % d2; - int l = pos % d3; - - int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3); - int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3); - output[pos] = Func()(input0[l_index], input1[r_index]); - } -} - -template -__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1, - const int r2, const int r3, const int d0, const int d1, const int d2, const int d3, - enum BroadcastOpType op, const T *input0, const T *input1, S *output) { - switch (op) { - case BROADCAST_TYPE_GREATER: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_LESS: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_MINIMUM: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_MAXIMUM: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_POWER: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_REALDIV: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_MUL: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_SUB: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - case BROADCAST_TYPE_ADD: - return BroadcastOperator>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, - output); - } -} - -template -void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, - const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, - const T *input0, const T *input1, S *output, cudaStream_t stream) { - int size = d0 * d1 * d2 * d3; - BroadcastKernel<<>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op, - input0, input1, output); -} - -template -__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { - output[pos] = Func()(input0[pos], input1[pos]); - } -} - -template -__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, - S *output) { - switch (op) { - case BROADCAST_TYPE_GREATER: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_LESS: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MINIMUM: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MAXIMUM: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_POWER: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_REALDIV: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MUL: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_SUB: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_ADD: - return NoBroadcastOperator>(nums, input0, input1, output); - } -} - -template -void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, - cudaStream_t stream) { - NoBroadcastKernel<<>>(nums, op, input0, input1, output); -} - -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const float *input0, const float *input1, bool *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const float *input0, const float *input1, float *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const half *input0, const half *input1, bool *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const half *input0, const half *input1, half *output, - cudaStream_t stream); -template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, - const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, - enum BroadcastOpType op, const int *input0, const int *input1, int *output, - cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, - bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, - float *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, - bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, - half *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, - int *output, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh deleted file mode 100644 index 5f6992511d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ - -#include "device/gpu/cuda_common.h" - -enum BroadcastOpType { - BROADCAST_TYPE_GREATER = 0, - BROADCAST_TYPE_LESS = 1, - BROADCAST_TYPE_MAXIMUM = 2, - BROADCAST_TYPE_MINIMUM = 3, - BROADCAST_TYPE_POWER = 4, - BROADCAST_TYPE_REALDIV = 5, - BROADCAST_TYPE_MUL = 6, - BROADCAST_TYPE_SUB = 7, - BROADCAST_TYPE_ADD = 8, - BROADCAST_TYPE_INVALID = 0xffffffff, -}; - -template -void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, - const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op, - const T *input0, const T *input1, S *output, cudaStream_t stream); - -template -void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, - cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu deleted file mode 100755 index 5cccf183ea..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/concatv2_impl.cuh" -template -__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2); - int m = pos % (w1 + w2); - output[pos] = m >= w1 ? input_2[n * w2 + m - w1] : input_1[n * w1 + m]; - } - return; -} - -template -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3); - int m = pos % (w1 + w2 + w3); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1] : - input_3[n * w3 + m - w1 - w2]; - } - return; -} - -template -__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int n = pos / (w1 + w2 + w3 + w4); - int m = pos % (w1 + w2 + w3 + w4); - output[pos] = m < w1 ? input_1[n * w1 + m] : - m < w1 + w2 ? input_2[n * w2 + m - w1]: - m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: - input_4[n * w4 + m - w1 - w2 - w3]; - } - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, input_1, input_2, output); - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, w3, input_1, input_2, input_3, output); - return; -} - -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, w1, w2, w3, w4, input_1, - input_2, input_3, input_4, output); - return; -} - -template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const float* input_1, const float* input_2, const float* input_3, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const int* input_1, const int* input_2, const int* input_3, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const half* input_1, const half* input_2, const half* input_3, - half* output, cudaStream_t cuda_stream); - -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const float* input_1, const float* input_2, const float* input_3, const float* input_4, - float* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const int* input_1, const int* input_2, const int* input_3, const int* input_4, - int* output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const half* input_1, const half* input_2, const half* input_3, const half* input_4, - half* output, cudaStream_t cuda_stream); - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh deleted file mode 100755 index b6932aa4a1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream); -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, - const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); -template -void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, - const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cu deleted file mode 100755 index ac2f99ed9a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/correction_mul_impl.cu +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "correction_mul_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw, - T* output) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) { - int n = i / chw; - output[i] = weight[i] * gamma[n] / running_std[n]; - } - return; -} - -template -__global__ void Mul(int N, const T* a, const T* b, T* c) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - c[i] = a[i] * b[i]; - } - return; -} - -template -__global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus()); - d_gamma[i] = d_gamma[i] / running_std[i]; - } - return; -} - -template -void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output, - cudaStream_t cuda_stream) { - CorrectionMul<<>>(weight, gamma, running_std, N, C * H * W, - output); -} - -template void CalCorrectionMul(const float* weight, const float* gamma, const float* running_std, int N, int C, - int H, int W, float* output, cudaStream_t cuda_stream); - -template -void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma, - T* tmp, cudaStream_t cuda_stream) { - Mul<<>>(N * C * H * W, d_out, weight, tmp); - Reduce<<>>(N, C * H * W, tmp, running_std, d_gamma); -} - -template void CalCorrectionMulGrad(const float* d_out, const float* weight, const float* running_std, int N, - int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh deleted file mode 100644 index 54ae072892..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ - -#include "device/gpu/cuda_common.h" - -template -void CrossEntropyWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *loss, - cudaStream_t cuda_stream); - -template -void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream); - -template -void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, - T *dlogits, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh deleted file mode 100644 index f89d42ce49..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/dropout_impl.cuh +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ - -#include "device/gpu/cuda_common.h" -template -void DropoutForward(const T *input, T *mask, T *output, float *mask_f, size_t num_count, float keep_prob, - cudaStream_t cuda_stream); -template -void DropoutBackward(const T *dy, const T *mask, T *dx, size_t num_count, float keep_prob, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cu deleted file mode 100755 index 38dd79c441..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/equalcount_impl.cu +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "equalcount_impl.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void EqualCount(const int size, const T* input1, const T* input2, T* output) { - T equal_count = 0; - - for (int i = 0; i < size; i++) { - if (input1[i] == input2[i]) { - equal_count++; - } - } - - output[0] = equal_count; - return; -} -template -void CalEqualCount(const int size, const T* input1, const T* input2, T* output, cudaStream_t cuda_stream) { - EqualCount<<<1, 1, 0, cuda_stream>>>(size, input1, input2, output); - return; -} - -template void CalEqualCount(const int size, const int* input1, const int* input2, int* output, - cudaStream_t cuda_stream); -template void CalEqualCount(const int size, const float* input1, const float* input2, float* output, - cudaStream_t cuda_stream); -template void CalEqualCount(const int size, const half* input1, const half* input2, half* output, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh deleted file mode 100644 index ad2e387b08..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ - -#include "device/gpu/cuda_common.h" - -void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, - cudaStream_t cuda_stream); - -void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num, - const float *nudge_min, const float *nudge_max, const float *scale, - cudaStream_t cuda_stream); - -void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, - const int channel_num, const float *nudge_min, const float *nudge_max, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh deleted file mode 100644 index dda95ed781..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ - -#include "device/gpu/cuda_common.h" - -void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream); - -void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, - const float *nudge_max, const float *scale, cudaStream_t cuda_stream); - -void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu deleted file mode 100644 index c2fd5ecd70..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "include/cuda_runtime.h" -#include "kernel/gpu/cuda_impl/float_status_impl.cuh" - -template -__global__ void IsNan(const size_t size, const T* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} -template <> -__global__ void IsNan(const size_t size, const half* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} - -template -__global__ void IsInf(const size_t size, const T* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isinf(input[pos]) != 0) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} -template <> -__global__ void IsInf(const size_t size, const half* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisinf(input[pos]) != 0) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} - -template -__global__ void IsFinite(const size_t size, const T* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isinf(input[pos]) == 0 && !isnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} -template <> -__global__ void IsFinite(const size_t size, const half* input, bool* out) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { - out[pos] = true; - } else { - out[pos] = false; - } - } - return; -} - -template -__global__ void FloatStatus(const size_t size, const T* input, T* out) { - out[0] = 0; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (isinf(input[pos]) != 0 || isnan(input[pos])) { - out[0] = 1; - } - } - return; -} -template <> -__global__ void FloatStatus(const size_t size, const half* input, half* out) { - out[0] = 0; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { - out[0] = 1; - } - } - return; -} - -template -void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { - FloatStatus<<>>(size, input, output); - return; -} -template -void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { - IsNan<<>>(size, input, output); - return; -} -template -void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { - IsInf<<>>(size, input, output); - return; -} -template -void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { - IsFinite<<>>(size, input, output); - return; -} - -template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); -template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); -template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); -template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); -template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); -template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); -template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); -template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh deleted file mode 100644 index da488ff937..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ -#include "device/gpu/cuda_common.h" -template -void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); -template -void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); -template -void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); -template -void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu deleted file mode 100644 index ea6ffdbbdc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/ftrl_impl.cuh" - -template -__device__ __forceinline__ T PowFunc(T x, T y) { - return pow(x, y); -} - -template <> -__device__ __forceinline__ half PowFunc(half x, half y) { - return __float2half(pow(__half2float(x), __half2float(y))); -} - -template -__device__ __forceinline__ bool CompareFunc(T x, T y) { - return abs(x) > y; -} - -template <> -__device__ __forceinline__ bool CompareFunc(half x, half y) { - return abs(__half2float(x)) > __half2float(y); -} - -template -__device__ __forceinline__ T Sgn(T x) { - return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); -} - -template <> -__device__ __forceinline__ half Sgn(half x) { - return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); -} - -template -__global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate, - const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power, - T *variable, T *accumulation, T *linear) { - const T two = static_cast(2.0); - const T learning_rate_power_val = -learning_rate_power[0]; - - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i]; - const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); - const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); - const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0]; - - linear[i] += gradient[i] - sigma * variable[i]; - variable[i] = CompareFunc(linear[i], l1_regularization[0]) - ? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) / - (cur_accumulation_power / learning_rate[0] + two * l2_regularization[0])) - : static_cast(0); - accumulation[i] = cur_accumulation; - } -} - -template -void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, - const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, - cudaStream_t cuda_stream) { - ApplyFtrlKernel<<>>(size, gradient, learning_rate, l1_regularization, - l2_regularization, learning_rate_power, variable, - accumulation, linear); -} - -template void ApplyFtrl(const size_t size, const float *gradient, const float *learning_rate, - const float *l1_regularization, const float *l2_regularization, - const float *learning_rate_power, float *variable, float *accumulation, float *linear, - cudaStream_t cuda_stream); -template void ApplyFtrl(const size_t size, const half *gradient, const half *learning_rate, - const half *l1_regularization, const half *l2_regularization, - const half *learning_rate_power, half *variable, half *accumulation, half *linear, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh deleted file mode 100644 index ba4a8fa816..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, - const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cu deleted file mode 100755 index 6bde359d9b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gather.cu +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "kernel/gpu/cuda_impl/gather.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void GatherKernel(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1) { - int num = output_dim0 * output_dim1 * output_dim2; - int i, j, k; - for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num; - write_index += blockDim.x * gridDim.x) { - i = write_index / (output_dim1 * output_dim2) % output_dim0; - j = write_index / output_dim2 % output_dim1; - k = write_index % output_dim2; - - if ((indices[j] >= 0) && (indices[j] < input_dim1)) { - int read_index = i * input_dim1 * output_dim2 + indices[j] * output_dim2 + k; - output[write_index] = input[read_index]; - } else { - output[write_index] = 0; - } - } - - return; -} -template -void Gather(T *input, S *indices, T *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, - size_t input_dim1, cudaStream_t stream) { - int size = output_dim0 * output_dim1 * output_dim2; - GatherKernel<<>>(input, indices, output, output_dim0, output_dim1, - output_dim2, input_dim1); - return; -} - -template void Gather(float *input, int *indices, float *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); - -template void Gather(half *input, int *indices, half *output, size_t output_dim0, size_t output_dim1, - size_t output_dim2, size_t input_dim1, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu deleted file mode 100644 index e460caec9e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/gelu_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void GeluKernel(size_t size, T *input_addr, T *output_addr) { - // formula: - // gelu(x) = 0.5 * x * (1.0 + tanh(y)) - // tanh(y) = 2 / (1 + exp(-2y)) - 1) - // y = sqrt(2/pi) * (x + 0.044715 * x^3) - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - float x = input_addr[pos]; - float tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); - output_addr[pos] = 0.5 * x * (1.0 + tanh_res); - } -} - -template <> -__global__ void GeluKernel(size_t size, half *input_addr, half *output_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - half x = input_addr[pos]; - float tanh_res = tanh(__half2float(half(0.7978845608) * (x + half(0.044715) * x * x * x))); - output_addr[pos] = half(0.5) * x * (half(1.0) + __float2half(tanh_res)); - } -} - -template <> -__global__ void GeluKernel(size_t size, half2 *input_addr, half2 *output_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - half2 x = input_addr[pos]; - float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); - float2 tanh_res; - tanh_res.x = tanh(tanh_param.x); - tanh_res.y = tanh(tanh_param.y); - output_addr[pos] = half2(0.5, 0.5) * x * (half2(1.0, 1.0) + __float22half2_rn(tanh_res)); - } -} - -template -void Gelu(size_t size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) { - GeluKernel<<>>(size, input_addr, output_addr); - return; -} - -template <> -void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream) { - if (size % 2 == 0) { - GeluKernel<<>>( - size / 2, reinterpret_cast(input_addr), reinterpret_cast(output_addr)); - } else { - GeluKernel<<>>(size, input_addr, output_addr); - } - return; -} - -template -__global__ void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr) { - // formula: - // dx = dy * y' - // y' = 0.5 * (1 + tanh(tanh_para)) + - // 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right - // tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3) - // mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2)) - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - T x = x_addr[pos]; - T tanh_res = tanh(0.7978845608 * (x + 0.044715 * x * x * x)); - T mul_right = 0.7978845608 + 0.1070322244 * x * x; - T y_res = 0.5 * (1.0 + tanh_res) + 0.5 * x * (1.0 - tanh_res * tanh_res) * mul_right; - dx_addr[pos] = dy_addr[pos] * y_res; - } -} - -template -__global__ void GeluGradKernel(size_t size, half2 *dy_addr, half2 *x_addr, half2 *dx_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - half2 x = x_addr[pos]; - float2 tanh_param = __half22float2(half2(0.7978845608, 0.7978845608) * (x + half2(0.044715, 0.044715) * x * x * x)); - float2 tanh_res; - tanh_res.x = tanh(tanh_param.x); - tanh_res.y = tanh(tanh_param.y); - half2 tanh_res_half = __float22half2_rn(tanh_res); - half2 mul_right = half2(0.7978845608, 0.7978845608) + half2(0.1070322244, 0.1070322244) * x * x; - half2 y_res = half2(0.5, 0.5) * (half2(1.0, 1.0) + tanh_res_half) + - half2(0.5, 0.5) * x * (half2(1.0, 1.0) - tanh_res_half * tanh_res_half) * mul_right; - dx_addr[pos] = dy_addr[pos] * y_res; - } -} - -template -__global__ void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - half x = x_addr[pos]; - half tanh_param = half(0.7978845608) * (x + half(0.044715) * x * x * x); - half tanh_res = __float2half_rn(tanh(__half2float(tanh_param))); - half mul_right = half(0.7978845608) + half(0.1070322244) * x * x; - half y_res = half(0.5) * (half(1.0) + tanh_res) + half(0.5) * x * (half(1.0) - tanh_res * tanh_res) * mul_right; - dx_addr[pos] = dy_addr[pos] * y_res; - } -} - -template -void GeluGradKernel(size_t size, T *dy_addr, T *x_addr, T *dx_addr, cudaStream_t cuda_stream) { - GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); -} - -template <> -void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream) { - if (size % 2 == 0) { - GeluGradKernel<<>>( - size / 2, reinterpret_cast(dy_addr), reinterpret_cast(x_addr), - reinterpret_cast(dx_addr)); - } else { - GeluGradKernel<<>>(size, dy_addr, x_addr, dx_addr); - } - return; -} - -template void Gelu(size_t size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); -template void Gelu(size_t size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); -template void GeluGradKernel(size_t size, float *dy_addr, float *x_addr, float *dx_addr, cudaStream_t cuda_stream); -template void GeluGradKernel(size_t size, half *dy_addr, half *x_addr, half *dx_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh deleted file mode 100644 index 7a8e1fae8a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ - -#include "device/gpu/cuda_common.h" -template -void Gelu(size_t input_size, T* input_addr, T* output_addr, cudaStream_t cuda_stream); - -template -void GeluGradKernel(size_t size, T* dy_addr, T* x_addr, T* dx_addr, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu deleted file mode 100644 index e887b98eca..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" -#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" - -constexpr int NUM_PER_THREAD_REDUCE = 4; -constexpr int WARP_SIZE = 32; - -template -inline __device__ T my_pow(T a, double b) { - return pow(a, static_cast(b)); -} - -template <> -inline __device__ half my_pow(half a, double b) { - return __float2half(pow(__half2float(a), static_cast(b))); -} - -template -inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, - const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, - T* dg, T* db) { - int loop_num = (row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int row = NUM_PER_THREAD_REDUCE * i + j; - if (row >= row_dim) { - return; - } - - int pos = row * col_dim + col; - dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); - db[0] += dy[pos]; - } - } -} - -template -inline __device__ void GammaAndBetaWarpReduce(T* dg, T* db) { - for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { - dg[0] += __shfl_down_sync(0xffffffff, dg[0], delta); - db[0] += __shfl_down_sync(0xffffffff, db[0], delta); - } -} - -template -inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_dim, T* dg, T* db, T* dg_addr, - T* db_addr) { - if (threadIdx.x >= row_dim) { - return; - } - - // load data to share memory - // thread(0, 32, 64, 96, ...) keep the data - DynamicSharedMem share_mem; - if (threadIdx.x % WARP_SIZE == 0) { - int offset = threadIdx.x / WARP_SIZE * 2; - share_mem.addr()[offset] = dg[0]; - share_mem.addr()[offset + 1] = db[0]; - } - __syncthreads(); - - for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { - if (threadIdx.x < stride) { - int offset = (threadIdx.x + stride) * 2; - share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset]; - share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1]; - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - dg_addr[col] = share_mem.addr()[0]; - db_addr[col] = share_mem.addr()[1]; - } -} - -template -__global__ void GammaAndBetaPropKernel(const int row_dim, const int col_dim, const T epsilon, const T* dy, const T* x, - const T* mean_addr, const T* var_addr, T* dg_addr, T* db_addr) { - // row: [0:param_axis] - // col: [param_axis:] - // dg[i][j] = dy[i][j] * (var[i] + epsilon, -0.5) * (x[i][j] - mean[i]) - // dg[j] = \Sigma_{j}dg[i][j] - for (int col = blockIdx.x; col < col_dim; col += gridDim.x) { - T dg = 0; - T db = 0; - GammaAndBetaThreadReduce(col, row_dim, col_dim, epsilon, dy, x, mean_addr, var_addr, &dg, &db); - GammaAndBetaWarpReduce(&dg, &db); - GammaAndBetaBlockReduce(col, row_dim, &dg, &db, dg_addr, db_addr); - } -} - -template -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - T* sum1, T* sum2, T* sum3, const T* dy, const T* x, const T* mean, - const T* var, const T* gamma) { - int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int col = NUM_PER_THREAD_REDUCE * i + j; - if (col >= col_dim) { - return; - } - - int pos = row * col_dim + col; - int gamma_offset = pos % param_dim; - T v1 = dy[pos] * gamma[gamma_offset]; - T v2 = x[pos] - mean[row]; - - sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5); - sum2[0] += v1; - sum3[0] += -2.0 * v2; - } - } -} - -template <> -inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - half* sum1, half* sum2, half* sum3, const half* dy, const half* x, - const half* mean, const half* var, const half* gamma) { - int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int col = NUM_PER_THREAD_REDUCE * i + j; - if (col >= col_dim) { - return; - } - - int pos = row * col_dim + col; - int gamma_offset = pos % param_dim; - half v1 = dy[pos] * gamma[gamma_offset]; - half v2 = x[pos] - mean[row]; - - sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5); - sum2[0] += v1; - sum3[0] += __float2half(-2.0) * v2; - } - } -} - -template -inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { - for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { - sum1[0] += __shfl_down_sync(0xffffffff, sum1[0], delta); - sum2[0] += __shfl_down_sync(0xffffffff, sum2[0], delta); - sum3[0] += __shfl_down_sync(0xffffffff, sum3[0], delta); - } -} - -template -inline __device__ void InputBlockReduce(const int& col_dim, T* sum1, T* sum2, T* sum3, T* share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - - // load data to share memory - // thread(0, 32, 64, 96, ...) keep the data - if (threadIdx.x % WARP_SIZE == 0) { - int offset = threadIdx.x / WARP_SIZE * 3; - share_mem[offset] = sum1[0]; - share_mem[offset + 1] = sum2[0]; - share_mem[offset + 2] = sum3[0]; - } - __syncthreads(); - - for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { - if (threadIdx.x < stride) { - int offset = (threadIdx.x + stride) * 3; - share_mem[threadIdx.x * 3] += share_mem[offset]; - share_mem[threadIdx.x * 3 + 1] += share_mem[offset + 1]; - share_mem[threadIdx.x * 3 + 2] += share_mem[offset + 2]; - } - } - __syncthreads(); -} - -template -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const T& epsilon, - const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, - const T* share_mem) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = (row * col_dim + col); - int gamma_offset = pos % param_dim; - T v1 = dy[pos] * gamma[gamma_offset]; - T v2 = x[pos] - mean[row]; - T v3 = my_pow(var[row] + epsilon, -0.5); - dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + - (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); - } -} - -template <> -inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, const half* share_mem) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = (row * col_dim + col); - int gamma_offset = pos % param_dim; - half v1 = dy[pos] * gamma[gamma_offset]; - half v2 = x[pos] - mean[row]; - half v3 = my_pow(var[row] + epsilon, -0.5); - dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + - (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ - * __float2half(1.0 / col_dim); - } -} - -template -__global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx) { - for (int row = blockIdx.x; row < row_dim; row += gridDim.x) { - T sum1 = 0; - T sum2 = 0; - T sum3 = 0; - DynamicSharedMem share_mem; - InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); - InputWarpReduce(&sum1, &sum2, &sum3); - InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr()); - InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr()); - } -} - -template -void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, - gamma, dx); - - share_mem_size = - ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); - GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); -} - -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, - const float* dy, const float* x, const float* mean, const float* var, const float* gamma, - float* dx, float* dg, float* db, cudaStream_t stream); -template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, - const half* dy, const half* x, const half* mean, const half* var, const half* gamma, - half* dx, half* dg, half* db, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh deleted file mode 100644 index 9f7d57cdb9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ - -#include "device/gpu/cuda_common.h" - -template -void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, - const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_GRAD_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu deleted file mode 100644 index cfb60f0ba6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" - -constexpr int NUM_PER_THREAD_REDUCE = 4; -constexpr int WARP_SIZE = 32; - -template -inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T &val) { - // Welford Algorithm: - // \mu_k = \mu_{k-1} + (x_k - \mu_{k-1})/k - // \sigma_k^2 = \sigma_{k-1}^2 + (x_k - \mu_{k-1}) * (x_k - \mu_k) - num[0]++; - T mean_new = mean[0] + (val - mean[0]) / num[0]; - var[0] = var[0] + (val - mean[0]) * (val - mean_new); - mean[0] = mean_new; -} - -template -inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { - T zero = 0; - if (n2 == zero) { - return; - } - - T count = n1[0] + n2; - v1[0] = v1[0] + v2 + (m1[0] - m2) * (m1[0] - m2) * n1[0] * n2 / count; - m1[0] = (n1[0] * m1[0] + n2 * m2) / count; - n1[0] = count; -} - -template -inline __device__ void ThreadReduce(const int &col_dim, const T *block_addr, T *mean, T *var, T *num) { - int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; - for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { - for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { - int pos = NUM_PER_THREAD_REDUCE * i + j; - if (pos >= col_dim) { - return; - } - MeanAndVarAccumulation(mean, var, num, block_addr[pos]); - } - } -} - -template -inline __device__ void WarpReduce(T *mean, T *var, T *num) { - for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { - T mean_other = __shfl_down_sync(0xffffffff, mean[0], delta); - T var_other = __shfl_down_sync(0xffffffff, var[0], delta); - T num_other = __shfl_down_sync(0xffffffff, num[0], delta); - MeanAndVarMerge(mean, var, num, mean_other, var_other, num_other); - } -} - -template -inline __device__ void BlockReduce(const int &col_dim, T *mean, T *var, T *num, T *mean_addr, T *var_addr, - T *share_mem) { - if (threadIdx.x >= col_dim) { - return; - } - - // load data to share memory - // thread(0, 32, 64, 96, ...) keep the data - if (threadIdx.x % WARP_SIZE == 0) { - int offset = threadIdx.x / WARP_SIZE * 3; - share_mem[offset] = mean[0]; - share_mem[offset + 1] = var[0]; - share_mem[offset + 2] = num[0]; - } - __syncthreads(); - - for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { - if (threadIdx.x < stride) { - int offset = (threadIdx.x + stride) * 3; - MeanAndVarMerge(&share_mem[threadIdx.x * 3], &share_mem[threadIdx.x * 3 + 1], &share_mem[threadIdx.x * 3 + 2], - share_mem[offset], share_mem[offset + 1], share_mem[offset + 2]); - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - mean_addr[blockIdx.x] = share_mem[0]; - share_mem[1] /= col_dim; - var_addr[blockIdx.x] = share_mem[1]; - } -} - -template -inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const T *x, - const T *share_mem, const T *gamma, const T *beta, const T epsilon, T *y) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = row * col_dim + col; - int i = pos % param_dim; - y[pos] = (x[pos] - share_mem[0]) / sqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; - } -} - -template <> -inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const half *x, - const half *share_mem, const half *gamma, const half *beta, const half epsilon, - half *y) { - for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { - int pos = row * col_dim + col; - int i = pos % param_dim; - y[pos] = (x[pos] - share_mem[0]) / hsqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; - } -} - -template -__global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, - const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { - for (auto row = blockIdx.x; row < row_dim; row += gridDim.x) { - T mean = 0; - T var = 0; - T num = 0; - const T *block_addr = x + row * col_dim; - DynamicSharedMem share_mem; - - ThreadReduce(col_dim, block_addr, &mean, &var, &num); - WarpReduce(&mean, &var, &num); - BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr()); - - __syncthreads(); - LayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y); - } -} - -template -void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const T &epsilon, const T *x, - const T *gamma, const T *beta, T *y, T *mean, T *var, cudaStream_t stream) { - const dim3 block(row_dim); - const dim3 thread(256); - // keep the mean/var/num after warp reduce - int share_mem_size = - ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, - mean, var); -} - -template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, - const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, - cudaStream_t stream); -template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, - const half *x, const half *gamma, const half *beta, half *y, half *mean, half *var, - cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh deleted file mode 100644 index c06a698384..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ - -#include "device/gpu/cuda_common.h" - -template -struct DynamicSharedMem; -template<> -struct DynamicSharedMem { - __device__ float *addr() { - extern __shared__ float addr_float[]; - return addr_float; - } -}; -template<> -struct DynamicSharedMem { - __device__ half *addr() { - extern __shared__ half addr_half[]; - return addr_half; - } -}; - -template -void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, - const T* beta, T* y, T* mean, T* var, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_LAYER_NORM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu deleted file mode 100644 index 27b2cb0232..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "minmax_update_impl.cuh" -#include "device/gpu/cuda_common.h" - -__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, - float *output_max, const float min, const float max, - const float decay) { - output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); - output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; - output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); - output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; - return; -} - -__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) { - output_min[0] = min > 0 ? 0 : min; - output_max[0] = max < 0 ? 0 : max; - return; -} - -__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, - float *output_max, int channels, int per_channel_nums, bool ema, - float ema_decay) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { - thrust::pair sum = - thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); - if (ema) { - output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; - output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i]; - } else { - output_min[i] = sum.first[0]; - output_max[i] = sum.second[0]; - } - output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; - output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; - } - return; -} - -void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const int channel_num, const float ema_decay, const bool ema, - cudaStream_t cuda_stream) { - int per_channel_num = total_num / channel_num; - UpdateInputMinMaxPerChannel<<>>( - input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay); - return; -} - -void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { - float minel = 0.f; - float maxel = 0.f; - auto policy = thrust::cuda::par.on(cuda_stream); - thrust::pair, thrust::device_ptr> tuple; - tuple = - thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num); - minel = tuple.first[0]; - maxel = tuple.second[0]; - - if (ema) { - UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, - maxel, ema_decay); - } else { - UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel); - } - return; -} diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh deleted file mode 100644 index 5e9becab38..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ - -#include "device/gpu/cuda_common.h" - -void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int total_num, const int channel_num, const float ema_decay, const bool ema, - cudaStream_t cuda_stream); - -void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, - const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh deleted file mode 100755 index 5405f5ef1d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, - const S *momentum, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cu deleted file mode 100644 index cf5dc7ecd0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/one_hot_impl.cu +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "one_hot_impl.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void OneHotKernel(size_t size, const S *indices, size_t depth, const T *on_value, const T *off_value, - size_t left_dim_size, size_t right_dim_size, T *output) { - T on_v = *on_value; - T off_v = *off_value; - for (int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; thread_idx < size; - thread_idx += blockDim.x * gridDim.x) { - if (thread_idx < size) { - int left_idx = (thread_idx / (depth * right_dim_size)) % left_dim_size; - int d_idx = thread_idx / right_dim_size % depth; - int right_idx = thread_idx % right_dim_size; - int input_idx = left_idx * right_dim_size + right_idx; - int output_idx = left_idx * depth * right_dim_size + d_idx * right_dim_size + right_idx; - if (indices[input_idx] == d_idx) { - output[output_idx] = on_v; - } else { - output[output_idx] = off_v; - } - } - } -} -template -void OneHot(const S *indices, size_t depth, const T *on_value, const T *off_value, size_t left_dim_size, - size_t right_dim_size, T *output, cudaStream_t cuda_stream) { - size_t size = left_dim_size * depth * right_dim_size; - OneHotKernel<<>>(size, indices, depth, on_value, off_value, - left_dim_size, right_dim_size, output); - return; -} -template void OneHot(const int *indices, size_t depth, const float *on_value, const float *off_value, - size_t left_dim_size, size_t right_dim_size, float *output, cudaStream_t cuda_stream); -template void OneHot(const int *indices, size_t depth, const half *on_value, const half *off_value, - size_t left_dim_size, size_t right_dim_size, half *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cu deleted file mode 100755 index ddc615d94b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cu +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "kernel/gpu/cuda_impl/pad_impl.cuh" - -template -__global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, float pad_value, T* output) { - T pad_value_ = static_cast(pad_value); - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int block_num = pos / padded_width / padded_height; - const int padded_w = pos % padded_width; - const int padded_h = pos / padded_width % padded_height; - if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || - padded_w - pad_left >= old_width) { - output[pos] = pad_value_; - } else { - output[pos] = input[(block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; - } - } - return; -} - -template -__global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, T* dx) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - int block_num = pos / old_width / old_height; - const int padded_w = pos % old_width + pad_left; - const int padded_h = pos / old_width % old_height + pad_top; - dx[pos] = dy[(block_num * padded_height + padded_h) * padded_width + padded_w]; - } - return; -} - -template -void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, - const float pad_value, T* output, cudaStream_t cuda_stream) { - Pad<<>>(size, input, num, channels, old_height, old_width, - padded_height, padded_width, pad_top, pad_left, pad_value, - output); - return; -} - -template -void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, T* dx, cudaStream_t cuda_stream) { - PadGrad<<>>(size, dy, num, channels, old_height, old_width, - padded_height, padded_width, pad_top, pad_left, dx); - return; -} - -template void CalPad(const size_t size, const float* input, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, float* output, - cudaStream_t cuda_stream); -template void CalPadGrad(const size_t size, const float* dy, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, - const int padded_width, const int pad_top, const int pad_left, float* dx, - cudaStream_t cuda_stream); -template void CalPad(const size_t size, const half* input, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, const int padded_width, - const int pad_top, const int pad_left, float pad_value, half* output, - cudaStream_t cuda_stream); -template void CalPadGrad(const size_t size, const half* dy, const int num, const int channels, - const int old_height, const int old_width, const int padded_height, - const int padded_width, const int pad_top, const int pad_left, half* dx, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cuh deleted file mode 100755 index dc3036b8b6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/pad_impl.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ -#include -#include "device/gpu/cuda_common.h" - -template -void CalPad(const size_t size, const T* input, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, - float pad_value, T* output, cudaStream_t cuda_stream); -template -void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, - const int old_width, const int padded_height, const int padded_width, const int pad_top, - const int pad_left, T* dx, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh deleted file mode 100644 index 5e9110a1bc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/random_op_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ - -#include -#include "device/gpu/cuda_common.h" - -template -void StandardNormal(int seed, int seed2, curandState *globalState, - T *output, size_t count, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu deleted file mode 100644 index 913aaa3b8d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void RmsPropKernel(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, - T* mean_square, T*moment, T* gradients, const size_t size) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { - mean_square[i] = decay * mean_square[i] + (1.0 - decay) * gradients[i] * gradients[i]; - moment[i] = momentum * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon) * gradients[i]; - variable[i] -= moment[i]; - } -} - -template -void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, - T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { - RmsPropKernel<<>>(learning_rate, decay, momentum, epsilon, - variable, mean_square, moment, gradients, size); -} - -template -__global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, - T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients, - const size_t size) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { - mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i]; - mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; - moment[i] = momentum[0] * moment[i] + learning_rate[0] * - rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i]; - variable[i] -= moment[i]; - } -} - -template -void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, - T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size, - cudaStream_t cuda_stream) { - RmsPropCenterKernel<<>>(learning_rate, decay, momentum, epsilon, - variable, mean_gradients, mean_square, - moment, gradients, size); -} - -template -void RmsProp(const float* learning_rate, const float decay, const float momentum, const float epsilon, - float* variable, float* mean_square, float* moment, float* gradients, const size_t size, - cudaStream_t cuda_stream); - -template -void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, - float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients, - const size_t size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh deleted file mode 100644 index b5802dbb67..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ -#include "device/gpu/cuda_common.h" - -template -void RmsProp(const T* learning_rate, const T decay, const T momentum, const T epsilon, T* variable, T* mean_square, - T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); - -template -void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, - T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cu deleted file mode 100644 index f07a820e75..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cu +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "kernel/gpu/cuda_impl/select_impl.cuh" - -template -__global__ void Select(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - output[pos] = cond[pos] ? input_x[pos] : input_y[pos]; - } - return; -} - -template -void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, - cudaStream_t cuda_stream) { - Select<<>>(size, cond, input_x, input_y, output); - return; -} - -template void CalSelect(const size_t size, const bool* cond, const float* input_X, const float* input_y, - float* output, cudaStream_t cuda_stream); -template void CalSelect(const size_t size, const bool* cond, const int* input_X, const int* input_y, int* output, - cudaStream_t cuda_stream); -template void CalSelect(const size_t size, const bool* cond, const half* input_X, const half* input_y, - half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cuh deleted file mode 100644 index da2d7d9a7f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/select_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ - -#include "device/gpu/cuda_common.h" - -template -void CalSelect(const size_t size, const bool* cond, const T* input_x, const T* input_y, T* output, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SELECT_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu deleted file mode 100644 index a0082b84c8..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" - -template -__global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, - T *outputs) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - if (logits[i] >= 0) { - outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; - } else { - const T exp_val = exp(logits[i]); - outputs[i] = exp_val / (1. + exp_val) - labels[i]; - } - } -} - -template -void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream) { - SigmoidCrossEntropyWithLogitsGradKernel<<>>(size, logits, labels, - outputs); -} - -template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, - const float *labels, float *outputs, - cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh deleted file mode 100644 index 2cd4922d25..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu deleted file mode 100644 index 3766f367db..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cu +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" - -template -__global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { - const T reverse_factor = static_cast(logits[i] >= 0); - outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); - } -} - -template -void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream) { - SigmoidCrossEntropyWithLogitsKernel<<>>(size, logits, labels, outputs); -} - -template void SigmoidCrossEntropyWithLogits(const size_t size, const float *logits, const float *labels, - float *outputs, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh deleted file mode 100644 index 575605bde0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu deleted file mode 100755 index e49a22bb46..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "kernel/gpu/cuda_impl/slice_impl.cuh" - -template -__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { - int i = pos / (l2 * l3 * l4) % l1; - int j = pos / (l3 * l4) % l2; - int k = pos / l4 % l3; - int o = pos % l4; - - int offset = (i + s1) * (d2 * d3 * d4) + - (j + s2) * (d3 * d4) + - (k + s3) * d4 + - (o + s4); - output[pos] = input[offset]; - } -} -template -__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { - output[start + pos] = dy[p + pos]; - } - return; -} -template -__global__ void StridedSlice(const T* input, int p, int start, int begin, int stride, int ended, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); - pos += blockDim.x * gridDim.x) { - output[p + pos] = input[start + pos * stride]; - } - return; -} -template -__global__ void StridedSliceGrad(const T* dy, int p, int start, int begin, int stride, int ended, T* dx) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < std::ceil(static_cast(ended - begin) / stride); - pos += blockDim.x * gridDim.x) { - dx[start + pos * stride] = dy[p + pos]; - } - return; -} -template -__global__ void FillArray(T* addr, const size_t len, const float value) { - T value_ = static_cast(value); - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < len; pos += blockDim.x * gridDim.x) { - addr[pos] = value_; - } - return; -} -template -void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream) { - FillArray<<>>(addr, input_size, value); - return; -} -template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output, cudaStream_t stream) { - Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, - d1, d2, d3, d4, input, output); -} -template -void CalSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, const std::vector begin, - const std::vector size, T* output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int length = size[3]; - int p = 0; - for (int i = begin[0]; i < size[0] + begin[0]; i++) { - for (int j = begin[1]; j < size[1] + begin[1]; j++) { - for (int k = begin[2]; k < size[2] + begin[2]; k++) { - SliceGrad<<>>( - dy, p, i * block + j * map + k * w + begin[3], length, output); - p = p + size[3]; - } - } - } -} -template -void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int ended = end[3]; - int p = 0; - int start = 0; - for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0])); i += std::abs(strides[0])) { - for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1])); j += std::abs(strides[1])) { - for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2])); k += std::abs(strides[2])) { - start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + - (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; - StridedSlice<<>>(input, p, start, begin[3], strides[3], - ended, output); - p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); - } - } - } -} -template -void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* dx, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int ended = end[3]; - int p = 0; - int start = 0; - for (int i = begin[0]; i < ((end[0] > begin[0]) ? end[0] : (2 * begin[0] - end[0] + 1)); i += std::abs(strides[0])) { - for (int j = begin[1]; j < ((end[1] > begin[1]) ? end[1] : (2 * begin[1] - end[1] + 1)); - j += std::abs(strides[1])) { - for (int k = begin[2]; k < ((end[2] > begin[2]) ? end[2] : (2 * begin[2] - end[2] + 1)); - k += std::abs(strides[2])) { - start = (strides[0] > 0 ? i : 2 * begin[0] - i) * block + (strides[1] > 0 ? j : 2 * begin[1] - j) * map + - (strides[2] > 0 ? k : 2 * begin[2] - k) * w + begin[3]; - StridedSliceGrad<<>>(dy, p, start, begin[3], strides[3], - ended, dx); - p = p + std::ceil(static_cast(end[3] - begin[3]) / strides[3]); - } - } - } -} - -template void FillDeviceArray(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const float *input, float *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, float* output, - cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const float* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, float* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, float* dx, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const half *input, half *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, half* output, - cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const half* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, half* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, half* dx, cudaStream_t cuda_stream); -template void FillDeviceArray(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); -template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const int *input, int *output, cudaStream_t stream); -template void CalSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, - const std::vector begin, const std::vector size, int* output, - cudaStream_t cuda_stream); -template void CalStridedSlice(const size_t input_size, const int* input, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, int* output, cudaStream_t cuda_stream); -template void CalStridedSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, - const std::vector strides, int* dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh deleted file mode 100755 index 9513d6ed24..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ - -#include -#include -#include "device/gpu/cuda_common.h" - - -template -void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, - const int l1, const int l2, const int l3, const int l4, - const int d1, const int d2, const int d3, const int d4, - const T *input, T *output, cudaStream_t stream); -template -void CalSliceGrad(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector size, T* output, cudaStream_t cuda_stream); -template -void CalStridedSlice(const size_t input_size, const T* input, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* output, cudaStream_t cuda_stream); -template -void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, - const std::vector begin, const std::vector end, const std::vector strides, - T* dx, cudaStream_t cuda_stream); -template -void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SLICEIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cu deleted file mode 100644 index bebcd50a0f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cu +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "smooth_l1_loss_impl.cuh" -#include "device/gpu/cuda_common.h" - -template -__global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target, - T *loss) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]); - if (value < sigma) { - loss[i] = static_cast(0.5) * value * value; - } else { - loss[i] = value - static_cast(0.5); - } - } -} - -template -void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, - cudaStream_t stream) { - SmoothL1LossKernel<<>>(input_size, sigma, prediction, target, loss); -} - -template -__global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target, - const T *dloss, T *dx) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T value = prediction[i] - target[i]; - if (value > static_cast(sigma)) { - dx[i] = dloss[i]; - } else if (value < static_cast(-sigma)) { - dx[i] = -dloss[i]; - } else { - dx[i] = value * dloss[i]; - } - } -} - -template -void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, - T *dx, cudaStream_t stream) { - SmoothL1LossGradKernel<<>>(input_size, sigma, prediction, target, - dloss, dx); -} - -template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target, - float *loss, cudaStream_t stream); -template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target, - const float *dloss, float *dx, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh deleted file mode 100755 index d16131470c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/sparse_cross_entropy_cuda_impl.cuh +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ - -#include "device/gpu/cuda_common.h" - -template -void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss, - cudaStream_t cuda_stream); - -template -void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cu deleted file mode 100755 index a0fea90136..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/transpose_impl.cu +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "transpose_impl.cuh" -#include "device/gpu/cuda_common.h" -template -__global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, - const int shape_size, T* output) { - int pos_size; - int temp_pos; - int newpos; - int newpos_size; - int pos_array[TRANSPOSE_MAX_DIMENSION]; - - // for example 4-D: pos = posArray[0] * input_shape[1] * input_shape[2] * input_shape[3] + - // posArray[1] * input_shape[2] * input_shape[3] + - // posArray[2] * input_shape[3] + - // posArray[3] - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { - temp_pos = pos; - pos_size = size / input_shape[0]; - pos_array[0] = temp_pos / pos_size; - for (int i = 1; i < shape_size; i++) { - temp_pos -= pos_array[i - 1] * pos_size; - pos_size = pos_size / input_shape[i]; - pos_array[i] = temp_pos / pos_size; - } - - newpos = pos_array[input_axis[shape_size - 1]]; - newpos_size = 1; - for (int j = shape_size - 2; j >= 0; j--) { - newpos_size *= input_shape[input_axis[j + 1]]; - newpos += pos_array[input_axis[j]] * newpos_size; - } - - output[newpos] = input[pos]; - } - return; -} -template -void CalTranspose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, - T* output, cudaStream_t cuda_stream) { - Transpose<<>>(size, input, input_shape, input_axis, shape_size, - output); - return; -} - -template void CalTranspose(const int size, const float* input, const int* input_shape, const int* input_axis, - const int shape_size, float* output, cudaStream_t cuda_stream); -template void CalTranspose(const int size, const half* input, const int* input_shape, const int* input_axis, - const int shape_size, half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh deleted file mode 100755 index 623b1a8c03..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ - -#include "device/gpu/cuda_common.h" -template -void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Logarithm(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); -template -void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu deleted file mode 100644 index a7affd4705..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cu +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/cuda_impl/unsorted_segment_sum.cuh" - -template -__global__ void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - T* input_addr, S* ids_addr, T* output_addr) { - for (int input_index = blockIdx.x * blockDim.x + threadIdx.x; input_index < input_dim0 * input_dim1; - input_index += blockDim.x * gridDim.x) { - size_t j = input_index / input_dim1; - size_t k = input_index % input_dim1; - - S i = ids_addr[j]; - if (i < 0 || i >= output_dim0) { - continue; - } - size_t output_index = i * output_dim1 + k; - atomicAdd(output_addr + output_index, input_addr[input_index]); - } -} - -template -void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - T* input_addr, S* ids_addr, T* output_addr, cudaStream_t stream) { - int size = input_dim0 * input_dim1; - UnsortedSegmentSum<<>>(input_dim0, input_dim1, - output_dim0, output_dim1, input_addr, ids_addr, output_addr); - return; -} - -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - float* input_addr, int* ids_addr, float* output_addr, cudaStream_t stream); -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - float* input_addr, int64_t* ids_addr, float* output_addr, cudaStream_t stream); - -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - int* input_addr, int* ids_addr, int* output_addr, cudaStream_t stream); -template void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - int* input_addr, int64_t* ids_addr, int* output_addr, cudaStream_t stream); - - - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh deleted file mode 100644 index ef95032996..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unsorted_segment_sum.cuh +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ - -#include -#include "device/gpu/cuda_common.h" - -template -void UnsortedSegmentSum(size_t input_dim0, size_t input_dim1, size_t output_dim0, size_t output_dim1, - T* input_addr, S* ids, T* output_addr, cudaStream_t stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNSORT_SEGMENT_SUM_H_ diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.cc deleted file mode 100644 index 777310cebc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/data/dataset_init_kernel.h" -#include "kernel/gpu/data/dataset_utils.h" -#include "device/gpu/gpu_buffer_mgr.h" -#include "device/gpu/gpu_memory_allocator.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::GpuBufferMgr; - -DatasetInitKernel::DatasetInitKernel() : total_bytes_(0) {} - -const std::vector &DatasetInitKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &DatasetInitKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &DatasetInitKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool DatasetInitKernel::Init(const CNodePtr &kernel_node) { - queue_name_ = GetAttr(kernel_node, "queue_name"); - auto shapes = GetAttr>>(kernel_node, "shapes"); - auto types = GetAttr>(kernel_node, "types"); - if (shapes.size() != types.size()) { - MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; - } - - for (size_t i = 0; i < shapes.size(); i++) { - int unit = UnitSizeInBytes(types[i]->type_id()); - int nums = ElementNums(shapes[i]); - int bytes = unit * nums; - shapes_.push_back(bytes); - total_bytes_ += bytes; - } - return true; -} - -void DatasetInitKernel::InitSizeLists() { return; } - -bool DatasetInitKernel::Launch(const std::vector &, const std::vector &, - const std::vector &, void *) { - void *addr = nullptr; - size_t len = total_bytes_ * buffer_q_capacity_; - - if (!device::gpu::GPUMemoryAllocator::GetInstance().AllocBufferQueueMem(len, &addr)) { - MS_LOG(EXCEPTION) << "Memory not enough: failed to allocate GPU buffer queue memory[" << len << "]."; - } - - auto status = GpuBufferMgr::GetInstance().Create(0, queue_name_, addr, shapes_, buffer_q_capacity_); - if (status) { - MS_LOG(EXCEPTION) << "Init Dataset Failed. len: " << len << ", status:" << status; - } - - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.h b/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.h deleted file mode 100644 index 318049f4ad..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_init_kernel.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_DATASET_INIT_KERNEL_H -#define MINDSPORE_DATASET_INIT_KERNEL_H - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class DatasetInitKernel : public GpuKernel { - public: - DatasetInitKernel(); - ~DatasetInitKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - std::string queue_name_; - std::vector shapes_; - size_t total_bytes_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - // The capacity of buffer Q. - size_t buffer_q_capacity_{2}; -}; - -MS_REG_GPU_KERNEL(InitDataSetQueue, DatasetInitKernel) -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc deleted file mode 100644 index 13ca191b0b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/data/dataset_iterator_kernel.h" -#include -#include -#include -#include "device/gpu/gpu_buffer_mgr.h" -#include "device/gpu/gpu_common.h" -#include "kernel/gpu/data/dataset_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::GpuBufferMgr; -using mindspore::device::HandleMgr; - -DatasetIteratorKernel::DatasetIteratorKernel() : handle_(HandleMgr::INVALID_HANDLE), total_bytes_(0) {} - -DatasetIteratorKernel::~DatasetIteratorKernel() { GpuBufferMgr::GetInstance().Close(handle_); } - -const std::vector &DatasetIteratorKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &DatasetIteratorKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &DatasetIteratorKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool DatasetIteratorKernel::Init(const CNodePtr &kernel_node) { - queue_name_ = GetAttr(kernel_node, "shared_name"); - auto shapes = GetAttr>>(kernel_node, "shapes"); - auto types = GetAttr>(kernel_node, "types"); - if (shapes.size() != types.size()) { - MS_LOG(EXCEPTION) << "Invalid shapes: " << shapes << ", types: " << types; - } - - for (size_t i = 0; i < shapes.size(); i++) { - int unit = UnitSizeInBytes(types[i]->type_id()); - int nums = ElementNums(shapes[i]); - int bytes = unit * nums; - output_size_list_.push_back(bytes); - total_bytes_ += bytes; - } - - handle_ = GpuBufferMgr::GetInstance().Open(0, queue_name_, output_size_list_); - if (handle_ == HandleMgr::INVALID_HANDLE) { - MS_LOG(EXCEPTION) << "Gpu Queue(" << queue_name_ << ") Open Failed"; - } - - return true; -} - -void DatasetIteratorKernel::InitSizeLists() { return; } - -bool DatasetIteratorKernel::Launch(const std::vector &, const std::vector &, - const std::vector &outputs, void *stream) { - void *addr = nullptr; - size_t len = 0; - - int repeat = 0; - while (true) { - auto ret = GpuBufferMgr::GetInstance().Front(handle_, &addr, &len); - if (ret == device::SUCCESS) { - break; - } - - if (ret == device::TIMEOUT) { - repeat++; - if (repeat < 10) { - MS_LOG(INFO) << "Waiting for data...(" << repeat << " / 10)"; - continue; - } else { - MS_LOG(ERROR) << "Get data timeout"; - return false; - } - } - - MS_LOG(ERROR) << "Get data failed, errcode " << ret; - return false; - } - - if (total_bytes_ != len) { - MS_LOG(ERROR) << "Dataset front error. read: " << len << ", expect: " << total_bytes_ << ", "; - return false; - } - - for (size_t i = 0; i < output_size_list_.size(); i++) { - void *output_addr = GetDeviceAddress(outputs, i); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(output_addr, addr, output_size_list_[i], cudaMemcpyDeviceToDevice, - reinterpret_cast(stream)), - "Cuda Memcpy Failed"); - addr = reinterpret_cast(addr) + output_size_list_[i]; - } - - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream)), - "cudaStreamSynchronize failed"); - (void)GpuBufferMgr::GetInstance().Pop(handle_); - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.h b/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.h deleted file mode 100644 index cdd7a47e7b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_iterator_kernel.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_GET_NEXT_KERNEL_H -#define MINDSPORE_GET_NEXT_KERNEL_H - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class DatasetIteratorKernel : public GpuKernel { - public: - DatasetIteratorKernel(); - ~DatasetIteratorKernel(); - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - std::string queue_name_; - unsigned int handle_; - size_t total_bytes_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -MS_REG_GPU_KERNEL(GetNext, DatasetIteratorKernel) -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_QUEUE_CPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/data/dataset_utils.cc b/mindspore/ccsrc/kernel/gpu/data/dataset_utils.cc deleted file mode 100644 index 846a63f84f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/data/dataset_utils.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/data/dataset_utils.h" - -namespace mindspore { -namespace kernel { -size_t UnitSizeInBytes(const mindspore::TypeId &t) { - size_t bytes = 0; - switch (t) { - case kNumberTypeBool: - case kNumberTypeInt8: - case kNumberTypeUInt8: - bytes = 1; - break; - case kNumberTypeInt16: - case kNumberTypeUInt16: - case kNumberTypeFloat16: - bytes = 2; - break; - case kNumberTypeInt: - case kNumberTypeUInt: - case kNumberTypeInt32: - case kNumberTypeUInt32: - case kNumberTypeFloat: - case kNumberTypeFloat32: - bytes = 4; - break; - case kNumberTypeUInt64: - case kNumberTypeInt64: - case kNumberTypeFloat64: - bytes = 8; - break; - default: - MS_LOG(EXCEPTION) << "Invalid types " << t; - break; - } - - return bytes; -} - -int ElementNums(const std::vector &shape) { - if (shape.size() == 0) { - return 0; - } - - int nums = 1; - for (size_t i = 0; i < shape.size(); i++) { - nums *= shape[i]; - } - - return nums; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/gpu_kernel.h deleted file mode 100644 index c935798f06..0000000000 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ - -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "kernel/gpu/kernel_constants.h" -#include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_common.h" -#include "session/anf_runtime_algorithm.h" -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; - -namespace mindspore { -namespace kernel { -class GpuKernel : public KernelMod { - public: - virtual ~GpuKernel() = default; - virtual bool Init(const CNodePtr &kernel_node) = 0; - - protected: - virtual void InitResource() {} - virtual void InitSizeLists() = 0; - - template - inline T *GetDeviceAddress(const std::vector &addr_list, size_t index) { - if (index >= addr_list.size()) { - MS_LOG(EXCEPTION) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; - } - // Kernels may run normally without workspace, the addr_list[index] maybe nullptr. - if ((addr_list[index] == nullptr) || (addr_list[index]->size == 0)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(addr_list[index]->addr); - return reinterpret_cast(addr_list[index]->addr); - } - - template - inline T GetAttr(const CNodePtr &kernel_node, const std::string &key) const { - const PrimitivePtr &prim = AnfAlgo::GetCNodePrimitive(kernel_node); - const ValuePtr &attr = prim->GetAttr(key); - if (attr == nullptr) { - const std::string &prim_name = AnfAlgo::GetCNodeName(kernel_node); - MS_LOG(EXCEPTION) << "The attr(" << key << ") of kernel(" << prim_name << ") not exist"; - } - return GetValue(attr); - } - // expand Nd Shape to 4d (N in [0,4]) - void ShapeNdTo4d(const std::vector &src, std::vector *dst) { - if (src.size() > 4) { - MS_EXCEPTION(ValueError) << src.size() << "-D data is not supported!"; - } - dst->push_back(src.size() < 4 ? 1 : SizeToInt(src[src.size() - 4])); - dst->push_back(src.size() < 3 ? 1 : SizeToInt(src[src.size() - 3])); - dst->push_back(src.size() < 2 ? 1 : SizeToInt(src[src.size() - 2])); - dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); - } - - inline void CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, - const std::vector &Out) { - if (A != Out && B != Out) { - MS_EXCEPTION(ValueError) - << "Double-sided broadcast was not supported in cudnn of cudnnOpTensor:\n" - "InputA must match the corresponding dimension of the destination tensor outC, and each " - "dimension of the inputB " - "must match the corresponding dimension of outC or must be equal to 1."; - } - } - - // choose the suitable datatype for cudnn/cublas - inline cudnnDataType_t GetCudnnDataType(const std::string &Type) { - auto type = kCudnnDtypeMap.find(Type); - if (type == kCudnnDtypeMap.end()) { - MS_EXCEPTION(TypeError) << Type << " is not supported."; - } - return type->second; - } - inline cudaDataType_t GetCudaDataType(const std::string &Type) { - auto type = kCudaDtypeMap.find(Type); - if (type == kCudaDtypeMap.end()) { - MS_EXCEPTION(TypeError) << Type << " is not supported."; - } - return type->second; - } -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc deleted file mode 100644 index b00b5c263d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/gpu_kernel_factory.h" - -#include -#include - -#include "common/utils.h" -#include "device/kernel_info.h" -#include "device/gpu/cuda_common.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -GpuKernelFactory &GpuKernelFactory::GetInstance() { - static GpuKernelFactory instance; - return instance; -} - -void GpuKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr, - GpuKernelCreater &&creater) { - map_kernel_name_to_creater_[kernel_name].emplace_back(kernel_attr, creater); -} - -void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, - std::vector> *iter_second, - size_t attr_index) { - if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { - if (iter_second->at(attr_index).first.GetAllSame()) { - auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; - for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddInputAttr(dtype); - } - } else { - MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; - } - } - if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { - if (iter_second->at(attr_index).first.GetAllSame()) { - auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; - for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); - } - } else { - MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; - } - } -} - -std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name) { - std::string type_lists = ""; - auto iter = map_kernel_name_to_creater_.find(kernel_name); - if (map_kernel_name_to_creater_.end() == iter) { - return type_lists; - } - for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { - std::string type_list = "in["; - auto attr = (iter->second)[attr_index].first; - for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) { - type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) + - ((input_index == (attr.GetInputSize() - 1)) ? "" : " "); - } - type_list = type_list + "], out["; - for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) { - type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) + - ((input_index == (attr.GetOutputSize() - 1)) ? "" : " "); - } - type_lists = type_lists + type_list + "]; "; - } - return type_lists; -} - -std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string &kernel_name, - const KernelBuildInfo *kernel_info) { - auto iter = map_kernel_name_to_creater_.find(kernel_name); - const int marjor_sm = GET_MAJOR_SM; - if (map_kernel_name_to_creater_.end() == iter) { - MS_LOG(INFO) << "Not registered GPU kernel: op[" << kernel_name << "]!"; - return std::make_pair(false, 0); - } - if ((iter->second).size() == 1 && (iter->second)[0].first.GetInputSize() == 0) { - return std::make_pair(true, 0); - } - - for (size_t attr_index = 0; attr_index < (iter->second).size(); ++attr_index) { - CheckIOParam(kernel_name, kernel_info, &(iter->second), attr_index); - bool flag = true; - // data type matching check of all input parameters of kernel - for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { - if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { - if (marjor_sm < MINIUM_SM) { - MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM - << ", but the current device's computing capacity is " << marjor_sm; - } - MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM - << ", but the current device's computing capacity is " << marjor_sm; - } - if (kernel_info->GetInputDeviceType(input_index) != - (iter->second)[attr_index].first.GetInputAttr(input_index).first) { - flag = false; - break; - } - } - if (!flag) { - continue; - } - // data type matching check of all output parameters of kernel - for (size_t output_index = 0; output_index < kernel_info->GetOutputNum(); output_index++) { - if (kernel_info->GetOutputDeviceType(output_index) != - (iter->second)[attr_index].first.GetOutputAttr(output_index).first) { - flag = false; - break; - } - } - // finish data type matching check and return a pair maintain the whether matching is success, - // if first is true, second is index of matching KernelAttr and creater pair in vector; - if (flag) { - size_t match_index = attr_index; - return std::make_pair(true, match_index); - } - } - return std::make_pair(false, 0); -} - -GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { - auto kernel_info = apply_kernel->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(kernel_build_Info); - std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_Info); - if (ret_pair.first) { - return (map_kernel_name_to_creater_.find(kernel_name)->second)[ret_pair.second].second(); - } - return nullptr; -} - -bool GpuKernelFactory::SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_build_info) { - std::pair ret_pair = GpuKernelAttrCheck(kernel_name, kernel_build_info.get()); - return ret_pair.first; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.h b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.h deleted file mode 100644 index dc5f61a315..0000000000 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ - -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "device/gpu/kernel_info_setter.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -using mindspore::device::gpu::KernelAttr; -using GpuKernelCreater = std::function; -class GpuKernelFactory { - public: - ~GpuKernelFactory() = default; - - static GpuKernelFactory &GetInstance(); - - void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater); - - GpuKernel *Create(const std::string &kernel_name, const CNodePtr &apply_kernel); - - bool SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_info); - - std::string SupportedTypeList(const std::string &kernel_name); - - private: - GpuKernelFactory() = default; - - GpuKernelFactory(GpuKernelFactory const &); - - GpuKernelFactory &operator=(const GpuKernelFactory &); - - std::pair GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); - void CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, - std::vector> *iter_second, size_t attr_index); - // map to maintain kernel and creater, KernelAttr object and creater must be registered as a pair. - std::map>> map_kernel_name_to_creater_; -}; - -class GpuKernelRegister { - public: - GpuKernelRegister(const std::string &kernel_name, const KernelAttr &kernel_attr, GpuKernelCreater &&creater) { - GpuKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(creater)); - } -}; - -#define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \ - static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, KernelAttr(), []() { return new OPCLASS(); }); - -// regular register of fixed accuracy kernels -#define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \ - static_assert(std::is_base_of::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); - -// register of mixed accuracy kernels which use template and maintain one typename, ignore input num -#define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \ - static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); - -// register of mixed accuracy kernels which use template and maintain one typename -#define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \ - static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_##T##_gpu_kernel_reg(#OPNAME, ATTR, []() { return new OPCLASS(); }); - -// register of mixed accuracy kernels which use template and maintain two typename -#define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \ - static_assert(std::is_base_of>::value, " must be base of GpuKernel"); \ - static const GpuKernelRegister g_##OPNAME##_##T##_##S##_gpu_kernel_reg(#OPNAME, ATTR, \ - []() { return new OPCLASS(); }); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_GPUKERNELFACTORY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.cc deleted file mode 100644 index 4683f015ae..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.cc +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/addn_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AddNGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - AddN, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - AddNGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(AddN, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AddNGpuFwdKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h deleted file mode 100644 index 41930d3d7b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/addn_gpu_kernel.h +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/math/broadcast_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/slice_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class AddNGpuFwdKernel : public GpuKernel { - public: - AddNGpuFwdKernel() - : cudnn_handle_(nullptr), - input_descriptor_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - input_size_(0), - output_size_(0), - workspace_size_(0), - is_null_input_(false), - num_input_(0) {} - ~AddNGpuFwdKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *output_addr = GetDeviceAddress(outputs, 0); - if (cudnn_data_type_ == CUDNN_DATA_INT32) { - FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast(stream_ptr)); - } - const float alpha = 1; - const float beta = 0; - for (size_t i = 0; i < IntToSize(num_input_); i++) { - T *input_addr = GetDeviceAddress(inputs, i); - if (cudnn_data_type_ == CUDNN_DATA_INT32) { - NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, - &(i > 0 ? alpha : beta), input_descriptor_, output_addr), - "cudnnAddTensor failed"); - } - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - num_input_ = GetAttr(kernel_node, "n"); - if (IntToSize(num_input_) != input_num) { - MS_LOG(ERROR) << "Input number is " << num_input_ << " in attr, but got " << input_num << "input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "AddNGpuFwdKernel input is null"; - InitSizeLists(); - return true; - } - for (size_t i = input_shape.size(); i < 4; i++) { - (void)input_shape.insert(input_shape.begin(), 1); - } - int dimA[4]; - for (size_t i = 0; i < input_shape.size(); i++) { - dimA[i] = SizeToInt(input_shape[i]); - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - SizeToInt(input_shape.size()), dimA), - "cudnnSetTensorNdDescriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - for (int i = 0; i < num_input_; i++) { - input_size_list_.push_back(input_size_); - } - output_size_list_.push_back(input_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnDataType_t cudnn_data_type_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - bool is_null_input_; - int num_input_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ADDN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.cc deleted file mode 100644 index 2ae1728ca3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/assign_add_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AssignAddGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE( - AssignAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AssignAddGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - AssignAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - AssignAddGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.h deleted file mode 100644 index db69fd7be6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/assign_add_gpu_kernel.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/assign_add_impl.cuh" -namespace mindspore { -namespace kernel { -template -class AssignAddGpuFwdKernel : public GpuKernel { - public: - AssignAddGpuFwdKernel() : is_null_input_(false), input_size_(0) {} - ~AssignAddGpuFwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *input_addr2 = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - CalAssignAdd(input_size_ / sizeof(T), input_addr, input_addr2, output_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "AssignAddGpuFwdKernel input is null"; - InitSizeLists(); - return true; - } - input_size_ = sizeof(T); - for (size_t i : input_shape) { - input_size_ = i * input_size_; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGNADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.cc deleted file mode 100644 index 5684f0c424..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/bias_add_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - BiasAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - BiasAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BiasAddGpuKernel, float16) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.h deleted file mode 100644 index 5a664db2e1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/bias_add_gpu_kernel.h +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H -#define MINDSPORE_BIAS_ADD_GPU_KERNEL_H -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class BiasAddGpuKernel : public GpuKernel { - public: - BiasAddGpuKernel() - : cudnn_handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - x_desc_(nullptr), - b_desc_(nullptr), - op_desc_(nullptr), - is_null_input_(false) {} - ~BiasAddGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - - T *x_addr = GetDeviceAddress(inputs, 0); - T *b_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - - try { - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_, - b_addr, &beta, x_desc_, output_addr), - "cudnnOpTensor failed"); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor"; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto num_dims = x_shape.size(); - is_null_input_ = CHECK_NULL_INPUT(x_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "input is null"; - InitSizeLists(); - return true; - } - - if (num_dims < 2) { - MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; - } - - std::string format = GetAttr(kernel_node, "data_format"); - string::size_type pos = format.find("C"); - if (pos == std::string::npos || pos >= num_dims) { - MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; - } - - // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. - auto cudnn_dims = std::max(num_dims, 4UL); - std::unique_ptr x_dims = std::make_unique(cudnn_dims); - std::unique_ptr b_dims = std::make_unique(cudnn_dims); - for (size_t i = 0; i < cudnn_dims; i++) { - x_dims[i] = (i < num_dims) ? SizeToInt(x_shape[i]) : 1; - b_dims[i] = (i == pos) ? SizeToInt(x_shape[i]) : 1; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(b_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), - "cudnnSetOpTensorDescriptor failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&b_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); - } - void InitSizeLists() override { - size_t x_size, b_size; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &x_size), "cudnnGetTensorSizeInBytes failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(b_desc_, &b_size), "cudnnGetTensorSizeInBytes failed."); - input_size_list_.push_back(x_size); - input_size_list_.push_back(b_size); - output_size_list_.push_back(x_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(b_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyOpTensorDescriptor failed"); - } - - cudnnHandle_t cudnn_handle_; - cudnnDataType_t cudnn_data_type_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t b_desc_; - cudnnOpTensorDescriptor_t op_desc_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc deleted file mode 100644 index 96d51b704c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/broadcast_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -// fp32 -MS_REG_GPU_KERNEL_TWO( - Greater, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( - Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( - Maximum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Minimum, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - RealDiv, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( - TensorAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) - -// fp16 -MS_REG_GPU_KERNEL_TWO( - Greater, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, half, bool) -MS_REG_GPU_KERNEL_TWO( - Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, half, bool) -MS_REG_GPU_KERNEL_TWO( - Maximum, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Minimum, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - RealDiv, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( - TensorAdd, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) - -// int32 -MS_REG_GPU_KERNEL_TWO( - TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( - Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( - Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( - Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h deleted file mode 100644 index be7d3a19d4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/broadcast_impl.cuh" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -template -class BroadcastOpGpuKernel : public GpuKernel { - public: - BroadcastOpGpuKernel() - : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} - ~BroadcastOpGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *lhs = GetDeviceAddress(inputs, 0); - T *rhs = GetDeviceAddress(inputs, 1); - S *output = GetDeviceAddress(outputs, 0); - - if (need_broadcast_) { - Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2], - rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs, - rhs, output, reinterpret_cast(stream_ptr)); - } else { - NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - GetOpType(kernel_node); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); - need_broadcast_ = IsBroadcast(shape1, shape2); - if (need_broadcast_ && shape1.size() > 4) { - MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; - } - - for (size_t i = 0; i < shape3.size(); i++) { - output_shape_[i] = shape3[i]; - output_num_ *= shape3[i]; - } - int lhs_offset = shape3.size() - shape1.size(); - for (size_t j = 0; j < shape1.size(); j++) { - lhs_shape_[j + lhs_offset] = shape1[j]; - input1_num_ *= shape1[j]; - } - int rhs_offset = shape3.size() - shape2.size(); - for (size_t k = 0; k < shape2.size(); k++) { - rhs_shape_[k + rhs_offset] = shape2[k]; - input2_num_ *= shape2[k]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { return; } - void InitSizeLists() override { - input_size_list_.push_back(input1_num_ * sizeof(T)); - input_size_list_.push_back(input2_num_ * sizeof(T)); - output_size_list_.push_back(output_num_ * sizeof(S)); - } - - private: - void GetOpType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - - static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, - }; - - auto iter = kBroadcastTypeMap.find(kernel_name); - if (iter == kBroadcastTypeMap.end()) { - MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; - } else { - op_type_ = iter->second; - } - } - - bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { - if (lhs.size() != rhs.size()) { - return true; - } - for (size_t i = 0; i < lhs.size(); i++) { - if (lhs[i] != rhs[i]) { - return true; - } - } - return false; - } - - BroadcastOpType op_type_; - bool need_broadcast_; - int input1_num_; - int input2_num_; - int output_num_; - int lhs_shape_[4] = {1, 1, 1, 1}; - int rhs_shape_[4] = {1, 1, 1, 1}; - int output_shape_[4] = {1, 1, 1, 1}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc deleted file mode 100644 index 85598cf940..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/broadcast_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(MinimumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(MaximumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(MinimumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - BroadcastOpGradGpuKernel, int) -MS_REG_GPU_KERNEL_ONE(MaximumGrad, - KernelAttr() - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeInt32), - BroadcastOpGradGpuKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h deleted file mode 100644 index f1eb5fecf9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/broadcast_grad_gpu_kernel.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BROADCAST_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/broadcast_grad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -template -class BroadcastOpGradGpuKernel : public GpuKernel { - public: - BroadcastOpGradGpuKernel() - : op_type_(BROADCAST_GRAD_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} - ~BroadcastOpGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *x1 = GetDeviceAddress(inputs, 0); - T *x2 = GetDeviceAddress(inputs, 1); - T *dy = GetDeviceAddress(inputs, 2); - T *dx1 = GetDeviceAddress(outputs, 0); - T *dx2 = GetDeviceAddress(outputs, 1); - - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx1, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), - "cudaMemSet Failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemsetAsync(dx2, 0, outputs[1]->size, reinterpret_cast(stream_ptr)), - "cudaMemSet Failed"); - if (need_broadcast_) { - BroadcastGrad(x1_shape_[0], x1_shape_[1], x1_shape_[2], x1_shape_[3], x2_shape_[0], x2_shape_[1], x2_shape_[2], - x2_shape_[3], dy_shape_[0], dy_shape_[1], dy_shape_[2], dy_shape_[3], op_type_, x1, x2, dy, dx1, - dx2, reinterpret_cast(stream_ptr)); - } else { - NoBroadcastGrad(output_num_, op_type_, x1, x2, dy, dx1, dx2, reinterpret_cast(stream_ptr)); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - GetOpType(kernel_node); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto shape3 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - need_broadcast_ = IsBroadcast(shape1, shape2); - if (need_broadcast_ && shape1.size() > 4) { - MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4"; - } - - for (size_t i = 0; i < shape3.size(); i++) { - dy_shape_[i] = shape3[i]; - output_num_ *= shape3[i]; - } - int x1_offset = shape3.size() - shape1.size(); - for (size_t i = 0; i < shape1.size(); i++) { - x1_shape_[i + x1_offset] = shape1[i]; - input1_num_ *= shape1[i]; - } - int x2_offset = shape3.size() - shape2.size(); - for (size_t i = 0; i < shape2.size(); i++) { - x2_shape_[i + x2_offset] = shape2[i]; - input2_num_ *= shape2[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { return; } - void InitSizeLists() override { - input_size_list_.push_back(input1_num_ * sizeof(T)); - input_size_list_.push_back(input2_num_ * sizeof(T)); - input_size_list_.push_back(output_num_ * sizeof(T)); - output_size_list_.push_back(input1_num_ * sizeof(T)); - output_size_list_.push_back(input2_num_ * sizeof(T)); - } - - private: - void GetOpType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - - static std::map kBroadcastTypeMap = { - {"MaximumGrad", BROADCAST_GRAD_TYPE_MAXIMUM}, - {"MinimumGrad", BROADCAST_GRAD_TYPE_MINIMUM}, - }; - - auto iter = kBroadcastTypeMap.find(kernel_name); - if (iter == kBroadcastTypeMap.end()) { - MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; - } else { - op_type_ = iter->second; - } - } - - bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { - if (lhs.size() != rhs.size()) { - return true; - } - for (size_t i = 0; i < lhs.size(); i++) { - if (lhs[i] != rhs[i]) { - return true; - } - } - return false; - } - - BroadcastGradOpType op_type_; - bool need_broadcast_; - int input1_num_; - int input2_num_; - int output_num_; - int x1_shape_[4] = {1, 1, 1, 1}; - int x2_shape_[4] = {1, 1, 1, 1}; - int dy_shape_[4] = {1, 1, 1, 1}; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BINARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.cc deleted file mode 100644 index f3c3b6164d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/equalcount_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - EqualCountGpuKernel, int) -MS_REG_GPU_KERNEL_ONE( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EqualCountGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - EqualCount, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - EqualCountGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.h deleted file mode 100644 index 7d3f74970f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/equalcount_gpu_kernel.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_EQUALCOUNT_GPU_KERNEL_H -#define MINDSPORE_EQUALCOUNT_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/equalcount_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class EqualCountGpuKernel : public GpuKernel { - public: - EqualCountGpuKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} - ~EqualCountGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input1 = GetDeviceAddress(inputs, 0); - T *input2 = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - int size = SizeToInt(input_size_ / sizeof(T)); - CalEqualCount(size, input1, input2, output, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but equalcount needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but equalcount needs 1 output."; - return false; - } - - output_size_ = sizeof(T); - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - return; - } - - private: - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc deleted file mode 100644 index 374644eaf5..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/float_status_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FloatStatusGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - FloatStatusGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h deleted file mode 100644 index 1aa9b18684..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/float_status_impl.cuh" - -namespace mindspore { -namespace kernel { -enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; -static const std::map kOpTypeMap = { - {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; -template -class FloatStatusGpuKernel : public GpuKernel { - public: - FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} - ~FloatStatusGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - - switch (kernel_name_) { - case OP_STATUS: { - T *output = GetDeviceAddress(outputs, 0); - CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - case OP_INF: { - bool *output = GetDeviceAddress(outputs, 0); - CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - case OP_NAN: { - bool *output = GetDeviceAddress(outputs, 0); - CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - case OP_FINITE: { - bool *output = GetDeviceAddress(outputs, 0); - CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); - break; - } - default: { - MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; - } - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(T); - for (size_t x : shape) { - input_size_ = input_size_ * x; - } - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kOpTypeMap.find(kernel_name); - if (iter == kOpTypeMap.end()) { - MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; - } else { - kernel_name_ = iter->second; - } - if (kernel_name_ == OP_STATUS) { - output_size_ = sizeof(T); - } else { - output_size_ = input_size_ / sizeof(T) * sizeof(bool); - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; - return false; - } - return true; - } - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - Optype kernel_name_; - size_t input_size_; - size_t output_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc deleted file mode 100644 index 808d599853..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/matmul_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - MatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MatMulGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - MatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - MatMulGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - BatchMatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - MatMulGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - BatchMatMul, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - MatMulGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h deleted file mode 100644 index 3ee3493ed6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MATMUL_GPU_KERNEL_H -#define MINDSPORE_MATMUL_GPU_KERNEL_H - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace kernel { -template -class MatMulGpuKernel : public GpuKernel { - public: - MatMulGpuKernel() - : batch_(0), - m_(0), - n_(0), - k_(0), - is_null_input_(false), - transpose_x1_(CUBLAS_OP_N), - transpose_x2_(CUBLAS_OP_N), - handle_(nullptr), - dtype_a_(CUDA_R_32F), - dtype_b_(CUDA_R_32F), - dtype_c_(CUDA_R_32F), - algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} - ~MatMulGpuKernel() = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - auto input1_addr = GetDeviceAddress(inputs, 0); - auto input2_addr = GetDeviceAddress(inputs, 1); - auto output_addr = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_); - const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_); - const int ldc = n_; - - auto stride_a = SizeToInt(m_ * k_); - auto stride_b = SizeToInt(k_ * n_); - auto stride_c = SizeToInt(m_ * n_); - - try { - CHECK_CUBLAS_RET_WITH_EXCEPT( - cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), - &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, - &beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_), - "cublasSgemm Call Fail"); - } catch (const std::exception &e) { - MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas cublasGemmStridedBatchedEx"; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); - dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); - dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(output_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "input is null"; - InitSizeLists(); - return true; - } - auto dims = output_shape.size(); - if (dims < 2) { - MS_LOG(EXCEPTION) << "Output dims " << dims << " not support."; - } - - m_ = output_shape[dims - 2]; - n_ = output_shape[dims - 1]; - batch_ = 1; - for (size_t i = 0; i < dims - 2; i++) { - batch_ *= output_shape[i]; - } - - bool transpose = GetAttr(kernel_node, "transpose_x1"); - transpose_x1_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - auto input1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - k_ = transpose ? input1_shape[dims - 2] : input1_shape[dims - 1]; - - transpose = GetAttr(kernel_node, "transpose_x2"); - transpose_x2_ = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t unit_size = sizeof(T); - - size_t input_size = batch_ * m_ * k_ * unit_size; - input_size_list_.push_back(input_size); - - input_size = batch_ * n_ * k_ * unit_size; - input_size_list_.push_back(input_size); - - size_t output_size = batch_ * m_ * n_ * unit_size; - output_size_list_.push_back(output_size); - } - - private: - size_t batch_; - size_t m_; - size_t n_; - size_t k_; - bool is_null_input_; - - cublasOperation_t transpose_x1_; - cublasOperation_t transpose_x2_; - cublasHandle_t handle_; - cudaDataType_t dtype_a_; - cudaDataType_t dtype_b_; - cudaDataType_t dtype_c_; - cublasGemmAlgo_t algo_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc deleted file mode 100644 index d54fe285c2..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/random_op_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - RandomOpGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h deleted file mode 100644 index 3767cd9fc8..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/random_op_gpu_kernel.h +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/random_op_impl.cuh" - -namespace mindspore { -namespace kernel { -enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; - -const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; -template -class RandomOpGpuKernel : public GpuKernel { - public: - RandomOpGpuKernel() - : random_op_type_(RANDOM_OP_INVALID_TYPE), - input_size_0_(0), - output_size_(sizeof(T)), - workspace_size_(sizeof(curandState)) {} - ~RandomOpGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - void *workspace_addr = GetDeviceAddress(workspace, 0); - curandState *devStates = reinterpret_cast(workspace_addr); - T *output_addr = GetDeviceAddress(outputs, 0); - - switch (random_op_type_) { - case RANDOM_OP_NORMAL: { - StandardNormal(seed_, seed2_, devStates, output_addr, outputs[0]->size / sizeof(T), - reinterpret_cast(stream_ptr)); - break; - } - default: { - MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; - } - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kRandomOpTypeMap.find(kernel_name); - if (iter == kRandomOpTypeMap.end()) { - MS_LOG(EXCEPTION) << "Random operation " << kernel_name << " is not supported."; - } else { - random_op_type_ = iter->second; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; - return false; - } - auto input_shape_0 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape_0.size(); i++) { - input_size_0_ += input_shape_0[i]; - } - input_size_0_ *= sizeof(int); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < output_shape.size(); i++) { - output_size_ *= output_shape[i]; - workspace_size_ *= output_shape[i]; - } - seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); - seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_0_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); - } - - private: - RandomOptype random_op_type_; - size_t input_size_0_; - size_t output_size_; - size_t workspace_size_; - int seed_; - int seed2_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_RANDOMOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc deleted file mode 100644 index 77f53fc417..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/math/unary_op_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Log, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - UnaryOpGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - UnaryOpGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h deleted file mode 100644 index 4503b805f6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/unary_op_impl.cuh" - -namespace mindspore { -namespace kernel { -enum UnaryOptype { - UNARY_OP_EXP = 0, - UNARY_OP_LOG, - UNARY_OP_NEG, - UNARY_OP_RECIPROCAL, - UNARY_OP_ZEROSLIKE, - UNARY_OP_SQUARE, - UNARY_OP_SQRT, - UNARY_OP_RSQRT, - UNARY_OP_INVALID_TYPE = 255 -}; -static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, - {"Log", UNARY_OP_LOG}, - {"Neg", UNARY_OP_NEG}, - {"Reciprocal", UNARY_OP_RECIPROCAL}, - {"ZerosLike", UNARY_OP_ZEROSLIKE}, - {"Square", UNARY_OP_SQUARE}, - {"Sqrt", UNARY_OP_SQRT}, - {"Rsqrt", UNARY_OP_RSQRT}}; -template -class UnaryOpGpuKernel : public GpuKernel { - public: - UnaryOpGpuKernel() - : unary_op_type_(UNARY_OP_INVALID_TYPE), - input_size_(sizeof(T)), - output_size_(sizeof(T)), - workspace_size_(0), - is_null_input_(false) {} - ~UnaryOpGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - switch (unary_op_type_) { - case UNARY_OP_EXP: { - Exponential(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_LOG: { - Logarithm(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_NEG: { - Negative(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_RECIPROCAL: { - Reciprocal(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_SQUARE: { - Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_SQRT: { - Sqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_RSQRT: { - Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); - break; - } - case UNARY_OP_ZEROSLIKE: { - Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); - return true; - } - default: { - MS_LOG(EXCEPTION) << "Unary operation " << unary_op_type_ << " is not supported."; - } - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kUnaryOpTypeMap.find(kernel_name); - if (iter == kUnaryOpTypeMap.end()) { - MS_LOG(EXCEPTION) << "Unary operation " << kernel_name << " is not supported."; - } else { - unary_op_type_ = iter->second; - } - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "UnaryOpGpuKernel input is null"; - InitSizeLists(); - return true; - } - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - output_size_ = input_size_; - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - UnaryOptype unary_op_type_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_UNARYOP_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.cc deleted file mode 100644 index 6993085a75..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nccl/nccl_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - NcclGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - NcclGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - NcclGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - NcclGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - NcclGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - NcclGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h deleted file mode 100644 index b5ab46a67d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "device/gpu/distribution/collective_init.h" - -namespace mindspore { -namespace kernel { -enum NcclKernelType { NCCL_ALL_REDUCE = 0, NCCL_ALL_GATHER, NCCL_REDUCE_SCATTER, NCCL_INVALID_TYPE = 255 }; -const std::map kNcclTypeMap = { - {"AllReduce", NCCL_ALL_REDUCE}, - {"AllGather", NCCL_ALL_GATHER}, - {"ReduceScatter", NCCL_REDUCE_SCATTER}, -}; - -static std::map kNcclDtypeMap = { - {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; - -typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); -typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t); -typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); - -template -class NcclGpuKernel : public GpuKernel { - public: - NcclGpuKernel() - : nccl_kernel_type_(NCCL_INVALID_TYPE), - nccl_reduce_type_(ncclSum), - input_size_(0), - output_size_(0), - collective_handle_(nullptr), - comm_stream_(nullptr) {} - ~NcclGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); - switch (nccl_kernel_type_) { - case NCCL_ALL_REDUCE: { - auto all_reduce_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "AllReduce")); - MS_EXCEPTION_IF_NULL(all_reduce_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream), - "ncclAllReduce failed"); - break; - } - case NCCL_ALL_GATHER: { - auto all_gather_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "AllGather")); - MS_EXCEPTION_IF_NULL(all_gather_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT( - (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream), - "ncclAllGather failed"); - break; - } - case NCCL_REDUCE_SCATTER: { - auto reduce_scatter_funcptr = - reinterpret_cast(dlsym(const_cast(collective_handle_), "ReduceScatter")); - MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); - CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), - nccl_data_type_, nccl_reduce_type_, stream), - "ncclReduceScatter failed"); - break; - } - default: { - MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; - } - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - nccl_data_type_ = kNcclDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; ++i) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - size_t size = sizeof(T); - for (size_t j = 0; j < shape.size(); j++) { - size *= IntToSize(shape[j]); - } - input_size_list_.push_back(size); - input_size_ += size; - } - for (size_t i = 0; i < output_num; ++i) { - auto shape = AnfAlgo::GetOutputInferShape(kernel_node, i); - size_t size = sizeof(T); - for (size_t j = 0; j < shape.size(); j++) { - size *= IntToSize(shape[j]); - } - output_size_list_.push_back(size); - output_size_ += size; - } - InferCommType(kernel_node); - collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); - MS_EXCEPTION_IF_NULL(collective_handle_); - - auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); - if (comm_stream_attr) { - comm_stream_ = reinterpret_cast(GetValue(comm_stream_attr)); - MS_EXCEPTION_IF_NULL(comm_stream_); - } - return true; - } - - protected: - void InitSizeLists() override { return; } - - private: - void InferCommType(const CNodePtr &kernel_node) { - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kNcclTypeMap.find(kernel_name); - if (iter == kNcclTypeMap.end()) { - MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported."; - } else { - nccl_kernel_type_ = iter->second; - } - - auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); - if (reduce_op) { - std::string type = GetValue(reduce_op); - if (type == "sum") { - nccl_reduce_type_ = ncclSum; - } else if (type == "max") { - nccl_reduce_type_ = ncclMax; - } else if (type == "min") { - nccl_reduce_type_ = ncclMin; - } else if (type == "prod") { - nccl_reduce_type_ = ncclProd; - } else { - MS_LOG(EXCEPTION) << "Nccl reduce type " << type << " is not supported."; - } - } - return; - } - - NcclKernelType nccl_kernel_type_; - ncclRedOp_t nccl_reduce_type_; - ncclDataType_t nccl_data_type_; - size_t input_size_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - const void *collective_handle_; - cudaStream_t comm_stream_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NCCL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc deleted file mode 100644 index 5e80cccd75..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/activation_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) - -MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Tanh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) - -MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Sigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h deleted file mode 100644 index bf6cfa7b23..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_gpu_kernel.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ActivationGpuFwdKernel : public GpuKernel { - public: - ActivationGpuFwdKernel() - : cudnn_handle_(nullptr), - activation_desc_(nullptr), - mode_(CUDNN_ACTIVATION_RELU), - data_descriptor_(nullptr), - is_null_input_(false), - cudnn_data_type_(CUDNN_DATA_FLOAT), - input_size_(0), - output_size_(0), - workspace_size_(0) {} - ~ActivationGpuFwdKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnActivationForward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, input, - &beta, data_descriptor_, output), - "cudnnActivationForward failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kernel_map.find(node_name); - if (iter == kernel_map.end()) { - MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; - } - mode_ = iter->second; - - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGpuFwdKernel needs 1."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ActivationGpuFwdKernel input is null."; - InitSizeLists(); - return true; - } - std::vector shape; - ShapeNdTo4d(input_shape, &shape); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0), - "cudnnSetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "cudnnSetTensor4dDescriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), - "cudnnCreateActivationDescriptor failed"); - } - - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - output_size_ = input_size_; - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), - "cudnnDestroyActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - std::map kernel_map = {{"ReLU", CUDNN_ACTIVATION_RELU}, - {"Tanh", CUDNN_ACTIVATION_TANH}, - {"ELU", CUDNN_ACTIVATION_ELU}, - {"Sigmoid", CUDNN_ACTIVATION_SIGMOID}}; - - cudnnHandle_t cudnn_handle_; - cudnnActivationDescriptor_t activation_desc_; - cudnnActivationMode_t mode_; - cudnnTensorDescriptor_t data_descriptor_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc deleted file mode 100644 index 35d11f8b47..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/activation_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - ReluGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - ReluGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGradGpuKernel, half) - -MS_REG_GPU_KERNEL_ONE( - TanhGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - TanhGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGradGpuKernel, half) - -MS_REG_GPU_KERNEL_ONE( - SigmoidGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ActivationGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - SigmoidGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ActivationGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h deleted file mode 100644 index 38e34eb752..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/activation_grad_kernel.h +++ /dev/null @@ -1,146 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ActivationGradGpuKernel : public GpuKernel { - public: - ActivationGradGpuKernel() - : cudnn_handle_(nullptr), - activation_desc_(nullptr), - mode_(CUDNN_ACTIVATION_RELU), - data_descriptor_(nullptr), - is_null_input_(false), - cudnn_data_type_(CUDNN_DATA_FLOAT), - input_size_(0) {} - ~ActivationGradGpuKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { - if (is_null_input_) { - return true; - } - T *dy = nullptr; - T *y = nullptr; - if (mode_ == CUDNN_ACTIVATION_RELU || mode_ == CUDNN_ACTIVATION_ELU) { - dy = GetDeviceAddress(inputs, 0); - y = GetDeviceAddress(inputs, 1); - } else { - y = GetDeviceAddress(inputs, 0); - dy = GetDeviceAddress(inputs, 1); - } - T *dx = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, - data_descriptor_, y, &beta, data_descriptor_, dx), - "cudnnActivationBackward failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - auto iter = kernel_map.find(node_name); - if (iter == kernel_map.end()) { - MS_LOG(EXCEPTION) << "Kernel: " << node_name << " not support."; - } - mode_ = iter->second; - - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but ActivationGradGpuKernel needs 2."; - return false; - } - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ActivationGradGpuKernel input is null."; - InitSizeLists(); - return true; - } - std::vector shape; - ShapeNdTo4d(input_shape, &shape); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0), - "SetActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - shape[0], shape[1], shape[2], shape[3]), - "SetTensor4dDescriptor failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&data_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_), - "cudnnCreateActivationDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(data_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_), - "cudnnDestroyActivationDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(data_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - std::map kernel_map = {{"ReluGrad", CUDNN_ACTIVATION_RELU}, - {"TanhGrad", CUDNN_ACTIVATION_TANH}, - {"ELUGrad", CUDNN_ACTIVATION_ELU}, - {"SigmoidGrad", CUDNN_ACTIVATION_SIGMOID}}; - cudnnHandle_t cudnn_handle_; - cudnnActivationDescriptor_t activation_desc_; - cudnnActivationMode_t mode_; - cudnnTensorDescriptor_t data_descriptor_; - bool is_null_input_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_RELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc deleted file mode 100644 index 049a5cc280..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/adam_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Adam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - AdamGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Adam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - AdamGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h deleted file mode 100644 index 93c6381ab3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/adam_impl.cuh" -namespace mindspore { -namespace kernel { -template -class AdamGpuKernel : public GpuKernel { - public: - AdamGpuKernel() - : variable_size_(0), - m_size_(0), - v_size_(0), - beta1_power_size_(0), - beta2_power_size_(0), - learning_rate_size_(0), - beta1_size_(0), - beta2_size_(0), - epsilon_size_(0), - gradient_size_(0) {} - - ~AdamGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *m = GetDeviceAddress(inputs, 1); - T *v = GetDeviceAddress(inputs, 2); - T *beta1_power = GetDeviceAddress(inputs, 3); - T *beta2_power = GetDeviceAddress(inputs, 4); - T *learning_rate = GetDeviceAddress(inputs, 5); - T *beta1 = GetDeviceAddress(inputs, 6); - T *beta2 = GetDeviceAddress(inputs, 7); - T *epsilon = GetDeviceAddress(inputs, 8); - T *gradient = GetDeviceAddress(inputs, 9); - ApplyAdam(inputs[0]->size / sizeof(T), gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, - variable, m, v, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 10) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs."; - return false; - } - - variable_size_ = sizeof(T); - m_size_ = sizeof(T); - v_size_ = sizeof(T); - beta1_power_size_ = sizeof(T); - beta2_power_size_ = sizeof(T); - learning_rate_size_ = sizeof(T); - beta1_size_ = sizeof(T); - beta2_size_ = sizeof(T); - epsilon_size_ = sizeof(T); - gradient_size_ = sizeof(T); - - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - } - - auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < m_shape.size(); i++) { - m_size_ *= m_shape[i]; - } - - auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - for (size_t i = 0; i < v_shape.size(); i++) { - v_size_ *= v_shape[i]; - } - - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); - for (size_t i = 0; i < gradient_shape.size(); i++) { - gradient_size_ *= gradient_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(m_size_); - input_size_list_.push_back(v_size_); - input_size_list_.push_back(beta1_power_size_); - input_size_list_.push_back(beta2_power_size_); - input_size_list_.push_back(learning_rate_size_); - input_size_list_.push_back(beta1_size_); - input_size_list_.push_back(beta2_size_); - input_size_list_.push_back(epsilon_size_); - input_size_list_.push_back(gradient_size_); - output_size_list_.push_back(0); - output_size_list_.push_back(0); - output_size_list_.push_back(0); - } - - private: - size_t variable_size_; - size_t m_size_; - size_t v_size_; - size_t beta1_power_size_; - size_t beta2_power_size_; - size_t learning_rate_size_; - size_t beta1_size_; - size_t beta2_size_; - size_t epsilon_size_; - size_t gradient_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.cc b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.cc deleted file mode 100644 index ce6c9beeb7..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/bias_add_grad_gpu_kenel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BiasAddGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BiasAddGradGpuKernel, float16) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h deleted file mode 100644 index 9b4f18d24c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class BiasAddGradGpuKernel : public GpuKernel { - public: - BiasAddGradGpuKernel() - : same_dims_(true), - cudnn_handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - dy_desc_(nullptr), - db_desc_(nullptr), - op_desc_(nullptr) {} - ~BiasAddGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - T *dy_addr = GetDeviceAddress(inputs, 0); - T *db_addr = GetDeviceAddress(outputs, 0); - T *indices_addr = GetDeviceAddress(workspace, 0); - T *workspace_addr = GetDeviceAddress(workspace, 1); - - const float alpha = 1; - const float beta = 0; - if (same_dims_) { - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(db_addr, dy_addr, output_size_list_[0], cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync failed."); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnReduceTensor(cudnn_handle_, op_desc_, indices_addr, workspace_size_list_[0], workspace_addr, - workspace_size_list_[1], &alpha, dy_desc_, dy_addr, &beta, db_desc_, db_addr), - "cudnnReduceTensor failed"); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto num_dims = dy_shape.size(); - if (num_dims < 2) { - MS_LOG(EXCEPTION) << "input dims must be at least 2, but got " << num_dims; - } - - std::string format = GetAttr(kernel_node, "data_format"); - string::size_type pos = format.find("C"); - if (pos == std::string::npos || pos >= num_dims) { - MS_LOG(EXCEPTION) << "format '" << format << "' invalid"; - } - - // Expand to 4 dims for cudnnSetTensorNdDescriptorEx. - auto cudnn_dims = std::max(num_dims, 4UL); - std::unique_ptr dy_dims = std::make_unique(cudnn_dims); - std::unique_ptr db_dims = std::make_unique(cudnn_dims); - for (size_t i = 0; i < cudnn_dims; i++) { - dy_dims[i] = (i < num_dims) ? SizeToInt(dy_shape[i]) : 1; - db_dims[i] = (i == pos) ? SizeToInt(dy_shape[i]) : 1; - - if (dy_dims[i] != db_dims[i]) { - same_dims_ = false; - } - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), dy_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), - "cudnnSetTensorNdDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, - CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), - "cudnnSetReduceTensorDescriptor failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&db_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateReduceTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); - } - void InitSizeLists() override { - size_t dy_size, db_size; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size), "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(db_desc_, &db_size), "cudnnGetTensorSizeInBytes failed"); - input_size_list_.push_back(dy_size); - output_size_list_.push_back(db_size); - - size_t indices_size, workspace_size; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetReductionIndicesSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &indices_size), - "cudnnGetReductionIndicesSize failed") - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetReductionWorkspaceSize(cudnn_handle_, op_desc_, dy_desc_, db_desc_, &workspace_size), - "cudnnGetReductionWorkspaceSize failed") - workspace_size_list_.push_back(indices_size); - workspace_size_list_.push_back(workspace_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDestroyReduceTensorDescriptor(op_desc_), - "cudnnDestroyReduceTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(db_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyOpTensorDescriptor failed"); - } - - bool same_dims_; - cudnnHandle_t cudnn_handle_; - cudnnDataType_t cudnn_data_type_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t db_desc_; - cudnnReduceTensorDescriptor_t op_desc_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BIAS_ADD_GRAD_GPU_KENEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.cc deleted file mode 100644 index df6825e079..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/conv2d_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Conv2D, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - Conv2dGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - Conv2D, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - Conv2dGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.h deleted file mode 100644 index f51cbfef33..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_gpu_kernel.h +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class Conv2dGpuFwdKernel : public GpuKernel { - public: - Conv2dGpuFwdKernel() - : cudnn_handle_(nullptr), - input_desc_(nullptr), - output_desc_(nullptr), - filter_desc_(nullptr), - conv_desc_(nullptr), - padded_desc_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - input_size_(0), - filter_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~Conv2dGpuFwdKernel() override { DestroyResource(); } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *filter_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - T *workspace_addr = nullptr; - if (workspace_size_ != 0) { - workspace_addr = GetDeviceAddress(workspace, 0); - } - - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded_addr = GetDeviceAddress(workspace, 1); - CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, - conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), - "cudnnConvolutionForward failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionForward(cudnn_handle_, &alpha, input_desc_, input_addr, filter_desc_, filter_addr, conv_desc_, - conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), - "cudnnConvolutionForward failed"); - } - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(in_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null."; - InitSizeLists(); - return true; - } - Set4DDesc(in_shape, filter_shape, output_shape); - group_ = GetAttr(kernel_node, "group"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - SetStrideAndDilation(kernel_node); - cudnnTensorDescriptor_t input_descriptor_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(in_shape, kernel_node); - input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - input_descriptor_real = input_desc_; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), - "cudnnSetConvolutionMathType failed.") - } - SelectAlgorithm(input_descriptor_real); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&filter_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), - "cudnnCreateConvolutionDescriptor failed"); - } - - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(input_desc_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(filter_desc_, reinterpret_cast(&filter_size_)), - "cudnnGetFilterSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(output_desc_, reinterpret_cast(&output_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_desc_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - input_size_list_.push_back(filter_size_); - output_size_list_.push_back(output_size_); - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, padded_desc_, filter_desc_, conv_desc_, output_desc_, - conv_algorithm_, &workspace_size_), - "cudnnGetConvolutionForwardWorkspaceSize failed"); - workspace_size_list_.push_back(padded_size_); - } else { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle_, input_desc_, filter_desc_, conv_desc_, output_desc_, - conv_algorithm_, &workspace_size_), - "cudnnGetConvolutionForwardWorkspaceSize failed"); - } - } - (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); - - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), - "cudnnDestroyConvolutionDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(filter_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); - } - bool CheckParam(const CNodePtr &kernel_node) { - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output."; - return false; - } - return true; - } - void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - - n_ = SizeToInt(in_shape[0]); - c_ = SizeToInt(in_shape[1]); - old_height_ = SizeToInt(in_shape[2]); - old_width_ = SizeToInt(in_shape[3]); - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - - // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, - old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } - - void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, - const std::vector &output_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), - SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]), - SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), - "cudnnSetFilter4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - } - void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionForwardAlgorithm( - cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, output_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 0, &conv_algorithm_), - "cudnnGetConvolutionForwardAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionFwdAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, - output_desc_, requested_algo_count, &returned_algo_count, &perf_results), - "cudnnGetConvolutionForwardAlgorithm_v7 failed"); - conv_algorithm_ = perf_results.algo; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - } - } - void SetStrideAndDilation(const CNodePtr &kernel_node) { - stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); - dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); - if (stride_.size() != 4) { - MS_LOG(EXCEPTION) << "Conv2d's' stride must be 4d!"; - } - if (stride_[0] != 1 || stride_[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2d stride only support 1 in N axis and C axis!"; - } - if (dilation_.size() != 4) { - MS_LOG(EXCEPTION) << "Conv2d's dilation must be 4d!"; - } - if (dilation_[0] != 1 || dilation_[1] != 1) { - MS_LOG(EXCEPTION) << "Conv2d dilation only support 1 in N axis and C axis!"; - } - } - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_desc_; - cudnnTensorDescriptor_t output_desc_; - cudnnFilterDescriptor_t filter_desc_; - cudnnConvolutionFwdAlgo_t conv_algorithm_; - cudnnConvolutionDescriptor_t conv_desc_; - cudnnTensorDescriptor_t padded_desc_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - const float pad_value_ = 0.0; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - std::vector stride_; - std::vector dilation_; - int group_; - bool is_null_input_; - size_t input_size_; - size_t filter_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2DGPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.cc deleted file mode 100644 index 28e9a10ccc..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropFilter, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConvGradFilterGpuBkwKernel, float) -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropFilter, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ConvGradFilterGpuBkwKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h deleted file mode 100644 index 0d7be25772..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ConvGradFilterGpuBkwKernel : public GpuKernel { - public: - ConvGradFilterGpuBkwKernel() - : cudnn_handle_(nullptr), - dw_desc_(nullptr), - conv_desc_(nullptr), - dy_desc_(nullptr), - x_desc_(nullptr), - padded_descriptor_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - input_size_(0), - dy_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~ConvGradFilterGpuBkwKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *dy = GetDeviceAddress(inputs, 0); - T *x = GetDeviceAddress(inputs, 1); - T *dw = GetDeviceAddress(outputs, 0); - T *work_space = nullptr; - if (workspace_size_ != 0) { - work_space = GetDeviceAddress(workspace, 0); - } - - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 1); - CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, - algo_, work_space, workspace_size_, &beta, dw_desc_, dw), - "ConvolutionBackwardFilter failed"); - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, x_desc_, x, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, dw_desc_, dw), - "ConvolutionBackwardFilter failed"); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null."; - InitSizeLists(); - return true; - } - std::vector filter_shape; - GetFilterShape(kernel_node, &filter_shape); - Set4DDesc(dy_shape, filter_shape, in_shape); - group_ = GetAttr(kernel_node, "group"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - SetStrideAndDilation(kernel_node); - cudnnTensorDescriptor_t x_desc_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(in_shape, kernel_node); - x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "GetConvolution2dDescriptor failed"); - x_desc_real = x_desc_; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), - "cudnnSetConvolutionMathType failed.") - } - SelectAlgorithm(x_desc_real); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "cudnnCreateFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), - "cudnnCreateConvolutionDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, reinterpret_cast(&dy_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(dw_desc_, reinterpret_cast(&output_size_)), - "cudnnGetFilterSizeInBytes failed"); - } - input_size_list_.push_back(dy_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, padded_descriptor_, dy_desc_, conv_desc_, - dw_desc_, algo_, reinterpret_cast(&workspace_size_)), - "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); - workspace_size_list_.push_back(padded_size_); - } else { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle_, x_desc_, dy_desc_, conv_desc_, dw_desc_, algo_, - reinterpret_cast(&workspace_size_)), - "cudnnGetConvolutionBackwardFilterWorkspaceSize failed"); - } - } - (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), - "cudnnDestroyConvolutionDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "cudnnDestroyFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "cudnnDestroyTensorDescriptor failed"); - } - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradFilter needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradFilter needs 1 output."; - return false; - } - return true; - } - void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - n_ = SizeToInt(in_shape[0]); - c_ = SizeToInt(in_shape[1]); - old_height_ = SizeToInt(in_shape[2]); - old_width_ = SizeToInt(in_shape[3]); - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } - void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), - "GetConvolutionBackwardFilterAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionBwdFilterAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, - requested_algo_count, &returned_algo_count, &perf_results), - "GetConvolutionBackwardFilterAlgorithm failed"); - algo_ = perf_results.algo; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - } - } - void GetFilterShape(const CNodePtr &kernel_node, std::vector *filter_shape) { - auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast()->value(); - (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), - [](const ValuePtr &e) -> int { return e->cast()->value(); }); - } - void Set4DDesc(const std::vector &dy_shape, const std::vector &filter_shape, - const std::vector &in_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), - SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), - "SetTensor4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1], - filter_shape[2], filter_shape[3]), - "SetFilter4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), - SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), - "SetTensor4dDescriptor failed"); - } - void SetStrideAndDilation(const CNodePtr &kernel_node) { - stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); - dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); - if (stride_.size() != 2) { - MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's stride must be 2d!"; - } - if (dilation_.size() != 4) { - MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel's dilation must be 4d!"; - } - if (dilation_[0] != 1 || dilation_[1] != 1) { - MS_LOG(EXCEPTION) << "ConvGradFilterGpuBkwKernel dilation only support 1 in N axis and C axis!"; - } - } - cudnnHandle_t cudnn_handle_; - cudnnFilterDescriptor_t dw_desc_; - cudnnConvolutionDescriptor_t conv_desc_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnConvolutionBwdFilterAlgo_t algo_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - const float pad_value_ = 0.0; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - std::vector stride_; - std::vector dilation_; - int group_; - bool is_null_input_; - size_t input_size_; - size_t dy_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.cc deleted file mode 100644 index 12b6f91537..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropInput, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConvGradInputGpuBkwKernel, float) -MS_REG_GPU_KERNEL_ONE( - Conv2DBackpropInput, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - ConvGradInputGpuBkwKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h deleted file mode 100644 index a33ea5b4da..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ /dev/null @@ -1,315 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class ConvGradInputGpuBkwKernel : public GpuKernel { - public: - ConvGradInputGpuBkwKernel() - : cudnn_handle_(nullptr), - w_desc_(nullptr), - conv_desc_(nullptr), - dy_desc_(nullptr), - dx_desc_(nullptr), - padded_descriptor_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - dy_size_(0), - w_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~ConvGradInputGpuBkwKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *dy = GetDeviceAddress(inputs, 0); - T *w = GetDeviceAddress(inputs, 1); - T *dx = GetDeviceAddress(outputs, 0); - T *work_space = nullptr; - if (workspace_size_ != 0) { - work_space = GetDeviceAddress(workspace, 0); - } - - const float alpha = 1; - const float beta = 0; - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 1); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, padded_descriptor_, padded), - "ConvolutionBackwardData failed"); - CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, - workspace_size_, &beta, dx_desc_, dx), - "ConvolutionBackwardData failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(dy_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null."; - InitSizeLists(); - return true; - } - std::vector input_shape; - GetInputShape(kernel_node, &input_shape); - Set4DDesc(dy_shape, input_shape, filter_shape); - - group_ = GetAttr(kernel_node, "group"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); - - pad_height_ = GetAttr(kernel_node, "pad"); - pad_width_ = pad_height_; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - SetStrideAndDilation(kernel_node); - cudnnTensorDescriptor_t dx_desc_real = nullptr; - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(input_shape, kernel_node); - dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - dx_desc_real = dx_desc_; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH), - "cudnnSetConvolutionMathType failed.") - } - SelectAlgorithm(dx_desc_real); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "cudnnCreateFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateConvolutionDescriptor(&conv_desc_), - "cudnnCreateConvolutionDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_desc_, &dy_size_), "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetFilterSizeInBytes(w_desc_, &w_size_), "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_desc_, &output_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(dy_size_); - input_size_list_.push_back(w_size_); - output_size_list_.push_back(output_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), - "cudnnGetTensorSizeInBytes failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, padded_descriptor_, - algo_, &workspace_size_), - "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); - workspace_size_list_.push_back(padded_size_); - } else { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_, algo_, &workspace_size_), - "cudnnGetConvolutionBackwardDataWorkspaceSize failed"); - } - } - (void)workspace_size_list_.insert(workspace_size_list_.begin(), workspace_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyConvolutionDescriptor(conv_desc_), - "cudnnDestroyConvolutionDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "cudnnDestroyFilterDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "cudnnDestroyTensorDescriptor failed"); - } - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConvGradInput needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but ConvGradInput needs 1 output."; - return false; - } - return true; - } - void SetPad(const std::vector &input_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - n_ = input_shape[0]; - c_ = input_shape[1]; - old_height_ = input_shape[2]; - old_width_ = input_shape[3]; - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } - void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), - "cudnnGetConvolutionBackwardDataAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionBwdDataAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, - requested_algo_count, &returned_algo_count, &perf_results), - "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); - algo_ = perf_results.algo; - } - if (cudnn_data_type_ == CUDNN_DATA_HALF) { - algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - } - } - void GetInputShape(const CNodePtr &kernel_node, std::vector *input_shape) { - auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast()->value(); - (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), - [](const ValuePtr &e) -> int { return e->cast()->value(); }); - } - void Set4DDesc(const std::vector &dy_shape, const std::vector &input_shape, - const std::vector &filter_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), - SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), - "SetFilter4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), - SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), - "SetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1], - input_shape[2], input_shape[3]), - "SetTensor4dDescriptor failed"); - } - void SetStrideAndDilation(const CNodePtr &kernel_node) { - stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); - dilation_ = AnfAlgo::GetNodeAttr>(kernel_node, "dilation"); - if (stride_.size() != 2) { - MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's stride must be 2d!"; - } - if (dilation_.size() != 4) { - MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel's dilation must be 4d!"; - } - if (dilation_[0] != 1 || dilation_[1] != 1) { - MS_LOG(EXCEPTION) << "ConvGradInputGpuBkwKernel dilation only support 1 in N axis and C axis!"; - } - } - cudnnHandle_t cudnn_handle_; - cudnnFilterDescriptor_t w_desc_; - cudnnConvolutionDescriptor_t conv_desc_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t dx_desc_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnConvolutionBwdDataAlgo_t algo_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - std::vector stride_; - std::vector dilation_; - int group_; - bool is_null_input_; - size_t dy_size_; - size_t w_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.cc deleted file mode 100644 index 355d238ab4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/ctcloss_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(CTCLossV2, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - CtcLossGpuKernel, float) - -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.h deleted file mode 100644 index 2bd83b3176..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/ctcloss_gpu_kernel.h +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "device/gpu/gpu_memory_allocator.h" - -namespace mindspore { -namespace kernel { -template -class CtcLossGpuKernel : public GpuKernel { - public: - CtcLossGpuKernel() - : cudnn_handle_(nullptr), - probs_desc_(nullptr), - ctcloss_desc_(nullptr), - label_size_(0), - input_lengths_size_(0), - label_lengths_size_(0) {} - ~CtcLossGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - float *probs = GetDeviceAddress(inputs, 0); - int *labels = GetDeviceAddress(inputs, 1); - int *input_lengths = GetDeviceAddress(inputs, 2); - int *label_lengths = GetDeviceAddress(inputs, 3); - float *costs = GetDeviceAddress(outputs, 0); - float *grads = GetDeviceAddress(outputs, 1); - - // Copy labels/input_lengths/label_length to host as cudnn7.x.x requires - void *labels_host = nullptr; - void *input_lengths_host = nullptr; - void *label_lengths_host = nullptr; - CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&labels_host, inputs[1]->size), "cudaMallocHost failed."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&input_lengths_host, inputs[2]->size), "cudaMallocHost failed."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(&label_lengths_host, inputs[3]->size), "cudaMallocHost failed."); - cudaStream_t stream = reinterpret_cast(stream_ptr); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(labels_host, labels, inputs[1]->size, cudaMemcpyDeviceToHost, stream), - "cudaMemcpyAsync failed."); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(input_lengths_host, input_lengths, inputs[2]->size, cudaMemcpyDeviceToHost, stream), - "cudaMemcpyAsync failed."); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(label_lengths_host, label_lengths, inputs[3]->size, cudaMemcpyDeviceToHost, stream), - "cudaMemcpyAsync failed."); - - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetCTCLossWorkspaceSize(cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast(labels_host), - reinterpret_cast(label_lengths_host), - reinterpret_cast(input_lengths_host), CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, - ctcloss_desc_, &workspace_size), - "cudnnGetCTCLossWorkspaceSize failed."); - void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size); - if (workspace == nullptr) { - MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast(labels_host), - reinterpret_cast(label_lengths_host), reinterpret_cast(input_lengths_host), costs, - probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size), - "cudnnCtcLoss failed."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); - - device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace); - CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (probs_shape.size() != 3) { - MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support."; - } - probs_dims_[0] = probs_shape[0]; - probs_dims_[1] = probs_shape[1]; - probs_dims_[2] = probs_shape[2]; - - auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - if (labels_dims.size() != 1 && labels_dims.size() != 2) { - MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support."; - } - label_size_ = sizeof(int); - for (auto i : labels_dims) { - label_size_ *= i; - } - - auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - input_lengths_size_ = input_length_dims[0] * sizeof(int); - auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - label_lengths_size_ = label_length_dims[0] * sizeof(int); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_), - "cudnnSetTensorNdDescriptorEx failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT, - CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN), - "cudnnSetCTCLossDescriptorEx failed."); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed."); - } - - void InitSizeLists() override { - input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); - input_size_list_.push_back(label_size_); - input_size_list_.push_back(input_lengths_size_); - input_size_list_.push_back(label_lengths_size_); - - output_size_list_.push_back(probs_dims_[1] * sizeof(float)); - output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float)); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed."); - } - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t probs_desc_; - cudnnCTCLossDescriptor_t ctcloss_desc_; - int probs_dims_[3] = {0}; - int label_size_; - int input_lengths_size_; - int label_lengths_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_CTCLOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc deleted file mode 100644 index 459010e9e9..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/dropout_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Dropout, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - DropoutGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - Dropout, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - DropoutGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h deleted file mode 100644 index 4dfacb7ca1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_gpu_kernel.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/dropout_impl.cuh" -#include "include/curand.h" - -namespace mindspore { -namespace kernel { -template -class DropoutGpuFwdKernel : public GpuKernel { - public: - DropoutGpuFwdKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - num_count_(0), - keep_prob_(0.0), - states_init_(false), - mask_generator_(nullptr) {} - - ~DropoutGpuFwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - T *mask = GetDeviceAddress(outputs, 1); - float *mask_f = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT); - curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL)); - states_init_ = true; - } - // curandGen only support float or double for mask. - curandGenerateUniform(mask_generator_, mask_f, num_count_); - DropoutForward(input, mask, output, mask_f, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); - - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1."; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - InitSizeLists(); - return true; - } - - num_count_ = 1; - for (size_t x : input_shape) { - num_count_ *= x; - } - keep_prob_ = GetAttr(kernel_node, "keep_prob"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - - void InitSizeLists() override { - size_t input_size = num_count_ * sizeof(T); - input_size_list_.push_back(input_size); - output_size_list_.push_back(input_size); // output size: the same with input size - output_size_list_.push_back(input_size); // mask size: the same with input size - workspace_size_list_.push_back(num_count_ * sizeof(float)); // temp mask_f for curandGen - } - - private: - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t num_count_; - float keep_prob_; - bool states_init_; - curandGenerator_t mask_generator_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc deleted file mode 100644 index 2fd21c96ee..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/dropout_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - DropoutGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - DropoutGradGpuBwdKernel, float) -MS_REG_GPU_KERNEL_ONE( - DropoutGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - DropoutGradGpuBwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h deleted file mode 100644 index e6683e15dd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/dropout_grad_kernel.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/dropout_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class DropoutGradGpuBwdKernel : public GpuKernel { - public: - DropoutGradGpuBwdKernel() : cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), keep_prob_(0.0) {} - ~DropoutGradGpuBwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - T *dy = GetDeviceAddress(inputs, 0); - T *mask = GetDeviceAddress(inputs, 1); - T *dx = GetDeviceAddress(outputs, 0); - - DropoutBackward(dy, mask, dx, num_count_, keep_prob_, reinterpret_cast(stream_ptr)); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuBwdKernel needs 2."; - return false; - } - - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - InitSizeLists(); - return true; - } - - num_count_ = 1; - for (size_t x : input_shape) { - num_count_ *= x; - } - keep_prob_ = GetAttr(kernel_node, "keep_prob"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - void InitSizeLists() override { - size_t dy_size = num_count_ * sizeof(T); - size_t mask_size = dy_size; - size_t dx_size = dy_size; - - input_size_list_.push_back(dy_size); - input_size_list_.push_back(mask_size); - output_size_list_.push_back(dx_size); - } - - private: - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t num_count_; - float keep_prob_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.cc deleted file mode 100644 index f9c993d31d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/flatten_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGpuFwdKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h deleted file mode 100644 index 3b0ad8c946..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class FlattenGpuFwdKernel : public GpuKernel { - public: - FlattenGpuFwdKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} - ~FlattenGpuFwdKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - cudaError_t ret = - cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); - if (ret) { - MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGpuFwdKernel::Launch, error code is " << ret; - return false; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(T); - for (size_t i = 0; i < shape.size(); ++i) { - input_size_ *= shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_ = input_size_; - output_size_list_.push_back(output_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.cc deleted file mode 100644 index 0e079d137b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.cc +++ /dev/null @@ -1,28 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/flatten_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - FlattenGardGpuBkwKernel, float) -MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - FlattenGardGpuBkwKernel, half) -MS_REG_GPU_KERNEL_ONE(FlattenGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - FlattenGardGpuBkwKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.h deleted file mode 100644 index 0748dc77db..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_grad_gpu_kernel.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class FlattenGardGpuBkwKernel : public GpuKernel { - public: - FlattenGardGpuBkwKernel() : input_size_(0), output_size_(0), workspace_size_(0) {} - ~FlattenGardGpuBkwKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - cudaError_t ret = - cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)); - if (ret) { - MS_LOG(ERROR) << "cudaMemcpyAsync error in FlattenGardGpuFwdKernel::Launch, error code is " << ret; - return false; - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but FlattenGardGpuFwdKernel needs 1."; - return false; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < shape.size(); ++i) { - if (input_size_ == 0) { - input_size_ = 1; - } - input_size_ *= shape[i]; - } - input_size_ = input_size_ * sizeof(T); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_ = input_size_; - output_size_list_.push_back(output_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; - size_t output_size_; - size_t workspace_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FLATTEN_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc deleted file mode 100644 index 4d30130931..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/ftrl_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FtrlGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ApplyFtrl, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FtrlGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h deleted file mode 100644 index 9e2153965b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/ftrl_impl.cuh" -namespace mindspore { -namespace kernel { -template -class FtrlGpuKernel : public GpuKernel { - public: - FtrlGpuKernel() - : variable_size_(0), - accumulation_size_(0), - linear_size_(0), - gradient_size_(0), - learning_rate_size_(0), - l1_regularization_size_(0), - l2_regularization_size_(0), - learning_rate_power_size_(0) {} - - ~FtrlGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *accumulation = GetDeviceAddress(inputs, 1); - T *linear = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); - T *learning_rate = GetDeviceAddress(inputs, 4); - T *l1_regularization = GetDeviceAddress(inputs, 5); - T *l2_regularization = GetDeviceAddress(inputs, 6); - T *learning_rate_power = GetDeviceAddress(inputs, 7); - ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization, - learning_rate_power, variable, accumulation, linear, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 8) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs."; - return false; - } - - variable_size_ = sizeof(T); - accumulation_size_ = sizeof(T); - linear_size_ = sizeof(T); - gradient_size_ = sizeof(T); - learning_rate_size_ = sizeof(T); - l1_regularization_size_ = sizeof(T); - l2_regularization_size_ = sizeof(T); - learning_rate_power_size_ = sizeof(T); - - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - } - - auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < accumulation_shape.size(); i++) { - accumulation_size_ *= accumulation_shape[i]; - } - - auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - for (size_t i = 0; i < linear_shape.size(); i++) { - linear_size_ *= linear_shape[i]; - } - - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - for (size_t i = 0; i < gradient_shape.size(); i++) { - gradient_size_ *= gradient_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(accumulation_size_); - input_size_list_.push_back(linear_size_); - input_size_list_.push_back(gradient_size_); - input_size_list_.push_back(learning_rate_size_); - input_size_list_.push_back(l1_regularization_size_); - input_size_list_.push_back(l2_regularization_size_); - input_size_list_.push_back(learning_rate_power_size_); - output_size_list_.push_back(0); - } - - private: - size_t variable_size_; - size_t accumulation_size_; - size_t linear_size_; - size_t gradient_size_; - size_t learning_rate_size_; - size_t l1_regularization_size_; - size_t l2_regularization_size_; - size_t learning_rate_power_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc deleted file mode 100644 index 99af1add46..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/fused_adam_weight_decay.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FusedAdamWeightDecay, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedAdamWeightDecayGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FusedAdam, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedAdamWeightDecayGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h b/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h deleted file mode 100644 index f13f6ed59f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_adam_weight_decay.h +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/adam_weight_decay_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class FusedAdamWeightDecayGpuKernel : public GpuKernel { - public: - FusedAdamWeightDecayGpuKernel() : element_nums_(0), weight_decay_(false) {} - ~FusedAdamWeightDecayGpuKernel() override = default; - - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "AdamWeighDecay") { - weight_decay_ = true; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 7); - element_nums_ = 1; - for (auto i : shape) { - element_nums_ *= i; - } - - InitSizeLists(); - return true; - } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - float *beta1 = GetDeviceAddress(inputs, 0); - float *one_sub_beta1 = GetDeviceAddress(inputs, 1); - float *beta2 = GetDeviceAddress(inputs, 2); - float *one_sub_beta2 = GetDeviceAddress(inputs, 3); - float *epsilon = GetDeviceAddress(inputs, 4); - float *lr = GetDeviceAddress(inputs, 5); - T *param = GetDeviceAddress(inputs, 6); - T *m = GetDeviceAddress(inputs, 7); - T *v = GetDeviceAddress(inputs, 8); - T *gradient = GetDeviceAddress(inputs, 9); - float *weight_decay = nullptr; - if (weight_decay_) { - weight_decay = GetDeviceAddress(inputs, 10); - } - AdamWeightDecay(element_nums_, true, beta1, one_sub_beta1, beta2, one_sub_beta2, epsilon, lr, weight_decay, m, v, - param, gradient, reinterpret_cast(stream_ptr)); - return true; - } - - protected: - void InitResource() override{}; - void InitSizeLists() override { - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(element_nums_ * sizeof(T)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(sizeof(float)); - input_size_list_.push_back(element_nums_ * sizeof(T)); - if (weight_decay_) { - input_size_list_.push_back(sizeof(float)); - } - output_size_list_.push_back(element_nums_ * sizeof(T)); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int element_nums_; - bool weight_decay_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_ADAM_WEIGHT_DECAY_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc deleted file mode 100644 index 91747d24d8..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/fused_batch_norm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedBatchNormGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FusedBatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FusedBatchNormGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(BatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedBatchNormGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(BatchNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FusedBatchNormGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h deleted file mode 100644 index b0a898209b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class FusedBatchNormGpuKernel : public GpuKernel { - public: - FusedBatchNormGpuKernel() - : batch_(0), - channel_(0), - height_(0), - width_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - epsilon_(10e-5), - exp_avg_factor_(0.1), - is_train_(false), - is_null_input_(false), - x_desc_(nullptr), - y_desc_(nullptr), - scale_bias_mean_var_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~FusedBatchNormGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - auto x = GetDeviceAddress(inputs, 0); - auto scale = GetDeviceAddress(inputs, 1); - auto bias = GetDeviceAddress(inputs, 2); - auto runing_mean = GetDeviceAddress(inputs, 3); - auto runnig_variance = GetDeviceAddress(inputs, 4); - auto y = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - if (is_train_) { - auto save_mean = GetDeviceAddress(outputs, 3); - auto save_variance = GetDeviceAddress(outputs, 4); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnBatchNormalizationForwardTraining(handle_, mode_, &alpha, &beta, x_desc_, x, y_desc_, y, - scale_bias_mean_var_desc_, scale, bias, exp_avg_factor_, runing_mean, - runnig_variance, epsilon_, save_mean, save_variance), - "Kernel launch failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardInference(handle_, mode_, &alpha, &beta, x_desc_, x, - y_desc_, y, scale_bias_mean_var_desc_, scale, - bias, runing_mean, runnig_variance, epsilon_), - "Kernel launch failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGpuKernel should be 5"; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGpuKernel should be >= 4"; - } - is_null_input_ = CHECK_NULL_INPUT(shape); - if (is_null_input_) { - MS_LOG(WARNING) << "FusedBatchNormGpuKernel input is null"; - InitSizeLists(); - return true; - } - batch_ = SizeToInt(shape[0]); - channel_ = SizeToInt(shape[1]); - height_ = SizeToInt(shape[2]); - width_ = SizeToInt(shape[3]); - - mode_ = CUDNN_BATCHNORM_SPATIAL; - epsilon_ = GetAttr(kernel_node, "epsilon"); - // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "FusedBatchNorm") { - is_train_ = true; - exp_avg_factor_ = GetAttr(kernel_node, "momentum"); - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set x desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set y desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), - "Set para desc failed"); - - InitSizeLists(); - - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); - } - void InitSizeLists() override { - size_t input_size = 0; - size_t para_size = 0; - size_t output_size = 0; - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_mean_var_desc_, ¶_size), - "Get para size failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_desc_, &output_size), "Get para size failed"); - } - input_size_list_.push_back(input_size); - input_size_list_.push_back(para_size); // scale - input_size_list_.push_back(para_size); // bias - input_size_list_.push_back(para_size); // mean - input_size_list_.push_back(para_size); // variance - - output_size_list_.push_back(output_size); - output_size_list_.push_back(para_size); // running mean - output_size_list_.push_back(para_size); // running variance - output_size_list_.push_back(para_size); // save mean - output_size_list_.push_back(para_size); // save variance - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); - } - - int batch_; - int channel_; - int height_; - int width_; - cudnnBatchNormMode_t mode_; - double epsilon_; - double exp_avg_factor_; - bool is_train_; - bool is_null_input_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t y_desc_; - cudnnTensorDescriptor_t scale_bias_mean_var_desc_; - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc deleted file mode 100644 index 3947aaea9a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - FusedBatchNormGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(FusedBatchNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - FusedBatchNormGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h deleted file mode 100644 index 712354b17c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batchnorm_grad_gpu_kernel.h +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class FusedBatchNormGradGpuKernel : public GpuKernel { - public: - FusedBatchNormGradGpuKernel() - : batch_(0), - channel_(0), - height_(0), - width_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - epsilon_(10e-5), - is_null_input_(false), - x_desc_(nullptr), - dy_desc_(nullptr), - dx_desc_(nullptr), - scale_bias_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~FusedBatchNormGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(stream_ptr); - if (is_null_input_) { - return true; - } - auto dy = GetDeviceAddress(inputs, 0); - auto x = GetDeviceAddress(inputs, 1); - auto scale = GetDeviceAddress(inputs, 2); - auto save_mean = GetDeviceAddress(inputs, 3); - auto save_variance = GetDeviceAddress(inputs, 4); - auto dx = GetDeviceAddress(outputs, 0); - auto bn_scale = GetDeviceAddress(outputs, 1); - auto bn_bias = GetDeviceAddress(outputs, 2); - - const float alpha_data_diff = 1; - const float beta_data_diff = 0; - const float alpha_param_diff = 1; - const float beta_param_diff = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnBatchNormalizationBackward(handle_, mode_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff, - &beta_param_diff, x_desc_, x, dy_desc_, dy, dx_desc_, dx, scale_bias_desc_, scale, - bn_scale, bn_bias, epsilon_, save_mean, save_variance), - "Kernel Launch Failed."); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", FusedBatchNormGradGpuKernel should be 5"; - } - - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradGpuKernel should be 4"; - return false; - } - is_null_input_ = CHECK_NULL_INPUT(shape); - if (is_null_input_) { - MS_LOG(WARNING) << "FusedBatchNormGradGpuKernel input is null"; - InitSizeLists(); - return true; - } - batch_ = SizeToInt(shape[0]); - channel_ = SizeToInt(shape[1]); - height_ = SizeToInt(shape[2]); - width_ = SizeToInt(shape[3]); - - mode_ = CUDNN_BATCHNORM_SPATIAL; - epsilon_ = GetAttr(kernel_node, "epsilon"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set dy desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), - "Set dx desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), - "Set para desc failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_desc_), "Create para desc failed"); - } - - void InitSizeLists() override { - size_t input_size = 0; - size_t para_size = 0; - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &input_size), "Get input size failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_desc_, ¶_size), "Get input size failed"); - } - - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(para_size); - input_size_list_.push_back(para_size); - input_size_list_.push_back(para_size); - - output_size_list_.push_back(input_size); - output_size_list_.push_back(para_size); - output_size_list_.push_back(para_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_desc_), "Destroy para desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); - } - - int batch_; - int channel_; - int height_; - int width_; - - cudnnBatchNormMode_t mode_; - double epsilon_; - bool is_null_input_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t dy_desc_; - cudnnTensorDescriptor_t dx_desc_; - cudnnTensorDescriptor_t scale_bias_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FUSED_BATCHNORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc deleted file mode 100644 index 32d91be80a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/gelu_grad_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(GeluGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - GeLUGpuGradKernel, float) -MS_REG_GPU_KERNEL_ONE(GeluGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - GeLUGpuGradKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h deleted file mode 100644 index 6415349012..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/gelu_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class GeLUGpuGradKernel : public GpuKernel { - public: - GeLUGpuGradKernel() : input_size_(0) {} - ~GeLUGpuGradKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *dy_addr = GetDeviceAddress(inputs, 0); - T *x_addr = GetDeviceAddress(inputs, 1); - T *dx_addr = GetDeviceAddress(outputs, 0); - - GeluGradKernel(input_size_ / sizeof(T), dy_addr, x_addr, dx_addr, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (auto dim : input_shape) { - input_size_ *= dim; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc deleted file mode 100644 index ca54ff68ad..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/gelu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - GeluGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - GeluGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h deleted file mode 100644 index 60968d109b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/gelu_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class GeluGpuKernel : public GpuKernel { - public: - GeluGpuKernel() : input_size_(0) {} - ~GeluGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - - Gelu(input_size_ / sizeof(T), input_addr, output_addr, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - input_size_ = sizeof(T); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (auto dim : input_shape) { - input_size_ *= dim; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc deleted file mode 100644 index 19e4dc17a6..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/layer_norm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LayerNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LayerNormGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LayerNorm, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LayerNormGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h deleted file mode 100644 index d5ec3ff8f2..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.h +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class LayerNormGpuKernel : public GpuKernel { - public: - LayerNormGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} - ~LayerNormGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto x = GetDeviceAddress(inputs, 0); - auto gamma = GetDeviceAddress(inputs, 1); - auto beta = GetDeviceAddress(inputs, 2); - auto y = GetDeviceAddress(outputs, 0); - auto mean = GetDeviceAddress(outputs, 1); - auto variance = GetDeviceAddress(outputs, 2); - - const T epsilon = 10e-12; - LayerNorm(input_row_, input_col_, param_dim_, epsilon, x, gamma, beta, y, mean, variance, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); - int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (begin_norm_axis < 0) { - begin_norm_axis += input_shape.size(); - } - - if (begin_params_axis < 0) { - begin_params_axis += input_shape.size(); - } - - for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { - input_row_ *= input_shape[i]; - } - - for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { - input_col_ *= input_shape[i]; - } - - for (size_t i = begin_params_axis; i < input_shape.size(); i++) { - param_dim_ *= input_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - input_size_list_.push_back(param_dim_ * sizeof(T)); - input_size_list_.push_back(param_dim_ * sizeof(T)); - - output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - output_size_list_.push_back(input_row_ * sizeof(T)); - output_size_list_.push_back(input_row_ * sizeof(T)); - return; - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int input_row_; - int input_col_; - int param_dim_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc deleted file mode 100644 index 7991d42499..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/layer_norm_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LayerNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LayerNormGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LayerNormGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LayerNormGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h deleted file mode 100644 index 83bdedb9b3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class LayerNormGradGpuKernel : public GpuKernel { - public: - LayerNormGradGpuKernel() : input_row_(1), input_col_(1), param_dim_(1) {} - ~LayerNormGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto x = GetDeviceAddress(inputs, 0); - auto dy = GetDeviceAddress(inputs, 1); - auto var = GetDeviceAddress(inputs, 2); - auto mean = GetDeviceAddress(inputs, 3); - auto gamma = GetDeviceAddress(inputs, 4); - auto dx = GetDeviceAddress(outputs, 0); - auto dg = GetDeviceAddress(outputs, 1); - auto db = GetDeviceAddress(outputs, 2); - - const T epsilon = 10e-12; - LayerNormGrad(input_row_, input_col_, param_dim_, epsilon, dy, x, mean, var, gamma, dx, dg, db, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - int begin_norm_axis = GetAttr(kernel_node, "begin_norm_axis"); - int begin_params_axis = GetAttr(kernel_node, "begin_params_axis"); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (begin_norm_axis < 0) { - begin_norm_axis += input_shape.size(); - } - - if (begin_params_axis < 0) { - begin_params_axis += input_shape.size(); - } - - for (size_t i = 0; i < IntToSize(begin_norm_axis); i++) { - input_row_ *= input_shape[i]; - } - - for (size_t i = begin_norm_axis; i < input_shape.size(); i++) { - input_col_ *= input_shape[i]; - } - - for (size_t i = begin_params_axis; i < input_shape.size(); i++) { - param_dim_ *= input_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - input_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - input_size_list_.push_back(input_row_ * sizeof(T)); - input_size_list_.push_back(input_row_ * sizeof(T)); - input_size_list_.push_back(param_dim_ * sizeof(T)); - - output_size_list_.push_back(input_row_ * input_col_ * sizeof(T)); - output_size_list_.push_back(param_dim_ * sizeof(T)); - output_size_list_.push_back(param_dim_ * sizeof(T)); - return; - } - - private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int input_row_; - int input_col_; - int param_dim_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LAYER_NORM_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.cc deleted file mode 100644 index c745c216f7..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/lstm_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LSTM, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LSTM, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LstmGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.h deleted file mode 100644 index 42eda96b02..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_gpu_kernel.h +++ /dev/null @@ -1,247 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class LstmGpuKernel : public GpuKernel { - public: - LstmGpuKernel() - : batch_size_(0), - seq_len_(0), - input_size_(0), - hidden_size_(0), - num_layers_(0), - has_bias_(false), - bidirectional_(false), - states_init_(false), - dropout_(0), - weight_size_(0), - reserved_size_(0), - x_desc_(nullptr), - hx_desc_(nullptr), - cx_desc_(nullptr), - w_desc_(nullptr), - dropout_desc_(nullptr), - y_desc_(nullptr), - hy_desc_(nullptr), - cy_desc_(nullptr), - rnn_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~LstmGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(stream_ptr); - auto x_addr = GetDeviceAddress(inputs, 0); - auto hx_addr = GetDeviceAddress(inputs, 1); - auto cx_addr = GetDeviceAddress(inputs, 2); - auto w_addr = GetDeviceAddress(inputs, 3); - auto y_addr = GetDeviceAddress(outputs, 0); - auto hy_addr = GetDeviceAddress(outputs, 1); - auto cy_addr = GetDeviceAddress(outputs, 2); - auto reserved_addr = GetDeviceAddress(outputs, 3); - auto states_addr = GetDeviceAddress(outputs, 4); - void *workspace_addr = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, output_size_list_[4], 0), - "set dropout_desc failed"); - states_init_ = true; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRNNForwardTraining(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, - w_desc_, w_addr, y_desc_.get(), y_addr, hy_desc_, hy_addr, cy_desc_, cy_addr, - workspace_addr, workspace_size_list_[0], reserved_addr, reserved_size_), - "launch lstm kernel failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - seq_len_ = SizeToInt(input_shape[0]); - batch_size_ = SizeToInt(input_shape[1]); - input_size_ = SizeToInt(input_shape[2]); - - input_size_ = GetAttr(kernel_node, "input_size"); - hidden_size_ = GetAttr(kernel_node, "hidden_size"); - num_layers_ = GetAttr(kernel_node, "num_layers"); - has_bias_ = GetAttr(kernel_node, "has_bias"); - bidirectional_ = GetAttr(kernel_node, "bidirectional"); - dropout_ = GetAttr(kernel_node, "dropout"); - - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - CreateTensorDescGrp(); - int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set cx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set cy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), - "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); - cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), - "get weight_size_ failed"); - if (weight_size != weight_size_) { - MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; - } - int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), - "set w_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), - "get reserve size failed"); - InitSizeLists(); - return true; - } - void CreateTensorDescGrp() { - int x_dims[3]{batch_size_, input_size_, 1}; - int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; - - x_desc_ = std::make_unique(seq_len_); - y_desc_ = std::make_unique(seq_len_); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); - } - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hy_desc_), "create hy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cy_desc_), "create cy_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); - } - void InitSizeLists() override { - size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); - - size_t h_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); - - input_size_list_.push_back(x_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(weight_size_); - - size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); - output_size_list_.push_back(y_size); - output_size_list_.push_back(h_size); - output_size_list_.push_back(h_size); - output_size_list_.push_back(reserved_size_); - size_t state_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); - output_size_list_.push_back(state_size); - - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), - "get workspace size failed"); - workspace_size_list_.push_back(workspace_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cy_desc_), "destroy cy_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hy_desc_), "destroy hy_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc failed"); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); - } - } - - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - int num_layers_; - - bool has_bias_; - bool bidirectional_; - bool states_init_; - float dropout_; - - size_t weight_size_; - size_t reserved_size_; - - // input desc - std::unique_ptr x_desc_; - cudnnTensorDescriptor_t hx_desc_; - cudnnTensorDescriptor_t cx_desc_; - cudnnFilterDescriptor_t w_desc_; - cudnnDropoutDescriptor_t dropout_desc_; - std::unique_ptr y_desc_; - cudnnTensorDescriptor_t hy_desc_; - cudnnTensorDescriptor_t cy_desc_; - cudnnRNNDescriptor_t rnn_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.cc deleted file mode 100644 index ab88308d4e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/lstm_grad_data_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LSTMGradData, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmGradDataGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LSTMGradData, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LstmGradDataGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.h deleted file mode 100644 index 6eeefa262c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_data_gpu_kernel.h +++ /dev/null @@ -1,284 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class LstmGradDataGpuKernel : public GpuKernel { - public: - LstmGradDataGpuKernel() - : batch_size_(0), - seq_len_(0), - input_size_(0), - hidden_size_(0), - num_layers_(0), - has_bias_(false), - bidirectional_(false), - states_init_(false), - dropout_(0), - weight_size_(0), - reserved_size_(0), - rnn_desc_(nullptr), - y_desc_(nullptr), - dy_desc_(nullptr), - dhy_desc_(nullptr), - dcy_desc_(nullptr), - w_desc_(nullptr), - hx_desc_(nullptr), - cx_desc_(nullptr), - dropout_desc_(nullptr), - dx_desc_(nullptr), - dhx_desc_(nullptr), - dcx_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~LstmGradDataGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(stream_ptr); - auto y_addr = GetDeviceAddress(inputs, 0); - auto dy_addr = GetDeviceAddress(inputs, 1); - auto dhy_addr = GetDeviceAddress(inputs, 2); - auto dcy_addr = GetDeviceAddress(inputs, 3); - auto w_addr = GetDeviceAddress(inputs, 4); - auto hx_addr = GetDeviceAddress(inputs, 5); - auto cx_addr = GetDeviceAddress(inputs, 6); - auto reserved_addr = GetDeviceAddress(inputs, 7); - auto states_addr = GetDeviceAddress(inputs, 8); - auto dx_addr = GetDeviceAddress(outputs, 0); - auto dhx_addr = GetDeviceAddress(outputs, 1); - auto dcx_addr = GetDeviceAddress(outputs, 2); - void *workspace_addr = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[8], 0), - "restore dropout state failed"); - states_init_ = true; - } - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRNNBackwardData(handle_, rnn_desc_, seq_len_, y_desc_.get(), y_addr, dy_desc_.get(), dy_addr, dhy_desc_, - dhy_addr, dcy_desc_, dcy_addr, w_desc_, w_addr, hx_desc_, hx_addr, cx_desc_, cx_addr, - dx_desc_.get(), dx_addr, dhx_desc_, dhx_addr, dcx_desc_, dcx_addr, workspace_addr, - workspace_size_list_[0], reserved_addr, reserved_size_), - "launch lstm back data kernel failed"); - - CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), - "stream synchronize failed."); - return true; - } - void GetAttrs(const CNodePtr &kernel_node) { - input_size_ = GetAttr(kernel_node, "input_size"); - hidden_size_ = GetAttr(kernel_node, "hidden_size"); - num_layers_ = GetAttr(kernel_node, "num_layers"); - has_bias_ = GetAttr(kernel_node, "has_bias"); - bidirectional_ = GetAttr(kernel_node, "bidirectional"); - dropout_ = GetAttr(kernel_node, "dropout"); - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - seq_len_ = SizeToInt(input_shape[0]); - batch_size_ = SizeToInt(input_shape[1]); - GetAttrs(kernel_node); - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - CreateTensorDescGrp(); - int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dhy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dcy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(cx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set cx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dhx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dhx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dcx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), "set dcx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), - "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); - cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); - size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, dx_desc_[0], &weight_size_, cudnn_data_type_), - "get weight_size_ failed"); - if (weight_size != weight_size_) { - MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; - } - int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), - "set w_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &reserved_size_), "get size failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhy_desc_), "create dhy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcy_desc_), "create dcy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&cx_desc_), "create cx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&w_desc_), "create w_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dhx_desc_), "create dhx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dcx_desc_), "create dcx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); - } - - void InitSizeLists() override { - size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); - - size_t h_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); - - input_size_list_.push_back(y_size); - input_size_list_.push_back(y_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(weight_size_); - input_size_list_.push_back(h_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(reserved_size_); - size_t state_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); - input_size_list_.push_back(state_size); - - size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); - output_size_list_.push_back(x_size); - output_size_list_.push_back(h_size); - output_size_list_.push_back(h_size); - - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, dx_desc_.get(), &workspace_size), - "get workspace size failed"); - workspace_size_list_.push_back(workspace_size); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcx_desc_), "destroy dcx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhx_desc_), "destroy dhx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(w_desc_), "destroy w_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(cx_desc_), "destroy cx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dcy_desc_), "destroy dcy_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dhy_desc_), "destroy dhy_desc_ failed"); - DestroyTensorDescGrp(); - } - void CreateTensorDescGrp() { - int x_dims[3]{batch_size_, input_size_, 1}; - int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; - - dx_desc_ = std::make_unique(seq_len_); - y_desc_ = std::make_unique(seq_len_); - dy_desc_ = std::make_unique(seq_len_); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_[i]), "create x_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dx_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), - "set dx_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_[i]), "create dy_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(dy_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), - "set dy_desc_ failed"); - } - } - - void DestroyTensorDescGrp() { - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_[i]), "destroy dy_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_[i]), "destroy x_desc failed"); - } - } - - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - int num_layers_; - - bool has_bias_; - bool bidirectional_; - bool states_init_; - float dropout_; - - size_t weight_size_; - size_t reserved_size_; - - cudnnRNNDescriptor_t rnn_desc_; - - // input desc - std::unique_ptr y_desc_; - std::unique_ptr dy_desc_; - cudnnTensorDescriptor_t dhy_desc_; - cudnnTensorDescriptor_t dcy_desc_; - cudnnFilterDescriptor_t w_desc_; - cudnnTensorDescriptor_t hx_desc_; - cudnnTensorDescriptor_t cx_desc_; - - cudnnDropoutDescriptor_t dropout_desc_; - - // output desc - std::unique_ptr dx_desc_; - cudnnTensorDescriptor_t dhx_desc_; - cudnnTensorDescriptor_t dcx_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_DATA_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.cc deleted file mode 100644 index 856a986e07..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - LstmGradWeightGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LSTMGradWeight, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - LstmGradWeightGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h deleted file mode 100644 index a1a4852c84..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/lstm_grad_weight_gpu_kernel.h +++ /dev/null @@ -1,231 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -namespace mindspore { -namespace kernel { -template -class LstmGradWeightGpuKernel : public GpuKernel { - public: - LstmGradWeightGpuKernel() - : batch_size_(0), - seq_len_(0), - input_size_(0), - hidden_size_(0), - num_layers_(0), - has_bias_(false), - bidirectional_(false), - states_init_(false), - dropout_(0), - weight_size_(0), - reserved_size_(0), - rnn_desc_(nullptr), - dropout_desc_(nullptr), - x_desc_(nullptr), - hx_desc_(nullptr), - y_desc_(nullptr), - dw_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} - ~LstmGradWeightGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - VARIABLE_NOT_USED(stream_ptr); - auto x_addr = GetDeviceAddress(inputs, 0); - auto hx_addr = GetDeviceAddress(inputs, 1); - auto y_addr = GetDeviceAddress(inputs, 2); - auto reserved_addr = GetDeviceAddress(inputs, 3); - auto states_addr = GetDeviceAddress(inputs, 4); - auto dw_addr = GetDeviceAddress(outputs, 0); - void *workspace_addr = GetDeviceAddress(workspace, 0); - - if (!states_init_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRestoreDropoutDescriptor(dropout_desc_, handle_, dropout_, states_addr, input_size_list_[4], 0), - "restore dropout state failed"); - states_init_ = true; - } - - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemsetAsync(dw_addr, 0, outputs[0]->size, reinterpret_cast(stream_ptr)), "cudaMemSet Failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnRNNBackwardWeights(handle_, rnn_desc_, seq_len_, x_desc_.get(), x_addr, hx_desc_, hx_addr, y_desc_.get(), - y_addr, workspace_addr, workspace_size_list_[0], dw_desc_, dw_addr, reserved_addr, - reserved_size_), - "launch lstm back weight kernel failed"); - - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - seq_len_ = SizeToInt(input_shape[0]); - batch_size_ = SizeToInt(input_shape[1]); - - input_size_ = GetAttr(kernel_node, "input_size"); - hidden_size_ = GetAttr(kernel_node, "hidden_size"); - num_layers_ = GetAttr(kernel_node, "num_layers"); - has_bias_ = GetAttr(kernel_node, "has_bias"); - bidirectional_ = GetAttr(kernel_node, "bidirectional"); - dropout_ = GetAttr(kernel_node, "dropout"); - - cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t direction = bidirectional_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t rnn_mode = CUDNN_LSTM; - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - - CreateTensorDescGrp(); - int hx_dims[3]{num_layers_ * (bidirectional_ ? 2 : 1), batch_size_, hidden_size_}; - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptorEx(hx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, hx_dims), - "set hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), - "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); - cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - - auto weight_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNParamsSize(handle_, rnn_desc_, x_desc_[0], &weight_size_, cudnn_data_type_), - "get weight_size_ failed"); - if (weight_size != weight_size_) { - MS_LOG(EXCEPTION) << "weight size: " << weight_size << " error, expect: " << weight_size_ << " ."; - } - int w_dims[3] = {SizeToInt(weight_size_ / 4), 1, 1}; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, 3, w_dims), - "set dw_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetRNNTrainingReserveSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &reserved_size_), - "get reserve size failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&hx_desc_), "create hx_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateFilterDescriptor(&dw_desc_), "create dw_desc_ failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateDropoutDescriptor(&dropout_desc_), "create dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateRNNDescriptor(&rnn_desc_), "create rnn_desc failed"); - } - void InitSizeLists() override { - size_t x_size = IntToSize(seq_len_ * batch_size_ * input_size_) * sizeof(T); - - size_t h_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(hx_desc_, &h_size), "get h size failed"); - - size_t y_size = IntToSize(seq_len_ * batch_size_ * hidden_size_ * (bidirectional_ ? 2 : 1)) * sizeof(T); - input_size_list_.push_back(x_size); - input_size_list_.push_back(h_size); - input_size_list_.push_back(y_size); - input_size_list_.push_back(reserved_size_); - size_t state_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnDropoutGetStatesSize(handle_, &state_size), "get dropout states size failed"); - input_size_list_.push_back(state_size); - - output_size_list_.push_back(weight_size_); - - size_t workspace_size = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetRNNWorkspaceSize(handle_, rnn_desc_, seq_len_, x_desc_.get(), &workspace_size), - "get workspace size failed"); - workspace_size_list_.push_back(workspace_size); - } - - private: - void CreateTensorDescGrp() { - int x_dims[3]{batch_size_, input_size_, 1}; - int y_dims[3]{batch_size_, hidden_size_ * (bidirectional_ ? 2 : 1), 1}; - - x_desc_ = std::make_unique(seq_len_); - y_desc_ = std::make_unique(seq_len_); - - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_[i]), "create x_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(x_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, x_dims), "set x_desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_[i]), "create y_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensorNdDescriptorEx(y_desc_[i], CUDNN_TENSOR_NCHW, cudnn_data_type_, 3, y_dims), "set y_desc failed"); - } - } - void DestroyTensorDescGrp() { - for (size_t i = 0; i < IntToSize(seq_len_); ++i) { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_[i]), "destroy y_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_[i]), "destroy x_desc failed"); - } - } - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyRNNDescriptor(rnn_desc_), "destroy rnn_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyDropoutDescriptor(dropout_desc_), "destroy dropout_desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyFilterDescriptor(dw_desc_), "destroy dw_desc_ failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(hx_desc_), "destroy hx_desc_ failed"); - DestroyTensorDescGrp(); - } - - int batch_size_; - int seq_len_; - int input_size_; - int hidden_size_; - int num_layers_; - - bool has_bias_; - bool bidirectional_; - bool states_init_; - float dropout_; - - size_t weight_size_; - size_t reserved_size_; - - cudnnRNNDescriptor_t rnn_desc_; - cudnnDropoutDescriptor_t dropout_desc_; - - // input desc - std::unique_ptr x_desc_; - cudnnTensorDescriptor_t hx_desc_; - std::unique_ptr y_desc_; - - // output desc - cudnnFilterDescriptor_t dw_desc_; - - cudnnHandle_t handle_; - cudnnDataType_t cudnn_data_type_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_LSTM_GRAD_WEIGHT_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc deleted file mode 100644 index e8b2b17706..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/momentum_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - MomentumGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO(ApplyMomentum, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h deleted file mode 100644 index 5abfb9e97b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/momentum_impl.cuh" -namespace mindspore { -namespace kernel { -template -class MomentumGpuKernel : public GpuKernel { - public: - MomentumGpuKernel() - : variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), momentum_size_(0) {} - ~MomentumGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, - void *stream_ptr) override { - T *variable = GetDeviceAddress(inputs, 0); - T *accumulation = GetDeviceAddress(inputs, 1); - S *learning_rate = GetDeviceAddress(inputs, 2); - T *gradient = GetDeviceAddress(inputs, 3); - S *momentum = GetDeviceAddress(inputs, 4); - MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 5) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but momentum needs 5 inputs."; - return false; - } - - variable_size_ = sizeof(T); - accumulation_size_ = sizeof(T); - learning_rate_size_ = sizeof(S); - gradient_size_ = sizeof(T); - momentum_size_ = sizeof(S); - - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < variable_shape.size(); i++) { - variable_size_ *= variable_shape[i]; - } - auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < accumulation_shape.size(); i++) { - accumulation_size_ *= accumulation_shape[i]; - } - auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); - for (size_t i = 0; i < gradient_shape.size(); i++) { - gradient_size_ *= gradient_shape[i]; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(variable_size_); - input_size_list_.push_back(accumulation_size_); - input_size_list_.push_back(learning_rate_size_); - input_size_list_.push_back(gradient_size_); - input_size_list_.push_back(momentum_size_); - output_size_list_.push_back(0); - } - - private: - size_t variable_size_; - size_t accumulation_size_; - size_t learning_rate_size_; - size_t gradient_size_; - size_t momentum_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_MOMENTUM_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.cc deleted file mode 100644 index e871af360a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/pooling_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - PoolingGpuFwdKernel, half) -MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingGpuFwdKernel, float) -MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - PoolingGpuFwdKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h deleted file mode 100644 index 0dda1e8998..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_gpu_kernel.h +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class PoolingGpuFwdKernel : public GpuKernel { - public: - PoolingGpuFwdKernel() - : cudnn_handle_(nullptr), - input_descriptor_(nullptr), - output_descriptor_(nullptr), - pooling_descriptor_(nullptr), - padded_descriptor_(nullptr), - pooling_mode_(CUDNN_POOLING_MAX), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - pad_value_(0), - is_null_input_(false), - input_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~PoolingGpuFwdKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (is_null_input_) { - return true; - } - T *input_addr = reinterpret_cast(inputs[0]->addr); - T *output_addr = reinterpret_cast(outputs[0]->addr); - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded_addr = reinterpret_cast(workspace[0]->addr); - CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_, - padded_addr, &beta, output_descriptor_, output_addr), - "cudnnPoolingForward failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, - input_addr, &beta, output_descriptor_, output_addr), - "cudnnPoolingForward failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null."; - InitSizeLists(); - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); - int window_height = window[2]; - int window_width = window[3]; - stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); - SetPoolingMode(kernel_node); - if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(input_shape, window_height, window_width); - } else { - pad_height_ = 0; - pad_width_ = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, - window_width, pad_height_, pad_width_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } - - InitSizeLists(); - return true; - } - - protected: - void InitResource() { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), - "cudnnCreatePoolingDescriptor failed"); - } - void InitSizeLists() { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast(&input_size_)), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(output_descriptor_, reinterpret_cast(&output_size_)), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - workspace_size_list_.push_back(padded_size_); - if (padded_size_ == 0) { - MS_LOG(EXCEPTION) << "Padded size is 0."; - } - } - return; - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but pooling needs 1 inputs."; - return false; - } - return true; - } - void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { - n_ = SizeToInt(input_shape[0]); - c_ = SizeToInt(input_shape[1]); - old_height_ = SizeToInt(input_shape[2]); - old_width_ = SizeToInt(input_shape[3]); - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - window_height, window_width, use_pad_ ? 0 : pad_top_, - use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } - void SetPoolingMode(const CNodePtr &kernel_node) { - pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); - mode_ = AnfAlgo::GetCNodeName(kernel_node); - if (mode_ == "AvgPool") { - pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - pad_value_ = 0.0; - } else { - pooling_mode_ = CUDNN_POOLING_MAX; - pad_value_ = kSignedMinFloat; - } - } - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), - "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnTensorDescriptor_t output_descriptor_; - cudnnPoolingDescriptor_t pooling_descriptor_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; - std::vector stride_; - std::string mode_; - std::string pad_mode_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - cudnnDataType_t cudnn_data_type_; - - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - float pad_value_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.cc deleted file mode 100644 index c3d4a44943..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/pooling_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(AvgPoolGradGpu, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h deleted file mode 100644 index e8f1ebc1af..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/pooling_grad_gpu_kernel.h +++ /dev/null @@ -1,296 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ - -#include -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/pad_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class PoolingGradGpuKernel : public GpuKernel { - public: - PoolingGradGpuKernel() - : cudnn_handle_(nullptr), - pooling_descriptor_(nullptr), - y_descriptor_(nullptr), - dy_descriptor_(nullptr), - x_descriptor_(nullptr), - dx_descriptor_(nullptr), - padded_descriptor_(nullptr), - pooling_mode_(CUDNN_POOLING_MAX), - cudnn_data_type_(CUDNN_DATA_FLOAT), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - pad_value_(0), - is_null_input_(false), - input_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} - ~PoolingGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *x_data = GetDeviceAddress(inputs, 0); - T *y = GetDeviceAddress(inputs, 1); - T *dy = GetDeviceAddress(inputs, 2); - T *dx = GetDeviceAddress(outputs, 0); - - const float alpha = 1; - const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 0); - T *padded_dx = GetDeviceAddress(workspace, 1); - - CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx), - "cudnnPoolingBackward failed"); - - CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - x_descriptor_, x_data, &beta, dx_descriptor_, dx), - "cudnnPoolingBackward failed"); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - if (!CheckParam(kernel_node)) { - return false; - } - auto window = GetAttr>(kernel_node, "ksize"); - int window_height = window[2]; - int window_width = window[3]; - SetPoolingMode(kernel_node); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); - if (is_null_input_) { - MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; - InitSizeLists(); - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]), - SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])), - "cudnnSetTensor4dDescriptor"); - - auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]), - SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])), - "cudnnSetTensor4dDescriptor"); - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); - if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { - SetPad(input_shape, window_height, window_width); - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, - window_width, pad_height_, pad_width_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), - "cudnnSetTensor4dDescriptor"); - } - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), - "cudnnCreatePoolingDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(y_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dx_descriptor_, &output_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(dy_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), - "cudnnGetTensorSizeInBytes failed"); - if (padded_size_ == 0) { - MS_LOG(EXCEPTION) << "Padded size is 0."; - } - workspace_size_list_.push_back(padded_size_); - workspace_size_list_.push_back(padded_size_); - } - return; - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but PoolingGradGpuKernel needs 3 inputs."; - return false; - } - return true; - } - void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { - n_ = SizeToInt(input_shape[0]); - c_ = SizeToInt(input_shape[1]); - old_height_ = SizeToInt(input_shape[2]); - old_width_ = SizeToInt(input_shape[3]); - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0), - SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)), - "cudnnSetTensor4dDescriptor"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - window_height, window_width, use_pad_ ? 0 : pad_top_, - use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } - void SetPoolingMode(const CNodePtr &kernel_node) { - pad_mode_ = GetAttr(kernel_node, "padding"); - stride_ = GetAttr>(kernel_node, "strides"); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - mode_ = AnfAlgo::GetCNodeName(kernel_node); - if (mode_ == "AvgPoolGradGpu") { - pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - pad_value_ = 0.0; - } else { - pooling_mode_ = CUDNN_POOLING_MAX; - pad_value_ = kSignedMinFloat; - } - } - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), - "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_descriptor_), "cudnnDestroyTensorDescriptor failed"); - } - - cudnnHandle_t cudnn_handle_; - cudnnPoolingDescriptor_t pooling_descriptor_; - cudnnTensorDescriptor_t y_descriptor_; - cudnnTensorDescriptor_t dy_descriptor_; - cudnnTensorDescriptor_t x_descriptor_; - cudnnTensorDescriptor_t dx_descriptor_; - cudnnTensorDescriptor_t padded_descriptor_; - cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; - std::vector stride_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - std::string mode_; - std::string pad_mode_; - cudnnDataType_t cudnn_data_type_; - int old_height_; - int old_width_; - int pad_height_; - int pad_width_; - int pad_top_; - int pad_left_; - int n_; - int c_; - float pad_value_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t padded_size_; - size_t workspace_size_; - bool use_pad_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc deleted file mode 100644 index 032e8eeec4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/rmsprop_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - RMSPropGpuKernel, float) - -MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - RMSPropGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h deleted file mode 100644 index 9e148b690d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class RMSPropGpuKernel : public GpuKernel { - public: - RMSPropGpuKernel() : size_(1), use_center_(false), decay_(0.0), momentum_(0.9), epsilon_(1e-12) {} - ~RMSPropGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream) override { - if (!use_center_) { - T *variable = GetDeviceAddress(inputs, 0); - T *mean_square = GetDeviceAddress(inputs, 1); - T *moment = GetDeviceAddress(inputs, 2); - T *learning_rate = GetDeviceAddress(inputs, 3); - T *gradients = GetDeviceAddress(inputs, 4); - - RmsProp(learning_rate, decay_, momentum_, epsilon_, variable, mean_square, moment, gradients, size_, - reinterpret_cast(stream)); - } else { - T *variable = GetDeviceAddress(inputs, 0); - T *mean_gradients = GetDeviceAddress(inputs, 1); - T *mean_square = GetDeviceAddress(inputs, 2); - T *moment = GetDeviceAddress(inputs, 3); - T *gradients = GetDeviceAddress(inputs, 4); - T *learning_rate = GetDeviceAddress(inputs, 5); - T *decay = GetDeviceAddress(inputs, 6); - T *momentum = GetDeviceAddress(inputs, 7); - T *epsilon = GetDeviceAddress(inputs, 8); - - RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients, - size_, reinterpret_cast(stream)); - } - return true; - } - bool Init(const CNodePtr &kernel_node) override { - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (node_name == "ApplyCenteredRMSProp") { - use_center_ = true; - } - - if (node_name == "ApplyRMSProp") { - decay_ = GetAttr(kernel_node, "rho"); - momentum_ = GetAttr(kernel_node, "momentum"); - epsilon_ = GetAttr(kernel_node, "epsilon"); - } - auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (auto &dim : input_shape) { - size_ *= dim; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t input_size = size_ * sizeof(T); - if (!use_center_) { - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(input_size); - output_size_list_.push_back(input_size); - } else { - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(input_size); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - input_size_list_.push_back(sizeof(T)); - output_size_list_.push_back(input_size); - } - } - - private: - size_t size_; - bool use_center_; - float decay_; - float momentum_; - float epsilon_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc deleted file mode 100644 index 1e650811fd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - SigmoidCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SigmoidCrossEntropyWithLogitsGpuKernel, float, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h deleted file mode 100644 index 8d0efe90b4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel { - public: - SigmoidCrossEntropyWithLogitsGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} - - ~SigmoidCrossEntropyWithLogitsGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *outputs_addr = GetDeviceAddress(outputs, 0); - - SigmoidCrossEntropyWithLogits(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs 2 inputs."; - return false; - } - logits_size_ = sizeof(T); - labels_size_ = sizeof(S); - outputs_size_ = sizeof(T); - - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < logits_shape.size(); i++) { - logits_size_ *= logits_shape[i]; - } - - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < labels_shape.size(); i++) { - labels_size_ *= labels_shape[i]; - } - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < output_shape.size(); i++) { - outputs_size_ *= output_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(outputs_size_); - } - - private: - size_t logits_size_; - size_t labels_size_; - size_t outputs_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc deleted file mode 100644 index dabc4df850..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h deleted file mode 100644 index 01f416f6b7..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { - public: - SigmoidCrossEntropyWithLogitsGradGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} - ~SigmoidCrossEntropyWithLogitsGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *outputs_addr = GetDeviceAddress(outputs, 0); - - SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogitsGrad needs 3 inputs."; - return false; - } - logits_size_ = sizeof(T); - labels_size_ = sizeof(S); - outputs_size_ = sizeof(T); - - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < logits_shape.size(); i++) { - logits_size_ *= logits_shape[i]; - } - - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - for (size_t i = 0; i < labels_shape.size(); i++) { - labels_size_ *= labels_shape[i]; - } - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < output_shape.size(); i++) { - outputs_size_ *= output_shape[i]; - } - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(outputs_size_); - } - - private: - size_t logits_size_; - size_t labels_size_; - size_t outputs_size_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.cc deleted file mode 100644 index dec1d23663..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.cc +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - SmoothL1Loss, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SmoothL1LossGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h deleted file mode 100644 index 1317e7a6a0..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh" -namespace mindspore { -namespace kernel { -template -class SmoothL1LossGpuKernel : public GpuKernel { - public: - SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {} - ~SmoothL1LossGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *prediction = GetDeviceAddress(inputs, 0); - T *target = GetDeviceAddress(inputs, 1); - T *loss = GetDeviceAddress(outputs, 0); - - SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - - sigma_ = GetAttr(kernel_node, "sigma"); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_ * sizeof(T)); - input_size_list_.push_back(input_size_ * sizeof(T)); - output_size_list_.push_back(input_size_ * sizeof(T)); - } - - private: - size_t input_size_; - float sigma_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc deleted file mode 100644 index c4acd1fb45..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(SmoothL1LossGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SmoothL1LossGradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h deleted file mode 100644 index 5319e0496c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh" -namespace mindspore { -namespace kernel { -template -class SmoothL1LossGradGpuKernel : public GpuKernel { - public: - SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {} - ~SmoothL1LossGradGpuKernel() override = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *prediction = GetDeviceAddress(inputs, 0); - T *target = GetDeviceAddress(inputs, 1); - T *dloss = GetDeviceAddress(inputs, 2); - T *dx = GetDeviceAddress(outputs, 0); - - SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - - sigma_ = GetAttr(kernel_node, "sigma"); - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_ * sizeof(T)); - input_size_list_.push_back(input_size_ * sizeof(T)); - output_size_list_.push_back(input_size_ * sizeof(T)); - } - - private: - size_t input_size_; - float sigma_; - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc deleted file mode 100644 index 160a26d200..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO(SoftmaxCrossEntropyWithLogits, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - SoftmaxCrossEntropyWithLogitsGpuKernel, float, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h deleted file mode 100644 index 8256174bcb..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ /dev/null @@ -1,205 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/cross_entropy_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { - public: - SoftmaxCrossEntropyWithLogitsGpuKernel() - : cudnn_handle_(nullptr), - logits_descriptor_(nullptr), - softmax_output_descriptor_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_null_input_(false), - logits_size_(0), - labels_size_(0), - output1_size_(0), - output2_size_(0), - softmax_output_logits_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *loss_addr = GetDeviceAddress(outputs, 0); - T *dlogits_addr = GetDeviceAddress(outputs, 1); - T *softmax_output_logits = GetDeviceAddress(workspace, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, - softmax_output_descriptor_, softmax_output_logits), - "cudnnSoftmaxForward failed."); - - CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num - << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 2) { - MS_LOG(ERROR) << "Output number is " << output_num - << ", but SoftmaxCrossEntropyWithLogitsGpuKernel needs 2 output."; - return false; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - - InferInputOutputSize(kernel_node); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size_, channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, - channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(output1_size_); - output_size_list_.push_back(output2_size_); - workspace_size_list_.push_back(softmax_output_logits_size_); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - void InferInputOutputSize(const CNodePtr &kernel_node) { - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; - InitSizeLists(); - return; - } - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; - InitSizeLists(); - return; - } - CheckShapeValidation(logits_shape, labels_shape); - - size_t logits_dims = logits_shape.size(); - batch_size_ = 1; - for (size_t i = 0; i < logits_dims - 1; i++) { - batch_size_ *= logits_shape[i]; - } - channel_size_ = logits_shape[logits_dims - 1]; - height_ = 1; - width_ = 1; - logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - - labels_size_ = 1; - size_t labels_dims = labels_shape.size(); - for (size_t i = 0; i < labels_dims; i++) { - labels_size_ *= labels_shape[i]; - } - labels_size_ *= sizeof(S); - - output1_size_ = logits_size_ / logits_shape[logits_dims - 1]; - output2_size_ = logits_size_; - softmax_output_logits_size_ = logits_size_; - return; - } - void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { - size_t logits_dim_length = logits_shape.size(); - size_t labels_dim_length = labels_shape.size(); - if (labels_dim_length != logits_dim_length) { - MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length for " - "SoftmaxCrossEntropyWithLogits, but got Labels " - "shape length:" - << labels_dim_length << ", Logits shape length:" << logits_dim_length; - } - if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { - MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; - } - return; - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t logits_descriptor_; - cudnnTensorDescriptor_t softmax_output_descriptor_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_null_input_; - - size_t logits_size_; - size_t labels_size_; - size_t output1_size_; - size_t output2_size_; - size_t softmax_output_logits_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.cc deleted file mode 100644 index b9667ed85b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/softmax_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SoftmaxGpuKernel, half) -MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(LogSoftmax, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SoftmaxGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.h deleted file mode 100644 index 9d5a2a24e1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_gpu_kernel.h +++ /dev/null @@ -1,252 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/transpose_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SoftmaxGpuKernel : public GpuKernel { - public: - SoftmaxGpuKernel() - : cudnn_handle_(nullptr), - input_descriptor_(nullptr), - output_descriptor_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_null_input_(false), - input_size_(0), - output_size_(0), - workspace_size_(0), - axis_(0), - shape_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SoftmaxGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); - const float alpha = 1; - const float beta = 0; - - if (axis_ == 1) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, - input_addr, &beta, output_descriptor_, output_addr), - "cudnnSoftmaxForward failed"); - } else { - T *transpose_input_addr = GetDeviceAddress(workspace, 0); - T *transpose_output_addr = GetDeviceAddress(workspace, 1); - int *input_shape = GetDeviceAddress(workspace, 2); - int *transpose_shape = GetDeviceAddress(workspace, 3); - int *transpose_axis = GetDeviceAddress(workspace, 4); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, input_addr, input_shape, transpose_axis, shape_size_, transpose_input_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, input_descriptor_, transpose_input_addr, &beta, - output_descriptor_, transpose_output_addr), - "cudnnSoftmaxForward failed"); - CalTranspose(size, transpose_output_addr, transpose_shape, transpose_axis, shape_size_, output_addr, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax needs 1 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxGpuKernel input is null"; - InitSizeLists(); - return true; - } - shape_size_ = SizeToInt(input_shape.size()); - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "LogSoftmax") { - algo_ = CUDNN_SOFTMAX_LOG; - auto axis = GetAttr(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis); - } else { - algo_ = CUDNN_SOFTMAX_ACCURATE; - auto axis = GetAttr>(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis[0]); - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), - SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), - "set input_descriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), - SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), - "set output_descriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "create input_descriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "create output_descriptor failed"); - } - - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "destroy output_descriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "destroy input_descriptor failed"); - } - - void InitSizeByAxis(const std::vector &input_shape, const int &axis) { - if (input_shape.size() == 2) { - InitSizeByAxis2D(input_shape, axis); - } else { - InitSizeByAxisLastDim(input_shape, axis); - } - } - - void InitSizeByAxis2D(const std::vector &input_shape, const int &axis) { - axis_ = axis; - if (axis_ < 0) { - axis_ += shape_size_; - } - if (axis_ == 1) { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - } else if (axis_ == 0) { - batch_size_ = input_shape[1]; - channel_size_ = input_shape[0]; - input_shape_.push_back(input_shape[0]); - input_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[0]); - transpose_axis_.push_back(1); - transpose_axis_.push_back(0); - } else { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; - } - - height_ = 1; - width_ = 1; - input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - output_size_ = input_size_; - workspace_size_ = IntToSize(shape_size_) * sizeof(int); - } - - void InitSizeByAxisLastDim(const std::vector &input_shape, const int &axis) { - int axis_pos = axis; - if (axis_pos < 0) { - axis_pos += input_shape.size(); - } - // axis should be -1 with ND - if (axis_pos != SizeToInt(input_shape.size() - 1)) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; - } - // squeeze to 2d, then invoke cudnn - size_t n = 1; - for (size_t i = 0; i < input_shape.size() - 1; i++) { - n *= input_shape[i]; - } - axis_ = 1; - batch_size_ = n; - channel_size_ = input_shape[axis_pos]; - height_ = 1; - width_ = 1; - input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - output_size_ = input_size_; - input_shape_.push_back(batch_size_); - input_shape_.push_back(channel_size_); - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_descriptor_; - cudnnTensorDescriptor_t output_descriptor_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - std::vector input_shape_; - std::vector transpose_shape_; - std::vector transpose_axis_; - int axis_; - int shape_size_; - - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.cc deleted file mode 100644 index 5b07136522..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/softmax_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - LogSoftmaxGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - SoftmaxGradGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - LogSoftmaxGrad, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - SoftmaxGradGpuKernel, half) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.h deleted file mode 100644 index d73503d5a5..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_grad_gpu_kernel.h +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/transpose_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class SoftmaxGradGpuKernel : public GpuKernel { - public: - SoftmaxGradGpuKernel() - : cudnn_handle_(nullptr), - y_desc_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_null_input_(false), - input_size_(0), - output_size_(0), - workspace_size_(0), - axis_(0), - shape_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SoftmaxGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *y_addr = GetDeviceAddress(inputs, 0); - T *dy_addr = GetDeviceAddress(inputs, 1); - T *dx_addr = GetDeviceAddress(outputs, 0); - - T *transpose_y_addr = GetDeviceAddress(workspace, 0); - T *transpose_dy_addr = GetDeviceAddress(workspace, 1); - T *transpose_dx_addr = GetDeviceAddress(workspace, 2); - int *input_shape = GetDeviceAddress(workspace, 3); - int *transpose_shape = GetDeviceAddress(workspace, 4); - int *transpose_axis = GetDeviceAddress(workspace, 5); - const float alpha = 1; - const float beta = 0; - - if (axis_ == 1) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, y_addr, y_desc_, - dy_addr, &beta, y_desc_, dx_addr), - "cudnnSoftmaxBackward failed"); - } else { - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_shape, &transpose_shape_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(transpose_axis, &transpose_axis_[0], workspace_size_, - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - int size = SizeToInt(input_size_ / sizeof(T)); - CalTranspose(size, y_addr, input_shape, transpose_axis, shape_size_, transpose_y_addr, - reinterpret_cast(stream_ptr)); - CalTranspose(size, dy_addr, input_shape, transpose_axis, shape_size_, transpose_dy_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSoftmaxBackward(cudnn_handle_, algo_, mode_, &alpha, y_desc_, transpose_y_addr, - y_desc_, transpose_dy_addr, &beta, y_desc_, transpose_dx_addr), - "cudnnSoftmaxBackward failed"); - CalTranspose(size, transpose_dx_addr, transpose_shape, transpose_axis, shape_size_, dx_addr, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but softmax grad needs 2 input."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but softmax grad needs 1 output."; - return false; - } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxGradGpuKernel input is null"; - InitSizeLists(); - return true; - } - shape_size_ = SizeToInt(input_shape.size()); - if (shape_size_ != 2) { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but softmax grad only supports 2-D inputs."; - } - auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "LogSoftmaxGrad") { - algo_ = CUDNN_SOFTMAX_LOG; - auto axis = GetAttr(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis); - } else { - algo_ = CUDNN_SOFTMAX_ACCURATE; - auto axis = GetAttr>(kernel_node, "axis"); - InitSizeByAxis(input_shape, axis[0]); - } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(batch_size_), - SizeToInt(channel_size_), SizeToInt(height_), SizeToInt(width_)), - "set input_descriptor failed"); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "create input_descriptor failed"); - } - - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(input_size_); - workspace_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "destroy output_descriptor failed"); - } - - void InitSizeByAxis(const std::vector input_shape, const int axis) { - axis_ = axis; - if (axis_ < 0) { - axis_ += shape_size_; - } - if (axis_ == 1) { - batch_size_ = input_shape[0]; - channel_size_ = input_shape[1]; - } else if (axis_ == 0) { - batch_size_ = input_shape[1]; - channel_size_ = input_shape[0]; - input_shape_.push_back(input_shape[0]); - input_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[1]); - transpose_shape_.push_back(input_shape[0]); - transpose_axis_.push_back(1); - transpose_axis_.push_back(0); - } else { - MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but axis(" << axis << ") is invalid."; - } - - height_ = 1; - width_ = 1; - input_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - output_size_ = input_size_; - workspace_size_ = IntToSize(shape_size_) * sizeof(int); - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t y_desc_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_null_input_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - std::vector input_shape_; - std::vector transpose_shape_; - std::vector transpose_axis_; - int axis_; - int shape_size_; - - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SOFTMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc deleted file mode 100644 index 537eeb5726..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.cc +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_TWO( - SparseSoftmaxCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int) -MS_REG_GPU_KERNEL_TWO( - SparseSoftmaxCrossEntropyWithLogits, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - SparseSoftmaxCrossEntropyWithLogitsGpuKernel, float, int64_t) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h deleted file mode 100644 index 6950f0e308..0000000000 --- a/mindspore/ccsrc/kernel/gpu/nn/sparse_softmax_cross_entropy_with_logits_gpu_kernel.h +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ - -#include -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/cross_entropy_impl.cuh" -#include "kernel/gpu/kernel_constants.h" - -namespace mindspore { -namespace kernel { -template -class SparseSoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { - public: - SparseSoftmaxCrossEntropyWithLogitsGpuKernel() - : cudnn_handle_(nullptr), - logits_descriptor_(nullptr), - softmax_output_descriptor_(nullptr), - algo_(CUDNN_SOFTMAX_ACCURATE), - mode_(CUDNN_SOFTMAX_MODE_INSTANCE), - cudnn_data_type_(CUDNN_DATA_FLOAT), - is_grad_(false), - is_null_input_(false), - logits_size_(0), - labels_size_(0), - output_size_(0), - softmax_output_logits_size_(0), - batch_size_(0), - channel_size_(0), - height_(0), - width_(0) {} - ~SparseSoftmaxCrossEntropyWithLogitsGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *logits_addr = GetDeviceAddress(inputs, 0); - S *labels_addr = GetDeviceAddress(inputs, 1); - T *output_addr = GetDeviceAddress(outputs, 0); - T *softmax_output_logits = GetDeviceAddress(workspace, 0); - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSoftmaxForward(cudnn_handle_, algo_, mode_, &alpha, logits_descriptor_, logits_addr, &beta, - softmax_output_descriptor_, softmax_output_logits), - "cudnnSoftmaxForward failed."); - - is_grad_ ? CrossEntropyGradWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, - reinterpret_cast(stream_ptr)) - : CrossEntropyWithSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output_addr, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num - << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 2 inputs."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num - << ", but SparseSoftmaxCrossEntropyWithLogitsGpuKernel needs 1 output."; - return false; - } - is_grad_ = GetAttr(kernel_node, "is_grad"); - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - - InferInputOutputSize(kernel_node); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(logits_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, - batch_size_, channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(softmax_output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_size_, - channel_size_, height_, width_), - "cudnnSetTensor4dDescriptor failed."); - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&logits_descriptor_), - "cudnnCreateTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&softmax_output_descriptor_), - "cudnnCreateTensorDescriptor failed."); - } - void InitSizeLists() override { - input_size_list_.push_back(logits_size_); - input_size_list_.push_back(labels_size_); - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(softmax_output_logits_size_); - return; - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(softmax_output_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(logits_descriptor_), - "cudnnDestroyTensorDescriptor failed."); - } - void InferInputOutputSize(const CNodePtr &kernel_node) { - auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input1 is null"; - InitSizeLists(); - return; - } - auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - is_null_input_ = CHECK_NULL_INPUT(logits_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "SoftmaxCrossEntropyWithLogitsGpuKernel input2 is null"; - InitSizeLists(); - return; - } - CheckShapeValidation(logits_shape, labels_shape); - - size_t logits_dims = logits_shape.size(); - batch_size_ = 1; - for (size_t i = 0; i < logits_dims - 1; i++) { - batch_size_ *= logits_shape[i]; - } - channel_size_ = logits_shape[logits_dims - 1]; - height_ = 1; - width_ = 1; - logits_size_ = sizeof(T) * batch_size_ * channel_size_ * height_ * width_; - - labels_size_ = 1; - size_t labels_dims = labels_shape.size(); - for (size_t i = 0; i < labels_dims; i++) { - labels_size_ *= labels_shape[i]; - } - labels_size_ *= sizeof(S); - - output_size_ = is_grad_ ? logits_size_ : sizeof(T); - softmax_output_logits_size_ = logits_size_; - return; - } - void CheckShapeValidation(const std::vector &logits_shape, const std::vector &labels_shape) { - size_t logits_dim_length = logits_shape.size(); - size_t labels_dim_length = labels_shape.size(); - if (labels_dim_length != logits_dim_length - 1) { - MS_LOG(EXCEPTION) << "Labels shape length should be equal to Logits shape length minus 1 for " - "SparseSoftmaxCrossEntropyWithLogits, " - "but got Labels shape length:" - << labels_dim_length << ", Logits shape length:" << logits_dim_length; - } - if (!std::equal(labels_shape.begin(), labels_shape.end(), logits_shape.begin())) { - MS_LOG(EXCEPTION) << "The shape of labels should be the same as the shape of logits except its last demension."; - } - return; - } - - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t logits_descriptor_; - cudnnTensorDescriptor_t softmax_output_descriptor_; - cudnnSoftmaxAlgorithm_t algo_; - cudnnSoftmaxMode_t mode_; - cudnnDataType_t cudnn_data_type_; - bool is_grad_; - bool is_null_input_; - - size_t logits_size_; - size_t labels_size_; - size_t output_size_; - size_t softmax_output_logits_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - size_t batch_size_; - size_t channel_size_; - size_t height_; - size_t width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc deleted file mode 100644 index 0f3e0c95f4..0000000000 --- a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/other/assign_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE( - Assign, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - AssignGpuKernel, float) -MS_REG_GPU_KERNEL_ONE( - Assign, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - AssignGpuKernel, half) -MS_REG_GPU_KERNEL_ONE( - Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - AssignGpuKernel, int) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h deleted file mode 100644 index b41d583a43..0000000000 --- a/mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -template -class AssignGpuKernel : public GpuKernel { - public: - AssignGpuKernel() : input_size_(0) {} - ~AssignGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - T *var = GetDeviceAddress(inputs, 0); - T *value = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(var, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "cudaMemxcpyAsync failed."); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(output, value, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "cudaMemxcpyAsync failed."); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - if (!CheckParam(kernel_node)) { - return false; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_size_ = sizeof(T); - for (size_t x : shape) { - input_size_ = input_size_ * x; - } - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - input_size_list_.push_back(input_size_); - output_size_list_.push_back(input_size_); - } - - private: - bool CheckParam(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but AssignGpuKernel needs 2 output."; - return false; - } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but AssignGpuKernel needs 1 output."; - return false; - } - return true; - } - - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - size_t input_size_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc deleted file mode 100644 index af95767407..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.cc +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFold2, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFold2GpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h deleted file mode 100644 index b898f34689..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFold2GpuKernel : public GpuKernel { - public: - BatchNormFold2GpuKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - batch_size_(0), - channel_(0), - height_(0), - width_(0), - freeze_bn_(0) {} - - ~BatchNormFold2GpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - auto *input = GetDeviceAddress(inputs, 0); - auto *beta = GetDeviceAddress(inputs, 1); - auto *gamma = GetDeviceAddress(inputs, 2); - auto *batch_std = GetDeviceAddress(inputs, 3); - auto *batch_mean = GetDeviceAddress(inputs, 4); - auto *running_std = GetDeviceAddress(inputs, 5); - auto *running_mean = GetDeviceAddress(inputs, 6); - auto *global_step = GetDeviceAddress(inputs, 7); - auto *output = GetDeviceAddress(outputs, 0); - - BatchNormFold2Forward(input, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, output, - freeze_bn_, batch_size_, channel_, height_, width_, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 8) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "BatchNormFold2GpuKernel input is null"; - InitSizeLists(); - return true; - } - - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "BatchNormFold2GpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = channel_ * sizeof(T); - input_size_list_.push_back(input_size); - input_size_list_.push_back(weight_size); // beta - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // batch_std - input_size_list_.push_back(weight_size); // batch_mean - input_size_list_.push_back(weight_size); // running_std - input_size_list_.push_back(weight_size); // running_mean - input_size_list_.push_back(sizeof(int32_t)); // global_step - output_size_list_.push_back(input_size); - } - - private: - void DestroyResource() noexcept {} - - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - size_t freeze_bn_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc deleted file mode 100644 index 93862aeedd..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFold2GradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h deleted file mode 100644 index e0bafdb96a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFold2GradGpuKernel : public GpuKernel { - public: - BatchNormFold2GradGpuKernel() - : cudnn_handle_(nullptr), - is_null_input_(false), - batch_size_(0), - channel_(0), - height_(0), - width_(0), - freeze_bn_(0) {} - - ~BatchNormFold2GradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - - auto *dout = GetDeviceAddress(inputs, 0); - auto *x = GetDeviceAddress(inputs, 1); - auto *gamma = GetDeviceAddress(inputs, 2); - auto *batch_std = GetDeviceAddress(inputs, 3); - auto *batch_mean = GetDeviceAddress(inputs, 4); - auto *running_std = GetDeviceAddress(inputs, 5); - auto *running_mean = GetDeviceAddress(inputs, 6); - auto *global_step = GetDeviceAddress(inputs, 7); - auto *d_batch_std = GetDeviceAddress(outputs, 0); - auto *d_batch_mean = GetDeviceAddress(outputs, 1); - auto *d_beta = GetDeviceAddress(outputs, 2); - auto *d_gamma = GetDeviceAddress(outputs, 3); - auto *d_x = GetDeviceAddress(outputs, 4); - auto *tmp = GetDeviceAddress(workspace, 0); - auto *tmp2 = GetDeviceAddress(workspace, 1); - auto *reduce_x = GetDeviceAddress(workspace, 2); - auto *tmp_x = GetDeviceAddress(workspace, 3); - - int32_t current_step_host[1]; - size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - CHECK_CUDA_RET_WITH_ERROR( - cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - - BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, - reinterpret_cast(stream_ptr)); - if (current_step_host[0] < freeze_bn_) { - CalBatchNormFold2GradNotFreezeDxMul(batch_std, running_std, d_x, batch_size_, channel_, height_, width_, - reinterpret_cast(stream_ptr)); - CalBatchNormFold2GradNotFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, - d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); - } else { - CalBatchNormFold2GradFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, - d_batch_mean, d_batch_std, channel_, reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 8) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "BatchNormFold2GradGpuKernel input is null"; - InitSizeLists(); - return true; - } - - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "BatchNormFold2GradGpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - InitSizeLists(); - return true; - } - - protected: - void InitResource() override { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); } - - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = channel_ * sizeof(T); - size_t workspace_size = batch_size_ * channel_ * sizeof(T); - input_size_list_.push_back(input_size); // dout - input_size_list_.push_back(input_size); // x - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // batch_std - input_size_list_.push_back(weight_size); // batch_mean - input_size_list_.push_back(weight_size); // running_std - input_size_list_.push_back(weight_size); // running_mean - input_size_list_.push_back(sizeof(int32_t)); // global_step - - output_size_list_.push_back(weight_size); // d_batch_std - output_size_list_.push_back(weight_size); // d_batch_mean - output_size_list_.push_back(weight_size); // d_beta - output_size_list_.push_back(weight_size); // d_gamma - output_size_list_.push_back(input_size); // d_x - - workspace_size_list_.push_back(workspace_size); // tmp - workspace_size_list_.push_back(workspace_size); // tmp2 - workspace_size_list_.push_back(weight_size); // reduce_x - workspace_size_list_.push_back(input_size); // tmp_x - } - - private: - void DestroyResource() noexcept {} - - cudnnHandle_t cudnn_handle_; - bool is_null_input_; - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - int32_t freeze_bn_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc deleted file mode 100644 index 4f968a0fa3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFold, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFoldGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h deleted file mode 100644 index 6cd001fd2e..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h +++ /dev/null @@ -1,209 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/kernel_constants.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFoldGpuKernel : public GpuKernel { - public: - BatchNormFoldGpuKernel() - : input_size_(0), - output_size_(0), - exp_avg_factor_(0.9), - epsilon_(1e-12), - is_training_(true), - freeze_bn_(0), - batch_(0), - channel_(0), - height_(0), - width_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - x_desc_(nullptr), - scale_bias_mean_var_desc_(nullptr), - handle_(nullptr) {} - - ~BatchNormFoldGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - (void)workspace; - auto x = GetDeviceAddress(inputs, 0); - auto mean = GetDeviceAddress(inputs, 1); - auto variance = GetDeviceAddress(inputs, 2); - int *current_step = GetDeviceAddress(inputs, 3); - int current_step_host[1]; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "Copy gpu memoy failed."); - if (x == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; - return false; - } - if (mean == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel mean is null."; - return false; - } - if (variance == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel variance is null."; - return false; - } - if (current_step == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; - return false; - } - auto batch_mean = GetDeviceAddress(outputs, 0); - auto batch_std = GetDeviceAddress(outputs, 1); - auto running_mean = GetDeviceAddress(outputs, 2); - auto running_std = GetDeviceAddress(outputs, 3); - auto y = GetDeviceAddress(workspace, 0); - - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Failed to copy gpu memory."); - CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast(stream_ptr)); - if (!is_training_ || current_step_host[0] >= freeze_bn_) { - CHECK_CUDA_RET_WITH_ERROR(cudaMemset(batch_mean, 0, output_size_), "Failed to set gpu memory."); - ThrustFillWith(batch_std, channel_, 1.f, reinterpret_cast(stream_ptr)); - return true; - } - const T alpha = 1; - const T beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardTraining( - handle_, mode_, &alpha, &beta, x_desc_, x, x_desc_, y, scale_bias_mean_var_desc_, - mean, mean, exp_avg_factor_, mean, variance, epsilon_, batch_mean, batch_std), - "Failed to launch kernel.") - CalUpdateBatchStd(channel_, batch_std, reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(ERROR) << "Input number is " << input_num << " but BatchNormFold GpuKernel OP needs 4 input."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 4) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFold GpuKernel OP needs 4 output."; - return false; - } - - T momentum = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("momentum")); - exp_avg_factor_ = 1.0 - momentum; - epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); - is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "Input shape is " << input_shape.size() - << ", but BatchNormFold GpuKernel OP needs 4DTensor input."; - return false; - } - batch_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; - output_size_ = sizeof(T) * channel_; - - cudnnDataType_t cudnnDataType = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_), - "Set x desc failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, 1, channel_, 1, 1), - "Set para desc failed"); - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - // x, mean, variance, current_step - input_size_list_.push_back(input_size_); - input_size_list_.push_back(output_size_); - input_size_list_.push_back(output_size_); - input_size_list_.push_back(sizeof(int)); - - // batch_mean, batch_std, running_mean, running_std - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - output_size_list_.push_back(output_size_); - - // store y - workspace_size_list_.push_back(input_size_); - } - - void InitResource() override { - handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed"); - } - - private: - void DestroyResource() noexcept { - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed"); - } - - size_t input_size_; - size_t output_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - double exp_avg_factor_; - double epsilon_; - bool is_training_; - int freeze_bn_; - int batch_; - int channel_; - int height_; - int width_; - - cudnnBatchNormMode_t mode_; - cudnnTensorDescriptor_t x_desc_; - cudnnTensorDescriptor_t scale_bias_mean_var_desc_; - - cudnnHandle_t handle_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc deleted file mode 100644 index 93ea66258d..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(BatchNormFoldGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeInt32) - .AddOutputAttr(kNumberTypeFloat32), - BatchNormFoldGradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h deleted file mode 100644 index 7a3ed7ef91..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h +++ /dev/null @@ -1,166 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class BatchNormFoldGradGpuKernel : public GpuKernel { - public: - BatchNormFoldGradGpuKernel() - : input_size_(0), - channel_size_(0), - workspace_size_(0), - momentum_(0.1), - epsilon_(1e-12), - is_training_(true), - freeze_bn_(0), - current_step_(0), - batch_(0), - channel_(0), - height_(0), - width_(0) {} - ~BatchNormFoldGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' - T *d_batch_mean = GetDeviceAddress(inputs, 0); - T *d_batch_std = GetDeviceAddress(inputs, 1); - T *x = GetDeviceAddress(inputs, 2); - T *batch_mean = GetDeviceAddress(inputs, 3); - T *batch_std = GetDeviceAddress(inputs, 4); - int *current_step = GetDeviceAddress(inputs, 5); - int current_step_host[1]; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, - reinterpret_cast(stream_ptr)), - "Copy gpu memoy failed."); - if (d_batch_mean == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; - return false; - } - if (d_batch_std == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_std is null."; - return false; - } - if (x == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel x is null."; - return false; - } - if (batch_mean == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_mean is null."; - return false; - } - if (batch_std == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_std is null."; - return false; - } - if (current_step == nullptr) { - MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; - return false; - } - T *dx = GetDeviceAddress(outputs, 0); - - if (!is_training_ || current_step_host[0] >= freeze_bn_) { - ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast(stream_ptr)); - return true; - } - CalBatchNormFoldGrad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_, channel_, height_, width_, dx, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 6) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFoldGrad GpuKernel OP needs 4 output."; - return false; - } - - epsilon_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon")); - is_training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training")); - freeze_bn_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn")); - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "Input shape is " << input_shape.size() - << ", but BatchNormFoldGrad GpuKernel OP needs 4DTensor input."; - return false; - } - batch_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_; - channel_size_ = sizeof(T) * channel_; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - // 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step' - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(input_size_); - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(channel_size_); - input_size_list_.push_back(sizeof(int)); - // 'dx' - output_size_list_.push_back(input_size_); - } - - private: - size_t input_size_; - size_t channel_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - T momentum_; - T epsilon_; - bool is_training_; - int freeze_bn_; - int current_step_; - int batch_; - int channel_; - int height_; - int width_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.cc deleted file mode 100644 index a914b6ec14..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020、 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/correction_mul_gpu_kernel.h" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(CorrectionMul, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - CorrectionMulGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h deleted file mode 100644 index 29aeabb03a..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_gpu_kernel.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class CorrectionMulGpuKernel : public GpuKernel { - public: - CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} - ~CorrectionMulGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - auto *weight = GetDeviceAddress(inputs, 0); - auto *gamma = GetDeviceAddress(inputs, 1); - auto *running_std = GetDeviceAddress(inputs, 2); - auto *output = GetDeviceAddress(outputs, 0); - - CalCorrectionMul(weight, gamma, running_std, batch_size_, channel_, height_, width_, output, - reinterpret_cast(stream_ptr)); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGpuKernel needs 3."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = batch_size_ * sizeof(T); - input_size_list_.push_back(input_size); // weight - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // running_std - output_size_list_.push_back(input_size); - } - - void InitResource() override {} - - private: - void DestroyResource() noexcept {} - - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc deleted file mode 100644 index 28b5d56e68..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/correction_mul_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" - -namespace mindspore { -namespace kernel { -MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - CorrectionMulGradGpuKernel, float) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h deleted file mode 100644 index 3feffa586b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/correction_mul_grad_gpu_kernel.h +++ /dev/null @@ -1,105 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" -#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh" - -namespace mindspore { -namespace kernel { -template -class CorrectionMulGradGpuKernel : public GpuKernel { - public: - CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {} - ~CorrectionMulGradGpuKernel() override { DestroyResource(); } - - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { - auto *d_out = GetDeviceAddress(inputs, 0); - auto *weight = GetDeviceAddress(inputs, 1); - auto *gamma = GetDeviceAddress(inputs, 2); - auto *running_std = GetDeviceAddress(inputs, 3); - auto *d_weight = GetDeviceAddress(outputs, 0); - auto *d_gamma = GetDeviceAddress(outputs, 1); - auto *tmp = GetDeviceAddress(workspace, 0); - - CalCorrectionMul(d_out, gamma, running_std, batch_size_, channel_, height_, width_, d_weight, - reinterpret_cast(stream_ptr)); - CalCorrectionMulGrad(d_out, weight, running_std, batch_size_, channel_, height_, width_, d_gamma, tmp, - reinterpret_cast(stream_ptr)); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - InitResource(); - - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGradGpuKernel needs 4."; - return false; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() != 4) { - MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W)."; - return false; - } - batch_size_ = input_shape[0]; - channel_ = input_shape[1]; - height_ = input_shape[2]; - width_ = input_shape[3]; - - InitSizeLists(); - return true; - } - - protected: - void InitSizeLists() override { - size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - size_t weight_size = batch_size_ * sizeof(T); - input_size_list_.push_back(input_size); // d_out - input_size_list_.push_back(input_size); // weight - input_size_list_.push_back(weight_size); // gamma - input_size_list_.push_back(weight_size); // running_std - output_size_list_.push_back(input_size); // d_weight - output_size_list_.push_back(weight_size); // d_gamma - workspace_size_list_.push_back(input_size); // tmp d_out * weight - } - void InitResource() override {} - - private: - void DestroyResource() noexcept {} - - size_t batch_size_; - size_t channel_; - size_t height_; - size_t width_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc deleted file mode 100644 index 8db6ddd848..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel() - : input_size_(0), - num_channels_(0), - num_bits_(0), - training_(false), - symmetric_(false), - narrow_range_(false), - quant_delay_(0), - quant_min_(0), - quant_max_(0), - global_step_(0) {} - -const std::vector &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input."; - return false; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output."; - return false; - } - - // get attribute - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16."; - return false; - } - - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0."; - return false; - } - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - // shape info for gpu - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - num_channels_ = SizeToInt(input_shape[0]); - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerChannelGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input in tensor - input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar - input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar - output_size_list_.push_back(input_size_); // output in tensor - workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel -} - -void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, - float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, - symmetric_, reinterpret_cast(stream_ptr)); - CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); -} - -bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - (void)workspace; - float *output = GetDeviceAddress(outputs, 0); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; - } - - if (training_) { - if (global_step_ >= quant_delay_) { - CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; - } else { - CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h deleted file mode 100755 index 122fe96af3..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerChannelGpuKernel : public GpuKernel { - public: - FakeQuantPerChannelGpuKernel(); - ~FakeQuantPerChannelGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, - float *nudge_max, float *scale, void *stream_ptr); - - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int num_channels_; - int num_bits_; - bool training_; - bool symmetric_; - bool narrow_range_; - int quant_delay_; - float quant_min_; - float quant_max_; - int global_step_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc deleted file mode 100644 index 5c774c05ed..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh" - -namespace mindspore { -namespace kernel { -FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel() - : input_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - num_channels_(0), - quant_delay_(0), - global_step_(0), - narrow_range_(false), - symmetric_(false) {} - -const std::vector &FakeQuantPerChannelGradGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerChannelGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const { - return workspace_size_list_; -} - -bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; - } - - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; - } - - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - num_channels_ = SizeToInt(input_shape[0]); - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // gradient - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float) * num_channels_); // min - input_size_list_.push_back(sizeof(float) * num_channels_); // max - output_size_list_.push_back(input_size_); // output - workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel - workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel -} - -bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - (void)workspace; - float *output = GetDeviceAddress(outputs, 0); - float *gradient = GetDeviceAddress(inputs, 0); - float *input = GetDeviceAddress(inputs, 1); - float *input_min = GetDeviceAddress(inputs, 2); - float *input_max = GetDeviceAddress(inputs, 3); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (gradient == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; - } - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input is null"; - } - if (input_min == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input min is null"; - } - if (input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input max is null"; - } - - int total_size = input_size_ / sizeof(float); - if (global_step_ >= quant_delay_) { - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, - symmetric_, reinterpret_cast(stream_ptr)); - CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed."); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h deleted file mode 100644 index d863a2c99f..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerChannelGradGpuKernel : public GpuKernel { - public: - FakeQuantPerChannelGradGpuKernel(); - ~FakeQuantPerChannelGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int num_bits_; - float quant_min_; - float quant_max_; - int num_channels_; - int quant_delay_; - int global_step_; - bool narrow_range_; - bool symmetric_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc deleted file mode 100644 index 44869983eb..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel() - : input_size_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - global_step_(0), - num_bits_(0), - quant_delay_(0), - training_(false), - narrow_range_(false), - symmetric_(false) {} - -const std::vector &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; - } - - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0."; - } - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerLayerGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // x - input_size_list_.push_back(sizeof(float)); // min - input_size_list_.push_back(sizeof(float)); // max - output_size_list_.push_back(input_size_); // y - workspace_size_list_.push_back(sizeof(float)); // scale - workspace_size_list_.push_back(sizeof(float)); // nudge_min - workspace_size_list_.push_back(sizeof(float)); // nudge_max -} - -bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - float *output = GetDeviceAddress(outputs, 0); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null."; - } - - if (training_) { - // control flow for quant_delay - if (global_step_ >= quant_delay_) { - // real launch - CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); - CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - } else { - // real launch - CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); - CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - } - - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h deleted file mode 100755 index 38810e06df..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerLayerGpuKernel : public GpuKernel { - public: - FakeQuantPerLayerGpuKernel(); - ~FakeQuantPerLayerGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - float quant_min_; - float quant_max_; - int quant_num_; - int global_step_; - int num_bits_; - int quant_delay_; - bool training_; - bool narrow_range_; - bool symmetric_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc deleted file mode 100644 index c8d57b2bb1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh" - -namespace mindspore { -namespace kernel { -FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel() - : input_size_(0), - workspace_size_(0), - num_bits_(0), - quant_min_(0), - quant_max_(0), - quant_num_(1), - quant_delay_(0), - global_step_(0), - narrow_range_(false), - symmetric_(false) {} - -const std::vector &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 4) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output."; - } - - num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); - if (num_bits_ <= 2 || num_bits_ >= 16) { - MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; - } - - quant_delay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay")); - if (quant_delay_ < 0) { - MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0."; - } - - symmetric_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); - narrow_range_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); - - // quant min and max value - quant_min_ = 0; - quant_max_ = (1 << num_bits_) - 1; - if (narrow_range_) { - quant_min_++; - } - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void FakeQuantPerLayerGradGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // gradient - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float)); // min - input_size_list_.push_back(sizeof(float)); // max - output_size_list_.push_back(input_size_); // output - workspace_size_list_.push_back(sizeof(float)); // scale - workspace_size_list_.push_back(sizeof(float)); // nudge_min - workspace_size_list_.push_back(sizeof(float)); // nudge_max -} - -bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - float *output = GetDeviceAddress(outputs, 0); - float *gradient = GetDeviceAddress(inputs, 0); - float *input = GetDeviceAddress(inputs, 1); - float *input_min = GetDeviceAddress(inputs, 2); - float *input_max = GetDeviceAddress(inputs, 3); - float *scale = GetDeviceAddress(workspace, 0); - float *nudge_min = GetDeviceAddress(workspace, 1); - float *nudge_max = GetDeviceAddress(workspace, 2); - - if (gradient == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null"; - } - if (input == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null."; - } - - if (global_step_ >= quant_delay_) { - CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); - CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, - reinterpret_cast(stream_ptr)), - "Copy gpu memory failed"); - } - global_step_++; - return true; -} - -MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h deleted file mode 100644 index ae2ea5bfac..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class FakeQuantPerLayerGradGpuKernel : public GpuKernel { - public: - FakeQuantPerLayerGradGpuKernel(); - ~FakeQuantPerLayerGradGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel_node) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - size_t workspace_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int num_bits_; - float quant_min_; - float quant_max_; - int quant_num_; - int quant_delay_; - int global_step_; - bool narrow_range_; - bool symmetric_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc deleted file mode 100644 index a8ce72148b..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() - : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {} - -const std::vector &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const { - return workspace_size_list_; -} - -bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 2) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; - } - - ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - num_channels_ = SizeToInt(input_shape[0]); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float) * num_channels_); // min - input_size_list_.push_back(sizeof(float) * num_channels_); // max - output_size_list_.push_back(sizeof(float) * num_channels_); // output min - output_size_list_.push_back(sizeof(float) * num_channels_); // output max -} - -bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - float *output_min = GetDeviceAddress(outputs, 0); - float *output_max = GetDeviceAddress(outputs, 1); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null."; - } - - // calculate the input min and max according by the parameter ema and ema_decay. - CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, - ema_decay_, ema_, reinterpret_cast(stream_ptr)); - return true; -} - -MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h deleted file mode 100644 index 563a583ca1..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { - public: - MinMaxUpdatePerChannelGpuKernel(); - ~MinMaxUpdatePerChannelGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int quant_num_; - bool ema_; - float ema_decay_; - int num_channels_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc deleted file mode 100644 index 3659665b23..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h" -#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh" -#include -#include -#include -#include - -namespace mindspore { -namespace kernel { -MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() - : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {} - -const std::vector &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 3) { - MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output."; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 2) { - MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; - } - - ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); - - // init size - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - for (size_t i = 0; i < input_shape.size(); ++i) { - quant_num_ *= SizeToInt(input_shape[i]); - } - input_size_ = sizeof(float); - for (size_t i = 0; i < input_shape.size(); i++) { - input_size_ *= input_shape[i]; - } - InitSizeLists(); - return true; -} - -void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(sizeof(float)); // input min - input_size_list_.push_back(sizeof(float)); // input max - output_size_list_.push_back(sizeof(float)); // output min - output_size_list_.push_back(sizeof(float)); // output max -} - -bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) { - float *output_min = GetDeviceAddress(outputs, 0); - float *output_max = GetDeviceAddress(outputs, 1); - float *input = GetDeviceAddress(inputs, 0); - float *input_min = GetDeviceAddress(inputs, 1); - float *input_max = GetDeviceAddress(inputs, 2); - - if (input == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null."; - } - if (input_min == nullptr || input_max == nullptr) { - MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; - } - - CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, - reinterpret_cast(stream_ptr)); - - return true; -} - -MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel) -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h deleted file mode 100644 index a237b6dc26..0000000000 --- a/mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ - -#include -#include "kernel/gpu/gpu_kernel.h" -#include "kernel/gpu/gpu_kernel_factory.h" - -namespace mindspore { -namespace kernel { -class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { - public: - MinMaxUpdatePerLayerGpuKernel(); - ~MinMaxUpdatePerLayerGpuKernel() = default; - - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - bool Init(const CNodePtr &kernel) override; - - protected: - void InitSizeLists() override; - - private: - size_t input_size_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - - int quant_num_; - bool ema_; - float ema_decay_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_ diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc deleted file mode 100644 index d5d6e55698..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel.cc +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hccl_kernel.h" -#include "device/ascend/tasksink/runtime_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" - -using HcclTaskInfoPtr = std::shared_ptr; -using ge::model_runner::HcclTaskInfo; -using mindspore::device::ascend::tasksink::RuntimeUtils; - -namespace mindspore { -namespace kernel { -void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { - hcclKernelMap_.emplace(name, std::move(fun)); -} - -std::shared_ptr HcclKernelFactory::Get(const std::string &name) { - const auto &map = Get().hcclKernelMap_; - auto it = map.find(name); - if (it != map.end() && it->second) { - return (it->second)(); - } - return nullptr; -} - -HcclKernelFactory &HcclKernelFactory::Get() { - static HcclKernelFactory _this; - return _this; -} - -HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {} - -HcclKernel::~HcclKernel() { - hccl_kernel_input_shape_list_.clear(); - hccl_kernel_output_shape_list_.clear(); - hccl_data_type_list_.clear(); - hccl_count_ = 0; - op_type_ = HCCL_REP_OP_SUM; - root_id_ = 0; - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - anf_node_ = nullptr; -} - -bool HcclKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - op_name_ = AnfAlgo::GetCNodeName(anf_node); - - if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) { - MS_LOG(ERROR) << "GetKernelInputShape fail!"; - return false; - } - if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) { - MS_LOG(ERROR) << "GetKernelOutputShape fail!"; - return false; - } - if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) { - MS_LOG(ERROR) << "GetHcomDataType fail!"; - return false; - } - if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) { - MS_LOG(ERROR) << "GetHcomCount fail!"; - return false; - } - if (op_name_ == kAllReduce || op_name_ == kReduceScatter) { - if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) { - MS_LOG(ERROR) << "GetHcomOperationType fail!"; - return false; - } - } - if (op_name_ == kBroadcast) { - if (!HcomUtil::GetHcomRootId(anf_node, &root_id_)) { - MS_LOG(ERROR) << "GetHcomRootId fail!"; - return false; - } - } - HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_)); - anf_node_ = anf_node; - return true; -} - -const std::vector &HcclKernel::GetInputSizeList() const { - size_t size = 0; - if (!input_size_list_.empty()) { - return input_size_list_; - } - for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { - if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_input_shape_list_[i], &size)) { - MS_LOG(ERROR) << "GetHcclOpInputSize failed"; - } - input_size_list_.push_back(size); - } - return input_size_list_; -} - -const std::vector &HcclKernel::GetOutputSizeList() const { - size_t size = 0; - if (!output_size_list_.empty()) { - return output_size_list_; - } - for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) { - if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) { - MS_LOG(ERROR) << "GetHcclOpOutputSize failed"; - } - output_size_list_.push_back(size); - } - return output_size_list_; -} - -const std::vector &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } - -std::vector HcclKernel::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - if (inputs.empty() || outputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; - } - stream_id_ = stream_id; - std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_); - MS_EXCEPTION_IF_NULL(inputs.at(0)); - auto input_data_addr = inputs.at(0)->addr; - MS_EXCEPTION_IF_NULL(outputs.at(0)); - auto output_data_addr = outputs.at(0)->addr; - void *workspace_address = nullptr; - const int64_t workspace_num = 0; - std::vector private_def; - hcclDataType_t data_type = hccl_data_type_list_[0]; - - MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_ - << ", root_id=" << root_id_ << ", op_type=" << static_cast(op_type_) - << ", data_type=" << static_cast(data_type); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - HcclTaskInfoPtr task_info_ptr = std::make_shared( - kernel_name_, stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, - private_def, nullptr, hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, - RuntimeUtils::HcomUnbindModel, RuntimeUtils::HcomDistribute, NeedDump()); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel.h deleted file mode 100644 index 72e202591f..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_H_ - -#include -#include -#include -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/hccl/hcom_util.h" -#include "hccl/hcom.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -class HcclKernel : public AscendKernelMod { - public: - HcclKernel(); - ~HcclKernel() override; - virtual bool Init(const AnfNodePtr &anf_node); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - protected: - std::vector> hccl_kernel_input_shape_list_; - std::vector> hccl_kernel_output_shape_list_; - std::vector hccl_data_type_list_; - std::vector hccl_format_list_; - uint64_t hccl_count_; - hcclRedOp_t op_type_; - uint32_t root_id_; - mutable std::vector input_size_list_; - mutable std::vector output_size_list_; - mutable std::vector workspace_size_list_; - AnfNodePtr anf_node_; - std::string op_name_; - std::string group_; -}; - -using HcclKernelCreater = std::function()>; - -class HcclKernelFactory { - HcclKernelFactory() = default; - ~HcclKernelFactory() = default; - - public: - static HcclKernelFactory &Get(); - void Registe(const string &name, HcclKernelCreater &&fun); - static std::shared_ptr Get(const string &name); - - private: - std::map hcclKernelMap_; -}; - -class _HcclKernelRegister { - public: - _HcclKernelRegister(const string &name, HcclKernelCreater &&fun) { - HcclKernelFactory::Get().Registe(name, std::move(fun)); - } - ~_HcclKernelRegister() = default; -}; - -#define _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) \ - static_assert(std::is_base_of::value, " must be base of HcclKernel"); \ - static const _HcclKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \ - std::shared_ptr ptr = nullptr; \ - ptr = std::make_shared(); \ - MS_EXCEPTION_IF_NULL(ptr); \ - return ptr; \ - }); - -#define MS_HCCL_REG_KERNEL(KNAME, clazz) _MS_HCCL_REG_KERNEL_REG(KNAME, clazz) -} // namespace kernel -} // namespace mindspore -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.cc deleted file mode 100644 index d6e4aa09b9..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hccl_kernel_build.h" - -#include -#include -#include - -#include "kernel/hccl/hccl_kernel.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string opname = AnfAlgo::GetCNodeName(anf_node); - MS_LOG(INFO) << "Hccl op [" << opname << "]"; - auto kerPtr = HcclKernelFactory::Get(opname); - if (kerPtr == nullptr) { - MS_LOG(ERROR) << "Hccl can't find Kernel[" << opname << "]"; - return nullptr; - } - if (!kerPtr->Init(anf_node)) { - MS_LOG(ERROR) << "Kernel initialize failed!"; - return nullptr; - } - return kerPtr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.h deleted file mode 100644 index f20760a3eb..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_build.h +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_BUILD_H_ - -#include -#include -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -KernelModPtr HcclOpBuild(const AnfNodePtr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc deleted file mode 100755 index bfd1327548..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hccl_kernel_metadata.h" -#include -#include -#include "utils/utils.h" -#include "kernel/hccl/hcom_util.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -namespace { -std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { - const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; - auto op_name = AnfAlgo::GetCNodeName(kernel_node); - auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); - if (op_name != kReduceScatter && op_name != kAllGatherOpName) { - return format; - } - if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) { - return kOpFormat_DEFAULT; - } - if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) { - return kOpFormat_DEFAULT; - } - return format; -} -} // namespace -void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - const std::vector kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, - kNumberTypeFloat32, kNumberTypeInt16}; - MS_EXCEPTION_IF_NULL(kernel_info_list); - MS_EXCEPTION_IF_NULL(kernel_node); - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - if (op_name != kAllGather && op_name != kAllReduce && op_name != kBroadcast && op_name != kReduceScatter) { - MS_LOG(DEBUG) << "Hccl does not have op [" << op_name << "]"; - return; - } - for (const auto &type : kHcclSupportTypes) { - std::vector inputs_format{}; - std::vector inputs_type{}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { - inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); - inputs_type.push_back(type); - } - std::vector outputs_format; - std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { - outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index)); - outputs_type.push_back(type); - } - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetKernelType(HCCL_KERNEL); - kernel_info_list->push_back(builder.Build()); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.h b/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.h deleted file mode 100755 index b13393d3bd..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hccl_kernel_metadata.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ -#include -#include -#include -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_HCCL_HCCL_KERNEL_METADATA_ANFALGO_H_ diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.cc deleted file mode 100644 index 9dbe708ef9..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_broadcast.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllBroadCastKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "BroadCast param is empty"; - return false; - } - const char *tag = "Hccl-BroadCast"; - MS_EXCEPTION_IF_NULL(inputs[0]); - hcclResult_t ret = - hcom_broadcast(tag, inputs[0]->addr, hccl_count_, hccl_data_type_list_[0], root_id_, nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomBroadcastOp : hcom_broadcast fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.h b/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.h deleted file mode 100644 index ca8eba91af..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_broadcast.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_BROADCAST_H_ - -#include -#include -#include "hccl/hcom.h" -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllBroadCastKernel : public HcclKernel { - public: - HcomAllBroadCastKernel() = default; - ~HcomAllBroadCastKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; -MS_HCCL_REG_KERNEL(Broadcast, HcomAllBroadCastKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_gather.cc deleted file mode 100644 index 6494f7fd12..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_gather.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllGatherKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "AllGather param is empty"; - return false; - } - const char *tag = "Hccl-AllGather"; - hcclResult_t ret = - hcom_all_gather(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomAllGatherKernelOp : hcom_all_gather fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.h b/mindspore/ccsrc/kernel/hccl/hcom_all_gather.h deleted file mode 100644 index 5de2c513cf..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_gather.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_GATHER_H_ - -#include -#include -#include "hccl/hcom.h" -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllGatherKernel : public HcclKernel { - public: - HcomAllGatherKernel() = default; - ~HcomAllGatherKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; -MS_HCCL_REG_KERNEL(AllGather, HcomAllGatherKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.cc deleted file mode 100644 index 35a058e766..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_reduce.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllReduceKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "AllReduce param is empty"; - return false; - } - const char *tag = "Hccl-AllReduce"; - hcclResult_t ret = hcom_all_reduce(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], - op_type_, nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomAllReduceKernelOp : hcom_all_reduce fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.h b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.h deleted file mode 100644 index 939abd9de7..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_H_ - -#include -#include -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllReduceKernel : public HcclKernel { - public: - HcomAllReduceKernel() = default; - ~HcomAllReduceKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; - -MS_HCCL_REG_KERNEL(AllReduce, HcomAllReduceKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.cc b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.cc deleted file mode 100644 index dea516885d..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.cc +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_all_reduce_scatter.h" - -#include -#include -#include - -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -bool HcomAllReduceScatterKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink()) { - return true; - } - if (inputs.empty() || outputs.empty() || hccl_data_type_list_.empty()) { - MS_LOG(ERROR) << "ReduceScatter param is empty"; - return false; - } - const char *tag = "Hccl-ReduceScatter"; - hcclResult_t ret = hcom_reduce_scatter(tag, inputs[0]->addr, outputs[0]->addr, hccl_count_, hccl_data_type_list_[0], - op_type_, nullptr, stream_ptr); - if (ret != HCCL_SUCCESS) { - MS_LOG(ERROR) << "HcomReduceScatterOp : hcom_reduce_scatter fail, return: " << static_cast(ret); - return false; - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.h b/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.h deleted file mode 100644 index c734b517c6..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_all_reduce_scatter.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ -#define MINDSPORE_CCSRC_KERNEL_HCCL_HCOM_ALL_REDUCE_SCATTER_H_ - -#include -#include -#include "hccl/hcom.h" -#include "kernel/hccl/hccl_kernel.h" - -namespace mindspore { -namespace kernel { -class HcomAllReduceScatterKernel : public HcclKernel { - public: - HcomAllReduceScatterKernel() = default; - ~HcomAllReduceScatterKernel() override = default; - - /* Inherit from kernelmod */ - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - - private: -}; - -MS_HCCL_REG_KERNEL(ReduceScatter, HcomAllReduceScatterKernel); -} // namespace kernel -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.cc b/mindspore/ccsrc/kernel/hccl/hcom_util.cc deleted file mode 100644 index 088dbe59d5..0000000000 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/hccl/hcom_util.h" - -#include - -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_intput_shape_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { - std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); - hccl_kernel_intput_shape_list->emplace_back(shape_i); - } - - return true; -} - -bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_output_shape_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list); - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) { - std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); - hccl_kernel_output_shape_list->emplace_back(shape_i); - } - - return true; -} - -bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector *data_type_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(data_type_list); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { - auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i); - auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr); - if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { - MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr; - } - data_type_list->emplace_back(iter->second); - } - auto type_base = *(std::begin(*data_type_list)); - if (std::any_of(data_type_list->begin(), data_type_list->end(), - [&type_base](hcclDataType_t type) { return type != type_base; })) { - MS_LOG(ERROR) << "hccl have different data type"; - return false; - } - return true; -} - -bool HcomUtil::GetHcclOpSize(const hcclDataType_t &data_type, const vector &shape, size_t *size) { - MS_EXCEPTION_IF_NULL(size); - size_t tmp_size = 1; - uint32_t type_size = 4; - for (size_t i = 0; i < shape.size(); i++) { - tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]); - } - - if (!GetHcomTypeSize(data_type, &type_size)) { - return false; - } - - *size = SizetMulWithOverflowCheck(tmp_size, type_size); - - MS_LOG(INFO) << "size[" << *size << "]"; - return true; -} - -bool HcomUtil::GetHcomTypeSize(const hcclDataType_t &data_type, uint32_t *size) { - MS_EXCEPTION_IF_NULL(size); - auto iter = CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.find(data_type); - if (iter == CONST_OP_HCOM_DATA_TYPE_SIZE_MAP.end()) { - MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!"; - return false; - } - *size = iter->second; - return true; -} - -bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector &data_type_list, - const vector> &shape_list, uint64_t *total_count) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(total_count); - const uint32_t align_size = 512; - const uint32_t filled_size = 32; - uint64_t total_size = 0; - uint64_t block_size; - size_t input_size; - uint32_t type_size = 4; - - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { - if (!GetHcomTypeSize(data_type_list[i], &type_size)) { - return false; - } - - if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) { - MS_LOG(ERROR) << "Get GetHcclOpSize failed"; - return false; - } - - if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) { - int32_t rank_size; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("rank_size") != nullptr) { - rank_size = GetValue(primitive->GetAttr("rank_size")); - } else { - MS_LOG(ERROR) << "Get rank size failed"; - return false; - } - block_size = input_size / IntToSize(rank_size); - total_size = total_size + block_size; - } else { - if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) { - block_size = input_size; - } else { - block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size; - } - total_size = total_size + block_size; - } - } - - if (type_size == 0 || total_size % type_size != 0) { - MS_LOG(ERROR) << "Total_size[" << total_size << "],Type_size[" << type_size << "] != 0, fail!"; - return false; - } - *total_count = total_size / type_size; - return true; -} - -bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, hcclRedOp_t *op_type) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_type); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("op") == nullptr) { - MS_LOG(ERROR) << "Get HCOM_ATTR_REDUCE_TYPE fail, not support!"; - return false; - } - auto hcom_op_type_get = GetValue(primitive->GetAttr("op")); - string hcom_op_type(hcom_op_type_get); - if (hcom_op_type == "min") { - *op_type = HCCL_REP_OP_MIN; - } else if (hcom_op_type == "max") { - *op_type = HCCL_REP_OP_MAX; - } else if (hcom_op_type == "prod") { - *op_type = HCCL_REP_OP_PROD; - } else if (hcom_op_type == "sum") { - *op_type = HCCL_REP_OP_SUM; - } else { - MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!"; - return false; - } - return true; -} - -bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(root_id); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (primitive->GetAttr("root_rank") != nullptr) { - *root_id = (uint32_t)GetValue(primitive->GetAttr("root_rank")); - } else { - MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"; - return false; - } - return true; -} - -void HcomUtil::GetHcomGroup(NotNull anf_node, NotNull group) { - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - auto attr = primitive->GetAttr("group"); - if (attr != nullptr) { - *group = GetValue(attr); - } else { - MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed"; - } -} -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kash/kernel_pack.cc b/mindspore/ccsrc/kernel/kash/kernel_pack.cc deleted file mode 100644 index a87441031b..0000000000 --- a/mindspore/ccsrc/kernel/kash/kernel_pack.cc +++ /dev/null @@ -1,249 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "mindspore/ccsrc/kernel/kernel.h" -#include "kernel/kernel.h" -#include "kernel/akg/akg_kernel_build.h" -#include "nlohmann/json.hpp" -#include "securec/include/securec.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -namespace mindspore { -namespace kernel { -constexpr auto kUtilsModule = "mindspore._extends.utils"; -constexpr auto kCalSha256Func = "cal_sha256"; - -namespace { -bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) { - if (js.find("sha256") == js.end()) { - MS_LOG(ERROR) << "No sha256 found in " << json_file; - return false; - } - std::string sha256_str = js["sha256"]; - py::object ret = parse::python_adapter::CallPyFn(kUtilsModule, kCalSha256Func, bin_file); - std::string sha256_cal = py::cast(ret); - if (sha256_cal.empty()) { - MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; - return false; - } - if (sha256_cal != sha256_str) { - MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed."; - return false; - } - return true; -} -} // namespace - -const std::string KernelPack::Serialize() const { - MS_EXCEPTION_IF_NULL(json_); - MS_EXCEPTION_IF_NULL(kernel_); - std::string buffer; - (void)buffer.append((const char *)json_, json_->len + sizeof(json_->len)); - (void)buffer.append((const char *)kernel_, kernel_->len + sizeof(kernel_->len)); - return buffer; -} - -bool KernelPack::ReadFromJsonFileHelper(std::ifstream &kernelbin) { - size_t binsize = LongToSize(kernelbin.seekg(0, std::ios::end).tellg()); - // free old data - if (kernel_ != nullptr) { - delete[] kernel_; - kernel_ = nullptr; - } - - void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); - if (ptr != nullptr) { - kernel_ = static_cast(ptr); - } - if (kernel_ == nullptr) { - MS_LOG(ERROR) << "memory malloc failed."; - kernelbin.close(); - return false; - } - if (memset_s(kernel_, sizeof(KernelPack) + binsize, 0, sizeof(KernelPack) + binsize) != EOK) { - MS_LOG(ERROR) << "memset kernel_ failed."; - delete[] kernel_; - kernel_ = nullptr; - kernelbin.close(); - return false; - } - kernel_->len = binsize; - MS_LOG(INFO) << "kernel len:" << kernel_->len; - (void)kernelbin.seekg(0, std::ios::beg); - (void)kernelbin.read(kernel_->contents, SizeToLong(kernel_->len)); - return true; -} - -bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &processor) { - if (json_f.length() <= strlen(kJsonSuffix)) { - MS_LOG(ERROR) << "please check json path."; - return false; - } - - std::ifstream kerneljson(json_f); - if (!kerneljson.is_open()) { - MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; - return false; - } - nlohmann::json js; - kerneljson >> js; - - size_t binsize = LongToSize(kerneljson.seekg(0, std::ios::end).tellg()); - void *ptr = static_cast(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]); - if (ptr != nullptr) { - json_ = static_cast(ptr); - } - if (json_ == nullptr) { - MS_LOG(ERROR) << "memory malloc failed."; - kerneljson.close(); - return false; - } - json_->len = binsize; - (void)kerneljson.seekg(0, std::ios::beg); - (void)kerneljson.read(json_->contents, SizeToLong(json_->len)); - - if (processor == kProcessorCuda) { - std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx"; - std::ifstream kernelbin(bin_f); - if (!kernelbin.is_open()) { - MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta."; - kerneljson.close(); - return false; - } - - if (ReadFromJsonFileHelper(kernelbin) == false) { - delete[] json_; - json_ = nullptr; - kerneljson.close(); - return false; - } - kerneljson.close(); - if (!CheckHash(json_f, bin_f, js)) { - return false; - } - return true; - } - - std::string binfilesuffix = js["binFileSuffix"]; - std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfilesuffix; - if (binfilesuffix.compare(".so") == 0) { - // change "xx/xx.so" -> "xx/libxx.so" - auto sp = bin_f.rfind('/'); - if (sp == std::string::npos) { - MS_LOG(ERROR) << "illegal bin file path " << bin_f; - kerneljson.close(); - return false; - } - bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); - } - - std::ifstream kernelbin(bin_f, std::ios::binary); - if (!kernelbin.is_open()) { - MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; - kerneljson.close(); - delete[] json_; - json_ = nullptr; - return false; - } - - MS_LOG(INFO) << "kernelbin_name:" << bin_f; - if (ReadFromJsonFileHelper(kernelbin) == false) { - delete[] json_; - json_ = nullptr; - kerneljson.close(); - return false; - } - kerneljson.close(); - - if (!CheckHash(json_f, bin_f, js)) { - return false; - } - - return true; -} - -void KernelPack::ParseKernelJson(const nlohmann::json &js) { - kernel_json_info_.bin_file_name = js["binFileName"]; - kernel_json_info_.bin_file_suffix = js["binFileSuffix"]; - kernel_json_info_.block_dim = js["blockDim"]; - kernel_json_info_.kernel_name = js["kernelName"]; - kernel_json_info_.magic = js["magic"]; - if (js.find("parameters") != js.end()) { - if (!js.at("parameters").is_array()) { - MS_LOG(DEBUG) << "Format error!,parameters should be array."; - } - std::vector sizes = js.at("parameters"); - for (auto size : sizes) { - MS_LOG(INFO) << "parameter " << size; - kernel_json_info_.parameters.push_back(size); - } - } - if (js.find("workspace") != js.end()) { - auto workspace = js.at("workspace"); - std::vector sizes = workspace.at("size"); - for (auto size : sizes) { - MS_LOG(INFO) << "workspace_size_list " << size; - kernel_json_info_.workspaces.push_back(size); - } - } - kernel_json_info_.sha256 = js["sha256"]; -} - -bool KernelPack::LoadKernelMeta(const std::string &json_f, const std::string &processor) { - if (json_f.length() <= strlen(kJsonSuffix)) { - MS_LOG(ERROR) << "please check json path."; - return false; - } - std::ifstream kernel_json(json_f); - if (!kernel_json.is_open()) { - MS_LOG(DEBUG) << "read json file error, please check kernelmeta."; - return false; - } - nlohmann::json js; - kernel_json >> js; - ParseKernelJson(js); - kernel_json.close(); - - std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix; - if (kernel_json_info_.bin_file_suffix == ".so") { - // change "xx/xx.so" -> "xx/libxx.so" - auto sp = bin_f.rfind('/'); - if (sp == std::string::npos) { - MS_LOG(ERROR) << "illegal bin file path " << bin_f; - return false; - } - bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1); - } - - std::ifstream kernelbin(bin_f, std::ios::binary); - if (!kernelbin.is_open()) { - MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta."; - return false; - } - - MS_LOG(INFO) << "kernelbin_name:" << bin_f; - if (!ReadFromJsonFileHelper(kernelbin)) { - return false; - } - - return CheckHash(json_f, bin_f, js); -} - -KernelJsonInfo KernelPack::kernel_json_info() const { return kernel_json_info_; } -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc deleted file mode 100644 index bb7ce75ac4..0000000000 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/kernel_build_info.h" -#include -#include "utils/log_adapter.h" -#include "debug/anf_ir_dump.h" -namespace mindspore { -namespace kernel { -std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { - if (input_index >= inputs_format_.size()) { - MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node"; - return kInvalidFormat; - } - return inputs_format_[input_index]; -} - -std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { - if (output_index >= outputs_format_.size()) { - MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; - return kInvalidFormat; - } - return outputs_format_[output_index]; -} - -TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const { - if (input_index >= inputs_device_type_.size()) { - MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input"; - return TypeId::kNumberTypeEnd; - } - return inputs_device_type_[input_index]; -} - -TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const { - if (output_index >= outputs_device_type_.size()) { - MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; - return TypeId::kNumberTypeEnd; - } - return outputs_device_type_[output_index]; -} - -std::vector KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; } - -std::vector KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; } - -std::vector KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; } - -std::vector KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; } - -size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } - -size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } - -std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { - if (input_index >= input_reshape_type_.size()) { - MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " - << input_reshape_type_.size(); - } - return input_reshape_type_[input_index]; -} - -std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { - if (output_index >= output_reshape_type_.size()) { - MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " - << output_reshape_type_.size(); - } - return output_reshape_type_[output_index]; -} - -std::string KernelBuildInfo::ToString() const { - std::ostringstream output_buffer; - output_buffer << "("; - for (size_t index = 0; index < GetInputNum(); ++index) { - if (index != 0) { - output_buffer << ", "; - } - output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">"; - } - output_buffer << ") -> ("; - for (size_t index = 0; index < GetOutputNum(); ++index) { - if (index != 0) { - output_buffer << ", "; - } - output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">"; - } - output_buffer << ")"; - return output_buffer.str(); -} - -bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { - if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) { - return false; - } - if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) { - if (op_pattern_ != kFormatAgnosticPattern) { - return false; - } else { - MS_LOG(INFO) << "this kernel build info:" << this->ToString() - << ", other kernel build info: " << other.ToString(); - } - } - return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); -} - -bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); } - -bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); } - -bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); } - -void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->kernel_type_ = kernel_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector &inputs_format) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->inputs_format_ = inputs_format; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector &outputs_format) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->outputs_format_ = outputs_format; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector &inputs_device_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->inputs_device_type_ = inputs_device_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector &outputs_device_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->outputs_device_type_ = outputs_device_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->fusion_type_ = fusion_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->processor_ = processor; -} - -std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } - -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( - const std::vector> &input_reshape_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->input_reshape_type_ = input_reshape_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( - const std::vector> &output_reshape_type) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->output_reshape_type_ = output_reshape_type; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - kernel_build_info_->op_pattern_ = pattern; -} -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - if (index >= kernel_build_info_->inputs_format_.size()) { - MS_LOG(EXCEPTION) << "index outof range!"; - } - kernel_build_info_->inputs_format_[index] = format; -} - -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) { - MS_EXCEPTION_IF_NULL(kernel_build_info_); - if (index >= kernel_build_info_->outputs_format_.size()) { - MS_LOG(EXCEPTION) << "index outof range!"; - } - kernel_build_info_->outputs_format_[index] = format; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h deleted file mode 100644 index 45ac45f98f..0000000000 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ -#include -#include -#include -#include -#include -#include "ir/dtype.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -class KernelBuildInfo { - public: - class KernelBuildInfoBuilder; - - KernelBuildInfo() { - kernel_type_ = TBE_KERNEL; - fusion_type_ = OPAQUE; - processor_ = AICORE; - op_pattern_ = kCommonPattern; - input_reshape_type_ = {}; - output_reshape_type_ = {}; - inputs_format_ = {}; - outputs_format_ = {}; - inputs_device_type_ = {}; - outputs_device_type_ = {}; - } - - ~KernelBuildInfo() = default; - - KernelType kernel_type() const { return kernel_type_; } - - std::string GetInputFormat(size_t input_index) const; - - std::string GetOutputFormat(size_t output_index) const; - - TypeId GetInputDeviceType(size_t input_index) const; - - TypeId GetOutputDeviceType(size_t output_index) const; - - std::vector GetInputReshapeType(size_t input_index) const; - - bool IsInputDefaultPadding() const; - - bool IsOutputDefaultPadding() const; - - std::vector GetOutputReshapeType(size_t input_index) const; - - std::vector GetAllInputFormats() const; - - std::vector GetAllOutputFormats() const; - - std::vector GetAllInputDeviceTypes() const; - - std::vector GetAllOutputDeviceTypes() const; - - OpPattern op_pattern() const { return op_pattern_; } - - FusionType fusion_type() const { return fusion_type_; } - - Processor processor() const { return processor_; } - - size_t GetInputNum() const; - - size_t GetOutputNum() const; - - std::string ToString() const; - - bool operator==(const KernelBuildInfo &other) const; - - bool operator!=(const KernelBuildInfo &other) const; - - public: - static auto constexpr kInvalidFormat = "InvalidFormat"; - - private: - KernelType kernel_type_; - std::vector inputs_format_; - OpPattern op_pattern_; - std::vector outputs_format_; - std::vector> input_reshape_type_; - std::vector> output_reshape_type_; - std::vector inputs_device_type_; - std::vector outputs_device_type_; - FusionType fusion_type_; - Processor processor_; -}; -using KernelBuildInfoPtr = std::shared_ptr; - -class KernelBuildInfo::KernelBuildInfoBuilder { - public: - KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared(); } - - explicit KernelBuildInfoBuilder(std::shared_ptr kernel_build_info) - : kernel_build_info_(std::move(kernel_build_info)) {} - - ~KernelBuildInfoBuilder() = default; - - void SetKernelType(const KernelType &kernel_type); - - void SetInputsFormat(const std::vector &inputs_format); - - void SetOutputsFormat(const std::vector &outputs_format); - - void SetInputsDeviceType(const std::vector &inputs_device_type); - - void SetOutputsDeviceType(const std::vector &outputs_device_type); - - void SetInputReshapeType(const std::vector> &input_reshape_type); - - void SetOutputReshapeType(const std::vector> &output_reshape_type); - - void SetFusionType(FusionType fusion_type); - - void SetProcessor(Processor processor); - - void SetOpPattern(OpPattern pattern); - - void SetInputFormat(const std::string &format, size_t index); - - void SetOutputFormat(const std::string &format, size_t index); - - std::shared_ptr Build(); - - private: - std::shared_ptr kernel_build_info_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_BUILD_INFO_H_ diff --git a/mindspore/ccsrc/kernel/kernel_fusion.cc b/mindspore/ccsrc/kernel/kernel_fusion.cc deleted file mode 100644 index be79eca15a..0000000000 --- a/mindspore/ccsrc/kernel/kernel_fusion.cc +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/kernel_fusion.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "kernel/tbe/tbe_kernel_parallel_build.h" -#include "kernel/tbe/tbe_utils.h" -#include "kernel/tbe/tbe_convert_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -static bool GenPreBuildKernelJson(const std::vector &compute_nodes, - std::vector *prebuild_op_list) { - MS_EXCEPTION_IF_NULL(prebuild_op_list); - TbeKernelJsonCreator creator(PREBUILD); - for (const auto &anf_node : compute_nodes) { - nlohmann::json prebuild; - if (!creator.GenTbeSingleKernelJson(anf_node, &prebuild)) { - MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; - return false; - } - (*prebuild_op_list).push_back(prebuild); - } - return true; -} - -std::map KernelFusion(const std::vector &fusion_scopes) { - MS_LOG(INFO) << "kernel fusion build start, scope size:" << fusion_scopes.size(); - std::map kernel_mod_ret; - auto build_manger = std::make_shared(); - MS_EXCEPTION_IF_NULL(build_manger); - for (const auto &fusion_scope_iter : fusion_scopes) { - auto scope_id = fusion_scope_iter.scope_id; - nlohmann::json fusion_op; - string fusion_kernel = "te_fusion"; - if (!TbeKernelBuild::GenFusionScopeJson(fusion_scope_iter.input_nodes, fusion_scope_iter.compute_nodes, &fusion_op, - &fusion_kernel)) { - continue; - } - // gen kernel_name & check cache - std::string json_str = fusion_op.dump(); - size_t hash_id = std::hash()(json_str); - auto json_name = fusion_kernel.append("_").append(std::to_string(hash_id)); - fusion_op["fusion_op_name"] = json_name; - // gen json for prebuild - std::vector prebuild_op_list; - if (!GenPreBuildKernelJson(fusion_scope_iter.compute_nodes, &prebuild_op_list)) { - continue; - } - // get io size - std::vector input_size_list; - std::vector output_size_list; - if (!TbeKernelBuild::GetIOSize(fusion_op["op_list"], fusion_scope_iter.output_nodes, &input_size_list, - &output_size_list)) { - continue; - } - // search cache - auto kernel_pack = TbeUtils::SearchCache(json_name, tbe::kProcessorAiCore); - if (kernel_pack != nullptr) { - MS_LOG(INFO) << "Use cached kernel, kernel json name: " << json_name; - auto kernel_mod = - build_manger->GenKernelMod(json_name, tbe::kProcessorAiCore, input_size_list, output_size_list, kernel_pack); - if (kernel_mod != nullptr) { - kernel_mod_ret[scope_id] = kernel_mod; - continue; - } - } - // fusion build - nlohmann::json fusion_json; - fusion_json["fusion_op"] = fusion_op; - fusion_json["prebuild_ops"] = prebuild_op_list; - auto task_id = build_manger->StartCompileOp(fusion_json); - TbeUtils::SaveJsonInfo(json_name, fusion_json.dump()); - if (task_id < 0) { - MS_EXCEPTION(ArgumentError) << "start compile failed."; - } - build_manger->SaveTaskInfo(task_id, nullptr, json_name, input_size_list, output_size_list, scope_id); - } - - int build_failed_num = 0; - while (!build_manger->IsAllTaskFinish()) { - int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); - if (!ret) { - MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; - } - - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { - MS_LOG(INFO) << "Fusion warning: Fuison op build failed, err log: " << task_result - << " change to single op build."; - build_failed_num++; - } - auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); - if (kernel_mod_item.second != nullptr) { - (void)kernel_mod_ret.emplace(kernel_mod_item); - } - } - MS_LOG(INFO) << "Build Fusion Kernel Failed Num: " << build_failed_num; - return kernel_mod_ret; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_fusion.h b/mindspore/ccsrc/kernel/kernel_fusion.h deleted file mode 100644 index 8ded21787c..0000000000 --- a/mindspore/ccsrc/kernel/kernel_fusion.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ -#include -#include -#include "kernel/kernel.h" -namespace mindspore { -namespace kernel { -/* - * @brief fuse op and return a callable mod - */ -struct FusionScopeInfo { - int32_t scope_id; - std::vector input_nodes; - std::vector compute_nodes; - std::vector output_nodes; -}; - -std::map KernelFusion(const std::vector &fusion_scopes); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_KERNELFUSION_H_ diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc deleted file mode 100755 index 4a8ae81afa..0000000000 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/kernel_query.h" -#include -#include -#include "kernel/aicpu/aicpu_kernel_metadata.h" -#include "kernel/rts/rt_kernel_info.h" -#include "kernel/hccl/hccl_kernel_metadata.h" -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" -#include "kernel/akg/akg_kernel_metadata.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -namespace { -void FilterInvalidKernelInfo(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_info_list); - std::vector> filtered_list; - (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), - [&kernel_node](const std::shared_ptr &kernel_build_info) { - return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && - AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); - }); - if (!filtered_list.empty()) { - kernel_info_list->clear(); - (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); - } else { - MS_LOG(INFO) << "All kernel Info list does not match any kernel info "; - for (size_t index = 0; index < kernel_info_list->size(); ++index) { - std::ostringstream buffer; - auto kernel_info = kernel_info_list->at(index); - MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) { - buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" - << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]"; - } else { - buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" - << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]"; - } - MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str(); - } - kernel_info_list->clear(); - MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" - << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" - << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; - } -} -} // namespace - -void KernelQueryAll(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - - TbeMetadataInfo(kernel_node, kernel_info_list); - - if (kernel_info_list->empty()) { - AicpuMetadataInfo(kernel_node, kernel_info_list); - if (!kernel_info_list->empty()) { - MS_LOG(INFO) << "The node [" << kernel_node->DebugString() - << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; - AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); - } - } - - if (kernel_info_list->empty()) { - GetRtKelInfo(kernel_node, kernel_info_list); - } - - if (kernel_info_list->empty()) { - HcclMetadataInfo(kernel_node, kernel_info_list); - } - if (kernel_info_list->empty()) { - MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; - } -} - -void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, - KernelType kernel_type) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - - std::string op_name = AnfAlgo::GetCNodeName(kernel_node); - - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { - kernel_type = KernelType::AKG_KERNEL; - } - - switch (kernel_type) { - case KernelType::AKG_KERNEL: - AkgMetadataInfo(kernel_node, kernel_info_list); - break; - default: - KernelQueryAll(kernel_node, kernel_info_list); - break; - } - - if (kernel_info_list->empty()) { - MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; - } - // check output - FilterInvalidKernelInfo(kernel_node, kernel_info_list); -} - -void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(kernel_info_list); - kernel_info_list->clear(); - AicpuMetadataInfo(kernel_node, kernel_info_list); - FilterInvalidKernelInfo(kernel_node, kernel_info_list); -} -bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(select_kernel_build_info); - std::vector> kernel_info_list; - auto cnode = kernel_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - AICPUQuery(cnode, &kernel_info_list); - return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), - [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { - MS_EXCEPTION_IF_NULL(item); - return *item == *select_kernel_build_info; - }); -} - -bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { - MS_EXCEPTION_IF_NULL(kernel_node); - MS_EXCEPTION_IF_NULL(select_kernel_build_info); - std::vector> kernel_info_list; - auto cnode = kernel_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - TbeMetadataInfo(cnode, &kernel_info_list); - return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), - [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { - MS_EXCEPTION_IF_NULL(item); - return *item == *select_kernel_build_info; - }); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_query.h b/mindspore/ccsrc/kernel/kernel_query.h deleted file mode 100644 index 257b0cf073..0000000000 --- a/mindspore/ccsrc/kernel/kernel_query.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ -#define MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ - -#include -#include -#include -#include "kernel/kernel.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace kernel { -void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, - KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); -void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); -bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_KERNEL_QUERY_H_ diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h deleted file mode 100644 index 990702d100..0000000000 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ -#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ -#include -#include -#include -#include -#include "ir/dtype.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU }; -enum OpIOType { kInput = 0, kOutput }; - -class OpAttr { - public: - OpAttr() = default; - ~OpAttr() = default; - - std::string name() const { return name_; } - std::string param_type() const { return param_type_; } - std::string type() const { return type_; } - std::string value() const { return value_; } - std::string default_value() const { return default_value_; } - - void set_name(const std::string &name) { name_ = name; } - void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } - void set_type(const std::string &type) { type_ = type; } - void set_value(const std::string &value) { value_ = value; } - void set_default_value(const std::string &default_value) { default_value_ = default_value; } - - private: - std::string name_; - std::string param_type_; - std::string type_; - std::string value_; - std::string default_value_; -}; - -class OpIOInfo { - public: - OpIOInfo() = default; - ~OpIOInfo() = default; - - int index() const { return index_; } - std::string name() const { return name_; } - bool need_compile() const { return need_compile_; } - std::string param_type() const { return param_type_; } - std::string reshape_type() const { return reshape_type_; } - std::string shape() const { return shape_; } - std::vector dtypes() const { return dtypes_; } - std::vector formats() const { return formats_; } - - void set_index(const int index) { index_ = index; } - void set_name(const std::string &name) { name_ = name; } - void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } - void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } - void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } - void set_shape(const std::string &shape) { shape_ = shape; } - void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } - void set_formats(const std::vector &formats) { formats_ = formats; } - - private: - int index_ = 0; - std::string name_; - bool need_compile_ = false; - std::string param_type_; - std::string reshape_type_; - std::string shape_; - std::vector dtypes_; - std::vector formats_; -}; - -class OpInfo { - public: - OpInfo() = default; - OpInfo(const OpInfo &opinfo) { - op_name_ = opinfo.op_name(); - imply_type_ = opinfo.imply_type(); - - impl_path_ = opinfo.impl_path(); - fusion_type_ = opinfo.fusion_type(); - async_flag_ = opinfo.async_flag_; - binfile_name_ = opinfo.binfile_name_; - compute_cost_ = opinfo.compute_cost_; - kernel_name_ = opinfo.kernel_name(); - partial_flag_ = opinfo.partial_flag_; - dynamic_format_ = opinfo.dynamic_format_; - op_pattern_ = opinfo.op_pattern(); - processor_ = opinfo.processor_; - for (const auto &attr : opinfo.attrs_ptr()) { - attrs_ptr_.push_back(std::make_shared(*attr)); - } - for (const auto &input : opinfo.inputs_ptr()) { - inputs_ptr_.push_back(std::make_shared(*input)); - } - for (const auto &output : opinfo.outputs_ptr()) { - outputs_ptr_.push_back(std::make_shared(*output)); - } - ref_infos_ = opinfo.ref_infos(); - } - ~OpInfo() = default; - std::string op_name() const { return op_name_; } - OpImplyType imply_type() const { return imply_type_; } - std::string impl_path() const { return impl_path_; } - std::string fusion_type() const { return fusion_type_; } - std::string kernel_name() const { return kernel_name_; } - OpPattern op_pattern() const { return op_pattern_; } - std::string processor() const { return processor_; } - std::vector> attrs_ptr() const { return attrs_ptr_; } - std::vector> inputs_ptr() const { return inputs_ptr_; } - std::vector> outputs_ptr() const { return outputs_ptr_; } - const std::unordered_map &ref_infos() const { return ref_infos_; } - - void set_op_name(const std::string &op_name) { op_name_ = op_name; } - void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } - void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } - void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } - void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } - void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } - void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } - void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } - void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } - void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } - void set_processor(const std::string &processor) { processor_ = processor; } - void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } - void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } - void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } - bool is_ref() const { return !ref_infos_.empty(); } - bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } - void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } - void ClearInputs() { (void)inputs_ptr_.clear(); } - void ClearOutputs() { (void)outputs_ptr_.clear(); } - bool equals_to(const std::shared_ptr &other_info) const { - return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && - this->processor_ == other_info->processor_; - } - - private: - std::string op_name_; - OpImplyType imply_type_ = kTBE; - std::string impl_path_; - std::string fusion_type_; - bool async_flag_ = false; - std::string binfile_name_; - int compute_cost_ = 0; - std::string kernel_name_; - bool partial_flag_ = false; - bool dynamic_format_ = false; - OpPattern op_pattern_ = kCommonPattern; - std::string processor_; - std::vector> attrs_ptr_; - std::vector> inputs_ptr_; - std::vector> outputs_ptr_; - std::unordered_map ref_infos_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPINFO_H_ diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc deleted file mode 100644 index 5b322c12a4..0000000000 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ /dev/null @@ -1,390 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/oplib/oplib.h" -#include -#include -#include -#include -#include -#include "utils/log_adapter.h" -#include "utils/overload.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -constexpr auto kImplyType = "imply_type"; -constexpr auto kOpName = "op_name"; -constexpr auto kFusionType = "fusion_type"; -constexpr auto kAsyncFlag = "async_flag"; -constexpr auto kBinfileName = "binfile_name"; -constexpr auto kComputeCost = "compute_cost"; -constexpr auto kKernelName = "kernel_name"; -constexpr auto kPartialFlag = "partial_flag"; -constexpr auto kReshapeType = "reshape_type"; -constexpr auto kOpPattern = "op_pattern"; -constexpr auto kDynamicFormat = "dynamicFormat"; -constexpr auto kFormatAgnostic = "formatAgnostic"; -constexpr auto kBroadcast = "broadcast"; -constexpr auto kReduce = "reduce"; -constexpr auto kDtypeFormat = "dtype_format"; -constexpr auto kAttr = "attr"; -constexpr auto kIputs = "inputs"; -constexpr auto kOutputs = "outputs"; -constexpr auto kAiCPU = "AiCPU"; -constexpr auto kAiCore = "AiCore"; -constexpr auto kCUDA = "CUDA"; -constexpr auto kTbe = "TBE"; -constexpr auto kAkg = "AKG"; -constexpr auto kName = "name"; -constexpr auto kParamType = "param_type"; -constexpr auto kDtype = "dtype"; -constexpr auto kType = "type"; -constexpr auto kValue = "value"; -constexpr auto kDefaultValue = "default_value"; -constexpr auto kIndex = "index"; -constexpr auto kFormat = "format"; -constexpr auto kNeedCompile = "need_compile"; -constexpr auto kShape = "shape"; -constexpr auto kProcessor = "processor"; -std::vector> OpLib::op_info_; - -static std::string ImplTypeToStr(OpImplyType impl_type) { - switch (impl_type) { - case kTBE: - return kTbe; - case kAKG: - return kAkg; - case kAICPU: - return kAiCPU; - default: - return "unknow"; - } -} -bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { - bool ret = false; - try { - auto op_json = nlohmann::json::parse(json_string); - std::string imply_type_string = op_json.at(kImplyType); - std::string op_name = op_json.at(kOpName); - if (imply_type_string == kTbe) { - OpImplyType imply_type = kTBE; - ret = DecodeOpInfo(op_json, imply_type, impl_path); - } else if (imply_type_string == kAkg) { - OpImplyType imply_type = kAKG; - ret = DecodeOpInfo(op_json, imply_type, impl_path); - } else if (imply_type_string == kAiCPU) { - OpImplyType imply_type = kAICPU; - ret = DecodeOpInfo(op_json, imply_type, impl_path); - } else { - MS_LOG(ERROR) << "Not support imply_type"; - } - if (!ret) { - MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string; - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "get op json elements failed: " << e.what(); - } - return ret; -} - -void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { - const std::map kOpPatternMap = {{kFormatAgnostic, kFormatAgnosticPattern}, - {kBroadcast, kBroadcastPattern}, - {kReduce, kReducePattern}, - {kDynamicFormat, kDynamicFormatPattern}}; - MS_EXCEPTION_IF_NULL(op_info); - op_info->set_async_flag(obj.at(kAsyncFlag)); - op_info->set_binfile_name(obj.at(kBinfileName)); - op_info->set_compute_cost(obj.at(kComputeCost)); - op_info->set_kernel_name(obj.at(kKernelName)); - op_info->set_partial_flag(obj.at(kPartialFlag)); - - if (obj.find(kOpPattern) != obj.end()) { - std::string op_pattern = obj.at(kOpPattern); - auto find_iter = kOpPatternMap.find(op_pattern); - if (find_iter == kOpPatternMap.end()) { - if (!op_pattern.empty()) { - MS_LOG(WARNING) << "Op pattern set value error: " << op_pattern; - } - op_info->set_op_pattern(kCommonPattern); - } else { - op_info->set_op_pattern(find_iter->second); - } - } -} - -void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - op_info->set_processor(obj.at(kProcessor)); -} - -bool OpLib::RegOpFromLocalInfo() { - MS_LOG(INFO) << "Start"; - static bool has_load = false; - if (has_load) { - return true; - } - has_load = true; - std::string dir = common::GetEnv("MINDSPORE_OP_INFO_PATH"); - if (dir.empty()) { - MS_LOG(INFO) << "MindSpore op info path does not been setted. use op info from python pass."; - return true; - } - char real_path[PATH_MAX] = {0}; - if (dir.size() >= PATH_MAX) { - MS_LOG(ERROR) << "Op info path is invalid: " << dir; - return false; - } -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Op info path is invalid: " << dir; - return false; - } -#else - if (realpath(common::SafeCStr(dir), real_path) == nullptr) { - MS_LOG(ERROR) << "Op info path is invalid: " << dir; - return false; - } -#endif - MS_LOG(INFO) << "Start to read op info from local file."; - std::ifstream file(real_path); - if (!file.is_open()) { - MS_LOG(ERROR) << "Find op info file failed."; - return false; - } - std::string line; - while (getline(file, line)) { - if (!line.empty()) { - (void)OpLib::RegOp(line, ""); - } - } - MS_LOG(INFO) << "End"; - return true; -} - -bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, - const std::string &impl_path) { - std::shared_ptr op_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_info); - op_info->set_op_name(obj.at(kOpName)); - op_info->set_impl_path(impl_path); - op_info->set_imply_type(imply_type); - op_info->set_fusion_type(obj.at(kFusionType)); - if (imply_type == kTBE) { - DecodeTBESpecificInfo(obj, op_info); - } else if (imply_type == kAKG) { - DecodeAKGSpecificInfo(obj, op_info); - } - auto attrs = obj.at(kAttr); - for (const auto &attr : attrs) { - if (!DecodeAttr(attr, imply_type, op_info)) { - MS_LOG(ERROR) << "DecodeAttr Failed"; - return false; - } - } - nlohmann::json dtype_format; - if (obj.find(kDtypeFormat) != obj.end()) { - dtype_format = obj.at(kDtypeFormat); - } - auto inputs = obj.at(kIputs); - for (const auto &input : inputs) { - if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { - MS_LOG(ERROR) << "DecodeInputOutput Failed"; - return false; - } - } - auto outputs = obj.at(kOutputs); - for (const auto &output : outputs) { - if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { - MS_LOG(ERROR) << "DecodeInputOutput Failed"; - return false; - } - } - if (CheckRepetition(op_info)) { - MS_LOG(WARNING) << "This op info has been already registed. op name: " << op_info->op_name() - << ", impl type: " << ImplTypeToStr(op_info->imply_type()) - << ", impl path: " << op_info->impl_path(); - return true; - } - if (!GetRefInfo(op_info)) { - MS_LOG(ERROR) << "GetRefInfo Failed"; - return false; - } - op_info_.push_back(op_info); - return true; -} - -bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, - const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - bool ret = true; - try { - std::shared_ptr op_attr = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_attr); - op_attr->set_name(obj.at(kName)); - if (imply_type != kAICPU) { - op_attr->set_param_type(obj.at(kParamType)); - } - op_attr->set_type(obj.at(kType)); - if (imply_type == kTBE) { - op_attr->set_value(obj.at(kValue)); - } - if (obj.find(kDefaultValue) != obj.end()) { - op_attr->set_default_value(obj.at(kDefaultValue)); - } - op_info->add_attrs_ptr(op_attr); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "DecodeAttr failed:" << e.what(); - ret = false; - } - return ret; -} - -bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, - size_t index) { - MS_EXCEPTION_IF_NULL(op_io); - bool ret = true; - try { - std::vector dtype; - std::vector format; - for (const auto &it : dtype_format) { - dtype.emplace_back(it[index][0]); - format.emplace_back(it[index][1]); - } - op_io->set_dtypes(dtype); - op_io->set_formats(format); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); - ret = false; - } - return ret; -} - -bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { - MS_EXCEPTION_IF_NULL(op_info); - bool ret = true; - try { - std::shared_ptr op_io = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_io); - op_io->set_index(obj.at(kIndex)); - op_io->set_name(obj.at(kName)); - if (!dtype_format.empty()) { - if (!DecodeDtypeFormat(dtype_format, op_io, op_info->inputs_ptr().size() + op_info->outputs_ptr().size())) { - MS_LOG(ERROR) << "Decode dtype format failed"; - return false; - } - } else { - op_io->set_dtypes(obj.at(kDtype)); - op_io->set_formats(obj.at(kFormat)); - } - if (op_io->dtypes().size() != op_io->formats().size()) { - MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes() - << " is not equal to format size: " << op_io->formats(); - return false; - } - if (obj.find(kParamType) != obj.end()) { - op_io->set_param_type(obj.at(kParamType)); - } - if (imply_type == kTBE) { - if (obj.find(kNeedCompile) != obj.end()) { - op_io->set_need_compile(obj.at(kNeedCompile)); - } - if (obj.find(kShape) != obj.end()) { - op_io->set_shape(obj.at(kShape)); - } - if (obj.find(kReshapeType) != obj.end()) { - op_io->set_reshape_type(obj.at(kReshapeType)); - } - } - - if (io_type == kInput) { - op_info->add_inputs_ptr(op_io); - } else if (io_type == kOutput) { - op_info->add_outputs_ptr(op_io); - } - } catch (const std::exception &e) { - MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what(); - ret = false; - } - return ret; -} - -std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { - if (!OpLib::RegOpFromLocalInfo()) { - MS_LOG(INFO) << "Warning reg local op info failed."; - } - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool is_gpu = (context->device_target() == kGPUDevice); - if (is_gpu && (imply_type == kTBE || imply_type == kAICPU)) { - MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) - << ", current op num: " << op_info_.size(); - return nullptr; - } - for (const auto &op_info : op_info_) { - MS_EXCEPTION_IF_NULL(op_info); - if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { - auto akg_processor_match = [&]() { - return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore; - }; - if (imply_type != kAKG || akg_processor_match()) { - return op_info; - } - } - } - MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) - << ", current op num: " << op_info_.size(); - return nullptr; -} - -bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - const auto &output_infos = op_info->outputs_ptr(); - const auto &input_infos = op_info->inputs_ptr(); - for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { - MS_EXCEPTION_IF_NULL(output_infos[out_index]); - const auto &out_name = output_infos[out_index]->name(); - for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { - MS_EXCEPTION_IF_NULL(input_infos[in_index]); - const auto &in_name = input_infos[in_index]->name(); - if (out_name == in_name) { - if (op_info->has_ref_index(out_index)) { - MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info"; - return false; - } - op_info->add_ref_pair(out_index, in_index); - MS_LOG(INFO) << "add ref info, op name is " << op_info->op_name() << ", outindex is " << out_index - << ", in_index is " << in_index; - } - } - } - return true; -} - -bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - for (const auto &exist_op_info : op_info_) { - MS_EXCEPTION_IF_NULL(exist_op_info); - if (exist_op_info->equals_to(op_info)) { - return true; - } - } - return false; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h deleted file mode 100644 index 742b0977c7..0000000000 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ -#define MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ -#include -#include -#include -#include -#include "kernel/oplib/opinfo.h" - -namespace mindspore { -namespace kernel { -class OpLib { - public: - OpLib() = default; - virtual ~OpLib() = default; - static bool RegOp(const std::string &json_string, const std::string &impl_path); - static void RegOpInfo(const std::shared_ptr &opinfo) { op_info_.emplace_back(opinfo); } - static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); - static const std::vector> &GetAllOpsInfo() { return op_info_; } - - protected: - static std::vector> op_info_; - - private: - static bool RegOpFromLocalInfo(); - static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); - static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, - const std::shared_ptr &op_info); - static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, - size_t index); - static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); - static void DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); - static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr &op_info, const nlohmann::json &dtype_format); - static bool GetRefInfo(const std::shared_ptr &op_info); - static bool CheckRepetition(const std::shared_ptr &op_info); -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_OPLIB_OPLIB_H_ diff --git a/mindspore/ccsrc/kernel/oplib/oploader.h b/mindspore/ccsrc/kernel/oplib/oploader.h deleted file mode 100644 index dd4c37e80b..0000000000 --- a/mindspore/ccsrc/kernel/oplib/oploader.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_OPLOADER_H -#define MINDSPORE_OPLOADER_H - -#include -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace kernel { -class OpInfoLoaderPy { - public: - OpInfoLoaderPy() = default; - - ~OpInfoLoaderPy() = default; - - size_t GetAllOpsInfo() { - auto ops = OpLib::GetAllOpsInfo(); - auto op_infos = new std::vector(); - for (auto op_info : ops) { - auto new_op_info = new OpInfo(*op_info); - op_infos->emplace_back(new_op_info); - } - return (size_t)op_infos; - } -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_OPLOADER_H diff --git a/mindspore/ccsrc/kernel/rts/assign.cc b/mindspore/ccsrc/kernel/rts/assign.cc deleted file mode 100644 index 7038004898..0000000000 --- a/mindspore/ccsrc/kernel/rts/assign.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/assign.h" - -#include - -#include "runtime/mem.h" -#include "common/utils.h" - -using ge::model_runner::MemcpyAsyncTaskInfo; -using MemcpyAsyncTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -AssignKernel::AssignKernel() {} - -AssignKernel::~AssignKernel() {} - -bool AssignKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void *stream_ptr) { - if (inputs.size() != 2) { - MS_LOG(ERROR) << "inputs size is not two"; - return false; - } - - if (inputs[0]->addr == inputs[1]->addr) { - MS_LOG(INFO) << "first addr is same with second addr , no need assign"; - return true; - } - rtError_t status = rtMemcpyAsync(inputs[0]->addr, inputs[0]->size, inputs[1]->addr, inputs[1]->size, - RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Assign op rtMemcpyAsync failed!"; - return false; - } - return true; -} - -std::vector AssignKernel::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - if (inputs.size() != 2) { - MS_LOG(EXCEPTION) << "inputs size is not two"; - } - stream_id_ = stream_id; - - std::shared_ptr task_info_ptr = - std::make_shared(kernel_name_, stream_id, inputs[0]->addr, inputs[0]->size, inputs[1]->addr, - inputs[1]->size, RT_MEMCPY_DEVICE_TO_DEVICE, false); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/assign.h b/mindspore/ccsrc/kernel/rts/assign.h deleted file mode 100644 index 0e7e52d48f..0000000000 --- a/mindspore/ccsrc/kernel/rts/assign.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H -#define MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H - -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class AssignKernel : public RtKernel { - public: - AssignKernel(); - ~AssignKernel() override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; -}; - -MS_REG_RTKERNEL(assign, AssignKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_ASSIGN_H diff --git a/mindspore/ccsrc/kernel/rts/label_goto.cc b/mindspore/ccsrc/kernel/rts/label_goto.cc deleted file mode 100644 index 1d29bb4f35..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_goto.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/label_goto.h" -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::LabelGotoTaskInfo; -using LabelGotoTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -LabelGotoKernel::LabelGotoKernel() { label_ = 0; } - -LabelGotoKernel::~LabelGotoKernel() {} - -bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "LabelGotoKernel init"; - auto cnode = anf_node->cast(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { - MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); - MS_LOG(INFO) << "LabelGotoKernel get attr label:" << label_; - return true; -} - -bool LabelGotoKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "LabelGotoKernel launch"; - return true; -} - -std::vector LabelGotoKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "LabelGotoKernel GenTask label:" << label_ << ", stream id:" << stream_id; - std::vector task_info_list; - std::shared_ptr task_info_ptr = - std::make_shared(kernel_name_, stream_id, label_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - return task_info_list; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_goto.h b/mindspore/ccsrc/kernel/rts/label_goto.h deleted file mode 100644 index efccc12d6f..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_goto.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class LabelGotoKernel : public RtKernel { - public: - LabelGotoKernel(); - ~LabelGotoKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t label_; -}; - -MS_REG_RTKERNEL(labelgoto, LabelGotoKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_GOTO_H diff --git a/mindspore/ccsrc/kernel/rts/label_set.cc b/mindspore/ccsrc/kernel/rts/label_set.cc deleted file mode 100644 index 4266e2b0af..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_set.cc +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/label_set.h" -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::LabelSetTaskInfo; -using LabelSetTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -LabelSetKernel::LabelSetKernel() { label_ = 0; } - -LabelSetKernel::~LabelSetKernel() {} - -bool LabelSetKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "LabelSetKernel init"; - auto cnode = anf_node->cast(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { - MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - label_ = GetValue(primitive->GetAttr(kAttrLabelIndex)); - MS_LOG(INFO) << "LabelSetKernel get attr label:" << label_; - return true; -} - -bool LabelSetKernel::Launch(const std::vector & /*inputs*/, const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "LabelSetKernel launch"; - return true; -} - -std::vector LabelSetKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "LabelSetKernel GenTask label:" << label_ << ", stream id:" << stream_id; - std::vector task_info_list; - std::shared_ptr task_info_ptr = std::make_shared(kernel_name_, stream_id, label_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - return task_info_list; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_set.h b/mindspore/ccsrc/kernel/rts/label_set.h deleted file mode 100644 index d05d81f898..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_set.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class LabelSetKernel : public RtKernel { - public: - LabelSetKernel(); - ~LabelSetKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t label_; -}; - -MS_REG_RTKERNEL(labelset, LabelSetKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SET_H diff --git a/mindspore/ccsrc/kernel/rts/label_switch.cc b/mindspore/ccsrc/kernel/rts/label_switch.cc deleted file mode 100644 index bc5282b4af..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_switch.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/label_switch.h" -#include -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::LabelSwitchTaskInfo; -using LabelSwitchTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -LabelSwitchKernel::LabelSwitchKernel() { - label_list_ = {}; - cond_ = nullptr; - label_size_ = 0; -} - -LabelSwitchKernel::~LabelSwitchKernel() {} - -bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "LabelSwitchKernel init"; - auto cnode = anf_node->cast(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) { - MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - label_list_ = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); - label_size_ = label_list_.size(); - MS_LOG(INFO) << "LabelSwitchKernel get attr label size:" << label_size_; - for (auto label : label_list_) { - MS_LOG(INFO) << "label: " << label; - } - return true; -} - -bool LabelSwitchKernel::Launch(const std::vector & /*inputs*/, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - MS_LOG(INFO) << "LabelSwitchKernel launch"; - return true; -} - -std::vector LabelSwitchKernel::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; - std::vector task_info_list; - cond_ = inputs[0]->addr; - auto task_info_ptr = std::make_shared(kernel_name_, stream_id, label_size_, label_list_, cond_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - return task_info_list; -} - -std::vector> LabelSwitchDesc::GetKernelInfo() { - std::vector> label_switch_build_info{}; - vector input_format{kOpFormat_DEFAULT}; - vector input_type{kNumberTypeInt32}; - if (input_format.size() != input_type.size()) { - MS_LOG(EXCEPTION) << "Invalid param num, input_format size " << input_format.size() << " input_type size " - << input_type.size(); - } - for (size_t i = 0; i < input_format.size(); ++i) { - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat({input_format[i]}); - builder.SetInputsDeviceType({input_type[i]}); - builder.SetProcessor(AICORE); - builder.SetKernelType(RT_KERNEL); - builder.SetFusionType(OPAQUE); - label_switch_build_info.emplace_back(builder.Build()); - } - return label_switch_build_info; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/label_switch.h b/mindspore/ccsrc/kernel/rts/label_switch.h deleted file mode 100644 index 858f851b2a..0000000000 --- a/mindspore/ccsrc/kernel/rts/label_switch.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H -#define MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class LabelSwitchKernel : public RtKernel { - public: - LabelSwitchKernel(); - ~LabelSwitchKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - std::vector label_list_; - uint32_t label_size_; - void *cond_; -}; - -class LabelSwitchDesc : public RtKerDesc { - public: - LabelSwitchDesc() = default; - ~LabelSwitchDesc() override = default; - std::vector> GetKernelInfo() override; -}; - -MS_REG_RTKERNEL_DESC(labelswitch, LabelSwitchDesc); -MS_REG_RTKERNEL(labelswitch, LabelSwitchKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_LABEL_SWITCH_H diff --git a/mindspore/ccsrc/kernel/rts/memcpy_async.cc b/mindspore/ccsrc/kernel/rts/memcpy_async.cc deleted file mode 100644 index ea33c4dd8b..0000000000 --- a/mindspore/ccsrc/kernel/rts/memcpy_async.cc +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/memcpy_async.h" - -#include -#include - -#include "runtime/mem.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/trans.h" -#include "utils/context/ms_context.h" - -using ge::model_runner::MemcpyAsyncTaskInfo; -using MemcpyAsyncTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -MemCpyAsyncKernel::MemCpyAsyncKernel() {} - -MemCpyAsyncKernel::~MemCpyAsyncKernel() {} - -bool MemCpyAsyncKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, - const std::vector &outputs, void *stream_ptr) { - if (inputs.size() != 1) { - MS_LOG(ERROR) << "inputs size is not one"; - return false; - } - if (outputs.size() != 1) { - MS_LOG(ERROR) << "outputs size is not one"; - return false; - } - - if (inputs[0]->addr == outputs[0]->addr) { - MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async"; - return true; - } - if (outputs[0]->size < inputs[0]->size) { - MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; - } - // input x -> memcpy_async -> AllReduce - if (outputs[0]->size > inputs[0]->size) { - MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; - } - rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, - RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!"; - return false; - } - return true; -} - -bool MemCpyAsyncKernel::Init(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - GetInputOutputDataType(anf_node); - GetInputOutputTotalCount(anf_node); - return true; -} - -void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); - if (input_size != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; - } - input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); -} - -void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - size_t input_size = AnfAlgo::GetInputTensorNum(anf_node); - if (input_size != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; - } - size_t type_size = trans::TypeIdSize(input_type_id_); - std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0); - size_t total_size = 1; - for (size_t i = 0; i < shape_i.size(); i++) { - total_size = total_size * shape_i[i]; - } - total_size *= type_size; - MS_LOG(INFO) << "MemCpyAsync size[" << total_size << "]"; - input_size_list_.emplace_back(total_size); - output_size_list_.emplace_back(total_size); -} - -std::vector MemCpyAsyncKernel::GenTask(const std::vector &inputs, - const std::vector &, - const std::vector &outputs, uint32_t stream_id) { - if (inputs.size() != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one"; - } - - if (outputs.size() != 1) { - MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one"; - } - - if (outputs[0]->size < inputs[0]->size) { - MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; - } - // input x -> memcpy_async -> AllReduce - if (outputs[0]->size > inputs[0]->size) { - MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; - } - - stream_id_ = stream_id; - std::shared_ptr task_info_ptr = - std::make_shared(kernel_name_, stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, - inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE, NeedDump()); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} - -const std::vector data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, - kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, - kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16, - kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool}; -const std::vector format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, - kOpFormat_C1HWNCoC0}; - -MemCpyAsyncDesc::MemCpyAsyncDesc() {} - -MemCpyAsyncDesc::~MemCpyAsyncDesc() {} - -std::vector> MemCpyAsyncDesc::GetKernelInfo() { - std::vector> memcpy_build_info{}; - for (const auto &format : format_list) { - for (const auto &type : data_type_list) { - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - vector input_format{format}; - vector input_type{type}; - vector output_format{format}; - vector output_type{type}; - builder.SetInputsFormat(input_format); - builder.SetInputsDeviceType(input_type); - builder.SetOutputsFormat(output_format); - builder.SetOutputsDeviceType(output_type); - builder.SetProcessor(AICORE); - builder.SetKernelType(RT_KERNEL); - builder.SetFusionType(OPAQUE); - memcpy_build_info.emplace_back(builder.Build()); - } - } - return memcpy_build_info; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/memcpy_async.h b/mindspore/ccsrc/kernel/rts/memcpy_async.h deleted file mode 100644 index 94bbf1ca1c..0000000000 --- a/mindspore/ccsrc/kernel/rts/memcpy_async.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H -#define MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class MemCpyAsyncKernel : public RtKernel { - public: - MemCpyAsyncKernel(); - ~MemCpyAsyncKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - void GetInputOutputDataType(const AnfNodePtr &anf_node); - void GetInputOutputTotalCount(const AnfNodePtr &anf_node); - TypeId input_type_id_{}; -}; - -class MemCpyAsyncDesc : public RtKerDesc { - public: - MemCpyAsyncDesc(); - ~MemCpyAsyncDesc() override; - std::vector> GetKernelInfo() override; -}; - -MS_REG_RTKERNEL_DESC(memcpy_async, MemCpyAsyncDesc); -MS_REG_RTKERNEL(memcpy_async, MemCpyAsyncKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_MEMCPY_ASYNC_H diff --git a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.cc b/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.cc deleted file mode 100644 index 0161e8562a..0000000000 --- a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/profiling_kernel_mod.h" - -#include -#include -#include - -#include "framework/ge_runtime/task_info.h" -#include "device/ascend/profiling/profiling_utils.h" -#include "session/anf_runtime_algorithm.h" - -using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; -using mindspore::device::ascend::ProfilingUtils; - -namespace mindspore { -namespace kernel { -bool ProfilingKernelMod::Init(const AnfNodePtr &anf_node) { - MS_LOG(INFO) << "[profiling] init profiling kernel mod"; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - - ValuePtr notify_ptr = primitive->GetAttr(ProfilingUtils::kNotify); - MS_EXCEPTION_IF_NULL(notify_ptr); - - ValuePtr log_id_ptr = primitive->GetAttr(ProfilingUtils::kProfilerTraceId); - MS_EXCEPTION_IF_NULL(log_id_ptr); - - ValuePtr flags_ptr = primitive->GetAttr(ProfilingUtils::kFlags); - MS_EXCEPTION_IF_NULL(flags_ptr); - - notify_ = GetValue(notify_ptr); - log_id_ = GetValue(log_id_ptr); - flags_ = GetValue(flags_ptr); - MS_LOG(INFO) << "[profiling] profiling kernel notify_:" << notify_ << ", log_id_:" << log_id_ - << ", flags_:" << flags_; - return true; -} - -bool ProfilingKernelMod::Launch(const std::vector & /*inputs*/, - const std::vector & /*workspace*/, - const std::vector & /*outputs*/, void * /*stream_ptr*/) { - return true; -} - -std::vector ProfilingKernelMod::GenTask(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) { - MS_LOG(INFO) << "gen task inputs size:" << inputs.size() << ", workspace size:" << workspace.size() - << ", outputs size:" << outputs.size(); - stream_id_ = stream_id; - std::shared_ptr task_info_ptr = - std::make_shared(kernel_name_, stream_id, log_id_, notify_, flags_); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.h b/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.h deleted file mode 100644 index f77f3b5c67..0000000000 --- a/mindspore/ccsrc/kernel/rts/profiling_kernel_mod.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ -#define MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ -#include -#include "kernel/rts/rt_kernel.h" -namespace mindspore { -namespace kernel { -class ProfilingKernelMod : public RtKernel { - public: - ProfilingKernelMod() = default; - ~ProfilingKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - bool Init(const AnfNodePtr &anf_node) override; - - private: - uint64_t log_id_{0}; - bool notify_{true}; - uint32_t flags_{0}; -}; -MS_REG_RTKERNEL(profiling, ProfilingKernelMod); -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_RTS_PROFILING_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/rts/recv.cc b/mindspore/ccsrc/kernel/rts/recv.cc deleted file mode 100644 index 3fb2fd6bb5..0000000000 --- a/mindspore/ccsrc/kernel/rts/recv.cc +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/recv.h" -#include -#include "runtime/stream.h" -#include "utils/context/ms_context.h" -#include "device/ascend/ascend_stream_assign.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -using ge::model_runner::EventWaitTaskInfo; -using mindspore::device::ascend::AscendStreamAssign; -using EventWaitTaskInfoPtr = std::shared_ptr; - -RecvKernel::RecvKernel() { event_id_ = 0; } - -RecvKernel::~RecvKernel() {} - -bool RecvKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { - MS_LOG(EXCEPTION) << "RecvKernel has no attr kAttrEventId"; - } - event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); - MS_LOG(INFO) << "recv op event_id_:" << event_id_; - return true; -} - -bool RecvKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - rtEvent_t stream_event{}; - auto status = rtStreamWaitEvent(stream_ptr, stream_event); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Recv rtStreamWaitEvent failed!"; - return false; - } - return true; -} - -std::vector RecvKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "RecvKernel GenTask event_id_:" << event_id_ << ", stream_id_:" << stream_id; - stream_id_ = stream_id; - EventWaitTaskInfoPtr task_info_ptr = std::make_shared(kernel_name_, stream_id, event_id_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/recv.h b/mindspore/ccsrc/kernel/rts/recv.h deleted file mode 100644 index 68f0b69cc5..0000000000 --- a/mindspore/ccsrc/kernel/rts/recv.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RECV_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RECV_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class RecvKernel : public RtKernel { - public: - RecvKernel(); - ~RecvKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t event_id_; -}; - -MS_REG_RTKERNEL(recv, RecvKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RECV_H diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel.cc b/mindspore/ccsrc/kernel/rts/rt_kernel.cc deleted file mode 100644 index 9e81372383..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/rt_kernel.h" - -namespace mindspore { -namespace kernel { -void RtKernelFactory::Registe(const std::string &name, RtKernelCreater &&fun) { - (void)fmap_.emplace(name, std::move(fun)); -} - -std::shared_ptr RtKernelFactory::Create(const std::string &name) { - const auto &map = Get().fmap_; - auto it = map.find(name); - if (it != map.end() && it->second) { - return (it->second)(); - } - return nullptr; -} - -RtKernelFactory &RtKernelFactory::Get() { - static RtKernelFactory _this; - return _this; -} - -RtKernel::RtKernel() {} - -RtKernel::~RtKernel() {} - -bool RtKernel::Init(const mindspore::AnfNodePtr & /*anf_node*/) { return true; } - -const std::vector &RtKernel::GetInputSizeList() const { return input_size_list_; } - -const std::vector &RtKernel::GetOutputSizeList() const { return output_size_list_; } - -const std::vector &RtKernel::GetWorkspaceSizeList() const { return workspace_size_list_; } -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel.h b/mindspore/ccsrc/kernel/rts/rt_kernel.h deleted file mode 100644 index 44d55dca31..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H - -#include -#include -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/task_stream.h" - -namespace mindspore { -namespace kernel { -class RtKernel : public AscendKernelMod { - public: - RtKernel(); - ~RtKernel() override; - virtual bool Init(const AnfNodePtr &anf_node); - const std::vector &GetInputSizeList() const override; - const std::vector &GetOutputSizeList() const override; - const std::vector &GetWorkspaceSizeList() const override; - - protected: - mutable std::vector input_size_list_; - mutable std::vector output_size_list_; - mutable std::vector workspace_size_list_; -}; - -using RTKernelPtr = std::shared_ptr; - -using RtKernelCreater = std::function()>; -class RtKernelFactory { - RtKernelFactory() = default; - ~RtKernelFactory() = default; - - public: - static RtKernelFactory &Get(); - void Registe(const std::string &name, RtKernelCreater &&fun); - static std::shared_ptr Create(const std::string &name); - - private: - std::map fmap_; -}; - -class _RtKernelRegister { - public: - _RtKernelRegister(const std::string &name, RtKernelCreater &&fun) { - RtKernelFactory::Get().Registe(name, std::move(fun)); - } - ~_RtKernelRegister() = default; -}; - -#define _MS_REG_RTKERNEL_REG(KNAME, clazz) \ - static_assert(std::is_base_of::value, " must be base of RtKernel"); \ - static const _RtKernelRegister g_##KNAME##_##_RtKernel_reg(#KNAME, []() { return std::make_shared(); }); - -#define MS_REG_RTKERNEL(KNAME, clazz) _MS_REG_RTKERNEL_REG(KNAME, clazz) -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_H diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_build.cc b/mindspore/ccsrc/kernel/rts/rt_kernel_build.cc deleted file mode 100644 index 164605fe9b..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_build.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/rt_kernel_build.h" - -#include -#include -#include -#include - -#include "kernel/rts/rt_kernel.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -KernelModPtr RtOpBuild(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - (void)std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); - MS_LOG(INFO) << "Op Name(tolower)[" << op_name << "]"; - auto ker_ptr = RtKernelFactory::Create(op_name); - MS_EXCEPTION_IF_NULL(ker_ptr); - if (!ker_ptr->Init(anf_node)) { - MS_LOG(ERROR) << "Rt Op initialize failed!"; - return nullptr; - } - - return ker_ptr; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_build.h b/mindspore/ccsrc/kernel/rts/rt_kernel_build.h deleted file mode 100644 index cbd674b751..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_build.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H - -#include -#include -#include "kernel/kernel.h" -namespace mindspore { -namespace kernel { -KernelModPtr RtOpBuild(const AnfNodePtr &anf_node); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_BUILD_H diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc b/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc deleted file mode 100755 index 14f5a60a07..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_info.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/rt_kernel_info.h" -#include -#include -#include "utils/convert_utils.h" -#include "utils/utils.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace kernel { -void RtKerDescFactory::Register(const std::string &name, RtKerDescCreater &&fun) { - if (fmap_.find(name) == fmap_.end()) { - (void)fmap_.emplace(name, std::move(fun)); - } -} - -std::shared_ptr RtKerDescFactory::Create(const std::string &name) { - const auto &map = Get().fmap_; - auto it = map.find(name); - if (it != map.end() && it->second) { - return (it->second)(); - } - return nullptr; -} - -RtKerDescFactory &RtKerDescFactory::Get() { - static RtKerDescFactory _this; - return _this; -} - -static bool IsDefaultKernelInfo(const std::string &name) { - static const std::set white_list = {kStreamSwitchOpName, kStreamActiveOpName, kLabelSetOpName, - kLabelGotoOpName}; - return white_list.find(name) != white_list.end(); -} - -void GetRtKelInfo(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - MS_EXCEPTION_IF_NULL(kernel_info_list); - MS_EXCEPTION_IF_NULL(kernel_node); - std::string opNameLower = AnfAlgo::GetCNodeName(kernel_node); - (void)std::transform(opNameLower.begin(), opNameLower.end(), opNameLower.begin(), ::tolower); - - auto ker_desc_ptr = RtKerDescFactory::Create(opNameLower); - if (ker_desc_ptr != nullptr && !ker_desc_ptr->GetKernelInfo().empty()) { - *kernel_info_list = ker_desc_ptr->GetKernelInfo(); - return; - } - // if can't find kernel info in kernel info database, use the default kernel info - auto node_name = AnfAlgo::GetCNodeName(kernel_node); - if (IsDefaultKernelInfo(node_name)) { - auto kernel_build_info_builder = std::make_shared(); - // set input infos - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - kernel_build_info_builder->SetInputsFormat(std::vector(input_num, kOpFormat_DEFAULT)); - std::vector input_types = {}; - for (size_t i = 0; i < input_num; i++) { - input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); - } - kernel_build_info_builder->SetInputsDeviceType(input_types); - // set output info - auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - kernel_build_info_builder->SetOutputsFormat(std::vector(output_num, kOpFormat_DEFAULT)); - kernel_build_info_builder->SetOutputsDeviceType(std::vector(output_num, TypeId::kTypeUnknown)); - // set ohter info - kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); - kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); - kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); - kernel_info_list->push_back(kernel_build_info_builder->Build()); - return; - } - MS_LOG(DEBUG) << "Rt dose not have op [" << opNameLower << "]."; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/rt_kernel_info.h b/mindspore/ccsrc/kernel/rts/rt_kernel_info.h deleted file mode 100644 index ae3753b4c8..0000000000 --- a/mindspore/ccsrc/kernel/rts/rt_kernel_info.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H -#define MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/dtype.h" -#include "kernel/kernel_build_info.h" -#include "kernel/kernel.h" -#include "utils/utils.h" - -namespace mindspore { -namespace kernel { -class RtKerDesc { - public: - virtual ~RtKerDesc() {} - virtual std::vector> GetKernelInfo() { - return std::vector>{}; - } -}; - -using RtKerDescCreater = std::function()>; -class RtKerDescFactory { - RtKerDescFactory() = default; - ~RtKerDescFactory() = default; - - public: - static RtKerDescFactory &Get(); - void Register(const std::string &name, RtKerDescCreater &&fun); - static std::shared_ptr Create(const std::string &name); - - private: - std::map fmap_; -}; - -class _RtKerDescRegister { - public: - _RtKerDescRegister(const std::string &name, RtKerDescCreater &&fun) { - RtKerDescFactory::Get().Register(name, std::move(fun)); - } - ~_RtKerDescRegister() = default; -}; - -#define _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) \ - static_assert(std::is_base_of::value, " must be base of RtKerDesc"); \ - static const _RtKerDescRegister g_##KNAME##_##_rtkernel_desc_reg(#KNAME, []() { return std::make_shared(); }); - -#define MS_REG_RTKERNEL_DESC(KNAME, clazz) _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz) - -void GetRtKelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_RT_KERNEL_INFO_H diff --git a/mindspore/ccsrc/kernel/rts/send.cc b/mindspore/ccsrc/kernel/rts/send.cc deleted file mode 100644 index 298d75befd..0000000000 --- a/mindspore/ccsrc/kernel/rts/send.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/send.h" -#include -#include "runtime/event.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::EventRecordTaskInfo; -using EventRecordTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -SendKernel::SendKernel() { event_id_ = 0; } - -SendKernel::~SendKernel() {} - -bool SendKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrEventId, anf_node->cast())) { - MS_LOG(EXCEPTION) << "SendKernel has no attr kAttrEventId"; - } - event_id_ = GetValue(primitive->GetAttr(kAttrEventId)); - MS_LOG(INFO) << "send op event id:" << event_id_; - return true; -} - -bool SendKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - rtEvent_t event{}; - rtError_t status = rtEventRecord(event, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Send op rtEventRecord failed!"; - return false; - } - return true; -} - -std::vector SendKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "SendKernel GenTask event id:" << event_id_ << ", stream id:" << stream_id; - stream_id_ = stream_id; - EventRecordTaskInfoPtr task_info_ptr = std::make_shared(kernel_name_, stream_id, event_id_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/send.h b/mindspore/ccsrc/kernel/rts/send.h deleted file mode 100644 index 5c5b7cf09e..0000000000 --- a/mindspore/ccsrc/kernel/rts/send.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_SEND_H -#define MINDSPORE_CCSRC_KERNEL_RTS_SEND_H -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class SendKernel : public RtKernel { - public: - SendKernel(); - ~SendKernel() override; - bool Init(const AnfNodePtr &anf_node) override; - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - uint32_t event_id_; -}; - -MS_REG_RTKERNEL(send, SendKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_SEND_H diff --git a/mindspore/ccsrc/kernel/rts/stream_active.cc b/mindspore/ccsrc/kernel/rts/stream_active.cc deleted file mode 100644 index b573964868..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_active.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/stream_active.h" -#include -#include -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::StreamActiveTaskInfo; -using StreamActiveTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -StreamActiveKernel::StreamActiveKernel() { active_streams_index_ = {}; } - -StreamActiveKernel::~StreamActiveKernel() {} - -bool StreamActiveKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "stream active op init start"; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrActiveStreamList, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamActiveKernel has no attr kAttrActiveStreamList"; - } - active_streams_index_ = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); - return true; -} - -bool StreamActiveKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - MS_LOG(INFO) << "Stream active op launch start"; - - if (active_streams_index_.empty()) { - MS_LOG(ERROR) << "activeStreamList_ is empty!"; - return false; - } - - rtStream_t act_stream; - rtError_t status; - for (auto index : active_streams_index_) { - act_stream = kernel::TaskStream::GetInstance()->gen_stream_list()[index]; - status = rtStreamActive(act_stream, stream_ptr); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Stream active failed!"; - return false; - } - } - return true; -} - -std::vector StreamActiveKernel::GenTask(const std::vector &, const std::vector &, - const std::vector &, uint32_t stream_id) { - MS_LOG(INFO) << "StreamActiveKernel GenTask active stream size:" << active_streams_index_.size() - << ", stream id:" << stream_id; - stream_id_ = stream_id; - std::vector task_info_list; - for (auto &index : active_streams_index_) { - std::shared_ptr task_info_ptr = - std::make_shared(kernel_name_, stream_id, index); - MS_EXCEPTION_IF_NULL(task_info_ptr); - task_info_list.emplace_back(task_info_ptr); - MS_LOG(INFO) << "StreamActiveKernel GenTask: streamId:" << stream_id << ", Active streamId:" << index; - } - return task_info_list; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/stream_active.h b/mindspore/ccsrc/kernel/rts/stream_active.h deleted file mode 100644 index 68c422e7c2..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_active.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H -#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class StreamActiveKernel : public RtKernel { - public: - StreamActiveKernel(); - ~StreamActiveKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - std::vector active_streams_index_; -}; - -MS_REG_RTKERNEL(streamactive, StreamActiveKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_ACTIVE_H diff --git a/mindspore/ccsrc/kernel/rts/stream_switch.cc b/mindspore/ccsrc/kernel/rts/stream_switch.cc deleted file mode 100644 index 44b0a1ef86..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_switch.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/rts/stream_switch.h" - -#include -#include - -#include "runtime/stream.h" -#include "framework/ge_runtime/task_info.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -using ge::model_runner::StreamSwitchTaskInfo; -using StreamSwitchTaskInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace kernel { -StreamSwitchKernel::StreamSwitchKernel() { - cond_ = RT_EQUAL; - true_stream_index_ = 0; - data_type_ = RT_SWITCH_INT32; -} - -StreamSwitchKernel::~StreamSwitchKernel() {} - -bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "stream switch op init start"; - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - if (!AnfAlgo::HasNodeAttr(kAttrSwitchCondition, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrSwitchCondition"; - } - cond_ = tagRtCondition(GetValue(primitive->GetAttr(kAttrSwitchCondition))); - if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrTrueBranchStream"; - } - true_stream_index_ = GetValue(primitive->GetAttr(kAttrTrueBranchStream)); - if (!AnfAlgo::HasNodeAttr(kAttrDataType, anf_node->cast())) { - MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrDataType"; - } - data_type_ = tagRtSwitchDataType(GetValue(primitive->GetAttr(kAttrDataType))); - MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ - << ", data_type_:" << static_cast(data_type_); - return true; -} - -bool StreamSwitchKernel::Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - MS_LOG(INFO) << "stream switch op launch start"; - if (inputs.size() != 2) { - MS_LOG(EXCEPTION) << "Stream switch inputs size is " << inputs.size() << ", only support 2"; - } - - void *loop_cnt = inputs[0]->addr; - void *ites_per_loop = inputs[1]->addr; - rtStream_t true_stream_ = kernel::TaskStream::GetInstance()->gen_stream_list()[true_stream_index_]; - rtError_t status = rtStreamSwitchEx(loop_cnt, cond_, ites_per_loop, true_stream_, stream_ptr, data_type_); - if (status != RT_ERROR_NONE) { - MS_LOG(ERROR) << "Stream switch failed!"; - return false; - } - return true; -} - -std::vector StreamSwitchKernel::GenTask(const std::vector &inputs, - const std::vector &, const std::vector &, - uint32_t stream_id) { - MS_LOG(INFO) << "StreamSwitchKernel GenTask start"; - if (inputs.size() != 2) { - MS_LOG(EXCEPTION) << "stream switch inputs size is " << inputs.size() << ", is not two"; - } - stream_id_ = stream_id; - MS_EXCEPTION_IF_NULL(inputs[0]); - MS_EXCEPTION_IF_NULL(inputs[1]); - auto loop_cnt = inputs[0]->addr; - auto ites_per_loop = inputs[1]->addr; - MS_LOG(INFO) << "cond_:" << static_cast(cond_) << ", true_stream_index_:" << true_stream_index_ - << ", stream_id:" << stream_id; - std::shared_ptr task_info_ptr = std::make_shared( - kernel_name_, stream_id, true_stream_index_, loop_cnt, ites_per_loop, cond_, data_type_); - MS_EXCEPTION_IF_NULL(task_info_ptr); - return {task_info_ptr}; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/rts/stream_switch.h b/mindspore/ccsrc/kernel/rts/stream_switch.h deleted file mode 100644 index 4e927f3059..0000000000 --- a/mindspore/ccsrc/kernel/rts/stream_switch.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H -#define MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H - -#include -#include -#include "kernel/rts/rt_kernel.h" -#include "kernel/rts/rt_kernel_info.h" - -namespace mindspore { -namespace kernel { -class StreamSwitchKernel : public RtKernel { - public: - StreamSwitchKernel(); - ~StreamSwitchKernel() override; - - bool Init(const AnfNodePtr &anf_node) override; - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, uint32_t stream_id) override; - - private: - rtCondition_t cond_; - uint32_t true_stream_index_; - rtSwitchDataType_t data_type_; -}; - -MS_REG_RTKERNEL(streamswitch, StreamSwitchKernel); -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_RTS_STREAM_SWITCH_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc deleted file mode 100644 index 052b7eb2df..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ /dev/null @@ -1,424 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_adapter.h" - -#include -#include -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/opinfo.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -static std::map tbe_func_adapter_map = { - {"softmax", "softmax_v2"}, - {"log_softmax", "log_softmax_v2"}, - {"apply_momentum", "apply_momentum_d"}, - {"apply_ftrl", "apply_ftrl_d"}, - {"re_lu6", "relu6"}, - {"re_lu6_grad", "relu6_grad"}, - {"re_lu", "relu"}, - {"re_luv2", "relu_v2"}, - {"p_re_lu", "prelu"}, - {"p_re_lu_grad", "prelu_grad"}, - {"tensor_add", "add"}, - {"reduce_mean", "reduce_mean_d"}, - {"reduce_max", "reduce_max_d"}, - {"reduce_min", "reduce_min_d"}, - {"avg_pool_grad", "avg_pool_grad_d"}, - {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, - {"conv2d_backprop_input", "conv2d_backprop_input_d"}, - {"depthwise_conv2d_native", "depthwise_conv2d"}, - {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"}, - {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"}, - {"scatter_nd", "scatter_nd_d"}, - {"tile", "tile_d"}, - {"gather_v2", "gather_v2_d"}, - {"sparse_gather_v2", "gather_v2_d"}, - {"batch_mat_mul", "batch_matmul"}, - {"b_n_training_reduce", "bn_training_reduce"}, - {"b_n_training_update", "bn_training_update"}, - {"b_n_training_update_v2", "bn_training_update_v2"}, - {"b_n_training_update_v3", "bn_training_update_v3"}, - {"b_n_training_reduce_grad", "bn_training_reduce_grad"}, - {"b_n_training_update_grad", "bn_training_update_grad"}, - {"b_n_infer", "bn_infer"}, - {"b_n_infer_grad", "bn_infer_grad"}, - {"n_pu_clear_float_status", "n_p_u_clear_float_status"}, - {"n_pu_get_float_status", "n_p_u_get_float_status"}, - {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"}, - {"dropout_do_mask", "drop_out_do_mask"}, - {"strided_slice", "strided_slice_d"}, - {"strided_slice_grad", "strided_slice_grad_d"}, - {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, - {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, - {"apply_ada_max", "apply_ada_max_d"}, - {"apply_adadelta", "apply_adadelta_d"}, - {"apply_adagrad", "apply_adagrad_d"}, - {"apply_adagrad_v2", "apply_adagradv2_d"}, - {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, - {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, - {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, - {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, - {"apply_add_sign", "apply_add_sign_d"}, - {"apply_power_sign", "apply_power_sign_d"}, - {"transpose", "transpose_d"}, - {"fill", "fill_d"}, - {"unsorted_segment_sum", "unsorted_segment_sum_d"}, - {"unsorted_segment_prod", "unsorted_segment_prod_d"}, - {"concat", "concat_d"}, - {"slice", "slice_d"}, - {"reduce_sum", "reduce_sum_d"}, - {"inplace_add", "inplace_add_d"}, - {"inplace_sub", "inplace_sub_d"}, - {"one_hot", "one_hot_d"}, - {"sum", "reduce_sum_d"}, - {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, - {"lamb_next_mv", "lamb_next_m_v"}, - {"split", "split_d"}, - {"split_v", "split_v_d"}, - {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, - {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, - {"pad", "pad_d"}, - {"argmax", "arg_max_d"}, - {"argmin", "arg_min_d"}, - {"space_to_batch", "space_to_batch_d"}, - {"batch_to_space", "batch_to_space_d"}, - {"space_to_batch_nd", "space_to_batch_nd_d"}, - {"batch_to_space_nd", "batch_to_space_nd_d"}, - {"resize_bilinear", "resize_bilinear_v2_d"}, - {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, - {"adam", "apply_adam_d"}, - {"r_oi_align", "roi_align"}, - {"r_oi_align_grad", "roi_align_grad"}, - {"i_ou", "iou"}, - {"s_gd", "sgd"}, - {"l_rn", "lrn"}, - {"l_rn_grad", "lrn_grad"}, - {"l_ars_update", "lars_v2_update"}, - {"n_ms_with_mask", "nms_with_mask"}, - {"square_sum_all", "square_sum_all"}, - {"cum_sum", "cumsum_d"}, - {"range", "range_d"}, - {"lin_space", "lin_space_d"}, - {"inv_grad", "inv_grad"}, - {"apply_rms_prop", "apply_rms_prop_d"}, - {"cum_prod", "cumprod_d"}, - {"reduce_all", "reduce_all_d"}, - {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, - {"unsorted_segment_min", "unsorted_segment_min_d"}, - {"reduce_prod", "reduce_prod_d"}, - {"a_cos", "acos"}, - {"a_cos_grad", "acos_grad"}, - {"histogram_fixed_width", "histogram_fixed_width_d"}, - {"broadcast_to", "broadcast_to_d"}, - {"inplace_update", "inplace_update_d"}, - {"matrix_diag", "matrix_diag_d"}, - {"matrix_diag_part", "matrix_diag_part_d"}, - {"matrix_set_diag", "matrix_set_diag_d"}}; - -void TbeAdapter::NormalizeFuncName(std::string *func_name) { - if (func_name == nullptr) { - MS_LOG(EXCEPTION) << "func_name is null"; - } - std::string name_tmp; - bool sub_head = false; - for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) { - if (islower(*iter)) { - sub_head = false; - } - if (isdigit(*iter)) { - sub_head = true; - } - if (isupper(*iter) && iter != func_name->begin()) { - if (!sub_head) { - (void)name_tmp.insert(name_tmp.end(), '_'); - sub_head = true; - } else { - string::iterator iter_next = iter + 1; - if (iter_next != func_name->end()) { - if (islower(*iter_next)) { - (void)name_tmp.insert(name_tmp.end(), '_'); - } - } - } - } - (void)name_tmp.insert(name_tmp.end(), *iter); - } - (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower); - *func_name = name_tmp; - auto iter = tbe_func_adapter_map.find(*func_name); - if (iter != tbe_func_adapter_map.end()) { - MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second; - *func_name = iter->second; - } -} - -void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) { - std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0); - std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0); - if (input_format == kOpFormat_DEFAULT) { - input_format = kOpFormat_NCHW; - } - if (output_format == kOpFormat_DEFAULT) { - output_format = kOpFormat_NCHW; - } - AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node); - AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node); - } -} - -std::unordered_set input_order_adjusted_ops = { - "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", - "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; - -void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, - nlohmann::json *inputs_json) { - MS_EXCEPTION_IF_NULL(inputs_json); - if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { - (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); - } else { - if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { - inputs_json->push_back(inputs_list[2]); - inputs_json->push_back(inputs_list[0]); - inputs_json->push_back(inputs_list[1]); - for (size_t i = 3; i < inputs_list.size(); ++i) { - inputs_json->push_back(inputs_list[i]); - } - } else if (op_name == "ApplyCenteredRMSProp") { - // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map - // TBE parameter to correspond python API parameter by latter's index using hardcode - inputs_json->push_back(inputs_list[0]); - inputs_json->push_back(inputs_list[1]); - inputs_json->push_back(inputs_list[2]); - inputs_json->push_back(inputs_list[3]); - inputs_json->push_back(inputs_list[5]); - inputs_json->push_back(inputs_list[6]); - inputs_json->push_back(inputs_list[7]); - inputs_json->push_back(inputs_list[8]); - inputs_json->push_back(inputs_list[4]); - } else { - inputs_json->push_back(inputs_list[1]); - inputs_json->push_back(inputs_list[0]); - for (size_t i = 2; i < inputs_list.size(); ++i) { - inputs_json->push_back(inputs_list[i]); - } - } - } -} - -void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, - std::vector *inputs_json) { - MS_EXCEPTION_IF_NULL(inputs_json); - if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { - (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json))); - } else { - if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { - inputs_json->emplace_back(inputs_list[2]); - inputs_json->emplace_back(inputs_list[0]); - inputs_json->emplace_back(inputs_list[1]); - for (size_t i = 3; i < inputs_list.size(); ++i) { - inputs_json->emplace_back(inputs_list[i]); - } - } else { - inputs_json->emplace_back(inputs_list[1]); - inputs_json->emplace_back(inputs_list[0]); - for (size_t i = 2; i < inputs_list.size(); ++i) { - inputs_json->emplace_back(inputs_list[i]); - } - } - } -} - -void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, - std::vector *reorder_data_layer) { - MS_EXCEPTION_IF_NULL(reorder_data_layer); - if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) { - (void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer))); - } else { - if (op_name == "MinimumGrad" || op_name == "MaximumGrad") { - reorder_data_layer->emplace_back(data_layer[2]); - reorder_data_layer->emplace_back(data_layer[0]); - reorder_data_layer->emplace_back(data_layer[1]); - for (size_t i = 3; i < data_layer.size(); ++i) { - reorder_data_layer->emplace_back(data_layer[i]); - } - } else { - reorder_data_layer->emplace_back(data_layer[1]); - reorder_data_layer->emplace_back(data_layer[0]); - for (size_t i = 2; i < data_layer.size(); ++i) { - reorder_data_layer->emplace_back(data_layer[i]); - } - } - } -} - -std::map TbeAdapter::build_json_attr_pass_map_ = { - {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass}, - {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass}, - {"Cast", TbeAdapter::CastAttrJsonPass}}; - -bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(attrs_json); - auto cnode_name = AnfAlgo::GetCNodeName(anf_node); - auto FPass = build_json_attr_pass_map_.find(cnode_name); - if (FPass != build_json_attr_pass_map_.end()) { - FPass->second(anf_node, op_info_attrs, attrs_json); - return true; - } - return false; -} - -void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attr_num = op_info_attrs.size(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (size_t i = 0; i < attr_num; i++) { - nlohmann::json attr_obj; - MS_EXCEPTION_IF_NULL(op_info_attrs[i]); - std::string attr_name = op_info_attrs[i]->name(); - auto value = primitive->GetAttr(attr_name); - if (value != nullptr) { - bool attr_value = GetValue(value); - attr_obj["value"] = attr_value; - attr_obj["valid"] = true; - } else { - attr_obj["valid"] = false; - } - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - } - MS_LOG(INFO) << "MaximumGradAttrJsonPass done."; -} - -void TbeAdapter::MinimumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attr_num = op_info_attrs.size(); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (size_t i = 0; i < attr_num; i++) { - nlohmann::json attr_obj; - MS_EXCEPTION_IF_NULL(op_info_attrs[i]); - std::string attr_name = op_info_attrs[i]->name(); - auto value = primitive->GetAttr(attr_name); - if (value != nullptr) { - bool attr_value = GetValue(value); - attr_obj["value"] = attr_value; - attr_obj["valid"] = true; - } else { - attr_obj["valid"] = false; - } - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - } - MS_LOG(INFO) << "MinimumGradAttrJsonPass done."; -} - -static int TypeStrToDstType(const std::string &type_str) { - int ret = -1; - if (type_str == "Float" || type_str == "Float32") { - ret = 0; - } else if (type_str == "Float16") { - ret = 1; - } else if (type_str == "Int8") { - ret = 2; - } else if (type_str == "Int32") { - ret = 3; - } else if (type_str == "UInt8") { - ret = 4; - } else if (type_str == "UInt64") { - ret = 10; - } else if (type_str == "Bool") { - ret = 12; - } else { - MS_LOG(INFO) << "Error type str is invailed: " << type_str; - } - return ret; -} - -void TbeAdapter::CastAttrJsonPass(const mindspore::AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(attrs_json); - if (op_info_attrs.size() != 1) { - MS_LOG(INFO) << "cast node should has dst_type attr"; - return; - } - auto attr_name = op_info_attrs[0]->name(); - auto type_ptr = std::make_shared(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, 0))); - MS_EXCEPTION_IF_NULL(type_ptr); - auto type_element = type_ptr->element(); - MS_EXCEPTION_IF_NULL(type_element); - auto dtype = type_element->ToString(); - auto dst_type_value = TypeStrToDstType(dtype); - nlohmann::json attr_obj; - attr_obj["value"] = dst_type_value; - attr_obj["valid"] = true; - attr_obj["name"] = attr_name; - attrs_json->push_back(attr_obj); - MS_LOG(INFO) << "CastAttrJsonPass done."; -} - -void TbeAdapter::GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, - size_t real_input_index, std::vector *input_list, - mindspore::kernel::kCreaterType creater_type) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_list); - auto input_x_shape = AnfAlgo::GetOutputInferShape(anf_node, 0); - size_t last_dim = input_x_shape[input_x_shape.size() - 1]; - std::vector tensor_shape = {last_dim}; - std::vector tensor_origin_shape = {last_dim}; - std::string tensor_format = AnfAlgo::GetInputFormat(anf_node, static_cast(real_input_index)); - if (tensor_format == kOpFormat_DEFAULT) { - tensor_format = kOpFormat_NCHW; - } - std::string tensor_origin_format = kOpFormat_NCHW; - std::string tensor_dtype = "float16"; - nlohmann::json input_desc_json; - input_desc_json["dtype"] = tensor_dtype; - input_desc_json["name"] = AnfAlgo::GetCNodeName(anf_node); - input_desc_json["ori_shape"] = tensor_origin_shape; - input_desc_json["ori_format"] = tensor_origin_format; - input_desc_json["shape"] = tensor_shape; - if (creater_type == OP_SELECT_FORMAT) { - input_desc_json["format"] = tensor_origin_format; - } else { - input_desc_json["format"] = tensor_format; - } - input_desc_json["valid"] = true; - input_list->emplace_back(input_desc_json); -} -} // namespace tbe -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.h b/mindspore/ccsrc/kernel/tbe/tbe_adapter.h deleted file mode 100644 index 354bcb3ebd..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H - -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "base/base.h" -#include "kernel/oplib/opinfo.h" -// Note: This file is mainly used to adapt the ME front-end operator description and -// the TBE back-end operator implementation difference -namespace mindspore { -namespace kernel { -enum kCreaterType : int { SINGLE_BUILD = 0, PREBUILD, OP_SELECT_FORMAT, CHECK_SUPPORTED, OP_PRE_COMPILE }; -namespace tbe { -using FAttrsPass = void (*)(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); -class TbeAdapter { - public: - TbeAdapter() = default; - ~TbeAdapter() = default; - static void NormalizeFuncName(std::string *func_name); - static void SetTbeAttrsForTransDataOp(const AnfNodePtr &anf_node); - static void InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, - nlohmann::json *inputs_json); - static bool RunAttrPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - static void GenTopKV2IndicesTensorInfo(const std::shared_ptr &anf_node, size_t real_input_index, - std::vector *input_list, kCreaterType creater_type); - - static void FusionInputOrderPass(const std::string &op_name, const std::vector &inputs_list, - std::vector *inputs_json); - static void FusionDataOrderPass(const std::string &op_name, const std::vector &data_layer, - std::vector *reorder_data_layer); - - private: - static void MaximumGradAttrJsonPass(const AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - static void MinimumGradAttrJsonPass(const AnfNodePtr &anf_node, - const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - - static void CastAttrJsonPass(const AnfNodePtr &anf_node, const std::vector> &op_info_attrs, - nlohmann::json *attrs_json); - - static std::map build_json_attr_pass_map_; -}; -} // namespace tbe -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_ADAPTER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc b/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc deleted file mode 100644 index 90c5557253..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_convert_utils.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -const std::unordered_map type_str_id_maps = { - {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16}, - {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64}, - {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8}, - {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32}, - {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt}, - {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16}, - {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64}, - {"bool", TypeId::kNumberTypeBool}, -}; - -const std::map type_id_str_maps = { - {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"}, - {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"}, - {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"}, - {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"}, - {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"}, - {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"}, - {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"}, - {TypeId::kNumberTypeBool, "int8"}, -}; - -const std::map type_str_maps = { - {"Float32", "float32"}, {"Float16", "float16"}, {"Int8", "int8"}, {"Int16", "int16"}, - {"UInt16", "uint16"}, {"UInt8", "uint8"}, {"Int32", "int32"}, {"UInt32", "uint32"}, - {"Int64", "int64"}, {"UInt64", "uint64"}, {"Bool", "int8"}, {"Float64", "float64"}, -}; - -const std::unordered_map type_nbyte_maps = { - {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2}, - {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)}, - {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2}, - {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)}, -}; - -const std::unordered_map fusion_type_maps = { - {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE}, - {"SEGMENT", FusionType::SEGMENT}, {"DYNAMIC", FusionType::DYNAMIC}, {"OPAQUE", FusionType::OPAQUE}, -}; - -TypeId DtypeToTypeId(const std::string &dtypes) { - auto iter = type_str_id_maps.find(dtypes); - if (iter == type_str_id_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input device dtype: " << dtypes; - } - return iter->second; -} - -std::string TypeIdToString(TypeId type_id) { - auto iter = type_id_str_maps.find(type_id); - if (iter == type_id_str_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input dtype: " << TypeIdLabel(type_id); - } - return iter->second; -} - -size_t GetDtypeNbyte(const std::string &dtypes) { - auto iter = type_nbyte_maps.find(dtypes); - if (iter == type_nbyte_maps.end()) { - MS_LOG(EXCEPTION) << "Illegal input dtype: " << dtypes; - } - return iter->second; -} - -FusionType GetFusionType(const std::string &pattern) { - auto iter = fusion_type_maps.find(pattern); - if (iter == fusion_type_maps.end()) { - MS_LOG(INFO) << "Illegal fusion pattern: " << pattern; - return UNKNOWN_FUSION_TYPE; - } - return iter->second; -} - -std::string GetProcessor(const AnfNodePtr &anf_node) { - MS_EXCEPTION_IF_NULL(anf_node); - std::string device; - switch (AnfAlgo::GetProcessor(anf_node)) { - case Processor::AICORE: - device = kProcessorAiCore; - break; - default: - MS_LOG(INFO) << "Unknown processor type." << anf_node->fullname_with_scope(); - break; - } - return device; -} -} // namespace tbe -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.h b/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.h deleted file mode 100644 index 3fc52becc2..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_convert_utils.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ - -#include -#include "kernel/kernel.h" -#include "base/base.h" -#include "ir/dtype/type.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -constexpr auto kProcessorAiCore = "aicore"; -TypeId DtypeToTypeId(const std::string &dtypes); - -std::string TypeIdToString(TypeId type_id); - -size_t GetDtypeNbyte(const std::string &dtypes); - -FusionType GetFusionType(const std::string &pattern); - -std::string GetProcessor(const AnfNodePtr &anf_node); -} // namespace tbe -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_COMMON_UTILS_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc deleted file mode 100644 index 645a195f5e..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ /dev/null @@ -1,1019 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_build.h" -#include -#include -#include -#include "operator/ops.h" -#include "parallel/ops_info/ops_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/tbe/tbe_adapter.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeAdapter; -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kFusionOpList = "op_list"; -constexpr auto kFusionKernelNamePrfix = "te_fusion"; -constexpr auto kOptional = "optional_"; -constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z"; -constexpr auto kPlatform = "platform"; -constexpr auto kPlatTBE = "TBE"; -constexpr auto kGenModel = "gen_model"; -constexpr auto kSingle = "single"; -constexpr auto kImplPath = "impl_path"; -constexpr auto kJInputs = "inputs"; -constexpr auto kJOutputs = "outputs"; -constexpr auto kJAttrs = "attrs"; -constexpr auto kJKernelName = "kernel_name"; -constexpr auto kJOpInfo = "op_info"; -constexpr auto kJDtype = "dtype"; -constexpr auto kJtype = "type"; -constexpr auto kJName = "name"; -constexpr auto kJOriShape = "ori_shape"; -constexpr auto kJOriFormat = "ori_format"; -constexpr auto kJShape = "shape"; -constexpr auto kJFormat = "format"; -constexpr auto kJValid = "valid"; -constexpr auto kJParamType = "param_type"; -constexpr auto kParamDynamic = "dynamic"; -constexpr auto kParamRequred = "required"; -constexpr auto kJDataType = "data_type"; -constexpr auto kJOutputIndex = "output_index"; -constexpr auto kJOutputDesc = "output_desc"; -constexpr auto kJInputDesc = "input_desc"; -constexpr auto kVTypeInt = "int"; -constexpr auto kVTypeStr = "str"; -constexpr auto kVTypeBool = "bool"; -constexpr auto kVTypeFloat = "float"; -constexpr auto kVTypeListInt = "listInt"; -constexpr auto kVTypeInt32 = "Int32"; -constexpr auto kVTypeListUInt64 = "listUInt64"; -constexpr auto kVTypeListFloat = "listFloat"; -constexpr auto kVTypeListListInt = "listListInt"; -constexpr auto kJValue = "value"; -constexpr auto kJDynIndex = "dyn_index"; -constexpr auto kJFuncName = "func_name"; - -std::string NormalizeFullScopeName(const string &full_scope_name) { - // exp:Default/ReLU-op0 -->Default_ReLU_op0 - string normal_ret = full_scope_name; - std::replace(normal_ret.begin(), normal_ret.end(), '/', '_'); - std::replace(normal_ret.begin(), normal_ret.end(), '-', '_'); - return normal_ret; -} - -bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, - nlohmann::json *kernel_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(kernel_json); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); - MS_EXCEPTION_IF_NULL(op_info_ptr); - (*kernel_json)[kPlatform] = kPlatTBE; - (*kernel_json)[kGenModel] = kSingle; - (*kernel_json)[kImplPath] = op_info_ptr->impl_path(); - nlohmann::json op_info_json; - if (op_info_ptr->impl_path().empty()) { - tbe::TbeAdapter::NormalizeFuncName(&op_name); - } else { - op_name = op_info_ptr->kernel_name(); - } - op_info_json[kJName] = op_name; - // generate inputs json - nlohmann::json inputs_json; - if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { - MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed"; - return false; - } - op_info_json[kJInputs] = inputs_json; - // generate outputs json - nlohmann::json outputs_json; - if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) { - MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed"; - return false; - } - op_info_json[kJOutputs] = outputs_json; - // generate attrs json - nlohmann::json attrs_json; - (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json); - op_info_json[kJAttrs] = attrs_json; - std::string json_str = op_info_json.dump(); - size_t hash_id = std::hash()(json_str); - json_name_ = op_name + "_" + std::to_string(hash_id); - json_info_ = json_str; - if (creater_type_ == PREBUILD) { - op_info_json[kJKernelName] = NormalizeFullScopeName(anf_node->fullname_with_scope()); - } else { - op_info_json[kJKernelName] = json_name_; - } - (*kernel_json)[kJOpInfo] = op_info_json; - if (creater_type_ == SINGLE_BUILD) { - TbeUtils::SaveJsonInfo(json_name_, json_info_); - } - - MS_LOG(INFO) << "Operate type:" << creater_type_ << ", full scope name is :" << anf_node->fullname_with_scope() - << ", json info name is : " << json_name_ << ", kernel json:" << kernel_json->dump(); - - return true; -} - -bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, - bool value, const std::shared_ptr &input_ptr, - const string &op_input_name, size_t input_i, - std::vector *input_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(input_list); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) { - TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_); - } else { - auto dtype = GetDeviceInputType(anf_node, real_input_index); - auto format = GetDeviceInputFormat(anf_node, real_input_index); - auto shape = GetDeviceInputShape(anf_node, real_input_index); - auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - nlohmann::json input_desc_json; - input_desc_json[kJDtype] = dtype; - input_desc_json[kJName] = op_input_name + std::to_string(input_i); - input_desc_json[kJOriShape] = ori_shape; - input_desc_json[kJOriFormat] = kOpFormat_NCHW; - input_desc_json[kJShape] = shape; - input_desc_json[kJFormat] = format; - input_desc_json[kJValid] = value; - input_desc_json[kJParamType] = input_ptr->param_type(); - input_list->emplace_back(input_desc_json); - } - return true; -} - -bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, - const std::shared_ptr &input_ptr, size_t *real_input_index, - string *op_input_name, std::vector *input_list) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(real_input_index); - MS_EXCEPTION_IF_NULL(op_input_name); - MS_EXCEPTION_IF_NULL(input_list); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node); - bool value = true; - for (size_t input_i = 0; input_i < input_tensor_num; input_i++) { - if (*real_input_index >= real_input_num) { - if (input_ptr->param_type() == "optional") { - *op_input_name = input_ptr->name() + "_optional_"; - nlohmann::json input_desc_json; - input_desc_json[kJValid] = false; - input_desc_json[kJName] = *op_input_name + std::to_string(*real_input_index); - input_list->emplace_back(input_desc_json); - continue; - } - MS_LOG(ERROR) << "Input num: " << *real_input_index << " is not match op inputs"; - return false; - } - if (op_name == "BatchNorm") { - if (input_ptr->name() == "mean" || input_ptr->name() == "variance") { - auto attr = primitive->GetAttr("is_training"); - MS_EXCEPTION_IF_NULL(attr); - bool is_training = GetValue(attr); - MS_LOG(INFO) << "Op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training " - << is_training; - if (is_training) { - (*real_input_index)++; - break; - } - } - } - bool ret = GenInputDescJson(anf_node, *real_input_index, value, input_ptr, *op_input_name, input_i, input_list); - (*real_input_index)++; - if (!ret) { - return false; - } - } - return true; -} - -bool GetInputNameAndRealNum(const std::shared_ptr &anf_node, const std::shared_ptr &input_ptr, - size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(input_ptr); - MS_EXCEPTION_IF_NULL(dyn_input_index); - MS_EXCEPTION_IF_NULL(input_num); - MS_EXCEPTION_IF_NULL(op_input_name); - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - std::vector dyn_input_sizes; - if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - - if (input_ptr->param_type() == kParamDynamic) { - if (*dyn_input_index >= dyn_input_sizes.size()) { - MS_LOG(ERROR) << "Dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size(); - return false; - } - *input_num = IntToSize(dyn_input_sizes[*dyn_input_index]); - *op_input_name = input_ptr->name() + "_dynamic_"; - (*dyn_input_index)++; - // if optional input is exist - } else { - *input_num = 1; - *op_input_name = input_ptr->name() + "_"; - } - return true; -} - -bool TbeKernelJsonCreator::GenTbeInputsJson(const std::shared_ptr &anf_node, - const std::shared_ptr &op_info, nlohmann::json *inputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(inputs_json); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kAtomicAddrCleanOpName) { - return true; - } - std::vector> inputs_ptr = op_info->inputs_ptr(); - if (inputs_ptr.empty()) { - MS_LOG(INFO) << "Apply kernel " << op_name << "registration info has no input info"; - return true; - } - auto op_info_input_num = inputs_ptr.size(); - size_t dyn_input_index = 0; - size_t real_input_index = 0; - std::vector> inputs_list; - for (size_t i = 0; i < op_info_input_num; i++) { - size_t input_tensor_num; - std::shared_ptr input_ptr = inputs_ptr[i]; - std::string op_input_name; - MS_EXCEPTION_IF_NULL(input_ptr); - if (!GetInputNameAndRealNum(anf_node, input_ptr, &dyn_input_index, &input_tensor_num, &op_input_name)) { - return false; - } - std::vector input_list; - if (!GenInputList(anf_node, input_tensor_num, input_ptr, &real_input_index, &op_input_name, &input_list)) { - return false; - } - inputs_list.emplace_back(input_list); - } - - TbeAdapter::InputOrderPass(op_name, inputs_list, inputs_json); - return true; -} - -bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr &anf_node, - const std::shared_ptr &op_info, nlohmann::json *outputs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(outputs_json); - auto op_name = AnfAlgo::GetCNodeName(anf_node); - if (op_name == kAtomicAddrCleanOpName) { - return true; - } - auto outputs_ptr = op_info->outputs_ptr(); - return GenOutputDescJson(anf_node, outputs_ptr, outputs_json); -} - -bool TbeKernelJsonCreator::GenOutputDescJson( - const std::shared_ptr &anf_node, - const std::vector> &outputs_ptr, nlohmann::json *outputs_json) { - MS_EXCEPTION_IF_NULL(outputs_json); - size_t output_idx = 0; - auto op_name = AnfAlgo::GetCNodeName(anf_node); - size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node); - - for (const auto &output_ptr : outputs_ptr) { - size_t output_obj_num = 0; - if (output_ptr->param_type() == kParamRequred) { - output_obj_num = 1; - } else if (output_ptr->param_type() == kParamDynamic) { - if (outputs_ptr.size() > 1) { - MS_LOG(ERROR) << "Dynamic output is unsupported multi output!"; - return false; - } - output_obj_num = real_output_num; - } else { - if (output_idx >= real_output_num) { - MS_LOG(INFO) << "Op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none."; - std::vector output_list; - nlohmann::json output_obj; - output_obj[kJName] = output_ptr->name(); - output_obj[kJValid] = false; - output_list.emplace_back(output_obj); - (*outputs_json).push_back(output_list); - continue; - } else { - output_obj_num = 1; - } - } - std::vector output_list; - GenOutputList(anf_node, output_obj_num, output_ptr, &output_idx, &output_list); - (*outputs_json).push_back(output_list); - } - return true; -} - -void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, - const std::shared_ptr &output_ptr, size_t *output_idx, - std::vector *output_list) { - MS_EXCEPTION_IF_NULL(output_idx); - MS_EXCEPTION_IF_NULL(output_list); - for (size_t i = 0; i < output_obj_num; i++) { - auto dtype = GetDeviceOutputType(anf_node, *output_idx); - auto format = GetDeviceOutputFormat(anf_node, *output_idx); - auto shape = GetDeviceOutputShape(anf_node, *output_idx); - std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - nlohmann::json output_obj; - output_obj[kJDtype] = dtype; - output_obj[kJShape] = shape; - output_obj[kJFormat] = format; - output_obj[kJOriShape] = ori_shape; - output_obj[kJOriFormat] = kOpFormat_NCHW; - output_obj[kJName] = output_ptr->name(); - output_obj[kJValid] = true; - output_obj[kJParamType] = output_ptr->param_type(); - output_list->emplace_back(output_obj); - (*output_idx)++; - } -} - -bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_node, - const std::shared_ptr &op_info, nlohmann::json *attrs_json) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(attrs_json); - auto attrs_ptr = op_info->attrs_ptr(); - std::string op_name = AnfAlgo::GetCNodeName(anf_node); - if (TbeAdapter::RunAttrPass(anf_node, attrs_ptr, attrs_json)) { - return true; - } - auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); - MS_EXCEPTION_IF_NULL(primitive); - for (const auto &attr_ptr : attrs_ptr) { - std::string attr_name = attr_ptr->name(); - nlohmann::json attr_obj; - attr_obj[kJName] = attr_name; - if (op_name == parallel::LAYER_NORM && attr_obj[kJName] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) { - continue; - } - if (primitive->GetAttr(attr_name) != nullptr) { - auto value = primitive->GetAttr(attr_name); - std::string type = attr_ptr->type(); - ParseAttrValue(type, value, &attr_obj); - attr_obj[kJValid] = true; - } else { - if (op_info->impl_path().empty()) { - attr_obj[kJValid] = false; - } else { - if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) { - MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name - << " is required, but not set."; - } else { - attr_obj[kJValid] = false; - } - } - } - (*attrs_json).push_back(attr_obj); - } - return true; -} - -void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, - nlohmann::json *attr_obj) { - MS_EXCEPTION_IF_NULL(value); - MS_EXCEPTION_IF_NULL(attr_obj); - if (type == kVTypeInt) { - auto attr_value = GetValue(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeStr) { - auto attr_value = GetValue(value); - if (attr_value == kOpFormat_FRAC_Z) { - attr_value = kOpFormat_FRACTAL_Z; - } - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeBool) { - auto attr_value = GetValue(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeFloat) { - auto attr_value = GetValue(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListInt) { - std::vector attr_value; - auto value_type = value->type(); - MS_EXCEPTION_IF_NULL(value_type); - auto value_type_str = value_type->ToString(); - if (value_type_str == kVTypeInt32) { - int data = GetValue(value); - attr_value.push_back(data); - } else { - attr_value = GetValue>(value); - } - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListFloat) { - std::vector attr_value; - auto value_type = value->type(); - MS_EXCEPTION_IF_NULL(value_type); - auto value_type_str = value_type->ToString(); - if (value_type_str == kVTypeFloat) { - auto data = GetValue(value); - attr_value.push_back(data); - } else { - attr_value = GetValue>(value); - } - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListUInt64) { - auto attr_value = GetValue>(value); - (*attr_obj)[kJValue] = attr_value; - } else if (type == kVTypeListListInt) { - auto attr_value = GetValue>>(value); - (*attr_obj)[kJValue] = attr_value; - } else { - MS_LOG(EXCEPTION) << "Type: " << type << "not support"; - } -} - -std::vector TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector shape; - if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { - shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index); - } else { - shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index); - } - if (shape.empty()) { - shape.emplace_back(1); - } - return shape; -} - -std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - TypeId type_id; - if (creater_type_ == OP_SELECT_FORMAT) { - type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index); - } else { - type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index); - } - return tbe::TypeIdToString(type_id); -} - -std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::string format = kOpFormat_NCHW; - if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { - format = AnfAlgo::GetInputFormat(anf_node, real_index); - if (format == kOpFormat_FRAC_Z) { - format = kOpFormat_FRACTAL_Z; - } else if (format == kOpFormat_DEFAULT) { - format = kOpFormat_NCHW; - } - } - return format; -} - -std::vector TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::vector shape; - if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) { - shape = AnfAlgo::GetOutputInferShape(anf_node, real_index); - } else { - shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index); - } - if (shape.empty()) { - shape.emplace_back(1); - } - return shape; -} - -std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - TypeId type_id; - if (creater_type_ == OP_SELECT_FORMAT) { - type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index); - } else { - type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index); - } - return tbe::TypeIdToString(type_id); -} - -std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { - MS_EXCEPTION_IF_NULL(anf_node); - std::string format = kOpFormat_NCHW; - if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { - format = AnfAlgo::GetOutputFormat(anf_node, real_index); - if (format == kOpFormat_FRAC_Z) { - format = kOpFormat_FRACTAL_Z; - } else if (format == kOpFormat_DEFAULT) { - format = kOpFormat_NCHW; - } - } - return format; -} - -bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, - std::vector *output_size_list) { - if (input_size_list == nullptr || output_size_list == nullptr) { - MS_LOG(ERROR) << "Input size or output size is nullptr"; - return false; - } - input_size_list->clear(); - output_size_list->clear(); - for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) { - for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) { - size_t size_i = 1; - if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) { - std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName]; - MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false."; - continue; - } - for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) { - size_i *= static_cast(j); - } - std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype]; - size_t nbyte = tbe::GetDtypeNbyte(dtype); - size_i *= nbyte; - input_size_list->push_back(size_i); - } - } - for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) { - for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) { - size_t size_i = 1; - if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) { - std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName]; - MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false."; - continue; - } - for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) { - size_i *= static_cast(j); - } - std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype]; - size_t nbyte = tbe::GetDtypeNbyte(dtype); - size_i *= nbyte; - output_size_list->push_back(size_i); - } - } - return true; -} - -bool TbeKernelBuild::GenFusionScopeJson(const std::vector &input_nodes, - const std::vector &compute_nodes, - nlohmann::json *fusion_str, std::string *fusion_kernel) { - MS_EXCEPTION_IF_NULL(fusion_str); - MS_EXCEPTION_IF_NULL(fusion_kernel); - // get input layer info - std::vector> input_layers; - std::map spec_data_input; - if (!GetInputLayers(input_nodes, compute_nodes, &input_layers, &spec_data_input)) { - return false; - } - // gen fusion scopre_op jsom - std::vector compute_list; - (*fusion_kernel) = kFusionKernelNamePrfix; - // index: fusion build option input record, next one from 0 - static size_t index = 0; - auto layer_iter = input_layers.begin(); - auto compute_op_iter = compute_nodes.begin(); - for (; compute_op_iter != compute_nodes.end(); ++compute_op_iter, ++layer_iter) { - nlohmann::json compute_op_str; - (void)GenFusionComputeJson(*compute_op_iter, &layer_iter, &compute_op_str, fusion_kernel, &index); - compute_list.push_back(compute_op_str); - } - index = 0; - // gen data input json - std::vector data_list; - for (const auto &layer : input_layers) { - for (const auto &data_input : layer) { - nlohmann::json data_str; - if (!GenFusionDataInputJson(data_input, spec_data_input, &data_str, &index)) { - MS_LOG(INFO) << "Fusion error: gen fusion datainput json faild."; - return false; - } - data_list.push_back(data_str); - } - } - index = 0; - data_list.insert(data_list.end(), compute_list.begin(), compute_list.end()); - (*fusion_str)[kFusionOpList] = data_list; - return true; -} - -void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, - size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { - std::string output_desc_name = anf_node->fullname_with_scope(); - if (node_out_idx > 0) { - output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); - } - (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); - auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); - (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); - auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); - if (ori_shape.empty()) { - ori_shape.emplace_back(1); - } - (*output_desc)[kJOriShape] = ori_shape; - auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); - if (shape.empty()) { - shape.emplace_back(1); - } - (*output_desc)[kJShape] = shape; - auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); - if (format == kOpFormat_DEFAULT) { - format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND; - } - (*output_desc)[kJFormat] = format; - (*output_desc)[kJOriFormat] = kOpFormat_NCHW; - (*output_desc)[kJOutputIndex] = desc_output_idx; - if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) { - std::vector spec_shape = {}; - spec_shape.emplace_back(shape[0]); - spec_shape.emplace_back(shape[1]); - spec_shape.emplace_back(shape[2] * shape[3]); - spec_shape.emplace_back(shape[4]); - (*output_desc)[kJShape] = spec_shape; - } else if (fusion_data_type == kFusionReLUGradV2) { - std::vector spec_shape = {}; - spec_shape.emplace_back(shape[0]); - spec_shape.emplace_back(shape[1]); - spec_shape.emplace_back(shape[2] * shape[3]); - spec_shape.emplace_back(16); - (*output_desc)[kJShape] = spec_shape; - (*output_desc)[kJDataType] = kVTypeBool; - } -} - -void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, - size_t output_index, nlohmann::json *output_desc) { - std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); - (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name); - (*output_desc)[kJOutputIndex] = output_index; - std::vector shape; - (*output_desc)[kJShape] = shape; -} - -bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name, - const std::vector &reorder_layer, - std::map *spec_data_input) { - if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) { - MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. "; - return false; - } - MS_LOG(INFO) << "Fusion info: op_name: " << op_name << "input layer size: " << reorder_layer.size(); - if (op_name == kReluGradV2OpName) { - (*spec_data_input)[reorder_layer[0]] = kFusionReLUGradV2; - } else if (op_name == kAddNOpName) { - for (const auto &it : reorder_layer) { - (*spec_data_input)[it] = kFusionAddN; - } - } - return true; -} - -bool TbeKernelBuild::GetInputLayers(const std::vector &input_nodes, - const std::vector &compute_nodes, - std::vector> *input_layers, - std::map *spec_data_input) { - MS_EXCEPTION_IF_NULL(input_layers); - MS_EXCEPTION_IF_NULL(spec_data_input); - auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) { - auto op_name = AnfAlgo::GetCNodeName(it); - return op_name == kConv2DBackpropInputOpName; - }); - bool need_spec = (result != compute_nodes.end()); - size_t input_size = 0; - for (const auto &compute_node : compute_nodes) { - std::vector layer = {}; - std::vector reorder_layer = {}; - MS_EXCEPTION_IF_NULL(compute_node); - auto op_name = AnfAlgo::GetCNodeName(compute_node); - auto ccompute_node = compute_node->cast(); - if (ccompute_node == nullptr) { - MS_LOG(INFO) << "Fusion error: fusion compute node must be cnode"; - return false; - } - MS_LOG(INFO) << "Fusion info: compute name: " << compute_node->fullname_with_scope(); - for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) { - auto input = ccompute_node->input(i); - auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input); - if (find_iter != input_nodes.end()) { - MS_LOG(INFO) << "Fusion info: add compute node's [" << i << "] input: " << input->fullname_with_scope(); - layer.emplace_back((*find_iter)); - } else { - MS_LOG(INFO) << "Fusion warnig: this input [" << i << "] may be pre compute(" << input->fullname_with_scope() - << ") node's output."; - } - } - TbeAdapter::FusionDataOrderPass(op_name, layer, &reorder_layer); - if (need_spec) { - MS_LOG(INFO) << "Fusion info: match conv2d backprop input + ... patten."; - if (!GetSpecInputLayers(op_name, reorder_layer, spec_data_input)) { - return false; - } - } - input_size += reorder_layer.size(); - input_layers->emplace_back(reorder_layer); - } - if (input_nodes.size() != input_size) { - MS_LOG(INFO) << "Fusion error: fusion scope error, layer input:" << input_size - << ", input_node:" << input_nodes.size(); - return false; - } - return true; -} - -bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr &data_input, - const std::map &spec_data_input, - nlohmann::json *data_str, size_t *index) { - MS_EXCEPTION_IF_NULL(data_str); - MS_EXCEPTION_IF_NULL(index); - std::vector output_desc_list; - if (!data_input) { - MS_LOG(INFO) << "Data input is optional node"; - auto name = std::string(kOptional) + std::to_string(*index); - (*data_str)[kJName] = name; - nlohmann::json output_desc; - output_desc[kJName] = name; - output_desc[kJShape] = "NULL"; - output_desc_list.push_back(output_desc); - (*index)++; - } else { - FusionDataType fusion_data_type = kFusionNormal; - if (spec_data_input.find(data_input) != spec_data_input.end()) { - fusion_data_type = spec_data_input.at(data_input); - } - auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0); - auto real_node = kernel_idx.first; - size_t real_idx = kernel_idx.second; - MS_LOG(INFO) << "Real name " << real_node->fullname_with_scope() << " index:" << real_idx; - // kJOutputDesc - nlohmann::json output_desc; - GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type); - output_desc_list.push_back(output_desc); - (*data_str)[kJName] = NormalizeFullScopeName(real_node->fullname_with_scope()); - } - (*data_str)[kJOutputDesc] = output_desc_list; - (*data_str)[kJtype] = "Data"; - return true; -} - -bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input. - bool ret = false; - std::vector dyn_input_sizes; - auto dynamic_input_attr = primitive->GetAttr(kAttrDynInputSizes); - if (dynamic_input_attr != nullptr) { - dyn_input_sizes = GetValue>(dynamic_input_attr); - auto real_input_size = cnode->inputs().size() - 1; - auto dyn_input_size = dyn_input_sizes.size(); - if (dyn_input_size != 1) { - MS_LOG(INFO) << "Fusion error: fusion build not support dyn_input_sizes > 1"; - return ret; - } - if (IntToSize(dyn_input_sizes[0]) != real_input_size) { - MS_LOG(INFO) << "Fusion error: dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size" - << real_input_size; - return ret; - } - ret = true; - } - return ret; -} - -size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) { - MS_EXCEPTION_IF_NULL(cnode); - if (is_dynamic_input) { - return 0; - } - MS_EXCEPTION_IF_NULL(cnode); - auto node_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = OpLib::FindOp(node_name, kTBE); - MS_EXCEPTION_IF_NULL(cnode); - if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { - MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); - } - return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); -} - -std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { - static std::map buffer_fussion_op_map = { - {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}}; - string result = origin_type; - auto iter = buffer_fussion_op_map.find(origin_type); - if (iter != buffer_fussion_op_map.end()) { - result = iter->second; - } - return result; -} - -bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(input_desc_list); - std::vector input_desc_list_tmp = {}; - bool is_dynamic_input = IsDynamicInput(cnode); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto input = cnode->input(i); - auto kernel_idx = AnfAlgo::VisitKernel(input, 0); - auto real_node = kernel_idx.first; - size_t real_idx = kernel_idx.second; - MS_LOG(INFO) << "Real name" << real_node->fullname_with_scope() << "index:" << real_idx; - nlohmann::json input_desc; - GenDescJson(real_node, real_idx, real_idx, &input_desc); - if (is_dynamic_input) { - MS_LOG(INFO) << "Node has dynamic input."; - input_desc[kJDynIndex] = (i - 1); - } - input_desc_list_tmp.emplace_back(input_desc); - } - size_t optional_num = GetOptionalInput(cnode, is_dynamic_input); - if (optional_num > 0) { - MS_LOG(INFO) << "Node has optional input."; - for (size_t i = 0; i < optional_num; ++i) { - nlohmann::json optional_input_desc; - optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index); - (*index)++; - (*layer_iter)->emplace_back(nullptr); - input_desc_list_tmp.emplace_back(optional_input_desc); - } - } - auto op_name = AnfAlgo::GetCNodeName(cnode); - TbeAdapter::FusionInputOrderPass(op_name, input_desc_list_tmp, input_desc_list); - return true; -} - -std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &output_used_nums) { - std::vector desc_output_index = {}; - for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { - auto output_use_num_item = output_used_nums[idx]; - MS_LOG(INFO) << "Output used num[" << idx << "] = " << output_use_num_item; - desc_output_index.emplace_back(idx); - if (output_use_num_item > 1) { - desc_output_index.emplace_back(idx); - } - } - return desc_output_index; -} - -bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, - std::vector *output_desc_list) { - MS_EXCEPTION_IF_NULL(output_desc_list); - auto output_size = AnfAlgo::GetOutputTensorNum(cnode); - if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { - auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); - MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); - if (output_used_nums.size() != output_size) { - MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" - << " is not match output used num(" << output_used_nums.size() << ")"; - return false; - } - auto desc_output_index = GetDescOutputIndex(output_used_nums); - for (size_t i = 0; i < output_size; ++i) { - MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; - nlohmann::json output_desc; - GenDescJson(cnode, i, desc_output_index[i], &output_desc); - output_desc_list->emplace_back(output_desc); - } - for (size_t j = output_size; j < desc_output_index.size(); ++j) { - MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j]; - nlohmann::json output_desc; - GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc); - output_desc_list->emplace_back(output_desc); - } - } else { - for (size_t i = 0; i < output_size; ++i) { - nlohmann::json output_desc; - GenDescJson(cnode, i, i, &output_desc); - output_desc_list->push_back(output_desc); - } - } - return true; -} - -bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, - std::vector>::iterator *layer_iter, - nlohmann::json *compute_op_str, std::string *fusion_kernel_name, - size_t *index) { - MS_EXCEPTION_IF_NULL(compute_node); - auto cnode = compute_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // gen input desc - std::vector input_desc_list; - (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); - (*compute_op_str)[kJInputDesc] = input_desc_list; - // gen output desc - std::vector output_desc_list; - if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { - MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); - return false; - } - (*compute_op_str)[kJOutputDesc] = output_desc_list; - // gen others - auto origin_type = AnfAlgo::GetCNodeName(cnode); - // replace special op type for buffer fusion op - auto type = GetRealOpType(origin_type); - (*compute_op_str)[kJtype] = type; - tbe::TbeAdapter::NormalizeFuncName(&type); - (*compute_op_str)[kJFuncName] = type; - (*compute_op_str)[kJName] = NormalizeFullScopeName(cnode->fullname_with_scope()); - (void)(*fusion_kernel_name).append("_"); - (void)(*fusion_kernel_name).append(type); - return true; -} - -size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) { - size_t ret = 1; - for (const auto &shape_item : desc[kJShape]) { - ret *= static_cast(shape_item); - } - std::string data_type = desc[kJDataType]; - size_t nbyte = tbe::GetDtypeNbyte(data_type); - ret *= nbyte; - return ret; -} - -bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, - const std::vector &output_nodes, - std::vector *input_size_list, std::vector *output_size_list) { - MS_EXCEPTION_IF_NULL(input_size_list); - MS_EXCEPTION_IF_NULL(output_size_list); - input_size_list->clear(); - output_size_list->clear(); - - for (const auto &op : fusion_op_list) { - if (op[kJtype] == "Data") { - const auto &data_output_desc = op[kJOutputDesc]; - for (const auto &data_output : data_output_desc) { - if (data_output[kJShape] == "NULL") { - break; - } - auto ret = GetIOSizeImpl(data_output); - input_size_list->push_back(ret); - MS_LOG(INFO) << "Fusion info: scope input name: " << op[kJName] << ", size: " << ret; - } - } - } - - for (const auto &output_node : output_nodes) { - auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0); - auto real_node = kernel_idx.first; - size_t real_idx = kernel_idx.second; - auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope()); - MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx; - for (const auto &op : fusion_op_list) { - if (op[kJName] == normal_name) { - auto op_output_desces = op[kJOutputDesc]; - if (output_node != real_node) { - // tuple_get item - MS_LOG(INFO) << "Output is a tuple getitem node"; - auto output_desc = op_output_desces[real_idx]; - if (output_desc[kJShape].empty()) { - MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx; - return false; - } - auto ret = GetIOSizeImpl(output_desc); - output_size_list->push_back(ret); - MS_LOG(INFO) << "Fusion info: scope output index: " << real_idx << ", size: " << ret; - } else { - for (const auto &output_desc : op_output_desces) { - if (output_desc[kJShape].empty()) { - MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output"; - continue; - } - auto ret = GetIOSizeImpl(output_desc); - output_size_list->push_back(ret); - MS_LOG(INFO) << "Fusion info: scope output size: " << ret; - } - } - } - } - } - return true; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h deleted file mode 100644 index eef02efa87..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "ir/dtype.h" -#include "kernel/kernel.h" -#include "pybind11/stl.h" -#include "kernel/oplib/oplib.h" -#include "kernel/tbe/tbe_adapter.h" - -namespace mindspore { -namespace kernel { -// kernel operate type used for generate json - -class TbeKernelBuild { - enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 }; - - public: - static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, - std::vector *output_size_list); - // Ub Fuison - static bool GenFusionScopeJson(const std::vector &input_nodes, - const std::vector &compute_nodes, nlohmann::json *fusion_str, - std::string *fusion_kernel); - static bool GetIOSize(const nlohmann::json &fusion_op_list, const std::vector &output_nodes, - std::vector *input_size_list, std::vector *output_size_list); - - private: - TbeKernelBuild() = default; - ~TbeKernelBuild() = default; - static bool GenFusionDataInputJson(const std::shared_ptr &data_input, - const std::map &spec_data_input, - nlohmann::json *data_str, size_t *index); - static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, - std::vector>::iterator *layer_iter, - nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); - static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index); - static std::vector GetDescOutputIndex(const std::vector &output_used_nums); - static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, - std::vector *output_desc_list); - static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, - size_t desc_output_idx, nlohmann::json *output_desc, - FusionDataType fusion_data_type = kFusionNormal); - static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, - size_t output_index, nlohmann::json *output_desc); - static size_t GetIOSizeImpl(const nlohmann::json &desc); - static bool GetSpecInputLayers(const std::string &op_name, const std::vector &reorder_layer, - std::map *spec_data_input); - static bool GetInputLayers(const std::vector &input_nodes, - const std::vector &compute_nodes, - std::vector> *input_layers, - std::map *spec_data_input); - static bool IsDynamicInput(const CNodePtr &cnode); - static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); - static std::string GetRealOpType(const std::string &origin_type); -}; - -class TbeKernelJsonCreator { - public: - explicit TbeKernelJsonCreator(kCreaterType creater_type = SINGLE_BUILD) : creater_type_(creater_type) {} - ~TbeKernelJsonCreator() = default; - bool GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json); - std::string json_name() { return json_name_; } - - private: - bool GenTbeInputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, - nlohmann::json *inputs_json); - bool GenTbeOutputsJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, - nlohmann::json *outputs_json); - bool GenTbeAttrJson(const std::shared_ptr &anf_node, const std::shared_ptr &op_info, - nlohmann::json *attrs_json); - static void ParseAttrValue(const std::string &type, const ValuePtr &value, nlohmann::json *attr_obj); - bool GenInputDescJson(const std::shared_ptr &anf_node, size_t real_input_index, bool value, - const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, - std::vector *input_list); - bool GenOutputDescJson(const std::shared_ptr &anf_node, - const std::vector> &outputs_ptr, nlohmann::json *outputs_json); - bool GenInputList(const std::shared_ptr &anf_node, size_t input_tensor_num, - const std::shared_ptr &input_ptr, size_t *real_input_index, string *op_input_name, - std::vector *input_list); - void GenOutputList(const std::shared_ptr &anf_node, const size_t &output_obj_num, - const std::shared_ptr &output_ptr, size_t *output_idx, - std::vector *output_list); - std::vector GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const; - std::vector GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const; - std::string GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const; - - kCreaterType creater_type_; - std::string json_name_; - std::string json_info_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.cc deleted file mode 100644 index 9d5222659a..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_mod.h" -#include -#include "runtime/rt.h" -#include "utils/context/ms_context.h" -#include "graphengine/inc/framework/ge_runtime/task_info.h" - -namespace mindspore { -namespace kernel { -using TbeTaskInfoPtr = std::shared_ptr; -using tbe::KernelManager; -bool TbeKernelMod::Launch(const std::vector &inputs, - const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) { - if (stream_ptr == nullptr) { - MS_LOG(ERROR) << "stream_ptr should not be nullptr."; - return false; - } - - if (kernel_pack_ == nullptr) { - MS_LOG(ERROR) << "kernel pack should not be nullptr."; - return false; - } - - uint32_t blockdim = 1; // default blockdim equal to 1. - auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim); - if (func_stub == 0) { - MS_LOG(ERROR) << "GenFuncStub failed."; - return false; - } - - // pack all addresses into a vector. - std::vector runtimeargs; - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs), - [](const AddressPtr &output) -> void * { return output->addr; }); - if (!workspace.empty()) { - (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs), - [](const AddressPtr &addr) -> void * { return addr->addr; }); - } - rtL2Ctrl_t *l2ctrl = nullptr; - const void *stubFunc = reinterpret_cast(func_stub); - auto argsSize = static_cast(UlongToUint(sizeof(void *)) * runtimeargs.size()); - if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) { - MS_LOG(ERROR) << "Call runtime rtKernelLaunch error."; - return false; - } - - return true; -} - -std::vector TbeKernelMod::GenTask(const std::vector &inputs, - const std::vector &workspaces, - const std::vector &outputs, uint32_t stream_id) { - if (kernel_pack_ == nullptr) { - MS_EXCEPTION(ArgumentError) << "kernel pack should not be nullptr."; - } - - std::vector args; - std::vector sm_desc; - std::vector meta_data; - std::vector input_data_addrs; - std::vector output_data_addrs; - std::vector workspace_addrs; - - // pack all addresses into a vector. - (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs), - [](const AddressPtr &input) -> void * { return input->addr; }); - (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs), - [](const AddressPtr &output) -> void * { return output->addr; }); - if (!workspaces.empty()) { - (void)std::transform(std::begin(workspaces), std::end(workspaces), std::back_inserter(workspace_addrs), - [](const AddressPtr &workspace) -> void * { return workspace->addr; }); - } - - stream_id_ = stream_id; - auto funcstub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim_); - if (funcstub == 0) { - MS_EXCEPTION(ArgumentError) << "GenFuncStub failed."; - } - - std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_); - - MS_LOG(INFO) << "block_dim is:" << block_dim_; - - TbeTaskInfoPtr task_info_ptr = make_shared( - kernel_name_, stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0, meta_data, input_data_addrs, - output_data_addrs, workspace_addrs, NeedDump()); - return {task_info_ptr}; -} - -vector TbeKernelMod::GenParameters() { - auto kernel_json_info = kernel_pack_->kernel_json_info(); - return kernel_json_info.parameters; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.h deleted file mode 100644 index e0e7ab4646..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_mod.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ - -#include -#include -#include -#include -#include "kernel/ascend_kernel_mod.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -class TbeKernelMod : public AscendKernelMod { - public: - explicit TbeKernelMod(KernelPackPtr kernel_pack) : kernel_pack_(std::move(kernel_pack)) {} - ~TbeKernelMod() override = default; - - void SetInputSizeList(const std::vector &size_list) { input_size_list_ = size_list; } - void SetOutputSizeList(const std::vector &size_list) { output_size_list_ = size_list; } - void SetWorkspaceSizeList(const std::vector &size_list) { workspace_size_list_ = size_list; } - const std::vector &GetInputSizeList() const override { return input_size_list_; } - const std::vector &GetOutputSizeList() const override { return output_size_list_; } - const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override; - std::vector GenTask(const std::vector &inputs, const std::vector &workspaces, - const std::vector &outputs, uint32_t stream_id) override; - std::vector GenParameters() override; - - private: - KernelPackPtr kernel_pack_; - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; -}; - -using TbeKernelModPtr = std::shared_ptr; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc deleted file mode 100644 index 43d492f397..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.cc +++ /dev/null @@ -1,326 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_parallel_build.h" - -#include -#include -#include -#include -#include -#include - -#include "utils/context/ms_context.h" -#include "kernel/tbe/tbe_adapter.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "kernel/tbe/tbe_kernel_mod.h" -#include "session/anf_runtime_algorithm.h" -#include "./common.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "kernel/tbe/tbe_utils.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kParallelCompileModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; -constexpr auto kCreateParallelCompiler = "create_tbe_parallel_compiler"; -constexpr auto kStartCompileOp = "start_compile_op"; -constexpr auto kWaitOne = "wait_one"; -constexpr auto kResetTaskInfo = "reset_task_info"; - -bool TbeOpParallelPreBuild(const std::vector &anf_nodes) { - auto build_manger = std::make_shared(); - MS_EXCEPTION_IF_NULL(build_manger); - for (const auto &anf_node : anf_nodes) { - // gen kernel json - MS_EXCEPTION_IF_NULL(anf_node); - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(OP_PRE_COMPILE); - if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { - MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; - return false; - } - kernel_json["compile_type"] = "pre_build"; - // op build - auto task_id = build_manger->StartCompileOp(kernel_json); - build_manger->SavePreTaskInfo(task_id, anf_node); - } - while (!build_manger->IsAllPreTaskFinish()) { - int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); - if (!ret) { - MS_EXCEPTION(ArgumentError) << "Pre Build Failed. wait one ret:" << ret << ", task id:" << task_id; - } - - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { - MS_EXCEPTION(ArgumentError) << "task pre compile Failed, task id:" << task_id << ", cause:" << task_result; - } - - build_manger->PreTaskFinishProcess(task_id, pre_build_result); - } - return true; -} - -bool TbeOpParallelBuild(const std::vector &anf_nodes) { - auto build_manger = std::make_shared(); - MS_EXCEPTION_IF_NULL(build_manger); - set processed_kernel; - for (const auto &anf_node : anf_nodes) { - // gen kernel json - tbe::TbeAdapter::SetTbeAttrsForTransDataOp(anf_node); - if (AnfAlgo::GetKernelMod(anf_node) != nullptr) { - continue; - } - const std::string &processor = tbe::GetProcessor(anf_node); - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(SINGLE_BUILD); - if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) { - MS_LOG(ERROR) << "GenTbeSingleKernelJson failed"; - return false; - } - // get size - std::vector input_size_list; - std::vector output_size_list; - (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); - // search cache - const std::string &json_name = creator.json_name(); - if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) { - MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name; - continue; - } - // same op not need build, but need wait build finish to set kernel mode - if (processed_kernel.find(json_name) != processed_kernel.end()) { - build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list); - continue; - } - (void)processed_kernel.insert(json_name); - // op build - auto task_id = build_manger->StartCompileOp(kernel_json); - build_manger->SaveTaskInfo(task_id, anf_node, json_name, input_size_list, output_size_list); - } - while (!build_manger->IsAllTaskFinish()) { - int task_id = -1; - char *task_result = nullptr; - char *pre_build_result = nullptr; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); - if (!ret) { - MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; - } - - if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { - MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; - } - (void)build_manger->TaskFinishProcess(task_id); - } - return build_manger->GenSameOpKernelMod(); -} - -ParallelBuildManager::ParallelBuildManager() { tbe_parallel_compiler_ = TbePythonFuncs::TbeParallelCompiler(); } - -ParallelBuildManager::~ParallelBuildManager() { ResetTaskInfo(); } - -int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) const { - PyObject *pRes = nullptr; - PyObject *pArgs = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); - (void)PyTuple_SetItem(pArgs, 0, arg1); - pRes = PyObject_CallMethod(tbe_parallel_compiler_, kStartCompileOp, "O", pArgs); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function start_compile_op"; - } - int task_id; - (void)PyArg_Parse(pRes, "i", &task_id); - MS_LOG(INFO) << "start compile , task id:" << task_id; - return task_id; -} - -bool ParallelBuildManager::WaitOne(int *task_id, char **task_result, char **pre_build_result) const { - MS_LOG(INFO) << "wait task start."; - MS_EXCEPTION_IF_NULL(task_id); - MS_EXCEPTION_IF_NULL(task_result); - PyObject *pRes = nullptr; - PyObject *pArg = Py_BuildValue("()"); - pRes = PyObject_CallMethod(tbe_parallel_compiler_, kWaitOne, "O", pArg); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function wait_one"; - return false; - } - (void)PyArg_ParseTuple(pRes, "iss", task_id, task_result, pre_build_result); - return true; -} - -void ParallelBuildManager::SavePreTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node) { - MS_LOG(INFO) << "SavePreTaskInfo, task id: " << task_id; - pre_task_map_[task_id] = anf_node; -} - -void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node, - const std::string &json_name, const std::vector &input_size_list, - const std::vector &output_size_list, int32_t scope_id) { - MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id; - struct KernelBuildTaskInfo task_info; - task_info.node = anf_node.get(); - task_info.json_name = json_name; - if (anf_node == nullptr) { - task_info.processor = tbe::kProcessorAiCore; - } else { - task_info.processor = tbe::GetProcessor(anf_node); - } - task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); - task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); - task_info.scope_id = scope_id; - task_map_[task_id] = task_info; -} - -bool ParallelBuildManager::IsAllPreTaskFinish() const { - MS_LOG(INFO) << "wait pre build process task_num: " << pre_task_map_.size(); - return pre_task_map_.empty(); -} - -bool ParallelBuildManager::IsAllTaskFinish() const { - MS_LOG(INFO) << "wait process task_num: " << task_map_.size(); - return task_map_.empty(); -} - -void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) { - auto task_iter = pre_task_map_.find(task_id); - if (task_iter == pre_task_map_.end()) { - MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; - } - auto node = task_iter->second; - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); - std::string start_flag = "fusion_pattern_start"; - std::string end_flag = "fusion_pattern_end"; - int start = pre_build_result.find(start_flag); - int end = pre_build_result.find(end_flag); - if (start != -1 && end != -1 && end >= start) { - std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size()); - if (result == "") { - (void)pre_task_map_.erase(task_iter); - return; - } - transform(result.begin(), result.end(), result.begin(), ::toupper); - FusionType fusion_type = tbe::GetFusionType(result); - builder->SetFusionType(fusion_type); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - } - (void)pre_task_map_.erase(task_iter); -} - -std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) { - auto task_iter = task_map_.find(task_id); - if (task_iter == task_map_.end()) { - MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id; - } - auto json_name = task_iter->second.json_name; - auto processor = task_iter->second.processor; - auto kernel_pack = TbeUtils::InsertCache(json_name, processor); - if (kernel_pack == nullptr) { - if (set_kernel_mod) { - MS_EXCEPTION(ArgumentError) << "build kernel name:" << task_iter->second.json_name << " failed."; - } else { - MS_LOG(INFO) << "fusion build kernel name:" << task_iter->second.json_name << "failed."; - auto ret = std::make_pair(task_iter->second.scope_id, nullptr); - (void)task_map_.erase(task_iter); - return ret; - } - } - auto kernel_mod = GenKernelMod(json_name, processor, task_iter->second.input_size_list, - task_iter->second.output_size_list, kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod); - if (set_kernel_mod) { - AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); - } - auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); - (void)task_map_.erase(task_iter); - MS_LOG(INFO) << "wait process remain task_num:" << task_map_.size(); - return ret; -} - -void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node, const std::string &json_name, - const std::vector &input_size_list, - const std::vector &output_size_list) { - struct KernelBuildTaskInfo task_info; - task_info.node = anf_node.get(); - task_info.json_name = json_name; - task_info.processor = tbe::GetProcessor(anf_node); - task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); - task_info.output_size_list.assign(output_size_list.begin(), output_size_list.end()); - same_op_list_.push_back(task_info); -} - -bool ParallelBuildManager::GenSameOpKernelMod() const { - for (const auto &task_info : same_op_list_) { - bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list, - task_info.output_size_list, task_info.node); - if (!ret) { - MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache."; - return false; - } - } - return true; -} - -bool ParallelBuildManager::SearchInCache(const std::string &json_name, const std::string &processor, - const std::vector &input_size_list, - const std::vector &output_size_list, mindspore::AnfNode *node) const { - auto cached_kernel_pack = TbeUtils::SearchCache(json_name, processor); - if (cached_kernel_pack != nullptr) { - MS_LOG(INFO) << "Find cached kernel, kernel json name" << json_name; - auto kernel_mod_ptr = GenKernelMod(json_name, processor, input_size_list, output_size_list, cached_kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - AnfAlgo::SetKernelMod(kernel_mod_ptr, node); - return true; - } else { - return false; - } -} - -KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const string &processor, - const vector &input_size_list, - const vector &output_size_list, - const mindspore::kernel::KernelPackPtr &kernel_pack) const { - MS_EXCEPTION_IF_NULL(kernel_pack); - auto kernel_json_info = kernel_pack->kernel_json_info(); - auto kernel_mod_ptr = std::make_shared(kernel_pack); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); - kernel_mod_ptr->SetInputSizeList(input_size_list); - kernel_mod_ptr->SetOutputSizeList(output_size_list); - kernel_mod_ptr->SetWorkspaceSizeList(kernel_json_info.workspaces); - return kernel_mod_ptr; -} - -void ParallelBuildManager::ResetTaskInfo() { - if (task_map_.empty()) { - MS_LOG(INFO) << "All tasks are compiled success."; - return; - } - task_map_.clear(); - same_op_list_.clear(); - if (tbe_parallel_compiler_ != nullptr) { - PyObject *pArg = Py_BuildValue("()"); - (void)PyObject_CallMethod(tbe_parallel_compiler_, kResetTaskInfo, "O", pArg); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h deleted file mode 100644 index 637c03bce3..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_parallel_build.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ - -#include -#include -#include -#include -#include "kernel/kernel.h" -#include "pybind11/stl.h" -#include -namespace mindspore { -namespace kernel { -bool TbeOpParallelPreBuild(const std::vector &anf_nodes); -bool TbeOpParallelBuild(const std::vector &anf_nodes); - -struct KernelBuildTaskInfo { - AnfNode *node; - std::string processor; - std::string json_name; - std::vector input_size_list; - std::vector output_size_list; - int32_t scope_id; -}; - -class ParallelBuildManager { - public: - ParallelBuildManager(); - ~ParallelBuildManager(); - int32_t StartCompileOp(const nlohmann::json &kernel_json) const; - void SavePreTaskInfo(int32_t task_id, const AnfNodePtr &anf_node); - void SaveTaskInfo(int32_t task_id, const AnfNodePtr &anf_node, const std::string &json_name, - const std::vector &input_size_list, const std::vector &output_size_list, - int32_t scope_id = 0); - void SaveSameOpInfo(const AnfNodePtr &anf_node, const std::string &json_name, - const std::vector &input_size_list, const std::vector &output_size_list); - bool GenSameOpKernelMod() const; - bool SearchInCache(const std::string &json_name, const std::string &processor, - const std::vector &input_size_list, const std::vector &output_size_list, - AnfNode *node) const; - - bool WaitOne(int *task_id, char **task_result, char **pre_build_result) const; - bool IsAllPreTaskFinish() const; - bool IsAllTaskFinish() const; - void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result); - std::pair TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true); - KernelModPtr GenKernelMod(const string &json_name, const string &processor, - const std::vector &input_size_list, const std::vector &output_size_list, - const KernelPackPtr &kernel_pack) const; - void ResetTaskInfo(); - - private: - PyObject *tbe_parallel_compiler_; - std::map pre_task_map_; - std::map task_map_; - std::vector same_op_list_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_KERNEL_PARALLEL_BUILD_H_ diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc deleted file mode 100644 index 8050f02f95..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc +++ /dev/null @@ -1,318 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kInputIndex_0 = 0; -constexpr size_t kChannelN = 0; -constexpr size_t kChannelC = 1; -constexpr size_t kAlignmented16 = 16; -// 1. all shape no scalar and same -// 2. part scalar : no_scalar (shape size > xxx && alig xxx) -// 3. all no_scalar and not same (broad cast xxx dim) -bool TbeKernelBroadCastSelecter::GetShapeInfo(SupportFormat *support_format) { - MS_EXCEPTION_IF_NULL(support_format); - input_num_ = 0; - output_num_ = 0; - input_shapes_.clear(); - output_shapes_.clear(); - if (AnfAlgo::HasNodeAttr(kAttrDynInputSizes, cnode_ptr_)) { - MS_LOG(INFO) << "This broadcast node has dynamic input."; - auto dynamic_size_vec = AnfAlgo::GetNodeAttr>(cnode_ptr_, kAttrDynInputSizes); - if (dynamic_size_vec.empty() || dynamic_size_vec[0] < 2) { - MS_LOG(EXCEPTION) << "dynamic attr set error, please check."; - } - auto dynamic_input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); - PadScalarShape(&dynamic_input_shape0_); - input_shapes_.emplace_back(dynamic_input_shape0_); - input_num_ = 1; - } else { - input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); - for (size_t i = 0; i < input_num_; ++i) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); - PadScalarShape(&input_shape); - input_shapes_.emplace_back(input_shape); - } - } - - output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); - for (size_t i = 0; i < output_num_; ++i) { - auto output = AnfAlgo::GetOutputInferShape(cnode_ptr_, i); - PadScalarShape(&output); - output_shapes_.emplace_back(output); - } - AssignSupportFormat(kOpFormat_DEFAULT, support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupport5HD(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_NC1HWC0, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (!Is4DShape(shape)) { - return false; - } - if (shape[kChannelC] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_NC1HWC0); - } - } - } else { - for (const auto &shape : input_shapes_) { - if (!Is4DShape(shape)) { - return false; - } - } - auto shape_tmp = input_shapes_[0]; - auto broadcast_c_axis = std::any_of( - input_shapes_.begin(), input_shapes_.end(), - [&shape_tmp](const std::vector &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); - if (broadcast_c_axis) { - MS_LOG(INFO) << "This node broadcast c channel."; - return false; - } - input_support_format.assign(input_num_, kOpFormat_NC1HWC0); - } - GenOutputSupportFormat(kOpFormat_NC1HWC0, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracZ(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_FRAC_Z, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (!Is4DShape(shape)) { - return false; - } - if (shape[kChannelN] % kAlignmented16 != 0 || shape[kChannelC] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_FRAC_Z); - } - } - } else { - return false; - } - GenOutputSupportFormat(kOpFormat_FRAC_Z, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} -bool TbeKernelBroadCastSelecter::IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_C1HWNCoC0, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (!Is4DShape(shape)) { - return false; - } - if (shape[kChannelN] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_C1HWNCoC0); - } - } - } else { - for (const auto &shape : input_shapes_) { - if (!Is4DShape(shape)) { - return false; - } - } - auto shape_tmp = input_shapes_[0]; - auto broadcast_nc_axis = - std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { - return (shape_tmp.at(kChannelC) != elem.at(kChannelC) || shape_tmp.at(kChannelN) != elem.at(kChannelN)); - }); - if (broadcast_nc_axis) { - MS_LOG(INFO) << "This node broadcast n || c channel."; - return false; - } - input_support_format.assign(input_num_, kOpFormat_C1HWNCoC0); - } - GenOutputSupportFormat(kOpFormat_C1HWNCoC0, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (IsSameShape()) { - if (!HasScalarInput()) { - AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); - return true; - } else { - return false; - } - } - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - if (HasScalarInput()) { - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - input_support_format.emplace_back(kOpFormat_DEFAULT); - } else { - if (shape.size() < kShape2dDims) { - return false; - } - if (shape[shape.size() - 1] % kAlignmented16 != 0 || shape[shape.size() - 2] % kAlignmented16 != 0) { - return false; - } - input_support_format.emplace_back(kOpFormat_FRAC_NZ); - } - } - } else { - auto less_2dims = std::any_of(input_shapes_.begin(), input_shapes_.end(), - [](const std::vector &elem) { return elem.size() < kShape2dDims; }); - if (less_2dims) { - MS_LOG(INFO) << "This node dim less 2."; - return false; - } - - auto shape_tmp = input_shapes_[0]; - auto broadcast_last_dim = - std::any_of(input_shapes_.begin(), input_shapes_.end(), [&shape_tmp](const std::vector &elem) { - return (shape_tmp.at(shape_tmp.size() - 1) != elem.at(elem.size() - 1)) || - (shape_tmp.at(shape_tmp.size() - 2) != elem.at(elem.size() - 2)); - }); - if (broadcast_last_dim) { - MS_LOG(INFO) << "This node broadcast last channel."; - return false; - } - - input_support_format.assign(input_num_, kOpFormat_FRAC_NZ); - } - GenOutputSupportFormat(kOpFormat_FRAC_NZ, &output_support_format); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); - return true; -} - -bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - return false; -} - -bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector &shape) const { - return shape.size() == kShape4dDims; -} - -bool TbeKernelBroadCastSelecter::IsSameShape() const { - auto shape = input_shapes_.begin(); - for (const auto &item : input_shapes_) { - if (shape->size() != item.size()) { - return false; - } - for (size_t i = 0; i < shape->size(); ++i) { - if (shape->at(i) != item.at(i)) { - return false; - } - } - } - return true; -} - -void TbeKernelBroadCastSelecter::PadScalarShape(std::vector *shape) const { - MS_EXCEPTION_IF_NULL(shape); - if (shape->empty()) { - shape->emplace_back(1); - } -} - -bool TbeKernelBroadCastSelecter::IsScalarShape(const std::vector &shape) const { - return (shape.size() == 1 && shape[0] == 1); -} - -bool TbeKernelBroadCastSelecter::HasScalarInput() const { - bool ret = false; - for (const auto &shape : input_shapes_) { - if (IsScalarShape(shape)) { - ret = true; - break; - } - } - return ret; -} - -void TbeKernelBroadCastSelecter::GenOutputSupportFormat(const std::string &support_format, - SupportFormatItem *output_support_item) const { - MS_EXCEPTION_IF_NULL(output_support_item); - for (const auto &shape : output_shapes_) { - if (IsScalarShape(shape)) { - output_support_item->emplace_back(kOpFormat_DEFAULT); - } else { - output_support_item->emplace_back(support_format); - } - } -} - -void TbeKernelBroadCastSelecter::AssignSupportFormat(const std::string &support_format_str, - mindspore::kernel::SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - input_support_format.assign(input_num_, support_format_str); - output_support_format.assign(output_num_, support_format_str); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h deleted file mode 100644 index af711ddf29..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_BROADCAST_SELECTER_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -class TbeKernelBroadCastSelecter { - public: - explicit TbeKernelBroadCastSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} - ~TbeKernelBroadCastSelecter() = default; - bool GetShapeInfo(SupportFormat *support_format); - bool IsBroadCastSupport5HD(SupportFormat *support_format) const; - bool IsBroadCastSupportFracZ(SupportFormat *support_format) const; - bool IsBroadCastSupportC1HWNCoC0(SupportFormat *support_format) const; - bool IsBroadCastSupportFracNZ(SupportFormat *support_format) const; - bool IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const; - - private: - bool IsSameShape() const; - void PadScalarShape(std::vector *shape) const; - bool Is4DShape(const std::vector &shape) const; - bool IsScalarShape(const std::vector &shape) const; - bool HasScalarInput() const; - void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; - void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; - // broadcast - CNodePtr cnode_ptr_; - size_t input_num_{}; - size_t output_num_{}; - std::vector> input_shapes_; - std::vector> output_shapes_; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_TBE_KERNEL_BROADCAST_SELECTER_HELPER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc deleted file mode 100644 index 84f3fc29e3..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" -#include -#include -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace kernel { -constexpr size_t kInputIndex_0 = 0; -constexpr size_t kOutputIndex_0 = 0; -constexpr size_t kChannelN = 0; -constexpr size_t kChannelC = 1; -constexpr size_t kReduceNZMinDim = 3; - -bool TbeKernelReduceSelecter::GetShapeInfo(SupportFormat *support_format) { - MS_EXCEPTION_IF_NULL(support_format); - input_shape_.clear(); - output_shape_.clear(); - axis_.clear(); - auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); - auto output_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); - if (input_num != 1 || output_num != 1) { - MS_LOG(EXCEPTION) << "Reduce operator only support one input/output, input num: " << input_num - << ", output num: " << output_num; - } - // get input/output shape - input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, kInputIndex_0); - PadScalarShape(&input_shape_); - output_shape_ = AnfAlgo::GetOutputInferShape(cnode_ptr_, kOutputIndex_0); - PadScalarShape(&output_shape_); - // get keep dim attr - GetReduceAttrKeepDim(); - // get axis attr - axis_ = GetReduceAttrAxis(cnode_ptr_); - AssignSupportFormat(kOpFormat_DEFAULT, support_format); - return true; -} - -bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (!Is4DShape(input_shape_)) { - return false; - } - if (!keep_dims_ || axis_.empty()) { - return false; - } - auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); - if (reduce_c_axis) { - return false; - } - AssignSupportFormat(kOpFormat_NC1HWC0, support_format); - return true; -} - -bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - // like to 5HD - return false; -} - -bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { - return IsFracZAndC1HWNCoC0Common(kOpFormat_FRAC_Z, support_format); -} - -bool TbeKernelReduceSelecter::IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const { - return IsFracZAndC1HWNCoC0Common(kOpFormat_C1HWNCoC0, support_format); -} - -bool TbeKernelReduceSelecter::IsReduceSupportFracNZ(SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (input_shape_.size() < kReduceNZMinDim) { - return false; - } - if (axis_.empty()) { - return false; - } - auto reduce_last_axis = std::any_of(axis_.begin(), axis_.end(), [this](const size_t &elem) { - return (elem == (this->input_shape_.size() - 1) || elem == (this->input_shape_.size() - 2)); - }); - if (reduce_last_axis) { - return false; - } - AssignSupportFormat(kOpFormat_FRAC_NZ, support_format); - return true; -} - -bool TbeKernelReduceSelecter::IsFracZAndC1HWNCoC0Common(const std::string &format, - mindspore::kernel::SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - if (!Is4DShape(input_shape_)) { - return false; - } - if (!keep_dims_ || axis_.empty()) { - return false; - } - auto reduce_n_c_axis = std::any_of(axis_.begin(), axis_.end(), - [](const size_t &elem) { return (elem == kChannelC || elem == kChannelN); }); - if (reduce_n_c_axis) { - return false; - } - AssignSupportFormat(format, support_format); - return true; -} - -void TbeKernelReduceSelecter::GetReduceAttrKeepDim() { - if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode_ptr_)) { - MS_LOG(INFO) << "This node does't have keep_attr."; - keep_dims_ = false; - return; - } - keep_dims_ = AnfAlgo::GetNodeAttr(cnode_ptr_, kAttrKeepDims); -} - -void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_format_str, - mindspore::kernel::SupportFormat *support_format) const { - MS_EXCEPTION_IF_NULL(support_format); - SupportFormatItem input_support_format; - SupportFormatItem output_support_format; - input_support_format.emplace_back(support_format_str); - output_support_format.emplace_back(support_format_str); - support_format->input_format.emplace_back(input_support_format); - support_format->output_format.emplace_back(output_support_format); -} - -bool TbeKernelReduceSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } - -void TbeKernelReduceSelecter::PadScalarShape(std::vector *shape) const { - MS_EXCEPTION_IF_NULL(shape); - if (shape->empty()) { - shape->emplace_back(1); - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h deleted file mode 100644 index 4cff87d60f..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_KERNEL_REDUCE_SELECTER_H_ -#include -#include -#include -#include "ir/anf.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" -namespace mindspore { -namespace kernel { -class TbeKernelReduceSelecter { - public: - explicit TbeKernelReduceSelecter(CNodePtr cnode_ptr) : cnode_ptr_(std::move(cnode_ptr)) {} - ~TbeKernelReduceSelecter() = default; - bool GetShapeInfo(SupportFormat *support_format); - bool IsReduceSupport5HD(SupportFormat *support_format) const; - bool IsReduceSupportNDC1HWC0(SupportFormat *support_format) const; - bool IsReduceSupportFracZ(SupportFormat *support_format) const; - bool IsReduceSupportC1HWNCoC0(SupportFormat *support_format) const; - bool IsReduceSupportFracNZ(SupportFormat *support_format) const; - - private: - bool IsFracZAndC1HWNCoC0Common(const std::string &format, SupportFormat *support_format) const; - void GetReduceAttrKeepDim(); - void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; - bool Is4DShape(const std::vector &shape) const; - void PadScalarShape(std::vector *shape) const; - CNodePtr cnode_ptr_; - std::vector input_shape_{}; - std::vector output_shape_{}; - std::vector axis_{}; - bool keep_dims_ = false; -}; -} // namespace kernel -} // namespace mindspore -#endif // MINDSPORE_TBE_KERNEL_REDUCE_SELECTER_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc deleted file mode 100644 index 5ef5d50e9c..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ /dev/null @@ -1,623 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_select.h" -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "kernel/tbe/tbe_kernel_build.h" -#include "nlohmann/json.hpp" -#include "utils/context/ms_context.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "pre_activate/common/helper.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "parallel/ops_info/ops_utils.h" -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" -#include "kernel/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -constexpr auto kName = "name"; -constexpr auto kDtype = "dtype"; -constexpr auto kFormat = "format"; -constexpr auto kPrefixInput = "input"; -constexpr auto kPrefixOutput = "output"; -constexpr char kParamTypeDynamic[] = "dynamic"; -constexpr char kParamTypeRequre[] = "required"; -constexpr char kParamTypeOptional[] = "optional"; -void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { - auto tbe_selecter = TbeKernelSelect(kernel_node, kernel_info_list); - tbe_selecter.TbeMetadataInfoEx(); -} - -TbeKernelSelect::TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list) - : cnode_ptr_(std::move(kernel_node)), kernel_info_list_(kernel_info_list) {} - -void TbeKernelSelect::TbeMetadataInfoEx() { - MS_EXCEPTION_IF_NULL(cnode_ptr_); - MS_EXCEPTION_IF_NULL(kernel_info_list_); - node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); - auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); - if (!op_info_ptr) { - MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; - return; - } - MS_LOG(INFO) << "Start to tbe metadata info. node type: " << node_name_ - << ", node name: " << cnode_ptr_->fullname_with_scope(); - OpPattern pattern = op_info_ptr->op_pattern(); - if (pattern == kCommonPattern) { - GetCommonPatternKernelInfo(*op_info_ptr); - } else if (pattern == kDynamicFormatPattern) { - GetDynamicFormatPatternKernelInfo(*op_info_ptr); - } else if (pattern == kFormatAgnosticPattern) { - GetAgnosticPatternKernelInfo(*op_info_ptr); - } else if (pattern == kBroadcastPattern) { - GetBroadcastPatternKernelInfo(*op_info_ptr); - } else if (pattern == kReducePattern) { - GetReducePatternKernelInfo(*op_info_ptr); - } else { - MS_LOG(INFO) << "Warning: op pattern is invailed."; - } - // check support - FilterInVaildKernelInfo(); - MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; -} - -void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - // get dynamic inputs - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr_); - MS_EXCEPTION_IF_NULL(primitive); - std::vector dyn_input_sizes; - if (primitive->HasAttr(kAttrDynInputSizes)) { - dyn_input_sizes = GetValue>(primitive->GetAttr(kAttrDynInputSizes)); - } - // get real input/output num - size_t real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); - const auto inputs_info = op_info.inputs_ptr(); - size_t real_output_tensor_num = AnfAlgo::GetOutputTensorNum(cnode_ptr_); - const auto outputs_info = op_info.outputs_ptr(); - if (inputs_info.empty() && outputs_info.empty()) { - MS_LOG(EXCEPTION) << "op info input & output is null, please check."; - } - // create kernel build info from opinfo - size_t kernel_build_info_num = - inputs_info.empty() ? outputs_info[0]->dtypes().size() : inputs_info[0]->dtypes().size(); - for (size_t kernel_build_info_index = 0; kernel_build_info_index < kernel_build_info_num; ++kernel_build_info_index) { - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - SetTbeBuildCommonInfo(op_info, &builder); - std::vector inputs_format; - std::vector inputs_device_type; - std::vector> inputs_reshape_type; - // input - if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, - &inputs_format, &inputs_device_type, &inputs_reshape_type)) { - break; - } - builder.SetInputsDeviceType(inputs_device_type); - builder.SetInputsFormat(inputs_format); - builder.SetInputReshapeType(inputs_reshape_type); - // output - std::vector outputs_format; - std::vector outputs_device_type; - std::vector> outputs_reshape_type; - if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, - &outputs_format, &outputs_device_type, &outputs_reshape_type)) { - break; - } - builder.SetOutputsDeviceType(outputs_device_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputReshapeType(outputs_reshape_type); - kernel_info_list_->emplace_back(builder.Build()); - } - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetDynamicFormatPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - // - OpInfo op_info_new; - CreateNewOpInfo(op_info, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetAgnosticPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - if (op_info.inputs_ptr().size() != 1) { - MS_LOG(EXCEPTION) << "AgnosticPattern only support one input."; - } - auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode_ptr_, 0); - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_LOG(INFO) << "Got the unknown format " << format; - format = kOpFormat_DEFAULT; - } - SupportFormat support_format; - SupportFormatItem input_item; - SupportFormatItem output_item; - input_item.assign(op_info.inputs_ptr().size(), format); - output_item.assign(op_info.outputs_ptr().size(), format); - support_format.input_format.emplace_back(input_item); - support_format.output_format.emplace_back(output_item); - PrintSupportedFormat(support_format); - OpInfo op_info_new; - CreateNewOpInfo(op_info, support_format, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - auto broadcast_selecter = TbeKernelBroadCastSelecter(cnode_ptr_); - SupportFormat support_format; - broadcast_selecter.GetShapeInfo(&support_format); - if (!broadcast_selecter.IsBroadCastSupport5HD(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support 5HD."; - } - if (!broadcast_selecter.IsBroadCastSupportFracZ(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracZ."; - } - if (!broadcast_selecter.IsBroadCastSupportC1HWNCoC0(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support C1HWNCoC0."; - } - if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { - MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; - } - PrintSupportedFormat(support_format); - OpInfo op_info_new; - CreateNewOpInfo(op_info, support_format, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::GetReducePatternKernelInfo(const OpInfo &op_info) { - MS_LOG(INFO) << "start."; - auto reduce_selecter = TbeKernelReduceSelecter(cnode_ptr_); - SupportFormat support_format; - reduce_selecter.GetShapeInfo(&support_format); - if (!reduce_selecter.IsReduceSupport5HD(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support 5HD."; - } - if (reduce_selecter.IsReduceSupportFracZ(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracZ."; - } - if (reduce_selecter.IsReduceSupportC1HWNCoC0(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support C1HWNCoC0."; - } - if (reduce_selecter.IsReduceSupportFracNZ(&support_format)) { - MS_LOG(INFO) << "Node (" << node_name_ << ") reduce not support FracNZ."; - } - PrintSupportedFormat(support_format); - OpInfo op_info_new; - CreateNewOpInfo(op_info, support_format, &op_info_new); - GetCommonPatternKernelInfo(op_info_new); - MS_LOG(INFO) << "end."; -} - -void TbeKernelSelect::FilterInVaildKernelInfo() { - if (kernel_info_list_->empty()) { - MS_LOG(INFO) << "Warning: get kernel build info failed."; - return; - } - auto kernel_build_info_iter = kernel_info_list_->begin(); - while (kernel_build_info_iter != kernel_info_list_->end()) { - if (!FilterInVaildShape(kernel_build_info_iter)) { - MS_LOG(INFO) << "Filter invaild shape, filter item info: " << (*kernel_build_info_iter)->ToString(); - kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); - continue; - } - if (!TbeCheckSupported(kernel_build_info_iter)) { - MS_LOG(INFO) << "Check support shape, filter item info: " << (*kernel_build_info_iter)->ToString(); - kernel_build_info_iter = kernel_info_list_->erase(kernel_build_info_iter); - continue; - } - kernel_build_info_iter++; - } -} - -bool TbeKernelSelect::FilterInVaildShape( - const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { - MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); - auto kernel_build_info_inputs_format = (*kernel_build_info_iter)->GetAllInputFormats(); - for (size_t i = 0; i < kernel_build_info_inputs_format.size(); ++i) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, i); - auto format = kernel_build_info_inputs_format.at(i); - if (!IsShapeMatchFormat(shape, format)) { - MS_LOG(INFO) << "The " << i << "th input check failed."; - return false; - } - } - auto kernel_build_info_outputs_format = (*kernel_build_info_iter)->GetAllOutputFormats(); - for (size_t j = 0; j < kernel_build_info_outputs_format.size(); ++j) { - auto shape = AnfAlgo::GetOutputInferShape(cnode_ptr_, j); - auto format = kernel_build_info_outputs_format.at(j); - if (!IsShapeMatchFormat(shape, format)) { - MS_LOG(INFO) << "The " << j << "th input check failed."; - return false; - } - } - return true; -} - -bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const std::string &format) { - if (format == kOpFormat_DEFAULT) { - return true; - } - static std::set kServerNotSupportFormat = {kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; - // if format is default, it remarkes support all format - if (kOpFormatList.find(format) == kOpFormatList.end()) { - MS_LOG(EXCEPTION) << "Got the unknown format " << format; - } - // server not support format with C04 suffix - if (std::find(kServerNotSupportFormat.begin(), kServerNotSupportFormat.end(), format) != - kServerNotSupportFormat.end()) { - MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; - return false; - } - // not support format: - // 1 NDHWC with shape size != 5 - // 2 FRAC_NZ with shape size < 2 - // 3 !NDHWC with shape size > 4 - if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || - (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || - (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { - MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); - return false; - } - return true; -} - -bool TbeKernelSelect::TbeCheckSupported( - const mindspore::kernel::TbeKernelSelect::KernelBuildInfoIter &kernel_build_info_iter) { - MS_EXCEPTION_IF_NULL((*kernel_build_info_iter)); - static const std::set kCheckSupportedOpType = {parallel::MATMUL, - parallel::BATCHMATMUL, - parallel::TOPK, - parallel::IN_TOPK, - parallel::PACK, - parallel::GATHER_ND, - parallel::UNSORTEF_SEGMENT_MIND, - parallel::UNSORTEF_SEGMENT_PRODD, - parallel::CAST}; - auto iter = std::find(kCheckSupportedOpType.begin(), kCheckSupportedOpType.end(), node_name_); - if (iter == kCheckSupportedOpType.end()) { - return true; - } - MS_LOG(INFO) << "Check support start."; - // replace kernel_info with current kernel info - auto kernel_build_info_tmp = AnfAlgo::GetSelectKernelBuildInfo(cnode_ptr_); - AnfAlgo::SetSelectKernelBuildInfo(*kernel_build_info_iter, cnode_ptr_.get()); - nlohmann::json kernel_json; - TbeKernelJsonCreator creator(CHECK_SUPPORTED); - bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); - if (!ret) { - MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed."; - } - ret = TbePythonFuncs::CheckSupported(kernel_json); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get()); - return ret; -} - -void TbeKernelSelect::SetTbeBuildCommonInfo(const mindspore::kernel::OpInfo &op_info, - mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder *builder) { - MS_EXCEPTION_IF_NULL(builder); - builder->SetProcessor(AICORE); - std::string fusion_type = op_info.fusion_type(); - if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) { - builder->SetFusionType(tbe::GetFusionType(fusion_type)); - } - builder->SetOpPattern(op_info.op_pattern()); - builder->SetKernelType(TBE_KERNEL); -} - -bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, - const std::vector> &ios_info, - const std::vector &dyn_input_sizes, std::vector *formats, - std::vector *device_types, std::vector> *reshape_types) { - MS_EXCEPTION_IF_NULL(formats); - MS_EXCEPTION_IF_NULL(device_types); - MS_EXCEPTION_IF_NULL(reshape_types); - size_t dynamic_input_index = 0; - size_t real_io_tensor_index = 0; - size_t io_info_index = 0; - size_t io_info_num = ios_info.size(); - for (; io_info_index < io_info_num && real_io_tensor_index < real_io_tensor_num; io_info_index++) { - std::shared_ptr io_info_item = ios_info[io_info_index]; - auto kernel_build_info_dtype = io_info_item->dtypes().at(kernel_build_info_index); - std::string kernel_build_info_format; - if (!io_info_item->formats().empty()) { - kernel_build_info_format = io_info_item->formats().at(kernel_build_info_index); - } - std::string io_param_type = io_info_item->param_type(); - std::vector reshape_type; - StringToAxisVector(io_info_item->reshape_type(), &reshape_type); - if (io_param_type == kParamTypeDynamic) { - // dynamic io - if (is_input) { - if (dynamic_input_index >= dyn_input_sizes.size()) { - MS_LOG(EXCEPTION) << "dyn_input_sizes attr set error, dynamic_input_index: " << dynamic_input_index - << ", dyn_input_sizes size: " << dyn_input_sizes.size(); - } - int dynamic_input_size = dyn_input_sizes[dynamic_input_index]; - for (int i = 0; i < dynamic_input_size; ++i) { - device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); - formats->emplace_back(kernel_build_info_format); - reshape_types->emplace_back(reshape_type); - } - dynamic_input_index++; - real_io_tensor_index += dynamic_input_size; - } else { - if (ios_info.size() != 1) { - MS_LOG(EXCEPTION) << "if output is dynamic, so output must has one output."; - } - for (size_t i = 0; i < real_io_tensor_num; ++i) { - device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); - formats->emplace_back(kernel_build_info_format); - reshape_types->emplace_back(reshape_type); - } - real_io_tensor_index += real_io_tensor_num; - } - } else if (io_param_type == kParamTypeRequre || io_param_type == kParamTypeOptional) { - // requre or optional io - device_types->emplace_back(tbe::DtypeToTypeId(kernel_build_info_dtype)); - formats->emplace_back(kernel_build_info_format); - reshape_types->emplace_back(reshape_type); - real_io_tensor_index++; - } else { - MS_LOG(EXCEPTION) << "op info's param type is not match: " << io_param_type; - } - } - - if (io_info_index != io_info_num) { - MS_LOG(INFO) << "Warning: io_info_index(" << io_info_index << ") != io_info_num(" << io_info_num - << "), this node may has optional input/output."; - } - if (real_io_tensor_index != real_io_tensor_num) { - std::string io_type = is_input ? "inputs " : "outputs"; - MS_LOG(INFO) << node_name_ << "'s " << io_type << "op io info num: " << io_info_num - << ", real io tensor num:" << real_io_tensor_num << "real_io_tensor_index(" << real_io_tensor_index - << ") != real_io_tensor_num(" << real_io_tensor_num << ")"; - return false; - } - return true; -} - -void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { - MS_EXCEPTION_IF_NULL(reshape_type_vec); - for (const auto &c : reshape_type_str) { - switch (c) { - case 'N': - reshape_type_vec->push_back(kernel::N); - break; - case 'C': - reshape_type_vec->push_back(kernel::C); - break; - case 'H': - reshape_type_vec->push_back(kernel::H); - break; - case 'W': - reshape_type_vec->push_back(kernel::W); - break; - default: - MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; - } - } -} - -void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, - const std::vector> &support_format_item, size_t index, - mindspore::kernel::OpIOInfo *op_io_info_new) { - MS_EXCEPTION_IF_NULL(op_io_info_new); - op_io_info_new->set_index(op_io_info.index()); - op_io_info_new->set_name(op_io_info.name()); - op_io_info_new->set_param_type(op_io_info.param_type()); - op_io_info_new->set_need_compile(op_io_info.need_compile()); - op_io_info_new->set_reshape_type(op_io_info.reshape_type()); - op_io_info_new->set_shape(op_io_info.shape()); - // dtype - std::vector dtype_new; - auto dtype = op_io_info.dtypes(); - for (size_t i = 0; i < support_format_item.size(); ++i) { - dtype_new.insert(dtype_new.end(), dtype.begin(), dtype.end()); - } - op_io_info_new->set_dtypes(dtype_new); - // format - std::vector format_new; - for (const auto &formats : support_format_item) { - auto format = formats.at(index); - for (size_t j = 0; j < dtype.size(); ++j) { - format_new.emplace_back(format); - } - } - op_io_info_new->set_formats(format_new); -} - -std::vector TbeKernelSelect::SplitStrToVec(const std::string &op_select_json_item) { - const std::map kDynamicFormatMap = { - {"NCHW", "DefaultFormat"}, {"ND", "DefaultFormat"}, {"FRACTAL_Z", "FracZ"}}; - if (op_select_json_item.empty()) { - MS_LOG(EXCEPTION) << "Op select ret item is null."; - } - const char space = ' '; - const char sep = ','; - std::string op_select_tmp = op_select_json_item + ","; - std::vector ret; - auto begin = op_select_tmp.find_first_not_of(space, 0); - auto sep_pos = op_select_tmp.find(sep); - if (begin >= sep_pos) { - MS_LOG(EXCEPTION) << "Select ret json is error."; - } - while (sep_pos != std::string::npos) { - auto obj = op_select_tmp.substr(begin, sep_pos - begin); - if (kDynamicFormatMap.find(obj) != kDynamicFormatMap.end()) { - obj = kDynamicFormatMap.at(obj); - } - ret.emplace_back(obj); - begin = op_select_tmp.find_first_not_of(space, sep_pos + 1); - sep_pos = op_select_tmp.find(sep, begin); - } - return ret; -} - -std::string TbeKernelSelect::OpSelectFormat() { - nlohmann::json kernel_json; - std::string res_json_str; - TbeKernelJsonCreator creator(OP_SELECT_FORMAT); - bool ret = creator.GenTbeSingleKernelJson(cnode_ptr_, &kernel_json); - if (!ret) { - MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed."; - } - res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json); - if (res_json_str.empty()) { - MS_LOG(EXCEPTION) << "op select format error."; - } - MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str; - return res_json_str; -} - -void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, const SupportFormat &support_format, - mindspore::kernel::OpInfo *op_info_new) { - MS_EXCEPTION_IF_NULL(op_info_new); - if (op_info.inputs_ptr().size() != support_format.input_format[0].size() || - op_info.outputs_ptr().size() != support_format.output_format[0].size()) { - MS_LOG(EXCEPTION) << "BroadCast input/output size not match, op info input size:" << op_info.inputs_ptr().size() - << ", input support size: " << support_format.input_format[0].size() - << ", op info output size: " << op_info.outputs_ptr().size() - << ", output support size: " << support_format.output_format[0].size(); - } - *op_info_new = op_info; - op_info_new->ClearInputs(); - op_info_new->ClearOutputs(); - for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { - auto input = op_info.inputs_ptr().at(i); - auto input_new = std::make_shared(); - CreateNewOpIOInfo(*input, support_format.input_format, i, input_new.get()); - op_info_new->add_inputs_ptr(input_new); - } - for (size_t j = 0; j < op_info.outputs_ptr().size(); ++j) { - auto output = op_info.outputs_ptr().at(j); - auto output_new = std::make_shared(); - CreateNewOpIOInfo(*output, support_format.output_format, j, output_new.get()); - op_info_new->add_outputs_ptr(output_new); - } -} - -struct SelectOpIOInfo { - std::string name; - std::vector dtypes; - std::vector formats; -}; - -void TbeKernelSelect::CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, - mindspore::kernel::OpInfo *op_info_new) { - MS_EXCEPTION_IF_NULL(op_info_new); - auto op_seclect_json = OpSelectFormat(); - if (!op_seclect_json.empty()) { - nlohmann::json json_obj = nlohmann::json::parse(op_seclect_json); - if (!json_obj.is_object()) { - MS_LOG(EXCEPTION) << "JsonStr is not an object, the jsonStr is:" << op_seclect_json; - } - std::vector inputs; - std::vector outputs; - for (const auto &item : json_obj.items()) { - const std::string &item_name = item.key(); - bool is_input = (item_name.find(kPrefixInput) != std::string::npos); - bool is_output = (item_name.find(kPrefixOutput) != std::string::npos); - if (!is_input && !is_output) { - MS_LOG(EXCEPTION) << "op select ret json is error."; - } - if (is_input) { - SelectOpIOInfo select_input; - select_input.name = item.value().at(kName); - std::string input_dtype_item = item.value().at(kDtype); - select_input.dtypes = SplitStrToVec(input_dtype_item); - std::string input_format_item = item.value().at(kFormat); - select_input.formats = SplitStrToVec(input_format_item); - inputs.emplace_back(select_input); - } else if (is_output) { - SelectOpIOInfo select_output; - select_output.name = item.value().at(kName); - std::string input_dtype_item = item.value().at(kDtype); - select_output.dtypes = SplitStrToVec(input_dtype_item); - std::string input_format_item = item.value().at(kFormat); - select_output.formats = SplitStrToVec(input_format_item); - outputs.emplace_back(select_output); - } - } - - if (op_info.inputs_ptr().size() != inputs.size() || op_info.outputs_ptr().size() != outputs.size()) { - MS_LOG(EXCEPTION) << "select format input/output size not equal, please check register."; - } - - *op_info_new = op_info; - op_info_new->ClearInputs(); - op_info_new->ClearOutputs(); - for (size_t i = 0; i < op_info.inputs_ptr().size(); ++i) { - auto input_new = std::make_shared(); - CreateNewOpIOInfo(*op_info.inputs_ptr().at(i), inputs.at(i).dtypes, inputs.at(i).formats, input_new.get()); - op_info_new->add_inputs_ptr(input_new); - } - for (size_t i = 0; i < op_info.outputs_ptr().size(); ++i) { - auto output_new = std::make_shared(); - CreateNewOpIOInfo(*op_info.outputs_ptr().at(i), outputs.at(i).dtypes, outputs.at(i).formats, output_new.get()); - op_info_new->add_outputs_ptr(output_new); - } - } -} - -void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, - const std::vector &support_dtype, - const std::vector &support_format, - mindspore::kernel::OpIOInfo *op_io_info_new) { - MS_EXCEPTION_IF_NULL(op_io_info_new); - op_io_info_new->set_index(op_io_info.index()); - op_io_info_new->set_name(op_io_info.name()); - op_io_info_new->set_param_type(op_io_info.param_type()); - op_io_info_new->set_need_compile(op_io_info.need_compile()); - op_io_info_new->set_reshape_type(op_io_info.reshape_type()); - op_io_info_new->set_shape(op_io_info.shape()); - // dtype && format - op_io_info_new->set_dtypes(support_dtype); - op_io_info_new->set_formats(support_format); -} - -void TbeKernelSelect::PrintSupportedFormat(const SupportFormat &support_format) { - if (support_format.input_format.size() != support_format.output_format.size()) { - MS_LOG(EXCEPTION) << "Input(" << support_format.input_format.size() << ")Output(" - << support_format.output_format.size() << ") size not match."; - } - for (size_t i = 0; i < support_format.input_format.size(); ++i) { - auto input_items = support_format.input_format.at(i); - auto output_items = support_format.output_format.at(i); - std::string print_str = "["; - for (const auto &input : input_items) { - print_str.append(input); - print_str.append(", "); - } - print_str.append("] -->"); - for (const auto &output : output_items) { - print_str.append(output); - print_str.append(", "); - } - MS_LOG(INFO) << "Support format: " << print_str; - } -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h deleted file mode 100644 index c400bdbb6f..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_TBE_KERNEL_SELECT_H -#define MINDSPORE_TBE_KERNEL_SELECT_H - -#include -#include -#include -#include "kernel/oplib/opinfo.h" -#include "kernel/kernel_build_info.h" -#include "kernel/tbe/tbe_kernel_select/common_utils.h" - -namespace mindspore { -namespace kernel { -void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); - -class TbeKernelSelect { - using OpInfoPtr = std::shared_ptr; - using KernelBuildInfoIter = std::vector>::iterator; - - public: - TbeKernelSelect(CNodePtr kernel_node, std::vector> *kernel_info_list); - ~TbeKernelSelect() = default; - void TbeMetadataInfoEx(); - - private: - void GetCommonPatternKernelInfo(const OpInfo &op_info); - void GetDynamicFormatPatternKernelInfo(const OpInfo &op_info); - void GetAgnosticPatternKernelInfo(const OpInfo &op_info); - void GetBroadcastPatternKernelInfo(const OpInfo &op_info); - void GetReducePatternKernelInfo(const OpInfo &op_info); - void FilterInVaildKernelInfo(); - bool FilterInVaildShape(const KernelBuildInfoIter &kernel_build_info_iter); - static bool IsShapeMatchFormat(const std::vector &shape, const std::string &format); - bool TbeCheckSupported(const KernelBuildInfoIter &kernel_build_info_iter); - static void SetTbeBuildCommonInfo(const OpInfo &op_info, KernelBuildInfo::KernelBuildInfoBuilder *builder); - bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, - const std::vector> &ios_info, const std::vector &dyn_input_sizes, - std::vector *formats, std::vector *device_types, - std::vector> *reshape_types); - static void StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec); - static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); - static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, - const std::vector> &support_format_item, size_t index, - OpIOInfo *op_io_info_new); - // op select(dynamic) - void CreateNewOpInfo(const mindspore::kernel::OpInfo &op_info, mindspore::kernel::OpInfo *op_info_new); - static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector &support_dtype, - const std::vector &support_format, OpIOInfo *op_io_info_new); - static std::vector SplitStrToVec(const std::string &op_select_json_item); - std::string OpSelectFormat(); - - static void PrintSupportedFormat(const SupportFormat &support_format); - - private: - CNodePtr cnode_ptr_; - std::vector> *kernel_info_list_; - std::string node_name_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_TBE_KERNEL_SELECT_H diff --git a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc b/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc deleted file mode 100644 index 7204fb7f96..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_python_funcs.h" -#include "kernel/tbe/tbe_utils.h" -#include "common/utils.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace kernel { -using mindspore::kernel::tbe::TbeUtils; -constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_compiler.tbe_process"; -constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; -constexpr auto kOpSelectFormatFunc = "op_select_format"; -constexpr auto kCheckSupportedFunc = "check_supported"; -constexpr auto kTBEException = "TBEException"; - -PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; -PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; -PyObject *TbePythonFuncs::pOpSelectFormatFunc_ = nullptr; -PyObject *TbePythonFuncs::pCheckSupportedFunc_ = nullptr; -bool TbePythonFuncs::Init() { - static bool initialized = false; - if (initialized) { - return true; - } - // Initialize cache - TbeUtils::LoadCache(); - - // tbe_process - PyObject *pTbeProcessModule = nullptr; - pTbeProcessModule = PyImport_ImportModule(kTbeProcessModule); - if (pTbeProcessModule == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kTbeProcessModule << "] module."; - return false; - } - - pCreateTbeParallelCompilerFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCreateTbeParallelCompilerFunc); - if (pCreateTbeParallelCompilerFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kCreateTbeParallelCompilerFunc << "]."; - return false; - } - - pTbeCompiler_ = PyEval_CallObject(pCreateTbeParallelCompilerFunc_, nullptr); - if (pTbeCompiler_ == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function : create_parallel_compiler."; - return false; - } - - pOpSelectFormatFunc_ = PyObject_GetAttrString(pTbeProcessModule, kOpSelectFormatFunc); - if (pOpSelectFormatFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kOpSelectFormatFunc << "]."; - return false; - } - - pCheckSupportedFunc_ = PyObject_GetAttrString(pTbeProcessModule, kCheckSupportedFunc); - if (pCheckSupportedFunc_ == nullptr) { - MS_LOG(ERROR) << "Failed to transform opModule and FuncName to PyObject, opModule:[" << kTbeProcessModule - << "], FuncName:[" << kCheckSupportedFunc << "]."; - return false; - } - initialized = true; - MS_LOG(INFO) << "TbePythonFuncs initialized Success."; - return true; -} - -std::string TbePythonFuncs::PyObjectToStr(PyObject *PyObj) { - char *pChar = nullptr; - std::string str_res; - if (PyObj == nullptr) { - MS_LOG(ERROR) << "Input parameter is nullptr."; - return str_res; - } - PyObject *strArgs = PyObject_Str(PyObj); - if (strArgs != nullptr) { - (void)PyArg_Parse(strArgs, "s", &pChar); - } - if (pChar == nullptr) { - MS_LOG(ERROR) << "pChar is nullptr."; - return str_res; - } - str_res = pChar; - return str_res; -} - -std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { - PyObject *pArg = nullptr; - PyObject *pRet = nullptr; - std::string res_json_str; - - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return res_json_str; - } - - // assembly Args - pArg = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - (void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", json_str.c_str())); - if (pArg == nullptr) { - MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; - return res_json_str; - } - - // call functions - if (pOpSelectFormatFunc_ == nullptr) { - MS_LOG(ERROR) << "function is nullptr."; - return res_json_str; - } - - pRet = PyEval_CallObject(pOpSelectFormatFunc_, pArg); - if (pRet == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc - << "], function args:" << PyObjectToStr(pArg); - } - - char *pstr = nullptr; - (void)PyArg_Parse(pRet, "s", &pstr); - res_json_str = pstr; - if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str - << " ,function args:" << PyObjectToStr(pArg); - } - return res_json_str; -} - -bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { - PyObject *pArg = nullptr; - PyObject *pRes = nullptr; - bool ret = false; - - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return ret; - } - // assembly Args - pArg = PyTuple_New(1); - std::string json_str = kernel_json.dump(); - PyObject *arg1 = Py_BuildValue("s", json_str.c_str()); - (void)PyTuple_SetItem(pArg, 0, arg1); - if (pArg == nullptr) { - MS_LOG(ERROR) << "Failed to generate parameter from kernel_json to PyObject."; - return ret; - } - - // call functions - if (pCheckSupportedFunc_ == nullptr) { - MS_LOG(ERROR) << "function is nullptr."; - return ret; - } - - pRes = PyEval_CallObject(pCheckSupportedFunc_, pArg); - if (pRes == nullptr) { - PyErr_Print(); - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc - << "], function args: " << PyObjectToStr(pArg); - } - if (PyBool_Check(pRes)) { - ret = PyObject_IsTrue(pRes) != 0; - } else { - char *pstr = nullptr; - (void)PyArg_Parse(pRes, "s", &pstr); - std::string res_str = pstr; - if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { - MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str - << ", function args: " << PyObjectToStr(pArg); - } - } - - return ret; -} - -PyObject *TbePythonFuncs::TbeParallelCompiler() { - if (!Init()) { - MS_LOG(ERROR) << "TbePythonFuncs Initialize Failed !"; - return nullptr; - } - return pTbeCompiler_; -} -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_utils.cc b/mindspore/ccsrc/kernel/tbe/tbe_utils.cc deleted file mode 100644 index ae7e5cb6d5..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_utils.cc +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel/tbe/tbe_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "runtime/kernel.h" -#include "kernel/oplib/oplib.h" -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "device/kernel_info.h" -#include "ir/dtype/type.h" -#include "kernel/tbe/tbe_convert_utils.h" -#include "securec/include/securec.h" -#include "operator/ops.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -constexpr auto kCceKernelMeta = "./kernel_meta/"; -constexpr auto kJsonSuffix = ".json"; -constexpr auto kInfoSuffix = ".info"; - -uintptr_t KernelManager::kernel_stub_gen_ = 0; -std::unordered_map KernelManager::info_table_ = {}; - -void TbeUtils::SaveJsonInfo(const std::string &json_name, const std::string &info) { - char real_path[PATH_MAX] = {0}; - std::string path = kCceKernelMeta + json_name + kInfoSuffix; - if (path.size() > PATH_MAX) { - MS_LOG(ERROR) << "file path: " << path << "is too long."; - return; - } - std::ifstream fin(path); - if (fin) { - MS_LOG(INFO) << "json file exist, no need to create."; - return; - } - std::ofstream file_write; - file_write.open(path); - if (!file_write.is_open()) { - return; - } - file_write << info << std::endl; - file_write.close(); - if (realpath(path.c_str(), real_path) == nullptr) { - MS_LOG(INFO) << "dir: " << path << "does not exit."; - return; - } - MS_LOG(INFO) << "real path is: " << real_path; - if (chmod(real_path, S_IRUSR) == -1) { - MS_LOG(INFO) << "modify file: " << real_path << "to read only fail."; - } -} - -void TbeUtils::LoadCache() { - static bool has_load = false; - if (!has_load) { - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map != nullptr && !bin_map->ReadIndex(kCceKernelMeta)) { - MS_LOG(INFO) << "Cache initialize failed[" << kCceKernelMeta << "]"; - } else { - MS_LOG(INFO) << "Cache initialize to " << kCceKernelMeta; - } - has_load = true; - } -} - -KernelPackPtr TbeUtils::SearchCache(const std::string &kernel_name, const std::string &processor) { - // search cache. - KernelMeta *bin_map = KernelMeta::GetInstance(); - if (bin_map == nullptr) { - MS_LOG(INFO) << "kernel cache is invalid."; - return nullptr; - } - return bin_map->GetKernelPack(kernel_name, processor); -} - -KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::string &processor) { - MS_LOG(INFO) << "kernel name: " << kernel_name << ", processr:" << processor; - if (processor != kProcessorAiCore) { - MS_LOG(EXCEPTION) << "process type should be aicore, actually is: " << processor; - } - return SearchCache(kernel_name, processor); -} - -int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, - const string &magic) { - static std::map magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF}, - {"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN}, - {"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU}, - {"RT_DEV_BINARY_MAGIC_ELF_AICPU", RT_DEV_BINARY_MAGIC_ELF_AICPU}}; - // object for device register. - rtDevBinary_t dev_bin; - dev_bin.data = kernel_buffer.contents; - auto iter = magic_maps.find(magic); - if (iter == magic_maps.end()) { - MS_LOG(INFO) << "Invalid magic number: " << magic; - return -1; - } - dev_bin.magic = iter->second; - dev_bin.length = kernel_buffer.len; - dev_bin.version = 2; - if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) { - MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error."; - return -1; - } - return 0; -} - -uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload, - uint32_t *block_dim) { - auto kernel = kernel_pack.GetKernel(); - if (kernel == nullptr) { - MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr."; - } - auto kernel_contents = kernel->contents; - if (kernel_contents == nullptr) { - MS_LOG(EXCEPTION) << "Invalid kernel context, json or kernel is nullptr."; - } - auto kernel_json_info = kernel_pack.kernel_json_info(); - - *block_dim = kernel_json_info.block_dim; - string func_name = kernel_json_info.kernel_name; - string magic = kernel_json_info.magic; - - if (!force_reload) { - // use the cached object. - auto iter = info_table_.find(func_name); - if (iter != info_table_.end()) { - auto kernelmeta = iter->second; - *block_dim = kernelmeta->block_dim_; - return kernelmeta->func_stub_; - } - } - void *module = nullptr; - if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) { - MS_LOG(INFO) << "Call runtime BinaryRegister error."; - return 0; - } - // to diff different funcs. - uintptr_t func_stub = ++kernel_stub_gen_; - if (RT_ERROR_NONE != - rtFunctionRegister(module, reinterpret_cast(func_stub), func_name.c_str(), func_name.c_str(), 0)) { - MS_LOG(INFO) << "Call runtime rtFunctionRegister error."; - return 0; - } - // cache the registered kernelmeta. - info_table_[func_name] = std::make_shared(KernelMetaInfo{func_stub, *block_dim}); - return func_stub; -} - -std::string KernelManager::GetStubFuncName(const KernelPackPtr &kernel_pack) { - MS_EXCEPTION_IF_NULL(kernel_pack); - auto kernel_json_info = kernel_pack->kernel_json_info(); - return kernel_json_info.kernel_name; -} - -KernelMeta *KernelMeta::GetInstance() { - static KernelMeta inst; - return &inst; -} - -bool KernelMeta::ReadIndex(const std::string &bin_dir) { - DIR *dir = opendir(bin_dir.c_str()); - if (dir == nullptr) { - auto ret = mkdir(bin_dir.c_str(), S_IRWXG | S_IRWXU); - if (ret != 0) { - MS_LOG(INFO) << "kernel dir: " << bin_dir << "not exist"; - return false; - } - dir = opendir(bin_dir.c_str()); - } - struct dirent *entry; - while ((entry = readdir(dir)) != nullptr) { - string bin_dir_tmp = bin_dir; - std::string cce_json = entry->d_name; - if (cce_json.length() <= 5) { - continue; - } - std::string suffix = cce_json.substr(cce_json.length() - 5); - if (suffix != kJsonSuffix) { - continue; - } - auto sp = cce_json.rfind('/'); - if (sp != std::string::npos) { - continue; - } - sp = cce_json.rfind('.'); - if (sp == std::string::npos) { - continue; - } - auto kernel_name = cce_json.substr(0, sp); - (void)bin_dir_tmp.append("/"); - (void)bin_dir_tmp.append(cce_json); - kernel_index_map_[kernel_name] = bin_dir_tmp; - } - (void)closedir(dir); - - MS_LOG(INFO) << "Cache kernel initialized, kernel size: " << kernel_index_map_.size(); - return true; -} - -KernelPackPtr KernelMeta::GetKernelPack(const std::string &kernel_name, const std::string &processor) { - KernelPackPtr ret = nullptr; - // 1. pack has been created - auto kernel_pack_iter = kernel_pack_map_.find(kernel_name); - if (kernel_pack_iter != kernel_pack_map_.end()) { - MS_LOG(INFO) << "kernel pack [" << kernel_name << "]has been created."; - ret = kernel_pack_iter->second; - } else { - // 2. kernel file has been create, but pack does not been created. - std::string cce_json = kCceKernelMeta; - (void)cce_json.append(kernel_name).append(kJsonSuffix); - ret = std::make_shared(); - if (!ret->LoadKernelMeta(cce_json, processor)) { - MS_LOG(INFO) << "Read cache json and bin file failed[" << cce_json << "]"; - return nullptr; - } - kernel_pack_map_[kernel_name] = ret; - auto iter = kernel_index_map_.find(kernel_name); - if (iter == kernel_index_map_.end()) { - MS_LOG(INFO) << "kernel name [" << kernel_name << "] has been ceated first."; - kernel_index_map_[kernel_name] = cce_json; - } - } - return ret; -} -} // namespace tbe -} // namespace kernel -} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/tbe/tbe_utils.h b/mindspore/ccsrc/kernel/tbe/tbe_utils.h deleted file mode 100644 index 56fbe7967a..0000000000 --- a/mindspore/ccsrc/kernel/tbe/tbe_utils.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ -#define MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ -#include -#include -#include -#include -#include -#include - -#include "session/kernel_graph.h" -#include "ir/anf.h" -#include "kernel/kernel.h" - -namespace mindspore { -namespace kernel { -namespace tbe { -using std::string; -using std::vector; - -class TbeUtils { - public: - TbeUtils() = default; - - ~TbeUtils() = default; - - static void SaveJsonInfo(const std::string &json_name, const std::string &info); - - static void LoadCache(); - - static KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor); - - static KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor); -}; - -struct KernelMetaInfo { - uintptr_t func_stub_; - uint32_t block_dim_; -}; -using KernelMetaPtr = std::shared_ptr; - -class KernelManager { - public: - static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim); - static std::string GetStubFuncName(const KernelPackPtr &kernel_pack); - - private: - KernelManager() = default; - ~KernelManager() = default; - static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic); - static std::unordered_map info_table_; - static uintptr_t kernel_stub_gen_; -}; - -class KernelMeta { - public: - static KernelMeta *GetInstance(); - bool ReadIndex(const std::string &bin_dir); - KernelPackPtr GetKernelPack(const std::string &kernel_name, const std::string &processor); - - private: - KernelMeta() = default; - ~KernelMeta() = default; - std::unordered_map kernel_index_map_{}; - std::unordered_map kernel_pack_map_{}; -}; -} // namespace tbe -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_KERNEL_TBE_TBE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt new file mode 100644 index 0000000000..df9729c4ee --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -0,0 +1,159 @@ +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-reorder") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-switch") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sequence-point") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-variable") + +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-uninitialized") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-maybe-uninitialized") +endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") + +############################# Options ################################ +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + add_definitions(-D _CRT_RAND_S) +endif () +if (ENABLE_GPUQUE) + add_definitions(-D ENABLE_GPUQUE) + message(STATUS "GPU queue is enabled") +endif () +if (ENABLE_TDTQUE) + add_definitions(-D ENABLE_TDTQUE) + message(STATUS "TDT queue is enabled") +endif () + +# conde coverage +# option(ENABLE_COVERAGE "Enable code coverage report" OFF) +# if (ENABLE_COVERAGE) +# include(${CMAKE_SOURCE_DIR}/cmake/CodeCoverage.cmake) +# append_coverage_compiler_flags() +# endif () + +########### Set up the include directories ########################### +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/runtime/device/ascend/platform) + +include_directories(${CMAKE_BINARY_DIR}) # for protobuf generated .h + +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/mindrecord/include) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include) +###################################################################### + +####################### Flags ######################################## +# compile flags +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default") + +ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR}) + +################## Include sub-modules ############################### +add_subdirectory(util) +add_subdirectory(core) +add_subdirectory(kernels) +add_subdirectory(engine) +add_subdirectory(api) +add_subdirectory(text) +###################################################################### +add_dependencies(utils core) +add_dependencies(kernels-image core) +add_dependencies(kernels-data core) +add_dependencies(kernels core) +add_dependencies(engine-datasetops-source core) +add_dependencies(engine-datasetops-source-sampler core) +add_dependencies(engine-datasetops core) +add_dependencies(engine-opt core) +add_dependencies(engine-perf core) +add_dependencies(engine-gnn core) +add_dependencies(engine core) +add_dependencies(text core) +add_dependencies(text-kernels core) +add_dependencies(cpp-API core) +if (ENABLE_PYTHON) + add_dependencies(APItoPython core) +endif() +if (ENABLE_TDTQUE) + add_dependencies(engine-tdt core) +endif () +################### Create _c_dataengine Library ###################### +set(submodules + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + ) + +if (ENABLE_PYTHON) + set(submodules + ${submodules} + $) +endif() + +if (ENABLE_TDTQUE) + add_library(_c_dataengine SHARED ${submodules} $) +else () + add_library(_c_dataengine SHARED ${submodules}) +endif () + +add_dependencies(_c_dataengine generated_engine_files) + +set_target_properties(_c_dataengine PROPERTIES + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" + ) + +###################################################################### + +################# Link with external libraries ######################## +target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY}) +else() + set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n) + target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY}) +endif() +target_link_libraries(_c_dataengine PUBLIC mindspore::jpeg_turbo mindspore::opencv_core mindspore::opencv_imgcodecs + mindspore::opencv_imgproc mindspore::tinyxml2 ${ICU_LIB}) +if (ENABLE_GPUQUE) + target_link_libraries(_c_dataengine PRIVATE gpu_queue + ${CUDNN_PATH}/lib64/libcudnn.so + ${CUDA_PATH}/lib64/libcudart.so + ${CUDA_PATH}/lib64/stubs/libcuda.so) +endif () + +if (ENABLE_TDTQUE) + target_link_libraries(_c_dataengine PRIVATE ${TSDCLIENT}) +endif () + +add_dependencies(_c_dataengine _c_mindrecord) +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + set(MINDRECORD_LINK_OBJECT ${CMAKE_BINARY_DIR}/mindspore/ccsrc/minddata/mindrecord/CMakeFiles/_c_mindrecord.dir/objects.a) + target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) +else() + target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) +endif() + +if (USE_GLOG) + target_link_libraries(_c_dataengine PRIVATE mindspore::glog) +else() + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_options(_c_dataengine PRIVATE -Wl,-init,mindspore_log_init) + elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") + set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) + endif () +endif() diff --git a/mindspore/ccsrc/dataset/api/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/api/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/api/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc new file mode 100644 index 0000000000..3072a62dc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -0,0 +1,446 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace api { + +#define RETURN_NULL_IF_ERROR(_s) \ + do { \ + Status __rc = (_s); \ + if (__rc.IsError()) { \ + return nullptr; \ + } \ + } while (false) + +// Function to create the iterator, which will build and launch the execution tree. +std::shared_ptr Dataset::CreateIterator() { + std::shared_ptr iter; + try { + iter = std::make_shared(); + Status rc = iter->BuildAndLaunchTree(shared_from_this()); + if (rc.IsError()) { + MS_LOG(ERROR) << "CreateIterator failed."; + return nullptr; + } + + return iter; + } catch (const std::exception &err) { + MS_LOG(ERROR) << "CreateIterator: Iterator exception caught: " << err.what(); + return nullptr; + } + + return iter; +} + +// Constructor +Dataset::Dataset() { + // Fetch some default value from config manager + std::shared_ptr cfg = GlobalContext::config_manager(); + num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + connector_que_size_ = cfg->op_connector_size(); +} + +// Function to create a ImageFolderDataset. +std::shared_ptr ImageFolder(std::string dataset_dir, bool decode, + std::shared_ptr sampler, std::set extensions, + std::map class_indexing) { + // This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false. + bool recursive = false; + + // Create logical representation of ImageFolderDataset. + auto ds = std::make_shared(dataset_dir, decode, sampler, recursive, extensions, class_indexing); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a MnistDataset. +std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Cifar10Dataset. +std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, + std::shared_ptr sampler) { + auto ds = std::make_shared(dataset_dir, num_samples, sampler); + + // Call derived class validation method. + return ds->ValidateParams() ? ds : nullptr; +} + +// Function to create a Batch dataset +std::shared_ptr Dataset::Batch(int32_t batch_size, bool drop_remainder) { + // Default values + std::vector cols_to_map = {}; + std::map>> pad_map; + bool pad = false; + auto ds = std::make_shared(batch_size, drop_remainder, pad, cols_to_map, pad_map); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create Repeat dataset. +std::shared_ptr Dataset::Repeat(int32_t count) { + // Workaround for repeat == 1, do not inject repeat. + if (count == 1) { + return shared_from_this(); + } + + auto ds = std::make_shared(count); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a Map dataset. +std::shared_ptr Dataset::Map(std::vector> operations, + std::vector input_columns, + std::vector output_columns, + const std::vector &project_columns) { + auto ds = std::make_shared(operations, input_columns, output_columns, project_columns); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a ShuffleOp +std::shared_ptr Dataset::Shuffle(int32_t shuffle_size) { + // Pass in reshuffle_each_epoch with true + auto ds = std::make_shared(shuffle_size, true); + + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Function to create a ProjectDataset. +std::shared_ptr Dataset::Project(const std::vector &columns) { + auto ds = std::make_shared(columns); + // Call derived class validation method. + if (!ds->ValidateParams()) { + return nullptr; + } + + ds->children.push_back(shared_from_this()); + + return ds; +} + +// Helper function to create default RandomSampler. +std::shared_ptr CreateDefaultSampler() { + int32_t num_samples = 0; // 0 means to sample all ids. + bool replacement = false; + return std::make_shared(replacement, num_samples); +} + +/* ####################################### Derived Dataset classes ################################# */ + +ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, + bool recursive, std::set extensions, + std::map class_indexing) + : dataset_dir_(dataset_dir), + decode_(decode), + sampler_(sampler), + recursive_(recursive), + class_indexing_(class_indexing), + exts_(extensions) {} + +bool ImageFolderDataset::ValidateParams() { + if (dataset_dir_.empty()) { + MS_LOG(ERROR) << "No dataset path is specified."; + return false; + } + + return true; +} + +std::shared_ptr>> ImageFolderDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler, i.e., RandomSampler. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + // This arg is exist in ImageFolderOp, but not externalized (in Python API). + std::unique_ptr schema = std::make_unique(); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + recursive_, decode_, exts_, class_indexing_, std::move(schema), + std::move(sampler_->Build()))); + return std::make_shared>>(node_ops); +} + +MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), sampler_(sampler) {} + +bool MnistDataset::ValidateParams() { + if (dataset_dir_.empty()) { + MS_LOG(ERROR) << "No dataset path is specified."; + return false; + } + + return true; +} + +std::shared_ptr>> MnistDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler, i.e., RandomSampler. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, + std::move(schema), std::move(sampler_->Build()))); + return std::make_shared>>(node_ops); +} + +BatchDataset::BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, + std::map>> pad_map) + : batch_size_(batch_size), + drop_remainder_(drop_remainder), + pad_(pad), + cols_to_map_(cols_to_map), + pad_map_(pad_map) {} + +std::shared_ptr>> BatchDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + +#ifdef ENABLE_PYTHON + py::function noop; + node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, + cols_to_map_, noop, noop, pad_map_)); +#else + node_ops.push_back(std::make_shared(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_, + cols_to_map_, pad_map_)); +#endif + return std::make_shared>>(node_ops); +} + +bool BatchDataset::ValidateParams() { + if (batch_size_ <= 0) { + return false; + } + + return true; +} + +RepeatDataset::RepeatDataset(uint32_t count) : repeat_count_(count) {} + +std::shared_ptr>> RepeatDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(repeat_count_)); + return std::make_shared>>(node_ops); +} + +bool RepeatDataset::ValidateParams() { + if (repeat_count_ <= 0) { + return false; + } + + return true; +} +MapDataset::MapDataset(std::vector> operations, std::vector input_columns, + std::vector output_columns, const std::vector &project_columns) + : operations_(operations), + input_columns_(input_columns), + output_columns_(output_columns), + project_columns_(project_columns) {} + +std::shared_ptr>> MapDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // Currently default is true, and this is not exposed to user. + bool perf_mode = true; + + std::vector> tensor_ops; + + // Build tensorOp from tensorOperation vector + // This is to ensure each iterator hold its own copy of the tensorOp objects. + (void)std::transform( + operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), + [](std::shared_ptr operation) -> std::shared_ptr { return operation->Build(); }); + + // This parameter will be removed with next rebase + std::vector col_orders; + auto map_op = + std::make_shared(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_, perf_mode); + if (!project_columns_.empty()) { + auto project_op = std::make_shared(project_columns_); + node_ops.push_back(project_op); + } + + node_ops.push_back(map_op); + return std::make_shared>>(node_ops); +} + +bool MapDataset::ValidateParams() { + if (operations_.empty()) { + return false; + } + + return true; +} + +// Constructor for ShuffleDataset +ShuffleDataset::ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch) + : shuffle_size_(shuffle_size), shuffle_seed_(GetSeed()), reset_every_epoch_(reset_every_epoch) {} + +// Function to build the ShuffleOp +std::shared_ptr>> ShuffleDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(shuffle_size_, shuffle_seed_, connector_que_size_, reset_every_epoch_, + rows_per_buffer_)); + return std::make_shared>>(node_ops); +} + +// Function to validate the parameters for ShuffleDataset +bool ShuffleDataset::ValidateParams() { + if (shuffle_size_ <= 1) { + MS_LOG(ERROR) << "ShuffleDataset: Invalid input, shuffle_size: " << shuffle_size_; + return false; + } + + return true; +} + +// Constructor for Cifar10Dataset +Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler) + : dataset_dir_(dataset_dir), num_samples_(num_samples), sampler_(sampler) {} + +bool Cifar10Dataset::ValidateParams() { + if (dataset_dir_.empty()) { + MS_LOG(ERROR) << "No dataset path is specified."; + return false; + } + if (num_samples_ < 0) { + MS_LOG(ERROR) << "Number of samples cannot be negative"; + return false; + } + return true; +} + +// Function to build CifarOp +std::shared_ptr>> Cifar10Dataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // If user does not specify Sampler, create a default sampler based on the shuffle variable. + if (sampler_ == nullptr) { + sampler_ = CreateDefaultSampler(); + } + + // Do internal Schema generation. + auto schema = std::make_unique(); + RETURN_NULL_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_NULL_IF_ERROR( + schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + + node_ops.push_back(std::make_shared(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, + dataset_dir_, connector_que_size_, std::move(schema), + std::move(sampler_->Build()))); + return std::make_shared>>(node_ops); +} + +// Function to build ProjectOp +ProjectDataset::ProjectDataset(const std::vector &columns) : columns_(columns) {} + +bool ProjectDataset::ValidateParams() { + if (columns_.empty()) { + MS_LOG(ERROR) << "No columns are specified."; + return false; + } + return true; +} + +std::shared_ptr>> ProjectDataset::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + node_ops.push_back(std::make_shared(columns_)); + return std::make_shared>>(node_ops); +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc new file mode 100644 index 0000000000..2a6166f868 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc @@ -0,0 +1,1605 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/api/de_pipeline.h" + +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/kernels/py_func_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); + +static std::unordered_map g_parse_op_func_ = { + {kShuffle, &DEPipeline::ParseShuffleOp}, + {kMindrecord, &DEPipeline::ParseMindRecordOp}, + {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, + {kBatch, &DEPipeline::ParseBatchOp}, + {kBucketBatch, &DEPipeline::ParseBucketBatchByLengthOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, + {kRepeat, &DEPipeline::ParseRepeatOp}, + {kSkip, &DEPipeline::ParseSkipOp}, + {kZip, &DEPipeline::ParseZipOp}, + {kConcat, &DEPipeline::ParseConcatOp}, + {kRename, &DEPipeline::ParseRenameOp}, + {kDeviceQueue, &DEPipeline::ParseDeviceQueueOp}, + {kGenerator, &DEPipeline::ParseGeneratorOp}, + {kTfReader, &DEPipeline::ParseTFReaderOp}, + {kProject, &DEPipeline::ParseProjectOp}, + {kTake, &DEPipeline::ParseTakeOp}, + {kImageFolder, &DEPipeline::ParseImageFolderOp}, + {kMnist, &DEPipeline::ParseMnistOp}, + {kManifest, &DEPipeline::ParseManifestOp}, + {kVoc, &DEPipeline::ParseVOCOp}, + {kCoco, &DEPipeline::ParseCocoOp}, + {kCifar10, &DEPipeline::ParseCifar10Op}, + {kCifar100, &DEPipeline::ParseCifar100Op}, + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kRandomData, &DEPipeline::ParseRandomDataOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}, + {kBuildVocab, &DEPipeline::ParseBuildVocabOp}, + {kClue, &DEPipeline::ParseClueOp}}; + +DEPipeline::DEPipeline() : iterator_(nullptr) { + try { + // One time init + (void)GlobalInit(); + + // Instantiate the execution tree + tree_ = std::make_shared(); + repeat_num_ = 1; + batch_size_ = 1; + num_rows_ = 0; + num_classes_ = 0; + temp_batch_size_ = 1; + temp_drop_remainder_ = false; + } catch (const std::exception &err) { + MS_LOG(ERROR) << "Dataset pipeline exception caught on init: " << err.what() << "."; + return; + } +} + +DEPipeline::~DEPipeline() { + { + // Release GIL before joining all threads + py::gil_scoped_release gil_release; + // Release tree + tree_.reset(); + } +} + +// Function to add a Node to the Execution Tree. +Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { + // For each operator, Parse through the list of arguments, then call the respective builder/constructor. + // Note that each call to the parse function may result in building more than one dataset operator. + // For example, one call to ParseNNNOp may result in multiple internal C nodes: + // nodeA + // | + // nodeB + // | + // nodeC + // However, the python side dataset is more abstract, and it does not know about the potential subtree that + // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the + // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child + // to this subtee. + // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse + // function. + DsOpPtr top = nullptr; + DsOpPtr bottom = nullptr; + auto iter = g_parse_op_func_.find(op_name); + if (iter != g_parse_op_func_.end()) { + pFunction func = iter->second; + RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); + + if (top == nullptr) { + RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); + } + + // It is not required that the parse function always produces the bottom pointer. If it's still null, + // then set top and bottom to be the same operator + if (bottom == nullptr) bottom = top; + + // Pack these pointers into a py dict so that we can return both back to python. + (*output)["top"] = top; + (*output)["bottom"] = bottom; + } else { + RETURN_STATUS_UNEXPECTED("No such Op"); + } + // Associate current dataset op node with the tree. + RETURN_IF_NOT_OK(tree_->AssociateNode(top)); + return Status::OK(); +} +// Function to add a child and parent relationship. +Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op) { + // Link this relationship. + // Note parent node takes ownership of the child + return (parent_op->AddChild(child_op)); +} + +// Function to assign the node as root. +Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); } + +// Function to launch the tree execution. +Status DEPipeline::LaunchTreeExec() { + RETURN_IF_NOT_OK(tree_->Prepare()); + RETURN_IF_NOT_OK(tree_->Launch()); + iterator_ = std::make_unique(tree_); + if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator."); + return Status::OK(); +} + +void DEPipeline::PrintTree() { + for (auto itr = tree_->begin(); itr != tree_->end(); ++itr) { + std::stringstream ss; + ss << *itr; + MS_LOG(DEBUG) << "Operator ID is " << itr->id() << ". Details: " << ss.str().c_str() << "."; + } +} + +Status DEPipeline::GetNextAsMap(py::dict *output) { + TensorMap row; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetNextAsMap(&row); + } + RETURN_IF_NOT_OK(s); + // Generate Python dict as return + for (auto el : row) { + (*output)[common::SafeCStr(el.first)] = el.second; + } + return Status::OK(); +} + +Status DEPipeline::GetNextAsList(py::list *output) { + TensorRow row; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->FetchNextTensorRow(&row); + } + RETURN_IF_NOT_OK(s); + // Generate Python list as return + for (auto el : row) { + output->append(el); + } + return Status::OK(); +} + +Status DEPipeline::GetOutputShapes(py::list *output) { + std::vector shapes; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetOutputShapes(&shapes); + } + RETURN_IF_NOT_OK(s); + for (auto el : shapes) { + py::list shape; + for (auto dim : el.AsVector()) { + shape.append(dim); + } + output->append(shape); + } + return Status::OK(); +} + +Status DEPipeline::GetOutputTypes(py::list *output) { + std::vector types; + Status s; + { + py::gil_scoped_release gil_release; + s = iterator_->GetOutputTypes(&types); + } + RETURN_IF_NOT_OK(s); + for (auto el : types) { + output->append(el.AsNumpyType()); + } + return Status::OK(); +} + +int DEPipeline::GetDatasetSize() const { return num_rows_ / batch_size_; } + +int DEPipeline::GetBatchSize() const { return batch_size_; } + +int DEPipeline::GetRepeatCount() const { return repeat_num_; } + +float ToFloat(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +int ToInt(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +bool ToBool(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +std::string ToString(const py::handle &handle) { return py::reinterpret_borrow(handle); } + +std::vector ToStringVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (!l.is_none()) + vector.push_back(py::str(l)); + else + vector.emplace_back(""); + } + return vector; +} + +std::set ToStringSet(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::set set; + for (auto l : list) { + if (!l.is_none()) { + (void)set.insert(py::str(l)); + } + } + return set; +} + +std::map ToStringMap(const py::handle handle) { + py::dict dict = py::reinterpret_borrow(handle); + std::map map; + for (auto p : dict) { + (void)map.insert(std::make_pair(ToString(p.first), ToInt(p.second))); + } + return map; +} + +std::vector ToIntVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (!l.is_none()) { + vector.push_back(ToInt(l)); + } + } + return vector; +} + +std::vector ToTypeVector(const py::handle handle) { + py::list list = py::reinterpret_borrow(handle); + std::vector vector; + for (auto l : list) { + if (l.is_none()) { + vector.emplace_back(DataType()); + } else { + vector.push_back(l.cast()); + } + } + return vector; +} + +Status DEPipeline::SetBatchParameters(const py::dict &args) { + if (args["batch_size"].is_none()) { + std::string err_msg = "Error: batchSize is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + temp_batch_size_ = ToInt(args["batch_size"]); + CHECK_FAIL_RETURN_UNEXPECTED(temp_batch_size_ > 0, "Error: batchSize is invalid."); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "drop_remainder") { + temp_drop_remainder_ = ToBool(value); + } + } + } + + return Status::OK(); +} + +Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + if (!args["buffer_size"].is_none()) { + (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); + } else { + std::string err_msg = "Error: Shuffle buffer size is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "reshuffle_each_epoch") { + (void)builder->SetReshuffleEachEpoch(ToBool(args["reshuffle_each_epoch"])); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded) { + auto sampler = py::reinterpret_borrow(handle); + auto create = sampler.attr("create_for_minddataset"); + auto op = create().cast>(); + std::stack> stack_ops; + while (op != nullptr) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (sampler_op && num_padded > 0) { + sampler_op->SetNumPaddedSamples(num_padded); + stack_ops.push(sampler_op); + } else { + stack_ops.push(op); + } + op = op->GetChildOp(); + } + while (!stack_ops.empty()) { + operators->push_back(stack_ops.top()); + stack_ops.pop(); + } + return Status::OK(); +} + +Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_file"].is_none()) { + std::string err_msg = "Error: at least one of dataset_files is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + bool load_dataset = ToBool(args["load_dataset"]); + if (load_dataset == true) { + (void)builder->SetDatasetFile({ToString(args["dataset_file"])}); + } else { + (void)builder->SetDatasetFile(ToStringVector(args["dataset_file"])); + } + (void)builder->SetLoadDataset(load_dataset); + std::vector in_col_names; + if (!args["columns_list"].is_none()) { + in_col_names = ToStringVector(args["columns_list"]); + if (in_col_names.empty() || in_col_names[0].empty()) { + std::string err_msg = "Error: columns_list is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetColumnsToLoad(in_col_names); + } + + if (!args["padded_sample"].is_none()) { + (void)builder->SetPaddedSample(args["padded_sample"]); + (void)builder->SetNumToPadSamples(ToInt(args["num_padded"])); + } + std::vector> operators; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumMindRecordWorkers(ToInt(value)); + } else if (key == "block_reader" && ToBool(value) == true) { + (void)builder->SetBlockReader(); + } else if (key == "sampler") { + int num_padded = 0; + if (!args["num_padded"].is_none()) { + num_padded = ToInt(args["num_padded"]); + } + RETURN_IF_NOT_OK(BuildMindrecordSamplerChain(value, &operators, num_padded)); + } + } + } + + if (!operators.empty()) { + (void)builder->SetOperators(operators); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + num_rows_ = op->num_rows(); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + MapOp::Builder map_builder; + std::vector> tensor_op_list; + std::vector project_columns; + std::shared_ptr cache_client = nullptr; + int num_workers = 0; + + if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)map_builder.SetInColNames(in_col_names); + } else if (key == "output_columns") { + (void)map_builder.SetOutColNames(ToStringVector(value)); + } else if (key == "columns_order") { + project_columns = ToStringVector(value); + } else if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)map_builder.SetNumWorkers(num_workers); + } else if (key == "prefetch_size") { + (void)map_builder.SetOpConnectorSize(ToInt(value)); + } else if (key == "operations") { + py::handle tensor_ops = args["operations"]; + // operation can be a list of TensorOps or a single TensorOp. + if (py::isinstance(tensor_ops)) { + for (auto op : tensor_ops) { + std::shared_ptr tensor_op; + if (py::isinstance(op)) { + tensor_op = op.cast>(); + } else if (py::isinstance(op)) { + tensor_op = std::make_shared(op.cast()); + } else { + RETURN_STATUS_UNEXPECTED("Error: tensor_op is not recognised (not TensorOp and not pyfunc)."); + } + tensor_op_list.push_back(tensor_op); + } + } + if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); + (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr map_op; + RETURN_IF_NOT_OK(map_builder.Build(&map_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); + *top = map_op; + + // Add a project op over top of the map if the user wanted to reposition the columns + if (!project_columns.empty()) { + ProjectOp::Builder proj_builder(project_columns); + std::shared_ptr proj_op; + RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); + RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); + *top = proj_op; + *bottom = map_op; + } + + // Additionally, add a cache if required. This will go over top of the project op if one + // was created, otherwise it goes over top of the map op + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op)); + *top = cache_op; + *bottom = map_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + + if (args["predicate"].is_none()) { + RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "predicate") { + py::handle op = args["predicate"]; + if (!py::isinstance(op)) { + RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); + } + py::function predicate_func = op.cast(); + (void)builder->SetPredicateFunc(std::move(predicate_func)); + } else if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)builder->SetInColNames(in_col_names); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + repeat_num_ = ToInt(args["count"]); + std::shared_ptr op; + RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "source") { + py::object obj = py::cast(&value); + if (!py::isinstance(obj)) { + std::string err_msg = "Error: generator is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetGeneratorFunction(obj.cast()); + } else if (key == "column_names") { + (void)builder->SetColumnNames(ToStringVector(value)); + } else if (key == "column_types") { + (void)builder->SetColumnTypes(ToTypeVector(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder; + if (py::isinstance(args["batch_size"])) { + batch_size_ = ToInt(args["batch_size"]); + CHECK_FAIL_RETURN_UNEXPECTED(batch_size_ > 0, "Error: batch_size is invalid."); + builder = std::make_shared(ToInt(args["batch_size"])); + } else if (py::isinstance(args["batch_size"])) { + builder = std::make_shared(1); + (void)builder->SetBatchSizeFunc(args["batch_size"].cast()); + } else { + std::string err_msg = "Error: batch_size is neither an Integer nor a python function"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "drop_remainder") { + (void)builder->SetDrop(ToBool(value)); + } + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } + if (key == "per_batch_map") { + (void)builder->SetBatchMapFunc(value.cast()); + } + if (key == "input_columns") { + (void)builder->SetColumnsToMap(ToStringVector(value)); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPaddingMap(pad_info, true); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", + "bucket_batch_sizes"}; + for (auto name : mandatory_arguments) { + if (args[name.c_str()].is_none()) { + std::string err_msg = "Error: " + name + " is not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + std::shared_ptr builder = std::make_shared( + ToStringVector(args[mandatory_arguments[0].c_str()]), ToIntVector(args[mandatory_arguments[1].c_str()]), + ToIntVector(args[mandatory_arguments[2].c_str()])); + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "length_dependent_columns") { + (void)builder->SetLengthDependentColumns(ToStringVector(value)); + } + if (key == "bucket_boundaries") { + (void)builder->SetBucketBoundaries(ToIntVector(value)); + } + if (key == "bucket_batch_sizes") { + (void)builder->SetBucketBatchSizes(ToIntVector(value)); + } + if (key == "element_length_function") { + (void)builder->SetElementLengthFunction(value.cast()); + } + if (key == "pad_info") { + PadInfo pad_info; + RETURN_IF_NOT_OK(ParsePadInfo(value, &pad_info)); + (void)builder->SetPadInfo(pad_info); + } + if (key == "pad_to_bucket_boundary") { + (void)builder->SetPadToBucketBoundary(ToBool(value)); + } + if (key == "drop_remainder") { + (void)builder->SetDropRemainder(ToBool(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + int32_t prefetch_size = 0; + if (args.contains("prefetch_size")) { + if (args["prefetch_size"].is_none()) { + prefetch_size = 16; + } else { + prefetch_size = ToInt(args["prefetch_size"]); + } + } + std::shared_ptr builder = std::make_shared(prefetch_size); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "queue_name") { + (void)builder->SetChannelName(ToString(value)); + } else if (key == "device_type") { + (void)builder->SetDeviceType(ToString(value)); + } else if (key == "device_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "num_batch") { + (void)builder->SetNumBatch(ToInt(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector in_col_names; + std::vector out_col_names; + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "input_columns") { + in_col_names = ToStringVector(value); + } else if (key == "output_columns") { + out_col_names = ToStringVector(value); + } + } + } + if (in_col_names.empty() || in_col_names[0].empty()) { + std::string err_msg = "Error: input_column_names is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (out_col_names.empty() || out_col_names[0].empty()) { + std::string err_msg = "Error: output_column_names is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)builder->SetInColNames(in_col_names); + (void)builder->SetOutColNames(out_col_names); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["count"].is_none()) { + std::string err_msg = "Error: count is invalid or not set."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr op; + RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + std::vector files_list; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetDatasetFilesList(files_list); + } else { + std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_load; + bool schema_exists = false; + bool shuffle_required = false; + int64_t num_devices = 0; + int64_t total_rows = 0; + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); + } else if (key == "columns_list") { + columns_to_load = ToStringVector(value); + (void)builder->SetColumnsToLoad(columns_to_load); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "schema_file_path" || key == "schema_json_string") { + schema_exists = true; + } else if (key == "num_samples") { + total_rows = ToInt(value); + (void)builder->setTotalRows(total_rows); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "shard_equal_rows") { + (void)builder->SetShardEqualRows(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } + } + } + if (schema_exists) { + std::unique_ptr schema = std::make_unique(); + if (args.contains("schema_file_path")) { + RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); + } else { + RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); + } + (void)builder->SetDataSchema(std::move(schema)); + } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because TFReaderOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder->SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder->SetSampler(std::move(sampler)); + } + + std::shared_ptr tf_op; + RETURN_IF_NOT_OK(builder->Build(&tf_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); + *top = tf_op; + + if (!cache_client && shuffle_required) { + const boolean estimate = true; + const int64_t workers = 8; + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset via estimate and then compute the shuffle size + RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); + *top = shuffle_op; + *bottom = tf_op; + } + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + // Note, it is not allowed to have both shuffle and cache + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op)); + *top = cache_op; + *bottom = tf_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["columns"].is_none()) { + std::string err_msg = "Error: columns is missing"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_project = ToStringVector(args["columns"]); + std::shared_ptr builder = std::make_shared(columns_to_project); + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + int num_workers = 0; + std::shared_ptr cache_client = nullptr; + std::shared_ptr builder = std::make_shared(); + (void)builder->SetImageFolderDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder->SetNumWorkers(num_workers); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "extensions") { + (void)builder->SetExtensions(ToStringSet(value)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } + } + } + std::shared_ptr if_op; + RETURN_IF_NOT_OK(builder->Build(&if_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(if_op)); + *top = if_op; + + // Additionally, add a cache if required. + // Note that this cache op is only acting as a place holder for the caching position + // within the tree. Later, a pre-pass will execute a tree transform to set up the actual + // caching logic in the tree. + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op)); + *top = cache_op; + *bottom = if_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_file"].is_none()) { + std::string err_msg = "Error: No dataset files specified for manifest"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::shared_ptr builder = std::make_shared(); + (void)builder->SetManifestFile(ToString(args["dataset_file"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "usage") { + (void)builder->SetUsage(ToString(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["task"].is_none()) { + std::string err_msg = "Error: No task specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["mode"].is_none()) { + std::string err_msg = "Error: No mode specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + (void)builder->SetTask(ToString(args["task"])); + (void)builder->SetMode(ToString(args["mode"])); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "class_indexing") { + (void)builder->SetClassIndex(ToStringMap(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + + return Status::OK(); +} + +Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["annotation_file"].is_none()) { + std::string err_msg = "Error: No annotation_file specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + if (args["task"].is_none()) { + std::string err_msg = "Error: No task specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + (void)builder->SetFile(ToString(args["annotation_file"])); + (void)builder->SetTask(ToString(args["task"])); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetCifarDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + + (void)builder->SetCifarType(true); + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetCifarDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + + (void)builder->SetCifarType(false); + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + RandomDataOp::Builder builder; + std::shared_ptr cache_client = nullptr; + std::shared_ptr sampler = nullptr; + int num_workers = 0; + + if (args["total_rows"].is_none()) { + std::string err_msg = "Error: total_rows is a required argument"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::vector columns_to_load; + bool schema_exists = false; + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + num_workers = ToInt(value); + (void)builder.SetNumWorkers(num_workers); + } else if (key == "schema_file_path" || key == "schema_json_string") { + schema_exists = true; + } else if (key == "columns_list") { + columns_to_load = ToStringVector(value); + } else if (key == "total_rows") { + // This is not sampling here. The random data op needs to know how much data to generate. + (void)builder.SetTotalRows(ToInt(value)); + } else if (key == "cache") { + cache_client = value.cast>(); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + sampler = create().cast>(); + } + } + } + if (schema_exists) { + std::unique_ptr schema = std::make_unique(); + if (args.contains("schema_file_path")) { + RETURN_IF_NOT_OK(schema->LoadSchemaFile(ToString(args["schema_file_path"]), columns_to_load)); + } else { + RETURN_IF_NOT_OK(schema->LoadSchemaString(ToString(args["schema_json_string"]), columns_to_load)); + } + (void)builder.SetDataSchema(std::move(schema)); + } + + // If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed + // because RandomDataOp is a non-mappable dataset that does not support sampling. + // However, if a cache operator is injected at some other place higher in the tree, that cache can + // inherit this sampler from the leaf, providing sampling support from the caching layer. + // That is why we save the sampler here in a leaf node that does not use sampling. + if (sampler) { + (void)builder.SetSampler(std::move(sampler)); + } else if (cache_client) { + int64_t num_samples = 0; + int64_t start_index = 0; + sampler = std::make_shared(num_samples, start_index); + (void)builder.SetSampler(std::move(sampler)); + } + + std::shared_ptr random_op = nullptr; + RETURN_IF_NOT_OK(builder.Build(&random_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(random_op)); + *top = random_op; + + // Add a cache op over this op if required and update the output subtree (top/bottom) + if (cache_client) { + std::shared_ptr cache_op = nullptr; + RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op)); + *top = cache_op; + *bottom = random_op; + } + + return Status::OK(); +} + +int32_t DEPipeline::GetNumClasses() const { return num_classes_; } + +Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::shared_ptr builder = std::make_shared(); + (void)builder->SetDir(ToString(args["dataset_dir"])); + + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + if (args["dataset_dir"].is_none()) { + std::string err_msg = "Error: No dataset path specified"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + + std::shared_ptr builder = std::make_shared(); + if (builder == nullptr) { + std::string err_msg = "Create celebaop builder failed"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + (void)builder->SetCelebADir(ToString(args["dataset_dir"])); + for (const auto &arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "sampler") { + auto create = py::reinterpret_borrow(value).attr("create"); + std::shared_ptr sampler = create().cast>(); + (void)builder->SetSampler(std::move(sampler)); + } else if (key == "decode") { + (void)builder->SetDecode(ToBool(value)); + } else if (key == "extensions") { + (void)builder->SetExtensions(ToStringSet(value)); + } else if (key == "dataset_type") { + (void)builder->SetDatasetType(ToString(value)); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + // Required arguments + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetTextFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetTotalRows(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } + } + } + + std::shared_ptr txt_op; + RETURN_IF_NOT_OK(builder->Build(&txt_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); + *top = txt_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); + *top = shuffle_op; + *bottom = txt_op; + } + + return Status::OK(); +} + +Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + auto tp = py::reinterpret_borrow(p.second); + CHECK_FAIL_RETURN_UNEXPECTED(tp.size() == 2, "tuple in pad_info must be (list,int) or (list,float)"); + TensorShape shape = tp[0].is_none() ? TensorShape::CreateUnknownRankShape() : TensorShape(tp[0]); + std::shared_ptr pad_val = nullptr; + if (py::isinstance(tp[1])) { + std::string pad_val_string = tp[1].is_none() ? "" : ToString(tp[1]); + CHECK_FAIL_RETURN_UNEXPECTED( + Tensor::CreateTensor(&pad_val, std::vector{pad_val_string}, TensorShape::CreateScalar()), + "Cannot create pad_value Tensor"); + } else { + float pad_val_float = tp[1].is_none() ? 0 : ToFloat(tp[1]); + CHECK_FAIL_RETURN_UNEXPECTED(Tensor::CreateTensor(&pad_val, TensorImpl::kFlexible, TensorShape::CreateScalar(), + DataType(DataType::DE_FLOAT32)), + "Cannot create pad_value Tensor"); + pad_val->SetItemAt({}, pad_val_float); + } + (void)pad_info->insert({ToString(p.first), {shape, pad_val}}); + } else { // tuple is None + (void)pad_info->insert({ToString(p.first), {TensorShape({}), nullptr}}); + } + } + return Status::OK(); +} + +Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::shared_ptr builder = std::make_shared(); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "freq_range") { + py::tuple tp = py::reinterpret_borrow(value); + if (!tp[0].is_none()) (void)builder->SetMinFreq(py::reinterpret_borrow(tp[0])); + if (!tp[1].is_none()) (void)builder->SetMaxFreq(py::reinterpret_borrow(tp[1])); + } else if (key == "top_k") { + builder->SetTopK(py::reinterpret_borrow(value)); + } else if (key == "columns") { + (void)builder->SetColumnNames(ToStringVector(value)); + } else if (key == "vocab") { + (void)builder->SetVocab(value.cast>()); + } else if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "special_first") { + (void)builder->SetSpecialFirst(ToBool(value)); + } else if (key == "special_tokens") { + (void)builder->SetSpecialTokens(ToStringVector(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *top = op; + return Status::OK(); +} + +Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector files_list; + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetClueFilesList(files_list); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "shuffle_global") { + shuffle_required = ToBool(value); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } else if (key == "cols_to_keyword") { + std::map map_dict; + for (auto p : py::reinterpret_borrow(value)) { + if (!p.second.is_none()) { + map_dict.insert({ToString(p.first), ToString(p.second)}); + } else { + map_dict.insert({ToString(p.first), ToString(p.first)}); + } + } + (void)builder->SetColsKeyMap(map_dict); + } + } + } + + std::shared_ptr clue_op; + RETURN_IF_NOT_OK(builder->Build(&clue_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); + *top = clue_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); + *top = shuffle_op; + *bottom = clue_op; + } + + return Status::OK(); +} + +// Helper function to inject the cache operator over top of the current operation being built. +Status DEPipeline::AddCacheOp(std::shared_ptr cache_client, int num_workers, + std::shared_ptr input_op, std::shared_ptr *cache_op) { + std::shared_ptr new_cache_op = nullptr; + CacheOp::Builder cache_builder; + // use the same number of workers as the leaf. We need some optimization here, the user does not + // give the cache op number of workers directly. + if (num_workers != 0) { + (void)cache_builder.SetNumWorkers(num_workers); + } + (void)cache_builder.SetClient(cache_client); + RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op)); + RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op)); + // We have now created: + // + // CacheOp + // | + // input_op + // + *cache_op = new_cache_op; + + return Status::OK(); +} + +// Helper function to inject a shuffle operator over top of the current operation being built. +Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op) { + std::shared_ptr new_shuffle_op = nullptr; + ShuffleOp::Builder shuffle_builder; + + (void)shuffle_builder.SetShuffleSize(shuffle_size); + RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); + RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); + // We have now created: + // + // ShuffleOp + // | + // input_op + // + *shuffle_op = new_shuffle_op; + + return Status::OK(); +} + +// Common code for computing a default shuffle size +Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size) { + const int64_t average_files_multiplier = 4; + const int64_t shuffle_max = 10000; + int64_t avg_rows_per_file = 0; + + // Adjust the num rows per shard if sharding was given + if (num_devices > 0) { + if (num_rows % num_devices == 0) { + num_rows = num_rows / num_devices; + } else { + num_rows = (num_rows / num_devices) + 1; + } + } + + // Cap based on total rows directive. Some ops do not have this and give value of 0. + if (total_rows > 0) { + num_rows = std::min(num_rows, total_rows); + } + + // get the average per file + avg_rows_per_file = num_rows / num_files; + + *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h new file mode 100644 index 0000000000..755e827ef2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/de_pipeline.h @@ -0,0 +1,225 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_API_DE_PIPELINE_H_ +#define DATASET_API_DE_PIPELINE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/client.h" // DE client +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/util/status.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +using DsOpPtr = std::shared_ptr; + +class CacheClient; + +// enum for the dataset operator names +enum OpName { + kShuffle, + kMindrecord, + kBatch, + kBucketBatch, + kBarrier, + kCache, + kRepeat, + kSkip, + kTake, + kZip, + kConcat, + kMap, + kFilter, + kDeviceQueue, + kGenerator, + kRename, + kTfReader, + kProject, + kImageFolder, + kMnist, + kManifest, + kVoc, + kCoco, + kCifar10, + kCifar100, + kCelebA, + kRandomData, + kTextFile, + kBuildVocab, + kClue +}; + +// The C++ binder class that we expose to the python script. +class DEPipeline { + public: + DEPipeline(); + + ~DEPipeline(); + + // Function to add a Node to the Execution Tree. + Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); + + // Function to add a child and parent relationship. + static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); + + // Function to assign the node as root. + Status AssignRootNode(const DsOpPtr &dataset_op); + + // Function to launch the tree execution. + Status LaunchTreeExec(); + + // Get a row of data as dictionary of column name to the value. + Status GetNextAsMap(py::dict *output); + + // Get a row of data as list. + Status GetNextAsList(py::list *output); + + Status GetOutputShapes(py::list *output); + + Status GetOutputTypes(py::list *output); + + int GetDatasetSize() const; + + int GetBatchSize() const; + + int GetRepeatCount() const; + + Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status BuildMindrecordSamplerChain(const py::handle &handle, + std::vector> *operators, + int num_padded); + + Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom); + + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + void PrintTree(); + + int32_t GetNumClasses() const; + + Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status SetBatchParameters(const py::dict &args); + + Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); + + private: + // Execution tree that links the dataset operators. + std::shared_ptr tree_; + + std::unique_ptr iterator_; + + static Status ParsePadInfo(py::handle value, PadInfo *pad_info); + + /// \brief Helper function to inject a cache operator over top of the current operation being built. + /// \param[in] cache_client The client to use for caching + /// \param[in] num_workers The number of workers to use in the cache op + /// \param[in] input_op The operator to build the cache on top of + /// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the cache operator + /// \return Status return code + Status AddCacheOp(std::shared_ptr cache_client, int num_workers, std::shared_ptr input_op, + std::shared_ptr *cache_op); + + /// \brief Helper function to inject a shuffle operator over top of the current operation being built. + /// \param[in] shuffle_size The size to use in the shuffle buffer + /// \param[in] input_op The operator to build shuffle on top of + /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the shuffle operator + /// \return Status return code + Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op); + + /// \brief Helper function to compute the shuffle size + /// \param[in] num_files The number of files in the dataset + /// \param[in] num_devices The number of devices in the dataset + /// \param[in] num_rows The number of rows in the dataset + /// \param[in] total_rows An upper bound on the total rows in the dataset + /// \param[out] shuffle_size The resultant computed shuffle size + /// \return Status return code + Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size); + + int batch_size_; + int repeat_num_; + int num_rows_; + int num_classes_; + + int temp_batch_size_; + bool temp_drop_remainder_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_API_DE_PIPELINE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/api/iterator.cc b/mindspore/ccsrc/minddata/dataset/api/iterator.cc new file mode 100644 index 0000000000..068bcfaa04 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/iterator.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/include/iterator.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/include/datasets.h" + +namespace mindspore { +namespace dataset { +namespace api { + +// Get the next row from the data pipeline. +void Iterator::GetNextRow(TensorMap *row) { + Status rc = iterator_->GetNextAsMap(row); + if (rc.IsError()) { + MS_LOG(ERROR) << "GetNextRow: Failed to get next row."; + row->clear(); + } +} + +// Shut down the data pipeline. +void Iterator::Stop() { + // Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_. + iterator_.reset(); + + // Release ownership of tree_ shared pointer. This will decrement the ref count. + tree_.reset(); +} + +// Function to build and launch the execution tree. +Status Iterator::BuildAndLaunchTree(std::shared_ptr ds) { + // One time init + Status rc; + rc = GlobalInit(); + RETURN_IF_NOT_OK(rc); + + // Instantiate the execution tree + tree_ = std::make_shared(); + + // Iterative BFS converting Dataset tree into runtime Execution tree. + std::queue, std::shared_ptr>> q; + + if (ds != nullptr) { + // Convert the current root node. + auto root_op = ds->Build()->front(); + RETURN_UNEXPECTED_IF_NULL(root_op); + + RETURN_IF_NOT_OK(tree_->AssociateNode(root_op)); + + q.push(std::make_pair(ds, root_op)); + + // Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes) + while (!q.empty()) { + auto node_pair = q.front(); + q.pop(); + // Iterate through all the direct children of the first element in our BFS queue + for (auto child : node_pair.first->children) { + auto child_ops = child->Build(); + RETURN_UNEXPECTED_IF_NULL(child_ops); + auto node_op = node_pair.second; + // Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them + // with the execution tree and add the child and parent relationship between the nodes + // Note that some Dataset objects might return more than one DatasetOps + // e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset + for (auto child_op : *child_ops) { + RETURN_IF_NOT_OK(tree_->AssociateNode(child_op)); + RETURN_IF_NOT_OK(node_op->AddChild(child_op)); + node_op = child_op; + } + // Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current + // execution tree) to the BFS queue + q.push(std::make_pair(child, child_ops->back())); + } + } + RETURN_IF_NOT_OK(tree_->AssignRoot(root_op)); + } + + // Launch the execution tree. + RETURN_IF_NOT_OK(tree_->Prepare()); + RETURN_IF_NOT_OK(tree_->Launch()); + iterator_ = std::make_unique(tree_); + RETURN_UNEXPECTED_IF_NULL(iterator_); + + return rc; +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc new file mode 100644 index 0000000000..145291ec3b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -0,0 +1,954 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "minddata/dataset/api/de_pipeline.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/gnn/graph.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" +#include "minddata/dataset/kernels/data/duplicate_op.h" +#include "minddata/dataset/kernels/data/fill_op.h" +#include "minddata/dataset/kernels/data/mask_op.h" +#include "minddata/dataset/kernels/data/one_hot_op.h" +#include "minddata/dataset/kernels/data/pad_end_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" +#include "minddata/dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_resize_op.h" +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/rescale_op.h" +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/uniform_aug_op.h" +#include "minddata/dataset/kernels/no_op.h" +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" +#include "minddata/dataset/text/kernels/lookup_op.h" +#include "minddata/dataset/text/kernels/ngram_op.h" +#include "minddata/dataset/text/kernels/to_number_op.h" +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/util/random.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_pk_sample.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_sequential_sample.h" +#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "pybind11/stl_bind.h" + +#ifdef ENABLE_ICU4C +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/bert_tokenizer_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" +#endif + +namespace py = pybind11; + +namespace mindspore { +namespace dataset { +#define THROW_IF_ERROR(s) \ + do { \ + Status rc = std::move(s); \ + if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ + } while (false) + +void bindDEPipeline(py::module *m) { + (void)py::class_(*m, "DEPipeline") + .def(py::init<>()) + .def( + "AddNodeToTree", + [](DEPipeline &de, const OpName &op_name, const py::dict &args) { + py::dict out; + THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); + return out; + }, + py::return_value_policy::reference) + .def_static("AddChildToParentNode", + [](const DsOpPtr &child_op, const DsOpPtr &parent_op) { + THROW_IF_ERROR(DEPipeline::AddChildToParentNode(child_op, parent_op)); + }) + .def("AssignRootNode", + [](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); }) + .def("SetBatchParameters", + [](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); }) + .def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); }) + .def("GetNextAsMap", + [](DEPipeline &de) { + py::dict out; + THROW_IF_ERROR(de.GetNextAsMap(&out)); + return out; + }) + .def("GetNextAsList", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetNextAsList(&out)); + return out; + }) + .def("GetOutputShapes", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetOutputShapes(&out)); + return out; + }) + .def("GetOutputTypes", + [](DEPipeline &de) { + py::list out; + THROW_IF_ERROR(de.GetOutputTypes(&out)); + return out; + }) + .def("GetDatasetSize", &DEPipeline::GetDatasetSize) + .def("GetBatchSize", &DEPipeline::GetBatchSize) + .def("GetNumClasses", &DEPipeline::GetNumClasses) + .def("GetRepeatCount", &DEPipeline::GetRepeatCount); +} +void bindDatasetOps(py::module *m) { + (void)py::class_>(*m, "TFReaderOp") + .def_static("get_num_rows", [](const py::list &files, int64_t numParallelWorkers, bool estimate = false) { + int64_t count = 0; + std::vector filenames; + for (auto l : files) { + !l.is_none() ? filenames.push_back(py::str(l)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TFReaderOp::CountTotalRows(&count, filenames, numParallelWorkers, estimate)); + return count; + }); + + (void)py::class_>(*m, "CifarOp") + .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { + int64_t count = 0; + THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); + return count; + }); + + (void)py::class_>(*m, "ImageFolderOp") + .def_static("get_num_rows_and_classes", [](const std::string &path) { + int64_t count = 0, num_classes = 0; + THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set{}, &count, &num_classes)); + return py::make_tuple(count, num_classes); + }); + + (void)py::class_>(*m, "MindRecordOp") + .def_static("get_num_rows", [](const std::vector &paths, bool load_dataset, const py::object &sampler, + const int64_t num_padded) { + int64_t count = 0; + std::shared_ptr op; + if (py::hasattr(sampler, "create_for_minddataset")) { + auto create = sampler.attr("create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(paths, load_dataset, op, &count, num_padded)); + return count; + }); + + (void)py::class_>(*m, "ManifestOp") + .def_static("get_num_rows_and_classes", + [](const std::string &file, const py::dict &dict, const std::string &usage) { + int64_t count = 0, num_classes = 0; + THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); + return py::make_tuple(count, num_classes); + }) + .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { + std::map output_class_indexing; + THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); + return output_class_indexing; + }); + + (void)py::class_>(*m, "MnistOp") + .def_static("get_num_rows", [](const std::string &dir) { + int64_t count = 0; + THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); + return count; + }); + + (void)py::class_>(*m, "TextFileOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); + return count; + }); + + (void)py::class_>(*m, "ClueOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + file.is_none() ? (void)filenames.emplace_back("") : filenames.push_back(py::str(file)); + } + THROW_IF_ERROR(ClueOp::CountAllFileRows(filenames, &count)); + return count; + }); + + (void)py::class_>(*m, "VOCOp") + .def_static("get_num_rows", + [](const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t numSamples) { + int64_t count = 0; + THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); + return count; + }) + .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, + const std::string &task_mode, const py::dict &dict) { + std::map output_class_indexing; + THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); + return output_class_indexing; + }); + (void)py::class_>(*m, "CocoOp") + .def_static("get_class_indexing", + [](const std::string &dir, const std::string &file, const std::string &task) { + std::vector>> output_class_indexing; + THROW_IF_ERROR(CocoOp::GetClassIndexing(dir, file, task, &output_class_indexing)); + return output_class_indexing; + }) + .def_static("get_num_rows", [](const std::string &dir, const std::string &file, const std::string &task) { + int64_t count = 0; + THROW_IF_ERROR(CocoOp::CountTotalRows(dir, file, task, &count)); + return count; + }); +} +void bindTensor(py::module *m) { + (void)py::class_(*m, "GlobalContext") + .def_static("config_manager", &GlobalContext::config_manager, py::return_value_policy::reference); + + (void)py::class_>(*m, "ConfigManager") + .def("__str__", &ConfigManager::ToString) + .def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer) + .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) + .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) + .def("set_op_connector_size", &ConfigManager::set_op_connector_size) + .def("set_seed", &ConfigManager::set_seed) + .def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval) + .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) + .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) + .def("get_worker_connector_size", &ConfigManager::worker_connector_size) + .def("get_op_connector_size", &ConfigManager::op_connector_size) + .def("get_seed", &ConfigManager::seed) + .def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval) + .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); + + (void)py::class_>(*m, "Tensor", py::buffer_protocol()) + .def(py::init([](py::array arr) { + std::shared_ptr out; + THROW_IF_ERROR(Tensor::CreateTensor(&out, arr)); + return out; + })) + .def_buffer([](Tensor &tensor) { + py::buffer_info info; + THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); + return info; + }) + .def("__str__", &Tensor::ToString) + .def("shape", &Tensor::shape) + .def("type", &Tensor::type) + .def("as_array", [](py::object &t) { + auto &tensor = py::cast(t); + if (tensor.type() == DataType::DE_STRING) { + py::array res; + tensor.GetDataAsNumpyStrings(&res); + return res; + } + py::buffer_info info; + THROW_IF_ERROR(Tensor::GetBufferInfo(&tensor, &info)); + return py::array(pybind11::dtype(info), info.shape, info.strides, info.ptr, t); + }); + + (void)py::class_(*m, "TensorShape") + .def(py::init()) + .def("__str__", &TensorShape::ToString) + .def("as_list", &TensorShape::AsPyList) + .def("is_known", &TensorShape::known); + + (void)py::class_(*m, "DataType") + .def(py::init()) + .def(py::self == py::self) + .def("__str__", &DataType::ToString) + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); +} + +void bindTensorOps1(py::module *m) { + (void)py::class_>(*m, "TensorOp") + .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); + + (void)py::class_>( + *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") + .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), + py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); + + (void)py::class_>( + *m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.") + .def(py::init(), py::arg("rescale"), py::arg("shift")); + + (void)py::class_>( + *m, "CenterCropOp", "Tensor operation to crop and image in the middle. Takes height and width (optional)") + .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); + + (void)py::class_>( + *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); + + (void)py::class_>( + *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, + py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); + + (void)py::class_>( + *m, "RandomResizeWithBBoxOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); + + (void)py::class_>( + *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") + .def(py::init>, int32_t>(), py::arg("operations"), + py::arg("NumOps") = UniformAugOp::kDefNumOps); + + (void)py::class_>( + *m, "BoundingBoxAugmentOp", "Tensor operation to apply a transformation on a random choice of bounding boxes.") + .def(py::init, float>(), py::arg("transform"), + py::arg("ratio") = BoundingBoxAugmentOp::kDefRatio); + + (void)py::class_>( + *m, "ResizeBilinearOp", + "Tensor operation to resize an image using " + "Bilinear mode. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeBilinearOp::kDefWidth); + + (void)py::class_>(*m, "DecodeOp", + "Tensor operation to decode a jpg image") + .def(py::init<>()) + .def(py::init(), py::arg("rgb_format") = DecodeOp::kDefRgbFormat); + + (void)py::class_>( + *m, "RandomHorizontalFlipOp", "Tensor operation to randomly flip an image horizontally.") + .def(py::init(), py::arg("probability") = RandomHorizontalFlipOp::kDefProbability); + + (void)py::class_>( + *m, "RandomHorizontalFlipWithBBoxOp", + "Tensor operation to randomly flip an image horizontally, while flipping bounding boxes.") + .def(py::init(), py::arg("probability") = RandomHorizontalFlipWithBBoxOp::kDefProbability); +} + +void bindTensorOps2(py::module *m) { + (void)py::class_>( + *m, "RandomVerticalFlipOp", "Tensor operation to randomly flip an image vertically.") + .def(py::init(), py::arg("probability") = RandomVerticalFlipOp::kDefProbability); + + (void)py::class_>( + *m, "RandomVerticalFlipWithBBoxOp", + "Tensor operation to randomly flip an image vertically" + " and adjust bounding boxes.") + .def(py::init(), py::arg("probability") = RandomVerticalFlipWithBBoxOp::kDefProbability); + + (void)py::class_>(*m, "RandomCropOp", + "Gives random crop of specified size " + "Takes crop size") + .def(py::init(), + py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropOp::kDefPadTop, + py::arg("padBottom") = RandomCropOp::kDefPadBottom, py::arg("padLeft") = RandomCropOp::kDefPadLeft, + py::arg("padRight") = RandomCropOp::kDefPadRight, py::arg("borderType") = RandomCropOp::kDefBorderType, + py::arg("padIfNeeded") = RandomCropOp::kDefPadIfNeeded, py::arg("fillR") = RandomCropOp::kDefFillR, + py::arg("fillG") = RandomCropOp::kDefFillG, py::arg("fillB") = RandomCropOp::kDefFillB); + (void)py::class_>(*m, "ChannelSwapOp").def(py::init<>()); + + (void)py::class_>(*m, "RandomCropWithBBoxOp", + "Gives random crop of given " + "size + adjusts bboxes " + "Takes crop size") + .def(py::init(), + py::arg("cropHeight"), py::arg("cropWidth"), py::arg("padTop") = RandomCropWithBBoxOp::kDefPadTop, + py::arg("padBottom") = RandomCropWithBBoxOp::kDefPadBottom, + py::arg("padLeft") = RandomCropWithBBoxOp::kDefPadLeft, + py::arg("padRight") = RandomCropWithBBoxOp::kDefPadRight, + py::arg("borderType") = RandomCropWithBBoxOp::kDefBorderType, + py::arg("padIfNeeded") = RandomCropWithBBoxOp::kDefPadIfNeeded, + py::arg("fillR") = RandomCropWithBBoxOp::kDefFillR, py::arg("fillG") = RandomCropWithBBoxOp::kDefFillG, + py::arg("fillB") = RandomCropWithBBoxOp::kDefFillB); + + (void)py::class_>( + *m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.") + .def(py::init()); + + (void)py::class_>( + *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") + .def(py::init>()); + + (void)py::class_>(*m, "SliceOp", "Tensor slice operation.") + .def(py::init()) + .def(py::init([](const py::list &py_list) { + std::vector c_list; + for (auto l : py_list) { + if (!l.is_none()) { + c_list.push_back(py::reinterpret_borrow(l)); + } + } + return std::make_shared(c_list); + })) + .def(py::init([](const py::tuple &py_slice) { + if (py_slice.size() != 3) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + Slice c_slice; + if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1]), + py::reinterpret_borrow(py_slice[2])); + } else if (py_slice[0].is_none() && py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[1])); + } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1])); + } + + if (!c_slice.valid()) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + return std::make_shared(c_slice); + })); + + (void)py::enum_(*m, "RelationalOp", py::arithmetic()) + .value("EQ", RelationalOp::kEqual) + .value("NE", RelationalOp::kNotEqual) + .value("LT", RelationalOp::kLess) + .value("LE", RelationalOp::kLessEqual) + .value("GT", RelationalOp::kGreater) + .value("GE", RelationalOp::kGreaterEqual) + .export_values(); + + (void)py::class_>(*m, "MaskOp", + "Tensor mask operation using relational comparator") + .def(py::init, DataType>()); + + (void)py::class_>(*m, "DuplicateOp", "Duplicate tensor.") + .def(py::init<>()); + + (void)py::class_>( + *m, "TruncateSequencePairOp", "Tensor operation to truncate two tensors to a max_length") + .def(py::init()); + + (void)py::class_>(*m, "ConcatenateOp", + "Tensor operation concatenate tensors.") + .def(py::init, std::shared_ptr>(), py::arg("axis"), + py::arg("prepend").none(true), py::arg("append").none(true)); + + (void)py::class_>( + *m, "RandomRotationOp", + "Tensor operation to apply RandomRotation." + "Takes a range for degrees and " + "optional parameters for rotation center and image expand") + .def(py::init(), + py::arg("startDegree"), py::arg("endDegree"), py::arg("centerX") = RandomRotationOp::kDefCenterX, + py::arg("centerY") = RandomRotationOp::kDefCenterY, + py::arg("interpolation") = RandomRotationOp::kDefInterpolation, + py::arg("expand") = RandomRotationOp::kDefExpand, py::arg("fillR") = RandomRotationOp::kDefFillR, + py::arg("fillG") = RandomRotationOp::kDefFillG, py::arg("fillB") = RandomRotationOp::kDefFillB); + + (void)py::class_>( + *m, "PadEndOp", "Tensor operation to pad end of tensor with a pad value.") + .def(py::init>()); +} + +void bindTensorOps3(py::module *m) { + (void)py::class_>( + *m, "RandomCropAndResizeOp", + "Tensor operation to randomly crop an image and resize to a given size." + "Takes output height and width and" + "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," + "interpolation mode, and max attempts to crop") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropAndResizeOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropAndResizeOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropAndResizeOp::kDefAspectUb, + py::arg("interpolation") = RandomCropAndResizeOp::kDefInterpolation, + py::arg("maxIter") = RandomCropAndResizeOp::kDefMaxIter); + + (void)py::class_>( + *m, "RandomCropAndResizeWithBBoxOp", + "Tensor operation to randomly crop an image (with BBoxes) and resize to a given size." + "Takes output height and width and" + "optional parameters for lower and upper bound for aspect ratio (h/w) and scale," + "interpolation mode, and max attempts to crop") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth"), py::arg("scaleLb") = RandomCropAndResizeWithBBoxOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropAndResizeWithBBoxOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropAndResizeWithBBoxOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropAndResizeWithBBoxOp::kDefAspectUb, + py::arg("interpolation") = RandomCropAndResizeWithBBoxOp::kDefInterpolation, + py::arg("maxIter") = RandomCropAndResizeWithBBoxOp::kDefMaxIter); + + (void)py::class_>( + *m, "RandomColorAdjustOp", + "Tensor operation to adjust an image's color randomly." + "Takes range for brightness, contrast, saturation, hue and") + .def(py::init(), py::arg("bright_factor_start"), + py::arg("bright_factor_end"), py::arg("contrast_factor_start"), py::arg("contrast_factor_end"), + py::arg("saturation_factor_start"), py::arg("saturation_factor_end"), py::arg("hue_factor_start"), + py::arg("hue_factor_end")); + + (void)py::class_>( + *m, "RandomResizeOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeOp::kDefTargetWidth); + + (void)py::class_>( + *m, "CutOutOp", "Tensor operation to randomly erase a portion of the image. Takes height and width.") + .def(py::init(), py::arg("boxHeight"), + py::arg("boxWidth"), py::arg("numPatches"), py::arg("randomColor") = CutOutOp::kDefRandomColor, + py::arg("fillR") = CutOutOp::kDefFillR, py::arg("fillG") = CutOutOp::kDefFillG, + py::arg("fillB") = CutOutOp::kDefFillB); +} + +void bindTensorOps4(py::module *m) { + (void)py::class_>( + *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); + + (void)py::class_>(*m, "NoOp", + "TensorOp that does nothing, for testing purposes only.") + .def(py::init<>()); + + (void)py::class_>( + *m, "ToFloat16Op", py::dynamic_attr(), "Tensor operator to type cast float32 data to a float16 type.") + .def(py::init<>()); + + (void)py::class_>( + *m, "RandomCropDecodeResizeOp", "equivalent to RandomCropAndResize but crops before decoding") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth"), py::arg("scaleLb") = RandomCropDecodeResizeOp::kDefScaleLb, + py::arg("scaleUb") = RandomCropDecodeResizeOp::kDefScaleUb, + py::arg("aspectLb") = RandomCropDecodeResizeOp::kDefAspectLb, + py::arg("aspectUb") = RandomCropDecodeResizeOp::kDefAspectUb, + py::arg("interpolation") = RandomCropDecodeResizeOp::kDefInterpolation, + py::arg("maxIter") = RandomCropDecodeResizeOp::kDefMaxIter); + + (void)py::class_>( + *m, "PadOp", + "Pads image with specified color, default black, " + "Takes amount to pad for top, bottom, left, right of image, boarder type and color") + .def(py::init(), py::arg("padTop"), + py::arg("padBottom"), py::arg("padLeft"), py::arg("padRight"), py::arg("borderTypes") = PadOp::kDefBorderType, + py::arg("fillR") = PadOp::kDefFillR, py::arg("fillG") = PadOp::kDefFillG, py::arg("fillB") = PadOp::kDefFillB); + (void)py::class_>(*m, "ToNumberOp", + "TensorOp to convert strings to numbers.") + .def(py::init(), py::arg("data_type")) + .def(py::init(), py::arg("data_type")); +} + +void bindTokenizerOps(py::module *m) { + (void)py::class_>(*m, "JiebaTokenizerOp", "") + .def(py::init(), py::arg("hmm_path"), + py::arg("mp_path"), py::arg("mode") = JiebaMode::kMix, + py::arg("with_offsets") = JiebaTokenizerOp::kDefWithOffsets) + .def("add_word", + [](JiebaTokenizerOp &self, const std::string word, int freq) { THROW_IF_ERROR(self.AddWord(word, freq)); }); + (void)py::class_>( + *m, "UnicodeCharTokenizerOp", "Tokenize a scalar tensor of UTF-8 string to Unicode characters.") + .def(py::init(), py::arg("with_offsets") = UnicodeCharTokenizerOp::kDefWithOffsets); + (void)py::class_>(*m, "LookupOp", + "Tensor operation to LookUp each word.") + .def(py::init([](std::shared_ptr vocab, const py::object &py_word) { + if (vocab == nullptr) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, "vocab object type is incorrect or null.")); + } + if (py_word.is_none()) { + return std::make_shared(vocab, Vocab::kNoTokenExists); + } + std::string word = py::reinterpret_borrow(py_word); + WordIdType default_id = vocab->Lookup(word); + if (default_id == Vocab::kNoTokenExists) { + THROW_IF_ERROR( + Status(StatusCode::kUnexpectedError, "default unknown token:" + word + " doesn't exist in vocab.")); + } + return std::make_shared(vocab, default_id); + })); + (void)py::class_>(*m, "NgramOp", "TensorOp performs ngram mapping.") + .def(py::init &, int32_t, int32_t, const std::string &, const std::string &, + const std::string &>(), + py::arg("ngrams"), py::arg("l_pad_len"), py::arg("r_pad_len"), py::arg("l_pad_token"), py::arg("r_pad_token"), + py::arg("separator")); + (void)py::class_>( + *m, "WordpieceTokenizerOp", "Tokenize scalar token or 1-D tokens to subword tokens.") + .def( + py::init &, const std::string &, const int &, const std::string &, const bool &>(), + py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), + py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, + py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), + py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); +} + +void bindDependIcuTokenizerOps(py::module *m) { +#ifdef ENABLE_ICU4C + (void)py::class_>( + *m, "WhitespaceTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on ICU defined whitespaces.") + .def(py::init(), py::arg("with_offsets") = WhitespaceTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "UnicodeScriptTokenizerOp", "Tokenize a scalar tensor of UTF-8 string on Unicode script boundaries.") + .def(py::init<>()) + .def(py::init(), + py::arg("keep_whitespace") = UnicodeScriptTokenizerOp::kDefKeepWhitespace, + py::arg("with_offsets") = UnicodeScriptTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "CaseFoldOp", "Apply case fold operation on utf-8 string tensor") + .def(py::init<>()); + (void)py::class_>( + *m, "NormalizeUTF8Op", "Apply normalize operation on utf-8 string tensor.") + .def(py::init<>()) + .def(py::init(), py::arg("normalize_form") = NormalizeUTF8Op::kDefNormalizeForm); + (void)py::class_>( + *m, "RegexReplaceOp", "Replace utf-8 string tensor with 'replace' according to regular expression 'pattern'.") + .def(py::init(), py::arg("pattern"), py::arg("replace"), + py::arg("replace_all")); + (void)py::class_>( + *m, "RegexTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by regex expression pattern.") + .def(py::init(), py::arg("delim_pattern"), + py::arg("keep_delim_pattern"), py::arg("with_offsets") = RegexTokenizerOp::kDefWithOffsets); + (void)py::class_>( + *m, "BasicTokenizerOp", "Tokenize a scalar tensor of UTF-8 string by specific rules.") + .def(py::init(), + py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, + py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, + py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, + py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, + py::arg("with_offsets") = BasicTokenizerOp::kDefWithOffsets); + (void)py::class_>(*m, "BertTokenizerOp", + "Tokenizer used for Bert text process.") + .def(py::init &, const std::string &, const int &, const std::string &, const bool &, + const bool &, const NormalizeForm &, const bool &, const bool &>(), + py::arg("vocab"), py::arg("suffix_indicator") = std::string(WordpieceTokenizerOp::kDefSuffixIndicator), + py::arg("max_bytes_per_token") = WordpieceTokenizerOp::kDefMaxBytesPerToken, + py::arg("unknown_token") = std::string(WordpieceTokenizerOp::kDefUnknownToken), + py::arg("lower_case") = BasicTokenizerOp::kDefLowerCase, + py::arg("keep_whitespace") = BasicTokenizerOp::kDefKeepWhitespace, + py::arg("normalization_form") = BasicTokenizerOp::kDefNormalizationForm, + py::arg("preserve_unused_token") = BasicTokenizerOp::kDefPreserveUnusedToken, + py::arg("with_offsets") = WordpieceTokenizerOp::kDefWithOffsets); +#endif +} + +void bindSamplerOps(py::module *m) { + (void)py::class_>(*m, "Sampler") + .def("set_num_rows", [](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); }) + .def("set_num_samples", [](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); }) + .def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); }) + .def("get_indices", + [](Sampler &self) { + py::array ret; + THROW_IF_ERROR(self.GetAllIdsThenReset(&ret)); + return ret; + }) + .def("add_child", + [](std::shared_ptr self, std::shared_ptr child) { THROW_IF_ERROR(self->AddChild(child)); }); + + (void)py::class_>(*m, "ShardOperator") + .def("add_child", [](std::shared_ptr self, + std::shared_ptr child) { self->SetChildOp(child); }); + + (void)py::class_>(*m, "DistributedSampler") + .def(py::init()); + + (void)py::class_>(*m, "PKSampler") + .def(py::init()); + + (void)py::class_>(*m, "RandomSampler") + .def(py::init()); + + (void)py::class_>(*m, "SequentialSampler") + .def(py::init()); + + (void)py::class_>(*m, "SubsetRandomSampler") + .def(py::init>()); + + (void)py::class_>( + *m, "MindrecordSubsetRandomSampler") + .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + + (void)py::class_>( + *m, "MindrecordPkSampler") + .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { + if (shuffle == true) { + return std::make_shared(kColumn, kVal, std::numeric_limits::max(), + GetSeed()); + } else { + return std::make_shared(kColumn, kVal); + } + })); + + (void)py::class_>(*m, "MindrecordDistributedSampler") + .def(py::init()); + + (void)py::class_>( + *m, "MindrecordRandomSampler") + .def(py::init([](int64_t num_samples, bool replacement, bool reshuffle_each_epoch) { + return std::make_shared(GetSeed(), num_samples, replacement, reshuffle_each_epoch); + })); + + (void)py::class_>(*m, "MindrecordSequentialSampler") + .def(py::init([](int num_samples, int start_index) { + return std::make_shared(num_samples, start_index); + })); + + (void)py::class_>(*m, "WeightedRandomSampler") + .def(py::init, bool>()); + + (void)py::class_>(*m, "PythonSampler") + .def(py::init()); +} + +void bindInfoObjects(py::module *m) { + (void)py::class_(*m, "CBatchInfo") + .def(py::init()) + .def("get_epoch_num", &BatchOp::CBatchInfo::get_epoch_num) + .def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num); +} + +void bindCacheClient(py::module *m) { + (void)py::class_>(*m, "CacheClient") + .def(py::init()); +} + +void bindVocabObjects(py::module *m) { + (void)py::class_>(*m, "Vocab") + .def(py::init<>()) + .def_static("from_list", + [](const py::list &words, const py::list &special_tokens, bool special_first) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyList(words, special_tokens, special_first, &v)); + return v; + }) + .def_static("from_file", + [](const std::string &path, const std::string &dlm, int32_t vocab_size, const py::list &special_tokens, + bool special_first) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromFile(path, dlm, vocab_size, special_tokens, special_first, &v)); + return v; + }) + .def_static("from_dict", [](const py::dict &words) { + std::shared_ptr v; + THROW_IF_ERROR(Vocab::BuildFromPyDict(words, &v)); + return v; + }); +} + +void bindGraphData(py::module *m) { + (void)py::class_>(*m, "Graph") + .def(py::init([](std::string dataset_file, int32_t num_workers) { + std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); + THROW_IF_ERROR(g_out->Init()); + return g_out; + })) + .def("get_all_nodes", + [](gnn::Graph &g, gnn::NodeType node_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); + return out; + }) + .def("get_all_edges", + [](gnn::Graph &g, gnn::EdgeType edge_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); + return out; + }) + .def("get_nodes_from_edges", + [](gnn::Graph &g, std::vector edge_list) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); + return out; + }) + .def("get_all_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); + return out; + }) + .def("get_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, + std::vector neighbor_types) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); + return out; + }) + .def("get_neg_sampled_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, + gnn::NodeType neg_neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); + return out; + }) + .def("get_node_feature", + [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); + return out.getRow(); + }) + .def("get_edge_feature", + [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); + return out.getRow(); + }) + .def("graph_info", + [](gnn::Graph &g) { + py::dict out; + THROW_IF_ERROR(g.GraphInfo(&out)); + return out; + }) + .def("random_walk", [](gnn::Graph &g, std::vector node_list, std::vector meta_path, + float step_home_param, float step_away_param, gnn::NodeIdType default_node) { + std::shared_ptr out; + THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); + return out; + }); +} + +// This is where we externalize the C logic as python modules +PYBIND11_MODULE(_c_dataengine, m) { + m.doc() = "pybind11 for _c_dataengine"; + (void)py::class_>(m, "DatasetOp"); + + (void)py::enum_(m, "OpName", py::arithmetic()) + .value("SHUFFLE", OpName::kShuffle) + .value("BATCH", OpName::kBatch) + .value("BUCKETBATCH", OpName::kBucketBatch) + .value("BARRIER", OpName::kBarrier) + .value("MINDRECORD", OpName::kMindrecord) + .value("CACHE", OpName::kCache) + .value("REPEAT", OpName::kRepeat) + .value("SKIP", OpName::kSkip) + .value("TAKE", OpName::kTake) + .value("ZIP", OpName::kZip) + .value("CONCAT", OpName::kConcat) + .value("MAP", OpName::kMap) + .value("FILTER", OpName::kFilter) + .value("DEVICEQUEUE", OpName::kDeviceQueue) + .value("GENERATOR", OpName::kGenerator) + .export_values() + .value("RENAME", OpName::kRename) + .value("TFREADER", OpName::kTfReader) + .value("PROJECT", OpName::kProject) + .value("IMAGEFOLDER", OpName::kImageFolder) + .value("MNIST", OpName::kMnist) + .value("MANIFEST", OpName::kManifest) + .value("VOC", OpName::kVoc) + .value("COCO", OpName::kCoco) + .value("CIFAR10", OpName::kCifar10) + .value("CIFAR100", OpName::kCifar100) + .value("RANDOMDATA", OpName::kRandomData) + .value("BUILDVOCAB", OpName::kBuildVocab) + .value("CELEBA", OpName::kCelebA) + .value("TEXTFILE", OpName::kTextFile) + .value("CLUE", OpName::kClue); + + (void)py::enum_(m, "JiebaMode", py::arithmetic()) + .value("DE_JIEBA_MIX", JiebaMode::kMix) + .value("DE_JIEBA_MP", JiebaMode::kMp) + .value("DE_JIEBA_HMM", JiebaMode::kHmm) + .export_values(); + +#ifdef ENABLE_ICU4C + (void)py::enum_(m, "NormalizeForm", py::arithmetic()) + .value("DE_NORMALIZE_NONE", NormalizeForm::kNone) + .value("DE_NORMALIZE_NFC", NormalizeForm::kNfc) + .value("DE_NORMALIZE_NFKC", NormalizeForm::kNfkc) + .value("DE_NORMALIZE_NFD", NormalizeForm::kNfd) + .value("DE_NORMALIZE_NFKD", NormalizeForm::kNfkd) + .export_values(); +#endif + + (void)py::enum_(m, "InterpolationMode", py::arithmetic()) + .value("DE_INTER_LINEAR", InterpolationMode::kLinear) + .value("DE_INTER_CUBIC", InterpolationMode::kCubic) + .value("DE_INTER_AREA", InterpolationMode::kArea) + .value("DE_INTER_NEAREST_NEIGHBOUR", InterpolationMode::kNearestNeighbour) + .export_values(); + + (void)py::enum_(m, "BorderType", py::arithmetic()) + .value("DE_BORDER_CONSTANT", BorderType::kConstant) + .value("DE_BORDER_EDGE", BorderType::kEdge) + .value("DE_BORDER_REFLECT", BorderType::kReflect) + .value("DE_BORDER_SYMMETRIC", BorderType::kSymmetric) + .export_values(); + bindDEPipeline(&m); + bindTensor(&m); + bindTensorOps1(&m); + bindTensorOps2(&m); + bindTensorOps3(&m); + bindTensorOps4(&m); + bindTokenizerOps(&m); + bindSamplerOps(&m); + bindDatasetOps(&m); + bindInfoObjects(&m); + bindCacheClient(&m); + bindVocabObjects(&m); + bindGraphData(&m); + bindDependIcuTokenizerOps(&m); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/samplers.cc b/mindspore/ccsrc/minddata/dataset/api/samplers.cc new file mode 100644 index 0000000000..91421f0ff8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/samplers.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/samplers.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" + +namespace mindspore { +namespace dataset { +namespace api { + +SamplerObj::SamplerObj() {} + +/// Function to create a Distributed Sampler. +std::shared_ptr DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle, + int64_t num_samples, uint32_t seed) { + auto sampler = std::make_shared(num_shards, shard_id, shuffle, num_samples, seed); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a PK Sampler. +std::shared_ptr PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) { + auto sampler = std::make_shared(num_val, shuffle, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Random Sampler. +std::shared_ptr RandomSampler(bool replacement, int64_t num_samples) { + auto sampler = std::make_shared(replacement, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Sequential Sampler. +std::shared_ptr SequentialSampler(int64_t start_index, int64_t num_samples) { + auto sampler = std::make_shared(start_index, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Subset Random Sampler. +std::shared_ptr SubsetRandomSampler(const std::vector &indices, int64_t num_samples) { + auto sampler = std::make_shared(indices, num_samples); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/// Function to create a Weighted Random Sampler. +std::shared_ptr WeightedRandomSampler(const std::vector &weights, int64_t num_samples, + bool replacement) { + auto sampler = std::make_shared(weights, num_samples, replacement); + // Input validation + if (!sampler->ValidateParams()) { + return nullptr; + } + return sampler; +} + +/* ####################################### Derived Sampler classes ################################# */ + +// DistributedSampler +DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples, + uint32_t seed) + : num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {} + +bool DistributedSamplerObj::ValidateParams() { + if (num_shards_ <= 0) { + MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_; + return false; + } + + if (shard_id_ < 0 || shard_id_ >= num_shards_) { + MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; + return false; + } + + if (num_samples_ < 0) { + MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_; + return false; + } + + return true; +} + +std::shared_ptr DistributedSamplerObj::Build() { + return std::make_shared(num_samples_, num_shards_, shard_id_, shuffle_, seed_); +} + +// PKSampler +PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples) + : num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {} + +bool PKSamplerObj::ValidateParams() { + if (num_val_ <= 0) { + MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_; + return false; + } + + if (num_samples_ < 0) { + MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_; + return false; + } + return true; +} + +std::shared_ptr PKSamplerObj::Build() { + return std::make_shared(num_samples_, num_val_, shuffle_); +} + +// RandomSampler +RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples) + : replacement_(replacement), num_samples_(num_samples) {} + +bool RandomSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_; + return false; + } + return true; +} + +std::shared_ptr RandomSamplerObj::Build() { + bool reshuffle_each_epoch = true; + auto sampler = std::make_shared(num_samples_, replacement_, reshuffle_each_epoch); + return sampler; +} + +// SequentialSampler +SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples) + : start_index_(start_index), num_samples_(num_samples) {} + +bool SequentialSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_; + return false; + } + + if (start_index_ < 0) { + MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_; + return false; + } + + return true; +} + +std::shared_ptr SequentialSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, start_index_); + return sampler; +} + +// SubsetRandomSampler +SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector &indices, int64_t num_samples) + : indices_(indices), num_samples_(num_samples) {} + +bool SubsetRandomSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_; + return false; + } + + return true; +} + +std::shared_ptr SubsetRandomSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, indices_); + return sampler; +} + +// WeightedRandomSampler +WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector &weights, int64_t num_samples, + bool replacement) + : weights_(weights), num_samples_(num_samples), replacement_(replacement) {} + +bool WeightedRandomSamplerObj::ValidateParams() { + if (num_samples_ < 0) { + MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_; + return false; + } + return true; +} + +std::shared_ptr WeightedRandomSamplerObj::Build() { + auto sampler = std::make_shared(num_samples_, weights_, replacement_); + return sampler; +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc new file mode 100644 index 0000000000..59a25ef9f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -0,0 +1,491 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/kernels/image/uniform_aug_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" + +namespace mindspore { +namespace dataset { +namespace api { + +TensorOperation::TensorOperation() {} + +// Transform operations for computer vision. +namespace vision { + +// Function to create NormalizeOperation. +std::shared_ptr Normalize(std::vector mean, std::vector std) { + auto op = std::make_shared(mean, std); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create DecodeOperation. +std::shared_ptr Decode(bool rgb) { + auto op = std::make_shared(rgb); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create ResizeOperation. +std::shared_ptr Resize(std::vector size, InterpolationMode interpolation) { + auto op = std::make_shared(size, interpolation); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomCropOperation. +std::shared_ptr RandomCrop(std::vector size, std::vector padding, + bool pad_if_needed, std::vector fill_value) { + auto op = std::make_shared(size, padding, pad_if_needed, fill_value); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create CenterCropOperation. +std::shared_ptr CenterCrop(std::vector size) { + auto op = std::make_shared(size); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create UniformAugOperation. +std::shared_ptr UniformAugment(std::vector> operations, + int32_t num_ops) { + auto op = std::make_shared(operations, num_ops); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomHorizontalFlipOperation. +std::shared_ptr RandomHorizontalFlip(float prob) { + auto op = std::make_shared(prob); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomVerticalFlipOperation. +std::shared_ptr RandomVerticalFlip(float prob) { + auto op = std::make_shared(prob); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomRotationOperation. +std::shared_ptr RandomRotation(std::vector degrees, InterpolationMode resample, + bool expand, std::vector center, + std::vector fill_value) { + auto op = std::make_shared(degrees, resample, expand, center, fill_value); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create PadOperation. +std::shared_ptr Pad(std::vector padding, std::vector fill_value, + BorderType padding_mode) { + auto op = std::make_shared(padding, fill_value, padding_mode); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create CutOutOp. +std::shared_ptr CutOut(int32_t length, int32_t num_patches) { + auto op = std::make_shared(length, num_patches); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +// Function to create RandomColorAdjustOperation. +std::shared_ptr RandomColorAdjust(std::vector brightness, + std::vector contrast, + std::vector saturation, std::vector hue) { + auto op = std::make_shared(brightness, contrast, saturation, hue); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + +/* ####################################### Derived TensorOperation classes ################################# */ + +// NormalizeOperation +NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} + +bool NormalizeOperation::ValidateParams() { + if (mean_.size() != 3) { + MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size(); + return false; + } + + if (std_.size() != 3) { + MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size(); + return false; + } + + return true; +} + +std::shared_ptr NormalizeOperation::Build() { + return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); +} + +// DecodeOperation +DecodeOperation::DecodeOperation(bool rgb) : rgb_(rgb) {} + +bool DecodeOperation::ValidateParams() { return true; } + +std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } + +// ResizeOperation +ResizeOperation::ResizeOperation(std::vector size, InterpolationMode interpolation) + : size_(size), interpolation_(interpolation) {} + +bool ResizeOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "Resize: size vector has incorrect size: " << size_.size(); + return false; + } + return true; +} + +std::shared_ptr ResizeOperation::Build() { + int32_t height = size_[0]; + int32_t width = 0; + + // User specified the width value. + if (size_.size() == 2) { + width = size_[1]; + } + + return std::make_shared(height, width, interpolation_); +} + +// RandomCropOperation +RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, + std::vector fill_value) + : size_(size), padding_(padding), pad_if_needed_(pad_if_needed), fill_value_(fill_value) {} + +bool RandomCropOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "RandomCrop: size vector has incorrect size: " << size_.size(); + return false; + } + + if (padding_.empty() || padding_.size() != 4) { + MS_LOG(ERROR) << "RandomCrop: padding vector has incorrect size: padding.size()"; + return false; + } + + if (fill_value_.empty() || fill_value_.size() != 3) { + MS_LOG(ERROR) << "RandomCrop: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr RandomCropOperation::Build() { + int32_t crop_height = size_[0]; + int32_t crop_width = 0; + + int32_t pad_top = padding_[0]; + int32_t pad_bottom = padding_[1]; + int32_t pad_left = padding_[2]; + int32_t pad_right = padding_[3]; + + uint8_t fill_r = fill_value_[0]; + uint8_t fill_g = fill_value_[1]; + uint8_t fill_b = fill_value_[2]; + + // User has specified the crop_width value. + if (size_.size() == 2) { + crop_width = size_[1]; + } + + auto tensor_op = std::make_shared(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, + BorderType::kConstant, pad_if_needed_, fill_r, fill_g, fill_b); + return tensor_op; +} + +// CenterCropOperation +CenterCropOperation::CenterCropOperation(std::vector size) : size_(size) {} + +bool CenterCropOperation::ValidateParams() { + if (size_.empty() || size_.size() > 2) { + MS_LOG(ERROR) << "CenterCrop: size vector has incorrect size."; + return false; + } + return true; +} + +std::shared_ptr CenterCropOperation::Build() { + int32_t crop_height = size_[0]; + int32_t crop_width = 0; + + // User has specified crop_width. + if (size_.size() == 2) { + crop_width = size_[1]; + } + + std::shared_ptr tensor_op = std::make_shared(crop_height, crop_width); + return tensor_op; +} + +// UniformAugOperation +UniformAugOperation::UniformAugOperation(std::vector> operations, int32_t num_ops) + : operations_(operations), num_ops_(num_ops) {} + +bool UniformAugOperation::ValidateParams() { return true; } + +std::shared_ptr UniformAugOperation::Build() { + std::vector> tensor_ops; + (void)std::transform(operations_.begin(), operations_.end(), std::back_inserter(tensor_ops), + [](std::shared_ptr op) -> std::shared_ptr { return op->Build(); }); + std::shared_ptr tensor_op = std::make_shared(tensor_ops, num_ops_); + return tensor_op; +} + +// RandomHorizontalFlipOperation +RandomHorizontalFlipOperation::RandomHorizontalFlipOperation(float probability) : probability_(probability) {} + +bool RandomHorizontalFlipOperation::ValidateParams() { return true; } + +std::shared_ptr RandomHorizontalFlipOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +// RandomVerticalFlipOperation +RandomVerticalFlipOperation::RandomVerticalFlipOperation(float probability) : probability_(probability) {} + +bool RandomVerticalFlipOperation::ValidateParams() { return true; } + +std::shared_ptr RandomVerticalFlipOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(probability_); + return tensor_op; +} + +// Function to create RandomRotationOperation. +RandomRotationOperation::RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, + bool expand, std::vector center, + std::vector fill_value) + : degrees_(degrees), + interpolation_mode_(interpolation_mode), + expand_(expand), + center_(center), + fill_value_(fill_value) {} + +bool RandomRotationOperation::ValidateParams() { + if (degrees_.empty() || degrees_.size() != 2) { + MS_LOG(ERROR) << "RandomRotation: degrees vector has incorrect size: degrees.size()"; + return false; + } + if (center_.empty() || center_.size() != 2) { + MS_LOG(ERROR) << "RandomRotation: center vector has incorrect size: center.size()"; + return false; + } + if (fill_value_.empty() || fill_value_.size() != 3) { + MS_LOG(ERROR) << "RandomRotation: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr RandomRotationOperation::Build() { + std::shared_ptr tensor_op = + std::make_shared(degrees_[0], degrees_[1], center_[0], center_[1], interpolation_mode_, expand_, + fill_value_[0], fill_value_[1], fill_value_[2]); + return tensor_op; +} + +// PadOperation +PadOperation::PadOperation(std::vector padding, std::vector fill_value, BorderType padding_mode) + : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} + +bool PadOperation::ValidateParams() { + if (padding_.empty() || padding_.size() == 3 || padding_.size() > 4) { + MS_LOG(ERROR) << "Pad: padding vector has incorrect size: padding.size()"; + return false; + } + + if (fill_value_.empty() || (fill_value_.size() != 1 && fill_value_.size() != 3)) { + MS_LOG(ERROR) << "Pad: fill_value vector has incorrect size: fill_value.size()"; + return false; + } + return true; +} + +std::shared_ptr PadOperation::Build() { + int32_t pad_top, pad_bottom, pad_left, pad_right; + switch (padding_.size()) { + case 1: + pad_left = padding_[0]; + pad_top = padding_[0]; + pad_right = padding_[0]; + pad_bottom = padding_[0]; + break; + case 2: + pad_left = padding_[0]; + pad_top = padding_[1]; + pad_right = padding_[0]; + pad_bottom = padding_[1]; + break; + default: + pad_left = padding_[0]; + pad_top = padding_[1]; + pad_right = padding_[2]; + pad_bottom = padding_[3]; + } + uint8_t fill_r, fill_g, fill_b; + + fill_r = fill_value_[0]; + fill_g = fill_value_[0]; + fill_b = fill_value_[0]; + + if (fill_value_.size() == 3) { + fill_r = fill_value_[0]; + fill_g = fill_value_[1]; + fill_b = fill_value_[2]; + } + + std::shared_ptr tensor_op = + std::make_shared(pad_top, pad_bottom, pad_left, pad_right, padding_mode_, fill_r, fill_g, fill_b); + return tensor_op; +} + +// CutOutOperation +CutOutOperation::CutOutOperation(int32_t length, int32_t num_patches) : length_(length), num_patches_(num_patches) {} + +bool CutOutOperation::ValidateParams() { + if (length_ < 0) { + MS_LOG(ERROR) << "CutOut: length cannot be negative"; + return false; + } + if (num_patches_ < 0) { + MS_LOG(ERROR) << "CutOut: number of patches cannot be negative"; + return false; + } + return true; +} + +std::shared_ptr CutOutOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(length_, length_, num_patches_, false, 0, 0, 0); + return tensor_op; +} + +// RandomColorAdjustOperation. +RandomColorAdjustOperation::RandomColorAdjustOperation(std::vector brightness, std::vector contrast, + std::vector saturation, std::vector hue) + : brightness_(brightness), contrast_(contrast), saturation_(saturation), hue_(hue) {} + +bool RandomColorAdjustOperation::ValidateParams() { + // Do some input validation. + if (brightness_.empty() || brightness_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: brightness must be a vector of one or two values"; + return false; + } + if (contrast_.empty() || contrast_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: contrast must be a vector of one or two values"; + return false; + } + if (saturation_.empty() || saturation_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: saturation must be a vector of one or two values"; + return false; + } + if (hue_.empty() || hue_.size() > 2) { + MS_LOG(ERROR) << "RandomColorAdjust: hue must be a vector of one or two values"; + return false; + } + return true; +} + +std::shared_ptr RandomColorAdjustOperation::Build() { + float brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub; + + brightness_lb = brightness_[0]; + brightness_ub = brightness_[0]; + + if (brightness_.size() == 2) brightness_ub = brightness_[1]; + + contrast_lb = contrast_[0]; + contrast_ub = contrast_[0]; + + if (contrast_.size() == 2) contrast_ub = contrast_[1]; + + saturation_lb = saturation_[0]; + saturation_ub = saturation_[0]; + + if (saturation_.size() == 2) saturation_ub = saturation_[1]; + + hue_lb = hue_[0]; + hue_ub = hue_[0]; + + if (hue_.size() == 2) hue_ub = hue_[1]; + + std::shared_ptr tensor_op = std::make_shared( + brightness_lb, brightness_ub, contrast_lb, contrast_ub, saturation_lb, saturation_ub, hue_lb, hue_ub); + return tensor_op; +} + +} // namespace vision +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/core/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/core/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/core/client.cc b/mindspore/ccsrc/minddata/dataset/core/client.cc new file mode 100644 index 0000000000..e3fd844e66 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/client.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/sig_handler.h" + +namespace mindspore { +namespace dataset { +// This is a one-time global initializer which includes the call to instantiate singletons. +// It is external api call and not a member of the GlobalContext directly. +Status GlobalInit() { + // Bring up all the services (logger, task, bufferpool) + return (Services::CreateInstance()); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/client.h b/mindspore/ccsrc/minddata/dataset/core/client.h new file mode 100644 index 0000000000..78b298e616 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/client.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CLIENT_H_ +#define DATASET_CORE_CLIENT_H_ + +// client.h +// Include file for DE client functions + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" + +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/datasetops/barrier_op.h" +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/engine/datasetops/build_vocab_op.h" +#endif + +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include "minddata/dataset/engine/datasetops/concat_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// This is a one-time global initializer that needs to be called at the +// start of any minddata applications. +extern Status GlobalInit(); +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc new file mode 100644 index 0000000000..e1fc7f29ba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/config_manager.h" + +#include +#include +#include + +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +// A print method typically used for debugging +void ConfigManager::Print(std::ostream &out) const { + // Don't show the test/internal ones. Only display the main ones here. + // fyi, boolalpha tells the output stream to write "true" and "false" for bools + out << "\nClient config settings :" + << "\nDataCache Rows per buffer : " << rows_per_buffer_ + << "\nParallelOp workers : " << num_parallel_workers_ + << "\nParallelOp worker connector size : " << worker_connector_size_ + << "\nSize of each Connector : " << op_connector_size_ << std::endl; +} + +// Private helper function that taks a nlohmann json format and populates the settings +Status ConfigManager::FromJson(const nlohmann::json &j) { + set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_)); + set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); + set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_)); + set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); + set_seed(j.value("seed", seed_)); + set_monitor_sampling_interval(j.value("monitorSamplingInterval", monitor_sampling_interval_)); + return Status::OK(); +} + +// Loads a json file with the default settings and populates all the settings +Status ConfigManager::LoadFile(const std::string &settingsFile) { + Status rc; + if (!Path(settingsFile).Exists()) { + RETURN_STATUS_UNEXPECTED("File is not found."); + } + // Some settings are mandatory, others are not (with default). If a setting + // is optional it will set a default value if the config is missing from the file. + try { + std::ifstream in(settingsFile); + nlohmann::json js; + in >> js; + rc = FromJson(js); + } catch (const nlohmann::json::type_error &e) { + std::ostringstream ss; + ss << "Client file failed to load:\n" << e.what(); + std::string err_msg = ss.str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } catch (const std::exception &err) { + RETURN_STATUS_UNEXPECTED("Client file failed to load."); + } + return rc; +} + +// Setter function +void ConfigManager::set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; } + +// Setter function +void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { + num_parallel_workers_ = num_parallel_workers; +} + +// Setter function +void ConfigManager::set_worker_connector_size(int32_t connector_size) { worker_connector_size_ = connector_size; } + +// Setter function +void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector_size_ = connector_size; } + +uint32_t ConfigManager::seed() const { return seed_; } + +void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } + +void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.h b/mindspore/ccsrc/minddata/dataset/core/config_manager.h new file mode 100644 index 0000000000..a8e1907c41 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.h @@ -0,0 +1,137 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CONFIG_MANAGER_H_ +#define DATASET_CORE_CONFIG_MANAGER_H_ + +#include +#include +#include + +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" + +// Config settings for the client-side +// example config file: +// { +// "rowsPerBuffer": 3 +// } +// + +namespace mindspore { +namespace dataset { +// The ConfigManager is a class for managing default values. When a user is constructing any objects +// in the framework, often they may choose to omit some settings instead of overriding them. +// This class manages some of the default values, for cases when the user does not manually specify +// those values. +class ConfigManager { + public: + ConfigManager() = default; + + // destructor + ~ConfigManager() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + void Print(std::ostream &out) const; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param cS - reference to the ConfigManager to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ConfigManager &cS) { + cS.Print(out); + return out; + } + + // Another debug print helper. Converts the print info to a string for you. + // @return The string version of the debug print + std::string ToString() { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + // Loads a json file with the default settings and populates all the settings + // @param settingsFile - A json file with a set of default settings + // @return Status error code + Status LoadFile(const std::string &settingsFile); + + // getter function + // @return The rows per buffer setting + int32_t rows_per_buffer() const { return rows_per_buffer_; } + + // getter function + // @return The number of workers setting + int32_t num_parallel_workers() const { return num_parallel_workers_; } + + // getter function + // @return The queue size of the operator's output connector + int32_t op_connector_size() const { return op_connector_size_; } + + // getter function + // @return The internal worker-to-master connector queue size + int32_t worker_connector_size() const { return worker_connector_size_; } + + // setter function + // @param rows_per_buffer - The setting to apply to the config + void set_rows_per_buffer(int32_t rows_per_buffer); + + // setter function + // @param num_parallel_workers - The setting to apply to the config + void set_num_parallel_workers(int32_t num_parallel_workers); + + // setter function + // @param connector_size - The setting to apply to the config + void set_worker_connector_size(int32_t connector_size); + + // setter function + // @param connector_size - The setting to apply to the config + void set_op_connector_size(int32_t connector_size); + + uint32_t seed() const; + + // setter function + // @param seed - The default seed to use + void set_seed(uint32_t seed); + + // setter function + // @param interval - The setting to apply to the config + void set_monitor_sampling_interval(uint32_t interval); + + // getter function + // @return The iterval of monitor sampling + int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; } + + private: + int32_t rows_per_buffer_{kCfgRowsPerBuffer}; + int32_t num_parallel_workers_{kCfgParallelWorkers}; + int32_t worker_connector_size_{kCfgWorkerConnectorSize}; + int32_t op_connector_size_{kCfgOpConnectorSize}; + uint32_t seed_{kCfgDefaultSeed}; + uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval}; + + // Private helper function that taks a nlohmann json format and populates the settings + // @param j - The json nlohmann json info + Status FromJson(const nlohmann::json &j); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_CONFIG_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h similarity index 100% rename from mindspore/ccsrc/dataset/core/constants.h rename to mindspore/ccsrc/minddata/dataset/core/constants.h diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc new file mode 100644 index 0000000000..5af748b5de --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/cv_tensor.h" + +#include +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { +CVTensor::CVTensor(const TensorShape &shape, const DataType &type) : Tensor(shape, type) { + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} + +CVTensor::CVTensor(const TensorShape &shape, const DataType &type, const uchar *data) : Tensor(shape, type, data) { + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} + +CVTensor::CVTensor(std::shared_ptr tensor) : Tensor(std::move(*tensor)) { + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} + +std::pair, int> CVTensor::IsValidImage(const TensorShape &shape, const DataType &type) { + std::array size = {1, 1}; + if (shape.Rank() <= 2 || (shape.Rank() == 3 && shape[2] <= CV_CN_MAX)) { + uint8_t ch = 1; + if (shape.Rank() == 3) { + ch = static_cast(shape[2]); + } + if (shape.Rank() > 0) size[0] = static_cast(shape[0]); + if (shape.Rank() > 1) size[1] = static_cast(shape[1]); + if (type.AsCVType() == kCVInvalidType) return std::make_pair(size, -1); + + int cv_type = CV_MAKETYPE(type.AsCVType(), ch); + return std::make_pair(size, cv_type); + } + return std::make_pair(size, -1); +} + +std::shared_ptr CVTensor::AsCVTensor(std::shared_ptr t) { + std::shared_ptr cv_t = std::dynamic_pointer_cast(t); + if (cv_t != nullptr) { + return cv_t; + } else { + return std::make_shared(t); + } +} + +Status CVTensor::MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat) { + std::pair, int> cv_shape_type = IsValidImage(shape, type); + if (cv_shape_type.second == -1) { + std::vector sizes = shape.AsVector(); + std::vector sizes32(sizes.begin(), sizes.end()); // convert long to int for usage with OpenCV + if (static_cast(shape.Rank()) != shape.Rank()) { + RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Wrong shape."); + } + + uint8_t cv_type = type.AsCVType(); + if (cv_type == kCVInvalidType) { + RETURN_STATUS_UNEXPECTED("Error in creating CV mat. Invalid type."); + } + *mat = cv::Mat(static_cast(shape.Rank()), &sizes32[0], cv_type, data); + } else { + *mat = cv::Mat(2, &(cv_shape_type.first[0]), cv_shape_type.second, data); + } + return Status::OK(); +} + +Status CVTensor::Reshape(const TensorShape &shape) { + RETURN_IF_NOT_OK(Tensor::Reshape(shape)); + RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); + return Status::OK(); +} + +Status CVTensor::ExpandDim(const dsize_t &axis) { + RETURN_IF_NOT_OK(Tensor::ExpandDim(axis)); + RETURN_IF_NOT_OK(this->MatInit(GetMutableBuffer(), shape_, type_, &mat_)); + return Status::OK(); +} + +void CVTensor::Squeeze() { + Tensor::Squeeze(); + (void)this->MatInit(GetMutableBuffer(), shape_, type_, &mat_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h new file mode 100644 index 0000000000..a614418be6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/cv_tensor.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_CV_TENSOR_H_ +#define DATASET_CORE_CV_TENSOR_H_ + +#include +#include +#include + +#include + +#include "./securec.h" + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { +class CVTensor : public Tensor { + public: + // Create an empty CVTensor of shape `shape` and type `type`. + // @note The shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + CVTensor(const TensorShape &shape, const DataType &type); + + // Create a CVTensor from a given buffer, shape and type. + // @note This constructor allocates a new space in the memory and copies the buffer into it. + // @note The buffer should be valid and the shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + // @param data unsigned char*, pointer to the data. + CVTensor(const TensorShape &shape, const DataType &type, const uchar *data); + + // Create a CVTensor from a given CV::Mat. + // @note This constructor allocates a new space in the memory and copies the CV::Mat buffer into it. + // @param mat CV::Mat + explicit CVTensor(const cv::Mat &mat) + : CVTensor(TensorShape(mat.size, mat.type()), DataType::FromCVType(mat.type()), mat.data) {} + + ~CVTensor() = default; + + // Static function to cast a given Tensor as CVTensor. If the input tensor is already of type CVTensor, + // this function would be treated as a no-op. Fot other tensor types, a new CVTensor is created based on the data + // provided. The Passed Tensor will be invalidated. + // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. + // @param tensor + // @return CVTensor + static std::shared_ptr AsCVTensor(std::shared_ptr tensor); + + // Create a CVTensor from a given tensor. The input tensor will be invalidated (i.e., the shape and type will be + // set to unknown and the data buffer will point to null. + // @note there is no memory copying here, the buffer will be assigned to the constructed tensor. + // @param tensor + explicit CVTensor(std::shared_ptr tensor); + + // Getter function for the CV::Mat + // @return + cv::Mat mat() const { return mat_; } + + // Static function to check if the passed information (shape and type) can be treated as a valid description + // of an image in OpenCV. Moreover, it returns OpenCV shape and type + // For example, if the shape is <512,512,3> and type is DE_UINT8, the output would be [512,512] and CV_8UC3. + // In case of invalid shape or type, the function will return pair + // @param shape TensorShape + // @param type DataType + // @return std::pair of OpenCV shape and type + std::pair, int> IsValidImage(const TensorShape &shape, const DataType &type); + + Status Reshape(const TensorShape &shape) override; + + Status ExpandDim(const dsize_t &axis) override; + + void Squeeze() override; + + Status Mat(const std::vector &index, cv::Mat *mat) { + uchar *start = nullptr; + TensorShape remaining({-1}); + RETURN_IF_NOT_OK(this->StartAddrOfIndex(index, &start, &remaining)); + RETURN_IF_NOT_OK(this->MatInit(start, remaining, type_, mat)); + return Status::OK(); + } + + private: + cv::Mat mat_; + + // Initialize CV::Mat with the data_, shape_ and type_ + Status MatInit(uchar *data, const TensorShape &shape, const DataType &type, cv::Mat *mat); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_CV_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.cc b/mindspore/ccsrc/minddata/dataset/core/data_type.cc new file mode 100644 index 0000000000..b5641e3105 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/data_type.cc @@ -0,0 +1,166 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/data_type.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +#endif + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { + +uint8_t DataType::SizeInBytes() const { + if (type_ < DataType::NUM_OF_TYPES) + return kTypeInfo[type_].sizeInBytes_; + else + return 0; +} + +#ifdef ENABLE_PYTHON +py::dtype DataType::AsNumpyType() const { + if (type_ < DataType::NUM_OF_TYPES) + return py::dtype(kTypeInfo[type_].pybindType_); + else + return py::dtype("unknown"); +} +#endif + +uint8_t DataType::AsCVType() const { + uint8_t res = kCVInvalidType; + if (type_ < DataType::NUM_OF_TYPES) { + res = kTypeInfo[type_].cvType_; + } + + if (res == kCVInvalidType) { + MS_LOG(ERROR) << "Cannot convert to OpenCV type. Return invalid type!"; + } + + return res; +} // namespace dataset + +DataType DataType::FromCVType(int cv_type) { + auto depth = static_cast(cv_type) & static_cast(CV_MAT_DEPTH_MASK); + switch (depth) { + case CV_8S: + return DataType(DataType::DE_INT8); + case CV_8U: + return DataType(DataType::DE_UINT8); + case CV_16S: + return DataType(DataType::DE_INT16); + case CV_16U: + return DataType(DataType::DE_UINT16); + case CV_32S: + return DataType(DataType::DE_INT32); + case CV_16F: + return DataType(DataType::DE_FLOAT16); + case CV_32F: + return DataType(DataType::DE_FLOAT32); + case CV_64F: + return DataType(DataType::DE_FLOAT64); + default: + MS_LOG(ERROR) << "Cannot convert from OpenCV type, unknown CV type. Unknown data type is returned!"; + return DataType(DataType::DE_UNKNOWN); + } +} + +DataType::DataType(const std::string &type_str) { + if (type_str == "bool") + type_ = DE_BOOL; + else if (type_str == "int8") + type_ = DE_INT8; + else if (type_str == "uint8") + type_ = DE_UINT8; + else if (type_str == "int16") + type_ = DE_INT16; + else if (type_str == "uint16") + type_ = DE_UINT16; + else if (type_str == "int32") + type_ = DE_INT32; + else if (type_str == "uint32") + type_ = DE_UINT32; + else if (type_str == "int64") + type_ = DE_INT64; + else if (type_str == "uint64") + type_ = DE_UINT64; + else if (type_str == "float16") + type_ = DE_FLOAT16; + else if (type_str == "float32") + type_ = DE_FLOAT32; + else if (type_str == "float64") + type_ = DE_FLOAT64; + else if (type_str == "string") + type_ = DE_STRING; + else + type_ = DE_UNKNOWN; +} + +std::string DataType::ToString() const { + if (type_ < DataType::NUM_OF_TYPES) + return kTypeInfo[type_].name_; + else + return "unknown"; +} + +#ifdef ENABLE_PYTHON +DataType DataType::FromNpArray(const py::array &arr) { + if (py::isinstance>(arr)) { + return DataType(DataType::DE_BOOL); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT8); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT8); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT16); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT16); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT32); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT32); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_INT64); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_UINT64); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_FLOAT16); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_FLOAT32); + } else if (py::isinstance>(arr)) { + return DataType(DataType::DE_FLOAT64); + } else if (arr.dtype().kind() == 'S' || arr.dtype().kind() == 'U') { + return DataType(DataType::DE_STRING); + } else { + MS_LOG(ERROR) << "Cannot convert from numpy type. Unknown data type is returned!"; + return DataType(DataType::DE_UNKNOWN); + } +} + +std::string DataType::GetPybindFormat() const { + std::string res; + if (type_ < DataType::NUM_OF_TYPES) { + res = kTypeInfo[type_].pybindFormatDescriptor_; + } + + if (res.empty()) { + MS_LOG(ERROR) << "Cannot convert from data type to pybind format descriptor!"; + } + return res; +} +#endif + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/core/data_type.h new file mode 100644 index 0000000000..db4834cae2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/data_type.h @@ -0,0 +1,350 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_DATA_TYPE_H_ +#define DATASET_CORE_DATA_TYPE_H_ + +#include + +#include +#ifdef ENABLE_PYTHON +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "minddata/dataset/core/pybind_support.h" +namespace py = pybind11; +#else +#include "Eigen/Core" +using float16 = Eigen::half; +#endif +#include "minddata/dataset/core/constants.h" +namespace mindspore { +namespace dataset { + +// Class that represents basic data types in DataEngine. +class DataType { + public: + enum Type : uint8_t { + DE_UNKNOWN = 0, + DE_BOOL, + DE_INT8, + DE_UINT8, + DE_INT16, + DE_UINT16, + DE_INT32, + DE_UINT32, + DE_INT64, + DE_UINT64, + DE_FLOAT16, + DE_FLOAT32, + DE_FLOAT64, + DE_STRING, + NUM_OF_TYPES + }; + + struct TypeInfo { + const char *name_; // name to be represent the type while printing + const uint8_t sizeInBytes_; // number of bytes needed for this type + const char *pybindType_; // Python matching type, used in get_output_types + const std::string pybindFormatDescriptor_; // pybind format used for numpy types + const uint8_t cvType_; // OpenCv matching type + }; + +#ifdef ENABLE_PYTHON + static inline const TypeInfo kTypeInfo[] = { + // name, sizeInBytes, pybindTypem formatDescriptor, openCV + {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN + {"bool", 1, "bool", py::format_descriptor::format(), CV_8U}, // DE_BOOL + {"int8", 1, "int8", py::format_descriptor::format(), CV_8S}, // DE_INT8 + {"uint8", 1, "uint8", py::format_descriptor::format(), CV_8U}, // DE_UINT8 + {"int16", 2, "int16", py::format_descriptor::format(), CV_16S}, // DE_INT16 + {"uint16", 2, "uint16", py::format_descriptor::format(), CV_16U}, // DE_UINT16 + {"int32", 4, "int32", py::format_descriptor::format(), CV_32S}, // DE_INT32 + {"uint32", 4, "uint32", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT32 + {"int64", 8, "int64", py::format_descriptor::format(), kCVInvalidType}, // DE_INT64 + {"uint64", 8, "uint64", py::format_descriptor::format(), kCVInvalidType}, // DE_UINT64 + {"float16", 2, "float16", "e", CV_16F}, // DE_FLOAT16 + {"float32", 4, "float32", py::format_descriptor::format(), CV_32F}, // DE_FLOAT32 + {"float64", 8, "double", py::format_descriptor::format(), CV_64F}, // DE_FLOAT64 + {"string", 0, "bytes", "S", kCVInvalidType} // DE_STRING + }; +#else + static inline const TypeInfo kTypeInfo[] = { + // name, sizeInBytes, pybindTypem formatDescriptor, openCV + {"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN + {"bool", 1, "bool", "", CV_8U}, // DE_BOOL + {"int8", 1, "int8", "", CV_8S}, // DE_INT8 + {"uint8", 1, "uint8", "", CV_8U}, // DE_UINT8 + {"int16", 2, "int16", "", CV_16S}, // DE_INT16 + {"uint16", 2, "uint16", "", CV_16U}, // DE_UINT16 + {"int32", 4, "int32", "", CV_32S}, // DE_INT32 + {"uint32", 4, "uint32", "", kCVInvalidType}, // DE_UINT32 + {"int64", 8, "int64", "", kCVInvalidType}, // DE_INT64 + {"uint64", 8, "uint64", "", kCVInvalidType}, // DE_UINT64 + {"float16", 2, "float16", "", CV_16F}, // DE_FLOAT16 + {"float32", 4, "float32", "", CV_32F}, // DE_FLOAT32 + {"float64", 8, "double", "", CV_64F}, // DE_FLOAT64 + {"string", 0, "bytes", "", kCVInvalidType} // DE_STRING + }; +#endif + + // No arg constructor to create an unknown shape + DataType() : type_(DE_UNKNOWN) {} + + // Create a type from a given string + /// \param type_str + explicit DataType(const std::string &type_str); + + // Default destructor + ~DataType() = default; + + // Create a type from a given enum + /// \param d + constexpr explicit DataType(Type d) : type_(d) {} + + constexpr bool operator==(const DataType a) const { return type_ == a.type_; } + + constexpr bool operator==(const Type a) const { return type_ == a; } + + constexpr bool operator!=(const DataType a) const { return type_ != a.type_; } + + constexpr bool operator!=(const Type a) const { return type_ != a; } + + // Disable this usage `if(d)` where d is of type DataType + /// \return + operator bool() = delete; + + // To be used in Switch/case + /// \return + operator Type() const { return type_; } + + // The number of bytes needed to store one value of this type + /// \return + uint8_t SizeInBytes() const; + + // Convert from DataType to OpenCV type + /// \return + uint8_t AsCVType() const; + + // Convert from OpenCV type to DataType + /// \param cv_type + /// \return + static DataType FromCVType(int cv_type); + + // Returns a string representation of the type + /// \return + std::string ToString() const; + + // returns true if the template type is the same as the Tensor type_ + /// \tparam T + /// \return true or false + template + bool IsCompatible() const { + return type_ == FromCType(); + } + + // returns true if the template type is the same as the Tensor type_ + /// \tparam T + /// \return true or false + template + bool IsLooselyCompatible() const; + + // << Stream output operator overload + /// \notes This allows you to print the info using stream operators + /// \param out - reference to the output stream being overloaded + /// \param rO - reference to the DataType to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DataType &so) { + out << so.ToString(); + return out; + } + + template + static DataType FromCType(); + +#ifdef ENABLE_PYTHON + // Convert from DataType to Pybind type + /// \return + py::dtype AsNumpyType() const; + + // Convert from NP type to DataType + /// \param type + /// \return + static DataType FromNpType(const py::dtype &type); + + // Convert from NP array to DataType + /// \param py array + /// \return + static DataType FromNpArray(const py::array &arr); +#endif + + // Get the buffer string format of the current type. Used in pybind buffer protocol. + /// \return + std::string GetPybindFormat() const; + + bool IsSignedInt() const { + return type_ == DataType::DE_INT8 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT32 || + type_ == DataType::DE_INT64; + } + + bool IsUnsignedInt() const { + return type_ == DataType::DE_UINT8 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT32 || + type_ == DataType::DE_UINT64; + } + + bool IsInt() const { return IsSignedInt() || IsUnsignedInt(); } + + bool IsFloat() const { + return type_ == DataType::DE_FLOAT16 || type_ == DataType::DE_FLOAT32 || type_ == DataType::DE_FLOAT64; + } + + bool IsBool() const { return type_ == DataType::DE_BOOL; } + + bool IsNumeric() const { return type_ != DataType::DE_STRING; } + + Type value() const { return type_; } + + private: + Type type_; +}; + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_BOOL); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_FLOAT64); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_FLOAT32); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_FLOAT16); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT64); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT64); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT32); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT32); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT16); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT16); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_INT8); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_UINT8); +} + +template <> +inline DataType DataType::FromCType() { + return DataType(DataType::DE_STRING); +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_BOOL; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_FLOAT64 || type_ == DataType::DE_FLOAT32; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_FLOAT32; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_FLOAT16; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT64 || type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || + type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT64 || type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || + type_ == DataType::DE_UINT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT32 || type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT32 || type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT16 || type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT16 || type_ == DataType::DE_UINT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_INT8; +} + +template <> +inline bool DataType::IsLooselyCompatible() const { + return type_ == DataType::DE_UINT8; +} +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_DATA_TYPE_H_ diff --git a/mindspore/ccsrc/dataset/core/example.proto b/mindspore/ccsrc/minddata/dataset/core/example.proto similarity index 100% rename from mindspore/ccsrc/dataset/core/example.proto rename to mindspore/ccsrc/minddata/dataset/core/example.proto diff --git a/mindspore/ccsrc/dataset/core/feature.proto b/mindspore/ccsrc/minddata/dataset/core/feature.proto similarity index 100% rename from mindspore/ccsrc/dataset/core/feature.proto rename to mindspore/ccsrc/minddata/dataset/core/feature.proto diff --git a/mindspore/ccsrc/minddata/dataset/core/global_context.cc b/mindspore/ccsrc/minddata/dataset/core/global_context.cc new file mode 100644 index 0000000000..eb76382ab2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/global_context.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/global_context.h" + +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +// Global static pointer for the singleton GlobalContext +std::unique_ptr GlobalContext::global_context_ = nullptr; +std::once_flag GlobalContext::init_instance_flag_; + +constexpr int GlobalContext::kArenaSize; +constexpr int GlobalContext::kMaxSize; +constexpr bool GlobalContext::kInitArena; + +// Singleton initializer +GlobalContext *GlobalContext::Instance() { + // If the single global context is not created yet, then create it. Otherwise the + // existing one is returned. + std::call_once(init_instance_flag_, []() { + global_context_.reset(new GlobalContext()); + Status rc = global_context_->Init(); + if (rc.IsError()) { + std::terminate(); + } + }); + return global_context_.get(); +} + +Status GlobalContext::Init() { + config_manager_ = std::make_shared(); + mem_pool_ = std::make_shared(); + // For testing we can use Dummy pool instead + + // Create some tensor allocators for the different types and hook them into the pool. + tensor_allocator_ = std::make_unique>(mem_pool_); + cv_tensor_allocator_ = std::make_unique>(mem_pool_); + int_allocator_ = std::make_unique(mem_pool_); + return Status::OK(); +} + +// A print method typically used for debugging +void GlobalContext::Print(std::ostream &out) const { + out << "GlobalContext contains the following default config: " << *config_manager_ << "\n"; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/global_context.h b/mindspore/ccsrc/minddata/dataset/core/global_context.h new file mode 100644 index 0000000000..fe0847f639 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/global_context.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_GLOBAL_CONTEXT_H_ +#define DATASET_CORE_GLOBAL_CONTEXT_H_ + +#include +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// forward declare +class MemoryPool; +class ConfigManager; +class Tensor; +class CVTensor; + +using TensorAlloc = Allocator; // An allocator for Tensors +using CVTensorAlloc = Allocator; // An allocator CVTensors +using IntAlloc = Allocator; + +class GlobalContext { + // some consts for pool config + static constexpr int kArenaSize = 128; + static constexpr int kMaxSize = -1; + static constexpr bool kInitArena = true; + + public: + // Singleton pattern. This method either: + // - creates the single version of the GlobalContext for the first time and returns it + // OR + // - returns the already existing single instance of the GlobalContext + // @return the single global context + static GlobalContext *Instance(); + + // Destructor + ~GlobalContext() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + void Print(std::ostream &out) const; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param g_c - reference to the GlobalContext to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const GlobalContext &g_c) { + g_c.Print(out); + return out; + } + + // Getter method + // @return the client config as raw const pointer + static std::shared_ptr config_manager() { return Instance()->config_manager_; } + + // Getter method + // @return the mem pool + std::shared_ptr mem_pool() const { return mem_pool_; } + + // Getter method + // @return the tensor allocator as raw pointer + const TensorAlloc *tensor_allocator() const { return tensor_allocator_.get(); } + + // Getter method + // @return the CVTensor allocator as raw pointer + const CVTensorAlloc *cv_tensor_allocator() const { return cv_tensor_allocator_.get(); } + + // Getter method + // @return the integer allocator as raw pointer + const IntAlloc *int_allocator() const { return int_allocator_.get(); } + + private: + // Constructor. + // @note Singleton. Instantiation flows through instance() + // @return This is a constructor. + GlobalContext() = default; + + Status Init(); + + static std::once_flag init_instance_flag_; + static std::unique_ptr global_context_; // The instance of the singleton (global) + std::shared_ptr mem_pool_; // A global memory pool + std::shared_ptr config_manager_; // The configs + std::unique_ptr tensor_allocator_; // An allocator for Tensors + std::unique_ptr cv_tensor_allocator_; // An allocator for CV Tensors + std::unique_ptr int_allocator_; // An allocator for ints +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CORE_GLOBAL_CONTEXT_H_ diff --git a/mindspore/ccsrc/dataset/core/pybind_support.h b/mindspore/ccsrc/minddata/dataset/core/pybind_support.h similarity index 100% rename from mindspore/ccsrc/dataset/core/pybind_support.h rename to mindspore/ccsrc/minddata/dataset/core/pybind_support.h diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.cc b/mindspore/ccsrc/minddata/dataset/core/tensor.cc new file mode 100644 index 0000000000..842615f9e1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.cc @@ -0,0 +1,1034 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/core/tensor.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/global_context.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +namespace py = pybind11; +#endif +#include "minddata/dataset/core/tensor_shape.h" + +namespace mindspore { +namespace dataset { +// Helper macros for printing tensor elements +#define CASE_PRINT(de_type, native_type) \ + case de_type: { \ + native_type o; \ + rc = GetItemAt(&o, index); \ + out << o; \ + break; \ + } + +#define CASE_PRINT_HEX(de_type, native_type) \ + case de_type: { \ + native_type o; \ + rc = GetItemAt(&o, index); \ + out << std::hex << std::setw(2) << std::setfill('0') << o << std::dec << std::setfill(' '); \ + break; \ + } + +Tensor::Tensor(const TensorShape &shape, const DataType &type) : shape_(shape), type_(type), data_(nullptr) { + // grab the mem pool from global context and create the allocator for char data area + std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); + data_allocator_ = std::make_unique>(global_pool); +} + +Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data) : Tensor(shape, type) { + if (type.IsNumeric()) { + // If the data pointer was given, then we can also populate the tensor with data + if (data != nullptr) { + // Given the shape/type of this tensor, compute the data size and copy in the input bytes. + int64_t byte_size = this->SizeInBytes(); + Status s = this->AllocateBuffer(byte_size); // Allocates data_ inside itself + if (s.IsOk() && data_ != nullptr) { + int ret_code = memcpy_s(data_, byte_size, data, byte_size); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data into Tensor!"; + } + } else { + MS_LOG(ERROR) << "Failed to create memory for Tensor!"; + } + } + } else { + MS_LOG(ERROR) << "Type should be numeric to use this constructor."; + } +} + +Tensor::Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length) + : Tensor(shape, type) { + // If the data pointer was given, then we can also populate the tensor with data + if (data != nullptr) { + // Allocates data_ inside itself + Status s = AllocateBuffer(length); + if (s.IsError()) { + MS_LOG(ERROR) << "Failed to create memory for Tensor!"; + } + if (data_ != nullptr) { + int ret_code = memcpy_s(data_, length, data, length); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data into Tensor!"; + } + } + } +} + +Tensor::Tensor(Tensor &&other) noexcept + : shape_(other.shape()), + type_(other.type()), + data_(other.GetMutableBuffer()), + data_allocator_(std::move(other.data_allocator_)) { + other.Invalidate(); +} + +Tensor &Tensor::operator=(Tensor &&other) noexcept { + if (&other != this) { + shape_ = other.shape(); + type_ = other.type(); + data_ = other.GetMutableBuffer(); + data_end_ = other.data_end_; + data_allocator_ = std::move(other.data_allocator_); + other.Invalidate(); + } + return *this; +} + +Tensor::Tensor(const std::vector &strings, const TensorShape &shape) + : Tensor(TensorShape({static_cast(strings.size())}), DataType(DataType::DE_STRING)) { + auto length_sum = [](dsize_t sum, const std::string &s) { return s.length() + sum; }; + dsize_t total_length = std::accumulate(strings.begin(), strings.end(), 0, length_sum); + + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize + 1) * shape_.NumOfElements() + kOffsetSize + total_length; + + data_ = data_allocator_->allocate(num_bytes); + + auto offset_arr = reinterpret_cast(data_); + uchar *buf = GetStringsBuffer(); + + offset_t offset = buf - data_; // the first string will start here + uint32_t i = 0; + for (const auto &str : strings) { + // insert the start index of the string. + offset_arr[i++] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) MS_LOG(ERROR) << "Cannot copy string into Tensor"; + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + this->data_end_ = data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) Tensor::Reshape(shape); +} + +Tensor::Tensor(const dataengine::BytesList &bytes_list, const TensorShape &shape) + : Tensor(TensorShape({static_cast(bytes_list.value_size())}), DataType(DataType::DE_STRING)) { + // total bytes needed = offset array + strings + // offset array needs to store one offset var per element + 1 extra to get the length of the last string. + // strings will be null-terminated --> need 1 extra byte per element + dsize_t num_bytes = (kOffsetSize)*shape_.NumOfElements() + kOffsetSize + bytes_list.ByteSizeLong(); + + data_ = data_allocator_->allocate(num_bytes); + + auto offset_arr = reinterpret_cast(data_); + uchar *buf = GetStringsBuffer(); + + offset_t offset = buf - data_; // the first string will start here + uint32_t i = 0; + for (; i < bytes_list.value_size(); i++) { + const std::string &str = bytes_list.value(i); + // insert the start index of the string. + offset_arr[i] = offset; + // total bytes are reduced by kOffsetSize + num_bytes -= kOffsetSize; + // insert actual string + int ret_code = memcpy_s(data_ + offset, num_bytes, common::SafeCStr(str), str.length() + 1); + if (ret_code != 0) { + MS_LOG(ERROR) << "Cannot copy string into Tensor"; + } + // next string will be stored right after the current one. + offset = offset + str.length() + 1; + // total bytes are reduced by the length of the string + num_bytes -= str.length() + 1; + } + // store one more offset value so we can get the length of the last string + // length[last_element] = offset_arr[last_element + 1] - offset_arr[last_element] + offset_arr[i] = offset; + + data_end_ = data_ + offset_arr[i]; + + MS_ASSERT(num_bytes == 0); + if (shape.known()) Tensor::Reshape(shape); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, TensorImpl tensor_impl, const TensorShape &shape, + DataType type, const unsigned char *data) { + if (!shape.known()) { + RETURN_STATUS_UNEXPECTED("Invalid shape."); + } + if (type == DataType::DE_UNKNOWN) { + RETURN_STATUS_UNEXPECTED("Invalid data type."); + } + + switch (tensor_impl) { + case TensorImpl::kFlexible: { + // The flex tensor is really just the base class tensor implementation + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, shape, type, data); + break; + } + case TensorImpl::kCv: { + const CVTensorAlloc *alloc = GlobalContext::Instance()->cv_tensor_allocator(); + *ptr = std::allocate_shared(*alloc, shape, type, data); + break; + } + default: { + std::string err_msg("Invalid tensor implementation type."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); // returns base-class shared_ptr +} + +#ifdef ENABLE_PYTHON +Status Tensor::CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr) { + std::vector shape; + for (dsize_t i = 0; i < arr.ndim(); i++) { + shape.push_back(static_cast(arr.shape()[i])); + } + arr.resize({arr.size()}); // flatten the py::array so we can iterate once + std::vector strings; + + if (arr.dtype().kind() == 'U') { + std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); + } else { + std::for_each(arr.begin(), arr.end(), [&strings](const auto &s) { strings.emplace_back(py::cast(s)); }); + } + + arr.resize(shape); // resize arr back to the original shape + + return CreateTensor(ptr, strings, TensorShape{shape}); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, py::array arr) { + if (DataType::FromNpArray(arr) == DataType::DE_STRING) { + return CreateTensorFromNumpyString(ptr, arr); + } + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, TensorShape({}), DataType(DataType::DE_UNKNOWN)); + + std::vector shape; + for (dsize_t i = 0; i < arr.ndim(); i++) { + shape.push_back(static_cast(arr.shape()[i])); + } + + (*ptr)->shape_ = TensorShape(shape); + (*ptr)->type_ = DataType::FromNpArray(arr); + if (!(*ptr)->shape_.known()) RETURN_STATUS_UNEXPECTED("Invalid shape."); + + if ((*ptr)->type_ == DataType::DE_UNKNOWN) RETURN_STATUS_UNEXPECTED("Invalid data type."); + + std::shared_ptr global_pool = GlobalContext::Instance()->mem_pool(); + (*ptr)->data_allocator_ = std::make_unique>(global_pool); + int64_t byte_size = (*ptr)->SizeInBytes(); + RETURN_IF_NOT_OK((*ptr)->AllocateBuffer(byte_size)); + + unsigned char *data = static_cast(arr.request().ptr); + if ((*ptr)->data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); + } + + std::vector strides; + for (dsize_t i = 0; i < arr.ndim(); i++) { + strides.push_back(static_cast(arr.strides()[i])); + } + + // check if strides are contiguous + bool is_strided = false; + dsize_t count = (*ptr)->shape_.NumOfElements(); + for (size_t i = 0; i < shape.size(); i++) { + count /= shape[i]; + if (strides[i] != (*ptr)->type_.SizeInBytes() * count) { + is_strided = true; + break; + } + } + + if (is_strided) { + RETURN_IF_NOT_OK(CopyStridedArray((*ptr)->data_, data, shape, strides, (*ptr)->type_.SizeInBytes())); + } else { + int ret_code = memcpy_s((*ptr)->data_, byte_size, data, byte_size); + if (ret_code != 0) { + RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); + } + } + + return Status::OK(); // returns base-class shared_ptr +} +#endif + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::vector &strings, + const TensorShape &shape) { + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, strings, shape); + return Status::OK(); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape) { + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *ptr = std::allocate_shared(*alloc, bytes_list, shape); + return Status::OK(); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const std::string &file_path) { + std::ifstream fs; + fs.open(file_path, std::ios::binary | std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(!fs.fail(), "Fail to open file: " + file_path); + int64_t num_bytes = fs.seekg(0, std::ios::end).tellg(); + CHECK_FAIL_RETURN_UNEXPECTED(fs.seekg(0, std::ios::beg).good(), "Fail to find size of file"); + RETURN_IF_NOT_OK( + Tensor::CreateTensor(ptr, TensorImpl::kFlexible, TensorShape{num_bytes}, DataType(DataType::DE_UINT8))); + int64_t written_bytes = fs.read(reinterpret_cast((*ptr)->GetMutableBuffer()), num_bytes).gcount(); + CHECK_FAIL_RETURN_UNEXPECTED(written_bytes == num_bytes && fs.good(), "Error in writing to tensor"); + fs.close(); + return Status::OK(); +} + +Status Tensor::CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape, const DataType &type, dsize_t pad_size) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(ptr, TensorImpl::kFlexible, shape, type)); + + unsigned char *current_tensor_addr = (*ptr)->GetMutableBuffer(); + int64_t tensor_bytes_remaining = bytes_list.value_size() * pad_size; + + for (int i = 0; i < bytes_list.value_size(); i++) { + // read string data into tensor + const std::string ¤t_element = bytes_list.value(i); + int return_code = + memcpy_s(current_tensor_addr, tensor_bytes_remaining, common::SafeCStr(current_element), current_element.size()); + + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when reading bytesList element into Tensor"); + + current_tensor_addr += current_element.size(); + tensor_bytes_remaining -= current_element.size(); + + // pad + int64_t chars_to_pad = pad_size - current_element.size(); + return_code = memset_s(current_tensor_addr, tensor_bytes_remaining, static_cast(' '), chars_to_pad); + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed when padding Tensor"); + + current_tensor_addr += chars_to_pad; + tensor_bytes_remaining -= chars_to_pad; + } + + return Status::OK(); +} + +// Memcpy the given strided array's used part to consecutive memory +// Consider a 3-d array +// A[(i * shape[1] + j) * shape[2] + k] = B[i][j][k] = C[i * strides[0] + j * strides[1] + k * strides[2]] +// Here we convert array C to array A, by memcpy index by index (Note that not all elements in C is copied) +Status Tensor::CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size) { + dsize_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + for (dsize_t i = 0; i < size; ++i) { + dsize_t offset = 0; + dsize_t count = i; + for (size_t j = 0; j < shape.size(); ++j) { + // convert 1d array's index to 3d array's index (A -> B) + dsize_t idx = count % shape[shape.size() - 1 - j]; + count /= shape[shape.size() - 1 - j]; + // calculate the raw data offset based on strides (B -> C) + offset += idx * strides[shape.size() - 1 - j]; + // once count = 0, the following idxes are all zero, skip them + if (count == 0) break; + } + // strides already consider byte size of the data type, but dst doesn't. + // dst[i] = dst + i * type_size = src + offset + int ret_code = memcpy_s(dst + i * type_size, type_size, src + offset, type_size); + if (ret_code != 0) { + RETURN_STATUS_UNEXPECTED("Failed to copy data into Tensor."); + } + } + return Status::OK(); +} + +// Name: Destructor +// Description: Destructor +Tensor::~Tensor() { + if (data_ != nullptr) { + if (data_allocator_ != nullptr) { + data_allocator_->deallocate(data_); + data_ = nullptr; + data_end_ = nullptr; + } else { + // If we didn't have an allocator, but data_ is not null then it must + // be a stand-alone tensor that used malloc directly. + free(data_); + data_ = nullptr; + data_end_ = nullptr; + } + } +} + +bool Tensor::operator==(const Tensor &rhs) const { + // 1. different shape 2. different type 3. one data_ is nullptr and the other is not + if (shape_ != rhs.shape() || type_ != rhs.type_ || (data_ == nullptr && rhs.data_ != nullptr) || + (data_ != nullptr && rhs.data_ == nullptr)) { + return false; + } + if (data_ == nullptr && rhs.data_ == nullptr) { + return true; + } + // use mem compare to compare the two data, size are already verified + return memcmp(data_, rhs.data_, SizeInBytes()) == 0; +} + +// Name: PrintItemAt() +// Description: A function that print the value as specified by its index +void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) const { + Status rc; + MS_ASSERT(data_); + + switch (type_.value()) { + CASE_PRINT_HEX(DataType::DE_BOOL, bool); + + CASE_PRINT_HEX(DataType::DE_INT8, int8_t); + + CASE_PRINT_HEX(DataType::DE_UINT8, uint8_t); + + CASE_PRINT(DataType::DE_INT16, int16_t); + + CASE_PRINT(DataType::DE_UINT16, uint16_t); + + CASE_PRINT(DataType::DE_INT32, int32_t); + + CASE_PRINT(DataType::DE_UINT32, uint32_t); + + CASE_PRINT(DataType::DE_INT64, int64_t); + + CASE_PRINT(DataType::DE_UINT64, uint64_t); + + CASE_PRINT(DataType::DE_FLOAT16, float16); + + CASE_PRINT(DataType::DE_FLOAT32, float); + + CASE_PRINT(DataType::DE_FLOAT64, double); + + case DataType::DE_STRING: { + std::string_view o{""}; + GetItemAt(&o, index); + out << "\"" << o << "\""; + break; + } + default: { + out << "?"; + break; + } + } + if (rc.IsError()) { + out << rc.ToString(); + } +} + +// Name: PrintRecursive() +// Description: A function that prints Tensor recursively, first called by print +void Tensor::PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const { + if (cur_index.size() == shape_.Rank()) { + PrintItemAt(cur_index, out); + } else { + out << "["; + for (dsize_t i = 0; i < shape_[cur_dim]; i++) { + std::vector new_index = cur_index; + new_index.push_back(i); + PrintRecursive(out, cur_dim + 1, new_index); + if (i < shape_[cur_dim] - 1) { + out << ","; + } + } + out << "]"; + } +} + +// Name: Print() +// Description: A function that prints info about the tensor +void Tensor::Print(std::ostream &out) const { + out << "Tensor (shape: "; + out << shape_; + out << ", Type: " << type_ << ")\n"; + if (data_) { + PrintRecursive(out, 0, std::vector{}); + } else { + out << "[Data area is null]"; + } +} +Status Tensor::AllocateBuffer(const dsize_t &length) { + if (data_ == nullptr) { + if (data_allocator_ != nullptr) { + data_ = data_allocator_->allocate(length); + RETURN_UNEXPECTED_IF_NULL(data_); + data_end_ = data_ + length; + } else { + data_ = static_cast(malloc(length)); + data_end_ = data_ + length; + RETURN_UNEXPECTED_IF_NULL(data_); + } + } + return Status::OK(); +} +const unsigned char *Tensor::GetBuffer() const { + // This version cannot modify anything. data_ could possibly be null. + return data_; +} + +// check for empty +bool Tensor::HasData() const { + if (data_ == nullptr) { + return true; + } else { + return false; + } +} + +unsigned char *Tensor::GetMutableBuffer() { + if (!shape_.known() || type_ == DataType::DE_UNKNOWN) { + return nullptr; + } + // If the data area is already created, return the pointer to it + if (data_ != nullptr) { + return data_; + } else { + // If the data area is not created, then identify the memory size based + // on the shape and type and allocate it. + if (this->AllocateBuffer(this->SizeInBytes()).IsOk()) { + return data_; + } else { + return nullptr; + } + } +} + +Status Tensor::Reshape(const TensorShape &shape) { + if (shape.NumOfElements() == shape_.NumOfElements()) { + shape_ = shape; + return Status::OK(); + } else { + std::string err = "Cannot reshape, Number of elements do not match"; + RETURN_STATUS_UNEXPECTED(err); + } +} + +void Tensor::Invalidate() { + shape_ = TensorShape::CreateUnknownRankShape(); + type_ = DataType(DataType::DE_UNKNOWN); + data_ = nullptr; + data_end_ = nullptr; + data_allocator_ = nullptr; +} + +template +Status Tensor::GetItemPtr(T **ptr, const std::vector &index) const { + if (type_.IsCompatible()) { + if (data_ == nullptr) { + std::string err = "Data is not allocated yet"; + RETURN_STATUS_UNEXPECTED(err); + } + dsize_t flat_idx; + RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); + *ptr = reinterpret_cast(data_ + flat_idx * type_.SizeInBytes()); + + return Status::OK(); + } else { + std::string err = "data type not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } +} + +Status Tensor::GetItemPtr(uchar **ptr, const std::vector &index, offset_t *length) const { + if (type_ == DataType::DE_STRING) { + if (data_ == nullptr) { + std::string err = "Data is not allocated yet"; + RETURN_STATUS_UNEXPECTED(err); + } + dsize_t flat_idx; + RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx)); + offset_t length_temp = 0; + RETURN_IF_NOT_OK(GetStringAt(flat_idx, ptr, &length_temp)); + if (length != nullptr) *length = length_temp; + return Status::OK(); + } else { + std::string err = "data type not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } +} + +Status Tensor::StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining) { + if (type() == DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("StartAddrOfIndex does not support string tensors yet."); + } + + dsize_t flat_ind; + std::vector t_shape = shape().AsVector(); + std::vector r(t_shape.begin() + ind.size(), t_shape.end()); + *remaining = TensorShape(r); + ind.resize(this->Rank(), 0); // same as -> while (ind.size() < this->Rank()) ind.push_back(0); + + RETURN_IF_NOT_OK(shape_.ToFlatIndex(ind, &flat_ind)); + // check if GetBuffer() returns null, we should flag this as an error, this sanity check will only + // be true is the tensor failed to allocate memory. + if (GetMutableBuffer() == nullptr) { + RETURN_STATUS_UNEXPECTED("Invalid GetBuffer in Tensor, got nullptr"); + } + *start_addr_of_index = GetMutableBuffer() + flat_ind * this->type().SizeInBytes(); + return Status::OK(); +} + +Status Tensor::InsertTensor(const std::vector &ind, const std::shared_ptr &tensor) { + std::string err_msg; + err_msg += (this->type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; + err_msg += (!this->shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; + err_msg += (ind.size() + tensor->Rank() != this->Rank()) ? "[Tensor] incorrect index\n" : ""; + err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; + uchar *start_addr_of_ind = nullptr; + TensorShape remaining_shape({-1}); + err_msg += (!StartAddrOfIndex(ind, &start_addr_of_ind, &remaining_shape).IsOk()) ? "[Tensor] incorrect index\n" : ""; + err_msg += !(remaining_shape == tensor->shape()) ? "[Tensor] memory error\n" : ""; + if (!err_msg.empty()) { + MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + if (start_addr_of_ind != nullptr) { + int ret_code = + memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); + if (ret_code == 0) { + return Status::OK(); + } else { + err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; + MS_LOG(DEBUG) << "Tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } else { + RETURN_STATUS_UNEXPECTED("Failed to create memory for Tensor."); + } + } +} + +Status Tensor::Concatenate(const std::vector &index, const std::shared_ptr &tensor) { + std::string err_msg; + err_msg += (index.size() != 1) ? "[Tensor] only supports 1d concatenation \n" : ""; + err_msg += (type() == DataType::DE_STRING) ? "[Tensor] Cannot batch tensors of type string\n" : ""; + err_msg += (!shape().known() || !tensor->shape().known()) ? "[Tensor] unknown shape\n" : ""; + + err_msg += + (index.at(0) + tensor->shape().NumOfElements() > this->shape().NumOfElements()) ? "[Tensor] incorrect index\n" : ""; + err_msg += tensor->type().SizeInBytes() != this->type().SizeInBytes() ? "[Tensor] incorrect datatype\n" : ""; + uchar *start_addr_of_ind = nullptr; + + TensorShape remaining_shape = tensor->shape(); + StartAddrOfIndex(index, &start_addr_of_ind, &remaining_shape); + err_msg += (start_addr_of_ind == nullptr) ? "Failed to create memory for Tensor.\n" : ""; + + if (!err_msg.empty()) { + MS_LOG(DEBUG) << "Insert tensor message: " << err_msg; + + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + int ret_code = + memcpy_s(start_addr_of_ind, tensor->SizeInBytes(), tensor->GetMutableBuffer(), tensor->SizeInBytes()); + + if (ret_code == 0) { + return Status::OK(); + } else { + err_msg += "[Tensor] error in memcpy_s when inserting tensor\n"; + MS_LOG(DEBUG) << "Tensor message: " << err_msg; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } +} + +Status Tensor::ExpandDim(const dsize_t &axis) { + if (axis > Rank()) { + std::string err = "Axis is out of bound"; + RETURN_STATUS_UNEXPECTED(err); + } + if (axis == Rank()) { + shape_ = shape_.AppendDim(1); + } else { + shape_ = shape_.InsertDim(axis, 1); + } + return Status::OK(); +} + +std::vector Tensor::Strides() { + std::vector strides = shape_.Strides(); + uint8_t size = type_.SizeInBytes(); + std::transform(strides.begin(), strides.end(), strides.begin(), [&size](const auto &c) { return c * size; }); + return strides; +} + +#ifdef ENABLE_PYTHON +Status Tensor::GetBufferInfo(Tensor *t, py::buffer_info *out) { + RETURN_UNEXPECTED_IF_NULL(t); + CHECK_FAIL_RETURN_UNEXPECTED(t->type().IsNumeric(), "Cannot use GetBufferInfo on tensor of strings."); + + std::string format_desc = t->type().GetPybindFormat(); + if (format_desc.empty()) { + RETURN_STATUS_UNEXPECTED("Cannot convert DE type tp pybind format"); + } + *out = py::buffer_info(t->GetMutableBuffer(), /* Pointer to buffer */ + t->type().SizeInBytes(), /* Size of one scalar */ + format_desc, /* Python struct-style format descriptor */ + t->Rank(), /* Number of dimensions */ + t->shape().AsVector(), /* Buffer dimensions */ + t->Strides()); + return Status::OK(); +} +#endif + +template +Status Tensor::GetItemAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + if (type_.IsUnsignedInt()) { + RETURN_IF_NOT_OK(GetUnsignedIntAt(o, index)); + } else if (type_.IsSignedInt()) { + RETURN_IF_NOT_OK(GetSignedIntAt(o, index)); + } else if (type_.IsFloat()) { + RETURN_IF_NOT_OK(GetFloatAt(o, index)); + } else if (type_.IsBool()) { + bool *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + } else { + std::string err = "Tensor Type is unknown"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} + +Status Tensor::GetItemAt(std::string_view *o, const std::vector &index) const { + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(o); + CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Tensor type is not a string"); + + uchar *start = nullptr; + offset_t length = 0; + RETURN_IF_NOT_OK(GetItemPtr(&start, index, &length)); + std::string_view sv{reinterpret_cast(start)}; + o->swap(sv); + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +// return data as numpy, should return status +Status Tensor::GetDataAsNumpy(py::array *data) { + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(data); + if (type_ == DataType::DE_BOOL) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT8) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT16) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT32) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_INT64) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT8) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT16) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT32) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_UINT64) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_FLOAT16) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_FLOAT32) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_FLOAT64) { + *data = py::array_t(shape_.AsVector(), reinterpret_cast(data_)); + } else if (type_ == DataType::DE_STRING) { + GetDataAsNumpyStrings(data); + } else { + RETURN_STATUS_UNEXPECTED("Got unexpected type when returning numpy"); + } + return Status::OK(); +} +Status Tensor::GetDataAsNumpyStrings(py::array *data) { + auto itr = begin(); + uint64_t max = 0; + for (; itr != end(); itr++) { + max = std::max((*itr).length(), max); + } + // if all strings are empty, numpy stores a byte for each string |S1 + max = (max == 0 ? 1 : max); + uint64_t total_size = shape_.NumOfElements() * max; + char *tmp_data = reinterpret_cast(data_allocator_->allocate(total_size)); + if (tmp_data == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create temp array."); + int ret_code = memset_s(tmp_data, total_size, 0, total_size); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to initialize temp memory"); + + itr = begin(); + uint64_t i = 0; + for (; itr != end(); itr++, i++) { + if (!(*itr).empty()) { + ret_code = memcpy_s(tmp_data + i * max, total_size, (*itr).data(), (*itr).length()); + CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy string data."); + } + } + auto strides = shape_.Strides(); + std::transform(strides.begin(), strides.end(), strides.begin(), [&max](const auto &s) { return s * max; }); + *data = py::array(py::dtype("S" + std::to_string(max)), shape_.AsVector(), strides, tmp_data); + data_allocator_->deallocate(reinterpret_cast(tmp_data)); + return Status::OK(); +} +#endif + +void Tensor::Squeeze() { shape_ = shape_.Squeeze(); } + +template +Status Tensor::GetUnsignedIntAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + switch (type_.value()) { + case DataType::DE_UINT8: { + uint8_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_UINT16: { + uint16_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_UINT32: { + uint32_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_UINT64: { + uint64_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + default: + std::string err = "Tensor Type is not an unsigned Integer"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} + +template +Status Tensor::GetSignedIntAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + switch (type_.value()) { + case DataType::DE_INT8: { + int8_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_INT16: { + int16_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_INT32: { + int32_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_INT64: { + int64_t *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + default: + std::string err = "Tensor Type is not a signed Integer"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} + +template +Status Tensor::GetFloatAt(T *o, const std::vector &index) const { + if (data_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Data is not allocated yet"); + } + if (!type_.IsLooselyCompatible()) { + std::string err = "Template type and Tensor type are not compatible"; + RETURN_STATUS_UNEXPECTED(err); + } + switch (type_.value()) { + case DataType::DE_FLOAT16: { + float16 *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_FLOAT32: { + float *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + case DataType::DE_FLOAT64: { + double *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *o = static_cast(*ptr); + break; + } + default: + std::string err = "Tensor Type is not a float/double"; + RETURN_STATUS_UNEXPECTED(err); + } + return Status::OK(); +} +Status Tensor::GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const { + CHECK_FAIL_RETURN_UNEXPECTED(type_ == DataType::DE_STRING, "Type is not string"); + RETURN_UNEXPECTED_IF_NULL(data_); + RETURN_UNEXPECTED_IF_NULL(string_start); + RETURN_UNEXPECTED_IF_NULL(length); + auto *offset_ptr = reinterpret_cast(data_); // offsets starts here + offset_t start = offset_ptr[index]; + *string_start = data_ + start; + *length = offset_ptr[index + 1] - start - 1; // -1 to skip the \0 from the string length + return Status::OK(); +} +Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vector &index) { + CHECK_FAIL_RETURN_UNEXPECTED(src->type() == type_, "Source Tensor has a different type"); + CHECK_FAIL_RETURN_UNEXPECTED(index.back() == 0, "Last dim in index should be 0"); + + uint8_t type_size = type_.SizeInBytes(); + size_t len = std::min(src->shape()[-1], shape_[-1]) * type_size; + dsize_t src_flat_ind = 0, dst_flat_ind = 0; + RETURN_IF_NOT_OK(src->shape().ToFlatIndex(index, &src_flat_ind)); + RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &dst_flat_ind)); + + const unsigned char *src_addr = src->GetBuffer() + src_flat_ind * type_size; + unsigned char *dst_addr = GetMutableBuffer() + dst_flat_ind * type_size; + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); + return Status::OK(); +} +Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { + CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); + CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); + if (type_.IsNumeric()) { + return SliceNumeric(out, indices); + } else { + return SliceString(out, indices); + } +} +Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { + RETURN_IF_NOT_OK( + CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); + (*out)->GetMutableBuffer(); + dsize_t out_index = 0; + dsize_t dim_length = shape_[0]; + dsize_t type_size = type_.SizeInBytes(); + dsize_t src_start = HandleNeg(indices[0], dim_length); + uchar *dst_addr = (*out)->data_; + dsize_t count = 1; + + for (dsize_t i = 0; i < indices.size(); i++) { + dsize_t cur_index = HandleNeg(indices[i], dim_length); + CHECK_FAIL_RETURN_UNEXPECTED( + cur_index >= 0 && cur_index < dim_length, + "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); + if (i < indices.size() - 1) { + dsize_t next_index = HandleNeg(indices[i + 1], dim_length); + if (next_index == cur_index + 1) { + count++; + continue; + } + } + int return_code = memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, + count * type_size); + CHECK_FAIL_RETURN_UNEXPECTED(return_code == 0, "memcpy_s failed in SliceNumeric"); + out_index += count; + if (i < indices.size() - 1) { + src_start = HandleNeg(indices[i + 1], dim_length); // next index + } + count = 1; + } + return Status::OK(); +} +Status Tensor::SliceString(std::shared_ptr *out, const std::vector &indices) { + dsize_t dim_length = shape_[0]; + std::vector strings; + for (dsize_t index : indices) { + dsize_t cur_index = HandleNeg(index, dim_length); + CHECK_FAIL_RETURN_UNEXPECTED( + cur_index >= 0 && cur_index < dim_length, + "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); + std::string_view sv; + GetItemAt(&sv, {cur_index}); + strings.emplace_back(sv); + } + return CreateTensor(out, strings); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor.h b/mindspore/ccsrc/minddata/dataset/core/tensor.h new file mode 100644 index 0000000000..b0b173e9c3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor.h @@ -0,0 +1,668 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_TENSOR_H_ +#define DATASET_CORE_TENSOR_H_ + +#include +#include +#include +#include +#include "./securec.h" +#include "utils/log_adapter.h" +#if defined(_WIN32) || defined(_WIN64) +#undef HAVE_STDDEF_H +#undef HAVE_STDLIB_H +#endif + +#ifdef ENABLE_PYTHON +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#endif + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/status.h" +#include "proto/example.pb.h" + +#ifdef ENABLE_PYTHON +namespace py = pybind11; +#endif +namespace mindspore { +namespace dataset { +class Tensor; +template +class Allocator; + +using CharAllocPtr = std::unique_ptr>; +using TensorAllocPtr = std::shared_ptr>; // An allocator shared_ptr for Tensors + +class Tensor { + public: + Tensor() = delete; + + // Create a new tensor, does not internally allocate storage. This constructor is protected, use CreateTensor. + // @note The shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + Tensor(const TensorShape &shape, const DataType &type); + + // Create a new tensor, allocates storage and copies in data. This constructor is protected, use CreateTensor. + // @note The buffer should be valid and the shape and type information should be known and valid. + // @param shape TensorShape + // @param type DataType + // @param data unsigned char*, pointer to the data. + Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data); + + Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length); + + Tensor(const Tensor &other) = delete; + + Tensor &operator=(const Tensor &other) = delete; + + Tensor(Tensor &&other) noexcept; + + Tensor &operator=(Tensor &&other) noexcept; + + Status AllocateBuffer(const dsize_t &length); + + // type of offest values to store strings information + using offset_t = uint32_t; + // const of the size of the offset variable + static constexpr uint8_t kOffsetSize = sizeof(offset_t); + // Tensor base class which holds the data in an unsigned char* buffer. + + // Construct a scalar string Tensor + explicit Tensor(const std::string &str) : Tensor(std::vector{str}, TensorShape::CreateScalar()) {} + + // Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is + // the size of the vector `strings`. + // The memory layout of a Tensor of strings consists of the Offset_array followed by the strings. + // Thr offset array will store one extra value to find the length of the last string. + // OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn + // The value of each offset is the start index of the corresponding string + // Offsets is of type offest_t + // strings will ne null-terminated + // example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING) + // |----------------------------------------------------------------| + // | OFFSET ARRAY | STRINGS | + // | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 | + // | 11 | 15 | 18 | abc\0 | de\0 | + // |----------------------------------------------------------------| + explicit Tensor(const std::vector &strings, + const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + + // Same as Tensor(vector) but the input is protobuf bytelist + explicit Tensor(const dataengine::BytesList &bytes_list, + const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + + // A static factory method to create the given flavour of derived Tensor + // Returns the base class reference for the Tensor. + // @param ptr output argument to hold the created Tensor of given tensor_impl + // @param tensor_impl - which implementation of Tensor + // @param shape - shape of the tensor + // @param type - datatype of the tensor + // @param data - data to be copied to Tensor new allocation + // @return Status Code + static Status CreateTensor(std::shared_ptr *, TensorImpl tensor_impl, const TensorShape &shape, DataType type, + const unsigned char *data = nullptr); + + // Create a copy of the input tensor + // @param out [out] output tensor to be generated + // @param in [in] orginal tensor to be copied + // @return Status + static Status CreateTensor(std::shared_ptr *out, const std::shared_ptr &in) { + const TensorAlloc *alloc = GlobalContext::Instance()->tensor_allocator(); + *out = std::allocate_shared(*alloc, in->shape(), in->type(), in->GetBuffer(), in->SizeInBytes()); + return Status::OK(); + } + +#ifdef ENABLE_PYTHON + // A static factory method to create a Tensor from a given py::array. + // @param ptr output argument to hold the created Tensor + // @param arr py::array + // @return Status Code + static Status CreateTensor(std::shared_ptr *ptr, py::array arr); + + // Helper function to create a tensor from Numpy of strings + static Status CreateTensorFromNumpyString(std::shared_ptr *ptr, py::array arr); +#endif + + // A static factory method to create a Tensor from a given list of strings. + // @param ptr output argument to hold the created Tensor + // @param strings elements of the tensor + // @param shape shape of the tensor + // @return Status Code + static Status CreateTensor(std::shared_ptr *ptr, const std::vector &strings, + const TensorShape &shape = TensorShape::CreateUnknownRankShape()); + + // create tensor from protobuf bytelist with strings + static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape); + + // A static factory method to create a Tensor from a given list of numbers. + // @param ptr output argument to hold the created Tensor + // @param items elements of the tensor + // @param shape shape of the tensor + // @return Status Code + template + static Status CreateTensor(std::shared_ptr *ptr, const std::vector &items, + const TensorShape &shape_req = TensorShape::CreateUnknownRankShape()) { + DataType type = DataType::FromCType(); + auto items_ptr = reinterpret_cast(&items[0]); + TensorShape shape = shape_req; + if (!shape.known()) { + shape = TensorShape({static_cast(items.size())}); + } + return CreateTensor(ptr, TensorImpl::kFlexible, shape, type, items_ptr); + } + + // A static factory method to create a Tensor from a given number. + // @param ptr output argument to hold the created Tensor + // @param item value + // @return Status Code + template + static Status CreateTensor(std::shared_ptr *ptr, const T &item) { + return CreateTensor(ptr, {item}, TensorShape::CreateScalar()); + } + + // Create tensor from protobuf bytelist with uint8 or int8 types + static Status CreateTensor(std::shared_ptr *ptr, const dataengine::BytesList &bytes_list, + const TensorShape &shape, const DataType &type, dsize_t pad_size); + + static Status CreateTensor(std::shared_ptr *ptr, const std::string &path); + + // Copy raw data of a array based on shape and strides to the destination pointer + // @param dst Pointer to the destination array where the content is to be copied + // @param src Pointer to the source of strided array to be copied + // @param shape - shape of the source array + // @param strides - strides of the source array + // @param type_size - number of bytes needed to store one array element's type + // @return Status Code + static Status CopyStridedArray(unsigned char *dst, unsigned char *src, std::vector shape, + std::vector strides, uint8_t type_size); + + // Release the memory using the allocator + virtual ~Tensor(); + + // compare the tensor shape and data + bool operator==(const Tensor &rhs) const; + + bool operator!=(const Tensor &rhs) const { return !((*this) == rhs); } + + // Get item located at `index`, caller needs to provide the type. + // @tparam T + // @param index vector + // @return return the item specified at index + template + Status GetItemAt(T *o, const std::vector &index) const; + + // Get string located at `index`. + // @param index vector + // @return return std::string_view specified at index + Status GetItemAt(std::string_view *o, const std::vector &index) const; + + template + Status GetUnsignedIntAt(T *o, const std::vector &index) const; + + template + Status GetSignedIntAt(T *o, const std::vector &index) const; + + template + Status GetFloatAt(T *o, const std::vector &index) const; + + // set item at location specified by index + // @tparam `T` + // @param index + // @param value of type `T` + template + Status SetItemAt(const std::vector &index, const T &value) { + RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); + T *ptr = nullptr; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index)); + *ptr = value; + return Status::OK(); + } + + // set string item at location specified by index + // @param index + // @param value of type std::string + Status SetItemAt(const std::vector &index, const std::string &value) { + RETURN_UNEXPECTED_IF_NULL(data_); + uchar *ptr = nullptr; + offset_t length = 0; + RETURN_IF_NOT_OK(GetItemPtr(&ptr, index, &length)); + if (value.length() != length) { + RETURN_STATUS_UNEXPECTED("Length of the new string does not match the item."); + } + memcpy_s(reinterpret_cast(ptr), length, value.c_str(), length); + + return Status::OK(); + } + // fill tensor with Zeros. Does not support strings. + Status Zero() { + CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use Zero on tensor of strings.."); + dsize_t size = SizeInBytes(); + CHECK_FAIL_RETURN_UNEXPECTED(memset_sp(GetMutableBuffer(), size, 0, size) == 0, + "Failed to fill tensor with zeroes."); + return Status::OK(); + } + + // Fill all elements in the Tensor with the given value of type `T`. Does not support strings. + // @tparam T + // @param value + template + Status Fill(const T &value) { + CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings."); + RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes())); + int64_t cellSize = type_.SizeInBytes(); + if ((data_ != nullptr) && type_.IsCompatible()) { + for (dsize_t i = 0; i < Size(); i++) { + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s((data_ + i * cellSize), cellSize, &value, cellSize) == 0, "memcpy err"); + } + return Status::OK(); + } else { + std::string err; + err += (data_ == nullptr) ? "data_ is nullptr \t" : ""; + err += type_.IsCompatible() ? "data type not compatible\t" : ""; + return Status(StatusCode::kUnexpectedError, err); + } + } + + // Getter function for shape + // @return + const TensorShape &shape() const { return shape_; } + + /// Check if tensor has data + /// \return bool - true if tensor is empty + bool HasData() const; + + // Reshape the tensor. The given shape should have the same number of elements in the Tensor + // @param shape + virtual Status Reshape(const TensorShape &shape); + + // @return number of elements in this tensor + dsize_t Size() const { return shape().NumOfElements(); } + + // @return the number of bytes this tensor is needs + dsize_t SizeInBytes() const { + if (data_end_ == nullptr) return type_.SizeInBytes() * shape_.NumOfElements(); + return data_end_ - data_; + } + + // @return the rank of the tensor + dsize_t Rank() const { return shape().Rank(); } + + // Get the starting memory address as a constant for the data of the tensor. This potentially + // drives an allocation if the data area. + // @return const unsigned char* + const unsigned char *GetBuffer() const; + + // Getter of the type + // @return + DataType type() const { return type_; } + + // Provide stream operator for displaying it + // @param output stream + // @param so the Tensor object to be printed + // @return output stream + friend std::ostream &operator<<(std::ostream &out, const Tensor &so) { + so.Print(out); + return out; + } + + // Invalidate this Tensor by setting the type and shape to unknown and MData to null. + // Calling this method will make the Tensor and its data inaccessible, use it with caution. + void Invalidate(); + + // Copy input tensor into self at the location index. + // Index is a vector of axises which can be incomplete: + // Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell. + // @param index + // @param input + // @return Status code + Status InsertTensor(const std::vector &index, const std::shared_ptr &input); + + // Find the address of the given index. Used in InsertTensor. + // Example: + // Tensor t= [[1,2],[3,4]] , StartAddrOfIndex({0}) -> &1 + // @param index incomplete index + // @param output: startAddrofIndex + // @param output: remaining + // @return Status code + Status StartAddrOfIndex(std::vector ind, uchar **start_addr_of_index, TensorShape *remaining); + + // Expand the shape of the Tensor with one extra dimension. + // For example, if the shape is <512,512,3>: + // *- ExpandDim(0) gives: <1,512,512,3> + // *- ExpandDim(1) gives: <512,1,512,3> + // *- ExpandDim(3) gives: <512,512,3,1> + // @param axis location of the dim + virtual Status ExpandDim(const dsize_t &axis); + + virtual void Squeeze(); + + // Calculates the strides of the Tensor + // Ex: Tensor of shape <4,2,2> and type DE_UINT8 (1 byte) + // The strides will be {6,2,1}. + // Ex: Tensor of shape <4,2,2> and type DE_UINT32 (4 byte) + // The strides will be {24,8,4}. + // @return vector of integers + std::vector Strides(); + + std::string ToString() { + std::stringstream ss; + this->Print(ss); + return ss.str(); + } + + // Handle negative indices. + static inline dsize_t HandleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } + + // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + // Based on the type of tensor, SliceNumeric or SliceString will be called + // @param out Tensor + // @param indices vector of indices + // @return Status error code + Status Slice(std::shared_ptr *out, const std::vector &indices); + + // Slice numeric tensors. + Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); + + // Slice string tensors + Status SliceString(std::shared_ptr *out, const std::vector &indices); + +#ifdef ENABLE_PYTHON + // Constructs numpy array from input tensor + // @param data this data is the location of python data + // @return Status code + Status GetDataAsNumpy(py::array *data); + + Status GetDataAsNumpyStrings(py::array *data); + + static Status GetBufferInfo(Tensor *t, py::buffer_info *out); +#endif + + // Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor + Status Concatenate(const std::vector &index, const std::shared_ptr &input); + + // TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor + // The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6 + // @tparam T type of values in the Tensor Iterator + template + class TensorIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T *; + using reference = T &; + + explicit TensorIterator(uchar *ptr = nullptr) { ptr_ = reinterpret_cast(ptr); } + + TensorIterator(const TensorIterator &raw_iterator) { ptr_ = raw_iterator.ptr_; } + + ~TensorIterator() = default; + + TensorIterator &operator=(const TensorIterator &rhs) { + ptr_ = rhs.ptr_; + return *this; + } + + TensorIterator &operator=(T *rhs) { + ptr_ = rhs; + return *this; + } + + bool operator==(const TensorIterator &rhs) { return ptr_ == rhs.ptr_; } + + bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } + + operator bool() const { return ptr_ != nullptr; } + + T &operator*() { return *ptr_; } + + const T &operator*() const { return *ptr_; } + + T *operator->() { return ptr_; } + + TensorIterator &operator+=(const ptrdiff_t &inc) { + ptr_ += inc; + return *this; + } + + TensorIterator &operator-=(const ptrdiff_t &inc) { + ptr_ -= inc; + return *this; + } + + TensorIterator &operator++() { + ++ptr_; + return *this; + } + + TensorIterator &operator--() { + --ptr_; + return *this; + } + + TensorIterator operator++(int) { + auto temp(*this); + ++ptr_; + return temp; + } + + TensorIterator operator--(int) { + auto temp(*this); + --ptr_; + return temp; + } + + TensorIterator operator+(const ptrdiff_t &inc) { + auto oldPtr = ptr_; + ptr_ += inc; + auto temp(*this); + ptr_ = oldPtr; + return temp; + } + + TensorIterator operator-(const ptrdiff_t &inc) { + auto oldPtr = ptr_; + ptr_ -= inc; + auto temp(*this); + ptr_ = oldPtr; + return temp; + } + + protected: + T *ptr_; + }; + + // Specialization of TensorIterator for strings. It returns std::string_view for every item. + // @tparam DUMMY, used to mbe able to specialize the inner class + template + class TensorIterator { + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = std::string_view; + using difference_type = ptrdiff_t; + using pointer = std::string_view *; + using reference = std::string_view &; + + explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) { + data_ = reinterpret_cast(data); + index_ = index; + } + + TensorIterator(const TensorIterator &raw_iterator) { + data_ = raw_iterator.data_; + index_ = raw_iterator.index_; + } + + ~TensorIterator() = default; + + bool operator==(const TensorIterator &rhs) { return data_ == rhs.data_ && index_ == rhs.index_; } + + bool operator!=(const TensorIterator &rhs) { return !(*this == rhs); } + + operator bool() const { return data_ != nullptr; } + + std::string_view operator*() const { + auto offset_ = reinterpret_cast(data_); + offset_t start = offset_[index_]; + return std::string_view{data_ + start}; + } + + TensorIterator &operator+=(const dsize_t &inc) { + index_ += inc; + return *this; + } + + TensorIterator &operator-=(const dsize_t &inc) { + index_ -= inc; + return *this; + } + + TensorIterator &operator++() { + ++index_; + return *this; + } + + TensorIterator &operator--() { + --index_; + return *this; + } + + TensorIterator operator++(int) { + auto temp(*this); + ++index_; + return temp; + } + + TensorIterator operator--(int) { + auto temp(*this); + --index_; + return temp; + } + + TensorIterator operator+(const dsize_t &inc) { + auto oldPtr = index_; + index_ += inc; + auto temp(*this); + index_ = oldPtr; + return temp; + } + + TensorIterator operator-(const dsize_t &inc) { + auto oldPtr = index_; + index_ -= inc; + auto temp(*this); + index_ = oldPtr; + return temp; + } + + protected: + dsize_t index_; + const char *data_; + }; + + // Return a TensorIterator that points to the start of the Tensor. + // It's the user responsibility to use the correct type that matches the Tensor type + // @param T The type of values in the Tensor + // @return TensorIterator + template + TensorIterator begin() { + AllocateBuffer(SizeInBytes()); + return TensorIterator(data_); + } + + // Return a linear iterator that points to the place after the last element of the Tensor. + // @tparam T The type of values in the Tensor + // @return TensorIterator + template + TensorIterator end() { + return TensorIterator(data_end_); + } + + // Copies the last dimension at `index` from Tensor `src` to this Tensor. + // @param src Tensor + // @param index vector to the start of the dimension. The last dim should be 0 + // @return Status + Status CopyLastDimAt(const std::shared_ptr &src, const std::vector &index); + + protected: + // Get the starting memory address for the data of the tensor. This potentially + // drives an allocation if the data is null. + // @return unsigned char* + unsigned char *GetMutableBuffer(); + + // A function that prints Tensor recursively, first called by print + // @param out + // @param cur_dim + // @param cur_index + void PrintRecursive(std::ostream &out, int32_t cur_dim, const std::vector &cur_index) const; + + // A function that prints info about the tensor + // @param out output stream + void Print(std::ostream &out) const; + + // A function that print the value as specified by its index + // @param index vector representing the index + // @param out + void PrintItemAt(const std::vector &index, std::ostream &out) const; + + // Get pointer to item located at `index`, caller needs to provide the type. + // @tparam T + // @param index vector + // @return return a pointer to the item specified at index of type `T` + template + Status GetItemPtr(T **, const std::vector &index) const; + + // Get pointer to string located at `index` and the length of string + // @param index vector + // @return return a pointer to the string specified at index and the length of the string + Status GetItemPtr(uchar **, const std::vector &index, offset_t *length = nullptr) const; + + // Given a flat index of an item string, return the start and length of the item + // @param index flat index of the item + // @return start address of the ths string + // @return length of the string + Status GetStringAt(dsize_t index, uchar **string_start, offset_t *length) const; + + // Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the + // tensor's type is a string, otherwise undefined address would be returned. + // @return address of the first string of the tensor. + uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; } + + // all access to shape_ should be via shape + TensorShape shape_; + // data type of tensor + DataType type_; + // pointer to the start of the physical data + unsigned char *data_; + // An allocator for data_ + CharAllocPtr data_allocator_; + // pointer to the end of the physical data + unsigned char *data_end_ = nullptr; +}; +template <> +inline Tensor::TensorIterator Tensor::end() { + return TensorIterator(data_, shape_.NumOfElements()); +} +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_row.cc b/mindspore/ccsrc/minddata/dataset/core/tensor_row.cc new file mode 100644 index 0000000000..5d75730a4c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_row.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { + +TensorRow::TensorRow() noexcept : id_(kDefaultRowId) {} + +TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept : id_(kDefaultRowId), row_(n, t) {} + +TensorRow::TensorRow(const TensorRow::vector_type &v) : id_(kDefaultRowId), row_(v) {} + +TensorRow::TensorRow(row_id_type id, const std::initializer_list &lst) : id_(id), row_(lst) {} + +TensorRow::TensorRow(const TensorRow &tr) : id_(tr.id_), row_(tr.row_) {} + +TensorRow &TensorRow::operator=(const TensorRow &tr) { + if (this == &tr) { + return *this; + } + row_ = tr.row_; + id_ = tr.id_; + return *this; +} + +TensorRow &TensorRow::operator=(const std::initializer_list &lst) { + row_ = lst; + return *this; +} + +TensorRow::TensorRow(TensorRow::vector_type &&v) noexcept : id_(kDefaultRowId), row_(std::move(v)) {} + +TensorRow::TensorRow(row_id_type id, std::initializer_list &&lst) noexcept + : id_(id), row_(std::move(lst)) {} + +TensorRow::TensorRow(TensorRow &&tr) noexcept { + id_ = tr.id_; + row_ = std::move(tr.row_); +} + +TensorRow &TensorRow::operator=(TensorRow &&tr) noexcept { + if (this == &tr) { + return *this; + } + row_ = std::move(tr.row_); + id_ = tr.id_; + tr.id_ = kDefaultRowId; + return *this; +} + +TensorRow &TensorRow::operator=(std::initializer_list &&lst) noexcept { + row_ = std::move(lst); + return *this; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_row.h b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h new file mode 100644 index 0000000000..e8f066c87b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_row.h @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_CORE_TENSOR_ROW_H_ +#define DATASET_CORE_TENSOR_ROW_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { + +class TensorRow; // A set of Tensor pointers with an id +using TensorTable = std::vector; // The table of tensors is a vector of rows +using TensorQTable = std::deque; // A different flavour of tensor table, this one has queue functionality + +class TensorRow { + public: + static constexpr row_id_type kDefaultRowId = -1; // Default row id + + // Type definitions + using size_type = dsize_t; + using value_type = std::shared_ptr; + using reference = std::shared_ptr &; + using const_reference = const std::shared_ptr &; + using vector_type = std::vector>; + using iterator = std::vector>::iterator; + using const_iterator = std::vector>::const_iterator; + + TensorRow() noexcept; + + TensorRow(size_type n, value_type t) noexcept; + + // Copy Constructors + explicit TensorRow(const vector_type &v); + + TensorRow(row_id_type id, const std::initializer_list &lst); + + TensorRow(const TensorRow &tr); + + TensorRow &operator=(const TensorRow &tr); + + TensorRow &operator=(const std::initializer_list &lst); + + // Move Constructors + explicit TensorRow(vector_type &&v) noexcept; + + TensorRow(row_id_type id, std::initializer_list &&lst) noexcept; + + TensorRow(TensorRow &&tr) noexcept; + + TensorRow &operator=(TensorRow &&tr) noexcept; + + TensorRow &operator=(std::initializer_list &&lst) noexcept; + + // Destructor + ~TensorRow() = default; + + // Functions to fetch/set id/vector + row_id_type getId() const { return id_; } + + void setId(row_id_type id) { id_ = id; } + + const vector_type &getRow() const { return row_; } + + // Wrapper functions to support vector operations + void emplace_back(value_type t) { row_.emplace_back(t); } + + void push_back(value_type t) { row_.push_back(t); } + + void clear() noexcept { row_.clear(); } + + size_type size() const noexcept { return row_.size(); } + + void reserve(size_type size) { row_.reserve(size); } + + void resize(size_type size) { row_.resize(size); } + + bool empty() { return row_.empty(); } + + void insert(iterator position, iterator first, iterator last) { row_.insert(position, first, last); } + + // Wrapper functions to support vector element access + reference at(size_type index) { return row_.at(index); } + + const_reference at(size_type index) const { return row_.at(index); } + + reference front() { return row_.front(); } + + const_reference front() const { return row_.front(); } + + reference back() { return row_.back(); } + + const_reference back() const { return row_.back(); } + + reference operator[](size_type index) { return row_[index]; } + + const_reference operator[](size_type index) const { return row_[index]; } + + // Wrapper functions to support vector iteration + iterator begin() { return row_.begin(); } + + const_iterator begin() const { return row_.begin(); } + + iterator end() { return row_.end(); } + + const_iterator end() const { return row_.end(); } + + protected: + row_id_type id_; + std::vector> row_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_TENSOR_ROW_H_ diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc new file mode 100644 index 0000000000..ff40062d37 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#define MAX_INTEGER_DTYPE 9223372036854775807 + +#include "minddata/dataset/core/tensor_shape.h" + +#include + +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { +constexpr dsize_t TensorShape::kDimUnknown; + +bool multi_ok(dsize_t x, dsize_t y) { + dsize_t p = x * y; + if (x == 0) { + return true; + } + return p / x == y; +} + +dsize_t TensorShape::NumOfElements() const { + if (!known()) { + return 0; + } + return strides_[0]; +} + +void TensorShape::Print(std::ostream &out) const { + if (!known() && raw_shape_.empty()) { + out << ""; + } else { + out << "<"; + for (auto i = 0; i < this->Rank(); i++) { + if (raw_shape_[i] == kDimUnknown) { + out << "*"; + } else { + out << raw_shape_[i]; + } + if (i != this->Rank() - 1) { + out << ","; + } + } + out << ">"; + } +} + +TensorShape::TensorShape(const std::initializer_list &list) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + AddListToShape(list); +} + +TensorShape::TensorShape(const std::vector &list) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + AddListToShape(list); +} + +TensorShape::TensorShape(const TensorShape &shape) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + AddListToShape(shape.AsVector()); + known_ = shape.known_; // override with the input shape in case of unknown-rank tensor shape. +} + +#ifdef ENABLE_PYTHON +TensorShape::TensorShape(py::list l) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + std::vector list_c; + for (auto &i : l) { + if (!i.is_none()) { + list_c.push_back(i.cast()); + } else { + list_c.push_back(TensorShape::kDimUnknown); + } + } + AddListToShape(list_c); +} +#endif + +TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) + : raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) { + for (int i = 0; i < cv_size.dims(); i++) { + raw_shape_.push_back(cv_size[i]); + } + auto channels = static_cast(1 + (type >> static_cast(CV_CN_SHIFT))); + if (channels != 1) { + raw_shape_.push_back(channels); + } + known_ = true; +} + +TensorShape TensorShape::CreateUnknownRankShape() { + TensorShape s({}); + s.known_ = false; + return s; +} + +TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const { + std::vector tmp = AsVector(); + (void)tmp.insert(tmp.begin() + axis, dim); + return TensorShape(tmp); +} + +std::vector TensorShape::AsVector() const { + return std::vector(raw_shape_.begin(), raw_shape_.end()); +} + +bool TensorShape::IsValidIndex(const std::vector &index) const { + dsize_t s_rank = Rank(); + if (index.size() != s_rank) { + return false; + } + for (dsize_t i = 0; i < s_rank; i++) { + if (index[i] < 0 || raw_shape_[i] <= index[i]) { + return false; + } + } + return true; +} + +template +void TensorShape::AddListToShape(const T &list) { + raw_shape_.resize(list.size()); + strides_.resize(list.size() + 1); + strides_[list.size()] = 1; + known_ = true; + dsize_t size = 0; + auto itr = std::rbegin(list); // iterate over the list in reverse order + auto s = list.size() - 1; // to compute strides while adding dims + for (; itr != std::rend(list); itr++, s--) { + dsize_t dim = *itr; + if (dim > 0) { + if (strides_[s + 1] > std::numeric_limits::max() / dim) { + MS_LOG(ERROR) << "Invalid shape data, overflow occurred!"; + known_ = false; + raw_shape_.clear(); + return; + } + strides_[s] = dim * strides_[s + 1]; + } + if (dim < 0) { + known_ = false; + } + if (dim > kDeMaxDim) { + std::stringstream ss; + ss << "Invalid shape data, dim (" << size << ") is larger than the maximum dim size(" << kDeMaxDim << ")!"; + MS_LOG(ERROR) << ss.str().c_str(); + known_ = false; + raw_shape_.clear(); + return; + } + raw_shape_[s] = dim; + size++; + } + if (size > kDeMaxRank) { + std::stringstream ss; + ss << "Invalid shape data, rank (" << size << ") is larger than the maximum rank size(" << kDeMaxRank << ")."; + MS_LOG(ERROR) << ss.str().c_str(); + known_ = false; + raw_shape_.clear(); + return; + } +} + +TensorShape TensorShape::CreateUnknownShapeWithRank(dsize_t rank) { + TensorShape s({}); + for (dsize_t i = 0; i < rank; i++) { + s.raw_shape_.push_back(kDimUnknown); + } + s.known_ = false; + return s; +} + +TensorShape TensorShape::PrependDim(dsize_t dim) const { + if (Size() == 0) { + return TensorShape({dim}); + } + return InsertDim(0, dim); +} + +TensorShape TensorShape::AppendDim(dsize_t dim) const { + auto vec = AsVector(); + vec.push_back(dim); + return TensorShape(vec); +} + +#ifdef ENABLE_PYTHON +py::list TensorShape::AsPyList() { + py::list list; + for (auto i : raw_shape_) { + list.append(i); + } + return list; +} +#endif + +TensorShape TensorShape::Squeeze() const { + std::vector new_shape; + for (auto s : AsVector()) { + if (s != 1) { + new_shape.push_back(s); + } + } + return TensorShape(new_shape); +} + +std::vector TensorShape::Strides() const { return std::vector{strides_.begin() + 1, strides_.end()}; } + +// Name: ToFlatIndex() +// Description: convert a vector style index to number, used to access memory internal use only +Status TensorShape::ToFlatIndex(const std::vector &index, dsize_t *flat_index) const { + *flat_index = 0; + for (size_t k = 0; k < index.size(); k++) { + *flat_index += index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements + } + CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h new file mode 100644 index 0000000000..4944f9e32c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/core/tensor_shape.h @@ -0,0 +1,196 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CORE_TENSOR_SHAPE_H_ +#define DATASET_CORE_TENSOR_SHAPE_H_ + +#include +#include +#include +#include +#include + +#include + +#ifdef ENABLE_PYTHON +#include "pybind11/pybind11.h" +namespace py = pybind11; +#endif + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/allocator.h" + +namespace mindspore { +namespace dataset { +// Class that represents a shape of a Tensor. A shape can be: +// -# Known shape (mKnown = true) +// -# Scalar --> empty vector --> <> +// -# n-Dim --> not empty vector --> where di is >= 0\n +// Example: <1,2>, <1>, <1,13,10,11,1> +// -# Unknown shape (mKnown = false) +// -# Rank is unknown --> empty vector --> <> +// -# one or more dim is unknown --> not empty vector --> where di is unknown\n +// Example: <3,?> (the 1st dim is unknown)\n +// <2,?,?,?> (all dims but the 0th dim are unknown) + +/// \brief TensorShape supports any dim > 0 and < 2^31-1 +class TensorShape { + public: + static constexpr dsize_t kDimUnknown = -1; // constant for an unknown dimension + + // Force the compiler to not create a no-arg constructor + TensorShape() = delete; + + /// \brief Create a Shape from an initialization list (e.g., TensorShape s = {2,2}). + /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown + /// \param[in] list + explicit TensorShape(const std::initializer_list &list); + + /// \brief Create a Shape from a vector (e.g., TensorShape s = std::vector({2,2}) ). + /// If one of the dims is set to DIM_UNKNOWN, the shape will flagged as unKnown + /// \param[in] list + explicit TensorShape(const std::vector &list); + + /// \brief Copy constructor + /// \param[in] shape + TensorShape(const TensorShape &shape); + +#ifdef ENABLE_PYTHON + /// \brief construct a TensorShape via a python list + /// \param[in] py::list l - a list object from python + explicit TensorShape(py::list l); +#endif + + ~TensorShape() = default; + + /// \brief Create a scalar Shape (i.e., empty shape with mKnown = true) + /// \return TensorShape + static TensorShape CreateScalar() { return TensorShape({}); } + + /// \brief Create a shape with an unknown rank. + /// \return TensorShape + static TensorShape CreateUnknownRankShape(); + + /// \brief Create a shape with a known rank . + /// \return TensorShape + static TensorShape CreateUnknownShapeWithRank(dsize_t rank); + + /// \brief Insert a new dim into a copy of the current shape. + /// \param[in] dim to be added + /// \param[in] axis the index where dim should be added + /// \return New modified shape + TensorShape InsertDim(dsize_t axis, dsize_t dim) const; + + /// \brief Insert new dim at index 0. For example, <2,4> --> PrependDim(4) --> <4,2,4> + /// \param[in] dim + /// \return + TensorShape PrependDim(dsize_t dim) const; + + /// \brief Insert a new dim at the end of the shape. For example, <2,4> --> AppendDim(4) --> <2,4,4> + /// \param[in] dim + /// \return + TensorShape AppendDim(dsize_t dim) const; + + /// \brief Create a shape based on OpenCV shape and type + /// \param[in] cv_size + /// \param[in] type int that represent the type in OpenCV, example CV_8U, CV_64S + TensorShape(cv::MatSize cv_size, uint32_t type); + + dsize_t Size() const { return raw_shape_.size(); } + + dsize_t Rank() const { return raw_shape_.size(); } + + bool known() const { return known_; } + + bool empty() const { return raw_shape_.empty(); } + + dsize_t NumOfElements() const; + + bool operator==(const TensorShape &rhs) const { return known_ == rhs.known_ && raw_shape_ == rhs.raw_shape_; } + + bool operator!=(const TensorShape &rhs) const { return !(rhs == *this); } + + dsize_t operator[](const dsize_t index) const { + if (index < 0) return raw_shape_[raw_shape_.size() + index]; + return raw_shape_[index]; + } + + /// \brief Return the Shape as a vector + /// \return + std::vector AsVector() const; + + /// \brief Returns the class info as a string + /// \return + std::string ToString() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + /// \brief Actual print function used by operator<< + /// \param out output string stream + void Print(std::ostream &out) const; + + /// \brief << Stream output operator overload + /// This allows you to print the info using stream operators + /// \param[in] out - reference to the output stream being overloaded + /// \param[in] rO - reference to the TensorShape to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const TensorShape &so) { + so.Print(out); + return out; + } + +#ifdef ENABLE_PYTHON + py::list AsPyList(); +#endif + + /// \brief Checks if the given index is a valid index for this tensor. + /// For example: Tensor<3,4> Index<1,1> is valid. But Index<4,1> or <1> are not. + /// \param[in] index + /// \return bool + bool IsValidIndex(const std::vector &index) const; + + TensorShape Squeeze() const; + + std::vector Strides() const; + + /// \brief Returns the location of the item assuming row major memory layout. + /// \param[in] index + /// \param[out] flat_index + /// \return + Status ToFlatIndex(const std::vector &index, dsize_t *flat_index) const; + + private: + // True if known and valid shape, false otherwise + bool known_; + // Vector to keep the dims of the shape. + std::vector raw_shape_; + // Vector to keep the strides of the shape. The size is rank+1 + std::vector strides_; + + /// \brief Internal utility function to iterate over a list, + /// check if the dim is valid and then insert it into the shape. + /// \param[in] list Iterable list + /// \return true if the shape is valid and no overflow would be generated when counting the number of elements. + /// False otherwise. + template + void AddListToShape(const T &list); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_TENSOR_SHAPE_H_ diff --git a/mindspore/ccsrc/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt diff --git a/mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc new file mode 100644 index 0000000000..04746131bb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/bit.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill) + : server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {} + +// print method for display cache details +void CacheClient::Print(std::ostream &out) const { + out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_ + << "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_ + << "\n Spilling: " << std::boolalpha << spill_; +} + +Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const { + CacheRowRequest rq(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row)); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + if (row_id_from_server != nullptr) { + *row_id_from_server = rq.GetRowIdAfterCache(); + } + return Status::OK(); +} + +Status CacheClient::WriteBuffer(std::unique_ptr &&in) const { + std::unique_ptr db_ptr = std::move(in); + auto num_rows = db_ptr->NumRows(); + std::vector all_rows; + if (num_rows > 0) { + all_rows.reserve(num_rows); + // Break down the DataBuffer into TensorRow. We will send the requests async + // and then do a final wait. + MemGuard rq_arr; + RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie())); + CacheServer &cs = CacheServer::GetInstance(); + for (auto i = 0; i < num_rows; ++i) { + TensorRow row; + auto rq = rq_arr[i]; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row)); + RETURN_IF_NOT_OK(cs.PushRequest(rq)); + // We can't let row go out of scope. Otherwise it will free all the tensor memory. + // So park it in the vector. When this function go out of scope, its memory + // will be freed. + all_rows.push_back(std::move(row)); + } + // Now we wait for the requests to be done. + for (auto i = 0; i < num_rows; ++i) { + auto rq = rq_arr[i]; + RETURN_IF_NOT_OK(rq->Wait()); + } + } + return Status::OK(); +} + +Status CacheClient::GetRows(const std::vector &row_id, TensorTable *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + BatchFetchRequest rq(server_connection_id_, row_id); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + RETURN_IF_NOT_OK(rq.RestoreRows(out)); + return Status::OK(); +} + +Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) { + UniqueLock lck(&mux_); + // To create a cache, we identify ourself at the client by: + // - the shared session id + // - a crc for the tree nodes from the cache downward + // Pack these 2 into a single 64 bit request id + // + // Consider this example: + // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch + // tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch + // These are different trees in a single session, but the user wants to share the cache. + // This is not allowed because the data of these caches are different. + // + // Consider this example: + // tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch + // tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch + // These are different trees in the same session, but the cached data is the same, so it is okay + // to allow the sharing of this cache between these pipelines. + + // The CRC is computed by the tree prepare phase and passed to this function when creating the cache. + // If we already have a server_connection_id_, then it means this same cache client has already been used + // to create a cache and some other tree is trying to use the same cache. + // That is allowed, however the crc better match! + if (server_connection_id_) { + if (cache_crc_ != tree_crc) { + RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!"); + } + // Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should + // skip the build phase. + lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock. + CacheClient::ServiceStat stat{}; + RETURN_IF_NOT_OK(GetStat(&stat)); + if (stat.cache_service_state == static_cast(CacheService::State::kFetchPhase)) { + return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase"); + } + } else { + cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client + // Combine the session and crc. This will form our client cache identifier. + connection_id_type connection_identification = (static_cast(session_id_) << 32) | cache_crc_; + // Now execute the cache create request using this identifier and other configs + BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone; + if (spill_) { + createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk; + } + if (generate_id) { + createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId; + } + CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + Status rc = rq.Wait(); + if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) { + server_connection_id_ = rq.GetServerConnectionId(); + if (rc.IsOk()) { + // The 1st guy creating the cache will get a cookie back. + // But this object may be shared among pipelines and we don't want + // overwrite it. + cookie_ = rq.cookie(); + } + } + // We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the + // CacheOp to bypass the build phase. + return rc; + } + return Status::OK(); +} + +Status CacheClient::PurgeCache() { + UniqueLock lck(&mux_); + PurgeCacheRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + return rq.Wait(); +} + +Status CacheClient::DestroyCache() { + UniqueLock lck(&mux_); + DestroyCacheRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + return rq.Wait(); +} + +Status CacheClient::GetStat(ServiceStat *stat) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(stat); + GetStatRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + stat->num_disk_cached = rq.GetNumDiskCached(); + stat->num_mem_cached = rq.GetNumMemCached(); + stat->min_row_id = rq.GetMinRowId(); + stat->max_row_id = rq.GetMaxRowId(); + stat->cache_service_state = rq.GetState(); + return Status::OK(); +} + +Status CacheClient::CacheSchema(const std::unordered_map &map) { + SharedLock lck(&mux_); + CacheSchemaRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map)); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + return Status::OK(); +} + +Status CacheClient::FetchSchema(std::unordered_map *map) { + SharedLock lck(&mux_); + RETURN_UNEXPECTED_IF_NULL(map); + FetchSchemaRequest rq(server_connection_id_); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + *map = rq.GetColumnMap(); + return Status::OK(); +} + +Status CacheClient::BuildPhaseDone() const { + SharedLock lck(&mux_); + BuildPhaseDoneRequest rq(server_connection_id_, cookie()); + RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq)); + RETURN_IF_NOT_OK(rq.Wait()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h new file mode 100644 index 0000000000..f25db87578 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h @@ -0,0 +1,141 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_CACHE_CLIENT_H_ +#define DATASET_ENGINE_CACHE_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/util/lock.h" + +namespace mindspore { +namespace dataset { +/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through +/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously +/// rows, etc. +class CacheClient { + public: + /// \brief Constructor + /// \param session_id A user assigned session id for the current pipeline + /// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited + /// \param spill Spill to disk if out of memory + CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill); + + /// \brief Destructor + ~CacheClient() = default; + + /// \brief Getter function for returning the current session id + /// \return session id + uint64_t session_id() const { return session_id_; } + + /// \brief Send a TensorRow to the cache server + /// \param[in] row + /// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset + /// \return return code + Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const; + + /// \brief Send a DataBuffer to the cache server + /// \param in Unique pointer of the DataBuffer to be cached + /// \return return code + Status WriteBuffer(std::unique_ptr &&in) const; + + /// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is + /// any cache miss + /// \param row_id A vector of row id's + /// \param out A TensorTable of TensorRows. + /// \return return code + Status GetRows(const std::vector &row_id, TensorTable *out) const; + + /// \brief Create a cache. + /// \param tree_crc A crc that was generated during tree prepare phase + /// \param generate_id Let the cache service generate row id + /// \return Status object + Status CreateCache(uint32_t tree_crc, bool generate_id); + + /// \brief Purge a cache. Cache can be reused after reset. + /// \return Status object + Status PurgeCache(); + + /// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused. + /// \return Status object + Status DestroyCache(); + + /// \brief Get the statistics from a cache. + /// \param[in/out] Pointer to a pre-allocated ServiceStat object + /// \return Status object + struct ServiceStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + row_id_type min_row_id; + row_id_type max_row_id; + int8_t cache_service_state; + }; + Status GetStat(ServiceStat *); + + /// \brief Cache the schema at the cache server + /// \param map The unordered map of the schema + /// \return Status object + Status CacheSchema(const std::unordered_map &map); + + /// \brief Fetch the schema from the cache server + /// \param map Pointer to pre-allocated map object + /// \return Status object. + Status FetchSchema(std::unordered_map *map); + + /// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache + /// client that holds cookie can be allowed to make this request + /// \return Status object + Status BuildPhaseDone() const; + + /// \brief A print method typically used for debugging + /// \param out The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief Stream output operator overload + /// \return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) { + cc.Print(out); + return out; + } + + /// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it. + /// \return Cookie + std::string cookie() const { return cookie_; } + + private: + mutable RWLock mux_; + uint64_t cache_mem_sz_; + bool spill_; + // The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow + // sharing of the cache. + uint32_t session_id_; + uint32_t cache_crc_; + // The server_connection_id_ is the actual id we use for operations after the cache is built + connection_id_type server_connection_id_; + // Some magic cookie returned from the cache server. + std::string cookie_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_CACHE_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc new file mode 100644 index 0000000000..3b7fc057a2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_request.h" + +namespace mindspore { +namespace dataset { + +Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) { + buffers_.reserve(row.size() + 1); + RETURN_IF_NOT_OK(SerializeTensorRowHeader(row)); + buffers_.push_back(fbb_->GetBufferPointer()); + for (const auto &ts : row) { + buffers_.push_back(ts->GetBuffer()); + } + return Status::OK(); +} + +Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) { + try { + fbb_ = std::make_shared(); + std::vector> v; + std::vector tensor_sz; + v.reserve(row.size()); + tensor_sz.reserve(row.size()); + // We will go through each column in the row. + for (const std::shared_ptr &ts_ptr : row) { + flatbuffers::Offset ts_off; + RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off)); + v.push_back(ts_off); + tensor_sz.push_back(ts_ptr->SizeInBytes()); + } + auto column_off = fbb_->CreateVector(v); + auto data_sz_off = fbb_->CreateVector(tensor_sz); + TensorRowHeaderMsgBuilder row_builder(*fbb_); + row_builder.add_column(column_off); + row_builder.add_data_sz(data_sz_off); + // Pass the row_id even if it may not be known. + row_builder.add_row_id(row.getId()); + row_builder.add_size_of_this(-1); // fill in later after we call Finish. + auto out = row_builder.Finish(); + fbb_->Finish(out); + // Now go back to fill in size_of_this in the flat buffer. + auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer()); + auto success = msg->mutate_size_of_this(fbb_->GetSize()); + if (!success) { + RETURN_STATUS_UNEXPECTED("Unable to set size_of_this"); + } + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, + flatbuffers::Offset *out_off) { + RETURN_UNEXPECTED_IF_NULL(out_off); + const Tensor *ts = ts_ptr.get(); + auto shape_off = fbb_->CreateVector(ts->shape().AsVector()); + const auto ptr = ts->GetBuffer(); + if (ptr == nullptr) { + RETURN_STATUS_UNEXPECTED("Tensor buffer is null"); + } + auto src = ts->type().value(); + TensorType dest; +#define CASE(t) \ + case DataType::t: \ + dest = TensorType::TensorType_##t; \ + break + // Map the type to fill in the flat buffer. + switch (src) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + default: + MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts; + RETURN_STATUS_UNEXPECTED("Unknown type"); + } +#undef CASE + + TensorMetaMsgBuilder ts_builder(*fbb_); + ts_builder.add_dims(shape_off); + ts_builder.add_type(dest); + auto ts_off = ts_builder.Finish(); + *out_off = ts_off; + return Status::OK(); +} + +Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, + std::shared_ptr *out) { + RETURN_UNEXPECTED_IF_NULL(col_ts); + auto shape_in = col_ts->dims(); + auto type_in = col_ts->type(); + std::vector v; + v.reserve(shape_in->size()); + v.assign(shape_in->begin(), shape_in->end()); + TensorShape shape(v); + DataType::Type dest = DataType::DE_UNKNOWN; +#define CASE(t) \ + case TensorType_##t: \ + dest = DataType::Type::t; \ + break + + switch (type_in) { + CASE(DE_BOOL); + CASE(DE_INT8); + CASE(DE_UINT8); + CASE(DE_INT16); + CASE(DE_UINT16); + CASE(DE_INT32); + CASE(DE_UINT32); + CASE(DE_INT64); + CASE(DE_UINT64); + CASE(DE_FLOAT16); + CASE(DE_FLOAT32); + CASE(DE_FLOAT64); + CASE(DE_STRING); + } +#undef CASE + + DataType type(dest); + std::shared_ptr ts = + std::make_shared(shape, type, static_cast(data.GetPointer()), data.GetSize()); + // Next we restore the real data which can be embedded or stored separately. + if (ts->SizeInBytes() != data.GetSize()) { + MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n" + << "Dumping tensor\n" + << *ts << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + *out = std::move(ts); + return Status::OK(); +} + +Status BatchFetchRequest::RestoreRows(TensorTable *out) { + RETURN_UNEXPECTED_IF_NULL(out); + auto num_elements = row_id_.size(); + auto *offset_array = reinterpret_cast(mem_.GetPointer()); + TensorTable tbl; + tbl.reserve(num_elements); + ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes()); + for (auto i = 0; i < num_elements; ++i) { + auto len = offset_array[i + 1] - offset_array[i]; + TensorRow row; + row.setId(row_id_.at(i)); + if (len > 0) { + ReadableSlice row_data(all, offset_array[i], len); + // Next we de-serialize flat buffer to get back each column + auto msg = GetTensorRowHeaderMsg(row_data.GetPointer()); + auto msg_sz = msg->size_of_this(); + // Start of the tensor data + auto ts_offset = msg_sz; + row.reserve(msg->column()->size()); + for (auto k = 0; k < msg->column()->size(); ++k) { + auto col_ts = msg->column()->Get(k); + std::shared_ptr ts; + ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k)); + RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts)); + row.push_back(ts); + ts_offset += data.GetSize(); + } + } + tbl.push_back(std::move(row)); + } + *out = std::move(tbl); + return Status::OK(); +} + +Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map &map) { + try { + fbb_ = std::make_shared(); + std::vector> v; + v.reserve(map.size()); + for (auto &column : map) { + auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second); + v.push_back(c); + } + auto v_off = fbb_->CreateVector(v); + auto final_off = CreateSchemaMsg(*fbb_, v_off); + fbb_->Finish(final_off); + buf_ = fbb_->GetBufferPointer(); + len_of_buf_ = fbb_->GetSize(); + return Status::OK(); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } +} + +std::unordered_map FetchSchemaRequest::GetColumnMap() { + if (column_name_id_map_.empty()) { + auto *map_msg = flatbuffers::GetRoot(mem_.GetPointer()); + auto v = map_msg->column(); + for (auto i = 0; i < v->size(); ++i) { + auto col = map_msg->column()->Get(i); + column_name_id_map_.emplace(col->name()->str(), col->id()); + } + } + return column_name_id_map_; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h new file mode 100644 index 0000000000..3d0edc6dd8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h @@ -0,0 +1,225 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_ENGINE_CACHE_REQ_H_ +#define DATASET_ENGINE_CACHE_REQ_H_ + +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +/// \brief CacheClient communicates with CacheServer using Requests. +class BaseRequest { + public: + // Request types + enum class RequestType : int16_t { + kCacheRow = 0, + kBatchFetchRows = 1, + kCreateCache = 2, + kPurgeCache = 3, + kDestroyCache = 4, + kGetStat = 5, + kCacheSchema = 6, + kFetchSchema = 7, + kBuildPhaseDone = 8, + // Add new request before it. + kRequestUnknown = 32767 + }; + // For kCreateCache + enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L }; + friend class CacheServer; + /// \brief Base class of a cache server request + /// \param connection_id A combination of session id and crc that uniquely identifies a connection. + /// \param type Type of the request + explicit BaseRequest(connection_id_type connection_id, RequestType type) + : type_(type), connection_id_(connection_id) {} + virtual ~BaseRequest() = default; + /// \brief Wait for the completion of a request + /// \return Status returned from the cache server + Status Wait() { + RETURN_IF_NOT_OK(wp_.Wait()); + return rc_; + } + + /// \brief Getter function of the current connection id + /// \return Connection id + connection_id_type GetServerConnectionId() const { return connection_id_; } + + private: + RequestType type_; + connection_id_type connection_id_; + Status rc_; + WaitPost wp_; +}; +/// \brief Request to cache a single TensorRow +class CacheRowRequest : public BaseRequest { + public: + friend class CacheServer; + explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie) + : BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {} + ~CacheRowRequest() = default; + + /// \brief Serialize a TensorRow for streaming to the cache server + /// \param row TensorRow + /// \return Status object + Status SerializeCacheRowRequest(const TensorRow &row); + /// \brief Return the row id assigned to this row for non-mappable dataset + /// \return row id of the cached row + row_id_type GetRowIdAfterCache() { return row_id_from_server_; } + + private: + std::shared_ptr fbb_; + row_id_type row_id_from_server_; + std::vector buffers_; + std::string cookie_; + + /// \brief Private function to serialize one TensorRow + /// \param row TensorRow + /// \return Status object + Status SerializeTensorRowHeader(const TensorRow &row); + /// \brief Private function to serialize one Tensor + /// \param ts_ptr Tensor + /// \return Status object + Status SerializeOneTensorMeta(const std::shared_ptr &ts_ptr, flatbuffers::Offset *out_off); +}; +/// \brief Request to fetch rows in batch +class BatchFetchRequest : public BaseRequest { + public: + friend class CacheServer; + friend class CacheService; + BatchFetchRequest(connection_id_type connection_id, const std::vector &row_id) + : BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {} + Status RestoreRows(TensorTable *out); + + private: + std::vector row_id_; + MemGuard mem_; + Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr *out); +}; +/// \brief Request to create a cache for the current connection +class CreationCacheRequest : public BaseRequest { + public: + friend class CacheServer; + /// \brief Constructor + /// \param connection_id + /// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited + /// \param flag Attributes of the cache. + explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz, + CreateCacheFlag flag = CreateCacheFlag::kNone) + : BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {} + + std::string cookie() const { return cookie_; } + + private: + uint64_t cache_mem_sz; + CreateCacheFlag flag_; + std::string cookie_; +}; +/// \brief Request to purge a cache. +class PurgeCacheRequest : public BaseRequest { + public: + friend class CacheServer; + explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {} +}; +/// \brief Request to destroy a cache +class DestroyCacheRequest : public BaseRequest { + public: + friend class CacheServer; + explicit DestroyCacheRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kDestroyCache) {} +}; +/// \brief Obtain the statistics of the current connection +class GetStatRequest : public BaseRequest { + public: + friend class CacheServer; + friend class CacheService; + explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {} + row_id_type GetMinRowId() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->min_row_id(); + } + row_id_type GetMaxRowId() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->max_row_id(); + } + int64_t GetNumMemCached() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->num_mem_cached(); + } + int64_t GetNumDiskCached() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->num_disk_cached(); + } + uint8_t GetState() const { + auto *msg = flatbuffers::GetRoot(mem_.GetPointer()); + return msg->state(); + } + + private: + MemGuard mem_; +}; +/// \brief Request to cache a schema +class CacheSchemaRequest : public BaseRequest { + public: + friend class CacheServer; + explicit CacheSchemaRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {} + ~CacheSchemaRequest() = default; + + Status SerializeCacheSchemaRequest(const std::unordered_map &map); + const void *GetBuffer() const { return buf_; } + + private: + std::shared_ptr fbb_; + const void *buf_; + int64_t len_of_buf_; +}; +/// \brief Request to fetch a schema +class FetchSchemaRequest : public BaseRequest { + public: + friend class CacheServer; + explicit FetchSchemaRequest(connection_id_type connection_id) + : BaseRequest(connection_id, RequestType::kFetchSchema) {} + ~FetchSchemaRequest() = default; + + std::unordered_map GetColumnMap(); + + private: + MemGuard mem_; + std::unordered_map column_name_id_map_; +}; +/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only. +class BuildPhaseDoneRequest : public BaseRequest { + public: + friend class CacheServer; + BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie) + : BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {} + + private: + std::string cookie_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc new file mode 100644 index 0000000000..c9fb6ecab1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc @@ -0,0 +1,252 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/bit.h" + +namespace mindspore { +namespace dataset { +Status CacheServer::DoServiceStart() { + if (!top_.empty()) { + Path spill(top_); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + MS_LOG(INFO) << "CacheServer will use disk folder: " << top_; + } + RETURN_IF_NOT_OK(vg_.ServiceStart()); + cache_q_ = std::make_shared>(1024); + RETURN_IF_NOT_OK(cache_q_->Register(&vg_)); + auto f = std::bind(&CacheServer::ServerRequest, this); + // Spawn a a few threads to serve the request. + for (auto i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f)); + } + return Status::OK(); +} + +Status CacheServer::DoServiceStop() { + Status rc; + Status rc2; + // First stop all the threads. + RETURN_IF_NOT_OK(vg_.ServiceStop()); + // Clean up all the caches if any. + UniqueLock lck(&rwLock_); + auto it = all_caches_.begin(); + while (it != all_caches_.end()) { + auto cs = std::move(it->second); + rc2 = cs->ServiceStop(); + if (rc2.IsError()) { + rc = rc2; + } + ++it; + } + return rc; +} + +CacheService *CacheServer::GetService(connection_id_type id) const { + SharedLock lck(&rwLock_); + auto it = all_caches_.find(id); + if (it != all_caches_.end()) { + return it->second.get(); + } + return nullptr; +} + +Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, + BaseRequest::CreateCacheFlag flag, std::string *out_cookie) { + // We can't do spilling unless this server is setup with a spill path in the first place + bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk; + bool generate_id = + (flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId; + if (spill && top_.empty()) { + RETURN_STATUS_UNEXPECTED("Server is not set up with spill support."); + } + RETURN_UNEXPECTED_IF_NULL(out_cookie); + *out_cookie = ""; + // Before creating the cache, first check if this is a request for a shared usage of an existing cache + // If two CreateService come in with identical connection_id, we need to serialize the create. + // The first create will be successful and be given a special cookie. + UniqueLock lck(&rwLock_); + auto end = all_caches_.end(); + auto it = all_caches_.find(connection_id); + if (it == end) { + std::unique_ptr cs; + try { + cs = std::make_unique(cache_mem_sz, spill ? top_ : "", generate_id); + RETURN_IF_NOT_OK(cs->ServiceStart()); + *out_cookie = cs->cookie(); + all_caches_.emplace(connection_id, std::move(cs)); + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } + } else { + MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service"; + // We can return OK but we will return a duplicate key so user can act accordingly to either ignore it + // treat it as OK. + return Status(StatusCode::kDuplicateKey); + } + return Status::OK(); +} + +/// This is the main loop the cache server thread(s) are running. +/// Each thread will pop a request and save the result in the same request. +/// The sender will wait on the wait post in the request. Once the request +/// is fulfilled, the server thread will do a post signalling the request is +/// is processed. +/// \return +Status CacheServer::ServerRequest() { + TaskManager::FindMe()->Post(); + // Loop forever until we are interrupted. + while (true) { + BaseRequest *base_rq = nullptr; + RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq)); + auto cs = GetService(base_rq->connection_id_); + // Except for creating a new session, we expect cs is not null. + switch (base_rq->type_) { + case BaseRequest::RequestType::kCacheRow: { + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + // Only if the cookie matches, we can accept insert into this cache that has a build phase + if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) { + rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + break; + } + case BaseRequest::RequestType::kBatchFetchRows: { + if (cs == nullptr) { + std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_); + } + break; + } + case BaseRequest::RequestType::kCreateCache: { + // If the cache is already created we still need to run the creation so that we do sanity checks on the + // client id and return the cache id back to the user. + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_); + break; + } + case BaseRequest::RequestType::kPurgeCache: { + if (cs != nullptr) { + base_rq->rc_ = cs->Purge(); + } else { + // it is already purged. Ignore it. + base_rq->rc_ = Status::OK(); + } + break; + } + case BaseRequest::RequestType::kDestroyCache: { + if (cs != nullptr) { + // We need a strong lock to protect the map. + connection_id_type id = base_rq->connection_id_; + UniqueLock lck(&rwLock_); + // std::map will invoke the constructor of CacheService. So we don't need to do anything here. + auto n = all_caches_.erase(id); + if (n == 0) { + // It has been destroyed by another duplicate request. + MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service"; + } + base_rq->rc_ = Status::OK(); + } else { + // it is already destroyed. Ignore it. + base_rq->rc_ = Status::OK(); + } + break; + } + case BaseRequest::RequestType::kGetStat: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + CacheService::ServiceStat svc_stat; + rq->rc_ = cs->GetStat(&svc_stat); + if (rq->rc_.IsOk()) { + flatbuffers::FlatBufferBuilder fbb; + ServiceStatMsgBuilder bld(fbb); + bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached); + bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached); + bld.add_max_row_id(svc_stat.max_); + bld.add_min_row_id(svc_stat.min_); + bld.add_state(svc_stat.state_); + auto offset = bld.Finish(); + fbb.Finish(offset); + rq->rc_ = rq->mem_.allocate(fbb.GetSize()); + if (rq->rc_.IsOk()) { + WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize()); + ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize()); + RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src)); + } + } + } + break; + } + case BaseRequest::RequestType::kCacheSchema: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_); + } + break; + } + case BaseRequest::RequestType::kFetchSchema: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + rq->rc_ = cs->FetchSchema(&rq->mem_); + } + break; + } + case BaseRequest::RequestType::kBuildPhaseDone: { + if (cs == nullptr) { + std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found"; + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg); + } else { + auto *rq = reinterpret_cast(base_rq); + // We can only allow to switch phase is the cookie match. + if (rq->cookie_ == cs->cookie()) { + rq->rc_ = cs->BuildPhaseDone(); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch"); + } + } + break; + } + default: + base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type"); + } + // Notify it is done, and move on to the next request. + base_rq->wp_.Set(); + } + return Status::OK(); +} +CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers) + : top_(spill_path), num_workers_(num_workers) {} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h new file mode 100644 index 0000000000..13b68c4389 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_CACHE_SERVER_H_ +#define DATASET_ENGINE_CACHE_SERVER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +class BaseRequest; +/// \brief A server which provides CacheService services. +class CacheServer : public Service { + public: + friend class Services; + using cache_index = std::map>; + + CacheServer(const CacheServer &) = delete; + CacheServer &operator=(const CacheServer &) = delete; + CacheServer(CacheServer &&) = delete; + CacheServer &operator=(CacheServer &) = delete; + static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); } + Status DoServiceStart() override; + Status DoServiceStop() override; + ~CacheServer() { (void)ServiceStop(); } + + /// \brief For the current demonstration, a cache client contacts cache server using a Queue. + /// \param rq + /// \return Status object + Status PushRequest(BaseRequest *rq) { + RETURN_UNEXPECTED_IF_NULL(rq); + RETURN_IF_NOT_OK(cache_q_->Add(rq)); + return Status::OK(); + } + + private: + mutable RWLock rwLock_; + std::string top_; + cache_index all_caches_; + std::shared_ptr> cache_q_; + TaskGroup vg_; + int32_t num_workers_; + + /// \brief Constructor + /// \param spill_path Top directory for spilling buffers to. + /// \param num_workers Number of threads for handling requests. + explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3); + + /// \brief Locate a cache service from connection id. + /// \return Pointer to cache service. Null if not found + CacheService *GetService(connection_id_type id) const; + + /// \brief Create a cache service. We allow multiple clients to create the same cache service. + /// Subsequent duplicate requests are ignored. The first cache client to create the service will be given + /// a special unique cookie. + /// \param[in] connection_id This is from a Cache client. + /// \param[in] cache_mem_sz + /// \param[in] flag + /// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator + /// \return Status object + Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag, + std::string *out_cookie); + + /// \brief Entry point for all server threads. + Status ServerRequest(); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CORE_CACHE_TENSOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc new file mode 100644 index 0000000000..4e1208d173 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc @@ -0,0 +1,265 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id) + : root_(root), + cache_mem_sz_(mem_sz), + cp_(nullptr), + map_(nullptr), + next_id_(0), + generate_id_(generate_id), + schema_key_(-1), + st_(generate_id ? State::kBuildPhase : State::kNone) {} +CacheService::~CacheService() { (void)ServiceStop(); } +bool CacheService::UseArena() { + // If fixed size, use Arena instead of the pool from global context. + return (cache_mem_sz_ > 0); +} +Status CacheService::DoServiceStart() { + std::shared_ptr mp_; + if (UseArena()) { + // Create a fixed size arena based on the parameter. + std::shared_ptr arena; + RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_)); + mp_ = std::move(arena); + } else { + // Unlimited size. Simply use a system pool. Another choice is CircularPool. + mp_ = std::make_shared(); + } + // Put together a CachePool for backing up the Tensor + cp_ = std::make_shared(CachePool::value_allocator(mp_), root_); + RETURN_IF_NOT_OK(cp_->ServiceStart()); + // Set up the B+ tree as well. But use the system pool instead. + map_ = std::make_shared(); + // Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name. + cookie_ = cp_->MyName(); + return Status::OK(); +} +Status CacheService::DoServiceStop() { + if (cp_ != nullptr) { + RETURN_IF_NOT_OK(cp_->ServiceStop()); + } + return Status::OK(); +} +Status CacheService::CacheRow(const std::vector &buf, row_id_type *row_id_generated) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(row_id_generated); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + try { + // The first buffer is a flatbuffer which describes the rest of the buffers follow + auto fb = buf.front(); + RETURN_UNEXPECTED_IF_NULL(fb); + auto msg = GetTensorRowHeaderMsg(fb); + // If the server side is designed to ignore incoming row id, we generate row id. + if (generate_id_) { + *row_id_generated = GetNextRowId(); + // Some debug information on how many rows we have generated so far. + if ((*row_id_generated) % 1000 == 0) { + MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated; + } + } else { + if (msg->row_id() < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + *row_id_generated = msg->row_id(); + } + auto size_of_this = msg->size_of_this(); + auto column_hdr = msg->column(); + // Number of tensor buffer should match the number of columns plus one. + if (buf.size() != column_hdr->size() + 1) { + std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) + + " but get " + std::to_string(buf.size()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + // Next we store in either memory or on disk. Low level code will consolidate everything in one piece. + std::vector all_data; + all_data.reserve(column_hdr->size() + 1); + all_data.emplace_back(fb, size_of_this); + for (auto i = 0; i < column_hdr->size(); ++i) { + all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i)); + } + // Now we cache the flat buffer. + CachePool::key_type key; + RETURN_IF_NOT_OK(cp_->Insert(all_data, &key)); + Status rc = map_->DoInsert(*row_id_generated, key); + if (rc == Status(StatusCode::kDuplicateKey)) { + MS_LOG(DEBUG) << "Ignoring duplicate key."; + } else { + RETURN_IF_NOT_OK(rc); + } + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } +} +std::ostream &operator<<(std::ostream &out, const CacheService &cs) { + // Then show any custom derived-internal stuff + out << "\nCache memory size: " << cs.cache_mem_sz_; + out << "\nSpill path: "; + if (cs.root_.empty()) { + out << "None"; + } else { + out << cs.GetSpillPath(); + } + return out; +} +Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); } +Status CacheService::Purge() { + // First we must lock exclusively. No one else can cache/restore anything. + UniqueLock rw(&rw_lock_); + RETURN_IF_NOT_OK(cp_->ServiceStop()); + auto new_map = std::make_shared(); + map_.reset(); + map_ = std::move(new_map); + next_id_ = 0; + RETURN_IF_NOT_OK(cp_->ServiceStart()); + return Status::OK(); +} +Status CacheService::GetStat(CacheService::ServiceStat *out) { + SharedLock rw(&rw_lock_); + RETURN_UNEXPECTED_IF_NULL(out); + if (st_ == State::kNone || st_ == State::kFetchPhase) { + out->stat_ = cp_->GetStat(); + out->state_ = static_cast(st_); + auto it = map_->begin(); + if (it != map_->end()) { + out->min_ = it.key(); + auto end_it = map_->end(); + --end_it; + out->max_ = end_it.key(); + } + } else { + out->state_ = static_cast(st_); + } + return Status::OK(); +} +Status CacheService::BatchFetch(const std::vector &v, MemGuard *out) const { + RETURN_UNEXPECTED_IF_NULL(out); + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + const auto num_elements = v.size(); + int64_t mem_sz = (num_elements + 1) * sizeof(int64_t); + int64_t data_offset = mem_sz; + std::vector sz_v; + std::vector keys; + sz_v.reserve(num_elements); + keys.reserve(num_elements); + for (auto row_id : v) { + auto r = map_->Search(row_id); + if (r.second) { + auto &it = r.first; + CachePool::key_type key = it.value(); + auto sz = cp_->GetSize(key); + if (sz == 0) { + std::string errMsg = "Key not found: "; + errMsg += std::to_string(key); + RETURN_STATUS_UNEXPECTED(errMsg); + } + keys.push_back(key); + sz_v.push_back(sz); + mem_sz += sz; + } else { + keys.push_back(-1); + sz_v.push_back(0); + } + } + MemGuard mem; + RETURN_IF_NOT_OK(mem.allocate(mem_sz)); + auto *offset_array = reinterpret_cast(mem.GetMutablePointer()); + offset_array[0] = data_offset; + WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes()); + for (auto i = 0; i < num_elements; ++i) { + auto sz = sz_v.at(i); + offset_array[i + 1] = offset_array[i] + sz; + if (sz > 0) { + WritableSlice row_data(all, offset_array[i], sz); + auto key = keys.at(i); + size_t bytesRead = 0; + RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead)); + if (bytesRead != sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + } + *out = std::move(mem); + return Status::OK(); +} +Status CacheService::CacheSchema(const void *buf, int64_t len) { + SharedLock rw(&rw_lock_); + if (st_ == State::kFetchPhase) { + // For this kind of cache service, once we are done with the build phase into fetch phase, we can't + // allow other to cache more rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + // This is a special request and we need to remember where we store it. + // In case we are calling the same function from multiple threads, only + // the first one is considered. Rest is ignored. + CachePool::key_type cur_key = schema_key_; + CachePool::key_type key; + if (cur_key < 0) { + RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key)); + auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key); + MS_LOG(DEBUG) << "Caching Schema. Result = " << result; + } else { + MS_LOG(DEBUG) << "Caching Schema already done"; + } + return Status::OK(); +} +Status CacheService::FetchSchema(MemGuard *out) const { + SharedLock rw(&rw_lock_); + if (st_ == State::kBuildPhase) { + // For this kind of cache service, we can't fetch yet until we are done with caching all the rows. + RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase"); + } + RETURN_UNEXPECTED_IF_NULL(out); + MemGuard mem; + if (schema_key_ >= 0) { + auto len = cp_->GetSize(schema_key_); + RETURN_IF_NOT_OK(mem.allocate(len)); + auto slice = WritableSlice(mem.GetMutablePointer(), len); + RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice)); + *out = std::move(mem); + } else { + return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached"); + } + return Status::OK(); +} +Status CacheService::BuildPhaseDone() { + if (HasBuildPhase()) { + // Exclusive lock to switch phase + UniqueLock rw(&rw_lock_); + st_ = State::kFetchPhase; + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase"); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h new file mode 100644 index 0000000000..bf324e82e3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_CACHE_SERVICE_H_ +#define DATASET_ENGINE_CACHE_SERVICE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "./de_tensor_generated.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/cache/cache_request.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/btree.h" +#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +struct CacheStat; +/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is +/// created to support spilling +class CacheService : public Service { + public: + friend class CacheServer; + using row_map = BPlusTree; + + enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase }; + + /// \brief Constructor + /// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited + /// \param root Spill path. Empty string means no spilling + /// \param generate_id If the cache service should generate row id for buffer that is cached. + /// For non-mappable dataset, this should be set to true. + CacheService(uint64_t mem_sz, const std::string &root, bool generate_id); + ~CacheService(); + + /// \brief For fixed size memory, we will create an Arena. + /// \return false if unlimited memory. + bool UseArena(); + + Status DoServiceStart() override; + Status DoServiceStop() override; + + /// \brief Main function to cache a row which is in form a series of buffers. + /// The first buffer is a Google flatbuffer which describes the rest of the buffers followed. + /// \param[in] buf Vector of buffer + /// \param[out] row_id_generated The row id assigned to this row if any + /// \return Status object + Status CacheRow(const std::vector &buf, row_id_type *row_id_generated); + /// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded + /// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row. + /// \param[in] v A vector of row id. + /// \param[out] out A contiguous memory buffer that holds the requested rows. + /// \return Status object + Status BatchFetch(const std::vector &v, MemGuard *out) const; + + /// \brief Getter function + /// \return Spilling path + Path GetSpillPath() const; + /// \brief A structure returned from the cache server for statistics request. + class ServiceStat { + public: + using state_type = std::underlying_type::type; + ServiceStat() : min_(0), max_(0), state_(0) {} + CachePool::CacheStat stat_{}; + row_id_type min_; + row_id_type max_; + state_type state_; + }; + /// \brief Statistics for the current service + /// \param[in/out] A pointer to a pre-allocated ServiceStat structure + /// \return Status Object + Status GetStat(ServiceStat *); + /// \brief Cache schema + /// \param buf A Google Flatbuffer that contains the schema + /// \param len size of the buffer + /// \return Status object + Status CacheSchema(const void *buf, int64_t len); + /// \brief Fetch schema + /// \param out A contiguous memory that contains the serialized form of schema. + /// \return Status object + Status FetchSchema(MemGuard *out) const; + /// \brief Purge the content of a cache + /// \return Status object + Status Purge(); + /// \brief Overload the << operator to print a cache service + /// \param out std::ostream + /// \param cs A cache service + /// \return std::ostream + friend std::ostream &operator<<(std::ostream &out, const CacheService &cs); + /// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient + /// is the creator + /// \return Cookie + std::string cookie() const { return cookie_; } + /// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and + /// a read phase. + /// \return True if has two phases. + bool HasBuildPhase() const { return generate_id_; } + /// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call. + /// \return Status object + Status BuildPhaseDone(); + + private: + mutable RWLock rw_lock_; + std::string root_; + uint64_t cache_mem_sz_; + std::shared_ptr cp_; + std::shared_ptr map_; + std::atomic next_id_; + bool generate_id_; + std::atomic schema_key_; + std::string cookie_; + State st_; + + /// \brief Private function to generate a row id + /// \return Row id assigned. + row_id_type GetNextRowId() { return next_id_.fetch_add(1); } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_CACHE_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs b/mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs similarity index 100% rename from mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs rename to mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs diff --git a/mindspore/ccsrc/minddata/dataset/engine/connector.h b/mindspore/ccsrc/minddata/dataset/engine/connector.h new file mode 100644 index 0000000000..a91d8e68e9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/connector.h @@ -0,0 +1,211 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_CONNECTOR_H_ +#define DATASET_ENGINE_CONNECTOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/cond_var.h" + +namespace mindspore { +namespace dataset { +// Connector is a communication data structure between two group of threads that +// preserve the order. +// +// Example use case: +// An initial tasks-list of [1,2,3,4,5,6,7,8,9] with 5 threads getting/processing elements from that list, +// and pushing the processed elements to a Connector in any order whoever finishes processing first. +// If the consumer of the Connector is single threaded, when the consumer pop() the +// element from the Connector one by one, it will get [1,2,3,4,5,6,7,8,9]. +// +// Requirements: +// 1. Each thread in the group of consumer or producer threads must be assigned ids starting from 0. +// 2. If your multi-threads program is not reading from a Connector class but +// want to push to a Connector class, you must follow roundrobin element distribution, +// i.e., the thread-id0 must have the first element, thread-id1 has the second element, +// and so on; then each of this worker can push to the Connector class async in parallel. +// +// Blocking conditions: +// 1. Connector.push(int, T) can block when the internal queue it's trying to push is full. +// 2. Connector.pop(int) can block when +// - The internal queue it's trying to pop is empty. +// - The caller thread of pop() is not equal to the _expectConsumer. This is to enforce +// the ordering. +// +// Future improvement: +// 1. Fault tolerant: Right now, if one of the worker dies, the Connector will not work +// properly. +template +class Connector { + public: + // Name: Constructor + // Description: Initializing private members with the given input arguments. + // expect_consumer_ and pop_from_ is initialized to 0 as part of + // our requirements. We instantiate nProducers number of internal + // queues so that each producer thread can push to its queue without + // any sync overhead. + // Constructor of Connector + // Initializing private members with the given input arguments. + // _expectConsumer and _popFrom is initialized to 0 as part of + // our requirements. We instantiate nProducers number of internal + // queues so that each producer thread can push to its queue without + // any sync overhead. + // @param n_producers The number of threads producing data into this DbConnector. + // @param n_consumers The number of thread consuming data from this DbConnector. + // @param queue_capacity The number of element (DataBuffer) for each queue. + Connector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) + : num_producers_(n_producers), num_consumers_(n_consumers) { + MS_LOG(DEBUG) << "A connector is created with " << n_producers << " producers and " << n_consumers << " consumers."; + my_name_ = Services::GetUniqueID(); + // We require the consumers to have ids sequentially from 0 to the num_consumers_-1, + // Otherwise a ordered list of consumer ids have to be passed here. (not implemented yet) + expect_consumer_ = 0; + + // Roundrobin pop starts from index 0 of the queues_. + pop_from_ = 0; + + // Initialize the queues_ to have num_producers_ number of queues. + // Each queue is a blocking queue and has the same queue_capacity. + queues_.Init(num_producers_, queue_capacity); + } + + // Destructor of Connector + virtual ~Connector() = default; + + // Get an element from the Connector. + // @not Call to pop() can block the caller thread, see the blocking condition at the top of this file. + // @param worker_id The id of a worker thread calling this method. + // @param result The address of an object where the popped element will be placed. + virtual Status Pop(int32_t worker_id, // The worker-id of the caller. See the requirement at the top of this file. + T *result) noexcept { + { + MS_ASSERT(worker_id < num_consumers_); + std::unique_lock lk(m_); + RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return expect_consumer_ == worker_id; })); + RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); + pop_from_ = (pop_from_ + 1) % num_producers_; + out_buffers_count_++; + expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; + } + + cv_.NotifyAll(); + return Status::OK(); + } + + // Add an element into the DbConnector without the overhead of synchronization. + // It may block when the internal queue is full. + // The element passed to this function will be copied into the internal queue. + // @param worker_id The id of a worker thread calling this method. + // @param el A const lvalue element to be passed/added/pushed. + Status Push(int32_t worker_id, const T &el) noexcept { + MS_ASSERT(worker_id < static_cast(queues_.size())); + MS_ASSERT(queues_[worker_id] != nullptr); + return (queues_[worker_id]->Add(el)); + } + + auto out_buffers_count() const { return out_buffers_count_.load(); } + + // Add an element into the DbConnector without the overhead of synchronization. + // It may block when the internal queue is full. + // The element passed to this function will be forwarded into the internal queue. + // @param worker_id The id of a worker thread calling this method. + // @param el An element to be passed/added/pushed. + virtual Status Push(int32_t worker_id, T &&el) noexcept { + MS_ASSERT(worker_id < static_cast(queues_.size())); + MS_ASSERT(queues_[worker_id] != nullptr); + return (queues_[worker_id]->Add(std::forward(el))); + } + + // Resets the internal index tracking of the queue so that it can be used again with new inputs, + // starting from the beginning. + void Reset() { + for (int i = 0; i < queues_.size(); ++i) { + queues_[i]->ResetQue(); + } + expect_consumer_ = 0; + pop_from_ = 0; + out_buffers_count_ = 0; + MS_LOG(DEBUG) << "Connector counters reset."; + } + + void Print(std::ostream &out, bool showAll) const { + out << "\n--------- Connector ------------" + << "\nConnector Name : " << my_name_ << "\nNumber of consumers : " << num_consumers_ + << "\nNumber of producers : " << num_producers_ << "\n"; + } + + friend std::ostream &operator<<(std::ostream &out, const Connector &con) { + con.print(out, false); + return out; + } + + // Get current size of connector. + int32_t size() const { + int32_t size = 0; + for (int32_t i = 0; i < queues_.size(); ++i) { + size += queues_[i]->size(); + } + return size; + } + + int32_t capacity() const { + int32_t capacity = 0; + for (int32_t i = 0; i < queues_.size(); ++i) { + capacity += queues_[i]->capacity(); + } + return capacity; + } + + // Register the internal resources with Task group for interruption service. + // @param vg + // @return + Status Register(TaskGroup *vg) { + Status rc = queues_.Register(vg); + if (rc.IsOk()) { + rc = cv_.Register(vg->GetIntrpService()); + } + return rc; + } + + protected: + std::string my_name_; + + // A list of Queues that are thread safe. + QueueList queues_; + + // The consumer that we allow to get the next data from pop() + int32_t expect_consumer_; + + // The index to the queues_ where the next data should be popped. + int32_t pop_from_; + + int32_t num_producers_; + int32_t num_consumers_; + + // Used in the Pop(), when a thread call pop() but it is not the expect_consumer_. + std::mutex m_; + CondVar cv_; + std::atomic out_buffers_count_ = 0; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.cc b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.cc new file mode 100644 index 0000000000..b36aae6837 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { +// Name: Constructor #1 +// Description: This is the main constructor that is used for making a buffer +DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {} + +// A method for debug printing of the buffer +void DataBuffer::Print(std::ostream &out, bool show_all) const { + out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n"; + + // If the column counts are set then it means that data has been set into + // the tensor table. Display the tensor table here. + if (this->NumCols() > 0) { + out << "Tensor table:\n"; + for (int32_t row = 0; row < DataBuffer::NumRows(); ++row) { + out << "Row # : " << row << "\n"; + TensorRow currRow = (*tensor_table_)[row]; + for (int32_t col = 0; col < this->NumCols(); ++col) { + out << "Column #: " << col << "\n"; // Should add the column name here as well? + // Call the tensor display + out << *(currRow[col]) << "\n"; + } + } + } +} + +// Remove me!! Callers should fetch rows via pop +Status DataBuffer::GetTensor(std::shared_ptr *ptr, int32_t row_id, int32_t col_id) const { + if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) { + *ptr = (tensor_table_->at(row_id)).at(col_id); + } else { + std::string err_msg = + "indices for mTensorTable out of range: (" + std::to_string(row_id) + "," + std::to_string(col_id) + ")."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// Remove me!! Callers should fetch rows via pop +Status DataBuffer::GetRow(int32_t row_id, TensorRow *ptr) const { + if (tensor_table_ && !tensor_table_->empty() && row_id < tensor_table_->size()) { + *ptr = tensor_table_->at(row_id); + } else { + std::string err_msg = "rowId for mTensorTable out of range: " + std::to_string(row_id); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + return Status::OK(); +} + +Status DataBuffer::PopRow(TensorRow *ptr) { + if (tensor_table_ && !tensor_table_->empty()) { + *ptr = std::move(tensor_table_->front()); + tensor_table_->pop_front(); + } + + return Status::OK(); +} + +Status DataBuffer::SliceOff(int64_t number_of_rows) { + while (number_of_rows > 0) { + tensor_table_->pop_back(); + number_of_rows--; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h new file mode 100644 index 0000000000..5fcb4c21a5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_buffer.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATA_BUFFER_H_ +#define DATASET_ENGINE_DATA_BUFFER_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between +/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format +/// where n TensorRows may consist of m tensors (columns). +class DataBuffer { + public: + // Buffer flags + enum BufferFlags : uint32_t { + kDeBFlagNone = 0, + kDeBFlagEOF = 1, // The buffer is an eof end-of-data msg + kDeBFlagEOE = 1u << 1 // The buffer is an eoe end-of-epoch msg + }; + + // Name: Constructor #1 + // Description: This is the main constructor that is used for making a buffer + DataBuffer(int32_t id, BufferFlags flags); + + /// \brief default destructor + ~DataBuffer() = default; + + /// \brief A method for debug printing of the buffer + /// \param[inout] out The stream to write to + /// \param[in] show_all A boolean to toggle between details and summary printing + void Print(std::ostream &out, bool show_all) const; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) { + cb.Print(out, false); + return out; + } + + // Convenience getter functions for flag checking + bool eof() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOF)); } + + bool eoe() const { return (static_cast(buffer_flags_) & static_cast(kDeBFlagEOE)); } + + // Simple getter funcs + int32_t id() const { return buffer_id_; } + + void set_id(int32_t id) { buffer_id_ = id; } + + int32_t NumRows() const { return ((tensor_table_) ? tensor_table_->size() : 0); } + + int32_t NumCols() const { + return (tensor_table_ == nullptr || tensor_table_->empty()) ? 0 : tensor_table_->at(0).size(); + } + + BufferFlags buffer_flags() const { return buffer_flags_; } + + // Remove me!! Callers should fetch rows via pop + Status GetTensor(std::shared_ptr *, int32_t row_id, int32_t col_id) const; + + // Remove me!! Callers should drain rows via pop. + Status GetRow(int32_t row_id, TensorRow *) const; + + // Get a row from the TensorTable + Status PopRow(TensorRow *); + + Status SliceOff(int64_t number_of_rows); + + // Replacing mTensorTable, the unique_ptr assignment will release the old TensorTable. + void set_tensor_table(std::unique_ptr new_table) { tensor_table_ = std::move(new_table); } + + void set_flag(BufferFlags in_flag) { + buffer_flags_ = static_cast(static_cast(buffer_flags_) | static_cast(in_flag)); + } + + void Shuffle() {} // does nothing right now. possibly remove later + + protected: + int32_t buffer_id_; // An id for the buffer. + std::unique_ptr tensor_table_; // A table (row major) of Tensors + BufferFlags buffer_flags_; // bit mask for various buffer properties +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATA_BUFFER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc b/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc new file mode 100644 index 0000000000..50d910251d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_schema.cc @@ -0,0 +1,451 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/data_schema.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// A macro for converting an input string representing the column type to it's actual +// numeric column type. +#define STR_TO_TENSORIMPL(in_col_str, out_type) \ + do { \ + if (in_col_str == "cvmat") { \ + out_type = TensorImpl::kCv; \ + } else if (in_col_str == "flex") { \ + out_type = TensorImpl::kFlexible; \ + } else if (in_col_str == "np") { \ + out_type = TensorImpl::kNP; \ + } else { \ + out_type = TensorImpl::kNone; \ + } \ + } while (false) + +// Constructor 1: Simple constructor that leaves things uninitialized. +ColDescriptor::ColDescriptor() + : type_(DataType::DE_UNKNOWN), rank_(0), tensor_impl_(TensorImpl::kNone), tensor_shape_(nullptr) {} + +// Constructor 2: Main constructor +ColDescriptor::ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, + const TensorShape *in_shape) + : type_(col_type), rank_(rank), tensor_impl_(tensor_impl), col_name_(col_name) { + // If a shape was provided, create unique pointer for it and copy construct it into + // our shape. Otherwise, set our shape to be empty. + if (in_shape != nullptr) { + // Create a shape and copy construct it into our column's shape. + tensor_shape_ = std::make_unique(*in_shape); + } else { + tensor_shape_ = nullptr; + } + // If the user input a shape, then the rank of the input shape needs to match + // the input rank + if (in_shape != nullptr && in_shape->known() && in_shape->Size() != rank_) { + rank_ = in_shape->Size(); + MS_LOG(WARNING) << "Rank does not match the number of dimensions in the provided shape." + << " Overriding rank with the number of dimensions in the provided shape."; + } +} + +// Explicit copy constructor is required +ColDescriptor::ColDescriptor(const ColDescriptor &in_cd) + : type_(in_cd.type_), rank_(in_cd.rank_), tensor_impl_(in_cd.tensor_impl_), col_name_(in_cd.col_name_) { + // If it has a tensor shape, make a copy of it with our own unique_ptr. + tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; +} + +// Assignment overload +ColDescriptor &ColDescriptor::operator=(const ColDescriptor &in_cd) { + if (&in_cd != this) { + type_ = in_cd.type_; + rank_ = in_cd.rank_; + tensor_impl_ = in_cd.tensor_impl_; + col_name_ = in_cd.col_name_; + // If it has a tensor shape, make a copy of it with our own unique_ptr. + tensor_shape_ = in_cd.hasShape() ? std::make_unique(in_cd.shape()) : nullptr; + } + return *this; +} + +// Destructor +ColDescriptor::~ColDescriptor() = default; + +// A print method typically used for debugging +void ColDescriptor::Print(std::ostream &out) const { + out << " Name : " << col_name_ << "\n Type : " << type_ << "\n Rank : " << rank_ + << "\n Shape : ("; + if (tensor_shape_) { + out << *tensor_shape_ << ")\n"; + } else { + out << "no shape provided)\n"; + } +} + +// Given a number of elements, this function will compute what the actual Tensor shape would be. +// If there is no starting TensorShape in this column, or if there is a shape but it contains +// an unknown dimension, then the output shape returned shall resolve dimensions as needed. +Status ColDescriptor::MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const { + if (out_shape == nullptr) { + RETURN_STATUS_UNEXPECTED("Unexpected null output shape argument."); + } + + // If the shape is not given in this column, then we assume the shape will be: {numElements} + if (tensor_shape_ == nullptr) { + if (this->rank() == 0 && num_elements == 1) { + *out_shape = TensorShape::CreateScalar(); + return Status::OK(); + } + *out_shape = TensorShape({num_elements}); + return Status::OK(); + } + + // Build the real TensorShape based on the requested shape and the number of elements in the data. + // If there are unknown dimensions, then the unknown dimension needs to be filled in. + // Example: requestedShape: {?,4,3}. + // If numElements is 24, then the output shape can be computed to: {2,4,3} + std::vector requested_shape = tensor_shape_->AsVector(); + int64_t num_elements_of_shape = 1; // init to 1 as a starting multiplier. + + // unknownDimPosition variable is overloaded to provide 2 meanings: + // 1) If it's set to DIM_UNKNOWN, then it provides a boolean knowledge to tell us if there are + // any unknown dimensions. i.e. if it's set to unknown, then there are no unknown dimensions. + // 2) If it's set to a numeric value, then this is the vector index position within the shape + // where the single unknown dimension can be found. + int64_t unknown_dim_position = TensorShape::kDimUnknown; // Assume there are no unknown dims to start + + for (int i = 0; i < requested_shape.size(); ++i) { + // If we already had an unknown dimension, then we cannot have a second unknown dimension. + // We only support the compute of a single unknown dim. + if (requested_shape[i] == TensorShape::kDimUnknown && unknown_dim_position != TensorShape::kDimUnknown) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Requested shape has more than one unknown dimension!"); + } + + // If the current dimension in the requested shape is a known value, then compute the number of + // elements so far. + if (requested_shape[i] != TensorShape::kDimUnknown) { + num_elements_of_shape *= requested_shape[i]; + } else { + // This dimension is unknown so track which dimension position has it. + unknown_dim_position = i; + } + } + + // Sanity check the the computed element counts divide evenly into the input element count + if (num_elements < num_elements_of_shape || num_elements_of_shape == 0 || num_elements % num_elements_of_shape != 0) { + RETURN_STATUS_UNEXPECTED("Requested shape has an invalid element count!"); + } + + // If there was any unknown dimensions, then update the requested shape to fill in the unknown + // dimension with the correct value. If there were no unknown dim's then the output shape will + // remain to be the same as the requested shape. + if (unknown_dim_position != TensorShape::kDimUnknown) { + requested_shape[unknown_dim_position] = (num_elements / num_elements_of_shape); + } + + // Any unknown dimension is filled in now. Set the output shape + *out_shape = TensorShape(requested_shape); + return Status::OK(); +} + +// getter function for the shape +TensorShape ColDescriptor::shape() const { + if (tensor_shape_ != nullptr) { + return *tensor_shape_; // copy construct a shape to return + } else { + return TensorShape::CreateUnknownRankShape(); // empty shape to return + } +} + +const char DataSchema::DEFAULT_DATA_SCHEMA_FILENAME[] = "datasetSchema.json"; + +// Constructor 1: Simple constructor that leaves things uninitialized. +DataSchema::DataSchema() : num_rows_(0) {} + +// Internal helper function. Parses the json schema file in any order and produces a schema that +// does not follow any particular order (json standard does not enforce any ordering protocol). +// This one produces a schema that contains all of the columns from the schema file. +Status DataSchema::AnyOrderLoad(nlohmann::json column_tree) { + // Iterate over the json file. Each parent json node is the column name, + // followed by the column properties in the child tree under the column. + // Outer loop here iterates over the parents (i.e. the column name) + if (!column_tree.is_array()) { + for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { + std::string col_name = it.key(); + nlohmann::json column_child_tree = it.value(); + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); + } + } else { + // Case where the schema is a list of columns not a dict + for (nlohmann::json::iterator it = column_tree.begin(); it != column_tree.end(); ++it) { + nlohmann::json column_child_tree = it.value(); + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, "")); + } + } + return Status::OK(); +} + +// Internal helper function. For each input column name, perform a lookup to the json document to +// find the matching column. When the match is found, process that column to build the column +// descriptor and add to the schema in the order in which the input column names are given.id +Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load) { + if (!column_tree.is_array()) { + // the json file is dict (e.g., {image: ...}) + // Loop over the column name list + for (const auto &curr_col_name : columns_to_load) { + // Find the column in the json document + auto column_info = column_tree.find(common::SafeCStr(curr_col_name)); + if (column_info == column_tree.end()) { + RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); + } + // At this point, columnInfo.value() is the subtree in the json document that contains + // all of the data for a given column. This data will formulate our schema column. + const std::string &col_name = column_info.key(); + nlohmann::json column_child_tree = column_info.value(); + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, col_name)); + } + } else { + // the json file is array (e.g., [name: image...]) + // Loop over the column name list + for (const auto &curr_col_name : columns_to_load) { + // Find the column in the json document + int32_t index = -1; + int32_t i = 0; + for (const auto &it_child : column_tree.items()) { + auto name = it_child.value().find("name"); + if (name == it_child.value().end()) { + RETURN_STATUS_UNEXPECTED("Name field is missing for this column."); + } + if (name.value() == curr_col_name) { + index = i; + break; + } + i++; + } + if (index == -1) { + RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); + } + nlohmann::json column_child_tree = column_tree[index]; + RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, curr_col_name)); + } + } + return Status::OK(); +} + +// Internal helper function for parsing shape info and building a vector for the shape construction. +static Status buildShape(const nlohmann::json &shapeVal, std::vector *outShape) { + if (outShape == nullptr) { + RETURN_STATUS_UNEXPECTED("null output shape"); + } + if (shapeVal.empty()) return Status::OK(); + + // Iterate over the integer list and add those values to the output shape tensor + auto items = shapeVal.items(); + using it_type = decltype(items.begin()); + (void)std::transform(items.begin(), items.end(), std::back_inserter(*outShape), [](it_type j) { return j.value(); }); + return Status::OK(); +} + +// Internal helper function. Given the json tree for a given column, load it into our schema. +Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name) { + int32_t rank_value = -1; + TensorImpl t_impl_value = TensorImpl::kFlexible; + std::string name, type_str; + std::vector tmp_shape = {}; + bool shape_field_exists = false; + // Iterate over this column's attributes. + // Manually iterating each of the child nodes/trees here so that we can provide our own error handling. + for (const auto &it_child : column_child_tree.items()) { + // Save the data for each of the attributes into variables. We'll use these to construct later. + if (it_child.key() == "name") { + name = it_child.value(); + } else if (it_child.key() == "type") { + type_str = it_child.value(); + } else if (it_child.key() == "rank") { + rank_value = it_child.value(); + } else if (it_child.key() == "t_impl") { + STR_TO_TENSORIMPL(it_child.value(), t_impl_value); + } else if (it_child.key() == "shape") { + shape_field_exists = true; + RETURN_IF_NOT_OK(buildShape(it_child.value(), &tmp_shape)); + } else { + std::string err_msg = "Unexpected column attribute " + it_child.key() + " for column " + col_name; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + if (!name.empty()) { + if (!col_name.empty() && col_name != name) { + std::string err_msg = + "json schema file for column " + col_name + " has column name that does not match columnsToLoad"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } else { + if (col_name.empty()) { + std::string err_msg = "json schema file for column " + col_name + " has invalid or missing column name."; + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + name = col_name; + } + } + // data type is mandatory field + if (type_str.empty()) + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "json schema file for column " + col_name + " has invalid or missing column type."); + + // rank number is mandatory field + if (rank_value <= -1) + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "json schema file for column " + col_name + " must define a positive rank value."); + + // Create the column descriptor for this column from the data we pulled from the json file + TensorShape col_shape = TensorShape(tmp_shape); + if (shape_field_exists) + (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value, &col_shape)); + else + // Create a column descriptor that doesn't have a shape + (void)this->AddColumn(ColDescriptor(name, DataType(type_str), t_impl_value, rank_value)); + return Status::OK(); +} + +// Parses a schema json file and populates the columns and meta info. +Status DataSchema::LoadSchemaFile(const std::string &schema_file_path, + const std::vector &columns_to_load) { + try { + std::ifstream in(schema_file_path); + + nlohmann::json js; + in >> js; + RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); + try { + num_rows_ = js.at("numRows").get(); + } catch (nlohmann::json::out_of_range &e) { + num_rows_ = 0; + } catch (nlohmann::json::exception &e) { + RETURN_STATUS_UNEXPECTED("Unable to parse \"numRows\" from schema"); + } + nlohmann::json column_tree = js.at("columns"); + if (column_tree.empty()) { + RETURN_STATUS_UNEXPECTED("columns is null"); + } + if (columns_to_load.empty()) { + // Parse the json tree and load the schema's columns in whatever order that the json + // layout decides + RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); + } else { + RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Schema file failed to load"); + } + return Status::OK(); +} + +// Parses a schema json string and populates the columns and meta info. +Status DataSchema::LoadSchemaString(const std::string &schema_json_string, + const std::vector &columns_to_load) { + try { + nlohmann::json js = nlohmann::json::parse(schema_json_string); + RETURN_IF_NOT_OK(PreLoadExceptionCheck(js)); + num_rows_ = js.value("numRows", 0); + nlohmann::json column_tree = js.at("columns"); + if (column_tree.empty()) { + RETURN_STATUS_UNEXPECTED("columns is null"); + } + if (columns_to_load.empty()) { + // Parse the json tree and load the schema's columns in whatever order that the json + // layout decides + RETURN_IF_NOT_OK(this->AnyOrderLoad(column_tree)); + } else { + RETURN_IF_NOT_OK(this->ColumnOrderLoad(column_tree, columns_to_load)); + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Schema file failed to load"); + } + return Status::OK(); +} + +// Destructor +DataSchema::~DataSchema() = default; + +// Getter for the ColDescriptor by index +const ColDescriptor &DataSchema::column(int32_t idx) const { + MS_ASSERT(idx < static_cast(col_descs_.size())); + return col_descs_[idx]; +} + +// A print method typically used for debugging +void DataSchema::Print(std::ostream &out) const { + out << "Dataset schema: ("; + for (const auto &col_desc : col_descs_) { + out << col_desc << "\n"; + } +} + +// Adds a column descriptor to the schema +Status DataSchema::AddColumn(const ColDescriptor &cd) { + // Sanity check there's not a duplicate name before adding the column + for (int32_t i = 0; i < col_descs_.size(); ++i) { + if (col_descs_[i].name() == cd.name()) { + std::ostringstream ss; + ss << "column name '" << cd.name() << "' already exists in schema."; + std::string err_msg = ss.str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + col_descs_.push_back(cd); + return Status::OK(); +} + +// Internal helper function. Performs sanity checks on the json file setup. +Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) { + // Check if columns node exists. It is required for building schema from file. + if (js.find("columns") == js.end()) + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "\"columns\" node is required in the schema json file."); + return Status::OK(); +} + +// Loops through all columns in the schema and returns a map with the column +// name to column index number. +Status DataSchema::GetColumnNameMap(std::unordered_map *out_column_name_map) { + if (out_column_name_map == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "unexpected null output column name map."); + } + + for (int32_t i = 0; i < col_descs_.size(); ++i) { + if (col_descs_[i].name().empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Constructing column name map from schema, but found empty column name."); + } + (*out_column_name_map)[col_descs_[i].name()] = i; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/data_schema.h b/mindspore/ccsrc/minddata/dataset/engine/data_schema.h new file mode 100644 index 0000000000..96f6f2b118 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/data_schema.h @@ -0,0 +1,208 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATA_SCHEMA_H_ +#define DATASET_ENGINE_DATA_SCHEMA_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +/// \class ColDescriptor data_schema.h +/// \brief A simple class to provide meta info about a column. +class ColDescriptor { + public: + /// \brief Constructor 1: Simple constructor that leaves things uninitialized. + ColDescriptor(); + + /// \brief Constructor 2: Main constructor + /// \param[in] col_name - The name of the column + /// \param[in] col_type - The DE Datatype of the column + /// \param[in] tensor_impl - The (initial) type of tensor implementation for the column + /// \param[in] rank - The number of dimension of the data + /// \param[in] in_shape - option argument for input shape + ColDescriptor(const std::string &col_name, DataType col_type, TensorImpl tensor_impl, int32_t rank, + const TensorShape *in_shape = nullptr); + + /// \brief Explicit copy constructor is required + /// \param[in] in_cd - the source ColDescriptor + ColDescriptor(const ColDescriptor &in_cd); + + /// \brief Assignment overload + /// \param in_cd - the source ColDescriptor + ColDescriptor &operator=(const ColDescriptor &in_cd); + + /// \brief Destructor + ~ColDescriptor(); + + /// \brief A print method typically used for debugging + /// \param out - The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief Given a number of elements, this function will compute what the actual Tensor shape would be. + /// If there is no starting TensorShape in this column, or if there is a shape but it contains + /// an unknown dimension, then the output shape returned shall resolve dimensions as needed. + /// \param[in] num_elements - The number of elements in the data for a Tensor + /// \param[inout] out_shape - The materialized output Tensor shape + /// \return Status - The error code return + Status MaterializeTensorShape(int32_t num_elements, TensorShape *out_shape) const; + + /// \brief << Stream output operator overload + /// This allows you to write the debug print info using stream operators + /// \param[in] out - reference to the output stream being overloaded + /// \param[in] cd - reference to the ColDescriptor to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ColDescriptor &cd) { + cd.Print(out); + return out; + } + + /// \brief getter function + /// \return The column's DataType + DataType type() const { return type_; } + + /// \brief getter function + /// \return The column's rank + int32_t rank() const { return rank_; } + + /// \brief getter function + /// \return The column's name + std::string name() const { return col_name_; } + + /// \brief getter function + /// \return The column's shape + TensorShape shape() const; + + /// \brief getter function + /// \return TF if the column has an assigned fixed shape. + bool hasShape() const { return tensor_shape_ != nullptr; } + + /// \brief getter function + /// \return The column's tensor implementation type + TensorImpl tensorImpl() const { return tensor_impl_; } + + private: + DataType type_; // The columns type + int32_t rank_; // The rank for this column (number of dimensions) + TensorImpl tensor_impl_; // The initial flavour of the tensor for this column + std::unique_ptr tensor_shape_; // The fixed shape (if given by user) + std::string col_name_; // The name of the column +}; + +/// \class DataSchema data_schema.h +/// \brief A list of the columns. +class DataSchema { + public: + /// \brief Constructor + DataSchema(); + + /// \brief Destructor + ~DataSchema(); + + /// \brief Parses a schema json file and populates the columns and meta info. + /// \param[in] schema_file_path - the schema file that has the column's info to load + /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. + /// \return Status - The error code return + Status LoadSchemaFile(const std::string &schema_file_path, const std::vector &columns_to_load); + + /// \brief Parses a schema JSON string and populates the columns and meta info. + /// \param[in] schema_json_string - the schema file that has the column's info to load + /// \param[in] columns_to_load - list of strings for columns to load. if empty, assumes all columns. + /// \return Status - The error code return + Status LoadSchemaString(const std::string &schema_json_string, const std::vector &columns_to_load); + + /// \brief A print method typically used for debugging + /// \param[in] out - The output stream to write output to + void Print(std::ostream &out) const; + + /// \brief << Stream output operator overload. This allows you to write the debug print info using stream operators + /// \param[in] out - reference to the output stream being overloaded + /// \param[in] ds - reference to the DataSchema to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DataSchema &ds) { + ds.Print(out); + return out; + } + + /// \brief Adds a column descriptor to the schema + /// \param[in] cd - The ColDescriptor to add + /// \return Status - The error code return + Status AddColumn(const ColDescriptor &cd); + + /// \brief getter + /// \return The reference to a ColDescriptor to get (const version) + const ColDescriptor &column(int32_t idx) const; + + /// \brief getter + /// \return The number of columns in the schema + int32_t NumColumns() const { return col_descs_.size(); } + + bool Empty() const { return NumColumns() == 0; } + + /// \brief getter + /// \return The number of rows read from schema + int64_t num_rows() const { return num_rows_; } + + static const char DEFAULT_DATA_SCHEMA_FILENAME[]; + + /// \brief Loops through all columns in the schema and returns a map with the column name to column index number. + /// \param[inout] out_column_name_map - The output map of columns names to column index + /// \return Status - The error code return + Status GetColumnNameMap(std::unordered_map *out_column_name_map); + + private: + /// \brief Internal helper function. Parses the json schema file in any order and produces a schema that + /// does not follow any particular order (json standard does not enforce any ordering protocol). + /// This one produces a schema that contains all of the columns from the schema file. + /// \param[in] column_tree - The nlohmann tree from the json file to parse + /// \return Status - The error code return + Status AnyOrderLoad(nlohmann::json column_tree); + + /// \brief Internal helper function. For each input column name, perform a lookup to the json document to + /// find the matching column. When the match is found, process that column to build the column + /// descriptor and add to the schema in the order in which the input column names are given. + /// \param[in] column_tree - The nlohmann tree from the json file to parse + /// \param[in] columns_to_load - list of strings for the columns to add to the schema + /// \return Status - The error code return + Status ColumnOrderLoad(nlohmann::json column_tree, const std::vector &columns_to_load); + + /// \brief Internal helper function. Given the json tree for a given column, load it into our schema. + /// \param[in] columnTree - The nlohmann child tree for a given column to load. + /// \param[in] col_name - The string name of the column for that subtree. + /// \return Status - The error code return + Status ColumnLoad(nlohmann::json column_child_tree, const std::string &col_name); + + /// \brief Internal helper function. Performs sanity checks on the json file setup. + /// \param[in] js - The nlohmann tree for the schema file + /// \return Status - The error code return + Status PreLoadExceptionCheck(const nlohmann::json &js); + + std::vector col_descs_; // Vector of column descriptors + int64_t num_rows_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATA_SCHEMA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc new file mode 100644 index 0000000000..f75ca5d097 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/dataset_iterator.h" +#include +#include +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { +// Constructor of the IteratorBase +IteratorBase::IteratorBase() : curr_buffer_(nullptr), eof_handled_(false) {} + +IteratorBase::~IteratorBase() = default; + +// Fetches one row of data from the iterator as a column map. +Status IteratorBase::GetNextAsMap(TensorMap *out_map) { + if (out_map == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output map in iterator!"); + } + + out_map->clear(); + + TensorRow curr_row; + RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row)); + + // Return empty map if there's no data + if (curr_row.empty()) { + return Status::OK(); + } + + // The column name mapping is needed to be able to produce the tensor map output. + // The column name mapping comes from the source operator that is producing the data into the iterator. + // To avoid having to fetch this for every time, we'll take a local copy of the column name id mapping + // and save in the iterator. We only have to do this once. All subsequent iterations use the same mapping. + if (col_name_id_map_.empty()) { + // Determine the column name map by calling the derived class method to retrieve the column + // name map + col_name_id_map_ = this->GetColumnNameMap(); + } + + // Populate the out map from the row and return it + for (auto colMap : col_name_id_map_) { + (*out_map)[colMap.first] = std::move(curr_row[colMap.second]); + } + + return Status::OK(); +} + +// Fetches one row of data from the iterator. +// The base class version simply performs error handling and returns empty row. Actual +// functionality exists in the derived versions of this function. +Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) { + if (out_row == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output row in iterator!"); + } + + // clear the old tensor row + out_row->clear(); + + return Status::OK(); +} + +// Constructor of the DatasetIterator +DatasetIterator::DatasetIterator(std::shared_ptr exe_tree) + : IteratorBase(), + root_(exe_tree->root()), + tracing_(nullptr), + cur_batch_num_(0), + cur_connector_size_(0), + cur_connector_capacity_(0) { + std::shared_ptr node; + Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node); + if (s.IsOk()) { + tracing_ = std::dynamic_pointer_cast(node); + } +} + +DatasetIterator::~DatasetIterator() = default; + +// Fetches one row of data from the iterator. Overrides the base class. This one fetches +// from the tree root node directly. +Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) { + // Common code init and error checking in the base class. + RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); + + // Once eof is handled, always return empty row. Class must be destroyed and recreated if you + // want to iterate again. + if (eof_handled_) { + return Status::OK(); + } + + // Check if we need to get a new DataBuffer to iterate. + if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { + if (tracing_ != nullptr) { + cur_connector_size_ = root_->ConnectorSize(); + cur_connector_capacity_ = root_->ConnectorCapacity(); + } + RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); + + // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually + // handle eoe and eof messages here. + // + // An eoe buffer means we have iterated fully to the end of the tree. + // An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of + // all operators. + if (curr_buffer_->eoe()) { + MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row."; + + // Before returning the last empty vector, fetch the eof buffer which should be the last + // buffer, and then free it. + RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); + + if (!curr_buffer_->eof()) { + RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!"); + } + eof_handled_ = true; + curr_buffer_.reset(); // explicitly free the eof buffer + // Set tree to Finished state + root_->Tree()->SetFinished(); + + return Status::OK(); + } + + if (curr_buffer_->eof()) { + // An eof by itself, without being preceded by an eoe, is possible if a repeat operator + // exists below us in the stack. Repeat operator eats eoe's but eventually allows the + // flow of an eof up the pipeline by itself. + eof_handled_ = true; + curr_buffer_.reset(); // explicitly free the eof buffer + // Set tree to Finished state + root_->Tree()->SetFinished(); + return Status::OK(); + } + } + + // If we got this far, now it's time to pop that next row for return to caller + RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); + if (tracing_ != nullptr) { + cur_batch_num_++; + tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_); + } + return Status::OK(); +} + +Status DatasetIterator::GetOutputShapes(std::vector *out_shapes) { + if (out_shapes == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output shape argument"); + } + if (device_queue_row_.empty()) { + RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); + } + for (auto ts : device_queue_row_) { + out_shapes->push_back(ts->shape()); + } + + return Status::OK(); +} + +Status DatasetIterator::GetOutputTypes(std::vector *out_types) { + if (out_types == nullptr) { + RETURN_STATUS_UNEXPECTED("Null output type argument"); + } + if (device_queue_row_.empty()) { + RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_)); + } + for (auto ts : device_queue_row_) { + out_types->push_back(ts->type()); + } + return Status::OK(); +} + +// Getter +std::unordered_map DatasetIterator::GetColumnNameMap() const { + return root_->column_name_id_map(); +} + +// Constructor of the ChildIterator +ChildIterator::ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx) + : IteratorBase(), current_op_(current_op), child_idx_(child_idx), worker_id_(worker_id), end_epoch_(false) {} + +ChildIterator::~ChildIterator() { current_op_ = nullptr; } + +// Fetches one row of data from the iterator. Overrides the base class. This one fetches +// only from the child/worker id as given from the constructor. +Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) { + // Common code init and error checking in the base class. + RETURN_IF_NOT_OK(IteratorBase::FetchNextTensorRow(out_row)); + + // Once eof is handled, always return empty row. Class must be destroyed and recreated if you + // want to iterate again. + if (eof_handled_) { + return Status::OK(); + } + + // Check if we need to get a new DataBuffer to iterate. + if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { + RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); + + // Unlike the DatasetIterator, this child iterator does not quit after eoe. + // Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the + // caller to decide what it wants to do next. + if (curr_buffer_->eoe()) { + MS_LOG(DEBUG) << "Child iterator picked up EOE."; + end_epoch_ = true; + return Status::OK(); + } + + if (curr_buffer_->eof()) { + MS_LOG(DEBUG) << "Child iterator picked up EOF."; + eof_handled_ = true; + return Status::OK(); + } + } + + // If we got this far, now it's time to pop that next row for return to caller + RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); + + return Status::OK(); +} + +// drain till the next eoe +Status ChildIterator::Drain() { + if (end_epoch_ == true) { + // Calling drain against a child that is already at it's eoe state will not result in any action. + // This allows you to do: + // - fetch until empty row + // - drain (will not actually drain because you are already at the end of the iteration) + // However, the next time after that, it will perform it's normal draining activities. + end_epoch_ = false; + MS_LOG(DEBUG) << "No operation drain, already at end of epoch."; + return Status::OK(); + } + MS_LOG(DEBUG) << "Child draining buffers until eoe."; + // else we drain until eoe or eof, eof here is for sanity check + while (!curr_buffer_->eoe() && !curr_buffer_->eof()) { + RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_)); + } + if (curr_buffer_->eof()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Child iterator picked up EOF in drain."); + } + return Status::OK(); +} + +// Getter +std::unordered_map ChildIterator::GetColumnNameMap() const { + return current_op_->child(child_idx_)->column_name_id_map(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h new file mode 100644 index 0000000000..253d1604e2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.h @@ -0,0 +1,156 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASET_ITERATOR_H_ +#define DATASET_ENGINE_DATASET_ITERATOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" + +namespace mindspore { +namespace dataset { +using TensorMap = std::unordered_map>; + +// forward declare +class ExecutionTree; + +class DataBuffer; + +// IteratorBase class is used to iterate data from an executionTree one row at a time. +// The base class provides the general interface, whereas derived classes provide slightly +// different implementations. +class IteratorBase { + public: + // Constructor of IteratorBase + IteratorBase(); + + // Destructor + virtual ~IteratorBase(); + + // Fetches one row of data from the iterator. + // the base class version simply performs error handling and returns empty row. Actual + // functionality exists in the derived versions of this function. + // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data + // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. + // @return Status - The error code return + // @note The position of a Tensor/column might be different from the initial column order + // in corresponding Dataset Op. User must be aware that MapOp, ZipOps, and others might change + // the column ordering. + virtual Status FetchNextTensorRow(TensorRow *out_row); + + // Fetches one row of data from the iterator as a column map. + // @return A unordered map from column name to shared pointer to Tensor. + Status GetNextAsMap(TensorMap *out_map); + + // Getter + // @return T/F if this iterator is completely done after getting an eof + bool eof_handled() const { return eof_handled_; } + + // Getter + // @return The string to column id mapping. + virtual std::unordered_map GetColumnNameMap() const = 0; + + protected: + std::unique_ptr curr_buffer_; // holds the current buffer + bool eof_handled_; // T/F if this op got an eof + std::unordered_map col_name_id_map_; +}; + +// The DatasetIterator derived class is for fetching rows off the end/root of the execution tree. +class DatasetIterator : public IteratorBase { + public: + // Constructor of the DatasetIterator + // @param exe_tree The execution tree we want to pull/iterate the data from using it's root node. + explicit DatasetIterator(std::shared_ptr exe_tree); + + // Destructor + ~DatasetIterator(); + + // Fetches one row of data from the iterator. Overrides the base class. This one fetches + // from the tree root node directly. + // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data + // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. + // @return Status - The error code return + Status FetchNextTensorRow(TensorRow *out_row) override; + + // Fetches the next tensor row into device row, and returns it's shape. + // @param out_shapes - A vector of tensor shapes (one shape per column) + // @return Status - The error code return + Status GetOutputShapes(std::vector *out_shapes); + + // Fetches the next tensor row into device row, and returns it's shape. + // @param outShapes - A vector of tensor shapes (one shape per column) + // @return Status - The error code return + Status GetOutputTypes(std::vector *out_types); + + // Getter + // @return The string to column id mapping. + std::unordered_map GetColumnNameMap() const override; + + private: + std::shared_ptr root_; // saves the root of the executionTree + TensorRow device_queue_row_; + std::shared_ptr tracing_; // trace profiling data + int32_t cur_batch_num_; // current batch number,used for profiling + int32_t cur_connector_size_; // current connector size of root op,used for profiling + int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling +}; + +// The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree. +// This one should only be used by internal Dataset operators, rather than an end-user. +class ChildIterator : public IteratorBase { + public: + // Constructor of the DatasetIterator + // @param current_op - The parent op from which we'll fetch from it's children. + // @param worker_id - The worker id to use when fetching from the children. + // @param child_idx - The index to the child to fetch from. + ChildIterator(DatasetOp *current_op, int32_t worker_id, int32_t child_idx); + + // Destructor + ~ChildIterator(); + + // Fetches one row of data from the iterator. Overrides the base class. This one fetches + // only from the child/worker id as given from the constructor. + // @param out_row - A TensorRow (vector of shared pointers to Tensors). If any of the of data + // messages are encountered (such as eoe or eof), then an empty TensorRow is returned back. + // @return Status - The error code return + Status FetchNextTensorRow(TensorRow *out_row) override; + + // This function drains buffer until next eoe has been received. + // It will be a no-op if the previous row returned is empty. + // @return Status - The error code return + Status Drain(); + + // Getter + // @return The string to column id mapping. + std::unordered_map GetColumnNameMap() const override; + + private: + DatasetOp *current_op_; // The parent operator. We consume from it's children. + int32_t child_idx_; // The specific child this iterator will fetch from. + int32_t worker_id_; // The worker id uses for fetching the child data. + bool end_epoch_; // the flag used when an empty row has been returned. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASET_ITERATOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/datasetops/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc new file mode 100644 index 0000000000..51ea232e68 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.cc @@ -0,0 +1,242 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/barrier_op.h" +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +BarrierOp::Builder::Builder() { + // Some arguments to the BarrierOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the BarrierOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } + +Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, + builder_condition_func_); + return Status::OK(); +} + +// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions +BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func) + : PipelineOp(op_connector_size), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + clean_up_(false), + eof_(false), + condition_name_(condition_name), + condition_function_(condition_func) {} + +// destructor +BarrierOp::~BarrierOp() {} + +// Entry point for Barrier, called by launch() +Status BarrierOp::operator()() { + // The children_num_ parameter needs to be put here + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // create child iterator, right now this barrier is a pipeline operator + const int32_t worker_id = 0; + const int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Loop until eof is true + while (!eof_) { + // Create new table to put the new tensor rows + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + + // we have to output new buffer with possibly different buffer size, possibly one row + while (!clean_up_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (clean_up_) { + MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; + // Send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + // 5 handle eof + // propagate eof here. + MS_LOG(INFO) << "Barrier operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status BarrierOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; + clean_up_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row = {}; + // use iterator to get next row and invoke pyfunc wait + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + // This epoch is empty + return Status::OK(); + } + // Pack this first row into our tensor table + // first row we also have to check if we should block + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + + // the update code below shouldn't do anything bad if the column name already exists. + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status BarrierOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); + } + TensorRow new_row = {}; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// function executes a py_func and blocks until condition becomes true. +Status BarrierOp::blockCond() { + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // we have condition name, however the flexibility is in python today + try { + // Invoke python function + py::object ret_py_obj = condition_function_(); + // Process the return value + if (!py::isinstance(ret_py_obj)) { + return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +// fetches next Barrier buffer row +Status BarrierOp::getNextTensorRow(TensorRow *new_row) { + // iterate over all iterators and generate a row + RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row->empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; + clean_up_ = true; + // If we picked up an eof here, then we are completely done. + if ((child_iterator_)->eof_handled()) { + MS_LOG(INFO) << "Barrier operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } + return Status::OK(); +} + +// A function that prints info about the Operator +void BarrierOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCondition: " << condition_name_ << "\n\n"; + } +} + +// overwrite function and handle eof +Status BarrierOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status BarrierOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h new file mode 100644 index 0000000000..a3ac843272 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/barrier_op.h @@ -0,0 +1,169 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has +// been received. This signal is given from python layer. The current barrier design respects the +// rows per buffer design and will only output a buffer with rows once it has received rows per buffer +// signals from python. + +class BarrierOp : public PipelineOp { + public: + // The nested builder class inside of the BarrierOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param const std::string & condition_name + // @return Builder setter method returns reference to the builder. + Builder &SetConditionName(const std::string &condition_name) { + builder_condition_name_ = condition_name; + return *this; + } + + // Setter method. + // @param py::function condition_func - blocking condition function + // @return Builder setter method returns reference to the builder. + Builder &SetConditionFunc(py::function condition_func) { + builder_condition_func_ = condition_func; + return *this; + } + + // The builder "build" method creates the BarrierOp dataset Operator. + // @return shared_ptr to the new BarrierOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::string builder_condition_name_; + py::function builder_condition_func_; + + Status SanityCheck() const; + }; + + // Constructor for BarrierOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + // @param condition_name - the condition name associated with this operator + // @param condition_func - the blocking condition check per row + // @note - currently rows_per_buffer should = 1 for barrier. + // The reason for this is having other values would complicate how the pipeline behaves with other operators + // One example of such case is having batch after barrier. Batch would be waiting for data and having + // rows per buffer in this case can result in hanging + BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func); + + // Destructor + ~BarrierOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Barrier + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { + bo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Handles preprocessing of the main loop, used when starting new epoch + // @param table - a table of tensors to be moved into a buffer + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table - a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Gets next tensor row and sets control signals + Status getNextTensorRow(TensorRow *new_row); + + // This function runs the wait function on condition + Status blockCond(); + + private: + // clean up variable to return imcomplete buffer + bool clean_up_; + // end of file state, we stop reading data and shut down + bool eof_; + // rows per buffer + int32_t rows_per_buffer_; + // buffer_id + int32_t buffer_id_; + // iterator to pull new rows, we only have one child + std::unique_ptr child_iterator_; + // condition name, to support multiple barriers + std::string condition_name_; + // Function pointer of blocking function + py::function condition_function_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc new file mode 100644 index 0000000000..844d054307 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.cc @@ -0,0 +1,446 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/batch_op.h" + +#include +#include + +#include "common/utils.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +#endif +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +using float16 = Eigen::half; + +namespace mindspore { +namespace dataset { +BatchOp::Builder::Builder(int32_t batch_size) : builder_drop_(false), builder_pad_(false), builder_pad_map_({}) { + builder_batch_size_ = batch_size; + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BatchOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); +#ifdef ENABLE_PYTHON + *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, + builder_num_workers_, builder_cols_to_map_, builder_batch_size_func_, + builder_batch_map_func_, builder_pad_map_); +#else + *ptr = std::make_shared(builder_batch_size_, builder_drop_, builder_pad_, builder_op_connector_size_, + builder_num_workers_, builder_cols_to_map_, builder_pad_map_); +#endif + return Status::OK(); +} + +Status BatchOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_batch_size_ <= 0 ? "batch size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "batch num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +#ifdef ENABLE_PYTHON +BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &cols_to_map, py::function batch_size_func, py::function batch_map_func, + PadInfo pad_map) + : ParallelOp(num_workers, op_queue_size), + start_batch_size_(batch_size), + drop_(drop), + pad_(pad), + pyfunc_column_names_(cols_to_map), + batch_size_func_(batch_size_func), + batch_map_func_(batch_map_func), + pad_info_(pad_map) { + worker_queues_.Init(num_workers, op_queue_size); +} +#else +BatchOp::BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &cols_to_map, PadInfo pad_map) + : ParallelOp(num_workers, op_queue_size), + start_batch_size_(batch_size), + drop_(drop), + pad_(pad), + pyfunc_column_names_(cols_to_map), + pad_info_(pad_map) { + worker_queues_.Init(num_workers, op_queue_size); +} +#endif + +Status BatchOp::operator()() { + Status rc = LaunchThreadsAndInitOp(); + // Synchronize with TaskManager + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + int64_t epoch_num = 0, batch_num = 0, cnt = 0; + TensorRow new_row; + std::unique_ptr table = std::make_unique(); + child_iterator_ = std::make_unique(this, 0, 0); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + int32_t cur_batch_size = 0; + RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(0, 0, 0))); + while (child_iterator_->eof_handled() == false) { + while (new_row.empty() == false) { + table->emplace_back(new_row); + // if # of rows is enough to make 1 batch (1 batch is buffer), send it to worker_queue + if (table->size() == static_cast(cur_batch_size)) { + RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( + std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); + table = std::make_unique(); + RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); + } + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + // Reminder logic, execute only when there is a remainder (table is non empty) and don't drop + if (drop_ == false && table->empty() == false) { + RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack( + std::make_pair(std::move(table), CBatchInfo(epoch_num, batch_num++, cnt - epoch_num)))); + } + table = std::make_unique(); // this drops when drop == true + // end of the current epoch, batch_num should start from 0 again + batch_num = 0; + epoch_num++; + RETURN_IF_NOT_OK( + worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOE)))); + RETURN_IF_NOT_OK(GetBatchSize(&cur_batch_size, CBatchInfo(epoch_num, batch_num, cnt - epoch_num))); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } // end of eof_handled() == false + RETURN_IF_NOT_OK( + worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kEOF)))); + // EOF received, send quit signal (an empty buffer) to all workers + for (int32_t ind = 0; ind < num_workers_; ind++) { + RETURN_IF_NOT_OK( + worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::make_pair(nullptr, CBatchInfo(batchCtrl::kQuit)))); + } + return Status::OK(); +} + +void BatchOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [batch size: " << start_batch_size_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nStart batch size: " << start_batch_size_ << "\nDrop remainder: " << (drop_ ? "yes" : "no") << "\n\n"; + } +} + +Status BatchOp::BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, + dsize_t batch_size) { + if ((*src)->size() != batch_size) { + RETURN_STATUS_UNEXPECTED("[Internal Batch ERROR] Source table size does not match the batch_size"); + } + + if (batch_size == 1) { + TensorRow row = std::move((*src)->front()); + (*src)->pop_front(); + (*dest)->push_back(row); + for (const auto &tensor : (*dest)->front()) { + RETURN_IF_NOT_OK(tensor->ExpandDim(0)); + } + return Status::OK(); + } + + TensorRow batched_row; + auto num_columns = (*src)->front().size(); + for (size_t i = 0; i < num_columns; i++) { + std::shared_ptr first_tensor = (*src)->at(0).at(i); // first row, column i + TensorShape first_shape = first_tensor->shape(); + DataType first_type = first_tensor->type(); + TensorShape new_shape = first_shape.PrependDim(static_cast(batch_size)); + + std::shared_ptr new_tensor; + if (first_type.IsNumeric()) { // numeric tensor + RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, TensorImpl::kFlexible, new_shape, first_type)); + dsize_t j = 0; + for (auto row : **src) { + std::shared_ptr old_tensor = row.at(i); // row j, column i + if (old_tensor->shape() == first_shape) { // check the newly popped rows have the same dim as the first + RETURN_IF_NOT_OK(new_tensor->InsertTensor({j++}, old_tensor)); + } else { + RETURN_STATUS_UNEXPECTED("[Batch ERROR] Inconsistent TensorShapes of Column " + std::to_string(i)); + } + } + } else { // handle string column differently + std::vector strings; + for (dsize_t j = 0; j < batch_size; j++) { + std::shared_ptr old_tensor = (*src)->at(j).at(i); + for (auto itr = old_tensor->begin(); itr != old_tensor->end(); itr++) { + strings.emplace_back(*itr); + } + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&new_tensor, strings, new_shape)); + } + batched_row.emplace_back(new_tensor); + } + + (*dest)->emplace_back(batched_row); + + return Status::OK(); +} + +Status BatchOp::WorkerEntry(int32_t workerId) { + TaskManager::FindMe()->Post(); + std::pair, CBatchInfo> table_pair; + RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); + while (table_pair.second.ctrl_ != batchCtrl::kQuit) { + if (table_pair.second.ctrl_ == batchCtrl::kEOE) { + RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + } else if (table_pair.second.ctrl_ == batchCtrl::kEOF) { + RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) { + std::unique_ptr db = nullptr; + RETURN_IF_NOT_OK(MakeBatchedBuffer(std::move(table_pair), &db)); + RETURN_IF_NOT_OK(out_connector_->Add(workerId, std::move(db))); + } + RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); + } + return Status::OK(); +} + +Status BatchOp::MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, + std::unique_ptr *db) { + RETURN_UNEXPECTED_IF_NULL(table_pair.first); +#ifdef ENABLE_PYTHON + if (!pyfunc_column_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc +#endif + if (pad_) RETURN_IF_NOT_OK(PadColumns(&table_pair.first, pad_info_, column_name_id_map_)); // do padding if needed + (*db) = std::make_unique(table_pair.second.batch_num_, DataBuffer::kDeBFlagNone); + std::unique_ptr dest_table = std::make_unique(); + RETURN_IF_NOT_OK(BatchRows(&table_pair.first, &dest_table, table_pair.first->size())); + (*db)->set_tensor_table(std::move(dest_table)); + return Status::OK(); +} + +Status BatchOp::LaunchThreadsAndInitOp() { + RETURN_UNEXPECTED_IF_NULL(tree_); + RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&BatchOp::WorkerEntry, this, std::placeholders::_1))); + return Status::OK(); +} + +Status BatchOp::EofReceived(int32_t) { return Status::OK(); } + +Status BatchOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status BatchOp::MapColumns(std::pair, CBatchInfo> *table_pair) { + TensorBatchTable input_table; + input_table.reserve(pyfunc_column_names_.size()); + for (std::string col_name : pyfunc_column_names_) { + if (column_name_id_map_.find(col_name) == column_name_id_map_.end()) { + RETURN_STATUS_UNEXPECTED("column : '" + col_name + "' does not exist\n"); + } + TensorBatch tensor_batch; + tensor_batch.reserve(table_pair->first->size()); + size_t col_idx = static_cast(column_name_id_map_[col_name]); + for (size_t row_idx = 0; row_idx < table_pair->first->size(); row_idx++) { + tensor_batch.push_back(std::move(table_pair->first->at(row_idx)[col_idx])); + } + input_table.push_back(std::move(tensor_batch)); + } + + // Perform batch map + TensorBatchTable output_table; + RETURN_IF_NOT_OK(InvokeBatchMapFunc(&input_table, &output_table, table_pair->second)); + + // Write back to TensorQTable + for (size_t input_idx = 0; input_idx < pyfunc_column_names_.size(); input_idx++) { + size_t col_idx = static_cast(column_name_id_map_[pyfunc_column_names_[input_idx]]); + size_t row_id = 0; + for (TensorRow &row : *(table_pair->first)) { + row[col_idx] = std::move(output_table[input_idx][row_id++]); + } + } + return Status::OK(); +} +#endif + +Status BatchOp::GetBatchSize(int32_t *batch_size, CBatchInfo info) { +#ifdef ENABLE_PYTHON + if (batch_size_func_ != nullptr) { + RETURN_IF_NOT_OK(InvokeBatchSizeFunc(batch_size, info)); + } else { + (*batch_size) = start_batch_size_; + } +#else + (*batch_size) = start_batch_size_; +#endif + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status BatchOp::InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info) { + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object size = batch_size_func_(info); + *batch_size = size.cast(); + if (*batch_size <= 0) { + return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Batch size function should return an integer > 0"); + } + } + return Status(StatusCode::kOK, "Batch size func call succeed"); +} + +Status BatchOp::InvokeBatchMapFunc(TensorBatchTable *input, TensorBatchTable *output, CBatchInfo info) { + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Prepare batch map call back parameters + py::tuple input_args(input->size() + 1); + for (size_t i = 0; i < input->size(); i++) { + std::vector np_batch; + for (std::shared_ptr t : input->at(i)) { + py::array np_array; + RETURN_IF_NOT_OK(t->GetDataAsNumpy(&np_array)); + np_batch.push_back(std::move(np_array)); + } + input_args[i] = np_batch; + } + input_args[input->size()] = info; + // Invoke batch map func + py::object ret_py_obj = batch_map_func_(*input_args); + // Parse batch map return value + py::tuple ret_tuple = py::cast(ret_py_obj); + if (ret_tuple.size() != pyfunc_column_names_.size() || !py::isinstance(ret_tuple)) { + return Status(StatusCode::kPyFuncException, "Batch map function should return a tuple"); + } + for (size_t i = 0; i < ret_tuple.size(); i++) { + TensorBatch output_batch; + py::list output_list = py::cast(ret_tuple[i]); + for (size_t j = 0; j < output_list.size(); j++) { + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, py::cast(output_list[j]))); + output_batch.push_back(std::move(out)); + } + output->push_back(std::move(output_batch)); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Batch map function should return an tuple of list of numpy array"); + } + } + return Status(StatusCode::kOK); +} +#endif + +Status BatchOp::PadColumns(std::unique_ptr *table, const PadInfo &pad_info, + const std::unordered_map &column_name_id_map) { + RETURN_UNEXPECTED_IF_NULL(table); // placeholder for now, might need this in the future + CHECK_FAIL_RETURN_UNEXPECTED((*table)->front().size() == column_name_id_map.size(), "col_name_map mismatch"); + std::vector> pad_vals(column_name_id_map.size(), + 0); // value to pad each column's tensor with, default 0 + std::set pad_cols; + // padded_shape provided by user, maximum shapes of current batch of tensors + std::vector> pad_shapes(column_name_id_map.size()), max_shapes(column_name_id_map.size()); + RETURN_IF_NOT_OK(UnpackPadInfo(pad_info, column_name_id_map, &pad_cols, &pad_vals, &pad_shapes)); + + // init each shape in max_shape to {-1,-1...} init each unspecified shape in pad_shape to -1 as well + for (size_t col_id : pad_cols) { + max_shapes[col_id] = std::vector((*table)->front()[col_id]->Rank(), -1); + if (pad_shapes[col_id].empty()) pad_shapes[col_id] = max_shapes[col_id]; // fill pad shape with -1 + CHECK_FAIL_RETURN_UNEXPECTED(pad_shapes[col_id].size() == max_shapes[col_id].size(), "wrong rank in pad_shape"); + } + + // calculate maximum shape for each column that needs to be padded + for (const TensorRow &row : **table) { // iterator each row in a batch + for (size_t col_id : pad_cols) { // iterator each tensor in a row + CHECK_FAIL_RETURN_UNEXPECTED(row[col_id]->Rank() == max_shapes[col_id].size(), + "Tensor to be padded together need to have the same rank"); + for (size_t dim = 0; dim < row[col_id]->Rank(); dim++) { // pick the largest number in each dimension + max_shapes[col_id][dim] = std::max(max_shapes[col_id][dim], row[col_id]->shape()[dim]); + } + } + } + + // if user sets a dimension to -1 (None in python), use the max value for current dimension + for (size_t col_id : pad_cols) { + for (size_t dim = 0; dim < pad_shapes[col_id].size(); dim++) { + if (pad_shapes[col_id][dim] < 0) pad_shapes[col_id][dim] = max_shapes[col_id][dim]; + } + } + + // call pad on each tensor that needs to be padded + for (TensorRow &row : **table) { + for (size_t col_id : pad_cols) { + std::shared_ptr pad_tensor; + RETURN_IF_NOT_OK(PadEnd(row[col_id], &pad_tensor, pad_shapes[col_id], pad_vals[col_id])); + row[col_id] = pad_tensor; + } + } + return Status::OK(); +} + +Status BatchOp::UnpackPadInfo(const PadInfo &pad_info, + const std::unordered_map &column_name_id_map, + std::set *pad_cols, std::vector> *pad_vals, + std::vector> *pad_shapes) { + if (pad_info.empty()) { // if pad_info empty, pad every columns automatically + for (dsize_t col_id = 0; col_id < column_name_id_map.size(); col_id++) { + pad_cols->insert(col_id); + } + } else { + for (const auto &p : pad_info) { + auto location = column_name_id_map.find(p.first); + CHECK_FAIL_RETURN_UNEXPECTED(location != column_name_id_map.end(), "no column exists with name:" + p.first); + auto col_id = static_cast(location->second); + CHECK_FAIL_RETURN_UNEXPECTED(col_id < pad_vals->size() && col_id < pad_shapes->size(), "col_id out of bound"); + pad_cols->insert(col_id); + (*pad_vals)[col_id] = p.second.second; // set pad values + (*pad_shapes)[col_id] = p.second.first.AsVector(); // empty vector if shape is unknown + } + } + return Status::OK(); +} + +// Visitor accept method for NodePass +Status BatchOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h new file mode 100644 index 0000000000..0c042433f7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/batch_op.h @@ -0,0 +1,287 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DataBuffer; + +using TensorBatch = TensorRow; +using TensorBatchTable = std::vector; +using PadInfo = std::map>>; + +class BatchOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor for Batch, batch size needs to be specified + // @param int32_t batch_size + explicit Builder(int32_t batch_size); + + // Default destructor + ~Builder() = default; + + // set number of parallel Workers on batch + // @param int32_t num_workers + // @return Builder & reference to builder class object + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // set drop for batch op,default false + // @param bool drop + // @return Builder & reference to builder class object + Builder &SetDrop(bool drop) { + builder_drop_ = drop; + return *this; + } + + Builder &SetPaddingMap(const PadInfo &pad_map, bool pad = true) { + builder_pad_ = pad; + builder_pad_map_ = pad_map; + return *this; + } + + // set connector size for batch + // @param int32_t op_conn_size + // @return Builder & reference to builder class object + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = (op_connector_size == 0 ? builder_op_connector_size_ : op_connector_size); + return *this; + } + + // set columns to perform map on + // @param const std::vector & cols_to_map - name of columns to perform map on + // @return Builder & reference to builder class object + Builder &SetColumnsToMap(const std::vector &cols_to_map) { + builder_cols_to_map_ = cols_to_map; + return *this; + } + +#ifdef ENABLE_PYTHON + // set columns to perform map on + // @param const std::vector & cols_to_map - name of columns to perform map on + // @return Builder & reference to builder class object + Builder &SetBatchMapFunc(py::function batch_map_func) { + builder_batch_map_func_ = batch_map_func; + return *this; + } + + // SetBatchSizeFunc, a function that calls to python after every batch is made + // @param py::function batch_size_func - python function to call, GIL required before calling + // @return Builder & reference to builder class object + Builder &SetBatchSizeFunc(py::function batch_size_func) { + builder_batch_size_func_ = batch_size_func; + return *this; + } +#endif + + // @param std::shared_ptr *ptr pointer to shared_ptr, actual return arg + // @return Status - The error code return + Status Build(std::shared_ptr *); + + private: + // Sanity check for builder class args + // @return Status - The error code return + Status SanityCheck(); + + bool builder_drop_; + bool builder_pad_; + int32_t builder_batch_size_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + std::vector builder_cols_to_map_; + PadInfo builder_pad_map_; +#ifdef ENABLE_PYTHON + py::function builder_batch_size_func_; + py::function builder_batch_map_func_; +#endif + }; + + enum batchCtrl : int8_t { kNoCtrl = 0, kEOE = 1, kEOF = 2, kQuit = 3 }; + + // Parameters associate with one batch. + // This struct is used for both internal control and python callback. + // This struct is bound to python with read-only access. + struct CBatchInfo { + CBatchInfo(int64_t ep, int64_t bat, int64_t cur, batchCtrl ctrl) + : epoch_num_(ep), batch_num_(bat), total_batch_num_(cur), ctrl_(ctrl) {} + CBatchInfo(int64_t ep, int64_t bat, int64_t cur) : CBatchInfo(ep, bat, cur, batchCtrl::kNoCtrl) {} + CBatchInfo() : CBatchInfo(0, 0, 0, batchCtrl::kNoCtrl) {} + explicit CBatchInfo(batchCtrl ctrl) : CBatchInfo(0, 0, 0, ctrl) {} + int64_t epoch_num_; // i-th epoch. i starts from 0 + int64_t batch_num_; // i-th batch since the start of current epoch. i starts from 0 + int64_t total_batch_num_; // i-th batch since the start of first epoch. i starts from 0 + batchCtrl ctrl_; // No control=0, EOE=1, EOF=2, Quit=3 + const int64_t get_batch_num() const { return batch_num_; } + const int64_t get_epoch_num() const { return epoch_num_; } + }; + +#ifdef ENABLE_PYTHON + // BatchOp constructor + // @param int32_t batch_size + // @param bool drop + // @param int32_t op_queue_size + // @param int32_t rows_per_buf + // @param int32_t num_workers + BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &, py::function batch_size_func, py::function batch_map_func, PadInfo pad_map); +#else + BatchOp(int32_t batch_size, bool drop, bool pad, int32_t op_queue_size, int32_t num_workers, + const std::vector &, PadInfo pad_map); +#endif + + // BatchOp destructor + ~BatchOp() {} + + // @param int32_t workerId + // @return Status - The error code return + Status EofReceived(int32_t) override; + + // @param int32_t workerId + // @return Status - The error code return + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sO - reference to the BatchOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BatchOp &bo) { + bo.Print(out, false); + return out; + } + + // Main loop of batch + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "BatchOp"; } + + // batch the rows in src table then put it to dest table + // @param const std::unique_ptr *src - table that has the rows for batching + // @param const std::unique_ptr *dest - dest_table to hold batched rows + // @param int32_t size - batch_size + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status BatchRows(const std::unique_ptr *src, const std::unique_ptr *dest, + dsize_t batch_size); + + // @param table + // @param const PadInfo &pad_info pad info + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @return Status - The error code return + static Status PadColumns(std::unique_ptr *table, const PadInfo &pad_info, + const std::unordered_map &column_name_id_map); + + private: + // Worker thread for doing the memcpy of batch + // @param int32_t param workerId + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Generate buffer with batched tensors + // @return Status - The error code return + Status MakeBatchedBuffer(std::pair, CBatchInfo> table_pair, + std::unique_ptr *db); + +#ifdef ENABLE_PYTHON + // Function that calls pyfunc to perform map on batch + // @param (std::pair, batch_stats> *table_pair - contains un-batched tensor + // @return Status - The error code return + Status MapColumns(std::pair, CBatchInfo> *table_pair); +#endif + + // @param const PadInfo &pad_info pad info to unpack + // @param const std::unordered_map& column_name_id_map - column names to index mapping + // @param std::set *cols, col ids to perform pad on + // @param std::vector *vals, default padding value for each column + // @param std::vector> *shapes, padding shape specified by user + // @return Status - The error code return + static Status UnpackPadInfo(const PadInfo &pad_info, + const std::unordered_map &column_name_id_map, + std::set *pad_cols, std::vector> *pad_vals, + std::vector> *pad_shapes); + + // the number of thread pulling from the mOutConnector of the Op below + // @return int32_t, 1 + int32_t num_consumers() const override { return 1; } + + // get the batch size for next batch + // @return Status - The error code return + Status GetBatchSize(int32_t *batch_size, CBatchInfo info); + + // Do the initialization of all queues then start all worker threads + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + +#ifdef ENABLE_PYTHON + // Invoke batch size function with current BatchInfo to generate batch size. + // @return Status - The error code return + Status InvokeBatchSizeFunc(int32_t *batch_size, CBatchInfo info); + + // Invoke batch map function with current BatchInfo to generate tensors to batch. + // @return Status - The error code return + Status InvokeBatchMapFunc(TensorTable *input, TensorTable *output, CBatchInfo info); +#endif + + int32_t start_batch_size_; + bool drop_; // bool for whether to drop remainder or not + bool pad_; // bool for whether to perform padding on tensor + std::vector pyfunc_column_names_; // Name of the columns to perform map op on + PadInfo pad_info_; // column names to perform padding on + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + QueueList, CBatchInfo>> worker_queues_; // internal queue for syncing worker +#ifdef ENABLE_PYTHON + py::function batch_size_func_; // Function pointer of batch size function + py::function batch_map_func_; // Function pointer of per batch map function +#endif +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BATCH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc new file mode 100644 index 0000000000..138bb7980b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.cc @@ -0,0 +1,240 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h" + +#include +#include +#include +#include +#include + +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "minddata/dataset/core/pybind_support.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/status.h" + +namespace py = pybind11; +namespace mindspore { +namespace dataset { +BucketBatchByLengthOp::Builder::Builder(std::vector length_dependent_columns, + std::vector bucket_boundaries, std::vector bucket_batch_sizes) + : builder_length_dependent_columns_(length_dependent_columns), + builder_bucket_boundaries_(bucket_boundaries), + builder_bucket_batch_sizes_(bucket_batch_sizes), + builder_pad_info_({}), + builder_pad_to_bucket_boundary_(false), + builder_drop_remainder_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_op_connector_size_ = config_manager->op_connector_size(); +} + +Status BucketBatchByLengthOp::Builder::SanityCheck() { + std::string error_message; + + if (builder_length_dependent_columns_.empty()) { + error_message += "At least 1 column must be specified for element length calculation.\n"; + } + + if (builder_bucket_boundaries_.empty()) { + error_message += "At least 1 bucket boundary must be specified.\n"; + } + + if (builder_bucket_batch_sizes_.size() != builder_bucket_boundaries_.size() + 1) { + error_message += "There must be exactly one bucket batch size specified for each bucket boundary.\n"; + } + + CHECK_FAIL_RETURN_UNEXPECTED(error_message.empty(), error_message); + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Builder::Build(std::shared_ptr *new_bucket_batch_by_length_op) { + RETURN_IF_NOT_OK(SanityCheck()); + + // insert 0 for the first bucket + builder_bucket_boundaries_.insert(builder_bucket_boundaries_.begin(), 0); + + *new_bucket_batch_by_length_op = std::make_shared( + builder_length_dependent_columns_, builder_bucket_boundaries_, builder_bucket_batch_sizes_, + builder_element_length_function_, builder_pad_info_, builder_pad_to_bucket_boundary_, builder_drop_remainder_, + builder_op_connector_size_); + + return Status::OK(); +} + +BucketBatchByLengthOp::BucketBatchByLengthOp(std::vector length_dependent_columns, + std::vector bucket_boundaries, + std::vector bucket_batch_sizes, + py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, + int32_t op_connector_size) + : PipelineOp(op_connector_size), + length_dependent_columns_(length_dependent_columns), + bucket_boundaries_(bucket_boundaries), + bucket_batch_sizes_(bucket_batch_sizes), + element_length_function_(element_length_function), + pad_info_(pad_info), + pad_to_bucket_boundary_(pad_to_bucket_boundary), + drop_remainder_(drop_remainder), + batch_count_(0) { + for (int i = 0; i < bucket_batch_sizes_.size(); i++) { + buckets_.push_back(std::make_unique()); + } +} + +Status BucketBatchByLengthOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +void BucketBatchByLengthOp::Print(std::ostream &out, bool show_all) const { out << "BucketBatchByLengthOp\n"; } + +Status BucketBatchByLengthOp::operator()() { + TaskManager::FindMe()->Post(); + + TensorRow current_row; + child_iterator_ = std::make_unique(this, 0, 0); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + while (!child_iterator_->eof_handled()) { + while (!current_row.empty()) { + int32_t element_length; + RETURN_IF_NOT_OK(ObtainElementLength(&element_length, current_row)); + + int bucket_index = bucket_boundaries_.size() - 1; + while (element_length < bucket_boundaries_[bucket_index]) { + bucket_index--; + } + + buckets_[bucket_index]->push_back(current_row); + + if (buckets_[bucket_index]->size() == bucket_batch_sizes_[bucket_index]) { + RETURN_IF_NOT_OK(PadAndBatchBucket(bucket_index, bucket_batch_sizes_[bucket_index])); + } + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + // got EOE, do what we need to do with remainders in each bucket + if (!drop_remainder_) { + for (int i = 0; i < bucket_boundaries_.size(); i++) { + if (!buckets_[i]->empty()) { + RETURN_IF_NOT_OK(PadAndBatchBucket(i, buckets_[i]->size())); + } + } + } + + // need to send EOE manually since we set state to idle in EoeRecieved() + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(¤t_row)); + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::ObtainElementLength(int32_t *out_element_length, TensorRow element) { + // call pyfunc here if given pyfunc, otherwise return 0th dimension of shape of + // the single column specified in length_dependent_columns_ + if (element_length_function_) { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + size_t number_of_arguments = length_dependent_columns_.size(); + py::tuple input_arguments(number_of_arguments); + for (size_t i = 0; i < number_of_arguments; i++) { + py::array argument_value; + int32_t column_index = column_name_id_map_[length_dependent_columns_[i]]; + RETURN_IF_NOT_OK(element[column_index]->GetDataAsNumpy(&argument_value)); + input_arguments[i] = argument_value; + } + + py::object length = element_length_function_(*input_arguments); + *out_element_length = length.cast(); + if (*out_element_length < 0) { + return Status(StatusCode::kPyFuncException, "Element length function should return a non negative integer."); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Count not cast output of element length function to int32_t."); + } + } else { + *out_element_length = element[0]->shape()[0]; + } + + return Status::OK(); +} + +Status BucketBatchByLengthOp::PadAndBatchBucket(int32_t bucket_index, int32_t batch_size) { + std::unique_ptr *bucket = &buckets_[bucket_index]; + + PadInfo pad_info_copy = pad_info_; + if (pad_to_bucket_boundary_) { + for (auto &pair : pad_info_copy) { + std::vector pad_shape = pair.second.first.AsVector(); + + for (size_t i = 0; i < pad_shape.size(); i++) { + if (pad_shape[i] == TensorShape::kDimUnknown) { + if (bucket_index + 1 >= bucket_boundaries_.size()) { + std::string error_message = "Requested to pad to bucket boundary, element falls in last bucket"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, error_message); + } + + pad_shape[i] = bucket_boundaries_[bucket_index + 1] - 1; + } + } + + pair.second.first = TensorShape(pad_shape); + } + } + + // PadColumns will change the data in bucket + RETURN_IF_NOT_OK(BatchOp::PadColumns(bucket, pad_info_copy, column_name_id_map_)); + + std::unique_ptr batched_bucket = std::make_unique(); + RETURN_IF_NOT_OK(BatchOp::BatchRows(bucket, &batched_bucket, batch_size)); + (*bucket)->clear(); + + std::unique_ptr batched_buffer = std::make_unique(batch_count_, DataBuffer::kDeBFlagNone); + batched_buffer->set_tensor_table(std::move(batched_bucket)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(batched_buffer))); + + batch_count_++; + + return Status::OK(); +} + +Status BucketBatchByLengthOp::Reset() { + batch_count_ = 0; + + for (int i = 0; i < buckets_.size(); i++) { + buckets_[i] = std::make_unique(); + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h new file mode 100644 index 0000000000..332ff4bb22 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DataBuffer; + +class BucketBatchByLengthOp : public PipelineOp { + public: + class Builder { + public: + Builder(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes); + + ~Builder() = default; + + Builder &SetLengthDependentColumns(std::vector length_dependent_columns) { + builder_length_dependent_columns_ = length_dependent_columns; + return *this; + } + + Builder &SetBucketBoundaries(std::vector bucket_boundaries) { + builder_bucket_boundaries_ = bucket_boundaries; + return *this; + } + + Builder &SetBucketBatchSizes(std::vector bucket_batch_sizes) { + builder_bucket_batch_sizes_ = bucket_batch_sizes; + return *this; + } + + Builder &SetElementLengthFunction(py::function element_length_function) { + builder_element_length_function_ = element_length_function; + return *this; + } + + Builder &SetPadInfo(PadInfo pad_info) { + builder_pad_info_ = pad_info; + return *this; + } + + Builder &SetPadToBucketBoundary(bool pad_to_bucket_boundary) { + builder_pad_to_bucket_boundary_ = pad_to_bucket_boundary; + return *this; + } + + Builder &SetDropRemainder(bool drop_remainder) { + builder_drop_remainder_ = drop_remainder; + return *this; + } + + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + Status Build(std::shared_ptr *new_bucket_batch_by_length_op); + + private: + Status SanityCheck(); + + std::vector builder_length_dependent_columns_; + std::vector builder_bucket_boundaries_; + std::vector builder_bucket_batch_sizes_; + py::function builder_element_length_function_; + PadInfo builder_pad_info_; + bool builder_pad_to_bucket_boundary_; + bool builder_drop_remainder_; + int32_t builder_op_connector_size_; + }; + + BucketBatchByLengthOp(std::vector length_dependent_columns, std::vector bucket_boundaries, + std::vector bucket_batch_sizes, py::function element_length_function, PadInfo pad_info, + bool pad_to_bucket_boundary, bool drop_remainder, int32_t op_connector_size); + + // Destructor + ~BucketBatchByLengthOp() = default; + + // Might need to batch remaining buckets after receiving eoe, so override this method. + // @param int32_t workerId + // @return Status - The error code returned + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sO - reference to the BucketBatchByLengthOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const BucketBatchByLengthOp &bo) { + bo.Print(out, false); + return out; + } + + // Main loop of batch + // @return Status - The error code returned + Status operator()() override; + + // Function that is called by ResetOp at the end of every epoch + // @return Status - The error code returned + Status Reset() override; + + private: + Status ObtainElementLength(int32_t *out_element_length, TensorRow element); + + Status PadAndBatchBucket(int32_t bucket_index, int32_t batch_size); + + std::vector length_dependent_columns_; + std::vector bucket_boundaries_; + std::vector bucket_batch_sizes_; + py::function element_length_function_; + PadInfo pad_info_; + bool pad_to_bucket_boundary_; + bool drop_remainder_; + + int32_t batch_count_; + std::unique_ptr child_iterator_; + std::vector> buckets_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BUCKET_BATCH_BY_LENGTH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc new file mode 100644 index 0000000000..8ed51ebbb6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.cc @@ -0,0 +1,206 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/build_vocab_op.h" + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" + +namespace mindspore { +namespace dataset { + +BuildVocabOp::BuildVocabOp(std::shared_ptr vocab, std::vector col_names, + std::pair freq_r, int64_t top_k, const std::vector &tokens, + bool prepend, int32_t num_workers, int32_t op_conn_size) + : ParallelOp(num_workers, op_conn_size), + interval_(op_conn_size * num_workers), + vocab_(vocab), + col_names_(col_names), + freq_range_(freq_r), + top_k_(top_k), + special_tokens_(tokens), + special_first_(prepend) { + // init two queues for thread sync + distributor_queue_ = std::make_unique>(num_workers * op_conn_size); + collector_queue_ = + std::make_unique>>>(num_workers * op_conn_size); +} + +Status BuildVocabOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + TensorRow new_row; + RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); + std::unique_ptr> wrkr_map = + std::make_unique>(); + int32_t row_cnt = 0; + while (!new_row.empty()) { + for (int32_t col : col_ids_) { + CHECK_FAIL_RETURN_UNEXPECTED(!new_row[col]->type().IsNumeric(), "from_dataset only works on string columns"); + for (auto itr = new_row[col]->begin(); itr != new_row[col]->end(); itr++) { + (*wrkr_map)[std::string(*itr)] += 1; + } + } + row_cnt++; // row is processed by this point + if ((row_cnt % interval_ == 0) && ((row_cnt / interval_) % num_workers_ == worker_id) && (!wrkr_map->empty())) { + RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); + wrkr_map = std::make_unique>(); + } + RETURN_IF_NOT_OK(distributor_queue_->PopFront(&new_row)); + } + // clean up + if (!wrkr_map->empty()) { + RETURN_IF_NOT_OK(collector_queue_->Add(std::move(wrkr_map))); + } + // empty map as quit signal + RETURN_IF_NOT_OK(collector_queue_->Add(std::make_unique>())); + return Status::OK(); +} + +Status BuildVocabOp::operator()() { + // launch the collector thread + RETURN_UNEXPECTED_IF_NULL(tree_); + RETURN_IF_NOT_OK(distributor_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(collector_queue_->Register(tree_->AllTasks())); + // launch worker threads and collector thread + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&BuildVocabOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("collector", std::bind(&BuildVocabOp::CollectorThread, this))); + TaskManager::FindMe()->Post(); + child_iterator_ = std::make_unique(this, 0, 0); + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + if (!col_names_.empty()) { + col_ids_.reserve(col_names_.size()); + for (std::string col : col_names_) { + auto itr = column_name_id_map_.find(col); + CHECK_FAIL_RETURN_UNEXPECTED(itr != column_name_id_map_.end(), col + " column doesn't exist"); + col_ids_.push_back(itr->second); + } + } else { + col_ids_.reserve(column_name_id_map_.size()); + for (const auto &p : column_name_id_map_) { + col_ids_.push_back(p.second); + } + } + bool eoe_warning = false; // give out warning if receive more than 1 eoe + while (child_iterator_->eof_handled() == false) { + while (new_row.empty() == false) { + RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(new_row)); + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + CHECK_FAIL_RETURN_UNEXPECTED(!eoe_warning, "no op should be after from_dataset (repeat detected)"); + eoe_warning = true; + } + + // tell all workers to quit + for (int32_t wrkr_id = 0; wrkr_id < num_workers_; wrkr_id++) { + RETURN_IF_NOT_OK(distributor_queue_->EmplaceBack(TensorRow())); + } + return Status::OK(); +} + +Status BuildVocabOp::CollectorThread() { + TaskManager::FindMe()->Post(); + int32_t num_quited_worker = 0; + std::unique_ptr> wrkr_map; + while (num_quited_worker != num_workers_) { + RETURN_IF_NOT_OK(collector_queue_->PopFront(&wrkr_map)); + RETURN_UNEXPECTED_IF_NULL(wrkr_map); + if (!wrkr_map->empty()) { + for (const auto &wd : *wrkr_map) word_cnt_[wd.first] += wd.second; + } else { + ++num_quited_worker; + } + } // all frequencies are obtained + CHECK_FAIL_RETURN_UNEXPECTED(!word_cnt_.empty(), "word_cnt is empty"); + std::vector words; + // make sure enough is reserved, this will become a partially sorted list eventually + words.reserve(wrkr_map->size()); + + for (auto it = word_cnt_.begin(); it != word_cnt_.end();) { + if (it->second >= freq_range_.first && it->second <= freq_range_.second) { + words.push_back(it->first); + it++; + } else { + it = word_cnt_.erase(it); + } + } + std::string err_msg; + + for (const std::string &sp_tk : special_tokens_) { + // if a special word exists in dataset, warn user about this + err_msg += (word_cnt_.find(sp_tk) != word_cnt_.end() ? sp_tk + "\t" : ""); + } + + CHECK_FAIL_RETURN_UNEXPECTED(err_msg.empty(), "These specials words are already in the dataset: " + err_msg + "."); + + int64_t num_words = std::min(static_cast(words.size()), top_k_); + if (num_words == 0) { + MS_LOG(WARNING) << "No word falls in the frequency range: (" << freq_range_.first << "," << freq_range_.second + << ") vocab would be empty (except for special tokens)."; + } + + // this would take the top-k most frequent words + std::partial_sort(words.begin(), words.begin() + num_words, words.end(), + [this](const std::string &w1, const std::string &w2) { + int64_t f1 = word_cnt_[w1], f2 = word_cnt_[w2]; + return f1 == f2 ? w1 < w2 : f1 > f2; + }); + + if (special_first_) { + for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); + } + + for (int64_t i = 0; i < num_words; i++) { + vocab_->append_word(words[i]); + } + + if (!special_first_) { + for (const std::string &sp_tk : special_tokens_) vocab_->append_word(sp_tk); + } + + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + // then use std::nth_element to partial sort + return Status::OK(); +} + +Status BuildVocabOp::Builder::Build(std::shared_ptr *op) { + CHECK_FAIL_RETURN_UNEXPECTED(builder_num_workers_ > 0, "builder num_workers need to be greater than 0"); + CHECK_FAIL_RETURN_UNEXPECTED(builder_top_k_ > 0, "top_k needs to be positive number"); + CHECK_FAIL_RETURN_UNEXPECTED(builder_max_freq_ >= builder_min_freq_ && builder_min_freq_ >= 0, + "frequency range [a,b] should be 0 <= a <= b (a,b are inclusive)"); + (*op) = std::make_shared( + builder_vocab_, builder_col_names_, std::make_pair(builder_min_freq_, builder_max_freq_), builder_top_k_, + builder_speical_tokens_, builder_special_first_, builder_num_workers_, builder_connector_size_); + return Status::OK(); +} + +BuildVocabOp::Builder::Builder() + : builder_top_k_(std::numeric_limits::max()), + builder_min_freq_(0), + builder_max_freq_(std::numeric_limits::max()), + builder_special_first_(true) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_connector_size_ = cfg->op_connector_size(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h new file mode 100644 index 0000000000..42ea0deb5c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/build_vocab_op.h @@ -0,0 +1,174 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BuildVocabOp : public ParallelOp { + public: + class Builder { + public: + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_connector_size_ = size; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param int64_t top_k + // @return Builder setter method returns reference to the builder. + Builder &SetTopK(int64_t top_k) { + builder_top_k_ = top_k; + return *this; + } + + // Setter method + // @param int64_t min_freq + // @return Builder setter method returns reference to the builder. + Builder &SetMinFreq(int64_t min_freq) { + builder_min_freq_ = min_freq; + return *this; + } + + // Setter method + // @param int64_t max_freq + // @return Builder setter method returns reference to the builder. + Builder &SetMaxFreq(int64_t max_freq) { + builder_max_freq_ = max_freq; + return *this; + } + + // set columns names + // @param const std::vector & col_names - name of columns to get words + // @return Builder & reference to builder class object + Builder &SetColumnNames(const std::vector &col_names) { + builder_col_names_ = col_names; + return *this; + } + + // set special tokens + // @param const std::vector & col_names - name of columns to get words + // @return Builder & reference to builder class object + Builder &SetSpecialTokens(const std::vector &tokens) { + builder_speical_tokens_ = tokens; + return *this; + } + + // set vocab object + Builder &SetVocab(std::shared_ptr vocab) { + builder_vocab_ = vocab; + return *this; + } + + // set special tokens first (or last) + Builder &SetSpecialFirst(bool prepend) { + builder_special_first_ = prepend; + return *this; + } + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + int32_t builder_num_workers_; + int32_t builder_connector_size_; + int64_t builder_min_freq_; + int64_t builder_max_freq_; + bool builder_special_first_; + std::vector builder_col_names_; + std::vector builder_speical_tokens_; + std::shared_ptr builder_vocab_; + int64_t builder_top_k_; + }; + + BuildVocabOp(std::shared_ptr vocab, std::vector col_names, std::pair freq_range, + int64_t top_k, const std::vector &tokens, bool prepend, int32_t num_workers, + int32_t op_connector_size); + + ~BuildVocabOp() = default; + + Status WorkerEntry(int32_t worker_id) override; + + // collect the work product from each worker + Status CollectorThread(); + + Status EofReceived(int32_t) override { return Status::OK(); } + + Status EoeReceived(int32_t) override { return Status::OK(); } + + Status operator()() override; + + // Getter + // @return the number of workers + int32_t num_producers() const override { return 1; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return 1; } + + Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); } + + private: + const int32_t interval_; + bool special_first_; + std::shared_ptr vocab_; + std::vector col_names_; + std::vector col_ids_; + std::vector special_tokens_; + // pair = {min_f, max_f} + // make sure that 0<= min_f < max_f <= int32_max in the builder + std::pair freq_range_; + + int64_t top_k_; // every thing means top_k_ == int32_max + std::unique_ptr child_iterator_; // child iterator for fetching TensorRows 1 by 1 + std::unique_ptr> distributor_queue_; // master thread assigns each worker TensorRow via this + std::unique_ptr>>> collector_queue_; + std::unordered_map word_cnt_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_BUILD_VOCAB_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc new file mode 100644 index 0000000000..1b0890686f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_base_op.h" +#include +#include +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// A print method typically used for debugging +void CacheBase::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCache client:\n" << *cache_client_ << "\n\n"; + } +} +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status CacheBase::Reset() { + if (sampler_ != nullptr) { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + } + // Wake up the workers to get them going again in a new epoch + MS_LOG(DEBUG) << Name() << " resetting."; + epoch_sync_.Set(); + return Status::OK(); +} +CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, sampler), + cache_client_(cache_client), + rows_per_buffer_(rows_per_buf), + // We can cause deadlock if this internal Connector size is too small. + keys_miss_(num_workers_, 1, connector_capacity_) { + io_block_queues_.Init(num_workers, op_connector_size); +} +// Common function to fetch samples from the sampler and send them using the io_block_queues to +// the parallel workers +Status CacheBase::FetchSamplesToWorkers() { + int64_t buf_cnt = 0; + int64_t wait_cnt = 0; + do { + epoch_sync_.Clear(); + std::vector keys; + int64_t row_cnt = 0; + keys.reserve(rows_per_buffer_); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (!sampler_buffer->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { + keys.push_back(*itr); + ++row_cnt; + if (row_cnt % rows_per_buffer_ == 0) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (!keys.empty()) { + auto blk = std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)); + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk))); + } + // send the eoe + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + // If repeat but the not last repeat, wait for reset. + if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt; + RETURN_IF_NOT_OK(epoch_sync_.Wait()); + } else { + // We can break out from the loop. + break; + } + } while (true); + // Flow the eof before exit + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + // Ask all the workers to quit. + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); +} +Status CacheBase::FetchFromCache(int32_t worker_id) { + int64_t buffer_id = worker_id; + std::unique_ptr blk; + do { + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk)); + if (blk->eof()) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else if (blk->eoe()) { + if (AllowCacheMiss()) { + // This code path is for CacheLookupOp acting as a sampler. If we get a eoe from + // a sampler, send a eoe to physical leaf op as well. + std::vector eoe; + eoe.push_back(eoe_row_id); + RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe)); + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(blk->GetKeys(&keys)); + if (keys.empty()) { + // empty key is a quit signal for workers + break; + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + std::unique_ptr que = std::make_unique(); + TensorTable ttbl; + RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl)); + auto row_it = ttbl.begin(); + std::vector cache_miss; + cache_miss.reserve(keys.size()); + for (auto row_id : keys) { + auto &row = *row_it; + if (row.empty()) { + if (AllowCacheMiss()) { + cache_miss.push_back(row_id); + } else { + std::string errMsg = "Row id " + std::to_string(row_id) + " not found."; + RETURN_STATUS_UNEXPECTED(errMsg); + } + } + que->push_back(std::move(row)); + ++row_it; + } + db->set_tensor_table(std::move(que)); + if (AllowCacheMiss()) { + // Because of the way connector works, we push unconditionally even cache_miss can be empty. + RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss)); + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + } while (true); + return Status::OK(); +} +Status CacheBase::RegisterResources() { + RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + return Status::OK(); +} +CacheBase::~CacheBase() {} +Status CacheBase::UpdateColumnMapFromCache() { + Status rc; + // Get the schema from the server. It may not be there yet. So tolerate the error. + if (column_name_id_map_.empty()) { + rc = cache_client_->FetchSchema(&column_name_id_map_); + if (rc == Status(StatusCode::kFileNotExist)) { + MS_LOG(DEBUG) << "Schema not in the server yet."; + rc = Status::OK(); + } + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h new file mode 100644 index 0000000000..fb3e999b76 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/cache/cache_service.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/datasetops/cache_base_op.h" +namespace mindspore { +namespace dataset { +/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities. +/// \see CacheOp +/// \see CacheLookupOp +class CacheBase : public ParallelOp { + public: + /// \brief Base class constructor + /// \param num_workers Number of parallel workers + /// \param op_connector_size Connector size + /// \param rows_per_buf Number of rows per buffer + /// \param cache_client CacheClient for communication to the CacheServer + /// \param sampler Sampler which is mandatory + CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler); + /// \brief Destructor + ~CacheBase(); + + /// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state + /// info from it's previous execution and then initializes itself so that it can be executed + /// again. + /// \return Status - The error code return + Status Reset() override; + + /// \brief A print method typically used for debugging + /// \param out The output stream to write output to + /// \param show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + /// \brief << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out reference to the output stream being overloaded + /// \param mo reference to the CacheOp to display + /// \return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) { + mo.Print(out, false); + return out; + } + + /// \brief Getter for the cache client + /// \return shared ptr to the cache client + std::shared_ptr cache_client() { return cache_client_; } + /// \brief Setter for the cache client + void SetCacheClient(std::shared_ptr cache_client) { cache_client_ = std::move(cache_client); } + /// \brief Derived class must implement this method if a cache miss is treated as error + virtual bool AllowCacheMiss() = 0; + + protected: + constexpr static int32_t eoe_row_id = -1; + std::shared_ptr cache_client_; + WaitPost epoch_sync_; + int32_t rows_per_buffer_; + Connector> keys_miss_; + + /// \brief Common function to register resources for interrupt + /// \note Derived should override this function for extra resources to be registered + virtual Status RegisterResources(); + /// \brief This function is called by main thread to send samples to the worker thread. + /// \note It is a non-virtual function + /// \return Status object + Status FetchSamplesToWorkers(); + /// \brief This function is called by each worker to fetch rows from the cache server for a given set of + /// sample row id's + /// \return Status object + Status FetchFromCache(int32_t worker_id); + /// \brief Get the column map from cache server + Status UpdateColumnMapFromCache(); + + private: + constexpr static int32_t connector_capacity_ = 1024; + QueueList> io_block_queues_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc new file mode 100644 index 0000000000..0a9b7544ba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "utils/log_adapter.h" +#include "utils/system/crc32c.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status CacheLookupOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Cache client for CacheLookupOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheLookupOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, + build_cache_client_, build_sampler_); + return Status::OK(); +} +Status CacheLookupOp::operator()() { + if (!sampler_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "CacheLookupOp requires a sampler before it can be executed!"); + } + RETURN_IF_NOT_OK(RegisterResources()); + // Kick off the workers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1))); + // required task group sync after launching workers + TaskManager::FindMe()->Post(); + // We have to wait until the leaf op has handshake with us. + RETURN_IF_NOT_OK(leaf_op_wp_.Wait()); + RETURN_IF_NOT_OK(FetchSamplesToWorkers()); + return Status::OK(); +} +Status CacheLookupOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(FetchFromCache(worker_id)); + return Status::OK(); +} +Status CacheLookupOp::ResetSampler() { return Status::OK(); } +Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) { + // We act like a sampler and as a dataset op. During handshake with leaf op, + // We must wait until the leaf op has indexed everything. + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op)); + // Now we notify the main thread handshake has finished. + leaf_op_wp_.Set(); + return Status::OK(); +} +Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); } +void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); } +Status CacheLookupOp::GetNextSample(std::unique_ptr *out_buffer) { + std::vector cache_miss; + RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + // Ignore the case we have no cache miss, we can't return empty samples. + while (cache_miss.empty()) { + RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss)); + } + // Special code for eoe + if (cache_miss.at(0) == eoe_row_id) { + *out_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + std::shared_ptr sample_ts; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size())); + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + auto idPtr = sample_ts->begin(); + for (auto i = 0; i < cache_miss.size(); ++i) { + *idPtr = cache_miss.at(i); + ++idPtr; + } + TensorRow row; + row.push_back(sample_ts); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} +Status CacheLookupOp::RegisterResources() { + RETURN_IF_NOT_OK(CacheBase::RegisterResources()); + RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks())); + return Status::OK(); +} +Status CacheLookupOp::ComputeColMap() { + // We don't know the column map at this point unless we contact the cache server + // to fetch the schema but the cache server may not have it at this point either. + // So we will just return OK and let MergeOp (our parent) to handle it. + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CacheLookupOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h new file mode 100644 index 0000000000..46a58c5d02 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_lookup_op.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/cache_base_op.h" + +namespace mindspore { +namespace dataset { +/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset. +/// \note For non-mappable dataset, please see CacheOp +/// \see CacheOp +class CacheLookupOp : public CacheBase, public Sampler { + public: + class Builder { + public: + /// \brief Builder constructor. Creates the builder object. + /// \note No default args + Builder(); + + /// Default destructor + ~Builder() = default; + + /// Setter method. + /// \treturn Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheLookupOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t rows_per_buffer_; + int32_t build_op_connector_size_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + // Check if the required parameters are set by the builder. + // \return Status The error code return + Status SanityCheck() const; + }; + /// \brief Constructor + /// \note It takes the same argument as the base class. + /// \see CacheBase + CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {} + ~CacheLookupOp() = default; + // As a parallel op, we override these two functions + Status operator()() override; + Status WorkerEntry(int32_t worker_id) override; + // As a sampler, we override the following functions + Status ResetSampler() override; + Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + Status InitSampler() override; + Status GetNextSample(std::unique_ptr *out_buffer) override; + void Print(std::ostream &out, bool show_all) const override; + bool AllowCacheMiss() override { return true; } + std::string Name() const override { return "CacheLookupOp"; } + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + protected: + Status ComputeColMap() override; + + private: + WaitPost leaf_op_wp_; + + Status RegisterResources() override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc new file mode 100644 index 0000000000..75579dc3a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc @@ -0,0 +1,302 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" + +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +CacheMergeOp::~CacheMergeOp() = default; +void CacheMergeOp::Print(std::ostream &out, bool show_all) + const { // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\n\n"; + } +} +CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, + std::shared_ptr cache_client, const std::shared_ptr &sampler) + : ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {} +Status CacheMergeOp::operator()() { + // A queue of row id to let cleaner send cache miss rows to the cache server + // We don't want a small queue as this will block the parallel op workers. + // A row id is 8 byte integer. So bigger size doesn't consume a lot of memory. + static const int32_t queue_sz = 512; + io_que_ = std::make_unique>(queue_sz); + RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1))); + // One dedicated thread to move TensorRow from the pool to the cache server + for (auto i = 0; i < num_cleaners_; ++i) { + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this))); + } + TaskManager::FindMe()->Post(); + return Status::OK(); +} +// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait +// until it shows up in the pool. +Status CacheMergeOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::shared_ptr cache_hit_stream = child_[kCacheHitChildIdx]; + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + while (!db_ptr->eof()) { + if (db_ptr->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + db_ptr.reset(); + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + } else { + // See if there is any missing row + auto tbl = std::make_unique(); + while (db_ptr->NumRows() > 0) { + TensorRow row; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + if (row.empty()) { + auto row_id = row.getId(); + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + // Block until the row shows up in the pool. + RETURN_IF_NOT_OK(rq->Wait(&row)); + } + tbl->push_back(std::move(row)); + } + db_ptr->set_tensor_table(std::move(tbl)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr))); + return Status::OK(); +} +Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) { + TaskManager::FindMe()->Post(); + // We will simply pop TensorRow from the stream and insert them into the pool and + // wake up any worker that is awaiting on the missing TensorRow. + // If we see an eoe, ignore it. For eof, we exit. + std::shared_ptr cache_missing_stream = child_[kCacheMissChildIdx]; + // Before we start, cache the schema at the server. Pick one of the workers + // do it. The schema should have been done at prepare time. + if (workerId == 0) { + RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); + } + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); + while (!db_ptr->eof()) { + if (db_ptr->eoe()) { + // Ignore it. + MS_LOG(DEBUG) << "Ignore eoe"; + } else { + while (db_ptr->NumRows() > 0) { + TensorRow row; + RETURN_IF_NOT_OK(db_ptr->PopRow(&row)); + row_id_type row_id = row.getId(); + if (row_id < 0) { + std::string errMsg = "Expect positive row id: " + std::to_string(row_id); + RETURN_STATUS_UNEXPECTED(errMsg); + } + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + rq->WakeUpAny(std::move(row)); + // Let the cleaner to flush out this row (async) to the cache server. + RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id)); + } + } + RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId)); + } + return Status::OK(); +} +Status CacheMergeOp::Cleaner() { + TaskManager::FindMe()->Post(); + while (true) { + row_id_type row_id; + RETURN_IF_NOT_OK(io_que_->PopFront(&row_id)); + if (row_id < 0) { + break; + } + TensorRowRequest *rq = nullptr; + RETURN_IF_NOT_OK(GetRq(row_id, &rq)); + if (rq->GetState() == TensorRowRequest::State::kClean) { + // If already flushed, move on to the next one. + continue; + } + TensorRow row; + RETURN_IF_NOT_OK(rq->Release(&row)); + CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error."); + Status rc = cache_client_->WriteRow(row); + // Bad rc should not bring down the pipeline + if (rc.IsError()) { + MS_LOG(WARNING) << "Cache not successful." << rc.ToString(); + } + rq->SetState(TensorRowRequest::State::kClean); + } + return Status::OK(); +} + +Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) { + RETURN_UNEXPECTED_IF_NULL(out); + std::unique_lock lck(mux_); + auto it = cache_miss_map_.find(row_id); + if (it != cache_miss_map_.end()) { + *out = it->second.GetMutablePointer(); + } else { + // We will create a new one. + auto alloc = Services::GetAllocator(); + auto r = cache_miss_map_.emplace(row_id, MemGuard>(alloc)); + if (r.second) { + auto &mem = r.first->second; + RETURN_IF_NOT_OK(mem.allocate(1, row_id)); + *out = mem.GetMutablePointer(); + } else { + RETURN_STATUS_UNEXPECTED("Map insert fail."); + } + } + return Status::OK(); +} +Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own + // specific logic + CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children"); + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + // Get the computed check sum from all ops in the cache miss class + uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]); + // This is a mappable cache op so the id's need to be generated. + // Construct the cache + const bool generate_ids = false; + Status rc = cache_client_->CreateCache(cache_crc, generate_ids); + if (rc.get_code() == StatusCode::kDuplicateKey) { + // We are told the cache has been created already. + MS_LOG(INFO) << "Cache created already"; + rc = Status::OK(); + } + RETURN_IF_NOT_OK(rc); + return Status::OK(); +} +Status CacheMergeOp::ComputeColMap() { + CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty"); + if (column_name_id_map().empty()) { + column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map(); + } + CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected"); + return Status::OK(); +} +Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // Block until the missing row is in the pool. + RETURN_IF_NOT_OK(use_count_.P()); + std::unique_lock lck(dq_mux_); + CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error"); + *out = std::move(row_.front()); + row_.pop_front(); + return Status::OK(); +} +void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) { + std::unique_lock lck(dq_mux_); + // Technically number of this row shows up in the cache miss stream is equal to the number + // of P() call. However the cleaner wants it too. So we need an extra copy. + if (GetState() == State::kEmpty) { + // We will do a deep copy + for (auto &ts : row) { + auto out_ts = std::make_shared(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes()); + cleaner_copy_.push_back(out_ts); + } + cleaner_copy_.setId(row.getId()); + // Change the state to dirty + SetState(State::kDirty); + } + row_.push_back(std::move(row)); + // Bump up the use count by 1. This wake up any parallel worker which is waiting + // for this row. + use_count_.V(); +} +Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) { + RETURN_UNEXPECTED_IF_NULL(out); + // We are not holding any mutex here because the cleaner isn't really touching the deque row_. + // In case we have multiple cleaners and they all see the copy, only one of them will + // get it. + auto expected = State::kDirty; + if (st_.compare_exchange_strong(expected, State::kClean)) { + *out = std::move(cleaner_copy_); + } + return Status::OK(); +} +// Builder constructor. Creates the builder object. +CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + build_op_connector_size_ = cfg->op_connector_size(); + build_num_cleaners_ = 1; +} + +// Check if the required parameters are set by the builder. +Status CacheMergeOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Cache client for CacheMergeOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheMergeOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, build_num_cleaners_, + build_cache_client_, build_sampler_); + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status CacheMergeOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CacheMergeOp::EoeReceived(int32_t worker_id) { + // If we are in a repeat path, send the eoe up. + // Otherwise ignore it. + if (BitTest(op_ctrl_flags_, kDeOpRepeated)) { + return DatasetOp::EoeReceived(worker_id); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h new file mode 100644 index 0000000000..df37465fc4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/semaphore.h" + +namespace mindspore { +namespace dataset { +/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single +/// stream +class CacheMergeOp : public ParallelOp { + public: + // Some handshake structures among the main thread, cleaner threads and parallel op threads. + class TensorRowRequest { + public: + enum class State : uint8_t { + kEmpty = 0, // No row in the deque + kDirty = 1, // Cleaner hasn't flushed it to the cache server yet. + kClean = 2 // The row has been flushed already. + }; + explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {} + ~TensorRowRequest() = default; + State GetState() const { return st_; } + void SetState(State newState) { st_ = newState; } + Status Wait(TensorRow *out); + void WakeUpAny(TensorRow &&row); + Status Release(TensorRow *out); + + private: + std::mutex dq_mux_; + std::atomic st_; + Semaphore use_count_; + std::deque row_; + TensorRow cleaner_copy_; + }; + + constexpr static int kCacheHitChildIdx = 0; // Cache hit stream + constexpr static int kCacheMissChildIdx = 1; // Cache miss stream + + /// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of + /// the arguments for constructing it. Use the builder by setting each argument + /// with the provided set methods, and then finally call the build method to execute + /// the actual construction. + class Builder { + public: + /// Builder constructor. Creates the builder object. + /// \note No default args + Builder(); + + /// Default destructor + ~Builder() = default; + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method + /// \param sampler + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief Setter method + /// \param num_cleaners + /// \return Builder setter method returns reference to the builder. + Builder &SetNumCleaner(int32_t num_cleaners) { + build_num_cleaners_ = num_cleaners; + return *this; + } + + /// The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheMergeOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t build_op_connector_size_; + int32_t build_num_cleaners_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + /// Check if the required parameters are set by the builder. + /// \return Status The error code return + Status SanityCheck() const; + }; + + /// \brief Constructor + /// \param numWorkers Number of parallel workers as a derived class of ParallelOp + /// \param opConnector Size Connector size as a derived class of ParallelOp + /// \param numCleaners Number of cleaners to move cache miss rows into the cache server + /// \param cache_client CacheClient to commmunicate with the Cache server + /// \param sampler as a derived class of ParallelOp + CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners, + std::shared_ptr cache_client, const std::shared_ptr &sampler); + ~CacheMergeOp(); + void Print(std::ostream &out, bool show_all) const override; + friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) { + mo.Print(out, false); + return out; + } + /// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and + /// the threads for the cleaners. + /// \return + Status operator()() override; + /// \brief Entry function for worker thread that fetch rows from CacheLookupOp + /// \param workerId + /// \return Status object + Status WorkerEntry(int32_t workerId) override; + Status PrepareNodePostAction() override; + /// \brief Entry function for worker thread that fetch rows from the cache miss stream + /// \param workerId + /// \return Status object + Status CacheMissWorkerEntry(int32_t workerId); + Status GetRq(row_id_type row_id, TensorRowRequest **); + + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for eoe handling + /// \param worker_id + /// \return Status object + Status EoeReceived(int32_t worker_id) override; + + protected: + Status ComputeColMap() override; + + private: + std::mutex mux_; + std::map>> cache_miss_map_; + std::unique_ptr> io_que_; + std::shared_ptr cache_client_; + int32_t num_cleaners_; + + /// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for + /// moving cache miss TensorRow into the CacheServer. + /// \return Status object + Status Cleaner(); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc new file mode 100644 index 0000000000..143c45b2dc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc @@ -0,0 +1,219 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/cache_op.h" + +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/task_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status CacheOp::Builder::SanityCheck() const { + if (build_cache_client_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient"); + } + // Make sure the cache client has a valid session + if (!build_cache_client_->session_id()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object and does some init on it +Status CacheOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_, + build_sampler_); + RETURN_IF_NOT_OK((*ptr)->InitCache()); + + return Status::OK(); +} + +// Constructor of CacheOp +CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler) + : CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), + num_guys_in_(0), + phase_(Phase::kBuildPhase) {} + +// Destructor +CacheOp::~CacheOp() = default; + +// Private function for cache setup/init work just after construction +Status CacheOp::InitCache() { return Status::OK(); } + +// This class functor will provide the master loop that drives the logic for performing the work +Status CacheOp::operator()() { + if (!sampler_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "CacheOp requires a sampler before it can be executed!"); + } + RETURN_IF_NOT_OK(RegisterResources()); + // Kick off the workers + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1))); + // required task group sync after launching workers + TaskManager::FindMe()->Post(); + // Wait for the workers to finish caching the rows. + RETURN_IF_NOT_OK(WaitForCachingAllRows()); + RETURN_IF_NOT_OK(FetchSamplesToWorkers()); + return Status::OK(); +} +Status CacheOp::CacheAllRows(int32_t worker_id) { + // If the current phase is to fill the cache, do it then. + if (phase_ == Phase::kBuildPhase) { + // We will take the chance to cache the schema at the server. + // Just do it once and pick one worker to do it. + if (worker_id == 0) { + RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map())); + } + MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id; + // SAVE mode loop + std::unique_ptr db_ptr; + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + while (!db_ptr->eof()) { + if (!db_ptr->eoe()) { + RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr))); + } else { + // In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up + // as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the + // the eoe to indicate the end of the epoch, we should next expect to get the eof. + // Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch + // from again. + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + if (!db_ptr->eof()) { + RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child."); + } + } + RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0)); + } + } + // Let the main guy know we are done. + auto last_guy_in = num_guys_in_.fetch_add(1); + if ((last_guy_in + 1) == num_workers_) { + rows_cache_done_.Set(); + } else { + // Let's do a sync up here. + RETURN_IF_NOT_OK(rows_cache_done_.Wait()); + } + return Status::OK(); +} +Status CacheOp::WaitForCachingAllRows() { + // Wait for the workers to finish caching the rows. + RETURN_IF_NOT_OK(rows_cache_done_.Wait()); + // Move from build phase to fetch phase if we are the one to fill the cache + if (phase_ == Phase::kBuildPhase) { + RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone()); + // Move to the next phase + phase_ = Phase::kFetchPhase; + } + // Get statistics from the server, and if we are not the one to create the cache, + // wait until the state changed from build phase to fetch base. + CacheClient::ServiceStat stat{}; + bool BuildPhaseDone = true; + do { + RETURN_IF_NOT_OK(cache_client_->GetStat(&stat)); + BuildPhaseDone = stat.cache_service_state == static_cast(CacheService::State::kFetchPhase); + if (!BuildPhaseDone) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } while (!BuildPhaseDone); + const row_id_type min_key = stat.min_row_id; + const row_id_type max_key = stat.max_row_id; + num_rows_ = max_key - min_key + 1; + MS_LOG(INFO) << "Number of rows cached: " << num_rows_; + MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached; + MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached; + // Now all rows are cached and we have done a sync point check up. Next phase is + // is pick up fetch input from sampler and pass up to the caller. + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} +Status CacheOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(CacheAllRows(worker_id)); + RETURN_IF_NOT_OK(FetchFromCache(worker_id)); + return Status::OK(); +} +Status CacheOp::RegisterResources() { + RETURN_IF_NOT_OK(CacheBase::RegisterResources()); + RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks())); + return Status::OK(); +} + +// Base-class override for setting specific CacheOp configurations. This code will be called +// during the execution tree prepare phase BEFORE traversing down to child operators. +uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; } +// Base-class override for special eoe handler. +// CacheOp must override this because it shall not perform default handling of eoe. Instead +// the CacheOp manages actions related to the end of the epoch. +Status CacheOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +// Base-class override for handling cases when an eof is received. +Status CacheOp::EofReceived(int32_t worker_id) { + // eofReceived is overloaded because we want to manually handle this eof. + // Specifically, the default behaviour is to pack it and flow it up to the next connection. + // In this case, we want a no-op behaviour so that we can perform correct action. + return Status::OK(); +} + +// Pre-Visitor accept method for NodePass +Status CacheOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status CacheOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +// A public wrapper for creating the cache through the client +Status CacheOp::CreateCache(uint32_t cache_crc) { + // This is a non-mappable cache op so the id's need to be generated. + // Construct the cache + const bool generate_ids = true; + Status rc = cache_client_->CreateCache(cache_crc, generate_ids); + if (rc.get_code() == StatusCode::kDuplicateKey) { + // We are told the cache has been created already. So we skip the build phase. + phase_ = Phase::kFetchPhase; + rc = Status::OK(); + } + RETURN_IF_NOT_OK(rc); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h new file mode 100644 index 0000000000..dd34d54973 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/cache_base_op.h" + +namespace mindspore { +namespace dataset { +/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset. +/// \note For mappable dataset, please see CacheLookupOp. +/// \see CacheLookupOp +class CacheOp : public CacheBase, public RandomAccessOp { + public: + // This CacheOp is for non-mappable case where it is divided into two phases. + // The first phase is we cache all the rows from the child (and let the cache server + // assigns row id). No read access in the first phase. Once the cache is fully built, + // we switch to second phase and fetch requests from the sampler. + enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 }; + + /// \brief The nested builder class inside of the CacheOp is used to help manage all of + /// the arguments for constructing it. Use the builder by setting each argument + /// with the provided set methods, and then finally call the build method to execute + /// the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + /// \brief Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + /// Setter method. + /// \return Builder setter method returns reference to the builder. + Builder &SetClient(std::shared_ptr cache_client) { + build_cache_client_ = cache_client; + return *this; + } + + /// \brief Setter method + /// \param rows_per_buffer + /// \return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + rows_per_buffer_ = rows_per_buffer; + return *this; + } + + /// \brief Setter method + /// \param sampler + /// \return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + build_sampler_ = std::move(sampler); + return *this; + } + + /// \brief The builder "build" method creates the final object and does some init on it. + /// \param ptr The shared_ptr to the new CacheOp object + /// \return Status + Status Build(std::shared_ptr *ptr); + + private: + int32_t build_num_workers_; + int32_t rows_per_buffer_; + int32_t build_op_connector_size_; + std::shared_ptr build_cache_client_; + std::shared_ptr build_sampler_; + + /// \brief Check if the required parameters are set by the builder. + /// \return Status The error code return + Status SanityCheck() const; + }; + + /// \brief Constructor of CacheOp + /// \note The builder class should be used to call it. + /// \param num_workers The number of worker threads. + /// \param op_connector_size The size of each queue in the connector. + CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, + std::shared_ptr cache_client, std::shared_ptr sampler); + + // Destructor + ~CacheOp(); + + /// \brief Base-class override for setting specific CacheOp configurations. This code will be called + /// during the execution tree prepare phase BEFORE traversing down to child operators. + uint32_t PrepareFlags() const override; + /// \brief Base-class override for special eoe handler. + /// CacheOp must override this because it shall not perform default handling of eoe. Instead + /// the CacheOp manages actions related to the end of the epoch. + /// \return Status - The error code return + Status EoeReceived(int32_t worker_id) override; + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + /// \brief Base-class override for handling cases when an eof is received. + /// \param worker_id - The worker id + /// \return Status - The error code return + Status EofReceived(int32_t worker_id) override; + Status operator()() override; + Status WorkerEntry(int32_t worker_id) override; + /// \brief Base-class override for handling cases if we allow cache miss + bool AllowCacheMiss() override { return false; } + /// \brief Base-class override for the name of this operator + std::string Name() const override { return "CacheOp"; } + /// \brief A public wrapper for creating the cache through the client + /// \param[in] cache_crc The crc that identifies the cache + /// \see cache_pass.cc + /// \return Status return code + Status CreateCache(uint32_t cache_crc); + + private: + WaitPost rows_cache_done_; + std::atomic num_guys_in_; + Phase phase_; + /// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler. + /// \return Status object + Status WaitForCachingAllRows(); + /// \brief For non-mappable dataset, there is a build phase where we cache all the rows. + /// \return Status object + Status CacheAllRows(int32_t worker_id); + Status RegisterResources() override; + /// \brief Private function for cache setup/init work just after construction + /// \return Status The error code return + Status InitCache(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc new file mode 100644 index 0000000000..7acb68350b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.cc @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/concat_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +ConcatOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +// The builder "build" method creates the final object. +Status ConcatOp::Builder::Build(std::shared_ptr *ptr) { + *ptr = std::make_shared(builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the ConcatOp. +ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {} + +// A function that prints info about the Operator +void ConcatOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nDatasets: " << children_num_ << "\n\n"; + } +} + +// Main entry point for Concat +Status ConcatOp::operator()() { + // The children_num_ parameter needs to be put here + children_num_ = static_cast(child_.size()); + TaskManager::FindMe()->Post(); + std::unique_ptr buf; + int eof_count = 0; + while (eof_count == 0) { + for (int i = 0; i < children_num_; i++) { + // 1. Read the first buffer + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + if (buf->eof()) { + eof_count++; + continue; + } + // 2. Do verification as for column name, column data type and rank of column data + if (!buf->eoe()) { + RETURN_IF_NOT_OK(Verify(i, buf)); + } + // 3. Put the data into output_connector + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); + RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf)); + } + } + // 4. Add eoe buffer after get buffer from all child + if (eof_count == 0) { + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + } + } + CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_, + "Something went wrong, eof count does not match the number of children."); + // 5. Add eof buffer in the end manually + MS_LOG(DEBUG) << "Add the eof buffer manualy in the end."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + return Status::OK(); +} + +Status ConcatOp::Verify(int32_t id, const std::unique_ptr &buf) { + TensorRow new_row; + buf->GetRow(0, &new_row); + + if (id == 0) { + // Obtain the data type and data rank in child[0] + for (auto item : new_row) { + data_type_.push_back(item->type()); + data_rank_.push_back(item->Rank()); + } + } else { + // Compare the data type and data rank with these in child[0] + int32_t index = 0; + for (auto item : new_row) { + if ((item->type() != data_type_[index]) || item->Rank() != data_rank_[index++]) { + RETURN_STATUS_UNEXPECTED("The data type or data rank is not the same with previous dataset."); + } + } + } + return Status::OK(); +} + +// We need to overwrite the super class ComputeColMap here because the number of children is more than 1. +Status ConcatOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + // Obtain columns_name_id_map from child_[0] + column_name_id_map_ = child_[0]->column_name_id_map(); + if (column_name_id_map_.empty()) { + RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); + } + // Verify all children have the same column name map + for (int32_t i = 0; i < child_.size(); ++i) { + if (child_[i]->column_name_id_map() != column_name_id_map_) { + RETURN_STATUS_UNEXPECTED("The column name or column order is not the same with previous dataset."); + } + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h new file mode 100644 index 0000000000..3d3d9df71c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/concat_op.h @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class ConcatOp : public PipelineOp { + public: + // The nested builder class inside of the ConcatOp is used to help manage all of the arguments + // for constructing it. This Concat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new ConcatOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_op_connector_size_; + }; + + // Constructor of the ConcatOp. + // @note The builder class should be used to call it + // @param op_connector_size - connector size + explicit ConcatOp(int32_t op_connector_size); + + // Destructor + ~ConcatOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the ConcatOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ConcatOp &ro) { + ro.Print(out, false); + return out; + } + + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ConcatOp"; } + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + private: + Status Verify(int32_t id, const std::unique_ptr &buf); + + int32_t children_num_; // The num of child of parent node. + std::unordered_map column_name_id_; // Mapping between col index and col name + std::vector data_type_; + std::vector data_rank_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc new file mode 100644 index 0000000000..9254141308 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc @@ -0,0 +1,391 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/system/crc32c.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Constructor +DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr sampler) + : oc_queue_size_(op_connector_size), + sampler_(sampler), + operator_id_(kInvalidOperatorId), + tree_(nullptr), + state_(OpState::kDeOpIdle), + op_ctrl_flags_(kDeOpNone), + out_connector_(nullptr) { + // The operator starts out with an invalid operator id. The only way to + // get it out of invalid state is to assign the operator to an execution tree. +} + +// Adds a operator to become our child. +Status DatasetOp::AddChild(std::shared_ptr child) { + if (std::dynamic_pointer_cast(child) != nullptr) { + std::string err_msg("DeviceQueueOp cannot be added as a child, DeviceQueueOp must be a root node"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (operator_id_ == kInvalidOperatorId) { + std::string err_msg( + "Cannot add child node. Tree node connections can only" + "be made if the node belongs to a tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // disallow relationships with other trees + if (tree_ != child->tree_) { + std::string err_msg( + "Cannot add child node. Tree node connections can only be made if both nodes belong to the same tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + child_.push_back(child); + child->AddParent(this); + return Status::OK(); +} + +Status DatasetOp::RemoveChild(std::shared_ptr child) { + if (operator_id_ == kInvalidOperatorId) { + std::string err_msg( + "Cannot remove child node. Tree node connections can only" + "be made if the node belongs to a tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // disallow relationships with other trees + if (tree_ != child->tree_) { + std::string err_msg( + "Cannot remove child node. Tree node connections can only be made if both nodes belong to the same tree."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + child_.erase(std::remove(child_.begin(), child_.end(), child), child_.end()); + child->RemoveParent(this); + return Status::OK(); +} + +Status DatasetOp::InsertAsParent(std::shared_ptr to_add) { + for (auto &prev_parent : this->parent_) { + RETURN_IF_NOT_OK(prev_parent->RemoveChild(shared_from_this())); + RETURN_IF_NOT_OK(prev_parent->AddChild(to_add)); + } + RETURN_IF_NOT_OK(to_add->AddChild(shared_from_this())); + if (tree_->root()->id() == this->id()) { + tree_->AssignRoot(to_add); + } + return Status::OK(); +} + +// Adds a parent operator to this operator +void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); } + +// Removes a parent operator from this operator +void DatasetOp::RemoveParent(const DatasetOp *parent) { + parent_.erase(std::remove(parent_.begin(), parent_.end(), parent), parent_.end()); +} + +// Removes this node from the tree and connects it's parent/child together +Status DatasetOp::Remove() { + if (parent_.size() > 1) { + std::string err_msg("No support for op removal if the operator has more than one parent"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (child_.size() > 1) { + std::string err_msg("No support for op removal if the operator has more than one child"); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Scenario's when removing node B: + // A -> B -> C + // A -> B + // B -> C + // + // If we remove B, then first take our child A and update it's parent to be C + // It's possible the parent is null if we are the root node being removed. + if (!child_.empty()) { + // If we have a parent, then assign chlid's parent to point to our parent. + if (!parent_.empty()) { + child_[0]->parent_[0] = parent_[0]; + } else { + // We don't have a parent, so we are the root node being removed. + // clear the parent list of our child so that it becomes the new root. + child_[0]->parent_.clear(); + tree_->AssignRoot(child_[0]); + } + } + + // Next, if we had a parent, then set it's child to be our child. + if (!parent_.empty()) { + // if we have a child, then set our parent to point to it + if (!child_.empty()) { + parent_[0]->child_[0] = child_[0]; + } else { + // We don't have a child, so clear the child list of the current + // parent because it will be empty once we are removed. + parent_[0]->child_.clear(); + } + } + + // Finally, clear "this" op's parent and child pointers since we have just + // disconnected it from the tree and invalidate it's fields. + child_.clear(); + parent_.clear(); + operator_id_ = kInvalidOperatorId; + tree_ = nullptr; + + return Status::OK(); +} + +// Getter function to get a shared pointer to our child +std::shared_ptr DatasetOp::child(int32_t child_index) const { + std::shared_ptr return_op = nullptr; + if (child_.empty()) { + return return_op; + } + MS_ASSERT(child_index < static_cast(child_.size())); + // Return a shared pointer + return child_[child_index]; +} + +// Getter function to get the parent pointer +void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const { + if (parent_.empty()) { + // common case if this is a root node + *parent = nullptr; + } else { + MS_ASSERT(parent_index < static_cast(parent_.size())); + *parent = parent_[parent_index]; + } +} + +// Creates the connector within this operator +void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) { + MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers + << ". Consumer: " << num_consumers << "."; + if (oc_queue_size_ > 0) { + out_connector_ = std::make_unique(num_producers, // The number of producers + num_consumers, // Only one consumer (the training App) + oc_queue_size_); + } else { + // Some op's may choose not to have an output connector + MS_LOG(DEBUG) << "Bypassed connector creation for tree operator: " << operator_id_ << "."; + out_connector_ = nullptr; + } +} + +// A print method typically used for debugging. showAll of true will recursively descend to child prints +void DatasetOp::Print(std::ostream &out, bool show_all) const { + // When show_all is false, we display a 1 liner piece of text for the op. + // When show_all is true, we display more detailed output for the op. + // Derived printers should show their own header info, then call base class printer, followed by + // derived-specific items. + // For now, the base class doesn't have any summary info to show so it's a no-op in that case. + if (show_all) { + // The detailed display will show common base class info of the op. Allow the derived class to print + // it's own id and name though as the first line. + out << "\nNumber of children : " << child_.size(); + for (size_t i = 0; i < child_.size(); i++) { + out << "\n Child[" << i << "] id: " << child_[i]->id(); + } + out << "\nNumber of parents : " << parent_.size(); + for (size_t i = 0; i < parent_.size(); i++) { + out << "\n Parent[" << i << "] id: " << parent_[i]->id(); + } + out << "\nConnector queue size : " << oc_queue_size_ << "\nOperator control flags : 0x" << std::hex + << std::setw(8) << std::setfill('0') << op_ctrl_flags_ << std::dec << std::setfill(' '); + if (sampler_) { + sampler_->Print(out, show_all); + } + } +} + +// Gets the next buffer from the given child +Status DatasetOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { +#if defined(_WIN32) || defined(_WIN64) + RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), p_buffer, retry_if_eoe)); +#else + std::unique_ptr next_buff; + // pop is a blocked call and will throw an interruption if the whole group shuts down. + RETURN_IF_NOT_OK(out_connector_->PopWithRetry(static_cast(worker_id), &next_buff, retry_if_eoe)); + + *p_buffer = std::move(next_buff); +#endif + return Status::OK(); +} + +// Gets the next buffer from the given child . This function also has built-in eoe and eof +// message handling so that child classes don't have to manually code pass-through logic when +// those messages are received. +Status DatasetOp::GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id, int32_t child_index) { + if (child_.size() == 0) { + return this->GetNextBuffer(p_buffer, worker_id); + } + CHECK_FAIL_RETURN_UNEXPECTED(child_index < child_.size(), "Child index too big : " + std::to_string(child_index)); + std::shared_ptr child = child_[child_index]; + std::unique_ptr buf; + RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); + // Loop until non EOE is received + while (buf->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + if (state_ == OpState::kDeOpIdle) { + *p_buffer = std::move(buf); + return Status::OK(); + } + RETURN_IF_NOT_OK(child->GetNextBuffer(&buf, worker_id)); + } + // Check if the last buf is next eof + if (buf->eof()) { + RETURN_IF_NOT_OK(EofReceived(worker_id)); + } + *p_buffer = std::move(buf); + return Status::OK(); +} + +// Performs handling for when an eoe message is received. +// The base class implementation simply flows the eoe message to output. Derived classes +// may override if they need to perform special eoe handling. +Status DatasetOp::EoeReceived(int32_t worker_id) { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + return (out_connector_->Add(static_cast(worker_id), std::move(eoe_buffer))); +} + +// Performs handling for when an eof message is received. +// The base class implementation simply flows the eof message to output. Derived classes +// may override if they need to perform special eof handling. +Status DatasetOp::EofReceived(int32_t worker_id) { + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + return (out_connector_->Add(static_cast(worker_id), std::move(eof_buffer))); +} + +// During tree prepare phase, operators may have specific pre-operations to perform depending on +// their role. +Status DatasetOp::PrepareNodePreAction() { return Status::OK(); } + +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status DatasetOp::PrepareNodePostAction() { + // Creating Connector object for each op. + // The consumer of the root node is assumed to be one thread. + // If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion. + if (parent_.empty()) { + this->CreateConnector(num_producers(), 1); + } else { + this->CreateConnector(num_producers(), parent_[0]->num_consumers()); + } + if (out_connector_) { + RETURN_IF_NOT_OK(out_connector_->Register(tree_->AllTasks())); + } + RETURN_IF_NOT_OK(this->RegisterWorkerConnectors()); + + // Generate the column name map for the current op. + RETURN_IF_NOT_OK(this->ComputeColMap()); + + return Status::OK(); +} + +// Getter function. Base class does not have any special flags setting. +uint32_t DatasetOp::PrepareFlags() const { return ExecutionTree::kDePrepNone; } + +// Derived classes may implement the reset function if the operator is stateful and needs +// specific reset handling that is not contained in this common code version of the reset. +Status DatasetOp::Reset() { + state_ = OpState::kDeOpRunning; + return Status::OK(); +} + +// gives a string output for the column map for handy debug printing +std::string DatasetOp::ColumnNameMapAsString() const { + std::string outStr = "Column name id map: "; + for (auto &it : column_name_id_map_) { + outStr += (" " + it.first + ":" + std::to_string(it.second)); + } + return outStr; +} + +// Computing the assignment of the column name map. +// This just inherits the column map from its first child, can only be used if the number of children is 1. +// Operations changing the column map must overwrite this function. +Status DatasetOp::ComputeColMap() { + if (child_.size() > 1) { + RETURN_STATUS_UNEXPECTED("Assigning column name map from child only works for single-child operators."); + } + if (column_name_id_map_.empty()) { + column_name_id_map_ = child_[0]->column_name_id_map(); + if (column_name_id_map_.empty()) { + RETURN_STATUS_UNEXPECTED("Child column name map cannot be empty!"); + } + MS_LOG(DEBUG) << "Setting column map:\n" << DatasetOp::ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +Status DatasetOp::PreAccept(NodePass *p, bool *modified) { + // DatasetOp is the base class of visitor target pre-visit. + // This method will only be called if its derived class does not implement one. + return p->PreRunOnNode(shared_from_this(), modified); +} + +Status DatasetOp::Accept(NodePass *p, bool *modified) { + // DatasetOp is the base class of visitor target. + // This method will only be called if its derived class does not implement one. + return p->RunOnNode(shared_from_this(), modified); +} + +// Getter for the sampler, and it also removes the sampler from the op +Status DatasetOp::FetchRemoveSampler(std::shared_ptr *sampler) { + *sampler = sampler_; // It's okay if it sampler_ points to nullptr + sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler + return Status::OK(); +} + +uint32_t DatasetOp::GenerateCRC(const std::shared_ptr &op) { + std::stringstream ss; + op->tree_->Print(ss, op); + std::string ss_str = ss.str(); + + // Filter out the Operator control flags field when generating the check sum + ss_str = std::regex_replace(ss_str, std::regex("Operator control flags.*\n"), ""); + + // Filter out the Device id field to allow cache sharing for a distributed run of the same pipeline + ss_str = std::regex_replace(ss_str, std::regex("Device id.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("device_id.*\n"), ""); + + // The Cache crc and Server cache id field is different when creating new cache_client and re-using the same + // cache_client later. So we filter out these two fields to allow cache sharing. + ss_str = std::regex_replace(ss_str, std::regex("Cache crc.*\n"), ""); + ss_str = std::regex_replace(ss_str, std::regex("Server cache id.*\n"), ""); + + uint32_t cache_crc = system::Crc32c::GetMaskCrc32cValue(ss_str.c_str(), ss_str.length()); + return cache_crc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h new file mode 100644 index 0000000000..b4630c1652 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.h @@ -0,0 +1,363 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ +#define DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +// Forward declare +class ExecutionTree; + +class DataBuffer; + +class NodePass; + +class Sampler; + +/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so +/// the actual implementation of the operators will be derived from here. +class DatasetOp : public std::enable_shared_from_this { + // Allow execution tree to access internal members + friend class ExecutionTree; + + public: + static constexpr int32_t kInvalidOperatorId = -1; + + // Operator control flags + enum OpControlFlags { + kDeOpNone = 0, + kDeOpRepeated = 1, // Operator is a node in a repeat path + kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop + }; + + // Flags that control operator runtime behaviours + enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated }; + + /// Constructor + /// \param op_connector_size - The size for the output connector of this operator. + /// \param sampler - The sampler for the op + explicit DatasetOp(int32_t op_connector_size, std::shared_ptr sampler); + + /// Destructor + virtual ~DatasetOp() { tree_ = nullptr; } + + /// Adds a operator to become our child. + /// \param child - shared pointer to the child to add. + Status AddChild(std::shared_ptr child); + + /// Remove a operator from our children. + /// \param child - shared pointer to the child to remove. + Status RemoveChild(std::shared_ptr child); + + /// \brief Removes this node from the tree and connects it's parent/child together + /// \return Status eerror code returned + Status Remove(); + + /// \brief Getter function to get a shared pointer to our child + /// \param[in] child_index An operator can have n children. Indicates which child to return. + /// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index + std::shared_ptr child(int32_t child_index) const; + + /// \brief Getter function to get the pointer to our parent + /// If there are no parents, it returns null regardless of the given index + /// \param[in] parent_index An operator can have n parents. Indicates which parent to return. + void Parent(DatasetOp **parent, int32_t parent_index) const; + + // Inserts a operator as the parent current op. + // Inserted op will become the sole parent of the current op. + // The existing parent of the current op will be transferred to the inserted op. + Status InsertAsParent(std::shared_ptr to_add); + + /// \brief Creates the connector within this operator + /// \param num_producers - number of threads that write into this connector + /// \param num_consumers - number of threads that read from this connector + void CreateConnector(int32_t num_producers, int32_t num_consumers); + + /// \brief A print method typically used for debugging + /// \param out - The output stream to write output to + /// \param show_all - A bool to control if you want to show all info or just a summary + virtual void Print(std::ostream &out, bool show_all) const; + + /// \brief << Stream output operator overload + /// \notes This allows you to write the debug print info using stream operators + /// \param out - reference to the output stream being overloaded + /// \param dO - reference to the DatasetOp to display + /// \return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const DatasetOp &dO) { + dO.Print(out, false); + return out; + } + + /// \brief Class functor operator (). + /// DatasetOps operate by launching a thread (see ExecutionTree). + /// This pure virtual version makes the requirement that derived classes must provide a functor + /// that will execute their main runtime loop code. + /// \return Status - The error code return + virtual Status operator()() = 0; + + /// \brief Gets the next buffer from the given child + /// \notes See GetNextInput for similar function that has built-in message handling + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \param worker_id - The worker id + /// \return Status - The error code return + virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { + return GetNextBuffer(p_buffer, worker_id, false); + } + + /// \brief Gets the next buffer from the given child + /// \notes See GetNextInput for similar function that has built-in message handling + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \return Status - The error code return + virtual Status GetNextBuffer(std::unique_ptr *p_buffer) { return GetNextBuffer(p_buffer, 0, false); } + + /// \brief Gets the next buffer from the given child + /// \notes See GetNextInput for similar function that has built-in message handling + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \param worker_id - The worker id + /// \param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. + /// \return Status - The error code return + virtual Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe); + + /// \brief Gets the next buffer from the given child . This function also has built-in eoe and eof + /// message handling so that child classes don't have to manually code pass-through logic when + /// those messages are received. + /// \param p_buffer - The shared pointer for the fetched buffer to return (by reference) + /// \param worker_id - The worker id + /// \return Status - The error code return + Status GetNextInput(std::unique_ptr *p_buffer, int32_t worker_id = 0, int32_t child_index = 0); + + /// \brief Performs handling for when an eoe message is received. + /// The base class implementation simply flows the eoe message to output. Derived classes + /// may override if they need to perform special eoe handling. + /// \param worker_id - The worker id + /// \return Status - The error code return + virtual Status EoeReceived(int32_t worker_id); + + /// \brief Performs handling for when an eof message is received. + /// The base class implementation simply flows the eof message to output. Derived classes + /// may override if they need to perform special eof handling. + /// \param worker_id - The worker id + /// \return Status - The error code return + virtual Status EofReceived(int32_t worker_id); + + /// \brief Derived classes may implement the reset function if the operator is stateful and needs + /// specific reset handling that is not contained in this common code version of the reset + /// \return Status - The error code return + virtual Status Reset(); + + /// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on + /// their role. + /// \notes Derived versions of this function should always call it's superclass version first + /// before providing their own implementations. + virtual Status PrepareNodePreAction(); + + /// \brief During tree prepare phase, operators may have specific post-operations to perform depending on + /// their role. + /// \notes Derived versions of this function should always call it's superclass version first + /// before providing their own implementations. + virtual Status PrepareNodePostAction(); + + /// \brief Getter function + /// \return The operator id + int32_t id() const { return operator_id_; } + + /// \brief Getter function + /// \return The prepare flags + virtual uint32_t PrepareFlags() const; + + /// \brief Getter function + /// \return The number of workers in this op + virtual int32_t num_workers() const = 0; + + /// \brief Getter function + /// \return The number of threads consuming from previous op. + virtual int32_t num_consumers() const = 0; + + /// \brief Getter function + /// \return The number of threads producing to the output connector. + virtual int32_t num_producers() const = 0; + + /// \brief Getter function + /// \return T/F if this is an inlined operator + bool inlined() const { return (oc_queue_size_ == 0); } + + /// \brief Setter function + /// \return Sets the control flags + void set_control_flag(uint64_t flag) { BitSet(&op_ctrl_flags_, flag); } + + /// \brief Setter function + /// \return Sets the control flags + void ClearControlFlag(uint64_t flag) { BitClear(&op_ctrl_flags_, flag); } + + /// \brief Register the internal worker connectors. No op unless it is a parallel op + /// \return Status + virtual Status RegisterWorkerConnectors() { return Status::OK(); } + + /// \brief Getter for the column name mapping + /// \return The returned map + std::unordered_map column_name_id_map() const { return column_name_id_map_; } + + /// \brief Checks if the column name map has been set up yet for this op + /// \return - T/F if the operator has the map set up + bool HasColumnNameMap() const { return (column_name_id_map_.empty()); } + + /// \brief gives a string output for the column map for handy debug printing + /// \return - the column name map as a string + std::string ColumnNameMapAsString() const; + + /// \brief Getter function + /// \return connector size of current op + int32_t ConnectorSize() const { + if (!inlined()) { + return out_connector_->size(); + } + // Return child connector size for inlined op + return ChildOpConnectorSize(); + } + + /// \brief Counting number of buffer sent out by a connector + int64_t ConnectorOutBufferCount() const { + return out_connector_ == nullptr ? int64_t(-1) : static_cast(out_connector_->out_buffers_count()); + } + + /// \brief Getter function + /// \return connector size of current op + int32_t ConnectorCapacity() const { + if (!inlined()) { + return out_connector_->capacity(); + } + // Return child connector capacity for inlined op + return ChildOpConnectorCapacity(); + } + + /// \brief Getter function + /// \return connector size of child op + int32_t ChildOpConnectorSize(int32_t child_index = 0) const { return child_[child_index]->ConnectorSize(); } + + /// \brief Getter function + /// \return connector capacity of child op + int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } + + /// \brief Children Getter + /// \return Vector of Children + std::vector> Children() const { return child_; } + + /// \brief Base method for NodePass pre-visit. A tree walk consists of walking down the tree and also walking back up + /// in a depth-first order. PreAccept is the node visit on the way down, whereas the regular Accept is the main + /// visit on the way back up the tree during a post-order traversal. Subclass needs to override this if it + /// requires special node visit access. Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status PreAccept(NodePass *p, bool *modified); + + /// \brief Base method for NodePass visit. Subclass needs to override this if it requires special node visit access. + /// Check "dataset/engine/opt/pass.h" for more details. + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + virtual Status Accept(NodePass *p, bool *modified); + + /// Op name getter + /// \return Name of the current Op + virtual std::string Name() const { return "DatasetOp"; } + + /// Execution Tree getter + /// \return Pointer to the ExecutionTree the current op belongs to, no ownership + ExecutionTree *Tree() { return tree_; } + + /// Getter for the sampler + /// \return Shared pointer to the sampler (may return nullptr) + std::shared_ptr sampler() { return sampler_; } + + /// \brief Getter for the sampler, and it also removes the sampler from the op + /// \param[out] sampler A pointer to the output sampler that was removed + /// \return Status error code + Status FetchRemoveSampler(std::shared_ptr *sampler); + + // Computes a CRC value for the operator + static uint32_t GenerateCRC(const std::shared_ptr &op); + + /// \brief A helper templated function for casting "this" pointer to shared_ptr + /// Similar to shared_from_this, except this one will give you the derived class as shared_ptr + /// \return A shared_ptr casted to the derived class + template + std::shared_ptr shared_from_base() { + return std::static_pointer_cast(shared_from_this()); + } + + /// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one. + void SetSampler(std::shared_ptr sampler) { sampler_ = sampler; } + + /// \brief Checks if this is a leaf node (0 children) + /// \return boolean returns true if it's a leaf + bool IsLeaf() { return (child_.empty()); } + + protected: + /// \brief Removes a parent operator from this operator + /// \notes External callers do not have access to this function + /// \param[in] parent The parent node to remove + void RemoveParent(const DatasetOp *parent); + + /// \brief Adds a parent operator to this operator + /// \notes External callers do not have access to this function + /// \param[in] parent The parent node to add + void AddParent(DatasetOp *parent); + + /// Compute the current op's column map using its child's column map. + /// Get called during the tree post-prepare phase in PrepareNodePostAction. + /// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1. + /// Operations changing the column map it inherits from the child must overwrite this function. + /// \return - Status + virtual Status ComputeColMap(); + + std::vector> child_; // Child nodes + std::vector parent_; // Parent nodes. No ownership + std::shared_ptr sampler_; // Some leaf ops might have a sampler + int32_t oc_queue_size_; // Capacity for each out_connector_ + int32_t operator_id_; // Generated id for the node + ExecutionTree *tree_; // Back pointer to our tree. + OpState state_; // The state of the operator, Running, Idle, Terminated + uint32_t op_ctrl_flags_; // Flags for the operator + std::unique_ptr out_connector_; // Output Connector + std::unordered_map column_name_id_map_; // Mapping between col index and col name + std::mutex column_name_map_mutex_; // For protecting shared access to the column map + + private: + /// Sets the operator id. + /// \notes No public interface. Only the class itself, or it's friend the execution tree can set + /// this + /// \param op_id - the Id value to set into the operator + void set_id(int32_t op_id) { operator_id_ = op_id; } + + /// Sets the tree into the op so that the operator has a back pointer to the tree. + /// \param tree - the tree to assign to the op. + void set_tree(ExecutionTree *tree) { tree_ = tree; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_DATASET_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc new file mode 100644 index 0000000000..4fe779246b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, + int32_t op_connector_size, int64_t num_batch) + : PipelineOp(op_connector_size), + channel_name_(channel_name), + device_type_(device_type), + device_id_(device_id), + prefetch_size_(prefetch_size), + num_batch_(num_batch) {} + +DeviceQueueOp::~DeviceQueueOp() {} + +#ifdef ENABLE_GPUQUE +void ReleaseData(void *addr) { + if (addr != nullptr) { + free(addr); + } +} +#endif + +DeviceQueueOp::Builder::Builder(int32_t prefetch_size) + : builder_prefetch_size_(prefetch_size), + builder_device_id_(0), + builder_device_type_(DeviceType::CPU), + builder_channel_name_(""), + builder_num_batch_(0) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status DeviceQueueOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +Status DeviceQueueOp::operator()() { + TaskManager::FindMe()->Post(); + + if (device_type_ == DeviceType::Ascend) { +#ifdef ENABLE_TDTQUE + RETURN_IF_NOT_OK(SendDataToAscend()); +#endif + } else if (device_type_ == DeviceType::GPU) { +#ifdef ENABLE_GPUQUE + RETURN_IF_NOT_OK(SendDataToGPU()); +#endif + } else if (device_type_ == DeviceType::CPU) { + RETURN_IF_NOT_OK(SendDataToCPU()); + } + + return Status::OK(); +} + +Status DeviceQueueOp::CheckExceptions(const std::unique_ptr &buffer) const { + // this method checks if the buffer meets the conditions to be sent to TDT + if (buffer->NumRows() != 0) { + TensorRow row; + buffer->GetRow(0, &row); + for (const auto &item : row) { + CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device."); + } + } + return Status::OK(); +} + +#ifdef ENABLE_TDTQUE +Status DeviceQueueOp::SendDataToAscend() { + MS_LOG(INFO) << "Device queue, sending data to Ascend."; + int64_t total_batch = 0; + bool is_break_loop = false; + double batch_start_time, end_time; + int32_t batch_cost, tdt_cost; + int32_t connector_size = 0; + int32_t connector_capacity; + std::shared_ptr profiling_node; + bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable(); + if (isProfilingEnable) { + std::shared_ptr node; + RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node)); + profiling_node = std::dynamic_pointer_cast(node); + batch_start_time = ProfilingTime::GetCurMilliSecond(); + connector_capacity = ChildOpConnectorCapacity(); + } + std::unique_ptr current_buffer; + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + + while (!current_buffer->eof() && !is_break_loop) { + while (!current_buffer->eoe() && !is_break_loop) { + RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); + TensorRow currRow; + for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { + RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); + auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost); + if (status == TdtStatus::FAILED) { + return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); + } + + if (isProfilingEnable) { + end_time = ProfilingTime::GetCurMilliSecond(); + // record push tdt time + profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost); + batch_cost = (int32_t)(end_time - batch_start_time); + // record batch time + profiling_node->Record(TIME, BATCH_TIME, total_batch + 1, batch_cost); + // record pipeline time + profiling_node->Record(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost); + batch_start_time = end_time; + // record connector depth + profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size); + } + total_batch++; + if (num_batch_ > 0 && total_batch == num_batch_) { + is_break_loop = true; + } + } + if (isProfilingEnable) { + connector_size = ChildOpConnectorSize(); + connector_capacity = ChildOpConnectorCapacity(); + } + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + } + if (isProfilingEnable) { + connector_size = ChildOpConnectorSize(); + connector_capacity = ChildOpConnectorCapacity(); + } + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + } + + tree_->SetFinished(); + MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + + return Status::OK(); +} +#endif + +#ifdef ENABLE_GPUQUE +Status DeviceQueueOp::SendDataToGPU() { + MS_LOG(INFO) << "Device queue, sending data to GPU."; + int64_t total_batch = 0; + bool is_break_loop = false; + bool is_open = false; + uint32_t handle = INVALID_HANDLE; + + std::unique_ptr current_buffer; + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + + while (!current_buffer->eof() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { + while (!current_buffer->eoe() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed()) { + RETURN_IF_NOT_OK(CheckExceptions(current_buffer)); + TensorRow curr_row; // batch data + for (int row_id = 0; + row_id < current_buffer->NumRows() && !is_break_loop && !GpuBufferMgr::GetInstance().IsClosed(); row_id++) { + RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &curr_row)); + + std::vector data_size; + for (int i = 0; i < curr_row.size(); i++) { + data_size.push_back(static_cast(curr_row[i]->SizeInBytes())); + } + if (!is_open) { + handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, ReleaseData); + if (handle == INVALID_HANDLE) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "open failed"); + } + is_open = true; + } + RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle)); + total_batch++; + if (num_batch_ > 0 && total_batch == num_batch_) { + is_break_loop = true; + } + } + if (!TaskManager::FindMe()->Interrupted()) + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + else + is_break_loop = true; + } + if (!TaskManager::FindMe()->Interrupted()) + RETURN_IF_NOT_OK(GetNextInput(¤t_buffer)); + else + is_break_loop = true; + } + + MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + + GpuBufferMgr::GetInstance().Close(handle); + + GpuBufferMgr::GetInstance().CloseConfirm(); + + return Status::OK(); +} + +Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, + uint32_t handle) { + std::vector items; + for (int i = 0; i < data_size.size(); i++) { + device::DataItemGpu data_item; + data_item.data_len_ = data_size[i]; + data_item.data_ptr_ = nullptr; + items.push_back(data_item); + } + + while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { + RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row)); + BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); + if (ret) { + for (int i = 0; i < items.size(); i++) { + free(items[i].data_ptr_); + } + if (ret == BlockQueueStatus_T::ERROR_INPUT) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it."); + } else { + MS_LOG(WARNING) << "Retry pushing data..."; + continue; + } + } else { + break; + } + } + return Status::OK(); +} + +Status DeviceQueueOp::MallocForGPUData(std::vector *items, const TensorRow &curr_row) { + int i = 0; + for (auto &sub_item : *items) { + sub_item.data_ptr_ = (unsigned char *)malloc(sub_item.data_len_); + if (sub_item.data_ptr_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memory malloc failed."); + } + (void)memset_s(sub_item.data_ptr_, sub_item.data_len_, 0, sub_item.data_len_); + const unsigned char *column_data = curr_row[i]->GetBuffer(); + if (memcpy_s(sub_item.data_ptr_, sub_item.data_len_, column_data, + static_cast(curr_row[i++]->SizeInBytes())) != 0) { + MS_LOG(ERROR) << "memcpy_s failed!"; + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "memcpy_s failed."); + } + } + + return Status::OK(); +} +#endif + +Status DeviceQueueOp::SendDataToCPU() { + MS_LOG(INFO) << "Device queue, sending data to CPU."; + int64_t total_batch = 0; + + std::unique_ptr child_iterator = std::make_unique(this, 0, 0); + while (!(child_iterator->eof_handled())) { + TensorRow curr_row; + RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&curr_row)); + + if (!curr_row.empty()) { + MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << "."; + MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << "."; + total_batch++; + if (num_batch_ > 0 && total_batch == num_batch_) { + break; + } + } + } + + MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; + + return Status::OK(); +} + +void DeviceQueueOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n"; + } +} + +// Visitor accept method for NodePass +Status DeviceQueueOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h new file mode 100644 index 0000000000..0fb4fb093d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -0,0 +1,175 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +#ifdef ENABLE_TDTQUE +#include "minddata/dataset/engine/tdt/tdt_plugin.h" +#endif + +#ifdef ENABLE_GPUQUE +#include "runtime/device/gpu/gpu_buffer_mgr.h" +using mindspore::device::BlockQueueStatus_T; +using mindspore::device::GpuBufferMgr; +#endif + +namespace mindspore { +namespace dataset { +class DeviceQueueOp : public PipelineOp { + public: + static const uint32_t INVALID_HANDLE = 0xffffffffUL; + static const uint32_t WAIT_TIME = 5; + + enum class DeviceType { Ascend = 0, GPU = 1, CPU = 2 }; + + // The nested builder class inside of the DeviceQueueOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + explicit Builder(int32_t prefetch_size); + + // Default destructor + ~Builder() = default; + + Builder &SetPrefetchSize(int32_t prefetch_size) { + builder_prefetch_size_ = prefetch_size; + return *this; + } + + Builder &SetChannelName(const std::string &channel_name) { + builder_channel_name_ = channel_name; + return *this; + } + + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + Builder &SetDeviceType(const std::string &device_type) { + if (device_type == "Ascend") { + builder_device_type_ = DeviceType::Ascend; + } else if (device_type == "GPU") { + builder_device_type_ = DeviceType::GPU; + } else if (device_type == "CPU") { + builder_device_type_ = DeviceType::CPU; + } + return *this; + } + + Builder &SetDeviceId(int32_t device_id) { + builder_device_id_ = device_id; + return *this; + } + + Builder &SetNumBatch(int64_t num_batch) { + builder_num_batch_ = num_batch; + return *this; + } + + // Name: Build() + // Description: The final step for building a DeviceQueueOp via the Builder is + // to call this Build() method. It will instantiate the DeviceQueueOp + // and return it to caller as a shared pointer. + Status Build(std::shared_ptr *ptr) { + *ptr = std::make_shared(builder_channel_name_, builder_device_type_, builder_device_id_, + builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_); + return Status::OK(); + } + + private: + int32_t builder_prefetch_size_; + int32_t builder_device_id_; + DeviceType builder_device_type_; + std::string builder_channel_name_; + int64_t builder_num_batch_; + int32_t builder_op_connector_size_; + }; + + // Name: constructor + // Description + DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, + int32_t op_connector_size, int64_t num_batch); + + // Name: destructor + // Description + ~DeviceQueueOp(); + + Status EoeReceived(int32_t worker_id) override; + + const int32_t get_prefetch_size() { return prefetch_size_; } + + // Name: Print() + // Description: A function that prints info about the node + void Print(std::ostream &out, // In: The output stream to print to + bool show_all) const override; // In: T/F if it should print everything + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const DeviceQueueOp &to) { + to.Print(out, false); + return out; + } + + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "DeviceQueueOp"; } + + private: + // Name: checkExceptions(DataBuffer); + // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp + Status CheckExceptions(const std::unique_ptr &buffer) const; + +#ifdef ENABLE_TDTQUE + Status SendDataToAscend(); +#endif + +#ifdef ENABLE_GPUQUE + Status SendDataToGPU(); + Status RetryPushGPUData(const std::vector &data_size, const TensorRow &curr_row, uint32_t handle); + Status MallocForGPUData(std::vector *items, const TensorRow &curr_row); +#endif + + Status SendDataToCPU(); + std::string channel_name_; + DeviceType device_type_; + const int32_t device_id_; + const int32_t prefetch_size_; + const int64_t num_batch_; + +#ifdef ENABLE_TDTQUE + std::shared_ptr tdtInstancePtr; +#endif +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_DEVICE_QUEUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc new file mode 100644 index 0000000000..f32648a3df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.cc @@ -0,0 +1,267 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { + +Status FilterOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +FilterOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status FilterOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, + builder_predicate_func_); + return Status::OK(); +} + +FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func) + : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} + +Status FilterOp::operator()() { + // The operator class just starts off threads by calling the tree_ function. + RETURN_UNEXPECTED_IF_NULL(tree_); + filter_queues_.Init(num_workers_, oc_queue_size_); + RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); + Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)); + // Synchronize with TaskManager. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + RETURN_IF_NOT_OK(Collector()); + return Status::OK(); +} + +Status FilterOp::EofReceived(int32_t) { return Status::OK(); } + +Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } + +// Validating if each of the input_columns exists in the DataBuffer. +Status FilterOp::ValidateInColumns(const std::vector *input_columns) { + for (const auto &inCol : *input_columns) { + bool found = column_name_id_map_.find(inCol) != column_name_id_map_.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +// A print method typically used for debugging. +void FilterOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nInput column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } + out << "\n\n"; + } +} + +Status FilterOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + bool worker_stop = false; + while (worker_stop == false) { + // Getting a databuffer to work on. + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); + if (in_buffer->eoe()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + continue; + } else if (in_buffer->eof()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); + worker_stop = true; + continue; + } + + RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); + + // if the databuffer was all filtered, it is marked as kFilterEmpty. + // if the databuffer was partially filtered, it is marked as kFilterPartial. + // if the databuffer was not filtered, it is marked as kFilterFull. + int32_t num_rows = in_buffer->NumRows(); + std::unique_ptr new_tensor_table; + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); + + if (new_tensor_table->empty()) { + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); + } else if (new_tensor_table->size() == num_rows) { + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); + } else { // kFilterPartial + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); + } + } + return Status::OK(); +} + +Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out) { + *out = std::make_unique(); + int32_t num_rows = in_buffer->NumRows(); + for (int32_t i = 0; i < num_rows; i++) { + TensorRow to_process; + TensorRow cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + if (in_columns_.empty() == true) { + MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; + to_process = cur_row; + } else { + (void)std::transform( + in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), + [&cur_row, this](const auto &it) -> std::shared_ptr { return cur_row[column_name_id_map_[it]]; }); + } + bool predicate = true; + RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); + if (predicate) { + (*out)->push_back(std::move(cur_row)); + } + } + return Status::OK(); +} + +// if the filtered DataBuffer is written directly to out_connector_, +// the thread fetching data will block in a queue. +// Collector function will reorder the DataBuffer in order. +// for example in two work queues: +// int filter_queues_: +// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) +// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) +// after reorder in out_connector_: +// queue1: DB(data2) DB(data4) DB(eof) +// queue2: DB(eoe) DB(eoe) +Status FilterOp::Collector() { + bool collector_stop = false; + uint64_t task_id_cnt = 0; + uint64_t out_id_cnt = 0; + std::pair, filterCtrl> in_pair; + while (collector_stop == false) { + uint32_t w_id = task_id_cnt % num_workers_; + RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); + if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || + in_pair.second == filterCtrl::kFilterEoe) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + out_id_cnt++; + task_id_cnt++; + } else if (in_pair.second == filterCtrl::kFilterEof) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + collector_stop = true; + } else { // kFilterEmpty + task_id_cnt++; + } + } + return Status::OK(); +} + +// Private function for checking the column legality. +Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns) { + int32_t num_rows = in_buf->NumRows(); + int32_t num_cols = in_buf->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); + } + // Check if there is invalid column name in the inColumns. + RETURN_IF_NOT_OK(ValidateInColumns(input_columns)); + return Status::OK(); +} + +Status FilterOp::CheckInput(const TensorRow &input) const { + for (auto &item : input) { + if (item == nullptr) { + RETURN_STATUS_UNEXPECTED("input is null."); + } + } + return Status::OK(); +} + +Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { + RETURN_IF_NOT_OK(CheckInput(input)); + // Acquire Python GIL. + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Transform input tensor vector into numpy array vector. + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + input_args[i] = new_data; + } + // Invoke python function. + py::object ret_py_obj = predicate_func_(*input_args); + *out_predicate = ret_py_obj.cast(); + } catch (const py::error_already_set &e) { + std::stringstream ss; + ss << e.what() << std::endl; + ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; + return Status(StatusCode::kPyFuncException, ss.str()); + } + return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); +} + +// Visitor accept method for NodePass +Status FilterOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h new file mode 100644 index 0000000000..fcc6e577df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/filter_op.h @@ -0,0 +1,188 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/queue.h" + +namespace mindspore { +namespace dataset { + +class FilterOp : public ParallelOp { + public: + // The nested builder class inside of the FilterOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args. + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPredicateFunc(py::function func) { + builder_predicate_func_ = std::move(func); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + builder_op_connector_size_ = connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new FilterOp object. + // @return Status. + Status Build(std::shared_ptr *ptr); + + private: + // Sanity check for builder class args. + // @return Status - The error code return. + Status SanityCheck(); + std::vector build_in_col_names_; + py::function builder_predicate_func_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + }; + + enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; + + // Constructor of FilterOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names,when it is empty the predicate will be + // applied all columns in the dataset. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + // @param predicate_func python callable which returns a boolean value. + FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func); + + // Destructor + ~FilterOp() = default; + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status The error code return + Status operator()() override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EofReceived(int32_t) override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging. + // @param out The output stream to write output to. + // @param show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "FilterOp"; } + + private: + // predicate_func python callable which returns a boolean value. + py::function predicate_func_; + + // Variable to store the column name that will feed to predicate function. + std::vector in_columns_; + + // Internal queue for filter. + QueueList, filterCtrl>> filter_queues_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of FilterOp, getting the data from previous Op, validating user specified column names, + // applying predicate to each of the data, filter the data when predicate result is false. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return. + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Filter the data by predicate function . + // @param in_buffer input data buffer. + // @param to_proess_indices Indices of columns to be processed. + // @param out data buffer that are filtered by predicate. + // @return Status The error code return. + Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out); + + // Collector databuffer. + // @return Status The error code return. + Status Collector(); + + // @param input tensor vector. + // @return Status - The error code return. + Status CheckInput(const TensorRow &input) const; + + // Invoke python func. + // @param input tensor vector. + // @param the result of predicate. + // @return Status - The error code return. + Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); + + // Private function for validating if each of the user specified input column names + // exist in the DataBuffer. + // @param input_columns The vector of input column names used in the current thread. + // @return Status The error code return. + Status ValidateInColumns(const std::vector *input_columns); + + // Private function for checking the column legality + // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory + // and is not shared with other threads. + // @param[out] to_process_indices Indices of columns that will feed to predicate. + // @param input_columns The vector of input column names used in the current thread. + Status CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc new file mode 100644 index 0000000000..e5e70dbbdf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.cc @@ -0,0 +1,373 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/map_op.h" +#include +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +MapOp::Builder::Builder() : build_perf_mode_(true) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_workers_ = cfg->num_parallel_workers(); + build_op_connector_size_ = cfg->op_connector_size(); +} + +// Check if the required parameters are set by the builder. +Status MapOp::Builder::sanityCheck() const { + if (build_tensor_funcs_.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Building a MapOp that has not provided any function/operation to apply"); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status MapOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(sanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), + std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, + build_perf_mode_); + return Status::OK(); +} + +// Constructor of MapOp +MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, + bool perf_mode) + : ParallelOp(num_workers, op_connector_size), + tfuncs_(std::move(tensor_funcs)), + in_columns_(in_col_names), + out_columns_(out_col_names), + perf_mode_(perf_mode) { + // If caller didn't specify the out_col_names, assume they are same as the in_columns. + if (out_columns_.empty() || out_columns_[0].empty()) { + out_columns_ = in_columns_; + } + MS_LOG(DEBUG) << "Performance Mode in map operator is " << perf_mode_ << "."; +} + +// The number of threads consuming data from previous op's output Connector. +int32_t MapOp::num_consumers() const { + // When Performance Mode is on, there is only one thread consuming from the previous Connector. + return perf_mode_ == true ? 1 : num_workers_; +} + +// A print method typically used for debugging +void MapOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nInput column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } + out << "\n TensorOps:"; + for (size_t i = 0; i < tfuncs_.size(); i++) { + out << " " << *(tfuncs_[i].get()); + } + out << "\n\n"; + } +} + +// This class functor will provide the master loop that drives the logic for performing the work +Status MapOp::operator()() { + if (perf_mode_) { + // Create and register the local queues. + local_queues_.Init(num_workers_, oc_queue_size_); + Status rc = local_queues_.Register(tree_->AllTasks()); + if (rc.IsError()) { + TaskManager::FindMe()->Post(); + return rc; + } + } + + // The operator class just starts off threads by calling the tree_ function + Status rc = tree_->LaunchWorkers(num_workers_, std::bind(&MapOp::WorkerEntry, this, std::placeholders::_1)); + // Synchronize with TaskManager + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(rc); + + if (perf_mode_) { + int64_t que_id = 0; + std::unique_ptr buff; + bool is_eof = false; + // Draining output connector of the previous op and distribute it to local queues. + // Stop when all worker threads are finished (received EOF). + while (!is_eof) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0)); + is_eof = buff->eof(); + RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); + que_id = (que_id + 1) % num_workers_; + } + } + + return Status::OK(); +} + +// Private function for worker/thread to loop continuously. It comprises the main +// logic of MapOp: getting the data from previous Op, validating user specified column names, +// applying a list of TensorOps to each of the data, process the results and then +// pushing them back to MapOp's output Connector to be fetched by the next Op. +Status MapOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + + // Getting a databuffer to work on. + // Perform the first fetch here outside of the loop. This allows us to execute one-time only + // initializations that happen after the first fetch. + RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); + + // Sanity check the databuffer. + // Special case: if there's more threads than buffers, some threads simply get the final control + // messages (eoe/eof), and so they will not perform the check. + if (!in_buffer->eoe() && !in_buffer->eof()) { + int32_t num_rows = in_buffer->NumRows(); + int32_t num_cols = in_buffer->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer."); + } + } + + // Now that init work is done, drop into the main fetching loop. + // Map op does not use child iterator, and it needs to manually handle eoe and eof's itself + // rather than use the base-class defaults. + while (true) { + // Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work + // with Performance Mode design. + if (in_buffer->eoe()) { + // Calling base class EoeReceived to forward eoe buffer. + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); + continue; + } else if (in_buffer->eof()) { + // Calling base class EofReceived to forward eof buffer. + RETURN_IF_NOT_OK(EofReceived(worker_id)); + break; + } + + std::unique_ptr new_tensor_table(std::make_unique()); + // Perform the compute function of TensorOp(s) and store the result in new_tensor_table. + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get())); + + // Replace the TensorTable in DataBuffer with the new one. + in_buffer->set_tensor_table(std::move(new_tensor_table)); + + // Push the buffer onto the connector for next operator to consume. + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(worker_id), std::move(in_buffer))); + + // Fetch the next buffer and loop back to the top. + RETURN_IF_NOT_OK(FetchNextBuffer(&in_buffer, worker_id)); + } + + return Status::OK(); +} + +Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table) { + // Getting number of rows and cols in this buffer. + int32_t num_rows = in_buffer->NumRows(); + int32_t num_cols = in_buffer->NumCols(); + + for (int32_t r = 0; r < num_rows; r++) { + // to_process : A vector of Tensors only holding cols in input_columns. + // result_row; : A vector of Tensors to hold the result after Compute(). + // cur_row : A vector of Tensors holding all the columns from DataBuffer. + TensorRow to_process, result_row, cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + + // Populate the Tensor from the current row to be processed by TensorOp + for (const auto &idx : to_process_indices_) { + to_process.push_back(std::move(cur_row[idx])); + } + + // Looping over multiple TensorOps supplied in to MapOp. + // The assumption is that the result of one TensorOp matches the required input to the next TensorOp. + for (size_t i = 0; i < tfuncs_.size(); i++) { + // TensorOp can operate on single col or multiple cols. MapOp always call compute for multiple cols. + // TensorOp base class will call the single column Compute() depending on the ops. + // Note: The columns of the result_row is not preallocated, the compute function of each tensor op are + // required to resize/push back the result_row + RETURN_IF_NOT_OK(tfuncs_[i]->Compute(to_process, &result_row)); + + // Assign result_row to to_process for the next TensorOp processing, except for the last TensorOp in the list. + if (i + 1 < tfuncs_.size()) { + to_process = std::move(result_row); + } + } + + if (out_columns_.size() != result_row.size()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Result of a tensorOp doesn't match output column names"); + } + + if (in_columns_.size() == out_columns_.size()) { + for (size_t i = 0; i < result_row.size(); i++) { + cur_row[to_process_indices_[i]] = std::move(result_row[i]); + } + new_tensor_table->push_back(std::move(cur_row)); + } else { + // Add the columns we did not touch to the result_row. + for (int32_t i = 0; i < num_cols; i++) { + if (keep_input_columns_[i]) { + result_row.push_back(std::move(cur_row[i])); + } + } + + // Add this final result_row to our new TensorTable. + new_tensor_table->push_back(std::move(result_row)); + } + } + + return Status::OK(); +} + +Status MapOp::ComputeColMap() { + // If the map has not been set up yet in the base class, then set it up + if (column_name_id_map_.empty()) { + std::unordered_map current_name_id_map = child_[0]->column_name_id_map(); + // Initialize private variables + RETURN_IF_NOT_OK(InitPrivateVariable(¤t_name_id_map)); + // Create the final column name to index mapping in the base class field + CreateFinalColMap(¤t_name_id_map); + MS_LOG(DEBUG) << "Column name map for map op set: " << this->ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// Validating if each of the input_columns exists in the DataBuffer. +Status MapOp::ValidateInColumns(const std::unordered_map &col_name_id_map) { + for (const auto &inCol : in_columns_) { + bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +Status MapOp::InitPrivateVariable(std::unordered_map *col_name_id_map) { + // If input_columns is empty(), The col at index-0 will be picked. + if (in_columns_.empty()) { + for (const auto &pair : *col_name_id_map) { + if (pair.second == 0) { + MS_LOG(INFO) << "Input columns empty for map op, will apply to the first column in the current table."; + in_columns_.push_back(pair.first); + break; + } + } + + // If caller didn't specify the out_col_names, assume they are same as the input_columns. + // This was done in the constructor, but if input columns was empty to start we have to redo it here. + if (out_columns_.empty() || out_columns_[0].empty()) { + out_columns_ = in_columns_; + } + } + + // Before we continue, issue a sanity check to make sure the input columns from user and the incoming + // columns from child are correct + RETURN_IF_NOT_OK(this->ValidateInColumns(*col_name_id_map)); + + // initialize keep_input_columns, true means to keep the column. + keep_input_columns_.resize(col_name_id_map->size(), true); + for (const auto &col_name : in_columns_) { + int32_t missed = (*col_name_id_map)[col_name]; + keep_input_columns_[missed] = false; + } + + // initialize to_process_indices. + for (const auto &col_name : in_columns_) { + to_process_indices_.push_back((*col_name_id_map)[col_name]); + } + return Status::OK(); +} + +// Create the final column name to index mapping and get indices of the columns this mapop does not use. +void MapOp::CreateFinalColMap(std::unordered_map *col_name_id_map) { + std::unordered_map final_col_name_id_map; + size_t num_cols = col_name_id_map->size(); + std::vector new_ids(num_cols); + if (in_columns_.size() == out_columns_.size()) { + for (size_t i = 0; i < in_columns_.size(); i++) { + int32_t loc = (*col_name_id_map)[in_columns_[i]]; + (void)col_name_id_map->erase(in_columns_[i]); + (*col_name_id_map)[out_columns_[i]] = loc; + } + + // Set the base class final column id map result + column_name_id_map_ = *col_name_id_map; + } else { + int32_t fill_idx = 0; + // First columns of the tables are occupied by the output columns from tensorOp. + for (const auto &col_name : out_columns_) { + final_col_name_id_map[col_name] = fill_idx++; + } + + // Creating new_ids mapping for the columns we keep. + for (size_t i = 0; i < num_cols; i++) { + if (keep_input_columns_[i]) { + new_ids[i] = fill_idx++; + } + } + + // Iterating through the old mapping to update the final mapping for the columns we kept. + std::string name; + for (const auto &pair : *col_name_id_map) { + name = pair.first; + int32_t old_id = pair.second; + if (keep_input_columns_[old_id]) { + final_col_name_id_map[name] = new_ids[old_id]; + } + } + + // Set the base class final column id map result + column_name_id_map_ = final_col_name_id_map; + } +} + +// Visitor accept method for NodePass +Status MapOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h new file mode 100644 index 0000000000..b1cd58010f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/map_op.h @@ -0,0 +1,268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_MAP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_MAP_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/queue.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// MapOp class implements the Map operator. It will apply a list of operations to each record specified by column names. +// The column order behavior after MapOp is as follows. +// [Case 1] If the number of Input Columns == the number of Output Column, column ordering after MapOp +// is the same as the original column order where the Remainder Columns stay in the same position, +// and the Output Columns are placed the same position of the Input Columns. +// For example, initially if the dataset has column order |A, B, C, D, E|, +// and we apply MapOp() with Input Columns {B, C} and Output Columns {X, Y}. +// The column order after applying MapOp will be |A, X, Y, D, E|. +// Note that in this case, |X, Y| is the Output Columns and |A, D, E| which is the Remainder Columns stay in +// their original position, and column B is replaced by column X and column C is replace by column Y. +// [Case 2] If the number of Input Columns != the number of Output Column, column ordering after MapOp +// is Output Columns followed by Remainder Columns. +// For example, initially if the dataset has column order |A, B, C, D, E|, +// and we apply MapOp() with Input Columns {B, C, A} and Output Columns {X, Y}. +// The column order after applying MapOp will be |X, Y, D, E|. +// Note that in this case, |X, Y| is the Output Columns and |D, E| is the Remainder Columns, +// and the Input Columns are gone and replaced by the Output Columns. + +// Keywords: +// Input Columns : a vector of column names (string) passed to MapOp specifying the column names from which +// Tensors are taken and passed to the TensorOp Compute(). +// Output Columns : a vector of column names (string) passed to MapOp specifying what are the column names +// for the Tensors produced by TensorOp Compute(). +// Remainder Columns : columns that exist in the dataset but are not mentioned in Input Columns. +// These columns will not be passed to TensorOp Compute(), but will be appended to the end of the Output Columns. +class MapOp : public ParallelOp { + public: + // The nested builder class inside of the MapOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOutColNames(const std::vector &out_col_names) { + build_out_col_names_ = out_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetTensorFuncs(std::vector> funcs) { + build_tensor_funcs_ = std::move(funcs); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + build_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + build_op_connector_size_ = connector_size; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPerformanceMode(bool perf_mode) { + build_perf_mode_ = perf_mode; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new MapOp object + // @return Status + Status Build(std::shared_ptr *ptr); + + private: + std::vector build_in_col_names_; + std::vector build_out_col_names_; + std::vector> build_tensor_funcs_; + int32_t build_num_workers_; + int32_t build_op_connector_size_; + bool build_perf_mode_; // Default true. + + // Check if the required parameters are set by the builder. + // @return Status The error code return + Status sanityCheck() const; + }; + + // Constructor of MapOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). + // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). + // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + MapOp(const std::vector &in_col_names, const std::vector &out_col_names, + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, + bool perf_mode); + + // Destructor + ~MapOp() = default; + + // A print method typically used for debugging + // @param out The output stream to write output to + // @param show_all A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out reference to the output stream being overloaded + // @param mo reference to the MapOp to display + // @return the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const MapOp &mo) { + mo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status The error code return + Status operator()() override; + + // Getter + // @return the number of threads consuming data from previous op's output Connector. + int32_t num_consumers() const override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "MapOp"; } + + // List of tensor ops getter/setter + // @Return the vector of tensor ops by non-const reference + + auto &TFuncs() { return tfuncs_; } + + const auto &TFuncs() const { return tfuncs_; } + + private: + // Local queues where worker threads can pop from. + // Popping directly from the Connector can block if the previous designated threads haven't pop. + // Setting the size of these queues to 0 is essentially the same as pulling directly from Connector. + QueueList> local_queues_; + + // Static variables to be ready by worker threads, no modification and readonly + std::vector> tfuncs_; + + // Variable to store the column name that the tensorOps are consuming + std::vector in_columns_; + + // Variable to store the column name that the tensorOps are producing + std::vector out_columns_; + + // Boolean mapping, true means to keep the column. + std::vector keep_input_columns_; + + // Indices of the columns to process. + std::vector to_process_indices_; + + // Performance mode is when the main thread creates local queues, pulls databuffers from the previous + // op's Connector and distributes them to the local queues. Workers pull from the local queues. + // If this flag is false, each worker pulls directly from the Connector. This use less resources + // (thread and memory), but when the computation cost is heavy (e.g. DecodeOp) and fluctuating, it can + // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. + bool perf_mode_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of MapOp: getting the data from previous Op, validating user specified column names, + // applying a list of TensorOps to each of the data, process the results and then + // pushing them back to MapOp's output Connector to be fetched by the next Op. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Private helper function for getting the next buffer + // When PerformanceMode is enabled, workers pop from the local queue. + // Otherwise, workers pop from the first child output Connector. + // @param p_buffer - the buffer to return + // @return Status return code + Status FetchNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id) { + if (perf_mode_) { + RETURN_IF_NOT_OK(local_queues_[worker_id]->PopFront(p_buffer)); + } else { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id)); + } + return Status::OK(); + } + + // Private function for worker thread to perform TensorOp's compute function and get the result. + // @param in_buffer A raw pointer to the DataBuffer. A raw pointer is fine because this function doesn't manage memory + // and is not shared with other threads. + // @param[out] new_tensor_table A new Tensor Table to be populated in this function. + Status WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_table); + + // Private function that create the final column name to index mapping and + // get indices of the columns this mapop does not use. + // @param col_name_id_map The column name to index mapping obtained from child operator + void CreateFinalColMap(std::unordered_map *col_name_id_map); + + // Validating if each of the input_columns exists in the DataBuffer. + // @param - the column map to check + // @return - status return code + Status ValidateInColumns(const std::unordered_map &col_name_id_map); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + // Private function for initializing private variables such as in_columns_, out_columns_. + // @return - Status + Status InitPrivateVariable(std::unordered_map *col_name_id_map); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_MAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc new file mode 100644 index 0000000000..abb827aea8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/parallel_op.h" + +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Constructor +ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler) + : DatasetOp(op_connector_size, sampler), + num_workers_(num_workers), + num_producers_(num_workers), + worker_connector_size_(1), + worker_connector_(nullptr) {} + +// Creates the internal worker connector for the parallel op if the derived class wants to use it +Status ParallelOp::CreateWorkerConnector(int32_t worker_connector_size) { + if (worker_connector_size == 0) { + RETURN_STATUS_UNEXPECTED("Worker connector size 0 is invalid."); + } + num_producers_ = 1; + worker_connector_size_ = worker_connector_size; + // Instantiate the worker connector. This is the internal connector, not the operators + // output connector. It has single master consuming from it (num producers is 1), and the number + // of workers is the defined count from the op. + worker_connector_ = std::make_unique(num_workers_, num_producers_, worker_connector_size); + + return Status::OK(); +} + +// A print method typically used for debugging +void ParallelOp::Print(std::ostream &out, bool show_all) const { + // Summary 1-liner print + if (!show_all) { + out << " [workers: " << num_workers_ << "]"; + // Call super class printer + DatasetOp::Print(out, show_all); + } else { + // Detailed print + DatasetOp::Print(out, show_all); + out << "\nNum workers: " << num_workers_; + } +} + +// Override base class reset to provide reset actions specific to the ParallelOp class. +Status ParallelOp::Reset() { + RETURN_IF_NOT_OK(DatasetOp::Reset()); // Perform any super class reset work + + // ParallelOp is abstract, but we do own the connector between workers and master + // (if the parallel op is configured for this). Reset that connector here. + if (worker_connector_) { + worker_connector_->Reset(); + } + + return Status::OK(); +} + +// Register the internal worker connectors +Status ParallelOp::RegisterWorkerConnectors() { + if (worker_connector_) { + return (worker_connector_->Register(tree_->AllTasks())); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h new file mode 100644 index 0000000000..da54ce1331 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/parallel_op.h @@ -0,0 +1,126 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ +#define DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ + +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// global const in our namespace +constexpr int32_t kEndOfActions = -1; + +// Forward declares +class DataBuffer; + +class DbConnector; + +// A ParallelOp provides a multi-threaded DatasetOp +class ParallelOp : public DatasetOp { + public: + // Constructor + // @param num_workers + // @param op_connector_size - size of the output connector for this operator + // @param sampler - The sampler for the op + ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr sampler = nullptr); + + // Destructor + ~ParallelOp() = default; + + // Creates the internal worker connector for the parallel op if the derived class wants to use it. + // @notes This changes the number of producers of this op to 1, since it establishes a master/worker + // relationship within the op, making all production flow through a single master. + // @return Status - The error return code + Status CreateWorkerConnector(int32_t worker_connector_size); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param pO - reference to the ParallelOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ParallelOp &po) { + po.Print(out, false); + return out; + } + + // During tree prepare phase, operators may have specific pre-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + // @return Status - The error return code + Status PrepareNodePreAction() override { + // Run common code from super class before adding ParallelOp specific logic + return (DatasetOp::PrepareNodePreAction()); + } + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + // @return Status - The error return code + Status PrepareNodePostAction() override { + // Run common code from super class before adding ParallelOp specific logic + return (DatasetOp::PrepareNodePostAction()); + } + + // Override base class reset to provide reset actions specific to the ParallelOp class. + // @return Status - The error code return + Status Reset() override; + + // Getter + // @return the number of workers + int32_t num_workers() const override { return num_workers_; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return num_workers_; } + + // Getter + // @return the number of producers pushing to the output Connector + // @notes The number of producers is commonly the same as number of workers, except in the case + // when a worker connector is set up. In that case, there are n workers, and a single master + // such that only 1 thread is a producer rather than the n workers. + // @return the number of producers + int32_t num_producers() const override { return num_producers_; } + + // Register the internal worker connectors. + // @return Status + Status RegisterWorkerConnectors() override; + + protected: + // Interface for derived classes to implement. All derived classes must provide the entry + // function with the main execution loop for worker threads. + // @return Status - The error code return + virtual Status WorkerEntry(int32_t workerId) = 0; + + int32_t num_workers_; // The number of worker threads + int32_t num_producers_; // The number of threads pushing to the out_connector_ + int32_t worker_connector_size_; + std::unique_ptr worker_connector_; // The internal connector for worker threads +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_PARALLEL_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc new file mode 100644 index 0000000000..fff5ba19e7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include +#include + +namespace mindspore { +namespace dataset { +// Constructor +PipelineOp::PipelineOp(int32_t op_connector_size, std::shared_ptr sampler) + : DatasetOp(op_connector_size, sampler) {} + +// A print method typically used for debugging +void PipelineOp::Print(std::ostream &out, bool show_all) const { + // Summary 1-liner print + if (!show_all) { + out << " [workers: "; + if (this->inlined()) { + out << "0 (inlined)]"; + } else { + out << "1]"; // Pipeline ops only have 1 worker + } + // Call super class printer + DatasetOp::Print(out, show_all); + } else { + // Detailed print + DatasetOp::Print(out, show_all); + out << "\nNum workers: "; + if (this->inlined()) { + out << "0 (inlined)"; + } else { + out << "1"; // Pipeline ops only have 1 worker + } + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h new file mode 100644 index 0000000000..0538349f48 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/pipeline_op.h @@ -0,0 +1,98 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ + +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { +// forward declare +class ExecutionTree; + +class DataBuffer; + +class PipelineOp : public DatasetOp { + public: + // Constructor + // @param op_connector_size - size of the output connector + // @return Builder setter method returns reference to the builder. + // @param sampler - The sampler for the op + explicit PipelineOp(int32_t op_connector_size, std::shared_ptr sampler = nullptr); + + // Destructor + ~PipelineOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param po - reference to the PipelineOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const PipelineOp &po) { + po.Print(out, false); + return out; + } + + // Getter + // @return The number of workers inside this op. Pipeline ops only have a single worker. + int32_t num_workers() const override { return 1; } + + // Getter + // @return the number of threads consuming from the previous Connector + int32_t num_consumers() const override { return 1; } + + // Getter + // @return The number of threads that push data to the output connector + int32_t num_producers() const override { return 1; } + + // During tree prepare phase, operators may have specific pre-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePreAction() override { + // Run common code from super class before adding PipelineOp specific logic + return (DatasetOp::PrepareNodePreAction()); + } + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override { + // Run common code from super class before adding PipelineOp specific logic + return (DatasetOp::PrepareNodePostAction()); + } + + protected: + // ******************************************************************************* + // I'm predicting there will be common arguments or functionality for pipeline ops, + // just not sure yet what those are. perhaps this intermediate class between + // DatasetOp and the actual ops is not needed at all? + // For example, if there's no common code for all of the non-parallel ops, then + // they can just inherit from DatasetOp directly and we can put this class into the + // trash. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_PIPELINE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc new file mode 100644 index 0000000000..e232a64164 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.cc @@ -0,0 +1,159 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/project_op.h" +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +ProjectOp::Builder::Builder(const std::vector &columns_to_project) + : builder_columns_to_project_(columns_to_project) {} + +Status ProjectOp::Builder::SanityCheck() const { + if (builder_columns_to_project_.empty()) { + std::string err_msg("Columns to project is empty."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +Status ProjectOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_columns_to_project_); + return Status::OK(); +} + +ProjectOp::ProjectOp(const std::vector &columns_to_project) + : PipelineOp(0), columns_to_project_(columns_to_project) {} + +void ProjectOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nColumns that are projected:"; + for (size_t i = 0; i < columns_to_project_.size(); i++) { + out << "\n" << columns_to_project_[i]; + } + out << "\n\n"; + } +} + +// Gets a buffer from the child operator and projects the buffer. +Status ProjectOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(p_buffer, worker_id, retry_if_eoe)); + if (!((*p_buffer)->eoe()) && !((*p_buffer)->eof())) { + RETURN_IF_NOT_OK(Project(p_buffer)); + } + return Status::OK(); +} + +Status ProjectOp::Project(std::unique_ptr *data_buffer) { + std::unique_ptr new_tensor_table = std::make_unique(); + while ((*data_buffer)->NumRows() > 0) { + TensorRow current_row; + RETURN_IF_NOT_OK((*data_buffer)->PopRow(¤t_row)); + TensorRow new_row; + (void)std::transform(projected_column_indices_.begin(), projected_column_indices_.end(), + std::back_inserter(new_row), [¤t_row](uint32_t x) { return current_row[x]; }); + new_tensor_table->push_back(new_row); + } + (*data_buffer)->set_tensor_table(std::move(new_tensor_table)); + return Status::OK(); +} + +// Class functor operator () override. +// Most dataset ops operate by launching a thread (see ExecutionTree). +// However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the +// functor since this op runs inlined inside another operator. The function is overloaded to +// ensure that it is not called by mistake (it will generate an error). +Status ProjectOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. ProjectOp is an inlined operator."); } + +int32_t ProjectOp::num_consumers() const { + if (parent_.empty()) { + MS_LOG(DEBUG) << "Project operator, no parent node, assuming it's the root and returning 1."; + return 1; + } else if (parent_[0] == nullptr) { + MS_LOG(DEBUG) << "Project operator, pointer to the first parent is null. Returning 0."; + return 0; + } else { + return parent_[0]->num_consumers(); + } +} + +int32_t ProjectOp::num_producers() const { + if (child_.empty() || child_[0] == nullptr) { + MS_LOG(DEBUG) << "Project operator, pointer to child node is null. Returning 0."; + return 0; + } else { + return child_[0]->num_producers(); + } +} + +Status ProjectOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); } + +// Visitor accept method for NodePass +Status ProjectOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +// Compute the column map and save it into our own column name map +// We cannot use the super class ComputeColMap here because we're making a modification of the +// map from the child map. +Status ProjectOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + std::unordered_map child_column_name_mapping = child_[0]->column_name_id_map(); + for (size_t i = 0; i < columns_to_project_.size(); i++) { + std::string ¤t_column = columns_to_project_[i]; + if (child_column_name_mapping.find(current_column) == child_column_name_mapping.end()) { + std::string err_msg = "ProjectOp: column " + current_column + " does not exist in child operator."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + // Setup the new column name mapping for ourself (base class field) + column_name_id_map_[current_column] = i; + projected_column_indices_.push_back(child_column_name_mapping[current_column]); + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h new file mode 100644 index 0000000000..c2f14d34b7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/project_op.h @@ -0,0 +1,127 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class ProjectOp : public PipelineOp { + public: + // The nested builder class inside of the ProjectOp is used to help manage all of the arguments + // for constructing it. This repeat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @param columns_to_project - + // @return This is a constructor. + explicit Builder(const std::vector &columns_to_project); + + // Builder destructor. + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new ProjectOp object. + Status Build(std::shared_ptr *); + + private: + std::vector builder_columns_to_project_; + Status SanityCheck() const; + }; + + // Constructor of the ProjectOp. + // @param columnsToProject - + explicit ProjectOp(const std::vector &columns_to_project); + + // Destructor. + ~ProjectOp() = default; + + // A print method typically used for debugging. + // @param out - The output stream to write output to. + // @param show_all - A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload. + // @notes This allows you to write the debug print info using stream operators. + // @param out - reference to the output stream being overloaded. + // @param project_op - reference to the ProjectOp to display. + // @return - the output stream must be returned. + friend std::ostream &operator<<(std::ostream &out, const ProjectOp &project_op) { + project_op.Print(out, false); + return out; + } + + // Class functor operator () override. + // Most dataset ops operate by launching a thread (see ExecutionTree). + // However, the ProjectOp is defined as a inlined operator, so it is invalid to launch the + // functor since this op runs inlined inside another operator. The function is overloaded to + // ensure that it is not called by mistake (it will generate an error). + // @return Status - The error code returned. + Status operator()() override; + + // Gets a buffer from the child node and projects that buffer. The caller is typically our parent node. + // @param p_buffer - output pointer to the projected buffer. + // @param worker_id - The worker id + Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; + + // Base-class override. Return the number of workers in the first parent. + // @param workerId - The worker id + int32_t num_consumers() const override; + + // Base-class override. Return the number of producers in the first child. + // @param workerId - The worker id + int32_t num_producers() const override; + + // Base-class override for special eoe handler. + // Inline operators must override this because there is no connector to push eoe onto. + // @return Status - The error code returned. + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for special eof handler. + // Inline operators must override this because there is no connector to push eof onto. + // @return Status - The error code returned. + Status EofReceived(int32_t worker_id) override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ProjectOp"; } + + private: + std::vector columns_to_project_; + std::vector projected_column_indices_; + + Status Project(std::unique_ptr *data_buffer); + + // Computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_PROJECT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc new file mode 100644 index 0000000000..d12660e6f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/rename_op.h" +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// builds +RenameOp::Builder::Builder() { + // Some arguments to the RenameOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the RenameOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status RenameOp::Builder::SanityCheck() const { return Status::OK(); } + +// build method for RenameOp +Status RenameOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_in_columns_, builder_out_columns_, builder_op_connector_size_); + return Status::OK(); +} + +// constructor +RenameOp::RenameOp(const std::vector &in_col_names, const std::vector &out_col_names, + int32_t op_connector_size) + : PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {} + +// destructor +RenameOp::~RenameOp() {} + +// main entry point for rename +Status RenameOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr curr_buffer; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + if (curr_buffer->buffer_flags() != DataBuffer::kDeBFlagNone) { + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + std::string err_msg = "Rename first buffer got was control signal"; + // if 1st eoe or eof, pass it on then return + RETURN_STATUS_UNEXPECTED(err_msg); + } + + while (curr_buffer->eof() == false) { + while (curr_buffer->eoe() == false) { + // push the renamed input buffer + MS_LOG(DEBUG) << "Rename operator pushing next buffer."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } // end of while eoe loop + + // we got eoe, now try again until we get eof + MS_LOG(DEBUG) << "Rename operator EOE Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + MS_LOG(DEBUG) << "Rename operator fetching buffer after EOE."; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } // end of while eof loop + + MS_LOG(DEBUG) << "Rename opeerator EOF Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Rename core functionality to compute the new column name id map. +// We need to overwrite the super class ComputeColMap here because we're making a modification of the +// map from the child map. +Status RenameOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + column_name_id_map_ = child_[0]->column_name_id_map(); + // iterate over my index in input vector, find the corresponding position + std::unordered_map new_col_name_id_map = {}; + // parameter for input check + size_t found = 0; + + // iterate over all the pairs and if there is a name match with rename, rename the column and add it to new map + // by doing it this way we recreate a new ColNameIdMap and allow for switching + for (const auto &pair : column_name_id_map_) { + std::string name = pair.first; + int32_t id = pair.second; + // find name + std::vector::iterator it; + it = std::find(in_columns_.begin(), in_columns_.end(), name); + // for c input checks here we have to count the number of times we find the stuff in in_columns_ + // because we iterate over the mInputList n times + if (it != in_columns_.end()) { + // found + found += 1; + int index = std::distance(in_columns_.begin(), it); + MS_LOG(DEBUG) << "Rename operator index found " << index << " value " << id << "."; + + new_col_name_id_map[out_columns_[index]] = id; + } else { + // not found + MS_LOG(DEBUG) << "Rename operator index not found: " << id << " is the column id."; + new_col_name_id_map[name] = id; + } + } + // only checks number of renamed columns have been found, this input check doesn't check everything + if (found != in_columns_.size()) { + MS_LOG(DEBUG) << "Rename operator column names found: " << found << " out of " << in_columns_.size() << "."; + std::string err_msg = "Renamed column doesn't exist in dataset"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Now, overwrite our column map with the new renamed columns/id's + column_name_id_map_ = new_col_name_id_map; + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// prints rename +void RenameOp::Print(std::ostream &out, // In: The output stream to print to + bool show_all) const { // In: T/F if it should print everything + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nIn columns:"; + for (size_t i = 0; i < in_columns_.size(); ++i) { + out << "\n " << in_columns_[i]; + } + for (size_t i = 0; i < out_columns_.size(); ++i) { + out << "\n " << out_columns_[i]; + } + out << "\n\n"; + } +} + +Status RenameOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now."; + return Status::OK(); +} + +Status RenameOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status RenameOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h new file mode 100644 index 0000000000..d846bb1b40 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/rename_op.h @@ -0,0 +1,138 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ +#define DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// forward declare +class DataBuffer; + +class RenameOp : public PipelineOp { + public: + // The nested builder class inside of the RenameOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + builder_in_columns_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOutColNames(const std::vector &out_col_names) { + builder_out_columns_ = out_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // The builder "build" method creates the ZipOp dataset Operator. + // @return shared_ptr to the new RenameOp object + Status Build(std::shared_ptr *); + + private: + std::vector builder_in_columns_; + std::vector builder_out_columns_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor for RenameOp + // @param in_col_names names of columns to rename + // @param out_col_names names of columns after rename + // @param op_connector_size connector size + RenameOp(const std::vector &in_col_names, // In: Col names to consume + const std::vector &out_col_names, // In: Col names to produce + int32_t op_connector_size); + + // Destructor + ~RenameOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Rename + // @param out output stream to print to + // @param show_all if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RenameOp &ro) { + ro.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "RenameOp"; } + + protected: + // Rename core functionality + // Computing the assignment of the new column name map. + // @return - Status + Status ComputeColMap() override; + + // Variable to store the input column names + std::vector in_columns_; + + // Variable to store the output column names + std::vector out_columns_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_RENAME_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc new file mode 100644 index 0000000000..6d3dc91ed3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +RepeatOp::Builder::Builder(int32_t count) : build_max_repeats_(count) {} + +Status RepeatOp::Builder::SanityCheck() const { + if (build_max_repeats_ < kInfiniteRepeat || build_max_repeats_ == 0) { + std::string err_msg("Repeat count must be > 0 or -1."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status RepeatOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_repeats_); + return Status::OK(); +} + +// Constructor of the RepeatOp. +RepeatOp::RepeatOp(int32_t count) : PipelineOp(0), max_repeats_(count), repeat_count_(0) {} + +// Destructor +RepeatOp::~RepeatOp() {} + +// A print method typically used for debugging +void RepeatOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [repeats: " << max_repeats_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nCurrent repeat count: " << repeat_count_ << "\nMax repeat count: " << max_repeats_ + << "\nLeaf Nodes in execution path:"; + if (!eoe_ops_.empty()) { + for (size_t i = 0; i < eoe_ops_.size(); i++) { + out << "\n Operator: " << eoe_ops_[i]->id(); + } + } else { + out << " None."; + } + out << "\n\n"; + } +} + +// This function returns the buffer that is at the top of our output connector. The caller is +// typically our parent node, when the parent is asking us to provide the next buffer of data. +// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get +// a buffer from our child. +// This function sets the `retryIfEoe` flag when popping from the child connector. This way, +// this function will retry to pop the connector again and will get the non-EOE buffer if any. +Status RepeatOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) { + if (child_.empty()) { + RETURN_STATUS_UNEXPECTED("RepeatOp can't be the leaf node."); + } + + std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + // Loop until non EOE is received + while (buf->eoe()) { + RETURN_IF_NOT_OK(EoeReceived(worker_id)); + if (state_ == OpState::kDeOpIdle) { + *p_buffer = std::move(buf); + return Status::OK(); + } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + } + // Check if the last buf is next eof + if (buf->eof()) { + RETURN_IF_NOT_OK(EofReceived(worker_id)); + } + *p_buffer = std::move(buf); + return Status::OK(); +} + +// Base-class override for handling cases when an eoe is received. +Status RepeatOp::EoeReceived(int32_t worker_id) { + repeat_count_++; + MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ + << ") end of epoch message received. Repeat count is now: " << repeat_count_ << "."; + bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated); + bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat); + // If we've reached the requested repeat count, then flag the eoe nodes + // to tell them they've got one more epoch to perform. When they reach the end + // of the last epoch, they quit rather than loop again. This happens in two cases: + // 1- We are also repeated (by another repeat op) and we are at the last repetition. Or, + // 2- We are not repeated + if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1) && (!repeated || last_repeat)) { + for (auto &eoe_op : eoe_ops_) { + eoe_op->set_control_flag(kDeOpLastRepeat); + } + } + if (repeat_count_ == max_repeats_) { + repeat_count_ = 0; + state_ = OpState::kDeOpIdle; + return Status::OK(); + } + + // Invoke a reset against the eoe nodes only. + for (auto &eoe_op : eoe_ops_) { + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + + return Status::OK(); +} + +// Class functor operator () override. +// Most dataset ops operate by launching a thread (see ExecutionTree). +// However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the +// functor since this op runs inlined inside another operator. The function is overloaded to +// ensure that it is not called by mistake (it will generate an error). +Status RepeatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RepeatOp is an inlined operator."); } + +// Base-class override for handling cases when an eof is received. +Status RepeatOp::EofReceived(int32_t worker_id) { + MS_LOG(DEBUG) << "Repeat operator EOF received, do nothing now."; + return Status::OK(); +} + +int32_t RepeatOp::num_consumers() const { + if (parent_.empty()) { + MS_LOG(DEBUG) << "Repeat operator, no parent node, assuming it's root and returning 1."; + return 1; + } else if (parent_[0] == nullptr) { + MS_LOG(DEBUG) << "Repeat operator, pointer to the first parent is null. Returning 0."; + return 0; + } else { + return parent_[0]->num_consumers(); + } +} + +// Drive reset actions if needed +Status RepeatOp::Reset() { + // If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op. + // In that case, we now have to bounce the reset down to our own eoe ops. + MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset."; + for (auto &eoe_op : eoe_ops_) { + RETURN_IF_NOT_OK(eoe_op->Reset()); + } + state_ = OpState::kDeOpRunning; + return Status::OK(); +} + +int32_t RepeatOp::num_producers() const { + if (child_.empty() || child_[0] == nullptr) { + MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0."; + return 0; + } else { + return child_[0]->num_producers(); + } +} + +// Pre-Visitor accept method for NodePass +Status RepeatOp::PreAccept(NodePass *p, bool *modified) { + // Downcast shared pointer then call the pre-visitation + return p->PreRunOnNode(shared_from_base(), modified); +} + +// Visitor accept method for NodePass +Status RepeatOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h new file mode 100644 index 0000000000..f5259de30e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/repeat_op.h @@ -0,0 +1,146 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ +#define DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class RepeatOp : public PipelineOp { + public: + static constexpr int32_t kInfiniteRepeat = -1; + + // The nested builder class inside of the RepeatOp is used to help manage all of the arguments + // for constructing it. This repeat op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of repeats to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new RepeatOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_repeats_; + + Status SanityCheck() const; + }; + + // Constructor of the RepeatOp. + // @note The builder class should be used to call it + // @param count - The number of repeats to do + explicit RepeatOp(int32_t count); + + // Destructor + ~RepeatOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the RepeatOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const RepeatOp &ro) { + ro.Print(out, false); + return out; + } + + // Class functor operator () override. + // Most dataset ops operate by launching a thread (see ExecutionTree). + // However, the RepeatOp is defined as a inlined operator, so it is invalid to launch the + // functor since this op runs inlined inside another operator. The function is overloaded to + // ensure that it is not called by mistake (it will generate an error). + // @return Status - The error code return + Status operator()() override; + + // This function returns the buffer that is at the top of our output connector. The caller is + // typically our parent node, when the parent is asking us to provide the next buffer of data. + // Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get + // a buffer from our child. + // @note This function sets the `retryIfEoe` flag when popping from the child connector. This way, + // this function will retry to pop the connector again and will get the non-EOE buffer if any. + // @param p_buffer - output pointer to the buffer that it will fetch. + // @param worker_id - The worker id + // @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE. + // @return Status - The error code return + Status GetNextBuffer(std::unique_ptr *p_buffer, int32_t worker_id, bool retry_if_eoe) override; + + // Base-class override for handling cases when an eoe is received. + // @param worker_id - The worker id + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for handling cases when an eof is received. + // @param worker_id - The worker id + Status EofReceived(int32_t worker_id) override; + + /// \brief reset Op + /// \@return Status - The error code return + Status Reset() override; + + // Base-class override. Return the number of workers in the first parent. + // @param workerId - The worker id + int32_t num_consumers() const override; + + // Base-class override. Return the number of producers in the first child. + // @param workerId - The worker id + int32_t num_producers() const override; + + /// \brief Base-class override for NodePass pre-visit acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status PreAccept(NodePass *p, bool *modified) override; + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p The node to visit + /// \param[out] modified Indicator if the node was modified + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "RepeatOp"; } + + /// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes + /// \param[in] eoe_op The input leaf/eoe operator to add to the list + void AddToEoeList(std::shared_ptr eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); } + + private: + int32_t max_repeats_; // The number of repeats that the user requested + int32_t repeat_count_; // A counter for the current number of executed repeats + std::vector> eoe_ops_; // List of operators that can generate EOE underneath this repeat. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_REPEAT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc new file mode 100644 index 0000000000..0eb5f29eaf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.cc @@ -0,0 +1,304 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(_WIN32) || defined(_WIN64) +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +constexpr int32_t ShuffleOp::kShuffleStateInit; +constexpr int32_t ShuffleOp::kShuffleStateActive; +constexpr int32_t ShuffleOp::kShuffleStateDrain; + +// Builder constructor. Creates the builder object. +ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) { + std::shared_ptr cfg = GlobalContext::config_manager(); + build_op_connector_size_ = cfg->op_connector_size(); + build_rows_per_buffer_ = cfg->rows_per_buffer(); + build_shuffle_seed_ = GetSeed(); +} + +Status ShuffleOp::Builder::SanityCheck() const { + if (build_shuffle_size_ < 2) { + RETURN_STATUS_UNEXPECTED("Shuffle buffer size must be greater than 1."); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status ShuffleOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_, + build_reshuffle_each_epoch_, build_rows_per_buffer_); + return Status::OK(); +} + +// Constructor of the ShuffleOp +ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, + int32_t rows_per_buffer) + : PipelineOp(op_connector_size), + shuffle_size_(shuffle_size), + shuffle_seed_(shuffle_seed), + reshuffle_each_epoch_(reset_every_epoch), + rng_(shuffle_seed), + buffer_counter_(0), + rows_per_buffer_(rows_per_buffer), + shuffle_buffer_(std::make_unique()), + shuffle_last_row_idx_(0), + shuffle_buffer_state_(kShuffleStateInit) {} + +// Private function to re-init the shuffle op for another epoch. Shuffle op calls this by +// itself rather than waiting for the reset driven from operators above it in the pipeline. +Status ShuffleOp::SelfReset() { + MS_LOG(DEBUG) << "Shuffle operator performing a self-reset."; + // If reshuffle_each_epoch is false, then we always use the same seed for every + // epoch. + // If reshuffle_each_epoch is true, then the first epoch uses the given seed, + // and all subsequent epochs will then keep on using the rng_ without resetting it + if (!reshuffle_each_epoch_) { + rng_ = std::mt19937_64(shuffle_seed_); + } + + shuffle_buffer_ = std::make_unique(); + buffer_counter_ = 0; + shuffle_last_row_idx_ = 0; + shuffle_buffer_state_ = kShuffleStateInit; + return Status::OK(); +} + +// A print method typically used for debugging +void ShuffleOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [shuffle size: " << shuffle_size_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nShuffle size: " << shuffle_size_ << "\nRows per buffer: " << rows_per_buffer_ + << "\nShuffle buffer state: " << shuffle_buffer_state_ << "\nShuffle seed: " << shuffle_seed_ << "\n\n"; + } +} + +// Private function to add a new row to the shuffle buffer. +Status ShuffleOp::AddRowToShuffleBuffer(TensorRow new_shuffle_row) { + // If the last slot of our shuffle buffer was not the full size of the shuffle buffer then we are + // filling it during the initial fill codepath and thus growing it's size. In that case, we push + // back the new row to grow our shuffle buffer size by 1. + // If we are already at the full size, then we overwrite the last slot with our row (and the last + // slot better be empty because it should already have been swapped out during the random row + // selection that was done previously!) + if (shuffle_last_row_idx_ < (shuffle_size_ - 1)) { + shuffle_buffer_->push_back(std::move(new_shuffle_row)); + shuffle_last_row_idx_ = (shuffle_buffer_->size()) - 1; + } else { + if (!(*shuffle_buffer_)[shuffle_last_row_idx_].empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Last row of shuffle buffer should not be occupied!"); + } + (*shuffle_buffer_)[shuffle_last_row_idx_] = std::move(new_shuffle_row); + } + return Status::OK(); +} + +// Class functor operator () override. +// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work +Status ShuffleOp::operator()() { + std::unique_ptr new_buffer_table; // A tensor table to be used for output. + + // Synchronize with TaskManager once the thread is launched. + TaskManager::FindMe()->Post(); + + // Shuffle op does not have workers, and only consumes from child 0. + // Create the child iterator to fetch our data from. + int32_t worker_id = 0; + int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Main operator loop + while (true) { + // Do an initial populate of the shuffle buffer + RETURN_IF_NOT_OK(InitShuffleBuffer()); + + // This is our main loop exit condition, when the iterator has no more data completely. + if (child_iterator_->eof_handled()) { + break; + } + + // Next, enter into the main execution loop of the shuffle op. + // When the tail index position of our shuffle buffer goes negative it means that we've + // fully drained the data from the shuffle buffer and we're done. + while (shuffle_last_row_idx_ >= 0) { + // Step 1) + // Create an output tensor table if one is not created yet. + if (!new_buffer_table) { + new_buffer_table = std::make_unique(); + } + + // Step 2) + // Randomly select a slot from our shuffle buffer and copy that row into the output + // tensor table. We remove the data from the shuffle buffer, leaving that slot + // in the table as an empty vector + int64_t random_slot = rng_() % (shuffle_last_row_idx_ + 1); + new_buffer_table->push_back(std::move((*shuffle_buffer_)[random_slot])); + + // Step 3) + // If the output tensor table is at the requested size, then create a buffer for it + // and send this buffer on it's way up the pipeline. Special case is if this is the + // last row then we also send it. + if (new_buffer_table->size() == rows_per_buffer_ || shuffle_last_row_idx_ == 0) { + auto new_buffer = std::make_unique(buffer_counter_, DataBuffer::kDeBFlagNone); + new_buffer->set_tensor_table(std::move(new_buffer_table)); + buffer_counter_++; + MS_LOG(DEBUG) << "Shuffle operator sending a buffer to output."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(new_buffer))); + } + + // Step 4) + // Take the last row from shuffle buffer, and swap it into the row position that was + // just vacated. This makes the shuffle buffer contiguous, with an empty slot at the + // tail of the shuffle buffer. + if (random_slot != shuffle_last_row_idx_) { + (*shuffle_buffer_)[random_slot] = std::move((*shuffle_buffer_)[shuffle_last_row_idx_]); + } + + // Step 5) + // Refill the last slot of the shuffle buffer with the next row from input if we are in the + // active state. + // If we are in the draining state, we do not need to fetch another row to replace the one we + // just drained. + if (shuffle_buffer_state_ == kShuffleStateActive) { + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + + if (!new_row.empty()) { + RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); + } else { + shuffle_buffer_state_ = kShuffleStateDrain; + } + } + + // If we are draining, reposition (decrement) our tail index in the shuffle buffer since we + // just drained a row from it. + if (shuffle_buffer_state_ == kShuffleStateDrain) { + shuffle_last_row_idx_--; + } + } + + // Since we overloaded eoeReceived function, we are responsible to flow the EOE up the + // pipepline manually now that we are done draining the shuffle buffer + MS_LOG(DEBUG) << "Shuffle operator sending EOE."; + auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + // Do not wait for any reset to be flown down from operators above us. + // Instead, manually update ourselves and then go reloop to start fetching from child operator + // right away. Any Reset() from the parent will still perform common reset actions. + RETURN_IF_NOT_OK(this->SelfReset()); + } + return Status::OK(); +} + +// Private function populate the shuffle buffer initially by fetching from the child output +// connector until the shuffle buffer is full (or there is no more data coming). +Status ShuffleOp::InitShuffleBuffer() { + MS_LOG(DEBUG) << "Shuffle operator initializing the shuffle buffer."; + + // The first phase of this operator is to read incoming buffers and then drain those + // rows from the buffers, putting them into our own local table of tensors (the shuffle + // buffer). + // This shuffle buffer initialization phase stops when we've either filled up the + // shuffle buffer to it's max size, or the dataset below us is not providing any more + // rows. + if (shuffle_buffer_state_ != kShuffleStateInit) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Invalid shuffle buffer state (SHUFFLE_STATE_INIT expected)"); + } + + // Before we drop into the fetching loop, call the fetch once for the first time + // to fill the first row and grab the first buffer. + TensorRow new_row; + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + + if (child_iterator_->eof_handled()) { + MS_LOG(DEBUG) << "Shuffle operator init picked up EOF. No more epochs."; + return Status::OK(); + } + + if (new_row.empty()) { + RETURN_STATUS_UNEXPECTED("Unable to fetch a single row for shuffle buffer."); + } + + // Now fill the rest of the shuffle buffer until we are unable to get the next row or we reached + // the desired shuffle buffer size. + while (!new_row.empty() && shuffle_buffer_->size() < static_cast(shuffle_size_ - 1)) { + // Add the previously fetched row + RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); + + // Fetch the next row + RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row)); + } + + // If we quit the loop due to being at the shuffle size, still need to add the last row here. + if (!new_row.empty()) { + RETURN_IF_NOT_OK(AddRowToShuffleBuffer(std::move(new_row))); + shuffle_buffer_state_ = kShuffleStateActive; // Transition to the active state + } else { + // If init phase doesn't have more rows, then skip the active state and jump straight to the + // shuffle buffer draining state + shuffle_buffer_state_ = kShuffleStateDrain; + } + + MS_LOG(DEBUG) << "Shuffle operator finished intializing the shuffle buffer."; + return Status::OK(); +} + +Status ShuffleOp::EoeReceived(int32_t worker_id) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ShuffleOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h new file mode 100644 index 0000000000..86bea7cc77 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/shuffle_op.h @@ -0,0 +1,204 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class ExecutionTree; + +class DbConnector; + +class DataBuffer; + +class ShuffleOp : public PipelineOp { + // Shuffle buffer state flags + // + // Shuffle buffer is in a state of being initialized + static constexpr int32_t kShuffleStateInit = 0; + + // Shuffle buffer is in a state of being actively drained from, but refilling as well + static constexpr int32_t kShuffleStateActive = 1; + + // Shuffle buffer is in a state of being drained + static constexpr int32_t kShuffleStateDrain = 2; + + public: + // The nested builder class inside of the ShuffleOp is used to help manage all of the arguments + // for constructing it. The shuffle op is fairly simple though, but the builder provides a + // consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetShuffleSize(int32_t shuffle_size) { + build_shuffle_size_ = shuffle_size; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetShuffleSeed(uint32_t shuffle_seed) { + build_shuffle_seed_ = shuffle_seed; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + build_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetReshuffleEachEpoch(bool reshuffle_each_epoch) { + build_reshuffle_each_epoch_ = reshuffle_each_epoch; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + build_op_connector_size_ = op_connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @return shared_ptr to the new ShuffleOp object + Status Build(std::shared_ptr *); + + private: + // The builder saves all ShuffleOp construction arguments internally. + // The following are the arguments. + int32_t build_shuffle_size_; + uint32_t build_shuffle_seed_; + int32_t build_rows_per_buffer_; + bool build_reshuffle_each_epoch_; + int32_t build_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor of the ShuffleOp + // @note The builder class should be used to call it + // @param shuffle_size - The size for the shuffle buffer + // @param shuffle_seed - The seed to use for random number generation + // @param op_connector_size - The output connector queue size + // @param rows_per_buffer - The requested number of rows per buffer + ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, + int32_t rows_per_buffer); + + // Destructor + ~ShuffleOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param so - reference to the ShuffleOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const ShuffleOp &so) { + so.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for special eoe handler. + // ShuffleOp must override this because it shall not perform default handling of eoe. Instead + // the ShuffleOp needs to manage actions related to the end of the epoch itself. + // @return Status - The error code return + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ShuffleOp"; } + + private: + // Private function to add a new row to the shuffle buffer. + // @return Status - The error code return + Status AddRowToShuffleBuffer(TensorRow new_shuffle_row); + + // Private function to populate the shuffle buffer initially by fetching from the child output + // connector until the shuffle buffer is full (or there is no more data coming). + // @return Status - The error code return + Status InitShuffleBuffer(); + + // Private function to re-init the shuffle op for another epoch. Shuffle op calls this by + // itself rather than waiting for the reset driven from operators above it in the pipeline. + // @return Status - The error code return + Status SelfReset(); + + int32_t shuffle_size_; // User config for the size of the shuffle buffer (number of rows) + uint32_t shuffle_seed_; + bool reshuffle_each_epoch_; + // rng_ is seeded initially with shuffle_seed_. mt19937 is used for its large period. + // specifically mt19937_64 is used to generate larger random numbers to reduce bias when + // modding to fit within our desired range. we dont use a distribution + // (ie uniform_int_distribution) because we will need to create up to |dataset| instances + // of the distribution object in the common case of a perfect shuffle + std::mt19937_64 rng_; + int32_t buffer_counter_; // For creating new buffer id's + int32_t rows_per_buffer_; // Number of rows to pack into output buffer + // A single (potentially large) buffer of tensor rows for performing shuffling. + std::unique_ptr shuffle_buffer_; + int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer + int32_t shuffle_buffer_state_; // State tracking for the shuffle buffer phases of work + + std::unique_ptr child_iterator_; // An iterator for fetching. +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SHUFFLE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc new file mode 100644 index 0000000000..2fe8cbeaa6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +SkipOp::Builder::Builder(int32_t count) : build_max_skips_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status SkipOp::Builder::SanityCheck() const { + if (build_max_skips_ < 0) { + std::string err_msg("Skip count must be positive integer or 0."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status SkipOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_skips_, builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the SkipOp. +SkipOp::SkipOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {} + +// Destructor +SkipOp::~SkipOp() {} + +// A print method typically used for debugging +void SkipOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [skips: " << max_skips_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nSkip count: " << skip_count_ << "\nMax skips: " << max_skips_ << "\n\n"; + } +} + +// Base-class override for handling cases when an eoe is received. +Status SkipOp::EoeReceived(int32_t worker_id) { + skip_count_ = 0; + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// main entry point for skip +Status SkipOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr curr_buffer; + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + + while (curr_buffer->eof() == false) { + // Reset count + skip_count_ = 0; + while (curr_buffer->eoe() == false) { + // Drop first count rows + while (skip_count_ < max_skips_) { + if (curr_buffer->eoe() || curr_buffer->eof()) { + break; + } + // Consider the rows of buffer more than one + TensorRow drop_row; + int row_num = curr_buffer->NumRows(); + int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; + skip_count_ += drop_num; + for (int i = 0; i < drop_num; i++) { + RETURN_IF_NOT_OK(curr_buffer->PopRow(&drop_row)); + } + if (curr_buffer->NumRows() == 0) { + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + } + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + // we got eoe, now try again until we got eof + MS_LOG(DEBUG) << "Skip operator EOE Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + RETURN_IF_NOT_OK(GetNextInput(&curr_buffer)); + } + + MS_LOG(DEBUG) << "Skip operator EOF Received."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Base-class override for handling cases when an eof is received. +Status SkipOp::EofReceived(int32_t worker_id) { + MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status SkipOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h new file mode 100644 index 0000000000..a717d0efa4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/skip_op.h @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ + +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class SkipOp : public PipelineOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of skip to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new SkipOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_skips_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor of the SkipOp. + // @note The builder class should be used to call it + // @param count - The number of skips to do + explicit SkipOp(int32_t count, int32_t op_connector_size); + + // Destructor + ~SkipOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for handling cases when an eoe is received. + // @param worker_id - The worker id + Status EoeReceived(int32_t worker_id) override; + + // Base-class override for handling cases when an eof is received. + // @param worker_id - The worker id + Status EofReceived(int32_t worker_id) override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "SkipOp"; } + + private: + int32_t max_skips_; // The number of skips that the user requested + int32_t skip_count_; // A counter for the current number of executed skips +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc new file mode 100644 index 0000000000..9d7d5622a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.cc @@ -0,0 +1,430 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" + +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status CelebAOp::Builder::Build(std::shared_ptr *op) { + MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << "."; + MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + // label is like this:0 1 0 0 1...... + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + *op = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, builder_decode_, builder_dataset_type_, + builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); + if (*op == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); + } + + return Status::OK(); +} + +Status CelebAOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() ? "" : "CelebA path is invalid or not set\n"; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is smaller than 1\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, + bool decode, const std::string &dataset_type, const std::set &exts, + std::unique_ptr schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size, std::move(sampler)), + rows_per_buffer_(rows_per_buffer), + folder_path_(dir), + decode_(decode), + extensions_(exts), + data_schema_(std::move(schema)), + num_rows_in_attr_file_(0), + dataset_type_(dataset_type) { + attr_info_queue_ = std::make_unique>>(queue_size); + io_block_queues_.Init(num_workers_, queue_size); +} + +Status CelebAOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "tree_ not set"); + } + + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(attr_info_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Walking attr file", std::bind(&CelebAOp::ParseAttrFile, this))); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CelebAOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(ParseImageAttrInfo()); + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + + return Status::OK(); +} + +Status CelebAOp::ParseAttrFile() { + TaskManager::FindMe()->Post(); + Path folder_path(folder_path_); + std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString()); + if (!attr_file.is_open()) { + return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "Celeba attr file does not exist"); + } + + const auto PushBackToQueue = [this](std::vector &vec, std::ifstream &attr_file, + std::ifstream &partition_file) { + Status s = attr_info_queue_->EmplaceBack(vec); + if (s.IsError()) { + CLOSE_FILE(attr_file, partition_file); + return s; + } + return Status::OK(); + }; + + std::string rows_num; + std::string attr_name; + (void)getline(attr_file, rows_num); + try { + num_rows_in_attr_file_ = static_cast(std::stoul(rows_num)); // First line is rows number in attr file + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, invalid argument."); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED("Conversion to ulong failed, out of range."); + } + + (void)getline(attr_file, attr_name); // Second line is attribute name,ignore it + std::string image_info; + std::vector image_infos; + image_infos.reserve(oc_queue_size_); + while (getline(attr_file, image_info)) { + if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) { + continue; + } + image_infos.push_back(image_info); + if (image_info.size() % oc_queue_size_ == 0) { + RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); + image_infos.clear(); + } + } + if (!image_infos.empty()) { + RETURN_IF_NOT_OK(PushBackToQueue(image_infos, attr_file, partition_file_)); + } + std::vector end_indicator = std::vector(0); + RETURN_IF_NOT_OK(PushBackToQueue(end_indicator, attr_file, partition_file_)); // end indicator + CLOSE_FILE(attr_file, partition_file_); + return Status::OK(); +} + +bool CelebAOp::CheckDatasetTypeValid() { + if (!partition_file_.is_open()) { + Path folder_path(folder_path_); + partition_file_.open((folder_path / "list_eval_partition.txt").toString()); + if (!partition_file_.is_open()) { + MS_LOG(ERROR) << "Celeba partition file does not exist!"; + return false; + } + } + std::string line; + (void)getline(partition_file_, line); + std::vector vec = Split(line); + if (vec.size() != 2) { + return false; + } + int32_t type; + try { + type = std::stoi(vec[1]); + } catch (std::invalid_argument &e) { + MS_LOG(WARNING) << "Conversion to unsigned long failed, invalid argument, " << vec[0] << "."; + return false; + } catch (std::out_of_range &e) { + MS_LOG(WARNING) << "Conversion to unsigned long failed, out of range, " << vec[0] << "."; + return false; + } + // train:0, valid=1, test=2 + if (dataset_type_ == "train" && (type == 0)) { + return true; + } else if (dataset_type_ == "valid" && (type == 1)) { + return true; + } else if (dataset_type_ == "test" && (type == 2)) { + return true; + } + + return false; +} + +Status CelebAOp::ParseImageAttrInfo() { + std::vector image_infos; + bool needMoreData = true; + RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); + while (!image_infos.empty() && needMoreData) { + for (uint32_t index = 0; index < image_infos.size(); index++) { + std::string image_info = image_infos[index]; + std::vector split = Split(image_info); + std::pair> image_labels; + + Path path(folder_path_); + Path file_path = path / split[0]; + if (!extensions_.empty() && extensions_.find(file_path.Extension()) == extensions_.end()) { + MS_LOG(WARNING) << "Unsupported file found at " << file_path.toString().c_str() << ", its extension is " + << file_path.Extension().c_str() << "."; + continue; + } + image_labels.first = split[0]; + for (uint32_t label_index = 1; label_index < split.size(); label_index++) { + int32_t value; + try { + value = std::stoi(split[label_index]); + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED("Conversion to int failed, invalid argument."); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); + } + image_labels.second.push_back(value); + } + + image_labels_vec_.push_back(image_labels); + } + + RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); + } + + num_rows_ = image_labels_vec_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " + "validation first."); + } + MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; + return Status::OK(); +} + +std::vector CelebAOp::Split(const std::string &line) { + std::string str = line; + std::string::size_type pos; + std::vector split; + str += " "; + int size = str.size(); + for (uint32_t index = 0; index < size;) { + pos = str.find(" ", index); + if (pos != index) { // skip space + std::string s = str.substr(index, pos - index); + split.push_back(s); + } + index = pos + 1; + } + + return split; +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status CelebAOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr data_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&data_buffer)); + RETURN_IF_NOT_OK(AddIOBlock(&data_buffer)); + return Status::OK(); +} + +Status CelebAOp::AddIOBlock(std::unique_ptr *data_buffer) { + int64_t buff_count = 0; + while (true) { + std::vector keys; + keys.reserve(rows_per_buffer_); + int64_t row_count = 0; + while (!(*data_buffer)->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) { + MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << "."; + continue; + } + keys.push_back(*itr); + row_count++; + if (row_count % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buff_count++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); + } + + if (!keys.empty()) { + RETURN_IF_NOT_OK(io_block_queues_[(buff_count++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_block_queues_[(buff_count++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(data_buffer)); + } + } +} + +Status CelebAOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty()) { + return Status::OK(); // empty key is a quit signal for workers + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unexpected nullptr received in worker"); +} + +Status CelebAOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + for (const auto &key : keys) { + TensorRow row; + RETURN_IF_NOT_OK(LoadTensorRow(key, image_labels_vec_[key], &row)); + deq->push_back(std::move(row)); + } + + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +Status CelebAOp::LoadTensorRow(row_id_type row_id, const std::pair> &image_label, + TensorRow *row) { + std::shared_ptr image; + std::shared_ptr label; + + Path path(folder_path_); + Path image_path = path / image_label.first; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, image_path.toString())); + if (decode_ == true) { + Status rc = Decode(image, &image); + if (rc.IsError()) { + image = nullptr; + std::string err_msg = "Fail to decode image: " + image_path.toString(); + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); + } + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), + TensorShape({1, (uint32_t)image_label.second.size()}), + data_schema_->column(1).type())); + RETURN_IF_NOT_OK(label->Zero()); + for (uint32_t index = 0; index < image_label.second.size(); index++) { + if (image_label.second[index] == 1) { + label->SetItemAt({0, static_cast(index)}, 1); + } else { + label->SetItemAt({0, static_cast(index)}, 0); + } + } + label->Squeeze(); + + (*row) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +void CelebAOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status CelebAOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CelebAOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CelebAOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t index = 0; index < data_schema_->NumColumns(); index++) { + column_name_id_map_[data_schema_->column(index).name()] = index; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h new file mode 100644 index 0000000000..ef183f8e65 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/celeba_op.h @@ -0,0 +1,240 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H +#define DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +#define CLOSE_FILE(attr_file, pairition_file) \ + do { \ + attr_file.close(); \ + if (pairition_file.is_open()) { \ + pairition_file.close(); \ + } \ + } while (false) + +namespace mindspore { +namespace dataset { +class CelebAOp : public ParallelOp, RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of CelebAOp + // @return Builder setter method returns reference to the builder. + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_op_connector_size_ = size; + return *this; + } + + // Setter method + // @param std::set & exts, file extensions to be read + // @return Builder setter method returns reference to the builder. + Builder &SetExtensions(const std::set &exts) { + builder_extensions_ = exts; + return *this; + } + + // Setter method + // @param bool decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool decode) { + builder_decode_ = decode; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string &dir + // @return Builder setter method returns reference to the builder. + Builder &SetCelebADir(const std::string &dir) { + builder_dir_ = dir; + return *this; + } + + // Setter method + // @param const std::string dataset_type: type to be read + // @return Builder setter method returns reference to the builder. + Builder &SetDatasetType(const std::string &dataset_type) { + builder_dataset_type_ = dataset_type; + return *this; + } + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + std::string builder_dir_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::set builder_extensions_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + std::string builder_dataset_type_; + }; + + // Constructor + // @param int32_t - num_workers - Num of workers reading images in parallel + // @param int32_t - rows_per_buffer Number of images (rows) in each buffer + // @param std::string - dir directory of celeba dataset + // @param int32_t queueSize - connector queue size + // @param std::unique_ptr sampler - sampler tells CelebAOp what to read + CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, + const std::string &dataset_type, const std::set &exts, std::unique_ptr schema, + std::shared_ptr sampler); + + ~CelebAOp() override = default; + + // Main Loop of CelebaOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t worker_id - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // Method in operator(), to fill IOBlockQueue + // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue + // @return Status - The error code return + Status AddIOBlock(std::unique_ptr *data_buffer); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const { return "CelebAOp"; } + + private: + // Called first when function is called + // @return + Status LaunchThreadsAndInitOp(); + + // Parse attribute file + // @return + Status ParseAttrFile(); + + // Parse each image line in attribute file + // @return + Status ParseImageAttrInfo(); + + // Split attribute info with space + // @param std::string - line - Line from att or partition file + // @return std::vector - string after split + std::vector Split(const std::string &line); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param std::pair - > + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::pair> &image_label, + TensorRow *row); + + // Check if need read according to dataset type + // @return bool - if need read + bool CheckDatasetTypeValid(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; + std::string folder_path_; // directory of celeba folder + bool decode_; + std::set extensions_; // extensions allowed + std::unique_ptr data_schema_; + std::unique_ptr>> attr_info_queue_; + int64_t num_rows_in_attr_file_; // rows number specified in attr file + QueueList> io_block_queues_; + WaitPost wp_; + std::vector>> image_labels_vec_; + std::string dataset_type_; + std::ifstream partition_file_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CELEBA_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc new file mode 100644 index 0000000000..06be682bfd --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.cc @@ -0,0 +1,472 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +constexpr uint32_t kCifarImageHeight = 32; +constexpr uint32_t kCifarImageWidth = 32; +constexpr uint32_t kCifarImageChannel = 3; +constexpr uint32_t kCifarBlockImageNum = 5; +constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; + +CifarOp::Builder::Builder() : sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + num_workers_ = cfg->num_parallel_workers(); + rows_per_buffer_ = cfg->rows_per_buffer(); + op_connect_size_ = cfg->op_connector_size(); + cifar_type_ = kCifar10; +} + +Status CifarOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + sampler_ = std::make_shared(start_index, num_samples); + } + schema_ = std::make_unique(); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + if (cifar_type_ == kCifar10) { + RETURN_IF_NOT_OK( + schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + } else { + RETURN_IF_NOT_OK(schema_->AddColumn( + ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + TensorShape another_scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(schema_->AddColumn( + ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); + } + + *ptr = std::make_shared(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, + std::move(schema_), std::move(sampler_)); + return Status::OK(); +} + +Status CifarOp::Builder::SanityCheck() { + Path dir(dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : ""; + err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, + int32_t queue_size, std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_works, queue_size, std::move(sampler)), + cifar_type_(type), + rows_per_buffer_(rows_per_buf), + folder_path_(file_dir), + data_schema_(std::move(data_schema)), + row_cnt_(0), + buf_cnt_(0) { + constexpr uint64_t kUtilQueueSize = 512; + cifar_raw_data_block_ = std::make_unique>>(kUtilQueueSize); + io_block_queues_.Init(num_workers_, queue_size); +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status CifarOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); itr++) { + keys.push_back(*itr); + row_cnt_++; + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +Status CifarOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK( + tree_->AllTasks()->CreateAsyncTask("Get cifar data block", std::bind(&CifarOp::ReadCifarBlockDataAsync, this))); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CifarOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + // The order of the following 2 functions must not be changed! + RETURN_IF_NOT_OK(ParseCifarData()); // Parse cifar data and get num rows, blocking + RETURN_IF_NOT_OK(InitSampler()); // Pass numRows to Sampler + return Status::OK(); +} + +// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +// IMPORTANT: 1 IOBlock produces 1 DataBuffer +Status CifarOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) { + return Status::OK(); // empty key is a quit signal for workers + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label). 1 function call produces 1 TensorTow in a DataBuffer +Status CifarOp::LoadTensorRow(uint64_t index, TensorRow *trow) { + std::shared_ptr label; + std::shared_ptr fine_label; + std::shared_ptr ori_image = cifar_image_label_pairs_[index].first; + std::shared_ptr copy_image = + std::make_shared(ori_image->shape(), ori_image->type(), ori_image->GetBuffer()); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), + data_schema_->column(1).type(), + reinterpret_cast(&cifar_image_label_pairs_[index].second[0]))); + if (cifar_image_label_pairs_[index].second.size() > 1) { + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &fine_label, data_schema_->column(2).tensorImpl(), data_schema_->column(2).shape(), + data_schema_->column(2).type(), reinterpret_cast(&cifar_image_label_pairs_[index].second[1]))); + (*trow) = TensorRow(index, {copy_image, std::move(label), std::move(fine_label)}); + } else { + (*trow) = TensorRow(index, {copy_image, std::move(label)}); + } + + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status CifarOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + for (const int64_t &key : keys) { + TensorRow trow; + RETURN_IF_NOT_OK(LoadTensorRow(key, &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void CifarOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nCifar directory: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status CifarOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status CifarOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +Status CifarOp::ReadCifarBlockDataAsync() { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(GetCifarFiles()); + if (cifar_type_ == kCifar10) { + RETURN_IF_NOT_OK(ReadCifar10BlockData()); + } else { + RETURN_IF_NOT_OK(ReadCifar100BlockData()); + } + + return Status::OK(); +} + +Status CifarOp::ReadCifar10BlockData() { + constexpr uint32_t num_cifar10_records = 10000; + uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M + std::vector image_data(block_size * sizeof(unsigned char), 0); + for (auto &file : cifar_files_) { + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + std::string err_msg = file + " can not be opened."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + + for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) { + (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); + if (in.fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); + } + (void)cifar_raw_data_block_->EmplaceBack(image_data); + } + in.close(); + } + (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // end block + + return Status::OK(); +} + +Status CifarOp::ReadCifar100BlockData() { + uint32_t num_cifar100_records = 0; // test:10000, train:50000 + uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M + std::vector image_data(block_size * sizeof(unsigned char), 0); + for (auto &file : cifar_files_) { + int pos = file.find_last_of('/'); + if (pos == std::string::npos) { + RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path"); + } + std::string file_name(file.substr(pos + 1)); + if (file_name.find("test") != std::string::npos) { + num_cifar100_records = 10000; + } else if (file_name.find("train") != std::string::npos) { + num_cifar100_records = 50000; + } else { + RETURN_STATUS_UNEXPECTED("Cifar 100 file not found!"); + } + + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + RETURN_STATUS_UNEXPECTED(file + " can not be opened."); + } + + for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) { + (void)in.read(reinterpret_cast(&(image_data[0])), block_size * sizeof(unsigned char)); + if (in.fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); + } + (void)cifar_raw_data_block_->EmplaceBack(image_data); + } + in.close(); + } + (void)cifar_raw_data_block_->EmplaceBack(std::vector()); // block end + return Status::OK(); +} + +Status CifarOp::GetCifarFiles() { + // Initialize queue to hold the file names + const std::string kExtension = ".bin"; + Path dataset_directory(folder_path_); + auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory); + if (dirIt) { + while (dirIt->hasNext()) { + Path file = dirIt->next(); + std::string filename = file.toString(); + if (filename.find(kExtension) != std::string::npos) { + cifar_files_.push_back(filename); + MS_LOG(INFO) << "Cifar operator found file at " << filename << "."; + } + } + } else { + std::string err_msg = "Unable to open directory " + dataset_directory.toString(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::sort(cifar_files_.begin(), cifar_files_.end()); + return Status::OK(); +} + +Status CifarOp::ParseCifarData() { + std::vector block; + RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); + uint32_t cur_block_index = 0; + while (!block.empty()) { + for (uint32_t index = 0; index < kCifarBlockImageNum; ++index) { + std::vector labels; + uint32_t label = block[cur_block_index++]; + labels.push_back(label); + if (cifar_type_ == kCifar100) { + uint32_t fine_label = block[cur_block_index++]; + labels.push_back(fine_label); + } + + std::shared_ptr image_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image_tensor, data_schema_->column(0).tensorImpl(), + TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), + data_schema_->column(0).type())); + auto itr = image_tensor->begin(); + uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; + for (int pix = 0; pix < total_pix; ++pix) { + for (int ch = 0; ch < kCifarImageChannel; ++ch) { + *itr = block[cur_block_index + ch * total_pix + pix]; + itr++; + } + } + cur_block_index += total_pix * kCifarImageChannel; + cifar_image_label_pairs_.emplace_back(std::make_pair(image_tensor, labels)); + } + RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); + cur_block_index = 0; + } + cifar_image_label_pairs_.shrink_to_fit(); + num_rows_ = cifar_image_label_pairs_.size(); + if (num_rows_ == 0) { + std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; + std::string err_msg = "There is no valid data matching the dataset API " + api + + ".Please check file path or dataset API validation first."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + cifar_raw_data_block_->Reset(); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status CifarOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty()) { + RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); + } + + for (uint64_t index = 0; index < cifar_image_label_pairs_.size(); ++index) { + uint32_t label = (cifar_image_label_pairs_[index].second)[0]; + (*cls_ids)[label].push_back(index); + } + + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { + // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); + RETURN_IF_NOT_OK(op->GetCifarFiles()); + if (op->cifar_type_ == kCifar10) { + constexpr int64_t num_cifar10_records = 10000; + for (auto &file : op->cifar_files_) { + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + std::string err_msg = file + " can not be opened."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + *count = *count + num_cifar10_records; + } + return Status::OK(); + } else { + int64_t num_cifar100_records = 0; + for (auto &file : op->cifar_files_) { + size_t pos = file.find_last_of('/'); + if (pos == std::string::npos) { + std::string err_msg = "Invalid cifar100 file path"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::string file_name; + if (file.size() > 0) + file_name = file.substr(pos + 1); + else + RETURN_STATUS_UNEXPECTED("Invalid string length!"); + if (file_name.find("test") != std::string::npos) { + num_cifar100_records = 10000; + } else if (file_name.find("train") != std::string::npos) { + num_cifar100_records = 50000; + } + std::ifstream in(file, std::ios::binary); + if (!in.is_open()) { + std::string err_msg = file + " can not be opened."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + *count = num_cifar100_records; + return Status::OK(); + } +} + +// Visitor accept method for NodePass +Status CifarOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CifarOp::ComputeColMap() { + // set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (uint32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h new file mode 100644 index 0000000000..60169f32bf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/cifar_op.h @@ -0,0 +1,236 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +class CifarOp : public ParallelOp, public RandomAccessOp { + public: + enum CifarType { kCifar10, kCifar100 }; + + class Builder { + public: + // Constructor for Builder class of CifarOp + // @return Builder setter method returns reference to the builder. + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param uint32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param uint32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + op_connect_size_ = size; + return *this; + } + + // Setter method + // @param uint32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetCifarDir(const std::string &dir) { + dir_ = dir; + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetCifarType(const bool cifar10) { + if (cifar10) { + cifar_type_ = kCifar10; + } else { + cifar_type_ = kCifar100; + } + return *this; + } + + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + std::string dir_; + int32_t num_workers_; + int32_t rows_per_buffer_; + int32_t op_connect_size_; + std::shared_ptr sampler_; + std::unique_ptr schema_; + CifarType cifar_type_; + }; + + // Constructor + // @param CifarType type - Cifar10 or Cifar100 + // @param uint32_t numWorks - Num of workers reading images in parallel + // @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer + // @param std::string - dir directory of cifar dataset + // @param uint32_t - queueSize - connector queue size + // @param std::unique_ptr sampler - sampler tells ImageFolderOp what to read + CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + // Destructor. + ~CifarOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param uint32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of CifarOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // Function to count the number of samples in the CIFAR dataset + // @param dir path to the CIFAR directory + // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 + // @param count output arg that will hold the actual dataset size + // @return + static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "CifarOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to a pair + // @param uint64_t index - index need to load + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(uint64_t index, TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Read block data from cifar file + // @return + Status ReadCifarBlockDataAsync(); + + // Called first when function is called + // @return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Get cifar files in dir + // @return + Status GetCifarFiles(); + + // Read cifar10 data as block + // @return + Status ReadCifar10BlockData(); + + // Read cifar100 data as block + // @return + Status ReadCifar100BlockData(); + + // Parse cifar data + // @return + Status ParseCifarData(); + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each calss + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + CifarType cifar_type_; + int32_t rows_per_buffer_; + std::string folder_path_; + std::unique_ptr data_schema_; + int64_t row_cnt_; + int64_t buf_cnt_; + + WaitPost wp_; + QueueList> io_block_queues_; + std::unique_ptr>> cifar_raw_data_block_; + std::vector cifar_files_; + std::vector, std::vector>> cifar_image_label_pairs_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CIFAR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc new file mode 100644 index 0000000000..958514583a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.cc @@ -0,0 +1,555 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/clue_op.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +ClueOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status ClueOp::Builder::ValidateInputs() const { + std::string err; + err += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err += (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) ? "Wrong sharding configs\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err); +} + +Status ClueOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_clue_files_list_.size()) { + builder_num_workers_ = builder_clue_files_list_.size(); + MS_LOG(WARNING) << "ClueOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + ColKeyMap ck_map; + for (auto &p : builder_cols_to_keyword_) { + ck_map.insert({p.first, split(p.second, '/')}); + } + + std::shared_ptr clue_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, + builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, + builder_device_id_); + RETURN_IF_NOT_OK(clue_op->Init()); + *op = std::move(clue_op); + + return Status::OK(); +} + +std::vector ClueOp::Builder::split(const std::string &s, char delim) { + std::vector res; + std::stringstream ss(s); + std::string item; + + while (getline(ss, item, delim)) { + res.push_back(item); + } + return res; +} + +ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + rows_per_buffer_(rows_per_buffer), + num_rows_per_shard_(0), + all_num_rows_(0), + num_samples_(num_samples), + filename_index_(std::make_unique()), + clue_files_list_(std::move(clue_files_list)), + load_jagged_connector_(true), + cols_to_keyword_(cols_to_keyword), + shuffle_files_(shuffle_files), + finished_reading_dataset_(false), + num_devices_(num_device), + device_id_(device_id), + load_io_block_queue_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status ClueOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(clue_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + + return Status::OK(); +} + +Status ClueOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status ClueOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status ClueOp::GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t) { + nlohmann::json cursor = js; + for (int i = 0; i < key_chain.size(); i++) { + if (cursor.find(key_chain[i]) != cursor.end()) { + cursor = cursor[key_chain[i]]; + } else { + RETURN_STATUS_UNEXPECTED("Failed to find key: " + key_chain[i]); + } + } + std::string final_str = key_chain.back(); + switch (cursor.type()) { + case nlohmann::detail::value_t::string: + RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get()}, TensorShape::CreateScalar())); + break; + + case nlohmann::detail::value_t::number_integer: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::number_unsigned: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_INT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::number_float: + RETURN_IF_NOT_OK( + Tensor::CreateTensor(t, TensorImpl::kFlexible, TensorShape::CreateScalar(), DataType(DataType::DE_FLOAT32))); + (*t)->SetItemAt({0}, cursor.get()); + break; + case nlohmann::detail::value_t::array: + RETURN_IF_NOT_OK(Tensor::CreateTensor(t, {cursor.get>()}, TensorShape::CreateScalar())); + break; + default: + break; + } + return Status::OK(); +} + +Status ClueOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + if (line.empty()) { + continue; + } + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + try { + nlohmann::json js = nlohmann::json::parse(line); + int cols_count = cols_to_keyword_.size(); + TensorRow tRow(cols_count, nullptr); + tensor_table->push_back(std::move(tRow)); + + int cout = 0; + for (auto &p : cols_to_keyword_) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(GetValue(js, p.second, &tensor)); + (*tensor_table)[rows_each_buffer][cout] = std::move(tensor); + cout++; + } + } catch (const std::exception &err) { + // Catch any exception and convert to Status return code + RETURN_STATUS_UNEXPECTED("Failed to load json file"); + } + + // RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + return Status::OK(); +} + +Status ClueOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&ClueOp::WaitToFillIOBlockQueue, this))); + + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&ClueOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + return Status::OK(); +} + +Status ClueOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// A print method typically used for debugging +void ClueOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ + << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ + << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; + for (int i = 0; i < clue_files_list_.size(); ++i) { + out << " " << clue_files_list_[i]; + } + out << "\n\n"; + } +} + +// Pops an element from a queue in io_block_queues +Status ClueOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status ClueOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +Status ClueOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +Status ClueOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +void ClueOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +bool ClueOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status ClueOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +Status ClueOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API CLUEDataset. Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +int64_t ClueOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + if (!line.empty()) { + count++; + } + } + + return count; +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status ClueOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +Status ClueOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetClueFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +Status ClueOp::ComputeColMap() { + // Set the column name mapping (base class field) + if (column_name_id_map_.empty()) { + int count = 0; + for (auto &p : cols_to_keyword_) { + column_name_id_map_[p.first] = count; + count++; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h new file mode 100644 index 0000000000..ab429561ec --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/clue_op.h @@ -0,0 +1,277 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; +using ColKeyMap = std::map>; + +class JaggedConnector; + +class ClueOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetClueFilesList(const std::vector &files_list) { + builder_clue_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColsKeyMap(const std::map &cols_to_key) { + builder_cols_to_keyword_ = cols_to_key; + return *this; + } + + // Split string based on a character delimiter + // @return - the a string vector + std::vector split(const std::string &s, char delim); + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_clue_files_list_; + bool builder_shuffle_files_; + std::map builder_cols_to_keyword_; + }; + + // Constructor of ClueOp + ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~ClueOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all clue files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return clue_files_list_; } + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a clue file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - clue file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // @return Status - the error code returned. + Status GetValue(const nlohmann::json &js, std::vector key_chain, std::shared_ptr *t); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t device_id_; + bool shuffle_files_; + bool finished_reading_dataset_; + int32_t num_devices_; + int64_t rows_per_buffer_; + bool load_io_block_queue_; + int64_t num_rows_per_shard_; + int64_t all_num_rows_; + int64_t num_samples_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + std::vector clue_files_list_; + WaitPost io_block_queue_wait_post_; + std::unique_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + bool load_jagged_connector_; + ColKeyMap cols_to_keyword_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc new file mode 100644 index 0000000000..daef2f284b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.cc @@ -0,0 +1,646 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/coco_op.h" + +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +const char kColumnImage[] = "image"; +const char kJsonImages[] = "images"; +const char kJsonImagesFileName[] = "file_name"; +const char kJsonId[] = "id"; +const char kJsonAnnotations[] = "annotations"; +const char kJsonAnnoSegmentation[] = "segmentation"; +const char kJsonAnnoCounts[] = "counts"; +const char kJsonAnnoSegmentsInfo[] = "segments_info"; +const char kJsonAnnoIscrowd[] = "iscrowd"; +const char kJsonAnnoBbox[] = "bbox"; +const char kJsonAnnoArea[] = "area"; +const char kJsonAnnoImageId[] = "image_id"; +const char kJsonAnnoNumKeypoints[] = "num_keypoints"; +const char kJsonAnnoKeypoints[] = "keypoints"; +const char kJsonAnnoCategoryId[] = "category_id"; +const char kJsonCategories[] = "categories"; +const char kJsonCategoriesIsthing[] = "isthing"; +const char kJsonCategoriesName[] = "name"; +const float kDefaultPadValue = -1.0; +const unsigned int kPadValueZero = 0; + +CocoOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); + builder_task_type_ = TaskType::Detection; +} + +Status CocoOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + switch (builder_task_type_) { + case TaskType::Detection: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case TaskType::Stuff: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoSegmentation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case TaskType::Keypoint: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoKeypoints), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoNumKeypoints), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + case TaskType::Panoptic: + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoBbox), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoCategoryId), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoIscrowd), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kJsonAnnoArea), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + break; + default: + RETURN_STATUS_UNEXPECTED("Invalid task type"); + } + *ptr = std::make_shared(builder_task_type_, builder_dir_, builder_file_, builder_num_workers_, + builder_rows_per_buffer_, builder_op_connector_size_, builder_decode_, + std::move(builder_schema_), std::move(builder_sampler_)); + return Status::OK(); +} + +Status CocoOp::Builder::SanityCheck() { + Path dir(builder_dir_); + Path file(builder_file_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "Coco image folder path is invalid or not set\n" : ""; + err_msg += file.Exists() == false ? "Coco annotation json path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, + int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size), + decode_(decode), + row_cnt_(0), + buf_cnt_(0), + task_type_(task_type), + image_folder_path_(image_folder_path), + annotation_path_(annotation_path), + rows_per_buffer_(rows_per_buffer), + sampler_(std::move(sampler)), + data_schema_(std::move(data_schema)) { + io_block_queues_.Init(num_workers_, queue_size); +} + +Status CocoOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) > num_rows_) continue; + keys->push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); + keys->clear(); + } + } + return Status::OK(); +} + +Status CocoOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); + if (sample_ids->type() != DataType(DataType::DE_INT64)) { + RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); + } + RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +void CocoOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows: " << num_rows_ << "\nCOCO Directory: " << image_folder_path_ << "\n\n"; + } +} + +Status CocoOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); + return Status::OK(); +} + +Status CocoOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { + std::shared_ptr image, coordinate; + auto itr = coordinate_map_.find(image_id); + if (itr == coordinate_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + std::string kImageFile = image_folder_path_ + image_id; + RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); + + auto bboxRow = itr->second; + std::vector bbox_row; + dsize_t bbox_row_num = static_cast(bboxRow.size()); + dsize_t bbox_column_num = 0; + for (auto bbox : bboxRow) { + if (static_cast(bbox.size()) > bbox_column_num) { + bbox_column_num = static_cast(bbox.size()); + } + } + + for (auto bbox : bboxRow) { + bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); + dsize_t pad_len = bbox_column_num - static_cast(bbox.size()); + if (pad_len > 0) { + for (dsize_t i = 0; i < pad_len; i++) { + bbox_row.push_back(kDefaultPadValue); + } + } + } + + std::vector bbox_dim = {bbox_row_num, bbox_column_num}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&coordinate, data_schema_->column(1).tensorImpl(), TensorShape(bbox_dim), + data_schema_->column(1).type(), + reinterpret_cast(&bbox_row[0]))); + if (task_type_ == TaskType::Detection) { + RETURN_IF_NOT_OK(LoadDetectionTensorRow(row_id, image_id, image, coordinate, trow)); + } else if (task_type_ == TaskType::Stuff || task_type_ == TaskType::Keypoint) { + RETURN_IF_NOT_OK(LoadSimpleTensorRow(row_id, image_id, image, coordinate, trow)); + } else if (task_type_ == TaskType::Panoptic) { + RETURN_IF_NOT_OK(LoadMixTensorRow(row_id, image_id, image, coordinate, trow)); + } else { + RETURN_STATUS_UNEXPECTED("Invalid task type."); + } + + return Status::OK(); +} + +// When task is Detection, user can get data with four columns: +// column ["image"] with datatype=uint8 +// column ["bbox"] with datatype=float32 +// column ["category_id"] with datatype=uint32 +// column ["iscrowd"] with datatype=uint32 +// By the way, column ["iscrowd"] is used for some testcases, like fasterRcnn. +// If "iscrowd" is not existed, user will get default value 0. +Status CocoOp::LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow) { + std::shared_ptr category_id, iscrowd; + std::vector category_id_row; + std::vector iscrowd_row; + auto itr_item = simple_item_map_.find(image_id); + if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + std::vector annotation = itr_item->second; + for (int64_t i = 0; i < annotation.size(); i++) { + if (i % 2 == 0) { + category_id_row.push_back(annotation[i]); + } else if (i % 2 == 1) { + iscrowd_row.push_back(annotation[i]); + } + } + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), + data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), + data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); + (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd)}); + return Status::OK(); +} + +// When task is "Stuff"/"Keypoint", user can get data with three columns: +// column ["image"] with datatype=uint8 +// column ["segmentation"]/["keypoints"] with datatype=float32 +// column ["iscrowd"]/["num_keypoints"] with datatype=uint32 +Status CocoOp::LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow) { + std::shared_ptr item; + std::vector item_queue; + auto itr_item = simple_item_map_.find(image_id); + if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + item_queue = itr_item->second; + std::vector bbox_dim = {static_cast(item_queue.size()), 1}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&item, data_schema_->column(2).tensorImpl(), TensorShape(bbox_dim), + data_schema_->column(2).type(), + reinterpret_cast(&item_queue[0]))); + (*trow) = TensorRow(row_id, {std::move(image), std::move(coordinate), std::move(item)}); + return Status::OK(); +} + +// When task is "Panoptic", user can get data with five columns: +// column ["image"] with datatype=uint8 +// column ["bbox"] with datatype=float32 +// column ["category_id"] with datatype=uint32 +// column ["iscrowd"] with datatype=uint32 +// column ["area"] with datattype=uint32 +Status CocoOp::LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow) { + std::shared_ptr category_id, iscrowd, area; + std::vector category_id_row; + std::vector iscrowd_row; + std::vector area_row; + auto itr_item = simple_item_map_.find(image_id); + if (itr_item == simple_item_map_.end()) RETURN_STATUS_UNEXPECTED("Invalid image_id found :" + image_id); + + std::vector annotation = itr_item->second; + for (int64_t i = 0; i < annotation.size(); i++) { + if (i % 3 == 0) { + category_id_row.push_back(annotation[i]); + } else if (i % 3 == 1) { + iscrowd_row.push_back(annotation[i]); + } else if (i % 3 == 2) { + area_row.push_back(annotation[i]); + } + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &category_id, data_schema_->column(2).tensorImpl(), TensorShape({static_cast(category_id_row.size()), 1}), + data_schema_->column(2).type(), reinterpret_cast(&category_id_row[0]))); + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &iscrowd, data_schema_->column(3).tensorImpl(), TensorShape({static_cast(iscrowd_row.size()), 1}), + data_schema_->column(3).type(), reinterpret_cast(&iscrowd_row[0]))); + + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &area, data_schema_->column(4).tensorImpl(), TensorShape({static_cast(area_row.size()), 1}), + data_schema_->column(4).type(), reinterpret_cast(&area_row[0]))); + (*trow) = TensorRow( + row_id, {std::move(image), std::move(coordinate), std::move(category_id), std::move(iscrowd), std::move(area)}); + return Status::OK(); +} + +Status CocoOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const int64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +Status CocoOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +template +Status CocoOp::SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node) { + auto node = input_tree.find(node_name); + if (node == input_tree.end()) RETURN_STATUS_UNEXPECTED("Invalid node found in json : " + node_name); + (*output_node) = *node; + return Status::OK(); +} + +Status CocoOp::ParseAnnotationIds() { + std::ifstream in(annotation_path_); + nlohmann::json js; + in >> js; + + std::vector image_que; + nlohmann::json image_list; + RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonImages), &image_list)); + RETURN_IF_NOT_OK(ImageColumnLoad(image_list, &image_que)); + if (task_type_ == TaskType::Detection || task_type_ == TaskType::Panoptic) { + nlohmann::json node_categories; + RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonCategories), &node_categories)); + RETURN_IF_NOT_OK(CategoriesColumnLoad(node_categories)); + } + nlohmann::json annotations_list; + RETURN_IF_NOT_OK(SearchNodeInJson(js, std::string(kJsonAnnotations), &annotations_list)); + for (auto annotation : annotations_list) { + int32_t image_id = 0, id = 0; + std::string file_name; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonAnnoImageId), &image_id)); + auto itr_file = image_index_.find(image_id); + if (itr_file == image_index_.end()) + RETURN_STATUS_UNEXPECTED("Invalid image id of annotations : " + std::to_string(image_id)); + file_name = itr_file->second; + switch (task_type_) { + case TaskType::Detection: + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); + RETURN_IF_NOT_OK(DetectionColumnLoad(annotation, file_name, id)); + break; + case TaskType::Stuff: + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); + RETURN_IF_NOT_OK(StuffColumnLoad(annotation, file_name, id)); + break; + case TaskType::Keypoint: + RETURN_IF_NOT_OK(SearchNodeInJson(annotation, std::string(kJsonId), &id)); + RETURN_IF_NOT_OK(KeypointColumnLoad(annotation, file_name, id)); + break; + case TaskType::Panoptic: + RETURN_IF_NOT_OK(PanopticColumnLoad(annotation, file_name, image_id)); + break; + default: + RETURN_STATUS_UNEXPECTED("Invalid task type"); + } + } + for (auto img : image_que) { + if (coordinate_map_.find(img) != coordinate_map_.end()) image_ids_.push_back(img); + } + num_rows_ = image_ids_.size(); + return Status::OK(); +} + +Status CocoOp::ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec) { + if (image_tree.size() == 0) { + RETURN_STATUS_UNEXPECTED("No images found in " + annotation_path_); + } + for (auto img : image_tree) { + std::string file_name; + int32_t id = 0; + RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonImagesFileName), &file_name)); + RETURN_IF_NOT_OK(SearchNodeInJson(img, std::string(kJsonId), &id)); + + image_index_[id] = file_name; + image_vec->push_back(file_name); + } + return Status::OK(); +} + +Status CocoOp::DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &unique_id) { + std::vector bbox; + nlohmann::json node_bbox; + uint32_t category_id = 0, iscrowd = 0; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoBbox), &node_bbox)); + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoCategoryId), &category_id)); + auto search_category = category_set_.find(category_id); + if (search_category == category_set_.end()) + RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + std::to_string(category_id)); + auto node_iscrowd = annotation_tree.find(kJsonAnnoIscrowd); + if (node_iscrowd != annotation_tree.end()) iscrowd = *node_iscrowd; + bbox.insert(bbox.end(), node_bbox.begin(), node_bbox.end()); + coordinate_map_[image_file].push_back(bbox); + simple_item_map_[image_file].push_back(category_id); + simple_item_map_[image_file].push_back(iscrowd); + return Status::OK(); +} + +Status CocoOp::StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &unique_id) { + uint32_t iscrowd = 0; + std::vector bbox; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoIscrowd), &iscrowd)); + simple_item_map_[image_file].push_back(iscrowd); + nlohmann::json segmentation; + RETURN_IF_NOT_OK(SearchNodeInJson(annotation_tree, std::string(kJsonAnnoSegmentation), &segmentation)); + if (iscrowd == 0) { + for (auto item : segmentation) { + if (bbox.size() > 0) bbox.clear(); + bbox.insert(bbox.end(), item.begin(), item.end()); + coordinate_map_[image_file].push_back(bbox); + } + } else if (iscrowd == 1) { + nlohmann::json segmentation_count; + RETURN_IF_NOT_OK(SearchNodeInJson(segmentation, std::string(kJsonAnnoCounts), &segmentation_count)); + bbox.insert(bbox.end(), segmentation_count.begin(), segmentation_count.end()); + coordinate_map_[image_file].push_back(bbox); + } + return Status::OK(); +} + +Status CocoOp::KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &unique_id) { + auto itr_num_keypoint = annotation_tree.find(kJsonAnnoNumKeypoints); + if (itr_num_keypoint == annotation_tree.end()) + RETURN_STATUS_UNEXPECTED("No num_keypoint found in annotations where id: " + std::to_string(unique_id)); + simple_item_map_[image_file].push_back(*itr_num_keypoint); + auto itr_keypoint = annotation_tree.find(kJsonAnnoKeypoints); + if (itr_keypoint == annotation_tree.end()) + RETURN_STATUS_UNEXPECTED("No keypoint found in annotations where id: " + std::to_string(unique_id)); + coordinate_map_[image_file].push_back(*itr_keypoint); + return Status::OK(); +} + +Status CocoOp::PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, + const int32_t &image_id) { + auto itr_segments = annotation_tree.find(kJsonAnnoSegmentsInfo); + if (itr_segments == annotation_tree.end()) + RETURN_STATUS_UNEXPECTED("No segments_info found in annotations where image_id: " + std::to_string(image_id)); + for (auto info : *itr_segments) { + std::vector bbox; + uint32_t category_id = 0; + auto itr_bbox = info.find(kJsonAnnoBbox); + if (itr_bbox == info.end()) + RETURN_STATUS_UNEXPECTED("No bbox found in segments_info where image_id: " + std::to_string(image_id)); + bbox.insert(bbox.end(), itr_bbox->begin(), itr_bbox->end()); + coordinate_map_[image_file].push_back(bbox); + + RETURN_IF_NOT_OK(SearchNodeInJson(info, std::string(kJsonAnnoCategoryId), &category_id)); + auto search_category = category_set_.find(category_id); + if (search_category == category_set_.end()) + RETURN_STATUS_UNEXPECTED("category_id can't find in categories where category_id: " + + std::to_string(category_id)); + auto itr_iscrowd = info.find(kJsonAnnoIscrowd); + if (itr_iscrowd == info.end()) + RETURN_STATUS_UNEXPECTED("No iscrowd found in segments_info where image_id: " + std::to_string(image_id)); + auto itr_area = info.find(kJsonAnnoArea); + if (itr_area == info.end()) + RETURN_STATUS_UNEXPECTED("No area found in segments_info where image_id: " + std::to_string(image_id)); + simple_item_map_[image_file].push_back(category_id); + simple_item_map_[image_file].push_back(*itr_iscrowd); + simple_item_map_[image_file].push_back(*itr_area); + } + return Status::OK(); +} + +Status CocoOp::CategoriesColumnLoad(nlohmann::json categories_tree) { + if (categories_tree.size() == 0) RETURN_STATUS_UNEXPECTED("No categories found in " + annotation_path_); + for (auto category : categories_tree) { + int32_t id = 0; + std::string name; + std::vector label_info; + auto itr_id = category.find(kJsonId); + if (itr_id == category.end()) RETURN_STATUS_UNEXPECTED("No id found in categories of " + annotation_path_); + id = *itr_id; + label_info.push_back(id); + category_set_.insert(id); + + auto itr_name = category.find(kJsonCategoriesName); + if (itr_name == category.end()) + RETURN_STATUS_UNEXPECTED("No name found in categories where id: " + std::to_string(id)); + name = *itr_name; + + if (task_type_ == TaskType::Panoptic) { + auto itr_isthing = category.find(kJsonCategoriesIsthing); + if (itr_isthing == category.end()) + RETURN_STATUS_UNEXPECTED("No isthing found in categories of " + annotation_path_); + label_info.push_back(*itr_isthing); + } + label_index_.emplace_back(std::make_pair(name, label_info)); + } + return Status::OK(); +} + +Status CocoOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +Status CocoOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CocoOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(this->ParseAnnotationIds()); + RETURN_IF_NOT_OK(this->InitSampler()); + return Status::OK(); +} + +Status CocoOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + + if (decode_ == true) { + Status rc = Decode(*tensor, tensor); + if (rc.IsError()) { + RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); + } + } + return Status::OK(); +} + +Status CocoOp::CountTotalRows(const std::string &dir, const std::string &file, const std::string &task, + int64_t *count) { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *count = static_cast(op->image_ids_.size()); + return Status::OK(); +} + +Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file, const std::string &task, + std::vector>> *output_class_indexing) { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetFile(file).SetTask(task).Build(&op)); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *output_class_indexing = op->label_index_; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status CocoOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status CocoOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h new file mode 100644 index 0000000000..31070c26f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/coco_op.h @@ -0,0 +1,340 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_COCO_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_COC0_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using CoordinateRow = std::vector>; + +class CocoOp : public ParallelOp, public RandomAccessOp { + public: + enum class TaskType { Detection = 0, Stuff = 1, Panoptic = 2, Keypoint = 3 }; + + class Builder { + public: + // Constructor for Builder class of ImageFolderOp + // @param uint32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method. + // @param const std::string & build_dir + // @return Builder setter method returns reference to the builder. + Builder &SetDir(const std::string &build_dir) { + builder_dir_ = build_dir; + return *this; + } + + // Setter method. + // @param const std::string & build_file + // @return Builder setter method returns reference to the builder. + Builder &SetFile(const std::string &build_file) { + builder_file_ = build_file; + return *this; + } + + // Setter method. + // @param const std::string & task_type + // @return Builder setter method returns reference to the builder. + Builder &SetTask(const std::string &task_type) { + if (task_type == "Detection") { + builder_task_type_ = TaskType::Detection; + } else if (task_type == "Stuff") { + builder_task_type_ = TaskType::Stuff; + } else if (task_type == "Panoptic") { + builder_task_type_ = TaskType::Panoptic; + } else if (task_type == "Keypoint") { + builder_task_type_ = TaskType::Keypoint; + } + return *this; + } + + // Setter method. + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method. + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Check validity of input args + // @return = The error code return + Status SanityCheck(); + + // The builder "Build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + std::string builder_dir_; + std::string builder_file_; + TaskType builder_task_type_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int32_t builder_rows_per_buffer_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + }; + + // Constructor + // @param TaskType task_type - task type of Coco + // @param std::string image_folder_path - image folder path of Coco + // @param std::string annotation_path - annotation json path of Coco + // @param int32_t num_workers - number of workers reading images in parallel + // @param int32_t rows_per_buffer - number of images (rows) in each buffer + // @param int32_t queue_size - connector queue size + // @param int64_t num_samples - number of samples to read + // @param bool decode - whether to decode images + // @param std::unique_ptr data_schema - the schema of the Coco dataset + // @param std::shared_ptr sampler - sampler tells CocoOp what to read + CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, + int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, + std::unique_ptr data_schema, std::shared_ptr sampler); + + // Destructor + ~CocoOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of CocoOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // @param const std::string &dir - Coco image dir path + // @param const std::string &file - Coco json file path + // @param const std::string &task - task mode of Coco task + // @param int64_t numSamples - samples number of CocoDataset + // @param int64_t *count - output rows number of CocoDataset + static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + int64_t *count); + + // @param const std::string &dir - Coco image dir path + // @param const std::string &file - Coco json file path + // @param const std::string &task - task mode of Coco task + // @param int64_t numSamples - samples number of CocoDataset + // @param std::map *output_class_indexing - output class index of CocoDataset + static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, + std::vector>> *output_class_indexing); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to image id + // @param row_id_type row_id - id for this tensor row + // @param std::string image_id - image id + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); + + // Load a tensor row with vector which a vector to a tensor + // @param row_id_type row_id - id for this tensor row + // @param const std::string &image_id - image is + // @param std::shared_ptr image - image tensor + // @param std::shared_ptr coordinate - coordinate tensor + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadDetectionTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow); + + // Load a tensor row with vector which a vector to a tensor + // @param row_id_type row_id - id for this tensor row + // @param const std::string &image_id - image is + // @param std::shared_ptr image - image tensor + // @param std::shared_ptr coordinate - coordinate tensor + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadSimpleTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow); + + // Load a tensor row with vector which a vector to multi-tensor + // @param row_id_type row_id - id for this tensor row + // @param const std::string &image_id - image is + // @param std::shared_ptr image - image tensor + // @param std::shared_ptr coordinate - coordinate tensor + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadMixTensorRow(row_id_type row_id, const std::string &image_id, std::shared_ptr image, + std::shared_ptr coordinate, TensorRow *trow); + + // @param const std::string &path - path to the image file + // @param const ColDescriptor &col - contains tensor implementation and datatype + // @param std::shared_ptr tensor - return + // @return Status - The error code return + Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Read annotation from Annotation folder + // @return Status - The error code return + Status ParseAnnotationIds(); + + // @param const std::shared_ptr &sample_ids - sample ids of tensor + // @param std::vector *keys - image id + // @return Status - The error code return + Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // Reset dataset state + // @return Status - The error code return + Status Reset() override; + + // @param nlohmann::json image_tree - image tree of json + // @param std::vector *image_vec - image id list of json + // @return Status - The error code return + Status ImageColumnLoad(nlohmann::json image_tree, std::vector *image_vec); + + // @param nlohmann::json categories_tree - categories tree of json + // return Status - The error code return + Status CategoriesColumnLoad(nlohmann::json categories_tree); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &id - current unique id of annotation + // @return Status - The error code return + Status DetectionColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &id - current unique id of annotation + // @return Status - The error code return + Status StuffColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &id - current unique id of annotation + // @return Status - The error code return + Status KeypointColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &id); + + // @param nlohmann::json categories_tree - categories tree of json + // @param const std::string &image_file - current image name in annotation + // @param const int32_t &image_id - current unique id of annotation + // @return Status - The error code return + Status PanopticColumnLoad(nlohmann::json annotation_tree, const std::string &image_file, const int32_t &image_id); + + template + Status SearchNodeInJson(nlohmann::json input_tree, std::string node_name, T *output_node); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + bool decode_; + int64_t row_cnt_; + int64_t buf_cnt_; + std::string image_folder_path_; + std::string annotation_path_; + TaskType task_type_; + int32_t rows_per_buffer_; + std::shared_ptr sampler_; + std::unique_ptr data_schema_; + + WaitPost wp_; + std::vector image_ids_; + std::map image_index_; + QueueList> io_block_queues_; + std::vector>> label_index_; + std::map coordinate_map_; + std::map> simple_item_map_; + std::set category_set_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_Coco_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc new file mode 100644 index 0000000000..773dfc78b6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.cc @@ -0,0 +1,267 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +GeneratorOp::Builder::Builder() { + // Some arguments to the GeneratorOp constructor have a default argument that is taken + // from the client config. + build_buffer_size_ = kCfgRowsPerBuffer; + build_op_connector_size_ = kCfgOpConnectorSize; +} + +Status GeneratorOp::Builder::SanityCheck() { + // Update queue size to fit the prefetch requirement + MS_LOG(DEBUG) << "Generator operator sanity check, prefetch size is " << build_prefetch_size_ << "."; + if (build_prefetch_size_ > 0) { + build_op_connector_size_ = (build_prefetch_size_ + build_buffer_size_ - 1) / build_buffer_size_; + } + return Status::OK(); +} + +Status GeneratorOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_generator_function_, build_column_names_, build_column_types_, + build_prefetch_size_, build_buffer_size_, build_op_connector_size_); + return (*ptr)->Init(); +} + +GeneratorOp::GeneratorOp(py::function generator_function, std::vector column_names, + std::vector column_types, int32_t prefetch_size, int32_t buffer_size, + int32_t connector_size) + : PipelineOp(connector_size), + generator_function_(generator_function), + column_names_(column_names), + column_types_(column_types), + prefetch_size_(prefetch_size), + buffer_size_(buffer_size), + buffer_id_(0) {} + +GeneratorOp::~GeneratorOp() { this->Dealloc(); } + +void GeneratorOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nColumn names:\n"; + for (int i = 0; i < column_names_.size(); ++i) { + out << "\n " << column_names_[i]; + } + out << "\n\n"; + } +} + +void GeneratorOp::Dealloc() noexcept { + // Setup GIL state + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + // GC the generator object within GIL + (void)generator_.dec_ref(); + // Release GIL + PyGILState_Release(gstate); +} + +// Reentrant init method. +Status GeneratorOp::Init() { + // Reset BufferID + buffer_id_ = 0; + Status ret; + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // Invoke the generatorFunction to get generator object + try { + generator_ = generator_function_(); + } catch (const py::error_already_set &e) { + ret = Status(StatusCode::kPyFuncException, e.what()); + } + } + return ret; +} + +Status GeneratorOp::PyRowToTensorRow(py::object py_data, TensorRow *tensor_row) { + if (!py::isinstance(py_data)) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator should return a tuple of numpy arrays."); + } + py::tuple py_row = py_data.cast(); + // Check if returned number of columns matches with column names + if (py_row.size() != column_names_.size()) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, + "Generator should return same number of numpy arrays as specified in column names."); + } + // Iterate over two containers simultaneously for memory copy + for (int i = 0; i < py_row.size(); ++i) { + py::object ret_py_ele = py_row[i]; + if (!py::isinstance(ret_py_ele)) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, + "Generator should return a tuple of numpy arrays."); + } + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, ret_py_ele.cast())); + if ((!column_types_.empty()) && (column_types_[i] != DataType::DE_UNKNOWN) && + (column_types_[i] != tensor->type())) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, "Generator type check failed."); + } + tensor_row->push_back(tensor); + } + return Status(StatusCode::kOK, ""); +} + +Status GeneratorOp::FillBuffer(TensorQTable *tt) { + for (int i = 0; i < buffer_size_; i++) { + TensorRow row; + RETURN_IF_NOT_OK(PyRowToTensorRow(generator_.attr("__next__")(), &row)); + tt->push_back(std::move(row)); + } + return Status::OK(); +} + +// Entry point for Generator, called by launch() +// Note that this function is very easy to break because of the Python GIL mechanism +// The master thread has the following workflow +// +// while !eof: +// Try: +// Prepare one data buffer GIL, Can throw +// Catch: +// Fetch Python Exception GIL +// Check if Exception is StopIteration (EOE) GIL +// Restore Python Exception GIL +// If not StopIteration: +// Return Status PyFuncException +// +// Push data buffer to connector Block +// +// if EOE +// Push EOE Block +// if more epoch: +// Block until next epoch Block +// else: +// Push EOF Block +// eof = true +// Return Status OK +// +// Note that any modification of this function need to guarantee: +// 1. All "Require GIL" operations are protected by GIL +// SegFault / Deadlock will occur if this condition is not fulfilled. +// 2. All "Block" operations are free from GIL, all block target are registered with tree. +// Deadlock will occur if this condition is not fulfilled +// 3. No Python GC should be triggered outside of GIL. +// SegFault will occur is this condition is not fulfilled +// +Status GeneratorOp::operator()() { + // Handshake with TaskManager to synchronize thread creation + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + std::unique_ptr fetched_buffer; + bool eof = false; + while (!eof) { + // Create new buffer each iteration + fetched_buffer = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); + std::unique_ptr fetched_table = std::make_unique(); + bool eoe = false; + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + RETURN_IF_NOT_OK(FillBuffer(fetched_table.get())); + } catch (py::error_already_set &e) { + eoe = e.matches(PyExc_StopIteration); + // Restore exception to python + e.restore(); + // Pop up non StopIteration Python Exception + if (!eoe) { + return Status(StatusCode::kPyFuncException, __LINE__, __FILE__, e.what()); + } + } + } + if (fetched_table->size() > 0) { + fetched_buffer->set_tensor_table(std::move(fetched_table)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); + } + if (eoe) { + // Push out EOE upon StopIteration exception from generator + MS_LOG(DEBUG) << "Generator operator sends out EOE."; + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + // If last repeat or not repeated, push out EOF and exit master loop + MS_LOG(DEBUG) << "Generator operator sends out EOF."; + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + MS_LOG(DEBUG) << "Generator operator main execution loop complete."; + eof = true; + } else { + // Waiting for repeatOp to start new epoch + // If Reset() is called first by repeat op, this wait() will return right away. + // If Reset() is not called yet, this wait() will block until reset. + RETURN_IF_NOT_OK(wp_.Wait()); + // Clear the status of the wait post + wp_.Clear(); + } + } + } + return Status::OK(); +} + +Status GeneratorOp::Reset() { + // Reset Op state + RETURN_IF_NOT_OK(this->Init()); + // Wake up master thread + wp_.Set(); + return Status(StatusCode::kOK, "GeneratorOp Reset Succeed"); +} + +// Visitor accept method for NodePass +Status GeneratorOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status GeneratorOp::ComputeColMap() { + // Setup column names map (base class field) + if (column_name_id_map_.empty()) { + for (int i = 0; i < column_names_.size(); ++i) { + column_name_id_map_[column_names_[i]] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h new file mode 100644 index 0000000000..d09bfc3d71 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/generator_op.h @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +#pragma GCC visibility push(hidden) + +class GeneratorOp : public PipelineOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetGeneratorFunction(py::function generator_function) { + build_generator_function_ = generator_function; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetColumnNames(const std::vector &column_names) { + build_column_names_ = column_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetColumnTypes(const std::vector &column_types) { + build_column_types_ = column_types; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPrefetchSize(int32_t prefetch_size) { + build_prefetch_size_ = prefetch_size; + return *this; + } + + // The builder "build" method creates the final object. + // @return shared_ptr to the new GeneratorOp object + Status Build(std::shared_ptr *); + + private: + // The builder saves all GeneratorOp construction arguments internally. + // The following are the arguments. + py::function build_generator_function_; + std::vector build_column_names_; + std::vector build_column_types_; + + int32_t build_prefetch_size_ = 0; + int32_t build_buffer_size_; + int32_t build_op_connector_size_; + + Status SanityCheck(); + }; + + GeneratorOp(py::function generator_function, std::vector column_names, + std::vector column_types, int32_t prefetch_size, int32_t buffer_size, int32_t connector_size); + + ~GeneratorOp(); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param generator_op - reference to the GeneratorOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const GeneratorOp &generator_op) { + generator_op.Print(out, false); + return out; + } + + // Class functor operator () override. + // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status - The error code return + Status operator()() override; + + // Overrides base class reset method. When an operator does a reset, it cleans up any state + // info from it's previous execution and then initializes itself so that it can be executed + // again. + // @return Status - The error code return + Status Reset() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "GeneratorOp"; } + + private: + py::function generator_function_; + std::vector column_names_; + std::vector column_types_; + int32_t prefetch_size_; + int32_t buffer_size_; + + py::object generator_; + int32_t buffer_id_; + + WaitPost wp_; + + Status Init(); + + void Dealloc() noexcept; + + Status PyRowToTensorRow(py::object py_data, TensorRow *tensor_row); + + Status FillBuffer(TensorQTable *tt); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; +}; + +#pragma GCC visibility pop +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_GENERATOR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc new file mode 100644 index 0000000000..85839303db --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.cc @@ -0,0 +1,429 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status ImageFolderOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, builder_recursive_, builder_decode_, + builder_extensions_, builder_labels_to_read_, std::move(builder_schema_), + std::move(builder_sampler_)); + return Status::OK(); +} + +Status ImageFolderOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "ImageFolder path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, + bool recursive, bool do_decode, const std::set &exts, + const std::map &map, std::unique_ptr data_schema, + std::shared_ptr sampler) + : ParallelOp(num_wkrs, queue_size, std::move(sampler)), + rows_per_buffer_(rows_per_buffer), + folder_path_(file_dir), + recursive_(recursive), + decode_(do_decode), + extensions_(exts), + class_index_(map), + data_schema_(std::move(data_schema)), + row_cnt_(0), + buf_cnt_(0), + sampler_ind_(0), + dirname_offset_(0) { + folder_name_queue_ = std::make_unique>(num_wkrs * queue_size); + image_name_queue_ = std::make_unique>(num_wkrs * queue_size); + io_block_queues_.Init(num_workers_, queue_size); +} + +// Master thread that pulls the prescan worker's results. +// Keep collecting results until all prescan workers quit +// Then consolidate 2 level shuffles together into 1 giant vector +// calculate numRows then return +Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { + std::vector v; + int64_t cnt = 0; + while (cnt != num_workers_) { // count number of end signals + FolderImagesPair p; + RETURN_IF_NOT_OK(image_name_queue_->PopFront(&p)); + if (p == nullptr) { + cnt++; + } else { + v.push_back(p); + } + } + std::sort(v.begin(), v.end(), + [](const FolderImagesPair &lhs, const FolderImagesPair &rhs) { return lhs->first < rhs->first; }); + // following loop puts the 2 level of shuffles together into 1 vector + for (size_t ind = 0; ind < v.size(); ++ind) { + while (v[ind]->second.empty() == false) { + MS_ASSERT(!(v[ind]->first.empty())); // make sure that v[ind]->first.substr(1) is not out of bound + v[ind]->second.front()->second = class_index_.empty() ? ind : class_index_[v[ind]->first.substr(1)]; + image_label_pairs_.push_back(v[ind]->second.front()); + v[ind]->second.pop(); + } + } + image_label_pairs_.shrink_to_fit(); + num_rows_ = image_label_pairs_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset " + "API validation first."); + } + // free memory of two queues used for pre-scan + folder_name_queue_->Reset(); + image_name_queue_->Reset(); + return Status::OK(); +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status ImageFolderOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + TensorRow sample_row; + RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + keys.push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK( + io_block_queues_[buf_cnt_++ % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(keys, IOBlock::kDeIoBlockNone))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); + for (int32_t i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { // not the last repeat. Sleep master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +// IMPORTANT: 1 IOBlock produces 1 DataBuffer +Status ImageFolderOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer +Status ImageFolderOp::LoadTensorRow(row_id_type row_id, ImageLabelPair pairPtr, TensorRow *trow) { + std::shared_ptr image, label; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), + data_schema_->column(1).type(), + reinterpret_cast(&pairPtr->second))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, folder_path_ + (pairPtr->first))); + + if (decode_ == true) { + Status rc = Decode(image, &image); + if (rc.IsError()) { + std::string err = "Fail to decode image:" + folder_path_ + (pairPtr->first); + RETURN_STATUS_UNEXPECTED(err); + } + } + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status ImageFolderOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const int64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void ImageFolderOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nImageFolder directory: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status ImageFolderOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status ImageFolderOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status ImageFolderOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { + RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); + } + for (size_t i = 0; i < image_label_pairs_.size(); ++i) { + (*cls_ids)[image_label_pairs_[i]->second].push_back(i); + } + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +// Worker Entry for pre-scanning all the folders and do the 1st level shuffle +// Worker pull a file name from mFoldernameQueue (which is a Queue), walks all the images under that foldername +// After walking is complete, sort all the file names (relative path to all jpeg files under the same directory ) +// (Sort is automatically conducted using a set which is implemented using a Red-Black Tree) +// Add the sorted filenames in to a queue. The make a pair (foldername, queue*), +// foldername is used for 2nd level sorting. +// FYI: 1st level sorting: sort all images under the same directory. +// FYI: 2nd level sorting: sort all folder names +// push this pair to mImagenameQueue (which is again a Queue) +Status ImageFolderOp::PrescanWorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::string folder_name; + RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); + while (folder_name.empty() == false) { + Path folder(folder_path_ + folder_name); + std::shared_ptr dirItr = Path::DirIterator::OpenDirectory(&folder); + if (folder.Exists() == false || dirItr == nullptr) { + RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_name); + } + std::set imgs; // use this for ordering + while (dirItr->hasNext()) { + Path file = dirItr->next(); + if (extensions_.empty() || extensions_.find(file.Extension()) != extensions_.end()) { + (void)imgs.insert(file.toString().substr(dirname_offset_)); + } else { + MS_LOG(WARNING) << "Image folder operator unsupported file found: " << file.toString() + << ", extension: " << file.Extension() << "."; + } + } + FolderImagesPair p = std::make_shared>>(); + p->first = folder_name; + for (const std::string &img : imgs) { + p->second.push(std::make_shared>(img, 0)); + } + RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(p)); + RETURN_IF_NOT_OK(folder_name_queue_->PopFront(&folder_name)); + } + RETURN_IF_NOT_OK(image_name_queue_->EmplaceBack(nullptr)); // end signal + return Status::OK(); +} + +// This helper function recursively walks all foldernames, and send each foldername to mFoldernameQueue +// if mRecursive == false, don't go into folder of folders +Status ImageFolderOp::RecursiveWalkFolder(Path *dir) { + std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(dir); + RETURN_UNEXPECTED_IF_NULL(dir_itr); + while (dir_itr->hasNext()) { + Path subdir = dir_itr->next(); + if (subdir.IsDirectory()) { + if (class_index_.empty() || + class_index_.find(subdir.toString().substr(dirname_offset_ + 1)) != class_index_.end()) { + RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack(subdir.toString().substr(dirname_offset_))); + } + if (recursive_ == true) { + RETURN_IF_NOT_OK(RecursiveWalkFolder(&subdir)); + } + } + } + return Status::OK(); +} + +// A thread that calls RecursiveWalkFolder +Status ImageFolderOp::startAsyncWalk() { + TaskManager::FindMe()->Post(); + Path dir(folder_path_); + if (dir.Exists() == false || dir.IsDirectory() == false) { + RETURN_STATUS_UNEXPECTED("Error unable to open: " + folder_path_); + } + dirname_offset_ = folder_path_.length(); + RETURN_IF_NOT_OK(RecursiveWalkFolder(&dir)); + // send out num_workers_ end signal to mFoldernameQueue, 1 for each worker. + // Upon receiving end Signal, worker quits and set another end Signal to mImagenameQueue. + for (int32_t ind = 0; ind < num_workers_; ++ind) { + RETURN_IF_NOT_OK(folder_name_queue_->EmplaceBack("")); // end signal + } + return Status::OK(); +} + +Status ImageFolderOp::LaunchThreadsAndInitOp() { + RETURN_UNEXPECTED_IF_NULL(tree_); + // Registers QueueList and individual Queues for interrupt services + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(folder_name_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(image_name_queue_->Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + // The following code launch 3 threads group + // 1) A thread that walks all folders and push the folder names to a util:Queue mFoldernameQueue. + // 2) Workers that pull foldername from mFoldernameQueue, walk it and return the sorted images to mImagenameQueue + // 3) Launch main workers that load DataBuffers by reading all images + RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("walk dir", std::bind(&ImageFolderOp::startAsyncWalk, this))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::PrescanWorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&ImageFolderOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + // The order of the following 2 functions must not be changed! + RETURN_IF_NOT_OK(this->PrescanMasterEntry(folder_path_)); // Master thread of pre-scan workers, blocking + RETURN_IF_NOT_OK(this->InitSampler()); // pass numRows to Sampler + return Status::OK(); +} + +Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, + int64_t *num_classes, int64_t dev_id, int64_t num_dev) { + Path dir(path); + std::string err_msg = ""; + int64_t row_cnt = 0; + err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : ""; + err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : ""; + err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : ""; + if (err_msg.empty() == false) { + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::queue foldernames; + std::shared_ptr dir_itr = Path::DirIterator::OpenDirectory(&dir); + while (dir_itr->hasNext()) { + Path subdir = dir_itr->next(); + if (subdir.IsDirectory()) { + foldernames.push(subdir.toString()); + } + } + (*num_classes) = foldernames.size(); + while (foldernames.empty() == false) { + Path subdir(foldernames.front()); + dir_itr = Path::DirIterator::OpenDirectory(&subdir); + while (dir_itr->hasNext()) { + if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) { + ++row_cnt; + } + } + foldernames.pop(); + } + (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ImageFolderOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status ImageFolderOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h new file mode 100644 index 0000000000..153751d3c5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/image_folder_op.h @@ -0,0 +1,274 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using ImageLabelPair = std::shared_ptr>; +using FolderImagesPair = std::shared_ptr>>; + +class ImageFolderOp : public ParallelOp, public RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of ImageFolderOp + // @param int32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_op_connector_size_ = size; + return *this; + } + + // Setter method + // @param std::set & exts, file extensions to be read + // @return Builder setter method returns reference to the builder. + Builder &SetExtensions(const std::set &exts) { + builder_extensions_ = exts; + return *this; + } + + // Setter method + // @paramconst std::map& map - a class name to label map + // @return + Builder &SetClassIndex(const std::map &map) { + builder_labels_to_read_ = map; + return *this; + } + + // Setter method + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetImageFolderDir(const std::string &dir) { + builder_dir_ = dir; + return *this; + } + + // Whether dir are walked recursively + // @param bool recursive - if set to false, only get dirs in top level dir + // @return + Builder &SetRecursive(bool recursive) { + builder_recursive_ = recursive; + return *this; + } + + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + bool builder_recursive_; + std::string builder_dir_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::set builder_extensions_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + std::map builder_labels_to_read_; + }; + + // Constructor + // @param int32_t num_wkrs - Num of workers reading images in parallel + // @param int32_t - rows_per_buffer Number of images (rows) in each buffer + // @param std::string - dir directory of ImageNetFolder + // @param int32_t queue_size - connector queue size + // @param std::set exts - set of file extensions to read, if empty, read everything under the dir + // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read + ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, + bool do_decode, const std::set &exts, const std::map &map, + std::unique_ptr, std::shared_ptr sampler); + + // Destructor. + ~ImageFolderOp() = default; + + // Initialize ImageFOlderOp related var, calls the function to walk all files + // @param - std::string dir file directory to ImageNetFolder + // @return - The error code return + Status PrescanMasterEntry(const std::string &dir); + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status PrescanWorkerEntry(int32_t worker_id); + + // Main Loop of ImageFolderOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each class + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // This function is a hack! It is to return the num_class and num_rows. The result + // returned by this function may not be consistent with what image_folder_op is going to return + // user this at your own risk! + static Status CountRowsAndClasses(const std::string &path, const std::set &exts, int64_t *num_rows, + int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ImageFolderOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param ImageLabelPair pair - + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, ImageLabelPair pair, TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // @param std::string & dir - dir to walk all images + // @param int64_t * cnt - number of non folder files under the current dir + // @return + Status RecursiveWalkFolder(Path *dir); + + // start walking of all dirs + // @return + Status startAsyncWalk(); + + // Called first when function is called + // @return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; + std::string folder_path_; // directory of image folder + bool recursive_; + bool decode_; + std::set extensions_; // extensions allowed + std::map class_index_; + std::unique_ptr data_schema_; + int64_t row_cnt_; + int64_t buf_cnt_; + int64_t sampler_ind_; + int64_t dirname_offset_; + WaitPost wp_; + std::vector image_label_pairs_; + QueueList> io_block_queues_; // queues of IOBlocks + std::unique_ptr> folder_name_queue_; + std::unique_ptr> image_name_queue_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IMAGE_FOLDER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.cc new file mode 100644 index 0000000000..2b2542430b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/io_block.h" + +#include +#include + +namespace mindspore { +namespace dataset { +// IOBlock Class // + +// Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. +IOBlock::IOBlock(int64_t inKey, IOBlockFlags io_block_flags) : index_keys_(1, inKey), io_block_flags_(io_block_flags) {} + +// Constructor of the IOBlock (2) +IOBlock::IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) { + index_keys_.insert(index_keys_.end(), in_keys.begin(), in_keys.end()); +} + +// Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. +IOBlock::IOBlock(IOBlockFlags io_block_flags) : io_block_flags_(io_block_flags) {} + +// Fetches the first key from this block +Status IOBlock::GetKey(int64_t *out_key) const { + if (out_key == nullptr || index_keys_.empty()) { + RETURN_STATUS_UNEXPECTED("Failed to get the key from IOBlock"); + } + *out_key = index_keys_[0]; + return Status::OK(); +} + +// Fetches the list of keys from this block. +Status IOBlock::GetKeys(std::vector *out_keys) const { + if (out_keys == nullptr) { + RETURN_STATUS_UNEXPECTED("Output arg for GetKeys is null"); + } + *out_keys = index_keys_; // vector copy assign + return Status::OK(); +} + +// FilenameBlock derived class // + +// Constructor of the FilenameBlock (1) +FilenameBlock::FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags) + : IOBlock(key, io_block_flags), start_offset_(start_offset), end_offset_(end_offset) {} + +// Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. +FilenameBlock::FilenameBlock(IOBlockFlags io_block_flags) + : IOBlock(io_block_flags), start_offset_(kInvalidOffset), end_offset_(kInvalidOffset) {} + +// Gets the filename from the block using the provided index container +Status FilenameBlock::GetFilename(std::string *out_filename, const AutoIndexObj &index) const { + if (out_filename == nullptr) { + RETURN_STATUS_UNEXPECTED("Failed to get filename from FilenameBlock"); + } + + // a FilenameBlock only has one key. Call base class method to fetch that key + int64_t fetched_key; + RETURN_IF_NOT_OK(IOBlock::GetKey(&fetched_key)); + + // Do an index lookup using that key to get the filename. + auto r = index.Search(fetched_key); + if (r.second) { + auto &it = r.first; + *out_filename = it.value(); + } else { + RETURN_STATUS_UNEXPECTED("Could not find filename from index"); + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h new file mode 100644 index 0000000000..df26aa1fc1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/io_block.h @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ + +#include +#include + +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// The IOBlock class is used to describe a "unit of work" that a storage leaf operator worker thread +// is responsible for acting on. +// The IOBlocks and it's derived classes abstracts a key-store and key-lookup interface where each +// block contains 1 to n keys, and the keys are used in conjunction with an index to provide the meta +// information for satisfying an IO request. +class IOBlock { + public: + enum IOBlockFlags : uint32_t { + kDeIoBlockNone = 0, + kDeIoBlockFlagEoe = 1u, // end of IOBlocks for one epoch + kDeIoBlockFlagEof = 1u << 1 // end of IOBlocks for entire program + }; + + // Constructor of the IOBlock (1). A simpler one for the case when the block only has 1 key. + // @param inKey - A single key to add into the block + // @param io_block_flags - The flag setting for the block + IOBlock(int64_t inKey, IOBlockFlags io_block_flags); + + // Constructor of the IOBlock (2). + // @param in_keys - A vector of keys to add into the block + // @param io_block_flags - The flag setting for the block + IOBlock(const std::vector &in_keys, IOBlockFlags io_block_flags); + + // Constructor of the IOBlock (3). A special IOBlock that is used for control messaging. + // @param io_block_flags - The flag setting for the block + explicit IOBlock(IOBlockFlags io_block_flags); + + // Destructor + virtual ~IOBlock() = default; + + // Fetches the first key from the block. + // @note Only useful if you know the block only has 1 key. + // @return A copy of the first key from the block + // @return Status - The error code return + Status GetKey(int64_t *out_key) const; + + // Fetches the list of keys from this block. + // @param out_keys - A copy of the vector of keys from the block. + // @return Status - The error code return + Status GetKeys(std::vector *out_keys) const; + + // Does this block have the eoe flag turned on? + // @return T/F if the IOBlock is eoe + bool eoe() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEoe); } + + // Does this block have the eof flag turned on? + // @return T/F if the IOBlock is eof + bool eof() const { return static_cast(io_block_flags_) & static_cast(kDeIoBlockFlagEof); } + + // Adds a key to this block + // @param key - The key to add to this block + void AddKey(int64_t key) { index_keys_.push_back(key); } + + protected: + std::vector index_keys_; // keys used for lookups to the meta info for the data + IOBlockFlags io_block_flags_; +}; // class IOBlock + +const int64_t kInvalidOffset = -1; + +// The Filename block derived class implements a style of IO block where each block contains only a +// single key that maps to a filename. +class FilenameBlock : public IOBlock { + public: + // Constructor of the FilenameBlock (1) + // @param key - The key identifier that can be used to find the data for this block + // @param start_offset - Start offset + // @param end_offset - End offset + // @param io_block_flags - The flag setting for the block + FilenameBlock(int64_t key, int64_t start_offset, int64_t end_offset, IOBlockFlags io_block_flags); + + // Constructor of the FilenameBlock (2). A special IOBlock that is used for control messaging. + // @param io_block_flags - The flag setting for the block + explicit FilenameBlock(IOBlockFlags io_block_flags); + + // Destructor + ~FilenameBlock() = default; + + // Gets the filename from the block using the provided index container + // @param out_filename - The filename to add to the block + // @param index - The index to perform lookup against + // @return Status - The error code return + Status GetFilename(std::string *out_filename, const AutoIndexObj &index) const; + + // Get the start offset of file + // @return int64_t - Start offset + int64_t GetStartOffset() const { return start_offset_; } + + // Get the end offset of the file + // @return int64_t - Start offset + int64_t GetEndOffset() const { return end_offset_; } + + private: + int64_t start_offset_; + int64_t end_offset_; +}; // class TFBlock +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_IO_BLOCK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc new file mode 100644 index 0000000000..0476baf56f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.cc @@ -0,0 +1,438 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" + +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status ManifestOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_file_, + builder_op_connector_size_, builder_decode_, builder_labels_to_read_, + std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); + return Status::OK(); +} + +Status ManifestOp::Builder::SanityCheck() { + std::string err_msg; + err_msg += builder_file_.empty() ? "Manifest file is not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers smaller than 1\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, + const std::map &class_index, std::unique_ptr data_schema, + std::shared_ptr sampler, std::string usage) + : ParallelOp(num_works, queue_size, std::move(sampler)), + rows_per_buffer_(rows_per_buffer), + io_block_pushed_(0), + row_cnt_(0), + sampler_ind_(0), + data_schema_(std::move(data_schema)), + file_(file), + class_index_(class_index), + decode_(decode), + usage_(usage), + buf_cnt_(0) { + io_block_queues_.Init(num_workers_, queue_size); + (void)std::transform(usage_.begin(), usage_.end(), usage_.begin(), ::tolower); +} + +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status ManifestOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + return AddIoBlock(&sampler_buffer); +} + +Status ManifestOp::AddIoBlock(std::unique_ptr *sampler_buffer) { + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (!(*sampler_buffer)->eoe()) { + TensorRow sample_row; + RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + keys.push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + keys.clear(); + } + } + RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(sampler_buffer)); + } + } +} + +Status ManifestOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&ManifestOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(ParseManifestFile()); + RETURN_IF_NOT_OK(CountDatasetInfo()); + RETURN_IF_NOT_OK(InitSampler()); + return Status::OK(); +} + +// contains the main logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +// IMPORTANT: 1 IOBlock produces 1 DataBuffer +Status ManifestOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty()) { + return Status::OK(); // empty key is a quit signal for workers + } + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label) using 1 ImageLabelPair. 1 function call produces 1 TensorTow in a DataBuffer +Status ManifestOp::LoadTensorRow(row_id_type row_id, const std::pair> &data, + TensorRow *trow) { + std::shared_ptr image; + std::shared_ptr label; + std::vector label_index(data.second.size()); + (void)std::transform(data.second.begin(), data.second.end(), label_index.begin(), + [this](const std::string &label_name) { return label_index_[label_name]; }); + if (label_index.size() == 1) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), TensorShape({}), + data_schema_->column(1).type(), + reinterpret_cast(&label_index[0]))); + } else { + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &label, data_schema_->column(1).tensorImpl(), TensorShape(std::vector(1, label_index.size())), + data_schema_->column(1).type(), reinterpret_cast(&label_index[0]))); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data.first)); + if (decode_ == true) { + Status rc = Decode(image, &image); + if (rc.IsError()) { + std::string err = "Fail to decode image:" + data.first; + RETURN_STATUS_UNEXPECTED(err); + } + } + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status ManifestOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + for (const auto &key : keys) { + TensorRow trow; + RETURN_IF_NOT_OK(LoadTensorRow(key, image_labelname_[static_cast(key)], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void ManifestOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nManifest file: " << file_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status ManifestOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status ManifestOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status ManifestOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) { + RETURN_STATUS_UNEXPECTED("Class indexing is invalid."); + } + + for (size_t i = 0; i < image_labelname_.size(); i++) { + size_t image_index = i; + for (size_t j = 0; j < image_labelname_[image_index].second.size(); j++) { + std::string label_name = (image_labelname_[image_index].second)[j]; + int32_t label_index = label_index_.at(label_name); + (*cls_ids)[label_index].emplace_back(image_index); + } + } + + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +// Manifest file content +// {"source": "/path/to/image1.jpg", "usage":"train", annotation": ...} +// {"source": "/path/to/image2.jpg", "usage":"eval", "annotation": ...} +Status ManifestOp::ParseManifestFile() { + std::ifstream file_handle(file_); + if (!file_handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Manifest file " + file_ + " can not open."); + } + std::string line; + while (getline(file_handle, line)) { + try { + nlohmann::json js = nlohmann::json::parse(line); + std::string image_file_path = js.value("source", ""); + // If image is not JPEG/PNG/GIF/BMP, drop it + bool valid = false; + RETURN_IF_NOT_OK(CheckImageType(image_file_path, &valid)); + if (!valid) { + continue; + } + std::string usage = js.value("usage", ""); + (void)std::transform(usage.begin(), usage.end(), usage.begin(), ::tolower); + if (usage != usage_) { + continue; + } + std::vector labels; + nlohmann::json annotations = js.at("annotation"); + for (nlohmann::json::iterator it = annotations.begin(); it != annotations.end(); ++it) { + nlohmann::json annotation = it.value(); + std::string label_name = annotation.value("name", ""); + if (label_name == "") { + file_handle.close(); + RETURN_STATUS_UNEXPECTED("Label name is not found in manifest file for " + image_file_path); + } + if (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) { + if (label_index_.find(label_name) == label_index_.end()) { + label_index_[label_name] = 0; + } + labels.emplace_back(label_name); + } + } + if (!labels.empty()) { + image_labelname_.emplace_back(std::make_pair(image_file_path, labels)); + } + } catch (const std::exception &err) { + file_handle.close(); + RETURN_STATUS_UNEXPECTED("Parse manifest file failed"); + } + } + file_handle.close(); + + return Status::OK(); +} + +// Only support JPEG/PNG/GIF/BMP +Status ManifestOp::CheckImageType(const std::string &file_name, bool *valid) { + std::ifstream file_handle; + constexpr int read_num = 3; + *valid = false; + file_handle.open(file_name, std::ios::binary | std::ios::in); + if (!file_handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Can not open image file " + file_name); + } + unsigned char file_type[read_num]; + (void)file_handle.read(reinterpret_cast(file_type), read_num); + + if (file_handle.fail()) { + file_handle.close(); + RETURN_STATUS_UNEXPECTED("Read image file failed " + file_name); + } + file_handle.close(); + if (file_type[0] == 0xff && file_type[1] == 0xd8 && file_type[2] == 0xff) { + // Normal JPEGs start with \xff\xd8\xff\xe0 + // JPEG with EXIF stats with \xff\xd8\xff\xe1 + // Use \xff\xd8\xff to cover both. + *valid = true; + } else if (file_type[0] == 0x89 && file_type[1] == 0x50 && file_type[2] == 0x4e) { + // It's a PNG + *valid = true; + } else if (file_type[0] == 0x47 && file_type[1] == 0x49 && file_type[2] == 0x46) { + // It's a GIF + *valid = true; + } else if (file_type[0] == 0x42 && file_type[1] == 0x4d) { + // It's a BMP + *valid = true; + } + return Status::OK(); +} + +Status ManifestOp::CountDatasetInfo() { + int32_t index = 0; + for (auto &label : label_index_) { + label.second = class_index_.empty() ? index : class_index_[label.first]; + index++; + } + + num_rows_ = static_cast(image_labelname_.size()); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " + "validation first."); + } + return Status::OK(); +} + +Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, + int64_t *count, int64_t *numClasses) { + // the logic of counting the number of samples is copied from ParseManifestFile() + std::map map; + for (auto p : dict) { + (void)map.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); + RETURN_IF_NOT_OK(op->ParseManifestFile()); + *numClasses = static_cast(op->label_index_.size()); + *count = static_cast(op->image_labelname_.size()); + return Status::OK(); +} + +Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, + std::map *output_class_indexing) { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + if (!input_class_indexing.empty()) { + *output_class_indexing = input_class_indexing; + } else { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op)); + RETURN_IF_NOT_OK(op->ParseManifestFile()); + RETURN_IF_NOT_OK(op->CountDatasetInfo()); + uint32_t count = 0; + for (const auto label : op->label_index_) { + (*output_class_indexing).insert(std::make_pair(label.first, count)); + count++; + } + } + + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ManifestOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status ManifestOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h new file mode 100644 index 0000000000..bac8f04c94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/manifest_op.h @@ -0,0 +1,250 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +class ManifestOp : public ParallelOp, public RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of ManifestOp + Builder(); + + // Destructor + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t size) { + builder_op_connector_size_ = size; + return *this; + } + + // Setter method + // @param const std::map& map - a class name to label map + // @return + Builder &SetClassIndex(const std::map &map) { + builder_labels_to_read_ = map; + return *this; + } + + // Setter method + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return Builder setter method returns reference to the builder. + Builder &SetManifestFile(const std::string &file) { + builder_file_ = file; + return *this; + } + + // Setter method + // @param const std::string & dir + // @return Builder setter method returns reference to the builder. + Builder &SetUsage(const std::string &usage) { + builder_usage_ = usage; + return *this; + } + + // Check validity of input args + // @return Status - The error code return + Status SanityCheck(); + + // The builder "build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + std::shared_ptr builder_sampler_; + bool builder_decode_; + + std::string builder_file_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::unique_ptr builder_schema_; + std::string builder_usage_; + std::map builder_labels_to_read_; + }; + + // Constructor + // @param int32_t num_works - Num of workers reading images in parallel + // @param int32_t - rows_per_buffer Number of images (rows) in each buffer + // @param std::string - file list of Manifest + // @param int32_t queue_size - connector queue size + // @param td::unique_ptr sampler - sampler tells ImageFolderOp what to read + ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, + const std::map &class_index, std::unique_ptr data_schema, + std::shared_ptr sampler, std::string usage); + // Destructor. + ~ManifestOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t worker_id - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of ManifestOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each class + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, + int64_t *numClasses); + + // Get str-to-int mapping from label name to index + static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, + std::map *output_class_indexing); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ManifestOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Method in operator(), to fill IOBlockQueue + // @param std::unique_ptr sampler_buffer - to fill IOBlockQueue + // @return Status - The error code return + Status AddIoBlock(std::unique_ptr *sampler_buffer); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param std::pair> - > + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::pair> &data, + TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Parse manifest file to get image path and label and so on. + // @return Status - The error code return + Status ParseManifestFile(); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Check if image ia valid.Only support JPEG/PNG/GIF/BMP + // @return + Status CheckImageType(const std::string &file_name, bool *valid); + + // Count label index,num rows and num samples + // @return Status - The error code return + Status CountDatasetInfo(); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; + int64_t io_block_pushed_; + int64_t row_cnt_; + int64_t sampler_ind_; + std::unique_ptr data_schema_; + std::string file_; // file that store the information of images + std::map class_index_; + bool decode_; + std::string usage_; + int64_t buf_cnt_; + + WaitPost wp_; + QueueList> io_block_queues_; + std::map label_index_; + std::vector>> image_labelname_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MANIFEST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc new file mode 100644 index 0000000000..cf1493eb78 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -0,0 +1,513 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using mindrecord::kInt64Len; +using mindrecord::MSRStatus; +using mindrecord::Schema; +using mindrecord::ShardOperator; +using mindrecord::ShardReader; + +// Builder constructor. Creates the builder object. +MindRecordOp::Builder::Builder() : build_dataset_file_({}) { + // Some arguments to the MindRecordOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the MindRecordOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + build_num_mind_record_workers_ = kDefaultMindRecordWorkers; + build_rows_per_buffer_ = cfg->rows_per_buffer(); + build_op_connector_queue_size_ = cfg->op_connector_size(); + build_block_reader_ = false; + builder_num_workers_ = 0; + build_num_padded_ = 0; + build_sample_ = nullptr; +} + +// The builder "build" method creates the final object. +Status MindRecordOp::Builder::Build(std::shared_ptr *ptr) { + std::shared_ptr new_mind_record_op; + + if (build_dataset_file_.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Building a MindRecordOp that has not provided a file."); + } + mindrecord::json sample_json; + if (build_num_padded_ > 0) { + sample_json = ToJson(build_sample_); + } + new_mind_record_op = std::make_shared( + build_num_mind_record_workers_, build_rows_per_buffer_, build_dataset_file_, build_load_dataset_, + build_op_connector_queue_size_, build_columns_to_load_, build_operators_, build_block_reader_, build_num_padded_, + sample_json, build_sample_bytes_); + + RETURN_IF_NOT_OK(new_mind_record_op->Init()); + *ptr = std::move(new_mind_record_op); + return Status::OK(); +} + +Status MindRecordOp::Builder::SanityCheck() const { return Status::OK(); } + +mindrecord::json MindRecordOp::Builder::ToJson(const py::handle &obj) { + if (obj.is_none()) { + return nullptr; + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { // also catch py::bytes + return obj.cast(); + } + if (py::isinstance(obj)) { + auto out = mindrecord::json::object(); + for (const py::handle &key : obj) { + if (py::isinstance(obj[key])) { + build_sample_bytes_[py::str(key).cast()] = obj[key].cast(); + } else { + out[py::str(key).cast()] = ToJson(obj[key]); + } + } + return out; + } + MS_LOG(ERROR) << "Python object convert to json failed, object is: " << py::cast(obj); + return mindrecord::json(); +} + +// Constructor of the MindRecordOp. +MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, + std::vector dataset_file, bool load_dataset, int32_t op_connector_queue_size, + const std::vector &columns_to_load, + const std::vector> &operators, const bool &block_reader, + int64_t num_padded, const mindrecord::json &sample_json, + const std::map &sample_bytes) + : ParallelOp(num_mind_record_workers, op_connector_queue_size), + rows_per_buffer_(rows_per_buffer), + dataset_file_(dataset_file), + load_dataset_(load_dataset), + columns_to_load_(columns_to_load), + operators_(operators), + num_mind_record_workers_(num_mind_record_workers), + block_reader_(block_reader), + num_rows_(0), + buffers_needed_(0), + buf_cnt_(0), + ended_worker_(0), + buffer_water_mark_(0), + num_padded_(num_padded), + sample_json_(sample_json), + sample_bytes_(sample_bytes) { + io_blk_queues_.Init(num_workers_, op_connector_queue_size); + if (!block_reader_) return; + for (int32_t i = 0; i < num_workers_; ++i) { + block_buffer_.emplace_back(std::make_unique>(std::vector{})); + } +} + +// Private helper method to encapsulate some common construction/reset tasks +Status MindRecordOp::Init() { + shard_reader_ = std::make_unique(); + auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, + block_reader_, num_padded_); + + CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, + "MindRecordOp init failed. Error message: " + ErrnoToMessage(rc)); + + data_schema_ = std::make_unique(); + + std::vector col_names = shard_reader_->GetShardColumn()->GetColumnName(); + CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); + std::vector col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType(); + std::vector> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape(); + + bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything + std::map colname_to_ind; + for (uint32_t i = 0; i < col_names.size(); i++) { + std::string colname = col_names[i]; + ColDescriptor col_desc; + + TensorShape t_shape = TensorShape::CreateUnknownRankShape(); // shape of tensor, default unknown + std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]]; + DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"} + + if (col_data_types[i] == mindrecord::ColumnBytes) { // rank = 1 + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1); + } else if (col_data_types[i] == mindrecord::ColumnString) { // rank = 0 + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 0); + } else if (col_shapes[i].size() > 0) { + std::vector vec(col_shapes[i].size()); // temporary vector to hold shape + (void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin()); + t_shape = TensorShape(vec); + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); + } else { // unknown shape + // create colDesc and add it to schema + col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, t_shape.Rank(), &t_shape); + } + + colname_to_ind[colname] = data_schema_->NumColumns(); + RETURN_IF_NOT_OK(data_schema_->AddColumn(col_desc)); + + if (load_all_cols) { + columns_to_load_.emplace_back(colname); + } + } + + if (!load_all_cols) { + std::unique_ptr tmp_schema = std::make_unique(); + for (std::string colname : columns_to_load_) { + CHECK_FAIL_RETURN_UNEXPECTED(colname_to_ind.find(colname) != colname_to_ind.end(), colname + ": doesn't exist"); + RETURN_IF_NOT_OK(tmp_schema->AddColumn(data_schema_->column(colname_to_ind[colname]))); + } + data_schema_ = std::move(tmp_schema); + } + + return Status::OK(); +} + +// Destructor +MindRecordOp::~MindRecordOp() {} + +// A print method typically used for debugging +void MindRecordOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\n Dataset file : "; + for (auto &file : dataset_file_) { + out << file << " "; + } + out << "\nNumber of rows : " << num_rows_ << "\nRows per buffer : " << rows_per_buffer_ + << "\nNumber of buffers : " << buffers_needed_ + << "\nNumber of ShardReader workers : " << num_mind_record_workers_ << "\n\n"; + } +} + +Status MindRecordOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe()) { + RETURN_IF_NOT_OK( + out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + continue; + } + if (io_block->eof()) { + RETURN_IF_NOT_OK( + out_connector_->Add(worker_id, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + continue; + } + + // load data buffer + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) { + { + std::unique_lock lock(ended_worker_mutex_); + ended_worker_++; + if (ended_worker_ == num_workers_) shard_reader_->Close(); + } + return Status::OK(); // empty key is a quit signal for workers + } + + const uint64_t buffer_id = keys[0]; + std::unique_ptr fetched_buffer; + + // Get the next buffer. Push it up to the output connector. + if (buffer_id % LOG_INTERVAL == 0) { + MS_LOG(DEBUG) << "MindRecord operator consumed buffer " << buffer_id << " by worker " << worker_id << "."; + } + RETURN_IF_NOT_OK(GetBufferFromReader(&fetched_buffer, buffer_id, worker_id)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(fetched_buffer))); + if (!block_reader_) { + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + continue; + } + + // update block-reader buffer + block_buffer_[buffer_id % num_workers_]->clear(); + { + std::unique_lock lck(mtx_block_reader_); + if (buffer_id == buffer_water_mark_) { + buffer_water_mark_++; + while (block_set_.count(buffer_water_mark_) > 0) (void)block_set_.erase(buffer_water_mark_++); + } else { + (void)block_set_.insert(buffer_id); + } + } + cv_reader_.notify_one(); + RETURN_IF_NOT_OK(io_blk_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +Status MindRecordOp::GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, + int32_t worker_id) { + *fetched_buffer = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + for (int32_t i = 0; i < rows_per_buffer_; ++i) { + ShardTuple tupled_buffer; + mindrecord::TaskType task_type = mindrecord::TaskType::kCommonTask; + if (block_reader_) { + if (i >= block_buffer_[buffer_id % num_workers_]->size()) break; + tupled_buffer = block_buffer_[buffer_id % num_workers_]->at(i); + } else { + int32_t row_id = buffer_id * rows_per_buffer_ + i; + auto rc = shard_reader_->GetNextById(row_id, worker_id); + task_type = rc.first; + tupled_buffer = rc.second; + if (task_type == mindrecord::TaskType::kPaddedTask) { + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, {}, mindrecord::json(), task_type)); + tensor_table->push_back(std::move(tensor_row)); + } + if (tupled_buffer.empty()) break; + } + if (task_type == mindrecord::TaskType::kCommonTask) { + for (const auto &tupled_row : tupled_buffer) { + std::vector columns_blob = std::get<0>(tupled_row); + mindrecord::json columns_json = std::get<1>(tupled_row); + TensorRow tensor_row; + RETURN_IF_NOT_OK(LoadTensorRow(&tensor_row, columns_blob, columns_json, task_type)); + tensor_table->push_back(std::move(tensor_row)); + } + } + } + + // Replace the TensorTable in DataBuffer with the new one. + (*fetched_buffer)->set_tensor_table(std::move(tensor_table)); + return Status::OK(); +} + +Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, + const mindrecord::json &columns_json, const mindrecord::TaskType task_type) { + for (uint32_t i_col = 0; i_col < columns_to_load_.size(); i_col++) { + auto column_name = columns_to_load_[i_col]; + + // Initialize column parameters + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0; + mindrecord::ColumnDataType column_data_type = mindrecord::ColumnNoDataType; + uint64_t column_data_type_size = 1; + std::vector column_shape; + + // Get column data + auto shard_column = shard_reader_->GetShardColumn(); + if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { + auto rc = + shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); + if (rc.first != MSRStatus::SUCCESS) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data type."); + } + if (rc.second == mindrecord::ColumnInRaw) { + auto has_column = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve raw data from padding sample."); + } + } else if (rc.second == mindrecord::ColumnInBlob) { + if (sample_bytes_.find(column_name) == sample_bytes_.end()) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve blob data from padding sample."); + } + std::string ss(sample_bytes_[column_name]); + n_bytes = ss.size(); + data_ptr = std::make_unique(n_bytes); + std::copy(ss.begin(), ss.end(), data_ptr.get()); + } else { + RETURN_STATUS_UNEXPECTED("Retrieved data type is unknown."); + } + if (data == nullptr) { + data = reinterpret_cast(data_ptr.get()); + } + } else { + auto has_column = + shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, + &column_data_type, &column_data_type_size, &column_shape); + if (has_column == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("Failed to retrieve data from mindrecord reader."); + } + } + + std::shared_ptr tensor; + const ColDescriptor &column = data_schema_->column(i_col); + DataType type = column.type(); + + // Set shape + auto num_elements = n_bytes / column_data_type_size; + if (type == DataType::DE_STRING) { + std::string s{data, data + n_bytes}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {s}, TensorShape::CreateScalar())); + } else if (column.hasShape()) { + auto new_shape = TensorShape(column.shape()); + RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast(num_elements), &new_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + } else { + std::vector shapeDetails = {static_cast(num_elements)}; + auto new_shape = TensorShape(shapeDetails); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data)); + } + tensor_row->push_back(std::move(tensor)); + } + return Status::OK(); +} + +Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { + { + std::unique_lock lck(mtx_block_reader_); + cv_reader_.wait(lck, [buffer_id, this] { return buffer_id < buffer_water_mark_ + num_workers_; }); + } + for (int32_t i = 0; i < rows_per_buffer_; i++) { + // Block reader does NOT care about argument + auto rc = shard_reader_->GetNextById(i, i); + ShardTuple tuple_buffer = rc.second; + if (tuple_buffer.empty()) break; + block_buffer_[buffer_id % num_workers_]->push_back(std::move(tuple_buffer)); + } + return Status::OK(); +} + +// Class functor operator () override. +// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work +// Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work +Status MindRecordOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); + num_rows_ = shard_reader_->GetNumRows(); + // Compute how many buffers we would need to accomplish rowsPerBuffer + buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; + + while (true) { // each iterator is 1 epoch + for (int32_t i = 0; i < buffers_needed_; ++i) { + if (block_reader_) RETURN_IF_NOT_OK(FetchBlockBuffer(i)); + std::vector keys(1, i); + RETURN_IF_NOT_OK(io_blk_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK(io_blk_queues_[i]->Add( + std::move(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone)))); + } + return Status::OK(); + } else { // not the last repeat. Acquire lock, sleeps master thread, wait for the wake-up from reset + RETURN_IF_NOT_OK( + io_blk_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + + // reset our buffer count and go to loop again. + RETURN_IF_NOT_OK(shard_reader_wait_post_.Wait()); + shard_reader_wait_post_.Clear(); + } + } +} + +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status MindRecordOp::Reset() { + RETURN_IF_NOT_OK(ParallelOp::Reset()); // Call our super class reset first. + + if (block_reader_) { + shard_reader_->Reset(); + buffer_water_mark_ = 0; + } else { + shard_reader_->ShuffleTask(); + } + shard_reader_wait_post_.Set(); + + return Status::OK(); +} + +Status MindRecordOp::LaunchThreadAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + + RETURN_IF_NOT_OK(io_blk_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(shard_reader_wait_post_.Register(tree_->AllTasks())); + if (shard_reader_->Launch(!block_reader_) == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); + } + // Launch main workers that load DataBuffers by reading all images + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + return Status::OK(); +} + +Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, + const std::shared_ptr &op, int64_t *count, int64_t num_padded) { + std::unique_ptr shard_reader = std::make_unique(); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); + if (rc == MSRStatus::FAILED) { + RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); + } + return Status::OK(); +} + +// Visitor accept method for NodePass +Status MindRecordOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status MindRecordOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + for (int i = 0; i < static_cast(columns_to_load_.size()); i++) { + column_name_id_map_[columns_to_load_[i]] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h new file mode 100644 index 0000000000..367505b172 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.h @@ -0,0 +1,276 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; +class DataBuffer; + +using mindrecord::ShardOperator; +using mindrecord::ShardReader; +using ShardTuple = std::vector, mindrecord::json>>; // Row of data from ShardReader + +const int32_t LOG_INTERVAL = 19; + +class MindRecordOp : public ParallelOp { + public: + // The nested builder class inside of the MindRecordOp is used to help manage all of the arguments + // for constructing it. Use the builder by setting each argument with the provided set methods, + // and then finally call the build method to execute the actual construction. + class Builder { + public: + Builder(); + + ~Builder() = default; + + Status Build(std::shared_ptr *); + + Builder &SetRowsPerBuffer(int rows_per_buffer) { + build_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + Builder &SetNumMindRecordWorkers(int32_t num_mind_record_workers) { + build_num_mind_record_workers_ = num_mind_record_workers; + return *this; + } + + Builder &SetOpConnectorQueueSize(int32_t queue_size) { + build_op_connector_queue_size_ = queue_size; + return *this; + } + + Builder &SetDatasetFile(const std::vector &files) { + build_dataset_file_ = files; + return *this; + } + + Builder &SetColumnsToLoad(const std::vector &columns) { + build_columns_to_load_ = columns; + return *this; + } + + Builder &SetOperators(const std::vector> &operators) { + build_operators_ = operators; + return *this; + } + + Builder &SetBlockReader() { + build_block_reader_ = true; + return *this; + } + + Builder &SetLoadDataset(bool load_dataset) { + build_load_dataset_ = load_dataset; + return *this; + } + + Builder &SetNumToPadSamples(int64_t num_padded) { + build_num_padded_ = num_padded; + return *this; + } + + Builder &SetPaddedSample(const py::handle &sample) { + build_sample_ = sample; + return *this; + } + + Status SanityCheck() const; + + static int32_t num_mind_record_workers() { return kDefaultMindRecordWorkers; } + + mindrecord::json ToJson(const py::handle &obj); + + private: + static constexpr int32_t kDefaultMindRecordWorkers = 4; + // The builder saves all MindRecordOp construction arguments internally. + // The following are the arguments. + int32_t build_num_mind_record_workers_; + int32_t builder_num_workers_; + int32_t build_rows_per_buffer_; + int32_t build_op_connector_queue_size_; + std::vector build_dataset_file_; + bool build_load_dataset_; + std::vector build_columns_to_load_; + std::vector> build_operators_; + bool build_block_reader_; + int64_t build_num_padded_; + py::handle build_sample_; + std::map build_sample_bytes_; + }; + + // Constructor of the MindRecordOp. + // @note The builder class should be used to call it + // @param num_mind_record_workers - The number of workers for the op (run by ShardReader) + // @param rows_per_buffer - The requested number of rows per buffer + // @param dataset_file - dataset files + // @param op_connector_queue_size - The output connector queue size + // @param columns_to_load - The list of columns to use (column name) + // @param operators - ShardOperators for Shuffle, Category, Sample + MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buffer, std::vector dataset_file, + bool load_dataset, int32_t op_connector_queue_size, const std::vector &columns_to_load, + const std::vector> &operators, const bool &block_reader, + int64_t num_padded_, const mindrecord::json &sample_json, + const std::map &sample_bytes_); + + // Destructor + ~MindRecordOp() override; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param op - reference to the MindRecordOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const MindRecordOp &op) { + op.Print(out, false); + return out; + } + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Class functor operator () override. + // All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status - The error code return + Status operator()() override; + + // Called first when function is called + // @return + Status LaunchThreadAndInitOp(); + + // Overrides base class reset method. When an operator does a reset, it cleans up any state + // info from it's previous execution and then initializes itself so that it can be executed + // again. + // @return Status - The error code return + Status Reset() override; + + // Getter method + int32_t num_rows() const { return num_rows_; } + + static Status CountTotalRows(const std::vector dataset_path, bool load_dataset, + const std::shared_ptr &op, int64_t *count, int64_t num_padded); + + // Getter method + int32_t rows_per_buffer() const { return rows_per_buffer_; } + + // Getter method + std::vector dataset_file() const { return dataset_file_; } + + // Getter method + std::vector columns_to_load() const { return columns_to_load_; } + + bool block_reader() const { return block_reader_; } + + bool load_dataset() const { return load_dataset_; } + + Status Init(); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "MindRecordOp"; } + + private: + Status GetBufferFromReader(std::unique_ptr *fetched_buffer, int64_t buffer_id, int32_t worker_id); + + // Parses a single cell and puts the data into a tensor + // @param tensor_row - the tensor row to put the parsed data in + // @param columns_blob - the blob data received from the reader + // @param columns_json - the data for fields received from the reader + Status LoadTensorRow(TensorRow *tensor_row, const std::vector &columns_blob, + const mindrecord::json &columns_json, const mindrecord::TaskType task_type); + + Status FetchBlockBuffer(const int32_t &buffer_id); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t rows_per_buffer_; // The number of requested rows per buffer. + std::vector dataset_file_; // dataset files + bool load_dataset_; // load dataset from single file or not + std::vector columns_to_load_; // Columns to load from dataset + std::vector> operators_; // ShardOperators to use + int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader + bool block_reader_; // block reader switch + int32_t buffers_needed_; // Counter for the buffers that were fetched + int64_t buf_cnt_; // Buffer counter + int32_t num_rows_; // One more than the last row id in the range for this cache + std::atomic ended_worker_; + std::atomic buffer_water_mark_; + + int64_t num_padded_; + mindrecord::json sample_json_; + std::map sample_bytes_; + + std::unique_ptr data_schema_; // Data schema for column typing + std::vector columns_blob_; // Blob Columns to load from dataset + std::vector columns_blob_index_; // Blob Columns to load from dataset + + std::unique_ptr shard_reader_; + WaitPost shard_reader_wait_post_; + QueueList> io_blk_queues_; + + // For block reader + std::mutex mtx_block_reader_; + std::condition_variable cv_reader_; + std::vector>> block_buffer_; + std::unordered_set block_set_; + + std::mutex ended_worker_mutex_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MINDRECORD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc new file mode 100644 index 0000000000..11ad18865e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.cc @@ -0,0 +1,450 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" + +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +const int32_t kMnistImageFileMagicNumber = 2051; +const int32_t kMnistLabelFileMagicNumber = 2049; +const int32_t kMnistImageRows = 28; +const int32_t kMnistImageCols = 28; + +MnistOp::Builder::Builder() : builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status MnistOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1))); + TensorShape scalar = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); + *ptr = std::make_shared(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, + builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); + return Status::OK(); +} + +Status MnistOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size, std::move(sampler)), + buf_cnt_(0), + row_cnt_(0), + folder_path_(folder_path), + rows_per_buffer_(rows_per_buffer), + data_schema_(std::move(data_schema)) { + io_block_queues_.Init(num_workers, queue_size); +} + +Status MnistOp::TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) >= num_rows_) continue; // index out of bound, skipping + keys->push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); + keys->clear(); + } + } + return Status::OK(); +} + +// functor that contains the main logic of MNIST op +Status MnistOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { // each iterator is 1 epoch + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); + if (sample_ids->type() != DataType(DataType::DE_INT64)) { + RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't UINT64"); + } + RETURN_IF_NOT_OK(TraversalSampleIds(sample_ids, &keys)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEof))); + for (int32_t i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); // Master thread goes to sleep after it has made all the IOBlocks + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +// contains the logic of pulling a IOBlock from IOBlockQueue, load a buffer and push the buffer to out_connector_ +Status MnistOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr iOBlock; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); + while (iOBlock != nullptr) { + if (iOBlock->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (iOBlock->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOF))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(iOBlock->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); // empty key is a quit signal for workers + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&iOBlock)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +// Load 1 TensorRow (image,label) using 1 MnistLabelPair. +Status MnistOp::LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *trow) { + std::shared_ptr image, label; + int32_t l = mnist_pair.second; + // make a copy of cached tensor + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), mnist_pair.first->shape(), + mnist_pair.first->type(), mnist_pair.first->GetBuffer())); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&label, data_schema_->column(1).tensorImpl(), data_schema_->column(1).shape(), + data_schema_->column(1).type(), reinterpret_cast(&l))); + (*trow) = TensorRow(row_id, {std::move(image), std::move(label)}); + return Status::OK(); +} + +// Looping over LoadTensorRow to make 1 DataBuffer. 1 function call produces 1 buffer +Status MnistOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const int64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_label_pairs_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +void MnistOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows:" << num_rows_ << "\nMNIST Directory: " << folder_path_ << "\n\n"; + } +} + +// Reset Sampler and wakeup Master thread (functor) +Status MnistOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); // wake up master thread after reset is done + return Status::OK(); +} + +// hand shake with Sampler, allow Sampler to call RandomAccessOp's functions to get NumRows +Status MnistOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +// Derived from RandomAccessOp +Status MnistOp::GetClassIds(std::map> *cls_ids) const { + if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { + RETURN_STATUS_UNEXPECTED("ImageLabelPair not set"); + } + for (size_t i = 0; i < image_label_pairs_.size(); ++i) { + (*cls_ids)[image_label_pairs_[i].second].push_back(i); + } + for (auto &pair : (*cls_ids)) { + pair.second.shrink_to_fit(); + } + return Status::OK(); +} + +Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) { + uint32_t res = 0; + reader->read(reinterpret_cast(&res), 4); + if (reader->fail()) { + RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file"); + } + *result = SwapEndian(res); + return Status::OK(); +} + +uint32_t MnistOp::SwapEndian(uint32_t val) const { + val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); + return (val << 16) | (val >> 16); +} + +Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) { + if (image_reader->is_open() == false) { + RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name); + } + int64_t image_len = image_reader->seekg(0, std::ios::end).tellg(); + (void)image_reader->seekg(0, std::ios::beg); + // The first 16 bytes of the image file are type, number, row and column + if (image_len < 16) { + RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); + } + uint32_t magic_number; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number)); + CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber, + "This is not the mnist image file: " + file_name); + + uint32_t num_items; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &num_items)); + uint32_t rows; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &rows)); + uint32_t cols; + RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols)); + // The image size of the Mnist dataset is fixed at [28,28] + if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) { + RETURN_STATUS_UNEXPECTED("Wrong shape of image."); + } + if ((image_len - 16) != num_items * rows * cols) { + RETURN_STATUS_UNEXPECTED("Wrong number of image."); + } + *num_images = num_items; + return Status::OK(); +} + +Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { + if (label_reader->is_open() == false) { + RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name); + } + int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); + (void)label_reader->seekg(0, std::ios::beg); + // The first 8 bytes of the image file are type and number + if (label_len < 8) { + RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); + } + uint32_t magic_number; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); + CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber, + "This is not the mnist label file: " + file_name); + uint32_t num_items; + RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); + if ((label_len - 8) != num_items) { + RETURN_STATUS_UNEXPECTED("Wrong number of labels!"); + } + *num_labels = num_items; + return Status::OK(); +} + +Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index) { + uint32_t num_images, num_labels; + RETURN_IF_NOT_OK(CheckImage(image_names_[index], image_reader, &num_images)); + RETURN_IF_NOT_OK(CheckLabel(label_names_[index], label_reader, &num_labels)); + CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num_images != num_labels"); + // The image size of the Mnist dataset is fixed at [28,28] + int64_t size = kMnistImageRows * kMnistImageCols; + auto images_buf = std::make_unique(size * num_images); + auto labels_buf = std::make_unique(num_images); + if (images_buf == nullptr || labels_buf == nullptr) { + std::string err_msg = "Fail to allocate memory for MNIST Buffer."; + MS_LOG(ERROR) << err_msg.c_str(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + (void)image_reader->read(images_buf.get(), size * num_images); + if (image_reader->fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read:" + image_names_[index] + " size:" + std::to_string(size * num_images)); + } + (void)label_reader->read(labels_buf.get(), num_images); + if (label_reader->fail()) { + RETURN_STATUS_UNEXPECTED("Fail to read:" + label_names_[index] + " size: " + std::to_string(num_images)); + } + TensorShape img_tensor_shape = TensorShape({kMnistImageRows, kMnistImageCols, 1}); + for (int64_t j = 0; j != num_images; ++j) { + auto pixels = &images_buf[j * size]; + for (int64_t m = 0; m < size; ++m) { + pixels[m] = (pixels[m] == 0) ? 0 : 255; + } + std::shared_ptr image; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&image, data_schema_->column(0).tensorImpl(), img_tensor_shape, + data_schema_->column(0).type(), reinterpret_cast(pixels))); + image_label_pairs_.emplace_back(std::make_pair(image, labels_buf[j])); + } + return Status::OK(); +} + +Status MnistOp::ParseMnistData() { + for (size_t i = 0; i < image_names_.size(); ++i) { + std::ifstream image_reader, label_reader; + image_reader.open(image_names_[i], std::ios::binary); + label_reader.open(label_names_[i], std::ios::binary); + + Status s = ReadImageAndLabel(&image_reader, &label_reader, i); + // Close the readers + image_reader.close(); + label_reader.close(); + RETURN_IF_NOT_OK(s); + } + image_label_pairs_.shrink_to_fit(); + num_rows_ = image_label_pairs_.size(); + if (num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " + "validation first."); + } + return Status::OK(); +} + +Status MnistOp::WalkAllFiles() { + const std::string kImageExtension = "idx3-ubyte"; + const std::string kLabelExtension = "idx1-ubyte"; + + Path dir(folder_path_); + auto dir_it = Path::DirIterator::OpenDirectory(&dir); + if (dir_it != nullptr) { + while (dir_it->hasNext()) { + Path file = dir_it->next(); + std::string filename = file.toString(); + if (filename.find(kImageExtension) != std::string::npos) { + image_names_.push_back(filename); + MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; + } else if (filename.find(kLabelExtension) != std::string::npos) { + label_names_.push_back(filename); + MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; + } + } + } else { + MS_LOG(WARNING) << "Mnist operator unable to open directory " << dir.toString() << "."; + } + + std::sort(image_names_.begin(), image_names_.end()); + std::sort(label_names_.begin(), label_names_.end()); + + if (image_names_.size() != label_names_.size()) { + RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels"); + } + + return Status::OK(); +} + +Status MnistOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&MnistOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(this->WalkAllFiles()); + RETURN_IF_NOT_OK(this->ParseMnistData()); + RETURN_IF_NOT_OK(this->InitSampler()); // handle shake with sampler + return Status::OK(); +} + +Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { + // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); + + RETURN_IF_NOT_OK(op->WalkAllFiles()); + + for (size_t i = 0; i < op->image_names_.size(); ++i) { + std::ifstream image_reader; + image_reader.open(op->image_names_[i], std::ios::binary); + std::ifstream label_reader; + label_reader.open(op->label_names_[i], std::ios::binary); + + uint32_t num_images; + RETURN_IF_NOT_OK(op->CheckImage(op->image_names_[i], &image_reader, &num_images)); + uint32_t num_labels; + RETURN_IF_NOT_OK(op->CheckLabel(op->label_names_[i], &label_reader, &num_labels)); + CHECK_FAIL_RETURN_UNEXPECTED((num_images == num_labels), "num of images does not equal to num of labels"); + *count = *count + num_images; + + // Close the readers + image_reader.close(); + label_reader.close(); + } + + return Status::OK(); +} + +// Visitor accept method for NodePass +Status MnistOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status MnistOp::ComputeColMap() { + // set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h new file mode 100644 index 0000000000..039f6b112f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mnist_op.h @@ -0,0 +1,252 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using MnistLabelPair = std::pair, int32_t>; + +class MnistOp : public ParallelOp, public RandomAccessOp { + public: + class Builder { + public: + // Constructor for Builder class of MnistOp + // @param uint32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method + // @param const std::string & dir + // @return + Builder &SetDir(const std::string &dir) { + builder_dir_ = dir; + return *this; + } + + // Check validity of input args + // @return - The error code return + Status SanityCheck(); + + // The builder "Build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + std::string builder_dir_; + int32_t builder_num_workers_; + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + }; + + // Constructor + // @param int32_t num_workers - number of workers reading images in parallel + // @param int32_t rows_per_buffer - number of images (rows) in each buffer + // @param std::string folder_path - dir directory of mnist + // @param int32_t queue_size - connector queue size + // @param std::unique_ptr data_schema - the schema of the mnist dataset + // @param td::unique_ptr sampler - sampler tells MnistOp what to read + MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, + std::unique_ptr data_schema, std::shared_ptr sampler); + + // Destructor. + ~MnistOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t worker_id - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of MnistOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // Method derived from RandomAccess Op, enable Sampler to get all ids for each class + // @param (std::map> * map - key label, val all ids for this class + // @return Status - The error code return + Status GetClassIds(std::map> *cls_ids) const override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // Function to count the number of samples in the MNIST dataset + // @param dir path to the MNIST directory + // @param count output arg that will hold the minimum of the actual dataset size and numSamples + // @return + static Status CountTotalRows(const std::string &dir, int64_t *count); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "MnistOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to a pair + // @param row_id_type row_id - id for this tensor row + // @param ImageLabelPair pair - + // @param TensorRow row - image & label read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const MnistLabelPair &mnist_pair, TensorRow *row); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Iterate through all members in sampleIds and fill them into IOBlock. + // @param std::shared_ptr sample_ids - + // @param std::vector *keys - keys in ioblock + // @return Status - The error code return + Status TraversalSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); + + // Check image file stream. + // @param const std::string *file_name - image file name + // @param std::ifstream *image_reader - image file stream + // @param uint32_t num_images - returns the number of images + // @return Status - The error code return + Status CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images); + + // Check label stream. + // @param const std::string &file_name - label file name + // @param std::ifstream *label_reader - label file stream + // @param uint32_t num_labels - returns the number of labels + // @return Status - The error code return + Status CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels); + + // Read 4 bytes of data from a file stream. + // @param std::ifstream *reader - file stream to read + // @return uint32_t - read out data + Status ReadFromReader(std::ifstream *reader, uint32_t *result); + + // Swap endian + // @param uint32_t val - + // @return uint32_t - swap endian data + uint32_t SwapEndian(uint32_t val) const; + + // Read the specified number of images and labels from the file stream + // @param std::ifstream *image_reader - image file stream + // @param std::ifstream *label_reader - label file stream + // @param int64_t read_num - number of image to read + // @return Status - The error code return + Status ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *label_reader, size_t index); + + // Parse all mnist dataset files + // @return Status - The error code return + Status ParseMnistData(); + + // Read all files in the directory + // @return Status - The error code return + Status WalkAllFiles(); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // reset Op + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int64_t buf_cnt_; + int64_t row_cnt_; + WaitPost wp_; + std::string folder_path_; // directory of image folder + int32_t rows_per_buffer_; + std::unique_ptr data_schema_; + std::vector image_label_pairs_; + std::vector image_names_; + std::vector label_names_; + QueueList> io_block_queues_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_MNIST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc new file mode 100644 index 0000000000..46f3adfa62 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.cc @@ -0,0 +1,426 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include +#include +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +RandomDataOp::Builder::Builder() + : builder_data_schema_(nullptr), + builder_num_workers_(0), + builder_op_connector_size_(0), + builder_rows_per_buffer_(0), + builder_total_rows_(0), + builder_sampler_(nullptr) { + // Some arguments to the RandomDataOp have a default argument that is taken from the config. + // The user may override these defaults by using the builder set methods. + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +// The build method that produces the instantiated RandomDataOp as a shared pointer +Status RandomDataOp::Builder::Build(std::shared_ptr *out_op) { + RETURN_IF_NOT_OK(SanityCheck()); + + *out_op = + std::make_shared(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_, + builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_)); + + // If the user did not provide a schema, then we will ask the op to generate a pseudo-random + // schema. + // See details of generateSchema function to learn what type of schema it will create. + if ((*out_op)->data_schema_ == nullptr) { + RETURN_IF_NOT_OK((*out_op)->GenerateSchema()); + } + + return Status::OK(); +} + +// Check if the required parameters are set by the builder. +Status RandomDataOp::Builder::SanityCheck() const { + // There actually is no required arguments for the random data op at all. + // Some arguments are preset with global values from config, and if they are not given by the user + // then we create them randomly. Leaving this function here for consistency with other operators. + return Status::OK(); +} + +// Constructor for RandomDataOp +RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, + std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + buffer_id_(0), + rows_per_buffer_(rows_per_buffer), + total_rows_(total_rows), + epoch_buffers_sent_(0), + guys_in_(0), + guys_out_(num_workers_), + eoe_worker_id_(0), + data_schema_(std::move(data_schema)) { + rand_gen_.seed(GetSeed()); // seed the random generator + // If total rows was not given, then randomly pick a number + if (total_rows_ == 0) { + total_rows_ = GenRandomInt(1, kMaxTotalRows); + } + // Everyone is already out from the sync area. + all_out_.Set(); +} + +// A print method typically used for debugging +void RandomDataOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [total rows: " << total_rows_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nTotal_rows: " << total_rows_ << "\nRows per buffer: " << rows_per_buffer_ << "\nSchema:\n" + << *data_schema_ << "\n\n"; + } +} + +// Helper function to produce a default/random schema if one didn't exist +Status RandomDataOp::GenerateSchema() { + if (data_schema_ != nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Generating a schema but one already exists!"); + } + + // To randomly create a schema, we need to choose: + // a) how many columns + // b) the type of each column + // c) the shape of each column (number of dimensions i.e. rank) + // d) the shape of each column (dimension values) + data_schema_ = std::make_unique(); + std::unique_ptr newShape; + std::unique_ptr newCol; + + // Loop over the number of chosen columns + int32_t numColumns = GenRandomInt(1, kMaxNumColumns); + for (int32_t i = 0; i < numColumns; i++) { + // For each column: + // - choose a datatype + // - generate a shape that randomly chooses the number of dimensions and the dimension values. + DataType::Type newType = static_cast(GenRandomInt(1, DataType::NUM_OF_TYPES - 2)); + int32_t rank = GenRandomInt(1, kMaxRank); + std::vector dims; + for (int32_t d = 0; d < rank; d++) { + // 0 is not a valid dimension value. however, we can support "*" or unknown, so map the random + // 0 value to the unknown attribute if 0 is chosen + dsize_t dim_value = static_cast(GenRandomInt(0, kMaxDimValue)); + if (dim_value == 0) dim_value = TensorShape::kDimUnknown; + dims.push_back(dim_value); + } + newShape = std::make_unique(dims); + + // Create the column descriptor + std::string colName = "c" + std::to_string(i); + newCol = std::make_unique(colName, DataType(newType), TensorImpl::kFlexible, rank, newShape.get()); + + data_schema_->AddColumn(*newCol); + } + + return Status::OK(); +} + +// Class functor operator () override. +// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work. +Status RandomDataOp::operator()() { + // First, compute how many buffers we'll need to satisfy the total row count. + // The only reason we do this is for the purpose of throttling worker count if needed. + int64_t buffers_needed = total_rows_ / rows_per_buffer_; + if (total_rows_ % rows_per_buffer_ != 0) { + buffers_needed++; + } + + // If the amount of workers we have exceeds the number of buffers to produce, then we'll have + // idle workers doing nothing. In that case, let's throttle the worker count. + if (num_workers_ > buffers_needed) { + MS_LOG(INFO) << "RandomDataOp throttling worker count from " << num_workers_ << "to " << buffers_needed; + num_workers_ = buffers_needed; + num_producers_ = num_workers_; + guys_out_ = num_workers_; + // The output connector was already created with a different worker count. We have to drop and recreate + // that connector. + DatasetOp::CreateConnector(num_producers_, num_workers_); + } + + // Assign the number of rows to each worker in a round robin fashion. + worker_max_rows_.reserve(num_workers_); + worker_rows_packed_.reserve(num_workers_); + // init the counts to zero to start. + for (int32_t w = 0; w < num_workers_; w++) { + worker_max_rows_.push_back(0); + worker_rows_packed_.push_back(0); + } + // then assign round robin row counts + int32_t currentWorker = 0; + for (int64_t r = 0; r < total_rows_; r++) { + worker_max_rows_[currentWorker]++; + currentWorker = (currentWorker + 1) % num_workers_; + } + + // Next, compute the total buffer count. This stat is needed during reset logic + for (int32_t w = 0; w < num_workers_; w++) { + int64_t worker_buffers = 0; + worker_buffers = worker_max_rows_[w] / rows_per_buffer_; + if (worker_max_rows_[w] % rows_per_buffer_ != 0) worker_buffers++; + epoch_buffers_sent_ += worker_buffers; + } + + // For the connector to work, we need to target the correct worker channel for the eoe. + // This will initialize it for the first one. reset() handles for the rest of the epochs. + eoe_worker_id_ = epoch_buffers_sent_ % num_workers_; + epoch_buffers_sent_++; // Add the eoe buffer to the count for subsequent epochs + + // RandomDataOp doesn't need the master thread to stay around. Kick off the workers and then master exits. + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&RandomDataOp::WorkerEntry, this, std::placeholders::_1))); + + // required task group setup after launching workers + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(epoch_sync_wait_post_.Register(tree_->AllTasks())); + + return Status::OK(); +} + +// Performs a synchronization between workers at the end of an epoch +Status RandomDataOp::EpochSync(int32_t worker_id, bool *quitting) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " syncing at end of epoch"; + + // Sync on the guys_in counter + // We have to wait the last guy is out. + all_out_.Wait(); + // If we are not in a repeat loop, or that was the last repeat already, then setup our exit + // condition from the master loop. + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + *quitting = true; + } + + auto prev = guys_in_.fetch_add(1); + bool last_guy_in = (prev + 1) == num_workers_; + // If we are the last worker to hit this sync point, we have some extra tasks + if (last_guy_in) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is the last one to sync. eoe sent as worker " + << eoe_worker_id_; + // Prepare for sync + all_out_.Clear(); + // Always flow eoe at the end + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(eoe_worker_id_, std::move(eoe_buffer))); + // If we're done then also flow the eof + if (*quitting) { + // The eof needs to be sent from the next sender in the round robin, so +1 + int32_t eof_worker_id = (eoe_worker_id_ + 1) % num_workers_; + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " has no more epochs. sending eof as worker " + << eof_worker_id; + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(eof_worker_id, std::move(eof_buffer))); + } + } + + // Wait for the reset to wake us up if we're not quitting + if (!(*quitting)) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entering sync wait."; + RETURN_IF_NOT_OK(epoch_sync_wait_post_.Wait()); + prev = guys_out_.fetch_add(1); + bool last_guy_out = (prev + 1) == num_workers_; + // Last guy out will clear the wait post and set the row counts + if (last_guy_out) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " last guy out clearing wait post."; + epoch_sync_wait_post_.Clear(); + guys_in_ = 0; + all_out_.Set(); + } + } + + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " epoch sync complete."; + return Status::OK(); +} + +// The entry point code for when workers are launched +Status RandomDataOp::WorkerEntry(int32_t worker_id) { + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " entry"; + + // handshake with the master first to tell it we're alive + TaskManager::FindMe()->Post(); + + bool quitting = false; + std::unique_ptr new_tensor_table = nullptr; + + // Loop until the quitting variable gets set to true + do { + // If we have not yet reached the row count for this worker then produce another record + if (worker_rows_packed_[worker_id] < worker_max_rows_[worker_id]) { + TensorRow new_row; + + // Start a new tensor table if needed + if (new_tensor_table == nullptr) { + new_tensor_table = std::make_unique(); + } + + // Create the data for the row + RETURN_IF_NOT_OK(CreateRandomRow(worker_id, &new_row)); + + // Add the row to our table + new_tensor_table->push_back(std::move(new_row)); + worker_rows_packed_[worker_id]++; + + // If the tensor table is at capacity then it's time to send it to output + if (new_tensor_table->size() == rows_per_buffer_) { + RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); + } + } else { + // We've reached the total row count for this worker, so it's time for epoch sync. + // There is likely some records built but not sent yet, so take care of those first + // (this buffer will be smaller than rows_per_buffer) + if (new_tensor_table != nullptr && new_tensor_table->size() > 0) { + RETURN_IF_NOT_OK(PackAndSend(worker_id, std::move(new_tensor_table))); + } + + // Now, let's enter the epoch sync + RETURN_IF_NOT_OK(EpochSync(worker_id, &quitting)); + } + } while (!quitting); + + MS_LOG(INFO) << "RandomDataOp worker " << worker_id << " is now quitting."; + + return Status::OK(); +} + +// A helper function to stuff the tensor table into a buffer and send it to output connector +Status RandomDataOp::PackAndSend(int32_t worker_id, std::unique_ptr in_table) { + auto new_buffer = std::make_unique(GetNextBufferId(), DataBuffer::kDeBFlagNone); + new_buffer->set_tensor_table(std::move(in_table)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(new_buffer))); + return Status::OK(); +} + +// A helper function to create random data for the row +Status RandomDataOp::CreateRandomRow(int32_t worker_id, TensorRow *new_row) { + if (new_row == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Missing tensor row output"); + } + + // Create a tensor for each column, then add the tensor to the row + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + const ColDescriptor current_col = data_schema_->column(i); + std::vector current_shape = current_col.shape().AsVector(); + std::unique_ptr new_shape = nullptr; + std::unique_ptr buf = nullptr; + std::shared_ptr new_tensor = nullptr; + + // We need to resolve the shape to fill in any unknown dimensions with random + // values, then use that as our shape for this tensor. + for (int j = 0; j < current_shape.size(); ++j) { + if (current_shape[j] == TensorShape::kDimUnknown) { + current_shape[j] = static_cast(GenRandomInt(1, kMaxDimValue)); + } + } + + new_shape = std::make_unique(current_shape); + int64_t size_in_bytes = new_shape->NumOfElements() * current_col.type().SizeInBytes(); + + // Generate a random byte of data. This may cause some funny data for things like doubles,floats, bools + // however the random data op is not too concerned about the physical data itself. + std::uniform_int_distribution uniDist(0, 255); + uint8_t random_byte = uniDist(rand_gen_); + + // Now, create a chunk of memory for the entire tensor and copy this byte in repeatedly. + buf = std::make_unique(size_in_bytes); + int ret_code = memset_s(buf.get(), size_in_bytes, random_byte, size_in_bytes); + if (ret_code != 0) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Failed to set random bytes for a tensor."); + } + + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&new_tensor, current_col.tensorImpl(), *new_shape, current_col.type(), buf.get())); + + // Add this tensor to the tensor row for output + (*new_row).push_back(std::move(new_tensor)); + } + return Status::OK(); +} + +// Overrides base class reset method. When an operator does a reset, it cleans up any state +// info from it's previous execution and then initializes itself so that it can be executed +// again. +Status RandomDataOp::Reset() { + MS_LOG(INFO) << "RandomDataOp resetting."; + + // Ensure all guys are in the waitpost + if (guys_in_ != num_workers_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "Issuing a reset, but some workers are missing from epochSync!"); + } + + // reset the row counters for all workers + for (int32_t w = 0; w < num_workers_; w++) { + worker_rows_packed_[w] = 0; + worker_max_rows_[w] = 0; + } + buffer_id_ = 0; + + // Re-assign round robin row counts, starting from the worker after the one that gave + // the eoe last time + int32_t currentWorker = (eoe_worker_id_ + 1) % num_workers_; + for (int64_t r = 0; r < total_rows_; r++) { + worker_max_rows_[currentWorker]++; + currentWorker = (currentWorker + 1) % num_workers_; + } + + // Compute which worker should get the eoe for the next epoch + eoe_worker_id_ = ((epoch_buffers_sent_ % num_workers_) + eoe_worker_id_) % num_workers_; + + // Wake up the workers to get them going again in a new epoch + guys_out_ = 0; + epoch_sync_wait_post_.Set(); + + return Status::OK(); +} + +// Visitor accept method for NodePass +Status RandomDataOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status RandomDataOp::ComputeColMap() { + // Extract the column name mapping from the schema and save it in the class. + if (column_name_id_map_.empty()) { + RETURN_IF_NOT_OK(data_schema_->GetColumnNameMap(&(column_name_id_map_))); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h new file mode 100644 index 0000000000..c77695439d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/random_data_op.h @@ -0,0 +1,291 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ + +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// The RandomDataOp is a leaf node storage operator that generates random data based +// on the schema specifications. Typically, it's used for testing and demonstrating +// various dataset operator pipelines. It is not "real" data to train with. +// The data that is random created is just random and repeated bytes, there is no +// "meaning" behind what these bytes are. +class RandomDataOp : public ParallelOp { + public: + // Some constants to provide limits to random generation. + static constexpr int32_t kMaxNumColumns = 4; + static constexpr int32_t kMaxRank = 4; + static constexpr int32_t kMaxDimValue = 32; + static constexpr int32_t kMaxTotalRows = 1024; + + // A nested builder class to aid in the construction of a RandomDataOp + class Builder { + public: + /** + * Builder constructor. Creates the builder object. + * @note No default args. + * @return This is a constructor. + */ + Builder(); + + /** + * Default destructor + */ + ~Builder() = default; + + /** + * The build method that produces the instantiated RandomDataOp as a shared pointer + * @param out_op - The output RandomDataOperator that was constructed + * @return Status - The error code return + */ + Status Build(std::shared_ptr *out_op); + + /** + * Builder set method + * @param data_schema - A user-provided schema + * @return Builder - The modified builder by reference + */ + Builder &SetDataSchema(std::unique_ptr data_schema) { + builder_data_schema_ = std::move(data_schema); + return *this; + } + + /** + * Builder set method + * @param num_workers - The number of workers + * @return Builder - The modified builder by reference + */ + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + /** + * Builder set method + * @param op_connector_size - The size of the output connector + * @return Builder - The modified builder by reference + */ + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + /** + * Builder set method + * @param rows_per_buffer - The number of rows in each DataBuffer + * @return Builder - The modified builder by reference + */ + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + /** + * Builder set method + * @param total_rows - The total number of rows in the dataset + * @return Builder - The modified builder by reference + */ + Builder &SetTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + private: + /** + * Check if the required parameters are set by the builder. + * @return Status - The error code return + */ + Status SanityCheck() const; + + std::unique_ptr builder_data_schema_; + std::shared_ptr builder_sampler_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_total_rows_; + }; // class Builder + + /** + * Constructor for RandomDataOp + * @note Private constructor. Must use builder to construct. + * @param num_workers - The number of workers + * @param op_connector_size - The size of the output connector + * @param rows_per_buffer - The number of rows in each DataBuffer + * @param data_schema - A user-provided schema + * @param total_rows - The total number of rows in the dataset + * @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes + * @return Builder - The modified builder by reference + */ + RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows, + std::unique_ptr data_schema, std::shared_ptr sampler); + + /** + * Destructor + */ + ~RandomDataOp() = default; + + /** + * A print method typically used for debugging + * @param out - The output stream to write output to + * @param show_all - A bool to control if you want to show all info or just a summary + */ + void Print(std::ostream &out, bool show_all) const override; + + /** + * << Stream output operator overload + * @notes This allows you to write the debug print info using stream operators + * @param out - reference to the output stream being overloaded + * @param so - reference to the ShuffleOp to display + * @return - the output stream must be returned + */ + friend std::ostream &operator<<(std::ostream &out, const RandomDataOp &op) { + op.Print(out, false); + return out; + } + + /** + * Class functor operator () override. + * All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will + * provide the master loop that drives the logic for performing the work. + * @return Status - The error code return + */ + Status operator()() override; + + /** + * Overrides base class reset method. When an operator does a reset, it cleans up any state + * info from it's previous execution and then initializes itself so that it can be executed + * again. + * @return Status - The error code return + */ + Status Reset() override; + + /** + * Quick getter for total rows. + */ + int64_t GetTotalRows() const { return total_rows_; } + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "RandomDataOp"; } + + private: + /** + * The entry point code for when workers are launched + * @param worker_id - The worker id + * @return Status - The error code return + */ + Status WorkerEntry(int32_t worker_id) override; + + /** + * Helper function to produce a default/random schema if one didn't exist + @return Status - The error code return + */ + Status GenerateSchema(); + + /** + * Performs a synchronization between workers at the end of an epoch + * @param worker_id - The worker id + * @return Status - The error code return + */ + Status EpochSync(int32_t worker_id, bool *quitting); + + /** + * A helper function to stuff the tensor table into a buffer and send it to output connector + * @param worker_id - The worker id + * @param in_table - The tensor table to pack and send + * @return Status - The error code return + */ + Status PackAndSend(int32_t worker_id, std::unique_ptr in_table); + + /** + * A helper function to create random data for the row + * @param worker_id - The worker id + * @param new_row - The output row to produce + * @return Status - The error code return + */ + Status CreateRandomRow(int32_t worker_id, TensorRow *new_row); + + /** + * A quick inline for producing a random number between (and including) min/max + * @param min - minimum number that can be generated + * @param max - maximum number that can be generated + * @return - The generated random number + */ + inline int32_t GenRandomInt(int32_t min, int32_t max) { + std::uniform_int_distribution uniDist(min, max); + return uniDist(rand_gen_); + } + + /** + * A quick inline for producing the next buffer id in sequence, threadsafe + * @return - The next buffer id. + */ + inline int32_t GetNextBufferId() { + std::unique_lock lock(buffer_id_mutex_); + return ++buffer_id_; + } + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t buffer_id_; + int64_t rows_per_buffer_; + int64_t total_rows_; + int64_t epoch_buffers_sent_; + std::atomic guys_in_; + std::atomic guys_out_; + int32_t eoe_worker_id_; + std::unique_ptr data_schema_; + std::vector worker_max_rows_; + std::vector worker_rows_packed_; + std::mt19937 rand_gen_; + WaitPost epoch_sync_wait_post_; + WaitPost all_out_; + std::mutex buffer_id_mutex_; +}; // class RandomDataOp +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_RANDOM_DATA_OP_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc new file mode 100644 index 0000000000..2b5e7c67c8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.cc @@ -0,0 +1,119 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" + +#include +#include + +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed) + : Sampler(num_samples, std::numeric_limits::max()), + cnt_(0), + seed_(seed == std::numeric_limits::max() ? GetSeed() : seed), + device_id_(dev_id), + num_devices_(num_dev), + shuffle_(shuffle) {} + +Status DistributedSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, + "fail to init DistributedSampler"); + rnd_.seed(seed_++); + samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices) + samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_; + if (shuffle_ == true) { + shuffle_vec_.reserve(num_rows_); + for (int64_t i = 0; i < num_rows_; i++) { + shuffle_vec_.push_back(i); + } + std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); + } + return Status::OK(); +} + +Status DistributedSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (cnt_ > samples_per_buffer_) { + RETURN_STATUS_UNEXPECTED("Distributed Sampler Error"); + } else if (cnt_ == samples_per_buffer_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(cnt_, DataBuffer::kDeBFlagNone); + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, samples_per_buffer_)); + auto id_ptr = sample_ids->begin(); + while (cnt_ < samples_per_buffer_ && id_ptr != sample_ids->end()) { + int64_t sampled_id = (num_devices_ * cnt_ + device_id_) % num_rows_; + if (shuffle_) { + sampled_id = shuffle_vec_[static_cast(sampled_id)]; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + cnt_++; + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status DistributedSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(cnt_ == samples_per_buffer_, "ERROR Reset() called early/late"); + cnt_ = 0; + + if (shuffle_ == true) { + rnd_.seed(seed_); + seed_++; + std::shuffle(shuffle_vec_.begin(), shuffle_vec_.end(), rnd_); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void DistributedSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: DistributedSampler"; + if (show_all) { + Sampler::Print(out, show_all); + out << "\nseed: " << seed_ << "\ndevice_id: " << device_id_ << "\nnum_devices: " << num_devices_ + << "\nshuffle: " << shuffle_; + } +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h new file mode 100644 index 0000000000..76bcf052f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class DistributedSampler : public Sampler { + public: + // @param num_samples + // @param int64_t num_dev + // @param int64_t dev_id + // @param bool shuffle + DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, + uint32_t seed = std::numeric_limits::max()); + + // default destructor + ~DistributedSampler() = default; + + // @param std::unique_ptr * pBuffer + // @param int32_t workerId + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Init sampler, called by base class or python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + void Print(std::ostream &out, bool show_all) const override; + + private: + int64_t cnt_; // number of samples that have already been filled in to buffer + uint32_t seed_; + int64_t device_id_; + int64_t num_devices_; + bool shuffle_; + std::mt19937 rnd_; + std::vector shuffle_vec_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_DISTRIBUTED_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc new file mode 100644 index 0000000000..770c24c8c5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.cc @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include +#include +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), + shuffle_(shuffle), + seed_(GetSeed()), + next_id_(0), + samples_per_class_(val) {} + +Status PKSampler::InitSampler() { + labels_.reserve(label_to_ids_.size()); + for (const auto &pair : label_to_ids_) { + if (pair.second.empty() == false) { + labels_.push_back(pair.first); + } + } + rnd_.seed(seed_++); + + // The special handshake gives the list of classes and id's, but it did not set the num_rows_ to + // capture the total number of possible sample ids. + // Compute that here for this case to find the total number of samples that are available to return. + // (in this case, samples per class * total classes). + num_rows_ = samples_per_class_ * static_cast(labels_.size()); + + // The user may have chosen to sample less than the total amount. + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + + samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; + if (shuffle_ == true) { + std::shuffle(labels_.begin(), labels_.end(), rnd_); + } else { + std::sort(labels_.begin(), labels_.end()); + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive"); + return Status::OK(); +} + +Status PKSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (next_id_ > num_samples_ || num_samples_ == 0) { + RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); + } else if (next_id_ == num_samples_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); + std::shared_ptr sample_ids; + int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; + RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); + auto id_ptr = sample_ids->begin(); + while (next_id_ < last_id && id_ptr != sample_ids->end()) { + int64_t cls_id = next_id_++ / samples_per_class_; + const std::vector &samples = label_to_ids_[labels_[cls_id]]; + int64_t rnd_ind = std::uniform_int_distribution(0, samples.size() - 1)(rnd_); + int64_t sampled_id = samples[rnd_ind]; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + } + + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status PKSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); + next_id_ = 0; + rnd_.seed(seed_++); + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +Status PKSampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + RETURN_UNEXPECTED_IF_NULL(op); + RETURN_IF_NOT_OK(op->GetClassIds(&label_to_ids_)); + RETURN_IF_NOT_OK(InitSampler()); + return Status::OK(); +} + +void PKSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: PKSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h new file mode 100644 index 0000000000..aed61fa273 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PKSampler : public Sampler { // NOT YET FINISHED + public: + // @param num_samples - the number of samples to draw. value of 0 means to take the full amount + // @param int64_t val + // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // default destructor + ~PKSampler() = default; + + // @param std::unique_ptr *out_buffer) override; + + // first handshake between leaf source op and Sampler. This func will determine the amount of data + // in the dataset that we can sample from. + // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is + // @return + Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; + + // init sampler, to be called by python or Handshake + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + bool shuffle_; + uint32_t seed_; + int64_t next_id_; + int64_t samples_per_class_; + std::mt19937 rnd_; + std::vector labels_; + std::map> label_to_ids_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PK_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc new file mode 100644 index 0000000000..50c67bca6c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/python_sampler.h" + +#include + +namespace mindspore { +namespace dataset { + +PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} + +Status PythonSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (need_to_reset_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + std::shared_ptr sample_ids; + { + py::gil_scoped_acquire gil_acquire; + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object py_ret = py_sampler_instance.attr("_get_indices")(); + py::array np_sample_ids = py_ret.cast(); + Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + + if (HasChildSampler()) { + for (auto it = sample_ids->begin(); it != sample_ids->end(); ++it) { + int64_t associated_child_id = 0; + RETURN_IF_NOT_OK(GetAssociatedChildId(&associated_child_id, associated_child_id)); + *it = associated_child_id; + } + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); + } + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + need_to_reset_ = true; + } + return Status::OK(); +} + +Status PythonSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +Status PythonSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); + need_to_reset_ = false; + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("reset")(); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void PythonSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: PythonSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h new file mode 100644 index 0000000000..61716feb94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ + +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PythonSampler : public Sampler { + public: + // Constructor + // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the + // data from the dataset. + // @param py_sampler_instance - the python instance of the sampler + // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~PythonSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() + + py::object py_sampler_instance; // The handle to the py_sampler python object +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc new file mode 100644 index 0000000000..998dee2a07 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.cc @@ -0,0 +1,124 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" + +#include +#include +#include +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), + seed_(GetSeed()), + replacement_(replacement), + next_id_(0), + reshuffle_each_epoch_(reshuffle_each_epoch), + dist(nullptr) {} + +Status RandomSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (next_id_ > num_samples_) { + RETURN_STATUS_UNEXPECTED("RandomSampler Internal Error"); + } else if (next_id_ == num_samples_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + (*out_buffer) = std::make_unique(next_id_, DataBuffer::kDeBFlagNone); + + std::shared_ptr sampleIds; + int64_t last_id = std::min(samples_per_buffer_ + next_id_, num_samples_); + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, last_id - next_id_)); + auto id_ptr = sampleIds->begin(); + + for (int64_t i = 0; i < (last_id - next_id_); i++) { + int64_t sampled_id = 0; + if (replacement_) { + sampled_id = (*dist)(rnd_); + } else { + sampled_id = shuffled_ids_[static_cast(i + next_id_)]; + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *(id_ptr + i) = sampled_id; + } + next_id_ = last_id; + TensorRow row(1, sampleIds); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status RandomSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + rnd_.seed(seed_); + + if (replacement_ == false) { + shuffled_ids_.reserve(num_rows_); + for (int64_t i = 0; i < num_rows_; i++) { + shuffled_ids_.push_back(i); + } + std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); + } else { + dist = std::make_unique>(0, num_rows_ - 1); + } + + return Status::OK(); +} + +Status RandomSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); + next_id_ = 0; + + if (reshuffle_each_epoch_) { + seed_++; + } + + rnd_.seed(seed_); + + if (replacement_ == false && reshuffle_each_epoch_) { + std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void RandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: RandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h new file mode 100644 index 0000000000..6e21b088b9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/random_sampler.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class RandomSampler : public Sampler { + public: + // Constructor + // @param int64_t num_samples - number samples to draw + // @param bool replacement - put he id back / or not after a sample + // @param reshuffle_each_epoch - T/F to reshuffle after epoch + // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~RandomSampler() = default; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // meant to be called by base class or python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + virtual void Print(std::ostream &out, bool show_all) const; + + private: + uint32_t seed_; + bool replacement_; + std::vector shuffled_ids_; // only used for NO REPLACEMENT + int64_t next_id_; + std::mt19937 rnd_; + std::unique_ptr> dist; + bool reshuffle_each_epoch_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc new file mode 100644 index 0000000000..60d75d2eec --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.cc @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +#include + +namespace mindspore { +namespace dataset { +Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { + // The sampler base class itself does not compute it's own num_rows_ value. + // Instead, this value is computed by the derived leaf op during it's own initialization + // after it has interacted with it's storage layers. + // Here, it is just a getter method to return the value. However, it is invalid if there is + // not a value set for this count, so generate a failure if that is the case. + if (num == nullptr || num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); + } + (*num) = num_rows_; + return Status::OK(); +} + +Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) + : num_rows_(0), num_samples_(num_samples), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} + +Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { + std::shared_ptr child_sampler; + if (HasChildSampler()) { + child_sampler = std::dynamic_pointer_cast(child_[0]); + if (!child_sampler) { + std::string err_msg("Cannot handshake, child is not a sampler object."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Handshake and init child first. + RETURN_IF_NOT_OK(child_sampler->HandshakeRandomAccessOp(op)); + } + + CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); + + // If there's a child sampler, set the row count to be it's sample count + if (HasChildSampler()) { + num_rows_ = child_sampler->num_samples_; + } else { + RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); + } + + // It's up to the derived class to check the validity of the two args + // Because some sampler only needs one of the arg (weighted_random_sampler) + RETURN_IF_NOT_OK(InitSampler()); // init sampler after callback + + return Status::OK(); +} + +Status Sampler::CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements) { + if (num_elements == 0) { + RETURN_STATUS_UNEXPECTED("num of Elements is 0"); + } + if (col_desc_ == nullptr) { + // a ColDescriptor for Tensor that holds SampleIds + col_desc_ = std::make_unique("sampleIds", DataType(DataType::DE_INT64), TensorImpl::kFlexible, 1); + } + TensorShape shape(std::vector(1, num_elements)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type())); + RETURN_IF_NOT_OK( + (*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets! + return Status::OK(); +} + +void Sampler::Print(std::ostream &out, bool show_all) const { + // Sampler printing is usually only called in the show_all mode. + // Derived classes will display the name, then call back to this base + // for common info. + // No-op in the summary mode. + if (show_all) { + out << "\nnum_rows_: " << num_rows_ << "\nnum_samples_: " << num_samples_; + } +} + +#ifdef ENABLE_PYTHON +Status Sampler::GetAllIdsThenReset(py::array *data) { + std::unique_ptr db; + std::shared_ptr sample_ids; + TensorRow sample_row; + + // A call to derived class to get sample ids wrapped inside a buffer + RETURN_IF_NOT_OK(GetNextSample(&db)); + // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch + RETURN_IF_NOT_OK(db->GetRow(0, &sample_row)); + sample_ids = sample_row[0]; + + // check this buffer is not a ctrl buffer + CHECK_FAIL_RETURN_UNEXPECTED(db->buffer_flags() == DataBuffer::kDeBFlagNone, "ERROR ctrl buffer received"); + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + RETURN_IF_NOT_OK(sample_ids->GetDataAsNumpy(data)); + } catch (const std::runtime_error &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + // perform error checking! Next buffer supposed to be EOE since last one already contains all ids for current epoch + RETURN_IF_NOT_OK(GetNextSample(&db)); + CHECK_FAIL_RETURN_UNEXPECTED(db->eoe(), "ERROR Non EOE received"); + // Reset Sampler since this is the end of the epoch + RETURN_IF_NOT_OK(ResetSampler()); + return Status::OK(); +} +#endif + +Status Sampler::SetNumSamples(int64_t num_samples) { + CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative"); + num_samples_ = num_samples; + return Status::OK(); +} + +Status Sampler::SetNumRowsInDataset(int64_t num_rows) { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows > 0, "num_rows is negative or 0"); + num_rows_ = num_rows; + return Status::OK(); +} + +Status Sampler::AddChild(std::shared_ptr child) { + if (child == nullptr) { + return Status::OK(); + } + + // Only samplers can be added, not any other DatasetOp. + std::shared_ptr sampler = std::dynamic_pointer_cast(child); + if (!sampler) { + std::string err_msg("Cannot add child, child is not a sampler object."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Samplers can have at most 1 child. + if (!child_.empty()) { + std::string err_msg("Cannot add child sampler, this sampler already has a child."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + child_.push_back(child); + + // doesn't work, protected? + // child->AddParent(this); + return Status::OK(); +} + +bool Sampler::HasChildSampler() { return !child_.empty(); } + +Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { + if (child_ids_ == nullptr) { + RETURN_STATUS_UNEXPECTED("Trying to get associated child id, but there are no child ids!"); + } + + TensorRow sample_row; + RETURN_IF_NOT_OK(child_ids_->GetRow(0, &sample_row)); + std::shared_ptr sample_ids = sample_row[0]; + RETURN_IF_NOT_OK(sample_ids->GetItemAt(out_associated_id, {id})); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h new file mode 100644 index 0000000000..4cae935a42 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sampler.h @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +namespace mindspore { +namespace dataset { +// RandomAccessOp is a base class that all data-producing leaf operators +// must inherit from if those leaf operator wish to support sampling. +class RandomAccessOp { + public: + // Sampler get number of rows in the dataset + // @param int64_t num - return number of rows for this dataset + // @return - The error code return + Status GetNumRowsInDataset(int64_t *num_rows) const; + + // sampler gets label , imageIds from corresponding Dataset Op, this function is unique to PK + // @param std::map> * map + // @return - The error code return + virtual Status GetClassIds(std::map> *map) const { + RETURN_STATUS_UNEXPECTED("GetClassIds needs to be override to support PK"); + } + + // default destructor + virtual ~RandomAccessOp() = default; + + protected: + // The amount of rows in the dataset itself. This is the before-sampling value, the + // total count of rows. A sampler may choose to sample less than this amount. + int64_t num_rows_; +}; + +class Sampler { + public: + // Constructor + // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 + // indicates that the sampler should produce the complete set of ids. + // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); + + Sampler(const Sampler &s) : Sampler(s.num_samples_, s.samples_per_buffer_) {} + + // default destructor + ~Sampler() = default; + + // Get a list of sample ids. + // @note It is Sampler responsibility to make sure that the id is not out of bound. + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + virtual Status GetNextSample(std::unique_ptr *out_buffer) = 0; + +// This function only called by python layer. Not needed by Android. +#ifdef ENABLE_PYTHON + // return all ids in one epoch as a numpy array, then call reset + Status GetAllIdsThenReset(py::array *data); +#endif + + // for next epoch of sampleIds + // @return - The error code return + virtual Status ResetSampler() = 0; + + // first handshake between leaf source op and Sampler. This func will determine the amount of data + // in the dataset that we can sample from. + // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is + // @return + virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); + + // initialize sampler and perform checks on certain vars + virtual Status InitSampler() { return Status::OK(); } + + // setter for num samples + // @param num_samples - the number of samples to assign. + // @return status error code + Status SetNumSamples(int64_t num_samples); + + // setter for num or records in the dataset + // @param num_rows - the number of records + // @return status error code + Status SetNumRowsInDataset(int64_t num_rows); + + // Adds a sampler to become our child. + // @param std::shared_ptr - The sampler to add as a child. + // @return - The error code returned. + Status AddChild(std::shared_ptr child); + + // A helper function to create a int64_t 1-D Tensor specifically used to hold sampleIds for Sampler + // @param std::shared_ptr* sampleIds + // @param int64_t numElements - must be a non 0 number + // @return - The error code returned. + Status CreateSamplerTensor(std::shared_ptr *sample_ids, int64_t num_elements); + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + virtual void Print(std::ostream &out, bool show_all) const; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param sampler - reference to teh sampler to print + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const Sampler &sampler) { + sampler.Print(out, false); + return out; + } + + // Checks if this sampler has a child sampler. + // @return - tre if there is a child sampler, false otherwise. + bool HasChildSampler(); + + // Uses id as an index for the list of ids generated by the child sampler, and gets the + // associated id. + // @param int64_t* out_associated_id - Out parameter, contains the associated id. + // @param int64_t id - The id used as an index to get the associated child id. + // @return - The error code returned. + Status GetAssociatedChildId(int64_t *out_associated_id, int64_t id); + + protected: + // Number of rows of data from the place this sampler is sampling from. If this sampler + // has a child sampler, num_rows_ is the number of ids the child sampler will + // output. Otherwise, num_rows_ is the number of rows in the dataset. + int64_t num_rows_; + + // The user may want to sample less than the full amount of data. num_samples_ reduces the number + // of id's returned as request by the user. Derived classes will choose how to sample the smaller + // amount. + int64_t num_samples_; + + int64_t samples_per_buffer_; + std::unique_ptr col_desc_; + std::vector> child_; // Child nodes + std::unique_ptr child_ids_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc new file mode 100644 index 0000000000..1cc4ac831a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" + +#include +#include + +namespace mindspore { +namespace dataset { +SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} + +Status SequentialSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (id_count_ > num_samples_) { + RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); + } else if (id_count_ == num_samples_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(current_id_, DataBuffer::kDeBFlagNone); + std::shared_ptr sampleIds; + + // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for + // samples per buffer though. + int64_t remaining_ids = num_samples_ - id_count_; + int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); + + RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); + auto idPtr = sampleIds->begin(); + for (int64_t i = 0; i < num_elements; i++) { + int64_t sampled_id = current_id_; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *idPtr = sampled_id; + current_id_++; // Move the current id to the next one in the sequence + idPtr++; + } + + id_count_ += num_elements; // Count the packed ids towards our overall sample count + + TensorRow row(1, sampleIds); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + } + return Status::OK(); +} + +Status SequentialSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); + CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n"); + // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample + // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data. + int64_t available_row_count = num_rows_ - start_index_; + if (num_samples_ == 0 || num_samples_ > available_row_count) { + num_samples_ = available_row_count; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); + samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; + return Status::OK(); +} + +Status SequentialSampler::ResetSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); + current_id_ = start_index_; + id_count_ = 0; + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +void SequentialSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: SequentialSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info + out << "\nStart index: " << start_index_; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h new file mode 100644 index 0000000000..c6ccd0d1eb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ + +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class SequentialSampler : public Sampler { + public: + // Constructor + // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the + // full amount of ids from the dataset + // @param start_index - The starting index value + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit SequentialSampler(int64_t num_samples, int64_t start_index, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~SequentialSampler() = default; + + // init sampler, called by python + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status ResetSampler() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to corresponding Dataset Op + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + int64_t current_id_; // The id sequencer. Each new id increments from this + int64_t start_index_; // The starting id. current_id_ begins from here. + int64_t id_count_; // An internal counter that tracks how many ids have been produced +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SEQUENTIAL_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc new file mode 100644 index 0000000000..db2078795e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" + +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +// Constructor. +SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector &indices, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} + +// Initialized this Sampler. +Status SubsetRandomSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); + + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // In this case, the id's are provided by the user. Cap the num_samples on the number of id's given. + if (num_samples_ == 0 || num_samples_ > static_cast(indices_.size())) { + num_samples_ = static_cast(indices_.size()); + } + // Initialize random generator with seed from config manager + rand_gen_.seed(GetSeed()); + + if (samples_per_buffer_ > num_samples_) { + samples_per_buffer_ = num_samples_; + } + + // num_samples_ could be smaller than the total number of input id's. + // We will shuffle the full set of id's, but only select the first num_samples_ of them later. + std::shuffle(indices_.begin(), indices_.end(), rand_gen_); + + return Status::OK(); +} + +// Reset the internal variable to the initial state. +Status SubsetRandomSampler::ResetSampler() { + // Reset the internal counters. + sample_id_ = 0; + buffer_id_ = 0; + + // Randomized the indices again. + rand_gen_.seed(GetSeed()); + std::shuffle(indices_.begin(), indices_.end(), rand_gen_); + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +// Get the sample ids. +Status SubsetRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { + // All samples have been drawn + if (sample_id_ == num_samples_) { + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); + std::shared_ptr outputIds; + + int64_t last_id = sample_id_ + samples_per_buffer_; + // Handling the return all samples at once, and when last draw is not a full batch. + if (last_id > num_samples_) { + last_id = num_samples_; + } + + // Allocate tensor + RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); + + // Initialize tensor + auto id_ptr = outputIds->begin(); + while (sample_id_ < last_id) { + if (indices_[sample_id_] >= num_rows_) { + std::string err_msg = + "Generated id is bigger than numRows (out of bound). indices_: " + std::to_string(indices_[sample_id_]) + + " num_rows_: " + std::to_string(num_rows_); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + int64_t sampled_id = indices_[sample_id_]; + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); + } + + *id_ptr = sampled_id; + id_ptr++; + sample_id_++; + } + + // Create a TensorTable from that single tensor and push into DataBuffer + (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); + } + + return Status::OK(); +} + +void SubsetRandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: SubsetRandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h new file mode 100644 index 0000000000..fccc15e57b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +// Randomly samples elements from a given list of indices, without replacement. +class SubsetRandomSampler : public Sampler { + public: + // Constructor. + // @param num_samples The number of samples to draw. 0 for the full amount. + // @param indices List of indices from where we will randomly draw samples. + // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). + // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. + explicit SubsetRandomSampler(int64_t num_samples, const std::vector &indices, + std::int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~SubsetRandomSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // Reset the internal variable to the initial state and reshuffle the indices. + // @return Status + Status ResetSampler() override; + + // Get the sample ids. + // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. + // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + // A list of indices (already randomized in constructor). + std::vector indices_; + + // Current sample id. + int64_t sample_id_; + + // Current buffer id. + int64_t buffer_id_; + + // A random number generator. + std::mt19937 rand_gen_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_RANDOM_SAMPLER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc new file mode 100644 index 0000000000..13863143c0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +// Constructor. +WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, + int64_t samples_per_buffer) + : Sampler(num_samples, samples_per_buffer), + weights_(weights), + replacement_(replacement), + sample_id_(0), + buffer_id_(0) {} + +// Initialized this Sampler. +Status WeightedRandomSampler::InitSampler() { + // Special value of 0 for num_samples means that the user wants to sample the entire set of data. + // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. + if (num_samples_ == 0 || num_samples_ > num_rows_) { + num_samples_ = num_rows_; + } + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); + CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); + + // Initialize random generator with seed from config manager + rand_gen_.seed(GetSeed()); + + samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; + + if (!replacement_) { + exp_dist_ = std::make_unique>(1); + InitOnePassSampling(); + } else { + discrete_dist_ = std::make_unique>(weights_.begin(), weights_.end()); + } + + return Status::OK(); +} + +// Initialized the computation for generating weighted random numbers without replacement using onepass method. +void WeightedRandomSampler::InitOnePassSampling() { + exp_dist_->reset(); + onepass_ids_.clear(); + std::vector> val_idx; + for (size_t i = 0; i < weights_.size(); i++) { + val_idx.emplace_back(std::make_pair((*exp_dist_)(rand_gen_) / weights_[i], i)); + } + + // Partial sort the first `numSamples` elements. + std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); + for (int64_t i = 0; i < num_samples_; i++) { + onepass_ids_.push_back(val_idx[i].second); + } +} + +// Reset the internal variable to the initial state and reshuffle the indices. +Status WeightedRandomSampler::ResetSampler() { + sample_id_ = 0; + buffer_id_ = 0; + rand_gen_.seed(GetSeed()); + if (!replacement_) { + InitOnePassSampling(); + } else { + discrete_dist_->reset(); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->ResetSampler()); + } + + return Status::OK(); +} + +// Get the sample ids. +Status WeightedRandomSampler::GetNextSample(std::unique_ptr *out_buffer) { + if (weights_.size() > static_cast(num_rows_)) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); + } + + if (!replacement_ && (weights_.size() < static_cast(num_samples_))) { + RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); + } + + if (sample_id_ == num_samples_) { + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagEOE); + } else { + if (HasChildSampler()) { + RETURN_IF_NOT_OK(child_[0]->GetNextSample(&child_ids_)); + } + + (*out_buffer) = std::make_unique(buffer_id_++, DataBuffer::kDeBFlagNone); + std::shared_ptr outputIds; + + int64_t last_id = sample_id_ + samples_per_buffer_; + // Handling the return all samples at once, and when last draw is not a full batch. + if (last_id > num_samples_) { + last_id = num_samples_; + } + + // Allocate tensor. + RETURN_IF_NOT_OK(CreateSamplerTensor(&outputIds, last_id - sample_id_)); + + // Initialize tensor. + auto id_ptr = outputIds->begin(); + // Assign the data to tensor element. + while (sample_id_ < last_id) { + int64_t genId; + if (replacement_) { + genId = (*discrete_dist_)(rand_gen_); + } else { + // Draw sample without replacement. + genId = onepass_ids_.front(); + onepass_ids_.pop_front(); + } + + if (genId >= num_rows_) { + RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound)."); + } + + if (HasChildSampler()) { + RETURN_IF_NOT_OK(GetAssociatedChildId(&genId, genId)); + } + + *id_ptr = genId; + id_ptr++; + sample_id_++; + } + + // Create a TensorTable from that single tensor and push into DataBuffer + (*out_buffer)->set_tensor_table(std::make_unique(1, TensorRow(1, outputIds))); + } + + return Status::OK(); +} + +void WeightedRandomSampler::Print(std::ostream &out, bool show_all) const { + out << "\nSampler: WeightedRandomSampler"; + if (show_all) { + // Call the super class for displaying any common detailed info + Sampler::Print(out, show_all); + // Then add our own info if any + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h new file mode 100644 index 0000000000..b1a531abe9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_WEIGHTED_RANDOM_SAMPLER_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +// Samples elements from id `0, 1, ..., weights.size()-1` with given probabilities (weights). +class WeightedRandomSampler : public Sampler { + public: + // Constructor. + // @param num_samples Number of samples to be drawn. + // @param weights A lift of sample weights. + // @param replacement Determine if samples are drawn with/without replacement. + // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). + // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. + WeightedRandomSampler(int64_t num_samples, const std::vector &weights, bool replacement, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~WeightedRandomSampler() = default; + + // Initialize the sampler. + // @param op (Not used in this sampler) + // @return Status + Status InitSampler() override; + + // Reset the internal variable to the initial state and reshuffle the indices. + Status ResetSampler() override; + + // Get the sample ids. + // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. + // @note the sample ids (int64_t) will be placed in one Tensor and be placed into pBuffer. + Status GetNextSample(std::unique_ptr *out_buffer) override; + + // Printer for debugging purposes. + // @param out - output stream to write to + // @param show_all - bool to show detailed vs summary + void Print(std::ostream &out, bool show_all) const override; + + private: + // A list of weights for each sample. + std::vector weights_; + + // A flag indicating if samples are drawn with/without replacement. + bool replacement_; + + // Current sample id. + int64_t sample_id_; + + // Current buffer id. + int64_t buffer_id_; + + // Random engine and device + std::mt19937 rand_gen_; + + // Discrete distribution for generating weighted random numbers with replacement. + std::unique_ptr> discrete_dist_; + + // Exponential distribution for generating weighted random numbers without replacement. + // based on "Accelerating weighted random sampling without replacement" by Kirill Muller. + std::unique_ptr> exp_dist_; + + // Initialized the computation for generating weighted random numbers without replacement + // using onepass method. + void InitOnePassSampling(); + + // Store the random weighted ids generated by onepass method in `InitOnePassSampling` + std::deque onepass_ids_; +}; +} // namespace dataset +} // namespace mindspore + +#endif diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc new file mode 100644 index 0000000000..c1f5b13a94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.cc @@ -0,0 +1,498 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +TextFileOp::Builder::Builder() + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_shuffle_files_(false), + builder_sampler_(nullptr) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status TextFileOp::Builder::ValidateInputs() const { + std::string err_msg; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greater than 0\n" : ""; + err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TextFileOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { + builder_num_workers_ = builder_text_files_list_.size(); + MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + + std::shared_ptr text_file_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, + std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, + builder_num_devices_, builder_device_id_, std::move(builder_sampler_)); + RETURN_IF_NOT_OK(text_file_op->Init()); + *op = std::move(text_file_op); + + return Status::OK(); +} + +TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, + std::unique_ptr schema, std::vector text_files_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id, + std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + total_rows_(total_rows), + text_files_list_(std::move(text_files_list)), + shuffle_files_(shuffle_files), + data_schema_(std::move(schema)), + all_num_rows_(0), + num_rows_per_shard_(0), + filename_index_(std::make_unique()), + finished_reading_dataset_(false), + load_io_block_queue_(true), + load_jagged_connector_(true) { + worker_connector_size_ = worker_connector_size; +} + +// A print method typically used for debugging +void TextFileOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ + << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") + << "\nText files list:\n"; + for (int i = 0; i < text_files_list_.size(); ++i) { + out << " " << text_files_list_[i]; + } + out << "\nData Schema:\n"; + out << *data_schema_ << "\n\n"; + } +} + +Status TextFileOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + return Status::OK(); +} + +Status TextFileOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {line}, TensorShape::CreateScalar())); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + if (line.empty()) { + continue; + } + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + + return Status::OK(); +} + +Status TextFileOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// Pops an element from a queue in io_block_queues +Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TextFileOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair((*filename_index_)[*it], *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TextFileOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +Status TextFileOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); + + // Read data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (total_rows_ == 0 || rows_read < total_rows_) { + if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { + int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +int64_t TextFileOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + if (!line.empty()) { + count++; + } + } + + return count; +} + +Status TextFileOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API TextFileDataset.Please check file path or dataset API " + "validation first."); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} + +Status TextFileOp::ComputeColMap() { + // Set the column name mapping (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h new file mode 100644 index 0000000000..68c226ab80 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/text_file_op.h @@ -0,0 +1,289 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; + +class TextFileOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTextFilesList(const std::vector &files_list) { + builder_text_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_total_rows_; + int32_t builder_worker_connector_size_; + std::vector builder_text_files_list_; + bool builder_shuffle_files_; + std::unique_ptr builder_schema_; + std::shared_ptr builder_sampler_; + }; + + // Constructor of TextFileOp + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes + TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, + std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr sampler); + + // Default destructor + ~TextFileOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all text files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "TextFileOp"; } + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return text_files_list_; } + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a text file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - text file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t total_rows_; + std::vector text_files_list_; + bool shuffle_files_; + std::unique_ptr data_schema_; + int64_t all_num_rows_; + int64_t num_rows_per_shard_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + bool finished_reading_dataset_; + bool load_io_block_queue_; + bool load_jagged_connector_; + std::unique_ptr jagged_buffer_connector_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc new file mode 100644 index 0000000000..ae7907b5ce --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.cc @@ -0,0 +1,1054 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "proto/example.pb.h" +#include "./securec.h" +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/jagged_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/wait_post.h" +#include "utils/system/crc32c.h" + +namespace mindspore { +namespace dataset { +TFReaderOp::Builder::Builder() + : builder_device_id_(0), + builder_num_devices_(1), + builder_total_rows_(0), + builder_equal_rows_per_shard_(false), + builder_sampler_(nullptr) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_shuffle_files_ = false; + builder_data_schema_ = std::make_unique(); +} + +bool ValidateFirstRowCrc(const std::string &filename) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + return false; + } + + // read data + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // read crc from file + uint32_t masked_crc = 0; + (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); + + // generate crc from data + uint32_t generated_crc = + system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); + + return masked_crc == generated_crc; +} + +Status TFReaderOp::Builder::ValidateInputs() const { + std::string err_msg; + + if (builder_num_workers_ <= 0) { + err_msg += "Number of parallel workers is smaller or equal to 0\n"; + } + + if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { + err_msg += "Wrong sharding configs\n"; + } + + std::vector invalid_files(builder_dataset_files_list_.size()); + auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), + [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + invalid_files.resize(std::distance(invalid_files.begin(), it)); + + if (!invalid_files.empty()) { + err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; + + std::string accumulated_filenames = std::accumulate( + invalid_files.begin(), invalid_files.end(), std::string(""), + [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); + err_msg += accumulated_filenames; + } + + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_dataset_files_list_.size()) { + builder_num_workers_ = builder_dataset_files_list_.size(); + MS_LOG(WARNING) << "TFReader operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + std::shared_ptr new_tf_reader_op = std::make_shared( + builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, + builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, + builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_, + std::move(builder_sampler_)); + + RETURN_IF_NOT_OK(new_tf_reader_op->Init()); + *out_tf_reader_op = std::move(new_tf_reader_op); + return Status::OK(); +} + +TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, + int64_t total_num_rows, std::vector dataset_files_list, + std::unique_ptr data_schema, int32_t op_connector_size, + std::vector columns_to_load, bool shuffle_files, int32_t num_device, + int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler) + : ParallelOp(num_workers, op_connector_size, std::move(sampler)), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + total_rows_(total_num_rows), + dataset_files_list_(std::move(dataset_files_list)), + columns_to_load_(std::move(columns_to_load)), + finished_reading_dataset_(false), + shuffle_files_(shuffle_files), + data_schema_(std::move(data_schema)), + filename_index_(std::make_unique()), + load_io_block_queue_(true), + load_jagged_connector_(true), + num_rows_(0), + num_rows_per_shard_(0), + equal_rows_per_shard_(equal_rows_per_shard) { + worker_connector_size_ = worker_connector_size; +} + +// A print method typically used for debugging +void TFReaderOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ + << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") + << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; + for (int i = 0; i < dataset_files_list_.size(); ++i) { + out << " " << dataset_files_list_[i]; + } + if (!columns_to_load_.empty()) { + out << "\nColumns to load:\n"; + for (int i = 0; i < columns_to_load_.size(); ++i) { + out << " " << columns_to_load_[i]; + } + } + out << "\nData Schema:\n"; + out << *data_schema_ << "\n\n"; + } +} + +Status TFReaderOp::Init() { + if (data_schema_->Empty()) { + RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_)); + } + + if (total_rows_ == 0) { + total_rows_ = data_schema_->num_rows(); + } + if (total_rows_ < 0) { + RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); + } + + // Build the index with our files such that each file corresponds to a key id. + RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); + + // The creation of the internal connector has been delayed until now, since we may have adjusted the + // number of workers. Now that the worker count is established, create the connector now in the + // parallel op base. + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + + // temporary: make size large enough to hold all files + EOE to avoid hangs + int32_t safe_queue_size = static_cast(std::ceil(dataset_files_list_.size() / num_workers_)) + 1; + io_block_queues_.Init(num_workers_, safe_queue_size); + + return Status::OK(); +} + +Status TFReaderOp::CalculateNumRowsPerShard() { + if (!equal_rows_per_shard_) { + return Status::OK(); + } + + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + std::vector file(1, it.value()); + int64_t num = CountTotalRowsSectioned(file, 0, 1); + filename_numrows_[it.value()] = num; + num_rows_ += num; + } + num_rows_per_shard_ = static_cast(std::ceil(num_rows_ * 1.0 / num_devices_)); + if (num_rows_per_shard_ == 0) { + RETURN_STATUS_UNEXPECTED( + "There is no valid data matching the dataset API TFRecordDataset.Please check file path or dataset API " + "validation first."); + } + return Status::OK(); +} +// Class functor operator () override. +// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will +// provide the master loop that drives the logic for performing the work +Status TFReaderOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling mIOBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TFReaderOp::WaitToFillIOBlockQueue, this))); + + // launch num_workers_ worker threads, responsible for pulling from the IOBlockQueue and reading + // data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TFReaderOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. workers can't be spawned after this post, + // so workers have to be kept alive until the end of the program + TaskManager::FindMe()->Post(); + + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Register(tree_->AllTasks())); + + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + { + std::unique_lock lock(load_io_block_queue_mutex_); + load_io_block_queue_ = true; + } + + while (workers_done < num_workers_) { + std::unique_ptr fetched_buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &fetched_buffer)); + if (fetched_buffer->eoe()) { + workers_done++; + } else if (total_rows_ == 0 || rows_read < total_rows_) { + // we need to push a buffer + if (total_rows_ > 0 && rows_read + fetched_buffer->NumRows() > total_rows_) { + // this is last buffer we need, and we only need a part of it + int64_t rowsToRemove = fetched_buffer->NumRows() - (total_rows_ - rows_read); + RETURN_IF_NOT_OK(fetched_buffer->SliceOff(rowsToRemove)); + } + + rows_read += fetched_buffer->NumRows(); + fetched_buffer->set_id(buffer_id); + buffer_id++; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer))); + } else { + // user specified number of rows they want, and we read enough rows + // + // IOBlockQueue thread needs to: + // -stop pushing stuff to IOBlockQueue + // -call PostEndOfEpoch (will send EOE) + // -wait for reset + // + // Worker threads need to: + // -stop reading the file they are currently reading and throw it away + // -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE) + // + // Master thread needs to: + // -tell IOBlockQueue thread to stop pushing + // -tell worker threads to stop reading the file tey are currently reading + // -keep pulling until EOE + + // don't think we need a lock for now + load_jagged_connector_ = false; + + std::unique_lock lock(load_io_block_queue_mutex_); + load_io_block_queue_ = false; + } + } + + // all workers finished reading for this epoch, and we have read all the data from all workers + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +// static local-only helper function +static void shuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +// The entry point for when workers are launched. +Status TFReaderOp::WorkerEntry(int32_t worker_id) { + // must be called first if called by worker spawned by taskgroup + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + MS_LOG(DEBUG) << "TFReader operator worker " << worker_id << " loaded file " << filename << "."; + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(1, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TFReaderOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TFReaderOp::FillIOBlockShuffle(const std::vector &i_keys) { + int32_t queue_index = 0; + int32_t key_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + bool end_of_epoch = false; + while (!finish) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + std::unique_lock lock(load_io_block_queue_mutex_); + if (load_io_block_queue_ == false) { + end_of_epoch = true; + break; + } + } + if (!equal_rows_per_shard_) { + if (key_index++ % num_devices_ == device_id_) { + auto ioBlock = std::make_unique(*it, kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + } else { + // Do an index lookup using that key to get the filename. + std::string file_name = (*filename_index_)[*it]; + if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { + auto ioBlock = std::make_unique(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset; + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_name]; + } + } + if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && + !end_of_epoch) { + finish = false; + } else { + finish = true; + } + } + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TFReaderOp::FillIOBlockNoShuffle() { + int32_t queue_index = 0; + int32_t key_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + bool end_of_epoch = false; + while (!finish) { + // Iterate over all the keys and add one key to each block. + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + std::unique_lock lock(load_io_block_queue_mutex_); + if (load_io_block_queue_ == false) { + end_of_epoch = true; + break; + } + } + if (!equal_rows_per_shard_) { + if (key_index++ % num_devices_ == device_id_) { + auto ioBlock = + std::make_unique(it.key(), kInvalidOffset, kInvalidOffset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + } else { + std::string file_name = it.value(); + if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) { + auto ioBlock = std::make_unique(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_name]; + } + } + if (equal_rows_per_shard_ && pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_ && + !end_of_epoch) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. +Status TFReaderOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spawned by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + // Generate a vector of keys that we can shuffle + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + shuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + RETURN_IF_NOT_OK(FillIOBlockShuffle(i_keys)); + } else { // shuffle_files_ == false + RETURN_IF_NOT_OK(FillIOBlockNoShuffle()); + } + } + + return Status::OK(); +} + +// Notifies the thread which called WaitToFillIOBlockQueue to resume execution. +void TFReaderOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +// Pops an element from a queue in io_block_queues +Status TFReaderOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TFReaderOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Reads a tf_file file and loads the data into multiple buffers. +Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, + const int32_t &worker_id) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + RETURN_STATUS_UNEXPECTED("failed to open file: " + filename); + } + + int64_t rows_read = 0; + int64_t rows_total = 0; + std::unique_ptr current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + std::unique_ptr new_tensor_table = std::make_unique(); + + while (reader.peek() != EOF) { + if (!load_jagged_connector_) { + break; + } + + // read length + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // ignore crc header + (void)reader.ignore(static_cast(sizeof(int32_t))); + + // read serialized Example + std::string serialized_example; + serialized_example.resize(record_length); + (void)reader.read(&serialized_example[0], static_cast(record_length)); + if (start_offset == kInvalidOffset || (rows_total >= start_offset && rows_total < end_offset)) { + dataengine::Example tf_file; + if (!tf_file.ParseFromString(serialized_example)) { + std::string errMsg = "parse tfrecord failed"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); + rows_read++; + } + + // ignore crc footer + (void)reader.ignore(static_cast(sizeof(int32_t))); + rows_total++; + + if (rows_read == rows_per_buffer_) { + current_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); + + current_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + new_tensor_table = std::make_unique(); + rows_read = 0; + } + } + + if (rows_read > 0) { + current_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(current_buffer))); + } + + return Status::OK(); +} + +// Parses a single row and puts the data into a tensor table. +Status TFReaderOp::LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, + int64_t row) { + int32_t num_columns = data_schema_->NumColumns(); + TensorRow newRow(num_columns, nullptr); + (*tensor_table)->push_back(std::move(newRow)); + + for (int32_t col = 0; col < num_columns; ++col) { + const ColDescriptor current_col = data_schema_->column(col); + const dataengine::Features &example_features = tf_file->features(); + const google::protobuf::Map &feature_map = example_features.feature(); + const dataengine::Feature &column_values_list = feature_map.at(current_col.name()); + RETURN_IF_NOT_OK(LoadFeature(tensor_table, column_values_list, current_col, row, col)); + } + + return Status::OK(); +} + +// Parses a single cell and puts the data into a tensor table. +Status TFReaderOp::LoadFeature(const std::unique_ptr *tensor_table, + const dataengine::Feature &column_values_list, const ColDescriptor ¤t_col, + int64_t row, int32_t col) { + const dataengine::Feature::KindCase column_list_type = column_values_list.kind_case(); + std::unique_ptr float_array; // For staging data from protobuf deserialization + const unsigned char *data_ptr = nullptr; // Generic pointer used for populating the Tensor + + // This variable will point into the above staging variables. + // Also used for creating shape attributes. + int32_t num_elements = 0; + + // we build a tensor first a read directly into it if we need to cast + std::shared_ptr ts; + + // Depending on the type of data from the tf_file, we want to extract 2 things: + // 1) A pointer to the data as a const unsigned char * + // 2) The number of elements of the data + // After those are determined, we can then build the tensor to represent this data. + switch (column_list_type) { + case dataengine::Feature::KindCase::kBytesList: { + RETURN_IF_NOT_OK(LoadBytesList(current_col, column_values_list, &num_elements, &ts)); + + break; + } + case dataengine::Feature::KindCase::kFloatList: { + RETURN_IF_NOT_OK(LoadFloatList(current_col, column_values_list, &num_elements, &float_array)); + + data_ptr = reinterpret_cast(float_array.get()); + + // only floatList needs to create the tensor here, other two lists read directly + // into the tensor + TensorShape current_shape = TensorShape::CreateUnknownRankShape(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(num_elements, ¤t_shape)); + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&ts, current_col.tensorImpl(), current_shape, current_col.type(), data_ptr)); + break; + } + case dataengine::Feature::KindCase::kInt64List: { + RETURN_IF_NOT_OK(LoadIntListSwitch(current_col, column_values_list, &num_elements, &ts)); + break; + } + case dataengine::Feature::KindCase::KIND_NOT_SET: { + std::string err_msg = "tf_file column list type enum is KIND_NOT_SET"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + default: { + std::string err_msg = "tf_file column list type enum does not match any known DE type"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + (**tensor_table)[row][col] = std::move(ts); + + return Status::OK(); +} + +// Overrides base class reset method. Cleans up any state info from it's previous execution and +// reinitializes itself so that it can be executed again, as if it was just created. +Status TFReaderOp::Reset() { + // start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true + load_jagged_connector_ = true; + + { + std::unique_lock lock(load_io_block_queue_mutex_); + load_io_block_queue_ = true; + } + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + + return Status::OK(); +} + +Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor) { + // kBytesList can map to the following DE types ONLY! + // DE_UINT8, DE_INT8 + // Must be single byte type for each element! + if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8 && + current_col.type() != DataType::DE_STRING) { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + const dataengine::BytesList &bytes_list = column_values_list.bytes_list(); + + *num_elements = bytes_list.value_size(); + + if (current_col.type() == DataType::DE_STRING) { + TensorShape shape = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape)); + return Status::OK(); + } + + uint64_t max_size = 0; + for (uint32_t i = 0; i < bytes_list.value_size(); ++i) max_size = std::max(max_size, bytes_list.value(i).size()); + + int64_t pad_size = max_size; + + // if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn + if (current_col.hasShape()) { + TensorShape cur_shape = current_col.shape(); + if (cur_shape.Size() >= 2 && cur_shape[0] == TensorShape::kDimUnknown) { + int64_t new_pad_size = 1; + for (int i = 1; i < cur_shape.Size(); ++i) { + if (cur_shape[i] == TensorShape::kDimUnknown) { + std::string err_msg = "More than one unknown dimension in the shape of column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + new_pad_size *= cur_shape[i]; + } + pad_size = new_pad_size; + } + } + + // know how many elements there are and the total bytes, create tensor here: + TensorShape current_shape = TensorShape::CreateScalar(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape((*num_elements) * pad_size, ¤t_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, current_shape, current_col.type(), pad_size)); + + return Status::OK(); +} + +Status TFReaderOp::LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::unique_ptr *float_array) { + // KFloatList can only map to DE types: + // DE_FLOAT32 + if (current_col.type() != DataType::DE_FLOAT32) { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + const dataengine::FloatList &float_list = column_values_list.float_list(); + + // Identify how many values we have and then create a local array of these + // to deserialize into + *num_elements = float_list.value_size(); + *float_array = std::make_unique(*num_elements); + for (int i = 0; i < float_list.value_size(); ++i) { + (*float_array)[i] = float_list.value(i); + } + + return Status::OK(); +} + +// Determines which template type to use and calls LoadIntList +Status TFReaderOp::LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor) { + if (current_col.type() == DataType::DE_UINT64) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT64) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_UINT32) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT32) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_UINT16) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT16) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_UINT8) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else if (current_col.type() == DataType::DE_INT8) { + RETURN_IF_NOT_OK(LoadIntList(current_col, column_values_list, num_elements, tensor)); + } else { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + return Status::OK(); +} + +// Reads values from a bytes list and casts the value to type T, must be an integral type +// compatible with int64_t +template +Status TFReaderOp::LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor) { + if (!(current_col.type().IsInt())) { + std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + const dataengine::Int64List &int64_list = column_values_list.int64_list(); + + // Identify how many values we have and then create a local array of these + // to deserialize into + *num_elements = int64_list.value_size(); + + // know how many elements there are, create tensor here: + TensorShape current_shape = TensorShape::CreateUnknownRankShape(); + RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, ¤t_shape)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type())); + + // Tensors are lazily allocated, this eagerly allocates memory for the tensor. + RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes())); + + int64_t i = 0; + auto it = (*tensor)->begin(); + for (; it != (*tensor)->end(); i++, ++it) { + T element = static_cast(int64_list.value(i)); + *it = element; + } + + return Status::OK(); +} + +Status TFReaderOp::CreateSchema(const std::string tf_file, std::vector columns_to_load) { + std::ifstream reader; + reader.open(tf_file); + + // read length + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // ignore crc header + (void)reader.ignore(static_cast(sizeof(int32_t))); + + // read serialized Example + std::string serialized_example; + serialized_example.resize(record_length); + (void)reader.read(&serialized_example[0], static_cast(record_length)); + + dataengine::Example example; + if (!example.ParseFromString(serialized_example)) RETURN_STATUS_UNEXPECTED("parse tf_file failed"); + + const dataengine::Features &example_features = example.features(); + const google::protobuf::Map &feature_map = example_features.feature(); + + if (columns_to_load.empty()) { + (void)std::transform(feature_map.begin(), feature_map.end(), std::back_inserter(columns_to_load), + [](const auto &it) -> std::string { return it.first; }); + std::sort(columns_to_load.begin(), columns_to_load.end()); + } + + for (const auto &curr_col_name : columns_to_load) { + auto it = feature_map.find(curr_col_name); + if (it == feature_map.end()) { + RETURN_STATUS_UNEXPECTED("Failed to find column " + curr_col_name); + } + std::string column_name = it->first; + + std::string column_type; + + const dataengine::Feature &feature = it->second; + const dataengine::Feature::KindCase kind_case = feature.kind_case(); + switch (kind_case) { + case dataengine::Feature::KindCase::kBytesList: + column_type = "uint8"; + break; + + case dataengine::Feature::KindCase::kFloatList: + column_type = "float32"; + break; + + case dataengine::Feature::KindCase::kInt64List: + column_type = "int64"; + break; + + case dataengine::Feature::KindCase::KIND_NOT_SET: + RETURN_STATUS_UNEXPECTED("trying to make schema, tf_file column list type enum is KIND_NOT_SET"); + + default: + RETURN_STATUS_UNEXPECTED( + "trying to make schema, tf_file column list type enum does not match any known DE type"); + } + + RETURN_IF_NOT_OK( + data_schema_->AddColumn(ColDescriptor(column_name, DataType(column_type), TensorImpl::kFlexible, 1))); + } + + return Status::OK(); +} + +Status TFReaderOp::CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads, + bool estimate) { + try { + if (threads > filenames.size()) { + threads = filenames.size(); + } + + std::vector> async_results; + + int64_t chunk_size = filenames.size() / threads; + int64_t remainder = filenames.size() % threads; + + int64_t begin = 0; + int64_t end = begin; + for (int i = 0; i < threads; i++) { + end += chunk_size; + if (remainder > 0) { + end++; + remainder--; + } + + if (estimate) { + // Parse a single file for each chunk with estimate mode on + async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, begin + 1)); + } else { + // Parse the whole chunk with estimate mode off + async_results.push_back(std::async(std::launch::async, &CountTotalRowsSectioned, filenames, begin, end)); + } + + begin = end; + } + + int64_t total_rows = 0; + for (int i = 0; i < async_results.size(); i++) { + total_rows += async_results[i].get(); + } + + if (estimate) { + // Each thread only scans 1 file + // Estimated total rows = Average rows * total number of files + total_rows = total_rows / threads * filenames.size(); + } + + *out_total_rows = total_rows; + } catch (const std::exception &e) { + std::string err_msg = "Unexpected error occurred: "; + err_msg += e.what(); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + return Status::OK(); +} + +int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector &filenames, int64_t begin, int64_t end) { + int64_t rows_read = 0; + for (int i = begin; i < end; i++) { + std::ifstream reader; + reader.open(filenames[i]); + if (!reader) { + MS_LOG(DEBUG) << "TFReader operator failed to open file " << filenames[i] << "."; + } + + while (reader.peek() != EOF) { + // read length + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // ignore crc header + (void)reader.ignore(static_cast(sizeof(int32_t))); + + // ignore tf_file contents + (void)reader.ignore(static_cast(record_length)); + + // ignore crc footer + (void)reader.ignore(static_cast(sizeof(int32_t))); + + rows_read++; + } + } + + return rows_read; +} + +// Visitor accept method for NodePass +Status TFReaderOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status TFReaderOp::ComputeColMap() { + // Construct the column name map for this operator (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} + +// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing +// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so +// that this tf reader will produce the full set of data into the cache. +void TFReaderOp::MakeSimpleProducer() { + device_id_ = 0; + num_devices_ = 1; + total_rows_ = 0; + shuffle_files_ = false; + equal_rows_per_shard_ = false; +} + +// During tree prepare phase, operators may have specific post-operations to perform depending on +// their role. +Status TFReaderOp::PrepareNodePostAction() { + // Run common code from super class before adding TFReaderOp specific handling + RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction()); + + // Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into + // a simpler producer of all data (no shuffling or sharding or anything) + if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) { + // This sanity check had been delayed until now in the prepare loop. + // If we are not in a cache path, then we can validate the file-based sharding config. + // If we are in a cache path, there is no file-based sharding so the check is not correct in that + // situation. + if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast(num_devices_)) { + RETURN_STATUS_UNEXPECTED("Not enough tfrecord files provided\n"); + } + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h new file mode 100644 index 0000000000..c03f3957e9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/tf_reader_op.h @@ -0,0 +1,420 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" + +namespace dataengine { +class Example; +class Feature; +class BytesList; +} // namespace dataengine + +namespace mindspore { +namespace dataset { +template +class Queue; + +template +class Connector; + +class JaggedConnector; +class FilenameBlock; + +using StringIndex = AutoIndexObj; + +class TFReaderOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + Status Build(std::shared_ptr *out_tf_reader_op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDataSchema(std::unique_ptr data_schema) { + builder_data_schema_ = std::move(data_schema); + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetWorkerConnectorSize(int32_t size) { + builder_worker_connector_size_ = size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &setTotalRows(int64_t total_rows) { + builder_total_rows_ = total_rows; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDatasetFilesList(const std::vector &dataset_files_list) { + builder_dataset_files_list_ = dataset_files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetColumnsToLoad(const std::vector &columns_to_load) { + builder_columns_to_load_ = columns_to_load; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShardEqualRows(bool shard_equal_rows) { + builder_equal_rows_per_shard_ = shard_equal_rows; + return *this; + } + + // Setter method + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + private: + std::unique_ptr builder_data_schema_; + std::shared_ptr builder_sampler_; + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_worker_connector_size_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_total_rows_; + std::vector builder_dataset_files_list_; + std::vector builder_columns_to_load_; + bool builder_shuffle_files_; + bool builder_equal_rows_per_shard_; + }; + + // Constructor of TFReaderOp (2) + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param worker_connector_size - size of each internal queue. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - Number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + // @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes + TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, + std::vector dataset_files_list, std::unique_ptr data_schema, + int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, + int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr sampler); + + // Default destructor + ~TFReaderOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // Instantiates the internal queues and connectors. + // @return Status - the error code returned. + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution and + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Getter method + int64_t rows_per_buffer() const { return rows_per_buffer_; } + + // Reads all the provided tf_file files and counts the total number of rows. filenames will + // first be sectioned into equal parts, then sections are read in parallel. If threads is + // greater than the number of files, threads will be clamped to the number of files. + // @param out_total_tows - output parameter which contains the total number of rows + // @param filenames - a list of tf_file filenames. + // @param threads - number of threads to use to read the tf_file files. + // @param estimate - estimate mode, under this mode each threads will sample a single file from each chunk + // @return Status - the error code returned. + static Status CountTotalRows(int64_t *out_total_rows, const std::vector &filenames, int64_t threads = 1, + bool estimate = false); + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "TFReaderOp"; } + + // File names getter + // @return Vector of the input file names + std::vector FileNames() { return dataset_files_list_; } + + /// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing + /// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so + /// that this tf reader will produce the full set of data into the cache. + void MakeSimpleProducer(); + + // During tree prepare phase, operators may have specific post-operations to perform depending on + // their role. + // @notes Derived versions of this function should always call it's superclass version first + // before providing their own implementations. + Status PrepareNodePostAction() override; + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Notifies the thread which called WaitToFillIOBlockQueue to resume execution. + void NotifyToFillIOBlockQueue(); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Reads a tf_file file and loads the data into multiple buffers. + // @param filename - the tf_file file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, + const int32_t &worker_id); + + // Parses a single row and puts the data into a tensor table. + // @param tf_file - the row to be parsed. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadExample(const dataengine::Example *tf_file, std::unique_ptr *tensor_table, int64_t row); + + // Parses a single cell and puts the data into a tensor table. + // @param tensor_table - the tensor table to put the parsed data in. + // @param column_values_list - the cell to parse. + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @return Status - the error code returned. + Status LoadFeature(const std::unique_ptr *tensor_table, const dataengine::Feature &column_values_list, + const ColDescriptor ¤t_col, int64_t row, int32_t col); + + // Reads values from a bytes list + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the bytes list to read from. + // @param elementStr - the string we read the value into. + // @return Status - the error code returned. + static Status LoadBytesList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor); + + // Reads values from a float list + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the float list to read from. + // @Param numElements - number of values in the float list. + // @param float_array - the array we read the values into. + // @return Status - the error code returned. + Status LoadFloatList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::unique_ptr *float_array); + + // Reads values from a bytes list and casts the value to type T, must be an integral + // type compatible with int64_t + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the int list to read from. + // @Param num_elements - number of values in the int list. + // @param tensor - the tensor we read the values into. + // @return Status - the error code returned. + template + Status LoadIntList(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor); + + // Determines which template type to use and calls LoadIntList + // @param current_col - the column descriptor containing the expected shape and type of the data. + // @param column_values_list - the cell that contains the int list to read from. + // @Param numElements - number of values in the int list. + // @param tensor - the tensor we read the values into. + // @return Status - the error code returned. + Status LoadIntListSwitch(const ColDescriptor ¤t_col, const dataengine::Feature &column_values_list, + int32_t *num_elements, std::shared_ptr *tensor); + + // Reads one row of data from a tf file and creates a schema based on that row + // @return Status - the error code returned. + Status CreateSchema(const std::string tf_file, std::vector columns_to_load); + + // Meant to be called async. Will read files in the range [begin, end) and return the total rows + // @param filenames - a list of tf data filenames. + // @param begin - index of first file to read. + // @param end - one greater than the index of the last file to read. + // @return int63_t - the total number of rows of files read. + static int64_t CountTotalRowsSectioned(const std::vector &filenames, const int64_t begin, + const int64_t end); + // Fill IO block queue if shuffle is true + // @param i_keys - shuffle keys. + // @return Status - the error code returned. + Status FillIOBlockShuffle(const std::vector &i_keys); + + /** + * Fill IO block queue if shuffle is false + * @param i_keys - shuffle keys. + * @return Status - the error code returned. + */ + Status FillIOBlockNoShuffle(); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Caculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t total_rows_; + std::vector dataset_files_list_; + std::vector columns_to_load_; + bool finished_reading_dataset_; + bool shuffle_files_; + std::unique_ptr data_schema_; + std::unique_ptr filename_index_; + bool load_io_block_queue_; + bool load_jagged_connector_; + + std::unique_ptr jagged_buffer_connector_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + std::mutex load_io_block_queue_mutex_; + std::map filename_numrows_; + int64_t num_rows_; + int64_t num_rows_per_shard_; + bool equal_rows_per_shard_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TF_READER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc new file mode 100644 index 0000000000..e90d423ef4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.cc @@ -0,0 +1,471 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/source/voc_op.h" + +#include +#include +#include +#include "./tinyxml2.h" +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +using tinyxml2::XMLDocument; +using tinyxml2::XMLElement; +using tinyxml2::XMLError; +namespace mindspore { +namespace dataset { +const char kColumnImage[] = "image"; +const char kColumnTarget[] = "target"; +const char kColumnAnnotation[] = "annotation"; +const char kJPEGImagesFolder[] = "/JPEGImages/"; +const char kSegmentationClassFolder[] = "/SegmentationClass/"; +const char kAnnotationsFolder[] = "/Annotations/"; +const char kImageSetsSegmentation[] = "/ImageSets/Segmentation/"; +const char kImageSetsMain[] = "/ImageSets/Main/"; +const char kImageExtension[] = ".jpg"; +const char kSegmentationExtension[] = ".png"; +const char kAnnotationExtension[] = ".xml"; +const char kImageSetsExtension[] = ".txt"; + +VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); + builder_task_type_ = TaskType::Segmentation; +} + +Status VOCOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + if (builder_sampler_ == nullptr) { + const int64_t num_samples = 0; + const int64_t start_index = 0; + builder_sampler_ = std::make_shared(start_index, num_samples); + } + builder_schema_ = std::make_unique(); + if (builder_task_type_ == TaskType::Segmentation) { + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnTarget), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + } else if (builder_task_type_ == TaskType::Detection) { + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnImage), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + RETURN_IF_NOT_OK(builder_schema_->AddColumn( + ColDescriptor(std::string(kColumnAnnotation), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); + } + *ptr = std::make_shared(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, + builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, + builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); + return Status::OK(); +} + +Status VOCOp::Builder::SanityCheck() { + Path dir(builder_dir_); + std::string err_msg; + err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : ""; + err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, + const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler) + : ParallelOp(num_workers, queue_size, std::move(sampler)), + decode_(decode), + row_cnt_(0), + buf_cnt_(0), + task_type_(task_type), + task_mode_(task_mode), + folder_path_(folder_path), + class_index_(class_index), + rows_per_buffer_(rows_per_buffer), + data_schema_(std::move(data_schema)) { + io_block_queues_.Init(num_workers_, queue_size); +} + +Status VOCOp::TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys) { + for (auto itr = sample_ids->begin(); itr != sample_ids->end(); ++itr) { + if ((*itr) > num_rows_) continue; + keys->push_back(*itr); + row_cnt_++; + if (row_cnt_ % rows_per_buffer_ == 0) { + RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( + std::make_unique(IOBlock(*keys, IOBlock::kDeIoBlockNone)))); + keys->clear(); + } + } + return Status::OK(); +} + +Status VOCOp::operator()() { + RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); + std::unique_ptr sampler_buffer; + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + while (true) { + std::vector keys; + keys.reserve(rows_per_buffer_); + while (sampler_buffer->eoe() == false) { + std::shared_ptr sample_ids; + RETURN_IF_NOT_OK(sampler_buffer->GetTensor(&sample_ids, 0, 0)); + if (sample_ids->type() != DataType(DataType::DE_INT64)) { + RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); + } + RETURN_IF_NOT_OK(TraverseSampleIds(sample_ids, &keys)); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + if (keys.empty() == false) { + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add( + std::make_unique(IOBlock(keys, IOBlock::kDeIoBlockNone)))); + } + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + std::unique_ptr eoe_block = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + std::unique_ptr eof_block = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eoe_block))); + RETURN_IF_NOT_OK(io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::move(eof_block))); + for (int32_t i = 0; i < num_workers_; i++) { + RETURN_IF_NOT_OK( + io_block_queues_[i]->Add(std::make_unique(std::vector(), IOBlock::kDeIoBlockNone))); + } + return Status::OK(); + } else { + RETURN_IF_NOT_OK( + io_block_queues_[(buf_cnt_++) % num_workers_]->Add(std::make_unique(IOBlock::kDeIoBlockFlagEoe))); + RETURN_IF_NOT_OK(wp_.Wait()); + wp_.Clear(); + RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer)); + } + } +} + +void VOCOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + ParallelOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nNumber of rows: " << num_rows_ << "\nVOC Directory: " << folder_path_ << "\n\n"; + } +} + +Status VOCOp::Reset() { + RETURN_IF_NOT_OK(sampler_->ResetSampler()); + row_cnt_ = 0; + wp_.Set(); + return Status::OK(); +} + +Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *trow) { + if (task_type_ == TaskType::Segmentation) { + std::shared_ptr image, target; + const std::string kImageFile = + folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); + const std::string kTargetFile = + folder_path_ + std::string(kSegmentationClassFolder) + image_id + std::string(kSegmentationExtension); + RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); + RETURN_IF_NOT_OK(ReadImageToTensor(kTargetFile, data_schema_->column(1), &target)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(target)}); + } else if (task_type_ == TaskType::Detection) { + std::shared_ptr image, annotation; + const std::string kImageFile = + folder_path_ + std::string(kJPEGImagesFolder) + image_id + std::string(kImageExtension); + const std::string kAnnotationFile = + folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); + RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); + RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, data_schema_->column(1), &annotation)); + (*trow) = TensorRow(row_id, {std::move(image), std::move(annotation)}); + } + return Status::OK(); +} + +Status VOCOp::LoadBuffer(const std::vector &keys, std::unique_ptr *db) { + std::unique_ptr deq = std::make_unique(); + TensorRow trow; + for (const uint64_t &key : keys) { + RETURN_IF_NOT_OK(this->LoadTensorRow(key, image_ids_[key], &trow)); + deq->push_back(std::move(trow)); + } + (*db)->set_tensor_table(std::move(deq)); + return Status::OK(); +} + +Status VOCOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + int64_t buffer_id = worker_id; + std::unique_ptr io_block; + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + while (io_block != nullptr) { + if (io_block->eoe() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique(0, DataBuffer::kDeBFlagEOE))); + buffer_id = worker_id; + } else if (io_block->eof() == true) { + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, (std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + } else { + std::vector keys; + RETURN_IF_NOT_OK(io_block->GetKeys(&keys)); + if (keys.empty() == true) return Status::OK(); + std::unique_ptr db = std::make_unique(buffer_id, DataBuffer::kDeBFlagNone); + RETURN_IF_NOT_OK(LoadBuffer(keys, &db)); + RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db))); + buffer_id += num_workers_; + } + RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&io_block)); + } + RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker"); +} + +Status VOCOp::ParseImageIds() { + std::string image_sets_file; + if (task_type_ == TaskType::Segmentation) { + image_sets_file = + folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension); + } else if (task_type_ == TaskType::Detection) { + image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension); + } + std::ifstream in_file; + in_file.open(image_sets_file); + if (in_file.fail()) { + RETURN_STATUS_UNEXPECTED("Fail to open file: " + image_sets_file); + } + std::string id; + while (getline(in_file, id)) { + if (id.size() > 0 && id[id.size() - 1] == '\r') { + image_ids_.push_back(id.substr(0, id.size() - 1)); + } else { + image_ids_.push_back(id); + } + } + in_file.close(); + image_ids_.shrink_to_fit(); + num_rows_ = image_ids_.size(); + return Status::OK(); +} + +Status VOCOp::ParseAnnotationIds() { + std::vector new_image_ids; + for (auto id : image_ids_) { + const std::string kAnnotationName = + folder_path_ + std::string(kAnnotationsFolder) + id + std::string(kAnnotationExtension); + RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName)); + if (label_map_.find(kAnnotationName) != label_map_.end()) { + new_image_ids.push_back(id); + } + } + + if (image_ids_.size() != new_image_ids.size()) { + image_ids_.clear(); + image_ids_.insert(image_ids_.end(), new_image_ids.begin(), new_image_ids.end()); + } + uint32_t count = 0; + for (auto &label : label_index_) { + label.second = count++; + } + + num_rows_ = image_ids_.size(); + return Status::OK(); +} + +Status VOCOp::ParseAnnotationBbox(const std::string &path) { + if (!Path(path).Exists()) { + RETURN_STATUS_UNEXPECTED("File is not found : " + path); + } + Bbox bbox; + XMLDocument doc; + XMLError e = doc.LoadFile(common::SafeCStr(path)); + if (e != XMLError::XML_SUCCESS) { + RETURN_STATUS_UNEXPECTED("Xml load failed"); + } + XMLElement *root = doc.RootElement(); + if (root == nullptr) { + RETURN_STATUS_UNEXPECTED("Xml load root element error"); + } + XMLElement *object = root->FirstChildElement("object"); + if (object == nullptr) { + RETURN_STATUS_UNEXPECTED("No object find in " + path); + } + while (object != nullptr) { + std::string label_name; + float xmin = 0.0, ymin = 0.0, xmax = 0.0, ymax = 0.0, truncated = 0.0, difficult = 0.0; + XMLElement *name_node = object->FirstChildElement("name"); + if (name_node != nullptr && name_node->GetText() != 0) label_name = name_node->GetText(); + XMLElement *truncated_node = object->FirstChildElement("truncated"); + if (truncated_node != nullptr) truncated = truncated_node->FloatText(); + XMLElement *difficult_node = object->FirstChildElement("difficult"); + if (difficult_node != nullptr) difficult = difficult_node->FloatText(); + + XMLElement *bbox_node = object->FirstChildElement("bndbox"); + if (bbox_node != nullptr) { + XMLElement *xmin_node = bbox_node->FirstChildElement("xmin"); + if (xmin_node != nullptr) xmin = xmin_node->FloatText(); + XMLElement *ymin_node = bbox_node->FirstChildElement("ymin"); + if (ymin_node != nullptr) ymin = ymin_node->FloatText(); + XMLElement *xmax_node = bbox_node->FirstChildElement("xmax"); + if (xmax_node != nullptr) xmax = xmax_node->FloatText(); + XMLElement *ymax_node = bbox_node->FirstChildElement("ymax"); + if (ymax_node != nullptr) ymax = ymax_node->FloatText(); + } else { + RETURN_STATUS_UNEXPECTED("bndbox dismatch in " + path); + } + if (label_name != "" && (class_index_.empty() || class_index_.find(label_name) != class_index_.end()) && xmin > 0 && + ymin > 0 && xmax > xmin && ymax > ymin) { + std::vector bbox_list = {xmin, ymin, xmax - xmin, ymax - ymin, truncated, difficult}; + bbox.emplace_back(std::make_pair(label_name, bbox_list)); + label_index_[label_name] = 0; + } + object = object->NextSiblingElement("object"); + } + if (bbox.size() > 0) label_map_[path] = bbox; + return Status::OK(); +} + +Status VOCOp::InitSampler() { + RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this)); + return Status::OK(); +} + +Status VOCOp::LaunchThreadsAndInitOp() { + if (tree_ == nullptr) { + RETURN_STATUS_UNEXPECTED("tree_ not set"); + } + RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(wp_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&VOCOp::WorkerEntry, this, std::placeholders::_1))); + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(this->ParseImageIds()); + if (task_type_ == TaskType::Detection) { + RETURN_IF_NOT_OK(this->ParseAnnotationIds()); + } + RETURN_IF_NOT_OK(this->InitSampler()); + return Status::OK(); +} + +Status VOCOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, path)); + if (decode_ == true) { + Status rc = Decode(*tensor, tensor); + if (rc.IsError()) { + RETURN_STATUS_UNEXPECTED("fail to decode file: " + path); + } + } + return Status::OK(); +} + +Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, + std::shared_ptr *tensor) { + Bbox bbox_info = label_map_[path]; + std::vector bbox_row; + dsize_t bbox_column_num = 0, bbox_num = 0; + for (auto box : bbox_info) { + if (label_index_.find(box.first) != label_index_.end()) { + std::vector bbox; + bbox.insert(bbox.end(), box.second.begin(), box.second.end()); + if (class_index_.find(box.first) != class_index_.end()) { + bbox.push_back(static_cast(class_index_[box.first])); + } else { + bbox.push_back(static_cast(label_index_[box.first])); + } + bbox_row.insert(bbox_row.end(), bbox.begin(), bbox.end()); + if (bbox_column_num == 0) { + bbox_column_num = static_cast(bbox.size()); + } + bbox_num++; + } + } + + std::vector bbox_dim = {bbox_num, bbox_column_num}; + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, col.tensorImpl(), TensorShape(bbox_dim), col.type(), + reinterpret_cast(&bbox_row[0]))); + return Status::OK(); +} + +Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t *count) { + if (task_type == "Detection") { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + *count = static_cast(op->image_ids_.size()); + } else if (task_type == "Segmentation") { + std::shared_ptr op; + RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + *count = static_cast(op->image_ids_.size()); + } + + return Status::OK(); +} + +Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, std::map *output_class_indexing) { + std::map input_class_indexing; + for (auto p : dict) { + (void)input_class_indexing.insert(std::pair(py::reinterpret_borrow(p.first), + py::reinterpret_borrow(p.second))); + } + + if (!input_class_indexing.empty()) { + *output_class_indexing = input_class_indexing; + } else { + std::shared_ptr op; + RETURN_IF_NOT_OK( + Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); + RETURN_IF_NOT_OK(op->ParseImageIds()); + RETURN_IF_NOT_OK(op->ParseAnnotationIds()); + for (const auto label : op->label_index_) { + (*output_class_indexing).insert(std::make_pair(label.first, label.second)); + } + } + + return Status::OK(); +} +// Visitor accept method for NodePass +Status VOCOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status VOCOp::ComputeColMap() { + // Set the column name map (base class field) + if (column_name_id_map_.empty()) { + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + column_name_id_map_[data_schema_->column(i).name()] = i; + } + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h new file mode 100644 index 0000000000..e0c46c7a94 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/voc_op.h @@ -0,0 +1,294 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/engine/datasetops/parallel_op.h" +#include "minddata/dataset/engine/datasetops/source/io_block.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/queue.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/wait_post.h" + +namespace mindspore { +namespace dataset { +// Forward declares +template +class Queue; + +using Bbox = std::vector>>; + +class VOCOp : public ParallelOp, public RandomAccessOp { + public: + enum class TaskType { Segmentation = 0, Detection = 1 }; + + class Builder { + public: + // Constructor for Builder class of ImageFolderOp + // @param uint32_t numWrks - number of parallel workers + // @param dir - directory folder got ImageNetFolder + Builder(); + + // Destructor. + ~Builder() = default; + + // Setter method. + // @param const std::string & build_dir + // @return Builder setter method returns reference to the builder. + Builder &SetDir(const std::string &build_dir) { + builder_dir_ = build_dir; + return *this; + } + + // Setter method. + // @param const std::map &map - a class name to label map + // @return Builder setter method returns reference to the builder. + Builder &SetClassIndex(const std::map &map) { + builder_labels_to_read_ = map; + return *this; + } + + // Setter method. + // @param const std::string & task_type + // @return Builder setter method returns reference to the builder. + Builder &SetTask(const std::string &task_type) { + if (task_type == "Segmentation") { + builder_task_type_ = TaskType::Segmentation; + } else if (task_type == "Detection") { + builder_task_type_ = TaskType::Detection; + } + return *this; + } + + // Setter method. + // @param const std::string & task_mode + // @return Builder setter method returns reference to the builder. + Builder &SetMode(const std::string &task_mode) { + builder_task_mode_ = task_mode; + return *this; + } + + // Setter method. + // @param int32_t num_workers + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param int32_t rows_per_buffer + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param std::shared_ptr sampler + // @return Builder setter method returns reference to the builder. + Builder &SetSampler(std::shared_ptr sampler) { + builder_sampler_ = std::move(sampler); + return *this; + } + + // Setter method. + // @param bool do_decode + // @return Builder setter method returns reference to the builder. + Builder &SetDecode(bool do_decode) { + builder_decode_ = do_decode; + return *this; + } + + // Check validity of input args + // @return = The error code return + Status SanityCheck(); + + // The builder "Build" method creates the final object. + // @param std::shared_ptr *op - DatasetOp + // @return - The error code return + Status Build(std::shared_ptr *op); + + private: + bool builder_decode_; + std::string builder_dir_; + TaskType builder_task_type_; + std::string builder_task_mode_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int32_t builder_rows_per_buffer_; + std::shared_ptr builder_sampler_; + std::unique_ptr builder_schema_; + std::map builder_labels_to_read_; + }; + + // Constructor + // @param TaskType task_type - task type of VOC + // @param std::string task_mode - task mode of VOC + // @param std::string folder_path - dir directory of VOC + // @param std::map class_index - input class-to-index of annotation + // @param int32_t num_workers - number of workers reading images in parallel + // @param int32_t rows_per_buffer - number of images (rows) in each buffer + // @param int32_t queue_size - connector queue size + // @param bool decode - whether to decode images + // @param std::unique_ptr data_schema - the schema of the VOC dataset + // @param std::shared_ptr sampler - sampler tells VOCOp what to read + VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, + const std::map &class_index, int32_t num_workers, int32_t rows_per_buffer, + int32_t queue_size, bool decode, std::unique_ptr data_schema, std::shared_ptr sampler); + + // Destructor + ~VOCOp() = default; + + // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector + // @param int32_t workerId - id of each worker + // @return Status - The error code return + Status WorkerEntry(int32_t worker_id) override; + + // Main Loop of VOCOp + // Master thread: Fill IOBlockQueue, then goes to sleep + // Worker thread: pulls IOBlock from IOBlockQueue, work on it the put buffer to mOutConnector + // @return Status - The error code return + Status operator()() override; + + // A print method typically used for debugging + // @param out + // @param show_all + void Print(std::ostream &out, bool show_all) const override; + + // @param const std::string &dir - VOC dir path + // @param const std::string &task_type - task type of reading voc job + // @param const std::string &task_mode - task mode of reading voc job + // @param const py::dict &dict - input dict of class index + // @param int64_t *count - output rows number of VOCDataset + static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, int64_t *count); + + // @param const std::string &dir - VOC dir path + // @param const std::string &task_type - task type of reading voc job + // @param const std::string &task_mode - task mode of reading voc job + // @param const py::dict &dict - input dict of class index + // @param int64_t numSamples - samples number of VOCDataset + // @param std::map *output_class_indexing - output class index of VOCDataset + static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, + const py::dict &dict, std::map *output_class_indexing); + + /// \brief Base-class override for NodePass visitor acceptor + /// \param[in] p Pointer to the NodePass to be accepted + /// \param[out] modified Indicator if the node was changed at all + /// \return Status of the node visit + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "VOCOp"; } + + private: + // Initialize Sampler, calls sampler->Init() within + // @return Status - The error code return + Status InitSampler(); + + // Load a tensor row according to image id + // @param row_id_type row_id - id for this tensor row + // @param std::string image_id - image id + // @param TensorRow row - image & target read into this tensor row + // @return Status - The error code return + Status LoadTensorRow(row_id_type row_id, const std::string &image_id, TensorRow *row); + + // @param const std::string &path - path to the image file + // @param const ColDescriptor &col - contains tensor implementation and datatype + // @param std::shared_ptr tensor - return + // @return Status - The error code return + Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + + // @param const std::string &path - path to the image file + // @param const ColDescriptor &col - contains tensor implementation and datatype + // @param std::shared_ptr tensor - return + // @return Status - The error code return + Status ReadAnnotationToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr *tensor); + + // @param const std::vector &keys - keys in ioblock + // @param std::unique_ptr db + // @return Status - The error code return + Status LoadBuffer(const std::vector &keys, std::unique_ptr *db); + + // Read image list from ImageSets + // @return Status - The error code return + Status ParseImageIds(); + + // Read annotation from Annotation folder + // @return Status - The error code return + Status ParseAnnotationIds(); + + // @param const std::string &path - path to annotation xml + // @return Status - The error code return + Status ParseAnnotationBbox(const std::string &path); + + // @param const std::shared_ptr &sample_ids - sample ids of tensor + // @param std::vector *keys - image id + // @return Status - The error code return + Status TraverseSampleIds(const std::shared_ptr &sample_ids, std::vector *keys); + + // Called first when function is called + // @return Status - The error code return + Status LaunchThreadsAndInitOp(); + + // Reset dataset state + // @return Status - The error code return + Status Reset() override; + + // Private function for computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + bool decode_; + int64_t row_cnt_; + int64_t buf_cnt_; + std::string folder_path_; + TaskType task_type_; + std::string task_mode_; + int32_t rows_per_buffer_; + std::unique_ptr data_schema_; + + WaitPost wp_; + std::vector image_ids_; + QueueList> io_block_queues_; + std::map class_index_; + std::map label_index_; + std::map label_map_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_VOC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc new file mode 100644 index 0000000000..d1f07983f7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { +// Builder constructor. Creates the builder object. +TakeOp::Builder::Builder(int32_t count) : build_max_takes_(count) { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status TakeOp::Builder::SanityCheck() const { + if (build_max_takes_ <= 0) { + std::string err_msg("Take count must be greater than 0."); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +// The builder "build" method creates the final object. +Status TakeOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(build_max_takes_, builder_op_connector_size_); + return Status::OK(); +} + +// Constructor of the TakeOp. +TakeOp::TakeOp(int32_t count, int32_t op_connector_size) + : PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {} + +// A print method typically used for debugging +void TakeOp::Print(std::ostream &out, bool show_all) const { + // Always show the id and name as first line regardless if this summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << " [takes: " << max_takes_ << "]\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nTake count: " << take_count_ << "\nMax takes: " << max_takes_ << "\n\n"; + } +} + +// Main entry point for Take +Status TakeOp::operator()() { + TaskManager::FindMe()->Post(); + std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + + while (buf->eof() == false) { + if (take_count_ == max_takes_) { + // Do drain Operation + while (!buf->eoe() && !buf->eof()) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + } + } + + // Loop until non EOE is received + if (buf->eoe()) { + take_count_ = 0; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf))); + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + continue; + } + + // Get buffer and push back when take_count is still small + if (take_count_ < max_takes_) { + std::unique_ptr p_buffer; + RETURN_IF_NOT_OK(FillBuffer(&buf, &p_buffer)); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(p_buffer))); + } + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf)); + } + + take_count_ = 0; + MS_LOG(DEBUG) << "Meet the end and push-back eof buffer."; + auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + return Status::OK(); +} + +// Function FillBuffer mainly prepare the buffer for returning +Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer) { + int32_t buffer_size = (*buffer)->NumRows(); + if (take_count_ + buffer_size < max_takes_) { + *data_buffer = std::move(*buffer); + take_count_ = take_count_ + buffer_size; + } else { + MS_LOG(DEBUG) << "In last buffer: Push one buffer."; + std::unique_ptr new_tensor_table = std::make_unique(); + while (take_count_ < max_takes_) { + TensorRow new_row; + RETURN_IF_NOT_OK((*buffer)->PopRow(&new_row)); + take_count_++; + new_tensor_table->push_back(new_row); + } + (*buffer)->set_tensor_table(std::move(new_tensor_table)); + *data_buffer = std::move(*buffer); + } + return Status::OK(); +} + +// Visitor accept method for NodePass +Status TakeOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h new file mode 100644 index 0000000000..7f3f821bd8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/take_op.h @@ -0,0 +1,100 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ + +#include +#include +#include +#include "minddata/dataset/engine/datasetops/pipeline_op.h" + +namespace mindspore { +namespace dataset { +class TakeOp : public PipelineOp { + public: + // The nested builder class inside of the TakeOp is used to help manage all of the arguments + // for constructing it. This take op is very simple though, so this builder is really just + // provided for a consistent look and feel for creators of Dataset operators overall. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @param count - The number of takes to do + // @return This is a constructor. + explicit Builder(int32_t count); + + // Default destructor + ~Builder() = default; + + // The builder "build" method creates the final object. + // @return shared_ptr to the new TakeOp object + Status Build(std::shared_ptr *); + + private: + int32_t build_max_takes_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor of the TakeOp. + // @note The builder class should be used to call it + // @param count - The number of takes to do + explicit TakeOp(int32_t count, int32_t op_connector_size); + + // Destructor + ~TakeOp() = default; + + // A print method typically used for debugging + // @param out - The output stream to write output to + // @param show_all - A bool to control if you want to show all info or just a summary + void Print(std::ostream &out, bool show_all) const override; + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param ro - reference to the TakeOp to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, const TakeOp &ro) { + ro.Print(out, false); + return out; + } + + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "TakeOp"; } + + private: + int32_t max_takes_; // The number of takes that the user requested + int32_t take_count_; // A counter for the current number of executed takes + + Status FillBuffer(std::unique_ptr *buffer, std::unique_ptr *data_buffer); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc new file mode 100644 index 0000000000..88019c30fc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.cc @@ -0,0 +1,268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/db_connector.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +ZipOp::Builder::Builder() { + // Some arguments to the ZipOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the ZipOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status ZipOp::Builder::SanityCheck() const { return Status::OK(); } + +Status ZipOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_); + return Status::OK(); +} + +// Construct ZipOp here, local variables initialized in operator due to tree construction restrictions +ZipOp::ZipOp(int32_t rows_per_buffer, int32_t op_connector_size) + : PipelineOp(op_connector_size), + children_num_(0), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + draining_(false), + eof_(false) {} + +// destructor +ZipOp::~ZipOp() {} + +// Entry point for Zip, called by launch() +Status ZipOp::operator()() { + // The children_num_ parameter needs to be put here + children_num_ = child_.size(); + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // initialize the iterators + for (int32_t i = 0; i < children_num_; ++i) { + // magic number 0 since Zip is not a parallel Op + child_iterators_.push_back(std::make_unique(this, 0, i)); + } + + // Loop until eof is true + while (!eof_) { + // Create tensor table and prepare it by fetching and packing the first zipped row into it. + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + while (!draining_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: draining mode might get turned on if any of the child inputs were done + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + MS_LOG(DEBUG) << "Zip operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << column_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (draining_) { + MS_LOG(DEBUG) << "Zip operator is now draining child inputs."; + RETURN_IF_NOT_OK(drainPipeline()); + // Now that we have drained child inputs, send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + + // 5 handle eof + // propagate eof here. + MS_LOG(DEBUG) << "Zip operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status ZipOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Zip operator prepares for new epoch."; + draining_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row; + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!"); + } + + // Pack this first row into our tensor table + table->push_back(std::move(new_row)); + + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status ZipOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp fillBuffer null table pointer."); + } + TensorRow new_row; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// fetches next zip buffer row (merged row) +Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) { + // iterate over all iterators and generate a row + for (int32_t i = 0; i < children_num_; ++i) { + TensorRow new_row = {}; + RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row.empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(DEBUG) << "Zip operator child iterator produced empty row."; + draining_ = true; + new_zip_row->clear(); + // If we picked up an eof here, then we are completely done. + if ((child_iterators_[i])->eof_handled()) { + MS_LOG(DEBUG) << "Zip operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } else { + MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << "."; + // if row isn't empty then we can append the fetched row with new_zip_row + new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end()); + } + } + MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << "."; + return Status::OK(); +} + +// drain end of epoch messages from iterator for this epoch +Status ZipOp::drainPipeline() { + // we don't need to drain if we reached eof + if (eof_) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "ZipOp draining should not be done if already at eof!"); + } + for (int32_t con = 0; con < children_num_; ++con) { + MS_LOG(DEBUG) << "Zip operator draining child at " << con << "."; + RETURN_IF_NOT_OK(child_iterators_[con]->Drain()); + } + // at this point all connectors don't contain end of epoch messages. next iteration should be clean + return Status::OK(); +} + +// A function that prints info about the Operator +void ZipOp::Print(std::ostream &out, // In: The output stream to print to + bool show_all) const { // In: T/F if it should print everything + // Always show the id and name as first line regardless if this is summary or detailed print + out << "(" << std::setw(2) << operator_id_ << ") :"; + if (!show_all) { + // Call the super class for displaying any common 1-liner info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal 1-liner info for this op + out << "\n"; + } else { + // Call the super class for displaying any common detailed info + PipelineOp::Print(out, show_all); + // Then show any custom derived-internal stuff + out << "\nDatasets: " << children_num_ << "\n\n"; + } +} + +// overwrite function and handle eof +Status ZipOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Zip operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status ZipOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} + +// Visitor accept method for NodePass +Status ZipOp::Accept(NodePass *p, bool *modified) { + // Downcast shared pointer then call visitor + return p->RunOnNode(shared_from_base(), modified); +} + +Status ZipOp::ComputeColMap() { + if (column_name_id_map_.empty()) { + column_name_id_map_ = {}; + for (int32_t i = 0; i < child_.size(); ++i) { + // Initializing col_name_id_map from the child. + const std::unordered_map col_name_id_map = child_[i]->column_name_id_map(); + int32_t colsCurrent = column_name_id_map_.size(); + // the update code below shouldn't do anything bad if the column name already exists. + for (const auto &pair : col_name_id_map) { + std::string name = pair.first; + int32_t old_id = pair.second; + // check if name already exists in column name descriptor + if (column_name_id_map_.count(name) == 1) { + RETURN_STATUS_UNEXPECTED("key already exists when zipping datasets"); + } + column_name_id_map_[name] = old_id + colsCurrent; + } + } + MS_LOG(DEBUG) << "Setting column map:\n" << this->ColumnNameMapAsString(); + } else { + MS_LOG(WARNING) << "Column name map is already set!"; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h new file mode 100644 index 0000000000..c9466e26e2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/zip_op.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ +#define DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/dataset_iterator.h" +#include "minddata/dataset/engine/datasetops/pipeline_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// forward declare +class DataBuffer; + +class ZipOp : public PipelineOp { + public: + // The nested builder class inside of the ZipOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + // NOTE: the rows per buffer with initial value 0 means to default to the number of rows from the first child + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // The builder "build" method creates the ZipOp dataset Operator. + // @return shared_ptr to the new ZipOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + + Status SanityCheck() const; + }; + + // Constructor for ZipOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); + + // Destructor + ~ZipOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Zip + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const ZipOp &zo) { + zo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Base-class override for NodePass visitor acceptor. + // @param p - Pointer to the NodePass to be accepted. + // @param modified - Whether this node visit modified the pipeline. + // @return - Status of the node visit. + Status Accept(NodePass *p, bool *modified) override; + + // Op name getter + // @return Name of the current Op + std::string Name() const override { return "ZipOp"; } + + private: + // Handles preprocessing of the main loop, used when starting new epoch + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Special handle case where an empty row has been received from child iterator + // @note - we need to drain eoe signals from all children connectors. + // @details - when this function is called, then we encountered eoe at child iterator + // we have to drain rows from other child iterators until we hit eoe from all other child iterators + Status drainPipeline(); + + // Merges 1 row from each childIterator together + // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty + // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true + // @details merge rows from iterator together. This is the main functionality for ZipOp + // this function takes one row and fills it with tensors from rows fetched + // from childIterators. + // @example: + // Zips multiple rows at a time, the output is store in newZipRow + // 1 a T + // \ | / + // 1, a, T + Status getNextTensorRow(TensorRow *const new_zip_row); + + // Computing the assignment of the column name map. + // @return - Status + Status ComputeColMap() override; + + int32_t children_num_; + int32_t rows_per_buffer_; + int32_t buffer_id_; + bool draining_; + bool eof_; + std::vector> child_iterators_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_ZIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/db_connector.h b/mindspore/ccsrc/minddata/dataset/engine/db_connector.h new file mode 100644 index 0000000000..4a5c20bc12 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/db_connector.h @@ -0,0 +1,98 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DB_CONNECTOR_H_ +#define DATASET_ENGINE_DB_CONNECTOR_H_ + +#include +#include +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { +// DbConnector is a derived class from Connector with added logic to handle EOE and EOF. +// The Connector class itself is responsible to ensure deterministic order on every run. +class DbConnector : public Connector> { + public: + // Constructor of DbConnector + // @note DbConnector will create internal N number of blocking queues, where N = nProducers. + // See Connector.h for more details. + // @param n_producers The number of threads producing data into this DbConnector. + // @param n_consumers The number of thread consuming data from this DbConnector. + // @param queue_capacity The number of element (DataBuffer) for each internal queue. + DbConnector(int32_t n_producers, int32_t n_consumers, int32_t queue_capacity) + : Connector>(n_producers, n_consumers, queue_capacity), end_of_file_(false) {} + + // Destructor of DbConnector + ~DbConnector() = default; + + // Add a unique_ptr into the DbConnector. + // @note The caller of this add method should use std::move to pass the ownership to DbConnector. + // @param worker_id The id of a worker thread calling this method. + // @param el A rvalue reference to an element to be passed/added/pushed. + Status Add(int32_t worker_id, std::unique_ptr &&el) noexcept { + return (Connector>::Push(worker_id, std::move(el))); + } + + // Get a unique_ptr from the DbConnector. + // @note After the first EOF Buffer is encountered, subsequent pop()s will return EOF Buffer. + // This will provide/propagate the EOF to all consumer threads of this Connector. + // Thus, When the num_consumers < num_producers, there will be extra EOF messages in some of the internal queues + // and reset() must be called before reusing DbConnector. + // @param worker_id The id of a worker thread calling this method. + // @param result The address of a unique_ptr where the popped element will be placed. + // @param retry_if_eoe A flag to allow the same thread invoke pop() again if the current pop returns eoe buffer. + Status PopWithRetry(int32_t worker_id, std::unique_ptr *result, bool retry_if_eoe = false) noexcept { + if (result == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "[ERROR] nullptr detected when getting data from db connector"); + } else { + std::unique_lock lk(m_); + RETURN_IF_NOT_OK(cv_.Wait(&lk, [this, worker_id]() { return (expect_consumer_ == worker_id) || end_of_file_; })); + // Once an EOF message is encountered this flag will be set and we can return early. + if (end_of_file_) { + *result = std::make_unique(0, DataBuffer::kDeBFlagEOF); + } else { + RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); + if (*result == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "[ERROR] nullptr detected when getting data from db connector"); + } + // Setting the internal flag once the first EOF is encountered. + if ((*result)->eof()) { + end_of_file_ = true; + } + pop_from_ = (pop_from_ + 1) % num_producers_; + } + // Do not increment expect_consumer_ when result is eoe and retry_if_eoe is set. + if (!((*result)->eoe() && retry_if_eoe)) { + expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; + } + } + out_buffers_count_++; + cv_.NotifyAll(); + return Status::OK(); + } + + private: + // A flag to indicate the end of stream has been encountered. + bool end_of_file_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DB_CONNECTOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc new file mode 100644 index 0000000000..55dec24e79 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/execution_tree.h" +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/opt/post/repeat_pass.h" +#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/perf/monitor.h" + +namespace mindspore { +namespace dataset { +// Constructor +ExecutionTree::ExecutionTree() : id_count_(0) { + tg_ = std::make_unique(); + tree_state_ = kDeTStateInit; + prepare_flags_ = kDePrepNone; + perf_monitor_ = std::make_unique(this); + profiling_manager_ = std::make_unique(this); + optimize_ = common::GetEnv("OPTIMIZE") == "true" ? true : false; +} + +// Destructor +ExecutionTree::~ExecutionTree() { (void)tg_->ServiceStop(); } + +// Associates a DatasetOp with this tree. This assigns a valid node id to the operator and +// provides it with a link to the tree. A node cannot form any relationships (parent/child) with +// other nodes unless they are associated with the same tree. +Status ExecutionTree::AssociateNode(const std::shared_ptr &op) { + // If we are already a part of the tree, no-op + if (op->tree_ == this) { + return Status::OK(); + } + if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { + std::string err_msg = + "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected states: " + std::to_string(static_cast(kDeTStateInit)) + " or " + + std::to_string(static_cast(kDeTStateBuilding)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // Enter the building state if we were not already there + tree_state_ = kDeTStateBuilding; + + // Assign an id to the operator + op->set_id(id_count_); + id_count_++; + + // Assign our tree into the op so that each op has a link back to the tree + op->set_tree(this); + return Status::OK(); +} + +// Sets the root node of the tree +Status ExecutionTree::AssignRoot(const std::shared_ptr &op) { + // Tree must be in building state before we can assign root to it + if (tree_state_ != kDeTStateBuilding) { + std::string err_msg = + "Invalid tree state for assigning a root node. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected state: " + std::to_string(static_cast(kDeTStateBuilding)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + // If they didn't already call AssociateNode for this node before calling AssignRoot, + // then do so now. + if (op->operator_id_ == DatasetOp::kInvalidOperatorId) { + RETURN_IF_NOT_OK(this->AssociateNode(op)); + } + + // Then add it as the root. + root_ = op; + + return Status::OK(); +} + +// A print method typically used for debugging +void ExecutionTree::Print(std::ostream &out, const std::shared_ptr &op) const { + out << "Execution tree summary:\n" + << "-----------------------\n"; + this->PrintNode(out, op == nullptr ? root_ : op, "", true, false); + out << "\nExecution tree operator details:\n" + << "--------------------------------\n"; + this->PrintNode(out, op == nullptr ? root_ : op, "", true, true); +} + +// A helper functions for doing the recursive printing +void ExecutionTree::PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, + bool last, bool detailed) const { + // Decide which printer to use based on detailed arg. + if (!detailed) { + out << indent << "+- " << *dataset_op; + indent += (last ? " " : "| "); + } else { + dataset_op->Print(out, detailed); + } + + // Descend to children + for (int32_t i = 0; i < dataset_op->child_.size(); ++i) { + this->PrintNode(out, dataset_op->child_[i], indent, (i == (dataset_op->child_.size() - 1)), detailed); + } +} + +// Start the execution of the tree +Status ExecutionTree::Launch() { + // Tree must be built and prepared before it can be launched! + if (tree_state_ != kDeTStateReady) { + std::string err_msg = + "Invalid tree state for launching tree. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected state: " + std::to_string(static_cast(kDeTStateReady)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + std::ostringstream ss; + ss << *this; + + // Profiling infrastructures need to be initialized before Op launching + if (profiling_manager_->IsProfilingEnable()) { + // Setup profiling manager + RETURN_IF_NOT_OK(profiling_manager_->Initialize()); + // Launch Monitor Thread + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Monitor Thread launched", std::ref(*perf_monitor_))); + } + + MS_LOG(DEBUG) << "Printing the tree before launch tasks:\n" << ss.str(); + for (auto itr = this->begin(); itr != this->end(); ++itr) { + // An inlined operator is one that has an output connector size of 0, and it does not + // require a thread to execute. Instead, the work of this operator is executed inlined + // from the tree node directly above it (or in the case of a root node, it runs from within + // the launching tree/user thread. Do not exec any thread for an inlined op. + itr->state_ = DatasetOp::OpState::kDeOpRunning; + if (!itr->inlined()) { + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Op launched, OperatorId:" + std::to_string(itr->id()), std::ref(*itr))); + // Set the state of the Operator as running. This only matters in Leaf ops, CacheOp and TakeOp + } + } + + tree_state_ = kDeTStateExecuting; + + return Status::OK(); +} + +// A function that traverse the tree in postorder then save the results in nodes +void ExecutionTree::Iterator::PostOrderTraverse(const std::shared_ptr &node) { + if (node == nullptr) { + return; + } + for (int32_t i = 0; i < node->child_.size(); ++i) { + PostOrderTraverse(node->child_[i]); + } + nodes_.push_back(node); +} + +ExecutionTree::Iterator::Iterator(const std::shared_ptr &root) : ind_(0) { + // post-order traverse the tree, if root is null, it return + PostOrderTraverse(root); + nodes_.emplace_back(nullptr); +} + +// Given the number of workers, launches the worker entry function for each. Essentially a +// wrapper for the TaskGroup handling that is stored inside the execution tree. +Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function func) { + // Launch the workers + for (int32_t i = 0; i < num_workers; ++i) { + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("Parallel Op Worker", std::bind(func, i))); + } + return Status::OK(); +} + +// The driver of the prepare phase of the execution tree. +// Prepare phase consists of three sub phases +// +// 1. PrepareTreePreAction() +// Compulsory transformation/action pre optimization. +// For example, CacheOp Insertion +// +// 2. Optimize() +// Optimization transformation/action, optional +// For example, MapOp Fusion +// +// 3. PrepareTreePostAction() +// Compulsory transformation/action post optimization. +// For example, repeatOp inlining +// +// @return Status - The error code return +Status ExecutionTree::Prepare() { + // Pre optimization compulsory transformation + RETURN_IF_NOT_OK(this->PrepareTreePreAction()); + + // If optional optimizations are enabled + if (optimize_) { + RETURN_IF_NOT_OK(this->Optimize()); + } + + // Post optimization compulsory transformation + RETURN_IF_NOT_OK(this->PrepareTreePostAction()); + + // Existing transformation implementation, will be removed later + RETURN_IF_NOT_OK(this->PrepareDeprecated()); + return Status::OK(); +} + +Status ExecutionTree::PrepareTreePreAction() { + bool modified = false; + std::vector> pre_actions; + // Construct pre actions + MS_LOG(INFO) << "Running pre pass loops."; + pre_actions.push_back(std::make_unique()); + pre_actions.push_back(std::make_unique()); + // Apply pre action passes + for (auto &pass : pre_actions) { + RETURN_IF_NOT_OK(pass->Run(this, &modified)); + } + MS_LOG(INFO) << "Pre passes complete."; + return Status::OK(); +} + +Status ExecutionTree::PrepareTreePostAction() { + // The tree is ready to be prepared. + tree_state_ = kDeTStatePrepare; + + bool modified = false; + std::vector> post_actions; + // Construct pre actions + MS_LOG(INFO) << "Running post pass loops."; + post_actions.push_back(std::make_unique()); + + // Apply post action passes + for (auto &pass : post_actions) { + RETURN_IF_NOT_OK(pass->Run(this, &modified)); + } + MS_LOG(INFO) << "Post passes complete."; + + return Status::OK(); +} + +Status ExecutionTree::Optimize() { + // Vector of optimizations, currently only 1, add more as necessary + std::vector> optimizations; + optimizations.push_back(std::make_unique()); + // vector of flags for each optimization + std::vector modified(optimizations.size(), false); + for (auto i = 0; i < optimizations.size(); i++) { + auto m = false; + optimizations[i]->Run(this, &m); + modified[i] = m; + } + return Status::OK(); +} + +// The driver of the prepare phase of the execution tree. The prepare phase will recursively +// walk the tree to perform modifications to the tree or specific nodes within the tree to get +// it ready for execution. +// +// This driver is deprecated. +Status ExecutionTree::PrepareDeprecated() { + // Tree must be in pending prepare state before we can assign root to it + if (tree_state_ != kDeTStatePrepare) { + std::string err_msg = + "Invalid tree state for preparing the tree. Current state: " + std::to_string(static_cast(tree_state_)) + + " Expected state: " + std::to_string(static_cast(kDeTStatePrepare)); + RETURN_STATUS_UNEXPECTED(err_msg); + } + // Start the recursive prepare + RETURN_IF_NOT_OK(this->PrepareNode(root_)); + tree_state_ = kDeTStateReady; + return Status::OK(); +} + +// Recursive function used during prepare phase to visit a node and drive any pre- and post- +// node actions during a tree walk. +Status ExecutionTree::PrepareNode(const std::shared_ptr &dataset_op) { + // execute PreAction + RETURN_IF_NOT_OK(dataset_op->PrepareNodePreAction()); + + // Before going down into children, make any prepare flags updates based on this operator. + uint32_t op_prep_flags = dataset_op->PrepareFlags(); + BitSet(&prepare_flags_, op_prep_flags); + + // Now, descend to children + for (const auto &i : dataset_op->child_) { + RETURN_IF_NOT_OK(this->PrepareNode(i)); + } + + // No more children, now we execute any prepare actions before going back up the + // the tree on recursive function + RETURN_IF_NOT_OK(dataset_op->PrepareNodePostAction()); + + // Then clear the flags from this op now that we have prepared it. + BitClear(&prepare_flags_, op_prep_flags); + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h new file mode 100644 index 0000000000..b62bf8e85d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/execution_tree.h @@ -0,0 +1,257 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_EXECUTION_TREE_H_ +#define DATASET_ENGINE_EXECUTION_TREE_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/util/status.h" +#include "mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +// Forward declares +class TaskGroup; +class DatasetOp; +class Monitor; + +class ExecutionTree { + public: + // Prepare flags used during tree prepare phase + enum PrepareFlags { + kDePrepNone = 0, + kDePrepRepeat = 1, // Processing a repeat operation + kDePrepCache = 2 // Processing a cache operation + }; + + // State flags for the lifecycle of the tree + enum TreeState { + kDeTStateInit = 0, // The freshly initialized state after construction + kDeTStateBuilding, // The tree is being built, nodes are being added + kDeTStatePrepare, // The tree has been assigned a root node and is pending prepare + kDeTStateReady, // The tree has been prepared and is ready to be launched + kDeTStateExecuting, // The tree has been launched and is executing + kDeTStateFinished // The tree has been drained, dataset iterator received EOF + }; + + class Iterator { + public: + // Constructor + // @param root The root node to start iterating from + explicit Iterator(const std::shared_ptr &root = nullptr); + + // Destructor + ~Iterator() {} + + Iterator &operator++() { + ++ind_; + return *this; + } // prefix ++ overload + Iterator operator++(int) { + Iterator it = *this; + it.ind_ = ind_; + ind_++; + return it; + } // post-fix ++ overload + Iterator &operator--() { + --ind_; + return *this; + } // prefix -- overload + Iterator operator--(int) { + Iterator it = *this; + it.ind_ = ind_; + ind_--; + return it; + } // post-fix -- overload + DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator + std::shared_ptr operator->() { return nodes_[ind_]; } + + // getter function + // @return Shared pointer to the current operator + std::shared_ptr get() { return nodes_[ind_]; } + + bool operator==(const Iterator &rhs) { return nodes_[ind_] == rhs.nodes_[rhs.ind_]; } + + bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; } + + int32_t NumNodes() { return nodes_.size(); } + + private: + int32_t ind_; // the cur node our Iterator points to + std::vector> nodes_; // store the nodes in post order + void PostOrderTraverse(const std::shared_ptr &); + }; + + // Constructor + ExecutionTree(); + + // Destructor + ~ExecutionTree(); + + // Associates a DatasetOp with this tree. This assigns a valid node id to the operator and + // provides it with a link to the tree. A node cannot form any relationships (parent/child) with + // other nodes unless they are associated with the same tree. + // @param op - The operator to associate + // @return Status - The error code return + Status AssociateNode(const std::shared_ptr &op); + + // Sets the root node of the tree + // @param op - The operator to assign as root + // @return Status - The error code return + Status AssignRoot(const std::shared_ptr &op); + + // Start the execution of the tree + // @return Status - The error code return + Status Launch(); + + /// A print method typically used for debugging + /// \param out - The output stream to write output to + void Print(std::ostream &out, const std::shared_ptr &op = nullptr) const; + + // Returns an iterator positioned at the start + // @return Iterator - The iterator + ExecutionTree::Iterator begin(const std::shared_ptr &root = nullptr) const { + return Iterator(root == nullptr ? root_ : root); + } + + // Returns an iterator positioned at the end + // @return Iterator - The iterator + ExecutionTree::Iterator end() const { return Iterator(nullptr); } + + // << Stream output operator overload + // @notes This allows you to write the debug print info using stream operators + // @param out - reference to the output stream being overloaded + // @param exe_tree - reference to the execution tree to display + // @return - the output stream must be returned + friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) { + exe_tree.Print(out); + return out; + } + + // Given the number of workers, launches the worker entry function for each. Essentially a + // wrapper for the TaskGroup handling that is stored inside the execution tree. + // @param num_workers - The number of workers to launch + // @param func - The function entry point that workers will execute + // @return Status - The error code return + Status LaunchWorkers(int32_t num_workers, std::function func); + + // Getter method + // @return shared_ptr to the root operator + std::shared_ptr root() const { return root_; } + + // Getter method + // @return the prepare flags + uint32_t PrepareFlags() const { return prepare_flags_; } + + // The driver of the prepare phase of the execution tree. + // Prepare phase consists of three sub phases + // + // 1. PrepareTreePreAction() + // Compulsory transformation/action pre optimization. + // For example, CacheOp Insertion + // + // 2. Optimize() + // Optimization transformation/action, optional + // For example, MapOp Fusion + // + // 3. PrepareTreePostAction() + // Compulsory transformation/action post optimization. + // For example, repeatOp inlining + // + // @return Status - The error code return + Status Prepare(); + + // Compulsory transformation/action pre optimization. + // @return Status - The error code return + Status PrepareTreePreAction(); + + // Compulsory transformation/action post optimization. + // @return Status - The error code return + Status PrepareTreePostAction(); + + // Optimization transformation/action, optional. + // @return Status - The error code return + Status Optimize(); + + // The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively + // walk the tree to perform modifications to the tree or specific nodes within the tree to get + // it ready for execution. + // @return Status - The error code return + Status PrepareDeprecated(); + + // Recursive function used during prepare phase to visit a node and drive any pre- and post- + // node actions during a tree walk. + // @param op - The dataset op to work on + // @return Status - The error code return + Status PrepareNode(const std::shared_ptr &dataset_op); + + // Return the pointer to the TaskGroup + // @return raw pointer to the TaskGroup + TaskGroup *AllTasks() const { return tg_.get(); } + + // Return if the ExecutionTree is finished (iterator receives EOF). + // @return Bool - true is ExecutionTree is finished + bool isFinished() const { return tree_state_ == TreeState::kDeTStateFinished; } + + // Set the ExecutionTree to Finished state. + void SetFinished() { tree_state_ = TreeState::kDeTStateFinished; } + + // Getter for profiling manager, no ownership + ProfilingManager *GetProfilingManager() { return profiling_manager_.get(); } + + // Set optional optimization if tree has not been prepared yet + Status SetOptimize(bool value) { + if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { + std::string optimize = (optimize_ == true) ? "true" : "false"; + std::string msg = "Tree has already been prepared with OPTIMIZE set to " + optimize; + RETURN_STATUS_UNEXPECTED(msg); + } else { + optimize_ = value; + return Status::OK(); + } + } + + // Optional optimizations status + bool OptimizationEnabled() const { return optimize_; } + + private: + // A helper functions for doing the recursive printing + // @param dataset_op - The dataset op to print + // @param indent - an indent string for aligning child levels in output + // @param last - an indicator if it's the last child or not + // @param detailed - should it display the detailed node output or the summary line + void PrintNode(std::ostream &out, const std::shared_ptr &dataset_op, std::string indent, bool last, + bool detailed) const; + + std::unique_ptr tg_; // Class for worker management + std::shared_ptr root_; // The root node of the tree + int32_t id_count_; // Counter for generating operator id's + uint32_t prepare_flags_; // Flags used during tree prepare + TreeState tree_state_; // Tracking the current tree state + std::unique_ptr perf_monitor_; // Performance Monitor + std::unique_ptr profiling_manager_; // Profiling manager + bool optimize_; // Flag to enable optional optimizations +}; + +inline bool operator==(const ExecutionTree::Iterator &lhs, const ExecutionTree::Iterator &rhs) { return lhs == rhs; } +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_EXECUTION_TREE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h new file mode 100644 index 0000000000..c62c088bab --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_EDGE_H_ +#define DATASET_ENGINE_GNN_EDGE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using EdgeType = int8_t; +using EdgeIdType = int32_t; + +class Edge { + public: + // Constructor + // @param EdgeIdType id - edge id + // @param EdgeType type - edge type + // @param std::shared_ptr src_node - source node + // @param std::shared_ptr dst_node - destination node + Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) + : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} + + virtual ~Edge() = default; + + // @return NodeIdType - Returned edge id + EdgeIdType id() const { return id_; } + + // @return NodeIdType - Returned edge type + EdgeType type() const { return type_; } + + // Get the feature of a edge + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; + + // Get nodes on the edge + // @param std::pair, std::shared_ptr> *out_node - Source and destination nodes returned + Status GetNode(std::pair, std::shared_ptr> *out_node) { + *out_node = std::make_pair(src_node_, dst_node_); + return Status::OK(); + } + + // Set node to edge + // @param const std::pair, std::shared_ptr> &in_node - + Status SetNode(const std::pair, std::shared_ptr> &in_node) { + src_node_ = in_node.first; + dst_node_ = in_node.second; + return Status::OK(); + } + + // Update feature of edge + // @param std::shared_ptr feature - + // @return Status - The error code return + virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; + + protected: + EdgeIdType id_; + EdgeType type_; + std::shared_ptr src_node_; + std::shared_ptr dst_node_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_EDGE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc new file mode 100644 index 0000000000..dba4a6fa60 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +Feature::Feature(FeatureType type_name, std::shared_ptr value) : type_name_(type_name), value_(value) {} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h new file mode 100644 index 0000000000..0d7eba1009 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_FEATURE_H_ +#define DATASET_ENGINE_GNN_FEATURE_H_ + +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using FeatureType = int16_t; + +class Feature { + public: + // Constructor + // @param FeatureType type_name - feature type + // @param std::shared_ptr value - feature value + Feature(FeatureType type_name, std::shared_ptr value); + + ~Feature() = default; + + // Get feature value + // @return std::shared_ptr *out_value - feature value + const std::shared_ptr Value() const { return value_; } + + // @return NodeIdType - Returned feature type + FeatureType type() const { return type_name_; } + + private: + FeatureType type_name_; + std::shared_ptr value_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_FEATURE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc new file mode 100644 index 0000000000..9083eb4c4b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc @@ -0,0 +1,681 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/graph.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +Graph::Graph(std::string dataset_file, int32_t num_workers) + : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { + rnd_.seed(GetSeed()); + MS_LOG(INFO) << "num_workers:" << num_workers; +} + +Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { + auto itr = node_type_map_.find(node_type); + if (itr == node_type_map_.end()) { + std::string err_msg = "Invalid node type:" + std::to_string(node_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); + } + return Status::OK(); +} + +template +Status Graph::CreateTensorByVector(const std::vector> &data, DataType type, + std::shared_ptr *out) { + if (!type.IsCompatible()) { + RETURN_STATUS_UNEXPECTED("Data type not compatible"); + } + if (data.empty()) { + RETURN_STATUS_UNEXPECTED("Input data is empty"); + } + std::shared_ptr tensor; + size_t m = data.size(); + size_t n = data[0].size(); + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); + auto ptr = tensor->begin(); + for (const auto &id_m : data) { + CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); + for (const auto &id_n : id_m) { + *ptr = id_n; + ptr++; + } + } + tensor->Squeeze(); + *out = std::move(tensor); + return Status::OK(); +} + +template +Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { + if (!data || data->empty()) { + RETURN_STATUS_UNEXPECTED("Input data is empty"); + } + for (std::vector &vec : *data) { + size_t size = vec.size(); + if (size > max_size) { + RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal"); + } else { + for (size_t i = 0; i < (max_size - size); ++i) { + vec.push_back(default_value); + } + } + } + return Status::OK(); +} + +Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { + auto itr = edge_type_map_.find(edge_type); + if (itr == edge_type_map_.end()) { + std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); + } + return Status::OK(); +} + +Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { + if (edge_list.empty()) { + RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); + } + + std::vector> node_list; + node_list.reserve(edge_list.size()); + for (const auto &edge_id : edge_list) { + auto itr = edge_id_map_.find(edge_id); + if (itr == edge_id_map_.end()) { + std::string err_msg = "Invalid edge id:" + std::to_string(edge_id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + std::pair, std::shared_ptr> nodes; + RETURN_IF_NOT_OK(itr->second->GetNode(&nodes)); + node_list.push_back({nodes.first->id(), nodes.second->id()}); + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(node_list, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); + + std::vector> neighbors; + size_t max_neighbor_num = 0; + neighbors.resize(node_list.size()); + for (size_t i = 0; i < node_list.size(); ++i) { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); + RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); + max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); + } + + RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); + + return Status::OK(); +} + +Status Graph::CheckSamplesNum(NodeIdType samples_num) { + NodeIdType all_nodes_number = + std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, + [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); + if ((samples_num < 1) || (samples_num > all_nodes_number)) { + std::string err_msg = "Wrong samples number, should be between 1 and " + std::to_string(all_nodes_number) + + ", got " + std::to_string(samples_num); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +Status Graph::CheckNeighborType(NodeType neighbor_type) { + if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { + std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } + return Status::OK(); +} + +Status Graph::GetSampledNeighbors(const std::vector &node_list, + const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), + "The sizes of neighbor_nums and neighbor_types are inconsistent."); + for (const auto &num : neighbor_nums) { + RETURN_IF_NOT_OK(CheckSamplesNum(num)); + } + for (const auto &type : neighbor_types) { + RETURN_IF_NOT_OK(CheckNeighborType(type)); + } + std::vector> neighbors_vec(node_list.size()); + for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + std::shared_ptr input_node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &input_node)); + neighbors_vec[node_idx].emplace_back(node_list[node_idx]); + std::vector input_list = {node_list[node_idx]}; + for (size_t i = 0; i < neighbor_nums.size(); ++i) { + std::vector neighbors; + neighbors.reserve(input_list.size() * neighbor_nums[i]); + for (const auto &node_id : input_list) { + if (node_id == kDefaultNodeId) { + for (int32_t j = 0; j < neighbor_nums[i]; ++j) { + neighbors.emplace_back(kDefaultNodeId); + } + } else { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); + std::vector out; + RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); + neighbors.insert(neighbors.end(), out.begin(), out.end()); + } + } + neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end()); + input_list = std::move(neighbors); + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors_vec, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::NegativeSample(const std::vector &data, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples) { + CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); + std::vector shuffled_id(data.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + for (const auto &index : shuffled_id) { + if (exclude_data.find(data[index]) != exclude_data.end()) { + continue; + } + out_samples->emplace_back(data[index]); + if (out_samples->size() >= samples_num) { + break; + } + } + return Status::OK(); +} + +Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); + RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); + + std::vector> neg_neighbors_vec; + neg_neighbors_vec.resize(node_list.size()); + for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node)); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors)); + std::unordered_set exclude_nodes; + std::transform(neighbors.begin(), neighbors.end(), + std::insert_iterator>(exclude_nodes, exclude_nodes.begin()), + [](const NodeIdType node) { return node; }); + const std::vector &all_nodes = node_type_map_[neg_neighbor_type]; + neg_neighbors_vec[node_idx].emplace_back(node->id()); + if (all_nodes.size() > exclude_nodes.size()) { + while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { + RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), + &neg_neighbors_vec[node_idx])); + } + } else { + MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() + << " neg_neighbor_type:" << neg_neighbor_type; + // If there are no negative neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); + } + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) { + RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); + std::vector> walks; + RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); + RETURN_IF_NOT_OK(CreateTensorByVector({walks}, DataType(DataType::DE_INT32), out)); + return Status::OK(); +} + +Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_node_feature_map_.find(feature_type); + if (itr == default_node_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_edge_feature_map_.find(feature_type); + if (itr == default_edge_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out) { + if (!nodes || nodes->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input nodes is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + TensorRow tensors; + for (const auto &f_type : feature_types) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); + + TensorShape shape(default_feature->Value()->shape()); + auto shape_vec = nodes->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + + dsize_t index = 0; + for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { + std::shared_ptr feature; + if (*node_itr == kDefaultNodeId) { + feature = default_feature; + } else { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); + if (!node->GetFeatures(f_type, &feature).IsOk()) { + feature = default_feature; + } + } + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); + index++; + } + + TensorShape reshape(nodes->shape()); + for (auto s : default_feature->Value()->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + tensors.push_back(fea_tensor); + } + *out = std::move(tensors); + return Status::OK(); +} + +Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, + TensorRow *out) { + if (!edges || edges->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input edges is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + TensorRow tensors; + for (const auto &f_type : feature_types) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetEdgeDefaultFeature(f_type, &default_feature)); + + TensorShape shape(default_feature->Value()->shape()); + auto shape_vec = edges->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + + dsize_t index = 0; + for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { + std::shared_ptr edge; + RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); + std::shared_ptr feature; + if (!edge->GetFeatures(f_type, &feature).IsOk()) { + feature = default_feature; + } + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); + index++; + } + + TensorShape reshape(edges->shape()); + for (auto s : default_feature->Value()->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + tensors.push_back(fea_tensor); + } + *out = std::move(tensors); + return Status::OK(); +} + +Status Graph::Init() { + RETURN_IF_NOT_OK(LoadNodeAndEdge()); + return Status::OK(); +} + +Status Graph::GetMetaInfo(MetaInfo *meta_info) { + meta_info->node_type.resize(node_type_map_.size()); + std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), + [](auto itr) { return itr.first; }); + std::sort(meta_info->node_type.begin(), meta_info->node_type.end()); + + meta_info->edge_type.resize(edge_type_map_.size()); + std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(), + [](auto itr) { return itr.first; }); + std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end()); + + for (const auto &node : node_type_map_) { + meta_info->node_num[node.first] = node.second.size(); + } + + for (const auto &edge : edge_type_map_) { + meta_info->edge_num[edge.first] = edge.second.size(); + } + + for (const auto &node_feature : node_feature_map_) { + for (auto type : node_feature.second) { + meta_info->node_feature_type.emplace_back(type); + } + } + std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); + auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end()); + meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end()); + + for (const auto &edge_feature : edge_feature_map_) { + for (const auto &type : edge_feature.second) { + meta_info->edge_feature_type.emplace_back(type); + } + } + std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); + auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end()); + meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end()); + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status Graph::GraphInfo(py::dict *out) { + MetaInfo meta_info; + RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); + (*out)["node_type"] = py::cast(meta_info.node_type); + (*out)["edge_type"] = py::cast(meta_info.edge_type); + (*out)["node_num"] = py::cast(meta_info.node_num); + (*out)["edge_num"] = py::cast(meta_info.edge_num); + (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); + (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); + return Status::OK(); +} +#endif + +Status Graph::LoadNodeAndEdge() { + GraphLoader gl(dataset_file_, num_workers_); + // ask graph_loader to load everything into memory + RETURN_IF_NOT_OK(gl.InitAndLoad()); + // get all maps + RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, + &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, + &default_edge_feature_map_)); + return Status::OK(); +} + +Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { + auto itr = node_id_map_.find(id); + if (itr == node_id_map_.end()) { + std::string err_msg = "Invalid node id:" + std::to_string(id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *node = itr->second; + } + return Status::OK(); +} + +Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { + auto itr = edge_id_map_.find(id); + if (itr == edge_id_map_.end()) { + std::string err_msg = "Invalid edge id:" + std::to_string(id); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *edge = itr->second; + } + return Status::OK(); +} + +Graph::RandomWalkBase::RandomWalkBase(Graph *graph) + : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} + +Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, const NodeIdType default_node, + int32_t num_walks, int32_t num_workers) { + CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); + node_list_ = node_list; + if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { + std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) + + ". The size of input path is " + std::to_string(meta_path.size()); + RETURN_STATUS_UNEXPECTED(err_msg); + } + for (const auto &type : meta_path) { + RETURN_IF_NOT_OK(graph_->CheckNeighborType(type)); + } + meta_path_ = meta_path; + if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) { + std::string err_msg = "Failed, step_home_param and step_away_param required greater than " + + std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) + + ", step_away_param: " + std::to_string(step_away_param); + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (default_node < -1) { + std::string err_msg = "Failed, default_node required to be greater or equal to -1."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (num_walks <= 0) { + std::string err_msg = "Failed, num_walks parameter required to be greater than 0"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + if (num_workers <= 0) { + std::string err_msg = "Failed, num_workers parameter required to be greater than 0"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + step_home_param_ = step_home_param; + step_away_param_ = step_away_param; + default_node_ = default_node; + num_walks_ = num_walks; + num_workers_ = num_workers; + return Status::OK(); +} + +Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { + // Simulate a random walk starting from start node. + auto walk = std::vector(1, start_node); // walk is an vector + // walk simulate + while (walk.size() - 1 < meta_path_.size()) { + // current nodE + auto cur_node_id = walk.back(); + std::shared_ptr cur_node; + RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node)); + + // current neighbors + std::vector cur_neighbors; + RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true)); + std::sort(cur_neighbors.begin(), cur_neighbors.end()); + + // break if no neighbors + if (cur_neighbors.empty()) { + break; + } + + // walk by the fist node, then by the previous 2 nodes + std::shared_ptr stochastic_index; + if (walk.size() == 1) { + RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index)); + } else { + NodeIdType prev_node_id = walk[walk.size() - 2]; + RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index)); + } + NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)]; + walk.push_back(next_node_id); + } + + while (walk.size() - 1 < meta_path_.size()) { + walk.push_back(default_node_); + } + + *walk_path = std::move(walk); + return Status::OK(); +} + +Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { + for (int32_t i = 0; i < num_walks_; i++) { + for (const auto &node : node_list_) { + std::vector walk; + RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); + walks->push_back(walk); + } + } + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability) { + // Generate alias nodes + std::shared_ptr node; + graph_->GetNodeByNodeId(node_id, &node); + std::vector neighbors; + RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true)); + std::sort(neighbors.begin(), neighbors.end()); + auto non_normalized_probability = std::vector(neighbors.size(), 1.0); + *node_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability) { + // Get the alias edge setup lists for a given edge. + std::shared_ptr src_node; + graph_->GetNodeByNodeId(src, &src_node); + std::vector src_neighbors; + RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true)); + + std::shared_ptr dst_node; + graph_->GetNodeByNodeId(dst, &dst_node); + std::vector dst_neighbors; + RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true)); + + std::sort(dst_neighbors.begin(), dst_neighbors.end()); + std::vector non_normalized_probability; + for (const auto &dst_nbr : dst_neighbors) { + if (dst_nbr == src) { + non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + continue; + } + auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr); + if (it != src_neighbors.end()) { + // stay close, this node connect both src and dst + non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight'] + } else { + // step far away + non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight'] + } + } + + *edge_probability = + std::make_shared(GenerateProbability(Normalize(non_normalized_probability))); + return Status::OK(); +} + +StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { + uint32_t K = probability.size(); + std::vector switch_to_large_index(K, 0); + std::vector weight(K, .0); + std::vector smaller; + std::vector larger; + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon); + float accumulate_threshold = 0.0; + for (uint32_t i = 0; i < K; i++) { + float threshold_one = distribution(random_device); + accumulate_threshold += threshold_one; + weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold; + weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i); + } + + while ((!smaller.empty()) && (!larger.empty())) { + uint32_t small = smaller.back(); + smaller.pop_back(); + uint32_t large = larger.back(); + larger.pop_back(); + switch_to_large_index[small] = large; + weight[large] = weight[large] + weight[small] - 1.0; + weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large); + } + return StochasticIndex(switch_to_large_index, weight); +} + +uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { + auto switch_to_large_index = stochastic_index.first; + auto weight = stochastic_index.second; + const uint32_t size_of_index = switch_to_large_index.size(); + + auto random_device = GetRandomDevice(); + std::uniform_real_distribution<> distribution(0.0, 1.0); + + // Generate random integer between [0, K) + uint32_t random_idx = std::floor(distribution(random_device) * size_of_index); + + if (distribution(random_device) < weight[random_idx]) { + return random_idx; + } + return switch_to_large_index[random_idx]; +} + +template +std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { + float sum_probability = + 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); + if (sum_probability < kGnnEpsilon) { + sum_probability = 1.0; + } + std::vector normalized_probability; + std::transform(non_normalized_probability.begin(), non_normalized_probability.end(), + std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; }); + return normalized_probability; +} +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h new file mode 100644 index 0000000000..76930d91f2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h @@ -0,0 +1,267 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_GRAPH_H_ +#define DATASET_ENGINE_GNN_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/engine/gnn/graph_loader.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +const float kGnnEpsilon = 0.0001; +const uint32_t kMaxNumWalks = 80; +using StochasticIndex = std::pair, std::vector>; + +struct MetaInfo { + std::vector node_type; + std::vector edge_type; + std::map node_num; + std::map edge_num; + std::vector node_feature_type; + std::vector edge_feature_type; +}; + +class Graph { + public: + // Constructor + // @param std::string dataset_file - + // @param int32_t num_workers - number of parallel threads + Graph(std::string dataset_file, int32_t num_workers); + + ~Graph() = default; + + // Get all nodes from the graph. + // @param NodeType node_type - type of node + // @param std::shared_ptr *out - Returned nodes id + // @return Status - The error code return + Status GetAllNodes(NodeType node_type, std::shared_ptr *out); + + // Get all edges from the graph. + // @param NodeType edge_type - type of edge + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The error code return + Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); + + // Get the node id from the edge. + // @param std::vector edge_list - List of edges + // @param std::shared_ptr *out - Returned node ids + // @return Status - The error code return + Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); + + // All neighbors of the acquisition node. + // @param std::vector node_list - List of nodes + // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is + // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors + // is not enough, fill in tensor as -1. + // @return Status - The error code return + Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out); + + // Get sampled neighbors. + // @param std::vector node_list - List of nodes + // @param std::vector neighbor_nums - Number of neighbors sampled per hop + // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::shared_ptr *out - Returned neighbor's id. + // @return Status - The error code return + Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out); + + // Get negative sampled neighbors. + // @param std::vector node_list - List of nodes + // @param NodeIdType samples_num - Number of neighbors sampled + // @param NodeType neg_neighbor_type - The type of negative neighbor. + // @param std::shared_ptr *out - Returned negative neighbor's id. + // @return Status - The error code return + Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out); + + // Node2vec random walk. + // @param std::vector node_list - List of nodes + // @param std::vector meta_path - node type of each step + // @param float step_home_param - return hyper parameter in node2vec algorithm + // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param NodeIdType default_node - default node id + // @param std::shared_ptr *out - Returned nodes id in walk path + // @return Status - The error code return + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out); + + // Get the feature of a node + // @param std::shared_ptr nodes - List of nodes + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param TensorRow *out - Returned features + // @return Status - The error code return + Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out); + + // Get the feature of a edge + // @param std::shared_ptr edget - List of edges + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param Tensor *out - Returned features + // @return Status - The error code return + Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, + TensorRow *out); + + // Get meta information of graph + // @param MetaInfo *meta_info - Returned meta information + // @return Status - The error code return + Status GetMetaInfo(MetaInfo *meta_info); + +#ifdef ENABLE_PYTHON + // Return meta information to python layer + Status GraphInfo(py::dict *out); +#endif + + Status Init(); + + private: + class RandomWalkBase { + public: + explicit RandomWalkBase(Graph *graph); + + Status Build(const std::vector &node_list, const std::vector &meta_path, + float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, + int32_t num_walks = 1, int32_t num_workers = 1); + + ~RandomWalkBase() = default; + + Status SimulateWalk(std::vector> *walks); + + private: + Status Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path); + + Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability); + + Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, + std::shared_ptr *edge_probability); + + static StochasticIndex GenerateProbability(const std::vector &probability); + + static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index); + + template + std::vector Normalize(const std::vector &non_normalized_probability); + + Graph *graph_; + std::vector node_list_; + std::vector meta_path_; + float step_home_param_; // Return hyper parameter. Default is 1.0 + float step_away_param_; // Inout hyper parameter. Default is 1.0 + NodeIdType default_node_; + + int32_t num_walks_; // Number of walks per source. Default is 1 + int32_t num_workers_; // The number of worker threads. Default is 1 + }; + + // Load graph data from mindrecord file + // @return Status - The error code return + Status LoadNodeAndEdge(); + + // Create Tensor By Vector + // @param std::vector> &data - + // @param DataType type - + // @param std::shared_ptr *out - + // @return Status - The error code return + template + Status CreateTensorByVector(const std::vector> &data, DataType type, std::shared_ptr *out); + + // Complete vector + // @param std::vector> *data - To be completed vector + // @param size_t max_size - The size of the completed vector + // @param T default_value - Filled default + // @return Status - The error code return + template + Status ComplementVector(std::vector> *data, size_t max_size, T default_value); + + // Get the default feature of a node + // @param FeatureType feature_type - + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + // Get the default feature of a edge + // @param FeatureType feature_type - + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + // Find node object using node id + // @param NodeIdType id - + // @param std::shared_ptr *node - Returned node object + // @return Status - The error code return + Status GetNodeByNodeId(NodeIdType id, std::shared_ptr *node); + + // Find edge object using edge id + // @param EdgeIdType id - + // @param std::shared_ptr *edge - Returned edge object + // @return Status - The error code return + Status GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge); + + // Negative sampling + // @param std::vector &input_data - The data set to be sampled + // @param std::unordered_set &exclude_data - Data to be excluded + // @param int32_t samples_num - + // @param std::vector *out_samples - Sampling results returned + // @return Status - The error code return + Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples); + + Status CheckSamplesNum(NodeIdType samples_num); + + Status CheckNeighborType(NodeType neighbor_type); + + std::string dataset_file_; + int32_t num_workers_; // The number of worker threads + std::mt19937 rnd_; + RandomWalkBase random_walk_; + + std::unordered_map> node_type_map_; + std::unordered_map> node_id_map_; + + std::unordered_map> edge_type_map_; + std::unordered_map> edge_id_map_; + + std::unordered_map> node_feature_map_; + std::unordered_map> edge_feature_map_; + + std::unordered_map> default_node_feature_map_; + std::unordered_map> default_edge_feature_map_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_GRAPH_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc new file mode 100644 index 0000000000..9d2c6211f4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -0,0 +1,260 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "minddata/dataset/engine/gnn/graph_loader.h" +#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" +#include "minddata/dataset/engine/gnn/local_edge.h" +#include "minddata/dataset/engine/gnn/local_node.h" +#include "minddata/dataset/util/task_manager.h" + +using ShardTuple = std::vector, mindspore::mindrecord::json>>; + +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::MSRStatus; + +GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) + : mr_path_(mr_filepath), + num_workers_(num_workers), + row_id_(0), + shard_reader_(nullptr), + keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} + +Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, + EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, + EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, + DefaultEdgeFeatureMap *default_edge_feature_map) { + for (std::deque> &dq : n_deques_) { + while (dq.empty() == false) { + std::shared_ptr node_ptr = dq.front(); + n_id_map->insert({node_ptr->id(), node_ptr}); + (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); + dq.pop_front(); + } + } + + for (std::deque> &dq : e_deques_) { + while (dq.empty() == false) { + std::shared_ptr edge_ptr = dq.front(); + std::pair, std::shared_ptr> p; + RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); + auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); + CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); + CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); + RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); + RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); + e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ + (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); + dq.pop_front(); + } + } + + for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); + for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); + + MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); + return Status::OK(); +} + +Status GraphLoader::InitAndLoad() { + CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n"); + CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n"); + n_deques_.resize(num_workers_); + e_deques_.resize(num_workers_); + n_feature_maps_.resize(num_workers_); + e_feature_maps_.resize(num_workers_); + default_node_feature_maps_.resize(num_workers_); + default_edge_feature_maps_.resize(num_workers_); + TaskGroup vg; + + shard_reader_ = std::make_unique(); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, + "Fail to open" + mr_path_); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); + + mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; + for (const std::string &key : keys_) { + if (schema.find(key) == schema.end()) { + RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); + } + } + + // launching worker threads + for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { + RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); + } + // wait for threads to finish and check its return code + vg.join_all(Task::WaitFlag::kBlocking); + RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny()); + return Status::OK(); +} + +Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrecord::json &col_jsn, + std::shared_ptr *node, NodeFeatureMap *feature_map, + DefaultNodeFeatureMap *default_feature) { + NodeIdType node_id = col_jsn["first_id"]; + NodeType node_type = static_cast(col_jsn["type"]); + (*node) = std::make_shared(node_id, node_type); + std::vector indices; + RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); + + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); + RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[node_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } + return Status::OK(); +} + +Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrecord::json &col_jsn, + std::shared_ptr *edge, EdgeFeatureMap *feature_map, + DefaultEdgeFeatureMap *default_feature) { + EdgeIdType edge_id = col_jsn["first_id"]; + EdgeType edge_type = static_cast(col_jsn["type"]); + NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; + std::shared_ptr src = std::make_shared(src_id, -1); + std::shared_ptr dst = std::make_shared(dst_id, -1); + (*edge) = std::make_shared(edge_id, edge_type, src, dst); + std::vector indices; + RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); + RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[edge_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } + return Status::OK(); +} + +Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, + const mindrecord::json &col_jsn, std::shared_ptr *tensor) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( + key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, + std::move(TensorShape({static_cast(n_bytes / col_type_size)})), + std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); + return Status::OK(); +} + +Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, + const mindrecord::json &col_jsn, std::vector *indices) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( + key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); + + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + + for (int i = 0; i < n_bytes; i += col_type_size) { + int32_t feature_ind = -1; + if (col_type == mindrecord::ColumnInt32) { + feature_ind = *(reinterpret_cast(data + i)); + } else if (col_type == mindrecord::ColumnInt64) { + feature_ind = *(reinterpret_cast(data + i)); + } else { + RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); + } + if (feature_ind >= 0) indices->push_back(feature_ind); + } + return Status::OK(); +} + +Status GraphLoader::WorkerEntry(int32_t worker_id) { + // Handshake + TaskManager::FindMe()->Post(); + auto ret = shard_reader_->GetNextById(row_id_++, worker_id); + ShardTuple rows = ret.second; + while (rows.empty() == false) { + RETURN_IF_INTERRUPTED(); + for (const auto &tupled_row : rows) { + std::vector col_blob = std::get<0>(tupled_row); + mindrecord::json col_jsn = std::get<1>(tupled_row); + std::string attr = col_jsn["attribute"]; + if (attr == "n") { + std::shared_ptr node_ptr; + RETURN_IF_NOT_OK(LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), + &default_node_feature_maps_[worker_id])); + n_deques_[worker_id].emplace_back(node_ptr); + } else if (attr == "e") { + std::shared_ptr edge_ptr; + RETURN_IF_NOT_OK(LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), + &default_edge_feature_maps_[worker_id])); + e_deques_[worker_id].emplace_back(edge_ptr); + } else { + MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; + } + } + auto rc = shard_reader_->GetNextById(row_id_++, worker_id); + rows = rc.second; + } + return Status::OK(); +} + +void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, + DefaultNodeFeatureMap *default_node_feature_map, + DefaultEdgeFeatureMap *default_edge_feature_map) { + for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { + for (auto &m : n_feature_maps_[wkr_id]) { + for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); + } + for (auto &m : e_feature_maps_[wkr_id]) { + for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); + } + for (auto &m : default_node_feature_maps_[wkr_id]) { + (*default_node_feature_map)[m.first] = m.second; + } + for (auto &m : default_edge_feature_maps_[wkr_id]) { + (*default_edge_feature_map)[m.first] = m.second; + } + } + n_feature_maps_.clear(); + e_feature_maps_.clear(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h new file mode 100644 index 0000000000..f7f9245b8a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h @@ -0,0 +1,129 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_ +#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/graph.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_reader.h" +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::ShardReader; +using NodeIdMap = std::unordered_map>; +using EdgeIdMap = std::unordered_map>; +using NodeTypeMap = std::unordered_map>; +using EdgeTypeMap = std::unordered_map>; +using NodeFeatureMap = std::unordered_map>; +using EdgeFeatureMap = std::unordered_map>; +using DefaultNodeFeatureMap = std::unordered_map>; +using DefaultEdgeFeatureMap = std::unordered_map>; + +// this class interfaces with the underlying storage format (mindrecord) +// it returns raw nodes and edges via GetNodesAndEdges +// it is then the responsibility of graph to construct itself based on the nodes and edges +// if needed, this class could become a base where each derived class handles a specific storage format +class GraphLoader { + public: + explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); + + ~GraphLoader() = default; + // Init mindrecord and load everything into memory multi-threaded + // @return Status - the status code + Status InitAndLoad(); + + // this function will query mindrecord and construct all nodes and edges + // nodes and edges are added to map without any connection. That's because there nodes and edges are read in + // random order. src_node and dst_node in Edge are node_id only with -1 as type. + // features attached to each node and edge are expected to be filled correctly + Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, + DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); + + private: + // + // worker thread that reads mindrecord file + // @param int32_t worker_id - id of each worker + // @return Status - the status code + Status WorkerEntry(int32_t worker_id); + + // Load a node based on 1 row of mindrecord, returns a shared_ptr + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *node - return value + // @param NodeFeatureMap *feature_map - + // @param DefaultNodeFeatureMap *default_feature - + // @return Status - the status code + Status LoadNode(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *node, + NodeFeatureMap *feature_map, DefaultNodeFeatureMap *default_feature); + + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *edge - return value, the edge ptr, edge is not yet connected + // @param FeatureMap *feature_map + // @param DefaultEdgeFeatureMap *default_feature - + // @return Status - the status code + Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, + EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); + + // @param std::string key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::vector *ind - return value, list of feature index in int32_t + // @return Status - the status code + Status LoadFeatureIndex(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, + std::vector *ind); + + // @param std::string &key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *tensor - return value feature tensor + // @return Status - the status code + Status LoadFeatureTensor(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, + std::shared_ptr *tensor); + + // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 + void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); + + const int32_t num_workers_; + std::atomic_int row_id_; + std::string mr_path_; + std::unique_ptr shard_reader_; + std::vector>> n_deques_; + std::vector>> e_deques_; + std::vector n_feature_maps_; + std::vector e_feature_maps_; + std::vector default_node_feature_maps_; + std::vector default_edge_feature_maps_; + const std::vector keys_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc new file mode 100644 index 0000000000..642c73eed3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/local_edge.h" + +#include + +namespace mindspore { +namespace dataset { +namespace gnn { + +LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) + : Edge(id, type, src_node, dst_node) {} + +Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = features_.find(feature_type); + if (itr != features_.end()) { + *out_feature = itr->second; + return Status::OK(); + } else { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status LocalEdge::UpdateFeature(const std::shared_ptr &feature) { + auto itr = features_.find(feature->type()); + if (itr != features_.end()) { + RETURN_STATUS_UNEXPECTED("Feature already exists"); + } else { + features_[feature->type()] = feature; + return Status::OK(); + } +} +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h new file mode 100644 index 0000000000..d112972f8f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_ +#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class LocalEdge : public Edge { + public: + // Constructor + // @param EdgeIdType id - edge id + // @param EdgeType type - edge type + // @param std::shared_ptr src_node - source node + // @param std::shared_ptr dst_node - destination node + LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); + + ~LocalEdge() = default; + + // Get the feature of a edge + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; + + // Update feature of edge + // @param std::shared_ptr feature - + // @return Status - The error code return + Status UpdateFeature(const std::shared_ptr &feature) override; + + private: + std::unordered_map> features_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc new file mode 100644 index 0000000000..8eaf9bb716 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/local_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } + +Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = features_.find(feature_type); + if (itr != features_.end()) { + *out_feature = itr->second; + return Status::OK(); + } else { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, bool exclude_itself) { + std::vector neighbors; + auto itr = neighbor_nodes_.find(neighbor_type); + if (itr != neighbor_nodes_.end()) { + if (exclude_itself) { + neighbors.resize(itr->second.size()); + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), + [](const std::shared_ptr node) { return node->id(); }); + } else { + neighbors.resize(itr->second.size() + 1); + neighbors[0] = id_; + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + [](const std::shared_ptr node) { return node->id(); }); + } + } else { + MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + if (!exclude_itself) { + neighbors.emplace_back(id_); + } + } + *out_neighbors = std::move(neighbors); + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out) { + std::vector shuffled_id(neighbors.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + int32_t num = std::min(samples_num, static_cast(neighbors.size())); + for (int32_t i = 0; i < num; ++i) { + out->emplace_back(neighbors[shuffled_id[i]]->id()); + } + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) { + std::vector neighbors; + neighbors.reserve(samples_num); + auto itr = neighbor_nodes_.find(neighbor_type); + if (itr != neighbor_nodes_.end()) { + while (neighbors.size() < samples_num) { + RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); + } + } else { + MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + // If there are no neighbors, they are filled with kDefaultNodeId + for (int32_t i = 0; i < samples_num; ++i) { + neighbors.emplace_back(kDefaultNodeId); + } + } + *out_neighbors = std::move(neighbors); + return Status::OK(); +} + +Status LocalNode::AddNeighbor(const std::shared_ptr &node) { + auto itr = neighbor_nodes_.find(node->type()); + if (itr != neighbor_nodes_.end()) { + itr->second.push_back(node); + } else { + neighbor_nodes_[node->type()] = {node}; + } + return Status::OK(); +} + +Status LocalNode::UpdateFeature(const std::shared_ptr &feature) { + auto itr = features_.find(feature->type()); + if (itr != features_.end()) { + RETURN_STATUS_UNEXPECTED("Feature already exists"); + } else { + features_[feature->type()] = feature; + return Status::OK(); + } +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h new file mode 100644 index 0000000000..9c122931e7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_ +#define DATASET_ENGINE_GNN_LOCAL_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class LocalNode : public Node { + public: + // Constructor + // @param NodeIdType id - node id + // @param NodeType type - node type + LocalNode(NodeIdType id, NodeType type); + + ~LocalNode() = default; + + // Get the feature of a node + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; + + // Get the all neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) override; + + // Get the sampled neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) override; + + // Add neighbor of node + // @param std::shared_ptr node - + // @return Status - The error code return + Status AddNeighbor(const std::shared_ptr &node) override; + + // Update feature of node + // @param std::shared_ptr feature - + // @return Status - The error code return + Status UpdateFeature(const std::shared_ptr &feature) override; + + private: + Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out); + + std::mt19937 rnd_; + std::unordered_map> features_; + std::unordered_map>> neighbor_nodes_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h new file mode 100644 index 0000000000..a7c803fee2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_NODE_H_ +#define DATASET_ENGINE_GNN_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using NodeType = int8_t; +using NodeIdType = int32_t; + +constexpr NodeIdType kDefaultNodeId = -1; + +class Node { + public: + // Constructor + // @param NodeIdType id - node id + // @param NodeType type - node type + Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} + + virtual ~Node() = default; + + // @return NodeIdType - Returned node id + NodeIdType id() const { return id_; } + + // @return NodeIdType - Returned node type + NodeType type() const { return type_; } + + // Get the feature of a node + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; + + // Get the all neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector *out_neighbors, + bool exclude_itself = false) = 0; + + // Get the sampled neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + std::vector *out_neighbors) = 0; + + // Add neighbor of node + // @param std::shared_ptr node - + // @return Status - The error code return + virtual Status AddNeighbor(const std::shared_ptr &node) = 0; + + // Update feature of node + // @param std::shared_ptr feature - + // @return Status - The error code return + virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; + + protected: + NodeIdType id_; + NodeType type_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h b/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h new file mode 100644 index 0000000000..cee0b7abf3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/jagged_connector.h @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_JAGGED_CONNECTOR_H_ +#define DATASET_ENGINE_JAGGED_CONNECTOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { +class JaggedConnector : public Connector> { + public: + JaggedConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity) + : Connector>(num_producers, num_consumers, queue_capacity) { + for (int i = 0; i < num_producers; i++) { + is_queue_finished_.push_back(false); + } + } + + ~JaggedConnector() = default; + + Status Add(int32_t worker_d, std::unique_ptr &&element) noexcept { + return Connector>::Push(worker_d, std::move(element)); + } + + Status Pop(int32_t worker_id, std::unique_ptr *result) noexcept override { + { + MS_ASSERT(worker_id < num_consumers_); + std::unique_lock lock(m_); + RETURN_IF_NOT_OK(cv_.Wait(&lock, [this, worker_id]() { return expect_consumer_ == worker_id; })); + if (is_queue_finished_[pop_from_]) { + std::string errMsg = "ERROR: popping from a finished queue in JaggedConnector"; + RETURN_STATUS_UNEXPECTED(errMsg); + } + + RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result)); + if ((*result)->eoe()) { + is_queue_finished_[pop_from_] = true; + } + + for (int offset = 1; offset <= num_producers_; offset++) { + int32_t nextQueueIndex = (pop_from_ + offset) % num_producers_; + if (is_queue_finished_[nextQueueIndex] == false) { + pop_from_ = nextQueueIndex; + break; + } + } + + expect_consumer_ = (expect_consumer_ + 1) % num_consumers_; + } + + cv_.NotifyAll(); + return Status::OK(); + } + + void DoReset() { + for (int i = 0; i < is_queue_finished_.size(); i++) { + is_queue_finished_[i] = false; + } + + Connector>::Reset(); + } + + private: + std::vector is_queue_finished_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_JAGGED_CONNECTOR_H_ diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/opt/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc new file mode 100644 index 0000000000..d8ce2dd863 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" + +namespace mindspore { +namespace dataset { + +Status TensorOpFusionPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp + // Abstract into a more general member function that can find any pattern, expressed + // by regular expressions, for instance. + // Add a list of optimisation policies. For now, just this lambda + auto FindPattern = [](auto &tfuncs) { + auto it = + std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; }); + auto next = it + 1; + if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) { + return it; + } else { + return tfuncs.end(); + } + }; + + auto &tfuncs = node->TFuncs(); + auto it = FindPattern(tfuncs); + if (it != tfuncs.end()) { + auto next = it + 1; + auto op = static_cast(next->get()); + *it = std::static_pointer_cast(std::make_shared(*op)); + tfuncs.erase(next); + } + if (modified != nullptr) { + *modified = true; + } else { + RETURN_STATUS_UNEXPECTED("modified is nullptr"); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h new file mode 100644 index 0000000000..a109af396c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TENSOR_OP_FUSION_PASS_H_ +#define DATASET_TENSOR_OP_FUSION_PASS_H_ + +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class TensorOpFusionPass tensor_op_fusion_pass.h +/// \brief And optional optimization pass identifying and fusing +/// tensor ops within MapOp +class TensorOpFusionPass : public NodePass { + /// \brief Identifies and fuses tensor ops within MapOp + /// \param[in] node The node being visited + /// \param[inout] *modified indicates whether the node has been visited + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TENSOR_OP_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc new file mode 100644 index 0000000000..4a8bbaf38f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/datasetops/batch_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/project_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/skip_op.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/engine/datasetops/filter_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#endif +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/take_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" + +namespace mindspore { +namespace dataset { + +// Driver method for TreePass +Status TreePass::Run(ExecutionTree *tree, bool *modified) { + if (tree == nullptr || modified == nullptr) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to TreePass"); + } + return this->RunOnTree(tree, modified); +} + +// Driver method for NodePass +Status NodePass::Run(ExecutionTree *tree, bool *modified) { + if (tree == nullptr || modified == nullptr) { + return Status(StatusCode::kUnexpectedError, "Null pointer passed to NodePass"); + } + std::shared_ptr root = tree->root(); + if (traversalOrder_ == Order::DFS) { + // DFS + return DFSNodeVisit(root, modified); + } else if (traversalOrder_ == Order::BFS) { + // BFS + return BFSNodeVisit(root, modified); + } + return Status::OK(); +} + +// Helper function to perform DFS visit +Status NodePass::DFSNodeVisit(std::shared_ptr node, bool *modified) { + RETURN_IF_NOT_OK(node->PreAccept(this, modified)); + for (const auto &c : node->Children()) { + RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified)); + } + return node->Accept(this, modified); +} + +// Helper function to perform BFS visit +Status NodePass::BFSNodeVisit(std::shared_ptr root, bool *modified) { + // Initialize bfs queue with root + std::queue> bfsQueue; + bfsQueue.push(root); + + // BFS loop + while (!bfsQueue.empty()) { + // Pop the front of the bfs queue + auto curNode = bfsQueue.front(); + bfsQueue.pop(); + + // Run node pass + RETURN_IF_NOT_OK(curNode->Accept(this, modified)); + + // Push children into bfs queue + for (const auto &c : curNode->Children()) { + bfsQueue.push(c); + } + } + return Status::OK(); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +#ifdef ENABLE_PYTHON +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} +#endif + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::RunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return RunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} + +Status NodePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Fallback to base class visitor by default + return PreRunOnNode(std::static_pointer_cast(node), modified); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h new file mode 100644 index 0000000000..845ab34d66 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pass.h @@ -0,0 +1,213 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_H_ + +#include +#include + +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BatchOp; + +class MapOp; + +class ProjectOp; + +class RenameOp; + +class SkipOp; + +class ShuffleOp; + +class MindRecordOp; + +class TFReaderOp; + +#ifdef ENABLE_PYTHON +class FilterOp; + +class GeneratorOp; +#endif + +class RandomDataOp; + +class RepeatOp; + +class TakeOp; + +class ZipOp; + +class DeviceQueueOp; + +class ImageFolderOp; + +class CacheOp; + +class MnistOp; + +class ManifestOp; + +class CifarOp; + +class VOCOp; + +class CocoOp; + +class CelebAOp; + +class CacheMergeOp; + +class CacheLookupOp; + +// The base class Pass is the basic unit of tree transformation. +// The actual implementation of the passes will be derived from here. +class Pass : public std::enable_shared_from_this { + public: + // Run the transformation pass against the execution tree. + // @param tree - Pointer to the execution tree to be transformed. + // @param modified - Pointer to the modified flag, + virtual Status Run(ExecutionTree *tree, bool *modified) = 0; +}; + +// TreePass is a basic Pass class which performs transformation on ExecutionTree directly. +class TreePass : public Pass { + public: + /// \brief Run the transformation pass against the execution tree. + /// \param[inout] tree Pointer to the execution tree to be transformed. + /// \param[inout] modified Indicate if the tree was modified + Status Run(ExecutionTree *tree, bool *modified) final; + + /// \brief Derived classes may implement the runOnTree function to implement tree transformation. + /// "modified" flag needs to be set to true if tree is modified during the pass execution. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); } +}; + +// NodePass is a basic Pass class which performs transformation on Node visiting. +// NodePass implements Visitor design pattern. +class NodePass : public Pass { + public: + // Tree traversal order + enum Order { DFS, BFS }; + + // Constructor + // Default DFS traversal + explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; } + + ~NodePass() = default; + + /// \brief Run the transformation pass against the execution tree + /// \param[inout] tree Pointer to the execution tree to be transformed + /// \param[inout] modified Indicator if the tree was changed + Status Run(ExecutionTree *tree, bool *modified) final; + + /// \brief Derived classes may implement the PreRunOnNode function to implement any initial visit work on the way down + /// a tree traversal. "modified" flag needs to be set to true if tree is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all + /// \return Status The error code return + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } + + /// \brief Derived classes may implement the RunOnNode function to implement node level tree transformation + /// "modified" flag needs to be set to true if tree is modified during the pass execution + /// \param[in] node The node being visited + /// \param[out] modified Indicator if the node was changed at all. + /// \return Status The error code return + virtual Status RunOnNode(std::shared_ptr node, bool *modified) { return Status::OK(); } + + // Visit methods to be overridden. + // Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode + // of its own type and override "Accept" from DatasetOp. + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + +#ifdef ENABLE_PYTHON + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); +#endif + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status RunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + virtual Status PreRunOnNode(std::shared_ptr node, bool *modified); + + private: + // Helper function to perform DFS visit + Status DFSNodeVisit(std::shared_ptr node, bool *modified); + + // Helper function to perform BFS visit + Status BFSNodeVisit(std::shared_ptr root, bool *modified); + + // Tree traversal order of the NodePass + Order traversalOrder_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc new file mode 100644 index 0000000000..59a3f71c53 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/post/repeat_pass.h" +#include "minddata/dataset/engine/datasetops/repeat_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" + +namespace mindspore { +namespace dataset { + +RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {} + +// Identifies the subtree below this node as being in a repeated path of the tree. +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // If we are already repeated, then this is a nested repeat. + if (is_repeated_) { + nested_repeats_++; + } + is_repeated_ = true; + return Status::OK(); +} + +// Identifies the subtree below this node as being in a cache merge path +Status RepeatPass::PreRunOnNode(std::shared_ptr node, bool *modified) { + // Turn on the flag that we're under a merge op + is_merge_ = true; + return Status::OK(); +} + +// Hooks up any identified eoe nodes under this repeat. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + node->AddToEoeList(leaf_op); + leaf_op = PopFromEOEOpStack(); + } + + // We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up + // and add it to the list of eoe/leaf ops for the repeat, removing it from the save area. + if (is_merge_ && cache_lookup_) { + cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated); + node->AddToEoeList(std::move(cache_lookup_)); + } + + // If we are a nested repeat, then we add ourself to the repeat stack for the next one above us. + // A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree. + if (nested_repeats_ > 0) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + AddToEOEOpStack(node); + nested_repeats_--; + } + + // If we are not nested, or we were the top-most repeat, now we clear the flag + if (nested_repeats_ == 0) { + is_repeated_ = false; + } + + return Status::OK(); +} + +// CacheOp removes previous leaf ops and replaces them with itself +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + // if we are a cache within a repeat path of the tree, then there will be + // eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the + // repeat or epoch ctrl operators can work with them for repeat activity during runtime. + // However, since a cache is present: + // - unflag those ops as being repeated ops + // - remove them from the eoe op stack so that repeat op above in the tree won't know about them + // - add ourself (the cache op), as an eoe op + // We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead + // the repeating behaviours shall be invoked against the cache op. + std::shared_ptr leaf_op = PopFromEOEOpStack(); + while (leaf_op != nullptr) { + leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat); + leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated); + leaf_op = PopFromEOEOpStack(); + } + AddToEOEOpStack(std::static_pointer_cast(node)); + } + + return Status::OK(); +} + +// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up +// for use with a controlling repeat above it. +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // If we are in a repeat path, then set our repeated flag + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + + // if we are a leaf node then save ourself in a stack for the repeat operator above us + if (node->IsLeaf()) { + AddToEOEOpStack(node); + } + } + return Status::OK(); +} + +// Turns off the tracking for operations under merge op +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + // Setting the flag is needed since we didn't call the base class DatasetOp version + if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated); + is_merge_ = false; + cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed + return Status::OK(); +} + +// Saves the lookup up in case it needs to be referenced by a repeat +Status RepeatPass::RunOnNode(std::shared_ptr node, bool *modified) { + if (!node->IsLeaf()) { + // By definition, the CacheLookup must be a leaf op. Make that clear here. + RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!"); + } + + // If we are in a repeat path already, then there must be a repeat above the merge op + // In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here. + if (is_repeated_) { + node->set_control_flag(DatasetOp::kDeOpRepeated); + AddToEOEOpStack(node); + } else { + // save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we + // may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself + // into the pass so that the decision can be made during the processing of the cache miss leg of the merge. + cache_lookup_ = std::static_pointer_cast(node); + } + return Status::OK(); +} + +// Adds an operator to the eoe operator stack save area +void RepeatPass::AddToEOEOpStack(std::shared_ptr dataset_op) { eoe_stack_.push(dataset_op); } + +// Pops an operator from the eoe operator stack save area +std::shared_ptr RepeatPass::PopFromEOEOpStack() { + std::shared_ptr top_op = nullptr; + if (!eoe_stack_.empty()) { + top_op = eoe_stack_.top(); + eoe_stack_.pop(); + } + return top_op; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h new file mode 100644 index 0000000000..9b733e2329 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/post/repeat_pass.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ +#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +/// \class RepeatPass repeat_pass.h +/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references +/// to the eoe-producing (typically leaf) nodes underneath it. +class RepeatPass : public NodePass { + public: + /// \brief Constructor + RepeatPass(); + + /// \brief Identifies the subtree below this node as being in a repeated path of the tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Identifies the subtree below this node as being in a cache merge path + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Hooks up any identified eoe nodes under this repeat. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief CacheOp removes previous leaf ops and replaces them with itself + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Turns of the tracking for operations under merge op + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Saves the lookup up in case it needs to be referenced by a repeat + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up + /// for use with a controlling repeat above it. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + /// \brief Adds an operator to the eoe operator stack save area + /// \param op - The dataset op to work add to eoe stack + /// \return Status - The error code return + void AddToEOEOpStack(std::shared_ptr dataset_op); + + /// \brief Pops an operator from the eoe operator stack save area + /// \return shared_ptr to the popped operator + std::shared_ptr PopFromEOEOpStack(); + + bool is_repeated_; // T/F if we are processing under a repeat + bool is_merge_; // T/F if we are processing under a cache merge op + int32_t nested_repeats_; // A counter for nested repeats + std::stack> eoe_stack_; // A save area for leaf/eoe ops + std::shared_ptr cache_lookup_; // A save area for a cache lookup op +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc new file mode 100644 index 0000000000..09b5f14a17 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/cache_pass.h" +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/generator_op.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" + +namespace mindspore { +namespace dataset { + +// Constructor +CachePass::CachePass(CacheTransformPass *transform_pass) + : transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status CachePass::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree."; + if (is_caching_) { + RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!"); + } + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache +// transformation +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + is_caching_ = false; // We a no longer in a cache subtree. clear the flag. + if (leaf_op_) { + MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache."; + // Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op, + // using base class pointers. + transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node); + } else { + // If there was no leaf_op set, then this is a non-mappable scenario. + + if (sampler_) { + // Grab the sampler that was saved from the leaf and plug it into the cache op + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf."; + } else { + // We're a cache op but no sampler was saved from leaf, so create a default sampler + int64_t num_samples = 0; + int64_t start_index = 0; + sampler_ = std::make_shared(num_samples, start_index); + node->SetSampler(std::move(sampler_)); + MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op."; + } + + // Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache + uint32_t cache_crc = DatasetOp::GenerateCRC(node); + RETURN_IF_NOT_OK(node->CreateCache(cache_crc)); + } + + return Status::OK(); +} + +// Common code for mappable leaf setup. +Status CachePass::MappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // If we are a leaf in the caching path, then save this leaf. + if (is_caching_) { + MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected"; + leaf_op_ = std::move(leaf_op); + } + return Status::OK(); +} + +// Common code for non mappable leaf setup. +Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr leaf_op) { + // If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree. + if (is_caching_ && leaf_op_) { + RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache."); + } + + // Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf + // as save it for use by cache op in ascendant tree. + if (is_caching_) { + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_)); + MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected"; + } else { + // If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can + // remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based) + std::shared_ptr sampler_from_leaf; + RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf)); + } + return Status::OK(); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + if (is_caching_) { + // If we are a TF Reader in a caching tree, then change our config so that it becomes a basic + // TF reader that parses all files. Selection of data will come from the sampler on the cache instead. + node->MakeSimpleProducer(); + } + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return NonMappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} + +// Perform leaf node cache tranform identifications +Status CachePass::RunOnNode(std::shared_ptr node, bool *modified) { + return MappableCacheLeafSetup(std::static_pointer_cast(node)); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h new file mode 100644 index 0000000000..cbc805cd3e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_pass.h @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class CacheTransformPass; + +/// \class CachePass cache_pass.h +/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache +/// transformation. It works in conjunction with the CacheTransformPass +class CachePass : public NodePass { + public: + /// \brief Constructor + /// \param[in] transform_pass Raw pointer back to controlling tree pass + explicit CachePass(CacheTransformPass *transform_pass); + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache + /// transformation + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Perform leaf node cache tranform identifications + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + /// \brief Common code for mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status MappableCacheLeafSetup(std::shared_ptr leaf_op); + + /// \brief Common code for non-mappable leaf setup. + /// \param[in] node The leaf node performing setup work. + /// \return Status The error code return + Status NonMappableCacheLeafSetup(std::shared_ptr leaf_op); + + bool is_caching_; + std::shared_ptr leaf_op_; + std::shared_ptr sampler_; + CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc new file mode 100644 index 0000000000..033150e8f4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/cache_pass.h" +#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" + +namespace mindspore { +namespace dataset { + +// constructor +CacheTransformPass::CacheTransformPass() {} + +// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations +Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: Cache transform pass started."; + // Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will + // use to execute a transform. + std::unique_ptr cache_pass = std::make_unique(this); + RETURN_IF_NOT_OK(cache_pass->Run(tree, modified)); + + // Then, execute the transform for each pair + for (auto cache_pair : cache_pairs_) { + MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform."; + ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client()); + } + MS_LOG(INFO) << "Pre pass: Cache transform pass complete."; + return Status::OK(); +} + +// Helper function to execute the cache transformation. +Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, + std::shared_ptr cache_op, + std::shared_ptr cache_client) { + // Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was + // the root node. It is also possible that cache_child == leaf_op + std::shared_ptr cache_child = cache_op->child(0); + DatasetOp *cache_parent = nullptr; + cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent + + // Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later. + std::shared_ptr leaf_sampler = leaf_op->sampler(); + + // Construct the merge op with defaults + std::shared_ptr merge_op; + CacheMergeOp::Builder merge_builder; + RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(merge_op)); + + // Construct the cache lookup op with defaults + std::shared_ptr cache_lookup_op; + CacheLookupOp::Builder lookup_builder; + RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op)); + RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op)); + + // Overwrite the old sampler in this leaf op to become the lookup op + leaf_op->SetSampler(cache_lookup_op); + + // If the cache had a parent, then go into that parent to remove the cache from it's child list and then + // replace it with the merge op. + if (cache_parent != nullptr) { + RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op)); + RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op)); + } else { + // If we didn't have a parent, then the merge op is the root node + RETURN_IF_NOT_OK(tree->AssignRoot(merge_op)); + } + + // Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op. + // We maintain a local pointer to the old child though. + RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child)); + + // Connect the merge op + RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op))); + RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child))); + + // At this point, the cache op has already had it's children and parents taken away. Calling remove + // on it at this point will not do any node hookups, and instead set internal fields to invalid. + RETURN_IF_NOT_OK(cache_op->Remove()); + + return Status::OK(); +} + +// Assigns the leaf and cache operators that are involved in a cache transformation +void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr leaf_op, + std::shared_ptr cache_op) { + cache_pairs_.push_back(std::make_pair(leaf_op, cache_op)); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h new file mode 100644 index 0000000000..02c22c4472 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/cache_transform_pass.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ + +#include +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +class CacheClient; + +/// \class CacheTransformPass cache_transform_pass.h +/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching +/// operations +class CacheTransformPass : public TreePass { + public: + /// \brief Constructor + CacheTransformPass(); + + /// \brief Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + /// \brief Assigns the leaf and cache operators that are involved in a cache transformation + /// \param[in] leaf_op The leaf operator involved in the cache transform + /// \param[in] cache_op The cache operator involved in the cache transform + void AddMappableCacheOperators(std::shared_ptr leaf_op, std::shared_ptr cache_op); + + private: + /// \brief Helper function to execute the cache transformation. + /// + /// Input: + /// Sampler + /// | + /// LeafOp --> OtherOps --> CacheOp + /// + /// Transformed: + /// Sampler --> CacheLookupOp ----------------> + /// | | + /// | MergeOp + /// | | + /// LeafOp --> OtherOps --> + /// + /// \param[in] leaf_op The leaf node in the transform + /// \param[in] cache_op The cache op in the transform (will get removed) + /// \param[in] cache_client The cache client + /// \return Status The error code return + Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr leaf_op, + std::shared_ptr cache_op, std::shared_ptr cache_client); + + // The two operators that work together to establish the cache transform + std::vector, std::shared_ptr>> cache_pairs_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_CACHE_TRANSFORM_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc new file mode 100644 index 0000000000..f04d7bc07d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/pre/removal_nodes.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" + +namespace mindspore { +namespace dataset { + +RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {} + +// Identifies the subtree below this node as a cached descendant tree. +Status RemovalNodes::PreRunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree."; + is_caching_ = true; + return Status::OK(); +} + +// Resets the tracking of the cache within the tree +Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + MS_LOG(INFO) << "Removal pass: cache descendant tree complete."; + is_caching_ = false; + return Status::OK(); +} + +// Perform ShuffleOp removal check. +Status RemovalNodes::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + // If we are in a cache descendant tree, then this shuffle op needs to be removed + if (is_caching_) { + MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)"; + if (removal_pass_) { + removal_pass_->AddToRemovalList(std::static_pointer_cast(node)); + } else { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!"); + } + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h new file mode 100644 index 0000000000..32025cd597 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_nodes.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_ + +#include +#include "minddata/dataset/engine/opt/pass.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" + +namespace mindspore { +namespace dataset { +/// \class RemovalNodes removal_nodes.h +/// \brief This is a NodePass who's job is to identify which nodes should be removed. +/// It works in conjunction with the removal_pass. +class RemovalNodes : public NodePass { + public: + /// \brief Constructor + /// \param[in] removal_pass Raw pointer back to controlling tree pass + explicit RemovalNodes(RemovalPass *removal_pass); + + /// \brief Identifies the subtree below this node as a cached descendant tree. + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status PreRunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Resets the tracking of the cache within the tree + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + /// \brief Destructor + ~RemovalNodes() = default; + + /// \brief Perform ShuffleOp removal check + /// \param[in] node The node being visited + /// \param[inout] modified Indicator if the node was changed at all + /// \return Status The error code return + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + private: + bool is_caching_; + RemovalPass *removal_pass_; // Back pointer to the owning removal pass +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc new file mode 100644 index 0000000000..0db422a7c2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/opt/pre/removal_nodes.h" +#include "minddata/dataset/engine/opt/pre/removal_pass.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { + +// constructor +RemovalPass::RemovalPass() {} + +// Runs a removal_nodes pass first to find out which nodes to remove, then removes them. +Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) { + MS_LOG(INFO) << "Pre pass: removal pass started."; + // Create the removal node pass which can identify which nodes need to be removed. + std::unique_ptr removal_nodes = std::make_unique(this); + RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified)); + + // Then, execute the removal of any nodes that were set up for removal + for (auto node : removal_nodes_) { + node->Remove(); + } + MS_LOG(INFO) << "Pre pass: removal pass complete."; + return Status::OK(); +} + +// Adds an operator to the list of operators to be removed +void RemovalPass::AddToRemovalList(std::shared_ptr dataset_op) { removal_nodes_.push_back(dataset_op); } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h new file mode 100644 index 0000000000..bcab7cf08c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/pre/removal_pass.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ +#define DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ + +#include +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class DatasetOp; + +/// \class RemovalPass removal_pass.h +/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which +/// nodes should be removed, and then removes them. +class RemovalPass : public TreePass { + public: + /// \brief Constructor + RemovalPass(); + + /// \brief Destructor + ~RemovalPass() = default; + + /// \brief Runs a removal_nodes pass first to find out which nodes to remove, then removes them. + /// \param[inout] tree The tree to operate on. + /// \param[inout] Indicate of the tree was modified. + /// \return Status The error code return + Status RunOnTree(ExecutionTree *tree, bool *modified) override; + + /// \brief Adds an operator to the list of operators to be removed + /// \param[in] dataset_op The operator to add to the removal list + void AddToRemovalList(std::shared_ptr dataset_op); + + private: + std::vector> removal_nodes_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_PASS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc new file mode 100644 index 0000000000..eb74d8fcc3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/engine/opt/util/printer_pass.h" + +namespace mindspore { +namespace dataset { + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting DatasetOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting BatchOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting MapOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ProjectOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting RenameOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting SkipOp" << '\n'; + return Status::OK(); +} +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ShuffleOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting MindRecordOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting TFReaderOp" << '\n'; + return Status::OK(); +} + +#ifdef ENABLE_PYTHON +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting FilterOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting GeneratorOp" << '\n'; + return Status::OK(); +} +#endif + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting TakeOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ZipOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting DeviceQueueOp" << '\n'; + return Status::OK(); +} + +Status PrinterPass::RunOnNode(std::shared_ptr node, bool *modified) { + *modified = false; + std::cout << "Visiting ImageFolderOp" << '\n'; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h new file mode 100644 index 0000000000..527df3ccc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/opt/util/printer_pass.h @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H +#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H + +#include +#include "minddata/dataset/engine/opt/pass.h" + +namespace mindspore { +namespace dataset { + +class PrinterPass : public NodePass { + public: + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + +#ifdef ENABLE_PYTHON + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; +#endif + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; + + Status RunOnNode(std::shared_ptr node, bool *modified) override; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H diff --git a/mindspore/ccsrc/dataset/engine/perf/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/perf/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/perf/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/perf/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.cc new file mode 100644 index 0000000000..20b4908030 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/perf/connector_size.h" +#include +#include +#include +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/path.h" + +using json = nlohmann::json; +namespace mindspore { +namespace dataset { +using Qrow = std::vector; + +// Sample action +Status ConnectorSize::Sample() { + Qrow cur_row; + std::transform(tree_->begin(), tree_->end(), std::back_inserter(cur_row), + [](DatasetOp &op) { return op.ConnectorSize(); }); + // Push new row of sample + sample_table_.push_back(cur_row); + return Status::OK(); +} + +// JSON serializer helper function +json ConnectorSize::ParseOpInfo(const DatasetOp &node, const std::vector &size) { + auto children = node.Children(); + std::vector children_id; + std::transform(children.begin(), children.end(), std::back_inserter(children_id), + [](std::shared_ptr op) -> int32_t { return op->id(); }); + json json_node; + json_node["op_id"] = node.id(); + json_node["op_type"] = node.Name(); + json_node["num_workers"] = node.num_workers(); + json metrics; + // DeviceQueueOp is a special op,it is not inlined but its output queue is invalid. + // So we should not output its queue size. + if (!node.inlined() && node.Name() != "DeviceQueueOp") { + metrics["output_queue"] = {{"size", size}, {"length", node.ConnectorCapacity()}}; + } + json_node["metrics"] = metrics; + if (!children_id.empty()) { + json_node["children"] = children_id; + } + + return json_node; +} + +// Save profiling data to file +Status ConnectorSize::SaveToFile() { + std::ofstream os(file_path_, std::ios::trunc); + uint32_t idx = 0; + json output; + std::shared_ptr cfg = GlobalContext::config_manager(); + output["sampling_interval"] = cfg->monitor_sampling_interval(); + // Traverse the ExecutionTree for JSON node generation + for (auto &node : *tree_) { + std::vector cur_queue_size; + std::transform(sample_table_.begin(), sample_table_.end(), std::back_inserter(cur_queue_size), + [&](const ConnectorSizeSample &sample) { return sample[idx]; }); + json json_node = ParseOpInfo(node, cur_queue_size); + output["op_info"].push_back(json_node); + idx++; + } + os << output; + return Status::OK(); +} +Status ConnectorSize::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + device_id + ".json")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h new file mode 100644 index 0000000000..61ba06a76f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_size.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_CONNECTOR_SIZE_H +#define DATASET_CONNECTOR_SIZE_H + +#include +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" + +using json = nlohmann::json; + +namespace mindspore { +namespace dataset { +class ExecutionTree; + +// Connector size sampling samples the output connector size of each op in the pipeline. +// It support JSON serialization for external usage. +class ConnectorSize : public Sampling { + // Connecto size sampling data is stored as a 2D vector + // op_0 ... op_m + // sample_0 size_0_0 ... size_m_0 + // ... ... ... ... + // sample_n size_0_m ... size_m_n + // + // A circular buffer will be implemented in the future to make this table more flexible. + using ConnectorSizeSample = std::vector; + using ConnectorSizeSampleTable = std::vector; + + public: + explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {} + + ~ConnectorSize() override = default; + + // Driver function for connector size sampling. + // This function samples the connector size of every nodes within the ExecutionTree + Status Sample() override; + + std::string Name() const override { return kConnectorSizeSamplingName; } + + // Save sampling data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id) override; + + // Parse op infomation and transform to json format + json ParseOpInfo(const DatasetOp &node, const std::vector &size); + + private: + ExecutionTree *tree_ = nullptr; // ExecutionTree pointer + ConnectorSizeSampleTable sample_table_; // Dataset structure to store all samples of connector size sampling +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CONNECTOR_SIZE_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc new file mode 100644 index 0000000000..b5e2efaf73 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/perf/connector_throughput.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +// temporary helper +int ConnectorThroughput::InitNodes() { + auto it = (*tree_).begin(); + return it.NumNodes(); +} +// Sample action +Status ConnectorThroughput::Sample() { + std::vector out_buffer_count_row(n_nodes_); + std::vector throughput_row(n_nodes_); + TimePoint cur_time; // initialised inside the loop, used outside the loop to update prev sample time. + auto col = 0; + for (const auto &node : *tree_) { + auto cur_out_buffer_count = node.ConnectorOutBufferCount(); + out_buffer_count_row[col] = cur_out_buffer_count; + auto sz = timestamps_.size(); + cur_time = std::chrono::steady_clock::now(); + auto _dt = std::chrono::duration_cast(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]); + auto dt = std::chrono::duration(_dt).count(); + auto prev_out_buffer_count = out_buffer_count_table_[col][out_buffer_count_table_.size() - 1]; + if (dt != 0) { + auto thr = (cur_out_buffer_count - prev_out_buffer_count) / (1000 * dt); + throughput_row[col] = thr; + } else { + throughput_row[col] = -1; + } + col++; + } + std::vector v = {cur_time}; // temporary fix + timestamps_.AddSample(v); + // Push new row of sample + out_buffer_count_table_.AddSample(out_buffer_count_row); + throughput_.AddSample(throughput_row); + return Status::OK(); +} + +json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector &thr) { + auto children = node.Children(); + std::vector children_id; + std::transform(children.begin(), children.end(), std::back_inserter(children_id), + [](std::shared_ptr op) -> int32_t { return op->id(); }); + json json_node; + json_node["op_id"] = node.id(); + json_node["op_type"] = node.Name(); + json_node["num_workers"] = node.num_workers(); + json metrics; + metrics["output_queue"] = {{"throughput", thr}}; + + json_node["metrics"] = metrics; + if (!children_id.empty()) { + json_node["children"] = children_id; + } + + return json_node; +} + +// Save profiling data to file +Status ConnectorThroughput::SaveToFile() { + std::ofstream os(file_path_); + json output; + output["sampling_interval"] = 10; + // Traverse the ExecutionTree for JSON node generation + int col = 0; + for (auto &node : *tree_) { + std::vector throughput; + for (auto i = 0; i < throughput_.size(); i++) { + throughput.push_back(throughput_[col][i]); + } + json json_node = ParseOpInfo(node, throughput); + output["op_info"].push_back(json_node); + col++; + } + os << output; + return Status::OK(); +} +Status ConnectorThroughput::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + Name() + "_" + device_id + ".json")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h new file mode 100644 index 0000000000..9cf387230a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/connector_throughput.h @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_CONNECTOR_THROUGHPUT_H +#define DATASET_CONNECTOR_THROUGHPUT_H + +#include +#include +#include +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" +#include "minddata/dataset/engine/perf/perf_data.h" +#include "minddata/dataset/engine/perf/cyclic_array.h" +#include "minddata/dataset/engine/datasetops/dataset_op.h" +#include "minddata/dataset/engine/execution_tree.h" + +using json = nlohmann::json; +namespace mindspore { +namespace dataset { +// Connector throughput samples the output connector size of each op in the pipeline. +// For the description of the data structure see perf_buffer.h +// It support JSON serialization for external usage. +class ConnectorThroughput : public Sampling { + using OutBufferCount = PerfData>; + using Throughput = PerfData>; + using TimePoint = std::chrono::time_point; + using TimeStamps = PerfData>; + + public: + explicit ConnectorThroughput(ExecutionTree *tree, int64_t max_rows = 1000000) + : tree_(tree), + max_rows_(max_rows), + n_nodes_(InitNodes()), + out_buffer_count_table_(OutBufferCount(max_rows_, n_nodes_)), + throughput_(Throughput(max_rows_, n_nodes_)), + timestamps_(TimeStamps(max_rows_, 1)) { + timestamps_.AddSample(std::vector(1)); + out_buffer_count_table_.AddSample(std::vector(n_nodes_)); + } + + /// \brief Destructor + ~ConnectorThroughput() = default; + + // Driver function for connector size sampling. + // This function samples the connector size of every nodes within the ExecutionTree + Status Sample() override; + + /* Status TestPrint() override { + std::ofstream os("performance_monitor.txt"); + if (throughput_.size() == 0) { + os << "data is empty" << std::endl; + return Status::OK(); + } + for (int i = 0; i < throughput_.size(); i++) { + for (int j = 0; j < n_nodes_; j++) { + os << throughput_[j][i] << " "; + } + os << std::endl; + } + return Status::OK(); + };*/ + + // Traverse the tree nodes and count them + int InitNodes(); + + std::string Name() const override { return name_; }; + + // Save sampling data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id); + + json ParseOpInfo(const DatasetOp &node, const std::vector &thr); + + private: + ExecutionTree *tree_ = nullptr; // ExecutionTree pointer + int64_t max_rows_; + int32_t n_nodes_; + OutBufferCount out_buffer_count_table_; + Throughput throughput_; + TimeStamps timestamps_; + std::string name_ = kConnectorThroughputSamplingName; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_CONNECTOR_THROUGHPUT_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h b/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h new file mode 100644 index 0000000000..2dfc3fd99d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/cyclic_array.h @@ -0,0 +1,197 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_CYCLIC_ARRAY_H +#define DATASET_CYCLIC_ARRAY_H + +#include +#include +#include +#include +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { + +/// \class CyclicArray "include/cyclic_array.h +/// \brief This is a container with a contiguous memory layout that pnly keeps N last entries, +/// when the number of entries exceeds the capacity +/// Must be preallocated +template +class CyclicArray { + public: + using value_type = T; + class Iterator { + // Add operator[] and make fully compliant with random access iterator + // and add a const iterator + // add resize(), empty() + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = CyclicArray::value_type; + using difference_type = std::ptrdiff_t; + using pointer = CyclicArray::value_type *; + using reference = CyclicArray::value_type &; + + Iterator() = default; + + Iterator(dsize_t idx, pointer ptr, dsize_t capacity, dsize_t head) + : cur_idx_(idx), ptr_(ptr), capacity_(capacity), head_(head) {} + + Iterator(const Iterator &rhs) = default; + + ~Iterator() = default; + + Iterator &operator++() { + cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); + return *this; + } + + Iterator operator++(int) { + Iterator tmp(*this); + cur_idx_ = (cur_idx_ + 1) % (capacity_ + 1); + return tmp; + } + + Iterator &operator--() { + cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); + return *this; + } + + Iterator operator--(int) { + Iterator tmp(*this); + cur_idx_ = (cur_idx_ + capacity_) % (capacity_ + 1); + return tmp; + } + + Iterator operator+(dsize_t x) { return Iterator((cur_idx_ + x) % (capacity_ + 1), ptr_, capacity_, head_); } + + Iterator operator-(dsize_t x) { + return Iterator((cur_idx_ + (capacity_ + 1 - x)) % (capacity_ + 1), ptr_, capacity_, head_); + } + + bool operator<(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) < (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + bool operator>(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) > (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + bool operator>=(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) >= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + bool operator<=(const Iterator &rhs) { + return (head_ + cur_idx_) % (capacity_ + 1) <= (rhs.head_ + rhs.cur_idx_) % (capacity_ + 1); + } + + difference_type operator-(const Iterator &rhs) { + return (cur_idx_ - rhs.cur_idx_ + capacity_ + 1) % (capacity_ + 1); + } + + reference operator*() { return ptr_[cur_idx_]; } + + pointer operator->() { return &(ptr_[cur_idx_]); } + + bool operator==(const Iterator &rhs) { return cur_idx_ == rhs.cur_idx_; } + + bool operator!=(const Iterator &rhs) { return cur_idx_ != rhs.cur_idx_; } + + private: + dsize_t cur_idx_; + pointer ptr_; + dsize_t capacity_; + dsize_t head_; + }; + + /// \brief Default constructor + CyclicArray() : buf_(nullptr), head_(0), tail_(0), size_(0), capacity_(0) {} + + /// \brief Constructor + /// \param[in] capacity + explicit CyclicArray(dsize_t capacity) + : buf_(std::make_unique(capacity + 1)), head_(0), tail_(0), size_(0), capacity_(capacity) {} + + CyclicArray(const CyclicArray &rhs) + : buf_(std::make_unique(rhs.capacity_ + 1)), + head_(rhs.head_), + tail_(rhs.tail_), + size_(rhs.size_), + capacity_(rhs.capacity_) { + std::copy(rhs.begin(), rhs.end(), begin()); + } + + CyclicArray(CyclicArray &&rhs) = default; + + ~CyclicArray() = default; + + /// \brief Iterator begin() + Iterator begin() { return Iterator(head_, buf_.get(), capacity_, head_); } + + /// \brief Iterator end() + Iterator end() { return Iterator(tail_, buf_.get(), capacity_, head_); } + + // not really const. + Iterator begin() const { return Iterator(head_, buf_.get(), capacity_, head_); } + + Iterator end() const { return Iterator(tail_, buf_.get(), capacity_, head_); } + + /// \brief clear the array. Does not deallocate memory, capacity remains the same + void clear() { + head_ = 0; + tail_ = 0; + size_ = 0; + } + + /// \brief returns current size + dsize_t size() { return size_; } + + /// \brief returns capacity + dsize_t capacity() { return capacity_; } + + /// \brief pushes a value + /// \param[in] val value + void push_back(T val) { + buf_[tail_] = val; + if (size_ >= capacity_) { + (tail_ != capacity_) ? tail_++ : tail_ = 0; + (head_ != capacity_) ? head_++ : head_ = 0; + } else { + tail_++; + size_++; + } + } + + /// \brief returns const reference to an element of the array + /// \param[in] idx index of the element + /// \param[out] const T& reference to an element of the array + const T &operator[](dsize_t idx) const { return buf_[(head_ + idx) % (capacity_ + 1)]; } + + /// \brief returns non-const reference to an element of the array + /// \param[in] idx index of the element + /// \param[out] T& reference to an element of the array + T &operator[](dsize_t idx) { return buf_[(head_ + idx) % (capacity_ + 1)]; } + + private: + std::unique_ptr buf_; + dsize_t head_; + dsize_t tail_; + dsize_t size_; + dsize_t capacity_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_CYCLIC_ARRAY_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.cc new file mode 100644 index 0000000000..4491db144e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +Status DatasetIteratorTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, + const int32_t value) { + // Format: "type extra-info batch-num value" + // type: 0: time, 1: connector size + // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time + // if type is 1 - connector capacity + // batch-num: batch number + // value: if type is 0 - value is time(ms) + // if type is 1 - value is connector size + // Examples: + // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. + // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. + std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + + std::to_string(value); + value_.emplace_back(data); + return Status::OK(); +} + +Status DatasetIteratorTracing::SaveToFile() { + if (value_.empty()) { + return Status::OK(); + } + + std::ofstream handle(file_path_, std::ios::trunc); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); + } + for (auto value : value_) { + handle << value << "\n"; + } + handle.close(); + + return Status::OK(); +} + +Status DatasetIteratorTracing::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("dataset_iterator_profiling_" + device_id + ".txt")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h new file mode 100644 index 0000000000..e7ba237a0a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/dataset_iterator_tracing.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DATASET_ITERATOR_TRACING_H +#define MINDSPORE_DATASET_ITERATOR_TRACING_H + +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +class DatasetIteratorTracing : public Tracing { + public: + // Constructor + DatasetIteratorTracing() = default; + + // Destructor + ~DatasetIteratorTracing() override = default; + + // Record tracing data + // @return Status - The error code return + Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); + + std::string Name() const override { return kDatasetIteratorTracingName; }; + + // Save tracing data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id) override; + + private: + std::vector value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_DATASET_ITERATOR_TRACING_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.cc new file mode 100644 index 0000000000..776b483b79 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/util/path.h" +namespace mindspore { +namespace dataset { + +Status DeviceQueueTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, + const int32_t value) { + // Format: "type extra-info batch-num value" + // type: 0: time, 1: connector size + // extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time + // if type is 1 - connector capacity + // batch-num: batch number + // value: if type is 0 - value is time(ms) + // if type is 1 - value is connector size + // Examples: + // 0 0 20 10 - The 20th batch took 10ms to get data from pipeline. + // 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64. + std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " + + std::to_string(value); + value_.emplace_back(data); + return Status::OK(); +} + +Status DeviceQueueTracing::SaveToFile() { + if (value_.empty()) { + return Status::OK(); + } + + std::ofstream handle(file_path_, std::ios::trunc); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Profiling file can not be opened."); + } + for (auto value : value_) { + handle << value << "\n"; + } + handle.close(); + + return Status::OK(); +} + +Status DeviceQueueTracing::Init(const std::string &dir_path, const std::string &device_id) { + file_path_ = (Path(dir_path) / Path("device_queue_profiling_" + device_id + ".txt")).toString(); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h new file mode 100644 index 0000000000..32f9d2d8c2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/device_queue_tracing.h @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DEVICE_QUEUE_TRACING_H +#define MINDSPORE_DEVICE_QUEUE_TRACING_H + +#include +#include +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +class DeviceQueueTracing : public Tracing { + public: + // Constructor + DeviceQueueTracing() = default; + + // Destructor + ~DeviceQueueTracing() override = default; + + // Record tracing data + // @return Status - The error code return + Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value); + + std::string Name() const override { return kDeviceQueueTracingName; }; + + // Save tracing data to file + // @return Status - The error code return + Status SaveToFile() override; + + Status Init(const std::string &dir_path, const std::string &device_id) override; + + private: + std::vector value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_DEVICE_QUEUE_TRACING_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.cc new file mode 100644 index 0000000000..7fa7e6fc78 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/engine/perf/monitor.h" +#include "minddata/dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { + +Monitor::Monitor(ExecutionTree *tree) : tree_(tree) { + std::shared_ptr cfg = GlobalContext::config_manager(); + sampling_interval_ = cfg->monitor_sampling_interval(); + max_samples_ = 0; + cur_row_ = 0; +} +Status Monitor::operator()() { + // Register this thread with TaskManager to receive proper interrupt signal. + TaskManager::FindMe()->Post(); + + // Keep sampling if + // 1) Monitor Task is not interrupted by TaskManager AND + // 2) Iterator has not received EOF + while (!this_thread::is_interrupted() && !(tree_->isFinished())) { + for (auto &node : tree_->GetProfilingManager()->GetSamplingNodes()) { + RETURN_IF_NOT_OK(node.second->Sample()); + std::this_thread::sleep_for(std::chrono::milliseconds(sampling_interval_)); + } + } + + // Output all profiling data upon request. + tree_->GetProfilingManager()->SaveProfilingData(); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.h b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.h new file mode 100644 index 0000000000..1e669dad71 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/monitor.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MONITOR_H +#define MINDSPORE_MONITOR_H + +#include +#include +#include +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +class ExecutionTree; +class Monitor { + public: + // Monitor object constructor + + explicit Monitor(ExecutionTree *tree); + + Monitor() = default; + + ~Monitor() = default; + + // Functor for Perf Monitor main loop. + // This function will be the entry point of mindspore::Dataset::Task + Status operator()(); + + int64_t GetSamplingInterval() { return sampling_interval_; } + + private: + int64_t cur_row_; + int64_t max_samples_; + int64_t sampling_interval_; + ExecutionTree *tree_; + std::vector> sampling_list_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_MONITOR_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h b/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h new file mode 100644 index 0000000000..8f215fd8df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/perf_data.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_PERF_DATA_H +#define DATASET_PERF_DATA_H + +#include +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { + +// PerfData is a convenience class to record and store the data produced by Monitor +// and represents a 2D column major table with every column storing samples +// for an operator. The number of rows equals to the number of samples, +// the number of columns equals to the number of operators. +// The capacity is determined on construction and cannot be changed. +// ColumnType can be std::vector or CyclicArray. In case of the latter data can be added +// indefinitely without the risk of overflowing otherwise the capacity must not be exceeded. +// Given PerfData pd(n_rows, n_cols) an element in the column i and row j can be accessed as +// pd[i][j] + +template +class PerfData { + public: + PerfData() = default; + ~PerfData() = default; + PerfData(dsize_t max_rows, dsize_t n_cols) : counter_(0), max_rows_(max_rows), n_cols_(n_cols) { + for (auto i = 0; i < n_cols_; i++) { + data_.push_back(ColumnType(max_rows_)); + } + } + PerfData(const PerfData &rhs) = default; + PerfData(PerfData &&rhs) = default; + + // Adds a row of data + // T must be any container working with range based loops + template + void AddSample(const T &row) { + auto i = 0; + for (const auto &e : row) { + data_[i++].push_back(e); + } + counter_++; + } + + // Fetches a row of data by copy + template + auto Row(dsize_t idx) { + std::vector row(n_cols_); + for (auto i = 0; i < n_cols_; i++) { + row[i] = data_[i][idx]; + } + return row; + } + + // returns a column of data + ColumnType &operator[](size_t idx) { return data_[idx]; } + + const ColumnType &operator[](size_t idx) const { return data_[idx]; } + + dsize_t size() { return counter_ < max_rows_ ? counter_ : max_rows_; } + + dsize_t capacity() { return max_rows_; } + + private: + std::vector data_; + dsize_t counter_; + dsize_t max_rows_; + int n_cols_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_PERF_DATA_H diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc new file mode 100644 index 0000000000..f5c018c03b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/perf/profiling.h" +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/engine/perf/monitor.h" +#include "minddata/dataset/engine/perf/device_queue_tracing.h" +#include "minddata/dataset/engine/perf/connector_size.h" +#include "minddata/dataset/engine/perf/connector_throughput.h" +#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { + +bool ProfilingManager::IsProfilingEnable() const { + auto profiling = common::GetEnv("PROFILING_MODE"); + if (profiling.empty() || profiling != "true") { + return false; + } + return true; +} + +Status ProfilingManager::Initialize() { + // Register nodes based on config + std::string dir = common::GetEnv("MINDDATA_PROFILING_DIR"); + if (dir.empty()) { + RETURN_STATUS_UNEXPECTED("Profiling dir is not set."); + } + char real_path[PATH_MAX] = {0}; + if (dir.size() >= PATH_MAX) { + RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(real_path, common::SafeCStr(dir), PATH_MAX) == nullptr) { + RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); + } +#else + if (realpath(common::SafeCStr(dir), real_path) == nullptr) { + RETURN_STATUS_UNEXPECTED("Profiling dir is invalid."); + } +#endif + dir_path_ = real_path; + + // If DEVICE_ID is not set,defult value is 0 + device_id_ = common::GetEnv("DEVICE_ID"); + if (device_id_.empty()) { + device_id_ = "0"; + } + + // Register all profiling node. + // device_queue node is used for graph mode + std::shared_ptr device_queue_tracing = std::make_shared(); + RETURN_IF_NOT_OK(RegisterTracingNode(device_queue_tracing)); + // dataset_iterator node is used for graph mode + std::shared_ptr dataset_iterator_tracing = std::make_shared(); + RETURN_IF_NOT_OK(RegisterTracingNode(dataset_iterator_tracing)); + + std::shared_ptr connector_size_sampling = std::make_shared(tree_); + RETURN_IF_NOT_OK(RegisterSamplingNode(connector_size_sampling)); + + std::shared_ptr connector_thr_sampling = std::make_shared(tree_); + RETURN_IF_NOT_OK(RegisterSamplingNode(connector_thr_sampling)); + return Status::OK(); +} + +// Profiling node registration +Status ProfilingManager::RegisterTracingNode(std::shared_ptr node) { + // Check if node with the same name has already been registered. + auto exist = tracing_nodes_.find(node->Name()); + if (exist != tracing_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); + } + // Register the node with its name as key. + RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); + tracing_nodes_[node->Name()] = node; + return Status::OK(); +} + +// Profiling node getter +Status ProfilingManager::GetTracingNode(const std::string &name, std::shared_ptr *node) { + // Check if node with the same name has already been registered. + auto exist = tracing_nodes_.find(name); + if (exist == tracing_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); + } + // Fetch node. + *node = tracing_nodes_[name]; + return Status::OK(); +} + +// Profiling node registration +Status ProfilingManager::RegisterSamplingNode(std::shared_ptr node) { + // Check if node with the same name has already been registered. + auto exist = sampling_nodes_.find(node->Name()); + if (exist != sampling_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node already exist: " + node->Name()); + } + // Register the node with its name as key. + RETURN_IF_NOT_OK(node->Init(dir_path_, device_id_)); + sampling_nodes_[node->Name()] = node; + return Status::OK(); +} + +// Profiling node getter +Status ProfilingManager::GetSamplingNode(const std::string &name, std::shared_ptr *node) { + // Check if node with the same name has already been registered. + auto exist = sampling_nodes_.find(name); + if (exist == sampling_nodes_.end()) { + return Status(StatusCode::kProfilingError, "Profiling node does not exist: " + name); + } + // Fetch node. + *node = sampling_nodes_[name]; + return Status::OK(); +} + +Status ProfilingManager::SaveProfilingData() { + if (!IsProfilingEnable()) { + return Status::OK(); + } + MS_LOG(INFO) << "Start to save profiling data."; + for (auto node : tracing_nodes_) { + RETURN_IF_NOT_OK(node.second->SaveToFile()); + } + for (auto node : sampling_nodes_) { + RETURN_IF_NOT_OK(node.second->SaveToFile()); + } + MS_LOG(INFO) << "Save profiling data end."; + return Status::OK(); +} + +int64_t ProfilingTime::GetCurMilliSecond() { + // because cpplint does not allow using namespace + using std::chrono::duration_cast; + using std::chrono::milliseconds; + using std::chrono::steady_clock; + return duration_cast(steady_clock::now().time_since_epoch()).count(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h new file mode 100644 index 0000000000..24f7f2efe8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/perf/profiling.h @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_PROFILE_H_ +#define DATASET_UTIL_PROFILE_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class Monitor; +class ExecutionTree; + +const char kDeviceQueueTracingName[] = "Device_Queue_Tracing"; +const char kDatasetIteratorTracingName[] = "Dataset_Iterator_Tracing"; +const char kConnectorSizeSamplingName[] = "Connector_Size_Sampling"; +const char kConnectorThroughputSamplingName[] = "Connector_Throughput_Sampling"; + +// Profiling is a class of basic unit of profiling action +// This base class encapsulate the serialization output logic +class Profiling : std::enable_shared_from_this { + public: + // Constructor + Profiling() = default; + + // Destructor + virtual ~Profiling() = default; + + virtual Status Init(const std::string &dir_path, const std::string &device_id) = 0; + + // Default serialization file generator + virtual Status SaveToFile() = 0; + + // Profiling name + virtual std::string Name() const = 0; + + protected: + std::string file_path_; +}; + +// Sampling is a class of profiling which generate samples periodically. +class Sampling : public Profiling { + public: + // Sampling action function. This function will be invoked by performance monitor thread. + virtual Status Sample() = 0; + // virtual Status TestPrint() = 0; + virtual ~Sampling() = default; +}; + +// Tracing is class of profiling which record samples upon request. +class Tracing : public Profiling { + // Tracing does not define a fixed interface to provide flexible on data recording. +}; + +// ProfilingManager is a class manages all profiling infrastructure +// It serves the following purposes: +// 1) Fetch profiling configs from global contexts +// 2) Setup all profiling node based on config +// 3) Provide access of profiling nodes for profiling actions +// 4) Manage profiling data serialization process +class ProfilingManager { + public: + explicit ProfilingManager(ExecutionTree *tree) : tree_(tree) {} + + ~ProfilingManager() = default; + + Status Initialize(); + + // Save profile data to file + // @return Status - The error code return + Status SaveProfilingData(); + + // Sampling node getter + // @param name - The name of the requested node + // @param node - Pointer to the shared pointer for the Sampling node + // @return Status - The error code return + Status GetSamplingNode(const std::string &name, std::shared_ptr *node); + + // Tracing node getter + // @param name - The name of the requested node + // @param node - Pointer to the shared pointer for the Tracing node + // @return Status - The error code return + Status GetTracingNode(const std::string &name, std::shared_ptr *node); + + // If profiling is enabled. + bool IsProfilingEnable() const; + + const std::unordered_map> &GetSamplingNodes() { return sampling_nodes_; } + + private: + std::unordered_map> tracing_nodes_; + + std::unordered_map> sampling_nodes_; + + // Register profile node to tree + // @param node - Profiling node + // @return Status - The error code return + Status RegisterTracingNode(std::shared_ptr node); + + // Register profile node to tree + // @param node - Profiling node + // @return Status - The error code return + Status RegisterSamplingNode(std::shared_ptr node); + + ExecutionTree *tree_ = nullptr; // ExecutionTree pointer + std::string dir_path_; // where to create profiling file + std::string device_id_; // used when create profiling file,filename_deviceid.suffix +}; + +enum ProfilingType { TIME, CONNECTOR_DEPTH }; + +enum ProfilingTimeSubType { + PIPELINE_TIME, + TDT_PUSH_TIME, + BATCH_TIME, + INVALID_TIME, +}; + +class ProfilingTime { + public: + static int64_t GetCurMilliSecond(); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/dataset/engine/tdt/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/tdt/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/engine/tdt/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/engine/tdt/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc new file mode 100644 index 0000000000..126291179a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/tdt/tdt_plugin.h" +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/engine/perf/profiling.h" + +namespace mindspore { +namespace dataset { +static std::shared_ptr instance_ptr_ = nullptr; + +std::shared_ptr TdtPlugin::GetInstance() { + if (instance_ptr_ == nullptr) { + instance_ptr_ = std::shared_ptr(new TdtPlugin); + } + return instance_ptr_; +} + +TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) { + MS_LOG(DEBUG) << "TDT channel name is " << channel_name << "."; + std::vector items; + double start_time; + auto ret = translate(ts_row, items); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "TDT converting tensor failed!"; + return FAILED; + } + if (profiling) { + start_time = ProfilingTime::GetCurMilliSecond(); + } + if (tdt::TdtHostPushData(channel_name, items) != 0) { + MS_LOG(ERROR) << "TDT pushing data failed!"; + return FAILED; + } + if (profiling) { + double end_time = ProfilingTime::GetCurMilliSecond(); + time = (int32_t)(end_time - start_time); + } + return SUCCESS; +} + +TdtStatus TdtPlugin::getTdtType(DataType d_type, std::string &datatype) { + switch (d_type.value()) { + case DataType::DE_BOOL: + datatype = "bool"; + break; + case DataType::DE_INT8: + datatype = "int8"; + break; + case DataType::DE_UINT8: + datatype = "uint8"; + break; + case DataType::DE_INT16: + datatype = "int16"; + break; + case DataType::DE_UINT16: + datatype = "uint16"; + break; + case DataType::DE_INT32: + datatype = "int32"; + break; + case DataType::DE_UINT32: + datatype = "uint32"; + break; + case DataType::DE_FLOAT16: + datatype = "float16"; + break; + case DataType::DE_FLOAT32: + datatype = "float32"; + break; + case DataType::DE_FLOAT64: + datatype = "float64"; + break; + case DataType::DE_INT64: + datatype = "int64"; + break; + case DataType::DE_UINT64: + datatype = "uint64"; + break; + default: + return FAILED; + } + return SUCCESS; +} + +TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector &items) { + if (ts_row.size() == 0) { + MS_LOG(ERROR) << "TDT the size of row is zero."; + return SUCCESS; + } + for (auto ts : ts_row) { + std::string datatype; + TdtStatus status = getTdtType(ts->type(), datatype); + if (status != SUCCESS) { + return status; + } + TensorShape tsShape = ts->shape(); + std::string dataShapes = "["; + for (auto dim : tsShape.AsVector()) { + (void)dataShapes.append(std::to_string(dim)).append(","); + } + dataShapes.pop_back(); + (void)dataShapes.append("]"); + DataItem data_item; + data_item.dataType_ = tdt::TDT_TENSOR; + data_item.tensorShape_ = dataShapes; + data_item.tensorType_ = datatype; + data_item.dataLen_ = ts->SizeInBytes(); + data_item.dataPtr_ = + std::shared_ptr(reinterpret_cast(&(*ts->begin())), [](const void *elem) {}); + items.emplace_back(data_item); + MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is " + << ts->Size() << "."; + } + return SUCCESS; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h new file mode 100644 index 0000000000..a7db08b7f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_TDT_TDT_PLUGIN_H_ +#define DATASET_ENGINE_TDT_TDT_PLUGIN_H_ + +#include +#include +#include +#include +#include +#include +#include "tdt/tdt_host_interface.h" + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +enum TdtStatus { SUCCESS, FAILED }; + +using tdt::DataItem; + +class TdtPlugin { + public: + static std::shared_ptr GetInstance(); + + TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time); + + private: + TdtPlugin() {} + + TdtStatus getTdtType(DataType d_type, std::string &datatype); + + TdtStatus translate(const TensorRow &ts_row, std::vector &items); + + void *tdt_handle_ = nullptr; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_TDT_TDT_PLUGIN_H_ diff --git a/mindspore/ccsrc/dataset/include/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h similarity index 100% rename from mindspore/ccsrc/dataset/include/dataset/core/constants.h rename to mindspore/ccsrc/minddata/dataset/include/dataset/core/constants.h diff --git a/mindspore/ccsrc/dataset/include/dataset/core/data_type.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h similarity index 100% rename from mindspore/ccsrc/dataset/include/dataset/core/data_type.h rename to mindspore/ccsrc/minddata/dataset/include/dataset/core/data_type.h diff --git a/mindspore/ccsrc/dataset/include/dataset/core/tensor_shape.h b/mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h similarity index 100% rename from mindspore/ccsrc/dataset/include/dataset/core/tensor_shape.h rename to mindspore/ccsrc/minddata/dataset/include/dataset/core/tensor_shape.h diff --git a/mindspore/ccsrc/dataset/include/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h similarity index 100% rename from mindspore/ccsrc/dataset/include/dataset/util/status.h rename to mindspore/ccsrc/minddata/dataset/include/dataset/util/status.h diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h new file mode 100644 index 0000000000..6f38f5ea16 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -0,0 +1,357 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_INCLUDE_DATASETS_H_ +#define DATASET_INCLUDE_DATASETS_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/include/tensor.h" +#include "minddata/dataset/include/iterator.h" +#include "minddata/dataset/include/samplers.h" + +namespace mindspore { +namespace dataset { + +// Forward declare +class DatasetOp; +class DataSchema; +class Tensor; +class TensorShape; + +namespace api { + +class TensorOperation; +class SamplerObj; +class ImageFolderDataset; +class MnistDataset; +class BatchDataset; +class RepeatDataset; +class MapDataset; +class ShuffleDataset; +class Cifar10Dataset; +class ProjectDataset; + +/// \brief Function to create an ImageFolderDataset +/// \notes A source dataset that reads images from a tree of directories +/// All images within one folder have the same label +/// The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] decode A flag to decode in ImageFolder +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, +/// A `RandomSampler` will be used to randomly iterate the entire dataset +/// \param[in] extensions File extensions to be read +/// \param[in] class_indexing a class name to label map +/// \return Shared pointer to the current ImageFolderDataset +std::shared_ptr ImageFolder(std::string dataset_dir, bool decode = false, + std::shared_ptr sampler = nullptr, + std::set extensions = {}, + std::map class_indexing = {}); + +/// \brief Function to create a MnistDataset +/// \notes The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, +/// A `RandomSampler` will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current MnistDataset +std::shared_ptr Mnist(std::string dataset_dir, std::shared_ptr sampler = nullptr); + +/// \brief Function to create a Cifar10 Dataset +/// \notes The generated dataset has two columns ['image', 'label'] +/// \param[in] dataset_dir Path to the root directory that contains the dataset +/// \param[in] num_samples The number of images to be included in the dataset +/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` +/// will be used to randomly iterate the entire dataset +/// \return Shared pointer to the current Dataset +std::shared_ptr Cifar10(const std::string &dataset_dir, int32_t num_samples, + std::shared_ptr sampler); + +/// \class Dataset datasets.h +/// \brief A base class to represent a dataset in the data pipeline. +class Dataset : public std::enable_shared_from_this { + public: + friend class Iterator; + + /// \brief Constructor + Dataset(); + + /// \brief Destructor + ~Dataset() = default; + + /// \brief Pure virtual function to convert a Dataset class into a runtime dataset object + /// \return shared pointer to the list of newly created DatasetOps + virtual std::shared_ptr>> Build() = 0; + + /// \brief Pure virtual function for derived class to implement parameters validation + /// \return bool True if all the params are valid + virtual bool ValidateParams() = 0; + + /// \brief Setter function for runtime number of workers + /// \param[in] num_workers The number of threads in this operator + /// \return Shared pointer to the original object + std::shared_ptr SetNumWorkers(int32_t num_workers) { + num_workers_ = num_workers; + return shared_from_this(); + } + + /// \brief Function to create an Iterator over the Dataset pipeline + /// \return Shared pointer to the Iterator + std::shared_ptr CreateIterator(); + + /// \brief Function to create a BatchDataset + /// \notes Combines batch_size number of consecutive rows into batches + /// \param[in] batch_size Path to the root directory that contains the dataset + /// \param[in] drop_remainder Determines whether or not to drop the last possibly incomplete + /// batch. If true, and if there are less than batch_size rows + /// available to make the last batch, then those rows will + /// be dropped and not propagated to the next node + /// \return Shared pointer to the current BatchDataset + std::shared_ptr Batch(int32_t batch_size, bool drop_remainder = false); + + /// \brief Function to create a RepeatDataset + /// \notes Repeats this dataset count times. Repeat indefinitely if count is -1 + /// \param[in] count Number of times the dataset should be repeated + /// \return Shared pointer to the current Dataset + /// \note Repeat will return shared pointer to `Dataset` instead of `RepeatDataset` + /// due to a limitation in the current implementation + std::shared_ptr Repeat(int32_t count = -1); + + /// \brief Function to create a MapDataset + /// \notes Applies each operation in operations to this dataset + /// \param[in] operations Vector of operations to be applied on the dataset. Operations are + /// applied in the order they appear in this list + /// \param[in] input_columns Vector of the names of the columns that will be passed to the first + /// operation as input. The size of this list must match the number of + /// input columns expected by the first operator. The default input_columns + /// is the first column + /// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation + /// This parameter is mandatory if len(input_columns) != len(output_columns) + /// The size of this list must match the number of output columns of the + /// last operation. The default output_columns will have the same + /// name as the input columns, i.e., the columns will be replaced + /// \param[in] project_columns A list of column names to project + /// \return Shared pointer to the current MapDataset + std::shared_ptr Map(std::vector> operations, + std::vector input_columns = {}, + std::vector output_columns = {}, + const std::vector &project_columns = {}); + + /// \brief Function to create a Shuffle Dataset + /// \notes Randomly shuffles the rows of this dataset + /// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling + /// \return Shared pointer to the current ShuffleDataset + std::shared_ptr Shuffle(int32_t shuffle_size); + + /// \brief Function to create a Project Dataset + /// \notes Applies project to the dataset + /// \param[in] columns The name of columns to project + /// \return Shared pointer to the current Dataset + std::shared_ptr Project(const std::vector &columns); + + protected: + std::vector> children; + std::shared_ptr parent; + + int32_t num_workers_; + int32_t rows_per_buffer_; + int32_t connector_que_size_; +}; + +/* ####################################### Derived Dataset classes ################################# */ + +/// \class ImageFolderDataset +/// \brief A Dataset derived class to represent ImageFolder dataset +class ImageFolderDataset : public Dataset { + public: + /// \brief Constructor + ImageFolderDataset(std::string dataset_dir, bool decode, std::shared_ptr sampler, bool recursive, + std::set extensions, std::map class_indexing); + + /// \brief Destructor + ~ImageFolderDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + bool decode_; + bool recursive_; + std::shared_ptr sampler_; + std::map class_indexing_; + std::set exts_; +}; + +class MnistDataset : public Dataset { + public: + /// \brief Constructor + MnistDataset(std::string dataset_dir, std::shared_ptr sampler); + + /// \brief Destructor + ~MnistDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + std::shared_ptr sampler_; +}; + +class BatchDataset : public Dataset { + public: + /// \brief Constructor + BatchDataset(int32_t batch_size, bool drop_remainder, bool pad, std::vector cols_to_map, + std::map>> pad_map); + + /// \brief Destructor + ~BatchDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + int32_t batch_size_; + bool drop_remainder_; + bool pad_; + std::vector cols_to_map_; + std::map>> pad_map_; +}; + +class RepeatDataset : public Dataset { + public: + /// \brief Constructor + explicit RepeatDataset(uint32_t count); + + /// \brief Destructor + ~RepeatDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + uint32_t repeat_count_; +}; + +class ShuffleDataset : public Dataset { + public: + ShuffleDataset(int32_t shuffle_size, bool reset_every_epoch); + + ~ShuffleDataset() = default; + + std::shared_ptr>> Build() override; + + bool ValidateParams() override; + + private: + int32_t shuffle_size_; + uint32_t shuffle_seed_; + bool reset_every_epoch_; +}; + +class MapDataset : public Dataset { + public: + /// \brief Constructor + MapDataset(std::vector> operations, std::vector input_columns = {}, + std::vector output_columns = {}, const std::vector &columns = {}); + + /// \brief Destructor + ~MapDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector> operations_; + std::vector input_columns_; + std::vector output_columns_; + std::vector project_columns_; +}; + +class Cifar10Dataset : public Dataset { + public: + /// \brief Constructor + Cifar10Dataset(const std::string &dataset_dir, int32_t num_samples, std::shared_ptr sampler); + + /// \brief Destructor + ~Cifar10Dataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::string dataset_dir_; + int32_t num_samples_; + std::shared_ptr sampler_; +}; + +class ProjectDataset : public Dataset { + public: + /// \brief Constructor + explicit ProjectDataset(const std::vector &columns); + + /// \brief Destructor + ~ProjectDataset() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::shared_ptr>> Build() override; + + /// \brief Parameters validation + /// \return bool true if all the params are valid + bool ValidateParams() override; + + private: + std::vector columns_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_INCLUDE_DATASETS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/iterator.h b/mindspore/ccsrc/minddata/dataset/include/iterator.h new file mode 100644 index 0000000000..c3784821a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/iterator.h @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_INCLUDE_ITERATOR_H_ +#define DATASET_INCLUDE_ITERATOR_H_ + +#include +#include +#include +#include +#include "minddata/dataset/include/status.h" + +namespace mindspore { +namespace dataset { + +// Forward declare +class ExecutionTree; +class DatasetIterator; +class DatasetOp; +class Tensor; + +namespace api { + +class Dataset; + +using TensorMap = std::unordered_map>; + +// Abstract class for iterating over the dataset. +class Iterator { + public: + /// \brief Constructor + Iterator() = default; + + /// \brief Destructor + ~Iterator() = default; + + /// \brief Method for building and launching the pipeline. + /// \param[in] ops - a vector of DatasetOp in the data pipeline. + /// \return - a Status error code, returns OK if no error encountered. + Status BuildAndLaunchTree(std::shared_ptr ds); + + /// \brief Function to get the next row from the data pipeline. + /// \param[out] row - the output tensor row. + void GetNextRow(TensorMap *row); + + /// \brief Function to shut down the data pipeline. + void Stop(); + + class _Iterator { + public: + explicit _Iterator(Iterator *lt) : lt_{lt}, cur_row_{nullptr} { + if (lt_) { + cur_row_ = new TensorMap(); + lt_->GetNextRow(cur_row_); + } + } + + // Destructor + ~_Iterator() { + if (cur_row_) { + delete cur_row_; + } + } + + _Iterator &operator++() { + if (lt_) { + ++ind_; + lt_->GetNextRow(cur_row_); + } + if (cur_row_ && cur_row_->size() == 0) { + delete cur_row_; + cur_row_ = nullptr; + } + return *this; + } // prefix ++ overload + TensorMap &operator*() { return *cur_row_; } // dereference operator + TensorMap *operator->() { return cur_row_; } + + bool operator!=(const _Iterator &rhs) { return cur_row_ != rhs.cur_row_; } + + private: + int ind_; // the cur node our Iterator points to + Iterator *lt_; + TensorMap *cur_row_; + }; + + _Iterator begin() { return _Iterator(this); } + + _Iterator end() { return _Iterator(nullptr); } + + private: + // Runtime tree. + // Use shared_ptr instead of unique_ptr because the DatasetIterator constructor takes in a shared_ptr type. + std::shared_ptr tree_; + + // Runtime iterator + std::unique_ptr iterator_; +}; +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_INCLUDE_ITERATOR_H_ diff --git a/mindspore/ccsrc/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h similarity index 100% rename from mindspore/ccsrc/dataset/include/samplers.h rename to mindspore/ccsrc/minddata/dataset/include/samplers.h diff --git a/mindspore/ccsrc/dataset/include/status.h b/mindspore/ccsrc/minddata/dataset/include/status.h similarity index 100% rename from mindspore/ccsrc/dataset/include/status.h rename to mindspore/ccsrc/minddata/dataset/include/status.h diff --git a/mindspore/ccsrc/dataset/include/tensor.h b/mindspore/ccsrc/minddata/dataset/include/tensor.h similarity index 100% rename from mindspore/ccsrc/dataset/include/tensor.h rename to mindspore/ccsrc/minddata/dataset/include/tensor.h diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h new file mode 100644 index 0000000000..31531a20af --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -0,0 +1,380 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_API_TRANSFORMS_H_ +#define DATASET_API_TRANSFORMS_H_ + +#include +#include +#include "minddata/dataset/core/constants.h" + +namespace mindspore { +namespace dataset { + +class TensorOp; + +namespace api { +// Abstract class to represent a dataset in the data pipeline. +class TensorOperation : public std::enable_shared_from_this { + public: + /// \brief Constructor + TensorOperation(); + + /// \brief Destructor + ~TensorOperation() = default; + + /// \brief Pure virtual function to convert a TensorOperation class into a runtime TensorOp object. + /// \return shared pointer to the newly created TensorOp. + virtual std::shared_ptr Build() = 0; + + virtual bool ValidateParams() = 0; +}; + +// Transform operations for performing computer vision. +namespace vision { + +class NormalizeOperation; +class DecodeOperation; +class ResizeOperation; +class RandomCropOperation; +class CenterCropOperation; +class UniformAugOperation; +class RandomHorizontalFlipOperation; +class RandomVerticalFlipOperation; +class RandomRotationOperation; +class PadOperation; +class CutOutOperation; +class RandomColorAdjustOperation; + +/// \brief Function to create a Normalize TensorOperation. +/// \notes Normalize the input image with respect to mean and standard deviation. +/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. +/// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Normalize(std::vector mean, std::vector std); + +/// \brief Function to create a Decode TensorOperation. +/// \notes Decode the input image in RGB mode. +/// \param[in] rgb - a boolean of whether to decode in RGB mode or not. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Decode(bool rgb = true); + +/// \brief Function to create a Resize TensorOperation. +/// \notes Resize the input image to the given size.. +/// \param[in] size - a vector representing the output size of the resized image. +/// If size is a single value, the image will be resized to this value with +/// the same image aspect ratio. If size has 2 values, it should be (height, width). +/// \param[in] interpolation An enum for the mode of interpolation +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr Resize(std::vector size, + InterpolationMode interpolation = InterpolationMode::kLinear); + +/// \brief Function to create a RandomCrop TensorOperation. +/// \notes Crop the input image at a random location. +/// \param[in] size - a vector representing the output size of the cropped image. +/// If size is a single value, a square crop of size (size, size) is returned. +/// If size has 2 values, it should be (height, width). +/// \param[in] padding - a vector with the value of pixels to pad the image. If 4 values are provided, +/// it pads the left, top, right and bottom respectively. +/// \param[in] pad_if_needed - a boolean whether to pad the image if either side is smaller than +/// the given output size. +/// \param[in] fill_value - a vector representing the pixel intensity of the borders, it is used to +/// fill R, G, B channels respectively. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomCrop(std::vector size, std::vector padding = {0, 0, 0, 0}, + bool pad_if_needed = false, + std::vector fill_value = {0, 0, 0}); + +/// \brief Function to create a CenterCrop TensorOperation. +/// \notes Crops the input image at the center to the given size. +/// \param[in] size - a vector representing the output size of the cropped image. +/// If size is a single value, a square crop of size (size, size) is returned. +/// If size has 2 values, it should be (height, width). +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr CenterCrop(std::vector size); + +/// \brief Function to create a UniformAugment TensorOperation. +/// \notes Tensor operation to perform randomly selected augmentation. +/// \param[in] operations - a vector of TensorOperation operations. +/// \param[in] num_ops - integer representing the number of OPs to be selected and applied. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr UniformAugment(std::vector> operations, + int32_t num_ops = 2); + +/// \brief Function to create a RandomHorizontalFlip TensorOperation. +/// \notes Tensor operation to perform random horizontal flip. +/// \param[in] prob - float representing the probability of flip. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomHorizontalFlip(float prob = 0.5); + +/// \brief Function to create a RandomVerticalFlip TensorOperation. +/// \notes Tensor operation to perform random vertical flip. +/// \param[in] prob - float representing the probability of flip. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr RandomVerticalFlip(float prob = 0.5); + +/// \brief Function to create a RandomRotation TensorOp +/// \notes Rotates the image according to parameters +/// \param[in] degrees A float vector size 2, representing the starting and ending degree +/// \param[in] resample An enum for the mode of interpolation +/// \param[in] expand A boolean representing whether the image is expanded after rotation +/// \param[in] center A float vector size 2, representing the x and y center of rotation. +/// \param[in] fill_value A uint8_t vector size 3, representing the rgb value of the fill color +/// \return Shared pointer to the current TensorOp +std::shared_ptr RandomRotation( + std::vector degrees, InterpolationMode resample = InterpolationMode::kNearestNeighbour, bool expand = false, + std::vector center = {-1, -1}, std::vector fill_value = {0, 0, 0}); + +/// \brief Function to create a Pad TensorOp +/// \notes Pads the image according to padding parameters +/// \param[in] padding A vector representing the number of pixels to pad the image +/// If vector has one value, it pads all sides of the image with that value +/// If vector has two values, it pads left and right with the first and +/// top and bottom with the second value +/// If vector has four values, it pads left, top, right, and bottom with +/// those values respectively +/// \param[in] fill_value A vector representing the pixel intensity of the borders if the padding_mode is +/// BorderType.kConstant. If 3 values are provided, +/// it is used to fill R, G, B channels respectively +/// \param[in] padding_mode The method of padding (default=BorderType.kConstant) +/// Can be any of +/// [BorderType.kConstant, BorderType.kEdge, BorderType.kReflect, BorderType.kSymmetric] +/// - BorderType.kConstant, means it fills the border with constant values +/// - BorderType.kEdge, means it pads with the last value on the edge +/// - BorderType.kReflect, means it reflects the values on the edge omitting the last value of edge +/// - BorderType.kSymmetric, means it reflects the values on the edge repeating the last value of edge +/// \return Shared pointer to the current TensorOp +std::shared_ptr Pad(std::vector padding, std::vector fill_value = {0}, + BorderType padding_mode = BorderType::kConstant); + +/// \brief Function to create a CutOut TensorOp +/// \notes Randomly cut (mask) out a given number of square patches from the input image +/// \param[in] length Integer representing the side length of each square patch +/// \param[in] num_patches Integer representing the number of patches to be cut out of an image +/// \return Shared pointer to the current TensorOp +std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1); + +/// \brief Randomly adjust the brightness, contrast, saturation, and hue of the input image +/// \param[in] brightness Brightness adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} +/// \param[in] contrast Contrast adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} +/// \param[in] saturation Saturation adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it needs to be in the form of [min, max]. Default value is {1, 1} +/// \param[in] hue Brightness adjustment factor. Must be a vector of one or two values +/// if it's a vector of two values it must be in the form of [min, max] where -0.5 <= min <= max <= 0.5 +/// Default value is {0, 0} +/// \return Shared pointer to the current TensorOp +std::shared_ptr RandomColorAdjust(std::vector brightness = {1.0, 1.0}, + std::vector contrast = {1.0, 1.0}, + std::vector saturation = {1.0, 1.0}, + std::vector hue = {0.0, 0.0}); + +/* ####################################### Derived TensorOperation classes ################################# */ + +class NormalizeOperation : public TensorOperation { + public: + NormalizeOperation(std::vector mean, std::vector std); + + ~NormalizeOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector mean_; + std::vector std_; +}; + +class DecodeOperation : public TensorOperation { + public: + explicit DecodeOperation(bool rgb = true); + + ~DecodeOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + bool rgb_; +}; + +class ResizeOperation : public TensorOperation { + public: + explicit ResizeOperation(std::vector size, + InterpolationMode interpolation_mode = InterpolationMode::kLinear); + + ~ResizeOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; + InterpolationMode interpolation_; +}; + +class RandomCropOperation : public TensorOperation { + public: + RandomCropOperation(std::vector size, std::vector padding = {0, 0, 0, 0}, + bool pad_if_needed = false, std::vector fill_value = {0, 0, 0}); + + ~RandomCropOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; + std::vector padding_; + bool pad_if_needed_; + std::vector fill_value_; +}; + +class CenterCropOperation : public TensorOperation { + public: + explicit CenterCropOperation(std::vector size); + + ~CenterCropOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector size_; +}; + +class UniformAugOperation : public TensorOperation { + public: + explicit UniformAugOperation(std::vector> operations, int32_t num_ops = 2); + + ~UniformAugOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector> operations_; + int32_t num_ops_; +}; + +class RandomHorizontalFlipOperation : public TensorOperation { + public: + explicit RandomHorizontalFlipOperation(float probability = 0.5); + + ~RandomHorizontalFlipOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float probability_; +}; + +class RandomVerticalFlipOperation : public TensorOperation { + public: + explicit RandomVerticalFlipOperation(float probability = 0.5); + + ~RandomVerticalFlipOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float probability_; +}; + +class RandomRotationOperation : public TensorOperation { + public: + RandomRotationOperation(std::vector degrees, InterpolationMode interpolation_mode, bool expand, + std::vector center, std::vector fill_value); + + ~RandomRotationOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector degrees_; + InterpolationMode interpolation_mode_; + std::vector center_; + bool expand_; + std::vector fill_value_; +}; + +class PadOperation : public TensorOperation { + public: + PadOperation(std::vector padding, std::vector fill_value = {0}, + BorderType padding_mode = BorderType::kConstant); + + ~PadOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector padding_; + std::vector fill_value_; + BorderType padding_mode_; +}; + +class CutOutOperation : public TensorOperation { + public: + explicit CutOutOperation(int32_t length, int32_t num_patches = 1); + + ~CutOutOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + int32_t length_; + int32_t num_patches_; +}; + +class RandomColorAdjustOperation : public TensorOperation { + public: + RandomColorAdjustOperation(std::vector brightness = {1.0, 1.0}, std::vector contrast = {1.0, 1.0}, + std::vector saturation = {1.0, 1.0}, std::vector hue = {0.0, 0.0}); + + ~RandomColorAdjustOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + std::vector brightness_; + std::vector contrast_; + std::vector saturation_; + std::vector hue_; +}; +} // namespace vision +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // DATASET_API_TRANSFORMS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h b/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h new file mode 120000 index 0000000000..f2c939bc0b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/utils/log_adapter.h @@ -0,0 +1 @@ +../../../../utils/log_adapter.h \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/include/utils/overload.h b/mindspore/ccsrc/minddata/dataset/include/utils/overload.h new file mode 120000 index 0000000000..7dc313d512 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/include/utils/overload.h @@ -0,0 +1 @@ +../../../../utils/overload.h \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/kernels/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/kernels/CMakeLists.txt diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/kernels/data/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.cc new file mode 100644 index 0000000000..0c91b38b2d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/concatenate_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status ConcatenateOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + RETURN_IF_NOT_OK(Concatenate(input, output, axis_, prepend_, append_)); + return Status::OK(); +} + +Status ConcatenateOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + + std::vector inputs_copy; + inputs_copy.push_back(inputs[0].Squeeze()); + + CHECK_FAIL_RETURN_UNEXPECTED(inputs.at(0).Rank() == 1, "Only 1D input tensors supported"); + + outputs.clear(); + dsize_t output_shape = 0; + output_shape = output_shape + inputs.at(0).NumOfElements(); + if (prepend_ != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(prepend_->shape().Rank() == 1, "Only 1D prepend tensors supported"); + output_shape = output_shape + prepend_->shape().NumOfElements(); + } + if (append_ != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(append_->shape().Rank() == 1, "Only 1D append tensors supported"); + output_shape = output_shape + append_->shape().NumOfElements(); + } + + outputs.emplace_back(std::vector{output_shape}); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h new file mode 100644 index 0000000000..46cc613049 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/concatenate_op.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_DATA_CONCATENATE_OP_H_ +#define DATASET_KERNELS_DATA_CONCATENATE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class ConcatenateOp : public TensorOp { + public: + /// Constructor to ConcatenateOp. + /// @param int8_t axis - axis to concatenate tensors along. + /// @param std::shared_ptr prepend - prepend tensor. + /// @param std::shared_ptr append -append tensor. + explicit ConcatenateOp(int8_t axis, std::shared_ptr prepend, std::shared_ptr append) + : axis_(axis), prepend_(prepend), append_(append) {} + + ~ConcatenateOp() override = default; + + /// Print method to see which tensor Op this is. + /// @param std::ostream &out - output stream object. + void Print(std::ostream &out) const override { out << "ConcatenateOp"; } + + /// Compute method allowing multiple tensors as inputs + /// @param TensorRow &input - input tensor rows + /// @param TensorRow *output - output tensor rows + Status Compute(const TensorRow &input, TensorRow *output) override; + + /// Compute tensor output shape + /// @param std::vector &inputs - vector of input tensor shapes + /// @param std::vector &inputs, std::vector &outputs) override; + + /// Number of inputs the tensor operation accepts + uint32_t NumInput() override { return 0; } + + std::string Name() const override { return kConcatenateOp; } + + private: + int8_t axis_; + std::shared_ptr prepend_; + std::shared_ptr append_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CONCATENATE_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc new file mode 100644 index 0000000000..b1d51a6c08 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -0,0 +1,656 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/data_utils.h" + +#include +#include +#include +#include + +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/data_type.h" +#ifdef ENABLE_PYTHON +#include "minddata/dataset/core/pybind_support.h" +#endif +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, + dsize_t num_classes, int64_t index) { + uint64_t class_idx; + if (input->Rank() == 0) { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); + } else { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); + } + if (class_idx >= static_cast(num_classes)) { + RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); + } + if (input->type() == DataType::DE_UINT64) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_UINT32) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_UINT16) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_UINT8) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else { + RETURN_STATUS_UNEXPECTED("One hot unsigned only supports unsigned int as input."); + } + return Status::OK(); +} + +Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, + int64_t index) { + int64_t class_idx; + if (input->Rank() == 0) { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {})); + } else { + RETURN_IF_NOT_OK(input->GetItemAt(&class_idx, {index})); + } + if (class_idx >= static_cast(num_classes)) { + RETURN_STATUS_UNEXPECTED("One_hot index values are not in range"); + } + if (input->type() == DataType::DE_INT64) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_INT32) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_INT16) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else if (input->type() == DataType::DE_INT8) { + RETURN_IF_NOT_OK((*output)->SetItemAt({index, static_cast(class_idx)}, 1)); + } else { + RETURN_STATUS_UNEXPECTED("One hot signed only supports signed int as input."); + } + return Status::OK(); +} + +Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes) { + input->Squeeze(); + + if (input->Rank() > 1) { // We expect the input to be int he first dimension + RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors."); + } + if (!input->type().IsInt()) { + RETURN_STATUS_UNEXPECTED("One hot does not support input of this type."); + } + try { + dsize_t num_elements = 1; + if (input->Rank() == 1) num_elements = input->shape()[0]; + TensorShape out_shape({num_elements, num_classes}); + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, out_shape, input->type())); + RETURN_IF_NOT_OK(out->Zero()); + for (dsize_t i = 0; i < num_elements; ++i) { + if (input->type().IsUnsignedInt()) { + RETURN_IF_NOT_OK(OneHotEncodingUnsigned(input, &out, num_classes, i)); + } else { + RETURN_IF_NOT_OK(OneHotEncodingSigned(input, &out, num_classes, i)); + } + } + out->Squeeze(); + *output = out; + return Status::OK(); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in OneHotOp"); + } +} + +Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value) { + const DataType &fill_type = fill_value->type(); + const DataType &input_type = input->type(); + const TensorShape &input_shape = input->shape(); + + CHECK_FAIL_RETURN_UNEXPECTED(!((fill_type == DataType::DE_STRING) && (input_type != DataType::DE_STRING)), + "Types do not match"); + + CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar"); + + std::shared_ptr out, fill_output; + + if (input_type != DataType::DE_STRING && fill_type != DataType::DE_STRING && input_type != fill_type) { + auto op = std::make_unique(input_type); + RETURN_IF_NOT_OK(op->Compute(fill_value, &fill_output)); + } else { + fill_output = fill_value; + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input_shape, input_type)); + + switch (input_type.value()) { + case DataType::DE_BOOL: { + bool value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT8: { + int8_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT8: { + uint8_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT16: { + uint16_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT16: { + int16_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT32: { + uint32_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT32: { + int32_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_UINT64: { + uint64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_INT64: { + int64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT16: { + int64_t value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT32: { + float value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_FLOAT64: { + double value = 0; + RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {})); + out->Fill(value); + break; + } + case DataType::DE_STRING: { + std::vector strings; + std::string_view fill_string_view; + RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {})); + std::string fill_string = std::string(fill_string_view); + for (int i = 0; i < input_shape.NumOfElements(); i++) { + strings.emplace_back(fill_string); + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input_shape)); + break; + } + case DataType::DE_UNKNOWN: { + RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type."); + break; + } + } + + *output = out; + return Status::OK(); +} +template +void Cast(const std::shared_ptr &input, std::shared_ptr *output) { + auto in_itr = input->begin(); + auto out_itr = (*output)->begin(); + auto out_end = (*output)->end(); + + for (; out_itr != out_end; static_cast(in_itr++), static_cast(out_itr++)) + *out_itr = static_cast(*in_itr); +} + +template +void CastFrom(const std::shared_ptr &input, std::shared_ptr *output) { + switch ((*output)->type().value()) { + case DataType::DE_BOOL: + Cast(input, output); + break; + case DataType::DE_INT8: + Cast(input, output); + break; + case DataType::DE_UINT8: + Cast(input, output); + break; + case DataType::DE_INT16: + Cast(input, output); + break; + case DataType::DE_UINT16: + Cast(input, output); + break; + case DataType::DE_INT32: + Cast(input, output); + break; + case DataType::DE_UINT32: + Cast(input, output); + break; + case DataType::DE_INT64: + Cast(input, output); + break; + case DataType::DE_UINT64: + Cast(input, output); + break; + case DataType::DE_FLOAT16: + Cast(input, output); + break; + case DataType::DE_FLOAT32: + Cast(input, output); + break; + case DataType::DE_FLOAT64: + Cast(input, output); + break; + case DataType::DE_UNKNOWN: + MS_LOG(ERROR) << "Unknown data type."; + break; + } +} + +// Type cast operator +Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type)); + + RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); + switch (input->type().value()) { + case DataType::DE_BOOL: + CastFrom(input, output); + break; + case DataType::DE_INT8: + CastFrom(input, output); + break; + case DataType::DE_UINT8: + CastFrom(input, output); + break; + case DataType::DE_INT16: + CastFrom(input, output); + break; + case DataType::DE_UINT16: + CastFrom(input, output); + break; + case DataType::DE_INT32: + CastFrom(input, output); + break; + case DataType::DE_UINT32: + CastFrom(input, output); + break; + case DataType::DE_INT64: + CastFrom(input, output); + break; + case DataType::DE_UINT64: + CastFrom(input, output); + break; + case DataType::DE_FLOAT16: + CastFrom(input, output); + break; + case DataType::DE_FLOAT32: + CastFrom(input, output); + break; + case DataType::DE_FLOAT64: + CastFrom(input, output); + break; + case DataType::DE_UNKNOWN: + // sanity check, unreachable code. + RETURN_STATUS_UNEXPECTED("TypeCast does not support input of this type."); + } + return Status::OK(); +} + +Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { + // initiate new tensor for type cast + DataType new_type = DataType("float16"); + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type)); + RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes())); + + auto in_itr = input->begin(); + auto out_itr = (*output)->begin(); + auto out_end = (*output)->end(); + + for (; out_itr != out_end; in_itr++, out_itr++) { + float element = *in_itr; + float float16_max = static_cast(std::numeric_limits::max()); + float float16_min = static_cast(std::numeric_limits::lowest()); + if (element > float16_max || element < float16_min) { + RETURN_STATUS_UNEXPECTED("Value " + std::to_string(element) + " is outside of valid float16 range [" + + std::to_string(float16_max) + ", " + std::to_string(float16_min) + "]."); + } + + *out_itr = Eigen::half(*in_itr); + } + + return Status::OK(); +} + +Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, + const std::shared_ptr &pad_val) { + if (pad_val == nullptr) { + if (src->type().IsNumeric()) { + return PadEndNumeric(src, dst, pad_shape, 0); + } else { + return PadEndString(src, dst, pad_shape, ""); + } + } + CHECK_FAIL_RETURN_UNEXPECTED(src->type().IsNumeric() == pad_val->type().IsNumeric(), + "Source and pad_value tensors are not of the same type."); + if (pad_val->type().IsNumeric()) { + std::shared_ptr float_pad_value; + RETURN_IF_NOT_OK(TypeCast(pad_val, &float_pad_value, DataType(DataType::DE_FLOAT32))); + float val = 0; + RETURN_IF_NOT_OK(float_pad_value->GetItemAt(&val, {})); + return PadEndNumeric(src, dst, pad_shape, val); + } + std::string_view val; + RETURN_IF_NOT_OK(pad_val->GetItemAt(&val, {})); + return PadEndString(src, dst, pad_shape, std::string(val)); +} + +Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, float pad_val) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); + if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { + (*dst) = src; // if no padding, copy the pointer + } else { + CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); + RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, TensorImpl::kFlexible, TensorShape(pad_shape), src->type())); + auto tensor_type = src->type().value(); + if (pad_val == 0) { // if pad with zero, don't care what type it is + RETURN_IF_NOT_OK((*dst)->Zero()); + } else if (tensor_type == DataType::DE_INT8) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_BOOL) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT8) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT16) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT16) { + RETURN_IF_NOT_OK((*dst)->Fill(static_cast(pad_val))); + } else if (tensor_type == DataType::DE_UINT16) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_INT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_UINT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT32) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else if (tensor_type == DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK((*dst)->Fill(pad_val)); + } else { + RETURN_STATUS_UNEXPECTED("Incorrect/Unknown tensor type"); + } + std::vector cur_ind(src->Rank(), 0); + RETURN_IF_NOT_OK(PadEndNumericHelper(src, *dst, cur_ind, 0)); + } + return Status::OK(); +} +Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, + std::vector cur_ind, size_t cur_dim) { + if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data + dst->CopyLastDimAt(src, cur_ind); + } else { // not the last dimension, keep doing recursion + dsize_t min_ind = std::min(dst->shape()[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + RETURN_IF_NOT_OK(PadEndNumericHelper(src, dst, cur_ind, cur_dim + 1)); + } + } + return Status::OK(); +} + +Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, const std::string &pad_val) { + CHECK_FAIL_RETURN_UNEXPECTED(src != nullptr && dst != nullptr, "tensor can't be nullptr"); + if (src->Rank() == 0 || src->shape().AsVector() == pad_shape) { + (*dst) = src; // if no padding, copy the pointer + } else { + CHECK_FAIL_RETURN_UNEXPECTED(src->Rank() == pad_shape.size(), "Pad to diff rank not allowed"); + std::vector cur_ind(src->Rank(), 0); + std::vector strings; + RETURN_IF_NOT_OK(PadEndStringHelper(src, &strings, TensorShape(pad_shape), cur_ind, 0, pad_val)); + RETURN_IF_NOT_OK(Tensor::CreateTensor(dst, strings, TensorShape(pad_shape))); + } + return Status::OK(); +} + +Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, + const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, + const std::string &pad_value) { + if (cur_dim == src->Rank() - 1) { // if this is the last dimension, copy the data + dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + std::string_view item; + RETURN_IF_NOT_OK(src->GetItemAt(&item, cur_ind)); + dst->emplace_back(item); + } + for (dsize_t i = min_ind; i < dst_shape[cur_dim]; i++) { + dst->emplace_back(pad_value); + } + + } else { // not the last dimension, keep doing recursion + dsize_t min_ind = std::min(dst_shape[cur_dim], src->shape()[cur_dim]); + for (dsize_t i = 0; i < min_ind; i++) { + cur_ind[cur_dim] = i; + RETURN_IF_NOT_OK(PadEndStringHelper(src, dst, dst_shape, cur_ind, cur_dim + 1, pad_value)); + } + dsize_t count = (dst_shape[cur_dim] - min_ind) * dst_shape.Strides()[cur_dim]; + for (dsize_t i = 0; i < count; i++) { + dst->emplace_back(pad_value); + } + } + return Status::OK(); +} + +template +Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, + const std::shared_ptr &value_tensor, RelationalOp op) { + T value; + RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {})); + auto in_itr = input->begin(); + auto out_itr = output->begin(); + for (; in_itr != input->end(); in_itr++, out_itr++) { + switch (op) { + case RelationalOp::kEqual: + *out_itr = (*in_itr == value); + break; + case RelationalOp::kNotEqual: + *out_itr = (*in_itr != value); + break; + case RelationalOp::kGreater: + *out_itr = (*in_itr > value); + break; + case RelationalOp::kGreaterEqual: + *out_itr = (*in_itr >= value); + break; + case RelationalOp::kLess: + *out_itr = (*in_itr < value); + break; + case RelationalOp::kLessEqual: + *out_itr = (*in_itr <= value); + break; + default: + RETURN_STATUS_UNEXPECTED("Unknown relational operator."); + } + } + return Status::OK(); +} + +Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, + RelationalOp op) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type().IsNumeric() == value->type().IsNumeric(), + "Cannot convert constant value to the type of the input tensor."); + CHECK_FAIL_RETURN_UNEXPECTED(value->shape() == TensorShape::CreateScalar(), "Value is not a scalar"); + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), DataType(DataType::DE_BOOL))); + + std::unique_ptr value_cast_op(new TypeCastOp(input->type())); + std::shared_ptr casted_value; + if (input->type().IsNumeric()) { + RETURN_IF_NOT_OK(value_cast_op->Compute(value, &casted_value)); + } else { + casted_value = value; + } + + switch (input->type().value()) { + case DataType::DE_BOOL: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT8: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT8: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UINT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_INT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT16: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT32: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_FLOAT64: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_STRING: + RETURN_IF_NOT_OK(MaskHelper(input, *output, casted_value, op)); + break; + case DataType::DE_UNKNOWN: + RETURN_STATUS_UNEXPECTED("Unsupported input type."); + break; + } + return Status::OK(); +} + +Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, + std::shared_ptr append) { + CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Rank() == 1, "Only 1D tensors supported"); + CHECK_FAIL_RETURN_UNEXPECTED(axis == 0 || axis == -1, "Only concatenation along the last dimension supported"); + + axis = Tensor::HandleNeg(axis, input[0]->shape().Rank()); + CHECK_FAIL_RETURN_UNEXPECTED(axis == 0, "Only axis=0 is supported"); + + std::shared_ptr out; + if (prepend != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(prepend->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(prepend, &out, axis, input[0])); + } else { + out = input[0]; + } + for (dsize_t i = 1; i < input.size(); i++) { + std::shared_ptr out_t; + CHECK_FAIL_RETURN_UNEXPECTED(input[i]->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, input[i])); + out = out_t; + } + std::shared_ptr out_t; + if (append != nullptr) { + CHECK_FAIL_RETURN_UNEXPECTED(append->shape().Rank() == 1, "Only 1D tensors supported"); + RETURN_IF_NOT_OK(ConcatenateHelper(out, &out_t, axis, append)); + } else { + out_t = out; + } + output->push_back(out_t); + + return Status::OK(); +} + +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == append->type(), "Tensor types do not match"); + + TensorShape t({}); + + for (dsize_t i = 0; i < input->shape().Rank(); i++) { + if (i != axis) { + t = t.AppendDim(input->shape()[i]); + } else { + dsize_t new_shape = input->shape()[i] + append->shape()[i]; + + t = t.AppendDim(new_shape); + } + } + std::shared_ptr out; + + if (input->type().IsNumeric()) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, t, input->type())); + + RETURN_IF_NOT_OK(out->Concatenate({0}, input)); + RETURN_IF_NOT_OK(out->Concatenate({input->shape()[0]}, append)); + *output = out; + } else { + std::vector strings; + + auto itr = input->begin(); + for (; itr != input->end(); itr++) { + strings.emplace_back(*itr); + } + itr = append->begin(); + for (; itr != append->end(); itr++) { + strings.emplace_back(*itr); + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, t)); + + *output = out; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h new file mode 100644 index 0000000000..141545a583 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_DATA_UTILS_H_ +#define DATASET_KERNELS_DATA_DATA_UTILS_H_ + +#include +#include +#include +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" + +namespace mindspore { +namespace dataset { +// Returns Onehot encoding of the input tensor. +// Example: if input=2 and numClasses=3, the output is [0 0 1]. +// @param input: Tensor has type DE_UINT64, the non-one hot values are stored +// along the first dimensions or rows.. +// If the rank of input is not 1 or the type is not DE_UINT64, +// then it will fail. +// @param output: Tensor. The shape of the output tensor is +// and the type is same as input. +// @param num_classes: Number of classes to. +Status OneHotEncoding(std::shared_ptr input, std::shared_ptr *output, dsize_t num_classes); + +Status OneHotEncodingUnsigned(const std::shared_ptr &input, std::shared_ptr *output, + dsize_t num_classes, int64_t index); + +Status OneHotEncodingSigned(const std::shared_ptr &input, std::shared_ptr *output, dsize_t num_classes, + int64_t index); + +// Returns a tensor of shape input filled with the passed fill_value +// @param input Tensor +// @param output Tensor. The shape and type of the output tensor is same as input +// @param fill_value Tensor. A scalar tensor used to fill the output tensor + +Status Fill(const std::shared_ptr input, std::shared_ptr *output, std::shared_ptr fill_value); + +// Returns a type changed input tensor. +// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp +// @param input Tensor +// @param output Tensor. The shape of the output tensor is same as input with the type changed. +// @param data_type: type of data to cast data to +// @note: this operation will do a memcpy and if the value is truncated then precision will be lost + +template +void CastFrom(const std::shared_ptr &input, std::shared_ptr *output); + +template +void Cast(const std::shared_ptr &input, std::shared_ptr *output); + +Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); + +Status TypeCast(const std::shared_ptr &input, std::shared_ptr *output, const DataType &data_type); + +// Pad input tensor according pad_shape, need to have same rank. +// Based on the type of the input tensor, PadEndNumeric/String will be called. +// @param std::shared_ptr src - tensor to pad from +// @param std::shared_ptr *dst - return tensor padded +// @param std::vector pad_shape - shape to pad to +// @param std::shared_ptr pad_val - value to pad with in Tensor format, +// @return - The error code return +Status PadEnd(const std::shared_ptr &src, std::shared_ptr *dst, const std::vector &pad_shape, + const std::shared_ptr &pad_val); + +// Pad input numeric tensor according pad_shape, need to have same rank. +// @param std::shared_ptr src - tensor to pad from +// @param std::shared_ptr *dst - return tensor padded +// @param std::vector pad_shape - shape to pad to +// @param float pad_val - value to pad with +// @return - The error code return +Status PadEndNumeric(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, float pad_val); + +// recursive helper function for padding numric tensors. This function could be very expensive if called on a +// multi-dimensional tensor it is only meant to be called by PadEndNumeric. +// @tparam T - type of tensor and fill value +// @param std::shared_ptr src - Tensor to pad from +// @param std::shared_ptr* dst - Tensor to pad to, return value +// @param std::vector cur_ind - recursion helper +// @param T pad_val - value to pad tensor with +// @param size_t cur_dim - recursion helper +// @return Status - The error code return +Status PadEndNumericHelper(const std::shared_ptr &src, std::shared_ptr dst, + std::vector cur_ind, size_t cur_dim = 0); + +// Pad input string tensor according pad_shape, need to have same rank. +// @param std::shared_ptr src - tensor to pad from +// @param std::shared_ptr *dst - return tensor padded +// @param std::vector pad_shape - shape to pad to +// @param std::string pad_val - value to pad with +// @return - The error code return +Status PadEndString(const std::shared_ptr &src, std::shared_ptr *dst, + const std::vector &pad_shape, const std::string &pad_val); + +// recursive helper function for padding string tensors. This function could be very expensive if called on a +// multi-dimensional tensor it is only meant to be called by PadEndString. +// @tparam T - type of tensor and fill value +// @param std::shared_ptr src - Tensor to pad from +// @param std::shared_ptr* dst - Tensor to pad to, return value +// @param std::vector cur_ind - recursion helperas text +// @param std::string pad_val - value to pad tensor with +// @param size_t cur_dim - recursion helper +// @return Status - The error code return +Status PadEndStringHelper(const std::shared_ptr &src, std::vector *dst, + const TensorShape &dst_shape, std::vector cur_ind, size_t cur_dim, + const std::string &pad_value); + +enum class RelationalOp { + kEqual = 0, // == + kNotEqual, // != + kLess, // < + kLessEqual, // <= + kGreater, // > + kGreaterEqual, // >= +}; + +/// Helper method that masks the input tensor +/// @tparam T type of the tensor +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @param value_tensor[in] scalar tensor value to compared with +/// @param op[in] RelationalOp enum +/// @return Status ok/error +template +Status MaskHelper(const std::shared_ptr &input, const std::shared_ptr &output, + const std::shared_ptr &value_tensor, RelationalOp op); + +/// Mask the input tensor +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @param value[in] scalar tensor value to compared with +/// @param op[in] RelationalOp enum +/// @return Status ok/error +Status Mask(const std::shared_ptr &input, std::shared_ptr *output, const std::shared_ptr &value, + RelationalOp op); + +Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, + std::shared_ptr append); + +// helper for concat, always append to the input, and pass that to the output +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append); + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_DATA_DATA_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc new file mode 100644 index 0000000000..57a424704f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/duplicate_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status DuplicateOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, input[0])); + output->push_back(input[0]); + output->push_back(out); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h new file mode 100644 index 0000000000..60b2d8c33b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/duplicate_op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_DUPLICATE_OP_H_ +#define DATASET_KERNELS_DATA_DUPLICATE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class DuplicateOp : public TensorOp { + public: + DuplicateOp() = default; + + ~DuplicateOp() override = default; + + void Print(std::ostream &out) const override { out << "DuplicateOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + uint32_t NumOutput() override { return 2; } + + std::string Name() const override { return kDuplicateOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DUPLICATE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.cc new file mode 100644 index 0000000000..f8dc746dff --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/fill_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status FillOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = Fill(input, output, fill_value_); + return s; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h new file mode 100644 index 0000000000..af0d9e7941 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/fill_op.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_DATA_FILL_OP_H_ +#define DATASET_KERNELS_DATA_FILL_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class FillOp : public TensorOp { + public: + explicit FillOp(std::shared_ptr value) : fill_value_(value) {} + + ~FillOp() override = default; + void Print(std::ostream &out) const override { out << "FillOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kFillOp; } + + private: + std::shared_ptr fill_value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_FILL_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.cc new file mode 100644 index 0000000000..2dbe501a47 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/data/mask_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +Status MaskOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::shared_ptr temp_output; + CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(), "Cannot generate a string mask. Type should be numeric."); + + RETURN_IF_NOT_OK(Mask(input, &temp_output, value_, op_)); + + // cast the output to the the required type. Skip casting if type_ is bool. + if (type_ != DataType::DE_BOOL) { + RETURN_IF_NOT_OK(cast_->Compute(temp_output, output)); + } else { + *output = std::move(temp_output); + } + + return Status::OK(); +} + +Status MaskOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = type_; + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h new file mode 100644 index 0000000000..e6ac8c3964 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/mask_op.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_MASK_OP_H_ +#define DATASET_KERNELS_DATA_MASK_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class MaskOp : public TensorOp { + public: + MaskOp(RelationalOp op, std::shared_ptr value, DataType type = DataType(DataType::DE_BOOL)) + : op_(op), value_(std::move(value)), type_(type), cast_(new TypeCastOp(type)) {} + + ~MaskOp() override = default; + + void Print(std::ostream &out) const override { out << "MaskOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kMaskOp; } + + private: + RelationalOp op_; + std::shared_ptr value_; + DataType type_; + std::unique_ptr cast_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_MASK_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.cc new file mode 100644 index 0000000000..e2b7b74a96 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/one_hot_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status OneHotOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = OneHotEncoding(input, output, num_classes_); + return s; +} + +Status OneHotOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + std::vector inputs_copy; + inputs_copy.push_back(inputs[0].Squeeze()); + if (inputs_copy[0].Rank() == 0) outputs.emplace_back(std::vector{num_classes_}); + if (inputs_copy[0].Rank() == 1) outputs.emplace_back(std::vector{inputs_copy[0][0], num_classes_}); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h new file mode 100644 index 0000000000..06a4823573 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/one_hot_op.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_ONE_HOT_OP_H_ +#define DATASET_KERNELS_DATA_ONE_HOT_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class OneHotOp : public TensorOp { + public: + explicit OneHotOp(int num_classes) : num_classes_(num_classes) {} + + ~OneHotOp() override = default; + + void Print(std::ostream &out) const override { out << "OneHotOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kOneHotOp; } + + private: + int num_classes_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.cc new file mode 100644 index 0000000000..7b83137d88 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/pad_end_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status PadEndOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + Status s = PadEnd(input, output, output_shape_.AsVector(), pad_val_); + return s; +} + +Status PadEndOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + for (auto s : inputs) { + outputs.emplace_back(TensorShape(output_shape_.AsVector())); + } + CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "Input has a wrong shape"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h new file mode 100644 index 0000000000..c28f7250e0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/pad_end_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_PAD_END_OP_H_ +#define DATASET_KERNELS_DATA_PAD_END_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class PadEndOp : public TensorOp { + public: + explicit PadEndOp(const TensorShape &pad_shape, const std::shared_ptr &pad_value) + : output_shape_(pad_shape), pad_val_(pad_value) {} + + ~PadEndOp() override = default; + + void Print(std::ostream &out) const override { out << "PadEndOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kPadEndOp; } + + private: + TensorShape output_shape_; + std::shared_ptr pad_val_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_PAD_END_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.cc new file mode 100644 index 0000000000..66f48d5c2b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/slice_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status SliceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now."); + + // if `all` flag is true, output is just the input. + if (all_) { + *output = input; + return Status::OK(); + } + + // if slice object was provided, indices should be empty. Generate indices from the slice object. + if (slice_.valid() && indices_.empty()) { + dsize_t len = input->shape()[0]; + std::vector indices = slice_.Indices(len); + return input->Slice(output, indices); + } + + // if indices are not empty, slices should be invalid, use indices_ to slice + if (!indices_.empty() && !slice_.valid()) { + return input->Slice(output, indices_); + } + RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h new file mode 100644 index 0000000000..1cf99830c9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/slice_op.h @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ +#define DATASET_KERNELS_DATA_SLICE_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class Slice { + public: + Slice() : start_(0), stop_(0), step_(0) {} + Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {} + Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} + explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} + + ~Slice() = default; + + std::vector Indices(dsize_t length) { + std::vector indices; + dsize_t index = std::min(Tensor::HandleNeg(start_, length), length); + dsize_t end_index = std::min(Tensor::HandleNeg(stop_, length), length); + if (step_ > 0) { + for (; index < end_index; index += step_) { + indices.push_back(index); + } + } else { + for (; index > end_index; index += step_) { + indices.push_back(index); + } + } + return indices; + } + + bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); } + + dsize_t start_; + dsize_t stop_; + dsize_t step_; +}; + +class SliceOp : public TensorOp { + public: + explicit SliceOp(std::vector indices) : indices_(std::move(indices)) {} + explicit SliceOp(Slice slice) : slice_(slice) {} + explicit SliceOp(bool all) : all_(all) {} + + ~SliceOp() override = default; + + void Print(std::ostream &out) const override { out << "SliceOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kSliceOp; } + + private: + // only on of the following will be valid + // given indices to slice the Tensor. Empty vector if invalid. + std::vector indices_; + // Slice object. All start, stop and step are 0 if invalid. + Slice slice_; + // Flag to read all indcies in the dim. + bool all_ = false; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_SLICE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.cc new file mode 100644 index 0000000000..c52162b1aa --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status ToFloat16Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + return ToFloat16(input, output); +} +Status ToFloat16Op::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = DataType(DataType::DE_FLOAT16); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h new file mode 100644 index 0000000000..91f660ca9c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/to_float16_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDDATA_TOFLOAT16OP_H +#define MINDDATA_TOFLOAT16OP_H + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class ToFloat16Op : public TensorOp { + public: + ToFloat16Op() = default; + + ~ToFloat16Op() override = default; + + // Overrides the base class compute function + // Calls the ToFloat16 function in ImageUtils, this function takes an input tensor + // and transforms its data to float16, the output memory is manipulated to contain the result + // @return Status - The error code return + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + void Print(std::ostream &out) const override { out << "ToFloat16Op"; } + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kToFloat16Op; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDDATA_TOFLOAT16OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc new file mode 100644 index 0000000000..5a58745293 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/data/type_cast_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +TypeCastOp::TypeCastOp(const DataType &new_type) : type_(new_type) {} + +TypeCastOp::TypeCastOp(const std::string &data_type) { type_ = DataType(data_type); } + +Status TypeCastOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + return TypeCast(input, output, type_); +} +Status TypeCastOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = type_; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h new file mode 100644 index 0000000000..b82bc32342 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ +#define DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class TypeCastOp : public TensorOp { + public: + // Constructor for TypecastOp + // @param data_type datatype to cast to + explicit TypeCastOp(const DataType &data_type); + + // Constructor for TypecastOp + // @param data_type datatype to cast to + explicit TypeCastOp(const std::string &data_type); + + ~TypeCastOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + void Print(std::ostream &out) const override { out << "TypeCastOp"; } + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kTypeCastOp; } + + private: + DataType type_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_DATA_TYPE_CAST_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.cc new file mode 100644 index 0000000000..618ed4d356 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/cv_tensor.h" + +namespace mindspore { +namespace dataset { +const float BoundingBoxAugmentOp::kDefRatio = 0.3; + +BoundingBoxAugmentOp::BoundingBoxAugmentOp(std::shared_ptr transform, float ratio) + : ratio_(ratio), uniform_(0, 1), transform_(std::move(transform)) { + rnd_.seed(GetSeed()); +} + +Status BoundingBoxAugmentOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); // check if bounding boxes are valid + uint32_t num_of_boxes = input[1]->shape()[0]; + std::shared_ptr crop_out; + std::shared_ptr res_out; + std::shared_ptr input_restore = CVTensor::AsCVTensor(input[0]); + for (uint32_t i = 0; i < num_of_boxes; i++) { + // using a uniform distribution to ensure op happens with probability ratio_ + if (uniform_(rnd_) < ratio_) { + float min_x = 0; + float min_y = 0; + float b_w = 0; + float b_h = 0; + // get the required items + RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_x, {i, 0})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_y, {i, 1})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_w, {i, 2})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_h, {i, 3})); + RETURN_IF_NOT_OK(Crop(input_restore, &crop_out, static_cast(min_x), static_cast(min_y), + static_cast(b_w), static_cast(b_h))); + // transform the cropped bbox region + RETURN_IF_NOT_OK(transform_->Compute(crop_out, &res_out)); + // place the transformed region back in the restored input + std::shared_ptr res_img = CVTensor::AsCVTensor(res_out); + // check if transformed crop is out of bounds of the box + if (res_img->mat().cols > b_w || res_img->mat().rows > b_h || res_img->mat().cols < b_w || + res_img->mat().rows < b_h) { + // if so, resize to fit in the box + std::shared_ptr resize_op = + std::make_shared(static_cast(b_h), static_cast(b_w)); + RETURN_IF_NOT_OK(resize_op->Compute(std::static_pointer_cast(res_img), &res_out)); + res_img = CVTensor::AsCVTensor(res_out); + } + res_img->mat().copyTo(input_restore->mat()(cv::Rect(min_x, min_y, res_img->mat().cols, res_img->mat().rows))); + } + } + (*output).push_back(std::move(std::static_pointer_cast(input_restore))); + (*output).push_back(input[1]); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h new file mode 100644 index 0000000000..8e30c5738d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/bounding_box_augment_op.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ +#define DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class BoundingBoxAugmentOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefRatio; + + // Constructor for BoundingBoxAugmentOp + // @param std::shared_ptr transform transform: C++ opration to apply on select bounding boxes + // @param float ratio: ratio of bounding boxes to have the transform applied on + BoundingBoxAugmentOp(std::shared_ptr transform, float ratio); + + ~BoundingBoxAugmentOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BoundingBoxAugmentOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << "BoundingBoxAugmentOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kBoundingBoxAugmentOp; } + + private: + float ratio_; + std::mt19937 rnd_; + std::uniform_real_distribution uniform_; + std::shared_ptr transform_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_BOUNDING_BOX_AUGMENT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc new file mode 100644 index 0000000000..35079b05cd --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include +#include "common/utils.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t CenterCropOp::kDefWidth = 0; + +Status CenterCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::string err_msg; + dsize_t rank = input->shape().Rank(); + err_msg += (rank < 2 || rank > 3) ? "Rank received::" + std::to_string(rank) + " Expected: 2 or 3 \t" : ""; + err_msg += (crop_het_ <= 0 || crop_wid_ <= 0) ? "crop size needs to be positive integers\t" : ""; + + if (err_msg.length() != 0) RETURN_STATUS_UNEXPECTED(common::SafeCStr(err_msg)); + + int32_t top = crop_het_ - input->shape()[0]; // number of pixels to pad (top and bottom) + int32_t left = crop_wid_ - input->shape()[1]; + std::shared_ptr pad_image; + if (top > 0 && left > 0) { // padding only + return Pad(input, output, top / 2 + top % 2, top / 2, left / 2 + left % 2, left / 2, BorderType::kConstant); + } else if (top > 0) { + RETURN_IF_NOT_OK(Pad(input, &pad_image, top / 2 + top % 2, top / 2, 0, 0, BorderType::kConstant)); + return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, + (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); + } else if (left > 0) { + RETURN_IF_NOT_OK(Pad(input, &pad_image, 0, 0, left / 2 + left % 2, left / 2, BorderType::kConstant)); + return Crop(pad_image, output, (static_cast(pad_image->shape()[1]) - crop_wid_) / 2, + (static_cast(pad_image->shape()[0]) - crop_het_) / 2, crop_wid_, crop_het_); + } + return Crop(input, output, (input->shape()[1] - crop_wid_) / 2, (input->shape()[0] - crop_het_) / 2, crop_wid_, + crop_het_); +} + +void CenterCropOp::Print(std::ostream &out) const { + out << "CenterCropOp: " + << "cropWidth: " << crop_wid_ << "cropHeight: " << crop_het_ << "\n"; +} +Status CenterCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{crop_het_, crop_wid_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h new file mode 100644 index 0000000000..1f8cbcf230 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ +#define DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CenterCropOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefWidth; + + explicit CenterCropOp(int32_t het, int32_t wid = kDefWidth) : crop_het_(het), crop_wid_(wid == 0 ? het : wid) {} + + ~CenterCropOp() override = default; + + void Print(std::ostream &out) const override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kCenterCropOp; } + + private: + int32_t crop_het_; + int32_t crop_wid_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_CENTER_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.cc new file mode 100644 index 0000000000..578138d427 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/kernels/image/cut_out_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const bool CutOutOp::kDefRandomColor = false; +const uint8_t CutOutOp::kDefFillR = 0; +const uint8_t CutOutOp::kDefFillG = 0; +const uint8_t CutOutOp::kDefFillB = 0; + +// constructor +CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, + uint8_t fill_g, uint8_t fill_b) + : rnd_(GetSeed()), + box_height_(box_height), + box_width_(box_width), + num_patches_(num_patches), + random_color_(random_color), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) {} + +// main function call for cut out +Status CutOutOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::shared_ptr inputCV = CVTensor::AsCVTensor(input); + // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black + RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, + fill_g_, fill_b_)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h new file mode 100644 index 0000000000..263cbdb27c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/cut_out_op.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ +#define DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CutOutOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const bool kDefRandomColor; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + // Constructor for CutOutOp + // @param box_height box height + // @param box_width box_width + // @param num_patches how many patches to erase from image + // @param random_color boolean value to indicate fill patch with random color + // @param fill_r R value for the color to fill patch with + // @param fill_g G value for the color to fill patch with + // @param fill_b B value for the color to fill patch with + // @note maybe using unsigned long int isn't the best here according to our coding rules + CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color = kDefRandomColor, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + ~CutOutOp() override = default; + + void Print(std::ostream &out) const override { + out << "CutOut:: box_height: " << box_height_ << " box_width: " << box_width_ << " num_patches: " << num_patches_; + } + + // Overrides the base class compute function + // Calls the erase function in ImageUtils, this function takes an input tensor + // and overwrites some of its data using openCV, the output memory is manipulated to contain the result + // @return Status - The error code return + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kCutOutOp; } + + private: + std::mt19937 rnd_; + int32_t box_height_; + int32_t box_width_; + int32_t num_patches_; + bool random_color_; + uint8_t fill_r_; + uint8_t fill_g_; + uint8_t fill_b_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_CUT_OUT_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.cc new file mode 100644 index 0000000000..5bc5377de9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/decode_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const bool DecodeOp::kDefRgbFormat = true; + +DecodeOp::DecodeOp(bool is_rgb_format) : is_rgb_format_(is_rgb_format) { + if (is_rgb_format_) { // RGB colour mode + MS_LOG(DEBUG) << "Decode colour mode is RGB."; + } else { + MS_LOG(DEBUG) << "Decode colour mode is BGR."; + } +} + +Status DecodeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (is_rgb_format_) { // RGB colour mode + return Decode(input, output); + } else { // BGR colour mode + RETURN_STATUS_UNEXPECTED("Decode BGR is deprecated"); + } +} +Status DecodeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels + if (inputs[0].Rank() == 1) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} + +Status DecodeOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = DataType(DataType::DE_UINT8); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h new file mode 100644 index 0000000000..29bf1d0146 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/decode_op.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_DECODE_OP_H_ +#define DATASET_KERNELS_IMAGE_DECODE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DecodeOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const bool kDefRgbFormat; + + explicit DecodeOp(bool is_rgb_format = true); + + ~DecodeOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + void Print(std::ostream &out) const override { out << "DecodeOp"; } + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDecodeOp; } + + private: + bool is_rgb_format_ = true; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_DECODE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.cc new file mode 100644 index 0000000000..5013958562 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status HwcToChwOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // input.shape == HWC + // output.shape == CHW + return HwcToChw(input, output); +} +Status HwcToChwOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape in = inputs[0]; + TensorShape out = TensorShape{in[2], in[0], in[1]}; + if (inputs[0].Rank() == 3) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h new file mode 100644 index 0000000000..0d5f70f895 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/hwc_to_chw_op.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ +#define DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class HwcToChwOp : public TensorOp { + public: + void Print(std::ostream &out) const override { out << "HwcToChw"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kHwcToChwOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_CHANNEL_SWAP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc new file mode 100644 index 0000000000..ddbce3e23a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -0,0 +1,836 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/image_utils.h" +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/util/random.h" + +#define MAX_INT_PRECISION 16777216 // float int precision is 16777216 +namespace mindspore { +namespace dataset { +int GetCVInterpolationMode(InterpolationMode mode) { + switch (mode) { + case InterpolationMode::kLinear: + return static_cast(cv::InterpolationFlags::INTER_LINEAR); + case InterpolationMode::kCubic: + return static_cast(cv::InterpolationFlags::INTER_CUBIC); + case InterpolationMode::kArea: + return static_cast(cv::InterpolationFlags::INTER_AREA); + case InterpolationMode::kNearestNeighbour: + return static_cast(cv::InterpolationFlags::INTER_NEAREST); + default: + return static_cast(cv::InterpolationFlags::INTER_LINEAR); + } +} + +int GetCVBorderType(BorderType type) { + switch (type) { + case BorderType::kConstant: + return static_cast(cv::BorderTypes::BORDER_CONSTANT); + case BorderType::kEdge: + return static_cast(cv::BorderTypes::BORDER_REPLICATE); + case BorderType::kReflect: + return static_cast(cv::BorderTypes::BORDER_REFLECT101); + case BorderType::kSymmetric: + return static_cast(cv::BorderTypes::BORDER_REFLECT); + default: + return static_cast(cv::BorderTypes::BORDER_CONSTANT); + } +} + +Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); + + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes())); + + if (input_cv->mat().data) { + try { + cv::flip(input_cv->mat(), output_cv->mat(), flip_code); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in flip op."); + } + } else { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor, the input data is null"); + } +} + +Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output) { + return Flip(std::move(input), output, 1); +} + +Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output) { + return Flip(std::move(input), output, 0); +} + +Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, + int32_t output_width, double fx, double fy, InterpolationMode mode) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of or "); + } + cv::Mat in_image = input_cv->mat(); + // resize image too large or too small + if (output_height == 0 || output_height > in_image.rows * 1000 || output_width == 0 || + output_width > in_image.cols * 1000) { + std::string err_msg = + "The resizing width or height 1) is too big, it's up to " + "1000 times the original image; 2) can not be 0."; + return Status(StatusCode::kShapeMisMatch, err_msg); + } + try { + TensorShape shape{output_height, output_width}; + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); + std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + auto cv_mode = GetCVInterpolationMode(mode); + cv::resize(in_image, output_cv->mat(), cv::Size(output_width, output_height), fx, fy, cv_mode); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image resize."); + } +} + +bool IsNonEmptyJPEG(const std::shared_ptr &input) { + const unsigned char *kJpegMagic = (unsigned char *)"\xFF\xD8\xFF"; + constexpr size_t kJpegMagicLen = 3; + return input->SizeInBytes() > kJpegMagicLen && memcmp(input->GetBuffer(), kJpegMagic, kJpegMagicLen) == 0; +} + +Status Decode(const std::shared_ptr &input, std::shared_ptr *output) { + if (IsNonEmptyJPEG(input)) { + return JpegCropAndDecode(input, output); + } else { + return DecodeCv(input, output); + } +} + +Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + try { + cv::Mat img_mat = cv::imdecode(input_cv->mat(), cv::IMREAD_COLOR | cv::IMREAD_IGNORE_ORIENTATION); + if (img_mat.data == nullptr) { + std::string err = "Error in decoding\t"; + RETURN_STATUS_UNEXPECTED(err); + } + cv::cvtColor(img_mat, img_mat, static_cast(cv::COLOR_BGR2RGB)); + std::shared_ptr output_cv = std::make_shared(img_mat); + RETURN_UNEXPECTED_IF_NULL(output_cv); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image Decode"); + } +} + +static void JpegInitSource(j_decompress_ptr cinfo) {} + +static boolean JpegFillInputBuffer(j_decompress_ptr cinfo) { + if (cinfo->src->bytes_in_buffer == 0) { + ERREXIT(cinfo, JERR_INPUT_EMPTY); + return FALSE; + } + return TRUE; +} + +static void JpegTermSource(j_decompress_ptr cinfo) {} + +static void JpegSkipInputData(j_decompress_ptr cinfo, int64_t jump) { + if (jump < 0) { + return; + } + if (static_cast(jump) > cinfo->src->bytes_in_buffer) { + cinfo->src->bytes_in_buffer = 0; + return; + } else { + cinfo->src->bytes_in_buffer -= jump; + cinfo->src->next_input_byte += jump; + } +} + +void JpegSetSource(j_decompress_ptr cinfo, const void *data, int64_t datasize) { + cinfo->src = static_cast( + (*cinfo->mem->alloc_small)(reinterpret_cast(cinfo), JPOOL_PERMANENT, sizeof(struct jpeg_source_mgr))); + cinfo->src->init_source = JpegInitSource; + cinfo->src->fill_input_buffer = JpegFillInputBuffer; +#if defined(_WIN32) || defined(_WIN64) + cinfo->src->skip_input_data = reinterpret_cast(JpegSkipInputData); +#else + cinfo->src->skip_input_data = JpegSkipInputData; +#endif + cinfo->src->resync_to_restart = jpeg_resync_to_restart; + cinfo->src->term_source = JpegTermSource; + cinfo->src->bytes_in_buffer = datasize; + cinfo->src->next_input_byte = static_cast(data); +} + +static Status JpegReadScanlines(jpeg_decompress_struct *const cinfo, int max_scanlines_to_read, JSAMPLE *buffer, + int buffer_size, int crop_w, int crop_w_aligned, int offset, int stride) { + // scanlines will be read to this buffer first, must have the number + // of components equal to the number of components in the image + int64_t scanline_size = crop_w_aligned * cinfo->output_components; + std::vector scanline(scanline_size); + JSAMPLE *scanline_ptr = &scanline[0]; + while (cinfo->output_scanline < static_cast(max_scanlines_to_read)) { + int num_lines_read = jpeg_read_scanlines(cinfo, &scanline_ptr, 1); + if (cinfo->out_color_space == JCS_CMYK && num_lines_read > 0) { + for (int i = 0; i < crop_w; ++i) { + int cmyk_pixel = 4 * i + offset; + const int c = scanline_ptr[cmyk_pixel]; + const int m = scanline_ptr[cmyk_pixel + 1]; + const int y = scanline_ptr[cmyk_pixel + 2]; + const int k = scanline_ptr[cmyk_pixel + 3]; + int r, g, b; + if (cinfo->saw_Adobe_marker) { + r = (k * c) / 255; + g = (k * m) / 255; + b = (k * y) / 255; + } else { + r = (255 - c) * (255 - k) / 255; + g = (255 - m) * (255 - k) / 255; + b = (255 - y) * (255 - k) / 255; + } + buffer[3 * i + 0] = r; + buffer[3 * i + 1] = g; + buffer[3 * i + 2] = b; + } + } else if (num_lines_read > 0) { + int copy_status = memcpy_s(buffer, buffer_size, scanline_ptr + offset, stride); + if (copy_status != 0) { + jpeg_destroy_decompress(cinfo); + RETURN_STATUS_UNEXPECTED("memcpy failed"); + } + } else { + jpeg_destroy_decompress(cinfo); + std::string err_msg = "failed to read scanline"; + RETURN_STATUS_UNEXPECTED(err_msg); + } + buffer += stride; + buffer_size = buffer_size - stride; + } + return Status::OK(); +} + +static Status JpegSetColorSpace(jpeg_decompress_struct *cinfo) { + switch (cinfo->num_components) { + case 1: + // we want to output 3 components if it's grayscale + cinfo->out_color_space = JCS_RGB; + return Status::OK(); + case 3: + cinfo->out_color_space = JCS_RGB; + return Status::OK(); + case 4: + // Need to manually convert to RGB + cinfo->out_color_space = JCS_CMYK; + return Status::OK(); + default: + jpeg_destroy_decompress(cinfo); + std::string err_msg = "wrong number of components"; + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +void JpegErrorExitCustom(j_common_ptr cinfo) { + char jpeg_last_error_msg[JMSG_LENGTH_MAX]; + (*(cinfo->err->format_message))(cinfo, jpeg_last_error_msg); + throw std::runtime_error(jpeg_last_error_msg); +} + +Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int crop_x, int crop_y, + int crop_w, int crop_h) { + struct jpeg_decompress_struct cinfo; + auto DestroyDecompressAndReturnError = [&cinfo](const std::string &err) { + jpeg_destroy_decompress(&cinfo); + RETURN_STATUS_UNEXPECTED(err); + }; + struct JpegErrorManagerCustom jerr; + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = JpegErrorExitCustom; + try { + jpeg_create_decompress(&cinfo); + JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); + (void)jpeg_read_header(&cinfo, TRUE); + RETURN_IF_NOT_OK(JpegSetColorSpace(&cinfo)); + jpeg_calc_output_dimensions(&cinfo); + } catch (std::runtime_error &e) { + return DestroyDecompressAndReturnError(e.what()); + } + if (crop_x == 0 && crop_y == 0 && crop_w == 0 && crop_h == 0) { + crop_w = cinfo.output_width; + crop_h = cinfo.output_height; + } else if (crop_w == 0 || static_cast(crop_w + crop_x) > cinfo.output_width || crop_h == 0 || + static_cast(crop_h + crop_y) > cinfo.output_height) { + return DestroyDecompressAndReturnError("Crop window is not valid"); + } + const int mcu_size = cinfo.min_DCT_scaled_size; + unsigned int crop_x_aligned = (crop_x / mcu_size) * mcu_size; + unsigned int crop_w_aligned = crop_w + crop_x - crop_x_aligned; + try { + (void)jpeg_start_decompress(&cinfo); + jpeg_crop_scanline(&cinfo, &crop_x_aligned, &crop_w_aligned); + } catch (std::runtime_error &e) { + return DestroyDecompressAndReturnError(e.what()); + } + JDIMENSION skipped_scanlines = jpeg_skip_scanlines(&cinfo, crop_y); + // three number of output components, always convert to RGB and output + constexpr int kOutNumComponents = 3; + TensorShape ts = TensorShape({crop_h, crop_w, kOutNumComponents}); + auto output_tensor = std::make_shared(ts, DataType(DataType::DE_UINT8)); + const int buffer_size = output_tensor->SizeInBytes(); + JSAMPLE *buffer = reinterpret_cast(&(*output_tensor->begin())); + const int max_scanlines_to_read = skipped_scanlines + crop_h; + // stride refers to output tensor, which has 3 components at most + const int stride = crop_w * kOutNumComponents; + // offset is calculated for scanlines read from the image, therefore + // has the same number of components as the image + const int offset = (crop_x - crop_x_aligned) * cinfo.output_components; + RETURN_IF_NOT_OK( + JpegReadScanlines(&cinfo, max_scanlines_to_read, buffer, buffer_size, crop_w, crop_w_aligned, offset, stride)); + *output = output_tensor; + jpeg_destroy_decompress(&cinfo); + return Status::OK(); +} + +Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + cv::Mat input_image = input_cv->mat(); + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); + RETURN_UNEXPECTED_IF_NULL(output_cv); + try { + input_image.convertTo(output_cv->mat(), CV_32F, rescale, shift); + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image rescale"); + } + return Status::OK(); +} + +Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + try { + TensorShape shape{h, w}; + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); + std::shared_ptr output_cv = std::make_shared(shape, input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::Rect roi(x, y, w, h); + (input_cv->mat())(roi).copyTo(output_cv->mat()); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in crop."); + } +} + +Status HwcToChw(std::shared_ptr input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() == 2) { + // If input tensor is 2D, we assume we have hw dimensions + *output = input; + return Status::OK(); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->shape().Size() < 2 || input_cv->shape().Size() > 3 || + (input_cv->shape().Size() == 3 && num_channels != 3 && num_channels != 1)) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3 nor 1"); + } + cv::Mat output_img; + + int height = input_cv->shape()[0]; + int width = input_cv->shape()[1]; + + auto output_cv = std::make_unique(TensorShape{num_channels, height, width}, input_cv->type()); + for (int i = 0; i < num_channels; ++i) { + cv::Mat mat; + RETURN_IF_NOT_OK(output_cv->Mat({i}, &mat)); + cv::extractChannel(input_cv->mat(), mat, i); + } + *output = std::move(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in ChannelSwap."); + } +} + +Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); + int num_channels = input_cv->shape()[2]; + if (input_cv->shape().Size() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_BGR2RGB)); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in ChangeMode."); + } +} + +Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, + int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // image too large or too small + if (crop_height == 0 || crop_width == 0 || target_height == 0 || target_height > crop_height * 1000 || + target_width == 0 || target_height > crop_width * 1000) { + std::string err_msg = + "The resizing width or height 1) is too big, it's up to " + "1000 times the original image; 2) can not be 0."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + cv::Rect roi(x, y, crop_width, crop_height); + auto cv_mode = GetCVInterpolationMode(mode); + cv::Mat cv_in = input_cv->mat(); + TensorShape shape{target_height, target_width}; + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3) shape = shape.AppendDim(num_channels); + std::shared_ptr cvt_out = std::make_shared(shape, input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(cvt_out); + cv::resize(cv_in(roi), cvt_out->mat(), cv::Size(target_width, target_height), 0, 0, cv_mode); + *output = std::static_pointer_cast(cvt_out); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in CropAndResize."); + } +} + +Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, + InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + cv::Mat input_img = input_cv->mat(); + if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) { + RETURN_STATUS_UNEXPECTED("Image too large center not precise"); + } + // default to center of image + if (fx == -1 && fy == -1) { + fx = (input_img.cols - 1) / 2.0; + fy = (input_img.rows - 1) / 2.0; + } + cv::Mat output_img; + cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); + // maybe don't use uint32 for image dimension here + cv::Point2f pc(fx, fy); + cv::Mat rot = cv::getRotationMatrix2D(pc, degree, 1.0); + std::shared_ptr output_cv; + if (!expand) { + // this case means that the shape doesn't change, size stays the same + // We may not need this memcpy if it is in place. + output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + // using inter_nearest to comply with python default + cv::warpAffine(input_img, output_cv->mat(), rot, input_img.size(), GetCVInterpolationMode(interpolation), + cv::BORDER_CONSTANT, fill_color); + } else { + // we resize here since the shape changes + // create a new bounding box with the rotate + cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f(); + rot.at(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0; + rot.at(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0; + // use memcpy and don't compute the new shape since openCV has a rounding problem + cv::warpAffine(input_img, output_img, rot, bbox.size(), GetCVInterpolationMode(interpolation), + cv::BORDER_CONSTANT, fill_color); + output_cv = std::make_shared(output_img); + RETURN_UNEXPECTED_IF_NULL(output_cv); + } + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in image rotation"); + } + return Status::OK(); +} + +Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, + const std::shared_ptr &mean, const std::shared_ptr &std) { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!(input_cv->mat().data && input_cv->Rank() == 3)) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + cv::Mat in_image = input_cv->mat(); + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), DataType(DataType::DE_FLOAT32)); + RETURN_UNEXPECTED_IF_NULL(output_cv); + mean->Squeeze(); + if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) { + std::string err_msg = "Mean tensor should be of size 3 and type float."; + return Status(StatusCode::kShapeMisMatch, err_msg); + } + std->Squeeze(); + if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) { + std::string err_msg = "Std tensor should be of size 3 and type float."; + return Status(StatusCode::kShapeMisMatch, err_msg); + } + try { + // NOTE: We are assuming the input image is in RGB and the mean + // and std are in RGB + cv::Mat rgb[3]; + cv::split(in_image, rgb); + for (uint8_t i = 0; i < 3; i++) { + float mean_c, std_c; + RETURN_IF_NOT_OK(mean->GetItemAt(&mean_c, {i})); + RETURN_IF_NOT_OK(std->GetItemAt(&std_c, {i})); + rgb[i].convertTo(rgb[i], CV_32F, 1.0 / std_c, (-mean_c / std_c)); + } + cv::merge(rgb, 3, output_cv->mat()); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in Normalize"); + } +} + +Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + output_cv->mat() = input_img * alpha; + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust brightness"); + } + return Status::OK(); +} + +Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + cv::Mat gray, output_img; + cv::cvtColor(input_img, gray, CV_RGB2GRAY); + int mean_img = static_cast(cv::mean(gray).val[0] + 0.5); + std::shared_ptr output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + output_img = cv::Mat::zeros(input_img.rows, input_img.cols, CV_8UC1); + output_img = output_img + mean_img; + cv::cvtColor(output_img, output_img, CV_GRAY2RGB); + output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust contrast"); + } + return Status::OK(); +} + +Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::Mat output_img = output_cv->mat(); + cv::Mat gray; + cv::cvtColor(input_img, gray, CV_RGB2GRAY); + cv::cvtColor(gray, output_img, CV_GRAY2RGB); + output_cv->mat() = output_img * (1.0 - alpha) + input_img * alpha; + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust saturation"); + } + return Status::OK(); +} + +Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue) { + if (hue > 0.5 || hue < -0.5) { + MS_LOG(ERROR) << "Hue factor is not in [-0.5, 0.5]."; + RETURN_STATUS_UNEXPECTED("hue_factor is not in [-0.5, 0.5]."); + } + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + cv::Mat input_img = input_cv->mat(); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3"); + } + auto output_cv = std::make_shared(input_cv->shape(), input_cv->type()); + RETURN_UNEXPECTED_IF_NULL(output_cv); + cv::Mat output_img; + cv::cvtColor(input_img, output_img, CV_RGB2HSV_FULL); + for (int y = 0; y < output_img.cols; y++) { + for (int x = 0; x < output_img.rows; x++) { + uint8_t cur1 = output_img.at(cv::Point(y, x))[0]; + uint8_t h_hue = 0; + h_hue = static_cast(hue * 255); + cur1 += h_hue; + output_img.at(cv::Point(y, x))[0] = cur1; + } + } + cv::cvtColor(output_img, output_cv->mat(), CV_HSV2RGB_FULL); + *output = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in adjust hue"); + } + return Status::OK(); +} + +Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, + uint8_t fill_g, uint8_t fill_b) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + int num_channels = input_cv->shape()[2]; + if (input_cv->mat().data == nullptr || input_cv->Rank() != 3 || num_channels != 3) { + RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); + } + cv::Mat input_img = input_cv->mat(); + int32_t image_h = input_cv->shape()[0]; + int32_t image_w = input_cv->shape()[1]; + // check if erase size is bigger than image itself + if (box_height > image_h || box_width > image_w) { + RETURN_STATUS_UNEXPECTED("input box size too large for image erase"); + } + + // for random color + std::normal_distribution normal_distribution(0, 1); + std::uniform_int_distribution height_distribution_bound(0, image_h - box_height); + std::uniform_int_distribution width_distribution_bound(0, image_w - box_width); + std::uniform_int_distribution height_distribution_unbound(0, image_h + box_height); + std::uniform_int_distribution width_distribution_unbound(0, image_w + box_width); + // core logic + // update values based on random erasing or cutout + + for (int32_t i = 0; i < num_patches; i++) { + // rows in cv mat refers to the height of the cropped box + // we determine h_start and w_start using two different distributions as erasing is used by two different + // image augmentations. The bounds are also different in each case. + int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); + int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); + + int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; + int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; + // check for starting range >= 0, here the start range is checked after for cut out, for random erasing + // w_start and h_start will never be less than 0. + h_start = (h_start < 0) ? 0 : h_start; + w_start = (w_start < 0) ? 0 : w_start; + for (int y = w_start; y < max_width; y++) { + for (int x = h_start; x < max_height; x++) { + if (random_color) { + // fill each box with a random value + input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(*rnd)); + } else { + input_img.at(cv::Point(y, x))[0] = fill_r; + input_img.at(cv::Point(y, x))[1] = fill_g; + input_img.at(cv::Point(y, x))[2] = fill_b; + } + } + } + } + *output = std::static_pointer_cast(input); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in erasing"); + } +} + +Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, + const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + try { + // input image + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + // get the border type in openCV + auto b_type = GetCVBorderType(border_types); + // output image + cv::Mat out_image; + if (b_type == cv::BORDER_CONSTANT) { + cv::Scalar fill_color = cv::Scalar(fill_b, fill_g, fill_r); + cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type, fill_color); + } else { + cv::copyMakeBorder(input_cv->mat(), out_image, pad_top, pad_bottom, pad_left, pad_right, b_type); + } + std::shared_ptr output_cv = std::make_shared(out_image); + RETURN_UNEXPECTED_IF_NULL(output_cv); + // pad the dimension if shape information is only 2 dimensional, this is grayscale + int num_channels = input_cv->shape()[2]; + if (input_cv->Rank() == 3 && num_channels == 1 && output_cv->Rank() == 2) output_cv->ExpandDim(2); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Unexpected error in pad"); + } +} +// -------- BBOX OPERATIONS -------- // +Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, + int CB_Ymax) { + // PASS LIST, COUNT OF BOUNDING BOXES + // Also PAss X/Y Min/Max of image cropped region - normally obtained from 'GetCropBox' functions + float bb_Xmin = 0.0, bb_Ymin = 0.0, bb_Xmax = 0.0, bb_Ymax = 0.0; + std::vector correct_ind; + std::vector copyVals; + dsize_t bboxDim = (*bboxList)->shape()[1]; + bool retFlag = false; // true unless overlap found + for (int i = 0; i < *bboxCount; i++) { + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Xmin, {i, 0})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Ymin, {i, 1})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Xmax, {i, 2})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&bb_Ymax, {i, 3})); + bb_Xmax = bb_Xmin + bb_Xmax; + bb_Ymax = bb_Ymin + bb_Ymax; + // check for image / BB overlap + if (((bb_Xmin > CB_Xmax) || (bb_Ymin > CB_Ymax)) || ((bb_Xmax < CB_Xmin) || (bb_Ymax < CB_Ymin))) { + continue; // no overlap found + } + // Update this bbox and select it to move to the final output tensor + correct_ind.push_back(i); + // adjust BBox corners by bringing into new CropBox if beyond + // Also reseting/adjusting for boxes to lie within CropBox instead of Image - subtract CropBox Xmin/YMin + + bb_Xmin = bb_Xmin - std::min(static_cast(0.0), (bb_Xmin - CB_Xmin)) - CB_Xmin; + bb_Xmax = bb_Xmax - std::max(static_cast(0.0), (bb_Xmax - CB_Xmax)) - CB_Xmin; + bb_Ymin = bb_Ymin - std::min(static_cast(0.0), (bb_Ymin - CB_Ymin)) - CB_Ymin; + bb_Ymax = bb_Ymax - std::max(static_cast(0.0), (bb_Ymax - CB_Ymax)) - CB_Ymin; + + // bound check for float values + bb_Xmin = std::max(bb_Xmin, static_cast(0)); + bb_Ymin = std::max(bb_Ymin, static_cast(0)); + bb_Xmax = std::min(bb_Xmax, static_cast(CB_Xmax - CB_Xmin)); // find max value relative to new image + bb_Ymax = std::min(bb_Ymax, static_cast(CB_Ymax - CB_Ymin)); + + // reset min values and calculate width/height from Box corners + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, bb_Xmin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, bb_Ymin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 2}, bb_Xmax - bb_Xmin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 3}, bb_Ymax - bb_Ymin)); + } + // create new tensor and copy over bboxes still valid to the image + // bboxes outside of new cropped region are ignored - empty tensor returned in case of none + *bboxCount = correct_ind.size(); + float temp = 0.0; + for (auto slice : correct_ind) { // for every index in the loop + for (int ix = 0; ix < bboxDim; ix++) { + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&temp, {slice, ix})); + copyVals.push_back(temp); + } + } + std::shared_ptr retV; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&retV, copyVals, TensorShape({static_cast(*bboxCount), bboxDim}))); + (*bboxList) = retV; // reset pointer + return Status::OK(); +} + +Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left) { + for (int i = 0; i < bboxCount; i++) { + float xMin = 0.0, yMin = 0.0; + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&xMin, {i, 0})); + RETURN_IF_NOT_OK((*bboxList)->GetItemAt(&yMin, {i, 1})); + xMin += pad_left; + yMin += pad_top; + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 0}, xMin)); + RETURN_IF_NOT_OK((*bboxList)->SetItemAt({i, 1}, yMin)); + } + return Status::OK(); +} + +Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, + int32_t target_height_, int orig_width, int orig_height) { + float bb_Xmin = 0, bb_Ymin = 0, bb_Xwidth = 0, bb_Ywidth = 0; + // cast to float to preserve fractional + float W_aspRatio = (target_width_ * 1.0) / (orig_width * 1.0); + float H_aspRatio = (target_height_ * 1.0) / (orig_height * 1.0); + for (int i = 0; i < bboxCount; i++) { + // for each bounding box + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Xmin, {i, 0})); + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Ymin, {i, 1})); + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Xwidth, {i, 2})); + RETURN_IF_NOT_OK(bboxList->GetItemAt(&bb_Ywidth, {i, 3})); + // update positions and widths + bb_Xmin = bb_Xmin * W_aspRatio; + bb_Ymin = bb_Ymin * H_aspRatio; + bb_Xwidth = bb_Xwidth * W_aspRatio; + bb_Ywidth = bb_Ywidth * H_aspRatio; + // reset bounding box values + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 0}, bb_Xmin)); + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 1}, bb_Ymin)); + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 2}, bb_Xwidth)); + RETURN_IF_NOT_OK(bboxList->SetItemAt({i, 3}, bb_Ywidth)); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h new file mode 100644 index 0000000000..f489c7367b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -0,0 +1,259 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ +#define DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ + +#include + +#include +#include +#include +#include +#if defined(_WIN32) || defined(_WIN64) +#undef HAVE_STDDEF_H +#undef HAVE_STDLIB_H +#endif +#include "./jpeglib.h" +#include "./jerror.h" +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +void JpegErrorExitCustom(j_common_ptr cinfo); + +struct JpegErrorManagerCustom { + // "public" fields + struct jpeg_error_mgr pub; + // for return to caller + jmp_buf setjmp_buffer; +}; + +// Returns the interpolation mode in openCV format +// @param mode: interpolation mode in DE format +int GetCVInterpolationMode(InterpolationMode mode); + +// Returns the openCV equivalent of the border type used for padding. +// @param type +// @return +int GetCVBorderType(BorderType type); + +// Returns flipped image +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param flip_code: 1 for Horizontal (around y-axis), 0 for Vertical (around x-axis), -1 for both +// The flipping happens in place. +Status Flip(std::shared_ptr input, std::shared_ptr *output, int flip_code); + +// Returns Horizontally flipped image +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// The flipping happens in place. +Status HorizontalFlip(std::shared_ptr input, std::shared_ptr *output); + +// Returns Vertically flipped image +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// The flipping happens in place. +Status VerticalFlip(std::shared_ptr input, std::shared_ptr *output); + +// Returns Resized image. +// @param input/output: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param output_height: height of output +// @param output_width: width of output +// @param fx: horizontal scale +// @param fy: vertical scale +// @param InterpolationMode: the interpolation mode +// @param output: Resized image of shape or +// and same type as input +Status Resize(const std::shared_ptr &input, std::shared_ptr *output, int32_t output_height, + int32_t output_width, double fx = 0.0, double fy = 0.0, + InterpolationMode mode = InterpolationMode::kLinear); + +// Returns Decoded image +// Supported images: +// BMP JPEG JPG PNG TIFF +// supported by opencv, if user need more image analysis capabilities, please compile opencv particularlly. +// @param input: CVTensor containing the not decoded image 1D bytes +// @param output: Decoded image Tensor of shape and type DE_UINT8. Pixel order is RGB +Status Decode(const std::shared_ptr &input, std::shared_ptr *output); + +Status DecodeCv(const std::shared_ptr &input, std::shared_ptr *output); + +bool IsNonEmptyJPEG(const std::shared_ptr &input); + +void JpegSetSource(j_decompress_ptr c_info, const void *data, int64_t data_size); + +Status JpegCropAndDecode(const std::shared_ptr &input, std::shared_ptr *output, int x = 0, int y = 0, + int w = 0, int h = 0); +// Returns Rescaled image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param rescale: rescale parameter +// @param shift: shift parameter +// @param output: Rescaled image Tensor of same input shape and type DE_FLOAT32 +Status Rescale(const std::shared_ptr &input, std::shared_ptr *output, float rescale, float shift); + +// Returns cropped ROI of an image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param x: starting horizontal position of ROI +// @param y: starting vertical position of ROI +// @param w: width of the ROI +// @param h: height of the ROI +// @param output: Cropped image Tensor of shape or and same input type. +Status Crop(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, int w, int h); + +// Swaps the channels in the image, i.e. converts HWC to CHW +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param output: Tensor of shape or and same input type. +Status HwcToChw(std::shared_ptr input, std::shared_ptr *output); + +// Swap the red and blue pixels (RGB <-> BGR) +// @param input: Tensor of shape and any OpenCv compatible type, see CVTensor. +// @param output: Swapped image of same shape and type +Status SwapRedAndBlue(std::shared_ptr input, std::shared_ptr *output); + +// Crops and resizes the image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param x: horizontal start point +// @param y: vertical start point +// @param crop_height: height of the cropped ROI +// @param crop_width: width of the cropped ROI +// @param target_width: width of the final resized image +// @param target_height: height of the final resized image +// @param InterpolationMode: the interpolation used in resize operation +// @param output: Tensor of shape or +// and same type as input +Status CropAndResize(const std::shared_ptr &input, std::shared_ptr *output, int x, int y, + int crop_height, int crop_width, int target_height, int target_width, InterpolationMode mode); + +// Returns rotated image +// @param input: Tensor of shape or and any OpenCv compatible type, see CVTensor. +// @param fx: rotation center x coordinate +// @param fy: rotation center y coordinate +// @param degree: degree to rotate +// @param expand: if reshape is necessary +// @param output: rotated image of same input type. +Status Rotate(const std::shared_ptr &input, std::shared_ptr *output, float fx, float fy, float degree, + InterpolationMode interpolation = InterpolationMode::kNearestNeighbour, bool expand = false, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + +// Returns Normalized image +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order +// @param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order +// @param output: Normalized image Tensor of same input shape and type DE_FLOAT32 +Status Normalize(const std::shared_ptr &input, std::shared_ptr *output, + const std::shared_ptr &mean, const std::shared_ptr &std); + +// Returns image with adjusted brightness. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param alpha: Alpha value to adjust brightness by. Should be a positive number. +// If user input one value in python, the range is [1 - value, 1 + value]. +// This will output original image multiplied by alpha. 0 gives a black image, 1 gives the +// original image while 2 increases the brightness by a factor of 2. +// @param output: Adjusted image of same shape and type. +Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); + +// Returns image with adjusted contrast. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param alpha: Alpha value to adjust contrast by. Should be a positive number. +// If user input one value in python, the range is [1 - value, 1 + value]. +// 0 gives a solid gray image, 1 gives the original image while 2 increases +// the contrast by a factor of 2. +// @param output: Adjusted image of same shape and type. +Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); + +// Returns image with adjusted saturation. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param alpha: Alpha value to adjust saturation by. Should be a positive number. +// If user input one value in python, the range is [1 - value, 1 + value]. +// 0 will give a black and white image, 1 will give the original image while +// 2 will enhance the saturation by a factor of 2. +// @param output: Adjusted image of same shape and type. +Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha); + +// Returns image with adjusted hue. +// @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. +// @param hue: Hue value to adjust by, should be within range [-0.5, 0.5]. 0.5 and - 0.5 will reverse the hue channel +// completely. +// If user input one value in python, the range is [-value, value]. +// @param output: Adjusted image of same shape and type. +Status AdjustHue(const std::shared_ptr &input, std::shared_ptr *output, const float &hue); + +// Masks out a random section from the image with set dimension +// @param input: input Tensor +// @param output: cutOut Tensor +// @param box_height: height of the cropped box +// @param box_width: width of the cropped box +// @param num_patches: number of boxes to cut out from the image +// @param bounded: boolean flag to toggle between random erasing and cutout +// @param random_color: whether or not random fill value should be used +// @param fill_r: red fill value for erase +// @param fill_g: green fill value for erase +// @param fill_b: blue fill value for erase. +Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + +// Pads the input image and puts the padded image in the output +// @param input: input Tensor +// @param output: padded Tensor +// @param pad_top: amount of padding done in top +// @param pad_bottom: amount of padding done in bottom +// @param pad_left: amount of padding done in left +// @param pad_right: amount of padding done in right +// @param border_types: the interpolation to be done in the border +// @param fill_r: red fill value for pad +// @param fill_g: green fill value for pad +// @param fill_b: blue fill value for pad. +Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, + const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); + +// -------- BBOX OPERATIONS -------- // +// Updates and checks bounding boxes for new cropped region of image +// @param bboxList: A tensor contaning bounding box tensors +// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop +// @param CB_Xmin: Image's CropBox Xmin coordinate +// @param CB_Xmin: Image's CropBox Ymin coordinate +// @param CB_Xmax: Image's CropBox Xmax coordinate - (Xmin + width) +// @param CB_Xmax: Image's CropBox Ymax coordinate - (Ymin + height) +Status UpdateBBoxesForCrop(std::shared_ptr *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax, + int CB_Ymax); + +// Updates bounding boxes with required Top and Left padding +// Top and Left padding amounts required to adjust bboxs min X,Y values according to padding 'push' +// Top/Left since images 0,0 coordinate is taken from top left +// @param bboxList: A tensor contaning bounding box tensors +// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop +// @param pad_top: Total amount of padding applied to image top +// @param pad_left: Total amount of padding applied to image left side +Status PadBBoxes(const std::shared_ptr *bboxList, const size_t &bboxCount, int32_t pad_top, int32_t pad_left); + +// Updates bounding boxes for an Image Resize Operation - Takes in set of valid BBoxes +// For e.g those that remain after a crop +// @param bboxList: A tensor contaning bounding box tensors +// @param bboxCount: total Number of bounding boxes - required within caller function to run update loop +// @param bboxList: A tensor contaning bounding box tensors +// @param target_width_: required width of image post resize +// @param target_width_: required height of image post resize +// @param orig_width: current width of image pre resize +// @param orig_height: current height of image pre resize +Status UpdateBBoxesForResize(const std::shared_ptr &bboxList, const size_t &bboxCount, int32_t target_width_, + int32_t target_height_, int orig_width, int orig_height); + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc new file mode 100644 index 0000000000..de5deb31ef --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/normalize_op.h" + +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +NormalizeOp::NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b) { + int size[] = {3}; + cv::Mat mean_cv(1, size, CV_32F); + mean_cv.at(0) = mean_r; + mean_cv.at(1) = mean_g; + mean_cv.at(2) = mean_b; + mean_ = std::make_shared(mean_cv); + mean_->Squeeze(); + + cv::Mat std_cv(1, size, CV_32F); + std_cv.at(0) = std_r; + std_cv.at(1) = std_g; + std_cv.at(2) = std_b; + std_ = std::make_shared(std_cv); + std_->Squeeze(); +} + +Status NormalizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + // Doing the normalization + return Normalize(input, output, mean_, std_); +} + +void NormalizeOp::Print(std::ostream &out) const { + out << "NormalizeOp, mean: " << mean_->mat().at(0) << ", " << mean_->mat().at(1) << ", " + << mean_->mat().at(2) << "std: " << std_->mat().at(0) << ", " << std_->mat().at(1) << ", " + << std_->mat().at(2) << std::endl; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h new file mode 100644 index 0000000000..7821869c8f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/normalize_op.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ + +#include +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class NormalizeOp : public TensorOp { + public: + NormalizeOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b); + + ~NormalizeOp() override = default; + + void Print(std::ostream &out) const override; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kNormalizeOp; } + + private: + std::shared_ptr mean_; + std::shared_ptr std_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.cc new file mode 100644 index 0000000000..52f32e2b1b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/pad_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const BorderType PadOp::kDefBorderType = BorderType::kConstant; +const uint8_t PadOp::kDefFillR = 0; +const uint8_t PadOp::kDefFillG = 0; +const uint8_t PadOp::kDefFillB = 0; + +PadOp::PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) + : pad_top_(pad_top), + pad_bottom_(pad_bottom), + pad_left_(pad_left), + pad_right_(pad_right), + boarder_type_(border_types), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) {} + +Status PadOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Pad(input, output, pad_top_, pad_bottom_, pad_left_, pad_right_, boarder_type_, fill_r_, fill_g_, fill_b_); +} + +Status PadOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, -1, 3}); // we don't know what is output image size, but we know it should be 3 channels + if (inputs[0].Rank() == 1) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h new file mode 100644 index 0000000000..9437058406 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/pad_op.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_PAD_OP_H_ +#define DATASET_KERNELS_IMAGE_PAD_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class PadOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const BorderType kDefBorderType; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + // Constructor for PadOp. + // @param pad_top number of pixels to pad the top of image with. + // @param pad_bottom number of pixels to pad the bottom of the image with. + // @param pad_left number of pixels to pad the left of the image with. + // @param pad_right number of pixels to pad the right of the image with. + // @param border_types BorderType enum, the type of boarders that we are using. + // @param fill_r R value for the color to pad with. + // @param fill_g G value for the color to pad with. + // @param fill_b B value for the color to pad with. + PadOp(int32_t pad_top, int32_t pad_bottom, int32_t pad_left, int32_t pad_right, BorderType border_types, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + ~PadOp() override = default; + + void Print(std::ostream &out) const override { out << "PadOp: "; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kPadOp; } + + private: + int32_t pad_top_; + int32_t pad_bottom_; + int32_t pad_left_; + int32_t pad_right_; + BorderType boarder_type_; + uint8_t fill_r_; + uint8_t fill_g_; + uint8_t fill_b_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_PAD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc new file mode 100644 index 0000000000..6dbf30c33e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +RandomColorAdjustOp::RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, + float e_contrast_factor, float s_saturation_factor, float e_saturation_factor, + float s_hue_factor, float e_hue_factor) + : bright_factor_start_(s_bright_factor), + bright_factor_end_(e_bright_factor), + contrast_factor_start_(s_contrast_factor), + contrast_factor_end_(e_contrast_factor), + saturation_factor_start_(s_saturation_factor), + saturation_factor_end_(e_saturation_factor), + hue_factor_start_(s_hue_factor), + hue_factor_end_(e_hue_factor) { + rnd_.seed(GetSeed()); +} + +Status RandomColorAdjustOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + // randomly select an augmentation to apply to the input image until all the transformations run + std::vector params_vector = {"brightness", "contrast", "saturation", "hue"}; + + std::shuffle(params_vector.begin(), params_vector.end(), rnd_); + + *output = std::static_pointer_cast(input); + // determine if certain augmentation needs to be executed: + for (const auto ¶m : params_vector) { + // case switch + if (param == "brightness") { + if (CmpFloat(bright_factor_start_, bright_factor_end_) && CmpFloat(bright_factor_start_, 1.0f)) { + MS_LOG(DEBUG) << "Not running brightness."; + } else { + // adjust the brightness of an image + float random_factor = std::uniform_real_distribution(bright_factor_start_, bright_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustBrightness(*output, output, random_factor)); + } + } else if (param == "contrast") { + if (CmpFloat(contrast_factor_start_, contrast_factor_end_) && CmpFloat(contrast_factor_start_, 1.0f)) { + MS_LOG(DEBUG) << "Not running contrast."; + } else { + float random_factor = std::uniform_real_distribution(contrast_factor_start_, contrast_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustContrast(*output, output, random_factor)); + } + } else if (param == "saturation") { + // adjust the Saturation of an image + if (CmpFloat(saturation_factor_start_, saturation_factor_end_) && CmpFloat(saturation_factor_start_, 1.0f)) { + MS_LOG(DEBUG) << "Not running saturation."; + } else { + float random_factor = + std::uniform_real_distribution(saturation_factor_start_, saturation_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustSaturation(*output, output, random_factor)); + } + } else if (param == "hue") { + if (CmpFloat(hue_factor_start_, hue_factor_end_) && CmpFloat(hue_factor_start_, 0.0f)) { + MS_LOG(DEBUG) << "Not running hue."; + } else { + // adjust the Hue of an image + float random_factor = std::uniform_real_distribution(hue_factor_start_, hue_factor_end_)(rnd_); + RETURN_IF_NOT_OK(AdjustHue(*output, output, random_factor)); + } + } + } + // now after we do all the transformations, the last one is fine + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h new file mode 100644 index 0000000000..fb29b57062 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_color_adjust_op.h @@ -0,0 +1,80 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomColorAdjustOp : public TensorOp { + public: + static const uint32_t kDefSeed; + + // Constructor for RandomColorAdjustOp. + // @param s_bright_factor brightness change range start value. + // @param e_bright_factor brightness change range end value. + // @param s_contrast_factor contrast change range start value. + // @param e_contrast_factor contrast change range start value. + // @param s_saturation_factor saturation change range end value. + // @param e_saturation_factor saturation change range end value. + // @param s_hue_factor hue change factor start value, this should be greater than -0.5. + // @param e_hue_factor hue change factor start value, this should be less than 0.5. + // @param seed optional seed to pass in to the constructor. + // @details the randomly chosen degree is uniformly distributed. + RandomColorAdjustOp(float s_bright_factor, float e_bright_factor, float s_contrast_factor, float e_contrast_factor, + float s_saturation_factor, float e_saturation_factor, float s_hue_factor, float e_hue_factor); + + ~RandomColorAdjustOp() override = default; + + // Print function for RandomJitter. + // @param out output stream to print to. + void Print(std::ostream &out) const override { out << "RandomColorAdjustOp: "; } + + // Overrides the base class compute function. + // Calls multiple transform functions in ImageUtils, this function takes an input tensor. + // and transforms its data using openCV, the output memory is manipulated to contain the result. + // @return Status - The error code return. + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomColorAdjustOp; } + + private: + std::mt19937 rnd_; + float bright_factor_start_; + float bright_factor_end_; + float contrast_factor_start_; + float contrast_factor_end_; + float saturation_factor_start_; + float saturation_factor_end_; + float hue_factor_start_; + float hue_factor_end_; + // Compare two floating point variables. Return true if they are same / very close. + inline bool CmpFloat(const float &a, const float &b, float epsilon = 0.0000000001f) const { + return (std::fabs(a - b) < epsilon); + } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_COLOR_ADJUST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc new file mode 100644 index 0000000000..8a7364d666 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomCropAndResizeOp::kDefScaleLb = 0.08; +const float RandomCropAndResizeOp::kDefScaleUb = 1.0; +const float RandomCropAndResizeOp::kDefAspectLb = 0.75; +const float RandomCropAndResizeOp::kDefAspectUb = 1.333333; +const InterpolationMode RandomCropAndResizeOp::kDefInterpolation = InterpolationMode::kLinear; +const int32_t RandomCropAndResizeOp::kDefMaxIter = 10; + +RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb, + float scale_ub, float aspect_lb, float aspect_ub, + InterpolationMode interpolation, int32_t max_iter) + : target_height_(target_height), + target_width_(target_width), + rnd_scale_(scale_lb, scale_ub), + rnd_aspect_(log(aspect_lb), log(aspect_ub)), + interpolation_(interpolation), + aspect_lb_(aspect_lb), + aspect_ub_(aspect_ub), + max_iter_(max_iter) { + rnd_.seed(GetSeed()); +} + +Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); + + int h_in = input->shape()[0]; + int w_in = input->shape()[1]; + int x = 0; + int y = 0; + int crop_height = 0; + int crop_width = 0; + (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); + return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); +} +Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{target_height_, target_width_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { + *crop_width = w_in; + *crop_height = h_in; + CHECK_FAIL_RETURN_UNEXPECTED(w_in != 0, "Width is 0"); + CHECK_FAIL_RETURN_UNEXPECTED(h_in != 0, "Height is 0"); + CHECK_FAIL_RETURN_UNEXPECTED(aspect_lb_ > 0, "Aspect lower bound must be greater than zero"); + for (int32_t i = 0; i < max_iter_; i++) { + double const sample_scale = rnd_scale_(rnd_); + // In case of non-symmetrical aspect ratios, use uniform distribution on a logarithmic sample_scale. + // Note rnd_aspect_ is already a random distribution of the input aspect ratio in logarithmic sample_scale. + double const sample_aspect = exp(rnd_aspect_(rnd_)); + + *crop_width = static_cast(std::round(std::sqrt(h_in * w_in * sample_scale * sample_aspect))); + *crop_height = static_cast(std::round(*crop_width / sample_aspect)); + if (*crop_width <= w_in && *crop_height <= h_in) { + std::uniform_int_distribution<> rd_x(0, w_in - *crop_width); + std::uniform_int_distribution<> rd_y(0, h_in - *crop_height); + *x = rd_x(rnd_); + *y = rd_y(rnd_); + return Status::OK(); + } + } + double const img_aspect = static_cast(w_in) / h_in; + if (img_aspect < aspect_lb_) { + *crop_width = w_in; + *crop_height = static_cast(std::round(*crop_width / static_cast(aspect_lb_))); + } else { + if (img_aspect > aspect_ub_) { + *crop_height = h_in; + *crop_width = static_cast(std::round(*crop_height * static_cast(aspect_ub_))); + } else { + *crop_width = w_in; + *crop_height = h_in; + } + } + *x = static_cast(std::round((w_in - *crop_width) / 2.0)); + *y = static_cast(std::round((h_in - *crop_height) / 2.0)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h new file mode 100644 index 0000000000..41d775fdf7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_op.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomCropAndResizeOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefScaleLb; + static const float kDefScaleUb; + static const float kDefAspectLb; + static const float kDefAspectUb; + static const InterpolationMode kDefInterpolation; + static const int32_t kDefMaxIter; + + RandomCropAndResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, + float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, + InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); + + RandomCropAndResizeOp() = default; + + RandomCropAndResizeOp(const RandomCropAndResizeOp &rhs) = default; + + RandomCropAndResizeOp(RandomCropAndResizeOp &&rhs) = default; + + ~RandomCropAndResizeOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropAndResize: " << target_height_ << " " << target_width_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + Status GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width); + + std::string Name() const override { return kRandomCropAndResizeOp; } + + protected: + int32_t target_height_; + int32_t target_width_; + std::uniform_real_distribution rnd_scale_; + std::uniform_real_distribution rnd_aspect_; + std::mt19937 rnd_; + InterpolationMode interpolation_; + int32_t max_iter_; + double aspect_lb_; + double aspect_ub_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc new file mode 100644 index 0000000000..98bfe41241 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" + +namespace mindspore { +namespace dataset { + +Status RandomCropAndResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of input is abnormal"); + + output->resize(2); + (*output)[1] = std::move(input[1]); // move boxes over to output + + size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor + int h_in = input[0]->shape()[0]; + int w_in = input[0]->shape()[1]; + int x = 0; + int y = 0; + int crop_height = 0; + int crop_width = 0; + + RETURN_IF_NOT_OK(RandomCropAndResizeOp::GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width)); + + int maxX = x + crop_width; // max dims of selected CropBox on image + int maxY = y + crop_height; + + RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &bboxCount, x, y, maxX, maxY)); // IMAGE_UTIL + RETURN_IF_NOT_OK(CropAndResize(input[0], &(*output)[0], x, y, crop_height, crop_width, target_height_, target_width_, + interpolation_)); + + RETURN_IF_NOT_OK( + UpdateBBoxesForResize((*output)[1], bboxCount, target_width_, target_height_, crop_width, crop_height)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h new file mode 100644 index 0000000000..ddaac10fac --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ + +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include + +namespace mindspore { +namespace dataset { + +class RandomCropAndResizeWithBBoxOp : public RandomCropAndResizeOp { + public: + // Constructor for RandomCropAndResizeWithBBoxOp, with default value and passing to base class constructor + RandomCropAndResizeWithBBoxOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, + float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, + float aspect_ub = kDefAspectUb, InterpolationMode interpolation = kDefInterpolation, + int32_t max_iter = kDefMaxIter) + : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, + max_iter) {} + + ~RandomCropAndResizeWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropAndResizeWithBBox: " << RandomCropAndResizeOp::target_height_ << " " + << RandomCropAndResizeOp::target_width_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomCropAndResizeWithBBoxOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_AND_RESIZE_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.cc new file mode 100644 index 0000000000..d62aebd37f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/kernels/image/decode_op.h" + +namespace mindspore { +namespace dataset { +RandomCropDecodeResizeOp::RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb, + float scale_ub, float aspect_lb, float aspect_ub, + InterpolationMode interpolation, int32_t max_iter) + : RandomCropAndResizeOp(target_height, target_width, scale_lb, scale_ub, aspect_lb, aspect_ub, interpolation, + max_iter) {} + +Status RandomCropDecodeResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + if (input == nullptr) { + RETURN_STATUS_UNEXPECTED("input tensor is null"); + } + if (!IsNonEmptyJPEG(input)) { + DecodeOp op(true); + std::shared_ptr decoded; + RETURN_IF_NOT_OK(op.Compute(input, &decoded)); + return RandomCropAndResizeOp::Compute(decoded, output); + } else { + struct jpeg_decompress_struct cinfo {}; + struct JpegErrorManagerCustom jerr {}; + cinfo.err = jpeg_std_error(&jerr.pub); + jerr.pub.error_exit = JpegErrorExitCustom; + try { + jpeg_create_decompress(&cinfo); + JpegSetSource(&cinfo, input->GetBuffer(), input->SizeInBytes()); + (void)jpeg_read_header(&cinfo, TRUE); + jpeg_calc_output_dimensions(&cinfo); + } catch (std::runtime_error &e) { + jpeg_destroy_decompress(&cinfo); + RETURN_STATUS_UNEXPECTED(e.what()); + } + int h_in = cinfo.output_height; + int w_in = cinfo.output_width; + jpeg_destroy_decompress(&cinfo); + + int x = 0; + int y = 0; + int crop_height = 0; + int crop_width = 0; + (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); + + std::shared_ptr decoded; + RETURN_IF_NOT_OK(JpegCropAndDecode(input, &decoded, x, y, crop_width, crop_height)); + return Resize(decoded, output, target_height_, target_width_, 0.0, 0.0, interpolation_); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h new file mode 100644 index 0000000000..863fd48c14 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_decode_resize_op.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomCropDecodeResizeOp : public RandomCropAndResizeOp { + public: + RandomCropDecodeResizeOp(int32_t target_height, int32_t target_width, float scale_lb = kDefScaleLb, + float scale_ub = kDefScaleUb, float aspect_lb = kDefAspectLb, float aspect_ub = kDefAspectUb, + InterpolationMode interpolation = kDefInterpolation, int32_t max_iter = kDefMaxIter); + + explicit RandomCropDecodeResizeOp(const RandomCropAndResizeOp &rhs) : RandomCropAndResizeOp(rhs) {} + + ~RandomCropDecodeResizeOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropDecodeResize: " << RandomCropAndResizeOp::target_height_ << " " + << RandomCropAndResizeOp::target_width_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomCropDecodeResizeOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_DECODE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc new file mode 100644 index 0000000000..51772e9ec3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_crop_op.h" +#include +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomCropOp::kDefPadTop = 0; +const int32_t RandomCropOp::kDefPadBottom = 0; +const int32_t RandomCropOp::kDefPadLeft = 0; +const int32_t RandomCropOp::kDefPadRight = 0; +const BorderType RandomCropOp::kDefBorderType = BorderType::kConstant; +const bool RandomCropOp::kDefPadIfNeeded = false; +const uint8_t RandomCropOp::kDefFillR = 0; +const uint8_t RandomCropOp::kDefFillG = 0; +const uint8_t RandomCropOp::kDefFillB = 0; + +RandomCropOp::RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top, int32_t pad_bottom, + int32_t pad_left, int32_t pad_right, BorderType border_types, bool pad_if_needed, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) + : crop_height_(crop_height), + crop_width_(crop_width), + pad_top_(pad_top), + pad_bottom_(pad_bottom), + pad_left_(pad_left), + pad_right_(pad_right), + pad_if_needed_(pad_if_needed), + border_type_(border_types), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) { + rnd_.seed(GetSeed()); +} + +Status RandomCropOp::ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, + int32_t *t_pad_top, int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, + int32_t *padded_image_w, int32_t *padded_image_h, bool *crop_further) { + *t_pad_top = pad_top_; + *t_pad_bottom = pad_bottom_; + *t_pad_left = pad_left_; + *t_pad_right = pad_right_; + + RETURN_IF_NOT_OK( + Pad(input, pad_image, pad_top_, pad_bottom_, pad_left_, pad_right_, border_type_, fill_r_, fill_g_, fill_b_)); + CHECK_FAIL_RETURN_UNEXPECTED((*pad_image)->shape().Size() >= 2, "Abnormal shape"); + + *padded_image_h = (*pad_image)->shape()[0]; + *padded_image_w = (*pad_image)->shape()[1]; + + if (*padded_image_h == crop_height_ && *padded_image_w == crop_width_) { + *crop_further = false; // no need for further crop + return Status::OK(); + } else if (pad_if_needed_) { + // check the dimensions of the image for padding, if we do need padding, then we change the pad values + if (*padded_image_h < crop_height_) { + RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, crop_height_ - *padded_image_h, crop_height_ - *padded_image_h, 0, 0, + border_type_, fill_r_, fill_g_, fill_b_)); + + // update pad total above/below + t_pad_top += (crop_height_ - *padded_image_h); + t_pad_bottom += (crop_height_ - *padded_image_h); + } + if (*padded_image_w < crop_width_) { + RETURN_IF_NOT_OK(Pad(*pad_image, pad_image, 0, 0, crop_width_ - *padded_image_w, crop_width_ - *padded_image_w, + border_type_, fill_r_, fill_g_, fill_b_)); + // update pad total left/right + t_pad_left += (crop_width_ - *padded_image_w); + t_pad_right += (crop_width_ - *padded_image_w); + } + *padded_image_h = (*pad_image)->shape()[0]; + *padded_image_w = (*pad_image)->shape()[1]; + } + + if (*padded_image_h < crop_height_ || *padded_image_w < crop_width_ || crop_height_ == 0 || crop_width_ == 0) { + return Status(StatusCode::kShapeMisMatch, __LINE__, __FILE__, + "Crop size is greater than the image dimensions or is zero."); + } + return Status::OK(); +} + +void RandomCropOp::GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h) { + // GenCropPoints for cropping + *x = std::uniform_int_distribution(0, padded_image_w - crop_width_)(rnd_); + *y = std::uniform_int_distribution(0, padded_image_h - crop_height_)(rnd_); +} + +Status RandomCropOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + + // Apply padding first then crop + std::shared_ptr pad_image; + int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; + int32_t padded_image_w; + int32_t padded_image_h; + bool crop_further = true; // whether image needs further cropping based on new size & requirements + + RETURN_IF_NOT_OK( // error code sent back directly + ImagePadding(input, &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, &padded_image_w, + &padded_image_h, &crop_further)); + if (!crop_further) { + *output = pad_image; + return Status::OK(); + } + + int x, y; + GenRandomXY(&x, &y, padded_image_w, padded_image_h); + return Crop(pad_image, output, x, y, crop_width_, crop_height_); +} + +Status RandomCropOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out = TensorShape{crop_height_, crop_width_}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h new file mode 100644 index 0000000000..44f1789f9d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomCropOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefPadTop; + static const int32_t kDefPadBottom; + static const int32_t kDefPadLeft; + static const int32_t kDefPadRight; + static const BorderType kDefBorderType; + static const bool kDefPadIfNeeded; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + RandomCropOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, + int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, int32_t pad_right = kDefPadRight, + BorderType border_types = kDefBorderType, bool pad_if_needed = kDefPadIfNeeded, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + RandomCropOp(const RandomCropOp &rhs) = default; + + RandomCropOp(RandomCropOp &&rhs) = default; + + ~RandomCropOp() override = default; + + void Print(std::ostream &out) const override { out << "RandomCropOp: " << crop_height_ << " " << crop_width_; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // Function breaks out the compute function's image padding functionality and makes available to other Ops + // Using this class as a base - restructrued to allow for RandomCropWithBBox Augmentation Op + // @param input: Input is the original Image + // @param pad_image: Pointer to new Padded image + // @param t_pad_top: Total Top Padding - Based on input and value calculated in function if required + // @param t_pad_bottom: Total bottom Padding - Based on input and value calculated in function if required + // @param t_pad_left: Total left Padding - Based on input and value calculated in function if required + // @param t_pad_right: Total right Padding - Based on input and value calculated in function if required + // @param padded_image_w: Final Width of the 'pad_image' + // @param padded_image_h: Final Height of the 'pad_image' + // @param crop_further: Whether image required cropping after padding - False if new padded image matches required + // dimensions + Status ImagePadding(const std::shared_ptr &input, std::shared_ptr *pad_image, int32_t *t_pad_top, + int32_t *t_pad_bottom, int32_t *t_pad_left, int32_t *t_pad_right, int32_t *padded_image_w, + int32_t *padded_image_h, bool *crop_further); + + // Function breaks X,Y generation functionality out of original compute function and makes available to other Ops + void GenRandomXY(int *x, int *y, const int32_t &padded_image_w, const int32_t &padded_image_h); + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kRandomCropOp; } + + protected: + int32_t crop_height_ = 0; + int32_t crop_width_ = 0; + + private: + int32_t pad_top_ = 0; + int32_t pad_bottom_ = 0; + int32_t pad_left_ = 0; + int32_t pad_right_ = 0; + bool pad_if_needed_ = false; + BorderType border_type_; + uint8_t fill_r_ = 0; + uint8_t fill_g_ = 0; + uint8_t fill_b_ = 0; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.cc new file mode 100644 index 0000000000..08b12b8b70 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status RandomCropWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + std::shared_ptr pad_image; + int32_t t_pad_top, t_pad_bottom, t_pad_left, t_pad_right; + size_t boxCount = input[1]->shape()[0]; // number of rows + + int32_t padded_image_h; + int32_t padded_image_w; + + output->resize(2); + (*output)[1] = std::move(input[1]); // since some boxes may be removed + + bool crop_further = true; // Whether further cropping will be required or not, true unless required size matches + RETURN_IF_NOT_OK( // Error passed back to caller + RandomCropOp::ImagePadding(input[0], &pad_image, &t_pad_top, &t_pad_bottom, &t_pad_left, &t_pad_right, + &padded_image_w, &padded_image_h, &crop_further)); + + // update bounding boxes with new values based on relevant image padding + if (t_pad_left || t_pad_bottom) { + RETURN_IF_NOT_OK(PadBBoxes(&(*output)[1], boxCount, t_pad_left, t_pad_top)); + } + if (!crop_further) { + // no further cropping required + (*output)[0] = pad_image; + (*output)[1] = std::move(input[1]); + return Status::OK(); + } + + int x, y; + RandomCropOp::GenRandomXY(&x, &y, padded_image_w, padded_image_h); + int maxX = x + RandomCropOp::crop_width_; // max dims of selected CropBox on image + int maxY = y + RandomCropOp::crop_height_; + RETURN_IF_NOT_OK(UpdateBBoxesForCrop(&(*output)[1], &boxCount, x, y, maxX, maxY)); + return Crop(pad_image, &(*output)[0], x, y, RandomCropOp::crop_width_, RandomCropOp::crop_height_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h new file mode 100644 index 0000000000..bfcd1610d3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_with_bbox_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/kernels/image/random_crop_op.h" + +namespace mindspore { +namespace dataset { +class RandomCropWithBBoxOp : public RandomCropOp { + public: + // Constructor for RandomCropWithBBoxOp, with default value and passing to base class constructor + RandomCropWithBBoxOp(int32_t crop_height, int32_t crop_width, int32_t pad_top = kDefPadTop, + int32_t pad_bottom = kDefPadBottom, int32_t pad_left = kDefPadLeft, + int32_t pad_right = kDefPadRight, BorderType border_types = kDefBorderType, + bool pad_if_needed = kDefPadIfNeeded, uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, + uint8_t fill_b = kDefFillB) + : RandomCropOp(crop_height, crop_width, pad_top, pad_bottom, pad_left, pad_right, border_types, pad_if_needed, + fill_r, fill_g, fill_b) {} + + ~RandomCropWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { + out << "RandomCropWithBBoxOp: " << RandomCropOp::crop_height_ << " " << RandomCropOp::crop_width_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomCropWithBBoxOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_CROP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.cc new file mode 100644 index 0000000000..5e8ab8a634 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomHorizontalFlipOp::kDefProbability = 0.5; + +Status RandomHorizontalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (distribution_(rnd_)) { + return HorizontalFlip(input, output); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h new file mode 100644 index 0000000000..9e08929180 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomHorizontalFlipOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + + explicit RandomHorizontalFlipOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomHorizontalFlipOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << "RandomHorizontalFlipOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomHorizontalFlipOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc new file mode 100644 index 0000000000..809f564b18 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/cv_tensor.h" + +namespace mindspore { +namespace dataset { +const float RandomHorizontalFlipWithBBoxOp::kDefProbability = 0.5; + +Status RandomHorizontalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + if (distribution_(rnd_)) { + // To test bounding boxes algorithm, create random bboxes from image dims + size_t num_of_boxes = input[1]->shape()[0]; // set to give number of bboxes + float img_center = (input[0]->shape()[1] / 2.); // get the center of the image + for (int i = 0; i < num_of_boxes; i++) { + float b_w = 0; // bounding box width + float min_x = 0; + // get the required items + RETURN_IF_NOT_OK(input[1]->GetItemAt(&min_x, {i, 0})); + RETURN_IF_NOT_OK(input[1]->GetItemAt(&b_w, {i, 2})); + // do the flip + float diff = img_center - min_x; // get distance from min_x to center + float refl_min_x = diff + img_center; // get reflection of min_x + float new_min_x = refl_min_x - b_w; // subtract from the reflected min_x to get the new one + RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 0}, new_min_x)); + } + (*output).resize(2); + // move input to output pointer of bounding boxes + (*output)[1] = std::move(input[1]); + // perform HorizontalFlip on the image + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); + return HorizontalFlip(std::static_pointer_cast(input_cv), &(*output)[0]); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h new file mode 100644 index 0000000000..d98669ea13 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomHorizontalFlipWithBBoxOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + + explicit RandomHorizontalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomHorizontalFlipWithBBoxOp() override = default; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const RandomHorizontalFlipWithBBoxOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << "RandomHorizontalFlipWithBBoxOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomHorizontalFlipWithBBoxOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_HORIZONTAL_FLIP_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.cc new file mode 100644 index 0000000000..8736f0a6a5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_resize_op.h" + +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomResizeOp::kDefTargetWidth = 0; + +Status RandomResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + // Randomly selects from the following four interpolation methods + // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area + interpolation_ = static_cast(distribution_(random_generator_)); + return ResizeOp::Compute(input, output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h new file mode 100644 index 0000000000..8b2b067751 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_op.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomResizeOp : public ResizeOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefTargetWidth; + + explicit RandomResizeOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeOp(size_1, size_2) { + random_generator_.seed(GetSeed()); + } + + ~RandomResizeOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { + out << "RandomResizeOp: " << ResizeOp::size1_ << " " << ResizeOp::size2_; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomResizeOp; } + + private: + std::mt19937 random_generator_; + std::uniform_int_distribution distribution_{0, 3}; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.cc new file mode 100644 index 0000000000..e099b78a0f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0; + +Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + // Randomly selects from the following four interpolation methods + // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area + interpolation_ = static_cast(distribution_(random_generator_)); + RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h new file mode 100644 index 0000000000..6bad0d30fa --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_resize_with_bbox_op.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H +#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefTargetWidth; + explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { + random_generator_.seed(GetSeed()); + } + + ~RandomResizeWithBBoxOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { + out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomResizeWithBBoxOp; } + + private: + std::mt19937 random_generator_; + std::uniform_int_distribution distribution_{0, 3}; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc new file mode 100644 index 0000000000..b2cb4facae --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/random_rotation_op.h" + +#include + +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomRotationOp::kDefCenterX = -1; +const float RandomRotationOp::kDefCenterY = -1; +const InterpolationMode RandomRotationOp::kDefInterpolation = InterpolationMode::kNearestNeighbour; +const bool RandomRotationOp::kDefExpand = false; +const uint8_t RandomRotationOp::kDefFillR = 0; +const uint8_t RandomRotationOp::kDefFillG = 0; +const uint8_t RandomRotationOp::kDefFillB = 0; + +// constructor +RandomRotationOp::RandomRotationOp(float start_degree, float end_degree, float center_x, float center_y, + InterpolationMode interpolation, bool expand, uint8_t fill_r, uint8_t fill_g, + uint8_t fill_b) + : degree_start_(start_degree), + degree_end_(end_degree), + center_x_(center_x), + center_y_(center_y), + interpolation_(interpolation), + expand_(expand), + fill_r_(fill_r), + fill_g_(fill_g), + fill_b_(fill_b) { + rnd_.seed(GetSeed()); +} + +// main function call for random rotation : Generate the random degrees +Status RandomRotationOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + float random_double = distribution_(rnd_); + // get the degree rotation range, mod by 360 because full rotation doesn't affect + // the way this op works (uniform distribution) + // assumption here is that mDegreesEnd > mDegreeStart so we always get positive number + // Note: the range technically is greater than 360 degrees, but will be halved + float degree_range = (degree_end_ - degree_start_) / 2; + float mid = (degree_end_ + degree_start_) / 2; + float degree = mid + random_double * degree_range; + + return Rotate(input, output, center_x_, center_y_, degree, interpolation_, expand_, fill_r_, fill_g_, fill_b_); +} +Status RandomRotationOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + int32_t outputH = -1, outputW = -1; + // if expand_, then we cannot know the shape. We need the input image to find the output shape --> set it to + // <-1,-1[,3]> + if (!expand_) { + outputH = inputs[0][0]; + outputW = inputs[0][1]; + } + TensorShape out = TensorShape{outputH, outputW}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h new file mode 100644 index 0000000000..ea679ccb56 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_rotation_op.h @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { +class RandomRotationOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefCenterX; + static const float kDefCenterY; + static const InterpolationMode kDefInterpolation; + static const bool kDefExpand; + static const uint8_t kDefFillR; + static const uint8_t kDefFillG; + static const uint8_t kDefFillB; + + // Constructor for RandomRotationOp + // @param startDegree starting range for random degree + // @param endDegree ending range for random degree + // @param centerX x coordinate for center of image rotation + // @param centerY y coordinate for center of image rotation + // @param interpolation DE interpolation mode for rotation + // @param expand option for the output image shape to change + // @param fill_r R value for the color to pad with + // @param fill_g G value for the color to pad with + // @param fill_b B value for the color to pad with + // @details the randomly chosen degree is uniformly distributed + // @details the output shape, if changed, will contain the entire rotated image + // @note maybe using unsigned long int isn't the best here according to our coding rules + RandomRotationOp(float start_degree, float end_degree, float center_x = kDefCenterX, float center_y = kDefCenterY, + InterpolationMode interpolation = kDefInterpolation, bool expand = kDefExpand, + uint8_t fill_r = kDefFillR, uint8_t fill_g = kDefFillG, uint8_t fill_b = kDefFillB); + + ~RandomRotationOp() override = default; + + // Print function for RandomRotation + // @param out output stream to print to + void Print(std::ostream &out) const override { out << "RandomRotationOp: "; } + + // Overrides the base class compute function + // Calls the rotate function in ImageUtils, this function takes an input tensor + // and transforms its data using openCV, the output memory is manipulated to contain the result + // @return Status - The error code return + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kRandomRotationOp; } + + private: + float degree_start_; + float degree_end_; + float center_x_; + float center_y_; + InterpolationMode interpolation_; + bool expand_; + uint8_t fill_r_; + uint8_t fill_g_; + uint8_t fill_b_; + std::uniform_real_distribution distribution_{-1.0, 1.0}; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_ROTATION_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.cc new file mode 100644 index 0000000000..24d816ef1a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const float RandomVerticalFlipOp::kDefProbability = 0.5; + +Status RandomVerticalFlipOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (distribution_(rnd_)) { + return VerticalFlip(input, output); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h new file mode 100644 index 0000000000..cee5869c71 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_op.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class RandomVerticalFlipOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + + explicit RandomVerticalFlipOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomVerticalFlipOp() override = default; + + void Print(std::ostream &out) const override { out << "RandomVerticalFlipOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRandomVerticalFlipOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc new file mode 100644 index 0000000000..7d2fa7bab5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" + +namespace mindspore { +namespace dataset { +const float RandomVerticalFlipWithBBoxOp::kDefProbability = 0.5; +Status RandomVerticalFlipWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + if (distribution_(rnd_)) { + dsize_t imHeight = input[0]->shape()[0]; + size_t boxCount = input[1]->shape()[0]; // number of rows in tensor + + // one time allocation -> updated in the loop + // type defined based on VOC test dataset + for (int i = 0; i < boxCount; i++) { + float boxCorner_y = 0.0, boxHeight = 0.0; + float newBoxCorner_y = 0.0; + RETURN_IF_NOT_OK(input[1]->GetItemAt(&boxCorner_y, {i, 1})); // get min y of bbox + RETURN_IF_NOT_OK(input[1]->GetItemAt(&boxHeight, {i, 3})); // get height of bbox + + // subtract (curCorner + height) from (max) for new Corner position + newBoxCorner_y = (imHeight - 1.0) - ((boxCorner_y + boxHeight) - 1.0); + RETURN_IF_NOT_OK(input[1]->SetItemAt({i, 1}, newBoxCorner_y)); + } + + output->resize(2); + (*output)[1] = std::move(input[1]); + + return VerticalFlip(input[0], &(*output)[0]); + } + *output = input; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h new file mode 100644 index 0000000000..c9f19f5217 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ +#define DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +class RandomVerticalFlipWithBBoxOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const float kDefProbability; + // Constructor for RandomVerticalFlipWithBBoxOp + // @param probability: Probablity of Image flipping, 0.5 by default + explicit RandomVerticalFlipWithBBoxOp(float probability = kDefProbability) : distribution_(probability) { + rnd_.seed(GetSeed()); + } + + ~RandomVerticalFlipWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { out << "RandomVerticalFlipWithBBoxOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kRandomVerticalFlipWithBBoxOp; } + + private: + std::mt19937 rnd_; + std::bernoulli_distribution distribution_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_VERTICAL_FLIP_WITH_BBOX_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc new file mode 100644 index 0000000000..2a500d6c34 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/rescale_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status RescaleOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Rescale(input, output, rescale_, shift_); +} +Status RescaleOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + outputs[0] = DataType(DataType::DE_FLOAT32); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h new file mode 100644 index 0000000000..c70b7bf6cf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESCALE_OP_H_ +#define DATASET_KERNELS_IMAGE_RESCALE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RescaleOp : public TensorOp { + public: + RescaleOp(float rescale_ratio, float shift_ratio) : rescale_(rescale_ratio), shift_(shift_ratio) {} + + ~RescaleOp() override = default; + + void Print(std::ostream &out) const override { + out << "RescaleOp: shift: " << shift_ << ", Rescale: " << rescale_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kRescaleOp; } + + private: + float rescale_; + float shift_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_IMAGE_RESCALE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc new file mode 100644 index 0000000000..48a8fbbc53 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" +#include + +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t ResizeBilinearOp::kDefWidth = 0; + +void ResizeBilinearOp::Print(std::ostream &out) const { out << "ResizeBilinearOp: "; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h new file mode 100644 index 0000000000..fd8f940946 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_bilinear_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ +#define DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ + +#include +#include +#include +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class ResizeBilinearOp : public ResizeOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefWidth; + + // Name: constructor + // Resizes the image to the output specified size using Bilinear interpolation. + // If only one value is provided, the it will resize the smaller size and maintains + // the aspect ratio. + // @param size1: the first size of output. If only this parameter is provided + // the smaller dimension will be resized to this and then the other dimension changes + // such that the aspect ratio is maintained. + // @param size2: the second size of output. If this is also provided, the output size + // will be (size1, size2) + explicit ResizeBilinearOp(int32_t size1, int32_t size2 = kDefWidth) + : ResizeOp(size1, size2, ResizeOp::kDefInterpolation) {} + + // Name: Destructor + // Description: Destructor + ~ResizeBilinearOp() = default; + + // Name: Print() + // Description: A function that prints info about the node + void Print(std::ostream &out) const override; + + std::string Name() const override { return kResizeBilinearOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.cc new file mode 100644 index 0000000000..7456f50f32 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/resize_op.h" + +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t ResizeOp::kDefWidth = 0; +const InterpolationMode ResizeOp::kDefInterpolation = InterpolationMode::kLinear; + +Status ResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape size " + std::to_string(input->shape().Size()) + + " of input tensor is invalid"); + int32_t output_h, output_w = 0; + int32_t input_h = static_cast(input->shape()[0]); + int32_t input_w = static_cast(input->shape()[1]); + if (size2_ == 0) { + if (input_h < input_w) { + CHECK_FAIL_RETURN_UNEXPECTED(input_h != 0, "The input height is 0"); + output_h = size1_; + output_w = static_cast(std::lround(static_cast(input_w) / input_h * output_h)); + } else { + CHECK_FAIL_RETURN_UNEXPECTED(input_w != 0, "The input width is 0"); + output_w = size1_; + output_h = static_cast(std::lround(static_cast(input_h) / input_w * output_w)); + } + } else { + output_h = size1_; + output_w = size2_; + } + return Resize(input, output, output_h, output_w, 0, 0, interpolation_); +} + +Status ResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + int32_t outputH = -1, outputW = -1; + // if size2_ == 0, then we cannot know the shape. We need the input image to find the output shape --> set it to + // <-1,-1[,3]> + if (size2_ != 0) { + outputH = size1_; + outputW = size2_; + } + TensorShape out = TensorShape{outputH, outputW}; + if (inputs[0].Rank() == 2) outputs.emplace_back(out); + if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h new file mode 100644 index 0000000000..3f847243ff --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h @@ -0,0 +1,68 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_OP_H_ +#define DATASET_KERNELS_IMAGE_RESIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class ResizeOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefWidth; + static const InterpolationMode kDefInterpolation; + + // Resizes the image to the output specified size. If only one value is provided, + // the it will resize the smaller size and maintains the aspect ratio. + // @param size1: the first size of output. If only this parameter is provided + // the smaller dimension will be resized to this and then the other dimension changes + // such that the aspect ratio is maintained. + // @param size2: the second size of output. If this is also provided, the output size + // will be (size1, size2) + // @param InterpolationMode: the interpolation mode being used. + explicit ResizeOp(int32_t size1, int32_t size2 = kDefWidth, InterpolationMode mInterpolation = kDefInterpolation) + : size1_(size1), size2_(size2), interpolation_(mInterpolation) {} + + ResizeOp(const ResizeOp &rhs) = default; + + ResizeOp(ResizeOp &&rhs) = default; + + ~ResizeOp() override = default; + + void Print(std::ostream &out) const override { out << "ResizeOp: " << size1_ << " " << size2_; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kResizeOp; } + + protected: + int32_t size1_; + int32_t size2_; + InterpolationMode interpolation_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc new file mode 100644 index 0000000000..9df2d8a25e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" +#include +#include +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/pybind_support.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + int32_t input_h = input[0]->shape()[0]; + int32_t input_w = input[0]->shape()[1]; + + output->resize(2); + (*output)[1] = std::move(input[1]); // move boxes over to output + + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); + + RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast(input_cv), &(*output)[0])); + + int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox + int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox + + size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor + RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h new file mode 100644 index 0000000000..d2b5c96bf3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_with_bbox_op.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H +#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H + +#include +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/kernels/image/resize_op.h" + +namespace mindspore { +namespace dataset { +class ResizeWithBBoxOp : public ResizeOp { + public: + // Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor + explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth, + InterpolationMode mInterpolation = kDefInterpolation) + : ResizeOp(size_1, size_2, mInterpolation) {} + + ~ResizeWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kResizeWithBBoxOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.cc new file mode 100644 index 0000000000..95d75af0f2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include +#include "minddata/dataset/kernels/image/uniform_aug_op.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +const int UniformAugOp::kDefNumOps = 2; + +UniformAugOp::UniformAugOp(std::vector> op_list, int32_t num_ops) + : tensor_op_list_(op_list), num_ops_(num_ops) { + rnd_.seed(GetSeed()); +} + +// compute method to apply uniformly random selected augmentations from a list +Status UniformAugOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + + // randomly select ops to be applied + std::vector> selected_tensor_ops; + std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); + + bool first = true; + for (const auto &tensor_op : selected_tensor_ops) { + // Do NOT apply the op, if second random generator returned zero + if (std::uniform_int_distribution(0, 1)(rnd_)) { + continue; + } + // apply C++ ops (note: python OPs are not accepted) + if (first) { + RETURN_IF_NOT_OK(tensor_op->Compute(input, output)); + first = false; + } else { + RETURN_IF_NOT_OK(tensor_op->Compute(std::move(*output), output)); + } + } + + // The case where no tensor op is applied. + if (output->empty()) { + *output = input; + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h new file mode 100644 index 0000000000..0ae0fda92b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/uniform_aug_op.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class UniformAugOp : public TensorOp { + public: + // Default number of Operations to be applied + static const int kDefNumOps; + + // Constructor for UniformAugOp + // @param std::vector> op_list: list of candidate C++ operations + // @param int32_t num_ops: number of augemtation operations to applied + UniformAugOp(std::vector> op_list, int32_t num_ops); + + // Destructor + ~UniformAugOp() override = default; + + void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } + + // Overrides the base class compute function + // @return Status - The error code return + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kUniformAugOp; } + + private: + int32_t num_ops_; + std::vector> tensor_op_list_; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/no_op.h b/mindspore/ccsrc/minddata/dataset/kernels/no_op.h new file mode 100644 index 0000000000..f5a6a58f2b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/no_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_NO_OP_H_ +#define DATASET_KERNELS_NO_OP_H_ + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class NoOp : public TensorOp { + public: + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override { + *output = input; + return Status::OK(); + } + + void Print(std::ostream &out) const override { out << "NoOp"; }; + + std::string Name() const override { return kNoOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_NO_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc new file mode 100644 index 0000000000..f501dd4b4f --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/py_func_op.h" + +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + Status ret = Status(StatusCode::kOK, "PyFunc Call Succeed"); + { + // Acquire Python GIL + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + ret = Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + goto ComputeReturn; + } + try { + // Transform input tensor vector into numpy array vector + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + // possible memcpy here + input_args[i] = new_data; + } + // Invoke python function + py::object ret_py_obj = this->py_func_ptr_(*input_args); + // Process the return value + if (py::isinstance(ret_py_obj)) { + // In case of a n-1 mapping, the return value will be a numpy array + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_obj.cast())); + output->push_back(out); + } else if (py::isinstance(ret_py_obj)) { + // In case of a n-m mapping, the return value will be a tuple of numpy arrays + py::tuple ret_py_tuple = ret_py_obj.cast(); + // Iterate over two containers simultaneously for memory copy + for (size_t i = 0; i < ret_py_tuple.size(); i++) { + py::object ret_py_ele = ret_py_tuple[i]; + if (!py::isinstance(ret_py_ele)) { + goto ShapeMisMatch; + } + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, ret_py_ele.cast())); + output->push_back(out); + } + } else { + goto ShapeMisMatch; + } + } catch (const py::error_already_set &e) { + ret = Status(StatusCode::kPyFuncException, e.what()); + } + } + +ComputeReturn: + return ret; + +ShapeMisMatch: + ret = Status(StatusCode::kShapeMisMatch, "PyFunc should return a numpy array or a numpy array tuple"); + goto ComputeReturn; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h new file mode 100644 index 0000000000..75d222b433 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/py_func_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_PY_FUNC_OP_H_ +#define DATASET_KERNELS_PY_FUNC_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class __attribute__((visibility("hidden"))) PyFuncOp : public TensorOp { + public: + explicit PyFuncOp(py::function func) : py_func_ptr_(std::move(func)) {} + + ~PyFuncOp() override = default; + + uint32_t NumInput() override { return 0; } + uint32_t NumOutput() override { return 0; } + + // Compute function for n-n mapping. + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kPyFuncOp; } + + private: + py::function py_func_ptr_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_PY_FUNC_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc new file mode 100644 index 0000000000..b625e3b532 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/tensor_op.h" +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { +// Name: Compute() +// Description: This Compute() take 1 Tensor and produce 1 Tensor. +// The derived class should override this function otherwise error. +Status TensorOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!OneToOne()) { + return Status(StatusCode::kUnexpectedError, "Wrong Compute() function is called. This is not 1-1 TensorOp."); + } else { + return Status(StatusCode::kUnexpectedError, + "Is this TensorOp 1-1? If yes, please implement this Compute() in the derived class."); + } +} + +// Name: Compute() +// Description: This Compute() take multiple Tensors from different columns and produce multiple Tensors too. +// The derived class should override this function otherwise error. +Status TensorOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + if (OneToOne()) { + output->resize(1); + return Compute(input[0], &(*output)[0]); + } + + return Status(StatusCode::kUnexpectedError, + "Is this TensorOp oneToOne? If no, please implement this Compute() in the derived class."); +} + +void TensorOp::Print(std::ostream &out) const { out << "TensorOp" << std::endl; } + +Status TensorOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + if (inputs.size() != NumInput()) + return Status(StatusCode::kUnexpectedError, + "The size of the input argument vector does not match the number of inputs"); + outputs = inputs; + return Status::OK(); +} + +Status TensorOp::OutputType(const std::vector &inputs, std::vector &outputs) { + if (inputs.size() != NumInput()) + return Status(StatusCode::kUnexpectedError, + "The size of the input argument vector does not match the number of inputs"); + outputs = inputs; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h new file mode 100644 index 0000000000..3bcba4b463 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -0,0 +1,212 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_TENSOR_OP_H_ +#define DATASET_KERNELS_TENSOR_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/util/status.h" + +#define IO_CHECK(input, output) \ + do { \ + if (input == nullptr || output == nullptr) { \ + RETURN_STATUS_UNEXPECTED("input or output is null."); \ + } \ + } while (false) + +#define IO_CHECK_VECTOR(input, output) \ + do { \ + if (output == nullptr) { \ + RETURN_STATUS_UNEXPECTED("output is null."); \ + } \ + for (auto &_i : input) { \ + if (_i == nullptr) { \ + RETURN_STATUS_UNEXPECTED("input is null."); \ + } \ + } \ + } while (false) + +#define BOUNDING_BOX_CHECK(input) \ + do { \ + if (input.size() != 2) { \ + return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ + "Requires Image and Bounding Boxes, likely missed bounding boxes."); \ + } \ + if (input[1]->shape().Size() < 2) { \ + return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ + "Bounding boxes shape should have at least two dimensions."); \ + } \ + uint32_t num_of_features = input[1]->shape()[1]; \ + if (num_of_features < 4) { \ + return Status(StatusCode::kBoundingBoxInvalidShape, __LINE__, __FILE__, \ + "Bounding boxes should be have at least 4 features."); \ + } \ + uint32_t num_of_boxes = input[1]->shape()[0]; \ + uint32_t img_h = input[0]->shape()[0]; \ + uint32_t img_w = input[0]->shape()[1]; \ + for (uint32_t i = 0; i < num_of_boxes; i++) { \ + float min_x = 0.0, min_y = 0.0, b_w = 0.0, b_h = 0.0; \ + bool passing_data_fetch = true; \ + passing_data_fetch &= input[1]->GetItemAt(&min_x, {i, 0}).IsOk(); \ + passing_data_fetch &= input[1]->GetItemAt(&min_y, {i, 1}).IsOk(); \ + passing_data_fetch &= input[1]->GetItemAt(&b_w, {i, 2}).IsOk(); \ + passing_data_fetch &= input[1]->GetItemAt(&b_h, {i, 3}).IsOk(); \ + if (!passing_data_fetch) { \ + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, \ + "Fetching BBox values failed in BOUNDING_BOX_CHECK."); \ + } \ + if ((min_x + b_w > img_w) || (min_y + b_h > img_h)) { \ + return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ + "At least one of the bounding boxes is out of bounds of the image."); \ + } \ + if (static_cast(min_x) < 0 || static_cast(min_y) < 0) { \ + return Status(StatusCode::kBoundingBoxOutOfBounds, __LINE__, __FILE__, \ + "At least one of the bounding boxes has negative min_x or min_y."); \ + } \ + } \ + } while (false) + +namespace mindspore { +namespace dataset { + +// image +constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; +constexpr char kDecodeOp[] = "DecodeOp"; +constexpr char kCenterCropOp[] = "CenterCropOp"; +constexpr char kCutOutOp[] = "CutOutOp"; +constexpr char kHwcToChwOp[] = "HwcToChwOp"; +constexpr char kNormalizeOp[] = "NormalizeOp"; +constexpr char kPadOp[] = "PadOp"; +constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; +constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp"; +constexpr char kRandomCropAndResizeWithBBoxOp[] = "RandomCropAndResizeWithBBoxOp"; +constexpr char kRandomCropDecodeResizeOp[] = "RandomCropDecodeResizeOp"; +constexpr char kRandomCropOp[] = "RandomCropOp"; +constexpr char kRandomCropWithBBoxOp[] = "RandomCropWithBBoxOp"; +constexpr char kRandomHorizontalFlipWithBBoxOp[] = "RandomHorizontalFlipWithBBoxOp"; +constexpr char kRandomHorizontalFlipOp[] = "RandomHorizontalFlipOp"; +constexpr char kRandomResizeOp[] = "RandomResizeOp"; +constexpr char kRandomResizeWithBBoxOp[] = "RandomResizeWithBBoxOp"; +constexpr char kRandomRotationOp[] = "RandomRotationOp"; +constexpr char kRandomVerticalFlipOp[] = "RandomVerticalFlipOp"; +constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp"; +constexpr char kRescaleOp[] = "RescaleOp"; +constexpr char kResizeBilinearOp[] = "ResizeBilinearOp"; +constexpr char kResizeOp[] = "ResizeOp"; +constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; +constexpr char kUniformAugOp[] = "UniformAugOp"; + +// text +constexpr char kBasicTokenizerOp[] = "BasicTokenizerOp"; +constexpr char kBertTokenizerOp[] = "BertTokenizerOp"; +constexpr char kCaseFoldOp[] = "CaseFoldOp"; +constexpr char kJiebaTokenizerOp[] = "JiebaTokenizerOp"; +constexpr char kLookupOp[] = "LookupOp"; +constexpr char kNgramOp[] = "NgramOp"; +constexpr char kNormalizeUTF8Op[] = "NormalizeUTF8Op"; +constexpr char kRegexReplaceOp[] = "RegexReplaceOp"; +constexpr char kRegexTokenizerOp[] = "RegexTokenizerOp"; +constexpr char kToNumberOp[] = "ToNumberOp"; +constexpr char kTruncateSequencePairOp[] = "TruncateSequencePairOp"; +constexpr char kUnicodeCharTokenizerOp[] = "UnicodeCharTokenizerOp"; +constexpr char kUnicodeScriptTokenizerOp[] = "UnicodeScriptTokenizerOp"; +constexpr char kWhitespaceTokenizerOp[] = "WhitespaceTokenizerOp"; +constexpr char kWordpieceTokenizerOp[] = "WordpieceTokenizerOp"; + +// data +constexpr char kConcatenateOp[] = "kConcatenateOp"; +constexpr char kDuplicateOp[] = "DuplicateOp"; +constexpr char kFillOp[] = "FillOp"; +constexpr char kMaskOp[] = "MaskOp"; +constexpr char kOneHotOp[] = "OneHotOp"; +constexpr char kPadEndOp[] = "PadEndOp"; +constexpr char kSliceOp[] = "SliceOp"; +constexpr char kToFloat16Op[] = "ToFloat16Op"; +constexpr char kTypeCastOp[] = "TypeCastOp"; + +// other +constexpr char kPyFuncOp[] = "PyFuncOp"; +constexpr char kNoOp[] = "NoOp"; + +// A class that does a computation on a Tensor +class TensorOp { + public: + TensorOp() = default; + + virtual ~TensorOp() = default; + + // A function that prints info about the tensor operation + // @param out + virtual void Print(std::ostream &out) const; + + // Provide stream operator for displaying it + // @param output stream + // @param so the TensorOp object to be printed + // @return output stream + friend std::ostream &operator<<(std::ostream &out, const TensorOp &so) { + so.Print(out); + return out; + } + + // Perform an operation on one Tensor and produce one Tensor. This is for 1-to-1 column MapOp + // @param input shares the ownership of the Tensor (increase the ref count). + // @param output the address to a shared_ptr where the result will be placed. + // @return Status + virtual Status Compute(const std::shared_ptr &input, std::shared_ptr *output); + + // Perform an operation on Tensors from multiple columns, and produce multiple Tensors. + // This is for m-to-n column MapOp. + // @param input is a vector of shared_ptr to Tensor (pass by const reference). + // @param output is the address to an empty vector of shared_ptr to Tensor. + // @return Status + virtual Status Compute(const TensorRow &input, TensorRow *output); + + // Returns true oif the TensorOp takes one input and returns one output. + // @return true/false + bool OneToOne() { return NumInput() == 1 && NumOutput() == 1; } + + // Function to determine the number of inputs the TensorOp can take. 0: means undefined. + // @return uint32_t + virtual uint32_t NumInput() { return 1; } + + // Function to determine the number of output the TensorOp generates. 0: means undefined. + // @return uint32_t + virtual uint32_t NumOutput() { return 1; } + + // Function to determine the shapes of the output tensor given the input tensors' shapes. + // If a subclass did not override this function, it means that the shape does not change. + // @param inputs in: vector of the shapes of the input tensors. + // @param outputs out: vector of the shapes of the output tensors to be filled. + // @return Status + virtual Status OutputShape(const std::vector &inputs, std::vector &outputs); + + // Function to determine the types of the output tensor given the input tensor's types. + // If a subclass did not override this function, it means that the type does not change. + // @param inputs in: vector of the types of the input tensors. + // @param outputs out: vector of the types of the output tensors to be filled. + // @return Status + virtual Status OutputType(const std::vector &inputs, std::vector &outputs); + + virtual std::string Name() const = 0; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_TENSOR_OP_H_ diff --git a/mindspore/ccsrc/dataset/text/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/text/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/text/CMakeLists.txt diff --git a/mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/text/kernels/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/text/kernels/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc new file mode 100644 index 0000000000..6195572944 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.cc @@ -0,0 +1,173 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include +#include +#include +#include +#include +#include + +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + +namespace mindspore { +namespace dataset { + +const bool BasicTokenizerOp::kDefLowerCase = false; +const bool BasicTokenizerOp::kDefKeepWhitespace = false; +const NormalizeForm BasicTokenizerOp::kDefNormalizationForm = NormalizeForm::kNone; +const bool BasicTokenizerOp::kDefPreserveUnusedToken = true; +const bool BasicTokenizerOp::kDefWithOffsets = false; +const char BasicTokenizerOp::kCommonPattern[] = + "[!-/]" + "|[:-@]" + "|[\\[-`]" + "|[{-~]" + "|[\\p{P}]" + "|[\\x{4E00}-\\x{9FFF}]" + "|[\\x{3400}-\\x{4DBF}]" + "|[\\x{20000}-\\x{2A6DF}]" + "|[\\x{2A700}-\\x{2B73F}]" + "|[\\x{2B740}-\\x{2B81F}]" + "|[\\x{2B820}-\\x{2CEAF}]" + "|[\\x{F900}-\\x{FAFF}]" + "|[\\x{2F800}-\\x{2FA1F}]"; +const char BasicTokenizerOp::kUnusedPattern[] = "\\[CLS\\]|\\[SEP\\]|\\[UNK\\]|\\[PAD\\]|\\[MASK\\]|\\[unused\\d+\\]|"; +const std::unordered_set BasicTokenizerOp::kUnusedWords{"[CLS]", "[SEP]", "[UNK]", "[PAD]", "[MASK]"}; + +BasicTokenizerOp::BasicTokenizerOp(const bool &lower_case, const bool &keep_whitespace, + const NormalizeForm &normalization_form, const bool &preserve_unused_token, + const bool &with_offsets) + : lower_case_(lower_case), + keep_whitespace_(keep_whitespace), + preserve_unused_token_(preserve_unused_token), + with_offsets_(with_offsets), + case_fold_(std::make_unique()), + nfd_normalize_(std::make_unique(NormalizeForm::kNfd)), + normalization_form_(normalization_form), + common_normalize_(std::make_unique(normalization_form)), + replace_accent_chars_(std::make_unique("\\p{Mn}", "")), + replace_control_chars_(std::make_unique("\\p{Cc}|\\p{Cf}", " ")) { + std::string delim_pattern = std::string("\\s+|") + kCommonPattern; + std::string keep_delim_pattern; + if (keep_whitespace_) { + keep_delim_pattern = delim_pattern; + } else { + keep_delim_pattern = kCommonPattern; + } + if (preserve_unused_token_) { + keep_delim_pattern = kUnusedPattern + keep_delim_pattern; + delim_pattern = kUnusedPattern + delim_pattern; + } + regex_tokenizer_ = std::make_unique(delim_pattern, keep_delim_pattern, with_offsets_); +} + +Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::string_view &text, + const std::unordered_set &unused_words, + std::string *outupt) { + icu::ErrorCode error; + const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); + outupt->clear(); + + // 1. get start and end offsets of not case fold strs + std::queue> offsets; // offsets of not used words + int start = -1; + int len = 0; + for (int i = 0; i < text.length(); i++) { + if (text[i] == '[') { + start = i; + ++len; + } else if (text[i] == ']' && start >= 0) { + ++len; + std::string word(text.substr(start, len)); + if (unused_words.find(word) != unused_words.end()) { + offsets.push(std::make_pair(start, start + len - 1)); + } + start = -1; + len = 0; + } else if (start >= 0) { + ++len; + } + } + + // 2. Do not apply case fold on `unused_words` + start = 0; + for (int i = 0; i < text.length();) { + std::string_view process_text; + std::string preserve_token; + if (offsets.empty()) { + i = text.length(); + process_text = text.substr(start, i - start); + } else { + preserve_token = text.substr(offsets.front().first, offsets.front().second - offsets.front().first + 1); + process_text = text.substr(start, offsets.front().first - start); + i = offsets.front().second + 1; + offsets.pop(); + } + std::string temp; + icu::StringByteSink sink(&temp); + nfkc_case_fold->normalizeUTF8(0, icu::StringPiece(process_text.data(), process_text.size()), sink, nullptr, error); + *outupt += temp + preserve_token; + } + return Status::OK(); +} + +Status BasicTokenizerOp::CaseFoldWithoutUnusedWords(const std::shared_ptr &input, + std::shared_ptr *output) { + IO_CHECK(input, output); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(*iter, kUnusedWords, &strs[i++])); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} + +Status BasicTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::shared_ptr cur_input; + std::shared_ptr processed_tensor; + if (lower_case_) { + if (!preserve_unused_token_) { + // to lower case + RETURN_IF_NOT_OK(case_fold_->Compute(input[0], &processed_tensor)); + } else { + // to lower case except words in kUnusedWords + RETURN_IF_NOT_OK(CaseFoldWithoutUnusedWords(input[0], &processed_tensor)); + } + cur_input = processed_tensor; + // strip accent characters + RETURN_IF_NOT_OK(nfd_normalize_->Compute(cur_input, &processed_tensor)); + cur_input = processed_tensor; + RETURN_IF_NOT_OK(replace_accent_chars_->Compute(cur_input, &processed_tensor)); + } else { + RETURN_IF_NOT_OK(common_normalize_->Compute(input[0], &processed_tensor)); + } + // strip control characters + cur_input = processed_tensor; + RETURN_IF_NOT_OK(replace_control_chars_->Compute(cur_input, &processed_tensor)); + return regex_tokenizer_->Compute(TensorRow(0, {std::move(processed_tensor)}), output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h new file mode 100644 index 0000000000..cbc21273c2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/basic_tokenizer_op.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class BasicTokenizerOp : public TensorOp { + public: + static const bool kDefLowerCase; + static const bool kDefKeepWhitespace; + static const NormalizeForm kDefNormalizationForm; + static const bool kDefPreserveUnusedToken; + static const bool kDefWithOffsets; + + explicit BasicTokenizerOp(const bool &lower_case = kDefLowerCase, const bool &keep_whitespace = kDefKeepWhitespace, + const NormalizeForm &normalization_form = kDefNormalizationForm, + const bool &preserve_unused_token = kDefPreserveUnusedToken, + const bool &with_offsets = kDefWithOffsets); + + ~BasicTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "BasicTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status CaseFoldWithoutUnusedWords(const std::string_view &text, const std::unordered_set &unused_words, + std::string *outupt); + Status CaseFoldWithoutUnusedWords(const std::shared_ptr &input, std::shared_ptr *output); + + std::string Name() const override { return kBasicTokenizerOp; } + + private: + static const char kCommonPattern[]; + static const char kUnusedPattern[]; + static const std::unordered_set kUnusedWords; + bool with_offsets_; + bool lower_case_; + bool keep_whitespace_; + NormalizeForm normalization_form_; + bool preserve_unused_token_; + std::unique_ptr case_fold_; + std::unique_ptr nfd_normalize_; + std::unique_ptr common_normalize_; + std::unique_ptr replace_accent_chars_; + std::unique_ptr replace_control_chars_; + std::unique_ptr regex_tokenizer_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_BASIC_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.cc new file mode 100644 index 0000000000..631597ba24 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/bert_tokenizer_op.h" +namespace mindspore { +namespace dataset { +Status BertTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + TensorRow basic_tensor; + RETURN_IF_NOT_OK(basic_tokenizer_.Compute(input, &basic_tensor)); + RETURN_IF_NOT_OK(wordpiece_tokenizer_.Compute(basic_tensor, output)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h new file mode 100644 index 0000000000..b281903349 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/bert_tokenizer_op.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class BertTokenizerOp : public TensorOp { + public: + explicit BertTokenizerOp(const std::shared_ptr &vocab, + const std::string &suffix_indicator = WordpieceTokenizerOp::kDefSuffixIndicator, + const int &max_bytes_per_token = WordpieceTokenizerOp::kDefMaxBytesPerToken, + const std::string &unknown_token = WordpieceTokenizerOp::kDefUnknownToken, + const bool &lower_case = BasicTokenizerOp::kDefLowerCase, + const bool &keep_whitespace = BasicTokenizerOp::kDefKeepWhitespace, + const NormalizeForm &normalization_form = BasicTokenizerOp::kDefNormalizationForm, + const bool &preserve_unused_token = BasicTokenizerOp::kDefPreserveUnusedToken, + const bool &with_offsets = WordpieceTokenizerOp::kDefWithOffsets) + : wordpiece_tokenizer_(vocab, suffix_indicator, max_bytes_per_token, unknown_token, with_offsets), + basic_tokenizer_(lower_case, keep_whitespace, normalization_form, preserve_unused_token, with_offsets) {} + + ~BertTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "BertTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kBertTokenizerOp; } + + private: + WordpieceTokenizerOp wordpiece_tokenizer_; + BasicTokenizerOp basic_tokenizer_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_BERT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc new file mode 100644 index 0000000000..0ea5cadedb --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include +#include +#include +#include +#include + +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + +namespace mindspore { +namespace dataset { + +Status CaseFoldOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + icu::ErrorCode error; + const icu::Normalizer2 *nfkc_case_fold = icu::Normalizer2::getNFKCCasefoldInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCCasefoldInstance failed."); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + icu::StringByteSink sink(&strs[i++]); + nfkc_case_fold->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h new file mode 100644 index 0000000000..f7a2105269 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/case_fold_op.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#define DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class CaseFoldOp : public TensorOp { + public: + CaseFoldOp() {} + + ~CaseFoldOp() override = default; + + void Print(std::ostream &out) const override { out << "CaseFoldOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kCaseFoldOp; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_CASE_FOLD_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc new file mode 100644 index 0000000000..0a1ae92d14 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" + +#include +#include +#include +#include "minddata/dataset/util/path.h" + +namespace mindspore { +namespace dataset { + +const bool JiebaTokenizerOp::kDefWithOffsets = false; + +JiebaTokenizerOp::JiebaTokenizerOp(const std::string &hmm_path, const std::string &dict_path, const JiebaMode &mode, + const bool &with_offsets) + : jieba_mode_(mode), hmm_model_path_(hmm_path), mp_dict_path_(dict_path), with_offsets_(with_offsets) { + jieba_parser_ = std::make_unique(mp_dict_path_, hmm_model_path_, ""); +} + +Status JiebaTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + RETURN_UNEXPECTED_IF_NULL(jieba_parser_); + + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("the input tensor should be scalar string tensor"); + } + + std::string_view sentence_v; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&sentence_v, {})); + std::string sentence{sentence_v}; + std::vector words; + std::vector offsets_start, offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + if (sentence == "") { + words.push_back(""); + } else { + std::vector tmp; + if (jieba_mode_ == JiebaMode::kMp) { + std::unique_ptr mp_seg = std::make_unique(jieba_parser_->GetDictTrie()); + mp_seg->Cut(sentence, tmp, MAX_WORD_LENGTH); + } else if (jieba_mode_ == JiebaMode::kHmm) { + std::unique_ptr hmm_seg = + std::make_unique(jieba_parser_->GetHMMModel()); + hmm_seg->Cut(sentence, tmp); + } else { // Mix + std::unique_ptr mix_seg = + std::make_unique(jieba_parser_->GetDictTrie(), jieba_parser_->GetHMMModel()); + mix_seg->Cut(sentence, tmp, true); + } + GetStringsFromWords(tmp, words); + for (auto item : tmp) { + offsets_start.push_back(static_cast(item.offset)); + offsets_limit.push_back(static_cast(item.offset + item.word.length())); + } + } + token_tensor = std::make_shared(words, TensorShape({(dsize_t)words.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} + +Status JiebaTokenizerOp::AddWord(const std::string &word, int freq) { + RETURN_UNEXPECTED_IF_NULL(jieba_parser_); + if (jieba_parser_->InsertUserWord(word, freq, "") == false) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "add word error"); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h new file mode 100644 index 0000000000..4e49891c00 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/jieba_tokenizer_op.h @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_TEXT_JIEBA_OP_H_ +#define DATASET_ENGINE_TEXT_JIEBA_OP_H_ + +#include +#include + +#include "cppjieba/Jieba.hpp" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +enum class JiebaMode { kMix = 0, kMp = 1, kHmm = 2 }; + +class JiebaTokenizerOp : public TensorOp { + public: + // default constant for Jieba MPSegment algorithm. + static constexpr size_t MAX_WORD_LENGTH = 512; + // default const for set whether Jieba output offsets tensor. + static const bool kDefWithOffsets; + // Constructor for JiebaTokenizerOp. + // @param hmm_path HMM model file. + // @param mp_path MP model file. + // @mode tokenization mode [Default "MIX"], "MP" model will tokenize with MPSegment algorithm, "HMM" mode will + // tokenize with Hiddel Markov Model Segment algorithm, "MIx" model will tokenize with a mix of MPSegment and + // HMMSegment algorithm. + // @with_offsets user set this value to choose whether output offset tensor. + JiebaTokenizerOp(const std::string &hmm_path, const std::string &mp_path, const JiebaMode &mode = JiebaMode::kMix, + const bool &with_offsets = kDefWithOffsets); + ~JiebaTokenizerOp() override = default; + + void Print(std::ostream &out) const override { + out << "JiebaTokenizerOp: " << jieba_mode_ << "hmm_model_path_ " << hmm_model_path_ << "mp_dict_path_" + << mp_dict_path_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + // @word the word to be added to the JiebaTokenizer. + // @freq [Default 0] the frequency fo the word to be added. + // @tag [Default ""] the tag of the word to be added. + Status AddWord(const std::string &word, int freq = 0); + + std::string Name() const override { return kJiebaTokenizerOp; } + + protected: + std::string hmm_model_path_; + std::string mp_dict_path_; + std::unique_ptr jieba_parser_; + JiebaMode jieba_mode_; + bool with_offsets_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_TEXT_JIEBA_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc new file mode 100644 index 0000000000..02b75bc4f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/lookup_op.h" + +#include + +namespace mindspore { +namespace dataset { + +LookupOp::LookupOp(std::shared_ptr vocab, WordIdType default_id) + : vocab_(vocab), default_id_(default_id), type_(DataType("int32")) {} + +Status LookupOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + RETURN_UNEXPECTED_IF_NULL(vocab_); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None String Tensor."); + std::vector word_ids; + word_ids.reserve(input->Size()); + for (auto itr = input->begin(); itr != input->end(); itr++) { + WordIdType word_id = vocab_->Lookup(std::string(*itr)); + word_ids.emplace_back(word_id == Vocab::kNoTokenExists ? default_id_ : word_id); + CHECK_FAIL_RETURN_UNEXPECTED( + word_ids.back() != Vocab::kNoTokenExists, + "Lookup Error: token" + std::string(*itr) + "doesn't exist in vocab and no unknown token is specified."); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), type_, + reinterpret_cast(word_ids.data()))); + return Status::OK(); +} +Status LookupOp::OutputType(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput() && outputs.size() == NumOutput(), "size doesn't match"); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0] == DataType::DE_STRING, "None String tensor type"); + outputs[0] = type_; + return Status::OK(); +} + +void LookupOp::Print(std::ostream &out) const { + out << "LookupOp: " + << "type: " << type_ << "\n default lookup id: " << default_id_ << "\n"; +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h new file mode 100644 index 0000000000..4efc64321b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/lookup_op.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_LOOKUP_OP_H_ +#define DATASET_TEXT_KERNELS_LOOKUP_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/text/vocab.h" + +namespace mindspore { +namespace dataset { +class LookupOp : public TensorOp { + public: + // constructor for lookup, takes in a vocab object + // @param std::shared_ptr vocab - + // @param WordIdType default_id, id to lookup if a word is not in vocab + explicit LookupOp(std::shared_ptr vocab, WordIdType default_id = 1); + + ~LookupOp() = default; + + // perform actual lookup on each tensor + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // print method + // @param std::ostream out + void Print(std::ostream &out) const override; + + // @param std::vector &inputs - + // @param std::vector &outputs - + // @return error code + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kLookupOp; } + + private: + std::shared_ptr vocab_; + WordIdType default_id_; + DataType type_; // type of tensor after lookup +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_LOOKUP_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc new file mode 100644 index 0000000000..36781b9b4d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/ngram_op.h" + +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { + +NgramOp::NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, + const std::string &r_pad, const std::string &separator) + : ngrams_(ngrams), + l_len_(l_len), + r_len_(r_len), + l_pad_with_sp_(l_pad + separator), + r_pad_with_sp_(r_pad + separator), + separator_(separator) {} + +Status NgramOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING && input->Rank() == 1, "Not a 1-D str Tensor"); + std::vector offsets; // offsets for each str + std::vector res; // holds the result of ngrams + std::string str_buffer; // concat all pad tokens with string interleaved with separators + res.reserve(input->shape().NumOfElements()); // this should be more than enough + offsets.reserve(1 + l_len_ + r_len_ + input->shape().NumOfElements()); + str_buffer.reserve(l_pad_with_sp_.size() * l_len_ + r_pad_with_sp_.size() * r_len_ + input->SizeInBytes()); + offsets.push_back(str_buffer.size()); // insert 0 as the starting pos + for (int i = 0; i < l_len_; i++) offsets.push_back((str_buffer += l_pad_with_sp_).size()); + + for (auto itr = input->begin(); itr != input->end(); itr++) { + str_buffer += (*itr); + str_buffer += separator_; + offsets.push_back(str_buffer.size()); + } + + for (int i = 0; i < r_len_; i++) offsets.push_back((str_buffer += r_pad_with_sp_).size()); + + for (auto n : ngrams_) { + CHECK_FAIL_RETURN_UNEXPECTED(n > 0, "n gram needs to be a positive number.\n"); + int32_t start_ind = l_len_ - std::min(l_len_, n - 1); + int32_t end_ind = offsets.size() - r_len_ + std::min(r_len_, n - 1); + if (end_ind - start_ind <= n) { + res.emplace_back(std::string()); // push back empty string + } else { + CHECK_FAIL_RETURN_UNEXPECTED(end_ind - n >= 0, "Incorrect loop condition"); + + for (int i = start_ind; i < end_ind - n; i++) { + res.emplace_back(str_buffer.substr(offsets[i], offsets[i + n] - offsets[i] - separator_.size())); + } + } + } + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, res, TensorShape({static_cast(res.size())}))); + return Status::OK(); +} + +void NgramOp::Print(std::ostream &out) const { + out << "NgramOp: " + << "left pad width: " << l_len_ << " left pad token with separator: " << l_pad_with_sp_ << "\n" + << "right pad width: " << r_len_ << " right pad token with separator: " << r_pad_with_sp_ << "\n" + << "separator: " << separator_ << "\n"; +} + +Status NgramOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + CHECK_FAIL_RETURN_UNEXPECTED(inputs.size() == NumInput(), "incorrect num of inputs\n"); + CHECK_FAIL_RETURN_UNEXPECTED(inputs[0].Rank() == 1, "ngram only works with 1-dim data\n"); + dsize_t num_elements = ngrams_.size(); + for (int32_t n : ngrams_) { + // here since rank == 1, NumOfElements == shape[0]. add padding length to string + int32_t len_with_padding = inputs[0].NumOfElements() + std::min(n - 1, l_len_) + std::min(n - 1, r_len_); + // if len_with_padding - n < 0, this would return an empty string + num_elements += std::max(len_with_padding - n, 0); + } + outputs.emplace_back(TensorShape({num_elements})); + CHECK_FAIL_RETURN_UNEXPECTED(outputs.size() == NumOutput(), "incorrect num of outputs\n"); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h new file mode 100644 index 0000000000..6ce3881638 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/ngram_op.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_NGRAM_OP_H_ +#define DATASET_TEXT_KERNELS_NGRAM_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class NgramOp : public TensorOp { + public: + // Constructor of Ngram model + // @param const std::vector &ngrams + // @param int32_tl_len - padding length on the left + // @param int32_t r_len - padding length on the right + // @param const std::string &l_pad - padding token on the left + // @param const std::string &r_pad - padding token on the right + // @param const std::string &separator - use to join strings + NgramOp(const std::vector &ngrams, int32_t l_len, int32_t r_len, const std::string &l_pad, + const std::string &r_pad, const std::string &separator); + + // perform ngram model on each tensor + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // destructor + ~NgramOp() override = default; + + // @param std::vector &inputs - shape of input tensors + // @param std::vector &outputs - shape of output tensors + // @return error code + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + // print arg for debugging + // @param std::ostream &out + void Print(std::ostream &out) const override; + + std::string Name() const override { return kNgramOp; } + + private: + std::vector ngrams_; // list of n grams + int32_t l_len_; // left padding length + int32_t r_len_; // right padding length + std::string l_pad_with_sp_; // left padding appended with separator + std::string r_pad_with_sp_; // right padding appended with separator + std::string separator_; // separator +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_NGRAM_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc new file mode 100644 index 0000000000..0c0aa5fa2d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include +#include +#include +#include +#include + +#include "unicode/errorcode.h" +#include "unicode/normalizer2.h" +#include "unicode/utypes.h" + +namespace mindspore { +namespace dataset { +const NormalizeForm NormalizeUTF8Op::kDefNormalizeForm = NormalizeForm::kNfkc; +Status NormalizeUTF8Op::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + icu::ErrorCode error; + const icu::Normalizer2 *normalize = nullptr; + switch (normalize_form_) { + case NormalizeForm::kNone: { + *output = input; + return Status::OK(); + } + case NormalizeForm::kNfc: { + normalize = icu::Normalizer2::getNFCInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFCInstance failed"); + break; + } + case NormalizeForm::kNfkc: { + normalize = icu::Normalizer2::getNFKCInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKCInstance failed"); + break; + } + case NormalizeForm::kNfd: { + normalize = icu::Normalizer2::getNFDInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFDInstance failed"); + break; + } + case NormalizeForm::kNfkd: { + normalize = icu::Normalizer2::getNFKDInstance(error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "getNFKDInstance failed"); + break; + } + default: { + RETURN_STATUS_UNEXPECTED("unexpected normalize form"); + break; + } + } + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + icu::StringByteSink sink(&strs[i++]); + normalize->normalizeUTF8(0, icu::StringPiece((*iter).data(), (*iter).size()), sink, nullptr, error); + CHECK_FAIL_RETURN_UNEXPECTED(error.isSuccess(), "normalizeUTF8 failed."); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h new file mode 100644 index 0000000000..f914be1c58 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/normalize_utf8_op.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#define DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +enum class NormalizeForm { + kNone = 0, + kNfc, + kNfkc, + kNfd, + kNfkd, +}; + +class NormalizeUTF8Op : public TensorOp { + public: + static const NormalizeForm kDefNormalizeForm; + explicit NormalizeUTF8Op(NormalizeForm normalize_form = kDefNormalizeForm) : normalize_form_(normalize_form) {} + + ~NormalizeUTF8Op() override = default; + + void Print(std::ostream &out) const override { out << "NormalizeUTF8Op"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kNormalizeUTF8Op; } + + private: + NormalizeForm normalize_form_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_NORMALIZE_UTF8_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc new file mode 100644 index 0000000000..c370393e76 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { + +Status RegexReplaceOp::RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, + std::string *out) const { + CHECK_FAIL_RETURN_UNEXPECTED((matcher != nullptr && out != nullptr), "Input is null"); + UErrorCode icu_error = U_ZERO_ERROR; + icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(text); + matcher->reset(unicode_text); + icu::UnicodeString unicode_out; + if (replace_all_) { + unicode_out = matcher->replaceAll(replace_, icu_error); + } else { + unicode_out = matcher->replaceFirst(replace_, icu_error); + } + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "RegexReplace failed"); + unicode_out.toUTF8String(*out); + return Status::OK(); +} + +Status RegexReplaceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + UErrorCode icu_error = U_ZERO_ERROR; + icu::RegexMatcher matcher(pattern_, 0, icu_error); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(icu_error), "Create icu RegexMatcher failed, you may input one error pattern"); + std::vector strs(input->Size()); + int i = 0; + for (auto iter = input->begin(); iter != input->end(); iter++) { + RETURN_IF_NOT_OK(RegexReplace(&matcher, *iter, &strs[i])); + } + *output = std::make_shared(std::move(strs), input->shape()); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h new file mode 100644 index 0000000000..ac3d3f7ff0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_replace_op.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#define DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ +#include +#include + +#include "unicode/regex.h" +#include "unicode/errorcode.h" +#include "unicode/utypes.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class RegexReplaceOp : public TensorOp { + public: + RegexReplaceOp(const std::string &pattern, const std::string &replace, bool replace_all = true) + : pattern_(icu::UnicodeString::fromUTF8(pattern)), + replace_(icu::UnicodeString::fromUTF8(replace)), + replace_all_(replace_all) {} + + ~RegexReplaceOp() override = default; + + void Print(std::ostream &out) const override { out << "RegexReplaceOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRegexReplaceOp; } + + protected: + Status RegexReplace(icu::RegexMatcher *const matcher, const std::string_view &text, std::string *out) const; + + private: + const icu::UnicodeString pattern_; + const icu::UnicodeString replace_; + const bool replace_all_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_REGEX_REPLACE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc new file mode 100644 index 0000000000..7ff1d994be --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.cc @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include +#include +#include +#include +#include + +namespace mindspore { +namespace dataset { + +const bool RegexTokenizerOp::kDefWithOffsets = false; + +Status RegexTokenizerOp::GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, + std::string *out_utf8, icu::UnicodeString *out_unicode) const { + CHECK_FAIL_RETURN_UNEXPECTED((out_utf8 != nullptr || out_unicode != nullptr), "Wrong input"); + int total_len = input.length(); + int end = start + len; + CHECK_FAIL_RETURN_UNEXPECTED((start >= 0 && len > 0 && end <= total_len), "Out of range"); + icu::UnicodeString temp; + input.extract(start, len, temp); + if (out_utf8 != nullptr) { + temp.toUTF8String(*out_utf8); + } + if (out_unicode != nullptr) { + *out_unicode = temp; + } + return Status::OK(); +} + +Status RegexTokenizerOp::GetRegexTokens(const std::string &text, std::vector *out_tokens, + std::vector *offsets_start, + std::vector *offsets_limit) const { + UErrorCode status = U_ZERO_ERROR; + out_tokens->clear(); + icu::RegexMatcher token_matcher(delim_pattern_, 0, status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); + icu::RegexMatcher delim_matcher(keep_delim_pattern_, 0, status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Create icu RegexMatcher failed, you may input one error pattern"); + + icu::UnicodeString utext(icu::UnicodeString::fromUTF8(text)); + token_matcher.reset(utext); + + int text_start_index = 0; + int token_start_index = 0; + status = U_ZERO_ERROR; + while (token_matcher.find(status) && U_SUCCESS(status)) { + int deli_start_index = token_matcher.start(status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); + int deli_end_index = token_matcher.end(status); + CHECK_FAIL_RETURN_UNEXPECTED(U_SUCCESS(status), "Get RegexMatcher matched start index failed"); + + // Add non-empty token + int token_len = deli_start_index - token_start_index; + if (token_len > 0) { + std::string token; + uint32_t token_offset = 0; + RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, token_len, &token)); + token_offset = token.length(); + out_tokens->emplace_back(std::move(token)); + offsets_start->push_back(static_cast(text_start_index)); + offsets_limit->push_back(static_cast(text_start_index + token_offset)); + text_start_index += token_offset; + } + + int delim_len = deli_end_index - deli_start_index; + if (delim_len > 0) { + icu::UnicodeString delim_str; + std::string delim_utf8_str; + uint32_t delim_str_offset = 0; + RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, deli_start_index, delim_len, &delim_utf8_str, &delim_str)); + delim_matcher.reset(delim_str); + delim_str_offset = delim_utf8_str.length(); + if (keep_delim_ && delim_matcher.matches(status) && U_SUCCESS(status)) { + out_tokens->emplace_back(std::move(delim_utf8_str)); + offsets_start->push_back(static_cast(text_start_index)); + offsets_limit->push_back(static_cast(text_start_index + delim_str_offset)); + } + text_start_index += delim_str_offset; + } + token_start_index = deli_end_index; + } + + if (token_start_index < utext.length()) { + std::string temp; + uint32_t temp_offset = 0; + RETURN_IF_NOT_OK(GetUnicodeSubstr(utext, token_start_index, utext.length() - token_start_index, &temp)); + temp_offset = temp.length(); + out_tokens->emplace_back(std::move(temp)); + offsets_start->push_back(static_cast(text_start_index)); + offsets_limit->push_back(static_cast(text_start_index + temp_offset)); + } + return Status::OK(); +} + +Status RegexTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view text; + std::vector tokens; + std::vector offsets_start; + std::vector offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&text, {})); + RETURN_IF_NOT_OK(GetRegexTokens(std::string(text.data(), text.size()), &tokens, &offsets_start, &offsets_limit)); + token_tensor = std::make_shared(std::move(tokens), TensorShape({(dsize_t)tokens.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h new file mode 100644 index 0000000000..56271f9551 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/regex_tokenizer_op.h @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#define DATASET_TEXT_REGEX_TOKENIZER_OP_H_ +#include +#include +#include + +#include "unicode/regex.h" +#include "unicode/errorcode.h" +#include "unicode/utypes.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class RegexTokenizerOp : public TensorOp { + public: + static const bool kDefWithOffsets; + + RegexTokenizerOp(const std::string &delim_pattern, const std::string &keep_delim_pattern, + const bool &with_offsets = kDefWithOffsets) + : delim_pattern_(icu::UnicodeString::fromUTF8(delim_pattern)), + keep_delim_pattern_(icu::UnicodeString::fromUTF8(keep_delim_pattern)), + with_offsets_(with_offsets), + keep_delim_(!keep_delim_pattern.empty()) {} + + ~RegexTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "RegexTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status GetUnicodeSubstr(const icu::UnicodeString &input, const int &start, const int &len, std::string *out_utf8, + icu::UnicodeString *out_unicode = nullptr) const; + Status GetRegexTokens(const std::string &text, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + + std::string Name() const override { return kRegexTokenizerOp; } + + private: + const icu::UnicodeString delim_pattern_; + const icu::UnicodeString keep_delim_pattern_; + bool with_offsets_; + const bool keep_delim_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_REGEX_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc new file mode 100644 index 0000000000..a6685a2d64 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.cc @@ -0,0 +1,241 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/to_number_op.h" + +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {} + +ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {} + +Status ToNumberOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string."); + + switch (cast_to_type_.value()) { + case DataType::DE_INT8: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT16: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT32: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_INT64: + RETURN_IF_NOT_OK(ToSignedIntegral(input, output)); + break; + case DataType::DE_UINT8: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT16: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT32: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_UINT64: + RETURN_IF_NOT_OK(ToUnsignedIntegral(input, output)); + break; + case DataType::DE_FLOAT16: + RETURN_IF_NOT_OK(this->ToFloat16(input, output)); + break; + case DataType::DE_FLOAT32: + RETURN_IF_NOT_OK(ToFloat(input, output)); + break; + case DataType::DE_FLOAT64: + RETURN_IF_NOT_OK(ToDouble(input, output)); + break; + } + + return Status::OK(); +} + +void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; } + +Status ToNumberOp::OutputShape(const std::vector &input_shapes, std::vector &output_shapes) { + (void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes)); + return Status::OK(); +} + +template +Status ToNumberOp::ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + int64_t result = 0; + + try { + result = std::stoll(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::min()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + T casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +template +Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + uint64_t result = 0; + + // If there is a - at the start of the string, it is considered by us to + // be out of bounds. If the - is somewhere else in the string, it is + // deemed invalid by std::stoull and will throw std::invalid_argument + for (int i = 0; i < (*it).size(); i++) { + if ((*it)[i] == '-') { + is_cast_out_of_range = true; + break; + } + } + + try { + result = std::stoull(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::min() || is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::min()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + T casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +Status ToNumberOp::ToFloat16(const std::shared_ptr &input, std::shared_ptr *output) { + // special case, float16 does not exist in c++, no native support for + // casting, so cast to float first then use this method, which use Eigen. + std::shared_ptr temp; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32"))); + RETURN_IF_NOT_OK(ToFloat(input, &temp)); + RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output)); + return Status::OK(); +} + +Status ToNumberOp::ToFloat(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + float result = 0; + + try { + result = std::stof(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || + is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::lowest()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + float casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +Status ToNumberOp::ToDouble(const std::shared_ptr &input, std::shared_ptr *output) { + std::vector casted; + + for (auto it = input->begin(); it != input->end(); ++it) { + bool is_cast_out_of_range = false; + double result = 0; + + try { + result = std::stod(std::string(*it)); + } catch (const std::out_of_range &) { + is_cast_out_of_range = true; + } catch (const std::invalid_argument &) { + RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer."); + } + + if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest() || + is_cast_out_of_range) { + std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " + + cast_to_type_.ToString() + ". The valid range is: [" + + std::to_string(std::numeric_limits::lowest()) + ", " + + std::to_string(std::numeric_limits::max()) + "]."; + + RETURN_STATUS_UNEXPECTED(error_message); + } + + double casted_result = static_cast(result); + casted.push_back(casted_result); + } + + RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape())); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h new file mode 100644 index 0000000000..8582fcf073 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/to_number_op.h @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ +#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class ToNumberOp : public TensorOp { + public: + // Constructor of ToNumberOp + // @param const DataType &cast_to_type - the type to convert string inputs to. + explicit ToNumberOp(const DataType &cast_to_type); + + // Constructor of ToNumberOp + // @param const std::string &cast_to_type - the type in string form to convert string inputs to. + explicit ToNumberOp(const std::string &cast_to_type); + + ~ToNumberOp() override = default; + + // Perform numeric conversion on each string in each tensor. + // @param const std::shared_ptr &input + // @param std::shared_ptr *output + // @return error code + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + // For each input shape, find the output shape + // @param std::vector &inputs - shape of input tensors + // @param std::vector &outputs - shape of output tensors + // @return error code + Status OutputShape(const std::vector &input_shapes, std::vector &output_shapes) override; + + // print arg for debugging + // @param std::ostream &out + void Print(std::ostream &out) const override; + + std::string Name() const override { return kToNumberOp; } + + private: + template + Status ToSignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); + + template + Status ToUnsignedIntegral(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToFloat16(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToFloat(const std::shared_ptr &input, std::shared_ptr *output); + + Status ToDouble(const std::shared_ptr &input, std::shared_ptr *output); + + DataType cast_to_type_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.cc new file mode 100644 index 0000000000..53a803c542 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/truncate_sequence_pair_op.h" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/slice_op.h" + +namespace mindspore { +namespace dataset { + +Status TruncateSequencePairOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 2, "Number of inputs should be two."); + std::shared_ptr seq1 = input[0]; + std::shared_ptr seq2 = input[1]; + CHECK_FAIL_RETURN_UNEXPECTED(seq1->shape().Rank() == 1 && seq2->shape().Rank() == 1, + "Both sequences should be of rank 1"); + dsize_t length1 = seq1->shape()[0]; + dsize_t length2 = seq2->shape()[0]; + dsize_t outLength1 = length1; + dsize_t outLength2 = length2; + + dsize_t total = length1 + length2; + while (total > max_length_) { + if (outLength1 > outLength2) + outLength1--; + else + outLength2--; + total--; + } + std::shared_ptr outSeq1; + if (length1 != outLength1) { + std::unique_ptr slice1(new SliceOp(Slice(outLength1 - length1))); + RETURN_IF_NOT_OK(slice1->Compute(seq1, &outSeq1)); + } else { + outSeq1 = std::move(seq1); + } + + std::shared_ptr outSeq2; + if (length2 != outLength2) { + std::unique_ptr slice2(new SliceOp(Slice(outLength2 - length2))); + RETURN_IF_NOT_OK(slice2->Compute(seq2, &outSeq2)); + } else { + outSeq2 = std::move(seq2); + } + output->push_back(outSeq1); + output->push_back(outSeq2); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h new file mode 100644 index 0000000000..ce82735645 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ +#define DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { + +class TruncateSequencePairOp : public TensorOp { + public: + explicit TruncateSequencePairOp(dsize_t length) : max_length_(length) {} + + ~TruncateSequencePairOp() override = default; + + void Print(std::ostream &out) const override { out << "TruncateSequencePairOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kTruncateSequencePairOp; } + + private: + dsize_t max_length_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_TRUNCATE_SEQUENCE_PAIR_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc new file mode 100644 index 0000000000..e08f61100b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.cc @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +const bool UnicodeCharTokenizerOp::kDefWithOffsets = false; + +Status UnicodeCharTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); + + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + std::vector splits(runes.size()); + std::vector offsets_start, offsets_limit; + for (size_t i = 0; i < runes.size(); i++) { + offsets_start.push_back(runes[i].offset); + offsets_limit.push_back(runes[i].offset + runes[i].len); + splits[i] = str.substr(runes[i].offset, runes[i].len); + } + if (splits.empty()) { + splits.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h new file mode 100644 index 0000000000..415d99b451 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_char_tokenizer_op.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class UnicodeCharTokenizerOp : public TensorOp { + public: + static const bool kDefWithOffsets; + + explicit UnicodeCharTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {} + + ~UnicodeCharTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "UnicodeCharTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kUnicodeCharTokenizerOp; } + + private: + bool with_offsets_; +}; + +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_UNICODE_CHAR_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc new file mode 100644 index 0000000000..60fe8dd0e4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" +#include "unicode/errorcode.h" +#include "unicode/uchar.h" +#include "unicode/uscript.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +const bool UnicodeScriptTokenizerOp::kDefKeepWhitespace = false; +const bool UnicodeScriptTokenizerOp::kDefWithOffsets = false; + +Status UnicodeScriptTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + UScriptCode last_script = USCRIPT_INVALID_CODE; + icu::ErrorCode status; + int start = 0; + int len = 0; + std::vector splits; + std::vector offsets_start, offsets_limit; + + bool was_space = false; + for (size_t i = 0; i < runes.size(); i++) { + bool is_space = u_isUWhiteSpace(runes[i].rune); + UScriptCode script = uscript_getScript(runes[i].rune, status); + if (status.isFailure()) { + status.reset(); + script = USCRIPT_INVALID_CODE; + } + // 1) Seperate UTF-8 strings of different UScriptCode values + // (such as: "Chinese中国" should be splited to ["Chinese", "中国"]) + // 2) Seperate whitespace and non-whitespace UTF-8 strings + // (such as: " ." should be split to [" ", "."]) + if (len > 0 && (script != last_script || is_space != was_space)) { + // 3) If keep_whitespace_ is false, all the whitespace characters will be discard + if (keep_whitespace_ || !was_space) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + } + start = runes[i].offset; + len = runes[i].len; + } else { + len += runes[i].len; + } + last_script = script; + was_space = is_space; + } + + if (len > 0 && (keep_whitespace_ || !was_space)) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + } + // 4) If the input is empty scalar string, the output will be 1-D empty string. + if (splits.empty()) { + splits.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h new file mode 100644 index 0000000000..fc3b9e620a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/unicode_script_tokenizer_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class UnicodeScriptTokenizerOp : public TensorOp { + public: + static const bool kDefKeepWhitespace; + static const bool kDefWithOffsets; + + explicit UnicodeScriptTokenizerOp(const bool &keep_whitespace = kDefKeepWhitespace, + const bool &with_offsets = kDefWithOffsets) + : keep_whitespace_(keep_whitespace), with_offsets_(with_offsets) {} + + ~UnicodeScriptTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "UnicodeScriptTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kUnicodeScriptTokenizerOp; } + + private: + bool keep_whitespace_; // If or not keep whitespace tokens + bool with_offsets_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_UNICODE_SCRIPT_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc new file mode 100644 index 0000000000..d3bb32081e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" +#include +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" +#include "unicode/errorcode.h" +#include "unicode/uchar.h" +#include "unicode/uscript.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; + +namespace mindspore { +namespace dataset { + +const bool WhitespaceTokenizerOp::kDefWithOffsets = false; + +Status WhitespaceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input.size() == 1, "Input should be one tensor"); + if (input[0]->Rank() != 0 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar string tensor"); + } + std::string_view str; + RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {})); + + RuneStrArray runes; + if (!DecodeRunesInString(str.data(), str.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + std::vector offsets_start, offsets_limit; + std::vector splits; + int start = 0; + int len = 0; + for (size_t i = 0; i < runes.size(); i++) { + if (u_isUWhiteSpace(runes[i].rune)) { + if (len > 0) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + len = 0; + } + } else { + if (len == 0) { + start = runes[i].offset; + } + len += runes[i].len; + } + } + if (len > 0) { + offsets_start.push_back(static_cast(start)); + offsets_limit.push_back(static_cast(start + len)); + std::string temp(str.substr(start, len)); + splits.emplace_back(std::move(temp)); + } + if (splits.empty()) { + splits.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(splits, TensorShape({(dsize_t)splits.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h new file mode 100644 index 0000000000..7cc37fd705 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/whitespace_tokenizer_op.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class WhitespaceTokenizerOp : public TensorOp { + public: + static const bool kDefWithOffsets; + + explicit WhitespaceTokenizerOp(const bool &with_offsets = kDefWithOffsets) : with_offsets_(with_offsets) {} + + ~WhitespaceTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "WhitespaceTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kWhitespaceTokenizerOp; } + + private: + bool with_offsets_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_WHITESPACE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc new file mode 100644 index 0000000000..f0bd448e39 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.cc @@ -0,0 +1,157 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/text/kernels/wordpiece_tokenizer_op.h" +#include +#include + +namespace mindspore { +namespace dataset { + +const char WordpieceTokenizerOp::kDefSuffixIndicator[] = "##"; +const int WordpieceTokenizerOp::kDefMaxBytesPerToken = 100; +const char WordpieceTokenizerOp::kDefUnknownToken[] = "[UNK]"; +const bool WordpieceTokenizerOp::kDefWithOffsets = false; + +WordpieceTokenizerOp::WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator, + const int &max_bytes_per_token, const std::string &unknown_token, + const bool &with_offsets) + : vocab_(vocab), + suffix_indicator_(suffix_indicator), + max_bytes_per_token_(max_bytes_per_token), + unknown_token_(unknown_token), + with_offsets_(with_offsets) {} + +Status WordpieceTokenizerOp::LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, + bool *out_found, int *out_end) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && start < input_token.size(), "Out of range"); + *out_found = false; + for (int i = runes.size() - 1; i >= 0; i--) { + *out_end = runes[i].offset + runes[i].len; + int len = *out_end - start; + std::string word = input_token.substr(start, len); + if (start > 0) { + word = suffix_indicator_ + word; + } + if (vocab_->Lookup(word) != Vocab::kNoTokenExists) { + *out_found = true; + break; + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::FoundNoToken(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + out_tokens->clear(); + offsets_start->push_back(basic_start); + if (unknown_token_.empty()) { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.length()); + } else { + out_tokens->emplace_back(unknown_token_); + offsets_limit->push_back(basic_start + input_token.length()); + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_tokens) const { + CHECK_FAIL_RETURN_UNEXPECTED(start >= 0 && end > start && end <= input_token.size(), "Out of range"); + std::string subword = input_token.substr(start, end - start); + if (start > 0) { + subword = suffix_indicator_ + subword; + } + out_tokens->emplace_back(subword); + return Status::OK(); +} + +Status WordpieceTokenizerOp::GetTokens(const std::string &input_token, const uint32_t &basic_start, + std::vector *out_tokens, std::vector *offsets_start, + std::vector *offsets_limit) const { + if (input_token.size() > max_bytes_per_token_) { + offsets_start->push_back(basic_start); + if (!unknown_token_.empty()) { + offsets_limit->push_back(basic_start + unknown_token_.size()); + out_tokens->emplace_back(unknown_token_); + } else { + out_tokens->emplace_back(input_token); + offsets_limit->push_back(basic_start + input_token.size()); + } + return Status::OK(); + } + RuneStrArray runes; + if (!DecodeRunesInString(input_token.data(), input_token.size(), runes)) { + RETURN_STATUS_UNEXPECTED("Decode utf8 string failed."); + } + int end = 0; + for (int start = 0; start < input_token.size();) { + bool found = false; + RETURN_IF_NOT_OK(LookupWord(input_token, runes, start, &found, &end)); + if (found) { + RETURN_IF_NOT_OK(AddSubword(input_token, start, end, out_tokens)); + offsets_start->push_back(static_cast(basic_start + start)); + offsets_limit->push_back(static_cast(basic_start + end)); + start = end; + } else { + return FoundNoToken(input_token, basic_start, out_tokens, offsets_start, offsets_limit); + } + } + return Status::OK(); +} + +Status WordpieceTokenizerOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + if (input[0]->Rank() > 1 || input[0]->type() != DataType::DE_STRING) { + RETURN_STATUS_UNEXPECTED("The input tensor should be scalar or 1-D string tensor"); + } + dsize_t count = 0; + std::vector out_tokens; + std::vector offsets_start, offsets_limit; + std::shared_ptr token_tensor, offsets_start_tensor, offsets_limit_tensor; + for (auto iter = input[0]->begin(); iter != input[0]->end(); iter++) { + uint32_t basic_start = 0; + std::vector temp_tokens; + if (with_offsets_ && input.size() == 3) { + RETURN_IF_NOT_OK(input[1]->GetItemAt(&basic_start, {count, 0})); + } + RETURN_IF_NOT_OK(GetTokens(std::string(*iter), basic_start, &temp_tokens, &offsets_start, &offsets_limit)); + out_tokens.insert(out_tokens.end(), temp_tokens.begin(), temp_tokens.end()); + count++; + } + if (out_tokens.empty()) { + out_tokens.emplace_back(""); + offsets_start.push_back(0); + offsets_limit.push_back(0); + } + token_tensor = std::make_shared(out_tokens, TensorShape({(dsize_t)out_tokens.size()})); + output->push_back(token_tensor); + if (with_offsets_) { + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_start_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_start.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_start[0]))); + RETURN_IF_NOT_OK(Tensor::CreateTensor(&offsets_limit_tensor, TensorImpl::kFlexible, + TensorShape({(dsize_t)offsets_limit.size()}), DataType(DataType::DE_UINT32), + reinterpret_cast(&offsets_limit[0]))); + output->push_back(offsets_start_tensor); + output->push_back(offsets_limit_tensor); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h new file mode 100644 index 0000000000..4f9c76f57e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/kernels/wordpiece_tokenizer_op.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#define DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ +#include +#include +#include +#include + +#include "cppjieba/Unicode.hpp" + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/text/vocab.h" +#include "minddata/dataset/util/status.h" + +using cppjieba::DecodeRunesInString; +using cppjieba::RuneStrArray; +namespace mindspore { +namespace dataset { + +class WordpieceTokenizerOp : public TensorOp { + public: + static const char kDefSuffixIndicator[]; + static const int kDefMaxBytesPerToken; + static const char kDefUnknownToken[]; + static const bool kDefWithOffsets; + WordpieceTokenizerOp(const std::shared_ptr &vocab, const std::string &suffix_indicator = kDefSuffixIndicator, + const int &max_bytes_per_token = kDefMaxBytesPerToken, + const std::string &unknown_token = kDefUnknownToken, const bool &with_offsets = kDefWithOffsets); + + ~WordpieceTokenizerOp() override = default; + + void Print(std::ostream &out) const override { out << "WordpieceTokenizerOp"; } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + protected: + Status AddSubword(const std::string &input_token, const int &start, const int &end, + std::vector *out_token) const; + Status FoundNoToken(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + Status LookupWord(const std::string &input_token, const RuneStrArray &runes, const int start, bool *out_found, + int *out_end) const; + Status GetTokens(const std::string &input_token, const uint32_t &basic_start, std::vector *out_tokens, + std::vector *offsets_start, std::vector *offsets_limit) const; + + std::string Name() const override { return kWordpieceTokenizerOp; } + + private: + const std::shared_ptr vocab_; + const std::string suffix_indicator_; + const bool with_offsets_; + const int max_bytes_per_token_; + const std::string unknown_token_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_TEXT_KERNELS_WORDPIECE_TOKENIZER_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/text/vocab.cc b/mindspore/ccsrc/minddata/dataset/text/vocab.cc new file mode 100644 index 0000000000..c1b7e6265c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/vocab.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +#include "minddata/dataset/text/vocab.h" + +namespace mindspore { +namespace dataset { +Vocab::Vocab(std::unordered_map word2id) { word2id_ = std::move(word2id); } + +WordIdType Vocab::Lookup(const WordType &word) const { + auto itr = word2id_.find(word); + return itr == word2id_.end() ? kNoTokenExists : itr->second; +} + +Status Vocab::BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, + std::shared_ptr *vocab) { + // check of duplication on both words and special_tokens will be performed in python + // special_tokens and words both need to be unique, and shouldn't overlap + std::unordered_map word2id; + // if special is added in front, normal words id will start from number of special tokens + WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; + + for (auto word : words) { + word2id[py::str(word)] = word_id++; + } + + word_id = prepend_special ? 0 : word2id.size(); + + for (auto special_token : special_tokens) { + word2id[py::str(special_token)] = word_id++; + } + + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +Status Vocab::BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, + const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab) { + // python validator checks special_tokens doesn't contain any duplicate words + std::unordered_set specials; + // used to check that words in file don't contain any special token that already exists + for (auto word : special_tokens) { + specials.insert(py::str(word)); + } + WordIdType word_id = prepend_special ? static_cast(special_tokens.size()) : 0; + std::unordered_map word2id; + std::fstream handle(path, std::ios::in); + CHECK_FAIL_RETURN_UNEXPECTED(handle.good() && handle.is_open(), "fail to open:" + path); + std::string word; + while (std::getline(handle, word)) { + if (!delimiter.empty()) { + // if delimiter is not found, find_first_of would return std::string::npos which is -1 + word = word.substr(0, word.find_first_of(delimiter)); + } + CHECK_FAIL_RETURN_UNEXPECTED(word2id.find(word) == word2id.end(), "duplicate word:" + word + "."); + CHECK_FAIL_RETURN_UNEXPECTED(specials.find(word) == specials.end(), word + " is already in special_tokens."); + word2id[word] = word_id++; + // break if enough row is read, if vocab_size is smaller than 0 + if (word2id.size() == vocab_size) break; + } + + word_id = prepend_special ? 0 : word2id.size(); + + for (auto special_token : special_tokens) { + word2id[py::str(special_token)] = word_id++; + } + + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +Status Vocab::BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab) { + std::unordered_map word2id; + for (auto p : words) { + word2id[py::str(p.first)] = py::reinterpret_borrow(p.second); + } + *vocab = std::make_shared(std::move(word2id)); + return Status::OK(); +} + +void Vocab::append_word(const std::string &word) { + if (word2id_.find(word) == word2id_.end()) { + word2id_[word] = word2id_.size(); + } +} + +const WordIdType Vocab::kNoTokenExists = -1; + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/text/vocab.h b/mindspore/ccsrc/minddata/dataset/text/vocab.h new file mode 100644 index 0000000000..6bf6c488c5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/text/vocab.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_TEXT_VOCAB_H_ +#define DATASET_TEXT_VOCAB_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace mindspore { +namespace dataset { +namespace py = pybind11; + +using WordIdType = int32_t; +using WordType = std::string; + +class Vocab { + public: + // Build a vocab from a python dictionary key is each word ,id needs to start from 2, no duplicate and continuous + // @param const py::dict &words - a dictionary containing word, word id pair. + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromPyDict(const py::dict &words, std::shared_ptr *vocab); + + // Build a vocab from a python list, id will be assigned automatically, start from 2 + // @param const py::list &words - a list of string, used to build vocab, id starts from 2 + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromPyList(const py::list &words, const py::list &special_tokens, bool prepend_special, + std::shared_ptr *vocab); + + // Build a vocab from reading a vocab file, id are automatically assigned, start from 2 + // @param std::string &path - path to vocab file , each line is assumed to contain 1 word + // @param std::string &delimiter - delimiter to break each line with + // @param int32_t vocab_size - number of words to read from file + // @param std::shared_ptr *vocab - return value, vocab object + // @return error code + static Status BuildFromFile(const std::string &path, const std::string &delimiter, int32_t vocab_size, + const py::list &special_tokens, bool prepend_special, std::shared_ptr *vocab); + + // Lookup the id of a word, if word doesn't exist in vocab, return default_id + // @param const WordType word - word to look up + // @param WordIdType default_id - word id to return to user when its not in the vocab + // @return WordIdType, word_id + WordIdType Lookup(const WordType &word) const; + + // constructor, shouldn't be called directly, can't be private due to std::make_unique() + // @param std::unordered_map map - sanitized word2id map + explicit Vocab(std::unordered_map map); + + Vocab() = default; + + // add one word to vocab, increment it's index automatically + // @param std::string & word - word to be added will skip if word already exists + void append_word(const std::string &word); + + // destructor + ~Vocab() = default; + + static const WordIdType kNoTokenExists; + + private: + std::unordered_map word2id_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_TEXT_VOCAB_H_ diff --git a/mindspore/ccsrc/dataset/util/.gitignore b/mindspore/ccsrc/minddata/dataset/util/.gitignore similarity index 100% rename from mindspore/ccsrc/dataset/util/.gitignore rename to mindspore/ccsrc/minddata/dataset/util/.gitignore diff --git a/mindspore/ccsrc/dataset/util/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/dataset/util/CMakeLists.txt rename to mindspore/ccsrc/minddata/dataset/util/CMakeLists.txt diff --git a/mindspore/ccsrc/dataset/util/README.md b/mindspore/ccsrc/minddata/dataset/util/README.md similarity index 100% rename from mindspore/ccsrc/dataset/util/README.md rename to mindspore/ccsrc/minddata/dataset/util/README.md diff --git a/mindspore/ccsrc/minddata/dataset/util/allocator.h b/mindspore/ccsrc/minddata/dataset/util/allocator.h new file mode 100644 index 0000000000..b5eaed97a6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/allocator.h @@ -0,0 +1,178 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_ALLOCATOR_H_ +#define DATASET_UTIL_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" + +namespace mindspore { +namespace dataset { +// The following conforms to the requirements of +// std::allocator. Do not rename/change any needed +// requirements, e.g. function names, typedef etc. +template +class Allocator { + public: + template + friend class Allocator; + + using value_type = T; + using pointer = T *; + using const_pointer = const T *; + using reference = T &; + using const_reference = const T &; + using size_type = uint64_t; + + template + struct rebind { + using other = Allocator; + }; + + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + explicit Allocator(const std::shared_ptr &b) : pool_(b) {} + + ~Allocator() = default; + + template + explicit Allocator(Allocator const &rhs) : pool_(rhs.pool_) {} + + template + bool operator==(Allocator const &rhs) const { + return pool_ == rhs.pool_; + } + + template + bool operator!=(Allocator const &rhs) const { + return pool_ != rhs.pool_; + } + + pointer allocate(std::size_t n) { + void *p; + Status rc = pool_->Allocate(n * sizeof(T), &p); + if (rc.IsOk()) { + return reinterpret_cast(p); + } else if (rc.IsOutofMemory()) { + throw std::bad_alloc(); + } else { + throw std::exception(); + } + } + + void deallocate(pointer p, std::size_t n = 0) noexcept { pool_->Deallocate(p); } + + size_type max_size() { return pool_->get_max_size(); } + + private: + std::shared_ptr pool_; +}; +/// \brief It is a wrapper of unique_ptr with a custom allocator and acts like std::lock_guard such that the memory will +/// be released when the object goes out of scope +/// \tparam T The type of object to be allocated +/// \tparam C Allocator. Default to std::allocator +template > +class MemGuard { + public: + using allocator = C; + MemGuard() : n_(0) {} + explicit MemGuard(allocator a) : n_(0), alloc_(a) {} + // There is no copy constructor nor assignment operator because the memory is solely owned by this object. + MemGuard(const MemGuard &) = delete; + MemGuard &operator=(const MemGuard &) = delete; + // On the other hand, We can support move constructor + MemGuard(MemGuard &&lhs) noexcept : alloc_(std::move(lhs.alloc_)), ptr_(std::move(lhs.ptr_)), n_(lhs.n_) {} + MemGuard &operator=(MemGuard &&lhs) noexcept { + if (this != &lhs) { + this->deallocate(); + n_ = lhs.n_; + alloc_ = std::move(lhs.alloc_); + ptr_ = std::move(lhs.ptr_); + } + return *this; + } + /// \brief Explicitly deallocate the memory if allocated + void deallocate() { + if (ptr_) { + auto *p = ptr_.release(); + if (!std::is_arithmetic::value && std::is_destructible::value) { + for (auto i = 0; i < n_; ++i) { + p[i].~T(); + } + } + alloc_.deallocate(p, n_); + n_ = 0; + } + } + /// \brief Allocate memory (with emplace feature). Previous one will be released. If size is 0, no new memory is + /// allocated. + /// \param n Number of objects of type T to be allocated + /// \tparam Args Extra arguments pass to the constructor of T + template + Status allocate(size_t n, Args &&... args) noexcept { + try { + deallocate(); + if (n > 0) { + T *data = alloc_.allocate(n); + if (!std::is_arithmetic::value) { + for (auto i = 0; i < n; i++) { + std::allocator_traits::construct(alloc_, &(data[i]), std::forward(args)...); + } + } + ptr_ = std::unique_ptr(data); + n_ = n; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory); + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); + } + ~MemGuard() noexcept { deallocate(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetPointer() const { return ptr_.get(); } + /// \brief Getter function + /// \return The pointer to the memory allocated + T *GetMutablePointer() { return ptr_.get(); } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) { return GetMutablePointer() + x; } + /// \brief Overload [] operator to access a particular element + /// \param x index to the element. Must be less than number of element allocated. + /// \return pointer to the x-th element + T *operator[](size_t x) const { return GetPointer() + x; } + /// \brief Return how many bytes are allocated in total + /// \return Number of bytes allocated in total + size_t GetSizeInBytes() const { return n_ * sizeof(T); } + + private: + allocator alloc_; + std::unique_ptr ptr_; + size_t n_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.cc b/mindspore/ccsrc/minddata/dataset/util/arena.cc new file mode 100644 index 0000000000..87a9c614a8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/arena.cc @@ -0,0 +1,256 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/arena.h" +#include +#include +#include "minddata/dataset/util/system_pool.h" +#include "./securec.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +struct MemHdr { + uint32_t sig; + uint64_t addr; + uint64_t blk_size; + MemHdr(uint64_t a, uint64_t sz) : sig(0xDEADBEEF), addr(a), blk_size(sz) {} + static void setHdr(void *p, uint64_t addr, uint64_t sz) { new (p) MemHdr(addr, sz); } + static void getHdr(void *p, MemHdr *hdr) { + auto *tmp = reinterpret_cast(p); + *hdr = *tmp; + } +}; +Status Arena::Init() { + RETURN_IF_NOT_OK(DeMalloc(size_in_MB_ * 1048576L, &ptr_, false)); + // Divide the memory into blocks. Ignore the last partial block. + uint64_t num_blks = size_in_bytes_ / ARENA_BLK_SZ; + MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << "."; + tr_.Insert(0, num_blks); + return Status::OK(); +} + +Status Arena::Allocate(size_t n, void **p) { + if (n == 0) { + *p = nullptr; + return Status::OK(); + } + std::unique_lock lck(mux_); + // Round up n to 1K block + uint64_t req_size = static_cast(n) + ARENA_WALL_OVERHEAD_SZ; + if (req_size > this->get_max_size()) { + return Status(StatusCode::kOutOfMemory); + } + uint64_t reqBlk = SizeToBlk(req_size); + // Do a first fit search + auto blk = tr_.Top(); + if (blk.second && reqBlk <= blk.first.priority) { + uint64_t addr = blk.first.key; + uint64_t size = blk.first.priority; + // Trim to the required size and return the rest to the tree. + tr_.Pop(); + if (size > reqBlk) { + tr_.Insert(addr + reqBlk, size - reqBlk); + } + lck.unlock(); + char *q = static_cast(ptr_) + addr * ARENA_BLK_SZ; + MemHdr::setHdr(q, addr, reqBlk); + *p = get_user_addr(q); + } else { + return Status(StatusCode::kOutOfMemory); + } + return Status::OK(); +} + +void Arena::Deallocate(void *p) { + auto *q = get_base_addr(p); + MemHdr hdr(0, 0); + MemHdr::getHdr(q, &hdr); + MS_ASSERT(hdr.sig == 0xDEADBEEF); + // We are going to insert a free block back to the treap. But first, check if we can combine + // with the free blocks before and after to form a bigger block. + std::unique_lock lck(mux_); + // Query if we have a free block after us. + auto nextBlk = tr_.Search(hdr.addr + hdr.blk_size); + if (nextBlk.second) { + // Form a bigger block + hdr.blk_size += nextBlk.first.priority; + tr_.DeleteKey(nextBlk.first.key); + } + // Next find a block in front of us. + auto result = FindPrevBlk(hdr.addr); + if (result.second) { + // We can combine with this block + hdr.addr = result.first.first; + hdr.blk_size += result.first.second; + tr_.DeleteKey(result.first.first); + } + // Now we can insert the free node + tr_.Insert(hdr.addr, hdr.blk_size); +} + +Status Arena::Reallocate(void **pp, size_t old_sz, size_t new_sz) { + MS_ASSERT(pp); + MS_ASSERT(*pp); + uint64_t actual_size = static_cast(new_sz) + ARENA_WALL_OVERHEAD_SZ; + if (actual_size > this->get_max_size()) { + RETURN_STATUS_UNEXPECTED("Request size too big : " + std::to_string(new_sz)); + } + uint64_t req_blk = SizeToBlk(actual_size); + char *oldAddr = reinterpret_cast(*pp); + auto *oldHdr = get_base_addr(oldAddr); + MemHdr hdr(0, 0); + MemHdr::getHdr(oldHdr, &hdr); + MS_ASSERT(hdr.sig == 0xDEADBEEF); + std::unique_lock lck(mux_); + if (hdr.blk_size > req_blk) { + // Refresh the header with the new smaller size. + MemHdr::setHdr(oldHdr, hdr.addr, req_blk); + // Return the unused memory back to the tree. Unlike allocate, we we need to merge with the block after us. + auto next_blk = tr_.Search(hdr.addr + hdr.blk_size); + if (next_blk.second) { + hdr.blk_size += next_blk.first.priority; + tr_.DeleteKey(next_blk.first.key); + } + tr_.Insert(hdr.addr + req_blk, hdr.blk_size - req_blk); + } else if (hdr.blk_size < req_blk) { + uint64_t addr = hdr.addr; + // Attempt a block enlarge. No guarantee it is always successful. + bool success = BlockEnlarge(&addr, hdr.blk_size, req_blk); + if (success) { + auto *newHdr = static_cast(ptr_) + addr * ARENA_BLK_SZ; + MemHdr::setHdr(newHdr, addr, req_blk); + if (addr != hdr.addr) { + errno_t err = + memmove_s(get_user_addr(newHdr), (req_blk * ARENA_BLK_SZ) - ARENA_WALL_OVERHEAD_SZ, oldAddr, old_sz); + if (err) { + RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); + } + } + *pp = get_user_addr(newHdr); + return Status::OK(); + } + // If we reach here, allocate a new block and simply move the content from the old to the new place. + // Unlock since allocate will grab the lock again. + lck.unlock(); + return FreeAndAlloc(pp, old_sz, new_sz); + } + return Status::OK(); +} + +std::ostream &operator<<(std::ostream &os, const Arena &s) { + for (auto &it : s.tr_) { + os << "Address : " << it.key << ". Size : " << it.priority << "\n"; + } + return os; +} + +Arena::Arena(size_t val_in_MB) : ptr_(nullptr), size_in_MB_(val_in_MB), size_in_bytes_(val_in_MB * 1048576L) {} + +Status Arena::CreateArena(std::shared_ptr *p_ba, size_t val_in_MB) { + if (p_ba == nullptr) { + RETURN_STATUS_UNEXPECTED("p_ba is null"); + } + Status rc; + auto ba = new (std::nothrow) Arena(val_in_MB); + if (ba == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = ba->Init(); + if (rc.IsOk()) { + (*p_ba).reset(ba); + } else { + delete ba; + } + return rc; +} + +int Arena::PercentFree() const { + uint64_t sz = 0; + for (auto &it : tr_) { + sz += it.priority; + } + double ratio = static_cast(sz * ARENA_BLK_SZ) / static_cast(size_in_bytes_); + return static_cast(ratio * 100.0); +} + +uint64_t Arena::get_max_size() const { return (size_in_bytes_ - ARENA_WALL_OVERHEAD_SZ); } + +std::pair, bool> Arena::FindPrevBlk(uint64_t addr) { + for (auto &it : tr_) { + if (it.key + it.priority == addr) { + return std::make_pair(std::make_pair(it.key, it.priority), true); + } else if (it.key > addr) { + break; + } + } + return std::make_pair(std::make_pair(0, 0), false); +} + +bool Arena::BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz) { + uint64_t size = old_sz; + // The logic is very much identical to Deallocate. We will see if we can combine with the blocks before and after. + auto next_blk = tr_.Search(*addr + old_sz); + if (next_blk.second) { + size += next_blk.first.priority; + if (size >= new_sz) { + // In this case, we can just enlarge the block without doing any moving. + tr_.DeleteKey(next_blk.first.key); + // Return unused back to the tree. + if (size > new_sz) { + tr_.Insert(*addr + new_sz, size - new_sz); + } + } + return true; + } + // If we still get here, we have to look at the block before us. + auto result = FindPrevBlk(*addr); + if (result.second) { + // We can combine with this block together with the next block (if any) + size += result.first.second; + *addr = result.first.first; + if (size >= new_sz) { + // We can combine with this block together with the next block (if any) + tr_.DeleteKey(*addr); + if (next_blk.second) { + tr_.DeleteKey(next_blk.first.key); + } + // Return unused back to the tree. + if (size > new_sz) { + tr_.Insert(*addr + new_sz, size - new_sz); + } + return true; + } + } + return false; +} + +Status Arena::FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz) { + MS_ASSERT(pp); + MS_ASSERT(*pp); + void *p = nullptr; + void *q = *pp; + RETURN_IF_NOT_OK(Allocate(new_sz, &p)); + errno_t err = memmove_s(p, new_sz, q, old_sz); + if (err) { + RETURN_STATUS_UNEXPECTED("Error from memmove: " + std::to_string(err)); + } + *pp = p; + // Free the old one. + Deallocate(q); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/arena.h b/mindspore/ccsrc/minddata/dataset/util/arena.h new file mode 100644 index 0000000000..8887757af1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/arena.h @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_ARENA_H_ +#define DATASET_UTIL_ARENA_H_ + +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/treap.h" + +#define ARENA_LOG_BLK_SZ (6u) +#define ARENA_BLK_SZ (static_cast(1u << ARENA_LOG_BLK_SZ)) +#define ARENA_WALL_OVERHEAD_SZ 32 +namespace mindspore { +namespace dataset { +// This is a memory arena based on a treap data structure. +// The constructor of the Arena takes the size of the initial memory size (in MB). +// Internally we divide the memory into multiple blocks. Each block is 64 bytes. +// The treap contains all the free blocks with the relative memory address as key +// and the size of the block as priority. +// +// Initially the treap has only one root which is the whole memory piece. +// +// For memory suballocation, we pop the root node of the treap which contains the largest free block. +// We allocate what we need and return the rest back to the treap. We search for the first fit instead +// of the best fit so to give us a constant time in memory allocation. +// +// When a block of memory is freed. It is joined with the blocks before and after (if they are available) to +// form a bigger block. +class Arena : public MemoryPool { + public: + Arena(const Arena &) = delete; + + Arena &operator=(const Arena &) = delete; + + ~Arena() override { + if (ptr_ != nullptr) { + free(ptr_); + ptr_ = nullptr; + } + } + + Status Allocate(size_t n, void **p) override; + + Status Reallocate(void **, size_t old_sz, size_t new_sz) override; + + void Deallocate(void *) override; + + uint64_t get_max_size() const override; + + static uint64_t SizeToBlk(uint64_t sz) { + uint64_t req_blk = sz / ARENA_BLK_SZ; + if (sz % ARENA_BLK_SZ) { + ++req_blk; + } + return req_blk; + } + + int PercentFree() const override; + + const void *get_base_addr() const { return ptr_; } + + friend std::ostream &operator<<(std::ostream &os, const Arena &s); + + static Status CreateArena(std::shared_ptr *p_ba, size_t val_in_MB = 4096); + + private: + std::mutex mux_; + Treap tr_; + void *ptr_; + size_t size_in_MB_; + size_t size_in_bytes_; + + explicit Arena(size_t val_in_MB = 4096); + + std::pair, bool> FindPrevBlk(uint64_t addr); + + Status Init(); + + bool BlockEnlarge(uint64_t *addr, uint64_t old_sz, uint64_t new_sz); + + Status FreeAndAlloc(void **pp, size_t old_sz, size_t new_sz); + + void *get_user_addr(void *base_addr) const { return reinterpret_cast(base_addr) + ARENA_WALL_OVERHEAD_SZ; } + + void *get_base_addr(void *user_addr) const { return reinterpret_cast(user_addr) - ARENA_WALL_OVERHEAD_SZ; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_ARENA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/auto_index.h b/mindspore/ccsrc/minddata/dataset/util/auto_index.h new file mode 100644 index 0000000000..0fe55159e6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/auto_index.h @@ -0,0 +1,99 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_AUTO_INDEX_H_ +#define DATASET_UTIL_AUTO_INDEX_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/util/btree.h" +#include "minddata/dataset/util/system_pool.h" + +namespace mindspore { +namespace dataset { +/// This is a B+ tree with generated int64_t value as key. +/// Use minKey() function to query the min key. +/// Use maxKey() function to query the max key. +/// @tparam T +template > +class AutoIndexObj : public BPlusTree { + public: + using my_tree = BPlusTree; + using key_type = typename my_tree::key_type; + using value_type = typename my_tree::value_type; + + AutoIndexObj() : my_tree::BPlusTree(), inx_(kMinKey) {} + + explicit AutoIndexObj(const Allocator &alloc) : my_tree::BPlusTree(alloc), inx_(kMinKey) {} + + ~AutoIndexObj() = default; + + // Insert an object into the tree. + // @param val + // @return + Status insert(const value_type &val, key_type *key = nullptr) { + key_type my_inx = inx_.fetch_add(1); + if (key != nullptr) { + *key = my_inx; + } + return my_tree::DoInsert(my_inx, val); + } + + Status insert(std::unique_ptr &&val, key_type *key = nullptr) { + key_type my_inx = inx_.fetch_add(1); + if (key) { + *key = my_inx; + } + return my_tree::DoInsert(my_inx, std::move(val)); + } + + // Insert a vector of objects into the tree. + // @param v + // @return + Status insert(std::vector v) { + uint64_t num_ele = v.size(); + if (num_ele > 0) { + // reserve a range of keys rather than getting it one by one. + key_type my_inx = inx_.fetch_add(num_ele); + for (uint64_t i = 0; i < num_ele; i++) { + RETURN_IF_NOT_OK(my_tree::DoInsert(my_inx + i, v.at(i))); + } + } + return Status::OK(); + } + + // @return the minimum key + key_type min_key() const { + auto it = this->cbegin(); + return it.key(); + } + + // @return the maximum key + key_type max_key() const { + auto it = this->cend(); + --it; + return it.key(); + } + + private: + static constexpr key_type kMinKey = 0; + std::atomic inx_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_AUTO_INDEX_H_ diff --git a/mindspore/ccsrc/dataset/util/bit.h b/mindspore/ccsrc/minddata/dataset/util/bit.h similarity index 100% rename from mindspore/ccsrc/dataset/util/bit.h rename to mindspore/ccsrc/minddata/dataset/util/bit.h diff --git a/mindspore/ccsrc/minddata/dataset/util/btree.h b/mindspore/ccsrc/minddata/dataset/util/btree.h new file mode 100644 index 0000000000..828976a0a1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/btree.h @@ -0,0 +1,459 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_INDEX_H_ +#define DATASET_UTIL_INDEX_H_ + +#include +#include +#include +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/list.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Default traits for a B+ tree +struct BPlusTreeTraits { + // This determines the limit of number of keys in a node. + using slot_type = uint16_t; + // Number of slots in each leaf of the tree. + static constexpr slot_type kLeafSlots = 256; + // Number of slots in each inner node of the tree + static constexpr slot_type kInnerSlots = 128; +}; + +/// Implementation of B+ tree +/// @tparam K -- the type of key +/// @tparam V -- the type of value +/// @tparam A -- allocator +/// @tparam C -- comparison class +/// @tparam T -- trait +template , typename C = std::less, + typename T = BPlusTreeTraits> +class BPlusTree { + public: + enum class IndexRc : char { + kOk = 0, + kDuplicateKey = 1, + kSlotFull = 2, + kKeyNotFound = 3, + kNullPointer = 4, + kOutOfMemory = 5, + kRetry = 6, + kUnexpectedError = 127 + }; +#define RETURN_IF_BAD_RC(_s) \ + do { \ + IndexRc __rc = (_s); \ + if (__rc != IndexRc::kOk) { \ + return __rc; \ + } \ + } while (false) + + Status IndexRc2Status(IndexRc rc) { + if (rc == IndexRc::kOk) { + return Status(StatusCode::kOK); + } else if (rc == IndexRc::kOutOfMemory) { + return Status(StatusCode::kOutOfMemory); + } else if (rc == IndexRc::kDuplicateKey) { + return Status(StatusCode::kDuplicateKey); + } else { + RETURN_STATUS_UNEXPECTED(std::to_string(static_cast(rc))); + } + } + + using key_type = K; + using value_type = V; + using key_compare = C; + using slot_type = typename T::slot_type; + using traits = T; + using value_allocator = A; + using key_allocator = typename value_allocator::template rebind::other; + using slot_allocator = typename value_allocator::template rebind::other; + + BPlusTree(); + + explicit BPlusTree(const Allocator &alloc); + + ~BPlusTree() noexcept; + + BPlusTree(const BPlusTree &) = delete; + + BPlusTree(BPlusTree &&) = delete; + + BPlusTree &operator=(const BPlusTree &) = delete; + + BPlusTree &operator=(BPlusTree &&) = delete; + + key_compare key_comp() const { return key_less_; } + + size_t size() const { return stats_.size_; } + + bool empty() const { return (size() == 0); } + + /// @param key + /// @param value + /// @return + Status DoInsert(const key_type &key, const value_type &value); + Status DoInsert(const key_type &key, std::unique_ptr &&value); + + // Update a new value for a given key. + std::unique_ptr DoUpdate(const key_type &key, const value_type &new_value); + std::unique_ptr DoUpdate(const key_type &key, std::unique_ptr &&new_value); + + // Statistics + struct tree_stats { + std::atomic size_; + uint32_t leaves_; + uint32_t inner_nodes_; + uint32_t level_; + + tree_stats() : size_(0), leaves_(0), inner_nodes_(0), level_(0) {} + }; + + private: + // Abstract class of a node (leaf or inner) + class BaseNode { + public: + friend class BPlusTree; + + virtual bool is_leafnode() const = 0; + + virtual bool is_full() const = 0; + + explicit BaseNode(const value_allocator &alloc) : alloc_(alloc) {} + + virtual ~BaseNode() = default; + + protected: + mutable RWLock rw_lock_; + value_allocator alloc_; + + private: + Node lru_; + }; + + // This control block keeps track of all the nodes we traverse on insert. + // To maximize concurrency, internal nodes are latched S. If a node split + // is required, we must releases all the latches and redo it again and change + // the latch mode from S to X. + struct LockPathCB { + enum class LockMode : char { kShared = 0, kExclusive = 1, kNone = 2 }; + + struct path { + BaseNode *node_; + bool locked_; + + path() : node_(nullptr), locked_(false) {} + + path(BaseNode *p, LockMode lockmode) : node_(p), locked_(false) { + if (lockmode == LockMode::kExclusive) { + p->rw_lock_.LockExclusive(); + locked_ = true; + } else if (lockmode == LockMode::kShared) { + p->rw_lock_.LockShared(); + locked_ = true; + } + } + }; + + LockPathCB(BPlusTree *tree, bool retryWithXlock) : self_(tree), latch_shared_(true) { + if (retryWithXlock) { + latch_shared_ = false; + } + if (latch_shared_) { + tree->rw_lock_.LockShared(); + } else { + tree->rw_lock_.LockExclusive(); + } + } + + ~LockPathCB() noexcept { + // Make sure all locks are released. + while (!paths_.empty()) { + path p = paths_.back(); + paths_.pop_back(); + if (p.locked_) { + p.node_->rw_lock_.Unlock(); + } + } + self_->rw_lock_.Unlock(); + self_ = nullptr; + } + + void LockNode(BaseNode *p, LockMode locktype) { paths_.emplace_back(p, locktype); } + + void UnlockMyParents(BaseNode *me) { + path p = paths_.front(); + while (p.node_ != me) { + if (p.locked_) { + p.node_->rw_lock_.Unlock(); + } + paths_.pop_front(); + p = paths_.front(); + } + } + + BPlusTree *self_; + std::deque paths_; + bool latch_shared_; + }; + + // Definition of inner node which fans to either inner node or leaf node. + class InnerNode : public BaseNode { + public: + friend class BPlusTree; + + using alloc_type = typename value_allocator::template rebind::other; + + bool is_leafnode() const override { return false; } + + bool is_full() const override { return (slotuse_ == traits::kInnerSlots); } + + IndexRc Sort(); + + // 50/50 split + IndexRc Split(InnerNode *to, key_type *split_key); + + IndexRc InsertIntoSlot(slot_type slot, const key_type &key, BaseNode *ptr); + + explicit InnerNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} + + ~InnerNode() = default; + + slot_type slot_dir_[traits::kInnerSlots] = {0}; + key_type keys_[traits::kInnerSlots] = {0}; + BaseNode *data_[traits::kInnerSlots + 1] = {nullptr}; + slot_type slotuse_; + }; + + // Definition of a leaf node which contains the key/value pair + class LeafNode : public BaseNode { + public: + friend class BPlusTree; + + using alloc_type = typename value_allocator::template rebind::other; + Node link_; + + bool is_leafnode() const override { return true; } + + bool is_full() const override { return (slotuse_ == traits::kLeafSlots); } + + IndexRc Sort(); + + // 50/50 split + IndexRc Split(LeafNode *to); + + IndexRc InsertIntoSlot(LockPathCB *insCB, slot_type slot, const key_type &key, std::unique_ptr &&value); + + explicit LeafNode(const value_allocator &alloc) : BaseNode::BaseNode(alloc), slotuse_(0) {} + + ~LeafNode() = default; + + slot_type slot_dir_[traits::kLeafSlots] = {0}; + key_type keys_[traits::kLeafSlots] = {0}; + std::unique_ptr data_[traits::kLeafSlots]; + slot_type slotuse_; + }; + + mutable RWLock rw_lock_; + value_allocator alloc_; + // All the leaf nodes. Used by the iterator to traverse all the key/values. + List leaf_nodes_; + // All the nodes (inner + leaf). Used by the destructor to free the memory of all the nodes. + List all_; + // Pointer to the root of the tree. + BaseNode *root_; + // Key comparison object + key_compare key_less_; + // Stat + tree_stats stats_; + + bool LessThan(const key_type &a, const key_type &b) const { return key_less_(a, b); } + + bool EqualOrLessThan(const key_type &a, const key_type &b) const { return !key_less_(b, a); } + + bool Equal(const key_type &a, const key_type &b) const { return !key_less_(a, b) && !key_less_(b, a); } + + IndexRc AllocateInner(InnerNode **p); + + IndexRc AllocateLeaf(LeafNode **p); + + template + slot_type FindSlot(const node_type *node, const key_type &key, bool *duplicate = nullptr) const { + slot_type lo = 0; + while (lo < node->slotuse_ && key_comp()(node->keys_[node->slot_dir_[lo]], key)) { + ++lo; + } + bool keymatch = (lo < node->slotuse_ && Equal(key, node->keys_[node->slot_dir_[lo]])); + if (keymatch && !node->is_leafnode()) { + // For an inner node and we match a key during search, we should look into the next slot. + ++lo; + } + if (duplicate != nullptr) { + *duplicate = keymatch; + } + return lo; + } + + IndexRc LeafInsertKeyValue(LockPathCB *ins_cb, LeafNode *node, const key_type &key, + std::unique_ptr &&value, key_type *split_key, LeafNode **split_node); + + IndexRc InnerInsertKeyChild(InnerNode *node, const key_type &key, BaseNode *ptr, key_type *split_key, + InnerNode **split_node); + + inline BaseNode *FindBranch(InnerNode *inner, slot_type slot) const { + BaseNode *child = nullptr; + if (slot == 0) { + child = inner->data_[0]; + } else { + child = inner->data_[inner->slot_dir_[slot - 1] + 1]; + } + return child; + } + + IndexRc InsertKeyValue(LockPathCB *ins_cb, BaseNode *n, const key_type &key, std::unique_ptr &&value, + key_type *split_key, BaseNode **split_node); + + IndexRc Locate(RWLock *parent_lock, bool forUpdate, BaseNode *top, const key_type &key, LeafNode **ln, + slot_type *s) const; + + public: + class Iterator : public std::iterator { + public: + using reference = BPlusTree::value_type &; + using pointer = BPlusTree::value_type *; + + explicit Iterator(BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} + + Iterator(LeafNode *leaf, slot_type slot, bool locked = false) : cur_(leaf), slot_(slot), locked_(locked) {} + + ~Iterator(); + + explicit Iterator(const Iterator &); + + Iterator &operator=(const Iterator &lhs); + + Iterator(Iterator &&); + + Iterator &operator=(Iterator &&lhs); + + pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } + + reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } + + value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + // Prefix++ + Iterator &operator++(); + + // Postfix++ + Iterator operator++(int); + + // Prefix-- + Iterator &operator--(); + + // Postfix-- + Iterator operator--(int); + + bool operator==(const Iterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } + bool operator!=(const Iterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } + + private: + typename BPlusTree::LeafNode *cur_; + slot_type slot_; + bool locked_; + }; + + class ConstIterator : public std::iterator { + public: + using reference = BPlusTree::value_type &; + using pointer = BPlusTree::value_type *; + + explicit ConstIterator(const BPlusTree *btree) : cur_(btree->leaf_nodes_.head), slot_(0), locked_(false) {} + + ~ConstIterator(); + + ConstIterator(const LeafNode *leaf, slot_type slot, bool locked = false) + : cur_(leaf), slot_(slot), locked_(locked) {} + + explicit ConstIterator(const ConstIterator &); + + ConstIterator &operator=(const ConstIterator &lhs); + + ConstIterator(ConstIterator &&); + + ConstIterator &operator=(ConstIterator &&lhs); + + pointer operator->() const { return cur_->data_[cur_->slot_dir_[slot_]].get(); } + + reference operator*() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + const key_type &key() const { return cur_->keys_[cur_->slot_dir_[slot_]]; } + + value_type &value() const { return *(cur_->data_[cur_->slot_dir_[slot_]].get()); } + + // Prefix++ + ConstIterator &operator++(); + + // Postfix++ + ConstIterator operator++(int); + + // Prefix-- + ConstIterator &operator--(); + + // Postfix-- + ConstIterator operator--(int); + + bool operator==(const ConstIterator &x) const { return (x.cur_ == cur_) && (x.slot_ == slot_); } + bool operator!=(const ConstIterator &x) const { return (x.cur_ != cur_) || (x.slot_ != slot_); } + + private: + const typename BPlusTree::LeafNode *cur_; + slot_type slot_; + bool locked_; + }; + + Iterator begin(); + Iterator end(); + + ConstIterator begin() const; + ConstIterator end() const; + + ConstIterator cbegin() const; + ConstIterator cend() const; + + // Locate the entry with key + std::pair Search(const key_type &key) const; + std::pair Search(const key_type &key); + + value_type operator[](key_type key); +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_INDEX_H_ + +#include "btree_impl.tpp" +#include "btree_iterator.tpp" diff --git a/mindspore/ccsrc/dataset/util/btree_impl.tpp b/mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp similarity index 100% rename from mindspore/ccsrc/dataset/util/btree_impl.tpp rename to mindspore/ccsrc/minddata/dataset/util/btree_impl.tpp diff --git a/mindspore/ccsrc/dataset/util/btree_iterator.tpp b/mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp similarity index 100% rename from mindspore/ccsrc/dataset/util/btree_iterator.tpp rename to mindspore/ccsrc/minddata/dataset/util/btree_iterator.tpp diff --git a/mindspore/ccsrc/minddata/dataset/util/buddy.cc b/mindspore/ccsrc/minddata/dataset/util/buddy.cc new file mode 100644 index 0000000000..d4f5434f81 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/buddy.cc @@ -0,0 +1,388 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/buddy.h" +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/system_pool.h" +#include "utils/log_adapter.h" +#include "./securec.h" + +inline uint64_t BitLeftShift(uint64_t v, uint64_t n) { return (v << n); } + +inline uint64_t BitRightShift(uint64_t v, uint64_t n) { return (v >> n); } + +inline uint64_t BitOr(uint64_t rhs, uint64_t lhs) { return rhs | lhs; } + +inline uint64_t BitEx(uint64_t rhs, uint64_t lhs) { return rhs ^ lhs; } + +inline uint64_t BitAnd(uint64_t rhs, uint64_t lhs) { return rhs & lhs; } + +namespace mindspore { +namespace dataset { +Status BuddySpace::Init() { + if (log_min_ < 0) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "log_min must be positive : " + std::to_string(log_min_)); + } + if (num_lvl_ < 3 || num_lvl_ > 18) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, + "num_lvl must be between 3 and 18 : " + std::to_string(num_lvl_)); + } + min_ = BitLeftShift(1, log_min_); + max_ = BitLeftShift(1, log_min_ + num_lvl_ - 1); + size_t offset_1 = sizeof(rel_addr_t) * num_lvl_; + size_t offset_2 = sizeof(int) * num_lvl_ + offset_1; + size_t offset_3 = sizeof(char) * BitLeftShift(1, num_lvl_ - 3) + offset_2; + RETURN_IF_NOT_OK(DeMalloc(offset_3, &ptr_, true)); + hint_ = reinterpret_cast(ptr_); + count_ = reinterpret_cast((reinterpret_cast(ptr_) + offset_1)); + map_ = reinterpret_cast(ptr_) + offset_2; + count_[num_lvl_ - 1] = 1; + map_[0] = BitOr(MORE_BIT, num_lvl_ - 3); + return Status::OK(); +} + +Status BuddySpace::Alloc(const uint64_t sz, BSpaceDescriptor *desc, addr_t *p) noexcept { + std::lock_guard lock(mutex_); + addr_t addr = AllocNoLock(sz, desc); + if (addr != NOSPACE) { + *p = addr; + return Status::OK(); + } else { + return Status(StatusCode::kNoSpace, "BuddySpace full. Not an error. Please ignore."); + } +} + +addr_t BuddySpace::AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept { + MS_ASSERT(sz <= max_); + uint32_t reqSize = SizeToBlock(sz); + rel_addr_t rel_addr = AllocBuddySeg(reqSize); + if (rel_addr != static_cast(NOSPACE)) { + (void)memset_s(desc, sizeof(BSpaceDescriptor), 0, sizeof(BSpaceDescriptor)); + desc->sig = static_cast(0xDEADBEEF); + desc->addr = rel_addr; + desc->req_size = reqSize; + desc->blk_size = NextPowerOf2(reqSize); + return static_cast(rel_addr * min_); + } else { + return NOSPACE; + } +} + +void BuddySpace::FreeNoLock(const BSpaceDescriptor *desc) { + MS_ASSERT(desc->sig == 0XDEADBEEF); + rel_addr_t rel_addr = desc->addr; + size_t blk_size = desc->blk_size; + size_t req_size = desc->req_size; + FreeBuddySeg(rel_addr, blk_size, req_size); +} + +void BuddySpace::Free(const BSpaceDescriptor *desc) { + std::lock_guard lock(mutex_); + return FreeNoLock(desc); +} + +std::ostream &operator<<(std::ostream &os, const BuddySpace &s) { + os << "1 unit = " << s.GetMinSize() << "\n" + << "Size of buddy space = " << s.GetMaxSize() << "\n" + << "Number of levels = " << s.num_lvl_ << "\n\n" + << "Percent free = " << s.PercentFree() << "\n" + << "Dumping count array : " + << "\n"; + for (int i = 0; i < s.num_lvl_; i++) { + os << "[" << i << "] = " << s.count_[i] << " "; + if (((i + 1) % 4) == 0) { + os << "\n"; + } + } + os << "\n"; + os << "Dumping allocation info:" + << "\n"; + auto max_addr = static_cast(BitLeftShift(1, s.num_lvl_ - 1)); + rel_addr_t addr = 0; + while (addr < max_addr) { + size_t sz = 0; + BuddySpace::STATE st; + s.GetBuddySegState(addr, &sz, &st); + os << "Address : " << std::left << std::setw(8) << addr << " Size : " << std::setw(8) << sz << " State : " + << ((st == BuddySpace::STATE::kAlloc) ? "ALLOC" : ((st == BuddySpace::STATE::kFree) ? "FREE" : "Unkonwn")) + << "\n"; + addr += sz; + } + return os; +} + +void BuddySpace::GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const { + char byte; + int pos; + int offset; + uint64_t val = 0; + int shift; + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + byte = map_[pos]; + switch (offset) { + case 0: + val = byte; + break; + case 1: + case 3: + if (offset == 1) { + val = BitLeftShift(BitAnd(byte, 0x30), shift); + } else { + val = BitLeftShift(BitAnd(byte, 0x03), shift); + } + break; + case 2: + val = BitLeftShift(BitAnd(byte, 0x0F), shift); + break; + } + if (BitAnd(val, ONE_BIT)) { + *rel_sz = 1; + } else if (BitAnd(val, TWO_BIT)) { + *rel_sz = 2; + } else if (BitAnd(val, MORE_BIT)) { + log_t lg = BitAnd(val, 0x0F); + *rel_sz = BitLeftShift(1, lg + 2); + } else { + *st = STATE::kEmpty; + return; + } + *st = BitAnd(val, ALLOC_BIT) ? STATE::kAlloc : STATE::kFree; +} + +void BuddySpace::SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st) { + int clr; + int mask; + int pos; + int offset; + int val = 0; + int shift; + auto log_sz = static_cast(Log2(rel_sz)); + pos = BitRightShift(rel_addr, 2); + offset = rel_addr % 4; + shift = offset * 2; + if (rel_sz == 1) { + val = ONE_BIT; + mask = 0xC0; + } else if (rel_sz == 2) { + val = TWO_BIT; + mask = 0xF0; + } else { + val = BitOr(log_sz - 2, MORE_BIT); + mask = 0xFF; + } + if (st == STATE::kAlloc) { + val = BitOr(val, ALLOC_BIT); + } else if (st == STATE::kFree) { + val = BitAnd(val, ~(static_cast(ALLOC_BIT))); + } else if (st == STATE::kEmpty) { + val = 0; + } + clr = static_cast(~(BitRightShift(mask, shift))); + map_[pos] = static_cast(BitAnd(map_[pos], clr)); + map_[pos] = static_cast(BitOr(map_[pos], BitRightShift(val, shift))); + if (st == STATE::kAlloc) { + count_[log_sz]--; + } else if (st == STATE::kFree) { + count_[log_sz]++; + if (rel_addr < hint_[log_sz]) { + hint_[log_sz] = rel_addr; + } + } +} + +void BuddySpace::JoinBuddySeg(rel_addr_t addr, size_t blk_sz) { + while (blk_sz < BitLeftShift(1, num_lvl_)) { + rel_addr_t buddy = BitEx(addr, blk_sz); + size_t sz = 0; + STATE st; + GetBuddySegState(buddy, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + auto log_sz = static_cast(Log2(blk_sz)); + rel_addr_t left = (buddy < addr) ? buddy : addr; + rel_addr_t right = left + blk_sz; + MS_ASSERT(count_[log_sz] >= 2); + count_[log_sz] -= 2; + SetBuddySegState(right, blk_sz, STATE::kEmpty); + SetBuddySegState(left, BitLeftShift(blk_sz, 1), STATE::kFree); + for (int i = 0; i < log_sz; i++) { + if (hint_[i] == right) { + hint_[i] = left; + } + } + addr = left; + blk_sz <<= 1u; + } else { + break; + } + } +} + +void BuddySpace::TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + MS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + count_[i]--; + SetBuddySegState(addr, half_sz, STATE::kFree); + SetBuddySegState(addr + half_sz, half_sz, STATE::kFree); + if (remaining_sz >= half_sz) { + SetBuddySegState(addr, half_sz, STATE::kAlloc); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + break; + } + addr += half_sz; + } + } +} + +void BuddySpace::UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz) { + MS_ASSERT(ask_sz < blk_sz); + uint32_t inx = Log2(blk_sz); + size_t remaining_sz = ask_sz; + for (int i = inx; i > 0; i--) { + size_t b_size = BitLeftShift(1, i); + size_t half_sz = BitRightShift(b_size, 1); + if (remaining_sz >= half_sz) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + MS_ASSERT(sz == half_sz && st == STATE::kAlloc); + } +#endif + SetBuddySegState(addr, half_sz, STATE::kFree); + remaining_sz -= half_sz; + if (remaining_sz == 0) { + JoinBuddySeg(addr, half_sz); + break; + } + addr += half_sz; + } + } +} + +rel_addr_t BuddySpace::AllocBuddySeg(uint32_t req_size) noexcept { + uint32_t blk_size = NextPowerOf2(req_size); + int start_inx = static_cast(Log2(blk_size)); + bool found = false; + rel_addr_t ask_addr = 0; + auto max_addr = static_cast(BitLeftShift(1, num_lvl_ - 1)); + STATE st; + size_t sz = 0; + for (int i = start_inx; !found && i < num_lvl_; i++) { + MS_ASSERT(count_[i] >= 0); + if (count_[i] == 0) { + continue; + } + auto blk_sz = static_cast(BitLeftShift(1, i)); + ask_addr = hint_[i]; + while (ask_addr < max_addr && !found) { + GetBuddySegState(ask_addr, &sz, &st); + if (st == STATE::kFree && sz == blk_sz) { + found = true; + } else { + MS_ASSERT(st != STATE::kEmpty); + ask_addr += ((sz > blk_sz) ? sz : blk_sz); + } + } + } + if (found) { + if (sz > req_size) { + TrimBuddySeg(ask_addr, sz, req_size); + } else { + SetBuddySegState(ask_addr, sz, STATE::kAlloc); + hint_[start_inx] = ask_addr; + } + return ask_addr; + } else { + return static_cast(NOSPACE); + } +} + +void BuddySpace::FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size) { + if (req_size == blk_size) { +#ifdef DEBUG + { + size_t sz = 0; + STATE st; + GetBuddySegState(addr, &sz, &st); + } +#endif + SetBuddySegState(addr, blk_size, STATE::kFree); + JoinBuddySeg(addr, blk_size); + } else { + UnTrimBuddySeg(addr, blk_size, req_size); + } +} + +int BuddySpace::PercentFree() const { + uint64_t total_free_sz = 0; + uint64_t max_sz_in_unit = BitLeftShift(1, num_lvl_ - 1); + // Go through the count array without lock + for (int i = 0; i < num_lvl_; i++) { + int cnt = count_[i]; + if (cnt == 0) { + continue; + } + uint64_t blk_sz = BitLeftShift(1, i); + total_free_sz += (blk_sz * cnt); + } + return static_cast(static_cast(total_free_sz) / static_cast(max_sz_in_unit) * 100); +} + +BuddySpace::BuddySpace(int log_min, int num_lvl) + : hint_(nullptr), + count_(nullptr), + map_(nullptr), + log_min_(log_min), + num_lvl_(num_lvl), + min_(0), + max_(0), + ptr_(nullptr) {} + +BuddySpace::~BuddySpace() { + if (ptr_ != nullptr) { + free(ptr_); + } + hint_ = nullptr; + count_ = nullptr; + map_ = nullptr; +} + +Status BuddySpace::CreateBuddySpace(std::unique_ptr *out_bs, int log_min, int num_lvl) { + Status rc; + auto bs = new (std::nothrow) BuddySpace(log_min, num_lvl); + if (bs == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = bs->Init(); + if (rc.IsOk()) { + (*out_bs).reset(bs); + } else { + delete bs; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/buddy.h b/mindspore/ccsrc/minddata/dataset/util/buddy.h new file mode 100644 index 0000000000..b1bcd3ce41 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/buddy.h @@ -0,0 +1,133 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_BUDDY_H_ +#define DATASET_UTIL_BUDDY_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/status.h" + +using addr_t = int64_t; +using rel_addr_t = int32_t; +using log_t = int; +#define ALLOC_BIT 0x80 +#define ONE_BIT 0x40 +#define TWO_BIT 0x20 +#define MORE_BIT 0x10 +#define NOSPACE ((addr_t)(-1)) +namespace mindspore { +namespace dataset { +struct BSpaceDescriptor { + int32_t sig; + rel_addr_t addr; + size_t req_size; + size_t blk_size; +}; + +class BuddySpace { + public: + // C++11 feature. Change STATE into a type safe class with + // the keyword. Don't take out the keyword 'class' + enum class STATE { kFree, kAlloc, kEmpty }; + + BuddySpace(const BuddySpace &) = delete; + + BuddySpace &operator=(const BuddySpace &) = delete; + + virtual ~BuddySpace(); + + Status Alloc(uint64_t sz, BSpaceDescriptor *desc, addr_t *) noexcept; + + void Free(const BSpaceDescriptor *desc); + + uint64_t GetMinSize() const { return min_; } + + uint64_t GetMaxSize() const { return max_; } + + int PercentFree() const; + + friend std::ostream &operator<<(std::ostream &os, const BuddySpace &s); + + static uint64_t NextPowerOf2(uint64_t n) { + if (n <= 1) { + return 1; + } + n = n - 1; + while (n & (n - 1)) { + n = n & (n - 1); + } + return n << 1; + } + + static uint32_t Log2(uint64_t n) { + uint32_t cnt = 0; + while (n >>= 1) { + cnt++; + } + return cnt; + } + + static Status CreateBuddySpace(std::unique_ptr *out_bs, int log_min = 15, int num_lvl = 18); + + private: + rel_addr_t *hint_; + int *count_; + char *map_; + int log_min_; + int num_lvl_; + uint64_t min_; + uint64_t max_; + void *ptr_; + std::mutex mutex_; + + explicit BuddySpace(int log_min = 15, int num_lvl = 18); + + Status Init(); + + addr_t AllocNoLock(const uint64_t sz, BSpaceDescriptor *desc) noexcept; + + void FreeNoLock(const BSpaceDescriptor *desc); + + uint32_t SizeToBlock(const uint64_t sz) const { + uint32_t reqSize = (sz / min_); + if (sz % min_) { + reqSize++; + } + return reqSize; + } + + void GetBuddySegState(const rel_addr_t rel_addr, size_t *rel_sz, STATE *st) const; + + void SetBuddySegState(rel_addr_t rel_addr, size_t rel_sz, STATE st); + + void JoinBuddySeg(rel_addr_t addr, size_t blk_sz); + + void TrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + void UnTrimBuddySeg(rel_addr_t addr, size_t blk_sz, size_t ask_sz); + + rel_addr_t AllocBuddySeg(uint32_t req_size) noexcept; + + void FreeBuddySeg(rel_addr_t addr, size_t blk_size, size_t req_size); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_BUDDY_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc new file mode 100644 index 0000000000..22fb72eb8a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.cc @@ -0,0 +1,197 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "common/utils.h" +#include "minddata/dataset/util/cache_pool.h" +#include "minddata/dataset/util/services.h" + +namespace mindspore { +namespace dataset { +CachePool::CachePool(const value_allocator &alloc, const std::string &root) + : alloc_(alloc), root_(root), subfolder_(Services::GetUniqueID()), sm_(nullptr), tree_(nullptr) {} + +Status CachePool::DoServiceStart() { + tree_ = std::make_shared(); + // If we are given a disk path, set up the StorageManager + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + RETURN_IF_NOT_OK(spill.CreateDirectories()); + sm_ = std::make_shared(spill); + RETURN_IF_NOT_OK(sm_->ServiceStart()); + MS_LOG(INFO) << "CachePool will use disk folder: " << common::SafeCStr(spill.toString()); + } + return Status::OK(); +} +Status CachePool::DoServiceStop() { + Status rc; + Status rc2; + if (sm_ != nullptr) { + rc = sm_->ServiceStop(); + if (rc.IsError()) { + rc2 = rc; + } + } + sm_.reset(); + for (auto &bl : *tree_) { + if (bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, bl.sz); + } + } + tree_.reset(); + if (!root_.toString().empty()) { + Path spill = GetSpillPath(); + auto it = Path::DirIterator::OpenDirectory(&spill); + while (it->hasNext()) { + rc = it->next().Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + rc = spill.Remove(); + if (rc.IsError() && rc2.IsOk()) { + rc2 = rc; + } + } + return rc2; +} +CachePool::~CachePool() noexcept { (void)ServiceStop(); } +Status CachePool::Insert(const std::vector &buf, CachePool::key_type *key) { + DataLocator bl; + Status rc; + size_t sz = 0; + // We will consolidate all the slices into one piece. + for (auto &v : buf) { + sz += v.GetSize(); + } + bl.sz = sz; + try { + bl.ptr = alloc_.allocate(sz); + // We will do a piecewise copy. + WritableSlice dest(bl.ptr, bl.sz); + size_t pos = 0; + for (auto &v : buf) { + WritableSlice out(dest, pos); + rc = WritableSlice::Copy(&out, v); + if (rc.IsError()) { + break; + } + pos += v.GetSize(); + } + if (rc.IsError()) { + alloc_.deallocate(bl.ptr, sz); + bl.ptr = nullptr; + return rc; + } + } catch (std::bad_alloc &e) { + if (sm_ != nullptr) { + RETURN_IF_NOT_OK(sm_->Write(&bl.storage_key, buf)); + } else { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + rc = tree_->insert(bl, key); + if (rc.IsError() && bl.ptr != nullptr) { + alloc_.deallocate(bl.ptr, sz); + } + return rc; +} +Status CachePool::Read(CachePool::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + if (it->ptr != nullptr) { + ReadableSlice src(it->ptr, it->sz); + RETURN_IF_NOT_OK(WritableSlice::Copy(dest, src)); + } else if (sm_ != nullptr) { + size_t expectedLength = 0; + RETURN_IF_NOT_OK(sm_->Read(it->storage_key, dest, &expectedLength)); + if (expectedLength != it->sz) { + MS_LOG(ERROR) << "Unexpected length. Read " << expectedLength << ". Expected " << it->sz << "." + << " Internal key: " << key << "\n"; + RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details."); + } + } + if (bytesRead != nullptr) { + *bytesRead = it->sz; + } + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} +const CachePool::value_allocator &CachePool::get_allocator() const { return alloc_; } +Path CachePool::GetSpillPath() const { + auto spill = Path(root_) / subfolder_; + return spill; +} +CachePool::CacheStat CachePool::GetStat() const { + CacheStat cs{0}; + for (auto &it : *tree_) { + if (it.ptr != nullptr) { + ++cs.num_mem_cached; + } else { + ++cs.num_disk_cached; + } + } + return cs; +} +Status CachePool::Spill(CachePool::DataLocator *dl) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to spill"); + } + RETURN_UNEXPECTED_IF_NULL(dl); + RETURN_UNEXPECTED_IF_NULL(dl->ptr); + if (dl->storage_key == 0) { + ReadableSlice data(dl->ptr, dl->sz); + RETURN_IF_NOT_OK(sm_->Write(&dl->storage_key, {data})); + } + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return Status::OK(); +} +Status CachePool::Locate(CachePool::DataLocator *dl) { + RETURN_UNEXPECTED_IF_NULL(dl); + if (dl->ptr == nullptr) { + if (sm_ == nullptr) { + RETURN_STATUS_UNEXPECTED("No disk storage to locate the data"); + } + try { + dl->ptr = alloc_.allocate(dl->sz); + WritableSlice dest(dl->ptr, dl->sz); + Status rc = Read(dl->storage_key, &dest); + if (rc.IsError()) { + alloc_.deallocate(dl->ptr, dl->sz); + dl->ptr = nullptr; + return rc; + } + } catch (const std::bad_alloc &e) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + return Status::OK(); +} +size_t CachePool::GetSize(CachePool::key_type key) const { + auto r = tree_->Search(key); + if (r.second) { + auto &it = r.first; + return it->sz; + } else { + return 0; + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/cache_pool.h b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h new file mode 100644 index 0000000000..cdb6da16b6 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cache_pool.h @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_CACHE_POOL_H_ +#define DATASET_UTIL_CACHE_POOL_H_ + +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/storage_manager.h" +#include "minddata/dataset/util/auto_index.h" + +namespace mindspore { +namespace dataset { +/// \brief A CachePool provides service for backup/restore a buffer. A buffer can be represented in a form of vector of +/// ReadableSlice where all memory blocks will be copied to one contiguous block which can be in memory or spilled to +/// disk (if a disk directory is provided). Every buffer insert will return a generated key which can be used to +/// restore the buffer. +/// \see ReadableSlice +class CachePool : public Service { + public: + using base_type = uint8_t; + using pointer = base_type *; + using const_pointer = const base_type *; + using reference = base_type &; + using const_reference = const base_type &; + using value_allocator = Allocator; + + // An internal class to locate the whereabouts of a backed up buffer which can be either in + class DataLocator { + public: + DataLocator() : ptr(nullptr), sz(0), storage_key(0) {} + ~DataLocator() = default; + DataLocator(const DataLocator &other) = default; + DataLocator &operator=(const DataLocator &other) = default; + DataLocator(DataLocator &&other) noexcept { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + DataLocator &operator=(DataLocator &&other) noexcept { + if (&other != this) { + ptr = other.ptr; + sz = other.sz; + storage_key = other.storage_key; + other.ptr = nullptr; + other.sz = 0; + other.storage_key = 0; + } + return *this; + } + pointer ptr; + size_t sz; + StorageManager::key_type storage_key; + }; + + using data_index = AutoIndexObj; + using key_type = data_index::key_type; + using bl_alloc_type = typename value_allocator::template rebind::other; + + /// \brief Simple statistics returned from CachePool like how many elements are cached in memory and + /// how many elements are spilled to disk. + struct CacheStat { + int64_t num_mem_cached; + int64_t num_disk_cached; + }; + + /// \brief Constructor + /// \param alloc Allocator to allocate memory from + /// \param root Optional disk folder to spill + explicit CachePool(const value_allocator &alloc, const std::string &root = ""); + + CachePool(const CachePool &) = delete; + CachePool(CachePool &&) = delete; + CachePool &operator=(const CachePool &) = delete; + CachePool &operator=(CachePool &&) = delete; + ~CachePool() noexcept; + + Status DoServiceStart() override; + Status DoServiceStop() override; + + Path GetSpillPath() const; + + /// \brief Insert a sequence of ReadableSlice objects into the pool. + /// All memory blocks will be consolidated into one contiguous block and be cached in either memory or on disk. + /// \param[in] buf A sequence of ReadableSlice objects. + /// \param[out] key Generated key + /// \return Error code + Status Insert(const std::vector &buf, key_type *key); + /// \brief Restore a cached buffer (from memory or disk) + /// \param[in] key A previous key returned from Insert + /// \param[out] dest The cached buffer will be copied to this destination represented by a WritableSlice + /// \param[out] bytesRead Optional. Number of bytes read. + /// \return Error code + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead = nullptr) const; + + Status Spill(DataLocator *dl); + + Status Locate(DataLocator *dl); + + size_t GetSize(key_type key) const; + + /// \brief Get statistics. + /// \return CacheStat object + CacheStat GetStat() const; + + const value_allocator &get_allocator() const; + + std::string MyName() const { return subfolder_; } + + private: + value_allocator alloc_; + Path root_; + const std::string subfolder_; + std::shared_ptr sm_; + std::shared_ptr tree_; +}; +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc b/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc new file mode 100644 index 0000000000..f99e6de2f1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/circular_pool.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/circular_pool.h" + +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/system_pool.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +Status CircularPool::AddOneArena() { + Status rc; + std::shared_ptr b; + RETURN_IF_NOT_OK(Arena::CreateArena(&b, arena_size_)); + tail_ = b.get(); + cur_size_in_mb_ += arena_size_; + mem_segments_.push_back(std::move(b)); + return Status::OK(); +} + +ListOfArenas::iterator CircularPool::CircularIterator::Next() { + ListOfArenas::iterator it = dp_->mem_segments_.begin(); + uint32_t size = dp_->mem_segments_.size(); + // This is what we return + it += cur_; + // Prepare for the next round + cur_++; + if (cur_ == size) { + if (start_ == 0) { + has_next_ = false; + } else { + wrap_ = true; + cur_ = 0; + } + } else if (cur_ == start_) { + has_next_ = false; + } + return it; +} + +bool CircularPool::CircularIterator::has_next() const { return has_next_; } + +void CircularPool::CircularIterator::Reset() { + wrap_ = false; + has_next_ = false; + if (!dp_->mem_segments_.empty()) { + // Find the buddy arena that corresponds to the tail. + cur_tail_ = dp_->tail_; + auto list_end = dp_->mem_segments_.end(); + auto it = std::find_if(dp_->mem_segments_.begin(), list_end, + [this](const std::shared_ptr &b) { return b.get() == cur_tail_; }); + MS_ASSERT(it != list_end); + start_ = std::distance(dp_->mem_segments_.begin(), it); + cur_ = start_; + has_next_ = true; + } +} + +CircularPool::CircularIterator::CircularIterator(CircularPool *dp) : dp_(dp) { Reset(); } + +Status CircularPool::Allocate(size_t n, void **p) { + if (p == nullptr) { + RETURN_STATUS_UNEXPECTED("p is null"); + } + Status rc; + void *ptr = nullptr; + do { + SharedLock lock_s(&rw_lock_); + int prevSzInMB = cur_size_in_mb_; + bool move_tail = false; + CircularIterator cirIt(this); + while (cirIt.has_next()) { + auto it = cirIt.Next(); + Arena *ba = it->get(); + if (ba->get_max_size() < n) { + return Status(StatusCode::kOutOfMemory); + } + // If we are asked to move forward the tail + if (move_tail) { + Arena *expected = cirIt.cur_tail_; + (void)atomic_compare_exchange_weak(&tail_, &expected, ba); + move_tail = false; + } + rc = ba->Allocate(n, &ptr); + if (rc.IsOk()) { + *p = ptr; + break; + } else if (rc.IsOutofMemory()) { + // Make the next arena a new tail and continue. + move_tail = true; + } else { + return rc; + } + } + + // Handle the case we have done one round robin search. + if (ptr == nullptr) { + // If we have room to expand. + if (unlimited_ || cur_size_in_mb_ < max_size_in_mb_) { + // lock in exclusively mode. + lock_s.Upgrade(); + // Check again if someone has already expanded. + if (cur_size_in_mb_ == prevSzInMB) { + RETURN_IF_NOT_OK(AddOneArena()); + } + // Re-acquire the shared lock and try again + lock_s.Downgrade(); + } else { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } + } + } while (ptr == nullptr); + return rc; +} + +void CircularPool::Deallocate(void *p) { + // Lock in the chain in shared mode and find out which + // segment it comes from + SharedLock lock(&rw_lock_); + auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { + char *q = reinterpret_cast(p); + char *base = const_cast(reinterpret_cast(b->get_base_addr())); + return (q > base && q < base + b->get_max_size()); + }); + lock.Unlock(); + it->get()->Deallocate(p); +} + +Status CircularPool::Reallocate(void **pp, size_t old_sz, size_t new_sz) { + // Lock in the chain in shared mode and find out which + // segment it comes from + if (pp == nullptr) { + RETURN_STATUS_UNEXPECTED("pp is null"); + } + void *p = *pp; + SharedLock lock(&rw_lock_); + auto it = std::find_if(mem_segments_.begin(), mem_segments_.end(), [p](std::shared_ptr &b) -> bool { + char *q = reinterpret_cast(p); + char *base = const_cast(reinterpret_cast(b->get_base_addr())); + return (q > base && q < base + b->get_max_size()); + }); + lock.Unlock(); + MS_ASSERT(it != mem_segments_.end()); + Arena *ba = it->get(); + Status rc = ba->Reallocate(pp, old_sz, new_sz); + if (rc.IsOutofMemory()) { + // The current arena has no room for the bigger size. + // Allocate free space from another arena and copy + // the content over. + void *q = nullptr; + rc = this->Allocate(new_sz, &q); + RETURN_IF_NOT_OK(rc); + errno_t err = memcpy_s(q, new_sz, p, old_sz); + if (err) { + this->Deallocate(q); + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + *pp = q; + ba->Deallocate(p); + } + return Status::OK(); +} + +uint64_t CircularPool::get_max_size() const { return mem_segments_.front()->get_max_size(); } + +int CircularPool::PercentFree() const { + int percent_free = 0; + int num_arena = 0; + for (auto const &p : mem_segments_) { + percent_free += p->PercentFree(); + num_arena++; + } + if (num_arena) { + return percent_free / num_arena; + } else { + return 100; + } +} + +CircularPool::CircularPool(int max_size_in_gb, int arena_size) + : unlimited_(max_size_in_gb <= 0), + max_size_in_mb_(unlimited_ ? std::numeric_limits::max() : max_size_in_gb * 1024), + arena_size_(arena_size), + cur_size_in_mb_(0) {} + +Status CircularPool::CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb, int arena_size, + bool createOneArena) { + Status rc; + if (out_pool == nullptr) { + RETURN_STATUS_UNEXPECTED("pPool is null"); + } + auto pool = new (std::nothrow) CircularPool(max_size_in_gb, arena_size); + if (pool == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + if (createOneArena) { + rc = pool->AddOneArena(); + } + if (rc.IsOk()) { + (*out_pool).reset(pool); + } else { + delete pool; + } + return rc; +} + +CircularPool::~CircularPool() = default; +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/circular_pool.h b/mindspore/ccsrc/minddata/dataset/util/circular_pool.h new file mode 100644 index 0000000000..a63afbd691 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/circular_pool.h @@ -0,0 +1,108 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_CIRCULAR_POOL_H_ +#define DATASET_UTIL_CIRCULAR_POOL_H_ + +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/arena.h" +#include "minddata/dataset/util/lock.h" + +namespace mindspore { +namespace dataset { +using ListOfArenas = std::vector>; + +// This is a dynamic memory pool built on top of memory +// segment each of which is 4G in size. Initially we start +// with one segment, and gradually add segments (not +// guaranteed contiguous) until we reach 32G in size. There +// is an assumption about this kind of memory pool. Allocated +// memory is not held for the whole duration of the pool and +// will be released soon. Based on this assumption, memory is +// obtained from the tail while allocated memory is returned +// to the head of the pool. +class CircularPool : public MemoryPool { + public: + class CircularIterator { + friend class CircularPool; + + public: + explicit CircularIterator(CircularPool *dp); + + ~CircularIterator() = default; + + bool has_next() const; + + ListOfArenas::iterator Next(); + + void Reset(); + + private: + CircularPool *dp_; + Arena *cur_tail_{}; + uint32_t start_{}; + uint32_t cur_{}; + bool wrap_{}; + bool has_next_{}; + }; + + CircularPool(const CircularPool &) = delete; + + CircularPool &operator=(const CircularPool &) = delete; + + ~CircularPool() override; + + Status Allocate(size_t n, void **) override; + + Status Reallocate(void **, size_t old_size, size_t new_size) override; + + void Deallocate(void *) override; + + uint64_t get_max_size() const override; + + int PercentFree() const override; + + friend std::ostream &operator<<(std::ostream &os, const CircularPool &s) { + int i = 0; + for (auto it = s.mem_segments_.begin(); it != s.mem_segments_.end(); ++it, ++i) { + os << "Dumping segment " << i << "\n" << *(it->get()); + } + return os; + } + + static Status CreateCircularPool(std::shared_ptr *out_pool, int max_size_in_gb = -1, + int arena_size = 4096, bool create_one_arena = false); + + private: + ListOfArenas mem_segments_; + std::atomic tail_{}; + bool unlimited_; + int max_size_in_mb_; + int arena_size_; + int cur_size_in_mb_; + RWLock rw_lock_; + + // We can take negative or 0 as input which means unlimited. + CircularPool(int max_size_in_gb, int arena_size); + + Status AddOneArena(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_CIRCULAR_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/cond_var.cc b/mindspore/ccsrc/minddata/dataset/util/cond_var.cc new file mode 100644 index 0000000000..b7c7b76cae --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cond_var.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/cond_var.h" +#include +#include +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +CondVar::CondVar() : svc_(nullptr), my_name_(Services::GetUniqueID()) {} + +Status CondVar::Wait(std::unique_lock *lck, const std::function &pred) { + try { + if (svc_ != nullptr) { + // If this cv registers with a global resource tracking, then wait unconditionally. + auto f = [this, &pred]() -> bool { return (pred() || this->Interrupted()); }; + cv_.wait(*lck, f); + // If we are interrupted, override the return value if this is the master thread. + // Master thread is being interrupted mostly because of some thread is reporting error. + RETURN_IF_NOT_OK(Task::OverrideInterruptRc(this->GetInterruptStatus())); + } else { + // Otherwise we wake up once a while to check for interrupt (for this thread). + auto f = [&pred]() -> bool { return (pred() || this_thread::is_interrupted()); }; + while (!f()) { + (void)cv_.wait_for(*lck, std::chrono::milliseconds(1)); + } + RETURN_IF_INTERRUPTED(); + } + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); +} + +CondVar::~CondVar() noexcept { + if (svc_ != nullptr) { + (void)svc_->Deregister(my_name_); + svc_ = nullptr; + } +} + +void CondVar::NotifyOne() noexcept { cv_.notify_one(); } + +void CondVar::NotifyAll() noexcept { cv_.notify_all(); } + +Status CondVar::Register(std::shared_ptr svc) { + Status rc = svc->Register(my_name_, this); + if (rc.IsOk()) { + svc_ = svc; + } + return rc; +} + +void CondVar::Interrupt() { + IntrpResource::Interrupt(); + cv_.notify_all(); +} + +std::string CondVar::my_name() const { return my_name_; } + +Status CondVar::Deregister() { + if (svc_) { + Status rc = svc_->Deregister(my_name_); + svc_ = nullptr; + return rc; + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/cond_var.h b/mindspore/ccsrc/minddata/dataset/util/cond_var.h new file mode 100644 index 0000000000..88fcad24a2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/cond_var.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_COND_VAR_H_ +#define DATASET_UTIL_COND_VAR_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/intrp_resource.h" +#include "minddata/dataset/util/intrp_service.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class CondVar : public IntrpResource { + public: + CondVar(); + + ~CondVar() noexcept; + + Status Wait(std::unique_lock *lck, const std::function &pred); + + void Interrupt() override; + + void NotifyOne() noexcept; + + void NotifyAll() noexcept; + + Status Register(std::shared_ptr svc); + + std::string my_name() const; + + Status Deregister(); + + protected: + std::condition_variable cv_; + std::shared_ptr svc_; + + private: + std::string my_name_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_COND_VAR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h b/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h new file mode 100644 index 0000000000..9d78e2cd32 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_resource.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_INTRP_RESOURCE_H_ +#define DATASET_UTIL_INTRP_RESOURCE_H_ + +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class IntrpResource { + public: + enum class State : int { kRunning, kInterrupted }; + + IntrpResource() : st_(State::kRunning) {} + + virtual ~IntrpResource() = default; + + virtual void Interrupt() { st_ = State::kInterrupted; } + + virtual void ResetIntrpState() { st_ = State::kRunning; } + + State CurState() const { return st_; } + + bool Interrupted() const { return CurState() == State::kInterrupted; } + + virtual Status GetInterruptStatus() const { + if (Interrupted()) { + return Status(StatusCode::kInterrupted); + } + return Status::OK(); + } + + protected: + std::atomic st_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_INTRP_RESOURCE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc b/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc new file mode 100644 index 0000000000..a82c82cdc9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_service.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/intrp_service.h" +#include +#include "common/utils.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +IntrpService::IntrpService() : high_water_mark_(0) { (void)ServiceStart(); } + +IntrpService::~IntrpService() noexcept { + MS_LOG(INFO) << "Number of registered resources is " << high_water_mark_ << "."; + if (!all_intrp_resources_.empty()) { + try { + InterruptAll(); + } catch (const std::exception &e) { + // Ignore all error as we can't throw in the destructor. + } + } + (void)ServiceStop(); +} + +Status IntrpService::Register(const std::string &name, IntrpResource *res) { + SharedLock stateLck(&state_lock_); + // Now double check the state + if (ServiceState() != STATE::kRunning) { + return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Interrupt service is shutting down"); + } else { + std::lock_guard lck(mutex_); + try { + std::ostringstream ss; + ss << this_thread::get_id(); + MS_LOG(DEBUG) << "Register resource with name " << name << ". Thread ID " << ss.str() << "."; + auto it = all_intrp_resources_.emplace(name, res); + if (it.second == false) { + return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, name); + } + high_water_mark_++; + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + } + return Status::OK(); +} + +Status IntrpService::Deregister(const std::string &name) noexcept { + std::lock_guard lck(mutex_); + try { + std::ostringstream ss; + ss << this_thread::get_id(); + MS_LOG(DEBUG) << "De-register resource with name " << name << ". Thread ID is " << ss.str() << "."; + auto n = all_intrp_resources_.erase(name); + if (n == 0) { + MS_LOG(INFO) << "Key " << name << " not found."; + } + } catch (std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + return Status::OK(); +} + +void IntrpService::InterruptAll() noexcept { + std::lock_guard lck(mutex_); + for (auto const &it : all_intrp_resources_) { + std::string kName = it.first; + try { + it.second->Interrupt(); + } catch (const std::exception &e) { + // continue the clean up. + } + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/intrp_service.h b/mindspore/ccsrc/minddata/dataset/util/intrp_service.h new file mode 100644 index 0000000000..cb6bf30c73 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/intrp_service.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_INTRP_SERVICE_H_ +#define DATASET_UTIL_INTRP_SERVICE_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/intrp_resource.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +using SvcAllocator = Allocator>; + +class IntrpService : public Service { + public: + IntrpService(); + + ~IntrpService() noexcept override; + + IntrpService(const IntrpService &) = delete; + + IntrpService &operator=(const IntrpService &) = delete; + + Status Register(const std::string &name, IntrpResource *res); + + Status Deregister(const std::string &name) noexcept; + + void InterruptAll() noexcept; + + Status DoServiceStart() override { return Status::OK(); } + + Status DoServiceStop() override { return Status::OK(); } + + private: + int high_water_mark_; + std::mutex mutex_; + std::map all_intrp_resources_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_INTRP_SERVICE_H_ diff --git a/mindspore/ccsrc/dataset/util/list.h b/mindspore/ccsrc/minddata/dataset/util/list.h similarity index 100% rename from mindspore/ccsrc/dataset/util/list.h rename to mindspore/ccsrc/minddata/dataset/util/list.h diff --git a/mindspore/ccsrc/minddata/dataset/util/lock.cc b/mindspore/ccsrc/minddata/dataset/util/lock.cc new file mode 100644 index 0000000000..5302196a46 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/lock.cc @@ -0,0 +1,185 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/lock.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +void SpinLock::Lock() { + while (true) { + int expected = kUnlocked; + if (val_.compare_exchange_weak(expected, kLocked)) { + break; + } + } +} + +bool SpinLock::TryLock() { + int expected = kUnlocked; + return val_.compare_exchange_strong(expected, kLocked); +} + +void SpinLock::Unlock() noexcept { val_.store(kUnlocked); } + +void RWLock::LockShared() { + std::unique_lock lck(mtx_); + waiting_readers_ += 1; + read_cv_.wait(lck, [this]() { return (waiting_writers_ == 0 && status_ >= 0); }); + waiting_readers_ -= 1; + status_ += 1; +} + +void RWLock::Unlock() noexcept { + std::unique_lock lck(mtx_); + if (status_ == -1) { + // I am the writer. By definition, no other writer nor reader. + status_ = 0; + } else if (status_ > 0) { + // One less reader + status_ -= 1; + } + // Wake up writer only if there is no reader. + if (waiting_writers_ > 0) { + if (status_ == 0) { + write_cv_.notify_one(); + } + } else { + read_cv_.notify_all(); + } +} + +void RWLock::Upgrade() { + std::unique_lock lck(mtx_); + MS_ASSERT(status_); + if (status_ == -1) { + // I am a writer already. + return; + } else if (status_ == 1) { + // If I am the only reader. Just change the status. + status_ = -1; + return; + } else { + // In all other cases, let of the shared lock and relock in exclusive. + lck.unlock(); + this->Unlock(); + this->LockExclusive(); + } +} + +void RWLock::Downgrade() { + std::unique_lock lck(mtx_); + MS_ASSERT(status_); + if (status_ == -1) { + // If there are no other writers waiting, just change the status + if (waiting_writers_ == 0) { + status_ = 1; + } else { + // Otherwise just unlock and relock in shared + lck.unlock(); + this->Unlock(); + this->LockShared(); + } + } else if (status_ > 0) { + return; + } +} + +SharedLock::SharedLock(RWLock *rw) : rw_(rw), ownlock_(false) { + rw_->LockShared(); + ownlock_ = true; +} + +SharedLock::~SharedLock() { + if (ownlock_) { + rw_->Unlock(); + ownlock_ = false; + } + rw_ = nullptr; +} + +void SharedLock::Unlock() { + MS_ASSERT(ownlock_ == true); + rw_->Unlock(); + ownlock_ = false; +} + +void SharedLock::Lock() { + MS_ASSERT(ownlock_ == false); + rw_->LockShared(); + ownlock_ = true; +} + +void SharedLock::Upgrade() { + MS_ASSERT(ownlock_ == true); + rw_->Upgrade(); +} + +void SharedLock::Downgrade() { + MS_ASSERT(ownlock_ == true); + rw_->Downgrade(); +} + +UniqueLock::UniqueLock(RWLock *rw) : rw_(rw), ownlock_(false) { + rw_->LockExclusive(); + ownlock_ = true; +} + +UniqueLock::~UniqueLock() { + if (ownlock_) { + rw_->Unlock(); + ownlock_ = false; + } + rw_ = nullptr; +} + +void UniqueLock::Unlock() { + MS_ASSERT(ownlock_ == true); + rw_->Unlock(); + ownlock_ = false; +} + +void UniqueLock::Lock() { + MS_ASSERT(ownlock_ == false); + rw_->LockExclusive(); + ownlock_ = true; +} + +LockGuard::LockGuard(SpinLock *lock) : lck_(lock), own_lock_(false) { + lck_->Lock(); + own_lock_ = true; +} + +LockGuard::~LockGuard() { + if (own_lock_) { + lck_->Unlock(); + own_lock_ = false; + } + lck_ = nullptr; +} + +void LockGuard::Unlock() { + MS_ASSERT(own_lock_); + lck_->Unlock(); + own_lock_ = false; +} + +void LockGuard::Lock() { + MS_ASSERT(own_lock_ == false); + lck_->Lock(); + own_lock_ = true; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/lock.h b/mindspore/ccsrc/minddata/dataset/util/lock.h similarity index 100% rename from mindspore/ccsrc/dataset/util/lock.h rename to mindspore/ccsrc/minddata/dataset/util/lock.h diff --git a/mindspore/ccsrc/minddata/dataset/util/memory_pool.cc b/mindspore/ccsrc/minddata/dataset/util/memory_pool.cc new file mode 100644 index 0000000000..0e1be9d798 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/memory_pool.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/memory_pool.h" +#include "./securec.h" + +namespace mindspore { +namespace dataset { +Status DeMalloc(std::size_t s, void **p, bool init_to_zero = false) { + if (p == nullptr) { + RETURN_STATUS_UNEXPECTED("p is null"); + } + void *q = ::malloc(s); + if (q == nullptr) { + return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__); + } else { + *p = q; + if (init_to_zero) { + (void)memset_s(q, s, 0, s); + } + return Status::OK(); + } +} +} // namespace dataset +} // namespace mindspore + +void *operator new(std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { + void *ptr = nullptr; + *rc = b->Allocate(s, &ptr); + return ptr; +} + +void *operator new[](std::size_t s, mindspore::dataset::Status *rc, std::shared_ptr b) { + void *ptr = nullptr; + *rc = b->Allocate(s, &ptr); + return ptr; +} + +void operator delete(void *p, std::shared_ptr b) { + if (p != nullptr) b->Deallocate(p); +} + +void operator delete[](void *p, std::shared_ptr b) { + if (p != nullptr) b->Deallocate(p); +} diff --git a/mindspore/ccsrc/minddata/dataset/util/memory_pool.h b/mindspore/ccsrc/minddata/dataset/util/memory_pool.h new file mode 100644 index 0000000000..c7cc473109 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/memory_pool.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_MEMORY_POOL_H_ +#define DATASET_UTIL_MEMORY_POOL_H_ + +#include +#include +#include +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +// Abstract class of a memory pool +class MemoryPool { + public: + // Allocate a block of size n + virtual Status Allocate(size_t, void **) = 0; + + // Enlarge or shrink a block from oldSz to newSz + virtual Status Reallocate(void **, size_t old_sz, size_t new_sz) = 0; + + // Free a pointer + virtual void Deallocate(void *) = 0; + + // What is the maximum size I can allocate ? + virtual uint64_t get_max_size() const = 0; + + virtual int PercentFree() const = 0; + + // Destructor + virtual ~MemoryPool() {} +}; + +Status DeMalloc(std::size_t s, void **p, bool); +} // namespace dataset +} // namespace mindspore + +void *operator new(std::size_t, mindspore::dataset::Status *, std::shared_ptr); + +void *operator new[](std::size_t, mindspore::dataset::Status *, std::shared_ptr); + +void operator delete(void *, std::shared_ptr); + +void operator delete[](void *, std::shared_ptr); + +#endif // DATASET_UTIL_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/path.cc b/mindspore/ccsrc/minddata/dataset/util/path.cc new file mode 100644 index 0000000000..8740ecb8e0 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/path.cc @@ -0,0 +1,340 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/path.h" + +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +#if defined(_WIN32) || defined(_WIN64) +char Path::separator_ = '\\'; +#else +char Path::separator_ = '/'; +#endif + +Path::Path(const std::string &s) : path_(s) {} + +Path::Path(const char *p) : path_(p) {} + +Path::Path(const Path &p) : path_(p.path_) {} + +Path &Path::operator=(const Path &p) { + if (&p != this) { + this->path_ = p.path_; + } + return *this; +} + +Path &Path::operator=(Path &&p) noexcept { + if (&p != this) { + this->path_ = std::move(p.path_); + } + return *this; +} + +Path::Path(Path &&p) noexcept { this->path_ = std::move(p.path_); } + +Path Path::operator+(const Path &p) { + std::string q = path_ + p.toString(); + return Path(q); +} + +Path Path::operator+(const std::string &p) { + std::string q = path_ + p; + return Path(q); +} + +Path Path::operator+(const char *p) { + std::string q = path_ + p; + return Path(q); +} + +Path &Path::operator+=(const Path &rhs) { + path_ += rhs.toString(); + return *this; +} + +Path &Path::operator+=(const std::string &p) { + path_ += p; + return *this; +} + +Path &Path::operator+=(const char *p) { + path_ += p; + return *this; +} + +Path Path::operator/(const Path &p) { + std::string q = path_ + separator_ + p.toString(); + return Path(q); +} + +Path Path::operator/(const std::string &p) { + std::string q = path_ + separator_ + p; + return Path(q); +} + +Path Path::operator/(const char *p) { + std::string q = path_ + separator_ + p; + return Path(q); +} + +std::string Path::Extension() const { + std::size_t found = path_.find_last_of('.'); + if (found != std::string::npos) { + return path_.substr(found); + } else { + return std::string(""); + } +} + +bool Path::Exists() { + struct stat sb; + int rc = stat(common::SafeCStr(path_), &sb); + if (rc == -1 && errno != ENOENT) { + MS_LOG(WARNING) << "Unable to query the status of " << path_ << ". Errno = " << errno << "."; + } + return (rc == 0); +} + +bool Path::IsDirectory() { + struct stat sb; + int rc = stat(common::SafeCStr(path_), &sb); + if (rc == 0) { + return S_ISDIR(sb.st_mode); + } else { + return false; + } +} + +Status Path::CreateDirectory() { + if (!Exists()) { +#if defined(_WIN32) || defined(_WIN64) + int rc = mkdir(common::SafeCStr(path_)); +#else + int rc = mkdir(common::SafeCStr(path_), S_IRUSR | S_IWUSR | S_IXUSR); +#endif + if (rc) { + std::ostringstream oss; + oss << "Unable to create directory " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + return Status::OK(); + } else { + if (IsDirectory()) { + return Status::OK(); + } else { + std::ostringstream oss; + oss << "Unable to create directory " << path_ << ". It exists but is not a directory"; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } +} + +std::string Path::ParentPath() { + std::string r(""); + std::size_t found = path_.find_last_of(separator_); + if (found != std::string::npos) { + if (found == 0) { + r += separator_; + } else { + r = std::string(path_.substr(0, found)); + } + } + return r; +} + +Status Path::CreateDirectories() { + if (IsDirectory()) { + MS_LOG(DEBUG) << "Directory " << toString() << " already exists."; + return Status::OK(); + } else { + MS_LOG(DEBUG) << "Creating directory " << toString() << "."; + std::string parent = ParentPath(); + if (!parent.empty()) { + if (Path(parent).CreateDirectories()) { + return CreateDirectory(); + } + } else { + return CreateDirectory(); + } + } + return Status::OK(); +} + +Status Path::Remove() { + if (Exists()) { + if (IsDirectory()) { + errno_t err = rmdir(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete directory " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } else { + errno_t err = unlink(common::SafeCStr(path_)); + if (err == -1) { + std::ostringstream oss; + oss << "Unable to delete file " << path_ << ". Errno = " << errno; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + } + } + return Status::OK(); +} + +Status Path::CreateFile(int *file_descriptor) { return OpenFile(file_descriptor, true); } + +Status Path::OpenFile(int *file_descriptor, bool create) { + int fd; + if (file_descriptor == nullptr) { + RETURN_STATUS_UNEXPECTED("null pointer"); + } + if (IsDirectory()) { + std::ostringstream oss; + oss << "Unable to create file " << path_ << " which is a directory."; + RETURN_STATUS_UNEXPECTED(oss.str()); + } + // Convert to canonical form. + if (strlen(common::SafeCStr(path_)) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + char canonical_path[PATH_MAX + 1] = {0x00}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(path_), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(path_), canonical_path) == nullptr) { +#endif + if (errno == ENOENT && create) { + // File doesn't exist and we are to create it. Let's break it down. + auto file_part = Basename(); + auto parent_part = ParentPath(); +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(canonical_path, common::SafeCStr(parent_part), PATH_MAX) == nullptr) { +#else + if (realpath(common::SafeCStr(parent_part), canonical_path) == nullptr) { +#endif + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto cur_inx = strlen(canonical_path); + if ((cur_inx + file_part.length() + 1) > PATH_MAX) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + canonical_path[cur_inx++] = separator_; + if (strncpy_s(canonical_path + cur_inx, PATH_MAX - cur_inx, common::SafeCStr(file_part), file_part.length()) != + EOK) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + } + if (create) { + fd = open(canonical_path, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP); + } else { + fd = open(canonical_path, O_RDWR); + } + if (fd == -1) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + *file_descriptor = fd; + return Status::OK(); +} + +Status Path::CloseFile(int fd) const { + if (close(fd) < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + return Status::OK(); +} + +Status Path::TruncateFile(int fd) const { + int rc; + rc = ftruncate(fd, 0); + if (rc == 0) { + return Status::OK(); + } else { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } +} + +std::string Path::Basename() { + std::size_t found = path_.find_last_of(separator_); + if (found != std::string::npos) { + return path_.substr(found + 1); + } else { + return path_; + } +} + +std::shared_ptr Path::DirIterator::OpenDirectory(Path *f) { + auto it = new (std::nothrow) DirIterator(f); + + if (it == nullptr) { + return nullptr; + } + + if (it->dp_) { + return std::shared_ptr(it); + } else { + delete it; + return nullptr; + } +} + +Path::DirIterator::~DirIterator() { + if (dp_) { + (void)closedir(dp_); + } + dp_ = nullptr; + dir_ = nullptr; + entry_ = nullptr; +} + +Path::DirIterator::DirIterator(Path *f) : dir_(f), dp_(nullptr), entry_(nullptr) { + MS_LOG(DEBUG) << "Open directory " << f->toString() << "."; + dp_ = opendir(f->toString().c_str()); +} + +bool Path::DirIterator::hasNext() { + do { + entry_ = readdir(dp_); + if (entry_) { + if (strcmp(entry_->d_name, ".") == 0 || strcmp(entry_->d_name, "..") == 0) { + continue; + } + } + break; + } while (true); + return (entry_ != nullptr); +} + +Path Path::DirIterator::next() { return (*(this->dir_) / Path(entry_->d_name)); } + +std::ostream &operator<<(std::ostream &os, const Path &s) { + os << s.path_; + return os; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/path.h b/mindspore/ccsrc/minddata/dataset/util/path.h new file mode 100644 index 0000000000..8bc07ca8f3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/path.h @@ -0,0 +1,114 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_PATH_H_ +#define DATASET_UTIL_PATH_H_ + +#include +#include +#include + +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class Path { + public: + class DirIterator { + public: + static std::shared_ptr OpenDirectory(Path *f); + + ~DirIterator(); + + bool hasNext(); + + Path next(); + + private: + explicit DirIterator(Path *f); + + Path *dir_; + DIR *dp_; + struct dirent *entry_; + }; + + explicit Path(const std::string &); + + explicit Path(const char *); + + ~Path() = default; + + Path(const Path &); + + Path &operator=(const Path &); + + Path(Path &&) noexcept; + + Path &operator=(Path &&) noexcept; + + std::string toString() const { return path_; } + + Path operator+(const Path &); + + Path operator+(const std::string &); + + Path operator+(const char *); + + Path &operator+=(const Path &rhs); + + Path &operator+=(const std::string &); + + Path &operator+=(const char *); + + Path operator/(const Path &); + + Path operator/(const std::string &); + + Path operator/(const char *); + + bool Exists(); + + bool IsDirectory(); + + Status CreateDirectory(); + + Status CreateDirectories(); + + std::string Extension() const; + + std::string ParentPath(); + + Status Remove(); + + Status CreateFile(int *fd); + + Status OpenFile(int *fd, bool create = false); + + Status CloseFile(int fd) const; + + Status TruncateFile(int fd) const; + + std::string Basename(); + + friend std::ostream &operator<<(std::ostream &os, const Path &s); + + private: + static char separator_; + std::string path_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_PATH_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/queue.h b/mindspore/ccsrc/minddata/dataset/util/queue.h new file mode 100644 index 0000000000..7a0a987499 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/queue.h @@ -0,0 +1,256 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_QUEUE_H_ +#define DATASET_UTIL_QUEUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "utils/log_adapter.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/cond_var.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +template +struct is_shared_ptr : public std::false_type {}; + +template +struct is_shared_ptr> : public std::true_type {}; + +template +struct is_unique_ptr : public std::false_type {}; + +template +struct is_unique_ptr> : public std::true_type {}; + +// A simple thread safe queue using a fixed size array +template +class Queue { + public: + using value_type = T; + using pointer = T *; + using const_pointer = const T *; + using reference = T &; + using const_reference = const T &; + + void Init() { + if (sz_ > 0) { + // We allocate a block of memory and then call the default constructor for each slot. Maybe simpler to call + // new[] but we want to control where the memory is allocated from. + arr_ = alloc_.allocate(sz_); + for (uint64_t i = 0; i < sz_; i++) { + std::allocator_traits>::construct(alloc_, &(arr_[i])); + } + } + } + + explicit Queue(int sz) + : sz_(sz), + arr_(nullptr), + head_(0), + tail_(0), + my_name_(Services::GetUniqueID()), + alloc_(Services::GetInstance().GetServiceMemPool()) { + Init(); + MS_LOG(DEBUG) << "Create Q with uuid " << my_name_ << " of size " << sz_ << "."; + } + + virtual ~Queue() { + ResetQue(); + if (arr_) { + // Simply free the pointer. Since there is nothing in the queue. We don't want to invoke the destructor + // of T in each slot. + alloc_.deallocate(arr_); + arr_ = nullptr; + } + } + + int size() const { + int v = tail_ - head_; + return (v >= 0) ? v : 0; + } + + int capacity() const { return sz_; } + + bool empty() const { return head_ == tail_; } + + void Reset() { ResetQue(); } + + // Producer + Status Add(const_reference ele) noexcept { + std::unique_lock _lock(mux_); + // Block when full + Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); + if (rc.IsOk()) { + uint32_t k = tail_++ % sz_; + arr_[k] = ele; + empty_cv_.NotifyAll(); + _lock.unlock(); + } else { + empty_cv_.Interrupt(); + } + return rc; + } + + Status Add(T &&ele) noexcept { + std::unique_lock _lock(mux_); + // Block when full + Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); + if (rc.IsOk()) { + uint32_t k = tail_++ % sz_; + arr_[k] = std::forward(ele); + empty_cv_.NotifyAll(); + _lock.unlock(); + } else { + empty_cv_.Interrupt(); + } + return rc; + } + + template + Status EmplaceBack(Ts &&... args) noexcept { + std::unique_lock _lock(mux_); + // Block when full + Status rc = full_cv_.Wait(&_lock, [this]() -> bool { return (size() != capacity()); }); + if (rc.IsOk()) { + uint32_t k = tail_++ % sz_; + new (&(arr_[k])) T(std::forward(args)...); + empty_cv_.NotifyAll(); + _lock.unlock(); + } else { + empty_cv_.Interrupt(); + } + return rc; + } + + // Consumer + Status PopFront(pointer p) { + std::unique_lock _lock(mux_); + // Block when empty + Status rc = empty_cv_.Wait(&_lock, [this]() -> bool { return !empty(); }); + if (rc.IsOk()) { + uint32_t k = head_++ % sz_; + *p = std::move(arr_[k]); + if (std::is_destructible::value) { + // std::move above only changes arr_[k] from rvalue to lvalue. + // The real implementation of move constructor depends on T. + // It may be compiler generated or user defined. But either case + // the result of arr_[k] is still a valid object of type T, and + // we will not keep any extra copy in the queue. + arr_[k].~T(); + // For gcc 9, an extra fix is needed here to clear the memory content + // of arr_[k] because this slot can be reused by another Add which can + // do another std::move. We have seen SEGV here in this case. + std::allocator_traits>::construct(alloc_, &(arr_[k])); + } + full_cv_.NotifyAll(); + _lock.unlock(); + } else { + full_cv_.Interrupt(); + } + return rc; + } + + void ResetQue() noexcept { + std::unique_lock _lock(mux_); + // If there are elements in the queue, invoke its destructor one by one. + if (!empty() && std::is_destructible::value) { + for (uint64_t i = head_; i < tail_; i++) { + uint32_t k = i % sz_; + arr_[k].~T(); + } + } + for (uint64_t i = 0; i < sz_; i++) { + std::allocator_traits>::construct(alloc_, &(arr_[i])); + } + empty_cv_.ResetIntrpState(); + full_cv_.ResetIntrpState(); + head_ = 0; + tail_ = 0; + } + + Status Register(TaskGroup *vg) { + Status rc1 = empty_cv_.Register(vg->GetIntrpService()); + Status rc2 = full_cv_.Register(vg->GetIntrpService()); + if (rc1.IsOk()) { + return rc2; + } else { + return rc1; + } + } + + private: + uint64_t sz_; + pointer arr_; + uint64_t head_; + uint64_t tail_; + std::string my_name_; + std::mutex mux_; + CondVar empty_cv_; + CondVar full_cv_; + Allocator alloc_; +}; + +// A container of queues with [] operator accessors. Basically this is a wrapper over of a vector of queues +// to help abstract/simplify code that is maintaining multiple queues. +template +class QueueList { + public: + QueueList() {} + + void Init(int num_queues, int capacity) { + queue_list_.reserve(num_queues); + for (int i = 0; i < num_queues; i++) { + queue_list_.emplace_back(std::make_unique>(capacity)); + } + } + + Status Register(TaskGroup *vg) { + if (vg == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Null task group during QueueList registration."); + } + for (int i = 0; i < queue_list_.size(); ++i) { + RETURN_IF_NOT_OK(queue_list_[i]->Register(vg)); + } + return Status::OK(); + } + + int size() const { return queue_list_.size(); } + + std::unique_ptr> &operator[](const int index) { return queue_list_[index]; } + + const std::unique_ptr> &operator[](const int index) const { return queue_list_[index]; } + + ~QueueList() = default; + + private: + // Queue contains non-copyable objects, so it cannot be added to a vector due to the vector + // requirement that objects must have copy semantics. To resolve this, we use a vector of unique + // pointers. This allows us to provide dynamic creation of queues in a container. + std::vector>> queue_list_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_QUEUE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/random.h b/mindspore/ccsrc/minddata/dataset/util/random.h new file mode 100644 index 0000000000..d2658f67ec --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/random.h @@ -0,0 +1,74 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_RANDOM_H_ +#define DATASET_UTIL_RANDOM_H_ + +#if defined(_WIN32) || defined(_WIN64) +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +inline std::mt19937 GetRandomDevice() { +#if defined(_WIN32) || defined(_WIN64) + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; +#else + int i = 0; + while (i < 5) { + try { + std::mt19937 random_device{std::random_device("/dev/urandom")()}; + return random_device; + } catch (const std::exception &e) { + MS_LOG(WARNING) << "Get std::random_device failed, retry: " << i << ", error: " << e.what(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + i++; + } + } + std::mt19937 random_device{std::random_device("/dev/urandom")()}; +#endif + return random_device; +} + +inline uint32_t GetNewSeed() { + std::mt19937 random_device = GetRandomDevice(); + std::uniform_int_distribution distribution(0, std::numeric_limits::max()); + return distribution(random_device); +} + +inline uint32_t GetSeed() { + uint32_t seed = GlobalContext::config_manager()->seed(); + if (seed == std::mt19937::default_seed) { + seed = GetNewSeed(); + } + return seed; +} + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_RANDOM_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.cc b/mindspore/ccsrc/minddata/dataset/util/semaphore.cc new file mode 100644 index 0000000000..5dadd98f3c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/semaphore.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +Status Semaphore::P() { + std::unique_lock lck(mutex_); + RETURN_IF_NOT_OK(wait_cond_.Wait(&lck, [this]() { return value_ > 0; })); + --value_; + return Status::OK(); +} +void Semaphore::V() { + std::unique_lock lck(mutex_); + ++value_; + wait_cond_.NotifyOne(); +} +int Semaphore::Peek() { + std::unique_lock lck(mutex_); + return value_; +} +Status Semaphore::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } +Status Semaphore::Deregister() { return (wait_cond_.Deregister()); } +void Semaphore::ResetIntrpState() { wait_cond_.ResetIntrpState(); } + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/semaphore.h b/mindspore/ccsrc/minddata/dataset/util/semaphore.h new file mode 100644 index 0000000000..d07398acb1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/semaphore.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SEMAPHORE_H_ +#define DATASET_UTIL_SEMAPHORE_H_ + +#include "minddata/dataset/util/cond_var.h" + +namespace mindspore { +namespace dataset { +class TaskGroup; + +/// \brief A counting semaphore. There are two external functions P and V. P decrements the internal count and will be +/// blocked if the count is 0 (zero). V increments the internal count and wake up one of the waiters. +class Semaphore { + public: + /// \brief Constructor + /// \param init Initial value of the internal counter. + explicit Semaphore(int init) : value_(init) {} + + virtual ~Semaphore() {} + /// \brief Decrement the internal counter. Will be blocked if the value is 0. + /// \return Error code. Can get interrupt. + Status P(); + /// \brief Increment the internal counter. Wakeup on of the watiers if any. + void V(); + /// \brief Peek the internal value + /// \return The internal value + int Peek(); + Status Register(TaskGroup *vg); + Status Deregister(); + void ResetIntrpState(); + + private: + int value_; + + std::mutex mutex_; + CondVar wait_cond_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SEMAPHORE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/service.cc b/mindspore/ccsrc/minddata/dataset/util/service.cc new file mode 100644 index 0000000000..19d60ab47a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/service.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/service.h" +#include + +namespace mindspore { +namespace dataset { +Status Service::ServiceStart() { + do { + UniqueLock lck(&state_lock_); + // No-op if it is already up or some other thread is + // in the process of bring it up. + if (state_ == STATE::kRunning || state_ == STATE::kStartInProg) { + return Status::OK(); + } + // If a stop is in progress, we line up after it + // is done. + if (state_ == STATE::kStopInProg) { + std::this_thread::yield(); + } else { + state_ = STATE::kStartInProg; + // At this point, we will let go of the lock. This allow others to proceed. + lck.Unlock(); + RETURN_IF_NOT_OK(DoServiceStart()); + // Lock again to change state. + lck.Lock(); + state_ = STATE::kRunning; + return Status::OK(); + } + } while (true); +} + +Status Service::ServiceStop() noexcept { + do { + UniqueLock lck(&state_lock_); + // No-op if it is already stopped or some other thread is + // in the process of shutting it down + if (state_ == STATE::kStopped || state_ == STATE::kStopInProg) { + return Status::OK(); + } + // If a start is in progress, we line up after it + // is done. + if (state_ == STATE::kStartInProg) { + std::this_thread::yield(); + } else { + state_ = STATE::kStopInProg; + // At this point, we will let go of the lock. This allows others to proceed. + lck.Unlock(); + RETURN_IF_NOT_OK(DoServiceStop()); + // Lock again to change state. + lck.Lock(); + state_ = STATE::kStopped; + return Status::OK(); + } + } while (true); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/service.h b/mindspore/ccsrc/minddata/dataset/util/service.h new file mode 100644 index 0000000000..2b9c7197fe --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/service.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SERVICE_H_ +#define DATASET_UTIL_SERVICE_H_ + +#include +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class Service { + public: + enum class STATE : int { kStartInProg = 1, kRunning, kStopInProg, kStopped }; + + Service() : state_(STATE::kStopped) {} + + Service(const Service &) = delete; + + Service &operator=(const Service &) = delete; + + virtual ~Service() {} + + STATE ServiceState() const { return state_; } + + virtual Status DoServiceStart() = 0; + + virtual Status DoServiceStop() = 0; + + Status ServiceStart(); + + Status ServiceStop() noexcept; + + protected: + STATE state_; + RWLock state_lock_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SERVICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/services.cc b/mindspore/ccsrc/minddata/dataset/util/services.cc new file mode 100644 index 0000000000..547773e0f1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/services.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/services.h" + +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#else +#include +#endif +#include +#include "minddata/dataset/engine/cache/cache_server.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +std::unique_ptr Services::instance_ = nullptr; +std::once_flag Services::init_instance_flag_; + +#if !defined(_WIN32) && !defined(_WIN64) +std::string Services::GetUserName() { + char user[LOGIN_NAME_MAX]; + (void)getlogin_r(user, sizeof(user)); + return std::string(user); +} + +std::string Services::GetHostName() { + char host[LOGIN_NAME_MAX]; + (void)gethostname(host, sizeof(host)); + return std::string(host); +} + +int Services::GetLWP() { return syscall(SYS_gettid); } +#endif + +std::string Services::GetUniqueID() { + const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; + std::mt19937 gen = GetRandomDevice(); + std::uniform_int_distribution dist(0, kStr.size() - 1); + char buffer[UNIQUEID_LEN]; + for (int i = 0; i < UNIQUEID_LEN; i++) { + buffer[i] = kStr[dist(gen)]; + } + return std::string(buffer, UNIQUEID_LEN); +} + +TaskManager &Services::getTaskMgrInstance() { + Services &sm = GetInstance(); + return *(static_cast(sm.sa_[kSlotTaskMgr_])); +} + +CacheServer &Services::getCacheServer() { + Services &sm = GetInstance(); + return *(static_cast(sm.sa_[kSlotCacheMgr_])); +} + +Status Services::CreateAllInstances() { + // In order, TaskMgr, BufferMgr + Status rc; + sa_[kSlotTaskMgr_] = new (&rc, pool_) TaskManager(); + RETURN_IF_NOT_OK(rc); + rc = sa_[kSlotTaskMgr_]->ServiceStart(); + RETURN_IF_NOT_OK(rc); + // TODO(jesse) : Get the parameters from config file. Right now spill to /tmp and spawn 3 workers + sa_[kSlotCacheMgr_] = new (&rc, pool_) CacheServer("/tmp", 3); + RETURN_IF_NOT_OK(rc); + rc = sa_[kSlotCacheMgr_]->ServiceStart(); + return rc; +} + +Services::Services() : pool_(nullptr), sa_{nullptr} { + Status rc = CircularPool::CreateCircularPool(&pool_, -1, 16, true); // each arena 16M + if (rc.IsError()) { + std::terminate(); + } +} + +Services::~Services() noexcept { + try { + // In reverse order + CacheServer *cs = static_cast(sa_[kSlotCacheMgr_]); + if (cs != nullptr) { + (void)cs->ServiceStop(); + cs->~CacheServer(); + pool_->Deallocate(cs); + } + TaskManager *tm = static_cast(sa_[kSlotTaskMgr_]); + if (tm != nullptr) { + (void)tm->ServiceStop(); + tm->~TaskManager(); + pool_->Deallocate(tm); + } + } catch (const std::exception &e) { + // Do nothing. + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/services.h b/mindspore/ccsrc/minddata/dataset/util/services.h new file mode 100644 index 0000000000..c7adea0b6e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/services.h @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SERVICES_H_ +#define DATASET_UTIL_SERVICES_H_ + +#include +#include +#include +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/service.h" + +#define UNIQUEID_LEN 36 +namespace mindspore { +namespace dataset { +class TaskManager; +class CacheServer; +class Services { + public: + static Status CreateInstance() { + std::call_once(init_instance_flag_, [&]() -> Status { + instance_.reset(new Services()); + return (instance_->CreateAllInstances()); + }); + + if (instance_ == nullptr) { + instance_.reset(new Services()); + return (instance_->CreateAllInstances()); + } + + return Status::OK(); + } + + static Services &GetInstance() { + if (instance_ == nullptr) { + if (!CreateInstance()) { + std::terminate(); + } + } + return *instance_; + } + + Services(const Services &) = delete; + + Services &operator=(const Services &) = delete; + + ~Services() noexcept; + + static TaskManager &getTaskMgrInstance(); + + static CacheServer &getCacheServer(); + + std::shared_ptr GetServiceMemPool() { return pool_; } + +#if !defined(_WIN32) && !defined(_WIN64) + static std::string GetUserName(); + + static std::string GetHostName(); + + static int GetLWP(); +#endif + + static std::string GetUniqueID(); + + template + static Allocator GetAllocator() { + return Allocator(Services::GetInstance().GetServiceMemPool()); + } + + private: + static std::once_flag init_instance_flag_; + static std::unique_ptr instance_; + // A small pool used for small objects that last until the + // Services Manager shuts down. Used by all sub-services. + std::shared_ptr pool_; + // We use pointers here instead of unique_ptr because we + // want to have ultimate control on the order of + // construction and destruction. + static constexpr int kSlotTaskMgr_ = 0; + static constexpr int kSlotCacheMgr_ = 1; + static constexpr int kNumServices_ = 2; + Service *sa_[kNumServices_]; + + Services(); + + Status CreateAllInstances(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_SERVICES_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/sig_handler.cc b/mindspore/ccsrc/minddata/dataset/util/sig_handler.cc new file mode 100644 index 0000000000..eed3b4ee4d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/sig_handler.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/sig_handler.h" +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#endif +#include +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// Register the custom signal handlers +#if !defined(_WIN32) && !defined(_WIN64) +void RegisterHandlers() { + struct sigaction new_int_action; + + // For the interrupt handler, we do not use SA_RESETHAND so this handler remains in play + // permanently, do not use the OS default handler for it. + new_int_action.sa_sigaction = &IntHandler; + (void)sigemptyset(&new_int_action.sa_mask); + new_int_action.sa_flags = SA_RESTART | SA_SIGINFO; + (void)sigaction(SIGINT, &new_int_action, nullptr); +} + +extern void IntHandler(int sig_num, // The signal that was raised + siginfo_t *sig_info, // The siginfo structure. + void *context) { // context info + // Wake up the watchdog which is designed as async-signal-safe. + TaskManager::WakeUpWatchDog(); +} +#endif +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/sig_handler.h b/mindspore/ccsrc/minddata/dataset/util/sig_handler.h similarity index 100% rename from mindspore/ccsrc/dataset/util/sig_handler.h rename to mindspore/ccsrc/minddata/dataset/util/sig_handler.h diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.cc b/mindspore/ccsrc/minddata/dataset/util/slice.cc new file mode 100644 index 0000000000..beff2b3dd2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/slice.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/util/slice.h" + +namespace mindspore { +namespace dataset { +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset, size_t len) : ReadableSlice(src, offset, len) { + mutable_data_ = static_cast(src.mutable_data_) + offset; +} +WritableSlice::WritableSlice(const WritableSlice &src, off64_t offset) + : WritableSlice(src, offset, src.GetSize() - offset) {} +Status WritableSlice::Copy(WritableSlice *dest, const ReadableSlice &src) { + RETURN_UNEXPECTED_IF_NULL(dest); + RETURN_UNEXPECTED_IF_NULL(dest->GetMutablePointer()); + if (dest->GetSize() <= 0) { + RETURN_STATUS_UNEXPECTED("Destination length is non-positive"); + } + auto err = memcpy_s(dest->GetMutablePointer(), dest->GetSize(), src.GetPointer(), src.GetSize()); + if (err) { + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/slice.h b/mindspore/ccsrc/minddata/dataset/util/slice.h new file mode 100644 index 0000000000..1caee0f816 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/slice.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SLICE_H_ +#define DATASET_UTIL_SLICE_H_ + +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/status.h" +namespace mindspore { +namespace dataset { +/// \brief A ReadableSlice wraps a const pointer in memory and its size. +/// \see WritableSlice for a non-const version +/// +class ReadableSlice { + public: + ReadableSlice() : ptr_(nullptr), sz_(0) {} + ReadableSlice(const void *ptr, size_t sz) : ptr_(ptr), sz_(sz) {} + + /// \brief Destructor + ~ReadableSlice() = default; + + ReadableSlice(const ReadableSlice &src, off64_t offset, size_t len) { + ptr_ = static_cast(src.GetPointer()) + offset; + sz_ = len; + } + ReadableSlice(const ReadableSlice &src, off64_t offset) : ReadableSlice(src, offset, src.sz_ - offset) {} + ReadableSlice(const ReadableSlice &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + ReadableSlice &operator=(const ReadableSlice &lhs) { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + } + return *this; + } + ReadableSlice(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + } + ReadableSlice &operator=(ReadableSlice &&lhs) noexcept { + if (this != &lhs) { + ptr_ = lhs.ptr_; + sz_ = lhs.sz_; + lhs.ptr_ = nullptr; + lhs.sz_ = 0; + } + return *this; + } + /// \brief Getter function + /// \return Const version of the pointer + const void *GetPointer() const { return ptr_; } + /// \brief Getter function + /// \return Size of the slice + size_t GetSize() const { return sz_; } + bool empty() const { return ptr_ == nullptr; } + + private: + const void *ptr_; + size_t sz_; +}; +/// \brief A WritableSlice inherits from ReadableSlice to allow +/// one to write to the address pointed to by the pointer. +/// +class WritableSlice : public ReadableSlice { + public: + friend class StorageContainer; + /// \brief Default constructor + WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {} + /// \brief This form of a constructor takes a pointer and its size. + WritableSlice(void *ptr, size_t sz) : ReadableSlice(ptr, sz), mutable_data_(ptr) {} + WritableSlice(const WritableSlice &src, off64_t offset, size_t len); + WritableSlice(const WritableSlice &src, off64_t offset); + WritableSlice(const WritableSlice &lhs) : ReadableSlice(lhs) { mutable_data_ = lhs.mutable_data_; } + /// \brief Destructor + ~WritableSlice() = default; + WritableSlice &operator=(const WritableSlice &lhs) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + ReadableSlice::operator=(lhs); + } + return *this; + } + WritableSlice(WritableSlice &&lhs) noexcept : ReadableSlice(std::move(lhs)) { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + } + } + WritableSlice &operator=(WritableSlice &&lhs) noexcept { + if (this != &lhs) { + mutable_data_ = lhs.mutable_data_; + lhs.mutable_data_ = nullptr; + ReadableSlice::operator=(std::move(lhs)); + } + return *this; + } + /// \brief Copy the content from one slice onto another. + static Status Copy(WritableSlice *dest, const ReadableSlice &src); + + private: + void *mutable_data_; + void *GetMutablePointer() { return mutable_data_; } +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_UTIL_SLICE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/status.cc b/mindspore/ccsrc/minddata/dataset/util/status.cc new file mode 100644 index 0000000000..3fc498b701 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/status.cc @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/status.h" +#include +#include "common/utils.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +std::string CodeAsString(const StatusCode c) { + const char *s = nullptr; + if (c == StatusCode::kOK) { + // Optimize the most frequent case + return std::string("OK"); + } else { + switch (c) { + case StatusCode::kOutOfMemory: + s = "Out of memory"; + break; + case StatusCode::kInterrupted: + s = "Interrupted system call"; + break; + case StatusCode::kShapeMisMatch: + s = "Shape is incorrect."; + break; + case StatusCode::kNoSpace: + s = "No space left on device"; + break; + case StatusCode::kPyFuncException: + s = "Exception thrown from PyFunc"; + break; + case StatusCode::kDuplicateKey: + s = "Duplicate key"; + break; + case StatusCode::kProfilingError: + s = "Error encountered while profiling"; + break; + case StatusCode::kUnexpectedError: + default: + s = "Unexpected error"; + break; + } + } + return std::string(s); +} + +Status::Status(StatusCode c) noexcept : code_(c), err_msg_(std::move(CodeAsString(c))) {} + +Status::Status() noexcept : code_(StatusCode::kOK), err_msg_("") {} + +Status::~Status() noexcept {} + +Status::Status(const Status &s) : code_(s.code_), err_msg_(s.err_msg_) {} + +Status &Status::operator=(const Status &s) { + if (this == &s) { + return *this; + } + code_ = s.code_; + err_msg_ = s.err_msg_; + return *this; +} + +Status::Status(Status &&s) noexcept { + code_ = s.code_; + s.code_ = StatusCode::kOK; + err_msg_ = std::move(s.err_msg_); +} + +Status &Status::operator=(Status &&s) noexcept { + if (this == &s) { + return *this; + } + code_ = s.code_; + s.code_ = StatusCode::kOK; + err_msg_ = std::move(s.err_msg_); + return *this; +} + +Status::Status(const StatusCode code, const std::string &msg) : code_(code), err_msg_(msg) {} + +Status::Status(const StatusCode code, int line_of_code, const char *file_name, const std::string &extra) { + code_ = code; + std::ostringstream ss; + ss << "Thread ID " << this_thread::get_id() << " " << CodeAsString(code) << ". "; + if (!extra.empty()) { + ss << extra; + } + ss << "\n"; + ss << "Line of code : " << line_of_code << "\n"; + if (file_name != nullptr) { + ss << "File : " << file_name << "\n"; + } + err_msg_ = ss.str(); + MS_LOG(INFO) << err_msg_; +} + +std::ostream &operator<<(std::ostream &os, const Status &s) { + os << s.ToString(); + return os; +} + +std::string Status::ToString() const { return err_msg_; } + +StatusCode Status::get_code() const { return code_; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/util/status.h b/mindspore/ccsrc/minddata/dataset/util/status.h similarity index 100% rename from mindspore/ccsrc/dataset/util/status.h rename to mindspore/ccsrc/minddata/dataset/util/status.h diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.cc b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc new file mode 100644 index 0000000000..506495227d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/storage_container.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +Status StorageContainer::Create() { + RETURN_IF_NOT_OK(BuddySpace::CreateBuddySpace(&bs_)); + RETURN_IF_NOT_OK(cont_.CreateFile(&fd_)); + is_open_ = true; + MS_LOG(INFO) << "Container " << cont_ << " created"; + return Status::OK(); +} + +Status StorageContainer::Open() noexcept { + std::lock_guard lck(mutex_); + // Check again + if (!is_open_) { + RETURN_IF_NOT_OK(cont_.OpenFile(&fd_)); + is_open_ = true; + } + return Status::OK(); +} + +Status StorageContainer::Close() noexcept { + if (is_open_) { + std::lock_guard lck(mutex_); + // Check again + if (is_open_) { + RETURN_IF_NOT_OK(cont_.CloseFile(fd_)); + is_open_ = false; + fd_ = -1; + } + } + return Status::OK(); +} + +Status StorageContainer::Read(WritableSlice *dest, off64_t offset) const noexcept { + MS_ASSERT(is_open_); + RETURN_UNEXPECTED_IF_NULL(dest); + auto sz = dest->GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pread64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = read(fd_, dest->GetMutablePointer(), sz); +#else + auto r_sz = pread64(fd_, dest->GetMutablePointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Write(const ReadableSlice &dest, off64_t offset) const noexcept { + MS_ASSERT(is_open_); + auto sz = dest.GetSize(); +#if defined(_WIN32) || defined(_WIN64) + // Doesn't seem there is any pwrite64 on mingw. + // So we will do a seek and then a read under + // a protection of mutex. + std::lock_guard lck(mutex_); + auto seek_err = lseek(fd_, offset, SEEK_SET); + if (seek_err < 0) { + RETURN_STATUS_UNEXPECTED(strerror(errno)); + } + auto r_sz = write(fd_, dest.GetPointer(), sz); +#else + auto r_sz = pwrite64(fd_, dest.GetPointer(), sz, offset); +#endif + if (r_sz != sz) { + errno_t err = (r_sz == 0) ? EOF : errno; + RETURN_STATUS_UNEXPECTED(strerror(err)); + } + return Status::OK(); +} + +Status StorageContainer::Insert(const std::vector &buf, off64_t *offset) noexcept { + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + if (sz > bs_->GetMaxSize()) { + RETURN_STATUS_UNEXPECTED("Request size too big"); + } + BSpaceDescriptor bspd{0}; + addr_t addr = 0; + RETURN_IF_NOT_OK(bs_->Alloc(sz, &bspd, &addr)); + *offset = static_cast(addr); + // We will do piecewise copy of the data to disk. + for (auto &v : buf) { + RETURN_IF_NOT_OK(Write(v, addr)); + addr += v.GetSize(); + } + return Status::OK(); +} + +Status StorageContainer::Truncate() const noexcept { + if (is_open_) { + RETURN_IF_NOT_OK(cont_.TruncateFile(fd_)); + MS_LOG(INFO) << "Container " << cont_ << " truncated"; + } + return Status::OK(); +} + +StorageContainer::~StorageContainer() noexcept { + (void)Truncate(); + (void)Close(); +} + +std::ostream &operator<<(std::ostream &os, const StorageContainer &s) { + os << "File path : " << s.cont_ << "\n" << *(s.bs_.get()); + return os; +} + +Status StorageContainer::CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path) { + Status rc; + auto sc = new (std::nothrow) StorageContainer(path); + if (sc == nullptr) { + return Status(StatusCode::kOutOfMemory); + } + rc = sc->Create(); + if (rc.IsOk()) { + (*out_sc).reset(sc); + } else { + delete sc; + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_container.h b/mindspore/ccsrc/minddata/dataset/util/storage_container.h new file mode 100644 index 0000000000..a304012b60 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_container.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_CONTAINER_H_ +#define DATASET_UTIL_STORAGE_CONTAINER_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/buddy.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class StorageManager; + +class StorageContainer { + public: + friend class StorageManager; + + ~StorageContainer() noexcept; + + StorageContainer(const StorageContainer &) = delete; + + StorageContainer &operator=(const StorageContainer &) = delete; + + friend std::ostream &operator<<(std::ostream &os, const StorageContainer &s); + + Status Open() noexcept; + + Status Close() noexcept; + + Status Insert(const std::vector &buf, off64_t *offset) noexcept; + + Status Write(const ReadableSlice &dest, off64_t offset) const noexcept; + + Status Read(WritableSlice *dest, off64_t offset) const noexcept; + + Status Truncate() const noexcept; + + bool IsOpen() const { return is_open_; } + + static Status CreateStorageContainer(std::shared_ptr *out_sc, const std::string &path); + + private: + mutable std::mutex mutex_; + Path cont_; + int fd_; + bool is_open_; + std::unique_ptr bs_; + + // Use the default value of BuddySpace + // which can map upto 4G of space. + explicit StorageContainer(const std::string &path) : cont_(path), fd_(-1), is_open_(false), bs_(nullptr) {} + + Status Create(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_CONTAINER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc new file mode 100644 index 0000000000..2f85d00a45 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.cc @@ -0,0 +1,166 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/storage_manager.h" + +#include +#include +#include +#include +#include "common/utils.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/services.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +std::string StorageManager::GetBaseName(const std::string &prefix, int32_t file_id) { + std::ostringstream oss; + oss << prefix << std::setfill('0') << std::setw(5) << file_id; + return oss.str(); +} + +std::string StorageManager::ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix) { + std::string base_name = GetBaseName(prefix, file_id); + return (base_name + "." + suffix); +} + +Status StorageManager::AddOneContainer() { + const std::string kPrefix = "IMG"; + const std::string kSuffix = "LB"; + Path container_name = root_ / ConstructFileName(kPrefix, file_id_, kSuffix); + std::shared_ptr sc; + RETURN_IF_NOT_OK(StorageContainer::CreateStorageContainer(&sc, container_name.toString())); + containers_.push_back(sc); + file_id_++; + return Status::OK(); +} + +Status StorageManager::DoServiceStart() { + containers_.reserve(1000); + if (root_.IsDirectory()) { + RETURN_IF_NOT_OK(AddOneContainer()); + } else { + RETURN_STATUS_UNEXPECTED("Not a directory"); + } + return Status::OK(); +} + +Status StorageManager::Write(key_type *key, const std::vector &buf) { + RETURN_UNEXPECTED_IF_NULL(key); + size_t sz = 0; + for (auto &v : buf) { + sz += v.GetSize(); + } + if (sz == 0) { + RETURN_STATUS_UNEXPECTED("Unexpected 0 length"); + } + std::shared_ptr cont; + key_type out_key; + value_type out_value; + bool create_new_container = false; + do { + SharedLock lock_s(&rw_lock_); + size_t num_containers = containers_.size(); + if (create_new_container) { + // Upgrade to exclusvie lock. + lock_s.Upgrade(); + create_new_container = false; + // Check again if someone has already added a + // new container after we got the x lock + if (containers_.size() == num_containers) { + RETURN_IF_NOT_OK(AddOneContainer()); + } + // Refresh how many containers there are. + num_containers = containers_.size(); + // Downgrade back to shared lock + lock_s.Downgrade(); + } + if (num_containers == 0) { + RETURN_STATUS_UNEXPECTED("num_containers is zero"); + } + // Go to the last container to insert. + cont = containers_.at(num_containers - 1); + off64_t offset; + Status rc = cont->Insert(buf, &offset); + if (rc.IsNoSpace()) { + create_new_container = true; + } else if (rc.IsOk()) { + out_value = std::make_pair(num_containers - 1, std::make_pair(offset, sz)); + RETURN_IF_NOT_OK(index_.insert(out_value, &out_key)); + *key = out_key; + break; + } else { + return rc; + } + } while (true); + return Status::OK(); +} + +Status StorageManager::Read(StorageManager::key_type key, WritableSlice *dest, size_t *bytesRead) const { + RETURN_UNEXPECTED_IF_NULL(dest); + auto r = index_.Search(key); + if (r.second) { + auto &it = r.first; + value_type v = *it; + int container_inx = v.first; + off_t offset = v.second.first; + size_t sz = v.second.second; + if (dest->GetSize() < sz) { + std::string errMsg = "Destination buffer too small. Expect at least " + std::to_string(sz) + + " but length = " + std::to_string(dest->GetSize()); + RETURN_STATUS_UNEXPECTED(errMsg); + } + if (bytesRead != nullptr) { + *bytesRead = sz; + } + auto cont = containers_.at(container_inx); + RETURN_IF_NOT_OK(cont->Read(dest, offset)); + } else { + RETURN_STATUS_UNEXPECTED("Key not found"); + } + return Status::OK(); +} + +Status StorageManager::DoServiceStop() noexcept { + Status rc; + Status rc1; + for (auto const &p : containers_) { + // The destructor of StorageContainer is not called automatically until the use + // count drops to 0. But it is not always the case. We will do it ourselves. + rc = p.get()->Truncate(); + if (rc.IsError()) { + rc1 = rc; + } + } + containers_.clear(); + file_id_ = 0; + return rc1; +} + +StorageManager::StorageManager(const Path &root) : root_(root), file_id_(0), index_() {} + +StorageManager::~StorageManager() { (void)StorageManager::DoServiceStop(); } + +std::ostream &operator<<(std::ostream &os, const StorageManager &s) { + os << "Dumping all containers ..." + << "\n"; + for (auto const &p : s.containers_) { + os << *(p.get()); + } + return os; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/storage_manager.h b/mindspore/ccsrc/minddata/dataset/util/storage_manager.h new file mode 100644 index 0000000000..e79e7c6e63 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/storage_manager.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_STORAGE_MANAGER_H_ +#define DATASET_UTIL_STORAGE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/service.h" +#include "minddata/dataset/util/slice.h" +#include "minddata/dataset/util/storage_container.h" + +using ListOfContainers = std::vector>; +namespace mindspore { +namespace dataset { +class StorageManager : public Service { + public: + using storage_index = AutoIndexObj>>; + using key_type = storage_index::key_type; + using value_type = storage_index::value_type; + + explicit StorageManager(const Path &); + + ~StorageManager() override; + + StorageManager(const StorageManager &) = delete; + + StorageManager &operator=(const StorageManager &) = delete; + + Status Write(key_type *out_key, const std::vector &buf); + + Status Read(key_type key, WritableSlice *dest, size_t *bytesRead) const; + + Status DoServiceStart() override; + + Status DoServiceStop() noexcept override; + + friend std::ostream &operator<<(std::ostream &os, const StorageManager &s); + + private: + Path root_; + ListOfContainers containers_; + int file_id_; + RWLock rw_lock_; + storage_index index_; + + std::string GetBaseName(const std::string &prefix, int32_t file_id); + + std::string ConstructFileName(const std::string &prefix, int32_t file_id, const std::string &suffix); + + Status AddOneContainer(); +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_STORAGE_MANAGER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/system_pool.h b/mindspore/ccsrc/minddata/dataset/util/system_pool.h new file mode 100644 index 0000000000..3a7e61d16b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/system_pool.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_SYSTEM_POOL_H_ +#define DATASET_UTIL_SYSTEM_POOL_H_ + +#include +#include +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/memory_pool.h" + +namespace mindspore { +namespace dataset { +// This class demonstrate how to implement a simple MemoryPool +// for minddata/dataset using malloc/free/realloc. We need to +// implement 4 virtual functions. Other MemoryPool +// implementation, e.g., are BuddyArena and CircularPool. All +// these MemoryPool can be used together with Allocator.h for +// C++ STL containers. +class SystemPool : public MemoryPool { + public: + ~SystemPool() override {} + + Status Allocate(size_t n, void **pp) override { return DeMalloc(n, pp, false); } + + void Deallocate(void *p) override { free(p); } + + Status Reallocate(void **p, size_t old_sz, size_t new_sz) override { + if (old_sz >= new_sz) { + // Do nothing if we shrink. + return Status::OK(); + } else { + void *ptr = *p; + void *q = nullptr; + RETURN_IF_NOT_OK(DeMalloc(new_sz, &q, false)); + errno_t err = memcpy_s(q, new_sz, ptr, old_sz); + if (err) { + free(q); + RETURN_STATUS_UNEXPECTED(std::to_string(err)); + } + free(ptr); + *p = q; + return Status::OK(); + } + } + + uint64_t get_max_size() const override { return std::numeric_limits::max(); } + + int PercentFree() const override { return 100; } + + template + static Allocator GetAllocator() { + return Allocator(std::make_shared()); + } +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_SYSTEM_POOL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/task.cc b/mindspore/ccsrc/minddata/dataset/util/task.cc new file mode 100644 index 0000000000..39d754e806 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/task.h" +#include "common/utils.h" +#include "minddata/dataset/util/task_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +thread_local Task *gMyTask = nullptr; + +void Task::operator()() { +#if !defined(_WIN32) && !defined(_WIN64) + gMyTask = this; +#endif + id_ = this_thread::get_id(); + std::stringstream ss; + ss << id_; + MS_LOG(DEBUG) << my_name_ << " Thread ID " << ss.str() << " Started."; + try { + // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set + // the TaskGroup pointer and register. We move the registration logic to here (after we spawn) so we can + // get the thread id. + TaskGroup *vg = MyTaskGroup(); + rc_ = vg->GetIntrpService()->Register(ss.str(), this); + if (rc_.IsOk()) { + // Now we can run the given task. + rc_ = fnc_obj_(); + } + // Some error codes are ignored, e.g. interrupt. Others we just shutdown the group. + if (rc_.IsError() && !rc_.IsInterrupted()) { + ShutdownGroup(); + } + } catch (const std::bad_alloc &e) { + rc_ = Status(StatusCode::kOutOfMemory, __LINE__, __FILE__, e.what()); + ShutdownGroup(); + } catch (const std::exception &e) { + rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); + ShutdownGroup(); + } +} + +void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine. + { + std::lock_guard lk(mux_); + caught_severe_exception_ = true; + } + TaskGroup *vg = MyTaskGroup(); + // If multiple threads hit severe errors in the same group. Keep the first one and + // discard the rest. + if (vg->rc_.IsOk()) { + std::unique_lock rcLock(vg->rc_mux_); + // Check again after we get the lock + if (vg->rc_.IsOk()) { + vg->rc_ = rc_; + rcLock.unlock(); + TaskManager::InterruptMaster(rc_); + TaskManager::InterruptGroup(*this); + } + } +} + +Status Task::GetTaskErrorIfAny() const { + std::lock_guard lk(mux_); + if (caught_severe_exception_) { + return rc_; + } else { + return Status::OK(); + } +} + +Task::Task(const std::string &myName, const std::function &f) + : my_name_(myName), + rc_(), + fnc_obj_(f), + task_group_(nullptr), + is_master_(false), + running_(false), + caught_severe_exception_(false) { + IntrpResource::ResetIntrpState(); + wp_.ResetIntrpState(); + wp_.Clear(); +} + +Status Task::Run() { + Status rc; + if (running_ == false) { + try { + thrd_ = std::async(std::launch::async, std::ref(*this)); + running_ = true; + caught_severe_exception_ = false; + } catch (const std::exception &e) { + rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, e.what()); + } + } + return rc; +} + +Status Task::Join(WaitFlag blocking) { + if (running_) { + RETURN_UNEXPECTED_IF_NULL(MyTaskGroup()); + auto interrupt_svc = MyTaskGroup()->GetIntrpService(); + try { + if (blocking == WaitFlag::kBlocking) { + // If we are asked to wait, then wait + thrd_.get(); + } else if (blocking == WaitFlag::kNonBlocking) { + // There is a race condition in the global resource tracking such that a thread can miss the + // interrupt and becomes blocked on a conditional variable forever. As a result, calling + // join() will not come back. We need some timeout version of join such that if the thread + // doesn't come back in a reasonable of time, we will send the interrupt again. + while (thrd_.wait_for(std::chrono::seconds(1)) != std::future_status::ready) { + // We can't tell which conditional_variable this thread is waiting on. So we may need + // to interrupt everything one more time. + MS_LOG(INFO) << "Some threads not responding. Interrupt again"; + interrupt_svc->InterruptAll(); + } + } else { + RETURN_STATUS_UNEXPECTED("Unknown WaitFlag"); + } + std::stringstream ss; + ss << get_id(); + MS_LOG(DEBUG) << MyName() << " Thread ID " << ss.str() << " Stopped."; + running_ = false; + RETURN_IF_NOT_OK(wp_.Deregister()); + RETURN_IF_NOT_OK(interrupt_svc->Deregister(ss.str())); + } catch (const std::exception &e) { + RETURN_STATUS_UNEXPECTED(e.what()); + } + } + return Status::OK(); +} + +TaskGroup *Task::MyTaskGroup() { return task_group_; } + +void Task::set_task_group(TaskGroup *vg) { task_group_ = vg; } + +Task::~Task() { task_group_ = nullptr; } +Status Task::OverrideInterruptRc(const Status &rc) { + if (rc.IsInterrupted() && this_thread::is_master_thread()) { + // If we are interrupted, override the return value if this is the master thread. + // Master thread is being interrupted mostly because of some thread is reporting error. + return TaskManager::GetMasterThreadRc(); + } + return rc; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/task.h b/mindspore/ccsrc/minddata/dataset/util/task.h new file mode 100644 index 0000000000..9309a3de7b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task.h @@ -0,0 +1,125 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_TASK_H_ +#define DATASET_UTIL_TASK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/dataset/util/intrp_resource.h" +#include "minddata/dataset/util/list.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/wait_post.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +class TaskManager; + +class Task : public IntrpResource { + public: + friend class TaskManager; + friend class TaskGroup; + + enum class WaitFlag : int { kBlocking, kNonBlocking }; + + Task(const std::string &myName, const std::function &f); + + // Future objects are not copyable. + Task(const Task &) = delete; + + ~Task() override; + + Task &operator=(const Task &) = delete; + + // Move constructor and Assignment are not supported. + // Too many things in this class. + Task(Task &&) = delete; + + Task &operator=(Task &&) = delete; + + Status GetTaskErrorIfAny() const; + + void ChangeName(const std::string &newName) { my_name_ = newName; } + + // To execute the _fncObj + void operator()(); + + Node node; + Node group; + Node free; + + // Run the task + Status Run(); + + Status Join(WaitFlag wf = WaitFlag::kBlocking); + + bool Running() const { return running_; } + + bool CaughtSevereException() const { return caught_severe_exception_; } + + bool IsMasterThread() const { return is_master_; } + + std::thread::id get_id() { return id_; } + + std::string MyName() { return my_name_; } + + // An operator used by std::find + bool operator==(const Task &other) const { return (this == &other); } + + bool operator!=(const Task &other) const { return !(*this == other); } + + void Post() { wp_.Set(); } + + Status Wait() { return (wp_.Wait()); } + + static Status OverrideInterruptRc(const Status &rc); + + private: + mutable std::mutex mux_; + std::string my_name_; + Status rc_; + WaitPost wp_; + // Task need to provide definition for this function. It + // will be called by thread function. + std::function fnc_obj_; + // Misc fields used by TaskManager. + TaskGroup *task_group_; + std::future thrd_; + std::thread::id id_; + bool is_master_; + volatile bool running_; + volatile bool caught_severe_exception_; + + void ShutdownGroup(); + TaskGroup *MyTaskGroup(); + void set_task_group(TaskGroup *vg); +}; + +extern thread_local Task *gMyTask; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_TASK_H_ diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.cc b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc new file mode 100644 index 0000000000..fefea0b97c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.cc @@ -0,0 +1,353 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "./securec.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +// This takes the same parameter as Task constructor. +Status TaskManager::CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, + Task **task) { + // We need to block destructor coming otherwise we will deadlock. We will grab the + // stateLock in shared allowing CreateAsyncTask to run concurrently. + SharedLock stateLck(&state_lock_); + // Now double check the state + if (ServiceState() == STATE::kStopInProg || ServiceState() == STATE::kStopped) { + return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "TaskManager is shutting down"); + } + RETURN_IF_NOT_OK(GetFreeTask(my_name, f, task)); + if (vg == nullptr) { + RETURN_STATUS_UNEXPECTED("TaskGroup is null"); + } + // Previously there is a timing hole where the thread is spawn but hit error immediately before we can set + // the TaskGroup pointer. We will do the set here before we call run(). The run() will do the registration. + (*task)->set_task_group(vg); + // Link to the master lru list. + { + UniqueLock lck(&lru_lock_); + lru_.Append(*task); + } + // Link to the group list as well before we spawn. + { + UniqueLock lck(&vg->rw_lock_); + vg->grp_list_.Append(*task); + } + // Track all the TaskGroup. Used for control-c + { + LockGuard lck(&tg_lock_); + this->grp_list_.insert(vg); + } + RETURN_IF_NOT_OK((*task)->wp_.Register(vg)); + RETURN_IF_NOT_OK((*task)->Run()); + // Wait for the thread to initialize successfully. + RETURN_IF_NOT_OK((*task)->Wait()); + return Status::OK(); +} + +Status TaskManager::join_all() { + Status rc; + Status rc2; + SharedLock lck(&lru_lock_); + for (Task &tk : lru_) { + rc = tk.Join(); + if (rc.IsError()) { + rc2 = rc; + } + } + return rc2; +} + +void TaskManager::interrupt_all() noexcept { + global_interrupt_ = 1; + LockGuard lck(&tg_lock_); + for (TaskGroup *vg : grp_list_) { + auto svc = vg->GetIntrpService(); + if (svc) { + // Stop the interrupt service. No new request is accepted. + svc->ServiceStop(); + svc->InterruptAll(); + } + } + master_->Interrupt(); +} + +Task *TaskManager::FindMe() { +#if !defined(_WIN32) && !defined(_WIN64) + return gMyTask; +#else + TaskManager &tm = TaskManager::GetInstance(); + SharedLock lock(&tm.lru_lock_); + auto id = this_thread::get_id(); + auto tk = std::find_if(tm.lru_.begin(), tm.lru_.end(), [id](const Task &tk) { return tk.id_ == id; }); + if (tk != tm.lru_.end()) { + return &(*tk); + } + // If we get here, either I am the watchdog or the master thread. + if (tm.master_->id_ == id) { + return tm.master_.get(); + } else if (tm.watchdog_ != nullptr && tm.watchdog_->id_ == id) { + return tm.watchdog_; + } + MS_LOG(ERROR) << "Task not found."; + return nullptr; +#endif +} + +TaskManager::TaskManager() try : global_interrupt_(0), + lru_(&Task::node), + free_lst_(&Task::free), + watchdog_grp_(nullptr), + watchdog_(nullptr) { + auto alloc = Services::GetAllocator(); + // Create a dummy Task for the master thread (this thread) + master_ = std::allocate_shared(alloc, "master", []() -> Status { return Status::OK(); }); + master_->id_ = this_thread::get_id(); + master_->running_ = true; + master_->is_master_ = true; +#if !defined(_WIN32) && !defined(_WIN64) + gMyTask = master_.get(); + // Initialize the semaphore for the watchdog + errno_t rc = sem_init(&sem_, 0, 0); + if (rc == -1) { + MS_LOG(ERROR) << "Unable to initialize a semaphore. Errno = " << rc << "."; + std::terminate(); + } +#endif +} catch (const std::exception &e) { + MS_LOG(ERROR) << "MindData initialization failed: " << e.what() << "."; + std::terminate(); +} + +TaskManager::~TaskManager() { + if (watchdog_) { + WakeUpWatchDog(); + watchdog_->Join(); + // watchdog_grp_ and watchdog_ pointers come from Services::GetInstance().GetServiceMemPool() which we will free it + // on shutdown. So no need to free these pointers one by one. + watchdog_grp_ = nullptr; + watchdog_ = nullptr; + } +#if !defined(_WIN32) && !defined(_WIN64) + (void)sem_destroy(&sem_); +#endif +} + +Status TaskManager::DoServiceStart() { + MS_LOG(INFO) << "Starting Task Manager."; +#if !defined(_WIN32) && !defined(_WIN64) + // Create a watchdog for control-c + std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); + // A dummy group just for the watchdog. We aren't really using it. But most code assumes a thread must + // belong to a group. + auto f = std::bind(&TaskManager::WatchDog, this); + Status rc; + watchdog_grp_ = new (&rc, mp) TaskGroup(); + RETURN_IF_NOT_OK(rc); + rc = watchdog_grp_->CreateAsyncTask("Watchdog", f, &watchdog_); + if (rc.IsError()) { + ::operator delete(watchdog_grp_, mp); + watchdog_grp_ = nullptr; + return rc; + } + grp_list_.erase(watchdog_grp_); + lru_.Remove(watchdog_); +#endif + return Status::OK(); +} + +Status TaskManager::DoServiceStop() { + WakeUpWatchDog(); + interrupt_all(); + return Status::OK(); +} + +Status TaskManager::WatchDog() { + TaskManager::FindMe()->Post(); +#if !defined(_WIN32) && !defined(_WIN64) + errno_t err = sem_wait(&sem_); + if (err == -1) { + RETURN_STATUS_UNEXPECTED("Errno = " + std::to_string(errno)); + } + // We are woken up by control-c and we are going to stop all threads that are running. + // In addition, we also want to prevent new thread from creating. This can be done + // easily by calling the parent function. + RETURN_IF_NOT_OK(ServiceStop()); +#endif + return Status::OK(); +} + +// Follow the group link and interrupt other +// Task in the same group. It is used by +// Watchdog only. +void TaskManager::InterruptGroup(Task &curTk) { + TaskGroup *vg = curTk.MyTaskGroup(); + vg->interrupt_all(); +} + +void TaskManager::InterruptMaster(const Status &rc) { + TaskManager &tm = TaskManager::GetInstance(); + std::shared_ptr master = tm.master_; + std::lock_guard lck(master->mux_); + master->Interrupt(); + if (rc.IsError() && master->rc_.IsOk()) { + master->rc_ = rc; + master->caught_severe_exception_ = true; + } +} + +Status TaskManager::GetMasterThreadRc() { + TaskManager &tm = TaskManager::GetInstance(); + std::shared_ptr master = tm.master_; + Status rc = tm.master_->GetTaskErrorIfAny(); + if (rc.IsError()) { + // Reset the state once we retrieve the value. + std::lock_guard lck(master->mux_); + master->rc_ = Status::OK(); + master->caught_severe_exception_ = false; + master->ResetIntrpState(); + } + return rc; +} + +void TaskManager::ReturnFreeTask(Task *p) noexcept { + // Take it out from lru_ if any + { + UniqueLock lck(&lru_lock_); + auto it = std::find(lru_.begin(), lru_.end(), *p); + if (it != lru_.end()) { + lru_.Remove(p); + } + } + // We need to deallocate the string resources associated with the Task class + // before we cache its memory for future use. + p->~Task(); + // Put it back into free list + { + LockGuard lck(&free_lock_); + free_lst_.Append(p); + } +} + +Status TaskManager::GetFreeTask(const std::string &my_name, const std::function &f, Task **p) { + if (p == nullptr) { + RETURN_STATUS_UNEXPECTED("p is null"); + } + Task *q = nullptr; + // First try the free list + { + LockGuard lck(&free_lock_); + if (free_lst_.count > 0) { + q = free_lst_.head; + free_lst_.Remove(q); + } + } + if (q) { + new (q) Task(my_name, f); + } else { + std::shared_ptr mp = Services::GetInstance().GetServiceMemPool(); + Status rc; + q = new (&rc, mp) Task(my_name, f); + RETURN_IF_NOT_OK(rc); + } + *p = q; + return Status::OK(); +} + +Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::function &f, Task **ppTask) { + auto pMytask = TaskManager::FindMe(); + // We need to block ~TaskGroup coming otherwise we will deadlock. We will grab the + // stateLock in shared allowing CreateAsyncTask to run concurrently. + SharedLock state_lck(&state_lock_); + // Now double check the state + if (ServiceState() != STATE::kRunning) { + return Status(StatusCode::kInterrupted, __LINE__, __FILE__, "Taskgroup is shutting down"); + } + TaskManager &dm = TaskManager::GetInstance(); + Task *pTask = nullptr; + // If the group is already in error, early exit too. + // We can't hold the rc_mux_ throughout because the thread spawned by CreateAsyncTask may hit error which + // will try to shutdown the group and grab the rc_mux_ and we will deadlock. + { + std::unique_lock rcLock(rc_mux_); + if (rc_.IsError()) { + return pMytask->IsMasterThread() ? rc_ : Status(StatusCode::kInterrupted); + } + } + RETURN_IF_NOT_OK(dm.CreateAsyncTask(my_name, f, this, &pTask)); + if (ppTask) { + *ppTask = pTask; + } + return Status::OK(); +} + +void TaskGroup::interrupt_all() noexcept { intrp_svc_->InterruptAll(); } + +Status TaskGroup::join_all(Task::WaitFlag wf) { + Status rc; + Status rc2; + SharedLock lck(&rw_lock_); + for (Task &tk : grp_list_) { + rc = tk.Join(wf); + if (rc.IsError()) { + rc2 = rc; + } + } + return rc2; +} + +Status TaskGroup::DoServiceStop() { + intrp_svc_->ServiceStop(); + interrupt_all(); + return (join_all(Task::WaitFlag::kNonBlocking)); +} + +TaskGroup::TaskGroup() : grp_list_(&Task::group), intrp_svc_(nullptr) { + auto alloc = Services::GetAllocator(); + intrp_svc_ = std::allocate_shared(alloc); + (void)Service::ServiceStart(); +} + +TaskGroup::~TaskGroup() { + (void)Service::ServiceStop(); + // The TaskGroup is going out of scope, and we can return the Task list to the free list. + Task *cur = grp_list_.head; + TaskManager &tm = TaskManager::GetInstance(); + while (cur) { + Task *next = cur->group.next; + grp_list_.Remove(cur); + tm.ReturnFreeTask(cur); + cur = next; + } + { + LockGuard lck(&tm.tg_lock_); + (void)tm.grp_list_.erase(this); + } +} + +Status TaskGroup::GetTaskErrorIfAny() { + SharedLock lck(&rw_lock_); + for (Task &tk : grp_list_) { + RETURN_IF_NOT_OK(tk.GetTaskErrorIfAny()); + } + return Status::OK(); +} + +std::shared_ptr TaskGroup::GetIntrpService() { return intrp_svc_; } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/task_manager.h b/mindspore/ccsrc/minddata/dataset/util/task_manager.h new file mode 100644 index 0000000000..3030390bab --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/task_manager.h @@ -0,0 +1,181 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_TASK_MANAGER_H_ +#define DATASET_UTIL_TASK_MANAGER_H_ + +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include // for sig_atomic_t +#endif +#include +#include +#include +#include +#include +#include "minddata/dataset/util/allocator.h" +#include "minddata/dataset/util/intrp_service.h" +#include "minddata/dataset/util/lock.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/util/task.h" + +namespace mindspore { +namespace dataset { +namespace thread { +using id = std::thread::id; +} // namespace thread + +namespace this_thread { +inline thread::id get_id() { return std::this_thread::get_id(); } +} // namespace this_thread + +class TaskManager : public Service { + public: + friend class Services; + + friend class TaskGroup; + + ~TaskManager() override; + + TaskManager(const TaskManager &) = delete; + + TaskManager &operator=(const TaskManager &) = delete; + + static TaskManager &GetInstance() noexcept { return Services::getTaskMgrInstance(); } + + Status DoServiceStart() override; + + Status DoServiceStop() override; + + // A public global interrupt flag for signal handlers + volatile sig_atomic_t global_interrupt_; + + // API + // This takes the same parameter as Task constructor. Take a look + // of the test-thread.cc for usage. + Status CreateAsyncTask(const std::string &my_name, const std::function &f, TaskGroup *vg, Task **); + + // Same usage as boot thread group + Status join_all(); + + void interrupt_all() noexcept; + + // Locate a particular Task. + static Task *FindMe(); + + static void InterruptGroup(Task &); + + static Status GetMasterThreadRc(); + + static void InterruptMaster(const Status &rc = Status::OK()); + + static void WakeUpWatchDog() { +#if !defined(_WIN32) && !defined(_WIN64) + TaskManager &tm = TaskManager::GetInstance(); + (void)sem_post(&tm.sem_); +#endif + } + + void ReturnFreeTask(Task *p) noexcept; + + Status GetFreeTask(const std::string &my_name, const std::function &f, Task **p); + + Status WatchDog(); + + private: + RWLock lru_lock_; + SpinLock free_lock_; + SpinLock tg_lock_; + std::shared_ptr master_; + List lru_; + List free_lst_; +#if !defined(_WIN32) && !defined(_WIN64) + sem_t sem_; +#endif + TaskGroup *watchdog_grp_; + std::set grp_list_; + Task *watchdog_; + + TaskManager(); +}; + +// A group of related tasks. +class TaskGroup : public Service { + public: + friend class Task; + friend class TaskManager; + + Status CreateAsyncTask(const std::string &my_name, const std::function &f, Task **pTask = nullptr); + + void interrupt_all() noexcept; + + Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking); + + int size() const noexcept { return grp_list_.count; } + + Status DoServiceStart() override { return Status::OK(); } + + Status DoServiceStop() override; + + TaskGroup(); + + ~TaskGroup() override; + + Status GetTaskErrorIfAny(); + + std::shared_ptr GetIntrpService(); + + private: + Status rc_; + // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_. + std::mutex rc_mux_; + RWLock rw_lock_; + List grp_list_; + std::shared_ptr intrp_svc_; +}; + +namespace this_thread { +inline bool is_interrupted() { + TaskManager &tm = TaskManager::GetInstance(); + if (tm.global_interrupt_ == 1) { + return true; + } + Task *my_task = TaskManager::FindMe(); + return my_task->Interrupted(); +} + +inline bool is_master_thread() { + Task *my_task = TaskManager::FindMe(); + return my_task->IsMasterThread(); +} + +inline Status GetInterruptStatus() { + Task *my_task = TaskManager::FindMe(); + return my_task->GetInterruptStatus(); +} +} // namespace this_thread + +#define RETURN_IF_INTERRUPTED() \ + do { \ + if (mindspore::dataset::this_thread::is_interrupted()) { \ + return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \ + } \ + } while (false) + +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_TASK_MANAGER_H_ diff --git a/mindspore/ccsrc/dataset/util/treap.h b/mindspore/ccsrc/minddata/dataset/util/treap.h similarity index 100% rename from mindspore/ccsrc/dataset/util/treap.h rename to mindspore/ccsrc/minddata/dataset/util/treap.h diff --git a/mindspore/ccsrc/minddata/dataset/util/wait_post.cc b/mindspore/ccsrc/minddata/dataset/util/wait_post.cc new file mode 100644 index 0000000000..944d9ca245 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/wait_post.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/util/wait_post.h" +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +WaitPost::WaitPost() : value_(0) {} + +Status WaitPost::Wait() { + std::unique_lock lck(mutex_); + return (wait_cond_.Wait(&lck, [this]() { return value_ != 0; })); +} + +void WaitPost::Set() { + std::unique_lock lck(mutex_); + value_ = 1; + wait_cond_.NotifyAll(); +} + +void WaitPost::Clear() { + std::unique_lock lck(mutex_); + value_ = 0; +} + +Status WaitPost::Register(TaskGroup *vg) { return wait_cond_.Register(vg->GetIntrpService()); } + +void WaitPost::ResetIntrpState() { wait_cond_.ResetIntrpState(); } + +Status WaitPost::Deregister() { return wait_cond_.Deregister(); } +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/util/wait_post.h b/mindspore/ccsrc/minddata/dataset/util/wait_post.h new file mode 100644 index 0000000000..afd3bea38b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/util/wait_post.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_UTIL_WAIT_POST_H_ +#define DATASET_UTIL_WAIT_POST_H_ + +#include +#include "minddata/dataset/util/cond_var.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class TaskGroup; + +class WaitPost { + public: + WaitPost(); + + ~WaitPost() = default; + + Status Wait(); + + void Set(); + + void Clear(); + + Status Register(TaskGroup *vg); + + Status Deregister(); + + void ResetIntrpState(); + + private: + std::mutex mutex_; + CondVar wait_cond_; + int value_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_UTIL_WAIT_POST_H_ diff --git a/mindspore/ccsrc/mindrecord/CMakeLists.txt b/mindspore/ccsrc/minddata/mindrecord/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/mindrecord/CMakeLists.txt rename to mindspore/ccsrc/minddata/mindrecord/CMakeLists.txt diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc new file mode 100644 index 0000000000..e4d35b8305 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc @@ -0,0 +1,181 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_error.h" + +namespace mindspore { +namespace mindrecord { +std::string ErrnoToMessage(MSRStatus status) { + switch (status) { + case FAILED: + return "operator failed"; + break; + case SUCCESS: + return "operator success"; + break; + case OPEN_FILE_FAILED: + return "open file failed"; + break; + case CLOSE_FILE_FAILED: + return "close file failed"; + break; + case WRITE_METADATA_FAILED: + return "write metadata failed"; + break; + case WRITE_RAWDATA_FAILED: + return "write rawdata failed"; + break; + case GET_SCHEMA_FAILED: + return "get schema failed"; + break; + case ILLEGAL_RAWDATA: + return "illegal raw data"; + break; + case PYTHON_TO_JSON_FAILED: + return "pybind: python object to json failed"; + break; + case DIR_CREATE_FAILED: + return "directory create failed"; + break; + case OPEN_DIR_FAILED: + return "open directory failed"; + break; + case INVALID_STATISTICS: + return "invalid statistics object"; + break; + case OPEN_DATABASE_FAILED: + return "open database failed"; + break; + case CLOSE_DATABASE_FAILED: + return "close database failed"; + break; + case DATABASE_OPERATE_FAILED: + return "database operate failed"; + break; + case BUILD_SCHEMA_FAILED: + return "build schema failed"; + break; + case DIVISOR_IS_ILLEGAL: + return "divisor is illegal"; + break; + case INVALID_FILE_PATH: + return "file path is invalid"; + break; + case SECURE_FUNC_FAILED: + return "secure function failed"; + break; + case ALLOCATE_MEM_FAILED: + return "allocate memory failed"; + break; + case ILLEGAL_FIELD_NAME: + return "illegal field name"; + break; + case ILLEGAL_FIELD_TYPE: + return "illegal field type"; + break; + case SET_METADATA_FAILED: + return "set metadata failed"; + break; + case ILLEGAL_SCHEMA_DEFINITION: + return "illegal schema definition"; + break; + case ILLEGAL_COLUMN_LIST: + return "illegal column list"; + break; + case SQL_ERROR: + return "sql error"; + break; + case ILLEGAL_SHARD_COUNT: + return "illegal shard count"; + break; + case ILLEGAL_SCHEMA_COUNT: + return "illegal schema count"; + break; + case VERSION_ERROR: + return "data version is not matched"; + break; + case ADD_SCHEMA_FAILED: + return "add schema failed"; + break; + case ILLEGAL_Header_SIZE: + return "illegal header size"; + break; + case ILLEGAL_Page_SIZE: + return "illegal page size"; + break; + case ILLEGAL_SIZE_VALUE: + return "illegal size value"; + break; + case INDEX_FIELD_ERROR: + return "add index fields failed"; + break; + case GET_CANDIDATE_CATEGORYFIELDS_FAILED: + return "get candidate category fields failed"; + break; + case GET_CATEGORY_INFO_FAILED: + return "get category information failed"; + break; + case ILLEGAL_CATEGORY_ID: + return "illegal category id"; + break; + case ILLEGAL_ROWNUMBER_OF_PAGE: + return "illegal row number of page"; + break; + case ILLEGAL_SCHEMA_ID: + return "illegal schema id"; + break; + case DESERIALIZE_SCHEMA_FAILED: + return "deserialize schema failed"; + break; + case DESERIALIZE_STATISTICS_FAILED: + return "deserialize statistics failed"; + break; + case ILLEGAL_DB_FILE: + return "illegal db file"; + break; + case OVERWRITE_DB_FILE: + return "overwrite db file"; + break; + case OVERWRITE_MINDRECORD_FILE: + return "overwrite mindrecord file"; + break; + case ILLEGAL_MINDRECORD_FILE: + return "illegal mindrecord file"; + break; + case PARSE_JSON_FAILED: + return "parse json failed"; + break; + case ILLEGAL_PARAMETERS: + return "illegal parameters"; + break; + case GET_PAGE_BY_GROUP_ID_FAILED: + return "get page by group id failed"; + break; + case GET_SYSTEM_STATE_FAILED: + return "get system state failed"; + break; + case IO_FAILED: + return "io operate failed"; + break; + case MATCH_HEADER_FAILED: + return "match header failed"; + break; + default: + return "invalid error no"; + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc new file mode 100644 index 0000000000..d9e51efc4e --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -0,0 +1,230 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_segment.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "nlohmann/json.hpp" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace py = pybind11; + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +void BindSchema(py::module *m) { + (void)py::class_>(*m, "Schema", py::module_local()) + .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Schema::Build) + .def("get_desc", &Schema::GetDesc) + .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) + .def("get_blob_fields", &Schema::GetBlobFields) + .def("get_schema_id", &Schema::GetSchemaID); +} + +void BindStatistics(const py::module *m) { + (void)py::class_>(*m, "Statistics", py::module_local()) + .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Statistics::Build) + .def("get_desc", &Statistics::GetDesc) + .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) + .def("get_statistics_id", &Statistics::GetStatisticsID); +} + +void BindShardHeader(const py::module *m) { + (void)py::class_>(*m, "ShardHeader", py::module_local()) + .def(py::init<>()) + .def("add_schema", &ShardHeader::AddSchema) + .def("add_statistics", &ShardHeader::AddStatistic) + .def("add_index_fields", + (MSRStatus(ShardHeader::*)(const std::vector &)) & ShardHeader::AddIndexFields) + .def("get_meta", &ShardHeader::GetSchemas) + .def("get_statistics", &ShardHeader::GetStatistics) + .def("get_fields", &ShardHeader::GetFields) + .def("get_schema_by_id", &ShardHeader::GetSchemaByID) + .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); +} + +void BindShardWriter(py::module *m) { + (void)py::class_(*m, "ShardWriter", py::module_local()) + .def(py::init<>()) + .def("open", &ShardWriter::Open) + .def("open_for_append", &ShardWriter::OpenForAppend) + .def("set_header_size", &ShardWriter::SetHeaderSize) + .def("set_page_size", &ShardWriter::SetPageSize) + .def("set_shard_header", &ShardWriter::SetShardHeader) + .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, + vector> &, bool, bool)) & + ShardWriter::WriteRawData) + .def("commit", &ShardWriter::Commit); +} + +void BindShardReader(const py::module *m) { + (void)py::class_>(*m, "ShardReader", py::module_local()) + .def(py::init<>()) + .def("open", (MSRStatus(ShardReader::*)(const std::vector &, bool, const int &, + const std::vector &, + const std::vector> &)) & + ShardReader::OpenPy) + .def("launch", &ShardReader::Launch) + .def("get_header", &ShardReader::GetShardHeader) + .def("get_blob_fields", &ShardReader::GetBlobFields) + .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & + ShardReader::GetNextPy) + .def("finish", &ShardReader::Finish) + .def("close", &ShardReader::Close); +} + +void BindShardIndexGenerator(const py::module *m) { + (void)py::class_(*m, "ShardIndexGenerator", py::module_local()) + .def(py::init()) + .def("build", &ShardIndexGenerator::Build) + .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); +} + +void BindShardSegment(py::module *m) { + (void)py::class_(*m, "ShardSegment", py::module_local()) + .def(py::init<>()) + .def("open", (MSRStatus(ShardSegment::*)(const std::vector &, bool, const int &, + const std::vector &, + const std::vector> &)) & + ShardSegment::OpenPy) + .def("get_category_fields", + (std::pair>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) + .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) + .def("read_category_info", (std::pair(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) + .def("read_at_page_by_id", (std::pair, pybind11::object>>>( + ShardSegment::*)(int64_t, int64_t, int64_t)) & + ShardSegment::ReadAtPageByIdPy) + .def("read_at_page_by_name", (std::pair, pybind11::object>>>( + ShardSegment::*)(std::string, int64_t, int64_t)) & + ShardSegment::ReadAtPageByNamePy) + .def("get_header", &ShardSegment::GetShardHeader) + .def("get_blob_fields", + (std::pair>(ShardSegment::*)()) & ShardSegment::GetBlobFields); +} + +void BindGlobalParams(py::module *m) { + (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize; + (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize; + (*m).attr("MIN_PAGE_SIZE") = kMinPageSize; + (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; + (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; + (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; + (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; + (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); +} + +PYBIND11_MODULE(_c_mindrecord, m) { + m.doc() = "pybind11 mindrecord plugin"; // optional module docstring + (void)py::enum_(m, "MSRStatus", py::module_local()) + .value("SUCCESS", SUCCESS) + .value("FAILED", FAILED) + .export_values(); + (void)py::enum_(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values(); + BindGlobalParams(&m); + BindSchema(&m); + BindStatistics(&m); + BindShardHeader(&m); + BindShardWriter(&m); + BindShardReader(&m); + BindShardIndexGenerator(&m); + BindShardSegment(&m); +} +} // namespace mindrecord +} // namespace mindspore + +namespace nlohmann { +namespace detail { +py::object FromJsonImpl(const json &j) { + if (j.is_null()) { + return py::none(); + } else if (j.is_boolean()) { + return py::bool_(j.get()); + } else if (j.is_number()) { + double number = j.get(); + if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) { + return py::int_(j.get()); + } else { + return py::float_(number); + } + } else if (j.is_string()) { + return py::str(j.get()); + } else if (j.is_array()) { + py::list obj; + for (const auto &el : j) { + (void)obj.attr("append")(FromJsonImpl(el)); + } + return std::move(obj); + } else { + py::dict obj; + for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) { + obj[py::str(it.key())] = FromJsonImpl(it.value()); + } + return std::move(obj); + } +} + +json ToJsonImpl(const py::handle &obj) { + if (obj.is_none()) { + return nullptr; + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj)) { + return obj.cast(); + } + if (py::isinstance(obj) || py::isinstance(obj)) { + auto out = json::array(); + for (const py::handle &value : obj) { + out.push_back(ToJsonImpl(value)); + } + return out; + } + if (py::isinstance(obj)) { + auto out = json::object(); + for (const py::handle &key : obj) { + out[py::str(key).cast()] = ToJsonImpl(obj[key]); + } + return out; + } + MS_LOG(ERROR) << "Python to json failed, obj is: " << py::cast(obj); + return json(); +} +} // namespace detail + +py::object adl_serializer::FromJson(const json &j) { return detail::FromJsonImpl(j); } + +void adl_serializer::ToJson(json *j, const py::object &obj) { + *j = detail::ToJsonImpl(obj); +} // namespace detail +} // namespace nlohmann diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc new file mode 100644 index 0000000000..b5021802a0 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc @@ -0,0 +1,204 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "common/utils.h" +#include "./securec.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +// split a string using a character +std::vector StringSplit(const std::string &field, char separator) { + std::vector res; + uint64_t s_pos = 0; + while (s_pos < field.length()) { + size_t e_pos = field.find_first_of(separator, s_pos); + if (e_pos != std::string::npos) { + res.push_back(field.substr(s_pos, e_pos - s_pos)); + } else { + res.push_back(field.substr(s_pos, field.length() - s_pos)); + break; + } + s_pos = e_pos + 1; + } + return res; +} + +bool ValidateFieldName(const std::string &str) { + std::string::const_iterator it = str.begin(); + if (it == str.end()) { + return false; + } + for (; it != str.end(); ++it) { + if (*it == '_' || ((*it >= '0') && (*it <= '9')) || ((*it >= 'A') && (*it <= 'Z')) || + ((*it >= 'a') && (*it <= 'z'))) { + continue; + } + return false; + } + return true; +} + +std::pair GetFileName(const std::string &path) { + char real_path[PATH_MAX] = {0}; + char buf[PATH_MAX] = {0}; + if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { + MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; + return {FAILED, ""}; + } + char tmp[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; + } +#else + if (realpath(dirname(&(buf[0])), tmp) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (realpath(common::SafeCStr(path), real_path) == nullptr) { + MS_LOG(DEBUG) << "Path: " << path << "check successfully"; + } +#endif + std::string s = real_path; + char sep = '/'; + size_t i = s.rfind(sep, s.length()); + if (i != std::string::npos) { + if (i + 1 < s.size()) { + return {SUCCESS, s.substr(i + 1)}; + } + } + return {SUCCESS, s}; +} + +std::pair GetParentDir(const std::string &path) { + char real_path[PATH_MAX] = {0}; + char buf[PATH_MAX] = {0}; + if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { + MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; + return {FAILED, ""}; + } + char tmp[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; + } +#else + if (realpath(dirname(&(buf[0])), tmp) == nullptr) { + MS_LOG(ERROR) << "Invalid file path, path: " << buf; + return {FAILED, ""}; + } + if (realpath(common::SafeCStr(path), real_path) == nullptr) { + MS_LOG(DEBUG) << "Path: " << path << "check successfully"; + } +#endif + std::string s = real_path; + if (s.rfind('/') + 1 <= s.size()) { + return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; + } + return {SUCCESS, "/"}; +} + +bool CheckIsValidUtf8(const std::string &str) { + int n = 0; + int ix = str.length(); + for (int i = 0; i < ix; ++i) { + uint8_t c = static_cast(str[i]); + if (c <= 0x7f) { + n = 0; + } else if ((c & 0xE0) == 0xC0) { + n = 1; + } else if (c == 0xed && i < (ix - 1) && (static_cast(str[i + 1]) & 0xa0) == 0xa0) { + return false; + } else if ((c & 0xF0) == 0xE0) { + n = 2; + } else if ((c & 0xF8) == 0xF0) { + n = 3; + } else { + return false; + } + for (int j = 0; j < n && i < ix; ++j) { + if ((++i == ix) || ((static_cast(str[i]) & 0xC0) != 0x80)) { + return false; + } + } + } + return true; +} + +bool IsLegalFile(const std::string &path) { + struct stat s; + if (stat(common::SafeCStr(path), &s) == 0) { + if (s.st_mode & S_IFDIR) { + return false; + } + return true; + } + return false; +} + +std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { +#if defined(_WIN32) || defined(_WIN64) + return {SUCCESS, 100}; +#else + uint64_t ll_count = 0; + struct statfs disk_info; + if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { + MS_LOG(ERROR) << "Get disk size error"; + return {FAILED, 0}; + } + + switch (disk_type) { + case kTotalSize: + ll_count = disk_info.f_bsize * disk_info.f_blocks; + ll_count = ll_count >> 20; + break; + case kFreeSize: + ll_count = disk_info.f_bsize * disk_info.f_bavail; + ll_count = ll_count >> 20; + break; + default: + ll_count = 0; + break; + } + + return {SUCCESS, ll_count}; +#endif +} + +uint32_t GetMaxThreadNum() { + // define the number of thread + uint32_t thread_num = std::thread::hardware_concurrency(); + if (thread_num == 0) { + thread_num = kMaxConsumerCount; + } + return thread_num; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h new file mode 100644 index 0000000000..3b3698ca68 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_pybind.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ +#define MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ + +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "pybind11/pybind11.h" + +namespace py = pybind11; +namespace nlohmann { +template <> +struct adl_serializer { + py::object FromJson(const json &j); + + void ToJson(json *j, const py::object &obj); +}; + +namespace detail { +py::object FromJsonImpl(const json &j); + +json ToJsonImpl(const py::handle &obj); +} // namespace detail +} // namespace nlohmann +#endif // MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h new file mode 100644 index 0000000000..bd1cda8a99 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -0,0 +1,182 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ +#define MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ + +#include +#include +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_error.h" +#include "nlohmann/json.hpp" +#include "./sqlite3.h" +#include "utils/log_adapter.h" + +/* To be used when dlog is ok #include "./slog.h" */ +#ifdef DEBUG +#define MS_ASSERT(f) assert(f) +#else +#define MS_ASSERT(f) ((void)0) +#endif + +namespace mindspore { +namespace mindrecord { +using json = nlohmann::json; + +const int kInt0 = 0; +const int kInt1 = 1; +const int kInt2 = 2; +const int kInt3 = 3; +const int kUnsignedInt4 = 4; + +enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; + +const char kVersion[] = "3.0"; +const std::vector kSupportedVersion = {"2.0", kVersion}; + +enum ShardType { + kNLP = 0, + kCV = 1, +}; + +enum TaskType { + kCommonTask = 0, + kPaddedTask = 1, +}; +enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; + +enum ShuffleType { kShuffleCategory, kShuffleSample }; + +const double kEpsilon = 1e-7; + +const int kThreadNumber = 14; + +// Shard default parameters +const uint64_t kDefaultHeaderSize = 1 << 24; // 16MB +const uint64_t kDefaultPageSize = 1 << 25; // 32MB + +// HeaderSize [16KB, 128MB] +const int kMinHeaderSize = 1 << 14; // 16KB +const int kMaxHeaderSize = 1 << 27; // 128MB + +// PageSize [32KB, 256MB] +const int kMinPageSize = 1 << 15; // 32KB +const int kMaxPageSize = 1 << 28; // 256MB + +// used by value length / schema id length / statistic id length ... +const uint64_t kInt64Len = 8; + +// Minimum file size +const uint64_t kMinFileSize = kInt64Len; + +const int kMinShardCount = 1; +const int kMaxShardCount = 1000; + +const int kMinConsumerCount = 1; +const int kMaxConsumerCount = 128; + +const int kMaxSchemaCount = 1; +const int kMaxThreadCount = 32; +const int kMaxFieldCount = 100; + +// Minimum free disk size +const int kMinFreeDiskSize = 10; // 10M + +// dummy json +const json kDummyId = R"({"id": 0})"_json; + +// translate type in schema to type in sqlite3(NULL, INTEGER, REAL, TEXT, BLOB) +const std::unordered_map kDbJsonMap = { + {"string", "TEXT"}, {"date", "DATE"}, {"date-time", "DATETIME"}, {"null", "NULL"}, + {"integer", "INTEGER"}, {"boolean", "BOOLEAN"}, {"array", "BLOB"}, {"number", "NUMERIC"}, + {"int32", "INTEGER"}, {"int64", "INTEGER"}, {"float32", "NUMERIC"}, {"float64", "NUMERIC"}, + {"bytes", "BLOB"}}; + +const char kPoint = '.'; + +// field type used by check schema validation +const std::set kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; + +// can be searched field list +const std::set kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; + +// number field list +const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; + +/// \brief split a string using a character +/// \param[in] field target string +/// \param[in] separator a character for spliting +/// \return vector type result +std::vector StringSplit(const std::string &field, char separator); + +/// \brief validate field name is composed of '0-9' or 'a-z' or 'A-Z' or '_' or '-' +/// \param[in] str target string +/// \return +bool ValidateFieldName(const std::string &str); + +/// \brief get the filename by the path +/// \param s file path +/// \return +std::pair GetFileName(const std::string &s); + +/// \brief get parent dir +/// \param path file path +/// \return parent path +std::pair GetParentDir(const std::string &path); + +bool CheckIsValidUtf8(const std::string &str); + +/// \brief judge if a path is legal file +/// \param path file path +/// \return parent path +bool IsLegalFile(const std::string &path); + +enum DiskSizeType { kTotalSize = 0, kFreeSize }; + +/// \brief get the free space about the disk +/// \param str_dir file path +/// \param disk_type: kTotalSize / kFreeSize +/// \return size in Megabytes +std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); + +/// \brief get the max hardware concurrency +/// \return max concurrency +uint32_t GetMaxThreadNum(); +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h new file mode 100644 index 0000000000..ed1e748afe --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" + +namespace mindspore { +namespace mindrecord { +class ShardCategory : public ShardOperator { + public: + explicit ShardCategory(const std::vector> &categories, + int64_t num_elements = std::numeric_limits::max(), bool replacement = false); + + ShardCategory(const std::string &category_field, int64_t num_elements, + int64_t num_categories = std::numeric_limits::max(), bool replacement = false); + + ~ShardCategory() override{}; + + const std::vector> &GetCategories() const { return categories_; } + + const std::string GetCategoryField() const { return category_field_; } + + int64_t GetNumElements() const { return num_elements_; } + + int64_t GetNumCategories() const { return num_categories_; } + + bool GetReplacement() const { return replacement_; } + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + std::vector> categories_; + std::string category_field_; + int64_t num_elements_; + int64_t num_categories_; + bool replacement_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h new file mode 100644 index 0000000000..f6353ed3ce --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h @@ -0,0 +1,167 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ +#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_header.h" + +namespace mindspore { +namespace mindrecord { +const uint64_t kUnsignedOne = 1; +const uint64_t kBitsOfByte = 8; +const uint64_t kDataTypeBits = 2; +const uint64_t kNumDataOfByte = 4; +const uint64_t kBytesOfColumnLen = 4; +const uint64_t kDataTypeBitMask = 3; +const uint64_t kDataTypes = 6; + +enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; + +enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; + +enum ColumnDataType { + ColumnBytes = 0, + ColumnString = 1, + ColumnInt32 = 2, + ColumnInt64 = 3, + ColumnFloat32 = 4, + ColumnFloat64 = 5, + ColumnNoDataType = 6 +}; + +// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; +const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; + +const std::vector ColumnDataTypeNameNormalized = {"uint8", "string", "int32", + "int64", "float32", "float64"}; + +const std::unordered_map ColumnDataTypeMap = { + {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, + {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; + +class ShardColumn { + public: + explicit ShardColumn(const std::shared_ptr &shard_header, bool compress_integer = true); + + ~ShardColumn() = default; + + /// \brief get column value by column name + MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *const n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape); + + /// \brief compress blob + std::vector CompressBlob(const std::vector &blob); + + /// \brief check if blob compressed + bool CheckCompressBlob() const { return has_compress_blob_; } + + uint64_t GetNumBlobColumn() const { return num_blob_column_; } + + std::vector GetColumnName() { return column_name_; } + + std::vector GeColumnDataType() { return column_data_type_; } + + std::vector> GetColumnShape() { return column_shape_; } + + /// \brief get column value from blob + MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *const n_bytes); + std::pair GetColumnTypeByName(const std::string &column_name, + ColumnDataType *column_data_type, + uint64_t *column_data_type_size, + std::vector *column_shape); + + /// \brief get column value from json + MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes); + + private: + /// \brief get float value from json + template + MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); + + /// \brief get integer value from json + template + MSRStatus GetInt(std::unique_ptr *data_ptr, const json &json_column_value); + + /// \brief get column offset address and size from blob + MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx); + + /// \brief check if column name is available + ColumnCategory CheckColumnName(const std::string &column_name); + + /// \brief compress integer column + static vector CompressInt(const vector &src_bytes, const IntegerType &int_type); + + /// \brief uncompress integer array column + template + static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); + + /// \brief convert big-endian bytes to unsigned int + /// \param bytes_array bytes array + /// \param pos shift address in bytes array + /// \param i_type integer type + /// \return unsigned int + static uint64_t BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &i_type); + + /// \brief convert unsigned int to big-endian bytes + /// \param value integer value + /// \param i_type integer type + /// \return bytes + static std::vector UIntToBytesBig(uint64_t value, const IntegerType &i_type); + + /// \brief convert unsigned int to little-endian bytes + /// \param value integer value + /// \param i_type integer type + /// \return bytes + static std::vector UIntToBytesLittle(uint64_t value, const IntegerType &i_type); + + /// \brief convert unsigned int to little-endian bytes + /// \param bytes_array bytes array + /// \param pos shift address in bytes array + /// \param src_i_type source integer typ0e + /// \param dst_i_type (output), destination integer type + /// \return integer + static int64_t BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); + + private: + std::vector column_name_; // column name list + std::vector column_data_type_; // column data type list + std::vector> column_shape_; // column shape list + std::unordered_map column_name_id_; // column name id map + std::vector blob_column_; // blob column list + std::unordered_map blob_column_id_; // blob column name id map + bool has_compress_blob_; // if has compress blob + uint64_t num_blob_column_; // number of blob columns +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h new file mode 100644 index 0000000000..f166ec1e6c --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_sample.h" + +namespace mindspore { +namespace mindrecord { +class ShardDistributedSample : public ShardSample { + public: + ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); + + ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); + + void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } + + ~ShardDistributedSample() override{}; + + MSRStatus PreExecute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + bool shuffle_; + int no_of_padded_samples_; + bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch + ShardTask task_; // maintain the input tasks in first epoch +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_error.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h similarity index 100% rename from mindspore/ccsrc/mindrecord/include/shard_error.h rename to mindspore/ccsrc/minddata/mindrecord/include/shard_error.h diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h new file mode 100644 index 0000000000..67169e8696 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -0,0 +1,186 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_HEADER_H_ +#define MINDRECORD_INCLUDE_SHARD_HEADER_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_page.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "minddata/mindrecord/include/shard_statistics.h" + +namespace mindspore { +namespace mindrecord { +class ShardHeader { + public: + ShardHeader(); + + ~ShardHeader() = default; + + MSRStatus BuildDataset(const std::vector &file_paths, bool load_dataset = true); + + static std::pair BuildSingleHeader(const std::string &file_path); + /// \brief add the schema and save it + /// \param[in] schema the schema needs to be added + /// \return the last schema's id + int AddSchema(std::shared_ptr schema); + + /// \brief add the statistic and save it + /// \param[in] statistic the statistic needs to be added + /// \return the last statistic's id + void AddStatistic(std::shared_ptr statistic); + + /// \brief create index and add fields which from schema for each schema + /// \param[in] fields the index fields needs to be added + /// \return SUCCESS if add successfully, FAILED if not + MSRStatus AddIndexFields(std::vector> fields); + + MSRStatus AddIndexFields(const std::vector &fields); + + /// \brief get the schema + /// \return the schema + std::vector> GetSchemas(); + + /// \brief get Statistics + /// \return the Statistic + std::vector> GetStatistics(); + + /// \brief get the fields of the index + /// \return the fields of the index + std::vector> GetFields(); + + /// \brief get the index + /// \return the index + std::shared_ptr GetIndex(); + + /// \brief get the schema by schemaid + /// \param[in] schemaId the id of schema needs to be got + /// \return the schema obtained by schemaId + std::pair, MSRStatus> GetSchemaByID(int64_t schema_id); + + /// \brief get the filepath to shard by shardID + /// \param[in] shardID the id of shard which filepath needs to be obtained + /// \return the filepath obtained by shardID + std::string GetShardAddressByID(int64_t shard_id); + + /// \brief get the statistic by statistic id + /// \param[in] statisticId the id of statistic needs to be get + /// \return the statistics obtained by statistic id + std::pair, MSRStatus> GetStatisticByID(int64_t statistic_id); + + MSRStatus InitByFiles(const std::vector &file_paths); + + void SetIndex(Index index) { index_ = std::make_shared(index); } + + std::pair, MSRStatus> GetPage(const int &shard_id, const int &page_id); + + MSRStatus SetPage(const std::shared_ptr &new_page); + + MSRStatus AddPage(const std::shared_ptr &new_page); + + int64_t GetLastPageId(const int &shard_id); + + int GetLastPageIdByType(const int &shard_id, const std::string &page_type); + + const std::pair> GetPageByGroupId(const int &group_id, const int &shard_id); + + std::vector GetShardAddresses() const { return shard_addresses_; } + + int GetShardCount() const { return shard_count_; } + + int GetSchemaCount() const { return schema_.size(); } + + uint64_t GetHeaderSize() const { return header_size_; } + + uint64_t GetPageSize() const { return page_size_; } + + void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } + + void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } + + std::vector SerializeHeader(); + + MSRStatus PagesToFile(const std::string dump_file_name); + + MSRStatus FileToPages(const std::string dump_file_name); + + private: + MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); + + /// \brief get the headers from all the shard data + /// \param[in] the shard data real path + /// \param[in] the headers which readed from the shard data + /// \return SUCCESS/FAILED + MSRStatus GetHeaders(const vector &real_addresses, std::vector &headers); + + MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); + + /// \brief check the binary file status + static MSRStatus CheckFileStatus(const std::string &path); + + static std::pair ValidateHeader(const std::string &path); + + void ParseHeader(const json &header); + + void GetHeadersOneTask(int start, int end, std::vector &headers, const vector &realAddresses); + + MSRStatus ParseIndexFields(const json &index_fields); + + MSRStatus CheckIndexField(const std::string &field, const json &schema); + + void ParsePage(const json &page, int shard_index, bool load_dataset); + + MSRStatus ParseStatistics(const json &statistics); + + MSRStatus ParseSchema(const json &schema); + + void ParseShardAddress(const json &address); + + std::string SerializeIndexFields(); + + std::vector SerializePage(); + + std::string SerializeStatistics(); + + std::string SerializeSchema(); + + std::string SerializeShardAddress(); + + std::shared_ptr InitIndexPtr(); + + MSRStatus GetAllSchemaID(std::set &bucket_count); + + uint32_t shard_count_; + uint64_t header_size_; + uint64_t page_size_; + + std::shared_ptr index_; + std::vector shard_addresses_; + std::vector> schema_; + std::vector> statistics_; + std::vector>> pages_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_HEADER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h new file mode 100644 index 0000000000..79b10893fb --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INDEX_H +#define MINDRECORD_INDEX_H +#pragma once + +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +using std::cin; +using std::endl; +using std::pair; +using std::string; +using std::vector; + +class Index { + public: + Index(); + + ~Index() {} + + /// \brief Add field which from schema according to schemaId + /// \param[in] schemaId the id of schema to be added + /// \param[in] field the field need to be added + /// + /// add the field to the fields_ vector + void AddIndexField(const int64_t &schemaId, const std::string &field); + + /// \brief get stored fields + /// \return fields stored + std::vector > GetFields(); + + private: + std::vector > fields_; + string database_name_; + string table_name_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INDEX_H diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h new file mode 100644 index 0000000000..fb85d9adbc --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ +#define MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_header.h" +#include "./sqlite3.h" + +namespace mindspore { +namespace mindrecord { +using INDEX_FIELDS = std::pair>>; +using ROW_DATA = std::pair>>>; +class ShardIndexGenerator { + public: + explicit ShardIndexGenerator(const std::string &file_path, bool append = false); + + MSRStatus Build(); + + static std::pair GenerateFieldName(const std::pair &field); + + ~ShardIndexGenerator() {} + + /// \brief fetch value in json by field name + /// \param[in] field + /// \param[in] input + /// \return pair + std::pair GetValueByField(const string &field, json input); + + /// \brief fetch field type in schema n by field path + /// \param[in] field_path + /// \param[in] schema + /// \return the type of field + static std::string TakeFieldType(const std::string &field_path, json schema); + + /// \brief create databases for indexes + MSRStatus WriteToDatabase(); + + private: + static int Callback(void *not_used, int argc, char **argv, char **az_col_name); + + static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); + + static std::string ConvertJsonToSQL(const std::string &json); + + std::pair CreateDatabase(int shard_no); + + std::pair> GetSchemaDetails(const std::vector &schema_lens, std::fstream &in); + + static std::pair GenerateRawSQL(const std::vector> &fields); + + std::pair CheckDatabase(const std::string &shard_address); + + /// + /// \param shard_no + /// \param blob_id_to_page_id + /// \param raw_page_id + /// \param in + /// \return field name, db type, field value + ROW_DATA GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, + std::fstream &in); + /// + /// \param db + /// \param sql + /// \param data + /// \return + MSRStatus BindParameterExecuteSQL( + sqlite3 *db, const std::string &sql, + const std::vector>> &data); + + INDEX_FIELDS GenerateIndexFields(const std::vector &schema_detail); + + MSRStatus ExecuteTransaction(const int &shard_no, std::pair &db, + const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); + + MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); + + MSRStatus AddBlobPageInfo(std::vector> &row_data, + const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, + std::fstream &in); + + void AddIndexFieldByRawData(const std::vector &schema_detail, + std::vector> &row_data); + + void DatabaseWriter(); // worker thread + + std::string file_path_; + bool append_; + ShardHeader shard_header_; + uint64_t page_size_; + uint64_t header_size_; + int schema_count_; + std::atomic_int task_; + std::atomic_bool write_success_; + std::vector> fields_; +}; +} // namespace mindrecord +} // namespace mindspore +#endif // MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h new file mode 100644 index 0000000000..b5ea53b759 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ +#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ + +#include +#include "minddata/mindrecord/include/shard_task.h" + +namespace mindspore { +namespace mindrecord { +class ShardOperator { + public: + virtual ~ShardOperator() = default; + + MSRStatus operator()(ShardTask &tasks) { + if (SUCCESS != this->PreExecute(tasks)) { + return FAILED; + } + if (SUCCESS != this->Execute(tasks)) { + return FAILED; + } + if (SUCCESS != this->SufExecute(tasks)) { + return FAILED; + } + return SUCCESS; + } + virtual bool HasChildOp() { return child_op_ != nullptr; } + + virtual MSRStatus SetChildOp(std::shared_ptr child_op) { + if (child_op != nullptr) child_op_ = child_op; + return SUCCESS; + } + + virtual std::shared_ptr GetChildOp() { return child_op_; } + + virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } + + virtual MSRStatus Execute(ShardTask &tasks) = 0; + + virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } + + virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } + + private: + std::shared_ptr child_op_ = nullptr; +}; +} // namespace mindrecord +} // namespace mindspore +#endif // MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h new file mode 100644 index 0000000000..01c70acf29 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_page.h @@ -0,0 +1,106 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PAGE_H_ +#define MINDRECORD_INCLUDE_SHARD_PAGE_H_ + +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +const std::string kPageTypeRaw = "RAW_DATA"; +const std::string kPageTypeBlob = "BLOB_DATA"; +const std::string kPageTypeNewColumn = "NEW_COLUMN_DATA"; + +class Page { + public: + Page(const int &page_id, const int &shard_id, const std::string &page_type, const int &page_type_id, + const uint64_t &start_row_id, const uint64_t end_row_id, + const std::vector> &row_group_ids, const uint64_t page_size) + : page_id_(page_id), + shard_id_(shard_id), + page_type_(page_type), + page_type_id_(page_type_id), + start_row_id_(start_row_id), + end_row_id_(end_row_id), + row_group_ids_(row_group_ids), + page_size_(page_size) {} + + ~Page() = default; + + /// \brief get the page and its description + /// \return the json format of the page and its description + json GetPage() const; + + int GetPageID() const { return page_id_; } + + int GetShardID() const { return shard_id_; } + + int GetPageTypeID() const { return page_type_id_; } + + std::string GetPageType() const { return page_type_; } + + uint64_t GetPageSize() const { return page_size_; } + + uint64_t GetStartRowID() const { return start_row_id_; } + + uint64_t GetEndRowID() const { return end_row_id_; } + + void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } + + void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } + + std::pair GetLastRowGroupID() const { return row_group_ids_.back(); } + + std::vector> GetRowGroupIds() const { return row_group_ids_; } + + void SetRowGroupIds(const std::vector> &last_row_group_ids) { + row_group_ids_ = last_row_group_ids; + } + + void DeleteLastGroupId(); + + private: + int page_id_; + int shard_id_; + std::string page_type_; + int page_type_id_; + uint64_t start_row_id_; + uint64_t end_row_id_; + std::vector> row_group_ids_; + uint64_t page_size_; + // JSON page: { + // "page_id":X, + // "shard_id":X, + // "page_type":"XXX", (enum "raw_data", "blob_data", "new_column") + // "page_type_id":X, + // "start_row_id":X, + // "end_row_id":X, + // "row_group_ids":[{"id":X, "offset":X}], + // "page_size":X, +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PAGE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h new file mode 100644 index 0000000000..2d420b563d --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +class ShardPkSample : public ShardCategory { + public: + ShardPkSample(const std::string &category_field, int64_t num_elements); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); + + ~ShardPkSample() override{}; + + MSRStatus SufExecute(ShardTask &tasks) override; + + private: + bool shuffle_; + std::shared_ptr shuffle_op_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h new file mode 100644 index 0000000000..b1b0c1397a --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -0,0 +1,366 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_READER_H_ +#define MINDRECORD_INCLUDE_SHARD_READER_H_ + +#include +#include +#if !defined(_WIN32) && !defined(_WIN64) +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +using ROW_GROUPS = + std::tuple>>, std::vector>>; +using ROW_GROUP_BRIEF = + std::tuple>, std::vector>; +using TASK_RETURN_CONTENT = + std::pair, json>>>>; +const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode +const int kNumPageInBuffer = 16; // page buffer size in block-reader mode + +class ShardReader { + public: + ShardReader(); + + virtual ~ShardReader(); + + /// \brief open files and initialize reader, c++ API + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not + /// \param[in] n_consumer number of threads when reading + /// \param[in] selected_columns column list to be populated + /// \param[in] operators operators applied to data, operator type is shuffle, sample or category + /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, + const std::vector &selected_columns = {}, + const std::vector> &operators = {}, const bool &block_reader = false, + const int num_padded = 0); + + /// \brief open files and initialize reader, python API + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not + /// \param[in] n_consumer number of threads when reading + /// \param[in] selected_columns column list to be populated + /// \param[in] operators operators applied to data, operator type is shuffle, sample or category + /// \return MSRStatus the status of MSRStatus + MSRStatus OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer = 4, + const std::vector &selected_columns = {}, + const std::vector> &operators = {}); + + /// \brief close reader + /// \return null + void Close(); + + /// \brief read the file, get schema meta,statistics and index, single-thread mode + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(); + + /// \brief read the file, get schema meta,statistics and index, multiple-thread mode + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(int n_consumer); + + /// \brief launch threads to get batches + /// \param[in] is_simple_reader trigger threads if false; do nothing if true + /// \return MSRStatus the status of MSRStatus + MSRStatus Launch(bool is_simple_reader = false); + + /// \brief aim to get the meta data + /// \return the metadata + std::shared_ptr GetShardHeader() const; + + /// \brief aim to get columns context + /// \return the columns + std::shared_ptr GetShardColumn() const; + + /// \brief get the number of shards + /// \return # of shards + int GetShardCount() const; + + /// \brief get the number of rows in database + /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list + /// \param[in] load_dataset load dataset from single file or not + /// \param[in] op smart pointer refer to ShardCategory or ShardSample object + /// \param[out] count # of rows + /// \return MSRStatus the status of MSRStatus + MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &op, int64_t *count, const int num_padded); + + /// \brief shuffle task with incremental seed + /// \return void + void ShuffleTask(); + + /// \brief get the number of rows in database + /// \return # of rows + int GetNumRows() const; + + /// \brief Read the summary of row groups + /// \return the tuple of 4 elements + /// 1. Sharding ID + /// 2. Row group ID + /// 3. The row ID started in row group + /// 4. # of rows in row group + std::vector> ReadRowGroupSummary(); + + /// \brief Read 1 row group data, excluding images + /// \param[in] groupID row group ID + /// \param[in] shard_id sharding ID + /// \param[in] columns multi-columns retrieved + /// \return the tuple of 5 elements + /// 1. file name where row group is located + /// 2. Actual row group size + /// 3. Offset address of row group in file + /// 4. The list of image offset in page [startOffset, endOffset) + /// 5. The list of columns data + ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, + const std::vector &columns = std::vector()); + + /// \brief Read 1 row group data, excluding images, following an index field criteria + /// \param[in] groupID row group ID + /// \param[in] shard_id sharding ID + /// \param[in] column-value pair of criteria to fulfill + /// \param[in] columns multi-columns retrieved + /// \return the tuple of 5 elements + /// 1. file name where row group is located + /// 2. Actual row group size + /// 3. Offset address of row group in file + /// 4. The list of image offset in page [startOffset, endOffset) + /// 5. The list of columns data + ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, + const std::vector &columns = std::vector()); + + /// \brief join all created threads + /// \return MSRStatus the status of MSRStatus + MSRStatus Finish(); + + /// \brief return a batch, given that one is ready + /// \return a batch of images and image data + std::vector, json>> GetNext(); + + /// \brief return a row by id + /// \return a batch of images and image data + std::pair, json>>> GetNextById(const int64_t &task_id, + const int32_t &consumer_id); + + /// \brief return a batch in block-reader mode, given that one is ready + /// \return a batch of images and image data + std::vector, json>> GetBlockNext(); + + /// \brief return a batch, given that one is ready, python API + /// \return a batch of images and image data + std::vector>, pybind11::object>> GetNextPy(); + + /// \brief get blob filed list + /// \return blob field list + std::pair> GetBlobFields(); + + /// \brief reset reader + /// \return null + void Reset(); + + /// \brief set flag of all-in-index + /// \return null + void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } + + /// \brief get NLP flag + bool GetNlpFlag(); + + /// \brief get all classes + MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + + protected: + /// \brief sqlite call back function + static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); + + private: + /// \brief wrap up labels to json format + MSRStatus ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, + std::vector>> &offsets, int shard_id, + const std::vector &columns, std::vector> &column_values); + + /// \brief read all rows for specified columns + ROW_GROUPS ReadAllRowGroup(std::vector &columns); + + /// \brief read all rows in one shard + MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, + std::vector>> &offsets, + std::vector> &column_values); + + /// \brief initialize reader + MSRStatus Init(const std::vector &file_paths, bool load_dataset); + + /// \brief validate column list + MSRStatus CheckColumnList(const std::vector &selected_columns); + + /// \brief populate one row by task list in row-reader mode + MSRStatus ConsumerByRow(int consumer_id); + + /// \brief populate one row by task list in block-reader mode + MSRStatus ConsumerByBlock(int consumer_id); + + /// \brief get offset address of images within page + std::vector> GetImageOffset(int group_id, int shard_id, + const std::pair &criteria = {"", ""}); + + /// \brief execute sqlite query with prepare statement + MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector> &labels); + + /// \brief get column values + std::pair> GetLabels(int group_id, int shard_id, const std::vector &columns, + const std::pair &criteria = {"", ""}); + + /// \brief get column values from raw data page + std::pair> GetLabelsFromPage(int group_id, int shard_id, + const std::vector &columns, + const std::pair &criteria = {"", + ""}); + + /// \brief create task list in block-reader mode + MSRStatus CreateTasksByBlock(const std::vector> &row_group_summary, + const std::vector> &operators); + + /// \brief create category-applied task list + MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op); + + /// \brief create task list in row-reader mode + MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators); + + /// \brief crate task list + MSRStatus CreateTasks(const std::vector> &row_group_summary, + const std::vector> &operators); + + /// \brief set NLP flag + void CheckNlp(); + + /// \brief check if all specified columns are in index table + void CheckIfColumnInIndex(const std::vector &columns); + + /// \brief open multiple file handle + void FileStreamsOperator(); + + /// \brief read one row by one task + TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); + + /// \brief get one row from buffer in block-reader mode + std::shared_ptr, json>>> GetRowFromBuffer(int bufId, int rowId); + + /// \brief get labels from binary file + std::pair> GetLabelsFromBinaryFile( + int shard_id, const std::vector &columns, const std::vector> &label_offsets); + + MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); + + /// \brief get classes in one shard + void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); + + /// \brief get number of classes + int64_t GetNumClasses(const std::string &category_field); + + /// \brief get meta of header + std::pair> GetMeta(const std::string &file_path, json &meta_data); + + /// \brief extract uncompressed data based on column list + std::pair>> UnCompressBlob(const std::vector &raw_blob_data); + + protected: + uint64_t header_size_; // header size + uint64_t page_size_; // page size + int shard_count_; // number of shards + std::shared_ptr shard_header_; // shard header + std::shared_ptr shard_column_; // shard column + + std::vector database_paths_; // sqlite handle list + std::vector file_paths_; // file paths + std::vector> file_streams_; // single-file handle list + std::vector>> file_streams_random_; // multiple-file handle list + + private: + int n_consumer_; // number of workers (threads) + std::vector selected_columns_; // columns which will be read + std::map column_schema_id_; // column-schema map + std::vector> operators_; // data operators, including shuffle, sample and category + ShardTask tasks_; // shard task + std::mutex shard_locker_; // locker of shard + + // flags + bool all_in_index_ = true; // if all columns are stored in index-table + bool interrupt_ = false; // reader interrupted + + int num_padded_; // number of padding samples + + // Delivery/Iterator mode begin + const std::string kThreadName = "THRD_ITER_"; // prefix of thread name + std::vector thread_set_; // thread list + int num_rows_; // number of rows + std::mutex mtx_delivery_; // locker for delivery + std::condition_variable cv_delivery_; // conditional variable for delivery + std::condition_variable cv_iterator_; // conditional variable for iterator + std::atomic task_id_; // task ID which is working + std::atomic deliver_id_; // delivery ID which is picked up by iterator + // map of delivery + std::unordered_map, json>>>> delivery_map_; + // Delivery/Iterator mode end + + // Block reader mode begin + bool block_reader_; // block-reader mode + int row_id_; // row id in one page + int num_blocks_; // number of pages + // raw data page + std::vector>, std::vector>>> delivery_block_; + std::unordered_set delivery_block_set_; // set of delivered pages + std::vector> buf_; // page buffer + // Block reader mode end +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_READER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h new file mode 100644 index 0000000000..ce813bc4bf --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_operator.h" +#include "minddata/mindrecord/include/shard_shuffle.h" + +namespace mindspore { +namespace mindrecord { +class ShardSample : public ShardOperator { + public: + explicit ShardSample(int n); + + ShardSample(int num, int den); + + ShardSample(int num, int den, int par); + + ShardSample(const std::vector &indices, uint32_t seed); + + ~ShardSample() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + MSRStatus SufExecute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + protected: + int numerator_; + int denominator_; + int partition_id_; + int no_of_samples_; + std::shared_ptr shuffle_op_; + + private: + std::vector indices_; + SamplerType sampler_type_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h new file mode 100644 index 0000000000..56eae85e5a --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h @@ -0,0 +1,90 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ +#define MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_pybind.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +class Schema { + public: + ~Schema() = default; + + /// \brief obtain the json schema ,its description, its block fields + /// \param[in] desc the description of the schema + /// \param[in] schema the schema's json + static std::shared_ptr Build(std::string desc, const json &schema); + + /// \brief obtain the json schema and its description for python + /// \param[in] desc the description of the schema + /// \param[in] schema the schema's json + static std::shared_ptr Build(std::string desc, pybind11::handle schema); + + /// \brief compare two schema to judge if they are equal + /// \param b another schema to be judged + /// \return true if they are equal,false if not + bool operator==(const Schema &b) const; + + /// \brief get the schema and its description + /// \return the json format of the schema and its description + std::string GetDesc() const; + + /// \brief get the schema and its description + /// \return the json format of the schema and its description + json GetSchema() const; + + /// \brief get the schema and its description for python method + /// \return the python object of the schema and its description + pybind11::object GetSchemaForPython() const; + + /// set the schema id + /// \param[in] id the id need to be set + void SetSchemaID(int64_t id); + + /// get the schema id + /// \return the int64 schema id + int64_t GetSchemaID() const; + + /// get the blob fields + /// \return the vector blob fields + std::vector GetBlobFields() const; + + private: + Schema() = default; + static bool ValidateNumberShape(const json &it_value); + static bool Validate(json schema); + static std::vector PopulateBlobFields(json schema); + + std::string desc_; + json schema_; + std::vector blob_fields_; + int64_t schema_id_ = -1; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h new file mode 100644 index 0000000000..45d9bda338 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h @@ -0,0 +1,102 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ +#define MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_reader.h" + +namespace mindspore { +namespace mindrecord { +class ShardSegment : public ShardReader { + public: + ShardSegment(); + + ~ShardSegment() override = default; + + /// \brief Get candidate category fields + /// \return a list of fields names which are the candidates of category + std::pair> GetCategoryFields(); + + /// \brief Set category field + /// \param[in] category_field category name + /// \return true if category name is existed + MSRStatus SetCategoryField(std::string category_field); + + /// \brief Thread-safe implementation of ReadCategoryInfo + /// \return statistics data in json format with 2 field: "key" and "categories". + /// The value of "categories" is a list. Each Element in list is {count, id, name} + /// count: count of images in category + /// id: internal unique identification, persistent + /// name: category name + /// example: + /// { "key": "label", + /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, + /// { "count": 3, "id": 1, "name": "finance", } ] } + std::pair ReadCategoryInfo(); + + /// \brief Thread-safe implementation of ReadAtPageById + /// \param[in] category_id category ID + /// \param[in] page_no page number + /// \param[in] n_rows_of_page rows number in one page + /// \return images array, image is a vector of uint8_t + std::pair>> ReadAtPageById(int64_t category_id, int64_t page_no, + int64_t n_rows_of_page); + + /// \brief Thread-safe implementation of ReadAtPageByName + /// \param[in] category_name category Name + /// \param[in] page_no page number + /// \param[in] n_rows_of_page rows number in one page + /// \return images array, image is a vector of uint8_t + std::pair>> ReadAtPageByName(std::string category_name, int64_t page_no, + int64_t n_rows_of_page); + + std::pair, json>>> ReadAllAtPageById(int64_t category_id, + int64_t page_no, + int64_t n_rows_of_page); + + std::pair, json>>> ReadAllAtPageByName( + std::string category_name, int64_t page_no, int64_t n_rows_of_page); + + std::pair, pybind11::object>>> ReadAtPageByIdPy( + int64_t category_id, int64_t page_no, int64_t n_rows_of_page); + + std::pair, pybind11::object>>> ReadAtPageByNamePy( + std::string category_name, int64_t page_no, int64_t n_rows_of_page); + + std::pair> GetBlobFields(); + + private: + std::pair>> WrapCategoryInfo(); + + std::string ToJsonForCategory(const std::vector> &tri_vec); + + std::string CleanUp(std::string fieldName); + + std::pair> PackImages(int group_id, int shard_id, std::vector offset); + + std::vector candidate_category_fields_; + std::string current_category_field_; + const uint32_t kStartFieldId = 9; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h new file mode 100644 index 0000000000..724be9acaf --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ + +#include +#include +#include +#include +#include "minddata/mindrecord/include/shard_sample.h" + +namespace mindspore { +namespace mindrecord { +class ShardSequentialSample : public ShardSample { + public: + ShardSequentialSample(int n, int offset); + + ShardSequentialSample(float per, float per_offset); + + ~ShardSequentialSample() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + int offset_; + float per_; + float per_offset_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h new file mode 100644 index 0000000000..d7f736b55b --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ +#define MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ + +#include +#include "minddata/mindrecord/include/shard_operator.h" + +namespace mindspore { +namespace mindrecord { +class ShardShuffle : public ShardOperator { + public: + explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); + + ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, + ShuffleType shuffle_type = kShuffleSample); + + ~ShardShuffle() override{}; + + MSRStatus Execute(ShardTask &tasks) override; + + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + + private: + uint32_t shuffle_seed_; + int64_t no_of_samples_; + bool replacement_; + bool reshuffle_each_epoch_; + ShuffleType shuffle_type_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h new file mode 100644 index 0000000000..f100bb9833 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef MINDRECORD_STATISTICS_H +#define MINDRECORD_STATISTICS_H + +#include +#include +#include +#include +#include + +#include "minddata/mindrecord/include/common/shard_pybind.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +class Statistics { + public: + /// \brief save the statistic and its description + /// \param[in] desc the statistic's description + /// \param[in] statistics the statistic needs to be saved + static std::shared_ptr Build(std::string desc, const json &statistics); + + /// \brief save the statistic from python and its description + /// \param[in] desc the statistic's description + /// \param[in] statistics the statistic needs to be saved + static std::shared_ptr Build(std::string desc, pybind11::handle statistics); + + ~Statistics() = default; + + /// \brief compare two statistics to judge if they are equal + /// \param b another statistics to be judged + /// \return true if they are equal,false if not + bool operator==(const Statistics &b) const; + + /// \brief get the description + /// \return the description + std::string GetDesc() const; + + /// \brief get the statistic + /// \return json format of the statistic + json GetStatistics() const; + + /// \brief get the statistic for python + /// \return the python object of statistics + pybind11::object GetStatisticsForPython() const; + + /// \brief decode the bson statistics to json + /// \param[in] encodedStatistics the bson type of statistics + /// \return json type of statistic + void SetStatisticsID(int64_t id); + + /// \brief get the statistics id + /// \return the int64 statistics id + int64_t GetStatisticsID() const; + + private: + /// \brief validate the statistic + /// \return true / false + static bool Validate(const json &statistics); + + static bool LevelRecursive(json level); + + Statistics() = default; + + std::string desc_; + json statistics_; + int64_t statistics_id_ = -1; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_STATISTICS_H diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h new file mode 100644 index 0000000000..f07da656f2 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_ +#define MINDRECORD_INCLUDE_SHARD_TASK_H_ + +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" + +namespace mindspore { +namespace mindrecord { +class ShardTask { + public: + ShardTask(); + + ShardTask(const ShardTask &task); // copy construction + + ShardTask &operator=(const ShardTask &task); // assignment operator + + ~ShardTask() = default; + + void MakePerm(); + + void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label); + + void InsertTask(std::tuple, std::vector, json> task); + + void PopBack(); + + uint32_t Size() const; + + uint32_t SizeOfRows() const; + + std::tuple, std::vector, json> &GetTaskByID(size_t id); + + std::tuple, std::vector, json> &GetRandomTask(); + + static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); + + uint32_t categories; + + std::vector permutation_; + + std::vector, std::vector, json>> task_list_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h new file mode 100644 index 0000000000..833928773e --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h @@ -0,0 +1,257 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_WRITER_H_ +#define MINDRECORD_INCLUDE_SHARD_WRITER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_column.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace mindrecord { +class ShardWriter { + public: + ShardWriter(); + + ~ShardWriter(); + + /// \brief Open file at the beginning + /// \param[in] paths the file names list + /// \param[in] append new data at the end of file if true, otherwise overwrite file + /// \return MSRStatus the status of MSRStatus + MSRStatus Open(const std::vector &paths, bool append = false); + + /// \brief Open file at the ending + /// \param[in] paths the file names list + /// \return MSRStatus the status of MSRStatus + MSRStatus OpenForAppend(const std::string &path); + + /// \brief Write header to disk + /// \return MSRStatus the status of MSRStatus + MSRStatus Commit(); + + /// \brief Set file size + /// \param[in] header_size the size of header, only (1< header_data); + + /// \brief write raw data by group size + /// \param[in] raw_data the vector of raw json data, vector format + /// \param[in] blob_data the vector of image data + /// \param[in] sign validate data or not + /// \return MSRStatus the status of MSRStatus to judge if write successfully + MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, + bool sign = true, bool parallel_writer = false); + + /// \brief write raw data by group size for call from python + /// \param[in] raw_data the vector of raw json data, python-handle format + /// \param[in] blob_data the vector of image data + /// \param[in] sign validate data or not + /// \return MSRStatus the status of MSRStatus to judge if write successfully + MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, + bool sign = true, bool parallel_writer = false); + + /// \brief write raw data by group size for call from python + /// \param[in] raw_data the vector of raw json data, python-handle format + /// \param[in] blob_data the vector of blob json data, python-handle format + /// \param[in] sign validate data or not + /// \return MSRStatus the status of MSRStatus to judge if write successfully + MSRStatus WriteRawData(std::map> &raw_data, + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); + + private: + /// \brief write shard header data to disk + MSRStatus WriteShardHeader(); + + /// \brief erase error data + void DeleteErrorData(std::map> &raw_data, std::vector> &blob_data); + + /// \brief populate error data + void PopulateMutexErrorData(const int &row, const std::string &message, std::map &err_raw_data); + + /// \brief check data + void CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, + std::map &err_raw_data); + + /// \brief write shard header data to disk + std::tuple ValidateRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign); + + /// \brief fill data array in multiple thread run + void FillArray(int start, int end, std::map> &raw_data, + std::vector> &bin_data); + + /// \brief serialized raw data + MSRStatus SerializeRawData(std::map> &raw_data, + std::vector> &bin_data, uint32_t row_count); + + /// \brief write all data parallel + MSRStatus ParallelWriteData(const std::vector> &blob_data, + const std::vector> &bin_raw_data); + + /// \brief write data shard by shard + MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector> &blob_data, + const std::vector> &bin_raw_data); + + /// \brief break image data up into multiple row groups + MSRStatus CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, + std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, + const std::shared_ptr &last_blob_page); + + /// \brief append partial blob data to previous page + MSRStatus AppendBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page); + + /// \brief write new blob data page to disk + MSRStatus NewBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page); + + /// \brief shift last row group to next raw page for new appending + MSRStatus ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page); + + /// \brief write raw data page to disk + MSRStatus WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data); + + /// \brief generate empty raw data page + void EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page); + + /// \brief append a row group at the end of raw page + MSRStatus AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, + const int &chunk_id, int &last_row_groupId, std::shared_ptr last_raw_page, + const std::vector> &bin_raw_data); + + /// \brief write blob chunk to disk + MSRStatus FlushBlobChunk(const std::shared_ptr &out, const std::vector> &blob_data, + const std::pair &blob_row); + + /// \brief write raw chunk to disk + MSRStatus FlushRawChunk(const std::shared_ptr &out, + const std::vector> &rows_in_group, const int &chunk_id, + const std::vector> &bin_raw_data); + + /// \brief break up into tasks by shard + std::vector> BreakIntoShards(); + + /// \brief calculate raw data size row by row + MSRStatus SetRawDataSize(const std::vector> &bin_raw_data); + + /// \brief calculate blob data size row by row + MSRStatus SetBlobDataSize(const std::vector> &blob_data); + + /// \brief populate last raw page pointer + void SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page); + + /// \brief populate last blob page pointer + void SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page); + + /// \brief check the data by schema + MSRStatus CheckData(const std::map> &raw_data); + + /// \brief check the data and type + MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, + std::map &err_raw_data); + + /// \brief Lock writer and save pages info + int LockWriter(bool parallel_writer = false); + + /// \brief Unlock writer and save pages info + MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + + /// \brief Check raw data before writing + MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); + + /// \brief Get full path from file name + MSRStatus GetFullPathFromFileName(const std::vector &paths); + + /// \brief Open files + MSRStatus OpenDataFiles(bool append); + + /// \brief Remove lock file + MSRStatus RemoveLockFile(); + + /// \brief Remove lock file + MSRStatus InitLockFile(); + + private: + const std::string kLockFileSuffix = "_Locker"; + const std::string kPageFileSuffix = "_Pages"; + std::string lock_file_; // lock file for parallel run + std::string pages_file_; // temporary file of pages info for parallel run + + int shard_count_; // number of files + uint64_t header_size_; // header size + uint64_t page_size_; // page size + uint32_t row_count_; // count of rows + uint32_t schema_count_; // count of schemas + + std::vector raw_data_size_; // Raw data size + std::vector blob_data_size_; // Blob data size + + std::vector file_paths_; // file paths + std::vector> file_streams_; // file handles + std::shared_ptr shard_header_; // shard header + std::shared_ptr shard_column_; // shard columns + + std::map> err_mg_; // used for storing error raw_data info + + std::mutex check_mutex_; // mutex for data check + std::atomic flag_{false}; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_WRITER_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc new file mode 100644 index 0000000000..f9b18a3bf0 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc @@ -0,0 +1,626 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "common/utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) + : file_path_(file_path), + append_(append), + page_size_(0), + header_size_(0), + schema_count_(0), + task_(0), + write_success_(true) {} + +MSRStatus ShardIndexGenerator::Build() { + auto ret = ShardHeader::BuildSingleHeader(file_path_); + if (ret.first != SUCCESS) { + return FAILED; + } + auto json_header = ret.second; + + auto ret2 = GetParentDir(file_path_); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : json_header["shard_addresses"]) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + ShardHeader header = ShardHeader(); + if (header.BuildDataset(real_addresses) == FAILED) { + return FAILED; + } + shard_header_ = header; + MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; + return SUCCESS; +} + +std::pair ShardIndexGenerator::GetValueByField(const string &field, json input) { + if (field.empty()) { + MS_LOG(ERROR) << "The input field is None."; + return {FAILED, ""}; + } + + if (input.empty()) { + MS_LOG(ERROR) << "The input json is None."; + return {FAILED, ""}; + } + + // parameter input does not contain the field + if (input.find(field) == input.end()) { + MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; + return {FAILED, ""}; + } + + // schema does not contain the field + auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; + if (schema.find(field) == schema.end()) { + MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; + return {FAILED, ""}; + } + + // field should be scalar type + if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { + MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; + return {FAILED, ""}; + } + + if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { + auto schema_field_options = schema[field]; + if (schema_field_options.find("shape") == schema_field_options.end()) { + return {SUCCESS, input[field].dump()}; + } else { + // field with shape option + MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; + return {FAILED, ""}; + } + } + + // the field type is string in here + return {SUCCESS, input[field].get()}; +} + +std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { + std::vector field_name = StringSplit(field_path, kPoint); + for (uint64_t i = 0; i < field_name.size(); i++) { + if (i != field_name.size() - 1) { + // Get type information from json schema + schema = schema.at(field_name[i]); + schema = schema.at("properties"); + } else { + // standard root layer exist "properties" if type is "object" + if (schema.find("properties") != schema.end()) { + schema = schema.at("properties"); + } + schema = schema.at(field_name[i]); + std::string field_type = schema.at("type").dump(); + if (field_type.length() <= 2) { + return ""; + } else { + return field_type.substr(1, field_type.length() - 2); + } + } + } + return ""; +} + +std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) { + if (kDbJsonMap.find(json) != kDbJsonMap.end()) { + return kDbJsonMap.at(json); + } else { + return "TEXT"; + } +} + +int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) { + for (auto i = 0; i < argc; i++) { + if (argv[i] != nullptr) { + MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr"); + } + } + MS_LOG(INFO) << "\n"; + return 0; +} + +MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { + char *z_err_msg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Sql error: " << z_err_msg; + sqlite3_free(z_err_msg); + return FAILED; + } else { + if (!success_msg.empty()) { + MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; + } + sqlite3_free(z_err_msg); + return SUCCESS; + } +} + +std::pair ShardIndexGenerator::GenerateFieldName( + const std::pair &field) { + // Replaces dots and dashes with underscores for SQL use + std::string field_name = field.second; + // white list to avoid sql injection + std::replace_if( + field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_'); + auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { + return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); + }); + if (pos != field_name.end()) { + MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; + return {FAILED, ""}; + } + return {SUCCESS, field_name + "_" + std::to_string(field.first)}; +} + +std::pair ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { + sqlite3 *db = nullptr; + std::ifstream fin(common::SafeCStr(shard_address)); + if (!append_ && fin.good()) { + MS_LOG(ERROR) << "DB file already exist"; + fin.close(); + return {FAILED, nullptr}; + } + fin.close(); + int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); + if (rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return {FAILED, nullptr}; + } else { + MS_LOG(DEBUG) << "Opened database successfully"; + return {SUCCESS, db}; + } +} + +MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { + // create shard_name table + std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; + if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { + return FAILED; + } + sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; + if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { + return FAILED; + } + sql = "INSERT INTO SHARD_NAME (NAME) VALUES ('" + shard_name + "');"; + if (ExecuteSQL(sql, db, "insert name successfully.") != SUCCESS) { + return FAILED; + } + return SUCCESS; +} + +std::pair ShardIndexGenerator::CreateDatabase(int shard_no) { + std::string shard_address = shard_header_.GetShardAddressByID(shard_no); + if (shard_address.empty()) { + MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; + return {FAILED, nullptr}; + } + + string shard_name = GetFileName(shard_address).second; + shard_address += ".db"; + auto ret1 = CheckDatabase(shard_address); + if (ret1.first != SUCCESS) { + return {FAILED, nullptr}; + } + sqlite3 *db = ret1.second; + std::string sql = "DROP TABLE IF EXISTS INDEXES;"; + if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { + return {FAILED, nullptr}; + } + sql = + "CREATE TABLE INDEXES(" + " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" + ", PAGE_OFFSET_RAW INT NOT NULL, PAGE_OFFSET_RAW_END INT NOT NULL" + ", ROW_GROUP_ID INT NOT NULL, PAGE_ID_BLOB INT NOT NULL" + ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; + + int field_no = 0; + for (const auto &field : fields_) { + uint64_t schema_id = field.first; + auto result = shard_header_.GetSchemaByID(schema_id); + if (result.second != SUCCESS) { + return {FAILED, nullptr}; + } + json json_schema = (result.first->GetSchema())["schema"]; + std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, nullptr}; + } + sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; + } + sql += ", PRIMARY KEY(ROW_ID"; + for (uint64_t i = 0; i < fields_.size(); ++i) sql += ",INC_" + std::to_string(i); + sql += "));"; + if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { + return {FAILED, nullptr}; + } + + if (CreateShardNameTable(db, shard_name) != SUCCESS) { + return {FAILED, nullptr}; + } + return {SUCCESS, db}; +} + +std::pair> ShardIndexGenerator::GetSchemaDetails(const std::vector &schema_lens, + std::fstream &in) { + std::vector schema_details; + if (schema_count_ <= kMaxSchemaCount) { + for (int sc = 0; sc < schema_count_; ++sc) { + std::vector schema_detail(schema_lens[sc]); + + auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + in.close(); + return {FAILED, {}}; + } + + schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); + } + } + + return {SUCCESS, schema_details}; +} + +std::pair ShardIndexGenerator::GenerateRawSQL( + const std::vector> &fields) { + std::string sql = + "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," + "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; + + int field_no = 0; + for (const auto &field : fields) { + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, ""}; + } + sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; + } + sql += + ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," + ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; + field_no = 0; + for (const auto &field : fields) { + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, ""}; + } + sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; + } + sql += " )"; + return {SUCCESS, sql}; +} + +MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( + sqlite3 *db, const std::string &sql, + const std::vector>> &data) { + sqlite3_stmt *stmt = nullptr; + if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; + return FAILED; + } + for (auto &row : data) { + for (auto &field : row) { + const auto &place_holder = std::get<0>(field); + const auto &field_type = std::get<1>(field); + const auto &field_value = std::get<2>(field); + + int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); + if (field_type == "INTEGER") { + if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index + << ", field value: " << std::stoll(field_value); + return FAILED; + } + } else if (field_type == "NUMERIC") { + if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index + << ", field value: " << std::stold(field_value); + return FAILED; + } + } else if (field_type == "NULL") { + if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; + return FAILED; + } + } else { + if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; + return FAILED; + } + } + } + if (sqlite3_step(stmt) != SQLITE_DONE) { + MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; + return FAILED; + } + (void)sqlite3_reset(stmt); + } + (void)sqlite3_finalize(stmt); + return SUCCESS; +} + +MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, + const std::shared_ptr cur_blob_page, + uint64_t &cur_blob_page_offset, std::fstream &in) { + row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); + + // blob data start + row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); + auto &io_seekg_blob = + in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); + if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + in.close(); + return FAILED; + } + + uint64_t image_size = 0; + + auto &io_read = in.read(reinterpret_cast(&image_size), kInt64Len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + in.close(); + return FAILED; + } + + cur_blob_page_offset += (kInt64Len + image_size); + row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); + + return SUCCESS; +} + +void ShardIndexGenerator::AddIndexFieldByRawData( + const std::vector &schema_detail, std::vector> &row_data) { + auto result = GenerateIndexFields(schema_detail); + if (result.first == SUCCESS) { + int index = 0; + for (const auto &field : result.second) { + // assume simple field: string , number etc. + row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); + row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); + } + } +} + +ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, + int raw_page_id, std::fstream &in) { + std::vector>> full_data; + + // current raw data page + std::shared_ptr cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; + + // related blob page + vector> row_group_list = cur_raw_page->GetRowGroupIds(); + + // pair: row_group id, offset in raw data page + for (pair blob_ids : row_group_list) { + // get blob data page according to row_group id + std::shared_ptr cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first; + + // offset in current raw data page + auto cur_raw_page_offset = static_cast(blob_ids.second); + uint64_t cur_blob_page_offset = 0; + for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { + std::vector> row_data; + row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); + row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); + row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); + + // raw data start + row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); + + // calculate raw data end + auto &io_seekg = + in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + in.close(); + return {FAILED, {}}; + } + + std::vector schema_lens; + if (schema_count_ <= kMaxSchemaCount) { + for (int sc = 0; sc < schema_count_; sc++) { + uint64_t schema_size = 0; + + auto &io_read = in.read(reinterpret_cast(&schema_size), kInt64Len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + in.close(); + return {FAILED, {}}; + } + + cur_raw_page_offset += (kInt64Len + schema_size); + schema_lens.push_back(schema_size); + } + } + row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); + + // Getting schema for getting data for fields + auto st_schema_detail = GetSchemaDetails(schema_lens, in); + if (st_schema_detail.first != SUCCESS) { + return {FAILED, {}}; + } + + // start blob page info + if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { + return {FAILED, {}}; + } + + // start index field + AddIndexFieldByRawData(st_schema_detail.second, row_data); + full_data.push_back(std::move(row_data)); + } + } + return {SUCCESS, full_data}; +} + +INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &schema_detail) { + std::vector> fields; + // index fields + std::vector> index_fields = shard_header_.GetFields(); + for (const auto &field : index_fields) { + if (field.first >= schema_detail.size()) { + return {FAILED, {}}; + } + auto field_value = GetValueByField(field.second, schema_detail[field.first]); + if (field_value.first != SUCCESS) { + MS_LOG(ERROR) << "Get value from json by field name failed"; + return {FAILED, {}}; + } + + auto result = shard_header_.GetSchemaByID(field.first); + if (result.second != SUCCESS) { + return {FAILED, {}}; + } + + std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); + auto ret = GenerateFieldName(field); + if (ret.first != SUCCESS) { + return {FAILED, {}}; + } + + fields.emplace_back(ret.second, field_type, field_value.second); + } + return {SUCCESS, std::move(fields)}; +} + +MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair &db, + const std::vector &raw_page_ids, + const std::map &blob_id_to_page_id) { + // Add index data to database + std::string shard_address = shard_header_.GetShardAddressByID(shard_no); + if (shard_address.empty()) { + MS_LOG(ERROR) << "Shard address is null"; + return FAILED; + } + + std::fstream in; + in.open(common::SafeCStr(shard_address), std::ios::in | std::ios::binary); + if (!in.good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); + for (int raw_page_id : raw_page_ids) { + auto sql = GenerateRawSQL(fields_); + if (sql.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw SQL failed"; + return FAILED; + } + auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); + if (data.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw data failed"; + return FAILED; + } + if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + MS_LOG(ERROR) << "Execute SQL failed"; + return FAILED; + } + MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; + } + (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); + in.close(); + + // Close database + if (sqlite3_close(db.second) != SQLITE_OK) { + MS_LOG(ERROR) << "Close database failed"; + return FAILED; + } + db.second = nullptr; + return SUCCESS; +} + +MSRStatus ShardIndexGenerator::WriteToDatabase() { + fields_ = shard_header_.GetFields(); + page_size_ = shard_header_.GetPageSize(); + header_size_ = shard_header_.GetHeaderSize(); + schema_count_ = shard_header_.GetSchemaCount(); + if (shard_header_.GetShardCount() > kMaxShardCount) { + MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; + return FAILED; + } + task_ = 0; // set two atomic vars to initial value + write_success_ = true; + + // spawn half the physical threads or total number of shards whichever is smaller + const unsigned int num_workers = + std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.GetShardCount())); + + std::vector threads; + threads.reserve(num_workers); + + for (size_t t = 0; t < threads.capacity(); t++) { + threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this)); + } + + for (size_t t = 0; t < threads.capacity(); t++) { + threads[t].join(); + } + return write_success_ ? SUCCESS : FAILED; +} + +void ShardIndexGenerator::DatabaseWriter() { + int shard_no = task_++; + while (shard_no < shard_header_.GetShardCount()) { + auto db = CreateDatabase(shard_no); + if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { + write_success_ = false; + return; + } + + MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; + + // Pre-processing page information + auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; + + std::map blob_id_to_page_id; + std::vector raw_page_ids; + for (uint64_t i = 0; i < total_pages; ++i) { + std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; + if (cur_page->GetPageType() == "RAW_DATA") { + raw_page_ids.push_back(i); + } else if (cur_page->GetPageType() == "BLOB_DATA") { + blob_id_to_page_id[cur_page->GetPageTypeID()] = i; + } + } + + if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { + write_success_ = false; + return; + } + MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; + shard_no = task_++; + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc new file mode 100644 index 0000000000..84d7fddb6f --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -0,0 +1,1449 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_distributed_sample.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "common/utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +template +// convert the string to exactly number type (int32_t/int64_t/float/double) +Type StringToNum(const std::string &str) { + std::istringstream iss(str); + Type num; + iss >> num; + return num; +} + +ShardReader::ShardReader() { + task_id_ = 0; + deliver_id_ = 0; + shard_count_ = 0; + n_consumer_ = 0; + page_size_ = 0; + header_size_ = 0; + num_rows_ = 0; + row_id_ = 0; + num_blocks_ = 0; + block_reader_ = false; + num_padded_ = 0; +} + +std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { + if (!IsLegalFile(file_path)) { + return {FAILED, {}}; + } + auto ret = ShardHeader::BuildSingleHeader(file_path); + if (ret.first != SUCCESS) { + return {FAILED, {}}; + } + auto header = ret.second; + meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, + {"version", header["version"]}, {"index_fields", header["index_fields"]}, + {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; + return {SUCCESS, header["shard_addresses"]}; +} + +MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { + std::string file_path = file_paths[0]; + json first_meta_data = json(); + auto ret = GetMeta(file_path, first_meta_data); + if (ret.first != SUCCESS) { + return FAILED; + } + if (file_paths.size() == 1 && load_dataset == true) { + auto ret2 = GetParentDir(file_path); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : ret.second) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + file_paths_ = real_addresses; + } else if (file_paths.size() >= 1 && load_dataset == false) { + file_paths_ = file_paths; + } else { + MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; + return FAILED; + } + for (const auto &file : file_paths_) { + json meta_data = json(); + auto ret1 = GetMeta(file, meta_data); + if (ret1.first != SUCCESS) { + return FAILED; + } + if (meta_data != first_meta_data) { + MS_LOG(ERROR) << "Mindrecord files meta information is different."; + return FAILED; + } + sqlite3 *db = nullptr; + // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it + int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return FAILED; + } + MS_LOG(DEBUG) << "Opened database successfully"; + + string sql = "select NAME from SHARD_NAME;"; + std::vector> name; + char *errmsg = nullptr; + rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &name, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return FAILED; + } else { + MS_LOG(DEBUG) << "Get " << static_cast(name.size()) << " records from index."; + string shardName = GetFileName(file).second; + if (name.empty() || name[0][0] != shardName) { + MS_LOG(ERROR) << "DB file can not match file " << file; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return FAILED; + } + } + database_paths_.push_back(db); + } + ShardHeader sh = ShardHeader(); + if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { + return FAILED; + } + shard_header_ = std::make_shared(sh); + header_size_ = shard_header_->GetHeaderSize(); + page_size_ = shard_header_->GetPageSize(); + // version < 3.0 + if (first_meta_data["version"] < kVersion) { + shard_column_ = std::make_shared(shard_header_, false); + } else { + shard_column_ = std::make_shared(shard_header_, true); + } + num_rows_ = 0; + auto row_group_summary = ReadRowGroupSummary(); + for (const auto &rg : row_group_summary) { + num_rows_ += std::get<3>(rg); + } + + MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; + + return SUCCESS; +} + +MSRStatus ShardReader::CheckColumnList(const std::vector &selected_columns) { + vector inSchema(selected_columns.size(), 0); + for (auto &p : GetShardHeader()->GetSchemas()) { + auto schema = p->GetSchema()["schema"]; + for (unsigned int i = 0; i < selected_columns.size(); ++i) { + if (schema.find(selected_columns[i]) != schema.end()) { + inSchema[i] = 1; + } + } + } + if (std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; })) { + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardReader::Open() { + file_streams_.clear(); + + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + MS_LOG(INFO) << "Open shard file successfully."; + file_streams_.push_back(fs); + } + + return SUCCESS; +} + +MSRStatus ShardReader::Open(int n_consumer) { + file_streams_random_ = + std::vector>>(n_consumer, std::vector>()); + for (const auto &file : file_paths_) { + for (int j = 0; j < n_consumer; ++j) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + file_streams_random_[j].push_back(fs); + } + MS_LOG(INFO) << "Open shard file successfully."; + } + + return SUCCESS; +} + +void ShardReader::FileStreamsOperator() { + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; --i) { + if (file_streams_[i] != nullptr) { + file_streams_[i]->close(); + } + } + for (int i = static_cast(file_streams_random_.size()) - 1; i >= 0; --i) { + for (int j = static_cast(file_streams_random_[i].size()) - 1; j >= 0; --j) { + if (file_streams_random_[i][j] != nullptr) { + file_streams_random_[i][j]->close(); + } + } + } + for (int i = static_cast(database_paths_.size()) - 1; i >= 0; --i) { + if (database_paths_[i] != nullptr) { + auto ret = sqlite3_close(database_paths_[i]); + if (ret != SQLITE_OK) { + MS_LOG(ERROR) << "Close db failed. Error code: " << ret << "."; + } + database_paths_[i] = nullptr; + } + } +} + +ShardReader::~ShardReader() { Close(); } + +void ShardReader::Close() { + (void)Finish(); // interrupt reading and stop threads + FileStreamsOperator(); +} + +std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } + +std::shared_ptr ShardReader::GetShardColumn() const { return shard_column_; } + +int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } + +int ShardReader::GetNumRows() const { return num_rows_; } + +std::vector> ShardReader::ReadRowGroupSummary() { + std::vector> row_group_summary; + int shard_count = shard_header_->GetShardCount(); + if (shard_count <= 0) { + return row_group_summary; + } + if (shard_count <= kMaxShardCount) { + for (int shard_id = 0; shard_id < shard_count; ++shard_id) { + // return -1 when page's size equals to 0. + auto last_page_id = shard_header_->GetLastPageId(shard_id); + if (static_cast(last_page_id) == -1) { + continue; + } + for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { + const auto &page_t = shard_header_->GetPage(shard_id, page_id); + const auto &page = page_t.first; + if (page->GetPageType() != kPageTypeBlob) continue; + uint64_t start_row_id = page->GetStartRowID(); + if (start_row_id > page->GetEndRowID()) { + return std::vector>(); + } + uint64_t number_of_rows = page->GetEndRowID() - start_row_id; + row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); + } + } + } + return row_group_summary; +} + +MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, + std::shared_ptr fs, + std::vector>> &offsets, int shard_id, + const std::vector &columns, + std::vector> &column_values) { + for (int i = 0; i < static_cast(labels.size()); ++i) { + uint64_t group_id = std::stoull(labels[i][0]); + uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len; + uint64_t offset_end = std::stoull(labels[i][2]); + offsets[shard_id].emplace_back( + std::vector{static_cast(shard_id), group_id, offset_start, offset_end}); + if (!all_in_index_) { + int raw_page_id = std::stoi(labels[i][3]); + uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len; + uint64_t label_end = std::stoull(labels[i][5]); + auto len = label_end - label_start; + auto label_raw = std::vector(len); + auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + fs->close(); + return FAILED; + } + + auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + fs->close(); + return FAILED; + } + json label_json = json::from_msgpack(label_raw); + json tmp; + if (!columns.empty()) { + for (auto &col : columns) { + if (label_json.find(col) != label_json.end()) { + tmp[col] = label_json[col]; + } + } + } else { + tmp = label_json; + } + column_values[shard_id].emplace_back(tmp); + } else { + json construct_json; + for (unsigned int j = 0; j < columns.size(); ++j) { + // construct json "f1": value + auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; + + // convert the string to base type by schema + if (schema[columns[j]]["type"] == "int32") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "int64") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "float32") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else if (schema[columns[j]]["type"] == "float64") { + construct_json[columns[j]] = StringToNum(labels[i][j + 3]); + } else { + construct_json[columns[j]] = std::string(labels[i][j + 3]); + } + } + column_values[shard_id].emplace_back(construct_json); + } + } + + return SUCCESS; +} + +MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, + std::vector>> &offsets, + std::vector> &column_values) { + auto db = database_paths_[shard_id]; + std::vector> labels; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return FAILED; + } + MS_LOG(INFO) << "Get " << static_cast(labels.size()) << " records from shard " << shard_id << " index."; + + std::string file_name = file_paths_[shard_id]; + std::shared_ptr fs = std::make_shared(); + if (!all_in_index_) { + fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return FAILED; + } + } + sqlite3_free(errmsg); + return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); +} + +MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { + std::map index_columns; + for (auto &field : GetShardHeader()->GetFields()) { + index_columns[field.second] = field.first; + } + if (index_columns.find(category_field) == index_columns.end()) { + MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; + return FAILED; + } + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); + if (SUCCESS != ret.first) { + return FAILED; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count_; x++) { + threads[x].join(); + } + return SUCCESS; +} + +void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, + std::set &categories) { + if (nullptr == db) { + return; + } + std::vector> columns; + char *errmsg = nullptr; + int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); + if (ret != SQLITE_OK) { + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; + return; + } + MS_LOG(INFO) << "Get " << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + std::lock_guard lck(shard_locker_); + for (int i = 0; i < static_cast(columns.size()); ++i) { + categories.emplace(columns[i][0]); + } +} + +ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { + std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; + std::vector>> offsets(shard_count_, std::vector>{}); + std::vector> column_values(shard_count_, std::vector{}); + if (all_in_index_) { + for (unsigned int i = 0; i < columns.size(); ++i) { + fields += ','; + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); + if (ret.first != SUCCESS) { + return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); + } + fields += ret.second; + } + } else { // fetch raw data from Raw page while some field is not index. + fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; + } + + std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;"; + + std::vector thread_read_db = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + thread_read_db[x] = + std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values)); + } + + for (int x = 0; x < shard_count_; x++) { + thread_read_db[x].join(); + } + return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); +} + +ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + const std::shared_ptr &page = ret.second; + std::string file_name = file_paths_[shard_id]; + uint64_t page_length = page->GetPageSize(); + uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id); + + auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); + if (status_labels.first != SUCCESS) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), + std::move(status_labels.second)); +} + +ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, + const std::pair &criteria, + const std::vector &columns) { + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + vector criteria_list{criteria.first}; + if (CheckColumnList(criteria_list) == FAILED) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + const std::shared_ptr &page = ret.second; + std::string file_name = file_paths_[shard_id]; + uint64_t page_length = page->GetPageSize(); + uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); + + auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); + if (status_labels.first != SUCCESS) { + return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); + } + + return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), + std::move(status_labels.second)); +} + +int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) { + auto *records = static_cast> *>(p_data); + if (num_fields > 0 && num_fields <= kMaxFieldCount) { + for (int i = 0; i < num_fields; ++i) + if (p_fields[i] == nullptr) p_fields[i] = const_cast(""); + } + records->emplace_back(p_fields, p_fields + num_fields); + return 0; +} + +std::vector> ShardReader::GetImageOffset(int page_id, int shard_id, + const std::pair &criteria) { + auto db = database_paths_[shard_id]; + + std::string sql = + "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); + + // whether use index search + if (!criteria.first.empty()) { + auto schema = shard_header_->GetSchemas()[0]->GetSchema(); + + // not number field should add '' in sql + if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { + sql += + " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; + } else { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + + criteria.second + "'"; + } + } + sql += ";"; + std::vector> image_offsets; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return std::vector>(); + } else { + MS_LOG(DEBUG) << "Get " << static_cast(image_offsets.size()) << "records from index."; + } + std::vector> res; + for (int i = static_cast(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector{0, 0}); + for (int i = 0; i < static_cast(image_offsets.size()); i++) { + const auto &image_offset = image_offsets[i]; + res[i][0] = std::stoull(image_offset[0]) + kInt64Len; + res[i][1] = std::stoull(image_offset[1]); + } + sqlite3_free(errmsg); + return res; +} + +std::pair> ShardReader::GetBlobFields() { + std::vector blob_fields; + for (auto &p : GetShardHeader()->GetSchemas()) { + // assume one schema + const auto &fields = p->GetBlobFields(); + blob_fields.assign(fields.begin(), fields.end()); + break; + } + return std::make_pair(kCV, blob_fields); +} + +void ShardReader::CheckIfColumnInIndex(const std::vector &columns) { + // assume different schemas do not contain same key. + if (columns.empty()) { + all_in_index_ = false; + return; + } + for (auto &field : GetShardHeader()->GetFields()) { + column_schema_id_[field.second] = field.first; + } + for (auto &col : columns) { + if (column_schema_id_.find(col) == column_schema_id_.end()) { + all_in_index_ = false; + return; + } + } +} + +MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria, + std::vector> &labels) { + sqlite3_stmt *stmt = nullptr; + if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not prepare statement"; + return FAILED; + } + int index = sqlite3_bind_parameter_index(stmt, ":criteria"); + if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) { + MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << criteria; + return FAILED; + } + int rc = sqlite3_step(stmt); + while (rc != SQLITE_DONE) { + vector tmp; + int ncols = sqlite3_column_count(stmt); + for (int i = 0; i < ncols; i++) { + tmp.emplace_back(reinterpret_cast(sqlite3_column_text(stmt, i))); + } + labels.push_back(tmp); + rc = sqlite3_step(stmt); + } + (void)sqlite3_finalize(stmt); + return SUCCESS; +} + +std::pair> ShardReader::GetLabelsFromBinaryFile( + int shard_id, const std::vector &columns, const std::vector> &label_offsets) { + std::string file_name = file_paths_[shard_id]; + std::vector res; + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "File could not opened"; + return {FAILED, {}}; + } + + // init the return + for (unsigned int i = 0; i < label_offsets.size(); ++i) { + res.emplace_back(json{}); + } + + for (unsigned int i = 0; i < label_offsets.size(); ++i) { + const auto &labelOffset = label_offsets[i]; + uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len; + uint64_t label_end = std::stoull(labelOffset[2]); + int raw_page_id = std::stoi(labelOffset[0]); + auto len = label_end - label_start; + auto label_raw = std::vector(len); + auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + fs->close(); + return {FAILED, {}}; + } + + auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + fs->close(); + return {FAILED, {}}; + } + + json label_json = json::from_msgpack(label_raw); + json tmp = label_json; + for (auto &col : columns) { + if (label_json.find(col) != label_json.end()) { + tmp[col] = label_json[col]; + } + } + res[i] = tmp; + } + return {SUCCESS, res}; +} + +std::pair> ShardReader::GetLabelsFromPage( + int page_id, int shard_id, const std::vector &columns, + const std::pair &criteria) { + // get page info from sqlite + auto db = database_paths_[shard_id]; + std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + + std::to_string(page_id); + std::vector> label_offsets; + if (!criteria.first.empty()) { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; + if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) { + return {FAILED, {}}; + } + } else { + sql += ";"; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return {FAILED, {}}; + } + MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index."; + sqlite3_free(errmsg); + } + // get labels from binary file + return GetLabelsFromBinaryFile(shard_id, columns, label_offsets); +} + +std::pair> ShardReader::GetLabels(int page_id, int shard_id, + const std::vector &columns, + const std::pair &criteria) { + if (all_in_index_) { + auto db = database_paths_[shard_id]; + std::string fields; + for (unsigned int i = 0; i < columns.size(); ++i) { + if (i > 0) fields += ','; + uint64_t schema_id = column_schema_id_[columns[i]]; + fields += columns[i] + "_" + std::to_string(schema_id); + } + if (fields.empty()) fields = "*"; + std::vector> labels; + std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); + if (!criteria.first.empty()) { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; + if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) { + return {FAILED, {}}; + } + } else { + sql += ";"; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return {FAILED, {}}; + } else { + MS_LOG(DEBUG) << "Get " << static_cast(labels.size()) << "records from index."; + } + sqlite3_free(errmsg); + } + std::vector ret; + for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); + for (unsigned int i = 0; i < labels.size(); ++i) { + json construct_json; + for (unsigned int j = 0; j < columns.size(); ++j) { + // construct json "f1": value + auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; + + // convert the string to base type by schema + if (schema[columns[j]]["type"] == "int32") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "int64") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "float32") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else if (schema[columns[j]]["type"] == "float64") { + construct_json[columns[j]] = StringToNum(labels[i][j]); + } else { + construct_json[columns[j]] = std::string(labels[i][j]); + } + } + ret[i] = construct_json; + } + return {SUCCESS, ret}; + } + return GetLabelsFromPage(page_id, shard_id, columns, criteria); +} + +bool ResortRowGroups(std::tuple a, std::tuple b) { + return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); +} + +MSRStatus ShardReader::Finish() { + { + std::lock_guard lck(mtx_delivery_); + interrupt_ = true; + } + cv_delivery_.notify_all(); + + // Wait for all threads to finish + for (auto &i_thread : thread_set_) { + if (i_thread.joinable()) { + i_thread.join(); + } + } + return SUCCESS; +} + +int64_t ShardReader::GetNumClasses(const std::string &category_field) { + auto shard_count = file_paths_.size(); + auto index_fields = shard_header_->GetFields(); + + std::map map_schema_id_fields; + for (auto &field : index_fields) { + map_schema_id_fields[field.second] = field.first; + } + + if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) { + MS_LOG(ERROR) << "Field " << category_field << " does not exist."; + return -1; + } + auto ret = + ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); + if (SUCCESS != ret.first) { + return -1; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count); + std::set categories; + for (int x = 0; x < shard_count; x++) { + sqlite3 *db = nullptr; + int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (SQLITE_OK != rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return -1; + } + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count; x++) { + threads[x].join(); + } + return categories.size(); +} + +MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &ops, int64_t *count, const int num_padded) { + if (SUCCESS != Init(file_paths, load_dataset)) { + return FAILED; + } + int64_t num_samples = num_rows_; + bool root = true; + std::stack> stack_ops; + std::shared_ptr op(ops); + while (op != nullptr) { + stack_ops.push(op); + op = op->GetChildOp(); + } + while (!stack_ops.empty()) { + op = stack_ops.top(); + stack_ops.pop(); + if (std::dynamic_pointer_cast(op)) { + num_samples = op->GetNumSamples(num_samples, 0); + if (num_padded > 0 && root == true) { + num_samples += num_padded; + MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; + root = false; + } + } else if (std::dynamic_pointer_cast(op)) { + auto category_op = std::dynamic_pointer_cast(op); + std::string category_field = category_op->GetCategoryField(); + auto num_classes = GetNumClasses(category_field); + num_samples = category_op->GetNumSamples(num_samples, num_classes); + } else if (std::dynamic_pointer_cast(op)) { + if (std::dynamic_pointer_cast(op)) { + auto sampler_op = std::dynamic_pointer_cast(op); + if (root == true) { + sampler_op->SetNumPaddedSamples(num_padded); + num_samples = op->GetNumSamples(num_samples, 0); + if (-1 == num_samples) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; + return FAILED; + } + root = false; + } + } else { + num_samples = op->GetNumSamples(num_samples, 0); + } + } else { + if (num_padded > 0) num_samples += num_padded; + } + } + *count = num_samples; + return SUCCESS; +} + +MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, + const std::vector &selected_columns, + const std::vector> &operators, const bool &block_reader, + int num_padded) { + // Open file and set header by ShardReader + auto ret = Init(file_paths, load_dataset); + if (SUCCESS != ret) { + return ret; + } + auto thread_limit = GetMaxThreadNum(); + if (n_consumer > thread_limit) { + n_consumer = thread_limit; + } + if (n_consumer < kMinConsumerCount) { + n_consumer = kMinConsumerCount; + } + vector blob_fields = GetBlobFields().second; + for (unsigned int i = 0; i < selected_columns.size(); ++i) { + if (!std::any_of(blob_fields.begin(), blob_fields.end(), + [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { + selected_columns_.push_back(selected_columns[i]); + } + } + selected_columns_ = selected_columns; + + if (CheckColumnList(selected_columns_) == FAILED) { + MS_LOG(ERROR) << "Illegal column list"; + return ILLEGAL_COLUMN_LIST; + } + + // Initialize argument + shard_count_ = static_cast(file_paths_.size()); + n_consumer_ = n_consumer; + num_padded_ = num_padded; + + operators_ = operators; + + if (block_reader) { + block_reader_ = true; + if (Open() == FAILED) { + return FAILED; + } + delivery_block_ = std::vector>, std::vector>>>( + kNumPageInBuffer, std::shared_ptr>, std::vector>>{}); + buf_ = std::vector>(kNumPageInBuffer, std::vector(page_size_)); + } else { + block_reader_ = false; + if (Open(n_consumer) == FAILED) { + return FAILED; + } + } + return SUCCESS; +} + +MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer, + const std::vector &selected_columns, + const std::vector> &operators) { + // Open file and set header by ShardReader + if (SUCCESS != Init(file_paths, load_dataset)) { + return FAILED; + } + // should remove blob field from selected_columns when call from python + std::vector columns(selected_columns); + auto blob_fields = GetBlobFields().second; + for (auto &blob_field : blob_fields) { + auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); + if (it != selected_columns.end()) { + columns.erase(columns.begin() + std::distance(selected_columns.begin(), it)); + } + } + if (CheckColumnList(columns) == FAILED) { + MS_LOG(ERROR) << "Illegal column list"; + return FAILED; + } + if (Open(n_consumer) == FAILED) { + return FAILED; + } + // Initialize argument + shard_count_ = static_cast(file_paths_.size()); + n_consumer_ = n_consumer; + + // Initialize columns which will be read + selected_columns_ = selected_columns; + operators_ = operators; + + return SUCCESS; +} + +MSRStatus ShardReader::Launch(bool isSimpleReader) { + // Get all row groups' info + auto row_group_summary = ReadRowGroupSummary(); + + // Sort row group by (group_id, shard_id), prepare for parallel reading + std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); + if (CreateTasks(row_group_summary, operators_) != SUCCESS) { + MS_LOG(ERROR) << "Failed to launch read threads."; + interrupt_ = true; + return FAILED; + } + if (isSimpleReader) return SUCCESS; + // Start provider consumer threads + thread_set_ = std::vector(n_consumer_); + if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) { + return FAILED; + } + + for (int x = 0; x < n_consumer_; ++x) { + if (block_reader_) { + thread_set_[x] = std::thread(&ShardReader::ConsumerByBlock, this, x); + } else { + thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); + } + } + + MS_LOG(INFO) << "Launch read thread successfully."; + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, + const std::vector> &operators) { + CheckIfColumnInIndex(selected_columns_); + for (const auto &rg : row_group_summary) { + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + auto n_Rows = std::get<3>(rg); + tasks_.InsertTask(TaskType::kCommonTask, shard_id, group_id, std::vector{n_Rows}, json{}); + } + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op) { + CheckIfColumnInIndex(selected_columns_); + auto category_op = std::dynamic_pointer_cast(op); + auto categories = category_op->GetCategories(); + int64_t num_elements = category_op->GetNumElements(); + if (num_elements <= 0) { + MS_LOG(ERROR) << "Parameter num_element is not positive"; + return FAILED; + } + if (categories.empty() == true) { + std::string category_field = category_op->GetCategoryField(); + int64_t num_categories = category_op->GetNumCategories(); + if (num_categories <= 0) { + MS_LOG(ERROR) << "Parameter num_categories is not positive"; + return FAILED; + } + std::set categories_set; + auto ret = GetAllClasses(category_field, categories_set); + if (SUCCESS != ret) { + return FAILED; + } + int i = 0; + for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { + categories.emplace_back(category_field, *it); + i++; + } + } + // Generate task list, a task will create a batch + std::vector categoryTasks(categories.size()); + for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { + int category_index = 0; + for (const auto &rg : row_group_summary) { + if (category_index >= num_elements) break; + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + + auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); + if (SUCCESS != std::get<0>(details)) { + return FAILED; + } + auto offsets = std::get<4>(details); + + auto number_of_rows = offsets.size(); + for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { + if (category_index < num_elements) { + categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], + std::get<5>(details)[iStart]); + category_index++; + } + } + } + MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; + } + tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); + if (SUCCESS != (*category_op)(tasks_)) { + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators) { + CheckIfColumnInIndex(selected_columns_); + + auto ret = ReadAllRowGroup(selected_columns_); + if (std::get<0>(ret) != SUCCESS) { + return FAILED; + } + auto offsets = std::get<1>(ret); + auto local_columns = std::get<2>(ret); + if (shard_count_ <= kMaxShardCount) { + for (int shard_id = 0; shard_id < shard_count_; shard_id++) { + for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { + tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], + std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, + local_columns[shard_id][i]); + } + } + } else { + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardReader::CreateTasks(const std::vector> &row_group_summary, + const std::vector> &operators) { + if (block_reader_) { + if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { + return FAILED; + } + } else { + int category_operator = -1; + for (uint32_t i = 0; i < operators.size(); ++i) { + const auto &op = operators[i]; + if (std::dynamic_pointer_cast(op)) { + category_operator = static_cast(i); + break; + } + } + if (-1 == category_operator) { + if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { + return FAILED; + } + if (num_padded_ > 0) { + for (int i = 0; i < num_padded_; ++i) { + tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); + } + } + } else { + if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { + return FAILED; + } + } + } + + for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { + const auto &op = operators[operator_no]; + if (std::dynamic_pointer_cast(op)) continue; + if (block_reader_ && std::dynamic_pointer_cast(op)) continue; + if (SUCCESS != (*op)(tasks_)) { + return FAILED; + } + } + + if (tasks_.permutation_.empty()) tasks_.MakePerm(); + num_rows_ = block_reader_ ? tasks_.SizeOfRows() : tasks_.Size(); + num_blocks_ = block_reader_ ? tasks_.Size() : 0; + MS_LOG(INFO) << "Total rows is " << num_rows_; + return SUCCESS; +} + +TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { + // All tasks are done + if (task_id >= static_cast(tasks_.Size())) { + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + + // Pick up task from task list + auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); + + // check task type + auto task_type = std::get<0>(task); + if (task_type == TaskType::kPaddedTask) { + return std::make_pair(SUCCESS, + std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); + } + + auto shard_id = std::get<0>(std::get<1>(task)); + auto group_id = std::get<1>(std::get<1>(task)); + auto addr = std::get<2>(task); + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + const std::shared_ptr &page = ret.second; + + // Pack image list + std::vector images(addr[1] - addr[0]); + auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; + + auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_random_[consumer_id][shard_id]->close(); + return std::make_pair(FAILED, + std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + + auto &io_read = + file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), addr[1] - addr[0]); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_random_[consumer_id][shard_id]->close(); + return std::make_pair(FAILED, + std::pair(TaskType::kCommonTask, std::vector, json>>())); + } + + // Deliver batch data to output map + std::vector, json>> batch; + batch.emplace_back(std::move(images), std::move(std::get<3>(task))); + + return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); +} + +MSRStatus ShardReader::ConsumerByRow(int consumer_id) { + // Set thread name +#if !defined(_WIN32) && !defined(_WIN64) + auto thread_id = kThreadName + std::to_string(consumer_id); + prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); +#endif + + // Loop forever + for (;;) { + int task_id = 0; + + // Get next task ID + task_id = task_id_++; + + // All tasks are done + if (task_id >= static_cast(tasks_.Size())) { + return FAILED; + } + const auto &ret = ConsumerOneTask(task_id, consumer_id); + if (SUCCESS != ret.first) { + return FAILED; + } + const auto &batch = (ret.second).second; + // Hanging if maximum map size exceeded + // otherwise, set batch data in map + { + std::unique_lock lck(mtx_delivery_); + cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); + if (interrupt_) { + return SUCCESS; + } + delivery_map_[task_id] = std::make_shared, json>>>(std::move(batch)); + } + cv_iterator_.notify_one(); + } +} + +MSRStatus ShardReader::ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, + const int &buf_id) { + auto &io_seekg = file_streams_[shard_id]->seekg(page_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf_[buf_id][0]), page_length); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { + // Set thread name +#if !defined(_WIN32) && !defined(_WIN64) + auto thread_id = kThreadName + std::to_string(consumer_id); + prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); +#endif + + // Loop forever + for (;;) { + int task_id = 0; + + // Get next task ID + task_id = task_id_++; + + // All tasks are done, either quit or repeat again + if (task_id >= num_blocks_) { + std::unique_lock lck(mtx_delivery_); + cv_delivery_.wait(lck, [this] { return interrupt_ || task_id_ < num_blocks_; }); + if (interrupt_) { + return SUCCESS; + } + continue; + } + + // Pick up task from task list + auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); + + auto shard_id = std::get<0>(std::get<1>(task)); + auto group_id = std::get<1>(std::get<1>(task)); + auto row_group_brief = ReadRowGroupBrief(group_id, shard_id, selected_columns_); + if (SUCCESS != std::get<0>(row_group_brief)) { + return FAILED; + } + auto page_length = std::get<2>(row_group_brief); + auto page_offset = std::get<3>(row_group_brief); + + MS_LOG(DEBUG) << "Block task " << task_id << tasks_.permutation_[task_id] << ", shard " << shard_id << ", group " + << group_id << ", page length " << page_length << ", page offset " << page_offset; + + // Deliver block data to output map + auto offset_and_labels = std::make_pair(std::get<4>(row_group_brief), std::get<5>(row_group_brief)); + + int deliver_id = deliver_id_; + // Hanging if maximum map size exceeded otherwise, set batch data in buffer + { + std::unique_lock lck(mtx_delivery_); + cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id < deliver_id_ + kNumPageInBuffer; }); + if (interrupt_) { + return SUCCESS; + } + } + + auto buf_id = task_id % kNumPageInBuffer; + delivery_block_[buf_id] = + std::make_shared>, std::vector>>(offset_and_labels); + + // Read blob + if (ReadBlob(shard_id, page_offset, page_length, buf_id) != SUCCESS) { + return FAILED; + } + + { + std::unique_lock lck(mtx_delivery_); + delivery_block_set_.insert(task_id); + } + cv_iterator_.notify_one(); + } +} + +std::shared_ptr, json>>> ShardReader::GetRowFromBuffer(int buf_id, + int rowId) { + auto &blob_page = buf_[buf_id]; + auto &offsets = (*delivery_block_[buf_id]).first; + auto &labels = (*delivery_block_[buf_id]).second; + auto &addr_start = offsets[rowId][0]; + auto &addr_end = offsets[rowId][1]; + std::vector images(blob_page.begin() + addr_start, blob_page.begin() + addr_end); + std::vector, json>> batch; + batch.emplace_back(std::move(images), std::move(labels[rowId])); + return std::make_shared, json>>>(std::move(batch)); +} + +std::vector, json>> ShardReader::GetBlockNext() { + if (deliver_id_ >= num_blocks_) { + return std::vector, json>>(); + } + + if (row_id_ == 0) { + std::unique_lock lck(mtx_delivery_); + cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_block_set_.count(deliver_id_) > 0); }); + + if (interrupt_) { + return std::vector, json>>(); + } + } + auto buf_id = deliver_id_ % kNumPageInBuffer; + auto res = GetRowFromBuffer(buf_id, row_id_); + + row_id_++; + if (row_id_ == (*delivery_block_[buf_id]).first.size()) { + row_id_ = 0; + { + std::unique_lock lck(mtx_delivery_); + delivery_block_set_.erase(deliver_id_++); + } + cv_delivery_.notify_all(); + } + + return *res; +} + +std::vector, json>> ShardReader::GetNext() { + if (interrupt_) { + return std::vector, json>>(); + } + if (block_reader_) return GetBlockNext(); + if (deliver_id_ >= static_cast(tasks_.Size())) { + return std::vector, json>>(); + } + + std::shared_ptr, json>>> res; + { + std::unique_lock lck(mtx_delivery_); + cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); }); + if (interrupt_) { + return std::vector, json>>(); + } + res = delivery_map_[deliver_id_]; + delivery_map_.erase(deliver_id_++); + } + + cv_delivery_.notify_all(); + + return *res; +} + +std::pair, json>>> ShardReader::GetNextById( + const int64_t &task_id, const int32_t &consumer_id) { + if (interrupt_) { + return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); + } + if (block_reader_) { + return std::make_pair(TaskType::kCommonTask, GetBlockNext()); + } + const auto &ret = ConsumerOneTask(task_id, consumer_id); + if (SUCCESS != ret.first) { + return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); + } + return std::move(ret.second); +} + +std::pair>> ShardReader::UnCompressBlob( + const std::vector &raw_blob_data) { + auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; + auto blob_fields = GetBlobFields().second; + std::vector> blob_data; + for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { + if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0; + auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); + if (ret != SUCCESS) { + MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; + return {FAILED, std::vector>(blob_fields.size(), std::vector())}; + } + if (data == nullptr) { + data = reinterpret_cast(data_ptr.get()); + } + std::vector column(data, data + (n_bytes / sizeof(unsigned char))); + blob_data.push_back(column); + } + return {SUCCESS, blob_data}; +} + +std::vector>, pybind11::object>> ShardReader::GetNextPy() { + auto res = GetNext(); + vector>, pybind11::object>> data; + std::transform(res.begin(), res.end(), std::back_inserter(data), + [this](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + auto ret = UnCompressBlob(std::get<0>(item)); + return std::make_tuple(ret.second, std::move(obj)); + }); + return data; +} + +void ShardReader::Reset() { + { + std::lock_guard lck(mtx_delivery_); + task_id_ = 0; + deliver_id_ = 0; + } + cv_delivery_.notify_all(); +} + +void ShardReader::ShuffleTask() { + if (block_reader_) return; + // exist shuffle and distributed sampler in ops, skip shuffle + bool has_sharding = false; + for (const auto &op : operators_) { + if (std::dynamic_pointer_cast(op)) { + has_sharding = true; + } + } + for (const auto &op : operators_) { + if (std::dynamic_pointer_cast(op) && has_sharding == false) { + if (SUCCESS != (*op)(tasks_)) { + MS_LOG(WARNING) << "Redo randomSampler failed."; + } + } else if (std::dynamic_pointer_cast(op)) { + if (SUCCESS != (*op)(tasks_)) { + MS_LOG(WARNING) << "Redo distributeSampler failed."; + } + } + } + if (tasks_.permutation_.empty()) tasks_.MakePerm(); +} + +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc new file mode 100644 index 0000000000..eda8924e13 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc @@ -0,0 +1,385 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_segment.h" +#include "common/utils.h" + +#include "./securec.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "pybind11/pybind11.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +ShardSegment::ShardSegment() { SetAllInIndex(false); } + +std::pair> ShardSegment::GetCategoryFields() { + // Skip if already populated + if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; + + std::string sql = "PRAGMA table_info(INDEXES);"; + std::vector> field_names; + + char *errmsg = nullptr; + int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(database_paths_[0]); + database_paths_[0] = nullptr; + return {FAILED, vector{}}; + } else { + MS_LOG(INFO) << "Get " << static_cast(field_names.size()) << " records from index."; + } + + uint32_t idx = kStartFieldId; + while (idx < field_names.size()) { + if (field_names[idx].size() < 2) { + sqlite3_free(errmsg); + sqlite3_close(database_paths_[0]); + database_paths_[0] = nullptr; + return {FAILED, vector{}}; + } + candidate_category_fields_.push_back(field_names[idx][1]); + idx += 2; + } + sqlite3_free(errmsg); + return {SUCCESS, candidate_category_fields_}; +} + +MSRStatus ShardSegment::SetCategoryField(std::string category_field) { + if (GetCategoryFields().first != SUCCESS) { + MS_LOG(ERROR) << "Get candidate category field failed"; + return FAILED; + } + category_field = category_field + "_0"; + if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), + [category_field](std::string x) { return x == category_field; })) { + current_category_field_ = category_field; + return SUCCESS; + } + MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; + return FAILED; +} + +std::pair ShardSegment::ReadCategoryInfo() { + MS_LOG(INFO) << "Read category begin"; + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info failed"; + return {FAILED, ""}; + } + // Convert category info to json string + auto category_json_string = ToJsonForCategory(ret.second); + + MS_LOG(INFO) << "Read category end"; + + return {SUCCESS, category_json_string}; +} + +std::pair>> ShardSegment::WrapCategoryInfo() { + std::map counter; + + std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + + ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; + + for (auto &db : database_paths_) { + std::vector> field_count; + + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return {FAILED, std::vector>()}; + } else { + MS_LOG(INFO) << "Get " << static_cast(field_count.size()) << " records from index."; + } + + for (const auto &field : field_count) { + counter[field[0]] += std::stoi(field[1]); + } + sqlite3_free(errmsg); + } + + int idx = 0; + std::vector> category_vec(counter.size()); + (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple item) { + return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); + }); + return {SUCCESS, std::move(category_vec)}; +} + +std::string ShardSegment::ToJsonForCategory(const std::vector> &tri_vec) { + std::vector category_json_vec; + for (auto q : tri_vec) { + json j; + j["id"] = std::get<0>(q); + j["name"] = std::get<1>(q); + j["count"] = std::get<2>(q); + + category_json_vec.emplace_back(j); + } + + json j_vec(category_json_vec); + json category_info; + category_info["key"] = current_category_field_; + category_info["categories"] = j_vec; + return category_info.dump(); +} + +std::pair>> ShardSegment::ReadAtPageById(int64_t category_id, + int64_t page_no, + int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info"; + return {FAILED, std::vector>{}}; + } + if (category_id >= static_cast(ret.second.size()) || category_id < 0) { + MS_LOG(ERROR) << "Illegal category id, id: " << category_id; + return {FAILED, std::vector>{}}; + } + int total_rows_in_category = std::get<2>(ret.second[category_id]); + // Quit if category not found or page number is out of range + if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || + page_no * n_rows_of_page >= total_rows_in_category) { + MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; + return {FAILED, std::vector>{}}; + } + + std::vector> page; + auto row_group_summary = ReadRowGroupSummary(); + + uint64_t i_start = page_no * n_rows_of_page; + uint64_t i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); + uint64_t idx = 0; + for (const auto &rg : row_group_summary) { + if (idx >= i_end) break; + + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + auto details = ReadRowGroupCriteria( + group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); + if (SUCCESS != std::get<0>(details)) { + return {FAILED, std::vector>{}}; + } + auto offsets = std::get<4>(details); + uint64_t number_of_rows = offsets.size(); + if (idx + number_of_rows < i_start) { + idx += number_of_rows; + continue; + } + + for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { + if (idx >= i_start && idx < i_end) { + auto ret1 = PackImages(group_id, shard_id, offsets[i]); + if (SUCCESS != ret1.first) { + return {FAILED, std::vector>{}}; + } + page.push_back(std::move(ret1.second)); + } + } + } + + return {SUCCESS, std::move(page)}; +} + +std::pair> ShardSegment::PackImages(int group_id, int shard_id, + std::vector offset) { + const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); + if (SUCCESS != ret.first) { + return {FAILED, std::vector()}; + } + const std::shared_ptr &blob_page = ret.second; + + // Pack image list + std::vector images(offset[1] - offset[0]); + auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; + auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_random_[0][shard_id]->close(); + return {FAILED, {}}; + } + + auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast(&images[0]), offset[1] - offset[0]); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_random_[0][shard_id]->close(); + return {FAILED, {}}; + } + + return {SUCCESS, std::move(images)}; +} + +std::pair>> ShardSegment::ReadAtPageByName(std::string category_name, + int64_t page_no, + int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info"; + return {FAILED, std::vector>{}}; + } + for (const auto &categories : ret.second) { + if (std::get<1>(categories) == category_name) { + auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); + return result; + } + } + + return {FAILED, std::vector>()}; +} + +std::pair, json>>> ShardSegment::ReadAllAtPageById( + int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS || category_id >= static_cast(ret.second.size())) { + MS_LOG(ERROR) << "Illegal category id, id: " << category_id; + return {FAILED, std::vector, json>>{}}; + } + int total_rows_in_category = std::get<2>(ret.second[category_id]); + // Quit if category not found or page number is out of range + if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { + MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; + return {FAILED, std::vector, json>>{}}; + } + + std::vector, json>> page; + auto row_group_summary = ReadRowGroupSummary(); + + int i_start = page_no * n_rows_of_page; + int i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); + int idx = 0; + for (const auto &rg : row_group_summary) { + if (idx >= i_end) break; + + auto shard_id = std::get<0>(rg); + auto group_id = std::get<1>(rg); + auto details = ReadRowGroupCriteria( + group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); + if (SUCCESS != std::get<0>(details)) { + return {FAILED, std::vector, json>>{}}; + } + auto offsets = std::get<4>(details); + auto labels = std::get<5>(details); + + int number_of_rows = offsets.size(); + if (idx + number_of_rows < i_start) { + idx += number_of_rows; + continue; + } + + if (number_of_rows > static_cast(labels.size())) { + MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; + return {FAILED, std::vector, json>>{}}; + } + for (int i = 0; i < number_of_rows; ++i, ++idx) { + if (idx >= i_start && idx < i_end) { + auto ret1 = PackImages(group_id, shard_id, offsets[i]); + if (SUCCESS != ret1.first) { + return {FAILED, std::vector, json>>{}}; + } + page.emplace_back(std::move(ret1.second), std::move(labels[i])); + } + } + } + return {SUCCESS, std::move(page)}; +} + +std::pair, json>>> ShardSegment::ReadAllAtPageByName( + std::string category_name, int64_t page_no, int64_t n_rows_of_page) { + auto ret = WrapCategoryInfo(); + if (ret.first != SUCCESS) { + MS_LOG(ERROR) << "Get category info"; + return {FAILED, std::vector, json>>{}}; + } + + // category_name to category_id + int64_t category_id = -1; + for (const auto &categories : ret.second) { + std::string categories_name = std::get<1>(categories); + + if (categories_name == category_name) { + category_id = std::get<0>(categories); + break; + } + } + + if (category_id == -1) { + return {FAILED, std::vector, json>>{}}; + } + + return ReadAllAtPageById(category_id, page_no, n_rows_of_page); +} + +std::pair, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( + int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { + auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); + if (res.first != SUCCESS) { + return {FAILED, std::vector, pybind11::object>>{}}; + } + + vector, pybind11::object>> json_data; + std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), + [](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + return std::make_tuple(std::get<0>(item), std::move(obj)); + }); + return {SUCCESS, std::move(json_data)}; +} + +std::pair, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( + std::string category_name, int64_t page_no, int64_t n_rows_of_page) { + auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); + if (res.first != SUCCESS) { + return {FAILED, std::vector, pybind11::object>>{}}; + } + vector, pybind11::object>> json_data; + std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), + [](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + return std::make_tuple(std::get<0>(item), std::move(obj)); + }); + return {SUCCESS, std::move(json_data)}; +} + +std::pair> ShardSegment::GetBlobFields() { + std::vector blob_fields; + for (auto &p : GetShardHeader()->GetSchemas()) { + // assume one schema + const auto &fields = p->GetBlobFields(); + blob_fields.assign(fields.begin(), fields.end()); + break; + } + return std::make_pair(kCV, blob_fields); +} + +std::string ShardSegment::CleanUp(std::string field_name) { + while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); + field_name.pop_back(); + return field_name; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc new file mode 100644 index 0000000000..e85229cc34 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc @@ -0,0 +1,1254 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_writer.h" +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "./securec.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; +using mindspore::MsLogLevel::ERROR; +using mindspore::MsLogLevel::INFO; + +namespace mindspore { +namespace mindrecord { +ShardWriter::ShardWriter() + : shard_count_(1), + header_size_(kDefaultHeaderSize), + page_size_(kDefaultPageSize), + row_count_(0), + schema_count_(1) {} + +ShardWriter::~ShardWriter() { + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } +} + +MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { + // Get full path from file name + for (const auto &path : paths) { + if (!CheckIsValidUtf8(path)) { + MS_LOG(ERROR) << "The filename contains invalid uft-8 data: " << path << "."; + return FAILED; + } + char resolved_path[PATH_MAX] = {0}; + char buf[PATH_MAX] = {0}; + if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { + MS_LOG(ERROR) << "Secure func failed"; + return FAILED; + } +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX) == nullptr) { + MS_LOG(ERROR) << "Invalid file path"; + return FAILED; + } + if (_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "Path " << resolved_path; + } +#else + if (realpath(dirname(&(buf[0])), resolved_path) == nullptr) { + MS_LOG(ERROR) << "Invalid file path"; + return FAILED; + } + if (realpath(common::SafeCStr(path), resolved_path) == nullptr) { + MS_LOG(DEBUG) << "Path " << resolved_path; + } +#endif + file_paths_.emplace_back(string(resolved_path)); + } + return SUCCESS; +} + +MSRStatus ShardWriter::OpenDataFiles(bool append) { + // Open files + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + if (!append) { + // if not append and mindrecord file exist, return FAILED + fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); + if (fs->good()) { + MS_LOG(ERROR) << "MindRecord file already existed."; + fs->close(); + return FAILED; + } + fs->close(); + + // open the mindrecord file to write + fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc); + if (!fs->good()) { + MS_LOG(ERROR) << "MindRecord file could not opened."; + return FAILED; + } + } else { + // open the mindrecord file to append + fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary); + if (!fs->good()) { + MS_LOG(ERROR) << "MindRecord file could not opened for append."; + return FAILED; + } + } + MS_LOG(INFO) << "Open shard file successfully."; + file_streams_.push_back(fs); + } + return SUCCESS; +} + +MSRStatus ShardWriter::RemoveLockFile() { + // Remove temporary file + int ret = std::remove(pages_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove page file."; + } + + ret = std::remove(lock_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove lock file."; + } + return SUCCESS; +} + +MSRStatus ShardWriter::InitLockFile() { + if (file_paths_.size() == 0) { + MS_LOG(ERROR) << "File path not initialized."; + return FAILED; + } + + lock_file_ = file_paths_[0] + kLockFileSuffix; + pages_file_ = file_paths_[0] + kPageFileSuffix; + + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + + // Get full path from file name + if (GetFullPathFromFileName(paths) == FAILED) { + MS_LOG(ERROR) << "Get full path from file name failed."; + return FAILED; + } + + // Open files + if (OpenDataFiles(append) == FAILED) { + MS_LOG(ERROR) << "Open data files failed."; + return FAILED; + } + + // Init lock file + if (InitLockFile() == FAILED) { + MS_LOG(ERROR) << "Init lock file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::OpenForAppend(const std::string &path) { + if (!IsLegalFile(path)) { + return FAILED; + } + auto ret1 = ShardHeader::BuildSingleHeader(path); + if (ret1.first != SUCCESS) { + return FAILED; + } + auto json_header = ret1.second; + auto ret2 = GetParentDir(path); + if (SUCCESS != ret2.first) { + return FAILED; + } + std::vector real_addresses; + for (const auto &path : json_header["shard_addresses"]) { + std::string abs_path = ret2.second + string(path); + real_addresses.emplace_back(abs_path); + } + ShardHeader header = ShardHeader(); + if (header.BuildDataset(real_addresses) == FAILED) { + return FAILED; + } + shard_header_ = std::make_shared(header); + MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); + if (ret == FAILED) { + return FAILED; + } + ret = SetPageSize(shard_header_->GetPageSize()); + if (ret == FAILED) { + return FAILED; + } + ret = Open(real_addresses, true); + if (ret == FAILED) { + MS_LOG(ERROR) << "Open file failed"; + return FAILED; + } + shard_column_ = std::make_shared(shard_header_); + return SUCCESS; +} + +MSRStatus ShardWriter::Commit() { + // Read pages file + std::ifstream page_file(pages_file_.c_str()); + if (page_file.good()) { + page_file.close(); + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return FAILED; + } + } + + if (WriteShardHeader() == FAILED) { + MS_LOG(ERROR) << "Write metadata failed"; + return FAILED; + } + MS_LOG(INFO) << "Write metadata successfully."; + + // Remove lock file + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove lock file failed."; + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) { + MSRStatus ret = header_data->InitByFiles(file_paths_); + if (ret == FAILED) { + return FAILED; + } + + // set fields in mindrecord when empty + std::vector> fields = header_data->GetFields(); + if (fields.empty()) { + MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; + std::vector> schemas = header_data->GetSchemas(); + for (const auto &schema : schemas) { + json jsonSchema = schema->GetSchema()["schema"]; + for (const auto &el : jsonSchema.items()) { + if (el.value()["type"] == "string" || + (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) || + (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || + (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || + (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { + fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); + } + } + } + // only blob data + if (!fields.empty()) { + ret = header_data->AddIndexFields(fields); + if (ret == FAILED) { + MS_LOG(ERROR) << "Add index field failed"; + return FAILED; + } + } + } + + shard_header_ = header_data; + shard_header_->SetHeaderSize(header_size_); + shard_header_->SetPageSize(page_size_); + shard_column_ = std::make_shared(shard_header_); + return SUCCESS; +} + +MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { + // header_size [16KB, 128MB] + if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { + MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; + return FAILED; + } + if (header_size % 4 != 0) { + MS_LOG(ERROR) << "Header size should be divided by four."; + return FAILED; + } + + header_size_ = header_size; + return SUCCESS; +} + +MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { + // PageSize [32KB, 256MB] + if (page_size < kMinPageSize || page_size > kMaxPageSize) { + MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; + return FAILED; + } + if (page_size % 4 != 0) { + MS_LOG(ERROR) << "Page size should be divided by four."; + return FAILED; + } + page_size_ = page_size; + return SUCCESS; +} + +void ShardWriter::DeleteErrorData(std::map> &raw_data, + std::vector> &blob_data) { + // get wrong data location + std::set> delete_set; + for (auto &err_mg : err_mg_) { + uint64_t id = err_mg.first; + auto sub_err_mg = err_mg.second; + for (auto &subMg : sub_err_mg) { + int loc = subMg.first; + std::string message = subMg.second; + MS_LOG(ERROR) << "For schema " << id << ", " << loc + 1 << " th data is wrong: " << message; + (void)delete_set.insert(loc); + } + } + + auto it = raw_data.begin(); + if (delete_set.size() == it->second.size()) { + raw_data.clear(); + blob_data.clear(); + return; + } + + // delete wrong raw data + for (auto &loc : delete_set) { + // delete row data + for (auto &raw : raw_data) { + (void)raw.second.erase(raw.second.begin() + loc); + } + + // delete blob data + (void)blob_data.erase(blob_data.begin() + loc); + } +} + +void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message, + std::map &err_raw_data) { + std::lock_guard lock(check_mutex_); + (void)err_raw_data.insert(std::make_pair(row, message)); +} + +MSRStatus ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, + std::map &err_raw_data) { + auto data_type = std::string(value["type"].get()); + + if ((data_type == "int32" && !data[key].is_number_integer()) || + (data_type == "int64" && !data[key].is_number_integer()) || + (data_type == "float32" && !data[key].is_number_float()) || + (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) { + std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is not matched"; + PopulateMutexErrorData(i, message, err_raw_data); + return FAILED; + } + + if (data_type == "int32" && data[key].is_number_integer()) { + int64_t temp_value = data[key]; + if (static_cast(temp_value) < static_cast(std::numeric_limits::min()) && + static_cast(temp_value) > static_cast(std::numeric_limits::max())) { + std::string message = + "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is out of range"; + PopulateMutexErrorData(i, message, err_raw_data); + return FAILED; + } + } + return SUCCESS; +} + +void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, + std::map &err_raw_data) { + if (start_row < 0 || start_row > end_row || end_row > static_cast(sub_raw_data.size())) { + return; + } + for (int i = start_row; i < end_row; i++) { + json data = sub_raw_data[i]; + + for (auto iter = schema.begin(); iter != schema.end(); iter++) { + std::string key = iter.key(); + json value = iter.value(); + if (data.find(key) == data.end()) { + std::string message = "there is not '" + key + "' object in the raw data"; + PopulateMutexErrorData(i, message, err_raw_data); + break; + } + + if (value.size() == kInt2) { + // Skip check since all shaped data will store as blob + continue; + } + + if (CheckDataTypeAndValue(key, value, data, i, err_raw_data) != SUCCESS) { + break; + } + } + } +} + +MSRStatus ShardWriter::CheckData(const std::map> &raw_data) { + auto rawdata_iter = raw_data.begin(); + + // make sure rawdata match schema + for (; rawdata_iter != raw_data.end(); ++rawdata_iter) { + // used for storing error + std::map sub_err_mg; + int schema_id = rawdata_iter->first; + auto result = shard_header_->GetSchemaByID(schema_id); + if (result.second != SUCCESS) { + return FAILED; + } + json schema = result.first->GetSchema()["schema"]; + for (const auto &field : result.first->GetBlobFields()) { + (void)schema.erase(field); + } + std::vector sub_raw_data = rawdata_iter->second; + + // calculate start position and end position for each thread + int batch_size = rawdata_iter->second.size() / shard_count_; + int thread_num = shard_count_; + if (thread_num <= 0) { + return FAILED; + } + if (thread_num > kMaxThreadCount) { + thread_num = kMaxThreadCount; + } + std::vector thread_set(thread_num); + + // start multiple thread + int start_row = 0, end_row = 0; + for (int x = 0; x < thread_num; ++x) { + if (x != thread_num - 1) { + start_row = batch_size * x; + end_row = batch_size * (x + 1); + } else { + start_row = batch_size * x; + end_row = rawdata_iter->second.size(); + } + thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema, + std::ref(sub_raw_data), std::ref(sub_err_mg)); + } + if (thread_num > kMaxThreadCount) { + return FAILED; + } + // Wait for threads done + for (int x = 0; x < thread_num; ++x) { + thread_set[x].join(); + } + + (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg)); + } + return SUCCESS; +} + +std::tuple ShardWriter::ValidateRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign) { + auto rawdata_iter = raw_data.begin(); + schema_count_ = raw_data.size(); + std::tuple failed(FAILED, 0, 0); + if (schema_count_ == 0) { + MS_LOG(ERROR) << "Data size is zero"; + return failed; + } + + // keep schema_id + std::set schema_ids; + row_count_ = (rawdata_iter->second).size(); + MS_LOG(DEBUG) << "Schema count is " << schema_count_; + + // Determine if the number of schemas is the same + if (shard_header_->GetSchemas().size() != schema_count_) { + MS_LOG(ERROR) << "Data size is not equal with the schema size"; + return failed; + } + + // Determine raw_data size == blob_data size + if (raw_data[0].size() != blob_data.size()) { + MS_LOG(ERROR) << "Raw data size is not equal blob data size"; + return failed; + } + + // Determine whether the number of samples corresponding to each schema is the same + for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { + if (row_count_ != rawdata_iter->second.size()) { + MS_LOG(ERROR) << "Data size is not equal"; + return failed; + } + (void)schema_ids.insert(rawdata_iter->first); + } + const std::vector> &schemas = shard_header_->GetSchemas(); + if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr &schema) { + return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); + })) { + // There is not enough data which is not matching the number of schema + MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; + return failed; + } + + if (!sign) { + std::tuple success(SUCCESS, schema_count_, row_count_); + return success; + } + + // check the data according the schema + if (CheckData(raw_data) != SUCCESS) { + MS_LOG(ERROR) << "Data validate check failed"; + return std::tuple(FAILED, schema_count_, row_count_); + } + + // delete wrong data from raw data + DeleteErrorData(raw_data, blob_data); + + // update raw count + row_count_ = row_count_ - err_mg_.begin()->second.size(); + std::tuple success(SUCCESS, schema_count_, row_count_); + return success; +} + +void ShardWriter::FillArray(int start, int end, std::map> &raw_data, + std::vector> &bin_data) { + // Prevent excessive thread opening and cause cross-border + if (start >= end) { + flag_ = true; + return; + } + int schema_count = static_cast(raw_data.size()); + std::map>::const_iterator rawdata_iter; + for (int x = start; x < end; ++x) { + int cnt = 0; + for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { + const json &line = raw_data.at(rawdata_iter->first)[x]; + std::vector bline = json::to_msgpack(line); + + // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2] + bin_data[x * schema_count + cnt] = bline; + cnt++; + } + } +} + +int ShardWriter::LockWriter(bool parallel_writer) { + if (!parallel_writer) { + return 0; + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Lock file done by python."; + const int fd = 0; +#else + const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); + if (fd >= 0) { + flock(fd, LOCK_EX); + } else { + MS_LOG(ERROR) << "Shard writer failed when locking file"; + return -1; + } +#endif + + // Open files + file_streams_.clear(); + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); + if (fs->fail()) { + MS_LOG(ERROR) << "File could not opened"; + return -1; + } + file_streams_.push_back(fs); + } + + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return -1; + } + return fd; +} + +MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { + if (!parallel_writer) { + return SUCCESS; + } + + if (shard_header_->PagesToFile(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Write pages to file failed"; + return FAILED; + } + + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Unlock file done by python."; +#else + flock(fd, LOCK_UN); + close(fd); +#endif + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { + // check the free disk size + auto st_space = GetDiskSize(file_paths_[0], kFreeSize); + if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { + MS_LOG(ERROR) << "IO error / there is no free disk to be used"; + return FAILED; + } + + // compress blob + if (shard_column_->CheckCompressBlob()) { + for (auto &blob : blob_data) { + blob = shard_column_->CompressBlob(blob); + } + } + + // Add 4-bytes dummy blob data if no any blob fields + if (blob_data.size() == 0 && raw_data.size() > 0) { + blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); + } + + // Add dummy id if all are blob fields + if (blob_data.size() > 0 && raw_data.size() == 0) { + raw_data.insert(std::pair>(0, std::vector(blob_data.size(), kDummyId))); + } + + auto v = ValidateRawData(raw_data, blob_data, sign); + if (std::get<0>(v) == FAILED) { + MS_LOG(ERROR) << "Validate raw data failed"; + return FAILED; + } + *schema_count = std::get<1>(v); + *row_count = std::get<2>(v); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { + // Lock Writer if loading data parallel + int fd = LockWriter(parallel_writer); + if (fd < 0) { + MS_LOG(ERROR) << "Lock writer failed"; + return FAILED; + } + + // Get the count of schemas and rows + int schema_count = 0; + int row_count = 0; + + // Serialize raw data + if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { + MS_LOG(ERROR) << "Check raw data failed"; + return FAILED; + } + + if (row_count == kInt0) { + MS_LOG(INFO) << "Raw data size is 0."; + return SUCCESS; + } + + std::vector> bin_raw_data(row_count * schema_count); + + // Serialize raw data + if (SerializeRawData(raw_data, bin_raw_data, row_count) == FAILED) { + MS_LOG(ERROR) << "Serialize raw data failed"; + return FAILED; + } + + // Set row size of raw data + if (SetRawDataSize(bin_raw_data) == FAILED) { + MS_LOG(ERROR) << "Set raw data size failed"; + return FAILED; + } + + // Set row size of blob data + if (SetBlobDataSize(blob_data) == FAILED) { + MS_LOG(ERROR) << "Set blob data size failed"; + return FAILED; + } + + // Write data to disk with multi threads + if (ParallelWriteData(blob_data, bin_raw_data) == FAILED) { + MS_LOG(ERROR) << "Parallel write data failed"; + return FAILED; + } + MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; + + if (UnlockWriter(fd, parallel_writer) == FAILED) { + MS_LOG(ERROR) << "Unlock writer failed"; + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::map> &blob_data, bool sign, + bool parallel_writer) { + std::map> raw_data_json; + std::map> blob_data_json; + + (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), + [](const std::pair> &pair) { + auto &py_raw_data = pair.second; + std::vector json_raw_data; + (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(pair.first, std::move(json_raw_data)); + }); + + (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()), + [](const std::pair> &pair) { + auto &py_blob_data = pair.second; + std::vector jsonBlobData; + (void)std::transform(py_blob_data.begin(), py_blob_data.end(), + std::back_inserter(jsonBlobData), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(pair.first, std::move(jsonBlobData)); + }); + + // Serialize blob page + auto blob_data_iter = blob_data.begin(); + auto schema_count = blob_data.size(); + auto row_count = blob_data_iter->second.size(); + + std::vector> bin_blob_data(row_count * schema_count); + // Serialize blob data + if (SerializeRawData(blob_data_json, bin_blob_data, row_count) == FAILED) { + MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; + return FAILED; + } + return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + vector> &blob_data, bool sign, bool parallel_writer) { + std::map> raw_data_json; + (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), + [](const std::pair> &pair) { + auto &py_raw_data = pair.second; + std::vector json_raw_data; + (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(pair.first, std::move(json_raw_data)); + }); + return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); +} + +MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, + const std::vector> &bin_raw_data) { + auto shards = BreakIntoShards(); + // define the number of thread + int thread_num = static_cast(shard_count_); + if (thread_num < 0) { + return FAILED; + } + if (thread_num > kMaxThreadCount) { + thread_num = kMaxThreadCount; + } + int left_thread = shard_count_; + int current_thread = 0; + while (left_thread) { + if (left_thread < thread_num) { + thread_num = left_thread; + } + // Start one thread for one shard + std::vector thread_set(thread_num); + if (thread_num <= kMaxThreadCount) { + for (int x = 0; x < thread_num; ++x) { + int start_row = shards[current_thread + x].first; + int end_row = shards[current_thread + x].second; + thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row, + std::ref(blob_data), std::ref(bin_raw_data)); + } + // Wait for threads done + for (int x = 0; x < thread_num; ++x) { + thread_set[x].join(); + } + left_thread -= thread_num; + current_thread += thread_num; + } + } + return SUCCESS; +} + +MSRStatus ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, + const std::vector> &blob_data, + const std::vector> &bin_raw_data) { + MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row + << ", schema size: " << schema_count_; + if (start_row == end_row) { + return SUCCESS; + } + vector> rows_in_group; + std::shared_ptr last_raw_page = nullptr; + std::shared_ptr last_blob_page = nullptr; + SetLastRawPage(shard_id, last_raw_page); + SetLastBlobPage(shard_id, last_blob_page); + + if (CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page) == FAILED) { + MS_LOG(ERROR) << "Cut row group failed"; + return FAILED; + } + + if (AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { + MS_LOG(ERROR) << "Append bolb page failed"; + return FAILED; + } + + if (NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { + MS_LOG(ERROR) << "New blob page failed"; + return FAILED; + } + + if (ShiftRawPage(shard_id, rows_in_group, last_raw_page) == FAILED) { + MS_LOG(ERROR) << "Shit raw page failed"; + return FAILED; + } + + if (WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data) == FAILED) { + MS_LOG(ERROR) << "Write raw page failed"; + return FAILED; + } + + return SUCCESS; +} + +MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, + std::vector> &rows_in_group, + const std::shared_ptr &last_raw_page, + const std::shared_ptr &last_blob_page) { + auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; + + auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; + auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; + auto n_byte_raw = last_raw_page_size - last_raw_offset; + + int page_start_row = start_row; + if (start_row > end_row) { + return FAILED; + } + if (end_row > static_cast(blob_data_size_.size()) || end_row > static_cast(raw_data_size_.size())) { + return FAILED; + } + for (int i = start_row; i < end_row; ++i) { + // n_byte_blob(0) indicate appendBlobPage + if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ || + n_byte_raw + raw_data_size_[i] > page_size_) { + rows_in_group.emplace_back(page_start_row, i); + page_start_row = i; + n_byte_blob = blob_data_size_[i]; + n_byte_raw = raw_data_size_[i]; + } else { + n_byte_blob += blob_data_size_[i]; + n_byte_raw += raw_data_size_[i]; + } + } + + // Not forget last one + rows_in_group.emplace_back(page_start_row, end_row); + return SUCCESS; +} + +MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page) { + auto blob_row = rows_in_group[0]; + if (blob_row.first == blob_row.second) return SUCCESS; + + // Write disk + auto page_id = last_blob_page->GetPageID(); + auto bytes_page = last_blob_page->GetPageSize(); + auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); + + // Update last blob page + bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); + last_blob_page->SetPageSize(bytes_page); + uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; + last_blob_page->SetEndRowID(end_row); + (void)shard_header_->SetPage(last_blob_page); + return SUCCESS; +} + +MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page) { + auto page_id = shard_header_->GetLastPageId(shard_id); + auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; + auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; + // index(0) indicate appendBlobPage + for (uint32_t i = 1; i < rows_in_group.size(); ++i) { + auto blob_row = rows_in_group[i]; + + // Write 1 blob page to disk + auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); + // Create new page info for header + auto page_size = + std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); + std::vector> row_group_ids; + auto start_row = current_row; + auto end_row = start_row + blob_row.second - blob_row.first; + auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size); + (void)shard_header_->AddPage(std::make_shared(page)); + current_row = end_row; + } + return SUCCESS; +} + +MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page) { + auto blob_row = rows_in_group[0]; + if (blob_row.first == blob_row.second) return SUCCESS; + auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; + if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + + last_raw_page_size <= + page_size_) { + return SUCCESS; + } + auto page_id = shard_header_->GetLastPageId(shard_id); + auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; + auto last_raw_page_id = last_raw_page->GetPageID(); + auto shift_size = last_raw_page_size - last_row_group_id_offset; + + std::vector buf(shift_size); + + // Read last row group from previous raw data page + if (shard_id < 0 || shard_id >= file_streams_.size()) { + return FAILED; + } + + auto &io_seekg = file_streams_[shard_id]->seekg( + page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + MS_LOG(ERROR) << "File seekg failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf[0]), buf.size()); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + // Merge into new row group at new raw data page + auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&buf[0]), buf.size()); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + last_raw_page->DeleteLastGroupId(); + (void)shard_header_->SetPage(last_raw_page); + + // Refresh page info in header + int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; + std::vector> row_group_ids; + row_group_ids.emplace_back(row_group_id, 0); + int page_type_id = last_raw_page->GetPageID(); + auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); + (void)shard_header_->AddPage(std::make_shared(page)); + + // Reset: last raw page + SetLastRawPage(shard_id, last_raw_page); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page, + const std::vector> &bin_raw_data) { + int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; + for (uint32_t i = 0; i < rows_in_group.size(); ++i) { + const auto &blob_row = rows_in_group[i]; + if (blob_row.first == blob_row.second) continue; + auto raw_size = + std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); + if (!last_raw_page) { + EmptyRawPage(shard_id, last_raw_page); + } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { + (void)shard_header_->SetPage(last_raw_page); + EmptyRawPage(shard_id, last_raw_page); + } + if (AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data) != SUCCESS) { + return FAILED; + } + } + (void)shard_header_->SetPage(last_raw_page); + return SUCCESS; +} + +void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { + auto row_group_ids = std::vector>(); + auto page_id = shard_header_->GetLastPageId(shard_id); + auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; + auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); + (void)shard_header_->AddPage(std::make_shared(page)); + SetLastRawPage(shard_id, last_raw_page); +} + +MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, + const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, + const std::vector> &bin_raw_data) { + std::vector> row_group_ids = last_raw_page->GetRowGroupIds(); + auto last_raw_page_id = last_raw_page->GetPageID(); + auto n_bytes = last_raw_page->GetPageSize(); + + // previous raw data page + auto &io_seekp = + file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + if (chunk_id > 0) row_group_ids.emplace_back(++last_row_group_id, n_bytes); + n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first, + raw_data_size_.begin() + rows_in_group[chunk_id].second, 0); + (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); + + // Update previous raw data page + last_raw_page->SetPageSize(n_bytes); + last_raw_page->SetRowGroupIds(row_group_ids); + (void)shard_header_->SetPage(last_raw_page); + + return SUCCESS; +} + +MSRStatus ShardWriter::FlushBlobChunk(const std::shared_ptr &out, + const std::vector> &blob_data, + const std::pair &blob_row) { + if (blob_row.first > blob_row.second) { + return FAILED; + } + if (blob_row.second > static_cast(blob_data.size()) || blob_row.first < 0) { + return FAILED; + } + for (int j = blob_row.first; j < blob_row.second; ++j) { + // Write the size of blob + uint64_t line_len = blob_data[j].size(); + auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + + // Write the data of blob + auto line = blob_data[j]; + auto &io_handle_data = out->write(reinterpret_cast(&line[0]), line_len); + if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + } + return SUCCESS; +} + +MSRStatus ShardWriter::FlushRawChunk(const std::shared_ptr &out, + const std::vector> &rows_in_group, const int &chunk_id, + const std::vector> &bin_raw_data) { + for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) { + // Write the size of multi schemas + for (uint32_t j = 0; j < schema_count_; ++j) { + uint64_t line_len = bin_raw_data[i * schema_count_ + j].size(); + auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + } + // Write the data of multi schemas + for (uint32_t j = 0; j < schema_count_; ++j) { + auto line = bin_raw_data[i * schema_count_ + j]; + auto &io_handle = out->write(reinterpret_cast(&line[0]), line.size()); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + out->close(); + return FAILED; + } + } + } + return SUCCESS; +} + +// Allocate data to shards evenly +std::vector> ShardWriter::BreakIntoShards() { + std::vector> shards; + int row_in_shard = row_count_ / shard_count_; + int remains = row_count_ % shard_count_; + + std::vector v_list(shard_count_); + std::iota(v_list.begin(), v_list.end(), 0); + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(v_list.begin(), v_list.end(), g); + std::unordered_set set(v_list.begin(), v_list.begin() + remains); + + if (shard_count_ <= kMaxShardCount) { + int start_row = 0; + for (int i = 0; i < shard_count_; ++i) { + int end_row = start_row + row_in_shard; + if (set.count(i)) end_row++; + shards.emplace_back(start_row, end_row); + start_row = end_row; + } + } + return shards; +} + +MSRStatus ShardWriter::WriteShardHeader() { + if (shard_header_ == nullptr) { + MS_LOG(ERROR) << "Shard header is null"; + return FAILED; + } + auto shard_header = shard_header_->SerializeHeader(); + // Write header data to multi files + if (shard_count_ > static_cast(file_streams_.size()) || shard_count_ > static_cast(shard_header.size())) { + return FAILED; + } + if (shard_count_ <= kMaxShardCount) { + for (int shard_id = 0; shard_id < shard_count_; ++shard_id) { + auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg); + if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { + MS_LOG(ERROR) << "File seekp failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + std::vector bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end()); + uint64_t line_len = bin_header.size(); + if (line_len + kInt64Len > header_size_) { + MS_LOG(ERROR) << "Shard header is too big"; + return FAILED; + } + + auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&line_len), kInt64Len); + if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { + MS_LOG(ERROR) << "File write failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + + auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast(&bin_header[0]), line_len); + if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) { + MS_LOG(ERROR) << "File write failed"; + file_streams_[shard_id]->close(); + return FAILED; + } + file_streams_[shard_id]->close(); + } + } + return SUCCESS; +} + +MSRStatus ShardWriter::SerializeRawData(std::map> &raw_data, + std::vector> &bin_data, uint32_t row_count) { + // define the number of thread + uint32_t thread_num = std::thread::hardware_concurrency(); + if (thread_num == 0) thread_num = kThreadNumber; + // Set the number of samples processed by each thread + int group_num = ceil(row_count * 1.0 / thread_num); + std::vector thread_set(thread_num); + int work_thread_num = 0; + for (uint32_t x = 0; x < thread_num; ++x) { + int start_num = x * group_num; + int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num; + if (start_num >= end_num) { + continue; + } + // Define the run boundary and start the child thread + thread_set[x] = + std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data)); + work_thread_num++; + } + for (uint32_t x = 0; x < work_thread_num; ++x) { + // Set obstacles to prevent the main thread from running + thread_set[x].join(); + } + return flag_ == true ? FAILED : SUCCESS; +} + +MSRStatus ShardWriter::SetRawDataSize(const std::vector> &bin_raw_data) { + raw_data_size_ = std::vector(row_count_, 0); + for (uint32_t i = 0; i < row_count_; ++i) { + raw_data_size_[i] = std::accumulate( + bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0, + [](uint64_t accumulator, const std::vector &row) { return accumulator + kInt64Len + row.size(); }); + } + if (*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) > page_size_) { + MS_LOG(ERROR) << "Page size is too small to save a row!"; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::SetBlobDataSize(const std::vector> &blob_data) { + blob_data_size_ = std::vector(row_count_); + (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(), + [](const std::vector &row) { return kInt64Len + row.size(); }); + if (*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) > page_size_) { + MS_LOG(ERROR) << "Page size is too small to save a row!"; + return FAILED; + } + return SUCCESS; +} + +void ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { + // Get last raw page + auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw); + if (last_raw_page_id >= 0) { + auto page = shard_header_->GetPage(shard_id, last_raw_page_id); + last_raw_page = page.first; + } +} + +void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page) { + // Get last blob page + auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob); + if (last_blob_page_id >= 0) { + auto page = shard_header_->GetPage(shard_id, last_blob_page_id); + last_blob_page = page.first; + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc new file mode 100644 index 0000000000..eb1428a2ad --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, + bool replacement) + : categories_(categories), + category_field_(""), + num_elements_(num_elements), + num_categories_(0), + replacement_(replacement) {} + +ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, + bool replacement) + : categories_({}), + category_field_(category_field), + num_elements_(num_elements), + num_categories_(num_categories), + replacement_(replacement) {} + +MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } + +int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (dataset_size == 0) return dataset_size; + if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { + return std::min(num_categories_, num_classes) * num_elements_; + } + return 0; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc new file mode 100644 index 0000000000..4cc5e9f413 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc @@ -0,0 +1,496 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_column.h" + +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_error.h" + +namespace mindspore { +namespace mindrecord { +ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool compress_integer) { + auto first_schema = shard_header->GetSchemas()[0]; + auto schema = first_schema->GetSchema()["schema"]; + + bool has_integer_array = false; + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + const std::string &column_name = it.key(); + column_name_.push_back(column_name); + + json it_value = it.value(); + + std::string str_type = it_value["type"]; + column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); + if (it_value.find("shape") != it_value.end()) { + std::vector vec(it_value["shape"].size()); + std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); + column_shape_.push_back(vec); + if (str_type == "int32" || str_type == "int64") { + has_integer_array = true; + } + } else { + std::vector vec = {}; + column_shape_.push_back(vec); + } + } + + for (uint64_t i = 0; i < column_name_.size(); i++) { + column_name_id_[column_name_[i]] = i; + } + + auto blob_fields = first_schema->GetBlobFields(); + + for (const auto &field : blob_fields) { + blob_column_.push_back(field); + } + + for (uint64_t i = 0; i < blob_column_.size(); i++) { + blob_column_id_[blob_column_[i]] = i; + } + + has_compress_blob_ = (compress_integer && has_integer_array); + num_blob_column_ = blob_column_.size(); +} + +std::pair ShardColumn::GetColumnTypeByName(const std::string &column_name, + ColumnDataType *column_data_type, + uint64_t *column_data_type_size, + std::vector *column_shape) { + // Skip if column not found + auto column_category = CheckColumnName(column_name); + if (column_category == ColumnNotFound) { + return {FAILED, ColumnNotFound}; + } + + // Get data type and size + auto column_id = column_name_id_[column_name]; + *column_data_type = column_data_type_[column_id]; + *column_data_type_size = ColumnDataTypeSize[*column_data_type]; + *column_shape = column_shape_[column_id]; + + return {SUCCESS, column_category}; +} + +MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *const n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape) { + // Skip if column not found + auto column_category = CheckColumnName(column_name); + if (column_category == ColumnNotFound) { + return FAILED; + } + + // Get data type and size + auto column_id = column_name_id_[column_name]; + *column_data_type = column_data_type_[column_id]; + *column_data_type_size = ColumnDataTypeSize[*column_data_type]; + *column_shape = column_shape_[column_id]; + + // Retrieve value from json + if (column_category == ColumnInRaw) { + if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { + MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; + return FAILED; + } + *data = reinterpret_cast(data_ptr->get()); + return SUCCESS; + } + + // Retrieve value from blob + if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { + MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; + return FAILED; + } + if (*data == nullptr) { + *data = reinterpret_cast(data_ptr->get()); + } + return SUCCESS; +} + +MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes) { + auto column_id = column_name_id_[column_name]; + auto column_data_type = column_data_type_[column_id]; + + // Initialize num bytes + *n_bytes = ColumnDataTypeSize[column_data_type]; + auto json_column_value = columns_json[column_name]; + switch (column_data_type) { + case ColumnFloat32: { + return GetFloat(data_ptr, json_column_value, false); + } + case ColumnFloat64: { + return GetFloat(data_ptr, json_column_value, true); + } + case ColumnInt32: { + return GetInt(data_ptr, json_column_value); + } + case ColumnInt64: { + return GetInt(data_ptr, json_column_value); + } + default: { + // Convert string to c_str + std::string tmp_string = json_column_value; + *n_bytes = tmp_string.size(); + auto data = reinterpret_cast(common::SafeCStr(tmp_string)); + *data_ptr = std::make_unique(*n_bytes); + for (uint32_t i = 0; i < *n_bytes; i++) { + (*data_ptr)[i] = *(data + i); + } + break; + } + } + return SUCCESS; +} + +template +MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, + bool use_double) { + std::unique_ptr array_data = std::make_unique(1); + if (!json_column_value.is_string() && !json_column_value.is_number()) { + MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; + return FAILED; + } + if (json_column_value.is_number()) { + array_data[0] = json_column_value; + } else { + // Convert string to float + try { + if (use_double) { + array_data[0] = json_column_value.get(); + } else { + array_data[0] = json_column_value.get(); + } + } catch (json::exception &e) { + MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; + return FAILED; + } + } + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(sizeof(T)); + for (uint32_t i = 0; i < sizeof(T); i++) { + (*data_ptr)[i] = *(data + i); + } + + return SUCCESS; +} + +template +MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { + std::unique_ptr array_data = std::make_unique(1); + int64_t temp_value; + bool less_than_zero = false; + + if (json_column_value.is_number_integer()) { + const json json_zero = 0; + if (json_column_value < json_zero) less_than_zero = true; + temp_value = json_column_value; + } else if (json_column_value.is_string()) { + std::string string_value = json_column_value; + + if (!string_value.empty() && string_value[0] == '-') { + try { + temp_value = std::stoll(string_value); + less_than_zero = true; + } catch (std::invalid_argument &e) { + MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; + return FAILED; + } catch (std::out_of_range &e) { + MS_LOG(ERROR) << "Conversion to int failed, out of range."; + return FAILED; + } + } else { + try { + temp_value = static_cast(std::stoull(string_value)); + } catch (std::invalid_argument &e) { + MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; + return FAILED; + } catch (std::out_of_range &e) { + MS_LOG(ERROR) << "Conversion to int failed, out of range."; + return FAILED; + } + } + } else { + MS_LOG(ERROR) << "Conversion to int failed."; + return FAILED; + } + + if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || + (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { + MS_LOG(ERROR) << "Conversion to int failed. Out of range"; + return FAILED; + } + array_data[0] = static_cast(temp_value); + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(sizeof(T)); + for (uint32_t i = 0; i < sizeof(T); i++) { + (*data_ptr)[i] = *(data + i); + } + + return SUCCESS; +} + +MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *const n_bytes) { + uint64_t offset_address = 0; + auto column_id = column_name_id_[column_name]; + if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { + return FAILED; + } + + auto column_data_type = column_data_type_[column_id]; + if (has_compress_blob_ && column_data_type == ColumnInt32) { + if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { + return FAILED; + } + } else if (has_compress_blob_ && column_data_type == ColumnInt64) { + if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { + return FAILED; + } + } else { + *data = reinterpret_cast(&(columns_blob[offset_address])); + } + + return SUCCESS; +} + +ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { + auto it_column = column_name_id_.find(column_name); + if (it_column == column_name_id_.end()) { + return ColumnNotFound; + } + auto it_blob = blob_column_id_.find(column_name); + return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; +} + +std::vector ShardColumn::CompressBlob(const std::vector &blob) { + // Skip if no compress columns + if (!CheckCompressBlob()) return blob; + + std::vector dst_blob; + uint64_t i_src = 0; + for (int64_t i = 0; i < num_blob_column_; i++) { + // Get column data type + auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; + auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; + + // Compress and return is blob has 1 column only + if (num_blob_column_ == 1) { + return CompressInt(blob, int_type); + } + + // Just copy and continue if column dat type is not int32/int64 + uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); + if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { + dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); + i_src += kInt64Len + num_bytes; + continue; + } + + // Get column slice in source blob + std::vector blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); + // Compress column + auto dst_blob_slice = CompressInt(blob_slice, int_type); + // Get new column size + auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); + // Append new colmn size + dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); + // Append new colmn data + dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); + i_src += kInt64Len + num_bytes; + } + MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; + return dst_blob; +} + +vector ShardColumn::CompressInt(const vector &src_bytes, const IntegerType &int_type) { + uint64_t i_size = kUnsignedOne << static_cast(int_type); + // Get number of elements + uint64_t src_n_int = src_bytes.size() / i_size; + // Calculate bitmap size (bytes) + uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; + + // Initilize destination blob, more space than needed, will be resized + vector dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); + + // Write number of elements to destination blob + vector size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); + for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { + dst_bytes[n] = size_by_bytes[n]; + } + + // Write compressed int + uint64_t i_dst = kBytesOfColumnLen + bitmap_size; + for (uint64_t i = 0; i < src_n_int; i++) { + // Initialize destination data type + IntegerType dst_int_type = kInt8Type; + // Shift to next int position + uint64_t pos = i * (kUnsignedOne << static_cast(int_type)); + // Narrow down this int + int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); + + // Write this int to destination blob + uint64_t u_n = *reinterpret_cast(&i_n); + auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); + for (uint64_t j = 0; j < (kUnsignedOne << static_cast(dst_int_type)); j++) { + dst_bytes[i_dst++] = temp_bytes[j]; + } + + // Update date type in bit map + dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= + (static_cast(dst_int_type) << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); + } + // Resize destination blob + dst_bytes.resize(i_dst); + MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; + return dst_bytes; +} + +MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx) { + if (num_blob_column_ == 1) { + *num_bytes = columns_blob.size(); + *shift_idx = 0; + return SUCCESS; + } + auto blob_id = blob_column_id_[column_name_[column_id]]; + + for (int32_t i = 0; i < blob_id; i++) { + *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); + } + *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); + + (*shift_idx) += kInt64Len; + + return SUCCESS; +} + +template +MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, + uint64_t shift_idx) { + auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); + *num_bytes = sizeof(T) * num_elements; + + // Parse integer array + uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; + auto array_data = std::make_unique(num_elements); + + for (uint64_t i = 0; i < num_elements; i++) { + uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; + uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; + auto mr_int_type = static_cast(i_type); + int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); + i_source += (kUnsignedOne << i_type); + array_data[i] = static_cast(i64); + } + + auto data = reinterpret_cast(array_data.get()); + *data_ptr = std::make_unique(*num_bytes); + int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); + if (ret_code != 0) { + MS_LOG(ERROR) << "Failed to copy data!"; + } + + return SUCCESS; +} + +uint64_t ShardColumn::BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &i_type) { + uint64_t result = 0; + for (uint64_t i = 0; i < (kUnsignedOne << static_cast(i_type)); i++) { + result = (result << kBitsOfByte) + bytes_array[pos + i]; + } + return result; +} + +std::vector ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { + uint64_t n_bytes = kUnsignedOne << static_cast(i_type); + std::vector result(n_bytes, 0); + for (uint64_t i = 0; i < n_bytes; i++) { + result[n_bytes - 1 - i] = value & std::numeric_limits::max(); + value >>= kBitsOfByte; + } + return result; +} + +std::vector ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { + uint64_t n_bytes = kUnsignedOne << static_cast(i_type); + std::vector result(n_bytes, 0); + for (uint64_t i = 0; i < n_bytes; i++) { + result[i] = value & std::numeric_limits::max(); + value >>= kBitsOfByte; + } + return result; +} + +int64_t ShardColumn::BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, + const IntegerType &src_i_type, IntegerType *dst_i_type) { + uint64_t u_temp = 0; + for (uint64_t i = 0; i < (kUnsignedOne << static_cast(src_i_type)); i++) { + u_temp = (u_temp << kBitsOfByte) + + bytes_array[pos + (kUnsignedOne << static_cast(src_i_type)) - kUnsignedOne - i]; + } + + int64_t i_out; + switch (src_i_type) { + case kInt8Type: { + i_out = (int8_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt16Type: { + i_out = (int16_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt32Type: { + i_out = (int32_t)(u_temp & std::numeric_limits::max()); + break; + } + case kInt64Type: { + i_out = (int64_t)(u_temp & std::numeric_limits::max()); + break; + } + default: { + i_out = 0; + } + } + + if (!dst_i_type) { + return i_out; + } + + if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt8Type; + } else if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt16Type; + } else if (i_out >= static_cast(std::numeric_limits::min()) && + i_out <= static_cast(std::numeric_limits::max())) { + *dst_i_type = kInt32Type; + } else { + *dst_i_type = kInt64Type; + } + return i_out; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc new file mode 100644 index 0000000000..4c7abbb4b4 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_distributed_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, + uint32_t seed) + : ShardSample(1, num_shards, shard_id), + shuffle_(shuffle), + no_of_padded_samples_(no_of_padded_samples), + first_epoch_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); +} + +ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) + : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} + +int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (no_of_padded_samples_ <= 0) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } else { + auto padded_size = dataset_size + no_of_padded_samples_; + if (padded_size % denominator_ == 0) { + return padded_size / denominator_ * numerator_; + } else { + return -1; + } + } + return 0; +} + +MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { + auto total_no = tasks.Size(); + if (no_of_padded_samples_ > 0 && first_epoch_) { + if (total_no % denominator_ != 0) { + MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " + << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ + << ", denominator: " << denominator_; + return FAILED; + } + } + if (first_epoch_) { + first_epoch_ = false; + task_ = tasks; + } else { + tasks = task_; + } + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc new file mode 100644 index 0000000000..500037399b --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -0,0 +1,725 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_header.h" + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_page.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +std::atomic thread_status(false); +ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared(); } + +MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { + shard_count_ = headers.size(); + int shard_index = 0; + bool first = true; + for (const auto &header : headers) { + if (first) { + first = false; + if (ParseSchema(header["schema"]) != SUCCESS) { + return FAILED; + } + if (ParseIndexFields(header["index_fields"]) != SUCCESS) { + return FAILED; + } + if (ParseStatistics(header["statistics"]) != SUCCESS) { + return FAILED; + } + ParseShardAddress(header["shard_addresses"]); + header_size_ = header["header_size"].get(); + page_size_ = header["page_size"].get(); + } + ParsePage(header["page"], shard_index, load_dataset); + shard_index++; + } + return SUCCESS; +} + +MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { + std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); + if (!fin) { + MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; + return FAILED; + } + if (fin.fail()) { + MS_LOG(ERROR) << "Failed to open file. path: " << path; + return FAILED; + } + + // fetch file size + auto &io_seekg = fin.seekg(0, std::ios::end); + if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { + fin.close(); + MS_LOG(ERROR) << "File seekg failed"; + return FAILED; + } + + size_t file_size = fin.tellg(); + if (file_size < kMinFileSize) { + fin.close(); + MS_LOG(ERROR) << "File size %d is smaller than the minimum value."; + return FAILED; + } + fin.close(); + return SUCCESS; +} + +std::pair ShardHeader::ValidateHeader(const std::string &path) { + if (CheckFileStatus(path) != SUCCESS) { + return {FAILED, {}}; + } + + // read header size + json json_header; + std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); + if (!fin.is_open()) { + MS_LOG(ERROR) << "File seekg failed"; + return {FAILED, json_header}; + } + + uint64_t header_size = 0; + auto &io_read = fin.read(reinterpret_cast(&header_size), kInt64Len); + if (!io_read.good() || io_read.fail() || io_read.bad()) { + MS_LOG(ERROR) << "File read failed"; + fin.close(); + return {FAILED, json_header}; + } + + if (header_size > kMaxHeaderSize) { + fin.close(); + MS_LOG(ERROR) << "Header size is illegal."; + return {FAILED, json_header}; + } + + // read header content + std::vector header_content(header_size); + auto &io_read_content = fin.read(reinterpret_cast(&header_content[0]), header_size); + if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { + MS_LOG(ERROR) << "File read failed"; + fin.close(); + return {FAILED, json_header}; + } + + fin.close(); + std::string raw_header_content = std::string(header_content.begin(), header_content.end()); + // parse json content + try { + json_header = json::parse(raw_header_content); + } catch (json::parse_error &e) { + MS_LOG(ERROR) << "Json parse error: " << e.what(); + return {FAILED, json_header}; + } + return {SUCCESS, json_header}; +} + +std::pair ShardHeader::BuildSingleHeader(const std::string &file_path) { + auto ret = ValidateHeader(file_path); + if (SUCCESS != ret.first) { + return {FAILED, json()}; + } + json raw_header = ret.second; + json header = {{"shard_addresses", raw_header["shard_addresses"]}, + {"header_size", raw_header["header_size"]}, + {"page_size", raw_header["page_size"]}, + {"index_fields", raw_header["index_fields"]}, + {"blob_fields", raw_header["schema"][0]["blob_fields"]}, + {"schema", raw_header["schema"][0]["schema"]}, + {"version", raw_header["version"]}}; + return {SUCCESS, header}; +} + +MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { + uint32_t thread_num = std::thread::hardware_concurrency(); + if (thread_num == 0) thread_num = kThreadNumber; + uint32_t work_thread_num = 0; + uint32_t shard_count = file_paths.size(); + int group_num = ceil(shard_count * 1.0 / thread_num); + std::vector thread_set(thread_num); + std::vector headers(shard_count); + for (uint32_t x = 0; x < thread_num; ++x) { + int start_num = x * group_num; + int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num; + if (start_num >= end_num) { + continue; + } + + thread_set[x] = + std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths); + work_thread_num++; + } + + for (uint32_t x = 0; x < work_thread_num; ++x) { + thread_set[x].join(); + } + if (thread_status) { + thread_status = false; + return FAILED; + } + if (SUCCESS != InitializeHeader(headers, load_dataset)) { + return FAILED; + } + return SUCCESS; +} + +void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &headers, + const vector &realAddresses) { + if (thread_status || end > realAddresses.size()) { + return; + } + for (int x = start; x < end; ++x) { + auto ret = ValidateHeader(realAddresses[x]); + if (SUCCESS != ret.first) { + thread_status = true; + return; + } + json header; + header = ret.second; + header["shard_addresses"] = realAddresses; + if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { + MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() + << ", lib version is: " << kVersion; + thread_status = true; + return; + } + headers[x] = header; + } +} + +MSRStatus ShardHeader::InitByFiles(const std::vector &file_paths) { + std::vector file_names(file_paths.size()); + std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { + if (GetFileName(fp).first == SUCCESS) { + return GetFileName(fp).second; + } + }); + + shard_addresses_ = std::move(file_names); + shard_count_ = file_paths.size(); + if (shard_count_ == 0) { + return FAILED; + } + if (shard_count_ <= kMaxShardCount) { + pages_.resize(shard_count_); + } else { + return FAILED; + } + return SUCCESS; +} + +void ShardHeader::ParseHeader(const json &header) {} + +MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { + std::vector> parsed_index_fields; + for (auto &index_field : index_fields) { + auto schema_id = index_field["schema_id"].get(); + std::string field_name = index_field["index_field"].get(); + std::pair parsed_index_field(schema_id, field_name); + parsed_index_fields.push_back(parsed_index_field); + } + if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { + return FAILED; + } + return SUCCESS; +} + +void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { + // set shard_index when load_dataset is false + if (pages_.empty() && shard_count_ <= kMaxShardCount) { + pages_.resize(shard_count_); + } + for (auto &page : pages) { + int page_id = page["page_id"]; + int shard_id = page["shard_id"]; + std::string page_type = page["page_type"]; + int page_type_id = page["page_type_id"]; + auto start_row_id = page["start_row_id"].get(); + auto end_row_id = page["end_row_id"].get(); + + std::vector> row_group_ids(page["row_group_ids"].size()); + std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(), + [](json rg) { return std::make_pair(rg["id"], rg["offset"].get()); }); + + auto page_size = page["page_size"].get(); + + std::shared_ptr parsed_page = std::make_shared(page_id, shard_id, page_type, page_type_id, start_row_id, + end_row_id, row_group_ids, page_size); + if (load_dataset == true) { + pages_[shard_id].push_back(std::move(parsed_page)); + } else { + pages_[shard_index].push_back(std::move(parsed_page)); + } + } +} + +MSRStatus ShardHeader::ParseStatistics(const json &statistics) { + for (auto &statistic : statistics) { + if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { + MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); + return FAILED; + } + std::string statistic_description = statistic["desc"].get(); + json statistic_body = statistic["statistics"]; + std::shared_ptr parsed_statistic = Statistics::Build(statistic_description, statistic_body); + if (!parsed_statistic) { + return FAILED; + } + AddStatistic(parsed_statistic); + } + return SUCCESS; +} + +MSRStatus ShardHeader::ParseSchema(const json &schemas) { + for (auto &schema : schemas) { + // change how we get schemaBody once design is finalized + if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || + schema.find("schema") == schema.end()) { + MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); + return FAILED; + } + std::string schema_description = schema["desc"].get(); + std::vector blob_fields = schema["blob_fields"].get>(); + json schema_body = schema["schema"]; + std::shared_ptr parsed_schema = Schema::Build(schema_description, schema_body); + if (!parsed_schema) { + return FAILED; + } + AddSchema(parsed_schema); + } + return SUCCESS; +} + +void ShardHeader::ParseShardAddress(const json &address) { + std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_)); +} + +std::vector ShardHeader::SerializeHeader() { + std::vector header; + auto index = SerializeIndexFields(); + auto stats = SerializeStatistics(); + auto schema = SerializeSchema(); + auto pages = SerializePage(); + auto address = SerializeShardAddress(); + if (shard_count_ > static_cast(pages.size())) { + return std::vector{}; + } + if (shard_count_ <= kMaxShardCount) { + for (int shardId = 0; shardId < shard_count_; shardId++) { + string s; + s += "{\"header_size\":" + std::to_string(header_size_) + ","; + s += "\"index_fields\":" + index + ","; + s += "\"page\":" + pages[shardId] + ","; + s += "\"page_size\":" + std::to_string(page_size_) + ","; + s += "\"schema\":" + schema + ","; + s += "\"shard_addresses\":" + address + ","; + s += "\"shard_id\":" + std::to_string(shardId) + ","; + s += "\"statistics\":" + stats + ","; + s += "\"version\":\"" + std::string(kVersion) + "\""; + s += "}"; + header.emplace_back(s); + } + } + return header; +} + +std::string ShardHeader::SerializeIndexFields() { + json j; + auto fields = index_->GetFields(); + for (const auto &field : fields) { + j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); + } + return j.dump(); +} + +std::vector ShardHeader::SerializePage() { + std::vector pages; + for (auto &shard_pages : pages_) { + json j; + for (const auto &p : shard_pages) { + j.emplace_back(p->GetPage()); + } + pages.emplace_back(j.dump()); + } + return pages; +} + +std::string ShardHeader::SerializeStatistics() { + json j; + for (const auto &stats : statistics_) { + j.emplace_back(stats->GetStatistics()); + } + return j.dump(); +} + +std::string ShardHeader::SerializeSchema() { + json j; + for (const auto &schema : schema_) { + j.emplace_back(schema->GetSchema()); + } + return j.dump(); +} + +std::string ShardHeader::SerializeShardAddress() { + json j; + for (const auto &addr : shard_addresses_) { + j.emplace_back(GetFileName(addr).second); + } + return j.dump(); +} + +std::pair, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { + if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { + return std::make_pair(pages_[shard_id][page_id], SUCCESS); + } else { + return std::make_pair(nullptr, FAILED); + } +} + +MSRStatus ShardHeader::SetPage(const std::shared_ptr &new_page) { + if (new_page == nullptr) { + return FAILED; + } + int shard_id = new_page->GetShardID(); + int page_id = new_page->GetPageID(); + if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { + pages_[shard_id][page_id] = new_page; + return SUCCESS; + } else { + return FAILED; + } +} + +MSRStatus ShardHeader::AddPage(const std::shared_ptr &new_page) { + if (new_page == nullptr) { + return FAILED; + } + int shard_id = new_page->GetShardID(); + int page_id = new_page->GetPageID(); + if (shard_id < static_cast(pages_.size()) && page_id == static_cast(pages_[shard_id].size())) { + pages_[shard_id].push_back(new_page); + return SUCCESS; + } else { + return FAILED; + } +} + +int64_t ShardHeader::GetLastPageId(const int &shard_id) { + if (shard_id >= static_cast(pages_.size())) { + return 0; + } + return pages_[shard_id].size() - 1; +} + +int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) { + if (shard_id >= static_cast(pages_.size())) { + return 0; + } + int last_page_id = -1; + for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { + if (pages_[shard_id][i - 1]->GetPageType() == page_type) { + last_page_id = pages_[shard_id][i - 1]->GetPageID(); + return last_page_id; + } + } + return last_page_id; +} + +const std::pair> ShardHeader::GetPageByGroupId(const int &group_id, + const int &shard_id) { + if (shard_id >= static_cast(pages_.size())) { + MS_LOG(ERROR) << "Shard id is more than sum of shards."; + return {FAILED, nullptr}; + } + for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { + auto page = pages_[shard_id][i - 1]; + if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { + return {SUCCESS, page}; + } + } + MS_LOG(ERROR) << "Could not get page by group id " << group_id; + return {FAILED, nullptr}; +} + +int ShardHeader::AddSchema(std::shared_ptr schema) { + if (schema == nullptr) { + MS_LOG(ERROR) << "Schema is illegal"; + return -1; + } + + if (!schema_.empty()) { + MS_LOG(ERROR) << "Only support one schema"; + return -1; + } + + int64_t schema_id = schema->GetSchemaID(); + if (schema_id == -1) { + schema_id = schema_.size(); + schema->SetSchemaID(schema_id); + } + schema_.push_back(schema); + return schema_id; +} + +void ShardHeader::AddStatistic(std::shared_ptr statistic) { + if (statistic) { + int64_t statistics_id = statistic->GetStatisticsID(); + if (statistics_id == -1) { + statistics_id = statistics_.size(); + statistic->SetStatisticsID(statistics_id); + } + statistics_.push_back(statistic); + } +} + +std::shared_ptr ShardHeader::InitIndexPtr() { + std::shared_ptr index = index_; + if (!index_) { + index = std::make_shared(); + index_ = index; + } + return index; +} + +MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { + // check field name is or is not valid + if (schema.find(field) == schema.end()) { + MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; + return FAILED; + } + + if (schema[field]["type"] == "bytes") { + MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; + return FAILED; + } + + if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { + MS_LOG(ERROR) << field << " array can not be schema index field."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { + // create index Object + std::shared_ptr index = InitIndexPtr(); + + if (fields.size() == kInt0) { + MS_LOG(ERROR) << "There are no index fields"; + return FAILED; + } + + if (GetSchemas().empty()) { + MS_LOG(ERROR) << "No schema is set"; + return FAILED; + } + + for (const auto &schemaPtr : schema_) { + auto result = GetSchemaByID(schemaPtr->GetSchemaID()); + if (result.second != SUCCESS) { + MS_LOG(ERROR) << "Could not get schema by id."; + return FAILED; + } + + if (result.first == nullptr) { + MS_LOG(ERROR) << "Could not get schema by id."; + return FAILED; + } + + json schema = result.first->GetSchema().at("schema"); + + // checkout and add fields for each schema + std::set field_set; + for (const auto &item : index->GetFields()) { + field_set.insert(item.second); + } + for (const auto &field : fields) { + if (field_set.find(field) != field_set.end()) { + MS_LOG(ERROR) << "Add same index field twice"; + return FAILED; + } + + // check field name is or is not valid + if (CheckIndexField(field, schema) == FAILED) { + return FAILED; + } + field_set.insert(field); + + // add field into index + index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); + } + } + + index_ = index; + return SUCCESS; +} + +MSRStatus ShardHeader::GetAllSchemaID(std::set &bucket_count) { + // get all schema id + for (const auto &schema : schema_) { + auto bucket_it = bucket_count.find(schema->GetSchemaID()); + if (bucket_it != bucket_count.end()) { + MS_LOG(ERROR) << "Schema duplication"; + return FAILED; + } else { + bucket_count.insert(schema->GetSchemaID()); + } + } + return SUCCESS; +} + +MSRStatus ShardHeader::AddIndexFields(std::vector> fields) { + // create index Object + std::shared_ptr index = InitIndexPtr(); + + if (fields.size() == kInt0) { + MS_LOG(ERROR) << "There are no index fields"; + return FAILED; + } + + // get all schema id + std::set bucket_count; + if (GetAllSchemaID(bucket_count) != SUCCESS) { + return FAILED; + } + + // check and add fields for each schema + std::set> field_set; + for (const auto &item : index->GetFields()) { + field_set.insert(item); + } + for (const auto &field : fields) { + if (field_set.find(field) != field_set.end()) { + MS_LOG(ERROR) << "Add same index field twice"; + return FAILED; + } + + uint64_t schema_id = field.first; + std::string field_name = field.second; + + // check schemaId is or is not valid + if (bucket_count.find(schema_id) == bucket_count.end()) { + MS_LOG(ERROR) << "Illegal schema id: " << schema_id; + return FAILED; + } + + // check field name is or is not valid + auto result = GetSchemaByID(schema_id); + if (result.second != SUCCESS) { + MS_LOG(ERROR) << "Could not get schema by id."; + return FAILED; + } + json schema = result.first->GetSchema().at("schema"); + if (schema.find(field_name) == schema.end()) { + MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; + return FAILED; + } + + if (CheckIndexField(field_name, schema) == FAILED) { + return FAILED; + } + + field_set.insert(field); + + // add field into index + index.get()->AddIndexField(schema_id, field_name); + } + index_ = index; + return SUCCESS; +} + +std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { + if (shard_id >= shard_addresses_.size()) { + return ""; + } + return shard_addresses_.at(shard_id); +} + +std::vector> ShardHeader::GetSchemas() { return schema_; } + +std::vector> ShardHeader::GetStatistics() { return statistics_; } + +std::vector> ShardHeader::GetFields() { return index_->GetFields(); } + +std::shared_ptr ShardHeader::GetIndex() { return index_; } + +std::pair, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { + int64_t schemaSize = schema_.size(); + if (schema_id < 0 || schema_id >= schemaSize) { + MS_LOG(ERROR) << "Illegal schema id"; + return std::make_pair(nullptr, FAILED); + } + return std::make_pair(schema_.at(schema_id), SUCCESS); +} + +std::pair, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { + int64_t statistics_size = statistics_.size(); + if (statistic_id < 0 || statistic_id >= statistics_size) { + return std::make_pair(nullptr, FAILED); + } + return std::make_pair(statistics_.at(statistic_id), SUCCESS); +} + +MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { + // write header content to file, dump whatever is in the file before + std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); + if (page_out_handle.fail()) { + MS_LOG(ERROR) << "Failed in opening page file"; + return FAILED; + } + + auto pages = SerializePage(); + for (const auto &shard_pages : pages) { + page_out_handle << shard_pages << "\n"; + } + + page_out_handle.close(); + return SUCCESS; +} + +MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { + for (auto &v : pages_) { // clean pages + v.clear(); + } + // attempt to open the file contains the page in json + std::ifstream page_in_handle(dump_file_name.c_str()); + + if (!page_in_handle.good()) { + MS_LOG(INFO) << "No page file exists."; + return SUCCESS; + } + + std::string line; + while (std::getline(page_in_handle, line)) { + ParsePage(json::parse(line), -1, true); + } + + page_in_handle.close(); + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_index.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_index.cc new file mode 100644 index 0000000000..73397b5bba --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_index.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_index.h" + +namespace mindspore { +namespace mindrecord { +// table name for index +const char TABLENAME[] = "index_table"; + +Index::Index() : database_name_(""), table_name_(TABLENAME) {} + +void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { + fields_.emplace_back(pair(schemaId, field)); +} + +// Get attribute list +std::vector> Index::GetFields() { return fields_; } +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_page.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_page.cc new file mode 100644 index 0000000000..ba2292415f --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_page.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_page.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace mindrecord { +json Page::GetPage() const { + json str_page; + str_page["page_id"] = page_id_; + str_page["shard_id"] = shard_id_; + str_page["page_type"] = page_type_; + str_page["page_type_id"] = page_type_id_; + str_page["start_row_id"] = start_row_id_; + str_page["end_row_id"] = end_row_id_; + if (row_group_ids_.size() == 0) { + json row_groups = json({}); + row_groups["id"] = 0; + row_groups["offset"] = 0; + str_page["row_group_ids"].push_back(row_groups); + } else { + for (const auto &rg : row_group_ids_) { + json row_groups = json({}); + row_groups["id"] = rg.first; + row_groups["offset"] = rg.second; + str_page["row_group_ids"].push_back(row_groups); + } + } + str_page["page_size"] = page_size_; + return str_page; +} + +void Page::DeleteLastGroupId() { + if (!row_group_ids_.empty()) { + page_size_ = row_group_ids_.back().second; + row_group_ids_.pop_back(); + } +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc new file mode 100644 index 0000000000..081a48352d --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_pk_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) + : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, + uint32_t seed) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement +} + +MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc new file mode 100644 index 0000000000..808ab55bfb --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc @@ -0,0 +1,141 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardSample::ShardSample(int n) + : numerator_(0), + denominator_(0), + partition_id_(0), + no_of_samples_(n), + indices_({}), + sampler_type_(kCustomTopNSampler) {} + +ShardSample::ShardSample(int num, int den) + : numerator_(num), + denominator_(den), + partition_id_(0), + no_of_samples_(0), + indices_({}), + sampler_type_(kCustomTopPercentSampler) {} + +ShardSample::ShardSample(int num, int den, int par) + : numerator_(num), + denominator_(den), + partition_id_(par), + no_of_samples_(0), + indices_({}), + sampler_type_(kCustomTopPercentSampler) {} + +ShardSample::ShardSample(const std::vector &indices, uint32_t seed) + : numerator_(0), + denominator_(0), + partition_id_(0), + no_of_samples_(0), + indices_(indices), + sampler_type_(kSubsetRandomSampler) { + shuffle_op_ = std::make_shared(seed); +} + +int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (sampler_type_ == kCustomTopNSampler) { + return no_of_samples_; + } + + if (sampler_type_ == kCustomTopPercentSampler) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } + if (sampler_type_ == kSubsetRandomSampler) { + return indices_.size(); + } + return 0; +} + +MSRStatus ShardSample::Execute(ShardTask &tasks) { + int no_of_categories = static_cast(tasks.categories); + int total_no = static_cast(tasks.Size()); // make sure task_size + + int taking = 0; + if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 + no_of_samples_ = std::min(no_of_samples_, total_no); + taking = no_of_samples_ - no_of_samples_ % no_of_categories; + } else if (sampler_type_ == kSubsetRandomSampler) { + if (indices_.size() > total_no) { + MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; + return FAILED; + } + } else { // constructor TopPercent + if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { + if (numerator_ == 1 && denominator_ > 1) { // sharding + taking = (total_no + denominator_ - 1) / denominator_; + } else { // non sharding + taking = total_no * numerator_ / denominator_; + taking -= (taking % no_of_categories); + } + } else { + MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; + return FAILED; + } + } + + if (tasks.permutation_.empty()) { + ShardTask new_tasks; + total_no = static_cast(tasks.Size()); + if (sampler_type_ == kSubsetRandomSampler) { + for (int i = 0; i < indices_.size(); ++i) { + int index = ((indices_[i] % total_no) + total_no) % total_no; + new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python + } + } else { + for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start + } + } + std::swap(tasks, new_tasks); + } else { + ShardTask new_tasks; + if (taking > static_cast(tasks.permutation_.size())) { + return FAILED; + } + total_no = static_cast(tasks.permutation_.size()); + for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { + new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + } + std::swap(tasks, new_tasks); + } + return SUCCESS; +} + +MSRStatus ShardSample::SufExecute(ShardTask &tasks) { + if (sampler_type_ == kSubsetRandomSampler) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc new file mode 100644 index 0000000000..093be9792f --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc @@ -0,0 +1,164 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_schema.h" +#include "common/utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +std::shared_ptr Schema::Build(std::string desc, const json &schema) { + // validate check + if (!Validate(schema)) { + return nullptr; + } + + std::vector blob_fields = PopulateBlobFields(schema); + Schema object_schema; + object_schema.desc_ = std::move(desc); + object_schema.blob_fields_ = std::move(blob_fields); + object_schema.schema_ = schema; + object_schema.schema_id_ = -1; + return std::make_shared(object_schema); +} + +std::shared_ptr Schema::Build(std::string desc, pybind11::handle schema) { + // validate check + json schema_json = nlohmann::detail::ToJsonImpl(schema); + return Build(std::move(desc), schema_json); +} + +std::string Schema::GetDesc() const { return desc_; } + +json Schema::GetSchema() const { + json str_schema; + str_schema["desc"] = desc_; + str_schema["schema"] = schema_; + str_schema["blob_fields"] = blob_fields_; + return str_schema; +} + +pybind11::object Schema::GetSchemaForPython() const { + json schema_json = GetSchema(); + pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); + return schema_py; +} + +void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } + +int64_t Schema::GetSchemaID() const { return schema_id_; } + +std::vector Schema::GetBlobFields() const { return blob_fields_; } + +std::vector Schema::PopulateBlobFields(json schema) { + std::vector blob_fields; + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + json it_value = it.value(); + if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") { + blob_fields.emplace_back(it.key()); + } + } + return blob_fields; +} + +bool Schema::ValidateNumberShape(const json &it_value) { + if (it_value.find("shape") == it_value.end()) { + MS_LOG(ERROR) << "%s supports shape only." << it_value["type"].dump(); + return false; + } + + auto shape = it_value["shape"]; + if (!shape.is_array()) { + MS_LOG(ERROR) << "%s shape format is wrong." << it_value["type"].dump(); + return false; + } + + int num_negtive_one = 0; + for (const auto &i : shape) { + if (i == 0 || i < -1) { + MS_LOG(ERROR) << "Shape %s, number is wrong." << it_value["shape"].dump(); + return false; + } + if (i == -1) { + num_negtive_one++; + } + } + + if (num_negtive_one > 1) { + MS_LOG(ERROR) << "Shape %s, have at most 1 variable-length dimension." << it_value["shape"].dump(); + return false; + } + + return true; +} + +bool Schema::Validate(json schema) { + if (schema.size() == kInt0) { + MS_LOG(ERROR) << "Schema is null"; + return false; + } + + for (json::iterator it = schema.begin(); it != schema.end(); ++it) { + // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_' + if (!ValidateFieldName(it.key())) { + MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', fieldName: " << it.key(); + return false; + } + + json it_value = it.value(); + if (it_value.find("type") == it_value.end()) { + MS_LOG(ERROR) << "No 'type' field exist: " << it_value.dump(); + return false; + } + + if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) { + MS_LOG(ERROR) << "Wrong type: " << it_value["type"].dump(); + return false; + } + + if (it_value.size() == kInt1) { + continue; + } + + if (it_value["type"] == "bytes" || it_value["type"] == "string") { + MS_LOG(ERROR) << it_value["type"].dump() << " can not 1 field only."; + return false; + } + + if (it_value.size() != kInt2) { + MS_LOG(ERROR) << it_value["type"].dump() << " can have at most 2 fields."; + return false; + } + + if (!ValidateNumberShape(it_value)) { + return false; + } + } + + return true; +} + +bool Schema::operator==(const mindrecord::Schema &b) const { + if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { + return false; + } + return true; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc new file mode 100644 index 0000000000..3aa695e03b --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_sequential_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardSequentialSample::ShardSequentialSample(int n, int offset) + : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} + +ShardSequentialSample::ShardSequentialSample(float per, float per_offset) + : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} + +int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { + return dataset_size; + } + if (per_ > kEpsilon && per_ <= 1.0f) { + return dataset_size * kEpsilon; + } + return no_of_samples_; +} + +MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { + int total_no = static_cast(tasks.Size()); + int taking; + if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { + taking = total_no; + } else if (per_ > kEpsilon && per_ <= 1.0f) { + taking = total_no * kEpsilon; + } else { + taking = no_of_samples_; + } + + if (tasks.permutation_.empty()) { + ShardTask new_tasks; + total_no = static_cast(tasks.Size()); + for (int i = offset_; i < taking + offset_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); + } + std::swap(tasks, new_tasks); + } else { // shuffled + ShardTask new_tasks; + if (taking > static_cast(tasks.permutation_.size())) { + return FAILED; + } + total_no = static_cast(tasks.permutation_.size()); + for (size_t i = offset_; i < taking + offset_; ++i) { + new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); + } + std::swap(tasks, new_tasks); + } + return SUCCESS; +} + +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc new file mode 100644 index 0000000000..7743cabea3 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -0,0 +1,88 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_shuffle.h" + +#include + +namespace mindspore { +namespace mindrecord { +ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) + : shuffle_seed_(seed), + no_of_samples_(0), + replacement_(false), + reshuffle_each_epoch_(true), + shuffle_type_(shuffle_type) {} + +ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, + ShuffleType shuffle_type) + : shuffle_seed_(seed), + no_of_samples_(no_of_samples), + replacement_(replacement), + reshuffle_each_epoch_(reshuffle_each_epoch), + shuffle_type_(shuffle_type) {} + +int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (replacement_) { + return no_of_samples_ == 0 ? dataset_size : no_of_samples_; + } + return dataset_size; +} + +MSRStatus ShardShuffle::Execute(ShardTask &tasks) { + if (reshuffle_each_epoch_) shuffle_seed_++; + if (tasks.categories < 1) { + return FAILED; + } + if (shuffle_type_ == kShuffleSample) { // shuffle each sample + if (tasks.permutation_.empty() == true) { + tasks.MakePerm(); + } + if (replacement_ == true) { + ShardTask new_tasks; + if (no_of_samples_ == 0) { + no_of_samples_ = static_cast(tasks.Size()); + } + if (no_of_samples_ <= 0) { + MS_LOG(ERROR) << "no_of_samples need to be positive."; + return FAILED; + } + new_tasks.task_list_.reserve(no_of_samples_); + for (uint32_t i = 0; i < no_of_samples_; ++i) { + new_tasks.InsertTask(tasks.GetRandomTask()); + } + std::swap(tasks, new_tasks); + } else { + std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + } + } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) + uint32_t individual_size = tasks.Size() / tasks.categories; + std::vector> new_permutations(tasks.categories, std::vector(individual_size)); + for (uint32_t i = 0; i < tasks.categories; i++) { + for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); + std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); + } + tasks.permutation_.clear(); + for (uint32_t j = 0; j < individual_size; j++) { + for (uint32_t i = 0; i < tasks.categories; i++) { + tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + } + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc new file mode 100644 index 0000000000..7024a2ab06 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_statistics.h" +#include "pybind11/pybind11.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +std::shared_ptr Statistics::Build(std::string desc, const json &statistics) { + // validate check + if (!Validate(statistics)) { + return nullptr; + } + Statistics object_statistics; + object_statistics.desc_ = std::move(desc); + object_statistics.statistics_ = statistics; + object_statistics.statistics_id_ = -1; + return std::make_shared(object_statistics); +} + +std::shared_ptr Statistics::Build(std::string desc, pybind11::handle statistics) { + // validate check + json statistics_json = nlohmann::detail::ToJsonImpl(statistics); + if (!Validate(statistics_json)) { + return nullptr; + } + Statistics object_statistics; + object_statistics.desc_ = std::move(desc); + object_statistics.statistics_ = statistics_json; + object_statistics.statistics_id_ = -1; + return std::make_shared(object_statistics); +} + +std::string Statistics::GetDesc() const { return desc_; } + +json Statistics::GetStatistics() const { + json str_statistics; + str_statistics["desc"] = desc_; + str_statistics["statistics"] = statistics_; + return str_statistics; +} + +pybind11::object Statistics::GetStatisticsForPython() const { + json str_statistics = Statistics::GetStatistics(); + return nlohmann::detail::FromJsonImpl(str_statistics); +} + +void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } + +int64_t Statistics::GetStatisticsID() const { return statistics_id_; } + +bool Statistics::Validate(const json &statistics) { + if (statistics.size() != kInt1) { + MS_LOG(ERROR) << "Statistics object is null"; + return false; + } + if (statistics.find("level") == statistics.end()) { + MS_LOG(ERROR) << "There is not 'level' object in statistic"; + return false; + } + return LevelRecursive(statistics["level"]); +} + +bool Statistics::LevelRecursive(json level) { + bool ini = true; + for (json::iterator it = level.begin(); it != level.end(); ++it) { + json a = it.value(); + if (a.size() == kInt2) { + if ((a.find("key") == a.end()) || (a.find("count") == a.end())) { + MS_LOG(ERROR) << "The node field is 2, but 'key'/'count' is not existed"; + return false; + } + } else if (a.size() == kInt3) { + if ((a.find("key") == a.end()) || (a.find("count") == a.end()) || a.find("level") == a.end()) { + MS_LOG(ERROR) << "The node field is 3, but 'key'/'count'/'level' is not existed"; + return false; + } else { + ini = LevelRecursive(a.at("level")); + } + } else { + MS_LOG(ERROR) << "The node field is not equal 2/3"; + return false; + } + } + return ini; +} + +bool Statistics::operator==(const Statistics &b) const { + if (this->GetStatistics() != b.GetStatistics()) { + return false; + } + return true; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc new file mode 100644 index 0000000000..6f8e440f91 --- /dev/null +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task.cc @@ -0,0 +1,121 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/mindrecord/include/shard_task.h" +#include "common/utils.h" +#include "minddata/mindrecord/include/common/shard_utils.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::DEBUG; + +namespace mindspore { +namespace mindrecord { +ShardTask::ShardTask() : categories(1) {} + +ShardTask::ShardTask(const ShardTask &other) + : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} + +ShardTask &ShardTask::operator=(const ShardTask &other) { + ShardTask tmp(other); + std::swap(categories, tmp.categories); + permutation_.swap(tmp.permutation_); + task_list_.swap(tmp.task_list_); + return *this; +} + +void ShardTask::MakePerm() { + permutation_ = std::vector(task_list_.size()); + for (uint32_t i = 0; i < task_list_.size(); i++) { + permutation_[i] = static_cast(i); + } +} + +void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, + const json &label) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id + << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; + task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); +} + +void ShardTask::InsertTask(std::tuple, std::vector, json> task) { + MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) + << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() + << ", size of task_list_: " << task_list_.size() << "."; + + task_list_.push_back(std::move(task)); +} + +void ShardTask::PopBack() { task_list_.pop_back(); } + +uint32_t ShardTask::Size() const { return static_cast(task_list_.size()); } + +uint32_t ShardTask::SizeOfRows() const { + if (task_list_.size() == 0) return static_cast(0); + + // 1 task is 1 page + auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { + return x + std::get<2>(y)[0]; + }; + uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); + return nRows; +} + +std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { + MS_ASSERT(id < task_list_.size()); + return task_list_[id]; +} + +std::tuple, std::vector, json> &ShardTask::GetRandomTask() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + return task_list_[dis(gen)]; +} + +ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { + ShardTask res; + if (category_tasks.empty()) return res; + auto total_categories = category_tasks.size(); + res.categories = static_cast(total_categories); + if (replacement == false) { + auto minTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + minTasks = std::min(minTasks, category_tasks[i].Size()); + } + for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + for (uint32_t i = 0; i < total_categories; i++) { + res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast(task_no)))); + } + } + } else { + auto maxTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + maxTasks = std::max(maxTasks, category_tasks[i].Size()); + } + if (num_elements != std::numeric_limits::max()) { + maxTasks = static_cast(num_elements); + } + for (uint32_t i = 0; i < total_categories; i++) { + for (uint32_t j = 0; j < maxTasks; j++) { + res.InsertTask(category_tasks[i].GetRandomTask()); + } + } + } + return res; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/common/shard_error.cc b/mindspore/ccsrc/mindrecord/common/shard_error.cc deleted file mode 100644 index ad68aaf92c..0000000000 --- a/mindspore/ccsrc/mindrecord/common/shard_error.cc +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_error.h" - -namespace mindspore { -namespace mindrecord { -std::string ErrnoToMessage(MSRStatus status) { - switch (status) { - case FAILED: - return "operator failed"; - break; - case SUCCESS: - return "operator success"; - break; - case OPEN_FILE_FAILED: - return "open file failed"; - break; - case CLOSE_FILE_FAILED: - return "close file failed"; - break; - case WRITE_METADATA_FAILED: - return "write metadata failed"; - break; - case WRITE_RAWDATA_FAILED: - return "write rawdata failed"; - break; - case GET_SCHEMA_FAILED: - return "get schema failed"; - break; - case ILLEGAL_RAWDATA: - return "illegal raw data"; - break; - case PYTHON_TO_JSON_FAILED: - return "pybind: python object to json failed"; - break; - case DIR_CREATE_FAILED: - return "directory create failed"; - break; - case OPEN_DIR_FAILED: - return "open directory failed"; - break; - case INVALID_STATISTICS: - return "invalid statistics object"; - break; - case OPEN_DATABASE_FAILED: - return "open database failed"; - break; - case CLOSE_DATABASE_FAILED: - return "close database failed"; - break; - case DATABASE_OPERATE_FAILED: - return "database operate failed"; - break; - case BUILD_SCHEMA_FAILED: - return "build schema failed"; - break; - case DIVISOR_IS_ILLEGAL: - return "divisor is illegal"; - break; - case INVALID_FILE_PATH: - return "file path is invalid"; - break; - case SECURE_FUNC_FAILED: - return "secure function failed"; - break; - case ALLOCATE_MEM_FAILED: - return "allocate memory failed"; - break; - case ILLEGAL_FIELD_NAME: - return "illegal field name"; - break; - case ILLEGAL_FIELD_TYPE: - return "illegal field type"; - break; - case SET_METADATA_FAILED: - return "set metadata failed"; - break; - case ILLEGAL_SCHEMA_DEFINITION: - return "illegal schema definition"; - break; - case ILLEGAL_COLUMN_LIST: - return "illegal column list"; - break; - case SQL_ERROR: - return "sql error"; - break; - case ILLEGAL_SHARD_COUNT: - return "illegal shard count"; - break; - case ILLEGAL_SCHEMA_COUNT: - return "illegal schema count"; - break; - case VERSION_ERROR: - return "data version is not matched"; - break; - case ADD_SCHEMA_FAILED: - return "add schema failed"; - break; - case ILLEGAL_Header_SIZE: - return "illegal header size"; - break; - case ILLEGAL_Page_SIZE: - return "illegal page size"; - break; - case ILLEGAL_SIZE_VALUE: - return "illegal size value"; - break; - case INDEX_FIELD_ERROR: - return "add index fields failed"; - break; - case GET_CANDIDATE_CATEGORYFIELDS_FAILED: - return "get candidate category fields failed"; - break; - case GET_CATEGORY_INFO_FAILED: - return "get category information failed"; - break; - case ILLEGAL_CATEGORY_ID: - return "illegal category id"; - break; - case ILLEGAL_ROWNUMBER_OF_PAGE: - return "illegal row number of page"; - break; - case ILLEGAL_SCHEMA_ID: - return "illegal schema id"; - break; - case DESERIALIZE_SCHEMA_FAILED: - return "deserialize schema failed"; - break; - case DESERIALIZE_STATISTICS_FAILED: - return "deserialize statistics failed"; - break; - case ILLEGAL_DB_FILE: - return "illegal db file"; - break; - case OVERWRITE_DB_FILE: - return "overwrite db file"; - break; - case OVERWRITE_MINDRECORD_FILE: - return "overwrite mindrecord file"; - break; - case ILLEGAL_MINDRECORD_FILE: - return "illegal mindrecord file"; - break; - case PARSE_JSON_FAILED: - return "parse json failed"; - break; - case ILLEGAL_PARAMETERS: - return "illegal parameters"; - break; - case GET_PAGE_BY_GROUP_ID_FAILED: - return "get page by group id failed"; - break; - case GET_SYSTEM_STATE_FAILED: - return "get system state failed"; - break; - case IO_FAILED: - return "io operate failed"; - break; - case MATCH_HEADER_FAILED: - return "match header failed"; - break; - default: - return "invalid error no"; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc deleted file mode 100644 index ee923ebc97..0000000000 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ /dev/null @@ -1,230 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_segment.h" -#include "mindrecord/include/shard_writer.h" -#include "nlohmann/json.hpp" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace py = pybind11; - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -void BindSchema(py::module *m) { - (void)py::class_>(*m, "Schema", py::module_local()) - .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Schema::Build) - .def("get_desc", &Schema::GetDesc) - .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) - .def("get_blob_fields", &Schema::GetBlobFields) - .def("get_schema_id", &Schema::GetSchemaID); -} - -void BindStatistics(const py::module *m) { - (void)py::class_>(*m, "Statistics", py::module_local()) - .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Statistics::Build) - .def("get_desc", &Statistics::GetDesc) - .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) - .def("get_statistics_id", &Statistics::GetStatisticsID); -} - -void BindShardHeader(const py::module *m) { - (void)py::class_>(*m, "ShardHeader", py::module_local()) - .def(py::init<>()) - .def("add_schema", &ShardHeader::AddSchema) - .def("add_statistics", &ShardHeader::AddStatistic) - .def("add_index_fields", - (MSRStatus(ShardHeader::*)(const std::vector &)) & ShardHeader::AddIndexFields) - .def("get_meta", &ShardHeader::GetSchemas) - .def("get_statistics", &ShardHeader::GetStatistics) - .def("get_fields", &ShardHeader::GetFields) - .def("get_schema_by_id", &ShardHeader::GetSchemaByID) - .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); -} - -void BindShardWriter(py::module *m) { - (void)py::class_(*m, "ShardWriter", py::module_local()) - .def(py::init<>()) - .def("open", &ShardWriter::Open) - .def("open_for_append", &ShardWriter::OpenForAppend) - .def("set_header_size", &ShardWriter::SetHeaderSize) - .def("set_page_size", &ShardWriter::SetPageSize) - .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, - vector> &, bool, bool)) & - ShardWriter::WriteRawData) - .def("commit", &ShardWriter::Commit); -} - -void BindShardReader(const py::module *m) { - (void)py::class_>(*m, "ShardReader", py::module_local()) - .def(py::init<>()) - .def("open", (MSRStatus(ShardReader::*)(const std::vector &, bool, const int &, - const std::vector &, - const std::vector> &)) & - ShardReader::OpenPy) - .def("launch", &ShardReader::Launch) - .def("get_header", &ShardReader::GetShardHeader) - .def("get_blob_fields", &ShardReader::GetBlobFields) - .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & - ShardReader::GetNextPy) - .def("finish", &ShardReader::Finish) - .def("close", &ShardReader::Close); -} - -void BindShardIndexGenerator(const py::module *m) { - (void)py::class_(*m, "ShardIndexGenerator", py::module_local()) - .def(py::init()) - .def("build", &ShardIndexGenerator::Build) - .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); -} - -void BindShardSegment(py::module *m) { - (void)py::class_(*m, "ShardSegment", py::module_local()) - .def(py::init<>()) - .def("open", (MSRStatus(ShardSegment::*)(const std::vector &, bool, const int &, - const std::vector &, - const std::vector> &)) & - ShardSegment::OpenPy) - .def("get_category_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) - .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) - .def("read_category_info", (std::pair(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) - .def("read_at_page_by_id", (std::pair, pybind11::object>>>( - ShardSegment::*)(int64_t, int64_t, int64_t)) & - ShardSegment::ReadAtPageByIdPy) - .def("read_at_page_by_name", (std::pair, pybind11::object>>>( - ShardSegment::*)(std::string, int64_t, int64_t)) & - ShardSegment::ReadAtPageByNamePy) - .def("get_header", &ShardSegment::GetShardHeader) - .def("get_blob_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::GetBlobFields); -} - -void BindGlobalParams(py::module *m) { - (*m).attr("MIN_HEADER_SIZE") = kMinHeaderSize; - (*m).attr("MAX_HEADER_SIZE") = kMaxHeaderSize; - (*m).attr("MIN_PAGE_SIZE") = kMinPageSize; - (*m).attr("MAX_PAGE_SIZE") = kMaxPageSize; - (*m).attr("MIN_SHARD_COUNT") = kMinShardCount; - (*m).attr("MAX_SHARD_COUNT") = kMaxShardCount; - (*m).attr("MIN_CONSUMER_COUNT") = kMinConsumerCount; - (void)(*m).def("get_max_thread_num", &GetMaxThreadNum); -} - -PYBIND11_MODULE(_c_mindrecord, m) { - m.doc() = "pybind11 mindrecord plugin"; // optional module docstring - (void)py::enum_(m, "MSRStatus", py::module_local()) - .value("SUCCESS", SUCCESS) - .value("FAILED", FAILED) - .export_values(); - (void)py::enum_(m, "ShardType", py::module_local()).value("NLP", kNLP).value("CV", kCV).export_values(); - BindGlobalParams(&m); - BindSchema(&m); - BindStatistics(&m); - BindShardHeader(&m); - BindShardWriter(&m); - BindShardReader(&m); - BindShardIndexGenerator(&m); - BindShardSegment(&m); -} -} // namespace mindrecord -} // namespace mindspore - -namespace nlohmann { -namespace detail { -py::object FromJsonImpl(const json &j) { - if (j.is_null()) { - return py::none(); - } else if (j.is_boolean()) { - return py::bool_(j.get()); - } else if (j.is_number()) { - double number = j.get(); - if (fabs(number - std::floor(number)) < mindspore::mindrecord::kEpsilon) { - return py::int_(j.get()); - } else { - return py::float_(number); - } - } else if (j.is_string()) { - return py::str(j.get()); - } else if (j.is_array()) { - py::list obj; - for (const auto &el : j) { - (void)obj.attr("append")(FromJsonImpl(el)); - } - return std::move(obj); - } else { - py::dict obj; - for (json::const_iterator it = j.cbegin(); it != j.cend(); ++it) { - obj[py::str(it.key())] = FromJsonImpl(it.value()); - } - return std::move(obj); - } -} - -json ToJsonImpl(const py::handle &obj) { - if (obj.is_none()) { - return nullptr; - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj)) { - return obj.cast(); - } - if (py::isinstance(obj) || py::isinstance(obj)) { - auto out = json::array(); - for (const py::handle &value : obj) { - out.push_back(ToJsonImpl(value)); - } - return out; - } - if (py::isinstance(obj)) { - auto out = json::object(); - for (const py::handle &key : obj) { - out[py::str(key).cast()] = ToJsonImpl(obj[key]); - } - return out; - } - MS_LOG(ERROR) << "Python to json failed, obj is: " << py::cast(obj); - return json(); -} -} // namespace detail - -py::object adl_serializer::FromJson(const json &j) { return detail::FromJsonImpl(j); } - -void adl_serializer::ToJson(json *j, const py::object &obj) { - *j = detail::ToJsonImpl(obj); -} // namespace detail -} // namespace nlohmann diff --git a/mindspore/ccsrc/mindrecord/common/shard_utils.cc b/mindspore/ccsrc/mindrecord/common/shard_utils.cc deleted file mode 100644 index edeabb3cde..0000000000 --- a/mindspore/ccsrc/mindrecord/common/shard_utils.cc +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/common/shard_utils.h" -#include "common/utils.h" -#include "./securec.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -// split a string using a character -std::vector StringSplit(const std::string &field, char separator) { - std::vector res; - uint64_t s_pos = 0; - while (s_pos < field.length()) { - size_t e_pos = field.find_first_of(separator, s_pos); - if (e_pos != std::string::npos) { - res.push_back(field.substr(s_pos, e_pos - s_pos)); - } else { - res.push_back(field.substr(s_pos, field.length() - s_pos)); - break; - } - s_pos = e_pos + 1; - } - return res; -} - -bool ValidateFieldName(const std::string &str) { - std::string::const_iterator it = str.begin(); - if (it == str.end()) { - return false; - } - for (; it != str.end(); ++it) { - if (*it == '_' || ((*it >= '0') && (*it <= '9')) || ((*it >= 'A') && (*it <= 'Z')) || - ((*it >= 'a') && (*it <= 'z'))) { - continue; - } - return false; - } - return true; -} - -std::pair GetFileName(const std::string &path) { - char real_path[PATH_MAX] = {0}; - char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; - return {FAILED, ""}; - } - char tmp[PATH_MAX] = {0}; -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; - } -#else - if (realpath(dirname(&(buf[0])), tmp) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (realpath(common::SafeCStr(path), real_path) == nullptr) { - MS_LOG(DEBUG) << "Path: " << path << "check successfully"; - } -#endif - std::string s = real_path; - char sep = '/'; - size_t i = s.rfind(sep, s.length()); - if (i != std::string::npos) { - if (i + 1 < s.size()) { - return {SUCCESS, s.substr(i + 1)}; - } - } - return {SUCCESS, s}; -} - -std::pair GetParentDir(const std::string &path) { - char real_path[PATH_MAX] = {0}; - char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; - return {FAILED, ""}; - } - char tmp[PATH_MAX] = {0}; -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; - } -#else - if (realpath(dirname(&(buf[0])), tmp) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; - } - if (realpath(common::SafeCStr(path), real_path) == nullptr) { - MS_LOG(DEBUG) << "Path: " << path << "check successfully"; - } -#endif - std::string s = real_path; - if (s.rfind('/') + 1 <= s.size()) { - return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; - } - return {SUCCESS, "/"}; -} - -bool CheckIsValidUtf8(const std::string &str) { - int n = 0; - int ix = str.length(); - for (int i = 0; i < ix; ++i) { - uint8_t c = static_cast(str[i]); - if (c <= 0x7f) { - n = 0; - } else if ((c & 0xE0) == 0xC0) { - n = 1; - } else if (c == 0xed && i < (ix - 1) && (static_cast(str[i + 1]) & 0xa0) == 0xa0) { - return false; - } else if ((c & 0xF0) == 0xE0) { - n = 2; - } else if ((c & 0xF8) == 0xF0) { - n = 3; - } else { - return false; - } - for (int j = 0; j < n && i < ix; ++j) { - if ((++i == ix) || ((static_cast(str[i]) & 0xC0) != 0x80)) { - return false; - } - } - } - return true; -} - -bool IsLegalFile(const std::string &path) { - struct stat s; - if (stat(common::SafeCStr(path), &s) == 0) { - if (s.st_mode & S_IFDIR) { - return false; - } - return true; - } - return false; -} - -std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { -#if defined(_WIN32) || defined(_WIN64) - return {SUCCESS, 100}; -#else - uint64_t ll_count = 0; - struct statfs disk_info; - if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { - MS_LOG(ERROR) << "Get disk size error"; - return {FAILED, 0}; - } - - switch (disk_type) { - case kTotalSize: - ll_count = disk_info.f_bsize * disk_info.f_blocks; - ll_count = ll_count >> 20; - break; - case kFreeSize: - ll_count = disk_info.f_bsize * disk_info.f_bavail; - ll_count = ll_count >> 20; - break; - default: - ll_count = 0; - break; - } - - return {SUCCESS, ll_count}; -#endif -} - -uint32_t GetMaxThreadNum() { - // define the number of thread - uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) { - thread_num = kMaxConsumerCount; - } - return thread_num; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_pybind.h b/mindspore/ccsrc/mindrecord/include/common/shard_pybind.h deleted file mode 100644 index 86c71a0ea7..0000000000 --- a/mindspore/ccsrc/mindrecord/include/common/shard_pybind.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ -#define MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ - -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "pybind11/pybind11.h" - -namespace py = pybind11; -namespace nlohmann { -template <> -struct adl_serializer { - py::object FromJson(const json &j); - - void ToJson(json *j, const py::object &obj); -}; - -namespace detail { -py::object FromJsonImpl(const json &j); - -json ToJsonImpl(const py::handle &obj); -} // namespace detail -} // namespace nlohmann -#endif // MINDRECORD_INCLUDE_COMMON_SHARD_PYBIND_H_ diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h deleted file mode 100644 index 8aa5bdfbda..0000000000 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ -#define MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ - -#include -#include -#include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_error.h" -#include "nlohmann/json.hpp" -#include "./sqlite3.h" -#include "utils/log_adapter.h" - -/* To be used when dlog is ok #include "./slog.h" */ -#ifdef DEBUG -#define MS_ASSERT(f) assert(f) -#else -#define MS_ASSERT(f) ((void)0) -#endif - -namespace mindspore { -namespace mindrecord { -using json = nlohmann::json; - -const int kInt0 = 0; -const int kInt1 = 1; -const int kInt2 = 2; -const int kInt3 = 3; -const int kUnsignedInt4 = 4; - -enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel }; - -const char kVersion[] = "3.0"; -const std::vector kSupportedVersion = {"2.0", kVersion}; - -enum ShardType { - kNLP = 0, - kCV = 1, -}; - -enum TaskType { - kCommonTask = 0, - kPaddedTask = 1, -}; -enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; - -enum ShuffleType { kShuffleCategory, kShuffleSample }; - -const double kEpsilon = 1e-7; - -const int kThreadNumber = 14; - -// Shard default parameters -const uint64_t kDefaultHeaderSize = 1 << 24; // 16MB -const uint64_t kDefaultPageSize = 1 << 25; // 32MB - -// HeaderSize [16KB, 128MB] -const int kMinHeaderSize = 1 << 14; // 16KB -const int kMaxHeaderSize = 1 << 27; // 128MB - -// PageSize [32KB, 256MB] -const int kMinPageSize = 1 << 15; // 32KB -const int kMaxPageSize = 1 << 28; // 256MB - -// used by value length / schema id length / statistic id length ... -const uint64_t kInt64Len = 8; - -// Minimum file size -const uint64_t kMinFileSize = kInt64Len; - -const int kMinShardCount = 1; -const int kMaxShardCount = 1000; - -const int kMinConsumerCount = 1; -const int kMaxConsumerCount = 128; - -const int kMaxSchemaCount = 1; -const int kMaxThreadCount = 32; -const int kMaxFieldCount = 100; - -// Minimum free disk size -const int kMinFreeDiskSize = 10; // 10M - -// dummy json -const json kDummyId = R"({"id": 0})"_json; - -// translate type in schema to type in sqlite3(NULL, INTEGER, REAL, TEXT, BLOB) -const std::unordered_map kDbJsonMap = { - {"string", "TEXT"}, {"date", "DATE"}, {"date-time", "DATETIME"}, {"null", "NULL"}, - {"integer", "INTEGER"}, {"boolean", "BOOLEAN"}, {"array", "BLOB"}, {"number", "NUMERIC"}, - {"int32", "INTEGER"}, {"int64", "INTEGER"}, {"float32", "NUMERIC"}, {"float64", "NUMERIC"}, - {"bytes", "BLOB"}}; - -const char kPoint = '.'; - -// field type used by check schema validation -const std::set kFieldTypeSet = {"bytes", "string", "int32", "int64", "float32", "float64"}; - -// can be searched field list -const std::set kScalarFieldTypeSet = {"string", "int32", "int64", "float32", "float64"}; - -// number field list -const std::set kNumberFieldTypeSet = {"int32", "int64", "float32", "float64"}; - -/// \brief split a string using a character -/// \param[in] field target string -/// \param[in] separator a character for spliting -/// \return vector type result -std::vector StringSplit(const std::string &field, char separator); - -/// \brief validate field name is composed of '0-9' or 'a-z' or 'A-Z' or '_' or '-' -/// \param[in] str target string -/// \return -bool ValidateFieldName(const std::string &str); - -/// \brief get the filename by the path -/// \param s file path -/// \return -std::pair GetFileName(const std::string &s); - -/// \brief get parent dir -/// \param path file path -/// \return parent path -std::pair GetParentDir(const std::string &path); - -bool CheckIsValidUtf8(const std::string &str); - -/// \brief judge if a path is legal file -/// \param path file path -/// \return parent path -bool IsLegalFile(const std::string &path); - -enum DiskSizeType { kTotalSize = 0, kFreeSize }; - -/// \brief get the free space about the disk -/// \param str_dir file path -/// \param disk_type: kTotalSize / kFreeSize -/// \return size in Megabytes -std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); - -/// \brief get the max hardware concurrency -/// \return max concurrency -uint32_t GetMaxThreadNum(); -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_COMMON_SHARD_UTILS_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h deleted file mode 100644 index 618a91b1d8..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ -#define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" - -namespace mindspore { -namespace mindrecord { -class ShardCategory : public ShardOperator { - public: - explicit ShardCategory(const std::vector> &categories, - int64_t num_elements = std::numeric_limits::max(), bool replacement = false); - - ShardCategory(const std::string &category_field, int64_t num_elements, - int64_t num_categories = std::numeric_limits::max(), bool replacement = false); - - ~ShardCategory() override{}; - - const std::vector> &GetCategories() const { return categories_; } - - const std::string GetCategoryField() const { return category_field_; } - - int64_t GetNumElements() const { return num_elements_; } - - int64_t GetNumCategories() const { return num_categories_; } - - bool GetReplacement() const { return replacement_; } - - MSRStatus Execute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - std::vector> categories_; - std::string category_field_; - int64_t num_elements_; - int64_t num_categories_; - bool replacement_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_column.h b/mindspore/ccsrc/mindrecord/include/shard_column.h deleted file mode 100644 index 968d82e717..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_column.h +++ /dev/null @@ -1,167 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_ -#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_header.h" - -namespace mindspore { -namespace mindrecord { -const uint64_t kUnsignedOne = 1; -const uint64_t kBitsOfByte = 8; -const uint64_t kDataTypeBits = 2; -const uint64_t kNumDataOfByte = 4; -const uint64_t kBytesOfColumnLen = 4; -const uint64_t kDataTypeBitMask = 3; -const uint64_t kDataTypes = 6; - -enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type }; - -enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound }; - -enum ColumnDataType { - ColumnBytes = 0, - ColumnString = 1, - ColumnInt32 = 2, - ColumnInt64 = 3, - ColumnFloat32 = 4, - ColumnFloat64 = 5, - ColumnNoDataType = 6 -}; - -// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"}; -const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8}; - -const std::vector ColumnDataTypeNameNormalized = {"uint8", "string", "int32", - "int64", "float32", "float64"}; - -const std::unordered_map ColumnDataTypeMap = { - {"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32}, - {"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}}; - -class ShardColumn { - public: - explicit ShardColumn(const std::shared_ptr &shard_header, bool compress_integer = true); - - ~ShardColumn() = default; - - /// \brief get column value by column name - MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, - const json &columns_json, const unsigned char **data, - std::unique_ptr *data_ptr, uint64_t *const n_bytes, - ColumnDataType *column_data_type, uint64_t *column_data_type_size, - std::vector *column_shape); - - /// \brief compress blob - std::vector CompressBlob(const std::vector &blob); - - /// \brief check if blob compressed - bool CheckCompressBlob() const { return has_compress_blob_; } - - uint64_t GetNumBlobColumn() const { return num_blob_column_; } - - std::vector GetColumnName() { return column_name_; } - - std::vector GeColumnDataType() { return column_data_type_; } - - std::vector> GetColumnShape() { return column_shape_; } - - /// \brief get column value from blob - MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, - const unsigned char **data, std::unique_ptr *data_ptr, - uint64_t *const n_bytes); - std::pair GetColumnTypeByName(const std::string &column_name, - ColumnDataType *column_data_type, - uint64_t *column_data_type_size, - std::vector *column_shape); - - /// \brief get column value from json - MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, - std::unique_ptr *data_ptr, uint64_t *n_bytes); - - private: - /// \brief get float value from json - template - MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); - - /// \brief get integer value from json - template - MSRStatus GetInt(std::unique_ptr *data_ptr, const json &json_column_value); - - /// \brief get column offset address and size from blob - MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, - uint64_t *num_bytes, uint64_t *shift_idx); - - /// \brief check if column name is available - ColumnCategory CheckColumnName(const std::string &column_name); - - /// \brief compress integer column - static vector CompressInt(const vector &src_bytes, const IntegerType &int_type); - - /// \brief uncompress integer array column - template - static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, - const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); - - /// \brief convert big-endian bytes to unsigned int - /// \param bytes_array bytes array - /// \param pos shift address in bytes array - /// \param i_type integer type - /// \return unsigned int - static uint64_t BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &i_type); - - /// \brief convert unsigned int to big-endian bytes - /// \param value integer value - /// \param i_type integer type - /// \return bytes - static std::vector UIntToBytesBig(uint64_t value, const IntegerType &i_type); - - /// \brief convert unsigned int to little-endian bytes - /// \param value integer value - /// \param i_type integer type - /// \return bytes - static std::vector UIntToBytesLittle(uint64_t value, const IntegerType &i_type); - - /// \brief convert unsigned int to little-endian bytes - /// \param bytes_array bytes array - /// \param pos shift address in bytes array - /// \param src_i_type source integer typ0e - /// \param dst_i_type (output), destination integer type - /// \return integer - static int64_t BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr); - - private: - std::vector column_name_; // column name list - std::vector column_data_type_; // column data type list - std::vector> column_shape_; // column shape list - std::unordered_map column_name_id_; // column name id map - std::vector blob_column_; // blob column list - std::unordered_map blob_column_id_; // blob column name id map - bool has_compress_blob_; // if has compress blob - uint64_t num_blob_column_; // number of blob columns -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h deleted file mode 100644 index ef0ad738c4..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_shuffle.h" -#include "mindrecord/include/shard_sample.h" - -namespace mindspore { -namespace mindrecord { -class ShardDistributedSample : public ShardSample { - public: - ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); - - ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); - - void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } - - ~ShardDistributedSample() override{}; - - MSRStatus PreExecute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - bool shuffle_; - int no_of_padded_samples_; - bool first_epoch_; // check (num_sample + num_padded) % num_shards == 0 in first epoch - ShardTask task_; // maintain the input tasks in first epoch -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_DISTRIBUTED_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h deleted file mode 100644 index e4361c466a..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ /dev/null @@ -1,186 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_HEADER_H_ -#define MINDRECORD_INCLUDE_SHARD_HEADER_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_page.h" -#include "mindrecord/include/shard_schema.h" -#include "mindrecord/include/shard_statistics.h" - -namespace mindspore { -namespace mindrecord { -class ShardHeader { - public: - ShardHeader(); - - ~ShardHeader() = default; - - MSRStatus BuildDataset(const std::vector &file_paths, bool load_dataset = true); - - static std::pair BuildSingleHeader(const std::string &file_path); - /// \brief add the schema and save it - /// \param[in] schema the schema needs to be added - /// \return the last schema's id - int AddSchema(std::shared_ptr schema); - - /// \brief add the statistic and save it - /// \param[in] statistic the statistic needs to be added - /// \return the last statistic's id - void AddStatistic(std::shared_ptr statistic); - - /// \brief create index and add fields which from schema for each schema - /// \param[in] fields the index fields needs to be added - /// \return SUCCESS if add successfully, FAILED if not - MSRStatus AddIndexFields(std::vector> fields); - - MSRStatus AddIndexFields(const std::vector &fields); - - /// \brief get the schema - /// \return the schema - std::vector> GetSchemas(); - - /// \brief get Statistics - /// \return the Statistic - std::vector> GetStatistics(); - - /// \brief get the fields of the index - /// \return the fields of the index - std::vector> GetFields(); - - /// \brief get the index - /// \return the index - std::shared_ptr GetIndex(); - - /// \brief get the schema by schemaid - /// \param[in] schemaId the id of schema needs to be got - /// \return the schema obtained by schemaId - std::pair, MSRStatus> GetSchemaByID(int64_t schema_id); - - /// \brief get the filepath to shard by shardID - /// \param[in] shardID the id of shard which filepath needs to be obtained - /// \return the filepath obtained by shardID - std::string GetShardAddressByID(int64_t shard_id); - - /// \brief get the statistic by statistic id - /// \param[in] statisticId the id of statistic needs to be get - /// \return the statistics obtained by statistic id - std::pair, MSRStatus> GetStatisticByID(int64_t statistic_id); - - MSRStatus InitByFiles(const std::vector &file_paths); - - void SetIndex(Index index) { index_ = std::make_shared(index); } - - std::pair, MSRStatus> GetPage(const int &shard_id, const int &page_id); - - MSRStatus SetPage(const std::shared_ptr &new_page); - - MSRStatus AddPage(const std::shared_ptr &new_page); - - int64_t GetLastPageId(const int &shard_id); - - int GetLastPageIdByType(const int &shard_id, const std::string &page_type); - - const std::pair> GetPageByGroupId(const int &group_id, const int &shard_id); - - std::vector GetShardAddresses() const { return shard_addresses_; } - - int GetShardCount() const { return shard_count_; } - - int GetSchemaCount() const { return schema_.size(); } - - uint64_t GetHeaderSize() const { return header_size_; } - - uint64_t GetPageSize() const { return page_size_; } - - void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } - - void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - - std::vector SerializeHeader(); - - MSRStatus PagesToFile(const std::string dump_file_name); - - MSRStatus FileToPages(const std::string dump_file_name); - - private: - MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); - - /// \brief get the headers from all the shard data - /// \param[in] the shard data real path - /// \param[in] the headers which readed from the shard data - /// \return SUCCESS/FAILED - MSRStatus GetHeaders(const vector &real_addresses, std::vector &headers); - - MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); - - /// \brief check the binary file status - static MSRStatus CheckFileStatus(const std::string &path); - - static std::pair ValidateHeader(const std::string &path); - - void ParseHeader(const json &header); - - void GetHeadersOneTask(int start, int end, std::vector &headers, const vector &realAddresses); - - MSRStatus ParseIndexFields(const json &index_fields); - - MSRStatus CheckIndexField(const std::string &field, const json &schema); - - void ParsePage(const json &page, int shard_index, bool load_dataset); - - MSRStatus ParseStatistics(const json &statistics); - - MSRStatus ParseSchema(const json &schema); - - void ParseShardAddress(const json &address); - - std::string SerializeIndexFields(); - - std::vector SerializePage(); - - std::string SerializeStatistics(); - - std::string SerializeSchema(); - - std::string SerializeShardAddress(); - - std::shared_ptr InitIndexPtr(); - - MSRStatus GetAllSchemaID(std::set &bucket_count); - - uint32_t shard_count_; - uint64_t header_size_; - uint64_t page_size_; - - std::shared_ptr index_; - std::vector shard_addresses_; - std::vector> schema_; - std::vector> statistics_; - std::vector>> pages_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_HEADER_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_index.h b/mindspore/ccsrc/mindrecord/include/shard_index.h deleted file mode 100644 index d430c5bdcf..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_index.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INDEX_H -#define MINDRECORD_INDEX_H -#pragma once - -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_schema.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -using std::cin; -using std::endl; -using std::pair; -using std::string; -using std::vector; - -class Index { - public: - Index(); - - ~Index() {} - - /// \brief Add field which from schema according to schemaId - /// \param[in] schemaId the id of schema to be added - /// \param[in] field the field need to be added - /// - /// add the field to the fields_ vector - void AddIndexField(const int64_t &schemaId, const std::string &field); - - /// \brief get stored fields - /// \return fields stored - std::vector > GetFields(); - - private: - std::vector > fields_; - string database_name_; - string table_name_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INDEX_H diff --git a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/mindrecord/include/shard_index_generator.h deleted file mode 100644 index b081b7a0a0..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_index_generator.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ -#define MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/shard_header.h" -#include "./sqlite3.h" - -namespace mindspore { -namespace mindrecord { -using INDEX_FIELDS = std::pair>>; -using ROW_DATA = std::pair>>>; -class ShardIndexGenerator { - public: - explicit ShardIndexGenerator(const std::string &file_path, bool append = false); - - MSRStatus Build(); - - static std::pair GenerateFieldName(const std::pair &field); - - ~ShardIndexGenerator() {} - - /// \brief fetch value in json by field name - /// \param[in] field - /// \param[in] input - /// \return pair - std::pair GetValueByField(const string &field, json input); - - /// \brief fetch field type in schema n by field path - /// \param[in] field_path - /// \param[in] schema - /// \return the type of field - static std::string TakeFieldType(const std::string &field_path, json schema); - - /// \brief create databases for indexes - MSRStatus WriteToDatabase(); - - private: - static int Callback(void *not_used, int argc, char **argv, char **az_col_name); - - static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); - - static std::string ConvertJsonToSQL(const std::string &json); - - std::pair CreateDatabase(int shard_no); - - std::pair> GetSchemaDetails(const std::vector &schema_lens, std::fstream &in); - - static std::pair GenerateRawSQL(const std::vector> &fields); - - std::pair CheckDatabase(const std::string &shard_address); - - /// - /// \param shard_no - /// \param blob_id_to_page_id - /// \param raw_page_id - /// \param in - /// \return field name, db type, field value - ROW_DATA GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, - std::fstream &in); - /// - /// \param db - /// \param sql - /// \param data - /// \return - MSRStatus BindParameterExecuteSQL( - sqlite3 *db, const std::string &sql, - const std::vector>> &data); - - INDEX_FIELDS GenerateIndexFields(const std::vector &schema_detail); - - MSRStatus ExecuteTransaction(const int &shard_no, std::pair &db, - const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); - - MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); - - MSRStatus AddBlobPageInfo(std::vector> &row_data, - const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, - std::fstream &in); - - void AddIndexFieldByRawData(const std::vector &schema_detail, - std::vector> &row_data); - - void DatabaseWriter(); // worker thread - - std::string file_path_; - bool append_; - ShardHeader shard_header_; - uint64_t page_size_; - uint64_t header_size_; - int schema_count_; - std::atomic_int task_; - std::atomic_bool write_success_; - std::vector> fields_; -}; -} // namespace mindrecord -} // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_INDEX_GENERATOR_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h deleted file mode 100644 index f33e3db5f4..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ -#define MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ - -#include -#include "mindrecord/include/shard_task.h" - -namespace mindspore { -namespace mindrecord { -class ShardOperator { - public: - virtual ~ShardOperator() = default; - - MSRStatus operator()(ShardTask &tasks) { - if (SUCCESS != this->PreExecute(tasks)) { - return FAILED; - } - if (SUCCESS != this->Execute(tasks)) { - return FAILED; - } - if (SUCCESS != this->SufExecute(tasks)) { - return FAILED; - } - return SUCCESS; - } - virtual bool HasChildOp() { return child_op_ != nullptr; } - - virtual MSRStatus SetChildOp(std::shared_ptr child_op) { - if (child_op != nullptr) child_op_ = child_op; - return SUCCESS; - } - - virtual std::shared_ptr GetChildOp() { return child_op_; } - - virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } - - virtual MSRStatus Execute(ShardTask &tasks) = 0; - - virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } - - virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } - - private: - std::shared_ptr child_op_ = nullptr; -}; -} // namespace mindrecord -} // namespace mindspore -#endif // MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_page.h b/mindspore/ccsrc/mindrecord/include/shard_page.h deleted file mode 100644 index c22acd8d2c..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_page.h +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_PAGE_H_ -#define MINDRECORD_INCLUDE_SHARD_PAGE_H_ - -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -const std::string kPageTypeRaw = "RAW_DATA"; -const std::string kPageTypeBlob = "BLOB_DATA"; -const std::string kPageTypeNewColumn = "NEW_COLUMN_DATA"; - -class Page { - public: - Page(const int &page_id, const int &shard_id, const std::string &page_type, const int &page_type_id, - const uint64_t &start_row_id, const uint64_t end_row_id, - const std::vector> &row_group_ids, const uint64_t page_size) - : page_id_(page_id), - shard_id_(shard_id), - page_type_(page_type), - page_type_id_(page_type_id), - start_row_id_(start_row_id), - end_row_id_(end_row_id), - row_group_ids_(row_group_ids), - page_size_(page_size) {} - - ~Page() = default; - - /// \brief get the page and its description - /// \return the json format of the page and its description - json GetPage() const; - - int GetPageID() const { return page_id_; } - - int GetShardID() const { return shard_id_; } - - int GetPageTypeID() const { return page_type_id_; } - - std::string GetPageType() const { return page_type_; } - - uint64_t GetPageSize() const { return page_size_; } - - uint64_t GetStartRowID() const { return start_row_id_; } - - uint64_t GetEndRowID() const { return end_row_id_; } - - void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } - - void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - - std::pair GetLastRowGroupID() const { return row_group_ids_.back(); } - - std::vector> GetRowGroupIds() const { return row_group_ids_; } - - void SetRowGroupIds(const std::vector> &last_row_group_ids) { - row_group_ids_ = last_row_group_ids; - } - - void DeleteLastGroupId(); - - private: - int page_id_; - int shard_id_; - std::string page_type_; - int page_type_id_; - uint64_t start_row_id_; - uint64_t end_row_id_; - std::vector> row_group_ids_; - uint64_t page_size_; - // JSON page: { - // "page_id":X, - // "shard_id":X, - // "page_type":"XXX", (enum "raw_data", "blob_data", "new_column") - // "page_type_id":X, - // "start_row_id":X, - // "end_row_id":X, - // "row_group_ids":[{"id":X, "offset":X}], - // "page_size":X, -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_PAGE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h deleted file mode 100644 index 4f1a1c307a..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_shuffle.h" -#include "mindrecord/include/shard_category.h" - -namespace mindspore { -namespace mindrecord { -class ShardPkSample : public ShardCategory { - public: - ShardPkSample(const std::string &category_field, int64_t num_elements); - - ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); - - ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); - - ~ShardPkSample() override{}; - - MSRStatus SufExecute(ShardTask &tasks) override; - - private: - bool shuffle_; - std::shared_ptr shuffle_op_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h deleted file mode 100644 index 1f2138d6d5..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ /dev/null @@ -1,366 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_READER_H_ -#define MINDRECORD_INCLUDE_SHARD_READER_H_ - -#include -#include -#if !defined(_WIN32) && !defined(_WIN64) -#include -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_column.h" -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -using ROW_GROUPS = - std::tuple>>, std::vector>>; -using ROW_GROUP_BRIEF = - std::tuple>, std::vector>; -using TASK_RETURN_CONTENT = - std::pair, json>>>>; -const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode -const int kNumPageInBuffer = 16; // page buffer size in block-reader mode - -class ShardReader { - public: - ShardReader(); - - virtual ~ShardReader(); - - /// \brief open files and initialize reader, c++ API - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] n_consumer number of threads when reading - /// \param[in] selected_columns column list to be populated - /// \param[in] operators operators applied to data, operator type is shuffle, sample or category - /// \param[in] block_reader block-reader mode if true, otherwise row-reader mode - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, - const std::vector &selected_columns = {}, - const std::vector> &operators = {}, const bool &block_reader = false, - const int num_padded = 0); - - /// \brief open files and initialize reader, python API - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] n_consumer number of threads when reading - /// \param[in] selected_columns column list to be populated - /// \param[in] operators operators applied to data, operator type is shuffle, sample or category - /// \return MSRStatus the status of MSRStatus - MSRStatus OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer = 4, - const std::vector &selected_columns = {}, - const std::vector> &operators = {}); - - /// \brief close reader - /// \return null - void Close(); - - /// \brief read the file, get schema meta,statistics and index, single-thread mode - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(); - - /// \brief read the file, get schema meta,statistics and index, multiple-thread mode - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(int n_consumer); - - /// \brief launch threads to get batches - /// \param[in] is_simple_reader trigger threads if false; do nothing if true - /// \return MSRStatus the status of MSRStatus - MSRStatus Launch(bool is_simple_reader = false); - - /// \brief aim to get the meta data - /// \return the metadata - std::shared_ptr GetShardHeader() const; - - /// \brief aim to get columns context - /// \return the columns - std::shared_ptr GetShardColumn() const; - - /// \brief get the number of shards - /// \return # of shards - int GetShardCount() const; - - /// \brief get the number of rows in database - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] op smart pointer refer to ShardCategory or ShardSample object - /// \param[out] count # of rows - /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &op, int64_t *count, const int num_padded); - - /// \brief shuffle task with incremental seed - /// \return void - void ShuffleTask(); - - /// \brief get the number of rows in database - /// \return # of rows - int GetNumRows() const; - - /// \brief Read the summary of row groups - /// \return the tuple of 4 elements - /// 1. Sharding ID - /// 2. Row group ID - /// 3. The row ID started in row group - /// 4. # of rows in row group - std::vector> ReadRowGroupSummary(); - - /// \brief Read 1 row group data, excluding images - /// \param[in] groupID row group ID - /// \param[in] shard_id sharding ID - /// \param[in] columns multi-columns retrieved - /// \return the tuple of 5 elements - /// 1. file name where row group is located - /// 2. Actual row group size - /// 3. Offset address of row group in file - /// 4. The list of image offset in page [startOffset, endOffset) - /// 5. The list of columns data - ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, - const std::vector &columns = std::vector()); - - /// \brief Read 1 row group data, excluding images, following an index field criteria - /// \param[in] groupID row group ID - /// \param[in] shard_id sharding ID - /// \param[in] column-value pair of criteria to fulfill - /// \param[in] columns multi-columns retrieved - /// \return the tuple of 5 elements - /// 1. file name where row group is located - /// 2. Actual row group size - /// 3. Offset address of row group in file - /// 4. The list of image offset in page [startOffset, endOffset) - /// 5. The list of columns data - ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, - const std::vector &columns = std::vector()); - - /// \brief join all created threads - /// \return MSRStatus the status of MSRStatus - MSRStatus Finish(); - - /// \brief return a batch, given that one is ready - /// \return a batch of images and image data - std::vector, json>> GetNext(); - - /// \brief return a row by id - /// \return a batch of images and image data - std::pair, json>>> GetNextById(const int64_t &task_id, - const int32_t &consumer_id); - - /// \brief return a batch in block-reader mode, given that one is ready - /// \return a batch of images and image data - std::vector, json>> GetBlockNext(); - - /// \brief return a batch, given that one is ready, python API - /// \return a batch of images and image data - std::vector>, pybind11::object>> GetNextPy(); - - /// \brief get blob filed list - /// \return blob field list - std::pair> GetBlobFields(); - - /// \brief reset reader - /// \return null - void Reset(); - - /// \brief set flag of all-in-index - /// \return null - void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } - - /// \brief get NLP flag - bool GetNlpFlag(); - - /// \brief get all classes - MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); - - protected: - /// \brief sqlite call back function - static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); - - private: - /// \brief wrap up labels to json format - MSRStatus ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, - std::vector>> &offsets, int shard_id, - const std::vector &columns, std::vector> &column_values); - - /// \brief read all rows for specified columns - ROW_GROUPS ReadAllRowGroup(std::vector &columns); - - /// \brief read all rows in one shard - MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::vector>> &offsets, - std::vector> &column_values); - - /// \brief initialize reader - MSRStatus Init(const std::vector &file_paths, bool load_dataset); - - /// \brief validate column list - MSRStatus CheckColumnList(const std::vector &selected_columns); - - /// \brief populate one row by task list in row-reader mode - MSRStatus ConsumerByRow(int consumer_id); - - /// \brief populate one row by task list in block-reader mode - MSRStatus ConsumerByBlock(int consumer_id); - - /// \brief get offset address of images within page - std::vector> GetImageOffset(int group_id, int shard_id, - const std::pair &criteria = {"", ""}); - - /// \brief execute sqlite query with prepare statement - MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector> &labels); - - /// \brief get column values - std::pair> GetLabels(int group_id, int shard_id, const std::vector &columns, - const std::pair &criteria = {"", ""}); - - /// \brief get column values from raw data page - std::pair> GetLabelsFromPage(int group_id, int shard_id, - const std::vector &columns, - const std::pair &criteria = {"", - ""}); - - /// \brief create task list in block-reader mode - MSRStatus CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators); - - /// \brief create category-applied task list - MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, - const std::shared_ptr &op); - - /// \brief create task list in row-reader mode - MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators); - - /// \brief crate task list - MSRStatus CreateTasks(const std::vector> &row_group_summary, - const std::vector> &operators); - - /// \brief set NLP flag - void CheckNlp(); - - /// \brief check if all specified columns are in index table - void CheckIfColumnInIndex(const std::vector &columns); - - /// \brief open multiple file handle - void FileStreamsOperator(); - - /// \brief read one row by one task - TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); - - /// \brief get one row from buffer in block-reader mode - std::shared_ptr, json>>> GetRowFromBuffer(int bufId, int rowId); - - /// \brief get labels from binary file - std::pair> GetLabelsFromBinaryFile( - int shard_id, const std::vector &columns, const std::vector> &label_offsets); - - MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); - - /// \brief get classes in one shard - void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); - - /// \brief get number of classes - int64_t GetNumClasses(const std::string &category_field); - - /// \brief get meta of header - std::pair> GetMeta(const std::string &file_path, json &meta_data); - - /// \brief extract uncompressed data based on column list - std::pair>> UnCompressBlob(const std::vector &raw_blob_data); - - protected: - uint64_t header_size_; // header size - uint64_t page_size_; // page size - int shard_count_; // number of shards - std::shared_ptr shard_header_; // shard header - std::shared_ptr shard_column_; // shard column - - std::vector database_paths_; // sqlite handle list - std::vector file_paths_; // file paths - std::vector> file_streams_; // single-file handle list - std::vector>> file_streams_random_; // multiple-file handle list - - private: - int n_consumer_; // number of workers (threads) - std::vector selected_columns_; // columns which will be read - std::map column_schema_id_; // column-schema map - std::vector> operators_; // data operators, including shuffle, sample and category - ShardTask tasks_; // shard task - std::mutex shard_locker_; // locker of shard - - // flags - bool all_in_index_ = true; // if all columns are stored in index-table - bool interrupt_ = false; // reader interrupted - - int num_padded_; // number of padding samples - - // Delivery/Iterator mode begin - const std::string kThreadName = "THRD_ITER_"; // prefix of thread name - std::vector thread_set_; // thread list - int num_rows_; // number of rows - std::mutex mtx_delivery_; // locker for delivery - std::condition_variable cv_delivery_; // conditional variable for delivery - std::condition_variable cv_iterator_; // conditional variable for iterator - std::atomic task_id_; // task ID which is working - std::atomic deliver_id_; // delivery ID which is picked up by iterator - // map of delivery - std::unordered_map, json>>>> delivery_map_; - // Delivery/Iterator mode end - - // Block reader mode begin - bool block_reader_; // block-reader mode - int row_id_; // row id in one page - int num_blocks_; // number of pages - // raw data page - std::vector>, std::vector>>> delivery_block_; - std::unordered_set delivery_block_set_; // set of delivered pages - std::vector> buf_; // page buffer - // Block reader mode end -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_READER_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h deleted file mode 100644 index a32acbff6e..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_operator.h" -#include "mindrecord/include/shard_shuffle.h" - -namespace mindspore { -namespace mindrecord { -class ShardSample : public ShardOperator { - public: - explicit ShardSample(int n); - - ShardSample(int num, int den); - - ShardSample(int num, int den, int par); - - ShardSample(const std::vector &indices, uint32_t seed); - - ~ShardSample() override{}; - - MSRStatus Execute(ShardTask &tasks) override; - - MSRStatus SufExecute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - protected: - int numerator_; - int denominator_; - int partition_id_; - int no_of_samples_; - std::shared_ptr shuffle_op_; - - private: - std::vector indices_; - SamplerType sampler_type_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_schema.h b/mindspore/ccsrc/mindrecord/include/shard_schema.h deleted file mode 100644 index 4ef134bde2..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_schema.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ -#define MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_pybind.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -class Schema { - public: - ~Schema() = default; - - /// \brief obtain the json schema ,its description, its block fields - /// \param[in] desc the description of the schema - /// \param[in] schema the schema's json - static std::shared_ptr Build(std::string desc, const json &schema); - - /// \brief obtain the json schema and its description for python - /// \param[in] desc the description of the schema - /// \param[in] schema the schema's json - static std::shared_ptr Build(std::string desc, pybind11::handle schema); - - /// \brief compare two schema to judge if they are equal - /// \param b another schema to be judged - /// \return true if they are equal,false if not - bool operator==(const Schema &b) const; - - /// \brief get the schema and its description - /// \return the json format of the schema and its description - std::string GetDesc() const; - - /// \brief get the schema and its description - /// \return the json format of the schema and its description - json GetSchema() const; - - /// \brief get the schema and its description for python method - /// \return the python object of the schema and its description - pybind11::object GetSchemaForPython() const; - - /// set the schema id - /// \param[in] id the id need to be set - void SetSchemaID(int64_t id); - - /// get the schema id - /// \return the int64 schema id - int64_t GetSchemaID() const; - - /// get the blob fields - /// \return the vector blob fields - std::vector GetBlobFields() const; - - private: - Schema() = default; - static bool ValidateNumberShape(const json &it_value); - static bool Validate(json schema); - static std::vector PopulateBlobFields(json schema); - - std::string desc_; - json schema_; - std::vector blob_fields_; - int64_t schema_id_ = -1; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SCHEMA_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_segment.h b/mindspore/ccsrc/mindrecord/include/shard_segment.h deleted file mode 100644 index 12497a5ace..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_segment.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ -#define MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_reader.h" - -namespace mindspore { -namespace mindrecord { -class ShardSegment : public ShardReader { - public: - ShardSegment(); - - ~ShardSegment() override = default; - - /// \brief Get candidate category fields - /// \return a list of fields names which are the candidates of category - std::pair> GetCategoryFields(); - - /// \brief Set category field - /// \param[in] category_field category name - /// \return true if category name is existed - MSRStatus SetCategoryField(std::string category_field); - - /// \brief Thread-safe implementation of ReadCategoryInfo - /// \return statistics data in json format with 2 field: "key" and "categories". - /// The value of "categories" is a list. Each Element in list is {count, id, name} - /// count: count of images in category - /// id: internal unique identification, persistent - /// name: category name - /// example: - /// { "key": "label", - /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, - /// { "count": 3, "id": 1, "name": "finance", } ] } - std::pair ReadCategoryInfo(); - - /// \brief Thread-safe implementation of ReadAtPageById - /// \param[in] category_id category ID - /// \param[in] page_no page number - /// \param[in] n_rows_of_page rows number in one page - /// \return images array, image is a vector of uint8_t - std::pair>> ReadAtPageById(int64_t category_id, int64_t page_no, - int64_t n_rows_of_page); - - /// \brief Thread-safe implementation of ReadAtPageByName - /// \param[in] category_name category Name - /// \param[in] page_no page number - /// \param[in] n_rows_of_page rows number in one page - /// \return images array, image is a vector of uint8_t - std::pair>> ReadAtPageByName(std::string category_name, int64_t page_no, - int64_t n_rows_of_page); - - std::pair, json>>> ReadAllAtPageById(int64_t category_id, - int64_t page_no, - int64_t n_rows_of_page); - - std::pair, json>>> ReadAllAtPageByName( - std::string category_name, int64_t page_no, int64_t n_rows_of_page); - - std::pair, pybind11::object>>> ReadAtPageByIdPy( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page); - - std::pair, pybind11::object>>> ReadAtPageByNamePy( - std::string category_name, int64_t page_no, int64_t n_rows_of_page); - - std::pair> GetBlobFields(); - - private: - std::pair>> WrapCategoryInfo(); - - std::string ToJsonForCategory(const std::vector> &tri_vec); - - std::string CleanUp(std::string fieldName); - - std::pair> PackImages(int group_id, int shard_id, std::vector offset); - - std::vector candidate_category_fields_; - std::string current_category_field_; - const uint32_t kStartFieldId = 9; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h deleted file mode 100644 index a8ee3a36db..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_sequential_sample.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ - -#include -#include -#include -#include -#include "mindrecord/include/shard_sample.h" - -namespace mindspore { -namespace mindrecord { -class ShardSequentialSample : public ShardSample { - public: - ShardSequentialSample(int n, int offset); - - ShardSequentialSample(float per, float per_offset); - - ~ShardSequentialSample() override{}; - - MSRStatus Execute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - int offset_; - float per_; - float per_offset_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SEQUENTIAL_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h deleted file mode 100644 index adb172bdcc..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ -#define MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ - -#include -#include "mindrecord/include/shard_operator.h" - -namespace mindspore { -namespace mindrecord { -class ShardShuffle : public ShardOperator { - public: - explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); - - ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, - ShuffleType shuffle_type = kShuffleSample); - - ~ShardShuffle() override{}; - - MSRStatus Execute(ShardTask &tasks) override; - - int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; - - private: - uint32_t shuffle_seed_; - int64_t no_of_samples_; - bool replacement_; - bool reshuffle_each_epoch_; - ShuffleType shuffle_type_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_SHUFFLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/mindrecord/include/shard_statistics.h deleted file mode 100644 index 7fc2f968cd..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_statistics.h +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#ifndef MINDRECORD_STATISTICS_H -#define MINDRECORD_STATISTICS_H - -#include -#include -#include -#include -#include - -#include "mindrecord/include/common/shard_pybind.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" -#include "pybind11/pybind11.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -class Statistics { - public: - /// \brief save the statistic and its description - /// \param[in] desc the statistic's description - /// \param[in] statistics the statistic needs to be saved - static std::shared_ptr Build(std::string desc, const json &statistics); - - /// \brief save the statistic from python and its description - /// \param[in] desc the statistic's description - /// \param[in] statistics the statistic needs to be saved - static std::shared_ptr Build(std::string desc, pybind11::handle statistics); - - ~Statistics() = default; - - /// \brief compare two statistics to judge if they are equal - /// \param b another statistics to be judged - /// \return true if they are equal,false if not - bool operator==(const Statistics &b) const; - - /// \brief get the description - /// \return the description - std::string GetDesc() const; - - /// \brief get the statistic - /// \return json format of the statistic - json GetStatistics() const; - - /// \brief get the statistic for python - /// \return the python object of statistics - pybind11::object GetStatisticsForPython() const; - - /// \brief decode the bson statistics to json - /// \param[in] encodedStatistics the bson type of statistics - /// \return json type of statistic - void SetStatisticsID(int64_t id); - - /// \brief get the statistics id - /// \return the int64 statistics id - int64_t GetStatisticsID() const; - - private: - /// \brief validate the statistic - /// \return true / false - static bool Validate(const json &statistics); - - static bool LevelRecursive(json level); - - Statistics() = default; - - std::string desc_; - json statistics_; - int64_t statistics_id_ = -1; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_STATISTICS_H diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h deleted file mode 100644 index 4a12eb9e45..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_TASK_H_ -#define MINDRECORD_INCLUDE_SHARD_TASK_H_ - -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" - -namespace mindspore { -namespace mindrecord { -class ShardTask { - public: - ShardTask(); - - ShardTask(const ShardTask &task); // copy construction - - ShardTask &operator=(const ShardTask &task); // assignment operator - - ~ShardTask() = default; - - void MakePerm(); - - void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label); - - void InsertTask(std::tuple, std::vector, json> task); - - void PopBack(); - - uint32_t Size() const; - - uint32_t SizeOfRows() const; - - std::tuple, std::vector, json> &GetTaskByID(size_t id); - - std::tuple, std::vector, json> &GetRandomTask(); - - static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); - - uint32_t categories; - - std::vector permutation_; - - std::vector, std::vector, json>> task_list_; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_TASK_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h deleted file mode 100644 index 6175180c92..0000000000 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ /dev/null @@ -1,257 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDRECORD_INCLUDE_SHARD_WRITER_H_ -#define MINDRECORD_INCLUDE_SHARD_WRITER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_column.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_index.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace mindrecord { -class ShardWriter { - public: - ShardWriter(); - - ~ShardWriter(); - - /// \brief Open file at the beginning - /// \param[in] paths the file names list - /// \param[in] append new data at the end of file if true, otherwise overwrite file - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::vector &paths, bool append = false); - - /// \brief Open file at the ending - /// \param[in] paths the file names list - /// \return MSRStatus the status of MSRStatus - MSRStatus OpenForAppend(const std::string &path); - - /// \brief Write header to disk - /// \return MSRStatus the status of MSRStatus - MSRStatus Commit(); - - /// \brief Set file size - /// \param[in] header_size the size of header, only (1< header_data); - - /// \brief write raw data by group size - /// \param[in] raw_data the vector of raw json data, vector format - /// \param[in] blob_data the vector of image data - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); - - /// \brief write raw data by group size for call from python - /// \param[in] raw_data the vector of raw json data, python-handle format - /// \param[in] blob_data the vector of image data - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); - - /// \brief write raw data by group size for call from python - /// \param[in] raw_data the vector of raw json data, python-handle format - /// \param[in] blob_data the vector of blob json data, python-handle format - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true, - bool parallel_writer = false); - - private: - /// \brief write shard header data to disk - MSRStatus WriteShardHeader(); - - /// \brief erase error data - void DeleteErrorData(std::map> &raw_data, std::vector> &blob_data); - - /// \brief populate error data - void PopulateMutexErrorData(const int &row, const std::string &message, std::map &err_raw_data); - - /// \brief check data - void CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, - std::map &err_raw_data); - - /// \brief write shard header data to disk - std::tuple ValidateRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign); - - /// \brief fill data array in multiple thread run - void FillArray(int start, int end, std::map> &raw_data, - std::vector> &bin_data); - - /// \brief serialized raw data - MSRStatus SerializeRawData(std::map> &raw_data, - std::vector> &bin_data, uint32_t row_count); - - /// \brief write all data parallel - MSRStatus ParallelWriteData(const std::vector> &blob_data, - const std::vector> &bin_raw_data); - - /// \brief write data shard by shard - MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector> &blob_data, - const std::vector> &bin_raw_data); - - /// \brief break image data up into multiple row groups - MSRStatus CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, - std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, - const std::shared_ptr &last_blob_page); - - /// \brief append partial blob data to previous page - MSRStatus AppendBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page); - - /// \brief write new blob data page to disk - MSRStatus NewBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page); - - /// \brief shift last row group to next raw page for new appending - MSRStatus ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page); - - /// \brief write raw data page to disk - MSRStatus WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data); - - /// \brief generate empty raw data page - void EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page); - - /// \brief append a row group at the end of raw page - MSRStatus AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, - const int &chunk_id, int &last_row_groupId, std::shared_ptr last_raw_page, - const std::vector> &bin_raw_data); - - /// \brief write blob chunk to disk - MSRStatus FlushBlobChunk(const std::shared_ptr &out, const std::vector> &blob_data, - const std::pair &blob_row); - - /// \brief write raw chunk to disk - MSRStatus FlushRawChunk(const std::shared_ptr &out, - const std::vector> &rows_in_group, const int &chunk_id, - const std::vector> &bin_raw_data); - - /// \brief break up into tasks by shard - std::vector> BreakIntoShards(); - - /// \brief calculate raw data size row by row - MSRStatus SetRawDataSize(const std::vector> &bin_raw_data); - - /// \brief calculate blob data size row by row - MSRStatus SetBlobDataSize(const std::vector> &blob_data); - - /// \brief populate last raw page pointer - void SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page); - - /// \brief populate last blob page pointer - void SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page); - - /// \brief check the data by schema - MSRStatus CheckData(const std::map> &raw_data); - - /// \brief check the data and type - MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, - std::map &err_raw_data); - - /// \brief Lock writer and save pages info - int LockWriter(bool parallel_writer = false); - - /// \brief Unlock writer and save pages info - MSRStatus UnlockWriter(int fd, bool parallel_writer = false); - - /// \brief Check raw data before writing - MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, - bool sign, int *schema_count, int *row_count); - - /// \brief Get full path from file name - MSRStatus GetFullPathFromFileName(const std::vector &paths); - - /// \brief Open files - MSRStatus OpenDataFiles(bool append); - - /// \brief Remove lock file - MSRStatus RemoveLockFile(); - - /// \brief Remove lock file - MSRStatus InitLockFile(); - - private: - const std::string kLockFileSuffix = "_Locker"; - const std::string kPageFileSuffix = "_Pages"; - std::string lock_file_; // lock file for parallel run - std::string pages_file_; // temporary file of pages info for parallel run - - int shard_count_; // number of files - uint64_t header_size_; // header size - uint64_t page_size_; // page size - uint32_t row_count_; // count of rows - uint32_t schema_count_; // count of schemas - - std::vector raw_data_size_; // Raw data size - std::vector blob_data_size_; // Blob data size - - std::vector file_paths_; // file paths - std::vector> file_streams_; // file handles - std::shared_ptr shard_header_; // shard header - std::shared_ptr shard_column_; // shard columns - - std::map> err_mg_; // used for storing error raw_data info - - std::mutex check_mutex_; // mutex for data check - std::atomic flag_{false}; -}; -} // namespace mindrecord -} // namespace mindspore - -#endif // MINDRECORD_INCLUDE_SHARD_WRITER_H_ diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc deleted file mode 100644 index 16c730bd4c..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ /dev/null @@ -1,626 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "mindrecord/include/shard_index_generator.h" -#include "common/utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) - : file_path_(file_path), - append_(append), - page_size_(0), - header_size_(0), - schema_count_(0), - task_(0), - write_success_(true) {} - -MSRStatus ShardIndexGenerator::Build() { - auto ret = ShardHeader::BuildSingleHeader(file_path_); - if (ret.first != SUCCESS) { - return FAILED; - } - auto json_header = ret.second; - - auto ret2 = GetParentDir(file_path_); - if (SUCCESS != ret2.first) { - return FAILED; - } - std::vector real_addresses; - for (const auto &path : json_header["shard_addresses"]) { - std::string abs_path = ret2.second + string(path); - real_addresses.emplace_back(abs_path); - } - ShardHeader header = ShardHeader(); - if (header.BuildDataset(real_addresses) == FAILED) { - return FAILED; - } - shard_header_ = header; - MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; - return SUCCESS; -} - -std::pair ShardIndexGenerator::GetValueByField(const string &field, json input) { - if (field.empty()) { - MS_LOG(ERROR) << "The input field is None."; - return {FAILED, ""}; - } - - if (input.empty()) { - MS_LOG(ERROR) << "The input json is None."; - return {FAILED, ""}; - } - - // parameter input does not contain the field - if (input.find(field) == input.end()) { - MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; - return {FAILED, ""}; - } - - // schema does not contain the field - auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; - if (schema.find(field) == schema.end()) { - MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; - return {FAILED, ""}; - } - - // field should be scalar type - if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { - MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; - return {FAILED, ""}; - } - - if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { - auto schema_field_options = schema[field]; - if (schema_field_options.find("shape") == schema_field_options.end()) { - return {SUCCESS, input[field].dump()}; - } else { - // field with shape option - MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; - return {FAILED, ""}; - } - } - - // the field type is string in here - return {SUCCESS, input[field].get()}; -} - -std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { - std::vector field_name = StringSplit(field_path, kPoint); - for (uint64_t i = 0; i < field_name.size(); i++) { - if (i != field_name.size() - 1) { - // Get type information from json schema - schema = schema.at(field_name[i]); - schema = schema.at("properties"); - } else { - // standard root layer exist "properties" if type is "object" - if (schema.find("properties") != schema.end()) { - schema = schema.at("properties"); - } - schema = schema.at(field_name[i]); - std::string field_type = schema.at("type").dump(); - if (field_type.length() <= 2) { - return ""; - } else { - return field_type.substr(1, field_type.length() - 2); - } - } - } - return ""; -} - -std::string ShardIndexGenerator::ConvertJsonToSQL(const std::string &json) { - if (kDbJsonMap.find(json) != kDbJsonMap.end()) { - return kDbJsonMap.at(json); - } else { - return "TEXT"; - } -} - -int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char **az_col_name) { - for (auto i = 0; i < argc; i++) { - if (argv[i] != nullptr) { - MS_LOG(INFO) << az_col_name[i] << " = " << (argv[i] ? argv[i] : "nullptr"); - } - } - MS_LOG(INFO) << "\n"; - return 0; -} - -MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { - char *z_err_msg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Sql error: " << z_err_msg; - sqlite3_free(z_err_msg); - return FAILED; - } else { - if (!success_msg.empty()) { - MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; - } - sqlite3_free(z_err_msg); - return SUCCESS; - } -} - -std::pair ShardIndexGenerator::GenerateFieldName( - const std::pair &field) { - // Replaces dots and dashes with underscores for SQL use - std::string field_name = field.second; - // white list to avoid sql injection - std::replace_if( - field_name.begin(), field_name.end(), [](char x) { return (x == '-' || x == '.'); }, '_'); - auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { - return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); - }); - if (pos != field_name.end()) { - MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; - return {FAILED, ""}; - } - return {SUCCESS, field_name + "_" + std::to_string(field.first)}; -} - -std::pair ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { - sqlite3 *db = nullptr; - std::ifstream fin(common::SafeCStr(shard_address)); - if (!append_ && fin.good()) { - MS_LOG(ERROR) << "DB file already exist"; - fin.close(); - return {FAILED, nullptr}; - } - fin.close(); - int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); - if (rc) { - MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); - return {FAILED, nullptr}; - } else { - MS_LOG(DEBUG) << "Opened database successfully"; - return {SUCCESS, db}; - } -} - -MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { - // create shard_name table - std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; - if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { - return FAILED; - } - sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; - if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { - return FAILED; - } - sql = "INSERT INTO SHARD_NAME (NAME) VALUES ('" + shard_name + "');"; - if (ExecuteSQL(sql, db, "insert name successfully.") != SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -std::pair ShardIndexGenerator::CreateDatabase(int shard_no) { - std::string shard_address = shard_header_.GetShardAddressByID(shard_no); - if (shard_address.empty()) { - MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; - return {FAILED, nullptr}; - } - - string shard_name = GetFileName(shard_address).second; - shard_address += ".db"; - auto ret1 = CheckDatabase(shard_address); - if (ret1.first != SUCCESS) { - return {FAILED, nullptr}; - } - sqlite3 *db = ret1.second; - std::string sql = "DROP TABLE IF EXISTS INDEXES;"; - if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { - return {FAILED, nullptr}; - } - sql = - "CREATE TABLE INDEXES(" - " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" - ", PAGE_OFFSET_RAW INT NOT NULL, PAGE_OFFSET_RAW_END INT NOT NULL" - ", ROW_GROUP_ID INT NOT NULL, PAGE_ID_BLOB INT NOT NULL" - ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; - - int field_no = 0; - for (const auto &field : fields_) { - uint64_t schema_id = field.first; - auto result = shard_header_.GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - return {FAILED, nullptr}; - } - json json_schema = (result.first->GetSchema())["schema"]; - std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, nullptr}; - } - sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; - } - sql += ", PRIMARY KEY(ROW_ID"; - for (uint64_t i = 0; i < fields_.size(); ++i) sql += ",INC_" + std::to_string(i); - sql += "));"; - if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { - return {FAILED, nullptr}; - } - - if (CreateShardNameTable(db, shard_name) != SUCCESS) { - return {FAILED, nullptr}; - } - return {SUCCESS, db}; -} - -std::pair> ShardIndexGenerator::GetSchemaDetails(const std::vector &schema_lens, - std::fstream &in) { - std::vector schema_details; - if (schema_count_ <= kMaxSchemaCount) { - for (int sc = 0; sc < schema_count_; ++sc) { - std::vector schema_detail(schema_lens[sc]); - - auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - in.close(); - return {FAILED, {}}; - } - - schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); - } - } - - return {SUCCESS, schema_details}; -} - -std::pair ShardIndexGenerator::GenerateRawSQL( - const std::vector> &fields) { - std::string sql = - "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," - "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; - - int field_no = 0; - for (const auto &field : fields) { - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, ""}; - } - sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; - } - sql += - ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," - ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; - field_no = 0; - for (const auto &field : fields) { - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, ""}; - } - sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; - } - sql += " )"; - return {SUCCESS, sql}; -} - -MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( - sqlite3 *db, const std::string &sql, - const std::vector>> &data) { - sqlite3_stmt *stmt = nullptr; - if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; - return FAILED; - } - for (auto &row : data) { - for (auto &field : row) { - const auto &place_holder = std::get<0>(field); - const auto &field_type = std::get<1>(field); - const auto &field_value = std::get<2>(field); - - int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder)); - if (field_type == "INTEGER") { - if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stoll(field_value); - return FAILED; - } - } else if (field_type == "NUMERIC") { - if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stold(field_value); - return FAILED; - } - } else if (field_type == "NULL") { - if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; - return FAILED; - } - } else { - if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; - return FAILED; - } - } - } - if (sqlite3_step(stmt) != SQLITE_DONE) { - MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; - return FAILED; - } - (void)sqlite3_reset(stmt); - } - (void)sqlite3_finalize(stmt); - return SUCCESS; -} - -MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, - const std::shared_ptr cur_blob_page, - uint64_t &cur_blob_page_offset, std::fstream &in) { - row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); - - // blob data start - row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); - auto &io_seekg_blob = - in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); - if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - in.close(); - return FAILED; - } - - uint64_t image_size = 0; - - auto &io_read = in.read(reinterpret_cast(&image_size), kInt64Len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - in.close(); - return FAILED; - } - - cur_blob_page_offset += (kInt64Len + image_size); - row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); - - return SUCCESS; -} - -void ShardIndexGenerator::AddIndexFieldByRawData( - const std::vector &schema_detail, std::vector> &row_data) { - auto result = GenerateIndexFields(schema_detail); - if (result.first == SUCCESS) { - int index = 0; - for (const auto &field : result.second) { - // assume simple field: string , number etc. - row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); - row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); - } - } -} - -ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, - int raw_page_id, std::fstream &in) { - std::vector>> full_data; - - // current raw data page - std::shared_ptr cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; - - // related blob page - vector> row_group_list = cur_raw_page->GetRowGroupIds(); - - // pair: row_group id, offset in raw data page - for (pair blob_ids : row_group_list) { - // get blob data page according to row_group id - std::shared_ptr cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first; - - // offset in current raw data page - auto cur_raw_page_offset = static_cast(blob_ids.second); - uint64_t cur_blob_page_offset = 0; - for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { - std::vector> row_data; - row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); - row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); - row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); - - // raw data start - row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); - - // calculate raw data end - auto &io_seekg = - in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - in.close(); - return {FAILED, {}}; - } - - std::vector schema_lens; - if (schema_count_ <= kMaxSchemaCount) { - for (int sc = 0; sc < schema_count_; sc++) { - uint64_t schema_size = 0; - - auto &io_read = in.read(reinterpret_cast(&schema_size), kInt64Len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - in.close(); - return {FAILED, {}}; - } - - cur_raw_page_offset += (kInt64Len + schema_size); - schema_lens.push_back(schema_size); - } - } - row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); - - // Getting schema for getting data for fields - auto st_schema_detail = GetSchemaDetails(schema_lens, in); - if (st_schema_detail.first != SUCCESS) { - return {FAILED, {}}; - } - - // start blob page info - if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { - return {FAILED, {}}; - } - - // start index field - AddIndexFieldByRawData(st_schema_detail.second, row_data); - full_data.push_back(std::move(row_data)); - } - } - return {SUCCESS, full_data}; -} - -INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &schema_detail) { - std::vector> fields; - // index fields - std::vector> index_fields = shard_header_.GetFields(); - for (const auto &field : index_fields) { - if (field.first >= schema_detail.size()) { - return {FAILED, {}}; - } - auto field_value = GetValueByField(field.second, schema_detail[field.first]); - if (field_value.first != SUCCESS) { - MS_LOG(ERROR) << "Get value from json by field name failed"; - return {FAILED, {}}; - } - - auto result = shard_header_.GetSchemaByID(field.first); - if (result.second != SUCCESS) { - return {FAILED, {}}; - } - - std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, {}}; - } - - fields.emplace_back(ret.second, field_type, field_value.second); - } - return {SUCCESS, std::move(fields)}; -} - -MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair &db, - const std::vector &raw_page_ids, - const std::map &blob_id_to_page_id) { - // Add index data to database - std::string shard_address = shard_header_.GetShardAddressByID(shard_no); - if (shard_address.empty()) { - MS_LOG(ERROR) << "Shard address is null"; - return FAILED; - } - - std::fstream in; - in.open(common::SafeCStr(shard_address), std::ios::in | std::ios::binary); - if (!in.good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); - for (int raw_page_id : raw_page_ids) { - auto sql = GenerateRawSQL(fields_); - if (sql.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw SQL failed"; - return FAILED; - } - auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); - if (data.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw data failed"; - return FAILED; - } - if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { - MS_LOG(ERROR) << "Execute SQL failed"; - return FAILED; - } - MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; - } - (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); - in.close(); - - // Close database - if (sqlite3_close(db.second) != SQLITE_OK) { - MS_LOG(ERROR) << "Close database failed"; - return FAILED; - } - db.second = nullptr; - return SUCCESS; -} - -MSRStatus ShardIndexGenerator::WriteToDatabase() { - fields_ = shard_header_.GetFields(); - page_size_ = shard_header_.GetPageSize(); - header_size_ = shard_header_.GetHeaderSize(); - schema_count_ = shard_header_.GetSchemaCount(); - if (shard_header_.GetShardCount() > kMaxShardCount) { - MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; - return FAILED; - } - task_ = 0; // set two atomic vars to initial value - write_success_ = true; - - // spawn half the physical threads or total number of shards whichever is smaller - const unsigned int num_workers = - std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.GetShardCount())); - - std::vector threads; - threads.reserve(num_workers); - - for (size_t t = 0; t < threads.capacity(); t++) { - threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this)); - } - - for (size_t t = 0; t < threads.capacity(); t++) { - threads[t].join(); - } - return write_success_ ? SUCCESS : FAILED; -} - -void ShardIndexGenerator::DatabaseWriter() { - int shard_no = task_++; - while (shard_no < shard_header_.GetShardCount()) { - auto db = CreateDatabase(shard_no); - if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { - write_success_ = false; - return; - } - - MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; - - // Pre-processing page information - auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; - - std::map blob_id_to_page_id; - std::vector raw_page_ids; - for (uint64_t i = 0; i < total_pages; ++i) { - std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; - if (cur_page->GetPageType() == "RAW_DATA") { - raw_page_ids.push_back(i); - } else if (cur_page->GetPageType() == "BLOB_DATA") { - blob_id_to_page_id[cur_page->GetPageTypeID()] = i; - } - } - - if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { - write_success_ = false; - return; - } - MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; - shard_no = task_++; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc deleted file mode 100644 index 99fa0c447d..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ /dev/null @@ -1,1449 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_distributed_sample.h" -#include "mindrecord/include/shard_reader.h" -#include "common/utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -template -// convert the string to exactly number type (int32_t/int64_t/float/double) -Type StringToNum(const std::string &str) { - std::istringstream iss(str); - Type num; - iss >> num; - return num; -} - -ShardReader::ShardReader() { - task_id_ = 0; - deliver_id_ = 0; - shard_count_ = 0; - n_consumer_ = 0; - page_size_ = 0; - header_size_ = 0; - num_rows_ = 0; - row_id_ = 0; - num_blocks_ = 0; - block_reader_ = false; - num_padded_ = 0; -} - -std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { - if (!IsLegalFile(file_path)) { - return {FAILED, {}}; - } - auto ret = ShardHeader::BuildSingleHeader(file_path); - if (ret.first != SUCCESS) { - return {FAILED, {}}; - } - auto header = ret.second; - meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, - {"version", header["version"]}, {"index_fields", header["index_fields"]}, - {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; - return {SUCCESS, header["shard_addresses"]}; -} - -MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { - std::string file_path = file_paths[0]; - json first_meta_data = json(); - auto ret = GetMeta(file_path, first_meta_data); - if (ret.first != SUCCESS) { - return FAILED; - } - if (file_paths.size() == 1 && load_dataset == true) { - auto ret2 = GetParentDir(file_path); - if (SUCCESS != ret2.first) { - return FAILED; - } - std::vector real_addresses; - for (const auto &path : ret.second) { - std::string abs_path = ret2.second + string(path); - real_addresses.emplace_back(abs_path); - } - file_paths_ = real_addresses; - } else if (file_paths.size() >= 1 && load_dataset == false) { - file_paths_ = file_paths; - } else { - MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; - return FAILED; - } - for (const auto &file : file_paths_) { - json meta_data = json(); - auto ret1 = GetMeta(file, meta_data); - if (ret1.first != SUCCESS) { - return FAILED; - } - if (meta_data != first_meta_data) { - MS_LOG(ERROR) << "Mindrecord files meta information is different."; - return FAILED; - } - sqlite3 *db = nullptr; - // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it - int rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); - return FAILED; - } - MS_LOG(DEBUG) << "Opened database successfully"; - - string sql = "select NAME from SHARD_NAME;"; - std::vector> name; - char *errmsg = nullptr; - rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &name, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return FAILED; - } else { - MS_LOG(DEBUG) << "Get " << static_cast(name.size()) << " records from index."; - string shardName = GetFileName(file).second; - if (name.empty() || name[0][0] != shardName) { - MS_LOG(ERROR) << "DB file can not match file " << file; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return FAILED; - } - } - database_paths_.push_back(db); - } - ShardHeader sh = ShardHeader(); - if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { - return FAILED; - } - shard_header_ = std::make_shared(sh); - header_size_ = shard_header_->GetHeaderSize(); - page_size_ = shard_header_->GetPageSize(); - // version < 3.0 - if (first_meta_data["version"] < kVersion) { - shard_column_ = std::make_shared(shard_header_, false); - } else { - shard_column_ = std::make_shared(shard_header_, true); - } - num_rows_ = 0; - auto row_group_summary = ReadRowGroupSummary(); - for (const auto &rg : row_group_summary) { - num_rows_ += std::get<3>(rg); - } - - MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; - - return SUCCESS; -} - -MSRStatus ShardReader::CheckColumnList(const std::vector &selected_columns) { - vector inSchema(selected_columns.size(), 0); - for (auto &p : GetShardHeader()->GetSchemas()) { - auto schema = p->GetSchema()["schema"]; - for (unsigned int i = 0; i < selected_columns.size(); ++i) { - if (schema.find(selected_columns[i]) != schema.end()) { - inSchema[i] = 1; - } - } - } - if (std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; })) { - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardReader::Open() { - file_streams_.clear(); - - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - MS_LOG(INFO) << "Open shard file successfully."; - file_streams_.push_back(fs); - } - - return SUCCESS; -} - -MSRStatus ShardReader::Open(int n_consumer) { - file_streams_random_ = - std::vector>>(n_consumer, std::vector>()); - for (const auto &file : file_paths_) { - for (int j = 0; j < n_consumer; ++j) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - file_streams_random_[j].push_back(fs); - } - MS_LOG(INFO) << "Open shard file successfully."; - } - - return SUCCESS; -} - -void ShardReader::FileStreamsOperator() { - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; --i) { - if (file_streams_[i] != nullptr) { - file_streams_[i]->close(); - } - } - for (int i = static_cast(file_streams_random_.size()) - 1; i >= 0; --i) { - for (int j = static_cast(file_streams_random_[i].size()) - 1; j >= 0; --j) { - if (file_streams_random_[i][j] != nullptr) { - file_streams_random_[i][j]->close(); - } - } - } - for (int i = static_cast(database_paths_.size()) - 1; i >= 0; --i) { - if (database_paths_[i] != nullptr) { - auto ret = sqlite3_close(database_paths_[i]); - if (ret != SQLITE_OK) { - MS_LOG(ERROR) << "Close db failed. Error code: " << ret << "."; - } - database_paths_[i] = nullptr; - } - } -} - -ShardReader::~ShardReader() { Close(); } - -void ShardReader::Close() { - (void)Finish(); // interrupt reading and stop threads - FileStreamsOperator(); -} - -std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } - -std::shared_ptr ShardReader::GetShardColumn() const { return shard_column_; } - -int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } - -int ShardReader::GetNumRows() const { return num_rows_; } - -std::vector> ShardReader::ReadRowGroupSummary() { - std::vector> row_group_summary; - int shard_count = shard_header_->GetShardCount(); - if (shard_count <= 0) { - return row_group_summary; - } - if (shard_count <= kMaxShardCount) { - for (int shard_id = 0; shard_id < shard_count; ++shard_id) { - // return -1 when page's size equals to 0. - auto last_page_id = shard_header_->GetLastPageId(shard_id); - if (static_cast(last_page_id) == -1) { - continue; - } - for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { - const auto &page_t = shard_header_->GetPage(shard_id, page_id); - const auto &page = page_t.first; - if (page->GetPageType() != kPageTypeBlob) continue; - uint64_t start_row_id = page->GetStartRowID(); - if (start_row_id > page->GetEndRowID()) { - return std::vector>(); - } - uint64_t number_of_rows = page->GetEndRowID() - start_row_id; - row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); - } - } - } - return row_group_summary; -} - -MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, - std::shared_ptr fs, - std::vector>> &offsets, int shard_id, - const std::vector &columns, - std::vector> &column_values) { - for (int i = 0; i < static_cast(labels.size()); ++i) { - uint64_t group_id = std::stoull(labels[i][0]); - uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len; - uint64_t offset_end = std::stoull(labels[i][2]); - offsets[shard_id].emplace_back( - std::vector{static_cast(shard_id), group_id, offset_start, offset_end}); - if (!all_in_index_) { - int raw_page_id = std::stoi(labels[i][3]); - uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len; - uint64_t label_end = std::stoull(labels[i][5]); - auto len = label_end - label_start; - auto label_raw = std::vector(len); - auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - fs->close(); - return FAILED; - } - - auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - fs->close(); - return FAILED; - } - json label_json = json::from_msgpack(label_raw); - json tmp; - if (!columns.empty()) { - for (auto &col : columns) { - if (label_json.find(col) != label_json.end()) { - tmp[col] = label_json[col]; - } - } - } else { - tmp = label_json; - } - column_values[shard_id].emplace_back(tmp); - } else { - json construct_json; - for (unsigned int j = 0; j < columns.size(); ++j) { - // construct json "f1": value - auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; - - // convert the string to base type by schema - if (schema[columns[j]]["type"] == "int32") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else if (schema[columns[j]]["type"] == "int64") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else if (schema[columns[j]]["type"] == "float32") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else if (schema[columns[j]]["type"] == "float64") { - construct_json[columns[j]] = StringToNum(labels[i][j + 3]); - } else { - construct_json[columns[j]] = std::string(labels[i][j + 3]); - } - } - column_values[shard_id].emplace_back(construct_json); - } - } - - return SUCCESS; -} - -MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::vector>> &offsets, - std::vector> &column_values) { - auto db = database_paths_[shard_id]; - std::vector> labels; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return FAILED; - } - MS_LOG(INFO) << "Get " << static_cast(labels.size()) << " records from shard " << shard_id << " index."; - - std::string file_name = file_paths_[shard_id]; - std::shared_ptr fs = std::make_shared(); - if (!all_in_index_) { - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return FAILED; - } - } - sqlite3_free(errmsg); - return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); -} - -MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { - std::map index_columns; - for (auto &field : GetShardHeader()->GetFields()) { - index_columns[field.second] = field.first; - } - if (index_columns.find(category_field) == index_columns.end()) { - MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; - return FAILED; - } - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); - if (SUCCESS != ret.first) { - return FAILED; - } - std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; - std::vector threads = std::vector(shard_count_); - for (int x = 0; x < shard_count_; x++) { - threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); - } - - for (int x = 0; x < shard_count_; x++) { - threads[x].join(); - } - return SUCCESS; -} - -void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, - std::set &categories) { - if (nullptr == db) { - return; - } - std::vector> columns; - char *errmsg = nullptr; - int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); - if (ret != SQLITE_OK) { - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; - return; - } - MS_LOG(INFO) << "Get " << static_cast(columns.size()) << " records from shard " << shard_id << " index."; - std::lock_guard lck(shard_locker_); - for (int i = 0; i < static_cast(columns.size()); ++i) { - categories.emplace(columns[i][0]); - } -} - -ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { - std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; - std::vector>> offsets(shard_count_, std::vector>{}); - std::vector> column_values(shard_count_, std::vector{}); - if (all_in_index_) { - for (unsigned int i = 0; i < columns.size(); ++i) { - fields += ','; - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); - if (ret.first != SUCCESS) { - return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); - } - fields += ret.second; - } - } else { // fetch raw data from Raw page while some field is not index. - fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; - } - - std::string sql = "SELECT " + fields + " FROM INDEXES ORDER BY ROW_ID ;"; - - std::vector thread_read_db = std::vector(shard_count_); - for (int x = 0; x < shard_count_; x++) { - thread_read_db[x] = - std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values)); - } - - for (int x = 0; x < shard_count_; x++) { - thread_read_db[x].join(); - } - return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); -} - -ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - const std::shared_ptr &page = ret.second; - std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->GetPageSize(); - uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; - std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id); - - auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); - if (status_labels.first != SUCCESS) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), - std::move(status_labels.second)); -} - -ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, - const std::pair &criteria, - const std::vector &columns) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - vector criteria_list{criteria.first}; - if (CheckColumnList(criteria_list) == FAILED) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - const std::shared_ptr &page = ret.second; - std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->GetPageSize(); - uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; - std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); - - auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); - if (status_labels.first != SUCCESS) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), - std::move(status_labels.second)); -} - -int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) { - auto *records = static_cast> *>(p_data); - if (num_fields > 0 && num_fields <= kMaxFieldCount) { - for (int i = 0; i < num_fields; ++i) - if (p_fields[i] == nullptr) p_fields[i] = const_cast(""); - } - records->emplace_back(p_fields, p_fields + num_fields); - return 0; -} - -std::vector> ShardReader::GetImageOffset(int page_id, int shard_id, - const std::pair &criteria) { - auto db = database_paths_[shard_id]; - - std::string sql = - "SELECT PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); - - // whether use index search - if (!criteria.first.empty()) { - auto schema = shard_header_->GetSchemas()[0]->GetSchema(); - - // not number field should add '' in sql - if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { - sql += - " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; - } else { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + - criteria.second + "'"; - } - } - sql += ";"; - std::vector> image_offsets; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &image_offsets, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return std::vector>(); - } else { - MS_LOG(DEBUG) << "Get " << static_cast(image_offsets.size()) << "records from index."; - } - std::vector> res; - for (int i = static_cast(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector{0, 0}); - for (int i = 0; i < static_cast(image_offsets.size()); i++) { - const auto &image_offset = image_offsets[i]; - res[i][0] = std::stoull(image_offset[0]) + kInt64Len; - res[i][1] = std::stoull(image_offset[1]); - } - sqlite3_free(errmsg); - return res; -} - -std::pair> ShardReader::GetBlobFields() { - std::vector blob_fields; - for (auto &p : GetShardHeader()->GetSchemas()) { - // assume one schema - const auto &fields = p->GetBlobFields(); - blob_fields.assign(fields.begin(), fields.end()); - break; - } - return std::make_pair(kCV, blob_fields); -} - -void ShardReader::CheckIfColumnInIndex(const std::vector &columns) { - // assume different schemas do not contain same key. - if (columns.empty()) { - all_in_index_ = false; - return; - } - for (auto &field : GetShardHeader()->GetFields()) { - column_schema_id_[field.second] = field.first; - } - for (auto &col : columns) { - if (column_schema_id_.find(col) == column_schema_id_.end()) { - all_in_index_ = false; - return; - } - } -} - -MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria, - std::vector> &labels) { - sqlite3_stmt *stmt = nullptr; - if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not prepare statement"; - return FAILED; - } - int index = sqlite3_bind_parameter_index(stmt, ":criteria"); - if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << criteria; - return FAILED; - } - int rc = sqlite3_step(stmt); - while (rc != SQLITE_DONE) { - vector tmp; - int ncols = sqlite3_column_count(stmt); - for (int i = 0; i < ncols; i++) { - tmp.emplace_back(reinterpret_cast(sqlite3_column_text(stmt, i))); - } - labels.push_back(tmp); - rc = sqlite3_step(stmt); - } - (void)sqlite3_finalize(stmt); - return SUCCESS; -} - -std::pair> ShardReader::GetLabelsFromBinaryFile( - int shard_id, const std::vector &columns, const std::vector> &label_offsets) { - std::string file_name = file_paths_[shard_id]; - std::vector res; - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file_name), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "File could not opened"; - return {FAILED, {}}; - } - - // init the return - for (unsigned int i = 0; i < label_offsets.size(); ++i) { - res.emplace_back(json{}); - } - - for (unsigned int i = 0; i < label_offsets.size(); ++i) { - const auto &labelOffset = label_offsets[i]; - uint64_t label_start = std::stoull(labelOffset[1]) + kInt64Len; - uint64_t label_end = std::stoull(labelOffset[2]); - int raw_page_id = std::stoi(labelOffset[0]); - auto len = label_end - label_start; - auto label_raw = std::vector(len); - auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - fs->close(); - return {FAILED, {}}; - } - - auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - fs->close(); - return {FAILED, {}}; - } - - json label_json = json::from_msgpack(label_raw); - json tmp = label_json; - for (auto &col : columns) { - if (label_json.find(col) != label_json.end()) { - tmp[col] = label_json[col]; - } - } - res[i] = tmp; - } - return {SUCCESS, res}; -} - -std::pair> ShardReader::GetLabelsFromPage( - int page_id, int shard_id, const std::vector &columns, - const std::pair &criteria) { - // get page info from sqlite - auto db = database_paths_[shard_id]; - std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + - std::to_string(page_id); - std::vector> label_offsets; - if (!criteria.first.empty()) { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; - if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) { - return {FAILED, {}}; - } - } else { - sql += ";"; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return {FAILED, {}}; - } - MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index."; - sqlite3_free(errmsg); - } - // get labels from binary file - return GetLabelsFromBinaryFile(shard_id, columns, label_offsets); -} - -std::pair> ShardReader::GetLabels(int page_id, int shard_id, - const std::vector &columns, - const std::pair &criteria) { - if (all_in_index_) { - auto db = database_paths_[shard_id]; - std::string fields; - for (unsigned int i = 0; i < columns.size(); ++i) { - if (i > 0) fields += ','; - uint64_t schema_id = column_schema_id_[columns[i]]; - fields += columns[i] + "_" + std::to_string(schema_id); - } - if (fields.empty()) fields = "*"; - std::vector> labels; - std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); - if (!criteria.first.empty()) { - sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; - if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) { - return {FAILED, {}}; - } - } else { - sql += ";"; - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return {FAILED, {}}; - } else { - MS_LOG(DEBUG) << "Get " << static_cast(labels.size()) << "records from index."; - } - sqlite3_free(errmsg); - } - std::vector ret; - for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); - for (unsigned int i = 0; i < labels.size(); ++i) { - json construct_json; - for (unsigned int j = 0; j < columns.size(); ++j) { - // construct json "f1": value - auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; - - // convert the string to base type by schema - if (schema[columns[j]]["type"] == "int32") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else if (schema[columns[j]]["type"] == "int64") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else if (schema[columns[j]]["type"] == "float32") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else if (schema[columns[j]]["type"] == "float64") { - construct_json[columns[j]] = StringToNum(labels[i][j]); - } else { - construct_json[columns[j]] = std::string(labels[i][j]); - } - } - ret[i] = construct_json; - } - return {SUCCESS, ret}; - } - return GetLabelsFromPage(page_id, shard_id, columns, criteria); -} - -bool ResortRowGroups(std::tuple a, std::tuple b) { - return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); -} - -MSRStatus ShardReader::Finish() { - { - std::lock_guard lck(mtx_delivery_); - interrupt_ = true; - } - cv_delivery_.notify_all(); - - // Wait for all threads to finish - for (auto &i_thread : thread_set_) { - if (i_thread.joinable()) { - i_thread.join(); - } - } - return SUCCESS; -} - -int64_t ShardReader::GetNumClasses(const std::string &category_field) { - auto shard_count = file_paths_.size(); - auto index_fields = shard_header_->GetFields(); - - std::map map_schema_id_fields; - for (auto &field : index_fields) { - map_schema_id_fields[field.second] = field.first; - } - - if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) { - MS_LOG(ERROR) << "Field " << category_field << " does not exist."; - return -1; - } - auto ret = - ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); - if (SUCCESS != ret.first) { - return -1; - } - std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; - std::vector threads = std::vector(shard_count); - std::set categories; - for (int x = 0; x < shard_count; x++) { - sqlite3 *db = nullptr; - int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); - if (SQLITE_OK != rc) { - MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); - return -1; - } - threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); - } - - for (int x = 0; x < shard_count; x++) { - threads[x].join(); - } - return categories.size(); -} - -MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &ops, int64_t *count, const int num_padded) { - if (SUCCESS != Init(file_paths, load_dataset)) { - return FAILED; - } - int64_t num_samples = num_rows_; - bool root = true; - std::stack> stack_ops; - std::shared_ptr op(ops); - while (op != nullptr) { - stack_ops.push(op); - op = op->GetChildOp(); - } - while (!stack_ops.empty()) { - op = stack_ops.top(); - stack_ops.pop(); - if (std::dynamic_pointer_cast(op)) { - num_samples = op->GetNumSamples(num_samples, 0); - if (num_padded > 0 && root == true) { - num_samples += num_padded; - MS_LOG(DEBUG) << "Padding samples work on shuffle sampler."; - root = false; - } - } else if (std::dynamic_pointer_cast(op)) { - auto category_op = std::dynamic_pointer_cast(op); - std::string category_field = category_op->GetCategoryField(); - auto num_classes = GetNumClasses(category_field); - num_samples = category_op->GetNumSamples(num_samples, num_classes); - } else if (std::dynamic_pointer_cast(op)) { - if (std::dynamic_pointer_cast(op)) { - auto sampler_op = std::dynamic_pointer_cast(op); - if (root == true) { - sampler_op->SetNumPaddedSamples(num_padded); - num_samples = op->GetNumSamples(num_samples, 0); - if (-1 == num_samples) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; - return FAILED; - } - root = false; - } - } else { - num_samples = op->GetNumSamples(num_samples, 0); - } - } else { - if (num_padded > 0) num_samples += num_padded; - } - } - *count = num_samples; - return SUCCESS; -} - -MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, - const std::vector &selected_columns, - const std::vector> &operators, const bool &block_reader, - int num_padded) { - // Open file and set header by ShardReader - auto ret = Init(file_paths, load_dataset); - if (SUCCESS != ret) { - return ret; - } - auto thread_limit = GetMaxThreadNum(); - if (n_consumer > thread_limit) { - n_consumer = thread_limit; - } - if (n_consumer < kMinConsumerCount) { - n_consumer = kMinConsumerCount; - } - vector blob_fields = GetBlobFields().second; - for (unsigned int i = 0; i < selected_columns.size(); ++i) { - if (!std::any_of(blob_fields.begin(), blob_fields.end(), - [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { - selected_columns_.push_back(selected_columns[i]); - } - } - selected_columns_ = selected_columns; - - if (CheckColumnList(selected_columns_) == FAILED) { - MS_LOG(ERROR) << "Illegal column list"; - return ILLEGAL_COLUMN_LIST; - } - - // Initialize argument - shard_count_ = static_cast(file_paths_.size()); - n_consumer_ = n_consumer; - num_padded_ = num_padded; - - operators_ = operators; - - if (block_reader) { - block_reader_ = true; - if (Open() == FAILED) { - return FAILED; - } - delivery_block_ = std::vector>, std::vector>>>( - kNumPageInBuffer, std::shared_ptr>, std::vector>>{}); - buf_ = std::vector>(kNumPageInBuffer, std::vector(page_size_)); - } else { - block_reader_ = false; - if (Open(n_consumer) == FAILED) { - return FAILED; - } - } - return SUCCESS; -} - -MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer, - const std::vector &selected_columns, - const std::vector> &operators) { - // Open file and set header by ShardReader - if (SUCCESS != Init(file_paths, load_dataset)) { - return FAILED; - } - // should remove blob field from selected_columns when call from python - std::vector columns(selected_columns); - auto blob_fields = GetBlobFields().second; - for (auto &blob_field : blob_fields) { - auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); - if (it != selected_columns.end()) { - columns.erase(columns.begin() + std::distance(selected_columns.begin(), it)); - } - } - if (CheckColumnList(columns) == FAILED) { - MS_LOG(ERROR) << "Illegal column list"; - return FAILED; - } - if (Open(n_consumer) == FAILED) { - return FAILED; - } - // Initialize argument - shard_count_ = static_cast(file_paths_.size()); - n_consumer_ = n_consumer; - - // Initialize columns which will be read - selected_columns_ = selected_columns; - operators_ = operators; - - return SUCCESS; -} - -MSRStatus ShardReader::Launch(bool isSimpleReader) { - // Get all row groups' info - auto row_group_summary = ReadRowGroupSummary(); - - // Sort row group by (group_id, shard_id), prepare for parallel reading - std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); - if (CreateTasks(row_group_summary, operators_) != SUCCESS) { - MS_LOG(ERROR) << "Failed to launch read threads."; - interrupt_ = true; - return FAILED; - } - if (isSimpleReader) return SUCCESS; - // Start provider consumer threads - thread_set_ = std::vector(n_consumer_); - if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) { - return FAILED; - } - - for (int x = 0; x < n_consumer_; ++x) { - if (block_reader_) { - thread_set_[x] = std::thread(&ShardReader::ConsumerByBlock, this, x); - } else { - thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); - } - } - - MS_LOG(INFO) << "Launch read thread successfully."; - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators) { - CheckIfColumnInIndex(selected_columns_); - for (const auto &rg : row_group_summary) { - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto n_Rows = std::get<3>(rg); - tasks_.InsertTask(TaskType::kCommonTask, shard_id, group_id, std::vector{n_Rows}, json{}); - } - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, - const std::shared_ptr &op) { - CheckIfColumnInIndex(selected_columns_); - auto category_op = std::dynamic_pointer_cast(op); - auto categories = category_op->GetCategories(); - int64_t num_elements = category_op->GetNumElements(); - if (num_elements <= 0) { - MS_LOG(ERROR) << "Parameter num_element is not positive"; - return FAILED; - } - if (categories.empty() == true) { - std::string category_field = category_op->GetCategoryField(); - int64_t num_categories = category_op->GetNumCategories(); - if (num_categories <= 0) { - MS_LOG(ERROR) << "Parameter num_categories is not positive"; - return FAILED; - } - std::set categories_set; - auto ret = GetAllClasses(category_field, categories_set); - if (SUCCESS != ret) { - return FAILED; - } - int i = 0; - for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { - categories.emplace_back(category_field, *it); - i++; - } - } - // Generate task list, a task will create a batch - std::vector categoryTasks(categories.size()); - for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { - int category_index = 0; - for (const auto &rg : row_group_summary) { - if (category_index >= num_elements) break; - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - - auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); - if (SUCCESS != std::get<0>(details)) { - return FAILED; - } - auto offsets = std::get<4>(details); - - auto number_of_rows = offsets.size(); - for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { - if (category_index < num_elements) { - categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], - std::get<5>(details)[iStart]); - category_index++; - } - } - } - MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; - } - tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); - if (SUCCESS != (*category_op)(tasks_)) { - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators) { - CheckIfColumnInIndex(selected_columns_); - - auto ret = ReadAllRowGroup(selected_columns_); - if (std::get<0>(ret) != SUCCESS) { - return FAILED; - } - auto offsets = std::get<1>(ret); - auto local_columns = std::get<2>(ret); - if (shard_count_ <= kMaxShardCount) { - for (int shard_id = 0; shard_id < shard_count_; shard_id++) { - for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { - tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], - std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, - local_columns[shard_id][i]); - } - } - } else { - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::CreateTasks(const std::vector> &row_group_summary, - const std::vector> &operators) { - if (block_reader_) { - if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { - return FAILED; - } - } else { - int category_operator = -1; - for (uint32_t i = 0; i < operators.size(); ++i) { - const auto &op = operators[i]; - if (std::dynamic_pointer_cast(op)) { - category_operator = static_cast(i); - break; - } - } - if (-1 == category_operator) { - if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { - return FAILED; - } - if (num_padded_ > 0) { - for (int i = 0; i < num_padded_; ++i) { - tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); - } - } - } else { - if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { - return FAILED; - } - } - } - - for (uint32_t operator_no = 0; operator_no < operators.size(); operator_no++) { - const auto &op = operators[operator_no]; - if (std::dynamic_pointer_cast(op)) continue; - if (block_reader_ && std::dynamic_pointer_cast(op)) continue; - if (SUCCESS != (*op)(tasks_)) { - return FAILED; - } - } - - if (tasks_.permutation_.empty()) tasks_.MakePerm(); - num_rows_ = block_reader_ ? tasks_.SizeOfRows() : tasks_.Size(); - num_blocks_ = block_reader_ ? tasks_.Size() : 0; - MS_LOG(INFO) << "Total rows is " << num_rows_; - return SUCCESS; -} - -TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) { - // All tasks are done - if (task_id >= static_cast(tasks_.Size())) { - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - - // Pick up task from task list - auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - - // check task type - auto task_type = std::get<0>(task); - if (task_type == TaskType::kPaddedTask) { - return std::make_pair(SUCCESS, - std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); - } - - auto shard_id = std::get<0>(std::get<1>(task)); - auto group_id = std::get<1>(std::get<1>(task)); - auto addr = std::get<2>(task); - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - const std::shared_ptr &page = ret.second; - - // Pack image list - std::vector images(addr[1] - addr[0]); - auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; - - auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - - auto &io_read = - file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), addr[1] - addr[0]); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, - std::pair(TaskType::kCommonTask, std::vector, json>>())); - } - - // Deliver batch data to output map - std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(std::get<3>(task))); - - return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); -} - -MSRStatus ShardReader::ConsumerByRow(int consumer_id) { - // Set thread name -#if !defined(_WIN32) && !defined(_WIN64) - auto thread_id = kThreadName + std::to_string(consumer_id); - prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); -#endif - - // Loop forever - for (;;) { - int task_id = 0; - - // Get next task ID - task_id = task_id_++; - - // All tasks are done - if (task_id >= static_cast(tasks_.Size())) { - return FAILED; - } - const auto &ret = ConsumerOneTask(task_id, consumer_id); - if (SUCCESS != ret.first) { - return FAILED; - } - const auto &batch = (ret.second).second; - // Hanging if maximum map size exceeded - // otherwise, set batch data in map - { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id <= deliver_id_ + kNumBatchInMap; }); - if (interrupt_) { - return SUCCESS; - } - delivery_map_[task_id] = std::make_shared, json>>>(std::move(batch)); - } - cv_iterator_.notify_one(); - } -} - -MSRStatus ShardReader::ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, - const int &buf_id) { - auto &io_seekg = file_streams_[shard_id]->seekg(page_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf_[buf_id][0]), page_length); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { - // Set thread name -#if !defined(_WIN32) && !defined(_WIN64) - auto thread_id = kThreadName + std::to_string(consumer_id); - prctl(PR_SET_NAME, common::SafeCStr(thread_id), 0, 0, 0); -#endif - - // Loop forever - for (;;) { - int task_id = 0; - - // Get next task ID - task_id = task_id_++; - - // All tasks are done, either quit or repeat again - if (task_id >= num_blocks_) { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [this] { return interrupt_ || task_id_ < num_blocks_; }); - if (interrupt_) { - return SUCCESS; - } - continue; - } - - // Pick up task from task list - auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); - - auto shard_id = std::get<0>(std::get<1>(task)); - auto group_id = std::get<1>(std::get<1>(task)); - auto row_group_brief = ReadRowGroupBrief(group_id, shard_id, selected_columns_); - if (SUCCESS != std::get<0>(row_group_brief)) { - return FAILED; - } - auto page_length = std::get<2>(row_group_brief); - auto page_offset = std::get<3>(row_group_brief); - - MS_LOG(DEBUG) << "Block task " << task_id << tasks_.permutation_[task_id] << ", shard " << shard_id << ", group " - << group_id << ", page length " << page_length << ", page offset " << page_offset; - - // Deliver block data to output map - auto offset_and_labels = std::make_pair(std::get<4>(row_group_brief), std::get<5>(row_group_brief)); - - int deliver_id = deliver_id_; - // Hanging if maximum map size exceeded otherwise, set batch data in buffer - { - std::unique_lock lck(mtx_delivery_); - cv_delivery_.wait(lck, [task_id, this] { return interrupt_ || task_id < deliver_id_ + kNumPageInBuffer; }); - if (interrupt_) { - return SUCCESS; - } - } - - auto buf_id = task_id % kNumPageInBuffer; - delivery_block_[buf_id] = - std::make_shared>, std::vector>>(offset_and_labels); - - // Read blob - if (ReadBlob(shard_id, page_offset, page_length, buf_id) != SUCCESS) { - return FAILED; - } - - { - std::unique_lock lck(mtx_delivery_); - delivery_block_set_.insert(task_id); - } - cv_iterator_.notify_one(); - } -} - -std::shared_ptr, json>>> ShardReader::GetRowFromBuffer(int buf_id, - int rowId) { - auto &blob_page = buf_[buf_id]; - auto &offsets = (*delivery_block_[buf_id]).first; - auto &labels = (*delivery_block_[buf_id]).second; - auto &addr_start = offsets[rowId][0]; - auto &addr_end = offsets[rowId][1]; - std::vector images(blob_page.begin() + addr_start, blob_page.begin() + addr_end); - std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(labels[rowId])); - return std::make_shared, json>>>(std::move(batch)); -} - -std::vector, json>> ShardReader::GetBlockNext() { - if (deliver_id_ >= num_blocks_) { - return std::vector, json>>(); - } - - if (row_id_ == 0) { - std::unique_lock lck(mtx_delivery_); - cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_block_set_.count(deliver_id_) > 0); }); - - if (interrupt_) { - return std::vector, json>>(); - } - } - auto buf_id = deliver_id_ % kNumPageInBuffer; - auto res = GetRowFromBuffer(buf_id, row_id_); - - row_id_++; - if (row_id_ == (*delivery_block_[buf_id]).first.size()) { - row_id_ = 0; - { - std::unique_lock lck(mtx_delivery_); - delivery_block_set_.erase(deliver_id_++); - } - cv_delivery_.notify_all(); - } - - return *res; -} - -std::vector, json>> ShardReader::GetNext() { - if (interrupt_) { - return std::vector, json>>(); - } - if (block_reader_) return GetBlockNext(); - if (deliver_id_ >= static_cast(tasks_.Size())) { - return std::vector, json>>(); - } - - std::shared_ptr, json>>> res; - { - std::unique_lock lck(mtx_delivery_); - cv_iterator_.wait(lck, [this] { return interrupt_ || (delivery_map_.count(deliver_id_) > 0); }); - if (interrupt_) { - return std::vector, json>>(); - } - res = delivery_map_[deliver_id_]; - delivery_map_.erase(deliver_id_++); - } - - cv_delivery_.notify_all(); - - return *res; -} - -std::pair, json>>> ShardReader::GetNextById( - const int64_t &task_id, const int32_t &consumer_id) { - if (interrupt_) { - return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); - } - if (block_reader_) { - return std::make_pair(TaskType::kCommonTask, GetBlockNext()); - } - const auto &ret = ConsumerOneTask(task_id, consumer_id); - if (SUCCESS != ret.first) { - return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); - } - return std::move(ret.second); -} - -std::pair>> ShardReader::UnCompressBlob( - const std::vector &raw_blob_data) { - auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; - auto blob_fields = GetBlobFields().second; - std::vector> blob_data; - for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { - if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0; - auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; - return {FAILED, std::vector>(blob_fields.size(), std::vector())}; - } - if (data == nullptr) { - data = reinterpret_cast(data_ptr.get()); - } - std::vector column(data, data + (n_bytes / sizeof(unsigned char))); - blob_data.push_back(column); - } - return {SUCCESS, blob_data}; -} - -std::vector>, pybind11::object>> ShardReader::GetNextPy() { - auto res = GetNext(); - vector>, pybind11::object>> data; - std::transform(res.begin(), res.end(), std::back_inserter(data), - [this](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - auto ret = UnCompressBlob(std::get<0>(item)); - return std::make_tuple(ret.second, std::move(obj)); - }); - return data; -} - -void ShardReader::Reset() { - { - std::lock_guard lck(mtx_delivery_); - task_id_ = 0; - deliver_id_ = 0; - } - cv_delivery_.notify_all(); -} - -void ShardReader::ShuffleTask() { - if (block_reader_) return; - // exist shuffle and distributed sampler in ops, skip shuffle - bool has_sharding = false; - for (const auto &op : operators_) { - if (std::dynamic_pointer_cast(op)) { - has_sharding = true; - } - } - for (const auto &op : operators_) { - if (std::dynamic_pointer_cast(op) && has_sharding == false) { - if (SUCCESS != (*op)(tasks_)) { - MS_LOG(WARNING) << "Redo randomSampler failed."; - } - } else if (std::dynamic_pointer_cast(op)) { - if (SUCCESS != (*op)(tasks_)) { - MS_LOG(WARNING) << "Redo distributeSampler failed."; - } - } - } - if (tasks_.permutation_.empty()) tasks_.MakePerm(); -} - -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/mindrecord/io/shard_segment.cc deleted file mode 100644 index fb1120b178..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_segment.cc +++ /dev/null @@ -1,385 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_segment.h" -#include "common/utils.h" - -#include "./securec.h" -#include "mindrecord/include/common/shard_utils.h" -#include "pybind11/pybind11.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -ShardSegment::ShardSegment() { SetAllInIndex(false); } - -std::pair> ShardSegment::GetCategoryFields() { - // Skip if already populated - if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; - - std::string sql = "PRAGMA table_info(INDEXES);"; - std::vector> field_names; - - char *errmsg = nullptr; - int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(database_paths_[0]); - database_paths_[0] = nullptr; - return {FAILED, vector{}}; - } else { - MS_LOG(INFO) << "Get " << static_cast(field_names.size()) << " records from index."; - } - - uint32_t idx = kStartFieldId; - while (idx < field_names.size()) { - if (field_names[idx].size() < 2) { - sqlite3_free(errmsg); - sqlite3_close(database_paths_[0]); - database_paths_[0] = nullptr; - return {FAILED, vector{}}; - } - candidate_category_fields_.push_back(field_names[idx][1]); - idx += 2; - } - sqlite3_free(errmsg); - return {SUCCESS, candidate_category_fields_}; -} - -MSRStatus ShardSegment::SetCategoryField(std::string category_field) { - if (GetCategoryFields().first != SUCCESS) { - MS_LOG(ERROR) << "Get candidate category field failed"; - return FAILED; - } - category_field = category_field + "_0"; - if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), - [category_field](std::string x) { return x == category_field; })) { - current_category_field_ = category_field; - return SUCCESS; - } - MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; - return FAILED; -} - -std::pair ShardSegment::ReadCategoryInfo() { - MS_LOG(INFO) << "Read category begin"; - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info failed"; - return {FAILED, ""}; - } - // Convert category info to json string - auto category_json_string = ToJsonForCategory(ret.second); - - MS_LOG(INFO) << "Read category end"; - - return {SUCCESS, category_json_string}; -} - -std::pair>> ShardSegment::WrapCategoryInfo() { - std::map counter; - - std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + - ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; - - for (auto &db : database_paths_) { - std::vector> field_count; - - char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; - sqlite3_free(errmsg); - sqlite3_close(db); - db = nullptr; - return {FAILED, std::vector>()}; - } else { - MS_LOG(INFO) << "Get " << static_cast(field_count.size()) << " records from index."; - } - - for (const auto &field : field_count) { - counter[field[0]] += std::stoi(field[1]); - } - sqlite3_free(errmsg); - } - - int idx = 0; - std::vector> category_vec(counter.size()); - (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple item) { - return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); - }); - return {SUCCESS, std::move(category_vec)}; -} - -std::string ShardSegment::ToJsonForCategory(const std::vector> &tri_vec) { - std::vector category_json_vec; - for (auto q : tri_vec) { - json j; - j["id"] = std::get<0>(q); - j["name"] = std::get<1>(q); - j["count"] = std::get<2>(q); - - category_json_vec.emplace_back(j); - } - - json j_vec(category_json_vec); - json category_info; - category_info["key"] = current_category_field_; - category_info["categories"] = j_vec; - return category_info.dump(); -} - -std::pair>> ShardSegment::ReadAtPageById(int64_t category_id, - int64_t page_no, - int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector>{}}; - } - if (category_id >= static_cast(ret.second.size()) || category_id < 0) { - MS_LOG(ERROR) << "Illegal category id, id: " << category_id; - return {FAILED, std::vector>{}}; - } - int total_rows_in_category = std::get<2>(ret.second[category_id]); - // Quit if category not found or page number is out of range - if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || - page_no * n_rows_of_page >= total_rows_in_category) { - MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; - return {FAILED, std::vector>{}}; - } - - std::vector> page; - auto row_group_summary = ReadRowGroupSummary(); - - uint64_t i_start = page_no * n_rows_of_page; - uint64_t i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); - uint64_t idx = 0; - for (const auto &rg : row_group_summary) { - if (idx >= i_end) break; - - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria( - group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); - if (SUCCESS != std::get<0>(details)) { - return {FAILED, std::vector>{}}; - } - auto offsets = std::get<4>(details); - uint64_t number_of_rows = offsets.size(); - if (idx + number_of_rows < i_start) { - idx += number_of_rows; - continue; - } - - for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { - if (idx >= i_start && idx < i_end) { - auto ret1 = PackImages(group_id, shard_id, offsets[i]); - if (SUCCESS != ret1.first) { - return {FAILED, std::vector>{}}; - } - page.push_back(std::move(ret1.second)); - } - } - } - - return {SUCCESS, std::move(page)}; -} - -std::pair> ShardSegment::PackImages(int group_id, int shard_id, - std::vector offset) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return {FAILED, std::vector()}; - } - const std::shared_ptr &blob_page = ret.second; - - // Pack image list - std::vector images(offset[1] - offset[0]); - auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; - auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_random_[0][shard_id]->close(); - return {FAILED, {}}; - } - - auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast(&images[0]), offset[1] - offset[0]); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_random_[0][shard_id]->close(); - return {FAILED, {}}; - } - - return {SUCCESS, std::move(images)}; -} - -std::pair>> ShardSegment::ReadAtPageByName(std::string category_name, - int64_t page_no, - int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector>{}}; - } - for (const auto &categories : ret.second) { - if (std::get<1>(categories) == category_name) { - auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); - return result; - } - } - - return {FAILED, std::vector>()}; -} - -std::pair, json>>> ShardSegment::ReadAllAtPageById( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS || category_id >= static_cast(ret.second.size())) { - MS_LOG(ERROR) << "Illegal category id, id: " << category_id; - return {FAILED, std::vector, json>>{}}; - } - int total_rows_in_category = std::get<2>(ret.second[category_id]); - // Quit if category not found or page number is out of range - if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { - MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; - return {FAILED, std::vector, json>>{}}; - } - - std::vector, json>> page; - auto row_group_summary = ReadRowGroupSummary(); - - int i_start = page_no * n_rows_of_page; - int i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); - int idx = 0; - for (const auto &rg : row_group_summary) { - if (idx >= i_end) break; - - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria( - group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); - if (SUCCESS != std::get<0>(details)) { - return {FAILED, std::vector, json>>{}}; - } - auto offsets = std::get<4>(details); - auto labels = std::get<5>(details); - - int number_of_rows = offsets.size(); - if (idx + number_of_rows < i_start) { - idx += number_of_rows; - continue; - } - - if (number_of_rows > static_cast(labels.size())) { - MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; - return {FAILED, std::vector, json>>{}}; - } - for (int i = 0; i < number_of_rows; ++i, ++idx) { - if (idx >= i_start && idx < i_end) { - auto ret1 = PackImages(group_id, shard_id, offsets[i]); - if (SUCCESS != ret1.first) { - return {FAILED, std::vector, json>>{}}; - } - page.emplace_back(std::move(ret1.second), std::move(labels[i])); - } - } - } - return {SUCCESS, std::move(page)}; -} - -std::pair, json>>> ShardSegment::ReadAllAtPageByName( - std::string category_name, int64_t page_no, int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector, json>>{}}; - } - - // category_name to category_id - int64_t category_id = -1; - for (const auto &categories : ret.second) { - std::string categories_name = std::get<1>(categories); - - if (categories_name == category_name) { - category_id = std::get<0>(categories); - break; - } - } - - if (category_id == -1) { - return {FAILED, std::vector, json>>{}}; - } - - return ReadAllAtPageById(category_id, page_no, n_rows_of_page); -} - -std::pair, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { - auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); - if (res.first != SUCCESS) { - return {FAILED, std::vector, pybind11::object>>{}}; - } - - vector, pybind11::object>> json_data; - std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), - [](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); - }); - return {SUCCESS, std::move(json_data)}; -} - -std::pair, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( - std::string category_name, int64_t page_no, int64_t n_rows_of_page) { - auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); - if (res.first != SUCCESS) { - return {FAILED, std::vector, pybind11::object>>{}}; - } - vector, pybind11::object>> json_data; - std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), - [](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); - }); - return {SUCCESS, std::move(json_data)}; -} - -std::pair> ShardSegment::GetBlobFields() { - std::vector blob_fields; - for (auto &p : GetShardHeader()->GetSchemas()) { - // assume one schema - const auto &fields = p->GetBlobFields(); - blob_fields.assign(fields.begin(), fields.end()); - break; - } - return std::make_pair(kCV, blob_fields); -} - -std::string ShardSegment::CleanUp(std::string field_name) { - while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); - field_name.pop_back(); - return field_name; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc deleted file mode 100644 index 913caab550..0000000000 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ /dev/null @@ -1,1254 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_writer.h" -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" -#include "./securec.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; -using mindspore::MsLogLevel::ERROR; -using mindspore::MsLogLevel::INFO; - -namespace mindspore { -namespace mindrecord { -ShardWriter::ShardWriter() - : shard_count_(1), - header_size_(kDefaultHeaderSize), - page_size_(kDefaultPageSize), - row_count_(0), - schema_count_(1) {} - -ShardWriter::~ShardWriter() { - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { - file_streams_[i]->close(); - } -} - -MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { - // Get full path from file name - for (const auto &path : paths) { - if (!CheckIsValidUtf8(path)) { - MS_LOG(ERROR) << "The filename contains invalid uft-8 data: " << path << "."; - return FAILED; - } - char resolved_path[PATH_MAX] = {0}; - char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Secure func failed"; - return FAILED; - } -#if defined(_WIN32) || defined(_WIN64) - if (_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path"; - return FAILED; - } - if (_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path " << resolved_path; - } -#else - if (realpath(dirname(&(buf[0])), resolved_path) == nullptr) { - MS_LOG(ERROR) << "Invalid file path"; - return FAILED; - } - if (realpath(common::SafeCStr(path), resolved_path) == nullptr) { - MS_LOG(DEBUG) << "Path " << resolved_path; - } -#endif - file_paths_.emplace_back(string(resolved_path)); - } - return SUCCESS; -} - -MSRStatus ShardWriter::OpenDataFiles(bool append) { - // Open files - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - if (!append) { - // if not append and mindrecord file exist, return FAILED - fs->open(common::SafeCStr(file), std::ios::in | std::ios::binary); - if (fs->good()) { - MS_LOG(ERROR) << "MindRecord file already existed."; - fs->close(); - return FAILED; - } - fs->close(); - - // open the mindrecord file to write - fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc); - if (!fs->good()) { - MS_LOG(ERROR) << "MindRecord file could not opened."; - return FAILED; - } - } else { - // open the mindrecord file to append - fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "MindRecord file could not opened for append."; - return FAILED; - } - } - MS_LOG(INFO) << "Open shard file successfully."; - file_streams_.push_back(fs); - } - return SUCCESS; -} - -MSRStatus ShardWriter::RemoveLockFile() { - // Remove temporary file - int ret = std::remove(pages_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove page file."; - } - - ret = std::remove(lock_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove lock file."; - } - return SUCCESS; -} - -MSRStatus ShardWriter::InitLockFile() { - if (file_paths_.size() == 0) { - MS_LOG(ERROR) << "File path not initialized."; - return FAILED; - } - - lock_file_ = file_paths_[0] + kLockFileSuffix; - pages_file_ = file_paths_[0] + kPageFileSuffix; - - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove file failed."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - - // Get full path from file name - if (GetFullPathFromFileName(paths) == FAILED) { - MS_LOG(ERROR) << "Get full path from file name failed."; - return FAILED; - } - - // Open files - if (OpenDataFiles(append) == FAILED) { - MS_LOG(ERROR) << "Open data files failed."; - return FAILED; - } - - // Init lock file - if (InitLockFile() == FAILED) { - MS_LOG(ERROR) << "Init lock file failed."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::OpenForAppend(const std::string &path) { - if (!IsLegalFile(path)) { - return FAILED; - } - auto ret1 = ShardHeader::BuildSingleHeader(path); - if (ret1.first != SUCCESS) { - return FAILED; - } - auto json_header = ret1.second; - auto ret2 = GetParentDir(path); - if (SUCCESS != ret2.first) { - return FAILED; - } - std::vector real_addresses; - for (const auto &path : json_header["shard_addresses"]) { - std::string abs_path = ret2.second + string(path); - real_addresses.emplace_back(abs_path); - } - ShardHeader header = ShardHeader(); - if (header.BuildDataset(real_addresses) == FAILED) { - return FAILED; - } - shard_header_ = std::make_shared(header); - MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); - if (ret == FAILED) { - return FAILED; - } - ret = SetPageSize(shard_header_->GetPageSize()); - if (ret == FAILED) { - return FAILED; - } - ret = Open(real_addresses, true); - if (ret == FAILED) { - MS_LOG(ERROR) << "Open file failed"; - return FAILED; - } - shard_column_ = std::make_shared(shard_header_); - return SUCCESS; -} - -MSRStatus ShardWriter::Commit() { - // Read pages file - std::ifstream page_file(pages_file_.c_str()); - if (page_file.good()) { - page_file.close(); - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return FAILED; - } - } - - if (WriteShardHeader() == FAILED) { - MS_LOG(ERROR) << "Write metadata failed"; - return FAILED; - } - MS_LOG(INFO) << "Write metadata successfully."; - - // Remove lock file - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove lock file failed."; - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) { - MSRStatus ret = header_data->InitByFiles(file_paths_); - if (ret == FAILED) { - return FAILED; - } - - // set fields in mindrecord when empty - std::vector> fields = header_data->GetFields(); - if (fields.empty()) { - MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; - std::vector> schemas = header_data->GetSchemas(); - for (const auto &schema : schemas) { - json jsonSchema = schema->GetSchema()["schema"]; - for (const auto &el : jsonSchema.items()) { - if (el.value()["type"] == "string" || - (el.value()["type"] == "int32" && el.value().find("shape") == el.value().end()) || - (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || - (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || - (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { - fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); - } - } - } - // only blob data - if (!fields.empty()) { - ret = header_data->AddIndexFields(fields); - if (ret == FAILED) { - MS_LOG(ERROR) << "Add index field failed"; - return FAILED; - } - } - } - - shard_header_ = header_data; - shard_header_->SetHeaderSize(header_size_); - shard_header_->SetPageSize(page_size_); - shard_column_ = std::make_shared(shard_header_); - return SUCCESS; -} - -MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { - // header_size [16KB, 128MB] - if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { - MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; - return FAILED; - } - if (header_size % 4 != 0) { - MS_LOG(ERROR) << "Header size should be divided by four."; - return FAILED; - } - - header_size_ = header_size; - return SUCCESS; -} - -MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { - // PageSize [32KB, 256MB] - if (page_size < kMinPageSize || page_size > kMaxPageSize) { - MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; - return FAILED; - } - if (page_size % 4 != 0) { - MS_LOG(ERROR) << "Page size should be divided by four."; - return FAILED; - } - page_size_ = page_size; - return SUCCESS; -} - -void ShardWriter::DeleteErrorData(std::map> &raw_data, - std::vector> &blob_data) { - // get wrong data location - std::set> delete_set; - for (auto &err_mg : err_mg_) { - uint64_t id = err_mg.first; - auto sub_err_mg = err_mg.second; - for (auto &subMg : sub_err_mg) { - int loc = subMg.first; - std::string message = subMg.second; - MS_LOG(ERROR) << "For schema " << id << ", " << loc + 1 << " th data is wrong: " << message; - (void)delete_set.insert(loc); - } - } - - auto it = raw_data.begin(); - if (delete_set.size() == it->second.size()) { - raw_data.clear(); - blob_data.clear(); - return; - } - - // delete wrong raw data - for (auto &loc : delete_set) { - // delete row data - for (auto &raw : raw_data) { - (void)raw.second.erase(raw.second.begin() + loc); - } - - // delete blob data - (void)blob_data.erase(blob_data.begin() + loc); - } -} - -void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &message, - std::map &err_raw_data) { - std::lock_guard lock(check_mutex_); - (void)err_raw_data.insert(std::make_pair(row, message)); -} - -MSRStatus ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, - std::map &err_raw_data) { - auto data_type = std::string(value["type"].get()); - - if ((data_type == "int32" && !data[key].is_number_integer()) || - (data_type == "int64" && !data[key].is_number_integer()) || - (data_type == "float32" && !data[key].is_number_float()) || - (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) { - std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is not matched"; - PopulateMutexErrorData(i, message, err_raw_data); - return FAILED; - } - - if (data_type == "int32" && data[key].is_number_integer()) { - int64_t temp_value = data[key]; - if (static_cast(temp_value) < static_cast(std::numeric_limits::min()) && - static_cast(temp_value) > static_cast(std::numeric_limits::max())) { - std::string message = - "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is out of range"; - PopulateMutexErrorData(i, message, err_raw_data); - return FAILED; - } - } - return SUCCESS; -} - -void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, - std::map &err_raw_data) { - if (start_row < 0 || start_row > end_row || end_row > static_cast(sub_raw_data.size())) { - return; - } - for (int i = start_row; i < end_row; i++) { - json data = sub_raw_data[i]; - - for (auto iter = schema.begin(); iter != schema.end(); iter++) { - std::string key = iter.key(); - json value = iter.value(); - if (data.find(key) == data.end()) { - std::string message = "there is not '" + key + "' object in the raw data"; - PopulateMutexErrorData(i, message, err_raw_data); - break; - } - - if (value.size() == kInt2) { - // Skip check since all shaped data will store as blob - continue; - } - - if (CheckDataTypeAndValue(key, value, data, i, err_raw_data) != SUCCESS) { - break; - } - } - } -} - -MSRStatus ShardWriter::CheckData(const std::map> &raw_data) { - auto rawdata_iter = raw_data.begin(); - - // make sure rawdata match schema - for (; rawdata_iter != raw_data.end(); ++rawdata_iter) { - // used for storing error - std::map sub_err_mg; - int schema_id = rawdata_iter->first; - auto result = shard_header_->GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - return FAILED; - } - json schema = result.first->GetSchema()["schema"]; - for (const auto &field : result.first->GetBlobFields()) { - (void)schema.erase(field); - } - std::vector sub_raw_data = rawdata_iter->second; - - // calculate start position and end position for each thread - int batch_size = rawdata_iter->second.size() / shard_count_; - int thread_num = shard_count_; - if (thread_num <= 0) { - return FAILED; - } - if (thread_num > kMaxThreadCount) { - thread_num = kMaxThreadCount; - } - std::vector thread_set(thread_num); - - // start multiple thread - int start_row = 0, end_row = 0; - for (int x = 0; x < thread_num; ++x) { - if (x != thread_num - 1) { - start_row = batch_size * x; - end_row = batch_size * (x + 1); - } else { - start_row = batch_size * x; - end_row = rawdata_iter->second.size(); - } - thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema, - std::ref(sub_raw_data), std::ref(sub_err_mg)); - } - if (thread_num > kMaxThreadCount) { - return FAILED; - } - // Wait for threads done - for (int x = 0; x < thread_num; ++x) { - thread_set[x].join(); - } - - (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg)); - } - return SUCCESS; -} - -std::tuple ShardWriter::ValidateRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { - auto rawdata_iter = raw_data.begin(); - schema_count_ = raw_data.size(); - std::tuple failed(FAILED, 0, 0); - if (schema_count_ == 0) { - MS_LOG(ERROR) << "Data size is zero"; - return failed; - } - - // keep schema_id - std::set schema_ids; - row_count_ = (rawdata_iter->second).size(); - MS_LOG(DEBUG) << "Schema count is " << schema_count_; - - // Determine if the number of schemas is the same - if (shard_header_->GetSchemas().size() != schema_count_) { - MS_LOG(ERROR) << "Data size is not equal with the schema size"; - return failed; - } - - // Determine raw_data size == blob_data size - if (raw_data[0].size() != blob_data.size()) { - MS_LOG(ERROR) << "Raw data size is not equal blob data size"; - return failed; - } - - // Determine whether the number of samples corresponding to each schema is the same - for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { - if (row_count_ != rawdata_iter->second.size()) { - MS_LOG(ERROR) << "Data size is not equal"; - return failed; - } - (void)schema_ids.insert(rawdata_iter->first); - } - const std::vector> &schemas = shard_header_->GetSchemas(); - if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr &schema) { - return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); - })) { - // There is not enough data which is not matching the number of schema - MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; - return failed; - } - - if (!sign) { - std::tuple success(SUCCESS, schema_count_, row_count_); - return success; - } - - // check the data according the schema - if (CheckData(raw_data) != SUCCESS) { - MS_LOG(ERROR) << "Data validate check failed"; - return std::tuple(FAILED, schema_count_, row_count_); - } - - // delete wrong data from raw data - DeleteErrorData(raw_data, blob_data); - - // update raw count - row_count_ = row_count_ - err_mg_.begin()->second.size(); - std::tuple success(SUCCESS, schema_count_, row_count_); - return success; -} - -void ShardWriter::FillArray(int start, int end, std::map> &raw_data, - std::vector> &bin_data) { - // Prevent excessive thread opening and cause cross-border - if (start >= end) { - flag_ = true; - return; - } - int schema_count = static_cast(raw_data.size()); - std::map>::const_iterator rawdata_iter; - for (int x = start; x < end; ++x) { - int cnt = 0; - for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { - const json &line = raw_data.at(rawdata_iter->first)[x]; - std::vector bline = json::to_msgpack(line); - - // Storage form is [Sample1-Schema1, Sample1-Schema2, Sample2-Schema1, Sample2-Schema2] - bin_data[x * schema_count + cnt] = bline; - cnt++; - } - } -} - -int ShardWriter::LockWriter(bool parallel_writer) { - if (!parallel_writer) { - return 0; - } - -#if defined(_WIN32) || defined(_WIN64) - MS_LOG(DEBUG) << "Lock file done by python."; - const int fd = 0; -#else - const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); - if (fd >= 0) { - flock(fd, LOCK_EX); - } else { - MS_LOG(ERROR) << "Shard writer failed when locking file"; - return -1; - } -#endif - - // Open files - file_streams_.clear(); - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - return -1; - } - file_streams_.push_back(fs); - } - - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return -1; - } - return fd; -} - -MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { - if (!parallel_writer) { - return SUCCESS; - } - - if (shard_header_->PagesToFile(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Write pages to file failed"; - return FAILED; - } - - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { - file_streams_[i]->close(); - } - -#if defined(_WIN32) || defined(_WIN64) - MS_LOG(DEBUG) << "Unlock file done by python."; -#else - flock(fd, LOCK_UN); - close(fd); -#endif - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, - std::vector> &blob_data, bool sign, int *schema_count, - int *row_count) { - // check the free disk size - auto st_space = GetDiskSize(file_paths_[0], kFreeSize); - if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { - MS_LOG(ERROR) << "IO error / there is no free disk to be used"; - return FAILED; - } - - // compress blob - if (shard_column_->CheckCompressBlob()) { - for (auto &blob : blob_data) { - blob = shard_column_->CompressBlob(blob); - } - } - - // Add 4-bytes dummy blob data if no any blob fields - if (blob_data.size() == 0 && raw_data.size() > 0) { - blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); - } - - // Add dummy id if all are blob fields - if (blob_data.size() > 0 && raw_data.size() == 0) { - raw_data.insert(std::pair>(0, std::vector(blob_data.size(), kDummyId))); - } - - auto v = ValidateRawData(raw_data, blob_data, sign); - if (std::get<0>(v) == FAILED) { - MS_LOG(ERROR) << "Validate raw data failed"; - return FAILED; - } - *schema_count = std::get<1>(v); - *row_count = std::get<2>(v); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign, bool parallel_writer) { - // Lock Writer if loading data parallel - int fd = LockWriter(parallel_writer); - if (fd < 0) { - MS_LOG(ERROR) << "Lock writer failed"; - return FAILED; - } - - // Get the count of schemas and rows - int schema_count = 0; - int row_count = 0; - - // Serialize raw data - if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { - MS_LOG(ERROR) << "Check raw data failed"; - return FAILED; - } - - if (row_count == kInt0) { - MS_LOG(INFO) << "Raw data size is 0."; - return SUCCESS; - } - - std::vector> bin_raw_data(row_count * schema_count); - - // Serialize raw data - if (SerializeRawData(raw_data, bin_raw_data, row_count) == FAILED) { - MS_LOG(ERROR) << "Serialize raw data failed"; - return FAILED; - } - - // Set row size of raw data - if (SetRawDataSize(bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Set raw data size failed"; - return FAILED; - } - - // Set row size of blob data - if (SetBlobDataSize(blob_data) == FAILED) { - MS_LOG(ERROR) << "Set blob data size failed"; - return FAILED; - } - - // Write data to disk with multi threads - if (ParallelWriteData(blob_data, bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Parallel write data failed"; - return FAILED; - } - MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; - - if (UnlockWriter(fd, parallel_writer) == FAILED) { - MS_LOG(ERROR) << "Unlock writer failed"; - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign, - bool parallel_writer) { - std::map> raw_data_json; - std::map> blob_data_json; - - (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), - [](const std::pair> &pair) { - auto &py_raw_data = pair.second; - std::vector json_raw_data; - (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(json_raw_data)); - }); - - (void)std::transform(blob_data.begin(), blob_data.end(), std::inserter(blob_data_json, blob_data_json.end()), - [](const std::pair> &pair) { - auto &py_blob_data = pair.second; - std::vector jsonBlobData; - (void)std::transform(py_blob_data.begin(), py_blob_data.end(), - std::back_inserter(jsonBlobData), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(jsonBlobData)); - }); - - // Serialize blob page - auto blob_data_iter = blob_data.begin(); - auto schema_count = blob_data.size(); - auto row_count = blob_data_iter->second.size(); - - std::vector> bin_blob_data(row_count * schema_count); - // Serialize blob data - if (SerializeRawData(blob_data_json, bin_blob_data, row_count) == FAILED) { - MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; - return FAILED; - } - return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign, bool parallel_writer) { - std::map> raw_data_json; - (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), - [](const std::pair> &pair) { - auto &py_raw_data = pair.second; - std::vector json_raw_data; - (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(json_raw_data)); - }); - return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); -} - -MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, - const std::vector> &bin_raw_data) { - auto shards = BreakIntoShards(); - // define the number of thread - int thread_num = static_cast(shard_count_); - if (thread_num < 0) { - return FAILED; - } - if (thread_num > kMaxThreadCount) { - thread_num = kMaxThreadCount; - } - int left_thread = shard_count_; - int current_thread = 0; - while (left_thread) { - if (left_thread < thread_num) { - thread_num = left_thread; - } - // Start one thread for one shard - std::vector thread_set(thread_num); - if (thread_num <= kMaxThreadCount) { - for (int x = 0; x < thread_num; ++x) { - int start_row = shards[current_thread + x].first; - int end_row = shards[current_thread + x].second; - thread_set[x] = std::thread(&ShardWriter::WriteByShard, this, current_thread + x, start_row, end_row, - std::ref(blob_data), std::ref(bin_raw_data)); - } - // Wait for threads done - for (int x = 0; x < thread_num; ++x) { - thread_set[x].join(); - } - left_thread -= thread_num; - current_thread += thread_num; - } - } - return SUCCESS; -} - -MSRStatus ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, - const std::vector> &blob_data, - const std::vector> &bin_raw_data) { - MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row - << ", schema size: " << schema_count_; - if (start_row == end_row) { - return SUCCESS; - } - vector> rows_in_group; - std::shared_ptr last_raw_page = nullptr; - std::shared_ptr last_blob_page = nullptr; - SetLastRawPage(shard_id, last_raw_page); - SetLastBlobPage(shard_id, last_blob_page); - - if (CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "Cut row group failed"; - return FAILED; - } - - if (AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "Append bolb page failed"; - return FAILED; - } - - if (NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "New blob page failed"; - return FAILED; - } - - if (ShiftRawPage(shard_id, rows_in_group, last_raw_page) == FAILED) { - MS_LOG(ERROR) << "Shit raw page failed"; - return FAILED; - } - - if (WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Write raw page failed"; - return FAILED; - } - - return SUCCESS; -} - -MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, - std::vector> &rows_in_group, - const std::shared_ptr &last_raw_page, - const std::shared_ptr &last_blob_page) { - auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; - - auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; - auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; - auto n_byte_raw = last_raw_page_size - last_raw_offset; - - int page_start_row = start_row; - if (start_row > end_row) { - return FAILED; - } - if (end_row > static_cast(blob_data_size_.size()) || end_row > static_cast(raw_data_size_.size())) { - return FAILED; - } - for (int i = start_row; i < end_row; ++i) { - // n_byte_blob(0) indicate appendBlobPage - if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ || - n_byte_raw + raw_data_size_[i] > page_size_) { - rows_in_group.emplace_back(page_start_row, i); - page_start_row = i; - n_byte_blob = blob_data_size_[i]; - n_byte_raw = raw_data_size_[i]; - } else { - n_byte_blob += blob_data_size_[i]; - n_byte_raw += raw_data_size_[i]; - } - } - - // Not forget last one - rows_in_group.emplace_back(page_start_row, end_row); - return SUCCESS; -} - -MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page) { - auto blob_row = rows_in_group[0]; - if (blob_row.first == blob_row.second) return SUCCESS; - - // Write disk - auto page_id = last_blob_page->GetPageID(); - auto bytes_page = last_blob_page->GetPageSize(); - auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); - - // Update last blob page - bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); - last_blob_page->SetPageSize(bytes_page); - uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; - last_blob_page->SetEndRowID(end_row); - (void)shard_header_->SetPage(last_blob_page); - return SUCCESS; -} - -MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page) { - auto page_id = shard_header_->GetLastPageId(shard_id); - auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; - auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; - // index(0) indicate appendBlobPage - for (uint32_t i = 1; i < rows_in_group.size(); ++i) { - auto blob_row = rows_in_group[i]; - - // Write 1 blob page to disk - auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); - // Create new page info for header - auto page_size = - std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); - std::vector> row_group_ids; - auto start_row = current_row; - auto end_row = start_row + blob_row.second - blob_row.first; - auto page = Page(++page_id, shard_id, kPageTypeBlob, ++page_type_id, start_row, end_row, row_group_ids, page_size); - (void)shard_header_->AddPage(std::make_shared(page)); - current_row = end_row; - } - return SUCCESS; -} - -MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page) { - auto blob_row = rows_in_group[0]; - if (blob_row.first == blob_row.second) return SUCCESS; - auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; - if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + - last_raw_page_size <= - page_size_) { - return SUCCESS; - } - auto page_id = shard_header_->GetLastPageId(shard_id); - auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; - auto last_raw_page_id = last_raw_page->GetPageID(); - auto shift_size = last_raw_page_size - last_row_group_id_offset; - - std::vector buf(shift_size); - - // Read last row group from previous raw data page - if (shard_id < 0 || shard_id >= file_streams_.size()) { - return FAILED; - } - - auto &io_seekg = file_streams_[shard_id]->seekg( - page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf[0]), buf.size()); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - // Merge into new row group at new raw data page - auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&buf[0]), buf.size()); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - last_raw_page->DeleteLastGroupId(); - (void)shard_header_->SetPage(last_raw_page); - - // Refresh page info in header - int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; - std::vector> row_group_ids; - row_group_ids.emplace_back(row_group_id, 0); - int page_type_id = last_raw_page->GetPageID(); - auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); - (void)shard_header_->AddPage(std::make_shared(page)); - - // Reset: last raw page - SetLastRawPage(shard_id, last_raw_page); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page, - const std::vector> &bin_raw_data) { - int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; - for (uint32_t i = 0; i < rows_in_group.size(); ++i) { - const auto &blob_row = rows_in_group[i]; - if (blob_row.first == blob_row.second) continue; - auto raw_size = - std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); - if (!last_raw_page) { - EmptyRawPage(shard_id, last_raw_page); - } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { - (void)shard_header_->SetPage(last_raw_page); - EmptyRawPage(shard_id, last_raw_page); - } - if (AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data) != SUCCESS) { - return FAILED; - } - } - (void)shard_header_->SetPage(last_raw_page); - return SUCCESS; -} - -void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { - auto row_group_ids = std::vector>(); - auto page_id = shard_header_->GetLastPageId(shard_id); - auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; - auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); - (void)shard_header_->AddPage(std::make_shared(page)); - SetLastRawPage(shard_id, last_raw_page); -} - -MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, - const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, - const std::vector> &bin_raw_data) { - std::vector> row_group_ids = last_raw_page->GetRowGroupIds(); - auto last_raw_page_id = last_raw_page->GetPageID(); - auto n_bytes = last_raw_page->GetPageSize(); - - // previous raw data page - auto &io_seekp = - file_streams_[shard_id]->seekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - if (chunk_id > 0) row_group_ids.emplace_back(++last_row_group_id, n_bytes); - n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first, - raw_data_size_.begin() + rows_in_group[chunk_id].second, 0); - (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); - - // Update previous raw data page - last_raw_page->SetPageSize(n_bytes); - last_raw_page->SetRowGroupIds(row_group_ids); - (void)shard_header_->SetPage(last_raw_page); - - return SUCCESS; -} - -MSRStatus ShardWriter::FlushBlobChunk(const std::shared_ptr &out, - const std::vector> &blob_data, - const std::pair &blob_row) { - if (blob_row.first > blob_row.second) { - return FAILED; - } - if (blob_row.second > static_cast(blob_data.size()) || blob_row.first < 0) { - return FAILED; - } - for (int j = blob_row.first; j < blob_row.second; ++j) { - // Write the size of blob - uint64_t line_len = blob_data[j].size(); - auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - - // Write the data of blob - auto line = blob_data[j]; - auto &io_handle_data = out->write(reinterpret_cast(&line[0]), line_len); - if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - } - return SUCCESS; -} - -MSRStatus ShardWriter::FlushRawChunk(const std::shared_ptr &out, - const std::vector> &rows_in_group, const int &chunk_id, - const std::vector> &bin_raw_data) { - for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) { - // Write the size of multi schemas - for (uint32_t j = 0; j < schema_count_; ++j) { - uint64_t line_len = bin_raw_data[i * schema_count_ + j].size(); - auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - } - // Write the data of multi schemas - for (uint32_t j = 0; j < schema_count_; ++j) { - auto line = bin_raw_data[i * schema_count_ + j]; - auto &io_handle = out->write(reinterpret_cast(&line[0]), line.size()); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - out->close(); - return FAILED; - } - } - } - return SUCCESS; -} - -// Allocate data to shards evenly -std::vector> ShardWriter::BreakIntoShards() { - std::vector> shards; - int row_in_shard = row_count_ / shard_count_; - int remains = row_count_ % shard_count_; - - std::vector v_list(shard_count_); - std::iota(v_list.begin(), v_list.end(), 0); - std::random_device rd; - std::mt19937 g(rd()); - std::shuffle(v_list.begin(), v_list.end(), g); - std::unordered_set set(v_list.begin(), v_list.begin() + remains); - - if (shard_count_ <= kMaxShardCount) { - int start_row = 0; - for (int i = 0; i < shard_count_; ++i) { - int end_row = start_row + row_in_shard; - if (set.count(i)) end_row++; - shards.emplace_back(start_row, end_row); - start_row = end_row; - } - } - return shards; -} - -MSRStatus ShardWriter::WriteShardHeader() { - if (shard_header_ == nullptr) { - MS_LOG(ERROR) << "Shard header is null"; - return FAILED; - } - auto shard_header = shard_header_->SerializeHeader(); - // Write header data to multi files - if (shard_count_ > static_cast(file_streams_.size()) || shard_count_ > static_cast(shard_header.size())) { - return FAILED; - } - if (shard_count_ <= kMaxShardCount) { - for (int shard_id = 0; shard_id < shard_count_; ++shard_id) { - auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg); - if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - std::vector bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end()); - uint64_t line_len = bin_header.size(); - if (line_len + kInt64Len > header_size_) { - MS_LOG(ERROR) << "Shard header is too big"; - return FAILED; - } - - auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&line_len), kInt64Len); - if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - - auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast(&bin_header[0]), line_len); - if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) { - MS_LOG(ERROR) << "File write failed"; - file_streams_[shard_id]->close(); - return FAILED; - } - file_streams_[shard_id]->close(); - } - } - return SUCCESS; -} - -MSRStatus ShardWriter::SerializeRawData(std::map> &raw_data, - std::vector> &bin_data, uint32_t row_count) { - // define the number of thread - uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) thread_num = kThreadNumber; - // Set the number of samples processed by each thread - int group_num = ceil(row_count * 1.0 / thread_num); - std::vector thread_set(thread_num); - int work_thread_num = 0; - for (uint32_t x = 0; x < thread_num; ++x) { - int start_num = x * group_num; - int end_num = ((x + 1) * group_num > row_count) ? row_count : (x + 1) * group_num; - if (start_num >= end_num) { - continue; - } - // Define the run boundary and start the child thread - thread_set[x] = - std::thread(&ShardWriter::FillArray, this, start_num, end_num, std::ref(raw_data), std::ref(bin_data)); - work_thread_num++; - } - for (uint32_t x = 0; x < work_thread_num; ++x) { - // Set obstacles to prevent the main thread from running - thread_set[x].join(); - } - return flag_ == true ? FAILED : SUCCESS; -} - -MSRStatus ShardWriter::SetRawDataSize(const std::vector> &bin_raw_data) { - raw_data_size_ = std::vector(row_count_, 0); - for (uint32_t i = 0; i < row_count_; ++i) { - raw_data_size_[i] = std::accumulate( - bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0, - [](uint64_t accumulator, const std::vector &row) { return accumulator + kInt64Len + row.size(); }); - } - if (*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) > page_size_) { - MS_LOG(ERROR) << "Page size is too small to save a row!"; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::SetBlobDataSize(const std::vector> &blob_data) { - blob_data_size_ = std::vector(row_count_); - (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(), - [](const std::vector &row) { return kInt64Len + row.size(); }); - if (*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) > page_size_) { - MS_LOG(ERROR) << "Page size is too small to save a row!"; - return FAILED; - } - return SUCCESS; -} - -void ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { - // Get last raw page - auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw); - if (last_raw_page_id >= 0) { - auto page = shard_header_->GetPage(shard_id, last_raw_page_id); - last_raw_page = page.first; - } -} - -void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page) { - // Get last blob page - auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob); - if (last_blob_page_id >= 0) { - auto page = shard_header_->GetPage(shard_id, last_blob_page_id); - last_blob_page = page.first; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc deleted file mode 100644 index bd427a330a..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_category.h" - -namespace mindspore { -namespace mindrecord { -ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, - bool replacement) - : categories_(categories), - category_field_(""), - num_elements_(num_elements), - num_categories_(0), - replacement_(replacement) {} - -ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, - bool replacement) - : categories_({}), - category_field_(category_field), - num_elements_(num_elements), - num_categories_(num_categories), - replacement_(replacement) {} - -MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } - -int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (dataset_size == 0) return dataset_size; - if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { - return std::min(num_categories_, num_classes) * num_elements_; - } - return 0; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/mindrecord/meta/shard_column.cc deleted file mode 100644 index 28dc243e17..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_column.cc +++ /dev/null @@ -1,496 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_column.h" - -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" -#include "mindrecord/include/shard_error.h" - -namespace mindspore { -namespace mindrecord { -ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool compress_integer) { - auto first_schema = shard_header->GetSchemas()[0]; - auto schema = first_schema->GetSchema()["schema"]; - - bool has_integer_array = false; - for (json::iterator it = schema.begin(); it != schema.end(); ++it) { - const std::string &column_name = it.key(); - column_name_.push_back(column_name); - - json it_value = it.value(); - - std::string str_type = it_value["type"]; - column_data_type_.push_back(ColumnDataTypeMap.at(str_type)); - if (it_value.find("shape") != it_value.end()) { - std::vector vec(it_value["shape"].size()); - std::copy(it_value["shape"].begin(), it_value["shape"].end(), vec.begin()); - column_shape_.push_back(vec); - if (str_type == "int32" || str_type == "int64") { - has_integer_array = true; - } - } else { - std::vector vec = {}; - column_shape_.push_back(vec); - } - } - - for (uint64_t i = 0; i < column_name_.size(); i++) { - column_name_id_[column_name_[i]] = i; - } - - auto blob_fields = first_schema->GetBlobFields(); - - for (const auto &field : blob_fields) { - blob_column_.push_back(field); - } - - for (uint64_t i = 0; i < blob_column_.size(); i++) { - blob_column_id_[blob_column_[i]] = i; - } - - has_compress_blob_ = (compress_integer && has_integer_array); - num_blob_column_ = blob_column_.size(); -} - -std::pair ShardColumn::GetColumnTypeByName(const std::string &column_name, - ColumnDataType *column_data_type, - uint64_t *column_data_type_size, - std::vector *column_shape) { - // Skip if column not found - auto column_category = CheckColumnName(column_name); - if (column_category == ColumnNotFound) { - return {FAILED, ColumnNotFound}; - } - - // Get data type and size - auto column_id = column_name_id_[column_name]; - *column_data_type = column_data_type_[column_id]; - *column_data_type_size = ColumnDataTypeSize[*column_data_type]; - *column_shape = column_shape_[column_id]; - - return {SUCCESS, column_category}; -} - -MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, - const json &columns_json, const unsigned char **data, - std::unique_ptr *data_ptr, uint64_t *const n_bytes, - ColumnDataType *column_data_type, uint64_t *column_data_type_size, - std::vector *column_shape) { - // Skip if column not found - auto column_category = CheckColumnName(column_name); - if (column_category == ColumnNotFound) { - return FAILED; - } - - // Get data type and size - auto column_id = column_name_id_[column_name]; - *column_data_type = column_data_type_[column_id]; - *column_data_type_size = ColumnDataTypeSize[*column_data_type]; - *column_shape = column_shape_[column_id]; - - // Retrieve value from json - if (column_category == ColumnInRaw) { - if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { - MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; - return FAILED; - } - *data = reinterpret_cast(data_ptr->get()); - return SUCCESS; - } - - // Retrieve value from blob - if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { - MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; - return FAILED; - } - if (*data == nullptr) { - *data = reinterpret_cast(data_ptr->get()); - } - return SUCCESS; -} - -MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, - std::unique_ptr *data_ptr, uint64_t *n_bytes) { - auto column_id = column_name_id_[column_name]; - auto column_data_type = column_data_type_[column_id]; - - // Initialize num bytes - *n_bytes = ColumnDataTypeSize[column_data_type]; - auto json_column_value = columns_json[column_name]; - switch (column_data_type) { - case ColumnFloat32: { - return GetFloat(data_ptr, json_column_value, false); - } - case ColumnFloat64: { - return GetFloat(data_ptr, json_column_value, true); - } - case ColumnInt32: { - return GetInt(data_ptr, json_column_value); - } - case ColumnInt64: { - return GetInt(data_ptr, json_column_value); - } - default: { - // Convert string to c_str - std::string tmp_string = json_column_value; - *n_bytes = tmp_string.size(); - auto data = reinterpret_cast(common::SafeCStr(tmp_string)); - *data_ptr = std::make_unique(*n_bytes); - for (uint32_t i = 0; i < *n_bytes; i++) { - (*data_ptr)[i] = *(data + i); - } - break; - } - } - return SUCCESS; -} - -template -MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, - bool use_double) { - std::unique_ptr array_data = std::make_unique(1); - if (!json_column_value.is_string() && !json_column_value.is_number()) { - MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; - return FAILED; - } - if (json_column_value.is_number()) { - array_data[0] = json_column_value; - } else { - // Convert string to float - try { - if (use_double) { - array_data[0] = json_column_value.get(); - } else { - array_data[0] = json_column_value.get(); - } - } catch (json::exception &e) { - MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; - return FAILED; - } - } - - auto data = reinterpret_cast(array_data.get()); - *data_ptr = std::make_unique(sizeof(T)); - for (uint32_t i = 0; i < sizeof(T); i++) { - (*data_ptr)[i] = *(data + i); - } - - return SUCCESS; -} - -template -MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { - std::unique_ptr array_data = std::make_unique(1); - int64_t temp_value; - bool less_than_zero = false; - - if (json_column_value.is_number_integer()) { - const json json_zero = 0; - if (json_column_value < json_zero) less_than_zero = true; - temp_value = json_column_value; - } else if (json_column_value.is_string()) { - std::string string_value = json_column_value; - - if (!string_value.empty() && string_value[0] == '-') { - try { - temp_value = std::stoll(string_value); - less_than_zero = true; - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; - return FAILED; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "Conversion to int failed, out of range."; - return FAILED; - } - } else { - try { - temp_value = static_cast(std::stoull(string_value)); - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; - return FAILED; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "Conversion to int failed, out of range."; - return FAILED; - } - } - } else { - MS_LOG(ERROR) << "Conversion to int failed."; - return FAILED; - } - - if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || - (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { - MS_LOG(ERROR) << "Conversion to int failed. Out of range"; - return FAILED; - } - array_data[0] = static_cast(temp_value); - - auto data = reinterpret_cast(array_data.get()); - *data_ptr = std::make_unique(sizeof(T)); - for (uint32_t i = 0; i < sizeof(T); i++) { - (*data_ptr)[i] = *(data + i); - } - - return SUCCESS; -} - -MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, - const unsigned char **data, std::unique_ptr *data_ptr, - uint64_t *const n_bytes) { - uint64_t offset_address = 0; - auto column_id = column_name_id_[column_name]; - if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { - return FAILED; - } - - auto column_data_type = column_data_type_[column_id]; - if (has_compress_blob_ && column_data_type == ColumnInt32) { - if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { - return FAILED; - } - } else if (has_compress_blob_ && column_data_type == ColumnInt64) { - if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { - return FAILED; - } - } else { - *data = reinterpret_cast(&(columns_blob[offset_address])); - } - - return SUCCESS; -} - -ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { - auto it_column = column_name_id_.find(column_name); - if (it_column == column_name_id_.end()) { - return ColumnNotFound; - } - auto it_blob = blob_column_id_.find(column_name); - return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; -} - -std::vector ShardColumn::CompressBlob(const std::vector &blob) { - // Skip if no compress columns - if (!CheckCompressBlob()) return blob; - - std::vector dst_blob; - uint64_t i_src = 0; - for (int64_t i = 0; i < num_blob_column_; i++) { - // Get column data type - auto src_data_type = column_data_type_[column_name_id_[blob_column_[i]]]; - auto int_type = src_data_type == ColumnInt32 ? kInt32Type : kInt64Type; - - // Compress and return is blob has 1 column only - if (num_blob_column_ == 1) { - return CompressInt(blob, int_type); - } - - // Just copy and continue if column dat type is not int32/int64 - uint64_t num_bytes = BytesBigToUInt64(blob, i_src, kInt64Type); - if (src_data_type != ColumnInt32 && src_data_type != ColumnInt64) { - dst_blob.insert(dst_blob.end(), blob.begin() + i_src, blob.begin() + i_src + kInt64Len + num_bytes); - i_src += kInt64Len + num_bytes; - continue; - } - - // Get column slice in source blob - std::vector blob_slice(blob.begin() + i_src + kInt64Len, blob.begin() + i_src + kInt64Len + num_bytes); - // Compress column - auto dst_blob_slice = CompressInt(blob_slice, int_type); - // Get new column size - auto new_blob_size = UIntToBytesBig(dst_blob_slice.size(), kInt64Type); - // Append new colmn size - dst_blob.insert(dst_blob.end(), new_blob_size.begin(), new_blob_size.end()); - // Append new colmn data - dst_blob.insert(dst_blob.end(), dst_blob_slice.begin(), dst_blob_slice.end()); - i_src += kInt64Len + num_bytes; - } - MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; - return dst_blob; -} - -vector ShardColumn::CompressInt(const vector &src_bytes, const IntegerType &int_type) { - uint64_t i_size = kUnsignedOne << static_cast(int_type); - // Get number of elements - uint64_t src_n_int = src_bytes.size() / i_size; - // Calculate bitmap size (bytes) - uint64_t bitmap_size = (src_n_int + kNumDataOfByte - 1) / kNumDataOfByte; - - // Initilize destination blob, more space than needed, will be resized - vector dst_bytes(kBytesOfColumnLen + bitmap_size + src_bytes.size(), 0); - - // Write number of elements to destination blob - vector size_by_bytes = UIntToBytesBig(src_n_int, kInt32Type); - for (uint64_t n = 0; n < kBytesOfColumnLen; n++) { - dst_bytes[n] = size_by_bytes[n]; - } - - // Write compressed int - uint64_t i_dst = kBytesOfColumnLen + bitmap_size; - for (uint64_t i = 0; i < src_n_int; i++) { - // Initialize destination data type - IntegerType dst_int_type = kInt8Type; - // Shift to next int position - uint64_t pos = i * (kUnsignedOne << static_cast(int_type)); - // Narrow down this int - int64_t i_n = BytesLittleToMinIntType(src_bytes, pos, int_type, &dst_int_type); - - // Write this int to destination blob - uint64_t u_n = *reinterpret_cast(&i_n); - auto temp_bytes = UIntToBytesLittle(u_n, dst_int_type); - for (uint64_t j = 0; j < (kUnsignedOne << static_cast(dst_int_type)); j++) { - dst_bytes[i_dst++] = temp_bytes[j]; - } - - // Update date type in bit map - dst_bytes[i / kNumDataOfByte + kBytesOfColumnLen] |= - (static_cast(dst_int_type) << (kDataTypeBits * (kNumDataOfByte - kUnsignedOne - (i % kNumDataOfByte)))); - } - // Resize destination blob - dst_bytes.resize(i_dst); - MS_LOG(DEBUG) << "Compress blob field from " << src_bytes.size() << " to " << dst_bytes.size() << "."; - return dst_bytes; -} - -MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, - uint64_t *num_bytes, uint64_t *shift_idx) { - if (num_blob_column_ == 1) { - *num_bytes = columns_blob.size(); - *shift_idx = 0; - return SUCCESS; - } - auto blob_id = blob_column_id_[column_name_[column_id]]; - - for (int32_t i = 0; i < blob_id; i++) { - *shift_idx += kInt64Len + BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); - } - *num_bytes = BytesBigToUInt64(columns_blob, *shift_idx, kInt64Type); - - (*shift_idx) += kInt64Len; - - return SUCCESS; -} - -template -MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, - const std::vector &columns_blob, uint64_t *num_bytes, - uint64_t shift_idx) { - auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); - *num_bytes = sizeof(T) * num_elements; - - // Parse integer array - uint64_t i_source = shift_idx + kBytesOfColumnLen + (num_elements + kNumDataOfByte - 1) / kNumDataOfByte; - auto array_data = std::make_unique(num_elements); - - for (uint64_t i = 0; i < num_elements; i++) { - uint8_t iBitMap = columns_blob[shift_idx + kBytesOfColumnLen + i / kNumDataOfByte]; - uint64_t i_type = (iBitMap >> ((kNumDataOfByte - 1 - (i % kNumDataOfByte)) * kDataTypeBits)) & kDataTypeBitMask; - auto mr_int_type = static_cast(i_type); - int64_t i64 = BytesLittleToMinIntType(columns_blob, i_source, mr_int_type); - i_source += (kUnsignedOne << i_type); - array_data[i] = static_cast(i64); - } - - auto data = reinterpret_cast(array_data.get()); - *data_ptr = std::make_unique(*num_bytes); - int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data!"; - } - - return SUCCESS; -} - -uint64_t ShardColumn::BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &i_type) { - uint64_t result = 0; - for (uint64_t i = 0; i < (kUnsignedOne << static_cast(i_type)); i++) { - result = (result << kBitsOfByte) + bytes_array[pos + i]; - } - return result; -} - -std::vector ShardColumn::UIntToBytesBig(uint64_t value, const IntegerType &i_type) { - uint64_t n_bytes = kUnsignedOne << static_cast(i_type); - std::vector result(n_bytes, 0); - for (uint64_t i = 0; i < n_bytes; i++) { - result[n_bytes - 1 - i] = value & std::numeric_limits::max(); - value >>= kBitsOfByte; - } - return result; -} - -std::vector ShardColumn::UIntToBytesLittle(uint64_t value, const IntegerType &i_type) { - uint64_t n_bytes = kUnsignedOne << static_cast(i_type); - std::vector result(n_bytes, 0); - for (uint64_t i = 0; i < n_bytes; i++) { - result[i] = value & std::numeric_limits::max(); - value >>= kBitsOfByte; - } - return result; -} - -int64_t ShardColumn::BytesLittleToMinIntType(const std::vector &bytes_array, const uint64_t &pos, - const IntegerType &src_i_type, IntegerType *dst_i_type) { - uint64_t u_temp = 0; - for (uint64_t i = 0; i < (kUnsignedOne << static_cast(src_i_type)); i++) { - u_temp = (u_temp << kBitsOfByte) + - bytes_array[pos + (kUnsignedOne << static_cast(src_i_type)) - kUnsignedOne - i]; - } - - int64_t i_out; - switch (src_i_type) { - case kInt8Type: { - i_out = (int8_t)(u_temp & std::numeric_limits::max()); - break; - } - case kInt16Type: { - i_out = (int16_t)(u_temp & std::numeric_limits::max()); - break; - } - case kInt32Type: { - i_out = (int32_t)(u_temp & std::numeric_limits::max()); - break; - } - case kInt64Type: { - i_out = (int64_t)(u_temp & std::numeric_limits::max()); - break; - } - default: { - i_out = 0; - } - } - - if (!dst_i_type) { - return i_out; - } - - if (i_out >= static_cast(std::numeric_limits::min()) && - i_out <= static_cast(std::numeric_limits::max())) { - *dst_i_type = kInt8Type; - } else if (i_out >= static_cast(std::numeric_limits::min()) && - i_out <= static_cast(std::numeric_limits::max())) { - *dst_i_type = kInt16Type; - } else if (i_out >= static_cast(std::numeric_limits::min()) && - i_out <= static_cast(std::numeric_limits::max())) { - *dst_i_type = kInt32Type; - } else { - *dst_i_type = kInt64Type; - } - return i_out; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc deleted file mode 100644 index b7e890da7c..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_distributed_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, - uint32_t seed) - : ShardSample(1, num_shards, shard_id), - shuffle_(shuffle), - no_of_padded_samples_(no_of_padded_samples), - first_epoch_(true) { - shuffle_op_ = std::make_shared(seed, kShuffleSample); -} - -ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) - : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} - -int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (no_of_padded_samples_ <= 0) { - if (dataset_size % denominator_ == 0) { - return dataset_size / denominator_ * numerator_; - } else { - return dataset_size / denominator_ * numerator_ + 1; - } - } else { - auto padded_size = dataset_size + no_of_padded_samples_; - if (padded_size % denominator_ == 0) { - return padded_size / denominator_ * numerator_; - } else { - return -1; - } - } - return 0; -} - -MSRStatus ShardDistributedSample::PreExecute(ShardTask &tasks) { - auto total_no = tasks.Size(); - if (no_of_padded_samples_ > 0 && first_epoch_) { - if (total_no % denominator_ != 0) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " - << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ - << ", denominator: " << denominator_; - return FAILED; - } - } - if (first_epoch_) { - first_epoch_ = false; - task_ = tasks; - } else { - tasks = task_; - } - if (shuffle_ == true) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc deleted file mode 100644 index ec177394ef..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ /dev/null @@ -1,725 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_header.h" - -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_page.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -std::atomic thread_status(false); -ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared(); } - -MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { - shard_count_ = headers.size(); - int shard_index = 0; - bool first = true; - for (const auto &header : headers) { - if (first) { - first = false; - if (ParseSchema(header["schema"]) != SUCCESS) { - return FAILED; - } - if (ParseIndexFields(header["index_fields"]) != SUCCESS) { - return FAILED; - } - if (ParseStatistics(header["statistics"]) != SUCCESS) { - return FAILED; - } - ParseShardAddress(header["shard_addresses"]); - header_size_ = header["header_size"].get(); - page_size_ = header["page_size"].get(); - } - ParsePage(header["page"], shard_index, load_dataset); - shard_index++; - } - return SUCCESS; -} - -MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { - std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); - if (!fin) { - MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; - return FAILED; - } - if (fin.fail()) { - MS_LOG(ERROR) << "Failed to open file. path: " << path; - return FAILED; - } - - // fetch file size - auto &io_seekg = fin.seekg(0, std::ios::end); - if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - fin.close(); - MS_LOG(ERROR) << "File seekg failed"; - return FAILED; - } - - size_t file_size = fin.tellg(); - if (file_size < kMinFileSize) { - fin.close(); - MS_LOG(ERROR) << "File size %d is smaller than the minimum value."; - return FAILED; - } - fin.close(); - return SUCCESS; -} - -std::pair ShardHeader::ValidateHeader(const std::string &path) { - if (CheckFileStatus(path) != SUCCESS) { - return {FAILED, {}}; - } - - // read header size - json json_header; - std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); - if (!fin.is_open()) { - MS_LOG(ERROR) << "File seekg failed"; - return {FAILED, json_header}; - } - - uint64_t header_size = 0; - auto &io_read = fin.read(reinterpret_cast(&header_size), kInt64Len); - if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - fin.close(); - return {FAILED, json_header}; - } - - if (header_size > kMaxHeaderSize) { - fin.close(); - MS_LOG(ERROR) << "Header size is illegal."; - return {FAILED, json_header}; - } - - // read header content - std::vector header_content(header_size); - auto &io_read_content = fin.read(reinterpret_cast(&header_content[0]), header_size); - if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { - MS_LOG(ERROR) << "File read failed"; - fin.close(); - return {FAILED, json_header}; - } - - fin.close(); - std::string raw_header_content = std::string(header_content.begin(), header_content.end()); - // parse json content - try { - json_header = json::parse(raw_header_content); - } catch (json::parse_error &e) { - MS_LOG(ERROR) << "Json parse error: " << e.what(); - return {FAILED, json_header}; - } - return {SUCCESS, json_header}; -} - -std::pair ShardHeader::BuildSingleHeader(const std::string &file_path) { - auto ret = ValidateHeader(file_path); - if (SUCCESS != ret.first) { - return {FAILED, json()}; - } - json raw_header = ret.second; - json header = {{"shard_addresses", raw_header["shard_addresses"]}, - {"header_size", raw_header["header_size"]}, - {"page_size", raw_header["page_size"]}, - {"index_fields", raw_header["index_fields"]}, - {"blob_fields", raw_header["schema"][0]["blob_fields"]}, - {"schema", raw_header["schema"][0]["schema"]}, - {"version", raw_header["version"]}}; - return {SUCCESS, header}; -} - -MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { - uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) thread_num = kThreadNumber; - uint32_t work_thread_num = 0; - uint32_t shard_count = file_paths.size(); - int group_num = ceil(shard_count * 1.0 / thread_num); - std::vector thread_set(thread_num); - std::vector headers(shard_count); - for (uint32_t x = 0; x < thread_num; ++x) { - int start_num = x * group_num; - int end_num = ((x + 1) * group_num > shard_count) ? shard_count : (x + 1) * group_num; - if (start_num >= end_num) { - continue; - } - - thread_set[x] = - std::thread(&ShardHeader::GetHeadersOneTask, this, start_num, end_num, std::ref(headers), file_paths); - work_thread_num++; - } - - for (uint32_t x = 0; x < work_thread_num; ++x) { - thread_set[x].join(); - } - if (thread_status) { - thread_status = false; - return FAILED; - } - if (SUCCESS != InitializeHeader(headers, load_dataset)) { - return FAILED; - } - return SUCCESS; -} - -void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &headers, - const vector &realAddresses) { - if (thread_status || end > realAddresses.size()) { - return; - } - for (int x = start; x < end; ++x) { - auto ret = ValidateHeader(realAddresses[x]); - if (SUCCESS != ret.first) { - thread_status = true; - return; - } - json header; - header = ret.second; - header["shard_addresses"] = realAddresses; - if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { - MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() - << ", lib version is: " << kVersion; - thread_status = true; - return; - } - headers[x] = header; - } -} - -MSRStatus ShardHeader::InitByFiles(const std::vector &file_paths) { - std::vector file_names(file_paths.size()); - std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { - if (GetFileName(fp).first == SUCCESS) { - return GetFileName(fp).second; - } - }); - - shard_addresses_ = std::move(file_names); - shard_count_ = file_paths.size(); - if (shard_count_ == 0) { - return FAILED; - } - if (shard_count_ <= kMaxShardCount) { - pages_.resize(shard_count_); - } else { - return FAILED; - } - return SUCCESS; -} - -void ShardHeader::ParseHeader(const json &header) {} - -MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { - std::vector> parsed_index_fields; - for (auto &index_field : index_fields) { - auto schema_id = index_field["schema_id"].get(); - std::string field_name = index_field["index_field"].get(); - std::pair parsed_index_field(schema_id, field_name); - parsed_index_fields.push_back(parsed_index_field); - } - if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -void ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { - // set shard_index when load_dataset is false - if (pages_.empty() && shard_count_ <= kMaxShardCount) { - pages_.resize(shard_count_); - } - for (auto &page : pages) { - int page_id = page["page_id"]; - int shard_id = page["shard_id"]; - std::string page_type = page["page_type"]; - int page_type_id = page["page_type_id"]; - auto start_row_id = page["start_row_id"].get(); - auto end_row_id = page["end_row_id"].get(); - - std::vector> row_group_ids(page["row_group_ids"].size()); - std::transform(page["row_group_ids"].begin(), page["row_group_ids"].end(), row_group_ids.begin(), - [](json rg) { return std::make_pair(rg["id"], rg["offset"].get()); }); - - auto page_size = page["page_size"].get(); - - std::shared_ptr parsed_page = std::make_shared(page_id, shard_id, page_type, page_type_id, start_row_id, - end_row_id, row_group_ids, page_size); - if (load_dataset == true) { - pages_[shard_id].push_back(std::move(parsed_page)); - } else { - pages_[shard_index].push_back(std::move(parsed_page)); - } - } -} - -MSRStatus ShardHeader::ParseStatistics(const json &statistics) { - for (auto &statistic : statistics) { - if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { - MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); - return FAILED; - } - std::string statistic_description = statistic["desc"].get(); - json statistic_body = statistic["statistics"]; - std::shared_ptr parsed_statistic = Statistics::Build(statistic_description, statistic_body); - if (!parsed_statistic) { - return FAILED; - } - AddStatistic(parsed_statistic); - } - return SUCCESS; -} - -MSRStatus ShardHeader::ParseSchema(const json &schemas) { - for (auto &schema : schemas) { - // change how we get schemaBody once design is finalized - if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || - schema.find("schema") == schema.end()) { - MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); - return FAILED; - } - std::string schema_description = schema["desc"].get(); - std::vector blob_fields = schema["blob_fields"].get>(); - json schema_body = schema["schema"]; - std::shared_ptr parsed_schema = Schema::Build(schema_description, schema_body); - if (!parsed_schema) { - return FAILED; - } - AddSchema(parsed_schema); - } - return SUCCESS; -} - -void ShardHeader::ParseShardAddress(const json &address) { - std::copy(address.begin(), address.end(), std::back_inserter(shard_addresses_)); -} - -std::vector ShardHeader::SerializeHeader() { - std::vector header; - auto index = SerializeIndexFields(); - auto stats = SerializeStatistics(); - auto schema = SerializeSchema(); - auto pages = SerializePage(); - auto address = SerializeShardAddress(); - if (shard_count_ > static_cast(pages.size())) { - return std::vector{}; - } - if (shard_count_ <= kMaxShardCount) { - for (int shardId = 0; shardId < shard_count_; shardId++) { - string s; - s += "{\"header_size\":" + std::to_string(header_size_) + ","; - s += "\"index_fields\":" + index + ","; - s += "\"page\":" + pages[shardId] + ","; - s += "\"page_size\":" + std::to_string(page_size_) + ","; - s += "\"schema\":" + schema + ","; - s += "\"shard_addresses\":" + address + ","; - s += "\"shard_id\":" + std::to_string(shardId) + ","; - s += "\"statistics\":" + stats + ","; - s += "\"version\":\"" + std::string(kVersion) + "\""; - s += "}"; - header.emplace_back(s); - } - } - return header; -} - -std::string ShardHeader::SerializeIndexFields() { - json j; - auto fields = index_->GetFields(); - for (const auto &field : fields) { - j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); - } - return j.dump(); -} - -std::vector ShardHeader::SerializePage() { - std::vector pages; - for (auto &shard_pages : pages_) { - json j; - for (const auto &p : shard_pages) { - j.emplace_back(p->GetPage()); - } - pages.emplace_back(j.dump()); - } - return pages; -} - -std::string ShardHeader::SerializeStatistics() { - json j; - for (const auto &stats : statistics_) { - j.emplace_back(stats->GetStatistics()); - } - return j.dump(); -} - -std::string ShardHeader::SerializeSchema() { - json j; - for (const auto &schema : schema_) { - j.emplace_back(schema->GetSchema()); - } - return j.dump(); -} - -std::string ShardHeader::SerializeShardAddress() { - json j; - for (const auto &addr : shard_addresses_) { - j.emplace_back(GetFileName(addr).second); - } - return j.dump(); -} - -std::pair, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { - if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { - return std::make_pair(pages_[shard_id][page_id], SUCCESS); - } else { - return std::make_pair(nullptr, FAILED); - } -} - -MSRStatus ShardHeader::SetPage(const std::shared_ptr &new_page) { - if (new_page == nullptr) { - return FAILED; - } - int shard_id = new_page->GetShardID(); - int page_id = new_page->GetPageID(); - if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { - pages_[shard_id][page_id] = new_page; - return SUCCESS; - } else { - return FAILED; - } -} - -MSRStatus ShardHeader::AddPage(const std::shared_ptr &new_page) { - if (new_page == nullptr) { - return FAILED; - } - int shard_id = new_page->GetShardID(); - int page_id = new_page->GetPageID(); - if (shard_id < static_cast(pages_.size()) && page_id == static_cast(pages_[shard_id].size())) { - pages_[shard_id].push_back(new_page); - return SUCCESS; - } else { - return FAILED; - } -} - -int64_t ShardHeader::GetLastPageId(const int &shard_id) { - if (shard_id >= static_cast(pages_.size())) { - return 0; - } - return pages_[shard_id].size() - 1; -} - -int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &page_type) { - if (shard_id >= static_cast(pages_.size())) { - return 0; - } - int last_page_id = -1; - for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { - if (pages_[shard_id][i - 1]->GetPageType() == page_type) { - last_page_id = pages_[shard_id][i - 1]->GetPageID(); - return last_page_id; - } - } - return last_page_id; -} - -const std::pair> ShardHeader::GetPageByGroupId(const int &group_id, - const int &shard_id) { - if (shard_id >= static_cast(pages_.size())) { - MS_LOG(ERROR) << "Shard id is more than sum of shards."; - return {FAILED, nullptr}; - } - for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { - auto page = pages_[shard_id][i - 1]; - if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { - return {SUCCESS, page}; - } - } - MS_LOG(ERROR) << "Could not get page by group id " << group_id; - return {FAILED, nullptr}; -} - -int ShardHeader::AddSchema(std::shared_ptr schema) { - if (schema == nullptr) { - MS_LOG(ERROR) << "Schema is illegal"; - return -1; - } - - if (!schema_.empty()) { - MS_LOG(ERROR) << "Only support one schema"; - return -1; - } - - int64_t schema_id = schema->GetSchemaID(); - if (schema_id == -1) { - schema_id = schema_.size(); - schema->SetSchemaID(schema_id); - } - schema_.push_back(schema); - return schema_id; -} - -void ShardHeader::AddStatistic(std::shared_ptr statistic) { - if (statistic) { - int64_t statistics_id = statistic->GetStatisticsID(); - if (statistics_id == -1) { - statistics_id = statistics_.size(); - statistic->SetStatisticsID(statistics_id); - } - statistics_.push_back(statistic); - } -} - -std::shared_ptr ShardHeader::InitIndexPtr() { - std::shared_ptr index = index_; - if (!index_) { - index = std::make_shared(); - index_ = index; - } - return index; -} - -MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { - // check field name is or is not valid - if (schema.find(field) == schema.end()) { - MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; - return FAILED; - } - - if (schema[field]["type"] == "bytes") { - MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; - return FAILED; - } - - if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { - MS_LOG(ERROR) << field << " array can not be schema index field."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { - // create index Object - std::shared_ptr index = InitIndexPtr(); - - if (fields.size() == kInt0) { - MS_LOG(ERROR) << "There are no index fields"; - return FAILED; - } - - if (GetSchemas().empty()) { - MS_LOG(ERROR) << "No schema is set"; - return FAILED; - } - - for (const auto &schemaPtr : schema_) { - auto result = GetSchemaByID(schemaPtr->GetSchemaID()); - if (result.second != SUCCESS) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - - if (result.first == nullptr) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - - json schema = result.first->GetSchema().at("schema"); - - // checkout and add fields for each schema - std::set field_set; - for (const auto &item : index->GetFields()) { - field_set.insert(item.second); - } - for (const auto &field : fields) { - if (field_set.find(field) != field_set.end()) { - MS_LOG(ERROR) << "Add same index field twice"; - return FAILED; - } - - // check field name is or is not valid - if (CheckIndexField(field, schema) == FAILED) { - return FAILED; - } - field_set.insert(field); - - // add field into index - index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); - } - } - - index_ = index; - return SUCCESS; -} - -MSRStatus ShardHeader::GetAllSchemaID(std::set &bucket_count) { - // get all schema id - for (const auto &schema : schema_) { - auto bucket_it = bucket_count.find(schema->GetSchemaID()); - if (bucket_it != bucket_count.end()) { - MS_LOG(ERROR) << "Schema duplication"; - return FAILED; - } else { - bucket_count.insert(schema->GetSchemaID()); - } - } - return SUCCESS; -} - -MSRStatus ShardHeader::AddIndexFields(std::vector> fields) { - // create index Object - std::shared_ptr index = InitIndexPtr(); - - if (fields.size() == kInt0) { - MS_LOG(ERROR) << "There are no index fields"; - return FAILED; - } - - // get all schema id - std::set bucket_count; - if (GetAllSchemaID(bucket_count) != SUCCESS) { - return FAILED; - } - - // check and add fields for each schema - std::set> field_set; - for (const auto &item : index->GetFields()) { - field_set.insert(item); - } - for (const auto &field : fields) { - if (field_set.find(field) != field_set.end()) { - MS_LOG(ERROR) << "Add same index field twice"; - return FAILED; - } - - uint64_t schema_id = field.first; - std::string field_name = field.second; - - // check schemaId is or is not valid - if (bucket_count.find(schema_id) == bucket_count.end()) { - MS_LOG(ERROR) << "Illegal schema id: " << schema_id; - return FAILED; - } - - // check field name is or is not valid - auto result = GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - json schema = result.first->GetSchema().at("schema"); - if (schema.find(field_name) == schema.end()) { - MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; - return FAILED; - } - - if (CheckIndexField(field_name, schema) == FAILED) { - return FAILED; - } - - field_set.insert(field); - - // add field into index - index.get()->AddIndexField(schema_id, field_name); - } - index_ = index; - return SUCCESS; -} - -std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { - if (shard_id >= shard_addresses_.size()) { - return ""; - } - return shard_addresses_.at(shard_id); -} - -std::vector> ShardHeader::GetSchemas() { return schema_; } - -std::vector> ShardHeader::GetStatistics() { return statistics_; } - -std::vector> ShardHeader::GetFields() { return index_->GetFields(); } - -std::shared_ptr ShardHeader::GetIndex() { return index_; } - -std::pair, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { - int64_t schemaSize = schema_.size(); - if (schema_id < 0 || schema_id >= schemaSize) { - MS_LOG(ERROR) << "Illegal schema id"; - return std::make_pair(nullptr, FAILED); - } - return std::make_pair(schema_.at(schema_id), SUCCESS); -} - -std::pair, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { - int64_t statistics_size = statistics_.size(); - if (statistic_id < 0 || statistic_id >= statistics_size) { - return std::make_pair(nullptr, FAILED); - } - return std::make_pair(statistics_.at(statistic_id), SUCCESS); -} - -MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { - // write header content to file, dump whatever is in the file before - std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); - if (page_out_handle.fail()) { - MS_LOG(ERROR) << "Failed in opening page file"; - return FAILED; - } - - auto pages = SerializePage(); - for (const auto &shard_pages : pages) { - page_out_handle << shard_pages << "\n"; - } - - page_out_handle.close(); - return SUCCESS; -} - -MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { - for (auto &v : pages_) { // clean pages - v.clear(); - } - // attempt to open the file contains the page in json - std::ifstream page_in_handle(dump_file_name.c_str()); - - if (!page_in_handle.good()) { - MS_LOG(INFO) << "No page file exists."; - return SUCCESS; - } - - std::string line; - while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line), -1, true); - } - - page_in_handle.close(); - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_index.cc b/mindspore/ccsrc/mindrecord/meta/shard_index.cc deleted file mode 100644 index 8b7a3c0342..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_index.cc +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_index.h" - -namespace mindspore { -namespace mindrecord { -// table name for index -const char TABLENAME[] = "index_table"; - -Index::Index() : database_name_(""), table_name_(TABLENAME) {} - -void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { - fields_.emplace_back(pair(schemaId, field)); -} - -// Get attribute list -std::vector> Index::GetFields() { return fields_; } -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_page.cc b/mindspore/ccsrc/mindrecord/meta/shard_page.cc deleted file mode 100644 index 6bb849ae1d..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_page.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_page.h" -#include "pybind11/pybind11.h" - -namespace mindspore { -namespace mindrecord { -json Page::GetPage() const { - json str_page; - str_page["page_id"] = page_id_; - str_page["shard_id"] = shard_id_; - str_page["page_type"] = page_type_; - str_page["page_type_id"] = page_type_id_; - str_page["start_row_id"] = start_row_id_; - str_page["end_row_id"] = end_row_id_; - if (row_group_ids_.size() == 0) { - json row_groups = json({}); - row_groups["id"] = 0; - row_groups["offset"] = 0; - str_page["row_group_ids"].push_back(row_groups); - } else { - for (const auto &rg : row_group_ids_) { - json row_groups = json({}); - row_groups["id"] = rg.first; - row_groups["offset"] = rg.second; - str_page["row_group_ids"].push_back(row_groups); - } - } - str_page["page_size"] = page_size_; - return str_page; -} - -void Page::DeleteLastGroupId() { - if (!row_group_ids_.empty()) { - page_size_ = row_group_ids_.back().second; - row_group_ids_.pop_back(); - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc deleted file mode 100644 index fac2fec708..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_pk_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) - : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} - -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) - : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} - -ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, - uint32_t seed) - : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { - shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement -} - -MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { - if (shuffle_ == true) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc deleted file mode 100644 index c207747194..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardSample::ShardSample(int n) - : numerator_(0), - denominator_(0), - partition_id_(0), - no_of_samples_(n), - indices_({}), - sampler_type_(kCustomTopNSampler) {} - -ShardSample::ShardSample(int num, int den) - : numerator_(num), - denominator_(den), - partition_id_(0), - no_of_samples_(0), - indices_({}), - sampler_type_(kCustomTopPercentSampler) {} - -ShardSample::ShardSample(int num, int den, int par) - : numerator_(num), - denominator_(den), - partition_id_(par), - no_of_samples_(0), - indices_({}), - sampler_type_(kCustomTopPercentSampler) {} - -ShardSample::ShardSample(const std::vector &indices, uint32_t seed) - : numerator_(0), - denominator_(0), - partition_id_(0), - no_of_samples_(0), - indices_(indices), - sampler_type_(kSubsetRandomSampler) { - shuffle_op_ = std::make_shared(seed); -} - -int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (sampler_type_ == kCustomTopNSampler) { - return no_of_samples_; - } - - if (sampler_type_ == kCustomTopPercentSampler) { - if (dataset_size % denominator_ == 0) { - return dataset_size / denominator_ * numerator_; - } else { - return dataset_size / denominator_ * numerator_ + 1; - } - } - if (sampler_type_ == kSubsetRandomSampler) { - return indices_.size(); - } - return 0; -} - -MSRStatus ShardSample::Execute(ShardTask &tasks) { - int no_of_categories = static_cast(tasks.categories); - int total_no = static_cast(tasks.Size()); // make sure task_size - - int taking = 0; - if (sampler_type_ == kCustomTopNSampler) { // non sharding case constructor #1 - no_of_samples_ = std::min(no_of_samples_, total_no); - taking = no_of_samples_ - no_of_samples_ % no_of_categories; - } else if (sampler_type_ == kSubsetRandomSampler) { - if (indices_.size() > total_no) { - MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; - return FAILED; - } - } else { // constructor TopPercent - if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { - if (numerator_ == 1 && denominator_ > 1) { // sharding - taking = (total_no + denominator_ - 1) / denominator_; - } else { // non sharding - taking = total_no * numerator_ / denominator_; - taking -= (taking % no_of_categories); - } - } else { - MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; - return FAILED; - } - } - - if (tasks.permutation_.empty()) { - ShardTask new_tasks; - total_no = static_cast(tasks.Size()); - if (sampler_type_ == kSubsetRandomSampler) { - for (int i = 0; i < indices_.size(); ++i) { - int index = ((indices_[i] % total_no) + total_no) % total_no; - new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python - } - } else { - for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start - } - } - std::swap(tasks, new_tasks); - } else { - ShardTask new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { - return FAILED; - } - total_no = static_cast(tasks.permutation_.size()); - for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); - } - std::swap(tasks, new_tasks); - } - return SUCCESS; -} - -MSRStatus ShardSample::SufExecute(ShardTask &tasks) { - if (sampler_type_ == kSubsetRandomSampler) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/mindrecord/meta/shard_schema.cc deleted file mode 100644 index ee0f5afa4a..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_schema.cc +++ /dev/null @@ -1,164 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_schema.h" -#include "common/utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -std::shared_ptr Schema::Build(std::string desc, const json &schema) { - // validate check - if (!Validate(schema)) { - return nullptr; - } - - std::vector blob_fields = PopulateBlobFields(schema); - Schema object_schema; - object_schema.desc_ = std::move(desc); - object_schema.blob_fields_ = std::move(blob_fields); - object_schema.schema_ = schema; - object_schema.schema_id_ = -1; - return std::make_shared(object_schema); -} - -std::shared_ptr Schema::Build(std::string desc, pybind11::handle schema) { - // validate check - json schema_json = nlohmann::detail::ToJsonImpl(schema); - return Build(std::move(desc), schema_json); -} - -std::string Schema::GetDesc() const { return desc_; } - -json Schema::GetSchema() const { - json str_schema; - str_schema["desc"] = desc_; - str_schema["schema"] = schema_; - str_schema["blob_fields"] = blob_fields_; - return str_schema; -} - -pybind11::object Schema::GetSchemaForPython() const { - json schema_json = GetSchema(); - pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); - return schema_py; -} - -void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } - -int64_t Schema::GetSchemaID() const { return schema_id_; } - -std::vector Schema::GetBlobFields() const { return blob_fields_; } - -std::vector Schema::PopulateBlobFields(json schema) { - std::vector blob_fields; - for (json::iterator it = schema.begin(); it != schema.end(); ++it) { - json it_value = it.value(); - if ((it_value.size() == kInt2 && it_value.find("shape") != it_value.end()) || it_value["type"] == "bytes") { - blob_fields.emplace_back(it.key()); - } - } - return blob_fields; -} - -bool Schema::ValidateNumberShape(const json &it_value) { - if (it_value.find("shape") == it_value.end()) { - MS_LOG(ERROR) << "%s supports shape only." << it_value["type"].dump(); - return false; - } - - auto shape = it_value["shape"]; - if (!shape.is_array()) { - MS_LOG(ERROR) << "%s shape format is wrong." << it_value["type"].dump(); - return false; - } - - int num_negtive_one = 0; - for (const auto &i : shape) { - if (i == 0 || i < -1) { - MS_LOG(ERROR) << "Shape %s, number is wrong." << it_value["shape"].dump(); - return false; - } - if (i == -1) { - num_negtive_one++; - } - } - - if (num_negtive_one > 1) { - MS_LOG(ERROR) << "Shape %s, have at most 1 variable-length dimension." << it_value["shape"].dump(); - return false; - } - - return true; -} - -bool Schema::Validate(json schema) { - if (schema.size() == kInt0) { - MS_LOG(ERROR) << "Schema is null"; - return false; - } - - for (json::iterator it = schema.begin(); it != schema.end(); ++it) { - // make sure schema key name must be composed of '0-9' or 'a-z' or 'A-Z' or '_' - if (!ValidateFieldName(it.key())) { - MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', fieldName: " << it.key(); - return false; - } - - json it_value = it.value(); - if (it_value.find("type") == it_value.end()) { - MS_LOG(ERROR) << "No 'type' field exist: " << it_value.dump(); - return false; - } - - if (kFieldTypeSet.find(it_value["type"]) == kFieldTypeSet.end()) { - MS_LOG(ERROR) << "Wrong type: " << it_value["type"].dump(); - return false; - } - - if (it_value.size() == kInt1) { - continue; - } - - if (it_value["type"] == "bytes" || it_value["type"] == "string") { - MS_LOG(ERROR) << it_value["type"].dump() << " can not 1 field only."; - return false; - } - - if (it_value.size() != kInt2) { - MS_LOG(ERROR) << it_value["type"].dump() << " can have at most 2 fields."; - return false; - } - - if (!ValidateNumberShape(it_value)) { - return false; - } - } - - return true; -} - -bool Schema::operator==(const mindrecord::Schema &b) const { - if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { - return false; - } - return true; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc deleted file mode 100644 index a7fa4e7343..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_sequential_sample.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_sequential_sample.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -ShardSequentialSample::ShardSequentialSample(int n, int offset) - : ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {} - -ShardSequentialSample::ShardSequentialSample(float per, float per_offset) - : ShardSample(0), offset_(0), per_(per), per_offset_(per_offset) {} - -int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { - return dataset_size; - } - if (per_ > kEpsilon && per_ <= 1.0f) { - return dataset_size * kEpsilon; - } - return no_of_samples_; -} - -MSRStatus ShardSequentialSample::Execute(ShardTask &tasks) { - int total_no = static_cast(tasks.Size()); - int taking; - if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { - taking = total_no; - } else if (per_ > kEpsilon && per_ <= 1.0f) { - taking = total_no * kEpsilon; - } else { - taking = no_of_samples_; - } - - if (tasks.permutation_.empty()) { - ShardTask new_tasks; - total_no = static_cast(tasks.Size()); - for (int i = offset_; i < taking + offset_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); - } - std::swap(tasks, new_tasks); - } else { // shuffled - ShardTask new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { - return FAILED; - } - total_no = static_cast(tasks.permutation_.size()); - for (size_t i = offset_; i < taking + offset_; ++i) { - new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); - } - std::swap(tasks, new_tasks); - } - return SUCCESS; -} - -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc deleted file mode 100644 index 5cf49b04f0..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_shuffle.h" - -#include - -namespace mindspore { -namespace mindrecord { -ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) - : shuffle_seed_(seed), - no_of_samples_(0), - replacement_(false), - reshuffle_each_epoch_(true), - shuffle_type_(shuffle_type) {} - -ShardShuffle::ShardShuffle(uint32_t seed, int64_t no_of_samples, bool replacement, bool reshuffle_each_epoch, - ShuffleType shuffle_type) - : shuffle_seed_(seed), - no_of_samples_(no_of_samples), - replacement_(replacement), - reshuffle_each_epoch_(reshuffle_each_epoch), - shuffle_type_(shuffle_type) {} - -int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { - if (replacement_) { - return no_of_samples_ == 0 ? dataset_size : no_of_samples_; - } - return dataset_size; -} - -MSRStatus ShardShuffle::Execute(ShardTask &tasks) { - if (reshuffle_each_epoch_) shuffle_seed_++; - if (tasks.categories < 1) { - return FAILED; - } - if (shuffle_type_ == kShuffleSample) { // shuffle each sample - if (tasks.permutation_.empty() == true) { - tasks.MakePerm(); - } - if (replacement_ == true) { - ShardTask new_tasks; - if (no_of_samples_ == 0) { - no_of_samples_ = static_cast(tasks.Size()); - } - if (no_of_samples_ <= 0) { - MS_LOG(ERROR) << "no_of_samples need to be positive."; - return FAILED; - } - new_tasks.task_list_.reserve(no_of_samples_); - for (uint32_t i = 0; i < no_of_samples_; ++i) { - new_tasks.InsertTask(tasks.GetRandomTask()); - } - std::swap(tasks, new_tasks); - } else { - std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); - } - } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) - uint32_t individual_size = tasks.Size() / tasks.categories; - std::vector> new_permutations(tasks.categories, std::vector(individual_size)); - for (uint32_t i = 0; i < tasks.categories; i++) { - for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); - std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); - } - tasks.permutation_.clear(); - for (uint32_t j = 0; j < individual_size; j++) { - for (uint32_t i = 0; i < tasks.categories; i++) { - tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); - } - } - } - return SUCCESS; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc b/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc deleted file mode 100644 index ca36c50863..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_statistics.h" -#include "pybind11/pybind11.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::ERROR; - -namespace mindspore { -namespace mindrecord { -std::shared_ptr Statistics::Build(std::string desc, const json &statistics) { - // validate check - if (!Validate(statistics)) { - return nullptr; - } - Statistics object_statistics; - object_statistics.desc_ = std::move(desc); - object_statistics.statistics_ = statistics; - object_statistics.statistics_id_ = -1; - return std::make_shared(object_statistics); -} - -std::shared_ptr Statistics::Build(std::string desc, pybind11::handle statistics) { - // validate check - json statistics_json = nlohmann::detail::ToJsonImpl(statistics); - if (!Validate(statistics_json)) { - return nullptr; - } - Statistics object_statistics; - object_statistics.desc_ = std::move(desc); - object_statistics.statistics_ = statistics_json; - object_statistics.statistics_id_ = -1; - return std::make_shared(object_statistics); -} - -std::string Statistics::GetDesc() const { return desc_; } - -json Statistics::GetStatistics() const { - json str_statistics; - str_statistics["desc"] = desc_; - str_statistics["statistics"] = statistics_; - return str_statistics; -} - -pybind11::object Statistics::GetStatisticsForPython() const { - json str_statistics = Statistics::GetStatistics(); - return nlohmann::detail::FromJsonImpl(str_statistics); -} - -void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } - -int64_t Statistics::GetStatisticsID() const { return statistics_id_; } - -bool Statistics::Validate(const json &statistics) { - if (statistics.size() != kInt1) { - MS_LOG(ERROR) << "Statistics object is null"; - return false; - } - if (statistics.find("level") == statistics.end()) { - MS_LOG(ERROR) << "There is not 'level' object in statistic"; - return false; - } - return LevelRecursive(statistics["level"]); -} - -bool Statistics::LevelRecursive(json level) { - bool ini = true; - for (json::iterator it = level.begin(); it != level.end(); ++it) { - json a = it.value(); - if (a.size() == kInt2) { - if ((a.find("key") == a.end()) || (a.find("count") == a.end())) { - MS_LOG(ERROR) << "The node field is 2, but 'key'/'count' is not existed"; - return false; - } - } else if (a.size() == kInt3) { - if ((a.find("key") == a.end()) || (a.find("count") == a.end()) || a.find("level") == a.end()) { - MS_LOG(ERROR) << "The node field is 3, but 'key'/'count'/'level' is not existed"; - return false; - } else { - ini = LevelRecursive(a.at("level")); - } - } else { - MS_LOG(ERROR) << "The node field is not equal 2/3"; - return false; - } - } - return ini; -} - -bool Statistics::operator==(const Statistics &b) const { - if (this->GetStatistics() != b.GetStatistics()) { - return false; - } - return true; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc deleted file mode 100644 index 8baa3c26cd..0000000000 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mindrecord/include/shard_task.h" -#include "common/utils.h" -#include "mindrecord/include/common/shard_utils.h" - -using mindspore::LogStream; -using mindspore::ExceptionType::NoExceptionType; -using mindspore::MsLogLevel::DEBUG; - -namespace mindspore { -namespace mindrecord { -ShardTask::ShardTask() : categories(1) {} - -ShardTask::ShardTask(const ShardTask &other) - : categories(other.categories), permutation_(other.permutation_), task_list_(other.task_list_) {} - -ShardTask &ShardTask::operator=(const ShardTask &other) { - ShardTask tmp(other); - std::swap(categories, tmp.categories); - permutation_.swap(tmp.permutation_); - task_list_.swap(tmp.task_list_); - return *this; -} - -void ShardTask::MakePerm() { - permutation_ = std::vector(task_list_.size()); - for (uint32_t i = 0; i < task_list_.size(); i++) { - permutation_[i] = static_cast(i); - } -} - -void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector &offset, - const json &label) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id - << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); -} - -void ShardTask::InsertTask(std::tuple, std::vector, json> task) { - MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task)) - << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; - - task_list_.push_back(std::move(task)); -} - -void ShardTask::PopBack() { task_list_.pop_back(); } - -uint32_t ShardTask::Size() const { return static_cast(task_list_.size()); } - -uint32_t ShardTask::SizeOfRows() const { - if (task_list_.size() == 0) return static_cast(0); - - // 1 task is 1 page - auto sum_num_rows = [](int x, std::tuple, std::vector, json> y) { - return x + std::get<2>(y)[0]; - }; - uint32_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); - return nRows; -} - -std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { - MS_ASSERT(id < task_list_.size()); - return task_list_[id]; -} - -std::tuple, std::vector, json> &ShardTask::GetRandomTask() { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, task_list_.size() - 1); - return task_list_[dis(gen)]; -} - -ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { - ShardTask res; - if (category_tasks.empty()) return res; - auto total_categories = category_tasks.size(); - res.categories = static_cast(total_categories); - if (replacement == false) { - auto minTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - minTasks = std::min(minTasks, category_tasks[i].Size()); - } - for (uint32_t task_no = 0; task_no < minTasks; task_no++) { - for (uint32_t i = 0; i < total_categories; i++) { - res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast(task_no)))); - } - } - } else { - auto maxTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - maxTasks = std::max(maxTasks, category_tasks[i].Size()); - } - if (num_elements != std::numeric_limits::max()) { - maxTasks = static_cast(num_elements); - } - for (uint32_t i = 0; i < total_categories; i++) { - for (uint32_t j = 0; j < maxTasks; j++) { - res.InsertTask(category_tasks[i].GetRandomTask()); - } - } - } - return res; -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/onnx/CMakeLists.txt b/mindspore/ccsrc/onnx/CMakeLists.txt deleted file mode 100644 index a65ea6d450..0000000000 --- a/mindspore/ccsrc/onnx/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) -add_library(_mindspore_onnx_obj OBJECT ${_ONNX_SRC_FILES}) diff --git a/mindspore/ccsrc/onnx/ir_exporter.cc b/mindspore/ccsrc/onnx/ir_exporter.cc deleted file mode 100644 index a2a9072090..0000000000 --- a/mindspore/ccsrc/onnx/ir_exporter.cc +++ /dev/null @@ -1,618 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/param_value.h" -#include "debug/anf_ir_utils.h" -#include "operator/ops.h" -#include "proto/onnx.pb.h" - -namespace mindspore { -using FloatPtr = std::shared_ptr; -using IntPtr = std::shared_ptr; - -// anf type to onnx type map -static std::unordered_map g_data_type_map = { - {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, - {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, - {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, - {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, - {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, - {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, - {kObjectTypeString, onnx::TensorProto_DataType_STRING}, -}; - -static std::unordered_map g_data_bits_int_map = { - {8, onnx::TensorProto_DataType_INT8}, - {16, onnx::TensorProto_DataType_INT16}, - {32, onnx::TensorProto_DataType_INT32}, - {64, onnx::TensorProto_DataType_INT64}, -}; - -static std::unordered_map g_data_bits_float_map = { - {16, onnx::TensorProto_DataType_FLOAT16}, - {32, onnx::TensorProto_DataType_FLOAT}, -}; - -// Can build different builder according to format -class IrExportBuilder; -using IrExportBuilderPtr = std::shared_ptr; - -class IrExporter { - public: - explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} - virtual ~IrExporter() = default; - std::string GetDumpString(const FuncGraphPtr &func_graph); - - private: - IrExportBuilderPtr builder_; -}; - -class IrExportBuilder { - public: - IrExportBuilder() = default; - ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } - std::string GetProtoString(const FuncGraphPtr &func_graph); - void BuildModelInfo(); - void BuildModel(const FuncGraphPtr &func_graph); - - private: - void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); - void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); - void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); - std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); - - void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); - void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); - void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); - void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); - void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); - void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, - std::string suffix = "0"); - void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); - void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); - void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); - - onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); - onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); - std::string GetNodeName(const AnfNodePtr &node); - std::string GetUniqueNodeName(const AnfNodePtr &node); - std::string GetOpTypeName(const AnfNodePtr &node); - size_t AllocateIndex() { return ++node_index_; } - void ResetIndex() { node_index_ = 0; } - - private: - onnx::ModelProto model_; - onnx::NodeProto *last_node_{nullptr}; - std::list todo_; - std::map node_index_map_; - size_t node_index_{0}; -}; - -using IrExporterPtr = std::shared_ptr; - -std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { - if ((builder_ == nullptr) || (func_graph == nullptr)) { - MS_LOG(EXCEPTION) << "Input params is null."; - } - - // Export model info - builder_->BuildModelInfo(); - - // Export model and return string - builder_->BuildModel(func_graph); - - return builder_->GetProtoString(func_graph); -} - -std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { - MS_LOG(DEBUG) << "BuildModel complete!"; - return model_.SerializeAsString(); -} - -void IrExportBuilder::BuildModelInfo() { - model_.set_ir_version(onnx::IR_VERSION_2019_1_22); - model_.set_producer_name("MindSpore"); - model_.set_model_version(1); -} - -void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { - onnx::GraphProto *graph_proto = model_.mutable_graph(); - graph_proto->set_name(func_graph->ToString()); - ResetIndex(); - todo_.clear(); - todo_.push_back(func_graph); - while (!todo_.empty()) { - FuncGraphPtr fg = todo_.back(); - todo_.pop_back(); - BuildFuncGraph(fg, graph_proto); - } -} - -void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - // Export parameters - // 1. parameters should be mapped to ValueInfoProto - // 2. parameters with default value should be mapped to Initializer - BuildParameters(func_graph, graph_proto); - - // Export operator nodes(include output) - BuildNodes(func_graph, graph_proto); -} - -void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - for (auto &item : func_graph->parameters()) { - auto param = item->cast(); - if (param == nullptr) { - MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; - } - onnx::ValueInfoProto *input_proto = graph_proto->add_input(); - std::string param_name = GetUniqueNodeName(param); - input_proto->set_name(param_name); - SetValueInfoProto(param, input_proto); - if (!param->has_default()) { - MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; - continue; - } - - // Using ONNX initializer to set parameter's default value - onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); - initializer_proto->set_name(param_name); - SetParamToTensorProto(param, initializer_proto); - auto tensor = std::dynamic_pointer_cast(param->default_param()->value()); - if (tensor) { - initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); - } - } -} - -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { - auto iter = g_data_type_map.find(type_id); - if (iter == g_data_type_map.end()) { - MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; - } - return iter->second; -} - -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { - auto iter = g_data_bits_int_map.find(bits); - if (iter == g_data_bits_int_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; - } - return iter->second; -} - -onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { - auto iter = g_data_bits_float_map.find(bits); - if (iter == g_data_bits_float_map.end()) { - MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; - } - return iter->second; -} - -void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { - if (node == nullptr || value_proto == nullptr) { - MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; - } - MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); - SetValueInfoProto(node->Type(), node->Shape(), value_proto); -} - -void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::ValueInfoProto *const value_proto) { - onnx::TypeProto *type_proto = value_proto->mutable_type(); - if (type->isa() && shape->isa()) { - auto tensor = type->cast(); - auto elem_type = tensor->element(); - const auto &dims = shape->cast()->shape(); - type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); - for (const auto &dim : dims) { - MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - } else if (type->isa()) { - auto tup_shape = shape->cast(); - type_proto->set_denotation(std::to_string(tup_shape->shape().size())); - } else { - MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; - } -} - -void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("tensor"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - auto data = value->cast(); - tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); - auto dtype = data->data_type(); - auto shape = data->shape_c(); - tensor_proto->set_data_type(GetOnnxDataType(dtype)); - for (const auto &dim : shape) { - tensor_proto->add_dims(dim); - } -} - -void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::TensorProto *const tensor_proto) { - if (!type->isa() || !shape->isa()) { - MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); - } - auto tensor = type->cast(); - const auto &dims = shape->cast()->shape(); - tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); - for (const auto &dim : dims) { - tensor_proto->add_dims(dim); - } -} - -void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { - if (param == nullptr || tensor_proto == nullptr) { - MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; - } - MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); - SetTensorProto(param->Type(), param->Shape(), tensor_proto); -} - -void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); - for (const AnfNodePtr &node : nodes) { - if (!node->isa()) { - MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; - continue; - } - auto cnode = node->cast(); - if (cnode == func_graph->get_return()) { - BuildOutput(cnode, graph_proto); - } else { - BuildCNode(cnode, graph_proto); - } - } -} - -void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { - if (node->size() != 2) { - MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; - } - AnfNodePtr arg = node->input(1); - // Using make_tuple to set multi-output - if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { - auto tuple_node = arg->cast(); - for (size_t i = 1; i < tuple_node->size(); i++) { - auto input_node = arg->cast()->input(i); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - auto output_name = GetUniqueNodeName(tuple_node->input(i)); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(tuple_node->input(i), output_proto); - } - } else { - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - std::string output_name = GetUniqueNodeName(node); - output_proto->set_name(output_name); - last_node_->add_output(output_name); - SetValueInfoProto(arg, output_proto); - } -} - -std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { - // May be ValueNode/CNode/Parameter - std::string type_name = ""; - if (IsValueNode(node)) { - PrimitivePtr prim = GetValueNode(node); - type_name = prim->ToString(); - } else if (IsValueNode(node)) { - FuncGraphPtr fg = GetValueNode(node); - todo_.push_back(fg); - type_name = fg->ToString(); - } else if (node->isa() || node->isa()) { - type_name = node->ToString(); - } else { - MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); - } - MS_LOG(DEBUG) << "ExportType: " << type_name; - return type_name; -} - -void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, - onnx::NodeProto *const node_proto, std::string suffix) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_ref_attr_name("shape"); - if (suffix.compare("0") != 0) { - attr_proto->set_name("shape" + suffix); - } else { - attr_proto->set_name("shape"); - } - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetTensorProto(type, shape, tensor_proto); -} - -void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { - // Get shape of cnode - // 1. prim ArgMaxWithValue need to get shape from tuple element - // 2. some cnode doesn't has shape, such as LayerNorm - // 3. other cnodes have shape - if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa()) { - MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); - } - auto elements = type->cast()->elements(); - auto tuple_shape = shape->cast()->shape(); - for (size_t i = 0; i < elements.size(); i++) { - SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); - } - } else { - auto type = node->Type(); - auto shape = node->Shape(); - if (!type->isa() || !shape->isa()) { - MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); - return; - } - SetShapeToNodeProto(type, shape, node_proto); - } -} - -void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { - auto inputs_size = node->size(); - if (inputs_size < 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - - // Need to build input node before dealing with cnode - std::vector op_inputs; - std::vector input_names; - for (size_t i = 1; i < inputs_size; i++) { - auto input = node->input(i); - op_inputs.push_back(input); - input_names.push_back(BuildInputNode(input, graph_proto)); - } - - // Build cnode - onnx::NodeProto *node_proto = graph_proto->add_node(); - std::string output_name = GetUniqueNodeName(node); - node_proto->add_output(output_name); - node_proto->set_name(output_name); - node_proto->set_domain(node->fullname_with_scope()); - AnfNodePtr op = node->input(0); - std::string type_name = GetOpTypeName(op); - node_proto->set_op_type(type_name); - last_node_ = node_proto; - SetShapeToNodeProto(node, node_proto); - (void)std::for_each(input_names.begin(), input_names.end(), - [&node_proto](const string &name) { node_proto->add_input(name); }); - - // Add primitive attrs - if (IsValueNode(op)) { - auto prim = GetValueNode(op); - for (auto attr : prim->attrs()) { - MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name(attr.first); - SetValueToAttributeProto(attr.second, attr_proto); - } - } else { - MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); - } -} - -std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { - std::string node_name = GetUniqueNodeName(node); - if (node->isa()) { - // When node input is a ValueNode, need to create a Constant Node - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->add_output(node_name); - SetAttributeProto(node, node_proto); - } - return node_name; -} - -std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { - // Naming anfnode - // 1. parameter is unique in one func_graph - // 2. cnode and valuenode may be reduplicative, so add index to identify. - std::string node_name = ""; - if (node->isa()) { - node_name = GetNodeName(node); - } else if (node->isa() || node->isa()) { - auto iter = node_index_map_.find(node); - if (iter != node_index_map_.end()) { - node_name = GetNodeName(node) + ":" + std::to_string(iter->second); - } else { - auto node_idx = AllocateIndex(); - node_index_map_[node] = node_idx; - node_name = GetNodeName(node) + ":" + std::to_string(node_idx); - } - } else { - MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); - } - MS_LOG(DEBUG) << "Node name: " << node_name; - return node_name; -} - -std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { - std::string node_name = ""; - if ((node != nullptr) && (node->func_graph() != nullptr)) { - node_name = node->func_graph()->ToString() + ":"; - } - node_name += node->ToString(); - MS_LOG(DEBUG) << "GetNodeName: " << node_name; - return node_name; -} - -void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { - if (node == nullptr || node_proto == nullptr) { - MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; - } - auto value = node->cast()->value(); - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); - SetValueToAttributeProto(value, attr_proto); -} - -void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("type"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - if (value->isa()) { - auto int_value = value->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); - } else if (value->isa()) { - auto float_value = value->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); - } else if (value->isa()) { - tensor_proto->set_name("tensor"); - auto elem_type = value->cast()->element(); - if (elem_type->isa()) { - auto int_value = elem_type->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); - } else if (elem_type->isa()) { - auto float_value = elem_type->cast(); - tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); - } else { - MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); - } - } else { - MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); - } -} - -void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - if (value->isa() || value->isa()) { - SetScalarToAttributeProto(value, attr_proto); - } else if (value->isa() || value->isa()) { - SetTypeToAttributeProto(value, attr_proto); - } else if (value->isa()) { - SetSequenceToAttributeProto(value->cast(), attr_proto); - } else if (value->isa()) { - SetTensorToAttributeProto(value, attr_proto); - } else { - MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); - } -} - -void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - SetScalarToProto(value, tensor_proto); -} - -void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { - if (value == nullptr || tensor_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; - } - if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); - tensor_proto->add_string_data(GetValue(value)); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); - tensor_proto->add_int32_data(GetValue(value)); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); - tensor_proto->add_int32_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - tensor_proto->add_int64_data(value->cast()->value()); - } else if (value->isa()) { - tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); - tensor_proto->add_float_data(GetValue(value)); - } else { - MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); - } -} - -void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, - onnx::AttributeProto *const attr_proto) { - if (value == nullptr || attr_proto == nullptr) { - MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; - } - attr_proto->set_ref_attr_name("scalar"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - if (value->isa()) { - const ValueTuplePtr &tuple_value = value->cast(); - if (tuple_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; - return; - } - auto type_id = tuple_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); - for (const auto &item : tuple_value->value()) { - SetScalarToProto(item, tensor_proto); - } - } else if (value->isa()) { - const ValueListPtr &list_value = value->cast(); - if (list_value->value().size() == 0) { - MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; - return; - } - auto type_id = list_value->value()[0]->type()->type_id(); - tensor_proto->set_data_type(GetOnnxDataType(type_id)); - for (const auto &item : list_value->value()) { - SetScalarToProto(item, tensor_proto); - } - } -} - -std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { - auto builder = std::make_shared(); - if (builder == nullptr) { - MS_LOG(ERROR) << "Create ir exporter failed!"; - return ""; - } - auto exporter = std::make_shared(builder); - if (exporter == nullptr) { - return ""; - } - return exporter->GetDumpString(func_graph); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc deleted file mode 100644 index 43c5c118c1..0000000000 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ /dev/null @@ -1,1207 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "debug/anf_ir_utils.h" -#include "proto/onnx.pb.h" -#include "operator/ops.h" -#include "ir/tensor.h" -#include "ir/param_value.h" - -namespace mindspore { -enum OpMergeMode { - OP_MERGE_UNDEFINED = 0, // undefined behavior - OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list - OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` - OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` - OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` - OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` -}; - -struct OpMergedInfo { - OpMergeMode mode = OP_MERGE_UNDEFINED; - int referred_count = 0; -}; - -using GenAttrFuncType = - std::function; - -template -void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - auto casted_value = dyn_cast(value); - if (casted_value == nullptr) { - MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; - } - auto attr_value = casted_value->value(); - switch (attr_type) { - case onnx::AttributeProto_AttributeType_INT: - attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); - break; - case onnx::AttributeProto_AttributeType_FLOAT: - attr_proto->set_f(static_cast(attr_value)); - break; - case onnx::AttributeProto_AttributeType_INTS: - for (size_t i = 0; i < rep_cnt; ++i) { - attr_proto->add_ints(static_cast<::google::protobuf::int64>(attr_value)); - } - break; - case onnx::AttributeProto_AttributeType_FLOATS: - for (size_t i = 0; i < rep_cnt; ++i) { - attr_proto->add_floats(static_cast(attr_value)); - } - break; - default: - MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; - } - attr_proto->set_type(attr_type); -} - -template -void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - auto tuple_ptr = dyn_cast(value); - if (tuple_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; - } - switch (attr_type) { - case onnx::AttributeProto_AttributeType_INTS: - for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { - attr_proto->add_ints(GetValue((*tuple_ptr)[i])); - } - break; - case onnx::AttributeProto_AttributeType_FLOATS: - for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { - attr_proto->add_floats(GetValue((*tuple_ptr)[i])); - } - break; - default: - MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; - } - attr_proto->set_type(attr_type); -} - -void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - auto attr_value = GetValue(value); - if (attr_value == "VALID") { - attr_proto->set_s("VALID"); - } else { - attr_proto->set_s("SAME_UPPER"); - } -} - -class OpAttrInfo { - public: - OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) - : attr_name_(attr_name), - onnx_attr_name_(onnx_attr_name), - onnx_attr_type_(onnx_attr_type), - fn_gen_attr_(fn_gen_attr) {} - ~OpAttrInfo() {} - - const std::string &attr_name() const { return attr_name_; } - const std::string &onnx_attr_name() const { return onnx_attr_name_; } - onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } - GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } - - private: - std::string attr_name_; // attribute name of MindSpore - std::string onnx_attr_name_; // corresponding attribute name of ONNX - onnx::AttributeProto_AttributeType onnx_attr_type_; // corresponding attribute type of ONNX - GenAttrFuncType fn_gen_attr_; // function used convert -}; - -class OpNameInfo { - public: - OpNameInfo &set_op_type(const std::string &op_type) { - op_type_ = op_type; - return *this; - } - - const std::string &op_type() const { return op_type_; } - - OpNameInfo &set_onnx_type(const std::string &onnx_type) { - onnx_type_ = onnx_type; - return *this; - } - - const std::string &onnx_type() const { return onnx_type_; } - - OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { - op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); - return *this; - } - - const std::vector &op_attrs() const { return op_attrs_; } - - private: - std::string op_type_; // operator type of MindSpore - std::string onnx_type_; // corresponding ONNX operator type - std::vector op_attrs_; // operator attributes map info -}; - -#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ - OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } - -OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) - -OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo()) - -OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze, - OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS, - SetAttrTupleValueToProto<0>)) - -OPERATOR_ONNX_CONVERT_DEFINE( - Conv2D, Conv, - OpNameInfo() - .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) - .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) - .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, - [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, - const PrimitivePtr &prim) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - auto attr_value = GetValue(value); - if (attr_value == "valid") { - attr_proto->set_s("VALID"); - } else if (attr_value == "same") { - attr_proto->set_s("SAME_UPPER"); - } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' - attr_proto->set_name("pads"); - SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, - prim); - } - }) - .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) -OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm, - OpNameInfo() - .Attr("transpose_a", "transA", onnx::AttributeProto_AttributeType_INT, - SetAttrValueToProto) - .Attr("transpose_b", "transB", onnx::AttributeProto_AttributeType_INT, - SetAttrValueToProto)) - -OPERATOR_ONNX_CONVERT_DEFINE(BatchNorm, BatchNormalization, - OpNameInfo().Attr("epsilon", "epsilon", onnx::AttributeProto_AttributeType_FLOAT, - SetAttrValueToProto)) - -OPERATOR_ONNX_CONVERT_DEFINE(Reshape, Reshape, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(ReduceMean, ReduceMean, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Cast, Cast, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(PReLU, PRelu, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, - OpNameInfo() - .Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT, - SetAttrValueToProto) - .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, - [](ValuePtr, onnx::AttributeProto_AttributeType, - onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - attr_proto->set_i(0); - })) - -OPERATOR_ONNX_CONVERT_DEFINE(SimpleMean, AveragePool, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE( - MaxPool, MaxPool, - OpNameInfo() - .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) - .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) - -OPERATOR_ONNX_CONVERT_DEFINE( - MaxPoolWithArgmax, MaxPool, - OpNameInfo() - .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) - .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) - -OPERATOR_ONNX_CONVERT_DEFINE( - AvgPool, AveragePool, - OpNameInfo() - .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) - .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) - .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) - -OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) - -#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name - -void RegisterOpConverters(const std::function &fn) { - fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); - fn(OP_CONVERT_FUNCTION_NAME(Mul)()); - - fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); - fn(OP_CONVERT_FUNCTION_NAME(Sigmoid)()); - - fn(OP_CONVERT_FUNCTION_NAME(Conv2D)()); - fn(OP_CONVERT_FUNCTION_NAME(Argmax)()); - - fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); - fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); - fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); - fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); - - fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); - fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); - fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); - - fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); - fn(OP_CONVERT_FUNCTION_NAME(Concat)()); - fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); - fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); - fn(OP_CONVERT_FUNCTION_NAME(Sub)()); -} - -class OpConvertRegistry { - public: - ~OpConvertRegistry() { Clear(); } - - static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } - - static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } - - static OpConvertRegistry &GetSingleton() { - static OpConvertRegistry registry = OpConvertRegistry(); - return registry; - } - - static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } - - void Clear() noexcept { op_map_.clear(); } - - private: - OpConvertRegistry() {} - - std::unordered_map op_map_; -}; - -class OnnxExporter { - public: - OnnxExporter() {} - ~OnnxExporter() {} - - std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); - - private: - void InitModelInfo(); - - void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - - size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, - const PrimitivePtr &prim, const std::vector &inputs, - onnx::GraphProto *graph_proto); - - static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); - void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); - - void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, - std::unordered_map *op_merged_infos_ptr); - void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - - void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - - void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - - void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *graph_proto); - - void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *graph_proto); - std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *const graph_proto); - - void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); - void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); - - size_t AllocateNodeIndex() { return ++onnx_node_index_; } - - void ResetNodeIndex() { onnx_node_index_ = 0; } - - static int GetInt32Value(const AnfNodePtr &node) { - auto value_node_ptr = dyn_cast(node); - MS_EXCEPTION_IF_NULL(value_node_ptr); - return GetValue(value_node_ptr->value()); - } - - onnx::ModelProto model_; - - size_t onnx_node_index_ = 0; -}; - -std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - return ""; - } - ResetNodeIndex(); - OpConvertRegistry::GetSingleton().Clear(); - OpConvertRegistry::RegisterAllOpConverters(); - InitModelInfo(); - onnx::GraphProto *graph_proto = model_.mutable_graph(); - ExportFuncGraph(func_graph, graph_proto); - return model_.SerializeAsString(); -} - -void OnnxExporter::InitModelInfo() { - model_.set_ir_version(onnx::IR_VERSION_2019_1_22); - model_.set_producer_name("MindSpore"); - model_.set_producer_version("1.0"); - onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); - opset_proto->set_version(9); -} - -void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - std::map node_map; - - MS_LOG(INFO) << "Begin exporting onnx model for graph " << func_graph->ToString(); - - onnx_node_index_ = func_graph->parameters().size(); - - // set graph name - graph_proto->set_name(func_graph->ToString()); - - // export parameters - // 1. all parameters (with or without default value) will be mapped to ONNX parameters - // 2. parameters with default value will mapped to ONNX initializers - ExportParameters(func_graph, graph_proto); - - // export computational nodes and output nodes - ExportNodes(func_graph, &node_map, graph_proto); - - MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString(); -} - -void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { - for (auto ¶m : func_graph->parameters()) { - const ParameterPtr param_ptr = dyn_cast(param); - if (param_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; - } - - onnx::ValueInfoProto *input_proto = graph_proto->add_input(); - input_proto->set_name(param_ptr->ToString()); - SetValueInfoType(param_ptr, input_proto); - - if (!param_ptr->has_default()) { - continue; - } - // parameter with default value is an ONNX initializer - onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); - initializer_proto->set_name(param_ptr->ToString()); - SetTensorProtoInfo(param_ptr, initializer_proto); - // set value for initializer - auto tensor = std::dynamic_pointer_cast(param_ptr->default_param()->value()); - if (tensor) { - initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); - } - } -} - -onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { - // clang-format off - static std::unordered_map type_map = { - {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, - {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, - {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, - {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, - {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, - {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, - {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, - {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, - {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, - {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, - {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, - {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, - }; - // clang-format on - - auto iter = type_map.find(type_id); - if (iter == type_map.end()) { - MS_LOG(EXCEPTION) << "Convert type error, unsupported type " << type_id; - } - - return iter->second; -} - -void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { - auto dtype = node->Type(); - auto shape = node->Shape(); - onnx::TypeProto *type_proto = value_proto->mutable_type(); - if (dtype->isa() && shape->isa()) { - auto tensor = dyn_cast(dtype); - auto elem_type = tensor->element(); - const auto &dims = dyn_cast(shape)->shape(); - // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 - auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); - type_proto->mutable_tensor_type()->set_elem_type(type); - - for (const auto &dim : dims) { - type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); - } - } -} - -void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { - auto dtype = param->Type(); - auto shape = param->Shape(); - if (!dtype->isa() || !shape->isa()) { - MS_LOG(EXCEPTION) << "Parameter " << param->name() << " is not a regular tensor, with value " << param->ToString(); - } - - auto tensor = dyn_cast(dtype); - auto elem_type = tensor->element(); - const auto &dims = dyn_cast(shape)->shape(); - tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); - for (const auto &dim : dims) { - tensor_proto->add_dims(dim); - } -} - -void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, - std::unordered_map *op_merged_infos_ptr) { - std::unordered_map &op_merged_infos = *op_merged_infos_ptr; - - for (auto &node : nodes) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (cnode == func_graph->get_return()) { - // if the key `input` does not exist, just create a new one - op_merged_infos[cnode].referred_count += 1; - } - for (auto &input : cnode->inputs()) { - if (!input->isa()) { - continue; - } - // if the key `input` does not exist, just create a new one - op_merged_infos[input].referred_count += 1; - } - // MindSpore Conv + BiasAdd --> ONNX Conv - if (cnode->IsApply(std::make_shared("BiasAdd")) && - IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) { - op_merged_infos[cnode].mode = OP_MERGE_CONV; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } else if (cnode->IsApply(std::make_shared("BiasAdd")) && - IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) { - op_merged_infos[cnode].mode = OP_MERGE_GEMM; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } else if (cnode->IsApply(prim::kPrimTupleGetItem) && - IsPrimitiveCNode(cnode->input(1), std::make_shared("BatchNorm")) && - GetInt32Value(cnode->input(2)) == 0) { - op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } else if (cnode->IsApply(prim::kPrimTupleGetItem) && - IsPrimitiveCNode(cnode->input(1), std::make_shared("MaxPoolWithArgmax")) && - GetInt32Value(cnode->input(2)) == 0) { - op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; - op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; - op_merged_infos[cnode->input(1)].referred_count -= 1; - } - } -} - -/** - * AnfNode - * +-- CNode - * +-- ANode - * | +-- Parameter - * | `-- ValueNode - */ -void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); - - std::unordered_map op_merged_infos; - MatchAndMark(func_graph, nodes, &op_merged_infos); - - for (const AnfNodePtr &node : nodes) { - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - auto iter = op_merged_infos.find(cnode); - // the node is not referenced by any other nodes, skip it - if (iter == op_merged_infos.end()) { - continue; - } - auto merged_info = iter->second; - // the op node is merged with other node and not used any more, skip it - if (merged_info.mode == OP_MERGE_IGNORE && merged_info.referred_count == 0) { - continue; - } - if (cnode == func_graph->get_return()) { - ExportOutput(func_graph, cnode, node_map_ptr, graph_proto); - continue; - } - switch (merged_info.mode) { - case OP_MERGE_CONV: - ExportMergeConv(func_graph, cnode, node_map_ptr, graph_proto); - break; - case OP_MERGE_GEMM: - ExportMergeGemm(func_graph, cnode, node_map_ptr, graph_proto); - break; - case OP_MERGE_BATCH_NORM: - ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto); - break; - case OP_MERGE_MAXPOOL_WITH_ARGMAX: - ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto); - break; - default: - ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); - break; - } - } -} - -void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_shape = node->input(2); - std::string name_shape; - if (input_shape->isa()) { - auto const_node_idx = AllocateNodeIndex(); - (*node_map_ptr)[input_shape] = const_node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - name_shape = std::to_string(const_node_idx); - node_proto->add_output(name_shape); - - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - ConvertTupleToTensor(dyn_cast(input_shape)->value(), attr_proto->mutable_t()); - } else { - name_shape = GetNodeInputName(input_shape, node_map_ptr, graph_proto); - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Reshape."; - } - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimReshape->name()); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_shape); -} - -void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_axis = node->input(2); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - auto name = prim::kPrimReduceMean->name(); - if (node->IsApply(prim::kPrimReduceSum)) { - name = prim::kPrimReduceSum->name(); - } - node_proto->set_op_type(name); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_data); - - if (input_axis->isa()) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("axes"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); - auto axis_value = dyn_cast(input_axis)->value(); - auto int_ptr = dyn_cast(axis_value); - if (int_ptr == nullptr) { - auto tuple_ptr = dyn_cast(axis_value); - MS_EXCEPTION_IF_NULL(tuple_ptr); - for (size_t i = 0; i < tuple_ptr->size(); ++i) { - attr_proto->add_ints(GetValue((*tuple_ptr)[i])); - } - } else { - attr_proto->add_ints(int_ptr->value()); - } - } else { - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; - } -} - -void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_type = node->input(2); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimCast->name()); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_data); - - if (input_type->isa()) { - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("to"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - auto type_value = dyn_cast(input_type)->value(); - auto type_ptr = dyn_cast(type_value); - MS_EXCEPTION_IF_NULL(type_ptr); - attr_proto->set_i(GetOnnxDataType(type_ptr->type_id())); - } else { - MS_LOG(EXCEPTION) << "Need to convert MindSpore Cast input(1) to ONNX Cast to attribute."; - } -} - -void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); - - auto x_shape = dyn_cast(node->input(1)->Shape()); - auto slope_shape = dyn_cast(node->input(2)->Shape()); - MS_EXCEPTION_IF_NULL(x_shape); - MS_EXCEPTION_IF_NULL(slope_shape); - - // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] - if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { - auto node_idx = AllocateNodeIndex(); - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Unsqueeze"); - node_proto->add_output(std::to_string(node_idx)); - - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); - attr_proto->set_name("axes"); - attr_proto->add_ints(1); - attr_proto->add_ints(2); - - node_proto->add_input(input_slope); - input_slope = std::to_string(node_idx); - } - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("PRelu"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_x); - node_proto->add_input(input_slope); -} - -void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Clip"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(input_x); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); - attr_proto->set_name("min"); - attr_proto->set_f(0.f); - attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); - attr_proto->set_name("max"); - attr_proto->set_f(6.f); -} - -void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); - auto x_shape = dyn_cast(node->input(1)->Shape()); - auto w_shape = dyn_cast(node->input(2)->Shape()); - MS_EXCEPTION_IF_NULL(x_shape); - MS_EXCEPTION_IF_NULL(w_shape); - if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) { - MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d."; - } - if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) { - MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape"; - } - // create w_shape constant node - auto node_idx = AllocateNodeIndex(); - onnx::NodeProto *node_proto = graph_proto->add_node(); - std::string name_w_shape = std::to_string(node_idx); - node_proto->add_output(name_w_shape); - node_proto->set_op_type("Constant"); - // create Value Tensor - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size())); - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - // reshape - tensor_proto->add_int64_data(w_shape->shape()[1]); - tensor_proto->add_int64_data(w_shape->shape()[0]); - tensor_proto->add_int64_data(w_shape->shape()[2]); - tensor_proto->add_int64_data(w_shape->shape()[3]); - - // add reshape node - node_idx = AllocateNodeIndex(); - node_proto = graph_proto->add_node(); - node_proto->set_op_type(prim::kPrimReshape->name()); - node_proto->add_input(input_w); - node_proto->add_input(name_w_shape); - input_w = std::to_string(node_idx); - node_proto->add_output(input_w); - - // add conv node - node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - node_proto = graph_proto->add_node(); - node_proto->set_op_type("Conv"); - node_proto->add_input(input_x); - node_proto->add_input(input_w); - node_proto->add_output(std::to_string(node_idx)); - // set attributes - AnfNodePtr op = node->input(0); - auto op_value = dyn_cast(op); - auto prim = dyn_cast(op_value->value()); - // set dilations - onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("dilations"); - SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, - prim); - // set group - onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("group"); - onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - onnx_attr_proto->set_i(x_shape->shape()[1]); - // set kernel_shape - onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("kernel_shape"); - SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, - prim); - - // set pad - onnx_attr_proto = node_proto->add_attribute(); - auto attr_value = GetValue(prim->GetAttr("pad_mode")); - onnx_attr_proto->set_name("auto_pad"); - onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); - if (attr_value == "valid") { - onnx_attr_proto->set_s("VALID"); - } else if (attr_value == "same") { - onnx_attr_proto->set_s("SAME_UPPER"); - } else { - onnx_attr_proto->set_name("pads"); - SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); - } - // set strides - onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name("strides"); - SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); -} - -void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto multiples = node->input(2); - std::string name_multiples; - if (multiples->isa()) { - auto const_node_idx = AllocateNodeIndex(); - (*node_map_ptr)[multiples] = const_node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - name_multiples = std::to_string(const_node_idx); - node_proto->add_output(name_multiples); - - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("repeat"); - - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - ConvertTupleToTensor(dyn_cast(multiples)->value(), attr_proto->mutable_t()); - } else { - name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto); - MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile."; - } - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Tile"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_multiples); -} - -void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - std::string name_exponent; - auto const_node_idx = AllocateNodeIndex(); - onnx::NodeProto *node_proto_exp = graph_proto->add_node(); - name_exponent = std::to_string(const_node_idx); - node_proto_exp->add_output(name_exponent); - - node_proto_exp->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - tensor_proto->set_name("exponent"); - tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1)); - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - tensor_proto->add_int64_data(2); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Pow"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_exponent); -} - -void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); - auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); - auto axis = node->input(3)->cast()->value(); - - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->set_op_type("Gather"); - node_proto->add_output(std::to_string(node_idx)); - node_proto->add_input(name_x); - node_proto->add_input(name_indices); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast(axis)->value())); -} - -void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert - if (node->IsApply(prim::kPrimReshape)) { - return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); - } - - if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { - return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) - if (node->IsApply(prim::kPrimCast)) { - return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); - } - - // ONNX PRelu requires unidirectional broadcasting, here need some process - if (node->IsApply(std::make_shared("PReLU"))) { - return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) - if (node->IsApply(std::make_shared("ReLU6"))) { - return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) - if (node->IsApply(std::make_shared("DepthwiseConv2dNative"))) { - return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore Tile(x) --> ONNX Tile(x, repeat) - if (node->IsApply(prim::kPrimTile)) { - return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore Square(x) --> ONNX Pow(x, 2) - if (node->IsApply(prim::kPrimSquare)) { - return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); - } - - // MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices) - if (node->IsApply(prim::kPrimGatherV2)) { - return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); - } - - auto inputs = node->inputs(); - if (inputs.size() < 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - - AnfNodePtr op = inputs[0]; - std::vector op_inputs; - // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator - for (size_t i = 1; i < inputs.size(); i++) { - op_inputs.push_back(inputs[i]); - } - auto op_value = dyn_cast(op); - if (op_value == nullptr) { - MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name(); - } - auto prim = dyn_cast(op_value->value()); - if (prim == nullptr) { - MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); - } - - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); -} - -size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, - const PrimitivePtr &prim, const std::vector &inputs, - onnx::GraphProto *const graph_proto) { - auto op_map = OpConvertRegistry::GetOpConvertMap(); - auto op_iter = op_map.find(prim->name()); - if (op_iter == op_map.end()) { - MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; - } - const OpNameInfo &op_convert_info = op_iter->second; - - auto node_idx = AllocateNodeIndex(); - - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->add_output(std::to_string(node_idx)); - node_proto->set_op_type(op_convert_info.onnx_type()); - - // Set inputs - for (const auto &input : inputs) { - auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); - node_proto->add_input(input_name); - } - - // Set node attribute - for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { - const std::string &attr_name = attr.attr_name(); - ValuePtr attr_value = nullptr; - if (!attr_name.empty()) { - attr_value = prim->GetAttr(attr_name); - if (attr_value == nullptr) { - MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; - } - } - onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); - onnx_attr_proto->set_name(attr.onnx_attr_name()); - attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); - } - return node_idx; -} - -void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto conv_node = dyn_cast(node->input(1)); - auto input_x = conv_node->input(1); // conv input x - auto input_w = conv_node->input(2); // conv weight(filter) - auto input_b = node->input(2); // conv bias - - PrimitivePtr prim_conv = dyn_cast((dyn_cast(conv_node->input(0)))->value()); - std::vector inputs{input_x, input_w, input_b}; - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); -} - -void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - auto matmul_node = dyn_cast(node->input(1)); - auto input_x = matmul_node->input(1); // matmul input x - auto input_y = matmul_node->input(2); // matmul input y - auto input_b = node->input(2); // matmul bias - - PrimitivePtr prim_matmul = dyn_cast((dyn_cast(matmul_node->input(0)))->value()); - std::vector inputs{input_x, input_y, input_b}; - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); -} - -void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - auto batch_norm_node = dyn_cast(node->input(1)); - - PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); - std::vector inputs; - for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) { - inputs.push_back(batch_norm_node->input(i)); - } - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); -} - -void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, - std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - auto maxpool_with_argmax_node = dyn_cast(node->input(1)); - - PrimitivePtr prim_maxpool_with_argmax = - dyn_cast((dyn_cast(maxpool_with_argmax_node->input(0)))->value()); - std::vector inputs; - for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) { - inputs.push_back(maxpool_with_argmax_node->input(i)); - } - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto); -} - -void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, - std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { - if (node->inputs().size() != 2) { - MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; - } - AnfNodePtr arg = node->input(1); - std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); - onnx::ValueInfoProto *output_proto = graph_proto->add_output(); - output_proto->set_name(name); - SetValueInfoType(arg, output_proto, false); -} - -std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, - onnx::GraphProto *const graph_proto) { - if (node->isa()) { - auto iter = node_map_ptr->find(node); - if (iter == node_map_ptr->end()) { - MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in node_map"; - } - return std::to_string(iter->second); - } - - if (node->isa()) { - return node->ToString(); - } - - // for ValueNode input, create a Constant Operator - if (node->isa()) { - auto iter = node_map_ptr->find(node); - if (iter != node_map_ptr->end()) { - return std::to_string(iter->second); - } - // the id number starts at 1, so the id of created node should be size of map plus one - auto node_idx = AllocateNodeIndex(); - (*node_map_ptr)[node] = node_idx; - std::string node_name = std::to_string(node_idx); - - onnx::NodeProto *node_proto = graph_proto->add_node(); - node_proto->add_output(node_name); - - SetNodeAttribute(node->cast()->value(), node_proto); - - return node_name; - } - - MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); -} - -void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { - auto tuple_ptr = dyn_cast(value); - MS_EXCEPTION_IF_NULL(tuple_ptr); - if (tuple_ptr->size() == 0) { - MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, the size of converted tuple is 0."; - } - auto type_id = (*tuple_ptr)[0]->type()->type_id(); - for (size_t i = 1; i < tuple_ptr->size(); ++i) { - if ((*tuple_ptr)[i]->type()->type_id() != type_id) { - MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, type of tuple elements is not same."; - } - } - - tensor_proto->add_dims(static_cast<::google::protobuf::int64>(tuple_ptr->size())); - tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); - for (size_t i = 0; i < tuple_ptr->size(); ++i) { - ValuePtr elem = (*tuple_ptr)[i]; - if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else if (elem->isa()) { - tensor_proto->add_int64_data(dyn_cast(elem)->value()); - } else { - MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type " << elem->type()->type_name() - << "."; - } - } -} - -void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { - node_proto->set_op_type("Constant"); - onnx::AttributeProto *attr_proto = node_proto->add_attribute(); - attr_proto->set_name("value"); - if (value->isa()) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - auto casted_value = dyn_cast(value); - if (casted_value == nullptr) { - MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; - } - auto attr_value = casted_value->value(); - attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); - attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); - } else if (value->isa()) { - attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); - onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); - auto data = dyn_cast(value); - tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); - auto dtype = data->data_type(); - auto shape = data->shape_c(); - - tensor_proto->set_data_type(GetOnnxDataType(dtype)); - for (const auto &dim : shape) { - tensor_proto->add_dims(dim); - } - } else { - MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; - } -} - -std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { - OnnxExporter exporter; - return exporter.GetOnnxProtoString(func_graph); -} -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/CMakeLists.txt b/mindspore/ccsrc/operator/CMakeLists.txt deleted file mode 100644 index 88bcf0e532..0000000000 --- a/mindspore/ccsrc/operator/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _OPERATOR_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_OPERATOR_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) -add_library(_mindspore_operator_obj OBJECT ${_OPERATOR_SRC_FILES}) diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc deleted file mode 100644 index 52b71f410f..0000000000 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ /dev/null @@ -1,432 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/cc_implementations.h" -#include -#include -#include -#include -#include -#include "utils/misc.h" -#include "utils/log_adapter.h" -#include "utils/convert_utils.h" -#include "common/utils.h" - -namespace mindspore { -// namespace to support primitive operators definition -namespace prim { -enum class DataType { kInt, kFloat, kDouble, kUnknown }; - -// Whether has a T type data in AnyPtrList. -template -bool HasType(const AnyPtrList &list) { - bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); - return ret; -} - -DataType InferType(const AnyPtrList &list) { - if (HasType(list)) { - return DataType::kDouble; - } else if (HasType(list)) { - return DataType::kFloat; - } else if (HasType(list)) { - return DataType::kInt; - } - return DataType::kUnknown; -} - -enum OpType { ADD, SUB, MUL, DIV, MOD }; - -template -bool IsSignedIntOverflow(T x, T y, OpType opType) { - auto max = std::numeric_limits::max(); - auto min = std::numeric_limits::min(); - - if (opType == OpType::ADD) { - return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x); - } - - if (opType == OpType::SUB) { - return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x); - } - - if (opType == OpType::MUL) { - return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || - (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x); - } - - if (opType == OpType::DIV || opType == OpType::MOD) { - return x == min && static_cast(y) == -1; - } - - MS_LOG(EXCEPTION) << "Unsupported operation type."; -} - -template -T InnerScalarAdd(T x, T y) { - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::ADD)) { - MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return x + y; -} - -template -T InnerScalarSub(T x, T y) { - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::SUB)) { - MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return x - y; -} - -template -T InnerScalarMul(T x, T y) { - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MUL)) { - MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return x * y; -} - -template -float InnerScalarDiv(T x, T y) { - if (y == 0) { - MS_LOG(EXCEPTION) << "Divisor could not be zero"; - } - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::DIV)) { - MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - return static_cast(x) / static_cast(y); -} - -template -T InnerScalarFloordiv(T x, T y) { - auto ret = std::floor(InnerScalarDiv(x, y)); - if (std::is_integral::value) { - return static_cast(ret); - } - return ret; -} - -template -T InnerScalarMod(T x, T y) { - if (y == 0) { - MS_LOG(EXCEPTION) << "Could not mod to zero."; - } - if (std::is_integral::value && std::is_signed::value && IsSignedIntOverflow(x, y, OpType::MOD)) { - MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x) - << ", y: " << std::to_string(y) << "."; - } - if (std::is_integral::value) { - return static_cast(x) % static_cast(y); - } - int x_int = std::floor(x); - int y_int = std::ceil(y); - int max = x_int / y_int; - float ret = x - y * max; - return ret; -} - -template -T InnerScalarPow(T x, U y) { - return std::pow(x, y); -} - -template -bool InnerScalarEq(T x, U y) { - double error = static_cast(x) - static_cast(y); - error = fabs(error); - return error < DBL_EPSILON; -} - -template -bool InnerScalarLt(T x, U y) { - return x < y; -} - -template -bool InnerScalarGt(T x, U y) { - return x > y; -} - -template -bool InnerScalarNe(T x, U y) { - return !InnerScalarEq(x, y); -} - -template -bool InnerScalarLe(T x, U y) { - return x <= y; -} - -template -bool InnerScalarGe(T x, U y) { - return x >= y; -} - -#define SCALAR_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList &list) { \ - do { \ - if (list.size() < 2) { \ - MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ - } \ - ValuePtr x = list[0]; \ - ValuePtr y = list[1]; \ - MS_EXCEPTION_IF_NULL(x); \ - MS_EXCEPTION_IF_NULL(y); \ - if (x->isa() && y->isa()) { \ - double sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - float sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - int sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - float sum = InnerScalar##op_t(IntToFloat(GetValue(x)), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - float sum = InnerScalar##op_t(GetValue(x), IntToFloat(GetValue(y))); \ - return MakeValue(sum); \ - } \ - MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ - << ", y: " << y->ToString(); \ - } while (0); \ - } - -SCALAR_OP(Add) -SCALAR_OP(Sub) -SCALAR_OP(Mul) -SCALAR_OP(Div) -SCALAR_OP(Mod) -SCALAR_OP(Pow) -SCALAR_OP(Floordiv) - -#define LOGIC_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList &list) { \ - if (list.size() < 2) { \ - MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ - } \ - ValuePtr x = list[0]; \ - ValuePtr y = list[1]; \ - MS_EXCEPTION_IF_NULL(x); \ - MS_EXCEPTION_IF_NULL(y); \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - if (x->isa() && y->isa()) { \ - bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ - return MakeValue(sum); \ - } \ - MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \ - << ", y: " << y->ToString() << "."; \ - } - -LOGIC_OP(Eq) -LOGIC_OP(Lt) -LOGIC_OP(Gt) -LOGIC_OP(Ne) -LOGIC_OP(Le) -LOGIC_OP(Ge) - -ValuePtr ScalarUAdd(const ValuePtrList &list) { - if (list.size() != 1) { - MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - return x; -} - -ValuePtr ScalarUSub(const ValuePtrList &list) { - if (list.size() != 1) { - MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - - if (x->isa()) { - int32_t sum = -1 * GetValue(x); - return MakeValue(sum); - } - if (x->isa()) { - float sum = -1.0f * GetValue(x); - return MakeValue(sum); - } - - MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; -} - -ValuePtr ScalarLog(const ValuePtrList &list) { - if (list.empty()) { - MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - - if (x->isa()) { - double v = log(GetValue(x)); - return MakeValue(v); - } - if (x->isa()) { - auto v = static_cast(log(GetValue(x))); - return MakeValue(v); - } - - MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); -} - -ValuePtr BoolNot(const ValuePtrList &list) { - if (list.empty()) { - MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; - } - ValuePtr x = list[0]; - MS_EXCEPTION_IF_NULL(x); - bool convert = false; - - if (ValueToBool(x, &convert)) { - auto res = !convert; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); -} - -ValuePtr BoolAnd(const ValuePtrList &list) { - if (list.size() < 2) { - MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; - } - ValuePtr x = list[0]; - ValuePtr y = list[1]; - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(y); - bool x_b = false; - bool y_b = false; - - if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { - auto res = x_b && y_b; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; -} - -ValuePtr BoolOr(const ValuePtrList &list) { - if (list.size() < 2) { - MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; - } - ValuePtr x = list[0]; - ValuePtr y = list[1]; - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(y); - bool x_b = false; - bool y_b = false; - - if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { - auto res = x_b || y_b; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; -} - -ValuePtr BoolEq(const ValuePtrList &list) { - if (list.size() < 2) { - MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; - } - ValuePtr x = list[0]; - ValuePtr y = list[1]; - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(y); - bool x_b = false; - bool y_b = false; - - if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { - auto res = x_b == y_b; - return MakeValue(res); - } - - MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << "."; -} - -std::vector BroadcastShape_(std::vector shpx, std::vector shpy) { - int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); - if (dlen < 0) { - for (int i = 0; i < -dlen; ++i) { - (void)shpx.insert(shpx.begin(), 1); - } - } else if (dlen > 0) { - for (int i = 0; i < dlen; i++) { - (void)shpy.insert(shpy.begin(), 1); - } - } - if (shpx.size() != shpy.size()) { - MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; - } - std::vector shp; - for (size_t i = 0; i < shpx.size(); i++) { - auto a = shpx[i]; - auto b = shpy[i]; - if (a == 1) { - shp.push_back(b); - } else if (b == 1) { - shp.push_back(a); - } else if (a == -1) { - shp.push_back(b); - } else if (b == -1) { - shp.push_back(a); - } else if (a == b) { - shp.push_back(a); - } else { - return std::vector(); - } - } - return shp; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc deleted file mode 100644 index db3055ad9a..0000000000 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ /dev/null @@ -1,971 +0,0 @@ - -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/composite.h" -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "abstract/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "abstract/dshape.h" -#include "abstract/param_validator.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" -#include "./common.h" -#include "ir/signature.h" -#include "debug/trace.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using AbstractTensor = mindspore::abstract::AbstractTensor; -using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; - -using mindspore::abstract::AbstractAttribute; -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractClass; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractDictionaryPtr; -using mindspore::abstract::AbstractEllipsis; -using mindspore::abstract::AbstractEllipsisPtr; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractFunctionPtr; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractNone; -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractSlice; -using mindspore::abstract::AbstractTuple; - -ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul}, - {"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod}, - {"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt}, - {"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe}, - {"__ge__", kPrimScalarGe}}; - -const MetaFuncGraphPtr kTail = std::make_shared("tail"); - -// copy from python API: reduce. -// Apply a function of two arguments cumulatively to the items of a sequence, -// from left to right, so as to reduce the sequence to a single value.For example, -// reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). -AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { - std::shared_ptr ret; - size_t size = list.size(); - if (size < 2) { - MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; - } - - AnyPtrList input; - input.push_back(list[0]); - input.push_back(list[1]); - ret = std::make_shared(func(input)); - - for (size_t i = 2; i < size; ++i) { - input.clear(); - input.push_back(ret); - input.push_back(list[i]); - ret = std::make_shared(func(input)); - } - - return ret; -} - -AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { - size_t size = list.size(); - if (size < 2) { - MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; - } - - std::vector input; - input.push_back(list[0]); - input.push_back(list[1]); - AnfNodePtr ret = func(input); - - for (size_t i = 2; i < size; ++i) { - input.clear(); - input.push_back(ret); - input.push_back(list[i]); - ret = func(input); - } - - return ret; -} - -ValuePtr kCompositeHyperMap = std::make_shared(); - -void HyperMap::Init() { - if (fn_leaf_) { - name_ = "hyper_map[" + fn_leaf_->name() + "]"; - } - signatures_ = - // def hypermap(func:read, *args:ref): - std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, - {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); -} - -HyperMap::HyperMap(const std::shared_ptr &fn_leaf) - : MetaFuncGraph("hyper_map"), - fn_leaf_(fn_leaf), - broadcast_(false), - nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { - Init(); -} - -HyperMap::HyperMap(const HyperMap &h) - : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { - Init(); -} - -AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector inputs; - if (fn_arg != nullptr) { - inputs.push_back(fn_arg); - } else { - inputs.push_back(NewValueNode(fn_leaf_)); - } - - (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), - [](const std::pair &item) { return item.first; }); - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { - auto lhs = std::static_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "List in HyperMap should have same length"; - } - - // cannot use shared_from_base() also known as this, as it will make a reference cycle on - // hypermap and graph generated, it will cause memory leak. - auto fn_rec = NewValueNode(std::make_shared(*this)); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeList)); - - for (int i = 0; i < SizeToInt(size); ++i) { - std::vector inputs2; - inputs2.push_back(fn_rec); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair &item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { - auto lhs = std::static_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "tuple in HyperMap should have same length"; - } - - // cannot use shared_from_base() also known as this, as it will make a reference cycle on - // hypermap and graph generated, it will cause memory leak. - auto fn_rec = NewValueNode(std::make_shared(*this)); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - for (int i = 0; i < SizeToInt(size); ++i) { - std::vector inputs2; - inputs2.push_back(fn_rec); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - MS_EXCEPTION_IF_NULL(type); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); - inputs.push_back(NewValueNode(type)); - - // cannot use shared_from_base() also known as this, as it will make a reference cycle on - // hypermap and graph generated, it will cause memory leak. - auto fn_rec = NewValueNode(std::make_shared(*this)); - std::size_t attrSize = type->GetAttributes().size(); - for (std::size_t i = 0; i < attrSize; ++i) { - std::vector inputs2; - inputs2.push_back(fn_rec); - if (fn_arg) { - inputs2.push_back(fn_arg); - } - - int j = 0; - for (auto item : arg_map) { - inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); - j++; - } - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { - bool found = false; - TypeId id = kObjectTypeEnd; - std::pair pair; - for (auto &item : arg_map) { - pair = item; - id = item.second->type_id(); - if (nonleaf_.count(id)) { - found = true; - break; - } - } - - if (found) { - // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { - if (item.first != pair.first) { - return item.second->type_id() != pair.second->type_id(); - } - return false; - }); - if (is_not_same) { - std::ostringstream oss; - oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" - << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; - int idx = 0; - for (auto &item : arg_map) { - oss << ++idx << ": " << item.second->ToString() << "\n"; - } - MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); - } - } - - switch (id) { - case kObjectTypeList: { - auto type = std::static_pointer_cast(pair.second); - return FullMake(type, func_graph, fn_arg, arg_map); - } - case kObjectTypeTuple: { - auto type = std::static_pointer_cast(pair.second); - return FullMake(type, func_graph, fn_arg, arg_map); - } - case kObjectTypeClass: { - auto type = std::static_pointer_cast(pair.second); - return FullMake(type, func_graph, fn_arg, arg_map); - } - default: - return FullMake(pair.second, func_graph, fn_arg, arg_map); - } -} - -ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { - TypePtr type_tensor = std::make_shared(); - bool flag = std::any_of( - args_spec_list.begin(), args_spec_list.end(), - [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); - if (flag && broadcast_) { - ArgsPairList ret; - for (auto &item : args_spec_list) { - if (!IsSubType(item.second, type_tensor)) { - TypePtr type_tensor_ele = std::make_shared(item.second); - ret.push_back( - std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele)); - } else { - ret.push_back(std::make_pair(item.first, item.second)); - } - } - return ret; - } - return args_spec_list; -} - -FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("hyper_map"); - - AnfNodePtr ptrFnArg = nullptr; - std::size_t i = 0; - ArgsPairList argmap; - ArgsPairList argmap2; - if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); - i = 1; - } - - std::size_t size = args_spec_list.size(); - for (; i < size; ++i) { - argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); - } - - argmap2 = Harmonize(ptrGraph, argmap); - ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2)); - return ptrGraph; -} - -abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { - if (fn_leaf_ == nullptr) { - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - // Assert that hypermap's function param does not contain free variables - if (args_spec_list[0]->isa()) { - auto graph_func = dyn_cast(args_spec_list[0]); - auto func_graph = graph_func->func_graph(); - if (func_graph->parent() != nullptr) { - MS_LOG(EXCEPTION) << "HyperMap don't support Closure with free variable yet."; - } - } - } - - AbstractBasePtrList broadened; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); - }); - return broadened; -} - -REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { - (void)py::class_>(*m, "HyperMap_") - .def(py::init>(), py::arg("leaf")) - .def(py::init<>()); - })); - -FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { - MS_EXCEPTION_IF_NULL(a_tuple); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("tail"); - AnfNodePtr ptrTup = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - - int tuple_size = SizeToInt(a_tuple->size()); - for (int i = 1; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrTup, NewValueNode(i)})); - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { - MS_EXCEPTION_IF_NULL(a_list); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("tail"); - AnfNodePtr ptrList = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeList)); - - int list_size = SizeToInt(a_list->size()); - for (int i = 1; i < list_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), ptrList, NewValueNode(i)})); - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; - } - - AbstractBasePtr a = args_spec_list[0]; - abstract::AbstractTuplePtr a_tuple = dyn_cast(a); - if (a_tuple != nullptr) { - return GenerateTupleFuncGraph(a_tuple); - } - - abstract::AbstractListPtr a_list = dyn_cast(a); - if (a_list != nullptr) { - return GenerateListFuncGraph(a_list); - } - - MS_LOG(EXCEPTION) << "arg0 must be AbstractTuple or AbstractList, but: " << a->ToString(); -} - -REGISTER_PYBIND_DEFINE( - Tail_, ([](const py::module *m) { - (void)py::class_>(*m, "Tail_").def(py::init()); - })); - -FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - int tuple_size = SizeToInt(args_spec_list.size()); - - std::ostringstream ss; - ss << "▶make_tuple_" << tuple_size; - FuncGraphPtr fg = std::make_shared(); - fg->debug_info()->set_name(ss.str()); - - std::vector params; - params.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (int i = 0; i < tuple_size; ++i) { - params.push_back(fg->add_parameter()); - } - - // make fprob first result, maketuple's forward result. - AnfNodePtr out = fg->NewCNode(params); - - // make fprob second result, maketuple's backward function. - FuncGraphPtr b = std::make_shared(); - - ss.clear(); - ss << "◀make_tuple_" << tuple_size; - b->debug_info()->set_name(ss.str()); - AnfNodePtr dout = b->add_parameter(); - - std::vector grads; - grads.push_back(NewValueNode(prim::kPrimMakeTuple)); - grads.push_back(NewValueNode(newenv)); - for (int i = 0; i < tuple_size; ++i) { - grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); - } - - b->set_flag(FUNC_GRAPH_FLAG_CORE, true); - b->set_output(b->NewCNode(grads)); - - fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); - (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); - return fg; -} - -GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) - : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { - if (get_by_list) { - signatures_ = - // def grad(func:read, weight_list:ref): - std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, - {"weight_list", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindDefault}}); - } -} - -FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, - const std::vector ¶ms_list, const std::vector &args, - bool applyJ) { - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - - auto weights_node = weights; - if (weights == nullptr && !args.empty()) { - weights_node = ret->NewCNode(args); - } - - ValueNodePtr opsJ = NewValueNode(prim::kPrimJ); - ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem); - - std::vector inputs; - if (applyJ) { - inputs.push_back(opsJ); - inputs.push_back(node); - node = ret->NewCNode(inputs); - } - - std::vector params; - for (size_t i = 0; i < params_list.size(); ++i) { - params.push_back(ret->add_parameter()); - } - - inputs.clear(); - inputs.push_back(node); - (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); - AnfNodePtr cnode = ret->NewCNode(inputs); - - inputs.clear(); - inputs.push_back(opsTupleItem); - inputs.push_back(cnode); - inputs.push_back(NewValueNode(0)); - auto out = ret->NewCNode(inputs); - - inputs.clear(); - inputs.push_back(opsTupleItem); - inputs.push_back(cnode); - inputs.push_back(NewValueNode(1)); - AnfNodePtr ptrBprop = ret->NewCNode(inputs); - - doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem); - return ret; -} - -void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, - ValueNodePtr opsTupleItem) { - MS_EXCEPTION_IF_NULL(func_graph); - - AnfNodePtr ptrBPropArg = nullptr; - if (sens_param_) { - ptrBPropArg = func_graph->add_parameter(); - } else { - auto ones_like = prim::GetPythonOps("ones_like"); - ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out}); - } - - AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg}); - - CNodePtr fv_bprop = nullptr; - if (get_by_list_) { - // python code: grads = hyper_map(F.partial(env_get, env), weights) - AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)}); - AnfNodePtr partial_env_get = - func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); - MetaFuncGraphPtr hyper_map = std::make_shared(); - fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); - } - - CNodePtr inputs_bprop = nullptr; - if (get_all_) { - inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp}); - } - - // Gradients wrt inputs and parameters - if (fv_bprop != nullptr && inputs_bprop != nullptr) { - func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); - return; - } - - // Gradients wrt parameters - if (fv_bprop != nullptr) { - func_graph->set_output(fv_bprop); - return; - } - - // Gradients wrt inputs - if (inputs_bprop != nullptr) { - func_graph->set_output(inputs_bprop); - return; - } - - // Gradients wrt first input. - // ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input - func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)})); -} - -// Generate the graph. -FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() < 1) { - MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " - << args_spec_list.size() << "."; - } - - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - AbstractFunctionPtr fn = dyn_cast(args_spec_list[0]); - if (fn == nullptr) { - MS_LOG(EXCEPTION) << "GradOperation arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); - } - - // Waiting for implementation. - auto real_fn = dyn_cast(fn); - MS_EXCEPTION_IF_NULL(real_fn); - - FuncGraphPtr ptrGraph = real_fn->func_graph(); - MS_EXCEPTION_IF_NULL(ptrGraph); - TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); - FuncGraphPtr dfBuilder = std::make_shared(); - TraceManager::EndTrace(); - auto nparam = ptrGraph->parameters().size(); - - std::ostringstream ss; - ss << "grad{" << nparam << "}"; - dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true); - dfBuilder->debug_info()->set_name(ss.str()); - ParameterPtr param_graph = dfBuilder->add_parameter(); - - AnfNodePtr weights = nullptr; - if (get_by_list_) { - weights = dfBuilder->add_parameter(); - } - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimJ)); - inputs.push_back(param_graph); - auto jf = dfBuilder->NewCNode(inputs); - // df is checked in GetGrad - TraceManager::DebugTrace(std::make_shared(ptrGraph->debug_info())); - auto df = GetGrad(jf, weights, ptrGraph->parameters()); - TraceManager::EndTrace(); - dfBuilder->set_output(NewValueNode(df)); - - return dfBuilder; -} - -REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { - (void)py::class_>( - *m, "GradOperation_") - .def(py::init(), py::arg("fn")) - .def(py::init(), py::arg("fn"), py::arg("get_all"), - py::arg("get_by_list"), py::arg("sens_param")); - })); - -// Generate the ListMap func graph. -FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - size_t args_num = args_spec_list.size(); - // args: fn, list1, list2, ... - if (args_num < 2) { - MS_LOG(EXCEPTION) << "list_map takes at least two arguments"; - } - - for (size_t i = 1; i < args_num; ++i) { - if (typeid(args_spec_list[i]) != typeid(AbstractBase)) { - // The function currently not be use - MS_LOG(EXCEPTION) << "list_map requires lists, not {t}'"; - } - } - - FuncGraphPtr fg_ptr = std::make_shared(); - fg_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fg_ptr->debug_info()->set_name("list_map"); - AnfNodePtr fn = fg_ptr->add_parameter(); - - std::vector lists; - for (size_t i = 1; i < args_num; ++i) { - lists.push_back(fg_ptr->add_parameter()); - } - - std::vector iters; - (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item}); - }); - - std::vector nexts; - (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); - }); - - std::vector values; - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item}); - }); - - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); - }); - - (void)values.insert(values.begin(), fn); - AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); - AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph}); - - FuncGraphPtr fgnext_ptr = std::make_shared(); - fgnext_ptr->debug_info()->set_name("body"); - - FuncGraphPtr fgcond_ptr = std::make_shared(); - fgcond_ptr->debug_info()->set_name("cond"); - - MakeCond(lists, fgnext_ptr, fgcond_ptr); - MakeNext(lists, fgcond_ptr, fgnext_ptr); - - CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); - - auto inputs = output_cnode->inputs(); - (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); - output_cnode->set_inputs(inputs); - - fg_ptr->set_output(output_cnode); - return fg_ptr; -} - -void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, - const FuncGraphPtr &fg_ptr) { - MS_EXCEPTION_IF_NULL(fg_ptr); - - AnfNodePtr fn = fg_ptr->add_parameter(); - AnfNodePtr resl = fg_ptr->add_parameter(); - - std::vector iters; - (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), - [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); - - std::vector hasnexts; - (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item}); - }); - - // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) - FuncGraphPtr fgtrue_ptr = std::make_shared(); - fgtrue_ptr->debug_info()->set_name("ftrue"); - fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - - CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); - auto inputs = fgtrue_output_cnode->inputs(); - (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); - fgtrue_output_cnode->set_inputs(inputs); - fgtrue_ptr->set_output(fgtrue_output_cnode); - - FuncGraphPtr fgfalse_ptr = std::make_shared(); - fgfalse_ptr->debug_info()->set_name("ffalse"); - fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fgfalse_ptr->set_output(resl); - - AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), - NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)}); - fgtrue_ptr->set_output(output_cnode); -} - -void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, - const FuncGraphPtr &fg_ptr) { - MS_EXCEPTION_IF_NULL(fg_ptr); - AnfNodePtr fn = fg_ptr->add_parameter(); - - std::vector iters; - (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), - [fg_ptr](AnfNodePtr) { return fg_ptr->add_parameter(); }); - - std::vector nexts; - (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); - }); - - std::vector values; - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr}); - }); - - iters.clear(); - (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(1)}); - }); - - (void)values.insert(values.begin(), fn); - AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); - AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph}); - CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); - - auto inputs = output_cnode->inputs(); - (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); - output_cnode->set_inputs(inputs); - fg_ptr->set_output(output_cnode); -} - -FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // args: tuple1, tuple2 - abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); - AbstractBasePtr abs_a = args_spec_list[0]; - AbstractBasePtr abs_b = args_spec_list[1]; - - abstract::AbstractTuplePtr a_tuple = dyn_cast(abs_a); - abstract::AbstractTuplePtr b_tuple = dyn_cast(abs_b); - if (a_tuple == nullptr || b_tuple == nullptr) { - MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " - << args_spec_list[1]->ToString(); - } - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr p_tup_a = ret->add_parameter(); - AnfNodePtr p_tup_b = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - - int tuple_size = SizeToInt(a_tuple->size()); - for (int i = 0; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)})); - } - - tuple_size = SizeToInt(b_tuple->size()); - for (int i = 0; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)})); - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { - MS_EXCEPTION_IF_NULL(scalar); - return GetValue(scalar->BuildValue()); -} - -bool CheckIndexInRange(int index, int min, int max) { return (index >= min && index <= max); } - -int GetPositiveIndex(int index, int length) { - if (index < 0) { - index += length; - } - return index; -} - -int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { - MS_EXCEPTION_IF_NULL(member); - - if (member->isa()) { - return GetArgScalarValue(dyn_cast(member), member_name); - } - - if (member->isa()) { - return default_value; - } - - MS_LOG(EXCEPTION) << member_name << " should be a AbstractScalar or AbstractNone, but got " << member->ToString(); -} - -void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, - int *stop_index, int *step_value) { - MS_EXCEPTION_IF_NULL(tuple); - MS_EXCEPTION_IF_NULL(slice); - MS_EXCEPTION_IF_NULL(start_index); - MS_EXCEPTION_IF_NULL(stop_index); - MS_EXCEPTION_IF_NULL(step_value); - - const std::string start_name("Slice start index"); - const std::string stop_name("Slice stop index"); - const std::string step_name("Slice step value"); - - int tuple_size = SizeToInt(tuple->size()); - int start_default = 0; - int stop_default = tuple_size; - int step_default = 1; - - *step_value = CheckSliceMember(slice->step(), step_default, step_name); - if (*step_value == 0) { - MS_LOG(EXCEPTION) << "TupleSlice require the step value could not be 0, but got 0."; - } - - if (*step_value < 0) { - start_default = tuple_size - 1; - stop_default = -1; - } - - *start_index = CheckSliceMember(slice->start(), start_default, start_name); - *stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name); - if (!CheckIndexInRange(*start_index, -tuple_size, tuple_size - 1) || - !CheckIndexInRange(*stop_index, -tuple_size - 1, tuple_size)) { - MS_LOG(EXCEPTION) << "TupleSlice the start index " << *start_index << " or end end index " << *stop_index - << " out of range, tuple size " << tuple_size << "."; - } - - *start_index = GetPositiveIndex(*start_index, tuple_size); - if (!slice->stop()->isa()) { - *stop_index = GetPositiveIndex(*stop_index, tuple_size); - } -} - -FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // slice a tuple - // args: tuple, start index, end index, step - const std::string op_name("TupleSlice"); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr tuple = abstract::CheckArg(op_name, args_spec_list, 0); - AbstractSlicePtr slice = abstract::CheckArg(op_name, args_spec_list, 1); - - int start_index; - int stop_index; - int step_value; - GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr p_tuple = ret->add_parameter(); - (void)ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeTuple)); - if (step_value > 0) { - for (int index = start_index; index < stop_index; index = index + step_value) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); - } - } else { - for (int index = start_index; index > stop_index; index = index + step_value) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); - } - } - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // select indexed item - // args: tuple of items, index - const std::string op_name = std::string("TupleGetItemTensor"); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr branches_abs = abstract::CheckArg(op_name, args_spec_list, 0); - AbstractBasePtrList branches = branches_abs->elements(); - if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa()) { - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr functions = ret_graph->add_parameter(); - auto index = ret_graph->add_parameter(); - - ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); - return ret_graph; - } - - MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; -} - -REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { - (void)py::class_>(*m, "TupleAdd_") - .def(py::init()); - })); - -REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { - (void)py::class_>(*m, "TupleSlice_") - .def(py::init()); - })); - -REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { - (void)py::class_>( - *m, "TupleGetItemTensor_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h deleted file mode 100644 index 5944c81fb0..0000000000 --- a/mindspore/ccsrc/operator/composite/composite.h +++ /dev/null @@ -1,192 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "operator/composite/zip_operation.h" -#include "operator/composite/list_append_operation.h" -#include "operator/composite/do_signature.h" -#include "operator/composite/unpack_call.h" -#include "operator/composite/multitype_funcgraph.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using AbstractSlicePtr = abstract::AbstractSlicePtr; -using AbstractScalarPtr = abstract::AbstractScalarPtr; -using AbstractTensorPtr = abstract::AbstractTensorPtr; -using ElemwiseMap = std::unordered_map; -using ArgsPairList = std::vector>; - -class HyperMap : public MetaFuncGraph { - public: - explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); - HyperMap(const HyperMap &h); - void Init(); - HyperMap &operator=(const HyperMap &h) { - if (this != &h) { - fn_leaf_ = h.fn_leaf_; - broadcast_ = h.broadcast_; - nonleaf_ = h.nonleaf_; - if (fn_leaf_) { - name_ = "hyper_map[" + fn_leaf_->name() + "]"; - } - } - return *this; - } - ~HyperMap() override = default; - MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) - - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; - MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } - - private: - AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_map); - AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); - ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); - - MultitypeFuncGraphPtr fn_leaf_; - bool broadcast_; - std::set nonleaf_; -}; -using HyperMapPtr = std::shared_ptr; - -class HyperMapPy : public HyperMap { - public: - explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} - ~HyperMapPy() override = default; - MS_DECLARE_PARENT(HyperMapPy, HyperMap) -}; -using HyperMapPyPtr = std::shared_ptr; - -extern ValuePtr kCompositeHyperMap; - -class Tail : public MetaFuncGraph { - public: - explicit Tail(const std::string &name) : MetaFuncGraph(name) {} - ~Tail() override = default; - MS_DECLARE_PARENT(Tail, MetaFuncGraph) - - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); - FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); - - friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } -}; -using TailPtr = std::shared_ptr; - -class MakeTupleGradient : public MetaFuncGraph { - public: - explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} - ~MakeTupleGradient() override = default; - MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } -}; -using MakeTupleGradientPtr = std::shared_ptr; - -class GradOperation : public MetaFuncGraph { - public: - explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, - bool sens_param = false); - ~GradOperation() override = default; - MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) - - FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, - const std::vector &args = {}, bool applyJ = false); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - bool sens_param() const { return sens_param_; } - bool get_all_; - bool get_by_list_; - bool sens_param_; - - private: - void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, - ValueNodePtr opsTupleItem); -}; -using GradOperationPtr = std::shared_ptr; - -class ListMap { - public: - explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } - ~ListMap() = default; - void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); - void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); - - private: - std::string name_; - std::map, FuncGraphPtr> cache_; -}; - -class TupleAdd : public MetaFuncGraph { - public: - explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} - ~TupleAdd() override = default; - MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } -}; -using TupleAddPtr = std::shared_ptr; - -class TupleSlice : public MetaFuncGraph { - public: - explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} - ~TupleSlice() override = default; - MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } -}; -using TupleSlicePtr = std::shared_ptr; - -class TupleGetItemTensor : public MetaFuncGraph { - public: - explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} - ~TupleGetItemTensor() override = default; - MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { - return lhs.name_ == rhs.name_; - } -}; -using TupleGetItemTensorPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc deleted file mode 100644 index 90ecfdb9f9..0000000000 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ /dev/null @@ -1,338 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/do_signature.h" -#include -#include - -#include "abstract/abstract_value.h" -#include "ir/anf.h" -#include "abstract/dshape.h" -#include "abstract/param_validator.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "./common.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt8, 2}, {kNumberTypeUInt8, 3}, - {kNumberTypeInt16, 4}, {kNumberTypeInt32, 5}, {kNumberTypeInt64, 6}, - {kNumberTypeFloat16, 7}, {kNumberTypeFloat32, 8}, {kNumberTypeFloat64, 9}}; -namespace { -const std::vector &GetSignature(const ValuePtr &function) { - static const auto empty = std::vector(); - if (function->isa() && function->cast()->has_signature()) { - return function->cast()->signatures(); - } else if (function->isa()) { - return function->cast()->signatures(); - } - return empty; -} - -void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, - const std::vector &signature, bool has_var, std::vector *const op_inputs) { - std::size_t sig_size = signature.size(); - auto positional_size = sig_size; - if (has_var) { - positional_size = sig_size - 1; - } - if (args_spec_list.size() < positional_size) { - for (size_t i = args_spec_list.size(); i < sig_size; ++i) { - auto default_value = signature[i].default_value; - if (default_value == nullptr) { - MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; - } else { - (*op_inputs).push_back(NewValueNode(default_value)); - } - } - } -} - -void SetMaxType(TypeId *max_type_id, size_t *max_type_number, const TypeId type_id, const size_t type_number) { - *max_type_id = type_id; - *max_type_number = type_number; -} - -bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, - TypeId *arg_type = nullptr) { - if (arg_value->isa()) { - if (is_write) { - arg_value = arg_value->cast()->ref_origin(); - } else { - arg_value = arg_value->cast()->ref(); - } - } - if (arg_value->isa()) { - auto tensor = arg_value->cast(); - auto tensor_type = tensor->element()->BuildType(); - MS_EXCEPTION_IF_NULL(tensor_type); - *arg_type_id = tensor_type->type_id(); - if (arg_type != nullptr) { - *arg_type = kObjectTypeTensorType; - } - return true; - } - if (arg_value->isa()) { - auto scalar = arg_value->cast(); - auto scalar_type = scalar->BuildType(); - MS_EXCEPTION_IF_NULL(scalar_type); - *arg_type_id = scalar_type->type_id(); - if (arg_type != nullptr) { - *arg_type = kObjectTypeNumber; - } - return true; - } - return false; -} - -TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, - const std::set &write_indices) { - TypeId max_type_id = kTypeUnknown; - size_t max_type_number = 0; - bool has_int8 = false; - bool has_scalar_int32 = false; - bool has_scalar_float32 = false; - for (const auto &index : indices) { - TypeId arg_type_id = kTypeUnknown; - TypeId arg_type = kTypeUnknown; - auto is_write = (write_indices.find(index) != write_indices.end()); - if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { - continue; - } - if (arg_type != kObjectTypeTensorType) { - if (arg_type_id == kNumberTypeInt32) { - has_scalar_int32 = true; - } else if (arg_type_id == kNumberTypeFloat32) { - has_scalar_float32 = true; - } - continue; - } - auto it = type_map.find(arg_type_id); - if (it == type_map.end()) { - continue; - } - if (arg_type_id == kNumberTypeInt8) { - has_int8 = true; - } - if (max_type_id == kTypeUnknown) { - SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); - continue; - } - if (it->second > max_type_number) { - SetMaxType(&max_type_id, &max_type_number, arg_type_id, it->second); - } - } - - if (max_type_id == kNumberTypeUInt8 && has_int8 == true) { - max_type_id = kNumberTypeInt16; - } - // if bool is the max type, see if there is scalar input - // if so, it means that max is bool tensor, use scalar type instead. - // for example: Tensor([True, True]) * 2, expect result is Tensor([2, 2]) - if (max_type_id == kNumberTypeBool) { - if (has_scalar_int32) { - max_type_id = kNumberTypeInt32; - } - if (has_scalar_float32) { - max_type_id = kNumberTypeFloat32; - } - } - return max_type_id; -} - -// Get the largest type of index in the same SignatureEnumDType of arguments. -std::map GetMaxDtype(const std::vector &dtypes, - const abstract::AbstractBasePtrList &args_spec_list, - const std::set &write_indices) { - // record index for signature.dtypes of the same type - // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} - std::map> type_indices; - for (size_t i = 0; i < dtypes.size(); ++i) { - auto it = type_indices.find(dtypes[i]); - if (it == type_indices.end()) { - (void)type_indices.insert(std::make_pair(dtypes[i], std::vector{i})); - } else { - it->second.push_back(i); - } - } - std::map dst_type; - for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { - auto type = it->first; - auto indices = it->second; - // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. - if (indices.size() < 2) { - continue; - } - bool has_tensor = false; - for (const auto &index : indices) { - AbstractBasePtr arg_value = args_spec_list[index]; - if (arg_value->isa()) { - arg_value = arg_value->cast()->ref(); - } - if (arg_value->isa()) { - has_tensor = true; - break; - } - } - if (!has_tensor) { - (void)dst_type.insert(std::make_pair(type, kTypeUnknown)); - continue; - } - (void)dst_type.insert(std::make_pair(type, GetMaxTypeId(args_spec_list, indices, write_indices))); - } - return dst_type; -} - -AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGraphPtr &graph) { - auto prim_cast_class = prim::GetPythonOps("Cast", "mindspore.ops.operations"); - MS_EXCEPTION_IF_NULL(prim_cast_class); - auto dtype_node = NewValueNode(TypeIdToType(type_id)); - auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph); - return NewCNode({cast_node, param, dtype_node}, graph); -} - -void DoAutoCast(const std::string &func_name, const std::vector &signature, - const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, - std::vector *const op_inputs, const std::set &write_indices) { - std::vector dtypes; - (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature &sig) { return sig.dtype; }); - int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); - if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { - return; - } - // Stat the index of the arguments with the largest type in the same SignatureEnumDType. - std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); - // Identify which arg requires auto cast - for (size_t i = 0; i < args_spec_list.size(); ++i) { - auto it = dst_type.find(dtypes[i]); - if (it == dst_type.end() || it->second == kTypeUnknown) { - continue; - } - auto rw_it = write_indices.find(i); - auto is_write = (rw_it != write_indices.end()); - - TypeId arg_type_id = kTypeUnknown; - AbstractBasePtr arg_value = args_spec_list[i]; - (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); - auto it_map = type_name_map.find(arg_type_id); - if (it_map == type_name_map.end()) { - continue; - } - if (is_write) { - if (arg_type_id != it->second) { - auto it_name_map = type_name_map.find(it->second); - if (it_name_map == type_name_map.end()) { - continue; - } - RaiseExceptionForConvertRefDtype(func_name, it_map->second, it_name_map->second); - } - continue; - } - if (arg_value->isa() && arg_type_id == it->second) { - continue; - } - (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], it->second, graph); - } -} - -AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, - const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { - // args: original inputs - auto &signature = GetSignature(function); - std::size_t sig_size = signature.size(); - auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); - if (sig_size > 0) { - if (has_var) { - if (sig_size - 1 > args_spec_list.size()) { - MS_LOG(EXCEPTION) << "Function " << func_name - << "'s input length less than PositionalKeyword Signature length."; - } - } else if (args_spec_list.size() > sig_size) { - MS_LOG(EXCEPTION) << "Function " << func_name << "'s input length is not equal to Signature length."; - } - } - std::vector op_inputs; - std::set write_indices; - op_inputs.push_back(NewValueNode(function)); - // Assume, the write input of op is always the first input. We check if any write op, - // and add cast op on other inputs to keep the same type with assigned parameter. - for (size_t i = 0; i < args_spec_list.size(); ++i) { - AnfNodePtr param = params_list[i]; - if (args_spec_list[i] == nullptr) { - op_inputs.push_back(param); - continue; - } - SignatureEnumRW sig = SignatureEnumRW::kRWDefault; - // If sig_size is 0 use defalut. - if (sig_size > 0 && i < sig_size) { - sig = signature[i].rw; - } else if (has_var && i >= sig_size) { - sig = signature[sig_size - 1].rw; - } - - TypePtr type = args_spec_list[i]->GetTypeTrack(); - if (type && type->type_id() == kObjectTypeRef) { - if (sig == SignatureEnumRW::kRWRead) { - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); - } else if (sig == SignatureEnumRW::kRWWrite) { - param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); - write_indices.insert(i); - } - // If sig is SignatureEnumRW::kRWRef, not do anything. - } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { - MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; - } - op_inputs.push_back(param); - } - // process default - ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); - DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); - return func_graph->NewCNode(op_inputs); -} -} // namespace - -AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, - const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { - auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); - return new_cnode; -} - -FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - FuncGraphPtr func_graph = std::make_shared(); - - for (size_t i = 0; i < args_spec_list.size(); ++i) { - (void)func_graph->add_parameter(); - } - auto new_cnode = BuildNewCNode(func_graph, name_, function_, args_spec_list, func_graph->parameters()); - func_graph->set_output(new_cnode); - func_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - return func_graph; -} - -void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, - const std::string &target_type) { - MS_LOG(EXCEPTION) << "In op '" << func_name << "', \n" - << "the type of writable argument is '" << ref_type << "', " - << "but the largest type in the same SignatureEumDtype is '" << target_type - << "'. The writable arg type is not equal to the largest type, " - << "so can not cast automatically."; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/do_signature.h b/mindspore/ccsrc/operator/composite/do_signature.h deleted file mode 100644 index 97f6d7e7a5..0000000000 --- a/mindspore/ccsrc/operator/composite/do_signature.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" -#include "common/utils.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -class DoSignatureMetaFuncGraph : public MetaFuncGraph { - public: - explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) - : MetaFuncGraph("S-" + name), function_(function) {} - - ~DoSignatureMetaFuncGraph() override = default; - - MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) - - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; - const ValuePtr function() const { return function_; } - - friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { - return &lhs == &rhs; - } - - private: - ValuePtr function_; -}; -using RWSignaturePtr = std::shared_ptr; - -extern const std::map type_map; - -void RaiseExceptionForConvertRefDtype(const std::string &func_name, const std::string &ref_type, - const std::string &target_type); - -AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, - const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_DO_SIGNATURE_H_ diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.cc b/mindspore/ccsrc/operator/composite/list_append_operation.cc deleted file mode 100644 index 076ae5d41b..0000000000 --- a/mindspore/ccsrc/operator/composite/list_append_operation.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/list_append_operation.h" - -#include -#include -#include - -#include "abstract/param_validator.h" -#include "optimizer/opt.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { - abstract::CheckArgsSize("ListAppend", args_list, 2); - - AbstractBasePtr arg0 = args_list[0]; - abstract::AbstractListPtr arg0_list = dyn_cast(arg0); - MS_EXCEPTION_IF_NULL(arg0_list); - - FuncGraphPtr ret = std::make_shared(); - ret->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ret->debug_info()->set_name("append"); - AnfNodePtr arg0_node = ret->add_parameter(); - - std::vector elems; - elems.push_back(NewValueNode(prim::kPrimMakeList)); - size_t arg0_length = arg0_list->size(); - for (size_t i = 0; i < arg0_length; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToInt(i))})); - } - AnfNodePtr arg1_node = ret->add_parameter(); - elems.push_back(arg1_node); - - ret->set_output(ret->NewCNode(elems)); - return ret; -} - -REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { - (void)py::class_>(*m, "ListAppend_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/map.cc b/mindspore/ccsrc/operator/composite/map.cc deleted file mode 100644 index eb8b4b6df1..0000000000 --- a/mindspore/ccsrc/operator/composite/map.cc +++ /dev/null @@ -1,292 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/map.h" -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "abstract/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "abstract/dshape.h" -#include "pybind_api/api_register.h" -#include "debug/trace.h" -#include "operator/ops.h" -#include "./common.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure; - -AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args) { - MS_LOG(DEBUG) << "Map FullMakeLeaf non recursive.\n"; - MS_EXCEPTION_IF_NULL(func_graph); - std::vector inputs; - if (fn_arg != nullptr) { - inputs.emplace_back(fn_arg); - } else { - inputs.emplace_back(NewValueNode(fn_leaf_)); - } - inputs.insert(inputs.end(), args.begin(), args.end()); - return func_graph->NewCNode(inputs); -} - -FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { - // Generate func for leaf nodes - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("map"); - AnfNodePtr ptrFnArg = nullptr; - if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); - } - AnfNodePtrList args; - for (size_t i = 0; i < args_size; ++i) { - args.emplace_back(ptrGraph->add_parameter()); - } - ptrGraph->set_output(FullMakeLeaf(ptrGraph, ptrFnArg, args)); - return ptrGraph; -} - -AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { - auto lhs = std::dynamic_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "List in Map should have same length"; - } - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeList)); - - for (int i = 0; i < SizeToInt(size); ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the target"; - auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); - auto fn = NewValueNode(ptrGraph); - - std::vector inputs2; - inputs2.push_back(fn); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair &item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(type); - - std::size_t size = type->elements().size(); - bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [size](const std::pair &item) { - auto lhs = std::dynamic_pointer_cast(item.second); - MS_EXCEPTION_IF_NULL(lhs); - return lhs->elements().size() != size; - }); - if (is_not_same) { - MS_LOG(EXCEPTION) << "tuple in Map should have same length"; - } - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - for (int i = 0; i < SizeToInt(size); ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th arg of the tuple inputs"; - auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); - auto fn = NewValueNode(ptrGraph); - - std::vector inputs2; - inputs2.push_back(fn); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - (void)std::transform( - arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), - [&func_graph, &i](std::pair item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); - }); - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, - const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - MS_EXCEPTION_IF_NULL(type); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimMakeRecord)); - inputs.push_back(NewValueNode(type)); - - std::size_t attrSize = type->GetAttributes().size(); - for (std::size_t i = 0; i < attrSize; ++i) { - MS_LOG(DEBUG) << "GenerateLeafFunc for the " << i << "th element of the inputs"; - auto ptrGraph = GenerateLeafFunc(arg_pairs.size()); - auto fn = NewValueNode(ptrGraph); - - std::vector inputs2; - inputs2.push_back(fn); - if (fn_arg != nullptr) { - inputs2.push_back(fn_arg); - } - - int j = 0; - for (auto item : arg_pairs) { - inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); - j++; - } - - inputs.push_back(func_graph->NewCNode(inputs2)); - } - return func_graph->NewCNode(inputs); -} - -AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { - if (arg_pairs.empty()) { - MS_EXCEPTION(TypeError) << "map() must have at least two arguments"; - } - bool found = false; - TypeId id = kObjectTypeEnd; - std::pair pair; - for (auto &item : arg_pairs) { - pair = item; - MS_LOG(DEBUG) << "Map " << pair.second->ToString(); - id = item.second->type_id(); - if (nonleaf_.count(id)) { - found = true; - break; - } - } - - if (found) { - // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = - std::any_of(arg_pairs.begin(), arg_pairs.end(), [pair](const std::pair &item) { - if (item.first != pair.first) { - return item.second->type_id() != pair.second->type_id(); - } - return false; - }); - if (is_not_same) { - std::ostringstream oss; - oss << "There are " << arg_pairs.size() << " inputs of `" << name_ << "`, corresponding type info:\n" - << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; - int idx = 0; - for (auto &item : arg_pairs) { - oss << ++idx << ": " << item.second->ToString() << "\n"; - } - MS_LOG(EXCEPTION) << "Map cannot match up all input types of arguments.\n" - << oss.str() << pair.second->ToString() << "\n"; - } - } - - switch (id) { - case kObjectTypeList: { - auto type = std::static_pointer_cast(pair.second); - return FullMakeList(type, func_graph, fn_arg, arg_pairs); - } - case kObjectTypeTuple: { - auto type = std::static_pointer_cast(pair.second); - return FullMakeTuple(type, func_graph, fn_arg, arg_pairs); - } - case kObjectTypeClass: { - auto type = std::static_pointer_cast(pair.second); - return FullMakeClass(type, func_graph, fn_arg, arg_pairs); - } - default: - MS_LOG(EXCEPTION) << "Map can only be applied to list, tuple and class " - << ", but got " << pair.second->ToString(); - } -} - -FuncGraphPtr Map::GenerateFromTypes(const TypePtrList &args_spec_list) { - FuncGraphPtr ptrGraph = std::make_shared(); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); - ptrGraph->debug_info()->set_name("map"); - - AnfNodePtr ptrFnArg = nullptr; - std::size_t i = 0; - if (fn_leaf_ == nullptr) { - ptrFnArg = ptrGraph->add_parameter(); - i = 1; - } - ArgsPairList arg_pairs; - std::size_t size = args_spec_list.size(); - for (; i < size; ++i) { - MS_LOG(DEBUG) << "GenerateFromTypes for elements from " << args_spec_list[i]->ToString(); - arg_pairs.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i])); - } - - ptrGraph->set_output(Make(ptrGraph, ptrFnArg, arg_pairs)); - return ptrGraph; -} - -abstract::AbstractBasePtrList Map::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { - if (fn_leaf_ == nullptr) { - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - // Assert that map's function param does not contain free variables - if (args_spec_list[0]->isa()) { - auto graph_func = dyn_cast(args_spec_list[0]); - auto func_graph = graph_func->func_graph(); - if (func_graph->parent() != nullptr) { - MS_LOG(EXCEPTION) << "Map don't support Closure with free variable yet."; - } - } - } - - AbstractBasePtrList broadened; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); - }); - return broadened; -} - -REGISTER_PYBIND_DEFINE(Map_, ([](const py::module *m) { - (void)py::class_>(*m, "Map_") - .def(py::init>(), py::arg("leaf")) - .def(py::init<>()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/map.h b/mindspore/ccsrc/operator/composite/map.h deleted file mode 100644 index 02d374214a..0000000000 --- a/mindspore/ccsrc/operator/composite/map.h +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ - -#include -#include -#include -#include - -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" -#include "operator/composite/multitype_funcgraph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using ArgsPairList = std::vector>; - -class Map : public MetaFuncGraph { - public: - explicit Map(const std::shared_ptr &fn_leaf = nullptr) - : MetaFuncGraph("map"), - fn_leaf_(fn_leaf), - broadcast_(false), - nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { - Init(); - } - Map(const Map &h) : MetaFuncGraph("map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { - Init(); - } - Map &operator=(const Map &h) { - if (this != &h) { - fn_leaf_ = h.fn_leaf_; - broadcast_ = h.broadcast_; - nonleaf_ = h.nonleaf_; - if (fn_leaf_) { - name_ = "map[" + fn_leaf_->name() + "]"; - } - } - return *this; - } - ~Map() override = default; - MS_DECLARE_PARENT(Map, MetaFuncGraph) - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; - MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } - - private: - FuncGraphPtr GenerateLeafFunc(const size_t &args_size); - AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); - AnfNodePtr FullMakeList(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_pairs); - AnfNodePtr FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_pairs); - AnfNodePtr FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, - const ArgsPairList &arg_pairs); - AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); - void Init() { - if (fn_leaf_ != nullptr) { - name_ = "map[" + fn_leaf_->name() + "]"; - } - signatures_ = - // def map(func:read, *args:ref): - std::vector({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, - {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); - } - - MultitypeFuncGraphPtr fn_leaf_; - bool broadcast_; - std::set nonleaf_; -}; -using MapPtr = std::shared_ptr; -class MapPy : public Map { - public: - explicit MapPy(const std::shared_ptr &fn_leaf = nullptr) : Map(fn_leaf) {} - ~MapPy() override = default; - MS_DECLARE_PARENT(MapPy, Map) -}; -using MapPyPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MAP_H_ diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc deleted file mode 100644 index bc51bb6395..0000000000 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc +++ /dev/null @@ -1,198 +0,0 @@ - -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/multitype_funcgraph.h" -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "abstract/abstract_value.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "abstract/dshape.h" -#include "abstract/param_validator.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "utils/context/ms_context.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" -#include "./common.h" -#include "ir/signature.h" -#include "debug/trace.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { - fn_cache_.clear(); - signatures_ = std::vector({// def multitype(*args:ref): - {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); -} - -void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { - MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; - auto fn = fn_cache_.find(types); - if (fn != fn_cache_.end()) { - MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; - } - fn_cache_[types] = s_fn; -} - -void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { - MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; - auto fn = fn_cache_.find(types); - if (fn != fn_cache_.end()) { - MS_LOG(EXCEPTION) << "Cannot register as (" << ::mindspore::ToString(types) << ", already registered."; - } - fn_cache_py_[types] = py_fn; -} - -void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { - TypePtrList types; - for (auto &type_name : types_name) { - auto type_ptr = StringToType(type_name); - if (type_ptr == nullptr) { - MS_LOG(EXCEPTION) << type_name << " convert from string error "; - } - types.push_back(type_ptr); - } - Register(types, py_fn); -} - -void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { - std::vector types_name; - for (size_t it = 0; it < tuple.size(); ++it) { - py::object name_py = tuple[it]; - if (py::isinstance(name_py)) { - types_name.push_back(name_py.cast()); - continue; - } - MS_LOG(EXCEPTION) << "Register must be string"; - } - Register(types_name, py_fn); -} -static TypePtr UnwrapRef(const TypePtr &type) { - if (type->isa()) { - return type->cast()->subtype(); - } - return type; -} - -// Return Exact match if exists, else return non ambiguous sub class match -// Return py::none() if matching is ambiguous -const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { - // Exact match - for (auto &item : fn_cache_py_) { - TypePtrList sign = item.first; - if (sign.size() != types.size()) { - continue; - } - auto match = true; - for (size_t i = 0; i < sign.size(); ++i) { - if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { - match = false; - break; - } - } - if (!match) { - continue; - } - return item.second; - } - return py::none(); -} - -FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); - if (!enable_sparse) { - return nullptr; - } - - std::vector parameters; - ParameterPtr undetermined_param = nullptr; - auto stub = std::make_shared(); - for (size_t i = 0; i < types.size(); ++i) { - auto param = stub->add_parameter(); - parameters.push_back(param); - if (types[i]->type_id() == kObjectTypeUndeterminedType) { - undetermined_param = param; - } - } - if (undetermined_param != nullptr) { - std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; - for (size_t i = 0; i < types.size(); ++i) { - if (types[i]->type_id() == kObjectTypeFunction) { - std::vector call_prim{parameters[i], undetermined_param}; - inputs.push_back(stub->NewCNode(call_prim)); - } else { - inputs.push_back(parameters[i]); - } - } - auto stub_output = stub->NewCNode(inputs); - stub->set_output(stub_output); - stub->set_stub(true); - return stub; - } - return nullptr; -} - -FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { - auto py_fn = SignMatch(types); - std::ostringstream buffer; - buffer << types; - if (py_fn != py::none()) { - FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); - } - MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); - return func_graph; - } - auto stub = GenerateStubFunc(types); - if (stub != nullptr) { - MS_LOG(DEBUG) << "GenerateStubFunc " << buffer.str() << ", function: " << stub->ToString(); - return stub; - } - std::ostringstream oss; - oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ - << "`, corresponding location info:\n"; - int idx = 0; - for (auto &item : fn_cache_py_) { - FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); - if (func_graph == nullptr) { - MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; - continue; - } - oss << ++idx << ". " << item.first << "\n " << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; - } - MS_LOG(EXCEPTION) << "The '" << name_ << "' operation does not support the type " << buffer.str() << "\n" - << oss.str(); -} - -REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { - (void)py::class_>( - *m, "MultitypeFuncGraph_") - .def(py::init()) - .def("register_fn", &MultitypeFuncGraph::PyRegister); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/operator/composite/multitype_funcgraph.h deleted file mode 100644 index ababf21883..0000000000 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_MULTITYPE_FUNCGRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -class MultitypeFuncGraph : public MetaFuncGraph { - public: - explicit MultitypeFuncGraph(const std::string &name); - ~MultitypeFuncGraph() override = default; - MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) - - using specialize_fn = FuncGraph *(*)(TypePtrList); - // Register a method which specialize based on types vectors; - virtual void Register(const TypePtrList &types, specialize_fn s_fn); - virtual void Register(const TypePtrList &types, const py::function &py_fn); - virtual void Register(const std::vector &types_name, const py::function &py_fn); - virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); - - FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; - size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } - const std::unordered_map &GetPyFunctions() const { - return fn_cache_py_; - } - - private: - const py::function SignMatch(const TypePtrList &types); - std::unordered_map fn_cache_; - std::unordered_map fn_cache_py_; -}; -using MultitypeFuncGraphPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ diff --git a/mindspore/ccsrc/operator/composite/unpack_call.cc b/mindspore/ccsrc/operator/composite/unpack_call.cc deleted file mode 100644 index 96298c9250..0000000000 --- a/mindspore/ccsrc/operator/composite/unpack_call.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/unpack_call.h" -#include -#include - -#include "./common.h" -#include "abstract/abstract_value.h" -#include "abstract/dshape.h" -#include "abstract/param_validator.h" -#include "operator/cc_implementations.h" -#include "ir/anf.h" -#include "optimizer/opt.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using mindspore::abstract::AbstractAttribute; -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractDictionaryPtr; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractKeywordArg; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; - -FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // slice a tensor - // args: tensor, slice or slice tuple - const std::string op_name = std::string("UnpackCall"); - size_t arg_length = args_spec_list.size(); - if (arg_length < 2) { - MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; - } - - (void)abstract::CheckArg(op_name, args_spec_list, 0); - auto ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - - AnfNodePtr fnNode = ret_graph->add_parameter(); - std::vector elems; - elems.push_back(fnNode); - for (size_t index = 1; index < arg_length; index++) { - MS_EXCEPTION_IF_NULL(args_spec_list[index]); - if (args_spec_list[index]->isa()) { - auto arg_tuple = args_spec_list[index]->cast(); - AnfNodePtr para_tuple = ret_graph->add_parameter(); - for (size_t i = 0; i < arg_tuple->size(); ++i) { - elems.push_back( - ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); - } - } else if (args_spec_list[index]->isa()) { - AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast(); - AnfNodePtr para_dict = ret_graph->add_parameter(); - auto dict_elems = arg_dict->elements(); - (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), - [ret_graph, para_dict](const AbstractAttribute &item) { - auto dict_get_item = ret_graph->NewCNode( - {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); - return ret_graph->NewCNode( - {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); - }); - } else { - MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " - << args_spec_list[index]->ToString(); - } - } - ret_graph->set_output(ret_graph->NewCNode(elems)); - return ret_graph; -} - -REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { - (void)py::class_>(*m, "UnpackCall_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h deleted file mode 100644 index 8c055a9386..0000000000 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" -#include "common/utils.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -// Expand the tuple and dict parameters generated when parsing the function call, -// and generate positional parameters and key-value pairs for function. -class UnpackCall : public MetaFuncGraph { - public: - explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} - ~UnpackCall() override = default; - MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } -}; -using UnpackCallPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ diff --git a/mindspore/ccsrc/operator/composite/zip_operation.cc b/mindspore/ccsrc/operator/composite/zip_operation.cc deleted file mode 100644 index 89118c7b3b..0000000000 --- a/mindspore/ccsrc/operator/composite/zip_operation.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/composite/zip_operation.h" -#include - -#include "abstract/abstract_value.h" -#include "ir/anf.h" -#include "abstract/dshape.h" -#include "operator/cc_implementations.h" -#include "optimizer/opt.h" -#include "pybind_api/api_register.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractSequeue; -using mindspore::abstract::AbstractSequeuePtr; -using mindspore::abstract::AbstractTuple; - -FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { - // zip operation: - // input: tuple arguments - // output: tuple of items of input iterated on every input - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "For 'zip', there is at least one input."; - } - - auto is_all_sequeue = - std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { - MS_EXCEPTION_IF_NULL(abs); - return abs->isa(); - }); - if (!is_all_sequeue) { - MS_LOG(EXCEPTION) << "For 'zip', all inputs must be sequence."; - } - - auto min_abs = std::min_element( - args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) { - return (x->cast()->size() < y->cast()->size()); - }); - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - for (size_t idx = 0; idx < args_spec_list.size(); idx++) { - (void)ret_graph->add_parameter(); - } - - // generate tuple output of ziped arguments input - std::vector make_tuple_nodes; - make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t idx = 0; idx < (*min_abs)->cast()->size(); idx++) { - std::vector make_tuple_zip_nodes; - make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl"; - ValuePtr op = prim::GetPythonOps("getitem", module_name); - for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) { - std::vector tuple_get_item_nodes{NewValueNode(op), ret_graph->parameters()[arg_idx], - NewValueNode(SizeToInt(idx))}; - auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes); - make_tuple_zip_nodes.push_back(tuple_get_item_op); - } - auto make_tuple_zip_op = ret_graph->NewCNode(make_tuple_zip_nodes); - make_tuple_nodes.push_back(make_tuple_zip_op); - } - ret_graph->set_output(ret_graph->NewCNode(make_tuple_nodes)); - return ret_graph; -} - -REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { - (void)py::class_>(*m, - "ZipOperation_") - .def(py::init()); - })); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/zip_operation.h b/mindspore/ccsrc/operator/composite/zip_operation.h deleted file mode 100644 index 1a3fa1f5fe..0000000000 --- a/mindspore/ccsrc/operator/composite/zip_operation.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ -#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/misc.h" -#include "utils/any.h" -#include "ir/dtype.h" -#include "ir/meta_func_graph.h" - -namespace mindspore { -// namespace to support composite operators definition -namespace prim { -using AbstractBasePtr = abstract::AbstractBasePtr; -using AbstractBasePtrList = abstract::AbstractBasePtrList; -using AbstractTuplePtr = abstract::AbstractTuplePtr; - -class ZipOperation : public MetaFuncGraph { - public: - explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} - ~ZipOperation() override = default; - MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; - friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { - os << op.name_; - return os; - } - friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } -}; -using ZipOperationPtr = std::shared_ptr; -} // namespace prim -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_ZIP_OPERATION_H_ diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc deleted file mode 100755 index b682847ed7..0000000000 --- a/mindspore/ccsrc/operator/ops.cc +++ /dev/null @@ -1,288 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/ops.h" -#include -#include - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -// Arithmetic -const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); -const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); -const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); -const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); -const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); -const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); -const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); -const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); -const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); -const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); -const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); -const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); -const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); -const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); -const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); -const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); - -// Comparisons -const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); -const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); -const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); -const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); -const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); -const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); -const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); -const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); -const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); -const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); -const PrimitivePtr kPrimGreater = std::make_shared("Greater"); -const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); -const PrimitivePtr kPrimLess = std::make_shared("Less"); -const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); -const PrimitivePtr kPrimEqual = std::make_shared("Equal"); -const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); - -// Type introspection -const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); -const PrimitivePtr kPrimHasType = std::make_shared("hastype"); - -// Statements -const PrimitivePtr kPrimSwitch = std::make_shared("switch"); -const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); -const PrimitivePtr kPrimReturn = std::make_shared("return"); -const PrimitivePtr kPrimAssign = std::make_shared("Assign"); -const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); -const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); -const PrimitivePtr kPrimSelect = std::make_shared("Select"); -const PrimitivePtr kPrimCall = std::make_shared("call"); - -const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); -const PrimitivePtr kPrimDot = std::make_shared("dot"); -const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); -const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); -const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); -const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); - -const PrimitivePtr kPrimResolve = std::make_shared("resolve"); -const PrimitivePtr kPrimEmbed = std::make_shared("embed"); -const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); -const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); - -const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); -const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); -const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); - -// Structure -const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); -const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); -const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); -const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); -const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); -const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); -const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); -const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); -const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); -const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); -const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); -const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); -const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); -const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); -const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); -const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); -const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); -const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); -const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); -const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); -const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); -const PrimitivePtr kPrimListLen = std::make_shared("list_len"); -const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); -const PrimitivePtr kPrimListMap = std::make_shared("list_map"); -const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); -const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); - -const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); -const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); -const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); -const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); -const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); -const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); -const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); -const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); -const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); -const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); -const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); - -// Arrays -const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); -const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); -const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); -const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); -const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); -const PrimitivePtr kPrimShape = std::make_shared("Shape"); -const PrimitivePtr kPrimCast = std::make_shared("Cast"); -const PrimitivePtr kPrimConcat = std::make_shared("Concat"); -const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); -const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); -const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); -const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); -const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); -const PrimitivePtr kPrimSize = std::make_shared("Size"); -const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); -const PrimitivePtr kPrimPack = std::make_shared("Pack"); -const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); -const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); -const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); -const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); -const PrimitivePtr kPrimTile = std::make_shared("Tile"); -const PrimitivePtr kPrimAddN = std::make_shared("AddN"); -const PrimitivePtr KPrimTransData = std::make_shared("TransData"); -const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); -const PrimitivePtr kPrimPad = std::make_shared("Pad"); -const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); - -// Maths -const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); -const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); -const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); -const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); -const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); -const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); -const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); -const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); -const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); -const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); -const PrimitivePtr kPrimNeg = std::make_shared("Neg"); -const PrimitivePtr kPrimSub = std::make_shared("Sub"); -const PrimitivePtr kPrimMul = std::make_shared("Mul"); -const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); -const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); -const PrimitivePtr kPrimSquare = std::make_shared("Square"); -const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); -const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); -const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); -const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); -const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); -const PrimitivePtr kPrimPow = std::make_shared("Pow"); -const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); -const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); -const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); -const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); - -// NN -const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); -const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); -const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); -const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); -const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); -const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); -const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); -const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); -const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); -const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); -const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); -const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); -const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); -const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); -const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); -const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); -const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); -const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); -const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); -const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = - std::make_shared("DepthwiseConv2dNativeBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = - std::make_shared("DepthwiseConv2dNativeBackpropInput"); -const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); -const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared("SoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = - std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); -const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); -const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); -const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); -const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); -const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); -const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); -const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); -const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); -const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); -const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); -const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); -const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); -const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); -const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); -const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); -const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); -const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); -const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); - -// Other miscellaneous -const PrimitivePtr kPrimIdentity = std::make_shared("identity"); -const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -const PrimitivePtr kPrimJ = std::make_shared("J"); -const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); -const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); -const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); -const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); -const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); -const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); -const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); -const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); -const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); -const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); -const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); -const PrimitivePtr kPrimPrint = std::make_shared("Print"); - -const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); -const PrimitivePtr kPrimDepend = std::make_shared("Depend"); -const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); - -const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); -const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); -const PrimitivePtr kPrimIs_ = std::make_shared("is_"); -const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); -const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); -const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); -const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); -const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); -const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); - -// Comm ops -const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); -const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); -const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); -const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); - -// Debug ops -const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); -const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); -const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); -const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -const PrimitivePtr kPrimDebug = std::make_shared("Debug"); - -// IndexedSlices -const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); -const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); -const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); -const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); -const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/ops_extends.cc b/mindspore/ccsrc/operator/ops_extends.cc deleted file mode 100755 index d415b45adf..0000000000 --- a/mindspore/ccsrc/operator/ops_extends.cc +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/ops.h" -#include -#include -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name, bool use_signature) { - py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); - ValuePtr node = nullptr; - bool succ = parse::ConvertData(obj, &node, use_signature); - if (!succ) { - MS_LOG(EXCEPTION) << "get Python op " << op_name << " from " << module_name << " fail"; - } - return node; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_arrays.cc b/mindspore/ccsrc/operator/prim_arrays.cc deleted file mode 100644 index 4e2e2ebd1f..0000000000 --- a/mindspore/ccsrc/operator/prim_arrays.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "abstract/utils.h" -#include "operator/cc_implementations.h" -#include "abstract/param_validator.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a scalar. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractScalarPtr arg = CheckArg(op_name, args_spec_list, 0); - return std::make_shared(arg, std::make_shared()); -} - -AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor with 0 shape. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto arg = CheckArg(op_name, args_spec_list, 0); - auto a_shp = arg->shape(); - if (!a_shp->shape().empty()) { - MS_LOG(EXCEPTION) << "array_to_scalar requires zero size shape."; - } - return arg->element(); -} - -AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto xs = CheckArg(op_name, args_spec_list, 0); - auto ys = CheckArg(op_name, args_spec_list, 1); - - auto value_tuple_x = xs->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(value_tuple_x); - auto shp_tuple_x = value_tuple_x->value(); - std::vector shp_x; - (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x), - [](const ValuePtr &e) -> int { return GetValue(e); }); - - auto value_tuple_y = ys->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(value_tuple_y); - auto shp_tuple_y = value_tuple_y->value(); - std::vector shp_y; - (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), - [](const ValuePtr &e) -> int { return GetValue(e); }); - - std::vector res = prim::BroadcastShape_(shp_x, shp_y); - if (res.empty()) { - MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," - << args_spec_list[1]->ToString(); - } - - AbstractBasePtrList elems; - (void)std::transform(res.begin(), res.end(), std::back_inserter(elems), [](int n) -> AbstractBasePtr { - return std::make_shared(std::make_shared(n), kInt32); - }); - - return std::make_shared(elems); -} - -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, 0); - MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString(); - - AbstractBasePtrList values; - auto shp = arg->shape(); - for (int entry : shp->shape()) { - auto entry_v = MakeValue(entry); - values.push_back(std::make_shared(entry_v, entry_v->type())); - } - return std::make_shared(values); -} - -AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto arg = CheckArg(op_name, args_spec_list, 0); - auto multiples = CheckArg(op_name, args_spec_list, 1); - - ShapePtr input_shape = arg->shape(); - (void)CheckTensorDType(arg, {kInt16, kFloat16, kInt32, kFloat32}, "Input 0 of Tile should be %s"); - - auto mul_shp_value = multiples->BuildValue(); - if (mul_shp_value->isa()) { - MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString(); - } - - std::vector mul_shp; - auto value_tuple_mul = mul_shp_value->cast(); - auto mul_shp_data = value_tuple_mul->value(); - (void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp), - [](const ValuePtr &e) -> int { return GetValue(e); }); - if (input_shape->shape().size() != mul_shp_data.size()) { - MS_LOG(EXCEPTION) << "Tile requires input and multiples size equal, while the input size is " - << input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << "."; - } - - std::vector result_shp; - for (size_t i = 0; i < mul_shp_data.size(); ++i) { - result_shp.push_back(input_shape->shape()[i] * mul_shp[i]); - } - return std::make_shared(arg->element(), std::make_shared(result_shp)); -} - -AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple of tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto arg = CheckArg(op_name, args_spec_list, 0); - if (arg->elements().empty()) { - MS_LOG(EXCEPTION) << "Arg elements is empty."; - } - - size_t tuple_len = arg->elements().size(); - AbstractTensorPtr tensor_base = CheckArg(op_name, arg->elements(), 0); - int rank_base = SizeToInt(tensor_base->shape()->shape().size()); - - ValuePtr axis = primitive->GetAttr("axis"); - // Axis value should be in [-(rank_base + 1), rank_base). - int axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); - // If axis is negative, add offset(rank_base + 1) to turn it to positive. - axis_value = GetPositiveAxis(axis_value, IntToSize(rank_base + 1)); - - for (size_t i = 1; i < tuple_len; ++i) { - AbstractTensorPtr tensor = CheckArg(op_name, arg->elements(), i); - (void)CheckDtypeSame(op_name, tensor_base, tensor); - (void)CheckShapeSame(op_name, tensor_base, tensor); - } - - primitive->set_attr("N", MakeValue(SizeToInt(tuple_len))); - primitive->set_attr("T", tensor_base->element()->BuildType()); - - AbstractTensorPtr ret = dyn_cast(tensor_base->Broaden()); - MS_EXCEPTION_IF_NULL(ret); - auto shape = ret->shape()->shape(); - (void)shape.insert(shape.begin() + axis_value, tuple_len); - ret->set_shape(std::make_shared(shape)); - return ret; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_debug.cc b/mindspore/ccsrc/operator/prim_debug.cc deleted file mode 100644 index 014797fb20..0000000000 --- a/mindspore/ccsrc/operator/prim_debug.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "abstract/param_validator.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "abstract/utils.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor(value) - const std::string op_name = primitive->name(); - - CheckArgsSize(op_name, args_spec_list, 1); - auto tensor_value = CheckArg(op_name, args_spec_list, 0); - - int tensor_rank = SizeToInt(tensor_value->shape()->shape().size()); - if (tensor_rank == 0) { - MS_LOG(EXCEPTION) << op_name << " summary evaluator second arg should be an tensor, but got a scalar, rank is 0"; - } - - return std::make_shared(AbstractBasePtrList({tensor_value->Broaden()})); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_maths.cc b/mindspore/ccsrc/operator/prim_maths.cc deleted file mode 100644 index e073a3630b..0000000000 --- a/mindspore/ccsrc/operator/prim_maths.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "abstract/utils.h" -#include "abstract/param_validator.h" -#include "common/utils.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto input_x = CheckArg(op_name, args_spec_list, 0); - auto input_y = CheckArg(op_name, args_spec_list, 1); - auto dout = CheckArg(op_name, args_spec_list, 2); - (void)CheckTensorsDTypeSame({input_x, input_y, dout}, {kInt, kUInt, kFloat}, - op_name + "evaluator three inputs should be %s"); - - AbstractBasePtr dx = input_x->Broaden(); - AbstractBasePtr dy = input_y->Broaden(); - - return std::make_shared(AbstractBasePtrList({dx, dy})); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_nn.cc b/mindspore/ccsrc/operator/prim_nn.cc deleted file mode 100644 index 729674cace..0000000000 --- a/mindspore/ccsrc/operator/prim_nn.cc +++ /dev/null @@ -1,432 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "abstract/utils.h" -#include "abstract/param_validator.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTensorPtr input_tensor = CheckArg(op_name, args_spec_list, 0); - (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "Input 0 of Pooling should be %s"); - - ShapePtr input_shape = dyn_cast(input_tensor->GetShapeTrack()); // NCHW - MS_EXCEPTION_IF_NULL(input_shape); - if (input_shape->shape().size() != 4) { - MS_LOG(EXCEPTION) << "Pooling input should be a 4-D tensor."; - } - int h_input = input_shape->shape()[2]; - int w_input = input_shape->shape()[3]; - - int window = primitive->GetAttr("window")->cast()->value(); - int stride = primitive->GetAttr("stride")->cast()->value(); - int padding = primitive->GetAttr("pad")->cast()->value(); - int nan_opt = primitive->GetAttr("nan_opt")->cast()->value(); - int data_mode = primitive->GetAttr("data_mode")->cast()->value(); - int ceil_mode = primitive->GetAttr("ceil_mode")->cast()->value(); - - if (stride <= 0) { - MS_LOG(EXCEPTION) << "Invalid stride value: " << stride << ", should greater then 0"; - } - if (nan_opt != 0) { - MS_LOG(EXCEPTION) << "Invalid nan_opt value: " << nan_opt << ", should be 0"; - } - if (data_mode != 1) { - MS_LOG(EXCEPTION) << "Invalid data_mode value: " << data_mode << ", should be 1"; - } - if (ceil_mode != 0) { - MS_LOG(EXCEPTION) << "Invalid ceil_mode value: " << ceil_mode << ", should be 0"; - } - - std::set available_pad_mode{"pad", "same", "valid"}; - auto pad_mode_ptr = primitive->GetAttr("pad_mode"); - if ((pad_mode_ptr != nullptr) && pad_mode_ptr->isa()) { - auto pad_mode = pad_mode_ptr->cast()->value(); - if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { - MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; - } - if (pad_mode == "valid") { - padding = 0; - } else if (pad_mode == "same") { - padding = (window - 1) / 2; - } - } - - std::set available_mode{"max", "avg"}; - auto mode_ptr = primitive->GetAttr("mode"); - if ((mode_ptr != nullptr) && mode_ptr->isa()) { - auto mode = mode_ptr->cast()->value(); - if (available_mode.find(mode) == available_mode.end()) { - MS_LOG(EXCEPTION) << "Unsupported pooling mode: " << mode << "."; - } - } - - int h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1; - int w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1; - std::vector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; - AbstractBasePtr ret = input_tensor->Broaden(); - ret->set_shape(std::make_shared(shape_out)); - return ret; -} - -AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(y, dy, x). - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto out_y = CheckArg(op_name, args_spec_list, 0); - auto d_out = CheckArg(op_name, args_spec_list, 1); - auto input_x = CheckArg(op_name, args_spec_list, 2); - (void)CheckTensorsDTypeSame({out_y, d_out, input_x}, {kInt, kUInt, kFloat}, - op_name + "evaluator three inputs should be %s"); - - AbstractBasePtr ret = d_out->Broaden(); - auto x_shape = dyn_cast(args_spec_list[2]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(x_shape); - - ret->set_shape(x_shape); - return ret; -} - -void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - // check dimension, x > 1, others equal 1 - const std::string op_name = primitive->name(); - for (std::size_t i = 0; i < args_spec_list.size(); ++i) { - AbstractTensorPtr arg = CheckArg(op_name, args_spec_list, i); - ShapePtr arg_shape = dyn_cast(arg->GetShapeTrack()); - if (arg_shape == nullptr) { - MS_LOG(EXCEPTION) << op_name << " type of args[" << i << "] should be Shape, but " << arg->ToString(); - } - - if (i == 0) { - if (arg_shape->shape().size() < 2) { - MS_LOG(EXCEPTION) << op_name << " shape of args[" << i - << "] should be TensorShape with dimension greater than 1, but shape: " - << arg_shape->ToString(); - } - continue; - } - - if (arg_shape->shape().size() != 1) { - MS_LOG(EXCEPTION) << op_name << " shape of args[" << i - << "] should be TensorShape with dimension: 1, but shape: " << arg_shape->ToString(); - } - } -} - -AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(x, gamma, beta, mean, variance). - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 5); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_LOG(DEBUG) << "InferImplFusedBatchNorm args0:" << args_spec_list[0]->ToString() - << ", arg1:" << args_spec_list[1]->ToString(); - FusedBatchNormCheckDim(primitive, args_spec_list); - - auto input = args_spec_list[0]; - auto input_shape = dyn_cast(input->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(input_shape); - const auto &input_shape_list = input_shape->shape(); - if (input_shape_list.size() < 2) { - MS_LOG(EXCEPTION) << "Input shape size should >= 2."; - } - - for (size_t i = 1; i < args_spec_list.size(); ++i) { - auto arg_shape = dyn_cast(args_spec_list[i]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(arg_shape); - const auto &arg_shape_list = arg_shape->shape(); - if (arg_shape_list.size() < 1) { - MS_LOG(EXCEPTION) << "Arg shape size should >= 1."; - } - if (arg_shape_list[0] != input_shape_list[1]) { - MS_LOG(EXCEPTION) << op_name << " size of tensor param[" << i << "](which is " << arg_shape_list[0] - << ") should match the second dimension of tensor" - " param[0](which is " - << input_shape_list[1] << ")."; - } - } - auto input_tensor = CheckArg(op_name, args_spec_list, 0); - (void)CheckTensorDType(input_tensor, {kFloat16, kFloat32}, "param 0 of FusedBatchNorm should be %s"); - - AbstractTensorPtrList tensorPtrList = std::vector(); - for (size_t i = 1; i < args_spec_list.size(); ++i) { - auto param = CheckArg(op_name, args_spec_list, i); - tensorPtrList.push_back(param); - } - (void)CheckTensorsDTypeSame(tensorPtrList, {kFloat16, kFloat32}, "param 1 to 4 of FusedBatchNorm should be %s"); - - // check validity; - auto epsilon_value = primitive->GetAttr("epsilon"); - auto momentum_value = primitive->GetAttr("momentum"); - MS_EXCEPTION_IF_NULL(epsilon_value); - MS_EXCEPTION_IF_NULL(momentum_value); - if (!epsilon_value->isa() || !momentum_value->isa()) { - MS_LOG(EXCEPTION) << "expect epsilon and momentum be float, but: epsilon: " << epsilon_value->ToString() - << ", momentum: " << momentum_value->ToString(); - } - - auto epsilon = epsilon_value->cast()->value(); - auto momentum = momentum_value->cast()->value(); - - if (epsilon > 1.0f || epsilon <= 0.0f) { - MS_LOG(EXCEPTION) << "expect epsilon is greater than 0 and less or equal than 1, but epsilon: " << epsilon; - } - if (momentum > 1.0f || momentum < 0.0f) { - MS_LOG(EXCEPTION) << "expect momentum is great or equal than 0 and less or equal than 1, but epsilon: " << momentum; - } - - // Outputs: y, running_mean, running_variance, save_mean, save_inv_variance. - AbstractBasePtr y = input->Broaden(); - AbstractBasePtr other = args_spec_list[1]->Broaden(); - MS_LOG(DEBUG) << "output y: " << y->ToString() << ", other: " << other->ToString(); - - AbstractBasePtrList elements = {y, other, other, other, other}; - return std::make_shared(elements); -} - -AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - MS_EXCEPTION_IF_NULL(args_spec_list[3]); - - CheckArgsSize(primitive->name(), args_spec_list, 5); - auto dx = args_spec_list[1]->Broaden(); - auto dscale = args_spec_list[2]->Broaden(); - auto dbias = args_spec_list[3]->Broaden(); - - AbstractBasePtrList rets = {dx, dscale, dbias}; - return std::make_shared(rets); -} - -AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors(y_backprop, x). - CheckArgsSize(primitive->name(), args_spec_list, 2); - return args_spec_list[1]->Broaden(); -} - -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(doutput, input, filters). - CheckArgsSize(primitive->name(), args_spec_list, 3); - return args_spec_list[1]->Broaden(); -} - -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(inputs, filter, doutput). - CheckArgsSize(primitive->name(), args_spec_list, 3); - return args_spec_list[2]->Broaden(); -} - -AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: at least one tensor(y_backprop) - // Outputs: dbias - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << primitive->name() << " evaluator at least has 1 parameters, while the input size is " - << args_spec_list.size() << "."; - } - - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - ShapePtr shape_y = dyn_cast(args_spec_list[0]->GetShapeTrack()); - MS_EXCEPTION_IF_NULL(shape_y); - std::vector y_dims = shape_y->shape(); - if (y_dims.size() < 2) { - MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << "."; - } - std::vector bias_dims = {y_dims[1]}; - ShapePtr ret_shape = std::make_shared(bias_dims); - AbstractBasePtr ret = args_spec_list[0]->Broaden(); - ret->set_shape(ret_shape); - return ret; -} - -AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - -AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - -AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Broaden(); -} - -AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor. - AbstractBasePtrList args_list; - for (size_t i = 0; i < args_spec_list.size() - 2; i++) { - args_list.push_back(args_spec_list[i]->Broaden()); - } - return std::make_shared(args_list); -} - -AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three tensors(x, gamma, beta). - // outputs: y, mean, variance - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto input_x = CheckArg(op_name, args_spec_list, 0); - auto input_shape = input_x->shape(); - auto const &input_shape_list = input_shape->shape(); - const size_t input_rank = input_shape_list.size(); - if (input_rank == 0) { - MS_LOG(EXCEPTION) << "input_rank should not be zero"; - } - - // begin_norm_axis and begin_params_axis should be smaller than the size of input_x and >= -1 - ValuePtr bna_ptr = primitive->GetAttr("begin_norm_axis"); - int begin_norm_axis = CheckAxis(op_name, bna_ptr, -1, SizeToInt(input_rank) - 1); - - ValuePtr bpa_ptr = primitive->GetAttr("begin_params_axis"); - int begin_params_axis = CheckAxis(op_name, bpa_ptr, -1, SizeToInt(input_rank) - 1); - begin_params_axis = GetPositiveAxis(begin_params_axis, input_rank); - - // the beta and gama shape should be x_shape[begin_params_axis:] - auto tensor = CheckArg(op_name, args_spec_list, 0); - auto gamma = CheckArg(op_name, args_spec_list, 1); - auto beta = CheckArg(op_name, args_spec_list, 2); - (void)CheckTensorDType(tensor, {kFloat16, kFloat32}, "input 0 of LayerNorm should be %s"); - (void)CheckTensorDType(gamma, {kFloat16, kFloat32}, "input 1 of LayerNorm should be %s"); - (void)CheckTensorDType(beta, {kFloat16, kFloat32}, "input 2 of LayerNorm should be %s"); - auto gamma_shape = dyn_cast(gamma->BuildShape()); - auto beta_shape = dyn_cast(beta->BuildShape()); - MS_EXCEPTION_IF_NULL(gamma_shape); - MS_EXCEPTION_IF_NULL(beta_shape); - - auto const &gamma_shape_list = gamma_shape->shape(); - auto const &beta_shape_list = beta_shape->shape(); - if (gamma_shape_list.empty() || beta_shape_list.empty()) { - MS_LOG(EXCEPTION) << "LayerNorm evaluator gamma or beta is a AbstractScalar that is not support."; - } - - size_t begin_params_axis_u = IntToSize(begin_params_axis); - if ((begin_params_axis_u > input_shape_list.size()) || - (gamma_shape_list.size() + begin_params_axis_u < input_shape_list.size()) || - (beta_shape_list.size() + begin_params_axis_u < input_shape_list.size())) { - MS_LOG(EXCEPTION) << "Gamma and beta shape get wrong size."; - } - for (size_t i = begin_params_axis_u; i < input_shape_list.size(); ++i) { - size_t gamma_beta_shape_dim = i - begin_params_axis_u; - if ((gamma_shape_list[gamma_beta_shape_dim] != input_shape_list[i]) || - (beta_shape_list[gamma_beta_shape_dim] != input_shape_list[i])) { - MS_LOG(EXCEPTION) << "Gamma or beta shape not match input shape, input_shape=" << input_shape->ToString() - << ", gamma_shape=" << gamma_shape->ToString() << ", beta_shape=" << beta_shape->ToString(); - } - } - - auto mean_var_shape_value = input_shape->shape(); - if (begin_norm_axis == -1) { - mean_var_shape_value[input_rank - 1] = 1; - } else { - for (size_t i = begin_norm_axis; i < input_rank; ++i) { - mean_var_shape_value[i] = 1; - } - } - - auto mean = input_x->Broaden(); - mean->set_shape(std::make_shared(mean_var_shape_value)); - auto var = input_x->Broaden(); - var->set_shape(std::make_shared(mean_var_shape_value)); - - AbstractBasePtrList args_list({input_x->Broaden(), mean, var}); - return std::make_shared(args_list); -} - -AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: five tensors(y_backprob, x, variance, mean, gamma). - // Outputs: x_backprob, gamma_backprob, beta_backprob - CheckArgsSize(primitive->name(), args_spec_list, 5); - - auto x_backprob = args_spec_list[0]->Broaden(); - auto gamma_backprob = args_spec_list[4]->Broaden(); - auto beta_backprob = args_spec_list[4]->Broaden(); - - AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob}); - return std::make_shared(args_list); -} - -AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple and a tensor. - // Outputs: mask. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr x_shape = CheckArg(op_name, args_spec_list, 0); - AbstractTensorPtr keep_prob = CheckArg(op_name, args_spec_list, 1); - - TypePtr prob_type = keep_prob->element()->BuildType(); - if ((prob_type->type_id() != kNumberTypeFloat16) && (prob_type->type_id() != kNumberTypeFloat32)) { - MS_LOG(EXCEPTION) << op_name << " keep_prob type should be float16 or float32, but " << prob_type->ToString() - << "."; - } - - auto x_shape_data = x_shape->elements(); - int count = 1; - for (std::size_t i = 0; i < x_shape->size(); ++i) { - auto value_track = x_shape_data[i]->GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - if (!value_track->isa()) { - MS_LOG(EXCEPTION) << "DropOutGenMask input x_shape elements is not int32, but " << value_track->ToString() << "."; - } - - int e_value = GetValue(value_track); - if (e_value <= 0) { - MS_LOG(EXCEPTION) << "DropOutGenMask product of x_shape should be > 0"; - } - if (std::numeric_limits::max() / count / e_value < 1) { - MS_LOG(EXCEPTION) << "integer multiply integer overflow"; - } - count = count * e_value; - } - - // convert to bytes(8 bits) mask, using round up - int n128s = count / 128; - if ((count % 128) != 0) { - n128s++; - } - int bytes_count = n128s * 16; - std::vector shape_y{bytes_count}; - - primitive->set_attr("T", kInt32); - return std::make_shared(std::make_shared(kAnyValue, kUInt8), - std::make_shared(std::vector{shape_y})); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc deleted file mode 100644 index f181fcacf7..0000000000 --- a/mindspore/ccsrc/operator/prim_others.cc +++ /dev/null @@ -1,410 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "ir/dtype.h" -#include "common/utils.h" -#include "operator/ops.h" -#include "abstract/param_validator.h" -#include "pipeline/static_analysis/prim.h" -#include "abstract/utils.h" -#include "utils/context/ms_context.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // An object of a subclass of AbstractBase - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]; -} - -AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: An object of AbstractFunction. - CheckArgsSize(primitive->name(), args_spec_list, 1); - MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString(); - - AbstractFunctionPtr x = dyn_cast(args_spec_list[0]); - if (x == nullptr) { - return std::make_shared(args_spec_list[0]); - } - - AbstractFuncAtomPtrList jv; - auto build_jv = [&jv](const AbstractFuncAtomPtr &func) { - auto j_closure = std::make_shared(func); - jv.push_back(j_closure); - }; - x->Visit(build_jv); - - return AbstractFunction::MakeAbstractFunction(jv); -} - -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(primitive); - // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). - CheckArgsSize(primitive->name(), args_spec_list, 3); - auto key = args_spec_list[1]; - auto dflt = args_spec_list[2]; - TypePtr type = key->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kObjectTypeSymbolicKeyType) { - MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString(); - } - - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); - if (enable_sparse && dflt->isa()) { - auto dflt_tensor = dflt->cast(); - return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); - } - - if (!key->GetValueTrack()->isa()) { - return dflt; - } - ValuePtr key_value_ptr = key->GetValueTrack(); - MS_EXCEPTION_IF_NULL(key_value_ptr); - auto key_value_track = key_value_ptr->cast(); - auto expected = key_value_track->abstract(); - MS_EXCEPTION_IF_NULL(expected); - (void)expected->Join(dflt); - return expected; -} - -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). - CheckArgsSize(primitive->name(), args_spec_list, 3); - - auto key = args_spec_list[1]; - ValuePtr key_value_ptr = key->GetValueTrack(); - MS_EXCEPTION_IF_NULL(key_value_ptr); - auto key_value_track = key_value_ptr->cast(); - if (key_value_track == nullptr) { - MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] expected should be able to cast to SymbolicKeyInstancePtrbut: " - << key_value_ptr->ToString(); - } - auto expected = key_value_track->abstract(); - MS_EXCEPTION_IF_NULL(expected); - return std::make_shared(kAnyValue, std::make_shared()); -} - -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Three objects of a subclass of AbstractBase, env, key, dflt(default). - CheckArgsSize(primitive->name(), args_spec_list, 2); - return std::make_shared(kAnyValue, std::make_shared()); -} - -AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &prim, const AbstractBasePtrList &) { - ValuePtr name_value = prim->GetAttr("tag"); - auto name = name_value->cast(); - if (name == nullptr) { - MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << "."; - } - auto refkey = std::make_shared(name->value()); - if (refkey == nullptr) { - MS_LOG(EXCEPTION) << "MakeRefKey std::make_shared failed"; - } - return refkey->ToAbstract(); -} - -AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: key, value, original value - if (args_spec_list.size() != 3) { - MS_LOG(EXCEPTION) << "make_ref evaluator requires 3 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRefKey) { - MS_LOG(EXCEPTION) << "First input of make_ref should be a RefKey but a " << type->ToString(); - } - return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); -} - -AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_key requires 1 parameters, while the input size is " << args_spec_list.size() << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_key should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref(); -} - -AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_value requires 1 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref(); -} - -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // arguments: value - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "get_ref_origin requires 1 parameters, while the input size is " << args_spec_list.size() - << "."; - } - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kObjectTypeRef) { - MS_LOG(EXCEPTION) << "First input of get_ref_value should be a Ref but a " << type->ToString(); - } - return args_spec_list[0]->cast()->ref_origin(); -} - -AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Two objects of a subclass of AbstractBase, key and value. - CheckArgsSize(primitive->name(), args_spec_list, 2); - - TypePtr type = args_spec_list[0]->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kObjectTypeRefKey && type->type_id() != kObjectTypeSymbolicKeyType) { - MS_LOG(EXCEPTION) << "First input of StateSetItem should be a RefKey or SymbolicKeyType but a " << type->ToString(); - } - return std::make_shared(kAnyValue, kBool); -} - -AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; - } - auto depends = args_spec_list[0]->Broaden(); - return depends; -} - -bool CompareShape(const std::vector &x_shape, const std::vector &y_shape) { - if (x_shape.size() != y_shape.size()) { - return false; - } - - for (size_t i = 0; i < x_shape.size(); ++i) { - if (GetValue(x_shape[i]) != GetValue(y_shape[i])) { - return false; - } - } - - return true; -} - -enum State { - SAME, - X_ONE, - Y_ONE, -}; - -void ComputeReduceIndex(const std::vector &reverse_x, const std::vector &reverse_y, - std::vector *grad_x_reduce_idx, std::vector *grad_y_reduce_idy) { - const size_t n = reverse_x.size(); - for (size_t i = 0; i < n; ++i) { - State curr; - const int32_t x_i = reverse_x[i]; - const int32_t y_i = reverse_y[i]; - const int reduce_idx = SizeToInt(n - 1 - i); - if (x_i == y_i) { - curr = SAME; - } else if (x_i == 1) { - grad_x_reduce_idx->push_back(reduce_idx); - curr = X_ONE; - } else if (y_i == 1) { - grad_y_reduce_idy->push_back(reduce_idx); - curr = Y_ONE; - } else { - MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs"; - } - if (curr == SAME && x_i == 1) { - grad_x_reduce_idx->push_back(reduce_idx); - grad_y_reduce_idy->push_back(reduce_idx); - continue; - } - } - - std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end()); - std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end()); -} - -AbstractBasePtr BroadcastGradientArgsDiff(const std::vector &x_shape, const std::vector &y_shape) { - std::vector reverse_x; - std::vector reverse_y; - - (void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x), - [](const ValuePtr &v) { return v->cast()->value(); }); - (void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y), - [](const ValuePtr &v) { return v->cast()->value(); }); - - if (reverse_x.size() > reverse_y.size()) { - reverse_y.resize(reverse_x.size(), 1); - } else { - reverse_x.resize(reverse_y.size(), 1); - } - - std::vector grad_x_reduce_idx; - std::vector grad_y_reduce_idy; - ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy); - - AbstractBasePtrList abs_list_x; - AbstractBasePtrList abs_list_y; - (void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x), - [](int v) { return abstract::FromValue(v); }); - (void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y), - [](int v) { return abstract::FromValue(v); }); - auto x_reduce_idx = std::make_shared(abs_list_x); - auto y_reduce_idx = std::make_shared(abs_list_y); - AbstractBasePtrList elem_list; - elem_list.push_back(x_reduce_idx); - elem_list.push_back(y_reduce_idx); - - return std::make_shared(elem_list); -} - -AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // this primitive get the index that need to reduce - // input: x's shape and y's shape, inputs should be tuple - // output: tuple of x and y 's reduce index, reduce index should be a tuple - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto arg_x = CheckArg(op_name, args_spec_list, 0); - auto arg_y = CheckArg(op_name, args_spec_list, 1); - - ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(arg_x_value); - - ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(arg_y_value); - - const std::vector x_shape = arg_x_value->value(); - const std::vector y_shape = arg_y_value->value(); - bool is_same_shape = CompareShape(x_shape, y_shape); - // if it is the same shape , do not need reduce , return empty tuple - if (is_same_shape) { - AbstractBasePtrList empty_list; - auto x_reduce_idx = std::make_shared(empty_list); - auto y_reduce_idx = std::make_shared(empty_list); - - AbstractBasePtrList elem_list; - elem_list.push_back(x_reduce_idx); - elem_list.push_back(y_reduce_idx); - - return std::make_shared(elem_list); - } - - return BroadcastGradientArgsDiff(x_shape, y_shape); -} - -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // args: Two objects of a subclass of AbstractBase - CheckArgsSize(primitive->name(), args_spec_list, 2); - auto arg_src = args_spec_list[0]; - auto arg_dst = args_spec_list[1]; - // control depend can not setup tuple of ops to tuple of ops dependency relation - if (arg_src->isa() && arg_dst->isa()) { - auto src_size = arg_src->cast()->size(); - auto dst_size = arg_src->cast()->size(); - if (src_size > 1 && dst_size > 1) { - MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple"; - } - } - return std::make_shared(kAnyValue, kBool); -} - -AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - auto indices = CheckArg(op_name, args_spec_list, 0); - auto values = CheckArg(op_name, args_spec_list, 1); - auto dense_shape = CheckArg(op_name, args_spec_list, 2); - - auto dense_shape_value = dense_shape->BuildValue()->cast(); - MS_EXCEPTION_IF_NULL(dense_shape_value); - auto shp = dense_shape_value->value(); - std::vector dense_shape_vec; - (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), - [](const ValuePtr &e) -> int { - auto elem = GetValue(e); - return elem; - }); - auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); - ret->set_indices(indices); - ret->set_values(values); - ret->set_dense_shape(dense_shape); - return ret; -} - -AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->values()); - return indexed_slices->values(); -} - -AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->indices()); - return indexed_slices->indices(); -} - -AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors and a tuple. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - auto indexed_slices = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); - return indexed_slices->dense_shape(); -} - -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - bool ret = false; - if (args_spec_list[0]->isa()) { - ret = true; - } - MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); - return std::make_shared(ret); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_statement.cc b/mindspore/ccsrc/operator/prim_statement.cc deleted file mode 100644 index 3760814554..0000000000 --- a/mindspore/ccsrc/operator/prim_statement.cc +++ /dev/null @@ -1,249 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "abstract/param_validator.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "abstract/utils.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace abstract { -AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object - if (args_spec_list.size() != 1) { - MS_LOG(INFO) << "Return evaluator requires 1 parameter, is this the default value attached? " - "while the input size is " - << args_spec_list.size() << "."; - } - AbstractBasePtr abs_base = args_spec_list[0]; - return abs_base; -} - -AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size() - << "."; - } - AbstractBasePtr abs_base = args_spec_list[0]; - MS_EXCEPTION_IF_NULL(abs_base); - TypePtr type = abs_base->BuildType(); - return std::make_shared(type); -} - -AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a pointer to an AbstractBase object and a pointer to a Type - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTypePtr abs_type = CheckArg(op_name, args_spec_list, 1); - - auto mode_v = abs_type->GetValueTrack(); - MS_EXCEPTION_IF_NULL(mode_v); - if (!mode_v->isa()) { - MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed."; - } - - TypePtr mode_t = mode_v->cast(); - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - bool v = IsSubtype(args_spec_list[0], mode_t); - return std::make_shared(std::make_shared(v), kBool); -} - -AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tensors. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); - AbstractTensorPtr input_y = CheckArg(op_name, args_spec_list, 1); - - ShapePtr x_shp = input_x->shape(); - auto x_shp_value = x_shp->shape(); - ShapePtr y_shp = input_y->shape(); - auto y_shp_value = y_shp->shape(); - // Should be matrix which shape size is 2. - if (x_shp_value.size() != 2 || y_shp_value.size() != 2) { - MS_LOG(EXCEPTION) << op_name << " evaluator requires input two 2D tensors, while the dimensions of two tensors are " - << x_shp_value.size() << ", " << y_shp_value.size() << " "; - } - if (x_shp_value[1] != y_shp_value[0] && x_shp_value[1] != Shape::SHP_ANY && y_shp_value[0] != Shape::SHP_ANY) { - MS_LOG(EXCEPTION) << "Incompatible shapes in dot: {" << x_shp->ToString() << "} and {" << y_shp->ToString() << "}"; - } - - auto x_element = input_x->element(); - MS_EXCEPTION_IF_NULL(x_element); - (void)x_element->Join(input_y->element()); - auto param = {x_shp_value[0], y_shp_value[1]}; - - return std::make_shared(input_x->element(), std::make_shared(param)); -} - -AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &prim, - const AbstractBasePtrList &args_spec_list) { - // Inputs: condition, true branch, false branch - if (args_spec_list.size() != 3) { - MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_spec_list.size() - << "."; - } - - auto cond = args_spec_list[0]; - auto tb = args_spec_list[1]; - auto fb = args_spec_list[2]; - MS_EXCEPTION_IF_NULL(cond); - - auto unroll_flag = prim->GetAttr(prim::SWITCH_UNROLL_FLAG); - if (unroll_flag != nullptr && GetValue(unroll_flag) == 0) { - return tb->Join(fb); - } - - ValuePtr v = cond->GetValueTrack(); - MS_EXCEPTION_IF_NULL(v); - // for tensor as condition, keeps both true and false branch. - if (v->isa() || cond->isa()) { - MS_EXCEPTION_IF_NULL(tb); - return tb->Join(fb); - } - - if (v->isa()) { - if (v->cast()->IsOne()) { - return tb; - } else { - return fb; - } - } - - MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString(); -} - -AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: index, branch - const std::string op_name = primitive->name(); - abstract::CheckArgsSize(op_name, args_spec_list, 2); - (void)CheckArg(op_name, args_spec_list, 0); - AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); - AbstractBasePtrList branches = branches_abs->elements(); - const size_t maximum_layer_num = 1000; - if (branches.size() < 0 || branches.size() > maximum_layer_num) { - MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " - << branches.size() << " branches."; - } - - for (size_t i = 0; i < branches.size(); i++) { - MS_EXCEPTION_IF_NULL(branches[i]); - if (!branches[i]->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " - << branches[i]->ToString() << " as the " << i << "th element."; - } - } - - auto b = branches[0]; - for (size_t i = 1; i < branches.size(); i++) { - b = b->Join(branches[i]); - } - return b; -} - -std::vector GetSupportedTargetValue() { - std::vector list = {kNone, MakeValue(false), MakeValue(true)}; - return list; -} - -bool SupportedIsTargetValue(const ValuePtr t) { - auto list = GetSupportedTargetValue(); - auto match = std::any_of(list.begin(), list.end(), [&t](const ValuePtr &v) { return *v == *t; }); - return match; -} - -AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x is t - // Inputs: x, t - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - ValuePtr t = args_spec_list[1]->BuildValue(); - if (!SupportedIsTargetValue(t)) { - MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() - << " for statement is, supported list is:None, False, True "; - } - ValuePtr x = args_spec_list[0]->BuildValue(); - - return std::make_shared(*t == *x); -} - -AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x is not t - // Inputs: x, t - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - ValuePtr t = args_spec_list[1]->BuildValue(); - if (!SupportedIsTargetValue(t)) { - MS_LOG(EXCEPTION) << "Not supported type:" << t->ToString() - << " for statement is not, supported list is:None, False, True "; - } - ValuePtr x = args_spec_list[0]->BuildValue(); - - return std::make_shared(!(*t == *x)); -} - -bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - auto key = CheckArg(op_name, args_spec_list, 0); - auto dict = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); - } - auto key_str = GetValue(key_value); - std::vector dict_elems = dict->elements(); - auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](const AbstractAttribute &item) { return item.first == key_str; }); - return it != dict_elems.end(); -} - -AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x in t - // Inputs: x, t - return std::make_shared(IsInDict(primitive, args_spec_list)); -} - -AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: x not in t - // Inputs: x, t - return std::make_shared(!IsInDict(primitive, args_spec_list)); -} - -AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // statement: isconstant(x) - // Inputs: x - if (args_spec_list.size() != 1) { - MS_LOG(EXCEPTION) << "IsConstant requires args input size = 1"; - } - ValuePtr v = args_spec_list[0]->BuildValue(); - return std::make_shared(!v->isa()); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_structures.cc b/mindspore/ccsrc/operator/prim_structures.cc deleted file mode 100644 index 6501e6a843..0000000000 --- a/mindspore/ccsrc/operator/prim_structures.cc +++ /dev/null @@ -1,712 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" -#include "abstract/utils.h" -#include "abstract/param_validator.h" -#include "operator/ops.h" -#include "utils/convert_utils.h" -#include "ir/tensor_py.h" - -using mindspore::tensor::TensorPy; - -namespace mindspore { -namespace abstract { - -AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two scalars whose value is a string. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); - - ValuePtr value_x = scalar_x->BuildValue(); - ValuePtr value_y = scalar_y->BuildValue(); - if (!value_x->isa() || !value_y->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() - << ", param1: " << value_y->ToString(); - } - - bool ret = (value_x->cast()->value() == value_y->cast()->value()); - return std::make_shared(ret); -} - -AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two scalars whose value is a string. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr scalar_x = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr scalar_y = CheckArg(op_name, args_spec_list, 1); - - ValuePtr value_x = scalar_x->BuildValue(); - ValuePtr value_y = scalar_y->BuildValue(); - if (!value_x->isa() || !value_y->isa()) { - MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString() - << ", param1: " << value_y->ToString(); - } - - std::string ret = (value_x->cast()->value() + value_y->cast()->value()); - return std::make_shared(ret); -} - -AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - return std::make_shared(args_spec_list); -} - -AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - return std::make_shared(args_spec_list); -} - -AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr keys = CheckArg(op_name, args_spec_list, 0); - AbstractTuplePtr values = CheckArg(op_name, args_spec_list, 1); - - size_t keys_size = keys->size(); - if (values->size() != keys_size) { - MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size"; - } - - std::vector key_value; - AbstractScalarPtr key; - AbstractBasePtrList key_list = keys->elements(); - AbstractBasePtrList value_list = values->elements(); - for (size_t index = 0; index < keys_size; index++) { - key = CheckArg(op_name + "key", key_list, index); - ValuePtr keyPtr = key->BuildValue(); - MS_EXCEPTION_IF_NULL(keyPtr); - if (!keyPtr->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString(); - } - std::string key_string = GetValue(keyPtr); - key_value.emplace_back(key_string, value_list[index]); - } - return std::make_shared(key_value); -} - -AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a string and an object of a subclass of AbstractBase. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); - - ValuePtr keyPtr = key->BuildValue(); - if (!keyPtr->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString(); - } - std::string key_string = GetValue(keyPtr); - return std::make_shared(key_string, args_spec_list[1]); -} - -AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a string and a keyword. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 0); - AbstractKeywordArgPtr kwarg = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); - } - std::string key_input = GetValue(key_value); - std::string key_actual = kwarg->get_key(); - if (key_actual != key_input) { - MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is " - << key_input << ", AbstractKeywordArg' key is " << key_actual; - } - return kwarg->get_arg(); -} - -AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: three scalars whose value is an int32 number. - CheckArgsSize(primitive->name(), args_spec_list, 3); - size_t args_size = args_spec_list.size(); - for (size_t index = 0; index < args_size; index++) { - MS_EXCEPTION_IF_NULL(args_spec_list[index]); - if (!args_spec_list[index]->isa() && !args_spec_list[index]->isa()) { - MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone."; - } - if (args_spec_list[index]->isa() && - !dyn_cast(args_spec_list[index])->BuildValue()->isa()) { - MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number."; - } - } - // Slice: start, end, step - return std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); -} - -// Eval the return type of make_record -AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: at lease two objects of a subclass of AbstractBase. - if (args_spec_list.size() < 2) { - MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is " - << args_spec_list.size() << "."; - } - - // args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - TypePtr type = args_spec_list[0]->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(type); - if (type->type_id() != kMetaTypeTypeType) { - MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType"; - } - - ValuePtr value_track = args_spec_list[0]->GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - TypePtr type_ptr = value_track->cast(); - if (type_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString(); - } - - auto cls = dyn_cast(type_ptr); - MS_EXCEPTION_IF_NULL(cls); - ClassAttrVector attributes = cls->GetAttributes(); - CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1); - - std::vector abs_attributes; - for (size_t i = 0; i < attributes.size(); i++) { - AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]); - abs_attributes.push_back(elem); - } - - return std::make_shared(cls->tag(), abs_attributes, cls->methods()); -} - -template -AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple or list and a scalar whose value is an int32 number. - CheckArgsSize(op_name, args_spec_list, 2); - auto queue = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); - - ValuePtr index_value = index->BuildValue(); - if (!index_value->isa()) { - // when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element - // and continue - if (dyn_cast(queue->elements()[0]) != nullptr) { - return std::make_shared(queue->elements()[0]->BuildType()); - } - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " - << index_value->ToString(); - } - int idx_v = GetValue(index_value); - std::size_t nelems = queue->elements().size(); - if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) { - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", " - << SizeToInt(nelems) << "), but got " << idx_v << "."; - } - - std::size_t uidx_v = 0; - if (idx_v >= 0) { - uidx_v = IntToSize(idx_v); - } else { - uidx_v = IntToSize(idx_v + SizeToInt(nelems)); - } - return queue->elements()[uidx_v]; -} - -template -AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase. - CheckArgsSize(op_name, args_spec_list, 3); - auto queue = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr index = CheckArg(op_name, args_spec_list, 1); - - ValuePtr index_value = index->BuildValue(); - if (!index_value->isa()) { - MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got " - << index_value->ToString(); - } - int idx_v = GetValue(index_value); - if (idx_v < 0) { - MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v - << "."; - } - - size_t uidx_v = IntToSize(idx_v); - AbstractBasePtrList elements = queue->elements(); - std::size_t nelems = elements.size(); - if (uidx_v >= nelems) { - MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1 - << "."; - } - elements[uidx_v] = args_spec_list[2]; - return std::make_shared(elements); -} - -AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListGetItem(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListGetItem(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListSetItem(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListSetItem(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a dict and a scalar whose value is a string. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); - } - auto key_str = GetValue(key_value); - std::vector dict_elems = dict->elements(); - auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](const AbstractAttribute &item) { return item.first == key_str; }); - - if (it == dict_elems.end()) { - MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString(); - } - return it->second; -} - -AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - AbstractDictionaryPtr dict = CheckArg(op_name, args_spec_list, 0); - AbstractScalarPtr key = CheckArg(op_name, args_spec_list, 1); - - ValuePtr key_value = key->BuildValue(); - if (!key_value->isa()) { - MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString(); - } - std::string key_str = GetValue(key_value); - std::vector dict_elems = dict->elements(); - auto it = std::find_if(dict_elems.begin(), dict_elems.end(), - [key_str](AbstractAttribute &item) { return item.first == key_str; }); - - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - auto new_ele = std::make_pair(key_str, args_spec_list[2]); - if (it != dict_elems.end()) { - int index = it - dict_elems.begin(); - dict_elems[IntToSize(index)] = new_ele; - } else { - dict_elems.push_back(new_ele); - } - return std::make_shared(dict_elems); -} - -AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a list and an object of a subclass of AbstractBase. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractListPtr list = CheckArg(op_name, args_spec_list, 0); - (void)AbstractJoin(list->elements()); - return list; -} - -template -AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple or list or dict. - CheckArgsSize(op_name, args_spec_list, 1); - auto arg = CheckArg(op_name, args_spec_list, 0); - return std::make_shared(SizeToInt(arg->size())); -} - -AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferTupleOrListOrDictLen(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - return std::make_shared(kAnyValue, kInt32); -} - -AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: fn, list1, list2, ... - MS_EXCEPTION_IF_NULL(engine); - if (args_spec_list.size() <= 1) { - MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << "."; - } - AbstractFunctionPtr fn = CheckArg(primitive->name(), args_spec_list, 0); - // check args from 1. - CheckArgsSpec(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end())); - - AbstractBasePtrList subargs; - for (std::size_t i = 1; i < args_spec_list.size(); i++) { - AbstractListPtr l_ptr = dyn_cast(args_spec_list[i]); - if (l_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list."; - } - subargs.push_back(AbstractJoin(l_ptr->elements())); - } - EvalResultPtr engin_exc = engine->Execute(fn, subargs); - AbstractBasePtrList result; - for (std::size_t i = 1; i < args_spec_list.size(); i++) { - result.push_back(engin_exc->abstract()); - } - return std::make_shared(result); -} - -AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a fn, a list and an object of a subclass of a AbstractBase. - MS_EXCEPTION_IF_NULL(engine); - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); - AbstractFunctionPtr fn = CheckArg(op_name, args_spec_list, 0); - AbstractListPtr lst = CheckArg(op_name, args_spec_list, 1); - AbstractBasePtr dflt = args_spec_list[2]; - - AbstractBasePtr list_type = AbstractJoin(lst->elements()); - auto result1 = engine->Execute(fn, lst->elements()); - auto result2 = engine->Execute(fn, {dflt, list_type}); - MS_EXCEPTION_IF_NULL(result1->abstract()); - MS_EXCEPTION_IF_NULL(result2->abstract()); - return result1->abstract()->Join(result2->abstract()); -} - -AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTuplePtr input = CheckArg(op_name, args_spec_list, 0); - - auto tuple_elements = input->elements(); - AbstractBasePtrList elem_list; - (void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list), - [](const AbstractBasePtr &elem) { return elem->Clone(); }); - return std::make_shared(elem_list); -} - -AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value, - const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) { - size_t x_rank = x_shape->size(); - std::set axis_set; - auto axis_data = axis_value_ptr->value(); - if (axis_data.empty()) { - int size = 1; - AbstractBasePtrList values(x_rank, std::make_shared(size)); - return std::make_shared(values); - } - - for (auto &elem : axis_data) { - int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1); - (void)axis_set.insert(e_value); - } - - auto x_shp_data = x_shp_value->cast()->value(); - if (x_shp_data.size() < x_rank) { - MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank; - } - AbstractBasePtrList values; - for (size_t i = 0; i < x_rank; i++) { - if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) { - auto axis_v = MakeValue(1); - values.push_back(std::make_shared(axis_v, axis_v->type())); - } else { - int dim_value = x_shp_data[i]->cast()->value(); - auto dim = MakeValue(dim_value); - values.push_back(std::make_shared(dim, dim->type())); - } - } - - return std::make_shared(values); -} - -AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: x_shape, axis - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - - auto x_shp_value = shape_x->BuildValue(); - if (x_shp_value->isa()) { - MS_LOG(EXCEPTION) << op_name - << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); - } - - // Axis can be scalar, tuple or None - AbstractTuplePtr axis = nullptr; - if (args_spec_list[1]->isa()) { - MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar"; - AbstractBasePtrList axis_list = {dyn_cast(args_spec_list[1])}; - axis = std::make_shared(axis_list); - } else if (args_spec_list[1]->isa()) { - MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple"; - axis = args_spec_list[1]->cast(); - } else { - MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got " - << args_spec_list[1]->ToString(); - } - - auto axis_value = axis->BuildValue(); - if (axis_value->isa()) { - MS_LOG(EXCEPTION) << op_name - << " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString(); - } - auto axis_value_ptr = axis_value->cast(); - MS_EXCEPTION_IF_NULL(axis_value_ptr); - - return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive); -} - -AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples. - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); - AbstractTuplePtr div_shp = CheckArg(op_name, args_spec_list, 1); - MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString(); - - auto div_shp_value = div_shp->BuildValue(); - if (div_shp_value->isa()) { - MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString(); - } - - auto shpx_value = shape_x->BuildValue(); - if (shpx_value->isa()) { - MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString(); - } - - if (div_shp->size() != shape_x->size()) { - MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size() - << ", shapex: " << shape_x->size() << "."; - } - - auto shpx_data = shpx_value->cast()->value(); - auto div_shp_data = div_shp_value->cast()->value(); - AbstractBasePtrList values; - - for (size_t i = 0; i < div_shp_data.size(); i++) { - if (div_shp_data[i]->cast() == nullptr) { - MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString(); - } - int shapex_value = GetValue(shpx_data[i]); - int div_value = GetValue(div_shp_data[i]); - MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value; - if (div_value == 0) { - MS_LOG(EXCEPTION) << "error: division value should not be 0!"; - } - if ((shapex_value % div_value) != 0) { - MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value; - } - - int result = shapex_value / div_value; - auto result_v = MakeValue(result); - values.push_back(std::make_shared(result_v, result_v->type())); - } - - return std::make_shared(values); -} - -AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTuplePtr input = CheckArg(op_name, args_spec_list, 0); - - py::tuple data_tuple = ValuePtrToPyData(input->BuildValue()); - py::array data = py::array(data_tuple); - auto tensor = TensorPy::MakeTensor(data); - auto ret = tensor->ToAbstract(); - ret->set_value(tensor); - MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString(); - return ret; -} - -AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tuple - // example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6 - const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 1); - AbstractTuplePtr shape_x = CheckArg(op_name, args_spec_list, 0); - - auto shpx_value = shape_x->BuildValue(); - if (shpx_value->isa()) { - MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString(); - } - - auto shpx_data = shpx_value->cast()->value(); - - int result = 1; - for (size_t i = 0; i < shpx_data.size(); i++) { - int value = GetValue(shpx_data[i]); - result = IntMulWithOverflowCheck(result, value); - } - - auto result_v = MakeValue(result); - MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString(); - return std::make_shared(result_v, result_v->type()); -} - -template -AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { - // Inputs: two tuples or two lists. - CheckArgsSize(op_name, args_spec_list, 2); - auto input_x = CheckArg(op_name, args_spec_list, 0); - auto input_y = CheckArg(op_name, args_spec_list, 1); - - ValuePtr x_value = input_x->BuildValue(); - ValuePtr y_value = input_y->BuildValue(); - return std::make_shared(*x_value == *y_value); -} - -AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferImplTupleOrListEqual(primitive->name(), args_spec_list); -} - -AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - return InferImplTupleOrListEqual(primitive->name(), args_spec_list); -} - -struct SlideInfo { - int start; - int step; - int stop; -}; - -void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) { - int arg1 = 0; - int arg2 = 0; - if (!args_spec_list.empty()) { - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - auto arg_value = args_spec_list[0]->BuildValue(); - if (!arg_value->isa()) { - MS_LOG(EXCEPTION) << "Only supported input an int32 number."; - } - arg1 = GetValue(arg_value); - } - - if (args_spec_list.size() >= 2) { - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - auto arg_value = args_spec_list[1]->BuildValue(); - if (!arg_value->isa()) { - MS_LOG(EXCEPTION) << "Only supported input an int32 number."; - } - arg2 = GetValue(arg_value); - } - - if (args_spec_list.size() == 3) { - MS_EXCEPTION_IF_NULL(args_spec_list[2]); - auto arg_value = args_spec_list[2]->BuildValue(); - if (!arg_value->isa()) { - MS_LOG(EXCEPTION) << "Only supported input an int32 number."; - } - slide->step = GetValue(arg_value); - slide->start = arg1; - slide->stop = arg2; - } - - if (args_spec_list.size() == 2) { - slide->start = arg1; - slide->stop = arg2; - } - - if (args_spec_list.size() == 1) { - slide->stop = arg1; - } -} - -AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "Cannot make range from empty input."; - } - - if (args_spec_list.size() > 3) { - MS_LOG(EXCEPTION) << "Error args size of make range operational."; - } - - SlideInfo slide = {0, 1, 0}; - CalcSlidePara(args_spec_list, &slide); - - if (slide.step == 0) { - MS_LOG(EXCEPTION) << "Error, step value is 0."; - } - - AbstractBasePtrList args; - if (slide.start <= slide.stop) { - if (slide.step <= 0) { - MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; - } - for (int i = slide.start; i < slide.stop; i += slide.step) { - args.push_back(abstract::FromValue(i)); - } - } else { - if (slide.step >= 0) { - MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]"; - } - for (int i = slide.start; i > slide.stop; i += slide.step) { - args.push_back(abstract::FromValue(i)); - } - } - - return std::make_shared(args); -} - -AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor - CheckArgsSize(primitive->name(), args_spec_list, 1); - return args_spec_list[0]->Clone(); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_to_function.cc b/mindspore/ccsrc/operator/prim_to_function.cc deleted file mode 100644 index 733cdbdb73..0000000000 --- a/mindspore/ccsrc/operator/prim_to_function.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "operator/prim_to_function.h" -#include -#include -#include - -namespace mindspore { -// namespace to support prim related definition -namespace prim { - -PrimToFunction::PrimToFunction() - : prim_func_type_map_({// ONE_ARG prim - {"bool_not", kPrimTypeOneArg}, - {"scalar_cos", kPrimTypeOneArg}, - {"scalar_exp", kPrimTypeOneArg}, - {"scalar_floor", kPrimTypeOneArg}, - {"scalar_log", kPrimTypeOneArg}, - {"scalar_sin", kPrimTypeOneArg}, - {"scalar_tan", kPrimTypeOneArg}, - {"scalar_trunc", kPrimTypeOneArg}, - {"typeof", kPrimTypeOneArg}, - {"scalar_uadd", kPrimTypeOneArg}, - {"scalar_usub", kPrimTypeOneArg}, - // TWO_ARGS prim - {"scalar_add", kPrimTypeTwoArgs}, - {"bool_and", kPrimTypeTwoArgs}, - {"bool_eq", kPrimTypeTwoArgs}, - {"bool_or", kPrimTypeTwoArgs}, - {"scalar_div", kPrimTypeTwoArgs}, - {"scalar_eq", kPrimTypeTwoArgs}, - {"scalar_ge", kPrimTypeTwoArgs}, - {"scalar_gt", kPrimTypeTwoArgs}, - {"scalar_le", kPrimTypeTwoArgs}, - {"scalar_lt", kPrimTypeTwoArgs}, - {"scalar_ne", kPrimTypeTwoArgs}, - {"scalar_mod", kPrimTypeTwoArgs}, - {"scalar_mul", kPrimTypeTwoArgs}, - {"scalar_pow", kPrimTypeTwoArgs}, - {"scalar_sub", kPrimTypeTwoArgs}, - {"scalar_floordiv", kPrimTypeTwoArgs}}) {} - -bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { - bool result = false; - - if (func != nullptr) { - int args_num = GetPrimType(prim); - std::vector one_arg{std::make_shared()}; - std::vector two_args{std::make_shared(), std::make_shared()}; - TypePtr retval = std::make_shared(); - result = true; - switch (args_num) { - case kPrimTypeOneArg: - *func = Function(one_arg, retval).DeepCopy()->cast(); - break; - case kPrimTypeTwoArgs: - *func = Function(two_args, retval).DeepCopy()->cast(); - break; - default: - result = false; - break; - } - } - - return result; -} - -int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { - MS_EXCEPTION_IF_NULL(prim); - int prim_type = static_cast(kPrimTypeUnknown); - - auto value = prim_func_type_map_.find(prim->name()); - if (value != prim_func_type_map_.end()) { - prim_type = value->second; - } - return prim_type; -} -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/CMakeLists.txt b/mindspore/ccsrc/optimizer/CMakeLists.txt deleted file mode 100644 index 44af01735a..0000000000 --- a/mindspore/ccsrc/optimizer/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -file(GLOB_RECURSE _OPTIMIZER_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -set_property(SOURCE ${_OPTIMIZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_OPTIMIZER) -add_library(_mindspore_optimizer_obj OBJECT ${_OPTIMIZER_SRC_FILES}) diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.cc b/mindspore/ccsrc/optimizer/ad/adjoint.cc deleted file mode 100644 index ed89aba20e..0000000000 --- a/mindspore/ccsrc/optimizer/ad/adjoint.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/ad/adjoint.h" - -#include -#include - -#include "ir/anf.h" -#include "optimizer/ad/dfunctor.h" - -namespace mindspore { -namespace ad { -Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) - : primal_(primal), caller_(caller), dout_(nullptr) { - if (k != nullptr) { - k_ = k; - MS_LOG(DEBUG) << "Add adjoint for " << primal->ToString() << " " << k_->ToString(); - } else { - // Init k hole in a recursive case. - auto k_hole = std::make_shared("k_hole"); - (void)k_hole->AddAttr("info", MakeValue(primal->ToString())); - k_ = NewValueNode(k_hole); - MS_LOG(DEBUG) << "Add hole for " << primal->ToString() << " " << k_->ToString(); - } - - dout_hole_ = caller_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), k_}); - RegisterKUser(dout_hole_->cast(), 1); -} - -AnfNodePtr Adjoint::k() { return k_; } - -void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } - -void Adjoint::UpdateK(const AnfNodePtr &new_k) { - MS_EXCEPTION_IF_NULL(new_k); - MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); - // In recursive case, it needs update. - for (auto &user : k_user_) { - MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" - << new_k->ToString(); - if (user.first->input(user.second) != k_) { - MS_LOG(EXCEPTION) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k " - << new_k->ToString() << ", user relation is set wrongly"; - } - user.first->set_input(user.second, new_k); - } - k_ = new_k; -} - -AnfNodePtr Adjoint::primal() { return primal_; } - -AnfNodePtr Adjoint::dout() { return dout_hole_; } - -void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { - dout_user_.emplace_back(std::make_pair(user, index)); -} - -void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { - if (dout_ != nullptr) { - MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); - auto add = prim::GetPythonOps("hyper_add"); - dout_ = caller_->NewCNode({NewValueNode(add), dout_, dout_factor}); - return; - } - dout_ = dout_factor; -} - -void Adjoint::CallDoutHole() { - if (dout_ != nullptr) { - for (auto &user : dout_user_) { - MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " - << dout_->ToString(); - if (user.first->input(user.second) != dout_hole_) { - MS_LOG(EXCEPTION) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " - << dout_->ToString() << ", user relation is set wrongly"; - } - user.first->set_input(user.second, dout_); - } - } -} -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.h b/mindspore/ccsrc/optimizer/ad/adjoint.h deleted file mode 100644 index b2dae8e66f..0000000000 --- a/mindspore/ccsrc/optimizer/ad/adjoint.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ - -#include -#include -#include - -#include "ir/anf.h" -#include "optimizer/opt.h" - -namespace mindspore { -namespace ad { -class Adjoint { - public: - Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); - ~Adjoint() = default; - AnfNodePtr primal(); - AnfNodePtr k(); - void UpdateK(const AnfNodePtr &k); - void RegisterKUser(const CNodePtr &user, size_t index); - AnfNodePtr dout(); - void AccumulateDout(const AnfNodePtr &dout_factor); - void RegisterDoutUser(const CNodePtr &user, size_t index); - void CallDoutHole(); - - private: - AnfNodePtr primal_; - FuncGraphPtr caller_; - // For ```def f(x): return expr```, The representation graph k is ```def kf(kx): return expr, bprop{expr}```. - AnfNodePtr k_; - std::vector> k_user_; - AnfNodePtr dout_; - AnfNodePtr dout_hole_; - std::vector> dout_user_; -}; - -using AdjointPtr = std::shared_ptr; -} // namespace ad -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_ADJOINT_H_ diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc deleted file mode 100644 index 308f1dd352..0000000000 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ /dev/null @@ -1,617 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/ad/dfunctor.h" - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/meta_func_graph.h" -#include "debug/info.h" -#include "ir/func_graph_cloner.h" -#include "ir/manager.h" -#include "pipeline/resource.h" -#include "pipeline/parse/parse.h" -#include "optimizer/ad/adjoint.h" -#include "optimizer/opt.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "utils/symbolic.h" -#include "utils/context/ms_context.h" -#include "./common.h" - -namespace mindspore { -namespace ad { -std::unordered_map DFunctor::func_graph_to_functor_; -std::unordered_map DFunctor::anfnode_to_adjoin_definition_; -FuncGraphSet DFunctor::scope_; - -DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) - : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { - TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); - k_graph_ = std::make_shared(); - if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); - } - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(primal_graph->debug_info())); - tape_ = std::make_shared(); - // Add "_Grad" postfix - if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - std::string grad_op_name = GetValue(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad"; - tape_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); - } - TraceManager::EndTrace(); - - dout_ = tape_->add_parameter(); -} - -void DFunctor::Init(bool is_top) { - func_graph_to_functor_[primal_graph_] = shared_from_this(); - is_top_ = is_top; - if (is_top) { - scope_ = primal_graph_->scope(); - } -} - -void DFunctor::Clear() { - func_graph_to_functor_.clear(); - anfnode_to_adjoin_definition_.clear(); - scope_.clear(); -} - -void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { - auto fv_adjoint = anfnode_to_adjoin_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_.end()) { - MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() - << " " << fv->ToString() << "."; - fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) { - MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - auto parent_adjoint = FindAdjoint(fv); - AdjointPtr adjoint = nullptr; - if (parent_adjoint != nullptr) { - adjoint = std::make_shared(fv, parent_adjoint->k(), tape_); - } else { - MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - adjoint = std::make_shared(fv, nullptr, tape_); - } - anfnode_to_adjoin_indirect_fv_[fv] = adjoint; - fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv); - } - } - auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(node, 1); - auto default_val = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(default_val, 1); - auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, node, default_val}); - MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv " - << fv->func_graph()->ToString() << " " << fv->ToString() << "."; - MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << node->ToString() << "."; - fv_adjoint->second->AccumulateDout(dfv); -} - -void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) { - // Take switch_layer as a set of candidate functions. - auto input = cnode_morph->input(2); - if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { - MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << "."; - } - auto tuple_graphs = input->cast(); - for (size_t i = 1; i < tuple_graphs->size(); ++i) { - auto graph = tuple_graphs->input(i); - if (!IsValueNode(graph)) { - MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString() - << " as the " << i << "th element."; - } - auto func_graph = GetValueNode(graph); - auto functor = func_graph_to_functor_.find(func_graph); - if (functor == func_graph_to_functor_.end()) { - MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] " - << func_graph->ToString() << "."; - } - // Consider direct and indirect fvs. - for (auto fv : func_graph->free_variables_nodes()) { - BackPropagateFv(fv, env); - } - for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { - MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " " - << indirect_fv.first->ToString() << "."; - BackPropagateFv(indirect_fv.first, env); - } - } -} - -void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { - auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)}); - // Call with delimited continuation dout. - auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); - node_adjoint->RegisterDoutUser(bprop_app, 1); - // Special case for switch_layer - if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { - auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)}); - BackPropagateSwitchLayer(cnode_morph, din); - return; - } - for (size_t i = 0; i < cnode_morph->size(); i++) { - auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))}); - auto input = cnode_morph->input(i); - // Backprop sens wrt fvs. - if (IsValueNode(input)) { - auto func_graph = GetValueNode(input); - auto functor = func_graph_to_functor_.find(func_graph); - if (functor == func_graph_to_functor_.end()) { - MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] " - << func_graph->ToString() << "."; - } - // Consider direct and indirect fvs. - for (auto fv : func_graph->free_variables_nodes()) { - BackPropagateFv(fv, din); - } - for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) { - MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " " - << indirect_fv.first->ToString() << "."; - BackPropagateFv(indirect_fv.first, din); - } - continue; - } - // Backprop sens wrt inputs. - auto input_adjoint = anfnode_to_adjoin_.find(input); - if (input_adjoint == anfnode_to_adjoin_.end()) { - MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << "."; - } - input_adjoint->second->AccumulateDout(din); - } -} - -// Map a morphism. -AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { - // MapMorphism All type except CNode should already be mapped by MapObject. - if (!morph->isa()) { - return nullptr; - } - ScopeGuard scope_guard(morph->scope()); - auto cnode_morph = morph->cast(); - - std::vector inputs; - std::vector param_adjoints; - for (size_t i = 0; i < cnode_morph->size(); i++) { - auto node = cnode_morph->input(i); - auto node_adjoint_iter = anfnode_to_adjoin_.find(node); - AdjointPtr node_adjoint = nullptr; - AnfNodePtr k = nullptr; - if (node_adjoint_iter != anfnode_to_adjoin_.end()) { - node_adjoint = node_adjoint_iter->second; - } else { - // Input might be a CNode that needs to be handled before hand. - node_adjoint = MapMorphism(node); - } - MS_EXCEPTION_IF_NULL(node_adjoint); - k = node_adjoint->k(); - if (k == nullptr) { - MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << "."; - } - inputs.push_back(k); - param_adjoints.push_back(node_adjoint); - } - TraceManager::DebugTrace(std::make_shared(cnode_morph->debug_info())); - auto k_app = k_graph_->NewCNode(inputs); - TraceManager::EndTrace(); - for (size_t i = 0; i < param_adjoints.size(); ++i) { - param_adjoints[i]->RegisterKUser(k_app, i); - } - - // Do forward computation - auto foward_app = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(0)}); - // K:: cnode -> forward_app - auto node_adjoint = std::make_shared(morph, foward_app, tape_); - UpdateAdjoint(node_adjoint); - anfnode_to_adjoin_[morph] = node_adjoint; - if (cnode_morph->stop_gradient()) { - MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped."; - return node_adjoint; - } - - // Do sens backpropagation - BackPropagate(cnode_morph, k_app, node_adjoint); - MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << "."; - return node_adjoint; -} - -bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { - // Do not care about non-CNode - if (!node->isa()) { - return false; - } - // Do not care about kPrimReturn - if (IsPrimitiveCNode(node, prim::kPrimReturn)) { - return false; - } - auto &users = primal_graph_->manager()->node_users()[node]; - // Do not care about isolated morphisms - if (users.empty()) { - return false; - } - // Not free if it's used by some node in primal_graph - bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) { - auto &user = kv.first; - return user->func_graph() == primal_graph_; - }); - return !nonfree; -} - -void DFunctor::MapFreeMorphism() { - // Handle cnode not attached to output, that might be refered in other functions. - for (auto &node : primal_graph_->nodes()) { - if (!IsFreeMorphism(node)) { - continue; - } - MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << "."; - (void)MapMorphism(node); - } -} - -AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { - AnfNodePtr new_grad_fv = grad_fv; - // Add grads wrt fv. - const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); - for (auto &fv : free_variables_nodes) { - auto fv_adjoint = anfnode_to_adjoin_.find(fv); - if (fv_adjoint == anfnode_to_adjoin_.end()) { - MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << "."; - } - auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()}); - fv_adjoint->second->RegisterKUser(node, 1); - auto sens = fv_adjoint->second->dout(); - new_grad_fv = tape_->NewCNode({ - NewValueNode(prim::kPrimEnvSetItem), - new_grad_fv, - node, - sens, - }); - fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast(), 3); - MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " " - << fv->ToString() << " " << primal_graph_->ToString() << "."; - } - return new_grad_fv; -} - -AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { - AnfNodePtr new_grad_fv = grad_fv; - // Add indirect fv bprop. - for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) { - MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " " - << primal_graph_->ToString() << "."; - auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()}); - fv_adjoint.second->RegisterKUser(node, 1); - auto sens = fv_adjoint.second->dout(); - new_grad_fv = tape_->NewCNode({ - NewValueNode(prim::kPrimEnvSetItem), - new_grad_fv, - node, - sens, - }); - fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast(), 3); - MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to " - << new_grad_fv->ToString() << "."; - } - return new_grad_fv; -} - -void DFunctor::MapMorphism() { - // Set stop_gradient before MapMorphism. - BroadCastStopFlag(); - - // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent - MapFreeMorphism(); - // Handle morphism from output. - (void)MapMorphism(primal_graph_->output()); - - // Construct K for primal_graph_ - auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); - // Attach dout_ parameter to output_adjoint. - output_adjoint->second->AccumulateDout(dout_); - - // Set output for tape closure. - auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); - - std::vector inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv}; - // Add grads wrt inputs. - std::vector param_adjoints; - for (auto ¶m : primal_graph_->parameters()) { - auto param_adjoint = anfnode_to_adjoin_.find(param); - inputs.push_back(param_adjoint->second->dout()); - param_adjoints.push_back(param_adjoint->second); - } - auto tape_output = tape_->NewCNode(inputs); - for (size_t i = 0; i < param_adjoints.size(); ++i) { - param_adjoints[i]->RegisterDoutUser(tape_output, i + 2); - } - tape_->set_output(tape_output); - // Set output for k_graph_, K:: cnode->forward_app. - auto forward_app = output_adjoint->second->k(); - auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)}); - output_adjoint->second->RegisterKUser(output, 1); - k_graph_->set_output(output); - (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_))); - (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_))); -} - -FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { - // K user defined cell bprop. - auto bprop = primal->transforms().find("bprop"); - if (bprop != primal->transforms().end()) { - FuncGraphPtr bprop_graph = bprop->second.func_graph(); - resources_->manager()->AddFuncGraph(bprop_graph); - - if (bprop_graph->free_variables_nodes().size() != 0 || primal->free_variables_nodes().size() != 0) { - MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " - << primal->output()->scope()->name() << " does not support Parameter data type."; - } - auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph); - if (fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope " - << primal->output()->scope()->name() << "."; - } - - // Cache the grad func - (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg))); - (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal))); - // Reset defer_inline to enable successive inlining - primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); - - auto functor = std::make_shared(primal, resources_); - functor->Init(); - functor->k_graph_ = fg; - - return fg; - } - return nullptr; -} - -// MapToK(func) -AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { - auto f = func_graph_to_functor_.find(primal); - if (f != func_graph_to_functor_.end()) { - MS_LOG(DEBUG) << "K graph functor already exist " << primal->ToString() << "."; - return NewValueNode(f->second->k_graph_); - } - - auto k_user_defined = KUserDefined(primal); - if (k_user_defined != nullptr) { - MS_LOG(DEBUG) << "K graph functor user defined bprop " << primal->ToString() << "."; - return NewValueNode(k_user_defined); - } - - auto functor = std::make_shared(primal, resources_); - functor->Init(); - functor->MapObject(); - functor->MapMorphism(); - - MS_LOG(DEBUG) << "K graph K function graph " << primal->ToString() << " " << functor->k_graph_->ToString() << "."; - return NewValueNode(functor->k_graph_); -} - -// Construct representation graph for given node. -AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { - ScopeGuard scope_guard(primal->scope()); - // MapToK(prim) - if (IsValueNode(primal)) { - auto value_node = primal->cast(); - auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { - MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; - need_cut_ = true; - } - auto k_prim = g_k_prims.KPrimitive(value_node, resources_); - if (k_prim != nullptr) { - return NewValueNode(k_prim); - } - // When failed to find k_prim, try k_meta. - auto k_meta = g_k_prims.KMetaFuncGraph(prim); - if (k_meta != nullptr) { - return NewValueNode(k_meta); - } - } - - // MapToK(func) - if (IsValueNode(primal)) { - auto func_graph = GetValueNode(primal); - auto k_func = MapToK(func_graph); - return k_func; - } - - if (primal->isa()) { - TraceManager::DebugTrace(std::make_shared(primal->debug_info())); - auto ret = k_graph_->add_parameter(); - TraceManager::EndTrace(); - return ret; - } - - if (!primal->isa()) { - MS_LOG(EXCEPTION) << "K node keeped node from primal_graph_ " << primal->ToString() << " that is not a ValueNode."; - } - return primal; -} - -bool DFunctor::IsInScope(const AnfNodePtr &node) { - return std::any_of(scope_.begin(), scope_.end(), - [&](const FuncGraphPtr &graph) { return node->func_graph() == graph; }); -} - -void DFunctor::MapFvObject() { - // Map free variable. - const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); - for (auto &node : free_variables_nodes) { - ScopeGuard scope_guard(node->scope()); - MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << "."; - // Find fv's K from parent. - AdjointPtr adjoint = nullptr; - auto parent_adjoint = FindAdjoint(node); - if (parent_adjoint != nullptr) { - adjoint = std::make_shared(node, parent_adjoint->k(), tape_); - } else { - if (is_top_ || node->isa() || !IsInScope(node)) { - // Out of ad scope, add adjoint for free variables. - adjoint = std::make_shared(node, node, tape_); - UpdateAdjoint(adjoint); - } else { - MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << "."; - adjoint = std::make_shared(node, nullptr, tape_); - } - } - if (adjoint == nullptr) { - MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << "."; - } - anfnode_to_adjoin_[node] = adjoint; - } -} - -void DFunctor::MapParamObject() { - // Map parameter. - for (auto &p : primal_graph_->parameters()) { - ScopeGuard scope_guard(p->scope()); - MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << "."; - auto adjoint = std::make_shared(p, MapToK(p), tape_); - UpdateAdjoint(adjoint); - anfnode_to_adjoin_[p] = adjoint; - } -} - -void DFunctor::MapValueObject() { - // Map ValueNode. - auto manager = resources_->manager(); - auto &value_nodes = primal_graph_->value_nodes(); - for (const auto &value_pair : value_nodes) { - auto node = value_pair.first; - auto parent_adjoint = FindAdjoint(node); - if (parent_adjoint != nullptr) { - auto adjoint = std::make_shared(node, parent_adjoint->k(), tape_); - anfnode_to_adjoin_[node] = adjoint; - continue; - } - // Skip Return. - if (IsValueNode(node) && GetValueNode(node) == prim::kPrimReturn) { - continue; - } - MS_LOG(DEBUG) << "MapValueObject node " << node->ToString() << "."; - auto adjoint = std::make_shared(node, MapToK(node), tape_); - UpdateAdjoint(adjoint); - anfnode_to_adjoin_[node] = adjoint; - } -} - -// Skip morphism. -void DFunctor::MapObject() { - // The order does not matter - MapFvObject(); - MapParamObject(); - MapValueObject(); -} - -void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) { - auto primal = adjoint_definition->primal(); - if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) { - MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " " - << primal->ToString() << "."; - } - anfnode_to_adjoin_definition_[primal] = adjoint_definition; - // Update k hole for primal. - for (auto &f : func_graph_to_functor_) { - auto adjoint = f.second->anfnode_to_adjoin_.find(primal); - if (adjoint != f.second->anfnode_to_adjoin_.end()) { - adjoint->second->UpdateK(adjoint_definition->k()); - } - adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal); - if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) { - adjoint->second->UpdateK(adjoint_definition->k()); - } - } -} - -AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) { - auto adjoint = anfnode_to_adjoin_definition_.find(primal); - if (adjoint != anfnode_to_adjoin_definition_.end()) { - MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << "."; - return adjoint->second; - } - MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << "."; - return nullptr; -} - -void DFunctor::CallDoutHoleOnTape() { - if (!is_top_) { - return; - } - - // Call dout hole of all adjoint. - for (auto &f : func_graph_to_functor_) { - for (auto &adjoint : f.second->anfnode_to_adjoin_) { - adjoint.second->CallDoutHole(); - } - for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) { - adjoint.second->CallDoutHole(); - } - } -} -FuncGraphPtr DFunctor::k_graph() { - CallDoutHoleOnTape(); - return k_graph_; -} - -void DFunctor::BroadCastStopFlag() { - // As stop set expanding, all directly or indirectly stopped CNode will be cut off - while (need_cut_) { - need_cut_ = false; - for (auto &node : primal_graph_->nodes()) { - if (node->isa()) { - auto cnode = node->cast(); - if (!cnode->stop_gradient()) { - // Cut off the cnode only when it's not referred any more - if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) { - MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; - cnode->set_stop_gradient(true); - // The stop set changed, more cut required - need_cut_ = true; - } - } - } - } - } -} - -bool DFunctor::AllReferencesStopped(const CNodePtr &node) { - auto &users = primal_graph_->manager()->node_users()[node]; - // Only care about stop_gradient caused cutting - if (users.empty()) { - return false; - } - for (auto &kv : users) { - auto &user = kv.first; - if (!user->isa() || !user->cast()->stop_gradient()) { - return false; - } - } - return true; -} -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h deleted file mode 100644 index 09c0f54fc8..0000000000 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ /dev/null @@ -1,210 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/meta_func_graph.h" -#include "ir/func_graph_cloner.h" -#include "pipeline/resource.h" -#include "optimizer/ad/adjoint.h" -#include "operator/ops.h" -#include "debug/trace.h" - -namespace mindspore { -namespace ad { -struct PrimitiveTotalEqual { - bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { - MS_EXCEPTION_IF_NULL(t1); - MS_EXCEPTION_IF_NULL(t2); - return *t1 == *t2; - } -}; - -using Registry = std::unordered_map; -class KPrim; -extern KPrim g_k_prims; -class DFunctor; -using DFunctorPtr = std::shared_ptr; - -// D Functor's rules to map closure object and morphisms. -class DFunctor : public std::enable_shared_from_this { - public: - DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); - ~DFunctor() = default; - // Map object in D category to K category. - void MapObject(); - // Map morphism in D category to K category. - void MapMorphism(); - FuncGraphPtr k_graph(); - // Construct user defined k object. - FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); - // Register functor objects to form a global view. - void Init(bool is_top = false); - bool IsInScope(const AnfNodePtr &node); - - // Clear resources. - static void Clear(); - - private: - // Map one morphism. - AdjointPtr MapMorphism(const AnfNodePtr &morph); - bool IsFreeMorphism(const AnfNodePtr &node); - // Map morphism that's not attached to output. - void MapFreeMorphism(); - void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din); - void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env); - void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); - AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); - AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); - // Map Anfnode object from D category to K category. - AnfNodePtr MapToK(const AnfNodePtr &primal); - // Map FuncGraph object from D category to K category. - AnfNodePtr MapToK(const FuncGraphPtr &primal); - // MapObject impls. - void MapFvObject(); - void MapValueObject(); - void MapParamObject(); - // Find adjoint with its primary k. - AdjointPtr FindAdjoint(const AnfNodePtr &primal); - // Broadcast stop flags. - void BroadCastStopFlag(); - bool AllReferencesStopped(const CNodePtr &node); - // Update k hole with adjoint_definition, only applied in recursive case. - void UpdateAdjoint(const AdjointPtr &adjoint_definition); - void CallDoutHoleOnTape(); - - std::unordered_map anfnode_to_adjoin_; - // Cache for indirect fv backpropagation, K o K can only do backprop layer by layer. - std::unordered_map anfnode_to_adjoin_indirect_fv_; - FuncGraphPtr primal_graph_; - // K object for primal_graph_; - FuncGraphPtr k_graph_; - // The Backprop part of k_graph_. - FuncGraphPtr tape_; - // Dout parameter for primal_graph_. - AnfNodePtr dout_; - pipeline::ResourceBasePtr resources_; - // Cut off stopped objects in category D. - bool need_cut_; - bool is_top_; - static std::unordered_map> func_graph_to_functor_; - static std::unordered_map anfnode_to_adjoin_definition_; - static FuncGraphSet scope_; -}; - -// D Functor's rules to map primitive object. -class KPrim { - public: - KPrim() = default; - ~KPrim() = default; - - FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); - MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); - FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); - - void clear() { - bprop_registry_meta_.clear(); - bprop_registry_.clear(); - } - - private: - FuncGraphPtr GetBprop(const PrimitivePtr &prim); - FuncGraphPtr GetFprop(const PrimitivePtr &prim); - FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); - FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); - // Given a bprop rule, do the K mapping. - template - FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); - AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); - void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, - std::vector *const transf_args); - void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); - - Registry bprop_registry_; - std::unordered_map bprop_registry_meta_; -}; - -template -FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { - MS_EXCEPTION_IF_NULL(primal); - MS_EXCEPTION_IF_NULL(bprop_fg); - CheckBprop(bprop_fg, primal->ToString()); - - auto debug_info = std::make_shared(); - debug_info->set_name(primal->ToString()); - - auto cloned_bprop_fg = BasicClone(bprop_fg); - MS_EXCEPTION_IF_NULL(cloned_bprop_fg); - - cloned_bprop_fg->debug_info()->set_name(""); - cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); - - AnfNodePtr bout = BuildOutput(cloned_bprop_fg); - cloned_bprop_fg->set_output(bout); - - TraceManager::DebugTrace(std::make_shared(debug_info)); - auto outer = std::make_shared(); - (void)outer->transforms().emplace("primal", FuncGraphTransform(primal)); - outer->set_output(NewValueNode(kNone)); - TraceManager::EndTrace(); - - auto mng = Manage({cloned_bprop_fg, outer}, false); - - // Make sure (out, dout) provided. - if (cloned_bprop_fg->parameters().size() < 2) { - MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() - << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() - << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); - } - - // In a bprop definition, the last two param should be out and dout. - auto dout = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 1]; - auto out_param = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 2]; - std::vector transf_args; - TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); - - TraceManager::DebugTrace(std::make_shared(dout->debug_info())); - (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); - auto out_value = outer->NewCNode(transf_args); - TraceManager::EndTrace(); - - (void)mng->Replace(out_param, out_value); - - TraceManager::DebugTrace(std::make_shared(out_param->debug_info())); - auto new_dout = cloned_bprop_fg->add_parameter(); - (void)mng->Replace(dout, new_dout); - // We remove all parameters except new_dout. - std::vector newBpropParams = {new_dout}; - cloned_bprop_fg->set_parameters(newBpropParams); - TraceManager::EndTrace(); - - outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); - return BasicClone(outer); -} -} // namespace ad -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_D_FUNCTOR_H_ diff --git a/mindspore/ccsrc/optimizer/ad/grad.cc b/mindspore/ccsrc/optimizer/ad/grad.cc deleted file mode 100644 index d141dc6eea..0000000000 --- a/mindspore/ccsrc/optimizer/ad/grad.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/ad/grad.h" -#include "optimizer/ad/dfunctor.h" -#include "ir/func_graph_cloner.h" -#include "utils/context/ms_context.h" -#include "utils/symbolic.h" -#include "utils/graph_utils.h" - -namespace mindspore { -namespace ad { -FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { - MS_EXCEPTION_IF_NULL(func_graph); - auto gradkv = func_graph->transforms().find("grad"); - if (gradkv != func_graph->transforms().end()) { - return gradkv->second.func_graph(); - } - - auto manager_ptr = resources->manager(); - MS_EXCEPTION_IF_NULL(manager_ptr); - manager_ptr->AddFuncGraph(func_graph); - - auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { - if (MsContext::GetInstance()->is_multi_graph_sink()) { - if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - f->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - } - } - }; - - auto f = std::make_shared(func_graph, resources); - auto user_defined = f->KUserDefined(func_graph); - if (user_defined != nullptr) { - multi_graph_sink(user_defined); - if (is_top) { - DFunctor::Clear(); - } - return user_defined; - } - f->Init(is_top); - f->MapObject(); - f->MapMorphism(); - auto ret = f->k_graph(); - if (is_top) { - DFunctor::Clear(); - } - - multi_graph_sink(ret); - return ret; -} - -FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto fg = g_k_prims.KPrimitive(value_node, resources); - if (fg == nullptr) { - return nullptr; - } - return BasicClone(fg); -} - -MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &) { - MetaFuncGraphPtr fg = g_k_prims.KMetaFuncGraph(prim); - return fg; -} - -void CleanRes() { DFunctor::Clear(); } -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/ad/grad.h b/mindspore/ccsrc/optimizer/ad/grad.h deleted file mode 100644 index a878aa9df7..0000000000 --- a/mindspore/ccsrc/optimizer/ad/grad.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ - -#include -#include - -#include "ir/anf.h" -#include "ir/meta_func_graph.h" -#include "pipeline/resource.h" - -namespace mindspore { -namespace ad { -using ResourcePtr = std::shared_ptr; - -FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top = true); -FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); -MetaFuncGraphPtr Kmeta(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &); -void CleanRes(); -} // namespace ad -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_AD_GRAD_H_ diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc deleted file mode 100644 index bdec1dc93c..0000000000 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ /dev/null @@ -1,291 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "ir/anf.h" -#include "ir/primitive_py.h" -#include "ir/meta_func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/manager.h" -#include "pipeline/resource.h" -#include "pipeline/parse/parse.h" -#include "optimizer/ad/dfunctor.h" -#include "optimizer/opt.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "utils/symbolic.h" -#include "utils/primitive_utils.h" -#include "utils/context/ms_context.h" -#include "debug/info.h" -#include "debug/trace.h" - -#include "./common.h" - -namespace mindspore { -namespace ad { -using PatternListType = std::initializer_list; -KPrim g_k_prims; - -FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { - // Set a child scope named "grad'PrimitiveName'" for the bprop function, - // and add "Gradients" to the front. - static const std::string gradients_scope = "Gradients/"; - static const std::string grad_op_child_scope_prefix = "/grad"; - MS_EXCEPTION_IF_NULL(prim); - auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + - grad_op_child_scope_prefix + prim->name()); - ScopeGuard scope_guard(scope); - py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast()->GetBpropFunction(); - if (fn == nullptr || py::isinstance(fn)) { - MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << "."; - return nullptr; - } - FuncGraphPtr func_graph = parse::ParsePythonCode(fn); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; - return nullptr; - } - return func_graph; -} - -FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) { - static const std::string ad_module = "mindspore.ops._grad.grad_implementations"; - std::string func_name = "_fprop_" + prim->name(); - py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name); - auto func_graph = parse::ParsePythonCode(fn); - MS_EXCEPTION_IF_NULL(func_graph); - return BasicClone(func_graph); -} - -MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(prim); - - auto iter = bprop_registry_meta_.find(prim); - if (iter != bprop_registry_meta_.end()) { - return iter->second; - } - - if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { - MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); - bprop_registry_meta_[prim::kPrimMakeTuple] = meta; - return meta; - } - - MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; -} - -FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - if (!IsValueNode(value_node)) { - MS_LOG(EXCEPTION) << "Primitive node is not valid."; - } - - auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) { - auto fprop = GetFprop(prim); - fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer)); - return fprop; - } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { - return nullptr; - } - - FuncGraphPtr bprop_fg = nullptr; - if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) { - bprop_fg = BpropCut(value_node, resources); - } else { - auto iter = bprop_registry_.find(prim); - if (iter != bprop_registry_.end()) { - bprop_fg = iter->second; - } - - if (bprop_fg == nullptr) { - bprop_fg = GetBprop(prim); - if (bprop_fg != nullptr) { - // Set bprop_g graph cache - bprop_registry_[prim] = bprop_fg; - } else { - bprop_fg = FakeBprop(value_node, resources); - } - } - } - - auto expanded_fg = BpropToK(prim, bprop_fg); - if (expanded_fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed convert " << prim->name() - << " prim bprop function to J expanded func graph. NodeInfo: " - << trace::GetDebugInfo(bprop_fg->debug_info()); - } - - return expanded_fg; -} - -AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) { - // bprop_fg has been checked in caller - if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { - // Set bprop output as (env, dx, dy, dz, ...) - auto cbprop = bprop_fg->output()->cast(); - auto &inputs = cbprop->inputs(); - - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - args.push_back(NewValueNode(newenv)); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - return NewCNode(args, bprop_fg); - } - - // Set bprop output as (env, dx) - std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); - std::string python_ops("_tuple_add"); - auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); - return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg); -} - -void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, - std::vector *const transf_args) { - MS_EXCEPTION_IF_NULL(mng); - // bprop_fg has been checked in caller - // transform except the last 2 parameters: out, dout. - for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) { - auto p = bprop_fg->parameters()[i]; - MS_EXCEPTION_IF_NULL(p); - - TraceManager::DebugTrace(std::make_shared(p->debug_info())); - auto transf_p = outer->add_parameter(); - TraceManager::EndTrace(); - - (void)mng->Replace(p, transf_p); - transf_args->push_back(transf_p); - } -} - -void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool check_bprop_flag = context->check_bprop_flag(); - // Skip checking if check_bprop not set - if (!check_bprop_flag) { - return; - } - - // bprop_fg has been checked in caller - auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops"); - MS_EXCEPTION_IF_NULL(check_bprop_class); - auto check_bprop = - bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared(prim_to_check))}); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - inputs.insert(inputs.begin() + 1, bprop_fg->parameters().begin(), bprop_fg->parameters().end() - 2); - AnfNodePtr params = bprop_fg->NewCNode(inputs); - - inputs.clear(); - inputs.push_back(check_bprop); - inputs.push_back(bprop_fg->output()); - inputs.push_back(params); - AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs); - bprop_fg->set_output(bprop_out); -} - -FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { - MS_EXCEPTION_IF_NULL(bprop_fg); - auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); - auto expanded_fg = BpropToK(fprop_fg, bprop_fg); - if (expanded_fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() - << " Cell bprop function to K expanded func graph. NodeInfo: " - << trace::GetDebugInfo(fprop_fg->debug_info()); - } - return expanded_fg; -} - -FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto prim = GetValueNode(value_node); - MS_EXCEPTION_IF_NULL(prim); - auto &node_users = resources->manager()->node_users(); - - auto &users = node_users[value_node]; - auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { - return IsPrimitiveCNode(user.first, prim); - }); - if (cnode == users.end()) { - MS_LOG(EXCEPTION) << "Fail to find cnode."; - } - auto inputs_num = cnode->first->cast()->size() - 1; - - auto func_graph = std::make_shared(); - std::vector outputs; - - auto bprop_cut = std::make_shared("bprop_cut", py::object()); - bprop_cut->CopyHookFunction(prim); - - auto cell_id = GetValue(prim->GetAttr("cell_id")); - if (cell_id != "") { - (void)bprop_cut->AddAttr("cell_hook", MakeValue(true)); - (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id)); - } - - outputs.push_back(NewValueNode(bprop_cut)); - for (size_t i = 0; i < inputs_num; ++i) { - auto param = func_graph->add_parameter(); - outputs.push_back(param); - } - auto p1 = func_graph->add_parameter(); - auto p2 = func_graph->add_parameter(); - outputs.push_back(p1); - outputs.push_back(p2); - - func_graph->set_output(func_graph->NewCNode(outputs)); - return func_graph; -} - -FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto prim = value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto &node_users = resources->manager()->node_users(); - - auto &users = node_users[value_node]; - auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair &user) -> bool { - return IsPrimitiveCNode(user.first, prim); - }); - if (cnode == users.end()) { - MS_LOG(EXCEPTION) << "Fail to find cnode."; - } - auto inputs_num = cnode->first->cast()->inputs().size() - 1; - - auto func_graph = std::make_shared(); - std::vector outputs; - outputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto fake_bprop = std::make_shared("fake_bprop"); - (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined.")); - - for (size_t i = 0; i < inputs_num; ++i) { - // Mock params for inputs - auto param = func_graph->add_parameter(); - // Mock derivatives for each inputs - outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param})); - } - // mock params for out and dout - (void)func_graph->add_parameter(); - (void)func_graph->add_parameter(); - func_graph->set_output(func_graph->NewCNode(outputs)); - return func_graph; -} -} // namespace ad -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc deleted file mode 100644 index bb52273568..0000000000 --- a/mindspore/ccsrc/optimizer/clean.cc +++ /dev/null @@ -1,531 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/clean.h" -#include -#include -#include -#include -#include -#include "./common.h" -#include "debug/trace.h" -#include "operator/composite/composite.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -using mindspore::abstract::AbstractAttribute; -using mindspore::abstract::AbstractClass; -using mindspore::abstract::AbstractDictionary; -using mindspore::abstract::AbstractJTagged; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractUndetermined; - -static AbstractBasePtr Reabs(const AbstractBasePtr &t) { - if (t == nullptr) { - return nullptr; - } - - AbstractBasePtr res = t; - if (t->isa()) { - auto abs_class = dyn_cast(t); - AbstractBasePtrList baselist; - auto attributes = abs_class->attributes(); - (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), - [](const AbstractAttribute &item) { return item.second; }); - res = std::make_shared(baselist); - } else if (t->isa()) { - auto abs_dict = dyn_cast(t); - AbstractBasePtrList baselist; - auto elements = abs_dict->elements(); - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), - [](const AbstractAttribute &item) { return item.second; }); - res = std::make_shared(baselist); - } else if (t->isa()) { - auto abs_dict = dyn_cast(t); - res = std::make_shared(abs_dict->elements()); - } - return res; -} - -AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [getattr, data, attribute] - MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto dt = data->abstract(); - if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { - return nullptr; - } - - if (!dt->isa()) { - MS_LOG(EXCEPTION) << "First parameter of getattr is not AbstractClass, but " << dt->type_name() << "."; - } - - auto cons_is_str = IsValueNode(cons); - auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; - - auto ct = dyn_cast(dt); - const auto &cmap = ct->attributes(); - int count = 0; - for (auto &item : cmap) { - if (cons_is_str && item.first == cons_str) { - break; - } - count++; - } - - auto idx_c = NewValueNode(count); - AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); - idx_c->set_abstract(aptr); - - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); -} - -AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - // Inputs should be [dict_getitem, dict, item] - const auto &inputs = node->inputs(); - MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto dt = data->abstract(); - MS_EXCEPTION_IF_NULL(dt); - if (!dt->isa()) { - MS_LOG(EXCEPTION) << "first parameter of dict_getitem is not AbstractDictionary, but " << dt->type_name(); - } - auto cons_is_str = IsValueNode(cons); - auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; - - auto ct = dyn_cast(dt); - const auto &cmap = ct->elements(); - int count = 0; - for (auto &item : cmap) { - if (cons_is_str && item.first == cons_str) { - break; - } - count++; - } - - auto idx_c = NewValueNode(count); - AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); - idx_c->set_abstract(aptr); - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); -} - -AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - // Inputs should be [dict_setitem, dict, item, value] - const auto &inputs = node->inputs(); - MS_ASSERT(inputs.size() == 4 && "DictSetItem should have three inputs."); - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - AnfNodePtr item_value = inputs[3]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto dt = data->abstract(); - MS_EXCEPTION_IF_NULL(dt); - if (!dt->isa()) { - MS_LOG(EXCEPTION) << "first parameter of dict_setitem is not AbstractDictionary, but " << dt->type_name(); - } - auto cons_is_str = IsValueNode(cons); - auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; - - auto ct = dyn_cast(dt); - const auto &cmap = ct->elements(); - int count = 0; - for (auto &item : cmap) { - if (cons_is_str && item.first == cons_str) { - break; - } - count++; - } - if (IntToSize(count) >= cmap.size()) { - // for dictionary set, if the key does not exist, we should create a new item - auto tuple_add_op = std::make_shared("tuple_add"); - auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value}); - return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item}); - } - auto idx_c = NewValueNode(count); - AbstractBasePtr aptr = std::make_shared(std::make_shared(count)); - idx_c->set_abstract(aptr); - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, idx_c, item_value}); -} - -AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - // Inputs of node should be [make_record, klass, attr1, attr2, ...], so offset by 2 to get attr; - (void)inputs.insert(inputs.end(), node->inputs().begin() + 2, node->inputs().end()); - return node->func_graph()->NewCNode(inputs); -} - -AnfNodePtr ErasePartialNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; - MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); - - std::vector args(inputs.begin() + 2, inputs.end()); - auto oper = inputs[1]; - if (IsPrimitive(oper, prim::kPrimMakeRecord)) { - if (args.size() == 1) { - return NewValueNode(prim::kPrimMakeTuple); - } - - if (args.size() > 1) { - std::vector new_inputs; - new_inputs.emplace_back(NewValueNode(prim::kPrimPartial)); - new_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - (void)new_inputs.insert(new_inputs.end(), args.begin() + 1, args.end()); - - MS_EXCEPTION_IF_NULL(node->func_graph()); - return node->func_graph()->NewCNode(new_inputs); - } - } - return nullptr; -} - -AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - std::vector inputs; - inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items; - (void)inputs.insert(inputs.end(), node->inputs().begin() + 1, node->inputs().end()); - return node->func_graph()->NewCNode(inputs); -} - -AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [list_getitem, list, item] - if (inputs.size() < 3) { - MS_LOG(EXCEPTION) << "Node's input number < 3."; - } - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - MS_EXCEPTION_IF_NULL(data); - MS_EXCEPTION_IF_NULL(cons); - - auto cons_node = cons->cast(); - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); -} - -AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - const auto &inputs = node->inputs(); - // Inputs should be [list_setitem, list, index, item] - if (inputs.size() < 4) { - MS_LOG(EXCEPTION) << "Node's input number < 4."; - } - - AnfNodePtr data = inputs[1]; - AnfNodePtr cons = inputs[2]; - AnfNodePtr value = inputs[3]; - - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); -} - -AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &inputs = node->inputs(); - MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); - return inputs[2]; -} - -AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &inputs = node->inputs(); - // Inputs should be [make_keyword_arg, key, value] - MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); - return inputs[2]; -} - -AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - const auto &inputs = node->inputs(); - // Inputs should be [extract_keyword_arg, arg, key] - MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); - return inputs[2]; -} - -ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { - const int DEPTH_MAX = 5; - if (depth > DEPTH_MAX) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; - } - std::vector elements; - for (const auto &it : value_list->value()) { - ValuePtr value = nullptr; - if (it->isa()) { - value = ConvertValueListToValueTuple(it->cast(), depth + 1); - } else { - value = it; - } - elements.push_back(value); - } - return std::make_shared(elements); -} - -AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - ValuePtr value = node->value(); - auto value_list = value->cast(); - MS_EXCEPTION_IF_NULL(value_list); - int depth = 0; - return std::make_shared(ConvertValueListToValueTuple(value_list, depth)); -} - -// Convert class to Tuple -// Convert getattr to getitem -// Convert make_record to make_tuple -bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - bool changed = false; - - // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var - AnfNodeSet all_node = manager->all_nodes(); - for (auto &node : all_node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - AnfNodePtr new_node = nullptr; - if (IsValueNode(node)) { - new_node = NewValueNode(prim::kPrimMakeTuple); - } else if (IsPrimitiveCNode(node, prim::kPrimGetAttr)) { - new_node = ConvertGetAttrToTupleGetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeRecord)) { - new_node = ConvertMakeRecordToMakeTuple(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimPartial)) { - new_node = ErasePartialNode(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimDictGetItem)) { - new_node = ConvertDictGetItemToTupleGetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) { - new_node = ConvertDictSetItemToTupleSetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) { - new_node = EraseMakeDictNode(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) { - new_node = EraseMakeKeywordArgNode(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { - new_node = EraseExtractKeywordArg(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { - new_node = ConvertMakeListToMakeTuple(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { - new_node = ConvertListGetItemToTupleGetItem(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimListSetItem)) { - new_node = ConvertListSetItemToTupleSetItem(cnode); - } else if (IsValueNode(node)) { - new_node = ConvertValueListNodeToValueTupleNode(node->cast()); - } - - if (new_node != nullptr) { - new_node->set_abstract(node->abstract()); - MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); - (void)manager->Replace(node, new_node); - changed = true; - } - } - - for (auto &node : manager->all_nodes()) { - auto ret = Reabs(node->abstract()); - node->set_abstract(ret); - } - return changed; -} - -// expand tuples in graph parameters -static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, - const std::vector ¶ms) { - MS_EXCEPTION_IF_NULL(mng); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector new_params; - for (const auto ¶m : params) { - MS_EXCEPTION_IF_NULL(param); - auto param_abs = param->abstract(); - MS_EXCEPTION_IF_NULL(param_abs); - - if (param_abs->isa()) { - MS_LOG(EXCEPTION) << "Not Implemented Error NodeInfo: " << trace::GetDebugInfo(param->debug_info()); - } - - if (!param_abs->isa()) { - new_params.emplace_back(param); - continue; - } - - std::vector new_param; - std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; - auto abs_tuple = dyn_cast(param_abs); - for (auto &elem : abs_tuple->elements()) { - auto np = std::make_shared(func_graph); - np->set_abstract(elem); - new_param.emplace_back(np); - } - (void)inputs.insert(inputs.end(), new_param.begin(), new_param.end()); - auto new_tuple = func_graph->NewCNode(inputs); - (void)mng->Replace(param, new_tuple); - - auto expand_param = ExpandTuplesP(mng, func_graph, new_param); - (void)new_params.insert(new_params.end(), expand_param.begin(), expand_param.end()); - } - return new_params; -} - -// expand tuples in graph applies -static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { - MS_EXCEPTION_IF_NULL(graph); - - std::vector new_inputs; - for (const auto &input : inputs) { - MS_EXCEPTION_IF_NULL(input); - - auto input_abs = input->abstract(); - MS_EXCEPTION_IF_NULL(input_abs); - - if (input_abs->isa()) { - auto abstract_tag = dyn_cast(input_abs); - if (abstract_tag->element()->isa()) { - MS_LOG(EXCEPTION) << "Not Implemented Error JTagged NodeInfo: " << trace::GetDebugInfo(input->debug_info()); - } - } - - if (!input_abs->isa()) { - new_inputs.emplace_back(input); - continue; - } - - int idx = 0; - std::vector new_input; - auto abs_tuple = dyn_cast(input_abs); - for (auto &elem : abs_tuple->elements()) { - auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); - AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); - c_node->input(2)->set_abstract(aptr); - c_node->set_abstract(elem); - new_input.emplace_back(c_node); - idx++; - } - - auto expand_tuple = ExpandTuplesC(graph, new_input); - (void)new_inputs.insert(new_inputs.end(), expand_tuple.begin(), expand_tuple.end()); - } - - return new_inputs; -} - -// remove most uses of tuples from the graph parameters & apply inputs -// tuples that are returned will be kept -// tuples in CNode's inputs: AbstractTuple (a, b ,c) --> -// CNode("tuple_getitem", (a,b,c), 0) -// CNode("tuple_getitem", (a,b,c), 1) -// CNode("tuple_getitem", (a,b,c), 2) -// tuples in Graph's parameters: AbstractTuple (a, b, c) --> -// CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) -// cppcheck-suppress unusedFunction -void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var - AnfNodeSet all_node = manager->all_nodes(); - for (auto &node : all_node) { - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - - const auto &inputs = cnode->inputs(); - - // Bypass the first input in inputs as it's fn. - if (!IsValueNode(inputs[0])) { - std::vector expand_inputs; - (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 1, inputs.end()); - - auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); - if (new_inputs != expand_inputs) { - std::vector cnode_inputs{inputs[0]}; - (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); - - MS_EXCEPTION_IF_NULL(node->func_graph()); - auto new_node = node->func_graph()->NewCNode(cnode_inputs); - new_node->set_abstract(node->abstract()); - - (void)manager->Replace(node, new_node); - } - // Bypass the first 2 inputs in inputs as it's [partial, fn]. - } else if (cnode->IsApply(prim::kPrimPartial) && !IsValueNode(inputs[1])) { - std::vector expand_inputs; - (void)expand_inputs.insert(expand_inputs.end(), inputs.begin() + 2, inputs.end()); - - auto new_inputs = ExpandTuplesC(cnode->func_graph(), expand_inputs); - if (new_inputs != expand_inputs) { - std::vector cnode_inputs{inputs[0], inputs[1]}; - (void)cnode_inputs.insert(cnode_inputs.end(), new_inputs.begin(), new_inputs.end()); - - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - auto new_node = cnode->func_graph()->NewCNode(cnode_inputs); - new_node->set_abstract(cnode->abstract()); - - (void)manager->Replace(node, new_node); - } - } - } - - FuncGraphSet all_graph = manager->func_graphs(); - for (auto &func_graph : all_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); - manager->SetParameters(func_graph, expand_p); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/clean.h b/mindspore/ccsrc/optimizer/clean.h deleted file mode 100644 index 672ee78414..0000000000 --- a/mindspore/ccsrc/optimizer/clean.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ - -#include -#include "ir/anf.h" -#include "operator/ops.h" -#include "utils/any.h" -#include "ir/manager.h" -#include "abstract/dshape.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -// Remove the class type from graphs -bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); - -// Remove most uses of tuples from the graph -// tuples that are returned will be kept -void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); - -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_CLEAN_H_ diff --git a/mindspore/ccsrc/optimizer/control_depend.cc b/mindspore/ccsrc/optimizer/control_depend.cc deleted file mode 100644 index 0b5c85b1e0..0000000000 --- a/mindspore/ccsrc/optimizer/control_depend.cc +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/control_depend.h" - -#include -#include -#include -#include -#include - -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -std::vector DoControlDepend(const FuncGraphPtr &graph, const CNodePtr &return_node, - const std::vector &effect_index, const std::vector &cnodes) { - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), return_node->input(1)}; - std::vector make_tuple{NewValueNode(prim::kPrimMakeTuple)}; - size_t effect_size = effect_index.size(); - for (size_t i = 0; i < effect_size; i++) { - size_t pre_index = 0; - if (i > 0) { - pre_index = effect_index[i - 1] + 1; - } - size_t this_index = effect_index[i]; - size_t last_index = cnodes.size() - 2; - if (i < effect_size - 1) { - last_index = effect_index[i + 1]; - } - - if (this_index > pre_index) { - std::vector pre_segment; - for (size_t k = pre_index; k < this_index; k++) { - // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. - if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || - IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { - continue; - } - pre_segment.push_back(cnodes[k]); - } - auto roots = FindRoots(pre_segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - AnfNodePtr control_depend = - graph->NewCNode({NewValueNode(prim::kPrimControlDepend), *iter, cnodes[this_index]}); - make_tuple.push_back(control_depend); - } - } - if (last_index > this_index) { - std::vector last_segment; - for (size_t k = this_index + 1; k <= last_index; k++) { - // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. - if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || - IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { - continue; - } - last_segment.push_back(cnodes[k]); - } - auto leaves = FindLeaves(last_segment); - for (auto iter = leaves->begin(); iter != leaves->end(); (void)iter++) { - AnfNodePtr control_depend = - graph->NewCNode({NewValueNode(prim::kPrimControlDepend), cnodes[this_index], *iter}); - make_tuple.push_back(control_depend); - } - } - } - depend_nodes.push_back(graph->NewCNode(make_tuple)); - return depend_nodes; -} - -void AddControlDepend(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - std::list orders = graph->GetOrderedCnodes(); - std::vector cnodes(orders.begin(), orders.end()); - size_t cnodes_size = cnodes.size(); - // get effect index of cnodes - std::vector effect_index{}; - for (size_t i = 0; i < cnodes_size; i++) { - if (graph->HasEffect(cnodes[i])) { - effect_index.push_back(i); - } - } - if (effect_index.empty()) { - return; - } - AnfNodePtr last_node = cnodes[cnodes_size - 1]; - CNodePtr return_node; - if (last_node->isa()) { - return_node = last_node->cast(); - } - MS_EXCEPTION_IF_NULL(return_node); - if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { - MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; - } - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; - } - - auto depend_node_inputs = DoControlDepend(graph, return_node, effect_index, cnodes); - auto depend_cnode = graph->NewCNode(depend_node_inputs); - depend_cnode->set_abstract(depend_cnode->input(1)->abstract()); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (!manager->Replace(return_node->input(1), depend_cnode)) { - MS_LOG(EXCEPTION) << "Depend replace node failed"; - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/cse.cc b/mindspore/ccsrc/optimizer/cse.cc deleted file mode 100644 index 0b675cca72..0000000000 --- a/mindspore/ccsrc/optimizer/cse.cc +++ /dev/null @@ -1,231 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/cse.h" -#include -#include -#include -#include "./common.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractFunctionPtr; - -BasePtr AbsOf(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto node_abs = node->abstract(); - // in testcase: TestOptOpt.CSE, node->abstract() is null; - if (node_abs == nullptr) { - return kAnyValue; - } - - return node_abs; -} - -bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { - bool changed = false; - for (FuncGraphPtr fg : manager->func_graphs()) { - MS_EXCEPTION_IF_NULL(fg); - std::vector order_group; - std::unordered_map> groups; - std::unordered_map hashes; - - std::vector toposet = TopoSort(fg->get_return()); - for (auto node : toposet) { - MS_EXCEPTION_IF_NULL(node); - if (hashes.find(node) != hashes.end()) { - continue; - } - - std::size_t h = 0; - if (node->isa()) { - ValueNodePtr value_node = node->cast(); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - h = hash_combine(value->hash(), (AbsOf(value_node)->hash())); - } else if (node->isa()) { - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - size_t init = 0; - h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) { - return hash_combine(hash, hashes[node_in]); - }); - } else if (node->isa()) { - h = node->hash(); - } else { - MS_LOG(ERROR) << "Unknow node type"; - } - - hashes[node] = h; - if (groups.find(h) == groups.end()) { - std::vector innervec({node}); - groups[h] = innervec; - order_group.emplace_back(h); - } else { - groups[h].push_back(node); - } - } - - changed = DoReplace(manager, order_group, &groups) || changed; - } - - return changed; -} -// The op like print, summary, or the op do not has true output, and always as a depend node input. -static bool HasSideEffect(const AnfNodePtr &node) { - auto prim = GetCNodePrimitive(node); - if (prim == nullptr) { - return false; - } - auto side_effect_v = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT); - if (side_effect_v != nullptr && side_effect_v->isa()) { - return GetValue(side_effect_v); - } - return false; -} -// If true do not merge the node. -bool CSE::CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const { - bool has_random_effect = false; - auto prim_main = GetCNodePrimitive(main); - auto prim_node = GetCNodePrimitive(node); - // if has random effect, when generate by different op (not same object), do not merge. - if (prim_main != nullptr) { - if (prim_main == prim_node) { - return false; - } - auto effect_val = prim_main->GetAttr(GRAPH_FLAG_RANDOM_EFFECT); - if (effect_val != nullptr && effect_val->isa()) { - has_random_effect = GetValue(effect_val); - } - } - return has_random_effect; -} - -bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect) const { - MS_EXCEPTION_IF_NULL(main); - MS_EXCEPTION_IF_NULL(node); - - if (main->isa() && node->isa()) { - auto main_value = GetValueNode(main); - auto node_value = GetValueNode(node); - return (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); - } else if (main->isa() && node->isa()) { - auto c_main = main->cast(); - auto c_node = node->cast(); - // When appsame is true, check if has side effect, do not merge. - if (check_side_effect && HasSideEffect(main)) { - return false; - } - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() != inp2.size()) { - return false; - } - for (size_t j = 0; j < inp1.size(); j++) { - auto inp1_j = inp1[j]; - auto inp2_j = inp2[j]; - MS_EXCEPTION_IF_NULL(inp1_j); - MS_EXCEPTION_IF_NULL(inp2_j); - if (!(*inp1_j == *inp2_j)) { - // Handle the case of two different Tensor, but with the same value - if (IsValueNode(inp1_j) && IsValueNode(inp2_j)) { - auto tensor1 = GetValueNode(inp1_j); - auto tensor2 = GetValueNode(inp2_j); - if (tensor1->ValueEqual(*tensor2)) { - continue; - } - } else if (HasSideEffect(inp1_j) && HasSideEffect(inp2_j)) { - // When the same side effect node as another two nodes' inputs, we still merge the node. - // Because the node only can be the inputs of `depend`, when the `depend` is duplicated merge the depend the - // node. - if (CheckReplace(inp1_j, inp2_j, false)) { - continue; - } - } - return false; - } - } - // When appsame is true, check if has random effect do not merge - if (CheckRandomEffect(c_main, c_node)) { - return false; - } - return true; - } - // a parameter node. - return false; -} - -bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, - std::unordered_map> *groups) const { - bool changes = false; - std::set clear_set; - for (auto &h : order_group) { - std::vector &group = (*groups)[h]; - // If there are more than 2 node in that group, they may be same common expression can be eliminated. - if (group.size() > 1) { - for (size_t k = 0; k < group.size() - 1; k++) { - AnfNodePtr main = group[k]; - MS_EXCEPTION_IF_NULL(main); - - // When all node in group has been replaced - // or a valuenode node, skip compare in group - if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa())) { - break; - } - - // skip node has been replaced - if (clear_set.find(k) != clear_set.end()) { - continue; - } - - // Compare with rest elements in this group. - for (size_t i = k + 1; i < group.size(); i++) { - auto node = group[i]; - MS_EXCEPTION_IF_NULL(node); - - if (clear_set.find(i) != clear_set.end()) { - continue; - } - if (main->func_graph() != node->func_graph()) { - continue; - } - if (CheckReplace(node, main)) { - changes = true; - (void)manager->Replace(node, main); - (void)clear_set.insert(i); - } - } - } - clear_set.clear(); - } - } - - return changes; -} - -bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - return BuildOrderGroupAndDoReplace(manager); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/cse.h b/mindspore/ccsrc/optimizer/cse.h deleted file mode 100644 index 57163cc5c9..0000000000 --- a/mindspore/ccsrc/optimizer/cse.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "ir/manager.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -// Common subexpression elimination. -class CSE { - public: - explicit CSE(bool report_changes = true) : report_changes_(report_changes) {} - virtual ~CSE() = default; - - bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { - bool chg = Cse(root, optimizer->resource()->manager()); - return chg && report_changes_; - } - - virtual bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const; - - virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const; - - bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const; - - private: - bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; - bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, - std::unordered_map> *groups) const; - bool report_changes_; -}; - -BasePtr AbsOf(const AnfNodePtr &node); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_CSE_H_ diff --git a/mindspore/ccsrc/optimizer/graph_kernel_reuse.cc b/mindspore/ccsrc/optimizer/graph_kernel_reuse.cc deleted file mode 100644 index dc20ad925e..0000000000 --- a/mindspore/ccsrc/optimizer/graph_kernel_reuse.cc +++ /dev/null @@ -1,157 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/graph_kernel_reuse.h" -#include -#include -#include -#include "./common.h" -#include "utils/graph_utils.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -bool GraphKernelReuse::CompareNode(const AnfNodePtr a, const AnfNodePtr b) { - if (a->abstract() && b->abstract()) { - auto a_type = a->abstract()->GetTypeTrack(); - auto b_type = b->abstract()->GetTypeTrack(); - - if (a_type != b_type) { - return false; - } - - auto a_shape = a->abstract()->GetShapeTrack(); - auto b_shape = b->abstract()->GetShapeTrack(); - if (a_shape != nullptr && a_shape == b_shape) { - return true; - } - - if (a_shape != nullptr && b_shape != nullptr && a_shape->isa() && - b_shape->isa()) { - return a_shape->cast()->shape() == b_shape->cast()->shape(); - } - } - return false; -} - -bool GraphKernelReuse::DoReplace(const FuncGraphManagerPtr manager) { - bool changed = false; - auto fgs = manager->func_graphs(); - for (FuncGraphPtr &fg : fgs) { - if (!fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - continue; - } - std::string key = GetValue(fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - if (graph_kernel_ops.find(key) != graph_kernel_ops.end()) { - if (find(graph_kernel_ops[key].begin(), graph_kernel_ops[key].end(), fg) == graph_kernel_ops[key].end()) { - FuncGraphPtr new_fg = nullptr; - for (auto &cfg : graph_kernel_ops[key]) { - // If two graphs have different size then continue - auto fg_topos = TopoSort(fg->get_return()); - auto cfg_topos = TopoSort(cfg->get_return()); - if (fg_topos.size() != cfg_topos.size()) { - continue; - } - - // Compare const tensor - bool has_same = true; - for (size_t i = 0; i < fg_topos.size(); ++i) { - if (IsValueNode(fg_topos[i])) { - if (!IsValueNode(cfg_topos[i])) { - has_same = false; - break; - } - - auto tensor1 = GetValueNode(fg_topos[i]); - auto tensor2 = GetValueNode(cfg_topos[i]); - if (!tensor1->ValueEqual(*tensor2)) { - has_same = false; - break; - } - } - } - - if (!has_same) { - continue; - } - - auto fg_input = fg->parameters(); - auto cfg_input = cfg->parameters(); - if (fg_input.size() != cfg_input.size()) { - continue; - } - // Compare input - for (size_t i = 0; i < fg_input.size(); ++i) { - if (!CompareNode(fg_input[i], cfg_input[i])) { - has_same = false; - break; - } - } - if (!has_same) { - continue; - } - - // Compare output - if (!CompareNode(fg->output(), cfg->output())) { - continue; - } - - // Find reusable fg - new_fg = cfg; - break; - } - - if (new_fg != nullptr) { - // Replace current fg with existing fg - auto users = fg->func_graph_cnodes_index(); - for (auto &iter : users) { - auto cnode = iter.first->first->cast(); - auto new_input = cnode->inputs(); - auto main_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(main_graph); - if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) { - new_input[1] = NewValueNode(new_fg); - } else { - new_input[0] = NewValueNode(new_fg); - } - auto new_cnode = main_graph->NewCNode(new_input); - manager->Replace(iter.first->first, new_cnode); - changed = true; - } - - } else { - // Add current fg to map - graph_kernel_ops[key].push_back(fg); - } - } - } else { - graph_kernel_ops[key] = {fg}; - } - } - - return changed; -} - -bool GraphKernelReuse::ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(root); - - return DoReplace(manager); -} - -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/graph_kernel_reuse.h b/mindspore/ccsrc/optimizer/graph_kernel_reuse.h deleted file mode 100644 index ed5cc93d18..0000000000 --- a/mindspore/ccsrc/optimizer/graph_kernel_reuse.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H -#define MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H - -#include -#include -#include -#include - -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { - -// Common subexpression elimination. -class GraphKernelReuse { - public: - GraphKernelReuse() : count(0) {} - virtual ~GraphKernelReuse() = default; - - bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { - bool chg = ReuseGraphKernel(root, optimizer->resource()->manager()); - return chg; - } - - bool CompareNode(const AnfNodePtr a, const AnfNodePtr other); - bool DoReplace(const FuncGraphManagerPtr manager); - - bool ReuseGraphKernel(const FuncGraphPtr root, const FuncGraphManagerPtr manager); - - private: - std::unordered_map> graph_kernel_ops; - int count; -}; - -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_GRAPH_KERNEL_OP_REUSE_H diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc deleted file mode 100644 index 166151751f..0000000000 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ /dev/null @@ -1,174 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "optimizer/irpass.h" -#include "optimizer/irpass/arithmetic_simplify.h" -#include "optimizer/irpass/branch_culling.h" -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass/convert.h" -#include "optimizer/irpass/env_item_eliminate.h" -#include "optimizer/irpass/grad_var_prepare.h" -#include "optimizer/irpass/gradient_eliminate.h" -#include "optimizer/irpass/inline.h" -#include "optimizer/irpass/incorporate_call.h" -#include "optimizer/irpass/incorporate_getitem.h" -#include "optimizer/irpass/item_tuple_eliminate.h" -#include "optimizer/irpass/mark_interface_fusion.h" -#include "optimizer/irpass/merge_addn.h" -#include "optimizer/irpass/minmax_grad.h" -#include "optimizer/irpass/param_replace.h" -#include "optimizer/irpass/partial_eliminate.h" -#include "optimizer/irpass/reduce_eliminate.h" -#include "optimizer/irpass/ref_eliminate.h" -#include "optimizer/irpass/reshape_eliminate.h" -#include "optimizer/irpass/special_op_eliminate.h" -#include "optimizer/irpass/specialize_transform.h" -#include "optimizer/irpass/symbol_resolver.h" -#include "optimizer/irpass/tile_eliminate.h" -#include "optimizer/irpass/transpose_eliminate.h" -#include "optimizer/opt.h" -#include "optimizer/irpass/indexed_slices_eliminate.h" - -namespace mindspore { -namespace opt { -namespace irpass { -OptimizeIRPassLib::OptimizeIRPassLib() { - arithmetic_simplify_ = MakeSubstitution(std::make_shared(), "arithmetic_simplify", - {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, - prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); - arithmetic_simplify2_ = - MakeSubstitution(std::make_shared(), "arithmetic_simplify2", {prim::kPrimMul}); - special_op_eliminate_ = - MakeSubstitution(std::make_shared(), "special_op_eliminate", - {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, - prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv}); - zero_like_fill_zero_ = - MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); - adjust_all_reduce_mul_add_ = - MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); - - // ops eliminate - item_tuple_eliminate_ = MakeSubstitution(std::make_shared(), "item_tuple_eliminate", - {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); - tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); - cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); - reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); - transpose_eliminate_ = - MakeSubstitution(std::make_shared(), "transpose_eliminate", prim::kPrimTranspose); - reduce_eliminate_ = MakeSubstitution( - std::make_shared(), "reduce_eliminate", - {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); - partial_eliminate_ = MakeSubstitution(std::make_shared(), "partial_eliminate", IsCNodeDup); - same_eliminate_ = MakeSubstitution(std::make_shared(), "same_eliminate", prim::kPrimSameTypeShape); - check_bprop_eliminate_ = - MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); - reset_defer_inline_ = - MakeSubstitution(std::make_shared(), "reset_defer_inline", IsValueNode); - depend_value_elim_ = MakeSubstitution(std::make_shared(), "depend_value_elim", prim::kPrimDepend); - - // Env Item Eliminate - env_get_item_eliminate_ = - MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_ = - MakeSubstitution(std::make_shared(), "incorporate_env_get_item", prim::kPrimEnvGetItem); - incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), - "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); - - // Ref eliminate - make_ref_eliminate_ = - MakeSubstitution(std::make_shared(), "make_ref_eliminate", prim::kPrimMakeRef); - get_ref_param_eliminate_ = MakeSubstitution(std::make_shared(), "get_ref_param_eliminate", - {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - get_make_ref_eliminate_ = MakeSubstitution(std::make_shared(), "get_make_ref_eliminate", - {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); - - replace_refkey_by_param_ = MakeSubstitution(std::make_shared(), "replace_refkey_by_param", - IsValueNode, opt::FORCE_RENORM); - replace_old_param_ = MakeSubstitution(std::make_shared(), "replace_old_param", IsParam); - // Gradient transforms - expand_jprim_ = MakeSubstitution(std::make_shared(), "expand_jprim", prim::kPrimJ); - minmaximum_grad_ = MakeSubstitution(std::make_shared(), "minmaximum_grad", prim::kPrimTupleGetItem); - - // branch culling - switch_simplify_ = MakeSubstitution(std::make_shared(), "switch_simplify", prim::kPrimSwitch); - float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared(), - "float_tuple_getitem_switch", prim::kPrimTupleGetItem); - float_env_getitem_switch_ = - MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); - convert_switch_replacement_ = - MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); - - // Addn - merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); - addn_zero_filter_ = MakeSubstitution(std::make_shared(), "addn_zero_filter", prim::kPrimAddN); - - // inline - inline_ = MakeSubstitution(std::make_shared(), "inline", IsCNodeGraph); - replace_applicator_ = - MakeSubstitution(std::make_shared(), "replace_applicator", IsValueNode); - specialize_transform_ = - MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); - - // Incorporation - incorporate_getitem_set_ = - MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); - incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared(), - "incorporate_getitem_from_param", IsCNodeGraphKernel); - incorporate_call_ = MakeSubstitution(std::make_shared(), "incorporate_call", IsCNodeDup); - incorporate_call_switch_ = - MakeSubstitution(std::make_shared(), "incorporate_call_switch", IsCNodeDup); - - // Virtual Dataset - virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared(), - "virtual_dataset_eliminate", prim::kPrimVirtualDataset); - - // Convert - print_tuple_wrapper_ = - MakeSubstitution(std::make_shared(), "print_tuple_wrapper", prim::kPrimPrint); - - // Unused parameter eliminate - unused_parameter_eliminate_ = - MakeSubstitution(std::make_shared(), "unused_parameter_eliminate", IsCNodeGraphKernel); - unused_output_eliminate_ = - MakeSubstitution(std::make_shared(), "unused_output_eliminate", IsCNodeGraphKernel); - - // AddN eliminate - addn_eliminate_ = MakeSubstitution(std::make_shared(), "addn_eliminate", IsCNodeGraphKernel); - - // Mark interface fusion - mark_interface_fusion_ = - MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); - - // IndexedSlices Eliminate - indexed_slices_eliminate_ = MakeSubstitution( - std::make_shared(), "indexed_slices_eliminate", - {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); -} - -ResolveIRPassLib::ResolveIRPassLib() { - resolver_resolve_ = MakeSubstitution(std::make_shared(), "resolver_resolve", prim::kPrimResolve); - resolver_getattr_ = MakeSubstitution(std::make_shared(), "resolver_getattr", prim::kPrimGetAttr); -} - -InferenceOptPrepareLib::InferenceOptPrepareLib() { - grad_var_prepare_ = MakeSubstitution(std::make_shared(), "grad_var_prepare", IsCNode); -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h deleted file mode 100644 index 782eae6124..0000000000 --- a/mindspore/ccsrc/optimizer/irpass.h +++ /dev/null @@ -1,192 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ - -#include - -#include "optimizer/optimizer.h" -#include "optimizer/opt.h" -#include "ir/visitor.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// the collection of irpass for optimie action -class OptimizeIRPassLib { - public: - OptimizeIRPassLib(); - ~OptimizeIRPassLib() = default; - - SubstitutionPtr arithmetic_simplify_; - SubstitutionPtr arithmetic_simplify2_; - SubstitutionPtr special_op_eliminate_; - SubstitutionPtr zero_like_fill_zero_; - SubstitutionPtr adjust_all_reduce_mul_add_; - - // ops eliminate - SubstitutionPtr item_tuple_eliminate_; - SubstitutionPtr tile_eliminate_; - SubstitutionPtr cast_eliminate_; - SubstitutionPtr reshape_eliminate_; - SubstitutionPtr transpose_eliminate_; - SubstitutionPtr reduce_eliminate_; - SubstitutionPtr partial_eliminate_; - SubstitutionPtr same_eliminate_; - SubstitutionPtr check_bprop_eliminate_; - SubstitutionPtr reset_defer_inline_; - SubstitutionPtr depend_value_elim_; - - // Env Item Eliminate - SubstitutionPtr env_get_item_eliminate_; - SubstitutionPtr new_env_get_item_; - SubstitutionPtr incorporate_env_getitem_; - SubstitutionPtr incorporate_env_getitem_switch_; - - // Ref eliminate - SubstitutionPtr make_ref_eliminate_; - SubstitutionPtr get_ref_param_eliminate_; - SubstitutionPtr get_make_ref_eliminate_; - SubstitutionPtr replace_refkey_by_param_; - SubstitutionPtr replace_old_param_; - - // Branch culling - SubstitutionPtr switch_simplify_; - SubstitutionPtr float_tuple_getitem_switch_; - SubstitutionPtr float_env_getitem_switch_; - SubstitutionPtr convert_switch_replacement_; - - // AddN - SubstitutionPtr merge_addn_; - SubstitutionPtr addn_zero_filter_; - - // Gradient irpasses - SubstitutionPtr expand_jprim_; - SubstitutionPtr minmaximum_grad_; - - // inline - SubstitutionPtr inline_; - SubstitutionPtr replace_applicator_; - SubstitutionPtr specialize_transform_; - - // Incorporation - SubstitutionPtr incorporate_getitem_set_; - SubstitutionPtr incorporate_getitem_from_param_; - SubstitutionPtr incorporate_call_; - SubstitutionPtr incorporate_call_switch_; - - // virtual dataset - SubstitutionPtr virtual_dataset_eliminate_; - - // Convert - SubstitutionPtr print_tuple_wrapper_; - - // Unused parameter eliminate - SubstitutionPtr unused_parameter_eliminate_; - SubstitutionPtr unused_output_eliminate_; - - // AddN eliminate - SubstitutionPtr addn_eliminate_; - - // Fusion - SubstitutionPtr mark_interface_fusion_; - - // IndexedSlices Eliminate - SubstitutionPtr indexed_slices_eliminate_; -}; - -// the collection of irpass for resolve action -class ResolveIRPassLib { - public: - ResolveIRPassLib(); - ~ResolveIRPassLib() = default; - - SubstitutionPtr resolver_resolve_; - SubstitutionPtr resolver_getattr_; -}; - -class InferenceOptPrepareLib { - public: - InferenceOptPrepareLib(); - ~InferenceOptPrepareLib() = default; - SubstitutionPtr grad_var_prepare_; -}; - -// predicate functions -inline bool IsNode(const AnfNodePtr &) { return true; } - -inline bool IsCNode(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} - -inline bool IsVNode(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} - -inline bool IsParam(const AnfNodePtr &node) { - if (node != nullptr) { - return node->isa(); - } - return false; -} - -// Check if CNode Input 0 is Func Graph -inline bool IsCNodeGraph(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { - return false; - } - - auto inp0 = node->cast()->input(0); - return IsValueNode(inp0); -} - -// Check if CNode Input 0 is Func Graph of graph kernel. -inline bool IsCNodeGraphKernel(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { - return false; - } - - auto inp0 = node->cast()->input(0); - if (IsValueNode(inp0)) { - auto fg = GetValueNode(inp0); - if (fg == nullptr) { - return false; - } - return fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - } - return false; -} - -// Check if CNode Input 0 is CNode -inline bool IsCNodeDup(const AnfNodePtr &node) { - if (node == nullptr || !node->isa()) { - return false; - } - - auto inp0 = node->cast()->input(0); - return (inp0 != nullptr) && inp0->isa(); -} -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc deleted file mode 100644 index b111a6b67a..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc +++ /dev/null @@ -1,680 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -#include "optimizer/irpass/arithmetic_simplify.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarMul)(node); - - if (is_zero_) { - return NewValueNode(zero_); - } - if (is_one_) { - return x_; - } - return nullptr; -} - -void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { - if (is_one_ || node->isa()) { - x_ = node; - return; - } - - AnfVisitor::Visit(node); - if (!is_one_) { - x_ = node; - } -} - -void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (*value == *zero_) { - is_zero_ = true; - } else if (*value == *one_) { - is_one_ = true; - } -} - -void MultiplyByZeroOrOne::Reset() { - x_ = nullptr; - is_one_ = false; - is_zero_ = false; -} - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > FLT_EPSILON) { - return false; - } - } - return true; - } else if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > DBL_EPSILON) { - return false; - } - } - return true; - } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (data2[i] != check_value_) { - return false; - } - } - return true; - } - // input Data Types is not supported - return false; -} - -bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { - return false; - } - return IsTensorConstant(value); -} - -void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { - if (!node->isa()) { - return nullptr; - } - - auto value = node->cast()->value(); - - if (!value->isa()) { - return nullptr; - } - - tensor::TensorPtr tensor_ptr = dyn_cast(value); - return tensor_ptr->data_c(); -} - -// Make a new tensor (when possible) with the same shape as of `node` -// If x is nullptr then fill new tensor will "0" -// If x is a tensor with empty shape then fill new tensor with the single value of x -// If x is a tensor with same shape as `node` then return x as result -AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - if (x == nullptr) { - std::memset(data, 0, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - // x is not nullptr - if (x->isa()) { - if ((x->abstract() == nullptr) || !x->abstract()->isa()) { - return nullptr; - } - auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); - - if (x_shape != tensor_shape) { - return nullptr; - } - return x; - } - - if (!x->isa()) { - return nullptr; - } - auto x_value = x->cast()->value(); - if (!x_value->isa()) { - return nullptr; - } - - auto x_tensor_ptr = dyn_cast(x_value); - - if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { - return nullptr; - } - char *source_data = reinterpret_cast(GetPointerToTensorData(x)); - if (x_tensor_ptr->DataSize() == 1) { - for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { - memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); - } - } else { - memcpy(data, source_data, mem_size); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_zero_) { - if (x_->func_graph() != node->func_graph()) { - return nullptr; - } - return NewTensorFilledWithData(node); - } - return nullptr; -} - -void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { - if (is_zero_) { - x_ = node; - return; - } - - if (IsParam(node)) { - x_ = node; - return; - } - - if (IsCNode(node)) { - CNodePtr cnode = node->cast(); - if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } - x_ = node; - return; - } - auto value = node->cast()->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_one_) { - return NewTensorFilledWithData(node, x_); - } - return nullptr; -} - -void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { - if (is_one_) { - x_ = node; - return; - } - - if (IsParam(node) || IsCNode(node)) { - x_ = node; - return; - } - - auto value = node->cast()->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByOne::Reset() { - x_ = nullptr; - is_one_ = false; -} - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void AddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && - ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void AddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimTensorAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void TensorAddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void TensorAddByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } -} - -void TensorAddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { - return nullptr; - } - - // {PrimMomentum, {...}, Y, Z, Xs} - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { - return nullptr; - } - auto y = inputs[2]; - auto z = inputs[3]; - - // {kPrimZerosLike, X} - if (inputs[1]->cast()->size() != 2) { - return nullptr; - } - - // {prim::kPrimMakeTuple, Z, Y} - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); -} - -// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -// Support function to multiply two constant tensors: partially support broadcasting shapes -template -void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, - void **out_data, int out_data_size) { - T *data_1 = reinterpret_cast(in_data_1); - T *data_2 = reinterpret_cast(in_data_2); - T *data_out = new T[out_data_size]; - - if (in_data_1_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[i]; - } - } - if (in_data_2_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; - } - } - *out_data = reinterpret_cast(data_out); - return; -} - -AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, - const AnfNodePtr &node_3) { - if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || - (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { - return nullptr; - } - - auto value_1 = GetValueNode(vnode_1); - auto value_2 = GetValueNode(vnode_2); - - if (!value_1->isa() || !value_2->isa()) { - return nullptr; - } - - auto tensor_ptr_1 = dyn_cast(value_1); - auto tensor_ptr_2 = dyn_cast(value_2); - - auto tensor_1_abstract = vnode_1->abstract()->cast(); - auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); - TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } - - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); - - int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); - - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; - } - - void *data_out; - - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), - &data_out, data_out_size); - } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - // Un-support data types - return nullptr; - } - } - } - - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - memcpy(data, data_out, mem_size); - - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimMul, Tensor1, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - - if (!IsCNode(c_p_node_)) { - return nullptr; - } - - auto tensor1 = vnode_; - auto mul = c_p_node_->cast(); - - Reset(); - // {prim::kPrimMul, Tensor2, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - auto tensor2 = vnode_; - auto c_p_node = c_p_node_; - - auto PrimMul = GetValueNode(mul->input(0)); - auto fg = node->func_graph(); - - auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); - if (new_mul_tensor == nullptr) { - auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); - return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); - } - return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); -} - -void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { - if (IsValueNode(node)) { - vnode_ = node; - } - - if (IsCNode(node) || IsParam(node)) { - c_p_node_ = node; - } -} - -void ConstantDuplicateMul::Reset() { - vnode_ = nullptr; - c_p_node_ = nullptr; -} - -AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[2])) { - return nullptr; - } - auto scalar = GetValueNode(inputs[2]); - if (scalar->isa() && GetValue(scalar) == 1.0) { - return inputs[1]; - } else if (scalar->isa() && GetValue(scalar) == 1) { - return inputs[1]; - } - return nullptr; -} - -// grad = AllReduce(grad) / worker_number -// grad = grad + weight * decy -// -> -// grad = grad + weight * decy -// grad = AllReduce(grad) / worker_number -// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> -// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimAddN, Zs} - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - return nullptr; - } - auto addn = node->cast(); - if (addn->size() != 2) { - return nullptr; - } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); - if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { - return nullptr; - } - auto addn_maketuple = addn->input(1); - - auto fg = all_reduce_fg_; - // addn inputs cross the graph, make the inputs same as allreduce node. - if (z_->isa() && fg != z_->func_graph()) { - auto cnode_z = z_->cast(); - z_ = NewCNode(cnode_z->inputs(), fg); - } - - auto addn_op_node = addn->input(0); - auto make_tuple_op_node = addn->input(1)->cast()->input(0); - - AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); - AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); - AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); - AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); - ProcessDependEdge(fg, addn_maketuple, all_reduce); - return mul; -} - -void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, - const AnfNodePtr &new_node) { - // If has dynamic loss scale. - auto &users_map = fg->manager()->node_users(); - auto it = users_map.find(mul_cnode_); - if (it != users_map.end()) { - auto users = it->second; - for (auto &user_pair : users) { - auto node = user_pair.first; - if (node != addn_maketuple) { - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - fg->manager()->SetEdge(node, user_pair.second, new_node); - } - } - } - } -} - -void AdjustAllReduceMulAdd::Visit(const AnfNodePtr &node) { - if (level_ == 0) { - level_ = 1; - is_reduce_match_ = false; - // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} - AnfVisitor::Match(prim::kPrimMul)(node); - level_ = 0; - if (is_reduce_match_) { - mul_ = node->cast()->input(0); - mul_cnode_ = node->cast(); - y_ = tmp_; - } else { - z_ = node; - } - } - - if (level_ == 1) { - // {prim::kPrimAllReduce, X} - if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { - auto cnode = node->cast(); - if (cnode->size() > 1) { - all_reduce_ = cnode->input(0); - x_ = cnode->input(1); - is_reduce_match_ = true; - all_reduce_fg_ = cnode->func_graph(); - } - } else { - tmp_ = node; - } - } -} - -void AdjustAllReduceMulAdd::Reset() { - level_ = 0; - is_reduce_match_ = false; - x_ = nullptr; - y_ = nullptr; - z_ = nullptr; - tmp_ = nullptr; - all_reduce_fg_ = nullptr; -} - -AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} - -AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h deleted file mode 100644 index f4bdb0d655..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ - -#include -#include -#include - -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -class MultiplyByZeroOrOne : public AnfVisitor { - public: - MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} - ~MultiplyByZeroOrOne() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}, is_one_{false}; - ValuePtr zero_, one_; - AnfNodePtr x_{nullptr}; -}; - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -class CheckTensorConstant { - public: - explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} - ~CheckTensorConstant() = default; - - bool IsTensorConstant(const ValuePtr &value); - bool IsTensorScalarConstant(const ValuePtr &value); - - private: - int check_value_; -}; - -class TensorMultiplyBase : public AnfVisitor { - protected: - void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); - - // Make a new tensor (when possible) with the same shape as of `node` - // If x is nullptr then fill new tensor will "0" - // If x is a tensor with empty shape then fill new tensor with the single value of x - // If x is a tensor with same shape as `node` then return x as result - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); - - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -class TensorMultiplyByZero : public TensorMultiplyBase { - public: - TensorMultiplyByZero() : zero_(MakeValue(0)) {} - ~TensorMultiplyByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; -}; - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -class TensorMultiplyByOne : public TensorMultiplyBase { - public: - TensorMultiplyByOne() {} - ~TensorMultiplyByOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_one_{false}; -}; - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -class AddByZero : public AnfVisitor { - public: - AddByZero() : zero_(MakeValue(0)) {} - ~AddByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -class TensorAddByZero : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - AnfNodePtr x_{nullptr}; -}; - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -class OptUpdateZeroTensor : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - -// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -class ConstantDuplicateMul : public AnfVisitor { - public: - // Support function to multiply two constant tensors: partially support broadcasting shapes - template - void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size); - - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - AnfNodePtr vnode_; - AnfNodePtr c_p_node_; -}; - -class PowerOneEliminate : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - -// grad = AllReduce(grad) / worker_number -// grad = grad + weight * decy -// -> -// grad = grad + weight * decy -// grad = AllReduce(grad) / worker_number - -// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> -// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -class AdjustAllReduceMulAdd : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node); - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - int level_{0}; - bool is_reduce_match_{false}; - AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; - AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; - FuncGraphPtr all_reduce_fg_{nullptr}; -}; - -class ArithmeticSimplify : public OptimizerCaller { - public: - ArithmeticSimplify() - : multiply_by_zero_or_one_(std::make_shared()), - tensor_multiply_by_one_(std::make_shared()), - add_by_zero_(std::make_shared()), - tensor_add_by_zero_(std::make_shared()), - identity_(std::make_shared(prim::kPrimIdentity)), - opt_update_zero_tensor_(std::make_shared()), - constant_duplicate_mul_(std::make_shared()), - power_one_(std::make_shared()) { - eliminaters_.emplace_back(multiply_by_zero_or_one_); - eliminaters_.emplace_back(tensor_multiply_by_one_); - eliminaters_.emplace_back(add_by_zero_); - eliminaters_.emplace_back(tensor_add_by_zero_); - eliminaters_.emplace_back(identity_); - eliminaters_.emplace_back(opt_update_zero_tensor_); - eliminaters_.emplace_back(constant_duplicate_mul_); - eliminaters_.emplace_back(power_one_); - } - ~ArithmeticSimplify() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr multiply_by_zero_or_one_; - OptimizerCallerPtr tensor_multiply_by_one_; - OptimizerCallerPtr add_by_zero_; - OptimizerCallerPtr tensor_add_by_zero_; - OptimizerCallerPtr identity_; - OptimizerCallerPtr opt_update_zero_tensor_; - OptimizerCallerPtr constant_duplicate_mul_; - OptimizerCallerPtr power_one_; - - std::vector eliminaters_{}; -}; - -// Arithmetic Simplifications should be done after step_parallel. -// eg: Mul(0, weight) where weight is a parameter will be simplified to a constant tensor -// with shape(weight), but after step_parallel, shape of weight may be changed, so the -// shape of the constant tensor should also be changed. So this pass is seperated from -// ArithmeticSimplify and deferred until step_parallel. -class ArithmeticSimplify2 : public OptimizerCaller { - public: - ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { - eliminaters_.emplace_back(tensor_multiply_by_zero_); - } - ~ArithmeticSimplify2() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr tensor_multiply_by_zero_; - std::vector eliminaters_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc deleted file mode 100644 index 726f4a28b0..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc +++ /dev/null @@ -1,584 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/branch_culling.h" - -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -AnfNodePtr GenerateSwitchNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data, - int switch_idx) { - auto switch_node = prim::GetPythonOps("geswitch", "mindspore.ops.functional")->cast(); - std::vector switch_nodes{NewValueNode(switch_node), data, cond}; - auto switch_apply = graph->NewCNode(switch_nodes); - std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), switch_apply, - NewValueNode(MakeValue(switch_idx))}; - return graph->NewCNode(tuple_getitem_nodes); -} - -AnfNodePtr GenerateSwitchTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - return GenerateSwitchNode(graph, cond, data, 1); -} - -AnfNodePtr GenerateSwitchFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - return GenerateSwitchNode(graph, cond, data, 0); -} - -bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { - // The CNode inputs of the following Primitive with index in std::vector should not be guarded by geswitch - // node because it is attribute or ge specific reason. - // Example : when convert CNode(kPrimReduceSum, x, axis), node of index 2 in CNode->inputs is axis which should not be - // converted to switch guarded. - std::vector>> white_list({{prim::kPrimApplyMomentum, {1, 2}}, - {prim::kPrimMomentum, {2, 3}}, - {prim::kPrimStateSetItem, {1}}, - {prim::kPrimTupleGetItem, {2}}, - {prim::kPrimEnvGetItem, {1}}, - {prim::kPrimEnvSetItem, {1}}, - {prim::kPrimReduceSum, {2}}, - {prim::kPrimReduceMean, {2}}, - {prim::kPrimReduceAll, {2}}, - {prim::kPrimCast, {2}}, - {prim::kPrimTranspose, {2}}, - {prim::kPrimOneHot, {2}}, - {prim::kPrimGatherV2, {3}}, - {prim::kPrimReshape, {2}}, - {prim::kPrimAssign, {1}}, - {prim::kPrimAssignAdd, {1}}, - {prim::kPrimAssignSub, {1}}, - {prim::kPrimTensorSummary, {1}}, - {prim::kPrimImageSummary, {1}}, - {prim::kPrimScalarSummary, {1}}, - {prim::kPrimApplyRMSProp, {6, 7, 8}}, - {prim::kPrimCumSum, {2}}, - {prim::kPrimTile, {2}}, - {prim::kPrimExpandDims, {2}}, - {prim::kPrimHistogramSummary, {1}}}); - for (auto &item : white_list) { - auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { - return IsPrimitiveCNode(node, item.first) && idx == index; - }); - if (matched) { - return true; - } - } - - std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend}; - for (auto &item : adapter_convert_ops) { - if (IsPrimitiveCNode(node, item)) { - return true; - } - } - return false; -} - -using NodeInputReplMap = std::unordered_map, AnfNodePtr, PairHasher>; -// replace the nodes which should be changed -void RunSwitchNodeReplace(const FuncGraphManagerPtr &manager, std::vector> nodes_changed, - std::unordered_map repl_node, NodeInputReplMap repl_node_inputs, - const FuncGraphPtr &func_graph) { - for (auto &node_pair : nodes_changed) { - CNodePtr old_node = node_pair.first; - CNodePtr new_node = node_pair.second; - MS_EXCEPTION_IF_NULL(old_node); - MS_EXCEPTION_IF_NULL(new_node); - for (size_t i = 0; i < old_node->size(); i++) { - auto input = old_node->input(i); - if (repl_node.count(input) != 0) { - new_node->add_input(repl_node[input]); - } else if (repl_node_inputs.count(std::pair(old_node, i)) != 0) { - new_node->add_input(repl_node_inputs[std::pair(old_node, i)]); - } else { - new_node->add_input(input); - } - } - } - - for (auto &item : repl_node) { - if (IsPrimitiveCNode(item.second, prim::kPrimReturn)) { - func_graph->set_output(item.second->cast()->input(1)); - } else if (!manager->Replace(item.first, item.second)) { - MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed original:" << item.first->DebugString(2) - << " to new: " << item.second->DebugString(2); - } - } -} - -// trace the node that should add switch and replace them with new nodes in the graph -FuncGraphPtr TransformGraphCondBranchNodes( - const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::function &generate_func) { - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - // record the node that has been changed - std::vector> nodes_changed; - // record the node to be replaced - std::unordered_map repl_node; - // record the node input to be replaced - NodeInputReplMap repl_node_inputs; - const AnfNodeSet &nodes = graph->nodes(); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto inputs = node->cast()->inputs(); - bool should_replace = false; - // if the apply input does not belong to graph, insert a switch node - for (size_t index = 0; index < inputs.size(); index++) { - auto input_node = inputs[index]; - MS_EXCEPTION_IF_NULL(input_node); - // for some ops input should not guard it with switch - if (InConvertWhiteList(node, index)) { - continue; - } - - // If the input for node is not the graph belonged, or it is an ValueNode. - // Bypass the Primitive node which is inputs[0]. - if ((index >= 1 && inputs[index]->func_graph() != nullptr && inputs[index]->func_graph() != graph) || - ((index >= 1 && inputs[index]->isa()))) { - input_node = generate_func(graph, cond, inputs[index]); - repl_node_inputs[std::pair(node, index)] = input_node; - should_replace = true; - } - if (input_node == nullptr) { - MS_LOG(EXCEPTION) << "generate switch node failed"; - } - } - if (should_replace) { - auto new_node = graph->NewCNode(); - repl_node[node] = new_node; - nodes_changed.emplace_back(node->cast(), new_node); - } - } - RunSwitchNodeReplace(manager, nodes_changed, repl_node, repl_node_inputs, graph); - return graph; -} - -struct SharedOp { - tensor::TensorPtr const_data; - CNodePtr square_ops[2]; - CNodePtr merge_ops[2]; -} MergeNetOutput; - -inline tensor::TensorPtr GetConstData() { return MergeNetOutput.const_data; } -inline void SetConstData(const tensor::TensorPtr &const_value) { MergeNetOutput.const_data = const_value; } - -inline CNodePtr GetSquareOp(int switch_idx) { return MergeNetOutput.square_ops[switch_idx]; } -inline void SetSquareOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.square_ops[switch_idx] = op; } - -inline CNodePtr GetMergeOp(int switch_idx) { return MergeNetOutput.merge_ops[switch_idx]; } -inline void SetMergeOp(int switch_idx, const CNodePtr &op) { MergeNetOutput.merge_ops[switch_idx] = op; } - -inline void ResetSharedOp() { - SetConstData(nullptr); - SetSquareOp(0, nullptr); - SetSquareOp(1, nullptr); - SetMergeOp(0, nullptr); - SetMergeOp(1, nullptr); -} - -tensor::TensorPtr ConstData() { - std::vector shp = {1}; - tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); - auto *val = static_cast(const_data->data_c()); - *val = 0; - return const_data; -} - -CNodePtr SquareOp(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, - const tensor::TensorPtr &const_data) { - auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); - // for the depended node , add two const data to merge the flow ,one for depended node with same switch, - // the other use the opposite - auto ctrl_data = NewValueNode(const_data); - auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); - - std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; - auto square_op = graph->NewCNode(square_nodes); - - return square_op; -} - -CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int switch_idx, - const tensor::TensorPtr &const_data, const CNodePtr &square_op) { - // for the depended node , add two const data to merge the flow ,one for depended node with same switch, - // the other use the opposite - auto oppsite_ctrl_data = NewValueNode(const_data); - auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); - - std::vector merge_nodes; - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; - merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); - auto merge_op = graph->NewCNode(merge_nodes); - - return merge_op; -} - -// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) -// control_depend(output_node, square_op) -AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, - int switch_idx) { - tensor::TensorPtr const_data = GetConstData(); - if (const_data == nullptr) { - const_data = ConstData(); - SetConstData(const_data); - } - - CNodePtr square_op = GetSquareOp(switch_idx); - if (square_op == nullptr) { - square_op = SquareOp(graph, cond, switch_idx, const_data); - SetSquareOp(switch_idx, square_op); - } - - CNodePtr merge_op = GetMergeOp(switch_idx); - if (merge_op == nullptr) { - merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); - SetMergeOp(switch_idx, merge_op); - } - - std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; - auto control_depend_op = graph->NewCNode(control_depend_nodes); - - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op}; - auto depend_op = graph->NewCNode(depend_nodes); - - return depend_op; -} - -// construct a merge output and add dependency with the netoutput node from control_depend -// we need to reserve the control_depend node, besides the generated merge node and control_depend node -CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst, - int switch_idx) { - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast(); - std::vector shp = {1}; - tensor::TensorPtr const_data = std::make_shared(kInt32->type_id(), shp); - auto *val = static_cast(const_data->data_c()); - *val = 0; - // for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same - // switch the other use the opposite - auto ctrl_data = NewValueNode(const_data); - auto oppsite_ctrl_data = NewValueNode(const_data); - auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx); - auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx); - - std::vector square_nodes{NewValueNode(PrimSquare), ctrl_node}; - auto square_op = graph->NewCNode(square_nodes); - - std::vector merge_nodes; - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node}; - merge_nodes.push_back(graph->NewCNode(make_tuple_nodes)); - auto merge_output = graph->NewCNode(merge_nodes); - - std::vector control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op}; - auto cond_dep_output = graph->NewCNode(control_depend_nodes); - - std::vector depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output, - cond_dep_output}; - return graph->NewCNode(depended_make_tuple_nodes); -} - -// generate switch nodes for true graph node inputs -AnfNodePtr GenerateSwitchDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchDependNode(graph, cond, data, 1); -} - -// generate switch nodes for false graph node inputs -AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &data) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchDependNode(graph, cond, data, 0); -} - -// generate switch nodes for true graph node inputs -CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &con_input, const AnfNodePtr &output) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1); -} - -// generate switch nodes for false graph node inputs -CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, - const AnfNodePtr &con_input, const AnfNodePtr &output) { - // for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch - return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0); -} - -// to judge if the node used in ControlDepend is a net output node -bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { - auto uses = manager->node_users()[node]; - bool is_output_node = true; - for (auto &item : uses) { - if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { - continue; - } - is_output_node = false; - break; - } - return is_output_node; -} - -// generate node for Depended MakeTuple -void GenerateReplNodeForDependMakeTuple( - const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::shared_ptr> &repl_node, - const std::function &generate_func, - const std::function &gen_ctl_depd_func) { - MS_EXCEPTION_IF_NULL(graph->manager()); - - auto make_tuple_inputs = depended_node->cast()->inputs(); - const size_t make_tuple_begin_idx = 1; - std::vector new_make_tuple_nodes; - bool replace_make_tuple = false; - new_make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t idx = make_tuple_begin_idx; idx < make_tuple_inputs.size(); idx++) { - auto depended_tuple_input_node = make_tuple_inputs[idx]; - if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimDepend)) { - new_make_tuple_nodes.push_back(depended_tuple_input_node); - continue; - } - if (IsPrimitiveCNode(depended_tuple_input_node->cast(), prim::kPrimControlDepend)) { - // only when the control depend input is not square op (the op to use as merge output) - auto control_inputs = depended_tuple_input_node->cast()->inputs(); - if (control_inputs.size() != 3) { - MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); - } - // control inputs: primitive, src, dst - auto dst_node = control_inputs[2]; - if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { - auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node); - MS_EXCEPTION_IF_NULL(gen_node); - auto tuple_inputs = gen_node->inputs(); - // add depended tuple inputs to new_make_tuple directly - for (size_t i = 1; i < tuple_inputs.size(); i++) { - new_make_tuple_nodes.push_back(tuple_inputs[i]); - } - } - replace_make_tuple = true; - continue; - } - - if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { - auto gen_node = generate_func(graph, cond, depended_tuple_input_node); - new_make_tuple_nodes.push_back(gen_node); - replace_make_tuple = true; - continue; - } - - MS_LOG(WARNING) << "depended node being used by others, "; - } - if (replace_make_tuple) { - auto make_tuple_op = graph->NewCNode(new_make_tuple_nodes); - (*repl_node)[depended_node] = make_tuple_op; - } -} - -// generate a replace depend node for a single network output node -void GenerateRepDepend( - const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::shared_ptr> &repl_node, - const std::function &generate_func, - const std::function &gen_ctl_depd_func) { - auto inputs = node->inputs(); - if (inputs.size() != 3) { - MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; - } - - std::vector new_depened_inputs; - // Inputs should be [depend, actual_value, depended_node] - auto depended_node = inputs[2]; - new_depened_inputs.push_back(inputs[0]); - new_depened_inputs.push_back(inputs[1]); - // depended node should be make_tuple or a single depended node - if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { - GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); - } else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) { - // only when the control depend input is not square op (the op to use as merge output) - auto control_inputs = depended_node->cast()->inputs(); - // control inputs: primitive, src, dst - if (control_inputs.size() != 3) { - MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size(); - } - auto dst_node = control_inputs[2]; - if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) { - auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node); - (*repl_node)[depended_node] = gen_node; - } - } else { - // Check if there is only single user for depend_node. - if (graph->manager()->node_users()[depended_node].size() == 1) { - auto gen_node = generate_func(graph, cond, depended_node); - (*repl_node)[depended_node] = gen_node; - } else { - MS_LOG(WARNING) << "depended node being used by others"; - } - } -} - -// generate depend node for netoutput node, to resolve the stream synchronize problem of ge -// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) -// and add control_depend of graph output node and square node. -FuncGraphPtr TransformGraphDependNode( - const FuncGraphPtr &graph, const AnfNodePtr &cond, - const std::function &gen_depend_func, - const std::function &gen_ctl_depd_func) { - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - ResetSharedOp(); - std::shared_ptr> repl_node = - std::make_shared>(); // record the node to be replaced - const AnfNodeSet &nodes = graph->nodes(); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - auto cnode = node->cast(); - if (cnode->size() != 3) { - MS_LOG(EXCEPTION) << "Dependnode input size != 3"; - } - auto depended_node = cnode->input(2); - MS_EXCEPTION_IF_NULL(depended_node); - if (!depended_node->isa()) { - continue; - } - if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { - continue; - } - GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); - } - } - ResetSharedOp(); - - for (auto &item : *repl_node) { - if (!manager->Replace(item.first, item.second)) { - MS_LOG(EXCEPTION) << "TransformGraphDependNode replace node failed"; - } - } - - return graph; -} - -FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { - (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); - return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); -} - -FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { - (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); - return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); -} - -// judge if the true and false graph output is compatible(they shall have same tuple size) -bool GraphOutputCompatible(const AbstractBasePtr &true_branch_abs, const AbstractBasePtr &false_branch_abs) { - MS_EXCEPTION_IF_NULL(true_branch_abs); - MS_EXCEPTION_IF_NULL(false_branch_abs); - - if (true_branch_abs->isa() && false_branch_abs->isa()) { - abstract::AbstractTuplePtr true_branch_tuple = true_branch_abs->cast(); - abstract::AbstractTuplePtr false_branch_tuple = false_branch_abs->cast(); - if (true_branch_tuple->elements().size() != false_branch_tuple->elements().size()) { - MS_LOG(ERROR) << "true branch size:" << true_branch_tuple->elements().size() - << ", not equal to false banch size:" << false_branch_tuple->elements().size() << " "; - return false; - } - bool all_compatible = true; - for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { - all_compatible = - all_compatible && GraphOutputCompatible(true_branch_tuple->elements()[i], false_branch_tuple->elements()[i]); - } - return all_compatible; - } - TypePtr true_branch_type = true_branch_abs->BuildType(); - TypePtr false_branch_type = false_branch_abs->BuildType(); - MS_LOG(DEBUG) << "branch output Type equal?" << (*true_branch_type == *false_branch_type) - << " true:" << true_branch_type->ToString() << " false:" << false_branch_type->ToString(); - return (*true_branch_type == *false_branch_type); -} - -AnfNodePtr GenerateMergeNodes(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, - const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const FuncGraphPtr &switch_graph, - const AnfNodePtr &cond) { - MS_EXCEPTION_IF_NULL(true_graph_output_abs); - MS_EXCEPTION_IF_NULL(false_graph_output_abs); - MS_EXCEPTION_IF_NULL(cond); - MS_EXCEPTION_IF_NULL(switch_graph); - auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast(); - MS_EXCEPTION_IF_NULL(PrimMerge); - - if (!true_graph_output_abs->isa()) { - std::vector merge_nodes; - merge_nodes.push_back(NewValueNode(PrimMerge)); - std::vector make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), true_output_node, false_output_node}; - merge_nodes.push_back(switch_graph->NewCNode(make_tuple_nodes)); - std::vector tuple_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), - switch_graph->NewCNode(merge_nodes), NewValueNode(MakeValue(0))}; - return switch_graph->NewCNode(tuple_getitem_nodes); - } else { - abstract::AbstractTuplePtr true_branch_tuple = true_graph_output_abs->cast(); - abstract::AbstractTuplePtr false_branch_tuple = false_graph_output_abs->cast(); - - std::vector make_tuple_nodes; - make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t i = 0; i < true_branch_tuple->elements().size(); i++) { - std::vector true_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), true_output_node, - NewValueNode(MakeValue(SizeToInt(i)))}; - auto true_node = switch_graph->NewCNode(true_getitem_nodes); - std::vector false_getitem_nodes{NewValueNode(prim::kPrimTupleGetItem), false_output_node, - NewValueNode(MakeValue(SizeToInt(i)))}; - auto false_node = switch_graph->NewCNode(false_getitem_nodes); - - auto merge_node = GenerateMergeNodes(true_node, false_node, true_branch_tuple->elements()[i], - false_branch_tuple->elements()[i], switch_graph, cond); - make_tuple_nodes.push_back(merge_node); - } - return switch_graph->NewCNode(make_tuple_nodes); - } -} - -AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, - const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, - const FuncGraphPtr &switch_graph) { - if (!GraphOutputCompatible(true_graph_output_abs, false_graph_output_abs)) { - MS_LOG(EXCEPTION) << "Switch output branch not compatible, true:" << true_graph_output_abs->ToString() - << ", false:" << false_graph_output_abs->ToString(); - } - return GenerateMergeNodes(true_output_node, false_output_node, true_graph_output_abs, false_graph_output_abs, - switch_graph, cond); -} -} // namespace internal -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/optimizer/irpass/branch_culling.h deleted file mode 100644 index 2b5b30bdbf..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ - -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/pattern_matcher.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimSwitch, true, X, Y} -// {prim::kPrimSwitch, false, X, Y} -class SwitchSimplify : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode cond, true_br, false_br; - auto SwitchSimplLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto cond_value_ = GetValue(GetValueNode(cond.GetNode(node))); - if (cond_value_) { - return true_br.GetNode(node); - } - return false_br.GetNode(node); - }; - - MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), SwitchSimplLambda, - cond.CheckFunc(IsValueNode, node)); - - return nullptr; - } -}; - -// {prim::kPrimTupleGetItem, {prim::kPrimSwith, X0, X1, X2}, C} => -// {prim::kPrimSwith, X0, {prim::kPrimTupleGetItem, X1, C}, {prim::kPrimTupleGetItem, X2, C}} -class FloatTupleGetItemSwitch : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode cond, true_br, false_br, x; - MATCH_REPLACE_IF(node, - PPrimitive(prim::kPrimTupleGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x), - PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimTupleGetItem, true_br, x), - PPrimitive(prim::kPrimTupleGetItem, false_br, x)), - x.CheckFunc(IsVNode, node)); - return nullptr; - } -}; - -// {prim::kPrimEnvGetItem, {prim::kPrimSwitch, X1, X2, X3}, X4, X5} => -// {prim::kPrimSwitch, X1, {prim::kPrimEnvGetItem, X2, X4, X5}, {prim::kPrimEnvGetItem, X3, X4, X5}} -class FloatEnvGetItemSwitch : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode cond, true_br, false_br, x, x2; - MATCH_REPLACE(node, - PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), x, x2), - PPrimitive(prim::kPrimSwitch, cond, PPrimitive(prim::kPrimEnvGetItem, true_br, x, x2), - PPrimitive(prim::kPrimEnvGetItem, false_br, x, x2))); - - return nullptr; - } -}; - -namespace internal { -FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); -FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond); -AnfNodePtr TransformMergeBranches(const AnfNodePtr &true_output_node, const AnfNodePtr &false_output_node, - const AbstractBasePtr &true_graph_output_abs, - const AbstractBasePtr &false_graph_output_abs, const AnfNodePtr &cond, - const FuncGraphPtr &func_graph); -} // namespace internal - -// {{prim::kPrimSwitch, X, G1, G2}, Xs} -class ConvertSwitchReplacement : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto cnode_ = node->cast(); - if (cnode_->size() < 1) { - return nullptr; - } - - auto node_ = cnode_->input(0); - - PatternNode cond, true_br, false_br; - - auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto g1_ = GetValueNode(true_br.GetNode(node_)); - auto g2_ = GetValueNode(false_br.GetNode(node_)); - auto x_ = cond.GetNode(node_); - - // for switch replace method, only graphs without graph inside can be replaced - for (auto &item : g1_->value_nodes()) { - auto value_node = item.first; - if (IsValueNode(value_node)) { - return nullptr; - } - } - - for (auto &item : g2_->value_nodes()) { - auto value_node = item.first; - if (IsValueNode(value_node)) { - return nullptr; - } - } - - auto true_output = g1_->output()->abstract(); - auto false_output = g2_->output()->abstract(); - auto trans_g1 = internal::TransformGraphCondTrueBranchNodes(g1_, x_); - auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); - - std::vector params; - auto fg = node_->func_graph(); - auto cloned_g1 = InlineClone(trans_g1, fg, params); - auto cloned_g2 = InlineClone(trans_g2, fg, params); - auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); - - return nnode; - }; - - MATCH_REPLACE_LAMBDA_IF( - node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, - true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); - - return nullptr; - } -}; - -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_BRANCH_CULLING_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc deleted file mode 100644 index a497f3d5bd..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/cast_eliminate.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "ir/func_graph.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimCast, X, T} -AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimCast, {IsNode, IsVNode})(node); - - // check pattern match - if (tgt_ == nullptr) { - return nullptr; - } - - // src type check - auto src_type = src_->Type(); - if (src_type == nullptr || !src_type->isa()) { - return nullptr; - } - - src_type = src_type->cast()->element(); - - // tgt type check - auto tgt_type = GetValueNode(tgt_); - if (tgt_type->isa()) { - tgt_type = tgt_type->cast()->element(); - } - - if (src_type->type_id() == tgt_type->type_id()) { - return src_; - } - - return nullptr; -} - -void CastSameTypeEliminater::Visit(const AnfNodePtr &node) { - if (src_ == nullptr) { - src_ = node; - } else { - tgt_ = node; - } -} - -// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -AnfNodePtr TwoCastEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimCast, {IsCNode, IsNode})(node); - - if (x_ != nullptr && t_ != nullptr) { - auto cast_op = parse::python_adapter::GetPyFn("mindspore.ops.operations", "Cast")(); - ValuePtr cast = parse::data_converter::PyDataToValue(cast_op); - auto cnode = NewCNode({NewValueNode(cast), x_, t_}, node->func_graph()); - cnode->set_abstract(node->abstract()); - return cnode; - } - return nullptr; -} - -void TwoCastEliminater::Visit(const AnfNodePtr &node) { - if (IsPrimitiveCNode(node, prim::kPrimCast)) { - auto cnode = node->cast(); - // {prim::kPrimCast, X, Y} - if (cnode->size() != 3) { - return; - } - x_ = cnode->input(1); - } else { - t_ = node; - } -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h deleted file mode 100644 index d98d0b677b..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ - -#include "ir/visitor.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimCast, X, T} -class CastSameTypeEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const AnfNodePtr &node) override; - void Reset() { - src_ = nullptr; - tgt_ = nullptr; - } - - private: - AnfNodePtr src_{nullptr}, tgt_{nullptr}; -}; - -// {prim::kPrimCast, {prim::kPrimCast, X, Y}, T} -class TwoCastEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - void Visit(const AnfNodePtr &node) override; - void Reset() { - x_ = nullptr; - t_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, t_{nullptr}; -}; - -class CastEliminater : public OptimizerCaller { - public: - CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} - ~CastEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - auto new_node = cast_same_type_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - new_node = two_cast_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - return nullptr; - } - - private: - CastSameTypeEliminater cast_same_type_eliminater_; - TwoCastEliminater two_cast_eliminater_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/convert.h b/mindspore/ccsrc/optimizer/irpass/convert.h deleted file mode 100644 index 3049bafb1e..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/convert.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ - -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimPrint, Xs} -> {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} -class PrintTupleWrapper : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimPrint)) { - return nullptr; - } - - // already be {prim::kPrimPrint, {prim::kPrinMakeTuple, Xs}} - auto cnode = node->cast(); - if (cnode->size() == 2 && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) { - return nullptr; - } - - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - - // {prim::kPrimPrint, Xs} - auto &inputs = cnode->inputs(); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - // {prim::kPrinMakeTuple, Xs} - auto fg = node->func_graph(); - auto tuple = NewCNode(args, fg); - auto print = GetValueNode(cnode->input(0)); - return NewCNode({NewValueNode(print), tuple}, fg); - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CONVERT_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h deleted file mode 100644 index 3f100dcaec..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/env_item_eliminate.h +++ /dev/null @@ -1,364 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ - -#include -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "utils/symbolic.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class EnvGetitemTransform { - public: - EnvGetitemTransform() : cache_() {} - ~EnvGetitemTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { - if (cache_.find(fg) == cache_.end()) { - cache_[fg] = {}; - } - - auto &cache = cache_[fg]; - auto hash_key = std::make_pair(key, default_node); - if (cache.find(hash_key) == cache.end()) { - std::ostringstream ss("env", std::ostringstream::app); - if (key->node() != nullptr) { - ss << key->node()->ToString(); - } - - auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); - auto env = new_fg->output(); - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; - } - - env = inputs[1]; - auto value = inputs[3]; - auto key2 = GetValueNode(inputs[2]); - if (*key2 == *key) { - new_fg->set_output(value); - cache[hash_key] = new_fg; - cache_[fg] = cache; - return new_fg; - } - } - new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); - cache[hash_key] = new_fg; - } - - return cache[hash_key]; - } - - private: - std::unordered_map, FuncGraphPtr, PairHasher>> - cache_; -}; -} // namespace internal - -// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y -class NewEnvGetItem : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - auto gety = [this](const AnfNodePtr &node) -> bool { - this->y_ = node; - return true; - }; - - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); - if (env_ != nullptr && env_->Len() == 0) { - return y_; - } - return nullptr; - } - - void Visit(const ValueNodePtr &vnode) override { - if (env_ == nullptr) { - env_ = GetValueNode(vnode); - } - } - - void Reset() { - y_ = nullptr; - env_ = nullptr; - } - - private: - AnfNodePtr y_{nullptr}; - EnvInstancePtr env_{nullptr}; -}; - -// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> -// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} -class AddEnvGetItem : public AnfVisitor { - public: - AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} - ~AddEnvGetItem() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsAddCNode = [](const AnfNodePtr &node) -> bool { - return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast()->size() == 3; - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node); - - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C, Z} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto c = cnode->input(2); - auto z = cnode->input(3); - - // {prim::kPrimEnvAdd, X, Y} - auto x = inp1->input(1); - auto y = inp1->input(2); - - auto fg = node->func_graph(); - auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z}); - auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z}); - - return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz}); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; - ValuePtr PrimHyperAdd_; -}; - -// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} -class EnvGetSetItem : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsSetCNode = [](const AnfNodePtr &node) -> bool { - if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) { - return false; - } - - // {prim::kPrimEnvSetItem, X, C1, Y} - auto &inputs = node->cast()->inputs(); - if (inputs.size() != 4) { - return false; - } - - return IsValueNode(inputs[2]); - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode, IsNode})(node); - - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C2, Z} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto key2 = cnode->input(2); - auto c2 = GetValueNode(key2); - auto default_v = cnode->input(3); - - // {prim::kPrimEnvSetItem, X, C1, Y} - auto env = inp1->input(1); - auto c1 = GetValueNode(inp1->input(2)); - auto last_set = inp1->input(3); - - if (*c1 == *c2) { - return last_set; - } - - while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { - // {prim::kPrimEnvSetItem, env, symbolickey, value} - auto &inputs = env->cast()->inputs(); - if (inputs.size() != 4 || !IsValueNode(inputs[2])) { - MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; - } - - env = inputs[1]; - last_set = inputs[3]; - auto symbolic_c1 = GetValueNode(inputs[2]); - if (*symbolic_c1 == *c2) { - return last_set; - } - } - - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v}); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; -}; - -class EnvGetItemEliminater : public OptimizerCaller { - public: - EnvGetItemEliminater() - : new_env_get_item_(std::make_shared()), - add_env_get_item_(std::make_shared()), - env_get_set_item_(std::make_shared()) { - eliminaters_.emplace_back(new_env_get_item_); - eliminaters_.emplace_back(add_env_get_item_); - eliminaters_.emplace_back(env_get_set_item_); - } - ~EnvGetItemEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; - std::vector eliminaters_{}; -}; - -// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} -class IncorporateEnvGetitem : public AnfVisitor { - public: - IncorporateEnvGetitem() : env_get_item_transform_() {} - ~IncorporateEnvGetitem() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsGCNode = [](const AnfNodePtr &node) -> bool { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() < 1) { - return false; - } - return IsValueNode(cnode->input(0)); - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode, IsNode})(node); - - if (!is_match_) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C, Y} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto key = GetValueNode(cnode->input(2)); - auto default_v = cnode->input(3); - - // {G, Xs} - auto inputs = inp1->inputs(); - auto fg = GetValueNode(inputs[0]); - auto new_fg = env_get_item_transform_(fg, key, default_v); - - std::vector args; - args.push_back(NewValueNode(new_fg)); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - return node->func_graph()->NewCNode(args); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; - internal::EnvGetitemTransform env_get_item_transform_; -}; - -// {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} -class IncorporateEnvGetitemSwitch : public AnfVisitor { - public: - IncorporateEnvGetitemSwitch() : env_get_item_transform_() {} - ~IncorporateEnvGetitemSwitch() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - is_match_ = false; - auto IsSwNode = [](const AnfNodePtr &node) -> bool { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() < 1) { - return false; - } - - return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); - }; - AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode, IsNode})(node); - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - // {prim::kPrimEnvGetItem, {...}, C, Y} - auto cnode = node->cast(); - auto inp1 = cnode->input(1)->cast(); - auto key = GetValueNode(cnode->input(2)); - auto default_v = cnode->input(3); - - // {{prim::kPrimSwitch, X, G1, G2}, Xs} - auto inputs = inp1->inputs(); - is_match_ = false; - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs[0]); - if (!is_match_) { - return nullptr; - } - - // {prim::kPrimSwitch, X, G1, G2} - auto sw = inputs[0]->cast(); - auto x = sw->input(1); - auto g1 = GetValueNode(sw->input(2)); - auto g2 = GetValueNode(sw->input(3)); - auto new_g1 = env_get_item_transform_(g1, key, default_v); - auto new_g2 = env_get_item_transform_(g2, key, default_v); - - auto fg = node->func_graph(); - auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); - - std::vector args{new_sw}; - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - return fg->NewCNode(args); - } - - void Visit(const AnfNodePtr &) override { is_match_ = true; } - - private: - bool is_match_{false}; - internal::EnvGetitemTransform env_get_item_transform_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc deleted file mode 100644 index 317d67e792..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc +++ /dev/null @@ -1,143 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/grad_var_prepare.h" -#include -#include -#include -#include - -#include "operator/composite/composite.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -namespace irpass { -static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, FuncGraphPtr func_graph, - AnfNodePtr func_node, bool is_unpack, bool sens_param) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(func_node); - std::vector nodes; - AnfNodePtr unpack_graph_node = nullptr; - if (is_unpack) { - auto unpack_graph = std::make_shared("unpack_graph", sens_param, true); - nodes.push_back(NewValueNode(unpack_graph)); - nodes.push_back(func_node); - // {unpackcall, {GradOperation, ...}, args...} - std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr &node) { return node; }); - unpack_graph_node = func_graph->NewCNode(nodes); - } else { - auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); - nodes.push_back(NewValueNode(unpack_graph)); - nodes.push_back(func_node); - // {{GradOperation, ...}, args...} - std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr &node) { return node; }); - unpack_graph_node = func_graph->NewCNode(nodes); - } - return unpack_graph_node; -} - -// get metagraph of value node -MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { - ValuePtr value; - if (IsValueNode(node)) { - value = GetValueNode(node)->cast()->function(); - } else { - value = GetValueNode(node); - } - if (value == nullptr) { - return nullptr; - } - return value->cast(); -} - -// check if node is a specific metafuncgraph op -bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { - if (node != nullptr) { - auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); - if (meta_func_graph_ptr == nullptr) { - return false; - } - - if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { - return true; - } - } - return false; -} - -// {{GradOperation, g, w}, Ys} -// {UnPackCall, {GradOperation, g, w}, Ys} -AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - // {{...}, Ys} - auto inputs_y = node->cast()->inputs(); - std::vector inputs_x; - if (IsCNode(inputs_y[0])) { - inputs_x = inputs_y[0]->cast()->inputs(); - } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { - inputs_x = inputs_y[1]->cast()->inputs(); - } else { - return nullptr; - } - - // {{...}, Xs} - if (inputs_x.size() < 2) { - return nullptr; - } - - // {GradOperation, g, w} or {GradOperation, g} - if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { - return nullptr; - } - - auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); - if (meta_func == nullptr) { - return nullptr; - } - auto grad_op_ptr = meta_func->cast(); - auto func_node = inputs_x[1]; - if (!IsValueNode(func_node)) { - return nullptr; - } - - AnfNodePtr unpack_graph_node = - GenerateUnpackGraphNode(inputs_y, node->cast()->func_graph(), func_node, - IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); - // constuct new grad_opration - inputs_x[1] = unpack_graph_node; - auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); - if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { - inputs_y[1] = grad_op_cnode; - } else { - inputs_y[0] = grad_op_cnode; - } - auto cnode = node->func_graph()->NewCNode(inputs_y); - return cnode; -} -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h deleted file mode 100644 index 9713017d12..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ - -#include -#include -#include -#include - -#include "operator/composite/composite.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {{GradOperation, g, w}, Ys} -// {UnPackCall, {GradOperation, g, w}, Ys} -class GradVarPrepare : public AnfVisitor { - public: - GradVarPrepare() - : grad_op_(std::make_shared("grad")), - unpack_op_(std::make_shared("unpack_call")) {} - ~GradVarPrepare() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - private: - MetaFuncGraphPtr grad_op_; - MetaFuncGraphPtr unpack_op_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.cc b/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.cc deleted file mode 100644 index 3347fa9dc0..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/irpass/gradient_eliminate.h" - -#include - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { - ScopeGuard scope_guard(vnode->scope()); - - auto newg = ad::Kprim(vnode, resource); - if (newg != nullptr) { - return NewValueNode(newg); - } - - // when find in J failed, try in Jmeta - auto prim = GetValueNode(vnode); - MetaFuncGraphPtr meta = ad::Kmeta(prim, resource); - if (meta != nullptr) { - return NewValueNode(meta); - } - - return nullptr; -} - -bool CheckIfEmbedJFuncGraph(const FuncGraphPtr func_graph) { - // if func graph also contain J FuncGraph, then ignore this funcgraph. ExpandJ innermost graph first; - auto func_graph_manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(func_graph_manager); - return func_graph_manager->func_graph_j_total(func_graph); -} - -AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource) { - if (IsValueNode(vnode)) { - ScopeGuard scope_guard(vnode->scope()); - - auto func_graph = GetValueNode(vnode); - MS_LOG(DEBUG) << "Node is ValueNodeGraph, graph: " << func_graph->ToString(); - - // high_order_grad begin; - // if graph also contain J Graph, then ignore this graph. ExpandJ innermost graph first; - if (CheckIfEmbedJFuncGraph(func_graph)) { - MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " contains J(funcgraph), will expandJ later"; - return nullptr; - } - // high_order_grad end; - - MS_LOG(DEBUG) << "Funcgraph: " << func_graph->ToString() << " will expandJ now"; - auto newfg = ad::Grad(func_graph, resource); - return NewValueNode(newfg); - } - - if (IsValueNode(vnode)) { - return ExpandJPrimitive(vnode, resource); - } - - return nullptr; -} -} // namespace internal -} // namespace irpass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h b/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h deleted file mode 100644 index 671d9bde49..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/gradient_eliminate.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ - -#include -#include -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "common/utils.h" -#include "operator/ops.h" -#include "optimizer/ad/grad.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const pipeline::ResourceBasePtr &resource); -} // namespace internal - -// {prim::kPrimJ, C} -class ExpandJPrim : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimJ, {IsVNode})(node); - if (x_ != nullptr) { - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - auto j_node = internal::ExpandJ(x_, optimizer->resource()); - TraceManager::EndTrace(); - return j_node; - } - return nullptr; - } - - void Visit(const ValueNodePtr &node) override { x_ = node; } - - private: - ValueNodePtr x_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRADIENT_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_call.h b/mindspore/ccsrc/optimizer/irpass/incorporate_call.h deleted file mode 100644 index 5842b7bfd6..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_call.h +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ - -#include -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class CallOutputTransform { - public: - CallOutputTransform() : cache_() {} - ~CallOutputTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) { - if (cache_.find(fg) == cache_.end()) { - cache_[fg] = {}; - } - - auto &cache = cache_[fg]; - if (cache.find(nargs) == cache.end()) { - FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("call")); - - std::vector new_items; - new_items.push_back(new_fg->output()); - for (size_t i = 0; i < nargs; i++) { - new_items.push_back(new_fg->add_parameter()); - } - new_fg->set_output(new_fg->NewCNode(new_items)); - - cache[nargs] = new_fg; - } - return cache[nargs]; - } - - private: - std::unordered_map> cache_; -}; -} // namespace internal - -// {{G, Xs}, Ys} -class IncorporateCall : public AnfVisitor { - public: - IncorporateCall() : call_output_transform_() {} - ~IncorporateCall() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (inputs[0] == nullptr || !inputs[0]->isa()) { - return nullptr; - } - - AnfVisitor::Visit(inputs[0]); - if (fg_ == nullptr) { - return nullptr; - } - - auto xs_size = Xs_.size(); - auto ys_size = inputs.size() - 1; - auto new_fg = call_output_transform_(fg_, ys_size); - - std::vector args; - args.push_back(NewValueNode(new_fg)); - - if (xs_size > 0) { - (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); - } - - if (ys_size > 0) { - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - } - - return node->func_graph()->NewCNode(args); - } - - void Visit(const CNodePtr &cnode) override { - // {G, Xs} - if (cnode->size() < 1 || !IsValueNode(cnode->input(0))) { - return; - } - - auto &inputs = cnode->inputs(); - fg_ = GetValueNode(inputs[0]); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - - void Reset() { - Xs_.clear(); - fg_ = nullptr; - } - - private: - FuncGraphPtr fg_; - std::vector Xs_{}; - internal::CallOutputTransform call_output_transform_; -}; - -// {{{prim::kPrimSwitch, X, G1, G2}, Xs}, Ys} -class IncorporateCallSwitch : public AnfVisitor { - public: - IncorporateCallSwitch() : call_output_transform_() {} - ~IncorporateCallSwitch() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - // {{...}, Ys} - auto &inputs = node->cast()->inputs(); - if (inputs[0] == nullptr || !inputs[0]->isa()) { - return nullptr; - } - - // {{...}, Xs} - auto &inputs_x = inputs[0]->cast()->inputs(); - if (inputs_x[0] == nullptr || !inputs_x[0]->isa()) { - return nullptr; - } - - // {prim::kPrimSwitch, X, G1, G2} - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs_x[0]); - if (g2_ == nullptr) { - return nullptr; - } - - auto fg = node->func_graph(); - auto xs_size = inputs_x.size() - 1; - auto ys_size = inputs.size() - 1; - auto new_g1 = call_output_transform_(g1_, ys_size); - auto new_g2 = call_output_transform_(g2_, ys_size); - auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); - - std::vector args{sw_node}; - if (xs_size > 0) { - (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); - } - if (ys_size > 0) { - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - } - - return fg->NewCNode(args); - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - return; - } - AnfVisitor::Visit(node); - } - - void Visit(const ValueNodePtr &vnode) override { - auto g = GetValueNode(vnode); - if (g1_ == nullptr) { - g1_ = g; - } else { - g2_ = g; - } - } - - void Reset() { - x_ = nullptr; - g1_ = nullptr; - g2_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}; - FuncGraphPtr g1_{nullptr}, g2_{nullptr}; - internal::CallOutputTransform call_output_transform_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_CALL_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h deleted file mode 100644 index b6c8fb0e18..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/incorporate_getitem.h +++ /dev/null @@ -1,416 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ - -#include -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class GetitemTransform { - public: - GetitemTransform() : cache_() {} - ~GetitemTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &fg, int idx) { - if (cache_.find(fg) == cache_.end()) { - cache_[fg] = {}; - } - - auto &cache = cache_[fg]; - if (cache.find(idx) == cache.end()) { - std::ostringstream ss("tp", std::ostringstream::app); - ss << idx; - - auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); - auto output = new_fg->output(); - if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto cnode = output->cast(); - auto ids = IntToSize(idx + 1); - // Inputs should be [make_tuple, item1, item2, ...], so have to offset idx in tuple_getitem by 1. - if (ids >= cnode->size()) { - MS_LOG(EXCEPTION) << "index " << ids << " is out of inputs length " << cnode->size(); - } - new_fg->set_output(cnode->input(ids)); - } else { - new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(idx)})); - } - - cache[idx] = new_fg; - } - return cache[idx]; - } - - private: - std::unordered_map> cache_; -}; -} // namespace internal - -// {prim::kPrimTupleGetItem, {G, Xs}, C} -class IncorporateGetitem : public AnfVisitor { - public: - IncorporateGetitem() : getitem_transform_() {} - ~IncorporateGetitem() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); - if (node->func_graph() == nullptr || idx_ == -1 || fg_ == nullptr) { - return nullptr; - } - - if (fg_->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - // If graph kernel has muti output, do not split. - // some graph kernel output has EnvInstance node or DeadCode node should split. - auto output = fg_->output(); - if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto output_cnode = output->cast(); - auto outputs = output_cnode->inputs(); - int real_output_cnt = 0; - for (size_t i = 1; i < outputs.size(); ++i) { - if (IsCNode(outputs[i]) || IsValueNode(outputs[i]) || IsParam(outputs[i])) { - real_output_cnt++; - if (real_output_cnt > 1) { - return nullptr; - } - } - } - } - } - - auto new_fg = getitem_transform_(fg_, idx_); - (void)args_.insert(args_.begin(), NewValueNode(new_fg)); - return node->func_graph()->NewCNode(args_); - } - - void Visit(const CNodePtr &cnode) override { - if (cnode->size() == 0 || !IsValueNode(cnode->input(0))) { - return; - } - - auto &inputs = cnode->inputs(); - fg_ = GetValueNode(inputs[0]); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); - } - - void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } - - void Reset() { - idx_ = -1; - fg_ = nullptr; - args_.clear(); - } - - private: - int idx_{-1}; - FuncGraphPtr fg_{nullptr}; - std::vector args_{}; - internal::GetitemTransform getitem_transform_; -}; - -class IncorporateGetitemFromParam : public AnfVisitor { - public: - void Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr ¶m, size_t input_idx) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto &node_users = mng->node_users(); - if (node_users.find(param) == node_users.end() || node_users[param].empty()) { - args_.push_back(cnode->input(input_idx + 1)); - return; - } - - for (auto &user : node_users[param]) { - if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { - // we do not process this case. - args_.push_back(cnode->input(input_idx + 1)); - return; - } - } - - // update new args. - if (IsPrimitiveCNode(cnode->input(input_idx + 1), prim::kPrimMakeTuple)) { - // case 1 - replace_parameters_[input_idx] = true; - need_update_ = true; - auto make_tuple_cnode = cnode->input(input_idx + 1)->cast(); - auto &make_tuple_cnode_inputs = make_tuple_cnode->inputs(); - inputs_num_[input_idx] = make_tuple_cnode_inputs.size() - 1; - args_.insert(args_.end(), make_tuple_cnode_inputs.begin() + 1, make_tuple_cnode_inputs.end()); - } else { - // case 2 - auto prev_cnode = cnode->input(input_idx + 1)->cast(); - auto prev_fg = GetValueNode(prev_cnode->input(0)); - auto fg_output = prev_fg->output(); - if (!IsPrimitiveCNode(fg_output, prim::kPrimMakeTuple)) { - MS_LOG(ERROR) << "The return of: " << prev_fg->ToString() - << " should be a make tuple, but got: " << fg_output->DebugString(); - return; - } - replace_parameters_[input_idx] = true; - need_update_ = true; - auto make_tuple_cnode = fg_output->cast(); - inputs_num_[input_idx] = make_tuple_cnode->inputs().size() - 1; - for (size_t output_i = 0; output_i < inputs_num_[input_idx]; ++output_i) { - auto new_getitem = - func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), prev_cnode, NewValueNode(SizeToInt(output_i))}); - auto aptr = std::make_shared(std::make_shared(SizeToInt(output_i))); - new_getitem->input(2)->set_abstract(aptr); - new_getitem->set_abstract(make_tuple_cnode->input(output_i + 1)->abstract()); - args_.push_back(new_getitem); - } - } - } - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (node->func_graph() == nullptr) { - return nullptr; - } - - Reset(); - - auto cnode = node->cast(); - if (cnode == nullptr) { - return nullptr; - } - auto &inputs = cnode->inputs(); - auto fg = GetValueNode(inputs[0]); - if (fg == nullptr) { - return nullptr; - } - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto parameters = fg->parameters(); - if (parameters.size() != inputs.size() - 1) { - return nullptr; - } - replace_parameters_ = std::vector(parameters.size(), false); - inputs_num_ = std::vector(parameters.size(), 1); - auto node_fg = node->func_graph(); - - for (size_t i = 1; i < inputs.size(); ++i) { - if (IsPrimitiveCNode(inputs[i], prim::kPrimMakeTuple) || IsCNodeGraphKernel(inputs[i])) { - Process(node_fg, cnode, parameters[i - 1], i - 1); - } else { - args_.push_back(inputs[i]); - } - } - - if (!need_update_) { - return nullptr; - } - - FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); - mng->AddFuncGraph(new_fg); - - auto node_users = mng->node_users(); - std::vector new_fg_parameters = new_fg->parameters(); - std::vector new_parameters; - size_t curr_input_idx{0}; - for (size_t param_i = 0; param_i < new_fg_parameters.size(); ++param_i) { - if (!replace_parameters_[param_i]) { - if (parameters[param_i]->abstract() != nullptr) { - new_fg_parameters[param_i]->set_abstract(parameters[param_i]->abstract()); - } - new_parameters.push_back(new_fg_parameters[param_i]); - curr_input_idx++; - continue; - } - - // make a new parameter. - for (size_t input_i = 0; input_i < inputs_num_[param_i]; ++input_i) { - auto new_param = std::make_shared(new_fg); - new_param->set_abstract(args_.at(curr_input_idx)->abstract()); - - // update users of new parameter. - for (auto &user : node_users[new_fg_parameters[param_i]]) { - idx_ = -1; - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsParam, IsValueNode})(user.first); - if (idx_ == -1) { - MS_LOG(ERROR) << "User of: " << new_fg_parameters[param_i]->DebugString() - << " must be tuple getitem here, but got: " << user.first->DebugString(); - return nullptr; - } - - if (input_i == IntToSize(idx_)) { - for (auto &sub_user : node_users[user.first]) { - auto sub_user_cnode = sub_user.first->cast(); - MS_EXCEPTION_IF_NULL(sub_user_cnode); - sub_user_cnode->set_input(sub_user.second, new_param); - (void)mng->Replace(sub_user.first, sub_user_cnode); - } - } - } - - // (void)mng->Replace(new_fg_parameters[param_i], new_param); - new_parameters.push_back(new_param); - curr_input_idx++; - } - } - - mng->SetParameters(new_fg, new_parameters); - (void)args_.insert(args_.begin(), NewValueNode(new_fg)); - auto new_call = node_fg->NewCNode(args_); - new_call->set_abstract(node->abstract()); - return new_call; - } - - void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } - - void Visit(const CNodePtr &cnode) override {} - - void Reset() { - replace_parameters_.clear(); - args_.clear(); - inputs_num_.clear(); - need_update_ = false; - idx_ = -1; - } - - private: - std::vector replace_parameters_{}; - std::vector args_{}; - std::vector inputs_num_{}; - bool need_update_{false}; - int idx_{-1}; -}; - -// {prim::kPrimTupleGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C} -class IncorporateGetitemSwitch : public AnfVisitor { - public: - IncorporateGetitemSwitch() : getitem_transform_() {} - ~IncorporateGetitemSwitch() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - is_in_get_ = true; - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); - is_in_get_ = false; - - auto fg = node->func_graph(); - if (idx_ == -1 || switch_ == nullptr || fg == nullptr) { - return nullptr; - } - - is_in_switch_ = true; - AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(switch_); - is_in_switch_ = false; - - if (g2_ == nullptr) { - return nullptr; - } - - auto new_g1 = getitem_transform_(g1_, idx_); - auto new_g2 = getitem_transform_(g2_, idx_); - auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); - (void)args_.insert(args_.begin(), sw_node); - - return fg->NewCNode(args_); - } - - void Visit(const AnfNodePtr &node) override { - if (is_in_switch_ && x_ == nullptr) { - x_ = node; - return; - } - AnfVisitor::Visit(node); - } - - void Visit(const CNodePtr &cnode) override { - if (is_in_get_ && cnode->size() != 0) { - auto &inputs = cnode->inputs(); - switch_ = inputs[0]; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args_)); - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (is_in_get_) { - idx_ = GetValue(vnode->value()); - } - - if (is_in_switch_) { - auto g = GetValueNode(vnode); - if (g1_ == nullptr) { - g1_ = g; - } else { - g2_ = g; - } - } - } - - void Reset() { - x_ = nullptr; - g1_ = nullptr; - g2_ = nullptr; - switch_ = nullptr; - args_.clear(); - is_in_get_ = false; - is_in_switch_ = false; - } - - private: - int idx_{-1}; - AnfNodePtr switch_{nullptr}, x_{nullptr}; - FuncGraphPtr g1_{nullptr}, g2_{nullptr}; - bool is_in_get_{false}, is_in_switch_{false}; - std::vector args_{}; - internal::GetitemTransform getitem_transform_; -}; - -class IncorporateGetitemSet : public OptimizerCaller { - public: - IncorporateGetitemSet() - : incorporate_getitem_(std::make_shared()), - incorporate_getitem_switch_(std::make_shared()) { - eliminaters_.emplace_back(incorporate_getitem_); - eliminaters_.emplace_back(incorporate_getitem_switch_); - } - ~IncorporateGetitemSet() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_; - std::vector eliminaters_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h b/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h deleted file mode 100644 index 630d567549..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ - -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} -// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} -// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} -class IndexedSlicesEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(1); - } - AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(2); - } - AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); - - if (is_match_) { - return tuple_->input(3); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { - tuple_ = cnode; - is_match_ = true; - } - } - - void Reset() { - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - CNodePtr tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/inline.h b/mindspore/ccsrc/optimizer/irpass/inline.h deleted file mode 100644 index 4b48d604d9..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/inline.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -class ReplaceApplicator : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsValueNode(node)) { - return nullptr; - } - - auto fg = GetValueNode(node); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { - return nullptr; - } - - auto out = fg->output(); - MS_EXCEPTION_IF_NULL(out); - if (!out->isa()) { - return nullptr; - } - - auto &inputs = out->cast()->inputs(); - auto params = fg->parameters(); - - // Exclude first elements of inputs which is fn. - auto input_size = inputs.size(); - auto param_size = params.size(); - if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size && - std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) { - auto inner = inputs[0]; - if (IsValueNode(inner) || - (IsValueNode(inner) && GetValueNode(inner)->parent() == nullptr)) { - return inner; - } - } - - return nullptr; - } -}; - -using CriterionFuncType = std::function; - -bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) { - auto n_cnode = fg->nodes().size() - fg->parameters().size(); - // There is at least one CNode(return, other_node). - return n_cnode <= 2; -} - -bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) { - auto &cnodes = fg->func_graph_cnodes_index(); - int n_use = - std::accumulate(cnodes.begin(), cnodes.end(), 0, - [](int sum, const std::pair &item) { return sum + item.second; }); - return n_use == 1; -} - -bool IsInside(FuncGraphPtr, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node->func_graph()); - return node->func_graph()->has_flag("inline_inside"); -} - -bool IsCore(const FuncGraphPtr &fg, AnfNodePtr) { return fg->has_flag("core"); } - -bool NoCriterion(FuncGraphPtr, AnfNodePtr) { return true; } - -// {G, Xs} -class InlinerBase : public AnfVisitor { - public: - explicit InlinerBase(std::vector> criterions) : criterions_(criterions) {} - ~InlinerBase() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa()) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 1 || !IsValueNode(inputs[0])) { - return nullptr; - } - - // G - auto fg = GetValueNode(inputs[0]); - if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { - return nullptr; - } - // Do not inline GraphKernel to Cell. - if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && !node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - // If the GraphKernel only contains a return node, we make it inlined. - if (fg->nodes().size() - fg->parameters().size() > 1) { - return nullptr; - } - } - - Reset(); - bool is_match = false; - for (auto &criterion : criterions_) { - if (!criterion.first(fg, node)) { - continue; - } - - if (criterion.second && IsRecursive(fg)) { - continue; - } - - is_match = true; - break; - } - - if (!is_match) { - return nullptr; - } - - std::vector params; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(params)); - - if (IsUniqueUse(fg, nullptr)) { - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - ReplaceParams(mng, params, fg); - auto out_node = fg->output(); - mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); - return out_node; - } - - return InlineClone(fg, node->func_graph(), params, inputs[0]->scope()); - } - - void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector &new_params, - const FuncGraphPtr &fg) { - auto params = fg->parameters(); - auto old_size = params.size(); - if (old_size != new_params.size()) { - MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size() - << fg->output()->DebugString(10); - } - for (size_t i = 0; i < old_size; i++) { - (void)mng->Replace(params[i], new_params[i]); - } - } - - bool IsRecursive(const FuncGraphPtr &fg) { - if (!is_checked_) { - is_checked_ = true; - is_recursive_ = fg->recursive(); - } - return is_recursive_; - } - - void Reset() { - is_checked_ = false; - is_recursive_ = false; - } - - private: - bool is_checked_{false}, is_recursive_{false}; - std::vector> criterions_; -}; - -class Inliner : public InlinerBase { - public: - Inliner() - : InlinerBase({ - {IsUniqueUse, true}, - {IsTrivial, false}, - {IsInside, false}, - {IsCore, false}, - {NoCriterion, true}, - }) {} - ~Inliner() override = default; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INLINE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h b/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h deleted file mode 100644 index 202951a254..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/item_tuple_eliminate.h +++ /dev/null @@ -1,301 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ - -#include -#include -#include - -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// (a, b, c, ...)[0] => a -// (a, b, c, ...)[1] => b -// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} -class GetitemEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); - - if (is_match_) { - return tuple_->input(id_); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - tuple_ = cnode; - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value()) + 1); - if (tuple_->size() > id_) { - is_match_ = true; - } - } - } - - void Reset() { - id_ = 0; - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - size_t id_{0}; - CNodePtr tuple_{nullptr}; -}; - -// (a, b, c, ...)[0] => a -// (a, b, c, ...)[1] => b -// {prim::kPrimTupleGetItem, C1, C} -class GetitemConstEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsVNode, IsVNode})(node); - - if (is_match_) { - return NewValueNode((*tuple_)[id_]); - } - return nullptr; - } - - void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { - tuple_ = GetValueNode(vnode); - } - if (tuple_ != nullptr && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value())); - if (tuple_->size() > id_) { - is_match_ = true; - } - } - } - - void Reset() { - id_ = 0; - tuple_ = nullptr; - is_match_ = false; - } - - private: - bool is_match_{false}; - size_t id_{0}; - ValueTuplePtr tuple_{nullptr}; -}; - -// setitem((a, b, c, ...), 0, z) => (z, b, c, ...) -// setitem((a, b, c, ...), 1, z) => (a, z, c, ...) -// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z} -class SetitemEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node); - - auto fg = node->func_graph(); - if (fg != nullptr && z_ != nullptr) { - args_[id_] = z_; - return fg->NewCNode(args_); - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (is_match_) { - z_ = node; - return; - } - - AnfVisitor::Visit(node); - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - auto &inputs = cnode->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(args_)); - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (args_.size() > 0 && IsValueNode(vnode)) { - id_ = IntToSize(GetValue(vnode->value()) + 1); - if (id_ < args_.size()) { - is_match_ = true; - } - } - } - - void Reset() { - id_ = 0; - z_ = nullptr; - is_match_ = false; - args_.clear(); - } - - private: - bool is_match_{false}; - size_t id_{0}; - AnfNodePtr z_{nullptr}; - std::vector args_{}; -}; - -// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} -class GetSetitemEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node); - - auto fg = node->func_graph(); - if (fg != nullptr && key1_ >= 0 && key2_ >= 0) { - if (key1_ == key2_) { - return last_; - } - return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_, c2_}); - } - return nullptr; - } - - void Visit(const CNodePtr &cnode) override { - if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) { - if (cnode->size() < 4) { - return; - } - - tuple_ = cnode->input(1); - last_ = cnode->input(3); - - // key of setitem - is_in_set_ = true; - AnfVisitor::Visit(cnode->input(2)); - is_in_set_ = false; - } - } - - void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { - auto key = GetValue(vnode->value()); - if (is_in_set_) { - key1_ = key; - } else { - c2_ = vnode; - key2_ = key; - } - } - } - - void Reset() { - key1_ = -1; - key2_ = -1; - c2_ = nullptr; - last_ = nullptr; - tuple_ = nullptr; - is_in_set_ = false; - } - - private: - bool is_in_set_{false}; - int key1_{-1}, key2_{-1}; - AnfNodePtr tuple_{nullptr}, last_{nullptr}, c2_{nullptr}; -}; - -// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} -> -// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} -class GetitemDependReorder : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsValueNode})(node); - if (x_ == nullptr) { - return nullptr; - } - - auto fg = node->func_graph(); - auto item_node = NewCNode({NewValueNode(prim::kPrimTupleGetItem), x_, c_}, fg); - return NewCNode({NewValueNode(prim::kPrimDepend), item_node, y_}, fg); - } - - void Visit(const CNodePtr &cnode) override { - // {prim::kPrimDepend, X, Y} - if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && cnode->size() == 3) { - x_ = cnode->input(1); - y_ = cnode->input(2); - } - } - - void Visit(const ValueNodePtr &vnode) override { c_ = vnode; } - - void Reset() { - x_ = nullptr; - y_ = nullptr; - c_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; -}; - -class ItemTupleEliminater : public OptimizerCaller { - public: - ItemTupleEliminater() - : get_item_eliminater_(std::make_shared()), - get_item_const_eliminater_(std::make_shared()), - set_item_eliminater_(std::make_shared()), - get_set_item_eliminater_(std::make_shared()), - get_item_depend_reorder_(std::make_shared()) { - eliminaters_.emplace_back(get_item_eliminater_); - eliminaters_.emplace_back(get_item_const_eliminater_); - eliminaters_.emplace_back(set_item_eliminater_); - eliminaters_.emplace_back(get_set_item_eliminater_); - eliminaters_.emplace_back(get_item_depend_reorder_); - } - ~ItemTupleEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_, - get_item_depend_reorder_; - std::vector eliminaters_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/mark_interface_fusion.h b/mindspore/ccsrc/optimizer/irpass/mark_interface_fusion.h deleted file mode 100644 index 6f2bcc187f..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/mark_interface_fusion.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" -#include "operator/composite/composite.h" - -namespace mindspore { -namespace opt { -namespace irpass { - -static int count = 0; - -std::string GetFusionNumber() { - std::stringstream ss; - ss << std::setw(4) << std::setfill('0') << count; - std::string num = ss.str(); - ++count; - - return "_" + num; -} - -// Mark CNodes which can be merged in kernel build -class MarkInterfaceFusion : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsPrimitiveCNode(node, prim::kPrimSelect)) { - auto cnode = node->cast(); - auto condition = cnode->input(1); - std::string cmp; - std::unordered_map cmp_list = {{"GreaterEqual", "GE"}, {"Greater", "GT"}, - {"LessEqual", "LE"}, {"Less", "LT"}, - {"Equal", "EQ"}, {"NotEqual", "NE"}}; - if (IsPrimitiveCNode(condition)) { - auto prim_name = GetCNodeFuncName(condition->cast()); - if (cmp_list.count(prim_name) != 0) { - // Mark Select and compare node - cmp = cmp_list[prim_name]; - auto cnt = GetFusionNumber(); - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), condition); - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt + "_end"), node); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - if (IsPrimitiveCNode(cnode->input(i), prim::kPrimZerosLike)) { - AnfAlgo::SetNodeAttr("fusion", MakeValue("Select" + cmp + cnt), cnode->input(i)); - } - } - } - } - } - return nullptr; - } - - void Visit(const AnfNodePtr &) override {} - - private: - AnfNodePtr y_{nullptr}; -}; - -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MARK_INTERFACE_FUSION_H diff --git a/mindspore/ccsrc/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/optimizer/irpass/merge_addn.h deleted file mode 100644 index e1e4b8878b..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/merge_addn.h +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {PrimAddN, {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys}} -> -// {{PrimAddNClass}, {prim::kPrimMakeTuple, Xs, Ys}} -// {PrimAddN, {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}}} -> -// {{PrimAddNClass}, {prim::kPrimMakeTuple, Ys, Xs}} -class MergeAddN : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - optimizer_ = optimizer; - is_outer_ = true; - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); - if (!is_match_ || node->func_graph() == nullptr) { - return nullptr; - } - - auto cnode = node->cast(); - auto addn = NewValueNode(GetValueNode(cnode->input(0))); - - // {prim::kPrimMakeTuple, Xs, Ys}, {prim::kPrimMakeTuple, Ys, Xs} - (void)args_.insert(args_.begin(), NewValueNode(prim::kPrimMakeTuple)); - auto fg = node->func_graph(); - auto make_node = fg->NewCNode(args_); - - return fg->NewCNode({addn, make_node}); - } - - void Visit(const CNodePtr &cnode) override { - if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - return; - } - - auto &inputs = cnode->inputs(); - - if (is_outer_) { - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Ys_)); - - is_outer_ = false; - is_inner_ = true; - - // {prim::kPrimMakeTuple, {PrimAddN, {prim::kPrimMakeTuple, Xs}}, Ys} - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs[1]); - if (is_match_) { - if (!is_unique(inputs[1])) { - is_match_ = false; - return; - } - (void)Ys_.erase(Ys_.begin()); - (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); - (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); - return; - } - - // {prim::kPrimMakeTuple, Ys, {PrimAddN, {prim::kPrimMakeTuple, Xs}}} - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(inputs.back()); - if (is_match_) { - if (!is_unique(inputs.back())) { - is_match_ = false; - return; - } - Ys_.pop_back(); - (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); - (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); - return; - } - - return; - } - - if (is_inner_) { - is_match_ = true; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - } - - bool is_unique(const AnfNodePtr &node) { - auto mng = optimizer_->resource()->manager(); - auto &node_users = mng->node_users(); - if (node_users.find(node) == node_users.end()) { - return false; - } - - size_t n_use = node_users[node].size(); - return n_use == 1; - } - - void Reset() { - Xs_.clear(); - Ys_.clear(); - args_.clear(); - is_inner_ = false; - is_outer_ = false; - is_match_ = false; - } - - private: - OptimizerPtr optimizer_{nullptr}; - std::vector Xs_{}, Ys_{}, args_{}; - bool is_inner_{false}, is_outer_{false}, is_match_{false}; -}; - -// {PrimAddN, {kPrimMakeTuple, Xs}} -class AddNZeroFilter : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimAddN, {IsCNode})(node); - - if (filtered_Xs_.empty() || node->func_graph() == nullptr) { - return nullptr; - } - - // if only two node in filtered_nodes, {make_tuple, x}. return x. - if (filtered_Xs_.size() == 2) { - return filtered_Xs_[1]; - } - - // if only one node in filtered_nodes, all node is zerolike, return one of the input. - if (filtered_Xs_.size() == 1 && Xs_.size() > 0) { - return Xs_[0]; - } - - if (!has_zero_like_) { - return nullptr; - } - - auto cnode = node->cast(); - auto addn = NewValueNode(GetValueNode(cnode->input(0))); - auto fg = node->func_graph(); - auto make_tuple = fg->NewCNode(filtered_Xs_); - return fg->NewCNode({addn, make_tuple}); - } - - void Visit(const CNodePtr &cnode) override { - if (!IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) { - return; - } - - auto &inputs = cnode->inputs(); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - - // {kPrimMakeTuple, X1, X2, ...} - filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (auto &x : Xs_) { - if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) { - filtered_Xs_.push_back(x); - } else { - has_zero_like_ = true; - } - } - } - - void Reset() { - Xs_.clear(); - filtered_Xs_.clear(); - has_zero_like_ = false; - } - - private: - std::vector filtered_Xs_{}, Xs_{}; - bool has_zero_like_{false}; -}; - -// {PrimAddN, {kPrimMakeTuple, Xs}} -// Akg don't support AddN(ValueNode, Tensor, ...), converted to TensorAdd. -// case0: AddN(inputs)(inputs size < 2) -> error -// case1: AddN(inputs)(all inputs is ValueNode) -> error -// case2: AddN(inputs)(inputs size = 2) -> TensorAdd(Tensor, Tensor) -// case3: AddN(ValueNode, Tensor, Tensor, ...)(has one ValueNode input) -// -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) -class AddNEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - auto fg = GetValueNode(inputs[0]); - MS_EXCEPTION_IF_NULL(fg); - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - if (fg->recursive()) { - return nullptr; - } - - auto new_fg = TransformableClone(fg, std::make_shared("fg")); - mng->AddFuncGraph(new_fg); - need_update_ = false; - bool changed; - do { - changed = Process(new_fg); - } while (changed); - - if (!need_update_) { - return nullptr; - } else { - auto new_sx = inputs; - new_sx[0] = NewValueNode(new_fg); - return node->func_graph()->NewCNode(new_sx); - } - } - - bool Process(const FuncGraphPtr &func_graph) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto nodes = TopoSort(func_graph->output()); - bool changed = false; - - for (size_t i = 0; i < nodes.size(); ++i) { - auto node = nodes[i]; - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - continue; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto &tuple_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(tuple_input); - auto tuple_input_cnode = tuple_input->cast(); - MS_EXCEPTION_IF_NULL(tuple_input_cnode); - auto &tuple_inputs = tuple_input_cnode->inputs(); - if (tuple_inputs.size() < 3) { - // case0: inputs size < 2, error - MS_EXCEPTION(ArgumentError) << "Inputs size of AddN less than 2. " << cnode->DebugString(2); - } - - int valuenode_num = - std::accumulate(tuple_inputs.begin() + 1, tuple_inputs.end(), 0, [](int accumulator, const AnfNodePtr &node) { - if (IsValueNode(node)) { - return accumulator + 1; - } else { - return accumulator; - } - }); - if (IntToSize(valuenode_num) == tuple_inputs.size()) { - // case1: all inputs is ValueNode, error - MS_EXCEPTION(ArgumentError) << "All inputs of AddN is ValueNode. " << cnode->DebugString(2); - } - - if (tuple_inputs.size() == 3) { - // case2: inputs size = 2, -> TensorAdd(Tensor, Tensor) - MS_LOG(DEBUG) << "Replace AddN with two inputs with TensorAdd. " << cnode->DebugString(2); - ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); - std::vector new_xs{func_graph->NewCNode({NewValueNode(prim_tensoradd)}), tuple_inputs[1], - tuple_inputs[2]}; - mng->Replace(node, func_graph->NewCNode(new_xs)); - changed = true; - continue; - } - - auto first_valuenode = std::find_if(tuple_inputs.begin() + 1, tuple_inputs.end(), - [](const AnfNodePtr &node) { return IsValueNode(node); }); - if (first_valuenode == tuple_inputs.end()) { - // no ValueNode input found. - continue; - } else { - // case3: has one ValueNode input -> TensorAdd(ValueNode, AddN(Tensor, Tensor, ...)) - std::vector make_tuple_new_xs{ - NewValueNode(prim::kPrimMakeTuple), - }; - std::for_each(tuple_inputs.begin() + 1, tuple_inputs.end(), - [&make_tuple_new_xs, &first_valuenode](const AnfNodePtr &node) { - if (node != *first_valuenode) { - make_tuple_new_xs.push_back(node); - } - }); - ValuePtr prim_addn = prim::GetPythonOps("AddN", "mindspore.ops.operations"); - auto new_addn = func_graph->NewCNode( - {func_graph->NewCNode({NewValueNode(prim_addn)}), func_graph->NewCNode(make_tuple_new_xs)}); - ValuePtr prim_tensoradd = prim::GetPythonOps("TensorAdd", "mindspore.ops.operations"); - auto new_add = - func_graph->NewCNode({func_graph->NewCNode({NewValueNode(prim_tensoradd)}), *first_valuenode, new_addn}); - (void)mng->Replace(node, new_add); - changed = true; - continue; - } - } - - need_update_ = need_update_ || changed; - return changed; - } - - private: - bool need_update_{false}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MERGE_ADDN_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/minmax_grad.h b/mindspore/ccsrc/optimizer/irpass/minmax_grad.h deleted file mode 100644 index a426a9fb9b..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/minmax_grad.h +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ - -#include -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -// check if node is MinimumGrad() or MaximumGrad() -bool IsOriginMaxMinGrad(const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimMaximumGrad) && !IsPrimitiveCNode(node, prim::kPrimMinimumGrad)) { - return false; - } - - auto cnode = node->cast(); - auto prim = GetValueNode(cnode->input(0)); - auto x_v = prim->GetAttr("grad_x"); - auto y_v = prim->GetAttr("grad_y"); - if (x_v == nullptr || y_v == nullptr || !x_v->isa() || !y_v->isa()) { - return false; - } - - bool x = GetValue(x_v); - bool y = GetValue(y_v); - return x && y; -} -} // namespace internal - -// {prim::kPrimTupleGetItem, {target_grad, Xs}, C} -class MinMaximumGrad : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTupleGetItem, {internal::IsOriginMaxMinGrad, IsValueNode})(node); - if (grad_ == nullptr || idx_ < 0 || idx_ > 1 || node->func_graph() == nullptr) { - return nullptr; - } - - // check single use - auto mng = optimizer->resource()->manager(); - auto &users = mng->node_users(); - if (users.find(grad_) == users.end() || users[grad_].size() != 1) { - return nullptr; - } - - // {target_grad, Xs} - auto &inputs = grad_->inputs(); - auto prim = GetValueNode(inputs[0]); - - auto new_prim = std::make_shared(prim->name()); - new_prim->set_attr("grad_x", MakeValue(true)); - new_prim->set_attr("grad_y", MakeValue(true)); - - if (idx_ == 0) { - new_prim->set_attr("grad_y", MakeValue(false)); - } - if (idx_ == 1) { - new_prim->set_attr("grad_x", MakeValue(false)); - } - - std::vector args; - args.push_back(NewValueNode(new_prim)); - (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); - - auto fg = node->func_graph(); - auto tuple = fg->NewCNode(args); - - return fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple, NewValueNode(MakeValue(idx_))}); - } - - void Visit(const CNodePtr &cnode) override { grad_ = cnode; } - - void Visit(const ValueNodePtr &vnode) override { idx_ = GetValue(vnode->value()); } - - void Reset() { - idx_ = -1; - grad_ = nullptr; - } - - private: - int idx_{-1}; - CNodePtr grad_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_MINMAX_GRAD_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/param_replace.h b/mindspore/ccsrc/optimizer/irpass/param_replace.h deleted file mode 100644 index c0c4c832d7..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/param_replace.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ - -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "pipeline/parse/parse.h" - -namespace mindspore { -namespace opt { -namespace irpass { -class ReplaceOldParam : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - if (!IsParam(node)) { - return nullptr; - } - auto resource = std::dynamic_pointer_cast(optimizer->resource()); - MS_EXCEPTION_IF_NULL(resource); - - auto top_graph = resource->func_graph(); // parse::Parser::GetTopFuncGraph(); - MS_EXCEPTION_IF_NULL(top_graph); - - auto param_node = node->cast(); - if (!param_node->has_default() || node->func_graph() == top_graph) { - return nullptr; - } - auto para_name = param_node->name(); - for (const auto &tnode : top_graph->parameters()) { - auto para = tnode->cast(); - if (para != nullptr && para->name() == para_name) { - return para; - } - } - return nullptr; - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/partial_eliminate.h b/mindspore/ccsrc/optimizer/irpass/partial_eliminate.h deleted file mode 100644 index bc8ef9d8f3..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/partial_eliminate.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} -class PartialEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - Xs_.clear(); - auto &inputs = node->cast()->inputs(); - Visit(inputs[0]); - - if (Xs_.size() == 0) { - return nullptr; - } - - // {X, Xs, Ys} - std::vector args{}; - (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args)); - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - auto new_node = node->func_graph()->NewCNode(args); - TraceManager::EndTrace(); - return new_node; - } - - void Visit(const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { - return; - } - - auto &inputs = node->cast()->inputs(); - // {prim::kPrimPartial, X, Xs} - if (inputs.size() < 2) { - return; - } - - // fill Xs - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(Xs_)); - } - - private: - std::vector Xs_{}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/prim_eliminate.h b/mindspore/ccsrc/optimizer/irpass/prim_eliminate.h deleted file mode 100644 index 725c30a6b9..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/prim_eliminate.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim, X} -class PrimEliminater : public AnfVisitor { - public: - explicit PrimEliminater(const PrimitivePtr &prim) : prim_(prim) {} - ~PrimEliminater() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim_, {IsNode})(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { x_ = node; } - - private: - AnfNodePtr x_{nullptr}; - PrimitivePtr prim_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PRIM_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/reduce_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reduce_eliminate.h deleted file mode 100644 index cea002111c..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/reduce_eliminate.h +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ - -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "abstract/dshape.h" - -namespace mindspore { -namespace opt { -namespace irpass { -using abstract::Shape; -using abstract::ShapePtr; - -// {ReduceLike, X, axis} -class ReduceOneEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - PrimitivePtr prim; - if (IsPrimitiveCNode(node, prim::kPrimReduceMean) || IsPrimitiveCNode(node, prim::kPrimReduceAll) || - IsPrimitiveCNode(node, prim::kPrimReduceSum) || IsPrimitiveCNode(node, prim::kPrimReduceMax) || - IsPrimitiveCNode(node, prim::kPrimReduceMin)) { - prim = GetValueNode(node->cast()->input(0)); - AnfVisitor::Match(prim, {IsNode, IsVNode})(node); - if (!is_axis_one_) { - return nullptr; - } - - // consider keep_dims - auto keep_dims = prim->GetAttr("keep_dims"); - auto is_keep_dims = GetValue(keep_dims); - // {_Reduce, X, axis} -> X - if (is_keep_dims) { - return x_; - } - - // {_Reduce, Tensor} - if (is_tensor_) { - return nullptr; - } - - // {_Reduce, X, axis} -> {Reshape, X, new_shape} - std::vector elements; - for (size_t i = 0; i < x_shape_.size(); i++) { - auto iter = find(axis_.begin(), axis_.end(), i); - if (iter == axis_.end()) { - ValuePtr s = MakeValue(x_shape_[i]); - elements.push_back(s); - } - } - auto new_shape = std::make_shared(elements); - auto reshape_op = prim::GetPythonOps("reshape", "mindspore.ops.functional")->cast(); - return node->func_graph()->NewCNode({NewValueNode(reshape_op), x_, NewValueNode(new_shape)}); - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (!IsVNode(node) && x_ == nullptr) { - if (IsValueNode(node)) { - is_tensor_ = true; - } - // get X's shape - auto x_shape_abs = node->abstract(); - if (x_shape_abs != nullptr) { - auto x_track = x_shape_abs->GetShapeTrack()->cast(); - if (x_track == nullptr) { - return; - } - auto x_shape = x_track->shape(); - (void)std::copy(x_shape.begin(), x_shape.end(), std::back_inserter(x_shape_)); - x_ = node; - } - return; - } - - // check axis - AnfVisitor::Visit(node); - } - - void Visit(const ValueNodePtr &vnode) override { - if (x_shape_.empty()) { - return; - } - - // axis : int - if (IsValueNode(vnode)) { - auto idx = GetValue(vnode->value()); - // axis could be negative - if (idx < 0) { - idx += SizeToInt(x_shape_.size()); - } - if (SizeToInt(x_shape_.size()) > idx && x_shape_[IntToSize(idx)] == 1) { - is_axis_one_ = true; - axis_.push_back(idx); - } - return; - } - - // axis : tuple(int), default () - if (IsValueNode(vnode)) { - auto axis = GetValue>(vnode->value()); - if (axis.empty()) { - return; - } - - auto cmp = std::all_of(axis.cbegin(), axis.cend(), [this](int idx) { - // axis could be negative - if (idx < 0) { - idx += SizeToInt(x_shape_.size()); - } - return SizeToInt(this->x_shape_.size()) > idx && this->x_shape_[IntToSize(idx)] == 1; - }); - if (cmp) { - is_axis_one_ = true; - (void)std::copy(axis.begin(), axis.end(), std::back_inserter(axis_)); - } - } - } - - void Reset() { - axis_.clear(); - x_shape_.clear(); - x_ = nullptr; - is_axis_one_ = false; - is_tensor_ = false; - } - - private: - bool is_axis_one_{false}, is_tensor_{false}; - std::vector axis_{}, x_shape_{}; - AnfNodePtr x_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REDUCE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h b/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h deleted file mode 100644 index 6d81b401c3..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/ref_eliminate.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ - -#include - -#include "ir/pattern_matcher.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimMakeRef, X, Y, Z} -> Y -class MakeRefEliminater : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x, y, z; - MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y); - return nullptr; - } -}; - -// {prim::kPrimGetRefValue, Parameter} -> Parameter -// {prim::kPrimGetRefOrigin, Parameter} -> Parameter -class GetRefParamEliminater : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x; - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node)); - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node)); - return nullptr; - } -}; - -// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X -// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y -// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z -class GetMakeRefEliminater : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x, y, z; - MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x); - MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y); - MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z); - - return nullptr; - } -}; - -// IsValueNode -class ReplaceRefkeyByParam : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr { - auto refkey = GetValueNode(node); - auto resource = std::dynamic_pointer_cast(optimizer->resource()); - MS_EXCEPTION_IF_NULL(resource); - - auto top_graph = resource->func_graph(); - MS_EXCEPTION_IF_NULL(top_graph); - - for (const auto &tnode : top_graph->parameters()) { - auto para = tnode->cast(); - if (para != nullptr && para->name() == refkey->tag()) { - return para; - } - } - return nullptr; - }; - PatternNode x; - MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode, node)); - return nullptr; - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h b/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h deleted file mode 100644 index e10ff5c678..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/reshape_eliminate.h +++ /dev/null @@ -1,154 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ - -#include - -#include "ir/func_graph.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "abstract/dshape.h" - -namespace mindspore { -namespace opt { -namespace irpass { -using abstract::Shape; -using abstract::ShapePtr; - -// {reshape_op, X, Shape} -class ReshapeSameShapeEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimReshape, {IsNode, IsVNode})(node); - - // check pattern match - if (shape_ == nullptr) { - return nullptr; - } - - auto src_shape_abs = x_->abstract(); - if (src_shape_abs == nullptr) { - return nullptr; - } - - auto src_shape = src_shape_abs->GetShapeTrack(); - auto tgt_shape_abs = node->abstract(); - if (tgt_shape_abs == nullptr) { - return nullptr; - } - auto tgt_shape = tgt_shape_abs->GetShapeTrack(); - if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa() && tgt_shape->isa()) { - auto elements = tgt_shape->cast(); - auto shape = src_shape->cast(); - if (shape->shape() == elements->shape()) { - return x_; - } - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } else { - shape_ = node; - } - } - - void Reset() { - x_ = nullptr; - shape_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, shape_{nullptr}; -}; - -// {PrimReshape, {PrimReshape, X, Y}, Shape} -class TwoReshapeEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimReshape, {IsCNode, IsNode})(node); - - auto fg = node->func_graph(); - if (fg != nullptr && x_ != nullptr && shape_ != nullptr) { - auto new_node = fg->NewCNode({NewValueNode(prim_), x_, shape_}); - new_node->set_abstract(node->abstract()); - return new_node; - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (IsPrimitiveCNode(node, prim::kPrimReshape)) { - auto &inputs = node->cast()->inputs(); - // {PrimReshape, X, Y} - if (inputs.size() != 3) { - return; - } - prim_ = GetValueNode(inputs[0]); - x_ = inputs[1]; - } else { - shape_ = node; - } - } - - void Reset() { - prim_ = nullptr; - x_ = nullptr; - shape_ = nullptr; - } - - private: - PrimitivePtr prim_{nullptr}; - AnfNodePtr x_{nullptr}, shape_{nullptr}; -}; - -class ReshapeEliminater : public OptimizerCaller { - public: - ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} - ~ReshapeEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - auto new_node = reshape_same_shape_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - new_node = two_reshape_eliminater_(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - - return nullptr; - } - - private: - ReshapeSameShapeEliminater reshape_same_shape_eliminater_; - TwoReshapeEliminater two_reshape_eliminater_; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_RESHAPE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h deleted file mode 100644 index b6a4e1c852..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ /dev/null @@ -1,210 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ - -#include -#include -#include -#include - -#include "ir/optimizer_caller.h" -#include "ir/pattern_matcher.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -namespace irpass { -class SpecialOpEliminater : public OptimizerCaller { - public: - SpecialOpEliminater() - : insert_gradient_of_(std::make_shared(prim::kPrimInsertGradientOf)), - stop_gradient_(std::make_shared(prim::kPrimStopGradient)), - hook_backward_(std::make_shared(prim::kPrimHookBackward)), - print_shape_type_(std::make_shared(prim::kPrimPrintShapeType)), - get_ref_value_(std::make_shared(prim::kPrimGetRefValue)), - mirror_(std::make_shared(prim::kPrimMirror)), - virtual_div_(std::make_shared(prim::kPrimVirtualDiv)) { - eliminaters_.emplace_back(insert_gradient_of_); - eliminaters_.emplace_back(stop_gradient_); - eliminaters_.emplace_back(hook_backward_); - eliminaters_.emplace_back(print_shape_type_); - eliminaters_.emplace_back(get_ref_value_); - eliminaters_.emplace_back(mirror_); - eliminaters_.emplace_back(virtual_div_); - } - ~SpecialOpEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, - virtual_div_; - std::vector eliminaters_{}; -}; - -// {PrimVirtualDataset, X} -> X -// {PrimVirtualDataset, Xs} -> {prim::kPrimMakeTuple, Xs} -class VirtualDatasetEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!IsPrimitiveCNode(node, prim::kPrimVirtualDataset) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 1) { - return nullptr; - } - - std::vector args; - (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); - if (args.size() == 1) { - return args.front(); - } - - (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); - - return node->func_graph()->NewCNode(args); - } - - void Visit(const AnfNodePtr &) override {} -}; - -// {prim::kPrimSameTypeShape, X, Y} -> X -class SameEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimSameTypeShape, {IsNode, IsNode})(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } - } - - private: - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimCheckBprop, X, Y} -> X -class CheckBpropEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - x_ = nullptr; - AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); - return x_; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } - } - - private: - AnfNodePtr x_{nullptr}; -}; - -// Reset defer_inline flag -class ResetDeferInline : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (IsValueNode(node)) { - auto fg = GetValueNode(node); - fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); - } - return nullptr; - } -}; - -// {PrimZerosLike, Y} -> -// {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0} -class ZeroLikeFillZero : public AnfVisitor { - public: - ZeroLikeFillZero() - : PrimFill_(prim::GetPythonOps("fill", "mindspore.ops.functional")->cast()), - PrimShape_(prim::GetPythonOps("shape", "mindspore.ops.functional")->cast()), - PrimDType_(prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast()) {} - ~ZeroLikeFillZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - y_ = nullptr; - AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node); - if (y_ == nullptr || node->func_graph() == nullptr) { - return nullptr; - } - if ((y_->abstract() == nullptr) || !y_->abstract()->isa()) { - auto fg = node->func_graph(); - auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); - auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); - return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))}); - } - - abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast(); - - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - tensor::TensorPtr new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - (void)memset_s(data, mem_size, 0, mem_size); - - auto new_cnode = NewValueNode(new_tensor_ptr); - new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); - - return new_cnode; - } - - void Visit(const AnfNodePtr &node) override { y_ = node; } - - private: - AnfNodePtr y_{nullptr}; - PrimitivePtr PrimFill_, PrimShape_, PrimDType_; -}; - -// {prim::kPrimDepend, X, ValueCond}->X -class DependValueElim : public OptimizerCaller { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - PatternNode x, cond; - MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); - return nullptr; - } -}; - -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/specialize_transform.h b/mindspore/ccsrc/optimizer/irpass/specialize_transform.h deleted file mode 100644 index 3db9e7bd51..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/specialize_transform.h +++ /dev/null @@ -1,305 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ - -#include -#include -#include -#include -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "ir/manager.h" -#include "ir/func_graph.h" -#include "ir/func_graph_cloner.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -namespace internal { -class SpecializeTransform { - public: - SpecializeTransform() : cache_() {} - ~SpecializeTransform() = default; - - FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, - std::vector prim_args, std::vector value_args) { - if (cache_.count(func_graph) == 0) { - cache_[func_graph] = {}; - } - - auto &cache = cache_[func_graph]; - auto key = std::make_pair(graph_args, prim_args); - if (cache.count(key) == 0) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - - FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared("sp")); - mng->AddFuncGraph(new_fg); - - std::vector params = new_fg->parameters(); - std::vector new_params; - size_t n = graph_args.size(); - for (size_t i = 0; i < n; i++) { - if (graph_args[i] != nullptr) { - auto arg = NewValueNode(graph_args[i]); - (void)mng->Replace(params[i], arg); - continue; - } - if (prim_args[i] != nullptr) { - auto arg = NewValueNode(prim_args[i]); - (void)mng->Replace(params[i], arg); - continue; - } - if (value_args[i] != nullptr) { - auto &const_tensor = *value_args[i]; - auto const_tensor_ptr = std::make_shared(const_tensor); - AnfNodePtr arg = NewValueNode(const_tensor_ptr); - (void)mng->Replace(params[i], arg); - continue; - } - new_params.push_back(params[i]); - } - - mng->SetParameters(new_fg, new_params); - cache[key] = new_fg; - } - return cache[key]; - } - - private: - std::unordered_map, std::vector>, FuncGraphPtr>> - cache_; -}; -} // namespace internal - -// {G, Xs} -class SpecializeOnGraphArguments : public AnfVisitor { - public: - SpecializeOnGraphArguments() : specialize_transform_() {} - ~SpecializeOnGraphArguments() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[0])) { - return nullptr; - } - - auto inp0_fg = GetValueNode(inputs[0]); - if (inp0_fg->recursive()) { - return nullptr; - } - - std::vector graph_args; - std::vector prim_args; - std::vector value_node_args; - std::vector new_xs; - bool hasVNode = false; - for (size_t i = 1; i < inputs.size(); i++) { - if (IsValueNode(inputs[i])) { - auto fg_vnode = GetValueNode(inputs[i]); - graph_args.push_back(fg_vnode); - prim_args.emplace_back(nullptr); - value_node_args.emplace_back(nullptr); - hasVNode = true; - } else if (IsValueNode(inputs[i])) { - auto p_vnode = GetValueNode(inputs[i]); - graph_args.emplace_back(nullptr); - prim_args.push_back(p_vnode); - value_node_args.emplace_back(nullptr); - hasVNode = true; - } else if (IsValueNode(inputs[i])) { - tensor::TensorPtr t_vnode = GetValueNode(inputs[i]); - graph_args.emplace_back(nullptr); - prim_args.emplace_back(nullptr); - value_node_args.emplace_back(t_vnode); - hasVNode = true; - } else { - graph_args.emplace_back(nullptr); - prim_args.emplace_back(nullptr); - value_node_args.emplace_back(nullptr); - new_xs.push_back(inputs[i]); - } - } - - if (!hasVNode) { - return nullptr; - } - - auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); - (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); - - return node->func_graph()->NewCNode(new_xs); - } - - private: - internal::SpecializeTransform specialize_transform_; -}; - -// Eliminate unused parameters. -// {G, Xs} -class UnusedParasEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto &inputs = cnode->inputs(); - auto fg = GetValueNode(inputs[0]); - MS_EXCEPTION_IF_NULL(fg); - - std::vector parameters = fg->parameters(); - size_t size = parameters.size(); - if (size != inputs.size() - 1) { - return nullptr; - } - - std::vector new_xs; - std::vector keep_parameters; - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto &node_users = mng->node_users(); - bool has_unused_para = false; - for (size_t i = 0; i < size; ++i) { - auto iter = node_users.find(parameters[i]); - if (iter != node_users.end() && !iter->second.empty()) { - keep_parameters.push_back(true); - new_xs.push_back(inputs[i + 1]); - continue; - } - keep_parameters.push_back(false); - has_unused_para = true; - } - - if (!has_unused_para) { - return nullptr; - } - FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); - mng->AddFuncGraph(new_fg); - - std::vector new_fg_parameters = new_fg->parameters(); - std::vector new_parameters; - for (size_t i = 0; i < size; i++) { - if (keep_parameters[i]) { - if (parameters[i]->abstract() != nullptr) { - new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); - } - new_parameters.push_back(new_fg_parameters[i]); - } - } - mng->SetParameters(new_fg, new_parameters); - - (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); - return node->func_graph()->NewCNode(new_xs); - } -}; - -// Eliminate unused outputs. -// {G, Xs} -class UnusedOutputEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - if (!node->isa() || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - auto fg = GetValueNode(inputs[0]); - MS_EXCEPTION_IF_NULL(fg); - auto mng = fg->manager(); - MS_EXCEPTION_IF_NULL(mng); - if (fg->recursive()) { - return nullptr; - } - - auto new_fg = TransformableClone(fg, std::make_shared("fg")); - mng->AddFuncGraph(new_fg); - auto new_fg_output = new_fg->output(); - if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { - return nullptr; - } - - auto output_cnode = new_fg_output->cast(); - auto &node_users = mng->node_users(); - if (node_users.count(node) == 0 || node_users[node].empty()) { - return nullptr; - } - std::unordered_set used_output_idx; - std::vector> all_users; - for (auto &node_user : node_users[node]) { - if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { - return nullptr; - } - auto user_cnode = node_user.first->cast(); - size_t used_idx = GetValue(user_cnode->input(2)->cast()->value()); - used_output_idx.insert(used_idx); - all_users.push_back(std::make_pair(node_user.first, used_idx)); - } - - if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { - // all output has users. - return nullptr; - } - - if (used_output_idx.empty()) { - // we do not process this case. - return nullptr; - } else if (used_output_idx.size() == 1) { - // after eliminate, only one output left. - new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); - // update users. - for (auto &ret_user : all_users) { - (void)mng->Replace(ret_user.first, node); - } - } else { - // after eliminate, create new multi output. - std::vector new_output_inputs{output_cnode->input(0)}; - std::unordered_map new_idx_map; - for (auto idx : used_output_idx) { - new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); - new_output_inputs.push_back(output_cnode->input(idx + 1)); - } - new_fg->set_output(new_fg->NewCNode(new_output_inputs)); - // update users. - for (auto &ret_user : all_users) { - auto ret_user_cnode = ret_user.first->cast(); - ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); - } - } - - auto new_sx = inputs; - new_sx[0] = NewValueNode(new_fg); - return node->func_graph()->NewCNode(new_sx); - } -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/optimizer/irpass/symbol_resolver.h deleted file mode 100644 index 7b35cf5451..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/symbol_resolver.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ - -#include -#include - -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// {prim::kPrimResolve, Ns, Sym} -class ResolverResolve : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimResolve, {IsVNode, IsVNode})(node); - if (sym_ != nullptr) { - return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); - } - return nullptr; - } - - void Visit(const ValueNodePtr &vnode) override { - if (IsValueNode(vnode)) { - ns_ = GetValueNode(vnode); - } else if (ns_ != nullptr && IsValueNode(vnode)) { - sym_ = GetValueNode(vnode); - } - } - - void Reset() { - ns_ = nullptr; - sym_ = nullptr; - } - - private: - parse::NameSpacePtr ns_{nullptr}; - parse::SymbolPtr sym_{nullptr}; -}; - -// {prim::kPrimGetAttr, Ns, Str} -class ResolverGetattr : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimGetAttr, {IsVNode, IsVNode})(node); - if (sym_ != nullptr) { - return parse::ResolveSymbol(optimizer->manager(), ns_, sym_, node); - } - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (IsValueNode(node)) { - ns_ = GetValueNode(node); - } else if (ns_ != nullptr && IsValueNode(node)) { - auto str = GetValue(GetValueNode(node)); - sym_ = std::make_shared(str); - } - } - - void Reset() { - ns_ = nullptr; - sym_ = nullptr; - } - - private: - parse::NameSpacePtr ns_{nullptr}; - parse::SymbolPtr sym_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SYMBOL_RESOLVER_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/tile_eliminate.h b/mindspore/ccsrc/optimizer/irpass/tile_eliminate.h deleted file mode 100644 index 86ac5bab73..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/tile_eliminate.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ - -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// check if node is value tuple and all one. e.g. (1, 1, 1) -// {PrimTile, X, MultiOne} -class TileMultiplyByOne : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTile, {IsNode, IsVNode})(node); - - // check pattern match - if (tuple_ == nullptr) { - return nullptr; - } - - auto value = GetValueNode(tuple_); - auto elements = GetValue>(value); - if (elements.empty()) { - return nullptr; - } - - auto cmp = std::all_of(elements.cbegin(), elements.cend(), [](int i) { return i == 1; }); - if (cmp) { - return x_; - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } else { - tuple_ = node; - } - } - - void Reset() { - x_ = nullptr; - tuple_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TILE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/transpose_eliminate.h b/mindspore/ccsrc/optimizer/irpass/transpose_eliminate.h deleted file mode 100644 index de196ea619..0000000000 --- a/mindspore/ccsrc/optimizer/irpass/transpose_eliminate.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ - -#include -#include - -#include "optimizer/irpass.h" -#include "optimizer/optimizer.h" -#include "ir/visitor.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace irpass { -// check if node is value tuple and ascends one by one from zero. e.g., (0, 1, 2, 3) -// {PrimTranspose, X, AscendingNums} -class TransposeSameIOEliminater : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - AnfVisitor::Match(prim::kPrimTranspose, {IsNode, IsVNode})(node); - - // check pattern match - if (tuple_ == nullptr) { - return nullptr; - } - - auto value = GetValueNode(tuple_); - auto elements = GetValue>(value); - if (elements.empty()) { - return nullptr; - } - - int j = 0; - bool cmp = std::all_of(elements.cbegin(), elements.cend(), [&j](int i) { return i == j++; }); - // same IO settings, eliminate this transpose - if (cmp) { - return x_; - } - - return nullptr; - } - - void Visit(const AnfNodePtr &node) override { - if (x_ == nullptr) { - x_ = node; - } else { - tuple_ = node; - } - } - - void Reset() { - x_ = nullptr; - tuple_ = nullptr; - } - - private: - AnfNodePtr x_{nullptr}, tuple_{nullptr}; -}; -} // namespace irpass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_TRANSPOSE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc deleted file mode 100644 index 5e893cf1aa..0000000000 --- a/mindspore/ccsrc/optimizer/opt.cc +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "optimizer/opt.h" - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/manager.h" -#include "optimizer/optimizer.h" -#include "utils/log_adapter.h" -#include "utils/ordered_set.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, - const RenormAction &renorm_action) { - auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; - return std::make_shared(transform, name, fn, renorm_action); -} - -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const std::vector &prims, const RenormAction &renorm_action) { - auto fn = [prims](const AnfNodePtr &node) -> bool { - if (!node->isa()) { - return false; - } - - auto cnode = node->cast(); - auto inp0 = cnode->input(0); - auto prim0 = GetValueNode(inp0); - if (prim0 == nullptr) { - return false; - } - - auto hash = prim0->Hash(); - auto const &name = prim0->name(); - for (auto &prim : prims) { - if (hash == prim->Hash() && name == prim->name()) { - return true; - } - } - return false; - }; - - return std::make_shared(transform, name, fn, renorm_action); -} - -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const PredicateFuncType &predicate, const RenormAction &renorm_action) { - return std::make_shared(transform, name, predicate, renorm_action); -} - -AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { -#ifdef ENABLE_PROFILE - double t = GetTime(); -#endif - AnfNodePtr result = (*transform_)(optimizer, node); -#ifdef ENABLE_PROFILE - if (optimizer != nullptr) { - auto time = GetTime(); - MsProfile::StatTime("substitution." + name_, time - t); - if (result != nullptr) { - MsProfile::StatTime("match." + name_, time - t); - } - } -#endif - if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { - if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) { - optimizer->set_is_untyped_generated(); - } - } - - return result; -} - -static bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->isa() || node->isa()) { - return true; - } - if (IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; -} - -bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, - const SubstitutionPtr &transform) const { -#ifdef ENABLE_PROFILE - double start = GetTime(); -#endif - FuncGraphManagerPtr manager = optimizer->manager(); - auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.clear(); - todo.push_back(root_node); - bool changes = false; - - auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - - // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { - continue; - } - node->seen_ = seen; - - // select nodes that this transform can be applied. - bool is_match = transform->predicate_(node); - - // apply transform on this node - bool change = false; - if (is_match) { - auto ret = (*transform)(optimizer, node); - if (ret != nullptr && ret != node) { - change = true; - changes = true; -#ifdef ENABLE_PROFILE - double t = GetTime(); -#endif - (void)manager->Replace(node, ret); -#ifdef ENABLE_PROFILE - MsProfile::StatTime("replace." + transform->name_, GetTime() - t); -#endif - node = ret; - } - } - - // find success, and add them to todo list - if (IsValueNode(node)) { - todo.push_back(GetValueNode(node)->output()); - } - - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); - } - - auto &node_users = manager->node_users(); - if (change && node_users.find(node) != node_users.end()) { - for (auto &use : node_users[node]) { - auto use_node = use.first; - if (use_node == nullptr) { - continue; - } - todo.push_back(use_node); - if (use_node->seen_ == seen) { - use_node->seen_--; - } - } - } - } - -#ifdef ENABLE_PROFILE - MsProfile::StatTime("opt.transform." + optimizer->name(), GetTime() - start); -#endif - return changes; -} - -bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = optimizer->manager(); - manager->AddFuncGraph(func_graph); - - // for transform status counting - size_t space = 0; - std::unordered_map> status; - if (optimizer->is_on_debug_) { - for (size_t i = 0; i < list_.size(); i++) { - status[list_[i]->name_ + std::to_string(i)] = {}; - } - } - - bool loop = false; - bool changes = false; - - do { - loop = false; - for (size_t i = 0; i < list_.size(); i++) { - auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); - changes = changes || change; - loop = loop || change; - - // record the status of each transform - if (optimizer->is_on_debug_) { - status[list_[i]->name_ + std::to_string(i)].push_back(change); - space = std::max(list_[i]->name_.size(), space); - } - } - - if (is_once_) { - break; - } - } while (loop); - - // display the status of each transform - if (optimizer->is_on_debug_) { - std::stringstream ss; - ss << std::endl - << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name - << std::endl; - for (size_t i = 0; i < list_.size(); i++) { - auto name = list_[i]->name_; - ss << std::left << std::setw(space + 4) << name << "\t"; - for (auto change : status[name + std::to_string(i)]) { - ss << change << " "; - } - ss << std::endl; - } - MS_LOG(DEBUG) << ss.str(); - } - - return changes; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/opt.h b/mindspore/ccsrc/optimizer/opt.h deleted file mode 100644 index 6601d969d2..0000000000 --- a/mindspore/ccsrc/optimizer/opt.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/optimizer_caller.h" -#include "operator/ops.h" - -namespace mindspore { -/* namespace to support opt */ -namespace opt { - -// Define the interaction mode between an Optimize pass and Renormalize pass -// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed -// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted -enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; - -class Substitution { - public: - OptimizerCallerPtr transform_; - std::string name_; - PredicateFuncType predicate_{nullptr}; - // an enum to mark this Substitution relation to renormalize pass - RenormAction renorm_action_; - Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate, - const RenormAction &renorm_action) - : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} - ~Substitution() = default; - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node); -}; - -using SubstitutionPtr = std::shared_ptr; - -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim, - const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const std::vector &prims, - const RenormAction &action_renorm = CHECK_RENORM); -SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, - const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); - -class SubstitutionList { - public: - explicit SubstitutionList(const std::vector &patterns, bool is_once = false) - : list_(patterns), is_once_(is_once) {} - ~SubstitutionList() = default; - - bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; - - private: - bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; - std::vector list_; - // a flag to mark this list of Substitution can only be executed only once - bool is_once_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h deleted file mode 100644 index a98a59caf2..0000000000 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ /dev/null @@ -1,242 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "debug/draw.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "debug/trace.h" -#include "optimizer/opt.h" -#include "pipeline/resource.h" -#include "pipeline/action.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -using OptimizeGraphFunc = std::function; - -class OptPassConfig { - public: - explicit OptPassConfig(const OptimizeGraphFunc &func) : func_(func) {} - explicit OptPassConfig(const std::vector &list, bool is_once = false) - : list_(list), is_once_(is_once) {} - OptPassConfig(const std::initializer_list &list, bool is_once = false) - : list_(list), is_once_(is_once) {} - ~OptPassConfig() = default; - - const std::vector &list() const { return list_; } - const OptimizeGraphFunc &func() const { return func_; } - - static OptPassConfig Renormalize() { return OptPassConfig(); } - const bool is_renormalize() const { return is_renormalize_; } - - const bool is_once() const { return is_once_; } - - private: - OptPassConfig() : is_renormalize_(true) {} - - OptimizeGraphFunc func_; - std::vector list_; - bool is_renormalize_{false}; - bool is_once_{false}; -}; - -class OptPass { - public: - explicit OptPass(const OptimizeGraphFunc &func) : pass_func_(func) {} - ~OptPass() = default; - - bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { - return pass_func_(func_graph, optimizer); - } - - static OptPass Renormalize() { return OptPass(); } - const bool is_renormalize() const { return is_renormalize_; } - - private: - OptPass() : is_renormalize_(true) {} - - OptimizeGraphFunc pass_func_; - bool is_renormalize_{false}; -}; -using OptPassGroupMap = std::vector>; - -class Optimizer : public std::enable_shared_from_this { - public: - Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) - : name_(name), - resource_(resource_ptr), - run_only_once_(false), - is_watch_renormalize_(false), - is_enable_(true), - is_untyped_generated_(false) {} - virtual ~Optimizer() = default; - - void Init(const OptPassGroupMap &passes, bool run_only_once) { - run_only_once_ = run_only_once; - is_watch_renormalize_ = false; - is_untyped_generated_ = false; - is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); - - for (auto &iter : passes) { - const std::string &name = iter.first; - pass_names_.push_back(name); - - const OptPassConfig &config = iter.second; - if (config.is_renormalize()) { - passes_.push_back(OptPass::Renormalize()); - continue; - } - - if (config.list().size() > 0) { - OptimizeGraphFunc func = SubstitutionList(config.list(), config.is_once()); - passes_.push_back(OptPass(func)); - continue; - } - - passes_.push_back(OptPass(config.func())); - } - - if (passes_.size() == 1) { - run_only_once_ = true; - } - } - - static std::shared_ptr MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, - const OptPassGroupMap &passes, bool run_only_once = false, - bool watch_renormalize = false) { - OptimizerPtr optimizer = std::make_shared(name, resource_ptr); - optimizer->Init(passes, run_only_once); - if (watch_renormalize) { - optimizer->enable_watch_renormalize(); - } - return optimizer; - } - - FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { - if (!is_enable_) { - return func_graph; - } - // Optimizer step counter; - int counter = 1; - bool changes = true; - - while (changes) { - changes = false; - auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { - for (size_t i = 0; i < passes_.size(); ++i) { - const OptPass &opt = passes_[i]; - CurPass_ = {counter, pass_names_[i]}; - auto opt_func = [&func_graph, &changes, &opt, this]() { - if (opt.is_renormalize()) { - auto resource_ptr = std::dynamic_pointer_cast(resource_); - if (resource_ptr != nullptr) { - // StepParallel may replace the AbstractValue of the parameters of func_graph, - // So generate the args_spec from parameters. - abstract::AbstractBasePtrList maybe_new_args_spec; - if (is_watch_renormalize_) { - if (is_untyped_generated_) { - std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), - std::back_inserter(maybe_new_args_spec), - [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); - func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); - clear_is_untyped_generated(); - } else { - MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False."; - } - } else { - std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), - std::back_inserter(maybe_new_args_spec), - [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); - func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); - } - } - } else if (opt(func_graph, shared_from_this())) { - changes = true; - } - }; - use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); - if (is_on_debug_ && MsContext::GetInstance()->save_graphs_flag()) { - MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; - auto fg_name = - "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; - func_graph->DumpFuncGraph(fg_name); - DumpIR(fg_name + ".ir", func_graph); - ExportIR(fg_name + ".dat", "", func_graph); - MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; - } - } - }; - use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter)) run_runc) : run_runc(); - counter++; - - if (run_only_once_) { - break; - } - } - return func_graph; - } - - pipeline::ResourceBasePtr resource() const { return resource_; } - FuncGraphManagerPtr manager() const { - if (resource_ != nullptr) { - return resource_->manager(); - } - MS_LOG(EXCEPTION) << "No ResourceBase exists."; - } - - const std::string name() const { return name_; } - - void set_is_untyped_generated() { is_untyped_generated_ = true; } - void clear_is_untyped_generated() { is_untyped_generated_ = false; } - - void enable_watch_renormalize() { is_watch_renormalize_ = true; } - void disable_watch_renormalize() { is_watch_renormalize_ = false; } - bool is_watch_renormalize() { return is_watch_renormalize_; } - void set_enable(bool enable) { is_enable_ = enable; } - - struct { - int counter; - std::string name; - } CurPass_; - - bool is_on_debug_{false}; - - private: - const std::string name_; - pipeline::ResourceBasePtr resource_; - std::vector passes_; - std::vector pass_names_; - bool run_only_once_; - bool is_watch_renormalize_; - bool is_enable_; - bool is_untyped_generated_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/optimizer/pass_group.cc b/mindspore/ccsrc/optimizer/pass_group.cc deleted file mode 100644 index 2d1ab07f7d..0000000000 --- a/mindspore/ccsrc/optimizer/pass_group.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "optimizer/pass_group.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -void PassGroup::AddPass(const PythonPassPtr &pass) { - if (pass != nullptr) { - passes_.push_back(pass); - } -} - -bool PassGroup::DeletePass(const std::string &pass_name) { - for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { - if ((*iter)->name() == pass_name) { - *iter = nullptr; - passes_.erase(iter); - return true; - } - } - return false; -} - -bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { - if (func_graph == nullptr) { - return false; - } - bool changed = false; - for (const auto &pass : passes) { - if (pass != nullptr) { - if (pass->Run(func_graph)) { - changed = true; - } - } - } - return changed; -} - -bool PassGroup::Run(const FuncGraphPtr &func_graph) const { - bool changed = false; - // run all passes - bool change = true; - while (change) { - change = Run(func_graph, passes_); - changed = change || changed; - if (run_only_once_) { - break; - } - } - return changed; -} - -} // namespace python_pass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/pass_group.h b/mindspore/ccsrc/optimizer/pass_group.h deleted file mode 100644 index 895f5a4128..0000000000 --- a/mindspore/ccsrc/optimizer/pass_group.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ - -#include -#include -#include -#include - -#include "optimizer/py_pass.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -class PassGroup { - public: - explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) - : name_(name), passes_{}, run_only_once_(run_only_once) {} - virtual ~PassGroup() = default; - // Add graph pass, the pass object will be freed when pass manager freed. - void AddPass(const PythonPassPtr &pass); - // Delete graph pass before the pass manager is freed. - bool DeletePass(const std::string &pass_name); - // Run passes added in pass manager on the input graph - // @param [inout] graph The graph to be optimized - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph) const; - // Run the given graph passes on the input graph - // @param [inout] graph The graph to be optimized - // @param [in] passes The given graph passes - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; - std::string name() const { return name_; } - - private: - const std::string name_; - std::vector passes_; - bool run_only_once_; -}; -using PassGroupPtr = std::shared_ptr; -} // namespace python_pass -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ diff --git a/mindspore/ccsrc/optimizer/py_pass.cc b/mindspore/ccsrc/optimizer/py_pass.cc deleted file mode 100644 index 842ccb75b9..0000000000 --- a/mindspore/ccsrc/optimizer/py_pass.cc +++ /dev/null @@ -1,237 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "optimizer/py_pass.h" -#include -#include -#include -#include -#include - -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/resource.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -namespace internal { -std::string GetNodeRepr(AnfNodePtr node) { - if (node != nullptr) { - if (node->isa()) { - std::string repr = "("; - auto const &inputs = node->cast()->inputs(); - for (auto &input : inputs) { - repr += " "; - repr += GetNodeRepr(input); - repr += " "; - } - repr += ")"; - return repr; - } - if (node->isa()) { - return GetValueNode(node)->ToString(); - } - return node->ToString(); - } - return ""; -} - -void ResolveFuncGraph_(const FuncGraphPtr &fg) { - auto manager = Manage(fg, false); - parse::python_adapter::set_use_signature_in_resolve(false); - parse::ResolveAll(manager); - parse::python_adapter::set_use_signature_in_resolve(true); -} - -bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { - if (node == nullptr) { - return false; - } - MS_EXCEPTION_IF_NULL(pattern); - if (pattern->isa()) { - if (!node->isa()) { - return false; - } - if (GetNodeRepr(pattern) == GetNodeRepr(node)) { - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); - return true; - } - return false; - } else if (pattern->isa()) { - MS_LOG(DEBUG) << pattern->ToString() + "\n"; - // add to equiv_ptr - equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); - return true; - } else if (pattern->isa()) { - // match every single sub ANode - if (!node->isa()) { - return false; - } - auto pattern_inputs = pattern->cast()->inputs(); - auto node_inputs = node->cast()->inputs(); - if (pattern_inputs.size() != node_inputs.size()) { - return false; - } - for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); - p_item++, node_item++) { - auto res = Match(*p_item, *node_item, equiv_ptr); - if (!res) { - return false; - } - } - return true; - } - MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; -} - -AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, - const NodeEquivPtr &equiv_ptr) { - if (cur_raw_dst_node_->isa()) { - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; - } else if (cur_raw_dst_node_->isa()) { - // check primitive ValueNode - auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast()->value()->ToString()); - if (sub_pair != equiv_ptr->end()) { - return sub_pair->second; - } - return cur_raw_dst_node_; - } else if (cur_raw_dst_node_->isa()) { - std::vector new_inputs; - auto inputs = cur_raw_dst_node_->cast()->inputs(); - for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { - auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); - new_inputs.push_back(subed); - } - return func_graph->NewCNode(new_inputs); - } - MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); -} - -bool isTraversable(const AnfNodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->isa() || node->isa()) { - return true; - } - if (IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; -} -} // namespace internal - -void PythonPass::Build(const py::function &src, const py::function &dst) { - // 1. get FuncGraph from py::function - auto src_fg_ = parse::ParsePythonCode(src); - auto dst_fg_ = parse::ParsePythonCode(dst); - if (src_fg_ == nullptr || dst_fg_ == nullptr) { - MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; - } - // 2. Resolve - internal::ResolveFuncGraph_(src_fg_); - internal::ResolveFuncGraph_(dst_fg_); - // 3. from FuncGraphPtr to ValueNode - src_node_ = src_fg_->output(); - dst_node_ = dst_fg_->output(); -} - -PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, - bool multigraph) - : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { - Build(src, dst); -} - -AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - auto equiv_ptr = std::make_shared(); - bool is_a_match = internal::Match(src_node_, node, equiv_ptr); - if (is_a_match) { - auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); - MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; - return new_node; - } - return nullptr; -} - -bool PythonPass::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - auto seen = NewSeenGeneration(); - // 1024 is for the initial capacity of deque - std::deque todo(1024); - todo.push_back(func_graph->output()); - bool changes = false; - - auto &all_nodes = manager->all_nodes(); - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - - // check whether this node has been matched. - if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { - continue; - } - node->seen_ = seen; - - // select nodes that this transform can be applied. - AnfNodePtr new_node = Run(func_graph, node); - bool change = (new_node != nullptr); - if (new_node != nullptr && new_node != node) { - (void)manager->Replace(node, new_node); - } else if (new_node == nullptr) { - new_node = node; - } - if (run_only_once_) { - return change; - } - - // find success, and add them to todo list - if (IsValueNode(node)) { - todo.push_back(GetValueNode(node)->output()); - } - - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); - } - - auto &node_users = manager->node_users(); - if (change && node_users.find(node) != node_users.end()) { - for (auto &use : node_users[node]) { - auto use_node = use.first; - if (use_node == nullptr) { - continue; - } - todo.push_back(use_node); - if (use_node->seen_ == seen) { - use_node->seen_--; - } - } - } - } - return changes; -} -} // namespace python_pass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.cc b/mindspore/ccsrc/optimizer/py_pass_manager.cc deleted file mode 100644 index 1c36e93c9a..0000000000 --- a/mindspore/ccsrc/optimizer/py_pass_manager.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "optimizer/py_pass_manager.h" - -#include -#include -#include -#include - -#include "ir/manager.h" -#include "optimizer/pass_group.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -PyPassManagerPtr PyPassManager::global_instance = nullptr; -std::unordered_map PyPassManager::phase_to_group_; - -PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { - auto pm = phase_to_group_.find(phase); - if (pm == phase_to_group_.end()) { - return nullptr; - } - return pm->second; -} - -PyPassManagerPtr PyPassManager::GetInstance() { - if (global_instance == nullptr) { - global_instance = std::shared_ptr(new (std::nothrow) PyPassManager()); - } - return global_instance; -} - -PyPassManager::PyPassManager() { - phase_to_group_[Phase::RESOLVE] = std::make_shared(); - phase_to_group_[Phase::OPT] = std::make_shared(); -} - -void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, - Phase phase, bool run_only_once, bool multigraph) { - auto cur_pm = GetPassGroup(phase); - MS_EXCEPTION_IF_NULL(cur_pm); - PythonPassPtr new_pass = std::make_shared(pass_name, pattern, target, run_only_once, multigraph); - cur_pm->AddPass(new_pass); -} - -void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { - auto cur_pm = GetPassGroup(phase); - MS_EXCEPTION_IF_NULL(cur_pm); - if (!cur_pm->DeletePass(pass_name)) { - MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; - } -} - -void PyPassManager::ClearRes() { - MS_LOG(INFO) << "Clear PyPassManager resources!"; - global_instance = nullptr; - phase_to_group_.clear(); -} - -REGISTER_PYBIND_DEFINE( - PyPassManager_, ([](const py::module *m) { - (void)py::enum_(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); - (void)py::class_>(*m, "PyPassManager_") - .def(py::init([]() { return PyPassManager::GetInstance(); })) - .def("registe", &PyPassManager::Registe, "Registe python pass") - .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); - })); -} // namespace python_pass -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/py_pass_manager.h b/mindspore/ccsrc/optimizer/py_pass_manager.h deleted file mode 100644 index f7218d5ab2..0000000000 --- a/mindspore/ccsrc/optimizer/py_pass_manager.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ - -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/primitive_py.h" -#include "utils/graph_utils.h" -#include "common/utils.h" - -#include "pipeline/parse/resolve.h" -#include "optimizer/py_pass.h" -#include "optimizer/pass_group.h" - -namespace mindspore { -namespace opt { -namespace python_pass { -class PyPassManager; -using PyPassManagerPtr = std::shared_ptr; - -enum Phase { RESOLVE, OPT }; - -class PyPassManager { - protected: - PyPassManager(); - static PyPassManagerPtr global_instance; - - public: - // Singletons should not be cloneable and assignable - PyPassManager(const PyPassManager &other) = delete; - void operator=(const PyPassManager &) = delete; - // Access the only global instance - static PyPassManagerPtr GetInstance(); - virtual ~PyPassManager() = default; - void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, - Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); - void Unregiste(const std::string &pass_name, Phase phase); - PassGroupPtr GetPassGroup(Phase phase); - void ClearRes(); - - private: - static std::unordered_map phase_to_group_; -}; -} // namespace python_pass -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/parallel/CMakeLists.txt b/mindspore/ccsrc/parallel/CMakeLists.txt deleted file mode 100644 index 76ac2cfcd7..0000000000 --- a/mindspore/ccsrc/parallel/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -file(GLOB_RECURSE _PARALLEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -list(REMOVE_ITEM _PARALLEL_SRC_FILES "ps/util.cc" "ps/scheduler.cc" "ps/optimizer_info.cc" "ps/optimizer_info_builder.cc") -if (ENABLE_DUMP_PROTO) - list(REMOVE_ITEM _PARALLEL_SRC_FILES "parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") -endif () - -set_property(SOURCE ${_PARALLEL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARALLEL) -add_library(_mindspore_parallel_obj OBJECT ${_PARALLEL_SRC_FILES}) diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc deleted file mode 100644 index 30173e533c..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ /dev/null @@ -1,435 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/allreduce_fusion.h" -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "parallel/costmodel_context.h" -#include "parallel/graph_util/node_info.h" -#include "parallel/status.h" -#include "parallel/step_parallel.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(para); - MS_EXCEPTION_IF_NULL(para->func_graph()); - FuncGraphManagerPtr manager = para->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto node_set = manager->node_users()[para]; - std::unordered_set cnode_set; - for (auto &node_pair : node_set) { - auto cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - auto node_prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { - (void)cnode_set.emplace(cnode); - } else { - auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); - for (auto &cnode_sub : cnode_set_sub) { - (void)cnode_set.emplace(cnode_sub); - } - } - } - return cnode_set; -} - -Status AllreduceFusion::AddNodeToGraph() { - const auto ¶meters = root_graph_->parameters(); - for (auto ¶meter : parameters) { - if (!ParameterRequireGrad(parameter)) { - continue; - } - auto cnode_set = FindCNodesWithPara(parameter); - if (cnode_set.empty()) { - continue; - } - for (auto &cnode : cnode_set) { - MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); - if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { - MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); - return FAILED; - } - } - } - return SUCCESS; -} - -CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(from); - std::unordered_map cnode_dist; - if (!from->isa()) { - return cnode_dist; - } - auto cnode = from->cast(); - if (!IsValueNode(cnode->input(0))) { - return cnode_dist; - } - - MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) - << " operator_info: " << (cnode->operator_info() != nullptr); - - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); - MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; - - if (allreduce_graph_.NodeInGraph(cnode)) { - cnode_dist[cnode] = cost; - return cnode_dist; - } else { - auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); - for (auto &ele : cnode_dist_next) { - cnode_dist[ele.first] = cost + ele.second; - } - } - } else { - auto cnode_dist_next = FindNextCNodes(cnode); - for (auto &ele : cnode_dist_next) { - cnode_dist[ele.first] = ele.second; - } - } - return cnode_dist; -} - -CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - const auto &from_inputs = from->inputs(); - std::unordered_map dist_map; - MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; - for (auto &input_node : from_inputs) { - auto cnode_dist = FindCNode(input_node, recursive_times + 1); - for (auto &ele : cnode_dist) { - (void)dist_map.emplace(ele); - } - } - return dist_map; -} - -Status AllreduceFusion::AddEdgeToGraph() { - std::unordered_map cnode_state_map; - const auto &cnodes = allreduce_graph_.cnode_set(); - for (auto &cnode : cnodes) { - cnode_state_map[cnode] = 0; - } - const auto &head_cnode = allreduce_graph_.head_cnode(); - std::queue cnode_queue; - cnode_queue.emplace(head_cnode); - cnode_state_map[head_cnode] = 1; - - while (!cnode_queue.empty()) { - const auto cur_cnode = cnode_queue.front(); - cnode_queue.pop(); - cnode_state_map[cur_cnode] = 2; - auto next = FindNextCNodes(cur_cnode); - for (auto &ele : next) { - auto &cnode = ele.first; - auto &dist = ele.second; - if (cnode_state_map[cnode] == 0) { - cnode_queue.emplace(cnode); - cnode_state_map[cnode] = 1; - } - if (allreduce_graph_.AddEdge(cur_cnode, cnode, dist) != SUCCESS) { - MS_LOG(ERROR) << "AddEdge error"; - return FAILED; - } - MS_LOG(DEBUG) << "from " << cur_cnode->DebugString() << ", to " << cnode->DebugString() << " dist " << dist; - } - } - return SUCCESS; -} - -std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { - if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { - MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " - << MAX_RECURSIVE_CALL_TIMES; - } - MS_EXCEPTION_IF_NULL(para); - MS_EXCEPTION_IF_NULL(para->func_graph()); - FuncGraphManagerPtr manager = para->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[para]; - std::vector cnode_list; - for (auto &node_pair : node_set) { - auto cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - auto node_prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == CAST) { - auto mirror_cnodes = FindMirror(node_pair.first, recursive_times + 1); - if (mirror_cnodes.empty()) { - MS_LOG(WARNING) << "mirror node after cast not found"; - continue; - } - if (mirror_cnodes.size() > 1) { - MS_LOG(EXCEPTION) << "mirror node after cast number is not 1"; - } - cnode_list.emplace_back(mirror_cnodes[0]); - } - if (node_prim->name() == MIRROR_OPERATOR) { - cnode_list.emplace_back(cnode); - } - } - return cnode_list; -} - -void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { - MS_EXCEPTION_IF_NULL(mirror_cnode); - MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; - auto node_prim = GetValueNode(mirror_cnode->input(0)); - auto old_value_ptr = node_prim->GetAttr(FUSION); - if (old_value_ptr != nullptr) { - if (old_value_ptr->isa()) { - int32_t old_value = old_value_ptr->cast()->value(); - if (old_value < fusion) { - return; - } - } - } - (void)node_prim->AddAttr(FUSION, MakeValue(std::make_shared(fusion))); - (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); -} - -Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { - auto mirror_cnodes = FindMirror(para); - if (mirror_cnodes.empty()) { - MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; - return SUCCESS; - } - if (mirror_cnodes.size() > 2) { - for (auto &mirror_cnode : mirror_cnodes) { - MS_EXCEPTION_IF_NULL(mirror_cnode); - MS_LOG(INFO) << mirror_cnode->DebugString(); - } - MS_EXCEPTION_IF_NULL(para); - MS_LOG(ERROR) << para->ToString() << " FindMirror is more than 2. " << mirror_cnodes.size() - << "Mirror CNode found."; - return FAILED; - } - for (auto &mirror_cnode : mirror_cnodes) { - auto parameter_name = ParameterName(para); - SetMirrorFusion(mirror_cnode, fusion, parameter_name); - } - return SUCCESS; -} - -Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { - for (auto ¶m_node : paras) { - if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - } - return SUCCESS; -} - -Status AllreduceFusion::SetFusion(const std::vector &cost_map) { - if (cost_map.size() < 2) { - MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); - return FAILED; - } - int32_t fusion = 1; - for (auto cost_iter = cost_map.end() - 1; cost_iter != cost_map.begin(); --cost_iter) { - auto paras = allreduce_graph_.GetParaByCost(*(cost_iter - 1), *cost_iter); - if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - fusion++; - } - return SUCCESS; -} - -std::vector AllreduceFusion::GenerateCostMap(int32_t fusion_times, double tail_percent) const { - double offset = allreduce_graph_.max() * (1 - tail_percent) / (fusion_times - 1); - MS_LOG(DEBUG) << "max = " << allreduce_graph_.max() << ", offset = " << offset; - std::vector cost_map; - double begin = 0; - for (auto i = 0; i < fusion_times - 1; i++) { - cost_map.push_back(begin); - begin += offset; - } - cost_map.push_back(allreduce_graph_.max() * (1 - tail_percent)); - cost_map.push_back(allreduce_graph_.max()); - MS_LOG(DEBUG) << "cost_map = " << cost_map; - return cost_map; -} - -Status AllreduceFusion::SetFusionByBackwardCompTime() { - auto fusion_times = CostModelContext::GetInstance()->costmodel_allreduce_fusion_times(); - if (fusion_times < 2) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_times' is " << fusion_times << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - auto tail_percent = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_percent(); - if (tail_percent < 0 || tail_percent >= 1) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_tail_percent' is " << tail_percent - << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - const auto cost_map = GenerateCostMap(fusion_times, tail_percent); - MS_LOG(DEBUG) << "AllreduceGraph GenerateCostMap succeed."; - if (SetFusion(cost_map) != SUCCESS) { - MS_LOG(ERROR) << "SetFusion failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusion succeed."; - return SUCCESS; -} - -Status AllreduceFusion::GetSetFusionByBackwardCompAndAllreduceTimeParams() { - tail_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_tail_time(); - if (tail_time_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - allreduce_inherent_time_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_inherent_time(); - if (allreduce_inherent_time_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - if (tail_time_ <= allreduce_inherent_time_) { - MS_LOG(INFO) << "'costmodel_allreduce_tail_time' is " << tail_time_ - << "'costmodel_allreduce_fusion_allreduce_inherent_time' is " << allreduce_inherent_time_ - << ".tail_time is not more than allreduce_inherent_time. Bypass ProcessAllreduceFusion"; - return FAILED; - } - allreduce_bandwidth_ = CostModelContext::GetInstance()->costmodel_allreduce_fusion_allreduce_bandwidth(); - if (allreduce_bandwidth_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_allreduce_bandwidth' is " << allreduce_bandwidth_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - computation_time_parameter_ = - CostModelContext::GetInstance()->costmodel_allreduce_fusion_computation_time_parameter(); - if (computation_time_parameter_ <= 0) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_computation_time_parameter' is " << computation_time_parameter_ - << ". Bypass ProcessAllreduceFusion"; - return FAILED; - } - return SUCCESS; -} - -Status AllreduceFusion::SetFusionByBackwardCompAndAllreduceTime() { - if (GetSetFusionByBackwardCompAndAllreduceTimeParams() != SUCCESS) { - MS_LOG(ERROR) << "GetSetFusionByBackwardCompAndAllreduceTimeParams failed!"; - return FAILED; - } - allreduce_graph_.SortArnode(); - if (allreduce_graph_.RemoveExtraParas() != SUCCESS) { - MS_LOG(ERROR) << "RemoveExtraParas failed!"; - return FAILED; - } - double para_size = (tail_time_ - allreduce_inherent_time_) / allreduce_bandwidth_; - double to_cost = allreduce_graph_.max(); - int32_t fusion = 1; - while (to_cost != 0) { - MS_LOG(INFO) << "to_cost: " << to_cost << " para_size: " << para_size; - auto node_cost_pair = allreduce_graph_.GetParaByParaSize(to_cost, para_size); - MS_LOG(INFO) << "para size: " << node_cost_pair.first.size() << " from_cost: " << node_cost_pair.second; - auto paras = node_cost_pair.first; - if (FindMirrorAndSetFusion(paras, fusion) != SUCCESS) { - MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; - return FAILED; - } - fusion++; - para_size = ((to_cost - node_cost_pair.second) * computation_time_parameter_ - allreduce_inherent_time_) / - allreduce_bandwidth_; - to_cost = node_cost_pair.second; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusionByBackwardCompAndAllreduceTime succeed."; - return SUCCESS; -} - -Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { - if (algorithm == 1) { - return SetFusionByBackwardCompTime(); - } - return SetFusionByBackwardCompAndAllreduceTime(); -} - -Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { - if (ret == nullptr) { - MS_LOG(ERROR) << "ret is nullptr."; - return FAILED; - } - auto algorithm = CostModelContext::GetInstance()->costmodel_allreduce_fusion_algorithm(); - if (algorithm < 1 || algorithm > 2) { - MS_LOG(INFO) << "'costmodel_allreduce_fusion_algorithm' is " << algorithm << ". Bypass ProcessAllreduceFusion"; - return SUCCESS; - } - ret_ = ret; - root_graph_ = ret_->func_graph(); - MS_EXCEPTION_IF_NULL(root_graph_); - auto graph_set = ForwardGraph(root_graph_); - if (graph_set.size() > 1) { - MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; - return SUCCESS; - } - auto forward_graph = *(graph_set.begin()); - MS_EXCEPTION_IF_NULL(forward_graph); - forward_ret_ = forward_graph->get_return(); - MS_EXCEPTION_IF_NULL(forward_ret_); - - if (allreduce_graph_.set_head_cnode(forward_ret_) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceGraph set_head_cnode failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph set_head_cnode succeed."; - if (AddNodeToGraph() != SUCCESS) { - MS_LOG(ERROR) << "AddNodeToGraph failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph AddNodeToGraph succeed."; - if (AddEdgeToGraph() != SUCCESS) { - MS_LOG(ERROR) << "AddNodeToGraph failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph AddEdgeToGraph succeed."; - if (SetFusionByAlgorithm(algorithm) != SUCCESS) { - MS_LOG(ERROR) << "SetFusionByAlgorithm failed."; - return FAILED; - } - MS_LOG(DEBUG) << "AllreduceGraph SetFusionByAlgorithm succeed."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h deleted file mode 100644 index 43a9935095..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ - -#include -#include -#include "ir/anf.h" -#include "parallel/allreduce_fusion/allreduce_graph.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -using CNodeCostMap = std::unordered_map; - -constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM = 0; -constexpr int32_t DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES = 0; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; -constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; - -constexpr char FUSION[] = "fusion"; -constexpr char PARAMETER[] = "parameter"; -const uint32_t MAX_RECURSIVE_CALL_TIMES = 100; -class AllreduceFusion { - public: - AllreduceFusion() - : allreduce_graph_(), - ret_(nullptr), - forward_ret_(nullptr), - root_graph_(nullptr), - tail_time_(0), - allreduce_inherent_time_(0), - allreduce_bandwidth_(0), - computation_time_parameter_(0) {} - virtual ~AllreduceFusion() = default; - Status ProcessAllreduceFusion(const CNodePtr &ret); - - private: - Status AddNodeToGraph(); - CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; - CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; - Status AddEdgeToGraph(); - std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; - Status SetFusion(const std::vector &cost_map); - Status SetFusionByAlgorithm(int32_t algorithm); - Status SetFusionByBackwardCompTime(); - Status SetFusionByBackwardCompAndAllreduceTime(); - Status GetSetFusionByBackwardCompAndAllreduceTimeParams(); - - AllreduceGraph allreduce_graph_; - CNodePtr ret_; - CNodePtr forward_ret_; - FuncGraphPtr root_graph_; - double tail_time_; - double allreduce_inherent_time_; - double allreduce_bandwidth_; - double computation_time_parameter_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc deleted file mode 100644 index 2a98a38add..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc +++ /dev/null @@ -1,209 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/allreduce_graph.h" -#include -#include -#include "ir/anf.h" -#include "parallel/allreduce_fusion/allreduce_node.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { - AllreduceNodePtr arnode; - auto cnode_emplace_return = cnode_set_.emplace(node); - if (!cnode_emplace_return.second) { - MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!"; - auto cnode_arnode_pair = cnode_arnode_map_.find(node); - if (cnode_arnode_pair == cnode_arnode_map_.end()) { - MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!"; - } - arnode = cnode_arnode_pair->second; - } else { - arnode = std::make_shared(AllreduceNode()); - } - - if (arnode->Init(node) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceNode Init failed"; - return FAILED; - } - if (arnode->AddPara(para) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceNode AddPara failed"; - return FAILED; - } - cnode_arnode_map_[node] = arnode; - - auto arnode_emplace_return = arnode_set_.insert(arnode); - if (!arnode_emplace_return.second) { - MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!"; - } - cnode_emplace_return = para_cnodeset_map_[para].emplace(node); - if (!cnode_emplace_return.second) { - MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope() - << "'s cnodeset!"; - } - auto para_emplace_return = cnode_paraset_map_[node].emplace(para); - if (!para_emplace_return.second) { - MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString() - << "'s paraset!"; - } - return SUCCESS; -} - -Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { - auto from_arnode_iter = cnode_arnode_map_.find(from); - if (from_arnode_iter == cnode_arnode_map_.end()) { - MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; - PrintCNodeSet(); - return FAILED; - } - auto to_arnode_iter = cnode_arnode_map_.find(to); - if (to_arnode_iter == cnode_arnode_map_.end()) { - MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added"; - PrintCNodeSet(); - return FAILED; - } - auto from_arnode = from_arnode_iter->second; - auto to_arnode = to_arnode_iter->second; - if (from_arnode->AddNext(to_arnode) != SUCCESS) { - MS_LOG(ERROR) << "from_arnode AddNext failed"; - return FAILED; - } - if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) { - MS_LOG(ERROR) << "to_arnode AddPrev failed"; - return FAILED; - } - max_ = std::max(max_, to_arnode->depend_feat_size()); - MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString(); - MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size() - << ", to depend_feat_size: " << to_arnode->depend_feat_size(); - return SUCCESS; -} - -bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { - auto cnode_iter = cnode_set_.find(node); - return !(cnode_iter == cnode_set_.end()); -} - -std::vector AllreduceGraph::GetParaByCost(double from, double to) { - std::vector nodes; - for (auto &cnode_arnode : cnode_arnode_map_) { - MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() - << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() - << " curr_para_size: " << cnode_arnode.second->curr_para_size(); - if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) { - (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(), - cnode_paraset_map_[cnode_arnode.first].end()); - } - } - return nodes; -} - -std::pair, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) { - std::vector nodes; - double cur_para_size = 0; - double from = to; - for (auto &arnode : arnode_vec_) { - if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { - continue; - } - if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) { - return std::make_pair(nodes, from); - } - (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end()); - cur_para_size += arnode.curr_para_size(); - from = arnode.depend_feat_size(); - } - MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size - << " cur_para_size: " << cur_para_size << " from: " << from; - return std::make_pair(nodes, from); -} - -void AllreduceGraph::PrintCNodeSet() const { - MS_LOG(INFO) << "CNodeSet:"; - for (auto &cnode : cnode_set_) { - MS_LOG(INFO) << cnode->DebugString(); - } -} - -void AllreduceGraph::PrintAllredueGraphInfo() const { - MS_LOG(INFO) << "max: " << max_; - for (auto &cnode_arnode : cnode_arnode_map_) { - MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); - MS_LOG(INFO) << "arnode info: "; - cnode_arnode.second->ToString(); - } -} - -void AllreduceGraph::PrintArnodeVec() const { - MS_LOG(INFO) << "ArnodeVec:"; - for (auto &arnode : arnode_vec_) { - arnode.ToString(); - } -} - -void AllreduceGraph::PrintArnodeSet() const { - MS_LOG(INFO) << "ArnodeSet:"; - for (auto &arnode : arnode_set_) { - arnode->ToString(); - } -} - -void AllreduceGraph::SortArnode() { - arnode_vec_.clear(); - for (auto &node : arnode_set_) { - arnode_vec_.emplace_back(*node); - } - std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); -} - -Status AllreduceGraph::RemoveExtraParas() { - std::unordered_set para_map; - for (auto &node : arnode_vec_) { - for (auto ¶ : node.paras()) { - auto emplac_result = para_map.emplace(para); - if (!emplac_result.second) { - MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; - if (node.RemovePara(para) != SUCCESS) { - MS_LOG(ERROR) << "remove para failed"; - return FAILED; - } - } - } - } - return SUCCESS; -} - -Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { - auto arnode = std::make_shared(AllreduceNode()); - if (arnode->Init(node) != SUCCESS) { - MS_LOG(ERROR) << "AllreduceNode Init failed"; - } - head_cnode_ = node; - cnode_arnode_map_[node] = arnode; - auto arnode_emplace_return = arnode_set_.insert(arnode); - if (!arnode_emplace_return.second) { - MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!"; - } - auto cnode_emplace_return = cnode_set_.emplace(node); - if (!cnode_emplace_return.second) { - MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!"; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h deleted file mode 100644 index b2084b735c..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "parallel/allreduce_fusion/allreduce_node.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class AllreduceGraph { - public: - AllreduceGraph() - : head_cnode_(nullptr), - arnode_set_(), - arnode_vec_(), - cnode_set_(), - para_cnode_map_(), - para_cnodeset_map_(), - cnode_paraset_map_(), - cnode_arnode_map_(), - max_(0) {} - virtual ~AllreduceGraph() = default; - Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); - Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); - bool NodeInGraph(const CNodePtr &node) const; - std::vector GetParaByCost(double from, double to); - // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is - // over para_size. - // Return the parameter AnfNodePtr vector corresponding to these AllreduceNodes and the smallest depend_feat_size. - // If the sum of left AllreduceNode's parameter size is less than para_size, the returned depend_feat_size must be 0. - std::pair, double> GetParaByParaSize(double to, double para_size); - // If one parameter is used by multiple AllreduceNode, parameter belong to the last node for backward computation - // is saved by the corresponding AllreduceNode, parameters belong to other AllreduceNode are removed. - // Called during precise optimization, not implemented temporarily. - void SortArnode(); - Status RemoveExtraParas(); - void PrintCNodeSet() const; - void PrintAllredueGraphInfo() const; - void PrintArnodeVec() const; - void PrintArnodeSet() const; - const std::unordered_set &cnode_set() const { return cnode_set_; } - CNodePtr head_cnode() const { return head_cnode_; } - Status set_head_cnode(const CNodePtr &node); - double max() const { return max_; } - - private: - CNodePtr head_cnode_; - std::set arnode_set_; - std::vector arnode_vec_; - std::unordered_set cnode_set_; - // If One ParameterPtr is used by multiple CNode, the last node for backward computation is saved. - std::unordered_map> para_cnode_map_; - // One ParameterPtr may be used by multiple CNode - std::unordered_map> para_cnodeset_map_; - // Multiple Parameter may be inputs to the same CNode - std::unordered_map> cnode_paraset_map_; - std::unordered_map cnode_arnode_map_; - double max_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc deleted file mode 100644 index 113d4ec59b..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/allreduce_node.h" -#include -#include "parallel/tensor_layout/tensor_layout.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { - if (next_node == nullptr) { - MS_LOG(ERROR) << "next_node is nullptr!"; - return FAILED; - } - next_.emplace_back(next_node); - return SUCCESS; -} - -Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { - if (prev_node == nullptr) { - MS_LOG(ERROR) << "next_node is nullptr!"; - return FAILED; - } - if (dist <= 0) { - MS_LOG(ERROR) << "dist must be positive! dist: " << dist; - return FAILED; - } - prev_.emplace_back(prev_node); - double add_dist = prev_node->depend_feat_size() + dist; - depend_feat_size_ += add_dist; - if (depend_feat_size_ > *max) { - *max = depend_feat_size_; - } - std::queue next_queue; - for (auto &next : next_) { - next_queue.push(next); - } - while (!next_queue.empty()) { - auto ele = next_queue.front(); - ele->AddDependFeatSize(add_dist); - if (ele->depend_feat_size() > *max) { - *max = ele->depend_feat_size(); - } - for (auto &next : ele->next()) { - next_queue.push(next); - } - next_queue.pop(); - } - return SUCCESS; -} - -Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { - if (cnode_ptr == nullptr) { - MS_LOG(ERROR) << "cnode_ptr is nullptr!"; - return FAILED; - } - cnode_ptr_ = cnode_ptr; - return SUCCESS; -} - -Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { - if (node_ptr == nullptr) { - MS_LOG(ERROR) << "node_ptr is nullptr!"; - return FAILED; - } - if (!node_ptr->isa()) { - MS_LOG(ERROR) << "node_ptr is not a ParameterPtr!"; - return FAILED; - } - auto para_ptr = node_ptr->cast(); - MS_EXCEPTION_IF_NULL(para_ptr); - auto layout_ptr = para_ptr->tensor_layout(); - if (layout_ptr == nullptr) { - MS_LOG(ERROR) << "layout_ptr is nullptr!"; - return FAILED; - } - auto emplace_return = paras_.emplace(node_ptr); - if (emplace_return.second) { - double para_size = static_cast(layout_ptr->slice_shape().size()); - curr_para_size_ += para_size; - para_size_map_[node_ptr] = para_size; - } else { - MS_LOG(INFO) << "node already exist!"; - } - return SUCCESS; -} - -Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { - if (node_ptr == nullptr) { - MS_LOG(ERROR) << "node_ptr is nullptr!"; - return FAILED; - } - auto erase_num = paras_.erase(node_ptr); - if (erase_num == 0) { - MS_LOG(ERROR) << "para not find!"; - return FAILED; - } - curr_para_size_ -= para_size_map_[node_ptr]; - return SUCCESS; -} - -void AllreduceNode::ToString() const { - MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); - for (auto ¶ : paras_) { - MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); - } - MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h deleted file mode 100644 index db1c4e3f2e..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class AllreduceNode; -using AllreduceNodePtr = std::shared_ptr; - -class AllreduceNode { - public: - AllreduceNode() - : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} - Status Init(const CNodePtr &cnode_ptr); - Status AddPara(const AnfNodePtr &node_ptr); - Status RemovePara(const AnfNodePtr &node_ptr); - const std::unordered_set ¶s() const { return paras_; } - double curr_para_size() const { return curr_para_size_; } - virtual ~AllreduceNode() = default; - // Add previous node - // prev_node is the previous to be added - // max is the current max depend_feat_size of the AllreduceGraph - Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); - Status AddNext(const AllreduceNodePtr &next_node); - double depend_feat_size() const { return depend_feat_size_; } - void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } - const std::vector &next() const { return next_; } - void ToString() const; - bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } - bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } - - private: - CNodePtr cnode_ptr_; - std::vector prev_; - std::vector next_; - std::unordered_set paras_; - std::unordered_map para_size_map_; - double curr_para_size_; - double depend_feat_size_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_NODE_H_ diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc deleted file mode 100644 index 999c4a85a9..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.cc +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/allreduce_fusion/step_allreduce_fusion.h" -#include -#include -#include "optimizer/optimizer.h" -#include "parallel/allreduce_fusion/allreduce_fusion.h" -#include "parallel/context.h" -#include "parallel/graph_util/graph_info.h" -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); - bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion(); - // assume no change to graph - bool changes = false; - // control whether use model_parallel mode - if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || - (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { - return changes; - } -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); -#endif - MS_LOG(INFO) << "Now entering allreduce fusion"; - DumpGraph(root, std::string(ALLREDUCE_FUSION_BEGIN)); - - pipeline::ResourceBasePtr res = optimizer->resource(); - MS_EXCEPTION_IF_NULL(res); - - FuncGraphManagerPtr manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - CNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - - AllreduceFusion allreduce_fusion; - if (allreduce_fusion.ProcessAllreduceFusion(ret) != SUCCESS) { - MS_LOG(EXCEPTION) << "ProcessAllreduceFusion failed"; - } - - DumpGraph(root, std::string(ALLREDUCE_FUSION_END)); - - // allreduce fusion only run once - root->set_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY, true); - res->results()[pipeline::kStepParallelGraph] = root; -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - uint64_t time = 1000000 * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Now leaving allreduce fusion, used time: " << time << " us"; -#endif - return changes; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.h deleted file mode 100644 index 2343a7a2fe..0000000000 --- a/mindspore/ccsrc/parallel/allreduce_fusion/step_allreduce_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ - -#include "optimizer/optimizer.h" - -namespace mindspore { -namespace parallel { -constexpr char ALLREDUCE_FUSION_RUN_ONCE_ONLY[] = "allreduce_fusion_run_once_only"; -constexpr char ALLREDUCE_FUSION_BEGIN[] = "allreduce_fusion_begin"; -constexpr char ALLREDUCE_FUSION_END[] = "allreduce_fusion_end"; - -bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_ALLREDUCE_FUSION_STEP_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc deleted file mode 100644 index 65e9acf714..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/costmodel.h" -#include -#include -#include -#include "parallel/auto_parallel/graph_costmodel.h" - -namespace mindspore { -namespace parallel { -void Simplify(CostPtrList *clist_ptrs) { - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - SimplifyForDecreasingCommunicationWithPartialPara(clist_ptrs); - } else { - // inference phase - SimplifyForDecreasingCommunicationForward(clist_ptrs); - } -} -void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs) { - // Sort the cost_list with the computation_cost_ increasing, and communication_forward decreasing order. This method - // excludes the cost with greater computation_cost_ and greater communication_forward. - // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} - if (!COST_MODEL_SIMPLIFY_CALCULATION) { - return; - } - MS_EXCEPTION_IF_NULL(clist_ptrs); - std::vector id(clist_ptrs->size()); - std::iota(id.begin(), id.end(), size_t(0)); - std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { - return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; - }); - CostPtrList ret; - for (size_t i = 0; i < clist_ptrs->size(); ++i) { - if ((ret.size() == size_t(0)) || - (clist_ptrs->at(id[i])->communication_forward_ < ret.back()->communication_forward_)) { - ret.emplace_back(std::move(clist_ptrs->at(id[i]))); - } - } - *clist_ptrs = std::move(ret); -} - -void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { - // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing - // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. - if (!COST_MODEL_SIMPLIFY_CALCULATION) { - return; - } - MS_EXCEPTION_IF_NULL(clist_ptrs); - std::vector id(clist_ptrs->size()); - std::iota(id.begin(), id.end(), size_t(0)); - std::sort(id.begin(), id.end(), [&clist_ptrs](size_t x, size_t y) { - return clist_ptrs->at(x)->computation_cost_ < clist_ptrs->at(y)->computation_cost_; - }); - CostPtrList ret; - for (size_t i = 0; i < clist_ptrs->size(); ++i) { - if ((ret.size() == size_t(0)) || - (clist_ptrs->at(id[i])->communication_with_partial_para_ < ret.back()->communication_with_partial_para_)) { - ret.emplace_back(std::move(clist_ptrs->at(id[i]))); - } - } - *clist_ptrs = std::move(ret); -} - -void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { - MS_EXCEPTION_IF_NULL(origin_cost); - if (is_redistribution) { - // Redistribution cost - if ((origin_cost->communication_redis_forward_ > EPS) && - (origin_cost->communication_redis_forward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_redis_forward_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_redis_forward_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_redis_forward_ += COST_MODEL_COMMUNI_BIAS; - } - if ((origin_cost->communication_redis_backward_ > EPS) && - (origin_cost->communication_redis_backward_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_redis_backward_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_redis_backward_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_redis_backward_ += COST_MODEL_COMMUNI_BIAS; - } - origin_cost->communication_cost_ = - origin_cost->communication_redis_forward_ + origin_cost->communication_redis_backward_; - origin_cost->communication_without_parameter_ = origin_cost->communication_cost_; - origin_cost->communication_with_partial_para_ = origin_cost->communication_cost_; - } else { - // Operator cost - double backward = 0.0; - if (std::abs(origin_cost->communication_cost_ - origin_cost->communication_without_parameter_) > EPS) { - backward = origin_cost->communication_cost_ - origin_cost->communication_without_parameter_; - } - // forward cost - if ((origin_cost->communication_without_parameter_ > EPS) && - (origin_cost->communication_without_parameter_ <= COST_MODEL_COMMUNI_THRESHOLD)) { - origin_cost->communication_without_parameter_ = COST_MODEL_COMMUNI_CONST; - } else if (origin_cost->communication_without_parameter_ > COST_MODEL_COMMUNI_THRESHOLD) { - origin_cost->communication_without_parameter_ += COST_MODEL_COMMUNI_BIAS; - } - // total - if (origin_cost->communication_cost_ > EPS) { - origin_cost->communication_cost_ = origin_cost->communication_without_parameter_ + backward; - } - if (origin_cost->communication_with_partial_para_ > EPS) { - origin_cost->communication_with_partial_para_ = - origin_cost->communication_without_parameter_ + COST_MODEL_GAMMA * backward; - } - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h deleted file mode 100644 index 8b92e18cd8..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ /dev/null @@ -1,311 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ - -#include -#include -#include -#include -#include -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_info.h" - -namespace mindspore { -namespace parallel { -struct Decision; -using OperatorName = std::string; -using Attr = std::pair; -using Param = std::pair, int32_t>; -using OperatorParams = std::vector; -using OperatorAttrs = std::vector; -// OutPutInfo.fist: true if the operator's output is a tuple -// OutPutInfo.second: elements number of the tuple output. Only meaningful if OutPutInfo.fist is true. -using OutPutInfo = std::pair; -using OutPutInfoVector = std::vector; -using OperatorArgs = std::pair; -using Operator = std::pair; -using OperatorVector = std::vector; -using RedistributionOpListPtr = std::shared_ptr>; - -struct Cost { - Cost(); - Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) - : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { - memory_with_reuse_ = 0.0; - communication_without_parameter_ = 0.0; - communication_with_partial_para_ = 0.0; - communication_redis_forward_ = 0.0; - communication_redis_backward_ = 0.0; - communication_forward_ = 0.0; - } - // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase - double memory_with_reuse_; - // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated - // by ONLY forward phase - double computation_cost_; - // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) - double communication_cost_; - // communication_without_parameter_ = communication_cost_ - (backward communication from operators) - double communication_without_parameter_; - // communication_with_partial_para_ = - // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) - double communication_with_partial_para_; - // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. - double communication_forward_; - double communication_redis_forward_; - double communication_redis_backward_; - std::shared_ptr decision_ptr_; -}; - -using CostPtr = std::shared_ptr; -using CostPtrList = std::vector>; - -class StrategyWithCost { - public: - StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) - : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} - - StrategyWithCost(const StrategyWithCost &swc) = delete; - StrategyWithCost(StrategyWithCost &&swc) - : strategy_ptr(swc.strategy_ptr), - inputs_ptr(swc.inputs_ptr), - outputs_ptr(swc.outputs_ptr), - cost_list(swc.cost_list) {} - ~StrategyWithCost() = default; - - StrategyPtr strategy_ptr; - std::vector inputs_ptr; - std::vector outputs_ptr; - CostPtrList cost_list; -}; - -enum DecisionType { - OP_ELIMINATION, - EDGE_ELIMINATION, - MERGE_ELIMINATION, - CONTRACT_ELIMINATION, - TRIANGLE_ELIMINATION, - STAR_ELIMINATION, - FINAL_TYPE, - FINAL_SINGLE -}; - -struct Decision : public Base { - ~Decision() override = default; - DecisionType type_; -}; - -// 'OpEliminationDecision' is for the Operator Elimination in DP algorithm: u --> v --> w ==> u --> w. -// This data structure records the strategy 'op_strategy_' for v, the edge cost 'left_cost_' for 'u --> v', the -// operator cost 'middle_cost_' for v, and the edge cost 'right_cost_' for 'v --> w' -struct OpEliminationDecision : public Decision { - OpEliminationDecision(StrategyPtr op_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) - : op_strategy_(std::move(op_stra)), - left_cost_(std::move(l_cost)), - middle_cost_(std::move(m_cost)), - right_cost_(std::move(r_cost)) { - type_ = DecisionType::OP_ELIMINATION; - } - - StrategyPtr op_strategy_; - CostPtr left_cost_; - CostPtr middle_cost_; - CostPtr right_cost_; - MS_DECLARE_PARENT(OpEliminationDecision, Decision); -}; - -/* 'EdgeEliminationDecision' is for the Edge Elimination in DP algorithm: - ____ - / \ - u v ==> u --> v, which replace the multi-edges by a single edge. - \____/ - This data structure records the cost list for all edges 'edges_cost_list_' - */ -struct EdgeEliminationDecision : public Decision { - explicit EdgeEliminationDecision(CostPtrList cost_list) : edges_cost_list_(std::move(cost_list)) { - type_ = DecisionType::EDGE_ELIMINATION; - } - - CostPtrList edges_cost_list_; - MS_DECLARE_PARENT(EdgeEliminationDecision, Decision); -}; - -// 'MergeEliminationDecision' is for the Merge Elimination in DP algorithm: -// w -// | -// | ==> u --> v -// u --> v In the original graph, v has two alive incoming edges, w has one alive outgoing edge, -// and w has zero alive incoming edges. After the Merge Elimination, the result graph contains only 'u -- >v'. -// This data structure records the strategy 'merged_op_strategy_' for operator 'w', -// the cost 'merged_op_cost_' for operator 'w', and the edge cost 'edge_cost_' for 'w --> v'. -struct MergeEliminationDecision : public Decision { - MergeEliminationDecision(StrategyPtr op_stra, CostPtr op_cost, CostPtr edge_c, StrategyPtr tar_op_stra, - CostPtr target_op_c) - : merged_op_strategy_(std::move(op_stra)), - merged_op_cost_(std::move(op_cost)), - edge_cost_(std::move(edge_c)), - target_op_strategy_(std::move(tar_op_stra)), - target_op_cost_(std::move(target_op_c)) { - type_ = DecisionType::MERGE_ELIMINATION; - } - - StrategyPtr merged_op_strategy_; - CostPtr merged_op_cost_; - CostPtr edge_cost_; - StrategyPtr target_op_strategy_; - CostPtr target_op_cost_; - MS_DECLARE_PARENT(MergeEliminationDecision, Decision); -}; - -// 'ContractEliminationDecision' is for the Contract Elimination in DP algorithm: -// u --> v -// | -// | ==> u --> w -// w In the original graph, u has two alive outgoing edges, v has one alive incoming edge, -// and v has zero outgoing edge. After the Contract Elimination, the result graph contains only 'u --> w'. -// This data structure records the strategy 'contracted_op_strategy_' for operator 'v', the cost for -// operator 'contracted_op_cost_', and the edge cost for 'edge_cost_'. -struct ContractEliminationDecision : public Decision { - ContractEliminationDecision(StrategyPtr contra_stra, CostPtr contra_op_cost, CostPtr edge_cost, - StrategyPtr target_stra, CostPtr tar_cost) - : contracted_op_strategy_(std::move(contra_stra)), - contracted_op_cost_(std::move(contra_op_cost)), - edge_cost_(std::move(edge_cost)), - target_op_strategy_(std::move(target_stra)), - target_cost_(std::move(tar_cost)) { - type_ = DecisionType::CONTRACT_ELIMINATION; - } - - StrategyPtr contracted_op_strategy_; - CostPtr contracted_op_cost_; - CostPtr edge_cost_; - StrategyPtr target_op_strategy_; - CostPtr target_cost_; - MS_DECLARE_PARENT(ContractEliminationDecision, Decision); -}; - -/* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: - * - * u - * / \ - * / \ - * v --- w ==> v --- w In the original graph, u has 2 outgoing edges, v has 1 outgoing edge, - * and w has 2 incoming edges, u can be eliminated into v. - * 'eliminated_op_strategy_' is for u, 'eliminated_op_cost_' is for u, 'eliminated_left_edge_' is for edge u --> v, - * 'eliminated_right_edge_' is for edge u --> w. - */ -struct TriangleEliminationDecision : public Decision { - TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, - StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra) - : eliminated_op_strategy_(std::move(elimi_stra)), - eliminated_op_cost_(std::move(elimi_op_cost)), - left_edge_cost_(std::move(l_edge_cost)), - right_edge_cost_(std::move(r_edge_cost)), - left_node_strategy_(std::move(left_stra)), - left_node_cost_(std::move(l_node_cost)), - right_node_strategy_(std::move(right_stra)) { - type_ = DecisionType::TRIANGLE_ELIMINATION; - } - - StrategyPtr eliminated_op_strategy_; - CostPtr eliminated_op_cost_; - CostPtr left_edge_cost_; - CostPtr right_edge_cost_; - StrategyPtr left_node_strategy_; - CostPtr left_node_cost_; - StrategyPtr right_node_strategy_; - MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); -}; - -/* 'StarEliminationDecision' is for the Star Elimination in DP algorithm: - * - * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. - * In addition, v and w have other complicated connections, resulting in v and w can not be performed other - * eliminations. After the StarElimination, u is merged into v, and the resulting graph is splitted into multiple - * connected components. - * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - */ -struct StarEliminationDecision : public Decision { - StarEliminationDecision(StrategyPtr elimi_op_stra, CostPtr elimi_op_cost, CostPtrList succ_edges_clist, - std::vector succ_ops_stra_list, CostPtrList succ_ops_clist) - : eliminated_op_strategy_(std::move(elimi_op_stra)), - eliminated_op_cost_(std::move(elimi_op_cost)), - succ_edges_cost_list_(std::move(succ_edges_clist)), - succ_ops_stra_list_(std::move(succ_ops_stra_list)), - succ_ops_cost_list_(std::move(succ_ops_clist)) { - type_ = DecisionType::STAR_ELIMINATION; - } - - StrategyPtr eliminated_op_strategy_; - CostPtr eliminated_op_cost_; - CostPtrList succ_edges_cost_list_; - std::vector succ_ops_stra_list_; - CostPtrList succ_ops_cost_list_; - MS_DECLARE_PARENT(StarEliminationDecision, Decision); -}; - -// This data structure records the decision for the graph which contains two nodes: u --> v. This includes -// the strategy 'u_strategy_' for 'u', the strategy 'v_strategy_' for 'v', the cost 'left_cost_' for 'u'. -struct FinalDecision : public Decision { - FinalDecision(StrategyPtr u_stra, StrategyPtr v_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) - : u_strategy_(std::move(u_stra)), - v_strategy_(std::move(v_stra)), - left_cost_(std::move(l_cost)), - middle_cost_(std::move(m_cost)), - right_cost_(std::move(r_cost)) { - type_ = DecisionType::FINAL_TYPE; - } - - StrategyPtr u_strategy_; - StrategyPtr v_strategy_; - CostPtr left_cost_; - CostPtr middle_cost_; - CostPtr right_cost_; - MS_DECLARE_PARENT(FinalDecision, Decision); -}; - -// This data structure records the final decision for the graph containing a single node: u. This includes -// the strategy 'u_strategy_' for 'u', the cost 'u_cost_' for 'u'. -struct FinalSingleDecision : public Decision { - FinalSingleDecision(StrategyPtr u_stra, CostPtr u_cost) : u_strategy_(std::move(u_stra)), u_cost_(std::move(u_cost)) { - type_ = DecisionType::FINAL_SINGLE; - } - - StrategyPtr u_strategy_; - CostPtr u_cost_; - MS_DECLARE_PARENT(FinalSingleDecision, Decision); -}; - -using DecisionPtr = std::shared_ptr; -using OpEliminationDecisionPtr = std::shared_ptr; -using EdgeEliminationDecisionPtr = std::shared_ptr; -using MergeEliminationDecisionPtr = std::shared_ptr; -using ContractEliminationDecisionPtr = std::shared_ptr; -using TriangleEliminationDecisionPtr = std::shared_ptr; -using StarEliminationDecisionPtr = std::shared_ptr; -using FinalDecisionPtr = std::shared_ptr; -using FinalSingleDecisionPtr = std::shared_ptr; - -void Simplify(CostPtrList *clist); -void SimplifyForDecreasingCommunicationForward(CostPtrList *clist); -void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist); -void RefineForPracticalCost(const CostPtr &, bool is_redistribution); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc deleted file mode 100644 index 72451fab57..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ /dev/null @@ -1,226 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/dp_algo_costmodel.h" - -#include -#include -#include - -namespace mindspore { -namespace parallel { -Status GetStrategy(const CostGraphPtr &graph) { - MS_LOG(INFO) << "Searching strategies begins."; - MS_EXCEPTION_IF_NULL(graph); - std::vector eliminations; - bool flag = true; - - // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order. - // Note: the checking and applying of the 6 operations MUST in current order. - while (flag) { - flag = false; - auto node = graph->CheckOpElimination(); - if (node != nullptr) { - // Applying the Operator Elimination - flag = true; - auto l_edge = node->GetAlivePrevEdges()[0]; - auto r_edge = node->GetAliveSuccEdges()[0]; - auto n_edge = graph->EliminationOp(node); - auto elimi = std::make_shared(n_edge, l_edge, node, r_edge); - eliminations.emplace_back(std::move(elimi)); - } - auto edges = graph->CheckEdgeElimination(); - if ((!flag) && (!edges.empty())) { - // Applying the Edge Elimination - flag = true; - auto n_edge = graph->EliminationEdges(edges); - auto elimi = std::make_shared(n_edge, edges); - eliminations.emplace_back(std::move(elimi)); - } - auto merge_node = graph->CheckMergeElimination(); - if ((!flag) && (merge_node != nullptr)) { - // Applying the Merge Elimination - flag = true; - auto succ_edge = merge_node->GetAliveSuccEdges()[0]; - auto target_node = graph->EliminationMerge(merge_node); - auto elimi = std::make_shared(merge_node, succ_edge, target_node); - eliminations.emplace_back(std::move(elimi)); - } - auto contracted_node = graph->CheckContractElimination(); - if ((!flag) && (contracted_node != nullptr)) { - // Applying the Contract Elimination - flag = true; - auto prev_edge = contracted_node->GetAlivePrevEdges()[0]; - auto target_node = graph->EliminationContract(contracted_node); - auto elimi = std::make_shared(target_node, prev_edge, contracted_node); - eliminations.emplace_back(std::move(elimi)); - } - auto triangle_pair = graph->CheckTriangleElimination(); - if ((!flag) && (triangle_pair.first != nullptr)) { - // Applying the Triangle Elimination - flag = true; - auto eliminated_node = triangle_pair.first; - auto l_r_edge = triangle_pair.second; - - auto left_node = l_r_edge->prev_operator(); - auto left_edge = eliminated_node->GetAliveSuccEdges()[0]; - auto right_edge = eliminated_node->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(left_edge); - if (left_edge->next_operator() != left_node) { - auto tmp = left_edge; - left_edge = right_edge; - right_edge = tmp; - } - auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge); - auto right_node = l_r_edge->next_operator(); - auto elimi = - std::make_shared(eliminated_node, left_edge, left_node_cpy, right_edge, right_node); - eliminations.emplace_back(std::move(elimi)); - } - auto star_center = graph->CheckStarElimination(); - if ((!flag) && (star_center != nullptr)) { - // Applying the Star Elimination - flag = true; - auto succ_edges = graph->EliminationStar(star_center); - std::vector succ_nodes; - for (size_t i = 0; i < succ_edges.size(); ++i) { - MS_EXCEPTION_IF_NULL(succ_edges[i]); - succ_nodes.push_back(succ_edges[i]->next_operator()); - } - auto elimi = std::make_shared(star_center, succ_edges, succ_nodes); - eliminations.emplace_back(std::move(elimi)); - } - } - - // Phase 2: Search the cost_list in the final graph, and determine the optimal one - if (graph->SearchStrategy() != SUCCESS) { - MS_LOG(ERROR) << "Searching strategy for the final failed."; - return FAILED; - } - - // Phase 3: Recover the original CostGraph, the determine strategy for each operator - if (RecoverStrategy(eliminations) == SUCCESS) { - MS_LOG(INFO) << "Searching strategies ends."; - return SUCCESS; - } else { - MS_LOG(EXCEPTION) << "Searching strategies failed."; - } -} - -Status RecoverStrategy(std::vector eliminations) { - std::vector::reverse_iterator rit; - - for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) { - if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto e = elimination->new_edge_; - auto w = elimination->op_; - MS_EXCEPTION_IF_NULL(e); - MS_EXCEPTION_IF_NULL(w); - auto left_edge = elimination->left_edge_; - auto right_edge = elimination->right_edge_; - MS_EXCEPTION_IF_NULL(left_edge); - MS_EXCEPTION_IF_NULL(right_edge); - auto decision = e->selected_cost()->decision_ptr_->cast(); - w->SetSelectedStrategyAndCost(decision->op_strategy_, decision->middle_cost_); - left_edge->set_selected_cost(decision->left_cost_); - right_edge->set_selected_cost(decision->right_cost_); - MS_LOG(INFO) << "Recover opElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto new_edge = elimination->new_edge_; - MS_EXCEPTION_IF_NULL(new_edge); - auto &edges = elimination->edges_; - auto decision = new_edge->selected_cost()->decision_ptr_->cast(); - for (size_t j = 0; j < edges.size(); ++j) { - MS_EXCEPTION_IF_NULL(edges[j]); - edges[j]->set_selected_cost(decision->edges_cost_list_[j]); - } - MS_LOG(INFO) << "Recover edgeElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto target_node = elimination->target_node_; - MS_EXCEPTION_IF_NULL(target_node); - auto merged_node = elimination->merged_node_; - MS_EXCEPTION_IF_NULL(merged_node); - auto merged_edge = elimination->dir_edge_; - MS_EXCEPTION_IF_NULL(merged_edge); - MS_EXCEPTION_IF_NULL(target_node->selected_cost()); - MS_EXCEPTION_IF_NULL(target_node->selected_cost()->decision_ptr_); - auto decision = target_node->selected_cost()->decision_ptr_->cast(); - merged_node->SetSelectedStrategyAndCost(decision->merged_op_strategy_, decision->merged_op_cost_); - merged_edge->set_selected_cost(decision->edge_cost_); - target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_op_cost_); - - MS_LOG(INFO) << "Recover mergeElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto target_node = elimination->target_node_; - auto contracted_node = elimination->contracted_node_; - auto contracted_edge = elimination->dir_edge_; - auto decision = target_node->selected_cost()->decision_ptr_->cast(); - - contracted_node->SetSelectedStrategyAndCost(decision->contracted_op_strategy_, decision->contracted_op_cost_); - contracted_edge->set_selected_cost(decision->edge_cost_); - target_node->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_cost_); - MS_LOG(INFO) << "Recover contractElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto left_node = elimination->left_node_; - auto left_edge = elimination->left_edge_; - auto eliminated_node = elimination->eliminated_node_; - auto right_edge = elimination->right_edge_; - auto right_node = elimination->right_node_; - auto decision = left_node->selected_cost()->decision_ptr_->cast(); - - eliminated_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); - left_edge->set_selected_cost(decision->left_edge_cost_); - right_edge->set_selected_cost(decision->right_edge_cost_); - // Since Triangle is eliminated into 'left_node', only 'left_node' is needed to recover the strategy. - left_node->SetSelectedStrategyAndCost(decision->left_node_strategy_, decision->left_node_cost_); - right_node->CheckSelectedStrategy(decision->right_node_strategy_); - MS_LOG(INFO) << "Recover triangleElimination succeeded."; - } else if ((*rit)->isa()) { - auto elimination = (*rit)->cast(); - auto merged_node = elimination->eliminated_node_; - auto succ_edges = elimination->succ_edges_; - auto succ_nodes = elimination->succ_ops_; - // decision is hided in succ_nodes[0] - auto decision = succ_nodes[0]->selected_cost()->decision_ptr_->cast(); - - merged_node->SetSelectedStrategyAndCost(decision->eliminated_op_strategy_, decision->eliminated_op_cost_); - for (size_t i = 0; i < succ_edges.size(); ++i) { - succ_edges[i]->set_selected_cost(decision->succ_edges_cost_list_[i]); - } - MS_EXCEPTION_IF_NULL(succ_nodes[0]); - MS_EXCEPTION_IF_NULL(decision->succ_ops_stra_list_[0]); - MS_EXCEPTION_IF_NULL(decision->succ_ops_cost_list_[0]); - // Since Star is eliminated into 'succ_nodes[0]', only 'succ_nodes[0]' is needed to recover the strategy. - succ_nodes[0]->SetSelectedStrategyAndCost(decision->succ_ops_stra_list_[0], decision->succ_ops_cost_list_[0]); - for (size_t k = 1; k < succ_nodes.size(); ++k) { - succ_nodes[k]->CheckSelectedStrategy(decision->succ_ops_stra_list_[k]); - } - MS_LOG(INFO) << "Recover starElimination succeeded."; - } else { - MS_LOG(ERROR) << "Unknown Elimination type."; - return FAILED; - } - } - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h deleted file mode 100644 index e3fbfba5a7..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ - -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" - -namespace mindspore { -namespace parallel { -// There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal -// is to compute the strategy for each operator in the CostGraph. -// -// Phase 1: Shrink the CostGraph using 6 operations, and record them in the order -// Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination, -// each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the -// interpretation of 6 operations in costmodel.h. -// Phase 2: Search the cost_list in the final graph, and determine the optimal one -// Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity -// COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost -// Phase 3: Recover the original CostGraph, the determine strategy for each operator -// After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying -// the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, -// the operators' strategies can be all determined. - -struct Elimination : public Base { - enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, TRIANGLE, STAR }; - Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} - - EdgePtr new_edge_; - EliminationType type_; -}; - -// Operator Elimination -struct OpElimination : public Elimination { - OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge) - : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA), - left_edge_(std::move(l_edge)), - op_(std::move(op_info)), - right_edge_(std::move(r_edge)) {} - - EdgePtr left_edge_; - OperatorInfoPtr op_; - EdgePtr right_edge_; - MS_DECLARE_PARENT(OpElimination, Elimination); -}; - -// Edge Elimination -struct EdgeElimination : public Elimination { - EdgeElimination(const EdgePtr &n_edge, std::vector eds) - : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} - - std::vector edges_; - MS_DECLARE_PARENT(EdgeElimination, Elimination); -}; - -// Merge Elimination -struct MergeElimination : public Elimination { - MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info) - : Elimination(nullptr, Elimination::EliminationType::MERGE), - merged_node_(std::move(u_info)), - dir_edge_(std::move(merged_target_edge)), - target_node_(std::move(v_info)) {} - - OperatorInfoPtr merged_node_; - EdgePtr dir_edge_; - OperatorInfoPtr target_node_; - MS_DECLARE_PARENT(MergeElimination, Elimination); -}; - -// Contract Elimination -struct ContractElimination : public Elimination { - ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info) - : Elimination(nullptr, Elimination::EliminationType::CONTRACT), - contracted_node_(std::move(con_info)), - dir_edge_(std::move(tar_con_edge)), - target_node_(std::move(tar_info)) {} - - OperatorInfoPtr contracted_node_; - EdgePtr dir_edge_; - OperatorInfoPtr target_node_; - MS_DECLARE_PARENT(ContractElimination, Elimination); -}; - -// Triangle Elimination -struct TriangleElimination : public Elimination { - TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, - OperatorInfoPtr r_node) - : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), - eliminated_node_(std::move(elim_node)), - left_edge_(std::move(l_edge)), - left_node_(std::move(l_node)), - right_edge_(std::move(r_edge)), - right_node_(std::move(r_node)) {} - - OperatorInfoPtr eliminated_node_; - EdgePtr left_edge_; - OperatorInfoPtr left_node_; - EdgePtr right_edge_; - OperatorInfoPtr right_node_; - MS_DECLARE_PARENT(TriangleElimination, Elimination); -}; - -// Star Elimination -struct StarElimination : public Elimination { - StarElimination(OperatorInfoPtr elimi_node, std::vector s_edges, std::vector s_ops) - : Elimination(nullptr, Elimination::EliminationType::STAR), - eliminated_node_(std::move(elimi_node)), - succ_edges_(std::move(s_edges)), - succ_ops_(std::move(s_ops)) {} - - OperatorInfoPtr eliminated_node_; - std::vector succ_edges_; - std::vector succ_ops_; - MS_DECLARE_PARENT(StarElimination, Elimination); -}; - -using EliminationPtr = std::shared_ptr; -using OpEliminationPtr = std::shared_ptr; -using EdgeEliminationPtr = std::shared_ptr; -using MergeEliminationPtr = std::shared_ptr; -using ContractEliminationPtr = std::shared_ptr; -using TriangleEliminationPtr = std::shared_ptr; -using StarEliminationPtr = std::shared_ptr; - -// Phase 1 and Phase 2 -Status GetStrategy(const CostGraphPtr &graph); - -// Phase 3 -Status RecoverStrategy(std::vector eliminations); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc deleted file mode 100644 index 60256a3ae3..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/edge_costmodel.h" - -#include -#include -#include -#include -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status Edge::InitEdgeCost() { - bool has_available_cost = false; - for (auto &swc : prev_op_->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(swc); - pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); - } - for (auto &swc : next_op_->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(swc); - next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); - } - if (is_identity_edge) { - for (auto &target_output : pre_op_output_) { - auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); - auto target_output_str = target_output.first; - for (auto &target_input : next_op_input_) { - auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); - auto target_input_str = target_input.first; - if (target_output_lyt == target_input_lyt) { - CostPtrKey ck = {target_output_str, target_input_str}; - CostPtr cost = std::make_shared(0.0, 0.0); - MS_EXCEPTION_IF_NULL(cost); - cost->communication_without_parameter_ = 0.0; - cost->communication_with_partial_para_ = 0.0; - CostPtrList cl; - cl.push_back(cost); - (void)cost_map_.emplace(std::make_pair(ck, cl)); - has_available_cost = true; - } - } - } - } else { - for (auto &target_output : pre_op_output_) { - auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); - auto target_output_str = target_output.first; - auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; - auto type = prev_op_->outputs_type()[prev_op_output_index_]; - for (auto &target_input : next_op_input_) { - auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); - auto target_input_str = target_input.first; - CostPtr cost; - if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) { - MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed"; - } - MS_EXCEPTION_IF_NULL(cost); - MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_ - << ", communication_cost: " << cost->communication_cost_ - << ", communication_without_parameter_: " << cost->communication_without_parameter_ - << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << "."; - // refine communication cost calculation for practice - RefineForPracticalCost(cost, true); - cost->communication_forward_ = cost->communication_redis_forward_; - CostPtrKey ck = {target_output_str, target_input_str}; - CostPtrList cl; - cl.push_back(cost); - (void)cost_map_.emplace(std::make_pair(ck, cl)); - has_available_cost = true; - } - } - } - if (!has_available_cost) { - if (FULLY_USE_DEVICES) { - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ - << " failed, it may be caused by setting 'fully_use_devices' true. Try to set " - "'fully_use_devices' false."; - } else if (ELEMENTWISE_OP_STRA_FOLLOW) { - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ - << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. " - "Try to set 'elementwise_op_strategy_follow' false."; - } - if (edge_name_.find(RESHAPE) != std::string::npos) { - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ - << " failed, it may be caused by setting different strategies for operators following Reshape. " - "Try to fix that."; - } - MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed."; - } - return Status::SUCCESS; -} - -Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, - size_t type_length, TypePtr type, CostPtr *cost) { - MS_EXCEPTION_IF_NULL(prev_op_); - MS_EXCEPTION_IF_NULL(cost); - RankList dev_list = prev_op_->global_device_list(); - TensorRedistribution tensor_redistribution(false); - - // Init TensorRedistribution - if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; - } - - if (tensor_redistribution.ComputeCost() == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; - } - - double comm_cost = tensor_redistribution.comm_cost(); - double forward_comm_cost = tensor_redistribution.forward_comm_cost(); - double backward_comm_cost = tensor_redistribution.backward_comm_cost(); - double computation_cost = tensor_redistribution.computation_cost(); - double mem_cost = tensor_redistribution.memory_cost(); - - // Now AllGather, ReduceScatter, AlltoAll don't support bool type - MS_EXCEPTION_IF_NULL(type); - if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) { - computation_cost = INF; - comm_cost = INF; - MS_LOG(WARNING) << "Communication Operators don't support bool dtype!"; - } - *cost = std::make_shared(type_length * computation_cost, type_length * comm_cost); - (*cost)->communication_without_parameter_ = type_length * comm_cost; - (*cost)->communication_with_partial_para_ = - (*cost)->communication_without_parameter_ + - COST_MODEL_GAMMA * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_); - (*cost)->communication_redis_forward_ = type_length * forward_comm_cost; - (*cost)->communication_redis_backward_ = type_length * backward_comm_cost; - (*cost)->memory_with_reuse_ = mem_cost; - return Status::SUCCESS; -} - -CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { - CostPtrKey ck = {output_str, input_str}; - CostPtrList result; - if (cost_map_.find(ck) != cost_map_.end()) { - return cost_map_.at(ck); - } - return result; -} - -CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, - const StrategyPtr &input_st_ptr) { - std::function LocalGetCostList = [&](const EdgePtr &edge) { - MS_EXCEPTION_IF_NULL(edge); - return edge->GetCostList(output_st_ptr, input_st_ptr); - }; - CostPtrList result; - std::vector all_cost_list; - all_cost_list.resize(edges.size()); - (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList); - - CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - std::function recursive = - [&](size_t k, double computation, double memory, double communication, double communication_without_para, - double communication_forward) { - if (k == edges.size()) { - auto decision = std::make_shared(selected_cost_list); - CostPtr new_cost = std::make_shared(computation, communication); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->memory_with_reuse_ = memory; - new_cost->communication_forward_ = communication_forward; - new_cost->decision_ptr_ = decision; - result.push_back(new_cost); - return; - } - for (auto &c : all_cost_list[k]) { - MS_EXCEPTION_IF_NULL(c); - selected_cost_list[k] = c; - recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, - communication + c->communication_cost_, - communication_without_para + c->communication_without_parameter_, - communication_forward + c->communication_forward_); - } - }; - recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0); - Simplify(&result); - return result; -} - -void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { - bool valid = false; - for (const auto &output_pair : pre_op_output_) { - StrategyPtr output_st_ptr = output_pair.first; - for (const auto &input_pair : next_op_input_) { - StrategyPtr input_st_ptr = input_pair.first; - CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); - CostPtrKey key = {output_st_ptr, input_st_ptr}; - cost_map_[key] = clist; - if ((!valid) && (!clist.empty())) { - valid = true; - } - } - } - if (!valid) { - MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; - } -} - -void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, - const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, - CostPtrList *ret_cost_list) { - for (auto &left_cost : left_cost_list) { - MS_EXCEPTION_IF_NULL(left_cost); - for (auto &middle_cost : middle_cost_list) { - MS_EXCEPTION_IF_NULL(middle_cost); - for (auto &right_cost : right_cost_list) { - MS_EXCEPTION_IF_NULL(right_cost); - double computation = - left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; - double communication = - left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_; - double communication_forward = - left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_; - double communication_without_para = left_cost->communication_without_parameter_ + - middle_cost->communication_without_parameter_ + - right_cost->communication_without_parameter_; - double memory_cost = - left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_; - - auto decision = std::make_shared(op_strategy, left_cost, middle_cost, right_cost); - auto cost = std::make_shared(computation, communication, decision); - MS_EXCEPTION_IF_NULL(cost); - cost->communication_without_parameter_ = communication_without_para; - cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - cost->memory_with_reuse_ = memory_cost; - cost->communication_forward_ = communication_forward; - ret_cost_list->emplace_back(std::move(cost)); - } - } - } -} - -CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, - const OperatorInfoPtr &op, const EdgePtr &e2, - const StrategyPtr &input_st_ptr) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(e1); - MS_EXCEPTION_IF_NULL(e2); - CostPtrList result; - for (const auto &op_strategy : op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(op_strategy); - auto middle_strategy = op_strategy->strategy_ptr; - CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), - op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result); - } - Simplify(&result); - return result; -} - -void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { - bool valid = false; - for (const auto &output_pair : pre_op_output_) { - StrategyPtr output_st_ptr = output_pair.first; - for (const auto &input_pair : next_op_input_) { - StrategyPtr input_st_ptr = input_pair.first; - - CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); - CostPtrKey key = {output_st_ptr, input_st_ptr}; - cost_map_[key] = clist; - if ((!valid) && (!clist.empty())) { - valid = true; - } - } - } - if (!valid) { - MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed."; - } -} - -Status Edge::CalculateMemoryCost() { - if (is_output_parameter_involve_ == -1) { - MS_LOG(ERROR) << "is_output_parameter_involve_ is unset."; - return FAILED; - } - if (is_output_parameter_involve_ == 0) { - // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is - // unnecessary to keep them in memory. - for (auto &cost_kv : cost_map_) { - auto &cost_v = cost_kv.second; - if (!cost_v.empty()) { - cost_v[0]->memory_with_reuse_ = 0; - } - } - } - - return SUCCESS; -} - -Status Edge::CalculateMemoryCostForInference() { - // Currently, memory cost is NOT calculated for redistribution - if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) { - MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_; - return FAILED; - } - for (auto &cost_kv : cost_map_) { - auto &cost_v = cost_kv.second; - if (!cost_v.empty()) { - cost_v[0]->memory_with_reuse_ = 0; - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h deleted file mode 100644 index 2a5ed3b2a4..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ -#define PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ - -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/tensor_layout/tensor_info.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -using CostPtrKey = std::pair; -using OperatorInfoPtr = std::shared_ptr; -using EdgePtr = std::shared_ptr; - -class Edge { - // An 'Edge' connects two Operators in the CostGraph. - public: - Edge(const std::string &edge_name, const std::shared_ptr &prev_op, - const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, - const bool &is_com) - : edge_name_(edge_name), - prev_op_(prev_op), - next_op_(next_op), - prev_op_output_index_(output_index_), - next_op_input_index_(input_index_), - is_combined_(is_com) { - is_identity_edge = false; - } - - Edge(const std::string &edge_name, const std::shared_ptr &prev_op, - const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, - const bool &is_com, const bool &is_iden) - : edge_name_(edge_name), - prev_op_(prev_op), - next_op_(next_op), - prev_op_output_index_(output_index_), - next_op_input_index_(input_index_), - is_combined_(is_com), - is_identity_edge(is_iden) {} - - Edge(const std::string &edge_name, const std::shared_ptr &prev_op, - const std::shared_ptr &next_op, const std::vector &output_indexs_, - const std::vector &input_indexs_, const bool &is_com) - : edge_name_(edge_name), - prev_op_(prev_op), - next_op_(next_op), - pre_op_output_indexs_(output_indexs_), - next_op_input_indexs_(input_indexs_), - is_combined_(is_com) { - prev_op_output_index_ = 0; - next_op_input_index_ = 0; - is_identity_edge = false; - } - - ~Edge() = default; - std::shared_ptr prev_operator() const { return prev_op_; } - std::shared_ptr next_operator() const { return next_op_; } - std::string edge_name() const { return edge_name_; } - // Init cost_map_: for each output layout and input layout, calculate the cost - Status InitEdgeCost(); - // For two operators u--->v, given the output tensor layout of u, - // and the input tensor layout of v, return the redistribution cost, - // and the op_list to carry out the redistribution. - Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, - size_t, TypePtr type, CostPtr *cost); - - void set_pre_op_output(const std::vector, std::vector>> &output_set) { - pre_op_output_ = output_set; - } - void set_next_op_input(const std::vector, std::vector>> &input_set) { - next_op_input_ = input_set; - } - - // Given a pair of output strategy and input strategy, return the corresponding costlist - CostPtrList GetCostList(StrategyPtr output_str, StrategyPtr input_str); - - std::vector, std::vector>> prev_op_output() const { - return pre_op_output_; - } - std::vector, std::vector>> next_op_input() const { - return next_op_input_; - } - - bool is_combined() const { return is_combined_; } - size_t prev_op_output_index() const { return prev_op_output_index_; } - size_t next_op_input_index() const { return next_op_input_index_; } - std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } - std::vector next_op_input_indexs() const { return next_op_input_indexs_; } - - CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, - const std::vector> &edges, - const StrategyPtr &input_st_ptr); - // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to - // set cost for this new edge - void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, - std::shared_ptr v); - void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, - const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, - CostPtrList *ret_cost_list); - - CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, - const std::shared_ptr &op, const std::shared_ptr &e2, - const StrategyPtr &input_st_ptr); - // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. - // This method is used to set cost for this new edge - void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, - const std::shared_ptr &e2); - - void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } - const CostPtr &selected_cost() const { return selected_cost_; } - void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } - // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving - // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory - // at the end of forward phase. - Status CalculateMemoryCost(); - // In the inference phase, - Status CalculateMemoryCostForInference(); - void mark_output_critical() { is_output_critical_ = 1; } - - private: - std::string edge_name_; - std::shared_ptr prev_op_, next_op_; - std::map cost_map_; - // pre_op_output_ - std::vector, std::vector>> pre_op_output_; - std::vector, std::vector>> next_op_input_; - // the index of outputs of prev_op, and the index of inputs of next_op - size_t prev_op_output_index_, next_op_input_index_; - - // pre_op_output_indexs_ and next_op_input_indexs_ store the indexs of inputs and outputs if is_combined = true - std::vector pre_op_output_indexs_; - std::vector next_op_input_indexs_; - // is this edge constructed by combining multiple edges? If is is, then is_combined = true, else is_combined = false - bool is_combined_; - // When a Parameter in the ANF graph being used by multiple operators, we include the Parameter in the costgraph by - // replace the Parameter by a TmpIdentity operator, and connecting this TmpIdentity operator with subsequent - // operators. The resulting edges are different from those normal edges, thus this Bool variable distinguishes them. - // If it is true, then we should guarantee that the strategy for output tensor consistent with the input tensor. - bool is_identity_edge; - CostPtr selected_cost_; - // In the training phase, 'is_output_parameter_involve_' is used to mark whether the output of the previous operator - // is parameter-involved - int is_output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved - // In the inference phase, this is used to mark whether the output of the previous operator is critical. - int is_output_critical_ = 0; -}; -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_EDGE_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc deleted file mode 100644 index d5523aaa62..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ /dev/null @@ -1,1677 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/step_auto_parallel.h" - -namespace mindspore { -namespace parallel { -CostGraphPtr entire_costgraph = nullptr; -size_t TOTAL_OPS = 0; -double COST_MODEL_GAMMA = DEFAULT_COST_MODEL_GAMMA; -bool COST_MODEL_SIMPLIFY_CALCULATION = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; -double DEVICE_MEMORY_CAPACITY = DEFAULT_DEVICE_MEMORY_CAPACITY; -double COST_MODEL_COMMUNI_THRESHOLD = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; -double COST_MODEL_COMMUNI_CONST = DEFAULT_COST_MODEL_COMMUNI_CONST; -double COST_MODEL_COMMUNI_BIAS = DEFAULT_COST_MODEL_COMMUNI_BIAS; -bool TENSOR_SLICE_ALIGNMENT_ENABLE = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; -size_t TENSOR_SLICE_ALIGNMENT_SIZE = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; -bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; -bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; -bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; -int32_t RUN_PHASE = DEFAULT_RUN_PHASE; - -void CostGraph::SetDeviceMemoryAndCostParameter() { - MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); - - // DEVICE_MEMORY_CAPACITY - auto device_memory = CostModelContext::GetInstance()->device_memory_capacity(); - if (device_memory <= 0) { - MS_LOG(EXCEPTION) << "'device_memory_capacity' must be positive."; - } - dev_memory_ = device_memory; - DEVICE_MEMORY_CAPACITY = device_memory; - MS_LOG(INFO) << "device_memory_capacity: " << DEVICE_MEMORY_CAPACITY << "."; - - // COST_MODEL_ALPHA - auto alpha = CostModelContext::GetInstance()->costmodel_alpha(); - if (alpha <= 0) { - MS_LOG(EXCEPTION) << "'costmodel_alpha' must be positive."; - } - costmodel_alpha_ = alpha; - MS_LOG(INFO) << "costmodel_alpha: " << costmodel_alpha_ << "."; - - // COST_MODEL_BETA - auto beta = CostModelContext::GetInstance()->costmodel_beta(); - if (beta <= 0) { - MS_LOG(EXCEPTION) << "'costmodel_beta' must be positive."; - } - costmodel_beta_ = beta; - MS_LOG(INFO) << "costmodel_beta: " << costmodel_beta_ << "."; - - // COST_MODEL_GAMMA - auto gamma = CostModelContext::GetInstance()->costmodel_gamma(); - if ((gamma < 0) || (gamma > 1)) { - MS_LOG(EXCEPTION) << "'costmodel_gamma' must in [0, 1]."; - } - COST_MODEL_GAMMA = gamma; - MS_LOG(INFO) << "costmodel_gamma: " << COST_MODEL_GAMMA << "."; - - // COST_MODEL_SIMPLIFY_CALCULATION - auto simplify = CostModelContext::GetInstance()->costmodel_simplify_cal(); - COST_MODEL_SIMPLIFY_CALCULATION = simplify; - if (COST_MODEL_SIMPLIFY_CALCULATION) { - MS_LOG(INFO) << "costmodel_simplify_cal: true."; - } else { - MS_LOG(INFO) << "costmodel_simplify_cal: false."; - } - - // COST_MODEL_COMMUNI_THRESHOLD - auto communi_threshold = CostModelContext::GetInstance()->costmodel_communi_threshold(); - if (communi_threshold < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_threshold' must be non-zero."; - } - COST_MODEL_COMMUNI_THRESHOLD = communi_threshold; - MS_LOG(INFO) << "costmodel_communi_threshold: " << COST_MODEL_COMMUNI_THRESHOLD << "."; - - // COST_MODEL_COMMUNI_CONST - auto communi_const = CostModelContext::GetInstance()->costmodel_communi_const(); - if (communi_const < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_const' must be non-zero."; - } - COST_MODEL_COMMUNI_CONST = communi_const; - MS_LOG(INFO) << "costmodel_communi_const: " << COST_MODEL_COMMUNI_CONST << "."; - - // COST_MODEL_COMMUNI_BIAS - auto communi_bias = CostModelContext::GetInstance()->costmodel_communi_bias(); - if (communi_bias < 0) { - MS_LOG(EXCEPTION) << "'costmodel_communi_bias' must be non-zero."; - } - COST_MODEL_COMMUNI_BIAS = communi_bias; - MS_LOG(INFO) << "costmodel_communi_bias: " << COST_MODEL_COMMUNI_BIAS << "."; - - // TENSOR_SLICE_ALIGNMENT_ENABLE - auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable(); - TENSOR_SLICE_ALIGNMENT_ENABLE = align_enable; - if (TENSOR_SLICE_ALIGNMENT_ENABLE) { - MS_LOG(INFO) << "tensor_slice_align_enable: true."; - } else { - MS_LOG(INFO) << "tensor_slice_align_enable: false."; - } - - // TENSOR_SLICE_ALIGNMENT_SIZE - auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size(); - if (align_size == 0) { - MS_LOG(EXCEPTION) << "'tensor_slice_align_size' must be positive."; - } - TENSOR_SLICE_ALIGNMENT_SIZE = align_size; - MS_LOG(INFO) << "tensor_slice_align_size: " << TENSOR_SLICE_ALIGNMENT_SIZE << "."; - - // FULLY_USE_DEVICES - auto fully_devices = CostModelContext::GetInstance()->fully_use_device(); - FULLY_USE_DEVICES = fully_devices; - if (FULLY_USE_DEVICES) { - MS_LOG(INFO) << "fully_use_devices: true."; - } else { - MS_LOG(INFO) << "fully_use_devices: false."; - } - - // ELEMENTWISE_OP_STRA_FOLLOW - auto is_ele_op_follow = CostModelContext::GetInstance()->elementwise_stra_follow(); - ELEMENTWISE_OP_STRA_FOLLOW = is_ele_op_follow; - if (ELEMENTWISE_OP_STRA_FOLLOW) { - MS_LOG(INFO) << "elementwise_op_strategy_follow: true."; - } else { - MS_LOG(INFO) << "elementwise_op_strategy_follow: false."; - } - - // MULTI_SUBGRAPHS - auto multi_subgraphs = CostModelContext::GetInstance()->is_multi_subgraphs(); - MULTI_SUBGRAPHS = multi_subgraphs; - if (MULTI_SUBGRAPHS) { - MS_LOG(INFO) << "multi_subgraphs: true."; - } else { - MS_LOG(INFO) << "multi_subgraphs: false."; - } - - // RUN_PHASE - auto phase = CostModelContext::GetInstance()->run_phase(); - if (phase != 0 && phase != 1) { - MS_LOG(EXCEPTION) << "'run_phase' must be in {0, 1}"; - } - RUN_PHASE = phase; - MS_LOG(INFO) << "run_phase: " << RUN_PHASE << "."; -} - -void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { - for (auto it = ops_.begin(); it != ops_.end();) { - if ((*it) == op) { - it = ops_.erase(it); - } else { - ++it; - } - } -} - -bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { - struct IsInGraph { - const OperatorInfoPtr test_; - explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} - bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } - }; - return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); -} - -void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { - std::vector curr_edges(edges_[{u_node, v_node}]); - curr_edges.push_back(edge); - edges_[{u_node, v_node}] = curr_edges; - - std::vector curr_out_edges(out_edges_[u_node]); - curr_out_edges.push_back(edge); - out_edges_[u_node] = curr_out_edges; - - std::vector curr_in_edges(in_edges_[v_node]); - curr_in_edges.push_back(edge); - in_edges_[v_node] = curr_in_edges; -} - -bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { - for (auto &edge_pair : edges_) { - auto edges = edge_pair.second; - for (auto &edge : edges) { - MS_EXCEPTION_IF_NULL(edge); - bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && - (edge->next_op_input_index() == input_index); - if (bool_result) { - return true; - } - } - } - return false; -} - -std::vector> CostGraph::ConstructConnectedComponents( - std::vector alive_ops) { - std::map visited; - - for (auto &op : alive_ops) { - visited[op] = false; - } - - MS_LOG(INFO) << "visited: " << visited.size() << "."; - for (auto &op : alive_ops) { - if ((!visited[op]) && op->is_alive()) { - std::shared_ptr new_component = std::make_shared(); - MS_EXCEPTION_IF_NULL(new_component); - new_component->SetDeviceMemoryAndCostParameter(); - DFS(op, &visited, new_component); - connected_compoents_.push_back(new_component); - } - } - return connected_compoents_; -} - -void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, - const std::shared_ptr &component) { - MS_EXCEPTION_IF_NULL(visited); - MS_EXCEPTION_IF_NULL(component); - visited->at(current_op) = true; - component->AddOperator(current_op); - - for (auto &edge : current_op->succ_edges()) { - bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && - (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); - if (bool_test) { - component->AddEdge(current_op, edge->next_operator(), edge); - DFS(edge->next_operator(), visited, component); - } - } - - for (auto &edge : current_op->prev_edges()) { - bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && - (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); - if (bool_test) { - component->AddEdge(edge->prev_operator(), current_op, edge); - DFS(edge->prev_operator(), visited, component); - } - } -} - -// Create final cost list for the graph: u --> v -CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, - const OperatorInfoPtr &v) { - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - MS_EXCEPTION_IF_NULL(e); - CostPtrList ret; - for (const auto &u_strategy : u->GetStrategyCost()) { - for (const auto &v_strategy : v->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(u_strategy); - MS_EXCEPTION_IF_NULL(v_strategy); - auto u_strategy_ptr = u_strategy->strategy_ptr; - auto v_strategy_ptr = v_strategy->strategy_ptr; - CostPtrList clist1 = u_strategy->cost_list; - CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); - CostPtrList clist3 = v_strategy->cost_list; - for (const auto &cost1 : clist1) { - for (const auto &cost2 : clist2) { - for (const auto &cost3 : clist3) { - MS_EXCEPTION_IF_NULL(cost1); - MS_EXCEPTION_IF_NULL(cost2); - MS_EXCEPTION_IF_NULL(cost3); - double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_; - double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_; - double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_; - double communication_forward = - cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_; - double communication_without_para = cost1->communication_without_parameter_ + - cost2->communication_without_parameter_ + - cost3->communication_without_parameter_; - auto decision = - std::make_shared(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3); - auto cost = std::make_shared(computation, communication, decision); - MS_EXCEPTION_IF_NULL(cost); - cost->communication_without_parameter_ = communication_without_para; - cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - cost->memory_with_reuse_ = memory; - cost->communication_forward_ = communication_forward; - ret.push_back(cost); - } - } - } - } - } - - Simplify(&ret); - return ret; -} - -// Create final cost list for the graph containing a signle node: u -CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { - MS_EXCEPTION_IF_NULL(u); - CostPtrList ret; - for (const auto &u_strategy : u->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(u_strategy); - auto u_strategy_ptr = u_strategy->strategy_ptr; - CostPtrList clist1 = u_strategy->cost_list; - for (const auto &cost1 : clist1) { - MS_EXCEPTION_IF_NULL(cost1); - auto decision = std::make_shared(u_strategy_ptr, cost1); - auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = cost1->communication_without_parameter_; - new_cost->communication_with_partial_para_ = - cost1->communication_without_parameter_ + - COST_MODEL_GAMMA * (cost1->communication_cost_ - cost1->communication_without_parameter_); - new_cost->memory_with_reuse_ = cost1->memory_with_reuse_; - new_cost->communication_forward_ = cost1->communication_forward_; - ret.push_back(new_cost); - } - } - - Simplify(&ret); - return ret; -} - -CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) { - // Select the cost with minimum inference time. Currently, the inference time is modeled as = - // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_ - if (cost_list.empty()) { - MS_LOG(ERROR) << "Final cost list is null."; - return nullptr; - } - CostPtrList after_mem_filter; - double minimum_memory = DBL_MAX; - // Filter out the valid costs. - for (auto &a_cost : cost_list) { - if (a_cost->memory_with_reuse_ <= memory) { - after_mem_filter.emplace_back(std::move(a_cost)); - } else if (a_cost->memory_with_reuse_ < minimum_memory) { - minimum_memory = a_cost->memory_with_reuse_; - } - } - if (after_mem_filter.empty()) { - MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory - << ", the memory capacity is: " << memory << "."; - return nullptr; - } - // Init the returned value with first cost. - CostPtr ret = after_mem_filter[0]; - - double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_forward_; - MS_LOG(INFO) << "Cost 0: " - << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ - << ", communication_forward_: " << ret->communication_forward_ - << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ - << ", communication_cost_: " << ret->communication_cost_ - << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; - for (size_t i = 1; i < after_mem_filter.size(); ++i) { - MS_EXCEPTION_IF_NULL(after_mem_filter[i]); - MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ - << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ - << ", communication_forward_: " << after_mem_filter[i]->communication_forward_ - << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ - << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ - << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ - << "."; - auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + - costmodel_beta_ * after_mem_filter[i]->communication_forward_; - MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; - if (minimum > tmp) { - minimum = tmp; - ret = after_mem_filter[i]; - MS_LOG(INFO) << "Selected: " << i; - } - } - return ret; -} - -CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { - // Select the cost with minimum training time. Currently, the training time is modeled as = - // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ - if (cost_list.empty()) { - MS_LOG(ERROR) << "Final cost list is null."; - return nullptr; - } - CostPtrList after_mem_filter; - double minimum_memory = DBL_MAX; - // Filter out the valid costs. - for (auto &a_cost : cost_list) { - if (a_cost->memory_with_reuse_ <= memory) { - after_mem_filter.emplace_back(std::move(a_cost)); - } else if (a_cost->memory_with_reuse_ < minimum_memory) { - minimum_memory = a_cost->memory_with_reuse_; - } - } - if (after_mem_filter.empty()) { - MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory - << ", the memory capacity is: " << memory << "."; - return nullptr; - } - // Init the returned value with first cost. - CostPtr ret = after_mem_filter[0]; - - double minimum = costmodel_alpha_ * ret->computation_cost_ + costmodel_beta_ * ret->communication_with_partial_para_; - MS_LOG(INFO) << "Cost 0: " - << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_ - << ", communication_with_partial_para_: " << ret->communication_with_partial_para_ - << ", communication_cost_: " << ret->communication_cost_ - << ", communication_without_parameter_: " << ret->communication_without_parameter_ << "."; - MS_LOG(INFO) << "Cost 0: total_cost: " << minimum; - for (size_t i = 1; i < after_mem_filter.size(); ++i) { - MS_EXCEPTION_IF_NULL(after_mem_filter[i]); - MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_ - << ", computation_cost_: " << after_mem_filter[i]->computation_cost_ - << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_ - << ", communication_cost_: " << after_mem_filter[i]->communication_cost_ - << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_ - << "."; - auto tmp = costmodel_alpha_ * after_mem_filter[i]->computation_cost_ + - costmodel_beta_ * after_mem_filter[i]->communication_with_partial_para_; - MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp; - if (minimum > tmp) { - minimum = tmp; - ret = after_mem_filter[i]; - MS_LOG(INFO) << "Selected: " << i; - } - } - return ret; -} - -CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, - double available_memory) { - CostPtrList selected_cost_list(all_cost_list.size(), nullptr); - double minimum = DBL_MAX, total_memory = 0.0; - CostPtrList ret(all_cost_list.size(), nullptr); - // Check whether valid costs exist. - for (size_t i = 0; i < all_cost_list.size(); ++i) { - if (all_cost_list[i][0] == nullptr) { - MS_LOG(ERROR) << "The cost list " << i << " is empty."; - return ret; - } else { - double memory_i_cost = DBL_MAX; - for (size_t j = 0; j < all_cost_list[i].size(); ++j) { - if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) { - memory_i_cost = all_cost_list[i][j]->memory_with_reuse_; - } - } - total_memory += memory_i_cost; - } - } - if (total_memory >= available_memory) { - MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory - << ", minimum strategy cost: " << total_memory << "."; - return selected_cost_list; - } - - std::function recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive, - &available_memory, this](size_t k) { - if (k == all_cost_list.size()) { - double tmp_memory = 0.0, tmp_minimum = 0.0; - for (size_t i = 0; i < selected_cost_list.size(); ++i) { - MS_EXCEPTION_IF_NULL(selected_cost_list[i]); - tmp_memory += selected_cost_list[i]->memory_with_reuse_; - tmp_minimum += costmodel_alpha_ * selected_cost_list[i]->computation_cost_ + - costmodel_beta_ * selected_cost_list[i]->communication_with_partial_para_; - } - MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum - << "."; - if (tmp_memory < available_memory && tmp_minimum < minimum) { - ret = selected_cost_list; - minimum = tmp_minimum; - MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << "."; - } - return; - } - - MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; - for (auto &c : all_cost_list[k]) { - selected_cost_list[k] = c; - recursive(k + 1); - } - }; - recursive(0); - return ret; -} - -Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { - MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; - auto connected_components = ConstructConnectedComponents(alive_ops); - MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; - std::vector all_list; - for (size_t j = 0; j < connected_components.size(); ++j) { - auto one_component = connected_components[j]; - MS_EXCEPTION_IF_NULL(one_component); - if (one_component->GetOperators().size() == 1) { - MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; - auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); - all_list.push_back(cost_list); - } else if (one_component->GetOperators().size() == 2) { - MS_LOG(INFO) << "There are 2 operators in a component in the final graph."; - OperatorInfoPtr u, v; - auto first_op = one_component->GetOperators()[0]; - auto second_op = one_component->GetOperators()[1]; - MS_EXCEPTION_IF_NULL(first_op); - MS_EXCEPTION_IF_NULL(second_op); - if (!first_op->GetAliveSuccEdges().empty() && - first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { - u = first_op; - v = second_op; - } else if (!second_op->GetAliveSuccEdges().empty() && - second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { - u = second_op; - v = first_op; - } else { - MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size() - << ", " << second_op->GetAliveSuccEdges().size() << "."; - } - MS_EXCEPTION_IF_NULL(u); - auto e = u->GetAliveSuccEdges()[0]; - auto cost_list = one_component->CreateFinalCostList(u, e, v); - all_list.push_back(cost_list); - } else { - MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size() - << " operators in a component in the final graph."; - } - } - // - auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); - for (size_t k = 0; k < selected_cost_list.size(); ++k) { - auto selected_cost = selected_cost_list[k]; - if (selected_cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(connected_components[k]); - if (connected_components[k]->GetOperators().size() == 1) { - auto u = connected_components[k]->GetOperators()[0]; - auto decision = selected_cost->decision_ptr_->cast(); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } else if (connected_components[k]->GetOperators().size() == 2) { - OperatorInfoPtr u = nullptr, v = nullptr; - auto first_op = connected_components[k]->GetOperators()[0]; - auto second_op = connected_components[k]->GetOperators()[1]; - MS_EXCEPTION_IF_NULL(first_op); - MS_EXCEPTION_IF_NULL(second_op); - if (!first_op->GetAliveSuccEdges().empty() && - first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) { - u = first_op; - v = second_op; - } else if (!second_op->GetAliveSuccEdges().empty() && - second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) { - u = second_op; - v = first_op; - } - MS_EXCEPTION_IF_NULL(u); - auto e = u->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(v); - MS_EXCEPTION_IF_NULL(e); - MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); - auto decision = selected_cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); - v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); - e->set_selected_cost(decision->middle_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } - } - return SUCCESS; -} - -// searching the strategy for the final eliminated graph -Status CostGraph::SearchStrategy() { - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began."; - std::vector alive_ops; - (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { - MS_EXCEPTION_IF_NULL(op); - if (op->is_alive()) { - alive_ops.push_back(op); - } - }); - - if (alive_ops.size() > 2) { - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - return SearchStrategyForMultiNodeFinalGraph(alive_ops); - } else { - // inference phase - MS_LOG(EXCEPTION) - << "Currently, searching strategy for the multi-node final graph in inference phase is not supported."; - } - } else if (alive_ops.size() == 1) { - MS_LOG(INFO) << "There are 1 single node in the final graph."; - OperatorInfoPtr u = alive_ops[0]; - auto cost_list = CreateFinalSingleCostList(u); - CostPtr cost = nullptr; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); - } else { - // inference phase - cost = SelectCostWithMinInferenceTime(cost_list, dev_memory_); - } - if (cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(cost->decision_ptr_); - auto decision = cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; - return SUCCESS; - } else { - // In this case, the final graph should contains exactly 2 nodes. - if (alive_ops.empty()) { - MS_LOG(INFO) << "0 Operator in the final graph."; - return SUCCESS; - } - OperatorInfoPtr u, v; - MS_EXCEPTION_IF_NULL(alive_ops[0]); - MS_EXCEPTION_IF_NULL(alive_ops[1]); - if (!alive_ops[0]->GetAliveSuccEdges().empty() && - alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) { - u = alive_ops[0]; - v = alive_ops[1]; - } else if (!alive_ops[1]->GetAliveSuccEdges().empty() && - alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) { - u = alive_ops[1]; - v = alive_ops[0]; - } else { - if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) { - MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size() - << ", " << alive_ops[1]->GetAliveSuccEdges().size() << "."; - } else { - // In this case, the final graph consists of two single nodes - MS_LOG(INFO) << "There are 2 single nodes in the final graph."; - std::vector all_list; - auto connected_components = ConstructConnectedComponents(alive_ops); - MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; - for (size_t i = 0; i < connected_components.size(); ++i) { - MS_LOG(INFO) << "There are 1 operator in a component in the final graph."; - auto one_component = connected_components[i]; - MS_EXCEPTION_IF_NULL(one_component); - auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]); - all_list.push_back(cost_list); - } - CostPtrList selected_cost_list; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, dev_memory_); - } else { - // inference phase - MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference " - "phase is not supported."; - } - for (size_t k = 0; k < selected_cost_list.size(); ++k) { - auto selected_cost = selected_cost_list[k]; - if (selected_cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(connected_components[k]); - auto one_operator = connected_components[k]->GetOperators()[0]; - MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_); - auto decision = selected_cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_); - MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended."; - } - - return SUCCESS; - } - } - MS_LOG(INFO) << "There are 2 nodes in the final graph."; - // In this case, the finale graph is exactly of the form: u --> v - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - auto e = u->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(e); - auto cost_list = CreateFinalCostList(u, e, v); - CostPtr cost = nullptr; - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - cost = SelectCostWithMinTrainingTime(cost_list, dev_memory_); - } else { - MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference " - "phase is not supported."; - } - if (cost == nullptr) { - MS_LOG(ERROR) << "No vaild strategy can be found under the current device memory: " << dev_memory_ << "."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(cost->decision_ptr_); - auto decision = cost->decision_ptr_->cast(); - MS_EXCEPTION_IF_NULL(decision); - u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_); - v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_); - e->set_selected_cost(decision->middle_cost_); - MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended."; - return SUCCESS; - } -} - -// Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated -// return the v and the edge u --> v -OperatorInfoPtr CostGraph::CheckOpElimination() const { - for (auto &op : ops_) { - bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; - if (bool_test) { - if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { - return op; - } - } - } - return nullptr; -} - -// Check the graph whether an EdgeElimination can be performed -std::vector> CostGraph::CheckEdgeElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (!op->is_alive()) continue; - std::map count; - for (auto &edge : op->GetAliveSuccEdges()) { - MS_EXCEPTION_IF_NULL(edge); - auto v = edge->next_operator(); - count[v.get()]++; - } - for (auto &pair : count) { - auto *op_ptr = pair.first; - int op_count = pair.second; - if (op_count > 1) { - std::vector> ret; - for (auto &edge : op->GetAliveSuccEdges()) { - MS_EXCEPTION_IF_NULL(edge); - if (edge->next_operator().get() == op_ptr) { - ret.push_back(edge); - } - } - return ret; - } - } - } - return {}; -} - -// Check the graph whether a MergeElimination can be performed -OperatorInfoPtr CostGraph::CheckMergeElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; - if (bool_test) { - auto next_op = op->GetAliveSuccEdges()[0]->next_operator(); - MS_EXCEPTION_IF_NULL(next_op); - if (!next_op->GetAlivePrevEdges().empty()) { - return op; - } - } - } - return nullptr; -} - -// Check the graph whether a ContractElimination can be performed -OperatorInfoPtr CostGraph::CheckContractElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); - if (bool_test) { - auto edge = op->GetAlivePrevEdges()[0]; - MS_EXCEPTION_IF_NULL(edge); - auto prev_op = edge->prev_operator(); - MS_EXCEPTION_IF_NULL(prev_op); - if (!prev_op->GetAliveSuccEdges().empty()) { - return op; - } - } - } - return nullptr; -} - -// Check the graph whether a TriangleElimination can be performed -std::pair> CostGraph::CheckTriangleElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); - if (bool_test) { - auto edge1 = op->GetAliveSuccEdges()[0]; - auto edge2 = op->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(edge1); - MS_EXCEPTION_IF_NULL(edge2); - auto first_op = edge1->next_operator(); - auto second_op = edge2->next_operator(); - MS_EXCEPTION_IF_NULL(first_op); - for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { - if (first_op_succ_edge->next_operator() == second_op) { - return {op, first_op_succ_edge}; - } - } - MS_EXCEPTION_IF_NULL(second_op); - for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { - if (second_op_succ_edge->next_operator() == first_op) { - return {op, second_op_succ_edge}; - } - } - } - } - return {nullptr, nullptr}; -} - -// Check the graph whether a StarElimination can be performed. -// NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. -OperatorInfoPtr CostGraph::CheckStarElimination() const { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); - if (bool_test) { - return op; - } - } - return nullptr; -} - -// This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace -// 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. -std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { - // in this case, the operators are organised in the form of u-->op-->v, and the goal - // is to eliminate 'op'. - MS_EXCEPTION_IF_NULL(op); - MS_LOG(INFO) << "Now eliminating node: " << op->name() << "."; - auto edge_u_op = op->GetAlivePrevEdges()[0]; - auto edge_op_v = op->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(edge_u_op); - MS_EXCEPTION_IF_NULL(edge_op_v); - auto u = edge_u_op->prev_operator(); - auto v = edge_op_v->next_operator(); - std::vector output_indexs, input_indexs; - size_t output_index, input_index; - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); - std::shared_ptr new_edge; - if (edge_u_op->is_combined()) { - output_indexs = edge_u_op->prev_op_output_indexs(); - } else { - output_index = edge_u_op->prev_op_output_index(); - output_indexs.push_back(output_index); - } - if (edge_op_v->is_combined()) { - input_indexs = edge_op_v->next_op_input_indexs(); - } else { - input_index = edge_op_v->next_op_input_index(); - input_indexs.push_back(input_index); - } - - if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) { - new_edge = std::make_shared(new_edge_name, u, v, output_index, input_index, false); - } else { - new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); - } - MS_EXCEPTION_IF_NULL(new_edge); - new_edge->set_pre_op_output(edge_u_op->prev_op_output()); - new_edge->set_next_op_input(edge_op_v->next_op_input()); - new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v); - u->ReplaceSuccEdge(op, new_edge); - v->ReplacePreEdge(op, new_edge); - op->SetNotAlive(); - MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded."; - return new_edge; -} - -// This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', -// and sets new costlist for the new edge. -std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { - MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; - MS_EXCEPTION_IF_NULL(edges[0]); - auto u = edges[0]->prev_operator(); - auto v = edges[0]->next_operator(); - MS_EXCEPTION_IF_NULL(u); - MS_EXCEPTION_IF_NULL(v); - std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); - std::vector output_indexs, input_indexs; - - for (auto &edge : edges) { - MS_EXCEPTION_IF_NULL(edge); - if (edge->is_combined()) { - auto from_output_indexs = edge->prev_op_output_indexs(); - auto from_input_indexs = edge->next_op_input_indexs(); - (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs)); - (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs)); - } else { - output_indexs.push_back(edge->prev_op_output_index()); - input_indexs.push_back(edge->next_op_input_index()); - } - } - - std::shared_ptr new_edge = std::make_shared(new_edge_name, u, v, output_indexs, input_indexs, true); - MS_EXCEPTION_IF_NULL(new_edge); - new_edge->set_pre_op_output(edges[0]->prev_op_output()); - new_edge->set_next_op_input(edges[0]->next_op_input()); - - new_edge->EdgeEliminationSetNewCost(u, edges, v); - - u->ReplaceSuccEdges(v, new_edge); - v->ReplacePreEdges(u, new_edge); - MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded."; - return new_edge; -} - -// Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' -// for this contract under the strategy 'op_strategy' -void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, - const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList &tar_cost_list, - CostPtrList *const tar_cost_list_new) { - for (size_t i = 0; i < op_cost_list.size(); ++i) { - auto &op_cost = op_cost_list[i]; - MS_EXCEPTION_IF_NULL(op_cost); - for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto &edge_cost = edge_cost_list[j]; - MS_EXCEPTION_IF_NULL(edge_cost); - for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto &tar_cost = tar_cost_list[k]; - MS_EXCEPTION_IF_NULL(tar_cost); - double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; - double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; - double communication = - op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; - double communication_forward = - op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_; - double communication_without_para = op_cost->communication_without_parameter_ + - edge_cost->communication_without_parameter_ + - tar_cost->communication_without_parameter_; - - auto decision = - std::make_shared(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost); - auto new_cost = std::make_shared(computation, communication, decision); - MS_EXCEPTION_IF_NULL(new_cost); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->memory_with_reuse_ = memory; - new_cost->communication_forward_ = communication_forward; - MS_EXCEPTION_IF_NULL(tar_cost_list_new); - tar_cost_list_new->emplace_back(std::move(new_cost)); - } - } - } -} - -// This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the -// target_op -OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { - MS_EXCEPTION_IF_NULL(op); - auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); - auto edge_ptr = op->GetAliveSuccEdges()[0]; - MS_EXCEPTION_IF_NULL(target_op); - MS_EXCEPTION_IF_NULL(edge_ptr); - MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; - bool valid = false; - - for (auto &tar_stra_cost : target_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(tar_stra_cost); - auto tar_stra = tar_stra_cost->strategy_ptr; - auto tar_clist_origin = tar_stra_cost->cost_list; - CostPtrList tar_clist_new; - - for (auto &op_stra_cost : op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(op_stra_cost); - auto op_stra = op_stra_cost->strategy_ptr; - auto op_clist = op_stra_cost->cost_list; - auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra); - - CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); - } - Simplify(&tar_clist_new); - // Set the new costlist w.r.t the strategy - tar_stra_cost->cost_list = tar_clist_new; - if ((!valid) && (!tar_clist_new.empty())) { - valid = true; - } - } - - if (!valid) { - MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed."; - } - op->SetNotAlive(); - MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded."; - return target_op; -} - -// Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' -// for this contract under the strategy 'contract_op_stra' -void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, - const CostPtrList &contract_op_cost_list, - const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, - const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { - for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { - auto &contract_op_cost = contract_op_cost_list[i]; - MS_EXCEPTION_IF_NULL(contract_op_cost); - for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto &edge_cost = edge_cost_list[j]; - MS_EXCEPTION_IF_NULL(edge_cost); - for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto &tar_cost = tar_cost_list[k]; - MS_EXCEPTION_IF_NULL(tar_cost); - double computation = - contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; - double memory = - contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; - double communication = - contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_; - double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ + - tar_cost->communication_forward_; - double communication_without_para = contract_op_cost->communication_without_parameter_ + - edge_cost->communication_without_parameter_ + - tar_cost->communication_without_parameter_; - - auto decision = std::make_shared(contract_op_stra, contract_op_cost, edge_cost, - target_op_stra, tar_cost); - auto new_cost = std::make_shared(computation, communication, decision); - new_cost->communication_without_parameter_ = communication_without_para; - new_cost->communication_with_partial_para_ = - communication_without_para + COST_MODEL_GAMMA * (communication - communication_without_para); - new_cost->memory_with_reuse_ = memory; - new_cost->communication_forward_ = communication_forward; - tar_cost_list_new->emplace_back(std::move(new_cost)); - } - } - } -} - -// This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the -// target_op -OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { - MS_EXCEPTION_IF_NULL(op); - auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); - auto edge_ptr = op->GetAlivePrevEdges()[0]; - MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; - bool valid = false; - - for (auto &tar_stra_cost : target_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(tar_stra_cost); - auto tar_stra = tar_stra_cost->strategy_ptr; - auto tar_clist_origin = tar_stra_cost->cost_list; - CostPtrList tar_clist_new; - - for (auto &op_stra_cost : op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(op_stra_cost); - auto op_stra = op_stra_cost->strategy_ptr; - auto op_clist = op_stra_cost->cost_list; - auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra); - - CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new); - } - Simplify(&tar_clist_new); - // Set the new costlist w.r.t the strategy - tar_stra_cost->cost_list = tar_clist_new; - if ((!valid) && (!tar_clist_new.empty())) { - valid = true; - } - } - if (!valid) { - MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed."; - } - op->SetNotAlive(); - MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded."; - return target_op; -} - -void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, - StrategyPtr right_op_stra, const CostPtr &right_op_cost, - const CostPtrList &elimi_op_clist, - const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, - const CostPtrList &left_node_clist_origin, - CostPtrList *left_node_clist_new) { - MS_EXCEPTION_IF_NULL(right_edge_cost); - MS_EXCEPTION_IF_NULL(right_op_cost); - MS_EXCEPTION_IF_NULL(left_node_clist_new); - for (auto &elimi_op_cost : elimi_op_clist) { - MS_EXCEPTION_IF_NULL(elimi_op_cost); - for (auto &left_edge_cost : left_edge_clist) { - MS_EXCEPTION_IF_NULL(left_edge_cost); - for (auto &left_node_cost : left_node_clist_origin) { - MS_EXCEPTION_IF_NULL(left_node_cost); - double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + - left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; - double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ + - left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_; - double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ + - left_node_cost->communication_cost_ + right_edge_cost->communication_cost_; - double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ + - left_node_cost->communication_forward_ + right_edge_cost->communication_forward_; - double new_commu_without = - elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ + - left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_; - - auto decision = std::make_shared( - elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost, left_op_stra, left_node_cost, right_op_stra); - auto new_cost = std::make_shared(new_computation, new_commu_cost, decision); - new_cost->communication_without_parameter_ = new_commu_without; - new_cost->communication_with_partial_para_ = - new_commu_without + COST_MODEL_GAMMA * (new_commu_cost - new_commu_without); - new_cost->memory_with_reuse_ = new_memory; - new_cost->communication_forward_ = new_commu_forward; - left_node_clist_new->emplace_back(std::move(new_cost)); - } - } - } -} - -void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, - const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, - const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, - const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, - const CostPtrList &left_node_clist_origin, - CostPtrList *left_node_clist_new) { - MS_EXCEPTION_IF_NULL(elimi_op); - for (auto &right_node_cost : right_node_clist) { - MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto &right_edge_cost : right_edge_clist) { - MS_EXCEPTION_IF_NULL(right_edge_cost); - CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, - elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, - left_node_clist_new); - } - } -} - -OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, - const std::shared_ptr &edge_left_right) { - MS_EXCEPTION_IF_NULL(edge_left_right); - MS_EXCEPTION_IF_NULL(elimi_op); - MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; - auto left_node = edge_left_right->prev_operator(); - auto right_node = edge_left_right->next_operator(); - auto left_edge = elimi_op->GetAliveSuccEdges()[0]; - auto right_edge = elimi_op->GetAliveSuccEdges()[1]; - MS_EXCEPTION_IF_NULL(left_node); - MS_EXCEPTION_IF_NULL(right_node); - MS_EXCEPTION_IF_NULL(left_edge); - MS_EXCEPTION_IF_NULL(right_edge); - MS_LOG(INFO) << "The left operator is: " << left_node->name() << "."; - MS_LOG(INFO) << "The right operator is: " << right_node->name() << "."; - - if (left_edge->next_operator() != left_node) { - auto tmp = left_edge; - left_edge = right_edge; - right_edge = tmp; - } - bool valid = false; - - for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(left_node_stra_cost); - auto left_node_stra = left_node_stra_cost->strategy_ptr; - auto left_node_clist_origin = left_node_stra_cost->cost_list; - CostPtrList left_node_clist_new; - - for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); - auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; - auto elimi_op_clist = elimi_op_stra_cost->cost_list; - auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); - - for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(right_node_stra_cost); - auto right_node_stra = right_node_stra_cost->strategy_ptr; - auto right_node_clist = right_node_stra_cost->cost_list; - auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra); - - CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra, - right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin, - &left_node_clist_new); - } - } - Simplify(&left_node_clist_new); - // Set the new costlist w.r.t the strategy - left_node_stra_cost->cost_list = left_node_clist_new; - if ((!valid) && (!left_node_clist_new.empty())) { - valid = true; - } - } - - if (!valid) { - MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name() << " failed."; - } - elimi_op->SetNotAlive(); - MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded."; - return left_node; -} - -void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, - const CostPtrList &first_succ_node_clist, - const CostPtrList &first_succ_edge_clist, - const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, - std::vector succ_nodes_stras, - CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, - CostPtrList *first_succ_node_clist_new) { - for (auto &first_succ_node_cost : first_succ_node_clist) { - for (auto &first_succ_edge_cost : first_succ_edge_clist) { - for (auto &merged_node_cost : merged_op_clist) { - MS_EXCEPTION_IF_NULL(merged_node_cost); - succ_nodes_stras[0] = first_succ_node_stra; - succ_edges_costs[0] = first_succ_edge_cost; - succ_nodes_costs[0] = first_succ_node_cost; - - double computation_cost = merged_node_cost->computation_cost_, - memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_, - commu_without = merged_node_cost->communication_without_parameter_, - commu_forward = merged_node_cost->communication_forward_; - for (size_t i = 0; i < succ_nodes_stras.size(); ++i) { - MS_EXCEPTION_IF_NULL(succ_edges_costs[i]); - if (i == 0) { - computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_; - memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_; - commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_; - commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_; - commu_without += succ_edges_costs[i]->communication_without_parameter_ + - succ_nodes_costs[i]->communication_without_parameter_; - } else { - computation_cost += succ_edges_costs[i]->computation_cost_; - memory_cost += succ_edges_costs[i]->memory_with_reuse_; - commu_cost += succ_edges_costs[i]->communication_cost_; - commu_forward += succ_edges_costs[i]->communication_forward_; - commu_without += succ_edges_costs[i]->communication_without_parameter_; - } - } - - auto decision = std::make_shared(merged_op_stra, merged_node_cost, succ_edges_costs, - succ_nodes_stras, succ_nodes_costs); - auto new_cost = std::make_shared(computation_cost, commu_cost, decision); - new_cost->communication_without_parameter_ = commu_without; - new_cost->communication_with_partial_para_ = commu_without + COST_MODEL_GAMMA * (commu_cost - commu_without); - new_cost->memory_with_reuse_ = memory_cost; - new_cost->communication_forward_ = commu_forward; - first_succ_node_clist_new->emplace_back(std::move(new_cost)); - } - } - } -} - -void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, - const StrategyPtr &first_succ_node_stra, - const CostPtrList &first_succ_node_clist, - const CostPtrList &first_succ_edge_clist, - const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, - CostPtrList *first_succ_node_clist_new) { - std::vector succ_nodes_stras(succ_edges.size(), nullptr); - CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); - std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, - &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs, - &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive, - this](size_t k) { - if (k == succ_edges.size()) { - CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, - merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs, - succ_nodes_costs, first_succ_node_clist_new); - return; - } - MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size() - << ", first_succ_edge_clist: " << first_succ_edge_clist.size() - << ", merged_op_clist: " << merged_op_clist.size() - << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << "."; - auto succ_edge = succ_edges[k]; - MS_EXCEPTION_IF_NULL(succ_edge); - auto succ_node = succ_edge->next_operator(); - MS_EXCEPTION_IF_NULL(succ_node); - for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(succ_node_stra_cost); - auto succ_node_stra = succ_node_stra_cost->strategy_ptr; - auto succ_node_clist = succ_node_stra_cost->cost_list; - auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); - - for (auto &succ_node_cost : succ_node_clist) { - MS_EXCEPTION_IF_NULL(succ_node_cost); - for (auto &succ_edge_cost : succ_edge_clist) { - MS_EXCEPTION_IF_NULL(succ_edge_cost); - succ_nodes_stras[k] = succ_node_stra; - succ_edges_costs[k] = succ_edge_cost; - succ_nodes_costs[k] = succ_node_cost; - recursive(k + 1); - } - } - } - }; - - recursive(1); -} - -std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { - MS_EXCEPTION_IF_NULL(merged_op); - auto succ_edges = merged_op->GetAliveSuccEdges(); - MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; - for (auto &succ_edge : succ_edges) { - MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); - MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; - } - - MS_EXCEPTION_IF_NULL(succ_edges[0]); - auto first_succ_node = succ_edges[0]->next_operator(); - auto first_succ_edge = succ_edges[0]; - bool valid = false; - - // 'merged_op' is merged into first_node - MS_EXCEPTION_IF_NULL(first_succ_node); - for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); - auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; - auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; - CostPtrList first_succ_node_clist_new; - - for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { - MS_EXCEPTION_IF_NULL(merged_op_stra_cost); - auto merged_op_stra = merged_op_stra_cost->strategy_ptr; - auto merged_op_clist = merged_op_stra_cost->cost_list; - auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra); - - CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist, - merged_op_stra, merged_op_clist, &first_succ_node_clist_new); - } - Simplify(&first_succ_node_clist_new); - // Set the new costlist w.r.t the strategy - first_succ_node_stra_cost->cost_list = first_succ_node_clist_new; - if ((!valid) && (!first_succ_node_clist_new.empty())) { - valid = true; - } - } - - if (!valid) { - MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name() << " failed."; - } - - merged_op->SetNotAlive(); - MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded."; - return succ_edges; -} - -size_t CostGraph::GetNumEdges() const { - size_t sum = 0; - for (const auto &kv : edges_) { - auto &edges = kv.second; - sum += edges.size(); - } - return sum; -} -Status CostGraph::InitSelectedStrategy() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->name().find(RESHAPEINFO) != std::string::npos) { - continue; - } - auto result = op->InitSelectedStrategy(op->selected_strategy()); - if (result != SUCCESS) { - return result; - } - } - // reshape init should be apply after the init of it's previous node and next node. - for (size_t i = 0; i < ops_.size(); ++i) { - if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) { - auto reshape_info = std::dynamic_pointer_cast(ops_[i]); - auto in_edges = GetOriginalPrevEdges(ops_[i]); - auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](std::shared_ptr edge) { - return edge->prev_operator()->name() == reshape_info->pre_operator_name(); - }); - auto out_edges = GetOriginalNextEdges(ops_[i]); - auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](std::shared_ptr edge) { - return edge->next_operator()->name() == reshape_info->next_operator_name(); - }); - if (pre_iter != in_edges.end()) { - MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name(); - int32_t pre_index = reshape_info->pre_operator_index(); - TensorInfo pre_info; - if (ops_[i]->name() == (*pre_iter)->prev_operator()->name()) { - pre_info = (*pre_iter)->prev_operator()->inputs_tensor_info()[pre_index]; - } else { - pre_info = (*pre_iter)->prev_operator()->outputs_tensor_info()[pre_index]; - } - reshape_info->SetInputLayout(pre_info.tensor_layout()); - Dimensions stra = pre_info.InferStrategy(); - if (stra.empty()) { - MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed"; - } - std::vector stra_inputs = {stra}; - StrategyPtr reshape_stra = - std::make_shared((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs); - reshape_info->set_strategy(reshape_stra); - } - if (next_iter != out_edges.end()) { - MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name(); - int32_t next_index = reshape_info->next_operator_index(); - reshape_info->SetOutputLayout((*next_iter)->next_operator()->inputs_tensor_info()[next_index].tensor_layout()); - } - if (reshape_info->Init(nullptr) != SUCCESS) { - return FAILED; - } - } - } - return SUCCESS; -} - -Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); - if ((output_parameter != 0) && (output_parameter != 1)) { - MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; - return FAILED; - } - } - return SUCCESS; -} - -void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map *visited, - std::vector *topo_order) { - MS_EXCEPTION_IF_NULL(current_op); - MS_EXCEPTION_IF_NULL(visited); - MS_EXCEPTION_IF_NULL(topo_order); - - visited->at(current_op) = true; - for (const auto &s_edge : current_op->succ_edges()) { - if (!visited->at(s_edge->next_operator())) { - DFSForTopoOrder(s_edge->next_operator(), visited, topo_order); - } - } - topo_order->push_back(current_op); -} - -// Compute a topological order of the costgraph -void CostGraph::TopologyOrder(std::vector *topo_order) { - std::map visited; - for (auto &op : ops_) { - visited[op] = false; - } - - for (auto &op : ops_) { - if (!visited[op]) { - DFSForTopoOrder(op, &visited, topo_order); - } - } -} -void CostGraph::MarkCriticalOpsAndEdges(const std::map &candidate_ops) { - for (auto &op : ops_) { - auto search = candidate_ops.find(op); - if (search != candidate_ops.end()) { - // Mark the critical operators - op->mark_output_critical(); - // Mark the successive edges - for (auto &s_edge : op->succ_edges()) { - s_edge->mark_output_critical(); - } - } else { - op->mark_output_not_critical(); - } - } -} - -Status CostGraph::DetermineCriticalOps(const std::vector &topo_order) { - if (topo_order.size() == 0) { - MS_LOG(ERROR) << "0 operator in costgraph."; - return FAILED; - } - auto &first_op = topo_order[0]; - if (first_op->prev_edges().size() > 0) { - MS_LOG(ERROR) << "The first operator in the first of topological order of " - "costgraph should have 0 incoming edge, but has " - << first_op->prev_edges() << "edges."; - return FAILED; - } - // The 'curr_memory_state' records , where remaining_output_cnt is the number - // of the output of OperatorInfo that currently has not been used - std::map curr_memory_state; - (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToInt(first_op->succ_edges().size()))); - std::map max_memory_state = curr_memory_state; - // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has - // not been used - double curr_memory_size = first_op->GetOutputsTotalSize(); - double max_memory_size = curr_memory_size; - - for (size_t finished = 1; finished < topo_order.size(); ++finished) { - // Produce - (void)curr_memory_state.emplace( - std::make_pair(topo_order[finished], SizeToInt(topo_order[finished]->succ_edges().size()))); - curr_memory_size += topo_order[finished]->GetOutputsTotalSize(); - // Consume - for (const auto &prev_edge : topo_order[finished]->prev_edges()) { - const auto &prev_op = prev_edge->prev_operator(); - curr_memory_state[prev_op]--; - } - for (const auto &prev_edge : topo_order[finished]->prev_edges()) { - const auto &prev_op = prev_edge->prev_operator(); - if (curr_memory_state[prev_op] < 0) { - MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op]; - return FAILED; - } else if (curr_memory_state[prev_op] == 0) { - curr_memory_state.erase(prev_op); - curr_memory_size -= prev_op->GetOutputsTotalSize(); - } - } - - if (curr_memory_size < 0) { - MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size; - } - // Modify the max - if (curr_memory_size > max_memory_size) { - max_memory_size = curr_memory_size; - max_memory_state = curr_memory_state; - } - } - // Mark those critical operators - MarkCriticalOpsAndEdges(max_memory_state); - return SUCCESS; -} - -Status CostGraph::ComputeOpsAndEdgesOutputCritical() { - // Two steps to do: - // 1. Compute a topological order of the costgraph - // 2. Determine and mark the operators (and necessary edges) that are critical - std::vector topo_order; - TopologyOrder(&topo_order); - std::reverse(std::begin(topo_order), std::end(topo_order)); - - if (DetermineCriticalOps(topo_order) != SUCCESS) { - MS_LOG(ERROR) << "Determining critical operators failed."; - return FAILED; - } - return SUCCESS; -} - -Status CostGraph::CalculateOpsMemoryCost() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->CalculateMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; - return FAILED; - } - } - return SUCCESS; -} - -Status CostGraph::CalculateOpsMemoryCostForInference() { - for (auto &op : ops_) { - MS_EXCEPTION_IF_NULL(op); - if (op->CalculateMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; - return FAILED; - } - } - return SUCCESS; -} - -Status CostGraph::CalculateEdgesMemoryCost() { - for (auto &edge_pair : edges_) { - const auto &edges = edge_pair.second; - for (auto &one_edge : edges) { - if (one_edge->CalculateMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; - return FAILED; - } - } - } - return SUCCESS; -} - -Status CostGraph::CalculateEdgesMemoryCostForInference() { - for (auto &edge_pair : edges_) { - const auto &edges = edge_pair.second; - for (auto &one_edge : edges) { - if (one_edge->CalculateMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; - return FAILED; - } - } - } - return SUCCESS; -} - -OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { - for (auto one_op : ops_) { - if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { - if (one_op->refkey_parameter_name() == p_name) { - return one_op; - } - } - } - return nullptr; -} -Status CostGraph::CorrectOpsMemoryCost() { - for (auto &one_op : ops_) { - if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { - if (one_op->GetAliveSuccEdges().size() > 1) { - // Filter out the case when the TmpIdentity being used by multiple operators - std::map output_count; - for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { - auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); - output_count[output_index]++; - } - for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) { - auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index(); - if (output_count[output_index] <= 1) { - continue; - } - auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator(); - MS_EXCEPTION_IF_NULL(next_op); - auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index(); - if (next_op->CorrectMemoryCost(input_index) != SUCCESS) { - MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name() - << ", the output_index: " << output_index << ", the input_index: " << input_index << "."; - return FAILED; - } - output_count[output_index]--; - } - } - } - } - return SUCCESS; -} - -Status CostGraph::CalculateMemoryCost() { - if (RUN_PHASE == TRAINING_PHASE) { - // training phase - if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) { - // Calculate operators' memory usage - if (CalculateOpsMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed."; - return FAILED; - } - // Calculate edges' memory usage - if (CalculateEdgesMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed."; - return FAILED; - } - // Correct memory usage caused by TmpIdentity - if (CorrectOpsMemoryCost() != SUCCESS) { - MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed."; - return FAILED; - } - } else { - MS_LOG(ERROR) << "Computing operators' parameter_involved failed."; - return FAILED; - } - } else { - // inference phase - if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) { - // Calculate operators' memory usage - if (CalculateOpsMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; - return FAILED; - } - // Calculate edges's memory usage - if (CalculateEdgesMemoryCostForInference() != SUCCESS) { - MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed."; - return FAILED; - } - } else { - MS_LOG(ERROR) << "Computing operators' critical flag failed."; - return FAILED; - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h deleted file mode 100644 index 3b8b389d81..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ /dev/null @@ -1,238 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ - -#include -#include -#include -#include -#include -#include "../../common.h" -#include "common/utils.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/costmodel_context.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/ops_info/tmp_identity_info.h" - -namespace mindspore { -namespace parallel { -#define OPERATOR_TO_OPERATOR_CONNECTOR "-" -#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) -#define DEFAULT_COST_MODEL_ALPHA 1.0 -#define DEFAULT_COST_MODEL_BETA 400.0 -#define DEFAULT_COST_MODEL_GAMMA 0.001 -#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true -#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 -#define DEFAULT_COST_MODEL_COMMUNI_CONST 3072.0 -#define DEFAULT_COST_MODEL_COMMUNI_BIAS 1024.0 -#define DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE false -#define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 -#define DEFAULT_FULLY_USE_DEVICES true -#define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false -#define DEFAULT_IS_MULTI_SUBGRAPHS false -#define DEFAULT_RUN_PHASE 0 -#define TRAINING_PHASE 0 -#define INFERENCE_PHASE 1 - -class CostGraph; -using CostGraphPtr = std::shared_ptr; -extern CostGraphPtr entire_costgraph; -extern size_t TOTAL_OPS; -extern double COST_MODEL_GAMMA; -extern bool COST_MODEL_SIMPLIFY_CALCULATION; -extern double DEVICE_MEMORY_CAPACITY; -extern double COST_MODEL_COMMUNI_THRESHOLD; -extern double COST_MODEL_COMMUNI_CONST; -extern double COST_MODEL_COMMUNI_BIAS; -extern bool TENSOR_SLICE_ALIGNMENT_ENABLE; -extern size_t TENSOR_SLICE_ALIGNMENT_SIZE; -extern bool FULLY_USE_DEVICES; -extern bool ELEMENTWISE_OP_STRA_FOLLOW; -extern bool MULTI_SUBGRAPHS; -extern int32_t RUN_PHASE; - -class CostGraph { - // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have - // output-input dependency relationship. - public: - CostGraph() { - dev_memory_ = DEFAULT_DEVICE_MEMORY_CAPACITY; - costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA; - } - ~CostGraph() = default; - void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } - OperatorInfoPtr FindOperatorByIndex(size_t index) { - if (index >= ops_.size()) { - MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; - return nullptr; - } - return ops_[index]; - } - void RemoveOperator(const OperatorInfoPtr &op); - bool IsOperatorInCostGraph(const OperatorInfoPtr &op); - // the edge is in the form: u --> v - void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); - std::vector> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } - std::vector> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } - // An edge is uniquely identified by its name, and its output index and input index. - bool IsEdgeInCostGraph(const std::string &, size_t, size_t); - - void SetDeviceMemoryAndCostParameter(); - - std::vector> ConstructConnectedComponents(std::vector); - void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, - const std::shared_ptr &component); - - CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); - CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); - CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); - CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); - CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); - Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); - std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { - return edges_[{u_node, v_node}]; - } - double GetDeviceMemory() const { return dev_memory_; } - - // Search the cost_list in the final graph, and determine the optimal one - Status SearchStrategy(); - - // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated - OperatorInfoPtr CheckOpElimination() const; - // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges - // can be eliminated into one - std::vector CheckEdgeElimination() const; - // Given a graph which contains the following subgraph: - // u - // | - // w --- v --- x - // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v. - // u is returned. - OperatorInfoPtr CheckMergeElimination() const; - // Given a graph which contains the following subgraph: - // u - // | - // v --- x - // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted - // into v. u is returned. - OperatorInfoPtr CheckContractElimination() const; - /* Given a graph which contains the following subgraph: - * u - * / \ - * / \ - * v --- w - * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v. - * The returned value includes u and the edge >. - */ - std::pair CheckTriangleElimination() const; - /* Given a graph which contains the following subgraph: - * v <--- u ---> w - * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections, - * resulting in v and w can not be performed ContractElimination. u is returned. - * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - */ - OperatorInfoPtr CheckStarElimination() const; - // Applying Operator Elimination in DP algorithm - EdgePtr EliminationOp(const OperatorInfoPtr &op); - // Applying Edge Elimination in DP algorithm - EdgePtr EliminationEdges(const std::vector &edges); - // Applying Merge Elimination in DP algorithm - OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); - void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, - const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); - // Applying Contract Elimination in DP algorithm - OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); - void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, - const CostPtrList &, CostPtrList *); - - // Applying Triangle Elimination in DP algorithm. return the left_node - OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); - void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, - const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, - const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); - // Given the relevant costlist, create the TriangleElimination cost - void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, - const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); - - // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op - // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - std::vector EliminationStar(const OperatorInfoPtr &op); - void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, - const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); - void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, - const StrategyPtr &, const CostPtrList &, std::vector, - CostPtrList &, CostPtrList &, CostPtrList *); - // Calculate memory cost for training phase or inference phase. - Status CalculateMemoryCost(); - // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then - // the memory cost can be resused. This is used to calculate memory in the training phase. - Status CalculateOpsMemoryCost(); - // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then - // the memory cost can be reused. This is used to calculate memory in the training phase. - Status CalculateEdgesMemoryCost(); - // Calculate memory cost of operators in the inference phase. - Status CalculateOpsMemoryCostForInference(); - // Calculate memory cost of edges in the inference phase. - Status CalculateEdgesMemoryCostForInference(); - Status ComputeOpsAndEdgesParameterInvolved(); - // Compute for each operator whether the output is critical. - Status ComputeOpsAndEdgesOutputCritical(); - - std::vector GetOperators() const { return ops_; } - size_t GetNumEdges() const; - Status InitSelectedStrategy(); - OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; - // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only - // once (instead of multiple times), this method is used to correct this. - Status CorrectOpsMemoryCost(); - // Needed by rec_parser - void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { - inputs_tensor_name_list_.push_back(inputs_tensor_name); - } - const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } - void add_tuple_getitem(const std::pair &tuple_getitem) { - auto ret = tuple_getitem_list_.insert(tuple_getitem); - if (ret.second == false) { - MS_LOG(EXCEPTION) << "The insert item is already exist."; - } - } - const std::map get_tuple_getitem_list() const { return tuple_getitem_list_; } - - private: - void TopologyOrder(std::vector *); - void DFSForTopoOrder(const OperatorInfoPtr &, std::map *, std::vector *); - Status DetermineCriticalOps(const std::vector &); - void MarkCriticalOpsAndEdges(const std::map &); - // Needed by rec_parser - std::vector> inputs_tensor_name_list_; - std::map tuple_getitem_list_; - double dev_memory_; - double costmodel_alpha_; - double costmodel_beta_; - std::vector ops_; - std::map, std::vector> edges_; - std::vector> connected_compoents_; - std::map> out_edges_; - std::map> in_edges_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc deleted file mode 100644 index 8ebfdb7d13..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ /dev/null @@ -1,892 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/operator_costmodel.h" - -#include -#include -#include "parallel/device_matrix.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } - -void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { - is_parameter_involve_ = is_parameter_inv; -} - -void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } - -void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, - const std::vector &output_lengths) { - inputs_type_lengths_ = input_lengths; - outputs_type_lengths_ = output_lengths; -} - -void OperatorCost::set_output_critical(int critical) { is_outputs_critical_ = critical; } - -double OperatorCost::GetMemoryCost(const std::vector &inputs, - const std::vector &outputs) const { - double result = 0.0; - if (output_parameter_involve_ == 1) { - // When this operator has multiple outputs, they all contributes to the memory. - for (size_t i = 0; i < outputs.size(); ++i) { - result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); - } - bool is_any_para_inv = - std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; }); - if (is_any_para_inv) { - for (size_t i = 0; i < inputs.size(); ++i) { - if (is_parameter_[i]) { - result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } else if (inputs_related_ && (!is_parameter_involve_[i])) { - // When the inputs of this operator are related, and they are not parameter-involved, then they are included - // in the memory cost. - result += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } - } - } - } - - return result; -} - -double OperatorCost::GetMemoryCostForInference(const std::vector &, - const std::vector &outputs) const { - double result = 0.0; - if (is_outputs_critical_ == -1) { - MS_LOG(EXCEPTION) << "The critical flag is not set."; - } - if (is_outputs_critical_ == 1) { - for (size_t i = 0; i < outputs.size(); ++i) { - result += ListProduct(outputs[i].slice_shape()) * static_cast(outputs_type_lengths_[i]); - } - } - return result; -} - -// return the per device communication cost in the forward phase. -double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const { - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - Shape input0_shape = input0.shape(); - Shape input0_slice_shape = input0.slice_shape(); - if (input0_shape[input0_shape.size() - 1] == input0_slice_shape[input0_slice_shape.size() - 1]) { - // If the reduced dimension has not been partitioned, then there is no communication cost. - return 0.0; - } else { - // Else, the communication cost is the size (number of bytes) of a slice of output tensor. - return ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } -} - -// return the per device communication cost in the forward phase. -double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not - // fully utilize all devices - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; // tensor B - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double MatMulCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t) const { - // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) - double result = 0.0; - TensorInfo output0 = outputs[0]; - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - Shape input0_shape = inputs[0].shape(); - if (input0_shape[input0_shape.size() - 1] != input0_slice_shape[input0_slice_shape.size() - 1]) { - // If the reduced dimension has been partitioned, then there is no communication cost. - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; // tensor B - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - - return result; -} - -// Return the per device communication cost in the forward phase. -double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // ReLU is the element-wise operator, thus it does not need communication in the forward phase - return 0.0; -} - -// Return the per device communication cost in the backward phase. -double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input1 = inputs[0]; - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - TensorInfo input0_info = inputs[0]; - Shape input0_slice_shape = input0_info.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const { - return 0.0; -} - -// Return the per device communication cost in the forward phase. -double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // In the forward phase, the communication cost = 0 - return 0.0; -} - -// Return the per device communication cost in the backward phase. -double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input1 = inputs[0]; - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In the forward phase, the computation cost = slice(A) - TensorInfo input0 = inputs[0]; - Shape input0_slice_shape = input0.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -// return the per device communication cost in the forward phase. -double TmpIdentityCost::GetForwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // Identity is the element-wise operator, thus it does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double TmpIdentityCost::GetBackwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // Identity is the element-wise operator, thus it does not need communication in the backward phase - return 0.0; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double TmpIdentityCost::GetForwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, - int32_t) const { - return 0.0; -} - -// Return the per device PEAK memory cost contributed by this operator in a training iteration. -double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { - return 0.0; -} - -double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &, - int32_t) const { - double cost = 0.0; - for (size_t i = 0; i < inputs.size(); ++i) { - cost += ListProduct(inputs[i].slice_shape()) * static_cast(inputs_type_lengths_[i]); - } - return cost; -} - -double BatchParallelCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, - int32_t) const { - return 0.0; -} - -double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t j = 0; j < inputs.size(); ++j) { - if (!is_parameter_[j]) { - continue; - } - TensorInfo input_a_tensor_info = inputs[j]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - } - - return result; -} -// return the per device communication cost in the forward phase. -double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { - // prelu does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In forward phase, the computation cost = slice(A) + slice(B) - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &, - int32_t stage_id) const { - // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) - double result = 0.0; - if (is_parameter_[1]) { - TensorInfo input1 = inputs[1]; // tensor B - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// return the per device communication cost in the forward phase. -double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { - // onehot does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // onehot does not need communication in the backward phase - return 0.0; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In onehot's forward phase, the computation cost = slice(A) - Shape input0_slice_shape = inputs[0].slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const { - return 0.0; -} - -// return the per device communication cost in the forward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, - const std::vector &, int32_t) const { - // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase - return 0.0; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &, int32_t) const { - // In forward phase, the computation cost = slice(A) + slice(B) - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -// return the per device communication cost in the forward phase. -double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const { - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - TensorRedistribution tensor_redistribution(false, true); - if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; - } - if (tensor_redistribution.ComputeCost() == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; - } - return (inputs_type_lengths_[0] * tensor_redistribution.comm_cost()); -} - -// return the per device communication cost in the backward phase. -double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input1 = inputs[0]; - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - Shape input1_shape = input1.shape(); - Shape input1_slice_shape = input1.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input1_shape.size(); ++i) { - used_device_num *= input1_shape[i] / input1_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result = ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - } - return result; -} - -// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes -// this operator uses -double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - TensorRedistribution tensor_redistribution(false, true); - if (tensor_redistribution.Init(inputs[0].tensor_layout(), outputs[0].tensor_layout(), dev_list) == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed."; - } - if (tensor_redistribution.ComputeCost() == FAILED) { - MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed."; - } - return (inputs_type_lengths_[0] * tensor_redistribution.computation_cost()); -} - -// Return the per device computation cost in the backward phase. The cost is calculated according to the bytes -// this operator uses -double ReshapeCost::GetBackwardComputationCost(const std::vector &, - const std::vector &, int32_t) const { - return 0.0; -} - -double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - double result; - result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + - ListProduct(inputs[1].slice_shape()) * static_cast(inputs_type_lengths_[1]); - return result; -} - -double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &, int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - if (is_parameter_[0]) { - TensorInfo input_a_tensor_info = inputs[0]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - if (is_parameter_[1]) { - TensorInfo input_b_tensor_info = inputs[1]; - Shape input_b_shape = input_b_tensor_info.shape(); - Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_b_shape.size(); ++i) { - used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - return result; -} - -double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - if (is_parameter_[0]) { - TensorInfo input_a_tensor_info = inputs[0]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - if (is_parameter_[1]) { - TensorInfo input_b_tensor_info = inputs[1]; - Shape input_b_shape = input_b_tensor_info.shape(); - Shape input_b_slice_shape = input_b_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_b_shape.size(); ++i) { - used_device_num *= input_b_shape[i] / input_b_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_b_slice_shape) * static_cast(inputs_type_lengths_[1]); - } - - return result; -} - -bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - auto strategy0 = shape[0] / slice_shape[0]; - - return (total_device_num == IntToSize(strategy0)); -} - -double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - Shape input0_shape = input0.shape(); - Shape input0_slice_shape = input0.slice_shape(); - if (cross_batch_ && IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { - return result; - } - std::vector dim_list = input0.reduce_dim(); - std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { - return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; - }); - if (pos != dim_list.end()) { - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } - - return result; -} - -double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_[0]) { - TensorInfo input_tensor_info = inputs[0]; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape input_shape = input_tensor_info.shape(); - Shape input_slice_shape = input_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_shape.size(); ++i) { - used_device_num *= input_shape[i] / input_slice_shape[i]; - } - - if (total_device_num != IntToSize(used_device_num)) - result += ListProduct(input_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - - return result; -} - -double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - std::vector dim_list = input0.reduce_dim(); - Shape input0_slice_shape = input0.slice_shape(); - Shape input0_shape = input0.shape(); - if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { - std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { - return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; - }); - if (pos != dim_list.end()) { - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]); - } - } - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); - - return result; -} - -double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - TensorInfo input0 = inputs[0]; - TensorInfo output0 = outputs[0]; - std::vector dim_list = input0.reduce_dim(); - Shape input0_slice_shape = input0.slice_shape(); - Shape input0_shape = input0.shape(); - if (!cross_batch_ || !IsDataParallel(input0_shape, input0_slice_shape, stage_id)) { - std::vector::iterator pos; - pos = std::find_if(dim_list.begin(), dim_list.end(), [input0_shape, input0_slice_shape](int32_t index) { - return input0_shape[IntToSize(index)] != input0_slice_shape[IntToSize(index)]; - }); - if (pos != dim_list.end()) { - result += ListProduct(output0.slice_shape()) * static_cast(outputs_type_lengths_[0]) * 2.0; - } - } - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]); - - return result; -} - -double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - if (inputs.empty()) { - return 0.0; - } - TensorInfo input0 = inputs[0]; - Shape input0_slice_shape = input0.slice_shape(); - return ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; -} - -// return the per device communication cost in the forward phase. -double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, - int32_t) const { - // GatherV2Cost does not need communication in the forward phase - return 0.0; -} - -// return the per device communication cost in the backward phase. -double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t j = 0; j < inputs.size(); ++j) { - if (!is_parameter_[j]) { - continue; - } - TensorInfo input_a_tensor_info = inputs[j]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - } - - return result; -} - -double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - // In forward phase, the computation cost = slice(A) + slice(B) - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - double result = ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - return result; -} - -double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const { - return 0.0; -} - -double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, - int32_t stage_id) const { - double result = 0.0; - if (is_parameter_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid parameter size " << is_parameter_.size() << " for layer norm cost"; - } - if (inputs_type_lengths_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; - } - - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t index = 0; index < inputs.size(); ++index) { - if (is_parameter_[index]) { - TensorInfo tensor_info = inputs[index]; - Shape shape = tensor_info.shape(); - Shape slice_shape = tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < shape.size(); ++i) { - if (slice_shape[i] == 0) { - MS_LOG(EXCEPTION) << "Invalid slice shape " << ShapeToString(slice_shape); - } - used_device_num *= shape[i] / slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); - } - } - } - return result; -} - -double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, - int32_t) const { - double result = 0.0; - if (inputs_type_lengths_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for layer norm cost"; - } - - for (size_t index = 0; index < inputs.size(); ++index) { - TensorInfo tensor_info = inputs[index]; - Shape slice_shape = tensor_info.slice_shape(); - result += ListProduct(slice_shape) * static_cast(inputs_type_lengths_[index]); - } - return result; -} - -double GatherV2PCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const { - double result = 0.0; - if (outputs_type_lengths_.size() != outputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; - } - // don't split axis - if (strategy_.at(IntToSize(axis_)) == 1) { - return result; - } - - // split axis - auto param_shape = inputs[0].slice_shape(); - auto index_shape = inputs[1].slice_shape(); - Shape reducescatter_shape = index_shape; - if (param_shape.size() == 2) { - reducescatter_shape.push_back(param_shape.at(1 - axis_)); - } - result += ListProduct(reducescatter_shape) * static_cast(outputs_type_lengths_[0]); - return result; -} - -double GatherV2PCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const { - double result = 0.0; - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - for (size_t j = 0; j < inputs.size(); ++j) { - if (!is_parameter_[j]) { - continue; - } - TensorInfo input_a_tensor_info = inputs[j]; - Shape input_a_shape = input_a_tensor_info.shape(); - Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); - int32_t used_device_num = 1; - for (size_t i = 0; i < input_a_shape.size(); ++i) { - used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; - } - if (total_device_num != IntToSize(used_device_num)) { - result += ListProduct(input_a_slice_shape) * static_cast(inputs_type_lengths_[0]); - } - } - return result; -} - -double GatherV2PCost::GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const { - double result = 0.0; - Shape input0_slice_shape = inputs[0].slice_shape(); - Shape input1_slice_shape = inputs[1].slice_shape(); - if (inputs_type_lengths_.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; - } - // don't split axis - if (strategy_.at(IntToSize(axis_)) == 1) { - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]); - } else { - // split axis - result += ListProduct(input0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1; - } - - return result; -} - -double GatherV2PCost::GetBackwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t) const { - double result = 0.0; - Shape input1_slice_shape = inputs[1].slice_shape(); - Shape output0_slice_shape = outputs[0].slice_shape(); - // don't split axis - if (strategy_.at(IntToSize(axis_)) == 1) { - result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]); - } else { - // split axis - result += ListProduct(output0_slice_shape) * static_cast(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 + - ListProduct(input1_slice_shape) * static_cast(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3; - } - - return result; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h deleted file mode 100644 index a08a4dbb13..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ /dev/null @@ -1,656 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ -#define PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ - -#include -#include -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_info.h" - -namespace mindspore { -namespace parallel { -#define MAXIMUM_INPUT_NUMBER 100 -#define DEFAULT_DATA_TYPE_LENGTH 4 -#define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory -#define GATHERV2_COST_WEIGHT0 3 -#define GATHERV2_COST_WEIGHT1 7 -#define GATHERV2_COST_WEIGHT2 2 -#define GATHERV2_COST_WEIGHT3 6 - -class OperatorCost; -using OperatorCostPtr = std::shared_ptr; - -template -double ListProduct(std::vector vec) { - double result = 1; - for (size_t i = 0; i < vec.size(); ++i) { - result *= vec[i]; - } - return result; -} -// NOTE: Currently, the returned value in each method is bytes of memory size, which is calculated by the number of -// entries timing the length of each entry's data type -class OperatorCost { - public: - explicit OperatorCost(bool is_inputs_related) : inputs_related_(is_inputs_related) { - // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked - for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { - is_parameter_.push_back(false); - is_parameter_involve_.push_back(false); - inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - } - } - OperatorCost() : inputs_related_(false) { - // this is only for the case when set_is_parameter() and SetInputAndOutputTypeLength() are not invoked - for (size_t i = 0; i < MAXIMUM_INPUT_NUMBER; ++i) { - is_parameter_.push_back(false); - is_parameter_involve_.push_back(false); - inputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - outputs_type_lengths_.push_back(DEFAULT_DATA_TYPE_LENGTH); - } - } - virtual ~OperatorCost() = default; - - void set_is_parameter(const std::vector &is_parameter); - void set_is_parameter_involve(const std::vector &); - void set_output_parameter_involve(int); - void set_output_critical(int); - void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); - std::vector inputs_type_lengths() const { return inputs_type_lengths_; } - std::vector outputs_type_lengths() const { return outputs_type_lengths_; } - - // per device communication cost - virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - // per device computation cost - virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const = 0; - virtual double GetForwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetBackwardComputationCost(const std::vector &inputs, - const std::vector &outputs, int32_t stage_id) const = 0; - // per device PEAK memory cost in a training iteration - // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), - // plus necessary inputs. - virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; - // per device memory cost in a inference phase - double GetMemoryCostForInference(const std::vector &, const std::vector &) const; - - protected: - // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of - // pre-operator that has parameters as input. - std::vector is_parameter_involve_; - int output_parameter_involve_ = -1; // -1: unset; 0: not parameter_involved; 1: parameter_involved - // Whether the inputs are related or not? For example, TensorAdd's two inputs are independent (not related), while - // Mul's two inputs are dependent (related). - bool inputs_related_; - // for each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter - std::vector is_parameter_; - // for each input and output, the followings record the number of bytes of each element - std::vector inputs_type_lengths_; - std::vector outputs_type_lengths_; - // Whether the output is critical, which means that this output is included in calculating peak memory cost - // in the inference phase. - int is_outputs_critical_ = -1; -}; - -using OperatorCostPtr = std::shared_ptr; - -class MatMulCost : public OperatorCost { - public: - explicit MatMulCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - MatMulCost() : OperatorCost(true) {} - ~MatMulCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using MatMulCostPtr = std::shared_ptr; - -class ActivationCost : public OperatorCost { - public: - explicit ActivationCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ActivationCost() : OperatorCost(false) {} - ~ActivationCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ActivationCostPtr = std::shared_ptr; -using TransposeCost = ActivationCost; -using TransposeCostPtr = std::shared_ptr; - -class SoftmaxCost : public OperatorCost { - public: - explicit SoftmaxCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - SoftmaxCost() : OperatorCost(false) {} - ~SoftmaxCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const override; -}; -using SoftmaxCostPtr = std::shared_ptr; - -class TmpIdentityCost : public OperatorCost { - public: - explicit TmpIdentityCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - TmpIdentityCost() : OperatorCost(false) {} - ~TmpIdentityCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; -}; -using TmpIdentityCostPtr = std::shared_ptr; - -class BatchParallelCost : public OperatorCost { - public: - explicit BatchParallelCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - BatchParallelCost() : OperatorCost(false) {} - ~BatchParallelCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using BatchParallelCostPtr = std::shared_ptr; - -class VirtualDatasetCost : public OperatorCost { - public: - explicit VirtualDatasetCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - VirtualDatasetCost() : OperatorCost(false) {} - ~VirtualDatasetCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { - return 0.0; - } -}; -using VirtualDatasetCostPtr = std::shared_ptr; - -class GeneratorBaseCost : public OperatorCost { - public: - explicit GeneratorBaseCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GeneratorBaseCost() : OperatorCost(false) {} - ~GeneratorBaseCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; -using GeneratorBaseCostPtr = std::shared_ptr; - -class PReLUCost : public OperatorCost { - public: - explicit PReLUCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - PReLUCost() : OperatorCost(true) {} - ~PReLUCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using PReLUCostPtr = std::shared_ptr; - -class OneHotCost : public OperatorCost { - public: - explicit OneHotCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - OneHotCost() : OperatorCost(true) {} - ~OneHotCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using OneHotCostPtr = std::shared_ptr; - -class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { - public: - explicit SoftmaxCrossEntropyWithLogitsCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - SoftmaxCrossEntropyWithLogitsCost() : OperatorCost(false) {} - ~SoftmaxCrossEntropyWithLogitsCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; - -class ReshapeCost : public OperatorCost { - public: - explicit ReshapeCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ReshapeCost() : OperatorCost(true) {} - - ~ReshapeCost() override = default; - - // per device communication cost - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - // per device computation cost - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ReshapeCostPtr = std::shared_ptr; - -class ArithmeticCost : public OperatorCost { - public: - explicit ArithmeticCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ArithmeticCost() : OperatorCost(false) {} - ~ArithmeticCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ArithmeticCostPtr = std::shared_ptr; -using BiasAddCost = ArithmeticCost; -using BiasAddCostPtr = std::shared_ptr; - -class ReduceMethodCost : public OperatorCost { - public: - explicit ReduceMethodCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - ReduceMethodCost() : OperatorCost(true) {} - ~ReduceMethodCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - void set_cross_batch(bool cb) { cross_batch_ = cb; } - - protected: - bool cross_batch_ = false; -}; -using ReduceMethodCostPtr = std::shared_ptr; - -class ReduceMeanCost : public ReduceMethodCost { - public: - explicit ReduceMeanCost(bool is_inputs_related) : ReduceMethodCost(is_inputs_related) {} - ReduceMeanCost() : ReduceMethodCost(true) {} - ~ReduceMeanCost() override = default; - - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; -}; -using ReduceMeanCostPtr = std::shared_ptr; - -class GetNextCost : public OperatorCost { - public: - explicit GetNextCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GetNextCost() : OperatorCost(false) {} - ~GetNextCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } - // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; -using GetNextCostPtr = std::shared_ptr; - -class DropOutCost : public OperatorCost { - public: - explicit DropOutCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - DropOutCost() : OperatorCost(true) {} - ~DropOutCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override; - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; - -using DropOutCostPtr = std::shared_ptr; - -class LayerNormCost : public OperatorCost { - public: - explicit LayerNormCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - LayerNormCost() : OperatorCost(true) {} - ~LayerNormCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { - return 0.0; - } - double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override; - double GetBackwardComputationCost(const std::vector &, const std::vector &, - int32_t) const override { - return 0.0; - } -}; - -using DropOutCostPtr = std::shared_ptr; - -class GatherV2Cost : public OperatorCost { - public: - explicit GatherV2Cost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} - GatherV2Cost() : OperatorCost(true) {} - ~GatherV2Cost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const override; -}; - -using GatherV2CostPtr = std::shared_ptr; - -class GatherV2PCost : public OperatorCost { - public: - explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {} - GatherV2PCost() : OperatorCost(true), axis_(0) {} - ~GatherV2PCost() override = default; - - double GetCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); - } - double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override { - return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); - } - double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, - int32_t) const override; - void set_axis(int32_t axis) { axis_ = axis; } - void set_strategy(const Shape &strategy) { strategy_ = strategy; } - - protected: - int32_t axis_; - Shape strategy_; -}; - -using GatherV2PCostPtr = std::shared_ptr; -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc deleted file mode 100644 index 9fb79ceee4..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc +++ /dev/null @@ -1,750 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_cost.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" - -namespace mindspore { -namespace parallel { - -// Compute redistributed cost -double CostRedis(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::vector> &mode, const Graph &graph) { - // Store value of cost redist - double cost_redis = 0; - - // Number of current strategies. - size_t num_strategy = node_name_to_strategy.size(); - - // Number of node-in and node-out - size_t num_node_in = node.node_in.size(); - size_t num_node_out = node.node_out.size(); - - // Set tensor edge value with original tensor shape and cutting times. - double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n * - node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c * - node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h * - node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w; - - double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n * - node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c * - node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h * - node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w; - - // For each strategy candidate. - for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) { - // Find its forward nodes - for (size_t i_node = 0; i_node < num_node_in; i_node++) { - if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first) { - bool is_search_forward = true; - cost_redis += - CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward); - } - } - - // Find its backward nodes - for (size_t i_node = 0; i_node < num_node_out; i_node++) { - if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first) { - bool is_search_forward = false; - cost_redis += - CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward); - } - } - } - - return cost_redis; -} - -double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, - const std::vector> &mode, size_t i_strategy, size_t i_node, - double tensor_size, bool search_forward) { - double new_redis_cost = 0; - int counter = 0; - - if (search_forward) { - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_n) != - static_cast(1 / mode[i_node][0])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_c) != - static_cast(1 / mode[i_node][1])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_h) != - static_cast(1 / mode[i_node][2])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_w) != - static_cast(1 / mode[i_node][3])) { - counter += 1; - } - } else { - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_n) != - static_cast(1 / mode[2][0])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_c) != - static_cast(1 / mode[2][1])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_h) != - static_cast(1 / mode[2][2])) { - counter += 1; - } - if (static_cast(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_w) != - static_cast(1 / mode[2][3])) { - counter += 1; - } - } - - if (counter >= 2) { - new_redis_cost = tensor_size / 4.0; - } else if (counter == 0 || counter == 1) { - new_redis_cost = 0; - } else { - MS_LOG(EXCEPTION) << "Failure: CostRedis failed."; - } - - return new_redis_cost; -} - -// Get optimal strategy for MatMul -StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { - int edge_i = - static_cast(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h); - int edge_j = - static_cast(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w); - int edge_k = - static_cast(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w); - - std::vector cost_op; - std::vector> mode; - - if (edge_i < 2 || edge_i % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrConcatDimI(edge_j, edge_k) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}}, - graph)); - } - - if (edge_j < 2 || edge_j % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrConcatDimJ(edge_i, edge_k) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, - graph)); - } - - if (edge_k < 2 || edge_k % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrReduceDimK(edge_i, edge_j) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}}, - graph)); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Get weight for MatMul -double CostMatMul::GetMinCostIn(const OperatorRec &op) { - int edge_i = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int edge_j = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); - int edge_k = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - - std::vector cost_in; - cost_in.push_back(StrConcatDimI(edge_j, edge_k)); - cost_in.push_back(StrConcatDimJ(edge_i, edge_k)); - cost_in.push_back(StrReduceDimK(edge_i, edge_j)); - - return *min_element(cost_in.begin(), cost_in.end()); -} - -// Chose strategy for MatMul -StrategyRec CostMatMul::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_i_; - break; - - case 1: - str.inputTensor[1].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_j_; - break; - - case 2: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_k_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure:CostMatMul failed."; - } - - return str; -} - -// Get optimal strategy for Conv -StrategyRec CostConvolution::GetOptimalStr( - const Graph::NodeType &node, const std::vector> &node_name_to_strategy, - const Graph &graph, bool channel_partition) { - const OperatorRec &op = node.apply; - - int input_tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int input_tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - int input_tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); - int input_tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - - int tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c; - - int tensor_filter_h = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h); - int tensor_filter_w = static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w); - int tensor_filter_n = static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n); - int tensor_filter_c = static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); - - int tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c; - - int output_tensor_h = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h); - int output_tensor_w = static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w); - int output_tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); - int output_tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); - - int tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c; - - std::vector cost_op; - cost_op.reserve(7); - std::vector> mode; - - if (input_tensor_n < 2 || input_tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); - } - - cost_op.push_back(DOUBLE_MAX); - cost_op.push_back(DOUBLE_MAX); - - if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}}, graph)); - } - - cost_op.push_back(DOUBLE_MAX); - cost_op.push_back(DOUBLE_MAX); - - if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy, - mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}}, graph)); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Get weight for Conv -double CostConvolution::GetMinCostIn(const Graph::NodeType &node) { - const OperatorRec &op = node.apply; - - int tensor_in = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) * - static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) * - static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) * - static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - int tensor_filter = static_cast(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) * - static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) * - static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) * - static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); - int tensor_out = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) * - static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) * - static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) * - static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); - - std::vector cost_in; - cost_in.push_back(StrDimB(tensor_filter)); - cost_in.push_back(StrDimI(tensor_in, tensor_filter)); - cost_in.push_back(StrDimJ(tensor_in, tensor_filter)); - cost_in.push_back(StrDimK(tensor_in)); - cost_in.push_back(StrDimDI(tensor_in, tensor_out)); - cost_in.push_back(StrDimDJ(tensor_in, tensor_out)); - cost_in.push_back(StrDimQ(tensor_out)); - - return *min_element(cost_in.begin(), cost_in.end()); -} - -// Chose strategy for Conv -StrategyRec CostConvolution::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_b_; - break; - - case 1: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_i_; - break; - - case 2: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_j_; - break; - - case 3: - str.inputTensor[1].str_n /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_k_; - break; - - case 4: - str.inputTensor[1].str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_di_; - break; - - case 5: - str.inputTensor[1].str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_dj_; - break; - - case 6: - str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_q_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostConvolution failed."; - } - return str; -} - -// Get optimal strategy for Pooling -StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { - int tensor_n = static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n); - int tensor_c = static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); - - std::vector cost_op; - std::vector> mode; - - if (tensor_n < 2 || tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); - } - - if (tensor_c < 2 || tensor_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); - } - - cost_op.push_back(DOUBLE_MAX); - cost_op.push_back(DOUBLE_MAX); - - return ChoseStr(cost_op, node.apply.str); -} - -// Chose strategy for Pooling -StrategyRec CostPooling::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostPooling failed."; - } - return str; -} - -// Chose strategy for Add -StrategyRec CostTensorAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.inputTensor[1].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.inputTensor[1].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostAdd failed."; - } - return str; -} - -// Get optimal strategy for Reshape -StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); } - -StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; } - -// Chose strategy for BiasAdd -StrategyRec CostBiasAdd::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed."; - } - return str; -} - -// Get optimal strategy for Common OPs -StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph) { - const OperatorRec &op = node.apply; - int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); - int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - - std::vector cost_op; - std::vector> mode; - - if (tensor_n < 2 || tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph)); - } - - if (tensor_c < 2 || tensor_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph)); - } - - if (tensor_h < 2 || tensor_h % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}}, graph)); - } - - if (tensor_w < 2 || tensor_w % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, - mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, graph)); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Chose strategy for Common op -StrategyRec CostCommon::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: Common failed."; - } - return str; -} - -// Get optimal strategy for BatchParallel OPs -StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) { - const OperatorRec &op = node.apply; - int tensor_n = static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n); - int tensor_c = static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); - int tensor_h = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h); - int tensor_w = static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w); - - std::vector cost_op; - - if (tensor_n < 2 || tensor_n % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - if (tensor_c < 2 || tensor_c % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - if (tensor_h < 2 || tensor_h % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - if (tensor_w < 2 || tensor_w % 2 != 0) { - cost_op.push_back(DOUBLE_MAX); - } else { - cost_op.push_back(cost_in_); - } - - return ChoseStr(cost_op, node.apply.str); -} - -// Chose strategy for BatchParallel op -StrategyRec CostBatchParallel::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.outputTensor.str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.outputTensor.str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.outputTensor.str_h /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed."; - } - return str; -} - -// Chose strategy for CostSoftmaxCrossEntropyWithLogits -StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector &cost_op, StrategyRec str) { - uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin(); - if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) { - return str; - } - - switch (min_position) { - case 0: - str.inputTensor[0].str_n /= 2.0; - str.inputTensor[1].str_n /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 1: - str.inputTensor[0].str_c /= 2.0; - str.inputTensor[1].str_c /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 2: - str.inputTensor[0].str_h /= 2.0; - str.inputTensor[1].str_h /= 2.0; - str.outputTensor.str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - case 3: - str.inputTensor[0].str_w /= 2.0; - str.inputTensor[1].str_w /= 2.0; - str.cut_counter += 1; - str.cost = str.cost + cost_in_; - break; - - default: - MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed."; - } - return str; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h deleted file mode 100644 index fb4fc27164..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h +++ /dev/null @@ -1,233 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_COST_H_ -#define PARALLEL_AUTO_PARALLEL_REC_COST_H_ - -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_strategy.h" - -namespace mindspore { -namespace parallel { -#define DOUBLE_MAX (std::numeric_limits::max)() - -double CostRedis(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::vector> &mode, const Graph &graph); - -double CostRedisWithAdjacentNode(const std::vector> &node_name_to_strategy, - const std::vector> &mode, size_t i_strategy, size_t i_node, - double tensor_size, bool is_search_forward); - -// class CostMatMul is used to compute the cost of MatMul operator. -class CostMatMul { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); - - double GetMinCostIn(const OperatorRec &op); - - private: - double StrConcatDimI(int32_t a, int32_t b) { - cost_in_i_ = (static_cast(a) * static_cast(b)) / 2.0; - - return cost_in_i_; - } - - double StrConcatDimJ(int32_t a, int32_t b) { - cost_in_j_ = (static_cast(a) * static_cast(b)) / 2.0; - - return cost_in_j_; - } - - double StrReduceDimK(int32_t a, int32_t b) { - cost_in_k_ = (static_cast(a) * static_cast(b)) / 2.0; - - return cost_in_k_; - } - - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_i_ = 0; - - double cost_in_j_ = 0; - - double cost_in_k_ = 0; -}; // class CostMatMul is used to compute the cost of MatMul operator. - -// class CostConvolution is used to compute the cost of Conv operator. -class CostConvolution { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph, bool channel_partition); - - double GetMinCostIn(const Graph::NodeType &node); - - private: - double StrDimB(int32_t TensorFilter) { - cost_in_b_ = static_cast((TensorFilter) / 2.0); - - return cost_in_b_; - } - - double StrDimI(int32_t TensorIn, int32_t TensorFilter) { - cost_in_i_ = static_cast((TensorIn + TensorFilter) / 2.0); - - return cost_in_i_; - } - - double StrDimJ(int32_t TensorIn, int32_t TensorFilter) { - cost_in_j_ = static_cast((TensorIn + TensorFilter) / 2.0); - - return cost_in_j_; - } - - double StrDimK(int32_t TensorIn) { - cost_in_k_ = static_cast((TensorIn) / 2.0); - - return cost_in_k_; - } - - double StrDimDI(int32_t TensorIn, int32_t TensorOut) { - cost_in_di_ = static_cast((TensorIn + TensorOut) / 2.0); - - return cost_in_di_; - } - - double StrDimDJ(int32_t TensorIn, int32_t TensorOut) { - cost_in_dj_ = static_cast((TensorIn + TensorOut) / 2.0); - - return cost_in_dj_; - } - - double StrDimQ(int32_t TensorOut) { - cost_in_q_ = static_cast((TensorOut) / 2.0); - - return cost_in_q_; - } - - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_b_ = 0; - - double cost_in_i_ = 0; - - double cost_in_j_ = 0; - - double cost_in_k_ = 0; - - double cost_in_di_ = 0; - - double cost_in_dj_ = 0; - - double cost_in_q_ = 0; -}; // class CostConvolution is used to compute the cost of Conv operator. - -// class CostPooling is used to compute the cost of Pooling operator. -class CostPooling { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); - - double GetMinCostIn() const { return cost_in_; } - - private: - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_ = 0; -}; // class CostPooling is used to compute the cost of Pooling operator. - -// class CostReshape is used to compute the cost of Reshape operator. -class CostReshape { - public: - StrategyRec GetOptimalStr(const Graph::NodeType &node) const; - - double GetMinCostIn() const { return cost_in_; } - - private: - StrategyRec ChoseStr(StrategyRec str) const; - - double cost_in_ = 0; -}; // class CostReshape is used to compute the cost of Reshape operator. - -// class CostCommon is used to compute the cost of an element-wise operator -class CostCommon { - public: - virtual StrategyRec GetOptimalStr(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const Graph &graph); - - virtual double GetMinCostIn() const { return cost_in_; } - - protected: - virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_ = 0; -}; // class CostCommon is used to compute the cost of an element-wise operator - -// class CostBiasAdd is used to compute the cost of the addition between a tensor and a bias -class CostBiasAdd : public CostCommon { - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); -}; -// class CostAdd is used to compute the cost of Add operator. -class CostTensorAdd : public CostCommon { - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); -}; - -// all the following operation are element-wise and have the same cost -class CostReLU : public CostCommon {}; -class CostLog : public CostCommon {}; -class CostExp : public CostCommon {}; -class CostAdd : public CostCommon {}; -class CostSub : public CostCommon {}; -class CostMul : public CostCommon {}; -class CostDiv : public CostCommon {}; -class CostSqueeze : public CostCommon {}; -class CostCast : public CostCommon {}; - -// class BatchParallel is used to compute the cost of BatchParallel operator. -class CostBatchParallel { - public: - virtual StrategyRec GetOptimalStr(const Graph::NodeType &node); - - virtual double GetMaxCostIn() const { return DOUBLE_MAX; } - - protected: - virtual StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); - - double cost_in_ = 0; -}; // class BatchParallel is used to compute the cost of BatchParallel operator. - -class CostBatchNorm : public CostBatchParallel {}; -class CostOneHot : public CostBatchParallel {}; -class CostPRelu : public CostBatchParallel {}; -class CostSoftmax : public CostBatchParallel {}; - -class CostSoftmaxCrossEntropyWithLogits : public CostBatchParallel { - StrategyRec ChoseStr(const std::vector &cost_op, StrategyRec str); -}; -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_COST_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc deleted file mode 100644 index 828523fed1..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ /dev/null @@ -1,837 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_generate_strategy.h" - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/rec_core/rec_parse_graph.h" -#include "parallel/auto_parallel/rec_core/rec_partition.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr>> &eli_list, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(eli_list); - MS_EXCEPTION_IF_NULL(index_list); - GeneratePartitionedOperatorStrategy(graph, ops, index_list); - std::shared_ptr> no_stra_op_list(new std::vector); - for (size_t i = 0; i < eli_list->size(); i++) { - no_stra_op_list->push_back(eli_list->at(i)[0]); - } - GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); - GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); - GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list); -} - -std::vector> PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies; - auto attrs = ops[iter_ops]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - - // HCCL does not support multi-dimension partition, and the hardware does not support excessive - // number of EVENT, so we temporarily disable matmul's multi-dimension partition function. - const auto max_cut = 1.0 / g_device_manager->DeviceNum(); - if (graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h != max_cut && - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w != max_cut) { - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = 1.0; - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_w = 1.0; - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_h = 1.0; - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; - - auto shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; - if (transpose_a) { - shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[1]; - } - auto shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[1]; - if (transpose_b) { - shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[0]; - } - - bool already_cut = false; - if (shape_1 >= shape_4) { - if (shape_1 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; - already_cut = true; - } - if (!already_cut && shape_4 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; - already_cut = true; - } - } else { - if (shape_4 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut; - already_cut = true; - } - if (!already_cut && shape_1 % g_device_manager->DeviceNum() == 0) { - graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut; - already_cut = true; - } - } - - if (!already_cut) { - MS_LOG(EXCEPTION) << "Failure: MatMul's shape is invalid."; - } - } - - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - std::vector s; - if (transpose_a && (iter_op_inputs == 0)) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - } else if (transpose_b && (iter_op_inputs == 1)) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - } else { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } - strategies.push_back(s); - } - return strategies; -} - -std::vector> PrepareBiasAdd(const std::shared_ptr> &s) { - std::vector> strategies; - strategies.push_back(*s); - std::vector s_biasadd; - s_biasadd.push_back(s->at(1)); - strategies.push_back(s_biasadd); - return strategies; -} - -std::vector> PrepareOneHot(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - std::vector> strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); - - int32_t axis = -1; - auto iter = ops[iter_ops]->attrs().find(AXIS); - if (iter != ops[iter_ops]->attrs().end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis = iter->second->cast()->value(); - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int."; - } - } - if (axis == -1) { - strategies[0][0] = strategies[0][1]; - strategies[0][1] = 1; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; - } - - std::vector s_empty = {}; - strategies.push_back(s_empty); - strategies.push_back(s_empty); - return strategies; -} - -std::vector> PrepareGatherV2(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector> strategies; - - auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); - if (axis_input < 0) { - axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); - } - int32_t axis = axis_input; - if (axis >= SizeToInt(s.size())) { - MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; - } - s[axis] = 1; - strategies.push_back(s); - - auto pos = ops[iter_ops]->name().find("Info"); - auto name = ops[iter_ops]->name().substr(0, pos); - if (name == "GatherV2") { - return strategies; - } - - std::vector s_indices; - for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { - s_indices.push_back(1); - } - strategies.push_back(s_indices); - - return strategies; -} - -std::vector> PrepareL2Normalize(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - int32_t axis = 0; - auto iter = ops[iter_ops]->attrs().find(AXIS); - if (iter != ops[iter_ops]->attrs().end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis = iter->second->cast()->value(); - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int."; - } - } - - int32_t axis_index = axis; - if (axis < 0) { - size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); - axis_index = static_cast(input_dim) + axis; - } - - s[IntToSize(axis_index)] = 1; - - std::vector> strategies; - strategies.push_back(s); - return strategies; -} - -std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - if (ops.empty()) { - MS_LOG(EXCEPTION) << "Failure: Operators is empty."; - } - if (iter_ops >= ops.size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - - StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - std::vector> strategies; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { - MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; - } - - size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - std::vector s; - if (output_size == 4) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 2) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 1) { - s.push_back( - static_cast(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w)); - } else if (output_size == 0) { - s = {}; - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted."; - } - strategies.push_back(s); - } - return strategies; -} - -std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - if (ops.empty()) { - MS_LOG(EXCEPTION) << "Failure: Operators is empty."; - } - if (iter_ops >= ops.size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - - StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - std::vector> strategies; - size_t max_device_num = g_device_manager->DeviceNum(); - size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0]; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { - MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; - } - - std::vector s; - size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); - for (size_t dim = 0; dim < input_size; dim++) { - if (input_size == 1 || input_size == 2 || input_size == 4) { - if (dim == 0) { - s.push_back(std::min(max_device_num, target_tensor_batch)); - } else { - s.push_back(1); - } - } else if (input_size == 0) { - s = {}; - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; - } - } - strategies.push_back(s); - } - - graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0; - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0; - if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { - graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { - graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { - graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch); - } - - return strategies; -} - -std::vector> PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops) { - if (ops.empty()) { - MS_LOG(EXCEPTION) << "Failure: Operators is empty."; - } - if (iter_ops >= ops.size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - MS_EXCEPTION_IF_NULL(ops[iter_ops]); - - auto type = ops[iter_ops]->type(); - auto idx = DictOpType.find(type); - if (idx == DictOpType.end()) { - return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); - } - - if (type == MATMUL) { - return PrepareMatMul(graph, ops, iter_graph, iter_ops); - } else if (type == ONEHOT) { - return PrepareOneHot(graph, ops, iter_graph, iter_ops); - } else { - return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops); - } -} - -void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::shared_ptr> &index_list) { - for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) { - std::vector> strategies; - size_t iter_graph = index_list->at(iter_ops); - if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) { - strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops); - } - StrategyPtr sp = std::make_shared(0, strategies); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } -} - -size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, - const size_t iter_ops) { - size_t incoming_op_index = SIZE_MAX; - for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) { - for (size_t j = 0; j < input_tensor_names.size(); j++) { - if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) { - incoming_op_index = j; - break; - } - } - if (incoming_op_index != SIZE_MAX) { - break; - } - } - return incoming_op_index; -} - -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_ops, const size_t iter_graph) { - std::vector s; - for (auto input : ops[iter_ops]->inputs_tensor_info()) { - auto input_stra_dim = input.shape().size(); - if (input_stra_dim == 0) { - continue; - } - if (input_stra_dim == 1) { - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); - } else if (input_stra_dim == 2) { - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); - } else if (input_stra_dim == 4) { - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h); - s.push_back(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w); - } else { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown."; - } - break; - } - return s; -} - -std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t incoming_op_index) { - std::vector s; - if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2 || - ops[incoming_op_index]->type() == TRANSPOSE) { - return s; - } - auto strategy = ops[incoming_op_index]->selected_strategy(); - if (strategy->GetInputNumber() == 0) { - return s; - } - - for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) { - if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) { - continue; - } - for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) { - s.push_back(strategy->GetInputDim()[i][j]); - } - break; - } - return s; -} - -std::vector GetAxisList(const std::vector> &ops, const int iter_ops) { - std::vector axis_list; - auto axis_param = ops[iter_ops]->attrs().find(AXIS)->second; - std::vector elements; - if (axis_param->isa()) { - elements = axis_param->cast()->value(); - } else if (axis_param->isa()) { - elements = axis_param->cast()->value(); - } else { - MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl; - } - - for (auto &element : elements) { - if (!element->isa()) { - MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl; - } - auto axis = element->cast()->value(); - axis_list.push_back(axis); - } - return axis_list; -} - -std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - std::vector s_Squeeze; - std::vector stra_dim_list; - for (size_t i = 0; i < s.size(); i++) { - stra_dim_list.push_back(i); - } - - auto axis_list = GetAxisList(ops, incoming_op_index); - for (auto axis : axis_list) { - auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis); - if (it == stra_dim_list.end()) { - MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; - } - if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[axis] != 1) { - MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl; - } - stra_dim_list.erase(it); - } - - for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) { - s_Squeeze.push_back(s[stra_dim_list[i]]); - } - return s_Squeeze; -} - -bool GetKeepDims(const std::vector> &ops, const size_t iter_ops) { - bool keepdims = false; - auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS); - if (keep_dims_iter == ops[iter_ops]->attrs().end()) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims."; - } - MS_EXCEPTION_IF_NULL(keep_dims_iter->second); - if (!keep_dims_iter->second->isa()) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool."; - } - keepdims = keep_dims_iter->second->cast()->value(); - return keepdims; -} - -std::vector GetDimList(const std::vector> &ops, const size_t iter_ops) { - std::vector dim_list; - bool keep_dims = GetKeepDims(ops, iter_ops); - if (keep_dims != false) { - return dim_list; - } - auto input_value = ops[iter_ops]->input_value(); - auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); - if (input_value.back()->isa()) { - auto attr_axis = GetValue>(input_value.back()); - if (attr_axis.empty()) { - MS_LOG(EXCEPTION) << "Failure: This output is a 0-D tensor." << std::endl; - } - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } else if (input_value.back()->isa()) { - int axis = GetValue(input_value.back()); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl; - } - return dim_list; -} - -std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - std::vector s_Reduce; - std::vector axis_list; - for (size_t i = 0; i < s.size(); i++) { - axis_list.push_back(i); - } - - auto dim_list = GetDimList(ops, incoming_op_index); - for (auto axis : dim_list) { - auto it = find(axis_list.begin(), axis_list.end(), axis); - if (it == axis_list.end()) { - MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; - } - axis_list.erase(it); - } - - for (size_t i = 0; i < (size_t)axis_list.size(); i++) { - s_Reduce.push_back(s[axis_list[i]]); - } - return s_Reduce; -} - -std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops) { - std::vector dim_list; - auto iter = ops[iter_ops]->attrs().find(AXIS); - if (iter == ops[iter_ops]->attrs().end()) { - MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis."; - } - auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size(); - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - auto attr_axis = GetValue>(iter->second); - if (attr_axis.empty()) { - for (size_t i = 0; i < input_dim; ++i) { - dim_list.push_back(SizeToInt(i)); - } - } else { - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } - } else if (iter->second->isa()) { - int axis = GetValue(iter->second); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Axis type is invalid."; - } - return dim_list; -} - -std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s) { - bool keepdims = GetKeepDims(ops, incoming_op_index); - if (keepdims) { - return s; - } - - std::vector s_Arg; - std::vector axis_list; - for (size_t i = 0; i < s.size(); i++) { - axis_list.push_back(i); - } - - auto dim_list = GetDimListFromAttrs(ops, incoming_op_index); - for (auto axis : dim_list) { - auto it = find(axis_list.begin(), axis_list.end(), axis); - if (it == axis_list.end()) { - MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl; - } - axis_list.erase(it); - } - - for (size_t i = 0; i < (size_t)axis_list.size(); i++) { - s_Arg.push_back(s[axis_list[i]]); - } - return s_Arg; -} - -std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t incoming_op_index) { - std::vector s; - s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index); - if (s.size() != 0) { - if (ops[incoming_op_index]->type() == SQUEEZE) { - s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s); - } - if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX || - ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) { - s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s); - } - if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) { - s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s); - } - } - return s; -} - -std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, - std::vector basic_stra) { - std::vector s_empty = {}; - std::vector> stra; - MS_EXCEPTION_IF_NULL(ops[iter_ops]); - - if (basic_stra.size() == 0) { - for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); - iter_op_inputs++) { - stra.push_back(basic_stra); - } - return stra; - } - - auto s_ptr = std::make_shared>(basic_stra); - if (ops[iter_ops]->type() == BIAS_ADD) { - return PrepareBiasAdd(s_ptr); - } - if (ops[iter_ops]->type() == GATHERV2) { - return PrepareGatherV2(ops, iter_ops, basic_stra); - } - if (ops[iter_ops]->type() == L2_NORMALIZE) { - return PrepareL2Normalize(ops, iter_ops, basic_stra); - } - - for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size(); - iter_op_inputs++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) { - stra.push_back(s_empty); - continue; - } - - std::vector tmp_stra = basic_stra; - bool modified = false; - for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) { - tmp_stra[j] = 1; - modified = true; - } - } - if (modified) { - stra.push_back(tmp_stra); - } else { - stra.push_back(basic_stra); - } - } - return stra; -} - -void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list) { - if (no_stra_op_list->size() == 0) { - return; - } - std::vector no_stra_op_list_bis; - - for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { - size_t iter_ops = no_stra_op_list->at(iter_list - 1); - std::vector> stra; - std::vector s; - size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops); - if (incoming_op_index != SIZE_MAX) { - auto iter_graph = index_list->at(incoming_op_index); - if (iter_graph != SIZE_MAX) { - s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph); - } else { - s = CopyIncomingOperatorInputStrategy(ops, iter_ops, incoming_op_index); - } - } - - if (s.size() == 0) { - no_stra_op_list_bis.push_back(iter_ops); - } else { - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); - } - - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } - - no_stra_op_list->clear(); - for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { - no_stra_op_list->push_back(no_stra_op_list_bis[i]); - } -} - -std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s) { - std::vector s_Squeeze; - auto axis_list = GetAxisList(ops, iter_ops); - size_t s_index = 0; - size_t axis_list_index = 0; - for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) { - if (i == (size_t)axis_list[axis_list_index]) { - s_Squeeze.push_back(1); - axis_list_index++; - } else { - s_Squeeze.push_back(s[s_index]); - s_index++; - } - } - - size_t cut = 1; - for (size_t i = 0; i < s_Squeeze.size(); i++) { - cut *= s_Squeeze[i]; - } - if (cut != g_device_manager->DeviceNum()) { - s_Squeeze.clear(); - } - - return s_Squeeze; -} - -std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, - const std::vector> &input_tensor_names, - const size_t iter_ops) { - std::vector s; - if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN || - ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE || - ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE || - ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) { - return s; - } - - bool found = false; - size_t outgoing_op_index = SIZE_MAX; - size_t iter_op_inputs = SIZE_MAX; - for (size_t i = 0; i < input_tensor_names.size(); i++) { - for (size_t j = 1; j < input_tensor_names[i].size(); j++) { - if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] && - ops[i]->selected_strategy()->GetInputNumber() != 0) { - outgoing_op_index = i; - iter_op_inputs = j - 1; - found = true; - break; - } - } - if (found) { - break; - } - } - - if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) { - for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) { - s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]); - } - } - return s; -} - -void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &no_stra_op_list) { - if (no_stra_op_list->size() == 0) { - return; - } - std::vector no_stra_op_list_bis; - - for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) { - auto iter_ops = no_stra_op_list->at(iter_list - 1); - std::vector> stra; - std::vector s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops); - - if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) { - s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s); - } - if (s.size() != 0) { - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); - } else { - no_stra_op_list_bis.push_back(iter_ops); - } - - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } - - no_stra_op_list->clear(); - for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) { - no_stra_op_list->push_back(no_stra_op_list_bis[i]); - } -} - -void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list) { - if (no_stra_op_list->size() == 0) { - return; - } - - size_t no_stra_op_list_size = no_stra_op_list->size(); - do { - no_stra_op_list_size = no_stra_op_list->size(); - GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list); - GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list); - } while (no_stra_op_list_size > no_stra_op_list->size()); - - for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) { - auto iter_ops = no_stra_op_list->at(iter_list); - std::vector> stra; - std::vector s; - - size_t max_dim_num = 0; - for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) { - max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); - } - } - for (size_t i = 0; i < max_dim_num; i++) { - s.push_back(1); - } - - stra = GenerateStrategiesFromStrategy(ops, iter_ops, s); - StrategyPtr sp = std::make_shared(0, stra); - ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h deleted file mode 100644 index e82efe6798..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ -#define PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ - -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -void GenerateStrategy(const std::shared_ptr &graph, const std::vector> &ops, - const std::shared_ptr>> &eli_list, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list); -std::vector> PrepareMatMul(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareBiasAdd(const std::shared_ptr> &s); -std::vector> PrepareOneHot(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareGatherV2(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector> PrepareL2Normalize(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> MakeDataParallelStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_graph, const size_t iter_ops); -void GeneratePartitionedOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::shared_ptr> &index_list); -size_t FindIndexOfOperatorIncoming(const std::vector> &input_tensor_names, - const size_t iter_ops); -std::vector CopyIncomingOperatorOutputStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const size_t iter_ops, const size_t iter_graph); -std::vector PrepareIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t incoming_op_index); -std::vector GetAxisList(const std::vector> &ops, const int iter_ops); -std::vector ModifyStrategyIfSqueezeIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -bool GetKeepDims(const std::vector> &ops, const size_t iter_ops); -std::vector GetDimList(const std::vector> &ops, const size_t iter_ops); -std::vector ModifyStrategyIfReduceIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -std::vector GetDimListFromAttrs(const std::vector> &ops, const size_t iter_ops); -std::vector ModifyStrategyIfArgIncoming(const std::vector> &ops, - const size_t incoming_op_index, std::vector s); -std::vector CopyIncomingOperatorInputStrategy(const std::vector> &ops, - const size_t iter_ops, const size_t incoming_op_index); -std::vector> GenerateStrategiesFromStrategy(const std::vector> &ops, - const size_t iter_ops, - std::vector basic_stra); -void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list); -std::vector ModifyStrategyIfSqueezeOutgoing(const std::vector> &ops, - const size_t iter_ops, std::vector s); -std::vector CopyOutgoingOperatorInputStrategy(const std::vector> &ops, - const std::vector> &input_tensor_names, - const size_t iter_ops); -void GenerateEliminatedOperatorStrategyBackward(const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &no_stra_op_list); -void GenerateRemainingOperatorStrategy(const std::shared_ptr &graph, - const std::vector> &ops, - const std::vector> &input_tensor_names, - const std::shared_ptr> &index_list, - const std::shared_ptr> &no_stra_op_list); -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h deleted file mode 100644 index 9007218d15..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_graph.h +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ -#define PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ - -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_strategy.h" -#include "parallel/auto_parallel/rec_core/rec_tensor.h" - -namespace mindspore { -namespace parallel { -enum OperatorType { - kRecUnkownType, - kRecMatMul, - kRecConvolution, - kRecPooling, - kRecElmWiseOp, - kRecReLU, - kRecBatchNorm, - kRecReshape, - kRecBiasAdd, - kRecSoftmax, - kRecSparseSoftmaxCrossEntropyWithLogits, - kRecSoftmaxCrossEntropyWithLogits, - kRecOneHot, - kRecLog, - kRecExp, - kRecAdd, - kRecSub, - kRecMul, - kRecDiv, - kRecSqueeze, - kRecCast, - kRecReduce, - kRecPReLU, - kRecGatherV2, - kRecArgWithValue -}; - -enum InfoType { kApplication, kConstant }; - -struct OperatorRec { - OperatorType op_type; - TensorParam arguments[MAX_INPUT_NUM]; - StrategyRec str; -}; - -// Define simplified dataflow Graph for partitioning -class Graph { - public: - struct NodeType { - std::string name; - // Nodes that point to this node - std::vector node_in; - // Nodes that point from this node - std::vector node_out; - std::vector node_in_aux; - // Node Type Info: Application or Constant. Defined in enum . - InfoType info; - // Operator info. Defined in struct . - OperatorRec apply; - // Tensor info. Defined in tensor.h struct . - TensorParam tensor_parm; - }; - - std::vector nodes; // Nodes of the graph. Pubic. -}; // Define simplified dataflow Graph for partitioning -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc deleted file mode 100644 index 0e6a3411e3..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ /dev/null @@ -1,264 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_parse_graph.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_tensor.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -const TensorParam MakeTensor(int n, int c, int h, int w) { - TensorParam new_tensor; - new_tensor.tensor_type = kFloat32; - new_tensor.tensor_shape.shape_n = n; - new_tensor.tensor_shape.shape_c = c; - new_tensor.tensor_shape.shape_h = h; - new_tensor.tensor_shape.shape_w = w; - const TensorParam &tensor = new_tensor; - return tensor; -} - -Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { - Graph::NodeType NewOp; - NewOp.name = ops[iter_ops]->name(); - NewOp.info = InfoType::kApplication; - - auto op_type = ops[iter_ops]->type(); - auto idx = DictOpType.find(op_type); - if (idx == DictOpType.end()) { - NewOp.apply.op_type = OperatorType::kRecUnkownType; - MS_LOG(INFO) << "Unknown operator type."; - } else { - NewOp.apply.op_type = DictOpType.at(op_type); - } - - if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { - NewOp.tensor_parm = MakeTensor( - ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], - ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { - NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); - } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; - } - - NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); - return NewOp; -} - -OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType NewTensor) { - for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); - iter_input_tensors++) { - if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { - NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { - NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape is unknown."; - } - } - return NewTensor.apply; -} - -TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, - const size_t iter_input_tensors, Graph::NodeType NewTensor) { - if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { - auto attrs = ops[iter_ops]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - if (transpose_a && (iter_input_tensors == 0)) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (transpose_b && (iter_input_tensors == 1)) { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); - } - } else { - NewTensor.apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); - } - return NewTensor.apply.arguments[iter_input_tensors]; -} - -std::shared_ptr ParseGraph(const std::vector> &ops, - const std::vector> &input_tensor_names) { - std::shared_ptr graph(new Graph); - if (ops.size() > SIZE_MAX / 2) { - MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; - } - - for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { - Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); - graph->nodes.push_back(NewOp); - } - MakeEdge(input_tensor_names, graph); - - return graph; -} - -void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { - for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { - for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { - size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); - if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { - graph->nodes[iter_i].node_in.push_back(head_node_index); - graph->nodes[head_node_index].node_out.push_back(iter_i); - } - } - } -} - -size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, - const std::string &input_name) { - for (size_t index = 0; index < input_tensor_name.size(); index++) { - if (input_tensor_name[index][0] == input_name) { - return index; - } - } - MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; - return SIZE_MAX; -} - -void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list) { - std::vector eli; - eli.push_back(node_index); - for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { - eli.push_back(graph->nodes[node_index].node_out[i]); - } - eli_list->push_back(eli); - - for (size_t i = 0; i < graph->nodes[node_index].node_in.size(); i++) { - auto *incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in[i]].node_out; - auto it = find(incoming_outputs->begin(), incoming_outputs->end(), node_index); - if (it != incoming_outputs->end()) { - it = incoming_outputs->erase(it); - incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), graph->nodes[node_index].node_out.end()); - } - } - - for (size_t i = 0; i < graph->nodes[node_index].node_in_aux.size(); i++) { - auto *aux_incoming_outputs = &graph->nodes[graph->nodes[node_index].node_in_aux[i]].node_out; - auto it = find(aux_incoming_outputs->begin(), aux_incoming_outputs->end(), node_index); - if (it != aux_incoming_outputs->end()) { - it = aux_incoming_outputs->erase(it); - aux_incoming_outputs->insert(it, graph->nodes[node_index].node_out.begin(), - graph->nodes[node_index].node_out.end()); - } - } - - for (size_t i = 0; i < graph->nodes[node_index].node_out.size(); i++) { - auto *outgoing_inputs = &graph->nodes[graph->nodes[node_index].node_out[i]].node_in; - auto it = find(outgoing_inputs->begin(), outgoing_inputs->end(), node_index); - if (it != outgoing_inputs->end()) { - if (graph->nodes[node_index].node_in.size() > 0) { - outgoing_inputs->at(std::distance(outgoing_inputs->begin(), it)) = graph->nodes[node_index].node_in[0]; - for (size_t j = 1; j < graph->nodes[node_index].node_in.size(); j++) { - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back(graph->nodes[node_index].node_in[j]); - } - for (size_t j = 1; j < graph->nodes[node_index].node_in_aux.size(); j++) { - graph->nodes[graph->nodes[node_index].node_out[i]].node_in_aux.push_back( - graph->nodes[node_index].node_in_aux[j]); - } - } else { - outgoing_inputs->erase(it); - } - } - } -} - -std::shared_ptr EliminateGraph(const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list, - const std::shared_ptr> &index_list) { - MS_EXCEPTION_IF_NULL(graph); - for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { - auto type = graph->nodes[node_index].apply.op_type; - if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) { - Eliminate_Aux(node_index, graph, eli_list); - } - } - index_list->reserve(graph->nodes.size()); - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { - index_list->push_back(i); - } - for (size_t i = 0; i < (size_t)eli_list->size(); i++) { - if (eli_list->at(i)[0] >= index_list->size()) { - MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; - } - index_list->at(eli_list->at(i)[0]) = SIZE_MAX; - for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { - index_list->at(j)--; - } - } - std::shared_ptr new_graph(new Graph); - for (size_t i = 0; i < graph->nodes.size(); i++) { - if (index_list->at(i) > SIZE_MAX / 2) { - continue; - } - new_graph->nodes.push_back(graph->nodes[i]); - auto *node_in = &new_graph->nodes[index_list->at(i)].node_in; - for (size_t j = node_in->size(); j > 0; j--) { - bool IsEliminated = (index_list->at(node_in->at(j - 1)) == SIZE_MAX); - if (IsEliminated) { - node_in->erase(node_in->begin() + j - 1); - } else { - node_in->at(j - 1) = index_list->at(node_in->at(j - 1)); - } - } - auto *node_out = &new_graph->nodes[index_list->at(i)].node_out; - for (size_t j = node_out->size(); j > 0; j--) { - bool IsEliminated = (index_list->at(node_out->at(j - 1)) == SIZE_MAX); - if (IsEliminated) { - node_out->erase(node_out->begin() + j - 1); - } else { - node_out->at(j - 1) = index_list->at(node_out->at(j - 1)); - } - } - } - return new_graph; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h deleted file mode 100644 index 6112579d51..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ /dev/null @@ -1,145 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ -#define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/ops_info/operator_info.h" - -namespace mindspore { -namespace parallel { -static const std::set ElementWiseOpType = { - OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, - OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, - OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, - OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue}; - -const std::map DictOpType{ - {MATMUL, OperatorType::kRecMatMul}, - {CONV2D, OperatorType::kRecConvolution}, - {MAXPOOL, OperatorType::kRecPooling}, - {MAXPOOLV2, OperatorType::kRecPooling}, - {SIMPLE_MEAN, OperatorType::kRecPooling}, - {RESHAPE, OperatorType::kRecReshape}, - {BIAS_ADD, OperatorType::kRecBiasAdd}, - {BATCH_NORM, OperatorType::kRecBatchNorm}, - {FUSE_BATCH_NORM, OperatorType::kRecBatchNorm}, - {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, - {ONEHOT, OperatorType::kRecOneHot}, - {SQUEEZE, OperatorType::kRecSqueeze}, - {CAST, OperatorType::kRecCast}, - {REDUCE_SUM, OperatorType::kRecReduce}, - {REDUCE_MAX, OperatorType::kRecReduce}, - {REDUCE_MIN, OperatorType::kRecReduce}, - {REDUCE_MEAN, OperatorType::kRecReduce}, - {GATHERV2, OperatorType::kRecGatherV2}, - {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, - {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, - - {RELU, OperatorType::kRecReLU}, - {"ReLU6", OperatorType::kRecReLU}, - {"ReLUV2", OperatorType::kRecReLU}, - {SIGMOID, OperatorType::kRecReLU}, - {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU}, - {"HSigmoid", OperatorType::kRecReLU}, - {GELU, OperatorType::kRecReLU}, - {TANH, OperatorType::kRecReLU}, - - {PRELU, OperatorType::kRecPReLU}, - - {TRANSPOSE, OperatorType::kRecElmWiseOp}, - {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, - {TENSOR_ADD, OperatorType::kRecElmWiseOp}, - {SUB, OperatorType::kRecElmWiseOp}, - {MUL, OperatorType::kRecElmWiseOp}, - {DIV, OperatorType::kRecElmWiseOp}, - {REAL_DIV, OperatorType::kRecElmWiseOp}, - {SOFTMAX, OperatorType::kRecSoftmax}, - {LOG_SOFTMAX, OperatorType::kRecSoftmax}, - {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, - {SQRT, OperatorType::kRecElmWiseOp}, - {NEG, OperatorType::kRecElmWiseOp}, - {POW, OperatorType::kRecElmWiseOp}, - {EXP, OperatorType::kRecElmWiseOp}, - {LOG, OperatorType::kRecElmWiseOp}, - {COS, OperatorType::kRecElmWiseOp}, - {ACOS, OperatorType::kRecElmWiseOp}, - {LOGICALNOT, OperatorType::kRecElmWiseOp}, - {"LogicalAnd", OperatorType::kRecElmWiseOp}, - {"LogicalOr", OperatorType::kRecElmWiseOp}, - {SQUARE, OperatorType::kRecElmWiseOp}, - {"Abs", OperatorType::kRecElmWiseOp}, - {"Acosh", OperatorType::kRecElmWiseOp}, - {"AddN", OperatorType::kRecElmWiseOp}, - {"AccumulateNV2", OperatorType::kRecElmWiseOp}, - {"Atan2", OperatorType::kRecElmWiseOp}, - {"Erf", OperatorType::kRecElmWiseOp}, - {"Floor", OperatorType::kRecElmWiseOp}, - {FLOORDIV, OperatorType::kRecElmWiseOp}, - {"FloorMod", OperatorType::kRecElmWiseOp}, - {GREATER, OperatorType::kRecElmWiseOp}, - {"GreaterEqual", OperatorType::kRecElmWiseOp}, - {"HSwish", OperatorType::kRecElmWiseOp}, - {"Less", OperatorType::kRecElmWiseOp}, - {"LessEqual", OperatorType::kRecElmWiseOp}, - {MAXIMUM, OperatorType::kRecElmWiseOp}, - {MINIMUM, OperatorType::kRecElmWiseOp}, - {EQUAL, OperatorType::kRecElmWiseOp}, - {NOT_EQUAL, OperatorType::kRecElmWiseOp}, - {"Reciprocal", OperatorType::kRecElmWiseOp}, - {"Round", OperatorType::kRecElmWiseOp}, - {"Rsqrt", OperatorType::kRecElmWiseOp}, - {"Sign", OperatorType::kRecElmWiseOp}, - {"Sin", OperatorType::kRecElmWiseOp}, - {ASSIGN, OperatorType::kRecElmWiseOp}, - {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, - {"AssignAdd", OperatorType::kRecElmWiseOp}}; - -const TensorParam MakeTensor(int n, int c, int h, int w); - -Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops); - -OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType NewTensor); - -TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, - const size_t iter_input_tensor, Graph::NodeType NewTensor); - -std::shared_ptr ParseGraph(const std::vector> &ops, - const std::vector> &input_tensor_names); - -void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph); - -size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, - const std::string &input_name); - -void Eliminate_Aux(const size_t node_index, const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list); - -std::shared_ptr EliminateGraph(const std::shared_ptr &graph, - const std::shared_ptr>> &eli_list, - const std::shared_ptr> &index_list); -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc deleted file mode 100644 index d5200f54d8..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ /dev/null @@ -1,310 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/auto_parallel/rec_core/rec_partition.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -// Get the target node's weight for sorting. -double GetWeights(const Graph::NodeType &node) { - const OperatorRec &op = node.apply; - - if (op.op_type == OperatorType::kRecMatMul) { - // For MatMul - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(op); - } else if (op.op_type == OperatorType::kRecConvolution) { - // For Convolution - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(node); - } else if (op.op_type == OperatorType::kRecPooling) { - // For Pooling - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecElmWiseOp) { - // For TensorAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecReLU) { - // For Activation - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecReshape) { - // For Reshape - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecBiasAdd) { - // For BiasAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || - op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || - op.op_type == OperatorType::kRecMul || op.op_type == OperatorType::kRecDiv || - op.op_type == OperatorType::kRecSqueeze || op.op_type == OperatorType::kRecCast) { - // For element-wise op - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMinCostIn(); - } else if (op.op_type == OperatorType::kRecBatchNorm || op.op_type == OperatorType::kRecOneHot || - op.op_type == OperatorType::kRecPReLU || op.op_type == OperatorType::kRecSoftmax || - op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits || - op.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { - // For BatchParallel op - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetMaxCostIn(); - } else if (op.op_type == OperatorType::kRecUnkownType) { - // For Unkown type - return 0.0; - } else { - MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; - } -} - -// Sort all the nodes by their weights -std::vector SortByWeight(const std::shared_ptr &graph) { - MS_EXCEPTION_IF_NULL(graph); - - std::vector> weight_to_node_index; - std::vector node_index_by_weights; - - // Get node's weight. - for (size_t i = 0; i < graph->nodes.size(); i++) { - if (graph->nodes[i].info == kApplication) { - const Graph::NodeType &node_ptr = graph->nodes[i]; - double weight = GetWeights(node_ptr); - size_t index = i; - weight_to_node_index.push_back(std::make_pair(weight, index)); - } - } - - // Ordering ops aka nodes of the graph - std::sort(weight_to_node_index.begin(), weight_to_node_index.end()); - - // Store the result in node_index_by_weights. - uint64_t size = weight_to_node_index.size(); - for (uint64_t i = 1; i <= size; i++) { - node_index_by_weights.push_back(weight_to_node_index[size - i].second); - } - - return node_index_by_weights; -} - -// Get optimal strategy to partition the target node -StrategyRec PartitionNode(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::shared_ptr &graph) { - bool enable_conv_chw_partition = false; - MS_EXCEPTION_IF_NULL(graph); - - if (node.apply.op_type == OperatorType::kRecMatMul) { - // For MatMul - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecConvolution) { - // For Convolution - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph, enable_conv_chw_partition); - } else if (node.apply.op_type == OperatorType::kRecPooling) { - // For Pooling - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecElmWiseOp) { - // For TensorAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecReLU) { - // For Activation - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecReshape) { - // For Reshape - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecBiasAdd) { - // For BiasAdd - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecLog || node.apply.op_type == OperatorType::kRecExp || - node.apply.op_type == OperatorType::kRecAdd || node.apply.op_type == OperatorType::kRecSub || - node.apply.op_type == OperatorType::kRecMul || node.apply.op_type == OperatorType::kRecDiv || - node.apply.op_type == OperatorType::kRecSqueeze || node.apply.op_type == OperatorType::kRecCast) { - // For element-wise op - auto cost_ptr = std::make_shared(); - - return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); - } else if (node.apply.op_type == OperatorType::kRecBatchNorm || node.apply.op_type == OperatorType::kRecOneHot || - node.apply.op_type == OperatorType::kRecPReLU || node.apply.op_type == kRecSoftmax || - node.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { - // For BatchParallel type - auto cost_ptr = std::make_shared(); - return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecSoftmaxCrossEntropyWithLogits) { - // For SoftmaxCrossEntropyWithLogits type - auto cost_ptr = std::make_shared(); - return cost_ptr->GetOptimalStr(node); - } else if (node.apply.op_type == OperatorType::kRecUnkownType) { - // For Unkown type - StrategyRec default_strategy; - return default_strategy; - } else { - MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; - } -} - -// Parttion graph into all devices. -Status PartitionForAllDevices(const size_t num_device, const double device_memory, - const std::shared_ptr &graph) { - if (num_device < 1) { - MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; - } - - if (num_device > 1024) { - MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be larger than 1024."; - } - - MS_EXCEPTION_IF_NULL(graph); - - // Comopute iter times - int iter_times = static_cast(log2(num_device)); - - // N-cuts loop - for (int loop = 0; loop < iter_times; loop++) { - // Sort by weights - std::vector reorder_node_list = SortByWeight(graph); - - // get total node number - size_t iter_nodes = reorder_node_list.size(); - - // temp vector to map nodename to its strategy. - std::vector> node_name_to_strategy; - - // Loop for all the nodes - for (size_t i_node = 0; i_node < iter_nodes; i_node++) { - // get current node's index - size_t index = reorder_node_list[i_node]; - - Graph::NodeType &node_ptr = graph->nodes[index]; - - // Serch optimal strategy to cut this operator. And store the result optimal strategy in graph. - graph->nodes[index].apply.str = PartitionNode(node_ptr, node_name_to_strategy, graph); - - // Apply OP Strategy to Tensor Strategy. - graph->nodes[index] = ApplyStrToTensor(node_ptr); - - // Note down the node name and its strategy in this loop. - auto node_name_to_str = - std::pair(graph->nodes[index].name, graph->nodes[index].apply.str); - node_name_to_strategy.push_back(node_name_to_str); - } - } - - if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) { - return FAILED; - } else { - return SUCCESS; - } -} - -// Apply OP Strategy to Tensor Strategy -Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { - // Set Node's tensor_parm - Node.tensor_parm.tensor_str.str_n = Node.apply.str.outputTensor.str_n; - Node.tensor_parm.tensor_str.str_c = Node.apply.str.outputTensor.str_c; - Node.tensor_parm.tensor_str.str_h = Node.apply.str.outputTensor.str_h; - Node.tensor_parm.tensor_str.str_w = Node.apply.str.outputTensor.str_w; - - // Set input tensors' tersor_parm - for (int i = 0; i < 2; i++) { - Node.apply.arguments[i].tensor_str.str_n = Node.apply.str.inputTensor[i].str_n; - Node.apply.arguments[i].tensor_str.str_c = Node.apply.str.inputTensor[i].str_c; - Node.apply.arguments[i].tensor_str.str_h = Node.apply.str.inputTensor[i].str_h; - Node.apply.arguments[i].tensor_str.str_w = Node.apply.str.inputTensor[i].str_w; - } - return Node; -} - -Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph) { - MS_EXCEPTION_IF_NULL(graph); - if (num_device == 0) { - MS_LOG(EXCEPTION) << "Failure: device number is 0."; - } - - uint64_t iter_nodes = graph->nodes.size(); - double used_memory = 0.0; - - for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { - if (graph->nodes[i_node].info == 0) { - Graph::NodeType &Node = graph->nodes[i_node]; - for (int index = 0; index < 2; index++) { - used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * - Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c * - Node.apply.arguments[index].tensor_str.str_h * Node.apply.arguments[index].tensor_shape.shape_h * - Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w * - GetDataTypeSize(Node.apply.arguments[index].tensor_type); - } - } - } - - if (device_memory < (used_memory / num_device)) { - MS_LOG(EXCEPTION) << "Failure: Out of memory!"; - return FAILED; - } else { - return SUCCESS; - } -} - -size_t GetDataTypeSize(const TensorType &type) { - switch (type) { - case kInt8: - return sizeof(int); - case kFloat16: - return sizeof(float) / 2; - case kFloat32: - return sizeof(float); - case kDouble64: - return sizeof(double); - default: - MS_LOG(EXCEPTION) << "GetDataTypeSize Failed. Unexpected type"; - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h deleted file mode 100644 index c98f3317f8..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ -#define PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/rec_core/rec_cost.h" -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_strategy.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -std::vector SortByWeight(const std::shared_ptr &graph); - -double GetWeights(const Graph::NodeType &node); - -StrategyRec PartitionNode(const Graph::NodeType &node, - const std::vector> &node_name_to_strategy, - const std::shared_ptr &graph); - -Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr &graph); - -Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); - -Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph); - -size_t GetDataTypeSize(const TensorType &type); -} // namespace parallel -} // namespace mindspore - -#endif // PARALLEL_AUTO_PARALLEL_REC_PARTITION_H_ diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_tensor.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_tensor.h deleted file mode 100644 index 51ffca4023..0000000000 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_tensor.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ -#define PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ - -#include "parallel/auto_parallel/rec_core/rec_strategy.h" - -namespace mindspore { -namespace parallel { -enum TensorType { kInt8, kFloat16, kFloat32, kDouble64 }; - -struct Shape4D { - int32_t shape_n = 1; - int32_t shape_c = 1; - int32_t shape_h = 1; - int32_t shape_w = 1; -}; - -struct TensorParam { - TensorType tensor_type = kFloat32; // default as float. - Shape4D tensor_shape; - TensorStr4D tensor_str; -}; -} // namespace parallel -} // namespace mindspore - -#endif // PARALLEL_AUTO_PARALLEL_REC_TENSOR_H_ diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc deleted file mode 100644 index 062d814aa0..0000000000 --- a/mindspore/ccsrc/parallel/context.cc +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/context.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "parallel/device_manager.h" - -namespace mindspore { -namespace parallel { -static std::map> param_shapes; - -std::vector PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, - AUTO_PARALLEL}; -std::vector STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; - -std::shared_ptr ParallelContext::inst_context_ = nullptr; - -std::shared_ptr ParallelContext::GetInstance() { - if (inst_context_ == nullptr) { - inst_context_.reset(new (std::nothrow) ParallelContext()); - } - return inst_context_; -} - -ParallelContext::ParallelContext() { Reset(); } - -void ParallelContext::Reset() { - mirror_mean_ = false; - full_batch_ = false; - cast_before_mirror_ = true; - loss_repeated_mean_ = true; - device_num_ = 1; - global_rank_ = 0; - communication_backend_ = HCCL_BACKEND; - device_num_is_set_ = false; - global_rank_is_set_ = false; - parallel_mode_ = STAND_ALONE; - parameter_broadcast_ = false; - parameter_broadcast_is_set_ = false; - enable_all_reduce_fusion_ = false; - strategy_ckpt_load_file_ = ""; - strategy_ckpt_save_file_ = ""; - enable_parallel_optimizer_ = false; -} - -void ParallelContext::set_device_num(int32_t device_num) { - device_num_ = device_num; - device_num_is_set_ = true; -} - -void ParallelContext::set_global_rank(int32_t global_rank) { - global_rank_ = global_rank; - global_rank_is_set_ = true; -} - -void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } - -void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } - -void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_before_mirror_ = cast_before_mirror; } - -void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } - -void ParallelContext::set_communication_backend(const std::string &communication_backend) { - communication_backend_ = communication_backend; -} - -bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { - auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); - if (iter == PARALLEL_MODE_LIST.end()) { - MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; - return false; - } - parallel_mode_ = parallel_mode; - return true; -} - -bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { - auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); - if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { - MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; - return false; - } - strategy_search_mode_ = strategy_search_mode; - return true; -} - -void ParallelContext::set_parameter_broadcast(bool parameter_broadcast) { - parameter_broadcast_ = parameter_broadcast; - parameter_broadcast_is_set_ = true; -} - -void ParallelContext::set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file) { - strategy_ckpt_load_file_ = strategy_ckpt_load_file; -} - -void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file) { - strategy_ckpt_save_file_ = strategy_ckpt_save_file; -} - -void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { - all_reduce_fusion_split_indices_[group] = indices; -} - -const std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { - auto iter = all_reduce_fusion_split_indices_.find(group); - if (iter != all_reduce_fusion_split_indices_.end()) { - return iter->second; - } - return {}; -} - -void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group) { - all_reduce_fusion_split_sizes_[group] = sizes; -} - -const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { - auto iter = all_reduce_fusion_split_sizes_.find(group); - if (iter != all_reduce_fusion_split_sizes_.end()) { - return iter->second; - } - return {}; -} - -// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextInit(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { - return; - } - param_shapes.clear(); -} - -// Restore the parameters' shape for evaluation/prediction in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - AbstractBasePtr ptr) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(param_node); - MS_EXCEPTION_IF_NULL(ptr); - if (!func_graph->has_flag(AUTO_PARALLEL) || (func_graph->attrs().count(TRAINING) == 0) || - func_graph->has_flag(TRAINING)) { - return; - } - - auto iter = param_shapes.find(param_node->name()); - if (iter == param_shapes.end()) { - MS_LOG(WARNING) << "Can not found the shape for parameter " << param_node->name(); - return; - } - std::vector shape = iter->second; - std::shared_ptr base_shape = std::make_shared(shape); - ptr->set_shape(base_shape); - MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; -} - -// Checkpoint the parameters' shape for training in auto-parallel or semi-auto-parallel mode -void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - const AbstractBasePtr &ptr) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(param_node); - MS_EXCEPTION_IF_NULL(ptr); - if (!func_graph->has_flag(AUTO_PARALLEL) || !func_graph->has_flag(TRAINING)) { - return; - } - - std::vector shape = dyn_cast(ptr->GetShapeTrack())->shape(); - auto ret = param_shapes.try_emplace(param_node->name(), shape); - if (!ret.second) { - MS_LOG(EXCEPTION) << "The shape for parameter name " << param_node->name() << " is existed"; - return; - } - - MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h deleted file mode 100644 index 76166f50cf..0000000000 --- a/mindspore/ccsrc/parallel/context.h +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ -#define MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ - -#include -#include -#include -#include -#include - -#include "parallel/ops_info/ops_utils.h" -#include "parallel/status.h" -#include "utils/convert_utils.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "debug/info.h" -#include "abstract/abstract_value.h" - -namespace mindspore { -namespace parallel { -constexpr char STAND_ALONE[] = "stand_alone"; -constexpr char DATA_PARALLEL[] = "data_parallel"; -constexpr char HYBRID_PARALLEL[] = "hybrid_parallel"; -constexpr char AUTO_PARALLEL[] = "auto_parallel"; -constexpr char SEMI_AUTO_PARALLEL[] = "semi_auto_parallel"; - -constexpr char DYNAMIC_PROGRAMMING[] = "dynamic_programming"; -constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; - -constexpr char TRAINING[] = "training"; - -class ParallelContext { - public: - ~ParallelContext() = default; - ParallelContext(const ParallelContext &) = delete; - ParallelContext &operator=(const ParallelContext &) = delete; - - static std::shared_ptr GetInstance(); - - void set_mirror_mean(bool mirror_mean); - bool mirror_mean() const { return mirror_mean_; } - - void set_full_batch(bool full_batch); - bool full_batch() const { return full_batch_; } - - void set_cast_before_mirror(bool cast_before_mirror); - bool cast_before_mirror() const { return cast_before_mirror_; } - - void set_loss_repeated_mean(bool loss_repeated_mean); - bool loss_repeated_mean() const { return loss_repeated_mean_; } - - void set_device_num(int32_t device_num); - int32_t device_num() const { return device_num_; } - - void set_global_rank(int32_t global_rank); - int32_t global_rank() const { return global_rank_; } - - void set_communication_backend(const std::string &communication_backend); - std::string communication_backend() const { return communication_backend_; } - - bool set_parallel_mode(const std::string ¶llel_mode); - std::string parallel_mode() const { return parallel_mode_; } - - bool set_strategy_search_mode(const std::string &strategy_search_mode); - std::string strategy_search_mode() const { return strategy_search_mode_; } - - void set_parameter_broadcast(bool parameter_broadcast); - bool parameter_broadcast() const { return parameter_broadcast_; } - - bool device_num_is_set() const { return device_num_is_set_; } - bool global_rank_is_set() const { return global_rank_is_set_; } - bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } - - void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); - const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; - void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); - const std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; - void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { - enable_all_reduce_fusion_ = enable_all_reduce_fusion; - } - bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; } - - void set_strategy_ckpt_load_file(const std::string &strategy_ckpt_load_file); - std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } - void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); - std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } - - void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { - enable_parallel_optimizer_ = enable_parallel_optimizer; - } - bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } - - void Reset(); - - private: - ParallelContext(); - static std::shared_ptr inst_context_; - bool mirror_mean_; - bool full_batch_; - bool cast_before_mirror_; - bool loss_repeated_mean_; - int32_t device_num_; - int32_t global_rank_; - std::string communication_backend_; - std::string parallel_mode_; - std::string strategy_search_mode_; - bool parameter_broadcast_; - bool device_num_is_set_; - bool global_rank_is_set_; - bool parameter_broadcast_is_set_; - bool enable_all_reduce_fusion_; - std::map> all_reduce_fusion_split_indices_; - std::map> all_reduce_fusion_split_sizes_; - std::string strategy_ckpt_load_file_; - std::string strategy_ckpt_save_file_; - bool enable_parallel_optimizer_; -}; - -void ParallelParameterContextInit(const FuncGraphPtr &func_graph); -void ParallelParameterContextRestoreInNoTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - AbstractBasePtr ptr); -void ParallelParameterContextCkptInTraining(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, - const AbstractBasePtr &ptr); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_CONTEXT_H_ diff --git a/mindspore/ccsrc/parallel/costmodel_context.cc b/mindspore/ccsrc/parallel/costmodel_context.cc deleted file mode 100644 index 92aff29557..0000000000 --- a/mindspore/ccsrc/parallel/costmodel_context.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/costmodel_context.h" - -#include - -#include "parallel/allreduce_fusion/allreduce_fusion.h" -#include "parallel/auto_parallel/graph_costmodel.h" - -namespace mindspore { -namespace parallel { -std::shared_ptr CostModelContext::cm_context_inst_ = nullptr; - -std::shared_ptr CostModelContext::GetInstance() { - if (cm_context_inst_ == nullptr) { - MS_LOG(INFO) << "Create costmodel_context"; - cm_context_inst_.reset(new (std::nothrow) CostModelContext()); - } - return cm_context_inst_; -} - -CostModelContext::CostModelContext() { - ResetCostModel(); - ResetAlgoParameters(); -} - -void CostModelContext::ResetCostModel() { - device_memory_capacity_ = DEFAULT_DEVICE_MEMORY_CAPACITY; - costmodel_alpha_ = DEFAULT_COST_MODEL_ALPHA; - costmodel_beta_ = DEFAULT_COST_MODEL_BETA; - costmodel_gamma_ = DEFAULT_COST_MODEL_GAMMA; - costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; - costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; - costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; - is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; - run_phase_ = DEFAULT_RUN_PHASE; - costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; - costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; - costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; - costmodel_allreduce_fusion_tail_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_TIME; - costmodel_allreduce_fusion_allreduce_inherent_time_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME; - costmodel_allreduce_fusion_allreduce_bandwidth_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH; - costmodel_allreduce_fusion_computation_time_parameter_ = - DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER; -} - -void CostModelContext::ResetAlgoParameters() { - costmodel_simplify_cal_ = DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION; - tensor_slice_alignment_enable_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_ENABLE; - tensor_slice_alignment_size_ = DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE; - fully_use_device_ = DEFAULT_FULLY_USE_DEVICES; - elementwise_stra_follow_ = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; -} - -void CostModelContext::set_device_memory_capacity(double dm_capacity) { device_memory_capacity_ = dm_capacity; } - -void CostModelContext::set_costmodel_alpha(double cm_alpha) { costmodel_alpha_ = cm_alpha; } - -void CostModelContext::set_costmodel_beta(double cm_beta) { costmodel_beta_ = cm_beta; } - -void CostModelContext::set_costmodel_gamma(double cm_gamma) { costmodel_gamma_ = cm_gamma; } - -void CostModelContext::set_costmodel_simplify_cal(bool cm_simplify) { costmodel_simplify_cal_ = cm_simplify; } - -void CostModelContext::set_costmodel_communi_threshold(double cm_communi_th) { - costmodel_communi_threshold_ = cm_communi_th; -} - -void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { - costmodel_communi_const_ = cm_communi_const; -} - -void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } - -void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } -void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { - costmodel_allreduce_fusion_algorithm_ = algorithm; -} - -void CostModelContext::set_costmodel_allreduce_fusion_times(int32_t allreduce_fusion_times) { - costmodel_allreduce_fusion_times_ = allreduce_fusion_times; -} - -void CostModelContext::set_costmodel_allreduce_fusion_tail_percent(double tail_percent) { - costmodel_allreduce_fusion_tail_percent_ = tail_percent; -} - -void CostModelContext::set_costmodel_allreduce_fusion_tail_time(double tail_time) { - costmodel_allreduce_fusion_tail_time_ = tail_time; -} - -void CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time(double allreduce_inherent_time) { - costmodel_allreduce_fusion_allreduce_inherent_time_ = allreduce_inherent_time; -} - -void CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth(double allreduce_bandwidth) { - costmodel_allreduce_fusion_allreduce_bandwidth_ = allreduce_bandwidth; -} - -void CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter(double computation_time_parameter) { - costmodel_allreduce_fusion_computation_time_parameter_ = computation_time_parameter; -} - -void CostModelContext::set_tensor_slice_alignment_enable(bool ts_align) { tensor_slice_alignment_enable_ = ts_align; } - -void CostModelContext::set_tensor_slice_alignment_size(size_t ts_align_size) { - tensor_slice_alignment_size_ = ts_align_size; -} - -void CostModelContext::set_fully_use_device(bool fully_use) { fully_use_device_ = fully_use; } - -void CostModelContext::set_elementwise_stra_follow(bool elementwise_follow) { - elementwise_stra_follow_ = elementwise_follow; -} - -void CostModelContext::set_run_phase(int32_t phase) { run_phase_ = phase; } -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/device.h b/mindspore/ccsrc/parallel/device.h deleted file mode 100644 index 8c3174ae55..0000000000 --- a/mindspore/ccsrc/parallel/device.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ - -#include -#include -#include - -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class Device { - // This class abstract the 'device' information, used in Parallel module. - public: - Device() : rank_(0) { name_.clear(); } - explicit Device(int32_t rank) : rank_(rank) { name_.clear(); } - Device(std::string name, int32_t rank) : name_(std::move(name)), rank_(rank) {} - ~Device() = default; - std::string name() const { return name_; } - int32_t rank() const { return rank_; } - - private: - std::string name_; - int32_t rank_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_H_ diff --git a/mindspore/ccsrc/parallel/device_manager.cc b/mindspore/ccsrc/parallel/device_manager.cc deleted file mode 100644 index 45628bec65..0000000000 --- a/mindspore/ccsrc/parallel/device_manager.cc +++ /dev/null @@ -1,374 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/device_manager.h" - -#include -#include -#include -#include -#include -#include - -#include "parallel/step_parallel.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -DeviceManagerPtr g_device_manager = nullptr; - -Stage::Stage(const std::vector &devices, int num, int rank) - : devices_(devices), number_(num), rank_(rank) { - gm_ = GroupManager(); -} - -// NOTE: '-1' indicates ERROR -int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } - -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { - if (device_num <= 0) { - MS_LOG(ERROR) << "'device_num' must be positive."; - return false; - } - if (global_rank < 0) { - MS_LOG(ERROR) << "'global_rank' must be nonnegative."; - return false; - } - if (device_num > MAX_DEVICE_NUM) { - MS_LOG(ERROR) << "'device_num' must be no more than " << MAX_DEVICE_NUM << "."; - return false; - } - // 'device_num_converted' must be the power of 2 - if ((IntToUint(device_num) & IntToUint(device_num - 1)) != 0) { - MS_LOG(ERROR) << "'device_num' must be the power of 2."; - return false; - } - if (global_rank >= device_num) { - MS_LOG(ERROR) << "'global_rank' must be less than 'device_num'."; - return false; - } - if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { - MS_LOG(ERROR) << "Invalid backend: " << backend; - return false; - } - - RankList devices, stage_map; - for (int i = 0; i < device_num; ++i) { - devices.push_back(i); - } - - stage_map.push_back(device_num); - g_device_manager = std::make_shared(); - if (g_device_manager->Init(devices, global_rank, stage_map, backend) == SUCCESS) { - MS_LOG(INFO) << "Device initialization succeeds."; - return true; - } else { - MS_LOG(ERROR) << "Device initialization fails."; - return false; - } -} - -void CheckGlobalDeviceManager() { - if (g_device_manager == nullptr) { - MS_LOG(EXCEPTION) << "Device information has not been set!"; - } -} - -int32_t GetListMemberByIndex(size_t index, const RankList &devices) { - size_t i = 0; - int32_t result = 0; - if ((devices.empty()) || (index >= devices.size())) { - MS_LOG(EXCEPTION) << "Index is out of the list scope"; - } - auto it = devices.begin(); - for (; it != devices.end(); ++it) { - if (i == index) { - result = *it; - break; - } - ++i; - } - return result; -} - -std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { - size_t i = 0; - std::shared_ptr result; - if ((device_list.empty()) || (index >= device_list.size())) { - MS_LOG(EXCEPTION) << "Index is out of the list scope"; - } - auto it = device_list.begin(); - for (; it != device_list.end(); ++it) { - if (i == index) { - result = *it; - break; - } - ++i; - } - return result; -} - -// E.g. devices = [4, 5, 2, 1, 7, 8, 10], stage_map = [4, 3], -// therefore the stage_devices_ = [[4, 5, 2, 1], [7, 8, 10]]. -Status DeviceManager::Init(const RankList &devices, int32_t global_device_rank, const RankList &stage_map, - const std::string &backend) { - auto dev_it = devices.begin(); - auto stage_it = stage_map.begin(); - int32_t sum = 0; - - if ((backend != HCCL_BACKEND) && (backend != NCCL_BACKEND) && (backend != UNDEFINED_BACKEND)) { - MS_LOG(ERROR) << "Invalid backend: " << backend; - return Status::FAILED; - } - - for (; stage_it != stage_map.end(); ++stage_it) { - sum += (*stage_it); - } - if (IntToSize(sum) != devices.size()) { - MS_LOG(ERROR) << "The number of 'devices' in the list is not equal to the mentioned " - << "size of 'stage_map'"; - return Status::FAILED; - } - - for (; dev_it != devices.end(); ++dev_it) { - std::shared_ptr one = std::make_shared(*dev_it); - devices_.push_back(one); - } - - size_t global_index = 0; - for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { - int num_device = *stage_it; - if (num_device > MAX_DEVICE_NUM) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must not be greater than " << MAX_DEVICE_NUM; - return Status::FAILED; - } - if (num_device <= 0) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; - return Status::FAILED; - } - RankList curr_dev_list; - for (int i = 0; i < num_device; ++i) { - curr_dev_list.push_back(GetListMemberByIndex(global_index, devices)); - global_index++; - } - stage_devices_.push_back(curr_dev_list); - } - - global_index = 0; - for (stage_it = stage_map.begin(); stage_it != stage_map.end(); ++stage_it) { - int num_device = *stage_it; - if (num_device > MAX_DEVICE_NUM) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must be less than " << MAX_DEVICE_NUM; - return Status::FAILED; - } - if (num_device <= 0) { - MS_LOG(ERROR) << "The number of 'devices' in a stage must be positive"; - return Status::FAILED; - } - std::vector curr_dev_list; - for (int i = 0; i < num_device; ++i) { - curr_dev_list.push_back(*GetListMemberByIndex(global_index, devices_)); - global_index++; - } - std::shared_ptr new_stage = std::make_shared(curr_dev_list); - stages_.push_back(new_stage); - } - - std::shared_ptr dev = std::make_shared(global_device_rank); - device_ = dev; - set_global_rank(global_device_rank); - backend_ = backend; - - if (backend == HCCL_BACKEND) { - gm_.set_world_group(HCCL_WORLD_GROUP); - } else if (backend_ == NCCL_BACKEND) { - gm_.set_world_group(NCCL_WORLD_GROUP); - } else { - gm_.set_world_group(UNDEFINED_WORLD_GROUP); - } - MS_LOG(INFO) << "The device num: " << devices.size() << "rank id: " << global_device_rank - << "the backend: " << backend; - return Status::SUCCESS; -} - -std::shared_ptr DeviceManager::GetStageById(int32_t stage_id) { - std::shared_ptr res; - if (IntToSize(stage_id) >= stages_.size()) { - MS_LOG(ERROR) << "the 'stage_id': " << stage_id << ", is out of the scope of 'stage_devices_': " << stages_.size(); - return res; - } - int32_t index = 0; - for (auto &stage : stages_) { - if (index == stage_id) return stage; - index++; - } - return res; -} - -RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { - if (IntToSize(stage_id) >= stage_devices_.size()) - MS_LOG(ERROR) << "the 'stage_id': " << stage_id - << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); - RankList res; - int32_t index = 0; - for (auto &stage : stage_devices_) { - if (index == stage_id) { - return stage; - } - index++; - } - return res; -} - -RankList DeviceManager::global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const { - RankList res; - if (split_num <= 0) { - return res; - } - if (IntToSize(stage_id) >= stage_devices_.size()) { - MS_LOG(ERROR) << "the 'stage_id': " << stage_id - << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); - return res; - } - - RankList global_list = GetDeviceListByStageId(stage_id); - if (global_list.size() % IntToSize(split_num)) { - MS_LOG(ERROR) << "dev list size(" << global_list.size() << ") can not be divisible by split num: " << stage_id; - return res; - } - - std::vector dev_list; - (void)std::copy(global_list.begin(), global_list.end(), std::back_inserter(dev_list)); - - size_t index = 0; - size_t slice_size = dev_list.size() / IntToSize(split_num); - for (int32_t i = 0; i < split_num; ++i) { - bool found = false; - index = slice_size * IntToSize(i); - for (size_t j = 0; j < slice_size; ++j) { - if (dev_list[index + j] == rank) { - found = true; - break; - } - } - - if (found) { - break; - } - } - - for (size_t k = 0; k < slice_size; ++k) { - res.push_back(dev_list[index + k]); - } - return res; -} - -Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device(rank); } - -std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { - std::vector dev_list; - for (auto &rank : ranks) { - Device one = CreateNewDeviceByRank(rank); - dev_list.push_back(one); - } - return dev_list; -} - -DeviceManager &DeviceManager::GetInstance() { - static DeviceManager instance = DeviceManager(); - return instance; -} - -std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { - std::string tmp = "WORLD_GROUP"; - if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { - return tmp; - } - auto iter = group_to_rank_.find(hash_name); - if (iter == group_to_rank_.end()) { - MS_LOG(WARNING) << "Can not find the rank list name by hash name: " << hash_name; - return tmp; - } - return iter->second; -} - -std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } - -// Group name is generated using the increasing ranks of the devices. -// E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name -// is '0-1-3-5-7'. -std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { - std::string rank_list_name; - std::vector::iterator it; - std::sort(ranks.begin(), ranks.end()); // sorted in increasing order - for (it = ranks.begin(); it != ranks.end(); ++it) { - if (it == ranks.begin()) { - rank_list_name = std::to_string(*it); - } else { - rank_list_name += "-" + std::to_string(*it); - } - } - - // hash rank-list-name and add ranks' size as prefix - std::string group_hash_name = HashName(rank_list_name); - std::string group_name = std::to_string(ranks.size()) + "-" + group_hash_name; - - if (rank_to_group_.find(rank_list_name) == rank_to_group_.end()) { - if (group_to_rank_.find(group_name) == group_to_rank_.end()) { - rank_to_group_[rank_list_name] = group_name; - group_to_rank_[group_name] = rank_list_name; - MS_LOG(INFO) << "The rank list name is " << rank_list_name << "nd group name is " << group_name; - } else { - MS_LOG(EXCEPTION) << "Hash collision, the current rank list: " << rank_list_name - << "the old rank list:" << group_to_rank_.find(group_name)->second - << "the group name: " << group_name; - } - } - return group_name; -} - -// Create the group with the given devices and the given name. The GroupManager -// gm_ will create a new group only if there does not exit a group with the same -// name. Otherwise, let the pointer g point to that group. -Group DeviceManager::CreateGroup(const std::string &group_name, - const std::vector &devices) { - if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { - MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; - } - Group g; - (void)gm_.CreateGroup(group_name, devices, &g); - return g; -} - -// Create the group with only the given devices' ranks. -Group DeviceManager::CreateGroup(const RankList &dev_ranks) { - std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); - if (dev_ranks.size() != rank_set.size()) { - MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; - } - - std::string group_name = GenerateGroupNameByRanks(dev_ranks); - auto dev_list = CreateDeviceListByRankList(dev_ranks); - return CreateGroup(group_name, dev_list); -} - -void DeviceManager::Clear() { - devices_.clear(); - stage_devices_.clear(); - gm_.Clear(); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/device_manager.h b/mindspore/ccsrc/parallel/device_manager.h deleted file mode 100644 index 3afafe6a9c..0000000000 --- a/mindspore/ccsrc/parallel/device_manager.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "parallel/device.h" -#include "parallel/device_matrix.h" -#include "parallel/group_manager.h" -#include "parallel/status.h" -#include "parallel/strategy.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace parallel { -#define MAX_DEVICE_NUM 1024 - -constexpr char HCCL_BACKEND[] = "hccl"; -constexpr char NCCL_BACKEND[] = "nccl"; -constexpr char UNDEFINED_BACKEND[] = "undefined_backend"; - -class DeviceManager; -using DeviceManagerPtr = std::shared_ptr; -// 'g_device_manager' is the globally unique manager to manage the devices. -extern DeviceManagerPtr g_device_manager; - -class Stage { - // This class is used in pipeline-parallelization. Available devices are partitioned into multiple stages. - // Currently, the function of pipeline-parallelization and this class are NOT implemented. - public: - explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { - gm_ = GroupManager(); - } - Stage(const std::vector &devices, int num, int rank); - ~Stage() = default; - - int GetStageNum() const { return number_; } - size_t GetDevicesNum() const { return devices_.size(); } - std::vector GetDevicesList() { return devices_; } - int global_rank(Group *g) const; - - private: - std::vector devices_; - int number_; - int32_t rank_; - GroupManager gm_; -}; - -// This method is used for initializing the global DeviceManager 'g_device_manager', -// arguments including 'device_num' and 'global_rank' -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); - -void CheckGlobalDeviceManager(); - -std::string HashName(const std::string &rank_list_name); - -class DeviceManager { - // This class is used to manage the abstract devices, including group-related and stage-related management. - public: - DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } - ~DeviceManager() = default; - - Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); - - static DeviceManager &GetInstance(); - RankList GetDeviceListByStageId(int32_t stage_id) const; - RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; - - Device CreateNewDeviceByRank(int32_t rank) const; - std::vector CreateDeviceListByRankList(RankList ranks); - - std::string GenerateGroupNameByRanks(RankList dev_ranks); - Group CreateGroup(const std::string &group_name, const std::vector &devices); - Group CreateGroup(const RankList &dev_ranks); - std::shared_ptr GetStageById(int32_t stage_id); - - size_t DeviceNum() const { return devices_.size(); } - - int32_t GetStageNum() const { return static_cast(stage_devices_.size()); } - - int32_t global_rank() const { return global_rank_; } - std::string backend() const { return backend_; } - void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } - void Clear(); - std::string world_group() const { return gm_.world_group(); } - std::string FindRankListNameByHashName(const std::string &hash_name); - - private: - std::vector> devices_; - // each stage has a list of devices - std::vector> stage_devices_; - std::shared_ptr device_; - std::vector> stages_; - GroupManager gm_; - std::string backend_; - - // bimap: - std::map rank_to_group_; // the key is rank list, value is hash name - std::map group_to_rank_; // the key is hash name, value is rank list - - int32_t local_rank_; - int32_t global_rank_; - int32_t stage_num_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/parallel/device_matrix.cc b/mindspore/ccsrc/parallel/device_matrix.cc deleted file mode 100644 index 3c9467a223..0000000000 --- a/mindspore/ccsrc/parallel/device_matrix.cc +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/device_matrix.h" - -#include -#include -#include -#include -#include -#include - -#include "parallel/ops_info/operator_info.h" -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -DeviceMatrix::DeviceMatrix(int32_t rank, RankList dev_list, Shape dev_shape) - : rank_(rank), dev_list_(std::move(dev_list)), dev_shape_(std::move(dev_shape)) { - if (!std::any_of(dev_list_.begin(), dev_list_.end(), [rank](int32_t a) { return a == rank; })) { - MS_LOG(EXCEPTION) << "Rank " << rank << " is not in the current stage!"; - } - int32_t total = std::accumulate(dev_shape_.begin(), dev_shape_.end(), 1, std::multiplies()); - if (IntToSize(total) != dev_list_.size()) { - MS_LOG(EXCEPTION) << "Device shape does not match the size of the device list!"; - } -} - -Status DeviceMatrix::CreateGroupList() { - size_t size = dev_shape_.size(); - RankList group; - for (size_t i = 0; i < size; i++) { - Status status = GetDevicesAlongDim(SizeToUint(i), &group); - group_list_.push_back(group); - if (status == Status::FAILED) { - return Status::FAILED; - } - } - return Status::SUCCESS; -} - -Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { - if (dim >= dev_shape_.size()) { - MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; - } - if (dev_shape_[dim] == 1) { - *devices = {rank_}; - return Status::SUCCESS; - } - - RankList group; - std::vector local_group_list; - - // lower than dim - int32_t step = 1; - for (uint32_t i = dim + 1; i < dev_shape_.size(); i++) { - step = step * dev_shape_[i]; - } - int32_t num = *dev_list_.begin(); - for (int32_t i = 0; i < dev_shape_[dim]; i++) { - group.push_back(num); - num += step; - } - - for (int32_t i = 0; i < step; i++) { - local_group_list.push_back(group); - (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); - } - - // higher than dim - step = step * dev_shape_[dim]; - int32_t len = SizeToInt(dev_list_.size()) / step; - - // search rank - int32_t target = rank_; - for (int32_t i = 0; i < len; i++) { - for (RankList &temp : local_group_list) { - if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { - *devices = temp; - return Status::SUCCESS; - } - (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); - } - } - MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; - return Status::FAILED; -} - -Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { - Shape dev_coordinate; - for (size_t i = 0; i < dev_shape.size(); ++i) { - int32_t size = dev_shape[dev_shape.size() - i - 1]; - if (size == 0) { - MS_LOG(EXCEPTION) << "Invalid dev shape: " << ShapeToString(dev_shape); - } else { - int32_t index = rank % size; - (void)dev_coordinate.insert(dev_coordinate.begin(), index); - rank = rank / size; - } - } - return dev_coordinate; -} - -Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { - for (auto &element : tensor_map) { - // -1 means the corresponding dimension is not split. - if (element == MAP_NONE) { - continue; - } else if ((element < 0) || (IntToSize(element) >= dev_shape_.size())) { - MS_LOG(ERROR) << "create group by tensor map: the tensor map is invalid"; - return FAILED; - } - } - - Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); - for (auto &tmp_rank : dev_list_) { - Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); - bool matched = true; - for (auto &map : tensor_map) { - if (map == MAP_NONE) { - continue; - } - size_t index = dev_shape_.size() - IntToSize(map) - 1; - if (current_rank_coordinate[index] != tmp_rank_coordinate[index]) { - matched = false; - break; - } - } - if (matched) { - rank_list->push_back(tmp_rank); - } - } - - return SUCCESS; -} - -std::string ShapeToString(const Shape &shape) { - std::string str = "["; - for (size_t i = 0; i < shape.size(); ++i) { - str += std::to_string(shape[i]); - if (i < shape.size() - 1) { - str += ", "; - } - } - return str + "]"; -} - -std::string ListToString(const std::vector &list) { - std::string str = "["; - for (auto &element : list) { - str += std::to_string(element) + ", "; - } - return str + "]"; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/device_matrix.h b/mindspore/ccsrc/parallel/device_matrix.h deleted file mode 100644 index 295bf33836..0000000000 --- a/mindspore/ccsrc/parallel/device_matrix.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ -#define MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ - -#include -#include -#include - -#include "parallel/status.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace parallel { -using RankList = std::vector; -using Shape = std::vector; - -class DeviceMatrix { - public: - DeviceMatrix(int32_t rank, RankList devices, Shape dev_shape); - DeviceMatrix() = default; - ~DeviceMatrix() = default; - std::vector group_list() const { return group_list_; } - Status CreateGroupList(); - Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); - Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); - - private: - int32_t rank_ = -1; - RankList dev_list_; - // From low dim to high dim. eg: [D0 D1 D2 D3] - Shape dev_shape_; - std::vector group_list_; -}; - -std::string ShapeToString(const Shape &shape); -std::string ListToString(const std::vector &list); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DEVICE_MATRIX_H_ diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h deleted file mode 100644 index 352c7449a5..0000000000 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ -#define MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ - -#include -#include -#include -#include - -#include "parallel/ops_info/ops_info_head_files.h" -#include "parallel/step_parallel.h" - -namespace mindspore { -namespace parallel { -#define REGISTER(className) \ - OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ - return std::make_shared(name, in, out, attrs); \ - } \ - RegisterAction className##Register(#className, (CreatFn)objectCreator##className); - -typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, - const PrimitiveAttrs &attrs); - -class DynCreator { - public: - ~DynCreator() = default; - - // creat static singleton dyn_creator instance - static DynCreator &Instance() { - static DynCreator fac = DynCreator(); - return fac; - } - // register - void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } - // creator - OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, - const PrimitiveAttrs &attrs, size_t count) { - std::string op_name = name + std::to_string(count); - auto iter = Function_map_.find(name); - if (iter == Function_map_.end()) { - MS_LOG(INFO) << name << " is not register yet"; - return nullptr; - } - return iter->second(op_name, shape_in, shape_out, attrs); - } - - private: - DynCreator() = default; - std::map Function_map_; -}; - -class RegisterAction { - public: - RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { - DynCreator::Instance().Regist(name, creatfn); - } - ~RegisterAction() = default; - - private: - std::string name_; -}; - -// operator register -REGISTER(MatMulInfo); -REGISTER(GeluInfo); -REGISTER(VirtualDatasetInfo); -REGISTER(BatchParallelInfo); -REGISTER(TanhInfo); -REGISTER(SoftmaxInfo); -REGISTER(LogSoftmaxInfo); -REGISTER(ActivationInfo); -REGISTER(SoftmaxCrossEntropyWithLogitsInfo); -REGISTER(SubInfo); -REGISTER(TensorAddInfo); -REGISTER(BiasAddInfo); -REGISTER(MulInfo); -REGISTER(DivInfo); -REGISTER(RealDivInfo); -REGISTER(PowInfo); -REGISTER(ExpInfo); -REGISTER(OneHotInfo); -REGISTER(EqualInfo); -REGISTER(NotEqualInfo); -REGISTER(LogInfo); -REGISTER(CosInfo); -REGISTER(ACosInfo); -REGISTER(LogicalNotInfo); -REGISTER(L2NormalizeInfo); -REGISTER(LayerNormInfo); -REGISTER(ReduceMaxInfo); -REGISTER(ArgMaxWithValueInfo); -REGISTER(ArgMinWithValueInfo); -REGISTER(ReduceMeanInfo); -REGISTER(ReduceSumInfo); -REGISTER(ReduceMinInfo); -REGISTER(TransposeInfo); -REGISTER(PReLUInfo); -REGISTER(DropoutDoMaskInfo); -REGISTER(ReshapeInfo); -REGISTER(FloorDivInfo); -REGISTER(MaximumInfo); -REGISTER(MinimumInfo); -REGISTER(CastInfo); -REGISTER(GreaterInfo); -REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo); -REGISTER(AssignSubInfo); -REGISTER(ReLUInfo); -REGISTER(GatherV2Info); -REGISTER(SparseGatherV2Info); -REGISTER(SqrtInfo); -REGISTER(SigmoidInfo); -REGISTER(GetNextInfo); -REGISTER(NegInfo); -REGISTER(BatchMatMulInfo); -REGISTER(ExpandDimsInfo); -REGISTER(SqueezeInfo); -REGISTER(SigmoidCrossEntropyWithLogitsInfo); -REGISTER(SquareInfo); -REGISTER(GatherV2PInfo); -REGISTER(EmbeddingLookupInfo); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_DYNAMIC_CREATOR_H_ diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc deleted file mode 100644 index 7bd2fa808d..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/generate_graph.h" - -#include -#include -#include -#include - -using mindspore::tensor::Tensor; - -namespace mindspore { -namespace parallel { -std::string GetOpPythonPath(const OperatorName &op_name) { - // almost all ops are defined in two main paths - const std::string ops_module = OP_PATH; - const std::string inner_ops_module = INNER_OP_PATH; - py::module mod = py::module::import(common::SafeCStr(ops_module)); - py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); - if (!py::hasattr(mod, common::SafeCStr(op_name))) { - if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { - MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; - } - return inner_ops_module; - } - return ops_module; -} - -ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { - std::string op_path = GetOpPythonPath(op_name); - py::module mod = py::module::import(common::SafeCStr(op_path)); - if (!py::hasattr(mod, common::SafeCStr(op_name))) { - MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name; - return nullptr; - } - std::vector arg_list; - (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), - [](const Attr &attr) { return ValuePtrToPyData(attr.second); }); - py::object obj = - parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list); - ValuePtr op_instance = nullptr; - bool succ = parse::ConvertData(obj, &op_instance); - if (!succ) { - MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail"; - return nullptr; - } - return op_instance; -} - -AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) { - auto value_node = NewValueNode(value_ptr); - MS_EXCEPTION_IF_NULL(value_node); - return value_node->cast(); -} - -static std::unordered_map int_tensor_map = {}; -AnfNodePtr CreateInt32Tensor(int32_t value) { - auto it = int_tensor_map.find(value); - if (it != int_tensor_map.end()) { - return it->second; - } - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(py::int_(value), kInt32); - ValuePtr value_ptr = MakeValue(tensor_ptr); - auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); - int_tensor_map[value] = anf_node_ptr; - return anf_node_ptr; -} - -AnfNodePtr CreatTypeInt(int32_t value) { - ValuePtr value_ptr = MakeValue(std::make_shared(value)); - return ValuePtrToAnfNodePtr(value_ptr); -} - -AnfNodePtr CreatInt32Imm(int32_t value) { - ValuePtr value_ptr = MakeValue(std::make_shared(value)); - return ValuePtrToAnfNodePtr(value_ptr); -} - -std::string GetInstanceNameByCNode(const CNodePtr &cnode) { - PrimitivePtr prim = GetValueNode(cnode->input(0)); - if (!prim) { - MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr."; - } - std::string instance_name = prim->instance_name(); - return HashInstanceName(instance_name); -} - -std::string HashInstanceName(const std::string &name) { - auto using_hash_name = common::GetEnv(USING_HASH_NAME); - std::string instance_name; - if ((using_hash_name.empty()) || (using_hash_name == "on")) { - instance_name = HashName(name); - } else { - instance_name = name; - } - return instance_name; -} - -Status GenerateGraph::Init(const CNodePtr &cnode) { - if (!cnode) { - MS_LOG(ERROR) << "Init:cnode is nullptr"; - return FAILED; - } - cnode_ = cnode; - func_graph_ = cnode->func_graph(); - if (!func_graph_) { - MS_LOG(ERROR) << "Init:func_graph_ is nullptr"; - return FAILED; - } - manager_ = func_graph_->manager(); - if (!manager_) { - MS_LOG(ERROR) << "Init:manager_ is nullptr"; - return FAILED; - } - scope_ = cnode_->scope(); - if (!scope_) { - MS_LOG(ERROR) << "Init:scope_ is nullptr"; - return FAILED; - } - virtual_input_node_ = std::make_shared(nullptr); - virtual_input_node_->set_scope(scope_); - instance_name_base_ = GetInstanceNameByCNode(cnode_); - name_idx_ = 0; - return SUCCESS; -} - -AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { - CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode - MS_EXCEPTION_IF_NULL(cnode); - cnode->set_scope(scope_); - if (inputs.size() < 2) { - MS_LOG(EXCEPTION) << "inputs.size() must be more than 1"; - } - (void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0] - auto new_anf_node_ptr = cnode->cast(); - MS_EXCEPTION_IF_NULL(new_anf_node_ptr); - return new_anf_node_ptr; -} - -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) { - name_idx_++; - ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_)); - if (pyop_instance == nullptr) { - MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed"; - } - auto value_node = NewValueNode(pyop_instance); - return value_node->cast(); -} - -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) { - name_idx_++; - OperatorAttrs attrs; - ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_)); - if (pyop_instance == nullptr) { - MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed"; - } - auto value_node = NewValueNode(pyop_instance); - return value_node->cast(); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/parallel/graph_util/generate_graph.h deleted file mode 100644 index 71227a6e7b..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ -#define MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include - -#include "./common.h" -#include "optimizer/opt.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -#define USING_HASH_NAME "USING_HASH_NAME" -// Get the operator's path where the operator has be defined -std::string GetOpPythonPath(const OperatorName &op_name); - -// Init python operator Instance -ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); - -AnfNodePtr CreatTypeInt(int32_t value); -AnfNodePtr CreatInt32Imm(int32_t value); -AnfNodePtr CreateInt32Tensor(int32_t value); -AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); -std::string HashInstanceName(const std::string &name); - -class GenerateGraph { - public: - GenerateGraph() : name_idx_(0) {} - Status Init(const CNodePtr &cnode); - ~GenerateGraph() = default; - AnfNodePtr virtual_input_node() { return virtual_input_node_; } - AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); - AnfNodePtr NewOpInst(const OperatorName &op_name); - AnfNodePtr PushBack(const std::vector &inputs); - - private: - CNodePtr cnode_; - FuncGraphManagerPtr manager_; - ScopePtr scope_; - FuncGraphPtr func_graph_; - AnfNodePtr virtual_input_node_; - std::string instance_name_base_; - int64_t name_idx_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_GRAPH_UTIL_GENERATE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc deleted file mode 100644 index 32cd106d8e..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/get_parallel_info.h" - -#include -#include -#include -#include - -#include "common/utils.h" -#include "ir/func_graph.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/graph_util/graph_info.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - py::dict dict; - std::vector graph_params = graph->parameters(); - - for (auto para : graph_params) { - std::string name = std::static_pointer_cast(para)->name(); - std::shared_ptr tensor_layout = std::static_pointer_cast(para)->tensor_layout(); - if (tensor_layout == nullptr) { - MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; - } else { - auto device_arrangement = tensor_layout->device_arrangement().array(); - auto tensor_map = tensor_layout->tensor_map().array(); - auto slice_shape = tensor_layout->slice_shape().array(); - std::vector> layout = {device_arrangement, tensor_map, slice_shape}; - dict[py::str(name)] = layout; - MS_LOG(INFO) << "GetParameterLayout name = " << name << ", layout " << tensor_layout->ToString(); - } - } - return dict; -} - -py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - py::dict dict; - auto ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - auto nodes = DeepScopedGraphSearch(ret); - - for (auto node : nodes) { - if (node->isa()) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto distributed_operation_info = cnode->operator_info(); - if (distributed_operation_info != nullptr) { - auto strategyPtr = distributed_operation_info->strategy(); - if (strategyPtr != nullptr) { - auto strategy = strategyPtr->GetInputDim(); - auto name = cnode->fullname_with_scope(); - dict[py::str(name)] = strategy; - } - } - } - } - return dict; -} - -py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - py::dict dict; - auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); - - for (auto prim : allreduce_prim_list) { - auto name_ptr = prim->GetAttr("parameter"); - auto fusion_ptr = prim->GetAttr("fusion"); - if (fusion_ptr == nullptr) { - MS_LOG(EXCEPTION) << "fusion_ptr is nullptr"; - } else if (name_ptr == nullptr) { - continue; - } - if (!name_ptr->isa()) { - MS_LOG(EXCEPTION) << "name is not StringImm"; - } - auto name = name_ptr->cast()->value(); - if (!fusion_ptr->isa()) { - MS_LOG(EXCEPTION) << "fusion is not Int32Imm"; - } - int32_t fusion = fusion_ptr->cast()->value(); - dict[py::str(name)] = fusion; - } - return dict; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/parallel/graph_util/graph_info.cc deleted file mode 100644 index 175413c0fd..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/graph_info.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "debug/draw.h" -#include "ir/func_graph.h" -#include "utils/context/ms_context.h" -#include "utils/graph_utils.h" - -namespace mindspore { -namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { - AnfNodePtr ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - std::vector prim_list; - for (auto &node : all_nodes) { - if (!IsValueNode(node)) { - continue; - } - ValueNodePtr prim_node_anf = node->cast(); - MS_EXCEPTION_IF_NULL(prim_node_anf); - PrimitivePtr node_prim = prim_node_anf->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == name) { - prim_list.emplace_back(node_prim); - } - } - return prim_list; -} - -void DumpGraph(const FuncGraphPtr &root, const std::string &name) { - if (MsContext::GetInstance()->save_graphs_flag()) { - draw::Draw(name + ".dot", root); - DumpIR(name + ".ir", root); - ExportIR(name + ".dat", "0", root); - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.cc b/mindspore/ccsrc/parallel/graph_util/node_info.cc deleted file mode 100644 index 1bc62f8807..0000000000 --- a/mindspore/ccsrc/parallel/graph_util/node_info.cc +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/graph_util/node_info.h" - -#include - -#include "ir/anf.h" -#include "ir/param_value.h" -#include "pipeline/parse/python_adapter.h" - -namespace mindspore { -namespace parallel { -std::string ParameterName(const AnfNodePtr &node_ptr) { - auto para_ptr = node_ptr->cast(); - MS_EXCEPTION_IF_NULL(para_ptr); - return para_ptr->name(); -} - -bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { - auto para_ptr = node_ptr->cast(); - if (para_ptr == nullptr) { - return false; - } - if (!para_ptr->has_default()) { - return false; - } - return para_ptr->default_param()->requires_grad(); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/group_manager.cc b/mindspore/ccsrc/parallel/group_manager.cc deleted file mode 100644 index 1562cbc140..0000000000 --- a/mindspore/ccsrc/parallel/group_manager.cc +++ /dev/null @@ -1,178 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/group_manager.h" - -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/ops_info/ops_utils.h" -#include "utils/comm_manager.h" - -namespace mindspore { -namespace parallel { -Group::Group() { - name_.clear(); - devices_.clear(); -} - -Status Group::Init(const std::string &name, const std::vector &devices) { - this->name_ = name; - this->devices_ = devices; - return Status::SUCCESS; -} - -std::vector Group::GetDevicesList() const { return devices_; } - -bool Group::IsInThisGroup(int32_t device_rank) { - for (auto &device : devices_) { - if (device.rank() == device_rank) { - return true; - } - } - return false; -} - -// Get the position of the device in the group -Status Group::GetIndex(size_t *index) { - size_t pos = 0; - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - for (auto &device : devices_) { - if (device.rank() == rank) { - *index = pos; - return Status::SUCCESS; - } else { - pos++; - } - } - MS_LOG(ERROR) << "Could not find device rank " << rank << "in this group!"; - return Status::FAILED; -} - -GroupManager::GroupManager() { groups_.clear(); } - -Status GroupManager::CreateGroup(const std::string &group_name, const std::vector &devices, - mindspore::parallel::Group *const group) { - // it is simple to use size to determine whether it is a world group - uint32_t world_size = 0; - if (world_group_ != NCCL_WORLD_GROUP) { - (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); - } - - if ((world_group_ == NCCL_WORLD_GROUP) || (devices.size() == world_size)) { - auto it = groups_.find(world_group_); - if (it == groups_.end()) { - (void)group->Init(world_group_, devices); - groups_[world_group_] = *group; - } else { - *group = it->second; - } - MS_LOG(INFO) << "It is world group " << world_group_ << ", no need to create it."; - return Status::SUCCESS; - } - - auto it = groups_.find(group_name); - // If there already exits a group with the desired 'name', - // let the pointer point to the group. - if (it != groups_.end()) { - *group = it->second; - return Status::SUCCESS; - } else { - (void)group->Init(group_name, devices); - groups_[group_name] = *group; - - vector ranks; - (void)std::transform(std::begin(devices), std::end(devices), std::back_inserter(ranks), - [](const Device dev) { return (uint32_t)dev.rank(); }); - // Create group through the CommManager interface - bool ret = CommManager::GetInstance().CreateGroupSync(group_name, ranks); - if (!ret) { - MS_LOG(ERROR) << "Create group failed, group name is " << group_name; - return Status::FAILED; - } - - MS_LOG(INFO) << "Create group success, group name is " << group_name; - return Status::SUCCESS; - } -} - -Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) { - std::string name = (*group).name(); - auto it = groups_.find(name); - if (it == groups_.end()) { - MS_LOG(ERROR) << "Could not find group name :" << name; - return Status::FAILED; - } - (void)groups_.erase(it); - bool ret = CommManager::GetInstance().DestroyGroup(name); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status GroupManager::DestroyAllGroups() { - for (auto &it : groups_) { - std::string name = it.first; - bool ret = CommManager::GetInstance().DestroyGroup(name); - if (!ret) { - return Status::FAILED; - } - } - groups_.clear(); - return Status::SUCCESS; -} - -Status GroupManager::GetRankID(const std::string &name, unsigned int *const rank_id) { - auto it = groups_.find(name); - if (it == groups_.end()) { - MS_LOG(ERROR) << "Could not find group name :" << name; - return Status::FAILED; - } - bool ret = CommManager::GetInstance().GetRankID(name, rank_id); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status GroupManager::GetRankSize(const std::string &name, unsigned int *const rank_size) { - auto it = groups_.find(name); - if (it == groups_.end()) { - MS_LOG(ERROR) << "Could not find group name :" << name; - return Status::FAILED; - } - bool ret = CommManager::GetInstance().GetRankSize(name, rank_size); - if (!ret) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Group **group) { - auto it = groups_.find(name); - if (it == groups_.end()) { - return Status::FAILED; - } - *group = &it->second; - return Status::SUCCESS; -} - -void GroupManager::Clear() { (void)DestroyAllGroups(); } -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/group_manager.h b/mindspore/ccsrc/parallel/group_manager.h deleted file mode 100644 index f763d483cc..0000000000 --- a/mindspore/ccsrc/parallel/group_manager.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ -#define MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ - -#include -#include -#include -#include - -#include "parallel/device.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -constexpr char HCCL_WORLD_GROUP[] = "hccl_world_group"; -constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; -constexpr char UNDEFINED_WORLD_GROUP[] = "undefined_world_group"; - -// Devices that need communication should in the same group. These classes are used to -// create and destroy group among devices. -class Group { - public: - Group(); - ~Group() = default; - Status Init(const std::string &name, const std::vector &devices); - std::vector GetDevicesList() const; - std::string name() const { return name_; } - bool IsInThisGroup(int32_t device_rank); - Status GetIndex(size_t *index); - size_t GetDevNum() const { return devices_.size(); } - - private: - std::string name_; - std::vector devices_; -}; - -class GroupManager { - public: - GroupManager(); - ~GroupManager() = default; - - Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); - Status DestroyGroup(Group *group); - Status DestroyAllGroups(); - Status GetRankID(const std::string &name, unsigned int *rank_id); - Status GetRankSize(const std::string &name, unsigned int *rank_size); - Status FindGroup(const std::string &name, Group **group); - std::string world_group() const { return world_group_; } - void set_world_group(const std::string &name) { world_group_ = name; } - void Clear(); - - private: - // the key is group name (name_) - std::map groups_; - std::string world_group_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_GROUP_MANAGER_H_ diff --git a/mindspore/ccsrc/parallel/node_check.cc b/mindspore/ccsrc/parallel/node_check.cc deleted file mode 100644 index 6b920f82ec..0000000000 --- a/mindspore/ccsrc/parallel/node_check.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/node_check.h" - -#include -#include - -#include "parallel/ops_info/ops_utils.h" - -namespace mindspore { -namespace parallel { -const std::set BLACK_LIST = {TUPLE_GETITEM, - MAKE_TUPLE, - J, - LIST_GETITEM, - ARRAY_GETITEM, - TUPLE_SETITEM, - DEPEND, - LIST_SETITEM, - ARRAY_SETITEM, - DICT_GETITEM, - LIST_APPEND, - LIST_MAP, - LIST_REDUCE, - TUPLE_REVERSED, - TILE_SHAPE, - TUPLE_DIV, - TUPLE_TO_ARRAY, - MAKE_LIST, - MAKE_DICT, - MAKE_SLICE, - MAKE_RECORD, - STRING_EQUAL, - VIRTUALLOSS, - RETURN, - ENV_GETITEM, - IDENTITY, - PARTIAL, - ENVSETITEM, - ENVGETITEM, - ENVADD, - MAKEREFKEY, - MAKEREF, - GETREFKEY, - GETREFVALUE, - GETREFORIGIN, - DOT, - IM2COL, - COL2IM, - IM2COLV1, - STATESETITEM, - SCALARSUMMARY, - IMAGESUMMARY, - TENSORSUMMARY, - DEBUG, - HISTOGRAMSUMMARY, - COL2IMV1, - RESOLVE, - BROADCASTGRADIENTARGS, - INVERTPERMUTATION, - CONTROLDEPEND, - DROPOUT_GEN_MASK, - EMBED, - CREATINSTANCE, - ZEROSLIKE, - ASSIGN, - REF_TO_EMBED, - STOP_GRADIENT}; - -bool IsInBlackList(const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(prim); - return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc deleted file mode 100644 index 6bc33677a6..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ /dev/null @@ -1,705 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/activation_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status Activation::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ActivationInfo::GetAttrs() { - if (attrs_.size() < ACTIVATION_ATTR_SIZE) { - MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; - return FAILED; - } - - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - - auto iter = attrs_.find(ACTIVATION_TYPE); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - std::string val = iter->second->cast()->value(); - if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) { - MS_LOG(ERROR) << name_ << " : Activation type is wrong."; - return FAILED; - } - } else { - MS_LOG(ERROR) << name_ << " : The value of activation_type is not string."; - return FAILED; - } - } - - return SUCCESS; -} - -Status ActivationOther::GetAttrs() { - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - return SUCCESS; -} - -Status Activation::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status Softmax::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - for (auto &element : axis_) { - int32_t axis_index = element; - if (element < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + element; - } - - int32_t axis_strategy = input_strategy.at(IntToSize(axis_index)); - // Dimension corresponding to axis is un-splittable - if (axis_strategy != MIN_SLICE_NUM) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; - } else { - MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1"; - } - return FAILED; - } - } - - return SUCCESS; -} - -Status Softmax::GetAttrs() { - if (attrs_.size() < SOFTMAX_ATTR_SIZE) { - MS_LOG(ERROR) << name_ << " : The size of attrs small than 1."; - return FAILED; - } - - auto iter = attrs_.find(AXIS); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { // the axis is a number - int32_t axis_element = iter->second->cast()->value(); - axis_.push_back(axis_element); - MS_LOG(INFO) << name_ << " : The axis is int, value is " << axis_element; - } else if (iter->second->isa()) { // the axis is a tuple - ValueTuplePtr value_tuple = iter->second->cast(); - if (value_tuple == nullptr) { - MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr."; - return FAILED; - } - std::vector value_vector = value_tuple->value(); - (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), - [](const ValuePtr &value) { return static_cast(GetValue(value)); }); - if (axis_.empty()) { - MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ShapeToString(axis_); - } else { - MS_LOG(ERROR) << name_ << " : The value of axis is not int or tuple int."; - return FAILED; - } - } - - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - - // for example: tensor dimension is 4, then axis range [-4, 3] - int32_t dim = SizeToInt(inputs_shape_.at(0).size()); - auto it = - std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); }); - if (it != axis_.end()) { - MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "]."; - return FAILED; - } - - return SUCCESS; -} - -Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status Softmax::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split; - (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - for (auto &element : axis_) { - int32_t axis_index = element; - if (element < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + element; - } - input0_split[IntToSize(axis_index)] = 0; - } - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status ActivationBase::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - dev_matrix_shape_ = input_strategy; - - return SUCCESS; -} - -Status ActivationBase::InferMirrorOps() { - mirror_ops_.clear(); - - Shape tensor_map = inputs_tensor_map_[0]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector mirror_op; - if (group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - std::string group_name = group[0].name(); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; - } - - return SUCCESS; -} - -Status ActivationBase::InferForwardCommunication() { - // do nothing - return SUCCESS; -} - -Status ActivationBase::InferTensorMap() { - std::vector tensor_map_index; - size_t size = inputs_shape_.at(0).size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - i - 1)); - } - - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(tensor_map_index); - return SUCCESS; -} - -Status ActivationBase::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - - TensorLayout input_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(input_tensor_info); // the same as input - - return SUCCESS; -} - -Status ActivationBase::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status CastInfo::InferMirrorOps() { - mirror_ops_.clear(); - - Shape tensor_map = inputs_tensor_map_[0]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector mirror_op; - OperatorVector op_for_value; - if (group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - mirror_ops_.push_back(op_for_value); - std::string group_name = group[0].name(); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; - } - - return SUCCESS; -} - -Status ExpandDimsInfo::GetAttrs() { - if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size(); - return FAILED; - } - - if (!input_value_.back()->isa()) { - MS_LOG(ERROR) << name_ << ": The type of axis is not int"; - return FAILED; - } - - int32_t axis = GetValue(input_value_.back()); - - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - int32_t dim = SizeToInt(inputs_shape_[0].size()); - if ((axis > dim) || (axis < -dim - 1)) { - MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim - 1 << ", " << dim << "]"; - return FAILED; - } - - if (axis < 0) { - positive_axis_ = dim + axis + 1; - } else { - positive_axis_ = axis; - } - MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_; - return SUCCESS; -} - -Status ExpandDimsInfo::InferTensorMap() { - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - // for example: if the dimension of input is 3, and the axis is 2, - // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0] - std::vector input_tensor_map, output_tensor_map; - size_t size = inputs_shape_[0].size(); - for (size_t i = 0; i < size; ++i) { - input_tensor_map.push_back(SizeToInt(size - i - 1)); - } - - inputs_tensor_map_.push_back(input_tensor_map); - - output_tensor_map = input_tensor_map; - if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(size))) { - MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; - return FAILED; - } - (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP); - outputs_tensor_map_.push_back(output_tensor_map); - - MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) - << ", and the tensor map of output is " << ShapeToString(output_tensor_map); - return SUCCESS; -} - -Status ExpandDimsInfo::InferTensorStrategy() { - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - inputs_strategy_ = strategy_->GetInputDim(); - if (inputs_strategy_.empty()) { - MS_LOG(ERROR) << name_ << ": The strategy is empty"; - return FAILED; - } - - Shape output_strategy = inputs_strategy_[0]; - if ((positive_axis_ < 0) || (positive_axis_ > SizeToInt(output_strategy.size()))) { - MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; - return FAILED; - } - (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY); - outputs_strategy_ = {output_strategy}; - return SUCCESS; -} - -Status ExpandDimsInfo::InferTensorInfo() { - if (inputs_shape_.empty() || outputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; - return FAILED; - } - - if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; - return FAILED; - } - - Shape input_shape = inputs_shape_[0]; - Shape output_shape = outputs_shape_[0]; - - // infer slice shape - if (InferTensorStrategy() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer tensor strategy failed"; - return FAILED; - } - Shapes inputs_slice_shape, outputs_slice_shape; - if (InferSliceShape(inputs_strategy_, outputs_strategy_, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; - return FAILED; - } - - if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { - MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; - return FAILED; - } - - Shape input_slice_shape = inputs_slice_shape[0]; - Shape output_slice_shape = outputs_slice_shape[0]; - - TensorLayout input_tensor_layout, output_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; - return FAILED; - } - - if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status ExpandDimsInfo::InferMirrorOps() { - mirror_ops_.clear(); - - if (inputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty"; - return FAILED; - } - - std::vector group; - if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Create group failed"; - return FAILED; - } - - if (group.empty()) { - MS_LOG(INFO) << name_ << ": No need to create mirror ops"; - return SUCCESS; - } - - OperatorVector mirror_op, placeholder_op; - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - mirror_ops_.push_back(mirror_op); - mirror_ops_.push_back(placeholder_op); - MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name(); - return SUCCESS; -} - -Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { - std::vector axis; - auto axis_list = value_tuple->value(); - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - Shape input_shape = inputs_shape_.at(0); - size_t input_size = input_shape.size(); - // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1. - if (axis_list.empty()) { - for (size_t i = 0; i < input_size; ++i) { - if (input_shape[i] == 1) { - axis.push_back(i); - } - } - axis_ = MakeValue(axis)->cast(); - return SUCCESS; - } - - // convert negative axis to positive. - for (auto &dim : axis_list) { - if (!dim->isa()) { - MS_LOG(ERROR) << name_ << ": The type of axis is not int"; - return FAILED; - } - int32_t dim_value = GetValue(dim); - int32_t positive_value = (dim_value < 0) ? (dim_value + SizeToInt(input_size)) : dim_value; - axis.push_back(positive_value); - } - axis_ = MakeValue(axis)->cast(); - return SUCCESS; -} - -Status SqueezeInfo::GetAttrs() { - auto iter = attrs_.find(AXIS); - if (iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can't find axis attribute."; - return FAILED; - } - MS_EXCEPTION_IF_NULL(iter->second); - auto value_tuple = iter->second->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - InferAxis(value_tuple); - attrs_[AXIS] = axis_; - return SUCCESS; -} - -Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { - Attr attr = std::make_pair(AXIS, axis_); - OperatorAttrs attrs = {attr}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - replace_op_ = {std::make_pair(SQUEEZE, args)}; - return SUCCESS; -} - -Status SqueezeInfo::InferTensorMap() { - // for example: if the shape of input is [32, 32, 1], and the axis is (2, ), - // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1] - std::vector input_tensor_map, output_tensor_map; - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - size_t size = inputs_shape_[0].size(); - std::vector axis = GetValue>(axis_); - for (size_t i = 0; i < size; ++i) { - size_t index = size - i - 1; - auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); - if (iter == axis.end()) { - output_tensor_map.push_back(SizeToInt(index)); - } - input_tensor_map.push_back(SizeToInt(index)); - } - inputs_tensor_map_.push_back(input_tensor_map); - outputs_tensor_map_.push_back(output_tensor_map); - MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map) - << ", and the tensor map of output is " << ShapeToString(output_tensor_map); - - return SUCCESS; -} - -Status SqueezeInfo::InferTensorInfo() { - if (inputs_shape_.empty() || outputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The shape of inputs or outputs is empty"; - return FAILED; - } - - if (inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The tensor map of inputs or outputs is empty"; - return FAILED; - } - - Shape input_shape = inputs_shape_[0]; - Shape output_shape = outputs_shape_[0]; - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Dimensions output_strategy; - std::vector axis = GetValue>(axis_); - for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { - auto iter = std::find(axis.begin(), axis.end(), SizeToInt(i)); - if (iter == axis.end()) { - output_strategy.push_back(inputs_strategy[0].at(i)); - } - } - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer slice shape failed"; - return FAILED; - } - - if (inputs_slice_shape.empty() || outputs_slice_shape.empty()) { - MS_LOG(ERROR) << name_ << ": The slice shape of inputs or outputs is empty"; - return FAILED; - } - - Shape input_slice_shape = inputs_slice_shape[0]; - Shape output_slice_shape = outputs_slice_shape[0]; - - // infer tensor layout - TensorLayout input_tensor_layout, output_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for input failed"; - return FAILED; - } - - if (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for output failed"; - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status SqueezeInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - } - - if (InferReplaceOps(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer replace ops failed"; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h deleted file mode 100644 index cd66bf8e8b..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ /dev/null @@ -1,224 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ - -#include -#include -#include -#include -#include - -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ActivationBase : public OperatorInfo { - public: - ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs, OperatorCostPtr cost) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} - ~ActivationBase() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - protected: - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; -}; - -class Activation : public ActivationBase { - public: - Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~Activation() override = default; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; -}; - -class ActivationInfo : public Activation { - public: - ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} - ~ActivationInfo() override = default; - - protected: - Status GetAttrs() override; // activation_type: relu, relu6, sigmoid -}; - -class ActivationOther : public Activation { - public: - ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} - ~ActivationOther() override = default; - - protected: - Status GetAttrs() override; -}; - -class GeluInfo : public ActivationOther { - public: - GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~GeluInfo() override = default; -}; - -class TanhInfo : public ActivationOther { - public: - TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~TanhInfo() override = default; -}; - -class Softmax : public ActivationBase { - public: - explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~Softmax() override = default; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - - private: - std::vector axis_; -}; - -class SoftmaxInfo : public Softmax { - public: - SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Softmax(name, inputs_shape, outputs_shape, attrs) {} - ~SoftmaxInfo() override = default; -}; - -class LogSoftmaxInfo : public Softmax { - public: - LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Softmax(name, inputs_shape, outputs_shape, attrs) {} - ~LogSoftmaxInfo() override = default; -}; - -class ReLUInfo : public ActivationOther { - public: - ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ReLUInfo() override = default; -}; - -class CastInfo : public ActivationOther { - public: - CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~CastInfo() override = default; - - protected: - Status InferMirrorOps() override; -}; - -class SqrtInfo : public ActivationOther { - public: - SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SqrtInfo() override = default; -}; - -class NegInfo : public ActivationOther { - public: - NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~NegInfo() override = default; -}; - -class ExpandDimsInfo : public ActivationOther { - public: - ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ExpandDimsInfo() override = default; - - protected: - Status GetAttrs() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferMirrorOps() override; - Status InferTensorStrategy(); - - private: - int32_t positive_axis_ = -1; - Strategys inputs_strategy_; - Strategys outputs_strategy_; -}; - -class SqueezeInfo : public ActivationOther { - public: - SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SqueezeInfo() override = default; - - protected: - Status InferAxis(const ValueTuplePtr &value_tuple); - Status GetAttrs() override; - Status InferReplaceOps(const StrategyPtr &strategy); - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status Init(const StrategyPtr &strategy) override; - - private: - ValueTuplePtr axis_; -}; - -class SquareInfo : public ActivationOther { - public: - SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SquareInfo() override = default; -}; - -class SigmoidInfo : public ActivationOther { - public: - SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~SigmoidInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.cc b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.cc deleted file mode 100644 index 02c26ea965..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.cc +++ /dev/null @@ -1,363 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/arithmetic_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Shape ExpendShape(const Shape &bigger_size_shape, Shape smaller_size_shape) { - size_t insert_num = bigger_size_shape.size() - smaller_size_shape.size(); - for (size_t num = 0; num < insert_num; ++num) { - (void)smaller_size_shape.insert(smaller_size_shape.begin(), 1); - } - return smaller_size_shape; -} - -Shapes ArithmeticBase::InferExpendShape() { - Shape input_a_shape = inputs_shape_.at(0); - Shape input_b_shape = inputs_shape_.at(1); - Shapes input_shapes; - size_t input_a_size = input_a_shape.size(); - size_t input_b_size = input_b_shape.size(); - if (input_a_size > input_b_size) { - input_shapes.push_back(input_a_shape); - input_shapes.push_back(ExpendShape(input_a_shape, input_b_shape)); - } else if (input_a_size < input_b_size) { - input_shapes.push_back(ExpendShape(input_b_shape, input_a_shape)); - input_shapes.push_back(input_b_shape); - } else { - input_shapes.push_back(input_a_shape); - input_shapes.push_back(input_b_shape); - } - return input_shapes; -} - -std::vector ExpendStrategy(const StrategyPtr &strategy) { - std::vector expend_strategy; - std::vector stra = strategy->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - Dimensions sub_b_strategy = stra.at(1); - size_t input_a_size = sub_a_strategy.size(); - size_t input_b_size = sub_b_strategy.size(); - if (input_a_size > input_b_size) { - expend_strategy.push_back(sub_a_strategy); - expend_strategy.push_back(ExpendShape(sub_a_strategy, sub_b_strategy)); - } else if (input_a_size < input_b_size) { - expend_strategy.push_back(ExpendShape(sub_b_strategy, sub_a_strategy)); - expend_strategy.push_back(sub_b_strategy); - } else { - expend_strategy = stra; - } - return expend_strategy; -} - -Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - Shapes input_shapes = InferExpendShape(); - std::vector expend_strategy = ExpendStrategy(strategy); - Dimensions sub_a_strategy = expend_strategy.at(0); - Dimensions sub_b_strategy = expend_strategy.at(1); - Shape input_a_shape = input_shapes.at(0); - Shape input_b_shape = input_shapes.at(1); - - for (size_t i = 0; i < input_a_shape.size(); ++i) { - if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != 1) && (input_b_shape[i] != 1)) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - } - return SUCCESS; -} - -Status ArithmeticBase::InferDevMatrixShape() { - std::vector expend_strategy = ExpendStrategy(strategy_); - Dimensions sub_a_strategy = expend_strategy.at(0); - Dimensions sub_b_strategy = expend_strategy.at(1); - Shape dev_shape; - for (size_t i = 0; i < sub_a_strategy.size(); ++i) { - if (sub_a_strategy[i] != sub_b_strategy[i]) { - dev_shape.push_back(sub_a_strategy[i] * sub_b_strategy[i]); - } else { - dev_shape.push_back(sub_a_strategy[i]); - } - } - dev_matrix_shape_ = dev_shape; - - return SUCCESS; -} - -TensorMap SetExpendTensorMap(const Shape &strategy, const Shape &dev_matrix_shape) { - TensorMap tensor_map_index; - for (size_t i = 0; i < strategy.size(); ++i) { - if (strategy[i] == dev_matrix_shape[i]) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(strategy.size())) - i)); - } else { - tensor_map_index.push_back(-1); - } - } - return tensor_map_index; -} - -TensorMap SetTensorMap(const Shape &strategy_expend, const Shape &dev_matrix_shape, const Shape &strategy) { - TensorMap expend_map = SetExpendTensorMap(strategy_expend, dev_matrix_shape); - size_t dev_matrix_size = dev_matrix_shape.size(); - size_t strategy_size = strategy.size(); - if (dev_matrix_size != strategy_size) { - (void)expend_map.erase(expend_map.begin(), - expend_map.begin() + static_cast(dev_matrix_size - strategy_size)); - } - return expend_map; -} - -void ArithmeticBase::ReComputeBatchSplitFlagList() { - Shapes expend_shapes = InferExpendShape(); - Shape expend_a_shape = expend_shapes.at(0); - Shape expend_b_shape = expend_shapes.at(1); - if (expend_a_shape.size() != expend_b_shape.size()) { - MS_LOG(EXCEPTION) << name_ << " : Recompute batch split flag list is wrong."; - } - if (expend_a_shape.empty()) { - split_flag_list_[0] = false; - split_flag_list_[1] = false; - return; - } - (expend_a_shape.at(0) != 1) ? (split_flag_list_[0] = true) : (split_flag_list_[0] = false); - (expend_b_shape.at(0) != 1) ? (split_flag_list_[1] = true) : (split_flag_list_[1] = false); -} - -Status ArithmeticBase::InferTensorMap() { - std::vector tensor_map_index; - std::vector expend_strategy = ExpendStrategy(strategy_); - Dimensions sub_a_expend_strategy = expend_strategy.at(0); - Dimensions sub_b_expend_strategy = expend_strategy.at(1); - Strategys stra = strategy_->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - Dimensions sub_b_strategy = stra.at(1); - for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_expend_strategy.size())) - i)); - } - - Shape dev_shape; - for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { - dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); - } else { - dev_shape.push_back(sub_a_expend_strategy[i]); - } - } - inputs_tensor_map_.push_back(SetTensorMap(sub_a_expend_strategy, dev_shape, sub_a_strategy)); - inputs_tensor_map_.push_back(SetTensorMap(sub_b_expend_strategy, dev_shape, sub_b_strategy)); - outputs_tensor_map_.push_back(tensor_map_index); - - return SUCCESS; -} - -Status ArithmeticBase::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_a_tensor_map = inputs_tensor_map_.at(0); - Shape input_b_tensor_map = inputs_tensor_map_.at(1); - std::vector input_a_group, input_b_group; - if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input a failed."; - return FAILED; - } - if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input b failed."; - return FAILED; - } - - OperatorVector op_for_input_a, op_for_input_b; - if (input_a_group.empty() && input_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror group is empty."; - return SUCCESS; - } - if (!input_a_group.empty()) { - op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); - } - if (!input_b_group.empty()) { - op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); - } - mirror_ops_.push_back(op_for_input_a); - mirror_ops_.push_back(op_for_input_b); - - return SUCCESS; -} - -Status ArithmeticBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, - const Shape &dev_matrix_array) { - if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { - MS_LOG(ERROR) << name_ << " : The layout is null."; - return FAILED; - } - TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); - TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); - TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); - Shape input_a_shape_array = inputs_shape_.at(0); - Shape input_b_shape_array = inputs_shape_.at(1); - Shape out_shape_array = outputs_shape_.at(0); - - TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; - if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; - return FAILED; - } - if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; - return FAILED; - } - if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; - return FAILED; - } - inputs_layout->push_back(input_a_tensor_layout); - inputs_layout->push_back(input_b_tensor_layout); - outputs_layout->push_back(out_tensor_layout); - - return SUCCESS; -} - -Status ArithmeticBase::InferTensorInfo() { - // infer tensor shape - Shape input_a_shape = inputs_shape_.at(0); - Shape input_b_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - std::vector expend_strategy = ExpendStrategy(strategy_); - Dimensions sub_a_expend_strategy = expend_strategy.at(0); - Dimensions sub_b_expend_strategy = expend_strategy.at(1); - Strategys inputs_strategy = strategy_->GetInputDim(); - Shape dev_shape; - for (size_t i = 0; i < sub_a_expend_strategy.size(); ++i) { - if (sub_a_expend_strategy[i] != sub_b_expend_strategy[i]) { - dev_shape.push_back(sub_a_expend_strategy[i] * sub_b_expend_strategy[i]); - } else { - dev_shape.push_back(sub_a_expend_strategy[i]); - } - } - Strategys outputs_strategy = {dev_shape}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_a_slice_shape = inputs_slice_shape.at(0); - Shape input_b_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; - return FAILED; - } - - TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); - TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); - TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a - inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b - outputs_tensor_info_.push_back(out_tensor_info); // output - - return SUCCESS; -} - -Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { - Shape input0_split(inputs_shape_[0].size(), 1); - Shape input1_split(inputs_shape_[1].size(), 1); - Shapes splittable_inputs = {input0_split, input1_split}; - - std::vector sp_vector; - is_auto_parallel_ = true; - if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies with broadcast failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; - - size_t success = 0; - for (auto &sp : sp_vector) { - PrintStrategy(sp); - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status ArithmeticBase::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h deleted file mode 100644 index 27caacc30c..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ArithmeticBase : public OperatorInfo { - public: - ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs, OperatorCostPtr cost) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} - ~ArithmeticBase() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr &) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); - Shapes InferExpendShape(); -}; - -class SubInfo : public ArithmeticBase { - public: - SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~SubInfo() override = default; -}; - -class TensorAddInfo : public ArithmeticBase { - public: - TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~TensorAddInfo() override = default; -}; - -class MulInfo : public ArithmeticBase { - public: - MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MulInfo() override = default; -}; - -class DivInfo : public ArithmeticBase { - public: - DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~DivInfo() override = default; -}; - -class RealDivInfo : public ArithmeticBase { - public: - RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~RealDivInfo() override = default; -}; - -class FloorDivInfo : public ArithmeticBase { - public: - FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~FloorDivInfo() override = default; -}; - -class PowInfo : public ArithmeticBase { - public: - PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~PowInfo() override = default; -}; - -class GreaterInfo : public ArithmeticBase { - public: - GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~GreaterInfo() override = default; -}; - -class AssignSubInfo : public ArithmeticBase { - public: - AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~AssignSubInfo() override = default; -}; - -// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. -class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { - public: - SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~SigmoidCrossEntropyWithLogitsInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ARITHMETIC_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc deleted file mode 100644 index dac3b0a675..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc +++ /dev/null @@ -1,235 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/batch_parallel_info.h" - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" - -namespace mindspore { -namespace parallel { -Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - int32_t stage = strategy->GetInputStage(); - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; - - size_t strategy_size = strategy->GetInputNumber(); - std::vector stra = strategy->GetInputDim(); - for (size_t i = 0; i < strategy_size; ++i) { - Shape sub_strategy = stra.at(i); - size_t strategy_len = sub_strategy.size(); - bool flag = false; - for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); - if (strategy_value > 1) { - if (flag || strategy_value != dev_num_) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : It is not a valid data parallel strategy."; - } else { - MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; - } - return FAILED; - } - flag = true; - } - } - } - return SUCCESS; -} - -Status BatchParallelInfo::InferDevMatrixShape() { - dev_matrix_shape_.push_back(dev_num_); - return SUCCESS; -} - -Status BatchParallelInfo::InferMirrorOps() { - mirror_ops_.clear(); - if (g_device_manager->DeviceNum() == 1) { - MS_LOG(INFO) << name_ << " : The device num is 1, no need to create mirror ops."; - return SUCCESS; - } - - MS_LOG(INFO) << name_ << " : Batch parallel input number " << strategy_->GetInputNumber(); - for (size_t i = 0; i < input_value_.size(); i++) { - MS_EXCEPTION_IF_NULL(g_device_manager); - OperatorVector op_vec = CreateMirrorOps(g_device_manager->world_group(), g_device_manager->DeviceNum()); - mirror_ops_.push_back(op_vec); - } - return SUCCESS; -} - -Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } - -Status BatchParallelInfo::InferTensorMap() { - if (strategy_->GetInputDim()[0][0] != dev_num_) { - MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; - return FAILED; - } - for (size_t i = 0; i < inputs_shape_.size(); i++) { - std::vector tensor_map_index; - for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { - if (strategy_->GetInputDim()[i][j] == dev_num_ && j == 0) { - tensor_map_index.push_back(0); - } else { - tensor_map_index.push_back(MAP_NONE); - } - } - inputs_tensor_map_.push_back(tensor_map_index); - } - for (size_t i = 0; i < outputs_shape_.size(); i++) { - std::vector tensor_map_index; - for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { - if (i == 0 && j == 0) { - tensor_map_index.push_back(0); - } else { - tensor_map_index.push_back(MAP_NONE); - } - } - outputs_tensor_map_.push_back(tensor_map_index); - } - return SUCCESS; -} - -Strategys BatchParallelInfo::GetOutputsStrategy() { - Strategys outputs_strategy; - - for (size_t i = 0; i < outputs_shape_.size(); ++i) { - std::vector strategy; - for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { - if (i == 0 && j == 0) { - strategy.push_back(dev_num_); - } else { - strategy.push_back(1); - } - } - outputs_strategy.push_back(strategy); - } - - return outputs_strategy; -} - -Status BatchParallelInfo::InferTensorInfo() { - for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - MS_LOG(INFO) << name_ << " : The input size is " << strategy_->GetInputNumber(); - TensorLayout tensor_layout_in; - if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { - return FAILED; - } - TensorInfo tensor_info_in(tensor_layout_in); - inputs_tensor_info_.push_back(tensor_info_in); - } - for (size_t i = 0; i < outputs_shape_.size(); i++) { - TensorLayout tensor_layout_out; - if (tensor_layout_out.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(i), outputs_shape_.at(i)) != - SUCCESS) { - return FAILED; - } - TensorInfo tensor_info_out(tensor_layout_out); - outputs_tensor_info_.push_back(tensor_info_out); - } - return SUCCESS; -} - -Status BatchParallelInfo::GetAttrs() { return SUCCESS; } - -Status BatchParallelInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { - CheckGlobalDeviceManager(); - is_auto_parallel_ = true; - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - StrategyPtr sp; - std::vector strategy; - for (size_t i = 0; i < inputs_shape_.size(); i++) { - Shape temp(inputs_shape_[i].size(), 1); - if (split_flag_list_[i]) { - temp[0] = SizeToInt(total_dev_num); - } - strategy.push_back(temp); - } - sp = std::make_shared(stage_id, strategy); - - if (SetCostUnderStrategy(sp) == SUCCESS) { - MS_LOG(INFO) << name_ << " : Successfully generated batch-parallel-strategy."; - PrintStrategy(sp); - } else { - MS_LOG(ERROR) << name_ << " : Generating batch-parallel-strategy failed."; - return FAILED; - } - return SUCCESS; -} - -void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { - for (size_t i = 0; i < inputs_shape_.size(); i++) { - split_flag_list_[i] = true; - } -} - -Status BatchParallelInfo::InferAsLossDivisor() { - as_loss_divisor_ = 1; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h deleted file mode 100644 index db6cb206d5..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ - -#include -#include -#include -#include -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class BatchParallelInfo : public OperatorInfo { - public: - BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs, OperatorCostPtr cost) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} - BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), - dev_num_(1) {} - - ~BatchParallelInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - Strategys GetOutputsStrategy(); - Status InferAsLossDivisor() override; - - private: - int32_t dev_num_; -}; - -class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { - public: - SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, - const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; - void ReComputeBatchSplitFlagList() override; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BATCH_PARALLEL_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.cc b/mindspore/ccsrc/parallel/ops_info/bias_add_info.cc deleted file mode 100644 index 005edaf7c7..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.cc +++ /dev/null @@ -1,261 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/bias_add_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - std::vector stra = strategy->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - Dimensions sub_b_strategy = stra.at(1); - int32_t channel_a_strategy = sub_a_strategy.at(1); - int32_t channel_b_strategy = sub_b_strategy.at(0); - if (channel_a_strategy != channel_b_strategy) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - return SUCCESS; -} - -Status BiasAddInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - dev_matrix_shape_ = sub_a_strategy; - return SUCCESS; -} - -void BiasAddInfo::ReComputeBatchSplitFlagList() { - split_flag_list_[0] = true; - split_flag_list_[1] = false; -} - -Status BiasAddInfo::InferTensorMap() { - TensorMap sub_a_tensor_map; - TensorMap sub_b_tensor_map; - std::vector stra = strategy_->GetInputDim(); - Dimensions sub_a_strategy = stra.at(0); - size_t sub_a_strategy_size = sub_a_strategy.size(); - for (size_t i = 0; i < sub_a_strategy_size; ++i) { - sub_a_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - i)); - } - sub_b_tensor_map.push_back((int32_t)(LAST_INDEX(SizeToUint(sub_a_strategy_size)) - 1)); - - inputs_tensor_map_.push_back(sub_a_tensor_map); - inputs_tensor_map_.push_back(sub_b_tensor_map); - outputs_tensor_map_.push_back(sub_a_tensor_map); - - return SUCCESS; -} - -Status BiasAddInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_a_tensor_map = inputs_tensor_map_.at(0); - Shape input_b_tensor_map = inputs_tensor_map_.at(1); - std::vector input_a_group, input_b_group; - if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input a failed."; - return FAILED; - } - if (CreateGroupByTensorMap(input_b_tensor_map, &input_b_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input b failed."; - return FAILED; - } - - OperatorVector op_for_input_a, op_for_input_b; - if (input_a_group.empty() && input_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror group is empty."; - return SUCCESS; - } - if (!input_a_group.empty()) { - op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); - } - if (!input_b_group.empty()) { - op_for_input_b = CreateMirrorOps(input_b_group[0].name(), input_b_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input b success, group is " << input_b_group[0].name(); - } - mirror_ops_.push_back(op_for_input_a); - mirror_ops_.push_back(op_for_input_b); - - return SUCCESS; -} - -Status BiasAddInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, - const Shape &dev_matrix_array) { - if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { - MS_LOG(ERROR) << name_ << " : The layout is null."; - return FAILED; - } - TensorMap input_a_tensor_map_array = inputs_tensor_map_.at(0); - TensorMap input_b_tensor_map_array = inputs_tensor_map_.at(1); - TensorMap out_tensor_map_array = outputs_tensor_map_.at(0); - Shape input_a_shape_array = inputs_shape_.at(0); - Shape input_b_shape_array = inputs_shape_.at(1); - Shape out_shape_array = outputs_shape_.at(0); - - TensorLayout input_a_tensor_layout, input_b_tensor_layout, out_tensor_layout; - if (input_a_tensor_layout.InitFromVector(dev_matrix_array, input_a_tensor_map_array, input_a_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input a failed."; - return FAILED; - } - if (input_b_tensor_layout.InitFromVector(dev_matrix_array, input_b_tensor_map_array, input_b_shape_array) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for input b failed."; - return FAILED; - } - if (out_tensor_layout.InitFromVector(dev_matrix_array, out_tensor_map_array, out_shape_array) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create tensor layout for output failed."; - return FAILED; - } - inputs_layout->push_back(input_a_tensor_layout); - inputs_layout->push_back(input_b_tensor_layout); - outputs_layout->push_back(out_tensor_layout); - - return SUCCESS; -} - -Status BiasAddInfo::InferTensorInfo() { - // infer tensor shape - Shape input_a_shape = inputs_shape_.at(0); - Shape input_b_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_a_slice_shape = inputs_slice_shape.at(0); - Shape input_b_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout, dev_matrix_shape_) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer tensor layout failed."; - return FAILED; - } - - TensorInfo input_a_tensor_info(inputs_layout.at(0), input_a_shape, input_a_slice_shape); - TensorInfo input_b_tensor_info(inputs_layout.at(1), input_b_shape, input_b_slice_shape); - TensorInfo out_tensor_info(outputs_layout.at(0), output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_a_tensor_info); // inputs_a - inputs_tensor_info_.push_back(input_b_tensor_info); // inputs_b - outputs_tensor_info_.push_back(out_tensor_info); // output - - return SUCCESS; -} - -Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split, input0_split}; - - std::vector sp_vector; - is_auto_parallel_ = true; - Shapes tmp_inputs_shape = {inputs_shape_[0], inputs_shape_[0]}; - Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; - if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &sp_vector) != - SUCCESS) { - return FAILED; - } - MS_LOG(INFO) << name_ << " : Generate strategies with broadcast success."; - - for (auto &sp : sp_vector) { - std::vector tmp_strategy; - Dimensions input0_strategy = sp->GetInputDim()[0]; - tmp_strategy.push_back(input0_strategy); // input0 - - Dimensions input1_strategy = {input0_strategy.at(1)}; - - // reset the strategy - tmp_strategy.push_back(input1_strategy); // input1 - sp->ResetInputs(tmp_strategy); - } - size_t success = 0; - for (auto &sp : sp_vector) { - PrintStrategy(sp); - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status BiasAddInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status BiasAddInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h deleted file mode 100644 index 37f555a258..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ - -#include - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class BiasAddInfo : public OperatorInfo { - public: - BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~BiasAddInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr &) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_BIAS_ADD_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h deleted file mode 100644 index 8dd2976b04..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ - -#include -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class EqualInfo : public ArithmeticBase { - public: - EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~EqualInfo() override = default; -}; - -class NotEqualInfo : public ArithmeticBase { - public: - NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~NotEqualInfo() override = default; -}; - -class MaximumInfo : public ArithmeticBase { - public: - MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MaximumInfo() override = default; -}; - -class MinimumInfo : public ArithmeticBase { - public: - MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MinimumInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_COMPARISON_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc deleted file mode 100644 index e88868c772..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ /dev/null @@ -1,323 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/dropout_do_mask_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "pipeline/resource.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -static int32_t SEED_NUM = 1; - -Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - if (stra.size() != 1) { - MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size() << ", it must be 1"; - return FAILED; - } - - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - // only check the input[0] - Shapes input_shape = {inputs_shape_[0]}; - if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy"; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy"; - } - return FAILED; - } - return SUCCESS; -} - -Status DropoutDoMaskInfo::InferDevMatrixShape() { - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - std::vector strategy = strategy_->GetInputDim(); - if (strategy.empty()) { - MS_LOG(ERROR) << name_ << ": The strategy is empty"; - return FAILED; - } - - dev_matrix_shape_ = strategy[0]; - return SUCCESS; -} - -Status DropoutDoMaskInfo::InferTensorMap() { - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - std::vector tensor_map_index; - size_t size = inputs_shape_[0].size(); - // if the dimension of input is 4, and tensor_map_index is [3, 2, 1, 0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back(SizeToInt(size - i - 1)); - } - - // the input[1] do not need tensor map - inputs_tensor_map_.push_back(tensor_map_index); // input_0 - outputs_tensor_map_.push_back(tensor_map_index); // output - return SUCCESS; -} - -Status DropoutDoMaskInfo::InferTensorInfo() { - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": Invalid inputs shape size " << inputs_shape_.size(); - return FAILED; - } - - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - - Shape input_0_shape = inputs_shape_[0]; - - if (inputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty"; - return FAILED; - } - - TensorLayout input_0_tensor_layout; - if (input_0_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_0_shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout failed"; - return FAILED; - } - - TensorInfo input_0_tensor_info(input_0_tensor_layout); - - // input_1 do not need tensor info - inputs_tensor_info_.push_back(input_0_tensor_info); // input_0 - outputs_tensor_info_.push_back(input_0_tensor_info); // output - return SUCCESS; -} - -Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { - if (inputs_shape_.empty()) { - MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - Shapes used_inputs_shape = {inputs_shape_[0]}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, used_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Generate strategies failed"; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -std::shared_ptr>> DropoutDoMaskInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions strategy(inputs_shape_[0].size() - 1, 1); - (void)strategy.insert(strategy.begin(), SizeToInt(dev_num)); - std::vector strategy_v = {strategy}; - return std::make_shared>>(strategy_v); -} - -Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); - MS_EXCEPTION_IF_NULL(dropout_gen_mask); - if (!dropout_gen_mask->isa()) { - MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode"; - } - - auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; - } - if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { - MS_LOG(EXCEPTION) << "The input[0] of dropout gen mask cnode is not primitive"; - } - - ValueNodePtr value_node = dropout_gen_mask_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value_node); - PrimitivePtr prim = value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() != DROPOUT_GEN_MASK) { - MS_LOG(EXCEPTION) << "The primitive name is not DropoutGenMask"; - } - return prim; -} - -void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX); - MS_EXCEPTION_IF_NULL(dropout_gen_mask); - if (!dropout_gen_mask->isa()) { - MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode."; - } - - auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; - } - - if (!IsValueNode(dropout_gen_mask_cnode->input(1))) { - MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple."; - } - - FuncGraphPtr func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr."; - } - - ValuePtr new_shape = MakeValue(input_slice_shape); - AnfNodePtr val = NewValueNode(new_shape); - (void)manager->Replace(dropout_gen_mask_cnode->input(1), val); -} - -// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is -// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape -// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation -// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -std::vector DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { - std::vector replace_ops; - MS_EXCEPTION_IF_NULL(cnode); - PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); - MS_EXCEPTION_IF_NULL(prim); - - if (inputs_tensor_info_.empty()) { - MS_LOG(EXCEPTION) << "The tensor info of dropout do mask is empty"; - } - - if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - - if (!cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)->isa()) { - MS_LOG(EXCEPTION) << "The keep prob of dropout do mask is not value node"; - } - - ValuePtr keep_prob = GetValueNode(cnode->input(DROPOUT_DO_MASK_KEEP_PROB_INDEX)); - MS_EXCEPTION_IF_NULL(keep_prob); - auto attr = prim->attrs(); - if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) { - MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1"; - } - - Shape input_slice_shape = inputs_tensor_info_[0].slice_shape(); - int32_t seed_0 = GetValue(attr[SEED0]); - int32_t seed_1 = GetValue(attr[SEED1]); - if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) { - seed_0 = SEED_NUM; - seed_1 = SEED_NUM; - SEED_NUM++; - } else { - SetGenMaskShape(cnode, input_slice_shape); - MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape); - return replace_ops; - } - - ValuePtr new_shape = MakeValue(input_slice_shape); - Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0)); - Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1)); - OperatorAttrs attrs = {attr_0, attr_1}; - Attr param_0 = std::make_pair(SHAPE, new_shape); - Attr param_1 = std::make_pair(KEEP_PROB, keep_prob); - OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)}; - OperatorArgs args = std::make_pair(attrs, params); - Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)}; - replace_ops.push_back(replace_op); - return replace_ops; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h deleted file mode 100644 index c51a0a9513..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class DropoutDoMaskInfo : public OperatorInfo { - public: - DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~DropoutDoMaskInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - std::shared_ptr>> GenerateBatchStrategies() override; - std::vector GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorMap() override; - Status GetAttrs() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; -}; - -using DropoutDoMaskInfoPtr = std::shared_ptr; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_DROPOUT_DO_MASK_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h deleted file mode 100644 index 2172c5cd89..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ - -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ExpInfo : public ActivationOther { - public: - ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ExpInfo() override = default; -}; - -class LogInfo : public ActivationOther { - public: - LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~LogInfo() override = default; -}; - -class CosInfo : public ActivationOther { - public: - CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~CosInfo() override = default; -}; - -class ACosInfo : public ActivationOther { - public: - ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~ACosInfo() override = default; -}; - -class LogicalNotInfo : public ActivationOther { - public: - LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} - ~LogicalNotInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ELEMENTARY_FUNCTION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc deleted file mode 100644 index 078be08128..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc +++ /dev/null @@ -1,350 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/gather_v2_info.h" - -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/value.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/graph_util/generate_graph.h" -#include "parallel/strategy.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status GatherV2Info::GetAttrs() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { - MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); - return FAILED; - } - // the second input is the index tensor - - // the third input is the axis, is a ValueNode - if (input_value_.at(2) == nullptr) { - MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; - return FAILED; - } - - if (inputs_shape_.at(0).size() == 0) { - MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; - return FAILED; - } - int axis = GetValue(input_value_.at(2)); - if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) { - MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " - << inputs_shape_.at(0).size() << ")."; - } - if (axis < 0) { - axis += SizeToInt(inputs_shape_[0].size()); - } - axis_ = axis; - - index_size_ = inputs_shape_.at(1).size(); - - return SUCCESS; -} - -Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_shape_.size(); - return FAILED; - } - // Only strategy of the first input should be set. - if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - axis_strategy_ = strategy->GetInputDim().at(0).at(axis_); - if (index_size_ != 1 && axis_strategy_ != 1) { - MS_LOG(ERROR) << name_ - << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " - "corresponding to axis must be 1, but is " - << axis_strategy_; - return FAILED; - } - if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { - MS_LOG(ERROR) << name_ - << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " - "axis. The first dimension of index is " - << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; - return FAILED; - } - return SUCCESS; -} - -Status GatherV2Info::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - dev_matrix_shape_ = stra.at(0); - return SUCCESS; -} - -// If index is a scalar, output dimension is input dimension minus 1; -// If index is a n dimension tensor, output dimension is input dimension plus (n - 1). -// Tensor map dimension is equal to the corresponding input and output dimension. -// If index's dimension is more than 1, we insert -1 for the output tensor map. -Status GatherV2Info::InferTensorMap() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_shape_.size(); - return FAILED; - } - std::vector tensor_map_in; - std::vector tensor_map_out; - size_t size = inputs_shape_.at(0).size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_in.push_back(SizeToInt(size - i - 1)); - tensor_map_out.push_back(SizeToInt(size - i - 1)); - } - - if (index_size_ == 0) { - (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); - } else if (index_size_ > 1) { - (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); - } - if (tensor_map_out.size() != outputs_shape_.at(0).size()) { - MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() - << " output size is " << outputs_shape_.at(0).size(); - return FAILED; - } - - std::vector tensor_map_in_index; - if (index_size_ >= 1) { - tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); - } - for (size_t i = 1; i < index_size_; ++i) { - tensor_map_in_index.push_back(-1); - } - inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); - inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); - outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); - return SUCCESS; -} - -Status GatherV2Info::InferTensorInfo() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_shape_.size(); - return FAILED; - } - if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_tensor_map_.size(); - return FAILED; - } - if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " - << outputs_tensor_map_.size(); - return FAILED; - } - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape input_index_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - - TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || - (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout); - TensorInfo input_index_info(input_index_layout); - TensorInfo output_tensor_info(output_tensor_layout); - - inputs_tensor_info_.push_back(input_tensor_info); - inputs_tensor_info_.push_back(input_index_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -OperatorVector CreateSubOp(int32_t sub_value) { - OperatorVector ops; - OperatorName operator_name = SUB; - OperatorAttrs operator_attrs; - - std::vector tensor_data = {sub_value}; - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, kInt32); - ValuePtr op_param_value = MakeValue(tensor_ptr); - - Attr op1_param = std::make_pair("", op_param_value); - OperatorParams operator_param = {std::make_pair(op1_param, 2)}; - - OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); - Operator op = std::make_pair(operator_name, operator_args); - ops.push_back(op); - return ops; -} - -Status GatherV2Info::InferTensorSubOps() { - sub_ops_.clear(); - if ((index_size_ == 0) || (axis_strategy_ == 1)) { - return SUCCESS; - } - int32_t mod_n = 1; - for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { - mod_n *= dev_matrix_shape_.at(i); - } - if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) { - MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; - } - int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_); - int32_t rank = g_device_manager->global_rank(); - int32_t mod_rank = rank % mod_p; - mod_rank = static_cast(mod_rank / mod_n); - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - return FAILED; - } - if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) { - MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; - } - int32_t sub_value = static_cast(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank; - - OperatorVector sub_op; - sub_ops_.emplace_back(std::move(sub_op)); - sub_op = CreateSubOp(sub_value); - sub_ops_.emplace_back(std::move(sub_op)); - return SUCCESS; -} - -Status GatherV2Info::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - Status status = InferTensorSubOps(); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; - return status; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status GatherV2Info::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" - << outputs_shape_.size() << "is wrong."; - return FAILED; - } - - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != - SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -std::shared_ptr>> GatherV2Info::GenerateBatchStrategies() { - if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { - MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " - << inputs_shape_.size(); - } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - if (GetAttrs() != SUCCESS) { - MS_LOG(EXCEPTION) << "GetAttrs failed!"; - } - - Dimensions strategy; - if (index_size_ != 1) { - strategy.push_back(1); - } else { - strategy.push_back(SizeToInt(dev_num)); - } - for (size_t i = 1; i < inputs_shape_[0].size(); i++) { - strategy.push_back(1); - } - std::vector strategy_v = {strategy}; - return std::make_shared>>(strategy_v); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h deleted file mode 100644 index f7aeb6a0d9..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -constexpr size_t GATHER_V2_INPUTS_SIZE = 2; -constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; -constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; -// We now supported limited parallel strategies. -// If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of -// the input. -// If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. -class GatherV2Info : public OperatorInfo { - public: - GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), - axis_(-1), - index_size_(0), - axis_strategy_(1) {} - ~GatherV2Info() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - std::shared_ptr>> GenerateBatchStrategies() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - - private: - Status InferTensorSubOps(); - - int32_t axis_; - size_t index_size_; - int32_t axis_strategy_; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc deleted file mode 100644 index 680d6f3ed6..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ /dev/null @@ -1,636 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/gather_v2_p_info.h" - -#include -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/graph_util/generate_graph.h" - -namespace mindspore { -namespace parallel { -Status GatherV2PInfo::GetAttrs() { - // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis. - if (target_ != CPU) { - if (input_value_.at(2) == nullptr) { - MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; - return FAILED; - } - auto axis = GetValue(input_value_.at(2)); - // if axis is negative then convert it to positive - auto params_shape = inputs_shape_.at(0); - if (params_shape.size() == 0) { - MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; - return FAILED; - } - if (axis < 0) { - axis += SizeToInt(inputs_shape_[0].size()); - } - axis_ = axis; - } - - auto target_iter = attrs_.find(TARGET); - if (target_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(target_iter->second); - if (target_iter->second->isa()) { - target_ = target_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of target is not a string."; - } - } - auto manual_split_iter = attrs_.find("manual_split"); - if (manual_split_iter != attrs_.end()) { - param_split_shapes_.clear(); - manual_split_ = true; - auto var = manual_split_iter->second->cast(); - MS_LOG(DEBUG) << "Extract manual split strategy " << manual_split_iter->second->ToString(); - - if (var->size() > 0) { - std::vector elements = var->value(); - for (auto &ele : elements) { - if (ele->isa()) { - auto value_tuple = ele->cast(); - std::vector value_vector = value_tuple->value(); - if (value_vector.size() != 2) { - MS_LOG(ERROR) << "Failure: Size of manual_split element must be 2."; - return FAILED; - } - param_split_shapes_.push_back(static_cast(GetValue(value_vector[0]))); - index_offsets_.push_back(static_cast(GetValue(value_vector[1]))); - } else { - MS_LOG(ERROR) << "Failure: Manual split strategy's format is wrong! Need ValueSequeue"; - return FAILED; - } - } - - if (param_split_shapes_.empty()) { - MS_LOG(ERROR) << "Failed to extract param split strategy."; - return FAILED; - } - } - } - - return SUCCESS; -} - -Status GatherV2PInfo::CheckManualSplit() { - auto param_shape = inputs_shape_.at(0); - int32_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, - [](int32_t s, int32_t shape) { return s + shape; }); - if (split_shape_sum < param_shape.at(0)) { - MS_LOG(ERROR) << "Failure: Sum of splited shapes should not be smaller than param_shape."; - return FAILED; - } - - if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int32_t &offset) { return offset < 0; })) { - MS_LOG(ERROR) << "Failure: Index offset must not less than 0."; - return FAILED; - } - - return SUCCESS; -} - -Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - // param slice shape need 32Byte aligned - auto param_shape = inputs_shape_.at(0); - auto param_strategy = strategy->GetInputDim().at(0); - auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1); - if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) { - MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned."; - return FAILED; - } - - // only support 1-dim and 2-dim param - if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { - MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); - return FAILED; - } - - // don't support scalar index - if (inputs_shape_.at(1).size() == 0) { - MS_LOG(DEBUG) << name_ << ": Don't support scalar index."; - return FAILED; - } - - // axis=0, index_shape(0)%param_strategy(0) must be 0 - Shape index_shape = inputs_shape_.at(1); - if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { - MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; - return FAILED; - } - - if (manual_split_) { - if (CheckManualSplit() != SUCCESS) { - return FAILED; - } - // when using manual_split, no need to check belowings. - return SUCCESS; - } - - // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 - if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { - MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; - return FAILED; - } - - // param_strategy(axis) != 1, index can't be splited - auto index_strategy = strategy->GetInputDim().at(1); - auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); - if ((param_strategy.at(IntToSize(axis_)) != 1) && (product_i != 1)) { - MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; - return FAILED; - } - - // param_strategy(axis) != 1, Don't support repeated calc - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - if (IntToSize(product_p) != dev_num && param_strategy.at(IntToSize(axis_)) != 1) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; - return FAILED; - } - - return SUCCESS; -} - -Status GatherV2PInfo::InferMirrorOps() { - // There is no mirror operators for manual split - if (manual_split_) { - return SUCCESS; - } - - mirror_ops_.clear(); - Shape input_a_tensor_map = inputs_tensor_map_.at(0); - std::vector input_a_group; - if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input a failed."; - return FAILED; - } - - OperatorVector op_for_input_a, op_for_input_b, op_for_axis; - if (input_a_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror group is empty."; - return SUCCESS; - } else { - op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name(); - } - - mirror_ops_.push_back(op_for_input_a); - mirror_ops_.push_back(op_for_input_b); - mirror_ops_.push_back(op_for_axis); - - return SUCCESS; -} - -Status GatherV2PInfo::InferDevMatrixShape() { - dev_matrix_shape_.clear(); - out_dev_matrix_shape_.clear(); - // infer input dev_matrix_shape - auto param_strategy = strategy_->GetInputDim().at(0); - auto index_strategy = strategy_->GetInputDim().at(1); - - if (manual_split_) { - dev_matrix_shape_ = param_strategy; - out_dev_matrix_shape_ = dev_matrix_shape_; - return SUCCESS; - } - - dev_matrix_shape_ = param_strategy; - - // param_strategy(axis)!=1, - if (param_strategy.at(IntToSize(axis_)) != 1) { - std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); - } else { - dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); - } - - // infer out dev_matrix_shape - // axis!=0, split axis - if (axis_ != 0 && param_strategy.at(IntToSize(axis_)) != 1) { - out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(IntToSize(axis_))); - for (size_t i = 1; i < param_strategy.size(); ++i) { - if (i == IntToSize(axis_)) { - out_dev_matrix_shape_.push_back(1); - } else { - out_dev_matrix_shape_.push_back(param_strategy.at(i)); - } - } - } else { - out_dev_matrix_shape_ = dev_matrix_shape_; - } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); - if (param_product * index_product < SizeToInt(dev_num)) { - out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); - } - - return SUCCESS; -} - -Status GatherV2PInfo::InferTensorMap() { - if (manual_split_) { - inputs_tensor_map_.push_back({1, 0}); - inputs_tensor_map_.push_back({-1, 1}); - outputs_tensor_map_.push_back({-1, 1, 0}); - return SUCCESS; - } - // infer input tensor map - // param_strategy(axis) != 1 - size_t param_size = inputs_shape_.at(0).size(); - size_t index_size = inputs_shape_.at(1).size(); - size_t total_size = param_size + index_size; - std::vector tensor_map_index; - std::vector tensor_map_params; - auto param_strategy = strategy_->GetInputDim().at(0); - if (param_strategy.at(IntToSize(axis_)) != 1) { - tensor_map_index.insert(tensor_map_index.begin(), index_size, -1); - for (size_t i = 0; i < param_size; ++i) { - tensor_map_params.push_back(SizeToInt(i)); - } - } else { - // param_strategy(axis) == 1 - for (size_t i = 0; i < param_size; ++i) { - tensor_map_params.push_back(SizeToInt(total_size - i - 1)); - } - for (size_t i = 0; i < index_size; ++i) { - tensor_map_index.push_back(SizeToInt(index_size - i - 1)); - } - } - - // infer output tensor map - std::vector tensor_map_out; - if (param_strategy.at(IntToSize(axis_)) == 1) { - // param_strategy(axis) == 1 - for (size_t i = 0; i < param_size; ++i) { - if (i == IntToSize(axis_)) { - for (size_t j = 0; j < index_size; ++j) { - tensor_map_out.push_back(SizeToInt(index_size - j - 1)); - } - } else { - tensor_map_out.push_back(SizeToInt(total_size - i - 1)); - } - } - } else { - // param_strategy(axis) != 1 - if (axis_ == 0) { - tensor_map_out.insert(tensor_map_out.end(), 0); - tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); - for (size_t i = 1; i < param_size; ++i) { - tensor_map_out.push_back(i); - } - } else { - for (size_t i = 0; i < param_size; ++i) { - if (i == IntToSize(axis_)) { - tensor_map_out.insert(tensor_map_out.end(), index_size, -1); - } else { - tensor_map_out.push_back(SizeToInt(param_size - i - 1)); - } - } - } - } - - inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); - inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); - outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); - return SUCCESS; -} - -Status GatherV2PInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape input_index_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - int32_t rank = g_device_manager->global_rank(); - // infer tensor layout - TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; - if (manual_split_) { - input_shape[0] = param_split_shapes_[rank / dev_matrix_shape_[1]]; - input_shape[0] = input_shape[0] * dev_matrix_shape_[0]; - } - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || - (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != - SUCCESS)) { - return FAILED; - } - // infer tensor info - TensorInfo input_tensor_info(input_tensor_layout); - TensorInfo input_index_info(input_index_layout); - TensorInfo output_tensor_info(output_tensor_layout); - - Shape slice_shape = input_tensor_info.slice_shape(); - MS_LOG(DEBUG) << "The fake slice shape is: " << ShapeToString(slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - inputs_tensor_info_.push_back(input_index_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status GatherV2PInfo::InferBias() { - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - auto input_shape = inputs_shape_.at(0); - auto params_strategy = strategy_->GetInputDim().at(0); - // axis don't split - if (params_strategy.at(axis_) == 1) { - bias_ = 0; - return SUCCESS; - } - // params_size=1, axis=0 - if ((input_shape.size() == 1) && (axis_ == 0)) { - slice_size_ = input_shape.at(0) / params_strategy.at(0); - bias_ = rank * slice_size_; - return SUCCESS; - } - // params_size=2, axis=0 - if ((input_shape.size() == 2) && (axis_ == 0)) { - slice_size_ = input_shape.at(0) / params_strategy.at(0); - bias_ = rank / params_strategy.at(1) * slice_size_; - return SUCCESS; - } - // params_size=2, axis=1 - if ((input_shape.size() == 2) && (axis_ == 1)) { - slice_size_ = input_shape.at(1) / params_strategy.at(1); - bias_ = rank % params_strategy.at(1) * slice_size_; - return SUCCESS; - } - MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_; - return FAILED; -} - -Status GatherV2PInfo::InferOffset() { - CheckGlobalDeviceManager(); - size_t rank = g_device_manager->global_rank(); - if (rank < index_offsets_.size()) { - index_offset_ = index_offsets_.at(rank); - MS_LOG(DEBUG) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_; - return SUCCESS; - } - - MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size(); - return FAILED; -} - -Status GatherV2PInfo::InferGroup() { - auto param_strategy = strategy_->GetInputDim().at(0); - size_t dim = IntToSize(axis_); - if (param_strategy.at(IntToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { - dim = (axis_ + 1) % 2; - } - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - int32_t rank = g_device_manager->global_rank(); - RankList dev_list = g_device_manager->GetDeviceListByStageId(0); - DeviceMatrix dev_matrix(rank, dev_list, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesAlongDim(SizeToUint(dim), &group_devices) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Create group failed."; - return FAILED; - } - if (group_devices.size() == 1) { - MS_LOG(INFO) << "the group is empty"; - return SUCCESS; - } - - group_ = g_device_manager->CreateGroup(group_devices); - return SUCCESS; -} - -std::vector GetRankFromGroup(const Group &group) { - std::vector rank_list; - auto device_list = group.GetDevicesList(); - for (auto &device : device_list) { - rank_list.insert(rank_list.end(), device.rank() % 8); - } - return rank_list; -} - -Status GatherV2PInfo::InferForwardCommunication() { - forward_op_.clear(); - auto param_strategy = strategy_->GetInputDim().at(0); - // don't split axis or target is not CPU, no need forward communication - if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) { - return SUCCESS; - } - // split axis - OperatorName operator_name; - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - Attr attr_group; - operator_name = REDUCE_SCATTER; - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - attr_group = std::make_pair(GROUP, MakeValue(group_.name())); - Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); - OperatorAttrs attrs = {attr_op, attr_group}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - Operator op = std::make_pair(operator_name, args); - - forward_op_.push_back(op); - return SUCCESS; -} - -Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { - GenerateGraph gen_g = GenerateGraph(); - if (gen_g.Init(cnode) != SUCCESS) { - MS_LOG(ERROR) << "GenerateGraph Init failed"; - return FAILED; - } - if (manual_split_) { - if (InferOffset() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Bias failed."; - return FAILED; - } - auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)}); - auto gather_v2 = - gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub, CreatInt32Imm(axis_)}); - std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; - replace_graph_ = std::make_shared>, AnfNodePtr>>( - std::make_pair(input_nodes, gather_v2)); - return SUCCESS; - } - if (InferBias() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Bias failed."; - return FAILED; - } - auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); - auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); - auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); - auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum}); - auto gather_v2 = - gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); - auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); - auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); - auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); - auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); - // don't need expandim,if param_size = 1, - if (inputs_shape_.at(0).size() == 1) { - mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); - } - if (InferGroup() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer Group failed."; - return FAILED; - } - Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); - Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); - OperatorAttrs attrs = {attr_op, attr_group}; - auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); - std::vector> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)}; - replace_graph_ = std::make_shared>, AnfNodePtr>>( - std::make_pair(input_nodes, reduce_scatter)); - - return SUCCESS; -} - -ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { - if (manual_split_) { - if (ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; - } - return replace_graph_; - } - - auto param_strategy = strategy_->GetInputDim().at(0); - // target_ == CPU, no need to raplace graph - if (target_ == CPU) { - return nullptr; - } - if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; - } - return replace_graph_; -} - -Status GatherV2PInfo::ComputeReplaceOp() { - if (InferBias() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer offset failed."; - return FAILED; - } - OperatorName op_name = EMBEDDING_LOOKUP; - OperatorAttrs attrs; - Attr param_offset = std::make_pair("offset", MakeValue(bias_)); - OperatorParams params = {std::make_pair(param_offset, 3)}; - OperatorArgs args = std::make_pair(attrs, params); - Operator op = std::make_pair(op_name, args); - replace_op_.push_back(op); - - return SUCCESS; -} - -Status GatherV2PInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - // only target_ == CPU, we need to replace op - if (target_ == CPU && ComputeReplaceOp() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - auto param_strategy = strategy_->GetInputDim().at(0); - // cost model set axis and strategy - auto gatherv2_2cost = std::dynamic_pointer_cast(operator_cost()); - gatherv2_2cost->set_axis(axis_); - gatherv2_2cost->set_strategy(param_strategy); - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shape input1_split(inputs_shape_[1].size(), 1); - Shapes splittable_inputs = {input0_split, input1_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -std::shared_ptr>> GatherV2PInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions param_strategy(inputs_shape_[0].size(), 1); - Dimensions index_strategy; - index_strategy.push_back(SizeToInt(dev_num)); - for (size_t i = 1; i < inputs_shape_[1].size(); i++) { - index_strategy.push_back(1); - } - std::vector strategy_v = {param_strategy, index_strategy}; - return std::make_shared>>(strategy_v); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h deleted file mode 100644 index 16d5c85622..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class GatherV2PInfo : public OperatorInfo { - public: - GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), - axis_(0), - bias_(0), - index_offset_(0), - slice_size_(0) {} - ~GatherV2PInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - std::shared_ptr>> GenerateBatchStrategies() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - - private: - Status ComputeReplaceGraph(const CNodePtr &cnode); - Status CheckManualSplit(); - Status ComputeReplaceOp(); - Status InferBias(); - Status InferOffset(); - Status InferGroup(); - - int32_t axis_; - std::string target_ = DEVICE; - std::string replace_op_name_ = GATHERV2; - int32_t bias_; - int32_t index_offset_; - int32_t slice_size_; - Shape out_dev_matrix_shape_; - Group group_; - bool manual_split_ = false; - std::vector param_split_shapes_; - std::vector index_offsets_; -}; - -class SparseGatherV2Info : public GatherV2PInfo { - public: - SparseGatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} - ~SparseGatherV2Info() override = default; - - private: - std::string replace_op_name_ = SPARSE_GATHERV2; -}; - -class EmbeddingLookupInfo : public GatherV2PInfo { - public: - EmbeddingLookupInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : GatherV2PInfo(name, inputs_shape, outputs_shape, attrs) {} - ~EmbeddingLookupInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc deleted file mode 100644 index 0fb49364f0..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc +++ /dev/null @@ -1,269 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/get_next_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/context.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status GetNextInfo::InferTensorMap() { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - for (auto shp : shapes_) { - TensorMap out_tensor_map; - for (size_t i = 0; i < shp.size(); ++i) { - if (full_batch) { - out_tensor_map.push_back(MAP_NONE); - } else { - out_tensor_map.push_back(SizeToInt(dev_matrix_shape_.size() - i - 1)); - } - } - outputs_tensor_map_.push_back(out_tensor_map); - } - return SUCCESS; -} - -Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { - if (outputs_layout == nullptr) { - MS_LOG(ERROR) << name_ << " : The layout is null."; - return FAILED; - } - for (size_t i = 0; i < outputs_shape_.size(); ++i) { - TensorLayout output_layout; - if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) { - return FAILED; - } - outputs_layout->push_back(output_layout); - } - return SUCCESS; -} - -Strategys GetNextInfo::GetOutputStrategy() { - Strategys outputs_strategy; - for (auto shp : shapes_) { - Dimensions out_strategy; - out_strategy.push_back(dev_num_); - for (size_t i = 1; i < shp.size(); ++i) { - out_strategy.push_back(1); - } - outputs_strategy.push_back(out_strategy); - } - return outputs_strategy; -} - -Status GetNextInfo::InferTensorInfo() { - TensorLayouts outputs_layout; - if (InferTensorLayout(&outputs_layout) != SUCCESS) { - return FAILED; - } - for (size_t i = 0; i < outputs_shape_.size(); ++i) { - TensorInfo output_tensor_info(outputs_layout[i]); - outputs_tensor_info_.push_back(output_tensor_info); - } - return SUCCESS; -} - -Status GetNextInfo::InferDevMatrixShape() { - size_t max_shape_length = 0; - for (auto shp : shapes_) { - if (max_shape_length < shp.size()) { - max_shape_length = shp.size(); - } - } - if (max_shape_length == 0) { - MS_LOG(ERROR) << name_ << " : shape is 0"; - } - dev_matrix_shape_.push_back(dev_num_); - for (size_t i = 1; i < max_shape_length; ++i) { - dev_matrix_shape_.push_back(1); - } - return SUCCESS; -} - -Status GetNextInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed"; - return FAILED; - } - if (InferReplaceOps(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer replace Ops failed"; - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init success"; - return SUCCESS; -} - -Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { - std::vector stras = strategy->GetInputDim(); - for (Dimensions stra : stras) { - if (stra.size() != 0) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - } - int32_t stage = strategy->GetInputStage(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - dev_num_ = dev_num; - return SUCCESS; -} - -Status GetNextInfo::GetAttrTypes() { - auto iter = attrs_.find(TYPES); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - auto iter_cast = iter->second->cast(); - MS_EXCEPTION_IF_NULL(iter_cast); - auto types = iter_cast->value(); - for (auto &type : types) { - MS_EXCEPTION_IF_NULL(type); - types_.push_back(type->ToString()); - } - } else if (iter->second->isa()) { - auto iter_cast = iter->second->cast(); - MS_EXCEPTION_IF_NULL(iter_cast); - auto types = iter_cast->value(); - for (auto &type : types) { - MS_EXCEPTION_IF_NULL(type); - types_.push_back(type->ToString()); - } - } else { - MS_LOG(ERROR) << name_ << " : The value of types is not list."; - return FAILED; - } - } - return SUCCESS; -} - -Status GetNextInfo::GetAttrShapes() { - shapes_ = outputs_shape_; - if (shapes_.size() == 0) { - MS_LOG(ERROR) << name_ << " : Shape is None."; - return FAILED; - } - return SUCCESS; -} - -Status GetNextInfo::GetAttrOutPutNum() { - auto iter = attrs_.find(GETNEXT_NUM); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - output_num_ = iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of output_num is not int."; - return FAILED; - } - } - return SUCCESS; -} - -Status GetNextInfo::GetAttrs() { - if (GetAttrTypes() == FAILED || GetAttrShapes() == FAILED || GetAttrOutPutNum() == FAILED) { - return FAILED; - } - if (types_.size() != IntToSize(output_num_) || shapes_.size() != IntToSize(output_num_) || output_num_ == 0) { - MS_LOG(ERROR) << name_ << " : The output_num is not equal to shapes size."; - return FAILED; - } - return SUCCESS; -} - -Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - Shapes out_shapes = outputs_shape_; - for (size_t i = 0; i < out_shapes.size(); ++i) { - if (dev_num_ <= 0) { - MS_LOG(ERROR) << name_ << " : The dev num is 0."; - return FAILED; - } - if (out_shapes[i][0] % dev_num_ != 0) { - MS_LOG(ERROR) << name_ << " : batch num cannot floor div dev num."; - return FAILED; - } - if (!full_batch) { - out_shapes[i][0] = out_shapes[i][0] / dev_num_; - } - } - ValuePtr new_shapes = MakeValue(out_shapes); - Attr attr_types = std::make_pair(TYPES, attrs_[TYPES]); - Attr attr_shapes = std::make_pair(SHAPES, new_shapes); - Attr attr_num = std::make_pair(GETNEXT_NUM, attrs_[GETNEXT_NUM]); - Attr attr_shared_name = std::make_pair(SHARED_NAME, attrs_[SHARED_NAME]); - OperatorAttrs attrs = {attr_types, attr_shapes, attr_num, attr_shared_name}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - replace_op_ = {std::make_pair(GET_NEXT, args)}; - return SUCCESS; -} - -Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status GetNextInfo::GenerateStrategies(int32_t stage_id) { - is_auto_parallel_ = true; - std::vector stra; - StrategyPtr sp = std::make_shared(stage_id, stra); - if (SetCostUnderStrategy(sp) == SUCCESS) { - MS_LOG(INFO) << name_ << " : Successfully generated strategy."; - PrintStrategy(sp); - } else { - MS_LOG(ERROR) << name_ << " : Generating strategy failed."; - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.h b/mindspore/ccsrc/parallel/ops_info/get_next_info.h deleted file mode 100644 index ba209910b7..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ - -#include -#include -#include -#include - -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class GetNextInfo : public OperatorInfo { - public: - GetNextInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~GetNextInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *outputs_layout); - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferReplaceOps(const StrategyPtr &strategy); - Status GetAttrTypes(); - Status GetAttrShapes(); - Status GetAttrOutPutNum(); - Strategys GetOutputStrategy(); - Status InferAsLossDivisor() override { return SUCCESS; } - - private: - int32_t dev_num_ = 1; - std::vector types_; - Shapes shapes_; - int32_t output_num_ = 0; - std::string shared_name_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GETNEXT_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc deleted file mode 100644 index 8716997d9f..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/l2_normalize_info.h" - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(INFO) << name_ << " : Init success."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions input_strategy = stra.at(0); - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + axis_; - } - - if (input_strategy[IntToSize(axis_index)] != 1) { - MS_LOG(ERROR) << name_ << " : The dim " << axis_index << " of input strategy must be 1."; - return FAILED; - } - - return SUCCESS; -} - -Status L2NormalizeInfo::GetAttrs() { - auto iter = attrs_.find(AXIS); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis_ = iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of axis is not int."; - return FAILED; - } - } - - return SUCCESS; -} - -Status L2NormalizeInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = inputs_tensor_map_.at(0); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group failed."; - return FAILED; - } - - OperatorVector op_for_weight; - if (input_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group is " << input_group[0].name(); - } - - return SUCCESS; -} - -Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size() - 1, 1); - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + axis_; - } - (void)input0_split.insert(input0_split.begin() + axis_index, 0); - Shapes splittable_inputs = {input0_split}; - - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h deleted file mode 100644 index ca063d01d8..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class L2NormalizeInfo : public Activation { - public: - L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : Activation(name, inputs_shape, outputs_shape, attrs) {} - ~L2NormalizeInfo() override = default; - Status GenerateStrategies(int32_t stage_id) override; - - protected: - Status GetAttrs() override; - Status InferMirrorOps() override; - Status CheckStrategy(const StrategyPtr &strategy) override; - - private: - int32_t axis_ = 0; // Default value = 0 -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_L2_NORMALIZE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc deleted file mode 100644 index 5bdd24090f..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.cc +++ /dev/null @@ -1,324 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/layer_norm_info.h" -#include -#include -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -Status LayerNormInfo::GetAttrs() { - auto iter = attrs_.find(BEGIN_NORM_AXIS); - if (iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis"; - return FAILED; - } - if ((iter->second == nullptr) || !iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": The axis type is not int"; - return FAILED; - } - - int32_t dim = SizeToInt(input_shape_.size()); - auto axis = GetValue(iter->second); - if ((axis >= dim) || (axis < -dim)) { - MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim << ", " << dim - 1 << "]"; - return FAILED; - } - - if (axis < 0) { - axis = axis + dim; - } - begin_norm_axis_ = IntToSize(axis); - return SUCCESS; -} - -Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { - MS_EXCEPTION_IF_NULL(strategy); - std::vector stra = strategy->GetInputDim(); - if (stra.size() != LAYER_NORM_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size(); - return FAILED; - } - - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Invalid strategy value"; - return FAILED; - } - - Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX]; - Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX]; - Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX]; - if (begin_norm_axis_ >= input_strategy.size()) { - MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; - return FAILED; - } - // check input strategy - for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) { - if (input_strategy[i] != NO_SPLIT_STRATEGY) { - MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy); - return FAILED; - } - } - - // check gamma and beta strategy - if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) { - MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy"; - return FAILED; - } - - size_t gamma_diff = input_strategy.size() - gamma_strategy.size(); - for (size_t j = 0; j < gamma_strategy.size(); ++j) { - if (gamma_strategy[j] != input_strategy[gamma_diff + j]) { - MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy); - return FAILED; - } - } - - size_t beta_diff = input_strategy.size() - beta_strategy.size(); - for (size_t k = 0; k < beta_strategy.size(); ++k) { - if (beta_strategy[k] != input_strategy[beta_diff + k]) { - MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy); - return FAILED; - } - } - return SUCCESS; -} - -Status LayerNormInfo::InferDevMatrixShape() { - if (strategy_ == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null"; - return FAILED; - } - std::vector stra = strategy_->GetInputDim(); - if (stra.empty()) { - MS_LOG(ERROR) << name_ << ": The strategy is empty"; - return FAILED; - } - dev_matrix_shape_ = stra[0]; - return SUCCESS; -} - -Status LayerNormInfo::CreateTensorMap(size_t input_index) { - if (inputs_shape_.size() <= input_index) { - MS_LOG(ERROR) << name_ << ": Invalid index" << input_index; - return FAILED; - } - Shape shape = inputs_shape_[input_index]; - Shape tensor_map; - for (size_t i = 0; i < shape.size(); ++i) { - tensor_map.push_back(SizeToInt(shape.size() - i - 1)); - } - inputs_tensor_map_.push_back(tensor_map); - outputs_tensor_map_.push_back(tensor_map); - return SUCCESS; -} - -Status LayerNormInfo::InferTensorMap() { - if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || - (CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Create tensor map failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::CreateMirrorOp(size_t input_index) { - if (inputs_tensor_map_.size() <= input_index) { - MS_LOG(ERROR) << name_ << ": Invalid index " << input_index; - return FAILED; - } - Shape tensor_map = inputs_tensor_map_[input_index]; - std::vector group; - if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed"; - return FAILED; - } - OperatorVector mirror_op; - if (!group.empty()) { - mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); - MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is " - << group[0].name(); - } - mirror_ops_.push_back(mirror_op); - return SUCCESS; -} - -Status LayerNormInfo::InferMirrorOps() { - if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || - (CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Create mirror op failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::CreateTensorInfo(size_t input_index) { - if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) { - MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index; - return FAILED; - } - Shape tensor_map = inputs_tensor_map_[input_index]; - Shape shape = inputs_shape_[input_index]; - TensorLayout tensor_layout; - if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed"; - return FAILED; - } - - TensorInfo tensor_info(tensor_layout); - inputs_tensor_info_.push_back(tensor_info); - outputs_tensor_info_.push_back(tensor_info); - return SUCCESS; -} - -Status LayerNormInfo::InferTensorInfo() { - if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) || - (CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Create tensor info failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::InferAsLossDivisor() { - if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error"; - return FAILED; - } - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); - MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) - << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0]) - << ", as_loss_divisor_ is " << as_loss_divisor_; - return SUCCESS; -} - -Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Set cost failed"; - return FAILED; - } - return SUCCESS; -} - -Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector &sp_vector) { - if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { - MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input"; - return FAILED; - } - - size_t gamma_diff = input_shape_.size() - gamma_shape_.size(); - size_t beta_diff = input_shape_.size() - beta_shape_.size(); - for (auto &sp : sp_vector) { - if ((sp == nullptr) || sp->GetInputDim().empty()) { - MS_LOG(ERROR) << name_ << ": Invalid strategy"; - return FAILED; - } - std::vector tmp_strategy; - Dimensions input_strategy = sp->GetInputDim()[0]; - Dimensions gamma_strategy = input_strategy; - (void)gamma_strategy.erase(gamma_strategy.begin(), - gamma_strategy.begin() + static_cast(gamma_diff)); - Dimensions beta_strategy = input_strategy; - (void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast(beta_diff)); - - // reset the strategy - tmp_strategy.push_back(input_strategy); - tmp_strategy.push_back(gamma_strategy); - tmp_strategy.push_back(beta_strategy); - sp->ResetInputs(tmp_strategy); - } - return SUCCESS; -} - -Status LayerNormInfo::GenerateStrategies(int32_t stage_id) { - if (InitShapes() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init shapes failed"; - return FAILED; - } - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Get attrs failed"; - return FAILED; - } - Shape input_split(input_shape_.size(), SPLIT_FLAG); - if (begin_norm_axis_ >= input_split.size()) { - MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_; - return FAILED; - } - - // Can not split the dimensions from begin norm axis - for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) { - input_split[i] = NO_SPLIT_FLAG; - } - - // Generate strategy for input - Shapes splittable_inputs = {input_split}; - Shapes tmp_inputs_shape = {input_shape_}; - std::vector sp_vector; - is_auto_parallel_ = true; - if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Generate input strategy failed"; - return FAILED; - } - - // Generate the strategies for gamma and beta - if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Generate gamma and beta strategies failed"; - return FAILED; - } - - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(DEBUG) << name_ << ": Successfully generated " << success << " strategy"; - } - } - return SUCCESS; -} - -Status LayerNormInfo::InitShapes() { - if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) { - MS_LOG(ERROR) << name_ << ": Invalid inputs size"; - return FAILED; - } - input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX]; - gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX]; - beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX]; - return SUCCESS; -} - -Status LayerNormInfo::Init(const StrategyPtr &strategy) { - if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed"; - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init success"; - return SUCCESS; -} - -Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) { - if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) { - MS_LOG(ERROR) << name_ << ": Init for cost model failed"; - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success"; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h deleted file mode 100644 index 50117b8185..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ - -#include -#include -#include -#include -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -constexpr size_t LAYER_NORM_INPUT_SIZE = 3; -constexpr size_t LAYER_NORM_INPUT_INDEX = 0; -constexpr size_t LAYER_NORM_GAMMA_INDEX = 1; -constexpr size_t LAYER_NORM_BETA_INDEX = 2; -constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; - -// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split -// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. -class LayerNormInfo : public OperatorInfo { - public: - LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), - begin_norm_axis_(0) {} - ~LayerNormInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr &) override; - - protected: - Status GetAttrs() override; - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferAsLossDivisor() override; - Status CreateTensorMap(size_t input_index); - Status CreateTensorInfo(size_t input_index); - Status CreateMirrorOp(size_t input_index); - Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); - Status InitShapes(); - - private: - size_t begin_norm_axis_; - Shape input_shape_; - Shape gamma_shape_; - Shape beta_shape_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/parallel/ops_info/loss_info.cc deleted file mode 100644 index 0ba325c0cd..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.cc +++ /dev/null @@ -1,232 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/loss_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_matrix.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions input_strategy = stra.at(0); - Dimensions label_strategy = stra.at(1); - if (input_strategy != label_strategy) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_.at(0).size(); - axis_index = static_cast(input_dim) + axis_; - } - - int32_t input_axis_strategy = input_strategy.at(IntToSize(axis_index)); - int32_t label_axis_strategy = label_strategy.at(IntToSize(axis_index)); - // Dimension corresponding to axis is un-splittable - if ((input_axis_strategy != MIN_SLICE_NUM) && (label_axis_strategy != MIN_SLICE_NUM)) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ - << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy - << ", label: " << label_axis_strategy; - } else { - MS_LOG(ERROR) << name_ - << " : The strategy corresponding to axis dimension is not 1, input: " << input_axis_strategy - << ", label: " << label_axis_strategy; - } - return FAILED; - } - - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::GetAttrs() { - if ((inputs_shape_.size() != SoftmaxCrossEntropyWithLogitsInputsSize) || - (outputs_shape_.size() != SoftmaxCrossEntropyWithLogitsOutputsSize)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - dev_matrix_shape_ = input_strategy; - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorMap() { - std::vector tensor_map_index; - size_t size = inputs_shape_[0].size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - i - 1)); - } - - std::vector first_output_tensor_map = {tensor_map_index[0]}; - inputs_tensor_map_.push_back(tensor_map_index); // input - inputs_tensor_map_.push_back(tensor_map_index); // label - outputs_tensor_map_.push_back(first_output_tensor_map); // output-0 - outputs_tensor_map_.push_back(tensor_map_index); // output-1 - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape first_output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {{inputs_strategy[0][0]}, inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape first_output_slice_shape = outputs_slice_shape.at(0); - - TensorMap input_tensor_map = inputs_tensor_map_.at(0); - TensorMap first_output_tensor_map = outputs_tensor_map_.at(0); - - TensorLayout input_tensor_layout, first_output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, input_tensor_map, input_shape) != SUCCESS) || - (first_output_tensor_layout.InitFromVector(dev_matrix_shape_, first_output_tensor_map, first_output_shape) != - SUCCESS)) { - return FAILED; - } - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo first_output_tensor_info(first_output_tensor_layout, first_output_shape, first_output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); // input - inputs_tensor_info_.push_back(input_tensor_info); // label - outputs_tensor_info_.push_back(first_output_tensor_info); // output-0 - outputs_tensor_info_.push_back(input_tensor_info); // output-1 - - return SUCCESS; -} - -// There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload the function. -Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { - if (outputs_tensor_map_.size() != 2) { - MS_LOG(ERROR) << name_ << " : The size of outputs tensor map " << outputs_tensor_map_.size() << " is error."; - return FAILED; - } - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[1]); - MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_) - << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[1]) << ", as_loss_divisor_ is " - << as_loss_divisor_; - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { - for (size_t i = 0; i < inputs_shape_.size(); ++i) { - split_flag_list_[i] = true; - } -} - -Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - int32_t axis_index = axis_; - if (axis_ < 0) { - size_t input_dim = inputs_shape_[0].size(); - axis_index = static_cast(input_dim) + axis_; - } - is_auto_parallel_ = true; - - Shape input0_split; - (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - input0_split[IntToSize(axis_index)] = 0; - Shapes splittable_inputs = {input0_split, input0_split}; - std::vector sp_vector; - if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Generate strategies failed."; - return FAILED; - } - - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - - return SUCCESS; -} - -Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - PrintStrategy(strategy); - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h deleted file mode 100644 index 2679c2d62b..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -// infer shape: -// input_0 : [a, b], input_1 : [a, b] -// output_0 : [a], output_1: [a, b] -class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { - public: - SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, - std::make_shared(false)) {} - ~SoftmaxCrossEntropyWithLogitsInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - // There are two outputs for SoftmaxCrossEntropyWithLogits, and outputs[1] is used for grad and overload - // the InferAsLossDivisor. - Status InferAsLossDivisor() override; - - private: - int32_t axis_ = -1; // default -1 -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LOSS_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc deleted file mode 100644 index 7d1ab8dc0f..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ /dev/null @@ -1,647 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/matmul_info.h" - -#include -#include -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -namespace mindspore { -namespace parallel { -void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, - Shape *dev_matrix_shape) { - MS_EXCEPTION_IF_NULL(dev_matrix_shape); - size_t mat_a_size = mat_a_strategy.size(); - size_t mat_b_size = mat_b_strategy.size(); - if (mat_a_size >= mat_b_size) { - // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] - // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) - - // [2],[4] in the example above - for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) { - dev_matrix_shape->push_back(mat_a_strategy.at(i)); - } - } else { - // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32] - // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) - - // [2],[4] in the example above - for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) { - dev_matrix_shape->push_back(mat_b_strategy.at(i)); - } - } - - // [8],[16] in the example above - dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size))); - dev_matrix_shape->push_back(mat_a_strategy.back()); - - // [32] in the example above - if (!transpose_b) { - dev_matrix_shape->push_back(mat_b_strategy.back()); - } else { - dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size))); - } -} - -Status MatMulBase::GetAttrs() { - if (attrs_.size() < MATMUL_ATTRS_SIZE) { - MS_LOG(ERROR) << name_ << " : The size of attrs small than 2."; - return FAILED; - } - - auto transpose_a_iter = attrs_.find(TRANSPOSE_A); - if (transpose_a_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(transpose_a_iter->second); - if (transpose_a_iter->second->isa()) { - transpose_a_ = transpose_a_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; - return FAILED; - } - } - - auto transpose_b_iter = attrs_.find(TRANSPOSE_B); - if (transpose_b_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(transpose_b_iter->second); - if (transpose_b_iter->second->isa()) { - transpose_b_ = transpose_b_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool."; - return FAILED; - } - } - - auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER); - if (forward_reduce_scatter_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second); - if (forward_reduce_scatter_iter->second->isa()) { - forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool."; - return FAILED; - } - } - - // infer inputs dimension size - if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong."; - return FAILED; - } - mat_a_dimension_ = inputs_shape_.at(0).size(); - mat_b_dimension_ = inputs_shape_.at(1).size(); - - return SUCCESS; -} - -Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { - size_t long_size = long_strategy.size(); - size_t short_size = short_strategy.size(); - if (long_size < short_size) { - MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is " - << short_size; - return FAILED; - } - - size_t len_diff = long_size - short_size; - for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) { - if (long_strategy.at(len_diff + j) != short_strategy.at(j)) { - MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is " - << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy); - return FAILED; - } - } - - return SUCCESS; -} - -Status MatMul::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << " : Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - Dimensions mat_a_strategy = stra.at(0); - Dimensions mat_b_strategy = stra.at(1); - - size_t mat_a_size = mat_a_strategy.size(); - size_t mat_b_size = mat_b_strategy.size(); - if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; - } else { - MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong."; - } - return FAILED; - } - - // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32] - // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false) - // [16] in the example above - if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - - if (mat_a_size >= mat_b_size) { - if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - } else { - if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal."; - return FAILED; - } - } - - if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) { - MS_LOG(WARNING) << name_ - << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, " - "setting the forward reduce scatter mode to false here"; - forward_reduce_scatter_ = false; - } - - return SUCCESS; -} - -Status MatMulBase::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions mat_a_strategy = stra.at(0); - Dimensions mat_b_strategy = stra.at(1); - - SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_); - return SUCCESS; -} - -// all-reduce weight's grad -Status MatMulBase::InferMirrorOps() { - mirror_ops_.clear(); - - Shape mat_b_tensor_map = inputs_tensor_map_[1]; - std::vector mat_b_group; - if (CreateGroupByTensorMap(mat_b_tensor_map, &mat_b_group) != SUCCESS) { - return FAILED; - } - - OperatorVector op_for_inputs; // op_for_inputs is empty - OperatorVector op_for_weight; - - if (mat_b_group.empty()) { - MS_LOG(INFO) << name_ << " : The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(mat_b_group[0].name(), mat_b_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_inputs); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << " : Create the mirror ops for weight success, group is " << mat_b_group[0].name(); - } - - return SUCCESS; -} - -Status MatMulBase::InferForwardCommunication() { - forward_op_.clear(); - size_t dimension = dev_matrix_shape_.size(); - size_t relevant_dimension_index = SECOND_FROM_END(dimension); - // Relevant dimension is not split and all reduce is not required - if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) { - MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; - return SUCCESS; - } - - std::vector group_list; - if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed."; - return FAILED; - } else if (group_list.empty()) { - MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; - return SUCCESS; - } - - Operator op; - if (forward_reduce_scatter_) { - op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name()); - } else { - op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); - } - - forward_op_.push_back(op); - MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); - return SUCCESS; -} - -Status MatMulBase::InferTensorMap() { - size_t size = dev_matrix_shape_.size(); - if (repeated_calc_num_ > 1) { - // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation - size = dev_matrix_shape_.size() - 1; - } - - std::vector tensor_map_index; - // such as 5: tensor_map_index [4,3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); - } - - // infer output tensor map: [4,3,2,0], delete the second-from-end element - TensorMap output_tensor_map = tensor_map_index; - (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast(SECOND_FROM_END(size))); - - // infer mat_a tensor map - // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1] - TensorMap mat_a_tensor_map = tensor_map_index; - // delete last one element - mat_a_tensor_map.pop_back(); - // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements - (void)mat_a_tensor_map.erase( - mat_a_tensor_map.begin(), - mat_a_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_a_dimension_)); - - // infer mat_b tensor map - TensorMap mat_b_tensor_map = tensor_map_index; - // delete the third-to-last element - (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast(THIRD_FROM_END(size))); - // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements - (void)mat_b_tensor_map.erase( - mat_b_tensor_map.begin(), - mat_b_tensor_map.begin() + static_cast(LAST_INDEX(size) - mat_b_dimension_)); - if (transpose_b_) { - // swap the last two elements - int32_t last_value = mat_b_tensor_map.back(); - mat_b_tensor_map.pop_back(); - (void)mat_b_tensor_map.insert( - mat_b_tensor_map.begin() + static_cast(LAST_INDEX(mat_b_tensor_map.size())), last_value); - } - - if (forward_reduce_scatter_) { - if (dev_matrix_shape_.size() != 3) { - MS_LOG(WARNING) << name_ - << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, " - "setting the forward reduce scatter mode to false here"; - forward_reduce_scatter_ = false; - } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) { - MS_LOG(WARNING) << name_ - << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in " - "forward reduce scatter mode, setting the forward reduce scatter mode to false here"; - forward_reduce_scatter_ = false; - } else { - // the forward reduce scatter only support that the dimension of output is 2 - output_tensor_map = {1, 0}; - } - } - - inputs_tensor_map_.push_back(mat_a_tensor_map); - inputs_tensor_map_.push_back(mat_b_tensor_map); - outputs_tensor_map_.push_back(output_tensor_map); - return SUCCESS; -} - -Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - Shape output_dev_matrix_shape; - if (forward_reduce_scatter_) { - if (dev_matrix_shape_.size() != 3) { - MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode"; - return FAILED; - } - output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]}; - } else { - output_dev_matrix_shape = dev_matrix_shape_; - } - - TensorLayout mat_a_layout, mat_b_layout, output_layout; - if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || - (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || - (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { - return FAILED; - } - - inputs_layout->push_back(mat_a_layout); - inputs_layout->push_back(mat_b_layout); - outputs_layout->push_back(output_layout); - return SUCCESS; -} - -Status MatMulBase::InferTensorInfo() { - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - - TensorLayout mat_a_layout = inputs_layout.at(0); - TensorLayout mat_b_layout = inputs_layout.at(1); - TensorLayout output_layout = outputs_layout.at(0); - TensorInfo mat_a_tensor_info(mat_a_layout); - TensorInfo mat_b_tensor_info(mat_b_layout); - TensorInfo output_tensor_info(output_layout); - - inputs_tensor_info_.push_back(mat_a_tensor_info); - inputs_tensor_info_.push_back(mat_b_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status MatMulBase::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << " : Init failed."; - return FAILED; - } - - if (forward_reduce_scatter_) { - virtual_div_op_.clear(); - MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op"; - } - - MS_LOG(INFO) << name_ << " : Init success."; - return SUCCESS; -} - -Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << " : Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << " : Init for cost model success."; - return SUCCESS; -} - -Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { - if (input->size() < 2) { - MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; - return FAILED; - } - auto last_1st_value = input->at(input->size() - 1); - auto last_2nd_value = input->at(input->size() - 2); - input->pop_back(); - input->pop_back(); - input->push_back(last_1st_value); - input->push_back(last_2nd_value); - return SUCCESS; -} - -Status MatMulBase::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << " : GetAttrs failed."; - return FAILED; - } - CheckGlobalDeviceManager(); - std::vector dev_list = g_device_manager->GetDeviceListByStageId(stage_id); - size_t dev_num = dev_list.size(); - Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1]; - if (transpose_a_) { - if (SwapLastTwoElements(&input0_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - if (transpose_b_) { - if (SwapLastTwoElements(&input1_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - // The shape of input0 (input1) - // E.g., input0 = [100, 200, 300], input1 = [300, 400] - - // Combining the input0_shape and input1_shape - // E.g., combined_shape = [100, 200, 300, 400] - is_auto_parallel_ = true; - size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size(); - Dimensions combined_partitions; - Shape combined_shape; - // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2 - if (input0_shape.size() >= input1_shape.size()) { - combined_shape = input0_shape; - combined_shape.push_back(input1_shape[input1_shape.size() - 1]); - } else { - combined_shape = input1_shape; - combined_shape.push_back(input0_shape[input0_shape.size() - 2]); - } - std::function recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape, - &input1_shape_size, &recursive, &input0_shape_size, - this](uint32_t current_index, size_t n) { - // Finishing the recursive steps, if the strategy is valid, then calculate the cost - // for this operator under the strategy. - if (current_index == combined_shape.size()) { - StrategyPtr sp; - if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) == - FAILED) { - return; - } - if (this->SetCostUnderStrategy(sp) == FAILED) { - MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed."; - return; - } - } else { - MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size - << ", input1_shape_size: " << input1_shape_size; - for (uint32_t i = 1; i <= n; i *= 2) { - if (n % i == 0 && IntToSize(combined_shape[current_index]) % i == 0) { - combined_partitions.push_back(i); - recursive(current_index + 1, n / i); - combined_partitions.pop_back(); - } - } - } - }; - recursive(0, dev_num); - if (strategy_cost_.empty()) { - MS_LOG(EXCEPTION) << name_ << " : No available strategy."; - } - return Status::SUCCESS; -} - -Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, - mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { - int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); - if (!FULLY_USE_DEVICES) { - if (IntToSize(product) > dev_num) { - return FAILED; - } - } else { - if (IntToSize(product) != dev_num) { - return FAILED; - } - } - Dimensions input0_partitions, input1_partitions; - if (input0_shape_size >= input1_shape_size) { - for (size_t i = 0; i < input0_shape_size; ++i) { - input0_partitions.push_back(combined_partitions[i]); - } - if (input1_shape_size == 2) { - input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]); - input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); - } else { - // input1_shape.size() > 2 - for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) { - if (j == combined_partitions.size() - 3) { - continue; - } - input1_partitions.push_back(combined_partitions[j]); - } - } - } else { - for (size_t i = 0; i < input1_shape_size; ++i) { - input1_partitions.push_back(combined_partitions[i]); - } - for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) { - input0_partitions.push_back(combined_partitions[j]); - } - input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]); - input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]); - } - if (transpose_a_) { - if (SwapLastTwoElements(&input0_partitions) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - if (transpose_b_) { - if (SwapLastTwoElements(&input1_partitions) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - } - std::vector stras; - stras.push_back(input0_partitions); - stras.push_back(input1_partitions); - (*sp) = std::make_shared(stage_id, stras); - - return SUCCESS; -} - -void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { - TensorLayout tly; - if (transpose_a_) { - Shape replica_input0_shape(inputs_tensor_info_[0].shape()); - Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape()); - if (SwapLastTwoElements(&replica_input0_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - - TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape); - relica_inputs_tensor_vector->push_back(replica_input0_info); - } else { - relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]); - } - if (transpose_b_) { - Shape replica_input1_shape(inputs_tensor_info_[1].shape()); - Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape()); - if (SwapLastTwoElements(&replica_input1_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) { - MS_LOG(ERROR) << name_ << " : Swap last two elements failed."; - } - - TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape); - relica_inputs_tensor_vector->push_back(replica_input1_info); - } else { - relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]); - } -} - -Status MatMulBase::CheckForTensorSliceValid() const { - if (!TENSOR_SLICE_ALIGNMENT_ENABLE) { - return SUCCESS; - } - if (inputs_tensor_info_.empty()) { - return FAILED; - } - for (auto &one_input_tensor : inputs_tensor_info_) { - auto slice_shape = one_input_tensor.slice_shape(); - if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || - (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { - return FAILED; - } - } - return SUCCESS; -} - -Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (InitForCostModel(strategy) == FAILED) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; - } else { - MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed."; - } - return FAILED; - } - PrintStrategy(strategy); - // Check whether the tensor slice of input_tensor_info is valid or not - if (CheckForTensorSliceValid() != SUCCESS) { - MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy."; - return FAILED; - } - // Here, a replicated inputs_ is constructed for the transposed TensorInfo. - std::vector relica_inputs_tensor_vector; - InitTensorInfoForCost(&relica_inputs_tensor_vector); - - int32_t stage_id = strategy->GetInputStage(); - // Here, we use the origin outputs_, because we only use the slice size of the output tensor. - // It does not matter whether the output tensor is transposed or not. - double computation_cost = - operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(computation_cost, communication_cost); - result->communication_without_parameter_ = - operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id); - result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); - - // Breaking ties for preferring data parallelization - BreakingTiesForPerferringDataParallel(strategy, result); - MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_ - << ", communication_cost: " << result->communication_cost_ - << ", communication_without_parameter_: " << result->communication_without_parameter_ - << ", communication_with_partial_para_: " << result->communication_with_partial_para_; - // refine communication cost calculation for practice - RefineForPracticalCost(result, false); - result->communication_forward_ = result->communication_without_parameter_; - - std::shared_ptr swc = - std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); - swc->cost_list.push_back(result); - strategy_cost_.emplace_back(swc); - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h deleted file mode 100644 index cb3e54a048..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ - -#include -#include -#include -#include - -#include "common/utils.h" -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class MatMulBase : public OperatorInfo { - public: - MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~MatMulBase() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - // Generate all strategies and the corresponding cost for this MatMul operator - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, StrategyPtr *sp); - - Status SwapLastTwoElements(Shape *shape); - - protected: - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - void InitTensorInfoForCost(std::vector *); - Status CheckForTensorSliceValid() const; - Status GetAttrs() override; - - bool transpose_a_ = false; - bool transpose_b_ = false; - bool forward_reduce_scatter_ = false; - size_t mat_a_dimension_ = 0; - size_t mat_b_dimension_ = 0; -}; - -class MatMul : public MatMulBase { - public: - MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) - : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} - ~MatMul() override = default; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; -}; - -class MatMulInfo : public MatMul { - public: - MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : MatMul(name, inputs_shape, outputs_shape, attrs) {} - ~MatMulInfo() override = default; -}; - -class BatchMatMulInfo : public MatMul { - public: - BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : MatMul(name, inputs_shape, outputs_shape, attrs) {} - ~BatchMatMulInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_MATMUL_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc deleted file mode 100644 index ea2d045104..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc +++ /dev/null @@ -1,311 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/onehot_info.h" - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/device_matrix.h" -#include "parallel/graph_util/generate_graph.h" -#include "parallel/strategy.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status OneHotInfo::GetAttrs() { - auto iter = attrs_.find(AXIS); - if (iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - axis_value_ptr_ = iter->second; - axis_ = iter->second->cast()->value(); - } else { - MS_LOG(ERROR) << name_ << ": The value of axis is not int."; - return FAILED; - } - } - - if (inputs_shape_[0].size() != 1) { - MS_LOG(ERROR) << name_ << ": Input's shape only support 1-D now."; - return FAILED; - } - - if ((axis_ > 1) || (axis_ < -1)) { - MS_LOG(ERROR) << name_ << ": Axis " << axis_ << " is out of range[-1, 1]."; - return FAILED; - } - return SUCCESS; -} - -Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != 1) { - MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, - is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status OneHotInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - // Now input only support 1-D tensor, so the output is a 2-D tensor - // If input is a vector of length features, the output shape will be: - // [features, depth] if axis == -1 (or axis == 1) - // [depth, features] if axis == 0 - if (axis_ == 0) { - dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable - dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable - } else { - dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable - dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable - } - - return SUCCESS; -} - -Status OneHotInfo::InferTensorMap() { - std::vector input_tensor_map_index, output_tensor_map_index; - size_t size = outputs_shape_[0].size(); - // such as 2: tensor_map_index [1,0] - if (axis_ == 0) { - for (size_t i = 0; i < size; ++i) { - output_tensor_map_index.push_back((int32_t)(i)); - } - } else { - for (size_t i = 0; i < size; ++i) { - output_tensor_map_index.push_back((int32_t)(LAST_INDEX(size) - i)); - } - } - outputs_tensor_map_.push_back(output_tensor_map_index); - - // Now input only support 1-D tensor - input_tensor_map_index.push_back(1); - - inputs_tensor_map_.push_back(input_tensor_map_index); - return SUCCESS; -} - -// axis = -1 -// (0,(1,16),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(1,0) -// (0,(16,1),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(1,0) -// (0,(2,8),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between -// machines dev_matrix=(2,8) map_in=(1) map_out=(1,0) (0, (2,4),(),())16 devices dev_matrix=(2,4,2) map_in=(1) -// map_out=(1,0) -// axis = 0 -// (0, (16,1),(),())reid dev_matrix=(1,16) map_in=(1) map_out=(0,1) -// (0, (1,16),(),())data parallel dev_matrix=(16,1) map_in=(1) map_out=(0,1) -// (0, (8,2),(),())16 devices two machines,model parallel among devices in the same machine,data parallel between -// machines dev_matrix=(2,8) map_in=(1) map_out=(0,1) (0,(4,2),(),())16 devices dev_matrix=(2,4,2) map_in=(1) -// map_out=(0,1) -Status OneHotInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape output_shape = outputs_shape_.at(0); - - TensorLayout input_tensor_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout); - TensorInfo output_tensor_info(output_tensor_layout); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - - return SUCCESS; -} - -Status OneHotInfo::ExtractInputInfo() { - CheckGlobalDeviceManager(); - rank_ = g_device_manager->global_rank(); - mod_rank_ = rank_ % dev_matrix_shape_.back(); - if (!cnode_) { - MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; - return FAILED; - } - if (cnode_->inputs().size() != 5) { - MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, real input size is " - << cnode_->inputs().size(); - return FAILED; - } - if (input_value_.size() != 4) { - MS_LOG(ERROR) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive, and input value size " - "must be 4, real size is " - << input_value_.size(); - return FAILED; - } - auto value_ptr = input_value_.at(1); - if (value_ptr == nullptr) { - MS_LOG(WARNING) << "Input 2 of cnode is not a value node, its type is " << cnode_->input(2)->type_name(); - return FAILED; - } - - if (value_ptr->isa()) { - total_class_number_ = value_ptr->cast()->value(); - } else { - MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; - return FAILED; - } - classes_each_device_ = total_class_number_ / dev_matrix_shape_.back(); - - return SUCCESS; -} - -Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { - if (dev_matrix_shape_.back() == 1) { - replace_graph_ = nullptr; - return SUCCESS; - } - if (ExtractInputInfo() != SUCCESS) { - MS_LOG(ERROR) << "ExtractInputInfo failed"; - return FAILED; - } - GenerateGraph gen_g = GenerateGraph(); - Status status = gen_g.Init(cnode); - if (status != SUCCESS) { - MS_LOG(ERROR) << "GenerateGraph Init failed"; - return FAILED; - } - - auto floor_div = - gen_g.PushBack({gen_g.NewOpInst(FLOORDIV), gen_g.virtual_input_node(), CreateInt32Tensor(classes_each_device_)}); - auto mul1 = gen_g.PushBack({gen_g.NewOpInst(MUL), floor_div, CreateInt32Tensor(classes_each_device_)}); - auto sub1 = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), mul1}); - auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), floor_div, CreateInt32Tensor(mod_rank_)}); - auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, CreatTypeInt(32)}); - auto mul2 = gen_g.PushBack({gen_g.NewOpInst(MUL), sub1, cast}); - auto tensor_add = gen_g.PushBack({gen_g.NewOpInst(TENSOR_ADD), mul2, CreateInt32Tensor(1)}); - auto mul3 = gen_g.PushBack({gen_g.NewOpInst(MUL), cast, tensor_add}); - auto sub2 = gen_g.PushBack({gen_g.NewOpInst(SUB), mul3, CreateInt32Tensor(1)}); - Attr attr_onehot_axis = std::make_pair(AXIS, axis_value_ptr_); - OperatorAttrs attrs_onehot = {attr_onehot_axis}; - auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_), - cnode->input(3), cnode->input(4)}); - std::vector> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)}; - replace_graph_ = std::make_shared>, AnfNodePtr>>( - std::make_pair(input_nodes, onehot)); - - return SUCCESS; -} - -ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { - if (ComputeReplaceGraph(cnode) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return nullptr; - } - return replace_graph_; -} - -Status OneHotInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - Status status = ComputeReplaceGraph(cnode_); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return status; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status OneHotInfo::GenerateStrategies(int32_t stage_id) { - Shapes splittable_inputs = {{1, 1}, {}, {}}; - std::vector sp_vector; - if (inputs_shape_.size() != 3) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); - return FAILED; - } - if (outputs_shape_.size() != 1) { - MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - if (GenerateStrategiesForIndependentInputs(stage_id, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, - splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategies failed."; - return FAILED; - } - - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - - return SUCCESS; -} - -Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -std::shared_ptr>> OneHotInfo::GenerateBatchStrategies() { - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - Dimensions strategy = {SizeToInt(dev_num), 1}; - Dimensions empty_strategy; - std::vector strategy_v = {strategy, empty_strategy, empty_strategy}; - return std::make_shared>>(strategy_v); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h deleted file mode 100644 index 3c8a64f954..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class OneHotInfo : public OperatorInfo { - public: - OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~OneHotInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - std::shared_ptr>> GenerateBatchStrategies() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status ExtractInputInfo(); - - private: - Status ComputeReplaceGraph(const CNodePtr &cnode); - - int axis_ = -1; - int32_t rank_ = 0; - int32_t total_class_number_ = 1; - int32_t classes_each_device_ = 1; - ValuePtr axis_value_ptr_; - int32_t mod_rank_ = 0; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ONEHOT_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc deleted file mode 100644 index f9b294898c..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ /dev/null @@ -1,1334 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/operator_info.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/dtype.h" -#include "ir/tensor.h" -#include "ir/value.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/context.h" -#include "utils/context/ms_context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { - if (strategy == nullptr) { - MS_LOG(ERROR) << "The strategy is null."; - return FAILED; - } - - size_t strategy_size = strategy->GetInputNumber(); - size_t inputs_shape_size = inputs_shape.size(); - if (strategy_size != inputs_shape_size) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; - } else { - MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - for (size_t i = 0; i < strategy_size; ++i) { - Shape sub_strategy = stra.at(i); - Shape sub_input_shape = inputs_shape.at(i); - size_t strategy_len = sub_strategy.size(); - size_t inputs_len = sub_input_shape.size(); - if (strategy_len != inputs_len) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len - << ", index: " << i; - } else { - MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len - << ", index: " << i; - } - return FAILED; - } - - for (size_t j = 0; j < strategy_len; ++j) { - int32_t strategy_value = sub_strategy.at(j); - if (strategy_value < MIN_SLICE_NUM) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; - } else { - MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value; - } - return FAILED; - } - - if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value; - } else { - MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value; - } - return FAILED; - } - - int32_t shape_value = sub_input_shape.at(j); - if ((shape_value % strategy_value) != 0) { - if (is_auto_parallel) { - MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; - } else { - MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; - } - return FAILED; - } - } - } - - return SUCCESS; -} - -void OperatorInfo::ResetQueueMember() { - inputs_tensor_info_.clear(); - outputs_tensor_info_.clear(); - inputs_tensor_map_.clear(); - outputs_tensor_map_.clear(); - dev_matrix_shape_.clear(); - forward_op_.clear(); - mirror_ops_.clear(); - sub_ops_.clear(); - replace_op_.clear(); - replace_op_info_.clear(); - virtual_div_op_.clear(); - global_device_list_.clear(); -} - -Status OperatorInfo::InferAttrs() { - if (infer_attrs_completed_) { - return SUCCESS; - } - - if (GetAttrs() != SUCCESS) { - return FAILED; - } - infer_attrs_completed_ = true; - return SUCCESS; -} - -void OperatorInfo::SetDeviceListByStrategy() { - int32_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); -} - -Status OperatorInfo::InferRepeatedCalcInfo() { - int32_t g_dev_list_size = SizeToInt(global_device_list_.size()); - int32_t dev_matrix_size = - std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); - if (dev_matrix_size == 0) { - MS_LOG(ERROR) << name_ << ": The dev matrix size is 0"; - return FAILED; - } - - if (g_dev_list_size == dev_matrix_size) { - repeated_calc_num_ = 1; - } else if (g_dev_list_size % dev_matrix_size == 0) { - repeated_calc_num_ = g_dev_list_size / dev_matrix_size; - } else { - MS_LOG(ERROR) << name_ << ": Dev list size " << g_dev_list_size << " can not be divisible by dev matrix size " - << dev_matrix_size; - return FAILED; - } - - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - int32_t stage = strategy_->GetInputStage(); - local_device_list_ = g_device_manager->global_device_list(stage, rank, repeated_calc_num_); - - return SUCCESS; -} - -// if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix, -// only use for infer tensor layout -void OperatorInfo::SetRepeatedCalcDevMatrix() { - if (repeated_calc_num_ <= 1) { - return; - } - - (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); -} - -// use for loss repeated calculation -Operator CreateVirtualDivOp(int32_t div_num) { - OperatorName operator_name = VIRTUAL_DIV; - ValuePtr attr0_value = MakeValue(div_num); - Attr attr0 = std::make_pair(DIVISOR, attr0_value); - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - - OperatorParams operator_param; - OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_arg); - return op; -} - -// use for forward all reduce -Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { - OperatorName operator_name = ALL_REDUCE; - ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM - ValuePtr attr1_value = MakeValue(group); // group - Attr attr0 = std::make_pair(OP, attr0_value); - Attr attr1 = std::make_pair(GROUP, attr1_value); - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - operator_attrs.push_back(attr1); - - OperatorParams operator_param; - OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_arg); - MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group; - return op; -} - -Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) { - OperatorName operator_name = REDUCE_SCATTER; - ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM - ValuePtr attr1_value = MakeValue(group); // group - Attr attr0 = std::make_pair(OP, attr0_value); - Attr attr1 = std::make_pair(GROUP, attr1_value); - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - operator_attrs.push_back(attr1); - - OperatorParams operator_param; - OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_arg); - MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group; - return op; -} - -// use for get tensor slice -Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { - Shape tensor_map = tensor_layout.tensor_map().array(); - Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); - OperatorName operator_name = GET_TENSOR_SLICE; - - OperatorAttrs attrs; - ValuePtr dev_mat_value = MakeValue(dev_matrix_shape); - Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2); - ValuePtr tensor_map_value = MakeValue(tensor_map); - Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3); - OperatorParams params = {dev_mat_param, tensor_map_param}; - OperatorArgs operator_arg = std::make_pair(attrs, params); - - Operator op = std::make_pair(operator_name, operator_arg); - MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is " - << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map); - return op; -} - -OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { - if ((dev_num == 0) || (dev_num == 1)) { - MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; - } - OperatorVector op_for_weight; - bool mean_flag = ParallelContext::GetInstance()->mirror_mean(); - - OperatorName operator_name = MIRROR_OPERATOR; - ValuePtr attr0_value = MakeValue(group_name); - ValuePtr attr1_value = MakeValue(SizeToInt(dev_num)); - ValuePtr attr2_value = MakeValue(mean_flag); - - Attr attr0 = std::make_pair(GROUP, attr0_value); - Attr attr1 = std::make_pair(DEV_NUM, attr1_value); - Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value); - - OperatorAttrs operator_attrs; - operator_attrs.push_back(attr0); - operator_attrs.push_back(attr1); - operator_attrs.push_back(attr2); - - OperatorParams operator_param; - OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); - - Operator op = std::make_pair(operator_name, operator_args); - - op_for_weight.push_back(op); - MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is " - << mean_flag; - return op_for_weight; -} - -Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { - if (group == nullptr) { - MS_LOG(ERROR) << "The group is null."; - return FAILED; - } - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { - return FAILED; - } - - if (group_devices.size() == 1) { - MS_LOG(INFO) << "The dev size is 1, no need to create group."; - return SUCCESS; - } - - Group g = g_device_manager->CreateGroup(group_devices); - group->push_back(g); - return SUCCESS; -} - -Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { - if (group == nullptr) { - MS_LOG(ERROR) << "The group is null."; - return FAILED; - } - CheckGlobalDeviceManager(); - int32_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, global_device_list_, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { - return FAILED; - } - - if (group_devices.size() == 1) { - MS_LOG(INFO) << "The dev size is 1, no need to create group."; - return SUCCESS; - } - - Group g = g_device_manager->CreateGroup(group_devices); - group->push_back(g); - return SUCCESS; -} - -Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { - Shape slice_shape; - if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { - MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; - return slice_shape; - } - for (size_t i = 0; i < strategy.size(); ++i) { - slice_shape.push_back(tensor_shape.at(i) / strategy.at(i)); - } - return slice_shape; -} - -Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { - if (slice_shapes == nullptr) { - MS_LOG(ERROR) << "The slice_shapes is null."; - return FAILED; - } - if (strategys.size() != shapes.size()) { - MS_LOG(ERROR) << "Strategy size " << strategys.size() << " not equal to shape size " << shapes.size(); - return FAILED; - } - - for (size_t i = 0; i < strategys.size(); ++i) { - if (strategys.at(i).size() != shapes.at(i).size()) { - MS_LOG(ERROR) << "Strategy dimension " << strategys.at(i).size() << " not equal to shape dimension " - << shapes.at(i).size(); - slice_shapes->clear(); - return FAILED; - } - - for (size_t j = 0; j < shapes.at(i).size(); ++j) { - if (strategys.at(i).at(j) <= 0) { - MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategys[i]) - << " the element is less than or equal to 0."; - slice_shapes->clear(); - return FAILED; - } - if (shapes.at(i).at(j) % strategys.at(i).at(j) != 0) { - MS_LOG(ERROR) << "Shape cannot be divisible by strategy, " << shapes.at(i).at(j) << " : " - << strategys.at(i).at(j); - slice_shapes->clear(); - return FAILED; - } - } - Shape slice_shape = GetSliceShape(shapes.at(i), strategys.at(i)); - slice_shapes->push_back(slice_shape); - } - - return SUCCESS; -} - -Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, - Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { - if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { - MS_LOG(ERROR) << "The slice_shape is null."; - return FAILED; - } - - if (InferSliceShapeByStrategy(inputs_strategy, inputs_shape_, inputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << "Infer inputs slice shape error."; - return FAILED; - } - - if (InferSliceShapeByStrategy(outputs_strategy, outputs_shape_, outputs_slice_shape) != SUCCESS) { - MS_LOG(ERROR) << "Infer outputs slice shape error."; - inputs_slice_shape->clear(); - return FAILED; - } - - return SUCCESS; -} - -// method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 -Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InferAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferAttrs failed."; - return FAILED; - } - - // must be after InferAttrs() - if (CheckStrategy(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": CheckStrategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; - } - return FAILED; - } - - // need to clear queues before Init(), - // because Init() may be called multiple times by cost model - ResetQueueMember(); - - strategy_ = strategy; - SetDeviceListByStrategy(); - - if (InferDevMatrixShape() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; - return FAILED; - } - - used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies()); - - // must be after InferDevMatrixShape - if (InferRepeatedCalcInfo() != SUCCESS) { - MS_LOG(ERROR) << ": InferRepeatedCalcInfo failed."; - return FAILED; - } - - // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix for layout - SetRepeatedCalcDevMatrix(); - - if (InferTensorMap() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; - return FAILED; - } - - if (InferTensorInfo() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; - return FAILED; - } - - return SUCCESS; -} - -// method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape -Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InferAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferAttrs failed."; - return FAILED; - } - - // must be after InferAttrs() - if (CheckStrategy(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; - return FAILED; - } - - // need to clear queues before Init(), - // because Init() may be called multiple times by cost model - ResetQueueMember(); - - strategy_ = strategy; - SetDeviceListByStrategy(); - - if (InferDevMatrixShape() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; - return FAILED; - } - - // must be after InferDevMatrixShape - if (InferRepeatedCalcInfo() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed."; - return FAILED; - } - - if (InferTensorMap() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; - return FAILED; - } - - if (InferTensorInfo() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; - return FAILED; - } - - return SUCCESS; -} - -Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - return FAILED; - } - - if (InferForwardCommunication() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; - return FAILED; - } - - if (InferMirrorOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; - return FAILED; - } - - if (InferVirtualDivOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; - return FAILED; - } - - return SUCCESS; -} - -Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { - if (strategy == nullptr) { - MS_LOG(ERROR) << name_ << ": The strategy is null."; - return FAILED; - } - - if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { - return FAILED; - } - - if (InferForwardCommunication() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; - return FAILED; - } - - if (InferMirrorOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; - return FAILED; - } - - if (InferVirtualDivOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; - return FAILED; - } - - return SUCCESS; -} - -std::vector> OperatorInfo::GetAliveSuccEdges() { - std::vector> ret; - for (auto &edge : succ_edges_) { - if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { - ret.push_back(edge); - } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) { - // CAST is ordered in front of L2NORMALIZE - ret.push_back(edge); - } - } - for (auto &edge : succ_edges_) { - if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) && - (edge->next_operator()->name().find(CAST) == std::string::npos)) { - ret.push_back(edge); - } - } - return ret; -} - -std::vector> OperatorInfo::GetAlivePrevEdges() { - std::vector> ret; - for (auto &edge : prev_edges_) { - if (edge->prev_operator()->is_alive()) { - ret.push_back(edge); - } - } - return ret; -} - -void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; - return; - } - for (auto &edge : prev_edges_) { - if (edge->prev_operator() == op) { - edge = new_edge; - return; - } - } - MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; -} - -void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; - return; - } - for (auto &edge : succ_edges_) { - if (edge->next_operator() == op) { - edge = new_edge; - return; - } - } - MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; -} - -void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; - return; - } - std::vector> new_pre_edges; - for (auto &edge : prev_edges_) { - if (edge->prev_operator() != op) { - new_pre_edges.push_back(edge); - } - } - new_pre_edges.push_back(new_edge); - prev_edges_ = new_pre_edges; -} - -void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { - if (op == nullptr) { - MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; - return; - } - std::vector> new_succ_edges; - for (auto &edge : succ_edges_) { - if (edge->next_operator() != op) { - new_succ_edges.push_back(edge); - } - } - new_succ_edges.push_back(new_edge); - succ_edges_ = new_succ_edges; -} - -std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes &shapes, const std::vector &split_flag_list) { - if (shapes.size() != split_flag_list.size()) { - MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " - << shapes.size(); - return nullptr; - } - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - std::vector> strategy_v; - for (size_t i = 0; i != shapes.size(); i++) { - if (shapes[i].empty()) { - MS_LOG(INFO) << "Elements of shapes is empty."; - std::vector empty_element; - strategy_v.push_back(empty_element); - } else { - std::vector element(shapes[i].size(), 1); - if (split_flag_list[i]) { - element[0] = dev_num; - } - strategy_v.push_back(element); - } - } - return std::make_shared>>(strategy_v); -} - -void OperatorInfo::ReComputeBatchSplitFlagList() { - if (!inputs_shape_.empty()) { - split_flag_list_[0] = true; - } -} - -void OperatorInfo::ComputeBatchSplitFlagList() { - split_flag_list_.clear(); - for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) { - split_flag_list_.push_back(false); - } - ReComputeBatchSplitFlagList(); -} - -// This is a common method for checking whether the generated stragegy has the correct number of devuces. -Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { - if (sp == nullptr) { - MS_LOG(ERROR) << "The strategy is null."; - return FAILED; - } - int32_t product = 1; - - for (auto &input_partition : inputs_partitions) { - product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); - } - if (!FULLY_USE_DEVICES) { - if (IntToSize(product) > dev_num) { - return FAILED; - } - } else { - if ((product != 1) && (IntToSize(product) != dev_num)) { - return FAILED; - } - } - std::vector stras(inputs_partitions); - (*sp) = std::make_shared(stage_id, stras); - return SUCCESS; -} - -std::shared_ptr>> OperatorInfo::GenerateBatchStrategies() { - ComputeBatchSplitFlagList(); - return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); -} - -void PrintStrategy(const StrategyPtr &strategy) { - if (strategy == nullptr) { - return; - } - std::string all_strategy = ""; - for (size_t i = 0; i < strategy->GetInputNumber(); ++i) { - all_strategy += "["; - for (size_t j = 0; j < strategy->GetInputDim()[i].size(); ++j) { - all_strategy += std::to_string(strategy->GetInputDim()[i][j]); - if (j != strategy->GetInputDim()[i].size() - 1) { - all_strategy += ", "; - } - } - all_strategy += "]"; - if (i != strategy->GetInputNumber() - 1) { - all_strategy += ", "; - } - } - MS_LOG(INFO) << "The strategy is: " << all_strategy; -} - -// generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) -Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { - MS_LOG(ERROR) << "The inputs size is wrong."; - return FAILED; - } - - if ((inputs_shape[0].size() != inputs_shape[1].size()) || - (splittable_inputs[0].size() != splittable_inputs[1].size())) { - MS_LOG(ERROR) << "The size of two inputs are not equal."; - return FAILED; - } - - Shapes input0_shape = {inputs_shape[0]}; - Shapes input0_splittable = {splittable_inputs[0]}; - if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) { - return FAILED; - } - - for (auto &sp : *sp_vector) { - sp->ExpandInputDimFromOneToTwo(); - } - - return SUCCESS; -} - -// generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast -// such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if (inputs_shape[0].size() >= inputs_shape[1].size()) { - MS_LOG(ERROR) << "Invalid inputs shape."; - return FAILED; - } - - // first, generate strategy for input0 the same as input1 - Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]}; - Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]}; - if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - - // second, get the correct strategy for input0 - for (auto &sp : *sp_vector) { - std::vector tmp_strategy; - Dimensions input0_strategy = sp->GetInputDim()[0]; - size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); - - // erase the unnecessary part - (void)input0_strategy.erase(input0_strategy.begin(), - input0_strategy.begin() + static_cast(size_diff)); - - // handel the case likes ([1, c, d], [a, b, c, d]) - for (size_t i = 0; i < inputs_shape[0].size(); ++i) { - if (inputs_shape[0][i] == 1) { - input0_strategy[i] = 1; - } else { - break; - } - } - - // reset the strategy - tmp_strategy.push_back(input0_strategy); // input0 - tmp_strategy.push_back(sp->GetInputDim()[1]); // input1 - sp->ResetInputs(tmp_strategy); - } - return SUCCESS; -} - -// generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast -// such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if (inputs_shape[0].size() <= inputs_shape[1].size()) { - MS_LOG(ERROR) << "Invalid inputs shape."; - return FAILED; - } - - // first, generate strategy for input1 the same as input0 - Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]}; - Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]}; - if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - - // second, get the correct strategy for input1 - for (auto &sp : *sp_vector) { - std::vector tmp_strategy; - tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 - - Dimensions input1_strategy = sp->GetInputDim()[1]; - size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size(); - - // erase the unnecessary part - (void)input1_strategy.erase(input1_strategy.begin(), - input1_strategy.begin() + static_cast(size_diff)); - - // handel the case likes ([a, b, c, d], [1, c, d]) - for (size_t i = 0; i < inputs_shape[1].size(); ++i) { - if (inputs_shape[1][i] == 1) { - input1_strategy[i] = 1; - } else { - break; - } - } - - // reset the strategy - tmp_strategy.push_back(input1_strategy); // input1 - sp->ResetInputs(tmp_strategy); - } - return SUCCESS; -} - -// generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast -// such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if (inputs_shape[0].size() != inputs_shape[1].size()) { - MS_LOG(ERROR) << "Invalid inputs shape."; - return FAILED; - } - - // step1: ([a, 1], [1, b]) -> [a, b] - Shape max_shape, splittable_vector; - for (size_t i = 0; i < inputs_shape[0].size(); ++i) { - if (inputs_shape[0][i] >= inputs_shape[1][i]) { - max_shape.push_back(inputs_shape[0][i]); - splittable_vector.push_back(splittable_inputs[0][i]); - } else { - max_shape.push_back(inputs_shape[1][i]); - splittable_vector.push_back(splittable_inputs[1][i]); - } - } - - // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b]) - Shapes tmp_inputs_shape = {max_shape, max_shape}; - Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector}; - if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - - // step3: reset the strategy if the dimension is 1 - for (auto &sp : *sp_vector) { - Dimensions input0_strategy = sp->GetInputDim()[0]; - Dimensions input1_strategy = sp->GetInputDim()[1]; - for (size_t i = 0; i < inputs_shape[0].size(); ++i) { - if (inputs_shape[0][i] == 1) { - input0_strategy[i] = 1; - } - - if (inputs_shape[1][i] == 1) { - input1_strategy[i] = 1; - } - } - sp->ResetInputs({input0_strategy, input1_strategy}); - } - - return SUCCESS; -} - -// 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that -// the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding -// dimension is splittable. 'inputs_partitions' is the result of partitions. -// NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring -// specific dimensions in inputs have the identical partition should have individual implementation. -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - if (splittable_inputs.size() != inputs_shape.size()) { - MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size() - << " : " << inputs_shape.size(); - return FAILED; - } - CheckGlobalDeviceManager(); - size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - - Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions; - for (size_t j = 0; j < inputs_shape.size(); ++j) { - (void)combined_inputs_shape.insert(combined_inputs_shape.end(), inputs_shape[j].begin(), inputs_shape[j].end()); - (void)combined_splittable_inputs.insert(combined_splittable_inputs.end(), splittable_inputs[j].begin(), - splittable_inputs[j].end()); - } - std::function recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape, - &combined_splittable_inputs, &combined_partitions, &recursive, - &inputs_shape](uint32_t current_index, size_t n) { - if (current_index == combined_inputs_shape.size()) { - MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); - Shapes inputs_partitions; - size_t global_index = 0; - for (auto &shape : inputs_shape) { - Shape tmp_partition; - for (size_t j = 0; j < shape.size(); ++j) { - tmp_partition.push_back(combined_partitions[global_index]); - global_index++; - } - inputs_partitions.push_back(tmp_partition); - } - StrategyPtr sp; - if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) { - sp_vector->push_back(sp); - } - return; - } else { - MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size(); - if (combined_splittable_inputs[current_index] == 0) { - combined_partitions.push_back(MIN_SLICE_NUM); - recursive(current_index + 1, n / MIN_SLICE_NUM); - combined_partitions.pop_back(); - } else if (combined_splittable_inputs[current_index] == 1) { - for (uint32_t i = 1; i <= n; i *= 2) { - if (n % i == 0 && IntToSize(combined_inputs_shape[current_index]) % i == 0) { - combined_partitions.push_back(i); - recursive(current_index + 1, n / i); - combined_partitions.pop_back(); - } - } - } - } - }; - recursive(0, dev_num); - if (sp_vector->empty()) { - MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo."; - } - return SUCCESS; -} - -// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, -// and the corresponding dimensions that are not broadcast are all relevant dimensions -// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *const sp_vector) { - if (sp_vector == nullptr) { - MS_LOG(ERROR) << "The sp_vector is null."; - return FAILED; - } - - if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) { - MS_LOG(ERROR) << "The inputs' size is wrong."; - return FAILED; - } - - if (inputs_shape[0] == inputs_shape[1]) { - // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy - if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success."; - } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) { - // ([a, b, c, d], []) or ([], [a, b, c, d]) - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "Generate strategies for scalar case failed."; - return FAILED; - } - MS_LOG(INFO) << "Generate strategies for scalar case success."; - } else if (inputs_shape[0].size() > inputs_shape[1].size()) { - // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) - if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success."; - } else if (inputs_shape[0].size() < inputs_shape[1].size()) { - // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) - if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success."; - } else { // same size, but different value - // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) - if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) { - MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed."; - return FAILED; - } - MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success."; - } - return SUCCESS; -} - -Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { - if (InitForCostModel(strategy) == FAILED) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Initialization under the strategy failed."; - } - return FAILED; - } - int32_t stage_id = strategy->GetInputStage(); - double computation_cost = - operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(computation_cost, communication_cost); - result->communication_without_parameter_ = - operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); - - // Breaking ties for preferring data parallelization - BreakingTiesForPerferringDataParallel(strategy, result); - // refine communication cost calculation for practice - RefineForPracticalCost(result, false); - result->communication_forward_ = result->communication_without_parameter_; - - std::shared_ptr swc = - std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); - swc->cost_list.push_back(result); - strategy_cost_.emplace_back(swc); - - return SUCCESS; -} - -int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { - if (is_output_parameter_involve_ != -1) { - return is_output_parameter_involve_; - } - is_parameter_involve_ = is_parameter_; - const auto &prev_edges = this->GetAlivePrevEdges(); - for (auto &p_edge : prev_edges) { - auto input_index = p_edge->next_op_input_index(); - auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); - if (input_index >= is_parameter_involve_.size()) { - MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size() - << ", but got wrong input_index: " << input_index; - } - if (prev_op_para == 0) { - is_parameter_involve_[input_index] = false; - } else if (prev_op_para == 1) { - is_parameter_involve_[input_index] = true; - } else { - MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index; - } - p_edge->set_parameter_involve(prev_op_para); - } - if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) { - // If anyone of the input is a parameter_involved, the output is parameter_involved. - is_output_parameter_involve_ = 1; - } else { - is_output_parameter_involve_ = 0; - } - - return is_output_parameter_involve_; -} - -Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { - if (is_parameter.size() != inputs_shape_.size()) { - MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() - << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); - return FAILED; - } - is_parameter_ = is_parameter; - operator_cost()->set_is_parameter(is_parameter); - return SUCCESS; -} - -Status OperatorInfo::CalculateMemoryCost() { - // First, set the 'is_parameter_involve_' and 'is_output_parameter_involve_' into OperatorCost, which are necessary to - // calculate memory cost. - if (is_parameter_involve_.size() != is_parameter_.size()) { - MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'."; - return FAILED; - } - operator_cost()->set_is_parameter_involve(is_parameter_involve_); - operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); - // Set the memory cost in the 'strategy_cost_' - for (auto &swc : strategy_cost_) { - auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); - swc->cost_list[0]->memory_with_reuse_ = mem_cost; - } - return SUCCESS; -} - -Status OperatorInfo::CalculateMemoryCostForInference() { - // First, set the 'is_outputs_critical_' flag into OperatorCost. - if (is_output_critical_ == -1) { - MS_LOG(EXCEPTION) << "The critical flag is not set."; - return FAILED; - } - operator_cost()->set_output_critical(is_output_critical_); - // Set the memory cost in the 'strategy_cost_' - for (auto &swc : strategy_cost_) { - auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr); - swc->cost_list[0]->memory_with_reuse_ = mem_cost; - } - return SUCCESS; -} - -Status OperatorInfo::CorrectMemoryCost(size_t input_index) { - for (auto &swc : strategy_cost_) { - double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * - static_cast(operator_cost()->inputs_type_lengths()[input_index]); - swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; - if (swc->cost_list[0]->memory_with_reuse_ < 0) { - MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_ - << ", the parameter memory cost is: " << parameter_mem_cost; - return FAILED; - } - } - return SUCCESS; -} - -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { - int32_t ret = -1; - - // The number of repetitions is equal to the number of all devices divided by the number of devices use for - // tensor map. - int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); - for (auto &element : tensor_map) { - // -1 means the corresponding dimension is not split. - if (element == MAP_NONE) { - continue; - } else if ((element < 0) || (IntToSize(element) >= dev_matrix_shape.size())) { - MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is " - << ShapeToString(dev_matrix_shape); - return ret; - } else { - size_t index = dev_matrix_shape.size() - IntToSize(element) - 1; - if (dev_matrix_shape[index] <= 0) { - MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape); - return ret; - } - device_num /= dev_matrix_shape[index]; - } - } - - return device_num; -} - -Status OperatorInfo::InferAsLossDivisor() { - if (!ParallelContext::GetInstance()->loss_repeated_mean()) { - as_loss_divisor_ = 1; - return SUCCESS; - } - - if (outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; - return FAILED; - } - - if (outputs_tensor_map_.size() > 1) { - MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size() - << ", need to override this function "; - return FAILED; - } - - if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToInt(global_device_list_.size()); - MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor."; - return SUCCESS; - } - - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); - MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(dev_matrix_shape_) - << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is " - << as_loss_divisor_; - return SUCCESS; -} - -// If the operator is used as a loss, a div node is inserted for the grad of all its inputs. -Status OperatorInfo::InferVirtualDivOps() { - if (InferAsLossDivisor() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed."; - return FAILED; - } - - if (as_loss_divisor_ <= 0) { - MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_; - return FAILED; - } else if (as_loss_divisor_ == 1) { - MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op."; - return SUCCESS; - } - - virtual_div_op_.clear(); - // if loss is repeated calculation, insert div op - Operator op = CreateVirtualDivOp(as_loss_divisor_); - virtual_div_op_.push_back(op); - return SUCCESS; -} - -Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, - const std::vector &output_lengths) { - if (input_lengths.size() != inputs_shape_.size()) { - MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() - << " do not have the same number of inputs shape: " << inputs_shape_.size(); - return FAILED; - } - if (output_lengths.size() != outputs_shape_.size()) { - MS_LOG(ERROR) << "Output_lengths: " << output_lengths.size() - << " do not have the same number of outputs shape: " << outputs_shape_.size(); - return FAILED; - } - inputs_type_lengths_ = input_lengths; - outputs_type_lengths_ = output_lengths; - operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths); - return SUCCESS; -} - -double OperatorInfo::GetOutputsTotalSize() { - if (is_calculated_outputs_size_) { - return outputs_total_size_; - } - if (outputs_type_lengths_.size() != outputs_shape_.size()) { - MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size() - << " do not have the same number of outputs shape: " << outputs_shape_.size(); - } - double sum = 0.0; - for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) { - auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast(1.0), - std::multiplies()); - sum += size * static_cast(outputs_type_lengths_[i]); - } - is_calculated_outputs_size_ = true; - outputs_total_size_ = sum; - return outputs_total_size_; -} - -Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { - if (outputs_type.size() != outputs_shape_.size()) { - MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() - << " do not have the same number of outputs shape: " << outputs_shape_.size(); - return FAILED; - } - outputs_type_ = outputs_type; - return SUCCESS; -} - -void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { - if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { - CheckGlobalDeviceManager(); - auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); - if (IntToSize(stra->GetInputDim()[0][0]) == total_device_num) { - if (cost->computation_cost_ > 1.0) { - cost->computation_cost_ -= 1.0; - } - if (cost->communication_cost_ > 1.0) { - cost->communication_cost_ -= 1.0; - } - if (cost->communication_with_partial_para_ > 1.0) { - cost->communication_with_partial_para_ -= 1.0; - } - if (cost->communication_without_parameter_ > 1.0) { - cost->communication_without_parameter_ -= 1.0; - } - } - } -} - -double OperatorInfo::GetForwardMemoryCostFromCNode() { - return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0); -} - -void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) { - MS_EXCEPTION_IF_NULL(s_strategy); - if (!s_strategy->IsEqual(selected_strategy_)) { - MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:"; - PrintStrategy(selected_strategy_); - MS_LOG(INFO) << "The minimal strategy:"; - PrintStrategy(s_strategy); - } -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h deleted file mode 100644 index a3e6bc2c06..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "common/utils.h" -#include "base/base.h" -#include "parallel/auto_parallel/costmodel.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/group_manager.h" -#include "parallel/ops_info/ops_utils.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_info.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -using ForwardOp = OperatorVector; -using MirrorOps = std::vector; -using Ops = std::vector; -using VirtualDivOp = OperatorVector; -using TensorMaps = std::vector>; -using TensorLayouts = std::vector; -using different_type = std::vector::difference_type; -using PrimitiveAttrs = std::unordered_map; -using Strategys = std::vector; -using ReplaceGraphPtr = std::shared_ptr>, AnfNodePtr>>; - -class Edge; - -class OperatorInfo { - public: - OperatorInfo(std::string name, Shapes inputs_shape, Shapes outputs_shape, PrimitiveAttrs attrs, OperatorCostPtr cost) - : name_(std::move(name)), - inputs_shape_(std::move(inputs_shape)), - outputs_shape_(std::move(outputs_shape)), - attrs_(std::move(attrs)), - is_alive_(true), - operator_cost_(cost), - outputs_type_() { - std::vector not_parameteter(inputs_shape_.size(), false); - is_parameter_ = not_parameteter; - refkey_parameter_name_ = ""; - } - - virtual ~OperatorInfo() = default; - - Status set_is_parameter(const std::vector &is_parameter); - Status SetInputAndOutputTypeLength(const std::vector &input_lengths, - const std::vector &output_lengths); - double GetOutputsTotalSize(); - // Set outputs dtype. - // If only one output, outputs_type.size() is 1. - // If output is tuple, outputs_type.size() is greater than 1. - Status set_outputs_type(const std::vector &outputs_type); - const std::vector &outputs_type() const { return outputs_type_; } - virtual Status Init(const StrategyPtr &strategy) = 0; - virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts - - // Given the stage_id (which indicates the number of devices), - // generate all strategies for this operator - virtual Status GenerateStrategies(int32_t stage_id) = 0; - const OperatorCostPtr &operator_cost() const { return operator_cost_; } - void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } - virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; - - virtual std::shared_ptr>> GenerateBatchStrategies(); - virtual void ReComputeBatchSplitFlagList(); - void ComputeBatchSplitFlagList(); - - double GetForwardMemoryCostFromCNode(); - // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy - // is checked - Status SetCostUnderStrategyBase(const StrategyPtr &strategy); - std::vector> GetStrategyCost() { return strategy_cost_; } - // In the training phase, when the input of a operator contains WEIGHT or a output from other operators involving - // WEIGHT, then these input should stay in memory until it is used in the backward phase, which is kept in memory - // at the end of forward phase. - Status CalculateMemoryCost(); - // In the inference phase, the memory cost is incurred only when the operator is critical. The size is calculated - // by the output - Status CalculateMemoryCostForInference(); - int ComputeOpAndPrevEdgeParameterInvolved(); - - ForwardOp forward_op() const { return forward_op_; } - ForwardOp replace_op() const { return replace_op_; } - OutPutInfoVector replace_op_info() const { return replace_op_info_; } - virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } - MirrorOps mirror_ops() const { return mirror_ops_; } - Ops sub_ops() const { return sub_ops_; } - VirtualDivOp virtual_div_op() const { return virtual_div_op_; } - Shape dev_matrix_shape() const { return dev_matrix_shape_; } - std::vector inputs_tensor_info() const { return inputs_tensor_info_; } - std::vector outputs_tensor_info() const { return outputs_tensor_info_; } - std::vector> strategy_cost() const { return strategy_cost_; } - const std::string &name() const { return name_; } - void set_name(const std::string &name) { name_ = name; } - RankList global_device_list() const { return global_device_list_; } - - void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } - void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } - std::vector> succ_edges() const { return succ_edges_; } - std::vector> prev_edges() const { return prev_edges_; } - std::vector> GetAliveSuccEdges(); - std::vector> GetAlivePrevEdges(); - void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); - void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); - void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); - void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); - std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } - void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { - selected_strategy_ = s_strategy; - selected_cost_ = cost; - } - StrategyPtr selected_strategy() const { return selected_strategy_; } - CostPtr selected_cost() const { return selected_cost_; } - void CheckSelectedStrategy(const StrategyPtr &); - Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } - void set_input_value(const std::vector &input_value) { input_value_ = input_value; } - const std::vector &input_value() const { return input_value_; } - void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } - void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } - bool is_alive() const { return is_alive_; } - void SetNotAlive() { is_alive_ = false; } - StrategyPtr strategy() const { return strategy_; } - void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } - void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } - const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } - // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated - // multiple times. This method is to correct this, and makes the cost is calulated only once. - Status CorrectMemoryCost(size_t input_index); - int is_output_parameter_involve() const { return is_output_parameter_involve_; } - int is_output_critical() const { return is_output_critical_; } - void mark_output_critical() { is_output_critical_ = 1; } - void mark_output_not_critical() { is_output_critical_ = 0; } - int used_devices() const { return used_devices_; } - // needed by rec_parser - void set_type(const std::string &type) { type_ = type; } - const std::string &type() const { return type_; } - const std::unordered_map &attrs() const { return attrs_; } - - protected: - // needed by rec_parser - std::string type_; - virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; - virtual Status InferTensorMap() = 0; - virtual Status InferForwardCommunication() = 0; - virtual Status InferMirrorOps() = 0; - virtual Status GetAttrs() = 0; - virtual Status InferTensorInfo() = 0; - virtual Status InferDevMatrixShape() = 0; - void SetDeviceListByStrategy(); - void SetRepeatedCalcDevMatrix(); - Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); - Status CreateGroupByDim(size_t axis, std::vector *group); - Status InferAttrs(); - void ResetQueueMember(); - Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); - Status InitWithManualRepeatCalc(const StrategyPtr &strategy); - Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); - Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); - Status InferRepeatedCalcInfo(); - Status InferVirtualDivOps(); - - // Calculate the number of repeated calculations for the output by the number of devices and the output tensor map. - // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output - // is used for grad and overload the function. If the output is a scalar, need to override the function too. - virtual Status InferAsLossDivisor(); - Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, - Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); - void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); - - std::string name_; - Shapes inputs_shape_; - Shapes outputs_shape_; - std::unordered_map attrs_; - std::vector input_value_; - TypePtr outputs_dtype_; - - StrategyPtr strategy_; - std::vector inputs_tensor_info_; - std::vector outputs_tensor_info_; - Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension - int32_t repeated_calc_num_ = 1; - int32_t as_loss_divisor_ = 1; - TensorMaps inputs_tensor_map_; - TensorMaps outputs_tensor_map_; - ForwardOp forward_op_; - Ops sub_ops_; - ForwardOp replace_op_; - OutPutInfoVector replace_op_info_; - ReplaceGraphPtr replace_graph_; - MirrorOps mirror_ops_; - VirtualDivOp virtual_div_op_; - RankList global_device_list_; // the size of global_device_list equal to the size of stageID - RankList local_device_list_; // the size equal to global_device_list_.size() / repeated_calc_num_ - bool infer_attrs_completed_ = false; - - bool is_auto_parallel_ = false; // false: semi_auto_parallel; true: auto_parallel - // 'corrected_input_indices_' used to store the indices of input that have ALREADY been corrected. - std::vector corrected_input_indices_; - // Given a parallization strategy, there is a cost. - std::vector> strategy_cost_; - // For each input in 'inputs_', there is a bool variable indicating whether that the corresponding input is parameter - std::vector is_parameter_; - // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of - // pre-operator that has parameters as input. - std::vector is_parameter_involve_; - // If any input is parameter-involved, the output is parameter-involved. This variable is used in calculating - // peak memory cost in the training phase. - // -1: unset; 0: not parameter_involved; 1: parameter_involved - int is_output_parameter_involve_ = -1; - // Whether this output is critical, which means that this output is included in calculating peak memory cost - // in the inference phase. - // -1 : unset; 0: not critical; 1: critical - int is_output_critical_ = -1; - double outputs_total_size_ = 0.0; - bool is_calculated_outputs_size_ = false; - // for each input and output, the followings record the number of bytes of each element - std::vector inputs_type_lengths_; - std::vector outputs_type_lengths_; - std::vector> prev_edges_; - std::vector> succ_edges_; - StrategyPtr selected_strategy_; - // Used in DP algorithm - bool is_alive_; - CostPtr selected_cost_; - std::vector split_flag_list_; - std::string refkey_parameter_name_; - CNodePtr cnode_; - int32_t used_devices_ = -1; - - private: - OperatorCostPtr operator_cost_; - std::vector outputs_type_; -}; - -Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); -Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); -Operator CreateVirtualDivOp(int32_t div_num); -Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); -Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group); -Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); -OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); -std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes &shapes, const std::vector &split_flag_list); - -void PrintStrategy(const StrategyPtr &strategy); -// generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, - const Shapes &splittable_inputs, std::vector *sp_vector); -// generate strategies for that have two inputs, and input0 or input1 maybe broadcast, -// and the corresponding dimensions that are not broadcast are all relevant dimensions -// such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -// or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -// or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, - std::vector *sp_vector); - -Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPERATOR_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h deleted file mode 100644 index 45b00aed30..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/ops_info_head_files.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_OPS_INFO_HEAD_FILES_H_ - -#include "parallel/ops_info/activation_info.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/ops_info/batch_parallel_info.h" -#include "parallel/ops_info/bias_add_info.h" -#include "parallel/ops_info/comparison_function_info.h" -#include "parallel/ops_info/dropout_do_mask_info.h" -#include "parallel/ops_info/elementary_function_info.h" -#include "parallel/ops_info/gather_v2_info.h" -#include "parallel/ops_info/get_next_info.h" -#include "parallel/ops_info/l2_normalize_info.h" -#include "parallel/ops_info/layer_norm_info.h" -#include "parallel/ops_info/loss_info.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/ops_info/onehot_info.h" -#include "parallel/ops_info/prelu_info.h" -#include "parallel/ops_info/reduce_method_info.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/ops_info/transpose_info.h" -#include "parallel/ops_info/virtual_dataset_info.h" -#include "parallel/ops_info/gather_v2_p_info.h" - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc deleted file mode 100644 index 14483e97a1..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ /dev/null @@ -1,253 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/prelu_info.h" - -#include -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -/* - * prelu has 2 input - * A: A float tensor of shape [NCHW] representing the output of the preview layer. - * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. - * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 - */ -Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - std::vector stra = strategy->GetInputDim(); - if (stra[1].size() != PRELU_SECOND_INPUT_SIZE) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy size."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy size."; - } - return FAILED; - } - if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid channel strategy."; - } - return FAILED; - } - return SUCCESS; -} - -/* - * device matrix is same with the strategy matrix - */ -Status PReLUInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - input_strategy_ = input_strategy; - dev_matrix_shape_ = input_strategy; - return SUCCESS; -} - -Status PReLUInfo::InferMirrorOps() { - Shape param_tensor_map = inputs_tensor_map_[1]; - std::vector param_group; - if (CreateGroupByTensorMap(param_tensor_map, ¶m_group) != SUCCESS) { - return FAILED; - } else if (param_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } - OperatorVector op_for_param; - op_for_param = CreateMirrorOps(param_group[0].name(), param_group[0].GetDevNum()); - // op_for_inputs is empty - OperatorVector op_for_inputs; - mirror_ops_.push_back(op_for_inputs); - mirror_ops_.push_back(op_for_param); - std::string group_name = param_group[0].name(); - MS_LOG(INFO) << name_ << ": The mirror ops group is " << group_name; - return SUCCESS; -} - -Status PReLUInfo::InferForwardCommunication() { return SUCCESS; } - -/* - * the output tensor map is the same as the input tensor map - */ -Status PReLUInfo::InferTensorMap() { - TensorMap input_tensor_map; - // such as 4: input_tensor_map [3,2,1,0] - for (size_t i = 0; i < inputs_shape_[0].size(); ++i) { - input_tensor_map.push_back((int32_t)(inputs_shape_[0].size() - i - 1)); - } - - TensorMap param_tensor_map; - if (inputs_shape_[1][0] == 1) { - param_tensor_map.push_back(-1); - } else { - param_tensor_map.push_back(input_tensor_map.at(1)); - } - inputs_tensor_map_.push_back(input_tensor_map); - inputs_tensor_map_.push_back(param_tensor_map); - outputs_tensor_map_.push_back(input_tensor_map); - return SUCCESS; -} - -Dimensions PReLUInfo::GetOutputStrategy() { - Dimensions output_strategy = input_strategy_; - return output_strategy; -} - -Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - if (inputs_layout == nullptr || outputs_layout == nullptr) { - MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; - return FAILED; - } - TensorLayout input_layout, param_layout, output_layout; - if ((input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || - (param_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || - (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) { - return FAILED; - } - inputs_layout->push_back(input_layout); - inputs_layout->push_back(param_layout); - outputs_layout->push_back(output_layout); - return SUCCESS; -} - -Status PReLUInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape param_shape = inputs_shape_.at(1); - Shape output_shape = outputs_shape_.at(0); - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Dimensions output_strategy = GetOutputStrategy(); - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape param_slice_shape = inputs_slice_shape.at(1); - Shape output_slice_shape = outputs_slice_shape.at(0); - - // infer tensor layout - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - - TensorLayout input_layout = inputs_layout.at(0); - TensorLayout param_layout = inputs_layout.at(1); - TensorLayout output_layout = outputs_layout.at(0); - TensorInfo input_tensor_info(input_layout, input_shape, input_slice_shape); - TensorInfo param_tensor_info(param_layout, param_shape, param_slice_shape); - TensorInfo output_tensor_info(output_layout, output_shape, output_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - inputs_tensor_info_.push_back(param_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status PReLUInfo::GetAttrs() { - if ((inputs_shape_.size() != PRELU_INPUTS_SIZE) || (outputs_shape_.size() != PRELU_OUTPUTS_SIZE)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size " << inputs_shape_.size() << " or outputs shape size " - << outputs_shape_.size() << " is wrong."; - return FAILED; - } - return SUCCESS; -} - -Status PReLUInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status PReLUInfo::GenerateStrategies(int32_t stage_id) { - if (inputs_shape_.size() != PRELU_INPUTS_SIZE) { - return FAILED; - } - if (inputs_shape_[1].size() != PRELU_SECOND_INPUT_SIZE) { - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split; - input0_split.emplace_back(1); - input0_split.emplace_back(0); - (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size() - 2, 1); - Shape input1_split(inputs_shape_[1].size(), 0); - Shapes splittable_inputs = {input0_split, input1_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h deleted file mode 100644 index 28e149fad7..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -/* - * parallel class for PReLU Primitive - */ -class PReLUInfo : public OperatorInfo { - public: - PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~PReLUInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - Status GetAttrs() override; - Dimensions GetOutputStrategy(); - - private: - Dimensions input_strategy_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_PRELU_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc deleted file mode 100644 index 7304666a77..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.cc +++ /dev/null @@ -1,571 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/reduce_method_info.h" - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/tensor_layout/tensor_redistribution.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ReduceMethod::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - - dev_matrix_shape_ = input_strategy; - - return SUCCESS; -} - -std::vector ReduceMethod::reduce_dim() { - std::vector dim_list; - if (input_value_.size() < 2) { - MS_LOG(EXCEPTION) << name_ << ": Input value size is smaller than 2."; - } - if (input_value_.back() == nullptr) { - MS_LOG(EXCEPTION) << name_ << ": Input value is nullptr."; - } - MS_ASSERT(inputs_shape_.size() == 1); - auto input_dim = inputs_shape_.at(0).size(); - if (input_value_.back()->isa()) { - auto attr_axis = GetValue>(input_value_.back()); - // axis is (), reduce all dim - if (attr_axis.empty()) { - for (size_t i = 0; i < input_dim; ++i) { - dim_list.push_back(SizeToInt(i)); - } - } else { - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } - } else if (input_value_.back()->isa()) { - int axis = GetValue(input_value_.back()); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Axis type is invalid."; - } - - return dim_list; -} - -Status ReduceMethod::GetAttrs() { - // get attr cross_batch and keep_dims - auto keep_dims_iter = attrs_.find(KEEP_DIMS); - if (keep_dims_iter == attrs_.end()) { - MS_LOG(ERROR) << name_ << ": Don't have attr keep_dims."; - return FAILED; - } - - if (keep_dims_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(keep_dims_iter->second); - if (!keep_dims_iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": Keep_dims is not a bool."; - return FAILED; - } - keepdims_ = keep_dims_iter->second->cast()->value(); - } - - auto cross_batch_iter = attrs_.find(CROSS_BATCH); - if (cross_batch_iter != attrs_.end()) { - MS_EXCEPTION_IF_NULL(cross_batch_iter->second); - if (!cross_batch_iter->second->isa()) { - MS_LOG(ERROR) << name_ << ": cross_batch is not a bool."; - return FAILED; - } - cross_batch_ = cross_batch_iter->second->cast()->value(); - } - auto reducemethodcost = std::dynamic_pointer_cast(operator_cost()); - if (reducemethodcost == nullptr) { - MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; - return FAILED; - } - reducemethodcost->set_cross_batch(cross_batch_); - return SUCCESS; -} - -Status ReduceMethod::InferTensorMap() { - std::vector tensor_map_index, dim_list, output_tensor_map; - size_t size = inputs_shape_.at(0).size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - 1 - i)); - } - dim_list = reduce_dim(); - for (size_t i = 0; i < size; ++i) { - if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { - if (keepdims_) { - output_tensor_map.push_back(-1); - } else { - continue; - } - } else { - output_tensor_map.push_back(tensor_map_index[i]); - } - } - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(output_tensor_map); - - return SUCCESS; -} - -bool IsDataParallelStrategy(const Dimensions &strategy) { - CheckGlobalDeviceManager(); - size_t total_dev_num = g_device_manager->GetDeviceListByStageId(0).size(); - if (strategy.empty()) { - MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty"; - } - - return (IntToSize(strategy[0]) == total_dev_num); -} - -Status ReduceMethod::InferForwardCommunication() { - Dimensions stra = strategy_->GetInputDim().at(0); - if (cross_batch_ && IsDataParallelStrategy(stra)) { - MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; - return SUCCESS; - } - if (cross_batch_) { - MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; - return SUCCESS; - } - forward_op_.clear(); - std::vector dim_list = reduce_dim(); - size_t size = stra.size(); - // judge if the reduce dim is partitioned. - Shape group_creat_map; - if (dev_matrix_shape_.size() > size) { - group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); - } - for (size_t index = 0; index < size; ++index) { - auto pos = - std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); - if (pos != dim_list.end() && stra[index] != 1) { - continue; - } - group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); - } - std::vector forward_group; - if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; - return FAILED; - } - if (!forward_group.empty()) { - Operator op = CreateAllReduceOp(reduce_method_, forward_group[0].name()); - forward_op_.push_back(op); - std::string group_name = forward_group[0].name(); - MS_LOG(INFO) << name_ << ": Forward communication group is " << group_name; - } - - return SUCCESS; -} - -ForwardOp CreatReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { - // Creat AllReduceSum op - Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); - std::string group_name = forward_group[0].name(); - MS_LOG(INFO) << "The group of forward all reduce is " << group_name; - - // Creat RealDiv op - OperatorName operator1_name = REAL_DIV; - std::vector device_list = forward_group[0].GetDevicesList(); - auto divisor = static_cast(device_list.size()); - std::vector tensor_data = {divisor}; - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, dtype); - ValuePtr op1_param_value = MakeValue(tensor_ptr); - Attr op1_param = std::make_pair("divisor", op1_param_value); - OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; - OperatorAttrs operator1_attrs; - OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params); - Operator op1 = std::make_pair(operator1_name, operator1_args); - ForwardOp forward_op = {op0, op1}; - - std::string dtype_name = dtype->ToString(); - MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name; - return forward_op; -} - -Status ReduceMeanInfo::InferForwardCommunication() { - Dimensions stra = strategy_->GetInputDim().at(0); - if (cross_batch_ && IsDataParallelStrategy(stra)) { - MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication"; - return SUCCESS; - } - forward_op_.clear(); - std::vector dim_list = reduce_dim(); - size_t size = stra.size(); - // judge if the reduce dim is partitioned. - Shape group_creat_map; - if (dev_matrix_shape_.size() > size) { - group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); - } - for (size_t index = 0; index < size; ++index) { - auto pos = - std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); - if (pos != dim_list.end() && stra[index] != 1) { - continue; - } - group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); - } - std::vector forward_group; - if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; - return FAILED; - } - if (!forward_group.empty()) { - if ((outputs_dtype_ == nullptr) || !outputs_dtype_->isa()) { - MS_LOG(ERROR) << name_ << ": The dtype of output is not Array"; - return FAILED; - } - - auto element_type = outputs_dtype_->cast()->element(); - forward_op_ = CreatReduceMeanForwardOp(forward_group, element_type); - } - - return SUCCESS; -} - -Status ReduceMethod::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = inputs_tensor_map_.at(0); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << " Infer MirrorOps failed."; - return FAILED; - } - - OperatorVector op_for_weight; - OperatorVector op_for_reduce_axis; // helper node - if (input_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_weight); - mirror_ops_.push_back(op_for_reduce_axis); - std::string group_name = input_group[0].name(); - MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success, the group is " << group_name; - } - - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = inputs_tensor_map_.at(0); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; - return FAILED; - } - - OperatorVector op_for_weight; - if (input_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } else { - op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - mirror_ops_.push_back(op_for_weight); - MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success."; - } - - return SUCCESS; -} - -Dimensions ReduceMethod::InferOutputStrategy() { - std::vector dim_list = reduce_dim(); - Dimensions output_strategy; - Dimensions stra = strategy_->GetInputDim().at(0); - // if keepdims_ is true,then output strategy is same with input. - for (size_t i = 0; i < stra.size(); ++i) { - if (find(dim_list.begin(), dim_list.end(), SizeToInt(i)) != dim_list.end()) { - if (keepdims_) { - output_strategy.push_back(1); - } - } else { - output_strategy.push_back(stra[i]); - } - } - return output_strategy; -} - -Status ReduceMethod::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Dimensions output_strategy = InferOutputStrategy(); - - Strategys outputs_strategy = {output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape output_slice_shape = outputs_slice_shape.at(0); - - TensorLayout input_tensor_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { - return FAILED; - } - - std::vector dim_list = reduce_dim(); - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - input_tensor_info.set_reduce_dim(dim_list); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - - return SUCCESS; -} - -Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status ReduceMethod::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - is_auto_parallel_ = true; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} - -Status ReduceMethod::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - - return SUCCESS; -} - -Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed"; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed"; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success"; - return SUCCESS; -} - -std::vector ArgMaxWithValueInfo::reduce_dim() { - std::vector dim_list; - auto iter = attrs_.find(AXIS); - if (iter == attrs_.end()) { - MS_LOG(EXCEPTION) << name_ << ": Don't have attr axis."; - } - - MS_ASSERT(inputs_shape_.size() == 1); - auto input_dim = inputs_shape_.at(0).size(); - MS_EXCEPTION_IF_NULL(iter->second); - if (iter->second->isa()) { - auto attr_axis = GetValue>(iter->second); - if (attr_axis.empty()) { - for (size_t i = 0; i < input_dim; ++i) { - dim_list.push_back(SizeToInt(i)); - } - } else { - for (auto &axis : attr_axis) { - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } - } - } else if (iter->second->isa()) { - int axis = GetValue(iter->second); - axis < 0 ? dim_list.push_back(axis + SizeToInt(input_dim)) : dim_list.push_back(axis); - } else { - MS_LOG(EXCEPTION) << "Axis type is invalid."; - } - - return dim_list; -} - -Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) { - if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; - } else { - MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; - } - return FAILED; - } - std::vector dim_list = reduce_dim(); - MS_ASSERT(dim_list.size() == 1); - - std::vector stra = strategy->GetInputDim(); - MS_ASSERT(stra.size() == 1); - Shape input_strategy = stra.at(0); - MS_ASSERT(dim_list.at(0) < input_strategy.size()); - if (input_strategy.at(IntToSize(dim_list.at(0))) != 1) { - MS_LOG(WARNING) - << name_ - << " CheckStrategy for ArgMaxWithValueInfo, the strategy corresponding to axis is not one, real strategy " - "is " - << input_strategy.at(IntToSize(dim_list.at(0))) - << ", the output index may be not compatible with the stand alone Primitive"; - } - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferTensorMap() { - if (ReduceMethod::InferTensorMap() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed"; - return FAILED; - } - MS_ASSERT(outputs_tensor_map_.size() == 1); - outputs_tensor_map_.push_back(outputs_tensor_map_[0]); - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - Shape output_shape = outputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Dimensions output_strategy = InferOutputStrategy(); - - Strategys outputs_strategy = {output_strategy, output_strategy}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - Shape output_slice_shape = outputs_slice_shape.at(0); - - TensorLayout input_tensor_layout, output_tensor_layout; - if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || - (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { - return FAILED; - } - - std::vector dim_list = reduce_dim(); - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); - input_tensor_info.set_reduce_dim(dim_list); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - outputs_tensor_info_.push_back(output_tensor_info); - return SUCCESS; -} - -Status ArgMaxWithValueInfo::InferAsLossDivisor() { - if (outputs_tensor_map_.empty()) { - MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; - return FAILED; - } - - MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; - if (outputs_tensor_map_[0].empty()) { - as_loss_divisor_ = SizeToInt(global_device_list_.size()); - MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; - return SUCCESS; - } - - as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); - - std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_); - std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]); - MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str - << ", " << output_tensor_map_str << ", " << as_loss_divisor_; - return SUCCESS; -} - -Status ArgMaxWithValueInfo::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 2)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - is_auto_parallel_ = true; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated strategy " << success; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h b/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h deleted file mode 100644 index 796c7e457b..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reduce_method_info.h +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ - -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/value.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class ReduceMethod : public OperatorInfo { - public: - ReduceMethod(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} - ~ReduceMethod() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - std::string reduce_method_; - bool keepdims_ = false; - bool cross_batch_ = false; - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override; - Dimensions InferOutputStrategy(); - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferMirrorOps() override; - virtual std::vector reduce_dim(); - Status InferForwardCommunication() override; - Status InferDevMatrixShape() override; -}; - -class ReduceMaxInfo : public ReduceMethod { - public: - ReduceMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MAX; - } - - ~ReduceMaxInfo() override = default; -}; - -class ArgMaxWithValueInfo : public ReduceMethod { - public: - ArgMaxWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MAX; - } - - ~ArgMaxWithValueInfo() override = default; - - Status GenerateStrategies(int32_t stage_id) override; - - protected: - std::vector reduce_dim() override; - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferAsLossDivisor() override; -}; - -class ArgMinWithValueInfo : public ArgMaxWithValueInfo { - public: - ArgMinWithValueInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ArgMaxWithValueInfo(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MIN; - } - - ~ArgMinWithValueInfo() override = default; -}; - -class ReduceMeanInfo : public ReduceMethod { - public: - ReduceMeanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - set_cost(std::make_shared()); - } - - ~ReduceMeanInfo() override = default; - - protected: - Status InferForwardCommunication() override; -}; - -class ReduceSumInfo : public ReduceMethod { - public: - ReduceSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_SUM; - } - - ~ReduceSumInfo() override = default; -}; - -class ReduceMinInfo : public ReduceMethod { - public: - ReduceMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : ReduceMethod(name, inputs_shape, outputs_shape, attrs) { - reduce_method_ = REDUCE_OP_MIN; - } - - ~ReduceMinInfo() override = default; -}; -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc deleted file mode 100644 index 57e1a76d0a..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ /dev/null @@ -1,507 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/reshape_info.h" - -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - size_t strategy_size = strategy->GetInputNumber(); - if (strategy_size != 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy size " << strategy_size; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy size " << strategy_size; - } - return FAILED; - } - return SUCCESS; -} - -/* - * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of - * device matrix - * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) - */ -Status ReshapeInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - input_strategy_ = stra.at(0); - dev_matrix_shape_.push_back(input_strategy_[0]); - return SUCCESS; -} - -/* - * there is no Parameter for Reshape Primitive, so no need to do allreduce - */ -Status ReshapeInfo::InferMirrorOps() { - mirror_ops_.clear(); - Shape input_tensor_map = input_layout_.tensor_map().array(); - std::vector input_group; - if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed."; - return FAILED; - } - - OperatorVector op_for_input; - if (input_group.empty()) { - MS_LOG(INFO) << name_ << ": The mirror ops is empty."; - return SUCCESS; - } - if (!input_group.empty()) { - op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum()); - std::string group_name = input_group[0].name(); - MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name; - } - mirror_ops_.push_back(op_for_input); - OperatorVector op_for_input_empty; - mirror_ops_.push_back(op_for_input_empty); - - return SUCCESS; -} - -/* - * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce - */ -Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; } - -/* - * get shape input of Reshape Primitive - * the result is saved in parameter_input_v_ - * not support -1 - */ -Status ReshapeInfo::GetParameterInput() { - if (input_value_[1] == nullptr) { - MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; - return FAILED; - } - std::vector elements; - ValueTuplePtr dim_tuple = input_value_[1]->cast(); - if (dim_tuple == nullptr) { - MS_LOG(ERROR) << name_ << ": Input_value_[1] must be ValueTuplePtr."; - return FAILED; - } - elements = dim_tuple->value(); - if (elements.size() != outputs_shape_[0].size()) { - MS_LOG(ERROR) << name_ << ": Elements size must equal to outputs shape[0] size."; - return FAILED; - } - - for (auto &element : elements) { - MS_EXCEPTION_IF_NULL(element); - if (element->isa()) { - int32_t axis = element->cast()->value(); - parameter_input_v_.push_back(axis); - } else { - MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; - return FAILED; - } - } - return SUCCESS; -} - -Status ReshapeInfo::ComputeReplaceOp() { - RankList dev_list = global_device_list(); - TensorRedistribution tensor_redistribution(!is_generating_costs_, true); - if (tensor_redistribution.Init(input_layout_, output_layout_, dev_list) == FAILED) { - if (is_generating_costs_) { - MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed."; - } else { - MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; - } - return FAILED; - } - MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); - MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); - MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); - RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); - if (redistribution_oplist_ptr == nullptr) { - if (is_generating_costs_) { - MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed."; - } else { - MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; - } - return FAILED; - } - replace_op_ = redistribution_oplist_ptr->first; - replace_op_info_ = redistribution_oplist_ptr->second; - MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); - return SUCCESS; -} - -/* - * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement, - * all other dimension is set to None - * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) - */ -Status ReshapeInfo::InferTensorMap() { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are " - << inputs_shape_.size() << " and " << outputs_shape_.size(); - return FAILED; - } - - std::vector tensor_map_index_input; - tensor_map_index_input.push_back(0); - - for (size_t j = 1; j < inputs_shape_[0].size(); ++j) { - tensor_map_index_input.push_back(MAP_NONE); - } - inputs_tensor_map_.push_back(tensor_map_index_input); - - std::vector tensor_map_index_output; - tensor_map_index_output.push_back(0); - - for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { - tensor_map_index_output.push_back(MAP_NONE); - } - outputs_tensor_map_.push_back(tensor_map_index_output); - return SUCCESS; -} - -/* - * the output tensor strategy is the same as input tensor strategy - * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number) - */ -Strategys ReshapeInfo::GetOutputsStrategy() { - Strategys outputs_strategy; - std::vector strategy; - strategy.push_back(input_strategy_[0]); - for (size_t j = 1; j < outputs_shape_[0].size(); ++j) { - strategy.push_back(1); - } - outputs_strategy.push_back(strategy); - return outputs_strategy; -} - -Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - if (inputs_layout == nullptr || outputs_layout == nullptr) { - MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; - return FAILED; - } - Arrangement dev_matrix; - Status status = dev_matrix.Init(dev_matrix_shape_); - if (status != Status::SUCCESS) { - return status; - } - // infer input tensor info - Shape shape_array_in = inputs_shape_.at(0); - TensorMap tensor_map_array_in = inputs_tensor_map_.at(0); - TensorLayout tensor_layout_in; - Map tensor_map_in; - status = tensor_map_in.Init(tensor_map_array_in); - if (status != Status::SUCCESS) { - return status; - } - Arrangement shape_in; - status = shape_in.Init(shape_array_in); - if (status != Status::SUCCESS) { - return status; - } - (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in); - inputs_layout->push_back(tensor_layout_in); - // infer output tensor info - Shape shape_array_out = outputs_shape_.at(0); - - TensorMap tensor_map_array_out = outputs_tensor_map_.at(0); - TensorLayout tensor_layout_out; - Map tensor_map_out; - status = tensor_map_out.Init(tensor_map_array_out); - if (status != Status::SUCCESS) { - return status; - } - Arrangement shape_out; - status = shape_out.Init(shape_array_out); - if (status != Status::SUCCESS) { - return status; - } - (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out); - outputs_layout->push_back(tensor_layout_out); - - input_layout_ = tensor_layout_in; - output_layout_ = tensor_layout_out; - return SUCCESS; -} - -Status ReshapeInfo::InferTensorInfo() { - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = GetOutputsStrategy(); - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - TensorLayout tensor_layout_in = inputs_layout.at(0); - TensorLayout tensor_layout_out = outputs_layout.at(0); - Shape shape_array_in = inputs_shape_.at(0); - Shape slice_shape_in = inputs_slice_shape.at(0); - Shape shape_array_out = outputs_shape_.at(0); - Shape slice_shape_out = outputs_slice_shape.at(0); - TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); - TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_out); - return SUCCESS; -} - -void ReshapeInfo::InferTensorInfoByLayout() { - TensorInfo tensor_info_in(input_layout_); - TensorInfo tensor_info_out(output_layout_); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_out); -} - -/* - * compute parameter_input_v_ during this method - */ -Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } - -void ReshapeInfo::device_number(const StrategyPtr &strategy) { - int32_t stage = 0; - if (strategy != nullptr) { - stage = strategy->GetInputStage(); - } - CheckGlobalDeviceManager(); - global_device_list_ = g_device_manager->GetDeviceListByStageId(stage); - dev_num_ = SizeToInt(global_device_list_.size()); - MS_ASSERT(dev_num_ > 0); -} - -Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { - std::vector tensor_map_index; - for (size_t i = 0; i < shape.size(); i++) { - tensor_map_index.push_back(MAP_NONE); - } - Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape); - if (status != Status::SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed."; - return status; - } - return Status::SUCCESS; -} - -Status ReshapeInfo::Init(const StrategyPtr &strategy) { - ResetQueueMember(); - device_number(strategy); - if (strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - } else { - if (!input_layout_set_flag_) { - MS_ASSERT(inputs_shape_.size() == 1); - Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": infer input default layout failed."; - return status; - } - } - if (!output_layout_set_flag_) { - MS_ASSERT(output_layout_.size() == 1); - Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": infer output default layout failed."; - return status; - } - } - inputs_tensor_map_.push_back(input_layout_.tensor_map().array()); - outputs_tensor_map_.push_back(output_layout_.tensor_map().array()); - InferTensorInfoByLayout(); - // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps - dev_matrix_shape_ = input_layout_.device_arrangement().array(); - if (InferMirrorOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; - return FAILED; - } - // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps - dev_matrix_shape_ = output_layout_.device_arrangement().array(); - if (InferVirtualDivOps() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; - return FAILED; - } - } - Status status = ComputeReplaceOp(); - if (status != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed."; - return status; - } - return SUCCESS; -} - -Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -void ReshapeInfo::SetCostForReshapeWithParameter() { - size_t success = 0; - for (auto &sp : sp_vector_) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } -} - -void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) { - MS_EXCEPTION_IF_NULL(strategy); - int32_t stage_id = strategy->GetInputStage(); - double computation_cost = - operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - std::shared_ptr result = std::make_shared(computation_cost, communication_cost); - result->communication_without_parameter_ = - operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id); - result->communication_with_partial_para_ = - result->communication_without_parameter_ + - COST_MODEL_GAMMA * (communication_cost - result->communication_without_parameter_); - - // Breaking ties for preferring data parallelization - BreakingTiesForPerferringDataParallel(strategy, result); - // refine communication cost calculation for practice - RefineForPracticalCost(result, false); - - std::shared_ptr swc = - std::make_shared(strategy, inputs_tensor_info_, outputs_tensor_info_); - swc->cost_list.push_back(result); - strategy_cost_.emplace_back(swc); -} - -Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GetAttrs failed."; - return FAILED; - } - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split; - (void)input0_split.insert(input0_split.end(), inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - // strategy used only in the input node is parameter, - // in other case, use the input node's output_layout as input_layout. - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - return SUCCESS; -} - -Status ReshapeInfo::GenetateStrategyCosts(const std::vector> &pre_stra_costs, - const std::vector> &next_stra_costs, - int32_t out_index, int32_t in_index, bool is_prev_param) { - is_generating_costs_ = true; - for (auto pre_stra_cost : pre_stra_costs) { - std::vector pre_out_tensor_infos; - if (is_prev_param) { - pre_out_tensor_infos = pre_stra_cost->inputs_ptr; - } else { - pre_out_tensor_infos = pre_stra_cost->outputs_ptr; - } - if (pre_out_tensor_infos.size() <= IntToSize(out_index)) { - MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout"; - return FAILED; - } - TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; - SetInputLayout(pre_out_tensor_info.tensor_layout()); - // infer pre_node output strategy from output_layout. - Dimensions stra = pre_out_tensor_info.InferStrategy(); - if (stra.empty()) { - MS_LOG(ERROR) << "Infer strategy by tensor_info failed"; - return FAILED; - } - std::vector stra_inputs = {stra}; - StrategyPtr reshape_stra = std::make_shared(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs); - if (next_stra_costs.empty()) { - if (Init(nullptr) == FAILED) { - MS_LOG(ERROR) << "Failure:operator reshape init failed"; - return FAILED; - } - SetCostForReshape(reshape_stra); - continue; - } - for (auto next_stra_cost : next_stra_costs) { - std::vector next_in_tensor_infos = next_stra_cost->inputs_ptr; - if (next_in_tensor_infos.size() <= IntToSize(in_index)) { - MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout"; - return FAILED; - } - TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; - SetOutputLayout(next_in_tensor_info.tensor_layout()); - if (Init(nullptr) == FAILED) { - MS_LOG(DEBUG) << "Failure:operator reshape init failed"; - continue; - } - SetCostForReshape(reshape_stra); - } - } - is_generating_costs_ = false; - if (strategy_cost_.empty()) { - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h deleted file mode 100644 index 77a1f8e7f1..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ - -#include - -#include -#include -#include -#include - -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -/* - * parallel class for Reshape Primitive - */ -class ReshapeInfo : public OperatorInfo { - public: - ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), - dev_num_(0), - pre_operator_index_(0), - next_operator_index_(0), - input_layout_set_flag_(false), - output_layout_set_flag_(false) {} - ~ReshapeInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - void SetInputLayout(const TensorLayout &input_layout) { - input_layout_ = input_layout; - input_layout_set_flag_ = true; - } - void SetOutputLayout(const TensorLayout &output_layout) { - output_layout_ = output_layout; - output_layout_set_flag_ = true; - } - void SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy); - void SetCostForReshapeWithParameter(); - void set_pre_operator_name(const std::string &pre_name) { pre_operator_name_ = pre_name; } - void set_next_operator_name(const std::string &next_name) { next_operator_name_ = next_name; } - void set_pre_operator_index(int32_t pre_index) { pre_operator_index_ = pre_index; } - void set_next_operator_index(int32_t next_index) { next_operator_index_ = next_index; } - Status GenetateStrategyCosts(const std::vector> &pre_stra_costs, - const std::vector> &next_stra_costs, int32_t out_index, - int32_t in_index, bool is_prev_param); - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - std::string pre_operator_name() const { return pre_operator_name_; } - std::string next_operator_name() const { return next_operator_name_; } - int32_t pre_operator_index() const { return pre_operator_index_; } - int32_t next_operator_index() const { return next_operator_index_; } - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorMap() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - Status GetAttrs() override; - Strategys GetOutputsStrategy(); - - private: - Status GetParameterInput(); - Status ComputeReplaceOp(); - void InferTensorInfoByLayout(); - void device_number(const StrategyPtr &strategy); - Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); - - int32_t dev_num_; - int32_t pre_operator_index_; - int32_t next_operator_index_; - std::vector parameter_input_v_; - std::vector sp_vector_; - Dimensions input_strategy_; - TensorLayout input_layout_; - TensorLayout output_layout_; - bool input_layout_set_flag_; - bool output_layout_set_flag_; - bool is_generating_costs_; - std::string pre_operator_name_; - std::string next_operator_name_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_RESHAPE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.cc b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.cc deleted file mode 100644 index 772a4f83f6..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.cc +++ /dev/null @@ -1,147 +0,0 @@ -/** -#include "utils/log_adapter.h" - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/tmp_identity_info.h" - -#include -#include - -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": invalid strategy."; - } - return FAILED; - } - return SUCCESS; -} - -Status TmpIdentityInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions input_strategy = stra.at(0); - dev_matrix_shape_ = input_strategy; - return SUCCESS; -} - -Status TmpIdentityInfo::InferTensorMap() { - std::vector tensor_map_index; - size_t size = inputs_shape_[0].size(); - // such as 4: tensor_map_index [3,2,1,0] - for (size_t i = 0; i < size; ++i) { - tensor_map_index.push_back((int32_t)(size - 1 - i)); - } - - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(tensor_map_index); - return SUCCESS; -} - -Status TmpIdentityInfo::InferTensorInfo() { - // infer tensor shape - Shape input_shape = inputs_shape_.at(0); - - // infer slice shape - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = {inputs_strategy.at(0)}; - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - Shape input_slice_shape = inputs_slice_shape.at(0); - - TensorLayout input_tensor_layout; - if (input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) { - return FAILED; - } - - TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); - - inputs_tensor_info_.push_back(input_tensor_info); - outputs_tensor_info_.push_back(input_tensor_info); // the same as input - - return SUCCESS; -} - -Status TmpIdentityInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": Inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed."; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h deleted file mode 100644 index f7895d0511..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ - -#include -#include -#include - -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class TmpIdentityInfo : public OperatorInfo { - // This operator is only used for the case of a parameter tensor being used by multiple operators, where we - // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, - // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. - public: - TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, - const std::string &name = IDENTITY_INFO) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~TmpIdentityInfo() override = default; - - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status GetAttrs() override { return SUCCESS; } - Status InferMirrorOps() override { return SUCCESS; } - Status InferForwardCommunication() override { return SUCCESS; } - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TMP_IDENTITY_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc deleted file mode 100644 index 49bbae0cb4..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc +++ /dev/null @@ -1,247 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/transpose_info.h" - -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - return SUCCESS; -} - -Status TransposeInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - input_strategy_ = stra.at(0); - for (auto &iter : input_strategy_) { - dev_matrix_shape_.push_back(iter); - } - return SUCCESS; -} - -// there is no Parameter for Transpose Primitive, so no need to do all reduce -Status TransposeInfo::InferMirrorOps() { return SUCCESS; } - -// there is no reduction dimension for forward computation of Transpose Primitive, so no need to do all reduce -Status TransposeInfo::InferForwardCommunication() { return SUCCESS; } - -/* - * get perm input of Transpose Primitive - * perm is a permutation of the dimensions of input - * the result is saved in axis_v_ - */ -Status TransposeInfo::ComputeAxis() { - if (input_value_[1] == nullptr) { - MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr."; - return FAILED; - } - std::vector elements; - ValueTuplePtr dim_tuple = input_value_[1]->cast(); - if (dim_tuple == nullptr) { - MS_LOG(ERROR) << name_ << ": input_value_[1] must be ValueTuplePtr."; - return FAILED; - } - elements = dim_tuple->value(); - if (elements.size() != inputs_shape_[0].size()) { - MS_LOG(ERROR) << name_ << ": elements size must equal to inputs shape 0 size."; - return FAILED; - } - axis_v_.clear(); - for (auto &element : elements) { - MS_EXCEPTION_IF_NULL(element); - if (element->isa()) { - int32_t axis = element->cast()->value(); - axis_v_.push_back(axis); - } else { - MS_LOG(ERROR) << name_ << ": The value of axis must be int32."; - return FAILED; - } - } - - for (int32_t i = 0; i < SizeToInt(axis_v_.size()); i++) { - auto iter = std::find(axis_v_.begin(), axis_v_.end(), i); - if (iter == axis_v_.end()) { - MS_LOG(ERROR) << name_ << ": axis_v_ must be a permutation."; - } - } - return SUCCESS; -} - -// the output tensor map is the permutation of input tensor map, the permutation is axis_v -Status TransposeInfo::InferTensorMap() { - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": inputs_shape_ and outputs_shape_ size must be 1, inputs shape and outputs shape is " - << inputs_shape_.size() << ", " << outputs_shape_.size(); - return FAILED; - } - - std::vector tensor_map_index_input; - for (size_t j = 0; j < inputs_shape_[0].size(); ++j) { - tensor_map_index_input.push_back(SizeToInt(inputs_shape_[0].size() - j - 1)); - } - inputs_tensor_map_.push_back(tensor_map_index_input); - - std::vector tensor_map_index_output = tensor_map_index_input; - for (uint32_t i = 0; i < tensor_map_index_output.size(); i++) { - tensor_map_index_output[i] = tensor_map_index_input[IntToUint(axis_v_[i])]; - } - outputs_tensor_map_.push_back(tensor_map_index_output); - return SUCCESS; -} - -// the output tensor strategy is the permutation of input tensor strategy, the permutation is axis_v -Strategys TransposeInfo::GetOutputsStrategy() { - Strategys outputs_strategy; - std::vector strategy = input_strategy_; - for (uint32_t i = 0; i < strategy.size(); i++) { - strategy[i] = input_strategy_[IntToUint(axis_v_[i])]; - } - outputs_strategy.push_back(strategy); - return outputs_strategy; -} - -Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { - if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { - MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; - return FAILED; - } - Shape shape_in = inputs_shape_.at(0); - TensorMap tensor_map_in = inputs_tensor_map_.at(0); - Shape shape_out = outputs_shape_.at(0); - TensorMap tensor_map_out = outputs_tensor_map_.at(0); - - TensorLayout tensor_layout_in, tensor_layout_out; - if ((tensor_layout_in.InitFromVector(dev_matrix_shape_, tensor_map_in, shape_in) != SUCCESS) || - (tensor_layout_out.InitFromVector(dev_matrix_shape_, tensor_map_out, shape_out) != SUCCESS)) { - return FAILED; - } - - inputs_layout->push_back(tensor_layout_in); - outputs_layout->push_back(tensor_layout_out); - return SUCCESS; -} - -Status TransposeInfo::InferTensorInfo() { - Shapes inputs_slice_shape, outputs_slice_shape; - Strategys inputs_strategy = strategy_->GetInputDim(); - Strategys outputs_strategy = GetOutputsStrategy(); - if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { - return FAILED; - } - - TensorLayouts inputs_layout, outputs_layout; - if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) { - return FAILED; - } - TensorLayout tensor_layout_in = inputs_layout.at(0); - TensorLayout tensor_layout_out = outputs_layout.at(0); - Shape shape_array_in = inputs_shape_.at(0); - Shape slice_shape_in = inputs_slice_shape.at(0); - Shape shape_array_out = outputs_shape_.at(0); - Shape slice_shape_out = outputs_slice_shape.at(0); - TensorInfo tensor_info_in(tensor_layout_in, shape_array_in, slice_shape_in); - TensorInfo tensor_info_out(tensor_layout_out, shape_array_out, slice_shape_out); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_out); - return SUCCESS; -} - -// compute axis_v_ during this method -Status TransposeInfo::GetAttrs() { return ComputeAxis(); } - -Status TransposeInfo::Init(const StrategyPtr &strategy) { - if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - MS_LOG(INFO) << name_ << ": Init success."; - return SUCCESS; -} - -Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status TransposeInfo::GenerateStrategies(int32_t stage_id) { - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GetAttrs failed."; - return FAILED; - } - if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { - MS_LOG(ERROR) << name_ << ": inputs shape size or outputs shape size is wrong, " << inputs_shape_.size() << ", " - << outputs_shape_.size(); - return FAILED; - } - is_auto_parallel_ = true; - Shape input0_split(inputs_shape_[0].size(), 1); - Shapes splittable_inputs = {input0_split}; - std::vector sp_vector; - if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GenerateStrategiesForIndependentInputs failed"; - return FAILED; - } - size_t success = 0; - for (auto &sp : sp_vector) { - if (SetCostUnderStrategy(sp) == SUCCESS) { - success++; - MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; - PrintStrategy(sp); - } - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h deleted file mode 100644 index 50b76bde65..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -/* - * parallel class for Transpose Primitive - */ -class TransposeInfo : public OperatorInfo { - public: - TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~TransposeInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); - Status GetAttrs() override; - Strategys GetOutputsStrategy(); - - private: - Status ComputeAxis(); - std::vector axis_v_; - Dimensions input_strategy_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_TRANSPOSE_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc deleted file mode 100644 index ce8b04d802..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc +++ /dev/null @@ -1,229 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ops_info/virtual_dataset_info.h" - -#include -#include -#include - -#include "parallel/device_manager.h" -#include "parallel/device_matrix.h" -#include "parallel/step_parallel.h" -#include "parallel/context.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { - if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy."; - } else { - MS_LOG(ERROR) << name_ << ": Invalid strategy."; - } - return FAILED; - } - - std::vector stra = strategy->GetInputDim(); - if (stra.size() < 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Strategy size must be larger than 1."; - } else { - MS_LOG(ERROR) << name_ << ": Strategy size must be larger than 1."; - } - return FAILED; - } - if (stra.size() == 1) { - MS_LOG(WARNING) << name_ << ": Strategy size is 1."; - return SUCCESS; - } - Dimensions strategy_first = stra.at(1); - for (auto iter_strategy = stra.begin() + 1; iter_strategy != stra.end(); ++iter_strategy) { - if (iter_strategy->empty()) { - MS_LOG(ERROR) << name_ << ": iter_strategy size is zero."; - } - if (strategy_first.at(0) != *(iter_strategy->begin())) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": The first dimension of each strategy must be the same."; - } else { - MS_LOG(ERROR) << name_ << ": The first dimension of each strategy must be the same."; - } - return FAILED; - } - - for (auto iter_element = iter_strategy->begin() + 1; iter_element != iter_strategy->end(); ++iter_element) { - if (*iter_element != 1) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": All dimension except the first dimension of each strategy must be 1."; - } else { - MS_LOG(ERROR) << name_ << ": All dimension except the first dimension of each strategy must be 1."; - } - return FAILED; - } - } - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InferDevMatrixShape() { - std::vector stra = strategy_->GetInputDim(); - Dimensions strategy_first = stra.at(0); - int32_t stage = strategy_->GetInputStage(); - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(stage).size()); - int32_t batch_split_num = strategy_first.at(0); - dev_matrix_shape_.push_back(batch_split_num); - if (dev_num > batch_split_num) { - dev_matrix_shape_.push_back(dev_num / batch_split_num); - } - - return SUCCESS; -} - -Status VirtualDatasetInfo::InferMirrorOps() { return SUCCESS; } - -Status VirtualDatasetInfo::InferForwardCommunication() { return SUCCESS; } - -Status VirtualDatasetInfo::InferTensorMap() { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - std::vector tensor_map_index; - if (full_batch) { - tensor_map_index.push_back(MAP_NONE); - } else { - tensor_map_index.push_back((int32_t)(LAST_INDEX(SizeToUint(dev_matrix_shape_.size())))); - } - for (size_t j = 1; j < strategy_->GetInputDim()[i].size(); ++j) { - tensor_map_index.push_back(MAP_NONE); - } - inputs_tensor_map_.push_back(tensor_map_index); - outputs_tensor_map_.push_back(tensor_map_index); - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InferTensorInfo() { - for (size_t i = 0; i < strategy_->GetInputNumber(); i++) { - MS_LOG(INFO) << name_ << ": InferTensorInfo " << i << ", size " << strategy_->GetInputNumber(); - TensorLayout tensor_layout_in; - if (tensor_layout_in.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(i), inputs_shape_.at(i)) != SUCCESS) { - return FAILED; - } - TensorInfo tensor_info_in(tensor_layout_in); - inputs_tensor_info_.push_back(tensor_info_in); - outputs_tensor_info_.push_back(tensor_info_in); - } - return SUCCESS; -} - -Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } - -Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { - if (InitWithManualRepeatCalc(strategy) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": Init failed."; - return FAILED; - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { - if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; - } else { - MS_LOG(ERROR) << name_ << ": Init for cost model failed."; - } - return FAILED; - } - - MS_LOG(INFO) << name_ << ": Init for cost model success."; - return SUCCESS; -} - -void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { - for (size_t i = 0; i < inputs_shape_.size(); i++) { - split_flag_list_[i] = true; - } -} - -Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { - if (SetCostUnderStrategyBase(strategy) != SUCCESS) { - if (is_auto_parallel_) { - MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; - } - return FAILED; - } - - return SUCCESS; -} - -Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - size_t total_dev_num; - - if (GetAttrs() != SUCCESS) { - MS_LOG(ERROR) << name_ << ": GetAttrs failed"; - return FAILED; - } - - CheckGlobalDeviceManager(); - is_auto_parallel_ = true; - if (full_batch) { - total_dev_num = 1; - } else { - total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); - } - StrategyPtr sp; - std::vector strategy; - for (auto &shape : inputs_shape_) { - Shape temp; - temp.emplace_back(SizeToInt(total_dev_num)); - (void)temp.insert(temp.end(), shape.size() - 1, 1); - strategy.push_back(temp); - } - sp = std::make_shared(stage_id, strategy); - - if (SetCostUnderStrategy(sp) == SUCCESS) { - if (full_batch) { - MS_LOG(INFO) << name_ << ": Successfully generated full-batch-parallel-strategy."; - } else { - MS_LOG(INFO) << name_ << ": Successfully generated batch-parallel-strategy."; - } - PrintStrategy(sp); - } else { - if (full_batch) { - MS_LOG(ERROR) << name_ << ": Generating full-batch-parallel-strategy failed."; - } else { - MS_LOG(ERROR) << name_ << ": Generating batch-parallel-strategy failed."; - } - return FAILED; - } - return SUCCESS; -} - -Status VirtualDatasetInfo::InferAsLossDivisor() { - // no need to insert div op - as_loss_divisor_ = 1; - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h deleted file mode 100644 index 312ac7a6a4..0000000000 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_OPS_INFO_DATASET_INFO_H_ -#define PARALLEL_OPS_INFO_DATASET_INFO_H_ - -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/strategy.h" - -namespace mindspore { -namespace parallel { -class VirtualDatasetInfo : public OperatorInfo { - public: - VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, - const PrimitiveAttrs &attrs) - : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} - ~VirtualDatasetInfo() override = default; - Status Init(const StrategyPtr &strategy) override; - Status InitForCostModel(const StrategyPtr &strategy) override; - - Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr &strategy) override; - void ReComputeBatchSplitFlagList() override; - - protected: - Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override; - Status InferForwardCommunication() override; - Status InferTensorInfo() override; - Status InferDevMatrixShape() override; - Status InferTensorMap() override; - Status GetAttrs() override; - Status InferAsLossDivisor() override; -}; -} // namespace parallel -} // namespace mindspore - -#endif // PARALLEL_OPS_INFO_VIRTUAL_DATASET_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/parallel/ps/optimizer_info.cc deleted file mode 100644 index 98d36ad038..0000000000 --- a/mindspore/ccsrc/parallel/ps/optimizer_info.cc +++ /dev/null @@ -1,184 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ps/optimizer_info.h" -#include - -namespace mindspore { -namespace parallel { -namespace ps { -void OptimizerInfo::AddWorkspace(const AddressPtr &workspace) { workspaces_.push_back(workspace); } - -const std::vector &OptimizerInfo::inputs() { return inputs_; } - -const std::vector &OptimizerInfo::workspaces() { return workspaces_; } - -const std::vector &OptimizerInfo::outputs() { return outputs_; } - -bool OptimizerInfo::IsSparse() const { return false; } - -size_t OptimizerInfo::grad_index() { return 0; } - -size_t OptimizerInfo::indices_index() { return 0; } - -void OptimizerInfo::UpdateWeight(const WeightPtr &weight) { - AddressPtr weight_addr = std::make_shared(); - weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); - inputs_[0] = weight_addr; -} - -void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { - float *accum_grad_data = reinterpret_cast(gradient()->addr); - size_t size = gradient()->size / sizeof(float); - size_t grad_index = this->grad_index(); - size_t grad_offset = 0; - for (size_t i = 0; i < grad_index; i++) { - grad_offset += lengths[i]; - } - float *grad_data = values.data() + grad_offset; - CHECK_EQ(size, static_cast(lengths[grad_index])); - - for (size_t i = 0; i < size; i++) { - accum_grad_data[i] += grad_data[i]; - } -} - -void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { - // Append grad data to the end - float *accum_grad_data = reinterpret_cast(gradient()->addr); - - size_t grad_index = this->grad_index(); - size_t grad_offset = 0; - for (size_t i = 0; i < grad_index; i++) { - grad_offset += lengths[i]; - } - float *incr_grad_data = values.data() + grad_offset; - size_t incr_grad_size = lengths[grad_index] * sizeof(float); - - auto ret = memcpy_s(accum_grad_data + grads_offset_, incr_grad_size, incr_grad_data, incr_grad_size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - } - grads_offset_ += incr_grad_size; - gradient()->size += incr_grad_size; - - // Append indice data to the end - int *accum_indices_data = reinterpret_cast(indices()->addr); - - size_t indices_index = this->indices_index(); - size_t indice_offset = 0; - for (size_t i = 0; i < indices_index; i++) { - indice_offset += lengths[i]; - } - int *incr_indice_data = reinterpret_cast(values.data() + indice_offset); - size_t incr_indice_size = lengths[indices_index] * sizeof(float); - - auto ret2 = memcpy_s(accum_indices_data + indices_offset_, incr_indice_size, incr_indice_data, incr_indice_size); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; - } - indices_offset_ += incr_indice_size; - indices()->size += incr_indice_size; -} - -void SparseOptimInfo::Reset() { - auto &gradient = this->gradient(); - gradient->size = 0; - auto &indices = this->indices(); - indices->size = 0; - grads_offset_ = 0; - indices_offset_ = 0; -} - -MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, - const AddressPtr &learning_rate, const AddressPtr &gradient, - const AddressPtr &momentum) { - inputs_.push_back(weight); - inputs_.push_back(accumulate); - inputs_.push_back(learning_rate); - inputs_.push_back(gradient); - inputs_.push_back(momentum); -} - -const AddressPtr &MomentumOptimInfo::gradient() { return inputs_[3]; } - -const AddressPtr &MomentumOptimInfo::indices() { return inputs_[3]; } - -SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, - const AddressPtr &beta1_power, const AddressPtr &beta2_power, - const AddressPtr &learning_rate, const AddressPtr &beta1, - const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, - const AddressPtr &indices, size_t grads_offset, size_t indices_offset) { - inputs_.push_back(weight); - inputs_.push_back(m); - inputs_.push_back(v); - inputs_.push_back(beta1_power); - inputs_.push_back(beta2_power); - inputs_.push_back(learning_rate); - inputs_.push_back(beta1); - inputs_.push_back(beta2); - inputs_.push_back(epsilon); - inputs_.push_back(grad); - inputs_.push_back(indices); - grads_offset_ = grads_offset; - indices_offset_ = indices_offset; -} - -void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { - void *data_ptr = values.data(); - AddressPtr beta1_power = inputs_[3]; - size_t size = values.size() * sizeof(float); - auto ret = memcpy_s(beta1_power->addr, size, data_ptr, size); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - } -} - -const AddressPtr &SparseAdamOptimInfo::gradient() { return inputs_[9]; } - -const AddressPtr &SparseAdamOptimInfo::indices() { return inputs_[10]; } - -bool SparseAdamOptimInfo::IsSparse() const { return true; } - -size_t SparseAdamOptimInfo::grad_index() { return 6; } - -size_t SparseAdamOptimInfo::indices_index() { return 7; } - -SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, - const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, - size_t indices_offset) { - inputs_.push_back(weight); - inputs_.push_back(accum); - inputs_.push_back(linear); - inputs_.push_back(grad); - inputs_.push_back(indices); - grads_offset_ = grads_offset; - indices_offset_ = indices_offset; -} - -const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } - -const AddressPtr &SparseFtrlOptimInfo::indices() { return inputs_[4]; } - -bool SparseFtrlOptimInfo::IsSparse() const { return true; } - -size_t SparseFtrlOptimInfo::grad_index() { return 0; } - -size_t SparseFtrlOptimInfo::indices_index() { return 1; } -} // namespace ps -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ps/optimizer_info.h b/mindspore/ccsrc/parallel/ps/optimizer_info.h deleted file mode 100644 index b7c130764d..0000000000 --- a/mindspore/ccsrc/parallel/ps/optimizer_info.h +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ - -#include -#include "kernel/kernel.h" -#include "parallel/ps/common.h" - -namespace mindspore { -namespace parallel { -namespace ps { -using mindspore::kernel::AddressPtr; -class OptimizerInfo { - public: - OptimizerInfo() = default; - virtual ~OptimizerInfo() = default; - - virtual void Update(const Values &values, const Lengths &lengths) {} - virtual void UpdateWeight(const WeightPtr &weight); - virtual void Accumulate(const Values &values, const Lengths &lengths) = 0; - virtual void Reset() {} - void AddWorkspace(const AddressPtr &workspace); - - virtual const AddressPtr &gradient() = 0; - virtual const AddressPtr &indices() = 0; - const std::vector &inputs(); - const std::vector &workspaces(); - const std::vector &outputs(); - - virtual bool IsSparse() const; - virtual size_t grad_index(); - virtual size_t indices_index(); - - protected: - std::vector inputs_; - std::vector workspaces_; - std::vector outputs_; -}; - -class DenseOptimInfo : public OptimizerInfo { - public: - DenseOptimInfo() = default; - ~DenseOptimInfo() override = default; - - void Accumulate(const Values &values, const Lengths &lens) override; -}; - -class SparseOptimInfo : public OptimizerInfo { - public: - SparseOptimInfo() = default; - ~SparseOptimInfo() override = default; - - void Accumulate(const Values &values, const Lengths &lens) override; - void Reset() override; - - protected: - size_t grads_offset_{0}; - size_t indices_offset_{0}; -}; - -class MomentumOptimInfo : public DenseOptimInfo { - public: - MomentumOptimInfo(const AddressPtr &weight, const AddressPtr &accumulate, const AddressPtr &learning_rate, - const AddressPtr &gradient, const AddressPtr &momentum); - ~MomentumOptimInfo() override = default; - - const AddressPtr &gradient(); - const AddressPtr &indices(); -}; - -class SparseAdamOptimInfo : public SparseOptimInfo { - public: - SparseAdamOptimInfo(const AddressPtr &weight, const AddressPtr &m, const AddressPtr &v, const AddressPtr &beta1_power, - const AddressPtr &beta2_power, const AddressPtr &learning_rate, const AddressPtr &beta1, - const AddressPtr &beta2, const AddressPtr &epsilon, const AddressPtr &grad, - const AddressPtr &indices, size_t grads_offset, size_t indices_offset); - ~SparseAdamOptimInfo() override = default; - - void Update(const Values &values, const Lengths &lens) override; - const AddressPtr &gradient(); - const AddressPtr &indices(); - bool IsSparse() const override; - size_t grad_index() override; - size_t indices_index() override; -}; - -class SparseFtrlOptimInfo : public SparseOptimInfo { - public: - SparseFtrlOptimInfo(const AddressPtr &weight, const AddressPtr &accum, const AddressPtr &linear, - const AddressPtr &grad, const AddressPtr &indices, size_t grads_offset, size_t indices_offset); - ~SparseFtrlOptimInfo() override = default; - - const AddressPtr &gradient(); - const AddressPtr &indices(); - bool IsSparse() const override; - size_t grad_index() override; - size_t indices_index() override; -}; -} // namespace ps -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_H_ diff --git a/mindspore/ccsrc/parallel/ps/optimizer_info_builder.cc b/mindspore/ccsrc/parallel/ps/optimizer_info_builder.cc deleted file mode 100644 index 02c99c4959..0000000000 --- a/mindspore/ccsrc/parallel/ps/optimizer_info_builder.cc +++ /dev/null @@ -1,184 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ps/optimizer_info_builder.h" -#include -#include -#include - -namespace mindspore { -namespace parallel { -namespace ps { -OptimizerInfo *OptimizerInfoBuilder::Build(const std::shared_ptr &pserver_kernel, - const WeightPtr &weight, const Keys &keys, const Values &values, - const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) { - OptimizerInfo *optim_info = BuildInputs(weight, keys, values, lens, inputs_shape, worker_num); - std::vector ws_sizes = pserver_kernel->workspace_sizes(); - BuildWorkspaces(optim_info, ws_sizes, worker_num); - BuildOutputs(optim_info, worker_num); - return optim_info; -} - -void OptimizerInfoBuilder::BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, - size_t worker_num) { - for (size_t i = 0; i < ws_sizes.size(); i++) { - size_t size = ws_sizes[i]; - AddressPtr workspace = std::make_shared(); - workspace->addr = new float[size]; - workspace->size = size; - info->AddWorkspace(workspace); - } -} - -OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, - const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num) { - AddressPtr weight_addr = std::make_shared(); - weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); - void *data_ptr = values.data(); - AddressPtr accumulate = std::make_shared(); - accumulate->addr = new float[weight->size()]; - accumulate->size = weight->size(); - AddressPtr learning_rate = std::make_shared(); - learning_rate->addr = data_ptr; - learning_rate->size = lens[0]; - AddressPtr gradient = std::make_shared(); - gradient->addr = reinterpret_cast(learning_rate->addr) + lens[0]; - gradient->size = lens[1]; - AddressPtr momentum = std::make_shared(); - momentum->addr = reinterpret_cast(gradient->addr) + lens[1]; - momentum->size = lens[2]; - - return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum); -} - -OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, - const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num) { - AddressPtr weight_addr = std::make_shared(); - weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); - AddressPtr m = std::make_shared(); - m->addr = new float[weight->size()]; - m->size = weight->size() * sizeof(float); - AddressPtr v = std::make_shared(); - v->addr = new float[weight->size()]; - v->size = weight->size() * sizeof(float); - - void *data_ptr = values.data(); - void *copy_data_ptr = new float[values.size()]; - auto ret = memcpy_s(copy_data_ptr, values.size() * sizeof(float), data_ptr, values.size() * sizeof(float)); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - } - - AddressPtr beta1_power = std::make_shared(); - beta1_power->addr = copy_data_ptr; - beta1_power->size = lens[0] * sizeof(float); - AddressPtr beta2_power = std::make_shared(); - beta2_power->addr = reinterpret_cast(beta1_power->addr) + lens[0]; - beta2_power->size = lens[1] * sizeof(float); - - AddressPtr learning_rate = std::make_shared(); - learning_rate->addr = reinterpret_cast(beta2_power->addr) + lens[1]; - learning_rate->size = lens[2] * sizeof(float); - - AddressPtr beta1 = std::make_shared(); - beta1->addr = reinterpret_cast(learning_rate->addr) + lens[2]; - beta1->size = lens[3] * sizeof(float); - - AddressPtr beta2 = std::make_shared(); - beta2->addr = reinterpret_cast(beta1->addr) + lens[3]; - beta2->size = lens[4] * sizeof(float); - - AddressPtr epsilon = std::make_shared(); - epsilon->addr = reinterpret_cast(beta2->addr) + lens[4]; - epsilon->size = lens[5] * sizeof(float); - - const std::shared_ptr> &grad_shape = (*inputs_shape)[9]; - size_t total_grad_size = - std::accumulate((*grad_shape).begin(), (*grad_shape).end(), sizeof(float), std::multiplies()); - AddressPtr grad = std::make_shared(); - grad->addr = new float[total_grad_size * worker_num]; - auto ret2 = memcpy_s(grad->addr, lens[6] * sizeof(float), reinterpret_cast(epsilon->addr) + lens[5], - lens[6] * sizeof(float)); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; - } - grad->size = lens[6] * sizeof(float); - - const std::shared_ptr> &indices_shape = (*inputs_shape)[10]; - size_t total_indice_size = - std::accumulate((*indices_shape).begin(), (*indices_shape).end(), sizeof(float), std::multiplies()); - AddressPtr indices = std::make_shared(); - indices->addr = new float[total_indice_size * worker_num]; - auto ret3 = memcpy_s(indices->addr, lens[7] * sizeof(float), - reinterpret_cast(epsilon->addr) + lens[5] + lens[6], lens[7] * sizeof(float)); - if (ret3 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret3 << ")"; - } - indices->size = lens[7] * sizeof(float); - - return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, - grad, indices, total_grad_size, total_indice_size); -} - -OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, - const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num) { - AddressPtr weight_addr = std::make_shared(); - weight_addr->addr = weight->data(); - weight_addr->size = weight->size(); - AddressPtr accum = std::make_shared(); - accum->addr = new float[weight->size()]; - accum->size = weight->size() * sizeof(float); - for (size_t i = 0; i < weight->size(); i++) { - float *tmp = reinterpret_cast(accum->addr); - tmp[i] = 1.0; - } - AddressPtr linear = std::make_shared(); - linear->addr = new float[weight->size()]; - memcpy_s(linear->addr, weight->size() * sizeof(float), 0x00, weight->size() * sizeof(float)); - linear->size = weight->size() * sizeof(float); - - const std::shared_ptr> &grad_shape = (*inputs_shape)[3]; - size_t total_grad_size = std::accumulate((*grad_shape).begin(), (*grad_shape).end(), 1, std::multiplies()); - AddressPtr grad = std::make_shared(); - grad->addr = new float[total_grad_size * worker_num]; - auto ret = memcpy_s(grad->addr, lens[0] * sizeof(float), values.data(), lens[0] * sizeof(float)); - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - } - grad->size = lens[0] * sizeof(float); - - const std::shared_ptr> &indices_shape = (*inputs_shape)[4]; - size_t total_indice_size = - std::accumulate((*indices_shape).begin(), (*indices_shape).end(), 1, std::multiplies()); - AddressPtr indices = std::make_shared(); - indices->addr = new float[total_indice_size * worker_num]; - auto ret2 = memcpy_s(indices->addr, lens[1] * sizeof(float), reinterpret_cast(values.data()) + lens[0], - lens[1] * sizeof(float)); - if (ret2 != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; - } - indices->size = lens[1] * sizeof(float); - - return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, total_grad_size, total_indice_size); -} -} // namespace ps -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ps/optimizer_info_builder.h b/mindspore/ccsrc/parallel/ps/optimizer_info_builder.h deleted file mode 100644 index 0703f5e755..0000000000 --- a/mindspore/ccsrc/parallel/ps/optimizer_info_builder.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ - -#include -#include -#include "kernel/kernel.h" -#include "kernel/ps/pserver_kernel.h" -#include "parallel/ps/optimizer_info.h" - -namespace mindspore { -namespace parallel { -namespace ps { -using mindspore::kernel::KernelMod; -using mindspore::kernel::ps::PServerKernel; -class OptimizerInfoBuilder { - public: - OptimizerInfoBuilder() = default; - virtual ~OptimizerInfoBuilder() = default; - - OptimizerInfo *Build(const std::shared_ptr &pserver_kernel, const WeightPtr &weight, const Keys &keys, - const Values &values, const Lengths &lens, const InputsShapePtr &inputs_shape, - size_t worker_num); - - virtual OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, - const Lengths &lens, const InputsShapePtr &inputs_shape, size_t worker_num) = 0; - - virtual void BuildWorkspaces(OptimizerInfo *info, const std::vector &ws_sizes, size_t worker_num); - virtual void BuildOutputs(OptimizerInfo *info, size_t worker_num) {} -}; - -class MomentumOptimInfoBuilder : public OptimizerInfoBuilder { - public: - OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shape, size_t worker_num) override; -}; - -class SparseAdamOptimInfoBuilder : public OptimizerInfoBuilder { - public: - OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shpae, size_t worker_num) override; -}; - -class SparseFtrlOptimInfoBuilder : public OptimizerInfoBuilder { - public: - OptimizerInfo *BuildInputs(const WeightPtr &weight, const Keys &keys, const Values &values, const Lengths &lens, - const InputsShapePtr &inputs_shpae, size_t worker_num) override; -}; -} // namespace ps -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_OPTIMIZER_INFO_BUILDER_H_ diff --git a/mindspore/ccsrc/parallel/ps/parameter_server.h b/mindspore/ccsrc/parallel/ps/parameter_server.h deleted file mode 100755 index 4d3aa41306..0000000000 --- a/mindspore/ccsrc/parallel/ps/parameter_server.h +++ /dev/null @@ -1,559 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "session/session_factory.h" -#include "parallel/ps/common.h" -#include "parallel/ps/optimizer_info.h" -#include "parallel/ps/optimizer_info_builder.h" -#include "parallel/ps/util.h" -#include "device/cpu/kernel_select_cpu.h" -#include "utils/context/ms_context.h" -#include "kernel/kernel.h" -#include "kernel/ps/pserver_kernel.h" -#include "kernel/cpu/cpu_kernel_factory.h" -#include "kernel/ps/sparse_apply_adam_ps_kernel.h" -#include "kernel/ps/sparse_apply_ftrl_ps_kernel.h" -#include "kernel/ps/apply_momentum_ps_kernel.h" -#include "kernel/ps/embedding_look_up_ps_kernel.h" - -namespace mindspore { -namespace parallel { -namespace ps { -using mindspore::kernel::ps::PServerKernel; -template -class ParameterServer { - public: - static ParameterServer &GetInstance() { - static ParameterServer instance; - return instance; - } - - void Run(const FuncGraphPtr &func_graph); - - private: - ParameterServer() - : pserver_num_(0), - worker_num_(0), - rank_id_(0), - grad_accum_count_(0), - ps_(new ::ps::KVServer(0)), - handler_(nullptr), - func_graph_(nullptr), - kernel_graph_(nullptr), - sess_(nullptr), - thread_(nullptr) {} - ~ParameterServer() = default; - ParameterServer(const ParameterServer &) = delete; - ParameterServer &operator=(const ParameterServer &) = delete; - - struct ServerHandler { - explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} - void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); - void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); - void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - void HandleInitWeights(const ::ps::KVPairs &req_data); - void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); - void HandleInitInputsShape(const ::ps::KVPairs &req_data); - void HandleInitEmbeddings(const ::ps::KVPairs &req_data); - void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); - ParameterServer *ps_; - }; - - bool Init(const FuncGraphPtr &func_graph); - void InitOptimInfoBuilders(); - void InitWeightKeyToOptims(const Key &key, const int &optim_id); - void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); - void InitWeight(const Key &key, const WeightPtr &weight); - void InitGrad(const Key &key, const GradPtr &grad); - void InitEmbeddingTable(const Key &key, - const std::shared_ptr>>> &shapes); - void UpdateWeights(); - void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); - WeightPtr weight(const Key &key); - void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); - int SumOfShapes(const std::vector &shapes) const; - size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); - bool ReadyForUpdateWeights(); - bool ReadyForAccumGrads(); - void ResetGradAccumCount(); - - size_t pserver_num_; - size_t worker_num_; - size_t rank_id_; - size_t grad_accum_count_; - std::unique_ptr<::ps::KVServer> ps_; - std::unique_ptr handler_; - FuncGraphPtr func_graph_; - std::shared_ptr kernel_graph_; - std::shared_ptr sess_; - - std::unordered_map> optimizers_; - std::unordered_map optim_inputs_shape_; - std::unordered_map> optim_infos_; - std::unordered_map> optim_info_builders_; - std::unordered_map weight_key_to_optims_; - std::unordered_map weights_; - std::unordered_map grads_; - std::unordered_map grads_accum_counter_; - // std::unordered_map embeddings_; - std::unordered_map> embedding_lookup_ops_; - std::unordered_map embedding_row_lens_; - - T learning_rate_; - T momentum_; - - std::mutex mutex_; - std::condition_variable apply_grads_cv_; - std::condition_variable accum_grads_cv_; - - std::unique_ptr thread_; - - friend struct ServerHandler; -}; - -class FuncGraph; -template -void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVServer *server) { - ::ps::KVPairs res; - if (req_meta.cmd == kInitWeightsCmd) { - MS_LOG(ERROR) << "handle init weights cmd" << std::endl; - HandleInitWeights(req_data); - } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { - MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; - HandleInitWeightToOptimId(req_data); - } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { - MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; - HandleInitInputsShape(req_data); - } else if (req_meta.cmd == kInitEmbeddingsCmd) { - MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; - HandleInitEmbeddings(req_data); - } else if (req_meta.cmd == kEmbeddingLookupCmd) { - MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; - HandleEmbeddingLookup(req_meta, req_data, &res); - } else if (req_meta.push) { - MS_LOG(ERROR) << "handle push req cmd" << std::endl; - HandlePushReq(req_meta, req_data); - } else { - MS_LOG(ERROR) << "handle pull req cmd" << std::endl; - HandlePullReq(req_meta, req_data, &res); - } - server->Response(req_meta, res); -} - -template -void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { - ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); -} - -template -void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, - ::ps::KVPairs *res) { - res->keys = req_data.keys; - ::ps::Key key = req_data.keys[0]; - res->vals = *(ps_->weight(key)); -} - -template -void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { - size_t key_num = req_data.keys.size(); - T *data_ptr = req_data.vals.data(); - size_t pos = 0; - for (size_t i = 0; i < key_num; i++) { - Key key = req_data.keys[i]; - size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; - - WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); - weight_ptr->CopyFrom(data_ptr + pos, data_len); - ps_->InitWeight(key, weight_ptr); - - GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); - ps_->InitGrad(key, grad_ptr); - pos += data_len; - } -} - -template -void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { - size_t key_num = req_data.keys.size(); - for (size_t i = 0; i < key_num; i++) { - Key key = req_data.keys[i]; - T val = req_data.vals[i]; - ps_->InitWeightKeyToOptims(key, val); - } -} - -template -void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { - ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); -} - -template -void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { - std::shared_ptr>>> shapes = - std::make_shared>>>(); - std::shared_ptr> input_shape = std::make_shared>(); - std::shared_ptr> indices_shape = std::make_shared>(); - std::shared_ptr> output_shape = std::make_shared>(); - shapes->push_back(input_shape); - shapes->push_back(indices_shape); - shapes->push_back(output_shape); - - const Key &key = req_data.keys[0]; - const Lengths &lens = req_data.lens; - size_t index = 0; - for (int i = 0; i < lens[0]; i++) { - input_shape->push_back(static_cast(req_data.vals[index++])); - } - for (int j = 0; j < lens[1]; j++) { - indices_shape->push_back(static_cast(req_data.vals[index++])); - } - for (int k = 0; k < lens[2]; k++) { - output_shape->push_back(static_cast(req_data.vals[index++])); - } - ps_->InitEmbeddingTable(key, shapes); -} - -template -void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, - const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { - const Key &key = req_data.keys[0]; - ps_->DoEmbeddingLookup(key, req_data.vals, res); - for (size_t i = 0; i < req_data.vals.size(); i++) { - res->keys->push_back(req_data.vals[i]); - } -} - -template -bool ParameterServer::Init(const FuncGraphPtr &func_graph) { - const char *server_num = getenv(kEnvPServerNum); - const char *worker_num = getenv(kEnvWorkerNum); - if (server_num != nullptr) { - pserver_num_ = *server_num - '0'; - } - if (worker_num != nullptr) { - worker_num_ = *worker_num - '0'; - } - func_graph_ = func_graph; - rank_id_ = ::ps::MyRank(); - handler_.reset(new ServerHandler(this)); - - InitOptimInfoBuilders(); - - ps_->set_request_handle(*handler_); - thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); - return true; -} - -template -void ParameterServer::InitOptimInfoBuilders() { - std::shared_ptr momentum_info_builder = std::make_shared(); - std::shared_ptr sparse_adam_info_builder = std::make_shared(); - std::shared_ptr sparse_ftrl_info_builder = std::make_shared(); - optim_info_builders_[kApplyMomentum] = momentum_info_builder; - optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; - optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; -} - -template -void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_id) { - if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { - return; - } - weight_key_to_optims_[key] = Util::optimizer_name(optim_id); -} - -template -void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { - InputsShapePtr inputs_shape = std::make_shared(); - int val_idx = 0; - const Key &key = keys[0]; - - if (optim_inputs_shape_.count(key) == 0) { - optim_inputs_shape_[key] = inputs_shape; - } - for (size_t i = 0; i < keys.size(); i++) { - auto shape = std::make_shared>(); - inputs_shape->push_back(shape); - - int len = lengths[i]; - for (int j = 0; j < len; j++) { - shape->push_back(values[val_idx++]); - } - } - if (weight_key_to_optims_.count(key) > 0) { - const std::string &optim_name = weight_key_to_optims_[key]; - if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { - if (optim_name == kSparseAdam) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); - optimizers_[optim_name] = optimizer; - } else if (optim_name == kApplyMomentum) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); - optimizers_[optim_name] = optimizer; - } else if (optim_name == kSparseFtrl) { - std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); - optimizer->InitKernel(optim_inputs_shape_[key]); - optimizers_[optim_name] = optimizer; - } - } - } -} - -template -void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { - if (weights_.count(key) == 0) { - weights_[key] = weight; - } -} - -template -void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { - if (grads_.count(key) == 0) { - grads_[key] = grad; - grads_accum_counter_[key] = 0; - } -} - -template -void ParameterServer::InitEmbeddingTable( - const Key &key, const std::shared_ptr>>> &shapes) { - // Init embedding lookup kernel - std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); - lookup->InitKernel(shapes); - embedding_lookup_ops_[key] = lookup; - - // Init embedding weight - const std::vector &input_shapes = lookup->input_sizes(); - size_t total_dims = 1; - for (auto shape : input_shapes) { - total_dims *= shape; - } - WeightPtr embedding = std::make_shared(total_dims, 0.01); - weights_[key] = embedding; - - grads_accum_counter_[key] = 0; -} - -template -void ParameterServer::UpdateWeights() { - while (true) { - std::unique_lock lock(mutex_); - apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); - - for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { - Key key = iter->first; - WeightPtr weight_ptr = iter->second; - - std::shared_ptr optimizer = nullptr; - if (weight_key_to_optims_.count(key) > 0) { - const std::string &optim_name = weight_key_to_optims_[key]; - optimizer = optimizers_[optim_name]; - } - MS_EXCEPTION_IF_NULL(optimizer); - - std::shared_ptr optim_info = optim_infos_[key]; - if (optim_info == nullptr) { - continue; - } - const WeightPtr &weight = weights_[key]; - optim_info->UpdateWeight(weight); - const std::vector &inputs = optim_info->inputs(); - const std::vector &workspaces = optim_info->workspaces(); - const std::vector &outputs = optim_info->outputs(); - - optimizer->Execute(inputs, workspaces, outputs); - optim_info->Reset(); - } - ResetGradAccumCount(); - accum_grads_cv_.notify_all(); - } -} - -template -void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { - std::unique_lock lock(mutex_); - accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); - - const Key &key = keys[0]; - std::shared_ptr optim_info = optim_infos_[key]; - - // Create or update the optimizer info - if (optim_info == nullptr) { - const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; - std::shared_ptr pserver_kernel = optimizers_[weight_key_to_optims_[key]]; - if (pserver_kernel == nullptr) { - MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; - } - MS_EXCEPTION_IF_NULL(pserver_kernel); - OptimizerInfo *optim = - builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); - optim_info.reset(optim); - optim_infos_[key] = optim_info; - } else { - optim_info->Update(values, lengths); - } - MS_EXCEPTION_IF_NULL(optim_info); - - optim_info->Accumulate(values, lengths); - - grads_accum_counter_[key] += 1; - if (grads_accum_counter_[key] == worker_num_) { - grad_accum_count_++; - } - if (ReadyForUpdateWeights()) { - apply_grads_cv_.notify_one(); - } -} - -template -WeightPtr ParameterServer::weight(const Key &key) { - std::unique_lock lock(mutex_); - - if (weights_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid weight key " << key; - return nullptr; - } - WeightPtr weight_ptr = weights_[key]; - WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); - copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); - return copy_weight_ptr; -} - -template -void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res) { - std::unique_lock lock(mutex_); - if (weights_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid embedding table key " << key; - return; - } - if (embedding_lookup_ops_.count(key) == 0) { - MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; - return; - } - WeightPtr table_ptr = weights_[key]; - std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; - - // Update shapes of lookup operator - std::shared_ptr>>> shapes = - std::make_shared>>>(); - std::shared_ptr> indices_shape = std::make_shared>(); - indices_shape->emplace_back(lookup_ids.size()); - shapes->push_back(indices_shape); - table_lookup_op->ReInit(shapes); - - const std::vector output_shapes = table_lookup_op->output_sizes(); - std::vector inputs; - AddressPtr embedding_table = std::make_shared(); - AddressPtr indices = std::make_shared(); - inputs.push_back(embedding_table); - inputs.push_back(indices); - embedding_table->addr = table_ptr->data(); - embedding_table->size = table_ptr->size() * sizeof(T); - indices->addr = lookup_ids.data(); - indices->size = lookup_ids.size() * sizeof(T); - - std::vector workspaces; - std::vector outputs; - AddressPtr output = std::make_shared(); - std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(T), 0); - - output->addr = addr->data(); - output->size = output_shapes[0]; - outputs.push_back(output); - - table_lookup_op->Execute(inputs, workspaces, outputs); - res->vals = *addr; - res->lens.push_back(res.vals.size()); -} - -template -int ParameterServer::SumOfShapes(const std::vector &shapes) const { - int sum = 1; - for (auto shape : shapes) { - sum *= shape; - } - return sum; -} - -template -size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { - size_t capacity = 0; - for (size_t i = 0; i < keys.size(); i++) { - Key key = keys[i]; - if (embedding_row_lens_.count(key) > 0) { - capacity += embedding_row_lens_[key] * lens[i]; - } else { - MS_LOG(ERROR) << "Invalid embedding lookup id " << key; - } - } - return capacity; -} - -template -inline bool ParameterServer::ReadyForUpdateWeights() { - return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); -} - -template -inline bool ParameterServer::ReadyForAccumGrads() { - return grad_accum_count_ < weights_.size(); -} - -template -inline void ParameterServer::ResetGradAccumCount() { - grad_accum_count_ = 0; - for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { - grads_accum_counter_[iter->first] = 0; - } -} - -template -void ParameterServer::Run(const FuncGraphPtr &func_graph) { - ::ps::Start(0); - if (!::ps::IsServer()) { - std::cout << "This is not ther Server" << std::endl; - return; - } - Init(func_graph); - thread_->join(); -} -} // namespace ps -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ diff --git a/mindspore/ccsrc/parallel/ps/scheduler.cc b/mindspore/ccsrc/parallel/ps/scheduler.cc deleted file mode 100755 index 81cd5f9358..0000000000 --- a/mindspore/ccsrc/parallel/ps/scheduler.cc +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ps/scheduler.h" -#include -#include "ps/ps.h" - -namespace mindspore { -namespace parallel { -namespace ps { -void Scheduler::Run() { - ::ps::Start(0); - while (true) { - sleep(1); - } -} -} // namespace ps -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ps/util.cc b/mindspore/ccsrc/parallel/ps/util.cc deleted file mode 100644 index dbc258284e..0000000000 --- a/mindspore/ccsrc/parallel/ps/util.cc +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/ps/util.h" -#include -#include "parallel/ps/common.h" -#include "common/utils.h" - -namespace mindspore { -namespace parallel { -namespace ps { -std::unordered_map Util::optimizer_to_ids{ - {kApplyMomentum, 0}, - {kSparseAdam, 1}, - {kSparseFtrl, 2}, -}; - -std::unordered_map Util::id_to_optimizers{ - {0, kApplyMomentum}, - {1, kSparseAdam}, - {2, kSparseFtrl}, -}; -bool Util::IsParamServerMode() { return IsRoleOfWorker() || IsRoleOfPServer() || IsRoleOfScheduler(); } - -bool Util::IsRoleOfWorker() { - auto role = common::GetEnv(kEnvRole); - if (strcmp(role.c_str(), kEnvRoleOfWorker) == 0) { - return true; - } else { - return false; - } -} - -bool Util::IsRoleOfPServer() { - auto role = common::GetEnv(kEnvRole); - if (strcmp(role.c_str(), kEnvRoleOfPServer) == 0) { - return true; - } else { - return false; - } -} - -bool Util::IsRoleOfScheduler() { - auto role = common::GetEnv(kEnvRole); - if (strcmp(role.c_str(), kEnvRoleOfScheduler) == 0) { - return true; - } else { - return false; - } -} - -void Util::SetInternalEnvVar() { - if (IsParamServerMode()) { - auto comm_type = common::GetEnv(kEnvCommType); - if (comm_type.size() > 0) { - (void)common::SetEnv(kDmlcCommType, comm_type.c_str()); - } - auto interface = common::GetEnv(kEnvInterface); - if (interface.size() > 0) { - (void)common::SetEnv(kDmlcInterface, interface.c_str()); - } - auto server_num = common::GetEnv(kEnvPServerNum); - if (server_num.size() > 0) { - (void)common::SetEnv(kDmlcPServerNum, server_num.c_str()); - } - auto worker_num = common::GetEnv(kEnvWorkerNum); - if (worker_num.size() > 0) { - (void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str()); - } - if (IsRoleOfScheduler()) { - (void)common::SetEnv(kDmlcRole, kRoleOfScheduler); - } else if (IsRoleOfPServer()) { - (void)common::SetEnv(kDmlcRole, kRoleOfPServer); - } else if (IsRoleOfWorker()) { - (void)common::SetEnv(kDmlcRole, kRoleOfWorker); - } - auto scheduler_host = common::GetEnv(kEnvSchedulerHost); - if (scheduler_host.size() > 0) { - (void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str()); - } - auto scheduler_port = common::GetEnv(kEnvSchedulerPort); - if (scheduler_port.size() > 0) { - (void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str()); - } - } -} - -int Util::optimizer_id(std::string name) { - if (optimizer_to_ids.count(name) > 0) { - return optimizer_to_ids[name]; - } - return -1; -} - -std::string Util::optimizer_name(int id) { - if (id_to_optimizers.count(id) > 0) { - return id_to_optimizers[id]; - } - return ""; -} - -bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } - -int Util::LocalShard(int first_dim, int rank_id, int server_num) { - int shard_size = std::round((static_cast(first_dim)) / server_num); - int remain_size = first_dim % server_num; - if (remain_size == 0 || rank_id < server_num - 1) { - return shard_size; - } else { - return first_dim - (shard_size * (server_num - 1)); - } -} -} // namespace ps -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ps/util.h b/mindspore/ccsrc/parallel/ps/util.h deleted file mode 100644 index b55ced0c97..0000000000 --- a/mindspore/ccsrc/parallel/ps/util.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ - -#include -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace parallel { -namespace ps { -class Util { - public: - static bool IsParamServerMode(); - static bool IsRoleOfWorker(); - static bool IsRoleOfPServer(); - static bool IsRoleOfScheduler(); - static void SetInternalEnvVar(); - static int optimizer_id(std::string name); - static std::string optimizer_name(int id); - static bool is_optimizer(std::string name); - static int LocalShard(int first_dim, int rank_id, int server_num); - - private: - static std::unordered_map optimizer_to_ids; - static std::unordered_map id_to_optimizers; -}; -} // namespace ps -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_UTIL_H_ diff --git a/mindspore/ccsrc/parallel/ps/worker.h b/mindspore/ccsrc/parallel/ps/worker.h deleted file mode 100644 index b9d0cdcc85..0000000000 --- a/mindspore/ccsrc/parallel/ps/worker.h +++ /dev/null @@ -1,259 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ - -#include -#include -#include -#include -#include -#include "ps/ps.h" -#include "utils/log_adapter.h" -#include "parallel/ps/util.h" -#include "parallel/ps/common.h" -#include "parallel/ps/worker_proxy.h" - -namespace mindspore { -namespace parallel { -namespace ps { -template -class Worker { - public: - static Worker &GetInstance() { - static Worker instance; - return instance; - } - - void Run(); - void Push(const std::vector &keys, std::vector addrs, const std::vector &sizes); - void Pull(const size_t key, void *dev_addr, const size_t size); - size_t SetParamKey(const std::string ¶m_name); - void SetKeyOptimId(size_t key, const std::string &optimizer_name); - void SetOptimInputShapes(size_t key, const std::vector &shape); - void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); - void InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size); - void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); - - private: - Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} - ~Worker() { ::ps::Finalize(0, true); } - Worker(const Worker &) = delete; - Worker &operator=(const Worker &) = delete; - - bool IsKeyInit(const size_t key); - size_t GetParamKey(const std::string ¶m_name); - void InitPSOptimId(const size_t param_key); - void InitPSOptimInputShapes(const size_t key); - void InitPSParamData(const std::vector &keys, void *origin_addr, size_t size); - static void EmbeddingLookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &ranges, - std::vector>> *sliced) {} - - std::shared_ptr> kv_worker_; - bool running_; - size_t key_cnt_; - std::map param_to_key_; - std::map init_keys_; - std::map key_to_optimId_; - std::map>> key_to_optim_shapes_; -}; - -template -void Worker::Run() { - if (running_) { - MS_LOG(INFO) << "'Worker is already running."; - return; - } - - ::ps::Start(0); - if (!::ps::IsWorker()) { - MS_LOG(EXCEPTION) << "The role is not worker."; - } - kv_worker_ = std::make_shared>(0, 0, 1); - running_ = true; -} - -template -void Worker::Push(const std::vector &keys, std::vector addrs, const std::vector &sizes) { - size_t total_size = 0; - for (auto size : sizes) { - total_size += size; - } - ::ps::SArray total_buffer(total_size, 0); - size_t offset = 0; - for (size_t i = 0; i < sizes.size(); i++) { - memcpy(total_buffer.data() + offset / sizeof(T), addrs[i], sizes[i] * sizeof(T)); - offset += sizes[i] * sizeof(T); - } - kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray(sizes)); -} - -template -void Worker::Pull(const size_t key, void *dev_addr, const size_t size) { - ::ps::SArray variables(size / sizeof(T), 0); - kv_worker_->Wait(kv_worker_->ZPull({key}, &variables)); - memcpy(dev_addr, variables.data(), size); -} - -template -void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd) { - kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, &lookup_result, cmd); -} - -template -void Worker::InitPSParamData(const std::vector &keys, void *origin_addr, size_t size) { - ::ps::SArray addr(reinterpret_cast(origin_addr), size / sizeof(T)); - ::ps::SArray<::ps::Key> key(keys); - ::ps::SArray lens; - lens.push_back(addr.size()); - kv_worker_->Wait(kv_worker_->ZPush(key, addr, lens, kInitWeightsCmd)); - init_keys_[key[0]] = true; -} - -template -void Worker::SetOptimInputShapes(size_t key, const std::vector &shape) { - if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) { - key_to_optim_shapes_[key] = {shape}; - } else { - key_to_optim_shapes_[key].push_back(shape); - } -} - -template -void Worker::InitPSOptimInputShapes(const size_t key) { - ::ps::SArray<::ps::Key> keys; - ::ps::SArray shape_len; - ::ps::SArray all_shape; - std::vector> shapes = key_to_optim_shapes_[key]; - for (auto shape : shapes) { - keys.push_back(key); - if (shape.size() == 0) { - shape_len.push_back(1); - all_shape.push_back(1); - } else { - shape_len.push_back(SizeToInt(shape.size())); - for (auto dim : shape) { - all_shape.push_back(static_cast(dim)); - } - } - } - MS_LOG(ERROR) << "keys:" << keys; - MS_LOG(ERROR) << "shape_len:" << shape_len; - MS_LOG(ERROR) << "all_shape:" << all_shape; - if (!init_keys_[key]) { - init_keys_[key] = true; - } - kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); -} - -template -bool Worker::IsKeyInit(const size_t key) { - if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) { - return false; - } - return true; -} - -template -size_t Worker::SetParamKey(const std::string ¶m_name) { - size_t key = UINT64_MAX; - if (param_to_key_.count(param_name)) { - key = param_to_key_[param_name]; - MS_LOG(INFO) << param_name << " key is already set: key value is " << key; - } else { - key = key_cnt_++; - param_to_key_[param_name] = key; - MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name; - } - return key; -} - -template -size_t Worker::GetParamKey(const std::string ¶m_name) { - size_t key = kInvalidKey; - if (param_to_key_.find(param_name) != param_to_key_.end()) { - key = param_to_key_[param_name]; - MS_LOG(ERROR) << "Get key of parameter " << param_name << " key is " << key; - } - return key; -} - -template -void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) { - key_to_optimId_[key] = Util::optimizer_id(optimizer_name); -} - -template -void Worker::InitPSOptimId(const size_t param_key) { - if (key_to_optimId_.count(param_key) == 0) { - MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key; - } - int optim_id = key_to_optimId_[param_key]; - - ::ps::SArray<::ps::Key> keys = {param_key}; - ::ps::SArray optim_id_vals = {static_cast(optim_id)}; - ::ps::SArray optim_id_lens = {optim_id_vals.size()}; - kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd); -} - -template -void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, - const std::vector &sizes) { - bool has_init = IsKeyInit(keys[0]); - if (has_init) { - MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; - return; - } - ::ps::SArray shapes_val; - for (auto dim : shapes) { - shapes_val.push_back(static_cast(dim)); - } - kv_worker_->Wait(kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray(sizes))); -} - -template -// Initialize parameters and optimizer kernels of Parameter Server. -void Worker::InitPSParamAndOptim(const std::string ¶m_name, void *param_data, size_t param_size) { - size_t param_key = GetParamKey(param_name); - if (param_key == kInvalidKey) { - MS_LOG(INFO) << "Parameter " << param_name << " has no key assigned."; - return; - } - bool init = IsKeyInit(param_key); - if (!init) { - MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name; - // No need to push embedding table data to Parameter Server. - if (param_name.find("embedding_table") == std::string::npos && param_name.find("wide_w") == std::string::npos) { - InitPSParamData({param_key}, param_data, param_size); - } - InitPSOptimId(param_key); - InitPSOptimInputShapes(param_key); - } -} - -template -void Worker::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { - kv_worker_->AddEmbeddingTable(key, row_count); -} - -} // namespace ps -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_H_ diff --git a/mindspore/ccsrc/parallel/ps/worker_proxy.h b/mindspore/ccsrc/parallel/ps/worker_proxy.h deleted file mode 100644 index 8ffdde84ea..0000000000 --- a/mindspore/ccsrc/parallel/ps/worker_proxy.h +++ /dev/null @@ -1,311 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ -#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ - -#include -#include -#include -#include -#include -#include "ps/ps.h" -#include "parallel/ps/util.h" - -namespace mindspore { -namespace parallel { -namespace ps { -template -class WorkerProxy : public ::ps::KVWorker { - public: - using Worker = ::ps::KVWorker; - using Callback = std::function; - using SlicedKVs = std::vector>>; - using Slicer = - std::function &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; - using ::ps::SimpleApp::obj_; - explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { - using _1 = std::placeholders::_1; - using _2 = std::placeholders::_2; - using _3 = std::placeholders::_3; - lookup_customer_ = std::unique_ptr<::ps::Customer>( - new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); - lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3); - init_embedding_slicer_ = std::bind(&WorkerProxy::EmbeddingTableInitSlicer, this, _1, _2, _3); - push_slicer_ = std::bind(&WorkerProxy::PushSlicer, this, _1, _2, _3); - broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3); - } - ~WorkerProxy() override = default; - - void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, - int priority = 0); - int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); - void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, - int cmd = 0, int priority = 0); - - private: - template - int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, - const Callback &cb); - void LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); - void EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); - void PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); - void BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced); - void ProcessLookupResult(const ::ps::Message &msg); - void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, - const Slicer &slicer); - - std::unique_ptr<::ps::Customer> lookup_customer_; - std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; - std::unordered_map>> lookup_results_; - std::mutex mutex_; - Slicer lookup_slicer_; - Slicer init_embedding_slicer_; - Slicer push_slicer_; - Slicer broadcast_slicer_; - std::unordered_map lookup_callbacks_; -}; - -template -void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { - uint64_t begin = 0; - uint64_t end = 0; - int server_num = ::ps::NumServers(); - for (int i = 0; i < server_num; i++) { - int local_row_cnt = Util::LocalShard(row_count, i, server_num); - if (i == 0) { - end = local_row_cnt - 1; - } else { - begin = end + 1; - end += local_row_cnt; - } - ::ps::Range range(begin, end); - if (embedding_table_ranges_.count(key) == 0) { - embedding_table_ranges_[key] = std::make_shared>(); - } - embedding_table_ranges_[key]->push_back(range); - } -} - -template -void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, - int priority) { - int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = lookup_ids; - kvs.lens = lens; - kvs.priority = priority; - Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); - lookup_customer_->WaitRequest(ts); -} - -template -int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens, const Callback &cb, int priority) { - int ts = obj_->NewRequest(::ps::kServerGroup); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = vals; - kvs.lens = lens; - kvs.priority = priority; - Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); - return ts; -} - -template -void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, - const ::ps::SArray &lens, int cmd, int priority) { - int ts = obj_->NewRequest(::ps::kServerGroup); - ::ps::KVPairs kvs; - kvs.keys = keys; - kvs.vals = vals; - kvs.lens = lens; - kvs.priority = priority; - Send(obj_, ts, true, false, cmd, kvs, push_slicer_); - obj_->WaitRequest(ts); -} - -template -template -int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, - C *lookup_result, int cmd, const Callback &cb) { - int ts = lookup_customer_->NewRequest(::ps::kServerGroup); - const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { - mutex_.lock(); - auto &kvs = lookup_results_[ts]; - mutex_.unlock(); - - size_t total_len = 0; - const auto &s = kvs[0]; - for (size_t i = 0; i < s.lens.size(); i++) { - total_len += s.lens[i]; - } - lookup_result->resize(total_len, 0); - T *result_addr = lookup_result->data(); - - for (const auto &s : kvs) { - size_t offset = 0; - for (size_t i = 0; i < s.vals.size(); i++) { - result_addr[offset++] += s.vals[i]; - } - } - - mutex_.lock(); - lookup_results_.erase(ts); - mutex_.unlock(); - if (cb) cb(); - }; - lookup_callbacks_[ts] = callback; - return ts; -} - -template -void WorkerProxy::LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { - int *data = send.lens.data(); - size_t size = send.lens.size(); - std::vector lookup_ids(data, data + size); - std::sort(lookup_ids.begin(), lookup_ids.end()); - - const Key &key = send.keys[0]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); - sliced->resize(ranges.size()); - - size_t index = 0; - for (size_t i = 0; i < ranges.size(); i++) { - const ::ps::Range &range = ranges[i]; - const auto &begin = range.begin(); - const auto &end = range.end(); - auto &kvs = sliced->at(i).second; - - auto lookup_id = static_cast(lookup_ids[index]); - while (lookup_id >= begin && lookup_id <= end) { - kvs.vals.push_back(lookup_id); - if (++index >= lookup_ids.size()) { - break; - } - lookup_id = static_cast(lookup_ids[index]); - } - kvs.keys.push_back(key); - kvs.lens.push_back(kvs.vals.size()); - - if (kvs.vals.size() == 0) { - sliced->at(i).first = false; - } else { - sliced->at(i).first = true; - } - } -} - -template -void WorkerProxy::EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { - const Key &key = send.keys[0]; - const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); - sliced->resize(ranges.size()); - for (size_t i = 0; i < ranges.size(); i++) { - sliced->at(i).first = true; - sliced->at(i).second = send; - } -} - -template -void WorkerProxy::PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { - auto server_num = ::ps::Postoffice::Get()->num_servers(); - sliced->resize(server_num); - for (int i = 0; i < server_num; i++) { - sliced->at(i).first = true; - sliced->at(i).second = send; - } -} - -template -void WorkerProxy::BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, - std::vector>> *sliced) { - auto server_num = ::ps::Postoffice::Get()->num_servers(); - sliced->resize(server_num); - for (int i = 0; i < server_num; i++) { - sliced->at(i).first = true; - sliced->at(i).second = send; - } -} - -template -void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { - int ts = msg.meta.timestamp; - if (msg.meta.pull) { - CHECK_GE(msg.data.size(), (size_t)2); - ::ps::KVPairs kvs; - kvs.keys = msg.data[0]; - kvs.vals = msg.data[1]; - if (msg.data.size() > (size_t)2) { - kvs.lens = msg.data[2]; - } - mutex_.lock(); - lookup_results_[ts].push_back(kvs); - mutex_.unlock(); - } - if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { - const auto &cb = lookup_callbacks_[ts]; - cb(); - lookup_callbacks_.erase(ts); - } -} - -template -void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, - const ::ps::KVPairs &kvs, const Slicer &slicer) { - SlicedKVs sliced; - slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); - - for (size_t i = 0; i < sliced.size(); i++) { - const auto &s = sliced[i]; - if (!s.first) continue; - ::ps::Message msg; - msg.meta.app_id = customer->app_id(); - msg.meta.customer_id = customer->customer_id(); - msg.meta.request = true; - msg.meta.push = push; - msg.meta.pull = pull; - msg.meta.head = cmd; - msg.meta.timestamp = timestamp; - msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); - msg.meta.priority = kvs.priority; - const auto &kvs = s.second; - if (kvs.keys.size()) { - msg.AddData(kvs.keys); - msg.AddData(kvs.vals); - if (kvs.lens.size()) { - msg.AddData(kvs.lens); - } - } - ::ps::Postoffice::Get()->van()->Send(msg); - } -} -} // namespace ps -} // namespace parallel -} // namespace mindspore -#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc deleted file mode 100644 index cda2407cd1..0000000000 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ /dev/null @@ -1,1187 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/step_auto_parallel.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/param_value.h" -#include "ir/tensor.h" -#include "optimizer/opt.h" -#include "optimizer/optimizer.h" -#include "parallel/auto_parallel/dp_algo_costmodel.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/auto_parallel/rec_core/rec_generate_strategy.h" -#include "parallel/auto_parallel/rec_core/rec_parse_graph.h" -#include "parallel/auto_parallel/rec_core/rec_partition.h" -#include "parallel/context.h" -#include "parallel/ops_info/tmp_identity_info.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/step_parallel.h" -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/pipeline.h" - -namespace mindspore { -namespace parallel { -bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); - // assume no change to graph - bool changes = false; - // control whether use model_parallel mode - if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) || - root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) { - return changes; - } - // check whether strategy_search_mode is valid - std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); - if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { - // Setting searching mode: dynanic programming as default. - strategy_search_mode = DYNAMIC_PROGRAMMING; - MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; - } - - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - - if (MsContext::GetInstance()->save_graphs_flag()) { - draw::Draw(STEP_AUTO_PARALLEL_BEGIN, root); - } - MS_LOG(INFO) << "Now entering step auto parallel"; - TOTAL_OPS = 0; - AnfNodePtr ret = root->get_return(); - std::vector all_nodes = DeepScopedGraphSearch(ret); - - if (ParallelInit() != SUCCESS) { - MS_LOG(EXCEPTION) << "Parallel init failed"; - } - - // mark the forward cnodes, parallel only care these nodes - MarkForwardCNode(root); - - if (FindCommunicationOp(all_nodes)) { - MS_LOG(EXCEPTION) << "The graph contain communication op"; - } - - // search parallelization strategy - if (strategy_search_mode == DYNAMIC_PROGRAMMING) { - if (ParallelStrategySearch(all_nodes, root) != SUCCESS) { - MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using DP searching mode"; - } - } else if (strategy_search_mode == RECURSIVE_PROGRAMMING) { - if (ParallelStrategyRecSearch(all_nodes, root) != SUCCESS) { - MS_LOG(EXCEPTION) << "Auto-parallel strategy search failed when using RP searching mode"; - } - } else { - MS_LOG(EXCEPTION) << "Auto-parallel strategy searching mode unexpected"; - } - - (void)gettimeofday(&end_time, nullptr); - uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Now leaving step auto parallel, used time: " << time << " us"; - - root->set_flag(AUTO_PARALLEL_RUN_ONCE_ONLY, true); - return changes; -} - -// Given the node, return whether each input is a parameter or a output of a operator. -// The returned boolean vector should be the same order of the inputs, thus its implementation -// is closely consistent with ExtractShape() in step_parallel.cc -std::vector ExtractInputParameterByNode(const CNodePtr &node) { - std::vector is_parameter; - std::vector node_inputs{node->inputs()}; - for (size_t i = 1; i < node_inputs.size(); ++i) { - auto input = node_inputs[i]; - - if (input->isa()) { - auto input_parameter = input->cast(); - if (input_parameter->has_default()) { - bool requires_grad = input_parameter->default_param()->requires_grad(); - is_parameter.push_back(requires_grad); - } else { - is_parameter.push_back(false); - } - } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { - is_parameter.push_back(false); - } - } - return is_parameter; -} - -// Given the type, return the number of bytes to represent this type -size_t GetLengthOfDataType(const TypePtr &type) { - switch (type->type_id()) { - case kNumberTypeBool: - return sizeof(bool); - case kNumberTypeInt8: - return sizeof(int8_t); - case kNumberTypeInt16: - return sizeof(int16_t); - case kNumberTypeInt32: - return sizeof(int32_t); - case kNumberTypeInt64: - return sizeof(int64_t); - case kNumberTypeUInt8: - return sizeof(uint8_t); - case kNumberTypeUInt16: - return sizeof(uint16_t); - case kNumberTypeUInt32: - return sizeof(uint32_t); - case kNumberTypeUInt64: - return sizeof(uint64_t); - case kNumberTypeFloat16: - return sizeof(float) / 2; - case kNumberTypeFloat32: - return sizeof(float); - case kNumberTypeFloat64: - return sizeof(double); - case kNumberTypeInt: - return sizeof(int); - case kNumberTypeUInt: - return sizeof(unsigned int); - case kNumberTypeFloat: - return sizeof(float); - default: - MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); - } -} - -size_t GetInputsTypeLen(const AnfNodePtr &input) { - MS_EXCEPTION_IF_NULL(input); - if (!input->isa() && !input->isa() && !IsValueNode(input)) { - MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; - } - - size_t input_type_len = 0; - auto type = input->Type(); - MS_EXCEPTION_IF_NULL(type); - if (type->isa()) { - auto input_element_type = type->cast()->element(); - input_type_len = GetLengthOfDataType(input_element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); - } - return input_type_len; -} - -std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::vector inputs_type_len; - std::vector node_inputs{node->inputs()}; - - // extract input element length - for (auto &input : node_inputs) { - if (IsValueNode(input)) { - auto func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters = FindParameterByRefKeyNode(input, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); - } else if (input->isa() || input->isa() || IsValueNode(input)) { - // extract input shape from parameter and apply node - inputs_type_len.push_back(GetInputsTypeLen(input)); - } - } - return inputs_type_len; -} - -std::vector ExtractOutputTypeByNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::vector outputs_type; - // extract output element type - auto primary_output_type = node->Type(); - MS_EXCEPTION_IF_NULL(primary_output_type); - if (primary_output_type->isa()) { - // in this case, the output is a tuple - auto tuple_output_type = primary_output_type->cast(); - auto elements = tuple_output_type->elements(); - for (auto &ele : elements) { - if (ele->isa()) { - auto ele_element_type = ele->cast()->element(); - outputs_type.push_back(ele_element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); - } - } - } else { - // in this case, the output is a single tensor - if (primary_output_type->isa()) { - auto element_type = primary_output_type->cast()->element(); - outputs_type.push_back(element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); - } - } - return outputs_type; -} - -bool IsElementWiseOperator(const std::string &op_name) { - static const std::set elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, - SQRT, CAST, POW, EXP, LOG, COS, - ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; - auto iter = elementwise_op.find(op_name); - return (iter != elementwise_op.end()); -} - -bool IsSplittableOperator(const std::string &op_name) { - // clang-format off - static const std::set splittable_op = - {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, - FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, - REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, - MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, - LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, - STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, - SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; - // clang-format on - - auto iter = splittable_op.find(op_name); - return (iter != splittable_op.end()); -} - -bool IsAutoParallelCareNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - ValueNodePtr prim_node = cnode->input(0)->cast(); - if (prim_node == nullptr) { - return false; - } - PrimitivePtr prim = GetValueNode(prim_node); - if (prim == nullptr) { - return false; - } - bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name()); - if (bool_result) { - MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name(); - } else if (prim->name() == CAST) { - if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) { - // Do not care CASTs from optimizer - return false; - } - return true; - } - return IsParallelCareNode(cnode) && IsSplittableOperator(prim->name()); -} - -OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(cnode); - auto attrs = prim->attrs(); - std::vector shape_list = ExtractShape(cnode); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure: node " << cnode->UniqueId() << " failed to extract shape"; - } - // Create an OperatorInfo instance - OperatorInfoPtr operator_info = NewOperatorInstance(prim, attrs, shape_list); - MS_EXCEPTION_IF_NULL(operator_info); - // Set the parameter information for this OperatorInfo (whether the inputs are parameters or not) - std::vector parameter_info = ExtractInputParameterByNode(cnode); - if (operator_info->set_is_parameter(parameter_info) != SUCCESS) { - MS_LOG(ERROR) << "Initializing parameter information failed for operator: " << operator_info->name(); - return nullptr; - } - // Set the data type for inputs and outputs of this OperatorInfo - auto inputs_type_length = ExtractInputTypeLengthByNode(cnode); - auto outputs_type = ExtractOutputTypeByNode(cnode); - std::vector outputs_type_length; - outputs_type_length.reserve(outputs_type.size()); - std::transform(outputs_type.begin(), outputs_type.end(), std::back_inserter(outputs_type_length), - GetLengthOfDataType); - if (operator_info->SetInputAndOutputTypeLength(inputs_type_length, outputs_type_length) != SUCCESS) { - MS_LOG(ERROR) << "Setting the lengths of inputs and outputs failed for operator: " << operator_info->name(); - return nullptr; - } - if (operator_info->set_outputs_type(outputs_type) != SUCCESS) { - MS_LOG(ERROR) << "Setting the types of outputs failed for operator: " << operator_info->name(); - return nullptr; - } - // When the 'inputs' contains numerical values for some operators, these values should be extracted from - // ANF graph - auto &inputs = cnode->inputs(); - std::vector input_value; - for (size_t index = 1; index < inputs.size(); ++index) { - if (inputs[index]->isa()) { - input_value.push_back(GetValueNode(inputs[index])); - } else { - input_value.emplace_back(nullptr); - } - } - operator_info->set_input_value(input_value); - operator_info->set_outputs_dtype(cnode->Type()); - operator_info->set_cnode(cnode); - // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); - bool load_strategy_from_ckpt = - StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); - // If no strategy has been configured for this operator, then candidate strategies are generated for - // auto-strategy searching; if this primitive is CAST, we ignore the user-specified strategy. - // if strategy is set to load from checkpoint, it is prefer to load strategy from checkpoint . - if ((!StrategyFound(attrs) || prim->name() == CAST) && !load_strategy_from_ckpt) { - // Compute split_flag_list_, indicating which input has batch dimension. This is ONLY used for preparation for - // BatchParallelInfo operator - operator_info->ComputeBatchSplitFlagList(); - if (operator_info->GenerateStrategies(0) != SUCCESS) { - MS_LOG(ERROR) << "Strategy search for Operator " << operator_info->name() << " failed."; - return nullptr; - } - } else { - // In this case, the configured strategy should be extracted to help setting cost - StrategyPtr strategyPtr; - if (load_strategy_from_ckpt) { - strategyPtr = (*stra_map)[strategy_key_name]; - } else { - strategyPtr = parallel::ExtractStrategy(attrs); - } - if (strategyPtr != nullptr) { - if (prim->name() == RESHAPE) { - MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; - } - // Set cost for this configured strategy - if (operator_info->SetCostUnderStrategy(strategyPtr) != SUCCESS) { - MS_LOG(EXCEPTION) << "Failure: operator " << prim->name() << " SetCostUnderStrategy failed"; - } else if (FULLY_USE_DEVICES) { - // If configured to fully use devices, then checking for the user-specified strategy - int32_t used_devices = operator_info->used_devices(); - MS_EXCEPTION_IF_NULL(g_device_manager); - auto total_device_num = g_device_manager->GetDeviceListByStageId(0).size(); - // 'used_devices == 1' means that ALL-1 strategy, which is valid in auto-parallel - if (used_devices == 1) { - return operator_info; - } - // 'used_devices == -1' means that 'used_devices_' is not set - if ((used_devices == -1) || IntToSize(used_devices) != total_device_num) { - MS_LOG(EXCEPTION) << "In configuration 'FULLY_USE_DEVICES' = True, " - << "but the specified strategy uses device: " << used_devices - << ", total devices: " << total_device_num; - } - } - } - } - return operator_info; -} - -// Using CNode's UniqueIds to construct nodes -Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { - MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); - // The map from CNode's UniqueId to its operatorInfo - std::map from_cnode_to_info; - // extract strategy from checkpoint for multi-train - StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - } - // Step 1 - for (auto &node : all_nodes) { - // NOTE: we only care about splittable Primitive operators - auto cnode = node->cast(); - bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); - if (bool_result) { - continue; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsAutoParallelCareNode(cnode)) { - // Needed by rec_parser - if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { - auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); - if (prev_cnode != nullptr) { - entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); - } - } - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - - auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); - if (search_cnode == from_cnode_to_info.end()) { - auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); - if (operator_info == nullptr) { - return FAILED; - } - // Needed by rec_parser - operator_info->set_type(prim->name()); - std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); - - entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); - (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); - // Needed by rec_parser - entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); - } else { - // Two CNODEs' UniqueIds should not be equal - MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); - } - } - - MS_LOG(INFO) << "Constructing nodes for cost graph ends."; - return SUCCESS; -} - -// Using CNode's UniqueIdThroughCopys to construct nodes -Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { - MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); - // The map from CNode's UniqueIdThroughCopy to its operatorInfo - std::map from_cnode_to_info; - // extract strategy from checkpoint for multi-train - StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - } - for (auto &node : all_nodes) { - // NOTE: we only care about splittable Primitive operators - auto cnode = node->cast(); - bool bool_result = (cnode == nullptr) || (!IsValueNode(cnode->input(0))); - if (bool_result) { - continue; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsAutoParallelCareNode(cnode)) { - // Needed by rec_parser - if (ParallelContext::GetInstance()->strategy_search_mode() == RECURSIVE_PROGRAMMING) { - auto prev_cnode = GetInternalOperatorInfo(cnode, prim_anf_node); - if (prev_cnode != nullptr) { - entire_costgraph->add_tuple_getitem(std::make_pair(cnode->UniqueId(), prev_cnode->UniqueId())); - } - } - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - - // Find the operatorInfo if it exists - auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); - if (search_cnode == from_cnode_to_info.end()) { - // In this case, the corresponding OperatorInfo is not created, create the new one. - auto operator_info = CreateTheOperatorInfo(prim, cnode, &stra_map); - if (operator_info == nullptr) { - return FAILED; - } - // Needed by rec_parser - operator_info->set_type(prim->name()); - std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); - - entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); - (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); - // Needed by rec_parser - entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); - } else { - auto current_op_ptr = search_cnode->second; - if (current_op_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; - } else { - bool is_find_wrong = (current_op_ptr->name().find(VIRTUAL_DATA_SET_INFO) == std::string::npos) && - (current_op_ptr->name().find(BATCH_PARALLEL) == std::string::npos) && - (current_op_ptr->name().find(prim->name()) == std::string::npos); - if (is_find_wrong) { - MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() - << " does not match the Prim: " << prim->name(); - } - (void)cnode->set_operator_info(current_op_ptr); - MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() - << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() - << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); - } - } - } - - MS_LOG(INFO) << "Constructing nodes for cost graph ends."; - return SUCCESS; -} - -void ConstructCostGraphEdges(const std::vector &all_nodes) { - // Step 2 - MS_LOG(INFO) << "Constructing edges for cost graph begins."; - for (auto &node : all_nodes) { - auto cnode = node->cast(); - bool bool_result_cnode = (cnode == nullptr) || !IsValueNode(cnode->input(0)); - if (bool_result_cnode) { - continue; - } - auto &inputs = cnode->inputs(); - ValueNodePtr prim_anf_node = inputs[0]->cast(); - if (!IsAutoParallelCareNode(cnode)) { - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - size_t edge_count = 0; - - for (size_t i = 1; i < inputs.size(); ++i) { - auto prev_cnode = inputs[i]->cast(); - bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); - if (bool_result_prev_cnode) { - continue; - } - ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast(); - PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast(); - size_t output_index = 0; - - bool bool_result = - (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); - while (bool_result) { - if (IsAutoParallelCareNode(prev_cnode)) { - std::string edge_name = - prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); - // If the edge between these two operators already has been added, then the edge will not be added again. - if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { - break; - } - EdgePtr edge_ptr; - MS_LOG(INFO) << "Creating edge: " << edge_name; - - bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) || - (ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name())); - if (follow_strategy) { - // Redistribution in not allowed on the edge. - // Elementwise operators have the same strategy as their previous operators. - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false, true); - } else { - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false); - } - - // Init costs for this edge - if (edge_ptr->InitEdgeCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Edge cost initialization failed"; - } - cnode->operator_info()->AddPrevEdge(edge_ptr); - prev_cnode->operator_info()->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " - << cnode->operator_info()->name(); - edge_count++; - - break; - } else if (prev_prim->name() == TUPLE_GETITEM) { - // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before - // this 'tuple_getitem' - MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; - output_index = IntToSize(GetValue(GetValueNode(prev_cnode->input(2)))); - prev_cnode = prev_cnode->input(1)->cast(); - bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); - if (bool_result_tuple) { - break; - } - prev_prim_anf_node = prev_cnode->input(0)->cast(); - prev_prim = prev_prim_anf_node->value()->cast(); - if (!IsAutoParallelCareNode(prev_cnode)) { - MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name(); - } - MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, " - << "and creating an edge between the Operator before " - << "'tuple_getitem' and the Operator after 'tuple_getitem'."; - } else if (prev_prim->name() == DEPEND) { - // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before - // this 'depend' - MS_LOG(INFO) << "Jumping the 'depend' operator."; - prev_cnode = prev_cnode->input(1)->cast(); - bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); - if (bool_result_depend) { - break; - } - prev_prim_anf_node = prev_cnode->input(0)->cast(); - prev_prim = prev_prim_anf_node->value()->cast(); - MS_LOG(INFO) << "Jumped the 'depend' operator, " - << "and creating an edge between the Operator before " - << "'depend' and the Operator after 'depend'."; - } - bool_result = - (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); - } - } - MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); - } - - MS_LOG(INFO) << "Constructing edges for cost graph ends."; -} - -std::pair> CNodeWithRefKeys(const AnfNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector refkeys; - if (cnode->isa()) { - auto cnode_ptr = cnode->cast(); - auto inputs = cnode_ptr->inputs(); - for (auto &one_input : inputs) { - if (IsValueNode(one_input)) { - refkeys.push_back(one_input); - } - } - if (refkeys.size() >= 1) { - return std::make_pair(cnode, refkeys); - } - } - return {nullptr, refkeys}; -} - -void AugmentCostGraph(const std::vector &all_nodes) { - // Step 3 - for (auto &node : all_nodes) { - auto cnode_with_refkeys = CNodeWithRefKeys(node); - if ((!node->isa()) && (cnode_with_refkeys.first == nullptr)) { - continue; - } - std::string parameter_name; - AnfNodePtr target_parameter = nullptr; - AnfNodeIndexSet target_set; - - if (cnode_with_refkeys.first != nullptr) { - // Dealing with the RefKey case - auto refkeys = cnode_with_refkeys.second; - auto cnode = cnode_with_refkeys.first; - - auto cnode_ptr = cnode->cast(); - if (cnode_ptr == nullptr || !IsValueNode(cnode_ptr->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(cnode_ptr)) { - continue; - } - - if (refkeys.size() > 1) { - MS_LOG(EXCEPTION) << "CNode: " << cnode->fullname_with_scope() << " 's inputs have more than 1 RefKeys."; - } - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - auto cnode_func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); - - // Find the RefKey being used - auto candidate_set_by_refkey = cnode_func_graph->manager()->node_users()[refkeys[0]]; - for (auto &candidate : candidate_set_by_refkey) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - target_set.add(candidate); - } - - // Find the corresponding Parameter being used - std::vector parameters = FindParameterByRefKeyNode(refkeys[0], cnode_func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - parameter_name = parameters[0]->cast()->name(); - target_parameter = parameters[0]; - auto candidate_set_by_para = cnode_func_graph->manager()->node_users()[parameters[0]]; - for (auto &candidate : candidate_set_by_para) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - (void)target_set.insert(candidate); - } - } else if (node->isa()) { - // Dealing with the Parameter case - MS_EXCEPTION_IF_NULL(node->func_graph()); - MS_EXCEPTION_IF_NULL(node->func_graph()->manager()); - auto candidate_set = node->func_graph()->manager()->node_users()[node]; - for (auto &candidate : candidate_set) { - auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !IsValueNode(c->input(0))) { - continue; - } - if (!IsAutoParallelCareNode(c)) { - continue; - } - (void)target_set.insert(candidate); - } - // In this case, node is a Parameter - parameter_name = node->cast()->name(); - target_parameter = node; - } - if (target_set.size() <= 1) { - continue; - } - - // Rule out the case when a Parameter being used by a Operator, but the Operator appears in multiple CNODEs - std::set target_without_duplicate; - for (auto &target : target_set) { - auto target_cnode = target.first->cast(); - auto input_index = target.second; - (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); - } - if (target_without_duplicate.size() <= 1) { - continue; - } - - // Here, it is sure that this Parameter (RefKey) is being used by multiple Operators. - OperatorInfoPtr tmp_identity_ptr; - bool new_identity = false; - std::string tmp_identity_name; - auto returned_identity = entire_costgraph->FindTmpIdentityByParameterName(parameter_name); - if (returned_identity != nullptr) { - // In this case, the TmpIdentityInfo instance has already been created - new_identity = false; - tmp_identity_ptr = returned_identity; - tmp_identity_name = tmp_identity_ptr->name(); - } else { - // In the case, the TmpIdentityInfo instance has NOT been created. Thus, a new one is created. - new_identity = true; - // 1) extract input shape from this Parameter - MS_EXCEPTION_IF_NULL(target_parameter); - AbstractBasePtr abstract = target_parameter->abstract(); - if (abstract == nullptr) { - MS_LOG(EXCEPTION) << "Failure: abstract is nullptr"; - } - auto input_shape = dyn_cast(abstract->GetShapeTrack()); - if (input_shape == nullptr) { - MS_LOG(EXCEPTION) << "Failure: input_shape is nullptr"; - } - std::vector shape_int = input_shape->shape(); - Shape shape; - (void)std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(shape), - [](int sub_shape) { return static_cast(sub_shape); }); - Shapes inputs_shape = {shape}; - Shapes outputs_shape = {shape}; - // 2) init the attr - std::unordered_map attr = {}; - - // Create the TmpIdentity instance - tmp_identity_ptr = std::make_shared(inputs_shape, outputs_shape, attr); - tmp_identity_ptr->set_name(tmp_identity_ptr->name() + std::to_string(TOTAL_OPS)); - TOTAL_OPS++; - tmp_identity_ptr->set_refkey_parameter_name(parameter_name); - // Set the parameter and type lengths for inputs and outputs - std::vector is_parameter; - auto casted_target_parameter = target_parameter->cast(); - MS_EXCEPTION_IF_NULL(casted_target_parameter); - if (casted_target_parameter->has_default()) { - bool requires_grad = casted_target_parameter->default_param()->requires_grad(); - is_parameter.push_back(requires_grad); - } else { - is_parameter.push_back(false); - } - if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) { - MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed"; - } - auto node_type = target_parameter->Type(); - if (node_type->isa()) { - auto input_element_type = node_type->cast()->element(); - std::vector type_length = {GetLengthOfDataType(input_element_type)}; - if (tmp_identity_ptr->SetInputAndOutputTypeLength(type_length, type_length) != SUCCESS) { - MS_LOG(EXCEPTION) << "Setting input and output type length for TmpIdentityInfo failed"; - } - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << node_type->type_name(); - } - - // Generate strategies for this TmpIdentityInfo instance; - if (tmp_identity_ptr->GenerateStrategies(0) != SUCCESS) { - MS_LOG(EXCEPTION) << "Strategy search for Operator failed : " << tmp_identity_ptr->name(); - } - } - // A flag recording whether new edges have been created or not - bool add_identity_edge = false; - - // Create edges between this TmpIdentityInfo instance and subsequent Operator instances - for (auto &target : target_set) { - auto target_cnode = target.first->cast(); - auto prim = GetValueNode(target_cnode->input(0)); - auto input_index = target.second; - - std::string edge_name = - std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); - // If the edge between these two operators already has been added, then the edge will not be added again. - if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { - continue; - } - std::shared_ptr edge_ptr = std::make_shared( - edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); - - if (edge_ptr->InitEdgeCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Edge cost initialization failed"; - } - target_cnode->operator_info()->AddPrevEdge(edge_ptr); - tmp_identity_ptr->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " - << target_cnode->operator_info()->name(); - add_identity_edge = true; - } - if (new_identity && add_identity_edge) { - // Add the TmpIdentityInfo to CostGraph if BOTH two conditions are satisfied - entire_costgraph->AddOperator(tmp_identity_ptr); - } - } -} - -bool FindReshape(const CNodePtr &cnode) { - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - return false; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { - return false; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; - } - if (prim->name() != RESHAPE) { - return false; - } - return true; -} - -// find previous node, then obtain its strategy_cost_ vector to get its layout vector. -bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) { - // if previous node is a parameter, handle it in the outsize. - if (node->isa()) { - return false; - } - if (!node->isa()) { - return false; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return false; - } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - *pre_operator_info = cnode->operator_info(); - *out_index = 0; - return true; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = prim_anf_node->value()->cast(); - if (prim->name() == TUPLE_GETITEM) { - *out_index = GetTupleGetItemIndex(cnode); - // find tuple_get_item's previous node - auto pre_node = cnode->input(1); - if (!pre_node->isa()) { - MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; - } - CNodePtr pre_cnode = pre_node->cast(); - if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { - *pre_operator_info = pre_cnode->operator_info(); - return true; - } - return false; - } - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (prim->name() == DEPEND && index != 1) { - continue; - } - if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { - continue; - } - return true; - } - MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; - return false; -} - -// find next node, then obtain its strategy_cost_ vector to get its layout vector. -// if reshape's output connect to several primitive, return the first layout found -bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - FuncGraphManagerPtr manager = cnode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { - MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); - *next_operator_info = use_apply->operator_info(); - *in_index = node_pair.second - 1; - return true; - } - MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); - - if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { - return true; - } - } - return false; -} - -void ReshapeCostCompute(const std::vector &all_nodes) { - for (auto node : all_nodes) { - auto cnode = node->cast(); - if (!FindReshape(cnode)) { - continue; - } - MS_ASSERT(cnode->inputs().size() == 3); - // get previous node's strategy_cost_ - auto pre_node = cnode->input(1); - int32_t out_index = 0; - OperatorInfoPtr pre_operator_info; - std::vector> pre_stra_costs; - if (pre_node->isa()) { - OperatorInfoPtr operator_info = cnode->operator_info(); - auto reshape_info = std::dynamic_pointer_cast(operator_info); - reshape_info->SetCostForReshapeWithParameter(); - pre_operator_info = reshape_info; - pre_stra_costs = reshape_info->strategy_cost(); - } else { - if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { - MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; - } - pre_stra_costs = pre_operator_info->strategy_cost(); - } - // get next node's strategy_cost_ - int32_t in_index = 0; - OperatorInfoPtr next_operator_info; - std::vector> next_stra_costs; - bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); - if (!find_next_node) { - MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; - } - // set input_layout and output_layout for reshape. - // init reshape and set cost for each input_layout and output_layout. - OperatorInfoPtr operator_info = cnode->operator_info(); - auto reshape_info = std::dynamic_pointer_cast(operator_info); - reshape_info->set_pre_operator_name(pre_operator_info->name()); - reshape_info->set_pre_operator_index(out_index); - if (find_next_node) { - next_stra_costs = next_operator_info->strategy_cost(); - reshape_info->set_next_operator_name(next_operator_info->name()); - reshape_info->set_next_operator_index(in_index); - } - bool is_prev_param = pre_node->isa(); - if (reshape_info->GenetateStrategyCosts(pre_stra_costs, next_stra_costs, out_index, in_index, is_prev_param) != - SUCCESS) { - MS_LOG(EXCEPTION) << "reshape genetate strategy_costs failed!"; - } - } -} - -Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root) { - // There are 4 meta-steps to determine the parallelization strategy for the ANF graph. - // Step 1: Traverse the ANF graph, and create NODEs for costgraph: - // create the OperatorInfo object for each primitive, and enumerate the parallelization strategies - // for each OperatorInfo; - // Step 1.1: Deal with 'Reshape': - // For 'Reshape', it takes its previous operator's layout as its input layout, and takes its next operator's - // layout as its output layout. - // Step 2: Traverse the ANF graph, and create EDGES for costgraph: - // create the Edge object for each pair of OperatorInfo, and enumerate the parallelization strategies - // for each edge, based on the strategies of two OperatorInfos; - // Step 3: Augment the costgraph: - // taking care for the case of a single Parameter being used by multiple operators. Create a TmpIdentity - // operator for this Parameter, and add an edge for the use of this Parameter by each - // subsequent operator; - // Step 3.1: Calculate memory usage: - // note the memory usage calculation is different in training phase and inference phase. - // Step 4: Run the Dynamic Programming algorithm: - // in this process, cost is calculated based on not only the operators, but also the edges. Here, the edge - // cost is caused by the redistribution of a operator's output tensor layout to the next operator's input - // tensor layout. Note that there may be several connected components in the costgraph, and the DP algorithm - // runs on each of them. - // - // OUTPUT: the determined strategy for each operator. - - // Step 1 - if (CostModelContext::GetInstance()->is_multi_subgraphs()) { - if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } else { - if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } - // Step 1.1 - ReshapeCostCompute(all_nodes); - // Step 2 - ConstructCostGraphEdges(all_nodes); - MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() - << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; - - // Step 3: Augment the costgraph. - AugmentCostGraph(all_nodes); - MS_LOG(INFO) << "After the augmenting procedure, there are " << entire_costgraph->GetOperators().size() - << " operators, and " << entire_costgraph->GetNumEdges() << " edges."; - - // Step 3.1: Calculate the memory usage - if (entire_costgraph->CalculateMemoryCost() != SUCCESS) { - MS_LOG(EXCEPTION) << "Calculating memory cost failed."; - } - - // Step 4: run DP algorithm on the costgraph. - if (GetStrategy(entire_costgraph) != SUCCESS) { - MS_LOG(ERROR) << "Strategy search for cost-graph fails"; - return FAILED; - } - MS_LOG(INFO) << "Searching strategy succeeded."; - - if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { - MS_LOG(INFO) << "Init selected strategy succeeded."; - } else { - MS_LOG(EXCEPTION) << "Init selected strategy failed."; - } - - // print the selected strategy - for (auto &op : entire_costgraph->GetOperators()) { - StrategyPtr s_strategy = op->selected_strategy(); - MS_LOG(INFO) << op->name() << " : The strategy is:"; - PrintStrategy(s_strategy); - } - - return SUCCESS; -} - -std::vector> RecInputTensorNames(const std::map::iterator &it, - std::vector> input_tensor_names) { - for (size_t j = 0; j < input_tensor_names.size(); j++) { - for (size_t k = 0; k < input_tensor_names[j].size(); k++) { - if (it->first == input_tensor_names[j][k]) { - input_tensor_names[j][k] = it->second; - break; - } - } - } - return input_tensor_names; -} - -CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node) { - PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == TUPLE_GETITEM || prim->name() == DEPEND) { - auto prev_cnode = cnode->input(1)->cast(); - if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { - return nullptr; - } - auto prev_prim = prev_cnode->input(0)->cast()->value()->cast(); - while (prev_prim->name() == TUPLE_GETITEM || prev_prim->name() == DEPEND) { - prev_cnode = prev_cnode->input(1)->cast(); - if (prev_cnode == nullptr || !IsValueNode(prev_cnode->input(0))) { - return nullptr; - } - prev_prim = prev_cnode->input(0)->cast()->value()->cast(); - } - return prev_cnode; - } - return nullptr; -} - -Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root) { - if (CostModelContext::GetInstance()->is_multi_subgraphs()) { - if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } else { - if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { - MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " - << entire_costgraph->GetOperators().size() << " operators."; - } else { - MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; - } - } - ReshapeCostCompute(all_nodes); - - auto ops = entire_costgraph->GetOperators(); - std::vector> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); - auto tuple_getitem_list = entire_costgraph->get_tuple_getitem_list(); - for (auto it = tuple_getitem_list.begin(); it != tuple_getitem_list.end();) { - input_tensor_names = RecInputTensorNames(it++, input_tensor_names); - } - std::shared_ptr graph = ParseGraph(ops, input_tensor_names); - - std::shared_ptr>> eli_list(new std::vector>); - std::shared_ptr> index_list(new std::vector); - graph = EliminateGraph(graph, eli_list, index_list); - - size_t num_device = g_device_manager->DeviceNum(); - double device_memory = entire_costgraph->GetDeviceMemory(); - if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { - MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; - } else { - MS_LOG(ERROR) << "PartitionForAllDevices failed."; - return FAILED; - } - - GenerateStrategy(graph, ops, eli_list, input_tensor_names, index_list); - - if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { - MS_LOG(INFO) << "Init selected strategy succeeded."; - } else { - MS_LOG(ERROR) << "Init selected strategy failed."; - return FAILED; - } - - // print the selected strategy - for (auto &op : entire_costgraph->GetOperators()) { - StrategyPtr s_strategy = op->selected_strategy(); - MS_LOG(INFO) << op->name() << " : The strategy is:"; - PrintStrategy(s_strategy); - } - - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.h b/mindspore/ccsrc/parallel/step_auto_parallel.h deleted file mode 100644 index c923e5770f..0000000000 --- a/mindspore/ccsrc/parallel/step_auto_parallel.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PARALLEL_STEP_AUTO_PARALLEL_H_ -#define PARALLEL_STEP_AUTO_PARALLEL_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "optimizer/opt.h" -#include "parallel/status.h" -#include "pipeline/pipeline.h" - -namespace mindspore { -namespace parallel { -bool IsSplittableOperator(const std::string &); - -bool IsAutoParallelCareNode(const CNodePtr &); - -// main step of Auto-parallel -bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); - -size_t GetLengthOfDataType(const TypePtr &type); - -std::vector ExtractInputParameterByNode(const CNodePtr &node); - -std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); - -std::vector ExtractOutputTypeByNode(const CNodePtr &node); - -Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root); - -Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &root); - -void ConstructCostGraphEdges(const std::vector &all_nodes); - -void AugmentCostGraph(const std::vector &all_nodes); - -Status ParallelStrategySearch(const std::vector &all_nodes, const FuncGraphPtr &root); - -Status ParallelStrategyRecSearch(const std::vector &all_nodes, const FuncGraphPtr &root); - -std::vector> RecInputTensorNames(const std::map::iterator &it, - std::vector> input_tensor_names); - -CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node); -} // namespace parallel -} // namespace mindspore -#endif // PARALLEL_STEP_AUTO_PARALLEL_H_ diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc deleted file mode 100644 index c79cc82d15..0000000000 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ /dev/null @@ -1,2362 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/step_parallel.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "ir/param_value.h" -#include "operator/ops.h" -#include "optimizer/optimizer.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/context.h" -#include "parallel/device_manager.h" -#include "parallel/dynamic_creator.h" -#include "parallel/graph_util/generate_graph.h" -#include "parallel/graph_util/graph_info.h" -#include "parallel/graph_util/node_info.h" -#include "parallel/node_check.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" -#include "utils/comm_manager.h" -#include "utils/symbolic.h" -#include "pipeline/static_analysis/prim.h" - -using mindspore::tensor::Tensor; - -namespace mindspore { -namespace parallel { -static const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; -static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; -// g_RefMap, for CNode B input i is a RefKey[Parameter C], -// it will be one item in map with key: C, and value: (B, i) -static std::map> g_RefMap; - -void SetCommunicationOpGroupLabel(std::vector new_node_input) { - if (new_node_input.empty()) { - return; - } - - ValueNodePtr prim_anf_node = new_node_input[0]->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - - auto attrs = prim->attrs(); - auto iter = attrs.find(GROUP); - if (iter != attrs.end()) { - auto value = iter->second; - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - std::string hash_name = value->cast()->value(); - MS_EXCEPTION_IF_NULL(g_device_manager); - std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); - (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); - } - } -} - -std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { - MS_EXCEPTION_IF_NULL(node); - OperatorArgs arg_forward = op.second; - ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); - MS_EXCEPTION_IF_NULL(pyop_instance); - OperatorParams params = arg_forward.second; - - std::vector new_node_input = {NewValueNode(pyop_instance), node}; - if (!params.empty()) { - for (auto ¶m : params) { - AnfNodePtr val = NewValueNode(param.first.second); - MS_EXCEPTION_IF_NULL(val); - int32_t position = param.second; - (void)new_node_input.insert(new_node_input.begin() + position, val); - } - } - - // if the op have 'group' attr, set the rank list name for the op - SetCommunicationOpGroupLabel(new_node_input); - return new_node_input; -} - -void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, - const FuncGraphPtr &func_graph, const std::string &instance_name) { - // insert new node before the node - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - std::vector node_input = CreateInput(op, pre_node, instance_name); - CNodePtr new_node = func_graph->NewCNode(node_input); - MS_EXCEPTION_IF_NULL(new_node); - if (instance_name.find(SPLIT_SENS) == std::string::npos) { - new_node->set_in_forward_flag(true); // mark forward flag - } - auto new_node_value = node_input[0]->cast(); - MS_EXCEPTION_IF_NULL(new_node_value); - PrimitivePtr new_node_prim = new_node_value->value()->cast(); - new_node_prim->set_instance_name(instance_name); - new_node_prim->set_attr("keep_value_node_input", MakeValue(true)); - new_node->set_scope(scope); - node_input[0]->set_scope(scope); - manager->SetEdge(node, SizeToInt(index), new_node); -} - -std::string CreateInstanceName(const CNodePtr &node, size_t index) { - MS_EXCEPTION_IF_NULL(node); - if (!IsValueNode(node->input(0))) { - MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; - } - std::string name_base = node->fullname_with_scope(); - std::string name = name_base + "_" + std::to_string(index); - std::string instance_name = HashInstanceName(name); - return instance_name; -} - -void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // step1:get graph manager distribute_operator - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto uses_set = manager->node_users()[node]; - CNodePtr node_to_insert = node; - for (auto &uses_pair : uses_set) { - auto uses_cnode = uses_pair.first->cast(); - MS_EXCEPTION_IF_NULL(uses_cnode); - if (!IsValueNode(uses_cnode->input(0))) { - break; - } - PrimitivePtr value_node_prim = GetValueNode(uses_cnode->input(0)); - MS_EXCEPTION_IF_NULL(value_node_prim); - if (value_node_prim->name() == TUPLE_GETITEM) { - if (uses_set.size() > 1) { - MS_LOG(EXCEPTION) << "Now only support one output, but got " << uses_set.size(); - } - node_to_insert = uses_cnode; - } - } - MS_EXCEPTION_IF_NULL(node_to_insert); - std::reverse(forward_op.begin(), forward_op.end()); - - // step2:traverse op_list and insert node - for (size_t index = 0; index < forward_op.size(); ++index) { - std::string instance_name_base = FORWARD_OP; - std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); - std::vector forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); - CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode - MS_EXCEPTION_IF_NULL(forward_node); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - forward_node->set_scope(scope); - forward_node->set_in_forward_flag(true); - forward_input[0]->set_scope(scope); - (void)manager->Replace(node_to_insert, forward_node); // using Replace function to insert node - } -} - -CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(prev); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (uint32_t i = 0; i < num; i++) { - std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), prev, - CreatInt32Imm(UintToInt(i))}; - auto tuple_get_item = func_graph->NewCNode(tuple_get_item_inputs); - MS_EXCEPTION_IF_NULL(tuple_get_item); - make_tuple_inputs.push_back(tuple_get_item); - } - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(prev, make_tuple); - return make_tuple; -} - -void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, - const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(pre_node); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if ((redistribution_oplist_ptr->first).size() != (redistribution_oplist_ptr->second).size()) { - MS_LOG(EXCEPTION) << "size of OperatorVector and OutPutInfoVector must be the same!"; - } - for (size_t index = 0; index < (redistribution_oplist_ptr->first).size(); ++index) { - if (pos >= SizeToInt(node->inputs().size())) { - MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; - } - // Creat new node - AnfNodePtr target_node = node->input(IntToSize(pos)); - MS_EXCEPTION_IF_NULL(target_node); - // Creat instance_name - auto op = (redistribution_oplist_ptr->first)[index]; - std::string op_name = (redistribution_oplist_ptr->first)[index].first; - std::string instance_name_base = REDISTRIBUTION_OP; - std::string instance_name = instance_name_base + "_" + CreateInstanceName(pre_node, index) + op_name; - InsertNode(op, node, IntToSize(pos), target_node, func_graph, instance_name); - if ((redistribution_oplist_ptr->second)[index].first) { - target_node = node->input(IntToSize(pos)); - MS_EXCEPTION_IF_NULL(target_node); - (void)InsertMakeTuple(target_node, (redistribution_oplist_ptr->second)[index].second, func_graph); - } - } -} - -void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, - const std::string &instance_name) { - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; - } - - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (pos >= SizeToInt(node->inputs().size())) { - MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " - << instance_name; - } - // Creat new node - AnfNodePtr pre_node = node->input(IntToSize(pos)); - MS_EXCEPTION_IF_NULL(pre_node); - InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); -} - -TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, - const OperatorInfoPtr &distribute_operator) { - TensorInfo tensorinfo_in; - if (middle_prim->name() == TUPLE_GETITEM) { - auto value_node = middle_node->input(2)->cast(); - MS_EXCEPTION_IF_NULL(value_node); - size_t index_s = IntToSize(GetValue(value_node->value())); - if (index_s >= distribute_operator->outputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index out of range, index: " << index_s - << ", vector size: " << distribute_operator->outputs_tensor_info().size(); - } - tensorinfo_in = distribute_operator->outputs_tensor_info()[index_s]; - } else { - if (distribute_operator->outputs_tensor_info().empty()) { - MS_LOG(EXCEPTION) << "The outputs tensor info is empty"; - } - tensorinfo_in = distribute_operator->outputs_tensor_info()[0]; - } - return tensorinfo_in.tensor_layout(); -} - -OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!IsParallelCareNode(node)) { - return nullptr; - } - OperatorInfoPtr distribute_operator = node->operator_info(); - if (distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; - } - return distribute_operator; -} - -void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, - const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr &pre_node) { - FuncGraphPtr func_graph = middle_node->func_graph(); - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; - } - CNodePtr next_node = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(next_node); - auto middle_value = middle_node->input(0)->cast(); - MS_EXCEPTION_IF_NULL(middle_value); - PrimitivePtr middle_prim = middle_value->value()->cast(); - MS_EXCEPTION_IF_NULL(middle_prim); - OperatorInfoPtr next_distribute_operator = GetDistributeOperator(next_node); - if (next_distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "Failure: " << next_node->ToString() << " GetDistributeOperator failed"; - } - RankList dev_list = distribute_operator->global_device_list(); - std::string next_prim_name = GetValueNode(next_node->input(0))->name(); - MS_LOG(DEBUG) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim " << next_prim_name; - MS_LOG(DEBUG) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " << next_node->ToString(); - // extract tensor layout in and out - if (distribute_operator->outputs_tensor_info().empty()) { - MS_LOG(EXCEPTION) << "Failure:pre_node's tensorinfo_in is empty"; - } - - if (IntToSize(index - 1) >= next_distribute_operator->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, the index is " << index - 1 << ", the vector size is " - << next_distribute_operator->inputs_tensor_info().size(); - } - TensorInfo tensorinfo_out = next_distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; - TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); - TensorLayout tensorlayout_in = GetTensorInLayout(middle_node, middle_prim, distribute_operator); - if (tensor_redistribution.Init(tensorlayout_in, tensorlayout_out, dev_list) == FAILED) { - MS_LOG(ERROR) << "Redistribution: middle_prim " << middle_prim->name() << " next_prim : " << next_prim_name; - MS_LOG(ERROR) << "Redistribution: middle_node " << middle_node->ToString() << " next_node " - << next_node->ToString(); - DumpGraph(func_graph, "redistribution_error"); - MS_LOG(EXCEPTION) << "Failure:tensor_redistribution init failed"; - } - RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); - if (redistribution_oplist_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Failure:InferTensorRedistribution failed"; - } - MS_LOG(DEBUG) << "Redistribution size " << redistribution_oplist_ptr->first.size(); - if (!redistribution_oplist_ptr->first.empty()) { - // insert node before next node - InsertRedistribution(redistribution_oplist_ptr, next_node, func_graph, node_pair.second, pre_node); - } -} - -bool StrategyFound(std::unordered_map attrs) { - auto iter = attrs.find(STRATEGY); - return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); -} - -bool HasStrategy(const FuncGraphPtr &root) { - AnfNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - - for (auto &node : all_nodes) { - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - auto attrs = prim->attrs(); - if (StrategyFound(attrs)) { - return true; - } - } - - return false; -} - -bool IsCommunicationOp(const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(prim); - return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); -} - -bool FindCommunicationOp(const std::vector &all_nodes) { - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr prim_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_value_node); - PrimitivePtr prim = GetValueNode(prim_value_node); - MS_EXCEPTION_IF_NULL(prim); - - if (IsCommunicationOp(prim) && cnode->in_forward_flag()) { - MS_EXCEPTION_IF_NULL(prim_value_node->scope()); - MS_LOG(INFO) << "The graph contain communication op: " << prim->name() << ", scope name is " - << prim_value_node->scope()->name(); - return true; - } - } - return false; -} - -bool IsParallelCareNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - ValueNodePtr prim_node = cnode->input(0)->cast(); - if (prim_node == nullptr) { - return false; - } - PrimitivePtr prim = prim_node->value()->cast(); - if (prim == nullptr) { - return false; - } - if (IsInBlackList(prim)) { - MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); - return false; - } - // get_next is not in the forward graph, we need mark the get_next as the forward node - if (prim->name() == GET_NEXT) { - return true; - } - if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { - return false; - } - - return cnode->in_forward_flag(); -} - -void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, - const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(node->func_graph()); - FuncGraphManagerPtr manager = node->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[node]; - CNodePtr insert_node_new; - if (IsValueNode(node->input(0))) { - auto current_value = node->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - insert_node_new = ((current_prim->name() == TUPLE_GETITEM) ? node : insert_node); - } else { - insert_node_new = insert_node; - } - MS_EXCEPTION_IF_NULL(insert_node_new); - for (auto &node_pair : node_set) { - CNodePtr use_cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(use_cnode); - if (!IsValueNode(use_cnode->input(0))) { - StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); - } else { - ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { - Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, - pre_node); - } else { - StepRedistribution(use_cnode, distribute_operator, insert_node_new, tensor_redistribution, pre_node); - } - } - } -} - -void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(next_node); - OperatorInfoPtr op_info = next_node->operator_info(); - MS_EXCEPTION_IF_NULL(op_info); - - // If the shape of tensor is [] or [1], no need to split it. - Shapes shapes = GetNodeShape(node); - if (shapes.size() != 1) { - MS_LOG(EXCEPTION) << "Split tensor for " << op_info->name() - << ": GetNodeShape for tensor_node, output size is not 1"; - } - Shape shape = shapes[0]; - std::string shape_str = ShapeToString(shape); - if (shape.empty() || ((shape.size() == 1) && (shape[0] == 1))) { - MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape is " << shape_str - << ", no need to split it."; - return; - } - - MS_LOG(INFO) << "Split tensor for " << op_info->name() << ": The shape of tensor is " << shape_str; - - // extract tensor layout - if (IntToSize(index - 1) >= op_info->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, index is " << index - 1 << ", vector size is " - << op_info->inputs_tensor_info().size(); - } - TensorInfo tensor_info = op_info->inputs_tensor_info()[IntToSize(index - 1)]; - TensorLayout tensor_layout = tensor_info.tensor_layout(); - - // Use _GetTensorSlice operator to split the tensor - FuncGraphPtr func_graph = next_node->func_graph(); // only cnode can get the graph - MS_EXCEPTION_IF_NULL(func_graph); - Operator op = CreateGetTensorSliceOp(tensor_layout); - InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); - if (!op_info->sub_ops().empty()) { - auto sub_ops = op_info->sub_ops(); - for (size_t i = 0; i < sub_ops.size(); i++) { - if (!sub_ops.at(i).empty()) { - InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB); - } - } - } -} - -void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto &node_pair : node_set) { - CNodePtr use_cnode = node_pair.first->cast(); - if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(use_cnode_prim); - if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_cnode)) { - SplitTensor(node, use_cnode, node_pair.second); - } - } -} - -std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, - const CNodePtr &node) { - OperatorArgs arg_replace_op = replace_op.second; - ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); - if (pyop_instance == nullptr) { - MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; - } - OperatorParams params = arg_replace_op.second; - if (node->inputs().size() < 2) { - // GetNext operator dose not has input - if (node->inputs().size() == 1) { - return {NewValueNode(pyop_instance)}; - } - MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; - } - std::vector replace_input = {NewValueNode(pyop_instance), node->input(1)}; - auto prim = GetValueNode(node->input(0)); - if (prim->name() == EMBEDDING_LOOKUP) { - replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; - } - if (!params.empty()) { - Param param_first = *(params.begin()); - int32_t first_position = param_first.second; - if (first_position == 1) { - replace_input.pop_back(); - } - for (auto ¶m : params) { - AnfNodePtr val = NewValueNode(param.first.second); - if (val == nullptr) { - MS_LOG(EXCEPTION) << "Failure:val is nullptr"; - } - int32_t position = param.second; - (void)replace_input.insert(replace_input.begin() + position, val); - } - } - - return replace_input; -} - -void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; - } - std::string instance_name = CreateInstanceName(node, 0); - std::vector replace_input; - replace_input = ReplaceOpInput(replace_op, instance_name, node); - CNodePtr replace_node = func_graph->NewCNode(replace_input); - MS_EXCEPTION_IF_NULL(replace_node); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - replace_node->set_scope(scope); - replace_node->set_in_forward_flag(true); - replace_input[0]->set_scope(scope); - (void)manager->Replace(node, replace_node); -} - -void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { - // step1:get graph manager distribute_operator - OperatorInfoPtr distribute_operator = node->operator_info(); - if (distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; - } - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; - } - // step2:traverse op_list and insert node - std::reverse(replace_op.begin(), replace_op.end()); - auto replace_op_info = distribute_operator->replace_op_info(); - std::reverse(replace_op_info.begin(), replace_op_info.end()); - if (!replace_op_info.empty() && replace_op_info.size() != replace_op.size()) { - MS_LOG(EXCEPTION) << "replace_op_info is not empty and size not equal to replace_op!"; - } - bool replace_op_info_flag = !replace_op_info.empty(); - for (size_t index = 0; index < replace_op.size(); ++index) { - std::string instance_name = CreateInstanceName(node, index); - std::vector replace_input; - if (index != replace_op.size() - 1) { - replace_input = CreateInput(replace_op[index], node, instance_name); - } else { - replace_input = ReplaceOpInput(replace_op[index], instance_name, node); - } - CNodePtr replace_node = func_graph->NewCNode(replace_input); - MS_EXCEPTION_IF_NULL(replace_node); - ScopePtr scope = node->scope(); - MS_EXCEPTION_IF_NULL(scope); - replace_node->set_scope(scope); - if (index == replace_op.size() - 1) { - (void)replace_node->set_operator_info(node->operator_info()); - } - replace_node->set_in_forward_flag(true); - replace_input[0]->set_scope(scope); - if (replace_op_info_flag && replace_op_info[index].first) { - auto new_cnode = InsertMakeTuple(replace_node, replace_op_info[index].second, func_graph); - (void)manager->Replace(node, new_cnode); // using Replace function to insert node - } else { - (void)manager->Replace(node, replace_node); // using Replace function to insert node - } - } - MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); -} - -bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { - ValueNodePtr anf_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(anf_node); - PrimitivePtr prim = anf_node->value()->cast(); - return (prim->name() == name); -} - -void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(replace_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(replace_graph->second); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; - } - for (auto &replace_input : replace_graph->first) { - auto pre_node = node->input(IntToSize(replace_input.second)); - manager->SetEdge(replace_input.first, 1, pre_node); - } - // "(void)manager->Replace(replace_graph->first, pre_node);" can not be called - auto replace_output = replace_graph->second; - MS_EXCEPTION_IF_NULL(replace_output); - (void)manager->Replace(node, replace_output); -} - -int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != 3) { - MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; - } - - if (!cnode->input(2)->isa()) { - MS_LOG(EXCEPTION) << "The index of tuple getitem is not a value node"; - } - - ValuePtr tuple_index_value = GetValueNode(cnode->input(2)); - MS_EXCEPTION_IF_NULL(tuple_index_value); - if (!tuple_index_value->isa()) { - MS_LOG(EXCEPTION) << "The index of tuple getitem is not int32"; - } - return tuple_index_value->cast()->value(); -} - -// Judge whether the node is a loss, and if there are multiple outputs, -// get which output is a grad according to the tuple getitem. -// Currently, it is not supported that the sens is a tuple. -LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { - MS_EXCEPTION_IF_NULL(loss_node); - FuncGraphPtr sub_graph = loss_node->func_graph(); - MS_EXCEPTION_IF_NULL(sub_graph); - CNodePtr return_node = sub_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - LossNodeInfo node_info; - - // return -> cast - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_node = pre_cnode->input(1); - } - - // return -> loss - if (pre_node == loss_node) { - node_info.has_tuple_getitem = false; - node_info.dout_index = 0; - return node_info; - } - - // return -> tuple_getitem -> loss - auto cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto current_value = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - // size of common cnode is larger than 1 - if (cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2"; - } - - if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) { - // size of tuple_getitem cnode is 3 - auto tuple_index = GetTupleGetItemIndex(cnode); - node_info.has_tuple_getitem = true; - node_info.dout_index = tuple_index; - return node_info; - } - - MS_LOG(EXCEPTION) << "Invalid loss"; -} - -void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - size_t node_size = node->inputs().size(); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (size_t index = 1; index < node_size; ++index) { - AnfNodePtr input = node->input(index); - MS_EXCEPTION_IF_NULL(input); - if (!input->isa() && !input->isa()) { // if it is not a tensor, continue - MS_LOG(INFO) << "insert div op: the index " << index << " is not tensor, skip"; - continue; - } - - for (size_t pos = 0; pos < virtual_div_op.size(); ++pos) { - std::string instance_name = CreateInstanceName(node, pos); - InsertNode(virtual_div_op[pos], node, index, node->input(index), func_graph, instance_name); - } - MS_LOG(INFO) << "insert div op for input index " << index << " of node"; - } -} - -std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - if (!node->isa() && !node->isa() && !node->isa()) { - return std::make_pair(nullptr, false); - } else if (node->isa()) { - return std::make_pair(node, false); - } else if (node->isa()) { - if (IsValueNode(node)) { - std::vector param_v = FindParameterByRefKeyNode(node, func_graph); - if (param_v.size() != 1) { - MS_LOG(EXCEPTION) << "FindParameterByRefKeyNode failed, return vector size must be 1, real is " - << param_v.size(); - } - return std::make_pair(node, true); - } - return std::make_pair(nullptr, false); - } else { - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (!FindParameter(cnode->input(index), func_graph).first) { - continue; - } - return FindParameter(cnode->input(index), func_graph); - } - } else { - if (IsParallelCareNode(cnode)) { - return std::make_pair(nullptr, false); - } else { - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - PrimitivePtr prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == DEPEND && index != 1) { - continue; - } - if (!FindParameter(cnode->input(index), func_graph).first) { - continue; - } - return FindParameter(cnode->input(index), func_graph); - } - } - } - } - return std::make_pair(nullptr, false); -} - -std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(anode); - MS_EXCEPTION_IF_NULL(anode->func_graph()); - FuncGraphManagerPtr manager = anode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[anode]; - bool result = false; - CNodePtr cnode_return = nullptr; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == name && node_pair.second == 1) { - if (use_apply->func_graph() == func_graph) { - result = true; - cnode_return = use_apply; - MS_LOG(INFO) << "Find Primitive " << name << " in the same func_graph"; - continue; - } - MS_LOG(INFO) << "Find Primitive " << name << " in different func_graph"; - } - } - return std::make_pair(result, cnode_return); -} - -bool IsCastBeforMirror(const CNodePtr &node, size_t index) { - // only if cast_before_mirror is true, pre node is cast and type is not float32 return true - if (!ParallelContext::GetInstance()->cast_before_mirror()) { - return false; - } - auto pre_node = node->input(index); - MS_EXCEPTION_IF_NULL(pre_node); - auto cnode = pre_node->cast(); - if (cnode == nullptr || !IsValueNode(cnode->input(0))) { - return false; - } - auto pre_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(pre_value_node); - auto pre_prim = pre_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(pre_prim); - if (pre_prim->name() != CAST) { - return false; - } - auto node_type = pre_node->Type(); - MS_EXCEPTION_IF_NULL(node_type); - if (!node_type->isa()) { - MS_LOG(EXCEPTION) << "Unknown type."; - } - auto input_element_type = node_type->cast()->element(); - MS_EXCEPTION_IF_NULL(input_element_type); - auto type_id = input_element_type->type_id(); - - return (type_id != kNumberTypeFloat32); -} - -void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - size_t node_size = node->inputs().size(); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (mirror_ops.size() != node_size - 1) { - MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() - << ", node_size is " << node_size; - } - for (size_t index = 1; index < node_size; ++index) { - OperatorVector backward_op = mirror_ops[index - 1]; - if (backward_op.empty()) { - continue; - } - std::pair param_node_pair = FindParameter(node->input(index), func_graph); - if (!param_node_pair.first) { - continue; - } - // not a RefKey - if (!param_node_pair.second) { - auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); - // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead - if (next_cnode.first) { - MS_EXCEPTION_IF_NULL(next_cnode.second); - manager->SetEdge(node, SizeToInt(index), next_cnode.second); - continue; - } - } - // if the parameter found is a RefKey, or no MirrorOp is found in the same graph, insert a new MirrorOp - // only one MirrorOp in backward_op - if (backward_op.size() != 1) { - MS_LOG(EXCEPTION) << "backward_op size must be 1, real is " << backward_op.size(); - } - std::string instance_name = MIRROR_OP; - if (IsCastBeforMirror(node, index)) { - for (auto &op : backward_op) { - // insert new node before the node - CNodePtr cnode = node->input(index)->cast(); - MS_EXCEPTION_IF_NULL(cnode); - AnfNodePtr pre_node = cnode->input(1); - InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); - } - } else { - for (auto &op : backward_op) { - AnfNodePtr pre_node = node->input(index); - InsertNode(op, node, index, pre_node, func_graph, instance_name); - } - } - } -} - -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(node); - - bool is_loss_cnode = - std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), - [node](const std::pair &element) { return element.second == node; }); - - MirrorOps mirror_ops = distribute_operator->mirror_ops(); - VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); - // insert mirror op - if (!mirror_ops.empty()) { - MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name(); - InsertMirrorOps(mirror_ops, node); - } - // insert virtual div op - if (!virtual_div_op.empty() && is_loss_cnode) { - MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name(); - InsertVirtualDivOp(virtual_div_op, node); - } -} - -std::string GetDisOpName(const std::string &prim_name) { - std::string op_name = prim_name; - if (!prim_name.empty() && (prim_name[0] == '_')) { - op_name = prim_name.substr(1); - } - return op_name + "Info"; -} - -OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, - const std::vector &shape_list) { - if (shape_list.size() != 2) { - MS_LOG(ERROR) << "The size of shape list is not 2"; - return nullptr; - } - if (name.length() == 0) { - MS_LOG(EXCEPTION) << "Length of name is zero!"; - } - std::string distribute_opname = GetDisOpName(name); - if (name == GATHERV2) { - distribute_opname = name + "PInfo"; - auto data_parallel_iter = attrs.find(DATA_PARALLEL); - if (data_parallel_iter != attrs.end()) { - MS_EXCEPTION_IF_NULL(data_parallel_iter->second); - if (!data_parallel_iter->second->isa()) { - MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; - } - bool data_parallel = data_parallel_iter->second->cast()->value(); - if (data_parallel) { - distribute_opname = name + "Info"; - } - } - } - OperatorInfoPtr operator_ = - (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); - if (operator_ == nullptr) { - MS_LOG(INFO) << "Creat " << name << " failed"; - return nullptr; - } - std::string origin_name = operator_->name(); - operator_->set_name(origin_name + std::to_string(TOTAL_OPS)); - MS_LOG(INFO) << "Successfully created operator " << origin_name; - ++TOTAL_OPS; - return operator_; -} - -OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - const std::vector &shape_list) { - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); - if (operator_ == nullptr) { - MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; - operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); - MS_EXCEPTION_IF_NULL(operator_); - } - return operator_; -} - -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - std::vector shape_list) { - OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); - for (size_t i = 0; i < shape_list[0].size(); ++i) { - MS_LOG(INFO) << "No: " << i << " input's shape: " << ShapeToString(shape_list[0][i]); - } - return operator_; -} - -StrategyPtr ExtractStrategy(std::unordered_map attrs) { - ValueTuplePtr var = attrs[STRATEGY]->cast(); - StrategyPtr strategyPtr; - MS_LOG(INFO) << "Extract information: strategy " << attrs[STRATEGY]->ToString(); - if (var == nullptr) { - MS_LOG(EXCEPTION) << "Strategy value is nullptr"; - } - if (var->size() > 0) { - std::vector elements = var->value(); - std::vector strategy; - for (uint32_t index = 0; index < elements.size(); ++index) { - Dimensions dim; - if (elements[index]->isa()) { - ValueTuplePtr value_tuple = elements[index]->cast(); - std::vector value_vector = value_tuple->value(); - (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), - [](const ValuePtr &value) { return static_cast(GetValue(value)); }); - strategy.push_back(dim); - } else { - MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; - } - } - if (strategy.empty()) { - MS_LOG(EXCEPTION) << "ExtractStrategy:failed to extract strategy"; - } - strategyPtr = NewStrategy(0, strategy); - } - - return strategyPtr; -} - -Shapes GetNodeShape(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - Shapes shapes; - BaseShapePtr base_shape_ptr = node->Shape(); - if (node->isa()) { - auto cnode = node->cast(); - if (IsValueNode(cnode->input(0))) { - PrimitivePtr prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == MAKEREF) { - AnfNodePtr ref_node = cnode->input(1); - auto func_graph = cnode->func_graph(); - MS_EXCEPTION_IF_NULL(ref_node); - MS_EXCEPTION_IF_NULL(func_graph); - return GetRefKeyNodeShape(ref_node, func_graph); - } - } - if (cnode->input(0)->isa()) { - if (cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; - } - base_shape_ptr = cnode->input(1)->Shape(); - } - } - if (base_shape_ptr == nullptr) { - MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " shape_ptr is nullptr, full name is " - << node->fullname_with_scope(); - } - auto tuple_shape_ptr = dyn_cast(base_shape_ptr); - if (tuple_shape_ptr != nullptr) { - auto tuple_shape = tuple_shape_ptr->shape(); - for (auto &shape : tuple_shape) { - auto each_shape = dyn_cast(shape); - MS_EXCEPTION_IF_NULL(each_shape); - shapes.push_back(each_shape->shape()); - } - } else { - auto shape_ptr = dyn_cast(base_shape_ptr); - MS_EXCEPTION_IF_NULL(shape_ptr); - shapes.push_back(shape_ptr->shape()); - } - return shapes; -} - -std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters; - if (!IsValueNode(node)) { - MS_LOG(ERROR) << "The node is not a ref key"; - return parameters; - } - - auto ref_key = GetValueNode(node); - MS_EXCEPTION_IF_NULL(ref_key); - auto name = ref_key->tag(); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto roots = manager->roots(); - if (roots.size() != 1) { - MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; - return parameters; - } - - FuncGraphPtr root_g = roots.back(); - MS_EXCEPTION_IF_NULL(root_g); - for (auto ¶m_node : root_g->parameters()) { - auto param = param_node->cast(); - if (param && (name == param->name())) { - parameters.push_back(param_node); - MS_LOG(INFO) << "The name of ref key is: " << name; - return parameters; - } - } - - MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; - return parameters; -} - -Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector parameters = FindParameterByRefKeyNode(node, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - - Shapes input_shapes; - input_shapes = GetNodeShape(parameters[0]); - if (input_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "Get input shape failed"; - } - - MS_LOG(INFO) << "The parameter shape is " << ShapeToString(input_shapes[0]); - return input_shapes; -} - -std::vector ExtractShape(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - Shapes shape_inputs, shape_outputs; - std::vector shape_all; - std::vector all_inputs = node->inputs(); - std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - - size_t inputs_size = all_inputs.size(); - for (size_t i = 1; i < inputs_size; ++i) { - Shapes input_shapes; - AnfNodePtr input = all_inputs[i]; - if (IsValueNode(input)) { - auto func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters = FindParameterByRefKeyNode(input, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - std::pair node_pair = std::make_pair(node, SizeToInt(i)); - g_RefMap[parameters[0]] = node_pair; - input_shapes = GetRefKeyNodeShape(input, func_graph); - } else if (IsValueNode(input) || input->isa() || input->isa()) { - input_shapes = GetNodeShape(input); - } else { - continue; - } - if (input_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "ExtractShape:Get input shape failed"; - } - shape_inputs.push_back(input_shapes[0]); - } - shape_all.push_back(shape_inputs); - // extract out shape - shape_outputs = GetNodeShape(node); - shape_all.push_back(shape_outputs); - return shape_all; -} - -std::pair FindParallelCareNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto &node_pair : node_set) { - CNodePtr cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr prim_node_anf = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_node_anf); - PrimitivePtr node_prim = prim_node_anf->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { - return node_pair; - } else if (FindParallelCareNode(node_pair.first).first != nullptr) { - return FindParallelCareNode(node_pair.first); - } - } - return std::make_pair(nullptr, 0); -} - -std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(parameter); - FuncGraphManagerPtr manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::pair prim_anf_node_pair = FindParallelCareNode(parameter); - if (prim_anf_node_pair.first != nullptr) { - return prim_anf_node_pair; - } else { - AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; - for (auto ¶m_pair : param_sub_set) { - CNodePtr graph_cnode = param_pair.first->cast(); - if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { - continue; - } - CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast(); - if (!IsValueNode(graph_cnode_inp0->input(1))) { - continue; - } - FuncGraphPtr graph_sub = GetValueNode(graph_cnode_inp0->input(1)); - auto parameters = graph_sub->parameters(); - if (IntToSize(param_pair.second - 1) >= parameters.size()) { - MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " - << parameters.size(); - } - std::pair res = FindSubGraph(graph_sub, parameters[IntToSize(param_pair.second - 1)]); - if (res.first != nullptr) { - return res; - } - } - } - return std::make_pair(nullptr, 0); -} - -void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { - MS_EXCEPTION_IF_NULL(parameter); - AbstractBasePtr abstract = parameter->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); - CNodePtr cnode = res.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = cnode->operator_info(); - if (distribute_operator == nullptr) { - MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; - } - - if (IntToSize(res.second - 1) >= distribute_operator->inputs_tensor_info().size()) { - MS_LOG(EXCEPTION) << "The index is out of range, index is " << res.second - 1 << ", vector size is " - << distribute_operator->inputs_tensor_info().size(); - } - TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)]; - Shape slice_shape = tensorinfo_in.slice_shape(); - MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape " - << MakeValue(slice_shape)->ToString(); - std::shared_ptr parallel_shape = std::make_shared(slice_shape); - MS_EXCEPTION_IF_NULL(parallel_shape); - // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. - auto cloned_abstract = abstract->Clone(); - MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(parallel_shape); - parameter->set_abstract(cloned_abstract); - TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); - ParameterPtr parameter_ptr = parameter->cast(); - MS_EXCEPTION_IF_NULL(parameter_ptr); - parameter_ptr->set_tensor_layout(std::make_shared(tensor_layout)); -} - -void CoverSliceShape(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - auto parameters = root->parameters(); - for (auto ¶meter : parameters) { - MS_EXCEPTION_IF_NULL(parameter->Shape()); - auto iter = g_RefMap.find(parameter); - if (iter != g_RefMap.end()) { - SetParallelShape(parameter, g_RefMap[parameter]); - continue; - } - std::pair res = FindSubGraph(root, parameter); - if (res.first == nullptr) { - MS_LOG(INFO) << "Parameter " << parameter->ToString() << " don't need to set parallel shape"; - } else { - SetParallelShape(parameter, res); - MS_LOG(DEBUG) << "Parameter " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); - } - } - g_RefMap.clear(); -} - -bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(parameter_node); - FuncGraphManagerPtr manager = root->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto cloned_parameter = parameter_node->cast(); - MS_EXCEPTION_IF_NULL(cloned_parameter); - - // find the clone parameter - if (!cloned_parameter->has_default()) { - return false; - } - - bool cloned = cloned_parameter->default_param()->cloned(); - if (!cloned) { - return false; - } - - MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned"; - return true; -} - -void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - for (auto &cloned_parameter_node : root->parameters()) { - MS_EXCEPTION_IF_NULL(cloned_parameter_node); - auto cloned_parameter = cloned_parameter_node->cast(); - MS_EXCEPTION_IF_NULL(cloned_parameter); - - if (!ParameterIsCloned(root, cloned_parameter_node)) { - continue; - } - - // get the cloned index - int32_t cloned_index = cloned_parameter->default_param()->cloned_index(); - - // find the be cloned parameter - bool found_be_cloned_parameter = false; - ParameterPtr cloned_from_parameter = nullptr; - AnfNodePtr cloned_from_node = nullptr; - for (auto &be_cloned_parameter_node : root->parameters()) { - MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); - auto be_cloned_parameter = be_cloned_parameter_node->cast(); - MS_EXCEPTION_IF_NULL(be_cloned_parameter); - if (!be_cloned_parameter->has_default()) { - continue; - } - - const auto ¶m_value_cloned = be_cloned_parameter->default_param(); - if (!param_value_cloned->be_cloned()) { - continue; - } - - // get the be cloned index - auto &be_cloned_index = param_value_cloned->be_cloned_index(); - if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) { - found_be_cloned_parameter = true; - cloned_from_parameter = be_cloned_parameter; - cloned_from_node = be_cloned_parameter_node; - } - } - - if (found_be_cloned_parameter) { - // set the shape and tensor layout for cloned parameter - cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); - MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); - MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); - auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); - MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); - cloned_parameter_node->set_abstract(cloned_abstract); - MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() - << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() - << ", clone index is: " << cloned_index; - } else { - MS_LOG(EXCEPTION) << "The parameter: " << cloned_parameter->name() << " is cloned, cloned index is " - << cloned_index << ", but not found the be cloned parameter"; - } - } - std::string env = common::GetEnv("SLICE_ENV"); - if (!env.empty()) { - MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; - } -} - -void SetVirtualDatasetStrategy(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - bool full_batch = ParallelContext::GetInstance()->full_batch(); - - PrimitivePtr prim = GetValueNode(node->input(0)); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == VIRTUAL_DATA_SET) { - CheckGlobalDeviceManager(); - int32_t dev_num; - if (full_batch) { - dev_num = 1; - } else { - dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - } - auto attrs_temp = prim->attrs(); - std::vector shape_list = ExtractShape(node); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; - } - std::vector elements; - for (size_t i = 0; i < shape_list[0].size(); i++) { - if (shape_list[0][i].empty()) { - MS_LOG(EXCEPTION) << "shape_list[ " << i << " ].size() is zero"; - } - std::vector input_strategy = {dev_num}; - for (size_t j = 1; j < shape_list[0][i].size(); j++) { - input_strategy.push_back(1); - } - elements.push_back(MakeValue(input_strategy)); - } - ValueTuplePtr strategy = std::make_shared(elements); - attrs_temp[STRATEGY] = strategy; - (void)prim->SetAttrs(attrs_temp); - } -} - -void ExtractInformation(const std::vector &all_nodes) { - // load strategy map from checkpoint - StrategyMap stra_map; - if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { - if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; - } - } - for (auto &node : all_nodes) { - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - SetVirtualDatasetStrategy(cnode); - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - auto attrs = prim->attrs(); - MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); - if (IsParallelCareNode(cnode)) { - std::vector shape_list = ExtractShape(cnode); - if (shape_list.empty()) { - MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; - } - OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); - if (operator_ == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; - } - auto &inputs = cnode->inputs(); - std::vector input_value; - for (size_t index = 1; index < inputs.size(); ++index) { - if (inputs[index]->isa()) { - input_value.push_back(GetValueNode(inputs[index])); - } else { - input_value.emplace_back(nullptr); - } - } - StrategyPtr strategyPtr = nullptr; - (*operator_).set_input_value(input_value); - (*operator_).set_outputs_dtype(cnode->Type()); - (*operator_).set_cnode(cnode); - if (prim->name() == RESHAPE) { - (void)cnode->set_operator_info(operator_); - continue; - } - // load strategy checkpoint - // key of strategy map - std::string strategy_key_name = NodeParameterName(cnode); - bool load_strategy_from_ckpt = - StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); - if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { - MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() - << " is empty, using batch parallel"; - std::shared_ptr> strategy_v_ptr = operator_->GenerateBatchStrategies(); - if (strategy_v_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Generate batch parallel strategy failed"; - } - std::vector elements; - for (size_t i = 0; i < strategy_v_ptr->size(); i++) { - elements.push_back(MakeValue((*strategy_v_ptr)[i])); - } - ValueTuplePtr strategy = std::make_shared(elements); - // display the strategy generated by batch parallel - attrs[GEN_STRATEGY] = strategy; - (void)prim->SetAttrs(attrs); - MS_LOG(INFO) << "node " << node->ToString() << " prim " << prim->name() << " batch parallel strategy is " - << attrs[GEN_STRATEGY]->ToString(); - strategyPtr = NewStrategy(0, *strategy_v_ptr); - } else if (load_strategy_from_ckpt) { - strategyPtr = stra_map[strategy_key_name]; - } else { - strategyPtr = ExtractStrategy(attrs); - } - if (strategyPtr != nullptr) { - if (operator_->Init(strategyPtr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; - } - (void)cnode->set_operator_info(operator_); - } else { - MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; - } - } - } -} - -TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { - CNodePtr cnode = node_pair.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - MS_EXCEPTION_IF_NULL(distribute_operator); - int index = node_pair.second; - if (index > SizeToInt(distribute_operator->inputs_tensor_info().size())) { - MS_LOG(EXCEPTION) << "The index is out of range, the node_pair.second is " << index - 1 << ", the vector size is " - << distribute_operator->inputs_tensor_info().size(); - } - TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(index - 1)]; - TensorLayout tensorlayout_in = tensorinfo_in.tensor_layout(); - return tensorlayout_in; -} - -// if reshape's output connect to several primitive, return the first layout found -std::shared_ptr FindNextLayout(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - FuncGraphManagerPtr manager = cnode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { - MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); - auto layout = GetInputLayoutFromCNode(node_pair); - return std::make_shared(layout); - } - MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); - - auto layout_ptr = FindNextLayout(use_apply); - if (layout_ptr) { - return layout_ptr; - } - } - MS_LOG(WARNING) << "FindNextLayout return nullptr, if reshape is not the last primitive, there must be some error"; - return nullptr; -} - -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { - MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - MS_EXCEPTION_IF_NULL(distribute_operator); - if (distribute_operator->outputs_tensor_info().size() < output_index) { - MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() - << ", must be less than output_index " << output_index; - } - TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; - TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); - return std::make_shared(tensorlayout_out); -} - -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { - if (!node->isa()) { - return nullptr; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return nullptr; - } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); - if (!layout_ptr) { - MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; - } - return layout_ptr; - } - return nullptr; -} - -std::shared_ptr CreateParameterLayout(const AnfNodePtr &node) { - // Create DataParallel tensor layout for parameter(support WideDeep). - CheckGlobalDeviceManager(); - int32_t dev_num = SizeToInt(g_device_manager->GetDeviceListByStageId(0).size()); - TensorLayout input_tensor_layout; - // create input_shape - Shapes inputs_shape = GetNodeShape(node); - Shape input_shape_array = inputs_shape[0]; - if (input_shape_array.empty()) { - MS_LOG(EXCEPTION) << "Don't support reshape a scalar parameter."; - } - // create tensor_map - size_t shape_size = input_shape_array.size(); - TensorMap input_tensor_map_array(SizeToInt(shape_size) - 1, -1); - input_tensor_map_array.insert(input_tensor_map_array.begin(), 0); - // create dev_matrix - Shape dev_matrix_array = {dev_num}; - if (input_tensor_layout.InitFromVector(dev_matrix_array, input_tensor_map_array, input_shape_array) != SUCCESS) { - MS_LOG(EXCEPTION) << "Create tensor layout for parameter failed."; - } - return std::make_shared(input_tensor_layout); -} - -std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { - if (node->isa()) { - return CreateParameterLayout(node); - } - if (!node->isa()) { - return nullptr; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return nullptr; - } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); - if (!layout_ptr) { - MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; - } - return layout_ptr; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = prim_anf_node->value()->cast(); - if (prim->name() == TUPLE_GETITEM) { - auto tuple_index = GetTupleGetItemIndex(cnode); - auto layout_ptr = FindPrevParallelCareNodeLayout(cnode->input(1), IntToSize(tuple_index)); - if (!layout_ptr) { - MS_LOG(EXCEPTION) - << " Failure:FindPrevLayout failed, tuple_getitem before reshape, but there does not exit a parallel care node " - "before tuple_getitem!"; - } - return layout_ptr; - } - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (prim->name() == DEPEND && index != 1) { - continue; - } - auto layout_ptr = FindPrevLayout(cnode->inputs()[index]); - if (!layout_ptr) { - continue; - } - return layout_ptr; - } - MS_LOG(WARNING) << "FindPrevLayout return nullptr, if reshape is not the first primitive, there must be some error"; - return nullptr; -} - -void ReshapeInit(const std::vector &all_nodes) { - for (auto &node : all_nodes) { - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { - continue; - } - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; - } - if (prim->name() != RESHAPE) { - continue; - } - auto attrs = prim->attrs(); - if (StrategyFound(attrs)) { - MS_LOG(EXCEPTION) << "Setting strategy for Reshape goes for nothing!"; - } - MS_ASSERT(cnode->inputs().size() == 3); - auto prev_layout_ptr = FindPrevLayout(cnode->input(1)); - if (prev_layout_ptr) { - auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); - reshape_info_ptr->SetInputLayout(*prev_layout_ptr); - } - auto next_layout_ptr = FindNextLayout(cnode); - if (next_layout_ptr) { - auto reshape_info_ptr = std::dynamic_pointer_cast(operator_info); - reshape_info_ptr->SetOutputLayout(*next_layout_ptr); - } - if (operator_info->Init(nullptr) == FAILED) { - MS_LOG(EXCEPTION) << "Failure:operator " << prim->ToString() << " init failed"; - } - } -} - -CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - CNodePtr return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - auto pre_cnode = pre_node->cast(); - if (pre_cnode == nullptr) { - return nullptr; - } - - auto current_prim = GetValueNode(pre_cnode->input(0)); - // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_cnode = pre_cnode->input(1)->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - current_prim = GetValueNode(pre_cnode->input(0)); - } - - // notice: the GetNext op has not input - if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(INFO) << "The loss is: " << current_prim->name(); - return pre_cnode; - } - - // size of common cnode is larger than 1 - if (pre_cnode->size() < 2) { - MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; - } - - // return -> tuple_getitem -> loss - if (current_prim->name() == TUPLE_GETITEM) { - AnfNodePtr pre_pre_node = pre_cnode->input(1); - MS_EXCEPTION_IF_NULL(pre_pre_node); - - auto pre_pre_cnode = pre_pre_node->cast(); - auto value = pre_pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value); - PrimitivePtr prim = value->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(DEBUG) << "The loss name is " << prim->name(); - return pre_pre_cnode; - } - - // return -> make_tuple - if (current_prim->name() == MAKE_TUPLE) { - MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; - } - - // return -> loss - MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); - return pre_cnode; -} - -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { - TensorLayouts ret; - MS_EXCEPTION_IF_NULL(loss_cnode); - AnfNodePtr node = loss_cnode->cast(); - MS_EXCEPTION_IF_NULL(node); - - LossNodeInfo node_info = GetLossNodeInfo(node); - ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; - return ret; - } - - OperatorInfoPtr operator_info = loss_cnode->operator_info(); - MS_EXCEPTION_IF_NULL(operator_info); - TensorInfo loss_grad_tensor_info; - size_t op_output_size = operator_info->outputs_tensor_info().size(); - MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " - << node_info.has_tuple_getitem << ", the output size is " << op_output_size << ", the dout_index is " - << node_info.dout_index; - - if ((op_output_size == 0) || (op_output_size <= IntToSize(node_info.dout_index))) { - MS_LOG(EXCEPTION) << "The index is " << node_info.dout_index << ", but the size of outputs is " << op_output_size; - } - - if (!node_info.has_tuple_getitem && (op_output_size > 1)) { - MS_LOG(EXCEPTION) << "Currently, it is not supported that the sens is a tuple."; - } - - loss_grad_tensor_info = operator_info->outputs_tensor_info()[IntToSize(node_info.dout_index)]; - ret.push_back(loss_grad_tensor_info.tensor_layout()); - return ret; -} - -void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { - MS_EXCEPTION_IF_NULL(grad_sens_node); - if (grad_sens_node->size() <= 1) { - MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2"; - } - AnfNodePtr sens_tensor_node = grad_sens_node->input(1); - MS_EXCEPTION_IF_NULL(sens_tensor_node); - Shapes sens_shapes = GetNodeShape(sens_tensor_node); - if (sens_shapes.size() != 1) { - MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1"; - } - // If the shape of sens tensor is [] or [1], no need to split it. - Shape sens_shape = sens_shapes[0]; - if (sens_shape.empty() || ((sens_shape.size() == 1) && (sens_shape[0] == 1))) { - if (sens_tensor_node->isa()) { - auto sens_tensor_param = sens_tensor_node->cast(); - MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); - } - MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; - return; - } - auto loss_shape = loss_grad_layout.tensor_shape().array(); - if (loss_shape != sens_shape) { - MS_LOG(EXCEPTION) << "The shape of sens is not equal to loss output, it is unsupported now. Sens shape is " - << ShapeToString(sens_shape) << ", loss shape is " << ShapeToString(loss_shape); - } - MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", split it."; - - if (!IsValueNode(sens_tensor_node)) { - if (sens_tensor_node->isa()) { - MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - AbstractBasePtr abstract = sens_tensor_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - auto slice_shape = loss_grad_layout.slice_shape().array(); - std::shared_ptr parallel_shape = std::make_shared(slice_shape); - MS_EXCEPTION_IF_NULL(parallel_shape); - auto cloned_abstract = abstract->Clone(); - MS_EXCEPTION_IF_NULL(cloned_abstract); - cloned_abstract->set_shape(parallel_shape); - sens_tensor_node->set_abstract(cloned_abstract); - auto sens_tensor_param = sens_tensor_node->cast(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); - return; - } - MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; - } - - // Use _GetTensorSlice operator to split the sens tensor - FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph - MS_EXCEPTION_IF_NULL(func_graph); - Operator op = CreateGetTensorSliceOp(loss_grad_layout); - InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS); -} - -void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); - OperatorVector forward_op = distribute_operator->forward_op(); - if (!forward_op.empty()) { - MS_LOG(INFO) << "Insert forward op for " << distribute_operator->name(); - ForwardCommunication(forward_op, cnode); - } -} - -void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); - // StepReplaceOp - OperatorVector replace_op = distribute_operator->replace_op(); - if (!replace_op.empty()) { - MS_LOG(INFO) << "StepReplaceOp " << cnode->ToString(); - StepReplaceOp(replace_op, cnode); - } - - // StepReplaceGraph: after calling StepReplaceGraph, cnode can not be used anymore. - ReplaceGraphPtr replace_graph = distribute_operator->replace_graph(cnode); - if (!replace_op.empty() && replace_graph) { - MS_LOG(EXCEPTION) << "Only one of replace_op or replace_op can be used"; - } - if (replace_graph) { - MS_LOG(INFO) << "StepReplaceGraph " << cnode->ToString(); - StepReplaceGraph(replace_graph, cnode); - } -} - -void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(distribute_operator); - MS_EXCEPTION_IF_NULL(cnode); - - std::string op_name = distribute_operator->name(); - if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) { - return; - } - - DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast(distribute_operator); - MS_EXCEPTION_IF_NULL(dropout_do_mask); - std::vector replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); - if (replace_op.empty()) { - MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; - return; - } - if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { - MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; - } - ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); -} - -void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { - HandleDropoutNode(distribute_operator, cnode); -} - -std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { - // J->CNode->Graph - std::set graph_set; - for (auto &node : root_all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - - auto cnode = node->cast(); - if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { - continue; - } - auto expect_j_prim = GetValueNode(cnode->input(0)); - if (expect_j_prim->name() != J) { - continue; - } - if (IsValueNode(cnode->input(1))) { - auto graph = GetValueNode(cnode->input(1)); - MS_LOG(DEBUG) << "Find the forward graph success"; - graph_set.insert(graph); - } - } - return graph_set; -} - -void StepSplitSens(const std::pair &sens_loss_pair) { - CNodePtr sens_node = sens_loss_pair.first; - CNodePtr loss_node = sens_loss_pair.second; - auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); - if (!loss_grad_layout.empty()) { - SplitSens(sens_node, loss_grad_layout[0]); - } -} - -// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -std::vector> GetSensLossPairs(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - std::vector> sens_loss_pairs; - for (auto &node : root->nodes()) { - if (!node->isa()) { - continue; - } - - // cnode(sens)-->cnode(tuple_getitem) - auto sens_cnode = node->cast(); - AnfNodePtr expect_tuple_getitem = sens_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem); - if (!expect_tuple_getitem->isa()) { - continue; - } - - auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); - if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) { - continue; - } - - // cnode(sens)-->cnode(tuple_getitem)-->cnode - AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); - MS_EXCEPTION_IF_NULL(expect_anonymous); - if (!expect_anonymous->isa()) { - continue; - } - - // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) - auto expect_anonymous_cnode = expect_anonymous->cast(); - AnfNodePtr expect_j = expect_anonymous_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_j); - if (!expect_j->isa()) { - continue; - } - auto expect_j_cnode = expect_j->cast(); - if (!IsSomePrimitive(expect_j_cnode, J)) { - continue; - } - - if (!IsValueNode(expect_j_cnode->input(1))) { - MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; - } - auto func_graph = GetValueNode(expect_j_cnode->input(1)); - auto loss_cnode = FindLossCNode(func_graph); - if (loss_cnode == nullptr) { - MS_LOG(WARNING) << "Can not find the loss cnode"; - continue; - } - std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); - sens_loss_pairs.push_back(sens_loss_pair); - } - return sens_loss_pairs; -} - -void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, - const FuncGraphManagerPtr &manager) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(manager); - TensorRedistribution tensor_redistribution; - - std::vector> sens_loss_pairs = GetSensLossPairs(root); - bool has_backward = !sens_loss_pairs.empty(); - // split sens must before inserting the operators. - for (auto &pair : sens_loss_pairs) { - // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. - // If the type of sens node is not Tensor, it is unsupported now, do nothing default. - StepSplitSens(pair); - } - - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - if (distribute_operator == nullptr) { - continue; - } - - // insert forward ops - InsertForwardOps(distribute_operator, cnode); - - // insert redistribution ops - StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); - - // insert backward ops - if (has_backward) { - BackwardCommunication(distribute_operator, cnode, sens_loss_pairs); - } - - HandleSpecialNode(distribute_operator, cnode); - } else if (IsValueNode(node)) { - StepSplitTensor(node, manager); - } - } - - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); - if (distribute_operator == nullptr) { - continue; - } - // StepReplace - StepReplace(distribute_operator, cnode); - } - } -} - -namespace { -void RevertSymbolicKeyInstance(const FuncGraphPtr &root, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(node); - auto symbolic_key = GetValueNode(node); - MS_EXCEPTION_IF_NULL(symbolic_key); - auto all_upstream_node = root->manager()->node_users()[node]; - for (auto &upstream_node : all_upstream_node) { - FuncGraphPtr fg = upstream_node.first->func_graph(); - if (symbolic_key->node()->isa()) { - for (auto ¶m : root->parameters()) { - if (*param == *symbolic_key->node()) { - AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); - MS_EXCEPTION_IF_NULL(reverted_node); - MS_LOG(DEBUG) << "before replace " << node->ToString() << " to node " << reverted_node->DebugString(); - (void)fg->manager()->Replace(node, reverted_node); - MS_LOG(DEBUG) << "revert node " << node->ToString() << " to node " << reverted_node->DebugString(); - } - } - } - } -} -} // namespace - -void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { - MS_EXCEPTION_IF_NULL(root); - for (auto &node : all_nodes) { - // revert back SymbolicKeyInstance to embed() primitive - if (IsValueNode(node)) { - RevertSymbolicKeyInstance(root, node); - continue; - } - } -} - -std::string NodeParameterName(const CNodePtr &node) { - std::vector node_inputs{node->inputs()}; - for (auto input : node_inputs) { - if (input->isa()) { - auto input_parameter = input->cast(); - if (input_parameter->has_default()) { - const auto ¶m_value = input_parameter->default_param(); - if (param_value->requires_grad()) { - return param_value->name(); - } - } - } - } - return ""; -} - -void CheckpointStrategy(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(DEBUG) << "Save strategy to checkpoint begin"; - StrategyMap stra_map; - auto ret = func_graph->get_return(); - auto all_nodes = DeepScopedGraphSearch(ret); - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - continue; - } - std::string param_name = NodeParameterName(cnode); - if (param_name.empty()) { - continue; - } - PrimitivePtr prim = GetValueNode(cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info) { - if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { - continue; - } - StrategyPtr strategyPtr = operator_info->strategy(); - MS_EXCEPTION_IF_NULL(node->scope()); - stra_map[param_name] = strategyPtr; - } - } - if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { - MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; - } -} - -void SetForwardFlag(const std::vector &all_nodes) { - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - - // CNode is globally unique. - MS_LOG(DEBUG) << "Set forward flag " << cnode->DebugString() << "."; - cnode->set_in_forward_flag(true); - } -} - -void SetForwardFlag(const AnfNodeSet &all_nodes) { - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - continue; - } - - // CNode is globally unique. - cnode->set_in_forward_flag(true); - } -} - -std::set ForwardGraph(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - const auto &all_nodes = root->nodes(); - std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - return graph_set; -} - -std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { - MS_EXCEPTION_IF_NULL(graph); - std::vector root_forward_nodes; - auto loss_cnode = FindLossCNode(graph); - if (loss_cnode == nullptr) { - MS_LOG(WARNING) << "Can not find the loss cnode"; - return root_forward_nodes; - } - - auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); - for (auto &node : all_nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - auto root_node_id = node->UniqueIdThroughCopy(); - if (loss_cnode_id == root_node_id) { - root_forward_nodes = DeepLinkedGraphSearch(cnode); - break; - } - } - return root_forward_nodes; -} - -void MarkForwardCNode(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - auto all_nodes = root->nodes(); - std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - - if (graph_set.empty()) { - MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; - SetForwardFlag(all_nodes); - } else { - for (auto &func_graph : graph_set) { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); - auto return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); - SetForwardFlag(all_dfs_nodes); - auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); - if (root_forward_nodes.empty()) { - continue; - } - // Mark forward flag for the nodes in root graph. - SetForwardFlag(root_forward_nodes); - } - } -} - -Status ParallelInit() { - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - int32_t device_num = ParallelContext::GetInstance()->device_num(); - int32_t global_rank = ParallelContext::GetInstance()->global_rank(); - std::string backend = ParallelContext::GetInstance()->communication_backend(); - std::string world_group; - - if (backend == HCCL_BACKEND) { - world_group = HCCL_WORLD_GROUP; - } else if (backend == NCCL_BACKEND) { - world_group = NCCL_WORLD_GROUP; - } else { - MS_LOG(EXCEPTION) << "Invalid communication backend: " << backend; - } - - uint32_t world_rank_size = 0; - if (!ParallelContext::GetInstance()->device_num_is_set()) { - if (!CommManager::GetInstance().GetRankSize(world_group, &world_rank_size)) { - MS_LOG(EXCEPTION) << "Get rank size failed"; - } - device_num = UintToInt(world_rank_size); - MS_LOG(INFO) << "Get device num from communication model, the device num is " << device_num; - } - - uint32_t rank_id = 0; - if (!ParallelContext::GetInstance()->global_rank_is_set()) { - if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { - MS_LOG(EXCEPTION) << "Get rank id failed"; - } - global_rank = UintToInt(rank_id); - MS_LOG(INFO) << "Get global rank from communication model, the global rank is " << global_rank; - } - - if (!InitDevice(device_num, global_rank, backend)) { - MS_LOG(ERROR) << "Init device failed"; - return FAILED; - } - - MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank - << ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() - << ", cast_before_mirror: " << ParallelContext::GetInstance()->cast_before_mirror(); - return SUCCESS; -} - -bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { - MS_EXCEPTION_IF_NULL(root); - MS_EXCEPTION_IF_NULL(optimizer); - MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); - std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); - // assume no change to graph - bool changes = false; - // control whether use model_parallel mode - if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || - (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { - if (!root->has_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY)) { - if (HasStrategy(root)) { - MS_LOG(INFO) << "Strategies ignored in " << parallel_mode - << ", set_strategy() only valid in [semi_]auto_parallel."; - } - root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); - } - - return changes; - } - - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - - MS_LOG(INFO) << "Now entering step parallel"; - DumpGraph(root, std::string(STEP_PARALLEL_BEGIN)); - - pipeline::ResourceBasePtr res = optimizer->resource(); - MS_EXCEPTION_IF_NULL(res); - - FuncGraphManagerPtr manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodePtr ret = root->get_return(); - MS_EXCEPTION_IF_NULL(ret); - std::vector all_nodes = DeepScopedGraphSearch(ret); - std::reverse(all_nodes.begin(), all_nodes.end()); - if (parallel_mode != AUTO_PARALLEL) { - TOTAL_OPS = 0; - if (ParallelInit() != SUCCESS) { - MS_LOG(EXCEPTION) << "Parallel init failed"; - } - - // mark the forward cnodes, parallel only care these nodes - MarkForwardCNode(root); - - if (FindCommunicationOp(all_nodes)) { - MS_LOG(EXCEPTION) << "The graph contain communication op"; - } - - // extract shape and strategy, set operator_info - ExtractInformation(all_nodes); - ReshapeInit(all_nodes); - } - // save strategy as checkpoint for multi-train - if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) { - CheckpointStrategy(root); - } - - HandleSymbolicKeyInstance(root, all_nodes); - - // cover Parallel shape - CoverSliceShape(root); - - // set the shape for optimizer's clone tensor - SetClonedTensorShapeForOptimizer(root); - - // ForwardCommunication BackwardCommunication TensorRedistribution - ParallelCommunication(root, all_nodes, manager); - - DumpGraph(root, std::string(STEP_PARALLEL_END)); - - // step parallel only run once - root->set_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY, true); - res->results()[pipeline::kStepParallelGraph] = root; - - // in auto parallel mode, no need to check if stategies set - root->set_flag(CHECK_SET_STRATEGY_VALID_ONCE_ONLY, true); - - (void)gettimeofday(&end_time, nullptr); - uint64_t time = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - time += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Now leaving step parallel, used time: " << time << " us"; - return changes; -} - -// Needed by rec_parser -std::vector ExtractInputsTensorName(const CNodePtr &node) { - std::vector name_inputs; - std::vector all_inputs = node->inputs(); - std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - - std::string node_id = node->UniqueId(); - name_inputs.push_back(node_id); - for (auto &input : node_inputs) { - std::string name = input->UniqueId(); - name_inputs.push_back(name); - } - - return name_inputs; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h deleted file mode 100644 index 308473dcd7..0000000000 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ -#define MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ - -#include - -#include -#include -#include -#include -#include -#include - -#include "./common.h" -#include "optimizer/opt.h" -#include "parallel/strategy.h" -#include "parallel/tensor_layout/tensor_redistribution.h" - -using OperatorInfoPtr = std::shared_ptr; - -namespace mindspore { -namespace parallel { -const uint64_t kUSecondInSecond = 1000000; - -struct LossNodeInfo { - bool has_tuple_getitem = false; - int dout_index = 0; // now don't support the sens is a tuple -}; - -std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); -std::string CreateInstanceName(const CNodePtr &node, size_t index); -void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); - -void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, - const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); - -TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, - const OperatorInfoPtr &distribute_operator_pre); - -OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); - -void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, - const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr &pre_node); - -bool StrategyFound(std::unordered_map attrs); - -bool IsParallelCareNode(const CNodePtr &cnode); - -void MarkForwardCNode(const FuncGraphPtr &root); - -bool FindCommunicationOp(const std::vector &all_nodes); - -void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, - const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); - -std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, - const CNodePtr &node); - -void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); - -void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); - -std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); - -std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); - -void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); - -void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs); - -// Generate and init parallel operator -OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - const std::vector &shape_list); - -// Generate without initing parallel operator -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, - std::vector shape_list); - -// Extract strategy from attr -StrategyPtr ExtractStrategy(std::unordered_map attrs); - -Shapes GetNodeShape(const AnfNodePtr &node); - -std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); - -// Extract shape from anfnode -std::vector ExtractShape(const CNodePtr &node); - -std::pair FindParallelCareNode(const AnfNodePtr &node); - -// Find finally sub graph -std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); - -// Set distribute shape for parameters abstract -void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); - -// change parameters'shape in resource -void CoverSliceShape(const FuncGraphPtr &root); - -void SetVirtualDatasetStrategy(const CNodePtr &node); - -// Creat parallel operator for primitive node(has strategy) -void ExtractInformation(const std::vector &all_nodes); - -TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); - -std::shared_ptr FindNextLayout(const CNodePtr &node); - -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); - -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); - -std::shared_ptr FindPrevLayout(const AnfNodePtr &node); - -void ReshapeInit(const std::vector &all_nodes); - -// Add node for whole graph -void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, - const FuncGraphManagerPtr &manager); - -std::string NodeParameterName(const CNodePtr &node); - -void CheckpointStrategy(const FuncGraphPtr &func_graph); - -// main step of Parallel -bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); - -int32_t GetTupleGetItemIndex(const CNodePtr &cnode); - -Status ParallelInit(); - -std::vector ExtractInputsTensorName(const CNodePtr &node); - -std::set ForwardGraph(const FuncGraphPtr &root); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_STEP_PARALLEL_H_ diff --git a/mindspore/ccsrc/parallel/strategy.h b/mindspore/ccsrc/parallel/strategy.h deleted file mode 100644 index bc62dd5308..0000000000 --- a/mindspore/ccsrc/parallel/strategy.h +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ -#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ - -#include -#include -#include -#include -#include - -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -#define MIN_SLICE_NUM 1 - -using Dimensions = std::vector; - -class Strategy; -using StrategyPtr = std::shared_ptr; - -class Strategy { - public: - Strategy(int32_t stage, std::vector inputs) : stage_(stage), inputs_(std::move(inputs)) {} - ~Strategy() = default; - size_t GetInputNumber() const { return inputs_.size(); } - std::vector GetInputDim() const { return inputs_; } - int32_t GetInputStage() const { return stage_; } - void ExpandInputDimFromOneToTwo() { - if (inputs_.size() == 1) { - inputs_.push_back(inputs_[0]); - } - } - void ResetInputs(const std::vector &input) { inputs_ = input; } - - bool IsEqual(const StrategyPtr &another_stra) { - if (another_stra == nullptr) { - return false; - } - if ((stage_ != another_stra->GetInputStage()) || (inputs_ != another_stra->GetInputDim())) { - return false; - } - return true; - } - - private: - const int32_t stage_; - - // The size of Dimensions must equal to inputs_ tensor dimension. - std::vector inputs_; -}; - -inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { - return std::make_shared(stage, inputs); -} -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc deleted file mode 100644 index a83b5eb627..0000000000 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" - -#include -#include -#include - -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" -#include "proto/node_strategy.pb.h" - -namespace mindspore { -namespace parallel { -StrategyCheckpoint &StrategyCheckpoint::GetInstance() { - static StrategyCheckpoint instance = StrategyCheckpoint(); - if (ParallelContext::GetInstance() != nullptr) { - instance.load_file_ = ParallelContext::GetInstance()->strategy_ckpt_load_file(); - instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); - instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); - instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); - } - return instance; -} - -bool StrategyCheckpoint::CheckPointExit(const std::string path) const { - std::ifstream fin(path); - if (fin) { - return true; - } - return false; -} - -Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { - if (strategy_map == nullptr) { - MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; - } - if (!CheckPointExit(load_file_)) { - MS_LOG(EXCEPTION) << "CheckPoint file is not found"; - } - straspb::ParallelStrategyMap parallel_strategy_map; - std::fstream input(load_file_, std::ios::in | std::ios::binary); - if (!parallel_strategy_map.ParseFromIstream(&input)) { - MS_LOG(ERROR) << "Load strategy file failed"; - return FAILED; - } - size_t node_num = IntToSize(parallel_strategy_map.parallel_strategy_item_size()); - for (size_t i = 0; i < node_num; i++) { - straspb::ParallelStrategyItem parallel_strategy_item = parallel_strategy_map.parallel_strategy_item(SizeToInt(i)); - std::string node_name = parallel_strategy_item.node_name(); - straspb::ParallelStrategys parallel_strategys = parallel_strategy_item.parallel_strategys(); - auto stage = (int32_t)parallel_strategys.stage(); - size_t strategys_num = IntToSize(parallel_strategys.parallel_strategy_size()); - std::vector> strategy_inputs; - for (size_t j = 0; j < strategys_num; j++) { - straspb::ParallelStrategy parallel_strategy = parallel_strategys.parallel_strategy(SizeToInt(j)); - std::vector dimension; - size_t dim_num = IntToSize(parallel_strategy.dim_size()); - for (size_t k = 0; k < dim_num; k++) { - dimension.push_back(parallel_strategy.dim(SizeToInt(k))); - } - strategy_inputs.push_back(dimension); - } - - StrategyPtr strategy = NewStrategy(stage, strategy_inputs); - (*strategy_map)[node_name] = strategy; - current_stage_ = (int32_t)parallel_strategy_map.current_stage(); - } - return SUCCESS; -} - -Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { - straspb::ParallelStrategyMap parallel_strategy_map; - parallel_strategy_map.set_current_stage(IntToUint(++current_stage_)); - for (auto &node_stra : strategy_map) { - straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); - MS_EXCEPTION_IF_NULL(parallel_strategy_item); - parallel_strategy_item->set_node_name(node_stra.first); - straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); - MS_EXCEPTION_IF_NULL(parallel_strategys); - MS_EXCEPTION_IF_NULL(node_stra.second); - parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); - for (auto &dims : node_stra.second->GetInputDim()) { - straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); - MS_EXCEPTION_IF_NULL(parallel_strategy); - for (auto dim : dims) { - parallel_strategy->add_dim(IntToUint(dim)); - } - } - } - std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); - if (!parallel_strategy_map.SerializeToOstream(&output)) { - MS_LOG(ERROR) << "Save strategy file failed"; - return FAILED; - } - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h deleted file mode 100644 index a758a9e7bb..0000000000 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ -#define MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ - -#include -#include -#include "parallel/ops_info/ops_utils.h" -#include "parallel/strategy.h" -#include "parallel/context.h" - -namespace mindspore { -namespace parallel { -using StrategyMap = std::unordered_map; -class StrategyCheckpoint { - public: - StrategyCheckpoint() { - current_stage_ = 0; - load_file_ = ""; - load_checkpoint_on_ = false; - save_file_ = ""; - save_checkpoint_on_ = false; - } - ~StrategyCheckpoint() = default; - - Status Load(StrategyMap *strategy_map); - Status Save(const StrategyMap &strategy_map); - - static StrategyCheckpoint &GetInstance(); - bool LoadCheckPointOn() const { return load_checkpoint_on_; } - bool SaveCheckPointOn() const { return save_checkpoint_on_; } - - private: - std::string load_file_; - std::string save_file_; - bool load_checkpoint_on_; - bool save_checkpoint_on_; - bool CheckPointExit(const std::string path) const; - int32_t current_stage_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_STRATEGY_CHEKCPOINT_PARALLEL_STRATEGY_CHECKPOINT_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc deleted file mode 100644 index 235ab00302..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc +++ /dev/null @@ -1,248 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/arrangement.h" -#include -#include -#include -#include "common/utils.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status Arrangement::Init(const std::vector &array) { - Status status = Array::Init(array); - if (status != Status::SUCCESS) { - return Status::FAILED; - } - if (!IsValidArrangement()) { - MS_LOG(ERROR) << "invalid arrangement " << this->ToString(); - return Status::FAILED; - } - ComputeSize(); - return Status::SUCCESS; -} - -bool Arrangement::IsValidArrangement() { - return !std::any_of(array_.begin(), array_.end(), [](int32_t value) { return value <= 0; }); -} - -void Arrangement::ComputeSize() { - size_ = 1; - for (auto &value : array_) { - size_ *= value; - } -} - -/* - * if GetDimSize() = 0, return [] - * if value <= array_[0], return [value] - * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]], - * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1], - * if value > size_, return [] - */ -std::vector Arrangement::GetFrontElementByValue(int32_t value) const { - std::vector out; - if (GetDimSize() == 0) { - return out; - } - if (value <= size_) { - int32_t size = 1; - uint32_t shape_list_idx = 0; - while (size < value) { - size *= array_[shape_list_idx]; - if (size <= value) { - out.push_back(array_[shape_list_idx]); - } else { - if (size == 0) { - MS_LOG(ERROR) << "The size is 0"; - out.clear(); - return out; - } - out.push_back(value * array_[shape_list_idx] / size); - } - shape_list_idx++; - } - } - return out; -} - -std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( - const std::vector &expand_list) const { - if (expand_list.size() != GetDimSize()) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i < expand_list.size(); i++) { - std::vector expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i)); - if (expand_shape.empty()) { - new_shape.push_back(GetDimByIdx(i)); - } else { - (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end()); - } - } - Arrangement arrangement_new; - (void)arrangement_new.Init(new_shape); - return std::make_shared(arrangement_new); -} - -/* - * example: - * expand_shape = [4, 2, 2, 2] - * array_ = [8, 4], - * arrangement_list = [[4, 2], [2, 2]] - */ -std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { - int32_t size = 1; - uint32_t ind = 0; - std::vector arrangement_list; - std::vector shape; - for (uint32_t i = 0; i < expand_shape.GetDimSize(); i++) { - size *= expand_shape.GetDimByIdx(i); - if (size > GetDimByIdx(ind)) { - MS_LOG(ERROR) << "invalid expand_shape"; - return nullptr; - } else if (size < GetDimByIdx(ind)) { - shape.push_back(expand_shape.GetDimByIdx(i)); - continue; - } else { - shape.push_back(expand_shape.GetDimByIdx(i)); - Arrangement arrangement; - (void)arrangement.Init(shape); - arrangement_list.push_back(arrangement); - shape.clear(); - ind++; - size = 1; - } - } - if (ind != GetDimSize()) { - MS_LOG(ERROR) << "invalid expand_shape"; - return nullptr; - } - auto arrangement_new = std::make_shared>(arrangement_list); - return arrangement_new; -} - -std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( - const Arrangement &expand_shape) const { - std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); - if (expand_shape_list_ptr == nullptr) { - return nullptr; - } - std::vector expand_num_list_shape; - (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), - std::back_inserter(expand_num_list_shape), - [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); - Arrangement expand_num_list; - Status status = expand_num_list.Init(expand_num_list_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list); - return std::make_shared, Arrangement>>(out_value); -} - -std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() const { - std::vector shape_accum; - int32_t size = 0; - for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) { - shape_accum.push_back(size); - size += *iter; - } - return shape_accum; -} - -std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( - const std::vector &expand_list) const { - if (expand_list.size() != GetDimSize()) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i < expand_list.size(); i++) { - if (expand_list[i].GetDimSize() >= 1) { - int32_t size = 1; - for (uint32_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) { - new_shape.push_back(expand_list[i].GetDimByIdx(k)); - size *= expand_list[i].GetDimByIdx(k); - } - new_shape.push_back(GetDimByIdx(i) / size); - } else { - new_shape.push_back(GetDimByIdx(i)); - } - } - Arrangement arrangement_new; - (void)arrangement_new.Init(new_shape); - return std::make_shared(arrangement_new); -} - -std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { - std::vector in1_accum; - Status status = ShapeToAccumulateProduct(array_, &in1_accum); - if (status != Status::SUCCESS) { - return nullptr; - } - std::vector in2_accum; - status = ShapeToAccumulateProduct(in2.array(), &in2_accum); - if (status != Status::SUCCESS) { - return nullptr; - } - std::vector out_accum; - status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); - if (status != Status::SUCCESS) { - return nullptr; - } - std::vector out_shape; - status = AccumulateProductToShape(out_accum, &out_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - Arrangement out; - status = out.Init(out_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -std::vector Arrangement::GetSqueezeIdx() const { - std::vector out; - for (size_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(SizeToUint(i)) == 1) { - out.push_back(i); - } - } - return out; -} - -Arrangement Arrangement::GetSqueezeArrangement() const { - std::vector out_shape(array_.size()); - auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int32_t value) { return value != 1; }); - out_shape.resize(LongToSize(std::distance(out_shape.begin(), it))); - - // if all elements are 1, out_shape = {1} - if (out_shape.empty()) { - MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; - out_shape.push_back(1); - } - Arrangement out; - (void)out.Init(out_shape); - return out; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h deleted file mode 100644 index ca71b05c91..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ - -#include -#include -#include -#include -#include -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/array.h" - -namespace mindspore { -namespace parallel { -class Arrangement : public Array { - public: - Arrangement() : size_(1) {} - ~Arrangement() override = default; - Status Init(const std::vector &array) override; - int32_t size() const { return size_; } - std::vector GetFrontElementByValue(int32_t value) const; - std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; - std::vector ComputeReverseAccumulateSumInReverseOrder() const; - std::shared_ptr GetExpandedShapeByExpandListReserveLeft( - const std::vector &expand_list) const; - std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( - const std::vector &expand_list) const; - std::shared_ptr, Arrangement>> GetExpandShapeListPair( - const Arrangement &expand_shape) const; - std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; - std::vector GetSqueezeIdx() const; - Arrangement GetSqueezeArrangement() const; - - private: - bool IsValidArrangement(); - void ComputeSize(); - int32_t size_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRANGEMENT_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.cc b/mindspore/ccsrc/parallel/tensor_layout/array.cc deleted file mode 100644 index ef358e7cde..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/array.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/array.h" -#include -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -std::string Array::ToString() const { - std::ostringstream buffer; - buffer << "[ "; - for (auto &element : array_) { - buffer << std::to_string(element) + " "; - } - buffer << "]"; - return buffer.str(); -} - -Status Array::Init(const std::vector &array) { - array_ = array; - return IsvalidArray() ? Status::SUCCESS : Status::FAILED; -} - -bool Array::IsvalidArray() const { return true; } - -int32_t Array::GetDimByIdx(uint32_t idx) const { - size_t mod_idx = idx; - if (idx >= GetDimSize()) { - MS_LOG(EXCEPTION) << "idx is " << idx << ", but array size is " << GetDimSize(); - } - return array_[mod_idx]; -} - -int32_t Array::GetDimByReverseIdx(uint32_t idx) const { - size_t mod_idx = idx; - if (idx >= GetDimSize()) { - MS_LOG(EXCEPTION) << "idx is " << idx << " but array size is " << GetDimSize(); - } - return array_[GetDimSize() - 1 - mod_idx]; -} - -bool Array::operator==(const Array &shape) const { - if (GetDimSize() != shape.GetDimSize()) { - return false; - } - for (uint32_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(i) != shape.GetDimByIdx(i)) { - return false; - } - } - return true; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.h b/mindspore/ccsrc/parallel/tensor_layout/array.h deleted file mode 100644 index 5aa3bdb138..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/array.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ - -#include -#include -#include -#include -#include -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -class Array { - public: - Array() = default; - virtual ~Array() = default; - std::string ToString() const; - virtual Status Init(const std::vector &array); - bool IsvalidArray() const; - std::vector array() const { return array_; } - size_t GetDimSize() const { return array_.size(); } - int32_t GetDimByIdx(uint32_t idx) const; - int32_t GetDimByReverseIdx(uint32_t idx) const; - bool operator==(const Array &a1) const; - - protected: - std::vector array_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_ARRAY_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc deleted file mode 100644 index b5ca5ed60a..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc +++ /dev/null @@ -1,254 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/construct_operator.h" - -#include -#include - -namespace mindspore { -namespace parallel { -Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { - dev_size_ = dev_matrix_shape.size(); - dev_matrix_shape_ = dev_matrix_shape; - dev_list_ = dev_list; - return Status::SUCCESS; -} - -Status ConstructOperator::ReshapeOP(Shape shape) { - int32_t prod = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - int32_t prod_expect = std::accumulate(tensor_shape_.begin(), tensor_shape_.end(), 1, std::multiplies()); - if (prod != prod_expect) { - ValuePtr ptr = MakeValue(shape); - MS_EXCEPTION_IF_NULL(ptr); - MS_LOG(ERROR) << "Invalid tensor shape " << ptr->ToString() << "when construct Reshape operator!"; - return Status::INVALID_ARGUMENT; - } - OperatorAttrs attrs; - ValuePtr param_value = MakeValue(shape); - Attr param = std::make_pair(SHAPE, param_value); - OperatorParams params = {std::make_pair(param, 2)}; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(RESHAPE, args); - return Status::SUCCESS; -} - -Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { - ValuePtr attr_value = MakeValue(value); - Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); - Attr attr_end_mask = std::make_pair(END_MASK, attr_value); - Attr attr_ellipsis_mask = std::make_pair(ELLIPSIS_MASK, attr_value); - Attr attr_new_axis_mask = std::make_pair(NEW_AXIS_MASK, attr_value); - Attr attr_shrink_axis_mask = std::make_pair(SHRINK_AXIS_MASK, attr_value); - OperatorAttrs attrs = {attr_begin_mask, attr_end_mask, attr_ellipsis_mask, attr_new_axis_mask, attr_shrink_axis_mask}; - - ValuePtr param_begin_value = MakeValue(begin); - Param param_begin = std::make_pair(std::make_pair(BEGIN, param_begin_value), 2); - ValuePtr param_end_value = MakeValue(end); - Param param_end = std::make_pair(std::make_pair(END, param_end_value), 3); - - ValuePtr param_strides_value = MakeValue(strides); - Param param_strides = std::make_pair(std::make_pair(STRIDES, param_strides_value), 4); - OperatorParams params = {param_begin, param_end, param_strides}; - OperatorArgs op_args = std::make_pair(attrs, params); - - return std::make_pair(STRIDED_SLICE, op_args); -} - -Status ConstructOperator::StridedSliceOP(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - int32_t split_count = args[0]; - if (split_count <= 0) { - MS_LOG(ERROR) << "split_count should not be less than 0!"; - return Status::FAILED; - } - int32_t split_dim = args[1]; - int32_t dev_dim = args[2]; - std::vector group_list; - - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << "stride slice op: create group failed"; - return FAILED; - } else if (group_list.empty()) { // this group only has one device, don't need do StridedSlice - MS_LOG(INFO) << "no need stride slice op"; - return SUCCESS; - } - - Group group = group_list[0]; - size_t rank; - if (group.GetIndex(&rank) == Status::FAILED) { - return Status::FAILED; - } - size_t size = tensor_shape_.size(); - Shape begin(size); - Shape end(size); - Shape strides(size, 1); - size_t index = 0; - for (auto num : tensor_shape_) { - if (index != IntToSize(split_dim)) { - begin[index] = 0; - end[index] = num; - } else { - if (num % split_count != 0) { - MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim - << "! when construct StridedSlice operator"; - return Status::INVALID_ARGUMENT; - } - int32_t count = num / split_count; - begin[index] = SizeToInt(rank) * count; - end[index] = (SizeToInt(rank) + 1) * count; - } - index++; - } - - op_ = CreateStridedSliceOp(DEFAULT, begin, end, strides); - - return Status::SUCCESS; -} - -Status ConstructOperator::AllGatherOP(int32_t dev_dim) { - if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { - MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AllGather operator!"; - return Status::INVALID_ARGUMENT; - } - - std::vector group_list; - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << "AllGather op: create group failed"; - return FAILED; - } else if (group_list.empty()) { // this group only has one device, don't need do allgather - MS_LOG(INFO) << "no need all gather op"; - return SUCCESS; - } - - std::string group_name = group_list[0].name(); - ValuePtr attr_value = MakeValue(group_name); - Attr attr = std::make_pair(GROUP, attr_value); - OperatorAttrs attrs = {attr}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(ALL_GATHER, args); - return Status::SUCCESS; -} - -Status ConstructOperator::ConcatOP(int32_t concat_dim) { - if (IntToSize(concat_dim) >= tensor_shape_.size()) { - MS_LOG(ERROR) << "Invalid tensor dimension " << concat_dim << " when construct Concat operator!"; - return Status::INVALID_ARGUMENT; - } - ValuePtr attr_value = MakeValue(concat_dim); - Attr attr = std::make_pair(AXIS, attr_value); - OperatorAttrs attrs = {attr}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(CONCAT, args); - return Status::SUCCESS; -} - -Status ConstructOperator::SplitOP(int32_t split_count) { - if (split_count <= 0) { - MS_LOG(ERROR) << "Invalid split count when construct Split operator!"; - return Status::FAILED; - } - OperatorAttrs attrs; - ValuePtr attr_value_axis = MakeValue(DEFAULT); - Attr attr_axis = std::make_pair(AXIS, attr_value_axis); - ValuePtr attr_value_split = MakeValue(split_count); - Attr attr_split = std::make_pair(OUTPUT_NUM, attr_value_split); - attrs = {attr_axis, attr_split}; - OperatorParams params; - OperatorArgs args = std::make_pair(attrs, params); - op_ = std::make_pair(SPLIT, args); - return Status::SUCCESS; -} - -Status ConstructOperator::AlltoAllOP(Args args) { - if (args.size() < 4) { - MS_LOG(ERROR) << "args size should not be less than 4!"; - return Status::FAILED; - } - int32_t split_count = args[0]; - int32_t split_dim = args[1]; - int32_t concat_dim = args[2]; - int32_t dev_dim = args[3]; - if (split_count <= 0) { - MS_LOG(ERROR) << "Invalid split count when construct AlltoAll operator!"; - return Status::FAILED; - } - if (tensor_shape_[IntToSize(split_dim)] % split_count != 0) { - MS_LOG(ERROR) << "Tensor can not be split into " << split_count << " slices in the dimension " << split_dim - << "when construct AlltoAll operator!"; - return Status::INVALID_ARGUMENT; - } - if (IntToSize(concat_dim) >= tensor_shape_.size()) { - MS_LOG(ERROR) << "Invalid split count " << split_count << " when construct AlltoAll operator!"; - return Status::INVALID_ARGUMENT; - } - if ((IntToSize(dev_dim) >= dev_size_) || (dev_dim < 0)) { - MS_LOG(ERROR) << "Invalid device dimension " << dev_dim << " when construct AlltoAll operator!"; - return Status::INVALID_ARGUMENT; - } - - std::vector group_list; - if (CreateGroupByDim(dev_size_ - IntToSize(dev_dim) - 1, &group_list) != SUCCESS) { - MS_LOG(ERROR) << "AlltoAll op: create group failed"; - return FAILED; - } else if (group_list.empty()) { // this group only has one device, don't need do alltoall - MS_LOG(INFO) << "no need all to all op"; - return SUCCESS; - } - - std::string group_name = group_list[0].name(); - ValuePtr attr_value_group = MakeValue(group_name); - Attr attr_group = std::make_pair(GROUP, attr_value_group); - ValuePtr attr_value_split_count = MakeValue(split_count); - Attr attr_split_count = std::make_pair(SPLIT_COUNT, attr_value_split_count); - ValuePtr attr_value_split_dim = MakeValue(split_dim); - Attr attr_split_dim = std::make_pair(SPLIT_DIM, attr_value_split_dim); - ValuePtr attr_value_concat_dim = MakeValue(concat_dim); - Attr attr_concat_dim = std::make_pair(CONCAT_DIM, attr_value_concat_dim); - OperatorAttrs attrs = {attr_split_count, attr_split_dim, attr_concat_dim, attr_group}; - OperatorParams params; - OperatorArgs op_args = std::make_pair(attrs, params); - op_ = std::make_pair(ALL_TO_ALL, op_args); - return Status::SUCCESS; -} - -Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { - MS_EXCEPTION_IF_NULL(group); - CheckGlobalDeviceManager(); - MS_EXCEPTION_IF_NULL(g_device_manager); - int32_t rank = g_device_manager->global_rank(); - DeviceMatrix dev_matrix(rank, dev_list_, dev_matrix_shape_); - RankList group_devices; - if (dev_matrix.GetDevicesAlongDim(SizeToUint(axis), &group_devices) != SUCCESS) { - return FAILED; - } - // this group only has one device, don't need create the group - if (group_devices.size() == 1) { - MS_LOG(INFO) << "the group is empty"; - return SUCCESS; - } - - Group g = g_device_manager->CreateGroup(group_devices); - group->push_back(g); - return SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h deleted file mode 100644 index 1a69638fb6..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ - -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -using Args = std::vector; - -class ConstructOperator { - public: - const int32_t DEFAULT = 0; - ConstructOperator() : dev_size_(0) {} - ~ConstructOperator() = default; - Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); - Status ReshapeOP(Shape shape); - Status StridedSliceOP(Args args); - Status AllGatherOP(int32_t dev_dim); - Status SplitOP(int32_t split_count); - Status ConcatOP(int32_t concat_dim); - Status AlltoAllOP(Args args); - Operator GetOperator() const { return op_; } - void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } - - private: - Operator op_; - size_t dev_size_; - Shape tensor_shape_; - RankList dev_list_; - Shape dev_matrix_shape_; - Status CreateGroupByDim(size_t axis, std::vector *group); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_CONSTRUCT_OPERATOR_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc deleted file mode 100644 index 84c0580ba8..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/layout_transfer.h" -#include "common/utils.h" -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -std::string LayoutTransfer::ToString() const { - std::ostringstream buffer; - buffer << std::endl << std::string("from_in_ tensor layout:" + from_in_.ToString()); - buffer << std::endl << std::string("to_in_ tensor layout:" + to_in_.ToString()); - return buffer.str(); -} - -LayoutTransfer::~LayoutTransfer() = default; - -Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { - from_in_ = from_in; - to_in_ = to_in; - MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); - Status status = CheckValidTransfer(); - return status; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h deleted file mode 100644 index c4da4b728f..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ - -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -class LayoutTransfer { - public: - LayoutTransfer() = default; - virtual ~LayoutTransfer() = 0; - std::string ToString() const; - Status Init(const TensorLayout &from_in, const TensorLayout &to_in); - TensorLayout from_in() const { return from_in_; } - TensorLayout to_in() const { return to_in_; } - - protected: - bool IsSameTensorShape() const { return from_in_.IsSameTensorShape(to_in_); } - bool IsSameDeviceArrangement() const { return from_in_.IsSameDeviceArrangement(to_in_); } - - TensorLayout from_in_; - TensorLayout to_in_; - - private: - virtual Status CheckValidTransfer() = 0; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.cc b/mindspore/ccsrc/parallel/tensor_layout/map.cc deleted file mode 100644 index 669920fc44..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/map.cc +++ /dev/null @@ -1,171 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/map.h" -#include -#include -#include -#include "common/utils.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -Status Map::Init(const std::vector &array) { - Status status = Array::Init(array); - if (status != Status::SUCCESS) { - return Status::FAILED; - } - if (!IsValidMap()) { - MS_LOG(ERROR) << "invalid map " << this->ToString(); - return Status::FAILED; - } - return Status::SUCCESS; -} - -bool Map::IsValidMap() { - if (std::any_of(array_.begin(), array_.end(), [](int32_t value) { return ((value < 0) && (value != MAP_NONE)); })) { - return false; - } - // check that all none -1 value in array_ is different - std::vector sorted_array = array_; - std::sort(sorted_array.begin(), sorted_array.end()); - int32_t value = MAP_NONE; - for (auto &element : sorted_array) { - if (element == MAP_NONE) { - continue; - } - if (element == value) { - return false; - } - value = element; - } - return true; -} - -int32_t Map::GetMaxItem() const { - if (!array_.empty()) { - return *std::max_element(array_.begin(), array_.end()); - } else { - return MAP_NONE; - } -} - -int32_t Map::GetIndexByValue(int32_t value) const { - auto iter = find(array_.begin(), array_.end(), value); - if (iter != array_.end()) { - return static_cast(std::distance(array_.begin(), iter)); - } else { - return MAP_NONE; - } -} - -/* - * expand.size() should be equal to array_.size() - */ -std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { - if (expand_num_list.GetDimSize() != GetDimSize()) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i != GetDimSize(); i++) { - if (GetDimByIdx(i) == MAP_NONE) { - for (int32_t j = 0; j < expand_num_list.GetDimByIdx(i); j++) { - new_shape.push_back(MAP_NONE); - } - } else { - new_shape.push_back(GetDimByIdx(i)); - int32_t j = 1; - while (j < expand_num_list.GetDimByIdx(i)) { - new_shape.push_back(MAP_NONE); - j++; - } - } - } - auto map_new = std::make_shared(); - (void)map_new->Init(new_shape); - return map_new; -} - -/* - * expand.size() should be equal to array_.size() - */ -std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { - if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { - return nullptr; - } - std::vector new_shape; - for (uint32_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(i) == MAP_NONE) { - new_shape.push_back(MAP_NONE); - } else { - int32_t start_map = - expand_num_list.ComputeReverseAccumulateSumInReverseOrder()[static_cast(GetDimByIdx(i))]; - for (int32_t k = expand_num_list.GetDimByReverseIdx(static_cast(GetDimByIdx(i))) - 1; k >= 0; k--) { - new_shape.push_back(k + start_map); - } - } - } - auto map_new = std::make_shared(); - (void)map_new->Init(new_shape); - return map_new; -} - -std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { - if (GetMaxItem() >= static_cast(input_vector.size())) { - return nullptr; - } - std::vector out; - Arrangement empty_arrangement; - for (uint32_t i = 0; i < GetDimSize(); i++) { - if (GetDimByIdx(i) == MAP_NONE) { - out.push_back(empty_arrangement); - } else { - out.push_back(input_vector[IntToUint(SizeToInt(input_vector.size()) - 1 - GetDimByIdx(i))]); - } - } - return std::make_shared>(out); -} - -bool Map::CheckNoneByIdxList(std::vector idx_list) const { - for (auto &value : idx_list) { - if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { - return false; - } - } - return true; -} - -Map Map::SqueezeMapByIdxList(std::vector idx_list) const { - std::vector out_shape; - for (size_t i = 0; i < GetDimSize(); i++) { - auto it = std::find(idx_list.begin(), idx_list.end(), i); - if (it == idx_list.end()) { - out_shape.push_back(GetDimByIdx(SizeToUint(i))); - } - } - if (out_shape.empty()) { - MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation"; - out_shape.push_back(MAP_NONE); - } - Map out; - (void)out.Init(out_shape); - return out; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.h b/mindspore/ccsrc/parallel/tensor_layout/map.h deleted file mode 100644 index 8c8bba2775..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/map.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ - -#include -#include -#include -#include -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/arrangement.h" -#include "parallel/tensor_layout/array.h" - -namespace mindspore { -namespace parallel { -constexpr int32_t MAP_NONE = -1; - -class Map : public Array { - public: - Map() = default; - ~Map() override = default; - Status Init(const std::vector &array) override; - int32_t GetMaxItem() const; - int32_t GetIndexByValue(int32_t value) const; - std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; - std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; - std::shared_ptr> ReMapVector(const std::vector &input_vector) const; - bool CheckNoneByIdxList(std::vector idx_list) const; - Map SqueezeMapByIdxList(std::vector idx_list) const; - - private: - bool IsValidMap(); -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_MAP_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.cc deleted file mode 100644 index 7ed07ac02e..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/redistribution_layout_transfer.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/reshape_layout_transfer.h" -#include "parallel/tensor_layout/shape_util.h" - -namespace mindspore { -namespace parallel { -Status RedistributionLayoutTransfer::CheckValidTransfer() { return Status::SUCCESS; } - -/* - * unify device arrangement between in_layout and out_layout - * after this function is called, - * in_step1_layout.device_arrangement and out_step1_layout.device_arrangement will be the same - */ -std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangement() const { - Arrangement in_arrangement; - Arrangement out_arrangement; - in_arrangement = from_in_.device_arrangement(); - out_arrangement = to_in_.device_arrangement(); - std::shared_ptr unify_arrangement_ptr = in_arrangement.GetUnifiedShape(out_arrangement); - if (unify_arrangement_ptr == nullptr) { - return nullptr; - } - std::shared_ptr from_out_ptr = from_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); - if (from_out_ptr == nullptr) { - return nullptr; - } - std::shared_ptr to_out_ptr = to_in_.ExpandDeviceArrangement(*unify_arrangement_ptr); - if (to_out_ptr == nullptr) { - return nullptr; - } - ReshapeLayoutTransfer out; - Status status = out.Init(*from_out_ptr, *to_out_ptr); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -/* - * unify tensor shape between in_step1_layout.tensor_shape and out_step1_layout.tensor_shape - * after this function is called, - * in_step2_layout.tensor_shape and out_step2_layout.tensor_shape will be the same - */ -std::shared_ptr RedistributionLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { - std::shared_ptr unified_device_arrangement_ptr = UnifyDeviceArrangement(); - if (unified_device_arrangement_ptr == nullptr) { - return nullptr; - } - return unified_device_arrangement_ptr->UnifyDeviceArrangementAndTensorShape(); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.h deleted file mode 100644 index 7b57f46dd6..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_layout_transfer.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ - -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/layout_transfer.h" -#include "parallel/tensor_layout/reshape_layout_transfer.h" - -namespace mindspore { -namespace parallel { -class RedistributionLayoutTransfer : public LayoutTransfer { - public: - RedistributionLayoutTransfer() = default; - ~RedistributionLayoutTransfer() override = default; - std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; - - private: - Status CheckValidTransfer() override; - std::shared_ptr UnifyDeviceArrangement() const; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc deleted file mode 100644 index 946620ec4c..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/redistribution_operator_infer.h" - -#include - -#include "parallel/device_manager.h" - -namespace mindspore { -namespace parallel { -Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, - RankList dev_list, bool is_cost_model) { - in_tensor_map_ = tensor_layout.tensor_map(); - dev_mat_ = tensor_layout.device_arrangement(); - - if (in_tensor_map_.GetDimSize() == 0 || out_tensor_map.GetDimSize() != in_tensor_map_.GetDimSize()) { - MS_LOG(ERROR) << "Invalid input when initialize RedistributionOperatorInfer!"; - return Status::FAILED; - } - - cur_tensor_layout_ = tensor_layout; - out_tensor_map_ = out_tensor_map; - dev_list_ = std::move(dev_list); - - operator_list_.clear(); - operator_vector_.clear(); - output_info_vector_.clear(); - - if (constructor_.Init(dev_list_, dev_mat_.array()) != Status::SUCCESS) { - MS_LOG(ERROR) << "Init constructor failed"; - return Status::FAILED; - } - constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); - - size_t key = 0; - std::vector map = in_tensor_map_.array(); - for (int32_t item : map) { - map_[key++] = item; - } - - is_cost_model_ = is_cost_model; - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferRedistributionOperator() { - while (!map_.empty()) { - size_t len_global = operator_list_.size(); - - while (!map_.empty()) { - size_t len_split_by_axis = operator_list_.size(); - // split_by_axis operation - if (InferSplitByAxis() == Status::FAILED) { - return Status::FAILED; - } - // permute_by_axis operation - while (!map_.empty()) { - size_t len_permute_by_axis = operator_list_.size(); - if (InferPermuteByAxis() == Status::FAILED) { - return Status::FAILED; - } - if (len_permute_by_axis == operator_list_.size()) break; - } - if (len_split_by_axis == operator_list_.size()) break; - } - // concat_by_axis operation - if (InferConcatByAxis() == Status::FAILED) { - return Status::FAILED; - } - // break loop structure with concat_by_axis - if (len_global == operator_list_.size() && !map_.empty()) { - size_t index = map_.begin()->first; - int32_t in_dim = map_[index]; - map_[index] = NONE; - Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; - if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { - return Status::FAILED; - } - } - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferSplitByAxis() { - for (auto iter = map_.begin(); iter != map_.end();) { - uint32_t index = iter->first; - int32_t in_dim = iter->second; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (in_dim == out_dim) { - (void)map_.erase(iter++); - continue; - } - if (in_dim == NONE && - !std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { - Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; - if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { - MS_LOG(ERROR) << "Insert SplitByAxis Error!"; - return Status::FAILED; - } - (void)map_.erase(iter++); - } else { - (void)++iter; - } - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferPermuteByAxis() { - for (auto iter = map_.begin(); iter != map_.end();) { - uint32_t index = iter->first; - int32_t in_dim = map_[index]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (in_dim == out_dim) { - (void)map_.erase(iter++); - continue; - } - if (in_dim == NONE && - std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { - int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); - int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); - if (is_cost_model_) { - int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); - Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, - dev_num}; - if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { - MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; - return Status::FAILED; - } - } else { - Args args_allconcat = {cat_dim, out_dim, dev_num}; - Args args_allsplit = {dev_num, UintToInt(index), out_dim}; - if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { - MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; - return Status::FAILED; - } - if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { - MS_LOG(ERROR) << "Insert SplitByAxis Error!"; - return Status::FAILED; - } - } - (void)map_.erase(iter++); - map_[IntToSize(cat_dim)] = NONE; - } else { - (void)++iter; - } - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::InferConcatByAxis() { - for (auto iter = map_.begin(); iter != map_.end();) { - uint32_t index = iter->first; - int32_t in_dim = map_[index]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - if (in_dim != NONE && out_tensor_map_.GetIndexByValue(in_dim) == NONE) { - Args args = {SizeToInt(index), in_dim, dev_mat_.GetDimByReverseIdx(IntToUint(in_dim))}; - if (InsertOperator(CONCAT_BY_AXIS, args) == Status::FAILED) { - MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; - return Status::FAILED; - } - if (out_dim == NONE) { - (void)map_.erase(iter++); - } else { - map_[index] = NONE; - (void)++iter; - } - } else { - (void)++iter; - } - } - return Status::SUCCESS; -} - -// Transfer communicative operators into primitives and insert them into vector -Status RedistributionOperatorInfer::InsertOperator(OperatorName name, Args args) { - OperatorR op = std::make_pair(name, args); - OperatorC op_cost = std::make_pair(op, cur_tensor_layout_.slice_shape().array()); - operator_list_.push_back(op_cost); - if (construct_op_flag_) { - if (name == SPLIT_BY_AXIS) { - if (TransferSplitByAxis(args) == Status::FAILED) { - return Status::FAILED; - } - } else if (name == PERMUTE_BY_AXIS) { - if (TransferPermuteByAxis(args) == Status::FAILED) { - return Status::FAILED; - } - } else { - if (TransferConcatByAxis(args) == Status::FAILED) { - return Status::FAILED; - } - } - constructor_.UpdateTensorShape(cur_tensor_layout_.slice_shape().array()); - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::TransferSplitByAxis(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - uint32_t index = IntToUint(args[1]); - if (constructor_.StridedSliceOP(args) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - if (cur_tensor_layout_.UpdateTensorMap(index, args[2]) == Status::FAILED) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::TransferPermuteByAxis(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - if (constructor_.AlltoAllOP(args) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - uint32_t index = IntToUint(args[1]); - int32_t val = args[2]; - int32_t out_dim = out_tensor_map_.GetDimByIdx(index); - - if (cur_tensor_layout_.UpdateTensorMap(IntToUint(val), NONE) == Status::FAILED) { - return Status::FAILED; - } - if (cur_tensor_layout_.UpdateTensorMap(index, out_dim) == Status::FAILED) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -Status RedistributionOperatorInfer::TransferConcatByAxis(Args args) { - if (args.size() < 3) { - MS_LOG(ERROR) << "args size should not be less than 3!"; - return Status::FAILED; - } - int32_t tensor_dim = args[0]; - int32_t dev_dim = args[1]; - int32_t split_count = args[2]; - if (constructor_.AllGatherOP(dev_dim) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - if (tensor_dim != 0) { - if (constructor_.SplitOP(split_count) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(true, split_count)); - } - if (constructor_.ConcatOP(tensor_dim) != Status::SUCCESS) { - return Status::FAILED; - } else { - operator_vector_.push_back(constructor_.GetOperator()); - output_info_vector_.push_back(std::make_pair(false, 0)); - } - } - if (cur_tensor_layout_.UpdateTensorMap(IntToUint(tensor_dim), NONE) == Status::FAILED) { - return Status::FAILED; - } - return Status::SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h deleted file mode 100644 index 37a8ac3d9e..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ - -#include -#include -#include -#include -#include - -#include "parallel/tensor_layout/construct_operator.h" -#include "parallel/tensor_layout/redistribution_layout_transfer.h" -#include "utils/convert_utils.h" -namespace mindspore { -namespace parallel { -using DeviceArrangement = std::vector; -using TensorMap = std::vector; -using TensorShape = std::vector; -using RedistributionOperatorMap = std::unordered_map; -using OperatorR = std::pair; -using OperatorC = std::pair; -using OperatorList = std::vector; - -class RedistributionOperatorInfer { - public: - const int NONE = -1; - explicit RedistributionOperatorInfer(bool construct_op_flag = true) - : construct_op_flag_(construct_op_flag), is_cost_model_(false) {} - Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, - bool is_cost_model = false); - ~RedistributionOperatorInfer() = default; - OperatorList operator_list() const { return operator_list_; } - OperatorVector operator_vector() const { return operator_vector_; } - OutPutInfoVector output_info_vector() const { return output_info_vector_; } - Status InferRedistributionOperator(); - - private: - Status InferSplitByAxis(); - Status InferPermuteByAxis(); - Status InferConcatByAxis(); - Status TransferSplitByAxis(Args args); - Status TransferPermuteByAxis(Args args); - Status TransferConcatByAxis(Args args); - Status InsertOperator(OperatorName name, Args args); - - OperatorList operator_list_; - OperatorVector operator_vector_; - OutPutInfoVector output_info_vector_; - Arrangement dev_mat_; - RedistributionOperatorMap map_; - Map in_tensor_map_; - Map out_tensor_map_; - TensorLayout cur_tensor_layout_; - ConstructOperator constructor_; - RankList dev_list_; - bool construct_op_flag_; - bool is_cost_model_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_REDISTRIBUTION_OPERATOR_INFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc deleted file mode 100644 index 4c66befd78..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/reshape_layout_transfer.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" - -namespace mindspore { -namespace parallel { -Status ReshapeLayoutTransfer::CheckValidTransfer() { - if (!IsSameDeviceArrangement()) { - return Status::FAILED; - } - return Status::SUCCESS; -} - -std::shared_ptr ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { - bool is_unified = IsSameTensorShape(); - std::shared_ptr out_layout_ptr = std::make_shared(*this); - if (out_layout_ptr == nullptr) { - return nullptr; - } - while (!is_unified) { - std::shared_ptr temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); - if (temp_layout_ptr == nullptr) { - return nullptr; - } - out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); - if (out_layout_ptr == nullptr) { - return nullptr; - } - is_unified = out_layout_ptr->IsSameTensorShape(); - } - return out_layout_ptr; -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByTo() const { - std::shared_ptr out_ptr = std::make_shared(*this); - bool is_expanded = FromTensorShapeCanBeExpandByTo(); - while (!is_expanded) { - out_ptr = out_ptr->ExtendFromTensorShapeByExpandedTensorShape(); - if (out_ptr == nullptr) { - return nullptr; - } - is_expanded = out_ptr->FromTensorShapeCanBeExpandByTo(); - } - return out_ptr; -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByFrom() const { - std::shared_ptr out_ptr = std::make_shared(*this); - bool is_expanded = ToTensorShapeCanBeExpandByFrom(); - while (!is_expanded) { - out_ptr = out_ptr->ExtendToTensorShapeByExpandedTensorShape(); - if (out_ptr == nullptr) { - return nullptr; - } - is_expanded = out_ptr->ToTensorShapeCanBeExpandByFrom(); - } - return out_ptr; -} - -bool ReshapeLayoutTransfer::FromTensorShapeCanBeExpandByTo() const { - return from_in_.TensorShapeCanBeExpanded(to_in_.tensor_shape()); -} - -bool ReshapeLayoutTransfer::ToTensorShapeCanBeExpandByFrom() const { - return to_in_.TensorShapeCanBeExpanded(from_in_.tensor_shape()); -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendFromTensorShapeByExpandedTensorShape() const { - std::shared_ptr expanded_shape_ptr = ComputeExpandedFromTensorShapeByTo(); - if (expanded_shape_ptr == nullptr) { - return nullptr; - } - return ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); -} - -std::shared_ptr ReshapeLayoutTransfer::ExtendToTensorShapeByExpandedTensorShape() const { - std::shared_ptr exchanged_from_and_to_ptr = ExchangeFromAndTo(); - if (exchanged_from_and_to_ptr == nullptr) { - return nullptr; - } - std::shared_ptr expanded_shape_ptr = exchanged_from_and_to_ptr->ComputeExpandedFromTensorShapeByTo(); - if (expanded_shape_ptr == nullptr) { - return nullptr; - } - std::shared_ptr exchanged_out = - exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); - if (exchanged_out == nullptr) { - return nullptr; - } - return exchanged_out->ExchangeFromAndTo(); -} - -std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo() const { - ReshapeLayoutTransfer out; - Status status = out.Init(to_in_, from_in_); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement &expand_shape) const { - std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); - if (extend_tensor_shape_from_ptr == nullptr) { - return nullptr; - } - Arrangement unified_device_arrangement = extend_tensor_shape_from_ptr->device_arrangement(); - std::shared_ptr extend_device_arrangement_to_ptr = - to_in_.ExpandDeviceArrangement(unified_device_arrangement); - if (extend_device_arrangement_to_ptr == nullptr) { - return nullptr; - } - ReshapeLayoutTransfer out; - Status status = out.Init(*extend_tensor_shape_from_ptr, *extend_device_arrangement_to_ptr); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(out); -} - -std::shared_ptr ReshapeLayoutTransfer::ComputeExpandedFromTensorShapeByTo() const { - return from_in_.ComputeExpandedTensorShape(to_in_.tensor_shape()); -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h deleted file mode 100644 index ed62cb59da..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ - -#include -#include "parallel/status.h" -#include "parallel/tensor_layout/layout_transfer.h" - -namespace mindspore { -namespace parallel { -class ReshapeLayoutTransfer : public LayoutTransfer { - public: - ReshapeLayoutTransfer() = default; - ~ReshapeLayoutTransfer() override = default; - std::shared_ptr UnifyDeviceArrangementAndTensorShape() const; - std::shared_ptr ExtendFromTensorShapeByTo() const; - std::shared_ptr ExtendToTensorShapeByFrom() const; - std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; - std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; - std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement &expand_shape) const; - std::shared_ptr ExchangeFromAndTo() const; - - private: - Status CheckValidTransfer() override; - std::shared_ptr ComputeExpandedFromTensorShapeByTo() const; - bool FromTensorShapeCanBeExpandByTo() const; - bool ToTensorShapeCanBeExpandByFrom() const; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_RESHAPE_LAYOUT_TRANSFER_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc deleted file mode 100644 index e8f208708c..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc +++ /dev/null @@ -1,263 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/shape_util.h" -#include -#include "parallel/status.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -/* - * example: - * shape = [2, 8, 32] - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - */ -Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { - MS_EXCEPTION_IF_NULL(shape_accum); - shape_accum->clear(); - int64_t size = 1; - for (auto iter = shape.begin(); iter < shape.end(); ++iter) { - size *= *iter; - if (size <= 0) { - MS_LOG(ERROR) << "element of shape should not be zero"; - return Status::FAILED; - } - shape_accum->push_back(size); - } - return Status::SUCCESS; -} - -/* - * example: - * shape = [2, 8, 32] - * shape_accum = [2 * 8 * 32, 8 * 32, 32] - * - */ -Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { - MS_EXCEPTION_IF_NULL(shape_accum); - shape_accum->clear(); - int64_t size = 1; - for (auto iter = shape.end() - 1; iter >= shape.begin(); --iter) { - size *= *iter; - if (size <= 0) { - MS_LOG(ERROR) << "element of shape should not be zero"; - return Status::FAILED; - } - (void)shape_accum->insert(shape_accum->begin(), size); - } - return Status::SUCCESS; -} - -/* - * example: - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - * shape = [2, 8, 32] - * - */ -Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { - MS_EXCEPTION_IF_NULL(shape); - shape->clear(); - int64_t value = 1; - for (auto iter = shape_accum.begin(); iter < shape_accum.end(); ++iter) { - if ((*iter) == 0) { - MS_LOG(ERROR) << "element of shape_accum should not be zero"; - return Status::FAILED; - } - if ((*iter) % value != 0) { - MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; - return Status::FAILED; - } - shape->push_back(static_cast((*iter) / value)); - value = (*iter); - } - return Status::SUCCESS; -} - -/* - * example: - * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * shape = [2, 8, 32] - */ -Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { - MS_EXCEPTION_IF_NULL(shape); - shape->clear(); - int64_t value = 1; - for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) { - if (*iter == 0) { - MS_LOG(ERROR) << "element of shape_accum should not be zero"; - return Status::FAILED; - } - if ((*iter) % value != 0) { - MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order"; - return Status::FAILED; - } - (void)shape->insert(shape->begin(), static_cast((*iter) / value)); - value = *iter; - } - return Status::SUCCESS; -} - -/* - * example1: - * in1 = [2, 8] - * in2 = [4, 8] - * *out = [2, 4, 8] - * - * example2: - * in1 = [2, 4, 16] - * in2 = [8, 16] - * *out = [2, 4, 8, 16] - */ -Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, - std::vector *out_accum) { - MS_EXCEPTION_IF_NULL(out_accum); - out_accum->clear(); - auto in1_iter = in1_accum.begin(); - auto in2_iter = in2_accum.begin(); - while ((in1_iter < in1_accum.end()) || (in2_iter < in2_accum.end())) { - if ((*in1_iter <= 0) || (*in2_iter <= 0)) { - MS_LOG(ERROR) << "element of in1 and in2 must be larger than zero"; - return Status::FAILED; - } - if (*in1_iter < *in2_iter) { - out_accum->push_back(*in1_iter); - ++in1_iter; - continue; - } else if (*in1_iter == *in2_iter) { - out_accum->push_back(*in1_iter); - ++in1_iter; - ++in2_iter; - } else { - out_accum->push_back(*in2_iter); - ++in2_iter; - } - } - if ((in1_iter != in1_accum.end()) || (in2_iter != in2_accum.end())) { - MS_LOG(ERROR) << "last element of in1 and in2 must be equal"; - return Status::FAILED; - } - return Status::SUCCESS; -} - -/* - * example: - * in1 = [8, 4] - * in2 = [2, 16] - * out = [2, 4, 4] - */ -Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { - MS_EXCEPTION_IF_NULL(out); - std::vector in1_accum; - Status status = ShapeToAccumulateProduct(in1, &in1_accum); - if (status != Status::SUCCESS) { - return status; - } - std::vector in2_accum; - status = ShapeToAccumulateProduct(in2, &in2_accum); - if (status != Status::SUCCESS) { - return status; - } - std::vector out_accum; - status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum); - if (status != Status::SUCCESS) { - return status; - } - status = AccumulateProductToShape(out_accum, out); - if (status != Status::SUCCESS) { - return status; - } - return status; -} - -/* - * example1: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 8 * 32, 32, 8] - * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] - * - * example2: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] - * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] - */ -Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, - const std::vector &expand_accum_reverse, - std::vector *out_accum_reverse) { - MS_EXCEPTION_IF_NULL(out_accum_reverse); - out_accum_reverse->clear(); - auto in_riter = in_accum_reverse.rbegin(); - auto expand_riter = expand_accum_reverse.rbegin(); - while (expand_riter != expand_accum_reverse.rend()) { - if (in_riter == in_accum_reverse.rend()) { - MS_LOG(ERROR) << "invalid ExpandAccumProd inputs"; - return Status::FAILED; - } - if (*in_riter > *expand_riter) { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); - ++expand_riter; - } else if (*in_riter == *expand_riter) { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter); - ++in_riter; - ++expand_riter; - } else { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); - ++in_riter; - } - } - while (in_riter != in_accum_reverse.rend()) { - (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter); - ++in_riter; - } - return Status::SUCCESS; -} - -/* - * example1: - * in = [2, 8, 32] - * expand = [16, 4, 8] - * out = [2, 8, 4, 8] - * - * example2: - * in = [2, 8, 32] - * expand = [2, 4, 8] - * out = [2, 4, 2, 4, 8] - */ -Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { - MS_EXCEPTION_IF_NULL(out); - std::vector in_accum_reverse; - Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); - if (status != Status::SUCCESS) { - return status; - } - std::vector expand_accum_reverse; - status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse); - if (status != Status::SUCCESS) { - return status; - } - std::vector out_accum_reverse; - status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse); - if (status != Status::SUCCESS) { - return status; - } - status = AccumulateProductReverseToShape(out_accum_reverse, out); - if (status != Status::SUCCESS) { - return status; - } - return status; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h deleted file mode 100644 index 2ec21f3881..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ - -#include -#include -#include -#include -#include - -#include "parallel/status.h" - -namespace mindspore { -namespace parallel { -/* - * compute the accumulating product of all the values in shape from left to right, - * the accumulating results are saved in shape_accum from left to right - * - * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), - * then *shape_accum = [d_n-1, d_n-1 * d_n-2, d_n-1 * d_n-2 * d_n-3, ..., d_n-1 * d_n-2 * ... *d_0] - * - * example: - * shape = [2, 8, 32] - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - * - */ -Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); - -/* - * compute the accumulating product of all the values in shape from right to left, - * the accumulating results are saved in shape_accum from right to left - * - * given a shape = [d_n-1, d_n-2, ..., d_0](d_i > 0, i=0,1,...,n-1, elements of shape must be larger than zero), - * then *shape_accum = [d_n-1 * d_n-2 * ... *d_0, d_n-2 * d_n-3 * ... *d_0, ..., d_0] - * - * example: - * shape = [2, 8, 32] - * shape_accum = [2 * 8 * 32, 8 * 32, 32] - * - */ -Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); - -/* - * compute the original shape from the accumulating product shape_accum, - * elements of shape_accum is saved from left to right, - * given shape_accum = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] - * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), - * (accum_i-1 % accum_i == 0, i=1,...,n-1) - * then *shape = [accum_n-2/accum_n-1, accum_n-3/accum_n-2, ..., accum_0/accum_1] - * - * example: - * shape_accum = [2, 2 * 8, 2 * 8 * 32] - * shape = [2, 8, 32] - * - */ -Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); - -/* - * compute the original shape from the accumulating product shape_accum, - * elements of shape_accum is saved from right to left, - * given shape_accum_reverse = [accum_n-1, accum_n-2, accum_n-3, ..., accum_0] - * (accum_i > 0, i=0,1,...,n-1, elements of shape_accum must be larger than zero), - * (accum_i % accum_i-1 == 0, i=1,...,n-1) - * then *shape = [accum_n-1/accum_n-2, accum_n-2/accum_n-1, ..., accum_1/accum_0] - * - * example: - * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * shape = [2, 8, 32] - * - */ -Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); - -/* - * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, - * results are saved in out. - * i.e. *out_accum = in1_accum U in2_accum - * elements of out are saved in increasing order - * - * example1: - * in1_accum = [2, 8] - * in2_accum = [4, 8] - * out_accum = [2, 4, 8] - * - * example2: - * in1_accum = [2, 4, 16] - * in2_accum = [8, 16] - * out_accum = [2, 4, 8, 16] - */ -Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, - std::vector *out_accum); - -/* - * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] - * size = din1_n-1 * din1n-2 * ... * din1_0 = din2_m-1 * din2_m-2 * ... * din2_0 - * find *out = [dout_k-1, dout_k-2, ..., dout_0], s.t. dout_k-1 * dout_k-2 * ... * dout_0 = size and - * suppose in1_accum, in2_accum, and *out_accum is the ShapeToAccumulateProduct result of in1, in2, and *out - * then for each din1_i in in1_accum, din1_i is in *out_accumulate, - * for each din2_i in in2_accum, din2_i is in *out_accumulate - * - * example: - * in1 = [8, 4] - * in2 = [2, 16] - * out = [2, 4, 4] - */ -Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); - -/* - * given two accumulate product in reverse order of in and expand, - * in_accum_reverse = [din_n-1, din_n-2, ..., din_0] and expand_pos_reverse = [dexp_n-1, dexp_n-2, ..., dexp_0], - * i.e. in_accum_reverse is the ShapeToAccumulateProductReverse result of a shape in, - * expand_accum_reverse is the ShapeToAccumulateProductReverse result of a shape expand, - * compute the accumulate product in reverse order out_accum_reverse = [dout_k-1, dout_k-2, ..., dout_0], - * s.t. elements in out_accum_reverse are union of elements in in_accum_reverse and expand_accum_reverse - * (out_accum_reverse = in_accum_reverse U expand_accum_reverse), and - * out_accum_reverse is the ShapeToAccumulateProductReverse result of shape expand, - * i.e. dout_i > 0, i=0,1,...,k-1, elements of out_accum_reverse must be larger than zero, - * dout_i-1 % dout_i == 0, i=1,...,k-1 - * - * example1: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 8 * 32, 32, 8] - * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8] - * - * example2: - * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32] - * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] - * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] - */ -Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, - const std::vector &expand_accum_reverse, - std::vector *out_accum_reverse); - -/* - * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], - * compute the expended shape out = [dout_k-1, dout_k-2, ..., dout_0], - * s.t. dout_k-1 * dout_k-2 * ...* dout_0 = din_n-1 * din_n-2 * ... * d_0 - * suppose in_accum_reverse is the ShapeToAccumulateProductReverse result of in, - * expand_accum_reverse is the ShapeToAccumulateProductReverse result of expand, - * out_accum_reverse is the ShapeToAccumulateProductReverse result of out, - * then out_accum_reverse is the union of in_accum_reverse and expand_accum_reverse - * (out_accum_reverse = in_accum_reverse U expand_accum_reverse) - * - * example1: - * in = [2, 8, 32] - * expand = [16, 4, 8] - * out = [2, 8, 4, 8] - * - * example2: - * in = [2, 8, 32] - * expand = [2, 4, 8] - * out = [2, 4, 2, 4, 8] - */ -Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_SHAPE_UTIL_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h deleted file mode 100644 index 0eee736cea..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ - -#include -#include -#include -#include - -#include "parallel/device_matrix.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -using Shapes = std::vector; - -class TensorInfo { - public: - TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) - : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} - explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { - shape_ = tensor_layout.tensor_shape().array(); - slice_shape_ = tensor_layout.slice_shape().array(); - } - // trivial default constructor will not initialize c language types. - TensorInfo() = default; - ~TensorInfo() = default; - TensorLayout tensor_layout() const { return tensor_layout_; } - Shape slice_shape() const { return slice_shape_; } - Shape shape() const { return shape_; } - void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } - std::vector reduce_dim() const { return reduce_dim_; } - Dimensions InferStrategy() const { - Dimensions stra; - for (size_t i = 0; i < shape_.size(); ++i) { - if ((slice_shape_[i] == 0) || (shape_[i] % slice_shape_[i] != 0)) { - return stra; - } - int32_t dim = (int32_t)(shape_[i] / slice_shape_[i]); - stra.push_back(dim); - } - return stra; - } - - private: - TensorLayout tensor_layout_; - Shape shape_; - Shape slice_shape_; - // reduce method's reduce dim - std::vector reduce_dim_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_INFO_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc deleted file mode 100644 index f3498065f2..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc +++ /dev/null @@ -1,394 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/tensor_layout.h" -#include -#include -#include "common/utils.h" -#include "ir/value.h" -#include "parallel/device_matrix.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/array.h" -#include "parallel/tensor_layout/shape_util.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parallel { -std::string TensorLayout::ToString() const { return StandardToString() + OriginToString(); } - -std::string TensorLayout::StandardToString() const { - std::ostringstream buffer; - buffer << std::endl << std::string("device arrangement = " + device_arrangement_.ToString()); - buffer << std::endl << std::string("tensor map = " + tensor_map_.ToString()); - buffer << std::endl << std::string("tensor shape = " + tensor_shape_.ToString()); - return buffer.str(); -} - -std::string TensorLayout::OriginToString() const { - std::ostringstream buffer; - buffer << std::endl << std::string("device arrangement origin = " + device_arrangement_origin_.ToString()); - buffer << std::endl << std::string("tensor map origin = " + tensor_map_origin_.ToString()); - buffer << std::endl << std::string("tensor shape origin = " + tensor_shape_origin_.ToString()); - return buffer.str(); -} - -Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, - const Arrangement &tensor_shape) { - device_arrangement_origin_ = device_arrangement; - tensor_map_origin_ = tensor_map; - tensor_shape_origin_ = tensor_shape; - device_arrangement_ = device_arrangement; - tensor_map_ = tensor_map; - tensor_shape_ = tensor_shape; - if (IsValidTensorLayout()) { - MS_LOG(DEBUG) << "valid origin tensor layout " << this->OriginToString(); - RemoveElementEqualToOneInDeviceArrangement(); - MS_LOG(DEBUG) << "standard tensor layout " << this->StandardToString(); - return Status::SUCCESS; - } else { - MS_LOG(ERROR) << "invalid origin tensor layout " << this->OriginToString(); - return Status::FAILED; - } -} - -Status TensorLayout::InitFromVector(const std::vector &device_arrangement, - const std::vector &tensor_map, const std::vector &tensor_shape) { - if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { - return FAILED; - } - if (tensor_map_origin_.Init(tensor_map) != SUCCESS) { - return FAILED; - } - if (tensor_shape_origin_.Init(tensor_shape) != SUCCESS) { - return FAILED; - } - if (Init(device_arrangement_origin_, tensor_map_origin_, tensor_shape_origin_) != SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -bool TensorLayout::IsValidTensorLayout() const { - if (tensor_map_origin_.GetMaxItem() >= static_cast(device_arrangement_origin_.GetDimSize())) { - MS_LOG(ERROR) << "the max element in tensor_map_origin_ must be smaller than device_arrangement_origin_ size!"; - return false; - } - if (tensor_map_origin_.GetDimSize() != tensor_shape_origin_.GetDimSize()) { - MS_LOG(ERROR) << "tensor_map_origin_ size must be equal to tensor_shape_origin_ size!"; - return false; - } - if (!TensorShapeDimensionIsDividedBySplitDeviceDimension()) { - MS_LOG(ERROR) << "TensorShapeDimensionIsDividedBySplitDeviceDimension failed!"; - return false; - } - return true; -} - -bool TensorLayout::TensorShapeDimensionIsDividedBySplitDeviceDimension() const { - for (uint32_t i = 0; i < tensor_map_.GetDimSize(); i++) { - if (tensor_map_.GetDimByIdx(i) != -1) { - int32_t divisor = GetSliceNumByTensorDimensionIndex(i); - if (divisor == 0) { - MS_LOG(ERROR) << "GetSliceNumByTensorDimensionIndex is 0"; - return false; - } - if (tensor_shape_.GetDimByIdx(i) % divisor != 0) { - return false; - } - } - } - return true; -} - -void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { - std::vector device_arrangement_shape; - std::vector tensor_map_shape = tensor_map_origin_.array(); - uint32_t dev_num = SizeToUint(device_arrangement_origin_.GetDimSize()); - int32_t dev_num_left = SizeToInt(device_arrangement_origin_.GetDimSize()); - for (uint32_t i = 0; i < dev_num; i++) { - if (device_arrangement_origin_.GetDimByIdx(i) == 1) { - int32_t idx = GetTensorDimensionIndexByDeviceDimensionIndex(static_cast(dev_num - 1 - i)); - if (idx != -1) { - tensor_map_shape[static_cast(idx)] = -1; - } - for (auto &value : tensor_map_shape) { - if (value >= dev_num_left - 1 - static_cast(i)) { - value--; - } - } - continue; - } - device_arrangement_shape.push_back(device_arrangement_origin_.GetDimByIdx(i)); - } - (void)device_arrangement_.Init(device_arrangement_shape); - (void)tensor_map_.Init(tensor_map_shape); - tensor_shape_ = tensor_shape_origin_; -} - -// if idx is not in tensor_map, return -1 -int32_t TensorLayout::GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const { - return tensor_map_.GetIndexByValue(idx); -} - -// tensor_map_.GetDimByIdx(idx) should not be -1 -int32_t TensorLayout::GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const { - return static_cast(device_arrangement_.GetDimSize()) - 1 - tensor_map_.GetDimByIdx(idx); -} - -// tensor_map_.GetDimByIdx(idx) should not be -1 -int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { - return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); -} - -std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { - std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); - if (expanded_arrangement_ptr == nullptr) { - return nullptr; - } - std::shared_ptr temp_tensor_layout_ptr = ExpandDeviceArrangement(*expanded_arrangement_ptr); - if (temp_tensor_layout_ptr == nullptr) { - return nullptr; - } - return temp_tensor_layout_ptr->ExpandTensorShapeWithoutExtendDeviceArrangement(expanded_shape); -} - -/* - * example1: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, 0], - * in_tensor_shape = [512, 1024], - * out_tensor_shape = [128, 4, 2, 512], - * => - * out_device_arrangement = [8, 2, 2] - */ -std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { - std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); - if (expand_list_ptr == nullptr) { - return nullptr; - } - std::vector re_map_expand_list; - Arrangement empty_arrangement; - for (int32_t i = static_cast(device_arrangement_.GetDimSize()) - 1; i >= 0; i--) { - if (tensor_map_.GetIndexByValue(i) < 0) { - re_map_expand_list.push_back(empty_arrangement); - } else { - re_map_expand_list.push_back((*expand_list_ptr)[IntToUint(tensor_map_.GetIndexByValue(i))]); - } - } - std::shared_ptr new_arrangement_ptr = - device_arrangement_.GetExpandedShapeByExpandListRemoveLeft(re_map_expand_list); - return new_arrangement_ptr; -} - -/* - * example1: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, 0], - * in_tensor_shape = [512, 1024], - * out_tensor_shape = [8, 64, 4, 256] - * => - * out_device_arrangement = [8, 4], - * out_tensor_map = [1, -1, 0, -1], - */ -std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement &expanded_shape) const { - std::shared_ptr, Arrangement>> expand_list_pair_ptr = - tensor_shape_.GetExpandShapeListPair(expanded_shape); - if (expand_list_pair_ptr == nullptr) { - return nullptr; - } - std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByNone(expand_list_pair_ptr->second); - if (tensor_map_new_ptr == nullptr) { - return nullptr; - } - TensorLayout tensor_layout_new; - Status status = tensor_layout_new.Init(device_arrangement_, *tensor_map_new_ptr, expanded_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(tensor_layout_new); -} - -/* - * example1: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, 0], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 2, 2] - * => - * out_tensor_map = [3, 2, 1, 0], - * out_tensor_shape = [4, 128, 2, 512] - * - * example2: - * in_device_arrangement = [8, 4], - * in_tensor_map = [0, 1], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 2, 2] - * => - * out_tensor_map = [1, 0, 3, 2], - * out_tensor_shape = [2, 256, 4, 256] - * - * example3: - * in_device_arrangement = [8, 4], - * in_tensor_map = [1, -1], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 2, 2] - * => - * out_tensor_map = [3, 2, -1], - * out_tensor_shape = [4, 128, 1024] - * - * example4: - * in_device_arrangement = [8, 4], - * in_tensor_map = [0, 1], - * in_tensor_shape = [512, 1024], - * out_device_arrangement = [4, 2, 4] - * => - * out_tensor_map = [0, 2, 1], - * out_tensor_shape = [512, 4, 256] - */ -std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { - std::shared_ptr, Arrangement>> expand_list_pair_ptr = - device_arrangement_.GetExpandShapeListPair(expanded_arrangement); - if (expand_list_pair_ptr == nullptr) { - return nullptr; - } - std::shared_ptr tensor_map_new_ptr = tensor_map_.ExpandMapByDecreaseNumber(expand_list_pair_ptr->second); - if (tensor_map_new_ptr == nullptr) { - return nullptr; - } - std::shared_ptr> re_map_shape_list_ptr = - tensor_map_.ReMapVector(expand_list_pair_ptr->first); - if (re_map_shape_list_ptr == nullptr) { - return nullptr; - } - std::shared_ptr tensor_shape_new_ptr = - tensor_shape_.GetExpandedShapeByExpandListReserveLeft(*re_map_shape_list_ptr); - if (tensor_shape_new_ptr == nullptr) { - return nullptr; - } - TensorLayout tensor_layout_new; - Status status = tensor_layout_new.Init(expanded_arrangement, *tensor_map_new_ptr, *tensor_shape_new_ptr); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(tensor_layout_new); -} - -bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { - std::vector in_expand_shape_shape; - Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); - if (status != Status::SUCCESS) { - return false; - } - return (in_expand_shape_shape == tensor_shape_.array()); -} - -std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { - std::vector in_expand_shape_shape; - Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - Arrangement expanded_shape; - status = expanded_shape.Init(in_expand_shape_shape); - if (status != Status::SUCCESS) { - return nullptr; - } - return std::make_shared(expanded_shape); -} - -Arrangement TensorLayout::slice_shape() const { - std::vector shape; - for (uint32_t index = 0; index < tensor_map_.GetDimSize(); index++) { - int32_t dim = tensor_map_.GetDimByIdx(index); - int32_t num = tensor_shape_.GetDimByIdx(index); - if (dim == -1) { - shape.push_back(num); - } else { - int32_t divisor = device_arrangement_.GetDimByReverseIdx(IntToUint(dim)); - shape.push_back(num / divisor); - } - } - Arrangement new_tensor_shape; - if (new_tensor_shape.Init(shape) == Status::FAILED) { - ValuePtr ptr = MakeValue(shape); - MS_LOG(EXCEPTION) << "Can't get slice shape when initialize a new shape " << ptr->ToString(); - } else { - return new_tensor_shape; - } -} - -Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { - if (index >= tensor_map_.GetDimSize()) { - MS_LOG(ERROR) << "Index is out of the size of the tensor map!"; - return Status::FAILED; - } - auto shape = tensor_map_.array(); - shape[index] = value; - if (tensor_map_.Init(shape) == Status::FAILED) { - MS_LOG(ERROR) << "Update tensor map failed!"; - return Status::FAILED; - } - return Status::SUCCESS; -} - -bool TensorLayout::operator==(const TensorLayout &t1) const { - return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); -} - -/* - * remove elements equal to 1 in tensor_shape, if all elements are 1, squeeze the tensor_shape to [ 1 ] - * example 1: - * original tensor layout: - * device arrangement = [ 8 ] - * tensor map = [ 0 -1 -1 -1 ] - * tensor shape = [ 128 64 1 1 ] - * return tensor layout: - * device arrangement = [ 8 ] - * tensor map = [ 0 -1 ] - * tensor shape = [ 128 64 ] - * - * example 2: - * device arrangement = [ 8 ] - * tensor map = [ -1 -1 -1 -1 ] - * tensor shape = [ 1 1 1 1 ] - * return tensor layout: - * device arrangement = [ 8 ] - * tensor map = [ -1 ] - * tensor shape = [ 1 ] - */ -TensorLayout TensorLayout::SqueezeShape() const { - TensorLayout out; - Map out_map; - Arrangement out_shape; - if (tensor_shape_.size() == 1) { - (void)out_map.Init({MAP_NONE}); - (void)out_shape.Init({1}); - (void)out.Init(device_arrangement_, out_map, out_shape); - return out; - } - std::vector squeeze_list = tensor_shape_.GetSqueezeIdx(); - if (!tensor_map_.CheckNoneByIdxList(squeeze_list)) { - MS_LOG(ERROR) << "CheckNoneByIdxList failed, this may not happen under current situation"; - return *this; - } - out_shape = tensor_shape_.GetSqueezeArrangement(); - out_map = tensor_map_.SqueezeMapByIdxList(squeeze_list); - (void)out.Init(device_arrangement_, out_map, out_shape); - return out; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h deleted file mode 100644 index f51ed4e3e0..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ - -#include -#include -#include -#include -#include -#include "parallel/device_manager.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/arrangement.h" -#include "parallel/tensor_layout/map.h" -#include "utils/convert_utils.h" - -namespace mindspore { -namespace parallel { -class TensorLayout { - public: - TensorLayout() = default; - ~TensorLayout() = default; - std::string ToString() const; - std::string StandardToString() const; - std::string OriginToString() const; - Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); - Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, - const std::vector &tensor_shape); - - Arrangement device_arrangement() const { return device_arrangement_; } - - Map tensor_map() const { return tensor_map_; } - - Arrangement tensor_shape() const { return tensor_shape_; } - - Map origin_tensor_map() const { return tensor_map_origin_; } - - std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; - - std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; - - bool IsSameTensorShape(const TensorLayout &tensor_layout) const { - return (tensor_shape_ == tensor_layout.tensor_shape()); - } - - bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { - return (device_arrangement_ == tensor_layout.device_arrangement()); - } - - bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } - - bool operator==(const TensorLayout &t1) const; - - bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; - - std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; - - Arrangement slice_shape() const; - - Status UpdateTensorMap(uint32_t index, int32_t value); - - TensorLayout SqueezeShape() const; - - private: - std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement &expanded_shape) const; - std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; - bool IsValidTensorLayout() const; - void RemoveElementEqualToOneInDeviceArrangement(); - int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; - int32_t GetSliceNumByTensorDimensionIndex(uint32_t idx) const; - bool TensorShapeDimensionIsDividedBySplitDeviceDimension() const; - int32_t GetTensorDimensionIndexByDeviceDimensionIndex(int32_t idx) const; - - Arrangement device_arrangement_origin_; - Map tensor_map_origin_; - Arrangement tensor_shape_origin_; - Arrangement device_arrangement_; - Map tensor_map_; - Arrangement tensor_shape_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_LAYOUT_H_ diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc deleted file mode 100644 index 7824c21f3d..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ /dev/null @@ -1,209 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "parallel/tensor_layout/tensor_redistribution.h" -#include -#include -#include -#include "common/utils.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/shape_util.h" - -namespace mindspore { -namespace parallel { -Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { - from_origin_ = from; - to_origin_ = to; - if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { - MS_LOG(ERROR) << "from shape size must be equal to to shape size!"; - MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString(); - MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString(); - return Status::FAILED; - } - - dev_list_ = dev_list; - from_ = from_origin_.SqueezeShape(); - to_ = to_origin_.SqueezeShape(); - return Status::SUCCESS; -} - -RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { - // Step 1: Match device arrangement between from_ and to_ - RedistributionLayoutTransfer layout_transfer; - Status status = layout_transfer.Init(from_, to_); - if (status != Status::SUCCESS) { - return nullptr; - } - std::shared_ptr ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape(); - if (ptr == nullptr) { - MS_LOG(ERROR) << "Infer tensor layout return nullptr!"; - return nullptr; - } - TensorLayout from_layout = ptr->from_in(); - TensorLayout to_layout = ptr->to_in(); - MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString(); - MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString(); - MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString(); - MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString(); - MS_LOG(DEBUG) << "reshape from_ " << from_.ToString(); - MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); - // Step 2: Infer redistribution and insert operators - RedistributionOperatorInfer operator_infer(construct_op_flag_); - if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { - MS_LOG(ERROR) << "Init operatorInfer failed!"; - return nullptr; - } - OperatorVector operator_vector; - OutPutInfoVector output_info_vector; - if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) { - MS_LOG(ERROR) << "Infer redistribution failed!"; - return nullptr; - } else { - operator_vector = operator_infer.operator_vector(); - output_info_vector = operator_infer.output_info_vector(); - operator_list_ = operator_infer.operator_list(); - } - - // Step 3: Infer reshape and insert operators - if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) { - MS_LOG(ERROR) << "Construct Reshape operator failed!"; - return nullptr; - } - - return std::make_shared>( - std::make_pair(operator_vector, output_info_vector)); -} - -Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, - OperatorVector *const operator_vector, - OutPutInfoVector *const output_info_vector) { - MS_EXCEPTION_IF_NULL(operator_vector); - MS_EXCEPTION_IF_NULL(output_info_vector); - ConstructOperator constructor; - if (operator_list_.empty()) { - if (from_origin_.slice_shape().array() != to_origin_.slice_shape().array() || keep_reshape_) { - reshape_flag_ = true; - constructor.UpdateTensorShape(from_origin_.slice_shape().array()); - Arrangement shape = to_origin_.slice_shape(); - MS_LOG(DEBUG) << "reshape " << shape.ToString(); - if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { - return Status::FAILED; - } else { - (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); - (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); - } - } - return Status::SUCCESS; - } - - if (from_origin_.slice_shape().array() != from_layout.slice_shape().array()) { - reshape_flag_ = true; - constructor.UpdateTensorShape(from_origin_.slice_shape().array()); - Arrangement shape = from_layout.slice_shape(); - MS_LOG(DEBUG) << "reshape " << shape.ToString(); - if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { - return Status::FAILED; - } else { - (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator()); - (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0)); - } - } - - if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) { - reshape_flag_ = true; - constructor.UpdateTensorShape(to_layout.slice_shape().array()); - Arrangement shape = to_origin_.slice_shape(); - MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString(); - if (constructor.ReshapeOP(shape.array()) == Status::FAILED) { - return Status::FAILED; - } else { - (void)operator_vector->insert(operator_vector->end(), constructor.GetOperator()); - (void)output_info_vector->insert(output_info_vector->end(), std::make_pair(false, 0)); - } - } - return Status::SUCCESS; -} - -Status TensorRedistribution::ComputeCost() { - RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); - if (redistribution_oplist_ptr == nullptr) { - MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; - return Status::FAILED; - } - // Compute redistribution communication cost and computation cost - for (auto &op_cost : operator_list_) { - OperatorR op = op_cost.first; - Shape slice_shape = op_cost.second; - double prod = - std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast(1.0), std::multiplies()); - std::string str = op.first; - if (str == PERMUTE_BY_AXIS) { - // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. - // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape - forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; - backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; - comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; - int32_t concat_dim = op.second[2]; - if (concat_dim == 0) { - // memory cost = all_gather - computation_cost_ += prod; - memory_cost_ += prod; - } else { - // memory cost = all_gather + split + concat - int32_t dev_num = op.second[4]; - computation_cost_ += (prod + prod * dev_num + prod * dev_num); - memory_cost_ += (prod * dev_num + prod * dev_num + prod); - } - } else if (str == CONCAT_BY_AXIS) { - // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape - // computation cost = before_slice_shape - if (op.second.size() < 3) { - MS_LOG(ERROR) << "op.second size should not be less than 3!"; - return Status::FAILED; - } - double dev_num = op.second[2]; - // here, communication cost = all_gather + reduce_scatter - forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; - backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; - comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; - int32_t concat_dim = op.second[0]; - if (concat_dim == 0) { - // computation cost = all_gather - computation_cost_ += prod; - memory_cost_ += prod * dev_num; - } else { - // computation cost = all_gather + split + concat - computation_cost_ += (prod + prod * dev_num + prod * dev_num); - memory_cost_ += (prod * dev_num + prod * dev_num + prod); - } - } else { - // There is only computation cost in SplitByAxis. - // computation cost = before_slice_shape - computation_cost_ += prod; - // This addtion may be erroneous - memory_cost_ += prod; - } - } - if (reshape_flag()) { - Shape prev_slice_shape = from_.slice_shape().array(); - double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies()); - computation_cost_ += 2.0 * prev_prod; - memory_cost_ += 2.0 * prev_prod; - } - return Status::SUCCESS; -} -} // namespace parallel -} // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h deleted file mode 100644 index d1f46108bb..0000000000 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ /dev/null @@ -1,90 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ -#define MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ - -#include -#include -#include -#include -#include -#include - -#include "ir/value.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/status.h" -#include "parallel/tensor_layout/construct_operator.h" -#include "parallel/tensor_layout/redistribution_operator_infer.h" -#include "parallel/tensor_layout/tensor_layout.h" - -namespace mindspore { -namespace parallel { -constexpr double ALLTOALL_SCALE_FACTOR = 2.0; -constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; -class TensorRedistribution { - public: - explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) - : reshape_flag_(false), - comm_cost_(0.0), - forward_comm_cost_(0.0), - backward_comm_cost_(0.0), - computation_cost_(0.0), - memory_cost_(0.0), - construct_op_flag_(construct_op_flag), - keep_reshape_(keep_reshape) {} - Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); - ~TensorRedistribution() = default; - RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); - OperatorList operator_list() const { return operator_list_; } - bool reshape_flag() const { return reshape_flag_; } - Status ComputeCost(); - double comm_cost() const { return comm_cost_; } - double computation_cost() const { return computation_cost_; } - double forward_comm_cost() const { return forward_comm_cost_; } - double backward_comm_cost() const { return backward_comm_cost_; } - double memory_cost() const { return memory_cost_; } - - private: - Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, - OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); - - TensorLayout from_origin_; - TensorLayout to_origin_; - TensorLayout from_; - TensorLayout to_; - RankList dev_list_; - OperatorList operator_list_; - bool reshape_flag_; - // communication cost, which is the sum of forward communication cost and backward communication cost - double comm_cost_; - // forward communication cost - double forward_comm_cost_; - // backward communication cost - double backward_comm_cost_; - // computation_cost models the time spending on computing in this tensor redistribution, which is calculated by the - // inputs. This is calculated ONLY for forward phase. - double computation_cost_; - // memory_cost models the PEAK memory cost in a training iteration contributed by this tensor redistribution, which is - // calculated by the outputs. - double memory_cost_; - bool construct_op_flag_; - bool keep_reshape_; -}; -} // namespace parallel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PARALLEL_TENSOR_LAYOUT_TENSOR_REDISTRIBUTION_H_ diff --git a/mindspore/ccsrc/pipeline/CMakeLists.txt b/mindspore/ccsrc/pipeline/CMakeLists.txt deleted file mode 100644 index 39664d717d..0000000000 --- a/mindspore/ccsrc/pipeline/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "pipeline.cc" - "resource.cc" - "pass.cc" - "action.cc" - "validator.cc" - "remove_value_node_dup.cc" - "parse/*.cc" - "static_analysis/*.cc" -) - - -file(GLOB PIPELINE_SRC_FILES "*.cc") -set_property(SOURCE ${PIPELINE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) - -file(GLOB_RECURSE PARSER_SRC_FILES "parse/*.cc") -set_property(SOURCE ${PARSER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARSER) - -file(GLOB_RECURSE ANALYZER_SRC_FILES "static_analysis/*.cc") -set_property(SOURCE ${ANALYZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) - -if (ENABLE_GE OR ENABLE_D) - file(GLOB_RECURSE _PIPELINE_GE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pipeline_ge.cc") - list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES}) -endif () - -add_library(_mindspore_pipeline_obj OBJECT ${_PIPELINE_SRC_FILES}) diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc deleted file mode 100644 index 3648bc991e..0000000000 --- a/mindspore/ccsrc/pipeline/action.cc +++ /dev/null @@ -1,494 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/action.h" - -#include -#include -#include -#include -#include -#include - -#include "ir/func_graph_cloner.h" -#include "ir/param_value.h" -#include "parallel/costmodel_context.h" -#include "parallel/context.h" -#include "pipeline/pass.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/data_converter.h" -#include "abstract/abstract_value.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "pipeline/static_analysis/program_specialize.h" -#include "pipeline/resource.h" -#include "utils/context/ms_context.h" -#include "pipeline/remove_value_node_dup.h" -#include "optimizer/optimizer.h" -#include "vm/transform.h" -#include "parse/python_adapter.h" -#include "optimizer/py_pass_manager.h" - -namespace mindspore { -namespace pipeline { -using CompileGraphs = compile::CompileGraphs; -using abstract::AnalysisResult; -using mindspore::abstract::AnalysisContextPtr; - -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec, bool clear) { - MS_LOG(DEBUG) << "AbstractAnalyze start"; - auto engine = res->engine(); - MS_EXCEPTION_IF_NULL(engine); - if (clear) { - auto manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - engine->Clear(); - for (auto &node : manager->all_nodes()) { - MS_EXCEPTION_IF_NULL(node); - const AbstractBasePtr &prev_inferred = node->abstract(); - // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. - if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { - node->set_abstract(nullptr); - MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr"; - } - } - } - auto ret = engine->Run(func_graph, args_spec); - MS_LOG(DEBUG) << "AbstractAnalyze end"; - return ret; -} - -FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context) { - MS_LOG(DEBUG) << "ProgramSpecialize start"; - abstract::ProgramSpecializer spc(res->engine()); - FuncGraphPtr result = spc.Run(func_graph, context); - auto manager = res->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->KeepRoots({result}); - MS_LOG(DEBUG) << "ProgramSpecialize end"; - return result; -} - -FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec) { - MS_LOG(DEBUG) << "Renormalize start"; -#ifdef ENABLE_PROFILE - double t1 = GetTime(); -#endif - abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true); -#ifdef ENABLE_PROFILE - double t2 = GetTime(); -#endif - auto ret = ProgramSpecialize(res, func_graph, result.context); - res->set_func_graph(ret); -#ifdef ENABLE_PROFILE - double t3 = GetTime(); - MsProfile::StatTime("renormalize.infer", t2 - t1); - MsProfile::StatTime("renormalize.specialize", t3 - t2); -#endif - MS_LOG(DEBUG) << "Renormalize end"; - return ret; -} - -bool ParseAction(const ResourcePtr &res) { - if (!res->input()) { - MS_LOG(EXCEPTION) << "Parse error"; - } - - py::object input = res->input(); - parse::Parser::InitParserEnvironment(input); - py::module path = py::module::import("os.path"); - std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast(); - - parse::python_adapter::set_python_env_flag(true); - parse::python_adapter::SetPythonPath(dir); - - FuncGraphPtr fg = parse::ConvertToFuncGraph(input); - if (fg == nullptr) { - MS_LOG(EXCEPTION) << "Parse error."; - } - res->set_func_graph(fg); - - FuncGraphManagerPtr manager = res->manager(); - if (manager == nullptr) { - MS_LOG(EXCEPTION) << "Manager is nullptr."; - } - manager->AddFuncGraph(fg); - return true; -} - -// obj_map's graphs have the same construct, these graphs can be optimized to one graph. -// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> -// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} -// all obj_map's graph shared base_graph -bool CombineLikeGraphs(const ResourcePtr &res) { - auto &obj_map = parse::data_converter::GetObjGraphs(); - - for (auto it : obj_map) { - auto &graphs = it.second; - MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); - auto fg = graphs[0]; - FuncGraphPtrList func_graphs = {fg}; - ClonerPtr cloner = std::make_shared(func_graphs, false, false, true, std::make_shared(), - std::make_shared()); - cloner->Run(); - auto base_graph = cloner->cloned_func_graph()[fg]; - MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString(); - - if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { - continue; - } - for (auto &fv : fg->paramter_obj_nodes()) { - TraceManager::DebugTrace(std::make_shared(fv->debug_info())); - auto param = base_graph->add_parameter(); - TraceManager::EndTrace(); - auto &node_users = res->manager()->node_users()[fv]; - for (auto &n : node_users) { - auto repl_n = (*cloner->cloned_node())[n.first]->cast(); - repl_n->set_input(n.second, param); - } - } - MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); - - for (auto &g : graphs) { - auto fvs = g->paramter_obj_nodes(); - std::vector new_node_inputs; - new_node_inputs.push_back(NewValueNode(base_graph)); - for (auto &p : g->parameters()) { - AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); - new_node_inputs.push_back(para_after_cast); - } - (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end()); - AnfNodePtr out = g->NewCNode(new_node_inputs); - g->set_output(out); - MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4); - } - MS_LOG(DEBUG) << "End combine graph:" << it.first; - } - return true; -} - -bool SymbolResolveAction(const ResourcePtr &res) { - if (res->manager() == nullptr) { - MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; - } - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; - } - FuncGraphPtr func_graph = res->func_graph(); - auto succ = parse::ResolveFuncGraph(func_graph, res); - - // Remove unused nodes in cnode order list. - func_graph->EraseUnusedNodeInOrder(); - func_graph->ReleaseFullOrderToEffectOrder(); - for (auto fg : func_graph->func_graphs_used_total()) { - MS_EXCEPTION_IF_NULL(fg); - fg->EraseUnusedNodeInOrder(); - fg->ReleaseFullOrderToEffectOrder(); - } - return succ; -} - -bool InferenceOptPrepareAction(const ResourcePtr &res) { - if (res->manager() == nullptr) { - MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; - } - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; - } - return InferenceOptPreparePass(res); -} - -bool AbstractSpecializeAction(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "AbstractSpecialize error"; - } - - FuncGraphPtr func_graph = res->func_graph(); - abstract::AbstractBasePtrList args_spec = res->args_spec(); - - parallel::ParallelParameterContextInit(func_graph); - - // suppose that there is not KeywordArgument for the top graph - // get the hyper parameter - for (const auto ¶m : func_graph->parameters()) { - auto param_node = std::static_pointer_cast(param); - if (param_node->has_default()) { - const auto ¶m_value = param_node->default_param(); - ValuePtr value = param_value->value(); - constexpr bool broaden = true; - AbstractBasePtr ptr = abstract::FromValue(value, broaden); - - parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); - args_spec.push_back(ptr); - parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); - } - } - // Analyze - AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec); - // The top graph may be replaced by infer, update the top graph when the infer is done - parse::Parser::UpdateTopFuncGraph(result.context->func_graph()); - - // Specialize - FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context); - res->set_func_graph(new_fg); - - MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); - return true; -} - -bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { - size_t counter = 0; - for (auto &pass : passes) { - WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { - MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; - auto result = pass.second(res); - if (!result) { - MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; - } - if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { - auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; - auto func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - func_graph->DumpFuncGraph(fg_name); - DumpIR(fg_name + ".ir", func_graph); - MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; - } - counter++; - MS_LOG(DEBUG) << "Pass " << pass.first << " end."; - }; - } - - return true; -} - -bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } - -bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } - -bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } - -static bool IsCtrlSink() { - auto ms_ctx = MsContext::GetInstance(); - if (ms_ctx->execution_mode() != kGraphMode) { - return false; - } - - std::string device_target = ms_ctx->device_target(); - if (device_target != kAscendDevice) { - return false; - } - - if (!ms_ctx->enable_task_sink()) { - return false; - } - - if (!ms_ctx->is_multi_graph_sink()) { - return false; - } - return true; -} - -bool TaskEmitAction(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "TaskEmit args error"; - } - FuncGraphPtr func_graph = res->func_graph(); - auto bc_ptr = res->results()[kBackend].cast(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (CompileGraphs::ContainMixedTarget(func_graph)) { - bc_ptr->set_is_multi_graph_sink(false); - context_ptr->set_is_multi_graph_sink(false); - context_ptr->set_loop_sink_flag(false); - } else if (context_ptr->execution_mode() != kPynativeMode) { - std::string device_target = context_ptr->device_target(); - if (device_target == kAscendDevice) { - bc_ptr->set_is_multi_graph_sink(true); - context_ptr->set_is_multi_graph_sink(true); - } - } - - if (IsCtrlSink()) { - res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); - return true; - } - std::vector cut_list = compile::nonlinear_ops; - if (bc_ptr->name() == kMsConvert) { - cut_list = compile::GetMsNonlinearOps(); - } - std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); - res->results()[kOutput] = compile->CompileAndLink(func_graph); - return true; -} - -bool ExecuteAction(const ResourcePtr &res) { - if (res->results().count(kOutput) == 0) { - MS_LOG(EXCEPTION) << "Execute args error"; - } - - if (IsCtrlSink()) { - if (!res->results()[kOutput].is()) { - MS_LOG(EXCEPTION) << "Execute args error"; - } - auto graph_id = res->results()[kOutput].cast(); - std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); - std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr); - MS_EXCEPTION_IF_NULL(msbc_ptr); - compile::VmEvalFuncPtr run = - std::make_shared([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { - MS_LOG(INFO) << "Execute args size " << args.size(); - auto outs = msbc_ptr->RunGraph(graph_id, args); - MS_LOG(DEBUG) << "out size " << outs.size(); - return outs[0]; - }); - res->results()[kOutput] = run; - return true; - } - - if (!res->results()[kOutput].is()) { - MS_LOG(EXCEPTION) << "Execute args error"; - } - compile::FinalVMPtr vm = res->results()[kOutput].cast(); - if (vm == nullptr) { - MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; - return true; - } - compile::VmEvalFuncPtr run = - std::make_shared(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1)); - res->results()[kOutput] = run; - return true; -} - -// The parallel primitive related valuenode might be partitioned so that its value changes by device, -// that will result in a syncronization error due to different executing order. -// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, -// the final solution will be proposed later as a parallel feature. -bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { - auto &node_users = res->manager()->node_users(); - auto &users = node_users[value_node]; - auto used_by_keep_value_prim = - std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { - MS_EXCEPTION_IF_NULL(user.first); - auto cnode = user.first->cast(); - if (cnode == nullptr) { - return false; - } - auto prim_node = cnode->input(0); - if (IsValueNode(prim_node)) { - auto prim = GetValue(prim_node->cast()->value()); - // value_node is referenced by some parallel primitive - return prim->HasAttr("keep_value_node_input"); - } - return false; - }); - return used_by_keep_value_prim; -} - -bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { - if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "Remove value node duplications error."; - } - FuncGraphPtr func_graph = res->func_graph(); - auto manager = res->manager(); - // Remove duplicated value nodes, due to replace operation, can't use reference. - auto value_nodes = func_graph->value_nodes(); - HashCache hash_cache; - HashValue hashes; - for (const auto &value_pair : value_nodes) { - if (KeepValueNodeDuplication(value_pair.first, res)) { - continue; - } - TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); - } - return true; -} - -bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } - -void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { - MS_EXCEPTION_IF_NULL(res->manager()); - MS_EXCEPTION_IF_NULL(res->func_graph()); - auto ppm = opt::python_pass::PyPassManager::GetInstance(); - if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { - MS_LOG(DEBUG) << "No match.\n"; - } -} - -bool ResolveActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::RESOLVE); - return true; -} - -bool OptActionPyStub(const ResourcePtr &res) { - ActionPyStub(res, opt::python_pass::Phase::OPT); - return true; -} - -static std::vector CommonPipeline() { - std::vector actions; - - // Parse the python ast to ANF graph - actions.emplace_back(std::make_pair("parse", ParseAction)); - - // Resolve the python func - actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); - auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); - if (!multi_graphs) { - actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); - } - // Add resolve-stage python pass stub - actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); - actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); - // Evaluate type and shape, and specialize - actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); - - return actions; -} - -std::vector GePipeline() { - auto actions = CommonPipeline(); - // optimize - actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); - // Add opt-stage python pass stub - actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); - actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); - actions.emplace_back(std::make_pair("validate", ValidateAction)); - return actions; -} - -std::vector VmPipeline() { - auto actions = CommonPipeline(); - - // optimize - actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); - - // Add opt-stage python pass stub - actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); - - actions.emplace_back(std::make_pair("validate", ValidateAction)); - - // compile the ANF graph - actions.emplace_back(std::make_pair("task_emit", TaskEmitAction)); - - // to execute the graph - actions.emplace_back(std::make_pair("execute", ExecuteAction)); - - return actions; -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/action.h b/mindspore/ccsrc/pipeline/action.h deleted file mode 100644 index eed1307872..0000000000 --- a/mindspore/ccsrc/pipeline/action.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_ACTION_H_ -#define MINDSPORE_CCSRC_PIPELINE_ACTION_H_ - -#include -#include -#include -#include -#include "pipeline/resource.h" -#include "vm/segment_runner.h" - -namespace mindspore { -extern const char kMsConvert[]; - -namespace pipeline { -using ActionItem = std::pair>; - -bool ParseAction(const ResourcePtr &res); -bool SymbolResolveAction(const ResourcePtr &res); -bool AbstractSpecializeAction(const ResourcePtr &res); -bool GeOptimizeAction(const ResourcePtr &res); -bool VmOptimizeAction(const ResourcePtr &res); -bool PynativeOptimizeAction(const ResourcePtr &res); -bool TaskEmitAction(const ResourcePtr &res); -bool ExecuteAction(const ResourcePtr &res); - -std::vector GePipeline(); -std::vector VmPipeline(); -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec, bool clear = false); -FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AnalysisContextPtr &context); -FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, - const abstract::AbstractBasePtrList &args_spec); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_ACTION_H_ diff --git a/mindspore/ccsrc/pipeline/base.h b/mindspore/ccsrc/pipeline/base.h deleted file mode 100644 index 57edea03a2..0000000000 --- a/mindspore/ccsrc/pipeline/base.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_BASE_H_ -#define MINDSPORE_CCSRC_PIPELINE_BASE_H_ - -#include -#include -#include -#include - -#include "ir/anf.h" -#include "pipeline/resource.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace pipeline { -struct ExecutorInfo { - FuncGraphPtr func_graph; - ResourcePtr resource; - std::size_t arg_list_size; -}; -using ExecutorInfoPtr = std::shared_ptr; - -inline std::string GetPhasePrefix(const std::string &phase) { - auto pos = phase.find('.'); - if (pos == std::string::npos) { - MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; - } - return phase.substr(0, pos); -} - -inline std::string GetFilePathName(const std::string &file_name) { - std::ostringstream oss; - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(EXCEPTION) << "ms_context is nullptr"; - } - auto save_graphs_path = ms_context->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - oss << save_graphs_path << "/" << file_name; - return oss.str(); -} -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc deleted file mode 100644 index f18178f19a..0000000000 --- a/mindspore/ccsrc/pipeline/init.cc +++ /dev/null @@ -1,336 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "kernel/oplib/oplib.h" -#include "kernel/oplib/oploader.h" -#include "pipeline/pipeline.h" -#include "operator/composite/composite.h" -#include "ir/signature.h" -#include "pynative/pynative_execute.h" -#include "utils/symbolic.h" -#include "pybind_api/api_register.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/summary/event_writer.h" -#include "utils/config_manager.h" -#include "utils/mpi/mpi_config.h" -#include "parallel/context.h" -#include "parallel/device_manager.h" -#include "parallel/costmodel_context.h" -#ifdef ENABLE_GPU_COLLECTIVE -#include "device/gpu/distribution/collective_init.h" -#else -#include "device/gpu/distribution/collective_fake_init.h" -#endif -namespace py = pybind11; - -using EnvInstance = mindspore::EnvInstance; -using ExecutorPy = mindspore::pipeline::ExecutorPy; -using Pipeline = mindspore::pipeline::Pipeline; -using PrimitivePy = mindspore::PrimitivePy; -using MetaFuncGraph = mindspore::MetaFuncGraph; -using EventWriter = mindspore::summary::EventWriter; -using OpLib = mindspore::kernel::OpLib; -using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; -using ParallelContext = mindspore::parallel::ParallelContext; -using CostModelContext = mindspore::parallel::CostModelContext; - -// Interface with python -PYBIND11_MODULE(_c_expression, m) { - m.doc() = "MindSpore c plugin"; - - auto fns = mindspore::PybindDefineRegister::AllFuncs(); - for (auto &item : fns) { - item.second(&m); - } - - // Class Pipeline interface - (void)py::class_>(m, "Executor_") - .def_static("get_instance", &ExecutorPy::GetInstance, "Executor get_instance.") - .def("__call__", &ExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.") - .def("del_net_res", &ExecutorPy::DelNetRes, py::arg("network_id") = py::str(""), "Delete network resource.") - .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") - .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), - py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") - .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), - py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") - .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), - "Get Parameter Tensor Layout Dictionary.") - .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), - "Get CNode Strategy Dictionary.") - .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"), - "Get Allreduce Fusion Dictionary.") - .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"), - "Fetch the inputs of Conv or Matmul for quant export.") - .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), - py::arg("broadcast_params") = py::dict(), "Build data graph.") - .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") - .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); - - (void)py::class_>(m, "EnvInstance_") - .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) - .def(py::init()); - - (void)m.def("generate_key", &mindspore::pipeline::GenerateKey, "Generate the function graph key."); - (void)m.def("real_run_op", &mindspore::pynative::RunOp, "Run op pynatively."); - (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id"); - (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl"); - (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl"); - (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); - (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), - py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), - py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); - (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); - (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); - - (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); - - (void)py::class_>(m, "MSContext") - .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") - .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") - .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") - .def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.") - .def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.") - .def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.") - .def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.") - .def("get_device_target", &mindspore::MsContext::device_target, "Get device target.") - .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") - .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") - .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") - .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") - .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") - .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") - .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") - .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, - "Get whether to enable auto mixed precision.") - .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, - "Set whether to enable auto mixed precision.") - .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, - "Get whether to enable reduce precision.") - .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, - "Set whether to enable reduce precision.") - .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") - .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") - .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.") - .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.") - .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.") - .def("set_save_ms_model_path", &mindspore::MsContext::set_save_ms_model_path, "Set path to save ms model") - .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") - .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") - .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") - .def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.") - .def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.") - .def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size, - "set variable memory max size") - .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") - .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") - .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") - .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.") - .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") - .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") - .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") - .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") - .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") - .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, - "Set the GraphKernel switch to on or off.") - .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") - .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") - .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); - - (void)py::class_>(m, "MpiConfig") - .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") - .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") - .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi."); - - (void)py::class_>(m, "AutoParallelContext") - .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") - .def("get_device_num", &ParallelContext::device_num, "Get device num.") - .def("set_device_num", &ParallelContext::set_device_num, "Set device num.") - .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") - .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") - .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") - .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") - .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") - .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") - .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") - .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") - .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") - .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") - .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") - .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") - .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") - .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") - .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") - .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") - .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, - "Set all reduce fusion split indices.") - .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices, - "Get all reduce fusion split indices.") - .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes, - "Set all reduce fusion split sizes.") - .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes, - "Get all reduce fusion split sizes.") - .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, - "Set enable/disable all reduce fusion.") - .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion, - "Get enable/disable all reduce fusion.") - .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.") - .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, - "Get parameter broadcast is set.") - .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") - .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file, - "Set strategy checkpoint load file.") - .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file, - "Set strategy checkpoint save file.") - .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") - .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") - .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") - .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") - .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, - "Set enable/disable parallel optimizer.") - .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, - "Get enable/disable parallel optimizer.") - .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); - - (void)py::class_>(m, "CostModelContext") - .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.") - .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity, - "Set the capacity of device memory.") - .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.") - .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha, - "Set the parameter cost_model_alpha of the DP algorithm.") - .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha, - "Get the parameter cost_model_alpha of the DP algorithm.") - .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta, - "Set the parameter cost_model_beta of the DP algorithm.") - .def("get_costmodel_beta", &CostModelContext::costmodel_beta, - "Get the parameter cost_model_beta of the DP algorithm.") - .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma, - "Set the parameter cost_model_gamma of the DP algorithm") - .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma, - "Get the parameter cost_model_gamma of the DP algorithm.") - .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold, - "Set the parameter cost_model_communi_threshold of the DP algorithm.") - .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold, - "Get the parameter cost_model_communi_threshold of the DP algorithm.") - .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const, - "Set the parameter cost_model_communi_const of the DP algorithm.") - .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const, - "Get the parameter cost_model_communi_const of the DP algorithm.") - .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias, - "Set the parameter cost_model_communi_bias of the DP algorithm.") - .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, - "Get the parameter cost_model_communi_bias of the DP algorithm.") - .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") - .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") - .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") - .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") - .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, - "Set the parameter gradient AllReduce fusion algorithm.") - .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, - "Get the parameter gradient AllReduce fusion algorithm.") - .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times, - "Set the parameter gradient AllReduce times.") - .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times, - "Get the parameter gradient AllReduce times.") - .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent, - "Set the parameter gradient AllReduce fusion tail percent.") - .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent, - "Get the parameter gradient AllReduce fusion tail percent.") - .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time, - "Set the parameter gradient AllReduce fusion tail time.") - .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time, - "Get the parameter gradient AllReduce fusion tail time.") - .def("set_costmodel_allreduce_fusion_allreduce_inherent_time", - &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time, - "Set the parameter gradient AllReduce fusion allreduce inherent time.") - .def("get_costmodel_allreduce_fusion_allreduce_inherent_time", - &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time, - "Get the parameter gradient AllReduce fusion allreduce inherent time.") - .def("set_costmodel_allreduce_fusion_allreduce_bandwidth", - &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth, - "Set the parameter gradient AllReduce fusion allreduce bandwidth.") - .def("get_costmodel_allreduce_fusion_allreduce_bandwidth", - &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth, - "Get the parameter gradient AllReduce fusion allreduce bandwidth.") - .def("set_costmodel_allreduce_fusion_computation_time_parameter", - &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter, - "Set the parameter gradient AllReduce fusion computation time parameter.") - .def("get_costmodel_allreduce_fusion_computation_time_parameter", - &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter, - "Get the parameter gradient AllReduce fusion computation time parameter.") - .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable, - "Set the parameter tensor_slice_align_enable in strategy generation.") - .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable, - "Get the parameter tensor_slice_align_enable in strategy generation.") - .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size, - "Set the parameter tensor_slice_size in strategy generation.") - .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, - "Get the parameter tensor_slice_size in strategy generation.") - .def("set_fully_use_devices", &CostModelContext::set_fully_use_device, - "Set the parameter fully_use_devices in the DP algorithm.") - .def("get_fully_use_devices", &CostModelContext::fully_use_device, - "Get the parameter fully_use_devices in the DP algorithm.") - .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, - "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") - .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, - "Get the parameter elementwise_op_strategy_follow in the DP algorithm.") - .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.") - .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters."); - - (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { - // only in case that c++ calling python interface, ClearResAtexit should be called. - if (mindspore::parse::python_adapter::IsPythonEnv()) { - mindspore::pipeline::ClearResAtexit(); - -#ifdef ENABLE_MINDDATA - py::module iterators = py::module::import("mindspore.dataset.engine.iterators"); - (void)iterators.attr("_cleanup")(); -#endif - } - }}); - - (void)py::class_>(m, "EventWriter_") - .def(py::init()) - .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") - .def("Open", &EventWriter::Open, "Open the write file.") - .def("Write", &EventWriter::Write, "Write the serialize event.") - .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.") - .def("Flush", &EventWriter::Flush, "Flush the event.") - .def("Close", &EventWriter::Close, "Close the write.") - .def("Shut", &EventWriter::Shut, "Final close the write."); - - (void)py::class_>(m, "Oplib") - .def(py::init()) - .def_static("reg_op", &OpLib::RegOp, "Register op info."); -#ifdef ENABLE_GPU_COLLECTIVE - (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective, - "Init gpu collective communication mode."); - (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective, - "Finalize gpu collective communication mode."); -#else - (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective, - "Init gpu collective communication mode."); - (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective, - "Finalize gpu collective communication mode."); - -#endif - - (void)py::class_>(m, "OpInfoLoaderPy") - .def(py::init()) - .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); -} diff --git a/mindspore/ccsrc/pipeline/jit/CMakeLists.txt b/mindspore/ccsrc/pipeline/jit/CMakeLists.txt new file mode 100644 index 0000000000..6188546ce5 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/CMakeLists.txt @@ -0,0 +1,27 @@ +file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + "pipeline.cc" + "resource.cc" + "pass.cc" + "action.cc" + "validator.cc" + "remove_value_node_dup.cc" + "parse/*.cc" + "static_analysis/*.cc" +) + + +file(GLOB PIPELINE_SRC_FILES "*.cc") +set_property(SOURCE ${PIPELINE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PIPELINE) + +file(GLOB_RECURSE PARSER_SRC_FILES "parse/*.cc") +set_property(SOURCE ${PARSER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PARSER) + +file(GLOB_RECURSE ANALYZER_SRC_FILES "static_analysis/*.cc") +set_property(SOURCE ${ANALYZER_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ANALYZER) + +if (ENABLE_GE OR ENABLE_D) + file(GLOB_RECURSE _PIPELINE_GE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pipeline_ge.cc") + list(APPEND _PIPELINE_SRC_FILES ${_PIPELINE_GE_SRC_FILES}) +endif () + +add_library(_mindspore_pipeline_jit_obj OBJECT ${_PIPELINE_SRC_FILES}) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc new file mode 100644 index 0000000000..74eb9f3f9b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -0,0 +1,494 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/action.h" + +#include +#include +#include +#include +#include +#include + +#include "ir/func_graph_cloner.h" +#include "ir/param_value.h" +#include "frontend/parallel/costmodel_context.h" +#include "frontend/parallel/context.h" +#include "pipeline/jit/pass.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/data_converter.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/program_specialize.h" +#include "pipeline/jit/resource.h" +#include "utils/context/ms_context.h" +#include "pipeline/jit/remove_value_node_dup.h" +#include "frontend/optimizer/optimizer.h" +#include "vm/transform.h" +#include "parse/python_adapter.h" +#include "frontend/optimizer/py_pass_manager.h" + +namespace mindspore { +namespace pipeline { +using CompileGraphs = compile::CompileGraphs; +using abstract::AnalysisResult; +using mindspore::abstract::AnalysisContextPtr; + +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear) { + MS_LOG(DEBUG) << "AbstractAnalyze start"; + auto engine = res->engine(); + MS_EXCEPTION_IF_NULL(engine); + if (clear) { + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + engine->Clear(); + for (auto &node : manager->all_nodes()) { + MS_EXCEPTION_IF_NULL(node); + const AbstractBasePtr &prev_inferred = node->abstract(); + // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. + if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { + node->set_abstract(nullptr); + MS_LOG(DEBUG) << "Abstract of node " << node->ToString() << " is set to nullptr"; + } + } + } + auto ret = engine->Run(func_graph, args_spec); + MS_LOG(DEBUG) << "AbstractAnalyze end"; + return ret; +} + +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context) { + MS_LOG(DEBUG) << "ProgramSpecialize start"; + abstract::ProgramSpecializer spc(res->engine()); + FuncGraphPtr result = spc.Run(func_graph, context); + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->KeepRoots({result}); + MS_LOG(DEBUG) << "ProgramSpecialize end"; + return result; +} + +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec) { + MS_LOG(DEBUG) << "Renormalize start"; +#ifdef ENABLE_PROFILE + double t1 = GetTime(); +#endif + abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true); +#ifdef ENABLE_PROFILE + double t2 = GetTime(); +#endif + auto ret = ProgramSpecialize(res, func_graph, result.context); + res->set_func_graph(ret); +#ifdef ENABLE_PROFILE + double t3 = GetTime(); + MsProfile::StatTime("renormalize.infer", t2 - t1); + MsProfile::StatTime("renormalize.specialize", t3 - t2); +#endif + MS_LOG(DEBUG) << "Renormalize end"; + return ret; +} + +bool ParseAction(const ResourcePtr &res) { + if (!res->input()) { + MS_LOG(EXCEPTION) << "Parse error"; + } + + py::object input = res->input(); + parse::Parser::InitParserEnvironment(input); + py::module path = py::module::import("os.path"); + std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast(); + + parse::python_adapter::set_python_env_flag(true); + parse::python_adapter::SetPythonPath(dir); + + FuncGraphPtr fg = parse::ConvertToFuncGraph(input); + if (fg == nullptr) { + MS_LOG(EXCEPTION) << "Parse error."; + } + res->set_func_graph(fg); + + FuncGraphManagerPtr manager = res->manager(); + if (manager == nullptr) { + MS_LOG(EXCEPTION) << "Manager is nullptr."; + } + manager->AddFuncGraph(fg); + return true; +} + +// obj_map's graphs have the same construct, these graphs can be optimized to one graph. +// This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> +// graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} +// all obj_map's graph shared base_graph +bool CombineLikeGraphs(const ResourcePtr &res) { + auto &obj_map = parse::data_converter::GetObjGraphs(); + + for (auto it : obj_map) { + auto &graphs = it.second; + MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); + auto fg = graphs[0]; + FuncGraphPtrList func_graphs = {fg}; + ClonerPtr cloner = std::make_shared(func_graphs, false, false, true, std::make_shared(), + std::make_shared()); + cloner->Run(); + auto base_graph = cloner->cloned_func_graph()[fg]; + MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString(); + + if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { + continue; + } + for (auto &fv : fg->paramter_obj_nodes()) { + TraceManager::DebugTrace(std::make_shared(fv->debug_info())); + auto param = base_graph->add_parameter(); + TraceManager::EndTrace(); + auto &node_users = res->manager()->node_users()[fv]; + for (auto &n : node_users) { + auto repl_n = (*cloner->cloned_node())[n.first]->cast(); + repl_n->set_input(n.second, param); + } + } + MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); + + for (auto &g : graphs) { + auto fvs = g->paramter_obj_nodes(); + std::vector new_node_inputs; + new_node_inputs.push_back(NewValueNode(base_graph)); + for (auto &p : g->parameters()) { + AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); + new_node_inputs.push_back(para_after_cast); + } + (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end()); + AnfNodePtr out = g->NewCNode(new_node_inputs); + g->set_output(out); + MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4); + } + MS_LOG(DEBUG) << "End combine graph:" << it.first; + } + return true; +} + +bool SymbolResolveAction(const ResourcePtr &res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; + } + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; + } + FuncGraphPtr func_graph = res->func_graph(); + auto succ = parse::ResolveFuncGraph(func_graph, res); + + // Remove unused nodes in cnode order list. + func_graph->EraseUnusedNodeInOrder(); + func_graph->ReleaseFullOrderToEffectOrder(); + for (auto fg : func_graph->func_graphs_used_total()) { + MS_EXCEPTION_IF_NULL(fg); + fg->EraseUnusedNodeInOrder(); + fg->ReleaseFullOrderToEffectOrder(); + } + return succ; +} + +bool InferenceOptPrepareAction(const ResourcePtr &res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; + } + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; + } + return InferenceOptPreparePass(res); +} + +bool AbstractSpecializeAction(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "AbstractSpecialize error"; + } + + FuncGraphPtr func_graph = res->func_graph(); + abstract::AbstractBasePtrList args_spec = res->args_spec(); + + parallel::ParallelParameterContextInit(func_graph); + + // suppose that there is not KeywordArgument for the top graph + // get the hyper parameter + for (const auto ¶m : func_graph->parameters()) { + auto param_node = std::static_pointer_cast(param); + if (param_node->has_default()) { + const auto ¶m_value = param_node->default_param(); + ValuePtr value = param_value->value(); + constexpr bool broaden = true; + AbstractBasePtr ptr = abstract::FromValue(value, broaden); + + parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); + args_spec.push_back(ptr); + parallel::ParallelParameterContextCkptInTraining(func_graph, param_node, ptr); + } + } + // Analyze + AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec); + // The top graph may be replaced by infer, update the top graph when the infer is done + parse::Parser::UpdateTopFuncGraph(result.context->func_graph()); + + // Specialize + FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context); + res->set_func_graph(new_fg); + + MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); + return true; +} + +bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { + size_t counter = 0; + for (auto &pass : passes) { + WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { + MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; + auto result = pass.second(res); + if (!result) { + MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; + } + if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { + auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; + auto func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + func_graph->DumpFuncGraph(fg_name); + DumpIR(fg_name + ".ir", func_graph); + MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; + } + counter++; + MS_LOG(DEBUG) << "Pass " << pass.first << " end."; + }; + } + + return true; +} + +bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } + +bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } + +bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } + +static bool IsCtrlSink() { + auto ms_ctx = MsContext::GetInstance(); + if (ms_ctx->execution_mode() != kGraphMode) { + return false; + } + + std::string device_target = ms_ctx->device_target(); + if (device_target != kAscendDevice) { + return false; + } + + if (!ms_ctx->enable_task_sink()) { + return false; + } + + if (!ms_ctx->is_multi_graph_sink()) { + return false; + } + return true; +} + +bool TaskEmitAction(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "TaskEmit args error"; + } + FuncGraphPtr func_graph = res->func_graph(); + auto bc_ptr = res->results()[kBackend].cast(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (CompileGraphs::ContainMixedTarget(func_graph)) { + bc_ptr->set_is_multi_graph_sink(false); + context_ptr->set_is_multi_graph_sink(false); + context_ptr->set_loop_sink_flag(false); + } else if (context_ptr->execution_mode() != kPynativeMode) { + std::string device_target = context_ptr->device_target(); + if (device_target == kAscendDevice) { + bc_ptr->set_is_multi_graph_sink(true); + context_ptr->set_is_multi_graph_sink(true); + } + } + + if (IsCtrlSink()) { + res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); + return true; + } + std::vector cut_list = compile::nonlinear_ops; + if (bc_ptr->name() == kMsConvert) { + cut_list = compile::GetMsNonlinearOps(); + } + std::shared_ptr compile = std::make_shared(bc_ptr, cut_list); + res->results()[kOutput] = compile->CompileAndLink(func_graph); + return true; +} + +bool ExecuteAction(const ResourcePtr &res) { + if (res->results().count(kOutput) == 0) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + + if (IsCtrlSink()) { + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + auto graph_id = res->results()[kOutput].cast(); + std::shared_ptr bc_ptr = res->results()[kBackend].cast>(); + std::shared_ptr msbc_ptr = std::dynamic_pointer_cast(bc_ptr); + MS_EXCEPTION_IF_NULL(msbc_ptr); + compile::VmEvalFuncPtr run = + std::make_shared([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef { + MS_LOG(INFO) << "Execute args size " << args.size(); + auto outs = msbc_ptr->RunGraph(graph_id, args); + MS_LOG(DEBUG) << "out size " << outs.size(); + return outs[0]; + }); + res->results()[kOutput] = run; + return true; + } + + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + compile::FinalVMPtr vm = res->results()[kOutput].cast(); + if (vm == nullptr) { + MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; + return true; + } + compile::VmEvalFuncPtr run = + std::make_shared(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1)); + res->results()[kOutput] = run; + return true; +} + +// The parallel primitive related valuenode might be partitioned so that its value changes by device, +// that will result in a syncronization error due to different executing order. +// Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, +// the final solution will be proposed later as a parallel feature. +bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { + auto &node_users = res->manager()->node_users(); + auto &users = node_users[value_node]; + auto used_by_keep_value_prim = + std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { + MS_EXCEPTION_IF_NULL(user.first); + auto cnode = user.first->cast(); + if (cnode == nullptr) { + return false; + } + auto prim_node = cnode->input(0); + if (IsValueNode(prim_node)) { + auto prim = GetValue(prim_node->cast()->value()); + // value_node is referenced by some parallel primitive + return prim->HasAttr("keep_value_node_input"); + } + return false; + }); + return used_by_keep_value_prim; +} + +bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "Remove value node duplications error."; + } + FuncGraphPtr func_graph = res->func_graph(); + auto manager = res->manager(); + // Remove duplicated value nodes, due to replace operation, can't use reference. + auto value_nodes = func_graph->value_nodes(); + HashCache hash_cache; + HashValue hashes; + for (const auto &value_pair : value_nodes) { + if (KeepValueNodeDuplication(value_pair.first, res)) { + continue; + } + TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes); + } + return true; +} + +bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } + +void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { + MS_EXCEPTION_IF_NULL(res->manager()); + MS_EXCEPTION_IF_NULL(res->func_graph()); + auto ppm = opt::python_pass::PyPassManager::GetInstance(); + if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { + MS_LOG(DEBUG) << "No match.\n"; + } +} + +bool ResolveActionPyStub(const ResourcePtr &res) { + ActionPyStub(res, opt::python_pass::Phase::RESOLVE); + return true; +} + +bool OptActionPyStub(const ResourcePtr &res) { + ActionPyStub(res, opt::python_pass::Phase::OPT); + return true; +} + +static std::vector CommonPipeline() { + std::vector actions; + + // Parse the python ast to ANF graph + actions.emplace_back(std::make_pair("parse", ParseAction)); + + // Resolve the python func + actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); + auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); + if (!multi_graphs) { + actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); + } + // Add resolve-stage python pass stub + actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); + actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); + // Evaluate type and shape, and specialize + actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); + + return actions; +} + +std::vector GePipeline() { + auto actions = CommonPipeline(); + // optimize + actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); + // Add opt-stage python pass stub + actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); + actions.emplace_back(std::make_pair("validate", ValidateAction)); + return actions; +} + +std::vector VmPipeline() { + auto actions = CommonPipeline(); + + // optimize + actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); + + // Add opt-stage python pass stub + actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); + + actions.emplace_back(std::make_pair("validate", ValidateAction)); + + // compile the ANF graph + actions.emplace_back(std::make_pair("task_emit", TaskEmitAction)); + + // to execute the graph + actions.emplace_back(std::make_pair("execute", ExecuteAction)); + + return actions; +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/action.h b/mindspore/ccsrc/pipeline/jit/action.h new file mode 100644 index 0000000000..0a1feab1c9 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/action.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_ACTION_H_ +#define MINDSPORE_CCSRC_PIPELINE_ACTION_H_ + +#include +#include +#include +#include +#include "pipeline/jit/resource.h" +#include "vm/segment_runner.h" + +namespace mindspore { +extern const char kMsConvert[]; + +namespace pipeline { +using ActionItem = std::pair>; + +bool ParseAction(const ResourcePtr &res); +bool SymbolResolveAction(const ResourcePtr &res); +bool AbstractSpecializeAction(const ResourcePtr &res); +bool GeOptimizeAction(const ResourcePtr &res); +bool VmOptimizeAction(const ResourcePtr &res); +bool PynativeOptimizeAction(const ResourcePtr &res); +bool TaskEmitAction(const ResourcePtr &res); +bool ExecuteAction(const ResourcePtr &res); + +std::vector GePipeline(); +std::vector VmPipeline(); +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear = false); +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context); +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_ACTION_H_ diff --git a/mindspore/ccsrc/pipeline/jit/base.h b/mindspore/ccsrc/pipeline/jit/base.h new file mode 100644 index 0000000000..0a8a2b75f3 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/base.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_BASE_H_ +#define MINDSPORE_CCSRC_PIPELINE_BASE_H_ + +#include +#include +#include +#include + +#include "ir/anf.h" +#include "pipeline/jit/resource.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace pipeline { +struct ExecutorInfo { + FuncGraphPtr func_graph; + ResourcePtr resource; + std::size_t arg_list_size; +}; +using ExecutorInfoPtr = std::shared_ptr; + +inline std::string GetPhasePrefix(const std::string &phase) { + auto pos = phase.find('.'); + if (pos == std::string::npos) { + MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; + } + return phase.substr(0, pos); +} + +inline std::string GetFilePathName(const std::string &file_name) { + std::ostringstream oss; + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(EXCEPTION) << "ms_context is nullptr"; + } + auto save_graphs_path = ms_context->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + oss << save_graphs_path << "/" << file_name; + return oss.str(); +} +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_BASE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc new file mode 100644 index 0000000000..65adebb6e2 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -0,0 +1,336 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/oplib/oploader.h" +#include "pipeline/jit/pipeline.h" +#include "frontend/operator/composite/composite.h" +#include "ir/signature.h" +#include "pipeline/pynative/pynative_execute.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/summary/event_writer.h" +#include "utils/config_manager.h" +#include "utils/mpi/mpi_config.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/costmodel_context.h" +#ifdef ENABLE_GPU_COLLECTIVE +#include "runtime/device/gpu/distribution/collective_init.h" +#else +#include "runtime/device/gpu/distribution/collective_fake_init.h" +#endif +namespace py = pybind11; + +using EnvInstance = mindspore::EnvInstance; +using ExecutorPy = mindspore::pipeline::ExecutorPy; +using Pipeline = mindspore::pipeline::Pipeline; +using PrimitivePy = mindspore::PrimitivePy; +using MetaFuncGraph = mindspore::MetaFuncGraph; +using EventWriter = mindspore::summary::EventWriter; +using OpLib = mindspore::kernel::OpLib; +using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy; +using ParallelContext = mindspore::parallel::ParallelContext; +using CostModelContext = mindspore::parallel::CostModelContext; + +// Interface with python +PYBIND11_MODULE(_c_expression, m) { + m.doc() = "MindSpore c plugin"; + + auto fns = mindspore::PybindDefineRegister::AllFuncs(); + for (auto &item : fns) { + item.second(&m); + } + + // Class Pipeline interface + (void)py::class_>(m, "Executor_") + .def_static("get_instance", &ExecutorPy::GetInstance, "Executor get_instance.") + .def("__call__", &ExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.") + .def("del_net_res", &ExecutorPy::DelNetRes, py::arg("network_id") = py::str(""), "Delete network resource.") + .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.") + .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""), + py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") + .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), + py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") + .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), + "Get Parameter Tensor Layout Dictionary.") + .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), + "Get CNode Strategy Dictionary.") + .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"), + "Get Allreduce Fusion Dictionary.") + .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"), + "Fetch the inputs of Conv or Matmul for quant export.") + .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), + py::arg("broadcast_params") = py::dict(), "Build data graph.") + .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") + .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); + + (void)py::class_>(m, "EnvInstance_") + .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) + .def(py::init()); + + (void)m.def("generate_key", &mindspore::pipeline::GenerateKey, "Generate the function graph key."); + (void)m.def("real_run_op", &mindspore::pynative::RunOp, "Run op pynatively."); + (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id"); + (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl"); + (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl"); + (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature."); + (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"), + py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"), + py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset."); + (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode."); + (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend."); + + (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); + + (void)py::class_>(m, "MSContext") + .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") + .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") + .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.") + .def("get_execution_mode", &mindspore::MsContext::execution_mode, "Get execution mode.") + .def("set_execution_mode", &mindspore::MsContext::set_execution_mode, "Set execution mode.") + .def("set_precompile_only", &mindspore::MsContext::set_precompile_only, "Set enable precompile only.") + .def("get_precompile_only", &mindspore::MsContext::precompile_only, "Get enable precompile only.") + .def("get_device_target", &mindspore::MsContext::device_target, "Get device target.") + .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") + .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") + .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") + .def("open_tsd", &mindspore::MsContext::OpenTsd, "Open tdt dataset client.") + .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") + .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") + .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") + .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, + "Get whether to enable auto mixed precision.") + .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, + "Set whether to enable auto mixed precision.") + .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, + "Get whether to enable reduce precision.") + .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, + "Set whether to enable reduce precision.") + .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") + .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") + .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.") + .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.") + .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.") + .def("set_save_ms_model_path", &mindspore::MsContext::set_save_ms_model_path, "Set path to save ms model") + .def("get_enable_dump", &mindspore::MsContext::enable_dump, "Get whether to enable dump.") + .def("set_enable_dump", &mindspore::MsContext::set_enable_dump, "Set whether to enable dump.") + .def("get_save_dump_path", &mindspore::MsContext::save_dump_path, "Get path to dump.") + .def("set_save_dump_path", &mindspore::MsContext::set_save_dump_path, "Set path to dump.") + .def("set_graph_memory_max_size", &mindspore::MsContext::set_graph_memory_max_size, "set graph memory max size.") + .def("set_variable_memory_max_size", &mindspore::MsContext::set_variable_memory_max_size, + "set variable memory max size") + .def("get_enable_profiling", &mindspore::MsContext::enable_profiling, "Get whether to open profiling.") + .def("set_enable_profiling", &mindspore::MsContext::set_enable_profiling, "Set whether to open profiling.") + .def("get_profiling_options", &mindspore::MsContext::profiling_options, "Get options to profiling.") + .def("set_profiling_options", &mindspore::MsContext::set_profiling_options, "Set options to profiling.") + .def("get_check_bprop_flag", &mindspore::MsContext::check_bprop_flag, "Get whether to check bprop.") + .def("set_check_bprop_flag", &mindspore::MsContext::set_check_bprop_flag, "Set whether to check bprop.") + .def("get_max_device_memory", &mindspore::MsContext::max_device_memory, "Get deivce memory max size.") + .def("set_max_device_memory", &mindspore::MsContext::set_max_device_memory, "Set deivce memory max size.") + .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") + .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, + "Set the GraphKernel switch to on or off.") + .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") + .def("get_enable_sparse", &mindspore::MsContext::enable_sparse, "Get whether to enable sparsity.") + .def("set_enable_sparse", &mindspore::MsContext::set_enable_sparse, "Set whether to enable sparsity."); + + (void)py::class_>(m, "MpiConfig") + .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") + .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") + .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi."); + + (void)py::class_>(m, "AutoParallelContext") + .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.") + .def("get_device_num", &ParallelContext::device_num, "Get device num.") + .def("set_device_num", &ParallelContext::set_device_num, "Set device num.") + .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.") + .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") + .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") + .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") + .def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") + .def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") + .def("get_cast_before_mirror", &ParallelContext::cast_before_mirror, "Get cast before mirror.") + .def("set_cast_before_mirror", &ParallelContext::set_cast_before_mirror, "Set cast before mirror.") + .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") + .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.") + .def("get_communication_backend", &ParallelContext::communication_backend, "Get communication backend.") + .def("set_communication_backend", &ParallelContext::set_communication_backend, "Set communication backend.") + .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.") + .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") + .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") + .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") + .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, + "Set all reduce fusion split indices.") + .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices, + "Get all reduce fusion split indices.") + .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes, + "Set all reduce fusion split sizes.") + .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes, + "Get all reduce fusion split sizes.") + .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, + "Set enable/disable all reduce fusion.") + .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion, + "Get enable/disable all reduce fusion.") + .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.") + .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, + "Get parameter broadcast is set.") + .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.") + .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file, + "Set strategy checkpoint load file.") + .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file, + "Set strategy checkpoint save file.") + .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") + .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") + .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.") + .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.") + .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer, + "Set enable/disable parallel optimizer.") + .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, + "Get enable/disable parallel optimizer.") + .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); + + (void)py::class_>(m, "CostModelContext") + .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.") + .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity, + "Set the capacity of device memory.") + .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.") + .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha, + "Set the parameter cost_model_alpha of the DP algorithm.") + .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha, + "Get the parameter cost_model_alpha of the DP algorithm.") + .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta, + "Set the parameter cost_model_beta of the DP algorithm.") + .def("get_costmodel_beta", &CostModelContext::costmodel_beta, + "Get the parameter cost_model_beta of the DP algorithm.") + .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma, + "Set the parameter cost_model_gamma of the DP algorithm") + .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma, + "Get the parameter cost_model_gamma of the DP algorithm.") + .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold, + "Set the parameter cost_model_communi_threshold of the DP algorithm.") + .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold, + "Get the parameter cost_model_communi_threshold of the DP algorithm.") + .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const, + "Set the parameter cost_model_communi_const of the DP algorithm.") + .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const, + "Get the parameter cost_model_communi_const of the DP algorithm.") + .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias, + "Set the parameter cost_model_communi_bias of the DP algorithm.") + .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, + "Get the parameter cost_model_communi_bias of the DP algorithm.") + .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") + .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") + .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.") + .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.") + .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, + "Set the parameter gradient AllReduce fusion algorithm.") + .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, + "Get the parameter gradient AllReduce fusion algorithm.") + .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times, + "Set the parameter gradient AllReduce times.") + .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times, + "Get the parameter gradient AllReduce times.") + .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent, + "Set the parameter gradient AllReduce fusion tail percent.") + .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent, + "Get the parameter gradient AllReduce fusion tail percent.") + .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time, + "Set the parameter gradient AllReduce fusion tail time.") + .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time, + "Get the parameter gradient AllReduce fusion tail time.") + .def("set_costmodel_allreduce_fusion_allreduce_inherent_time", + &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time, + "Set the parameter gradient AllReduce fusion allreduce inherent time.") + .def("get_costmodel_allreduce_fusion_allreduce_inherent_time", + &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time, + "Get the parameter gradient AllReduce fusion allreduce inherent time.") + .def("set_costmodel_allreduce_fusion_allreduce_bandwidth", + &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth, + "Set the parameter gradient AllReduce fusion allreduce bandwidth.") + .def("get_costmodel_allreduce_fusion_allreduce_bandwidth", + &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth, + "Get the parameter gradient AllReduce fusion allreduce bandwidth.") + .def("set_costmodel_allreduce_fusion_computation_time_parameter", + &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter, + "Set the parameter gradient AllReduce fusion computation time parameter.") + .def("get_costmodel_allreduce_fusion_computation_time_parameter", + &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter, + "Get the parameter gradient AllReduce fusion computation time parameter.") + .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable, + "Set the parameter tensor_slice_align_enable in strategy generation.") + .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable, + "Get the parameter tensor_slice_align_enable in strategy generation.") + .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size, + "Set the parameter tensor_slice_size in strategy generation.") + .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size, + "Get the parameter tensor_slice_size in strategy generation.") + .def("set_fully_use_devices", &CostModelContext::set_fully_use_device, + "Set the parameter fully_use_devices in the DP algorithm.") + .def("get_fully_use_devices", &CostModelContext::fully_use_device, + "Get the parameter fully_use_devices in the DP algorithm.") + .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow, + "Set the parameter elementwise_op_strategy_follow in the DP algorithm.") + .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow, + "Get the parameter elementwise_op_strategy_follow in the DP algorithm.") + .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.") + .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters."); + + (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void { + // only in case that c++ calling python interface, ClearResAtexit should be called. + if (mindspore::parse::python_adapter::IsPythonEnv()) { + mindspore::pipeline::ClearResAtexit(); + +#ifdef ENABLE_MINDDATA + py::module iterators = py::module::import("mindspore.dataset.engine.iterators"); + (void)iterators.attr("_cleanup")(); +#endif + } + }}); + + (void)py::class_>(m, "EventWriter_") + .def(py::init()) + .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") + .def("Open", &EventWriter::Open, "Open the write file.") + .def("Write", &EventWriter::Write, "Write the serialize event.") + .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.") + .def("Flush", &EventWriter::Flush, "Flush the event.") + .def("Close", &EventWriter::Close, "Close the write.") + .def("Shut", &EventWriter::Shut, "Final close the write."); + + (void)py::class_>(m, "Oplib") + .def(py::init()) + .def_static("reg_op", &OpLib::RegOp, "Register op info."); +#ifdef ENABLE_GPU_COLLECTIVE + (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective, + "Init gpu collective communication mode."); + (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective, + "Finalize gpu collective communication mode."); +#else + (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective, + "Init gpu collective communication mode."); + (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective, + "Finalize gpu collective communication mode."); + +#endif + + (void)py::class_>(m, "OpInfoLoaderPy") + .def(py::init()) + .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info."); +} diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc new file mode 100644 index 0000000000..baef64481b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -0,0 +1,559 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/data_converter.h" +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "ir/func_graph_cloner.h" +#include "utils/symbolic.h" +#include "utils/context/ms_context.h" +#include "debug/trace.h" +#include "frontend/optimizer/ad/grad.h" + +namespace mindspore { +namespace parse { +using Tensor = mindspore::tensor::Tensor; +using TensorPtr = mindspore::tensor::TensorPtr; +using MetaTensor = mindspore::tensor::MetaTensor; +using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; + +FuncGraphPtr ConvertToBpropCut(const py::object &obj) { + std::vector results = data_converter::GetObjKey(obj); + std::string obj_key = results[0]; + py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME); + + auto bprop_graph = std::make_shared(); + std::vector outputs; + + auto fake_bprop = std::make_shared("bprop_cut", py::object()); + fake_bprop->set_hook(bprop_func); + (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); + outputs.push_back(NewValueNode(fake_bprop)); + + py::object code_obj = py::getattr(bprop_func, "__code__"); + size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; + for (size_t i = 0; i < inputs_num; ++i) { + auto param = bprop_graph->add_parameter(); + outputs.push_back(param); + } + auto p1 = bprop_graph->add_parameter(); + auto p2 = bprop_graph->add_parameter(); + outputs.push_back(p1); + outputs.push_back(p2); + + bprop_graph->set_output(bprop_graph->NewCNode(outputs)); + data_converter::SetObjGraphValue(obj_key, bprop_graph); + return bprop_graph; +} + +namespace { +bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { + MS_LOG(DEBUG) << "Converting python tuple"; + py::tuple tuple = obj.cast(); + std::vector value_list; + for (size_t it = 0; it < tuple.size(); ++it) { + ValuePtr out = nullptr; + bool success = ConvertData(tuple[it], &out, use_signature); + if (!success) { + return false; + } + value_list.push_back(out); + } + *data = std::make_shared(value_list); + + return true; +} + +bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { + MS_LOG(DEBUG) << "Converting python list"; + + py::list list = obj.cast(); + std::vector value_list; + for (size_t it = 0; it < list.size(); ++it) { + ValuePtr out = nullptr; + bool success = ConvertData(list[it], &out, use_signature); + if (!success) { + return false; + } + value_list.push_back(out); + } + *data = std::make_shared(value_list); + return true; +} + +bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { + MS_LOG(DEBUG) << "Converting cell list"; + py::sequence list = obj; + std::vector value_list; + for (size_t it = 0; it < list.size(); ++it) { + ValuePtr out = nullptr; + bool success = ConvertData(list[it], &out, use_signature); + if (!success) { + return false; + } + value_list.push_back(out); + } + *data = std::make_shared(value_list); + return true; +} + +bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { + MS_LOG(DEBUG) << "Converting python dict"; + + py::dict dict_values = obj.cast(); + std::vector> key_values; + for (auto item : dict_values) { + if (!py::isinstance(item.first)) { + MS_LOG(EXCEPTION) << "The key of dict is only support str."; + } + std::string key = py::str(item.first); + ValuePtr out = nullptr; + bool success = ConvertData(dict_values[item.first], &out, use_signature); + if (!success) { + return false; + } + key_values.emplace_back(key, out); + } + *data = std::make_shared(key_values); + return true; +} + +void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting python module"; + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); + *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); +} + +void ConvertDataClass(py::object obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting dataclass"; + // Maybe the obj is dataclass define + auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); + // desc has format "", strip the '<' and '>' by offset 1; + *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); +} + +bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { + MS_LOG(DEBUG) << "Converting primitive object"; + + // need check the primitive is class type or instance + auto obj_type = data_converter::GetObjType(obj); + if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { + auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); + // desc has format "", strip the '<' and '>' by offset 1; + *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + } else { + auto primitive = obj.cast(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; + return false; + } + if (py::hasattr(obj, "__setattr_flag__")) { + if (py::hasattr(obj, "_clone")) { + auto clone_fn = obj.attr("_clone"); + py::object new_obj = clone_fn(); + primitive = new_obj.cast(); + } + } + if (use_signature) { + *data = std::make_shared(primitive->name(), primitive); + } else { + *data = primitive; + } + } + return true; +} + +bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { + MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; + auto meta = obj.cast(); + if (meta == nullptr) { + MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; + return false; + } + if (use_signature) { + *data = std::make_shared(meta->name(), meta); + } else { + *data = meta; + } + return true; +} + +bool ConvertDataType(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting type object"; + auto typeptr = obj.cast(); + if (typeptr == nullptr) { + MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null"; + return false; + } + *data = typeptr; + return true; +} + +bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting MetaTensor object."; + + auto m_tensor = obj.cast(); + if (m_tensor == nullptr) { + MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null."; + return false; + } + *data = m_tensor; + return true; +} + +bool ConvertTensor(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting tensor object"; + + auto m_tensor = obj.cast(); + if (m_tensor == nullptr) { + MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null"; + return false; + } + *data = m_tensor; + return true; +} + +bool ConvertSlice(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting slice object"; + + py::slice slice_obj = obj.cast(); + auto convert_func = [obj](std::string attr) -> ValuePtr { + auto py_attr = py::getattr(obj, attr.c_str()); + if (py::isinstance(py_attr)) { + return kNone; + } else if (py::isinstance(py_attr)) { + int value = py::cast(py_attr); + return MakeValue(value); + } else { + MS_LOG(EXCEPTION) << "Slice should contain only int or none"; + } + }; + ValuePtr start = convert_func("start"); + ValuePtr stop = convert_func("stop"); + ValuePtr step = convert_func("step"); + *data = std::make_shared(start, stop, step); + return true; +} + +bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { + FuncGraphPtr func_graph = ConvertToFuncGraph(obj); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return false; + } + // if the cell object has specified bprop, it has user-defined bprop function parse and record it + if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { + FuncGraphPtr bprop_graph = nullptr; + bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); + if (enable_bprop_debug) { + bprop_graph = ConvertToBpropCut(obj); + } else { + bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); + } + if (bprop_graph != nullptr) { + (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); + func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); + } + } + *data = func_graph; + return true; +} + +bool ConvertOtherObj(py::object obj, ValuePtr *const data) { + auto obj_type = data_converter::GetObjType(obj); + MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; + if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { + MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; + std::string desc = py::str(obj); + // desc has format "", strip the '<' and '>' by offset 1; + *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); + return true; + } + if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { + MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; + FuncGraphPtr func_graph = ConvertToFuncGraph(obj); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return false; + } + *data = func_graph; + return true; + } + if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { + // Create the namespace for common class instance + // When the obj is Cell, default parse the 'construct' + if (data_converter::IsCellInstance(obj)) { + return ConvertCellObjToFuncGraph(obj, data); + } + + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); + *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + return true; + } + MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); + return false; +} +} // namespace + +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { + // check parameter valid + if (data == nullptr) { + MS_LOG(ERROR) << "Data is null pointer"; + return false; + } + + bool ret = true; + ValuePtr converted = nullptr; + if (py::isinstance(obj)) { + converted = kNone; + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + converted = std::make_shared(py::cast(obj)); + } else if (py::isinstance(obj)) { + ret = ConvertDict(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertSlice(obj, &converted); + } else if (py::isinstance(obj)) { + converted = kEllipsis; + } else if (py::isinstance(obj)) { + ret = ConvertTuple(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { + ret = ConvertCellList(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertList(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ConvertNameSpace(obj, &converted); + } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { + ConvertDataClass(obj, &converted); + } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { + ret = ConvertPrimitive(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) { + ret = ConvertMetaFuncGraph(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) { + ret = ConvertDataType(obj, &converted); + } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { + ret = ConvertTensor(obj, &converted); + } else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) { + ret = ConvertMetaTensor(obj, &converted); + } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { + std::shared_ptr env = obj.cast>(); + converted = env; + } else if (py::hasattr(obj, "__parameter__")) { + auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); + ret = ConvertData(to_convert, &converted); + } else { + ret = ConvertOtherObj(obj, &converted); + } + + *data = converted; + return ret; +} + +// convert data to graph +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { + std::vector results = data_converter::GetObjKey(obj); + std::string obj_id = results[0] + python_mod_get_parse_method; + std::string obj_key = results[1]; + FuncGraphPtr func_graph = nullptr; + Any value = Any(); + bool is_cache = data_converter::GetObjectValue(obj_id, &value); + if (is_cache) { + if (value.is()) { + MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; + func_graph = value.cast(); + return func_graph; + } + } + + func_graph = ParsePythonCode(obj, python_mod_get_parse_method); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return nullptr; + } + + data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); + data_converter::CacheObjectValue(obj_id, func_graph); + if (obj_key != "") { + MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); + data_converter::SetObjGraphValue(obj_key, func_graph); + } + + return func_graph; +} +namespace data_converter { +static std::unordered_map object_map_ = std::unordered_map(); + +static std::unordered_map> object_graphs_map_ = + std::unordered_map>(); + +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { + object_graphs_map_[obj_key].push_back(data); + MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); +} + +const std::unordered_map> &GetObjGraphs() { + MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); + return object_graphs_map_; +} + +void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } +bool GetObjectValue(const std::string &obj_key, Any *const data) { + if (object_map_.count(obj_key)) { + *data = object_map_[obj_key]; + return true; + } + return false; +} +std::vector GetObjKey(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); + if (obj_tuple.size() != 2) { + MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements"; + } + return {py::cast(obj_tuple[0]), py::cast(obj_tuple[1])}; +} + +// get obj detail type +ResolveTypeDef GetObjType(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + auto obj_type = + ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); + return obj_type; +} + +// get class instance detail type +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + auto class_type = + ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); + return class_type; +} + +// check the object is Cell Instance +bool IsCellInstance(const py::object &obj) { + auto class_type = GetClassInstanceType(obj); + bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); + return isCell; +} + +// create the python class instance +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object obj; + if (params.size() == 0) { + obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); + } else { + obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); + } + return obj; +} + +// Generate an appropriate name and set to graph debuginfo +// character <> can not used in the dot file, so change to another symbol +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_graph->debug_info()); + // set detail name info of function + std::ostringstream oss; + for (size_t i = 0; i < name.size(); i++) { + if (name[i] == '<') { + oss << "「"; + } else if (name[i] == '>') { + oss << "」"; + } else { + oss << name[i]; + } + } + func_graph->debug_info()->set_full_name(oss.str()); +} + +ValuePtr PyDataToValue(const py::object &obj) { + py::object to_convert = obj; + if (py::hasattr(obj, "__parameter__")) { + to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); + } + ValuePtr value = nullptr; + (void)ConvertData(to_convert, &value); + return value; +} + +void ClearObjectCache() { + object_map_.clear(); + object_graphs_map_.clear(); +} +} // namespace data_converter + +static std::unordered_map g_dataClassToClass = {}; + +// parse dataclass to mindspore Class type +ClassPtr ParseDataClass(const py::object &cls_obj) { + std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); + std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); + std::string cls = cls_module + "." + cls_name; + auto iterator = g_dataClassToClass.find(cls); + if (iterator != g_dataClassToClass.end()) { + return iterator->second; + } + + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + ClassAttrVector attributes; + py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); + for (auto &item : names) { + TypePtr type_value = item.second.cast(); + MS_EXCEPTION_IF_NULL(type_value); + MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; + attributes.push_back(std::make_pair(py::cast(item.first), type_value)); + } + + std::unordered_map methods_map; + py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); + for (auto &item : methods) { + std::string fun_name = item.first.cast(); + py::object obj = py::cast(item.second); + std::shared_ptr method_obj = std::make_shared(obj, fun_name); + methods_map[fun_name] = method_obj; + } + + std::shared_ptr me_class = std::make_shared(Named(cls_name), attributes, methods_map); + // static Variable for cache + // cppcheck-suppress unreadVariable + g_dataClassToClass[cls] = me_class; + + return me_class; +} + +void CleanDataClassToClassMap() { g_dataClassToClass.clear(); } +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.h b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h new file mode 100644 index 0000000000..6632d4801e --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.h @@ -0,0 +1,61 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_DATA_CONVERTER_H_ +#define PIPELINE_PARSE_DATA_CONVERTER_H_ + +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parse { +// data convert for parse +namespace data_converter { +void CacheObjectValue(const std::string &obj_key, const Any &data); +bool GetObjectValue(const std::string &obj_key, Any *const data); + +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); + +const std::unordered_map> &GetObjGraphs(); + +std::vector GetObjKey(const py::object &obj); +ResolveTypeDef GetObjType(const py::object &obj); +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); + +bool IsCellInstance(const py::object &obj); +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); +ValuePtr PyDataToValue(const py::object &obj); +void ClearObjectCache(); +} // namespace data_converter + +ClassPtr ParseDataClass(const py::object &cls_obj); +FuncGraphPtr ConvertToBpropCut(const py::object &obj); + +void CleanDataClassToClassMap(); + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_DATA_CONVERTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc new file mode 100644 index 0000000000..b52dddda66 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -0,0 +1,374 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/function_block.h" +#include +#include +#include +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/operator/ops.h" +#include "debug/info.h" +#include "debug/trace.h" +#include "pybind11/pybind11.h" + +namespace mindspore { +namespace py = pybind11; + +namespace parse { +FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { + func_graph_ = std::make_shared(); + matured_ = false; +} + +void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } + +// write variable records the variable name to corresponding node +void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { + MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); + vars_[var_name] = node; +} + +// read variable from predecessors +AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { + // get var node if it is found + if (vars_.count(var)) { + AnfNodePtr node = vars_[var]; + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return NewValueNode(GetValueNode(node)); + } else { + return node; + } + } + // get var from predecessor block ,if can't get the make a resolve node to it + if (matured_) { + // If only one predecessor block, read the definition of var from it. + if (prev_blocks_.size() == 1) { + auto block = prev_blocks_[0]; + MS_EXCEPTION_IF_NULL(block); + return block->ReadVariable(var); + } else if (prev_blocks_.empty()) { + // get namespace and make Reslove + return MakeResolveSymbol(var); + } + } + // If have more than one predecessor blocks then build a phi node. + auto debug_info = std::make_shared(); + debug_info->set_name(var); + TraceManager::DebugTrace(std::make_shared(debug_info)); + ParameterPtr phi_param = std::make_shared(func_graph()); + TraceManager::EndTrace(); + MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var; + func_graph()->add_parameter(phi_param); + phi_nodes_[phi_param] = var; + WriteVariable(var, phi_param); + if (matured_) { + SetPhiArgument(phi_param); + } + return phi_param; +} + +// Resolve Ast operator node +AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { + auto ast = parser_.ast(); + MS_EXCEPTION_IF_NULL(ast); + TraceGuard trace_guard(parser_.GetLocation(op)); + py::tuple namespace_var = ast->CallParserObjMethod(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op); + if (namespace_var.size() != 2) { + MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size(); + } + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]); + SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); + return MakeResolve(name_space, symbol); +} + +// Resolve class member, two possible: method, member variable +AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { + py::object namespace_var = + parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj()); + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + SymbolPtr symbol = std::make_shared(attr); + return MakeResolve(name_space, symbol); +} + +// Make a resolve node for symbol string +AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { + if (value.compare(0, strlen("self."), "self.") == 0) { + auto start = value.find_first_of('.') + 1; + if (start >= value.size()) { + MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; + return nullptr; + } + auto bits_str = value.substr(start); + return MakeResolveClassMember(bits_str); + } + py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); + + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_var[0]); + SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); + return MakeResolve(name_space, symbol); +} + +AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { + py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); + NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); + SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); + return MakeResolve(name_space, symbol); +} + +AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { + MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " + << ((std::string)resolve_symbol->symbol()); + ValueNodePtr module_node = NewValueNode(name_space); + ValueNodePtr symbol_node = NewValueNode(resolve_symbol); + auto node = func_graph()->NewCNode({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); + return node; +} + +// add input for the block's phi parameter +void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { + std::string var = phi_nodes_[phi]; + MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; + for (auto &pred : prev_blocks_) { + MS_EXCEPTION_IF_NULL(pred); + MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); + AnfNodePtr arg_node = pred->ReadVariable(var); + CNodePtr jump = pred->jumps_[this]; + jump->add_input(arg_node); + } + // If the phi node in the body part of a for/while loop is being removed, + // then the closure convert phase will generate a cycle in graph if the + // loop is kept after specialization. This should be investigate further. + // Just now user has to set a flag on a function to indicate the for loop + // will definitely can be unroll as the sequence in for statement is fixed + // size in compile time. + if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || + parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { + CollectRemovablePhi(phi); + } +} + +AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { + AnfNodePtr arg_node = nullptr; + for (auto &prev : prev_blocks_) { + MS_EXCEPTION_IF_NULL(prev); + AnfNodePtr temp_node = prev->ReadVariable(var); + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var + << " is " << temp_node->DebugString(); + if (temp_node != phi) { + if (arg_node == nullptr) { + arg_node = temp_node; + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() + << " may be replaced by node " << arg_node->DebugString(); + } else if (temp_node == arg_node) { + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node " + << arg_node->DebugString(); + } else { + MS_LOG(DEBUG) << "phi " << phi->ToString() + << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString() + << ", node2: " << temp_node->DebugString(); + return nullptr; + } + } + } + return arg_node; +} + +// Check if there is removable unnecessary phi node in this graph. +// as per the FIRM TR 3.2, a phi node can be remove if: +// +// If all arguments of a φ-function are the same value s or the φfunction itself, +// then we remove the φ-function and let all users directly uses. We call such a +// φ-function obviously unnecessary. +// When we removed a φ-function p, then we recursively try to apply this simplification +// rule with all (former) users of p, because they may have become obviously unnecessary +// due to the removal of p +// +// phi node in graph will be removed after the whole function is parsed in a DFS visit +// of that graph.The reason is : +// 1. when this function is called, not all usage of this phi node had bound to the +// graph of this function block, some may stay in vars_ in other blocks. +// 2. it's costly to iterate the graph to replace the phi for each phi. +// Args : +// phi : This parameter node is functioning as a phi node. +void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { + MS_EXCEPTION_IF_NULL(phi); + std::string var = phi_nodes_[phi]; + MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); + if (prev_blocks_.size() == 0) { + MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); + return; + } + AnfNodePtr arg_node = SearchReplaceNode(var, phi); + if (arg_node != nullptr) { + MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with " + << arg_node->DebugString(); + // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." + WriteVariable(var, arg_node); + removable_phis_[phi] = arg_node; + // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized + // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. + for (auto &prev : prev_blocks_) { + MS_EXCEPTION_IF_NULL(prev); + if (!prev->matured_) { + continue; + } + for (auto &phi_iter : prev->removable_phis_) { + MS_EXCEPTION_IF_NULL(phi_iter.second); + if (phi_iter.second->isa()) { + const auto ¶m = phi_iter.second->cast(); + if (param == phi) { + MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() + << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); + prev->removable_phis_[phi_iter.first] = arg_node; + } + } + } + } + } +} + +// A block should be marked matured if its predecessor blocks have been processed +void FunctionBlock::Mature() { + const auto &graphParamVec = func_graph_->parameters(); + for (auto ¶mItr : graphParamVec) { + MS_EXCEPTION_IF_NULL(paramItr); + ParameterPtr param = paramItr->cast(); + if (phi_nodes_.find(param) != phi_nodes_.cend()) { + SetPhiArgument(param); + } + } + matured_ = true; +} + +// Force the conditIon node to bool using bool operation +CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { + TraceManager::DebugTrace(std::make_shared(cond->debug_info())); + CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); + TraceManager::EndTrace(); + return op_apply_node; +} + +CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) { + TraceManager::DebugTrace(std::make_shared(cond->debug_info())); + CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond}); + TraceManager::EndTrace(); + return op_apply_node; +} + +// Perform a jump from this block to target block +void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { + if (func_graph()->get_return() != nullptr) { + MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " + << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); + } + std::vector input_nodes; + input_nodes.emplace_back(NewValueNode(target_block->func_graph())); + if (node != nullptr) { + input_nodes.emplace_back(node); + } + + CNodePtr jump = func_graph()->NewCNode(input_nodes); + jumps_[target_block.get()] = jump; + target_block->AddPrevBlock(shared_from_this()); + func_graph()->set_output(jump); + InsertDependItemsBeforeReturn(); +} + +// Perform a conditional jump using switch operation. +// The first CNode select graph with condition, and than execute this graph +void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block, bool unroll_loop) { + if (func_graph()->get_return() != nullptr) { + MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " + << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); + } + // Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch' + auto prim_switch = std::make_shared(prim::kPrimSwitch->name()); + if (!unroll_loop) { + prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0)); + } + CNodePtr switch_app = + func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); + CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); + func_graph()->set_output(switch_app_new); + InsertDependItemsBeforeReturn(); +} + +void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { + state_assign_[target] = readid; +} + +void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } + +void FunctionBlock::InsertDependItemsBeforeReturn() { + if (!prev_blocks_.empty()) { + for (auto &prev_block : prev_blocks_) { + MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); + } + } + + ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); + ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); + ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); + const std::string primitive_name("assign"); + const std::string module_name("mindspore.ops.functional"); + ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); + if (state_assign_.size() == 0 && auto_depends_.size() == 0) { + return; + } + AnfNodePtr state = nullptr; + std::vector vec_states; + vec_states.emplace_back(make_tuple_op); + for (auto &item : state_assign_) { + auto source = ReadVariable(item.second); + auto assign = func_graph()->NewCNode({assign_op, item.first, source}); + MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; + vec_states.emplace_back(assign); + } + for (auto &item : auto_depends_) { + MS_LOG(DEBUG) << "auto_depends " << item->ToString(); + vec_states.emplace_back(item); + } + // if there are only make_tuple_op and another node in vec_states(the vec_states size is 2) + // do not need to make_tuple, just use the node. + if (vec_states.size() == 2) { + state = vec_states[1]; + } else { + state = func_graph()->NewCNode(vec_states); + } + + AnfNodePtr old_ret = nullptr; + auto return_node = func_graph()->get_return(); + if (return_node) { + if (return_node->inputs().size() < 1) { + MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2"; + } + old_ret = return_node->input(1); + } else { + old_ret = NewValueNode(kNone); + } + AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); + AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); + func_graph()->set_output(ret, true); + state_assign_.clear(); +} +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h new file mode 100644 index 0000000000..cbf75a3dd8 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -0,0 +1,118 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_FUNCTION_BLOCK_H_ +#define PIPELINE_PARSE_FUNCTION_BLOCK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/parse_base.h" +#include "utils/log_adapter.h" +#include "utils/ordered_map.h" + +namespace mindspore { +namespace parse { + +class Parser; +class NameSpace; +class Symbol; +class FunctionBlock; +using FunctionBlockPtr = std::shared_ptr; + +// A function block is a straight-line code sequence with no branches, every block has one one exit point +// which is return. When parsing function, loop or branch , we use function block to track the structure of +// the original source code. +class FunctionBlock : public std::enable_shared_from_this { + public: + explicit FunctionBlock(const Parser &parser); + virtual ~FunctionBlock() {} + + FuncGraphPtr func_graph() { return func_graph_; } + void WriteVariable(const std::string &var_name, const AnfNodePtr &node); + AnfNodePtr ReadVariable(const std::string &var_name); + void AddPrevBlock(const FunctionBlockPtr &block); + void SetPhiArgument(const ParameterPtr &phi); + void CollectRemovablePhi(const ParameterPtr &phi); + // A block is matured if all its predecessors is generated + void Mature(); + CNodePtr ForceToBoolNode(const AnfNodePtr &cond); + CNodePtr ForceToWhileCond(const AnfNodePtr &cond); + void Jump(const FunctionBlockPtr &block, AnfNodePtr node); + AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); + void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock, + bool unroll_loop = true); + // record the assign statement of self.xx weight parameter ,which will use state_setitem op + void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); + void AddAutoDepend(const AnfNodePtr &target); + void InsertDependItemsBeforeReturn(); + void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } + bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } + AnfNodePtr MakeResolveAstOp(const py::object &op); + AnfNodePtr MakeResolveClassMember(std::string attr); + AnfNodePtr MakeResolveSymbol(const std::string &value); + AnfNodePtr MakeResolveOperation(const std::string &value); + AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); + const std::unordered_map &removable_phis() const { return removable_phis_; } + + private: + // block graph + FuncGraphPtr func_graph_; + + // the block's parser + const Parser &parser_; + + // A block is matured if all its prev_blocks is processed + bool matured_; + + // store the nest-level block + // refer to comments in Parser::func_block_list_; + std::vector prev_blocks_; + + // store args and variable's node + std::map vars_; + + // phi_nodes map the parameter node to variable, it can be resolved if the block's predecessors are processed + std::map phi_nodes_; + + // jumps map the successor block and the function call that perform jump + // refer to comments in Parser::func_block_list_ that how to break the cyclic reference + std::map jumps_; + + // keeps all removable phis which will be removed in one pass. + std::unordered_map removable_phis_; + + // set state nodes need to insert before function return nodes. + OrderedMap state_assign_; + + // hold declared global variables in function + std::set global_vars_; + + // other depend need to insert before function return nodes. + // summary or some other node + std::vector auto_depends_; +}; + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_FUNCTION_BLOCK_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc new file mode 100644 index 0000000000..edc9a66594 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -0,0 +1,1604 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/parse.h" +#include +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/composite/composite.h" +#include "utils/context/ms_context.h" +#include "debug/trace.h" + +namespace mindspore { +namespace parse { + +FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) { + (void)python_adapter::set_python_scoped(); + + if (obj == nullptr || py::isinstance(obj)) { + MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none"; + return nullptr; + } + + auto ast = std::make_shared(obj); + bool success = ast->InitParseAstInfo(python_mod_get_parse_method); + if (!success) { + MS_LOG(ERROR) << "Parse code to ast tree failed."; + return nullptr; + } + + auto parser = std::make_shared(ast); + + FuncGraphPtr func_graph = parser->ParseFuncGraph(); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode(); + return nullptr; + } + + return func_graph; +} + +// if any mixed precision flag add a cast node after the parameter node. +AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { + TypePtr dst_type; + if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { + dst_type = kFloat32; + } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { + dst_type = kFloat16; + } else { + return param; + } + auto cast_helper = prim::kPrimMixedPrecisionCast; + auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); + return cast; +} + +FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); + +Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { + errcode_ = PARSE_SUCCESS; + BuildMethodMap(); +} + +void Parser::BuildMethodMap() { + stmt_method_map_["Return"] = &Parser::ParseReturn; + stmt_method_map_["Expr"] = &Parser::ParseExpr; + stmt_method_map_["If"] = &Parser::ParseIf; + stmt_method_map_["Assign"] = &Parser::ParseAssign; + stmt_method_map_["While"] = &Parser::ParseWhile; + stmt_method_map_["For"] = &Parser::ParseFor; + stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; + stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; + stmt_method_map_["Global"] = &Parser::ParseGlobal; + stmt_method_map_["Break"] = &Parser::ParseBreak; + stmt_method_map_["Continue"] = &Parser::ParseContinue; + stmt_method_map_["Pass"] = &Parser::ParsePass; + expr_method_map_["NoneType"] = &Parser::ParseNone; + expr_method_map_["BinOp"] = &Parser::ParseBinOp; + expr_method_map_["Name"] = &Parser::ParseName; + expr_method_map_["Num"] = &Parser::ParseNum; + expr_method_map_["Str"] = &Parser::ParseStr; + expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; + expr_method_map_["Call"] = &Parser::ParseCall; + expr_method_map_["IfExp"] = &Parser::ParseIfExp; + expr_method_map_["Attribute"] = &Parser::ParseAttribute; + expr_method_map_["Compare"] = &Parser::ParseCompare; + expr_method_map_["BoolOp"] = &Parser::ParseBoolOp; + expr_method_map_["Lambda"] = &Parser::ParseLambda; + expr_method_map_["Tuple"] = &Parser::ParseTuple; + expr_method_map_["List"] = &Parser::ParseList; + expr_method_map_["Subscript"] = &Parser::ParseSubscript; + expr_method_map_["Slice"] = &Parser::ParseSlice; + expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice; + expr_method_map_["Index"] = &Parser::ParseIndex; + expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; + expr_method_map_["Dict"] = &Parser::ParseDict; + expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; +} + +void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } + +void Parser::InitParserEnvironment(const py::object &obj) { + Parser::top_func_graph_ = FuncGraphWeakPtr(); + ScopeManager::GetInstance().ClearScope(); + (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj); +} + +void Parser::CleanParserResource() { + Parser::top_func_graph_ = FuncGraphWeakPtr(); + ScopeManager::GetInstance().ClearScope(); +} + +FuncGraphPtr Parser::ParseFuncGraph() { + // get ast FunctionDef node + py::object node = ast_->GetAstNode(); + FunctionBlockPtr pFnBlock = ParseFunction(node); + if (errcode() != PARSE_SUCCESS) { + MS_LOG(ERROR) << "Parse function error, code is " << errcode(); + return nullptr; + } + + RemoveUnnecessaryPhis(); + + MS_EXCEPTION_IF_NULL(pFnBlock); + return pFnBlock->func_graph(); +} + +void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { + py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args"); + py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg"); + block->func_graph()->set_has_vararg(!py::isinstance(var_arg_node)); + + py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg"); + block->func_graph()->set_has_kwarg(!py::isinstance(kw_arg_node)); + + py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs"); + block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size())); + + MS_EXCEPTION_IF_NULL(ast_); + py::list args = ast_->GetArgs(fn_node); + for (std::size_t i = 0; i < args.size(); i++) { + std::string arg_name = py::cast(args[i].attr("arg")); + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + if (arg_name == "self") { + continue; + } + } + TraceManager::DebugTrace(GetLocation(args[i])); + auto para_node = std::make_shared(block->func_graph()); + MS_EXCEPTION_IF_NULL(para_node); + TraceManager::EndTrace(); + para_node->set_name(arg_name); + para_node->debug_info()->set_name(arg_name); + block->func_graph()->add_parameter(para_node); + AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node); + block->WriteVariable(arg_name, para_after_cast); + MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; + } +} + +void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { + py::list defaults = ast_->GetArgsDefaultValues(fn_node); + py::list args = ast_->GetArgs(fn_node); + std::vector namelist_for_default_value; + std::vector default_values; + for (std::size_t i = 0; i < args.size(); i++) { + std::string arg_name = py::cast(args[i].attr("arg")); + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + if (arg_name == "self") { + continue; + } + } + + namelist_for_default_value.push_back(arg_name); + if (py::isinstance(defaults[i])) { + default_values.push_back(NewValueNode(kNull)); + } else { + default_values.push_back(ParseExprNode(block, defaults[i])); + } + } + block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values); +} + +ScopePtr Parser::GetScopeForParseFunction() { + ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope(); + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj()); + if (!py::isinstance(scope_str)) { + auto scope_name = py::cast(scope_str); + scope = std::make_shared(scope_name); + } + } + return scope; +} + +FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { + ScopePtr scope = GetScopeForParseFunction(); + // the node created in the parsefunction context, will inherit the scope created using scope_guard + ScopeGuard scope_guard(scope); + TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node)); + FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this); + if (block != nullptr) { + pFunBlock->AddPrevBlock(block); + } else { + func_graph_ = pFunBlock->func_graph(); + } + pFunBlock->Mature(); + auto current_fg = pFunBlock->func_graph(); + auto function_name = py::cast(python_adapter::GetPyObjAttr(node, "name")); + MS_LOG(DEBUG) << "The function name is " << function_name; + + current_fg->debug_info()->set_name(function_name); + MS_EXCEPTION_IF_NULL(ast_); + py::list deco_list = node.attr("decorator_list"); + if (deco_list.size() > 0) { + current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); + } + + bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); + if (ast_->obj() != ast_->function()) { + set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); + } + + if (!set_flag) { + MS_LOG(ERROR) << "Set flags failed"; + return nullptr; + } + GenerateArgsNodeForFunction(pFunBlock, node); + + // when parsing the top graph of construct, save the top graph + if (GetTopFuncGraph() == nullptr) { + UpdateTopFuncGraph(pFunBlock->func_graph()); + } + + // save the function node to block + pFunBlock->WriteVariable(function_name, NewValueNode(current_fg)); + + py::object funcObj = python_adapter::GetPyObjAttr(node, "body"); + (void)ParseStatements(pFunBlock, funcObj); + + if (current_fg->get_return() == nullptr) { + MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); + errcode_ = PARSE_NO_RETURN; + return pFunBlock; + } + GenerateArgsDefaultValueForFunction(pFunBlock, node); + return pFunBlock; +} + +FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) { + py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__"); + size_t count = IntToSize(pcount); + MS_LOG(DEBUG) << "The nodes count is " << count; + for (size_t i = 0; i < count; i++) { + auto node = py::cast(nodes)[i]; + TraceManager::DebugTrace(GetLocation(node)); + fn_block = ParseStatement(fn_block, node); + TraceManager::EndTrace(); + // insert appropriate depended items for the function block if it has a return node + if (fn_block->func_graph()->get_return() != nullptr) { + fn_block->InsertDependItemsBeforeReturn(); + // Skip statements after 'return' (or 'break', 'continue'). + break; + } + } + return fn_block; +} + +FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) { + auto node_type = ast_->GetNodeType(node); + + // check the node type + AstMainType nodeType = node_type->main_type(); + if (nodeType != AST_MAIN_TYPE_STMT) { + MS_LOG(INFO) << "Node type is error : " << nodeType; + return block; + } + // call the process function + std::string node_name = node_type->node_name(); + MS_LOG(DEBUG) << "Ast node is " << node_name; + if (stmt_method_map_.count(node_name)) { + TraceManager::DebugTrace(GetLocation(node)); + auto stmt_block = (this->*stmt_method_map_[node_name])(block, node); + TraceManager::EndTrace(); + return stmt_block; + } else { + errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; + } +} + +AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast expr"; + auto node_type = ast_->GetNodeType(node); + // check the node type + AstMainType node_main_type = node_type->main_type(); + if (node_main_type != AST_MAIN_TYPE_EXPR) { + MS_LOG(ERROR) << "Node type is error : " << node_main_type; + errcode_ = PARSE_NODE_TYPE_NO_MATCH; + return nullptr; + } + // call the process function + std::string node_name = node_type->node_name(); + MS_LOG(DEBUG) << "Ast node is " << node_name; + if (expr_method_map_.count(node_name)) { + TraceManager::DebugTrace(GetLocation(node)); + auto expr_node = (this->*expr_method_map_[node_name])(block, node); + TraceManager::EndTrace(); + return expr_node; + } else { + errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; + py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + auto filename = ret[0].cast(); + auto line_no = ret[1].cast(); + MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; + } +} + +// process the expr statement and expand it +// eg: x.append(y) -> x = x.append(y) +FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Expr"; + // Expr only have value , no target + py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node); + + // refer python function expand_expr_statement, expand_info is one of the following: + // True, expr.value, x + // True, expr.value + // False, None, None + // check the expand info result + auto is_expand = py::cast(expand_info[0]); + if (is_expand) { + // process the expr statement + py::object value_object = expand_info[1]; + AnfNodePtr value_node = ParseExprNode(block, value_object); + + if (py::len(expand_info) == 2) { + // add to depend list and insert before output + block->AddAutoDepend(value_node); + } else { + // expand the assign statement + py::object target_node = expand_info[2]; + WriteAssignVars(block, target_node, value_node); + } + } + return block; +} + +LocationPtr Parser::GetLocation(const py::object &node) const { + MS_EXCEPTION_IF_NULL(ast_); + py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (ret.size() < 5) { + MS_LOG(EXCEPTION) << "List size should not be less than 5."; + } + // refer to Location::Location() for each member of ret: line, column, line_end, column_end. + auto location = std::make_shared(ret[0].cast(), ret[1].cast(), ret[2].cast(), + ret[3].cast(), ret[4].cast()); + return location; +} + +void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block) { + true_block->AddPrevBlock(pre_block); + true_block->Mature(); + + false_block->AddPrevBlock(pre_block); + false_block->Mature(); +} + +FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast return"; + MS_EXCEPTION_IF_NULL(block); + // create return valuenode + AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn); + // parse the return Statements value + py::object value = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr pReturnStatementNode = ParseExprNode(block, value); + // Create the cnode + CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode}); + + block->func_graph()->set_return(pReturnCNode); + + return block; +} + +// Process binary operators,eg: `a + b`, `a | b`, etc. +AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast BinOP"; + + py::object left = python_adapter::GetPyObjAttr(node, "left"); + py::object right = python_adapter::GetPyObjAttr(node, "right"); + py::object op = python_adapter::GetPyObjAttr(node, "op"); + // create left and right ANF node + AnfNodePtr left_node = ParseExprNode(block, left); + if (left_node == nullptr) { + MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode(); + return nullptr; + } + AnfNodePtr right_node = ParseExprNode(block, right); + if (right_node == nullptr) { + MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode(); + return nullptr; + } + // resolve the op + AnfNodePtr op_node = block->MakeResolveAstOp(op); + // create apply node + return block->func_graph()->NewCNode({op_node, left_node, right_node}); +} + +AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Name"; + auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); + MS_LOG(DEBUG) << "The Name id is " << name_id; + TraceGuard trace_guard(GetLocation(node)); + if (block->IsGlobalVar(name_id)) { + return block->MakeResolveSymbol(name_id); + } + return block->ReadVariable(name_id); +} + +AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { + MS_LOG(DEBUG) << "Process ast NoneType"; + return NewValueNode(kNone); +} + +AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { + MS_LOG(DEBUG) << "Process ast Ellipsis"; + return NewValueNode(kEllipsis); +} + +AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Num"; + py::object obj = python_adapter::GetPyObjAttr(node, "n"); + TraceGuard trace_guard(GetLocation(node)); + if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else { + // no else actually + MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString(); + errcode_ = PARSE_NODE_TYPE_UNKOWN; + return nullptr; + } +} + +AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Str"; + auto str_s = py::cast(python_adapter::GetPyObjAttr(node, "s")); + return NewValueNode(str_s); +} + +AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { + MS_LOG(DEBUG) << "Process ast NameConstant"; + py::object obj = python_adapter::GetPyObjAttr(node, "value"); + TraceGuard trace_guard(GetLocation(node)); + if (py::isinstance(obj)) { + MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj); + auto data = py::cast(obj); + return NewValueNode(data); + } else if (py::isinstance(obj)) { + MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj); + return NewValueNode(kNone); + } else { + // no else actually + MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString(); + errcode_ = PARSE_NODE_TYPE_UNKOWN; + return nullptr; + } +} +AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes) { + AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); + std::vector make_tuple_nodes; + make_tuple_nodes.push_back(make_tuple_op); + (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), + [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); + return block->func_graph()->NewCNode(make_tuple_nodes); +} +// process function call, eg : f1(x, y) ... +AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Call"; + // process function call + py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); + AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); + // function call arguments should be passed in as groups and unpacked later using unpack call + py::list args = python_adapter::GetPyObjAttr(node, "args"); + std::vector packed_arguments; + std::vector group_arguments; + + bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments); + bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments); + // if there is stared or keyword argument, unpack may be needed + bool need_unpack = need_unpack_args || need_unpack_keywords; + + return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); +} + +AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, + const std::vector &packed_arguments, + const std::vector &group_arguments, bool need_unpack) const { + // if there is keyword arguments or starred, using an unpack_call op to unpack the argument + if (need_unpack) { + std::vector unpack_call_nodes; + auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); + unpack_call_nodes.push_back(unpack_call_op); + unpack_call_nodes.push_back(call_function_anf_node); + (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), + [](AnfNodePtr node) -> AnfNodePtr { return node; }); + CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); + return unpack_call; + } + // else there is no keyword arguments and starred, parsed as normal arguments without unpack + std::vector func_call_nodes; + func_call_nodes.push_back(call_function_anf_node); + (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), + [](AnfNodePtr node) -> AnfNodePtr { return node; }); + CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes); + return call_anf_node; +} + +bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, + std::vector *packed_arguments, std::vector *group_arguments) { + bool need_unpack = false; + for (size_t i = 0; i < args.size(); i++) { + auto arg_node = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i]))); + if (arg_node == AST_SUB_TYPE_STARRED) { + if (!group_arguments->empty()) { + packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); + } + packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value"))); + group_arguments->clear(); + need_unpack = true; + } else { + group_arguments->push_back(ParseExprNode(block, args[i])); + } + } + if (!group_arguments->empty()) { + packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); + } + return need_unpack; +} + +bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, + std::vector *packed_arguments) { + bool need_unpack = false; + py::list keywords = python_adapter::GetPyObjAttr(node, "keywords"); + if (!keywords.empty()) { + need_unpack = true; + std::vector keys; + std::vector values; + for (size_t index = 0; index < keywords.size(); index++) { + auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg"); + auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value"); + if (py::isinstance(kw_key)) { + packed_arguments->push_back(ParseExprNode(block, kw_value)); + } else { + auto kw_key_c = kw_key.cast(); + keys.push_back(NewValueNode(kw_key_c)); + values.push_back(ParseExprNode(block, kw_value)); + } + } + auto keys_tuple = GenerateMakeTuple(block, keys); + auto values_tuple = GenerateMakeTuple(block, values); + auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); + std::vector make_dict_nodes; + make_dict_nodes.push_back(make_dict_op); + make_dict_nodes.push_back(keys_tuple); + make_dict_nodes.push_back(values_tuple); + packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes)); + } + return need_unpack; +} + +// process call attributes of class type define, eg: x.y() +AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Attribute"; + + // process class value,eg: self.xx + if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { + if (ast_->IsClassMember(node)) { + std::string var_name = "self."; + std::string attr_name = node.attr("attr").cast(); + (void)var_name.append(attr_name); + auto attr_obj = ast()->obj().attr(attr_name.c_str()); + if (py::hasattr(ast()->obj(), attr_name.c_str()) && + (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance(attr_obj) || + py::isinstance(attr_obj) || py::isinstance(attr_obj) || + py::isinstance(attr_obj) || data_converter::IsCellInstance(attr_obj))) { + return block->MakeResolveSymbol(var_name); + } else { + return block->ReadVariable(var_name); + } + } + } + + // process the get attr + // Use the Primitive replace the operation resolve node (getattr) + // because the getattr will eventually be converted to Primitive node + AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr); + + // process the attr body + py::object value_body = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr value_node = ParseExprNode(block, value_body); + if (value_node == nullptr) { + MS_LOG(WARNING) << "Parse attribute failed"; + return nullptr; + } + + // process the node attr + auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast(); + MS_LOG(DEBUG) << "Attr = " << attr_str; + TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr"))); + AnfNodePtr attr_node = NewValueNode(attr_str); + TraceManager::EndTrace(); + + // create the apply node + return block->func_graph()->NewCNode({op_node, value_node, attr_node}); +} + +// Process comparison expression : a == b. a > b etc. +AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Compare"; + + // for python comparison ,there may be if x>y>5 , + // which there is two ops , but we only support one now + py::list ops = python_adapter::GetPyObjAttr(node, "ops"); + if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) { + MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size(); + return nullptr; + } + + py::object left = python_adapter::GetPyObjAttr(node, "left"); + py::list comparators = python_adapter::GetPyObjAttr(node, "comparators"); + AnfNodePtr left_node = ParseExprNode(block, left); + AnfNodePtr right_node = ParseExprNode(block, comparators[0]); + + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); + + return block->func_graph()->NewCNode({op_node, left_node, right_node}); +} + +AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, + const py::object &op) { + // if there is only one bool op now + if (value_list.size() == 1) { + AnfNodePtr first_node = ParseExprNode(block, value_list[0]); + return first_node; + } else { + py::object first = value_list[0]; + py::list rest; + for (size_t i = 1; i < value_list.size(); i++) { + rest.append(value_list[i]); + } + + AnfNodePtr first_node = ParseExprNode(block, first); + AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); + auto op_node = block->MakeResolveAstOp(op); + return block->func_graph()->NewCNode({op_node, first_node, rest_node}); + } +} + +// Process comparison expression : a and b. a or b . +AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast BoolOp"; + py::object op_node = python_adapter::GetPyObjAttr(node, "op"); + py::list op_values = python_adapter::GetPyObjAttr(node, "values"); + return ProcessBoolOpValueList(block, op_values, op_node); +} + +// Process a function def +FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast FunctionDef"; + FunctionBlockPtr function_block = ParseFunction(node, block); + MS_EXCEPTION_IF_NULL(function_block); + + // get function name + py::str name = python_adapter::GetPyObjAttr(node, "name"); + std::string function_name = name; + ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph()); + block->WriteVariable(function_name, valuenode_graph); + return block; +} + +// Process a lambda expression . like lambda x,y: x + y +AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Lambda"; + FunctionBlockPtr func_block = MakeFunctionBlock(*this); + func_block->AddPrevBlock(block); + func_block->Mature(); + + // get lambda args + py::list args = ast_->GetArgs(node); + for (std::size_t i = 0; i < args.size(); i++) { + std::string arg = py::cast(args[i].attr("arg")); + TraceManager::DebugTrace(GetLocation(args[i])); + auto para_node = std::make_shared(func_block->func_graph()); + TraceManager::EndTrace(); + para_node->debug_info()->set_name(arg); + func_block->func_graph()->add_parameter(para_node); + func_block->WriteVariable(arg, para_node); + MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg; + } + + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); + func_block->func_graph()->set_output(lambda_body_node); + ValueNodePtr const_graph = NewValueNode(func_block->func_graph()); + return const_graph; +} + +// process a tuple +AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Tuple"; + MS_EXCEPTION_IF_NULL(block); + py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); + if (elts.size() == 0) { + auto empty_tuple = std::vector(); + return NewValueNode(std::make_shared(empty_tuple)); + } + + std::vector tuple_vec; + AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); + tuple_vec.emplace_back(make_tuple_op); + for (size_t i = 0; i < elts.size(); i++) { + AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); + tuple_vec.emplace_back(node_ptr); + } + CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec); + return tuple_app; +} + +// process a list +AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast List"; + MS_EXCEPTION_IF_NULL(block); + py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); + if (elts.size() == 0) { + auto empty_list = std::vector(); + return NewValueNode(std::make_shared(empty_list)); + } + + std::vector list_vec; + AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST); + list_vec.emplace_back(make_list_op); + for (size_t i = 0; i < elts.size(); i++) { + AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); + list_vec.emplace_back(node_ptr); + } + CNodePtr list_app = block->func_graph()->NewCNode(list_vec); + return list_app; +} + +// process a subscript, such as x[y] , node expressed as value[slice] +AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Subscript"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + py::object value_node = python_adapter::GetPyObjAttr(node, "value"); + py::object slice_node = python_adapter::GetPyObjAttr(node, "slice"); + AnfNodePtr value = ParseExprNode(block, value_node); + AnfNodePtr slice = ParseExprNode(block, slice_node); + + return block->func_graph()->NewCNode({op_getitem, value, slice}); +} + +// process a slice, get the slice value +AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Slice"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE); + py::object start = python_adapter::GetPyObjAttr(node, "lower"); + py::object stop = python_adapter::GetPyObjAttr(node, "upper"); + py::object step = python_adapter::GetPyObjAttr(node, "step"); + AnfNodePtr start_node = ParseExprNode(block, start); + AnfNodePtr stop_node = ParseExprNode(block, stop); + AnfNodePtr step_node = ParseExprNode(block, step); + + return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node}); +} + +// process a extslice +AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast ExtSlice"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); + py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims"); + + std::vector node_vec; + node_vec.emplace_back(make_tuple_op); + for (size_t i = 0; i < slice_tuple.size(); i++) { + AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); + node_vec.emplace_back(node_ptr); + } + CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec); + return tuple_conde; +} + +// process a index, get the index number +AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Index"; + py::object value_node = python_adapter::GetPyObjAttr(node, "value"); + return ParseExprNode(block, value_node); +} + +// process a UnaryOp, +a, -b +AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast UnaryOp"; + py::object op = python_adapter::GetPyObjAttr(node, "op"); + + MS_EXCEPTION_IF_NULL(block); + // resolve the op + AnfNodePtr op_node = block->MakeResolveAstOp(op); + + py::object operand = python_adapter::GetPyObjAttr(node, "operand"); + AnfNodePtr operand_node = ParseExprNode(block, operand); + return block->func_graph()->NewCNode({op_node, operand_node}); +} + +// process a dict ast node expression +AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Dict"; + py::list keys = node.attr("keys"); + py::list values = node.attr("values"); + std::vector key_nodes; + std::vector value_nodes; + for (size_t i = 0; i < keys.size(); i++) { + key_nodes.push_back(ParseExprNode(block, keys[i])); + value_nodes.push_back(ParseExprNode(block, values[i])); + } + auto keys_tuple = GenerateMakeTuple(block, key_nodes); + auto values_tuple = GenerateMakeTuple(block, value_nodes); + MS_EXCEPTION_IF_NULL(block); + auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); + return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple}); +} + +// process a augment assign such as a += b; +FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast AugAssign"; + py::object op = python_adapter::GetPyObjAttr(node, "op"); + + MS_EXCEPTION_IF_NULL(block); + // resolve the op + AnfNodePtr op_node = block->MakeResolveAstOp(op); + py::object target_node = python_adapter::GetPyObjAttr(node, "target"); + MS_EXCEPTION_IF_NULL(ast_); + auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node))); + AnfNodePtr read_node = nullptr; + if (ast_type == AST_SUB_TYPE_NAME) { + read_node = ParseName(block, target_node); + } else if (ast_->IsClassMember(target_node)) { + read_node = ParseAttribute(block, target_node); + } else { + MS_LOG(EXCEPTION) << "Not supported augassign"; + } + if (read_node == nullptr) { + MS_LOG(EXCEPTION) << "Can not get target node "; + } + + py::object value = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr value_node = ParseExprNode(block, value); + CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node}); + WriteAssignVars(block, target_node, augassign_app); + return block; +} + +// process global declaration such as 'global x'; +FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast Global"; + MS_EXCEPTION_IF_NULL(block); + py::list vars = python_adapter::GetPyObjAttr(node, "names"); + for (auto &item : vars) { + block->AddGlobalVar(py::cast(item)); + } + return block; +} + +// process a if statement +FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast If"; + py::object test_node = python_adapter::GetPyObjAttr(node, "test"); + AnfNodePtr condition_node = ParseExprNode(block, test_node); + MS_EXCEPTION_IF_NULL(block); + CNodePtr bool_node = block->ForceToBoolNode(condition_node); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + MakeConditionBlocks(block, true_block, false_block); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + // process the if-true branch + py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); + + // if the return_ is set ,it has its own continuation block + if (true_end->func_graph()->get_return() == nullptr) { + true_end->Jump(after_block, nullptr); + } + + // process the orelse branch + py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); + FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode); + + // if the return_ is set ,it has its own continuation block + if (false_end->func_graph()->get_return() == nullptr) { + false_end->Jump(after_block, nullptr); + } + + block->ConditionalJump(bool_node, true_block, false_block); + after_block->Mature(); + return after_block; +} + +FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast While"; + MS_EXCEPTION_IF_NULL(block); + MS_LOG(INFO) << "Parse while statement"; + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr header_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr body_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + body_block->AddPrevBlock(header_block); + after_block->AddPrevBlock(header_block); + block->Jump(header_block, nullptr); + + py::object test_node = python_adapter::GetPyObjAttr(node, "test"); + AnfNodePtr condition_node = ParseExprNode(header_block, test_node); + condition_node = header_block->ForceToWhileCond(condition_node); + body_block->Mature(); + header_block->ConditionalJump(condition_node, body_block, after_block); + + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, nullptr}; + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr after_body = ParseStatements(body_block, body_node); + if (after_body->func_graph()->get_return() == nullptr) { + after_body->Jump(header_block, nullptr); + } + + header_block->Mature(); + after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. + return after_block; +} + +CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node, + const AnfNodePtr &op_iter) { + py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); + AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); + return block->func_graph()->NewCNode({op_iter, iter_anf_node}); +} + +CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, + const AnfNodePtr &op_hasnext) { + MS_EXCEPTION_IF_NULL(header_block); + return header_block->func_graph()->NewCNode({op_hasnext, iter_param}); +} + +FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { + TraceManager::DebugTrace(trace_info); + FunctionBlockPtr body_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + return body_block; +} + +// A for loop will generate 3 functions :the test, the body, and the continuation +// for x in xs: +// body +// it is compiled to be following statement +// if len(xs) < max_loop_cnt: +// ParseForIter() // use iter to implement for loop, which always unroll loop +// else: +// ParseForLoop() // use loop var to implement for loop, which always sink loop +FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast For, create an if else statement"; + MS_EXCEPTION_IF_NULL(block); + // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' + AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); + py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); + AnfNodePtr iter_node = ParseExprNode(block, iter_obj); + CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); + CNodePtr bool_node = block->func_graph()->NewCNode( + {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); + + // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + MakeConditionBlocks(block, true_block, false_block); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + FunctionBlockPtr true_end = ParseForIter(true_block, node); + true_end->Jump(after_block, nullptr); + + FunctionBlockPtr false_end = ParseForLoop(false_block, node); + false_end->Jump(after_block, nullptr); + + block->ConditionalJump(bool_node, true_block, false_block); + after_block->Mature(); + return after_block; +} + +// A for loop will generate 3 functions :the test, the body, and the continuation +// for x in xs: +// body +// it is compiled to be following statement +// it = iter(xs) +// while hastnext(it) +// x, it = next(it) +// body +FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast For"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); + AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); + // generate the iterator apply + CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); + MS_EXCEPTION_IF_NULL(iter_apply); + FunctionBlockPtr header_block = + GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(header_block); + // generate the hasnext apply which is a condition + ParameterPtr iter_param = header_block->func_graph()->add_parameter(); + CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); + // generate the body of the for statement + FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(body_block); + body_block->AddPrevBlock(header_block); + // generate the iterator next apply + // process as following: `app = next(it); target = app[0]; it = app[1];` + CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param}); + CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)}); + py::object target_node = python_adapter::GetPyObjAttr(node, "target"); + + CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)}); + WriteAssignVars(body_block, target_node, target_app); + + // link the variable name with the target + auto it_info = std::make_shared(target_app->debug_info()); + iter_param->debug_info()->set_trace_info(it_info); + iter2_app->debug_info()->set_trace_info(it_info); + iter_apply->debug_info()->set_trace_info(it_info); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + MS_EXCEPTION_IF_NULL(after_block); + TraceManager::EndTrace(); + after_block->AddPrevBlock(header_block); + + block->Jump(header_block, iter_apply); + body_block->Mature(); + header_block->ConditionalJump(cond_apply, body_block, after_block); + + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, iter2_app}; + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); + if (after_body_block->func_graph()->get_return() == nullptr) { + after_body_block->Jump(header_block, iter2_app); + } + + header_block->Mature(); + after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. + return after_block; +} + +// A for loop will generate 3 functions :the test, the body, and the continuation +// for x in xs: +// body +// it is compiled to be following statement +// i = 0 +// while i < len(xs) +// x = xs[i] +// i = i + 1 +// body +FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast For by loop variable"; + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + + // get varibale name of 'x' in statement 'for x in xs' + py::object target_node = python_adapter::GetPyObjAttr(node, "target"); + + // create statement 'len(xs)' + py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); + AnfNodePtr iter_node = ParseExprNode(block, iter_obj); + MS_EXCEPTION_IF_NULL(iter_node); + CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); + + FunctionBlockPtr header_block = + GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(header_block); + // create loop variable 'i' + ParameterPtr loop_var = header_block->func_graph()->add_parameter(); + // create loop condition 'i < len(xs)' + CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); + + // generate the body of the for statement + FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); + MS_EXCEPTION_IF_NULL(body_block); + body_block->AddPrevBlock(header_block); + // create 'x = xs[i]' + CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var}); + WriteAssignVars(body_block, target_node, target_var); + // create 'i = i + 1' + CNodePtr loop_var_inc = + body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)}); + body_block->WriteVariable(loop_var->name(), loop_var_inc); + + // link the variable name with the target + auto it_info = std::make_shared(loop_var_inc->debug_info()); + loop_var->debug_info()->set_trace_info(it_info); + len_iter->debug_info()->set_trace_info(it_info); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr after_block = MakeFunctionBlock(*this); + MS_EXCEPTION_IF_NULL(after_block); + TraceManager::EndTrace(); + after_block->AddPrevBlock(header_block); + + block->Jump(header_block, NewValueNode(0)); + body_block->Mature(); + + header_block->ConditionalJump(cond_node, body_block, after_block, false); + + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, loop_var_inc}; + py::object body_node = python_adapter::GetPyObjAttr(node, "body"); + FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); + if (after_body_block->func_graph()->get_return() == nullptr) { + after_body_block->Jump(header_block, loop_var_inc); + } + + header_block->Mature(); + after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. + return after_block; +} + +AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast IfExp"; + MS_EXCEPTION_IF_NULL(block); + py::object test_node = python_adapter::GetPyObjAttr(node, "test"); + AnfNodePtr condition_node = ParseExprNode(block, test_node); + CNodePtr bool_node = block->ForceToBoolNode(condition_node); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr true_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + FunctionBlockPtr false_block = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + + MakeConditionBlocks(block, true_block, false_block); + + // process the if-true branch + py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); + true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode)); + AnfNodePtr true_node = ParseExprNode(true_block, bodyNode); + + // process the orelse branch + py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); + false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode)); + AnfNodePtr false_node = ParseExprNode(false_block, orelseNode); + + true_block->func_graph()->set_output(true_node); + false_block->func_graph()->set_output(false_node); + + // Use the Primitive replace the operation resolve node (switch) + // because the switch will eventually be converted to Primitive node + CNodePtr switch_app = + block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); + + std::vector call_graph_nodes{switch_app}; + CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); + return switch_app_call; +} + +void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { + MS_EXCEPTION_IF_NULL(block); + MS_EXCEPTION_IF_NULL(assigned_node); + py::str name = python_adapter::GetPyObjAttr(targ, "id"); + std::string name_id = name; + assigned_node->debug_info()->set_name(name_id); + // set the debug name of the constant graph + if (IsValueNode(assigned_node)) { + // the value should be graph + auto fg = GetValueNode(assigned_node); + if (fg->debug_info()->name().empty()) { + fg->debug_info()->set_name(name_id); + } + } + block->WriteVariable(name_id, assigned_node); +} + +void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); + py::list items = python_adapter::GetPyObjAttr(targ, "elts"); + for (size_t i = 0; i < items.size(); i++) { + // Use the Primitive replace the operation resolve node (getitem) + // because the getitem will eventually be converted to Primitive node + CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast(i))}); + + py::object elt = items[i]; + WriteAssignVars(block, elt, item_apply); + } +} + +void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, + const AnfNodePtr &assigned_node) { + // Now only support the self.xx = xxxxx, can't support x.y = xxxx + AnfNodePtr target_node = ParseExprNode(block, targ); + MS_EXCEPTION_IF_NULL(target_node); + + std::string attr_name = targ.attr("attr").cast(); + std::string var_name = "self."; + (void)var_name.append(attr_name); + MS_LOG(DEBUG) << "assign " << var_name; + + // Get targ location info for error printing + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type + if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" + << line_no; + } + auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); + auto obj_type = obj.attr("__class__").attr("__name__"); + if (!py::hasattr(obj, "__parameter__")) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" + << py::str(obj).cast() << "' with type '" + << py::str(obj_type).cast() << "' at " << filename << ":" << line_no; + } + + MS_EXCEPTION_IF_NULL(block); + block->WriteVariable(var_name, assigned_node); + MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); + block->SetStateAssgin(target_node, var_name); +} + +void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, + const AnfNodePtr &assigned_node) { + MS_EXCEPTION_IF_NULL(block); + AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); + py::object value_obj = python_adapter::GetPyObjAttr(targ, "value"); + py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice"); + AnfNodePtr value_node = ParseExprNode(block, value_obj); + AnfNodePtr slice_node = ParseExprNode(block, slice_obj); + CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); + // getitem apply should return the sequence data structure itself + std::string var_name = ""; + if (ast_->IsClassMember(value_obj)) { + std::string attr_name = value_obj.attr("attr").cast(); + var_name = "self." + attr_name; + if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function."; + } + auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); + auto obj_type = obj.attr("__class__").attr("__name__"); + if (!py::hasattr(obj, "__parameter__")) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" + << py::str(obj).cast() << "' with type '" + << py::str(obj_type).cast() << "'."; + } + } else { + var_name = value_obj.attr("id").cast(); + } + block->WriteVariable(var_name, setitem_app); +} + +void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + MS_LOG(DEBUG) << "Process WriteAssignVars"; + auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ))); + if (ast_type == AST_SUB_TYPE_NAME) { + HandleAssignName(block, targ, value_node); + } else if (ast_type == AST_SUB_TYPE_TUPLE) { + HandleAssignTuple(block, targ, value_node); + } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { + HandleAssignSubscript(block, targ, value_node); + } else if (ast_->IsClassMember(targ)) { + HandleAssignClassMember(block, targ, value_node); + } else { + MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type + << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info()); + } +} + +// process a assign statement, such as a =b, a,b = tup +FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { + MS_LOG(DEBUG) << "Process ast assgin"; + py::object value_object = python_adapter::GetPyObjAttr(node, "value"); + AnfNodePtr value_node = ParseExprNode(block, value_object); + py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); + py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__"); + size_t count = IntToSize(pcount); + MS_LOG(DEBUG) << "The nodes count is " << count; + for (size_t i = 0; i < count; i++) { + auto target_node = py::cast(targets_object)[i]; + WriteAssignVars(block, target_node, value_node); + } + + return block; +} + +FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { + if (loops_.empty()) { + // Report error if loop context not set for the 'break' statement. + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; + } + // Get current loop. + Loop &loop = loops_.top(); + if (loop.end == nullptr) { + // Create end_block if it is not existed. + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + loop.end = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + } + // Jump to the end_block. + block->Jump(loop.end, nullptr); + return block; +} + +FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { + if (loops_.empty()) { + // Report error if loop context not set for the 'continue' statement. + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; + } + // Jump to the header of the loop with iterator called. + Loop &loop = loops_.top(); + block->Jump(loop.header, loop.iterator); + return block; +} + +FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { + // We just bypass 'pass' statement. + return block; +} + +void Parser::RemoveUnnecessaryPhis() { + // merge all removable phis to one map; + std::unordered_map removable_phis; + for (FunctionBlockPtr &block : func_block_list_) { + MS_EXCEPTION_IF_NULL(block); + removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); + } + + if (removable_phis.size() == 0) { + return; + } + for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { + if (node->isa()) { + const auto &cnode = node->cast(); + auto &inputs = cnode->inputs(); + for (std::size_t i = 0; i < inputs.size(); i++) { + if (inputs[i]->isa()) { + const auto &inp = inputs[i]->cast(); + const auto &iter = removable_phis.find(inp); + if (iter == removable_phis.end()) { + continue; + } + auto &argNode = iter->second; + MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " + << cnode->DebugString() << " with " << argNode->DebugString(); + cnode->set_input(i, argNode); + } + } + } + } +} + +// ParseAst class code +bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) { + // init the type + target_type_ = PARSE_TARGET_UNKNOW; + + // call python parse, get the parser fn + module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD); + + // get the obj type + auto type = data_converter::GetObjType(obj_); + if (type == RESOLVE_TYPE_FUNCTION) { + target_type_ = PARSE_TARGET_FUNCTION; + function_ = obj_; + } else if (type == RESOLVE_TYPE_METHOD) { + // process the method ,need get the method's self obj + target_type_ = PARSE_TARGET_METHOD; + py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS); + if (py::isinstance(method_object)) { + MS_LOG(ERROR) << "Get method's self object instance failed."; + return false; + } + target_type_ = PARSE_TARGET_OBJECT_INSTANCE; + function_ = obj_; + obj_ = method_object; + } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) { + // obj is class instance, get the method to parse. + function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method); + if (py::isinstance(function_)) { + MS_LOG(ERROR) << "Get obj method function failed."; + return false; + } + target_type_ = PARSE_TARGET_OBJECT_INSTANCE; + // check the fn is method + auto obj_type = data_converter::GetObjType(function_); + if (obj_type != RESOLVE_TYPE_METHOD) { + MS_LOG(WARNING) << "Parse method function is invalid."; + return false; + } + } else { + MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type; + return false; + } + + // call python parse get ast tree + parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method); + ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse"); + + // get fn name and module + function_module_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_module")); + function_name_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_name")); + function_filename_ = py::cast(python_adapter::GetPyObjAttr(parser_, "filename")); + function_line_offset_ = py::cast(python_adapter::GetPyObjAttr(parser_, "line_offset")); + + return true; +} + +// Get ast tree node : is the tree bode list[0] +py::object ParseAst::GetAstNode() { + py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body"); + py::object ast_node = tree_body[0]; + return ast_node; +} + +py::list ParseAst::GetArgs(const py::object &func_node) { + py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node); + return ret; +} + +py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) { + py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node); + return ret; +} + +AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) { + py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node); + if (list_value.size() < 2) { + MS_LOG(ERROR) << "The node of python method must has 2 values."; + return nullptr; + } + auto node_name = py::cast(list_value[0]); + auto type = AstMainType(py::cast(list_value[1])); + return std::make_shared(node, node_name, type); +} + +AstSubType ParseAst::GetOpType(const py::object &node) { + auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast()); + return op_type; +} + +bool ParseAst::IsClassMember(const py::object &node) { + py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node); + if (!py::isinstance(ret)) { + MS_LOG(ERROR) << "The result of mod function parse, should be bool type."; + return false; + } + return ret.cast(); +} + +bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + MS_LOG(ERROR) << "FuncGraph is null"; + return false; + } + + if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { + MS_LOG(DEBUG) << "No flags"; + return true; + } + py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); + for (auto &item : flags) { + if (!py::isinstance(item.first)) { + MS_LOG(ERROR) << "Type error in flags dict convert"; + return false; + } + auto name = py::cast(item.first); + if (py::isinstance(item.second)) { + auto value = py::cast(item.second); + MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; + func_graph->set_flag(name, value); + } else if (py::isinstance(item.second)) { + auto value = py::cast(item.second); + MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; + func_graph->set_attr(name, MakeValue(value)); + } else { + MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; + return false; + } + } + return true; +} + +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.h b/mindspore/ccsrc/pipeline/jit/parse/parse.h new file mode 100644 index 0000000000..90e965389f --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.h @@ -0,0 +1,360 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_PARSE_H_ +#define PIPELINE_PARSE_PARSE_H_ + +#include +#include +#include +#include +#include +#include +#include "utils/misc.h" +#include "ir/anf.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/function_block.h" + +namespace mindspore { +namespace parse { + +// Parse status define +enum ParseStatusCode : int { + PARSE_SUCCESS = 0, + PARSE_FUNCTION_IS_NULL, // python function is null + PARSE_PARAMETER_INVALID, // parameter is invalid + PARSE_NO_RETURN, // function no return node + PARSE_NODE_TYPE_NO_MATCH, // ast node type is error + PARSE_NODE_TYPE_UNKOWN, // node type is unkown + PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node + PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string + PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported + PARSE_FAILURE = 0xFF +}; + +class AstNodeType; +class ParseAst; + +// Save loop info for 'continue' and 'break' statements. +struct Loop { + // Loop header block. + FunctionBlockPtr header; + // Loop iterator node, used in 'for loop'. + AnfNodePtr iterator; + // Loop end block. + FunctionBlockPtr end; + + Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) + : header(header), iterator(iterator), end(end) {} + ~Loop() = default; +}; + +// Loop context for loop stack management. +class LoopContext { + public: + LoopContext(std::stack *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { + loops_->emplace(header, iterator, nullptr); + } + ~LoopContext() { loops_->pop(); } + const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } + + private: + std::stack *loops_; +}; + +// Parser to parse python function +class Parser { + public: + explicit Parser(const std::shared_ptr &ast); + + ~Parser() {} + FuncGraphPtr ParseFuncGraph(); + FuncGraphPtr func_graph() const { return func_graph_; } + ParseStatusCode errcode() const { return errcode_; } + std::shared_ptr ast() const { return ast_; } + // get location info from the ast node + LocationPtr GetLocation(const py::object &node) const; + static void InitParserEnvironment(const py::object &obj); + static void CleanParserResource(); + static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); } + static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph); + + private: + // process the stmt node method list + FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node); + // parse expression + FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node); + // process a if statement + FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node); + // process a while statement + FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); + // process a for statement + FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); + FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node); + FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node); + // process a function def statement + FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node); + // process a augment assign + FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node); + // process a global declaration + FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); + // process assign statement + FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); + // process break statement + FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); + // process continue statement + FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); + // process pass statement + FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); + // process the expr and slice node method list + AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); + // process a variable name + AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); + // process NoneType + AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); + // process Ellipsis + AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); + // process a integer or float number + AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); + // process a string variable + AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node); + // process a name + AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); + // process a function call + AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); + // process the if expression + AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); + // process class type define + AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node); + // process a compare expression + AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node); + // process a bool operation + AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node); + // process a lambda operation + AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node); + // process a tuple + AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node); + // process a tuple + AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node); + // process a tuple + AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node); + // process a slice + AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node); + + // process a extslice + AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node); + + // process a tuple + AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node); + + // process a unaryop + AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node); + + // process a dict ast node expression + AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node); + // generate argument nodes for ast function node + void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node); + // generate argument default value for ast function node + void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); + // parse ast function node + FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); + // parse ast statements + FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); + // parse one ast statement node + FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node); + // parse an ast expresion node + AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node); + + void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock, + const FunctionBlockPtr &falseBlock); + void RemoveUnnecessaryPhis(); + // write a new var + void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node); + + // assign value to single variable name + void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // assign value to tuple + void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // assign value to class member + void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // assign value to subscript + void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); + + // process a bool operation value list + AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); + + CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, + const AnfNodePtr &op_iter); + + CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, + const AnfNodePtr &op_hasnext); + + FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info); + + bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, + std::vector *packed_arguments); + + bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector *packed_arguments, + std::vector *group_arguments); + + AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, + const std::vector &packed_arguments, + const std::vector &group_arguments, bool need_unpack) const; + ScopePtr GetScopeForParseFunction(); + void BuildMethodMap(); + FunctionBlockPtr MakeFunctionBlock(const Parser &parse) { + FunctionBlockPtr block = std::make_shared(parse); + // In order to keep effect order in the sub-graphs which generated by control flow. + // We copy the flags from the top graph to the sub-graphs. + if (func_graph_ && !func_graph_->attrs().empty()) { + block->func_graph()->set_attrs(func_graph_->attrs()); + } + func_block_list_.push_back(block); + return block; + } + // return a make tuple for input elements list + AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes); + + // shared_ptr will be hold by GraphManager, so just hold a weak ref here. + static FuncGraphWeakPtr top_func_graph_; + // Python function id, used to indicate whether two CNodes come from the same Python function + const std::shared_ptr &ast_; + FuncGraphPtr func_graph_; + // error code setwhen parsing ast tree + ParseStatusCode errcode_; + + // hold all reference for FunctionBlock in this round of parsing, + // so in FunctionBlock class we can use FunctionBlock* in member + // pre_blocks_ and jumps_ to break reference cycle. + std::vector func_block_list_; + using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); + using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); + // define the function map to parse ast Statement + std::map stmt_method_map_; + // define the function map to parse ast expression + std::map expr_method_map_; + // Save current loops to support 'continue', 'break' statement. + std::stack loops_; +}; + +// AST node type define code to ast +class AstNodeType { + public: + AstNodeType(const py::object &node, const std::string &name, AstMainType type) + : node_(node), node_name_(name), main_type_(type) {} + + ~AstNodeType() {} + + std::string node_name() const { return node_name_; } + + py::object node() const { return node_; } + + AstMainType main_type() const { return main_type_; } + + private: + const py::object &node_; + const std::string node_name_; + AstMainType main_type_; +}; + +using AstNodeTypePtr = std::shared_ptr; + +// A helper class to parse python function +class ParseAst { + public: + explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {} + + ~ParseAst() = default; + + bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); + + py::object GetAstNode(); + + py::list GetArgs(const py::object &func_node); + + py::list GetArgsDefaultValues(const py::object &func_node); + + AstNodeTypePtr GetNodeType(const py::object &node); + + AstSubType GetOpType(const py::object &node); + + template + py::object CallParserObjMethod(const std::string &method, const T &... args) { + return python_adapter::CallPyObjMethod(parser_, method, args...); + } + + template + py::object CallParseModFunction(const std::string &function, const T &... args) { + return python_adapter::CallPyModFn(module_, function, args...); + } + + const std::string &function_name() const { return function_name_; } + + const std::string &function_module() const { return function_module_; } + + const std::string &function_filename() const { return function_filename_; } + + int function_line_offset() const { return function_line_offset_; } + + py::function function() { return function_; } + + ParseTargetTypeDef target_type() const { return target_type_; } + + py::object obj() { return obj_; } + + py::object parser() { return parser_; } + + py::object module() { return module_; } + + py::object ast_tree() { return ast_tree_; } + + bool IsClassMember(const py::object &node); + + private: + // save obj,eg: class instance or function + py::object obj_; + + // function or class method. + py::function function_; + + py::object ast_tree_; + py::object parser_; + py::module module_; + + // Is function or method + ParseTargetTypeDef target_type_; + + std::string function_name_; + std::string function_module_; + std::string function_filename_; + int function_line_offset_; +}; + +// update the graph flags +bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); + +AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_PARSE_H_ diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h similarity index 100% rename from mindspore/ccsrc/pipeline/parse/parse_base.h rename to mindspore/ccsrc/pipeline/jit/parse/parse_base.h diff --git a/mindspore/ccsrc/pipeline/jit/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.cc new file mode 100644 index 0000000000..17be74b2a1 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/python_adapter.h" +#include +#include +#include + +namespace mindspore { +namespace parse { +namespace python_adapter { +// python scoped env, should only have one scoped_ instance +static std::shared_ptr scoped_ = nullptr; +// true: start process from python, false: start process from c++ +static bool python_env_ = false; +static bool use_signature_in_resolve_ = true; +void ResetPythonScope() { scoped_ = nullptr; } +void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_in_resolve_ = use_signature; } +bool UseSignatureInResolve() { return use_signature_in_resolve_; } +void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } +bool IsPythonEnv() { return python_env_; } +void SetPythonPath(const std::string &path) { + // load the python module path + (void)python_adapter::set_python_scoped(); + py::module sys = py::module::import("sys"); + py::list sys_path = sys.attr("path"); + + // check the path is exist? + bool is_exist = false; + for (size_t i = 0; i < sys_path.size(); i++) { + std::string path_str = py::cast(sys_path[i]); + if (path_str == path) { + is_exist = true; + } + } + if (!is_exist) { + (void)sys_path.attr("append")(path.c_str()); + } +} + +std::shared_ptr set_python_scoped() { + // if start process from python, no need set the python scope. + if (!python_env_) { + if ((Py_IsInitialized() == 0) && (scoped_ == nullptr)) { + scoped_ = std::make_shared(); + } + } + return scoped_; +} + +// return the module of python +py::module GetPyModule(const std::string &module) { + if (!module.empty()) { + return py::module::import(module.c_str()); + } else { + return py::none(); + } +} + +// Get the obj of attr +py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { + if (!attr.empty() && !py::isinstance(obj)) { + if (py::hasattr(obj, attr.c_str())) { + return obj.attr(attr.c_str()); + } + MS_LOG(DEBUG) << "Obj have not the attr: " << attr; + } + return py::none(); +} + +py::object GetPyFn(const std::string &module, const std::string &name) { + (void)python_adapter::set_python_scoped(); + if (!module.empty() && !name.empty()) { + py::module mod = py::module::import(module.c_str()); + py::object fn = mod.attr(name.c_str()); + return fn; + } + return py::none(); +} + +} // namespace python_adapter +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h new file mode 100644 index 0000000000..0f49539bc8 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/python_adapter.h @@ -0,0 +1,78 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_PYTHON_ADAPTER_H_ +#define PIPELINE_PARSE_PYTHON_ADAPTER_H_ +#include +#include +#include + +#include "pybind11/embed.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +#include "pipeline/jit/parse/parse_base.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace parse { +// A utility to call python interface +namespace python_adapter { +py::module GetPyModule(const std::string &module); +py::object GetPyObjAttr(const py::object &obj, const std::string &attr); +template +py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { + if (!method.empty() && !py::isinstance(obj)) { + return obj.attr(method.c_str())(args...); + } + return py::none(); +} + +// call python function of module +template +py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { + if (!function.empty() && !py::isinstance(mod)) { + return mod.attr(function.c_str())(args...); + } + return py::none(); +} + +// turn off the signature when ut use parser to construct a graph. +void set_use_signature_in_resolve(bool use_signature) noexcept; +bool UseSignatureInResolve(); + +std::shared_ptr set_python_scoped(); +void ResetPythonScope(); +bool IsPythonEnv(); +void SetPythonPath(const std::string &path); +void set_python_env_flag(bool python_env) noexcept; +py::object GetPyFn(const std::string &module, const std::string &name); +// Call the python function +template +py::object CallPyFn(const std::string &module, const std::string &name, T... args) { + (void)set_python_scoped(); + if (!module.empty() && !name.empty()) { + py::module mod = py::module::import(module.c_str()); + py::object fn = mod.attr(name.c_str())(args...); + return fn; + } + return py::none(); +} +} // namespace python_adapter +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_PYTHON_ADAPTER_H_ diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc new file mode 100644 index 0000000000..9524da4cfd --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -0,0 +1,320 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/parse/resolve.h" + +#include +#include +#include +#include + +#include "ir/param_value.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "utils/any.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass.h" +#include "./common.h" + +namespace mindspore { +namespace parse { +abstract::AbstractBasePtr ClassObject::ToAbstract() { + ClassPtr cls_ptr = ParseDataClass(obj()); + auto abs_scalar = std::make_shared(); + abs_scalar->set_type(std::make_shared()); + abs_scalar->set_value(cls_ptr); + + AbstractBasePtrList args_spec_list = {abs_scalar}; + auto func_ptr = std::make_shared(prim::kPrimMakeRecord); + return std::make_shared(func_ptr, args_spec_list); +} + +abstract::AbstractBasePtr ClassType::ToAbstract() { + auto abs_scalar = + std::make_shared(shared_from_base(), std::make_shared()); + AbstractBasePtrList args_spec_list = {abs_scalar}; + + auto func_ptr = std::make_shared(prim::kPrimCreateInstance); + auto ret_val = std::make_shared(func_ptr, args_spec_list); + ret_val->set_value_desc(ToString()); + return ret_val; +} + +// call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace +bool SymbolResolver::Resolve() { + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + + py::object obj = namespace_->obj(); + std::string symbol = symbol_->symbol(); + if (py::isinstance(obj)) { + MS_LOG(ERROR) << "Unresolved symbol: " << symbol; + return false; + } + result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol)); + return true; +} + +namespace { +// argument obj should be python Parameter object +// it will be converted to Parameter node here +AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { + MS_EXCEPTION_IF_NULL(func_graph); + + // parameter object should not be none + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null."; + } + + if (!py::hasattr(obj, "name")) { + MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj"; + } + + // get the parameter name from parameter object + auto name_attr = python_adapter::GetPyObjAttr(obj, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; + } + + std::string param_name = py::cast(name_attr); + auto top_graph = Parser::GetTopFuncGraph(); + // if the parameter node has been created , return it + AnfNodePtr para_node = nullptr; + for (auto const ¶m : top_graph->parameters()) { + auto param_node = dyn_cast(param); + if (param_node != nullptr && param_node->name() == param_name) { + para_node = param; + break; + } + } + if (para_node == nullptr) { + auto node = top_graph->AddWeightParameter(param_name); + auto param_value = py::cast(python_adapter::GetPyObjAttr(obj, "_value")); + node->set_default_param(param_value); + // set_abstract for parameter + ValuePtr value = param_value->value(); + constexpr bool broaden = true; + node->set_abstract(abstract::FromValue(value, broaden)); + para_node = node; + } + auto iter = func_graph->make_ref_params().find(para_node); + if (iter == func_graph->make_ref_params().end()) { + AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); + + AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); + AnfNodePtr ref_key = NewValueNode(std::make_shared(param_name)); + AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); + func_graph->make_ref_params()[para_node] = ref_node; + func_graph->add_parameter_obj_node(ref_node); + return ref_node; + } else { + return iter->second; + } +} + +bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { + AnfNodePtr output = nullptr; + if (py::hasattr(obj, "__parameter__")) { + auto param = ResolveParameterObj(func_graph, obj); + if (param == nullptr) { + MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; + return false; + } + MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString(); + + output = param; + } else if (py::hasattr(obj, "__parameter_tuple__")) { + auto tuple = obj.cast(); + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t it = 0; it < tuple.size(); ++it) { + AnfNodePtr out = nullptr; + bool success = ResolveObjectToNode(func_graph, tuple[it], &out); + if (!success) { + MS_LOG(ERROR) << "Resolve object to node failed"; + return false; + } + args.push_back(out); + } + output = NewCNode(args, func_graph); + } else { + ValuePtr convert_result = nullptr; + bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve()); + if (!converted) { + MS_LOG(ERROR) << "Convert data failed"; + return false; + } + MS_EXCEPTION_IF_NULL(convert_result); + output = NewValueNode(convert_result); + if (convert_result->isa()) { + output = GetMixedPrecisionCastHelp(func_graph, output); + } + } + *node = output; + return true; +} + +bool IsAllGraphInValueSequence(const std::vector &value_vec) { + for (auto &elem : value_vec) { + if (elem->isa() || elem->isa()) { + const auto &vec = GetValue>(elem); + auto is_graph = IsAllGraphInValueSequence(vec); + if (!is_graph) { + return false; + } + } else if (!elem->isa()) { + return false; + } + } + return true; +} + +AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, + const std::vector &value_vec) { + std::vector nodes; + nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (auto &elem : value_vec) { + AnfNodePtr node = nullptr; + if (elem->isa() || elem->isa()) { + const auto &vec = GetValue>(elem); + node = TransformToMakeTupleNodes(manager, func_graph, vec); + } else if (elem->isa()) { + FuncGraphPtr new_fg = elem->cast(); + manager->AddFuncGraph(new_fg); + node = NewValueNode(new_fg); + } else { + MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); + } + nodes.emplace_back(node); + } + auto cnode = func_graph->NewCNode(nodes); + return cnode; +} + +// transform the ValueTuple or ValueList of graph node to make tuple of const graph node +bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, + const ValueNodePtr &value_node, AnfNodePtr *const transformed) { + MS_EXCEPTION_IF_NULL(value_node); + const auto &value_vec = GetValue>(value_node->value()); + if (!IsAllGraphInValueSequence(value_vec)) { + return false; + } + + // The celllist or ordered_cell will be parsed as valuetuple of const graph in it, + // So if has graph in list, try to replace the node with make tuple of graph value node. + // we do this because the graphmanger won't investigate the graph inside valuetuple, + // change the vector of graph to be make_tuple of graph value node + auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); + // replace the ret ptr to be make tuple of graph value node + *transformed = node_tuple_graphs; + + return true; +} +} // namespace + +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { + if (node->func_graph() == nullptr || manager == nullptr) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; + } + SymbolResolver symbol_resolver(name_space, symbol, node); + if (!symbol_resolver.Resolve()) { + MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + + py::object obj = symbol_resolver.result(); + ScopeGuard scope_guard(node->scope()); + AnfNodePtr resolved_node = nullptr; + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); + if (!success) { + MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + if (IsValueNode(resolved_node)) { + auto new_fg = GetValueNode(resolved_node); + manager->AddFuncGraph(new_fg); + } + + // if the constant node is constant of vector of graph ,add graph to manager + if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { + (void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast(), + &resolved_node); + } + + TraceManager::EndTrace(); + return resolved_node; +} + +namespace { +opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { + opt::OptPassGroupMap map({ + {"resolve", + { + // for resolve and getattr primitive; + irpass.resolver_resolve_, + irpass.resolver_getattr_, + }}, + }); + return map; +} +} // namespace + +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { + if (func_graph == nullptr || res == nullptr) { + MS_LOG(ERROR) << "func_graph or resource is null"; + return false; + } + opt::irpass::ResolveIRPassLib irpass; + opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); + + (void)parse::python_adapter::set_python_scoped(); + + MS_EXCEPTION_IF_NULL(opt_resolve); + (void)opt_resolve->step(func_graph, use_profile); + return true; +} + +bool ResolveAll(const FuncGraphManagerPtr &manager) { + if (manager == nullptr) { + MS_LOG(ERROR) << "func graph manager is null"; + return false; + } + + if (manager->roots().size() > 1) { + MS_LOG(WARNING) + << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs" + "called from root graph, so it's not necessary to pass all graphs as roots. " + "Please ensure your usage."; + } + // should not use pipeline::Resource as Resource::Clean will clean some + // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop + // fail as valid scope has been cleaned. + auto res = std::make_shared(); + res->set_manager(manager); + + auto roots = manager->roots(); + for (auto &fg : roots) { + bool ret = ResolveFuncGraph(fg, res, false); + if (!ret) { + MS_EXCEPTION_IF_NULL(fg); + MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed"; + } + } + return true; +} +} // namespace parse +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h new file mode 100644 index 0000000000..d924f1ef44 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_PARSE_RESOLVE_H_ +#define PIPELINE_PARSE_RESOLVE_H_ + +#include +#include +#include "ir/anf.h" +#include "ir/manager.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/parse_base.h" +#include "abstract/abstract_value.h" +#include "utils/log_adapter.h" + +// forward declaration of ResourceBase +namespace mindspore { +namespace pipeline { +class ResourceBase; +using ResourceBasePtr = std::shared_ptr; +} // namespace pipeline +} // namespace mindspore + +namespace mindspore { +namespace parse { + +// NameSpace class for resolving python code. +class NameSpace : public Named { + public: + NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} + ~NameSpace() override = default; + MS_DECLARE_PARENT(NameSpace, Named); + + py::object obj() { return obj_; } + std::string module() { return module_; } + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), std::make_shared()); + } + + private: + // namespace of the module + std::string module_; + // namespace object + py::object obj_; +}; +using NameSpacePtr = std::shared_ptr; + +// Symbol in NameSpace or Class which shall be resolved. +class Symbol : public Named { + public: + explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} + explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} + + ~Symbol() override = default; + MS_DECLARE_PARENT(Symbol, Named); + + std::string symbol() { return symbol_; } + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), std::make_shared()); + } + + private: + std::string symbol_; +}; +using SymbolPtr = std::shared_ptr; + +// PyObjectWrapper class wrappers resolved python object for further processing. +class PyObjectWrapper : public Named { + public: + explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + ~PyObjectWrapper() override = default; + MS_DECLARE_PARENT(PyObjectWrapper, Named); + py::object obj() { return obj_; } + + private: + // the object that needs to be resolved + py::object obj_; +}; + +// ClassObject class wrappers dataclass +class ClassObject : public PyObjectWrapper { + public: + explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") + : PyObjectWrapper(obj, name) {} + ~ClassObject() override = default; + MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); + abstract::AbstractBasePtr ToAbstract() override; +}; + +// ClassType class wrappers class name in python +class ClassType : public PyObjectWrapper { + public: + explicit ClassType(const py::object &obj, const std::string name = "Python class type") + : PyObjectWrapper(obj, name) {} + ~ClassType() override = default; + MS_DECLARE_PARENT(ClassType, PyObjectWrapper); + abstract::AbstractBasePtr ToAbstract() override; +}; + +// SymbolResolver class for resolving symbol extracted from AnfNode. +class SymbolResolver { + public: + SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) + : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} + + ~SymbolResolver() = default; + + // resolve symbol in namespace and save it in result_; + bool Resolve(); + + NameSpacePtr get_namespace() { return namespace_; } + + SymbolPtr symbol() { return symbol_; } + + py::object &result() { return result_; } + + AnfNodePtr resolved_node() { return resolved_node_; } + + // Resolve result + py::object result_; + + private: + // namespace where the symbol locates + NameSpacePtr namespace_; + // the symbol that needs to be resovled + SymbolPtr symbol_; + // the node that has been resolved + AnfNodePtr resolved_node_; +}; +using SymbolResolverPtr = std::shared_ptr; +// Resolve symbol in namespace. +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node); + +// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); + +// Resolve all graphs in manager which is defined outside of pipeline::Resource. +// Mainly used for test cases or resolve graphs which will not be managed by manager. +bool ResolveAll(const FuncGraphManagerPtr &manager); + +} // namespace parse +} // namespace mindspore + +#endif // PIPELINE_PARSE_RESOLVE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc new file mode 100644 index 0000000000..bb9a517556 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -0,0 +1,340 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/pass.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/func_graph_cloner.h" +#include "debug/anf_ir_utils.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/validator.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/cse.h" +#include "frontend/optimizer/graph_kernel_reuse.h" +#include "frontend/optimizer/clean.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/control_depend.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/step_auto_parallel.h" +#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" +#include "utils/any.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace pipeline { +using OptPassGroupMap = opt::OptPassGroupMap; +using Optimizer = opt::Optimizer; +using CompileGraphs = compile::CompileGraphs; +using abstract::AnalysisResult; +using mindspore::abstract::AnalysisContextPtr; +using mindspore::validator::Validate; + +bool SimplifyDataStructuresPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + + FuncGraphPtr func_graph = res->func_graph(); + bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); + + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + if (changed) { + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + } + res->set_args_spec(args_spec); + return true; +} + +namespace { +OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig a_1 = opt::OptPassConfig({ + irpass.switch_simplify_, + + // Safe inlining + irpass.inline_, + irpass.partial_eliminate_, + irpass.replace_applicator_, + + // Specialization + irpass.specialize_transform_, + + // Miscellaneous + irpass.item_tuple_eliminate_, + irpass.env_get_item_eliminate_, + irpass.cast_eliminate_, + irpass.reshape_eliminate_, + irpass.reduce_eliminate_, + irpass.tile_eliminate_, + irpass.transpose_eliminate_, + irpass.minmaximum_grad_, + irpass.get_make_ref_eliminate_, + + // Arithmetic simplifications + irpass.arithmetic_simplify_, + irpass.addn_zero_filter_, + irpass.adjust_all_reduce_mul_add_, + + // Safe inlining + irpass.inline_, + }); + opt::OptPassConfig a_2 = opt::OptPassConfig({ + irpass.merge_addn_, + irpass.float_tuple_getitem_switch_, + irpass.float_env_getitem_switch_, + irpass.incorporate_getitem_set_, + irpass.incorporate_call_, + irpass.incorporate_call_switch_, + irpass.incorporate_env_getitem_, + irpass.incorporate_env_getitem_switch_, + irpass.new_env_get_item_, + irpass.depend_value_elim_, + }); + opt::OptPassConfig a_3 = opt::OptPassConfig({ + irpass.arithmetic_simplify2_, + irpass.same_eliminate_, + irpass.check_bprop_eliminate_, + irpass.replace_applicator_, + }); + opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); + opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); + opt::irpass::ResolveIRPassLib resolve_irpass; + + opt::OptPassConfig resolve_pass = + opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, + irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); + + OptPassGroupMap map_a({{"a_1", a_1}, + {"a_2", a_2}, + {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, + {"parallel", opt::OptPassConfig(parallel::StepParallel)}, + {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, + {"virtual_dataset", virtual_dataset}, + {"grad", grad}, + {"resolve", resolve_pass}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSE(false))}, + {"a_3", a_3}}); + + return map_a; +} + +OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig b_1 = opt::OptPassConfig({ + irpass.zero_like_fill_zero_, + irpass.item_tuple_eliminate_, + irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, + irpass.inline_, + irpass.special_op_eliminate_, + irpass.get_make_ref_eliminate_, + }); + opt::OptPassConfig b_2 = opt::OptPassConfig({ + irpass.replace_refkey_by_param_, + irpass.make_ref_eliminate_, + irpass.get_ref_param_eliminate_, + irpass.indexed_slices_eliminate_, + }); + OptPassGroupMap map({ + {"b_1", b_1}, + {"b_2", b_2}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSE(false))}, + }); + return map; +} + +OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig interface_fusion = opt::OptPassConfig({ + irpass.mark_interface_fusion_, + }); + OptPassGroupMap map({ + {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, + {"interface_fusion", interface_fusion}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"cse", opt::OptPassConfig(opt::CSE(false))}, + }); + return map; +} + +OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig elim_1 = opt::OptPassConfig({ + irpass.addn_eliminate_, + irpass.incorporate_getitem_from_param_, + }); + opt::OptPassConfig elim_2 = opt::OptPassConfig({ + irpass.unused_parameter_eliminate_, + irpass.unused_output_eliminate_, + }); + OptPassGroupMap map({ + {"elim_1", elim_1}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + {"elim_2", elim_2}, + }); + return map; +} + +OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { + return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); +} + +OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); + OptPassGroupMap map({ + {"control_group", control_group}, + {"renormalize", opt::OptPassConfig::Renormalize()}, + }); + return map; +} + +OptPassGroupMap GetInferenceOptPreparePhases() { + opt::irpass::InferenceOptPrepareLib irpass; + auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); + opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); + return prepare_map; +} + +OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { + opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); + OptPassGroupMap map({{"prepare_group", prepare_group}}); + return map; +} + +static std::unordered_map> g_pass_opts = {}; + +void InitOpt(const ResourcePtr &res) { + if (g_pass_opts.size() == 0) { + opt::irpass::OptimizeIRPassLib irpass; + g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); + g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); + g_pass_opts["opt_graph_kernel_a"] = + Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); + g_pass_opts["opt_graph_kernel_b"] = + Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); + g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); + g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); + g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!(context_ptr->enable_graph_kernel())) { + g_pass_opts["opt_graph_kernel_a"]->set_enable(false); + g_pass_opts["opt_graph_kernel_b"]->set_enable(false); + } + } +} +} // namespace + +void ReclaimOptimizer() { + for (auto &opt : g_pass_opts) { + opt.second = nullptr; + } + g_pass_opts.clear(); +} + +bool OptPassGroup(const ResourcePtr &res, const std::string &name) { + if (res->func_graph() == nullptr) { + MS_LOG(ERROR) << "Opt passes int error"; + return false; + } + + FuncGraphPtr func_graph = res->func_graph(); + MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " + << func_graph->get_return()->DebugString(true); + InitOpt(res); + if (g_pass_opts.find(name) != g_pass_opts.end()) { + res->set_func_graph(g_pass_opts[name]->step(func_graph)); + } + // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to + // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. + return true; +} + +bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } +bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } +bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } +bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } +bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } + +bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } + +bool AddControlDependPass(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + + if (func_graph->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { + opt::AddControlDepend(func_graph); + } + for (auto fg : func_graph->func_graphs_used_total()) { + MS_EXCEPTION_IF_NULL(fg); + if (fg->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { + opt::AddControlDepend(fg); + } + } + return true; +} + +bool CconvPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + FuncGraphPtr new_fg = LiftingClone(func_graph); + res->set_func_graph(new_fg); + return true; +} + +bool ValidatePass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + FuncGraphPtr func_graph = res->func_graph(); + Validate(func_graph); + return true; +} + +bool InferenceOptPreparePass(const ResourcePtr &res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto prepare_map = GetInferenceOptPreparePhases(); + auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); + (void)infer_opt_prepare->step(func_graph, false); + return true; +} + +std::vector kVmPasses = {{"opt_a", OptPassAGroup}, + {"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_b", OptPassBGroup}, + {"cconv", CconvPass}, + {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, + {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, + {"add_control_depend", AddControlDependPass}}; + +std::vector kGePasses = { + {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, + {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, + {"cconv", CconvPass}}; + +std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.h b/mindspore/ccsrc/pipeline/jit/pass.h new file mode 100644 index 0000000000..0233b6cf26 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pass.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_PASS_H_ +#define MINDSPORE_CCSRC_PIPELINE_PASS_H_ + +#include +#include +#include +#include +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace pipeline { +using PassItem = std::pair>; + +extern std::vector kGePasses; +extern std::vector kVmPasses; +extern std::vector kPynativePasses; + +bool CconvPass(const ResourcePtr &res); +bool ValidatePass(const ResourcePtr &res); +bool ConvertPrepareAdapt(const ResourcePtr &res); +bool AddControlDependPass(const ResourcePtr &res); +bool InferenceOptPreparePass(const ResourcePtr &res); +void ReclaimOptimizer(); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_PASS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc new file mode 100644 index 0000000000..05699793ff --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -0,0 +1,948 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/pipeline.h" + +#include +#include +#include +#include +#include + +#include "ir/param_value.h" +#include "pipeline/jit/pass.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "debug/anf_ir_dump.h" +#include "debug/anf_ir_utils.h" +#include "utils/config_manager.h" +#include "utils/convert_utils.h" +#include "utils/utils.h" +#include "vm/segment_runner.h" +#include "frontend/parallel/context.h" +#include "frontend/parallel/graph_util/get_parallel_info.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "debug/trace.h" +#include "pipeline/pynative/pynative_execute.h" +#include "frontend/optimizer/py_pass_manager.h" + +#if (ENABLE_GE || ENABLE_D) +#include "pipeline/jit/pipeline_ge.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/df_graph_manager.h" +#endif + +namespace mindspore { +// namespace to support intermediate representation definition +namespace pipeline { +using Tensor = mindspore::tensor::Tensor; +using MetaTensor = mindspore::tensor::MetaTensor; +using TensorOrderMap = std::map>; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTensorPtr; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; + +const char IR_TYPE_ANF[] = "anf_ir"; +const char IR_TYPE_ONNX[] = "onnx_ir"; +const char IR_TYPE_BINARY[] = "binary_ir"; + +ExecutorPyPtr ExecutorPy::executor_ = nullptr; +std::mutex ExecutorPy::instance_lock_; + +std::unordered_map + g_args_cache; + +namespace { +std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { + std::ostringstream oss; + auto ms_context = MsContext::GetInstance(); + if (ms_context == nullptr) { + MS_LOG(EXCEPTION) << "ms_context is nullptr"; + } + auto save_graphs_path = ms_context->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + oss << save_graphs_path << "/" << stage_idx << "_" << action_name; + return oss.str(); +} +} // namespace + +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { + MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); + abstract::AbstractBasePtrList args_spec; + + for (auto arg : defaults) { + if (py::isinstance(arg.second)) { + MS_LOG(EXCEPTION) << "GenerateKey failed, argument input should not be py::module"; + } + ValuePtr converted = nullptr; + if (!parse::ConvertData(arg.second, &converted)) { + MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; + } + args_spec.push_back(abstract::FromValue(converted, true)); + } + if (g_args_cache.count(args_spec) == 0) { + static int key = 0; + MS_LOG(INFO) << "Start new args and compile key:" << key; + g_args_cache[args_spec] = key++; + } + auto argSpec = py::tuple(2); + argSpec[0] = name; + argSpec[1] = g_args_cache[args_spec]; + return argSpec; +} + +py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs) { + MS_LOG(DEBUG) << "Verify args size:" << inputs.size(); + if (inputs.size() != input_signature.size()) { + MS_LOG(ERROR) << "Signature size not equal to args size"; + return false; + } + + size_t count = 0; + for (auto arg_obj : inputs) { + if (py::hasattr(arg_obj, PYTHON_TENSOR_FLAG)) { + MS_LOG(DEBUG) << "Verify Tensor"; + std::shared_ptr m_tensor = arg_obj.cast>(); + if (m_tensor == nullptr) { + MS_LOG(ERROR) << "Verify Tensor error, get ptr is null"; + return false; + } + std::shared_ptr sig = input_signature[count].cast>(); + std::vector sig_shape = sig->shape(); + TypePtr sig_type = sig->Dtype(); + + std::vector tensor_shape = m_tensor->shape_c(); + if (tensor_shape != sig_shape) { + MS_LOG(ERROR) << "Python input shape is incompatible with input_signature"; + return false; + } + + if (*m_tensor->Dtype() != *sig_type) { + MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature(" + << sig_type->ToString() << ")"; + return false; + } + } + count++; + } + + return true; +} + +ExecutorPy::ExecutorPy() {} + +ResourcePtr ExecutorPy::GetResource(const std::string &phase) { + MS_LOG(DEBUG) << "Phase size:" << info_.size(); + if (info_.count(phase) == 0) { + return nullptr; + } + return info_[phase]->resource; +} + +FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { + if (info_.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + return info_[phase]->func_graph; +} + +std::size_t ExecutorPy::ArgListSize(const std::string &phase) { + if (info_.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + return info_[phase]->arg_list_size; +} + +compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { + ResourcePtr res = GetResource(phase); + MS_EXCEPTION_IF_NULL(res); + if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { + return res->results()[kOutput].cast(); + } + MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput; + return nullptr; +} + +bool ExecutorPy::HasCompiled(const std::string &phase) const { + if (info_.count(phase) == 0) { + return false; + } + return true; +} + +py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { + FuncGraphPtr fg_ptr = GetFuncGraph(phase); + if (fg_ptr == nullptr) { + for (auto &item : info_) { + MS_LOG(DEBUG) << "Phase key is: " << item.first; + } + MS_LOG(EXCEPTION) << "Can not find func graph " << phase; + } + + if (ir_type == IR_TYPE_ANF) { + std::string proto_str = GetFuncGraphProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + + if (ir_type == IR_TYPE_ONNX) { + std::string proto_str = GetOnnxProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + + if (ir_type == IR_TYPE_BINARY) { + std::string proto_str = GetBinaryProtoString(fg_ptr); + if (proto_str.empty()) { + MS_LOG(EXCEPTION) << "Graph proto is empty."; + } + return proto_str; + } + + MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; +} + +py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { + MS_LOG(DEBUG) << "GetParameterLayout!"; + std::string layout_graph = phase + kStepParallelGraph; + auto graph = GetFuncGraph(layout_graph); + return mindspore::parallel::GetParameterLayout(graph); +} + +py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { + MS_LOG(DEBUG) << "GetCNodeStrategy!"; + std::string layout_graph = phase + kStepParallelGraph; + auto graph = GetFuncGraph(layout_graph); + return mindspore::parallel::GetCNodeStrategy(graph); +} + +py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { + MS_LOG(INFO) << "GetAllreduceFusion!"; + auto graph = GetFuncGraph(phase); + return mindspore::parallel::GetAllreduceFusion(graph); +} + +void ExecutorPy::DelNetRes(const std::string &id) { +#ifdef ENABLE_GE + FinalizeBackend(); +#endif + if (executor_ != nullptr) { + bool flag = false; + auto tmp_info = info_; + for (auto &item : tmp_info) { + if (item.first.find(id) != string::npos) { + MS_LOG(DEBUG) << "Delete network res:" << item.first; + (void)info_.erase(item.first); + flag = true; + } + } + + MS_LOG(DEBUG) << "Delete flag:" << flag; +#ifdef ENABLE_GE + if (flag && info_.size() == 0) { + // because Ge only support one Session exist at the same time ,so we delete the old one + transform::DfGraphManager::GetInstance().DeleteGraphRunner(); + transform::DfGraphManager::GetInstance().EraseAnfGraph(); + transform::DfGraphManager::GetInstance().DeleteGeSession(); + } +#endif + } +} + +void ExecutorPy::ClearRes() { + MS_LOG(INFO) << "Clean executor resource!"; + executor_ = nullptr; +} + +ExecutorPy::~ExecutorPy() { + MS_LOG(INFO) << "Release Executor!"; + ConfigManager::GetInstance().ResetConfig(); +} + +std::map> ExecutorPy::FetchInfoForQuantExport( + const std::string &phase_s) { + FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; + std::map> fake_quant_table; + auto filter = [](AnfNodePtr node) { + return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) || + IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative)); + }; + std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); + auto is_quant_cnode = [](AnfNodePtr node) { + return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || + IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); + }; + for (auto node : nodes) { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() != 3) { + continue; + } + auto x = cnode->input(1); + auto weight = cnode->input(2); + if (!is_quant_cnode(weight)) { + continue; + } + // get parameter weight's name + cnode = weight->cast(); + auto weight_node = cnode->input(2); + if (!weight_node->isa()) { + continue; + } + auto weight_name = weight_node->cast()->name(); + // find the fakequant from input + int count = 0; + const int max_depth = 5; + while (!is_quant_cnode(x)) { + if (count >= max_depth) { + break; + } + cnode = x->cast(); + if (cnode == nullptr || cnode->size() <= 1) { + break; + } + x = cnode->input(1); + count += 1; + } + if (x->isa()) { + fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); + } + // get the fakequant parameter minq's name + if (!is_quant_cnode(x)) { + continue; + } + cnode = x->cast(); + if (cnode == nullptr || cnode->size() != 4) { + continue; + } + auto fakequant_min_node = cnode->input(2); + if (!fakequant_min_node->isa()) { + continue; + } + auto fakequant_min_node_name = fakequant_min_node->cast()->name(); + auto quant_op_value = cnode->input(0)->cast()->value(); + if (!quant_op_value->isa()) { + continue; + } + auto quant_op = quant_op_value->cast(); + fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); + } + + return fake_quant_table; +} + +void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { + // save the graph to ExecutorPy + FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); + std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); + + MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; + info_[phase_s]->func_graph = func_graph; + if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) && + ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) { + MS_LOG(DEBUG) << "Save model parallel parameter layout graph!"; + func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast(); + ExecutorInfoPtr executor_info = std::make_shared(); + std::string layout_graph = phase_s + kStepParallelGraph; + executor_info->func_graph = func_graph; + info_[layout_graph] = executor_info; + } else { + MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!"; + } + MS_LOG(INFO) << "End save compiled func graph!"; +} + +bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { + std::string phase_prefix = GetPhasePrefix(phase_s); + + if (use_vm && phase_prefix == "export") { + MS_LOG(INFO) << "Use ge backend to export geir"; + use_vm = false; + } + return use_vm; +} + +void ExecutorPy::GetGeBackendPolicy() const { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::string backend = ms_context->backend_policy(); + if (backend != "ge") { + MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!"; + } +} + +bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { + MS_LOG(DEBUG) << "Start ExecutorPy compile!"; + if ((!py::isinstance(phase))) { + MS_LOG(ERROR) << "Arg phase must be string."; + return false; + } + // check the arg valid? + if (py::isinstance(obj)) { + MS_LOG(ERROR) << "Find error: parse obj is None."; + return false; + } +#ifdef ENABLE_GE + GetGeBackendPolicy(); +#endif + ExecutorInfoPtr executor_info = std::make_shared(); + std::string phase_s = py::cast(phase); + MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!"; + ResourcePtr resource = std::make_shared(obj); + std::vector p_actions; + + use_vm = ChangeExportGeirUseVmFlag(use_vm, phase_s); + + std::string backend = MsContext::GetInstance()->backend_policy(); + if (use_vm && backend != "ge") { + // Create backend and session + auto backend_ptr = compile::CreateBackend(); + // Connect session to debugger + backend_ptr->SetDebugger(); + resource->results()[kBackend] = backend_ptr; + p_actions = VmPipeline(); + } else { + p_actions = GePipeline(); + } + + std::shared_ptr pip = std::make_shared(resource, FilterActions(p_actions, phase_s)); + + // get the parameters items and add the value to args_spec + abstract::AbstractBasePtrList args_spec; + std::size_t size = args.size(); + for (std::size_t i = 0; i < size; i++) { + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(args[i], &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "Args convert error"; + } + bool broaden = true; + args_spec.push_back(abstract::FromValue(converted, broaden)); + } + + resource->set_args_spec(args_spec); + executor_info->arg_list_size = size; + executor_info->resource = resource; + info_[phase_s] = executor_info; + pip->Run(); + + // save the run graph func to MsPipeLine + SaveCompiledGraph(phase_s); + + resource->Clean(); + // Reclaim all resource used by optimizer; + ReclaimOptimizer(); + + MS_LOG(INFO) << "End ExecutorPy compile!"; + return true; +} + +std::vector ExecutorPy::FilterActions(const std::vector &actions, const std::string &phase) { + // phase does not contain 'export_onnx' + if (GetPhasePrefix(phase).find("export_onnx") == std::string::npos) { + return actions; + } + MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'"; + std::vector filtered_actions; + for (const auto &item : actions) { + filtered_actions.emplace_back(item); + if (item.first == "validate") { + break; + } + } + return filtered_actions; +} + +void ExecutorPy::ReleaseResource(const py::object &phase) { + ResourcePtr res = GetResource(py::cast(phase)); + if (res != nullptr) { + res->Clean(); + } + // Reclaim all resource used by optimizer; + ReclaimOptimizer(); +} + +static std::string PrintArgs(const py::tuple &args) { + py::print(args); + return ""; +} + +bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { + bool ret_value = false; + + try { + MS_LOG(DEBUG) << PrintArgs(args); + ret_value = CompileInner(obj, args, phase, use_vm); + } catch (const py::error_already_set &ex) { + // print function call stack info before release + std::ostringstream oss; + trace::TraceGraphEval(); + trace::GetEvalStackInfo(oss); + // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info + py::print(oss.str()); + MS_LOG(ERROR) << oss.str(); + ReleaseResource(phase); + + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + ReleaseResource(phase); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + ReleaseResource(phase); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + ReleaseResource(phase); + throw py::index_error(ex); + } catch (const std::exception &ex) { + ReleaseResource(phase); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + ReleaseResource(phase); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } + + return ret_value; +} + +#ifdef ENABLE_LOAD_ANF_IR +// get MindSpore Intermediate Representation File +std::string GetMsIrFile(void) { + std::string file; + const char *path = getenv("MS_IR_FILE"); + if (path == nullptr) { + return file; + } + + char real_path[PATH_MAX] = {0}; + if (realpath(path, real_path) == nullptr) { + MS_LOG(ERROR) << "MS IR path error, " << path; + return file; + } + file = real_path; + return file; +} + +void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { + MS_EXCEPTION_IF_NULL(resource); + MS_EXCEPTION_IF_NULL(result); + + std::string ir_file = GetMsIrFile(); + (void)parse::python_adapter::set_python_scoped(); + if (ir_file.empty()) { + *result = action.second(resource); + return; + } + + // when in loading anf ir mode, action `parse` do nothing + if (action.first == "parse") { + return; + } + + // load MindSpore IR from file + if (action.first == "symbol_resolve") { + MS_LOG(DEBUG) << action.first << " read ir file: " << ir_file; + std::vector graphs = ImportIR(ir_file); + if (graphs.size() == 0) { + MS_LOG(EXCEPTION) << action.first << " read ir file " << ir_file << " failed as no graph found"; + } + auto manager = resource->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (auto &graph : graphs) { + manager->AddFuncGraph(graph); + } + resource->set_func_graph(graphs[0]); + return; + } + + // do normal action when not in `parse` and `symbol_resolve` stage + *result = action.second(resource); +} +#endif + +void Pipeline::Run() { + MS_LOG(INFO) << "Pipeline run"; + MS_EXCEPTION_IF_NULL(resource_); + FuncGraphPtr user_graph = nullptr; + + WITH(MsProfile::GetProfile())[&user_graph, this]() { + int i = 0; + for (auto &action : actions_) { +#ifdef ENABLE_TIMELINE + DumpTime &dump_time = DumpTime::GetInstance(); + dump_time.Record(action.first, GetTime(), true); +#endif + bool result = true; + WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() { + MS_LOG(DEBUG) << "Action " << action.first << " start ..."; +#ifdef ENABLE_LOAD_ANF_IR + RunPipelineAction(action, resource_, &result); +#else + result = action.second(resource_); +#endif + MS_LOG(DEBUG) << "Action " << action.first << " end."; + }; + if (!result) { + MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; + } + if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { + auto graph = resource_->func_graph(); + if (graph != nullptr) { + user_graph = graph; + std::string base_name = GetBaseNameForIR(i, action.first); + + // generate IR file in dot format, which can be converted to svg file using graphviz dot command + draw::Draw(base_name + ".dot", graph); + // generate IR file in human readable format + DumpIR(base_name + ".ir", graph); + // generate IR file in a heavily commented format, which can also be reloaded + ExportIR(base_name + ".dat", std::to_string(i), graph); + } +#ifdef MS_DEBUG + // Dump graph cnode list + MS_LOG(INFO) << "Show CNode list after " << action.first; + graph->DumpCNodeList(); +#endif + } + if (resource_->func_graph() != nullptr) { + auto func_graph = resource_->func_graph(); + if (func_graph->has_flag(GRAPH_FLAG_HAS_EFFECT)) { + func_graph->EraseUnusedNodeInOrder(); + func_graph->CheckOrder(); + for (auto fg : func_graph->func_graphs_used_total()) { + MS_LOG(DEBUG) << "Check order graph " << fg->ToString() << "."; + fg->EraseUnusedNodeInOrder(); + fg->CheckOrder(); + } + } + } + i++; +#ifdef ENABLE_TIMELINE + dump_time.Record(action.first, GetTime(), false); +#endif + } + }; +#ifdef ENABLE_PROFILE + MsProfile::Print(); + MsProfile::Reset(); +#endif + + if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { + std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); + MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; + draw::DrawUserFuncGraph(user_graph_file, user_graph); + } + MS_LOG(INFO) << "End"; +} + +void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { + std::size_t size = args.size(); + + for (std::size_t i = 0; i < size; i++) { + py::object arg = args[i]; + auto ms_context = MsContext::GetInstance(); + if (ms_context->backend_policy() == kMsConvert && py::isinstance(arg)) { + MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor."; + } + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(arg, &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; + } + if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa()) { + MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() + << " is not tensor."; + } + arg_list->push_back(converted); + } + + MS_EXCEPTION_IF_NULL(res); + auto graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + std::vector graph_params = graph->parameters(); + std::size_t graph_params_size = graph_params.size(); + if ((*arg_list).size() != graph_params_size) { + // maybe some default parameter + for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { + MS_EXCEPTION_IF_NULL(graph_params[i]); + auto param_ptr = (graph_params[i])->cast(); + if (!param_ptr->has_default()) { + MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; + } + arg_list->push_back(param_ptr->default_param()->value()); + } + } +} + +void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) { + ProcessVmArgInner(args, GetResource(phase), arg_list); +} + +py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { + std::size_t size = args.size(); + if (!py::isinstance(phase)) { + MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; + } + auto phase_s = py::cast(phase); + std::string backend = MsContext::GetInstance()->backend_policy(); +#ifdef ENABLE_GE + if (backend == "ge") { + return ExecDFGraph(info_, args, phase_s); + } +#else + if (backend == "ms" || backend == "ge") { + auto ret_val = std::make_shared(); + if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { + if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { + return *ret_val; + } + } + if (backend == "ge") { + if (args.size() > 0) { + return args[0]; + } + return args; + } + } +#endif + std::size_t full_arg_size = ArgListSize(phase_s); + if (size > full_arg_size) { + MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size; + } + VectorRef arg_list; + ProcessVmArg(args, phase_s, &arg_list); + + compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s); + if (run == nullptr) { + MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; + } + + MS_LOG(DEBUG) << "Eval run" << backend; + BaseRef value = (*run)(arg_list); + MS_LOG(DEBUG) << "Run end"; + return BaseRefToPyData(value); +} + +FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params) { +#if (ENABLE_GE || ENABLE_D) + return BuildDFGraph(info_, init_params, phase, broadcast_params); +#else + return nullptr; +#endif +} + +void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { +#if ENABLE_GE + RunGEInitGraph(init_params, phase); +#endif +} + +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase, bool need_run) { + std::string name = MsContext::GetInstance()->backend_policy(); +#ifndef NO_DLIB + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->IsTsdOpened() || !ms_context->IsGeInited()) { + (void)InitBackend(); + } +#endif + if (name == kMsConvert || name == kMsVm) { + return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); + } +#if ENABLE_GE + return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase); +#else + std::string backend = MsContext::GetInstance()->backend_policy(); + if (backend == "ge") { + return true; + } +#endif + return false; +} + +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, bool need_run) { + MS_LOG(INFO) << "Start InitDataSet Entry"; + std::vector int_input_indexes; + (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), + [](int64_t item) { return static_cast(item); }); + std::vector> int_shapes; + (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), + [](const std::vector &item) { + std::vector vector_item; + (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), + [](int64_t inner_item) { return static_cast(inner_item); }); + return vector_item; + }); + auto p_init = std::make_shared("InitDataSetQueue"); + p_init->set_attr("queue_name", MakeValue(queue_name)); + p_init->set_attr("size", MakeValue(static_cast(size))); + p_init->set_attr("batch_size", MakeValue(static_cast(batch_size))); + p_init->set_attr("types", MakeValue(types)); + p_init->set_attr("shapes", MakeValue(int_shapes)); + p_init->set_attr("input_indexes", MakeValue(int_input_indexes)); + + const std::vector empty_str_list; + p_init->set_attr("input_names", MakeValue(empty_str_list)); + p_init->set_attr("output_names", MakeValue(empty_str_list)); + + FuncGraphPtr func_graph = std::make_shared(); + auto app_init = std::make_shared(AnfNodePtrList{NewValueNode(p_init)}, func_graph); + func_graph->set_output(app_init); + auto manager = MakeManager(); + manager->AddFuncGraph(func_graph); + + // AbstractNone indicates there is no output for this apply node. + auto abstract_none = std::make_shared(); + app_init->set_abstract(abstract_none); + + auto backend = compile::CreateBackend(); + MS_EXCEPTION_IF_NULL(backend); + auto convert_fn = backend->convert_fn(); + MS_EXCEPTION_IF_NULL(convert_fn); + // Convert CNodeList to LinConvertResult. + ConfigManager::GetInstance().set_iter_num(1); + auto runner = convert_fn({app_init}, ""); + if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { + backend->Link(runner.graph_id); + } + ConfigManager::GetInstance().set_iter_num(size); + + if (!(*runner.run)) { + // empty function + MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset."; + } + + // launch init dataset runner without inputs and outputs + VectorRef args; + auto fn = runner.run; + if (need_run) { + (void)(*fn)(args); + } + MS_LOG(DEBUG) << "InitDataSetVm End."; + return true; +} + +void ResetOpId() { mindspore::id_generator::reset_id(); } + +void InitHccl() { +#ifdef ENABLE_GE + (void)InitBackend(); +#else + mindspore::parse::python_adapter::set_python_env_flag(true); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + (void)ms_context->OpenTsd(); + uint32_t device_id = ms_context->device_id(); + std::string device_name = ms_context->device_target(); + ms_context->set_enable_hccl(true); + if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + if (!runtime_instance->Init()) { + MS_LOG(ERROR) << "Kernel runtime init error."; + return; + } + } +#endif +} + +void FinalizeHccl() { +#ifdef ENABLE_GE + (void)FinalizeBackend(); +#else + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); +#endif +} + +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { +#if (ENABLE_GE || ENABLE_D) + ExportDFGraph(file_name, phase); +#endif + MS_LOG(WARNING) << "In ut test no export_graph"; +} + +void ReleaseGeTsd() { + auto context_ptr = MsContext::GetInstance(); + if (context_ptr != nullptr) { + (void)context_ptr->FinalizeGe(true); + (void)context_ptr->CloseTsd(true); + } +} + +void InitBackend() { + // set python env flag + mindspore::parse::python_adapter::set_python_env_flag(true); + // open tsd before ge initialize + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->OpenTsd()) { + MS_LOG(EXCEPTION) << "Open tsd failed"; + } + (void)ms_context->InitGe(); +} + +void FinalizeBackend() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + (void)context_ptr->FinalizeGe(); + (void)context_ptr->CloseTsd(); +} + +void ClearResAtexit() { + MS_LOG(DEBUG) << "Pipeline clear all resource"; + pynative::ClearPyNativeSession(); + session::ClearPythonParasMap(); + device::KernelRuntimeManager::Instance().ClearRuntimeResource(); + + ad::g_k_prims.clear(); + + abstract::ClearPrimEvaluatorMap(); + compile::ClearConvertCache(); + pipeline::GetMethodMap().clear(); + pipeline::ExecutorPy::ClearRes(); + pipeline::ReclaimOptimizer(); + pynative::PynativeExecutor::GetInstance()->ClearRes(); + opt::python_pass::PyPassManager::GetInstance()->ClearRes(); +#ifdef ENABLE_GE + transform::DfGraphManager::GetInstance().ClearGraph(); + transform::DfGraphConvertor::get_adpt_map().clear(); +#endif + ReleaseGeTsd(); + parse::python_adapter::ResetPythonScope(); +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h new file mode 100644 index 0000000000..705853d086 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -0,0 +1,148 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "utils/base_ref_extends.h" +#include "debug/draw.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "pipeline/jit/action.h" +#include "vm/segment_runner.h" +#include "vm/transform.h" +#include "pipeline/jit/base.h" + +namespace mindspore { +extern const char kMsConvert[]; +extern const char kMsVm[]; + +// namespace to support pipeline structures definition +namespace pipeline { + +namespace py = pybind11; + +class Pipeline { + public: + Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} + + ~Pipeline() = default; + + void Run(); + + ResourcePtr resource() { return resource_; } + + private: + ResourcePtr resource_; + std::vector actions_; +}; + +// A function pipeline. +class ExecutorPy : public std::enable_shared_from_this { + public: + static std::shared_ptr GetInstance() { + std::lock_guard i_lock(instance_lock_); + if (executor_ == nullptr) { + executor_ = std::shared_ptr(new (std::nothrow) ExecutorPy()); + } + return executor_; + } + + ~ExecutorPy(); + + void SaveCompiledGraph(const std::string &phase_s); + bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + + void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); + + // for pynative mode when use_vm is on + py::object Run(const py::tuple &args, const py::object &phase); + ResourcePtr GetResource(const std::string &phase); + FuncGraphPtr GetFuncGraph(const std::string &phase); + py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); + std::size_t ArgListSize(const std::string &phase); + compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); + bool HasCompiled(const std::string &phase) const; + + FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params = {}); + void RunInitGraph(const py::dict &init_params, const std::string &phase); + py::dict GetParameterLayout(const std::string &phase); + py::dict GetCNodeStrategy(const std::string &phase); + py::dict GetAllreduceFusion(const std::string &phase); + void DelNetRes(const std::string &id); + void ReleaseResource(const py::object &phase); + static void ClearRes(); + + std::map> FetchInfoForQuantExport(const std::string &phase_s); + + private: + ExecutorPy(); + void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); + bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; + void GetGeBackendPolicy() const; + // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after + // 'validate' stage + static std::vector FilterActions(const std::vector &actions, const std::string &phase); + + std::map info_; + static std::shared_ptr executor_; + static std::mutex instance_lock_; +}; +using ExecutorPyPtr = std::shared_ptr; + +// Generate a key for mapping function graph +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); +py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); + +bool InitDistribute(const std::map &options); + +void ResetOpId(); +void InitHccl(); +void FinalizeHccl(); +void InitBackend(); +void FinalizeBackend(); + +void ClearResAtexit(); +void ReleaseGeTsd(); + +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); + +// init and exec dataset sub graph +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase, bool need_run); + +// Build and run dataset subgraph for ms backend +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, bool need_run); + +void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); + +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc new file mode 100644 index 0000000000..e08af4f2dc --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc @@ -0,0 +1,535 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/pipeline_ge.h" + +#include +#include +#include +#include +#include + +#include "debug/anf_ir_dump.h" +#include "ir/tensor.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "transform/graph_ir/graph_builder.h" +#include "transform/graph_ir/graph_runner.h" +#include "debug/draw.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace pipeline { +using Tensor = mindspore::tensor::Tensor; +using MetaTensor = mindspore::tensor::MetaTensor; +using TensorOrderMap = std::map>; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; +using mindspore::transform::DfGraphConvertor; +using mindspore::transform::DfGraphManager; +using mindspore::transform::GeTensorPtr; +using mindspore::transform::MeTensorPtr; +using mindspore::transform::Status; +using mindspore::transform::TransformUtil; + +void DoExecNonInputGraph(const std::string &phase) { + std::vector ge_tensors; + std::vector ge_outputs; + transform::RunOptions run_options; + run_options.name = phase; + auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Can not found GraphRunner"; + return; + } + + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed"; + return; + } + } +} + +void SetGeOption(const std::map &options) { + ConfigManager::GetInstance().set_ge_initialize_options(options); +} + +Status CreateSessionAndGraphRunner(bool is_training = true) { + std::shared_ptr sess = DfGraphManager::GetInstance().GetGeSession(); + if (sess == nullptr) { + transform::SessionOptions options; + if (is_training) { + options["ge.trainFlag"] = "1"; + options["ge.streamNum"] = "100"; + options["ge.enabledLocalFmkop"] = "1"; + options["ge.hcomParallel"] = "1"; + } else { + options["ge.trainFlag"] = "0"; + } + + options["ge.enablePrintOpPass"] = "0"; + sess = transform::GraphRunner::NewSession(options); + if (sess == nullptr) { + MS_LOG(ERROR) << "Init data graph failed, because of create Ge session failed"; + return Status::FAILED; + } else { + DfGraphManager::GetInstance().SetGeSession(sess); + } + } + + transform::GraphRunnerOptions options; + options.sess_ptr = sess; + auto graph_runner = std::make_shared(options); + if (graph_runner == nullptr) { + MS_LOG(ERROR) << "Create new graph runner failed"; + return Status::FAILED; + } else { + DfGraphManager::GetInstance().SetGraphRunner(graph_runner); + } + + return Status::SUCCESS; +} + +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { + std::vector ge_types; + (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { + return transform::TransformUtil::ConvertDataType(i->type_id()); + }); + + ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE); + ConfigManager::GetInstance().set_iter_num(size); + ConfigManager::GetInstance().set_dataset_phase(phase); + + DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes); + ConfigManager::GetInstance().set_dataset_param(param); + + if (transform::BuildDatasetGraph(param, phase) != transform::SUCCESS) { + MS_LOG(ERROR) << "Build dateset graph failed."; + return false; + } + +#if ENABLE_TRAIN + (void)setenv("GE_TRAIN", "1", 1); +#else + (void)setenv("GE_TRAIN", "0", 1); +#endif + + if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { + MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; + return false; + } + + MS_LOG(INFO) << "DoExecNonInputGraph:" << phase; + DoExecNonInputGraph(phase); + + return true; +} + +void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { + for (auto item : dict) { + if ((!py::isinstance(item.first))) { + MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; + continue; + } + std::shared_ptr tensor; + std::string name = py::cast(item.first); + if (py::isinstance(item.second.attr("default_input"))) { + // convert float to tensor with shape([1]) + tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); + *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); + } else if (py::isinstance(item.second.attr("default_input"))) { + // convert int to tensor with shape([1]) + tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); + *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); + } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { + // cast tensor + tensor = py::cast>(item.second.attr("default_input")); + } + + if (tensor == nullptr) { + MS_LOG(EXCEPTION) << "Get default value for " << name << " failed"; + } + (void)tensors->emplace(name, tensor); + } +} + +bool AddDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { + FuncGraphPtr anf_graph = info.at(phase)->func_graph; + DfGraphConvertor convertor(anf_graph); + + size_t pos = phase.find('.'); + std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1)); + std::string phase_prefix = phase.substr(0, pos); + if (phase_prefix == "export") { + MS_LOG(INFO) << "Set DfGraphConvertor training : false"; + convertor.set_training(false); + } + + TensorOrderMap init_tensors{}; + ConvertObjectToTensors(init_params, &init_tensors); + (void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph(); + + if (broadcast_params != py::none()) { + if (!py::isinstance(broadcast_params)) { + MS_LOG(ERROR) << "Invalid broadcast params, it must be py::dict type"; + return false; + } + py::dict broadcast = broadcast_params.cast(); + if (broadcast.empty()) { + (void)convertor.GenerateBroadcastGraph(init_tensors); + } else { + TensorOrderMap broadcast_tensors{}; + ConvertObjectToTensors(broadcast, &broadcast_tensors); + (void)convertor.GenerateBroadcastGraph(broadcast_tensors); + } + MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty(); + } + + (void)convertor.GenerateCheckpointGraph(); + if (convertor.ErrCode() != 0) { + DfGraphManager::GetInstance().ClearGraph(); + MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); + return false; + } + + if (MsContext::GetInstance()->save_graphs_flag()) { + convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug + convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug + convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug + } + std::string init_graph = "init_subgraph." + net_id; + std::string checkpoint_name = "save." + net_id; + if (phase.find("train") != std::string::npos) { + (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); + } else { + (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); + } + (void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); + (void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); + + Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); + if (ret == Status::SUCCESS) { + DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); + } + + return true; +} + +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { + if (info.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + FuncGraphPtr anf_graph = info.at(phase)->func_graph; + + if (MsContext::GetInstance()->save_graphs_flag()) { + draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug + DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); + } + + if (!AddDFGraph(info, init_params, phase, broadcast_params)) { + MS_LOG(ERROR) << "GenConvertor failed"; + return nullptr; + } + +#if ENABLE_TRAIN + (void)setenv("GE_TRAIN", "1", 1); +#else + (void)setenv("GE_TRAIN", "0", 1); +#endif + + if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { + MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; + return nullptr; + } + + return anf_graph; +} + +void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { + MS_LOG(DEBUG) << "ExecInitGraph start."; + TensorOrderMap inputs_with_name{}; + ConvertObjectToTensors(init_params, &inputs_with_name); + std::vector inputs; + (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), + [](const std::pair &item) { return item.second; }); + + std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); + if (ge_tensors.size() != inputs.size()) { + MS_LOG(ERROR) << "Args convert to ge tensor error."; + return; + } + MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size() << "."; + + std::vector ge_outputs; + transform::RunOptions run_options; + + run_options.name = phase; + if (DfGraphManager::GetInstance().GetGraphByName(phase) == nullptr) { + MS_LOG(WARNING) << "Can not find " << phase << " sub graph, don't need data init subgraph in INFER mode."; + return; + } + auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(EXCEPTION) << "Can not found GraphRunner."; + } + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(EXCEPTION) << "Exec " << phase << " graph failed."; + } + + MS_LOG(INFO) << "Exec " << phase << " graph success."; + + if ((ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::DISTRIBUTION) && + (DfGraphManager::GetInstance().GetGraphByName(BROADCAST_GRAPH_NAME) != nullptr)) { + run_options.name = BROADCAST_GRAPH_NAME; + ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + if (ret != Status::SUCCESS) { + MS_LOG(EXCEPTION) << "Exec BROADCAST_GRAPH_NAME failed."; + } + MS_LOG(INFO) << "Exec broadcast graph success."; + } + } +} + +py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { + MS_EXCEPTION_IF_NULL(cnode_data); + + if (cnode_data->isa()) { + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } + + BaseShapePtr shape = cnode_data->BuildShape(); + if (!shape->isa()) { + MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); + } + auto shape_me = shape->cast()->shape(); + auto shape_ge = py::cast(data[*count]).shape(); + if (shape_ge != shape_me) { + MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge + << " is not the same as the shape of the tensor derived: " << shape_me; + } + + return data[(*count)++]; + } + + if (!cnode_data->isa()) { + MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could " + << "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString() + << "."; + } + auto data_tp = cnode_data->cast(); + auto elements = data_tp->elements(); + size_t size = data_tp->size(); + auto tp = py::tuple(size); + for (size_t i = 0; i < size; i++) { + tp[i] = ExtractGeneralCnodeRet(elements[i], data, count); + } + return std::move(tp); +} + +py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { + MS_EXCEPTION_IF_NULL(output_node); + + if (output_node->isa()) { + return ValuePtrToPyData(GetValueNode(output_node)); + } + + if (output_node->isa()) { + if (*count >= data.size()) { + MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() + << " less than the number of elements required. "; + } + return data[(*count)++]; + } + + auto output_c = output_node->cast(); + if (output_c == nullptr) { + MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got " + << output_node->ToString(); + } + + if (output_c->IsApply(prim::kPrimMakeTuple)) { + auto input_list = output_c->inputs(); + size_t size = input_list.size(); + auto tp = py::tuple(size - 1); + for (size_t i = 1; i < size; i++) { + tp[i - 1] = StructureOutput(input_list[i], data, count); + } + return std::move(tp); + } + if (output_c->IsApply(prim::kPrimDepend)) { + return StructureOutput(output_c->input(1), data, count); + } + + return ExtractGeneralCnodeRet(output_c->abstract(), data, count); +} + +std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, + const std::string &phase) { + std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); + if (ge_tensors.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; + } + + std::vector ge_outputs; + transform::RunOptions run_options; + run_options.name = phase; + auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); + if (graph_runner == nullptr) { + MS_LOG(EXCEPTION) << "Can not found GraphRunner."; + } + + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size(); + Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); + MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size(); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "Exec graph failed"; + return nullptr; + } + } + + std::vector me_outputs = TransformUtil::ConvertGeTensors(ge_outputs); + if (me_outputs.size() != ge_outputs.size()) { + MS_LOG(WARNING) << "Convert output Ge tensor to Me tensor failed"; + } + + py::tuple outputs(me_outputs.size()); + for (std::size_t i = 0; i < outputs.size(); i++) { + outputs[i] = *me_outputs[i]; + } + + std::shared_ptr ret = nullptr; + + AnfNodePtr output_node = graph->get_return()->input(1); + MS_EXCEPTION_IF_NULL(output_node); + size_t count = 0; + py::object oj = StructureOutput(output_node, outputs, &count); + ret = std::make_shared(oj); + + return ret; +} + +void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, + std::vector *inputs) { + // check the arg and use the ExecutorPy args + std::size_t size = args.size(); + + if (info.count(phase) == 0) { + MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); + } + + auto arg_size = info.at(phase)->arg_list_size; + if (size != arg_size) { + MS_LOG(EXCEPTION) << "The real arg num : size = " << size << ". graph_arg_size = " << arg_size; + } + + // process the first args of tensor + // only in dataset normal(non-sink) mode, fp_bp graph need input tensors + if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) { + for (std::size_t i = 0; i < size; i++) { + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(args[i], &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; + } + if (converted->isa()) { + inputs->push_back(converted->cast()); + } else { + MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; + } + } + } +} + +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase) { + std::string phase_prefix = GetPhasePrefix(phase); + if (phase_prefix == "save") { + DoExecNonInputGraph(phase); + ConfigManager::GetInstance().ResetConfig(); + return py::none(); + } + + if (info.count(phase) == 0) { + MS_LOG(EXCEPTION) << "There is no phase:" << phase; + } + FuncGraphPtr anf_graph = info.at(phase)->func_graph; + +#ifdef ENABLE_INFER + // Now don't use the graph because the exec ge function don't take effect + MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); + if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) { + MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; + ConfigManager::GetInstance().ResetConfig(); + return py::none(); + } +#endif + + std::shared_ptr ret_val = std::make_shared(); + // We will not execute graph when output is constant or just input itself. + if (IsGraphOutputValueNodeOrParameter(info.at(phase)->func_graph->output(), args, ret_val)) { + ConfigManager::GetInstance().ResetConfig(); + return *ret_val; + } + + std::vector inputs; + ProcessGeArg(info, args, phase, &inputs); + + std::shared_ptr ret = DoExecGraph(anf_graph, inputs, phase); + ConfigManager::GetInstance().ResetConfig(); + if (ret != nullptr) { + return *ret; + } else { + MS_LOG(EXCEPTION) << "Exec graph failed"; + } +} +void ExportDFGraph(const std::string &file_name, const std::string &phase) { + MS_LOG(DEBUG) << "ExportGraph Begin"; + transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); + if (wrap_ptr == nullptr) { + MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; + return; + } + + transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_; + if (nullptr == ge_graph) { + MS_LOG(ERROR) << "The export graph is null"; + return; + } + + (void)ge_graph->SaveToFile(file_name); + + MS_LOG(DEBUG) << "ExportGraph End"; +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.h b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h new file mode 100644 index 0000000000..f834125231 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ +#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pipeline/jit/base.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace pipeline { +namespace py = pybind11; + +void SetGeOption(const std::map &options); + +void RunGEInitGraph(const py::dict &init_params, const std::string &phase); + +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase = "train"); + +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params = {}); + +// init and exec dataset sub graph for GE backend +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); + +void ExportDFGraph(const std::string &file_name, const std::string &phase); +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc new file mode 100644 index 0000000000..e9467e4aeb --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/remove_value_node_dup.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "ir/manager.h" +#include "frontend/optimizer/cse.h" +#include "utils/log_adapter.h" +#include "utils/hashing.h" + +namespace mindspore { +namespace pipeline { +void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, + HashValue *const hash_value) { + const auto &to_check_value = GetValueNode(node); + MS_EXCEPTION_IF_NULL(to_check_value); + + // Calculate hash value. + size_t h; + auto hash_iter = hash_value->find(node); + if (hash_iter == hash_value->end()) { + h = hash_combine(to_check_value->hash(), (opt::AbsOf(node)->hash())); + (*hash_value)[node] = h; + } else { + h = hash_iter->second; + } + + auto bucket_iter = hash_cache->find(h); + if (bucket_iter == hash_cache->end()) { + // Meet for the first time, add bucket. + (*hash_cache)[h] = {node}; + return; + } + + auto &bucket = bucket_iter->second; + // Check if need to replace node with value node already met. + for (const auto &v : bucket) { + // Already met and cached. + if (v == node) { + return; + } + const auto &existed_value = GetValueNode(v); + MS_EXCEPTION_IF_NULL(existed_value); + auto equal = [&]() -> bool { + if (existed_value->isa() && to_check_value->isa()) { + return existed_value->cast()->ValueEqual(*(to_check_value->cast())); + } + return *existed_value == *to_check_value; + }; + if (equal()) { + (void)manager->Replace(node, v); + return; + } + } + + // Meet for the first time, append node to bucket. + bucket.emplace_back(node); +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h similarity index 100% rename from mindspore/ccsrc/pipeline/remove_value_node_dup.h rename to mindspore/ccsrc/pipeline/jit/remove_value_node_dup.h diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc new file mode 100644 index 0000000000..ece128b77b --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -0,0 +1,260 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/resource.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "debug/draw.h" +#include "debug/trace.h" +#include "ir/dtype.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/ops.h" +#include "utils/graph_utils.h" +#include "frontend/optimizer/ad/dfunctor.h" +#include "vm/segment_runner.h" + +namespace mindspore { +// namespace to support opmap definition +namespace pipeline { + +MethodMap &GetMethodMap() { + static MethodMap method_map = { + {kObjectTypeString, + { + {"__bool__", std::string("str_bool")} // C.str_bool + }}, + {kMetaTypeNone, + { + {"__bool__", std::string("none_bool")} // C.none_bool + }}, + {kNumberTypeBool, + { + {"__and__", prim::kPrimBoolAnd}, // P.bool_and + {"__or__", prim::kPrimBoolOr}, // P.bool_or + {"__eq__", prim::kPrimBoolEq}, // P.bool_eq + {"__ne__", std::string("bool_ne")}, // C.bool_ne + {"__bool__", prim::kPrimIdentity} // P.identity + }}, + {kNumberTypeInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul + {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow + {"__floor__", prim::kPrimIdentity}, // P.identity + {"__trunc__", prim::kPrimIdentity}, // P.identity + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt + {"__le__", prim::kPrimScalarLe}, // P.scalar_le + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array + }}, + {kNumberTypeUInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimIdentity}, // P.identity, + {"__trunc__", prim::kPrimIdentity}, // P.identity, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kNumberTypeFloat, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv + {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, + {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("float_bool")}, // C.float_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kObjectTypeTuple, + { + {"__len__", prim::kPrimTupleLen}, // P.tuple_len, + {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, + {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity, + {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, + {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext + {"__bool__", std::string("tuple_bool")} // C.tuple_bool + }}, + {kObjectTypeList, + { + {"__len__", prim::kPrimListLen}, // P.list_len, + {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, + {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity + {"__ms_next__", std::string("list_next")}, // C.list_next + {"append", std::string("list_append")}, // C.list_next + {"__bool__", std::string("list_bool")}, // C.list_bool + {"__ms_hasnext__", std::string("list_hasnext")}, + }}, + {kObjectTypeDictionary, + { + {"__len__", prim::kPrimDictLen}, // P.dict_len + {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem + {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, + {"__bool__", std::string("dict_bool")} // C.dict_bool + }}, + {kObjectTypeTensorType, + { + {"__add__", std::string("add")}, // C.add + {"__sub__", std::string("sub")}, // C.sub + {"__mul__", std::string("mul")}, // C.mul + {"__truediv__", std::string("truediv")}, // C.truediv + {"__floordiv__", std::string("floordiv")}, // C.floordiv + {"__mod__", std::string("mod")}, // C.mod + {"__pow__", std::string("pow_")}, // C.pow + {"__floor__", std::string("array_floor")}, // C.array_floor + {"__trunc__", std::string("array_trunc")}, // C.array_trunc + {"__pos__", std::string("array_uadd")}, // C.array_uadd + {"__neg__", std::string("array_usub")}, // C.array_usub + {"__eq__", std::string("eq")}, // C.eq + {"__ne__", std::string("ne")}, // C.ne + {"__lt__", std::string("lt")}, // C.lt + {"__gt__", std::string("gt")}, // C.gt + {"__le__", std::string("le")}, // C.le + {"__ge__", std::string("ge")}, // C.ge + {"__matmul__", prim::kPrimDot}, // P.dot, + {"__len__", prim::kPrimArrayLen}, // P.array_len, + {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, + {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, + {"__ms_iter__", std::string("array_iter")}, // C.array_iter + {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, + {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"transpose", std::string("transpose")}, // P.transpose + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + }}, + {kObjectTypeIndexedSlicesType, + { + {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values + {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices + {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape + }}, + {kObjectTypeJTagged, {}}, + {kObjectTypeSymbolicKeyType, {}}, + {kObjectTypeEnvType, {}}}; + return method_map; +} + +Resource::Resource(const py::object &obj) + : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), + input_(obj), + is_cleaned_(false) {} + +Resource::~Resource() { + MS_LOG(DEBUG) << "Resource clear"; + + // If exit normally, these global variables will be cleaned + // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION, + // these global variables may not being cleaned, it may + // cause segmentfault when free python object inside these global variables + // after python interpreter got freed, so these global variables + // are cleaned here. + // So if exit normally, these global variable will be cleaned twice, + // care be taken to prevent double free in the following functions. + if (!is_cleaned_) { + try { + Clean(); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); + } catch (...) { + MS_LOG(ERROR) << "Exception when cleaning resource."; + } + } +} + +bool Resource::IsTypeInMethodMap(const TypeId &type) { + TypeId type_id = NormalizeTypeId(type); + const MethodMap &method_map = GetMethodMap(); + auto iter = method_map.find(static_cast(type_id)); + if (iter != method_map.end()) { + return true; + } + return false; +} + +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { + TypeId type_id = NormalizeTypeId(type); + const MethodMap &method_map = GetMethodMap(); + auto iter = method_map.find(static_cast(type_id)); + if (iter == method_map.end()) { + MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; + return Any(); + } + + auto iter_map = iter->second.find(name); + if (iter_map == iter->second.end()) { + MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; + return Any(); + } + return iter_map->second; +} + +void Resource::Clean() { + // AbstractTensor->elements() will be saved in AbstractBasePtrList + args_spec_.clear(); + input_ = py::none(); + // Context with AbstractBasePtrList may be saved in GraphEvaluator + // some Evaluator like ResolveEvaluator may save Python object in cache, + // it should be cleaned before Python Interpreter destructed. + MS_EXCEPTION_IF_NULL(engine_); + engine_->ClearEvaluatorCache(); + // clean static variable to prevent from crash. As static variable is released after + // Python threads is released. + parse::data_converter::ClearObjectCache(); + parse::Parser::CleanParserResource(); + parse::CleanDataClassToClassMap(); + trace::ClearTraceStack(); + is_cleaned_ = true; +} +} // namespace pipeline +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/resource.h b/mindspore/ccsrc/pipeline/jit/resource.h new file mode 100644 index 0000000000..819fdd3d20 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/resource.h @@ -0,0 +1,120 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ +#define MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ + +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +#include "utils/any.h" +#include "utils/profile.h" +#include "ir/manager.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "./common.h" + +namespace mindspore { +namespace pipeline { + +namespace py = pybind11; + +const char kBackend[] = "backend"; +const char kStepParallelGraph[] = "step_parallel"; +const char kOutput[] = "output"; + +class InferenceResource; + +using MethodMap = std::unordered_map>; + +MethodMap &GetMethodMap(); + +class ResourceBase { + public: + ResourceBase() { manager_ = MakeManager(); } + + virtual ~ResourceBase() = default; + + FuncGraphManagerPtr manager() { return manager_; } + // set a manager defined outside which will not manage the graphs. + void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } + + std::unordered_map &results() { return results_; } + + void SetResult(const std::string &key, const Any &value) { results_[key] = value; } + + Any GetResult(const std::string &key) { + if (results_.count(key) == 0) { + MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; + } + return results_[key]; + } + + bool HasResult(const std::string &key) const { return results_.count(key) != 0; } + + std::unordered_map results_; + + protected: + FuncGraphManagerPtr manager_; +}; + +using ResourceBasePtr = std::shared_ptr; + +class Resource : public ResourceBase { + public: + explicit Resource(const py::object &obj = py::none()); + + ~Resource() override; + + abstract::AnalysisEnginePtr engine() { return engine_; } + + static bool IsTypeInMethodMap(const TypeId &type); + + static Any GetMethodPtr(const TypeId &type, const std::string &name); + + const py::object &input() const { return input_; } + + FuncGraphPtr func_graph() const { return func_graph_; } + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } + + const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } + void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } + + // Reclaim resource and clear the cache. + // ExecutorPy::Compile() can be called multiple times, so cache + // should be cleared. + void Clean(); + + private: + abstract::AnalysisEnginePtr engine_; + FuncGraphPtr func_graph_; + abstract::AbstractBasePtrList args_spec_; + py::object input_; + bool is_cleaned_; +}; + +using ResourcePtr = std::shared_ptr; + +} // namespace pipeline +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc new file mode 100644 index 0000000000..8bdb2a0c6c --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.cc @@ -0,0 +1,361 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/abstract_function.h" + +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" + +namespace mindspore { +namespace abstract { +class Evaluator; +class AnalysisEngine; + +AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { + if (func_list.size() == 1) { + return func_list[0]; + } + return std::make_shared(func_list); +} + +AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { + auto this_func = shared_from_base(); + if (other->isa()) { + if (*this_func == *other) { + return this_func; + } + return std::make_shared(this_func, other); + } + auto other_union = dyn_cast(other); + if (other_union->IsSuperSet(this_func)) { + return other; + } + return std::make_shared(this_func, other); +} + +void AbstractFuncAtom::Visit(std::function visit_func) const { + visit_func(const_cast(this)->shared_from_base()); +} + +bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; } + +AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; } + +AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) { + AbstractFuncAtomPtrList new_func_list; + auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); }; + + first->Visit(build_func_list); + second->Visit(build_func_list); + func_list_ = new_func_list; +} + +std::string AbstractFuncUnion::ToString() const { + std::ostringstream buffer; + buffer << "AbstractFuncUnion({"; + int i = 0; + for (const auto &func : func_list_) { + MS_EXCEPTION_IF_NULL(func); + buffer << "[" << i << "]: " << func->ToString() << ", "; + i++; + } + buffer << "})"; + return buffer.str(); +} + +bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) { + MS_EXCEPTION_IF_NULL(other); + std::vector is_in_list; + auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) { + auto iter = find(func_list_.begin(), func_list_.end(), func); + if (iter == func_list_.end()) { + is_in_list.push_back(false); + } + return true; + }; + other->Visit(build_in_list); + return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; }); +} + +AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { + auto this_func = shared_from_base(); + if (other->isa()) { + if (IsSuperSet(other)) { + return this_func; + } + return std::make_shared(this_func, other); + } + auto other_union = dyn_cast(other); + if (other_union->IsSuperSet(this_func)) { + return other; + } + return std::make_shared(this_func, other); +} + +void AbstractFuncUnion::Visit(std::function visit_func) const { + for (AbstractFuncAtomPtr poss : func_list_) { + visit_func(poss); + } +} + +bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_union = static_cast(&other); + if (func_list_.size() != other_union->func_list_.size()) { + return false; + } + if (func_list_ == other_union->func_list_) { + return true; + } + return false; +} + +std::size_t AbstractFuncUnion::hash() const { + std::size_t hash_sum = 0; + for (auto f : func_list_) { + hash_sum = hash_combine(hash_sum, f->hash()); + } + return hash_sum; +} + +EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_prim = static_cast(&other); + if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { + return true; + } + return false; +} + +std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } + +EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_fg = static_cast(&other); + if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { + return true; + } + return false; +} + +std::size_t FuncGraphAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), func_graph_->hash()); + hash_value = hash_combine(hash_value, context_->hash()); + return hash_value; +} + +std::string FuncGraphAbstractClosure::ToString() const { + std::stringstream ss; + ss << "FuncGraphAbstractClosure: " + << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); + return ss.str(); +} + +EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_meta_fg = static_cast(&other); + if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { + return true; + } + return false; +} + +std::size_t MetaFuncGraphAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); + return hash_value; +} + +std::string MetaFuncGraphAbstractClosure::ToString() const { + return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name(); +} + +bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_partial = static_cast(&other); + if (fn_ != other_partial->fn_) { + return false; + } + if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_partial->args_spec_list_) { + return true; + } + return false; +} + +std::size_t PartialAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), fn_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +std::string PartialAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; + for (auto arg : args_spec_list_) { + buffer << arg->ToString() << ", "; + } + buffer << "))"; + return buffer.str(); +} + +EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_transformed = static_cast(&other); + if (fn_ == other_transformed->fn_) { + return true; + } + return false; +} + +std::size_t JTransformedAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), fn_->hash()); + return hash_value; +} + +EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_virtual = static_cast(&other); + if (output_ != other_virtual->output_) { + return false; + } + if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_virtual->args_spec_list_) { + return true; + } + return false; +} + +std::size_t VirtualAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), output_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string VirtualAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "VirtualAbstractClosure(args: {"; + int i = 0; + for (const auto &arg : args_spec_list_) { + MS_EXCEPTION_IF_NULL(arg); + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + buffer << "}, output: " << output_->ToString() << ")"; + return buffer.str(); +} + +EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { + MS_EXCEPTION_IF_NULL(engine); + + return engine->_GetEvaluatorFor(shared_from_base()); +} + +bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + auto other_typed = static_cast(&other); + if (output_ != other_typed->output_) { + return false; + } + if (prim_ != other_typed->prim_) { + return false; + } + if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { + return false; + } + if (args_spec_list_ == other_typed->args_spec_list_) { + return true; + } + return false; +} + +std::size_t TypedPrimitiveAbstractClosure::hash() const { + auto hash_value = hash_combine(tid(), prim_->hash()); + hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); + return hash_value; +} + +std::string TypedPrimitiveAbstractClosure::ToString() const { + std::ostringstream buffer; + buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {"; + int i = 0; + for (const auto &arg : args_spec_list_) { + MS_EXCEPTION_IF_NULL(arg); + buffer << "[" << i << "]: " << arg->ToString() << ", "; + i++; + } + buffer << "}, output: " << output_->ToString() << ")"; + return buffer.str(); +} + +bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { + if (!other.isa()) { + return false; + } + return true; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h b/mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h similarity index 100% rename from mindspore/ccsrc/pipeline/static_analysis/abstract_function.h rename to mindspore/ccsrc/pipeline/jit/static_analysis/abstract_function.h diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc new file mode 100644 index 0000000000..3e820eed3a --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -0,0 +1,404 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/evaluator.h" + +#include +#include + +#include "ir/func_graph_cloner.h" +#include "abstract/utils.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +namespace { +string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, + const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(evaluator); + std::stringstream ss; + if (out_conf != nullptr) { + ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); + } + for (size_t i = 0; i < arg_spec_list.size(); i++) { + ss << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString(); + } + return ss.str(); +} + +void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(evaluator); + if (out_conf != nullptr) { + auto node = out_conf->node(); + if (IsValueNode(node)) { + MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope() + << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); + } else { + MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString() + << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); + } + } +} +} // namespace + +AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list) { + AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); + normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); + FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); + MS_EXCEPTION_IF_NULL(parent_context_); + AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); + return context; +} + +static std::vector FastShadowSort(const AnfNodePtr &ret_node) { + auto current_func_graph = ret_node->func_graph(); + MS_EXCEPTION_IF_NULL(current_func_graph); + + std::vector sorted_nodes; + auto seen = NewSeenGeneration(); + std::size_t index = 0; + sorted_nodes.emplace_back(ret_node); + while (index < sorted_nodes.size()) { + auto current = sorted_nodes[index]; + index++; + MS_EXCEPTION_IF_NULL(current); + if (current->isa()) { + auto &inputs = current->cast()->inputs(); + for (auto it = inputs.begin(); it != inputs.end(); it++) { + AnfNodePtr input = *it; + if (input != nullptr && input->isa() && input->seen_ != seen && + input->func_graph() == current_func_graph) { + sorted_nodes.emplace_back(input); + input->seen_ = seen; + } + } + } + } + return sorted_nodes; +} + +EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { + FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); + MS_EXCEPTION_IF_NULL(fg); + std::size_t nargs = fg->parameters().size(); + if (args_spec_list.size() != nargs) { + MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is " + << fg->parameters().size() << ", but the number of provided arguments is " + << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); + } + MS_EXCEPTION_IF_NULL(parent_context_); + MS_EXCEPTION_IF_NULL(engine); + graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list); + const auto ¶meters = fg->parameters(); + for (size_t i = 0; i < nargs; i++) { + const auto &arg = args_spec_list[i]; + const auto &node = parameters[i]; + AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); + engine->cache().set_value(conf, std::make_shared(arg, nullptr)); + } + const AnfNodePtr &func_node = fg->get_return(); + + MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() + << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); + AbstractBasePtr ret_base = nullptr; + std::vector nodes = FastShadowSort(func_node); + for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { + const auto &node = *it; + AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); + MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); + ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); + MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() + << ", abstract: " << ret_base->ToString(); + } + + MS_EXCEPTION_IF_NULL(ret_base); + MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() + << ", is stub: " << fg->stub(); + if (fg->stub()) { + return std::make_shared(std::make_shared(), nullptr); + } + return std::make_shared(ret_base, nullptr); +} + +AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { + MS_EXCEPTION_IF_NULL(func_graph_); + if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + AbstractBasePtrList broaded_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->Broaden(); + }); + MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) + << ", broaded: " << mindspore::ToString(broaded_list); + return broaded_list; + } + return args_spec_list; +} + +AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(func_graph_); + if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + return args_spec_list; + } + if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { + if (parent_context_) { + MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() + << ", context: " << parent_context_->ToString(); + auto last_context = parent_context_->Filter(func_graph_); + if (last_context && last_context->func_graph() == func_graph_) { + MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); + MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); + // Join the last eval arguments and current arguments to check if there are loop variant. + auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list()); + MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); + // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. + if (!(joined_args_spec_list == args_spec_list)) { + func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; + } + return joined_args_spec_list; + } + } + if (trace_.size() != 0) { + MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); + MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); + // Join the last eval arguments and current arguments to check if there are loop variant. + auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); + // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. + if (!(joined_args_spec_list == args_spec_list)) { + trace_.push_back(joined_args_spec_list); + func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; + } + MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); + return joined_args_spec_list; + } else { + trace_.push_back(args_spec_list); + } + } + return args_spec_list; +} + +FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { + auto iter = func_graph_cache_.find(args_spec_list); + FuncGraphPtr ret = nullptr; + if (iter == func_graph_cache_.end()) { + auto fg = func_graph(); + MS_EXCEPTION_IF_NULL(fg); + TraceManager::DebugTrace(std::make_shared(fg->debug_info())); + FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list); + TraceManager::EndTrace(); + func_graph_cache_[args_spec_list] = generated_graph; + MS_EXCEPTION_IF_NULL(engine); + engine->func_graph_manager()->AddFuncGraph(generated_graph); + ret = generated_graph; + } else { + ret = iter->second; + } + + // For the top graph, if it is replaced by generated graph, update the top graph to the new one. + if (parse::Parser::GetTopFuncGraph() == func_graph()) { + if (ret != func_graph()) { + parse::Parser::UpdateTopFuncGraph(ret); + } + } + return ret; +} + +FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { + auto iter = func_graph_cache_.find(args_spec_list); + if (iter != func_graph_cache_.end()) { + return iter->second; + } + + MS_EXCEPTION_IF_NULL(meta_func_graph_); + FuncGraphPtr generated_func_graph = nullptr; + if (this->bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); + TraceManager::EndTrace(); + } else { + generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); + } + + FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph); + func_graph_cache_[args_spec_list] = cloned_func_graph; + MS_EXCEPTION_IF_NULL(engine); + engine->func_graph_manager()->AddFuncGraph(cloned_func_graph); + return cloned_func_graph; +} + +EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { + const std::string &evaluator_name = ToString(); + + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + args_spec_list = NormalizeArgs(args_spec_list); + args_spec_list = BroadenUndeterminedArgs(args_spec_list); + trace::TraceGraphEvalEnter(shared_from_base(), out_conf); + MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base(), args_spec_list, out_conf); + MS_EXCEPTION_IF_NULL(cache_); + auto iter = cache_->find(args_spec_list); + if (iter == cache_->end()) { + MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; + EvalResultPtr ret = Eval(engine, args_spec_list); + if (ret->abstract() == nullptr) { + EvalFailLogging(shared_from_base(), args_spec_list, out_conf); + MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; + } + MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; + (*cache_)[args_spec_list] = ret; + trace::TraceGraphEvalLeave(shared_from_base()); + return ret; + } else { + MS_EXCEPTION_IF_NULL(iter->second); + MS_EXCEPTION_IF_NULL(iter->second->abstract()); + MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; + trace::TraceGraphEvalLeave(shared_from_base()); + return iter->second; + } +} + +EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + EvalResultPtr ret = EvalPrim(engine, args_spec_list); + return ret; +} + +EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + if (args_conf_list.size() == 0) { + MS_LOG(EXCEPTION) << "Size should greater than 0"; + } + EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); + // No need to cache. + return ret; +} + +EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { + EvalResultPtr ret = EvalPrim(args_conf_list); + return ret; +} + +EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); + // Don't lookup from cache, as different out_conf with same node but different context + // may add different entry to anfnode_config_map_, like getattr primitive. + (*cache_)[args_spec_list] = ret; + return ret; +} + +EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + MS_EXCEPTION_IF_NULL(cache_); + auto iter = cache_->find(args_spec_list); + if (iter != cache_->end()) { + return iter->second; + } + + ConfigPtrList partial_args_conf_list; + // Join arguments in partial and the rest arguments from args_conf_list. + (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); + + (*cache_)[args_spec_list] = ret; + return ret; +} + +EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + MS_EXCEPTION_IF_NULL(cache_); + auto iter = cache_->find(args_spec_list); + if (iter != cache_->end()) { + return iter->second; + } + + // Call the original evaluator, get the result: y = f(x) + EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); + // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input + // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) + AbstractBasePtrList bparams; + bparams.push_back(SensitivityTransform(orig_func_)); + (void)std::transform( + args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), + [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); + AbstractBasePtr bparams_final = std::make_shared(bparams); + AbstractFunctionPtr bprop = + std::make_shared(SensitivityTransform(result->abstract()), bparams_final); + + // J(f)(J(x)) return a tuple (y, bprop_f) + AbstractBasePtrList jargs = {result->abstract(), bprop}; + AbstractBasePtr jtuple = std::make_shared(jargs); + auto infer_reuslt = std::make_shared(jtuple, std::make_shared()); + (*cache_)[args_spec_list] = infer_reuslt; + return infer_reuslt; +} + +EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.size() != args_spec_list_.size()) { + MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() + << ", arguments no: " << args_spec_list.size(); + } + // Check each parameter and argument match; + for (std::size_t i = 0; i < args_spec_list.size(); i++) { + MS_EXCEPTION_IF_NULL(args_spec_list[i]); + (void)args_spec_list[i]->Join(args_spec_list_[i]); + } + return std::make_shared(output_, std::make_shared()); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h new file mode 100644 index 0000000000..461574257d --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h @@ -0,0 +1,330 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ +#define PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ + +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace abstract { +using EvaluatorCacheMap = + std::unordered_map; +using EvaluatorCacheMapPtr = std::shared_ptr; + +using EvaluatorAttrMap = + std::unordered_map; +using EvaluatorAttrMapPtr = std::shared_ptr; + +class Evaluator : public Base { + public: + explicit Evaluator(const std::string &id) + : cache_(std::make_shared()), + attr_cache_(std::make_shared()), + identifier_(id) {} + ~Evaluator() override = default; + MS_DECLARE_PARENT(Evaluator, Base); + + // difference between Run() and Eval(): + // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. + // Run() will modify cache_ member, so it cannot marked as const; + virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); + + virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; + + virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } + + virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { + return args_spec_list; + } + + virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { + if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { + return true; + } + return false; + }); + if (is_abstract) { + MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; + return std::make_shared(std::make_shared(), std::make_shared()); + } + return nullptr; + } + + std::string ToString() const override { return identifier_; } + + virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } + + virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } + + EvaluatorCacheMapPtr &cache() { return cache_; } + EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } + + EvaluatorCacheMapPtr cache_; + EvaluatorAttrMapPtr attr_cache_; + std::string identifier_; + + AnfNodeWeakPtr bound_node_; +}; + +class PrimEvaluator : public Evaluator { + public: + explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} + ~PrimEvaluator() override = default; + MS_DECLARE_PARENT(PrimEvaluator, Evaluator); + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } +}; + +class TrivialPrimEvaluator : public PrimEvaluator { + public: + explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} + ~TrivialPrimEvaluator() override = default; + MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; +}; + +class TransitionPrimEvaluator : public PrimEvaluator { + public: + explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} + ~TransitionPrimEvaluator() override = default; + MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + // Parameter in_conf0 : the first element in args_conf_list; + virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; +}; + +class SymbolicPrimEvaluator : public PrimEvaluator { + public: + explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} + ~SymbolicPrimEvaluator() override = default; + MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; + virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; +}; + +// Evaluator will be stored in AnalysisEngine.constructors_ +using EvaluatorPtrList = std::vector; + +class DummyEvaluator : public Evaluator { + public: + DummyEvaluator() : Evaluator("dummy") {} + ~DummyEvaluator() override = default; + MS_DECLARE_PARENT(DummyEvaluator, Evaluator); + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } +}; + +// Wrap another evaluator to track a subset of uses. +// A TrackedEvaluator has its own cache that maps possible calls to +// their results, but is ultimately backed by a different evaluator. +// Multiple TrackedEvaluators can be backed by the same Evaluator. +class TrackedEvaluator : public Evaluator { + public: + explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {} + ~TrackedEvaluator() override = default; + MS_DECLARE_PARENT(TrackedEvaluator, Evaluator); + AnfNodePtr bound_node() const override { + if (sub_evaluator_ != nullptr) { + return sub_evaluator_->bound_node(); + } + return bound_node_.lock(); + } + + void set_bound_node(const AnfNodePtr &node) override { + if (sub_evaluator_ != nullptr) { + sub_evaluator_->set_bound_node(node); + } + bound_node_ = AnfNodeWeakPtr(node); + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } + + private: + EvaluatorPtr sub_evaluator_; +}; + +class BaseFuncGraphEvaluator : public Evaluator { + public: + explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) + : Evaluator("basegraph"), parent_context_(context) {} + + ~BaseFuncGraphEvaluator() override = default; + MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); + + EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + + virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; + + AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list); + AnalysisContextPtr graph_context() const { return graph_context_; } + + protected: + AnalysisContextPtr parent_context_; + + private: + AnalysisContextPtr graph_context_; +}; + +class FuncGraphEvaluator : public BaseFuncGraphEvaluator { + public: + FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) + : BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {} + + ~FuncGraphEvaluator() override = default; + MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); + + FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + + FuncGraphPtr func_graph() { return func_graph_; } + + AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; + AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; + std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } + + private: + FuncGraphPtr func_graph_; + std::unordered_map + func_graph_cache_; + std::vector trace_; +}; +using FuncGraphEvaluatorPtr = std::shared_ptr; + +class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { + public: + // Note: context parameter is not used; + MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, AnalysisContextPtr, const ScopePtr &scope) + : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {} + ~MetaFuncGraphEvaluator() override = default; + MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator); + + FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + + // Return normalized versions of the arguments. + AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { + return meta_func_graph_->NormalizeArgs(args_spec_list); + } + std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); } + + private: + MetaFuncGraphPtr meta_func_graph_; + std::unordered_map + func_graph_cache_; + ScopePtr scope_; +}; + +class PartialAppEvaluator : public Evaluator { + public: + PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args) + : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {} + ~PartialAppEvaluator() override = default; + MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator); + AnfNodePtr bound_node() const override { + if (evaluator_ != nullptr) { + return evaluator_->bound_node(); + } + return bound_node_.lock(); + } + + void set_bound_node(const AnfNodePtr &node) override { + if (evaluator_ != nullptr) { + evaluator_->set_bound_node(node); + } + bound_node_ = AnfNodeWeakPtr(node); + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; + } + + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } + + private: + EvaluatorPtr evaluator_; + AbstractBasePtrList args_spec_list_; +}; + +class VirtualEvaluator : public Evaluator { + public: + VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output) + : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {} + ~VirtualEvaluator() override = default; + MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); + + EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; + std::string ToString() const override { return identifier_; } + + private: + AbstractBasePtrList args_spec_list_; + AbstractBasePtr output_; +}; + +class JEvaluator : public Evaluator { + public: + JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) + : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {} + ~JEvaluator() override = default; + MS_DECLARE_PARENT(JEvaluator, Evaluator); + AnfNodePtr bound_node() const override { + if (evaluator_ != nullptr) { + return evaluator_->bound_node(); + } + return bound_node_.lock(); + } + + void set_bound_node(const AnfNodePtr &node) override { + if (evaluator_ != nullptr) { + evaluator_->set_bound_node(node); + } + bound_node_ = AnfNodeWeakPtr(node); + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; + } + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; + std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } + + private: + EvaluatorPtr evaluator_; + AbstractFunctionPtr orig_func_; +}; +} // namespace abstract +} // namespace mindspore +#endif // PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc new file mode 100644 index 0000000000..99e613395c --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -0,0 +1,1384 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/prim.h" + +#include +#include +#include +#include +#include +#include + +#include "frontend/operator/cc_implementations.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/do_signature.h" +#include "frontend/operator/prim_to_function.h" +#include "abstract/utils.h" +#include "utils/symbolic.h" +#include "./common.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/resolve.h" +#include "ir/tensor.h" +#include "utils/convert_utils.h" +#include "utils/context/ms_context.h" +#include "pipeline/jit/parse/data_converter.h" +#include "abstract/param_validator.h" +#include "common/utils.h" + +namespace mindspore { +namespace abstract { +PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { + static PrimitiveEvalImplMap prim_eval_implement_map = { + // Statements + {prim::kPrimReturn, {InferImplReturn, true}}, + {prim::kPrimTypeOf, {InferImplTypeof, false}}, + {prim::kPrimHasType, {InferImplHasType, false}}, + {prim::kPrimDot, {InferImplDot, true}}, + {prim::kPrimSwitch, {InferImplSwitch, true}}, + {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, + {prim::kPrimIs_, {InferImplIs_, true}}, + {prim::kPrimIsNot, {InferImplIsNot, true}}, + {prim::kPrimInDict, {InferImplInDict, true}}, + {prim::kPrimNotInDict, {InferImplNotInDict, true}}, + {prim::kPrimIsConsant, {InferImplIsConstant, true}}, + // Maths + {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, + {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + // Array + {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, + {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, + {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, + {prim::kPrimShape, {InferImplShape, true}}, + {prim::kPrimPack, {InferImplPack, true}}, + // Structure + {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, + {prim::kPrimMakeList, {InferImplMakeList, true}}, + {prim::kPrimMakeDict, {InferImplMakeDict, true}}, + {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, + {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, + {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, + {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, + {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, + {prim::kPrimListGetItem, {InferImplListGetItem, true}}, + {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, + {prim::kPrimListSetItem, {InferImplListSetItem, true}}, + {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, + {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, + {prim::kPrimListAppend, {InferImplListAppend, true}}, + {prim::kPrimTupleLen, {InferImplTupleLen, true}}, + {prim::kPrimListLen, {InferImplListLen, true}}, + {prim::kPrimArrayLen, {InferImplArrayLen, true}}, + {prim::kPrimListMap, {InferImplListMap, false}}, + {prim::kPrimListReduce, {InferImplListReduce, false}}, + {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, + {prim::kPrimReducedShape, {InferImplReduceShape, false}}, + {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, + {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, + {prim::kPrimShapeMul, {InferImplShapeMul, false}}, + {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, + {prim::kPrimListEqual, {InferImplListEqual, false}}, + {prim::kPrimMakeRange, {InferImplMakeRange, false}}, + {prim::kPrimStopGradient, {InferImplStopGradient, false}}, + {prim::kPrimStringEqual, {InferImplStringEqual, false}}, + {prim::kPrimStringConcat, {InferImplStringConcat, false}}, + {prim::kPrimDictLen, {InferImplDictLen, false}}, + // NN + {prim::kPrimPooling, {InferImplPooling, true}}, + {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, + {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, + {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, + {prim::kPrimReluGrad, {InferImplReluGrad, true}}, + {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, + {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, + {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, + {prim::kPrimRelu, {InferImplRelu, true}}, + {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, + {prim::kPrimZerosLike, {InferImplZerosLike, true}}, + {prim::kPrimBpropCut, {InferImplBpropCut, true}}, + {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, + {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, + {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, + // Others + {prim::kPrimIdentity, {InferImplIdentity, true}}, + // Set impl to null as it will use PartialEvaluator; + {prim::kPrimPartial, {nullptr, true}}, + {prim::kPrimJ, {InferImplJ, false}}, + {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, + {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, + {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, + {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, + {prim::kPrimMakeRef, {InferImplMakeRef, true}}, + {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, + {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, + {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, + {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, + {prim::kPrimDepend, {InferImplDepend, true}}, + {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, + {prim::kPrimControlDepend, {InferImplControlDepend, true}}, + // Debug + {prim::kPrimDebug, {InferImplDebug, true}}, + // IndexedSlices + {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, + {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, + {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, + {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, + {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, + }; + return prim_eval_implement_map; +} + +using mindspore::parse::PyObjectWrapper; + +EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + } + prim_->BeginRecordAddAttr(); + AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); + prim_->EndRecordAddAttr(); + auto added_attrs = prim_->evaluate_added_attrs(); + auto infer_result = std::make_shared(abs_base, std::make_shared(added_attrs)); + return infer_result; +} + +EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + + auto do_signature = dyn_cast(prim_); + auto out_node = dyn_cast(out_conf->node()); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString() + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + + AnfNodePtr new_cnode = nullptr; + if (bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, + args_inputs); + TraceManager::EndTrace(); + } else { + new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, + args_inputs); + } + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + +static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { + // arg[0] is the func graph to unpack, ignore it + AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); + AbstractBasePtrList graph_specialize_args; + if (need_unpack) { + for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { + MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); + if (specialize_args_before_unpack[index]->isa()) { + AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast(); + std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), + std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); + } else if (specialize_args_before_unpack[index]->isa()) { + AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast(); + auto dict_elems = arg_dict->elements(); + (void)std::transform( + dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), + [](const AbstractAttribute &item) { return std::make_shared(item.first, item.second); }); + } else { + MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " + << specialize_args_before_unpack[index]->ToString(); + } + } + } else { + graph_specialize_args = specialize_args_before_unpack; + } + return graph_specialize_args; +} + +EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + + auto unpack_graph = prim_->cast(); + auto out_node = out_conf->node()->cast(); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + // get the forward graph + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + AbstractFunctionPtr fn = args_spec_list[0]->cast(); + if (fn == nullptr) { + MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); + } + auto real_fn = fn->cast(); + MS_EXCEPTION_IF_NULL(real_fn); + FuncGraphPtr forward_graph = real_fn->func_graph(); + MS_EXCEPTION_IF_NULL(forward_graph); + AbstractBasePtrList graph_specialize_args = + GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); + + AbstractBasePtrList graph_specialize_args_without_sens; + (void)std::transform(graph_specialize_args.begin(), + graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), + std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); + auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens); + engine->func_graph_manager()->AddFuncGraph(new_graph); + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + AnfNodePtr new_vnode = NewValueNode(new_graph); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + +AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type, + FuncGraphPtr func_graph) { + AnfNodePtr target_node = source_node; + if (node_type->isa()) { + auto x = node_type->cast(); + if (x->element()->BuildType()->isa()) { + auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); + MS_EXCEPTION_IF_NULL(cast); + target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); + } + } else if (node_type->isa()) { + auto x = node_type->cast(); + auto &items = x->elements(); + std::vector nodes; + nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + int idx = 0; + for (const auto &item : items) { + AnfNodePtr tuple_node = + func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)}); + AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph); + nodes.emplace_back(node); + ++idx; + } + target_node = func_graph->NewCNode(nodes); + } else if (node_type->isa()) { + auto x = node_type->cast(); + auto &items = x->elements(); + std::vector dict_key_nodes; + std::vector dict_value_nodes; + dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (const auto &item : items) { + AnfNodePtr dict_value_node = + func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); + AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); + dict_key_nodes.emplace_back(NewValueNode(item.first)); + dict_value_nodes.emplace_back(node); + } + target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), + func_graph->NewCNode(dict_value_nodes)}); + } else if (node_type->isa()) { + auto x = node_type->cast(); + std::string kwarg_key = x->get_key(); + AnfNodePtr kwarg_value_node = + func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); + AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph); + target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node}); + } + return target_node; +} + +EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + AbstractBasePtrList args_spec_list; + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + auto out_node = out_conf->node()->cast(); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "MixedPrecisionCast" + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); + + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + + FuncGraphPtr func_graph = out_conf->node()->func_graph(); + AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + +namespace { +py::object BuildValue(const ValuePtr &value_ptr) { + if (value_ptr == nullptr) { + return py::none(); + } else { + return ValuePtrToPyData(value_ptr); + } +} +} // end anonymous namespace + +py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { + MS_EXCEPTION_IF_NULL(abs_base); + py::dict dic; + if (abs_base->isa()) { + auto arg_tensor = dyn_cast(abs_base); + dic["shape"] = arg_tensor->shape()->shape(); + dic["dtype"] = arg_tensor->BuildType(); + dic["value"] = BuildValue(arg_tensor->BuildValue()); + } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { + std::vector shape; + dic["shape"] = shape; + dic["dtype"] = abs_base->BuildType(); + dic["value"] = BuildValue(abs_base->BuildValue()); + } else if (abs_base->isa()) { + auto arg_slice = dyn_cast(abs_base); + std::vector shape; + dic["shape"] = shape; + dic["dtype"] = arg_slice->BuildType(); + dic["value"] = BuildValue(arg_slice->BuildValue()); + } else if (abs_base->isa()) { + auto value = abs_base->cast()->ref(); + dic = ConvertAbstractToPython(value); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = py::ellipsis(); + dic["value"] = py::ellipsis(); + } else if (abs_base->isa()) { + auto arg_tuple = dyn_cast(abs_base); + size_t len = arg_tuple->size(); + py::tuple shape_tuple(len); + py::tuple dtype_tuple(len); + + for (size_t i = 0; i < len; i++) { + py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); + shape_tuple[i] = out["shape"]; + dtype_tuple[i] = out["dtype"]; + } + dic["shape"] = shape_tuple; + dic["dtype"] = dtype_tuple; + dic["value"] = BuildValue(arg_tuple->BuildValue()); + } else if (abs_base->isa()) { + auto arg_list = dyn_cast(abs_base); + size_t len = arg_list->size(); + py::list shape_list(len); + py::list dtype_list(len); + + for (size_t i = 0; i < len; i++) { + py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); + shape_list[i] = out["shape"]; + dtype_list[i] = out["dtype"]; + } + dic["shape"] = shape_list; + dic["dtype"] = dtype_list; + dic["value"] = BuildValue(arg_list->BuildValue()); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = py::none(); + dic["value"] = py::none(); + } else if (abs_base->isa()) { + dic["shape"] = py::none(); + dic["dtype"] = abs_base->BuildType(); + dic["value"] = py::none(); + } else { + auto value = abs_base->BuildValue(); + if ((*value == *kAnyValue)) { + auto value_desc = abs_base->value_desc(); + MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) + << " for python primitive." << abs_base->ToString(); + } + MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " + << value->ToString(); + } + return dic; +} + +namespace { +py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) { + const AbstractBasePtrList *args_ptr; + + if (prim_py->is_tuple_input_) { + if (args.empty()) { + MS_LOG(EXCEPTION) << "Primitive args is empty"; + } + if (args[0] == nullptr || !args[0]->isa()) { + MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting" + "prim convert pass for GE."; + } + args_ptr = &(args[0]->cast()->elements()); + } else { + args_ptr = &args; + } + + py::tuple py_args(args_ptr->size()); + for (size_t i = 0; i < args_ptr->size(); i++) { + auto arg_i = (*args_ptr)[i]; + py_args[i] = ConvertAbstractToPython(arg_i); + } + return py_args; +} + +AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { + // Convert to AbstractValue based on type and shape + if (output["value"].is_none()) { + auto out_shape = output["shape"]; + auto out_dtype = output["dtype"]; + return PyListDtype2AbstractTensor(out_shape, out_dtype); + } + // Convert pyobject to Value, then to AbstractValue + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(output["value"], &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Convert data failed"; + } + auto res_spec = FromValue(converted_ret); + MS_EXCEPTION_IF_NULL(res_spec); + if (res_spec->isa()) { + // Replace to tensor constant node in specialize + auto res_tensor = res_spec->cast(); + res_tensor->set_value(converted_ret); + } + if (prim_py->IsCustomPrim()) { + // Raise error if output_num is not match the infer result. + int output_num = GetValue(prim_py->GetAttr("output_num")); + if (res_spec->isa() && output_num != 1) { + MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num + << " not matches the infer result."; + } else if (res_spec->isa() && + (res_spec->cast()->size() != IntToSize(output_num))) { + MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num + << " not matches the infer result."; + } + } + return res_spec; +} +} // end anonymous namespace + +EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; + return ret_abstract; + } + MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); + + const auto &iter = cache_->find(args); + if (iter != cache_->end()) { + return iter->second; + } + auto py_args = PreparePyInputs(prim_py_, args); + + auto pyobj = prim_py_->GetPyObj(); + if (pyobj == nullptr) { + MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; + } + auto infer_fuc = pyobj.attr("__infer__"); + prim_py_->BeginRecordAddAttr(); + py::dict output = infer_fuc(*py_args); + prim_py_->EndRecordAddAttr(); + auto added_attrs = prim_py_->evaluate_added_attrs(); + MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); + auto res_spec = PyInferRes2Abstract(prim_py_, output); + + MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; + auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); + (*cache_)[args] = infer_result; + return infer_result; +} + +EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; + return ret_abstract; + } + // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. + if (nargs_ != args.size()) { + MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; + return nullptr; + } + TypePtr ret_value_type = return_value_type_; + ValuePtrList value_list; + for (const auto &arg : args) { + // Check if all arguments are scalar type. + MS_EXCEPTION_IF_NULL(arg); + if (arg->isa()) { + auto arg_scalar = dyn_cast(arg); + auto arg_value = arg_scalar->GetValueTrack(); + value_list.push_back(arg_value); + } else { + // Raise TypeError Expected Scalar. + MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives."; + } + } + for (const auto &item : type_map_) { + TypePtrList selections; + MS_EXCEPTION_IF_NULL(item.second); + (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections), + [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); }); + TypePtr res = CheckTypeList(item.first, selections); + if (*return_value_type_ == *(item.first)) { + ret_value_type = res; + } + } + + ValuePtr evaluated_value = RunImpl(value_list); + if (!(*evaluated_value == *kAnyValue)) { + ret_value_type = evaluated_value->type(); + } + // for comparison primitives , return type shall have be specified to be bool. + if (specify_out_type_ != nullptr) { + ret_value_type = specify_out_type_; + } + + AbstractScalarPtr abs_base = std::make_shared(evaluated_value, ret_value_type); + return std::make_shared(abs_base, std::make_shared()); +} + +ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { + if (!eval_value_) { + return kAnyValue; + } else { + if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) { + MS_EXCEPTION_IF_NULL(arg); + return arg->isa(); + })) { + return kAnyValue; + } + return impl_(args); + } +} + +// Primitive implementation +// static function start +namespace { +EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) { + EvaluatorPtr prim_evaluator = std::make_shared(primitive, eval_impl); + return prim_evaluator; +} + +EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value, + const TypePtr &specify_out_type) { + FunctionPtr func = nullptr; + (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func); + MS_EXCEPTION_IF_NULL(func); + + EvaluatorPtr uniform_primitive_evaluator = + std::make_shared(func, prim_impl, eval_value, specify_out_type); + return uniform_primitive_evaluator; +} + +const int kResolveCaseUserDefineClass = 1; +const int kResolveCaseBuildinTypeMethod = 2; +const int kResolveCaseFunction = 3; +int GetResolveCase(const TypePtr &data_type) { + MS_EXCEPTION_IF_NULL(data_type); + if (data_type->type_id() == kObjectTypeClass) { + return kResolveCaseUserDefineClass; + } + + // try method map, if not in method map, the data_type should be External type. + if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { + return kResolveCaseBuildinTypeMethod; + } + + return kResolveCaseFunction; +} + +FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) { + MS_EXCEPTION_IF_NULL(engine); + MS_EXCEPTION_IF_NULL(method); + if (!method->isa()) { + MS_LOG(EXCEPTION) << "Method type error: " << method->ToString(); + } + + std::shared_ptr obj = method->cast>(); + FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj()); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed"; + } + + FuncGraphManagerPtr manager = engine->func_graph_manager(); + manager->AddFuncGraph(func_graph); + return func_graph; +} + +inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) { + MS_EXCEPTION_IF_NULL(engine); + FuncGraphManagerPtr manager = engine->func_graph_manager(); + manager->AddFuncGraph(func_graph); +} + +EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &old_conf) { + MS_EXCEPTION_IF_NULL(old_conf); + + AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); + AbstractFunctionPtr abs_func = dyn_cast(abs_ptr); + MS_EXCEPTION_IF_NULL(abs_func); + + // Create new cnode + std::vector input = {NewValueNode(prim::kPrimPartial)}; + auto func_graph_func = dyn_cast(abs_func); + if (func_graph_func != nullptr) { + FuncGraphPtr fg = func_graph_func->func_graph(); + input.push_back(NewValueNode(fg)); + } else { + auto prim_func = dyn_cast(abs_func); + MS_EXCEPTION_IF_NULL(prim_func); + PrimitivePtr prim = prim_func->prim(); + input.push_back(NewValueNode(prim)); + } + + AnfNodeConfigPtr conf = dyn_cast(data_conf); + MS_EXCEPTION_IF_NULL(conf); + input.push_back(conf->node()); + MS_EXCEPTION_IF_NULL(old_conf); + FuncGraphPtr func_graph = old_conf->node()->func_graph(); + CNodePtr new_cnode = func_graph->NewCNode(input); + AnalysisEnginePtr eng = old_conf->engine(); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); + return eng->ForwardConfig(old_conf, fn_conf); +} + +EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list, + const AnfNodeConfigPtr &out_conf) { + // args_spec_list: same as StaticGetter + if (args_spec_list.size() < 2) { + MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; + } + MS_EXCEPTION_IF_NULL(out_conf); + // An external type. + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); + MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); + auto data_v = args_spec_list[0]->BuildValue(); + if (!data_v->isa()) { + MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString(); + } + + auto item_v = args_spec_list[1]->BuildValue(); + if (item_v->isa()) { + item_v = std::make_shared(item_v->cast()->value()); + } + + if (!item_v->isa()) { + MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString(); + } + + // item_name to func addr from obj_map + parse::SymbolPtr symbol = item_v->cast(); + parse::NameSpacePtr name_space = data_v->cast(); + FuncGraphPtr func_graph = out_conf->node()->func_graph(); + + auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); + if (new_node == nullptr) { + MS_LOG(EXCEPTION) << "Resolve node failed"; + } + + AnalysisEnginePtr eng = out_conf->engine(); + AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); + return eng->ForwardConfig(out_conf, fn_conf); +} + +EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, + const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, + const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "args_spec_list is empty"; + } + AbstractClassPtr cls = CheckArg("__FUNC__", args_spec_list, 0); + + // If item_v is an attribute, get abstract value from AbstractClass + MS_EXCEPTION_IF_NULL(item_v); + if (!item_v->isa()) { + MS_LOG(EXCEPTION) << "Attribute type error"; + } + std::string item_name = item_v->cast()->value(); + MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name(); + MS_LOG(DEBUG) << "Resolve item: " << item_name; + + AbstractBasePtr attr = cls->GetAttribute(item_name); + if (attr != nullptr) { + return std::make_shared(attr, nullptr); + } + + ValuePtr method = cls->GetMethod(item_name); + if (method->isa()) { + MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() + << ", item value: " << item_v->ToString(); + } + + // Infer class method + ValuePtr converted_v = PyObjToGraph(engine, method); + return StaticGetterInferred(converted_v, data_conf, out_conf); +} + +EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, + const TypePtr &data_type, const ConfigPtr &data_conf, + const AnfNodeConfigPtr &out_conf) { + MS_EXCEPTION_IF_NULL(item_v); + MS_EXCEPTION_IF_NULL(data_type); + // The method maybe a Primitive or Composite + if (!item_v->isa()) { + MS_LOG(EXCEPTION) << "Error item is not string"; + } + + std::string item_name = item_v->cast()->value(); + Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); + if (method.empty()) { + MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; + } + + ValuePtr converted_v = nullptr; + if (method.is()) { + // composite registered in standard_method_map go to this branch + converted_v = prim::GetPythonOps(method.cast()); + AddToManager(engine, converted_v->cast()); + } else if (method.is()) { + converted_v = method.cast(); + } else { + MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); + } + return StaticGetterInferred(converted_v, data_conf, out_conf); +} + +EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { + // Inputs: namespace and its static function; or class and its member function + CheckArgsSize("StaticGetter", args_spec_list, 2); + + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + MS_EXCEPTION_IF_NULL(args_spec_list[1]); + TypePtr data_type = args_spec_list[0]->BuildType(); + ValuePtr item_value = args_spec_list[1]->BuildValue(); + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + if (item_value->isa()) { + MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); + } + + int case_v = GetResolveCase(data_type); + if (case_v == kResolveCaseUserDefineClass) { + return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); + } else if (case_v == kResolveCaseBuildinTypeMethod) { + return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); + } else { + return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); + } +} +} // end anonymous namespace + +// static variable start; +namespace { +class EmbedEvaluator : public SymbolicPrimEvaluator { + public: + EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} + ~EmbedEvaluator() override = default; + MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); + EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { + // arg: free variable to be embedded + if (args_conf_list.size() != 1) { + MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); + } + AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); + MS_EXCEPTION_IF_NULL(node_conf); + + AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); + x = SensitivityTransform(x); + SymbolicKeyInstancePtr key = std::make_shared(node_conf->node(), x); + AbstractScalarPtr abs_scalar = std::make_shared(key, std::make_shared()); + return std::make_shared(abs_scalar, std::make_shared()); + } +}; + +static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) { + auto root_g_set = manager->roots(); + if (root_g_set.size() != 1) { + return nullptr; + } + const FuncGraphPtr &root_g = root_g_set.back(); + + for (auto ¶m_node : root_g->parameters()) { + auto param = param_node->cast(); + if (param && name == param->name()) { + return param; + } + } + return nullptr; +} + +class RefToEmbedEvaluator : public SymbolicPrimEvaluator { + public: + RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} + ~RefToEmbedEvaluator() override = default; + MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); + EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { + if (args_conf_list.size() != 1) { + MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); + return nullptr; + } + static TypePtr type = std::make_shared(); + auto node_conf = dyn_cast(args_conf_list[0]); + if (node_conf == nullptr) { + MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; + return nullptr; + } + AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); + AbstractRefPtr ref_abs = abs->cast(); + if (ref_abs == nullptr) { + MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); + return nullptr; + } + auto key_abs = ref_abs->ref_key(); + if (key_abs == nullptr) { + MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr."; + return nullptr; + } + auto key_value = key_abs->BuildValue(); + if (key_value == nullptr) { + MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; + return nullptr; + } + auto refkey = key_value->cast(); + if (refkey == nullptr) { + auto ret = std::make_shared(type); + auto ref_value = ref_abs->ref(); + MS_EXCEPTION_IF_NULL(ref_value); + return std::make_shared(ret, std::make_shared()); + } + + std::string name = refkey->tag(); + const auto &manager = node_conf->node()->func_graph()->manager(); + auto node = FindParameterNodeByString(manager, name); + if (node == nullptr) { + MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph."; + return nullptr; + } + AbstractBasePtr x = ref_abs->ref(); + x = SensitivityTransform(x); + std::shared_ptr key = std::make_shared(node, x); + std::shared_ptr abs_scalar = std::make_shared(key, type); + return std::make_shared(abs_scalar, std::make_shared()); + } +}; + +class GetAttrEvaluator : public TransitionPrimEvaluator { + public: + GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} + ~GetAttrEvaluator() override = default; + MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; + return ret_abstract; + } + // Inputs: data, item + if (args_spec_list.size() != 2) { + MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); + } + EvalResultPtr ret = nullptr; + if (bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + TraceManager::EndTrace(); + } else { + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + } + // don't lookup from cache, as different out_conf with same node but different context + // may add different entry to anfnode_config_map, like getattr primitive; + (*cache_)[args_spec_list] = ret; + return ret; + } +}; + +class ResolveEvaluator : public TransitionPrimEvaluator { + public: + ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} + ~ResolveEvaluator() override = default; + MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, + const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + // Inputs: namespace, symbol + if (args_spec_list.size() != 2) { + MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); + } + EvalResultPtr ret = nullptr; + if (bound_node() != nullptr) { + TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + TraceManager::EndTrace(); + } else { + ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); + } + return ret; + } +}; + +class CreateInstanceEvaluator : public TransitionPrimEvaluator { + public: + CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} + ~CreateInstanceEvaluator() override = default; + MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, + const AnfNodeConfigPtr &out_conf) override { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; + } + + // get the type parameter + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + TypePtr type = args_spec_list[0]->GetTypeTrack(); + if (type->type_id() != kMetaTypeTypeType) { + MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " + << type->ToString(); + } + + ValuePtr value_track = args_spec_list[0]->GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + + std::shared_ptr type_obj = dyn_cast(value_track); + if (type_obj == nullptr) { + MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; + } + + if (!type_obj->isa()) { + MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " + << type_obj->ToString() << "."; + } + + auto class_type = type_obj->obj(); + MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; + + // get the create instance obj's parameters + pybind11::tuple params = GetParameters(args_spec_list); + + // create class instance + auto obj = parse::data_converter::CreatePythonObject(class_type, params); + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type"; + } + + // process the object + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(obj, &converted_ret, true); + if (!converted) { + MS_LOG(EXCEPTION) << "Convert the python object failed"; + } + MS_EXCEPTION_IF_NULL(converted_ret); + + if (converted_ret->isa()) { + AddToManager(engine, converted_ret->cast()); + } + + AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); + auto infer_result = std::make_shared(ret, nullptr); + (*cache_)[args_spec_list] = infer_result; + return infer_result; + } + + pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { + // Exclude class type by minus 1; + std::size_t params_size = args_spec_list.size() - 1; + auto params = py::tuple(params_size); + if (params_size > 0) { + for (size_t i = 0; i < params_size; i++) { + // Only support the Scalar parameters type. Bypass class type by offset with 1. + auto arg = args_spec_list[i + 1]; + MS_EXCEPTION_IF_NULL(arg); + // Because the Tensor's AbstractTensor can't get value from GetValueTrack. + ValuePtr param_value = arg->BuildValue(); + py::object param = ValuePtrToPyData(param_value); + params[i] = param; + } + } + return params; + } +}; + +class PartialEvaluator : public Evaluator { + public: + PartialEvaluator() : Evaluator("PartialEvaluator") {} + ~PartialEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf = nullptr) override { + if (args_conf_list.size() == 0) { + MS_LOG(EXCEPTION) << "Args size should be greater than 0"; + } + + MS_EXCEPTION_IF_NULL(out_conf); + MS_EXCEPTION_IF_NULL(out_conf->node()); + auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); + AbstractBasePtrList args_spec_list{arg0_value}; + // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. + if (arg0_value->isa()) { + auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); + MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() + << " as func is: " << arg0_value->ToString(); + auto eval_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = eval_result; + return eval_result; + } + auto func = CheckArg("partial", args_spec_list, 0); + // Sometimes, node[0] in out_conf becomes phi0; + if (func->isa()) { + auto prim_func = dyn_cast(func); + if (prim_func->prim()->isa()) { + prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast(prim_func->prim()); + return HandleDoSignature(engine, do_signature_prim->function(), out_conf); + } + } + + (void)std::transform( + args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); + AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); + + auto cnode = out_conf->node()->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != (args_conf_list.size() + 1)) { + MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() + << ", args_conf_list: " << mindspore::ToString(args_conf_list); + } + + AbstractFuncAtomPtrList partial_funcs_list; + auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { + auto new_func = std::make_shared(atom_func, args, cnode); + partial_funcs_list.push_back(new_func); + }; + func->Visit(build_partial); + + auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); + auto infer_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = infer_result; + return infer_result; + } + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, + const AnfNodeConfigPtr &out_conf = nullptr) const { + MS_EXCEPTION_IF_NULL(out_conf); + MS_EXCEPTION_IF_NULL(out_conf->node()); + auto cnode = out_conf->node()->cast(); + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "Cnode is nullptr"; + } + std::vector new_nodes_inputs = cnode->inputs(); + auto new_signature_value = std::make_shared("signature", signature_value); + new_nodes_inputs[1] = NewValueNode(new_signature_value); + FuncGraphPtr func_graph = cnode->func_graph(); + + ScopePtr scope = out_conf->node()->scope(); + ScopeGuard scope_guard(scope); + + CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); + return engine->ForwardConfig(out_conf, fn_conf); + } +}; + +struct PrimitiveImplInferValue { + PrimitiveImpl impl_; // implement function of primitive + bool eval_value_; // whether evaluate value + TypePtr specify_out_type_; // whether specify return type + bool in_white_list_; // true if this Primitive in white list, else false. +}; + +using PrimitiveToImplMap = std::unordered_map; +PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { + static PrimitiveToImplMap uniform_prim_implement_map = { + {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, + {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, + {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, + {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, + {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, + {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}}, + {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}}, + {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, + {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, + {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, + {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared(), true}}, + {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared(), true}}, + {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared(), true}}, + {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared(), true}}, + {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared(), true}}, + {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared(), true}}, + {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared(), true}}, + {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared(), true}}, + {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared(), true}}, + {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared(), true}}, + }; + return uniform_prim_implement_map; +} + +PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); +std::mutex PrimEvaluatorConstructorMutex; + +void InitPrimEvaluatorConstructors() { + PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; + + for (const auto &iter : GetPrimitiveToEvalImplMap()) { + constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); + } + + for (const auto &iter : GetUniformPrimitiveToImplMap()) { + constructor[iter.first] = + InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); + } + constructor[prim::kPrimEmbed] = std::make_shared(); + constructor[prim::kPrimRefToEmbed] = std::make_shared(); + constructor[prim::kPrimGetAttr] = std::make_shared(); + constructor[prim::kPrimResolve] = std::make_shared(); + constructor[prim::kPrimCreateInstance] = std::make_shared(); + constructor[prim::kPrimPartial] = std::make_shared(); +} +} // namespace + +void ClearPrimEvaluatorMap() { + PrimEvaluatorConstructors.clear(); + GetPrimitiveToEvalImplMap().clear(); + GetUniformPrimitiveToImplMap().clear(); +} + +bool IsInWhiteList(const PrimitivePtr primitive) { + MS_EXCEPTION_IF_NULL(primitive); + + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter != GetPrimitiveToEvalImplMap().end()) { + return iter->second.in_white_list_; + } + + auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive); + if (uni_iter != GetUniformPrimitiveToImplMap().end()) { + return uni_iter->second.in_white_list_; + } + + return false; +} + +StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + auto iter = GetPrimitiveToEvalImplMap().find(primitive); + if (iter == GetPrimitiveToEvalImplMap().end()) { + return nullptr; + } + return iter->second.impl_; +} + +PrimEvaluatorMap &GetPrimEvaluatorConstructors() { + PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; + if (!constructor.empty()) { + return constructor; + } + std::lock_guard initLock(PrimEvaluatorConstructorMutex); + if (constructor.empty()) { + InitPrimEvaluatorConstructors(); + } + + return constructor; +} + +namespace { +bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_tuple = dyn_cast(x); + auto model_tuple = dyn_cast(model); + + if (x_tuple == nullptr || model_tuple == nullptr) { + return false; + } + + if (model->IsGeneric()) { + return true; + } + + if (x_tuple->size() != model_tuple->size()) { + return false; + } + + for (size_t i = 0; i < x_tuple->size(); i++) { + bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]); + if (!is_subtype) { + return false; + } + } + return true; +} + +bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_tensor = dyn_cast(x); + auto model_tensor = dyn_cast(model); + + if (x_tensor == nullptr || model_tensor == nullptr) { + return false; + } + + if (model->IsGeneric()) { + return true; + } + + return IsSubtype(x_tensor->element(), model_tensor->element()); +} + +bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_list = dyn_cast(x); + auto model_list = dyn_cast(model); + + if (x_list == nullptr || model_list == nullptr) { + return false; + } + + if (model->IsGeneric()) { + return true; + } + + if (x_list->size() != model_list->size()) { + return false; + } + + bool is_subtype = true; + for (size_t i = 0; i < x_list->size(); i++) { + is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]); + if (!is_subtype) { + return false; + } + } + return is_subtype; +} + +bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + auto x_class = dyn_cast(x); + auto model_class = dyn_cast(model); + if (x_class == nullptr) { + return false; + } + if (model->IsGeneric()) { + return true; + } + + if (x_class->tag() == model_class->tag()) { + auto m_attributes = model_class->GetAttributes(); + auto x_attributes = x_class->attributes(); + if (m_attributes.size() != x_attributes.size()) { + return false; + } + + for (size_t i = 0; i < m_attributes.size(); i++) { + if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) { + return false; + } + } + return true; + } + + return false; +} + +inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + if (dyn_cast(x) == nullptr) { + return false; + } + TypePtr x_type = x->GetTypeTrack(); + return IsSubType(x_type, model); +} +} // namespace + +bool IsSubtype(const AbstractBasePtr x, const TypePtr model) { + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(model); + TypeId model_typeid = model->type_id(); + switch (model_typeid) { + case kMetaTypeObject: + return true; + case kObjectTypeTuple: + return IsSubtypeTuple(x, model); + case kObjectTypeTensorType: + return IsSubtypeArray(x, model); + case kObjectTypeList: + return IsSubtypeList(x, model); + case kObjectTypeClass: + return IsSubtypeClass(x, model); + default: + if (IsSubType(model, std::make_shared())) { + return IsSubtypeScalar(x, model); + } + MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << "."; + } +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h new file mode 100644 index 0000000000..692fbe66e8 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -0,0 +1,366 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_PRIM_H_ +#define PIPELINE_STATIC_ANALYSIS_PRIM_H_ + +#include +#include +#include +#include +#include + +#include "pipeline/jit/static_analysis/evaluator.h" + +namespace mindspore { +namespace abstract { +using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &); +struct StandartPrimitiveImplReg { + StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. + bool in_white_list_; // true if this Primitive in white list, else false. +}; + +using PrimitiveEvalImplMap = + std::unordered_map; + +class StandardPrimEvaluator : public TrivialPrimEvaluator { + public: + StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) + : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} + ~StandardPrimEvaluator() override = default; + MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + PrimitivePtr prim() { return prim_; } + + std::string ToString() const override { return identifier_ + prim_->name(); } + + private: + PrimitivePtr prim_; + const StandardPrimitiveEvalImpl eval_impl_; +}; + +using StandardPrimEvaluatorPtr = std::shared_ptr; + +class PythonPrimEvaluator : public TrivialPrimEvaluator { + public: + explicit PythonPrimEvaluator(const PrimitivePyPtr primitive) + : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} + ~PythonPrimEvaluator() override = default; + MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + PrimitivePtr prim() { return dyn_cast(prim_py_); } + + std::string ToString() const override { return identifier_ + prim_py_->name(); } + + private: + PrimitivePyPtr prim_py_; +}; + +class DoSignatureEvaluator : public Evaluator { + public: + explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} + ~DoSignatureEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + +class UnpackGraphEvaluator : public Evaluator { + public: + explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} + ~UnpackGraphEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + +class MixedPrecisionCastEvaluator : public Evaluator { + public: + explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive) + : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} + ~MixedPrecisionCastEvaluator() override = default; + EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + +bool IsInWhiteList(PrimitivePtr primitive); +StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); + +using ValuePtrList = std::vector; +using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); + +class UniformPrimEvaluator : public TrivialPrimEvaluator { + public: + UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type) + : TrivialPrimEvaluator("UniformPrimEvaluator"), + impl_(impl), + eval_value_(eval_value), + func_desc_(func_desc), + nargs_(func_desc_->args().size()), + return_value_type_(func_desc_->retval()), + specify_out_type_(specify_out_type) { + for (size_t i = 0; i < nargs_; ++i) { + TypePtr type = func_desc_->args()[i]; + if (type_map_[type]) { + type_map_[type]->push_back(i); + } else { + type_map_[type] = std::make_shared>(); + type_map_[type]->push_back(i); + } + } + } + ~UniformPrimEvaluator() override = default; + MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); + + EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; + ValuePtr RunImpl(const ValuePtrList &args) const; + + // If eval_value_ is False, return broadened arguments. + AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { + if (!eval_value_) { + AbstractBasePtrList broadened_args_spec_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); + return broadened_args_spec_list; + } + return args_spec_list; + } + + private: + PrimitiveImpl impl_; + bool eval_value_; + const FunctionPtr func_desc_; + const std::size_t nargs_; + const TypePtr return_value_type_; + const TypePtr specify_out_type_; + std::unordered_map>, TypeHasher, TypeEqual> type_map_; +}; + +PrimEvaluatorMap &GetPrimEvaluatorConstructors(); + +// Check whether type x is a subtype of model. +bool IsSubtype(const AbstractBasePtr x, const TypePtr model); + +void ClearPrimEvaluatorMap(); + +py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); + +AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +} // namespace abstract +} // namespace mindspore + +#endif // PIPELINE_STATIC_ANALYSIS_PRIM_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc new file mode 100644 index 0000000000..ad39190dc3 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -0,0 +1,728 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/program_specialize.h" + +#include +#include +#include "./common.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/do_signature.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "utils/graph_utils.h" +#include "utils/log_adapter.h" +#include "utils/profile.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +namespace { +inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { + if (conf->node()->intermediate_abstract()) { + return conf->node()->intermediate_abstract(); + } + return conf->GetEvaluatedValue()->abstract(); +} + +AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { + AnfNodePtr value_node = NewValueNode(v); + value_node->set_abstract(abs_base); + MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString(); + return value_node; +} + +bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) { + while (fg != nullptr && fg != parent) { + fg = fg->parent(); + } + return fg == parent; +} +} // namespace + +FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(context); + MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString(); + return SpecializeFuncGraph(fg, context); +} + +FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(context); + auto iter = specializations_.find(context->SpecializeKey()); + if (iter != specializations_.end()) { + return iter->second->specialized_func_graph(); + } + + std::shared_ptr fg_spec = std::make_shared(this, fg, context); + FuncGraphPtr fg2 = fg_spec->specialized_func_graph(); + specializations_[context->SpecializeKey()] = fg_spec; + fg_spec->Run(); + return fg2; +} + +std::shared_ptr ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) { + MS_EXCEPTION_IF_NULL(context); + auto iter = specializations_.find(context->SpecializeKey()); + if (iter != specializations_.end()) { + return iter->second; + } + return nullptr; +} + +std::string GetNextCounter() { + static int g_CloneCounter = 1; + std::string str_count = std::to_string(g_CloneCounter); + g_CloneCounter++; + return str_count; +} + +FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, + const AnalysisContextPtr &context) + : specializer_(s), func_graph_(fg), context_(context) { + parent_ = s->GetFuncGraphSpecializer(context->parent()); + engine_ = s->engine(); + cloner_ = SpecializerClone(fg, std::make_shared(GetNextCounter())); + repl_node_ = cloner_->cloned_node(); + specialized_func_graph_ = cloner_->cloned_func_graph()[fg]; + todo_.push_back(fg->get_return()); + auto ps = fg->parameters(); + (void)todo_.insert(todo_.end(), ps.begin(), ps.end()); +} + +AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + FuncGraphPtr fg = node->func_graph(); + + if (node->isa()) { + return node; + } + std::shared_ptr specializer = shared_from_this(); + while (fg != nullptr && fg != specializer->func_graph_) { + specializer = specializer->parent_; + } + // If had replicated, just return that. + auto iter = specializer->repl_node_->find(node); + if (iter != specializer->repl_node_->end()) { + return iter->second; + } + + auto new_node = specializer->cloner_->CloneDisconnected(node); + if (node->isa()) { + if (!new_node->isa()) { + MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; + } + auto c_node = node->cast(); + MS_EXCEPTION_IF_NULL(c_node); + auto inputs = c_node->inputs(); + std::vector new_inputs; + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs), + [this](const AnfNodePtr &inp) -> AnfNodePtr { + if (inp->isa()) { + return inp; + } + return ReplicateDisconnectedNode(inp); + }); + auto c_new_node = new_node->cast(); + MS_EXCEPTION_IF_NULL(c_new_node); + c_new_node->set_inputs(new_inputs); + } + + iter = specializer->repl_node_->find(node); + if (iter != specializer->repl_node_->end()) { + if (iter->second == node) { + MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString(); + } + } else { + MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString(); + } + return new_node; +} + +AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + FuncGraphPtr fg = node->func_graph(); + + std::shared_ptr specializer = shared_from_this(); + while (fg != nullptr && fg != specializer->func_graph_) { + specializer = specializer->parent_; + } + + MS_EXCEPTION_IF_NULL(specializer->repl_node_); + auto iter = specializer->repl_node_->find(node); + if (iter != specializer->repl_node_->end()) { + return iter->second; + } + return node; +} + +void FuncGraphSpecializer::Run() { + MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString() + << ", cloned func graph name: " << specialized_func_graph_->ToString() + << ", func graph: " << func_graph_->get_return()->DebugString(); + FirstPass(); + SecondPass(); + MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString() + << ", cloned func graph name: " << specialized_func_graph_->ToString() + << ", new func graph: " << specialized_func_graph_->get_return()->DebugString(); +} + +void FuncGraphSpecializer::FirstPass() { + while (todo_.size()) { + AnfNodePtr node = todo_.back(); + todo_.pop_back(); + if (node->func_graph() == nullptr) { + // do nothing for ValueNode + continue; + } + if (node->func_graph() != func_graph_) { + if (parent_ == nullptr) { + MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + parent_->AddTodoItem(node); + parent_->FirstPass(); + AnfNodePtr new_node = parent_->GetReplicatedNode(node); + if (node->isa()) { + parent_->ProcessCNode(new_node->cast()); + } + continue; + } + if (marked_.count(node) > 0) { + continue; + } + (void)marked_.insert(node); + ProcessNode(node); + } +} + +// Specialize CNode in func graphs +void FuncGraphSpecializer::SecondPass() { + for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { + if (node->isa()) { + ProcessCNode(node->cast()); + } + } +} + +void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + ScopeGuard scope_guard(node->scope()); + AnfNodeConfigPtr conf = MakeConfig(node); + AnfNodePtr new_node = GetReplicatedNode(node); + MS_EXCEPTION_IF_NULL(new_node); + if (new_node->func_graph() != specialized_func_graph_) { + MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() + << ", new_node: " << new_node->DebugString() + << ", new_node->func_graph(): " << new_node->func_graph()->ToString() + << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); + return; + } + new_node->set_abstract(GetEvaluatedValueWrap(conf)); + if (new_node->isa() && new_node->abstract()->isa()) { + auto partial_abstract = dyn_cast(new_node->abstract()); + if (partial_abstract->node() == node) { + partial_abstract->set_node(new_node); + } + } + + MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); + + if (node->isa()) { + auto attrs = conf->GetEvaluatedValue()->attribute(); + auto c_old = node->cast(); + auto c_new = new_node->cast(); + auto new_inputs = c_new->inputs(); + auto old_inputs = c_old->inputs(); + for (size_t i = 0; i < old_inputs.size(); ++i) { + auto node_input = old_inputs[i]; + AnfNodeConfigPtr iconf = MakeConfig(node_input); + AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); + // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if + // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. + AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); + if (replace_node == nullptr) { + replace_node = BuildReplacedNode(iconf); + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_abstract(ival); + MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); + } else { + MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() + << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); + } + if (new_inputs[i] != replace_node) { + new_inputs[i] = replace_node; + MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); + } + } + c_new->set_inputs(new_inputs); + } +} + +AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + + auto conf_iter = engine_->anfnode_config_map().find(conf); + AnfNodeConfigPtr new_conf = conf; + while (conf_iter != engine_->anfnode_config_map().end()) { + MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node(" + << new_conf->node()->DebugString() << ")"; + new_conf = conf_iter->second; + MS_EXCEPTION_IF_NULL(new_conf); + MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node(" + << conf->node()->DebugString() << ")"; + (void)ReplicateDisconnectedNode(new_conf->node()); + conf_iter = engine_->anfnode_config_map().find(new_conf); + } + todo_.push_back(new_conf->node()); + auto repl = GetReplicatedNode(new_conf->node()); + if (repl->func_graph()) { + MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString() + << ") to replace origin:" << new_conf->node()->DebugString(); + } else { + MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString() + << ") to replace origin: " << new_conf->node()->DebugString(); + } + return repl; +} + +namespace { +const StringImmPtr kDeadNode = std::make_shared("Dead Node"); +const StringImmPtr kPolyNode = std::make_shared("Poly Node"); + +inline bool CanSpecializeNode(const AnfNodePtr &node) { + if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { + return true; + } + return false; +} +} // namespace + +AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractBasePtrList &argvals) { + MS_EXCEPTION_IF_NULL(abs); + AbstractFunctionPtr real_a = dyn_cast(abs); + MS_EXCEPTION_IF_NULL(real_a); + + AbstractFunctionPtr func = real_a->GetUnique(); + SpecializeStatusCode errcode; + ScopeGuard scope_guard(node->scope()); + AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode); + if (repl == nullptr) { + if (errcode == kSpecializeFindUniqueArgvalDead) { + const auto error_dead_node = std::make_shared(kDeadNode, node); + repl = BuildValueNode(kDeadNode, error_dead_node); + MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString(); + } else if (errcode == kSpecializeFindUniqueArgvalPoly) { + const auto error_poly_node = std::make_shared(kPolyNode, node); + repl = BuildValueNode(kPolyNode, error_poly_node); + MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString(); + } else { + MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString() + << ", abstract: " << abs->ToString(); + } + } + + return repl; +} + +AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractFunctionPtr &func, + const AbstractBasePtrList &args, + SpecializeStatusCode *errcode) { + MS_EXCEPTION_IF_NULL(abs); + MS_EXCEPTION_IF_NULL(func); + MS_EXCEPTION_IF_NULL(errcode); + *errcode = kSpecializeSuccess; + + auto real_func = dyn_cast(func); + if (real_func != nullptr) { + return BuildValueNode(real_func->prim(), abs); + } + + EvaluatorPtr eval; + eval = engine_->GetEvaluatorFor(func); + MS_EXCEPTION_IF_NULL(eval); + AbstractBasePtrList argvals = eval->NormalizeArgs(args); + + std::pair result; + SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result); + if (status != kSpecializeSuccess) { + *errcode = status; + return nullptr; + } + argvals = result.first; + AbstractBasePtr unique_output = result.second; + + auto prim_func = dyn_cast(func); + if (prim_func != nullptr) { + auto type_func = std::make_shared(prim_func->prim(), argvals, unique_output); + return BuildValueNode(prim_func->prim(), type_func); + } + + if (!eval->isa()) { + MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString(); + } + auto real_eval = dyn_cast(eval); + + if (func->context() == nullptr) { + MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); + } + AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); + MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() + << ", graph: " << context->func_graph()->get_return()->DebugString(); + if (context->func_graph()->stub()) { + MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString() + << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString() + << ", " << node->ToString(); + return node; + } + FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); + v->set_flag(kFuncGraphFlagUndetermined, false); + return BuildValueNode(v, abs); +} + +AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { + auto new_inputs = new_node->inputs(); + AnfNodePtr func = new_inputs[0]; + AbstractBasePtr fnval = new_inputs[0]->abstract(); + + AbstractBasePtrList args; + auto backed_fnval = fnval; + if (fnval->isa()) { + auto partial_closure = dyn_cast(fnval); + backed_fnval = partial_closure->fn(); + args = partial_closure->args(); + } + std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), + [](const AnfNodePtr &inp) { return inp->abstract(); }); + + ScopeGuard scope_guard(new_node->scope()); + + auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); + auto wrapped_node = specialized_node; + if (fnval->isa()) { + auto partial_closure = dyn_cast(fnval); + AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), + specialized_node}; + auto anf_node = partial_closure->node(); + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); + } + auto cnode = anf_node->cast(); + if (cnode->size() != partial_closure->args().size() + 2) { + MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() + << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); + } + auto attrs = std::make_shared(); + for (size_t i = 0; i < partial_closure->args().size(); i++) { + auto old_node = cnode->input(i + 2); + auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); + if (possibile_value_node != nullptr) { + partial_node_list.push_back(possibile_value_node); + } else { + if (!(old_node->isa() || old_node->isa())) { + MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); + } + partial_node_list.push_back(old_node); + } + } + wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); + wrapped_node->set_abstract(partial_closure); + } + return wrapped_node; +} + +const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { + auto cache_iter = evalcaches_.find(eval); + if (cache_iter == evalcaches_.end()) { + evalcaches_[eval] = eval->cache(); + return eval->cache(); + } + return cache_iter->second; +} + +std::pair FuncGraphSpecializer::BuildFromBroadedArgsVal( + const EvaluatorPtr &eval) { + MS_EXCEPTION_IF_NULL(eval); + std::unordered_set choices; + EvalResultPtr ret = nullptr; + AbstractBasePtrList broaded_argvals; + for (auto &argvals_map : *evalcaches_[eval]) { + auto argvals = argvals_map.first; + broaded_argvals.clear(); + + (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals), + [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); + (void)choices.insert(broaded_argvals); + MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals); + } + + if (1 == choices.size()) { + ConfigPtrList args_conf_list; + (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), + [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared(v); }); + + // if broaden return null + ret = eval->Run(engine_, args_conf_list, nullptr); + EvaluatorCacheMapPtr real = std::make_shared(); + + (*real)[broaded_argvals] = ret; + evalcaches_[eval] = real; + return std::make_pair(broaded_argvals, ret->abstract()); + } else { + MS_LOG(DEBUG) << "Choices.size: " << choices.size(); + return std::make_pair(AbstractBasePtrList(), nullptr); + } +} + +void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(new_node); + if (specializer_->seen().count(new_node) > 0) { + return; + } + specializer_->AddSeen(new_node); + auto new_inputs = new_node->inputs(); + if (new_inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; + } + AnfNodePtr func = new_inputs[0]; + MS_EXCEPTION_IF_NULL(func); + + // First element is func so arg start from 1 + std::vector args(new_inputs.begin() + 1, new_inputs.end()); + // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) + while (IsPrimitiveCNode(func, prim::kPrimPartial)) { + std::vector inputs = func->cast()->inputs(); + // First element is partial, second is func so arg is start from 2 + (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); + func = inputs[1]; + } + new_inputs = args; + (void)new_inputs.insert(new_inputs.begin(), func); + + AbstractBasePtrList argvals; + MS_EXCEPTION_IF_NULL(new_inputs[0]); + AbstractBasePtr fnval = new_inputs[0]->abstract(); + MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", " + << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString(); + + // First element is func so function arguments start from 1 + for (size_t i = 1; i < new_inputs.size(); ++i) { + argvals.push_back(new_inputs[i]->abstract()); + MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", " + << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); + } + + if (!func->isa()) { + MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); + if (func->abstract()->isa() && !func->abstract()->isa()) { + auto func_abs = func->abstract()->cast(); + EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); + std::pair result; + AbstractBasePtrList empty_args; + auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); + MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; + // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early + if (status == kSpecializeFindUniqueArgvalPoly || + (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || + func->abstract()->isa()))) { + auto wrapped_node = BuildSpecializedParameterNode(new_node); + new_inputs[0] = wrapped_node; + } + } + } + + if (CanSpecializeNode(func)) { + // for primitive node , we build the primitive node with infered attributes in the first pass + // so we do not build replaced node again here in second pass + if (IsValueNode(func)) { + new_inputs[0] = func; + } else { + new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); + } + } + + for (size_t i = 0; i < argvals.size();) { + size_t next = i + 1; + if (CanSpecializeNode(args[i])) { + new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector{}); + } + i = next; + } + new_node->set_inputs(new_inputs); +} + +namespace { +void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) { + MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items."; + int i = 0; + for (const auto &item : evaluator_cache_map) { + MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first; + } +} + +bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) { + if (func->isa() && argvals.empty()) { + MS_LOG(DEBUG) << "High order primitive return POLY."; + return true; + } + if (func->isa() && argvals.empty()) { + auto meta_func_graph_wrapper = dyn_cast(func); + auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph(); + if (meta_func_graph != nullptr && meta_func_graph->isa()) { + auto do_signature = dyn_cast(meta_func_graph); + if (do_signature != nullptr && do_signature->function()->isa()) { + MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY."; + return true; + } + } + } + return false; +} +} // end anonymous namespace + +SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval, + const AbstractBasePtrList &argvals, + std::pair *result) { + MS_EXCEPTION_IF_NULL(func); + MS_EXCEPTION_IF_NULL(eval); + MS_EXCEPTION_IF_NULL(result); + + EvaluatorCacheMap evaluator_cache_map = *eval->cache(); + if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { + *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); + return kSpecializeSuccess; + } + DumpEvaluatorCache(evaluator_cache_map, argvals); + + const EvaluatorCacheMapPtr &choices = GetEvalCache(eval); + MS_EXCEPTION_IF_NULL(choices); + + if (choices->count(argvals)) { + *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); + return kSpecializeSuccess; + } else if (choices->size() == 1) { + MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; + *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); + return kSpecializeSuccess; + } else if (choices->empty()) { + MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | " + << func->type_name(); + return kSpecializeFindUniqueArgvalDead; + } else { + if (IsPolyFunc(func, argvals)) { + return kSpecializeFindUniqueArgvalPoly; + } + + MS_LOG(DEBUG) << "Try to find generalized argvals."; + *result = BuildFromBroadedArgsVal(eval); + if (!result->first.empty()) { + return kSpecializeSuccess; + } + MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism."; + return kSpecializeFindUniqueArgvalPoly; + } +} +static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { + auto &prim_attrs = prim->attrs(); + bool is_attr_same = true; + for (auto &item : *attrs) { + auto itr = prim_attrs.find(item.first); + if (itr != prim_attrs.end()) { + if (!(*(itr->second) == *(item.second))) { + is_attr_same = false; + break; + } + } else { + is_attr_same = false; + break; + } + } + if (!is_attr_same) { + if (prim->isa()) { + PrimitivePyPtr prim_py = prim->cast(); + auto clone_fn = prim_py->GetPyObj().attr("_clone"); + py::object new_obj = clone_fn(); + auto cloned_prim = new_obj.cast(); + for (auto &item : *attrs) { + cloned_prim->AddAttr(item.first, item.second); + } + return cloned_prim; + } + auto cloned_prim = std::make_shared(*prim); + for (auto &item : *attrs) { + cloned_prim->AddAttr(item.first, item.second); + } + return cloned_prim; + } + return prim; +} + +AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, + const AttrValueMapPtr &attrs) { + MS_EXCEPTION_IF_NULL(origin_node); + MS_EXCEPTION_IF_NULL(ival); + + AbstractFunctionPtr abs = dyn_cast(ival); + if (abs != nullptr) { + // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. + if (abs->isa()) { + return nullptr; + } + ValuePtr value = nullptr; + if (abs->isa()) { + auto real_fn = dyn_cast(abs); + // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one + if (attrs != nullptr) { + value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); + } else { + value = real_fn->prim(); + } + } else if (abs->isa()) { + auto real_fn = dyn_cast(abs); + value = real_fn->meta_func_graph(); + } else if (abs->isa()) { + auto real_fn = dyn_cast(abs); + value = real_fn->func_graph(); + } else { + return nullptr; + } + if (!value->isa() || value->cast()->parent() == nullptr || + (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { + return BuildValueNode(value, ival); + } else { + return nullptr; + } + } else { + ValuePtr val = ival->BuildValue(); + if (val->isa()) { + return nullptr; + } + // keep primitive 'depend' not to be optimized + if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) { + return nullptr; + } + return BuildValueNode(val, ival); + } +} + +AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) { + return engine_->MakeConfig(node, context_); +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h new file mode 100644 index 0000000000..d7f95be4ca --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -0,0 +1,136 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ +#define PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph_cloner.h" +#include "pipeline/jit/static_analysis/evaluator.h" + +namespace mindspore { +namespace abstract { +enum SpecializeStatusCode { + kSpecializeSuccess = 0, + kSpecializeFindUniqueArgvalDead = 1, // Dead Node + kSpecializeFindUniqueArgvalPoly = 2, // Poly Node + kSpecializeFailure = 0xFF +}; + +class FuncGraphSpecializer; + +// Specialize a func graph using analyzed abstract values. +class ProgramSpecializer { + public: + explicit ProgramSpecializer(const std::shared_ptr &engine) : engine_(engine) { + mng_ = engine_->func_graph_manager(); + } + ~ProgramSpecializer() = default; + // Run the program specializer on the topmost graph in the given context. + FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context); + const std::unordered_set &seen() const { return seen_; } + void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); } + + std::shared_ptr GetFuncGraphSpecializer(const AnalysisContextPtr &context); + // Specialze one FuncGraph in a given context. + FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context); + + std::shared_ptr engine() { return engine_; } + + private: + std::shared_ptr engine_; + std::unordered_set seen_; + FuncGraphManagerPtr mng_; + std::unordered_map, ContextHasher, ContextEqual> + specializations_; +}; + +class FuncGraphSpecializer : public std::enable_shared_from_this { + public: + FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context); + virtual ~FuncGraphSpecializer() { + specializer_ = nullptr; + repl_node_ = nullptr; + } + void Run(); + FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; } + + private: + ProgramSpecializer *specializer_; + FuncGraphPtr func_graph_; + FuncGraphPtr specialized_func_graph_; + AnalysisContextPtr context_; + std::shared_ptr parent_; + std::shared_ptr engine_; + ClonerPtr cloner_; + // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again. + // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that. + std::unordered_map *repl_node_; + std::vector todo_; + std::unordered_set marked_; + std::unordered_map evalcaches_; + + void FirstPass(); + void SecondPass(); + void ProcessNode(const AnfNodePtr &node); + void ProcessCNode(const CNodePtr &new_node); + + AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); + inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } + // Get node replicated by Cloner. + AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); + // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node + // (disconnected). + AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); + + // Build a value node from parameter if the function graph has special flag to hint it can be done. + AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); + + // Build a value node if ival is constant and not any-value + AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, + const AttrValueMapPtr &attrs); + // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a + // replicated node. + AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); + // Build a specialized node from given argvals; + AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractBasePtrList &argvals); + AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, + const AbstractFunctionPtr &func, const AbstractBasePtrList &args, + SpecializeStatusCode *errcode); + + // Find the unique argument values which can be used to specialize a primitive or graph function. + SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, + const AbstractBasePtrList &argvals, + std::pair *result); + // Get cache, it may be eval's cache or cache built from broaded argument values. + const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval); + // Try to build unique argvals from the broaded arg vals if it is unique. + std::pair BuildFromBroadedArgsVal(const EvaluatorPtr &eval); +}; +} // namespace abstract +} // namespace mindspore +#endif // PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc new file mode 100644 index 0000000000..acecb2980e --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -0,0 +1,655 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/static_analysis.h" + +#include +#include + +#include "abstract/utils.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" +#include "utils/symbolic.h" +#include "ir/tensor.h" +#include "ir/func_graph_cloner.h" +#include "./common.h" +#include "pipeline/jit/parse/data_converter.h" +#include "debug/draw.h" +#include "pipeline/jit/static_analysis/evaluator.h" +#include "debug/trace.h" + +namespace mindspore { +namespace abstract { +bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) { + if (dyn_cast(arg_spec)) { + auto v = arg_spec->GetValueTrack(); + if (v->isa()) { + return true; + } else { + return false; + } + } else { + return false; + } +} + +AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) { + if (dyn_cast(arg1) && dyn_cast(arg2)) { + return arg1->Join(arg2); + } + return nullptr; +} + +void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { + MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() + << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() + << ", Pointer: " << result->abstract().get(); + cache_[conf] = result; + + // Set intermediate abstract value. + if (IsIntermediateAbstract(result->abstract())) { + if (conf->node()->intermediate_abstract() == nullptr) { + conf->node()->set_intermediate_abstract(result->abstract()); + MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); + } else { + auto old_spec = conf->node()->intermediate_abstract(); + auto joined_spec = IntermediateJoin(result->abstract(), old_spec); + conf->node()->set_intermediate_abstract(joined_spec); + MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" + << result->abstract()->ToString() << "\njoined_spec:\t" + << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); + } + } +} + +EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { + auto value = cache_.find(conf); + if (value == cache_.end()) { + return nullptr; + } + return value->second; +} + +std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(conf->node()); + std::size_t hash_value = conf->node()->hash(); + if (!conf->context()->IsDummyContext()) { + hash_value = hash_combine(hash_value, std::hash{}(conf->context().get())); + } + if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) { + MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() + << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value; + } else { + MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value; + } + return hash_value; +} + +bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + if (lhs == rhs) { + return true; + } + return (*lhs == *rhs); +} + +AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { + ConfigPtrList args_conf_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + MS_EXCEPTION_IF_NULL(func_graph_manager_); + func_graph_manager_->AddFuncGraph(func_graph); + + AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); + + // Running the analyzer. + AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); + MS_EXCEPTION_IF_NULL(root_context); + MS_EXCEPTION_IF_NULL(root_context->func_graph()); + AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; + + AnalysisResult result; + MS_EXCEPTION_IF_NULL(output_conf); + result.inferred = output_conf->GetEvaluatedValue(); + result.context = root_context; + return result; +} + +AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const ConfigPtrList &args_conf_list) { + std::shared_ptr eval = std::make_shared(func_graph, context); + (void)eval->Run(shared_from_this(), args_conf_list, nullptr); + return eval->graph_context(); +} + +EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + auto value = cache_.GetValue(conf); + if (value != nullptr) { + MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() + << ", " << value->abstract()->ToString(); + return value; + } + + MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString(); + value = Eval(conf); + if (value == nullptr) { + MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr"; + } + cache_.set_value(conf, value); + return value; +} + +EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + AnfNodePtr node = conf->node(); + EvalResultPtr eval_result = nullptr; +#ifdef DEBUG + compute_conf_stack_.push_back(node); + std::ostringstream buffer; + buffer << "Compute Config Begin:"; + for (auto iter : compute_conf_stack_) { + buffer << " -> " << iter->DebugString(); + } + MS_LOG(DEBUG) << buffer.str(); +#endif + MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString(); + MS_EXCEPTION_IF_NULL(node); + if (node->abstract() != nullptr) { + MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); + eval_result = std::make_shared(node->abstract(), std::make_shared()); + } else if (node->isa()) { + auto value_node = node->cast(); + eval_result = std::make_shared(EvalValueNode(value_node, conf), nullptr); + } else if (node->isa()) { + auto cnode = node->cast(); + trace::TraceEvalCNodeEnter(conf); + eval_result = EvalCNode(cnode, conf); + trace::TraceEvalCNodeLeave(); + } else { + MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() + << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } + +#ifdef DEBUG + compute_conf_stack_.pop_back(); + if (eval_result == nullptr) { + MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() + << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); + } +#endif + MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); + return eval_result; +} + +AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(value_node); + return ToAbstract(value_node->value(), conf->context(), conf); +} + +EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { + MS_EXCEPTION_IF_NULL(conf); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString(); + } + + AnfNodePtr func_node = inputs[0]; + MS_EXCEPTION_IF_NULL(func_node); + MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString(); + AnalysisContextPtr context = conf->context(); + AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); + MS_EXCEPTION_IF_NULL(func_conf); + // Keep it in a local variable, otherwise smart pointer will free it. + AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); + if (maybe_func == nullptr) { + MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() + << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); + } + if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { + MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; + return std::make_shared(maybe_func->Clone(), std::make_shared()); + } + AbstractFunctionPtr func = dyn_cast(maybe_func); + if (func == nullptr) { + MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() + << ", func_conf: " << func_conf->ToString() + << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); + } + + ConfigPtrList args_conf_list; + // ignore the first node which is function name + for (std::size_t i = 1; i < inputs.size(); i++) { + const AnfNodePtr &node = inputs[i]; + args_conf_list.push_back(MakeConfig(node, context)); + } + std::vector infs; + + auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) { + auto evaluator = this->GetEvaluatorFor(poss); + evaluator->set_bound_node(cnode); + infs.push_back(evaluator); + }; + func->Visit(build_evaluator); + + return ExecuteEvaluators(infs, conf, args_conf_list); +} + +EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { + ConfigPtrList args_conf_list; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), + [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); + std::vector infs; + MS_EXCEPTION_IF_NULL(func); + auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) { + auto evaluator = this->GetEvaluatorFor(poss); + infs.push_back(evaluator); + }; + func->Visit(build_evaluator); + return ExecuteEvaluators(infs, nullptr, args_conf_list); +} + +void AnalysisEngine::ClearEvaluatorCache() { + for (std::pair element : constructors_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } + for (auto &element : prim_constructors_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } + for (auto &element : prim_py_evaluators_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } +} + +void AnalysisEngine::Clear() { + cache_.Clear(); + anfnode_config_map_.clear(); + eval_trace_.clear(); + constructors_.clear(); +} + +namespace { +EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { + // Custom Primitive with python infer_shape, infer_type + EvaluatorPtr evaluator = nullptr; + MS_EXCEPTION_IF_NULL(prim); + if (prim->isa()) { + evaluator = std::make_shared(prim); + return evaluator; + } + if (prim->isa()) { + evaluator = std::make_shared(prim); + return evaluator; + } + if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { + evaluator = std::make_shared(prim); + return evaluator; + } + if (prim->HasPyEvaluator()) { + auto prim_py = dyn_cast(prim); + if (prim_py != nullptr) { + if (engine == nullptr) { + return std::make_shared(prim_py); + } + + const auto &iter = engine->prim_py_evaluators_.find(prim_py); + if (iter != engine->prim_py_evaluators_.end()) { + return iter->second; + } + evaluator = std::make_shared(prim_py); + engine->prim_py_evaluators_[prim_py] = evaluator; + return evaluator; + } + MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; + } + + if (prim->isa() || prim->HasAttr()) { + if (engine == nullptr) { + (void)GetPrimEvaluatorConstructors(); + } + // If a primitive may have attr, try to create a new evaluator. + StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); + if (eval_impl != nullptr) { + return std::make_shared(prim, eval_impl); + } + } + + if (engine == nullptr) { + // If engine is nullptr, get constructor from default. + const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); + auto iter = prim_evaluator_map.find(prim); + if (iter != prim_evaluator_map.end()) { + evaluator = iter->second; + } + } else { + // If engine is given, get constructor from engine resource. + const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); + auto iter = prim_evaluator_map.find(prim); + if (iter != prim_evaluator_map.end()) { + evaluator = iter->second; + } + } + if (evaluator == nullptr) { + MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ")."; + } + return evaluator; +} +} // namespace + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + MS_EXCEPTION_IF_NULL(func); + auto primitive = func->prim(); + auto evaluator = GetPrimEvaluator(primitive, shared_from_this()); + constructors_[func] = evaluator; + return evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + MS_EXCEPTION_IF_NULL(func); + std::shared_ptr func_graph_evaluator = + std::make_shared(func->func_graph(), func->context()); + constructors_[func] = func_graph_evaluator; + return func_graph_evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + MS_EXCEPTION_IF_NULL(func); + std::shared_ptr evaluator = + std::make_shared(func->meta_func_graph(), func->context(), func->GetScope()); + constructors_[func] = evaluator; + return evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + MS_EXCEPTION_IF_NULL(func); + AbstractFunctionPtr func_orig = func->fn(); + EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); + auto jevaluator = std::make_shared(evaluator_orig, func_orig); + return jevaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + MS_EXCEPTION_IF_NULL(func); + std::shared_ptr virtual_evaluator = + std::make_shared(func->args_spec_list(), func->output()); + return virtual_evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { + MS_EXCEPTION_IF_NULL(func); + AbstractFunctionPtr func_orig = func->fn(); + EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); + std::shared_ptr partial_evaluator = + std::make_shared(evaluator_orig, func->args()); + return partial_evaluator; +} + +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &) { + MS_LOG(EXCEPTION) << "Should not be called "; +} + +// Forward to specific subclass of FunctionWrapper. +EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { + MS_EXCEPTION_IF_NULL(func); + EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this()); + return evaluator; +} + +EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { + MS_LOG(DEBUG) << "The func value: " << func->ToString(); + if (func->tracking_id() != nullptr) { + MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); + } + MS_EXCEPTION_IF_NULL(func); + if (func->tracking_id() == nullptr) { + EvaluatorPtr evaluator = _GetEvaluatorFor(func); + return evaluator; + } + auto inf_pair = constructors_.find(func); + if (inf_pair != constructors_.end()) { + return inf_pair->second; + } + + AbstractFunctionPtr func_generic = func->Copy(); + func_generic->set_tracking_id(nullptr); + EvaluatorPtr eval = _GetEvaluatorFor(func_generic); + auto tracked_eval = std::make_shared(eval); + constructors_[func] = tracked_eval; + + return tracked_eval; +} + +EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { + if (evaluators.size() == 1) { + EvaluatorPtr eval = evaluators[0]; + MS_EXCEPTION_IF_NULL(eval); + return eval->Run(shared_from_this(), args_conf_list, out_conf); + } + return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); +} + +void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { + auto fg_eval = evaluator->cast(); + if (fg_eval == nullptr) { + return; + } + auto fg = fg_eval->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto undetermined_fgs = fg->recursive_graphs(); + if (undetermined_fgs) { + auto fg_parent = fg->parent(); + MS_EXCEPTION_IF_NULL(fg_parent); + fg_parent->set_flag(kFuncGraphFlagUndetermined, true); + MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); + } +} + +EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector &evaluators, + const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, + const EvalTraceRevIter &it, bool *continue_flag) { + *continue_flag = false; + // Find latest entry function to handle nested recursion. + EvaluatorPtr latest_entry = eval; + auto latest_entry_iter = eval_trace_.rbegin(); + for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { + auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); + if (it_temp != evaluators.end()) { + latest_entry = *it_temp; + latest_entry_iter = r_it; + break; + } + latest_entry_iter = ++r_it; + } + if (latest_entry != eval) { + MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); + *continue_flag = true; + return latest_entry; + } + + bool has_undetermined = false; + // Check whether sub loop has untraced undetermined evaluator. + std::set> undetermined_evals; + for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { + undetermined_evals.insert(*r_it); + } + MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); + + for (auto u_eval : undetermined_evals) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; + if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; + has_undetermined = true; + break; + } + } + if (has_undetermined == false) { + MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + *continue_flag = true; + return latest_entry; + } + + return latest_entry; +} + +EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { + if (out_specs.size() == 0) { + MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; + } + + if (out_specs.size() == 1) { + MS_EXCEPTION_IF_NULL(out_specs[0]); + // If only one result derived, then broaden it to avoid wrong constant propagation. + return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); + } + auto joined_spec = AbstractJoin(out_specs); + MS_EXCEPTION_IF_NULL(joined_spec); + MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); + return std::make_shared(joined_spec, std::make_shared()); +} + +EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, + const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list) { + AbstractBasePtrList out_specs; + if (!multi_poss_.count(evaluators[0])) { + multi_poss_[evaluators[0]] = evaluators[1]; + multi_poss_[evaluators[1]] = evaluators[0]; + } + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &conf) -> AbstractBasePtr { + MS_EXCEPTION_IF_NULL(conf); + return conf->GetEvaluatedValue()->abstract(); + }); + for (auto eval : evaluators) { + SetUndeterminedFlag(eval); + + auto current_inf = std::make_pair(eval, args_spec_list); + MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); + + // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. + auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); + if (it == eval_trace_.rend()) { + eval_trace_.push_back(current_inf); + MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); + MS_EXCEPTION_IF_NULL(eval); + auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); + out_specs.push_back(eval_result->abstract()); + eval_trace_.pop_back(); + if (eval_trace_.empty()) { + multi_poss_.clear(); + } + } else if (it != eval_trace_.rbegin()) { + bool continue_flag = false; + auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); + if (continue_flag) { + continue; + } + + // Try to travel the latest undetermined. + if (latest_entry != eval_trace_.rbegin()->first) { + MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); + auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); + MS_EXCEPTION_IF_NULL(eval_result->abstract()); + MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() + << " return out_spec: " << eval_result->abstract()->ToString(); + return eval_result; + } + } + } + + return ProcessEvalResults(out_specs); +} + +EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { + AnfNodeConfigPtr self = shared_from_base(); + return engine_.lock()->GetEvaluatedValue(self); +} + +AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { + if (value->isa()) { + auto func_graph = value->cast(); + return func_graph->MakeAbstractClosure(context); + } + AnfNodePtr anf_node = nullptr; + if (conf != nullptr) { + anf_node = conf->node(); + } + if (value->isa()) { + auto meta_func_graph = value->cast(); + return meta_func_graph->MakeAbstractClosure(anf_node); + } + if (value->isa()) { + auto prim = value->cast(); + return prim->ToPrimAbstract(anf_node); + } + return value->ToAbstract(); +} + +AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { + AbstractBasePtr a = ToAbstract(value, nullptr, nullptr); + if (broaden) { + a = a->Broaden(); + } + return a; +} + +EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { + auto evaluator = GetPrimEvaluator(primitive, nullptr); + MS_EXCEPTION_IF_NULL(evaluator); + if (!evaluator->isa()) { + MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but " + << evaluator->ToString(); + } + auto trivial_evaluator = dyn_cast(evaluator); + auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); + return eval_result; +} +} // namespace abstract +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h new file mode 100644 index 0000000000..181696f756 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -0,0 +1,280 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ +#define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#ifdef DEBUG +#include +#endif + +#include "utils/log_adapter.h" +#include "ir/anf.h" +#include "ir/primitive_py.h" +#include "abstract/analysis_context.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "pipeline/jit/parse/parse.h" + +namespace mindspore { +namespace abstract { +// define attribute value map +using AttrValueMap = std::unordered_map; +using AttrValueMapPtr = std::shared_ptr; + +// the class to save evaluated result: abstract value and modified attribute +class EvalResult : public Base { + public: + EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} + ~EvalResult() override = default; + MS_DECLARE_PARENT(EvalResult, Base); + AbstractBasePtr abstract() { return abstract_; } + AttrValueMapPtr attribute() { return attribute_; } + + private: + AbstractBasePtr abstract_; + AttrValueMapPtr attribute_; +}; + +using EvalResultPtr = std::shared_ptr; +// Superclass for AnfNodeConfig and VirtualConfig. +class Config : public Base { + public: + Config() = default; + ~Config() override = default; + MS_DECLARE_PARENT(Config, Base); + virtual EvalResultPtr GetEvaluatedValue() = 0; +}; + +// Config will be stored in AnalysisCache +using ConfigPtr = std::shared_ptr; +using ConfigPtrList = std::vector; + +// Config to a certain node in a certain context. +class AnfNodeConfig : public Config { + public: + AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context) + : Config(), engine_(std::weak_ptr(engine)), node_(node) { + FuncGraphPtr fg; + if (IsValueNode(node)) { + auto v = node->cast(); + fg = v->value()->cast(); + } else { + fg = node->func_graph(); + } + context_ = nullptr; + if (context != nullptr) { + context_ = context->Filter(fg); + } + } + + ~AnfNodeConfig() override = default; + MS_DECLARE_PARENT(AnfNodeConfig, Config); + + EvalResultPtr GetEvaluatedValue() override; + + AnalysisContextPtr context() const { return context_; } + + AnfNodePtr node() const { return node_; } + + AnalysisEnginePtr engine() const { return engine_.lock(); } + + // used by unordered_map; + bool operator==(const AnfNodeConfig &other) const { + // compare node with pointer, context with pointer except DummyContext as it's created by make_shared; + // context should not be nullptr; + if (context_->IsDummyContext() && other.context_->IsDummyContext()) { + return true; + } + return (node_ == other.node_) && (context_ == other.context_); + } + + std::string ToString() const override { + std::ostringstream buffer; + buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); + return buffer.str(); + } + + private: + // AnalysisEngine is global. + // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use + // weak_ptr to break Config cycle. + std::weak_ptr engine_; + AnfNodePtr node_; + AnalysisContextPtr context_; +}; + +using AnfNodeConfigPtr = std::shared_ptr; + +struct AnfNodeConfigHasher { + std::size_t operator()(const AnfNodeConfigPtr conf) const; +}; + +struct AnfNodeConfigEqual { + bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const; +}; + +class VirtualConfig : public Config { + public: + explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {} + + ~VirtualConfig() override = default; + MS_DECLARE_PARENT(VirtualConfig, Config); + EvalResultPtr GetEvaluatedValue() override { + return std::make_shared(abstract_, std::make_shared()); + } + + private: + AbstractBasePtr abstract_; +}; + +// AnalysisCache +class AnalysisCache { + public: + AnalysisCache() = default; + ~AnalysisCache() = default; + void Clear() { cache_.clear(); } + void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); + EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); + + private: + std::unordered_map cache_; +}; + +using PrimEvaluatorMap = std::unordered_map; +using AnfNodeConfigMap = + std::unordered_map; + +struct AnalysisResult { + EvalResultPtr inferred; + AnalysisContextPtr context; +}; + +using EvalTraceRevIter = std::list>::reverse_iterator; + +class AnalysisEngine : public std::enable_shared_from_this { + public: + AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) + : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} + ~AnalysisEngine() = default; + + // func_graph: The func_graph to analyze. + // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. + AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); + EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); + // Return the Evaluator for the given function. + EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); + + AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); + EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); + // Infer the result of fn(args). + EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); + void Clear(); + void ClearEvaluatorCache(); + AnalysisCache &cache() { return cache_; } + AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) { + return std::make_shared(shared_from_this(), node, context); + } + // Overloaded function. + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &); + EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); + + FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; } + const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; } + + // Set the analysis result for orig to the result for new. + // This sets an entry in anfnode_config_map from orig to new. + EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { + // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. + (void)anfnode_config_map_.emplace(orig_conf, new_conf); + MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() + << ", to new_conf: " << new_conf->node()->DebugString(); + return GetEvaluatedValue(new_conf); + } + const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } + + AnalysisCache cache_; + std::unordered_map prim_py_evaluators_; + + private: + void SetUndeterminedFlag(const EvaluatorPtr &evaluator); + EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, + const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, + bool *continue_flag); + EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); + + const PrimEvaluatorMap &prim_constructors_; + FuncGraphManagerPtr func_graph_manager_; + std::unordered_map constructors_; + AnfNodeConfigMap anfnode_config_map_; + // Use a list to trace multiple evaluators. + std::list> eval_trace_; + std::map multi_poss_; + + AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, + const ConfigPtrList &args_conf_list); + EvalResultPtr Eval(const AnfNodeConfigPtr &conf); + EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); + EvalResultPtr ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); + EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, + const ConfigPtrList &args_conf_list); + +#ifdef DEBUG + std::vector compute_conf_stack_; +#endif +}; + +// Translate the value to an abstract value. +// Arguments: +// value: The value to convert. +// context: The context in which the value was found, used if the value is a Graph. +// conf: The Config to the valuenode we are converting, if there is one, +// so that we can generate a tracking_id. +AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr, + const AnfNodeConfigPtr &conf = nullptr); + +// Convert a value to an abstract value. +// Arguments: +// v: The value to convert. +// broaden: If True, concrete values will be made more abstract, so e.g. +// the value 1234 would become ANYTHING. +AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false); + +template +AbstractBasePtr FromValue(const T &value, bool broaden = false) { + return FromValueInside(MakeValue(value), broaden); +} + +EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); +} // namespace abstract +} // namespace mindspore + +#endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc new file mode 100644 index 0000000000..04aa6efd05 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -0,0 +1,120 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/validator.h" + +#include +#include + +#include "ir/manager.h" +#include "ir/dtype.h" +#include "./common.h" +#include "pipeline/jit/static_analysis/prim.h" + +namespace mindspore { +namespace validator { +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractClass; +using mindspore::abstract::AbstractError; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractIndexedSlices; +using mindspore::abstract::AbstractJTagged; +using mindspore::abstract::AbstractList; +using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractTensor; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractType; + +void ValidateOperation(const AnfNodePtr &node) { + if (!IsValueNode(node)) { + return; + } + + // Primitive must in whitelist + PrimitivePtr prim = GetValueNode(node); + if (abstract::IsInWhiteList(prim)) { + return; + } + if (prim->HasPyEvaluator()) { + MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; + return; + } + if (prim->name() == "fake_bprop") { + MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue(prim->GetAttr("info")); + } + + MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); +} + +void ValidateAbstract(const AnfNodePtr &node) { + if (node == nullptr) { + MS_LOG(DEBUG) << "Node to validate is invalid"; + return; + } + AbstractBasePtr ptrBase = node->abstract(); + if (ptrBase == nullptr) { + MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); + return; + } + if (ptrBase->isa() || ptrBase->isa()) { + // Validate a type. + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); + } + if (ptrBase->isa()) { + TypePtr ptrType = ptrBase->GetTypeTrack(); + MS_EXCEPTION_IF_NULL(ptrType); + if (ptrType->isa() || ptrType->isa()) { + // only send string in external + if (!IsValueNode(node)) { + // Validate a type. + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); + } + } + return; + } + if (ptrBase->isa()) { + // NOTICE: validate dead code? + MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); + return; + } + + if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa()) { + return; + } + + if (ptrBase->isa()) { + return; + } + + // Other types show exception + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); +} + +void Validate(const FuncGraphPtr &fg) { + FuncGraphManagerPtr mgr = Manage(fg, false); + MS_EXCEPTION_IF_NULL(mgr); + AnfNodeSet &all_nodes = mgr->all_nodes(); + for (const auto &anf_node : all_nodes) { + ValidateOperation(anf_node); + ValidateAbstract(anf_node); + } +} +} // namespace validator +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/validator.h b/mindspore/ccsrc/pipeline/jit/validator.h new file mode 100644 index 0000000000..041448aed9 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/validator.h @@ -0,0 +1,38 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ +#define MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ + +#include +#include +#include +#include +#include "frontend/operator/ops.h" +#include "ir/anf.h" +#include "utils/misc.h" + +namespace mindspore { +namespace validator { +void Validate(const FuncGraphPtr &func_graph); +void ValidateAbstract(const AnfNodePtr &node); +void ValidateOperation(const AnfNodePtr &node); +} // namespace validator +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H__ diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc deleted file mode 100644 index 330d03d11c..0000000000 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ /dev/null @@ -1,559 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/data_converter.h" -#include -#include -#include -#include -#include -#include -#include -#include "pipeline/parse/resolve.h" -#include "pipeline/parse/python_adapter.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "ir/func_graph_cloner.h" -#include "utils/symbolic.h" -#include "utils/context/ms_context.h" -#include "debug/trace.h" -#include "optimizer/ad/grad.h" - -namespace mindspore { -namespace parse { -using Tensor = mindspore::tensor::Tensor; -using TensorPtr = mindspore::tensor::TensorPtr; -using MetaTensor = mindspore::tensor::MetaTensor; -using MetaTensorPtr = mindspore::tensor::MetaTensorPtr; - -FuncGraphPtr ConvertToBpropCut(const py::object &obj) { - std::vector results = data_converter::GetObjKey(obj); - std::string obj_key = results[0]; - py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME); - - auto bprop_graph = std::make_shared(); - std::vector outputs; - - auto fake_bprop = std::make_shared("bprop_cut", py::object()); - fake_bprop->set_hook(bprop_func); - (void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true)); - outputs.push_back(NewValueNode(fake_bprop)); - - py::object code_obj = py::getattr(bprop_func, "__code__"); - size_t inputs_num = py::cast(py::getattr(code_obj, "co_argcount")) - 3; - for (size_t i = 0; i < inputs_num; ++i) { - auto param = bprop_graph->add_parameter(); - outputs.push_back(param); - } - auto p1 = bprop_graph->add_parameter(); - auto p2 = bprop_graph->add_parameter(); - outputs.push_back(p1); - outputs.push_back(p2); - - bprop_graph->set_output(bprop_graph->NewCNode(outputs)); - data_converter::SetObjGraphValue(obj_key, bprop_graph); - return bprop_graph; -} - -namespace { -bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { - MS_LOG(DEBUG) << "Converting python tuple"; - py::tuple tuple = obj.cast(); - std::vector value_list; - for (size_t it = 0; it < tuple.size(); ++it) { - ValuePtr out = nullptr; - bool success = ConvertData(tuple[it], &out, use_signature); - if (!success) { - return false; - } - value_list.push_back(out); - } - *data = std::make_shared(value_list); - - return true; -} - -bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { - MS_LOG(DEBUG) << "Converting python list"; - - py::list list = obj.cast(); - std::vector value_list; - for (size_t it = 0; it < list.size(); ++it) { - ValuePtr out = nullptr; - bool success = ConvertData(list[it], &out, use_signature); - if (!success) { - return false; - } - value_list.push_back(out); - } - *data = std::make_shared(value_list); - return true; -} - -bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { - MS_LOG(DEBUG) << "Converting cell list"; - py::sequence list = obj; - std::vector value_list; - for (size_t it = 0; it < list.size(); ++it) { - ValuePtr out = nullptr; - bool success = ConvertData(list[it], &out, use_signature); - if (!success) { - return false; - } - value_list.push_back(out); - } - *data = std::make_shared(value_list); - return true; -} - -bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { - MS_LOG(DEBUG) << "Converting python dict"; - - py::dict dict_values = obj.cast(); - std::vector> key_values; - for (auto item : dict_values) { - if (!py::isinstance(item.first)) { - MS_LOG(EXCEPTION) << "The key of dict is only support str."; - } - std::string key = py::str(item.first); - ValuePtr out = nullptr; - bool success = ConvertData(dict_values[item.first], &out, use_signature); - if (!success) { - return false; - } - key_values.emplace_back(key, out); - } - *data = std::make_shared(key_values); - return true; -} - -void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting python module"; - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); -} - -void ConvertDataClass(py::object obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting dataclass"; - // Maybe the obj is dataclass define - auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); - // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); -} - -bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { - MS_LOG(DEBUG) << "Converting primitive object"; - - // need check the primitive is class type or instance - auto obj_type = data_converter::GetObjType(obj); - if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { - auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); - // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); - } else { - auto primitive = obj.cast(); - if (primitive == nullptr) { - MS_LOG(ERROR) << "Resolve Primitive error, get ptr is null"; - return false; - } - if (py::hasattr(obj, "__setattr_flag__")) { - if (py::hasattr(obj, "_clone")) { - auto clone_fn = obj.attr("_clone"); - py::object new_obj = clone_fn(); - primitive = new_obj.cast(); - } - } - if (use_signature) { - *data = std::make_shared(primitive->name(), primitive); - } else { - *data = primitive; - } - } - return true; -} - -bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { - MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; - auto meta = obj.cast(); - if (meta == nullptr) { - MS_LOG(ERROR) << "Resolve MetaFuncGraph error, get ptr is null"; - return false; - } - if (use_signature) { - *data = std::make_shared(meta->name(), meta); - } else { - *data = meta; - } - return true; -} - -bool ConvertDataType(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting type object"; - auto typeptr = obj.cast(); - if (typeptr == nullptr) { - MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null"; - return false; - } - *data = typeptr; - return true; -} - -bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting MetaTensor object."; - - auto m_tensor = obj.cast(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null."; - return false; - } - *data = m_tensor; - return true; -} - -bool ConvertTensor(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting tensor object"; - - auto m_tensor = obj.cast(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null"; - return false; - } - *data = m_tensor; - return true; -} - -bool ConvertSlice(const py::object &obj, ValuePtr *const data) { - MS_LOG(DEBUG) << "Converting slice object"; - - py::slice slice_obj = obj.cast(); - auto convert_func = [obj](std::string attr) -> ValuePtr { - auto py_attr = py::getattr(obj, attr.c_str()); - if (py::isinstance(py_attr)) { - return kNone; - } else if (py::isinstance(py_attr)) { - int value = py::cast(py_attr); - return MakeValue(value); - } else { - MS_LOG(EXCEPTION) << "Slice should contain only int or none"; - } - }; - ValuePtr start = convert_func("start"); - ValuePtr stop = convert_func("stop"); - ValuePtr step = convert_func("step"); - *data = std::make_shared(start, stop, step); - return true; -} - -bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { - FuncGraphPtr func_graph = ConvertToFuncGraph(obj); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return false; - } - // if the cell object has specified bprop, it has user-defined bprop function parse and record it - if (py::hasattr(obj, CUSTOM_BPROP_NAME)) { - FuncGraphPtr bprop_graph = nullptr; - bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); - if (enable_bprop_debug) { - bprop_graph = ConvertToBpropCut(obj); - } else { - bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); - } - if (bprop_graph != nullptr) { - (void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); - func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true); - } - } - *data = func_graph; - return true; -} - -bool ConvertOtherObj(py::object obj, ValuePtr *const data) { - auto obj_type = data_converter::GetObjType(obj); - MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; - if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { - MS_LOG(DEBUG) << "Resolve the class type, need create class instance."; - std::string desc = py::str(obj); - // desc has format "", strip the '<' and '>' by offset 1; - *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); - return true; - } - if (obj_type == RESOLVE_TYPE_FUNCTION || obj_type == RESOLVE_TYPE_METHOD) { - MS_LOG(DEBUG) << "Convert the obj to func graph, type is " << obj_type; - FuncGraphPtr func_graph = ConvertToFuncGraph(obj); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return false; - } - *data = func_graph; - return true; - } - if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { - // Create the namespace for common class instance - // When the obj is Cell, default parse the 'construct' - if (data_converter::IsCellInstance(obj)) { - return ConvertCellObjToFuncGraph(obj, data); - } - - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); - return true; - } - MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); - return false; -} -} // namespace - -bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { - // check parameter valid - if (data == nullptr) { - MS_LOG(ERROR) << "Data is null pointer"; - return false; - } - - bool ret = true; - ValuePtr converted = nullptr; - if (py::isinstance(obj)) { - converted = kNone; - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - ret = ConvertDict(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertSlice(obj, &converted); - } else if (py::isinstance(obj)) { - converted = kEllipsis; - } else if (py::isinstance(obj)) { - ret = ConvertTuple(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { - ret = ConvertCellList(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertList(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ConvertNameSpace(obj, &converted); - } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { - ConvertDataClass(obj, &converted); - } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { - ret = ConvertPrimitive(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_METAFUNCGRAPH_FLAG)) { - ret = ConvertMetaFuncGraph(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_DTYPE_FLAG)) { - ret = ConvertDataType(obj, &converted); - } else if (py::hasattr(obj, PYTHON_TENSOR_FLAG)) { - ret = ConvertTensor(obj, &converted); - } else if (py::hasattr(obj, PYTHON_META_TENSOR_FLAG)) { - ret = ConvertMetaTensor(obj, &converted); - } else if (py::hasattr(obj, PYTHON_ENVINSTANCE_FLAG)) { - std::shared_ptr env = obj.cast>(); - converted = env; - } else if (py::hasattr(obj, "__parameter__")) { - auto to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - ret = ConvertData(to_convert, &converted); - } else { - ret = ConvertOtherObj(obj, &converted); - } - - *data = converted; - return ret; -} - -// convert data to graph -FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { - std::vector results = data_converter::GetObjKey(obj); - std::string obj_id = results[0] + python_mod_get_parse_method; - std::string obj_key = results[1]; - FuncGraphPtr func_graph = nullptr; - Any value = Any(); - bool is_cache = data_converter::GetObjectValue(obj_id, &value); - if (is_cache) { - if (value.is()) { - MS_LOG(DEBUG) << "Get the cache data, obj = " << obj_id; - func_graph = value.cast(); - return func_graph; - } - } - - func_graph = ParsePythonCode(obj, python_mod_get_parse_method); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return nullptr; - } - - data_converter::MakeProperNameToFuncGraph(func_graph, obj_id); - data_converter::CacheObjectValue(obj_id, func_graph); - if (obj_key != "") { - MS_LOG(DEBUG) << "Add graph:" << obj_key << ", func_graph:" << func_graph->ToString(); - data_converter::SetObjGraphValue(obj_key, func_graph); - } - - return func_graph; -} -namespace data_converter { -static std::unordered_map object_map_ = std::unordered_map(); - -static std::unordered_map> object_graphs_map_ = - std::unordered_map>(); - -void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { - object_graphs_map_[obj_key].push_back(data); - MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); -} - -const std::unordered_map> &GetObjGraphs() { - MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); - return object_graphs_map_; -} - -void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } -bool GetObjectValue(const std::string &obj_key, Any *const data) { - if (object_map_.count(obj_key)) { - *data = object_map_[obj_key]; - return true; - } - return false; -} -std::vector GetObjKey(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); - if (obj_tuple.size() != 2) { - MS_LOG(EXCEPTION) << "Get_obj_key must return 2 elements"; - } - return {py::cast(obj_tuple[0]), py::cast(obj_tuple[1])}; -} - -// get obj detail type -ResolveTypeDef GetObjType(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - auto obj_type = - ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); - return obj_type; -} - -// get class instance detail type -ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - auto class_type = - ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); - return class_type; -} - -// check the object is Cell Instance -bool IsCellInstance(const py::object &obj) { - auto class_type = GetClassInstanceType(obj); - bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); - return isCell; -} - -// create the python class instance -py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object obj; - if (params.size() == 0) { - obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type); - } else { - obj = python_adapter::CallPyModFn(mod, PYTHON_MOD_CREATE_OBJ_INSTANCE, type, params); - } - return obj; -} - -// Generate an appropriate name and set to graph debuginfo -// character <> can not used in the dot file, so change to another symbol -void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(func_graph->debug_info()); - // set detail name info of function - std::ostringstream oss; - for (size_t i = 0; i < name.size(); i++) { - if (name[i] == '<') { - oss << "「"; - } else if (name[i] == '>') { - oss << "」"; - } else { - oss << name[i]; - } - } - func_graph->debug_info()->set_full_name(oss.str()); -} - -ValuePtr PyDataToValue(const py::object &obj) { - py::object to_convert = obj; - if (py::hasattr(obj, "__parameter__")) { - to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); - } - ValuePtr value = nullptr; - (void)ConvertData(to_convert, &value); - return value; -} - -void ClearObjectCache() { - object_map_.clear(); - object_graphs_map_.clear(); -} -} // namespace data_converter - -static std::unordered_map g_dataClassToClass = {}; - -// parse dataclass to mindspore Class type -ClassPtr ParseDataClass(const py::object &cls_obj) { - std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); - std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); - std::string cls = cls_module + "." + cls_name; - auto iterator = g_dataClassToClass.find(cls); - if (iterator != g_dataClassToClass.end()) { - return iterator->second; - } - - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - ClassAttrVector attributes; - py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); - for (auto &item : names) { - TypePtr type_value = item.second.cast(); - MS_EXCEPTION_IF_NULL(type_value); - MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; - attributes.push_back(std::make_pair(py::cast(item.first), type_value)); - } - - std::unordered_map methods_map; - py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); - for (auto &item : methods) { - std::string fun_name = item.first.cast(); - py::object obj = py::cast(item.second); - std::shared_ptr method_obj = std::make_shared(obj, fun_name); - methods_map[fun_name] = method_obj; - } - - std::shared_ptr me_class = std::make_shared(Named(cls_name), attributes, methods_map); - // static Variable for cache - // cppcheck-suppress unreadVariable - g_dataClassToClass[cls] = me_class; - - return me_class; -} - -void CleanDataClassToClassMap() { g_dataClassToClass.clear(); } -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.h b/mindspore/ccsrc/pipeline/parse/data_converter.h deleted file mode 100644 index 0165b55363..0000000000 --- a/mindspore/ccsrc/pipeline/parse/data_converter.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_DATA_CONVERTER_H_ -#define PIPELINE_PARSE_DATA_CONVERTER_H_ - -#include -#include -#include -#include -#include -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parse { -// data convert for parse -namespace data_converter { -void CacheObjectValue(const std::string &obj_key, const Any &data); -bool GetObjectValue(const std::string &obj_key, Any *const data); - -void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); - -const std::unordered_map> &GetObjGraphs(); - -std::vector GetObjKey(const py::object &obj); -ResolveTypeDef GetObjType(const py::object &obj); -ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); - -bool IsCellInstance(const py::object &obj); -py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); -void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); -ValuePtr PyDataToValue(const py::object &obj); -void ClearObjectCache(); -} // namespace data_converter - -ClassPtr ParseDataClass(const py::object &cls_obj); -FuncGraphPtr ConvertToBpropCut(const py::object &obj); - -void CleanDataClassToClassMap(); - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_DATA_CONVERTER_H_ diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc deleted file mode 100644 index 701f7d0f6b..0000000000 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ /dev/null @@ -1,374 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/function_block.h" -#include -#include -#include -#include "pipeline/parse/resolve.h" -#include "pipeline/parse/parse.h" -#include "operator/ops.h" -#include "debug/info.h" -#include "debug/trace.h" -#include "pybind11/pybind11.h" - -namespace mindspore { -namespace py = pybind11; - -namespace parse { -FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { - func_graph_ = std::make_shared(); - matured_ = false; -} - -void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } - -// write variable records the variable name to corresponding node -void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { - MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); - vars_[var_name] = node; -} - -// read variable from predecessors -AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { - // get var node if it is found - if (vars_.count(var)) { - AnfNodePtr node = vars_[var]; - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - return NewValueNode(GetValueNode(node)); - } else { - return node; - } - } - // get var from predecessor block ,if can't get the make a resolve node to it - if (matured_) { - // If only one predecessor block, read the definition of var from it. - if (prev_blocks_.size() == 1) { - auto block = prev_blocks_[0]; - MS_EXCEPTION_IF_NULL(block); - return block->ReadVariable(var); - } else if (prev_blocks_.empty()) { - // get namespace and make Reslove - return MakeResolveSymbol(var); - } - } - // If have more than one predecessor blocks then build a phi node. - auto debug_info = std::make_shared(); - debug_info->set_name(var); - TraceManager::DebugTrace(std::make_shared(debug_info)); - ParameterPtr phi_param = std::make_shared(func_graph()); - TraceManager::EndTrace(); - MS_LOG(DEBUG) << func_graph_->ToString() << " generate phi node " << phi_param->ToString() << " for " << var; - func_graph()->add_parameter(phi_param); - phi_nodes_[phi_param] = var; - WriteVariable(var, phi_param); - if (matured_) { - SetPhiArgument(phi_param); - } - return phi_param; -} - -// Resolve Ast operator node -AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { - auto ast = parser_.ast(); - MS_EXCEPTION_IF_NULL(ast); - TraceGuard trace_guard(parser_.GetLocation(op)); - py::tuple namespace_var = ast->CallParserObjMethod(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op); - if (namespace_var.size() != 2) { - MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size(); - } - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]); - SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); - return MakeResolve(name_space, symbol); -} - -// Resolve class member, two possible: method, member variable -AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { - py::object namespace_var = - parser_.ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, parser_.ast()->obj()); - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); - SymbolPtr symbol = std::make_shared(attr); - return MakeResolve(name_space, symbol); -} - -// Make a resolve node for symbol string -AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { - if (value.compare(0, strlen("self."), "self.") == 0) { - auto start = value.find_first_of('.') + 1; - if (start >= value.size()) { - MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; - return nullptr; - } - auto bits_str = value.substr(start); - return MakeResolveClassMember(bits_str); - } - py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); - - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, namespace_var[0]); - SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); - return MakeResolve(name_space, symbol); -} - -AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { - py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); - NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); - SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); - return MakeResolve(name_space, symbol); -} - -AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { - MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " - << ((std::string)resolve_symbol->symbol()); - ValueNodePtr module_node = NewValueNode(name_space); - ValueNodePtr symbol_node = NewValueNode(resolve_symbol); - auto node = func_graph()->NewCNode({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); - return node; -} - -// add input for the block's phi parameter -void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { - std::string var = phi_nodes_[phi]; - MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; - for (auto &pred : prev_blocks_) { - MS_EXCEPTION_IF_NULL(pred); - MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); - AnfNodePtr arg_node = pred->ReadVariable(var); - CNodePtr jump = pred->jumps_[this]; - jump->add_input(arg_node); - } - // If the phi node in the body part of a for/while loop is being removed, - // then the closure convert phase will generate a cycle in graph if the - // loop is kept after specialization. This should be investigate further. - // Just now user has to set a flag on a function to indicate the for loop - // will definitely can be unroll as the sequence in for statement is fixed - // size in compile time. - if (parser_.func_graph()->has_flag(GRAPH_FLAG_LOOP_CAN_UNROLL) || - parser_.func_graph()->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - CollectRemovablePhi(phi); - } -} - -AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { - AnfNodePtr arg_node = nullptr; - for (auto &prev : prev_blocks_) { - MS_EXCEPTION_IF_NULL(prev); - AnfNodePtr temp_node = prev->ReadVariable(var); - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var - << " is " << temp_node->DebugString(); - if (temp_node != phi) { - if (arg_node == nullptr) { - arg_node = temp_node; - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() - << " may be replaced by node " << arg_node->DebugString(); - } else if (temp_node == arg_node) { - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " is same as node " - << arg_node->DebugString(); - } else { - MS_LOG(DEBUG) << "phi " << phi->ToString() - << " cannot be removed as it assigns to different node. node1: " << arg_node->DebugString() - << ", node2: " << temp_node->DebugString(); - return nullptr; - } - } - } - return arg_node; -} - -// Check if there is removable unnecessary phi node in this graph. -// as per the FIRM TR 3.2, a phi node can be remove if: -// -// If all arguments of a φ-function are the same value s or the φfunction itself, -// then we remove the φ-function and let all users directly uses. We call such a -// φ-function obviously unnecessary. -// When we removed a φ-function p, then we recursively try to apply this simplification -// rule with all (former) users of p, because they may have become obviously unnecessary -// due to the removal of p -// -// phi node in graph will be removed after the whole function is parsed in a DFS visit -// of that graph.The reason is : -// 1. when this function is called, not all usage of this phi node had bound to the -// graph of this function block, some may stay in vars_ in other blocks. -// 2. it's costly to iterate the graph to replace the phi for each phi. -// Args : -// phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { - MS_EXCEPTION_IF_NULL(phi); - std::string var = phi_nodes_[phi]; - MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); - if (prev_blocks_.size() == 0) { - MS_LOG(DEBUG) << "no phi " << phi->ToString() << " for var " << var << " in graph " << func_graph_->ToString(); - return; - } - AnfNodePtr arg_node = SearchReplaceNode(var, phi); - if (arg_node != nullptr) { - MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " phi " << phi->ToString() << " can be replaced with " - << arg_node->DebugString(); - // replace var with new one. This equal to statement in TR "v0 is immediately replaced by v1." - WriteVariable(var, arg_node); - removable_phis_[phi] = arg_node; - // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized - // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. - for (auto &prev : prev_blocks_) { - MS_EXCEPTION_IF_NULL(prev); - if (!prev->matured_) { - continue; - } - for (auto &phi_iter : prev->removable_phis_) { - MS_EXCEPTION_IF_NULL(phi_iter.second); - if (phi_iter.second->isa()) { - const auto ¶m = phi_iter.second->cast(); - if (param == phi) { - MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() - << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); - prev->removable_phis_[phi_iter.first] = arg_node; - } - } - } - } - } -} - -// A block should be marked matured if its predecessor blocks have been processed -void FunctionBlock::Mature() { - const auto &graphParamVec = func_graph_->parameters(); - for (auto ¶mItr : graphParamVec) { - MS_EXCEPTION_IF_NULL(paramItr); - ParameterPtr param = paramItr->cast(); - if (phi_nodes_.find(param) != phi_nodes_.cend()) { - SetPhiArgument(param); - } - } - matured_ = true; -} - -// Force the conditIon node to bool using bool operation -CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { - TraceManager::DebugTrace(std::make_shared(cond->debug_info())); - CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); - TraceManager::EndTrace(); - return op_apply_node; -} - -CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) { - TraceManager::DebugTrace(std::make_shared(cond->debug_info())); - CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond}); - TraceManager::EndTrace(); - return op_apply_node; -} - -// Perform a jump from this block to target block -void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { - if (func_graph()->get_return() != nullptr) { - MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " - << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); - } - std::vector input_nodes; - input_nodes.emplace_back(NewValueNode(target_block->func_graph())); - if (node != nullptr) { - input_nodes.emplace_back(node); - } - - CNodePtr jump = func_graph()->NewCNode(input_nodes); - jumps_[target_block.get()] = jump; - target_block->AddPrevBlock(shared_from_this()); - func_graph()->set_output(jump); - InsertDependItemsBeforeReturn(); -} - -// Perform a conditional jump using switch operation. -// The first CNode select graph with condition, and than execute this graph -void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, - const FunctionBlockPtr &false_block, bool unroll_loop) { - if (func_graph()->get_return() != nullptr) { - MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " - << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); - } - // Here we need set an attribute to primtive 'switch', so we create a new variable instead of global 'kPrimSwitch' - auto prim_switch = std::make_shared(prim::kPrimSwitch->name()); - if (!unroll_loop) { - prim_switch->AddAttr(prim::SWITCH_UNROLL_FLAG, MakeValue(0)); - } - CNodePtr switch_app = - func_graph()->NewCNode({NewValueNode(prim_switch), condNode, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); - CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); - func_graph()->set_output(switch_app_new); - InsertDependItemsBeforeReturn(); -} - -void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { - state_assign_[target] = readid; -} - -void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } - -void FunctionBlock::InsertDependItemsBeforeReturn() { - if (!prev_blocks_.empty()) { - for (auto &prev_block : prev_blocks_) { - MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); - } - } - - ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); - ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); - ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); - const std::string primitive_name("assign"); - const std::string module_name("mindspore.ops.functional"); - ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); - if (state_assign_.size() == 0 && auto_depends_.size() == 0) { - return; - } - AnfNodePtr state = nullptr; - std::vector vec_states; - vec_states.emplace_back(make_tuple_op); - for (auto &item : state_assign_) { - auto source = ReadVariable(item.second); - auto assign = func_graph()->NewCNode({assign_op, item.first, source}); - MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; - vec_states.emplace_back(assign); - } - for (auto &item : auto_depends_) { - MS_LOG(DEBUG) << "auto_depends " << item->ToString(); - vec_states.emplace_back(item); - } - // if there are only make_tuple_op and another node in vec_states(the vec_states size is 2) - // do not need to make_tuple, just use the node. - if (vec_states.size() == 2) { - state = vec_states[1]; - } else { - state = func_graph()->NewCNode(vec_states); - } - - AnfNodePtr old_ret = nullptr; - auto return_node = func_graph()->get_return(); - if (return_node) { - if (return_node->inputs().size() < 1) { - MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2"; - } - old_ret = return_node->input(1); - } else { - old_ret = NewValueNode(kNone); - } - AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); - AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); - func_graph()->set_output(ret, true); - state_assign_.clear(); -} -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h deleted file mode 100644 index b93838b43c..0000000000 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ /dev/null @@ -1,118 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_FUNCTION_BLOCK_H_ -#define PIPELINE_PARSE_FUNCTION_BLOCK_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "pipeline/parse/parse_base.h" -#include "utils/log_adapter.h" -#include "utils/ordered_map.h" - -namespace mindspore { -namespace parse { - -class Parser; -class NameSpace; -class Symbol; -class FunctionBlock; -using FunctionBlockPtr = std::shared_ptr; - -// A function block is a straight-line code sequence with no branches, every block has one one exit point -// which is return. When parsing function, loop or branch , we use function block to track the structure of -// the original source code. -class FunctionBlock : public std::enable_shared_from_this { - public: - explicit FunctionBlock(const Parser &parser); - virtual ~FunctionBlock() {} - - FuncGraphPtr func_graph() { return func_graph_; } - void WriteVariable(const std::string &var_name, const AnfNodePtr &node); - AnfNodePtr ReadVariable(const std::string &var_name); - void AddPrevBlock(const FunctionBlockPtr &block); - void SetPhiArgument(const ParameterPtr &phi); - void CollectRemovablePhi(const ParameterPtr &phi); - // A block is matured if all its predecessors is generated - void Mature(); - CNodePtr ForceToBoolNode(const AnfNodePtr &cond); - CNodePtr ForceToWhileCond(const AnfNodePtr &cond); - void Jump(const FunctionBlockPtr &block, AnfNodePtr node); - AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); - void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock, - bool unroll_loop = true); - // record the assign statement of self.xx weight parameter ,which will use state_setitem op - void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); - void AddAutoDepend(const AnfNodePtr &target); - void InsertDependItemsBeforeReturn(); - void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } - bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } - AnfNodePtr MakeResolveAstOp(const py::object &op); - AnfNodePtr MakeResolveClassMember(std::string attr); - AnfNodePtr MakeResolveSymbol(const std::string &value); - AnfNodePtr MakeResolveOperation(const std::string &value); - AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); - const std::unordered_map &removable_phis() const { return removable_phis_; } - - private: - // block graph - FuncGraphPtr func_graph_; - - // the block's parser - const Parser &parser_; - - // A block is matured if all its prev_blocks is processed - bool matured_; - - // store the nest-level block - // refer to comments in Parser::func_block_list_; - std::vector prev_blocks_; - - // store args and variable's node - std::map vars_; - - // phi_nodes map the parameter node to variable, it can be resolved if the block's predecessors are processed - std::map phi_nodes_; - - // jumps map the successor block and the function call that perform jump - // refer to comments in Parser::func_block_list_ that how to break the cyclic reference - std::map jumps_; - - // keeps all removable phis which will be removed in one pass. - std::unordered_map removable_phis_; - - // set state nodes need to insert before function return nodes. - OrderedMap state_assign_; - - // hold declared global variables in function - std::set global_vars_; - - // other depend need to insert before function return nodes. - // summary or some other node - std::vector auto_depends_; -}; - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_FUNCTION_BLOCK_H_ diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc deleted file mode 100644 index 1d306d9ca4..0000000000 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ /dev/null @@ -1,1604 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/parse.h" -#include -#include -#include -#include -#include -#include "operator/ops.h" -#include "pipeline/parse/data_converter.h" -#include "operator/composite/composite.h" -#include "utils/context/ms_context.h" -#include "debug/trace.h" - -namespace mindspore { -namespace parse { - -FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) { - (void)python_adapter::set_python_scoped(); - - if (obj == nullptr || py::isinstance(obj)) { - MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none"; - return nullptr; - } - - auto ast = std::make_shared(obj); - bool success = ast->InitParseAstInfo(python_mod_get_parse_method); - if (!success) { - MS_LOG(ERROR) << "Parse code to ast tree failed."; - return nullptr; - } - - auto parser = std::make_shared(ast); - - FuncGraphPtr func_graph = parser->ParseFuncGraph(); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode(); - return nullptr; - } - - return func_graph; -} - -// if any mixed precision flag add a cast node after the parameter node. -AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) { - TypePtr dst_type; - if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) { - dst_type = kFloat32; - } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) { - dst_type = kFloat16; - } else { - return param; - } - auto cast_helper = prim::kPrimMixedPrecisionCast; - auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); - return cast; -} - -FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr(); - -Parser::Parser(const std::shared_ptr &ast) : ast_(ast) { - errcode_ = PARSE_SUCCESS; - BuildMethodMap(); -} - -void Parser::BuildMethodMap() { - stmt_method_map_["Return"] = &Parser::ParseReturn; - stmt_method_map_["Expr"] = &Parser::ParseExpr; - stmt_method_map_["If"] = &Parser::ParseIf; - stmt_method_map_["Assign"] = &Parser::ParseAssign; - stmt_method_map_["While"] = &Parser::ParseWhile; - stmt_method_map_["For"] = &Parser::ParseFor; - stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; - stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; - stmt_method_map_["Global"] = &Parser::ParseGlobal; - stmt_method_map_["Break"] = &Parser::ParseBreak; - stmt_method_map_["Continue"] = &Parser::ParseContinue; - stmt_method_map_["Pass"] = &Parser::ParsePass; - expr_method_map_["NoneType"] = &Parser::ParseNone; - expr_method_map_["BinOp"] = &Parser::ParseBinOp; - expr_method_map_["Name"] = &Parser::ParseName; - expr_method_map_["Num"] = &Parser::ParseNum; - expr_method_map_["Str"] = &Parser::ParseStr; - expr_method_map_["NameConstant"] = &Parser::ParseNameConstant; - expr_method_map_["Call"] = &Parser::ParseCall; - expr_method_map_["IfExp"] = &Parser::ParseIfExp; - expr_method_map_["Attribute"] = &Parser::ParseAttribute; - expr_method_map_["Compare"] = &Parser::ParseCompare; - expr_method_map_["BoolOp"] = &Parser::ParseBoolOp; - expr_method_map_["Lambda"] = &Parser::ParseLambda; - expr_method_map_["Tuple"] = &Parser::ParseTuple; - expr_method_map_["List"] = &Parser::ParseList; - expr_method_map_["Subscript"] = &Parser::ParseSubscript; - expr_method_map_["Slice"] = &Parser::ParseSlice; - expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice; - expr_method_map_["Index"] = &Parser::ParseIndex; - expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; - expr_method_map_["Dict"] = &Parser::ParseDict; - expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; -} - -void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } - -void Parser::InitParserEnvironment(const py::object &obj) { - Parser::top_func_graph_ = FuncGraphWeakPtr(); - ScopeManager::GetInstance().ClearScope(); - (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj); -} - -void Parser::CleanParserResource() { - Parser::top_func_graph_ = FuncGraphWeakPtr(); - ScopeManager::GetInstance().ClearScope(); -} - -FuncGraphPtr Parser::ParseFuncGraph() { - // get ast FunctionDef node - py::object node = ast_->GetAstNode(); - FunctionBlockPtr pFnBlock = ParseFunction(node); - if (errcode() != PARSE_SUCCESS) { - MS_LOG(ERROR) << "Parse function error, code is " << errcode(); - return nullptr; - } - - RemoveUnnecessaryPhis(); - - MS_EXCEPTION_IF_NULL(pFnBlock); - return pFnBlock->func_graph(); -} - -void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { - py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args"); - py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg"); - block->func_graph()->set_has_vararg(!py::isinstance(var_arg_node)); - - py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg"); - block->func_graph()->set_has_kwarg(!py::isinstance(kw_arg_node)); - - py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs"); - block->func_graph()->set_kwonlyargs_count(SizeToInt(kwonly_args.size())); - - MS_EXCEPTION_IF_NULL(ast_); - py::list args = ast_->GetArgs(fn_node); - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg_name = py::cast(args[i].attr("arg")); - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - if (arg_name == "self") { - continue; - } - } - TraceManager::DebugTrace(GetLocation(args[i])); - auto para_node = std::make_shared(block->func_graph()); - MS_EXCEPTION_IF_NULL(para_node); - TraceManager::EndTrace(); - para_node->set_name(arg_name); - para_node->debug_info()->set_name(arg_name); - block->func_graph()->add_parameter(para_node); - AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block->func_graph(), para_node); - block->WriteVariable(arg_name, para_after_cast); - MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name; - } -} - -void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) { - py::list defaults = ast_->GetArgsDefaultValues(fn_node); - py::list args = ast_->GetArgs(fn_node); - std::vector namelist_for_default_value; - std::vector default_values; - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg_name = py::cast(args[i].attr("arg")); - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - if (arg_name == "self") { - continue; - } - } - - namelist_for_default_value.push_back(arg_name); - if (py::isinstance(defaults[i])) { - default_values.push_back(NewValueNode(kNull)); - } else { - default_values.push_back(ParseExprNode(block, defaults[i])); - } - } - block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values); -} - -ScopePtr Parser::GetScopeForParseFunction() { - ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope(); - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj()); - if (!py::isinstance(scope_str)) { - auto scope_name = py::cast(scope_str); - scope = std::make_shared(scope_name); - } - } - return scope; -} - -FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) { - ScopePtr scope = GetScopeForParseFunction(); - // the node created in the parsefunction context, will inherit the scope created using scope_guard - ScopeGuard scope_guard(scope); - TraceGuard trace_guard(data_converter::GetObjKey(ast()->obj())[0], GetLocation(node)); - FunctionBlockPtr pFunBlock = MakeFunctionBlock(*this); - if (block != nullptr) { - pFunBlock->AddPrevBlock(block); - } else { - func_graph_ = pFunBlock->func_graph(); - } - pFunBlock->Mature(); - auto current_fg = pFunBlock->func_graph(); - auto function_name = py::cast(python_adapter::GetPyObjAttr(node, "name")); - MS_LOG(DEBUG) << "The function name is " << function_name; - - current_fg->debug_info()->set_name(function_name); - MS_EXCEPTION_IF_NULL(ast_); - py::list deco_list = node.attr("decorator_list"); - if (deco_list.size() > 0) { - current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); - } - - bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); - if (ast_->obj() != ast_->function()) { - set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); - } - - if (!set_flag) { - MS_LOG(ERROR) << "Set flags failed"; - return nullptr; - } - GenerateArgsNodeForFunction(pFunBlock, node); - - // when parsing the top graph of construct, save the top graph - if (GetTopFuncGraph() == nullptr) { - UpdateTopFuncGraph(pFunBlock->func_graph()); - } - - // save the function node to block - pFunBlock->WriteVariable(function_name, NewValueNode(current_fg)); - - py::object funcObj = python_adapter::GetPyObjAttr(node, "body"); - (void)ParseStatements(pFunBlock, funcObj); - - if (current_fg->get_return() == nullptr) { - MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); - errcode_ = PARSE_NO_RETURN; - return pFunBlock; - } - GenerateArgsDefaultValueForFunction(pFunBlock, node); - return pFunBlock; -} - -FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) { - py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__"); - size_t count = IntToSize(pcount); - MS_LOG(DEBUG) << "The nodes count is " << count; - for (size_t i = 0; i < count; i++) { - auto node = py::cast(nodes)[i]; - TraceManager::DebugTrace(GetLocation(node)); - fn_block = ParseStatement(fn_block, node); - TraceManager::EndTrace(); - // insert appropriate depended items for the function block if it has a return node - if (fn_block->func_graph()->get_return() != nullptr) { - fn_block->InsertDependItemsBeforeReturn(); - // Skip statements after 'return' (or 'break', 'continue'). - break; - } - } - return fn_block; -} - -FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) { - auto node_type = ast_->GetNodeType(node); - - // check the node type - AstMainType nodeType = node_type->main_type(); - if (nodeType != AST_MAIN_TYPE_STMT) { - MS_LOG(INFO) << "Node type is error : " << nodeType; - return block; - } - // call the process function - std::string node_name = node_type->node_name(); - MS_LOG(DEBUG) << "Ast node is " << node_name; - if (stmt_method_map_.count(node_name)) { - TraceManager::DebugTrace(GetLocation(node)); - auto stmt_block = (this->*stmt_method_map_[node_name])(block, node); - TraceManager::EndTrace(); - return stmt_block; - } else { - errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; - } -} - -AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast expr"; - auto node_type = ast_->GetNodeType(node); - // check the node type - AstMainType node_main_type = node_type->main_type(); - if (node_main_type != AST_MAIN_TYPE_EXPR) { - MS_LOG(ERROR) << "Node type is error : " << node_main_type; - errcode_ = PARSE_NODE_TYPE_NO_MATCH; - return nullptr; - } - // call the process function - std::string node_name = node_type->node_name(); - MS_LOG(DEBUG) << "Ast node is " << node_name; - if (expr_method_map_.count(node_name)) { - TraceManager::DebugTrace(GetLocation(node)); - auto expr_node = (this->*expr_method_map_[node_name])(block, node); - TraceManager::EndTrace(); - return expr_node; - } else { - errcode_ = PARSE_NODE_METHOD_UNSUPPORTED; - py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - auto filename = ret[0].cast(); - auto line_no = ret[1].cast(); - MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; - } -} - -// process the expr statement and expand it -// eg: x.append(y) -> x = x.append(y) -FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Expr"; - // Expr only have value , no target - py::tuple expand_info = ast_->CallParserObjMethod(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node); - - // refer python function expand_expr_statement, expand_info is one of the following: - // True, expr.value, x - // True, expr.value - // False, None, None - // check the expand info result - auto is_expand = py::cast(expand_info[0]); - if (is_expand) { - // process the expr statement - py::object value_object = expand_info[1]; - AnfNodePtr value_node = ParseExprNode(block, value_object); - - if (py::len(expand_info) == 2) { - // add to depend list and insert before output - block->AddAutoDepend(value_node); - } else { - // expand the assign statement - py::object target_node = expand_info[2]; - WriteAssignVars(block, target_node, value_node); - } - } - return block; -} - -LocationPtr Parser::GetLocation(const py::object &node) const { - MS_EXCEPTION_IF_NULL(ast_); - py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (ret.size() < 5) { - MS_LOG(EXCEPTION) << "List size should not be less than 5."; - } - // refer to Location::Location() for each member of ret: line, column, line_end, column_end. - auto location = std::make_shared(ret[0].cast(), ret[1].cast(), ret[2].cast(), - ret[3].cast(), ret[4].cast()); - return location; -} - -void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, - const FunctionBlockPtr &false_block) { - true_block->AddPrevBlock(pre_block); - true_block->Mature(); - - false_block->AddPrevBlock(pre_block); - false_block->Mature(); -} - -FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast return"; - MS_EXCEPTION_IF_NULL(block); - // create return valuenode - AnfNodePtr pReturnValueNode = NewValueNode(prim::kPrimReturn); - // parse the return Statements value - py::object value = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr pReturnStatementNode = ParseExprNode(block, value); - // Create the cnode - CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode}); - - block->func_graph()->set_return(pReturnCNode); - - return block; -} - -// Process binary operators,eg: `a + b`, `a | b`, etc. -AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast BinOP"; - - py::object left = python_adapter::GetPyObjAttr(node, "left"); - py::object right = python_adapter::GetPyObjAttr(node, "right"); - py::object op = python_adapter::GetPyObjAttr(node, "op"); - // create left and right ANF node - AnfNodePtr left_node = ParseExprNode(block, left); - if (left_node == nullptr) { - MS_LOG(WARNING) << "DoBinOp process left node failed: " << errcode(); - return nullptr; - } - AnfNodePtr right_node = ParseExprNode(block, right); - if (right_node == nullptr) { - MS_LOG(WARNING) << "DoBinOp process right node failed:" << errcode(); - return nullptr; - } - // resolve the op - AnfNodePtr op_node = block->MakeResolveAstOp(op); - // create apply node - return block->func_graph()->NewCNode({op_node, left_node, right_node}); -} - -AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Name"; - auto name_id = py::cast(python_adapter::GetPyObjAttr(node, "id")); - MS_LOG(DEBUG) << "The Name id is " << name_id; - TraceGuard trace_guard(GetLocation(node)); - if (block->IsGlobalVar(name_id)) { - return block->MakeResolveSymbol(name_id); - } - return block->ReadVariable(name_id); -} - -AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { - MS_LOG(DEBUG) << "Process ast NoneType"; - return NewValueNode(kNone); -} - -AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { - MS_LOG(DEBUG) << "Process ast Ellipsis"; - return NewValueNode(kEllipsis); -} - -AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Num"; - py::object obj = python_adapter::GetPyObjAttr(node, "n"); - TraceGuard trace_guard(GetLocation(node)); - if (py::isinstance(obj)) { - MS_LOG(INFO) << "The Num is int:" << (std::string)py::str(obj); - auto data = py::cast(obj); - return NewValueNode(data); - } else if (py::isinstance(obj)) { - MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj); - auto data = py::cast(obj); - return NewValueNode(data); - } else { - // no else actually - MS_LOG(ERROR) << "Unsupported Num type : " << (std::string)py::str(obj) << GetLocation(node)->ToString(); - errcode_ = PARSE_NODE_TYPE_UNKOWN; - return nullptr; - } -} - -AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Str"; - auto str_s = py::cast(python_adapter::GetPyObjAttr(node, "s")); - return NewValueNode(str_s); -} - -AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) { - MS_LOG(DEBUG) << "Process ast NameConstant"; - py::object obj = python_adapter::GetPyObjAttr(node, "value"); - TraceGuard trace_guard(GetLocation(node)); - if (py::isinstance(obj)) { - MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj); - auto data = py::cast(obj); - return NewValueNode(data); - } else if (py::isinstance(obj)) { - MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj); - return NewValueNode(kNone); - } else { - // no else actually - MS_LOG(ERROR) << "Unsupported NameConstant type: " << (std::string)py::str(obj) << GetLocation(node)->ToString(); - errcode_ = PARSE_NODE_TYPE_UNKOWN; - return nullptr; - } -} -AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes) { - AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); - std::vector make_tuple_nodes; - make_tuple_nodes.push_back(make_tuple_op); - (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), - [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); - return block->func_graph()->NewCNode(make_tuple_nodes); -} -// process function call, eg : f1(x, y) ... -AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Call"; - // process function call - py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func"); - AnfNodePtr call_function_anf_node = ParseExprNode(block, function_ast_node); - // function call arguments should be passed in as groups and unpacked later using unpack call - py::list args = python_adapter::GetPyObjAttr(node, "args"); - std::vector packed_arguments; - std::vector group_arguments; - - bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments); - bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments); - // if there is stared or keyword argument, unpack may be needed - bool need_unpack = need_unpack_args || need_unpack_keywords; - - return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); -} - -AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, - const std::vector &packed_arguments, - const std::vector &group_arguments, bool need_unpack) const { - // if there is keyword arguments or starred, using an unpack_call op to unpack the argument - if (need_unpack) { - std::vector unpack_call_nodes; - auto unpack_call_op = NewValueNode(std::make_shared(NAMED_METAGRAPH_UNPACKCALL)); - unpack_call_nodes.push_back(unpack_call_op); - unpack_call_nodes.push_back(call_function_anf_node); - (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), - [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); - return unpack_call; - } - // else there is no keyword arguments and starred, parsed as normal arguments without unpack - std::vector func_call_nodes; - func_call_nodes.push_back(call_function_anf_node); - (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), - [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes); - return call_anf_node; -} - -bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, - std::vector *packed_arguments, std::vector *group_arguments) { - bool need_unpack = false; - for (size_t i = 0; i < args.size(); i++) { - auto arg_node = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, args[i]))); - if (arg_node == AST_SUB_TYPE_STARRED) { - if (!group_arguments->empty()) { - packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); - } - packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value"))); - group_arguments->clear(); - need_unpack = true; - } else { - group_arguments->push_back(ParseExprNode(block, args[i])); - } - } - if (!group_arguments->empty()) { - packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments)); - } - return need_unpack; -} - -bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, - std::vector *packed_arguments) { - bool need_unpack = false; - py::list keywords = python_adapter::GetPyObjAttr(node, "keywords"); - if (!keywords.empty()) { - need_unpack = true; - std::vector keys; - std::vector values; - for (size_t index = 0; index < keywords.size(); index++) { - auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg"); - auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value"); - if (py::isinstance(kw_key)) { - packed_arguments->push_back(ParseExprNode(block, kw_value)); - } else { - auto kw_key_c = kw_key.cast(); - keys.push_back(NewValueNode(kw_key_c)); - values.push_back(ParseExprNode(block, kw_value)); - } - } - auto keys_tuple = GenerateMakeTuple(block, keys); - auto values_tuple = GenerateMakeTuple(block, values); - auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); - std::vector make_dict_nodes; - make_dict_nodes.push_back(make_dict_op); - make_dict_nodes.push_back(keys_tuple); - make_dict_nodes.push_back(values_tuple); - packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes)); - } - return need_unpack; -} - -// process call attributes of class type define, eg: x.y() -AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Attribute"; - - // process class value,eg: self.xx - if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE) { - if (ast_->IsClassMember(node)) { - std::string var_name = "self."; - std::string attr_name = node.attr("attr").cast(); - (void)var_name.append(attr_name); - auto attr_obj = ast()->obj().attr(attr_name.c_str()); - if (py::hasattr(ast()->obj(), attr_name.c_str()) && - (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance(attr_obj) || - py::isinstance(attr_obj) || py::isinstance(attr_obj) || - py::isinstance(attr_obj) || data_converter::IsCellInstance(attr_obj))) { - return block->MakeResolveSymbol(var_name); - } else { - return block->ReadVariable(var_name); - } - } - } - - // process the get attr - // Use the Primitive replace the operation resolve node (getattr) - // because the getattr will eventually be converted to Primitive node - AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr); - - // process the attr body - py::object value_body = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr value_node = ParseExprNode(block, value_body); - if (value_node == nullptr) { - MS_LOG(WARNING) << "Parse attribute failed"; - return nullptr; - } - - // process the node attr - auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast(); - MS_LOG(DEBUG) << "Attr = " << attr_str; - TraceManager::DebugTrace(GetLocation(python_adapter::GetPyObjAttr(node, "attr"))); - AnfNodePtr attr_node = NewValueNode(attr_str); - TraceManager::EndTrace(); - - // create the apply node - return block->func_graph()->NewCNode({op_node, value_node, attr_node}); -} - -// Process comparison expression : a == b. a > b etc. -AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Compare"; - - // for python comparison ,there may be if x>y>5 , - // which there is two ops , but we only support one now - py::list ops = python_adapter::GetPyObjAttr(node, "ops"); - if (ops.size() > MAX_COMPARISON_OPS_SUPPORTED) { - MS_LOG(ERROR) << "MindSpore does not support comparison with operators more than one now, ops size =" << ops.size(); - return nullptr; - } - - py::object left = python_adapter::GetPyObjAttr(node, "left"); - py::list comparators = python_adapter::GetPyObjAttr(node, "comparators"); - AnfNodePtr left_node = ParseExprNode(block, left); - AnfNodePtr right_node = ParseExprNode(block, comparators[0]); - - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); - - return block->func_graph()->NewCNode({op_node, left_node, right_node}); -} - -AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, - const py::object &op) { - // if there is only one bool op now - if (value_list.size() == 1) { - AnfNodePtr first_node = ParseExprNode(block, value_list[0]); - return first_node; - } else { - py::object first = value_list[0]; - py::list rest; - for (size_t i = 1; i < value_list.size(); i++) { - rest.append(value_list[i]); - } - - AnfNodePtr first_node = ParseExprNode(block, first); - AnfNodePtr rest_node = ProcessBoolOpValueList(block, rest, op); - auto op_node = block->MakeResolveAstOp(op); - return block->func_graph()->NewCNode({op_node, first_node, rest_node}); - } -} - -// Process comparison expression : a and b. a or b . -AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast BoolOp"; - py::object op_node = python_adapter::GetPyObjAttr(node, "op"); - py::list op_values = python_adapter::GetPyObjAttr(node, "values"); - return ProcessBoolOpValueList(block, op_values, op_node); -} - -// Process a function def -FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast FunctionDef"; - FunctionBlockPtr function_block = ParseFunction(node, block); - MS_EXCEPTION_IF_NULL(function_block); - - // get function name - py::str name = python_adapter::GetPyObjAttr(node, "name"); - std::string function_name = name; - ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph()); - block->WriteVariable(function_name, valuenode_graph); - return block; -} - -// Process a lambda expression . like lambda x,y: x + y -AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Lambda"; - FunctionBlockPtr func_block = MakeFunctionBlock(*this); - func_block->AddPrevBlock(block); - func_block->Mature(); - - // get lambda args - py::list args = ast_->GetArgs(node); - for (std::size_t i = 0; i < args.size(); i++) { - std::string arg = py::cast(args[i].attr("arg")); - TraceManager::DebugTrace(GetLocation(args[i])); - auto para_node = std::make_shared(func_block->func_graph()); - TraceManager::EndTrace(); - para_node->debug_info()->set_name(arg); - func_block->func_graph()->add_parameter(para_node); - func_block->WriteVariable(arg, para_node); - MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg; - } - - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node); - func_block->func_graph()->set_output(lambda_body_node); - ValueNodePtr const_graph = NewValueNode(func_block->func_graph()); - return const_graph; -} - -// process a tuple -AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Tuple"; - MS_EXCEPTION_IF_NULL(block); - py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); - if (elts.size() == 0) { - auto empty_tuple = std::vector(); - return NewValueNode(std::make_shared(empty_tuple)); - } - - std::vector tuple_vec; - AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); - tuple_vec.emplace_back(make_tuple_op); - for (size_t i = 0; i < elts.size(); i++) { - AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); - tuple_vec.emplace_back(node_ptr); - } - CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec); - return tuple_app; -} - -// process a list -AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast List"; - MS_EXCEPTION_IF_NULL(block); - py::tuple elts = python_adapter::GetPyObjAttr(node, "elts"); - if (elts.size() == 0) { - auto empty_list = std::vector(); - return NewValueNode(std::make_shared(empty_list)); - } - - std::vector list_vec; - AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST); - list_vec.emplace_back(make_list_op); - for (size_t i = 0; i < elts.size(); i++) { - AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); - list_vec.emplace_back(node_ptr); - } - CNodePtr list_app = block->func_graph()->NewCNode(list_vec); - return list_app; -} - -// process a subscript, such as x[y] , node expressed as value[slice] -AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Subscript"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - py::object value_node = python_adapter::GetPyObjAttr(node, "value"); - py::object slice_node = python_adapter::GetPyObjAttr(node, "slice"); - AnfNodePtr value = ParseExprNode(block, value_node); - AnfNodePtr slice = ParseExprNode(block, slice_node); - - return block->func_graph()->NewCNode({op_getitem, value, slice}); -} - -// process a slice, get the slice value -AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Slice"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE); - py::object start = python_adapter::GetPyObjAttr(node, "lower"); - py::object stop = python_adapter::GetPyObjAttr(node, "upper"); - py::object step = python_adapter::GetPyObjAttr(node, "step"); - AnfNodePtr start_node = ParseExprNode(block, start); - AnfNodePtr stop_node = ParseExprNode(block, stop); - AnfNodePtr step_node = ParseExprNode(block, step); - - return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node}); -} - -// process a extslice -AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast ExtSlice"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE); - py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims"); - - std::vector node_vec; - node_vec.emplace_back(make_tuple_op); - for (size_t i = 0; i < slice_tuple.size(); i++) { - AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); - node_vec.emplace_back(node_ptr); - } - CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec); - return tuple_conde; -} - -// process a index, get the index number -AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Index"; - py::object value_node = python_adapter::GetPyObjAttr(node, "value"); - return ParseExprNode(block, value_node); -} - -// process a UnaryOp, +a, -b -AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast UnaryOp"; - py::object op = python_adapter::GetPyObjAttr(node, "op"); - - MS_EXCEPTION_IF_NULL(block); - // resolve the op - AnfNodePtr op_node = block->MakeResolveAstOp(op); - - py::object operand = python_adapter::GetPyObjAttr(node, "operand"); - AnfNodePtr operand_node = ParseExprNode(block, operand); - return block->func_graph()->NewCNode({op_node, operand_node}); -} - -// process a dict ast node expression -AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Dict"; - py::list keys = node.attr("keys"); - py::list values = node.attr("values"); - std::vector key_nodes; - std::vector value_nodes; - for (size_t i = 0; i < keys.size(); i++) { - key_nodes.push_back(ParseExprNode(block, keys[i])); - value_nodes.push_back(ParseExprNode(block, values[i])); - } - auto keys_tuple = GenerateMakeTuple(block, key_nodes); - auto values_tuple = GenerateMakeTuple(block, value_nodes); - MS_EXCEPTION_IF_NULL(block); - auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); - return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple}); -} - -// process a augment assign such as a += b; -FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast AugAssign"; - py::object op = python_adapter::GetPyObjAttr(node, "op"); - - MS_EXCEPTION_IF_NULL(block); - // resolve the op - AnfNodePtr op_node = block->MakeResolveAstOp(op); - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - MS_EXCEPTION_IF_NULL(ast_); - auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, target_node))); - AnfNodePtr read_node = nullptr; - if (ast_type == AST_SUB_TYPE_NAME) { - read_node = ParseName(block, target_node); - } else if (ast_->IsClassMember(target_node)) { - read_node = ParseAttribute(block, target_node); - } else { - MS_LOG(EXCEPTION) << "Not supported augassign"; - } - if (read_node == nullptr) { - MS_LOG(EXCEPTION) << "Can not get target node "; - } - - py::object value = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr value_node = ParseExprNode(block, value); - CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, read_node, value_node}); - WriteAssignVars(block, target_node, augassign_app); - return block; -} - -// process global declaration such as 'global x'; -FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast Global"; - MS_EXCEPTION_IF_NULL(block); - py::list vars = python_adapter::GetPyObjAttr(node, "names"); - for (auto &item : vars) { - block->AddGlobalVar(py::cast(item)); - } - return block; -} - -// process a if statement -FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast If"; - py::object test_node = python_adapter::GetPyObjAttr(node, "test"); - AnfNodePtr condition_node = ParseExprNode(block, test_node); - MS_EXCEPTION_IF_NULL(block); - CNodePtr bool_node = block->ForceToBoolNode(condition_node); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr true_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr false_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - MakeConditionBlocks(block, true_block, false_block); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - // process the if-true branch - py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode); - - // if the return_ is set ,it has its own continuation block - if (true_end->func_graph()->get_return() == nullptr) { - true_end->Jump(after_block, nullptr); - } - - // process the orelse branch - py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); - FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode); - - // if the return_ is set ,it has its own continuation block - if (false_end->func_graph()->get_return() == nullptr) { - false_end->Jump(after_block, nullptr); - } - - block->ConditionalJump(bool_node, true_block, false_block); - after_block->Mature(); - return after_block; -} - -FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast While"; - MS_EXCEPTION_IF_NULL(block); - MS_LOG(INFO) << "Parse while statement"; - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr header_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr body_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - body_block->AddPrevBlock(header_block); - after_block->AddPrevBlock(header_block); - block->Jump(header_block, nullptr); - - py::object test_node = python_adapter::GetPyObjAttr(node, "test"); - AnfNodePtr condition_node = ParseExprNode(header_block, test_node); - condition_node = header_block->ForceToWhileCond(condition_node); - body_block->Mature(); - header_block->ConditionalJump(condition_node, body_block, after_block); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, nullptr}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body = ParseStatements(body_block, body_node); - if (after_body->func_graph()->get_return() == nullptr) { - after_body->Jump(header_block, nullptr); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, nullptr); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} - -CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node, - const AnfNodePtr &op_iter) { - py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); - AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); - return block->func_graph()->NewCNode({op_iter, iter_anf_node}); -} - -CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, - const AnfNodePtr &op_hasnext) { - MS_EXCEPTION_IF_NULL(header_block); - return header_block->func_graph()->NewCNode({op_hasnext, iter_param}); -} - -FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { - TraceManager::DebugTrace(trace_info); - FunctionBlockPtr body_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - return body_block; -} - -// A for loop will generate 3 functions :the test, the body, and the continuation -// for x in xs: -// body -// it is compiled to be following statement -// if len(xs) < max_loop_cnt: -// ParseForIter() // use iter to implement for loop, which always unroll loop -// else: -// ParseForLoop() // use loop var to implement for loop, which always sink loop -FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For, create an if else statement"; - MS_EXCEPTION_IF_NULL(block); - // create statement 'len(xs) < prim::MAX_FOR_LOOP_COUNT' - AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); - py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); - AnfNodePtr iter_node = ParseExprNode(block, iter_obj); - CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); - CNodePtr bool_node = block->func_graph()->NewCNode( - {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(prim::MAX_FOR_LOOP_COUNT)}); - - // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr true_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr false_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - MakeConditionBlocks(block, true_block, false_block); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - FunctionBlockPtr true_end = ParseForIter(true_block, node); - true_end->Jump(after_block, nullptr); - - FunctionBlockPtr false_end = ParseForLoop(false_block, node); - false_end->Jump(after_block, nullptr); - - block->ConditionalJump(bool_node, true_block, false_block); - after_block->Mature(); - return after_block; -} - -// A for loop will generate 3 functions :the test, the body, and the continuation -// for x in xs: -// body -// it is compiled to be following statement -// it = iter(xs) -// while hastnext(it) -// x, it = next(it) -// body -FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER); - AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT); - // generate the iterator apply - CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter); - MS_EXCEPTION_IF_NULL(iter_apply); - FunctionBlockPtr header_block = - GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(header_block); - // generate the hasnext apply which is a condition - ParameterPtr iter_param = header_block->func_graph()->add_parameter(); - CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext); - // generate the body of the for statement - FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(body_block); - body_block->AddPrevBlock(header_block); - // generate the iterator next apply - // process as following: `app = next(it); target = app[0]; it = app[1];` - CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param}); - CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(0)}); - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - - CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(1)}); - WriteAssignVars(body_block, target_node, target_app); - - // link the variable name with the target - auto it_info = std::make_shared(target_app->debug_info()); - iter_param->debug_info()->set_trace_info(it_info); - iter2_app->debug_info()->set_trace_info(it_info); - iter_apply->debug_info()->set_trace_info(it_info); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - MS_EXCEPTION_IF_NULL(after_block); - TraceManager::EndTrace(); - after_block->AddPrevBlock(header_block); - - block->Jump(header_block, iter_apply); - body_block->Mature(); - header_block->ConditionalJump(cond_apply, body_block, after_block); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, iter2_app}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); - if (after_body_block->func_graph()->get_return() == nullptr) { - after_body_block->Jump(header_block, iter2_app); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, nullptr); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} - -// A for loop will generate 3 functions :the test, the body, and the continuation -// for x in xs: -// body -// it is compiled to be following statement -// i = 0 -// while i < len(xs) -// x = xs[i] -// i = i + 1 -// body -FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast For by loop variable"; - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - - // get varibale name of 'x' in statement 'for x in xs' - py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - - // create statement 'len(xs)' - py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter"); - AnfNodePtr iter_node = ParseExprNode(block, iter_obj); - MS_EXCEPTION_IF_NULL(iter_node); - CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); - - FunctionBlockPtr header_block = - GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(header_block); - // create loop variable 'i' - ParameterPtr loop_var = header_block->func_graph()->add_parameter(); - // create loop condition 'i < len(xs)' - CNodePtr cond_node = header_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarLt), loop_var, len_iter}); - - // generate the body of the for statement - FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); - MS_EXCEPTION_IF_NULL(body_block); - body_block->AddPrevBlock(header_block); - // create 'x = xs[i]' - CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var}); - WriteAssignVars(body_block, target_node, target_var); - // create 'i = i + 1' - CNodePtr loop_var_inc = - body_block->func_graph()->NewCNode({NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(1)}); - body_block->WriteVariable(loop_var->name(), loop_var_inc); - - // link the variable name with the target - auto it_info = std::make_shared(loop_var_inc->debug_info()); - loop_var->debug_info()->set_trace_info(it_info); - len_iter->debug_info()->set_trace_info(it_info); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr after_block = MakeFunctionBlock(*this); - MS_EXCEPTION_IF_NULL(after_block); - TraceManager::EndTrace(); - after_block->AddPrevBlock(header_block); - - block->Jump(header_block, NewValueNode(0)); - body_block->Mature(); - - header_block->ConditionalJump(cond_node, body_block, after_block, false); - - // Parse loop body statements with loop context. - LoopContext loop_context{&loops_, header_block, loop_var_inc}; - py::object body_node = python_adapter::GetPyObjAttr(node, "body"); - FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); - if (after_body_block->func_graph()->get_return() == nullptr) { - after_body_block->Jump(header_block, loop_var_inc); - } - - header_block->Mature(); - after_block->Mature(); - auto &end_block = loop_context.EndBlock(); - if (end_block) { - // end_block exists if we encounter 'break' in loop body. - after_block->Jump(end_block, nullptr); - end_block->Mature(); - return end_block; - } - // No 'break', no end_block. - return after_block; -} - -AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast IfExp"; - MS_EXCEPTION_IF_NULL(block); - py::object test_node = python_adapter::GetPyObjAttr(node, "test"); - AnfNodePtr condition_node = ParseExprNode(block, test_node); - CNodePtr bool_node = block->ForceToBoolNode(condition_node); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr true_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - FunctionBlockPtr false_block = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - - MakeConditionBlocks(block, true_block, false_block); - - // process the if-true branch - py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); - true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode)); - AnfNodePtr true_node = ParseExprNode(true_block, bodyNode); - - // process the orelse branch - py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse"); - false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode)); - AnfNodePtr false_node = ParseExprNode(false_block, orelseNode); - - true_block->func_graph()->set_output(true_node); - false_block->func_graph()->set_output(false_node); - - // Use the Primitive replace the operation resolve node (switch) - // because the switch will eventually be converted to Primitive node - CNodePtr switch_app = - block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); - - std::vector call_graph_nodes{switch_app}; - CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); - return switch_app_call; -} - -void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { - MS_EXCEPTION_IF_NULL(block); - MS_EXCEPTION_IF_NULL(assigned_node); - py::str name = python_adapter::GetPyObjAttr(targ, "id"); - std::string name_id = name; - assigned_node->debug_info()->set_name(name_id); - // set the debug name of the constant graph - if (IsValueNode(assigned_node)) { - // the value should be graph - auto fg = GetValueNode(assigned_node); - if (fg->debug_info()->name().empty()) { - fg->debug_info()->set_name(name_id); - } - } - block->WriteVariable(name_id, assigned_node); -} - -void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node) { - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - py::list items = python_adapter::GetPyObjAttr(targ, "elts"); - for (size_t i = 0; i < items.size(); i++) { - // Use the Primitive replace the operation resolve node (getitem) - // because the getitem will eventually be converted to Primitive node - CNodePtr item_apply = block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast(i))}); - - py::object elt = items[i]; - WriteAssignVars(block, elt, item_apply); - } -} - -void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, - const AnfNodePtr &assigned_node) { - // Now only support the self.xx = xxxxx, can't support x.y = xxxx - AnfNodePtr target_node = ParseExprNode(block, targ); - MS_EXCEPTION_IF_NULL(target_node); - - std::string attr_name = targ.attr("attr").cast(); - std::string var_name = "self."; - (void)var_name.append(attr_name); - MS_LOG(DEBUG) << "assign " << var_name; - - // Get targ location info for error printing - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, targ); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type - if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" - << line_no; - } - auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); - auto obj_type = obj.attr("__class__").attr("__name__"); - if (!py::hasattr(obj, "__parameter__")) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" - << py::str(obj).cast() << "' with type '" - << py::str(obj_type).cast() << "' at " << filename << ":" << line_no; - } - - MS_EXCEPTION_IF_NULL(block); - block->WriteVariable(var_name, assigned_node); - MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); - block->SetStateAssgin(target_node, var_name); -} - -void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, - const AnfNodePtr &assigned_node) { - MS_EXCEPTION_IF_NULL(block); - AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM); - py::object value_obj = python_adapter::GetPyObjAttr(targ, "value"); - py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice"); - AnfNodePtr value_node = ParseExprNode(block, value_obj); - AnfNodePtr slice_node = ParseExprNode(block, slice_obj); - CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); - // getitem apply should return the sequence data structure itself - std::string var_name = ""; - if (ast_->IsClassMember(value_obj)) { - std::string attr_name = value_obj.attr("attr").cast(); - var_name = "self." + attr_name; - if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function."; - } - auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); - auto obj_type = obj.attr("__class__").attr("__name__"); - if (!py::hasattr(obj, "__parameter__")) { - MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" - << py::str(obj).cast() << "' with type '" - << py::str(obj_type).cast() << "'."; - } - } else { - var_name = value_obj.attr("id").cast(); - } - block->WriteVariable(var_name, setitem_app); -} - -void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - MS_LOG(DEBUG) << "Process WriteAssignVars"; - auto ast_type = AstSubType(py::cast(ast_->CallParserObjMethod(PYTHON_PARSE_GET_AST_TYPE, targ))); - if (ast_type == AST_SUB_TYPE_NAME) { - HandleAssignName(block, targ, value_node); - } else if (ast_type == AST_SUB_TYPE_TUPLE) { - HandleAssignTuple(block, targ, value_node); - } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) { - HandleAssignSubscript(block, targ, value_node); - } else if (ast_->IsClassMember(targ)) { - HandleAssignClassMember(block, targ, value_node); - } else { - MS_LOG(EXCEPTION) << "Not supported assign type: " << ast_type - << " NodeInfo: " << trace::GetDebugInfo(value_node->debug_info()); - } -} - -// process a assign statement, such as a =b, a,b = tup -FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast assgin"; - py::object value_object = python_adapter::GetPyObjAttr(node, "value"); - AnfNodePtr value_node = ParseExprNode(block, value_object); - py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); - py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__"); - size_t count = IntToSize(pcount); - MS_LOG(DEBUG) << "The nodes count is " << count; - for (size_t i = 0; i < count; i++) { - auto target_node = py::cast(targets_object)[i]; - WriteAssignVars(block, target_node, value_node); - } - - return block; -} - -FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { - if (loops_.empty()) { - // Report error if loop context not set for the 'break' statement. - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; - } - // Get current loop. - Loop &loop = loops_.top(); - if (loop.end == nullptr) { - // Create end_block if it is not existed. - TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); - loop.end = MakeFunctionBlock(*this); - TraceManager::EndTrace(); - } - // Jump to the end_block. - block->Jump(loop.end, nullptr); - return block; -} - -FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { - if (loops_.empty()) { - // Report error if loop context not set for the 'continue' statement. - py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); - if (location.size() < 2) { - MS_LOG(EXCEPTION) << "List size should not be less than 2."; - } - auto filename = location[0].cast(); - auto line_no = location[1].cast(); - MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; - } - // Jump to the header of the loop with iterator called. - Loop &loop = loops_.top(); - block->Jump(loop.header, loop.iterator); - return block; -} - -FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { - // We just bypass 'pass' statement. - return block; -} - -void Parser::RemoveUnnecessaryPhis() { - // merge all removable phis to one map; - std::unordered_map removable_phis; - for (FunctionBlockPtr &block : func_block_list_) { - MS_EXCEPTION_IF_NULL(block); - removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end()); - } - - if (removable_phis.size() == 0) { - return; - } - for (auto &node : DeepUsedGraphSearch(func_graph_->get_return())) { - if (node->isa()) { - const auto &cnode = node->cast(); - auto &inputs = cnode->inputs(); - for (std::size_t i = 0; i < inputs.size(); i++) { - if (inputs[i]->isa()) { - const auto &inp = inputs[i]->cast(); - const auto &iter = removable_phis.find(inp); - if (iter == removable_phis.end()) { - continue; - } - auto &argNode = iter->second; - MS_LOG(DEBUG) << "graph " << cnode->func_graph()->ToString() << " replace phi " << inp->ToString() << " in " - << cnode->DebugString() << " with " << argNode->DebugString(); - cnode->set_input(i, argNode); - } - } - } - } -} - -// ParseAst class code -bool ParseAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) { - // init the type - target_type_ = PARSE_TARGET_UNKNOW; - - // call python parse, get the parser fn - module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD); - - // get the obj type - auto type = data_converter::GetObjType(obj_); - if (type == RESOLVE_TYPE_FUNCTION) { - target_type_ = PARSE_TARGET_FUNCTION; - function_ = obj_; - } else if (type == RESOLVE_TYPE_METHOD) { - // process the method ,need get the method's self obj - target_type_ = PARSE_TARGET_METHOD; - py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS); - if (py::isinstance(method_object)) { - MS_LOG(ERROR) << "Get method's self object instance failed."; - return false; - } - target_type_ = PARSE_TARGET_OBJECT_INSTANCE; - function_ = obj_; - obj_ = method_object; - } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) { - // obj is class instance, get the method to parse. - function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method); - if (py::isinstance(function_)) { - MS_LOG(ERROR) << "Get obj method function failed."; - return false; - } - target_type_ = PARSE_TARGET_OBJECT_INSTANCE; - // check the fn is method - auto obj_type = data_converter::GetObjType(function_); - if (obj_type != RESOLVE_TYPE_METHOD) { - MS_LOG(WARNING) << "Parse method function is invalid."; - return false; - } - } else { - MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type; - return false; - } - - // call python parse get ast tree - parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method); - ast_tree_ = python_adapter::CallPyObjMethod(parser_, "parse"); - - // get fn name and module - function_module_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_module")); - function_name_ = py::cast(python_adapter::GetPyObjAttr(parser_, "function_name")); - function_filename_ = py::cast(python_adapter::GetPyObjAttr(parser_, "filename")); - function_line_offset_ = py::cast(python_adapter::GetPyObjAttr(parser_, "line_offset")); - - return true; -} - -// Get ast tree node : is the tree bode list[0] -py::object ParseAst::GetAstNode() { - py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body"); - py::object ast_node = tree_body[0]; - return ast_node; -} - -py::list ParseAst::GetArgs(const py::object &func_node) { - py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS, func_node); - return ret; -} - -py::list ParseAst::GetArgsDefaultValues(const py::object &func_node) { - py::list ret = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node); - return ret; -} - -AstNodeTypePtr ParseAst::GetNodeType(const py::object &node) { - py::list list_value = python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_NODE_TYPE, node); - if (list_value.size() < 2) { - MS_LOG(ERROR) << "The node of python method must has 2 values."; - return nullptr; - } - auto node_name = py::cast(list_value[0]); - auto type = AstMainType(py::cast(list_value[1])); - return std::make_shared(node, node_name, type); -} - -AstSubType ParseAst::GetOpType(const py::object &node) { - auto op_type = AstSubType(python_adapter::CallPyObjMethod(parser_, PYTHON_PARSE_GET_AST_TYPE, node).cast()); - return op_type; -} - -bool ParseAst::IsClassMember(const py::object &node) { - py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node); - if (!py::isinstance(ret)) { - MS_LOG(ERROR) << "The result of mod function parse, should be bool type."; - return false; - } - return ret.cast(); -} - -bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - MS_LOG(ERROR) << "FuncGraph is null"; - return false; - } - - if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { - MS_LOG(DEBUG) << "No flags"; - return true; - } - py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); - for (auto &item : flags) { - if (!py::isinstance(item.first)) { - MS_LOG(ERROR) << "Type error in flags dict convert"; - return false; - } - auto name = py::cast(item.first); - if (py::isinstance(item.second)) { - auto value = py::cast(item.second); - MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; - func_graph->set_flag(name, value); - } else if (py::isinstance(item.second)) { - auto value = py::cast(item.second); - MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value; - func_graph->set_attr(name, MakeValue(value)); - } else { - MS_LOG(ERROR) << "Type error in flags/attrs dict convert"; - return false; - } - } - return true; -} - -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/parse.h b/mindspore/ccsrc/pipeline/parse/parse.h deleted file mode 100644 index 65ed5ddd12..0000000000 --- a/mindspore/ccsrc/pipeline/parse/parse.h +++ /dev/null @@ -1,360 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_PARSE_H_ -#define PIPELINE_PARSE_PARSE_H_ - -#include -#include -#include -#include -#include -#include -#include "utils/misc.h" -#include "ir/anf.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/function_block.h" - -namespace mindspore { -namespace parse { - -// Parse status define -enum ParseStatusCode : int { - PARSE_SUCCESS = 0, - PARSE_FUNCTION_IS_NULL, // python function is null - PARSE_PARAMETER_INVALID, // parameter is invalid - PARSE_NO_RETURN, // function no return node - PARSE_NODE_TYPE_NO_MATCH, // ast node type is error - PARSE_NODE_TYPE_UNKOWN, // node type is unkown - PARSE_NODE_METHOD_UNSUPPORTED, // no method to parse the node - PARSE_DONT_RESOLVE_SYMBOL, // can't resolve the string - PARSE_NOT_SUPPORTED_COMPARE_EXPR, // the comparison is not supported - PARSE_FAILURE = 0xFF -}; - -class AstNodeType; -class ParseAst; - -// Save loop info for 'continue' and 'break' statements. -struct Loop { - // Loop header block. - FunctionBlockPtr header; - // Loop iterator node, used in 'for loop'. - AnfNodePtr iterator; - // Loop end block. - FunctionBlockPtr end; - - Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) - : header(header), iterator(iterator), end(end) {} - ~Loop() = default; -}; - -// Loop context for loop stack management. -class LoopContext { - public: - LoopContext(std::stack *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { - loops_->emplace(header, iterator, nullptr); - } - ~LoopContext() { loops_->pop(); } - const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } - - private: - std::stack *loops_; -}; - -// Parser to parse python function -class Parser { - public: - explicit Parser(const std::shared_ptr &ast); - - ~Parser() {} - FuncGraphPtr ParseFuncGraph(); - FuncGraphPtr func_graph() const { return func_graph_; } - ParseStatusCode errcode() const { return errcode_; } - std::shared_ptr ast() const { return ast_; } - // get location info from the ast node - LocationPtr GetLocation(const py::object &node) const; - static void InitParserEnvironment(const py::object &obj); - static void CleanParserResource(); - static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); } - static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph); - - private: - // process the stmt node method list - FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node); - // parse expression - FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node); - // process a if statement - FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node); - // process a while statement - FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); - // process a for statement - FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); - FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node); - FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node); - // process a function def statement - FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node); - // process a augment assign - FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node); - // process a global declaration - FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); - // process assign statement - FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); - // process break statement - FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); - // process continue statement - FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); - // process pass statement - FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); - // process the expr and slice node method list - AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); - // process a variable name - AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); - // process NoneType - AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); - // process Ellipsis - AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); - // process a integer or float number - AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); - // process a string variable - AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node); - // process a name - AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); - // process a function call - AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); - // process the if expression - AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); - // process class type define - AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node); - // process a compare expression - AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node); - // process a bool operation - AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node); - // process a lambda operation - AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node); - // process a tuple - AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node); - // process a tuple - AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node); - // process a tuple - AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node); - // process a slice - AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node); - - // process a extslice - AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node); - - // process a tuple - AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node); - - // process a unaryop - AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node); - - // process a dict ast node expression - AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node); - // generate argument nodes for ast function node - void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node); - // generate argument default value for ast function node - void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); - // parse ast function node - FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); - // parse ast statements - FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); - // parse one ast statement node - FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node); - // parse an ast expresion node - AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node); - - void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock, - const FunctionBlockPtr &falseBlock); - void RemoveUnnecessaryPhis(); - // write a new var - void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node); - - // assign value to single variable name - void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // assign value to tuple - void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // assign value to class member - void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // assign value to subscript - void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); - - // process a bool operation value list - AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, const py::object &op); - - CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, - const AnfNodePtr &op_iter); - - CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, - const AnfNodePtr &op_hasnext); - - FunctionBlockPtr GenerateBlockInFor(const TraceInfoPtr &trace_info); - - bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, - std::vector *packed_arguments); - - bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector *packed_arguments, - std::vector *group_arguments); - - AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, - const std::vector &packed_arguments, - const std::vector &group_arguments, bool need_unpack) const; - ScopePtr GetScopeForParseFunction(); - void BuildMethodMap(); - FunctionBlockPtr MakeFunctionBlock(const Parser &parse) { - FunctionBlockPtr block = std::make_shared(parse); - // In order to keep effect order in the sub-graphs which generated by control flow. - // We copy the flags from the top graph to the sub-graphs. - if (func_graph_ && !func_graph_->attrs().empty()) { - block->func_graph()->set_attrs(func_graph_->attrs()); - } - func_block_list_.push_back(block); - return block; - } - // return a make tuple for input elements list - AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector &element_nodes); - - // shared_ptr will be hold by GraphManager, so just hold a weak ref here. - static FuncGraphWeakPtr top_func_graph_; - // Python function id, used to indicate whether two CNodes come from the same Python function - const std::shared_ptr &ast_; - FuncGraphPtr func_graph_; - // error code setwhen parsing ast tree - ParseStatusCode errcode_; - - // hold all reference for FunctionBlock in this round of parsing, - // so in FunctionBlock class we can use FunctionBlock* in member - // pre_blocks_ and jumps_ to break reference cycle. - std::vector func_block_list_; - using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); - using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); - // define the function map to parse ast Statement - std::map stmt_method_map_; - // define the function map to parse ast expression - std::map expr_method_map_; - // Save current loops to support 'continue', 'break' statement. - std::stack loops_; -}; - -// AST node type define code to ast -class AstNodeType { - public: - AstNodeType(const py::object &node, const std::string &name, AstMainType type) - : node_(node), node_name_(name), main_type_(type) {} - - ~AstNodeType() {} - - std::string node_name() const { return node_name_; } - - py::object node() const { return node_; } - - AstMainType main_type() const { return main_type_; } - - private: - const py::object &node_; - const std::string node_name_; - AstMainType main_type_; -}; - -using AstNodeTypePtr = std::shared_ptr; - -// A helper class to parse python function -class ParseAst { - public: - explicit ParseAst(const py::object &obj) : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {} - - ~ParseAst() = default; - - bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); - - py::object GetAstNode(); - - py::list GetArgs(const py::object &func_node); - - py::list GetArgsDefaultValues(const py::object &func_node); - - AstNodeTypePtr GetNodeType(const py::object &node); - - AstSubType GetOpType(const py::object &node); - - template - py::object CallParserObjMethod(const std::string &method, const T &... args) { - return python_adapter::CallPyObjMethod(parser_, method, args...); - } - - template - py::object CallParseModFunction(const std::string &function, const T &... args) { - return python_adapter::CallPyModFn(module_, function, args...); - } - - const std::string &function_name() const { return function_name_; } - - const std::string &function_module() const { return function_module_; } - - const std::string &function_filename() const { return function_filename_; } - - int function_line_offset() const { return function_line_offset_; } - - py::function function() { return function_; } - - ParseTargetTypeDef target_type() const { return target_type_; } - - py::object obj() { return obj_; } - - py::object parser() { return parser_; } - - py::object module() { return module_; } - - py::object ast_tree() { return ast_tree_; } - - bool IsClassMember(const py::object &node); - - private: - // save obj,eg: class instance or function - py::object obj_; - - // function or class method. - py::function function_; - - py::object ast_tree_; - py::object parser_; - py::module module_; - - // Is function or method - ParseTargetTypeDef target_type_; - - std::string function_name_; - std::string function_module_; - std::string function_filename_; - int function_line_offset_; -}; - -// update the graph flags -bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); - -AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_PARSE_H_ diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/parse/python_adapter.cc deleted file mode 100644 index df2f7d0d45..0000000000 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/python_adapter.h" -#include -#include -#include - -namespace mindspore { -namespace parse { -namespace python_adapter { -// python scoped env, should only have one scoped_ instance -static std::shared_ptr scoped_ = nullptr; -// true: start process from python, false: start process from c++ -static bool python_env_ = false; -static bool use_signature_in_resolve_ = true; -void ResetPythonScope() { scoped_ = nullptr; } -void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_in_resolve_ = use_signature; } -bool UseSignatureInResolve() { return use_signature_in_resolve_; } -void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } -bool IsPythonEnv() { return python_env_; } -void SetPythonPath(const std::string &path) { - // load the python module path - (void)python_adapter::set_python_scoped(); - py::module sys = py::module::import("sys"); - py::list sys_path = sys.attr("path"); - - // check the path is exist? - bool is_exist = false; - for (size_t i = 0; i < sys_path.size(); i++) { - std::string path_str = py::cast(sys_path[i]); - if (path_str == path) { - is_exist = true; - } - } - if (!is_exist) { - (void)sys_path.attr("append")(path.c_str()); - } -} - -std::shared_ptr set_python_scoped() { - // if start process from python, no need set the python scope. - if (!python_env_) { - if ((Py_IsInitialized() == 0) && (scoped_ == nullptr)) { - scoped_ = std::make_shared(); - } - } - return scoped_; -} - -// return the module of python -py::module GetPyModule(const std::string &module) { - if (!module.empty()) { - return py::module::import(module.c_str()); - } else { - return py::none(); - } -} - -// Get the obj of attr -py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { - if (!attr.empty() && !py::isinstance(obj)) { - if (py::hasattr(obj, attr.c_str())) { - return obj.attr(attr.c_str()); - } - MS_LOG(DEBUG) << "Obj have not the attr: " << attr; - } - return py::none(); -} - -py::object GetPyFn(const std::string &module, const std::string &name) { - (void)python_adapter::set_python_scoped(); - if (!module.empty() && !name.empty()) { - py::module mod = py::module::import(module.c_str()); - py::object fn = mod.attr(name.c_str()); - return fn; - } - return py::none(); -} - -} // namespace python_adapter -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.h b/mindspore/ccsrc/pipeline/parse/python_adapter.h deleted file mode 100644 index 98adcd4f73..0000000000 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_PYTHON_ADAPTER_H_ -#define PIPELINE_PARSE_PYTHON_ADAPTER_H_ -#include -#include -#include - -#include "pybind11/embed.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -#include "pipeline/parse/parse_base.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace parse { -// A utility to call python interface -namespace python_adapter { -py::module GetPyModule(const std::string &module); -py::object GetPyObjAttr(const py::object &obj, const std::string &attr); -template -py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { - if (!method.empty() && !py::isinstance(obj)) { - return obj.attr(method.c_str())(args...); - } - return py::none(); -} - -// call python function of module -template -py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { - if (!function.empty() && !py::isinstance(mod)) { - return mod.attr(function.c_str())(args...); - } - return py::none(); -} - -// turn off the signature when ut use parser to construct a graph. -void set_use_signature_in_resolve(bool use_signature) noexcept; -bool UseSignatureInResolve(); - -std::shared_ptr set_python_scoped(); -void ResetPythonScope(); -bool IsPythonEnv(); -void SetPythonPath(const std::string &path); -void set_python_env_flag(bool python_env) noexcept; -py::object GetPyFn(const std::string &module, const std::string &name); -// Call the python function -template -py::object CallPyFn(const std::string &module, const std::string &name, T... args) { - (void)set_python_scoped(); - if (!module.empty() && !name.empty()) { - py::module mod = py::module::import(module.c_str()); - py::object fn = mod.attr(name.c_str())(args...); - return fn; - } - return py::none(); -} -} // namespace python_adapter -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_PYTHON_ADAPTER_H_ diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc deleted file mode 100644 index b4b45c078a..0000000000 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ /dev/null @@ -1,320 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/parse/resolve.h" - -#include -#include -#include -#include - -#include "ir/param_value.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/python_adapter.h" -#include "utils/any.h" -#include "operator/ops.h" -#include "optimizer/opt.h" -#include "optimizer/irpass.h" -#include "./common.h" - -namespace mindspore { -namespace parse { -abstract::AbstractBasePtr ClassObject::ToAbstract() { - ClassPtr cls_ptr = ParseDataClass(obj()); - auto abs_scalar = std::make_shared(); - abs_scalar->set_type(std::make_shared()); - abs_scalar->set_value(cls_ptr); - - AbstractBasePtrList args_spec_list = {abs_scalar}; - auto func_ptr = std::make_shared(prim::kPrimMakeRecord); - return std::make_shared(func_ptr, args_spec_list); -} - -abstract::AbstractBasePtr ClassType::ToAbstract() { - auto abs_scalar = - std::make_shared(shared_from_base(), std::make_shared()); - AbstractBasePtrList args_spec_list = {abs_scalar}; - - auto func_ptr = std::make_shared(prim::kPrimCreateInstance); - auto ret_val = std::make_shared(func_ptr, args_spec_list); - ret_val->set_value_desc(ToString()); - return ret_val; -} - -// call python PYTHON_MOD_RESOLVE_FUNCTION interface to resolve the symbol in corresponding namespace -bool SymbolResolver::Resolve() { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - - py::object obj = namespace_->obj(); - std::string symbol = symbol_->symbol(); - if (py::isinstance(obj)) { - MS_LOG(ERROR) << "Unresolved symbol: " << symbol; - return false; - } - result_ = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_FUNCTION, obj, common::SafeCStr(symbol)); - return true; -} - -namespace { -// argument obj should be python Parameter object -// it will be converted to Parameter node here -AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { - MS_EXCEPTION_IF_NULL(func_graph); - - // parameter object should not be none - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "Resolve class Parameter error because obj is null."; - } - - if (!py::hasattr(obj, "name")) { - MS_LOG(EXCEPTION) << "Resolve class Parameter error: cannot find name attr for obj"; - } - - // get the parameter name from parameter object - auto name_attr = python_adapter::GetPyObjAttr(obj, "name"); - if (py::isinstance(name_attr)) { - MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; - } - - std::string param_name = py::cast(name_attr); - auto top_graph = Parser::GetTopFuncGraph(); - // if the parameter node has been created , return it - AnfNodePtr para_node = nullptr; - for (auto const ¶m : top_graph->parameters()) { - auto param_node = dyn_cast(param); - if (param_node != nullptr && param_node->name() == param_name) { - para_node = param; - break; - } - } - if (para_node == nullptr) { - auto node = top_graph->AddWeightParameter(param_name); - auto param_value = py::cast(python_adapter::GetPyObjAttr(obj, "_value")); - node->set_default_param(param_value); - // set_abstract for parameter - ValuePtr value = param_value->value(); - constexpr bool broaden = true; - node->set_abstract(abstract::FromValue(value, broaden)); - para_node = node; - } - auto iter = func_graph->make_ref_params().find(para_node); - if (iter == func_graph->make_ref_params().end()) { - AnfNodePtr value = GetMixedPrecisionCastHelp(func_graph, para_node); - - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - AnfNodePtr ref_key = NewValueNode(std::make_shared(param_name)); - AnfNodePtr ref_node = func_graph->NewCNode({make_ref, ref_key, value, para_node}); - func_graph->make_ref_params()[para_node] = ref_node; - func_graph->add_parameter_obj_node(ref_node); - return ref_node; - } else { - return iter->second; - } -} - -bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { - AnfNodePtr output = nullptr; - if (py::hasattr(obj, "__parameter__")) { - auto param = ResolveParameterObj(func_graph, obj); - if (param == nullptr) { - MS_LOG(ERROR) << "Resolve parameter object failed, got nullptr"; - return false; - } - MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString(); - - output = param; - } else if (py::hasattr(obj, "__parameter_tuple__")) { - auto tuple = obj.cast(); - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t it = 0; it < tuple.size(); ++it) { - AnfNodePtr out = nullptr; - bool success = ResolveObjectToNode(func_graph, tuple[it], &out); - if (!success) { - MS_LOG(ERROR) << "Resolve object to node failed"; - return false; - } - args.push_back(out); - } - output = NewCNode(args, func_graph); - } else { - ValuePtr convert_result = nullptr; - bool converted = ConvertData(obj, &convert_result, parse::python_adapter::UseSignatureInResolve()); - if (!converted) { - MS_LOG(ERROR) << "Convert data failed"; - return false; - } - MS_EXCEPTION_IF_NULL(convert_result); - output = NewValueNode(convert_result); - if (convert_result->isa()) { - output = GetMixedPrecisionCastHelp(func_graph, output); - } - } - *node = output; - return true; -} - -bool IsAllGraphInValueSequence(const std::vector &value_vec) { - for (auto &elem : value_vec) { - if (elem->isa() || elem->isa()) { - const auto &vec = GetValue>(elem); - auto is_graph = IsAllGraphInValueSequence(vec); - if (!is_graph) { - return false; - } - } else if (!elem->isa()) { - return false; - } - } - return true; -} - -AnfNodePtr TransformToMakeTupleNodes(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, - const std::vector &value_vec) { - std::vector nodes; - nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - for (auto &elem : value_vec) { - AnfNodePtr node = nullptr; - if (elem->isa() || elem->isa()) { - const auto &vec = GetValue>(elem); - node = TransformToMakeTupleNodes(manager, func_graph, vec); - } else if (elem->isa()) { - FuncGraphPtr new_fg = elem->cast(); - manager->AddFuncGraph(new_fg); - node = NewValueNode(new_fg); - } else { - MS_LOG(EXCEPTION) << "TransformToMakeTupleNodes error, expect funcgraph, got " << elem->ToString(); - } - nodes.emplace_back(node); - } - auto cnode = func_graph->NewCNode(nodes); - return cnode; -} - -// transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const FuncGraphPtr &func_graph, - const ValueNodePtr &value_node, AnfNodePtr *const transformed) { - MS_EXCEPTION_IF_NULL(value_node); - const auto &value_vec = GetValue>(value_node->value()); - if (!IsAllGraphInValueSequence(value_vec)) { - return false; - } - - // The celllist or ordered_cell will be parsed as valuetuple of const graph in it, - // So if has graph in list, try to replace the node with make tuple of graph value node. - // we do this because the graphmanger won't investigate the graph inside valuetuple, - // change the vector of graph to be make_tuple of graph value node - auto node_tuple_graphs = TransformToMakeTupleNodes(manager, func_graph, value_vec); - // replace the ret ptr to be make tuple of graph value node - *transformed = node_tuple_graphs; - - return true; -} -} // namespace - -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, - const AnfNodePtr &node) { - if (node->func_graph() == nullptr || manager == nullptr) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; - } - SymbolResolver symbol_resolver(name_space, symbol, node); - if (!symbol_resolver.Resolve()) { - MS_LOG(EXCEPTION) << "Parse Resolve node failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - - py::object obj = symbol_resolver.result(); - ScopeGuard scope_guard(node->scope()); - AnfNodePtr resolved_node = nullptr; - TraceManager::DebugTrace(std::make_shared(node->debug_info())); - bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); - if (!success) { - MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - if (IsValueNode(resolved_node)) { - auto new_fg = GetValueNode(resolved_node); - manager->AddFuncGraph(new_fg); - } - - // if the constant node is constant of vector of graph ,add graph to manager - if (IsValueNode(resolved_node) || IsValueNode(resolved_node)) { - (void)TransformVectorGraphValueNode(manager, node->func_graph(), resolved_node->cast(), - &resolved_node); - } - - TraceManager::EndTrace(); - return resolved_node; -} - -namespace { -opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { - opt::OptPassGroupMap map({ - {"resolve", - { - // for resolve and getattr primitive; - irpass.resolver_resolve_, - irpass.resolver_getattr_, - }}, - }); - return map; -} -} // namespace - -bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { - if (func_graph == nullptr || res == nullptr) { - MS_LOG(ERROR) << "func_graph or resource is null"; - return false; - } - opt::irpass::ResolveIRPassLib irpass; - opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); - - (void)parse::python_adapter::set_python_scoped(); - - MS_EXCEPTION_IF_NULL(opt_resolve); - (void)opt_resolve->step(func_graph, use_profile); - return true; -} - -bool ResolveAll(const FuncGraphManagerPtr &manager) { - if (manager == nullptr) { - MS_LOG(ERROR) << "func graph manager is null"; - return false; - } - - if (manager->roots().size() > 1) { - MS_LOG(WARNING) - << "After call ResolveAll, only one graph will be kept in GraphManager. ResolveAll can resolve graphs" - "called from root graph, so it's not necessary to pass all graphs as roots. " - "Please ensure your usage."; - } - // should not use pipeline::Resource as Resource::Clean will clean some - // global variable such as ScopeManager, it will cause JExpandedGraphs::GetBprop - // fail as valid scope has been cleaned. - auto res = std::make_shared(); - res->set_manager(manager); - - auto roots = manager->roots(); - for (auto &fg : roots) { - bool ret = ResolveFuncGraph(fg, res, false); - if (!ret) { - MS_EXCEPTION_IF_NULL(fg); - MS_LOG(ERROR) << "Resolve fg " << fg->ToString() << " failed"; - } - } - return true; -} -} // namespace parse -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/resolve.h b/mindspore/ccsrc/pipeline/parse/resolve.h deleted file mode 100644 index a84b533bd0..0000000000 --- a/mindspore/ccsrc/pipeline/parse/resolve.h +++ /dev/null @@ -1,158 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_PARSE_RESOLVE_H_ -#define PIPELINE_PARSE_RESOLVE_H_ - -#include -#include -#include "ir/anf.h" -#include "ir/manager.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/parse_base.h" -#include "abstract/abstract_value.h" -#include "utils/log_adapter.h" - -// forward declaration of ResourceBase -namespace mindspore { -namespace pipeline { -class ResourceBase; -using ResourceBasePtr = std::shared_ptr; -} // namespace pipeline -} // namespace mindspore - -namespace mindspore { -namespace parse { - -// NameSpace class for resolving python code. -class NameSpace : public Named { - public: - NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} - ~NameSpace() override = default; - MS_DECLARE_PARENT(NameSpace, Named); - - py::object obj() { return obj_; } - std::string module() { return module_; } - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } - - private: - // namespace of the module - std::string module_; - // namespace object - py::object obj_; -}; -using NameSpacePtr = std::shared_ptr; - -// Symbol in NameSpace or Class which shall be resolved. -class Symbol : public Named { - public: - explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} - explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} - - ~Symbol() override = default; - MS_DECLARE_PARENT(Symbol, Named); - - std::string symbol() { return symbol_; } - abstract::AbstractBasePtr ToAbstract() override { - return std::make_shared(shared_from_base(), std::make_shared()); - } - - private: - std::string symbol_; -}; -using SymbolPtr = std::shared_ptr; - -// PyObjectWrapper class wrappers resolved python object for further processing. -class PyObjectWrapper : public Named { - public: - explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} - ~PyObjectWrapper() override = default; - MS_DECLARE_PARENT(PyObjectWrapper, Named); - py::object obj() { return obj_; } - - private: - // the object that needs to be resolved - py::object obj_; -}; - -// ClassObject class wrappers dataclass -class ClassObject : public PyObjectWrapper { - public: - explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") - : PyObjectWrapper(obj, name) {} - ~ClassObject() override = default; - MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); - abstract::AbstractBasePtr ToAbstract() override; -}; - -// ClassType class wrappers class name in python -class ClassType : public PyObjectWrapper { - public: - explicit ClassType(const py::object &obj, const std::string name = "Python class type") - : PyObjectWrapper(obj, name) {} - ~ClassType() override = default; - MS_DECLARE_PARENT(ClassType, PyObjectWrapper); - abstract::AbstractBasePtr ToAbstract() override; -}; - -// SymbolResolver class for resolving symbol extracted from AnfNode. -class SymbolResolver { - public: - SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) - : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} - - ~SymbolResolver() = default; - - // resolve symbol in namespace and save it in result_; - bool Resolve(); - - NameSpacePtr get_namespace() { return namespace_; } - - SymbolPtr symbol() { return symbol_; } - - py::object &result() { return result_; } - - AnfNodePtr resolved_node() { return resolved_node_; } - - // Resolve result - py::object result_; - - private: - // namespace where the symbol locates - NameSpacePtr namespace_; - // the symbol that needs to be resovled - SymbolPtr symbol_; - // the node that has been resolved - AnfNodePtr resolved_node_; -}; -using SymbolResolverPtr = std::shared_ptr; -// Resolve symbol in namespace. -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, - const AnfNodePtr &node); - -// Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). -bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); - -// Resolve all graphs in manager which is defined outside of pipeline::Resource. -// Mainly used for test cases or resolve graphs which will not be managed by manager. -bool ResolveAll(const FuncGraphManagerPtr &manager); - -} // namespace parse -} // namespace mindspore - -#endif // PIPELINE_PARSE_RESOLVE_H_ diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc deleted file mode 100644 index abffc37bb2..0000000000 --- a/mindspore/ccsrc/pipeline/pass.cc +++ /dev/null @@ -1,340 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/pass.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/func_graph_cloner.h" -#include "debug/anf_ir_utils.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/resource.h" -#include "pipeline/validator.h" -#include "optimizer/optimizer.h" -#include "optimizer/cse.h" -#include "optimizer/graph_kernel_reuse.h" -#include "optimizer/clean.h" -#include "optimizer/irpass.h" -#include "optimizer/control_depend.h" -#include "parallel/step_parallel.h" -#include "parallel/step_auto_parallel.h" -#include "parallel/allreduce_fusion/step_allreduce_fusion.h" -#include "utils/any.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace pipeline { -using OptPassGroupMap = opt::OptPassGroupMap; -using Optimizer = opt::Optimizer; -using CompileGraphs = compile::CompileGraphs; -using abstract::AnalysisResult; -using mindspore::abstract::AnalysisContextPtr; -using mindspore::validator::Validate; - -bool SimplifyDataStructuresPass(const ResourcePtr &res) { - MS_EXCEPTION_IF_NULL(res->func_graph()); - - FuncGraphPtr func_graph = res->func_graph(); - bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); - - abstract::AbstractBasePtrList args_spec; - auto parameters = func_graph->parameters(); - (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), - [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); - if (changed) { - FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); - res->set_func_graph(new_fg); - } - res->set_args_spec(args_spec); - return true; -} - -namespace { -OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig a_1 = opt::OptPassConfig({ - irpass.switch_simplify_, - - // Safe inlining - irpass.inline_, - irpass.partial_eliminate_, - irpass.replace_applicator_, - - // Specialization - irpass.specialize_transform_, - - // Miscellaneous - irpass.item_tuple_eliminate_, - irpass.env_get_item_eliminate_, - irpass.cast_eliminate_, - irpass.reshape_eliminate_, - irpass.reduce_eliminate_, - irpass.tile_eliminate_, - irpass.transpose_eliminate_, - irpass.minmaximum_grad_, - irpass.get_make_ref_eliminate_, - - // Arithmetic simplifications - irpass.arithmetic_simplify_, - irpass.addn_zero_filter_, - irpass.adjust_all_reduce_mul_add_, - - // Safe inlining - irpass.inline_, - }); - opt::OptPassConfig a_2 = opt::OptPassConfig({ - irpass.merge_addn_, - irpass.float_tuple_getitem_switch_, - irpass.float_env_getitem_switch_, - irpass.incorporate_getitem_set_, - irpass.incorporate_call_, - irpass.incorporate_call_switch_, - irpass.incorporate_env_getitem_, - irpass.incorporate_env_getitem_switch_, - irpass.new_env_get_item_, - irpass.depend_value_elim_, - }); - opt::OptPassConfig a_3 = opt::OptPassConfig({ - irpass.arithmetic_simplify2_, - irpass.same_eliminate_, - irpass.check_bprop_eliminate_, - irpass.replace_applicator_, - }); - opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); - opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); - opt::irpass::ResolveIRPassLib resolve_irpass; - - opt::OptPassConfig resolve_pass = - opt::OptPassConfig({resolve_irpass.resolver_resolve_, resolve_irpass.resolver_getattr_, - irpass.get_make_ref_eliminate_, irpass.replace_old_param_}); - - OptPassGroupMap map_a({{"a_1", a_1}, - {"a_2", a_2}, - {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)}, - {"parallel", opt::OptPassConfig(parallel::StepParallel)}, - {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, - {"virtual_dataset", virtual_dataset}, - {"grad", grad}, - {"resolve", resolve_pass}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"cse", opt::OptPassConfig(opt::CSE(false))}, - {"a_3", a_3}}); - - return map_a; -} - -OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = opt::OptPassConfig({ - irpass.zero_like_fill_zero_, - irpass.item_tuple_eliminate_, - irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, - irpass.inline_, - irpass.special_op_eliminate_, - irpass.get_make_ref_eliminate_, - }); - opt::OptPassConfig b_2 = opt::OptPassConfig({ - irpass.replace_refkey_by_param_, - irpass.make_ref_eliminate_, - irpass.get_ref_param_eliminate_, - irpass.indexed_slices_eliminate_, - }); - OptPassGroupMap map({ - {"b_1", b_1}, - {"b_2", b_2}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"cse", opt::OptPassConfig(opt::CSE(false))}, - }); - return map; -} - -OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig interface_fusion = opt::OptPassConfig({ - irpass.mark_interface_fusion_, - }); - OptPassGroupMap map({ - {"graph_kernel_reuse", opt::OptPassConfig(opt::GraphKernelReuse())}, - {"interface_fusion", interface_fusion}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"cse", opt::OptPassConfig(opt::CSE(false))}, - }); - return map; -} - -OptPassGroupMap GetOptPassesGraphKernelB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig elim_1 = opt::OptPassConfig({ - irpass.addn_eliminate_, - irpass.incorporate_getitem_from_param_, - }); - opt::OptPassConfig elim_2 = opt::OptPassConfig({ - irpass.unused_parameter_eliminate_, - irpass.unused_output_eliminate_, - }); - OptPassGroupMap map({ - {"elim_1", elim_1}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - {"elim_2", elim_2}, - }); - return map; -} - -OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &irpass) { - return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}}); -} - -OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}, true); - OptPassGroupMap map({ - {"control_group", control_group}, - {"renormalize", opt::OptPassConfig::Renormalize()}, - }); - return map; -} - -OptPassGroupMap GetInferenceOptPreparePhases() { - opt::irpass::InferenceOptPrepareLib irpass; - auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); - opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); - return prepare_map; -} - -OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); - OptPassGroupMap map({{"prepare_group", prepare_group}}); - return map; -} - -static std::unordered_map> g_pass_opts = {}; - -void InitOpt(const ResourcePtr &res) { - if (g_pass_opts.size() == 0) { - opt::irpass::OptimizeIRPassLib irpass; - g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); - g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true); - g_pass_opts["opt_graph_kernel_a"] = - Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true); - g_pass_opts["opt_graph_kernel_b"] = - Optimizer::MakeOptimizer("opt_graph_kernel_b", res, GetOptPassesGraphKernelB(irpass), false); - g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass)); - g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true); - g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - g_pass_opts["opt_graph_kernel_a"]->set_enable(false); - g_pass_opts["opt_graph_kernel_b"]->set_enable(false); - } - } -} -} // namespace - -void ReclaimOptimizer() { - for (auto &opt : g_pass_opts) { - opt.second = nullptr; - } - g_pass_opts.clear(); -} - -bool OptPassGroup(const ResourcePtr &res, const std::string &name) { - if (res->func_graph() == nullptr) { - MS_LOG(ERROR) << "Opt passes int error"; - return false; - } - - FuncGraphPtr func_graph = res->func_graph(); - MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " - << func_graph->get_return()->DebugString(true); - InitOpt(res); - if (g_pass_opts.find(name) != g_pass_opts.end()) { - res->set_func_graph(g_pass_opts[name]->step(func_graph)); - } - // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to - // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. - return true; -} - -bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } -bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } -bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); } -bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); } -bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } -bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } - -bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); } - -bool AddControlDependPass(const ResourcePtr &res) { - FuncGraphPtr func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - - if (func_graph->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { - opt::AddControlDepend(func_graph); - } - for (auto fg : func_graph->func_graphs_used_total()) { - MS_EXCEPTION_IF_NULL(fg); - if (fg->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { - opt::AddControlDepend(fg); - } - } - return true; -} - -bool CconvPass(const ResourcePtr &res) { - MS_EXCEPTION_IF_NULL(res->func_graph()); - FuncGraphPtr func_graph = res->func_graph(); - FuncGraphPtr new_fg = LiftingClone(func_graph); - res->set_func_graph(new_fg); - return true; -} - -bool ValidatePass(const ResourcePtr &res) { - MS_EXCEPTION_IF_NULL(res->func_graph()); - FuncGraphPtr func_graph = res->func_graph(); - Validate(func_graph); - return true; -} - -bool InferenceOptPreparePass(const ResourcePtr &res) { - FuncGraphPtr func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - auto prepare_map = GetInferenceOptPreparePhases(); - auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); - (void)infer_opt_prepare->step(func_graph, false); - return true; -} - -std::vector kVmPasses = {{"opt_a", OptPassAGroup}, - {"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_b", OptPassBGroup}, - {"cconv", CconvPass}, - {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, - {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, - {"add_control_depend", AddControlDependPass}}; - -std::vector kGePasses = { - {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, - {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, - {"cconv", CconvPass}}; - -std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.h b/mindspore/ccsrc/pipeline/pass.h deleted file mode 100644 index 9064df52ee..0000000000 --- a/mindspore/ccsrc/pipeline/pass.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_PASS_H_ -#define MINDSPORE_CCSRC_PIPELINE_PASS_H_ - -#include -#include -#include -#include -#include "pipeline/resource.h" - -namespace mindspore { -namespace pipeline { -using PassItem = std::pair>; - -extern std::vector kGePasses; -extern std::vector kVmPasses; -extern std::vector kPynativePasses; - -bool CconvPass(const ResourcePtr &res); -bool ValidatePass(const ResourcePtr &res); -bool ConvertPrepareAdapt(const ResourcePtr &res); -bool AddControlDependPass(const ResourcePtr &res); -bool InferenceOptPreparePass(const ResourcePtr &res); -void ReclaimOptimizer(); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_PASS_H_ diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc deleted file mode 100644 index 5325cc8249..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ /dev/null @@ -1,948 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/pipeline.h" - -#include -#include -#include -#include -#include - -#include "ir/param_value.h" -#include "pipeline/pass.h" -#include "pipeline/parse/data_converter.h" -#include "optimizer/ad/dfunctor.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "utils/config_manager.h" -#include "utils/convert_utils.h" -#include "utils/utils.h" -#include "vm/segment_runner.h" -#include "parallel/context.h" -#include "parallel/graph_util/get_parallel_info.h" -#include "device/kernel_runtime_manager.h" -#include "debug/trace.h" -#include "pynative/pynative_execute.h" -#include "optimizer/py_pass_manager.h" - -#if (ENABLE_GE || ENABLE_D) -#include "pipeline/pipeline_ge.h" -#include "transform/convert.h" -#include "transform/df_graph_manager.h" -#endif - -namespace mindspore { -// namespace to support intermediate representation definition -namespace pipeline { -using Tensor = mindspore::tensor::Tensor; -using MetaTensor = mindspore::tensor::MetaTensor; -using TensorOrderMap = std::map>; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTensorPtr; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; - -const char IR_TYPE_ANF[] = "anf_ir"; -const char IR_TYPE_ONNX[] = "onnx_ir"; -const char IR_TYPE_BINARY[] = "binary_ir"; - -ExecutorPyPtr ExecutorPy::executor_ = nullptr; -std::mutex ExecutorPy::instance_lock_; - -std::unordered_map - g_args_cache; - -namespace { -std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) { - std::ostringstream oss; - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(EXCEPTION) << "ms_context is nullptr"; - } - auto save_graphs_path = ms_context->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - oss << save_graphs_path << "/" << stage_idx << "_" << action_name; - return oss.str(); -} -} // namespace - -py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { - MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); - abstract::AbstractBasePtrList args_spec; - - for (auto arg : defaults) { - if (py::isinstance(arg.second)) { - MS_LOG(EXCEPTION) << "GenerateKey failed, argument input should not be py::module"; - } - ValuePtr converted = nullptr; - if (!parse::ConvertData(arg.second, &converted)) { - MS_LOG(EXCEPTION) << "GenerateKey convert arg failed"; - } - args_spec.push_back(abstract::FromValue(converted, true)); - } - if (g_args_cache.count(args_spec) == 0) { - static int key = 0; - MS_LOG(INFO) << "Start new args and compile key:" << key; - g_args_cache[args_spec] = key++; - } - auto argSpec = py::tuple(2); - argSpec[0] = name; - argSpec[1] = g_args_cache[args_spec]; - return argSpec; -} - -py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs) { - MS_LOG(DEBUG) << "Verify args size:" << inputs.size(); - if (inputs.size() != input_signature.size()) { - MS_LOG(ERROR) << "Signature size not equal to args size"; - return false; - } - - size_t count = 0; - for (auto arg_obj : inputs) { - if (py::hasattr(arg_obj, PYTHON_TENSOR_FLAG)) { - MS_LOG(DEBUG) << "Verify Tensor"; - std::shared_ptr m_tensor = arg_obj.cast>(); - if (m_tensor == nullptr) { - MS_LOG(ERROR) << "Verify Tensor error, get ptr is null"; - return false; - } - std::shared_ptr sig = input_signature[count].cast>(); - std::vector sig_shape = sig->shape(); - TypePtr sig_type = sig->Dtype(); - - std::vector tensor_shape = m_tensor->shape_c(); - if (tensor_shape != sig_shape) { - MS_LOG(ERROR) << "Python input shape is incompatible with input_signature"; - return false; - } - - if (*m_tensor->Dtype() != *sig_type) { - MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature(" - << sig_type->ToString() << ")"; - return false; - } - } - count++; - } - - return true; -} - -ExecutorPy::ExecutorPy() {} - -ResourcePtr ExecutorPy::GetResource(const std::string &phase) { - MS_LOG(DEBUG) << "Phase size:" << info_.size(); - if (info_.count(phase) == 0) { - return nullptr; - } - return info_[phase]->resource; -} - -FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { - if (info_.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - return info_[phase]->func_graph; -} - -std::size_t ExecutorPy::ArgListSize(const std::string &phase) { - if (info_.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - return info_[phase]->arg_list_size; -} - -compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { - ResourcePtr res = GetResource(phase); - MS_EXCEPTION_IF_NULL(res); - if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { - return res->results()[kOutput].cast(); - } - MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput; - return nullptr; -} - -bool ExecutorPy::HasCompiled(const std::string &phase) const { - if (info_.count(phase) == 0) { - return false; - } - return true; -} - -py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { - FuncGraphPtr fg_ptr = GetFuncGraph(phase); - if (fg_ptr == nullptr) { - for (auto &item : info_) { - MS_LOG(DEBUG) << "Phase key is: " << item.first; - } - MS_LOG(EXCEPTION) << "Can not find func graph " << phase; - } - - if (ir_type == IR_TYPE_ANF) { - std::string proto_str = GetFuncGraphProtoString(fg_ptr); - if (proto_str.empty()) { - MS_LOG(EXCEPTION) << "Graph proto is empty."; - } - return proto_str; - } - - if (ir_type == IR_TYPE_ONNX) { - std::string proto_str = GetOnnxProtoString(fg_ptr); - if (proto_str.empty()) { - MS_LOG(EXCEPTION) << "Graph proto is empty."; - } - return proto_str; - } - - if (ir_type == IR_TYPE_BINARY) { - std::string proto_str = GetBinaryProtoString(fg_ptr); - if (proto_str.empty()) { - MS_LOG(EXCEPTION) << "Graph proto is empty."; - } - return proto_str; - } - - MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; -} - -py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { - MS_LOG(DEBUG) << "GetParameterLayout!"; - std::string layout_graph = phase + kStepParallelGraph; - auto graph = GetFuncGraph(layout_graph); - return mindspore::parallel::GetParameterLayout(graph); -} - -py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { - MS_LOG(DEBUG) << "GetCNodeStrategy!"; - std::string layout_graph = phase + kStepParallelGraph; - auto graph = GetFuncGraph(layout_graph); - return mindspore::parallel::GetCNodeStrategy(graph); -} - -py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { - MS_LOG(INFO) << "GetAllreduceFusion!"; - auto graph = GetFuncGraph(phase); - return mindspore::parallel::GetAllreduceFusion(graph); -} - -void ExecutorPy::DelNetRes(const std::string &id) { -#ifdef ENABLE_GE - FinalizeBackend(); -#endif - if (executor_ != nullptr) { - bool flag = false; - auto tmp_info = info_; - for (auto &item : tmp_info) { - if (item.first.find(id) != string::npos) { - MS_LOG(DEBUG) << "Delete network res:" << item.first; - (void)info_.erase(item.first); - flag = true; - } - } - - MS_LOG(DEBUG) << "Delete flag:" << flag; -#ifdef ENABLE_GE - if (flag && info_.size() == 0) { - // because Ge only support one Session exist at the same time ,so we delete the old one - transform::DfGraphManager::GetInstance().DeleteGraphRunner(); - transform::DfGraphManager::GetInstance().EraseAnfGraph(); - transform::DfGraphManager::GetInstance().DeleteGeSession(); - } -#endif - } -} - -void ExecutorPy::ClearRes() { - MS_LOG(INFO) << "Clean executor resource!"; - executor_ = nullptr; -} - -ExecutorPy::~ExecutorPy() { - MS_LOG(INFO) << "Release Executor!"; - ConfigManager::GetInstance().ResetConfig(); -} - -std::map> ExecutorPy::FetchInfoForQuantExport( - const std::string &phase_s) { - FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; - std::map> fake_quant_table; - auto filter = [](AnfNodePtr node) { - return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) || - IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative)); - }; - std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); - auto is_quant_cnode = [](AnfNodePtr node) { - return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || - IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); - }; - for (auto node : nodes) { - auto cnode = node->cast(); - if (cnode == nullptr || cnode->size() != 3) { - continue; - } - auto x = cnode->input(1); - auto weight = cnode->input(2); - if (!is_quant_cnode(weight)) { - continue; - } - // get parameter weight's name - cnode = weight->cast(); - auto weight_node = cnode->input(2); - if (!weight_node->isa()) { - continue; - } - auto weight_name = weight_node->cast()->name(); - // find the fakequant from input - int count = 0; - const int max_depth = 5; - while (!is_quant_cnode(x)) { - if (count >= max_depth) { - break; - } - cnode = x->cast(); - if (cnode == nullptr || cnode->size() <= 1) { - break; - } - x = cnode->input(1); - count += 1; - } - if (x->isa()) { - fake_quant_table[weight_name] = std::make_pair(nullptr, "input"); - } - // get the fakequant parameter minq's name - if (!is_quant_cnode(x)) { - continue; - } - cnode = x->cast(); - if (cnode == nullptr || cnode->size() != 4) { - continue; - } - auto fakequant_min_node = cnode->input(2); - if (!fakequant_min_node->isa()) { - continue; - } - auto fakequant_min_node_name = fakequant_min_node->cast()->name(); - auto quant_op_value = cnode->input(0)->cast()->value(); - if (!quant_op_value->isa()) { - continue; - } - auto quant_op = quant_op_value->cast(); - fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); - } - - return fake_quant_table; -} - -void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { - // save the graph to ExecutorPy - FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); - std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); - - MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; - info_[phase_s]->func_graph = func_graph; - if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) && - ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) { - MS_LOG(DEBUG) << "Save model parallel parameter layout graph!"; - func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast(); - ExecutorInfoPtr executor_info = std::make_shared(); - std::string layout_graph = phase_s + kStepParallelGraph; - executor_info->func_graph = func_graph; - info_[layout_graph] = executor_info; - } else { - MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!"; - } - MS_LOG(INFO) << "End save compiled func graph!"; -} - -bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { - std::string phase_prefix = GetPhasePrefix(phase_s); - - if (use_vm && phase_prefix == "export") { - MS_LOG(INFO) << "Use ge backend to export geir"; - use_vm = false; - } - return use_vm; -} - -void ExecutorPy::GetGeBackendPolicy() const { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - std::string backend = ms_context->backend_policy(); - if (backend != "ge") { - MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!"; - } -} - -bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { - MS_LOG(DEBUG) << "Start ExecutorPy compile!"; - if ((!py::isinstance(phase))) { - MS_LOG(ERROR) << "Arg phase must be string."; - return false; - } - // check the arg valid? - if (py::isinstance(obj)) { - MS_LOG(ERROR) << "Find error: parse obj is None."; - return false; - } -#ifdef ENABLE_GE - GetGeBackendPolicy(); -#endif - ExecutorInfoPtr executor_info = std::make_shared(); - std::string phase_s = py::cast(phase); - MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!"; - ResourcePtr resource = std::make_shared(obj); - std::vector p_actions; - - use_vm = ChangeExportGeirUseVmFlag(use_vm, phase_s); - - std::string backend = MsContext::GetInstance()->backend_policy(); - if (use_vm && backend != "ge") { - // Create backend and session - auto backend_ptr = compile::CreateBackend(); - // Connect session to debugger - backend_ptr->SetDebugger(); - resource->results()[kBackend] = backend_ptr; - p_actions = VmPipeline(); - } else { - p_actions = GePipeline(); - } - - std::shared_ptr pip = std::make_shared(resource, FilterActions(p_actions, phase_s)); - - // get the parameters items and add the value to args_spec - abstract::AbstractBasePtrList args_spec; - std::size_t size = args.size(); - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - bool broaden = true; - args_spec.push_back(abstract::FromValue(converted, broaden)); - } - - resource->set_args_spec(args_spec); - executor_info->arg_list_size = size; - executor_info->resource = resource; - info_[phase_s] = executor_info; - pip->Run(); - - // save the run graph func to MsPipeLine - SaveCompiledGraph(phase_s); - - resource->Clean(); - // Reclaim all resource used by optimizer; - ReclaimOptimizer(); - - MS_LOG(INFO) << "End ExecutorPy compile!"; - return true; -} - -std::vector ExecutorPy::FilterActions(const std::vector &actions, const std::string &phase) { - // phase does not contain 'export_onnx' - if (GetPhasePrefix(phase).find("export_onnx") == std::string::npos) { - return actions; - } - MS_LOG(INFO) << "Phase is '" << phase << "', filter out actions after stage 'validate'"; - std::vector filtered_actions; - for (const auto &item : actions) { - filtered_actions.emplace_back(item); - if (item.first == "validate") { - break; - } - } - return filtered_actions; -} - -void ExecutorPy::ReleaseResource(const py::object &phase) { - ResourcePtr res = GetResource(py::cast(phase)); - if (res != nullptr) { - res->Clean(); - } - // Reclaim all resource used by optimizer; - ReclaimOptimizer(); -} - -static std::string PrintArgs(const py::tuple &args) { - py::print(args); - return ""; -} - -bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { - bool ret_value = false; - - try { - MS_LOG(DEBUG) << PrintArgs(args); - ret_value = CompileInner(obj, args, phase, use_vm); - } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::ostringstream oss; - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info - py::print(oss.str()); - MS_LOG(ERROR) << oss.str(); - ReleaseResource(phase); - - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - ReleaseResource(phase); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - ReleaseResource(phase); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - ReleaseResource(phase); - throw py::index_error(ex); - } catch (const std::exception &ex) { - ReleaseResource(phase); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - ReleaseResource(phase); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } - - return ret_value; -} - -#ifdef ENABLE_LOAD_ANF_IR -// get MindSpore Intermediate Representation File -std::string GetMsIrFile(void) { - std::string file; - const char *path = getenv("MS_IR_FILE"); - if (path == nullptr) { - return file; - } - - char real_path[PATH_MAX] = {0}; - if (realpath(path, real_path) == nullptr) { - MS_LOG(ERROR) << "MS IR path error, " << path; - return file; - } - file = real_path; - return file; -} - -void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { - MS_EXCEPTION_IF_NULL(resource); - MS_EXCEPTION_IF_NULL(result); - - std::string ir_file = GetMsIrFile(); - (void)parse::python_adapter::set_python_scoped(); - if (ir_file.empty()) { - *result = action.second(resource); - return; - } - - // when in loading anf ir mode, action `parse` do nothing - if (action.first == "parse") { - return; - } - - // load MindSpore IR from file - if (action.first == "symbol_resolve") { - MS_LOG(DEBUG) << action.first << " read ir file: " << ir_file; - std::vector graphs = ImportIR(ir_file); - if (graphs.size() == 0) { - MS_LOG(EXCEPTION) << action.first << " read ir file " << ir_file << " failed as no graph found"; - } - auto manager = resource->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (auto &graph : graphs) { - manager->AddFuncGraph(graph); - } - resource->set_func_graph(graphs[0]); - return; - } - - // do normal action when not in `parse` and `symbol_resolve` stage - *result = action.second(resource); -} -#endif - -void Pipeline::Run() { - MS_LOG(INFO) << "Pipeline run"; - MS_EXCEPTION_IF_NULL(resource_); - FuncGraphPtr user_graph = nullptr; - - WITH(MsProfile::GetProfile())[&user_graph, this]() { - int i = 0; - for (auto &action : actions_) { -#ifdef ENABLE_TIMELINE - DumpTime &dump_time = DumpTime::GetInstance(); - dump_time.Record(action.first, GetTime(), true); -#endif - bool result = true; - WITH(MsProfile::GetProfile()->Step(action.first))[&result, &action, this]() { - MS_LOG(DEBUG) << "Action " << action.first << " start ..."; -#ifdef ENABLE_LOAD_ANF_IR - RunPipelineAction(action, resource_, &result); -#else - result = action.second(resource_); -#endif - MS_LOG(DEBUG) << "Action " << action.first << " end."; - }; - if (!result) { - MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first; - } - if (MsContext::GetInstance()->save_graphs_flag() && resource_->func_graph() != nullptr) { - auto graph = resource_->func_graph(); - if (graph != nullptr) { - user_graph = graph; - std::string base_name = GetBaseNameForIR(i, action.first); - - // generate IR file in dot format, which can be converted to svg file using graphviz dot command - draw::Draw(base_name + ".dot", graph); - // generate IR file in human readable format - DumpIR(base_name + ".ir", graph); - // generate IR file in a heavily commented format, which can also be reloaded - ExportIR(base_name + ".dat", std::to_string(i), graph); - } -#ifdef MS_DEBUG - // Dump graph cnode list - MS_LOG(INFO) << "Show CNode list after " << action.first; - graph->DumpCNodeList(); -#endif - } - if (resource_->func_graph() != nullptr) { - auto func_graph = resource_->func_graph(); - if (func_graph->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - func_graph->EraseUnusedNodeInOrder(); - func_graph->CheckOrder(); - for (auto fg : func_graph->func_graphs_used_total()) { - MS_LOG(DEBUG) << "Check order graph " << fg->ToString() << "."; - fg->EraseUnusedNodeInOrder(); - fg->CheckOrder(); - } - } - } - i++; -#ifdef ENABLE_TIMELINE - dump_time.Record(action.first, GetTime(), false); -#endif - } - }; -#ifdef ENABLE_PROFILE - MsProfile::Print(); - MsProfile::Reset(); -#endif - - if (MsContext::GetInstance()->save_graphs_flag() && (user_graph != nullptr)) { - std::string user_graph_file = GetFilePathName("ModelDigraph.dot"); - MS_LOG(DEBUG) << "Save user graph to: " << user_graph_file; - draw::DrawUserFuncGraph(user_graph_file, user_graph); - } - MS_LOG(INFO) << "End"; -} - -void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { - std::size_t size = args.size(); - - for (std::size_t i = 0; i < size; i++) { - py::object arg = args[i]; - auto ms_context = MsContext::GetInstance(); - if (ms_context->backend_policy() == kMsConvert && py::isinstance(arg)) { - MS_LOG(EXCEPTION) << "The " << i << "th arg is numpy array, not tensor."; - } - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(arg, &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; - } - if (MsContext::GetInstance()->execution_mode() == 0 && !converted->isa()) { - MS_EXCEPTION(TypeError) << "For 'graph mode', the " << i << "th arg: " << converted->ToString() - << " is not tensor."; - } - arg_list->push_back(converted); - } - - MS_EXCEPTION_IF_NULL(res); - auto graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(graph); - std::vector graph_params = graph->parameters(); - std::size_t graph_params_size = graph_params.size(); - if ((*arg_list).size() != graph_params_size) { - // maybe some default parameter - for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) { - MS_EXCEPTION_IF_NULL(graph_params[i]); - auto param_ptr = (graph_params[i])->cast(); - if (!param_ptr->has_default()) { - MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; - } - arg_list->push_back(param_ptr->default_param()->value()); - } - } -} - -void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) { - ProcessVmArgInner(args, GetResource(phase), arg_list); -} - -py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { - std::size_t size = args.size(); - if (!py::isinstance(phase)) { - MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; - } - auto phase_s = py::cast(phase); - std::string backend = MsContext::GetInstance()->backend_policy(); -#ifdef ENABLE_GE - if (backend == "ge") { - return ExecDFGraph(info_, args, phase_s); - } -#else - if (backend == "ms" || backend == "ge") { - auto ret_val = std::make_shared(); - if (info_.count(phase_s) != 0 && info_[phase_s]->func_graph != nullptr) { - if (IsGraphOutputValueNodeOrParameter(info_[phase_s]->func_graph->output(), args, ret_val)) { - return *ret_val; - } - } - if (backend == "ge") { - if (args.size() > 0) { - return args[0]; - } - return args; - } - } -#endif - std::size_t full_arg_size = ArgListSize(phase_s); - if (size > full_arg_size) { - MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size; - } - VectorRef arg_list; - ProcessVmArg(args, phase_s, &arg_list); - - compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s); - if (run == nullptr) { - MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; - } - - MS_LOG(DEBUG) << "Eval run" << backend; - BaseRef value = (*run)(arg_list); - MS_LOG(DEBUG) << "Run end"; - return BaseRefToPyData(value); -} - -FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, - const py::object &broadcast_params) { -#if (ENABLE_GE || ENABLE_D) - return BuildDFGraph(info_, init_params, phase, broadcast_params); -#else - return nullptr; -#endif -} - -void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { -#if ENABLE_GE - RunGEInitGraph(init_params, phase); -#endif -} - -bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase, bool need_run) { - std::string name = MsContext::GetInstance()->backend_policy(); -#ifndef NO_DLIB - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->IsTsdOpened() || !ms_context->IsGeInited()) { - (void)InitBackend(); - } -#endif - if (name == kMsConvert || name == kMsVm) { - return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run); - } -#if ENABLE_GE - return InitExecDatasetGe(queue_name, iter_num, batch_size, types, shapes, input_indexes, phase); -#else - std::string backend = MsContext::GetInstance()->backend_policy(); - if (backend == "ge") { - return true; - } -#endif - return false; -} - -bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, bool need_run) { - MS_LOG(INFO) << "Start InitDataSet Entry"; - std::vector int_input_indexes; - (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), - [](int64_t item) { return static_cast(item); }); - std::vector> int_shapes; - (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), - [](const std::vector &item) { - std::vector vector_item; - (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), - [](int64_t inner_item) { return static_cast(inner_item); }); - return vector_item; - }); - auto p_init = std::make_shared("InitDataSetQueue"); - p_init->set_attr("queue_name", MakeValue(queue_name)); - p_init->set_attr("size", MakeValue(static_cast(size))); - p_init->set_attr("batch_size", MakeValue(static_cast(batch_size))); - p_init->set_attr("types", MakeValue(types)); - p_init->set_attr("shapes", MakeValue(int_shapes)); - p_init->set_attr("input_indexes", MakeValue(int_input_indexes)); - - const std::vector empty_str_list; - p_init->set_attr("input_names", MakeValue(empty_str_list)); - p_init->set_attr("output_names", MakeValue(empty_str_list)); - - FuncGraphPtr func_graph = std::make_shared(); - auto app_init = std::make_shared(AnfNodePtrList{NewValueNode(p_init)}, func_graph); - func_graph->set_output(app_init); - auto manager = MakeManager(); - manager->AddFuncGraph(func_graph); - - // AbstractNone indicates there is no output for this apply node. - auto abstract_none = std::make_shared(); - app_init->set_abstract(abstract_none); - - auto backend = compile::CreateBackend(); - MS_EXCEPTION_IF_NULL(backend); - auto convert_fn = backend->convert_fn(); - MS_EXCEPTION_IF_NULL(convert_fn); - // Convert CNodeList to LinConvertResult. - ConfigManager::GetInstance().set_iter_num(1); - auto runner = convert_fn({app_init}, ""); - if (MsContext::GetInstance()->execution_mode() != kPynativeMode) { - backend->Link(runner.graph_id); - } - ConfigManager::GetInstance().set_iter_num(size); - - if (!(*runner.run)) { - // empty function - MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset."; - } - - // launch init dataset runner without inputs and outputs - VectorRef args; - auto fn = runner.run; - if (need_run) { - (void)(*fn)(args); - } - MS_LOG(DEBUG) << "InitDataSetVm End."; - return true; -} - -void ResetOpId() { mindspore::id_generator::reset_id(); } - -void InitHccl() { -#ifdef ENABLE_GE - (void)InitBackend(); -#else - mindspore::parse::python_adapter::set_python_env_flag(true); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - (void)ms_context->OpenTsd(); - uint32_t device_id = ms_context->device_id(); - std::string device_name = ms_context->device_target(); - ms_context->set_enable_hccl(true); - if (ms_context->backend_policy() == "ms" && ms_context->device_target() == kAscendDevice) { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(device_name, device_id); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Init()) { - MS_LOG(ERROR) << "Kernel runtime init error."; - return; - } - } -#endif -} - -void FinalizeHccl() { -#ifdef ENABLE_GE - (void)FinalizeBackend(); -#else - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); -#endif -} - -void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { -#if (ENABLE_GE || ENABLE_D) - ExportDFGraph(file_name, phase); -#endif - MS_LOG(WARNING) << "In ut test no export_graph"; -} - -void ReleaseGeTsd() { - auto context_ptr = MsContext::GetInstance(); - if (context_ptr != nullptr) { - (void)context_ptr->FinalizeGe(true); - (void)context_ptr->CloseTsd(true); - } -} - -void InitBackend() { - // set python env flag - mindspore::parse::python_adapter::set_python_env_flag(true); - // open tsd before ge initialize - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (!ms_context->OpenTsd()) { - MS_LOG(EXCEPTION) << "Open tsd failed"; - } - (void)ms_context->InitGe(); -} - -void FinalizeBackend() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - (void)context_ptr->FinalizeGe(); - (void)context_ptr->CloseTsd(); -} - -void ClearResAtexit() { - MS_LOG(DEBUG) << "Pipeline clear all resource"; - pynative::ClearPyNativeSession(); - session::ClearPythonParasMap(); - device::KernelRuntimeManager::Instance().ClearRuntimeResource(); - - ad::g_k_prims.clear(); - - abstract::ClearPrimEvaluatorMap(); - compile::ClearConvertCache(); - pipeline::GetMethodMap().clear(); - pipeline::ExecutorPy::ClearRes(); - pipeline::ReclaimOptimizer(); - pynative::PynativeExecutor::GetInstance()->ClearRes(); - opt::python_pass::PyPassManager::GetInstance()->ClearRes(); -#ifdef ENABLE_GE - transform::DfGraphManager::GetInstance().ClearGraph(); - transform::DfGraphConvertor::get_adpt_map().clear(); -#endif - ReleaseGeTsd(); - parse::python_adapter::ResetPythonScope(); -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h deleted file mode 100644 index 58456c4d3b..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ -#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "utils/base_ref_extends.h" -#include "debug/draw.h" -#include "ir/anf.h" -#include "ir/tensor.h" -#include "pipeline/action.h" -#include "vm/segment_runner.h" -#include "vm/transform.h" -#include "pipeline/base.h" - -namespace mindspore { -extern const char kMsConvert[]; -extern const char kMsVm[]; - -// namespace to support pipeline structures definition -namespace pipeline { - -namespace py = pybind11; - -class Pipeline { - public: - Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} - - ~Pipeline() = default; - - void Run(); - - ResourcePtr resource() { return resource_; } - - private: - ResourcePtr resource_; - std::vector actions_; -}; - -// A function pipeline. -class ExecutorPy : public std::enable_shared_from_this { - public: - static std::shared_ptr GetInstance() { - std::lock_guard i_lock(instance_lock_); - if (executor_ == nullptr) { - executor_ = std::shared_ptr(new (std::nothrow) ExecutorPy()); - } - return executor_; - } - - ~ExecutorPy(); - - void SaveCompiledGraph(const std::string &phase_s); - bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - - void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); - - // for pynative mode when use_vm is on - py::object Run(const py::tuple &args, const py::object &phase); - ResourcePtr GetResource(const std::string &phase); - FuncGraphPtr GetFuncGraph(const std::string &phase); - py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); - std::size_t ArgListSize(const std::string &phase); - compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); - bool HasCompiled(const std::string &phase) const; - - FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, - const py::object &broadcast_params = {}); - void RunInitGraph(const py::dict &init_params, const std::string &phase); - py::dict GetParameterLayout(const std::string &phase); - py::dict GetCNodeStrategy(const std::string &phase); - py::dict GetAllreduceFusion(const std::string &phase); - void DelNetRes(const std::string &id); - void ReleaseResource(const py::object &phase); - static void ClearRes(); - - std::map> FetchInfoForQuantExport(const std::string &phase_s); - - private: - ExecutorPy(); - void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); - bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; - void GetGeBackendPolicy() const; - // filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after - // 'validate' stage - static std::vector FilterActions(const std::vector &actions, const std::string &phase); - - std::map info_; - static std::shared_ptr executor_; - static std::mutex instance_lock_; -}; -using ExecutorPyPtr = std::shared_ptr; - -// Generate a key for mapping function graph -py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); -py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); - -bool InitDistribute(const std::map &options); - -void ResetOpId(); -void InitHccl(); -void FinalizeHccl(); -void InitBackend(); -void FinalizeBackend(); - -void ClearResAtexit(); -void ReleaseGeTsd(); - -void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); - -// init and exec dataset sub graph -bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase, bool need_run); - -// Build and run dataset subgraph for ms backend -bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, bool need_run); - -void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); - -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_H_ diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc deleted file mode 100644 index ffc907f698..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ /dev/null @@ -1,535 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/pipeline_ge.h" - -#include -#include -#include -#include -#include - -#include "debug/anf_ir_dump.h" -#include "ir/tensor.h" -#include "transform/convert.h" -#include "transform/df_graph_manager.h" -#include "transform/graph_builder.h" -#include "transform/graph_runner.h" -#include "debug/draw.h" -#include "abstract/abstract_value.h" - -namespace mindspore { -namespace pipeline { -using Tensor = mindspore::tensor::Tensor; -using MetaTensor = mindspore::tensor::MetaTensor; -using TensorOrderMap = std::map>; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractTuplePtr; -using mindspore::transform::DfGraphConvertor; -using mindspore::transform::DfGraphManager; -using mindspore::transform::GeTensorPtr; -using mindspore::transform::MeTensorPtr; -using mindspore::transform::Status; -using mindspore::transform::TransformUtil; - -void DoExecNonInputGraph(const std::string &phase) { - std::vector ge_tensors; - std::vector ge_outputs; - transform::RunOptions run_options; - run_options.name = phase; - auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(ERROR) << "Can not found GraphRunner"; - return; - } - - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "Exec graph:" << run_options.name << " failed"; - return; - } - } -} - -void SetGeOption(const std::map &options) { - ConfigManager::GetInstance().set_ge_initialize_options(options); -} - -Status CreateSessionAndGraphRunner(bool is_training = true) { - std::shared_ptr sess = DfGraphManager::GetInstance().GetGeSession(); - if (sess == nullptr) { - transform::SessionOptions options; - if (is_training) { - options["ge.trainFlag"] = "1"; - options["ge.streamNum"] = "100"; - options["ge.enabledLocalFmkop"] = "1"; - options["ge.hcomParallel"] = "1"; - } else { - options["ge.trainFlag"] = "0"; - } - - options["ge.enablePrintOpPass"] = "0"; - sess = transform::GraphRunner::NewSession(options); - if (sess == nullptr) { - MS_LOG(ERROR) << "Init data graph failed, because of create Ge session failed"; - return Status::FAILED; - } else { - DfGraphManager::GetInstance().SetGeSession(sess); - } - } - - transform::GraphRunnerOptions options; - options.sess_ptr = sess; - auto graph_runner = std::make_shared(options); - if (graph_runner == nullptr) { - MS_LOG(ERROR) << "Create new graph runner failed"; - return Status::FAILED; - } else { - DfGraphManager::GetInstance().SetGraphRunner(graph_runner); - } - - return Status::SUCCESS; -} - -bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase) { - std::vector ge_types; - (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { - return transform::TransformUtil::ConvertDataType(i->type_id()); - }); - - ConfigManager::GetInstance().set_dataset_mode(DatasetMode::DS_SINK_MODE); - ConfigManager::GetInstance().set_iter_num(size); - ConfigManager::GetInstance().set_dataset_phase(phase); - - DatasetGraphParam param(queue_name, size, batch_size, ge_types, shapes, input_indexes); - ConfigManager::GetInstance().set_dataset_param(param); - - if (transform::BuildDatasetGraph(param, phase) != transform::SUCCESS) { - MS_LOG(ERROR) << "Build dateset graph failed."; - return false; - } - -#if ENABLE_TRAIN - (void)setenv("GE_TRAIN", "1", 1); -#else - (void)setenv("GE_TRAIN", "0", 1); -#endif - - if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { - MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; - return false; - } - - MS_LOG(INFO) << "DoExecNonInputGraph:" << phase; - DoExecNonInputGraph(phase); - - return true; -} - -void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { - for (auto item : dict) { - if ((!py::isinstance(item.first))) { - MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; - continue; - } - std::shared_ptr tensor; - std::string name = py::cast(item.first); - if (py::isinstance(item.second.attr("default_input"))) { - // convert float to tensor with shape([1]) - tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); - *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); - } else if (py::isinstance(item.second.attr("default_input"))) { - // convert int to tensor with shape([1]) - tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); - *(static_cast(tensor->data_c())) = py::cast(item.second.attr("default_input")); - } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { - // cast tensor - tensor = py::cast>(item.second.attr("default_input")); - } - - if (tensor == nullptr) { - MS_LOG(EXCEPTION) << "Get default value for " << name << " failed"; - } - (void)tensors->emplace(name, tensor); - } -} - -bool AddDFGraph(const std::map &info, const py::dict &init_params, - const std::string &phase, const py::object &broadcast_params) { - FuncGraphPtr anf_graph = info.at(phase)->func_graph; - DfGraphConvertor convertor(anf_graph); - - size_t pos = phase.find('.'); - std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1)); - std::string phase_prefix = phase.substr(0, pos); - if (phase_prefix == "export") { - MS_LOG(INFO) << "Set DfGraphConvertor training : false"; - convertor.set_training(false); - } - - TensorOrderMap init_tensors{}; - ConvertObjectToTensors(init_params, &init_tensors); - (void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph(); - - if (broadcast_params != py::none()) { - if (!py::isinstance(broadcast_params)) { - MS_LOG(ERROR) << "Invalid broadcast params, it must be py::dict type"; - return false; - } - py::dict broadcast = broadcast_params.cast(); - if (broadcast.empty()) { - (void)convertor.GenerateBroadcastGraph(init_tensors); - } else { - TensorOrderMap broadcast_tensors{}; - ConvertObjectToTensors(broadcast, &broadcast_tensors); - (void)convertor.GenerateBroadcastGraph(broadcast_tensors); - } - MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty(); - } - - (void)convertor.GenerateCheckpointGraph(); - if (convertor.ErrCode() != 0) { - DfGraphManager::GetInstance().ClearGraph(); - MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); - return false; - } - - if (MsContext::GetInstance()->save_graphs_flag()) { - convertor.DrawComputeGraph(GetFilePathName("ge_graph.dot")); // for debug - convertor.DrawInitGraph(GetFilePathName("init_graph.dot")); // for debug - convertor.DrawSaveCheckpointGraph(GetFilePathName("save_checkpoint_graph.dot")); // for debug - } - std::string init_graph = "init_subgraph." + net_id; - std::string checkpoint_name = "save." + net_id; - if (phase.find("train") != std::string::npos) { - (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); - } else { - (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); - } - (void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); - (void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); - - Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); - if (ret == Status::SUCCESS) { - DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); - } - - return true; -} - -FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, - const std::string &phase, const py::object &broadcast_params) { - if (info.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - FuncGraphPtr anf_graph = info.at(phase)->func_graph; - - if (MsContext::GetInstance()->save_graphs_flag()) { - draw::Draw(GetFilePathName("anf_graph.dot"), anf_graph); // for debug - DumpIR(GetFilePathName("anf_graph.ir"), anf_graph, true); - } - - if (!AddDFGraph(info, init_params, phase, broadcast_params)) { - MS_LOG(ERROR) << "GenConvertor failed"; - return nullptr; - } - -#if ENABLE_TRAIN - (void)setenv("GE_TRAIN", "1", 1); -#else - (void)setenv("GE_TRAIN", "0", 1); -#endif - - if (CreateSessionAndGraphRunner(static_cast(ENABLE_TRAIN)) != Status::SUCCESS) { - MS_LOG(ERROR) << "Create GE Session or GraphRunner failed."; - return nullptr; - } - - return anf_graph; -} - -void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { - MS_LOG(DEBUG) << "ExecInitGraph start."; - TensorOrderMap inputs_with_name{}; - ConvertObjectToTensors(init_params, &inputs_with_name); - std::vector inputs; - (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), - [](const std::pair &item) { return item.second; }); - - std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); - if (ge_tensors.size() != inputs.size()) { - MS_LOG(ERROR) << "Args convert to ge tensor error."; - return; - } - MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size() << "."; - - std::vector ge_outputs; - transform::RunOptions run_options; - - run_options.name = phase; - if (DfGraphManager::GetInstance().GetGraphByName(phase) == nullptr) { - MS_LOG(WARNING) << "Can not find " << phase << " sub graph, don't need data init subgraph in INFER mode."; - return; - } - auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(EXCEPTION) << "Can not found GraphRunner."; - } - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(EXCEPTION) << "Exec " << phase << " graph failed."; - } - - MS_LOG(INFO) << "Exec " << phase << " graph success."; - - if ((ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::DISTRIBUTION) && - (DfGraphManager::GetInstance().GetGraphByName(BROADCAST_GRAPH_NAME) != nullptr)) { - run_options.name = BROADCAST_GRAPH_NAME; - ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - if (ret != Status::SUCCESS) { - MS_LOG(EXCEPTION) << "Exec BROADCAST_GRAPH_NAME failed."; - } - MS_LOG(INFO) << "Exec broadcast graph success."; - } - } -} - -py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { - MS_EXCEPTION_IF_NULL(cnode_data); - - if (cnode_data->isa()) { - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } - - BaseShapePtr shape = cnode_data->BuildShape(); - if (!shape->isa()) { - MS_LOG(EXCEPTION) << "The shape of the tensor derived is not Shape, is " << shape->ToString(); - } - auto shape_me = shape->cast()->shape(); - auto shape_ge = py::cast(data[*count]).shape(); - if (shape_ge != shape_me) { - MS_LOG(EXCEPTION) << "The shape of the " << *count << "th tensor returned: " << shape_ge - << " is not the same as the shape of the tensor derived: " << shape_me; - } - - return data[(*count)++]; - } - - if (!cnode_data->isa()) { - MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could " - << "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString() - << "."; - } - auto data_tp = cnode_data->cast(); - auto elements = data_tp->elements(); - size_t size = data_tp->size(); - auto tp = py::tuple(size); - for (size_t i = 0; i < size; i++) { - tp[i] = ExtractGeneralCnodeRet(elements[i], data, count); - } - return std::move(tp); -} - -py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { - MS_EXCEPTION_IF_NULL(output_node); - - if (output_node->isa()) { - return ValuePtrToPyData(GetValueNode(output_node)); - } - - if (output_node->isa()) { - if (*count >= data.size()) { - MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() - << " less than the number of elements required. "; - } - return data[(*count)++]; - } - - auto output_c = output_node->cast(); - if (output_c == nullptr) { - MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got " - << output_node->ToString(); - } - - if (output_c->IsApply(prim::kPrimMakeTuple)) { - auto input_list = output_c->inputs(); - size_t size = input_list.size(); - auto tp = py::tuple(size - 1); - for (size_t i = 1; i < size; i++) { - tp[i - 1] = StructureOutput(input_list[i], data, count); - } - return std::move(tp); - } - if (output_c->IsApply(prim::kPrimDepend)) { - return StructureOutput(output_c->input(1), data, count); - } - - return ExtractGeneralCnodeRet(output_c->abstract(), data, count); -} - -std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, - const std::string &phase) { - std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); - if (ge_tensors.size() != inputs.size()) { - MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; - } - - std::vector ge_outputs; - transform::RunOptions run_options; - run_options.name = phase; - auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); - if (graph_runner == nullptr) { - MS_LOG(EXCEPTION) << "Can not found GraphRunner."; - } - - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - MS_LOG(DEBUG) << "Run graph begin, inputs size is: " << inputs.size(); - Status ret = graph_runner->RunGraph(run_options, ge_tensors, &ge_outputs); - MS_LOG(DEBUG) << "Run graph finish, outputs size is: " << ge_outputs.size(); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "Exec graph failed"; - return nullptr; - } - } - - std::vector me_outputs = TransformUtil::ConvertGeTensors(ge_outputs); - if (me_outputs.size() != ge_outputs.size()) { - MS_LOG(WARNING) << "Convert output Ge tensor to Me tensor failed"; - } - - py::tuple outputs(me_outputs.size()); - for (std::size_t i = 0; i < outputs.size(); i++) { - outputs[i] = *me_outputs[i]; - } - - std::shared_ptr ret = nullptr; - - AnfNodePtr output_node = graph->get_return()->input(1); - MS_EXCEPTION_IF_NULL(output_node); - size_t count = 0; - py::object oj = StructureOutput(output_node, outputs, &count); - ret = std::make_shared(oj); - - return ret; -} - -void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, - std::vector *inputs) { - // check the arg and use the ExecutorPy args - std::size_t size = args.size(); - - if (info.count(phase) == 0) { - MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); - } - - auto arg_size = info.at(phase)->arg_list_size; - if (size != arg_size) { - MS_LOG(EXCEPTION) << "The real arg num : size = " << size << ". graph_arg_size = " << arg_size; - } - - // process the first args of tensor - // only in dataset normal(non-sink) mode, fp_bp graph need input tensors - if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) { - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; - } - if (converted->isa()) { - inputs->push_back(converted->cast()); - } else { - MS_EXCEPTION(TypeError) << "The " << i << "th arg: " << converted->ToString() << " is not tensor."; - } - } - } -} - -py::object ExecDFGraph(const std::map &info, const py::tuple &args, - const std::string &phase) { - std::string phase_prefix = GetPhasePrefix(phase); - if (phase_prefix == "save") { - DoExecNonInputGraph(phase); - ConfigManager::GetInstance().ResetConfig(); - return py::none(); - } - - if (info.count(phase) == 0) { - MS_LOG(EXCEPTION) << "There is no phase:" << phase; - } - FuncGraphPtr anf_graph = info.at(phase)->func_graph; - -#ifdef ENABLE_INFER - // Now don't use the graph because the exec ge function don't take effect - MS_EXCEPTION_IF_NULL(info.at(phase)->func_graph); - if (ENABLE_TRAIN != info.at(phase)->func_graph->has_flag("training")) { - MS_LOG(ERROR) << "Graph training mode mismatch mode of libraries"; - ConfigManager::GetInstance().ResetConfig(); - return py::none(); - } -#endif - - std::shared_ptr ret_val = std::make_shared(); - // We will not execute graph when output is constant or just input itself. - if (IsGraphOutputValueNodeOrParameter(info.at(phase)->func_graph->output(), args, ret_val)) { - ConfigManager::GetInstance().ResetConfig(); - return *ret_val; - } - - std::vector inputs; - ProcessGeArg(info, args, phase, &inputs); - - std::shared_ptr ret = DoExecGraph(anf_graph, inputs, phase); - ConfigManager::GetInstance().ResetConfig(); - if (ret != nullptr) { - return *ret; - } else { - MS_LOG(EXCEPTION) << "Exec graph failed"; - } -} -void ExportDFGraph(const std::string &file_name, const std::string &phase) { - MS_LOG(DEBUG) << "ExportGraph Begin"; - transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); - if (wrap_ptr == nullptr) { - MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; - return; - } - - transform::DfGraphPtr ge_graph = wrap_ptr->graph_ptr_; - if (nullptr == ge_graph) { - MS_LOG(ERROR) << "The export graph is null"; - return; - } - - (void)ge_graph->SaveToFile(file_name); - - MS_LOG(DEBUG) << "ExportGraph End"; -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.h b/mindspore/ccsrc/pipeline/pipeline_ge.h deleted file mode 100644 index f3a363dbe8..0000000000 --- a/mindspore/ccsrc/pipeline/pipeline_ge.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ -#define MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pipeline/base.h" -#include "operator/ops.h" - -namespace mindspore { -namespace pipeline { -namespace py = pybind11; - -void SetGeOption(const std::map &options); - -void RunGEInitGraph(const py::dict &init_params, const std::string &phase); - -py::object ExecDFGraph(const std::map &info, const py::tuple &args, - const std::string &phase = "train"); - -FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, - const std::string &phase, const py::object &broadcast_params = {}); - -// init and exec dataset sub graph for GE backend -bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, - const std::vector &types, const std::vector> &shapes, - const std::vector &input_indexes, const std::string &phase); - -void ExportDFGraph(const std::string &file_name, const std::string &phase); -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_PIPELINE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt b/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt new file mode 100644 index 0000000000..c15928ee76 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/CMakeLists.txt @@ -0,0 +1,9 @@ +file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc") + +if (ENABLE_GE) + file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc") + list(APPEND _PYNATIVE_SRC_LIST ${_GE_SRC_LIST}) +endif () + +set_property(SOURCE ${_PYNATIVE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PYNATIVE) +add_library(_mindspore_pipeline_pynative_obj OBJECT ${_PYNATIVE_SRC_LIST}) diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h similarity index 100% rename from mindspore/ccsrc/pynative/base.h rename to mindspore/ccsrc/pipeline/pynative/base.h diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc new file mode 100644 index 0000000000..5e3add1b5f --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -0,0 +1,1167 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/pynative/pynative_execute.h" + +#include +#include +#include +#include +#include + +#include "debug/trace.h" +#include "ir/tensor_py.h" +#include "ir/param_value.h" +#include "utils/any.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/composite/do_signature.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "backend/session/session_factory.h" +#include "backend/optimizer/pass/const_input_to_attr_registry.h" +#include "backend/optimizer/common/helper.h" +#include "pipeline/jit/action.h" + +#include "pipeline/pynative/base.h" +#include "pybind_api/api_register.h" +#include "vm/transform.h" + +#include "frontend/optimizer/ad/grad.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/pass.h" + +#ifdef ENABLE_GE +#include "pipeline/pynative/pynative_execute_ge.h" +#endif + +using mindspore::tensor::TensorPy; + +const char SINGLE_OP_GRAPH[] = "single_op_graph"; +// primitive unable to infer value for constant input in PyNative mode +const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; + +namespace mindspore { +namespace pynative { + +static std::shared_ptr session = nullptr; +PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; +std::mutex PynativeExecutor::instance_lock_; +ResourcePtr PynativeExecutor::resource_; + +template +void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { + try { + (executor->*method)(args...); + } catch (const py::error_already_set &ex) { + // print function call stack info before release + std::ostringstream oss; + trace::TraceGraphEval(); + trace::GetEvalStackInfo(oss); + // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info + py::print(oss.str()); + MS_LOG(ERROR) << oss.str(); + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::index_error(ex); + } catch (const std::exception &ex) { + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + PynativeExecutor::GetInstance()->Clean(); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } +} + +inline ValuePtr PyAttrValue(const py::object &obj) { + ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj); + if (!converted_ret) { + MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); + } + return converted_ret; +} + +std::string GetId(const py::object &obj) { + py::object to_process = obj; + std::string prefix = ""; + if (py::isinstance(to_process)) { + auto p_list = py::cast(to_process); + if (p_list.size() == 0) { + return "empty"; + } + prefix = "tuple:"; + std::string key = ""; + for (size_t i = 0; i < p_list.size(); ++i) { + key += std::string(py::str(GetId(p_list[i]))) + ":"; + } + return prefix + key; + } + if (py::isinstance(to_process)) { + return prefix + std::string(py::str(to_process)); + } + if (py::isinstance(to_process)) { + return prefix + std::string(py::str(to_process)); + } + if (py::isinstance(to_process)) { + auto tensor_ptr = py::cast(to_process); + return prefix + tensor_ptr->id(); + } + + py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj); + return py::cast(ret); +} + +py::object GetTupleObj(const py::object &obj) { + py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); + py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); + return obj_tuple; +} + +std::map> GetTypeIndex(const std::vector &dtypes) { + std::map> type_indexes; + for (size_t i = 0; i < dtypes.size(); ++i) { + auto it = type_indexes.find(dtypes[i]); + if (it == type_indexes.end()) { + (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector{i})); + } else { + it->second.push_back(i); + } + } + return type_indexes; +} + +std::map GetDstType(const py::tuple &py_args, + const std::map> &type_indexes) { + std::map dst_type; + for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { + auto type = it->first; + auto indexes = it->second; + if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) { + continue; + } + size_t priority = 0; + TypeId max_type = TypeId::kTypeUnknown; + bool has_float = false; + bool has_int = false; + for (size_t index : indexes) { + if (!has_float && py::isinstance(py_args[index])) { + has_float = true; + } + if (!has_int && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { + has_int = true; + } + if (py::isinstance(py_args[index])) { + auto arg = py::cast(py_args[index]); + TypeId arg_type_id = arg->data_type(); + auto type_priority = prim::type_map.find(arg_type_id); + if (type_priority == prim::type_map.end()) { + continue; + } + if (type_priority->second > priority) { + max_type = type_priority->first; + priority = type_priority->second; + } + } + } + if (max_type == TypeId::kNumberTypeBool) { + if (has_int) { + max_type = TypeId::kNumberTypeInt32; + } + if (has_float) { + max_type = TypeId::kNumberTypeFloat32; + } + } + (void)dst_type.insert(std::make_pair(type, max_type)); + } + return dst_type; +} + +std::string TypeIdToMsTypeStr(const TypeId &type_id) { + auto type_name = type_name_map.find(type_id); + if (type_name == type_name_map.end()) { + MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id); + } + return type_name->second; +} + +py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { + py::tuple args(3); + std::string module_name = "mindspore.ops.functional"; + std::string op_name = "cast"; + args[0] = parse::python_adapter::GetPyFn(module_name, op_name); + args[1] = "Cast"; + + std::string dst_type_str = TypeIdToMsTypeStr(type_id); + module_name = "mindspore.common.dtype"; + py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str); + py::tuple inputs(2); + inputs[0] = arg; + inputs[1] = dst_type; + args[2] = inputs; + + return RunOp(args)[0]; +} +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, + py::list *const out_args_list) { + auto &py_args = *out_args; + py::tuple input_mask(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + input_mask[i] = py::hasattr(args[i], "__parameter__"); + py_args[i] = GetTupleObj(args[i]); + } + auto signature = prim->signatures(); + std::vector dtypes; + (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), + [](const Signature &sig) { return sig.dtype; }); + int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); + if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { + return input_mask; + } + auto type_indexes = GetTypeIndex(dtypes); + auto dst_type = GetDstType(py_args, type_indexes); + + for (size_t i = 0; i < dtypes.size(); ++i) { + if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { + continue; + } + auto it = dst_type.find(dtypes[i]); + if (it == dst_type.end() || it->second == kTypeUnknown) { + continue; + } + if (py::isinstance(py_args[i])) { + auto arg = py::cast(py_args[i]); + if (arg->data_type() == it->second) { + continue; + } + if (signature[i].rw == SignatureEnumRW::kRWWrite) { + prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()), + TypeIdToMsTypeStr(it->second)); + } + } + py::object cast_output = DoAutoCast(py_args[i], it->second); + (*out_args)[i] = cast_output; + (*out_args_list)[i] = cast_output; + } + return input_mask; +} + +void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { + size_t size = py_args.size(); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < size; i++) { + ValuePtr input_value = PyAttrValue(py_args[i]); + args_spec_list.emplace_back(abstract::FromValueInside( + input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa())); + } + AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); + op_exec_info->abstract = infer_res; +} + +OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) { + if (args.size() != PY_ARGS_NUM) { + MS_LOG(ERROR) << "Three args are needed by RunOp"; + return nullptr; + } + auto op_exec_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(op_exec_info); + op_exec_info->op_name = py::cast(args[PY_NAME]); + auto prim = py::cast(args[PY_PRIM]); + auto pyobj = prim->GetPyObj(); + if (pyobj == nullptr) { + MS_LOG(EXCEPTION) << "pyobj is empty"; + } + + py::list a = args[PY_INPUTS]; + size_t input_num = a.size(); + op_exec_info->op_inputs = py::tuple(input_num); + + op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args); + // use python infer method + if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { + PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); + } + op_exec_info->py_primitive = prim; + op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); + if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { + MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; + return nullptr; + } + return op_exec_info; +} + +std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(op_exec_info); + std::string graph_info; + // get input tensor info + size_t input_num = op_exec_info->op_inputs.size(); + for (size_t index = 0; index < input_num; ++index) { + auto input = op_exec_info->op_inputs[index]; + if (py::isinstance(input)) { + auto tensor_ptr = py::cast(input); + (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); + } + } + // get prim and abstract info + MS_EXCEPTION_IF_NULL(op_exec_info->abstract); + (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + + op_exec_info->abstract->ToString()); + return graph_info; +} + +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_LOG(INFO) << "RunOpInVM start"; + + MS_EXCEPTION_IF_NULL(status); + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); + if (op_exec_info->op_name == "HookBackward") { + auto op_inputs = op_exec_info->op_inputs; + py::tuple result(op_inputs.size()); + for (size_t i = 0; i < op_inputs.size(); i++) { + py::object input = op_inputs[i]; + if (py::hasattr(input, "__parameter__")) { + input = py::getattr(input, "data"); + } + auto tensor = py::cast(input); + auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); + new_tensor->set_device_address(tensor->device_address()); + new_tensor->set_dirty(tensor->is_dirty()); + result[i] = new_tensor; + } + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(result); + } + auto func = op_exec_info->py_primitive->GetComputeFunction(); + if (py::isinstance(func)) { + MS_LOG(ERROR) << "VM failed to get func"; + *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; + py::tuple err_ret(0); + return std::move(err_ret); + } + + // execute op + py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(result); +} + +bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, + const std::unordered_set &input_attrs) { + MS_EXCEPTION_IF_NULL(op_prim); + auto input_names_value = op_prim->GetAttr(kAttrInputNames); + if (input_names_value == nullptr) { + return false; + } + auto input_names_vec = GetValue>(input_names_value); + if (input_index >= input_names_vec.size()) { + MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; + } + + if (input_attrs.find(input_index) != input_attrs.end()) { + ValuePtr value = parse::data_converter::PyDataToValue(input_object); + MS_EXCEPTION_IF_NULL(value); + auto input_name = input_names_vec[input_index]; + op_prim->set_attr(input_name, value); + return true; + } + return false; +} + +void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, + std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(op_prim); + MS_EXCEPTION_IF_NULL(input_tensors); + for (const auto &input_object : tuple_inputs) { + if (!py::isinstance(input_object)) { + MS_LOG(EXCEPTION) << "The input object is not a tensor!"; + } + auto tensor = py::cast(input_object); + MS_EXCEPTION_IF_NULL(tensor); + input_tensors->push_back(tensor); + } + op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(tuple_inputs.size())})); +} + +void ConvertValueTupleToTensor(const py::object &input_object, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); + ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); + MS_EXCEPTION_IF_NULL(input_value); + if (!input_value->isa()) { + MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; + } + auto value_tuple = input_value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple); + MS_EXCEPTION_IF_NULL(tensor_ptr); + input_tensors->push_back(tensor_ptr); +} + +void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, + std::vector *input_tensors, int *tensor_mask) { + MS_EXCEPTION_IF_NULL(op_prim); + MS_EXCEPTION_IF_NULL(input_tensors); + MS_EXCEPTION_IF_NULL(tensor_mask); + + if (!py::isinstance(input_object)) { + MS_LOG(EXCEPTION) << "The input should be a tuple!"; + } + auto tuple_inputs = py::cast(input_object); + if (tuple_inputs.size() == 0) { + MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; + } + if (py::isinstance(tuple_inputs[0])) { + PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); + } else { + ConvertValueTupleToTensor(input_object, input_tensors); + *tensor_mask = kValueNodeTensorMask; + } +} + +void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, + std::vector *input_tensors, int *tensor_mask) { + MS_EXCEPTION_IF_NULL(op_prim); + MS_EXCEPTION_IF_NULL(input_tensors); + MS_EXCEPTION_IF_NULL(tensor_mask); + tensor::TensorPtr tensor_ptr = nullptr; + if (py::isinstance(input_object)) { + tensor_ptr = py::cast(input_object); + } else if (py::isinstance(input_object)) { + double input_value = py::cast(input_object); + tensor_ptr = std::make_shared(input_value, kFloat32); + *tensor_mask = kValueNodeTensorMask; + } else if (py::isinstance(input_object)) { + tensor_ptr = std::make_shared(py::cast(input_object), kInt32); + *tensor_mask = kValueNodeTensorMask; + } else if (py::isinstance(input_object)) { + tensor_ptr = TensorPy::MakeTensor(py::cast(input_object), nullptr); + } else if (py::isinstance(input_object)) { + auto list_inputs = py::cast(input_object); + py::tuple tuple_inputs(list_inputs.size()); + for (size_t i = 0; i < tuple_inputs.size(); ++i) { + tuple_inputs[i] = list_inputs[i]; + } + ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask); + return; + } else if (py::isinstance(input_object)) { + ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask); + return; + } else if (py::isinstance(input_object)) { + return; + } else { + MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; + } + MS_EXCEPTION_IF_NULL(tensor_ptr); + input_tensors->push_back(tensor_ptr); +} + +void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, + std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(op_run_info); + MS_EXCEPTION_IF_NULL(tensors_mask); + MS_EXCEPTION_IF_NULL(input_tensors); + PrimitivePtr op_prim = op_run_info->py_primitive; + MS_EXCEPTION_IF_NULL(op_prim); + + if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) { + MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size " + << op_run_info->inputs_mask.size(); + } + opt::ConstInputToAttrInfoRegister reg; + bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); + size_t input_num = op_run_info->op_inputs.size(); + for (size_t index = 0; index < input_num; ++index) { + // convert const input to attr + if (reg_exist && + RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { + continue; + } + // convert const and tuple input to tensor + int tensor_mask = py::cast(op_run_info->inputs_mask[index]); + ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); + // mark tensors, data : 0, weight : 1, valuenode: 2 + std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); + tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); + } +} + +void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); + if (input_tensors->size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " + << tensors_mask.size(); + } + std::vector new_input_tensors; + for (size_t index = 0; index < tensors_mask.size(); ++index) { + if (tensors_mask[index] != kValueNodeTensorMask) { + new_input_tensors.push_back(input_tensors->at(index)); + } + } + *input_tensors = new_input_tensors; +} + +py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_enable_pynative_infer(true); + std::string device_target = ms_context->device_target(); + if (device_target != kAscendDevice && device_target != kGPUDevice) { + MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; + } + + if (session == nullptr) { + session = session::SessionFactory::Get().Create(device_target); + } + MS_EXCEPTION_IF_NULL(session); + session->Init(ms_context->device_id()); + + std::vector input_tensors; + std::vector tensors_mask; + ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); + // get graph info for checking it whether existing in the cache + std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); + session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, &input_tensors); + py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); + ms_context->set_enable_pynative_infer(false); + *status = PYNATIVE_SUCCESS; + return result; +} + +py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, + PynativeStatusCode *const status) { + MS_EXCEPTION_IF_NULL(status); + py::object result; + switch (backend_policy) { + case kMsBackendVmOnly: { + // use vm only + MS_LOG(INFO) << "RunOp use VM only backend"; + result = RunOpInVM(op_exec_info, status); + break; + } + case kMsBackendGePrior: { +#ifdef ENABLE_GE + // use GE first, use vm when GE fails + MS_LOG(INFO) << "RunOp use GE first backend"; + result = RunOpInGE(op_exec_info, status); + if (*status != PYNATIVE_SUCCESS) { + result = RunOpInVM(op_exec_info, status); + } +#endif + break; + } + case kMsBackendMsPrior: { + // use Ms fisrt,use others when ms failed + MS_LOG(INFO) << "RunOp use Ms first backend"; + result = RunOpInMs(op_exec_info, status); + if (*status != PYNATIVE_SUCCESS) { + MS_LOG(ERROR) << "RunOp use Ms backend failed!!!"; + } + break; + } + default: + MS_LOG(ERROR) << "No backend configured for run op"; + } + return result; +} + +AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { + if (!grad_flag_ || graph_info_map_.empty()) { + return nullptr; + } + std::vector inputs; + auto prim = op_exec_info->py_primitive; + inputs.push_back(NewValueNode(prim)); + py::tuple op_masks = op_exec_info->inputs_mask; + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < args.size(); i++) { + auto node = GetInput(args[i], op_masks[i]); + args_spec_list.push_back(node->abstract()); + inputs.push_back(node); + } + + auto cnode = curr_g_->NewCNode(inputs); + MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); + py::object out_real = out; + if (out.size() == 1) { + MS_LOG(DEBUG) << "MakeCnode out size is one."; + out_real = out[0]; + } + std::string obj_id = GetId(out_real); + if (py::isinstance(out_real)) { + auto value = py::cast(out_real); + if (value.size() > 1) { + for (int i = 0; i < static_cast(value.size()); i++) { + auto value_id = GetId(value[i]); + MS_LOG(DEBUG) << "MakeCnode set node id " << value_id; + set_obj_node_map(curr_g_, value_id, cnode, i); + } + } + } + MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id; + set_obj_node_map(curr_g_, obj_id, cnode); + set_pyobj(curr_g_, obj_id); + return cnode; +} + +AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { + auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)]; + if (out.second.size() == 1 && out.second[0] == -1) { + return out.first; + } + auto node = out.first; + MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); + for (auto &idx : out.second) { + std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)}; + node = curr_g_->NewCNode(tuple_get_item_inputs); + } + MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); + return node; +} + +py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { + MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; + mindspore::parse::python_adapter::set_python_env_flag(true); + MsBackendPolicy backend_policy; +#if (!defined ENABLE_GE) + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->backend_policy() == "ms") { + backend_policy = kMsBackendMsPrior; + } else { + backend_policy = kMsBackendVmOnly; + } +#else + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->PynativeInitGe(); + backend_policy = kMsBackendGeOnly; +#endif + if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { + backend_policy = kMsBackendVmOnly; + } + PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; + // returns a null py::tuple on error + py::tuple err_ret(0); + py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); + if (status != PYNATIVE_SUCCESS) { + MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; + return err_ret; + } + + auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); + if (node != nullptr) { + node->set_abstract(op_exec_info->abstract); + MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); + } + MS_LOG(DEBUG) << "RunOp end"; + return result; +} + +py::tuple RunOpInner(const py::args &args) { + MS_LOG(DEBUG) << "RunOp start" << args.size(); + py::list args_input = args[PY_INPUTS]; + + OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input); + MS_EXCEPTION_IF_NULL(op_exec_info); + + if (op_exec_info->abstract != nullptr) { + py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); + if (!output["value"].is_none()) { + py::tuple value_ret(1); + value_ret[0] = output["value"]; + return value_ret; + } + if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { + py::tuple value_ret(1); + value_ret[0] = ""; + return value_ret; + } + } + return RunOpInner(op_exec_info, args_input); +} + +py::tuple RunOp(const py::args &args) { + try { + return RunOpInner(args); + } catch (const py::error_already_set &ex) { + // print function call stack info before release + std::ostringstream oss; + trace::TraceGraphEval(); + trace::GetEvalStackInfo(oss); + // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see + // these info from screen, no need to open log file to find these info + py::print(oss.str()); + MS_LOG(ERROR) << oss.str(); + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(py::error_already_set(ex)); + } catch (const py::type_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::type_error(ex); + } catch (const py::value_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::value_error(ex); + } catch (const py::index_error &ex) { + PynativeExecutor::GetInstance()->Clean(); + throw py::index_error(ex); + } catch (const std::exception &ex) { + PynativeExecutor::GetInstance()->Clean(); + // re-throw this exception to Python interpreter to handle it + throw(std::runtime_error(ex.what())); + } catch (...) { + PynativeExecutor::GetInstance()->Clean(); + std::string exName(abi::__cxa_current_exception_type()->name()); + MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; + } +} + +void ClearPyNativeSession() { session = nullptr; } + +PynativeExecutor::~PynativeExecutor() { ClearRes(); } + +PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } + +void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { + auto cell_id = GetId(cell); + if (cell_graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "Newgraph already compiled"; + return; + } + + auto g = std::make_shared(); + + if (top_g_ == nullptr) { + top_g_ = curr_g_ = g; + df_builder_ = std::make_shared(); + MS_LOG(DEBUG) << "First new graph" << top_g_.get(); + Pushp(); + } else { + Pushp(); + curr_g_ = g; + } + if (graph_info_map_.count(g) == 0) { + graph_info_map_[g] = GraphInfo(); + } + for (size_t i = 0; i < args.size(); i++) { + auto new_param = g->add_parameter(); + std::string param_obj = GetId(args[i]); + graph_info_map_[g].param_map[param_obj] = new_param; + } +} + +AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { + ValuePtr converted_ret = nullptr; + parse::ConvertData(obj, &converted_ret); + auto node = NewValueNode(converted_ret); + set_obj_node_map(curr_g_, obj_id, node); + return node; +} + +AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) { + AnfNodePtr node = nullptr; + std::string obj_id = GetId(obj); + + if (op_mask != nullptr && py::cast(op_mask)) { + MS_LOG(DEBUG) << "Topgraph free parameter"; + // get the parameter name from parameter object + auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); + if (py::isinstance(name_attr)) { + MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; + } + auto param_name = py::cast(name_attr); + if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { + auto free_param = df_builder_->add_parameter(); + free_param->set_name(param_name); + auto free_param_new = py::cast(obj.attr("_value")); + free_param->set_default_param(free_param_new); + free_param->debug_info()->set_name(param_name); + MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; + graph_info_map_[df_builder_].param_map[obj_id] = free_param; + return free_param; + } + return graph_info_map_[df_builder_].param_map[obj_id]; + } + + // if input is graph output + if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) { + // op(x, y) + node = graph_info_map_[curr_g_].param_map[obj_id]; + } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) { + // out = op(op1(x, y)) + // out = op(cell1(x, y)) + // out = op(cell1(x, y)[0]) + node = GetObjNode(obj); + } else if (py::isinstance(obj)) { + // out = op((x, y)) + // out = cell((x, y)) + auto tuple = obj.cast(); + + // cell((1,2)): support not mix (scalar, tensor) + if (tuple.size() > 0 && !py::isinstance(tuple[0])) { + return MakeValueNode(obj, obj_id); + } + + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto tuple_size = static_cast(tuple.size()); + for (int i = 0; i < tuple_size; i++) { + args.push_back(GetInput(tuple[i], py::object())); + } + auto cnode = curr_g_->NewCNode(args); + set_obj_node_map(curr_g_, GetId(obj), cnode); + node = cnode; + } else { + node = MakeValueNode(obj, obj_id); + } + + MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id; + return node; +} + +// for output[0][1] need getitem multi +void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx) { + if (py::isinstance(obj)) { + auto tuple = obj.cast(); + for (int i = 0; i < static_cast(tuple.size()); i++) { + std::vector tmp = idx; + tmp.push_back(i); + set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp); + SetTupleOutput(tuple[i], cnode, tmp); + } + } +} + +void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } + +void PynativeExecutor::Popp() { + if (graph_p_.empty()) { + MS_LOG(EXCEPTION) << "Stack graph_p_ is empty"; + } + curr_g_ = graph_p_.top(); + graph_p_.pop(); +} + +void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { + auto cell_id = GetId(cell); + if (cell_graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "Endgraph already compiled"; + return; + } + cell_graph_map_[cell_id] = curr_g_; + auto out_id = GetId(out); + if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { + // cell construct return x, y + if (py::isinstance(out)) { + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto tuple = out.cast(); + MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size(); + auto tuple_size = static_cast(tuple.size()); + auto cnode = curr_g_->NewCNode(args); + for (int i = 0; i < tuple_size; i++) { + args.push_back(GetInput(tuple[i], py::object())); + set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); + SetTupleOutput(tuple[i], cnode, std::vector{i}); + } + cnode->set_inputs(args); + set_obj_node_map(curr_g_, out_id, cnode); + } else { + MS_LOG(ERROR) << "Graph has no this out: " << out_id; + return; + } + } + EndGraphByOutId(out_id, cell, out, args); +} + +void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, + const py::args &args) { + AnfNodePtr output_node; + if (graph_info_map_[curr_g_].param_map.count(out_id)) { + output_node = graph_info_map_[curr_g_].param_map[out_id]; + } else { + output_node = GetObjNode(out); + } + curr_g_->set_output(output_node); + std::vector inputs; + inputs.push_back(NewValueNode(curr_g_)); + MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); + resource_->manager()->AddFuncGraph(curr_g_); + // custom bprop debug + if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { + MS_LOG(DEBUG) << "Use cell custom bprop function."; + FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); + if (bprop_graph != nullptr) { + (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); + } + } + auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); + if (curr_g_ != top_g_) { + Popp(); + for (size_t i = 0; i < args.size(); i++) { + auto input = GetInput(args[i], py::object()); + inputs.push_back(input); + } + auto out_cnode = curr_g_->NewCNode(inputs); + set_pyobj(curr_g_, GetId(cell)); + if (py::isinstance(out)) { + auto out_list = py::cast(out); + auto out_size = static_cast(out_list.size()); + for (int i = 0; i < out_size; i++) { + set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); + SetTupleOutput(out_list[i], out_cnode, std::vector{i}); + } + } + set_obj_node_map(curr_g_, GetId(out), out_cnode); + } else { + parse::ResolveFuncGraph(newfg, resource_); + resource_->set_func_graph(newfg); + } +} + +std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights) { + std::vector w_args; + if (py::hasattr(weights, "__parameter_tuple__")) { + auto tuple = weights.cast(); + MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size(); + w_args.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t it = 0; it < tuple.size(); ++it) { + auto param = tuple[it]; + auto param_id = GetId(param); + AnfNodePtr para_node = nullptr; + if (graph_info_map_[df_builder_].param_map.count(param_id)) { + para_node = graph_info_map_[df_builder_].param_map[param_id]; + + AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); + AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); + auto refkey = std::make_shared(para_node->cast()->name()); + AnfNodePtr ref_key_node = NewValueNode(refkey); + AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); + + w_args.push_back(ref_node); + } + } + } else { + MS_LOG(DEBUG) << "training not paramter_tuple"; + } + return w_args; +} + +abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { + abstract::AbstractBasePtrList args_spec; + std::size_t size = args.size(); + for (std::size_t i = 0; i < size; i++) { + ValuePtr converted = nullptr; + bool succ = parse::ConvertData(args[i], &converted); + if (!succ) { + MS_LOG(EXCEPTION) << "Args convert error"; + } + bool broaden = true; + auto abs = abstract::FromValue(converted, broaden); + args_spec.push_back(abs); + auto param_node = std::static_pointer_cast(df_builder_->parameters()[i]); + param_node->set_abstract(abs); + } + + for (const auto ¶m : df_builder_->parameters()) { + auto param_node = std::static_pointer_cast(param); + if (param_node->has_default()) { + const auto ¶m_value = param_node->default_param(); + ValuePtr value = param_value->value(); + AbstractBasePtr ptr = abstract::FromValue(value, true); + if (ptr == nullptr) { + MS_LOG(EXCEPTION) << "Args convert error"; + } + args_spec.push_back(ptr); + param_node->set_abstract(ptr); + } + } + + return args_spec; +} + +void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args) { + MS_LOG(INFO) << "GradNet start" << args.size(); + + std::size_t size = args.size(); + auto cell_id = GetId(cell); + if (graph_map_.count(cell_id) != 0) { + MS_LOG(DEBUG) << "GradNet already compiled"; + return; + } + MS_LOG(DEBUG) << "GradNet first compiled"; + std::vector new_params; + for (size_t i = 0; i < size; i++) { + ParameterPtr p = std::make_shared(df_builder_); + new_params.push_back(p); + } + MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); + new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); + df_builder_->set_parameters(new_params); + resource_->manager()->SetParameters(df_builder_, new_params); + + std::vector w_args = GetWeightsArgs(weights); + MS_EXCEPTION_IF_NULL(resource_->func_graph()); + auto g = GradGraph(resource_->func_graph(), grad, w_args, size); + resource_->set_func_graph(g); + resource_->manager()->KeepRoots({g}); + + // get the parameters items and add the value to args_spec + abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); + MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); + + resource_->set_args_spec(args_spec); + MS_LOG(DEBUG) << "Start opt"; + + // Create backend and session + resource_->results()[pipeline::kBackend] = compile::CreateBackend(); + + graph_map_[cell_id] = g; + PynativeOptimizeAction(resource_); + TaskEmitAction(resource_); + ExecuteAction(resource_); + resource_->Clean(); + ad::CleanRes(); + pipeline::ReclaimOptimizer(); +} + +void PynativeExecutor::Clear(const std::string &flag) { + if (!flag.empty()) { + MS_LOG(INFO) << "Clear res"; + (void)graph_map_.erase(flag); + (void)cell_graph_map_.erase(flag); + Clean(); + // Maybe exit in the pynative runing op, so need reset pynative flag. + auto ms_context = MsContext::GetInstance(); + if (ms_context != nullptr) { + ms_context->set_enable_pynative_infer(false); + } + return; + } + + MS_LOG(INFO) << "Clear"; + top_g_ = nullptr; + curr_g_ = nullptr; + graph_info_map_.clear(); + std::stack().swap(graph_p_); +} + +void PynativeExecutor::Clean() { + MS_LOG(INFO) << "Clean all res"; + Clear(); + grad_flag_ = false; + df_builder_ = nullptr; + ad::CleanRes(); + pipeline::ReclaimOptimizer(); +} + +void PynativeExecutor::ClearRes() { + Clean(); + resource_.reset(); +} + +py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { + VectorRef arg_list; + pipeline::ProcessVmArgInner(args, resource_, &arg_list); + if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || + !resource_->results()[pipeline::kOutput].is()) { + MS_LOG(EXCEPTION) << "Can't find run graph func for "; + } + compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast(); + if (run == nullptr) { + MS_LOG(EXCEPTION) << "Can't find run graph func for "; + } + + std::string backend = MsContext::GetInstance()->backend_policy(); + + MS_LOG(DEBUG) << "Eval run" << backend; + BaseRef value = (*run)(arg_list); + MS_LOG(DEBUG) << "Run end" << value.ToString(); + return BaseRefToPyData(value); +} + +FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, + const std::vector &weights, size_t arg_size) { + auto nparam = top_g_->parameters().size(); + std::ostringstream ss; + ss << "grad{" << nparam << "}"; + df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true); + df_builder_->debug_info()->set_name(ss.str()); + + auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights); + std::vector inputs = {NewValueNode(df)}; + for (size_t i = 0; i < arg_size; ++i) { + inputs.push_back(df_builder_->parameters()[i]); + } + auto out = df_builder_->NewCNode(inputs); + df_builder_->set_output(out); + resource_->manager()->AddFuncGraph(df); + resource_->manager()->AddFuncGraph(df_builder_); + return df_builder_; +} + +void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { + PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args); +} + +void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { + PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); +} + +void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args) { + PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); +} + +REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { + (void)py::class_>(*m, "PynativeExecutor_") + .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") + .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.") + .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.") + .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") + .def("clear", &PynativeExecutor::Clear, "pynative clear status.") + .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""), + "Executor run function.") + .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), + "Executor set grad flag."); + })); +} // namespace pynative +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h new file mode 100644 index 0000000000..152d58aca4 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ +#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +#include "pipeline/pynative/base.h" +#include "utils/context/ms_context.h" +#include "ir/anf.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/composite/composite.h" + +namespace mindspore { +namespace pynative { + +namespace py = pybind11; +using ResourcePtr = std::shared_ptr; +using GradOperationPtr = std::shared_ptr; + +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); + +py::tuple RunOp(const py::args &args); + +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, + py::list *const out_args_list); + +void ClearPyNativeSession(); + +struct GraphInfo { + std::unordered_map param_map; + std::unordered_map>> obj_node_map; + AnfNodePtr output; + std::vector objects; +}; + +class PynativeExecutor : public std::enable_shared_from_this { + public: + static std::shared_ptr GetInstance() { + std::lock_guard i_lock(instance_lock_); + if (executor_ == nullptr) { + executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); + resource_ = std::make_shared(); + } + return executor_; + } + void NewGraph(const py::object &cell, const py::args &args); + void NewGraphInner(const py::object &cell, const py::args &args); + void EndGraph(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); + void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); + std::vector GetWeightsArgs(const py::object &weights); + abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); + void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); + void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, + const py::args &args); + void Clear(const std::string &flag = ""); + void Clean(); + void ClearRes(); + bool grad_flag() { return grad_flag_; } + void set_grad_flag(bool flag) { grad_flag_ = flag; } + AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask); + AnfNodePtr GetObjNode(const py::object &obj); + FuncGraphPtr curr_g() { return curr_g_; } + void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } + void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { + graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{-1}); + } + void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { + graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{index}); + } + void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { + graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); + } + AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); + py::object Run(const py::tuple &args, const py::object &phase); + + void Pushp(); + void Popp(); + FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, + size_t arg_size); + void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); + AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); + + ~PynativeExecutor(); + + private: + PynativeExecutor(); + static std::shared_ptr executor_; + static std::mutex instance_lock_; + static ResourcePtr resource_; + bool grad_flag_; + std::unordered_map graph_map_; + std::unordered_map cell_graph_map_; + std::unordered_map graph_info_map_; + std::stack graph_p_; + FuncGraphPtr top_g_; + FuncGraphPtr df_builder_; + FuncGraphPtr curr_g_; +}; + +using PynativeExecutorPtr = std::shared_ptr; + +} // namespace pynative +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc new file mode 100644 index 0000000000..897c21fc90 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.cc @@ -0,0 +1,312 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/pynative/pynative_execute_ge.h" + +#include +#include +#include +#include + +#include "utils/any.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "backend/session/session_factory.h" +#include "ir/tensor_py.h" + +const char SINGLE_OP_GRAPH[] = "single_op_graph"; + +using mindspore::tensor::TensorPy; + +namespace mindspore { +namespace pynative { +using MeTensor = mindspore::tensor::Tensor; +using MeTensorPtr = mindspore::tensor::TensorPtr; +using GeOperator = ge::Operator; +using GeOperatorPtr = std::shared_ptr; + +using transform::GraphRunner; +using transform::GraphRunnerOptions; +using transform::OperatorPtr; +static std::shared_ptr session = nullptr; +inline ValuePtr PyAttrValue(const py::object &obj) { + ValuePtr converted_ret = nullptr; + bool converted = parse::ConvertData(obj, &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); + } + return converted_ret; +} + +MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { + MeTensorPtr me_tensor_ptr = nullptr; + if (py::isinstance(obj)) { + me_tensor_ptr = py::cast(obj); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); + } else if (py::isinstance(obj)) { + me_tensor_ptr = TensorPy::MakeTensor(py::cast(obj), nullptr); + } else { + MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; + } + return me_tensor_ptr; +} + +bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const OperatorPtr &op, std::vector *graph_input_nodes) { + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(graph_input_nodes); + auto op_inputs = op_exec_info->op_inputs; + std::string op_name = op_exec_info->op_name; + transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); + if (adapter == nullptr) { + return false; + } + + int op_input_idx = 1; + size_t size = inputs.size(); + for (size_t i = 0; i < size; i++) { + if (inputs[i] == nullptr) { + continue; + } + auto const_op = std::make_shared(); + MS_EXCEPTION_IF_NULL(const_op); + (void)const_op->set_attr_value(*inputs[i]); + MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); + MS_EXCEPTION_IF_NULL(me_tensor_ptr); + auto const_op_desc = + transform::TransformUtil::GetGeTensorDesc(me_tensor_ptr->shape_c(), me_tensor_ptr->data_type(), kOpFormat_NCHW); + if (const_op_desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << op_name << " output descriptor failed!"; + return false; + } + auto pointer_cast_const_op = std::static_pointer_cast(const_op); + MS_EXCEPTION_IF_NULL(pointer_cast_const_op); + (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); + auto &input_map = adapter->getInputMap(); + if (input_map.find(op_input_idx) == input_map.end()) { + continue; + } + if (adapter->setInput(op, op_input_idx++, const_op)) { + MS_LOG(ERROR) << "Failed to set params, index is " << op_input_idx; + return false; + } + graph_input_nodes->push_back(*const_op); + } + return true; +} + +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(op_exec_info); + std::string op_name = op_exec_info->op_name; + auto op_inputs = op_exec_info->op_inputs; + transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); + if (adapter == nullptr) { + MS_LOG(ERROR) << "Unable to find Adapter for " << ((std::string)py::str(op_name)); + return false; + } + OperatorPtr op = adapter->generate(op_name); + MS_EXCEPTION_IF_NULL(op); + + std::vector graph_input_nodes; + // hold param nodes after setting input and output for the graph + // set input + if (!SetInputsForSingleOpGraph(op_exec_info, inputs, op, &graph_input_nodes)) { + return false; + } + // set attributes + for (auto attr : attrs) { + (void)adapter->setAttr(op, attr.first, attr.second); + } + // set default attributes + auto extra_attrs = adapter->GetExtraAttr(); + for (auto attr : extra_attrs) { + (void)adapter->setAttr(op, attr.first, attr.second); + } + // set input attributes + auto &input_attr_map = adapter->getInputAttrMap(); + for (auto &it : input_attr_map) { + if (op_inputs.size() < it.first) { + continue; + } + auto const_value = PyAttrValue(op_inputs[it.first - 1]); + if (const_value->isa()) { + continue; + } + it.second.set_attr(op, const_value); + } + // construct output data nodes + std::vector graph_outputs{*op}; + // set input and output nodes for the graph + MS_EXCEPTION_IF_NULL(graph); + (void)graph->SetInputs(graph_input_nodes).SetOutputs(graph_outputs); + MS_LOG(INFO) << "BuildSingleOpGraph done"; + return true; +} + +void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { + MS_EXCEPTION_IF_NULL(inputs); + MS_EXCEPTION_IF_NULL(op_exec_info); + auto op_inputs = op_exec_info->op_inputs; + size_t size = op_inputs.size(); + for (size_t i = 0; i < size; i++) { + if (py::isinstance(op_inputs[i])) { + inputs->emplace_back(nullptr); + continue; + } + MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); + auto ge_tensor_ptr = transform::TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW); + if (ge_tensor_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Convert inputs to GE tensor failed in op " << op_exec_info->op_name << "."; + } + // set inputs for operator to build single node graph + inputs->push_back(ge_tensor_ptr); + } +} + +PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { + MS_EXCEPTION_IF_NULL(op_exec_info); + auto op_attrs = op_exec_info->op_attrs; + std::unordered_map attrs{}; + + for (auto &item : op_attrs) { + if (!py::isinstance(item.first)) { + MS_LOG(ERROR) << "Type error in py dict convert"; + return PYNATIVE_OP_ATTRS_ERR; + } + std::string name = py::cast(item.first); + auto attr_value = PyAttrValue(py::cast(item.second)); + (void)attrs.emplace(name, attr_value); + } + + // build graph + GeGraphPtr graph = std::make_shared(op_exec_info->op_name); + if (BuildSingleOpGraph(op_exec_info, inputs, attrs, graph) == false) { + MS_LOG(ERROR) << "Failed to BuildSingleOpGraph"; + return PYNATIVE_GRAPH_GE_BUILD_ERR; + } + + // add the single op graph into the graph manager, which will be iterated by session. + transform::Status ret = + transform::DfGraphManager::GetInstance().AddGraph(SINGLE_OP_GRAPH, std::shared_ptr(graph)); + if (ret != transform::SUCCESS) { + MS_LOG(ERROR) << "Failed to AddGraph into graph manager"; + return PYNATIVE_GRAPH_MANAGER_ERR; + } + + return PYNATIVE_SUCCESS; +} + +std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, + const std::vector &ge_tensors) { + std::vector outputs; + AbstractBasePtr abs_base = op_exec_info->abstract; + std::vector> shapes; + if (abs_base != nullptr && abs_base->isa()) { + auto arg_tensor = dyn_cast(abs_base); + shapes.emplace_back(arg_tensor->shape()->shape()); + outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); + return outputs; + } + if (abs_base != nullptr && abs_base->isa()) { + auto arg_tuple = dyn_cast(abs_base); + size_t len = arg_tuple->size(); + + for (size_t i = 0; i < len; i++) { + if (arg_tuple->elements()[i]->isa()) { + auto arg_tensor = dyn_cast(arg_tuple->elements()[i]); + shapes.emplace_back(arg_tensor->shape()->shape()); + } + } + outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); + return outputs; + } + for (auto &it : ge_tensors) { + auto tensor = transform::TransformUtil::ConvertGeTensor(it); + if (tensor != nullptr) { + outputs.emplace_back(tensor); + } + } + return outputs; +} + +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_LOG(INFO) << "RunOpInGe start"; + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(status); + + // returns a null py::tuple on error + py::tuple err_ret(0); + auto op_name = op_exec_info->op_name; + transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); + if (adapter == nullptr) { + MS_LOG(ERROR) << "Unable to find GE Adapter for " << ((std::string)py::str(op_name)); + *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; + return std::move(err_ret); + } + + std::vector inputs{}; + ToTensorPtr(op_exec_info, &inputs); + // convert me attr to ge AttrValue + PynativeStatusCode ret = ConvertAttributes(op_exec_info, inputs); + if (ret != PYNATIVE_SUCCESS) { + *status = ret; + return std::move(err_ret); + } + // run graph + transform::RunOptions run_options; + run_options.name = SINGLE_OP_GRAPH; + std::vector ge_inputs; + std::vector ge_outputs; + transform::GraphRunnerOptions graph_runner_options; + graph_runner_options.options["ge.trainFlag"] = "1"; + auto graph_runner = std::make_shared(graph_runner_options); + transform::Status run_ret; + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + run_ret = graph_runner->RunGraph(run_options, ge_inputs, &ge_outputs); + } + if (run_ret != transform::Status::SUCCESS) { + MS_LOG(ERROR) << "GraphRunner fails to run graph"; + *status = PYNATIVE_GRAPH_GE_RUN_ERR; + return std::move(err_ret); + } + + std::vector graph_outputs = ConvertOutputTensors(op_exec_info, ge_outputs); + size_t output_size = graph_outputs.size(); + py::tuple result(output_size); + for (size_t i = 0; i < output_size; i++) { + MS_EXCEPTION_IF_NULL(graph_outputs[i]); + result[i] = *graph_outputs[i]; + } + + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInGe end"; + return std::move(result); +} +} // namespace pynative +} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h new file mode 100644 index 0000000000..2978278489 --- /dev/null +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute_ge.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ +#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ + +#include +#include +#include +#include +#include + +#include "pipeline/pynative/base.h" +#include "transform/graph_ir/convert.h" +#include "transform/graph_ir/graph_runner.h" +#include "transform/graph_ir/types.h" +#include "utils/context/ms_context.h" + +using GeTensor = ge::Tensor; +using GeTensorPtr = std::shared_ptr; +using GeGraph = ge::Graph; +using GeGraphPtr = std::shared_ptr; + +namespace mindspore { +namespace pynative { +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph); + +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); +} // namespace pynative +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc deleted file mode 100644 index 47881e4b91..0000000000 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/remove_value_node_dup.h" -#include "ir/anf.h" -#include "ir/tensor.h" -#include "ir/manager.h" -#include "optimizer/cse.h" -#include "utils/log_adapter.h" -#include "utils/hashing.h" - -namespace mindspore { -namespace pipeline { -void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, - HashValue *const hash_value) { - const auto &to_check_value = GetValueNode(node); - MS_EXCEPTION_IF_NULL(to_check_value); - - // Calculate hash value. - size_t h; - auto hash_iter = hash_value->find(node); - if (hash_iter == hash_value->end()) { - h = hash_combine(to_check_value->hash(), (opt::AbsOf(node)->hash())); - (*hash_value)[node] = h; - } else { - h = hash_iter->second; - } - - auto bucket_iter = hash_cache->find(h); - if (bucket_iter == hash_cache->end()) { - // Meet for the first time, add bucket. - (*hash_cache)[h] = {node}; - return; - } - - auto &bucket = bucket_iter->second; - // Check if need to replace node with value node already met. - for (const auto &v : bucket) { - // Already met and cached. - if (v == node) { - return; - } - const auto &existed_value = GetValueNode(v); - MS_EXCEPTION_IF_NULL(existed_value); - auto equal = [&]() -> bool { - if (existed_value->isa() && to_check_value->isa()) { - return existed_value->cast()->ValueEqual(*(to_check_value->cast())); - } - return *existed_value == *to_check_value; - }; - if (equal()) { - (void)manager->Replace(node, v); - return; - } - } - - // Meet for the first time, append node to bucket. - bucket.emplace_back(node); -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc deleted file mode 100644 index cd79b2466a..0000000000 --- a/mindspore/ccsrc/pipeline/resource.cc +++ /dev/null @@ -1,260 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/resource.h" -#include "pipeline/pipeline.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "debug/draw.h" -#include "debug/trace.h" -#include "ir/dtype.h" -#include "pipeline/parse/data_converter.h" -#include "operator/ops.h" -#include "utils/graph_utils.h" -#include "optimizer/ad/dfunctor.h" -#include "vm/segment_runner.h" - -namespace mindspore { -// namespace to support opmap definition -namespace pipeline { - -MethodMap &GetMethodMap() { - static MethodMap method_map = { - {kObjectTypeString, - { - {"__bool__", std::string("str_bool")} // C.str_bool - }}, - {kMetaTypeNone, - { - {"__bool__", std::string("none_bool")} // C.none_bool - }}, - {kNumberTypeBool, - { - {"__and__", prim::kPrimBoolAnd}, // P.bool_and - {"__or__", prim::kPrimBoolOr}, // P.bool_or - {"__eq__", prim::kPrimBoolEq}, // P.bool_eq - {"__ne__", std::string("bool_ne")}, // C.bool_ne - {"__bool__", prim::kPrimIdentity} // P.identity - }}, - {kNumberTypeInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul - {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow - {"__floor__", prim::kPrimIdentity}, // P.identity - {"__trunc__", prim::kPrimIdentity}, // P.identity - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt - {"__le__", prim::kPrimScalarLe}, // P.scalar_le - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array - }}, - {kNumberTypeUInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimIdentity}, // P.identity, - {"__trunc__", prim::kPrimIdentity}, // P.identity, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kNumberTypeFloat, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv - {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, - {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("float_bool")}, // C.float_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kObjectTypeTuple, - { - {"__len__", prim::kPrimTupleLen}, // P.tuple_len, - {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, - {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity, - {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, - {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext - {"__bool__", std::string("tuple_bool")} // C.tuple_bool - }}, - {kObjectTypeList, - { - {"__len__", prim::kPrimListLen}, // P.list_len, - {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, - {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity - {"__ms_next__", std::string("list_next")}, // C.list_next - {"append", std::string("list_append")}, // C.list_next - {"__bool__", std::string("list_bool")}, // C.list_bool - {"__ms_hasnext__", std::string("list_hasnext")}, - }}, - {kObjectTypeDictionary, - { - {"__len__", prim::kPrimDictLen}, // P.dict_len - {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem - {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, - {"__bool__", std::string("dict_bool")} // C.dict_bool - }}, - {kObjectTypeTensorType, - { - {"__add__", std::string("add")}, // C.add - {"__sub__", std::string("sub")}, // C.sub - {"__mul__", std::string("mul")}, // C.mul - {"__truediv__", std::string("truediv")}, // C.truediv - {"__floordiv__", std::string("floordiv")}, // C.floordiv - {"__mod__", std::string("mod")}, // C.mod - {"__pow__", std::string("pow_")}, // C.pow - {"__floor__", std::string("array_floor")}, // C.array_floor - {"__trunc__", std::string("array_trunc")}, // C.array_trunc - {"__pos__", std::string("array_uadd")}, // C.array_uadd - {"__neg__", std::string("array_usub")}, // C.array_usub - {"__eq__", std::string("eq")}, // C.eq - {"__ne__", std::string("ne")}, // C.ne - {"__lt__", std::string("lt")}, // C.lt - {"__gt__", std::string("gt")}, // C.gt - {"__le__", std::string("le")}, // C.le - {"__ge__", std::string("ge")}, // C.ge - {"__matmul__", prim::kPrimDot}, // P.dot, - {"__len__", prim::kPrimArrayLen}, // P.array_len, - {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, - {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, - {"__ms_iter__", std::string("array_iter")}, // C.array_iter - {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, - {"transpose", std::string("transpose")}, // P.transpose - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool - }}, - {kObjectTypeIndexedSlicesType, - { - {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values - {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices - {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape - }}, - {kObjectTypeJTagged, {}}, - {kObjectTypeSymbolicKeyType, {}}, - {kObjectTypeEnvType, {}}}; - return method_map; -} - -Resource::Resource(const py::object &obj) - : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), - input_(obj), - is_cleaned_(false) {} - -Resource::~Resource() { - MS_LOG(DEBUG) << "Resource clear"; - - // If exit normally, these global variables will be cleaned - // in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION, - // these global variables may not being cleaned, it may - // cause segmentfault when free python object inside these global variables - // after python interpreter got freed, so these global variables - // are cleaned here. - // So if exit normally, these global variable will be cleaned twice, - // care be taken to prevent double free in the following functions. - if (!is_cleaned_) { - try { - Clean(); - } catch (const std::exception &e) { - MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); - } catch (...) { - MS_LOG(ERROR) << "Exception when cleaning resource."; - } - } -} - -bool Resource::IsTypeInMethodMap(const TypeId &type) { - TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); - auto iter = method_map.find(static_cast(type_id)); - if (iter != method_map.end()) { - return true; - } - return false; -} - -Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { - TypeId type_id = NormalizeTypeId(type); - const MethodMap &method_map = GetMethodMap(); - auto iter = method_map.find(static_cast(type_id)); - if (iter == method_map.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; - return Any(); - } - - auto iter_map = iter->second.find(name); - if (iter_map == iter->second.end()) { - MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name; - return Any(); - } - return iter_map->second; -} - -void Resource::Clean() { - // AbstractTensor->elements() will be saved in AbstractBasePtrList - args_spec_.clear(); - input_ = py::none(); - // Context with AbstractBasePtrList may be saved in GraphEvaluator - // some Evaluator like ResolveEvaluator may save Python object in cache, - // it should be cleaned before Python Interpreter destructed. - MS_EXCEPTION_IF_NULL(engine_); - engine_->ClearEvaluatorCache(); - // clean static variable to prevent from crash. As static variable is released after - // Python threads is released. - parse::data_converter::ClearObjectCache(); - parse::Parser::CleanParserResource(); - parse::CleanDataClassToClassMap(); - trace::ClearTraceStack(); - is_cleaned_ = true; -} -} // namespace pipeline -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/resource.h b/mindspore/ccsrc/pipeline/resource.h deleted file mode 100644 index 0c1348fd94..0000000000 --- a/mindspore/ccsrc/pipeline/resource.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ -#define MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ - -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -#include "utils/any.h" -#include "utils/profile.h" -#include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "./common.h" - -namespace mindspore { -namespace pipeline { - -namespace py = pybind11; - -const char kBackend[] = "backend"; -const char kStepParallelGraph[] = "step_parallel"; -const char kOutput[] = "output"; - -class InferenceResource; - -using MethodMap = std::unordered_map>; - -MethodMap &GetMethodMap(); - -class ResourceBase { - public: - ResourceBase() { manager_ = MakeManager(); } - - virtual ~ResourceBase() = default; - - FuncGraphManagerPtr manager() { return manager_; } - // set a manager defined outside which will not manage the graphs. - void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } - - std::unordered_map &results() { return results_; } - - void SetResult(const std::string &key, const Any &value) { results_[key] = value; } - - Any GetResult(const std::string &key) { - if (results_.count(key) == 0) { - MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; - } - return results_[key]; - } - - bool HasResult(const std::string &key) const { return results_.count(key) != 0; } - - std::unordered_map results_; - - protected: - FuncGraphManagerPtr manager_; -}; - -using ResourceBasePtr = std::shared_ptr; - -class Resource : public ResourceBase { - public: - explicit Resource(const py::object &obj = py::none()); - - ~Resource() override; - - abstract::AnalysisEnginePtr engine() { return engine_; } - - static bool IsTypeInMethodMap(const TypeId &type); - - static Any GetMethodPtr(const TypeId &type, const std::string &name); - - const py::object &input() const { return input_; } - - FuncGraphPtr func_graph() const { return func_graph_; } - void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } - - const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } - void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } - - // Reclaim resource and clear the cache. - // ExecutorPy::Compile() can be called multiple times, so cache - // should be cleared. - void Clean(); - - private: - abstract::AnalysisEnginePtr engine_; - FuncGraphPtr func_graph_; - abstract::AbstractBasePtrList args_spec_; - py::object input_; - bool is_cleaned_; -}; - -using ResourcePtr = std::shared_ptr; - -} // namespace pipeline -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_RESOURCE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc deleted file mode 100644 index cd768f7515..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc +++ /dev/null @@ -1,361 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/abstract_function.h" - -#include - -#include "pipeline/static_analysis/static_analysis.h" - -namespace mindspore { -namespace abstract { -class Evaluator; -class AnalysisEngine; - -AbstractFunctionPtr AbstractFunction::MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list) { - if (func_list.size() == 1) { - return func_list[0]; - } - return std::make_shared(func_list); -} - -AbstractFunctionPtr AbstractFuncAtom::Join(const AbstractFunctionPtr &other) { - auto this_func = shared_from_base(); - if (other->isa()) { - if (*this_func == *other) { - return this_func; - } - return std::make_shared(this_func, other); - } - auto other_union = dyn_cast(other); - if (other_union->IsSuperSet(this_func)) { - return other; - } - return std::make_shared(this_func, other); -} - -void AbstractFuncAtom::Visit(std::function visit_func) const { - visit_func(const_cast(this)->shared_from_base()); -} - -bool AbstractFuncAtom::operator==(const AbstractFunction &other) const { return this == &other; } - -AbstractFuncUnion::AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list) { func_list_ = func_list; } - -AbstractFuncUnion::AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second) { - AbstractFuncAtomPtrList new_func_list; - auto build_func_list = [&new_func_list](const AbstractFuncAtomPtr &func) { new_func_list.push_back(func); }; - - first->Visit(build_func_list); - second->Visit(build_func_list); - func_list_ = new_func_list; -} - -std::string AbstractFuncUnion::ToString() const { - std::ostringstream buffer; - buffer << "AbstractFuncUnion({"; - int i = 0; - for (const auto &func : func_list_) { - MS_EXCEPTION_IF_NULL(func); - buffer << "[" << i << "]: " << func->ToString() << ", "; - i++; - } - buffer << "})"; - return buffer.str(); -} - -bool AbstractFuncUnion::IsSuperSet(const AbstractFunctionPtr &other) { - MS_EXCEPTION_IF_NULL(other); - std::vector is_in_list; - auto build_in_list = [this, &is_in_list](const AbstractFuncAtomPtr &func) { - auto iter = find(func_list_.begin(), func_list_.end(), func); - if (iter == func_list_.end()) { - is_in_list.push_back(false); - } - return true; - }; - other->Visit(build_in_list); - return std::all_of(is_in_list.begin(), is_in_list.end(), [](bool is_in) { return is_in; }); -} - -AbstractFunctionPtr AbstractFuncUnion::Join(const AbstractFunctionPtr &other) { - auto this_func = shared_from_base(); - if (other->isa()) { - if (IsSuperSet(other)) { - return this_func; - } - return std::make_shared(this_func, other); - } - auto other_union = dyn_cast(other); - if (other_union->IsSuperSet(this_func)) { - return other; - } - return std::make_shared(this_func, other); -} - -void AbstractFuncUnion::Visit(std::function visit_func) const { - for (AbstractFuncAtomPtr poss : func_list_) { - visit_func(poss); - } -} - -bool AbstractFuncUnion::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_union = static_cast(&other); - if (func_list_.size() != other_union->func_list_.size()) { - return false; - } - if (func_list_ == other_union->func_list_) { - return true; - } - return false; -} - -std::size_t AbstractFuncUnion::hash() const { - std::size_t hash_sum = 0; - for (auto f : func_list_) { - hash_sum = hash_combine(hash_sum, f->hash()); - } - return hash_sum; -} - -EvaluatorPtr PrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_prim = static_cast(&other); - if (prim_ == other_prim->prim_ && tracking_id() == other_prim->tracking_id()) { - return true; - } - return false; -} - -std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); } - -EvaluatorPtr FuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_fg = static_cast(&other); - if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) { - return true; - } - return false; -} - -std::size_t FuncGraphAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), func_graph_->hash()); - hash_value = hash_combine(hash_value, context_->hash()); - return hash_value; -} - -std::string FuncGraphAbstractClosure::ToString() const { - std::stringstream ss; - ss << "FuncGraphAbstractClosure: " - << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); - return ss.str(); -} - -EvaluatorPtr MetaFuncGraphAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_meta_fg = static_cast(&other); - if (meta_func_graph_ == other_meta_fg->meta_func_graph_) { - return true; - } - return false; -} - -std::size_t MetaFuncGraphAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), meta_func_graph_->hash()); - return hash_value; -} - -std::string MetaFuncGraphAbstractClosure::ToString() const { - return "MetaFuncGraphAbstractClosure: " + meta_func_graph_->name(); -} - -bool PartialAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_partial = static_cast(&other); - if (fn_ != other_partial->fn_) { - return false; - } - if (args_spec_list_.size() != other_partial->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_partial->args_spec_list_) { - return true; - } - return false; -} - -std::size_t PartialAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), fn_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -EvaluatorPtr PartialAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -std::string PartialAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "PartialAbstractClosure(" << fn_->ToString() << "("; - for (auto arg : args_spec_list_) { - buffer << arg->ToString() << ", "; - } - buffer << "))"; - return buffer.str(); -} - -EvaluatorPtr JTransformedAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool JTransformedAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_transformed = static_cast(&other); - if (fn_ == other_transformed->fn_) { - return true; - } - return false; -} - -std::size_t JTransformedAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), fn_->hash()); - return hash_value; -} - -EvaluatorPtr VirtualAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool VirtualAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_virtual = static_cast(&other); - if (output_ != other_virtual->output_) { - return false; - } - if (args_spec_list_.size() != other_virtual->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_virtual->args_spec_list_) { - return true; - } - return false; -} - -std::size_t VirtualAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), output_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -std::string VirtualAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "VirtualAbstractClosure(args: {"; - int i = 0; - for (const auto &arg : args_spec_list_) { - MS_EXCEPTION_IF_NULL(arg); - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - buffer << "}, output: " << output_->ToString() << ")"; - return buffer.str(); -} - -EvaluatorPtr TypedPrimitiveAbstractClosure::GetEvaluator(AnalysisEnginePtr engine) { - MS_EXCEPTION_IF_NULL(engine); - - return engine->_GetEvaluatorFor(shared_from_base()); -} - -bool TypedPrimitiveAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - auto other_typed = static_cast(&other); - if (output_ != other_typed->output_) { - return false; - } - if (prim_ != other_typed->prim_) { - return false; - } - if (args_spec_list_.size() != other_typed->args_spec_list_.size()) { - return false; - } - if (args_spec_list_ == other_typed->args_spec_list_) { - return true; - } - return false; -} - -std::size_t TypedPrimitiveAbstractClosure::hash() const { - auto hash_value = hash_combine(tid(), prim_->hash()); - hash_value = hash_combine(hash_value, AbstractBasePtrListHash(args_spec_list_)); - return hash_value; -} - -std::string TypedPrimitiveAbstractClosure::ToString() const { - std::ostringstream buffer; - buffer << "TypedPrimitiveAbstractClosure: primitive: " << prim_->name() << "(args: {"; - int i = 0; - for (const auto &arg : args_spec_list_) { - MS_EXCEPTION_IF_NULL(arg); - buffer << "[" << i << "]: " << arg->ToString() << ", "; - i++; - } - buffer << "}, output: " << output_->ToString() << ")"; - return buffer.str(); -} - -bool DummyAbstractClosure::operator==(const AbstractFunction &other) const { - if (!other.isa()) { - return false; - } - return true; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc deleted file mode 100644 index 14ebeb0fc7..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ /dev/null @@ -1,404 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/evaluator.h" - -#include -#include - -#include "ir/func_graph_cloner.h" -#include "abstract/utils.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -namespace { -string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, - const AnfNodeConfigPtr &out_conf) { - MS_EXCEPTION_IF_NULL(evaluator); - std::stringstream ss; - if (out_conf != nullptr) { - ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); - } - for (size_t i = 0; i < arg_spec_list.size(); i++) { - ss << evaluator->ToString() << " input[" << i << "] abstract value: " << arg_spec_list[i]->ToString(); - } - return ss.str(); -} - -void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { - MS_EXCEPTION_IF_NULL(evaluator); - if (out_conf != nullptr) { - auto node = out_conf->node(); - if (IsValueNode(node)) { - MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope() - << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); - } else { - MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString() - << ", with debug info: " << trace::GetDebugInfo(node->debug_info()); - } - } -} -} // namespace - -AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list) { - AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); - normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); - FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); - MS_EXCEPTION_IF_NULL(parent_context_); - AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); - return context; -} - -static std::vector FastShadowSort(const AnfNodePtr &ret_node) { - auto current_func_graph = ret_node->func_graph(); - MS_EXCEPTION_IF_NULL(current_func_graph); - - std::vector sorted_nodes; - auto seen = NewSeenGeneration(); - std::size_t index = 0; - sorted_nodes.emplace_back(ret_node); - while (index < sorted_nodes.size()) { - auto current = sorted_nodes[index]; - index++; - MS_EXCEPTION_IF_NULL(current); - if (current->isa()) { - auto &inputs = current->cast()->inputs(); - for (auto it = inputs.begin(); it != inputs.end(); it++) { - AnfNodePtr input = *it; - if (input != nullptr && input->isa() && input->seen_ != seen && - input->func_graph() == current_func_graph) { - sorted_nodes.emplace_back(input); - input->seen_ = seen; - } - } - } - } - return sorted_nodes; -} - -EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { - FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); - MS_EXCEPTION_IF_NULL(fg); - std::size_t nargs = fg->parameters().size(); - if (args_spec_list.size() != nargs) { - MS_EXCEPTION(TypeError) << "Function " << fg->ToString() << ", The number of parameters of this function is " - << fg->parameters().size() << ", but the number of provided arguments is " - << args_spec_list.size() << ". NodeInfo: " << trace::GetDebugInfo(fg->debug_info()); - } - MS_EXCEPTION_IF_NULL(parent_context_); - MS_EXCEPTION_IF_NULL(engine); - graph_context_ = parent_context_->NewFuncGraphContext(fg, args_spec_list); - const auto ¶meters = fg->parameters(); - for (size_t i = 0; i < nargs; i++) { - const auto &arg = args_spec_list[i]; - const auto &node = parameters[i]; - AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_); - engine->cache().set_value(conf, std::make_shared(arg, nullptr)); - } - const AnfNodePtr &func_node = fg->get_return(); - - MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg->ToString() - << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); - AbstractBasePtr ret_base = nullptr; - std::vector nodes = FastShadowSort(func_node); - for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { - const auto &node = *it; - AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); - MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString(); - ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); - MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString() - << ", abstract: " << ret_base->ToString(); - } - - MS_EXCEPTION_IF_NULL(ret_base); - MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() - << ", is stub: " << fg->stub(); - if (fg->stub()) { - return std::make_shared(std::make_shared(), nullptr); - } - return std::make_shared(ret_base, nullptr); -} - -AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { - MS_EXCEPTION_IF_NULL(func_graph_); - if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - AbstractBasePtrList broaded_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(arg); - return arg->Broaden(); - }); - MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) - << ", broaded: " << mindspore::ToString(broaded_list); - return broaded_list; - } - return args_spec_list; -} - -AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { - MS_EXCEPTION_IF_NULL(func_graph_); - if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { - return args_spec_list; - } - if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { - if (parent_context_) { - MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() - << ", context: " << parent_context_->ToString(); - auto last_context = parent_context_->Filter(func_graph_); - if (last_context && last_context->func_graph() == func_graph_) { - MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); - MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); - MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); - // Join the last eval arguments and current arguments to check if there are loop variant. - auto joined_args_spec_list = AbstractJoin(args_spec_list, last_context->args_spec_list()); - MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list); - // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. - if (!(joined_args_spec_list == args_spec_list)) { - func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; - } - return joined_args_spec_list; - } - } - if (trace_.size() != 0) { - MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); - MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); - // Join the last eval arguments and current arguments to check if there are loop variant. - auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); - // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. - if (!(joined_args_spec_list == args_spec_list)) { - trace_.push_back(joined_args_spec_list); - func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); - MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; - } - MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); - return joined_args_spec_list; - } else { - trace_.push_back(args_spec_list); - } - } - return args_spec_list; -} - -FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { - auto iter = func_graph_cache_.find(args_spec_list); - FuncGraphPtr ret = nullptr; - if (iter == func_graph_cache_.end()) { - auto fg = func_graph(); - MS_EXCEPTION_IF_NULL(fg); - TraceManager::DebugTrace(std::make_shared(fg->debug_info())); - FuncGraphPtr generated_graph = fg->GenerateGraph(args_spec_list); - TraceManager::EndTrace(); - func_graph_cache_[args_spec_list] = generated_graph; - MS_EXCEPTION_IF_NULL(engine); - engine->func_graph_manager()->AddFuncGraph(generated_graph); - ret = generated_graph; - } else { - ret = iter->second; - } - - // For the top graph, if it is replaced by generated graph, update the top graph to the new one. - if (parse::Parser::GetTopFuncGraph() == func_graph()) { - if (ret != func_graph()) { - parse::Parser::UpdateTopFuncGraph(ret); - } - } - return ret; -} - -FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { - auto iter = func_graph_cache_.find(args_spec_list); - if (iter != func_graph_cache_.end()) { - return iter->second; - } - - MS_EXCEPTION_IF_NULL(meta_func_graph_); - FuncGraphPtr generated_func_graph = nullptr; - if (this->bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); - TraceManager::EndTrace(); - } else { - generated_func_graph = meta_func_graph_->GenerateFuncGraph(args_spec_list); - } - - FuncGraphPtr cloned_func_graph = BasicClone(generated_func_graph); - func_graph_cache_[args_spec_list] = cloned_func_graph; - MS_EXCEPTION_IF_NULL(engine); - engine->func_graph_manager()->AddFuncGraph(cloned_func_graph); - return cloned_func_graph; -} - -EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { - const std::string &evaluator_name = ToString(); - - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - args_spec_list = NormalizeArgs(args_spec_list); - args_spec_list = BroadenUndeterminedArgs(args_spec_list); - trace::TraceGraphEvalEnter(shared_from_base(), out_conf); - MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base(), args_spec_list, out_conf); - MS_EXCEPTION_IF_NULL(cache_); - auto iter = cache_->find(args_spec_list); - if (iter == cache_->end()) { - MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval()."; - EvalResultPtr ret = Eval(engine, args_spec_list); - if (ret->abstract() == nullptr) { - EvalFailLogging(shared_from_base(), args_spec_list, out_conf); - MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; - } - MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << "."; - (*cache_)[args_spec_list] = ret; - trace::TraceGraphEvalLeave(shared_from_base()); - return ret; - } else { - MS_EXCEPTION_IF_NULL(iter->second); - MS_EXCEPTION_IF_NULL(iter->second->abstract()); - MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << "."; - trace::TraceGraphEvalLeave(shared_from_base()); - return iter->second; - } -} - -EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - EvalResultPtr ret = EvalPrim(engine, args_spec_list); - return ret; -} - -EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - if (args_conf_list.size() == 0) { - MS_LOG(EXCEPTION) << "Size should greater than 0"; - } - EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf); - // No need to cache. - return ret; -} - -EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { - EvalResultPtr ret = EvalPrim(args_conf_list); - return ret; -} - -EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf); - // Don't lookup from cache, as different out_conf with same node but different context - // may add different entry to anfnode_config_map_, like getattr primitive. - (*cache_)[args_spec_list] = ret; - return ret; -} - -EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - MS_EXCEPTION_IF_NULL(cache_); - auto iter = cache_->find(args_spec_list); - if (iter != cache_->end()) { - return iter->second; - } - - ConfigPtrList partial_args_conf_list; - // Join arguments in partial and the rest arguments from args_conf_list. - (void)std::transform(args_spec_list_.begin(), args_spec_list_.end(), std::back_inserter(partial_args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf); - - (*cache_)[args_spec_list] = ret; - return ret; -} - -EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - MS_EXCEPTION_IF_NULL(cache_); - auto iter = cache_->find(args_spec_list); - if (iter != cache_->end()) { - return iter->second; - } - - // Call the original evaluator, get the result: y = f(x) - EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr); - // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input - // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) - AbstractBasePtrList bparams; - bparams.push_back(SensitivityTransform(orig_func_)); - (void)std::transform( - args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), - [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); - AbstractBasePtr bparams_final = std::make_shared(bparams); - AbstractFunctionPtr bprop = - std::make_shared(SensitivityTransform(result->abstract()), bparams_final); - - // J(f)(J(x)) return a tuple (y, bprop_f) - AbstractBasePtrList jargs = {result->abstract(), bprop}; - AbstractBasePtr jtuple = std::make_shared(jargs); - auto infer_reuslt = std::make_shared(jtuple, std::make_shared()); - (*cache_)[args_spec_list] = infer_reuslt; - return infer_reuslt; -} - -EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { - if (args_spec_list.size() != args_spec_list_.size()) { - MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() - << ", arguments no: " << args_spec_list.size(); - } - // Check each parameter and argument match; - for (std::size_t i = 0; i < args_spec_list.size(); i++) { - MS_EXCEPTION_IF_NULL(args_spec_list[i]); - (void)args_spec_list[i]->Join(args_spec_list_[i]); - } - return std::make_shared(output_, std::make_shared()); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h deleted file mode 100644 index 079c1aac61..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h +++ /dev/null @@ -1,330 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ -#define PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ - -#include -#include -#include -#include - -#include "pipeline/static_analysis/static_analysis.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace abstract { -using EvaluatorCacheMap = - std::unordered_map; -using EvaluatorCacheMapPtr = std::shared_ptr; - -using EvaluatorAttrMap = - std::unordered_map; -using EvaluatorAttrMapPtr = std::shared_ptr; - -class Evaluator : public Base { - public: - explicit Evaluator(const std::string &id) - : cache_(std::make_shared()), - attr_cache_(std::make_shared()), - identifier_(id) {} - ~Evaluator() override = default; - MS_DECLARE_PARENT(Evaluator, Base); - - // difference between Run() and Eval(): - // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr. - // Run() will modify cache_ member, so it cannot marked as const; - virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); - - virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; - - virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } - - virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { - return args_spec_list; - } - - virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); - if (!enable_sparse) { - return nullptr; - } - - auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { - if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { - return true; - } - return false; - }); - if (is_abstract) { - MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; - return std::make_shared(std::make_shared(), std::make_shared()); - } - return nullptr; - } - - std::string ToString() const override { return identifier_; } - - virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } - - virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); } - - EvaluatorCacheMapPtr &cache() { return cache_; } - EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; } - - EvaluatorCacheMapPtr cache_; - EvaluatorAttrMapPtr attr_cache_; - std::string identifier_; - - AnfNodeWeakPtr bound_node_; -}; - -class PrimEvaluator : public Evaluator { - public: - explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} - ~PrimEvaluator() override = default; - MS_DECLARE_PARENT(PrimEvaluator, Evaluator); - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } -}; - -class TrivialPrimEvaluator : public PrimEvaluator { - public: - explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} - ~TrivialPrimEvaluator() override = default; - MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0; -}; - -class TransitionPrimEvaluator : public PrimEvaluator { - public: - explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} - ~TransitionPrimEvaluator() override = default; - MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - // Parameter in_conf0 : the first element in args_conf_list; - virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0; -}; - -class SymbolicPrimEvaluator : public PrimEvaluator { - public: - explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {} - ~SymbolicPrimEvaluator() override = default; - MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator); - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final; - virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0; -}; - -// Evaluator will be stored in AnalysisEngine.constructors_ -using EvaluatorPtrList = std::vector; - -class DummyEvaluator : public Evaluator { - public: - DummyEvaluator() : Evaluator("dummy") {} - ~DummyEvaluator() override = default; - MS_DECLARE_PARENT(DummyEvaluator, Evaluator); - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } -}; - -// Wrap another evaluator to track a subset of uses. -// A TrackedEvaluator has its own cache that maps possible calls to -// their results, but is ultimately backed by a different evaluator. -// Multiple TrackedEvaluators can be backed by the same Evaluator. -class TrackedEvaluator : public Evaluator { - public: - explicit TrackedEvaluator(const EvaluatorPtr &subinf) : Evaluator("TrackedEvaluator"), sub_evaluator_(subinf) {} - ~TrackedEvaluator() override = default; - MS_DECLARE_PARENT(TrackedEvaluator, Evaluator); - AnfNodePtr bound_node() const override { - if (sub_evaluator_ != nullptr) { - return sub_evaluator_->bound_node(); - } - return bound_node_.lock(); - } - - void set_bound_node(const AnfNodePtr &node) override { - if (sub_evaluator_ != nullptr) { - sub_evaluator_->set_bound_node(node); - } - bound_node_ = AnfNodeWeakPtr(node); - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; - std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); } - - private: - EvaluatorPtr sub_evaluator_; -}; - -class BaseFuncGraphEvaluator : public Evaluator { - public: - explicit BaseFuncGraphEvaluator(const AnalysisContextPtr &context) - : Evaluator("basegraph"), parent_context_(context) {} - - ~BaseFuncGraphEvaluator() override = default; - MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); - - EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - - virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; - - AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list); - AnalysisContextPtr graph_context() const { return graph_context_; } - - protected: - AnalysisContextPtr parent_context_; - - private: - AnalysisContextPtr graph_context_; -}; - -class FuncGraphEvaluator : public BaseFuncGraphEvaluator { - public: - FuncGraphEvaluator(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context) - : BaseFuncGraphEvaluator(context->Filter(func_graph)), func_graph_(func_graph) {} - - ~FuncGraphEvaluator() override = default; - MS_DECLARE_PARENT(FuncGraphEvaluator, BaseFuncGraphEvaluator); - - FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - - FuncGraphPtr func_graph() { return func_graph_; } - - AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; - AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; - std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } - - private: - FuncGraphPtr func_graph_; - std::unordered_map - func_graph_cache_; - std::vector trace_; -}; -using FuncGraphEvaluatorPtr = std::shared_ptr; - -class MetaFuncGraphEvaluator : public BaseFuncGraphEvaluator { - public: - // Note: context parameter is not used; - MetaFuncGraphEvaluator(const MetaFuncGraphPtr &meta_func_graph, AnalysisContextPtr, const ScopePtr &scope) - : BaseFuncGraphEvaluator(AnalysisContext::DummyContext()), meta_func_graph_(meta_func_graph), scope_(scope) {} - ~MetaFuncGraphEvaluator() override = default; - MS_DECLARE_PARENT(MetaFuncGraphEvaluator, BaseFuncGraphEvaluator); - - FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - - // Return normalized versions of the arguments. - AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { - return meta_func_graph_->NormalizeArgs(args_spec_list); - } - std::string ToString() const override { return identifier_ + "_" + meta_func_graph_->ToString(); } - - private: - MetaFuncGraphPtr meta_func_graph_; - std::unordered_map - func_graph_cache_; - ScopePtr scope_; -}; - -class PartialAppEvaluator : public Evaluator { - public: - PartialAppEvaluator(const EvaluatorPtr &evaluator, const AbstractBasePtrList &args) - : Evaluator("PartialAppEvaluator"), evaluator_(evaluator), args_spec_list_(args) {} - ~PartialAppEvaluator() override = default; - MS_DECLARE_PARENT(PartialAppEvaluator, Evaluator); - AnfNodePtr bound_node() const override { - if (evaluator_ != nullptr) { - return evaluator_->bound_node(); - } - return bound_node_.lock(); - } - - void set_bound_node(const AnfNodePtr &node) override { - if (evaluator_ != nullptr) { - evaluator_->set_bound_node(node); - } - bound_node_ = AnfNodeWeakPtr(node); - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; - } - - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; - std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } - - private: - EvaluatorPtr evaluator_; - AbstractBasePtrList args_spec_list_; -}; - -class VirtualEvaluator : public Evaluator { - public: - VirtualEvaluator(const AbstractBasePtrList &args_spec_list, const AbstractBasePtr &output) - : Evaluator("virtual"), args_spec_list_(args_spec_list), output_(output) {} - ~VirtualEvaluator() override = default; - MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); - - EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; - std::string ToString() const override { return identifier_; } - - private: - AbstractBasePtrList args_spec_list_; - AbstractBasePtr output_; -}; - -class JEvaluator : public Evaluator { - public: - JEvaluator(const EvaluatorPtr &evaluator, const AbstractFunctionPtr &orig_func) - : Evaluator("JEvaluator"), evaluator_(evaluator), orig_func_(orig_func) {} - ~JEvaluator() override = default; - MS_DECLARE_PARENT(JEvaluator, Evaluator); - AnfNodePtr bound_node() const override { - if (evaluator_ != nullptr) { - return evaluator_->bound_node(); - } - return bound_node_.lock(); - } - - void set_bound_node(const AnfNodePtr &node) override { - if (evaluator_ != nullptr) { - evaluator_->set_bound_node(node); - } - bound_node_ = AnfNodeWeakPtr(node); - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; - } - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override; - std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); } - - private: - EvaluatorPtr evaluator_; - AbstractFunctionPtr orig_func_; -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_EVALUATOR_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc deleted file mode 100644 index bf16bb5237..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ /dev/null @@ -1,1384 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/prim.h" - -#include -#include -#include -#include -#include -#include - -#include "operator/cc_implementations.h" -#include "operator/ops.h" -#include "operator/composite/do_signature.h" -#include "operator/prim_to_function.h" -#include "abstract/utils.h" -#include "utils/symbolic.h" -#include "./common.h" -#include "pipeline/resource.h" -#include "pipeline/parse/resolve.h" -#include "ir/tensor.h" -#include "utils/convert_utils.h" -#include "utils/context/ms_context.h" -#include "pipeline/parse/data_converter.h" -#include "abstract/param_validator.h" -#include "common/utils.h" - -namespace mindspore { -namespace abstract { -PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { - static PrimitiveEvalImplMap prim_eval_implement_map = { - // Statements - {prim::kPrimReturn, {InferImplReturn, true}}, - {prim::kPrimTypeOf, {InferImplTypeof, false}}, - {prim::kPrimHasType, {InferImplHasType, false}}, - {prim::kPrimDot, {InferImplDot, true}}, - {prim::kPrimSwitch, {InferImplSwitch, true}}, - {prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}}, - {prim::kPrimIs_, {InferImplIs_, true}}, - {prim::kPrimIsNot, {InferImplIsNot, true}}, - {prim::kPrimInDict, {InferImplInDict, true}}, - {prim::kPrimNotInDict, {InferImplNotInDict, true}}, - {prim::kPrimIsConsant, {InferImplIsConstant, true}}, - // Maths - {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, - {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, - // Array - {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, - {prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}}, - {prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}}, - {prim::kPrimShape, {InferImplShape, true}}, - {prim::kPrimPack, {InferImplPack, true}}, - // Structure - {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, - {prim::kPrimMakeList, {InferImplMakeList, true}}, - {prim::kPrimMakeDict, {InferImplMakeDict, true}}, - {prim::kPrimMakeSlice, {InferImplMakeSlice, true}}, - {prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}}, - {prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}}, - {prim::kPrimMakeRecord, {InferImplMakeRecord, false}}, - {prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}}, - {prim::kPrimListGetItem, {InferImplListGetItem, true}}, - {prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}}, - {prim::kPrimListSetItem, {InferImplListSetItem, true}}, - {prim::kPrimDictGetItem, {InferImplDictGetItem, true}}, - {prim::kPrimDictSetItem, {InferImplDictSetItem, true}}, - {prim::kPrimListAppend, {InferImplListAppend, true}}, - {prim::kPrimTupleLen, {InferImplTupleLen, true}}, - {prim::kPrimListLen, {InferImplListLen, true}}, - {prim::kPrimArrayLen, {InferImplArrayLen, true}}, - {prim::kPrimListMap, {InferImplListMap, false}}, - {prim::kPrimListReduce, {InferImplListReduce, false}}, - {prim::kPrimTupleReversed, {InferImplTupleReversed, false}}, - {prim::kPrimReducedShape, {InferImplReduceShape, false}}, - {prim::kPrimTupleDiv, {InferImplTupleDiv, false}}, - {prim::kPrimTupleToArray, {InferImplTuple2Array, false}}, - {prim::kPrimShapeMul, {InferImplShapeMul, false}}, - {prim::kPrimTupleEqual, {InferImplTupleEqual, false}}, - {prim::kPrimListEqual, {InferImplListEqual, false}}, - {prim::kPrimMakeRange, {InferImplMakeRange, false}}, - {prim::kPrimStopGradient, {InferImplStopGradient, false}}, - {prim::kPrimStringEqual, {InferImplStringEqual, false}}, - {prim::kPrimStringConcat, {InferImplStringConcat, false}}, - {prim::kPrimDictLen, {InferImplDictLen, false}}, - // NN - {prim::kPrimPooling, {InferImplPooling, true}}, - {prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}}, - {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, - {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, - {prim::kPrimReluGrad, {InferImplReluGrad, true}}, - {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, - {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, - {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, - {prim::kPrimRelu, {InferImplRelu, true}}, - {prim::kPrimFakeBprop, {InferImplFakeBprop, false}}, - {prim::kPrimZerosLike, {InferImplZerosLike, true}}, - {prim::kPrimBpropCut, {InferImplBpropCut, true}}, - {prim::kPrimLayerNorm, {InferImplLayerNorm, true}}, - {prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}}, - {prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}}, - // Others - {prim::kPrimIdentity, {InferImplIdentity, true}}, - // Set impl to null as it will use PartialEvaluator; - {prim::kPrimPartial, {nullptr, true}}, - {prim::kPrimJ, {InferImplJ, false}}, - {prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}}, - {prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}}, - {prim::kPrimEnvAdd, {InferImplEnvAdd, true}}, - {prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}}, - {prim::kPrimMakeRef, {InferImplMakeRef, true}}, - {prim::kPrimGetRefKey, {InferImplGetRefKey, true}}, - {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, - {prim::kPrimGetRefOrigin, {InferImplGetRefOrigin, true}}, - {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, - {prim::kPrimDepend, {InferImplDepend, true}}, - {prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}}, - {prim::kPrimControlDepend, {InferImplControlDepend, true}}, - // Debug - {prim::kPrimDebug, {InferImplDebug, true}}, - // IndexedSlices - {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, - {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, - {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, - {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, - {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, - }; - return prim_eval_implement_map; -} - -using mindspore::parse::PyObjectWrapper; - -EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { - if (prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; - return ret_abstract; - } - } - prim_->BeginRecordAddAttr(); - AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); - prim_->EndRecordAddAttr(); - auto added_attrs = prim_->evaluate_added_attrs(); - auto infer_result = std::make_shared(abs_base, std::make_shared(added_attrs)); - return infer_result; -} - -EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - auto ret_abstract = AbstractEval(args_spec_list); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; - return ret_abstract; - } - - if (out_conf->node() == nullptr || !out_conf->node()->isa()) { - MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; - } - - auto do_signature = dyn_cast(prim_); - auto out_node = dyn_cast(out_conf->node()); - const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { - MS_LOG(EXCEPTION) << "Op: " << do_signature->function()->ToString() - << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() - << ", inputs size " << out_node_inputs.size(); - } - AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - - AnfNodePtr new_cnode = nullptr; - if (bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, - args_inputs); - TraceManager::EndTrace(); - } else { - new_cnode = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), do_signature->function(), args_spec_list, - args_inputs); - } - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); - - return engine->ForwardConfig(out_conf, fn_conf); -} - -static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { - // arg[0] is the func graph to unpack, ignore it - AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); - AbstractBasePtrList graph_specialize_args; - if (need_unpack) { - for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) { - MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]); - if (specialize_args_before_unpack[index]->isa()) { - AbstractTuplePtr arg_tuple = specialize_args_before_unpack[index]->cast(); - std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), - std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; }); - } else if (specialize_args_before_unpack[index]->isa()) { - AbstractDictionaryPtr arg_dict = specialize_args_before_unpack[index]->cast(); - auto dict_elems = arg_dict->elements(); - (void)std::transform( - dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args), - [](const AbstractAttribute &item) { return std::make_shared(item.first, item.second); }); - } else { - MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " - << specialize_args_before_unpack[index]->ToString(); - } - } - } else { - graph_specialize_args = specialize_args_before_unpack; - } - return graph_specialize_args; -} - -EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - if (out_conf->node() == nullptr || !out_conf->node()->isa()) { - MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; - } - - auto unpack_graph = prim_->cast(); - auto out_node = out_conf->node()->cast(); - const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { - MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" - << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() - << ", inputs size " << out_node_inputs.size(); - } - AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - // get the forward graph - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - AbstractFunctionPtr fn = args_spec_list[0]->cast(); - if (fn == nullptr) { - MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); - } - auto real_fn = fn->cast(); - MS_EXCEPTION_IF_NULL(real_fn); - FuncGraphPtr forward_graph = real_fn->func_graph(); - MS_EXCEPTION_IF_NULL(forward_graph); - AbstractBasePtrList graph_specialize_args = - GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); - - AbstractBasePtrList graph_specialize_args_without_sens; - (void)std::transform(graph_specialize_args.begin(), - graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), - std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); - auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens); - engine->func_graph_manager()->AddFuncGraph(new_graph); - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - AnfNodePtr new_vnode = NewValueNode(new_graph); - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); - - return engine->ForwardConfig(out_conf, fn_conf); -} - -AnfNodePtr MixedPrecisionCastHelper(AnfNodePtr source_node, AbstractBasePtr node_type, AnfNodePtr target_type, - FuncGraphPtr func_graph) { - AnfNodePtr target_node = source_node; - if (node_type->isa()) { - auto x = node_type->cast(); - if (x->element()->BuildType()->isa()) { - auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); - MS_EXCEPTION_IF_NULL(cast); - target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); - } - } else if (node_type->isa()) { - auto x = node_type->cast(); - auto &items = x->elements(); - std::vector nodes; - nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - int idx = 0; - for (const auto &item : items) { - AnfNodePtr tuple_node = - func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)}); - AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph); - nodes.emplace_back(node); - ++idx; - } - target_node = func_graph->NewCNode(nodes); - } else if (node_type->isa()) { - auto x = node_type->cast(); - auto &items = x->elements(); - std::vector dict_key_nodes; - std::vector dict_value_nodes; - dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - for (const auto &item : items) { - AnfNodePtr dict_value_node = - func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)}); - AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph); - dict_key_nodes.emplace_back(NewValueNode(item.first)); - dict_value_nodes.emplace_back(node); - } - target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes), - func_graph->NewCNode(dict_value_nodes)}); - } else if (node_type->isa()) { - auto x = node_type->cast(); - std::string kwarg_key = x->get_key(); - AnfNodePtr kwarg_value_node = - func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node}); - AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph); - target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node}); - } - return target_node; -} - -EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf) { - AbstractBasePtrList args_spec_list; - if (out_conf->node() == nullptr || !out_conf->node()->isa()) { - MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; - } - auto out_node = out_conf->node()->cast(); - const auto &out_node_inputs = out_node->inputs(); - if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { - MS_LOG(EXCEPTION) << "MixedPrecisionCast" - << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() - << ", inputs size " << out_node_inputs.size(); - } - AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); }); - - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - - FuncGraphPtr func_graph = out_conf->node()->func_graph(); - AnfNodePtr new_node = MixedPrecisionCastHelper(out_node_inputs[2], args_spec_list[1], out_node_inputs[1], func_graph); - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context()); - - return engine->ForwardConfig(out_conf, fn_conf); -} - -namespace { -py::object BuildValue(const ValuePtr &value_ptr) { - if (value_ptr == nullptr) { - return py::none(); - } else { - return ValuePtrToPyData(value_ptr); - } -} -} // end anonymous namespace - -py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { - MS_EXCEPTION_IF_NULL(abs_base); - py::dict dic; - if (abs_base->isa()) { - auto arg_tensor = dyn_cast(abs_base); - dic["shape"] = arg_tensor->shape()->shape(); - dic["dtype"] = arg_tensor->BuildType(); - dic["value"] = BuildValue(arg_tensor->BuildValue()); - } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { - std::vector shape; - dic["shape"] = shape; - dic["dtype"] = abs_base->BuildType(); - dic["value"] = BuildValue(abs_base->BuildValue()); - } else if (abs_base->isa()) { - auto arg_slice = dyn_cast(abs_base); - std::vector shape; - dic["shape"] = shape; - dic["dtype"] = arg_slice->BuildType(); - dic["value"] = BuildValue(arg_slice->BuildValue()); - } else if (abs_base->isa()) { - auto value = abs_base->cast()->ref(); - dic = ConvertAbstractToPython(value); - } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = py::ellipsis(); - dic["value"] = py::ellipsis(); - } else if (abs_base->isa()) { - auto arg_tuple = dyn_cast(abs_base); - size_t len = arg_tuple->size(); - py::tuple shape_tuple(len); - py::tuple dtype_tuple(len); - - for (size_t i = 0; i < len; i++) { - py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]); - shape_tuple[i] = out["shape"]; - dtype_tuple[i] = out["dtype"]; - } - dic["shape"] = shape_tuple; - dic["dtype"] = dtype_tuple; - dic["value"] = BuildValue(arg_tuple->BuildValue()); - } else if (abs_base->isa()) { - auto arg_list = dyn_cast(abs_base); - size_t len = arg_list->size(); - py::list shape_list(len); - py::list dtype_list(len); - - for (size_t i = 0; i < len; i++) { - py::dict out = ConvertAbstractToPython(arg_list->elements()[i]); - shape_list[i] = out["shape"]; - dtype_list[i] = out["dtype"]; - } - dic["shape"] = shape_list; - dic["dtype"] = dtype_list; - dic["value"] = BuildValue(arg_list->BuildValue()); - } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = py::none(); - dic["value"] = py::none(); - } else if (abs_base->isa()) { - dic["shape"] = py::none(); - dic["dtype"] = abs_base->BuildType(); - dic["value"] = py::none(); - } else { - auto value = abs_base->BuildValue(); - if ((*value == *kAnyValue)) { - auto value_desc = abs_base->value_desc(); - MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc) - << " for python primitive." << abs_base->ToString(); - } - MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is " - << value->ToString(); - } - return dic; -} - -namespace { -py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrList &args) { - const AbstractBasePtrList *args_ptr; - - if (prim_py->is_tuple_input_) { - if (args.empty()) { - MS_LOG(EXCEPTION) << "Primitive args is empty"; - } - if (args[0] == nullptr || !args[0]->isa()) { - MS_LOG(EXCEPTION) << "Custom Primitive inputs should be packed into a Tuple after converting" - "prim convert pass for GE."; - } - args_ptr = &(args[0]->cast()->elements()); - } else { - args_ptr = &args; - } - - py::tuple py_args(args_ptr->size()); - for (size_t i = 0; i < args_ptr->size(); i++) { - auto arg_i = (*args_ptr)[i]; - py_args[i] = ConvertAbstractToPython(arg_i); - } - return py_args; -} - -AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) { - // Convert to AbstractValue based on type and shape - if (output["value"].is_none()) { - auto out_shape = output["shape"]; - auto out_dtype = output["dtype"]; - return PyListDtype2AbstractTensor(out_shape, out_dtype); - } - // Convert pyobject to Value, then to AbstractValue - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(output["value"], &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Convert data failed"; - } - auto res_spec = FromValue(converted_ret); - MS_EXCEPTION_IF_NULL(res_spec); - if (res_spec->isa()) { - // Replace to tensor constant node in specialize - auto res_tensor = res_spec->cast(); - res_tensor->set_value(converted_ret); - } - if (prim_py->IsCustomPrim()) { - // Raise error if output_num is not match the infer result. - int output_num = GetValue(prim_py->GetAttr("output_num")); - if (res_spec->isa() && output_num != 1) { - MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num - << " not matches the infer result."; - } else if (res_spec->isa() && - (res_spec->cast()->size() != IntToSize(output_num))) { - MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num - << " not matches the infer result."; - } - } - return res_spec; -} -} // end anonymous namespace - -EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; - return ret_abstract; - } - MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); - - const auto &iter = cache_->find(args); - if (iter != cache_->end()) { - return iter->second; - } - auto py_args = PreparePyInputs(prim_py_, args); - - auto pyobj = prim_py_->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; - } - auto infer_fuc = pyobj.attr("__infer__"); - prim_py_->BeginRecordAddAttr(); - py::dict output = infer_fuc(*py_args); - prim_py_->EndRecordAddAttr(); - auto added_attrs = prim_py_->evaluate_added_attrs(); - MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); - auto res_spec = PyInferRes2Abstract(prim_py_, output); - - MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; - auto infer_result = std::make_shared(res_spec, std::make_shared(added_attrs)); - (*cache_)[args] = infer_result; - return infer_result; -} - -EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { - auto ret_abstract = AbstractEval(args); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; - return ret_abstract; - } - // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. - if (nargs_ != args.size()) { - MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; - return nullptr; - } - TypePtr ret_value_type = return_value_type_; - ValuePtrList value_list; - for (const auto &arg : args) { - // Check if all arguments are scalar type. - MS_EXCEPTION_IF_NULL(arg); - if (arg->isa()) { - auto arg_scalar = dyn_cast(arg); - auto arg_value = arg_scalar->GetValueTrack(); - value_list.push_back(arg_value); - } else { - // Raise TypeError Expected Scalar. - MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives."; - } - } - for (const auto &item : type_map_) { - TypePtrList selections; - MS_EXCEPTION_IF_NULL(item.second); - (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections), - [&args](size_t arg_idx) -> TypePtr { return args[arg_idx]->GetTypeTrack(); }); - TypePtr res = CheckTypeList(item.first, selections); - if (*return_value_type_ == *(item.first)) { - ret_value_type = res; - } - } - - ValuePtr evaluated_value = RunImpl(value_list); - if (!(*evaluated_value == *kAnyValue)) { - ret_value_type = evaluated_value->type(); - } - // for comparison primitives , return type shall have be specified to be bool. - if (specify_out_type_ != nullptr) { - ret_value_type = specify_out_type_; - } - - AbstractScalarPtr abs_base = std::make_shared(evaluated_value, ret_value_type); - return std::make_shared(abs_base, std::make_shared()); -} - -ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const { - if (!eval_value_) { - return kAnyValue; - } else { - if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) { - MS_EXCEPTION_IF_NULL(arg); - return arg->isa(); - })) { - return kAnyValue; - } - return impl_(args); - } -} - -// Primitive implementation -// static function start -namespace { -EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveEvalImpl eval_impl) { - EvaluatorPtr prim_evaluator = std::make_shared(primitive, eval_impl); - return prim_evaluator; -} - -EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value, - const TypePtr &specify_out_type) { - FunctionPtr func = nullptr; - (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func); - MS_EXCEPTION_IF_NULL(func); - - EvaluatorPtr uniform_primitive_evaluator = - std::make_shared(func, prim_impl, eval_value, specify_out_type); - return uniform_primitive_evaluator; -} - -const int kResolveCaseUserDefineClass = 1; -const int kResolveCaseBuildinTypeMethod = 2; -const int kResolveCaseFunction = 3; -int GetResolveCase(const TypePtr &data_type) { - MS_EXCEPTION_IF_NULL(data_type); - if (data_type->type_id() == kObjectTypeClass) { - return kResolveCaseUserDefineClass; - } - - // try method map, if not in method map, the data_type should be External type. - if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) { - return kResolveCaseBuildinTypeMethod; - } - - return kResolveCaseFunction; -} - -FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) { - MS_EXCEPTION_IF_NULL(engine); - MS_EXCEPTION_IF_NULL(method); - if (!method->isa()) { - MS_LOG(EXCEPTION) << "Method type error: " << method->ToString(); - } - - std::shared_ptr obj = method->cast>(); - FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj()); - if (func_graph == nullptr) { - MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed"; - } - - FuncGraphManagerPtr manager = engine->func_graph_manager(); - manager->AddFuncGraph(func_graph); - return func_graph; -} - -inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) { - MS_EXCEPTION_IF_NULL(engine); - FuncGraphManagerPtr manager = engine->func_graph_manager(); - manager->AddFuncGraph(func_graph); -} - -EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &old_conf) { - MS_EXCEPTION_IF_NULL(old_conf); - - AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf); - AbstractFunctionPtr abs_func = dyn_cast(abs_ptr); - MS_EXCEPTION_IF_NULL(abs_func); - - // Create new cnode - std::vector input = {NewValueNode(prim::kPrimPartial)}; - auto func_graph_func = dyn_cast(abs_func); - if (func_graph_func != nullptr) { - FuncGraphPtr fg = func_graph_func->func_graph(); - input.push_back(NewValueNode(fg)); - } else { - auto prim_func = dyn_cast(abs_func); - MS_EXCEPTION_IF_NULL(prim_func); - PrimitivePtr prim = prim_func->prim(); - input.push_back(NewValueNode(prim)); - } - - AnfNodeConfigPtr conf = dyn_cast(data_conf); - MS_EXCEPTION_IF_NULL(conf); - input.push_back(conf->node()); - MS_EXCEPTION_IF_NULL(old_conf); - FuncGraphPtr func_graph = old_conf->node()->func_graph(); - CNodePtr new_cnode = func_graph->NewCNode(input); - AnalysisEnginePtr eng = old_conf->engine(); - AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context()); - return eng->ForwardConfig(old_conf, fn_conf); -} - -EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list, - const AnfNodeConfigPtr &out_conf) { - // args_spec_list: same as StaticGetter - if (args_spec_list.size() < 2) { - MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2"; - } - MS_EXCEPTION_IF_NULL(out_conf); - // An external type. - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString(); - MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString(); - auto data_v = args_spec_list[0]->BuildValue(); - if (!data_v->isa()) { - MS_LOG(EXCEPTION) << "Data is not NameSpace : " << data_v->ToString(); - } - - auto item_v = args_spec_list[1]->BuildValue(); - if (item_v->isa()) { - item_v = std::make_shared(item_v->cast()->value()); - } - - if (!item_v->isa()) { - MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_v->ToString(); - } - - // item_name to func addr from obj_map - parse::SymbolPtr symbol = item_v->cast(); - parse::NameSpacePtr name_space = data_v->cast(); - FuncGraphPtr func_graph = out_conf->node()->func_graph(); - - auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); - if (new_node == nullptr) { - MS_LOG(EXCEPTION) << "Resolve node failed"; - } - - AnalysisEnginePtr eng = out_conf->engine(); - AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); - return eng->ForwardConfig(out_conf, fn_conf); -} - -EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine, - const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v, - const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "args_spec_list is empty"; - } - AbstractClassPtr cls = CheckArg("__FUNC__", args_spec_list, 0); - - // If item_v is an attribute, get abstract value from AbstractClass - MS_EXCEPTION_IF_NULL(item_v); - if (!item_v->isa()) { - MS_LOG(EXCEPTION) << "Attribute type error"; - } - std::string item_name = item_v->cast()->value(); - MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name(); - MS_LOG(DEBUG) << "Resolve item: " << item_name; - - AbstractBasePtr attr = cls->GetAttribute(item_name); - if (attr != nullptr) { - return std::make_shared(attr, nullptr); - } - - ValuePtr method = cls->GetMethod(item_name); - if (method->isa()) { - MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() - << ", item value: " << item_v->ToString(); - } - - // Infer class method - ValuePtr converted_v = PyObjToGraph(engine, method); - return StaticGetterInferred(converted_v, data_conf, out_conf); -} - -EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v, - const TypePtr &data_type, const ConfigPtr &data_conf, - const AnfNodeConfigPtr &out_conf) { - MS_EXCEPTION_IF_NULL(item_v); - MS_EXCEPTION_IF_NULL(data_type); - // The method maybe a Primitive or Composite - if (!item_v->isa()) { - MS_LOG(EXCEPTION) << "Error item is not string"; - } - - std::string item_name = item_v->cast()->value(); - Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name); - if (method.empty()) { - MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name; - } - - ValuePtr converted_v = nullptr; - if (method.is()) { - // composite registered in standard_method_map go to this branch - converted_v = prim::GetPythonOps(method.cast()); - AddToManager(engine, converted_v->cast()); - } else if (method.is()) { - converted_v = method.cast(); - } else { - MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString(); - } - return StaticGetterInferred(converted_v, data_conf, out_conf); -} - -EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) { - // Inputs: namespace and its static function; or class and its member function - CheckArgsSize("StaticGetter", args_spec_list, 2); - - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - MS_EXCEPTION_IF_NULL(args_spec_list[1]); - TypePtr data_type = args_spec_list[0]->BuildType(); - ValuePtr item_value = args_spec_list[1]->BuildValue(); - ScopePtr scope = kDefaultScope; - if (out_conf != nullptr) { - scope = out_conf->node()->scope(); - } - ScopeGuard scope_guard(scope); - if (item_value->isa()) { - MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString(); - } - - int case_v = GetResolveCase(data_type); - if (case_v == kResolveCaseUserDefineClass) { - return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf); - } else if (case_v == kResolveCaseBuildinTypeMethod) { - return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf); - } else { - return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf); - } -} -} // end anonymous namespace - -// static variable start; -namespace { -class EmbedEvaluator : public SymbolicPrimEvaluator { - public: - EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {} - ~EmbedEvaluator() override = default; - MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator); - EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { - // arg: free variable to be embedded - if (args_conf_list.size() != 1) { - MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size(); - } - AnfNodeConfigPtr node_conf = dyn_cast(args_conf_list[0]); - MS_EXCEPTION_IF_NULL(node_conf); - - AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract(); - x = SensitivityTransform(x); - SymbolicKeyInstancePtr key = std::make_shared(node_conf->node(), x); - AbstractScalarPtr abs_scalar = std::make_shared(key, std::make_shared()); - return std::make_shared(abs_scalar, std::make_shared()); - } -}; - -static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) { - auto root_g_set = manager->roots(); - if (root_g_set.size() != 1) { - return nullptr; - } - const FuncGraphPtr &root_g = root_g_set.back(); - - for (auto ¶m_node : root_g->parameters()) { - auto param = param_node->cast(); - if (param && name == param->name()) { - return param; - } - } - return nullptr; -} - -class RefToEmbedEvaluator : public SymbolicPrimEvaluator { - public: - RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {} - ~RefToEmbedEvaluator() override = default; - MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator); - EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override { - if (args_conf_list.size() != 1) { - MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size(); - return nullptr; - } - static TypePtr type = std::make_shared(); - auto node_conf = dyn_cast(args_conf_list[0]); - if (node_conf == nullptr) { - MS_LOG(ERROR) << "Conf should be AnfNodeConfig"; - return nullptr; - } - AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract(); - AbstractRefPtr ref_abs = abs->cast(); - if (ref_abs == nullptr) { - MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString(); - return nullptr; - } - auto key_abs = ref_abs->ref_key(); - if (key_abs == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr."; - return nullptr; - } - auto key_value = key_abs->BuildValue(); - if (key_value == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr."; - return nullptr; - } - auto refkey = key_value->cast(); - if (refkey == nullptr) { - auto ret = std::make_shared(type); - auto ref_value = ref_abs->ref(); - MS_EXCEPTION_IF_NULL(ref_value); - return std::make_shared(ret, std::make_shared()); - } - - std::string name = refkey->tag(); - const auto &manager = node_conf->node()->func_graph()->manager(); - auto node = FindParameterNodeByString(manager, name); - if (node == nullptr) { - MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph."; - return nullptr; - } - AbstractBasePtr x = ref_abs->ref(); - x = SensitivityTransform(x); - std::shared_ptr key = std::make_shared(node, x); - std::shared_ptr abs_scalar = std::make_shared(key, type); - return std::make_shared(abs_scalar, std::make_shared()); - } -}; - -class GetAttrEvaluator : public TransitionPrimEvaluator { - public: - GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {} - ~GetAttrEvaluator() override = default; - MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { - auto ret_abstract = AbstractEval(args_spec_list); - if (ret_abstract != nullptr) { - MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; - return ret_abstract; - } - // Inputs: data, item - if (args_spec_list.size() != 2) { - MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); - } - EvalResultPtr ret = nullptr; - if (bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - TraceManager::EndTrace(); - } else { - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - } - // don't lookup from cache, as different out_conf with same node but different context - // may add different entry to anfnode_config_map, like getattr primitive; - (*cache_)[args_spec_list] = ret; - return ret; - } -}; - -class ResolveEvaluator : public TransitionPrimEvaluator { - public: - ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {} - ~ResolveEvaluator() override = default; - MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, - const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { - // Inputs: namespace, symbol - if (args_spec_list.size() != 2) { - MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); - } - EvalResultPtr ret = nullptr; - if (bound_node() != nullptr) { - TraceManager::DebugTrace(std::make_shared(bound_node()->debug_info())); - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - TraceManager::EndTrace(); - } else { - ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf); - } - return ret; - } -}; - -class CreateInstanceEvaluator : public TransitionPrimEvaluator { - public: - CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {} - ~CreateInstanceEvaluator() override = default; - MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &, - const AnfNodeConfigPtr &out_conf) override { - if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty"; - } - - // get the type parameter - MS_EXCEPTION_IF_NULL(args_spec_list[0]); - TypePtr type = args_spec_list[0]->GetTypeTrack(); - if (type->type_id() != kMetaTypeTypeType) { - MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got " - << type->ToString(); - } - - ValuePtr value_track = args_spec_list[0]->GetValueTrack(); - MS_EXCEPTION_IF_NULL(value_track); - - std::shared_ptr type_obj = dyn_cast(value_track); - if (type_obj == nullptr) { - MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << "."; - } - - if (!type_obj->isa()) { - MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got " - << type_obj->ToString() << "."; - } - - auto class_type = type_obj->obj(); - MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << "."; - - // get the create instance obj's parameters - pybind11::tuple params = GetParameters(args_spec_list); - - // create class instance - auto obj = parse::data_converter::CreatePythonObject(class_type, params); - if (py::isinstance(obj)) { - MS_LOG(EXCEPTION) << "Create python object failed, only support Cell and Primitive type"; - } - - // process the object - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(obj, &converted_ret, true); - if (!converted) { - MS_LOG(EXCEPTION) << "Convert the python object failed"; - } - MS_EXCEPTION_IF_NULL(converted_ret); - - if (converted_ret->isa()) { - AddToManager(engine, converted_ret->cast()); - } - - AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); - auto infer_result = std::make_shared(ret, nullptr); - (*cache_)[args_spec_list] = infer_result; - return infer_result; - } - - pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const { - // Exclude class type by minus 1; - std::size_t params_size = args_spec_list.size() - 1; - auto params = py::tuple(params_size); - if (params_size > 0) { - for (size_t i = 0; i < params_size; i++) { - // Only support the Scalar parameters type. Bypass class type by offset with 1. - auto arg = args_spec_list[i + 1]; - MS_EXCEPTION_IF_NULL(arg); - // Because the Tensor's AbstractTensor can't get value from GetValueTrack. - ValuePtr param_value = arg->BuildValue(); - py::object param = ValuePtrToPyData(param_value); - params[i] = param; - } - } - return params; - } -}; - -class PartialEvaluator : public Evaluator { - public: - PartialEvaluator() : Evaluator("PartialEvaluator") {} - ~PartialEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, - AnfNodeConfigPtr out_conf = nullptr) override { - if (args_conf_list.size() == 0) { - MS_LOG(EXCEPTION) << "Args size should be greater than 0"; - } - - MS_EXCEPTION_IF_NULL(out_conf); - MS_EXCEPTION_IF_NULL(out_conf->node()); - auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract(); - AbstractBasePtrList args_spec_list{arg0_value}; - // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. - if (arg0_value->isa()) { - auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); - MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() - << " as func is: " << arg0_value->ToString(); - auto eval_result = std::make_shared(ret, std::make_shared()); - (*cache_)[args_spec_list] = eval_result; - return eval_result; - } - auto func = CheckArg("partial", args_spec_list, 0); - // Sometimes, node[0] in out_conf becomes phi0; - if (func->isa()) { - auto prim_func = dyn_cast(func); - if (prim_func->prim()->isa()) { - prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast(prim_func->prim()); - return HandleDoSignature(engine, do_signature_prim->function(), out_conf); - } - } - - (void)std::transform( - args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); - AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); - - auto cnode = out_conf->node()->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() != (args_conf_list.size() + 1)) { - MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() - << ", args_conf_list: " << mindspore::ToString(args_conf_list); - } - - AbstractFuncAtomPtrList partial_funcs_list; - auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { - auto new_func = std::make_shared(atom_func, args, cnode); - partial_funcs_list.push_back(new_func); - }; - func->Visit(build_partial); - - auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); - auto infer_result = std::make_shared(ret, std::make_shared()); - (*cache_)[args_spec_list] = infer_result; - return infer_result; - } - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, - const AnfNodeConfigPtr &out_conf = nullptr) const { - MS_EXCEPTION_IF_NULL(out_conf); - MS_EXCEPTION_IF_NULL(out_conf->node()); - auto cnode = out_conf->node()->cast(); - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "Cnode is nullptr"; - } - std::vector new_nodes_inputs = cnode->inputs(); - auto new_signature_value = std::make_shared("signature", signature_value); - new_nodes_inputs[1] = NewValueNode(new_signature_value); - FuncGraphPtr func_graph = cnode->func_graph(); - - ScopePtr scope = out_conf->node()->scope(); - ScopeGuard scope_guard(scope); - - CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs); - AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context()); - return engine->ForwardConfig(out_conf, fn_conf); - } -}; - -struct PrimitiveImplInferValue { - PrimitiveImpl impl_; // implement function of primitive - bool eval_value_; // whether evaluate value - TypePtr specify_out_type_; // whether specify return type - bool in_white_list_; // true if this Primitive in white list, else false. -}; - -using PrimitiveToImplMap = std::unordered_map; -PrimitiveToImplMap &GetUniformPrimitiveToImplMap() { - static PrimitiveToImplMap uniform_prim_implement_map = { - {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}}, - {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}}, - {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}}, - {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}}, - {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}}, - {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}}, - {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}}, - {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}}, - {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}}, - {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}}, - {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared(), true}}, - {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared(), true}}, - {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared(), true}}, - {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared(), true}}, - {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared(), true}}, - {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared(), true}}, - {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared(), true}}, - {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared(), true}}, - {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared(), true}}, - {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared(), true}}, - }; - return uniform_prim_implement_map; -} - -PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap(); -std::mutex PrimEvaluatorConstructorMutex; - -void InitPrimEvaluatorConstructors() { - PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; - - for (const auto &iter : GetPrimitiveToEvalImplMap()) { - constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_); - } - - for (const auto &iter : GetUniformPrimitiveToImplMap()) { - constructor[iter.first] = - InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_); - } - constructor[prim::kPrimEmbed] = std::make_shared(); - constructor[prim::kPrimRefToEmbed] = std::make_shared(); - constructor[prim::kPrimGetAttr] = std::make_shared(); - constructor[prim::kPrimResolve] = std::make_shared(); - constructor[prim::kPrimCreateInstance] = std::make_shared(); - constructor[prim::kPrimPartial] = std::make_shared(); -} -} // namespace - -void ClearPrimEvaluatorMap() { - PrimEvaluatorConstructors.clear(); - GetPrimitiveToEvalImplMap().clear(); - GetUniformPrimitiveToImplMap().clear(); -} - -bool IsInWhiteList(const PrimitivePtr primitive) { - MS_EXCEPTION_IF_NULL(primitive); - - auto iter = GetPrimitiveToEvalImplMap().find(primitive); - if (iter != GetPrimitiveToEvalImplMap().end()) { - return iter->second.in_white_list_; - } - - auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive); - if (uni_iter != GetUniformPrimitiveToImplMap().end()) { - return uni_iter->second.in_white_list_; - } - - return false; -} - -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) { - MS_EXCEPTION_IF_NULL(primitive); - auto iter = GetPrimitiveToEvalImplMap().find(primitive); - if (iter == GetPrimitiveToEvalImplMap().end()) { - return nullptr; - } - return iter->second.impl_; -} - -PrimEvaluatorMap &GetPrimEvaluatorConstructors() { - PrimEvaluatorMap &constructor = PrimEvaluatorConstructors; - if (!constructor.empty()) { - return constructor; - } - std::lock_guard initLock(PrimEvaluatorConstructorMutex); - if (constructor.empty()) { - InitPrimEvaluatorConstructors(); - } - - return constructor; -} - -namespace { -bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_tuple = dyn_cast(x); - auto model_tuple = dyn_cast(model); - - if (x_tuple == nullptr || model_tuple == nullptr) { - return false; - } - - if (model->IsGeneric()) { - return true; - } - - if (x_tuple->size() != model_tuple->size()) { - return false; - } - - for (size_t i = 0; i < x_tuple->size(); i++) { - bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]); - if (!is_subtype) { - return false; - } - } - return true; -} - -bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_tensor = dyn_cast(x); - auto model_tensor = dyn_cast(model); - - if (x_tensor == nullptr || model_tensor == nullptr) { - return false; - } - - if (model->IsGeneric()) { - return true; - } - - return IsSubtype(x_tensor->element(), model_tensor->element()); -} - -bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_list = dyn_cast(x); - auto model_list = dyn_cast(model); - - if (x_list == nullptr || model_list == nullptr) { - return false; - } - - if (model->IsGeneric()) { - return true; - } - - if (x_list->size() != model_list->size()) { - return false; - } - - bool is_subtype = true; - for (size_t i = 0; i < x_list->size(); i++) { - is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]); - if (!is_subtype) { - return false; - } - } - return is_subtype; -} - -bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - auto x_class = dyn_cast(x); - auto model_class = dyn_cast(model); - if (x_class == nullptr) { - return false; - } - if (model->IsGeneric()) { - return true; - } - - if (x_class->tag() == model_class->tag()) { - auto m_attributes = model_class->GetAttributes(); - auto x_attributes = x_class->attributes(); - if (m_attributes.size() != x_attributes.size()) { - return false; - } - - for (size_t i = 0; i < m_attributes.size(); i++) { - if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) { - return false; - } - } - return true; - } - - return false; -} - -inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - if (dyn_cast(x) == nullptr) { - return false; - } - TypePtr x_type = x->GetTypeTrack(); - return IsSubType(x_type, model); -} -} // namespace - -bool IsSubtype(const AbstractBasePtr x, const TypePtr model) { - MS_EXCEPTION_IF_NULL(x); - MS_EXCEPTION_IF_NULL(model); - TypeId model_typeid = model->type_id(); - switch (model_typeid) { - case kMetaTypeObject: - return true; - case kObjectTypeTuple: - return IsSubtypeTuple(x, model); - case kObjectTypeTensorType: - return IsSubtypeArray(x, model); - case kObjectTypeList: - return IsSubtypeList(x, model); - case kObjectTypeClass: - return IsSubtypeClass(x, model); - default: - if (IsSubType(model, std::make_shared())) { - return IsSubtypeScalar(x, model); - } - MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << "."; - } -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h deleted file mode 100644 index 5a686fbadc..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ /dev/null @@ -1,366 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_PRIM_H_ -#define PIPELINE_STATIC_ANALYSIS_PRIM_H_ - -#include -#include -#include -#include -#include - -#include "pipeline/static_analysis/evaluator.h" - -namespace mindspore { -namespace abstract { -using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &); -struct StandartPrimitiveImplReg { - StandardPrimitiveEvalImpl impl_; // Implement function of Primitive. - bool in_white_list_; // true if this Primitive in white list, else false. -}; - -using PrimitiveEvalImplMap = - std::unordered_map; - -class StandardPrimEvaluator : public TrivialPrimEvaluator { - public: - StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl) - : TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {} - ~StandardPrimEvaluator() override = default; - MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; - PrimitivePtr prim() { return prim_; } - - std::string ToString() const override { return identifier_ + prim_->name(); } - - private: - PrimitivePtr prim_; - const StandardPrimitiveEvalImpl eval_impl_; -}; - -using StandardPrimEvaluatorPtr = std::shared_ptr; - -class PythonPrimEvaluator : public TrivialPrimEvaluator { - public: - explicit PythonPrimEvaluator(const PrimitivePyPtr primitive) - : TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {} - ~PythonPrimEvaluator() override = default; - MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator); - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; - PrimitivePtr prim() { return dyn_cast(prim_py_); } - - std::string ToString() const override { return identifier_ + prim_py_->name(); } - - private: - PrimitivePyPtr prim_py_; -}; - -class DoSignatureEvaluator : public Evaluator { - public: - explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {} - ~DoSignatureEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - private: - PrimitivePtr prim_; -}; - -class UnpackGraphEvaluator : public Evaluator { - public: - explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} - ~UnpackGraphEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - private: - PrimitivePtr prim_; -}; - -class MixedPrecisionCastEvaluator : public Evaluator { - public: - explicit MixedPrecisionCastEvaluator(const PrimitivePtr primitive) - : Evaluator("MixedPrecisionCastEvaluator"), prim_(primitive) {} - ~MixedPrecisionCastEvaluator() override = default; - EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, - AnfNodeConfigPtr out_config = nullptr) override; - - EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { - MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called"; - } - - private: - PrimitivePtr prim_; -}; - -bool IsInWhiteList(PrimitivePtr primitive); -StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); - -using ValuePtrList = std::vector; -using PrimitiveImpl = ValuePtr (*)(const ValuePtrList &); - -class UniformPrimEvaluator : public TrivialPrimEvaluator { - public: - UniformPrimEvaluator(const FunctionPtr func_desc, PrimitiveImpl impl, bool eval_value, const TypePtr specify_out_type) - : TrivialPrimEvaluator("UniformPrimEvaluator"), - impl_(impl), - eval_value_(eval_value), - func_desc_(func_desc), - nargs_(func_desc_->args().size()), - return_value_type_(func_desc_->retval()), - specify_out_type_(specify_out_type) { - for (size_t i = 0; i < nargs_; ++i) { - TypePtr type = func_desc_->args()[i]; - if (type_map_[type]) { - type_map_[type]->push_back(i); - } else { - type_map_[type] = std::make_shared>(); - type_map_[type]->push_back(i); - } - } - } - ~UniformPrimEvaluator() override = default; - MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator); - - EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override; - ValuePtr RunImpl(const ValuePtrList &args) const; - - // If eval_value_ is False, return broadened arguments. - AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override { - if (!eval_value_) { - AbstractBasePtrList broadened_args_spec_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened_args_spec_list), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); - return broadened_args_spec_list; - } - return args_spec_list; - } - - private: - PrimitiveImpl impl_; - bool eval_value_; - const FunctionPtr func_desc_; - const std::size_t nargs_; - const TypePtr return_value_type_; - const TypePtr specify_out_type_; - std::unordered_map>, TypeHasher, TypeEqual> type_map_; -}; - -PrimEvaluatorMap &GetPrimEvaluatorConstructors(); - -// Check whether type x is a subtype of model. -bool IsSubtype(const AbstractBasePtr x, const TypePtr model); - -void ClearPrimEvaluatorMap(); - -py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base); - -AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); - -AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); -} // namespace abstract -} // namespace mindspore - -#endif // PIPELINE_STATIC_ANALYSIS_PRIM_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc deleted file mode 100644 index b0ad1c3d67..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc +++ /dev/null @@ -1,728 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/program_specialize.h" - -#include -#include -#include "./common.h" -#include "operator/ops.h" -#include "operator/composite/do_signature.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "utils/graph_utils.h" -#include "utils/log_adapter.h" -#include "utils/profile.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -namespace { -inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) { - if (conf->node()->intermediate_abstract()) { - return conf->node()->intermediate_abstract(); - } - return conf->GetEvaluatedValue()->abstract(); -} - -AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) { - AnfNodePtr value_node = NewValueNode(v); - value_node->set_abstract(abs_base); - MS_LOG(DEBUG) << "Create ValueNode: " << value_node->ToString() << ", with abstract: " << abs_base->ToString(); - return value_node; -} - -bool IsVisible(FuncGraphPtr fg, const FuncGraphPtr &parent) { - while (fg != nullptr && fg != parent) { - fg = fg->parent(); - } - return fg == parent; -} -} // namespace - -FuncGraphPtr ProgramSpecializer::Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(context); - MS_LOG(DEBUG) << "Specialize topmost function graph: " << context->func_graph()->ToString(); - return SpecializeFuncGraph(fg, context); -} - -FuncGraphPtr ProgramSpecializer::SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context) { - MS_EXCEPTION_IF_NULL(fg); - MS_EXCEPTION_IF_NULL(context); - auto iter = specializations_.find(context->SpecializeKey()); - if (iter != specializations_.end()) { - return iter->second->specialized_func_graph(); - } - - std::shared_ptr fg_spec = std::make_shared(this, fg, context); - FuncGraphPtr fg2 = fg_spec->specialized_func_graph(); - specializations_[context->SpecializeKey()] = fg_spec; - fg_spec->Run(); - return fg2; -} - -std::shared_ptr ProgramSpecializer::GetFuncGraphSpecializer(const AnalysisContextPtr &context) { - MS_EXCEPTION_IF_NULL(context); - auto iter = specializations_.find(context->SpecializeKey()); - if (iter != specializations_.end()) { - return iter->second; - } - return nullptr; -} - -std::string GetNextCounter() { - static int g_CloneCounter = 1; - std::string str_count = std::to_string(g_CloneCounter); - g_CloneCounter++; - return str_count; -} - -FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, - const AnalysisContextPtr &context) - : specializer_(s), func_graph_(fg), context_(context) { - parent_ = s->GetFuncGraphSpecializer(context->parent()); - engine_ = s->engine(); - cloner_ = SpecializerClone(fg, std::make_shared(GetNextCounter())); - repl_node_ = cloner_->cloned_node(); - specialized_func_graph_ = cloner_->cloned_func_graph()[fg]; - todo_.push_back(fg->get_return()); - auto ps = fg->parameters(); - (void)todo_.insert(todo_.end(), ps.begin(), ps.end()); -} - -AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr fg = node->func_graph(); - - if (node->isa()) { - return node; - } - std::shared_ptr specializer = shared_from_this(); - while (fg != nullptr && fg != specializer->func_graph_) { - specializer = specializer->parent_; - } - // If had replicated, just return that. - auto iter = specializer->repl_node_->find(node); - if (iter != specializer->repl_node_->end()) { - return iter->second; - } - - auto new_node = specializer->cloner_->CloneDisconnected(node); - if (node->isa()) { - if (!new_node->isa()) { - MS_LOG(EXCEPTION) << "new_node must be a CNode, but is " << new_node->DebugString() << "."; - } - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - auto inputs = c_node->inputs(); - std::vector new_inputs; - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs), - [this](const AnfNodePtr &inp) -> AnfNodePtr { - if (inp->isa()) { - return inp; - } - return ReplicateDisconnectedNode(inp); - }); - auto c_new_node = new_node->cast(); - MS_EXCEPTION_IF_NULL(c_new_node); - c_new_node->set_inputs(new_inputs); - } - - iter = specializer->repl_node_->find(node); - if (iter != specializer->repl_node_->end()) { - if (iter->second == node) { - MS_LOG(EXCEPTION) << "Replicated is same as original node, node: " << node->ToString(); - } - } else { - MS_LOG(EXCEPTION) << "Replicate node failed, node: " << node->ToString(); - } - return new_node; -} - -AnfNodePtr FuncGraphSpecializer::GetReplicatedNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - FuncGraphPtr fg = node->func_graph(); - - std::shared_ptr specializer = shared_from_this(); - while (fg != nullptr && fg != specializer->func_graph_) { - specializer = specializer->parent_; - } - - MS_EXCEPTION_IF_NULL(specializer->repl_node_); - auto iter = specializer->repl_node_->find(node); - if (iter != specializer->repl_node_->end()) { - return iter->second; - } - return node; -} - -void FuncGraphSpecializer::Run() { - MS_LOG(DEBUG) << "Before run, origin func graph name: " << func_graph_->ToString() - << ", cloned func graph name: " << specialized_func_graph_->ToString() - << ", func graph: " << func_graph_->get_return()->DebugString(); - FirstPass(); - SecondPass(); - MS_LOG(DEBUG) << "After run, origin func graph name: " << func_graph_->ToString() - << ", cloned func graph name: " << specialized_func_graph_->ToString() - << ", new func graph: " << specialized_func_graph_->get_return()->DebugString(); -} - -void FuncGraphSpecializer::FirstPass() { - while (todo_.size()) { - AnfNodePtr node = todo_.back(); - todo_.pop_back(); - if (node->func_graph() == nullptr) { - // do nothing for ValueNode - continue; - } - if (node->func_graph() != func_graph_) { - if (parent_ == nullptr) { - MS_LOG(EXCEPTION) << "Parent must not null NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - parent_->AddTodoItem(node); - parent_->FirstPass(); - AnfNodePtr new_node = parent_->GetReplicatedNode(node); - if (node->isa()) { - parent_->ProcessCNode(new_node->cast()); - } - continue; - } - if (marked_.count(node) > 0) { - continue; - } - (void)marked_.insert(node); - ProcessNode(node); - } -} - -// Specialize CNode in func graphs -void FuncGraphSpecializer::SecondPass() { - for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { - if (node->isa()) { - ProcessCNode(node->cast()); - } - } -} - -void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - ScopeGuard scope_guard(node->scope()); - AnfNodeConfigPtr conf = MakeConfig(node); - AnfNodePtr new_node = GetReplicatedNode(node); - MS_EXCEPTION_IF_NULL(new_node); - if (new_node->func_graph() != specialized_func_graph_) { - MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString() - << ", new_node: " << new_node->DebugString() - << ", new_node->func_graph(): " << new_node->func_graph()->ToString() - << ", specialized_func_graph_: " << specialized_func_graph_->ToString(); - return; - } - new_node->set_abstract(GetEvaluatedValueWrap(conf)); - if (new_node->isa() && new_node->abstract()->isa()) { - auto partial_abstract = dyn_cast(new_node->abstract()); - if (partial_abstract->node() == node) { - partial_abstract->set_node(new_node); - } - } - - MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); - - if (node->isa()) { - auto attrs = conf->GetEvaluatedValue()->attribute(); - auto c_old = node->cast(); - auto c_new = new_node->cast(); - auto new_inputs = c_new->inputs(); - auto old_inputs = c_old->inputs(); - for (size_t i = 0; i < old_inputs.size(); ++i) { - auto node_input = old_inputs[i]; - AnfNodeConfigPtr iconf = MakeConfig(node_input); - AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); - // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if - // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. - AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); - if (replace_node == nullptr) { - replace_node = BuildReplacedNode(iconf); - MS_EXCEPTION_IF_NULL(replace_node); - replace_node->set_abstract(ival); - MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); - } else { - MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() - << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); - } - if (new_inputs[i] != replace_node) { - new_inputs[i] = replace_node; - MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); - } - } - c_new->set_inputs(new_inputs); - } -} - -AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - - auto conf_iter = engine_->anfnode_config_map().find(conf); - AnfNodeConfigPtr new_conf = conf; - while (conf_iter != engine_->anfnode_config_map().end()) { - MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node(" - << new_conf->node()->DebugString() << ")"; - new_conf = conf_iter->second; - MS_EXCEPTION_IF_NULL(new_conf); - MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node(" - << conf->node()->DebugString() << ")"; - (void)ReplicateDisconnectedNode(new_conf->node()); - conf_iter = engine_->anfnode_config_map().find(new_conf); - } - todo_.push_back(new_conf->node()); - auto repl = GetReplicatedNode(new_conf->node()); - if (repl->func_graph()) { - MS_LOG(DEBUG) << "Set repl: graph(" << repl->func_graph()->ToString() << "), node:" << repl->DebugString() - << ") to replace origin:" << new_conf->node()->DebugString(); - } else { - MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString() - << ") to replace origin: " << new_conf->node()->DebugString(); - } - return repl; -} - -namespace { -const StringImmPtr kDeadNode = std::make_shared("Dead Node"); -const StringImmPtr kPolyNode = std::make_shared("Poly Node"); - -inline bool CanSpecializeNode(const AnfNodePtr &node) { - if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { - return true; - } - return false; -} -} // namespace - -AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, - const AbstractBasePtrList &argvals) { - MS_EXCEPTION_IF_NULL(abs); - AbstractFunctionPtr real_a = dyn_cast(abs); - MS_EXCEPTION_IF_NULL(real_a); - - AbstractFunctionPtr func = real_a->GetUnique(); - SpecializeStatusCode errcode; - ScopeGuard scope_guard(node->scope()); - AnfNodePtr repl = BuildSpecializedNodeInner(node, abs, func, argvals, &errcode); - if (repl == nullptr) { - if (errcode == kSpecializeFindUniqueArgvalDead) { - const auto error_dead_node = std::make_shared(kDeadNode, node); - repl = BuildValueNode(kDeadNode, error_dead_node); - MS_LOG(DEBUG) << "DEAD for node: " << node->DebugString() << ", abstract: " << abs->ToString(); - } else if (errcode == kSpecializeFindUniqueArgvalPoly) { - const auto error_poly_node = std::make_shared(kPolyNode, node); - repl = BuildValueNode(kPolyNode, error_poly_node); - MS_LOG(DEBUG) << "POLY for node: " << node->DebugString() << ", abstract: " << abs->ToString(); - } else { - MS_LOG(EXCEPTION) << "Failed to build specialized node, node: " << node->DebugString() - << ", abstract: " << abs->ToString(); - } - } - - return repl; -} - -AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, - const AbstractFunctionPtr &func, - const AbstractBasePtrList &args, - SpecializeStatusCode *errcode) { - MS_EXCEPTION_IF_NULL(abs); - MS_EXCEPTION_IF_NULL(func); - MS_EXCEPTION_IF_NULL(errcode); - *errcode = kSpecializeSuccess; - - auto real_func = dyn_cast(func); - if (real_func != nullptr) { - return BuildValueNode(real_func->prim(), abs); - } - - EvaluatorPtr eval; - eval = engine_->GetEvaluatorFor(func); - MS_EXCEPTION_IF_NULL(eval); - AbstractBasePtrList argvals = eval->NormalizeArgs(args); - - std::pair result; - SpecializeStatusCode status = FindUniqueArgvals(func, eval, argvals, &result); - if (status != kSpecializeSuccess) { - *errcode = status; - return nullptr; - } - argvals = result.first; - AbstractBasePtr unique_output = result.second; - - auto prim_func = dyn_cast(func); - if (prim_func != nullptr) { - auto type_func = std::make_shared(prim_func->prim(), argvals, unique_output); - return BuildValueNode(prim_func->prim(), type_func); - } - - if (!eval->isa()) { - MS_LOG(EXCEPTION) << "Eval is not BaseGraphEvaluator, but " << eval->ToString(); - } - auto real_eval = dyn_cast(eval); - - if (func->context() == nullptr) { - MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); - } - AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); - MS_LOG(DEBUG) << "Specialize function graph: " << context->func_graph()->ToString() << ", args: " << argvals.size() - << ", graph: " << context->func_graph()->get_return()->DebugString(); - if (context->func_graph()->stub()) { - MS_LOG(DEBUG) << "Specialize stub function graph, return the original node: " << context->func_graph()->ToString() - << ", args: " << argvals.size() << ", graph: " << context->func_graph()->get_return()->DebugString() - << ", " << node->ToString(); - return node; - } - FuncGraphPtr v = specializer_->SpecializeFuncGraph(context->func_graph(), context); - v->set_flag(kFuncGraphFlagUndetermined, false); - return BuildValueNode(v, abs); -} - -AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { - auto new_inputs = new_node->inputs(); - AnfNodePtr func = new_inputs[0]; - AbstractBasePtr fnval = new_inputs[0]->abstract(); - - AbstractBasePtrList args; - auto backed_fnval = fnval; - if (fnval->isa()) { - auto partial_closure = dyn_cast(fnval); - backed_fnval = partial_closure->fn(); - args = partial_closure->args(); - } - std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), - [](const AnfNodePtr &inp) { return inp->abstract(); }); - - ScopeGuard scope_guard(new_node->scope()); - - auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); - auto wrapped_node = specialized_node; - if (fnval->isa()) { - auto partial_closure = dyn_cast(fnval); - AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), - specialized_node}; - auto anf_node = partial_closure->node(); - if (!anf_node->isa()) { - MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); - } - auto cnode = anf_node->cast(); - if (cnode->size() != partial_closure->args().size() + 2) { - MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() - << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); - } - auto attrs = std::make_shared(); - for (size_t i = 0; i < partial_closure->args().size(); i++) { - auto old_node = cnode->input(i + 2); - auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); - if (possibile_value_node != nullptr) { - partial_node_list.push_back(possibile_value_node); - } else { - if (!(old_node->isa() || old_node->isa())) { - MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); - } - partial_node_list.push_back(old_node); - } - } - wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); - wrapped_node->set_abstract(partial_closure); - } - return wrapped_node; -} - -const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { - auto cache_iter = evalcaches_.find(eval); - if (cache_iter == evalcaches_.end()) { - evalcaches_[eval] = eval->cache(); - return eval->cache(); - } - return cache_iter->second; -} - -std::pair FuncGraphSpecializer::BuildFromBroadedArgsVal( - const EvaluatorPtr &eval) { - MS_EXCEPTION_IF_NULL(eval); - std::unordered_set choices; - EvalResultPtr ret = nullptr; - AbstractBasePtrList broaded_argvals; - for (auto &argvals_map : *evalcaches_[eval]) { - auto argvals = argvals_map.first; - broaded_argvals.clear(); - - (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals), - [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); - (void)choices.insert(broaded_argvals); - MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals); - } - - if (1 == choices.size()) { - ConfigPtrList args_conf_list; - (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), - [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared(v); }); - - // if broaden return null - ret = eval->Run(engine_, args_conf_list, nullptr); - EvaluatorCacheMapPtr real = std::make_shared(); - - (*real)[broaded_argvals] = ret; - evalcaches_[eval] = real; - return std::make_pair(broaded_argvals, ret->abstract()); - } else { - MS_LOG(DEBUG) << "Choices.size: " << choices.size(); - return std::make_pair(AbstractBasePtrList(), nullptr); - } -} - -void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { - MS_EXCEPTION_IF_NULL(new_node); - if (specializer_->seen().count(new_node) > 0) { - return; - } - specializer_->AddSeen(new_node); - auto new_inputs = new_node->inputs(); - if (new_inputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs of CNode is empty"; - } - AnfNodePtr func = new_inputs[0]; - MS_EXCEPTION_IF_NULL(func); - - // First element is func so arg start from 1 - std::vector args(new_inputs.begin() + 1, new_inputs.end()); - // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) - while (IsPrimitiveCNode(func, prim::kPrimPartial)) { - std::vector inputs = func->cast()->inputs(); - // First element is partial, second is func so arg is start from 2 - (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); - func = inputs[1]; - } - new_inputs = args; - (void)new_inputs.insert(new_inputs.begin(), func); - - AbstractBasePtrList argvals; - MS_EXCEPTION_IF_NULL(new_inputs[0]); - AbstractBasePtr fnval = new_inputs[0]->abstract(); - MS_LOG(DEBUG) << "The new_inputs[0] node: pointer: " << new_inputs[0]->ToString() << ", " - << new_inputs[0]->DebugString() << ", abstract: " << new_inputs[0]->abstract()->ToString(); - - // First element is func so function arguments start from 1 - for (size_t i = 1; i < new_inputs.size(); ++i) { - argvals.push_back(new_inputs[i]->abstract()); - MS_LOG(DEBUG) << "The new_inputs[" << i << "] node: pointer: " << new_inputs[i]->ToString() << ", " - << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); - } - - if (!func->isa()) { - MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); - if (func->abstract()->isa() && !func->abstract()->isa()) { - auto func_abs = func->abstract()->cast(); - EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); - std::pair result; - AbstractBasePtrList empty_args; - auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); - MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; - // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early - if (status == kSpecializeFindUniqueArgvalPoly || - (func->isa() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || - func->abstract()->isa()))) { - auto wrapped_node = BuildSpecializedParameterNode(new_node); - new_inputs[0] = wrapped_node; - } - } - } - - if (CanSpecializeNode(func)) { - // for primitive node , we build the primitive node with infered attributes in the first pass - // so we do not build replaced node again here in second pass - if (IsValueNode(func)) { - new_inputs[0] = func; - } else { - new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); - } - } - - for (size_t i = 0; i < argvals.size();) { - size_t next = i + 1; - if (CanSpecializeNode(args[i])) { - new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector{}); - } - i = next; - } - new_node->set_inputs(new_inputs); -} - -namespace { -void DumpEvaluatorCache(const EvaluatorCacheMap &evaluator_cache_map, const AbstractBasePtrList &argvals) { - MS_LOG(DEBUG) << "Find unique argvals failed: " << argvals.size() << ", " << argvals << ". Check cache all items."; - int i = 0; - for (const auto &item : evaluator_cache_map) { - MS_LOG(DEBUG) << "evaluator_cache_map[" << i++ << "]: " << item.first; - } -} - -bool IsPolyFunc(const AbstractFunctionPtr &func, const AbstractBasePtrList &argvals) { - if (func->isa() && argvals.empty()) { - MS_LOG(DEBUG) << "High order primitive return POLY."; - return true; - } - if (func->isa() && argvals.empty()) { - auto meta_func_graph_wrapper = dyn_cast(func); - auto meta_func_graph = meta_func_graph_wrapper->meta_func_graph(); - if (meta_func_graph != nullptr && meta_func_graph->isa()) { - auto do_signature = dyn_cast(meta_func_graph); - if (do_signature != nullptr && do_signature->function()->isa()) { - MS_LOG(DEBUG) << "High order primitive " << do_signature->function()->ToString() << " return POLY."; - return true; - } - } - } - return false; -} -} // end anonymous namespace - -SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunctionPtr &func, const EvaluatorPtr &eval, - const AbstractBasePtrList &argvals, - std::pair *result) { - MS_EXCEPTION_IF_NULL(func); - MS_EXCEPTION_IF_NULL(eval); - MS_EXCEPTION_IF_NULL(result); - - EvaluatorCacheMap evaluator_cache_map = *eval->cache(); - if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) { - *result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract()); - return kSpecializeSuccess; - } - DumpEvaluatorCache(evaluator_cache_map, argvals); - - const EvaluatorCacheMapPtr &choices = GetEvalCache(eval); - MS_EXCEPTION_IF_NULL(choices); - - if (choices->count(argvals)) { - *result = std::make_pair(argvals, (*choices)[argvals]->abstract()); - return kSpecializeSuccess; - } else if (choices->size() == 1) { - MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it."; - *result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract()); - return kSpecializeSuccess; - } else if (choices->empty()) { - MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase " << func->ToString() << " | " - << func->type_name(); - return kSpecializeFindUniqueArgvalDead; - } else { - if (IsPolyFunc(func, argvals)) { - return kSpecializeFindUniqueArgvalPoly; - } - - MS_LOG(DEBUG) << "Try to find generalized argvals."; - *result = BuildFromBroadedArgsVal(eval); - if (!result->first.empty()) { - return kSpecializeSuccess; - } - MS_LOG(DEBUG) << "Find POLY code, it may be unused code or unresolved polymorphism."; - return kSpecializeFindUniqueArgvalPoly; - } -} -static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) { - auto &prim_attrs = prim->attrs(); - bool is_attr_same = true; - for (auto &item : *attrs) { - auto itr = prim_attrs.find(item.first); - if (itr != prim_attrs.end()) { - if (!(*(itr->second) == *(item.second))) { - is_attr_same = false; - break; - } - } else { - is_attr_same = false; - break; - } - } - if (!is_attr_same) { - if (prim->isa()) { - PrimitivePyPtr prim_py = prim->cast(); - auto clone_fn = prim_py->GetPyObj().attr("_clone"); - py::object new_obj = clone_fn(); - auto cloned_prim = new_obj.cast(); - for (auto &item : *attrs) { - cloned_prim->AddAttr(item.first, item.second); - } - return cloned_prim; - } - auto cloned_prim = std::make_shared(*prim); - for (auto &item : *attrs) { - cloned_prim->AddAttr(item.first, item.second); - } - return cloned_prim; - } - return prim; -} - -AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, - const AttrValueMapPtr &attrs) { - MS_EXCEPTION_IF_NULL(origin_node); - MS_EXCEPTION_IF_NULL(ival); - - AbstractFunctionPtr abs = dyn_cast(ival); - if (abs != nullptr) { - // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. - if (abs->isa()) { - return nullptr; - } - ValuePtr value = nullptr; - if (abs->isa()) { - auto real_fn = dyn_cast(abs); - // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one - if (attrs != nullptr) { - value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); - } else { - value = real_fn->prim(); - } - } else if (abs->isa()) { - auto real_fn = dyn_cast(abs); - value = real_fn->meta_func_graph(); - } else if (abs->isa()) { - auto real_fn = dyn_cast(abs); - value = real_fn->func_graph(); - } else { - return nullptr; - } - if (!value->isa() || value->cast()->parent() == nullptr || - (IsValueNode(origin_node) && IsVisible(func_graph_, value->cast()->parent()))) { - return BuildValueNode(value, ival); - } else { - return nullptr; - } - } else { - ValuePtr val = ival->BuildValue(); - if (val->isa()) { - return nullptr; - } - // keep primitive 'depend' not to be optimized - if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) { - return nullptr; - } - return BuildValueNode(val, ival); - } -} - -AnfNodeConfigPtr FuncGraphSpecializer::MakeConfig(const AnfNodePtr &node) { - return engine_->MakeConfig(node, context_); -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h deleted file mode 100644 index 831c404873..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h +++ /dev/null @@ -1,136 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ -#define PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph_cloner.h" -#include "pipeline/static_analysis/evaluator.h" - -namespace mindspore { -namespace abstract { -enum SpecializeStatusCode { - kSpecializeSuccess = 0, - kSpecializeFindUniqueArgvalDead = 1, // Dead Node - kSpecializeFindUniqueArgvalPoly = 2, // Poly Node - kSpecializeFailure = 0xFF -}; - -class FuncGraphSpecializer; - -// Specialize a func graph using analyzed abstract values. -class ProgramSpecializer { - public: - explicit ProgramSpecializer(const std::shared_ptr &engine) : engine_(engine) { - mng_ = engine_->func_graph_manager(); - } - ~ProgramSpecializer() = default; - // Run the program specializer on the topmost graph in the given context. - FuncGraphPtr Run(const FuncGraphPtr &fg, const AnalysisContextPtr &context); - const std::unordered_set &seen() const { return seen_; } - void AddSeen(const AnfNodePtr &node) { (void)seen_.insert(node); } - - std::shared_ptr GetFuncGraphSpecializer(const AnalysisContextPtr &context); - // Specialze one FuncGraph in a given context. - FuncGraphPtr SpecializeFuncGraph(const FuncGraphPtr &fg, const AnalysisContextPtr &context); - - std::shared_ptr engine() { return engine_; } - - private: - std::shared_ptr engine_; - std::unordered_set seen_; - FuncGraphManagerPtr mng_; - std::unordered_map, ContextHasher, ContextEqual> - specializations_; -}; - -class FuncGraphSpecializer : public std::enable_shared_from_this { - public: - FuncGraphSpecializer(ProgramSpecializer *const s, const FuncGraphPtr &fg, const AnalysisContextPtr &context); - virtual ~FuncGraphSpecializer() { - specializer_ = nullptr; - repl_node_ = nullptr; - } - void Run(); - FuncGraphPtr specialized_func_graph() { return specialized_func_graph_; } - - private: - ProgramSpecializer *specializer_; - FuncGraphPtr func_graph_; - FuncGraphPtr specialized_func_graph_; - AnalysisContextPtr context_; - std::shared_ptr parent_; - std::shared_ptr engine_; - ClonerPtr cloner_; - // ProcessNode-> [cloner_->CloneDisconnected] will clone AnfNode again. - // So, repl_node_ should pointer to GraphCloner->repl_node_ other than a copy of that. - std::unordered_map *repl_node_; - std::vector todo_; - std::unordered_set marked_; - std::unordered_map evalcaches_; - - void FirstPass(); - void SecondPass(); - void ProcessNode(const AnfNodePtr &node); - void ProcessCNode(const CNodePtr &new_node); - - AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node); - inline void AddTodoItem(const AnfNodePtr &node) { todo_.push_back(node); } - // Get node replicated by Cloner. - AnfNodePtr GetReplicatedNode(const AnfNodePtr &node); - // Replicated node which is not used directly by a func graph, so it's not searchable from it's return node - // (disconnected). - AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node); - - // Build a value node from parameter if the function graph has special flag to hint it can be done. - AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node); - - // Build a value node if ival is constant and not any-value - AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, - const AttrValueMapPtr &attrs); - // Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a - // replicated node. - AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); - // Build a specialized node from given argvals; - AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, - const AbstractBasePtrList &argvals); - AnfNodePtr BuildSpecializedNodeInner(const AnfNodePtr &node, const AbstractBasePtr &abs, - const AbstractFunctionPtr &func, const AbstractBasePtrList &args, - SpecializeStatusCode *errcode); - - // Find the unique argument values which can be used to specialize a primitive or graph function. - SpecializeStatusCode FindUniqueArgvals(const AbstractFunctionPtr &fn, const EvaluatorPtr &eval, - const AbstractBasePtrList &argvals, - std::pair *result); - // Get cache, it may be eval's cache or cache built from broaded argument values. - const EvaluatorCacheMapPtr &GetEvalCache(const EvaluatorPtr &eval); - // Try to build unique argvals from the broaded arg vals if it is unique. - std::pair BuildFromBroadedArgsVal(const EvaluatorPtr &eval); -}; -} // namespace abstract -} // namespace mindspore -#endif // PIPELINE_STATIC_ANALYSIS_SPECIALIZE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc deleted file mode 100644 index 53c2c064b4..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ /dev/null @@ -1,655 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/static_analysis/static_analysis.h" - -#include -#include - -#include "abstract/utils.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" -#include "utils/symbolic.h" -#include "ir/tensor.h" -#include "ir/func_graph_cloner.h" -#include "./common.h" -#include "pipeline/parse/data_converter.h" -#include "debug/draw.h" -#include "pipeline/static_analysis/evaluator.h" -#include "debug/trace.h" - -namespace mindspore { -namespace abstract { -bool IsIntermediateAbstract(const AbstractBasePtr &arg_spec) { - if (dyn_cast(arg_spec)) { - auto v = arg_spec->GetValueTrack(); - if (v->isa()) { - return true; - } else { - return false; - } - } else { - return false; - } -} - -AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBasePtr &arg2) { - if (dyn_cast(arg1) && dyn_cast(arg2)) { - return arg1->Join(arg2); - } - return nullptr; -} - -void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { - MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString() - << ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString() - << ", Pointer: " << result->abstract().get(); - cache_[conf] = result; - - // Set intermediate abstract value. - if (IsIntermediateAbstract(result->abstract())) { - if (conf->node()->intermediate_abstract() == nullptr) { - conf->node()->set_intermediate_abstract(result->abstract()); - MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString(); - } else { - auto old_spec = conf->node()->intermediate_abstract(); - auto joined_spec = IntermediateJoin(result->abstract(), old_spec); - conf->node()->set_intermediate_abstract(joined_spec); - MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t" - << result->abstract()->ToString() << "\njoined_spec:\t" - << (joined_spec != nullptr ? joined_spec->ToString() : "nullptr"); - } - } -} - -EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) { - auto value = cache_.find(conf); - if (value == cache_.end()) { - return nullptr; - } - return value->second; -} - -std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(conf->node()); - std::size_t hash_value = conf->node()->hash(); - if (!conf->context()->IsDummyContext()) { - hash_value = hash_combine(hash_value, std::hash{}(conf->context().get())); - } - if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) { - MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() - << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value; - } else { - MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value; - } - return hash_value; -} - -bool AnfNodeConfigEqual::operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const { - if (lhs == nullptr || rhs == nullptr) { - return false; - } - if (lhs == rhs) { - return true; - } - return (*lhs == *rhs); -} - -AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list) { - ConfigPtrList args_conf_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - MS_EXCEPTION_IF_NULL(func_graph_manager_); - func_graph_manager_->AddFuncGraph(func_graph); - - AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); - - // Running the analyzer. - AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); - MS_EXCEPTION_IF_NULL(root_context); - MS_EXCEPTION_IF_NULL(root_context->func_graph()); - AnfNodeConfigPtr output_conf = MakeConfig(root_context->func_graph()->get_return(), root_context); - MS_EXCEPTION_IF_NULL(func_graph); - MS_LOG(INFO) << func_graph->ToString() << ": Run finished."; - - AnalysisResult result; - MS_EXCEPTION_IF_NULL(output_conf); - result.inferred = output_conf->GetEvaluatedValue(); - result.context = root_context; - return result; -} - -AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, - const ConfigPtrList &args_conf_list) { - std::shared_ptr eval = std::make_shared(func_graph, context); - (void)eval->Run(shared_from_this(), args_conf_list, nullptr); - return eval->graph_context(); -} - -EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - auto value = cache_.GetValue(conf); - if (value != nullptr) { - MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() - << ", " << value->abstract()->ToString(); - return value; - } - - MS_LOG(DEBUG) << "Evaluate cache miss for NodeConfig: " << conf->ToString(); - value = Eval(conf); - if (value == nullptr) { - MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr"; - } - cache_.set_value(conf, value); - return value; -} - -EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - AnfNodePtr node = conf->node(); - EvalResultPtr eval_result = nullptr; -#ifdef DEBUG - compute_conf_stack_.push_back(node); - std::ostringstream buffer; - buffer << "Compute Config Begin:"; - for (auto iter : compute_conf_stack_) { - buffer << " -> " << iter->DebugString(); - } - MS_LOG(DEBUG) << buffer.str(); -#endif - MS_LOG(DEBUG) << "Begin Eval NodeConfig " << conf->ToString(); - MS_EXCEPTION_IF_NULL(node); - if (node->abstract() != nullptr) { - MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString(); - eval_result = std::make_shared(node->abstract(), std::make_shared()); - } else if (node->isa()) { - auto value_node = node->cast(); - eval_result = std::make_shared(EvalValueNode(value_node, conf), nullptr); - } else if (node->isa()) { - auto cnode = node->cast(); - trace::TraceEvalCNodeEnter(conf); - eval_result = EvalCNode(cnode, conf); - trace::TraceEvalCNodeLeave(); - } else { - MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() - << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } - -#ifdef DEBUG - compute_conf_stack_.pop_back(); - if (eval_result == nullptr) { - MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString() - << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); - } -#endif - MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); - return eval_result; -} - -AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(value_node); - return ToAbstract(value_node->value(), conf->context(), conf); -} - -EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { - MS_EXCEPTION_IF_NULL(conf); - MS_EXCEPTION_IF_NULL(cnode); - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "CNode->inputs() is empty, CNode: " << cnode->DebugString(); - } - - AnfNodePtr func_node = inputs[0]; - MS_EXCEPTION_IF_NULL(func_node); - MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString(); - AnalysisContextPtr context = conf->context(); - AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); - MS_EXCEPTION_IF_NULL(func_conf); - // Keep it in a local variable, otherwise smart pointer will free it. - AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); - if (maybe_func == nullptr) { - MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() - << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); - } - if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { - MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; - return std::make_shared(maybe_func->Clone(), std::make_shared()); - } - AbstractFunctionPtr func = dyn_cast(maybe_func); - if (func == nullptr) { - MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() - << ", func_conf: " << func_conf->ToString() - << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); - } - - ConfigPtrList args_conf_list; - // ignore the first node which is function name - for (std::size_t i = 1; i < inputs.size(); i++) { - const AnfNodePtr &node = inputs[i]; - args_conf_list.push_back(MakeConfig(node, context)); - } - std::vector infs; - - auto build_evaluator = [this, &infs, &cnode](const AbstractFuncAtomPtr &poss) { - auto evaluator = this->GetEvaluatorFor(poss); - evaluator->set_bound_node(cnode); - infs.push_back(evaluator); - }; - func->Visit(build_evaluator); - - return ExecuteEvaluators(infs, conf, args_conf_list); -} - -EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { - ConfigPtrList args_conf_list; - (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list), - [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared(arg); }); - std::vector infs; - MS_EXCEPTION_IF_NULL(func); - auto build_evaluator = [this, &infs](const AbstractFuncAtomPtr &poss) { - auto evaluator = this->GetEvaluatorFor(poss); - infs.push_back(evaluator); - }; - func->Visit(build_evaluator); - return ExecuteEvaluators(infs, nullptr, args_conf_list); -} - -void AnalysisEngine::ClearEvaluatorCache() { - for (std::pair element : constructors_) { - EvaluatorPtr evaluator = element.second; - MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->cache()); - evaluator->cache()->clear(); - } - for (auto &element : prim_constructors_) { - EvaluatorPtr evaluator = element.second; - MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->cache()); - evaluator->cache()->clear(); - } - for (auto &element : prim_py_evaluators_) { - EvaluatorPtr evaluator = element.second; - MS_EXCEPTION_IF_NULL(evaluator); - MS_EXCEPTION_IF_NULL(evaluator->cache()); - evaluator->cache()->clear(); - } -} - -void AnalysisEngine::Clear() { - cache_.Clear(); - anfnode_config_map_.clear(); - eval_trace_.clear(); - constructors_.clear(); -} - -namespace { -EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr &engine) { - // Custom Primitive with python infer_shape, infer_type - EvaluatorPtr evaluator = nullptr; - MS_EXCEPTION_IF_NULL(prim); - if (prim->isa()) { - evaluator = std::make_shared(prim); - return evaluator; - } - if (prim->isa()) { - evaluator = std::make_shared(prim); - return evaluator; - } - if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) { - evaluator = std::make_shared(prim); - return evaluator; - } - if (prim->HasPyEvaluator()) { - auto prim_py = dyn_cast(prim); - if (prim_py != nullptr) { - if (engine == nullptr) { - return std::make_shared(prim_py); - } - - const auto &iter = engine->prim_py_evaluators_.find(prim_py); - if (iter != engine->prim_py_evaluators_.end()) { - return iter->second; - } - evaluator = std::make_shared(prim_py); - engine->prim_py_evaluators_[prim_py] = evaluator; - return evaluator; - } - MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; - } - - if (prim->isa() || prim->HasAttr()) { - if (engine == nullptr) { - (void)GetPrimEvaluatorConstructors(); - } - // If a primitive may have attr, try to create a new evaluator. - StandardPrimitiveEvalImpl eval_impl = GetPrimitiveInferImpl(prim); - if (eval_impl != nullptr) { - return std::make_shared(prim, eval_impl); - } - } - - if (engine == nullptr) { - // If engine is nullptr, get constructor from default. - const PrimEvaluatorMap &prim_evaluator_map = GetPrimEvaluatorConstructors(); - auto iter = prim_evaluator_map.find(prim); - if (iter != prim_evaluator_map.end()) { - evaluator = iter->second; - } - } else { - // If engine is given, get constructor from engine resource. - const PrimEvaluatorMap &prim_evaluator_map = engine->PrimConstructors(); - auto iter = prim_evaluator_map.find(prim); - if (iter != prim_evaluator_map.end()) { - evaluator = iter->second; - } - } - if (evaluator == nullptr) { - MS_LOG(EXCEPTION) << "The evaluator of the primitive is not defined (" << prim->name() << ")."; - } - return evaluator; -} -} // namespace - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - MS_EXCEPTION_IF_NULL(func); - auto primitive = func->prim(); - auto evaluator = GetPrimEvaluator(primitive, shared_from_this()); - constructors_[func] = evaluator; - return evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - MS_EXCEPTION_IF_NULL(func); - std::shared_ptr func_graph_evaluator = - std::make_shared(func->func_graph(), func->context()); - constructors_[func] = func_graph_evaluator; - return func_graph_evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - MS_EXCEPTION_IF_NULL(func); - std::shared_ptr evaluator = - std::make_shared(func->meta_func_graph(), func->context(), func->GetScope()); - constructors_[func] = evaluator; - return evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - MS_EXCEPTION_IF_NULL(func); - AbstractFunctionPtr func_orig = func->fn(); - EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); - auto jevaluator = std::make_shared(evaluator_orig, func_orig); - return jevaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - MS_EXCEPTION_IF_NULL(func); - std::shared_ptr virtual_evaluator = - std::make_shared(func->args_spec_list(), func->output()); - return virtual_evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &func) { - MS_EXCEPTION_IF_NULL(func); - AbstractFunctionPtr func_orig = func->fn(); - EvaluatorPtr evaluator_orig = GetEvaluatorFor(func_orig); - std::shared_ptr partial_evaluator = - std::make_shared(evaluator_orig, func->args()); - return partial_evaluator; -} - -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr &) { - MS_LOG(EXCEPTION) << "Should not be called "; -} - -// Forward to specific subclass of FunctionWrapper. -EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const AbstractFunctionPtr &func) { - MS_EXCEPTION_IF_NULL(func); - EvaluatorPtr evaluator = func->GetEvaluator(shared_from_this()); - return evaluator; -} - -EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { - MS_LOG(DEBUG) << "The func value: " << func->ToString(); - if (func->tracking_id() != nullptr) { - MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString(); - } - MS_EXCEPTION_IF_NULL(func); - if (func->tracking_id() == nullptr) { - EvaluatorPtr evaluator = _GetEvaluatorFor(func); - return evaluator; - } - auto inf_pair = constructors_.find(func); - if (inf_pair != constructors_.end()) { - return inf_pair->second; - } - - AbstractFunctionPtr func_generic = func->Copy(); - func_generic->set_tracking_id(nullptr); - EvaluatorPtr eval = _GetEvaluatorFor(func_generic); - auto tracked_eval = std::make_shared(eval); - constructors_[func] = tracked_eval; - - return tracked_eval; -} - -EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { - if (evaluators.size() == 1) { - EvaluatorPtr eval = evaluators[0]; - MS_EXCEPTION_IF_NULL(eval); - return eval->Run(shared_from_this(), args_conf_list, out_conf); - } - return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); -} - -void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { - auto fg_eval = evaluator->cast(); - if (fg_eval == nullptr) { - return; - } - auto fg = fg_eval->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto undetermined_fgs = fg->recursive_graphs(); - if (undetermined_fgs) { - auto fg_parent = fg->parent(); - MS_EXCEPTION_IF_NULL(fg_parent); - fg_parent->set_flag(kFuncGraphFlagUndetermined, true); - MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); - } -} - -EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector &evaluators, - const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, - const EvalTraceRevIter &it, bool *continue_flag) { - *continue_flag = false; - // Find latest entry function to handle nested recursion. - EvaluatorPtr latest_entry = eval; - auto latest_entry_iter = eval_trace_.rbegin(); - for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { - auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); - if (it_temp != evaluators.end()) { - latest_entry = *it_temp; - latest_entry_iter = r_it; - break; - } - latest_entry_iter = ++r_it; - } - if (latest_entry != eval) { - MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); - *continue_flag = true; - return latest_entry; - } - - bool has_undetermined = false; - // Check whether sub loop has untraced undetermined evaluator. - std::set> undetermined_evals; - for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { - undetermined_evals.insert(*r_it); - } - MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); - - for (auto u_eval : undetermined_evals) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; - if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; - has_undetermined = true; - break; - } - } - if (has_undetermined == false) { - MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; - *continue_flag = true; - return latest_entry; - } - - return latest_entry; -} - -EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { - if (out_specs.size() == 0) { - MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; - } - - if (out_specs.size() == 1) { - MS_EXCEPTION_IF_NULL(out_specs[0]); - // If only one result derived, then broaden it to avoid wrong constant propagation. - return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); - } - auto joined_spec = AbstractJoin(out_specs); - MS_EXCEPTION_IF_NULL(joined_spec); - MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); - return std::make_shared(joined_spec, std::make_shared()); -} - -EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, - const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list) { - AbstractBasePtrList out_specs; - if (!multi_poss_.count(evaluators[0])) { - multi_poss_[evaluators[0]] = evaluators[1]; - multi_poss_[evaluators[1]] = evaluators[0]; - } - AbstractBasePtrList args_spec_list; - (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &conf) -> AbstractBasePtr { - MS_EXCEPTION_IF_NULL(conf); - return conf->GetEvaluatedValue()->abstract(); - }); - for (auto eval : evaluators) { - SetUndeterminedFlag(eval); - - auto current_inf = std::make_pair(eval, args_spec_list); - MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); - - // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. - auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); - if (it == eval_trace_.rend()) { - eval_trace_.push_back(current_inf); - MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); - MS_EXCEPTION_IF_NULL(eval); - auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString(); - out_specs.push_back(eval_result->abstract()); - eval_trace_.pop_back(); - if (eval_trace_.empty()) { - multi_poss_.clear(); - } - } else if (it != eval_trace_.rbegin()) { - bool continue_flag = false; - auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); - if (continue_flag) { - continue; - } - - // Try to travel the latest undetermined. - if (latest_entry != eval_trace_.rbegin()->first) { - MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); - auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); - MS_EXCEPTION_IF_NULL(eval_result->abstract()); - MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() - << " return out_spec: " << eval_result->abstract()->ToString(); - return eval_result; - } - } - } - - return ProcessEvalResults(out_specs); -} - -EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { - AnfNodeConfigPtr self = shared_from_base(); - return engine_.lock()->GetEvaluatedValue(self); -} - -AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) { - if (value->isa()) { - auto func_graph = value->cast(); - return func_graph->MakeAbstractClosure(context); - } - AnfNodePtr anf_node = nullptr; - if (conf != nullptr) { - anf_node = conf->node(); - } - if (value->isa()) { - auto meta_func_graph = value->cast(); - return meta_func_graph->MakeAbstractClosure(anf_node); - } - if (value->isa()) { - auto prim = value->cast(); - return prim->ToPrimAbstract(anf_node); - } - return value->ToAbstract(); -} - -AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { - AbstractBasePtr a = ToAbstract(value, nullptr, nullptr); - if (broaden) { - a = a->Broaden(); - } - return a; -} - -EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { - auto evaluator = GetPrimEvaluator(primitive, nullptr); - MS_EXCEPTION_IF_NULL(evaluator); - if (!evaluator->isa()) { - MS_LOG(EXCEPTION) << "Prim " << primitive->ToString() << " should build a TrivialPrimEvaluator, but " - << evaluator->ToString(); - } - auto trivial_evaluator = dyn_cast(evaluator); - auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); - return eval_result; -} -} // namespace abstract -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h deleted file mode 100644 index d4a3fd6a8d..0000000000 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ /dev/null @@ -1,280 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ -#define PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ - -#include -#include -#include -#include -#include -#include -#include - -#ifdef DEBUG -#include -#endif - -#include "utils/log_adapter.h" -#include "ir/anf.h" -#include "ir/primitive_py.h" -#include "abstract/analysis_context.h" -#include "pipeline/static_analysis/abstract_function.h" -#include "pipeline/parse/parse.h" - -namespace mindspore { -namespace abstract { -// define attribute value map -using AttrValueMap = std::unordered_map; -using AttrValueMapPtr = std::shared_ptr; - -// the class to save evaluated result: abstract value and modified attribute -class EvalResult : public Base { - public: - EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {} - ~EvalResult() override = default; - MS_DECLARE_PARENT(EvalResult, Base); - AbstractBasePtr abstract() { return abstract_; } - AttrValueMapPtr attribute() { return attribute_; } - - private: - AbstractBasePtr abstract_; - AttrValueMapPtr attribute_; -}; - -using EvalResultPtr = std::shared_ptr; -// Superclass for AnfNodeConfig and VirtualConfig. -class Config : public Base { - public: - Config() = default; - ~Config() override = default; - MS_DECLARE_PARENT(Config, Base); - virtual EvalResultPtr GetEvaluatedValue() = 0; -}; - -// Config will be stored in AnalysisCache -using ConfigPtr = std::shared_ptr; -using ConfigPtrList = std::vector; - -// Config to a certain node in a certain context. -class AnfNodeConfig : public Config { - public: - AnfNodeConfig(const AnalysisEnginePtr &engine, const AnfNodePtr &node, const AnalysisContextPtr &context) - : Config(), engine_(std::weak_ptr(engine)), node_(node) { - FuncGraphPtr fg; - if (IsValueNode(node)) { - auto v = node->cast(); - fg = v->value()->cast(); - } else { - fg = node->func_graph(); - } - context_ = nullptr; - if (context != nullptr) { - context_ = context->Filter(fg); - } - } - - ~AnfNodeConfig() override = default; - MS_DECLARE_PARENT(AnfNodeConfig, Config); - - EvalResultPtr GetEvaluatedValue() override; - - AnalysisContextPtr context() const { return context_; } - - AnfNodePtr node() const { return node_; } - - AnalysisEnginePtr engine() const { return engine_.lock(); } - - // used by unordered_map; - bool operator==(const AnfNodeConfig &other) const { - // compare node with pointer, context with pointer except DummyContext as it's created by make_shared; - // context should not be nullptr; - if (context_->IsDummyContext() && other.context_->IsDummyContext()) { - return true; - } - return (node_ == other.node_) && (context_ == other.context_); - } - - std::string ToString() const override { - std::ostringstream buffer; - buffer << "Node: " << node_->DebugString() << ", Context: " << context_->ToString(); - return buffer.str(); - } - - private: - // AnalysisEngine is global. - // As AnfNodeConfig is cached in AnalysisEngine.AnalysisCache, use - // weak_ptr to break Config cycle. - std::weak_ptr engine_; - AnfNodePtr node_; - AnalysisContextPtr context_; -}; - -using AnfNodeConfigPtr = std::shared_ptr; - -struct AnfNodeConfigHasher { - std::size_t operator()(const AnfNodeConfigPtr conf) const; -}; - -struct AnfNodeConfigEqual { - bool operator()(const AnfNodeConfigPtr lhs, const AnfNodeConfigPtr rhs) const; -}; - -class VirtualConfig : public Config { - public: - explicit VirtualConfig(const AbstractBasePtr &abstract) : Config(), abstract_(abstract) {} - - ~VirtualConfig() override = default; - MS_DECLARE_PARENT(VirtualConfig, Config); - EvalResultPtr GetEvaluatedValue() override { - return std::make_shared(abstract_, std::make_shared()); - } - - private: - AbstractBasePtr abstract_; -}; - -// AnalysisCache -class AnalysisCache { - public: - AnalysisCache() = default; - ~AnalysisCache() = default; - void Clear() { cache_.clear(); } - void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg); - EvalResultPtr GetValue(const AnfNodeConfigPtr &conf); - - private: - std::unordered_map cache_; -}; - -using PrimEvaluatorMap = std::unordered_map; -using AnfNodeConfigMap = - std::unordered_map; - -struct AnalysisResult { - EvalResultPtr inferred; - AnalysisContextPtr context; -}; - -using EvalTraceRevIter = std::list>::reverse_iterator; - -class AnalysisEngine : public std::enable_shared_from_this { - public: - AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) - : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} - ~AnalysisEngine() = default; - - // func_graph: The func_graph to analyze. - // args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase. - AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list); - EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf); - // Return the Evaluator for the given function. - EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); - - AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); - EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); - // Infer the result of fn(args). - EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); - void Clear(); - void ClearEvaluatorCache(); - AnalysisCache &cache() { return cache_; } - AnfNodeConfigPtr MakeConfig(const AnfNodePtr &node, const AnalysisContextPtr &context) { - return std::make_shared(shared_from_this(), node, context); - } - // Overloaded function. - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &); - EvaluatorPtr _GetEvaluatorFor(const std::shared_ptr &fn); - - FuncGraphManagerPtr func_graph_manager() { return func_graph_manager_; } - const AnfNodeConfigMap &anfnode_config_map() const { return anfnode_config_map_; } - - // Set the analysis result for orig to the result for new. - // This sets an entry in anfnode_config_map from orig to new. - EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { - // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. - (void)anfnode_config_map_.emplace(orig_conf, new_conf); - MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() - << ", to new_conf: " << new_conf->node()->DebugString(); - return GetEvaluatedValue(new_conf); - } - const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } - - AnalysisCache cache_; - std::unordered_map prim_py_evaluators_; - - private: - void SetUndeterminedFlag(const EvaluatorPtr &evaluator); - EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, - const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, - bool *continue_flag); - EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); - - const PrimEvaluatorMap &prim_constructors_; - FuncGraphManagerPtr func_graph_manager_; - std::unordered_map constructors_; - AnfNodeConfigMap anfnode_config_map_; - // Use a list to trace multiple evaluators. - std::list> eval_trace_; - std::map multi_poss_; - - AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, - const ConfigPtrList &args_conf_list); - EvalResultPtr Eval(const AnfNodeConfigPtr &conf); - EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn); - EvalResultPtr ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list); - EvalResultPtr ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, - const ConfigPtrList &args_conf_list); - -#ifdef DEBUG - std::vector compute_conf_stack_; -#endif -}; - -// Translate the value to an abstract value. -// Arguments: -// value: The value to convert. -// context: The context in which the value was found, used if the value is a Graph. -// conf: The Config to the valuenode we are converting, if there is one, -// so that we can generate a tracking_id. -AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context = nullptr, - const AnfNodeConfigPtr &conf = nullptr); - -// Convert a value to an abstract value. -// Arguments: -// v: The value to convert. -// broaden: If True, concrete values will be made more abstract, so e.g. -// the value 1234 would become ANYTHING. -AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden = false); - -template -AbstractBasePtr FromValue(const T &value, bool broaden = false) { - return FromValueInside(MakeValue(value), broaden); -} - -EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); -} // namespace abstract -} // namespace mindspore - -#endif // PIPELINE_STATIC_ANALYSIS_STATIC_ANALYSIS_H_ diff --git a/mindspore/ccsrc/pipeline/validator.cc b/mindspore/ccsrc/pipeline/validator.cc deleted file mode 100644 index bbca3c8721..0000000000 --- a/mindspore/ccsrc/pipeline/validator.cc +++ /dev/null @@ -1,120 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pipeline/validator.h" - -#include -#include - -#include "ir/manager.h" -#include "ir/dtype.h" -#include "./common.h" -#include "pipeline/static_analysis/prim.h" - -namespace mindspore { -namespace validator { -using mindspore::abstract::AbstractBase; -using mindspore::abstract::AbstractClass; -using mindspore::abstract::AbstractError; -using mindspore::abstract::AbstractFunction; -using mindspore::abstract::AbstractIndexedSlices; -using mindspore::abstract::AbstractJTagged; -using mindspore::abstract::AbstractList; -using mindspore::abstract::AbstractScalar; -using mindspore::abstract::AbstractTensor; -using mindspore::abstract::AbstractTuple; -using mindspore::abstract::AbstractType; - -void ValidateOperation(const AnfNodePtr &node) { - if (!IsValueNode(node)) { - return; - } - - // Primitive must in whitelist - PrimitivePtr prim = GetValueNode(node); - if (abstract::IsInWhiteList(prim)) { - return; - } - if (prim->HasPyEvaluator()) { - MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; - return; - } - if (prim->name() == "fake_bprop") { - MS_LOG(EXCEPTION) << "Illegal primitive: " << GetValue(prim->GetAttr("info")); - } - - MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); -} - -void ValidateAbstract(const AnfNodePtr &node) { - if (node == nullptr) { - MS_LOG(DEBUG) << "Node to validate is invalid"; - return; - } - AbstractBasePtr ptrBase = node->abstract(); - if (ptrBase == nullptr) { - MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); - return; - } - if (ptrBase->isa() || ptrBase->isa()) { - // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); - } - if (ptrBase->isa()) { - TypePtr ptrType = ptrBase->GetTypeTrack(); - MS_EXCEPTION_IF_NULL(ptrType); - if (ptrType->isa() || ptrType->isa()) { - // only send string in external - if (!IsValueNode(node)) { - // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); - } - } - return; - } - if (ptrBase->isa()) { - // NOTICE: validate dead code? - MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); - return; - } - - if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa()) { - return; - } - - if (ptrBase->isa()) { - return; - } - - // Other types show exception - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); -} - -void Validate(const FuncGraphPtr &fg) { - FuncGraphManagerPtr mgr = Manage(fg, false); - MS_EXCEPTION_IF_NULL(mgr); - AnfNodeSet &all_nodes = mgr->all_nodes(); - for (const auto &anf_node : all_nodes) { - ValidateOperation(anf_node); - ValidateAbstract(anf_node); - } -} -} // namespace validator -} // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/validator.h b/mindspore/ccsrc/pipeline/validator.h deleted file mode 100644 index 61f7470349..0000000000 --- a/mindspore/ccsrc/pipeline/validator.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ -#define MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H_ - -#include -#include -#include -#include -#include "operator/ops.h" -#include "ir/anf.h" -#include "utils/misc.h" - -namespace mindspore { -namespace validator { -void Validate(const FuncGraphPtr &func_graph); -void ValidateAbstract(const AnfNodePtr &node); -void ValidateOperation(const AnfNodePtr &node); -} // namespace validator -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PIPELINE_VALIDATOR_H__ diff --git a/mindspore/ccsrc/pre_activate/CMakeLists.txt b/mindspore/ccsrc/pre_activate/CMakeLists.txt deleted file mode 100644 index 239757fb17..0000000000 --- a/mindspore/ccsrc/pre_activate/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -file(GLOB_RECURSE _PREACTIVATE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "common/*.cc" - "mem_reuse/*.cc" - "pass/*.cc" - "gpu/*.cc" -) - -if (ENABLE_D) - file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc") - list(APPEND _PREACTIVATE_SRC_LIST ${_D_SRC_LIST}) -endif () - -set_property(SOURCE ${_PREACTIVATE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PRE_ACT) -add_library(_mindspore_pre_activate_obj OBJECT ${_PREACTIVATE_SRC_LIST}) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc deleted file mode 100644 index f6020500f8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ /dev/null @@ -1,495 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ascend_backend_optimization.h" -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fission/bn_split.h" -#include "pre_activate/ascend/ir_fission/bn_grad_split.h" -#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" -#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" -#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" -#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" -#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" -#include "pre_activate/pass/communication_op_fusion.h" -#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" -#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" -#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" -#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" -#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" -#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" -#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" -#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" -#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h" -#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" -#include "pre_activate/ascend/ir_fission/transdata_split.h" -#include "pre_activate/ascend/ir_fission/topk_split.h" -#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" -#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" -#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" -#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" -#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" -#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" -#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" -#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" -#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" -#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" -#include "pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.h" -#include "pre_activate/pass/getitem_tuple.h" -#include "pre_activate/pass/optimize_dependence.h" -#include "pre_activate/pass/erase_visit_attr.h" -#include "pre_activate/ascend/format_type/insert_cast.h" -#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" -#include "pre_activate/pass/eliminate_redundant_op.h" -#include "pre_activate/pass/common_subexpression_elimination.h" -#include "pre_activate/pass/fuse_graph_kernel.h" -#include "pre_activate/pass/fuse_basic.h" -#include "pre_activate/pass/add_atomic_clean.h" -#include "pre_activate/ascend/format_type/merge_cast_to_op.h" -#include "pre_activate/ascend/format_type/check_consistency.h" -#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" -#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" -#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" -#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" -#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" -#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" -#include "pre_activate/ascend/ir_fission/addn_fission.h" -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" -#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" -#include "pre_activate/ascend/ir_fission/split_fission.h" -#include "pre_activate/ascend/format_type/modify_ops_attrs.h" -#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" -#include "pre_activate/ascend/ir_fusion/add_input_to_output.h" -#include "utils/context/ms_context.h" -#include "utils/config_manager.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" - -namespace mindspore { -namespace opt { -namespace { -void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { - MS_EXCEPTION_IF_NULL(ir_fusion_pm); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); -} -} // namespace - -void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto data_layout_pm = std::make_shared("pynative_transop_pm"); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(data_layout_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - MS_EXCEPTION_IF_NULL(optimizer); - auto common_process = std::make_shared("graph_kernel_common_process"); - MS_EXCEPTION_IF_NULL(common_process); - common_process->AddPass(std::make_shared()); - common_process->AddPass(std::make_shared()); - optimizer->AddPassManager(common_process); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendDataLayout(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto data_layout_pm = std::make_shared("transop_pm"); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - data_layout_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(data_layout_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendMixPrecision(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto mixed_precision_pm = std::make_shared("cast_pm"); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(mixed_precision_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before" + "_graph_" + - std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - DumpIRProto(kernel_graph, "before_hwopt_" + std::to_string(kernel_graph->graph_id())); - } - auto optimizer = std::make_shared(); - auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - if (context_ptr->execution_mode() == kPynativeMode) { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - } else { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - } - ir_fusion_pm->AddPass(std::make_shared()); - if (context_ptr->ir_fusion_flag()) { - AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); - } - - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - } - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(ir_fusion_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_ir_fusion_after" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } -} - -void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->ir_fusion_flag()) { - MS_LOG(INFO) << "IRFusion is not enable, skip"; - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_before.ir"; - DumpIR(file_path, kernel_graph); - } - auto optimizer = std::make_shared(); - auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - - optimizer->AddPassManager(ir_fusion_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_ir_fusion_after.ir"; - DumpIR(file_path, kernel_graph); - } -} - -void AscendBackendOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_before" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - // data layout optimization - AscendDataLayout(kernel_graph); - // mixed precision optimization - AscendMixPrecision(kernel_graph); - // other optimization - auto optimizer = std::make_shared(); - auto other_pm = std::make_shared("other_pm"); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - other_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(other_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - // buffer fusion - AscendBackendUBFusionOptimization(kernel_graph); - - // other2 optimization - auto optimizer2 = std::make_shared(); - auto other2_pm = std::make_shared("other2_pm"); - other2_pm->AddPass(std::make_shared()); - other2_pm->AddPass(std::make_shared()); - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { - other2_pm->AddPass(std::make_shared()); - } - other2_pm->AddPass(std::make_shared()); - optimizer2->AddPassManager(other2_pm); - (void)optimizer2->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph, true); - DumpIRProto(kernel_graph, "after_hwopt_" + std::to_string(kernel_graph->graph_id())); - kernel_graph->DumpFuncGraph("hwopt_d_end"); - } -} - -void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_before_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph); - } - - // Fuse graph kernels with basic ops - FuseGraphKernel(kernel_graph, is_before_kernel_select); - - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_graph_kernel_opt_end_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph, true); - } -} - -void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_before_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph, true); - } - - // Fuse basic ops with basic ops - FuseBasic(kernel_graph, is_before_kernel_select); - - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_fuse_basic_opt_end_graph_" + - std::to_string(!is_before_kernel_select) + "_" + std::to_string(kernel_graph->graph_id()) + - ".ir"; - DumpIR(file_path, kernel_graph, true); - } -} - -void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->enable_graph_kernel())) { - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "hwopt_d_add_atomic_clean_before" + "_graph_" + - std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - - AddAtomicClean(kernel_graph); - - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "hwopt_d_end" + "_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph, true); - } -} - -void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->ir_fusion_flag()) { - MS_LOG(INFO) << "UBFusion is not enable, skip"; - return; - } - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_d_ub_fusion_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - auto fusion_id_allocator = std::make_shared(); - MS_EXCEPTION_IF_NULL(fusion_id_allocator); - fusion_id_allocator->Init(); - auto optimizer = std::make_shared(); - auto ub_fusion_pm = std::make_shared("ub_fusion_pm"); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); - ub_fusion_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(ub_fusion_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_d_ub_fusion_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h deleted file mode 100644 index 222c4b90b5..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ -#include -#include "session/kernel_graph.h" -namespace mindspore { -namespace opt { -void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph); -void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); -void AscendDataLayout(const std::shared_ptr &kernel_graph); -void AscendMixPrecision(const std::shared_ptr &kernel_graph); -void AscendBackendOptimization(const std::shared_ptr &kernel_graph); -void AscendGraphKernelCommonProcess(const std::shared_ptr &kernel_graph); -void AscendBackendGraphKernelOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select = false); -void AscendBackendFuseBasicOpt(const std::shared_ptr &kernel_graph, - bool is_before_kernel_select = false); -void AscendBackendAddAtomicClean(const std::shared_ptr &kernel_graph); -void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); -void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc deleted file mode 100644 index 9c498bd736..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ /dev/null @@ -1,345 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ascend_helper.h" -#include -#include "common/trans.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" -#include "device/kernel_info.h" -#include "kernel/oplib/oplib.h" -#include "kernel/common_utils.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; -namespace { -const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; -AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, - const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { - std::vector trans_inputs; - auto prim = std::make_shared(prim::kPrimReshape->name()); - trans_inputs.emplace_back(NewValueNode(prim)); - trans_inputs.emplace_back(input_node); - auto reshape = func_graph->NewCNode(trans_inputs); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape); - reshape->set_scope(input_node->scope()); - kernel_select->SelectKernel(reshape); - return reshape; -} - -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { - AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = node; - CNodePtr trans_data = nullptr; - std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); - std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; - std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); - MS_EXCEPTION_IF_NULL(node); - // if insert transdata for input we need to change the input - if (is_insert_input) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; - } - auto cnode = node->cast(); - dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); - input_node = AnfAlgo::GetInputNode(cnode, insert_index); - padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); - } - bool need_padding = false; - if (is_insert_input) { - need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); - } else { - need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); - } - if (!need_padding) { - // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else if (is_insert_input) { - // if need padding & is input need insert a transdata - // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); - auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else { - // if need padding & is output need insert a transdata - // node -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - auto reshape_node = - CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); - trans_node = reshape_node; - } - // refresh the transdata's format to ori format & dst format - RefreshKernelBuildInfo(input_format, dst_format, trans_data, padding_axis); - return trans_node; -} - -AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - auto input_node = AnfAlgo::GetInputNode(node, index); - auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); - MS_EXCEPTION_IF_NULL(node_with_index.first); - auto real_input = node_with_index.first; - if (real_input->isa() || real_input->isa()) { - input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); - MS_EXCEPTION_IF_NULL(input_node); - AnfAlgo::SetNodeInput(node, input_node, index); - } - std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); - std::string dest_format = AnfAlgo::GetInputFormat(node, index); - if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) - << " To DefaultFormat , index: " << index; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); - } - return input_node; -} - -AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - std::string output_format = AnfAlgo::GetOutputFormat(node, 0); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); - if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " - << node->DebugString(); - } - if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); - } - return node; -} - -AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { - std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); - if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " - << node->DebugString(); - } - auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) { - make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); - } else { - // No need insert trans op. - make_tuple_inputs.push_back(tuple_getitem); - } - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type) { - MS_EXCEPTION_IF_NULL(trans_data); - auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); - MS_EXCEPTION_IF_NULL(ori_build_info); - auto builder = std::make_shared(ori_build_info); - builder->SetInputsFormat({input_format}); - builder->SetInputReshapeType({reshape_type}); - builder->SetOutputReshapeType({reshape_type}); - builder->SetOutputsFormat({output_format}); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); -} - -CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(input); - std::vector trans_inputs; - auto prim = std::make_shared(op_name); - trans_inputs.push_back(NewValueNode(prim)); - trans_inputs.push_back(input); - CNodePtr trans_node = func_graph->NewCNode(trans_inputs); - MS_EXCEPTION_IF_NULL(trans_node); - auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); - if (need_padding) { - // if need padding we should set the transdata node's shape to the padding shape - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, - trans_node.get()); - } else { - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); - } - // special handle for ut - if (trans_node->kernel_info() == nullptr) { - auto kernel_info = std::make_shared(); - trans_node->set_kernel_info(kernel_info); - } - MS_EXCEPTION_IF_NULL(kernel_select); - kernel_select->SelectKernel(trans_node); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); - MS_EXCEPTION_IF_NULL(trans_node); - trans_node->set_scope(input->scope()); - return trans_node; -} - -AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, - const TypeId &input_type, const TypeId &output_type, - const std::vector &origin_shape, const TypeId &origin_type) { - MS_EXCEPTION_IF_NULL(func_graph); - std::string input_format = format; - std::string output_format = format; - std::vector new_cast_inputs; - auto prim = std::make_shared(prim::kPrimCast->name()); - new_cast_inputs.push_back(NewValueNode(prim)); - new_cast_inputs.push_back(input); - CNodePtr cast = func_graph->NewCNode(new_cast_inputs); - MS_EXCEPTION_IF_NULL(cast); - // set kernel build info - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetInputsFormat({input_format}); - builder.SetOutputsFormat({output_format}); - builder.SetInputsDeviceType({input_type}); - builder.SetOutputsDeviceType({output_type}); - builder.SetFusionType(kernel::FusionType::OPAQUE); - builder.SetProcessor(kernel::Processor::AICORE); - if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kTBE) != nullptr) { - builder.SetKernelType(KernelType::TBE_KERNEL); - } else { - builder.SetKernelType(KernelType::AKG_KERNEL); - } - // if kernel info is null , it remarks this function is running ut - if (cast->kernel_info() == nullptr) { - auto kernel_info = std::make_shared(); - cast->set_kernel_info(kernel_info); - } - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); - AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); - return cast; -} - -AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - size_t outputs_num = AnfAlgo::GetOutputTensorNum(node); - if (outputs_num == 0) { - return node; - } - // Single output - if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) { - return InsertTransOpForSingleOutput(func_graph, node, kernel_select); - } - // Multiple output - return InsertTransOpForMultipleOutput(func_graph, node, kernel_select); -} - -AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); - MS_EXCEPTION_IF_NULL(input_node); - new_inputs.push_back(input_node); - } - CNodePtr new_cnode = nullptr; - // cnode changed so make a new cnode to differ from original one. - auto kernel_graph = func_graph->cast>(); - if (kernel_graph == nullptr) { - new_cnode = std::make_shared(*cnode); - } else { - new_cnode = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_inputs(new_inputs); - return new_cnode; -} - -CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - const auto infer_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - TypeId origin_type(kTypeUnknown); - auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); - auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); - auto real_input_node = kernel_with_index.first; - if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - // weight - origin_type = AnfAlgo::GetPrevNodeOutputPrecision(cnode, input_index); - if (origin_type == kTypeUnknown) { - origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); - } - } else { - // feature map - origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); - } - const std::string dev_fmt = AnfAlgo::GetInputFormat(cnode, input_index); - const std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_index); - const TypeId device_type = AnfAlgo::GetInputDeviceDataType(cnode, input_index); - // In graph kernel, we check parameter, - // the eliminate pass will not eliminate this case, so we just do not insert the noused cast. - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL) && IsValueNode(cur_input)) { - new_inputs.push_back(cur_input); - } else if (origin_type != device_type) { - auto cast = - AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, origin_type, device_type, origin_shape, infer_type); - MS_EXCEPTION_IF_NULL(cast); - cast->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast); - new_inputs.push_back(cast); - } else { - new_inputs.push_back(cur_input); - } - } - auto kernel_graph = func_graph->cast>(); - CNodePtr new_node = nullptr; - if (kernel_graph == nullptr) { - new_node = std::make_shared(*cnode); - } else { - new_node = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_inputs(new_inputs); - return new_node; -} - -AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto prim = std::make_shared(kMemCpyAsyncOpName); - std::vector new_node_inputs = {NewValueNode(prim), node}; - auto new_node = graph->NewCNode(new_node_inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_abstract(node->abstract()); - new_node->set_scope(node->scope()); - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h deleted file mode 100644 index dc88ca2e52..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ - -#include -#include -#include -#include "device/ascend/kernel_select_ascend.h" -#include "kernel/kernel_query.h" -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -class KernelSelect { - public: - KernelSelect() = default; - virtual ~KernelSelect() = default; - virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } -}; -using KernelSelectPtr = std::shared_ptr; - -class SupportedChecker { - public: - SupportedChecker() = default; - virtual ~SupportedChecker() = default; - virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); - } - virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) { - return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); - } -}; -using SupportedCheckerPtr = std::shared_ptr; - -class KernelQuery { - public: - KernelQuery() = default; - virtual ~KernelQuery() = default; - virtual void Query(const CNodePtr &kernel_node, - std::vector> *kernel_info_list) { - kernel::KernelQuery(kernel_node, kernel_info_list); - } - virtual bool IsTbeRef(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); - if (op_info != nullptr) { - return op_info->is_ref(); - } - return false; - } -}; -using KernelQueryPtr = std::shared_ptr; - -class OpFinder { - public: - OpFinder() = default; - virtual ~OpFinder() = default; - virtual int GetOpRegisteredOutputNum(const std::string &op_name) { - auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); - if (op_info == nullptr) { - return -1; - } - return op_info->outputs_ptr().size(); - } -}; -using OpFinderPtr = std::shared_ptr; - -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); - -CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name); - -AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, - const TypeId &input_type, const TypeId &output_type, - const std::vector &origin_shape, const TypeId &origin_type); - -AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select); - -AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select); - -CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); - -AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ASCEND_HELPER_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc deleted file mode 100644 index 94318d63ca..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(relu_input); - auto add = relu_input->cast(); - MS_EXCEPTION_IF_NULL(add); - auto tuple_getitem = add->input(1); - MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->isa() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) { - auto getitem = tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(getitem); - auto bnupdate = getitem->input(1); - MS_EXCEPTION_IF_NULL(bnupdate); - if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); - for (auto out_getitem : manager->node_users()[bnupdate]) { - MS_EXCEPTION_IF_NULL(out_getitem.first); - auto out_getitem_ptr = out_getitem.first->cast(); - MS_EXCEPTION_IF_NULL(out_getitem_ptr); - auto input2 = out_getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); - } - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); - std::unordered_set record{cnode, relu_input, bnupdate}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { - MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h deleted file mode 100644 index 6cdc5885f6..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { - public: - explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} - ~BnupdateEltwiseEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc deleted file mode 100644 index 1f7fef9e62..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(relu_input); - auto getitem = relu_input->cast(); - MS_EXCEPTION_IF_NULL(getitem); - auto bnupdate = getitem->input(1); - MS_EXCEPTION_IF_NULL(bnupdate); - if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { - std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); - for (auto out_getitem : manager->node_users()[bnupdate]) { - MS_EXCEPTION_IF_NULL(out_getitem.first); - auto out_getitem_ptr = out_getitem.first->cast(); - MS_EXCEPTION_IF_NULL(out_getitem_ptr); - auto input2 = out_getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); - } - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); - std::unordered_set record{cnode, bnupdate}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { - MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h deleted file mode 100644 index b5688f3a36..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class BnupdateEltwiseFusionPass : public FusionBasePass { - public: - explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} - ~BnupdateEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc deleted file mode 100644 index 6091eb572d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( - const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - } else { - return; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - auto double_in_eltwise_input = input_cnode->input(1); - MS_EXCEPTION_IF_NULL(double_in_eltwise_input); - if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { - return; - } - if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { - (void)record.insert(double_in_eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && - (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { - MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h deleted file mode 100644 index 7d779d35f8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { - public: - explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} - ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc deleted file mode 100644 index 963f1885fe..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { - (void)record.insert(eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && - (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { - MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h deleted file mode 100644 index 171352de9b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { - public: - explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} - ~Conv2DBackpropEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc deleted file mode 100644 index 63e7dcf6b8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.cc +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" - -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto conv = cnode->input(1); - MS_EXCEPTION_IF_NULL(conv); - if (conv->isa() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) { - std::vector output_used_num{SizeToInt(manager->node_users()[conv].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); - std::unordered_set record{cnode, conv}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { - MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h deleted file mode 100644 index 7a06faa624..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_CONV_BNREDUCE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ConvBnReduceFusionPass : public FusionBasePass { - public: - explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} - ~ConvBnReduceFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_CONV_BNREDUCE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc deleted file mode 100644 index a126143811..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - } else { - return; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - auto double_in_eltwise_input = input_cnode->input(1); - MS_EXCEPTION_IF_NULL(double_in_eltwise_input); - if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) { - (void)record.insert(double_in_eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h deleted file mode 100644 index 062b8182fb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ConvDoubleInFusionPass : public FusionBasePass { - public: - explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} - ~ConvDoubleInFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_DOUBLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc deleted file mode 100644 index d83b32a888..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - if (record.size() == MAX_ELTWISE_NUM) { - break; - } - } - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) { - (void)record.insert(eltwise_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchConvSingleInEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h deleted file mode 100644 index bf7e581dff..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ConvSingleInFusionPass : public FusionBasePass { - public: - explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} - ~ConvSingleInFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchConvSingleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV_SINGLE_IN_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc deleted file mode 100644 index 98a6838bed..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" - -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion, bool is_order) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - if (is_order) { - // DepthwiseConvolution--->Elemwise - auto depthwise_conv = cnode->input(1); - MS_EXCEPTION_IF_NULL(depthwise_conv); - if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { - std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); - std::unordered_set record{cnode, depthwise_conv}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } else { - // Elemwise-->DepthwiseConvolution - auto relu = cnode->input(1); - MS_EXCEPTION_IF_NULL(relu); - if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { - std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); - std::unordered_set record{cnode, relu}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); - } - } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { - MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h deleted file mode 100644 index c2e72f26ff..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class DepthwiseConvEltwiseFusionPass : public FusionBasePass { - public: - explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} - ~DepthwiseConvEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion, bool is_order); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc deleted file mode 100644 index 2f04e16692..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - if (record.size() == MAX_ELTWISE_SIZE) { - break; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - } - if (record.size() < MIN_ELTWISE_SIZE) { - return; - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h deleted file mode 100644 index 54ff0f5982..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class EltwiseFusionPass : public FusionBasePass { - public: - explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} - ~EltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc deleted file mode 100644 index a516f04442..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include -#include -#include "debug/anf_ir_dump.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && - cnode->inputs().size() == ELTWISE_INPUT_SIZE; -} - -bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && - cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; -} - -bool FusionBasePass::CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_MULTI_USE && - cnode->inputs().size() == ELTWISE_INPUT_SIZE; -} - -void FusionBasePass::SetRecordFusionId(const std::unordered_set &record) { - auto id = fusion_id_allocator->AllocateFusionId(); - for (auto node : record) { - fusion_id_allocator->SetFusionId(node, id); - } -} - -bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - auto return_node = kernel_graph.get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() <= 1) { - return false; - } - MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; - FusedNodeRecord candidate_fusion; - MatchSingleFusionPattern(kernel_graph, &candidate_fusion); - if (candidate_fusion.empty()) { - return false; - } - MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; - return true; -} - -bool FusionBasePass::Run(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto kernel_graph = graph->cast>(); - MS_EXCEPTION_IF_NULL(kernel_graph); - return MatchUBFusionPattern(*kernel_graph); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h deleted file mode 100644 index 8d6eca774c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ -#include -#include -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -const int8_t MAX_ELTWISE_NUM = 3; -const int8_t MIN_ELTWISE_SIZE = 2; -const int8_t ELTWISE_INPUT_SIZE = 2; -const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; -const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3; -const int8_t CONV_QUART_IN_INPUT_SIZE = 5; -const int8_t ELTWISE_USE = 1; -const int8_t ELTWISE_MULTI_USE = 2; -const int8_t MAX_ELTWISE_SIZE = 6; -const int8_t MULTI_ELTWISE_SIZE = 4; -using FusedNodeRecord = std::vector>; - -struct BufferFusionInfo_t { - std::vector anf_nodes; - std::vector inputs_list; - std::vector outputs_list; - kernel::KernelBuildInfoPtr kernel_build_info; -}; - -class FusionBasePass : public Pass { - public: - FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) - : Pass(name), fusion_id_allocator(idAllocator) {} - ~FusionBasePass() override = default; - bool Run(const FuncGraphPtr &graph) override; - bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph); - - protected: - virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) = 0; - void SetRecordFusionId(const std::unordered_set &record); - bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - bool CheckMultiOutputEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - FusionIdAllocatorPtr fusion_id_allocator; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc deleted file mode 100644 index d1ef5dc83b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector output_used_num{SizeToInt(manager->node_users()[relu_input].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); - std::unordered_set record{cnode, relu_input}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimMatMul)) { - MatchMatmulEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h deleted file mode 100644 index 5baaa6db86..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class MatmulEltwiseFusionPass : public FusionBasePass { - public: - explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} - ~MatmulEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchMatmulEltwise(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MATMUL_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc deleted file mode 100644 index be4d2af1cb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - MS_EXCEPTION_IF_NULL(eltwise_input); - if (CheckMultiOutputEltWiseNode(manager.get(), eltwise_input)) { - std::vector output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())}; - AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input); - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - } else { - return; - } - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - if (record.size() == MULTI_ELTWISE_SIZE) { - break; - } - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - } - if (record.size() != MULTI_ELTWISE_SIZE) { - return; - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); -} - -void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchMultiOutputEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h deleted file mode 100644 index 0e2510128a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/multi_output_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class MultiOutputFusionPass : public FusionBasePass { - public: - explicit MultiOutputFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("MultiOutputFusionPass", idAllocator) {} - ~MultiOutputFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchMultiOutputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_MULTI_OUTPUT_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc deleted file mode 100644 index 623f0e3426..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.cc +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - if (record.size() == MAX_ELTWISE_NUM) { - break; - } - } - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE) { - (void)record.insert(eltwise_input); - auto previous_input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_input_cnode); - auto previous_eltwise_input = previous_input_cnode->input(1); - auto previous_size = record.size(); - while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { - (void)record.insert(previous_eltwise_input); - auto previous_node = previous_eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_node); - previous_eltwise_input = previous_node->input(1); - if (record.size() - previous_size == MAX_ELTWISE_NUM) { - break; - } - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchReduceEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h deleted file mode 100644 index 42d896e96b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class ReduceEltwiseFusionPass : public FusionBasePass { - public: - explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} - ~ReduceEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchReduceEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_REDUCE_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc deleted file mode 100644 index 0dcf2362bc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto eltwise_input = cnode->input(1); - while (CheckEltWiseNode(manager.get(), eltwise_input)) { - (void)record.insert(eltwise_input); - auto input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - eltwise_input = input_cnode->input(1); - if (record.size() == MAX_ELTWISE_NUM) { - break; - } - } - MS_EXCEPTION_IF_NULL(eltwise_input); - if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || - fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { - return; - } - if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) { - (void)record.insert(eltwise_input); - auto previous_input_cnode = eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_input_cnode); - auto previous_eltwise_input = previous_input_cnode->input(1); - auto previous_size = record.size(); - while (CheckEltWiseNode(manager.get(), previous_eltwise_input)) { - (void)record.insert(previous_eltwise_input); - auto previous_node = previous_eltwise_input->cast(); - MS_EXCEPTION_IF_NULL(previous_node); - previous_eltwise_input = previous_node->input(1); - if (record.size() - previous_size == MAX_ELTWISE_NUM) { - break; - } - } - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - std::reverse(node_list.begin(), node_list.end()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) { - MatchSegmentEltwise(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h deleted file mode 100644 index 41f06ba1f9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWISE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class SegmentEltwiseFusionPass : public FusionBasePass { - public: - explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} - ~SegmentEltwiseFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchSegmentEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_SEGMENT_ELTWSIE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc deleted file mode 100644 index 5bc0fdced7..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h" - -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(const CNodePtr &cnode, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - std::unordered_set record{cnode}; - auto write_input = cnode->input(1); - if (CheckEltWiseNode(manager.get(), write_input)) { - (void)record.insert(write_input); - auto input_cnode = write_input->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - write_input = input_cnode->input(1); - } - MS_EXCEPTION_IF_NULL(write_input); - if (!write_input->isa() || !AnfAlgo::IsRealCNodeKernel(write_input) || - fusion_id_allocator->HasFusionIdAttr(write_input)) { - return; - } - auto conv_cnode = write_input->cast(); - MS_EXCEPTION_IF_NULL(conv_cnode); - if (AnfAlgo::GetKernelType(conv_cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(conv_cnode) == kernel::FusionType::CONVLUTION && - conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE && - conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) { - (void)record.insert(write_input); - auto conv_input = conv_cnode->input(1); - MS_EXCEPTION_IF_NULL(conv_input); - if (!conv_input->isa() || !AnfAlgo::IsRealCNodeKernel(conv_input) || - fusion_id_allocator->HasFusionIdAttr(conv_input)) { - return; - } - if (AnfAlgo::GetCNodeName(conv_input) == kStridedReadOpName) { - (void)record.insert(conv_input); - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } - } -} - -void StridedReadConvStridedWriteFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - std::vector node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == kStridedWriteOpName) { - MatchStridedReadConvStridedWrite(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h deleted file mode 100644 index c6c5fe88dc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/stridedread_conv_stridedwrite_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ - -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class StridedReadConvStridedWriteFusionPass : public FusionBasePass { - public: - explicit StridedReadConvStridedWriteFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("StridedReadConvStridedWriteFusionPass", idAllocator) {} - ~StridedReadConvStridedWriteFusionPass() override = default; - void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; - - private: - void MatchStridedReadConvStridedWrite(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_STRIDEDREAD_CONV_STRIDEDWRITE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc deleted file mode 100644 index faa5169c40..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.cc +++ /dev/null @@ -1,448 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "kernel/kernel_fusion.h" -#include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -namespace { -const int8_t MAX_PATTERN_SIZE = 7; -const int8_t MIN_PATTERN_SIZE = 2; -const int8_t ELTWISE_INPUT_SIZE = 2; -const int8_t ELTWISE_USE = 1; -const int8_t MULTI_ELTWISE_USE = 2; -const int8_t MAX_MULTI_ELTWISE_SIZE = 4; -const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; -constexpr auto kOpAttrFusionId = "fusion_id"; - -#ifdef DEBUG -std::string GetFusionTypeName(const kernel::FusionType &type) { - switch (type) { - case kernel::FusionType::COMMREDUCE: - return "COMMREDUCE"; - case kernel::FusionType::SEGMENT: - return "SEGMENT"; - case kernel::FusionType::ELEMWISE: - return "ELEMWISE"; - case kernel::FusionType::CONVLUTION: - return "CONVLUTION"; - case kernel::FusionType::OPAQUE: - return "OPAQUE"; - default: - return "OPAQUE"; - } -} - -void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { - MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id; - for (auto &node : info.input_nodes) { - MS_LOG(INFO) << "=== Input: " << node->DebugString(); - } - for (auto &node : info.output_nodes) { - MS_LOG(INFO) << "=== Output: " << node->DebugString(); - } - for (auto &node : info.compute_nodes) { - MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-(" << GetFusionTypeName(AnfAlgo::GetFusionType(node)) - << ")"; - } - MS_LOG(INFO) << "=== Dump FusionScopeInfo end"; -} -#endif -CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::vector &outputs_list, - const std::vector &anf_nodes, session::KernelGraph *kernel_graph) { - MS_LOG(DEBUG) << "Start Create FusionOp Kernel"; - MS_EXCEPTION_IF_NULL(kernel_graph); - std::string fusion_op_name = "FusionOp"; - for (auto node : anf_nodes) { - fusion_op_name += '_' + AnfAlgo::GetCNodeName(node); - } - auto fusion_op = std::make_shared(fusion_op_name); - MS_EXCEPTION_IF_NULL(fusion_op); - - std::vector input_names; - for (uint8_t i = 0; i < inputs_list.size(); i++) { - input_names.emplace_back("input" + std::to_string(i)); - } - std::vector output_names; - for (uint8_t i = 0; i < outputs_list.size(); i++) { - output_names.emplace_back("output" + std::to_string(i)); - } - - ValuePtr input_names_v = MakeValue(input_names); - ValuePtr output_names_v = MakeValue(output_names); - fusion_op->set_attr("input_names", input_names_v); - fusion_op->set_attr("output_names", output_names_v); - std::vector fusion_inputs_list = inputs_list; - auto value_node = std::make_shared(fusion_op); - (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); - auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list); - if (buffer_fusion_kernel == nullptr) { - MS_LOG(EXCEPTION) << "New FusionOp kernel failed!"; - } - buffer_fusion_kernel->set_scope((anf_nodes.back())->scope()); - - return buffer_fusion_kernel; -} - -kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, - const std::vector &outputs_list) { - MS_LOG(DEBUG) << "Start Create Kernel Info"; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - // inputs format and data type - std::vector inputs_format; - std::vector inputs_data_type; - for (const auto &input : inputs_list) { - auto real_input = AnfAlgo::VisitKernel(input, 0); - inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); - inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); - } - // outputs format and data type - std::vector outputs_format; - std::vector outputs_data_type; - for (const auto &output : outputs_list) { - if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { - auto tuple_getitem = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - outputs_format.push_back(AnfAlgo::GetOutputFormat( - tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( - tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); - } else { - outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); - outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); - } - } - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_data_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_data_type); - builder.SetKernelType(KernelType::TBE_KERNEL); - return builder.Build(); -} - -AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph, - size_t output_index) { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector tuple_getitem_inputs_list; - auto value = std::make_shared(prim::kPrimTupleGetItem); - MS_EXCEPTION_IF_NULL(value); - auto idx = NewValueNode(SizeToInt(output_index)); - MS_EXCEPTION_IF_NULL(idx); - int temp = SizeToInt(output_index); - auto imm = std::make_shared(temp); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - tuple_getitem_inputs_list.push_back(value); - tuple_getitem_inputs_list.push_back(buffer_fusion_kernel); - tuple_getitem_inputs_list.push_back(idx); - auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list); - MS_EXCEPTION_IF_NULL(tuple_item); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, - {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)}, - tuple_item.get()); - return tuple_item; -} - -void ReplaceInputNodeInOtherFusionScope(std::unordered_map *buffer_fusion_infos, - int32_t fusion_id, const AnfNodePtr &output_item, - const AnfNodePtr &replace_item) { - for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { - auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), - output_item); - if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { - MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; - *itr = replace_item; - } - } -} - -void ReplaceOldNode(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, - const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; - if (buffer_fusion_info.outputs_list.size() == 1) { // single output - (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); - ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], - buffer_fusion_kernel); - } else { // multiple output - for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { - auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); - (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); - ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], - tuple_item); - } - } -} - -void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto nodes = TopoSort(kernel_graph->get_return()); - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { - auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); - (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); - } - } -} - -void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto manager = kernel_graph.manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; - for (const auto &node : fusion_info.anf_nodes) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { - auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == - fusion_info.anf_nodes.end()) { - if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), - (*buffer_fusion_infos)[fusion_id].inputs_list.end(), - cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { - (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); - } - } - } - } - } -} - -bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - auto getitem1 = node1->cast(); - auto getitem2 = node2->cast(); - MS_EXCEPTION_IF_NULL(getitem1); - MS_EXCEPTION_IF_NULL(getitem2); - if (getitem1->size() < kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" - << getitem1->DebugString() << "]"; - } - if (getitem2->size() < kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "node's input size less than " << kTupleGetItemInputSize << ", getitem1[" - << getitem2->DebugString() << "]"; - } - auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); - auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); - return output_idx1 < output_idx2; -} - -void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - auto fusion_id = buffer_fusion_info.first; - auto fusion_info = buffer_fusion_info.second; - for (const auto &node : fusion_info.anf_nodes) { - if (AnfAlgo::GetOutputTensorNum(node) == 1) { - for (auto use_node : manager->node_users()[node]) { - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == - fusion_info.anf_nodes.end()) { - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); - break; - } - } - } else { - int prev_idx = 0; - std::vector tuple_getitem_nodes; - std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), - std::back_inserter(tuple_getitem_nodes), - [](const std::pair &use_node) { return use_node.first; }); - std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); - for (auto getitem : tuple_getitem_nodes) { - MS_EXCEPTION_IF_NULL(getitem); - auto getitem_ptr = getitem->cast(); - auto input2 = getitem_ptr->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { - auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); - } - prev_idx = output_idx + 1; - for (auto item_use_node : manager->node_users()[getitem]) { - if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == - fusion_info.anf_nodes.end()) { - (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); - break; - } - } - } - } - } - } -} - -void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, - const AnfNodePtr &fusion_kernel) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto manager = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (size_t idx = 0; idx < outputs_list.size(); ++idx) { - auto output = outputs_list[idx]; - MS_EXCEPTION_IF_NULL(output); - if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { - auto real_output = AnfAlgo::VisitKernel(output, 0); - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto input2 = output_cnode->input(2); - auto output_idx = GetValue(GetValueNode(input2)); - session::AnfWithOutIndex out_pair(real_output.first, output_idx); - if (kernel_graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); - session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); - kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); - } - } else { - session::AnfWithOutIndex out_pair(output, 0); - if (kernel_graph->IsInRefOutputMap(out_pair)) { - auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); - session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); - kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); - } - } - } -} -} // namespace - -void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) const { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); - GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); - GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); - for (auto &buffer_fusion_info : *buffer_fusion_infos) { - buffer_fusion_info.second.kernel_build_info = - CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); - } -} - -bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - bool change = false; - std::unordered_map buffer_fusion_infos; - buffer_fusion_infos.clear(); - GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); - - std::vector fusion_scope_infos; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - mindspore::kernel::FusionScopeInfo fusion_scope_info; - fusion_scope_info.scope_id = buffer_fusion_info.first; - fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; - fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; - fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; - fusion_scope_infos.push_back(fusion_scope_info); -#ifdef DEBUG - DumpFusionScopeInfo(fusion_scope_info); -#endif - } - auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); - std::vector fusion_ids; - for (auto &buffer_fusion_info : buffer_fusion_infos) { - MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() - << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() - << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); - fusion_ids.push_back(buffer_fusion_info.first); - } - // Replace fusion op from return to head - std::sort(fusion_ids.begin(), fusion_ids.end()); - for (auto &fusion_id : fusion_ids) { - // Get kernel mod when supporting tbe - if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { - MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; - continue; - } - change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); - } - MS_LOG(DEBUG) << "End Buffer Fusion"; - return change; -} - -bool UbPatternFusion::ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, - int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, - session::KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; - auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, - buffer_fusion_info.anf_nodes, kernel_graph); - AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); - // Set abstract of fusion_op node - std::vector types; - std::vector> shapes; - for (const auto &out_node : buffer_fusion_info.outputs_list) { - for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { - types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); - shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); - } - } - if (types.empty() || shapes.empty()) { - MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty"; - return false; - } - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); - AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); - SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); - ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); - return true; -} - -bool UbPatternFusion::Run(const FuncGraphPtr &graph) { - bool changed = false; - MS_EXCEPTION_IF_NULL(graph); - auto kernel_graph = graph->cast>(); - MS_EXCEPTION_IF_NULL(kernel_graph); - changed = FuseBufferFusionPattern(kernel_graph.get()); - // clear fusion_id attr - for (auto &node : graph->nodes()) { - if (node != nullptr && node->isa()) { - AnfAlgo::EraseNodeAttr(kAttrFusionId, node); - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h deleted file mode 100644 index 7099c92772..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ -#include -#include -#include - -#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" -#include "ir/anf.h" -#include "pre_activate/common/pass.h" -#include "pre_activate/common/fusion_id_allocator.h" -#include "device/kernel_info.h" -#include "kernel/kernel.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -using FusedNodeRecord = std::vector>; - -class UbPatternFusion : public Pass { - public: - UbPatternFusion() : Pass("TbeBufferFusion") {} - ~UbPatternFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - void GetBufferFusionInfo(session::KernelGraph *kernel_graph, - std::unordered_map *buffer_fusion_infos) const; - bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, - const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; - bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_UB_PATTERN_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc deleted file mode 100644 index 6d0906363e..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" - -namespace mindspore::opt { - -const BaseRef GetnextMemcpyElimination::DefinePattern() const { - auto prim_memcpy = std::make_shared(kMemCpyAsyncOpName); - VarPtr x = std::make_shared(); - VectorRef memcpy_async({prim_memcpy, x}); - return memcpy_async; -} - -const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - auto memcpy_cnode = node->cast(); - if (memcpy_cnode == nullptr) { - return nullptr; - } - - // 1. memcpy has attr kAttrLabelForInsertStreamActive - if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, memcpy_cnode)) { - MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr"; - return nullptr; - } - - // 2. memcpy's output has only one user next_node - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(memcpy_cnode) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "memcpy has no output in manager"; - } - auto next_nodes = manager->node_users()[memcpy_cnode]; - if (next_nodes.size() > 1) { - MS_LOG(DEBUG) << "node's output has more than one users"; - return nullptr; - } - - // 3. next_node is not nop node and it has only one input which is memcpy's output - for (auto &item : next_nodes) { - auto next_node = item.first->cast(); - if (opt::IsNopNode(next_node)) { - return nullptr; - } - if (next_node->inputs().size() != 2) { - MS_LOG(DEBUG) << "next node has more than one input"; - return nullptr; - } - // add attr label_for_insert_stream_active for next_node - AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), next_node); - } - - return memcpy_cnode->input(1); -} -} // namespace mindspore::opt diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h deleted file mode 100644 index 523fc87a38..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class GetnextMemcpyElimination : public PatternProcessPass { - public: - explicit GetnextMemcpyElimination(bool multigraph = true) - : PatternProcessPass("getnext_memcpy_elimination", multigraph) {} - ~GetnextMemcpyElimination() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_GETNEXT_MEMCPY_ELIMINATION_H diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc deleted file mode 100644 index 01a3f789e7..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" -#include -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -AnfNodePtr InsertMemcpyAsyncForGetNextOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - if (func_graph == nullptr || node == nullptr) { - return nullptr; - } - - size_t output_num = AnfAlgo::GetOutputTensorNum(node); - if (output_num == 0) { - MS_LOG(DEBUG) << "Output number is zero, no need to insert memcpy_async!"; - return node; - } - - // getnext output is tuple and dynamic - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - - for (size_t output_index = 0; output_index < output_num; ++output_index) { - auto tuple_get_item = CreatTupleGetItemNode(func_graph, node, output_index); - auto new_node = CreateMemcpyAsyncOp(func_graph, tuple_get_item); - if (new_node == nullptr) { - MS_LOG(EXCEPTION) << "Create memcpy_async op failed!"; - } - AnfAlgo::SetNodeAttr(kAttrLabelForInsertStreamActive, MakeValue(true), new_node); - make_tuple_inputs.push_back(new_node); - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} - -const BaseRef InsertMemcpyAsyncForGetNext::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - auto prim = std::make_shared(kGetNextOpName); - - return VectorRef({prim, Xs}); -} - -const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (func_graph == nullptr || node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - - auto cnode = node->cast(); - if (AnfAlgo::HasNodeAttr(kAttrVisited, cnode)) { - MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited."; - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cnode); - - return InsertMemcpyAsyncForGetNextOutputs(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h deleted file mode 100644 index eb3b78d33f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class InsertMemcpyAsyncForGetNext : public PatternProcessPass { - public: - explicit InsertMemcpyAsyncForGetNext(bool multigraph = true) - : PatternProcessPass("insert_memcpy_async_for_getnext", multigraph) {} - ~InsertMemcpyAsyncForGetNext() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_GETNEXT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc deleted file mode 100644 index 63ea59d744..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ /dev/null @@ -1,144 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" -#include -#include -#include -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -namespace { -// insert memcpy for some cnode even if not a Ref cnode -const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, - kLambUpdateWithLROpName}; - -bool IsParameterOrValueNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - return kernel_with_index.first->isa() || kernel_with_index.first->isa(); -} - -void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(hccl_node); - MS_EXCEPTION_IF_NULL(memcpy_async); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(hccl_node); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // find hccl_node's output which is a control depend - for (const auto &node_index : iter->second) { - AnfNodePtr output = node_index.first; - int output_index = node_index.second; - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - CNodePtr control_depend = output->cast(); - MS_EXCEPTION_IF_NULL(control_depend); - std::vector new_inputs; - for (size_t i = 0; i < control_depend->size(); ++i) { - if (i == IntToSize(output_index)) { - new_inputs.push_back(memcpy_async); - } else { - new_inputs.push_back(control_depend->input(i)); - } - } - control_depend->set_inputs(new_inputs); - } - } -} -} // namespace - -bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input); - // when input is a parameter or is a value node - if (IsParameterOrValueNode(input)) { - return true; - } - - // when input is a Ref or some special cnodes - if (kernel_query_->IsTbeRef(input) || - kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { - return true; - } - - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(input); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // when input is used by others - if (iter->second.size() > 1) { - return true; - } - return false; -} - -void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(hccl_node); - bool has_insert_memcpy = false; - AnfNodePtr memcpy_async = nullptr; - std::vector new_inputs = {hccl_node->input(0)}; - for (size_t i = 1; i < hccl_node->size(); ++i) { - auto input = hccl_node->input(i); - if (NeedInsertMemcpy(graph, input)) { - memcpy_async = CreateMemcpyAsyncOp(graph, input); - has_insert_memcpy = true; - new_inputs.push_back(memcpy_async); - } else { - new_inputs.push_back(input); - } - } - - if (has_insert_memcpy) { - CNodePtr new_hccl_node = std::make_shared(*hccl_node); - new_hccl_node->set_inputs(new_inputs); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; - (void)manager->Replace(hccl_node, new_hccl_node); - MS_LOG(DEBUG) << "end replace"; - - // transer hccl op's control to the memcpy_async - if (hccl_node->size() == 2) { - TransferControl(new_hccl_node, memcpy_async, graph); - } - } -} - -const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (func_graph == nullptr || node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - if (!AnfAlgo::IsCommunicationOp(node)) { - return nullptr; - } - InsertMemcpyAsync(func_graph, cnode); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h deleted file mode 100644 index e2f3b781ed..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { - public: - explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) - : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), - kernel_query_(std::make_shared()) {} - ~InsertMemcpyAsyncForHcclOp() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; - bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; - KernelQueryPtr kernel_query_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc deleted file mode 100644 index b73fe6c83c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" -#include -#include -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "device/kernel_info.h" -#include "kernel//oplib/oplib.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef InsertPadForNMSWithMask::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimNMSWithMask, Xs}); -} - -AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type, - const std::vector &origin_shape) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector new_pad_inputs; - auto prim = std::make_shared(prim::kPrimPad->name()); - new_pad_inputs.push_back(NewValueNode(prim)); - new_pad_inputs.push_back(input); - CNodePtr pad = func_graph->NewCNode(new_pad_inputs); - MS_EXCEPTION_IF_NULL(pad); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get()); - return pad; -} - -const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - - size_t input_num = AnfAlgo::GetInputTensorNum(node); - if (input_num == 0) { - return nullptr; - } - std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { - auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); - auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); - auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); - if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) { - return nullptr; - } - origin_shape[1] = 8; - auto pad = InsertPadToGraph(func_graph, cur_input, origin_type, origin_shape); - MS_EXCEPTION_IF_NULL(pad); - pad->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector>{{0, 0}, {0, 3}}), pad); - new_inputs.push_back(pad); - } - auto kernel_graph = func_graph->cast>(); - CNodePtr new_node = nullptr; - if (kernel_graph == nullptr) { - new_node = std::make_shared(*cnode); - } else { - new_node = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_inputs(new_inputs); - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h deleted file mode 100644 index bfc201ed11..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -class InsertPadForNMSWithMask : public PatternProcessPass { - public: - explicit InsertPadForNMSWithMask(bool multigraph = true) - : PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {} - ~InsertPadForNMSWithMask() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.cc deleted file mode 100644 index b661df9d98..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.h" - -#include -#include -#include -#include - -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -namespace { -using ConvertFunction = std::function; - -void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode); -const size_t kAxis_H = 2; -const size_t kAxis_W = 3; -const size_t kAxis_6HD_H = 1; -const size_t kAxis_6HD_W = 2; -const std::map kReduceConvertMap = {{kOpFormat_FRAC_Z, ConvertReduceAttrFraczAnd6HD}, - {kOpFormat_C1HWNCoC0, ConvertReduceAttrFraczAnd6HD}}; -void SafeCheckFunction(const CNodePtr &cnode, const std::vector &reduce_axis) { - if (reduce_axis.empty()) { - MS_LOG(EXCEPTION) << "The node " << cnode->DebugString() << "'s reduce axis got a empty vector"; - } - if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) && - AnfAlgo::GetInputTensorNum(cnode) != 1) { - MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString() - << "] is not single input or single output "; - } - for (auto elem : reduce_axis) { - if (elem > 4) { - MS_LOG(INFO) << "reduce axis is larger than 4 dims reduce axis : [" << elem << "]"; - } - } -} - -void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) { - auto axis = kernel::GetReduceAttrAxis(cnode); - std::vector convert_axis; - SafeCheckFunction(cnode, axis); - auto format = AnfAlgo::GetInputFormat(cnode, 0); - if (format != kOpFormat_FRAC_Z || format != kOpFormat_C1HWNCoC0) { - MS_LOG(EXCEPTION) << "The node [" << cnode->DebugString() << "] format " << format << " is not 5hd"; - } - for (auto elem : axis) { - switch (elem) { - case kAxis_H: - convert_axis.emplace_back(kAxis_6HD_H); - break; - case kAxis_W: - convert_axis.emplace_back(kAxis_6HD_W); - break; - default: - MS_LOG(INFO) << "reduce axis is axis : [" << elem << "]" - << " but the format is not supported this reduce axis"; - } - } - AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(convert_axis), cnode); -} -} // namespace - -const BaseRef ChangeAxisOfReduceKernel::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) { - return nullptr; - } - auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0)); - if (convert_map == kReduceConvertMap.end()) { - return nullptr; - } - convert_map->second(node->cast()); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.h b/mindspore/ccsrc/pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.h deleted file mode 100644 index ec23baf0ab..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/chang_axis_of_reduce_kernel.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ChangeAxisOfReduceKernel : public PatternProcessPass { - public: - explicit ChangeAxisOfReduceKernel(bool multigraph = true) - : PatternProcessPass("change_axis_of_reduce_kernel", multigraph) {} - ~ChangeAxisOfReduceKernel() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.cc deleted file mode 100644 index 7c8fb70fda..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/format_type/check_consistency.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckFormatForConsistency(const CNodePtr &node, const size_t input_index) { - MS_EXCEPTION_IF_NULL(node); - // get prior node's device output format - string pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(node, input_index); - string selected_input_format = AnfAlgo::GetInputFormat(node, input_index); - if (pre_output_format == selected_input_format) { - return true; - } - auto input_origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, input_index); - if (pre_output_format == kOpFormat_DEFAULT || selected_input_format == kOpFormat_DEFAULT) { - string checking_format = (pre_output_format == kOpFormat_DEFAULT) ? selected_input_format : pre_output_format; - // when input shape size is 1D, default format and NC1HWC0 are compatible - if (input_origin_shape.size() == 1 && checking_format == kOpFormat_NC1HWC0) { - return true; - } - if (kDefaultCompatibleFormat.find(checking_format) != kDefaultCompatibleFormat.end()) { - return true; - } - } - if (input_origin_shape.size() == 0) { - return true; - } - MS_LOG(ERROR) << "Found inconsistent format! input format " << input_index << ": " << pre_output_format - << ", selected input format: " << selected_input_format; - return false; -} - -bool CheckDataTypeForConsistency(const CNodePtr &node, const size_t input_index) { - MS_EXCEPTION_IF_NULL(node); - TypeId input_data_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(node, input_index); - TypeId selected_data_type = AnfAlgo::GetInputDeviceDataType(node, input_index); - if (input_data_type == selected_data_type) { - return true; - } - MS_LOG(ERROR) << "Found inconsistent dtype! input dtype " << input_index << ": " << TypeIdLabel(input_data_type) - << ", selected dtype: " << TypeIdLabel(selected_data_type); - return false; -} -} // namespace - -const BaseRef CheckConsistency::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - - std::vector todos = {node}; - if (AnfAlgo::IsGraphKernel(node)) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - kernel::GetValidKernelNodes(sub_graph, &todos); - } - - for (auto &t : todos) { - CNodePtr cnode = t->cast(); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { - if (!CheckFormatForConsistency(cnode, i) || !CheckDataTypeForConsistency(cnode, i)) { - MS_LOG(EXCEPTION) << "Found inconsistent format or data type! Op: " << AnfAlgo::GetCNodeName(cnode) << "[" - << cnode->DebugString() << "]"; - } - } - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.h b/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.h deleted file mode 100644 index e134547dc8..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/check_consistency.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class CheckConsistency : public PatternProcessPass { - public: - explicit CheckConsistency(bool multigraph = true) : PatternProcessPass("check_consistency", multigraph) {} - ~CheckConsistency() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_CHECK_CONSISTENCY_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc deleted file mode 100644 index c0f99ed415..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" -#include "kernel/kernel_query.h" -namespace mindspore { -namespace opt { -const BaseRef ConvertUnSupportNodeToAICPU::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraphPtr &, - const mindspore::AnfNodePtr &node, - const mindspore::EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto node_name = AnfAlgo::GetCNodeName(node); - if (node_name != prim::KPrimTransData->name() && node_name != prim::kPrimCast->name()) { - return nullptr; - } - auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); - if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { - return nullptr; - } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { - auto builder = std::make_shared(kernel_builder_info); - builder->SetKernelType(AICPU_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), node); - } else { - MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" - << node->DebugString() << "]"; - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h deleted file mode 100644 index 80cc8170ac..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" -#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H -#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H -namespace mindspore { -namespace opt { -class ConvertUnSupportNodeToAICPU : public PatternProcessPass { - public: - explicit ConvertUnSupportNodeToAICPU(bool multigraph = true) - : PatternProcessPass("convert_unsupported_node_to_aicpu", multigraph), - supported_checker_(std::make_shared()) {} - ~ConvertUnSupportNodeToAICPU() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - SupportedCheckerPtr supported_checker_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc deleted file mode 100644 index f909dae9e4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ /dev/null @@ -1,226 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" -#include -#include -#include -#include -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { - session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); - AnfNodePtr cur_node = kernel_with_index.first; - size_t cur_out_index = kernel_with_index.second; - MS_EXCEPTION_IF_NULL(cur_node); - if (cur_node->isa()) { - auto cnode = cur_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string op_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); - // deal ref op - if (op_info != nullptr && op_info->is_ref()) { - auto ref_infos = op_info->ref_infos(); - if (ref_infos.count(cur_out_index) != 0) { - auto in_index = ref_infos.at(cur_out_index); - if (in_index > cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() - << ", ref info is " << cur_out_index; - } - AnfNodePtr next_node = cnode->input(in_index + 1); - return FindRefOriginNode(next_node); - } - } - - // deal special (trans,cast,reshape) op - if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || - op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { - AnfNodePtr next_node = cnode->input(1); - return FindRefOriginNode(next_node); - } - } - - return kernel_with_index; -} - -void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, - const AnfNodePtr &final_node, size_t final_index, - const session::KernelWithIndex &origin_pair) { - // record the ref_pair - auto kernel_graph = func_graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - // if the final node is get item, means no trans or cast op is added, the final node is itself - // so add the pair for itself, because the get item will removed later - auto final_ref = (final_node == get_item ? cnode : final_node); - session::AnfWithOutIndex final_pair = std::make_pair(final_ref, final_index); - if (kernel_graph->IsInRefOutputMap(final_pair)) { - MS_LOG(EXCEPTION) << "ref_pair is already in ref map, node is " << final_ref->DebugString() << ", index is " - << final_index; - } - MS_LOG(DEBUG) << "Add Ref pair, final {node ptr " << final_pair.first.get() << " , info is " - << final_pair.first->DebugString() << " , index is " << final_pair.second << "}, origin {node ptr " - << origin_pair.first.get() << ", info is " << origin_pair.first->DebugString() << " : index " - << origin_pair.second << "}"; - kernel_graph->AddRefCorrespondPairs(final_pair, origin_pair); -} - -// if get_item is nullptr, the additional node will link to the cnode -// else the additional node will link to the get_item node (the get_item node link to cnode) -AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, - size_t input_index, const AnfNodePtr &get_item) { - AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); - size_t final_index = output_index; - AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); - session::KernelWithIndex origin_pair; - origin_pair = FindRefOriginNode(input_node); - MS_EXCEPTION_IF_NULL(origin_pair.first); - if (!origin_pair.first->isa()) { - MS_LOG(EXCEPTION) << "ref op origin node is not parameter"; - } - MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is " - << origin_pair.first->DebugString() << ", index is " << origin_pair.second; - auto origin_format = AnfAlgo::GetOutputFormat(origin_pair.first, origin_pair.second); - auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); - auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); - auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); - auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); - // insert trans - if (origin_format != cur_format && cur_shape.size() > 1) { - auto kernel_select = std::make_shared(); - final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(cur_format, origin_format, final_node); - final_index = 0; - MS_EXCEPTION_IF_NULL(final_node); - MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); - } - // insert cast - if (origin_type != cur_type) { - final_node = - AddCastOpNodeToGraph(func_graph, final_node, origin_format, cur_type, origin_type, cur_shape, cur_type); - MS_EXCEPTION_IF_NULL(final_node); - final_node->set_scope(cnode->scope()); - final_index = 0; - MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); - } - // add ref pair - AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); - // insert depend - if (origin_format != cur_format || origin_type != cur_type) { - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; - final_node = func_graph->NewCNode(depend_nodes); - MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); - } - - return final_node; -} -AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(op_info); - auto ref_infos = op_info->ref_infos(); - std::vector make_tuple_inputs; - AbstractBasePtrList abstract_list; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); - // deal with ref output - if (ref_infos.count(output_index) != 0) { - auto input_index = ref_infos.at(output_index); - final_node = AddAdditionalToRefOutput(func_graph, cnode, output_index, input_index, final_node); - } - MS_EXCEPTION_IF_NULL(final_node); - abstract_list.push_back(final_node->abstract()); - make_tuple_inputs.push_back(final_node); - } - MS_EXCEPTION_IF_NULL(func_graph); - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - make_tuple->set_abstract(std::make_shared(abstract_list)); - return make_tuple; -} - -AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::shared_ptr &op_info) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(op_info); - auto ref_infos = op_info->ref_infos(); - for (const auto &ref_info : ref_infos) { - if (ref_info.second > cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "ref op has wrong inputs: op inputs num is " << cnode->inputs().size() << ", ref info is " - << ref_info.second; - } - return AddAdditionalToRefOutput(func_graph, cnode, ref_info.first, ref_info.second, nullptr); - } - return nullptr; -} -} // namespace - -const BaseRef DealRefTransAndCast::DefinePattern() const { - VarPtr V = std::make_shared(UnVisited); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { - auto input_size = AnfAlgo::GetInputTensorNum(cnode); - for (size_t i = 0; i < input_size; ++i) { - auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); - auto input_node = input_node_with_index.first; - MS_EXCEPTION_IF_NULL(input_node); - MS_LOG(INFO) << "origin node:" << input_node->fullname_with_scope(); - AddRefPairToKernelGraph(func_graph, cnode, nullptr, cnode, i, input_node_with_index); - } - } -} - -const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::IsRealCNodeKernel(cnode)) { - return nullptr; - } - - DealBroadCastAsRef(graph, cnode); - - auto op_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); - if (op_info == nullptr || !op_info->is_ref()) { - return nullptr; - } - if (op_info->is_ref()) { - auto type = cnode->Type(); - MS_EXCEPTION_IF_NULL(type); - if (!type->isa()) { - return DealRefSigleOutput(graph, cnode, op_info); - } else { - return DealRefForMultipleOutput(graph, cnode, op_info); - } - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h deleted file mode 100644 index 1b54a7b111..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class DealRefTransAndCast : public PatternProcessPass { - public: - explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} - ~DealRefTransAndCast() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc deleted file mode 100644 index 2b2749090a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.cc +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/format_type/insert_cast.h" - -#include -#include -#include -#include - -#include "device/kernel_info.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "kernel/kernel_build_info.h" -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "utils/utils.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::vector &need_insert_cast) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - std::vector make_tuple_inputs; - AbstractBasePtrList abstract_list; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) { - AnfNodePtr replace_node = nullptr; - const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx); - const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, output_idx); - auto idx = NewValueNode(SizeToInt(output_idx)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(output_idx); - idx->set_abstract(std::make_shared(imm)); - auto getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); - AnfAlgo::SetOutputInferTypeAndShape({infer_type}, {origin_shape}, getitem.get()); - if (need_insert_cast[output_idx]) { - const auto dev_fmt = AnfAlgo::GetOutputFormat(cnode, output_idx); - TypeId origin_type(kTypeUnknown); - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); - } - origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; - const auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_idx); - if (origin_type != device_type) { - replace_node = - AddCastOpNodeToGraph(func_graph, getitem, dev_fmt, device_type, origin_type, origin_shape, infer_type); - MS_EXCEPTION_IF_NULL(replace_node); - replace_node->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); - } else { - replace_node = getitem; - } - } else { - replace_node = getitem; - } - abstract_list.push_back(replace_node->abstract()); - make_tuple_inputs.push_back(replace_node); - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - make_tuple->set_abstract(std::make_shared(abstract_list)); - return make_tuple; -} // namespace - -AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::vector &need_insert_cast) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetOutputTensorNum(cnode) == 0) { - return cnode; - } - MS_EXCEPTION_IF_NULL(cnode->Type()); - // Single output - if (!cnode->Type()->isa()) { - if (!need_insert_cast[0]) { - return cnode; - } - - const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, 0); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(cnode, 0); - const auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, 0); - TypeId origin_type(kTypeUnknown); - if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - origin_type = AnfAlgo::GetCNodeOutputPrecision(cnode); - } - origin_type = origin_type == kTypeUnknown ? infer_type : origin_type; - const TypeId device_type = AnfAlgo::GetOutputDeviceDataType(cnode, 0); - AnfNodePtr replace_node = cnode; - if (origin_type != device_type) { - replace_node = - AddCastOpNodeToGraph(func_graph, cnode, dev_fmt, device_type, origin_type, origin_shape, infer_type); - MS_EXCEPTION_IF_NULL(replace_node); - replace_node->set_scope(cnode->scope()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node); - } - return replace_node; - } - // Multiple output - return InsertCastForMultipleOutput(func_graph, cnode, need_insert_cast); -} - -AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - // insert cast for ops in graph kernel. - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - auto mng = sub_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todo; - std::vector> graph_rets; - kernel::GetValidKernelNodes(sub_graph, &todo); - kernel::GetGraphRealOutput(sub_graph, &graph_rets); - for (auto &t : todo) { - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); - // process input - CNodePtr t_cnode = t->cast(); - MS_EXCEPTION_IF_NULL(t_cnode); - auto t_new_node = InsertCastForInput(sub_graph, t_cnode); - AnfNodePtr t_new_node_1 = nullptr; - std::vector need_insert_cast(AnfAlgo::GetOutputTensorNum(t), true); - // process output - auto iter = std::find_if(graph_rets.begin(), graph_rets.end(), - [&t](const std::pair &ret) { return ret.first == t; }); - if (iter != graph_rets.end()) { - auto t_fix_output_type = AnfAlgo::GetCNodeOutputPrecision(t); - auto t_output_type = AnfAlgo::GetOutputDeviceDataType(t, iter->second); - auto graph_output_type = AnfAlgo::GetOutputDeviceDataType(node, iter - graph_rets.begin()); - if (t_fix_output_type == kTypeUnknown && t_output_type == graph_output_type) { - need_insert_cast[iter->second] = false; - } else if (t_fix_output_type == t_output_type && t_output_type == graph_output_type) { - need_insert_cast[iter->second] = false; - } - t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); - } else { - t_new_node_1 = InsertCastForOutput(sub_graph, t_new_node, need_insert_cast); - } - - if (t_new_node_1 != nullptr && t_new_node_1 != t) { - (void)mng->Replace(t, t_new_node_1); - } - } - - // insert cast for graph kernel. - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto new_node = InsertCastForInput(func_graph, cnode); - // process output - return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); -} -} // namespace - -const BaseRef InsertCast::DefinePattern() const { - VarPtr V = std::make_shared(UnVisited); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { - return nullptr; - } - - if (AnfAlgo::IsGraphKernel(node)) { - return ProcessGraphKernelOp(func_graph, node); - } - // insert cast for single op. - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto new_node = InsertCastForInput(func_graph, cnode); - // process output - return InsertCastForOutput(func_graph, new_node, std::vector(AnfAlgo::GetOutputTensorNum(new_node), true)); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.h deleted file mode 100644 index a7f93ec8f3..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ -#include - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pattern_engine.h" -#include "ir/anf.h" - -namespace mindspore { -namespace opt { -class InsertCast : public PatternProcessPass { - public: - explicit InsertCast(bool multigraph = true) : PatternProcessPass("insert_cast", multigraph) {} - ~InsertCast() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc deleted file mode 100644 index 3f77c68f86..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include -#include -#include "utils/utils.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/oplib/oplib.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -const BaseRef InsertTransOp::DefinePattern() const { - std::shared_ptr V = std::make_shared(UnVisited); - std::shared_ptr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -bool IsGraphOutput(const AnfNodePtr &node, const std::vector &outputs) { - auto iter = std::find(outputs.begin(), outputs.end(), node); - if (iter != outputs.end()) { - return true; - } - - return false; -} - -const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - AnfNodePtr front_node; - auto kernel_graph = func_graph->cast>(); - if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) { - front_node = kernel_graph->GetFrontNodeByInternalOutput(node); - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) { - if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { - return new_node; - } - } - auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_); - if (kernel_graph != nullptr && front_node != nullptr) { - auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node); - kernel_graph->ReplaceInternalOutput(old_node, final_node); - } - return final_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.h deleted file mode 100644 index eb6cfa9542..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_trans_op.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class InsertTransOp : public PatternProcessPass { - public: - explicit InsertTransOp(bool multigraph = true) - : PatternProcessPass("insert_trans_op", multigraph), kernel_select_(std::make_shared()) {} - ~InsertTransOp() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANS_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.cc deleted file mode 100644 index 3df513a19f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.cc +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" -#include -#include "utils/utils.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace opt { -const BaseRef RunOpInsertTransData::DefinePattern() const { - std::shared_ptr V = std::make_shared(UnVisited); - MS_EXCEPTION_IF_NULL(V); - std::shared_ptr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - return VectorRef({V, Xs}); -} - -const AnfNodePtr RunOpInsertTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - MS_LOG(DEBUG) << "====process op: " << node->DebugString(); - return InsertTransOpForInput(func_graph, node, kernel_select_); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h deleted file mode 100644 index f699cdd580..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_transdata_for_runop.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class RunOpInsertTransData : public PatternProcessPass { - public: - explicit RunOpInsertTransData(bool multigraph = true) - : PatternProcessPass("insert_transdata_for_runop", multigraph), - kernel_select_(std::make_shared()) {} - ~RunOpInsertTransData() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_TRANSDATA_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc deleted file mode 100644 index b1817cec3d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc +++ /dev/null @@ -1,282 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/merge_cast_to_op.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t kCastInputNum = 2; -const size_t kTupleGetitemInputNum = 3; -bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx, - const std::shared_ptr &candidate_kernel_info) { - if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) { - return false; - } - - // checkout inputs' fmt and dtype except index equal change_idx - for (size_t i = 0; i < candidate_kernel_info->GetInputNum(); i++) { - if (i == change_idx) { - if (candidate_kernel_info->GetInputDeviceType(i) != dst_type || - candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { - return false; - } - } else if (candidate_kernel_info->GetInputDeviceType(i) != AnfAlgo::GetInputDeviceDataType(node, i) || - candidate_kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(node, i)) { - return false; - } - } - - // check outputs's fmt and dtype - for (size_t i = 0; i < candidate_kernel_info->GetOutputNum(); i++) { - if (candidate_kernel_info->GetOutputDeviceType(i) != AnfAlgo::GetOutputDeviceDataType(node, i) || - candidate_kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(node, i)) { - return false; - } - } - return true; -} - -bool GetNextNodeAndCastIndex(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodePtr *next_node, - size_t *cast_index) { - auto output_node_list = GetRealNodeUsedList(graph, node); - MS_EXCEPTION_IF_NULL(output_node_list); - if (output_node_list->size() != 1) { - return false; - } - auto node_pair = output_node_list->at(0); - *next_node = node_pair.first; - *cast_index = node_pair.second - 1; - return true; -} - -bool CheckInputs(const CNodePtr &node, const std::shared_ptr &kernel_info) { - MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetInputTensorNum(node) != kernel_info->GetInputNum()) { - return false; - } - - for (size_t index = 0; index < kernel_info->GetInputNum(); ++index) { - if (AnfAlgo::GetInputFormat(node, index) != kernel_info->GetInputFormat(index) || - AnfAlgo::GetInputDeviceDataType(node, index) != kernel_info->GetInputDeviceType(index)) { - return false; - } - } - return true; -} - -bool CheckOtherOutputs(const CNodePtr &node, const std::shared_ptr &kernel_info, - const size_t idx) { - MS_EXCEPTION_IF_NULL(kernel_info); - if (AnfAlgo::GetOutputTensorNum(node) != kernel_info->GetOutputNum()) { - return false; - } - for (size_t index = 0; index < kernel_info->GetOutputNum(); ++index) { - if (idx == index) { - continue; - } - if (AnfAlgo::GetOutputFormat(node, index) != kernel_info->GetOutputFormat(index) || - AnfAlgo::GetOutputDeviceDataType(node, index) != kernel_info->GetOutputDeviceType(index)) { - return false; - } - } - return true; -} - -bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr &kernel_info, size_t index) { - if (kernel_info == nullptr) { - return false; - } - - if (AnfAlgo::GetOutputDeviceDataType(node, 0) != kernel_info->GetOutputDeviceType(index)) { - return false; - } - if (AnfAlgo::GetOutputInferShape(node, 0).size() == 4 && AnfAlgo::GetOutputFormat(node, 0) == kOpFormat_NCHW && - kernel_info->GetOutputFormat(index) == kOpFormat_DEFAULT) { - return true; - } - return AnfAlgo::GetOutputFormat(node, 0) == kernel_info->GetOutputFormat(index); -} - -void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) { - using Shape = std::vector; - auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0); - auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); - std::vector shapes; - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { - if (cast_index == index) { - shapes.emplace_back(cast_shape); - types.emplace_back(cast_dtype); - continue; - } - shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index)); - types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index)); - } - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get()); -} - -AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, const KernelQueryPtr kernel_query) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(kernel_query); - AnfNodePtr next_node = nullptr; - size_t cast_index = 0; - if (!GetNextNodeAndCastIndex(graph, node, &next_node, &cast_index)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(next_node); - if (!next_node->isa() || !AnfAlgo::IsRealKernel(next_node)) { - return nullptr; - } - auto next_cnode = next_node->cast(); - if (AnfAlgo::IsGraphKernel(next_node)) { - return nullptr; - } - auto next_op_name = AnfAlgo::GetCNodeName(next_node); - std::vector> kernel_info_list; - kernel_query->Query(next_cnode, &kernel_info_list); - - auto dst_type_id = AnfAlgo::GetInputDeviceDataType(node, 0); - auto alternative_kernel_info = std::find_if( - kernel_info_list.begin(), kernel_info_list.end(), - [&next_cnode, &dst_type_id, &cast_index](const std::shared_ptr &candidate_kernel_info) { - return AlternativeKernelInfoForInput(next_cnode, dst_type_id, cast_index, candidate_kernel_info); - }); - if (alternative_kernel_info == kernel_info_list.end()) { - return nullptr; - } - auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(next_node); - MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << next_cnode->DebugString() - << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" - << (*alternative_kernel_info)->ToString(); - AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); - ChangeNodeInferInfo(next_cnode, node, cast_index); - if (node->inputs().size() < kCastInputNum) { - MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; - } - return node->input(1); -} - -bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_output, size_t *output_idx) { - MS_EXCEPTION_IF_NULL(x_node); - if (x_node->isa()) { - auto x_cnode = x_node->cast(); - *prior_op = x_cnode; - // when x_node is tuple_getitem - if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) { - if (x_cnode->inputs().size() < kTupleGetitemInputNum) { - MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size(); - } - MS_EXCEPTION_IF_NULL(output_idx); - AnfNodePtr input1 = x_cnode->input(1); - MS_EXCEPTION_IF_NULL(input1); - if (!input1->isa()) { - return false; - } - *prior_op = input1->cast(); - MS_EXCEPTION_IF_NULL(*prior_op); - AnfNodePtr input2 = x_cnode->input(2); - MS_EXCEPTION_IF_NULL(input2); - auto value_ptr = input2->cast(); - MS_EXCEPTION_IF_NULL(value_ptr); - *output_idx = IntToSize(GetValue(value_ptr->value())); - *single_output = false; - } - return AnfAlgo::IsRealKernel(*prior_op); - } - return false; -} - -AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) { - MS_EXCEPTION_IF_NULL(cur_node); - MS_EXCEPTION_IF_NULL(kernel_query); - if (cur_node->inputs().size() < kCastInputNum) { - MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:"; - } - AnfNodePtr x_node = cur_node->input(1); - if (IsUsedByOthers(graph, x_node)) { - return nullptr; - } - - CNodePtr prior_op = nullptr; - bool single_output = true; - size_t output_idx = 0; - if (!GetPriorOp(x_node, &prior_op, &single_output, &output_idx)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(prior_op); - if (AnfAlgo::IsGraphKernel(prior_op)) { - return nullptr; - } - - std::vector> kernel_info_list; - kernel_query->Query(prior_op, &kernel_info_list); - auto kernel_info_it = std::find_if( - kernel_info_list.begin(), kernel_info_list.end(), - [&prior_op, &cur_node, &output_idx](const std::shared_ptr &item_kernel_info) { - return CheckInputs(prior_op, item_kernel_info) && CheckOtherOutputs(prior_op, item_kernel_info, output_idx) && - CheckIndexOutput(cur_node, item_kernel_info, output_idx); - }); - if (kernel_info_it == kernel_info_list.end()) { - return nullptr; - } - auto ori_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(prior_op); - MS_LOG(INFO) << "Found alternative kernel info for current anf kernel " << prior_op->DebugString() - << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" - << (*kernel_info_it)->ToString(); - AnfAlgo::SetSelectKernelBuildInfo(*kernel_info_it, prior_op.get()); - ChangeNodeInferInfo(prior_op, cur_node, output_idx); - if (!single_output) { - MS_EXCEPTION_IF_NULL(x_node); - ChangeNodeInferInfo(x_node->cast(), cur_node, 0); - } - auto prior_name = AnfAlgo::GetCNodeName(prior_op); - if (prior_name == kFive2FourOpName) { - AnfAlgo::CopyNodeAttr("dst_type", "dstType", cur_node, prior_op); - } else if (prior_name == kFour2FiveOpName) { - AnfAlgo::CopyNodeAttr("dst_type", cur_node, prior_op); - } - return single_output ? prior_op : x_node; -} -} // namespace - -const BaseRef MergeCastToOp::DefinePattern() const { - VarPtr X = std::make_shared(); - return VectorRef({prim::kPrimCast, X}); -} - -const AnfNodePtr MergeCastToOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - auto new_node = MergeCastToNextOp(graph, cnode, kernel_query_); - if (new_node == nullptr) { - new_node = MergeCastToPriorOp(graph, cnode, kernel_query_); - } - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.h b/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.h deleted file mode 100644 index 7e05c8a02a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class MergeCastToOp : public PatternProcessPass { - public: - explicit MergeCastToOp(bool multigraph = true) - : PatternProcessPass("merge_cast_to_op", multigraph), kernel_query_(std::make_shared()) {} - ~MergeCastToOp() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelQueryPtr kernel_query_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MERGE_CAST_TO_OP_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.cc deleted file mode 100644 index 42061957b9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/modify_ops_attrs.h" -#include -#include -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr ModifyReduceOpsAttrs(const CNodePtr &cnode) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - auto input_format = AnfAlgo::GetInputFormat(cnode, 0); - if (input_shape.size() == 5 || input_format != kOpFormat_NC1HWC0) { - return nullptr; - } - if (!AnfAlgo::HasNodeAttr(kAttrKeepDims, cnode)) { - return nullptr; - } - - AnfAlgo::SetNodeAttr(kAttrKeepDims, MakeValue(true), cnode); - return cnode; -} - -AnfNodePtr ModifyTileOpAttrs(const CNodePtr &cnode) { - auto input_shape = AnfAlgo::GetInputDeviceShape(cnode, 0); - if (input_shape.size() != 5) { - return nullptr; - } - if (!AnfAlgo::HasNodeAttr(kAttrMultiples, cnode)) { - return nullptr; - } - - auto multiples = AnfAlgo::GetNodeAttr>(cnode, kAttrMultiples); - if (multiples.size() == 4 && multiples[1] == 1) { - multiples.push_back(1); - AnfAlgo::SetNodeAttr(kAttrMultiples, MakeValue(multiples), cnode); - } - - return cnode; -} - -AnfNodePtr ModifyAttrs(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto op_name = AnfAlgo::GetCNodeName(cnode); - if (op_name == prim::kPrimTile->name()) { - return ModifyTileOpAttrs(cnode); - } else if (op_name == prim::kPrimReduceSum->name()) { - // kPrimReduceMean - // kPrimReduceSum - // kPrimReduceAll - // kPrimReduceMax - // kPrimReduceMin - return ModifyReduceOpsAttrs(cnode); - } - return nullptr; -} -} // namespace - -const AnfNodePtr ModifyOpAttrs::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { - return nullptr; - } - MS_LOG(DEBUG) << "====Process op: " << AnfAlgo::GetCNodeName(node); - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - auto manager = fg->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector todos; - kernel::GetValidKernelNodes(fg, &todos); - for (auto &t : todos) { - auto new_node = ModifyAttrs(t->cast()); - if (new_node != nullptr && new_node != t) { - (void)manager->Replace(t, new_node); - } - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.h b/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.h deleted file mode 100644 index 25ec94b6b4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/modify_ops_attrs.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ModifyOpAttrs : public PatternProcessPass { - public: - explicit ModifyOpAttrs(bool multigraph = true) : PatternProcessPass("modify_ops_attrs", multigraph) {} - ~ModifyOpAttrs() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_MODIFY_OPS_ATTRS_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc deleted file mode 100644 index 571e70dca5..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc +++ /dev/null @@ -1,184 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" -#include "utils/utils.h" -#include "kernel/common_utils.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode) { - return RectifyKernelInfoInPynativeProcess(node); - } - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { - return nullptr; - } - std::vector do_mask_node_list; - auto gen_mask_output_nodes = GetRealNodeUsedList(graph, cnode); - MS_EXCEPTION_IF_NULL(gen_mask_output_nodes); - for (const auto &output_node : *gen_mask_output_nodes) { - if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { - MS_EXCEPTION_IF_NULL(output_node.first); - auto output_cnode = output_node.first->cast(); - do_mask_node_list.push_back(output_cnode); - } - } - std::vector input_shape; - for (const auto &output_node : do_mask_node_list) { - if (input_shape.empty()) { - input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); - continue; - } - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); - if (!kernel::IsSameShape(shape, input_shape)) { - MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!" - << " GenMask " << node->DebugString(); - } - } - RectifyKernelInfo(do_mask_node_list, graph); - return nullptr; -} - -void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list, - const FuncGraphPtr &graph) const { - std::map format_counter; - std::string special_format; - std::string convert_format; - for (const auto &do_mask : do_mask_node_list) { - auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); - if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) { - special_format = do_mask_data_format; - } - if (format_counter.find(do_mask_data_format) == format_counter.end()) { - format_counter[do_mask_data_format] = 1; - } else { - format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; - } - } - if (format_counter.size() == 1) { - return; - } - if (convert_format.empty()) { - convert_format = GetConvertFormat(format_counter); - } - RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format, graph); -} - -std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { - std::string convert_format = kOpFormat_DEFAULT; - size_t counter = 0; - if (format_counter.size() > 2) { - return kOpFormat_DEFAULT; - } - if (format_counter.size() == 2 && format_counter.find(kOpFormat_DEFAULT) == format_counter.end()) { - return kOpFormat_DEFAULT; - } - for (const auto &iter : format_counter) { - if (counter < iter.second) { - convert_format = iter.first; - counter = iter.second; - } else if (counter == iter.second && kHWSpecialFormatSet.find(iter.first) != kHWSpecialFormatSet.end()) { - convert_format = iter.first; - } - } - return convert_format; -} - -void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, - const std::string &format, - const FuncGraphPtr &graph) const { - for (const auto &do_mask : do_mask_node_list) { - if (AnfAlgo::GetInputFormat(do_mask, 0) != format) { - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); - builder->SetInputFormat(format, 0); - builder->SetOutputFormat(format, 0); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); - ReSelecChildNodeKernelInfo(do_mask, graph); - } - } -} - -AnfNodePtr RectifyDoMaskKernelInfo::RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr) { - return nullptr; - } - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { - return nullptr; - } - auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); - if (do_mask_input_format != kOpFormat_DEFAULT) { - auto builder = - std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); - builder->SetInputFormat(kOpFormat_DEFAULT, 0); - builder->SetOutputFormat(kOpFormat_DEFAULT, 0); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); - } - return nullptr; -} - -void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const { - MS_EXCEPTION_IF_NULL(cnode); - auto output_node_list = GetRealNodeUsedList(graph, cnode); - MS_EXCEPTION_IF_NULL(output_node_list); - for (const auto &out_node_info : *output_node_list) { - MS_EXCEPTION_IF_NULL(out_node_info.first); - auto out_node = out_node_info.first->cast(); - if (AnfAlgo::IsRealKernel(out_node_info.first)) { - auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); - kernel_selecter->SelectKernel(out_node); - auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node); - MS_EXCEPTION_IF_NULL(new_build_info); - MS_EXCEPTION_IF_NULL(ori_build_info); - if ((*new_build_info) != (*ori_build_info)) { - ReSelecChildNodeKernelInfo(out_node, graph); - } - } else if (AnfAlgo::GetCNodeName(out_node) == prim::kPrimTupleGetItem->name() || - AnfAlgo::GetCNodeName(out_node) == prim::kPrimDepend->name()) { - ReSelecChildNodeKernelInfo(out_node, graph); - } else { - MS_LOG(INFO) << "Reselected the node " << cnode->DebugString() << " failed"; - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h deleted file mode 100644 index b03937db47..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H -#include -#include -#include -#include - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" -namespace mindspore { -namespace opt { -class RectifyDoMaskKernelInfo : public PatternProcessPass { - public: - explicit RectifyDoMaskKernelInfo(bool multigraph = true) - : PatternProcessPass("batch_norm_bert_fission", multigraph), kernel_selecter(std::make_shared()) {} - ~RectifyDoMaskKernelInfo() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void RectifyKernelInfo(const std::vector &do_mask_node_list, const FuncGraphPtr &graph) const; - AnfNodePtr RectifyKernelInfoInPynativeProcess(const AnfNodePtr &node) const; - std::string GetConvertFormat(const std::map &format_counter) const; - void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format, - const FuncGraphPtr &graph) const; - void ReSelecChildNodeKernelInfo(const CNodePtr &cnode, const FuncGraphPtr &graph) const; - KernelSelectPtr kernel_selecter; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.cc deleted file mode 100644 index dde40a5090..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/format_type/remove_no_use_reshape_op.h" -#include -#include -#include "pre_activate/common/helper.h" -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr RemoveReshapeOp(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto op_name = AnfAlgo::GetCNodeName(cnode); - if (op_name != prim::kPrimReshape->name()) { - return nullptr; - } - - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); - auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); - if (input_shape.size() != 1 || input_format != kOpFormat_NC1HWC0) { - return nullptr; - } - - return cnode->input(1); -} -} // namespace - -const AnfNodePtr RemoveNoUseReshapeOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsGraphKernel(node)) { - return nullptr; - } - MS_LOG(DEBUG) << "====process op: " << AnfAlgo::GetCNodeName(node); - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - auto manager = fg->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector todos; - kernel::GetValidKernelNodes(fg, &todos); - for (auto &t : todos) { - auto new_node = RemoveReshapeOp(t->cast()); - if (new_node != nullptr && new_node != t) { - (void)manager->Replace(t, new_node); - } - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.h b/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.h deleted file mode 100644 index 4942c2fc08..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/remove_no_use_reshape_op.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class RemoveNoUseReshapeOp : public PatternProcessPass { - public: - explicit RemoveNoUseReshapeOp(bool multigraph = true) : PatternProcessPass("remove_no_use_reshape_op", multigraph) {} - ~RemoveNoUseReshapeOp() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_NO_USE_RESHAPE_OP_H diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc deleted file mode 100644 index b9a86f7bcb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/addn_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_addn_cnode, size_t begin_index, - size_t offset) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(origin_addn_cnode); - std::vector new_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; - for (size_t i = begin_index; i < begin_index + offset; ++i) { - new_addn_inputs.push_back(origin_addn_cnode->input(i)); - } - CNodePtr new_addn = func_graph->NewCNode(new_addn_inputs); - MS_EXCEPTION_IF_NULL(new_addn); - new_addn->set_scope(origin_addn_cnode->scope()); - new_addn->set_abstract(origin_addn_cnode->abstract()); - AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); - std::vector dyn_input_sizes{SizeToInt(offset)}; - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); - return new_addn; -} -} // namespace - -const BaseRef AddnFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimAddN, Xs}); -} - -const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // The real input begins with index 1. - size_t origin_input_size = cnode->inputs().size() - 1; - if (origin_input_size <= inputs_divisor_) { - return nullptr; - } - CNodePtr new_cnode = cnode; - while (origin_input_size > inputs_divisor_) { - MS_EXCEPTION_IF_NULL(new_cnode); - std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; - size_t cur_input_index = 1; - // Divide the inputs of addn by inputs_divisor_. - while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { - base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); - cur_input_index += inputs_divisor_; - } - for (size_t i = cur_input_index; i <= origin_input_size; i++) { - base_addn_inputs.push_back(new_cnode->input(i)); - } - CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); - MS_EXCEPTION_IF_NULL(base_addn); - base_addn->set_scope(new_cnode->scope()); - base_addn->set_abstract(new_cnode->abstract()); - AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); - std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); - new_cnode = base_addn; - origin_input_size = base_addn->inputs().size() - 1; - } - - return new_cnode; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.h deleted file mode 100644 index 3c62391f9a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -constexpr size_t kAddnInputsDivisor = 63; -class AddnFission : public PatternProcessPass { - public: - explicit AddnFission(bool multigraph = true) - : PatternProcessPass("addn_fission", multigraph), inputs_divisor_(kAddnInputsDivisor) {} - ~AddnFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - size_t inputs_divisor_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_ADDN_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc deleted file mode 100644 index e6a8864e46..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.cc +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const std::vector kOutputIndex{0, 3, 4, 5}; -constexpr size_t kBatchNormRealOutputNum = 3; -constexpr size_t kBatchNormRealInputNum = 3; - -bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(bn) == manager->node_users().end()) { - return false; - } - size_t output_num = 0; - for (const auto &node_index : manager->node_users()[bn]) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getiterm_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); - auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) { - return false; - } - bn_outputs->push_back(output); - output_num++; - } - return output_num == kBatchNormRealOutputNum; -} - -AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; - auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - auto bn_input1 = bn_cnode->input(2); - MS_EXCEPTION_IF_NULL(bn_input1); - auto bn_input2 = bn_cnode->input(3); - MS_EXCEPTION_IF_NULL(bn_input2); - AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input2->abstract()}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_reduce->set_abstract(abstract_tuple); - bn_training_reduce->set_scope(bn->scope()); - AnfAlgo::CopyNodeAttrs(bn, bn_training_reduce); - return bn_training_reduce; -} - -AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, - const std::vector &bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } - std::vector bn_training_update_v2_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateV2OpName)), - bn_cnode->input(1), - bn_training_reduce_outputs[0], - bn_training_reduce_outputs[1], - bn_cnode->input(2), - bn_cnode->input(3)}; - auto bn_training_update_v2 = func_graph->NewCNode(bn_training_update_v2_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update_v2); - - auto bn_abstract_tuple = dyn_cast(bn->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " - << bn_abstract_tuple->elements().size(); - } - std::vector abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3], - bn_abstract_tuple->elements()[4]}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_update_v2->set_abstract(abstract_tuple); - bn_training_update_v2->set_scope(bn->scope()); - AnfAlgo::CopyNodeAttrs(bn, bn_training_update_v2); - return bn_training_update_v2; -} -} // namespace - -const BaseRef BatchNormBertFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimBatchNorm, Xs}); -} - -const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - std::vector bn_outputs; - if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) { - MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed"; - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != kBatchNormRealInputNum + 1) { - MS_LOG(INFO) << "The input size of BatchNorm should be " << kBatchNormRealInputNum - << ". The node should not be changed"; - return nullptr; - } - AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); - std::vector bn_training_reduce_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, - &bn_training_reduce_outputs); - - AnfNodePtr bn_training_update_v2 = CreateBNTrainingUpdateV2(func_graph, node, bn_training_reduce_outputs); - std::vector bn_training_update_v2_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v2, kBNTrainingUpdateV2OutputNum, - &bn_training_update_v2_outputs); - if (bn_training_update_v2_outputs.size() != kBNTrainingUpdateV2OutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum - << ", but it is " << bn_training_update_v2_outputs.size(); - } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem); - size_t output_index = 0; - for (const auto &output : bn_outputs) { - (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); - output_index++; - } - // Return the new node for control depends. - return bn_training_update_v2; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.h deleted file mode 100644 index fc214817fc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_bert_fission.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormBertFission : public PatternProcessPass { - public: - explicit BatchNormBertFission(bool multigraph = true) : PatternProcessPass("batch_norm_bert_fission", multigraph) {} - ~BatchNormBertFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_BERT_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc deleted file mode 100644 index 5e41111660..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.cc +++ /dev/null @@ -1,172 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kBatchNormGradInferOutputNum = 3; -bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(node) == manager->node_users().end()) { - MS_LOG(DEBUG) << "The node " << node->DebugString() << " should have some outputs"; - return false; - } - for (const auto &node_index : manager->node_users()[node]) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getiterm_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode); - auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index == kBatchNormGradInferOutputNum || index == kBatchNormGradInferOutputNum + 1) { - MS_LOG(DEBUG) << "The output " << index << " of node " << node->DebugString() << " is not null, no need change"; - return false; - } - } - return true; -} -} // namespace - -AnfNodePtr BatchNormGradInferFission::CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_grad); - MS_EXCEPTION_IF_NULL(equiv); - // Set inputs - auto iter_input0 = (*equiv).find(input0_var_); - if (iter_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; - } - auto iter_input2 = (*equiv).find(input2_var_); - if (iter_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input2 var after matched."; - } - auto iter_input4 = (*equiv).find(input4_var_); - if (iter_input4 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; - } - std::vector bn_infer_grad_inputs = { - NewValueNode(std::make_shared(kBNInferGradOpName)), utils::cast(iter_input0->second), - utils::cast(iter_input2->second), utils::cast(iter_input4->second)}; - auto bn_infer_grad = func_graph->NewCNode(bn_infer_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_infer_grad); - // Set abstract, the output of new node is taking the place of the 0th output of bn_grad. - auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); - MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); - if (bn_grad_abstract_tuple->elements().empty()) { - MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be empty"; - } - bn_infer_grad->set_abstract(bn_grad_abstract_tuple->elements()[0]); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_infer_grad); - bn_infer_grad->set_scope(bn_grad->scope()); - return bn_infer_grad; -} - -AnfNodePtr BatchNormGradInferFission::CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, - const AnfNodePtr &bn_grad, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn_grad); - MS_EXCEPTION_IF_NULL(equiv); - // Set inputs - auto iter_input0 = (*equiv).find(input0_var_); - if (iter_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input0 var after matched."; - } - auto iter_input1 = (*equiv).find(input1_var_); - if (iter_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input1 var after matched."; - } - auto iter_input3 = (*equiv).find(input3_var_); - if (iter_input3 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input3 var after matched."; - } - auto iter_input4 = (*equiv).find(input4_var_); - if (iter_input4 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the input4 var after matched."; - } - std::vector bn_training_update_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), - utils::cast(iter_input0->second), utils::cast(iter_input1->second), - utils::cast(iter_input3->second), utils::cast(iter_input4->second)}; - auto bn_training_update_grad = func_graph->NewCNode(bn_training_update_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update_grad); - // Set abstract, the outputs of new node are taking the place of the 1st and 2nd outputs of bn_grad. - auto bn_grad_abstract_tuple = dyn_cast(bn_grad->abstract()); - MS_EXCEPTION_IF_NULL(bn_grad_abstract_tuple); - if (bn_grad_abstract_tuple->elements().size() < kBatchNormGradInferOutputNum) { - MS_LOG(EXCEPTION) << "The abstract tuple of node " << bn_grad->DebugString() << "should not be less than 3"; - } - std::vector abstract_list{bn_grad_abstract_tuple->elements()[1], - bn_grad_abstract_tuple->elements()[2]}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_update_grad->set_abstract(abstract_tuple); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad, bn_training_update_grad); - bn_training_update_grad->set_scope(bn_grad->scope()); - return bn_training_update_grad; -} - -const BaseRef BatchNormGradInferFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimBatchNormGrad, input0_var_, input1_var_, input2_var_, input3_var_, input4_var_, Xs}); -} - -const AnfNodePtr BatchNormGradInferFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, node->cast())) { - MS_LOG(DEBUG) << "The BatchNormGrad " << node->DebugString() << " has no is_training attr, should not be changed"; - return nullptr; - } - if (AnfAlgo::GetNodeAttr(node, kAttrIsTraining)) { - MS_LOG(DEBUG) << "The is_training attr value of " << node->DebugString() << " is true, no need change"; - return nullptr; - } - if (!CheckOutputsIndex(func_graph, node)) { - MS_LOG(DEBUG) << "The output 3 or 4 of BatchNormGrad is not null, no need change"; - return nullptr; - } - AnfNodePtr bn_infer_grad = CreateBNInferGrad(func_graph, node, equiv); - AnfNodePtr bn_training_update_grad = CreateBNTrainingUpdateGrad(func_graph, node, equiv); - std::vector bn_training_update_grad_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_grad, kBNTrainingUpdateGradOutputNum, - &bn_training_update_grad_outputs); - if (bn_training_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "The output size of " << bn_training_update_grad << " should be " - << kBNTrainingUpdateGradOutputNum << ", but it is " << bn_training_update_grad_outputs.size(); - } - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_infer_grad, - bn_training_update_grad_outputs[0], bn_training_update_grad_outputs[1]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h deleted file mode 100644 index a8eefdaa85..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormGradInferFission : public PatternProcessPass { - public: - explicit BatchNormGradInferFission(bool multigraph = true) - : PatternProcessPass("batch_norm_grad_infer_fission", multigraph), - input0_var_(std::make_shared()), - input1_var_(std::make_shared()), - input2_var_(std::make_shared()), - input3_var_(std::make_shared()), - input4_var_(std::make_shared()) {} - ~BatchNormGradInferFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - AnfNodePtr CreateBNInferGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, const EquivPtr &equiv) const; - AnfNodePtr CreateBNTrainingUpdateGrad(const FuncGraphPtr &func_graph, const AnfNodePtr &bn_grad, - const EquivPtr &equiv) const; - - VarPtr input0_var_; - VarPtr input1_var_; - VarPtr input2_var_; - VarPtr input3_var_; - VarPtr input4_var_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_INFER_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc deleted file mode 100644 index 270b02cb00..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/batch_norm_grad_split.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - std::vector *bn_update_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() < kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - std::vector bn_update_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], - bn_grad_inputs[4], bn_grad_inputs[5]}; - auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_update_grad); - bn_update_grad->set_kernel_info(std::make_shared()); - bn_update_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); - CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); -} - -void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - const std::vector &bn_update_grad_outputs, - std::vector *bn_reduce_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() < kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size"; - } - std::vector bn_reduce_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), - bn_grad_inputs[1], - bn_grad_inputs[2], - bn_update_grad_outputs[0], - bn_update_grad_outputs[1], - bn_grad_inputs[3], - bn_grad_inputs[4], - bn_grad_inputs[5]}; - auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_reduce_grad); - bn_reduce_grad->set_kernel_info(std::make_shared()); - bn_reduce_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); - (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); -} -} // namespace -const BaseRef BatchNormGradSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto prim = std::make_shared(kBatchNormGradOpName); - return VectorRef({prim, Xs}); -} - -const AnfNodePtr BatchNormGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - if (!primitive->HasAttr(kAttrIsTraining)) { - MS_LOG(INFO) << "Op BatchNormGrad must have attrs of is_training"; - return nullptr; - } - if (!AnfAlgo::GetNodeAttr(cnode, kAttrIsTraining)) { - MS_LOG(INFO) << "is_training must be true"; - return nullptr; - } - - std::vector bn_update_grad_outputs; - CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; - } - - std::vector bn_reduce_grad_outputs; - CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); - if (bn_reduce_grad_outputs.size() != kSingleOutputNum) { - MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; - } - - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], - bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.h deleted file mode 100644 index e539fdb27c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class BatchNormGradSplit : public PatternProcessPass { - public: - explicit BatchNormGradSplit(bool multigraph = true) : PatternProcessPass("batch_norm_grad_split", multigraph) {} - ~BatchNormGradSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BATCH_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.cc deleted file mode 100644 index 6282ed4f76..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.cc +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/bn_grad_split.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - std::vector *bn_update_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() != kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - std::vector bn_update_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], - bn_grad_inputs[4], bn_grad_inputs[5]}; - auto bn_update_grad = graph->NewCNode(bn_update_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_update_grad); - bn_update_grad->set_kernel_info(std::make_shared()); - bn_update_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 1), AnfAlgo::GetOutputInferDataType(bn_grad_node, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 1), AnfAlgo::GetOutputInferShape(bn_grad_node, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_update_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_update_grad); - CreateMultipleOutputsOfAnfNode(graph, bn_update_grad, kBNTrainingUpdateGradOutputNum, bn_update_grad_outputs); -} - -void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node, - const std::vector &bn_update_grad_outputs, - std::vector *bn_reduce_grad_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_grad_node); - auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() != kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size"; - } - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; - } - std::vector bn_reduce_grad_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceGradOpName)), - bn_grad_inputs[1], - bn_grad_inputs[2], - bn_update_grad_outputs[0], - bn_update_grad_outputs[1], - bn_grad_inputs[3], - bn_grad_inputs[4], - bn_grad_inputs[5]}; - auto bn_reduce_grad = graph->NewCNode(bn_reduce_grad_inputs); - MS_EXCEPTION_IF_NULL(bn_reduce_grad); - bn_reduce_grad->set_kernel_info(std::make_shared()); - bn_reduce_grad->set_scope(bn_grad_node->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(bn_grad_node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_grad_node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_reduce_grad.get()); - - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); - (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); -} - -CNodePtr BNGradSplitForTBE(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(func_graph); - std::vector bn_update_grad_outputs; - CreateOutputsOfUpdateGrad(func_graph, cnode, &bn_update_grad_outputs); - if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { - MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; - } - - std::vector bn_reduce_grad_outputs; - CreateOutputsOfReduceGrad(func_graph, cnode, bn_update_grad_outputs, &bn_reduce_grad_outputs); - if (bn_reduce_grad_outputs.size() != 1) { - MS_LOG(EXCEPTION) << "bn_reduce_grad_outputs has wrong size"; - } - - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), bn_reduce_grad_outputs[0], - bn_update_grad_outputs[0], bn_update_grad_outputs[1]}; - auto make_tuple = func_graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace - -const BaseRef BnGradSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimFusedBatchNormGrad, Xs}); -} - -const AnfNodePtr BnGradSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - return BNGradSplitForTBE(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.h deleted file mode 100644 index 17e1f9b98e..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_grad_split.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class BnGradSplit : public PatternProcessPass { - public: - explicit BnGradSplit(bool multigraph = true) : PatternProcessPass("bn_grad_split", multigraph) {} - ~BnGradSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc deleted file mode 100644 index 66ffa24bf1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.cc +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/bn_split.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, - std::vector *bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); - return false; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName))}; - bn_training_reduce_inputs.push_back(bn_cnode->input(1)); - auto bn_training_reduce = graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - bn_training_reduce->set_kernel_info(kernel_info); - std::vector bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0); - if (bn_shape_i0.size() < kShape2dDims) { - MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims; - return false; - } - std::vector bn_training_reduce_shape = {bn_shape_i0[1]}; - auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; - auto shapes = {bn_training_reduce_shape, bn_training_reduce_shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, bn_training_reduce.get()); - bn_training_reduce->set_scope(bn_cnode->scope()); - AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce); - - CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs); - return true; -} - -AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode, - const std::vector &bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; - } - // the inputs of BNTrainingUpdate are from the outputs of BNTrainingReduce and the inputs of BN - std::vector bn_training_update_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateOpName))}; - bn_training_update_inputs.push_back(bn_cnode->input(1)); - bn_training_update_inputs.push_back(bn_training_reduce_outputs[0]); - bn_training_update_inputs.push_back(bn_training_reduce_outputs[1]); - bn_training_update_inputs.push_back(bn_cnode->input(2)); - bn_training_update_inputs.push_back(bn_cnode->input(3)); - bn_training_update_inputs.push_back(bn_cnode->input(4)); - bn_training_update_inputs.push_back(bn_cnode->input(5)); - auto bn_training_update = graph->NewCNode(bn_training_update_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - bn_training_update->set_kernel_info(kernel_info); - bn_training_update->set_abstract(bn_cnode->abstract()); - bn_training_update->set_scope(bn_cnode->scope()); - auto factor = AnfAlgo::GetNodeAttr(bn_cnode, kAttrMomentum); - AnfAlgo::SetNodeAttr(kAttrFactor, MakeValue(factor), bn_training_update); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update); - AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); - return bn_training_update; -} - -AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kBnInputNum) { - MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; - return nullptr; - } - // Create BNTrainingReduce node and get outputs of BNTrainingReduce - std::vector bn_training_reduce_outputs; - if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) { - MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split"; - return nullptr; - } - if (bn_training_reduce_outputs.size() != kBN1OutputNum) { - MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail"; - } - - // Create BNTrainingUpdate node - return CreateOutputsOfBNTrainingUpdate(func_graph, cnode, bn_training_reduce_outputs); -} -} // namespace - -const BaseRef BnSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - return VectorRef({prim::kPrimFusedBatchNorm, Xs}); -} - -const AnfNodePtr BnSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - return SplitFusedBatchNormForTBE(func_graph, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.h deleted file mode 100644 index bc5975af17..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/bn_split.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class BnSplit : public PatternProcessPass { - public: - explicit BnSplit(bool multigraph = true) : PatternProcessPass("bn_split", multigraph) {} - ~BnSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_BN_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.cc deleted file mode 100644 index 479e00e4c0..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/lars_v2_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -namespace { -void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars_v2, - std::vector *square_sum_all_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(lars_v2); - if (lars_v2->size() != kLarsV2InputNum) { - MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; - } - - std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName))}; - inputs.push_back(lars_v2->input(1)); - inputs.push_back(lars_v2->input(2)); - auto square_sum_all = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(square_sum_all); - square_sum_all->set_scope(lars_v2->scope()); - - auto types = {kNumberTypeFloat32, kNumberTypeFloat32}; - std::vector shape; - auto shapes = {shape, shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sum_all.get()); - - CreateMultipleOutputsOfAnfNode(graph, square_sum_all, 2, square_sum_all_outputs); -} - -CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, - const std::vector &square_sum_all_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(lars_v2); - if (square_sum_all_outputs.size() != 2) { - MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2"; - } - if (lars_v2->size() != kLarsV2InputNum) { - MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum; - } - std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName))}; - inputs.push_back(lars_v2->input(1)); - inputs.push_back(lars_v2->input(2)); - inputs.push_back(square_sum_all_outputs[0]); - inputs.push_back(square_sum_all_outputs[1]); - inputs.push_back(lars_v2->input(3)); - inputs.push_back(lars_v2->input(4)); - auto lars_v2_update = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(lars_v2_update); - lars_v2_update->set_scope(lars_v2->scope()); - lars_v2_update->set_abstract(lars_v2->abstract()); - return lars_v2_update; -} -} // namespace - -const BaseRef LarsV2Fission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto lars_v2_prim = std::make_shared(kLarsV2OpName); - return VectorRef({lars_v2_prim, Xs}); -} - -const AnfNodePtr LarsV2Fission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto lars_v2 = node->cast(); - MS_EXCEPTION_IF_NULL(lars_v2); - - std::vector square_sum_all_outputs; - CreateOutputsOfSquareSumAll(graph, lars_v2, &square_sum_all_outputs); - return CreateLarsV2Update(graph, lars_v2, square_sum_all_outputs); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.h deleted file mode 100644 index 846d221c53..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/lars_v2_fission.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LarsV2Fission : public PatternProcessPass { - public: - explicit LarsV2Fission(bool multigraph = true) : PatternProcessPass("lars_v2_fission", multigraph) {} - ~LarsV2Fission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LARS_V2_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.cc deleted file mode 100644 index 1a25d83650..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "ir/primitive.h" -#include "common/utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -void LayerNormGradSplit::CreateOutputsOfLayerNormXBackprop( - const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_x_backprop_outputs) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(layer_norm_grad); - auto prim = std::make_shared(kLayerNormXBackpropOpName); - std::vector layer_norm_x_backprop_inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < layer_norm_grad->inputs().size(); ++i) { - layer_norm_x_backprop_inputs.push_back(layer_norm_grad->input(i)); - } - auto layer_norm_x_backprop = graph->NewCNode(layer_norm_x_backprop_inputs); - MS_EXCEPTION_IF_NULL(layer_norm_x_backprop); - layer_norm_x_backprop->set_scope(layer_norm_grad->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_x_backprop.get()); - - (*layer_norm_x_backprop_outputs).push_back(layer_norm_x_backprop); -} - -void LayerNormGradSplit::CreateOutputsOfLayerNormBetaGammaBackprop( - const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_beta_gamma_backprop_outputs) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(layer_norm_grad); - auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); - std::vector layer_norm_beta_gamma_backprop_inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < layer_norm_grad->inputs().size() - 1; ++i) { - layer_norm_beta_gamma_backprop_inputs.push_back(layer_norm_grad->input(i)); - } - auto layer_norm_beta_gamma_backprop = graph->NewCNode(layer_norm_beta_gamma_backprop_inputs); - MS_EXCEPTION_IF_NULL(layer_norm_beta_gamma_backprop); - auto kernel_info = std::make_shared(); - layer_norm_beta_gamma_backprop->set_kernel_info(kernel_info); - layer_norm_beta_gamma_backprop->set_scope(layer_norm_grad->scope()); - - auto types = {AnfAlgo::GetOutputInferDataType(layer_norm_grad, 1), - AnfAlgo::GetOutputInferDataType(layer_norm_grad, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(layer_norm_grad, 1), AnfAlgo::GetOutputInferShape(layer_norm_grad, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, layer_norm_beta_gamma_backprop.get()); - - // get device shape of LayerNormGrad's 5th Input, and convert it to attr - std::vector shape_gamma = AnfAlgo::GetPrevNodeOutputInferShape(layer_norm_grad, 4); - AnfAlgo::SetNodeAttr(kAttrShapeGamma, MakeValue(opt::Convert2Int(shape_gamma)), layer_norm_beta_gamma_backprop); - - CreateMultipleOutputsOfAnfNode(graph, layer_norm_beta_gamma_backprop, kLayerNormBetaGammaBackpropOutputNum, - layer_norm_beta_gamma_backprop_outputs); -} - -const BaseRef LayerNormGradSplit::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VectorRef pattern({prim::kPrimLayerNormGrad, Xs}); - return pattern; -} - -const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode->inputs().size() != kLayerNormGradInputNum) { - return nullptr; - } - - // create layer_norm_x_backprop - std::vector layer_norm_x_backprop_outputs; - CreateOutputsOfLayerNormXBackprop(graph, cnode, &layer_norm_x_backprop_outputs); - if (layer_norm_x_backprop_outputs.size() != kSingleOutputNum) { - MS_LOG(EXCEPTION) << "layer_norm_grad_outputs has wrong size"; - } - - // create layer_norm_beta_gamma_backprop - std::vector layer_norm_beta_gamma_backprop_outputs; - CreateOutputsOfLayerNormBetaGammaBackprop(graph, cnode, &layer_norm_beta_gamma_backprop_outputs); - if (layer_norm_beta_gamma_backprop_outputs.size() != kLayerNormBetaGammaBackpropOutputNum) { - MS_LOG(EXCEPTION) << "layer_norm_beta_gamma_outputs has wrong size"; - } - - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple), layer_norm_x_backprop_outputs[0], - layer_norm_beta_gamma_backprop_outputs[0], - layer_norm_beta_gamma_backprop_outputs[1]}; - auto make_tuple = graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - return make_tuple; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.h deleted file mode 100644 index f442446b01..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/layer_norm_grad_split.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class LayerNormGradSplit : public PatternProcessPass { - public: - explicit LayerNormGradSplit(bool multigraph = true) : PatternProcessPass("layer_norm_grad_split", multigraph) {} - ~LayerNormGradSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void CreateOutputsOfLayerNormXBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_grad_outputs) const; - void CreateOutputsOfLayerNormBetaGammaBackprop(const FuncGraphPtr &graph, const CNodePtr &layer_norm_grad, - std::vector *layer_norm_beta_gamma_outputs) const; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_LAYER_NORM_GRAD_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc deleted file mode 100644 index 159be2ac3b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kBatchNormRealInputNum = 3; - -AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), bn_cnode->input(1)}; - auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - - // set abstract - auto bn_input1 = bn_cnode->input(2); - MS_EXCEPTION_IF_NULL(bn_input1); - AbstractBasePtrList abstract_list{bn_input1->abstract(), bn_input1->abstract()}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_reduce->set_abstract(abstract_tuple); - bn_training_reduce->set_scope(bn->scope()); - return bn_training_reduce; -} - -AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, - const std::vector &bn_training_reduce_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - auto bn_cnode = bn->cast(); - MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) { - MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than " - << kBatchNormRealInputNum + 1; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } - std::vector bn_training_update_v3_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateV3OpName)), - bn_cnode->input(1), - bn_training_reduce_outputs[0], - bn_training_reduce_outputs[1], - bn_cnode->input(2), - bn_cnode->input(3)}; - auto bn_training_update_v3 = func_graph->NewCNode(bn_training_update_v3_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update_v3); - - auto bn_abstract_tuple = dyn_cast(bn->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " - << bn_abstract_tuple->elements().size(); - } - bn_training_update_v3->set_abstract(bn->abstract()); - bn_training_update_v3->set_scope(bn->scope()); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_cnode, bn_training_update_v3); - return bn_training_update_v3; -} -} // namespace - -const BaseRef SingleBatchNormFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - return VectorRef({prim::kPrimBatchNorm, Xs}); -} - -const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() < kBatchNormRealInputNum + 1) { - MS_LOG(INFO) << "The input num of BatchNorm less than" << kBatchNormRealInputNum - << ". The node should not be changed"; - return nullptr; - } - if (!GetBoolAttr(cnode, kAttrIsTraining)) { - MS_LOG(INFO) << "is training should be true if do fusion"; - return nullptr; - } - AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node); - std::vector bn_training_reduce_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, - &bn_training_reduce_outputs); - - return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.h deleted file mode 100644 index 145603132b..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/single_batch_norm_fission.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SingleBatchNormFission : public PatternProcessPass { - public: - explicit SingleBatchNormFission(bool multigraph = true) - : PatternProcessPass("single_batch_norm_fission", multigraph) {} - ~SingleBatchNormFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SINGLE_BATCH_NORM_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc deleted file mode 100644 index 2ab1cb6130..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.cc +++ /dev/null @@ -1,197 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/split_fission.h" -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(input_node); - std::vector splitv_inputs{NewValueNode(std::make_shared(kSplitVOpName)), input_node}; - CNodePtr splitv = func_graph->NewCNode(splitv_inputs); - MS_EXCEPTION_IF_NULL(splitv); - splitv->set_scope(input_node->scope()); - return splitv; -} - -CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { - MS_EXCEPTION_IF_NULL(origin_cnode); - if (origin_cnode->inputs().size() < kSplitInputNum) { - MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " - << kSplitInputNum - 1; - } - return CreateSplitVNode(func_graph, origin_cnode->input(1)); -} - -void SetAttrForSplitVNode(const AnfNodePtr &splitv, const std::vector &size_splits, int split_dim, int num_split) { - AnfAlgo::SetNodeAttr(kAttrSizeSplits, MakeValue(size_splits), splitv); - AnfAlgo::SetNodeAttr(kAttrSplitDim, MakeValue(split_dim), splitv); - AnfAlgo::SetNodeAttr(kAttrNumSplit, MakeValue(num_split), splitv); -} - -size_t GetSmallSplitSize(const AnfNodePtr &split_node, int split_dim, int num_split) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(split_node, 0); - if (split_dim < 0) { - split_dim += input_shape.size(); - } - if (IntToSize(split_dim) >= input_shape.size()) { - MS_LOG(EXCEPTION) << "The split_dim value should be less than the shape size of input 0"; - } - return input_shape[split_dim] / num_split; -} - -void AddNewOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &new_splitv, int outputs_num, - std::vector *inputs) { - MS_EXCEPTION_IF_NULL(inputs); - std::vector new_splitv_output; - CreateMultipleOutputsOfAnfNode(func_graph, new_splitv, outputs_num, &new_splitv_output); - inputs->insert(inputs->end(), new_splitv_output.begin(), new_splitv_output.end()); -} - -AnfNodePtr CreateTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, size_t index) { - MS_EXCEPTION_IF_NULL(func_graph); - auto idx = NewValueNode(SizeToInt(index)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(SizeToInt(index)); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, idx}); - return tuple_getitem; -} - -void CreateOutputShapeAndTypeId(const CNodePtr &origin_cnode, int split_dim, int split_size, int num_split, - std::vector *new_type_ids, - std::vector> *new_output_shapes) { - MS_EXCEPTION_IF_NULL(new_type_ids); - MS_EXCEPTION_IF_NULL(new_output_shapes); - auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); - if (split_dim < 0) { - split_dim += output_shape.size(); - } - output_shape[split_dim] = split_size; - TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); - for (int i = 0; i < num_split; ++i) { - new_type_ids->emplace_back(type_id); - new_output_shapes->emplace_back(output_shape); - } -} - -void SetAttrAndAbstractForBaseSplitv(const CNodePtr &origin_cnode, const CNodePtr &base_splitv, - const std::vector &size_splits_base, int split_dim, int num_split) { - SetAttrForSplitVNode(base_splitv, size_splits_base, split_dim, num_split); - std::vector base_type_ids; - std::vector> base_output_shapes_base; - auto output_shape = AnfAlgo::GetOutputInferShape(origin_cnode, 0); - TypeId type_id = AnfAlgo::GetOutputInferDataType(origin_cnode, 0); - if (split_dim < 0) { - split_dim += output_shape.size(); - } - for (int i = 0; i < num_split; ++i) { - output_shape[split_dim] = size_splits_base[i]; - base_output_shapes_base.emplace_back(output_shape); - base_type_ids.emplace_back(type_id); - } - AnfAlgo::SetOutputInferTypeAndShape(base_type_ids, base_output_shapes_base, base_splitv.get()); -} - -AnfNodePtr DoFission(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int num_split, int divisor) { - MS_EXCEPTION_IF_NULL(func_graph); - auto split_dim = AnfAlgo::GetNodeAttr(cnode, kAttrAxis); - CNodePtr base_splitv = CreateBaseSplitVNode(func_graph, cnode); - - // Create new size_splits for "size_splits" attr of each new Splitv node which has full inputs. - auto small_split_size = SizeToInt(GetSmallSplitSize(cnode, split_dim, num_split)); - std::vector size_splits_new; - for (int i = 0; i < divisor; ++i) { - size_splits_new.emplace_back(small_split_size); - } - // Create new output shape and new output type id for each new Splitv node which has full inputs. - std::vector new_type_ids; - std::vector> new_output_shapes; - CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, divisor, &new_type_ids, &new_output_shapes); - - // Create make_tuple input to create a make_tuple for replacing the old Split node. - std::vector make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; - // Start to divide the outputs of Split. - std::vector size_splits_base; - const auto base_split_size = divisor * small_split_size; - int nodes_num = 0; - int cur_output_index = 0; - while (num_split - cur_output_index > divisor) { - CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - SetAttrForSplitVNode(new_splitv, size_splits_new, split_dim, divisor); - AnfAlgo::SetOutputInferTypeAndShape(new_type_ids, new_output_shapes, new_splitv.get()); - AddNewOutputs(func_graph, new_splitv, divisor, &make_tuple_inputs); - cur_output_index += divisor; - size_splits_base.emplace_back(base_split_size); - nodes_num++; - } - if (cur_output_index < num_split) { - auto last_node_num_split = num_split - cur_output_index; - if (last_node_num_split > 1) { - CNodePtr new_splitv = CreateSplitVNode(func_graph, CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - std::vector size_splits_new_last; - for (int i = 0; i < last_node_num_split; ++i) { - size_splits_new_last.emplace_back(small_split_size); - } - SetAttrForSplitVNode(new_splitv, size_splits_new_last, split_dim, last_node_num_split); - // Create new output shape and new output type id for the last Splitv node - std::vector last_new_type_ids; - std::vector> last_new_output_shapes; - CreateOutputShapeAndTypeId(cnode, split_dim, small_split_size, last_node_num_split, &last_new_type_ids, - &last_new_output_shapes); - AnfAlgo::SetOutputInferTypeAndShape(last_new_type_ids, last_new_output_shapes, new_splitv.get()); - AddNewOutputs(func_graph, new_splitv, last_node_num_split, &make_tuple_inputs); - size_splits_base.emplace_back(last_node_num_split * small_split_size); - } else { - make_tuple_inputs.emplace_back(CreateTupleGetItem(func_graph, base_splitv, nodes_num)); - size_splits_base.emplace_back(small_split_size); - } - nodes_num++; - } - // Set Attr and abstract for the base splitv - SetAttrAndAbstractForBaseSplitv(cnode, base_splitv, size_splits_base, split_dim, nodes_num); - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace - -const BaseRef SplitFission::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto split_prim = std::make_shared(kSplitOpName); - return VectorRef({split_prim, Xs}); -} - -const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // Check output num - if (!AnfAlgo::HasNodeAttr(kAttrOutputNum, cnode)) { - return nullptr; - } - auto num_split = AnfAlgo::GetNodeAttr(cnode, kAttrOutputNum); - if (num_split <= outputs_divisor_) { - return nullptr; - } - return DoFission(func_graph, cnode, num_split, outputs_divisor_); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h deleted file mode 100644 index c2763bb714..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/split_fission.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -constexpr int kSplitOutputsDivisor = 63; -class SplitFission : public PatternProcessPass { - public: - explicit SplitFission(bool multigraph = true) - : PatternProcessPass("split_fission", multigraph), outputs_divisor_(kSplitOutputsDivisor) {} - ~SplitFission() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - int outputs_divisor_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_SPLIT_FISSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc deleted file mode 100644 index c8477353f9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ /dev/null @@ -1,182 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/topk_split.h" -#include -#include -#include -#include -#include "pre_activate/common/helper.h" -#include "kernel/kernel_build_info.h" -#include "utils/utils.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -constexpr size_t kFloat16Len = 2; // size of float16; -constexpr size_t kTopkIndexK = 1; -namespace { -tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { - // 1 create tensor - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - auto last_dim = shape[shape.size() - 1]; - std::vector indices_shape = {SizeToInt(last_dim * 2)}; - TensorTypePtr tensor_type = std::make_shared(kFloat16); - MS_EXCEPTION_IF_NULL(tensor_type); - tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type}; - tensor::TensorPtr indices_tensor = std::make_shared(kFloat16->type_id(), indices_shape); - MS_EXCEPTION_IF_NULL(indices_tensor); - indices_tensor->set_device_info(device_info); - - // 2 set value of tensor - auto data_ptr = indices_tensor->data_c(); - MS_EXCEPTION_IF_NULL(data_ptr); - std::vector half_data; - for (size_t i = 0; i < last_dim; ++i) { - half_data.emplace_back(Eigen::half(static_cast(i))); - } - for (size_t i = 0; i < last_dim; ++i) { - auto gap = static_cast(i) - static_cast(Eigen::half(static_cast(i))); - half_data.emplace_back(Eigen::half(static_cast(gap))); - } - auto elem_num = last_dim * kFloat16Len * 2; - auto ret_code = memcpy_s(data_ptr, static_cast(indices_tensor->data().nbytes()), half_data.data(), elem_num); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data into Tensor."; - return nullptr; - } - return indices_tensor; -} - -ValueNodePtr CreateValueNode(const AnfNodePtr &node) { - tensor::TensorPtr indices_tensor = CreateTensor(node); - MS_EXCEPTION_IF_NULL(indices_tensor); - auto indices_const = std::make_shared(indices_tensor); - MS_EXCEPTION_IF_NULL(indices_const); - auto indices_abstract = indices_tensor->ToAbstract(); - indices_const->set_abstract(indices_abstract); - auto indices_kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(indices_kernel_info); - indices_const->set_kernel_info(indices_kernel_info); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1; - builder1.SetOutputsFormat({kOpFormat_DEFAULT}); - builder1.SetOutputsDeviceType({kNumberTypeFloat16}); - AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); - return indices_const; -} - -kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetKernelType(TBE_KERNEL); - builder.SetFusionType(kernel::OPAQUE); - builder.SetProcessor(kernel::AICORE); - builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); - builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); - builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); - builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); - return builder.Build(); -} - -bool CheckInputNamesSize(const CNodePtr &cnode) { - auto input_names_vec = AnfAlgo::GetNodeAttr>(cnode, kAttrInputNames); - if (input_names_vec.size() < kTopkIndexK + 1) { - MS_LOG(INFO) << "The input k of topk has been converted to attr"; - return false; - } - return true; -} - -bool CheckOutputShape(const AnfNodePtr &node) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - if (shape.empty()) { - MS_LOG(INFO) << "The output shape of topk to split must not be empty"; - return false; - } - auto last_dim = shape[shape.size() - 1]; - const size_t kMaxFloat16 = 65500; - if (last_dim > kMaxFloat16) { - MS_LOG(INFO) << "The last dim is more than " << kMaxFloat16 << ", switch to aicpu ops."; - return false; - } - return true; -} -} // namespace - -const BaseRef TopKSplit::DefinePattern() const { - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - auto prim = std::make_shared(kTopKOpName); - return VectorRef({prim, X1, X2}); -} - -const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto kernel_graph = func_graph->cast(); - // set value node as topk's input - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!CheckInputNamesSize(cnode)) { - return nullptr; - } - if (!CheckOutputShape(cnode)) { - return nullptr; - } - // Copy a new node to check supported. - std::vector new_inputs{NewValueNode(std::make_shared(kTopKOpName))}; - new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - CNodePtr new_cnode = func_graph->NewCNode(new_inputs); - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_scope(cnode->scope()); - AnfAlgo::CopyNodeAttrs(cnode, new_cnode); - CheckCNodeInputSize(new_cnode, kTopkInputNum); - // Convert the tensor input to scalar and convert it to attr - auto input_k = new_cnode->input(kTopkIndexK + 1); - MS_EXCEPTION_IF_NULL(input_k); - if (!IsValueNode(input_k)) { - return nullptr; - } - ValuePtr value = GetValueNode(input_k); - MS_EXCEPTION_IF_NULL(value); - auto tensor = value->cast(); - MS_EXCEPTION_IF_NULL(tensor); - int32_t *data = reinterpret_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(data); - auto new_value_node = std::make_shared(MakeValue(*data)); - new_cnode->set_input(kTopkIndexK + 1, new_value_node); - - std::unordered_set attr_index{kTopkIndexK}; - ConstInputToAttr(new_cnode, attr_index); - auto indices_const = CreateValueNode(new_cnode); - new_cnode->add_input(indices_const); - MS_EXCEPTION_IF_NULL(supported_checker_); - if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { - MS_LOG(INFO) << "split topk failed, check to aicpu."; - return nullptr; - } - - if (kernel_graph != nullptr) { - MS_LOG(INFO) << "split topk success. use tbe aicore."; - kernel_graph->AddValueNodeToGraph(indices_const); - } - - return new_cnode; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h deleted file mode 100644 index e7293e1fa3..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TopKSplit : public PatternProcessPass { - public: - explicit TopKSplit(bool multigraph = true) - : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared()) {} - ~TopKSplit() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - SupportedCheckerPtr supported_checker_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc deleted file mode 100644 index bfb7e50486..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fission/transdata_split.h" -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, - {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, - {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, - {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; - -bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - bool changed = false; - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { - CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); - if (IsFormatInvaild(node)) { - changed = DoSplit(func_graph, node); - } - } - } - return changed; -} -bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_format = AnfAlgo::GetInputFormat(node, 0); - auto output_format = AnfAlgo::GetOutputFormat(node, 0); - auto format_pair = std::make_pair(input_format, output_format); - - return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); -} -// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) -bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = node->cast()->input(1); - MS_EXCEPTION_IF_NULL(input_node); - - auto input_format = AnfAlgo::GetInputFormat(node, 0); - auto output_format = AnfAlgo::GetOutputFormat(node, 0); - AnfNodePtr new_transdata_node = nullptr; - AnfNodePtr new_transpose_node = nullptr; - AnfNodePtr new_replace_node = nullptr; - // if output_format=default transdata need split transdata->transpose else transpose->transdata - if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { - // trans input_format to hwcn - new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transdata_node); - // trans hwcn to default_format - new_transpose_node = - NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transpose_node); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); - new_replace_node = new_transpose_node; - } else { - // trans default to hwcn - new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::kPrimTranspose->name()); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, new_transpose_node); - - // trans hwcn to output_format - new_transdata_node = - NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, new_transdata_node); - new_replace_node = new_transdata_node; - } - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - if (!manager->Replace(node, new_replace_node)) { - MS_LOG(EXCEPTION) << "Manager replace node failed"; - } - MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; - return true; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.h deleted file mode 100644 index f450897db1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ -#include -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TransDataSplit : public Pass { - public: - TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared()) {} - ~TransDataSplit() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); - bool IsFormatInvaild(const AnfNodePtr &node); - KernelSelectPtr kernel_select_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TRANSDATA_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc deleted file mode 100644 index 4db08d0859..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" -#include "pre_activate/common/helper.h" -namespace mindspore { -namespace opt { -AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto prim = std::make_shared(kAdamApplyOneOpName); - std::vector new_node_inputs = {NewValueNode(prim)}; - for (const auto &input_var : input_vars_) { - auto input_node = utils::cast((*equiv)[input_var]); - MS_EXCEPTION_IF_NULL(input_node); - new_node_inputs.push_back(input_node); - } - for (const auto &mul_x_input_var : mul_x_input_vars_) { - auto mul_x_input_node = utils::cast((*equiv)[mul_x_input_var]); - MS_EXCEPTION_IF_NULL(mul_x_input_node); - new_node_inputs.push_back(mul_x_input_node); - } - auto add2_y_node = utils::cast((*equiv)[add2_y_]); - MS_EXCEPTION_IF_NULL(add2_y_node); - new_node_inputs.push_back(add2_y_node); - auto new_node = func_graph->NewCNode(new_node_inputs); - return new_node; -} - -const BaseRef AdamApplyOneFusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); -} - -const BaseRef AdamApplyOneCond1Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); -} - -const BaseRef AdamApplyOneCond2Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, VectorRef({prim::kPrimSquare, input_vars_[0]}), mul_x_input_vars_[3]}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); -} - -const BaseRef AdamApplyOneCond3Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); -} - -const BaseRef AdamApplyOneCond4Fusion::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); - VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); - VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef true_div0 = VectorRef({prim_real_div, add0, VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt0})}); - return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, true_div0, input_vars_[4]})}); -} - -const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto new_node = CreateAdamApplyOneNode(func_graph, equiv); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(node->scope()); - // Set abstract of new node - AbstractBasePtrList new_node_abstract_list; - auto iter_add0 = (*equiv).find(add0_var_); - if (iter_add0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; - } - auto iter_add1 = (*equiv).find(add1_var_); - if (iter_add1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; - } - auto add0 = utils::cast(iter_add0->second); - MS_EXCEPTION_IF_NULL(add0); - auto add1 = utils::cast(iter_add1->second); - MS_EXCEPTION_IF_NULL(add1); - new_node_abstract_list.push_back(add1->abstract()); - new_node_abstract_list.push_back(add0->abstract()); - new_node_abstract_list.push_back(node->abstract()); - auto abstract_tuple = std::make_shared(new_node_abstract_list); - new_node->set_abstract(abstract_tuple); - // Create tuple_getitem node for outputs - std::vector new_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, new_node, kAdamApplyOneOutputNum, &new_node_outputs); - if (new_node_outputs.size() != kAdamApplyOneOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node " << new_node->DebugString() << " should be " - << kAdamApplyOneOutputNum; - } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add1, new_node_outputs[0]); - (void)manager->Replace(add0, new_node_outputs[1]); - return new_node_outputs[2]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h deleted file mode 100644 index 5ee8a86cfb..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h +++ /dev/null @@ -1,95 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -constexpr size_t kAdamApplyOneInputVarNum = 5; -constexpr size_t kAdamApplyOneMulInputVarNum = 4; - -class AdamApplyOneFusion : public PatternProcessPass { - public: - explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { - input_vars_.push_back(std::make_shared()); - } - for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { - mul_x_input_vars_.push_back(std::make_shared()); - } - add2_y_ = std::make_shared(); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - - ~AdamApplyOneFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; - std::vector input_vars_; - std::vector mul_x_input_vars_; - VarPtr add2_y_; - VarPtr add0_var_; - VarPtr add1_var_; -}; - -class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond1Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} - - ~AdamApplyOneCond1Fusion() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond2Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} - - ~AdamApplyOneCond2Fusion() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond3Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} - - ~AdamApplyOneCond3Fusion() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { - public: - explicit AdamApplyOneCond4Fusion(bool multigraph = true) - : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} - - ~AdamApplyOneCond4Fusion() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc deleted file mode 100644 index f6077c95f2..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - auto input0 = utils::cast((*equiv)[input0_]); - auto input1 = utils::cast((*equiv)[input1_]); - auto input2 = utils::cast((*equiv)[input2_]); - auto input3 = utils::cast((*equiv)[input3_]); - auto input4 = utils::cast((*equiv)[input4_]); - auto mul0_x = utils::cast((*equiv)[mul0_x_]); - auto mul1_x = utils::cast((*equiv)[mul1_x_]); - auto mul2_x = utils::cast((*equiv)[mul2_x_]); - auto mul3_x = utils::cast((*equiv)[mul3_x_]); - auto mul4_x = utils::cast((*equiv)[mul4_x_]); - auto add2_y = utils::cast((*equiv)[add2_y_]); - auto prim = std::make_shared(kAdamApplyOneWithDecayOpName); - return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; -} - -const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, input4_, add3}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); - VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); - VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { - auto sqrt = std::make_shared(kSqrtOpName); - auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0({prim::kPrimSquare, input0_}); - VectorRef add0({add0_var_, mul0, mul1}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); - VectorRef add1({add1_var_, mul2, mul3}); - VectorRef sqrt0({sqrt, add1}); - VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div0({real_div, add0, add2}); - VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); - VectorRef mul5({prim::kPrimMul, add3, input4_}); - VectorRef sub0({prim::kPrimSub, input3_, mul5}); - return sub0; -} - -const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - std::vector inputs = GetFusionNodeInputs(equiv); - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); - - auto iter_add0 = (*equiv).find(add0_var_); - if (iter_add0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; - } - auto iter_add1 = (*equiv).find(add1_var_); - if (iter_add1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; - } - auto add0 = utils::cast(iter_add0->second); - MS_EXCEPTION_IF_NULL(add0); - auto add1 = utils::cast(iter_add1->second); - MS_EXCEPTION_IF_NULL(add1); - auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), - AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), - AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - - std::vector fusion_node_outputs; - CreateMultipleOutputsOfAnfNode(graph, fusion_node, kAdamApplyOneWithDecayOutputNum, &fusion_node_outputs); - if (fusion_node_outputs.size() != kAdamApplyOneWithDecayOutputNum) { - MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; - return nullptr; - } - - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add1, fusion_node_outputs[0]); - (void)manager->Replace(add0, fusion_node_outputs[1]); - return fusion_node_outputs[2]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h deleted file mode 100644 index 742295dd9c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" -namespace mindspore { -namespace opt { -class AdamApplyOneWithDecayRule : public PatternProcessPass { - public: - explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - mul0_x_ = std::make_shared(); - mul1_x_ = std::make_shared(); - mul2_x_ = std::make_shared(); - mul3_x_ = std::make_shared(); - mul4_x_ = std::make_shared(); - add2_y_ = std::make_shared(); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - ~AdamApplyOneWithDecayRule() override = default; - const BaseRef DefinePattern() const override = 0; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr mul0_x_; - VarPtr mul1_x_; - VarPtr mul2_x_; - VarPtr mul3_x_; - VarPtr mul4_x_; - VarPtr add2_y_; - VarPtr add0_var_; - VarPtr add1_var_; -}; - -class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond1() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond2() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond3() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond4() override = default; - const BaseRef DefinePattern() const override; -}; - -class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { - public: - explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) - : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} - - ~AdamApplyOneWithDecayRuleCond5() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc deleted file mode 100644 index 867f30b9d2..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.cc +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/add_input_to_output.h" -#include -#include -#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" - -namespace mindspore { -namespace opt { -namespace { -void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector *names_vec) { - MS_EXCEPTION_IF_NULL(names_vec); - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - ValuePtr names_value = primitive->GetAttr(attr_name); - if (names_value == nullptr) { - return; - } - *names_vec = GetValue>(names_value); -} - -void AddOutputs(const CNodePtr &cnode, const std::vector &input_indices) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector input_names_vec; - GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec); - std::vector output_names_vec; - GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec); - AbstractBasePtrList abstract_list; - auto origin_abstract = cnode->abstract(); - MS_EXCEPTION_IF_NULL(origin_abstract); - if (origin_abstract->isa()) { - auto origin_abstract_tuple = dyn_cast(origin_abstract); - MS_EXCEPTION_IF_NULL(origin_abstract_tuple); - AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements(); - (void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list)); - } else { - abstract_list.emplace_back(origin_abstract); - } - - for (size_t i = 0; i < input_indices.size(); ++i) { - size_t index = input_indices[i]; - if (index + 1 >= cnode->inputs().size()) { - MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, " - << "node: " << cnode->DebugString(); - continue; - } - auto node_to_output = cnode->input(index + 1); - MS_EXCEPTION_IF_NULL(node_to_output); - abstract_list.emplace_back(node_to_output->abstract()); - if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) { - output_names_vec.emplace_back(input_names_vec[index]); - } - } - if (!output_names_vec.empty()) { - AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode); - } - auto abstract_tuple = std::make_shared(abstract_list); - cnode->set_abstract(abstract_tuple); -} -} // namespace - -const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::string op_name = AnfAlgo::GetCNodeName(cnode); - InputToOutputRegister reg; - if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) { - return nullptr; - } - int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); - // No need add output when it is not a tbe op. - if (output_num == -1) { - return nullptr; - } - // No need add output if the output num matches the registered output num for tbe. - if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) { - return nullptr; - } - bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode); - AddOutputs(cnode, reg.input_indices()); - // No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems - // pointed to the outputs. - if (is_origin_tuple_output) { - return nullptr; - } - std::vector new_outputs; - auto new_abstract_tuple = dyn_cast(cnode->abstract()); - MS_EXCEPTION_IF_NULL(new_abstract_tuple); - CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs); - if (new_outputs.size() != new_abstract_tuple->size()) { - MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString(); - } - return new_outputs[0]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h deleted file mode 100644 index d57b32f370..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/add_input_to_output.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class AddInputToOutput : public PatternProcessPass { - public: - explicit AddInputToOutput(bool multigraph = true) - : PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared()) {} - ~AddInputToOutput() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - OpFinderPtr op_finder_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.cc deleted file mode 100644 index debe9e8351..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.cc +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "abstract/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateBNInfer(const FuncGraphPtr &graph, const CNodePtr &batchnorm, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnorm); - MS_EXCEPTION_IF_NULL(node); - auto prim = std::make_shared(kBNInferOpName); - std::vector inputs = {NewValueNode(prim)}; - for (size_t i = 1; i < batchnorm->size(); ++i) { - inputs.push_back(batchnorm->input(i)); - } - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(batchnorm->scope()); - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnorm, new_node); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnorm, new_node); - return new_node; -} - -bool CheckIndex(const AnfNodePtr &index_node) { - MS_EXCEPTION_IF_NULL(index_node); - if (!IsValueNode(index_node)) { - return false; - } - ValueNodePtr value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index != 0) { - MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNorm"; - return false; - } - return true; -} - -bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnorm); - if (batchnorm->size() < kBatchNormInputNum + 1) { - MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum; - return false; - } - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { - return false; - } - auto is_training = AnfAlgo::GetNodeAttr(batchnorm, kAttrIsTraining); - if (is_training) { - MS_LOG(DEBUG) << "is_training is true, no need do fusion"; - return false; - } - - if (IsUsedByOthers(graph, batchnorm)) { - MS_LOG(DEBUG) << "Only the 0th output of BatchNorm is used, then do fusion"; - return false; - } - return true; -} - -bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnorm) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto tuple_getitem = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); - AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - if (!CheckIndex(index_node)) { - return false; - } - - AnfNodePtr batchnorm_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(batchnorm_anf); - MS_EXCEPTION_IF_NULL(batchnorm); - *batchnorm = batchnorm_anf->cast(); - MS_EXCEPTION_IF_NULL(*batchnorm); - return CheckBatchNorm(graph, *batchnorm); -} -} // namespace - -const BaseRef BatchNorm2BNInfer::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr Y = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Y); - VectorRef batchnorm({prim::kPrimBatchNorm, Xs}); - VectorRef pattern({prim::kPrimTupleGetItem, batchnorm, Y}); - return pattern; -} - -const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - - CNodePtr batchnorm = nullptr; - if (!NeedFusion(graph, node, &batchnorm)) { - return nullptr; - } - return CreateBNInfer(graph, batchnorm, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h deleted file mode 100644 index 551fe0f6f9..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNorm2BNInfer : public PatternProcessPass { - public: - explicit BatchNorm2BNInfer(bool multigraph = true) : PatternProcessPass("batchnorm_to_bninfer", multigraph) {} - ~BatchNorm2BNInfer() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORM_TO_BNINFER_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc deleted file mode 100644 index e9d28c32dc..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "abstract/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateBNInferGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnormgrad); - auto prim = std::make_shared(kBNInferGradOpName); - std::vector inputs = {NewValueNode(prim)}; - inputs.push_back(batchnormgrad->input(1)); - inputs.push_back(batchnormgrad->input(3)); - inputs.push_back(batchnormgrad->input(5)); - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(batchnormgrad->scope()); - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttr(kAttrIsTraining, batchnormgrad, new_node); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, batchnormgrad, new_node); - return new_node; -} - -bool CheckIndex(const AnfNodePtr &index_node) { - MS_EXCEPTION_IF_NULL(index_node); - if (!IsValueNode(index_node)) { - return false; - } - ValueNodePtr value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index != 0) { - MS_LOG(DEBUG) << "tuple_getitem must be 0th output of BatchNormGrad"; - return false; - } - return true; -} - -bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(batchnormgrad); - if (batchnormgrad->size() < kBatchNormInputNum + 1) { - MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; - return false; - } - if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { - return false; - } - auto is_training = AnfAlgo::GetNodeAttr(batchnormgrad, kAttrIsTraining); - if (is_training) { - MS_LOG(DEBUG) << "is_training is true, no need do fusion"; - return false; - } - - if (IsUsedByOthers(graph, batchnormgrad)) { - MS_LOG(DEBUG) << "Only the 0th output of BatchNormGrad is used, then do fusion"; - return false; - } - return true; -} - -bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *batchnormgrad) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto tuple_getitem = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); - AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - if (!CheckIndex(index_node)) { - return false; - } - - AnfNodePtr batchnormgrad_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(batchnormgrad_anf); - MS_EXCEPTION_IF_NULL(batchnormgrad); - *batchnormgrad = batchnormgrad_anf->cast(); - MS_EXCEPTION_IF_NULL(*batchnormgrad); - return CheckBatchNormGrad(graph, *batchnormgrad); -} -} // namespace - -const BaseRef BatchNormGrad2BNInferGrad::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr Y = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Y); - VectorRef batchnormgrad({prim::kPrimBatchNormGrad, Xs}); - VectorRef pattern({prim::kPrimTupleGetItem, batchnormgrad, Y}); - return pattern; -} - -const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - - CNodePtr batchnormgrad = nullptr; - if (!NeedFusion(graph, node, &batchnormgrad)) { - return nullptr; - } - return CreateBNInferGrad(graph, batchnormgrad, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h deleted file mode 100644 index 020dc1a999..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class BatchNormGrad2BNInferGrad : public PatternProcessPass { - public: - explicit BatchNormGrad2BNInferGrad(bool multigraph = true) - : PatternProcessPass("batchnormgrad_to_bninfergrad", multigraph) {} - ~BatchNormGrad2BNInferGrad() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_BATCHNORMGRAD_TO_BNINFERGRAD_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc deleted file mode 100644 index 2af3afbf19..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.cc +++ /dev/null @@ -1,74 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "common/utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -const BaseRef ClipByNormNoDivSquareSumFusion::DefinePattern() const { - auto greater = std::make_shared(kGreaterOpName); - MS_EXCEPTION_IF_NULL(greater); - auto sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(sqrt); - - VectorRef greater_pattern({greater, input_, constant_greater_}); - VectorRef pattern( - {prim::kPrimMaximum, - VectorRef({prim::kPrimSelect, greater_pattern, - VectorRef({sqrt, VectorRef({prim::kPrimSelect, greater_pattern, input_, constant_select_})}), input_}), - constant_maximum_}); - return pattern; -} - -const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - BaseRef &input_gnode = (*equiv)[input_]; - BaseRef &constant_select_gnode = (*equiv)[constant_select_]; - BaseRef &constant_greater_gnode = (*equiv)[constant_greater_]; - BaseRef &constant_maximum_gnode = (*equiv)[constant_maximum_]; - auto input = utils::cast(input_gnode); - auto constant_select = utils::cast(constant_select_gnode); - auto constant_greater = utils::cast(constant_greater_gnode); - auto constant_maximum = utils::cast(constant_maximum_gnode); - MS_EXCEPTION_IF_NULL(input); - MS_EXCEPTION_IF_NULL(constant_select); - MS_EXCEPTION_IF_NULL(constant_greater); - MS_EXCEPTION_IF_NULL(constant_maximum); - - auto prim = std::make_shared(kClipByNormNoDivSumOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), input, constant_select, constant_greater, constant_maximum}; - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - fusion_node->set_scope(node->scope()); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h deleted file mode 100644 index 126480603e..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -constexpr auto kInputVarName = "input"; -constexpr auto kConstantSelectVarName = "constant_select"; -constexpr auto kConstantGreaterVarName = "constant_greater"; -constexpr auto kConstantMaximumVarName = "constant_maximum"; - -class ClipByNormNoDivSquareSumFusion : public PatternProcessPass { - public: - explicit ClipByNormNoDivSquareSumFusion(bool multigraph = true) - : PatternProcessPass("clip_by_norm_no_div_square_sum_fusion", multigraph) { - input_ = std::make_shared(kInputVarName); - constant_select_ = std::make_shared(kConstantSelectVarName); - constant_greater_ = std::make_shared(kConstantGreaterVarName); - constant_maximum_ = std::make_shared(kConstantMaximumVarName); - } - ~ClipByNormNoDivSquareSumFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_; - VarPtr constant_select_; - VarPtr constant_greater_; - VarPtr constant_maximum_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_NORM_NO_DIV_SQUARE_SUM_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.cc deleted file mode 100644 index df94e897ec..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool GetMinimumOp(const AnfNodePtr &input0, const AnfNodePtr &input1, CNodePtr *minimum, bool *is_first_input) { - MS_EXCEPTION_IF_NULL(input0); - MS_EXCEPTION_IF_NULL(input1); - - CNodePtr cnode = nullptr; - if (input0->isa() && !input1->isa()) { - cnode = input0->cast(); - *is_first_input = true; - } else if (!input0->isa() && input1->isa()) { - cnode = input1->cast(); - *is_first_input = false; - } else if (input0->isa() && input1->isa()) { - if (AnfAlgo::GetCNodeName(input0) == prim::kPrimMinimum->name()) { - cnode = input0->cast(); - *is_first_input = true; - } else { - cnode = input1->cast(); - *is_first_input = false; - } - } else { - return false; - } - - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMinimum->name()) { - return false; - } - *minimum = cnode; - return true; -} -} // namespace - -const BaseRef ClipByValueFusion::DefinePattern() const { - VectorRef pattern({prim::kPrimMaximum, maximum_input0_, maximum_input1_}); - return pattern; -} - -const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto maximum_input0 = utils::cast((*equiv)[maximum_input0_]); - auto maximum_input1 = utils::cast((*equiv)[maximum_input1_]); - MS_EXCEPTION_IF_NULL(maximum_input0); - MS_EXCEPTION_IF_NULL(maximum_input1); - - CNodePtr minimum = nullptr; - bool is_first_input = true; - if (!GetMinimumOp(maximum_input0, maximum_input1, &minimum, &is_first_input)) { - return nullptr; - } - MS_EXCEPTION_IF_NULL(minimum); - if (minimum->inputs().size() != kMinimumInputNum) { - return nullptr; - } - - auto prim = std::make_shared(kClipByValueOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), minimum->input(1), - is_first_input ? maximum_input1 : maximum_input0, minimum->input(2)}; - auto clip_by_value = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(clip_by_value); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, clip_by_value.get()); - clip_by_value->set_scope(node->scope()); - return clip_by_value; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.h deleted file mode 100644 index 309b7cedd0..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/clip_by_value_fusion.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ClipByValueFusion : public PatternProcessPass { - public: - explicit ClipByValueFusion(bool multigraph = true) : PatternProcessPass("clip_by_value_fusion", multigraph) { - maximum_input0_ = std::make_shared(); - maximum_input1_ = std::make_shared(); - } - ~ClipByValueFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr maximum_input0_; - VarPtr maximum_input1_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CLIP_BY_VALUE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc deleted file mode 100644 index 41c0b21d10..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "abstract/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t kConfusionMulGradOutputNum = 2; - -CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &reduce_sum, const AnfNodePtr &mul0_anf, - const AnfNodePtr &input3) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(reduce_sum); - MS_EXCEPTION_IF_NULL(mul0_anf); - MS_EXCEPTION_IF_NULL(input3); - auto mul0 = mul0_anf->cast(); - MS_EXCEPTION_IF_NULL(mul0); - - auto prim = std::make_shared(kConfusionMulGradOpName); - std::vector inputs = {NewValueNode(prim), mul0->input(1), mul0->input(2), input3}; - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(reduce_sum->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, reduce_sum, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, reduce_sum, fusion_node); - auto types = {AnfAlgo::GetOutputInferDataType(mul0, 0), AnfAlgo::GetOutputInferDataType(reduce_sum, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(mul0, 0), AnfAlgo::GetOutputInferShape(reduce_sum, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - return fusion_node; -} - -AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const AnfNodePtr &mul1) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input2); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(input2) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - - AnfNodePtr mul0 = nullptr; - const AnfNodeIndexSet &outputs_set = manager->node_users()[input2]; - // input2 must be the 2rd input of mul0 - auto it = std::find_if(outputs_set.begin(), outputs_set.end(), [&mul1](const std::pair &node_index) { - return node_index.first != mul1 && node_index.second == 2; - }); - if (it != outputs_set.end() && AnfAlgo::GetCNodeName(it->first) == prim::kPrimMul->name()) { - mul0 = it->first; - } - return mul0; -} - -bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, - const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { - MS_EXCEPTION_IF_NULL(mul0_anf); - MS_EXCEPTION_IF_NULL(mul1_anf); - MS_EXCEPTION_IF_NULL(reduce_sum); - MS_EXCEPTION_IF_NULL(input2); - auto addn = input2->cast(); - if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { - MS_LOG(INFO) << "mul's second input is not addn"; - return true; - } - std::vector shape = AnfAlgo::GetOutputInferShape(addn, 0); - if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { - MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; - return true; - } - if (!mul0_anf->isa() || !mul1_anf->isa()) { - return true; - } - auto mul1 = mul1_anf->cast(); - MS_EXCEPTION_IF_NULL(mul1); - auto mul0 = mul0_anf->cast(); - MS_EXCEPTION_IF_NULL(mul0); - - if (IsDepend(graph, mul0->input(1), reduce_sum)) { - MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; - return true; - } - if (IsDepend(graph, mul1->input(1), mul0)) { - MS_LOG(INFO) << "mul1->input(1) depends on mul0, quit fusion"; - return true; - } - return false; -} -} // namespace - -const BaseRef ConfusionMulGradFusion::DefinePattern() const { - VectorRef mul1({prim::kPrimMul, input3_, input2_}); - VectorRef reduce_sum({prim::kPrimReduceSum, mul1}); - return reduce_sum; -} - -const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto input2 = utils::cast((*equiv)[input2_]); - auto input3 = utils::cast((*equiv)[input3_]); - auto reduce_sum = node->cast(); - MS_EXCEPTION_IF_NULL(reduce_sum); - auto mul1 = reduce_sum->input(1); - if (IsUsedByOthers(graph, mul1)) { - MS_LOG(INFO) << "Mul1 is used by others, quit fusion!"; - return nullptr; - } - auto mul0 = GetMul0(graph, input2, mul1); - if (mul0 == nullptr) { - MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; - return nullptr; - } - if (QuitFusion(graph, mul0, mul1, node, input2)) { - return nullptr; - } - - auto fusion_node = CreateFusionNode(graph, reduce_sum, mul0, input3); - std::vector fusion_node_outputs; - CreateMultipleOutputsOfAnfNode(graph, fusion_node, kConfusionMulGradOutputNum, &fusion_node_outputs); - - auto manage = graph->manager(); - MS_EXCEPTION_IF_NULL(manage); - manage->Replace(mul0, fusion_node_outputs[0]); - return fusion_node_outputs[1]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h deleted file mode 100644 index 170df5b0e4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConfusionMulGradFusion : public PatternProcessPass { - public: - explicit ConfusionMulGradFusion(bool multigraph = true) - : PatternProcessPass("confusion_mul_grad_fusion", multigraph) { - input2_ = std::make_shared(); - input3_ = std::make_shared(); - } - ~ConfusionMulGradFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input2_; - VarPtr input3_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_MUL_GRAD_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc deleted file mode 100644 index 9e2c6374ce..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.cc +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef ConfusionSoftmaxGradRule::DefinePattern() const { - return VectorRef({prim::kPrimSub, input0_, VectorRef({reduce_sum_, VectorRef({prim::kPrimMul, input1_, input0_})})}); -} - -const AnfNodePtr ConfusionSoftmaxGradRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - AnfNodePtr input0 = GetAnfNodeByVar(equiv, input0_); - AnfNodePtr input1 = GetAnfNodeByVar(equiv, input1_); - AnfNodePtr sum_anf = GetAnfNodeByVar(equiv, reduce_sum_); - if (sum_anf == nullptr || !sum_anf->isa()) { - MS_LOG(WARNING) << "Matched ReduceSum is not a CNode!"; - return nullptr; - } - if (!GetBoolAttr(sum_anf, kAttrKeepDims)) { - MS_LOG(INFO) << "ReduceSum's attr keep_dims should be true if do fusion. Otherwise the calculation will be wrong"; - return nullptr; - } - - auto prim = std::make_shared(kConfusionSoftmaxGradOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), input0, input1}; - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_abstract(node->abstract()); - fusion_node->set_scope(node->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum_anf, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum_anf, fusion_node); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h deleted file mode 100644 index a4d0d1ce7a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConfusionSoftmaxGradRule : public PatternProcessPass { - public: - explicit ConfusionSoftmaxGradRule(bool multigraph = true) - : PatternProcessPass("confusion_softmax_grad_rule", multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - reduce_sum_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); - } - ~ConfusionSoftmaxGradRule() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input0_; - VarPtr input1_; - VarPtr reduce_sum_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_CONFUSION_SOFTMAX_GRAD_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc deleted file mode 100644 index 252e586f62..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "abstract/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t kReluV2OutputNum = 2; - -CNodePtr GetRelu(const CNodePtr &relu_grad) { - MS_EXCEPTION_IF_NULL(relu_grad); - if (relu_grad->size() != kReluGradInputNum) { - MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); - } - auto relu_anf = relu_grad->input(2); - MS_EXCEPTION_IF_NULL(relu_anf); - return relu_anf->cast(); -} - -CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(relu); - if (relu->size() != kReluInputNum) { - MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); - } - - auto prim = std::make_shared(kReluV2OpName); - std::vector inputs = {NewValueNode(prim), relu->input(1)}; - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(relu->scope()); - - // ReluV2's 2rd output is mask whose data type is uint8 - TypeId mask_dtype = kNumberTypeUInt8; - std::vector mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); - if (mask_shape.size() != 4) { - MS_LOG(DEBUG) << "relu's infer shape size not equal 4"; - return nullptr; - } - auto input_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(relu, 0); - if (input_dtype == kNumberTypeUInt8 || input_dtype == kNumberTypeInt8) { - mask_shape[1] = (mask_shape[1] + 31) / 32; - mask_shape.push_back(4); - } else { - mask_shape[1] = (mask_shape[1] + 15) / 16; - mask_shape.push_back(2); - } - - auto types = {AnfAlgo::GetOutputInferDataType(relu, 0), mask_dtype}; - auto shapes = {AnfAlgo::GetOutputInferShape(relu, 0), mask_shape}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, new_node.get()); - return new_node; -} - -CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, const AnfNodePtr &second_input) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(relu_grad); - MS_EXCEPTION_IF_NULL(second_input); - - auto prim = std::make_shared(kReluGradV2OpName); - std::vector inputs = {NewValueNode(prim), relu_grad->input(1), second_input}; - auto new_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_scope(relu_grad->scope()); - new_node->set_abstract(relu_grad->abstract()); - return new_node; -} -} // namespace - -const BaseRef DereluFusion::DefinePattern() const { - VarPtr i0 = std::make_shared(); - VarPtr i1 = std::make_shared(); - VectorRef relu({prim::kPrimRelu, i1}); - VectorRef relu_grad({prim::kPrimReluGrad, i0, relu}); - return relu_grad; -} - -const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto relu_grad = node->cast(); - MS_EXCEPTION_IF_NULL(relu_grad); - auto relu = GetRelu(relu_grad); - MS_EXCEPTION_IF_NULL(relu); - - auto relu_v2 = CreateReluV2(graph, relu); - if (relu_v2 == nullptr) { - return nullptr; - } - std::vector relu_v2_node_outputs; - CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); - - auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); - - auto manage = graph->manager(); - MS_EXCEPTION_IF_NULL(manage); - manage->Replace(relu, relu_v2_node_outputs[0]); - return relu_grad_v2; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h deleted file mode 100644 index e1811f4db4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/derelu_fusion.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class DereluFusion : public PatternProcessPass { - public: - explicit DereluFusion(bool multigraph = true) : PatternProcessPass("derelu_fusion", multigraph) {} - ~DereluFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_DERELU_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc deleted file mode 100644 index efc9ee7934..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ /dev/null @@ -1,340 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" -#include -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kReplaceOutputIndex0 = 3; -constexpr size_t kReplaceOutputIndex1 = 4; -bool IsC(const BaseRef &n) { - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - return in->isa(); - } - return false; -} - -void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector *bn_outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(bn); - MS_EXCEPTION_IF_NULL(bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(bn) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"; - } - for (const auto &node_index : manager->node_users()[bn]) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - bn_outputs->push_back(output); - } -} -} // namespace - -const BaseRef FusedBatchNormFusion::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); - VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} - -ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - auto iter_constant_input0 = (*equiv).find(constant_input0_var_); - if (iter_constant_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; - } - auto constant_input = utils::cast(iter_constant_input0->second); - MS_EXCEPTION_IF_NULL(constant_input); - if (!constant_input->isa()) { - return nullptr; - } - auto value_node = constant_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - return nullptr; - } - auto tensor_ptr = value->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - if (tensor_ptr->data_type() == kNumberTypeFloat16) { - auto *half_data = static_cast(tensor_ptr->data_c()); - MS_EXCEPTION_IF_NULL(half_data); - float float_data = Eigen::half_impl::half_to_float(half_data[0]); - return MakeValue(float_data); - } else if (tensor_ptr->data_type() == kNumberTypeFloat32) { - auto *tensor_data = static_cast(tensor_ptr->data_c()); - MS_EXCEPTION_IF_NULL(tensor_data); - return MakeValue(tensor_data[0]); - } else { - MS_LOG(WARNING) << "The factor data type of value node " << value_node->DebugString() << " is not fp16 or fp32"; - return nullptr; - } -} - -AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - // Set input to create node - auto iter_data_input0 = (*equiv).find(data_input0_var_); - if (iter_data_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; - } - std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), - utils::cast(iter_data_input0->second)}; - auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); - MS_EXCEPTION_IF_NULL(bn_training_reduce); - bn_training_reduce->set_scope(node->scope()); - // Set abstract - auto iter_data_input1 = (*equiv).find(data_input1_var_); - if (iter_data_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; - } - auto data_input1 = utils::cast(iter_data_input1->second); - MS_EXCEPTION_IF_NULL(data_input1); - auto iter_data_input2 = (*equiv).find(data_input2_var_); - if (iter_data_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; - } - auto data_input2 = utils::cast(iter_data_input2->second); - MS_EXCEPTION_IF_NULL(data_input2); - AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_reduce->set_abstract(abstract_tuple); - return bn_training_reduce; -} - -void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, - const std::vector &bn_training_reduce_outputs, - std::vector *bn_training_update_inputs) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(bn_training_update_inputs); - auto iter_data_input0 = (*equiv).find(data_input0_var_); - if (iter_data_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; - } - auto iter_data_input1 = (*equiv).find(data_input1_var_); - if (iter_data_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; - } - auto iter_data_input2 = (*equiv).find(data_input2_var_); - if (iter_data_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; - } - auto iter_variable_input0 = (*equiv).find(variable_input0_var_); - if (iter_variable_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; - } - auto iter_variable_input1 = (*equiv).find(variable_input1_var_); - if (iter_variable_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } - *bn_training_update_inputs = { - NewValueNode(std::make_shared(kBNTrainingUpdateOpName)), - utils::cast(iter_data_input0->second), - bn_training_reduce_outputs[0], - bn_training_reduce_outputs[1], - utils::cast(iter_data_input1->second), - utils::cast(iter_data_input2->second), - utils::cast(iter_variable_input0->second), - utils::cast(iter_variable_input1->second), - }; -} - -void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, - std::vector *abstract_list) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(bn); - MS_EXCEPTION_IF_NULL(abstract_list); - auto bn_abstract_tuple = dyn_cast(bn->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() < kBnOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " - << bn_abstract_tuple->elements().size(); - } - auto iter_variable_input0 = (*equiv).find(variable_input0_var_); - if (iter_variable_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; - } - auto variable_input0 = utils::cast(iter_variable_input0->second); - MS_EXCEPTION_IF_NULL(variable_input0); - auto iter_variable_input1 = (*equiv).find(variable_input1_var_); - if (iter_variable_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; - } - auto variable_input1 = utils::cast(iter_variable_input1->second); - MS_EXCEPTION_IF_NULL(variable_input1); - *abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), - bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; -} - -AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( - const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - const std::vector &bn_training_reduce_outputs) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - // Set input - std::vector bn_training_update_inputs; - GetBNTrainingUpdateInputs(equiv, bn_training_reduce_outputs, &bn_training_update_inputs); - auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); - MS_EXCEPTION_IF_NULL(bn_training_update); - // Set abstract - auto iter_batch_norm = (*equiv).find(batch_norm_var_); - if (iter_batch_norm == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; - } - AnfNodePtr bn = utils::cast(iter_batch_norm->second); - MS_EXCEPTION_IF_NULL(bn); - AbstractBasePtrList abstract_list; - GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); - auto abstract_tuple = std::make_shared(abstract_list); - bn_training_update->set_abstract(abstract_tuple); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn, bn_training_update); - ValuePtr factor = GetFactor(equiv); - if (factor == nullptr) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrFactor, factor, bn_training_update); - AnfAlgo::SetNodeAttr(kAttrIsRef, MakeValue(true), bn_training_update); - bn_training_update->set_scope(node->scope()); - return bn_training_update; -} - -const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(node); - AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node, equiv); - std::vector bn_training_reduce_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum, - &bn_training_reduce_outputs); - AnfNodePtr bn_training_update = CreateBNTrainingUpdate(func_graph, node, equiv, bn_training_reduce_outputs); - if (bn_training_update == nullptr) { - MS_LOG(DEBUG) << "Create BNTrainingUpdate failed for bn node " << node->DebugString(); - return nullptr; - } - std::vector bn_training_update_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update, kBNTrainingUpdateOutputNum, - &bn_training_update_outputs); - if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is " - << bn_training_update_outputs.size(); - } - // Replace old bn outputs with new outputs - auto iter_batch_norm = (*equiv).find(batch_norm_var_); - if (iter_batch_norm == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."; - } - AnfNodePtr bn = utils::cast(iter_batch_norm->second); - std::vector bn_outputs; - GetBNOutput(func_graph, bn, &bn_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - for (const auto &output : bn_outputs) { - MS_EXCEPTION_IF_NULL(output); - if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { - continue; - } - auto tuple_getitem_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem_cnode); - AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - auto value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - if (index == kReplaceOutputIndex0 || index == kReplaceOutputIndex1) { - (void)manager->Replace(output, bn_training_update_outputs[index]); - } - } - return bn_training_update_outputs[0]; -} - -const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); - VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); - VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); - VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); - VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); - VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} - -const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef cast_variable_input0 = VectorRef({prim::kPrimCast, variable_input0_var_}); - VectorRef cast_variable_input1 = VectorRef({prim::kPrimCast, variable_input1_var_}); - VectorRef sub0 = VectorRef({prim::kPrimSub, cast_variable_input0, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, cast_variable_input1, tuple_getitem2}); - VectorRef cast0 = VectorRef({prim::kPrimCast, sub0}); - VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); - VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h deleted file mode 100644 index f476e96062..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -class FusedBatchNormFusion : public PatternProcessPass { - public: - explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph), - data_input0_var_(std::make_shared()), - data_input1_var_(std::make_shared()), - data_input2_var_(std::make_shared()), - variable_input0_var_(std::make_shared()), - variable_input1_var_(std::make_shared()), - constant_input0_var_(std::make_shared()), - constant_input1_var_(std::make_shared()), - batch_norm_var_(std::make_shared(std::make_shared(prim::kPrimBatchNorm->name()))) {} - ~FusedBatchNormFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const; - void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector &bn_training_reduce_outputs, - std::vector *bn_training_update_inputs) const; - void GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn, - std::vector *abstract_list) const; - AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - const std::vector &bn_training_reduce_outputs) const; - ValuePtr GetFactor(const EquivPtr &equiv) const; - - VarPtr data_input0_var_; - VarPtr data_input1_var_; - VarPtr data_input2_var_; - VarPtr variable_input0_var_; - VarPtr variable_input1_var_; - VarPtr constant_input0_var_; - VarPtr constant_input1_var_; - VarPtr batch_norm_var_; -}; - -class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { - public: - explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true) - : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} - - ~FusedBatchNormMixPrecisionFusion0() override = default; - const BaseRef DefinePattern() const override; -}; - -class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion { - public: - explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true) - : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {} - - ~FusedBatchNormMixPrecisionFusion1() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc deleted file mode 100644 index b82efdf86a..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/input_to_output_registry.cc +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h" -#include -#include "utils/utils.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -bool ApplyRMSPropPreCheck(const CNodePtr &node) { - return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); -} - -bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) { - TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); - return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); -} - -bool SparseApplyRMSPropPreCheck(const CNodePtr &node) { - return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); -} - -bool ApplyAdagradV2PreCheck(const CNodePtr &node) { - TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); - return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); -} - -bool ApplyKerasMomentumPreCheck(const CNodePtr &node) { - TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0); - return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16); -} - -bool SparseApplyFtrlPreCheck(const CNodePtr &node) { - return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); -} - -bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) { - return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); -} - -bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) { - return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); -} - -bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) { - return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32); -} -} // namespace -InputToOutputRegistry::InputToOutputRegistry() { - Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck); - Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck); - Register(kApplyAdagradOpName, {1}); - Register(kApplyAdagradDAName, {1, 2}); - Register(kApplyAdadeltaOpName, {1, 2}); - Register(kApplyPowerSignOpName, {1}); - Register(kApplyProximalAdagradOpName, {1}); - Register(kApplyAdaMaxOpName, {1, 2}); - Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck); - Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck); - Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck); - Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck); - Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck); - Register(kSparseApplyProximalAdagradOpName, {1}); - Register(kSparseApplyAdagradOpName, {1}); - Register(kApplyFtrlV2OpName, {1, 2}); - Register(kApplyMomentumOpName, {1}); - Register(kApplyFtrlOpName, {1, 2}); - Register(kApplyAdamOpName, {1, 2}); - Register(kApplyCenteredRMSPropOpName, {1, 2, 3}); - Register(kApplyAddSignOpName, {1}); - Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck); - Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck); - Register(kApplyAdamWithAmsgradOpName, {1, 2}); -} - -InputToOutputRegistry &InputToOutputRegistry::Instance() { - static InputToOutputRegistry instance; - return instance; -} - -void InputToOutputRegistry::Register(const InputToOutputRegister ®) { - auto op_name = reg.op_name(); - if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { - (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); - MS_LOG(DEBUG) << op_name << " input2output register successfully!"; - } -} - -void InputToOutputRegistry::Register(const std::string &op_name, const std::vector &input_indices, - const PreCheckFunc &pre_check_func) { - if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) { - InputToOutputRegister reg(op_name, pre_check_func); - reg.set_input_indices(input_indices); - (void)op_input_to_output_map_.insert(make_pair(op_name, reg)); - MS_LOG(DEBUG) << op_name << " input2output register successfully!"; - } -} - -bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const { - if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) { - *reg = op_input_to_output_map_.at(op_name); - MS_LOG(DEBUG) << op_name << " input2output find in registry."; - return true; - } - return false; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc deleted file mode 100644 index 42e37df3e4..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.cc +++ /dev/null @@ -1,266 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -bool LambNextMVRule::IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - std::vector *old_pattern_outputs) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto real_div0 = GetAnfNodeByVar(equiv, real_div0_var_); - auto real_div2 = GetAnfNodeByVar(equiv, real_div2_var_); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &users = manager->node_users(); - if (users.find(real_div0) == users.end() || users[real_div0].size() < 2) { - return false; - } - AnfNodeIndexSet real_div0_outputs = users[real_div0]; - auto iter = std::find_if(real_div0_outputs.begin(), real_div0_outputs.end(), - [&real_div2, &equiv, this](const std::pair &node_index) { - return node_index.first != real_div2 && node_index.second == 1 && - MatchAnotherPattern(node_index.first, equiv); - }); - if (iter == real_div0_outputs.end()) { - return false; - } - - (*old_pattern_outputs).push_back(node); - (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add0_var_)); - (*old_pattern_outputs).push_back(GetAnfNodeByVar(equiv, add1_var_)); - (*old_pattern_outputs).push_back(iter->first); - - return true; -} - -AnfNodePtr LambNextMVRule::CreateLambNextMVNode(const FuncGraphPtr &func_graph, - const std::vector &old_pattern_outputs, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - auto prim = std::make_shared(kLambNextMVOpName); - std::vector lamb_next_mv_rule_inputs = {NewValueNode(prim)}; - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input0_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input1_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input2_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input3_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input4_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input5_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[input6_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul0_x_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul1_sub_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul2_x_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul3_sub1_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[mul4_x_])); - lamb_next_mv_rule_inputs.push_back(utils::cast((*equiv)[add2_y_])); - auto lamb_next_mv_rule = func_graph->NewCNode(lamb_next_mv_rule_inputs); - MS_EXCEPTION_IF_NULL(lamb_next_mv_rule); - - // Set abstract of new node - AbstractBasePtrList new_abstracts; - (void)std::transform(old_pattern_outputs.begin(), old_pattern_outputs.end(), std::back_inserter(new_abstracts), - [](const AnfNodePtr &out) { return out->abstract(); }); - auto abstract_tuple = std::make_shared(new_abstracts); - MS_EXCEPTION_IF_NULL(abstract_tuple); - lamb_next_mv_rule->set_abstract(abstract_tuple); - - // Create tuple_getitem node for outputs - std::vector lamb_next_mv_rule_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, lamb_next_mv_rule, kLambNextMVRuleOutputNum, &lamb_next_mv_rule_outputs); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(old_pattern_outputs[1], lamb_next_mv_rule_outputs[1]); - (void)manager->Replace(old_pattern_outputs[2], lamb_next_mv_rule_outputs[2]); - (void)manager->Replace(old_pattern_outputs[3], lamb_next_mv_rule_outputs[3]); - - return lamb_next_mv_rule_outputs[0]; -} - -bool LambNextMVRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { - return IsSameNode(equiv1, equiv2, real_div0_var_) && IsSameNode(equiv1, equiv2, real_div1_var_) && - IsSameNode(equiv1, equiv2, add2_y_); -} - -const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - std::vector old_pattern_outputs; - if (!IsRuleMatched(func_graph, node, equiv, &old_pattern_outputs)) { - return nullptr; - } - - return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv); -} - -const BaseRef LambNextMVRuleCond1::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); - auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); - auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); - auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); - auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); - - return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); -} - -BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} - -const BaseRef LambNextMVRuleCond2::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); - auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); - auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); - auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); - auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); - - return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); -} - -BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} - -const BaseRef LambNextMVRuleCond3::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_}); - auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_}); - auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_}); - auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_}); - auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0}); - - return VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); -} - -BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} - -const BaseRef LambNextMVRuleCond4::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - - auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_}); - auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_}); - auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_}); - auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_}); - auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_}); - auto add0 = VectorRef({add0_var_, mul0, mul1}); - auto add1 = VectorRef({add1_var_, mul2, mul3}); - - auto real_div0 = VectorRef({real_div0_var_, add0, input5_}); - auto real_div1 = VectorRef({real_div1_var_, add1, input2_}); - - auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_}); - auto sqrt0 = VectorRef({prim_rsqrt, add2}); - auto real_div2 = VectorRef({real_div2_var_, real_div0, sqrt0}); - - return VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); -} - -BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - // Two patterns share: real_div0, real_div1, add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_}); - VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4}); - return real_div4; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h deleted file mode 100644 index 0089c33f87..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ - -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LambNextMVRule : public MultipleOutputPatternProcessPass { - public: - explicit LambNextMVRule(const std::string &name = "", bool multigraph = true) - : MultipleOutputPatternProcessPass(name, multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - input5_ = std::make_shared(); - input6_ = std::make_shared(); - mul0_x_ = std::make_shared(); - mul1_sub_ = std::make_shared(); - mul2_x_ = std::make_shared(); - mul3_sub1_ = std::make_shared(); - mul4_x_ = std::make_shared(); - add2_y_ = std::make_shared(); - real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - real_div2_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - ~LambNextMVRule() override = default; - const BaseRef DefinePattern() const override = 0; - BaseRef DefineAnotherPattern() const override = 0; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; - - protected: - bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, - std::vector *old_pattern_outputs) const; - AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector &old_pattern_outputs, - const EquivPtr &equiv) const; - - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr input5_; - VarPtr input6_; - VarPtr mul0_x_; - VarPtr mul1_sub_; - VarPtr mul2_x_; - VarPtr mul3_sub1_; - VarPtr mul4_x_; - VarPtr add2_y_; - // nodes which two patterns share, and add2_y_ also. - VarPtr real_div0_var_; - VarPtr real_div1_var_; - // part of output nodes - VarPtr add0_var_; - VarPtr add1_var_; - // other node - VarPtr real_div2_var_; -}; - -class LambNextMVRuleCond1 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {} - - ~LambNextMVRuleCond1() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVRuleCond2 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {} - - ~LambNextMVRuleCond2() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVRuleCond3 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {} - - ~LambNextMVRuleCond3() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVRuleCond4 : public LambNextMVRule { - public: - explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {} - - ~LambNextMVRuleCond4() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc deleted file mode 100644 index 0e3cd28a66..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.cc +++ /dev/null @@ -1,278 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" - -namespace mindspore { -namespace opt { -AnfNodePtr LambNextMVWithDecayRule::GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, - const AnfNodePtr &new_node, const AnfNodePtr &add3, - const AnfNodePtr &add5, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(new_node); - MS_EXCEPTION_IF_NULL(add3); - MS_EXCEPTION_IF_NULL(add5); - MS_EXCEPTION_IF_NULL(equiv); - auto add0 = GetAnfNodeByVar(equiv, add0_var_); - MS_EXCEPTION_IF_NULL(add0); - auto add1 = GetAnfNodeByVar(equiv, add1_var_); - MS_EXCEPTION_IF_NULL(add1); - - // Set abstract of new node - AbstractBasePtrList new_node_list; - new_node_list.push_back(add3->abstract()); - new_node_list.push_back(add0->abstract()); - new_node_list.push_back(add1->abstract()); - new_node_list.push_back(add5->abstract()); - auto abstract_tuple = std::make_shared(new_node_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - new_node->set_abstract(abstract_tuple); - // Create tuple_getitem node for outputs - std::vector new_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextMVWithDecayOutputNum, &new_node_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add3, new_node_outputs[0]); - (void)manager->Replace(add0, new_node_outputs[1]); - (void)manager->Replace(add1, new_node_outputs[2]); - return new_node_outputs[3]; -} - -AnfNodePtr LambNextMVWithDecayRule::CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, - const AnfNodePtr &add3, const AnfNodePtr &add5, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(add3); - MS_EXCEPTION_IF_NULL(equiv); - // Create new node with all the inputs - auto prim = std::make_shared(kLambNextMVWithDecayOpName); - std::vector new_node_inputs = {NewValueNode(prim)}; - for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { - auto input_node = utils::cast((*equiv)[input_vars_[i]]); - MS_EXCEPTION_IF_NULL(input_node); - new_node_inputs.push_back(input_node); - } - for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { - auto constant_mul_input_node = utils::cast((*equiv)[constant_mul_input_vars_[i]]); - MS_EXCEPTION_IF_NULL(constant_mul_input_node); - new_node_inputs.push_back(constant_mul_input_node); - } - auto constant_add2_y_node = utils::cast((*equiv)[constant_add2_y_]); - MS_EXCEPTION_IF_NULL(constant_add2_y_node); - new_node_inputs.push_back(constant_add2_y_node); - auto new_node = func_graph->NewCNode(new_node_inputs); - return GetLambNextMVWithDecayOutput(func_graph, new_node, add3, add5, equiv); -} - -bool LambNextMVWithDecayRule::IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const { - return IsSameNode(equiv1, equiv2, mul4_var_) && IsSameNode(equiv1, equiv2, real_div0_var_) && - IsSameNode(equiv1, equiv2, real_div1_var_) && IsSameNode(equiv1, equiv2, constant_add2_y_); -} - -const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - AnfNodePtr mul4 = GetAnfNodeByVar(equiv, mul4_var_); - MS_EXCEPTION_IF_NULL(mul4); - // Get add3 and match the add3 pattern - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(mul4) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; - } - AnfNodeIndexSet mul4_outputs = manager->node_users()[mul4]; - auto iter = std::find_if(mul4_outputs.begin(), mul4_outputs.end(), - [&node, &equiv, this](const std::pair &node_index) { - return node_index.first != node && MatchAnotherPattern(node_index.first, equiv); - }); - if (iter != mul4_outputs.end()) { - return CreateLambNextMVWithDecayNode(func_graph, iter->first, node, equiv); - } - return nullptr; -} - -BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); - return add5; -} - -BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1}); - VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); - return add5; -} - -BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4}); - return add5; -} - -BaseRef LambNextMVWithDecayRuleCond4::DefineAnotherPattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - MS_EXCEPTION_IF_NULL(prim_rsqrt); - VarPtr Xs = std::make_shared(); - VarPtr Ys = std::make_shared(); - VarPtr Zs = std::make_shared(); - MS_EXCEPTION_IF_NULL(Xs); - MS_EXCEPTION_IF_NULL(Ys); - MS_EXCEPTION_IF_NULL(Zs); - // Two patterns share: real_div0, real_div1, mul4, constant_add2_y_ - VectorRef real_div0 = VectorRef({real_div0_var_, Xs}); - VectorRef real_div1 = VectorRef({real_div1_var_, Ys}); - VectorRef mul4 = VectorRef({mul4_var_, Zs}); - - VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_}); - VectorRef sqrt0 = VectorRef({prim_rsqrt, add2}); - VectorRef real_div2 = VectorRef({prim::kPrimMul, real_div0, sqrt0}); - VectorRef add3 = VectorRef({prim::kPrimTensorAdd, real_div2, mul4}); - return add3; -} - -const BaseRef LambNextMVWithDecayRuleCond4::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(prim_deal_div); - VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]}); - VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]}); - VectorRef add1 = VectorRef({add1_var_, mul2, mul3}); - VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]}); - VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1}); - VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_}); - VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]}); - VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]}); - VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); - VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]}); - VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4}); - VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]}); - VectorRef add5 = VectorRef({prim::kPrimTensorAdd, real_div4, mul4}); - return add5; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h deleted file mode 100644 index 5d61975197..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ - -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class LambNextMVWithDecayRule : public MultipleOutputPatternProcessPass { - public: - explicit LambNextMVWithDecayRule(const std::string &name = "", bool multigraph = true) - : MultipleOutputPatternProcessPass(name, multigraph) { - for (size_t i = 0; i < kLambNextMVWithDecayInputNum; ++i) { - input_vars_.push_back(std::make_shared()); - } - for (size_t i = 0; i < kLambNextMVWithDecayConstantMulInputNum; ++i) { - constant_mul_input_vars_.push_back(std::make_shared()); - } - constant_add2_y_ = std::make_shared(); - mul4_var_ = std::make_shared(std::make_shared(prim::kPrimMul->name())); - real_div0_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - real_div1_var_ = std::make_shared(std::make_shared(kRealDivOpName)); - add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); - } - - ~LambNextMVWithDecayRule() override = default; - const BaseRef DefinePattern() const override = 0; - BaseRef DefineAnotherPattern() const override = 0; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override; - - protected: - AnfNodePtr GetLambNextMVWithDecayOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &new_node, - const AnfNodePtr &add3, const AnfNodePtr &add5, const EquivPtr &equiv) const; - AnfNodePtr CreateLambNextMVWithDecayNode(const FuncGraphPtr &func_graph, const AnfNodePtr &add3, - const AnfNodePtr &add5, const EquivPtr &equiv) const; - std::vector input_vars_; - std::vector constant_mul_input_vars_; - // nodes which two patterns share - VarPtr constant_add2_y_; - VarPtr mul4_var_; - VarPtr real_div0_var_; - VarPtr real_div1_var_; - // part of output nodes - VarPtr add0_var_; - VarPtr add1_var_; -}; - -class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond1(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {} - - ~LambNextMVWithDecayRuleCond1() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond2(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {} - - ~LambNextMVWithDecayRuleCond2() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond3(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {} - - ~LambNextMVWithDecayRuleCond3() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; - -class LambNextMVWithDecayRuleCond4 : public LambNextMVWithDecayRule { - public: - explicit LambNextMVWithDecayRuleCond4(bool multigraph = true) - : LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond4", multigraph) {} - - ~LambNextMVWithDecayRuleCond4() override = default; - const BaseRef DefinePattern() const override; - BaseRef DefineAnotherPattern() const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc deleted file mode 100644 index 26828f2137..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" - -namespace mindspore { -namespace opt { -namespace { -std::tuple GetSharedNodes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto add3 = node->cast(); - MS_EXCEPTION_IF_NULL(add3); - if (add3->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum; - } - auto real_div2_anf = add3->input(1); - MS_EXCEPTION_IF_NULL(real_div2_anf); - auto real_div2 = real_div2_anf->cast(); - MS_EXCEPTION_IF_NULL(real_div2); - if (real_div2->inputs().size() < kRealDivInputNum) { - MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum; - } - auto sqrt0_anf = real_div2->input(2); - MS_EXCEPTION_IF_NULL(sqrt0_anf); - auto sqrt0 = sqrt0_anf->cast(); - MS_EXCEPTION_IF_NULL(sqrt0); - if (sqrt0->inputs().size() < kRsqrtInputNum) { - MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum; - } - auto add2_anf = sqrt0->input(1); - MS_EXCEPTION_IF_NULL(add2_anf); - auto add2 = add2_anf->cast(); - if (add2->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum; - } - return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2)); -} - -bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfNodePtr &real_div0, - const AnfNodePtr &real_div1, const AnfNodePtr &add2_y) { - if (node == nullptr || !node->isa()) { - return false; - } - auto add5 = node->cast(); - if (AnfAlgo::GetCNodeName(add5) != prim::kPrimTensorAdd->name() || add5->inputs().size() != kAddInputNum) { - return false; - } - auto real_div4_anf = add5->input(1); - if (real_div4_anf == nullptr || !real_div4_anf->isa()) { - return false; - } - auto real_div4 = real_div4_anf->cast(); - if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) { - return false; - } - auto add4_anf = real_div4->input(2); - if (add4_anf == nullptr || !add4_anf->isa()) { - return false; - } - auto add4 = add4_anf->cast(); - if (AnfAlgo::GetCNodeName(add4) != prim::kPrimTensorAdd->name() || add4->inputs().size() != kAddInputNum) { - return false; - } - auto sqrt1_anf = add4->input(1); - if (sqrt1_anf == nullptr || !sqrt1_anf->isa()) { - return false; - } - auto sqrt1 = sqrt1_anf->cast(); - if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) { - return false; - } - return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 && - *add4->input(2) == *add2_y; -} - -std::tuple GetAdd0Add1Nodes(const AnfNodePtr &real_div0_anf, const AnfNodePtr &real_div1_anf) { - MS_EXCEPTION_IF_NULL(real_div0_anf); - MS_EXCEPTION_IF_NULL(real_div1_anf); - auto real_div0 = real_div0_anf->cast(); - auto real_div1 = real_div1_anf->cast(); - MS_EXCEPTION_IF_NULL(real_div0); - MS_EXCEPTION_IF_NULL(real_div1); - if (real_div0->inputs().size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size"; - } - if (real_div1->inputs().size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size"; - } - return std::make_tuple(real_div0->input(1), real_div1->input(1)); -} -} // namespace - -std::vector LambNextMVWithDecayV1Rule::GetFusionNodeInputs(const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - auto i0 = utils::cast((*equiv)[input0_]); - auto i1 = utils::cast((*equiv)[input1_]); - auto i2 = utils::cast((*equiv)[input2_]); - auto i3 = utils::cast((*equiv)[input3_]); - auto i4 = utils::cast((*equiv)[input4_]); - auto i5 = utils::cast((*equiv)[input5_]); - auto i6 = utils::cast((*equiv)[input6_]); - auto i7 = utils::cast((*equiv)[mul0_x_]); - auto i8 = utils::cast((*equiv)[mul1_sub_]); - auto i9 = utils::cast((*equiv)[mul2_x_]); - auto i10 = utils::cast((*equiv)[mul3_sub1_]); - auto i11 = utils::cast((*equiv)[mul4_x_]); - auto i12 = utils::cast((*equiv)[add2_y_]); - auto prim = std::make_shared(kLambNextMVWithDecayV1OpName); - return {NewValueNode(prim), i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12}; -} - -const BaseRef LambNextMVWithDecayV1Rule::DefinePattern() const { - const auto prim_rsqrt = std::make_shared(kRsqrtOpName); - const auto prim_real_div = std::make_shared(kRealDivOpName); - VectorRef mul3({prim::kPrimMul, mul3_sub1_, input0_}); - VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); - VectorRef add1({prim::kPrimTensorAdd, mul2, mul3}); - VectorRef real_div1({prim_real_div, add1, input2_}); - VectorRef add2({prim::kPrimTensorAdd, real_div1, add2_y_}); - VectorRef mul0({prim::kPrimMul, mul0_x_, input4_}); - VectorRef mul1({prim::kPrimMul, mul1_sub_, input3_}); - VectorRef sqrt0({prim_rsqrt, add2}); - VectorRef add0({prim::kPrimTensorAdd, mul0, mul1}); - VectorRef real_div0({prim_real_div, add0, input5_}); - VectorRef real_div2({prim::kPrimMul, real_div0, sqrt0}); - VectorRef mul4({prim::kPrimMul, mul4_x_, input6_}); - VectorRef add3({prim::kPrimTensorAdd, real_div2, mul4}); - return add3; -} - -const AnfNodePtr LambNextMVWithDecayV1Rule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (func_graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - AnfNodePtr mul4 = nullptr; - AnfNodePtr real_div0 = nullptr; - AnfNodePtr real_div1 = nullptr; - AnfNodePtr add2_y = nullptr; - std::tie(mul4, real_div0, real_div1, add2_y) = GetSharedNodes(node); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(mul4) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The Mul4 should be used by at least another node input"; - } - AnfNodeIndexSet mul4_output_node_index_set = manager->node_users()[mul4]; - auto iter = std::find_if( - mul4_output_node_index_set.begin(), mul4_output_node_index_set.end(), - [&node, &mul4, &real_div0, &real_div1, &add2_y](const std::pair &node_index) { - return node_index.first != node && MatchAdd5Pattern(node_index.first, mul4, real_div0, real_div1, add2_y); - }); - if (iter == mul4_output_node_index_set.end()) { - return nullptr; - } - - std::vector inputs = GetFusionNodeInputs(equiv); - auto fusion_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); - - AnfNodePtr add0 = nullptr; - AnfNodePtr add1 = nullptr; - AnfNodePtr add5 = iter->first; - std::tie(add0, add1) = GetAdd0Add1Nodes(real_div0, real_div1); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0), AnfAlgo::GetOutputInferDataType(add0, 0), - AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add5, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0), AnfAlgo::GetOutputInferShape(add0, 0), - AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add5, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fusion_node.get()); - - std::vector fusion_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, fusion_node, kLambNextMVWithDecayV1OutputNum, &fusion_node_outputs); - if (fusion_node_outputs.size() != kLambNextMVWithDecayV1OutputNum) { - MS_LOG(ERROR) << "create multiple outputs for fusion node fail!"; - return nullptr; - } - - (void)manager->Replace(add0, fusion_node_outputs[1]); - (void)manager->Replace(add1, fusion_node_outputs[2]); - (void)manager->Replace(add5, fusion_node_outputs[3]); - return fusion_node_outputs[0]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h deleted file mode 100644 index ff14a253dd..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -class LambNextMVWithDecayV1Rule : public PatternProcessPass { - public: - explicit LambNextMVWithDecayV1Rule(bool multigraph = true) - : PatternProcessPass("lamb_next_mv_with_decay_v1_rule", multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - input5_ = std::make_shared(); - input6_ = std::make_shared(); - mul0_x_ = std::make_shared(); - mul1_sub_ = std::make_shared(); - mul2_x_ = std::make_shared(); - mul3_sub1_ = std::make_shared(); - mul4_x_ = std::make_shared(); - add2_y_ = std::make_shared(); - } - - ~LambNextMVWithDecayV1Rule() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr input5_; - VarPtr input6_; - VarPtr mul0_x_; - VarPtr mul1_sub_; - VarPtr mul2_x_; - VarPtr mul3_sub1_; - VarPtr mul4_x_; - VarPtr add2_y_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_MV_WITH_DECAY_V1_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc deleted file mode 100644 index 5065c4c5ba..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" -#include -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - std::vector new_node_inputs; - auto prim = std::make_shared(kLambNextRightOpName); - MS_EXCEPTION_IF_NULL(prim); - new_node_inputs.push_back(NewValueNode(prim)); - auto input0 = utils::cast((*equiv)[input0_]); - MS_EXCEPTION_IF_NULL(input0); - new_node_inputs.push_back(input0); - auto input1 = utils::cast((*equiv)[input1_]); - MS_EXCEPTION_IF_NULL(input1); - new_node_inputs.push_back(input1); - auto mul2_x = utils::cast((*equiv)[mul2_x_]); - MS_EXCEPTION_IF_NULL(mul2_x); - new_node_inputs.push_back(mul2_x); - auto mul3_x = utils::cast((*equiv)[mul3_x_]); - MS_EXCEPTION_IF_NULL(mul3_x); - new_node_inputs.push_back(mul3_x); - auto true_div1_recip = utils::cast((*equiv)[true_div1_recip_]); - MS_EXCEPTION_IF_NULL(true_div1_recip); - new_node_inputs.push_back(true_div1_recip); - auto add2_y = utils::cast((*equiv)[add2_y_]); - MS_EXCEPTION_IF_NULL(add2_y); - new_node_inputs.push_back(add2_y); - auto new_node = func_graph->NewCNode(new_node_inputs); - return new_node; -} - -const BaseRef LambNextRightRule::DefinePattern() const { - const auto prim_sqrt = std::make_shared(kSqrtOpName); - MS_EXCEPTION_IF_NULL(prim_sqrt); - VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); - VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); - return VectorRef( - {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); -} - -const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto new_node = CreateLambNextRightNode(func_graph, equiv); - MS_EXCEPTION_IF_NULL(new_node); - // Set abstract of new node - auto iter_add1 = (*equiv).find(add1_var_); - if (iter_add1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; - } - auto add1 = utils::cast(iter_add1->second); - MS_EXCEPTION_IF_NULL(add1); - AbstractBasePtrList new_node_abstract_list; - new_node_abstract_list.push_back(add1->abstract()); - new_node_abstract_list.push_back(node->abstract()); - auto abstract_tuple = std::make_shared(new_node_abstract_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - new_node->set_abstract(abstract_tuple); - // Create tuple_getitem node for outputs - std::vector new_node_outputs; - CreateMultipleOutputsOfAnfNode(func_graph, new_node, kLambNextRightOutputNum, &new_node_outputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - (void)manager->Replace(add1, new_node_outputs[0]); - return new_node_outputs[1]; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h deleted file mode 100644 index 3d15001da2..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -class LambNextRightRule : public PatternProcessPass { - public: - explicit LambNextRightRule(bool multigraph = true) - : PatternProcessPass("lamb_next_right_rule", multigraph), - input0_(std::make_shared()), - input1_(std::make_shared()), - mul2_x_(std::make_shared()), - mul3_x_(std::make_shared()), - true_div1_recip_(std::make_shared()), - add2_y_(std::make_shared()), - add1_var_(std::make_shared(std::make_shared(prim::kPrimTensorAdd->name()))) {} - - ~LambNextRightRule() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - AnfNodePtr CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; - - VarPtr input0_; - VarPtr input1_; - VarPtr mul2_x_; - VarPtr mul3_x_; - VarPtr true_div1_recip_; - VarPtr add2_y_; - VarPtr add1_var_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_NEXT_RIGHT_RULE_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc deleted file mode 100644 index b5b6d2bb08..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.cc +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "common/utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -const BaseRef LambUpdateWithLRRuleFusion::DefinePattern() const { - auto real_div = std::make_shared(kRealDivOpName); - MS_EXCEPTION_IF_NULL(real_div); - auto greater = std::make_shared(kGreaterOpName); - MS_EXCEPTION_IF_NULL(greater); - - VectorRef pattern_real_div0({real_div, input1_, input2_}); - VectorRef pattern_greater0({greater, input0_, constant_greater_max_}); - VectorRef pattern_greater1({greater, input1_, constant_greater_max_}); - VectorRef pattern_select0({prim::kPrimSelect, pattern_greater0, pattern_real_div0, constant_select_}); - VectorRef pattern_select1({prim::kPrimSelect, pattern_greater1, pattern_select0, constant_select_}); - VectorRef pattern_minimum0({prim::kPrimMinimum, pattern_select1, constant_minimum_}); - VectorRef pattern_maximum0({prim::kPrimMaximum, pattern_minimum0, constant_greater_max_}); - VectorRef pattern_mul0({prim::kPrimMul, pattern_maximum0, input3_}); - VectorRef pattern_mul1({prim::kPrimMul, pattern_mul0, input4_}); - VectorRef pattern({prim::kPrimSub, input5_, pattern_mul1}); - return pattern; -} - -const AnfNodePtr LambUpdateWithLRRuleFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto input0 = utils::cast((*equiv)[input0_]); - auto input1 = utils::cast((*equiv)[input1_]); - auto input2 = utils::cast((*equiv)[input2_]); - auto input3 = utils::cast((*equiv)[input3_]); - auto input4 = utils::cast((*equiv)[input4_]); - auto input5 = utils::cast((*equiv)[input5_]); - auto input6 = utils::cast((*equiv)[constant_greater_max_]); - auto input7 = utils::cast((*equiv)[constant_select_]); - auto input8 = utils::cast((*equiv)[constant_minimum_]); - - auto prim = std::make_shared(kLambUpdateWithLROpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), input0, input1, input2, input3, input4, input5, input6, input7, input8}; - auto lamb_update_with_lr = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(lamb_update_with_lr); - - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, lamb_update_with_lr.get()); - lamb_update_with_lr->set_scope(node->scope()); - return lamb_update_with_lr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h deleted file mode 100644 index cb3939549f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LambUpdateWithLRRuleFusion : public PatternProcessPass { - public: - explicit LambUpdateWithLRRuleFusion(bool multigraph = true) - : PatternProcessPass("lamb_update_with_lr_rule_fusion", multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - input3_ = std::make_shared(); - input4_ = std::make_shared(); - input5_ = std::make_shared(); - constant_greater_max_ = std::make_shared(); - constant_select_ = std::make_shared(); - constant_minimum_ = std::make_shared(); - } - ~LambUpdateWithLRRuleFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr input3_; - VarPtr input4_; - VarPtr input5_; - VarPtr constant_greater_max_; - VarPtr constant_select_; - VarPtr constant_minimum_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_RULE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc deleted file mode 100644 index 43e1872163..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.cc +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" -#include -#include -#include -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef LambUpdateWithLrV2::DefinePattern() const { - const auto prim_greater = std::make_shared(kGreaterOpName); - const auto prim_deal_div = std::make_shared(kRealDivOpName); - - VectorRef greater0({prim_greater, input_varptr_[0], input_varptr_[5]}); - VectorRef greater1({prim_greater, input_varptr_[1], input_varptr_[5]}); - VectorRef real_div0({prim_deal_div, input_varptr_[0], input_varptr_[1]}); - VectorRef select0({prim::kPrimSelect, greater1, real_div0, input_varptr_[6]}); - VectorRef select1({prim::kPrimSelect, greater0, select0, input_varptr_[6]}); - VectorRef mul0({prim::kPrimMul, select1, input_varptr_[2]}); - VectorRef mul1({prim::kPrimMul, mul0, input_varptr_[3]}); - - return VectorRef({prim::kPrimSub, input_varptr_[4], mul1}); -} - -const AnfNodePtr LambUpdateWithLrV2::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - if (!CheckSupportDataType(node, kFloatDataTypeSet)) { - return nullptr; - } - auto prim = std::make_shared(kLambUpdateWithLrV2OpName); - std::vector inputs = {NewValueNode(prim)}; - (void)std::transform(input_varptr_.begin(), input_varptr_.end(), std::back_inserter(inputs), - [&equiv](const VarPtr &in) { return utils::cast((*equiv)[in]); }); - auto lamb_update_with_lr_v2 = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(lamb_update_with_lr_v2); - lamb_update_with_lr_v2->set_abstract(node->abstract()); - - return lamb_update_with_lr_v2; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h deleted file mode 100644 index ea614d3d2d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ - -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class LambUpdateWithLrV2 : public PatternProcessPass { - public: - explicit LambUpdateWithLrV2(bool multigraph = true) : PatternProcessPass("lamb_update_with_lr_v2", multigraph) { - for (size_t i = 0; i < kLambUpdateWithLrV2InputNum - 1; ++i) { - input_varptr_.push_back(std::make_shared()); - } - } - ~LambUpdateWithLrV2() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - std::vector input_varptr_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAMB_UPDATE_WITH_LR_V2_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc deleted file mode 100644 index b16387d8f1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc +++ /dev/null @@ -1,162 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" -#include -#include -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -using common::SafeCStr; -namespace { -void GetOutputCastNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &node, std::vector *cast_nodes) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(node) == manager->node_users().end()) { - return; - } - for (const auto &node_index : manager->node_users()[node]) { - AnfNodePtr output = node_index.first; - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - if (AnfAlgo::GetCNodeName(output_cnode) != prim::kPrimTupleGetItem->name()) { - MS_LOG(EXCEPTION) << "The output of node " << node->DebugString() << " should be " - << prim::kPrimTupleGetItem->name(); - } - if (manager->node_users().find(output) == manager->node_users().end() || - manager->node_users()[output].size() != 1) { - continue; - } - AnfNodePtr transitive_output = manager->node_users()[output].begin()->first; - MS_EXCEPTION_IF_NULL(transitive_output); - auto transitive_output_cnode = transitive_output->cast(); - MS_EXCEPTION_IF_NULL(transitive_output_cnode); - if (AnfAlgo::GetCNodeName(transitive_output_cnode) == prim::kPrimCast->name()) { - cast_nodes->push_back(transitive_output_cnode); - } - } -} - -bool CheckKernelBuildInfo(const CNodePtr &cnode, const kernel::KernelBuildInfoPtr &kernel_info) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(kernel_info); - for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) { - if (kernel_info->GetInputDeviceType(i) != kNumberTypeFloat16 || - kernel_info->GetInputFormat(i) != AnfAlgo::GetInputFormat(cnode, i)) { - return false; - } - } - for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) { - if (kernel_info->GetOutputDeviceType(i) != kNumberTypeFloat32 || - kernel_info->GetOutputFormat(i) != AnfAlgo::GetOutputFormat(cnode, i)) { - return false; - } - } - return true; -} - -bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - std::vector *cast_nodes) { - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::HasNodeAttr(kAttrShapeGamma, cnode)) { - MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr"; - return false; - } - if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) { - MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to " - << kLayerNormBetaGammaBackpropInputNum; - return false; - } - if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) { - MS_LOG(INFO) << "The node " << cnode->DebugString() << " outputs num is not equal to " - << kLayerNormBetaGammaBackpropOutputNum; - return false; - } - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { - if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) { - MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16"; - return false; - } - } - GetOutputCastNodes(func_graph, cnode, cast_nodes); - if (cast_nodes->size() != kLayerNormBetaGammaBackpropOutputNum) { - MS_LOG(INFO) << "The num of cast node in node " << cnode->DebugString() << " outputs is not equal to " - << kLayerNormBetaGammaBackpropOutputNum; - return false; - } - for (const auto &cast : *cast_nodes) { - if (AnfAlgo::GetInputDeviceDataType(cast, 0) != kNumberTypeFloat16 || - AnfAlgo::GetOutputDeviceDataType(cast, 0) != kNumberTypeFloat32) { - MS_LOG(INFO) << "The cast " << cast->DebugString() << " should be fp16->fp32"; - return false; - } - } - return true; -} -} // namespace - -const BaseRef LayerNormBetaGammaBackpropFusion::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - const auto prim = std::make_shared(kLayerNormBetaGammaBackpropOpName); - return VectorRef({prim, Xs}); -} - -const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - if (AnfAlgo::IsGraphKernel(node)) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - std::vector cast_nodes; - if (!CheckLayernormBetaGammaBackprop(func_graph, cnode, &cast_nodes)) { - return nullptr; - } - std::vector> kernel_info_list; - MS_EXCEPTION_IF_NULL(kernel_query_); - kernel_query_->Query(cnode, &kernel_info_list); - auto alternative_kernel_build_info = - std::find_if(kernel_info_list.begin(), kernel_info_list.end(), - [&cnode](const kernel::KernelBuildInfoPtr &candidate_kernel_build_info) { - return CheckKernelBuildInfo(cnode, candidate_kernel_build_info); - }); - if (alternative_kernel_build_info == kernel_info_list.end()) { - MS_LOG(INFO) << "Can not find alternative kernel build info for node " << node->DebugString(); - return nullptr; - } - AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_build_info, cnode.get()); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - // The cast_nodes size has been checked above. - MS_EXCEPTION_IF_NULL(cast_nodes[0]); - MS_EXCEPTION_IF_NULL(cast_nodes[1]); - if (cast_nodes[0]->inputs().size() != kCastInputNum) { - MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum; - } - (void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1)); - if (cast_nodes[1]->inputs().size() != kCastInputNum) { - MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum; - } - (void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1)); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h deleted file mode 100644 index 2655c0f14d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class LayerNormBetaGammaBackpropFusion : public PatternProcessPass { - public: - explicit LayerNormBetaGammaBackpropFusion(bool multigraph = true) - : PatternProcessPass("layer_norm_beta_gamma_backprop_fusion", multigraph), - kernel_query_(std::make_shared()) {} - - ~LayerNormBetaGammaBackpropFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - KernelQueryPtr kernel_query_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_LAYER_NORM_BETA_GAMMA_BACKPROP_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.cc deleted file mode 100644 index e81c804b71..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kMatMulInputIndex = 1; -constexpr size_t kBiasInputIndex = 2; -} // namespace - -const BaseRef MatmulBiasaddFusion::DefinePattern() const { - VarPtr X0 = std::make_shared(); - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - const auto prim_bias_add = std::make_shared(kBiasAddOpName); - return VectorRef({prim_bias_add, VectorRef({prim::kPrimMatMul, X0, X1}), X2}); -} - -const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kBiasAddInputNum); - AnfNodePtr matmul = cnode->input(kMatMulInputIndex); - MS_EXCEPTION_IF_NULL(matmul); - auto matmul_cnode = matmul->cast(); - MS_EXCEPTION_IF_NULL(matmul_cnode); - matmul_cnode->add_input(cnode->input(kBiasInputIndex)); - AnfAlgo::SetNodeAttr(kAttrHasBias, MakeValue(true), matmul); - return matmul; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h deleted file mode 100644 index 56675243de..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MatmulBiasaddFusion : public PatternProcessPass { - public: - explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {} - - ~MatmulBiasaddFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc deleted file mode 100644 index e7a73a9c7f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" -#include -#include -#include -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr size_t kAccumIndex = 1; -bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - std::vector mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); - return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); -} -} // namespace - -const BaseRef MomentumLossscaleFusion::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr X0 = std::make_shared(); - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - VarPtr X4 = std::make_shared(); - return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4}); -} - -const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kApplyMomentumInputNum); - AnfNodePtr mul = cnode->input(4); - MS_EXCEPTION_IF_NULL(mul); - auto mul_cnode = mul->cast(); - MS_EXCEPTION_IF_NULL(mul_cnode); - CheckCNodeInputSize(mul_cnode, kMulInputNum); - size_t value_node_index = 0; - for (size_t i = 1; i < kMulInputNum; ++i) { - if (CheckValueNodeInputOfMul(mul_cnode->input(i))) { - value_node_index = i; - break; - } - } - if (value_node_index == 0) { - MS_LOG(DEBUG) << "The Mul " << mul->DebugString() << " to be fused must has a scalar constant input"; - return nullptr; - } - auto new_prim = std::make_shared(kFusedMulApplyMomentumOpName); - std::vector new_node_inputs{NewValueNode(new_prim), - cnode->input(1), - cnode->input(2), - cnode->input(3), - mul_cnode->input(kMulInputNum - value_node_index), - cnode->input(5), - mul_cnode->input(value_node_index)}; - auto new_node = func_graph->NewCNode(new_node_inputs); - MS_EXCEPTION_IF_NULL(new_node); - AnfAlgo::CopyNodeAttrs(node, new_node); - auto input_names_value = AnfAlgo::GetNodeAttr>(new_node, kAttrInputNames); - input_names_value[3] = "x1"; - input_names_value.emplace_back("x2"); - AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node); - new_node->set_abstract(node->abstract()); - new_node->set_scope(node->scope()); - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h deleted file mode 100644 index c092e0ca22..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MomentumLossscaleFusion : public PatternProcessPass { - public: - explicit MomentumLossscaleFusion(bool multigraph = true) - : PatternProcessPass("momentum_lossscale_fusion", multigraph) {} - - ~MomentumLossscaleFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MOMENTUM_LOSSSCALE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc deleted file mode 100644 index 2536255fc1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.cc +++ /dev/null @@ -1,99 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(add); - - for (size_t index = 1; index < add->size(); ++index) { - auto input = add->input(index); - MS_EXCEPTION_IF_NULL(input); - if (input->isa()) { - auto cnode = input->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) { - if (!opt::IsUsedByOthers(graph, cnode)) { - auto full_name = cnode->fullname_with_scope(); - // exclude lamb and adam, and only work in bert - if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") || - std::string::npos == full_name.find("bert")) { - MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion"; - return false; - } - - *mul = cnode; - *mul_index = index; - return true; - } - } - } - } - return false; -} -} // namespace -const BaseRef MulAddFusion::DefinePattern() const { - VarPtr x = std::make_shared(); - VarPtr y = std::make_shared(); - VectorRef pattern({prim::kPrimTensorAdd, x, y}); - return pattern; -} - -const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - if (graph == nullptr || node == nullptr) { - return nullptr; - } - auto add = node->cast(); - if (add == nullptr || add->inputs().size() != kAddInputNum) { - return nullptr; - } - CNodePtr mul = nullptr; - size_t mul_index = 0; - if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) { - MS_LOG(DEBUG) << "Cannot find used-by-only-one-op Mul in Add's inputs"; - return nullptr; - } - - auto prim = std::make_shared(kFusedMulAddOpName); - std::vector inputs = {NewValueNode(prim)}; - for (size_t index = 1; index < mul->size(); ++index) { - inputs.push_back(mul->input(index)); - } - auto another_input_node = add->input(add->size() - mul_index); - if (another_input_node->isa() && - AnfAlgo::GetCNodeName(another_input_node) == prim::kPrimTupleGetItem->name()) { - MS_LOG(INFO) << "Add's another input node has multiple outputs, do not fuse"; - return nullptr; - } - inputs.push_back(another_input_node); - auto fusion_node = graph->NewCNode(inputs); - fusion_node->set_scope(add->scope()); - fusion_node->set_abstract(add->abstract()); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.h deleted file mode 100644 index 4b4db2b312..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_add_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MulAddFusion : public PatternProcessPass { - public: - explicit MulAddFusion(bool multigraph = true) : PatternProcessPass("mul_add_fusion", multigraph) {} - ~MulAddFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_MUL_ADD_FUSION_H diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc deleted file mode 100644 index a5e4675c8f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" -#include -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "optimizer/opt.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const CNodePtr &addn, - const size_t &lossscale_input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(mul); - MS_EXCEPTION_IF_NULL(addn); - auto prim = std::make_shared(kFusedMulAddNOpName); - std::vector inputs = {NewValueNode(prim)}; - inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); - inputs.push_back(addn->input(2)); - // scalar input should be 3rd input - inputs.push_back(mul->input(lossscale_input_index)); - auto fusion_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(addn->scope()); - fusion_node->set_abstract(addn->abstract()); - return fusion_node; -} -} // namespace - -const BaseRef MulAddNFusion::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Y = std::make_shared(); - VarPtr Z = std::make_shared(); - - VectorRef mul({prim::kPrimMul, X, Z}); - VectorRef addn({prim::kPrimAddN, mul, Y}); - return addn; -} - -const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - if (graph == nullptr || node == nullptr || equiv == nullptr) { - return nullptr; - } - - auto addn = node->cast(); - if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { - return nullptr; - } - auto mul_anf = addn->input(1); - if (mul_anf == nullptr) { - return nullptr; - } - auto mul = mul_anf->cast(); - if (mul == nullptr || mul->inputs().size() != kMulInputNum) { - return nullptr; - } - if (IsUsedByOthers(graph, mul)) { - MS_LOG(DEBUG) << "Mul is used by more then two nodes, cannot fuse"; - return nullptr; - } - - size_t lossscale_input_index = 1; - for (size_t index = 1; index < mul->inputs().size(); ++index) { - auto input_node = mul->input(index); - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa()) { - lossscale_input_index = index; - break; - } - } - auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0); - if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) { - MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is " - << constant_shape.size() << " and shape[0] is " << constant_shape[0]; - return nullptr; - } - - return CreateFusionNode(graph, mul, addn, lossscale_input_index); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h deleted file mode 100644 index d03309bf73..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class MulAddNFusion : public PatternProcessPass { - public: - explicit MulAddNFusion(bool multigraph = true) : PatternProcessPass("mul_addn_fusion", multigraph) {} - ~MulAddNFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc deleted file mode 100644 index a3c87dad5d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -namespace { -const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, - std::vector *trans_road) { - if (node == nullptr) { - MS_LOG(ERROR) << "nullptr"; - return nullptr; - } - if (node->isa()) { - auto cnode = node->cast(); - auto op_name = AnfAlgo::GetCNodeName(cnode); - auto manager = func_graph->manager(); - if (manager == nullptr) { - return nullptr; - } - if (op_name == prim::kPrimCast->name() || op_name == prim::kPrimTranspose->name() || - op_name == prim::kPrimReshape->name() || op_name == kTransDataOpName) { - auto users = manager->node_users()[node]; - if (users.size() > 1 && !first_flag) { - return nullptr; - } - trans_road->push_back(cnode); - first_flag = false; - auto next_node = AnfAlgo::GetInputNode(cnode, 0); - if (next_node->isa() || next_node->isa()) { - return next_node; - } - return ParamTransRoad(func_graph, next_node, first_flag, trans_road); - } - } else if (node->isa() || node->isa()) { - return node; - } - return nullptr; -} - -kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, - TypeId output_type) { - MS_EXCEPTION_IF_NULL(cast); - auto kernel_info = cast->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto cast_build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(cast_build_info); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat({format}); - builder.SetInputsFormat({format}); - builder.SetInputsDeviceType({input_type}); - builder.SetOutputsDeviceType({output_type}); - builder.SetKernelType(cast_build_info->kernel_type()); - builder.SetFusionType(cast_build_info->fusion_type()); - builder.SetProcessor(cast_build_info->processor()); - return builder.Build(); -} -} // namespace -bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Func graph is nullptr"; - return false; - } - auto manager = func_graph->manager(); - if (manager == nullptr) { - return false; - } - std::vector node_list = TopoSort(func_graph->get_return()); - bool changed = false; - for (auto node : node_list) { - if (node == nullptr || !node->isa()) { - continue; - } - auto cnode = node->cast(); - auto node_name = AnfAlgo::GetCNodeName(cnode); - if (node_name == prim::kPrimCast->name() || node_name == prim::kPrimTranspose->name() || - node_name == prim::kPrimReshape->name() || node_name == kTransDataOpName) { - MS_LOG(DEBUG) << "Skip trans op"; - continue; - } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { - std::vector trans_road; - bool first_flag = true; - auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); - if (final_node != nullptr && trans_road.size() == 3 && AnfAlgo::GetCNodeName(trans_road[0]) == kTransDataOpName && - AnfAlgo::GetCNodeName(trans_road[1]) == prim::kPrimCast->name() && - AnfAlgo::GetCNodeName(trans_road[2]) == kTransDataOpName) { - auto cur_transop = trans_road[0]; - auto format = AnfAlgo::GetOutputFormat(cur_transop, 0); - auto dtype = AnfAlgo::GetOutputDeviceDataType(cur_transop, 0); - auto param_format = AnfAlgo::GetOutputFormat(final_node, 0); - auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); - - auto cast = trans_road[1]; - if (param_format == format && param_dtype != dtype) { - AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); - manager->Replace(trans_road[2], final_node); - manager->Replace(cur_transop, cast); - } - changed = true; - } - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h deleted file mode 100644 index 823ec083b1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_PARAMETER_AND_TRANSOP_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -class ParameterTransOpFusion : public Pass { - public: - explicit ParameterTransOpFusion(size_t groups = 1) : Pass("Parameter_and_transop_fusion"), groups_(groups) {} - ~ParameterTransOpFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - size_t groups_ = 1; -}; -} // namespace opt -} // namespace mindspore - -#endif diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc deleted file mode 100644 index 857670a384..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -void DoRefresh(const CNodePtr &cnode) { - if (cnode == nullptr) { - MS_LOG(EXCEPTION) << "node is nullptr"; - } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { - auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index); - if (input_kernel_node->isa()) { - std::shared_ptr builder = - std::make_shared(); - auto cnode_input_format = AnfAlgo::GetInputFormat(cnode, input_index); - auto kernel_node_format = AnfAlgo::GetOutputFormat(input_kernel_node, 0); - auto dtype = AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0); - if (kernel_node_format != cnode_input_format) { - builder->SetOutputsFormat({cnode_input_format}); - builder->SetOutputsDeviceType({dtype}); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); - } - } - } -} - -bool RefreshParameterFormat::Run(const FuncGraphPtr &func_graph) { - if (func_graph == nullptr) { - MS_LOG(ERROR) << "func_graph is nullptr."; - return false; - } - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto node : node_list) { - if (node == nullptr || !node->isa()) { - continue; - } - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - auto node_name = AnfAlgo::GetCNodeName(cnode); - if (node_name == kBNTrainingUpdateOpName) { - DoRefresh(cnode); - } - } - return true; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h deleted file mode 100644 index 0ba688b134..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ - -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -class RefreshParameterFormat : public Pass { - public: - explicit RefreshParameterFormat(size_t groups = 1) : Pass("refresh_parameter_format"), groups_(groups) {} - ~RefreshParameterFormat() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - size_t groups_ = 1; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc deleted file mode 100644 index fa2815ff62..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef RemoveReshapePair::DefinePattern() const { - VarPtr X = std::make_shared(); - MS_EXCEPTION_IF_NULL(X); - return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})}); -} - -const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_op_1); - // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly - if (IsUsedByOthers(func_graph, reshape_op_1)) { - return nullptr; - } - auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_op_2); - if (IsUsedByOthers(func_graph, reshape_op_2)) { - return nullptr; - } - auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0); - auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0); - if (input_shape == output_shape) { - auto input_node = reshape_op_2->input(1); - return input_node; - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h deleted file mode 100644 index ddb25df70c..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ - -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class RemoveReshapePair : public PatternProcessPass { - public: - explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {} - ~RemoveReshapePair() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc deleted file mode 100644 index 9b13002798..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckShapeDimInfo(const std::vector &shape) { - if (shape.empty()) { - return false; - } - if (shape.size() == 1 && shape[0] % kCubeSize != 0) { - return false; - } - return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); -} -} // namespace - -const BaseRef ReshapeTransposeFusion::DefinePattern() const { - const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); - VectorRef reshape({prim_reshape, input_varptr_}); - - return VectorRef({prim::kPrimTranspose, reshape}); -} - -const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(transpose_cnode); - auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_cnode); - std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); - std::vector transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0); - if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) { - return nullptr; - } - auto prim = std::make_shared(kConfusionTransposeDOpName); - std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; - auto new_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_abstract(node->abstract()); - - AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); - AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); - AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(false), new_node); - auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); - - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h deleted file mode 100644 index 5abf3e0d53..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ReshapeTransposeFusion : public PatternProcessPass { - public: - explicit ReshapeTransposeFusion(bool multigraph = true) : PatternProcessPass("reshape_transpose_fusion", multigraph) { - input_varptr_ = std::make_shared(); - } - ~ReshapeTransposeFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc deleted file mode 100644 index f95406e5e1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.cc +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef SoftmaxGradExtFusion::DefinePattern() const { - VectorRef mul({prim::kPrimMul, input1_, input0_}); - VectorRef sum({sum_var_, mul}); - VectorRef sub({prim::kPrimSub, input0_, sum}); - VectorRef mul1({prim::kPrimMul, input2_, input1_}); - VectorRef mul_grad({prim::kPrimMul, mul1, sub}); - return mul_grad; -} - -const BaseRef SoftmaxGradExtFusionV2::DefinePattern() const { - VectorRef mul({prim::kPrimMul, input1_, input0_}); - VectorRef sum({sum_var_, mul}); - VectorRef sub({prim::kPrimSub, input0_, sum}); - VectorRef mul1({prim::kPrimMul, input1_, sub}); - VectorRef mul_grad({prim::kPrimMul, input2_, mul1}); - return mul_grad; -} - -const BaseRef SoftmaxGradExtFusionV3::DefinePattern() const { - VectorRef mul({prim::kPrimMul, input1_, input0_}); - VectorRef sum({sum_var_, mul}); - VectorRef sub({prim::kPrimSub, input0_, sum}); - VectorRef mul1({prim::kPrimMul, input1_, sub}); - VectorRef mul_grad({prim::kPrimMul, mul1, input2_}); - return mul_grad; -} - -const AnfNodePtr SoftmaxGradExtFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(node); - auto input0 = GetAnfNodeByVar(equiv, input0_); - auto input1 = GetAnfNodeByVar(equiv, input1_); - auto input2 = GetAnfNodeByVar(equiv, input2_); - auto sum = GetAnfNodeByVar(equiv, sum_var_); - if (!GetBoolAttr(sum, kAttrKeepDims)) { - MS_LOG(INFO) << "sum's attr keep_dims should be true if do fusion"; - return nullptr; - } - - auto prim = std::make_shared(kSoftmaxGradExtOpName); - auto fusion_node = graph->NewCNode({NewValueNode(prim), input0, input1, input2}); - MS_EXCEPTION_IF_NULL(fusion_node); - fusion_node->set_scope(node->scope()); - fusion_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, "keepdims", sum, fusion_node); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum, fusion_node); - return fusion_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h deleted file mode 100644 index 59032e6973..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ - -#include -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SoftmaxGradExtFusion : public PatternProcessPass { - public: - explicit SoftmaxGradExtFusion(const std::string &name = "softmax_grad_ext_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph) { - input0_ = std::make_shared(); - input1_ = std::make_shared(); - input2_ = std::make_shared(); - sum_var_ = std::make_shared(std::make_shared(prim::kPrimReduceSum->name())); - } - ~SoftmaxGradExtFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - protected: - VarPtr input0_; - VarPtr input1_; - VarPtr input2_; - VarPtr sum_var_; -}; - -class SoftmaxGradExtFusionV2 : public SoftmaxGradExtFusion { - public: - explicit SoftmaxGradExtFusionV2(bool multigraph = true) - : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v2", multigraph) {} - ~SoftmaxGradExtFusionV2() override = default; - const BaseRef DefinePattern() const override; -}; - -class SoftmaxGradExtFusionV3 : public SoftmaxGradExtFusion { - public: - explicit SoftmaxGradExtFusionV3(bool multigraph = true) - : SoftmaxGradExtFusion("softmax_grad_ext_fusion_v3", multigraph) {} - ~SoftmaxGradExtFusionV3() override = default; - const BaseRef DefinePattern() const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SOFTMAX_GRAD_EXT_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.cc deleted file mode 100644 index 8c0335ecc1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.cc +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" - -#include -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(square); - MS_EXCEPTION_IF_NULL(sum); - if (square->inputs().size() != kSquareNodeInputNum) { - MS_LOG(EXCEPTION) << "Square node has wrong input size"; - } - auto prim = std::make_shared(kSquareSumV1OpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector square_sumv1_inputs = {NewValueNode(prim), square->input(1)}; - auto square_sumv1 = graph->NewCNode(square_sumv1_inputs); - MS_EXCEPTION_IF_NULL(square_sumv1); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - square_sumv1->set_kernel_info(kernel_info); - auto types = {AnfAlgo::GetOutputInferDataType(sum, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv1.get()); - square_sumv1->set_scope(sum->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv1); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv1); - auto names = MakeValue>({square->fullname_with_scope(), sum->fullname_with_scope()}); - AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv1); - return square_sumv1; -} - -CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, const CNodePtr &sum) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(square); - MS_EXCEPTION_IF_NULL(sum); - if (square->inputs().size() != kSquareNodeInputNum) { - MS_LOG(EXCEPTION) << "Square node has wrong input size"; - } - auto prim = std::make_shared(kSquareSumV2OpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector square_sumv2_inputs = {NewValueNode(prim), square->input(1)}; - auto square_sumv2 = graph->NewCNode(square_sumv2_inputs); - MS_EXCEPTION_IF_NULL(square_sumv2); - auto types = {AnfAlgo::GetOutputInferDataType(sum, 0), AnfAlgo::GetOutputInferDataType(square, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(sum, 0), AnfAlgo::GetOutputInferShape(square, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, square_sumv2.get()); - square_sumv2->set_scope(sum->scope()); - AnfAlgo::CopyNodeAttr(kAttrAxis, sum, square_sumv2); - AnfAlgo::CopyNodeAttr(kAttrKeepDims, sum, square_sumv2); - auto names = MakeValue>({square->fullname_with_scope(), sum->fullname_with_scope()}); - AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, names, square_sumv2); - return square_sumv2; -} - -std::tuple GetPrevNodes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto sum = node->cast(); - MS_EXCEPTION_IF_NULL(sum); - if (sum->inputs().size() != kSumNodeInputNum) { - MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size"; - } - auto square_anf = sum->input(1); - MS_EXCEPTION_IF_NULL(square_anf); - auto square = square_anf->cast(); - MS_EXCEPTION_IF_NULL(square); - - return std::make_tuple(sum, square_anf, square); -} -} // namespace - -const BaseRef SquareSumFusion::DefinePattern() const { - VarPtr X = std::make_shared(); - MS_EXCEPTION_IF_NULL(X); - return VectorRef({prim::kPrimReduceSum, VectorRef({prim::kPrimSquare, X})}); -} - -const AnfNodePtr SquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - CNodePtr sum = nullptr; - AnfNodePtr square_anf = nullptr; - CNodePtr square = nullptr; - std::tie(sum, square_anf, square) = GetPrevNodes(node); - - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (manager->node_users().find(square_anf) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "Square node has no output in NodeUsersMap"; - } - AnfNodePtr ret_node = nullptr; - if (manager->node_users()[square_anf].size() == 1) { - ret_node = GenerateSquareSumV1(graph, square, sum); - } else if (manager->node_users()[square_anf].size() == 2) { - auto square_sumv2 = GenerateSquareSumV2(graph, square, sum); - - std::vector square_sumv2_outputs; - CreateMultipleOutputsOfAnfNode(graph, square_sumv2, kSquareSumv2OutputNum, &square_sumv2_outputs); - if (square_sumv2_outputs.size() != kSquareSumv2OutputNum) { - MS_LOG(EXCEPTION) << "make SquareSumV2 outputs fail"; - } - (void)manager->Replace(square, square_sumv2_outputs[1]); - ret_node = square_sumv2_outputs[0]; - } - return ret_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.h deleted file mode 100644 index 5a694a5585..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/square_sum_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class SquareSumFusion : public PatternProcessPass { - public: - explicit SquareSumFusion(bool multigraph = true) : PatternProcessPass("square_sum_fusion", multigraph) {} - ~SquareSumFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc deleted file mode 100644 index 250f86d9b1..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckShapeDimInfo(const std::vector &shape) { - if (shape.empty()) { - return false; - } - if (shape.size() == 1 && shape[0] % kCubeSize != 0) { - return false; - } - return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); -} -} // namespace - -const BaseRef TransposeReshapeFusion::DefinePattern() const { - const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); - VectorRef transpose({prim::kPrimTranspose, input_varptr_}); - - return VectorRef({prim_reshape, transpose}); -} - -const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(reshape_cnode); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); - MS_EXCEPTION_IF_NULL(transpose_cnode); - std::vector reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); - std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); - if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { - return nullptr; - } - auto prim = std::make_shared(kConfusionTransposeDOpName); - std::vector inputs = {NewValueNode(prim), utils::cast((*equiv)[input_varptr_])}; - auto new_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttrs(reshape_cnode, new_node); - AnfAlgo::CopyNodeAttr(kAttrPerm, transpose_cnode, new_node); - AnfAlgo::SetNodeAttr(kAttrTransposeFirst, MakeValue(true), new_node); - auto reshape_output_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(Convert2Int(reshape_output_shape)), new_node); - - return new_node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h deleted file mode 100644 index 8b979f869d..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class TransposeReshapeFusion : public PatternProcessPass { - public: - explicit TransposeReshapeFusion(bool multigraph = true) : PatternProcessPass("transpose_reshape_fusion", multigraph) { - input_varptr_ = std::make_shared(); - } - ~TransposeReshapeFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_RESHAPE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc deleted file mode 100644 index e45fc2637f..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -const BaseRef TransposeTransDataFusion::DefinePattern() const { - const auto prim_transdata = std::make_shared(prim::KPrimTransData->name()); - VectorRef transpose({prim::kPrimTranspose, input_varptr_}); - - return VectorRef({prim_transdata, transpose}); -} - -const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(equiv); - auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum); - MS_EXCEPTION_IF_NULL(transdata_cnode); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum); - MS_EXCEPTION_IF_NULL(transpose_cnode); - auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode); - auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode); - MS_EXCEPTION_IF_NULL(transpose_kernel_build_info); - MS_EXCEPTION_IF_NULL(transdata_kernel_build_info); - - auto new_transdata_builder = std::make_shared(); - auto transpose_input_formats = transpose_kernel_build_info->GetAllInputFormats(); - new_transdata_builder->SetInputsFormat(transpose_input_formats); - new_transdata_builder->SetOutputsFormat(transdata_kernel_build_info->GetAllOutputFormats()); - new_transdata_builder->SetInputsDeviceType(transdata_kernel_build_info->GetAllInputDeviceTypes()); - new_transdata_builder->SetOutputsDeviceType(transdata_kernel_build_info->GetAllOutputDeviceTypes()); - new_transdata_builder->SetKernelType(transdata_kernel_build_info->kernel_type()); - new_transdata_builder->SetFusionType(transdata_kernel_build_info->fusion_type()); - new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); - - auto new_fusion_transdata = std::make_shared(kTransDataOpName); - if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { - std::vector inputs = {NewValueNode(new_fusion_transdata), - utils::cast((*equiv)[input_varptr_])}; - auto new_node = func_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(new_node); - new_node->set_abstract(node->abstract()); - AnfAlgo::CopyNodeAttrs(transdata_cnode, new_node); - AnfAlgo::SetNodeAttr(kAttrSrcFormat, MakeValue(transpose_input_formats[0]), new_node); - AnfAlgo::SetSelectKernelBuildInfo(new_transdata_builder->Build(), new_node.get()); - MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " success"; - return new_node; - } else { - MS_LOG(INFO) << "transpose transdata fusion node:" << node->fullname_with_scope() << " failed"; - return node; - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h deleted file mode 100644 index 833588cf45..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ascend_helper.h" - -namespace mindspore { -namespace opt { -class TransposeTransDataFusion : public PatternProcessPass { - public: - explicit TransposeTransDataFusion(bool multigraph = true) - : PatternProcessPass("transpose_transdata_fusion", multigraph) { - input_varptr_ = std::make_shared(); - supported_checker_ = std::make_shared(); - } - ~TransposeTransDataFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr input_varptr_; - - private: - SupportedCheckerPtr supported_checker_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_TRANSPOSE_TRANSDATA_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc deleted file mode 100644 index b930ac69c9..0000000000 --- a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.cc +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/common_backend_optimization.h" -#include -#include -#include "pre_activate/common/optimizer.h" -#include "pre_activate/pass/convert_const_input_to_attr.h" -#include "pre_activate/pass/convert_tuple_output_to_maketuple.h" -#include "pre_activate/pass/convert_const_input_to_tensor_input.h" -#include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" -#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" -#include "utils/context/ms_context.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -void BackendCommonOptimization(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_LOG(INFO) << "start common opt graph:" << kernel_graph->graph_id(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_common_before_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } - auto optimizer = std::make_shared(); - auto common_pm = std::make_shared("common_pm"); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - common_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(common_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/hwopt_common_after_graph_" + std::to_string(kernel_graph->graph_id()) + ".ir"; - DumpIR(file_path, kernel_graph); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.h b/mindspore/ccsrc/pre_activate/common/common_backend_optimization.h deleted file mode 100644 index 6ce92da0dc..0000000000 --- a/mindspore/ccsrc/pre_activate/common/common_backend_optimization.h +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ -#include -#include "session/kernel_graph.h" -namespace mindspore { -namespace opt { -void BackendCommonOptimization(const std::shared_ptr &kernel_graph); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_COMMON_BACKEND_OPTIMIZATION_H_ diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc deleted file mode 100644 index 2b45fc6579..0000000000 --- a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/fusion_id_allocator.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; } - -FusionIdAllocator::~FusionIdAllocator() {} - -void FusionIdAllocator::Init() { fusion_id = 0; } - -int32_t FusionIdAllocator::AllocateFusionId() { - fusion_id++; - return fusion_id; -} - -bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode); -} - -int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) { - if (HasFusionIdAttr(node)) { - return AnfAlgo::GetNodeAttr(node, kAttrFusionId); - } - return -1; -} - -void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) { - ValuePtr fusion_id_v = MakeValue(id); - AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc deleted file mode 100644 index e1db0ed6ed..0000000000 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ /dev/null @@ -1,785 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/common/helper.h" -#include -#include -#include -#include -#include -#include -#include -#include "utils/utils.h" -#include "utils/base_ref.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" -#include "common/utils.h" -#include "device/kernel_info.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace opt { -constexpr size_t kType32Len = 4; -std::vector Convert2Int(const std::vector &v) { - std::vector result; - (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); - return result; -} - -bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node1); - MS_EXCEPTION_IF_NULL(node2); - std::vector node_list = TopoSort(graph->get_return()); - std::map> control_depend_map; - for (auto &nd : node_list) { - MS_EXCEPTION_IF_NULL(nd); - if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { - auto control_depend = nd->cast(); - auto prior_node = control_depend->input(kControlDependPriorIndex); - auto behind_node = control_depend->input(kControlDependBehindIndex); - auto it = control_depend_map.find(behind_node); - if (it == control_depend_map.end()) { - control_depend_map[behind_node] = std::set{prior_node}; - } else { - it->second.insert(prior_node); - } - } - } - - FuncGraphManagerPtr manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - std::unordered_set seen_node; - std::deque todo{node1}; - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { - continue; - } - (void)seen_node.insert(node); - - if (node == node2) { - return true; - } - if (node->isa()) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); - } - auto it = control_depend_map.find(node); - if (it != control_depend_map.end()) { - (void)todo.insert(todo.end(), it->second.begin(), it->second.end()); - } - } - return false; -} - -bool UnVisited(const BaseRef &n) { - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - if (IsValueNode(in)) { - auto value_node = in->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - auto prim_py = value->cast(); - MS_EXCEPTION_IF_NULL(prim_py); - return !prim_py->HasAttr(kAttrVisited); - } else if (IsValueNode(in)) { - auto func_graph = GetValueNode(in); - MS_EXCEPTION_IF_NULL(func_graph); - return !func_graph->has_flag(kAttrVisited); - } - return false; - } - return false; -} - -bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(ERROR) << "The node is expected to be a cnode"; - return false; - } - *cnode = node->cast(); - if (*cnode == nullptr) { - return false; - } - if ((*cnode)->inputs().size() < IntToSize(input_size)) { - auto op_name = AnfAlgo::GetCNodeName(*cnode); - MS_LOG(ERROR) << "op[" + op_name + "] has less than " << input_size << " inputs."; - return false; - } - return true; -} - -CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != IntToSize(input_size)) { - auto op_name = AnfAlgo::GetCNodeName(cnode); - MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs."; - } - return cnode; -} - -void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) { - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != input_size) { - MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size; - } -} - -bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y) { - MS_EXCEPTION_IF_NULL(node_x); - MS_EXCEPTION_IF_NULL(node_y); - return (AnfAlgo::GetInputDeviceDataType(node_x, 0) == AnfAlgo::GetOutputDeviceDataType(node_y, 0) && - AnfAlgo::GetOutputDeviceDataType(node_x, 0) == AnfAlgo::GetInputDeviceDataType(node_y, 0)); -} - -const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - - auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); - MS_EXCEPTION_IF_NULL(transop_cnode); - auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); - auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); - MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); - MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1)); - auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1); - MS_EXCEPTION_IF_NULL(transed_node); - - std::vector replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, - depend_cnode->input(kDependInputNum - 1)}; - AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); - MS_EXCEPTION_IF_NULL(replace_depend); - auto transed_abstract = transed_node->abstract(); - replace_depend->set_abstract(transed_abstract); - return replace_depend; -} - -bool Visited(const BaseRef &n) { - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - if (IsValueNode(in)) { - auto value_node = in->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - auto prim_py = value->cast(); - MS_EXCEPTION_IF_NULL(prim_py); - return prim_py->HasAttr(kAttrVisited); - } else if (IsValueNode(in)) { - auto func_graph = GetValueNode(in); - MS_EXCEPTION_IF_NULL(func_graph); - return func_graph->has_flag(kAttrVisited); - } - return false; - } - return false; -} - -void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, - std::vector *conv_bn1_outputs) { - auto prim = std::make_shared(kConvBN1OpName); - std::vector conv_bn1_inputs = {NewValueNode(prim)}; - MS_EXCEPTION_IF_NULL(conv_cnode); - // All the inputs of conv_bn1 are from the inputs of conv - for (size_t i = 1; i < conv_cnode->inputs().size(); i++) { - conv_bn1_inputs.push_back(conv_cnode->input(i)); - } - MS_EXCEPTION_IF_NULL(func_graph); - CNodePtr conv_bn1_cnode = func_graph->NewCNode(conv_bn1_inputs); - MS_EXCEPTION_IF_NULL(conv_bn1_cnode); - auto kernel_info = std::make_shared(); - conv_bn1_cnode->set_kernel_info(kernel_info); - // Set attr for conv_bn1 - AnfAlgo::CopyNodeAttrs(conv_cnode, conv_bn1_cnode); - // Set abstract of conv_bn1 - MS_EXCEPTION_IF_NULL(bn_cnode); - auto bn_abstract_tuple = dyn_cast(bn_cnode->abstract()); - MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - AbstractBasePtrList conv_bn1_abstract_list; - conv_bn1_abstract_list.push_back(conv_cnode->abstract()); - auto abstract_tensor = std::make_shared( - kFloat32, Convert2Int(AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, kVariance - 1))); - conv_bn1_abstract_list.push_back(abstract_tensor); - conv_bn1_abstract_list.push_back(bn_abstract_tuple->elements()[kSaveMean]); - auto abstract_tuple = std::make_shared(conv_bn1_abstract_list); - conv_bn1_cnode->set_abstract(abstract_tuple); - - CreateMultipleOutputsOfAnfNode(func_graph, conv_bn1_cnode, kConvBn1OutputNum, conv_bn1_outputs); -} - -void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, - const CNodePtr &bn_node, std::vector *fused_bn2_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(bn_node); - MS_EXCEPTION_IF_NULL(fused_bn2_outputs); - if (bn_node->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; - } - if (fused_bn1_outputs.size() != kBN1OutputNum) { - MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; - } - - // the inputs of fused_bn2 are from the outputs of fused_bn1 and the inputs of bn - std::vector fused_bn2_inputs = {NewValueNode(std::make_shared(kFusedBN2OpName))}; - fused_bn2_inputs.push_back(fused_bn1_outputs[0]); - fused_bn2_inputs.push_back(fused_bn1_outputs[1]); - fused_bn2_inputs.push_back(bn_node->input(4)); - fused_bn2_inputs.push_back(bn_node->input(5)); - auto fused_bn2 = graph->NewCNode(fused_bn2_inputs); - MS_EXCEPTION_IF_NULL(fused_bn2); - auto kernel_info = std::make_shared(); - fused_bn2->set_kernel_info(kernel_info); - auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 4), AnfAlgo::GetOutputInferDataType(bn_node, 1), - AnfAlgo::GetOutputInferDataType(bn_node, 2)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 4), AnfAlgo::GetOutputInferShape(bn_node, 1), - AnfAlgo::GetOutputInferShape(bn_node, 2)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn2.get()); - fused_bn2->set_scope(bn_node->scope()); - AnfAlgo::CopyNodeAttr(kAttrMomentum, bn_node, fused_bn2); - - CreateMultipleOutputsOfAnfNode(graph, fused_bn2, kBN2OutputNum, fused_bn2_outputs); -} - -void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, - const std::vector &fused_bn1_outputs, - const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, - std::vector *fused_bn3_outputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(data_input); - MS_EXCEPTION_IF_NULL(bn_node); - MS_EXCEPTION_IF_NULL(fused_bn3_outputs); - if (bn_node->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size"; - } - - if (fused_bn1_outputs.size() != kBN1OutputNum) { - MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size"; - } - - if (fused_bn2_outputs.size() != kBN2OutputNum) { - MS_LOG(EXCEPTION) << "BN2 outputs has wrong input size"; - } - - // the inputs of fused_bn3 are from the outputs of fused_bn1 and the inputs of bn - std::vector fused_bn3_inputs = {NewValueNode(std::make_shared(kFusedBN3OpName))}; - fused_bn3_inputs.push_back(data_input); - fused_bn3_inputs.push_back(fused_bn1_outputs[0]); - fused_bn3_inputs.push_back(fused_bn2_outputs[0]); - fused_bn3_inputs.push_back(bn_node->input(2)); - fused_bn3_inputs.push_back(bn_node->input(3)); - auto fused_bn3 = graph->NewCNode(fused_bn3_inputs); - MS_EXCEPTION_IF_NULL(fused_bn3); - auto kernel_info = std::make_shared(); - fused_bn3->set_kernel_info(kernel_info); - auto types = {AnfAlgo::GetOutputInferDataType(bn_node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(bn_node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, fused_bn3.get()); - - fused_bn3->set_scope(bn_node->scope()); - AnfAlgo::CopyNodeAttr(kAttrEpsilon, kAttrEps, bn_node, fused_bn3); - - (*fused_bn3_outputs).push_back(fused_bn3); -} - -void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_num, - std::vector *outputs) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(outputs); - for (size_t i = 0; i < output_num; i++) { - auto idx = NewValueNode(SizeToInt(i)); - MS_EXCEPTION_IF_NULL(idx); - int temp = SizeToInt(i); - auto imm = std::make_shared(temp); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); - MS_EXCEPTION_IF_NULL(tuple_getitem); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(node, i)}, - {AnfAlgo::GetOutputInferShape(node, i)}, tuple_getitem.get()); - (*outputs).push_back(tuple_getitem); - } -} - -template -tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, - size_t data_length) { - MS_EXCEPTION_IF_NULL(value_tuple_ptr); - MS_EXCEPTION_IF_NULL(type_ptr); - std::vector values; - for (const auto &v : value_tuple_ptr->value()) { - MS_EXCEPTION_IF_NULL(v); - if (v->isa()) { - ScalarPtr scalar = v->cast(); - values.push_back(GetValue(scalar)); - } else { - MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; - return nullptr; - } - } - std::vector tensor_shape = {SizeToInt(values.size())}; - tensor::TensorPtr tensor = std::make_shared(type_ptr->type_id(), tensor_shape); - MS_EXCEPTION_IF_NULL(tensor); - tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr}; - tensor->set_device_info(device_info); - auto data_ptr = tensor->data_c(); - MS_EXCEPTION_IF_NULL(data_ptr); - auto elem_num = values.size() * data_length; - auto ret_code = memcpy_s(data_ptr, static_cast(tensor->data().nbytes()), values.data(), elem_num); - if (ret_code != 0) { - MS_LOG(EXCEPTION) << "Failed to copy data into Tensor."; - } - return tensor; -} - -tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { - MS_EXCEPTION_IF_NULL(value_tuple); - tensor::TensorPtr tensor = nullptr; - if (value_tuple->value().empty()) { - MS_LOG(WARNING) << "The value tuple is empty."; - return nullptr; - } - ValuePtr v = *(value_tuple->value().begin()); - MS_EXCEPTION_IF_NULL(v); - // Currently we only deal with the scalar tuple - if (!v->isa()) { - MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar"; - return nullptr; - } - ScalarPtr scalar = v->cast(); - MS_EXCEPTION_IF_NULL(scalar); - if (scalar->isa()) { - tensor = CreateTensorWithValueTuple(value_tuple, kInt32, kType32Len); - } else if (scalar->isa()) { - tensor = CreateTensorWithValueTuple(value_tuple, kFloat32, kType32Len); - } else { - auto type = scalar->type(); - auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); - MS_LOG(ERROR) << "Invalid scalar type: " << type_str; - return nullptr; - } - return tensor; -} - -bool IsNopNode(const AnfNodePtr &node) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { - return false; - } - static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, - prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), - kFlattenGradOpName}; - if (node == nullptr || !node->isa()) { - return false; - } - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (nop_nodes.find(AnfAlgo::GetCNodeName(cnode)) == nop_nodes.end()) { - return false; - } - return true; -} - -bool IsAllNopNode(const session::KernelGraph *const graph) { - MS_EXCEPTION_IF_NULL(graph); - auto execution_order = graph->execution_order(); - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - if (!IsNopNode(cnode)) { - return false; - } - } - return true; -} - -void HideNopNode(session::KernelGraph *const graph) { - MS_EXCEPTION_IF_NULL(graph); - if (IsAllNopNode(graph) == true) { - return; - } - auto execution_order = graph->execution_order(); - MS_LOG(INFO) << "nop node info (Before Remove) size: " << execution_order.size(); - std::vector new_nodes; - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - if (!IsNopNode(cnode)) { - new_nodes.push_back(cnode); - } - } - graph->set_execution_order(new_nodes); - MS_LOG(INFO) << "nop node info (After Remove) size: " << graph->execution_order().size(); -} - -void RemoveNopNode(session::KernelGraph *const graph) { - MS_EXCEPTION_IF_NULL(graph); - if (IsAllNopNode(graph) == true) { - return; - } - bool changed = true; - while (changed) { - changed = false; - std::vector new_nodes; - for (auto &cnode : graph->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode); - // ignore nop node itself - if (IsNopNode(cnode)) { - continue; - } - // Replace the input which is nop node - std::vector new_inputs; - new_inputs.push_back(cnode->input(0)); - bool need_update = false; - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto input = cnode->input(i); - MS_EXCEPTION_IF_NULL(input); - auto cinput = input->cast(); - if (cinput == nullptr || !IsNopNode(cinput)) { - new_inputs.push_back(input); - continue; - } - if (cinput->inputs().size() == 2) { - new_inputs.push_back(cinput->input(1)); - need_update = true; - changed = true; - } else { - new_inputs.push_back(input); - } - } - if (need_update) { - cnode->set_inputs(new_inputs); - } - // push into new execution list - new_nodes.push_back(cnode); - } - graph->set_execution_order(new_nodes); - } -} - -std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, - const AnfNodePtr &node) { - auto output_node_list = std::make_shared>>(); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto iter = manager->node_users().find(node); - if (iter == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - auto output_info_list = iter->second; - for (const auto &output_info : output_info_list) { - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { - continue; - } - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && - output_info.second == kDependAttachNodeIndex) { - continue; - } - output_node_list->push_back(output_info); - } - return output_node_list; -} - -bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto output_node_list = GetRealNodeUsedList(graph, node); - MS_EXCEPTION_IF_NULL(output_node_list); - return output_node_list->size() > 1; -} - -AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { - auto idx = NewValueNode(SizeToInt(output_idx)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(SizeToInt(output_idx)); - auto abstract_scalar = std::make_shared(imm); - idx->set_abstract(abstract_scalar); - AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); - MS_EXCEPTION_IF_NULL(tuple_getitem); - tuple_getitem->set_scope(node->scope()); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); - AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); - return tuple_getitem; -} - -void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs; - std::vector new_input_names; - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - auto input_names = primitive->GetAttr(kAttrInputNames); - if (input_names == nullptr) { - MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; - return; - } - auto input_names_vec = GetValue>(input_names); - auto inputs = cnode->inputs(); - new_inputs.push_back(inputs[0]); - bool need_update = false; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - auto input_node = inputs[i + 1]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; - if (i >= input_names_vec.size()) { - MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; - } - primitive->set_attr(input_names_vec[i], value_node->value()); - need_update = true; - } else { - new_inputs.push_back(input_node); - if (i < input_names_vec.size()) { - new_input_names.push_back(input_names_vec[i]); - } - } - } - if (need_update) { - // Update cnode's inputs - cnode->set_inputs(new_inputs); - // Update cnode's input_names attr - primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); - } -} - -bool AnfEqual(const BaseRef &a, const BaseRef &b) { - if (utils::isa(a) && utils::isa(b)) { - auto a_node = utils::cast(a); - auto b_node = utils::cast(b); - MS_EXCEPTION_IF_NULL(a_node); - MS_EXCEPTION_IF_NULL(b_node); - if (IsValueNode(a_node) && IsValueNode(b_node)) { - auto a_value_node = a_node->cast(); - MS_EXCEPTION_IF_NULL(a_value_node); - auto a_value = a_value_node->value(); - MS_EXCEPTION_IF_NULL(a_value); - auto a_prim = a_value->cast(); - MS_EXCEPTION_IF_NULL(a_prim); - - auto b_value_node = b_node->cast(); - MS_EXCEPTION_IF_NULL(b_value_node); - auto b_value = b_value_node->value(); - MS_EXCEPTION_IF_NULL(b_value); - auto b_prim = b_value->cast(); - MS_EXCEPTION_IF_NULL(b_prim); - - return a_prim->name() == b_prim->name(); - } else if (a_node->isa() && b_node->isa()) { - auto a_value_node_ptr = a_node->cast(); - if (a_value_node_ptr == nullptr) { - MS_LOG(EXCEPTION) << "cast value node ptr fail"; - } - auto a_value_ptr = a_value_node_ptr->value(); - if (a_value_ptr == nullptr) { - MS_LOG(EXCEPTION) << "value ptr is nullptr"; - } - - auto b_value_node_ptr = b_node->cast(); - if (b_value_node_ptr == nullptr) { - MS_LOG(EXCEPTION) << "cast value node ptr fail"; - } - auto b_value_ptr = b_value_node_ptr->value(); - if (b_value_ptr == nullptr) { - MS_LOG(EXCEPTION) << "value ptr is nullptr"; - } - - return (*a_value_ptr) == (*b_value_ptr); - } - MS_LOG(DEBUG) << "check AnfNodePtr equal"; - } - if (utils::isa(a) && utils::isa(b)) { - MS_LOG(DEBUG) << "check GraphPtr equal"; - } - return a == b; -} - -bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b) { - // To matchCNode and Kernel's type - if (utils::isa(a) && utils::isa(b)) { - return true; - } - return a.type() == b.type(); -} - -namespace { -ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - if (utils::isa(sexp)) { - return NewValueNode(utils::cast(sexp)); - } - return nullptr; -} - -CNodePtr CreateCNodeWithGraph(const std::vector &input_nodes, const BaseRef &graph) { - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - if (utils::isa(graph)) { - return std::make_shared(input_nodes, utils::cast(graph)); - } - return nullptr; -} - -VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { - if (utils::isa(graph)) { - MS_LOG(DEBUG) << "make VarPtr " + graph.ToString(); - return std::make_shared(utils::cast(sexp), nullptr); - } - if (utils::isa(graph)) { - MS_LOG(DEBUG) << "VarNode, should input a Var in graph. It's GraphPtr: " + graph.ToString(); - return std::make_shared(utils::cast(sexp), utils::cast(graph)); - } - MS_LOG(ERROR) << "VarNode, should input a Var in graph. It's " + graph.ToString(); - return nullptr; -} - -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph) { - MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - std::vector input_nodes; - const auto &tuple = utils::cast(sexp); - if (multigraph && utils::isa(graph)) { - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); - input_nodes.push_back(node); - } - VarPtr var_ptr = utils::cast(graph); - return std::make_shared(input_nodes, var_ptr); - } - - for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); - input_nodes.push_back(node); - } - return CreateCNodeWithGraph(input_nodes, graph); -} -} // namespace - -AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, bool multigraph) { - MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); - MS_EXCEPTION_IF_NULL(primitive_vars); - if (utils::isa(sexp)) { - return HandleSexpVector(sexp, graph, primitive_vars, multigraph); - } - if (utils::isa(sexp)) { - auto var_ptr = utils::cast(sexp); - MS_EXCEPTION_IF_NULL(var_ptr); - if (var_ptr->primitive()) { - (*primitive_vars)[var_ptr->primitive()] = var_ptr; - return NewValueNode(var_ptr->primitive()); - } - return CreateVarNodeWithSexp(sexp, graph); - } - if (utils::isa(sexp)) { - return utils::cast(sexp); - } - auto value_node = CreateValueNodeWithSexp(sexp); - if (value_node == nullptr) { - MS_LOG(EXCEPTION) << "sexp cannot converted. sexp: " + sexp.ToString(); - } - return value_node; -} - -bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node) { - MS_EXCEPTION_IF_NULL(equiv1); - MS_EXCEPTION_IF_NULL(equiv2); - MS_EXCEPTION_IF_NULL(var_node); - auto equiv1_node = GetAnfNodeByVar(equiv1, var_node); - MS_EXCEPTION_IF_NULL(equiv1_node); - auto equiv2_node = GetAnfNodeByVar(equiv2, var_node); - MS_EXCEPTION_IF_NULL(equiv2_node); - return *equiv1_node == *equiv2_node; -} - -AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node) { - MS_EXCEPTION_IF_NULL(equiv); - MS_EXCEPTION_IF_NULL(var_node); - auto iter = (*equiv).find(var_node); - if (iter == (*equiv).end()) { - MS_LOG(INFO) << "The equiv map doesn't contain the var_node after matched."; - return nullptr; - } - auto res = utils::cast(iter->second); - if (res == nullptr) { - MS_LOG(EXCEPTION) << "Cast fail! Maybe var is not a anf node"; - } - return res; -} - -bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) { - MS_EXCEPTION_IF_NULL(n1); - MS_EXCEPTION_IF_NULL(n2); - auto n1_cnode = n1->cast(); - auto n2_cnode = n2->cast(); - MS_EXCEPTION_IF_NULL(n1_cnode); - MS_EXCEPTION_IF_NULL(n2_cnode); - auto index_input1 = n1_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_input1); - auto value_node1 = index_input1->cast(); - MS_EXCEPTION_IF_NULL(value_node1); - auto index_input2 = n2_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_input2); - auto value_node2 = index_input2->cast(); - MS_EXCEPTION_IF_NULL(value_node2); - return GetValue(value_node1->value()) < GetValue(value_node2->value()); -} - -bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(INFO) << "node is not a cnode"; - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - return AnfAlgo::HasNodeAttr(attr_name, cnode) && AnfAlgo::GetNodeAttr(node, attr_name); -} - -bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set) { - MS_EXCEPTION_IF_NULL(node); - TypeId data_type = AnfAlgo::GetOutputInferDataType(node, 0); - if (supported_data_type_set.find(data_type) != supported_data_type_set.end()) { - return true; - } - MS_LOG(DEBUG) << "Not supported data type. Node:" << node->DebugString(); - return false; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h deleted file mode 100644 index 49a1d47d0c..0000000000 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ /dev/null @@ -1,199 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ - -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "session/kernel_graph.h" -#include "common/utils.h" -#include "pre_activate/common/pattern_engine.h" - -namespace mindspore { -namespace opt { -constexpr size_t kTransOpInputNum = 2; -constexpr size_t kCastInputNum = 2; -constexpr size_t kDependInputNum = 3; -constexpr size_t kReluInputNum = 2; -constexpr size_t kReluGradInputNum = 3; -constexpr size_t kAddInputNum = 3; -constexpr size_t kAddNInputNum = 3; -constexpr size_t kTupleGetitemInputNum = 3; -constexpr size_t kConvInputNum = 3; -constexpr size_t kRealDivInputNum = 3; -constexpr size_t kSqrtInputNum = 2; -constexpr size_t kMulInputNum = 3; -constexpr size_t kRsqrtInputNum = 2; -constexpr size_t kSubInputNum = 3; -constexpr size_t kAssignSubInputNum = 3; - -constexpr size_t kConvBn1OutputNum = 3; -constexpr size_t kBn2ReluOutputNum = 4; - -constexpr size_t kBnInputNum = 6; -constexpr size_t kBnOutputNum = 5; -constexpr size_t kBatchNormInputNum = 5; -constexpr size_t kBatchNormOutputNum = 5; - -constexpr size_t kBN1OutputNum = 2; -constexpr size_t kBN2OutputNum = 3; -constexpr size_t kBN3OutputNum = 1; - -constexpr size_t kBNGradInputNum = 6; -constexpr size_t kBNGradOutputNum = 3; - -constexpr size_t kBNGrad1OutputNum = 3; -constexpr size_t kBNGrad2OutputNum = 5; -constexpr size_t kBNGrad3OutputNum = 1; - -constexpr size_t kBNTrainingReduceOutputNum = 2; -constexpr size_t kBNTrainingUpdateOutputNum = 5; -constexpr size_t kBNTrainingUpdateV2OutputNum = 3; -constexpr size_t kBNTrainingUpdateV3OutputNum = 5; -constexpr size_t kBNTrainingUpdateGradOutputNum = 2; - -constexpr size_t kSingleOutputNum = 1; -constexpr size_t kSumNodeInputNum = 2; -constexpr size_t kSquareNodeInputNum = 2; -constexpr size_t kSquareSumv2OutputNum = 2; -constexpr size_t kMinimumInputNum = 3; - -constexpr size_t kLambNextMVWithDecayInputNum = 7; -constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5; -constexpr size_t kLambNextMVWithDecayOutputNum = 4; -constexpr size_t kLambNextMVWithDecayV1OutputNum = 4; -constexpr size_t kLambNextRightOutputNum = 2; -constexpr size_t kLambUpdateWithLrV2InputNum = 8; -constexpr size_t kLambNextMVRuleInputNum = 14; -constexpr size_t kLambNextMVRuleOutputNum = 4; -constexpr size_t kBackendReshapeInputNum = 2; -constexpr size_t kBackendTransposeInputNum = 2; -constexpr size_t kAdamApplyOneWithDecayOutputNum = 3; -constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5; -constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2; -constexpr size_t kLayerNormGradInputNum = 6; -constexpr size_t kAdamApplyOneOutputNum = 3; -constexpr size_t kBackendTransDataInputNum = 2; -constexpr size_t kApplyMomentumInputNum = 6; -constexpr size_t kBiasAddInputNum = 3; -constexpr size_t kTopkInputNum = 3; -constexpr size_t kLarsV2InputNum = 5; -constexpr size_t kFusedMulApplyMomentumOutputNum = 2; -constexpr size_t kSplitInputNum = 2; - -enum FusedBatchNormInput { - kX = 1, - kVariance = 5, -}; -enum FusedBatchNormOutput { - kY = 0, - kRunningMean, - kRunningVariance, - kSaveMean, - kSaveInvVariance, -}; -enum ConvBn1Output { - kData = 0, - kVarPart, - kMean, -}; - -std::vector Convert2Int(const std::vector &v); - -// check whether node1 depends on node2 or not -bool IsDepend(const FuncGraphPtr &graph, const AnfNodePtr &node1, const AnfNodePtr &node2); - -bool UnVisited(const BaseRef &n); - -bool Visited(const BaseRef &n); - -// check if the input node is CNode, then check it's input_size, if meet condition above, return true, otherwise return -// false. cnode can only be used when return true. -bool CheckIfCNodeAndInputSize(const AnfNodePtr &node, int input_size, CNodePtr *cnode); - -// check if the input node is CNode, then check it's input_size, return CNodePtr if check success. -CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, int input_size); - -void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size); - -bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y); - -const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node); - -void CreateOutputsOfConvBn1(const FuncGraphPtr &func_graph, const CNodePtr &conv_cnode, const CNodePtr &bn_cnode, - std::vector *conv_bn1_outputs); - -void CreateOutputsOfFusedBn2(const FuncGraphPtr &graph, const std::vector &fused_bn1_outputs, - const CNodePtr &bn_node, std::vector *fused_bn2_outputs); -void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_input, - const std::vector &fused_bn1_outputs, - const std::vector &fused_bn2_outputs, const CNodePtr &bn_node, - std::vector *fused_bn3_outputs); - -void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num, - std::vector *outputs); - -tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr, - size_t data_length); - -tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); - -bool IsAllNopNode(const session::KernelGraph *const graph); - -bool IsNopNode(const AnfNodePtr &node); - -void HideNopNode(session::KernelGraph *const graph); - -void RemoveNopNode(session::KernelGraph *const graph); - -AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); - -bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); - -std::shared_ptr>> GetRealNodeUsedList(const FuncGraphPtr &graph, - const AnfNodePtr &node); - -void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); - -bool AnfEqual(const BaseRef &a, const BaseRef &b); - -bool CNodeTypeEqual(const BaseRef &a, const BaseRef &b); - -AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, - bool multigraph = false); - -// Check var_node in two equivs is the same node -bool IsSameNode(const EquivPtr &equiv1, const EquivPtr &equiv2, const VarPtr &var_node); - -// Get anf_node from equiv by var_node -AnfNodePtr GetAnfNodeByVar(const EquivPtr &equiv, const VarPtr &var_node); - -// Compare tuple getitem's index, return bool[n1's index < n2's index] -bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2); - -// Get attr which is bool from cnode -bool GetBoolAttr(const AnfNodePtr &node, const std::string &attr_name); - -// Check node's data type is in supported data type set -bool CheckSupportDataType(const AnfNodePtr &node, const std::set &supported_data_type_set); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/node_pass.cc b/mindspore/ccsrc/pre_activate/common/node_pass.cc deleted file mode 100644 index 876da8667b..0000000000 --- a/mindspore/ccsrc/pre_activate/common/node_pass.cc +++ /dev/null @@ -1,73 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/node_pass.h" - -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -bool NodePass::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - std::unordered_set seen_node; - std::deque todo{func_graph->output()}; - bool changes = false; - while (!todo.empty()) { - AnfNodePtr node = todo.front(); - todo.pop_front(); - if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { - continue; - } - (void)seen_node.insert(node); - AnfNodePtr new_node = Run(func_graph, node); - bool change = (new_node != nullptr); - if (new_node != nullptr && new_node != node) { - (void)manager->Replace(node, new_node); - (void)seen_node.erase(node); - } else if (new_node == nullptr) { - new_node = node; - } - if (new_node && IsValueNode(new_node)) { - auto const_func_graph = GetValueNode(new_node); - MS_EXCEPTION_IF_NULL(const_func_graph); - if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - todo.push_back(const_func_graph->output()); - } - } else if (new_node && new_node->isa()) { - if (AnfAlgo::IsGraphKernel(new_node)) { - todo.push_back(new_node); - } - auto cnode = new_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - (void)todo.insert(todo.end(), inputs.begin(), inputs.end()); - } - changes = changes || change; - } - return changes; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/node_pass.h b/mindspore/ccsrc/pre_activate/common/node_pass.h deleted file mode 100644 index 7750a59e59..0000000000 --- a/mindspore/ccsrc/pre_activate/common/node_pass.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ -#include -#include - -#include "pre_activate/common/pass.h" - -namespace mindspore { -namespace opt { -// @brief ANF Node level optimization base pass -class NodePass : public Pass { - public: - explicit NodePass(const std::string &name) : Pass(name) {} - ~NodePass() override = default; - bool Run(const FuncGraphPtr &func_graph) final; - virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; -}; -using NodePassPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.cc b/mindspore/ccsrc/pre_activate/common/optimizer.cc deleted file mode 100644 index 71a523ea1d..0000000000 --- a/mindspore/ccsrc/pre_activate/common/optimizer.cc +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/optimizer.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "pre_activate/common/pass_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "ir/manager.h" - -namespace mindspore { -namespace opt { -PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) - : NodePass(name), - multigraph_(multigraph), - pattern_engine_(PatternEngine(std::make_shared(), - std::function(AnfEqual), - std::function(CNodeTypeEqual))), - primitive_vars_(std::make_shared()) {} - -const BaseRef PatternProcessPass::DefinePattern() const { - VarPtr X = std::make_shared(); - return BaseRef({X}); -} - -void PatternProcessPass::Build() { - VarPtr fg = std::make_shared("RootG"); - BaseRef pattern = std::move(DefinePattern()); - pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); -} - -AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - if (pattern_ == nullptr) { - Build(); - } - - auto empty_equiv = std::make_shared(); - MS_EXCEPTION_IF_NULL(primitive_vars_); - EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); - if (equiv != nullptr && !equiv->empty()) { - return Process(func_graph, node, equiv); - } - return nullptr; -} - -bool MultipleOutputPatternProcessPass::MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - VarPtr fg = std::make_shared("RootG"); - auto empty_equiv = std::make_shared(); - MS_EXCEPTION_IF_NULL(child_primitive_vars_); - EquivPtr another_equiv = - child_pattern_engine_.Match(SexpToNode(DefineAnotherPattern(), fg, child_primitive_vars_.get(), true), node, - *child_primitive_vars_, empty_equiv); - if (another_equiv != nullptr && !another_equiv->empty()) { - return IsShareNodes(equiv, another_equiv); - } - return false; -} - -void GraphOptimizer::AddPassManager(const PassManagerPtr &pass_manager) { - if (pass_manager != nullptr) { - pass_managers_.push_back(pass_manager); - } -} - -FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_only_once) { - MS_EXCEPTION_IF_NULL(func_graph); - run_only_once_ = (pass_managers_.size() == 1) ? true : run_only_once; - // Performance risk by creating new manager each time - auto manager = Manage(func_graph, true); - - bool changed = true; - while (changed) { - changed = false; - for (size_t i = 0; i < pass_managers_.size(); ++i) { - const PassManagerPtr &pm = pass_managers_[i]; - if (pm != nullptr && pm->Run(func_graph)) { - changed = true; - } - } - if (run_only_once_) { - break; - } - } - - std::vector func_graphs; - func_graphs.push_back(func_graph); - manager->KeepRoots(func_graphs); - (void)TopoSort(func_graph->get_return()); - return func_graph; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.h b/mindspore/ccsrc/pre_activate/common/optimizer.h deleted file mode 100644 index 1f9961df6b..0000000000 --- a/mindspore/ccsrc/pre_activate/common/optimizer.h +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ - -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/primitive.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/common/pattern_engine.h" -#include "utils/graph_utils.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -using PatternListType = std::initializer_list; - -class PatternProcessPass : public NodePass { - public: - explicit PatternProcessPass(const std::string &name = "", bool multigraph = true); - ~PatternProcessPass() override = default; - virtual const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; - virtual const BaseRef DefinePattern() const; - AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; - - private: - void Build(); - - AnfNodePtr pattern_ = nullptr; - bool multigraph_ = true; - PatternEngine pattern_engine_; - PrimitiveVarMapPtr primitive_vars_; -}; - -class MultipleOutputPatternProcessPass : public PatternProcessPass { - public: - explicit MultipleOutputPatternProcessPass(const std::string &name = "", bool multigraph = true) - : PatternProcessPass(name, multigraph), - child_pattern_engine_(PatternEngine(std::make_shared(), - std::function(AnfEqual), - std::function(CNodeTypeEqual))), - child_primitive_vars_(std::make_shared()) {} - ~MultipleOutputPatternProcessPass() override = default; - virtual BaseRef DefineAnotherPattern() const = 0; - // check two patterns whether share the same nodes or not - virtual bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const = 0; - - protected: - bool MatchAnotherPattern(const AnfNodePtr &node, const EquivPtr &equiv) const; - PatternEngine child_pattern_engine_; - PrimitiveVarMapPtr child_primitive_vars_; -}; - -class GraphOptimizer { - public: - explicit GraphOptimizer(const std::string &name = "graph_optimizer") : name_(name) {} - virtual ~GraphOptimizer() = default; - - void AddPassManager(const PassManagerPtr &pass_manager); - FuncGraphPtr Optimize(const FuncGraphPtr &func_graph, bool run_only_once = true); - - private: - const std::string name_ = "graph_optimizer"; - std::vector pass_managers_{}; - bool run_only_once_ = true; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_OPTIMIZER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/pass.h b/mindspore/ccsrc/pre_activate/common/pass.h deleted file mode 100644 index 3d2468cddb..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pass.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ -#include -#include - -#include "ir/anf.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -// @brief ANF Graph level optimization base pass -class Pass { - public: - explicit Pass(const std::string &name = "pass") : name_(name) {} - virtual ~Pass() = default; - virtual bool Run(const FuncGraphPtr &func_graph) = 0; - virtual std::string name() const { return name_; } - - private: - const std::string name_; -}; -using PassPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/common/pass_manager.cc b/mindspore/ccsrc/pre_activate/common/pass_manager.cc deleted file mode 100644 index 3213b8a6d2..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pass_manager.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/common/pass_manager.h" - -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "ir/manager.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -const std::vector &PassManager::Passes() const { return passes_; } - -void PassManager::AddPass(const PassPtr &pass) { - if (pass != nullptr) { - passes_.push_back(pass); - } -} - -bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector &passes) const { - if (func_graph == nullptr) { - return false; - } - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - bool changed = false; - size_t num = 0; - for (const auto &pass : passes) { - if (pass != nullptr) { -#if defined(_WIN32) || defined(_WIN64) - auto start_time = std::chrono::steady_clock::now(); -#else - struct timeval start_time {}; - struct timeval end_time {}; - (void)gettimeofday(&start_time, nullptr); -#endif - if (pass->Run(func_graph)) { - changed = true; - } -#if defined(_WIN32) || defined(_WIN64) - auto end_time = std::chrono::steady_clock::now(); - std::chrono::duration> cost = end_time - start_time; - MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost.count() << " us"; -#else - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Run pass hwopt_" + name() + "_" << num << "_" + pass->name() + " in " << cost << " us"; -#endif - if (save_graphs) { - auto dump_file_path = - save_graphs_path + "/" + "hwopt_" + name() + "_" + std::to_string(num) + "_" + pass->name() + ".ir"; - DumpIR(dump_file_path, func_graph); - } - num++; - } - } - return changed; -} - -bool PassManager::Run(const FuncGraphPtr &func_graph) const { - bool changed = false; - // run all passes - bool change = true; - while (change) { - change = Run(func_graph, passes_); - changed = change || changed; - if (run_only_once_) { - break; - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/pass_manager.h b/mindspore/ccsrc/pre_activate/common/pass_manager.h deleted file mode 100644 index 38fe49b94c..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pass_manager.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ - -#include -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "pre_activate/common/node_pass.h" - -namespace mindspore { -namespace opt { -// @brief For optimization passes management -class PassManager { - public: - explicit PassManager(const std::string &name = "pm", bool run_only_once = true) - : name_(name), passes_{}, run_only_once_(run_only_once) {} - virtual ~PassManager() = default; - // Get all the passes added by AddPass - const std::vector &Passes() const; - // Add graph pass, the pass object will be freed when pass manager freed. - void AddPass(const PassPtr &pass); - // Run passes added in pass manager on the input graph - // @param [inout] graph The graph to be optimized - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph) const; - // Run the given graph passes on the input graph - // @param [inout] graph The graph to be optimized - // @param [in] passes The given graph passes - // @return true, graph changed - // @return false, graph not changed - bool Run(const FuncGraphPtr &func_graph, const std::vector &passes) const; - std::string name() const { return name_; } - - private: - const std::string name_; - std::vector passes_; - bool run_only_once_; -}; -using PassManagerPtr = std::shared_ptr; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PASS_MANAGER_H_ diff --git a/mindspore/ccsrc/pre_activate/common/pattern_engine.cc b/mindspore/ccsrc/pre_activate/common/pattern_engine.cc deleted file mode 100644 index 42f966aa3d..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pattern_engine.cc +++ /dev/null @@ -1,360 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/common/pattern_engine.h" - -#include -#include -#include -#include - -#include "optimizer/opt.h" - -#include "ir/anf.h" -#include "utils/convert_utils_base.h" -#include "utils/overload.h" - -namespace mindspore { -static int GetNextTag() { - static int kID = 0; - return kID++; -} - -void Var::EnsureTag() { - if (tag_.length() == 0) { - std::ostringstream buffer; - buffer << "_" << GetNextTag(); - tag_ = buffer.str(); - } -} - -bool operator==(const VarPtr &lhs, const VarPtr &rhs) { - if (lhs->isa() && rhs->isa()) { - CondVarPtr v1 = dyn_cast(lhs); - CondVarPtr v2 = dyn_cast(rhs); - return *v1 == *v2; - } - - if (lhs->isa() && rhs->isa()) { - SVarPtr v1 = dyn_cast(lhs); - SVarPtr v2 = dyn_cast(rhs); - return *v1 == *v2; - } - return (*lhs == *rhs); -} - -std::string SeqVar::ToString() const { - std::ostringstream buffer; - buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; - return buffer.str(); -} - -std::ostream &operator<<(std::ostream &os, const VarPtr &var) { - if (var == nullptr) { - os << ""; - } else { - os << var->ToString(); - } - return os; -} - -template <> -std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { - os << "[Equiv]" - << "\n"; - for (auto &equiv_item : equiv) { - auto k = equiv_item.first; - os << k << ":"; - BaseRef x = equiv_item.second; - if (utils::isa(x)) { - auto node = utils::cast(x); - os << "TypeString[" << node->type_name() << "]"; - if (IsValueNode(node)) { - os << "IsValueNodeGraph "; - } - os << "type " << node->type_name(); - if (node->isa()) { - os << " value " << GetValueNode(node); - } - os << " addr: " << node; - } else if (utils::isa(x)) { - os << "Named " << x.ToString().c_str(); - } else if (utils::isa(x)) { - os << "TypeString[Var]"; - os << utils::cast(x); - } else if (utils::isa(x)) { - os << "TypeString[Graph]"; - } - os << "\n"; - } - return os; -} - -static BaseRef GetVar(const BaseRef &x) { - MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); - if (utils::isa(x)) { - auto node = utils::cast(x); - MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; - if (node->isa()) { - MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); - return node->cast()->var_; - } - if (node->isa()) { - MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); - } else { - MS_LOG(DEBUG) << "type " + node->type_name(); - } - } else if (utils::isa(x)) { - MS_LOG(DEBUG) << "Named " + x.ToString(); - } else if (utils::isa(x)) { - MS_LOG(DEBUG) << "VectorRef"; - } else if (utils::isa(x)) { - MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); - } - MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); - return x; -} - -EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { - MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); - MS_EXCEPTION_IF_NULL(equiv); - if (utils::isa(pattern)) { - VarPtr var = utils::cast(pattern); - if (var->matches(expr)) { - (*equiv)[var] = expr; - MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); - return equiv; - } - } - - return nullptr; -} - -bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const { - MS_EXCEPTION_IF_NULL(values_expr); - if (utils::isa(pattern_ref)) { - *values_pattern = pattern_ref; - *values_expr = expr_ref; - return true; - } - return false; -} - -bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const { - MS_EXCEPTION_IF_NULL(values_expr); - // visitor to visite the list - auto appender_pattern = [](VectorRef &values) { - std::function fn = [&](const BaseRef &u) { - values.push_back(GetVar(u)); - return u; - }; - return fn; - }; - - visitor_->SetFn(appender_pattern(*values_pattern)); - MS_LOG(DEBUG) << "visit pattern_ref"; - bool success = visitor_->Visit(pattern_ref, nullptr); - if (!success) { - return false; - } - - auto appender_expr = [](VectorRef &values) { - std::function fn = [&](const BaseRef &u) { - values.push_back(u); - return u; - }; - return fn; - }; - - visitor_->SetFn(appender_expr(*values_expr)); - MS_LOG(DEBUG) << "visit expr_ref"; - return visitor_->Visit(expr_ref, nullptr); -} - -static int GetSVarStartIndex(const VectorRef &values) { - int index = -1; - int count = 0; - for (auto &value : values) { - if (utils::isa(value) && utils::cast(value)->isa()) { - if (index != -1) { - MS_LOG(DEBUG) << "Multiple SVars in sequence"; - return kInvalidVarIndex; - } - index = count; - } - count++; - } - return index; -} - -void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) { - if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || - !utils::isa(expr_ref)) { - return; - } - auto real_node = utils::cast(expr_ref); - MS_EXCEPTION_IF_NULL(real_node); - if (!real_node->isa()) { - return; - } - auto prim_node = utils::cast(values_pattern[0]); - MS_EXCEPTION_IF_NULL(prim_node); - if (!IsValueNode(prim_node)) { - return; - } - ValuePtr value = GetValueNode(prim_node); - MS_EXCEPTION_IF_NULL(value); - auto prim = value->cast(); - MS_EXCEPTION_IF_NULL(prim); - auto iter = primitive_vars.find(prim); - if (iter == primitive_vars.end()) { - return; - } - (*equiv)[iter->second] = real_node; -} - -EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, - const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { - int svar_index = GetSVarStartIndex(values_pattern); - if (svar_index == kInvalidVarIndex) { - return nullptr; - } - - size_t values_pattern_len = values_pattern.size(); - size_t values_expr_len = values_expr.size(); - - if (svar_index == -1) { - if (values_pattern_len != values_expr_len) { - MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len " - << values_expr_len; - return nullptr; - } - } - if (values_expr_len < values_pattern_len - 1) { - MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; - return nullptr; - } - size_t diff = values_expr_len - values_pattern_len + 1; - for (size_t i = 0; i < values_pattern_len; i++) { - size_t expr_i = i; - if (svar_index != -1 && i == IntToSize(svar_index)) { - auto seq = - std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); - equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); - } else { - if (svar_index != -1 && i > IntToSize(svar_index)) { - expr_i = i + diff - 1; - } - equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); - } - if (equiv == nullptr) { - return nullptr; - } - } - return equiv; -} - -EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) const { - MS_LOG(DEBUG) << "-----[in Match]"; - MS_LOG(DEBUG) << "GetVar w"; - BaseRef pattern_ref = GetVar(pattern); - MS_LOG(DEBUG) << "GetVar v"; - BaseRef expr_ref = expr; - - if (equiv == nullptr) { - MS_LOG(EXCEPTION) << "Equiv pointer is null"; - } - - MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); - // 1. if pattern_ref is var and already in equiv, replace it. - if (utils::isa(pattern_ref)) { - VarPtr var = utils::cast(pattern_ref); - auto iter = equiv->find(var); - if (iter != equiv->end()) { - pattern_ref = iter->second; - } - } - - // 2. check equal - if (eq_(pattern_ref, expr_ref)) { - return equiv; - } - - // 3. match var - EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); - if (ret_equiv) { - return ret_equiv; - } - - // 4. here the type can be std:vector, std:list, - // or cnode. - if (!type_eq_(pattern_ref, expr_ref)) { - MS_LOG(DEBUG) << "Type mismatch"; - return nullptr; - } - - // 5. transfer the Containers by visitor to std::vector - VectorRef values_pattern; - VectorRef values_expr; - if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { - return nullptr; - } - - // 6. if any svar in both side, find the SeqVar index, - // try to pack the Var s in std::vector to a Seq and match elements one by one. - // check svar - equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); - UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); - return equiv; -} - -BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(equiv); - MS_LOG(DEBUG) << "-----[in Replace]"; - BaseRef ref = GetVar(pattern); - BaseRef out; - bool is_match = false; - - // w is var - if (utils::isa(ref)) { - const VarPtr &var = utils::cast(ref); - auto iter = equiv->find(var); - if (iter != equiv->end()) { - out = iter->second; - is_match = true; - } - } - if (is_match) { - return out; - } - - // visitor to visit the list - std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; - - visitor_->SetFn(fn); - BaseRef visit_out; - if (!visitor_->Visit(pattern, &visit_out)) { - return pattern; - } - return visit_out; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/pattern_engine.h b/mindspore/ccsrc/pre_activate/common/pattern_engine.h deleted file mode 100644 index ff38c50423..0000000000 --- a/mindspore/ccsrc/pre_activate/common/pattern_engine.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "pre_activate/common/visit.h" -#include "base/base.h" -#include "utils/log_adapter.h" -#include "utils/base_ref.h" - -namespace mindspore { -class CondVar; -class SeqVar; -using CondVarPtr = std::shared_ptr; -using SVarPtr = std::shared_ptr; -const int kInvalidVarIndex = -2; - -using ConditionFunc = std::function; - -// Base wildcard variable which could match any anf node. -class Var : public Base { - friend class VarHasher; - - public: - explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } - explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { - EnsureTag(); - } - Var(const Var &other) : Base(other), tag_(other.tag_) {} - virtual Var &operator=(const Var &other) { - if (&other == this) { - return *this; - } - this->tag_ = other.tag_; - return *this; - } - ~Var() override = default; - MS_DECLARE_PARENT(Var, Base); - - virtual bool matches(const BaseRef &) { return true; } - - virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } - bool operator!=(const Var &other) const { return !(&other == this); } - - std::string tag() const { return tag_; } - PrimitivePtr primitive() const { return primitive_; } - std::string ToString() const override { - std::ostringstream buffer; - buffer << "Var(" << tag_ << ")"; - return buffer.str(); - } - std::size_t hash() const override { return std::hash()(tag_); } - - protected: - void EnsureTag(); - - std::string tag_; - PrimitivePtr primitive_; -}; - -// VarNode means variable node, a subclass of AnfNode -class VarNode : public AnfNode { - public: - VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} - ~VarNode() override = default; - MS_DECLARE_PARENT(VarNode, AnfNode); - - const VarPtr var_; -}; -using VarNodePtr = std::shared_ptr; - -class VarHasher { - public: - std::size_t operator()(const Var &var) const { return var.hash(); } -}; - -// Condition Var, match an anf node when condition function return true. -class CondVar : public Var { - public: - explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} - ~CondVar() override = default; - MS_DECLARE_PARENT(CondVar, Var); - bool matches(const BaseRef &value) override { - MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); - if (utils::isa(value)) { - return false; - } - return cond_fn_(value); - } - ConditionFunc cond_fn_; -}; - -using Seq = VectorRef; -using SeqPtr = std::shared_ptr; - -// Sequence Var which could match multiple consecutive input nodes of a CNode. -class SeqVar : public Var { - public: - SeqVar() { subvar_ = std::make_shared(); } - ~SeqVar() override = default; - MS_DECLARE_PARENT(SeqVar, Var); - explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } - bool matches(const BaseRef &value) override { - // match Seq. - if (utils::isa(value)) { - const Seq &seq = utils::cast(value); - return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { - auto eq = subvar_->matches(v); - return eq; - }); - } - return false; - } - bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } - std::string ToString() const override; - - private: - VarPtr subvar_; -}; - -bool operator==(const VarPtr &lhs, const VarPtr &rhs); - -inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } - -std::ostream &operator<<(std::ostream &os, const VarPtr &var); - -using Equiv = std::map; -using EquivPtr = std::shared_ptr; -using PrimitiveVarMap = std::unordered_map; -using PrimitiveVarMapPtr = std::shared_ptr; - -inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } - -class PatternEngine { - public: - PatternEngine(const std::shared_ptr &visitor, - const std::function &eq, - const std::function &type_eq = DefaultTypeEq) - : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} - ~PatternEngine() = default; - - EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, - EquivPtr equiv) const; - // Replace pattern with equivalent - BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; - - private: - EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, - const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; - bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, - VectorRef *const values_expr) const; - bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, - VectorRef *const values_expr) const; - std::shared_ptr visitor_; - std::function eq_; - std::function type_eq_; -}; -} // namespace mindspore -namespace std { -using mindspore::ERROR; -using mindspore::LogStream; -using mindspore::NoExceptionType; -template <> -struct hash { - std::size_t operator()(const mindspore::VarPtr var) const { - if (var == nullptr) { - MS_LOG(ERROR) << "Invalid var ptr"; - return 0; - } - return std::hash{}(var->tag()); - } -}; -} // namespace std -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_PATTERN_ENGINE_H_ diff --git a/mindspore/ccsrc/pre_activate/common/visit.cc b/mindspore/ccsrc/pre_activate/common/visit.cc deleted file mode 100644 index 179177dd67..0000000000 --- a/mindspore/ccsrc/pre_activate/common/visit.cc +++ /dev/null @@ -1,166 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/common/visit.h" - -#include -#include -#include -#include - -#include "pre_activate/common/pattern_engine.h" -#include "utils/any.h" -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "utils/log_adapter.h" - -/* namespace to support utils definition */ -namespace mindspore { -bool CheckIfNeedExpand(const std::vector &list) { - return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); -} - -std::shared_ptr ExpandList(const std::vector &list) { - std::shared_ptr new_list = std::make_shared(); - for (auto &item : list) { - if (utils::isa(item)) { - const Seq &seq = utils::cast(item); - new_list->insert(new_list->end(), seq.begin(), seq.end()); - } else { - new_list->push_back(item); - } - } - return new_list; -} - -bool DefaultVisitor::Visit(const VectorRef &v_any, BaseRef *const visit_out) const { - std::vector out; - (void)std::transform(v_any.begin(), v_any.end(), std::back_inserter(out), - [this](const BaseRef &item) { return fn_(item); }); - if (visit_out != nullptr) { - *visit_out = ExpandList(out); - } - return true; -} - -bool DefaultVisitor::Visit(const BaseRef &any, BaseRef *const visit_out) const { - if (utils::isa(any)) { - return Visit(utils::cast(any), visit_out); - } else if (utils::isa(any)) { - auto nodeptr = utils::cast(any); - AnfNodePtr output; - AnfNodePtr *p_output = &output; - if (visit_out == nullptr) { - p_output = nullptr; - } - Visit(nodeptr, fn_, p_output); - if (visit_out != nullptr) { - *visit_out = output; - } - return true; - } - MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); - return false; -} - -void DefaultVisitor::Visit(const AnfNodePtr &node, const VisitFn &fn, AnfNodePtr *output) const { - if (node->isa()) { - Visit(node->cast(), fn, output); - return; - } - - if (node->isa()) { - Visit(node->cast(), fn, output); - return; - } - - if (output != nullptr) { - *output = node; - } -} - -void DefaultVisitor::Visit(const CNodePtr &cnode, const VisitFn &fn, AnfNodePtr *output) const { - // if output is nullptr, it's not required to make the new CNode node. - if (output == nullptr) { - for (auto &inp : cnode->inputs()) { - (void)fn(inp); - } - - if (cnode->func_graph() != nullptr) { - (void)fn(cnode->func_graph()); - } else { - (void)fn(cnode->func_graph_as_var()); - } - return; - } - - std::vector new_inputs; - std::vector after_cnode_fn; - std::shared_ptr out; - (void)std::transform(cnode->inputs().begin(), cnode->inputs().end(), std::back_inserter(after_cnode_fn), fn); - if (CheckIfNeedExpand(after_cnode_fn)) { - out = ExpandList(after_cnode_fn); - } - - std::vector &outs = after_cnode_fn; - if (out != nullptr) { - outs = out->elements(); - } - - for (auto &any_item : outs) { - if (!utils::isa(any_item)) { - MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; - } - new_inputs.push_back(utils::cast(any_item)); - } - - BaseRef any_fg; - AnfNodePtr new_cnode = nullptr; - if (cnode->func_graph() != nullptr) { - any_fg = fn(cnode->func_graph()); - if (!utils::isa(any_fg)) { - MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; - } - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else { - any_fg = fn(cnode->func_graph_as_var()); - if (utils::isa(any_fg)) { - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else if (utils::isa(any_fg)) { - new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); - } else { - MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; - } - } - new_cnode->set_abstract(cnode->abstract()); - *output = new_cnode; -} - -void DefaultVisitor::Visit(const ValueNodePtr &vnode, const VisitFn &fn, AnfNodePtr *output) const { - const BaseRef &value = utils::cast(fn(vnode->value())); - if (utils::isa(value)) { - if (output != nullptr) { - auto ct = NewValueNode(utils::cast(value)); - ct->set_abstract(vnode->abstract()); - *output = ct; - } - return; - } - MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; -} -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc deleted file mode 100644 index 8111ee429d..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/gpu/adam_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { - std::vector inputs_format; - std::vector outputs_format; - std::vector inputs_type; - std::vector outputs_type; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); - inputs_format.push_back(kOpFormat_DEFAULT); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); - outputs_format.push_back(kOpFormat_DEFAULT); - } - builder.SetInputsDeviceType(inputs_type); - builder.SetInputsFormat(inputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetOutputsFormat(outputs_format); - return builder.Build(); -} -} // namespace - -const BaseRef AdamFusion::DefinePattern() const { - VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), - VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); - VectorRef next_v = - VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), - VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); - VectorRef update = VectorRef( - {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); - VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); - VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; -} - -const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto beta1_input = utils::cast((*equiv)[beta1_]); - auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); - auto beta2_input = utils::cast((*equiv)[beta2_]); - auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); - auto eps_input = utils::cast((*equiv)[eps_]); - auto lr_input = utils::cast((*equiv)[lr_]); - auto param_input = utils::cast((*equiv)[param_]); - auto m_input = utils::cast((*equiv)[m_]); - auto v_input = utils::cast((*equiv)[v_]); - auto gradient_input = utils::cast((*equiv)[gradient_]); - MS_EXCEPTION_IF_NULL(beta1_input); - MS_EXCEPTION_IF_NULL(one_sub_beta1_input); - MS_EXCEPTION_IF_NULL(beta2_input); - MS_EXCEPTION_IF_NULL(one_sub_beta2_input); - MS_EXCEPTION_IF_NULL(eps_input); - MS_EXCEPTION_IF_NULL(lr_input); - MS_EXCEPTION_IF_NULL(param_input); - MS_EXCEPTION_IF_NULL(m_input); - MS_EXCEPTION_IF_NULL(v_input); - MS_EXCEPTION_IF_NULL(gradient_input); - - auto prim = std::make_shared(kFusedAdamName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, - eps_input, lr_input, param_input, m_input, v_input, - gradient_input}; - auto adam = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(adam); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam.get()); - adam->set_scope(node->scope()); - - auto build_info = GenerateKernelBuildInfo(adam); - AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); - return adam; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h b/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h deleted file mode 100644 index d8c10a0986..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_fusion.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class AdamFusion : public PatternProcessPass { - public: - explicit AdamFusion(bool multigraph = true) : PatternProcessPass("adam_fusion", multigraph) { - beta1_ = std::make_shared(); - one_sub_beta1_ = std::make_shared(); - beta2_ = std::make_shared(); - one_sub_beta2_ = std::make_shared(); - eps_ = std::make_shared(); - lr_ = std::make_shared(); - param_ = std::make_shared(); - m_ = std::make_shared(); - v_ = std::make_shared(); - gradient_ = std::make_shared(); - } - ~AdamFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr beta1_; - VarPtr one_sub_beta1_; - VarPtr beta2_; - VarPtr one_sub_beta2_; - VarPtr eps_; - VarPtr lr_; - VarPtr param_; - VarPtr m_; - VarPtr v_; - VarPtr gradient_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc deleted file mode 100644 index c950cbd56f..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/gpu/adam_weight_decay_fusion.h" - -#include -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { - std::vector inputs_format; - std::vector outputs_format; - std::vector inputs_type; - std::vector outputs_type; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); - inputs_format.push_back(kOpFormat_DEFAULT); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); - outputs_format.push_back(kOpFormat_DEFAULT); - } - builder.SetInputsDeviceType(inputs_type); - builder.SetInputsFormat(inputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetOutputsFormat(outputs_format); - return builder.Build(); -} -} // namespace - -const BaseRef AdamWeightDecayFusion::DefinePattern() const { - VectorRef next_m = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta1_, m_}), - VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); - VectorRef next_v = - VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, beta2_, v_}), - VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); - VectorRef update = VectorRef( - {prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimTensorAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); - VectorRef new_update = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); - - VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); - VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - VectorRef depend1 = VectorRef({prim::kPrimDepend, next_v, VectorRef({prim::kPrimAssign, param_, next_param})}); - VectorRef depend2 = VectorRef({prim::kPrimDepend, depend1, VectorRef({prim::kPrimAssign, m_, next_m})}); - VectorRef depend3 = VectorRef({prim::kPrimDepend, depend2, VectorRef({prim::kPrimAssign, v_, depend2})}); - return depend3; -} - -const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - auto beta1_input = utils::cast((*equiv)[beta1_]); - auto one_sub_beta1_input = utils::cast((*equiv)[one_sub_beta1_]); - auto beta2_input = utils::cast((*equiv)[beta2_]); - auto one_sub_beta2_input = utils::cast((*equiv)[one_sub_beta2_]); - auto eps_input = utils::cast((*equiv)[eps_]); - auto lr_input = utils::cast((*equiv)[lr_]); - auto weight_decay_input = utils::cast((*equiv)[weight_decay_]); - auto param_input = utils::cast((*equiv)[param_]); - auto m_input = utils::cast((*equiv)[m_]); - auto v_input = utils::cast((*equiv)[v_]); - auto gradient_input = utils::cast((*equiv)[gradient_]); - MS_EXCEPTION_IF_NULL(beta1_input); - MS_EXCEPTION_IF_NULL(one_sub_beta1_input); - MS_EXCEPTION_IF_NULL(beta2_input); - MS_EXCEPTION_IF_NULL(one_sub_beta2_input); - MS_EXCEPTION_IF_NULL(eps_input); - MS_EXCEPTION_IF_NULL(lr_input); - MS_EXCEPTION_IF_NULL(weight_decay_input); - MS_EXCEPTION_IF_NULL(param_input); - MS_EXCEPTION_IF_NULL(m_input); - MS_EXCEPTION_IF_NULL(v_input); - MS_EXCEPTION_IF_NULL(gradient_input); - - auto prim = std::make_shared(kFusedAdamWeightDecayName); - MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, - eps_input, lr_input, param_input, m_input, v_input, - gradient_input, weight_decay_input}; - auto adam_weight_decay = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(adam_weight_decay); - auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; - auto shapes = {AnfAlgo::GetOutputInferShape(node, 0)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, adam_weight_decay.get()); - adam_weight_decay->set_scope(node->scope()); - - auto build_info = GenerateKernelBuildInfo(adam_weight_decay); - AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); - return adam_weight_decay; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h deleted file mode 100644 index 0ada5756e3..0000000000 --- a/mindspore/ccsrc/pre_activate/gpu/adam_weight_decay_fusion.h +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class AdamWeightDecayFusion : public PatternProcessPass { - public: - explicit AdamWeightDecayFusion(bool multigraph = true) : PatternProcessPass("adam_weight_decay_fusion", multigraph) { - beta1_ = std::make_shared(); - one_sub_beta1_ = std::make_shared(); - beta2_ = std::make_shared(); - one_sub_beta2_ = std::make_shared(); - eps_ = std::make_shared(); - lr_ = std::make_shared(); - weight_decay_ = std::make_shared(); - param_ = std::make_shared(); - m_ = std::make_shared(); - v_ = std::make_shared(); - gradient_ = std::make_shared(); - } - ~AdamWeightDecayFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - VarPtr beta1_; - VarPtr one_sub_beta1_; - VarPtr beta2_; - VarPtr one_sub_beta2_; - VarPtr eps_; - VarPtr lr_; - VarPtr weight_decay_; - VarPtr param_; - VarPtr m_; - VarPtr v_; - VarPtr gradient_; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_GPU_IR_FUSION_ADAM_WEIGHT_DECAY_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.cc b/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.cc deleted file mode 100644 index c75860a8df..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/kernel_refcount.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include -#include "utils/log_adapter.h" -namespace mindspore { -namespace memreuse { -/** - * Add some set && get function - */ -void KernelRefCount::SetKernelRefCountInfo(int index, size_t size, RefCountType reftype) { - index_ = index; - size_ = size; - reftype_ = reftype; -} - -std::vector KernelDef::GetInputRefIndexs() const { - std::vector input_ref_indexs; - if (input_refs_.empty()) { - return input_ref_indexs; - } - (void)std::transform(input_refs_.begin(), input_refs_.end(), std::back_inserter(input_ref_indexs), - [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); - return input_ref_indexs; -} - -std::vector KernelDef::GetOutputRefIndexs() const { - std::vector output_ref_indexs; - if (output_refs_.empty()) { - return output_ref_indexs; - } - (void)std::transform(output_refs_.begin(), output_refs_.end(), std::back_inserter(output_ref_indexs), - [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); - return output_ref_indexs; -} - -std::vector KernelDef::GetWorkspaceRefIndexs() const { - std::vector wk_ref_indexs; - if (wk_space_.empty()) { - return wk_ref_indexs; - } - // only one key - auto wk_refs_iter = wk_space_.begin(); - auto wk_refs = wk_refs_iter->second; - (void)std::transform(wk_refs.begin(), wk_refs.end(), std::back_inserter(wk_ref_indexs), - [](const KernelRefCountPtr &ref_info) { return ref_info->index_; }); - return wk_ref_indexs; -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h deleted file mode 100644 index ea9947b41b..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_copy_manager.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ - -#include -#include -#include -#include -#include -#include "session/kernel_graph.h" -#include "kernel/kernel.h" - -using HostAddress = mindspore::kernel::Address; -namespace mindspore { -namespace device { -namespace memswap { -enum class SwapKind { kDeviceToHost = 0, kHostToDevice = 1 }; - -struct TensorInfo { - size_t tensor_size_{0}; - AnfNodePtr kernel_{nullptr}; - size_t output_idx_{0}; -}; - -struct KernelExecutionInfo { - size_t topo_order_{0}; - float execution_perform_{0.0}; - bool trigger_swap_{false}; - bool need_swap_{false}; - // output index to topo orders of node users - std::map> node_users_map_; - // kernel output idx to host addr - std::map host_addrs_; - - KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} - explicit KernelExecutionInfo(size_t topo_order) - : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} - KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) - : topo_order_(topo_order), - execution_perform_(execution_perform), - trigger_swap_(trigger_swap), - need_swap_(need_swap) {} -}; - -// trigger swap -struct MemSwapInfo { - SwapKind swap_kind_; - // kernel need to be swapped - AnfNodePtr kernel_{nullptr}; - size_t output_idx_{0}; -}; - -class MemCopyManager { - public: - MemCopyManager() = default; - - virtual ~MemCopyManager() = default; - - virtual void Init() {} - - virtual void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} - - virtual void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) {} - - virtual bool SyncMemCopyStream(SwapKind swap_kind) { return true; } - - virtual DeviceAddressPtr UpdateSwapOutQueue() { return nullptr; } - - virtual DeviceAddressPtr UpdateSwapInQueue() { return nullptr; } - - virtual bool AllocHostPinnedMem(size_t size, void **addr) const { return true; } - - virtual void FreeHostPinnedMem(void *addr) const {} - - virtual void ClearSwapQueue() {} -}; -using MemCopyManagerPtr = std::shared_ptr; -} // namespace memswap -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc deleted file mode 100644 index 7c5e87b128..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ /dev/null @@ -1,326 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_dynamic_allocator.h" -#include "common/utils.h" -#include "utils/convert_utils.h" -#include "utils/log_adapter.h" - -namespace mindspore { -namespace device { -DynamicMemPoolBestFit::~DynamicMemPoolBestFit() { - global_mem_block_list_.clear(); - global_idle_mem_buf_map_.clear(); -} - -DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { - size_t align_size = AlignMemorySize(size); - // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf. - DeviceMemPtr device_addr = FindIdleMemBuf(align_size); - if (!device_addr) { - device_addr = AddMemBlockAndMemBuf(align_size); - } - return device_addr; -} - -std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, - std::vector size_list) { - std::vector device_addr_list; - // Pre-alloc the one whole piece memory. - auto device_addr = AllocTensorMem(total_size); - if (!device_addr) { - return device_addr_list; - } - // Remove the pre-alloc memory. - auto mem_block = FindMemBlock(device_addr); - MS_EXCEPTION_IF_NULL(mem_block); - auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); - if (iter == mem_block->block_all_mem_buf_map_.end()) { - MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; - } - auto mem_buf = iter->second; - MS_EXCEPTION_IF_NULL(mem_buf); - auto rest_size = mem_buf->size_ - total_size; - (void)mem_block->block_all_mem_buf_map_.erase(iter); - // Split the pre-alloc memory into continuous memory by the size list. - DynamicMemBufPtr continuous_mem_buf; - auto buf_addr = device_addr; - for (size_t i = 0; i < size_list.size(); i++) { - continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); - (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); - device_addr_list.emplace_back(buf_addr); - buf_addr = AddressOffset(buf_addr, size_list[i]); - } - // Update the size of the last memory buf. - continuous_mem_buf->size_ += rest_size; - return device_addr_list; -} - -size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { - if (size == 0) { - return DYNAMIC_MEM_ALIGN_SIZE; - } - return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE; -} - -DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) { - auto iter = global_idle_mem_buf_map_.lower_bound(size); - if (iter != global_idle_mem_buf_map_.end()) { - auto mem_buf = iter->second; - MS_EXCEPTION_IF_NULL(mem_buf); - if (mem_buf->status_ != kMemBufIdle) { - MS_LOG(EXCEPTION) << "Find the mem_buf is not idle, alloc_size[" << size << "] mem_buf_size[" << mem_buf->size_ - << "] mem_buf_address[" << mem_buf->device_addr_ << "]."; - } - mem_buf->status_ = kMemBufUsed; - // Remove map of old idle memory buf - (void)global_idle_mem_buf_map_.erase(iter); - // Divide memory buf - if (IsDivide(size, mem_buf->size_)) { - DivideMemBuf(size, mem_buf); - } - // Memory statistics - total_used_mem_statistics_ += mem_buf->size_; - if (total_used_mem_statistics_ > used_mem_peak_statistics_) { - used_mem_peak_statistics_ = total_used_mem_statistics_; - } - return mem_buf->device_addr_; - } - return nullptr; -} - -DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) { - size_t alloc_mem_size = CalMemBlockAllocSize(size); - if (alloc_mem_size == 0) { - return nullptr; - } - // Add new memory block - DeviceMemPtr device_addr = nullptr; - auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr); - if (real_alloc_size < size) { - MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size - << "]."; - return nullptr; - } - auto mem_block = std::make_shared(device_addr, real_alloc_size); - MS_EXCEPTION_IF_NULL(mem_block); - auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); - (void)global_mem_block_list_.insert(iter, mem_block); - // Add new memory buf - auto mem_buf = std::make_shared(device_addr, kMemBufUsed, real_alloc_size); - MS_EXCEPTION_IF_NULL(mem_buf); - // Add map of new memory buf in the block - (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf); - // Divide memory buf - if (IsDivide(size, mem_buf->size_)) { - DivideMemBuf(size, mem_buf); - } - // Memory statistics - total_mem_statistics_ += real_alloc_size; - total_used_mem_statistics_ += mem_buf->size_; - if (total_used_mem_statistics_ > used_mem_peak_statistics_) { - used_mem_peak_statistics_ = total_used_mem_statistics_; - } - return mem_buf->device_addr_; -} - -size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) { - auto device_free_mem_size = free_mem_size(); - if (device_free_mem_size < size) { - MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size - << "] is smaller than required size[" << size << "]."; - return 0; - } - auto alloc_mem_size = mem_alloc_unit_size(); - // Growing at twice of alloc size - while (alloc_mem_size < size) { - alloc_mem_size = alloc_mem_size * 2; - } - alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size); - return alloc_mem_size; -} - -bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const { - return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; -} - -void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { - MS_EXCEPTION_IF_NULL(mem_buf); - auto mem_block = FindMemBlock(mem_buf->device_addr_); - MS_EXCEPTION_IF_NULL(mem_block); - // Divide new memory buf - size_t newbuf_size = mem_buf->size_ - size; - mem_buf->size_ = size; - DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size); - auto new_mem_buf = std::make_shared(newbuf_addr, kMemBufIdle, newbuf_size); - // Add map of new memory buf in the block - (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf); - // Add map of new idle memory buf - (void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf); -} - -bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block) { - MS_EXCEPTION_IF_NULL(device_addr); - MS_EXCEPTION_IF_NULL(mem_block); - return device_addr < mem_block->device_addr(); -} - -DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(device_addr); - auto iter = std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock); - if (iter != global_mem_block_list_.begin()) { - return *(--iter); - } - return nullptr; -} - -void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(device_addr); - auto mem_block = FindMemBlock(device_addr); - if (mem_block == nullptr) { - MS_LOG(WARNING) << "Can't find the mem_block of the device address[" << device_addr << "]."; - return; - } - CombineMemBuf(mem_block, device_addr); -} - -void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(mem_block); - MS_EXCEPTION_IF_NULL(device_addr); - auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); - if (iter == mem_block->block_all_mem_buf_map_.end()) { - MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; - } - auto mem_buf = iter->second; - MS_EXCEPTION_IF_NULL(mem_buf); - if (mem_buf->status_ != kMemBufUsed) { - MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "]."; - } - mem_buf->status_ = kMemBufIdle; - total_used_mem_statistics_ -= mem_buf->size_; - // Combine backward(combine the next_mem_buf to mem_buf) - auto next_iter = iter; - (void)next_iter++; - if (next_iter != mem_block->block_all_mem_buf_map_.end()) { - auto next_mem_buf = next_iter->second; - MS_EXCEPTION_IF_NULL(next_mem_buf); - if (next_mem_buf->status_ == kMemBufIdle) { - mem_buf->size_ += next_mem_buf->size_; - EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_); - (void)mem_block->block_all_mem_buf_map_.erase(next_iter); - } - } - // Combine forward(combine the mem_buf to prev_mem_buf) - bool forward_combine = false; - DynamicMemBufPtr prev_mem_buf; - if (iter != mem_block->block_all_mem_buf_map_.begin()) { - auto prev_iter = iter; - (void)prev_iter--; - prev_mem_buf = prev_iter->second; - MS_EXCEPTION_IF_NULL(prev_mem_buf); - if (prev_mem_buf->status_ == kMemBufIdle) { - EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_); - prev_mem_buf->size_ += mem_buf->size_; - (void)mem_block->block_all_mem_buf_map_.erase(iter); - forward_combine = true; - } - } - // Add map of new idle memory - if (forward_combine) { - (void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf); - } else { - (void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf); - } -} - -void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr) { - MS_EXCEPTION_IF_NULL(device_addr); - auto iter = global_idle_mem_buf_map_.equal_range(size); - while (iter.first != iter.second) { - MS_EXCEPTION_IF_NULL(iter.first->second); - // Remove map of the idle memory buf by size and device address - if (iter.first->second->device_addr_ == device_addr) { - (void)global_idle_mem_buf_map_.erase(iter.first); - return; - } - (void)iter.first++; - } - MS_LOG(ERROR) << "Can't find the size[" << size << "] and device address[" << device_addr << "] in the idle mem_buf."; -} - -void DynamicMemPoolBestFit::ReleaseDeviceRes() { - MS_LOG(INFO) << "The dynamic memmory pool total size is " << total_mem_statistics_ << ", total used size is " - << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << "."; - for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { - auto device_addr = (*iter)->device_addr(); - if (device_addr != nullptr) { - if (!FreeDeviceMem(device_addr)) { - MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error."; - } - } - } -} - -void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() { - MS_LOG(INFO) << "Start dump dynamic memory pool info."; - DeviceAddrMapMemBuf mem_block_map; - DynamicMemBufPtr mem_buf; - size_t total_mem = 0; - size_t total_used_mem = 0; - size_t total_idle_mem1 = 0; - size_t total_idle_mem2 = 0; - // Dump the memory block info and memory buf info - MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "]."; - for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) { - total_mem += (*iter)->size(); - mem_block_map = (*iter)->block_all_mem_buf_map_; - MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts[" - << mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size[" - << (*iter)->size() << "]."; - for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) { - mem_buf = iter_mem_buf->second; - MS_EXCEPTION_IF_NULL(mem_buf); - if (mem_buf->status_ == kMemBufIdle) { - total_idle_mem1 += mem_buf->size_; - } else { - total_used_mem += mem_buf->size_; - } - MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status[" - << mem_buf->status_ << "]."; - } - } - // Dump all the idle memory buf info - MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "]."; - for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) { - mem_buf = iter_idle->second; - MS_EXCEPTION_IF_NULL(mem_buf); - total_idle_mem2 += mem_buf->size_; - MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status[" - << mem_buf->status_ << "]."; - } - // Dump the memory statistical info - MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory[" - << total_idle_mem1 << "]."; - if (total_idle_mem1 != total_idle_mem2) { - MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory."; - } - if (total_mem != total_used_mem + total_idle_mem1) { - MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory."; - } - MS_LOG(INFO) << "Finish dump dynamic memory pool info."; -} -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc deleted file mode 100644 index e050f3d590..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ /dev/null @@ -1,436 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_reuse.h" -#include -#include -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace memreuse { -bool MemReuseUtil::InitDynamicOutputKernelRef() { - int index = util_index_; - auto kernel_cnodes = graph_->execution_order(); - if (kernel_cnodes.empty()) { - return true; - } - int kernel_out_ref_num = 0; - for (auto &kernel_cnode : kernel_cnodes) { -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().CheckSignalOps(kernel_cnode); -#endif - if (kernel_cnode == nullptr) { - return false; - } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); - if (kernel_mod == nullptr) { - return false; - } - auto key = kernel_cnode.get(); - // for every apply_kernel to set new output - auto iter = kernel_output_refs_.find(key); - if (iter == kernel_output_refs_.end()) { - auto output_sizes = kernel_mod->GetOutputSizeList(); - KernelRefCountPtrList kernel_refs; - for (auto size : output_sizes) { - total_dy_size_ += size; - // do not MallocDynamicMem just record this - KernelRefCountPtr kernel_ref = std::make_shared(); - index++; - auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); - kernel_ref->stream_id_ = curr_stream_id; - kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); - kernel_refs.push_back(kernel_ref); - kernel_out_ref_num++; - total_refs_list_.push_back(kernel_ref); - } - if (!kernel_refs.empty()) { - kernel_output_refs_[key] = kernel_refs; - } - } - } - return true; -} - -bool MemReuseUtil::InitDynamicWorkspaceKernelRef() { - int WkIndex = util_index_; - auto kernel_cnodes = graph_->execution_order(); - if (kernel_cnodes.empty()) { - return true; - } - for (auto &kernel_cnode : kernel_cnodes) { - if (kernel_cnode == nullptr) { - return false; - } - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_cnode); - if (kernel_mod == nullptr) { - return false; - } - auto key = kernel_cnode.get(); - auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); - KernelRefCountPtrList workspace_kernel_refs; - for (auto size : workspace_sizes) { - total_workspace_size_ += size; - ++WkIndex; - KernelRefCountPtr workspace_ref = std::make_shared(); - workspace_ref->SetKernelRefCountInfo(WkIndex, size, kDynamicRefCount); - workspace_kernel_refs.push_back(workspace_ref); - // total wk ref - total_wk_ref_list_.push_back(workspace_ref); - } - if (!workspace_kernel_refs.empty()) { - // every key index wk_refs - kernel_workspace_refs_[key] = workspace_kernel_refs; - } - } - return true; -} - -bool MemReuseUtil::InitDynamicKernelRef(const KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - graph_ = graph; - is_all_nop_node_ = opt::IsAllNopNode(graph); - if (!InitDynamicOutputKernelRef()) { - MS_LOG(INFO) << "InitDynamicOutputKernelRef fail"; - return false; - } - if (!InitDynamicWorkspaceKernelRef()) { - MS_LOG(INFO) << "InitDynamicWorkspaceKernelRef fail"; - return false; - } - return true; -} - -// set longest worspace list && largest workspace sizes -void MemReuseUtil::SetWorkSpaceList() { - int max_list_size = 0; - std::vector total_sizes; - std::vector max_list; - auto kernel_cnodes = graph_->execution_order(); - for (auto &kernel_cnode : kernel_cnodes) { - MS_EXCEPTION_IF_NULL(kernel_cnode); - auto cnode_key = kernel_cnode.get(); - auto cnode_iter = kernel_workspace_refs_.find(cnode_key); - if (cnode_iter != kernel_workspace_refs_.end()) { - auto kernel_refs = cnode_iter->second; - std::vector current_list; - for (size_t i = 0; i < kernel_refs.size(); ++i) { - auto size = kernel_refs[i]->size_; - current_list.push_back(size); - } - if (max_list_size < SizeToInt(current_list.size())) { - max_list_size = SizeToInt(current_list.size()); - } - (void)std::copy(current_list.begin(), current_list.end(), std::back_inserter(total_sizes)); - } - } - sort(total_sizes.rbegin(), total_sizes.rend()); - max_list.resize(IntToSize(max_list_size)); - if (SizeToInt(total_sizes.size()) < max_list_size) { - MS_LOG(EXCEPTION) << "total workspace size is less than required max list size"; - } - max_list.assign(total_sizes.begin(), total_sizes.begin() + max_list_size); - for (auto &ma : max_list) { - total_reuseworkspace_size_ += ma; - } - max_workspace_size_ = max_list_size; - max_workspace_list_ = max_list; -} - -void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_def_ptr); - auto key = kernel.get(); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto ref_ptr = GetKernelInputRef(kernel, i); - if (ref_ptr != nullptr) { - if (ref_ptr->reftype() == kStaticRefCount) { - continue; - } else if (ref_ptr->reftype() == kDynamicRefCount) { - auto iter = kernel_def_ptr->inputs_.find(key); - if (iter == kernel_def_ptr->inputs_.end()) { - kernel_def_ptr->inputs_[key].push_back(ref_ptr); - } else { - iter->second.push_back(ref_ptr); - } - } - } - } -} - -void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_def_ptr); - auto key = kernel.get(); - auto iter = kernel_def_ptr->outputs_.find(key); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (size_t k = 0; k < kernel_mod->GetOutputSizeList().size(); ++k) { - KernelRefCountPtr kernel_ref = kernel_output_refs_[key][k]; - if (iter == kernel_def_ptr->outputs_.end()) { - kernel_def_ptr->outputs_[key].push_back(kernel_ref); - } else { - iter->second.push_back(kernel_ref); - } - } -} - -void MemReuseUtil::SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(kernel_def_ptr); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto key = kernel.get(); - for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { - if (kernel_workspace_refs_.find(key) != kernel_workspace_refs_.end()) { - auto wk_refs = kernel_workspace_refs_[key]; - if (i < wk_refs.size()) { - auto wk_ref = wk_refs[i]; - kernel_def_ptr->wk_space_[key].push_back(wk_ref); - } else { - MS_LOG(EXCEPTION) << "current index: " << i << " larger than wk_refs size " << wk_refs.size(); - } - } else { - MS_LOG(EXCEPTION) << "kernel_workspace_refs_ init error"; - } - } -} - -KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { - if (node == nullptr) { - MS_LOG(EXCEPTION) << "The node pointer is a nullptr."; - } - if (node->isa()) { - auto ak_node = node->cast(); - auto key = ak_node.get(); - MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx)); - return kernel_output_refs_[key][IntToSize(output_idx)]; - } - return nullptr; -} - -KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t input_idx) { - if (input_idx >= AnfAlgo::GetInputTensorNum(kernel)) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " - << AnfAlgo::GetInputTensorNum(kernel); - } - auto input_node = kernel->input(input_idx + 1); - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - session::KernelWithIndex kernel_input; - if (is_all_nop_node_) { - // The graph does not remove the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); - } else { - // The graph removes the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); - } - if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { - MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; - } - auto result = GetRef(kernel_input.first, SizeToInt(kernel_input.second)); - return result; -} - -void MemReuseUtil::SetKernelDefMap() { - auto kernel_cnodes = graph_->execution_order(); - for (auto &kernel : kernel_cnodes) { - KernelDefPtr kernel_def_ptr = std::make_shared(); - kernel_def_ptr->set_kernel_name(AnfAlgo::GetCNodeName(kernel)); - kernel_def_ptr->set_scope_full_name(kernel->fullname_with_scope()); - kernel_def_ptr->set_stream_id(AnfAlgo::GetStreamId(kernel)); - SetInputMap(kernel, kernel_def_ptr.get()); - SetOutputMap(kernel, kernel_def_ptr.get()); - SetWkMap(kernel, kernel_def_ptr.get()); - auto key = kernel.get(); - kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); - kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); - kernel_def_ptr_list_.push_back(kernel_def_ptr); - kernel_map_[key] = kernel_def_ptr; - } - SetKernelDefInputs(); -} - -void MemReuseUtil::SetKernelDefInputs() { - for (const auto &kernel : graph_->execution_order()) { - MS_EXCEPTION_IF_NULL(kernel); - auto key = kernel.get(); - // find kernel_def according to cnode addr - auto iter = kernel_map_.find(key); - if (iter == kernel_map_.end()) { - MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init."; - } - auto kernel_def = iter->second; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto ref_ptr = GetKernelInputRef(kernel, i); - if (ref_ptr != nullptr) { - // set the inputs of this kernel_def - auto input_node = AnfAlgo::GetInputNode(kernel, i); - // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. - session::KernelWithIndex input; - if (is_all_nop_node_) { - // The graph does not remove the nop node. - input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); - } else { - // The graph removes the nop node. - input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); - } - if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { - MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; - } - auto input_key = (input.first).get(); - auto input_iter = kernel_map_.find(input_key); - if (input_iter == kernel_map_.end()) { - MS_LOG(EXCEPTION) << "kernel [" << (input.first)->fullname_with_scope() << "] is not init."; - } - kernel_def->InsertInputKernel(input_iter->second); - } - } - } -} - -void MemReuseUtil::SetReuseRefCount() { - auto kernels = graph_->execution_order(); - for (auto &kernel : kernels) { - auto key = kernel.get(); - for (auto &def : kernel_def_ptr_list_) { - auto iter = def->inputs_.find(key); - if (iter != def->inputs_.end()) { - for (auto &input : iter->second) { - input->ref_count_++; - input->ref_count_dynamic_use_++; - } - } - } - } -} - -void MemReuseUtil::SetSummaryNodesRefCount() { - bool summary_exist = graph_->summary_node_exist(); - if (!summary_exist) { - return; - } - - auto summary_nodes = graph_->summary_nodes(); - if (summary_nodes.empty()) { - return; - } - - size_t total_summary_size = 0; - for (auto &node_item : summary_nodes) { - auto node = node_item.second.first; - size_t index = IntToSize(node_item.second.second); - if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) { - KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; - kernel_ref->ref_count_ = kMaxRefCount; - kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; - total_summary_size += kernel_ref->size_; - MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; - } else { - MS_LOG(WARNING) << "Can't find summary node's kernel_def " << node->fullname_with_scope() << " index: " << index; - } - } -#ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); -#endif - MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size; -} - -void MemReuseUtil::SetGraphOutputRefCount() { - auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); - for (const auto &node : nodes) { - session::KernelWithIndex kernel_input; - if (is_all_nop_node_) { - // The graph does not remove the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); - } else { - // The graph removes the nop node. - kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - } - MS_EXCEPTION_IF_NULL(kernel_input.first); - if (!kernel_input.first->isa() || !AnfAlgo::IsRealKernel(kernel_input.first)) { - continue; - } - auto ak_node = kernel_input.first->cast(); - auto key = ak_node.get(); - auto iter = kernel_output_refs_.find(key); - if ((iter != kernel_output_refs_.end()) && (kernel_input.second < iter->second.size())) { - auto kernel_ref_count_ptr = kernel_output_refs_[key][kernel_input.second]; - MS_EXCEPTION_IF_NULL(kernel_ref_count_ptr); - kernel_ref_count_ptr->ref_count_ = kMaxRefCount; - kernel_ref_count_ptr->ref_count_dynamic_use_ = kMaxRefCount; - } - } -#ifdef MEM_REUSE_DEBUG - auto graph = *graph_; - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph); -#endif -} - -void MemReuseUtil::ResetDynamicUsedRefCount() { - for (auto iter = kernel_output_refs_.begin(); iter != kernel_output_refs_.end(); ++iter) { - for (auto &ref_count : iter->second) { - MS_EXCEPTION_IF_NULL(ref_count); - ref_count->ref_count_dynamic_use_ = ref_count->ref_count_; - } - } -} - -void MemReuseUtil::SetAllInfo(KernelGraph *graph) { - if (!InitDynamicKernelRef(graph)) { - MS_LOG(EXCEPTION) << "Init ReuseAssignDynamicMemory Fault"; - } - SetKernelDefMap(); - SetReuseRefCount(); - SetSummaryNodesRefCount(); - SetWorkSpaceList(); -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); -#endif -} - -uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { - auto key = node.get(); - auto iter = kernel_output_refs_.find(key); - uint8_t *ptr = nullptr; - if (iter != kernel_output_refs_.end()) { - if (index >= iter->second.size()) { - MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; - } - auto output_ref = iter->second[index]; - ptr = mem_base_ + output_ref->offset_; - } else { - MS_LOG(EXCEPTION) << "node [" << AnfAlgo::GetCNodeName(node) << "] don't exist in kernel_output_refs"; - } - return ptr; -} - -uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const { - auto key = node.get(); - auto iter = kernel_workspace_refs_.find(key); - uint8_t *ptr = nullptr; - if (iter != kernel_workspace_refs_.end()) { - if (index >= iter->second.size()) { - MS_LOG(EXCEPTION) << "index:[" << index << "] is larger than it's workspace size:[" << iter->second.size() << "]"; - } - auto wk_ref = iter->second[index]; - ptr = mem_base_ + wk_ref->offset_; - } - return ptr; -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h deleted file mode 100644 index 37281a7128..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.h +++ /dev/null @@ -1,107 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ -#include -#include -#include -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "kernel/tbe/tbe_utils.h" -using mindspore::kernel::tbe::TbeUtils; -namespace mindspore { -namespace memreuse { -static constexpr int kMaxRefCount = 9999; -static constexpr size_t kDefaultMemAlignSize = 512; -static constexpr size_t kAttAlignSize = 31; -static constexpr int kInvalidIndex = -2; - -using KernelDefPtrMaps = std::vector; -using KernelRefs = std::map; - -using KernelGraph = mindspore::session::KernelGraph; - -class MemReuseUtil { - public: - KernelRefs kernel_output_refs_; - KernelRefCountPtrList total_refs_list_; - KernelRefCountPtrList total_wk_ref_list_; - KernelRefs kernel_workspace_refs_; - MemReuseUtil() : util_index_(kInitIndex), graph_(nullptr), is_all_nop_node_(false) {} - ~MemReuseUtil() { - if (graph_ != nullptr) { - graph_ = nullptr; - } - MS_LOG(INFO) << "Total Dynamic Memory Size: " << total_dy_size_; - MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_; - MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; - } - - void SetAllInfo(KernelGraph *graph); - bool InitDynamicOutputKernelRef(); - bool InitDynamicWorkspaceKernelRef(); - bool InitDynamicKernelRef(const KernelGraph *graph); - void SetWorkSpaceList(); - void SetKernelDefMap(); - void SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); - void SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); - void SetWkMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr); - void SetKernelDefInputs(); - void SetReuseRefCount(); - void SetSummaryNodesRefCount(); - // Set the reference count of graph output specially. - void SetGraphOutputRefCount(); - // Reset the dynamic used reference count by ref_count_. - void ResetDynamicUsedRefCount(); - - KernelRefCountPtr GetRef(const AnfNodePtr &node, int output_idx); - KernelRefCountPtr GetKernelInputRef(const CNodePtr &kernel, size_t input_idx); - KernelRefCountPtrList total_refs_list() const { return total_refs_list_; } - KernelRefCountPtrList total_wk_ref_list() const { return total_wk_ref_list_; } - KernelDefPtrMaps kernel_def_ptr_list() const { return kernel_def_ptr_list_; } - int max_workspace_size() const { return max_workspace_size_; } - std::vector max_workspace_list() const { return max_workspace_list_; } - void set_total_refs_list(const KernelRefCountPtrList &total_refs_list) { total_refs_list_ = total_refs_list; } - void set_kernel_def_ptr_list(const KernelDefPtrMaps &kernel_def_ptr_list) { - kernel_def_ptr_list_ = kernel_def_ptr_list; - } - void set_mem_base(uint8_t *mem_base) { mem_base_ = mem_base; } - uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; - uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; - - private: - int util_index_; - const KernelGraph *graph_; - bool is_all_nop_node_; - KernelRefCountPtrList ref_list_; - KernelDefPtrMaps kernel_def_ptr_list_; - KernelRefCountPtrList last_ref_list_; - int max_workspace_size_ = 0; - std::vector max_workspace_list_; - size_t total_dy_size_ = 0; - size_t total_workspace_size_ = 0; - size_t total_reuseworkspace_size_ = 0; - uint8_t *mem_base_{nullptr}; - // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef - std::map kernel_map_; -}; -using MemReuseUtilPtr = std::shared_ptr; -} // namespace memreuse -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc deleted file mode 100644 index c50cb4b021..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc +++ /dev/null @@ -1,411 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#ifdef ENABLE_D -#include "device/ascend/ascend_stream_assign.h" -#endif - -namespace mindspore { -namespace memreuse { -void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - set_tensor_ptr_list(mem_reuse_util_ptr->total_refs_list()); - set_workspace_ptr_list(mem_reuse_util_ptr->total_wk_ref_list()); - set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); - // check info Correctness - for (auto &tensor : tensor_ptr_list_) { - tensor->size_ = AlignMemorySize(tensor->size_); - } - // align wk size to 512 && refcount == 1 - for (auto &wk : wk_tensor_list_) { - wk->size_ = AlignMemorySize(wk->size_); - wk->ref_count_ = 1; - } -#ifdef ENABLE_D - stream_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group(); -#endif -} - -void BestFitMemReuse::InitKernelDependence() { - for (const auto &kernel : op_ptr_list_) { - std::set front; - std::queue to_visit; - to_visit.push(kernel); - // find all kernels before current kernel - while (!to_visit.empty()) { - auto curr = to_visit.front(); - to_visit.pop(); - if (front.count(curr)) { - continue; - } - front.insert(curr); - auto iter = kernel_front_map_.find(curr); - if (iter != kernel_front_map_.end()) { - auto visited_front = iter->second; - front.insert(visited_front.begin(), visited_front.end()); - continue; - } - for (const auto &input : curr->input_kernels()) { - to_visit.push(input); - } - } - kernel_front_map_[kernel] = front; - } -} - -bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf) { - // determine whether the kernel_curr can reuse kernel_prev's output tensor membuf - MS_EXCEPTION_IF_NULL(kernel_curr); - MS_EXCEPTION_IF_NULL(mem_buf); - auto kernel_prev = mem_buf->used_kernel_; - MS_EXCEPTION_IF_NULL(kernel_prev); - auto curr_stream_id = kernel_curr->stream_id(); - auto prev_stream_id = kernel_prev->stream_id(); - if (curr_stream_id == prev_stream_id) { - mem_buf->type_ = IN_STREAM_REUSE; - return true; - } - - bool reuse_between_streams = true; - for (auto &stream_group : stream_groups_) { - size_t cur_index = UINT32_MAX; - size_t prev_index = UINT32_MAX; - for (size_t index = 0; index < stream_group.size(); index++) { - if (curr_stream_id == stream_group[index]) { - cur_index = index; - continue; - } - if (prev_stream_id == stream_group[index]) { - prev_index = index; - continue; - } - } - if ((prev_index != UINT32_MAX) && (cur_index == UINT32_MAX || (prev_index > cur_index))) { - // previous stream and current stream are not in the same group can't be reused - // previous stream is behind current stream can't be reused - reuse_between_streams = false; - break; - } - } - - if (reuse_between_streams) { - mem_buf->type_ = BETWEEN_STREAMS_REUSE; - return true; - } - - auto iter = kernel_front_map_.find(kernel_curr); - if (iter == kernel_front_map_.end()) { - MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init."; - } - auto kernel_curr_front = iter->second; - auto depend_count = kernel_curr_front.count(kernel_prev); - if (depend_count) { - mem_buf->type_ = KERNEL_DEPENDENCE_REUSE; - return true; - } - - return false; -} - -void BestFitMemReuse::AssignNodeOutputOffset() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { - size_t index = GetTensorIndex(tensor_idx); - auto tensor_desc = tensor_ptr_list_[index]; - MS_EXCEPTION_IF_NULL(tensor_desc); - auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); - if (!reusable_membuf_map.empty()) { - auto membuf_index = reusable_membuf_map.begin()->second; - // find the best suitable membuf in membuf list, and reuse it - ReuseExistMembuf(tensor_desc.get(), membuf_index, kDynamicMem); - } else { - // no membuf can reuse, add new membuf after the membuf_ptr_list - AddNewMembufPtr(tensor_desc.get(), kDynamicMem); -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; -#endif - } - } -} - -void BestFitMemReuse::AssignNodeWorkspaceOffset() { - for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { - size_t index = GetWorkspaceIndex(wk_idx); - auto wk_ref = wk_tensor_list_[index]; - MS_EXCEPTION_IF_NULL(wk_ref); - auto re_wk_membuf_map = GetReusableMembufMap(wk_ref->size_); - if (!re_wk_membuf_map.empty()) { - auto membuf_index = re_wk_membuf_map.begin()->second; - ReuseExistMembuf(wk_ref.get(), membuf_index, kWorkspaceMem); - } else { - AddNewMembufPtr(wk_ref.get(), kWorkspaceMem); - } - } -} - -void BestFitMemReuse::ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag) { - MS_EXCEPTION_IF_NULL(tensor_desc); - CheckMembufIndx(membuf_index); - auto membuf = membuf_ptr_list_[membuf_index]; - MS_EXCEPTION_IF_NULL(membuf); - // first to split && then update membuf_info - if (IsSplit(tensor_desc->size_, membuf->size_)) { - // split the membuf, and insert a new membuf after this membuf - SplitMembuf(tensor_desc, membuf_index); - } - // update membuf status, and set tensor offset - UpdateMembufInfo(tensor_desc, membuf.get(), flag); -} - -std::map BestFitMemReuse::GetReusableMembufMap(size_t tensor_size) { - std::map size_map; - for (size_t i = 0; i < membuf_ptr_list_.size(); ++i) { - auto membuf = membuf_ptr_list_[i]; - auto index = i; - bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size; - if (is_membuf_ok && IsUsable(current_kernel_, membuf)) { - (void)size_map.insert(std::make_pair(membuf->size_, index)); - break; - } - } - return size_map; -} - -void BestFitMemReuse::UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag) { - MS_EXCEPTION_IF_NULL(tensor_desc); - MS_EXCEPTION_IF_NULL(membuf); - auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); - membuf->status_ = kReused; - membuf->index_ = real_index; - membuf->used_kernel_ = current_kernel_; - tensor_desc->offset_ = membuf->offset_; -} - -bool BestFitMemReuse::IsSplit(size_t tensor_size, size_t membuf_size) const { return tensor_size < membuf_size; } - -void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index) { - MS_EXCEPTION_IF_NULL(tensor_desc); - CheckMembufIndx(membuf_index); - auto membuf = membuf_ptr_list_[membuf_index]; - MS_EXCEPTION_IF_NULL(membuf); - auto bias = membuf->size_ - tensor_desc->size_; - membuf->size_ = tensor_desc->size_; - // to check if spilt membuf can be merge - auto new_membuf = std::make_shared(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, - membuf->type_, current_kernel_); - (void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf); -} - -void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) { - MS_EXCEPTION_IF_NULL(tensor_desc); - size_t membuf_offset = 0; - if (!membuf_ptr_list_.empty()) { - membuf_offset = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; - } - auto membuf_size = tensor_desc->size_; - auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag); - auto membuf = std::make_shared(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_); - membuf_ptr_list_.push_back(membuf); - tensor_desc->offset_ = membuf_offset; -} - -void BestFitMemReuse::UpdateNodeInputAndMembuf() { - // process node input tensor - for (const auto &tensor_idx : current_kernel_->GetInputRefIndexs()) { - size_t tensor_index = GetTensorIndex(tensor_idx); - auto tensor_desc = tensor_ptr_list_[tensor_index]; - MS_EXCEPTION_IF_NULL(tensor_desc); - tensor_desc->ref_count_--; - if (tensor_desc->ref_count_ == 0) { - ReleaseMembuf(tensor_index, kDynamicMem); - } else if (tensor_desc->ref_count_ < 0) { - MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ - << " check error"; - } - } -} - -void BestFitMemReuse::ReleaseNodeUnusedOutput() { - for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { - size_t tensor_index = GetTensorIndex(tensor_idx); - auto tensor_desc = tensor_ptr_list_[tensor_index]; - MS_EXCEPTION_IF_NULL(tensor_desc); - if (tensor_desc->ref_count_ == 0) { - ReleaseMembuf(tensor_index, kDynamicMem); - } else if (tensor_desc->ref_count_ < 0) { - MS_LOG(EXCEPTION) << "tensor: " << tensor_desc->index_ << " refcount: " << tensor_desc->ref_count_ - << " check error"; - } - } -} - -void BestFitMemReuse::ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr) { - for (auto &workspace_index : kernel_def_ptr->GetWorkspaceRefIndexs()) { - size_t index = GetWorkspaceIndex(workspace_index); - auto wk_tensor = wk_tensor_list_[index]; - wk_tensor->ref_count_--; - if (wk_tensor->ref_count_ == 0) { - ReleaseMembuf(index, kWorkspaceMem); - } else if (wk_tensor->ref_count_ < 0) { - MS_LOG(EXCEPTION) << "tensor: " << wk_tensor->index_ << " refcount: " << wk_tensor->ref_count_ << " check error"; - } - } -} - -void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { - if (membuf_ptr_list_.empty()) { - return; - } - auto real_index = GetRealIndex(tensor_index, flag); - auto membuf_iter = std::find_if(membuf_ptr_list_.begin(), membuf_ptr_list_.end(), - [real_index](const MembufPtr &membuf) { return membuf->index_ == real_index; }); - if (membuf_iter == membuf_ptr_list_.end()) { - return; - } - auto membuf = (*membuf_iter); - MS_EXCEPTION_IF_NULL(membuf); - membuf->status_ = kUnused; - if (membuf_iter != membuf_ptr_list_.end() - 1) { - auto next_iter = membuf_iter + 1; - auto membuf_next = (*next_iter); - MS_EXCEPTION_IF_NULL(membuf_next); - if (membuf_next->status_ == kUnused) { - bool is_merge = IsUsable(current_kernel_, membuf_next); - if (is_merge) { - membuf->size_ += membuf_next->size_; - (void)membuf_ptr_list_.erase(next_iter); - } - } - } - if (membuf_iter != membuf_ptr_list_.begin()) { - auto prev_iter = membuf_iter - 1; - auto membuf_prev = (*prev_iter); - MS_EXCEPTION_IF_NULL(membuf_prev); - if (membuf_prev->status_ == kUnused) { - bool is_merge = IsUsable(current_kernel_, membuf_prev); - if (is_merge) { - membuf->size_ += membuf_prev->size_; - membuf->offset_ = membuf_prev->offset_; - (void)membuf_ptr_list_.erase(prev_iter); - } - } - } -} - -size_t BestFitMemReuse::AlignMemorySize(size_t size) const { - // memory size 512 align - return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; -} - -size_t BestFitMemReuse::GetAllocatedSize() { - size_t AllocatedSize = kTotalSize; - if (membuf_ptr_list_.empty()) { - return AllocatedSize; - } - AllocatedSize = membuf_ptr_list_.back()->offset_ + membuf_ptr_list_.back()->size_; - MS_LOG(INFO) << "MemReuse Allocated Dynamic Size: " << AllocatedSize; - return AllocatedSize; -} - -bool BestFitMemReuse::IsRelease() { - // unable_used_node include the node type that output tensor cannot be released, - // even if its refcount is equal to zero. - std::unordered_set unable_used_node = {prim::kPrimBatchNorm->name(), prim::kPrimBatchNormGrad->name(), - prim::kPrimFusedBatchNorm->name(), - prim::kPrimFusedBatchNormGrad->name()}; - return unable_used_node.find(current_kernel_->kernel_name()) == unable_used_node.end(); -} - -size_t BestFitMemReuse::GetTensorIndex(int index) const { - if (index < 0 || IntToSize(index) >= tensor_ptr_list_.size()) { - MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); - MS_LOG(EXCEPTION) << "invalid tensor index"; - } - return IntToSize(index); -} - -size_t BestFitMemReuse::GetWorkspaceIndex(int index) const { - if (index < 0 || IntToSize(index) >= wk_tensor_list_.size()) { - MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); - MS_LOG(EXCEPTION) << "invalid tensor index"; - } - return IntToSize(index); -} - -int BestFitMemReuse::GetRealIndex(size_t index, int flag) const { - if (flag == kDynamicMem) { - return SizeToInt(index); - } else if (flag == kWorkspaceMem) { - return kWorkspaceIndexFactor * SizeToInt(index + 1); - } else { - MS_LOG(EXCEPTION) << "flag " << flag << " is invalid"; - } -} - -void BestFitMemReuse::CheckMembufIndx(size_t membuf_index) const { - if (membuf_index >= membuf_ptr_list_.size()) { - MS_LOG(WARNING) << "current cnode: " << current_kernel_->scope_full_name(); - MS_LOG(EXCEPTION) << "invalid membuf index: " << membuf_index << ", real size: " << membuf_ptr_list_.size(); - } -} - -void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { - MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); - InitMemReuseInfo(mem_reuse_util_ptr); - InitKernelDependence(); - KernelDefPtr pre_op = nullptr; -#ifdef MEM_REUSE_DEBUG - size_t op_num = 0; -#endif - for (const auto &op_def_ptr : op_ptr_list_) { - current_kernel_ = op_def_ptr; - // releas pre_op_def - if (pre_op != nullptr) { - ReleasePreNodeWorkspace(pre_op.get()); - } - MemReuseChecker::GetInstance().IsAddNewMembuf_ = false; - // process node output tensor - AssignNodeOutputOffset(); -#ifdef MEM_REUSE_DEBUG - if (MemReuseChecker::GetInstance().IsAddNewMembuf_) { - MemReuseChecker::GetInstance().SetAddNewMembuInfos(op_def_ptr.get(), membuf_ptr_list_, op_num); - } -#endif - // deal with current op'workspace - AssignNodeWorkspaceOffset(); - pre_op = op_def_ptr; - // update node input tensor refcount, and membuf list status - UpdateNodeInputAndMembuf(); - // check node output tensor which refcount is equal to zero - if (IsRelease()) { - ReleaseNodeUnusedOutput(); - } -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().SetMembuInfos(op_def_ptr.get(), membuf_ptr_list_); - ++op_num; -#endif - } -#ifdef MEM_REUSE_DEBUG - MemReuseChecker::GetInstance().ExportMembufInfoIR(); - MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); - MemReuseChecker::GetInstance().set_kernel_front_map(kernel_front_map_); - MemReuseChecker::GetInstance().ExportKernelDependence(); -#endif -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h deleted file mode 100644 index 321a36c824..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.h +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include "pre_activate/mem_reuse/mem_reuse.h" - -namespace mindspore { -namespace memreuse { -static constexpr int kWorkspaceIndexFactor = -1000; -static constexpr int kDynamicMem = -1; -static constexpr int kWorkspaceMem = 1; -static constexpr size_t kTotalSize = 0; -enum Status { kUnused, kReused }; -enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE }; -class Membuf { - public: - Membuf() = default; - Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel) - : status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {} - ~Membuf() = default; - // Memory block status flags - Status status_ = kUnused; - size_t size_{0}; - size_t offset_{0}; - // Store the tensor index stored in this memory block at a certain moment - int index_{0}; - MEMTYPE type_{NEW}; - KernelDefPtr used_kernel_; -}; -using MembufPtr = std::shared_ptr; - -class BestFitMemReuse { - public: - BestFitMemReuse() = default; - ~BestFitMemReuse() { membuf_ptr_list_.clear(); } - /** - * Init all information need by memory reuse - * @param mem_reuse_util_ptr, initialize in the memreuse.cc - */ - void InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr); - void CheckMembufIndx(size_t check_idx) const; - void AssignNodeWorkspaceOffset(); - void ReleasePreNodeWorkspace(const KernelDef *kernel_def_ptr); - /** - * Assign output tensor memory offset of current kernel - */ - void AssignNodeOutputOffset(); - /** - * Update input tensor's status of current kernel, and the status of membuf used by current kernel - */ - void UpdateNodeInputAndMembuf(); - /** - * Check whether to release the kernel output tensor which refcount is equal to zero - */ - void ReleaseNodeUnusedOutput(); - /** - * Reuse the exist membuf if possible - * @param tensor_desc, the output tensor of current kernel - * @param membuf_index, the index of membuf to be reused - * @param flag - */ - void ReuseExistMembuf(KernelRefCount *tensor_desc, size_t membuf_index, int flag); - /** - * Get the membuf that can be reused - * @param tensor_size, the size of the tensor ready to assign memory offset - * @return membuf map, key: the membuf size, value: the membuf index - */ - std::map GetReusableMembufMap(size_t tensor_size); - /** - * Update the status of the reused memory block - * @param tensor_desc, the tensor ready to assign memory - * @param membuf, the membuf to be reused - * @param flag, distinguish dynamic memory and workspace - */ - void UpdateMembufInfo(KernelRefCount *tensor_desc, Membuf *membuf, int flag); - // If the size of the memory block is greater than the size of the tensor, split the extra memory - void SplitMembuf(const KernelRefCount *tensor_desc, size_t membuf_index); - // Determine if the memory block needs to be split - bool IsSplit(size_t tensor_size, size_t membuf_size) const; - // If there is no memory block that can be reused, add a new memory block at the end - void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); - // Merge unused membuf - void ReleaseMembuf(size_t tensor_index, int flag); - // Memory address alignment 512 - size_t AlignMemorySize(size_t size) const; - int GetRealIndex(size_t index, int flag = kDynamicMem) const; - size_t GetTensorIndex(int index) const; - size_t GetWorkspaceIndex(int index) const; - // Memory reuse main program entry - void Reuse(const MemReuseUtil *mem_reuse_util_ptr); - // Get the total memory that needs to be applied eventually - size_t GetAllocatedSize(); - // return false, when the node output cannot be released - bool IsRelease(); - /** - * determine if the kernel_curr can reuse the output tensor add of kernel_prev - * @param kernel_curr, current kernel - * @param mem_buf, the membuf - * @return bool - */ - bool IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf); - /** - * init the dependence of all kernels in the graph - */ - void InitKernelDependence(); - // set tensor_def and op_def - void set_tensor_ptr_list(const std::vector &tensor_ptr_list) { - tensor_ptr_list_ = tensor_ptr_list; - } - void set_workspace_ptr_list(const std::vector &workspace_ptr_list) { - wk_tensor_list_ = workspace_ptr_list; - } - void set_op_ptr_list(const std::vector &op_ptr_list) { op_ptr_list_ = op_ptr_list; } - - private: - KernelDefPtr current_kernel_; - // Save all tensor information - std::vector tensor_ptr_list_; - std::vector wk_tensor_list_; - // Save all op information, including input and output tensor index - std::vector op_ptr_list_; - // Memory block information sequence, temporary variables - std::vector membuf_ptr_list_; - // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def - std::map> kernel_front_map_; - std::vector> stream_groups_; -}; -} // namespace memreuse -} // namespace mindspore -#endif // #define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc deleted file mode 100644 index 1421bc6a7d..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.cc +++ /dev/null @@ -1,572 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_reuse_checker.h" -#include -#include -#include -#include - -namespace mindspore { -namespace memreuse { -MemReuseChecker &MemReuseChecker::GetInstance() { - static MemReuseChecker instance; - return instance; -} - -void MemReuseChecker::CheckSignalOps(const CNodePtr &c_node) { - std::string node_name = AnfAlgo::GetCNodeName(c_node); - if (node_name == kSend || node_name == kRecv) { - MS_LOG(INFO) << "MemReuseChecker check op_name of Send or Send"; - // get op's info && check - MS_LOG(INFO) << "op: " << node_name << " in_num: " << AnfAlgo::GetInputTensorNum(c_node) - << " out_num: " << AnfAlgo::GetOutputTensorNum(c_node); - } -} - -void MemReuseChecker::CheckWorkSpace(const std::vector &max_list) { - for (auto &ma : max_list) { - total_re_wkspe_size_checker_ += ma; - } -} - -void MemReuseChecker::CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx) { - auto key = c_node.get(); - auto iter = kernel_refs.find(key); - auto node_name = AnfAlgo::GetCNodeName(c_node); - if (iter == kernel_refs.end()) { - MS_LOG(EXCEPTION) << "kernel [" << node_name << "] has no output tensor, node: " << c_node->DebugString() - << " output index: " << output_idx; - } - if (output_idx >= iter->second.size()) { - MS_LOG(INFO) << "invalid cnode: " << c_node->fullname_with_scope().c_str(); - MS_LOG(EXCEPTION) << "The index: " << output_idx - << " is out of the size of kernel_output_refs_:" << iter->second.size(); - } -} - -int64_t MemReuseChecker::CalculOriInput(const KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t static_input_size = 0; - for (auto &item : graph->inputs()) { - if (!item->isa()) { - continue; - } - auto output_size = AnfAlgo::GetOutputTensorNum(item); - for (size_t index = 0; index < output_size; index++) { - TypeId ou_type = AnfAlgo::GetOutputDeviceDataType(item, index); - // parameter has not init by a cnode - if (ou_type == kTypeUnknown) { - ou_type = AnfAlgo::GetOutputInferDataType(item, index); - } - size_t type_size = GetTypeByte(TypeIdToType(ou_type)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(item, index); - size_t tensor_size = - shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - auto checker_size = SizeToLong(tensor_size); - static_input_size += checker_size; - } - } - return static_input_size; -} - -int64_t MemReuseChecker::CalculOriValue(KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t static_value_size = 0; - for (auto &value_node : graph->graph_value_nodes()) { - MS_EXCEPTION_IF_NULL(value_node); - auto &node_value = value_node->value(); - MS_EXCEPTION_IF_NULL(node_value); - auto tensor = node_value->cast(); - if (tensor == nullptr) { - continue; - } - size_t tensor_size = tensor->data().nbytes(); - auto checker_size = SizeToLong(tensor_size); - static_value_size += checker_size; - } - return static_value_size; -} - -int64_t MemReuseChecker::CalculOriStatic(KernelGraph *graph) const { - // cal static inputs - auto static_input_size = CalculOriInput(graph); - // do not calcul outpput size - auto statica_value_size = CalculOriValue(graph); - auto total_ori_static_size = static_input_size + statica_value_size; - return total_ori_static_size; -} - -int64_t MemReuseChecker::CalculOriDy(const KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t ori_dy_size = 0; - auto kerenls = graph->execution_order(); - for (auto &kernel : kerenls) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (auto &dy_size : kernel_mod->GetOutputSizeList()) { - auto checker_size = SizeToLong(dy_size); - ori_dy_size += checker_size; - } - } - return ori_dy_size; -} - -int64_t MemReuseChecker::CalculOriWk(const KernelGraph *graph) const { - MS_EXCEPTION_IF_NULL(graph); - int64_t ori_wk_size = 0; - auto kerenls = graph->execution_order(); - for (auto &kernel : kerenls) { - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - for (auto &wk_size : kernel_mod->GetWorkspaceSizeList()) { - auto checker_size = SizeToLong(wk_size); - ori_wk_size += checker_size; - } - } - return ori_wk_size; -} - -std::string MemReuseChecker::GetSplitName(const std::string &scope_name) const { - auto indx = scope_name.rfind(kSplitC); - if (indx == std::string::npos) { - return scope_name; - } else { - if (indx < scope_name.size() - 1) { - auto split_name = scope_name.substr(indx + 1); - return split_name; - } - return scope_name; - } -} - -void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, - const KernelDefPtrMaps &kernel_def_ptr_list, KernelGraph *graph) { - total_ori_static_size_ = CalculOriStatic(graph); - total_ori_input_size_ = CalculOriInput(graph); - total_ori_value_size_ = CalculOriValue(graph); - total_ori_dy_size_ = CalculOriDy(graph); - total_ori_wkspace_size_ = CalculOriWk(graph); - std::string graph_id = std::to_string(graph->graph_id()); - std::string filename = "./memreuse_" + graph_id + ".ir"; - std::ofstream ofs(filename); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; - return; - } - ofs << "all_tensor_refs:\n"; - ofs << "index:" - << "\tsize:" - << "\trefcount:\n"; - for (auto &ref : total_refs_list) { - ofs << "%" << ref->index_ << "T" - << "\t" - << "#" << ref->size_ << "S" - << "\t" << ref->ref_count_ << "C" - << "\n"; - } - ofs << "kernel_def exc_order:\n"; - int def_idx = 0; - for (auto &def : kernel_def_ptr_list) { - ExportMemOpIr(def.get(), ofs, def_idx); - def_idx++; - } - ofs.close(); -} - -void MemReuseChecker::ExportKernelDependence() { - std::string filename = "./memreuse_dependence.ir"; - std::ofstream ofs(filename); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; - return; - } - size_t i = 0; - for (const auto &kernel_front : kernel_front_map_) { - auto kernel = kernel_front.first; - auto front = kernel_front.second; - ofs << "[" << i++ << "] " << kernel->scope_full_name() << "\n"; - for (const auto &node : front) { - ofs << node->scope_full_name() << "\n"; - } - ofs << "\n\n"; - } - - ofs.close(); -} - -bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph) { - // set real graph output node to be special who's refcount equal kMaxRefCount - for (const auto &output : graph->outputs()) { - MS_EXCEPTION_IF_NULL(output); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { - if (output->isa()) { - auto cnode = output->cast(); - auto input_node = cnode->input(i + 1); - auto kernel_input_with_idx = AnfAlgo::VisitKernel(input_node, 0); - auto kernel_input = kernel_input_with_idx.first; - MS_EXCEPTION_IF_NULL(kernel_input); - auto kernel_mod = AnfAlgo::GetKernelMod(kernel_input); - if (kernel_mod == nullptr) { - continue; - } - auto output_sizes = kernel_mod->GetOutputSizeList(); - if (output_sizes.empty()) { - continue; - } - for (size_t j = 0; j < output_sizes.size(); ++j) { - if (!AnfAlgo::OutputAddrExist(kernel_input, j)) { - return false; - } - } - } - } - } - return true; -} - -void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { - auto scope_name = def->scope_full_name(); - std::string split_name = GetSplitName(scope_name); - ofs << "$" << def_idx << "\t" << split_name << "\t"; - ofs << "inputs["; - for (auto &in : def->inputs_) { - for (auto &in_ref : in.second) { - ofs << "%" << in_ref->index_ << "T" - << ","; - } - } - ofs << "]"; - ofs << "\toutpus["; - for (auto &ou : def->outputs_) { - for (auto &ou_ref : ou.second) { - ofs << "%" << ou_ref->index_ << "T" - << ","; - } - } - ofs << "]"; - ofs << "\tstreamID[" - << "@" << def->stream_id() << "]\n"; -} - -void MemReuseChecker::ExportNormalTensorIR(std::ofstream &ofs) { - ofs << "all_tensor_refs:\n"; - ofs << "index:" - << "\tsize:" - << "\trefcount:\n"; - size_t ou_idx = 0; - for (auto &ou : nor_output_tensors_) { - ofs << "%" << ou_idx << "T" - << "\t" - << "#" << nor_tensor_sizes_[ou_idx] << "S" - << "\t"; - auto iter_ref = ptr_refs_.find(ou); - if (iter_ref != ptr_refs_.end()) { - ofs << iter_ref->second << "C" - << "\n"; - } else { - MS_LOG(EXCEPTION) << "can not find refs for output"; - } - ou_idx++; - } - ofs << "kernel_def exc_order:\n"; -} - -int MemReuseChecker::GetTensorIdx(const void *in) const { - auto iter = ptr_idx_.find(in); - if (iter == ptr_idx_.end()) { - return kInvalidIndex; - } else { - return SizeToInt(iter->second); - } -} - -void MemReuseChecker::ExportNormalOpIr(const std::vector &cnodes) { - std::ofstream ofs("./normal_mem.ir"); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file failed!"; - return; - } - ExportNormalTensorIR(ofs); - size_t node_idx = 0; - for (const auto &node : cnodes) { - MS_EXCEPTION_IF_NULL(node); - ofs << "$" << node_idx << "\t" << GetSplitName(node->fullname_with_scope()) << "\t"; - std::vector in_idx; - auto iter = node_ins_.find(node.get()); - if (iter != node_ins_.end()) { - for (auto &in : iter->second) { - if (GetTensorIdx(in) != kInvalidIndex) { - in_idx.push_back(GetTensorIdx(in)); - } - } - } - std::vector ou_idx; - iter = node_ous_.find(node.get()); - if (iter != node_ous_.end()) { - for (auto &ou : iter->second) { - if (GetTensorIdx(ou) != kInvalidIndex) { - ou_idx.push_back(GetTensorIdx(ou)); - } - } - } - ofs << "inputs["; - for (auto idx : in_idx) { - bool has_in_ou = std::any_of(ou_idx.begin(), ou_idx.end(), [idx](int odx) { return idx == odx; }); - if (!has_in_ou) { - ofs << "%" << idx << "T,"; - } - } - ofs << "]\toutpus["; - for (auto odx : ou_idx) { - ofs << "%" << odx << "T,"; - } - ofs << "]\tstreamID[@" << AnfAlgo::GetStreamId(node) << "]\n"; - node_idx++; - } - ofs.close(); -} - -void MemReuseChecker::SetTesnorFromAndToInfo(const KernelDef *op_def) { - auto split_name = GetSplitName(op_def->scope_full_name()); - for (auto &in : op_def->inputs_) { - auto in_tensors = in.second; - for (auto &tensor : in_tensors) { - auto indx = tensor->index_; - tensor_to_[indx].push_back(split_name); - } - } - for (auto &ou : op_def->outputs_) { - auto ou_tensors = ou.second; - for (auto &tensor : ou_tensors) { - auto indx = tensor->index_; - tensor_from_[indx].push_back(split_name); - } - } -} - -void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { - const auto &cnodes = graph->execution_order(); - for (const auto &node : cnodes) { - std::vector curr_ous; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { - auto it = AnfAlgo::GetOutputAddr(node, i); - MS_EXCEPTION_IF_NULL(it); - auto ptr = it->GetPtr(); - nor_output_tensors_.push_back(ptr); - nor_tensor_sizes_.push_back(it->GetSize()); - curr_ous.push_back(it->GetPtr()); - } - (void)node_ous_.insert(std::make_pair(node.get(), curr_ous)); - std::vector curr_ins; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { - if (i + 1 >= node->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index: " << i - << " is larger than input number: " << AnfAlgo::GetInputTensorNum(node); - } - auto real_input_index = AnfAlgo::GetRealInputIndex(node, i); - auto input = node->input(real_input_index + 1); - MS_EXCEPTION_IF_NULL(input); - auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); - if (kernel_with_index.first->isa()) { - continue; - } - auto device_address = AnfAlgo::GetPrevNodeOutputAddr(node, real_input_index); - MS_EXCEPTION_IF_NULL(device_address); - nor_input_tensors_.push_back(device_address->GetPtr()); - curr_ins.push_back(device_address->GetPtr()); - } - (void)node_ins_.insert(std::make_pair(node.get(), curr_ins)); - } - size_t ou_idx = 0; - for (const auto &ou : nor_output_tensors_) { - (void)ptr_idx_.insert(std::make_pair(ou, ou_idx)); - (void)ptr_refs_.insert(std::make_pair(ou, 0)); - ou_idx++; - } - for (const auto &in : nor_input_tensors_) { - if (ptr_idx_.find(in) != ptr_idx_.end()) { - if (ptr_refs_.find(in) != ptr_refs_.end()) { - auto iter = ptr_refs_.find(in); - (iter->second)++; - } else { - MS_LOG(EXCEPTION) << "ptr_refs is not equal to ptr_idx"; - } - } - } - ExportNormalOpIr(cnodes); -} - -void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list) { - std::vector curr_mem_infos; - for (const auto &mem : membuf_ptr_list) { - auto mem_checker = - std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); - curr_mem_infos.push_back(mem_checker); - } - membuf_all_infos_.push_back(curr_mem_infos); - auto split_name = GetSplitName(op_def->scope_full_name()); - all_split_names_.push_back(split_name); - SetTesnorFromAndToInfo(op_def); -} - -void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, - size_t op_idx) { - std::vector add_new_curr_mem; - - for (const auto &mem : membuf_ptr_list) { - auto mem_checker = - std::make_shared(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_); - add_new_curr_mem.push_back(mem_checker); - } - add_new_mem_infos_.push_back(add_new_curr_mem); - auto split_name = GetSplitName(op_def->scope_full_name()); - add_new_names_.push_back(split_name); - add_new_op_indxs_.push_back(op_idx); - add_new_stream_ids_.push_back(op_def->stream_id()); -} - -void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) { - size_t i = 0; - std::vector each_node_used_size; - std::vector each_node_allocated_size; - for (const auto &curr_membuf_list : membuf_all_infos_) { - ofs << all_split_names_.at(i) << "\n"; - ++i; - ofs << "mem_num\t" - << "stream_id\t" - << "status\t" - << "tensor_idex\t" - << "mem_size\t" - << "mem_head\t" - << "mem_tail\t" - << "mem_type\t" - << "used_kernel\n"; - size_t curr_used = 0; - size_t curr_allocated = 0; - for (size_t j = 0; j < curr_membuf_list.size(); ++j) { - auto membuf = curr_membuf_list.at(j); - auto used_kernel = membuf->used_kernel_->scope_full_name(); - ofs << "&" << j << "\t" - << "streamID[@" << membuf->used_kernel_->stream_id() << "]" - << "\t" - << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" - << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t\t" << membuf->offset_ + membuf->size_ << "\t" - << "\t" << static_cast(membuf->type_) << "\t" << GetSplitName(used_kernel) << "\n"; - if (membuf->status_ == kReused) { - curr_used += membuf->size_; - } - } - if (!curr_membuf_list.empty()) { - curr_allocated = curr_membuf_list.back()->offset_ + curr_membuf_list.back()->size_; - } - each_node_used_size.push_back(curr_used); - each_node_allocated_size.push_back(curr_allocated); - ofs << "curr real used size: \t" << curr_used << "\n"; - ofs << "curr allocated size: \t" << curr_allocated << "\n"; - ofs << "\n\n"; - } - auto optimal_iter = std::max_element(each_node_used_size.begin(), each_node_used_size.end()); - ofs << "theoretical optimal size: " << *optimal_iter << "\n"; - ofs << "each node used size: \n"; - for (auto size : each_node_used_size) { - ofs << size << "\t"; - } - ofs << "\n\n"; - ofs << "each node allocated size: \n"; - for (auto size : each_node_allocated_size) { - ofs << size << "\t"; - } - ofs << "\n\n"; -} - -void MemReuseChecker::ExportMembufInfoIR() { - std::string ir_file_name = "./mem_buf_info.ir"; - std::ofstream ofs(ir_file_name); - int64_t total_reuse_size = 0; - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; - } - ofs << "Total static size:\t" << total_ori_static_size_ << "\n"; - ofs << "Graph inputs size:\t" << total_ori_input_size_ << "\n"; - ofs << "Value nodes size:\t" << total_ori_value_size_ << "\n"; - ofs << "Total dynamic size:\t" << total_ori_dy_size_ << "\n"; - ofs << "Total workspace size:\t" << total_ori_wkspace_size_ << "\n"; - // get last membuf_list - if (membuf_all_infos_.empty()) { - return; - } - auto last_membuf_list = membuf_all_infos_.back(); - for (const auto &membuf : last_membuf_list) { - auto checker_size = SizeToLong(membuf->size_); - total_reuse_size += checker_size; - } - ofs << "After reuse size:\t" << total_reuse_size << "\n\n"; - ExportEachMembufInfo(ofs); - ofs.close(); -} - -void MemReuseChecker::ExportAddNewMmebufIR() { - std::string ir_file_name = "./AddNewMembuf.ir"; - std::ofstream ofs(ir_file_name); - if (!ofs.is_open()) { - MS_LOG(ERROR) << "Open file [" << ir_file_name << "] failed!"; - } - auto check_idx = add_new_mem_infos_.size(); - if (check_idx == add_new_op_indxs_.size() && check_idx == add_new_names_.size() && - check_idx == add_new_stream_ids_.size()) { - size_t i = 0; - for (const auto &curr_membuf_list : add_new_mem_infos_) { - ofs << "op_idx:$" << add_new_op_indxs_.at(i) << "\t" << add_new_names_.at(i) << "\t"; - ofs << "streamID[@" << add_new_stream_ids_.at(i) << "]" - << "\n"; - i++; - ofs << "mem_num\t" - << "status\t" - << "tensor_idex\t" - << "mem_size\t" - << "mem_head\t" - << "mem_tail\t" - << "FromOp\t" - << "ToOp\n"; - for (size_t j = 0; j < curr_membuf_list.size(); ++j) { - auto membuf = curr_membuf_list.at(j); - ofs << "&" << j << "\t" - << "\t" - << "#" << static_cast(membuf->status_) << "\t%" << membuf->index_ << "T" - << "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t"; - auto in_idx_iter = tensor_from_.find(membuf->index_); - if (in_idx_iter != tensor_from_.end()) { - for (auto &in_name : in_idx_iter->second) { - ofs << in_name << ","; - } - ofs << "\t"; - } - auto ou_idx_iter = tensor_to_.find(membuf->index_); - if (ou_idx_iter != tensor_to_.end()) { - for (auto &ou_name : ou_idx_iter->second) { - ofs << ou_name << ","; - } - ofs << "\n"; - } - } - ofs << "\n"; - } - } - ofs.close(); -} -} // namespace memreuse -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.h deleted file mode 100644 index 5fd3d0f5ae..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_checker.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ -#include -#include -#include -#include -#include -#include -#include "mindspore/ccsrc/ir/anf.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "kernel/common_utils.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -namespace mindspore { -namespace memreuse { -constexpr auto kSend = "Send"; -constexpr auto kRecv = "Recv"; -constexpr auto kSplitC = '/'; -class MemReuseChecker { - public: - bool IsAddNewMembuf_ = false; - static MemReuseChecker &GetInstance(); - MemReuseChecker(const MemReuseChecker &) = delete; - MemReuseChecker &operator=(const MemReuseChecker &) = delete; - void CheckSignalOps(const CNodePtr &c_node); - void CheckWorkSpace(const std::vector &max_list); - void CheckOutRef(const KernelRefs &kernel_refs, const CNodePtr &c_node, size_t output_idx); - bool CheckGraphOutputAssigned(const session::KernelGraph *graph); - void CheckMemReuseIR(const KernelRefCountPtrList &total_refs_list, const KernelDefPtrMaps &kernel_def_ptr_list, - KernelGraph *graph); - int64_t CalculOriStatic(KernelGraph *graph) const; - int64_t CalculOriInput(const KernelGraph *graph) const; - int64_t CalculOriValue(KernelGraph *graph) const; - int64_t CalculOriDy(const KernelGraph *graph) const; - int64_t CalculOriWk(const KernelGraph *graph) const; - std::string GetSplitName(const std::string &scope_name) const; - int GetTensorIdx(const void *in) const; - void SetMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list); - void SetTesnorFromAndToInfo(const KernelDef *op_def); - void ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx); - void ExportNormalOpIr(const std::vector &cnodes); - void ExportNormalTensorIR(std::ofstream &ofs); - void CheckNormalIR(const session::KernelGraph *graph); - void ExportMembufInfoIR(); - void ExportEachMembufInfo(std::ofstream &ofs); - void SetAddNewMembuInfos(const KernelDef *op_def, const std::vector &membuf_ptr_list, size_t op_idx); - void ExportAddNewMmebufIR(); - void set_kernel_front_map(const std::map> &kernel_front_map) { - kernel_front_map_ = kernel_front_map; - } - void ExportKernelDependence(); - - private: - MemReuseChecker() = default; - ~MemReuseChecker() {} - size_t total_re_wkspe_size_checker_{0}; - std::vector> membuf_all_infos_; - std::vector nor_output_tensors_; - std::vector nor_tensor_sizes_; - std::vector nor_input_tensors_; - std::map ptr_idx_; - std::map ptr_refs_; - std::map> node_ins_; - std::map> node_ous_; - std::vector> add_new_mem_infos_; - std::vector add_new_names_; - std::vector add_new_op_indxs_; - std::vector add_new_stream_ids_; - std::vector all_split_names_; - std::map> tensor_from_; - std::map> tensor_to_; - std::map> kernel_front_map_; - int64_t total_ori_static_size_ = 0; - int64_t total_ori_input_size_ = 0; - int64_t total_ori_value_size_ = 0; - int64_t total_ori_dy_size_ = 0; - int64_t total_ori_wkspace_size_ = 0; -}; -} // namespace memreuse -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_REUSE_CHECKER_H_ diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc deleted file mode 100644 index 14073bfbc9..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.cc +++ /dev/null @@ -1,344 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/mem_reuse/mem_swap_manager.h" -#include -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace device { -namespace memswap { -void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - graph_manager_ = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(graph_manager_); - auto &kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { - if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { - execution_order_.push_back(kernel); - } - } - - size_t kernel_index = 0; - for (const auto &kernel : execution_order_) { - // parse topo order of kernel - (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); - // parse tensor info - auto kernel_mod = AnfAlgo::GetKernelMod(kernel); - MS_EXCEPTION_IF_NULL(kernel_mod); - auto output_sizes = kernel_mod->GetOutputSizeList(); - - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) { - TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx}; - ordered_tensors_.push_back(tensor_info); - } - } - - // parse topo order of user kernel - SaveUserKernelTopoOrder(); - - sort(ordered_tensors_.begin(), ordered_tensors_.end(), - [](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; }); - - auto cur_tensor_size = ordered_tensors_.front().tensor_size_; - for (auto &tensor_info : ordered_tensors_) { - if (cur_tensor_size != tensor_info.tensor_size_) { - cur_tensor_size = tensor_info.tensor_size_; - tensor_size_num_++; - } - } - tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; - tensor_size_threshold_idx_ = 0; - - distance_threshold_ = kernel_index / kDistanceInitFactor; - mem_swap_initialized_ = true; - MS_EXCEPTION_IF_NULL(mem_copy_manager_); - mem_copy_manager_->Init(); -} - -bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { - MS_EXCEPTION_IF_NULL(kernel); - NodeUsersMap &user_map = graph_manager_->node_users(); - auto iter = user_map.find(kernel); - bool adjacent_with_communication_op = false; - if (iter != user_map.end()) { - AnfNodeIndexSet node_set = iter->second; - adjacent_with_communication_op = std::any_of( - node_set.begin(), node_set.end(), - [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); - } - return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; -} - -void MemSwapManager::SaveUserKernelTopoOrder() { - NodeUsersMap &user_map = graph_manager_->node_users(); - for (const auto &kernel : execution_order_) { - auto iter = user_map.find(kernel); - if (iter == user_map.end()) { - continue; - } - AnfNodeIndexSet node_set = iter->second; - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - for (auto &node_pair : node_set) { - auto user_kernel = node_pair.first; - if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { - continue; - } - - size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_; - auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1); - auto &output_idx = kernel_with_index.second; - if (kernel_with_index.first.get() != kernel.get()) { - MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort); - } - for (auto &node_user_pair : kernel_exec_info.node_users_map_) { - sort(node_user_pair.second.begin(), node_user_pair.second.end()); - } - } -} - -void MemSwapManager::AddSwapInfo() { - for (const auto &tensor : ordered_tensors_) { - size_t tensor_size = tensor.tensor_size_; - if (tensor_size < tensor_size_threshold_) { - break; - } - - size_t output_idx = tensor.output_idx_; - const AnfNodePtr &kernel = tensor.kernel_; - if (IsCommunicationRelevantOp(kernel)) { - continue; - } - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - auto &node_users_map = kernel_exec_info.node_users_map_; - - auto iter = node_users_map.find(output_idx); - if (iter == node_users_map.end()) { - continue; - } - auto &node_users = iter->second; - bool need_swap = (node_users.size() == 1 && node_users[0] - kernel_exec_info.topo_order_ >= distance_threshold_) || - (node_users.size() > 1 && node_users[1] - node_users[0] >= distance_threshold_); - if (!need_swap) { - continue; - } - AddKernelNeedSwap(kernel, true); - HostAddress host_addr; - host_addr.size = tensor_size; - auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast(&host_addr.addr)); - if (!ret) { - MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; - } - kernel_exec_info.host_addrs_[output_idx] = host_addr; - MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; - if (node_users.size() > 1) { - AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); - AddKernelTriggerSwap(execution_order_[node_users[0]], true); - } else { - AddKernelMemSwapInfo(kernel, mem_swap_out_info); - AddKernelTriggerSwap(kernel, true); - } - - size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; - if (swap_in_order <= kernel_exec_info.topo_order_) { - MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - auto swap_in_kernel = execution_order_[swap_in_order]; - MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; - AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); - AddKernelTriggerSwap(swap_in_kernel, true); - - host_addrs_list_.push_back(host_addr); - } -} - -void MemSwapManager::AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, - const HostAddress &host_address) const { - if (swap_kind == SwapKind::kDeviceToHost) { - mem_copy_manager_->AddMemSwapOutTask(device_address, host_address); - } else if (swap_kind == SwapKind::kHostToDevice) { - mem_copy_manager_->AddMemSwapInTask(device_address, host_address); - } -} - -bool MemSwapManager::SyncMemCopyStream(SwapKind swap_kind) const { - return mem_copy_manager_->SyncMemCopyStream(swap_kind); -} - -DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { - if (swap_kind == SwapKind::kDeviceToHost) { - return mem_copy_manager_->UpdateSwapOutQueue(); - } else { - return mem_copy_manager_->UpdateSwapInQueue(); - } -} - -// retreat to find a workable swap scheme -bool MemSwapManager::RetreatSwapInfo() { - if (!trigger_swap_) { - trigger_swap_ = true; - } - if (swap_info_already_set_) { - ResetSwapInfo(); - if (distance_threshold_ >= kDistanceLowerBound) { - auto distance_decay_step = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; - distance_threshold_ -= (distance_decay_step > 1 ? distance_decay_step : 1); - } - - while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { - ++tensor_size_threshold_idx_; - if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { - tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; - break; - } - } - - if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { - MS_LOG(ERROR) << "Retreat swap info failed"; - return false; - } - } else { - swap_info_already_set_ = true; - } - AddSwapInfo(); - return true; -} - -KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { - MS_EXCEPTION_IF_NULL(kernel); - auto iter = kernel_execution_info_.find(kernel.get()); - if (iter == kernel_execution_info_.end()) { - MS_LOG(EXCEPTION) << "Can not find execution info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return const_cast(iter->second); -} - -void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.execution_perform_ = perform; -} - -void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.trigger_swap_ = trigger_swap; -} - -void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.need_swap_ = need_swap; -} - -void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, - const std::pair &perform) { - MS_EXCEPTION_IF_NULL(kernel); - kernel_swap_perform_[kernel.get()][output_idx] = perform; -} - -void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { - MS_EXCEPTION_IF_NULL(kernel); - mem_swap_info_[kernel.get()].push_back(mem_swap_info); -} - -float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { - const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.execution_perform_; -} - -bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { - const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.trigger_swap_; -} - -bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { - const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.need_swap_; -} - -const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { - MS_EXCEPTION_IF_NULL(kernel); - auto iter_kernel = kernel_swap_perform_.find(kernel.get()); - if (iter_kernel == kernel_swap_perform_.end()) { - MS_LOG(EXCEPTION) << "Can not find swap performance data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - - auto &perform_map = iter_kernel->second; - auto iter_output = perform_map.find(output_idx); - if (iter_output == perform_map.end()) { - MS_LOG(EXCEPTION) << "Can not find swap performance data of output[" << output_idx << "] of op[" - << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter_output->second; -} - -const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { - MS_EXCEPTION_IF_NULL(kernel); - auto iter = mem_swap_info_.find(kernel.get()); - if (iter == mem_swap_info_.end()) { - MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter->second; -} - -void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } - -bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { - auto iter = swap_in_blacklist_.find(device_ptr); - return iter != swap_in_blacklist_.end(); -} - -const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - auto &host_addrs = kernel_exec_info.host_addrs_; - auto iter = host_addrs.find(output_idx); - if (iter == host_addrs.end()) { - MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter->second; -} - -bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { - return mem_copy_manager_->AllocHostPinnedMem(size, addr); -} - -void MemSwapManager::ReleaseHostPinnedMem() { - for (const auto &host_addr : host_addrs_list_) { - if (host_addr.addr) { - mem_copy_manager_->FreeHostPinnedMem(host_addr.addr); - } - } - host_addrs_list_.clear(); -} - -void MemSwapManager::ClearSwapQueue() const { mem_copy_manager_->ClearSwapQueue(); } - -void MemSwapManager::ResetSwapInfo() { - ClearSwapQueue(); - for (auto &kernel_exec_info_pair : kernel_execution_info_) { - auto &kernel_exec_info = kernel_exec_info_pair.second; - kernel_exec_info.trigger_swap_ = false; - kernel_exec_info.need_swap_ = false; - kernel_exec_info.host_addrs_.clear(); - } - ReleaseHostPinnedMem(); - swap_in_blacklist_.clear(); - mem_swap_info_.clear(); -} -} // namespace memswap -} // namespace device -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h deleted file mode 100644 index 1969dadb54..0000000000 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_swap_manager.h +++ /dev/null @@ -1,132 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include "pre_activate/mem_reuse/mem_copy_manager.h" - -using PerformPair = std::pair; -namespace mindspore { -namespace device { -namespace memswap { -class MemSwapManager { - public: - explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) - : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { - mem_copy_manager_ = mem_copy_manager; - } - - MemSwapManager(const MemSwapManager &) = delete; - - MemSwapManager &operator=(const MemSwapManager &) = delete; - - ~MemSwapManager() = default; - - void Init(const mindspore::session::KernelGraph *kernel_graph); - - void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, - const HostAddress &host_address) const; - - bool SyncMemCopyStream(SwapKind swap_kind) const; - - DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; - - // retreat to find a workable swap scheme - bool RetreatSwapInfo(); - - bool trigger_swap() const { return trigger_swap_; } - - bool mem_swap_init() const { return mem_swap_initialized_; } - - KernelExecutionInfo &SearchKernelExecutionInfo(const AnfNodePtr &kernel) const; - - void AddKernelExecutionPerform(const AnfNodePtr &kernel, float perform); - - float QueryKernelExecutionPerform(const AnfNodePtr &kernel) const; - - void AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, const PerformPair &perform); - - const PerformPair &QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const; - - bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; - - bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; - - const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; - - void InsertSwapInBlackList(const void *device_ptr); - - bool FindInSwapInBlackList(const void *device_ptr) const; - - const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; - - bool AllocHostPinnedMem(size_t size, void **addr) const; - - void ReleaseHostPinnedMem(); - - void ClearSwapQueue() const; - - private: - void AddSwapInfo(); - - void ResetSwapInfo(); - - void SaveUserKernelTopoOrder(); - - void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); - - void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); - - void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); - - bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; - - std::vector execution_order_; - std::vector ordered_tensors_; - std::unordered_map kernel_execution_info_; - std::unordered_map> kernel_swap_perform_; - // trigger swap kernel key : MemSwapInfo of kernel need to be swapped - std::unordered_map> mem_swap_info_; - std::vector host_addrs_list_; - std::unordered_set swap_in_blacklist_; - - size_t tensor_size_threshold_; - size_t tensor_size_threshold_idx_; - size_t tensor_size_num_; - size_t distance_threshold_; - - MemCopyManagerPtr mem_copy_manager_{nullptr}; - FuncGraphManagerPtr graph_manager_{nullptr}; - bool mem_swap_initialized_{false}; - bool swap_info_already_set_{false}; - bool trigger_swap_{false}; - - static constexpr size_t kDistanceInitFactor = 3; - static constexpr size_t kDistanceLowerBound = 3; -}; -using MemSwapManagerPtr = std::shared_ptr; -} // namespace memswap -} // namespace device -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_MEM_REUSE_MEM_SWAP_MANAGER_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.cc b/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.cc deleted file mode 100644 index 9df34a1c59..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.cc +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/add_atomic_clean.h" -#include -#include -#include -#include "operator/ops.h" -#include "utils/utils.h" -#include "utils/graph_utils.h" -#include "utils/log_adapter.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -namespace { - -static std::vector g_output_idx; - -bool HasAtomic(const AnfNodePtr &input) { - if (IsPrimitiveCNode(input)) { - const auto &cnode = input->cast(); - const auto &prim = GetValueNode(cnode->input(0)); - return prim->HasAttr("atomic_add"); - } - return false; -} - -std::vector CalCleanSize(const CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(pre_node); - std::vector clean_size_list; - // clean output - for (auto &index : g_output_idx) { - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(pre_node, index); - size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(pre_node, index); - auto size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - clean_size_list.push_back((size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize); - } - MS_LOG(DEBUG) << "Clear output size: " << clean_size_list.size() << ", pre_node: " << pre_node->fullname_with_scope(); - return clean_size_list; -} - -CNodePtr CreateTbeAtomicCleanNode(const std::shared_ptr &kernel_graph, - const mindspore::CNodePtr &pre_node) { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(pre_node); - auto clean_zero_prim = std::make_shared(kAtomicAddrCleanOpName); - auto new_value_node = NewValueNode(clean_zero_prim); - std::vector inputs = {new_value_node}; - CNodePtr clean_zero = kernel_graph->NewCNode(inputs); - AbstractBasePtr abstract = std::make_shared(); - clean_zero->set_abstract(abstract); - auto builder = std::make_shared(); - builder->SetKernelType(KernelType::TBE_KERNEL); - AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clean_zero.get()); - auto clean_size = CalCleanSize(pre_node); - AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clean_zero); - AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(g_output_idx), clean_zero); - AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clean_zero.get()); - return clean_zero; -} -} // namespace - -void AddAtomicClean(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - auto &todos = kernel_graph->execution_order(); - for (auto iter = todos.cbegin(); iter != todos.end(); ++iter) { - auto node = *iter; - if (AnfAlgo::IsGraphKernel(node) && kernel_graph->nodes().contains(node)) { - auto fg = GetValueNode(node->input(kAnfPrimitiveIndex)); - MS_EXCEPTION_IF_NULL(fg); - auto input = fg->get_return()->input(1); - if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { - const auto &cnode = input->cast(); - for (size_t i = 0; i < cnode->inputs().size(); ++i) { - if (HasAtomic(cnode->input(i))) { - g_output_idx.push_back(i - 1); - } - } - } else if (HasAtomic(input)) { - g_output_idx.push_back(0); - } - - if (!g_output_idx.empty()) { - auto zero_node = CreateTbeAtomicCleanNode(kernel_graph, node); - auto depend = kernel_graph->NewCNode({NewValueNode(prim::kPrimDepend), node->input(1), zero_node}); - std::vector new_input = node->inputs(); - new_input[1] = depend; - auto new_cnode = std::make_shared(new_input, kernel_graph); - // Set abstract - new_cnode->set_abstract(node->abstract()); - // Set kernel info - new_cnode->set_kernel_info(node->kernel_info_ptr()); - mng->Replace(node, new_cnode); - g_output_idx.clear(); - } - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.h b/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.h deleted file mode 100644 index bb1edb0e35..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/add_atomic_clean.h +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H_ - -#include -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -void AddAtomicClean(const std::shared_ptr &kernel_graph); -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ADD_ATOMIC_CLEAN_H diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc deleted file mode 100644 index 297a167aa8..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/common_subexpression_elimination.h" -#include -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(main); - MS_EXCEPTION_IF_NULL(node); - auto main_kernel_info = main->kernel_info(); - auto node_kernel_info = node->kernel_info(); - if (main_kernel_info == nullptr && node_kernel_info == nullptr) { - return true; - } - if (main_kernel_info != nullptr && node_kernel_info != nullptr) { - return *main_kernel_info == *node_kernel_info; - } - return false; -} -} // namespace - -bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool) const { - MS_EXCEPTION_IF_NULL(main); - MS_EXCEPTION_IF_NULL(node); - - bool replace = false; - if (main->isa() && node->isa()) { - auto main_value = GetValueNode(main); - auto node_value = GetValueNode(node); - if (main_value->isa() && node_value->isa()) { - replace = false; - } else if (main_value->isa() && node_value->isa()) { - replace = (AbsOf(main) == AbsOf(node)) && CheckEqualKernelBuildInfo(main, node); - } else { - replace = (AbsOf(main) == AbsOf(node)) && (*main_value == *node_value); - } - } else if (main->isa() && node->isa()) { - if (!CheckEqualKernelBuildInfo(main, node)) { - replace = false; - } else { - auto c_main = main->cast(); - MS_EXCEPTION_IF_NULL(c_main); - auto c_node = node->cast(); - MS_EXCEPTION_IF_NULL(c_node); - const auto &inp1 = c_main->inputs(); - const auto &inp2 = c_node->inputs(); - if (inp1.size() == inp2.size()) { - bool appsame = true; - for (size_t j = 0; j < inp1.size(); j++) { - MS_EXCEPTION_IF_NULL(inp1[j]); - MS_EXCEPTION_IF_NULL(inp2[j]); - if (!(*inp1[j] == *inp2[j])) { - appsame = false; - break; - } - } - replace = appsame; - } - } - } - return replace; -} - -bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto backend_cse = std::make_shared(); - return backend_cse->Cse(func_graph, func_graph->manager()); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h b/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h deleted file mode 100644 index 18f433ab95..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/common_subexpression_elimination.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ -#include "pre_activate/common/pass.h" -#include "optimizer/cse.h" - -namespace mindspore { -namespace opt { -class CommonSubexpressionElimination : public Pass { - public: - CommonSubexpressionElimination() : Pass("cse") {} - ~CommonSubexpressionElimination() override = default; - bool Run(const FuncGraphPtr &func_graph) override; -}; - -class BackendCSE : public CSE { - public: - BackendCSE() = default; - ~BackendCSE() override = default; - bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMON_SUBEXPRESSION_ELIMINATION_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc deleted file mode 100644 index aa4690abcb..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/communication_op_fusion.h" - -#include -#include -#include - -#include "utils/graph_utils.h" -#include "operator/ops.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" -#include "parallel/context.h" - -namespace mindspore { -namespace opt { -namespace { -constexpr auto kAttrDefaultGroup = "default_group"; -constexpr auto kAttrDefaultOp = "default_op"; - -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index, - size_t end_index) { - if (end_index >= communication_op_info.communication_op_nodes.size()) { - MS_LOG(EXCEPTION) << "end index out of vector size"; - } - std::vector inputs_device_format; - std::vector outputs_device_format; - std::vector inputs_device_type; - std::vector outputs_device_type; - std::vector> outputs_shape; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); - inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); - outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); - } - builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); - builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); - builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); - } - builder.SetInputsFormat(inputs_device_format); - builder.SetOutputsFormat(outputs_device_format); - builder.SetInputsDeviceType(inputs_device_type); - builder.SetOutputsDeviceType(outputs_device_type); - return builder.Build(); -} - -std::string GetFusionGroupKey(const AnfNodePtr &node) { - auto primitive = AnfAlgo::GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(primitive); - ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion); - if (attr_fusion == nullptr) { - return ""; - } - int fusion = GetValue(attr_fusion); - if (fusion == 0) { - return ""; - } - std::string group = kAttrDefaultGroup; - ValuePtr attr_group = primitive->GetAttr(kAttrGroup); - if (attr_group != nullptr) { - group = GetValue(attr_group); - } - std::string op = kAttrDefaultOp; - ValuePtr attr_op = primitive->GetAttr(kAttrOp); - if (attr_op != nullptr) { - op = GetValue(attr_op); - } - return group + op + std::to_string(fusion); -} -} // namespace - -bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, - std::vector *segment_index, const std::string &group) const { - MS_EXCEPTION_IF_NULL(segment_num); - MS_EXCEPTION_IF_NULL(segment_index); - size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); - MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size; - - auto parallel_context = parallel::ParallelContext::GetInstance(); - MS_EXCEPTION_IF_NULL(parallel_context); - const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); - - size_t segments = 0; - if (split_indices.size() != 0) { - uint32_t last_index = 0; - for (size_t i = 0; i < split_indices.size(); ++i) { - uint32_t index = split_indices[i]; - if (index <= last_index || index >= communication_op_node_size) { - MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index; - } - segment_index->push_back(index); - last_index = index; - segments++; - } - if (last_index != communication_op_node_size - 1) { - segment_index->push_back(communication_op_node_size - 1); - segments++; - } - } else { - segments = groups_; - for (size_t i = 0; i < segments - 1; ++i) { - segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1); - } - segment_index->push_back(communication_op_node_size - 1); - } - - if (segments >= communication_op_node_size) { - MS_LOG(INFO) << "fusion not changed: segment_num=" << segments - << ", communication_op_node_size=" << communication_op_node_size; - return false; - } - if (segment_index->at(segments - 1) != communication_op_node_size - 1) { - MS_LOG(EXCEPTION) << "the last segment index is invalid."; - } - for (size_t i = 0; i < segments - 1; ++i) { - if (segment_index->at(i) > segment_index->at(i + 1)) { - MS_LOG(EXCEPTION) << "illegal split: segment_index[" << i << "]=" << segment_index->at(i) << ", segment_index[ " - << i + 1 << "]=" << segment_index->at(i + 1); - } - } - *segment_num = segments; - return true; -} - -AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, - const CommunicationOpInfo &communication_op_info, - size_t start_index, size_t end_index) const { - MS_EXCEPTION_IF_NULL(func_graph); - auto prim = std::make_shared(op_name_); - MS_EXCEPTION_IF_NULL(prim); - std::vector fusion_inputs = {NewValueNode(prim)}; - // get all inputs of current segment - if (end_index >= communication_op_info.communication_op_nodes.size()) { - MS_LOG(EXCEPTION) << "end index out of vector size"; - } - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - } - AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); - MS_EXCEPTION_IF_NULL(fused_node); - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - fused_node->set_kernel_info(kernel_info); - AbstractBasePtrList abstract_list; - for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = communication_op_info.communication_op_nodes[idx]; - MS_EXCEPTION_IF_NULL(cnode); - AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); - AnfAlgo::CopyNodeAttr("op", cnode, fused_node); - AnfAlgo::CopyNodeAttr("group", cnode, fused_node); - abstract_list.push_back(cnode->abstract()); - } - auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); - auto abstract_tuple = std::make_shared(abstract_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - fused_node->set_abstract(abstract_tuple); - return fused_node; -} - -bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, - size_t segment_num, const std::vector &segment_index) const { - MS_EXCEPTION_IF_NULL(func_graph); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - bool changed = false; - size_t start_index = 0; - for (size_t segment_idx = 0; segment_idx < segment_num; ++segment_idx) { - size_t end_index = segment_index.at(segment_idx); - if (end_index - start_index < 1) { - start_index = end_index + 1; - continue; - } - AnfNodePtr new_communication_op = - CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index); - // replace old communication op with new communication op - for (auto idx = start_index; idx <= end_index; ++idx) { - std::vector tuple_getitem_input; - tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem)); - tuple_getitem_input.push_back(new_communication_op); - auto index = NewValueNode(SizeToInt(idx - start_index)); - MS_EXCEPTION_IF_NULL(index); - auto imm = std::make_shared(idx - start_index); - MS_EXCEPTION_IF_NULL(imm); - auto abstract_scalar = std::make_shared(); - MS_EXCEPTION_IF_NULL(abstract_scalar); - index->set_abstract(abstract_scalar); - tuple_getitem_input.push_back(index); - AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input); - MS_EXCEPTION_IF_NULL(tuple_getitem); - auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx); - MS_EXCEPTION_IF_NULL(communication_op_node_item); - tuple_getitem->set_abstract(communication_op_node_item->abstract()); - if (!manager->Replace(communication_op_node_item, tuple_getitem)) { - MS_LOG(EXCEPTION) << "manager replace node failed"; - } - } - start_index = end_index + 1; - changed = true; - } - return changed; -} - -bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - const float input_grad_size_num = 0.0; - const float input_grad_time_num = 0.0; - // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion - std::unordered_map candidate_groups; - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == op_name_) { - std::string key = GetFusionGroupKey(node); - if (key.empty()) { - continue; - } - if (candidate_groups.find(key) == candidate_groups.end()) { - CommunicationOpInfo communication_op_info; - candidate_groups[key] = communication_op_info; - } - candidate_groups[key].communication_op_nodes.push_back(node->cast()); - candidate_groups[key].input_grad_size.push_back(input_grad_size_num); - candidate_groups[key].input_grad_time.push_back(input_grad_time_num); - } - } - // split candidate group to segments according to _group class member - bool changed = false; - for (auto &it : candidate_groups) { - if (it.second.communication_op_nodes.size() <= 1) { - continue; - } - auto first_node = it.second.communication_op_nodes[0]; - if (AnfAlgo::HasNodeAttr(kAttrIndex, first_node) && AnfAlgo::GetNodeAttr(first_node, kAttrIndex) > 0) { - std::stable_sort(it.second.communication_op_nodes.begin(), it.second.communication_op_nodes.end(), - [](const CNodePtr &a, const CNodePtr &b) { - return AnfAlgo::GetNodeAttr(a, kAttrIndex) < AnfAlgo::GetNodeAttr(b, kAttrIndex); - }); - } - size_t segment_num = 0; - std::vector segment_index; - if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { - if (DoFusion(func_graph, it.second, segment_num, segment_index)) { - changed = true; - } - } - } - return changed; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h deleted file mode 100644 index d00180f97f..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -struct CommunicationOpInfo { - std::vector communication_op_nodes; - std::vector input_grad_size; - std::vector input_grad_time; -}; - -class CommunicationOpFusion : public Pass { - public: - explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) - : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} - ~CommunicationOpFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num, - const std::vector &segment_index) const; - AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, - const CommunicationOpInfo &communication_op_info, size_t start_index, - size_t end_index) const; - bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, - std::vector *segment_index, const std::string &group) const; - std::string op_name_; - size_t groups_ = 1; -}; - -class AllReduceFusion : public CommunicationOpFusion { - public: - explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} - ~AllReduceFusion() override = default; -}; - -class AllGatherFusion : public CommunicationOpFusion { - public: - explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} - ~AllGatherFusion() override = default; -}; - -class BroadcastFusion : public CommunicationOpFusion { - public: - explicit BroadcastFusion(size_t groups = 1) : CommunicationOpFusion("broadcast_fusion", kBroadcastOpName, groups) {} - ~BroadcastFusion() override = default; -}; - -class ReduceScatterFusion : public CommunicationOpFusion { - public: - explicit ReduceScatterFusion(size_t groups = 1) - : CommunicationOpFusion("reduce_scatter_fusion", kReduceScatterOpName, groups) {} - ~ReduceScatterFusion() override = default; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc deleted file mode 100644 index af82f380f5..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ /dev/null @@ -1,111 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/const_input_to_attr_registry.h" - -#include - -#include "utils/utils.h" -#include "utils/log_adapter.h" -#include "operator/ops.h" - -namespace mindspore { -namespace opt { -ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { - Register(prim::kPrimCast->name(), {1}); - Register(prim::kPrimAvgPoolGrad->name(), {0}); - Register(prim::kPrimConv2DBackpropInput->name(), {2}); - Register(prim::kPrimConv2DBackpropFilter->name(), {2}); - Register(prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), {1}); - Register(prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), {0}); - Register(prim::kPrimReshape->name(), {1}); - Register(prim::kPrimReduceMax->name(), {1}); - Register(prim::kPrimReduceMin->name(), {1}); - Register(prim::kPrimReduceSum->name(), {1}); - Register(prim::kPrimReduceMean->name(), {1}); - Register(prim::kPrimGatherV2->name(), {2}); - Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); - Register(prim::kPrimEmbeddingLookupCommGrad->name(), {1}); - Register(prim::kPrimSubscalar->name(), {1}); - Register(prim::kPrimTranspose->name(), {1}); - Register(prim::kPrimUnsortedSegmentSum->name(), {2}); - Register(prim::kPrimOneHot->name(), {1}); - Register(prim::kPrimConcat->name(), {0}); - Register(prim::kPrimCumSum->name(), {1}); - Register(prim::kPrimCumProd->name(), {1}); - Register(prim::kPrimReduceAll->name(), {1}); - Register(prim::kPrimUnsortedSegmentMin->name(), {2}); - Register(kSparseGatherV2, {2}); - Register(kUnsortedSegmentProdOpName, {2}); - Register(kSimpleMeanGradOpName, {1}); - Register(kMeanGradOpName, {1}); - Register(kSliceOpName, {1, 2}); - Register(kSliceGradOpName, {2, 3}); - Register(kTileOpName, {1}); - Register(kScatterNdOpName, {2}); - Register(kStridedSliceAssignOpName, {1, 2, 3}); - Register(kStridedSliceOpName, {1, 2, 3}); - Register(kFlattenGradOpName, {1}); - Register(kExpandDimsOpName, {1}); - Register(kSplitOpName, {0}); - Register(kErfOpName, {1}); - Register(kSparseApplyAdagradOpName, {2}); - Register(kResizeNearestNeighborGradOpName, {1}); - Register(kResizeNearestNeighborV2OpName, {1}); - Register(kResizeNearestNeighborV2GradOpName, {1}); - Register(kApplyRMSPropOpname, {5, 6, 7}); - Register(kResizeBilinearV2OpName, {1}); - Register(kReduceProdOpName, {1}); - Register(kCumprodOpName, {1}); - Register(kSpaceToBatchOpName, {1}); - Register(kBatchToSpaceOpName, {1}); - Register(kPadOpName, {1}); - Register(kPushOpName, {1}); -} - -ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { - static ConstInputToAttrInfoRegistry instance; - return instance; -} - -void ConstInputToAttrInfoRegistry::Register(const ConstInputToAttrInfoRegister ®) { - auto op_name = reg.GetOpName(); - if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { - (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); - MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; - } -} - -void ConstInputToAttrInfoRegistry::Register(const std::string &op_name, - const std::unordered_set &input_attr_set) { - if (op_input_to_attr_map_.find(op_name) == op_input_to_attr_map_.end()) { - ConstInputToAttrInfoRegister reg(op_name); - (void)reg.SetConstInputToAttr(input_attr_set); - (void)op_input_to_attr_map_.insert(make_pair(op_name, reg)); - MS_LOG(DEBUG) << op_name << " const2attr register successfully!"; - } -} - -bool ConstInputToAttrInfoRegistry::GetRegisterByOpName(const std::string &op_name, - ConstInputToAttrInfoRegister *reg) const { - if (op_input_to_attr_map_.find(op_name) != op_input_to_attr_map_.end()) { - *reg = op_input_to_attr_map_.at(op_name); - MS_LOG(DEBUG) << op_name << " const2attr find in registery."; - return true; - } - return false; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc deleted file mode 100644 index ec2d232584..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.cc +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "ir/primitive.h" -#include "utils/context/ms_context.h" -#include "utils/utils.h" -#include "abstract/abstract_value.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -const size_t strides_index = 5; - -bool GetStridesValues(const CNodePtr &strided_slice_grad, ValuePtrList *strides_values) { - MS_EXCEPTION_IF_NULL(strided_slice_grad); - if (strided_slice_grad->size() < 6) { - MS_LOG(DEBUG) << "Op strided_slice_grad's inputs size less than 6, graph not changed"; - return false; - } - auto strides_input = strided_slice_grad->input(strides_index); - MS_EXCEPTION_IF_NULL(strides_input); - auto strides_value_node = strides_input->cast(); - if (strides_value_node == nullptr) { - MS_LOG(DEBUG) << "strides is not a value node."; - return false; - } - auto value = strides_value_node->value(); - if (value == nullptr) { - MS_LOG(DEBUG) << "strides has no value."; - return false; - } - auto value_tuple = value->cast(); - if (value_tuple == nullptr) { - MS_LOG(DEBUG) << "strides is not a value tuple."; - return false; - } - *strides_values = value_tuple->value(); - return true; -} - -bool CheckValues(const ValuePtrList &strides_values) { - if (strides_values.empty()) { - MS_LOG(DEBUG) << "strides_values is empty"; - return false; - } - for (auto &value : strides_values) { - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - auto scalar = value->cast(); - MS_EXCEPTION_IF_NULL(scalar); - if (!scalar->isa()) { - MS_LOG(DEBUG) << "strides value is not a Integer"; - return false; - } - if (GetValue(scalar) != 1) { - MS_LOG(DEBUG) << "StridedSliceGrad has no 1 value"; - return false; - } - } else { - MS_LOG(DEBUG) << "The value " << value << "of tuple is not a scalar"; - return false; - } - } - return true; -} - -bool CheckAttrs(const CNodePtr &strided_slice_grad) { - MS_EXCEPTION_IF_NULL(strided_slice_grad); - if (!AnfAlgo::HasNodeAttr(kAttrNewAxisMask, strided_slice_grad) || - !AnfAlgo::HasNodeAttr(kAttrShrinkAxisMask, strided_slice_grad)) { - MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not exist in cnode[" + strided_slice_grad->DebugString() + "]"; - return false; - } - auto new_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrNewAxisMask); - auto shrink_axis_mask = AnfAlgo::GetNodeAttr(strided_slice_grad, kAttrShrinkAxisMask); - if (new_axis_mask != 0 || shrink_axis_mask != 0) { - MS_LOG(INFO) << "new_axis_mask or shrink_axis_mask not equal 0"; - return false; - } - return true; -} -} // namespace - -const BaseRef ConstToAttrStridedSliceGradPass::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto strided_slice_grad_prim = std::make_shared(kStridedSliceGradOpName); - return VectorRef({strided_slice_grad_prim, Xs}); -} - -const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto strided_slice_grad = node->cast(); - MS_EXCEPTION_IF_NULL(strided_slice_grad); - - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - - if (ms_context->device_target() == kAscendDevice) { - if (!CheckAttrs(strided_slice_grad)) { - MS_LOG(INFO) << "Check strided_slice_grad's attrs failed, graph not changed"; - return nullptr; - } - - ValuePtrList strides_values; - if (!GetStridesValues(strided_slice_grad, &strides_values)) { - return nullptr; - } - - if (!CheckValues(strides_values)) { - MS_LOG(INFO) << "Check strides' values failed, graph not changed"; - return nullptr; - } - } - - ConstInputToAttr(strided_slice_grad, {1, 2, 3, 4}); - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h b/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h deleted file mode 100644 index 2e364244bf..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/const_to_attr_strided_slice_grad.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConstToAttrStridedSliceGradPass : public PatternProcessPass { - public: - explicit ConstToAttrStridedSliceGradPass(bool multigraph = true) - : PatternProcessPass("const_to_attr_strided_slice_grad_", multigraph) {} - ~ConstToAttrStridedSliceGradPass() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONST_TO_ATTR_STRIDED_SLICE_GRAD_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc deleted file mode 100644 index 38d629c415..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_const_input_to_attr.h" - -#include -#include -#include -#include - -#include "pre_activate/pass/const_input_to_attr_registry.h" -#include "pre_activate/common/helper.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { - return nullptr; - } - std::vector todos; - if (AnfAlgo::IsGraphKernel(node)) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - kernel::GetValidKernelNodes(sub_graph, &todos); - } else { - todos.push_back(node); - } - - for (auto &t : todos) { - CNodePtr cnode = t->cast(); - ConstInputToAttrInfoRegister reg; - if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), ®)) { - continue; - } - ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h deleted file mode 100644 index e124ff8cf4..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ -#include -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertConstInputToAttr : public PatternProcessPass { - public: - explicit ConvertConstInputToAttr(bool multigraph = true) - : PatternProcessPass("convert_const_input_to_attr", multigraph) {} - ~ConvertConstInputToAttr() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - std::unordered_map> op_input_attr_map_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_ATTR_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc deleted file mode 100644 index b4f98cc6d7..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_const_input_to_tensor_input.h" - -#include -#include -#include - -#include "utils/graph_utils.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "session/kernel_graph.h" -#include "kernel/common_utils.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} - -AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { - MS_EXCEPTION_IF_NULL(input_node); - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - tensor::TensorPtr tensor_ptr = nullptr; - if (value->isa()) { - tensor_ptr = ScalarToTensor(value->cast()); - } else if (value->isa()) { - tensor_ptr = CreateTupleTensor(value->cast()); - } else { - MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple"; - } - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "Create tensor failed"; - return nullptr; - } - auto tensor_input = std::make_shared(tensor_ptr); - MS_EXCEPTION_IF_NULL(tensor_input); - tensor_input->set_abstract(tensor_ptr->ToAbstract()); - if (kernel_graph != nullptr) { - tensor_input = kernel_graph->NewValueNode(tensor_input); - kernel_graph->AddValueNodeToGraph(tensor_input); - } else { - tensor_input = MakeValueNode(tensor_input); - } - tensor_input->set_scope(input_node->scope()); - return tensor_input; -} - -AnfNodePtr ConstInputToTensorInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs; - auto kernel_graph = func_graph->cast>(); - auto inputs = cnode->inputs(); - new_inputs.push_back(inputs[0]); - bool need_update = false; - // the first input is primitive node which is not the real input - for (size_t i = 0; i < inputs.size() - 1; ++i) { - auto input_node = inputs[i + 1]; - if (IsValueNode(input_node) || IsValueNode(input_node)) { - auto tensor_input = CreateTensorInput(kernel_graph, input_node); - if (tensor_input == nullptr) { - new_inputs.push_back(input_node); - continue; - } - new_inputs.push_back(tensor_input); - need_update = true; - } else { - new_inputs.push_back(input_node); - } - } - if (need_update) { - MS_EXCEPTION_IF_NULL(func_graph); - auto new_cnode = func_graph->NewCNode(new_inputs); - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_scope(cnode->scope()); - AnfAlgo::CopyNodeAttrs(cnode, new_cnode); - return new_cnode; - } - return nullptr; -} - -AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - auto mng = sub_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todo; - std::vector> graph_rets; - kernel::GetValidKernelNodes(sub_graph, &todo); - kernel::GetGraphRealOutput(sub_graph, &graph_rets); - - for (auto &t : todo) { - auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast()); - if (t_new_node != nullptr && t_new_node != t) { - (void)mng->Replace(t, t_new_node); - } - } - - return node; -} -} // namespace - -const AnfNodePtr ConvertConstInputToTensorInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || func_graph == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { - return nullptr; - } - if (AnfAlgo::IsGraphKernel(node)) { - return ProcessGraphKernelOp(node); - } else { - return ConstInputToTensorInput(func_graph, node->cast()); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.h b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.h deleted file mode 100644 index 1cc2bdf0ec..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertConstInputToTensorInput : public PatternProcessPass { - public: - explicit ConvertConstInputToTensorInput(bool multigraph = true) - : PatternProcessPass("convert_const_input_to_tensor_input", multigraph) {} - ~ConvertConstInputToTensorInput() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_CONST_INPUT_TO_TENSOR_INPUT_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc deleted file mode 100644 index a03087c1a4..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" -#include "session/kernel_graph.h" -#include "kernel/common_utils.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace opt { -namespace { -bool MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return false; - } - - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - TypeId infer_data_type; - if (AnfAlgo::GetOutputTensorNum(value_node) == 0) { - infer_data_type = kTypeUnknown; - } else { - infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0); - } - kernel_build_info_builder->SetOutputsDeviceType(std::vector{infer_data_type}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get()); - return true; -} - -void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node, - std::vector *plant_inputs, std::vector *dyn_input_sizes) { - MS_EXCEPTION_IF_NULL(plant_inputs); - MS_EXCEPTION_IF_NULL(dyn_input_sizes); - MS_EXCEPTION_IF_NULL(graph); - auto output_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes->push_back(output_size); - std::vector convert_inputs; - auto kernel_graph = graph->cast(); - MS_EXCEPTION_IF_NULL(kernel_graph); - if (input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node); - } else { - for (size_t index = 0; index < output_size; ++index) { - auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)}, - {AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get()); - convert_inputs.emplace_back(tuple_get_item); - } - } - (void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs)); -} - -void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - MS_EXCEPTION_IF_NULL(graph); - auto &ori_args = cnode_ptr->inputs(); - if (ori_args.size() < 1) { - return; - } - std::vector plant_inputs; - std::vector dyn_input_sizes; - plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]); - for (size_t i = 1; i < ori_args.size(); ++i) { - auto input_node = ori_args[i]; - if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) { - auto input_size = AnfAlgo::GetOutputTensorNum(input_node); - dyn_input_sizes.push_back(input_size); - auto cnode = input_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto inputs = cnode->inputs(); - for (size_t j = 1; j < inputs.size(); ++j) { - MS_EXCEPTION_IF_NULL(inputs[j]); - if (IsValueNode(inputs[j])) { - auto success = MakeValueNode(inputs[j]); - if (!success) { - MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString(); - } - } - plant_inputs.push_back(inputs[j]); - } - } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { - ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); - } else { - dyn_input_sizes.push_back(-1); - plant_inputs.push_back(input_node); - } - } - // If there is dynamic input, set the dyn_input_sizes as an attribute and update the inputs. - if (std::any_of(dyn_input_sizes.begin(), dyn_input_sizes.end(), [](int s) { return s >= 0; })) { - AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), cnode_ptr); - cnode_ptr->set_inputs(plant_inputs); - } -} -} // namespace - -const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const { - VarPtr V = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { - return nullptr; - } - if (AnfAlgo::IsGraphKernel(node)) { - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - std::vector todos; - kernel::GetValidKernelNodes(sub_graph, &todos); - for (auto &t : todos) { - ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast()); - } - } else { - ConvertMakeTupleInputToPlantInputs(func_graph, node->cast()); - } - return node; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.h b/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.h deleted file mode 100644 index b3d8e25d6e..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ - -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertTupleInputToDynamicInput : public PatternProcessPass { - public: - explicit ConvertTupleInputToDynamicInput(bool multigraph = true) - : PatternProcessPass("convert_tuple_input_to_dynamic_input", multigraph) {} - - ~ConvertTupleInputToDynamicInput() override = default; - - const BaseRef DefinePattern() const override; - - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_CONVERT_TUPLE_INPUT_TO_DYNAMIC_INPUT_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc deleted file mode 100644 index a5e51411bc..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/convert_tuple_output_to_maketuple.h" - -#include -#include - -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -namespace { -CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - MS_EXCEPTION_IF_NULL(graph); - std::vector convert_inputs = {cnode_ptr->input(0)}; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode_ptr); ++index) { - auto input_node = AnfAlgo::GetInputNode(cnode_ptr, index); - if (AnfAlgo::IsTupleOutput(input_node)) { - std::vector types; - std::vector> shapes; - std::vector make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)}; - for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) { - make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index)); - types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index)); - shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index)); - } - auto make_tuple = graph->NewCNode(make_tuple_inputs_list); - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get()); - convert_inputs.emplace_back(make_tuple); - } else { - convert_inputs.push_back(input_node); - } - } - return graph->NewCNode(convert_inputs); -} -} // namespace - -const BaseRef ConvertTupleOutputToMaketuple::DefinePattern() const { - VarPtr V = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - return nullptr; - } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { - return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); - })) { - return ConvertTupleInputToMakeTuple(func_graph, cnode); - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.h b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.h deleted file mode 100644 index a16ffaf674..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H -#define MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H -#include -#include - -#include "ir/anf.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class ConvertTupleOutputToMaketuple : public PatternProcessPass { - public: - explicit ConvertTupleOutputToMaketuple(bool multigraph = true) - : PatternProcessPass("convert_tuple_output_to_maketuple", multigraph) {} - - ~ConvertTupleOutputToMaketuple() override = default; - - const BaseRef DefinePattern() const override; - - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CONVERT_TUPLE_OUTPUT_TO_MAKETUPLE_H diff --git a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.cc deleted file mode 100644 index 4d3dcfccc0..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.cc +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/eliminate_redundant_op.h" -#include -#include -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" -#include "operator/ops.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace opt { -using KernelWithIndex = std::pair; -namespace { -CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector *pass_vector) { - MS_EXCEPTION_IF_NULL(pass_vector); - if (node == nullptr || !node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::IsRealCNodeKernel(cnode)) { - pass_vector->push_back(make_pair(cnode, IntToSize(1))); - return cnode; - } - - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - auto temp_node = cnode->input(index + IntToSize(1)); - MS_EXCEPTION_IF_NULL(temp_node); - pass_vector->push_back(make_pair(cnode, index + IntToSize(1))); - return GetRealPrevCNode(temp_node, 0, pass_vector); - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - auto input2 = cnode->input(2); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - pass_vector->push_back(make_pair(cnode, IntToSize(1))); - return GetRealPrevCNode(cnode->input(1), IntToSize(item_idx), pass_vector); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - pass_vector->push_back(make_pair(cnode, IntToSize(1))); - return GetRealPrevCNode(cnode->input(1), 0, pass_vector); - } else { - return nullptr; - } -} - -bool TransOpEliminateCondition(const CNodePtr &, const CNodePtr &) { return true; } - -bool CastEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { - return HasSymmetricalKernelInfo(node1, node2); -} - -bool TransDataOpEliminateCondition(const CNodePtr &node1, const CNodePtr &node2) { - return AnfAlgo::GetInputFormat(node1, 0) == AnfAlgo::GetOutputFormat(node2, 0) && - AnfAlgo::GetOutputFormat(node1, 0) == AnfAlgo::GetInputFormat(node2, 0); -} - -const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const CNodePtr &prev_cnode, - std::vector *pass_vector) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(pass_vector); - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - - bool has_depend_node = false; - bool has_node_used_more_than_once = false; - auto &users = manager->node_users(); - - auto pass_size = pass_vector->size(); - for (size_t idx = 1; idx <= pass_size - 1; ++idx) { - auto nd = (*pass_vector)[idx].first; - if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || - AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) { - has_depend_node = true; - } - if (users[nd].size() >= 2) { - has_node_used_more_than_once = true; - } - } - - // when no depend node and no node used more than once, no need to rebuild the pass nodes - if (!has_depend_node) { - return prev_cnode->input(1); - } else if (!has_node_used_more_than_once) { - (void)manager->Replace(prev_cnode, prev_cnode->input(1)); - return cnode->input(1); - } else { // rebuild the pass nodes - for (size_t idx = pass_size - 2; idx > 0; --idx) { - auto new_node = func_graph->NewCNode((*pass_vector)[idx].first->inputs()); - new_node->set_input((*pass_vector)[idx].second, - (*pass_vector)[idx + 1].first->input((*pass_vector)[idx + 1].second)); - (*pass_vector)[idx].first = new_node; - } - return (*pass_vector)[1].first; - } -} -} // namespace - -void EliminateRedundantOp::Init() { - (void)redundant_process_map_.emplace(std::pair( - kFour2FiveOpName, std::pair(kFive2FourOpName, TransOpEliminateCondition))); - (void)redundant_process_map_.emplace(std::pair( - kFive2FourOpName, std::pair(kFour2FiveOpName, TransOpEliminateCondition))); - (void)redundant_process_map_.emplace(std::pair( - prim::kPrimCast->name(), std::pair(prim::kPrimCast->name(), CastEliminateCondition))); - (void)redundant_process_map_.emplace(std::pair( - kTransDataOpName, std::pair(kTransDataOpName, TransDataOpEliminateCondition))); -} - -const AnfNodePtr EliminateRedundantOp::DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { - // match the first name - auto name1 = AnfAlgo::GetCNodeName(cnode); - auto it = redundant_process_map_.find(name1); - if (it == redundant_process_map_.end()) { - return nullptr; - } - std::vector pass_vector; - pass_vector.push_back(make_pair(cnode, 1)); - auto prev_cnode = GetRealPrevCNode(cnode->input(1), 0, &pass_vector); - if (prev_cnode == nullptr) { - return nullptr; - } - // match the second name - auto name2 = AnfAlgo::GetCNodeName(prev_cnode); - if (name2 != it->second.first) { - return nullptr; - } - // match condition - auto condition_func = it->second.second; - if (condition_func == nullptr) { - return nullptr; - } - if (!condition_func(cnode, prev_cnode)) { - return nullptr; - } - - return ProcessMatchedNodes(func_graph, cnode, prev_cnode, &pass_vector); -} - -const AnfNodePtr EliminateRedundantOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr || func_graph == nullptr) { - return nullptr; - } - - if (AnfAlgo::IsGraphKernel(node)) { - // do eliminate for ops in graph kernel. - auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(sub_graph); - auto mng = sub_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - std::vector todo; - kernel::GetValidKernelNodes(sub_graph, &todo); - for (auto &t : todo) { - CNodePtr t_cnode = t->cast(); - MS_EXCEPTION_IF_NULL(t_cnode); - auto t_new_node = DoEliminate(sub_graph, t_cnode); - if (t_new_node != nullptr && t_new_node != t) { - (void)mng->Replace(t, t_new_node); - } - } - return node; - } - // do eliminate for single op. - return DoEliminate(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.h b/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.h deleted file mode 100644 index c44190f645..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/eliminate_redundant_op.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ - -#include -#include -#include -#include -#include "ir/anf.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -using ConditionFunc = std::function; -using RedundantOpPair = std::pair; - -class EliminateRedundantOp : public PatternProcessPass { - public: - explicit EliminateRedundantOp(bool multigraph = true) : PatternProcessPass("eliminate_redundant_op", multigraph) { - Init(); - } - ~EliminateRedundantOp() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - - private: - void Init(); - const AnfNodePtr DoEliminate(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; - std::unordered_map redundant_process_map_; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ELIMINATE_REDUNDANT_OP_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.cc b/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.cc deleted file mode 100644 index 3b566b4f7c..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.cc +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/erase_visit_attr.h" -#include -#include -#include "kernel/common_utils.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -const BaseRef EraseVisitAttr::DefinePattern() const { - std::shared_ptr V = std::make_shared(Visited); - std::shared_ptr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr EraseVisitAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - if (node != nullptr && AnfAlgo::IsRealCNodeKernel(node)) { - if (AnfAlgo::IsGraphKernel(node)) { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - std::vector todos; - kernel::GetValidKernelNodes(fg, &todos); - for (auto &t : todos) { - AnfAlgo::EraseNodeAttr(kAttrVisited, t); - } - } - AnfAlgo::EraseNodeAttr(kAttrVisited, node); - } else { - AnfAlgo::EraseNodeAttr(kAttrVisited, node); - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.h b/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.h deleted file mode 100644 index a986aad83a..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/erase_visit_attr.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ - -#include -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class EraseVisitAttr : public PatternProcessPass { - public: - explicit EraseVisitAttr(bool multigraph = true) : PatternProcessPass("erase_visit_attr", multigraph) {} - ~EraseVisitAttr() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ERASE_VISIT_ATTR_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_basic.cc b/mindspore/ccsrc/pre_activate/pass/fuse_basic.cc deleted file mode 100644 index 84edd5c5e2..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_basic.cc +++ /dev/null @@ -1,222 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/fuse_basic.h" -#include "pre_activate/pass/fuse_graph_kernel.h" - -#include -#include -#include -#include -#include -#include - -#include "operator/ops.h" -#include "utils/utils.h" -#include "utils/graph_utils.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "vm/segment_runner.h" -#include "debug/draw.h" -#include "debug/anf_ir_dump.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -namespace { -std::vector get_fusable_basic_ops(bool is_before_kernel_select) { - std::vector fusable_basic_ops = {prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, - prim::kPrimExpandDims}; - if (!is_before_kernel_select) { - fusable_basic_ops.push_back(prim::kPrimCast); - } - return fusable_basic_ops; -} - -IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, - const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); - bool is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - - return is_fusable ? FOLLOW : EXCLUDE; -} - -std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { - GraphKernelInfo info; - info.is_before_kernel_select = is_before_kernel_select; - // Search fusable nodes according input direction. - auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); - auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); - if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes, false); - } - TopoSortForNodeList(&used_nodes); - return used_nodes; -} - -void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) { - AnfNodeSet outputs_set; - for (auto out : *outputs) { - outputs_set.insert(out); - } - - AnfNodePtrList vir_outputs; - std::unordered_map eqv; - auto fg_outputs = fg->output(); - if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) { - auto cnode = fg_outputs->cast(); - for (size_t i = 1; i < cnode->size(); ++i) { - vir_outputs.push_back(cnode->input(i)); - } - } else { - vir_outputs.push_back(fg_outputs); - } - - if (vir_outputs.size() != outputs->size()) { - MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output"; - } - bool has_erase_outs = false; - size_t index = -1; - for (auto it = outputs->begin(); it != outputs->end();) { - index++; - auto out = *it; - eqv[out] = vir_outputs[index]; - auto users = mng->node_users()[out]; - bool is_only_control_depend_use = true; - std::vector control_depend_use_index; - std::vector control_depend_nodes; - AnfNodePtr use_out = nullptr; - for (auto &user : users) { - auto use_node = user.first; - if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - is_only_control_depend_use = false; - continue; - } - if (outputs_set.count(use_node) != 0) { - use_out = use_node; - } - - if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) { - control_depend_nodes.push_back(use_node->cast()); - control_depend_use_index.push_back(user.second); - } - } - - if (is_only_control_depend_use && !control_depend_nodes.empty()) { - MS_EXCEPTION_IF_NULL(use_out); - it = outputs->erase(it); - for (size_t i = 0; i < control_depend_nodes.size(); ++i) { - auto control_depend_node = control_depend_nodes[i]; - std::vector new_control_depend_inputs; - for (size_t j = 0; j < control_depend_node->size(); ++j) { - if (j == control_depend_use_index[i]) { - new_control_depend_inputs.push_back(use_out); - } else { - new_control_depend_inputs.push_back(control_depend_node->input(j)); - } - } - auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs); - mng->Replace(control_depend_node, new_control_depend); - has_erase_outs = true; - } - } else { - it++; - } - } - - if (!has_erase_outs) { - return; - } - - AnfNodePtr fg_new_output; - if (outputs->size() > 1) { - std::vector output_args; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(std::begin(*outputs), std::end(*outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); - // Set output for AnfGraph - fg_new_output = fg->NewCNode(output_args); - } else { - fg_new_output = eqv[(*outputs)[0]]; - } - fg->set_output(fg_new_output, true); -} - -void FuseBasic(const std::shared_ptr &kernel_graph, const std::vector &todos, - std::unordered_set *fused_ops, bool is_before_kernel_select) { - auto mng = kernel_graph->manager(); - for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { - auto node = (*iter)->cast(); - if (node == nullptr) { - continue; - } - if (fused_ops->count(node)) { - continue; - } - auto fusable_basic_ops = get_fusable_basic_ops(is_before_kernel_select); - bool is_basic_op = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - if (!is_basic_op || !kernel_graph->nodes().contains(node)) { - continue; - } - - auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); - if (fuse_nodes.size() <= 1) { - continue; - } - - FuncGraphPtr fg; - AnfNodePtrList inputs; - AnfNodePtrList outputs; - std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); - RemoveControlDependOut(fg, &outputs, mng); - auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs, is_before_kernel_select); - - ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); - - // Set graph kernel attr - std::string fuse_op_name = ""; - for (auto &fuse_node : fuse_nodes) { - fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_"; - } - fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end()); - fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name)); - } -} -} // namespace - -void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - std::unordered_set fused_ops; - auto todos = TopoSort(kernel_graph->get_return()); - std::reverse(todos.begin(), todos.end()); - FuseBasic(kernel_graph, todos, &fused_ops, is_before_kernel_select); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_basic.h b/mindspore/ccsrc/pre_activate/pass/fuse_basic.h deleted file mode 100644 index fbbf5d9937..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_basic.h +++ /dev/null @@ -1,29 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ - -#include -#include "pre_activate/common/optimizer.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -void FuseBasic(const std::shared_ptr &kernel_graph, bool is_before_kernel_select); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_BASIC_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.cc b/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.cc deleted file mode 100644 index 0e287587a2..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.cc +++ /dev/null @@ -1,562 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/fuse_graph_kernel.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "operator/ops.h" -#include "utils/utils.h" -#include "utils/graph_utils.h" -#include "pre_activate/common/helper.h" -#include "session/anf_runtime_algorithm.h" -#include "vm/segment_runner.h" -#include "debug/draw.h" -#include "debug/anf_ir_dump.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace opt { -std::vector get_fusable_basic_ops(bool is_before_kernel_select) { - std::vector fusable_basic_ops = { - prim::kPrimAddN, prim::kPrimTensorAdd, prim::kPrimMul, prim::kPrimSub, prim::kPrimMaximum, - prim::kPrimMinimum, prim::kPrimNeg, prim::kPrimRealDiv, prim::kPrimPow, prim::kPrimSqrt, - prim::kPrimReciprocal, prim::kPrimExpandDims, prim::kPrimLessEqual}; - if (!is_before_kernel_select) { - fusable_basic_ops.push_back(prim::kPrimCast); - } - return fusable_basic_ops; -} - -std::vector get_fusable_basic_ops_with_reduce(bool is_before_kernel_select) { - std::vector fusable_basic_ops_with_reduce; - if (!is_before_kernel_select) { - fusable_basic_ops_with_reduce.push_back(prim::kPrimCast); - } - return fusable_basic_ops_with_reduce; -} - -std::vector get_reduce_ops() { - std::vector reduce_ops = {prim::kPrimReduceSum, prim::kPrimReduceMean, prim::kPrimReduceMin, - prim::kPrimReduceMax, prim::kPrimReduceAll}; - return reduce_ops; -} - -void GetGraphKernelInfo(const FuncGraphPtr fg, GraphKernelInfo *info) { - MS_EXCEPTION_IF_NULL(fg); - auto reduce_ops = get_reduce_ops(); - const auto &nodes = fg->nodes(); - info->op_type = ELEWISE; - info->cal_step = -1; - info->reduce_op_num = 0; - for (auto node : nodes) { - auto cnode = node->cast(); - if (cnode == nullptr) { - continue; - } - info->cal_step++; - auto prim = GetValueNode(cnode->input(0)); - if (prim != nullptr) { - bool is_reudce = std::any_of(reduce_ops.begin(), reduce_ops.end(), [&prim](const PrimitivePtr &op) { - return op->hash() == prim->hash() && op->name() == prim->name(); - }); - if (is_reudce) { - info->op_type = REDUCE; - info->reduce_op_num++; - } - } - } -} - -bool IsFuse(const GraphKernelInfo &info, const AnfNodePtr &node) { - auto fusable_basic_ops = get_fusable_basic_ops(info.is_before_kernel_select); - auto fusable_basic_ops_with_reduce = get_fusable_basic_ops_with_reduce(info.is_before_kernel_select); - bool is_fusable = false; - if (info.op_type == REDUCE && - (info.cal_step >= MAX_REDUCE_OP_FUSION_CAL_STEP || info.reduce_op_num >= MAX_REDUCE_OP_FUSION_REDUCE_NUM)) { - is_fusable = std::any_of(fusable_basic_ops_with_reduce.begin(), fusable_basic_ops_with_reduce.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - } else { - is_fusable = std::any_of(fusable_basic_ops.begin(), fusable_basic_ops.end(), - [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); - } - - return is_fusable; -} - -IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, - const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - bool is_fusable = IsFuse(info, node); - return is_fusable ? FOLLOW : EXCLUDE; -} - -IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const GraphKernelInfo &info, - const AnfNodePtr &node) { - if (cur_node == node) { - return FOLLOW; - } - if (AnfAlgo::IsGraphKernel(node)) { - auto cnode = node->cast(); - auto fg = GetValueNode(cnode->input(kAnfPrimitiveIndex)); - auto fg_attr_val = fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - MS_EXCEPTION_IF_NULL(fg_attr_val); - auto fg_attr = GetValue(fg_attr_val); - if (fg_attr == kApplyMomentumOpName) { - return FOLLOW; - } - return EXCLUDE; - } - if (!IsPrimitiveCNode(node)) { - return EXCLUDE; - } - - bool is_fusable = IsFuse(info, node); - return is_fusable ? FOLLOW : EXCLUDE; -} - -bool CheckCircle(const std::set &fused_op_set, const AnfNodePtr &check_node, - std::set *cached_unconnected_set) { - if (!check_node->isa() || AnfAlgo::IsGraphKernel(check_node)) { - return false; - } - - auto cnode = check_node->cast(); - const auto &inputs = cnode->inputs(); - // there is a input not in fused_op_set, but the input depends on the fused_op_set - bool has_circle = false; - for (auto input : inputs) { - if (input->isa() && !fused_op_set.count(input)) { - std::set done; - std::vector todos = {input}; - while (!todos.empty()) { - auto node = todos.back(); - todos.pop_back(); - if (done.count(node) || cached_unconnected_set->count(node)) { - continue; - } - - done.insert(node); - if (fused_op_set.count(node)) { - has_circle = true; - break; - } - - if (node->isa()) { - auto cnode_ptr = node->cast(); - for (auto it : cnode_ptr->inputs()) { - if (it->isa()) { - todos.push_back(it); - } - } - } - } - - if (has_circle) { - return true; - } - cached_unconnected_set->insert(done.begin(), done.end()); - } - } - - return false; -} - -bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) { - if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { - auto &inputs = out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - real_outs->push_back(inputs[i]); - } - return true; - } - - if (AnfAlgo::GetCNodeFuncGraphPtr(out) != nullptr) { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); - auto fg_out = fg->output(); - if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) { - auto inputs = fg_out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - real_outs->push_back(inputs[i]); - } - return true; - } - } - return false; -} - -std::vector RemoveCircle(const std::vector &fused_op, bool is_backward) { - std::set cached_unconnected_set; - std::set fused_op_set(fused_op.begin(), fused_op.end()); - auto include = [&fused_op_set](const AnfNodePtr &node) { - if (fused_op_set.count(node)) { - return FOLLOW; - } - return EXCLUDE; - }; - for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) { - bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set); - // delete the circle node and the node which depend on the circle node in fused op - if (has_circle) { - auto mng = (*iter)->func_graph()->manager(); - std::vector erase_nodes; - if (is_backward) { - erase_nodes = DeepUsersSearch(*iter, include, mng); - } else { - erase_nodes = DeepLinkedGraphSearch(*iter, include); - } - for (auto erase_node : erase_nodes) { - fused_op_set.erase(erase_node); - } - } - } - - std::vector res; - for (auto node : fused_op) { - if (fused_op_set.count(node)) { - res.push_back(node); - } - } - return res; -} - -void TopoSortForNodeList(std::vector *lst) { - if (lst->size() < 2) { - return; - } - - std::vector res; - std::set node_sets(lst->begin(), lst->end()); - std::map> ins; - std::map> outs; - std::queue q; - for (auto node : *lst) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (auto input : cnode->inputs()) { - if (!node_sets.count(input)) { - continue; - } - // out_degree - outs[input].insert(node); - // in_degree - ins[node].insert(input); - } - if (!ins.count(node)) { - ins[node] = {}; - } - } - - for (auto p : ins) { - if (p.second.size() == 0) { - q.push(p.first); - } - } - - while (!q.empty()) { - auto node = q.front(); - q.pop(); - res.push_back(node); - if (!outs.count(node)) { - continue; - } - for (auto out : outs[node]) { - if (!ins.count(out)) { - continue; - } - ins[out].erase(node); - if (ins[out].size() == 0) { - q.push(out); - } - } - } - - lst->assign(res.begin(), res.end()); -} - -std::vector FindFuseCNodes(const CNodePtr &cnode, bool is_before_kernel_select) { - auto func_graph = cnode->func_graph(); - auto graph_kernel_g = GetValueNode(cnode->input(0)); - GraphKernelInfo info; - info.is_before_kernel_select = is_before_kernel_select; - GetGraphKernelInfo(graph_kernel_g, &info); - auto mng = func_graph->manager(); - // Search fusable nodes according input direction. - auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, info, std::placeholders::_1); - auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward); - std::reverse(used_nodes.begin(), used_nodes.end()); - // Search fusable nodes according output direction. - auto include_func_backward = std::bind(IncludeFusedBasicOpBackward, cnode, info, std::placeholders::_1); - auto user_nodes = DeepUsersSearch(cnode, include_func_backward, mng); - - used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end()); - if (used_nodes.size() > 1) { - used_nodes = RemoveCircle(used_nodes); - } - TopoSortForNodeList(&used_nodes); - return used_nodes; -} - -AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) { - auto out_spec = node->abstract(); - if (out_spec->isa()) { - return out_spec->cast()->elements()[output_idx]; - } - return out_spec; -} - -AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, - const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, - bool is_before_kernel_select) { - auto func_node = NewValueNode(fg); - std::vector fn_inputs; - fn_inputs.push_back(func_node); - fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end()); - auto fuse_cnode = kernel_graph->NewCNode(fn_inputs); - // Set output abstract - if (outputs.size() > 1) { - std::vector out_specs; - for (size_t i = 0; i < outputs.size(); ++i) { - out_specs.push_back(outputs[i]->abstract()); - } - auto out_spec = std::make_shared(out_specs); - fuse_cnode->set_abstract(out_spec); - } else { - fuse_cnode->set_abstract(outputs[0]->abstract()); - } - // Set parameter abstract. - for (size_t i = 0; i < inputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); - auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); - fg->parameters()[i]->set_abstract(input_abs); - if (is_before_kernel_select) { - fg->parameters()[i]->set_kernel_info(std::make_shared()); - } - } - // Set kernel info. - if (!is_before_kernel_select) { - std::vector graph_input_format; - std::vector graph_input_type; - std::vector graph_output_format; - std::vector graph_output_type; - for (size_t i = 0; i < inputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0); - auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); - graph_input_format.push_back(input_format); - auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); - graph_input_type.push_back(input_type); - auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second); - fg->parameters()[i]->set_abstract(input_abs); - } - auto new_outputs = outputs; - if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) { - std::vector real_outs; - if (IsMakeTupleOut(outputs[0], &real_outs)) { - new_outputs = real_outs; - } - } - for (size_t i = 0; i < new_outputs.size(); ++i) { - auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0); - auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); - auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); - graph_output_format.push_back(output_format); - graph_output_type.push_back(output_type); - } - kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; - graph_info_builder.SetInputsFormat(graph_input_format); - graph_info_builder.SetInputsDeviceType(graph_input_type); - graph_info_builder.SetOutputsFormat(graph_output_format); - graph_info_builder.SetOutputsDeviceType(graph_output_type); - graph_info_builder.SetProcessor(kernel::Processor::AICORE); - graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); - graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); - auto graph_selected_info = graph_info_builder.Build(); - AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, fuse_cnode.get()); - } - return fuse_cnode; -} - -void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, - const AnfNodePtrList &outputs) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - // single out - if (outputs.size() == 1) { - mng->Replace(outputs[0], new_fuse_cnode); - return; - } - - std::vector fn_inputs; - for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) { - AnfNodePtrList real_outs; - // not make tuple out, replace - if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) { - fn_inputs.clear(); - fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); - fn_inputs.push_back(new_fuse_cnode); - fn_inputs.push_back(NewValueNode(MakeValue(SizeToInt(out_idx)))); - auto new_out = kernel_graph->NewCNode(fn_inputs); - new_out->set_abstract(outputs[out_idx]->abstract()); - mng->Replace(outputs[out_idx], new_out); - continue; - } - - // the out is make tuple , modify the get_item node's value - auto users = mng->node_users()[outputs[out_idx]]; - for (auto &user : users) { - auto use_node = user.first; - if (use_node->isa() && (IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem))) { - auto get_item_cnode = use_node->cast(); - auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(value_input); - auto value_node = value_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - int new_item_idx = SizeToInt(out_idx) + item_idx; - fn_inputs.clear(); - fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem)); - fn_inputs.push_back(new_fuse_cnode); - fn_inputs.push_back(NewValueNode(new_item_idx)); - auto new_out = kernel_graph->NewCNode(fn_inputs); - new_out->set_abstract(get_item_cnode->abstract()); - mng->Replace(get_item_cnode, new_out); - } - } - } -} - -AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr *fg, FuncGraphManagerPtr *mng) { - AnfNodePtrList outs; - auto out_node = (*fg)->output(); - if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) { - std::vector output_args; - auto out_cnode = out_node->cast(); - for (auto out : out_cnode->inputs()) { - if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) { - auto inputs = out->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - output_args.push_back(inputs[i]); - } - } else { - output_args.push_back(out); - } - } - if (output_args.size() != out_cnode->inputs().size()) { - auto new_out = (*fg)->NewCNode(output_args); - (*mng)->Replace(out_node, new_out); - } - - for (size_t i = 1; i < output_args.size(); ++i) { - outs.push_back(output_args[i]); - } - return outs; - } - - outs.push_back(out_node); - return outs; -} - -AnfNodePtrList GetExpandOuts(const AnfNodePtrList &outs) { - AnfNodePtrList res; - if (outs.size() <= 1) { - return outs; - } - - for (auto out : outs) { - AnfNodePtrList real_outs; - if (IsMakeTupleOut(out, &real_outs)) { - res.insert(res.end(), real_outs.begin(), real_outs.end()); - continue; - } - res.push_back(out); - } - return res; -} - -void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto mng = kernel_graph->manager(); - if (mng == nullptr) { - mng = Manage(kernel_graph, true); - kernel_graph->set_manager(mng); - } - auto &todos = kernel_graph->execution_order(); - for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { - auto node = *iter; - if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) { - continue; - } - - auto origin_fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - auto fg_attr = origin_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); - if (fg_attr != nullptr) { - auto fg_name = GetValue(fg_attr); - if (graph_kernel_black_list.count(fg_name) != 0) { - continue; - } - } - - auto fuse_nodes = FindFuseCNodes(node, is_before_kernel_select); - if (fuse_nodes.size() <= 1) { - continue; - } - - FuncGraphPtr fg; - AnfNodePtrList inputs; - AnfNodePtrList outputs; - std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes); - - // Remove nest make tuple in outs - auto expand_out = GetExpandOuts(outputs); - auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, expand_out, is_before_kernel_select); - - ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs); - - // Inline origin graphkernel - auto cnodes = fg->GetOrderedCnodes(); - for (const auto &n : cnodes) { - if (!AnfAlgo::IsGraphKernel(n)) { - continue; - } - auto graph_kernel_g = GetValueNode(n->input(0)); - AnfNodePtrList ins; - ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end()); - auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope()); - mng->Replace(n, out); - } - - EliminateMakeTuple(&fg, &mng); - // Set graphkernel flag - auto ori_fg = GetValueNode(node->input(kAnfPrimitiveIndex)); - fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, ori_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.h b/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.h deleted file mode 100644 index a5a26765a3..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/fuse_graph_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ - -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ - -#include -#include -#include -#include -#include "pre_activate/common/optimizer.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace opt { -enum GraphKernelType { - ELEWISE = 0, // only contain elewise basic ops - REDUCE, // contain reduce ops - CUBE, // contain cube ops -}; -struct GraphKernelInfo { - GraphKernelType op_type = ELEWISE; - bool is_before_kernel_select = false; - int reduce_op_num = 0; - int cal_step = 0; -}; - -// when reduce graph kernel's cal step is greater than this number, not fuse -const int MAX_REDUCE_OP_FUSION_CAL_STEP = 5; -// when reduce graph kernel contain reduce op num is greater than this number, not fuse -const int MAX_REDUCE_OP_FUSION_REDUCE_NUM = 2; - -const std::set graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward", - "LambNextMV", "LambUpdateWithLR"}; - -std::vector RemoveCircle(const std::vector &fused_op, bool is_backward = true); - -void TopoSortForNodeList(std::vector *lst); - -AnfNodePtr CreateNewFuseCNode(const std::shared_ptr &kernel_graph, const FuncGraphPtr &fg, - const AnfNodePtrList &inputs, const AnfNodePtrList &outputs, - bool is_before_kernel_select); - -void ReplaceNewFuseCNode(const std::shared_ptr &kernel_graph, const AnfNodePtr &new_fuse_cnode, - const AnfNodePtrList &outputs); - -void FuseGraphKernel(const std::shared_ptr &kernel_graph, bool is_before_kernel_select = false); -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_FUSE_GRAPH_KERNEL_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.cc b/mindspore/ccsrc/pre_activate/pass/getitem_tuple.cc deleted file mode 100644 index af16017a7c..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.cc +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/getitem_tuple.h" - -#include -#include "operator/ops.h" -#include "utils/utils.h" -#include "pre_activate/common/helper.h" - -namespace mindspore { -namespace opt { -namespace { -bool IsC(const BaseRef &n) { - MS_EXCEPTION_IF_NULL(n); - if (utils::isa(n)) { - AnfNodePtr in = utils::cast(n); - MS_EXCEPTION_IF_NULL(in); - return in->isa(); - } else { - return false; - } -} -} // namespace - -const BaseRef GetitemTuple::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VarPtr C = std::make_shared(IsC); - return VectorRef({prim::kPrimTupleGetItem, VectorRef({prim::kPrimMakeTuple, Xs}), C}); -} - -const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - CNodePtr tuple_getitem = node->cast(); - MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) { - MS_LOG(EXCEPTION) << "tuple getitem's input num is wrong"; - } - AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(make_tuple_anf); - AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(index_node); - if (IsValueNode(index_node)) { - ValueNodePtr value_node = index_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int index = GetValue(value_node->value()); - CNodePtr make_tuple = make_tuple_anf->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - if (make_tuple->inputs().size() > IntToSize(index + 1)) { - auto ret = make_tuple->input(IntToSize(index + 1)); - MS_EXCEPTION_IF_NULL(ret); - return ret; - } - } - return nullptr; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.h b/mindspore/ccsrc/pre_activate/pass/getitem_tuple.h deleted file mode 100644 index 0fc42a15dc..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/getitem_tuple.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class GetitemTuple : public PatternProcessPass { - public: - explicit GetitemTuple(bool multigraph = true) : PatternProcessPass("getitem_tuple", multigraph) {} - ~GetitemTuple() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_GETITEM_TUPLE_SPLIT_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc deleted file mode 100644 index 1d5f909e7d..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ /dev/null @@ -1,161 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pre_activate/pass/optimize_dependence.h" -#include -#include -#include -#include "pre_activate/common/helper.h" -#include "operator/ops.h" -#include "utils/utils.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" - -namespace mindspore { -namespace opt { -constexpr auto kSingleInputIndex = 1; -namespace { -AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return nullptr; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - string op_name = AnfAlgo::GetCNodeName(cnode); - // Currently we only eliminate transdata or cast nodes. - if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { - return nullptr; - } - CheckCNodeInputSize(cnode, kSingleInputIndex + 1); - return cnode->input(kSingleInputIndex); -} - -AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimMakeTuple->name()) { - return nullptr; - } - std::vector new_make_tuple_inputs; - bool need_update = false; - for (const auto &input : cnode->inputs()) { - AnfNodePtr replace_input = GetReplaceNode(input); - // If replace input is not null, it will be the input of the TransData or Cast. - if (replace_input == nullptr) { - new_make_tuple_inputs.push_back(input); - continue; - } - new_make_tuple_inputs.push_back(replace_input); - need_update = true; - } - if (need_update) { - auto kernel_graph = func_graph->cast>(); - CNodePtr new_make_tuple = nullptr; - if (kernel_graph == nullptr) { - new_make_tuple = func_graph->NewCNode(new_make_tuple_inputs); - } else { - new_make_tuple = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_make_tuple); - new_make_tuple->set_inputs(new_make_tuple_inputs); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->Replace(cnode, new_make_tuple); - return new_make_tuple; - } - return nullptr; -} -} // namespace - -const BaseRef OptimizeDependence::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Xs = std::make_shared(); - return VectorRef({X, Xs}); -} - -const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return nullptr; - } - auto node_name = AnfAlgo::GetCNodeName(node); - if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { - return nullptr; - } - size_t index = 0; - auto depend_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(depend_cnode); - std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; - if (node_name == prim::kPrimDepend->name()) { - index = 1; - new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); - } - if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { - MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " - << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); - } - auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); - while (index < input_num) { - auto replace_node = GetConvertNode(func_graph, node, index); - MS_EXCEPTION_IF_NULL(replace_node); - new_depend_inputs.push_back(replace_node); - ++index; - } - auto kernel_graph = func_graph->cast>(); - CNodePtr new_depend = nullptr; - if (kernel_graph == nullptr) { - new_depend = func_graph->NewCNode(new_depend_inputs); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_abstract(node->abstract()); - new_depend->set_scope(node->scope()); - } else { - new_depend = kernel_graph->NewCNode(depend_cnode); - MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_inputs(new_depend_inputs); - } - return new_depend; -} - -const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, - const size_t index) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto depend_cnode = node->cast(); - auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); - MS_EXCEPTION_IF_NULL(replacing_node); - if (!replacing_node->isa()) { - return replacing_node; - } - auto replacing_cnode = replacing_node->cast(); - MS_EXCEPTION_IF_NULL(replacing_cnode); - // Deal with the make_tuple with TransData or Cast inputs. - auto make_tuple_replace_node = ReplaceMakeTuple(graph, replacing_cnode); - if (make_tuple_replace_node != nullptr) { - return make_tuple_replace_node; - } - AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); - if (replace_node == nullptr) { - MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); - return replacing_node; - } - return replace_node; -} - -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h deleted file mode 100644 index 30027b790a..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ - -#include "pre_activate/common/optimizer.h" - -namespace mindspore { -namespace opt { -class OptimizeDependence : public PatternProcessPass { - public: - explicit OptimizeDependence(bool multigraph = true) : PatternProcessPass("optimize_dependence", multigraph) {} - ~OptimizeDependence() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - const AnfNodePtr GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t index) const; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_OPTIMIZE_DEPENDENCE_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc b/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc deleted file mode 100644 index fd342ec43c..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.cc +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/pass/replace_node_by_proxy.h" -#include -#include -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace opt { -kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector inputs_device_format; - std::vector outputs_device_format; - std::vector inputs_device_type; - std::vector outputs_device_type; - std::vector> outputs_shape; - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { - inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); - inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); - } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { - outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); - outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); - outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); - } - builder.SetFusionType(AnfAlgo::GetFusionType(cnode)); - builder.SetProcessor(AnfAlgo::GetProcessor(cnode)); - builder.SetKernelType(AnfAlgo::GetKernelType(cnode)); - - builder.SetInputsFormat(inputs_device_format); - builder.SetOutputsFormat(outputs_device_format); - builder.SetInputsDeviceType(inputs_device_type); - builder.SetOutputsDeviceType(outputs_device_type); - return builder.Build(); -} - -bool ReplaceNodeByProxy::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) { - CNodePtr cnode = node->cast(); - auto prim = std::make_shared(kEmbeddingLookupProxyOpName); - MS_EXCEPTION_IF_NULL(prim); - std::vector proxy_inputs = {NewValueNode(prim)}; - proxy_inputs.insert(proxy_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); - AnfNodePtr proxy_node = func_graph->NewCNode(proxy_inputs); - MS_EXCEPTION_IF_NULL(proxy_node); - - auto kernel_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(kernel_info); - proxy_node->set_kernel_info(kernel_info); - - AbstractBasePtrList abstract_list; - AnfAlgo::CopyNodeAttr(kAttrPsKey, cnode, proxy_node); - AnfAlgo::CopyNodeAttr("reduce_scatter_flag", cnode, proxy_node); - AnfAlgo::CopyNodeAttr("offset", cnode, proxy_node); - abstract_list.push_back(cnode->abstract()); - auto abstract_tuple = std::make_shared(abstract_list); - MS_EXCEPTION_IF_NULL(abstract_tuple); - proxy_node->set_abstract(abstract_tuple); - - auto kernel_build_info = GenerateKernelBuildInfo(cnode); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, proxy_node.get()); - - if (!manager->Replace(cnode, proxy_node)) { - MS_LOG(EXCEPTION) << "Replace node by proxy node failed."; - } - } - } - return true; -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h b/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h deleted file mode 100644 index 2549501a0a..0000000000 --- a/mindspore/ccsrc/pre_activate/pass/replace_node_by_proxy.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ -#include -#include -#include - -#include "pre_activate/common/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "utils/utils.h" -#include "kernel/kernel_build_info.h" - -namespace mindspore { -namespace opt { -class ReplaceNodeByProxy : public Pass { - public: - explicit ReplaceNodeByProxy(const std::string &name) : Pass(name) {} - ~ReplaceNodeByProxy() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CNodePtr &cnode); -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_REPLACE_NODE_BY_PROXY_H_ diff --git a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h b/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h index 5c7551a190..612ccde1a5 100644 --- a/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h +++ b/mindspore/ccsrc/predict/converter/attr_utils/convert_util.h @@ -25,7 +25,7 @@ #include #include #include "ir/tensor.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #include "predict/schema/inner/ms_generated.h" using TensorPtr = mindspore::tensor::TensorPtr; diff --git a/mindspore/ccsrc/predict/converter/kernel2ms.cc b/mindspore/ccsrc/predict/converter/kernel2ms.cc index 1b1277aade..04aceb62eb 100644 --- a/mindspore/ccsrc/predict/converter/kernel2ms.cc +++ b/mindspore/ccsrc/predict/converter/kernel2ms.cc @@ -18,7 +18,7 @@ #include #include "ir/anf.h" #include "predict/converter/lite_model/op_attr_packer.h" -#include "mindspore/ccsrc/operator/ops.h" +#include "mindspore/ccsrc/frontend/operator/ops.h" namespace mindspore { namespace executor { diff --git a/mindspore/ccsrc/predict/converter/kernel2ms.h b/mindspore/ccsrc/predict/converter/kernel2ms.h index 7013f88107..8cbc89ed6a 100644 --- a/mindspore/ccsrc/predict/converter/kernel2ms.h +++ b/mindspore/ccsrc/predict/converter/kernel2ms.h @@ -22,7 +22,7 @@ #include #include #include -#include "session/kernel_graph.h" +#include "backend/session/kernel_graph.h" #include "predict/converter/executor_tensor.h" #include "predict/schema/inner/ms_generated.h" #include "predict/converter/attr_utils/convert_util.h" diff --git a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h index 89e38d1871..31f14ef73a 100644 --- a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h +++ b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.h @@ -20,7 +20,7 @@ #include #include #include -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #include "predict/schema/inner/ms_generated.h" static constexpr size_t kNIndex = 0; diff --git a/mindspore/ccsrc/predict/predict.h b/mindspore/ccsrc/predict/predict.h index 7c65f16619..9125451492 100644 --- a/mindspore/ccsrc/predict/predict.h +++ b/mindspore/ccsrc/predict/predict.h @@ -19,7 +19,7 @@ #include #include -#include "session/session_basic.h" +#include "backend/session/session_basic.h" #include "predict/converter/kernel2ms.h" namespace mindspore { diff --git a/mindspore/ccsrc/pynative/CMakeLists.txt b/mindspore/ccsrc/pynative/CMakeLists.txt deleted file mode 100644 index 5139160774..0000000000 --- a/mindspore/ccsrc/pynative/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -file(GLOB_RECURSE _PYNATIVE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "base.cc" "pynative_execute.cc") - -if (ENABLE_GE) - file(GLOB_RECURSE _GE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "pynative_execute_ge.cc") - list(APPEND _PYNATIVE_SRC_LIST ${_GE_SRC_LIST}) -endif () - -set_property(SOURCE ${_PYNATIVE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PYNATIVE) -add_library(_mindspore_pynative_obj OBJECT ${_PYNATIVE_SRC_LIST}) diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc deleted file mode 100644 index 16b55554d4..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ /dev/null @@ -1,1167 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pynative/pynative_execute.h" - -#include -#include -#include -#include -#include - -#include "debug/trace.h" -#include "ir/tensor_py.h" -#include "ir/param_value.h" -#include "utils/any.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "operator/composite/composite.h" -#include "operator/composite/do_signature.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/resolve.h" -#include "pipeline/static_analysis/prim.h" -#include "session/session_factory.h" -#include "pre_activate/pass/const_input_to_attr_registry.h" -#include "pre_activate/common/helper.h" -#include "pipeline/action.h" - -#include "pynative/base.h" -#include "pybind_api/api_register.h" -#include "vm/transform.h" - -#include "optimizer/ad/grad.h" -#include "pipeline/resource.h" -#include "pipeline/pipeline.h" -#include "pipeline/pass.h" - -#ifdef ENABLE_GE -#include "pynative/pynative_execute_ge.h" -#endif - -using mindspore::tensor::TensorPy; - -const char SINGLE_OP_GRAPH[] = "single_op_graph"; -// primitive unable to infer value for constant input in PyNative mode -const std::set vm_operators = {"make_ref", "HookBackward", "stop_gradient"}; - -namespace mindspore { -namespace pynative { - -static std::shared_ptr session = nullptr; -PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; -std::mutex PynativeExecutor::instance_lock_; -ResourcePtr PynativeExecutor::resource_; - -template -void PynativeExecutorTry(PynativeExecutor *const executor, void (PynativeExecutor::*method)(Args...), Args &&... args) { - try { - (executor->*method)(args...); - } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::ostringstream oss; - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info - py::print(oss.str()); - MS_LOG(ERROR) << oss.str(); - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::index_error(ex); - } catch (const std::exception &ex) { - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - PynativeExecutor::GetInstance()->Clean(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } -} - -inline ValuePtr PyAttrValue(const py::object &obj) { - ValuePtr converted_ret = parse::data_converter::PyDataToValue(obj); - if (!converted_ret) { - MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); - } - return converted_ret; -} - -std::string GetId(const py::object &obj) { - py::object to_process = obj; - std::string prefix = ""; - if (py::isinstance(to_process)) { - auto p_list = py::cast(to_process); - if (p_list.size() == 0) { - return "empty"; - } - prefix = "tuple:"; - std::string key = ""; - for (size_t i = 0; i < p_list.size(); ++i) { - key += std::string(py::str(GetId(p_list[i]))) + ":"; - } - return prefix + key; - } - if (py::isinstance(to_process)) { - return prefix + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - return prefix + std::string(py::str(to_process)); - } - if (py::isinstance(to_process)) { - auto tensor_ptr = py::cast(to_process); - return prefix + tensor_ptr->id(); - } - - py::object ret = parse::python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_MOD_GET_OBJ_ID, obj); - return py::cast(ret); -} - -py::object GetTupleObj(const py::object &obj) { - py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); - py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); - return obj_tuple; -} - -std::map> GetTypeIndex(const std::vector &dtypes) { - std::map> type_indexes; - for (size_t i = 0; i < dtypes.size(); ++i) { - auto it = type_indexes.find(dtypes[i]); - if (it == type_indexes.end()) { - (void)type_indexes.insert(std::make_pair(dtypes[i], std::vector{i})); - } else { - it->second.push_back(i); - } - } - return type_indexes; -} - -std::map GetDstType(const py::tuple &py_args, - const std::map> &type_indexes) { - std::map dst_type; - for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { - auto type = it->first; - auto indexes = it->second; - if (type == SignatureEnumDType::kDTypeEmptyDefaultValue || indexes.size() < 2) { - continue; - } - size_t priority = 0; - TypeId max_type = TypeId::kTypeUnknown; - bool has_float = false; - bool has_int = false; - for (size_t index : indexes) { - if (!has_float && py::isinstance(py_args[index])) { - has_float = true; - } - if (!has_int && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { - has_int = true; - } - if (py::isinstance(py_args[index])) { - auto arg = py::cast(py_args[index]); - TypeId arg_type_id = arg->data_type(); - auto type_priority = prim::type_map.find(arg_type_id); - if (type_priority == prim::type_map.end()) { - continue; - } - if (type_priority->second > priority) { - max_type = type_priority->first; - priority = type_priority->second; - } - } - } - if (max_type == TypeId::kNumberTypeBool) { - if (has_int) { - max_type = TypeId::kNumberTypeInt32; - } - if (has_float) { - max_type = TypeId::kNumberTypeFloat32; - } - } - (void)dst_type.insert(std::make_pair(type, max_type)); - } - return dst_type; -} - -std::string TypeIdToMsTypeStr(const TypeId &type_id) { - auto type_name = type_name_map.find(type_id); - if (type_name == type_name_map.end()) { - MS_LOG(EXCEPTION) << "For implicit type conversion, not support convert to the type: " << TypeIdToType(type_id); - } - return type_name->second; -} - -py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { - py::tuple args(3); - std::string module_name = "mindspore.ops.functional"; - std::string op_name = "cast"; - args[0] = parse::python_adapter::GetPyFn(module_name, op_name); - args[1] = "Cast"; - - std::string dst_type_str = TypeIdToMsTypeStr(type_id); - module_name = "mindspore.common.dtype"; - py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str); - py::tuple inputs(2); - inputs[0] = arg; - inputs[1] = dst_type; - args[2] = inputs; - - return RunOp(args)[0]; -} -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args, - py::list *const out_args_list) { - auto &py_args = *out_args; - py::tuple input_mask(args.size()); - for (size_t i = 0; i < args.size(); ++i) { - input_mask[i] = py::hasattr(args[i], "__parameter__"); - py_args[i] = GetTupleObj(args[i]); - } - auto signature = prim->signatures(); - std::vector dtypes; - (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature &sig) { return sig.dtype; }); - int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); - if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { - return input_mask; - } - auto type_indexes = GetTypeIndex(dtypes); - auto dst_type = GetDstType(py_args, type_indexes); - - for (size_t i = 0; i < dtypes.size(); ++i) { - if (dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { - continue; - } - auto it = dst_type.find(dtypes[i]); - if (it == dst_type.end() || it->second == kTypeUnknown) { - continue; - } - if (py::isinstance(py_args[i])) { - auto arg = py::cast(py_args[i]); - if (arg->data_type() == it->second) { - continue; - } - if (signature[i].rw == SignatureEnumRW::kRWWrite) { - prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg->data_type()), - TypeIdToMsTypeStr(it->second)); - } - } - py::object cast_output = DoAutoCast(py_args[i], it->second); - (*out_args)[i] = cast_output; - (*out_args_list)[i] = cast_output; - } - return input_mask; -} - -void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { - size_t size = py_args.size(); - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < size; i++) { - ValuePtr input_value = PyAttrValue(py_args[i]); - args_spec_list.emplace_back(abstract::FromValueInside( - input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa())); - } - AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); - op_exec_info->abstract = infer_res; -} - -OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) { - if (args.size() != PY_ARGS_NUM) { - MS_LOG(ERROR) << "Three args are needed by RunOp"; - return nullptr; - } - auto op_exec_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_exec_info); - op_exec_info->op_name = py::cast(args[PY_NAME]); - auto prim = py::cast(args[PY_PRIM]); - auto pyobj = prim->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "pyobj is empty"; - } - - py::list a = args[PY_INPUTS]; - size_t input_num = a.size(); - op_exec_info->op_inputs = py::tuple(input_num); - - op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs, out_args); - // use python infer method - if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); - } - op_exec_info->py_primitive = prim; - op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); - if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { - MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; - return nullptr; - } - return op_exec_info; -} - -std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(op_exec_info); - std::string graph_info; - // get input tensor info - size_t input_num = op_exec_info->op_inputs.size(); - for (size_t index = 0; index < input_num; ++index) { - auto input = op_exec_info->op_inputs[index]; - if (py::isinstance(input)) { - auto tensor_ptr = py::cast(input); - (void)graph_info.append(tensor_ptr->GetShapeAndDataTypeInfo() + "_"); - } - } - // get prim and abstract info - MS_EXCEPTION_IF_NULL(op_exec_info->abstract); - (void)graph_info.append(std::to_string((uintptr_t)(op_exec_info->py_primitive.get())) + "_" + - op_exec_info->abstract->ToString()); - return graph_info; -} - -py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_LOG(INFO) << "RunOpInVM start"; - - MS_EXCEPTION_IF_NULL(status); - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); - if (op_exec_info->op_name == "HookBackward") { - auto op_inputs = op_exec_info->op_inputs; - py::tuple result(op_inputs.size()); - for (size_t i = 0; i < op_inputs.size(); i++) { - py::object input = op_inputs[i]; - if (py::hasattr(input, "__parameter__")) { - input = py::getattr(input, "data"); - } - auto tensor = py::cast(input); - auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); - new_tensor->set_device_address(tensor->device_address()); - new_tensor->set_dirty(tensor->is_dirty()); - result[i] = new_tensor; - } - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); - } - auto func = op_exec_info->py_primitive->GetComputeFunction(); - if (py::isinstance(func)) { - MS_LOG(ERROR) << "VM failed to get func"; - *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; - py::tuple err_ret(0); - return std::move(err_ret); - } - - // execute op - py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); -} - -bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, - const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(op_prim); - auto input_names_value = op_prim->GetAttr(kAttrInputNames); - if (input_names_value == nullptr) { - return false; - } - auto input_names_vec = GetValue>(input_names_value); - if (input_index >= input_names_vec.size()) { - MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!"; - } - - if (input_attrs.find(input_index) != input_attrs.end()) { - ValuePtr value = parse::data_converter::PyDataToValue(input_object); - MS_EXCEPTION_IF_NULL(value); - auto input_name = input_names_vec[input_index]; - op_prim->set_attr(input_name, value); - return true; - } - return false; -} - -void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, - std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensors); - for (const auto &input_object : tuple_inputs) { - if (!py::isinstance(input_object)) { - MS_LOG(EXCEPTION) << "The input object is not a tensor!"; - } - auto tensor = py::cast(input_object); - MS_EXCEPTION_IF_NULL(tensor); - input_tensors->push_back(tensor); - } - op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(tuple_inputs.size())})); -} - -void ConvertValueTupleToTensor(const py::object &input_object, std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(input_tensors); - ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); - MS_EXCEPTION_IF_NULL(input_value); - if (!input_value->isa()) { - MS_LOG(EXCEPTION) << "The input object is not a value tuple!"; - } - auto value_tuple = input_value->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple); - MS_EXCEPTION_IF_NULL(tensor_ptr); - input_tensors->push_back(tensor_ptr); -} - -void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, - std::vector *input_tensors, int *tensor_mask) { - MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensors); - MS_EXCEPTION_IF_NULL(tensor_mask); - - if (!py::isinstance(input_object)) { - MS_LOG(EXCEPTION) << "The input should be a tuple!"; - } - auto tuple_inputs = py::cast(input_object); - if (tuple_inputs.size() == 0) { - MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; - } - if (py::isinstance(tuple_inputs[0])) { - PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); - } else { - ConvertValueTupleToTensor(input_object, input_tensors); - *tensor_mask = kValueNodeTensorMask; - } -} - -void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, - std::vector *input_tensors, int *tensor_mask) { - MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensors); - MS_EXCEPTION_IF_NULL(tensor_mask); - tensor::TensorPtr tensor_ptr = nullptr; - if (py::isinstance(input_object)) { - tensor_ptr = py::cast(input_object); - } else if (py::isinstance(input_object)) { - double input_value = py::cast(input_object); - tensor_ptr = std::make_shared(input_value, kFloat32); - *tensor_mask = kValueNodeTensorMask; - } else if (py::isinstance(input_object)) { - tensor_ptr = std::make_shared(py::cast(input_object), kInt32); - *tensor_mask = kValueNodeTensorMask; - } else if (py::isinstance(input_object)) { - tensor_ptr = TensorPy::MakeTensor(py::cast(input_object), nullptr); - } else if (py::isinstance(input_object)) { - auto list_inputs = py::cast(input_object); - py::tuple tuple_inputs(list_inputs.size()); - for (size_t i = 0; i < tuple_inputs.size(); ++i) { - tuple_inputs[i] = list_inputs[i]; - } - ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask); - return; - } else if (py::isinstance(input_object)) { - ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask); - return; - } else if (py::isinstance(input_object)) { - return; - } else { - MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; - } - MS_EXCEPTION_IF_NULL(tensor_ptr); - input_tensors->push_back(tensor_ptr); -} - -void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, - std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(op_run_info); - MS_EXCEPTION_IF_NULL(tensors_mask); - MS_EXCEPTION_IF_NULL(input_tensors); - PrimitivePtr op_prim = op_run_info->py_primitive; - MS_EXCEPTION_IF_NULL(op_prim); - - if (op_run_info->op_inputs.size() != op_run_info->inputs_mask.size()) { - MS_LOG(EXCEPTION) << "Op input size " << op_run_info->op_inputs.size() << " should be equal to op input mask size " - << op_run_info->inputs_mask.size(); - } - opt::ConstInputToAttrInfoRegister reg; - bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); - size_t input_num = op_run_info->op_inputs.size(); - for (size_t index = 0; index < input_num; ++index) { - // convert const input to attr - if (reg_exist && - RunOpConvertConstInputToAttr(op_run_info->op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) { - continue; - } - // convert const and tuple input to tensor - int tensor_mask = py::cast(op_run_info->inputs_mask[index]); - ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); - // mark tensors, data : 0, weight : 1, valuenode: 2 - std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); - tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); - } -} - -void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(input_tensors); - if (input_tensors->size() != tensors_mask.size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " - << tensors_mask.size(); - } - std::vector new_input_tensors; - for (size_t index = 0; index < tensors_mask.size(); ++index) { - if (tensors_mask[index] != kValueNodeTensorMask) { - new_input_tensors.push_back(input_tensors->at(index)); - } - } - *input_tensors = new_input_tensors; -} - -py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_enable_pynative_infer(true); - std::string device_target = ms_context->device_target(); - if (device_target != kAscendDevice && device_target != kGPUDevice) { - MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; - } - - if (session == nullptr) { - session = session::SessionFactory::Get().Create(device_target); - } - MS_EXCEPTION_IF_NULL(session); - session->Init(ms_context->device_id()); - - std::vector input_tensors; - std::vector tensors_mask; - ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); - // get graph info for checking it whether existing in the cache - std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); - session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); - EraseValueNodeTensor(tensors_mask, &input_tensors); - py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); - ms_context->set_enable_pynative_infer(false); - *status = PYNATIVE_SUCCESS; - return result; -} - -py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, - PynativeStatusCode *const status) { - MS_EXCEPTION_IF_NULL(status); - py::object result; - switch (backend_policy) { - case kMsBackendVmOnly: { - // use vm only - MS_LOG(INFO) << "RunOp use VM only backend"; - result = RunOpInVM(op_exec_info, status); - break; - } - case kMsBackendGePrior: { -#ifdef ENABLE_GE - // use GE first, use vm when GE fails - MS_LOG(INFO) << "RunOp use GE first backend"; - result = RunOpInGE(op_exec_info, status); - if (*status != PYNATIVE_SUCCESS) { - result = RunOpInVM(op_exec_info, status); - } -#endif - break; - } - case kMsBackendMsPrior: { - // use Ms fisrt,use others when ms failed - MS_LOG(INFO) << "RunOp use Ms first backend"; - result = RunOpInMs(op_exec_info, status); - if (*status != PYNATIVE_SUCCESS) { - MS_LOG(ERROR) << "RunOp use Ms backend failed!!!"; - } - break; - } - default: - MS_LOG(ERROR) << "No backend configured for run op"; - } - return result; -} - -AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) { - if (!grad_flag_ || graph_info_map_.empty()) { - return nullptr; - } - std::vector inputs; - auto prim = op_exec_info->py_primitive; - inputs.push_back(NewValueNode(prim)); - py::tuple op_masks = op_exec_info->inputs_mask; - AbstractBasePtrList args_spec_list; - for (size_t i = 0; i < args.size(); i++) { - auto node = GetInput(args[i], op_masks[i]); - args_spec_list.push_back(node->abstract()); - inputs.push_back(node); - } - - auto cnode = curr_g_->NewCNode(inputs); - MS_LOG(DEBUG) << "MakeCnode set node " << cnode->DebugString(4); - py::object out_real = out; - if (out.size() == 1) { - MS_LOG(DEBUG) << "MakeCnode out size is one."; - out_real = out[0]; - } - std::string obj_id = GetId(out_real); - if (py::isinstance(out_real)) { - auto value = py::cast(out_real); - if (value.size() > 1) { - for (int i = 0; i < static_cast(value.size()); i++) { - auto value_id = GetId(value[i]); - MS_LOG(DEBUG) << "MakeCnode set node id " << value_id; - set_obj_node_map(curr_g_, value_id, cnode, i); - } - } - } - MS_LOG(DEBUG) << "MakeCnode set node id " << obj_id; - set_obj_node_map(curr_g_, obj_id, cnode); - set_pyobj(curr_g_, obj_id); - return cnode; -} - -AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { - auto &out = graph_info_map_[curr_g_].obj_node_map[GetId(obj)]; - if (out.second.size() == 1 && out.second[0] == -1) { - return out.first; - } - auto node = out.first; - MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString(); - for (auto &idx : out.second) { - std::vector tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)}; - node = curr_g_->NewCNode(tuple_get_item_inputs); - } - MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6); - return node; -} - -py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info, const py::args &args) { - MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; - mindspore::parse::python_adapter::set_python_env_flag(true); - MsBackendPolicy backend_policy; -#if (!defined ENABLE_GE) - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->backend_policy() == "ms") { - backend_policy = kMsBackendMsPrior; - } else { - backend_policy = kMsBackendVmOnly; - } -#else - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - ms_context->PynativeInitGe(); - backend_policy = kMsBackendGeOnly; -#endif - if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { - backend_policy = kMsBackendVmOnly; - } - PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; - // returns a null py::tuple on error - py::tuple err_ret(0); - py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); - if (status != PYNATIVE_SUCCESS) { - MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; - return err_ret; - } - - auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result); - if (node != nullptr) { - node->set_abstract(op_exec_info->abstract); - MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString(); - } - MS_LOG(DEBUG) << "RunOp end"; - return result; -} - -py::tuple RunOpInner(const py::args &args) { - MS_LOG(DEBUG) << "RunOp start" << args.size(); - py::list args_input = args[PY_INPUTS]; - - OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args, &args_input); - MS_EXCEPTION_IF_NULL(op_exec_info); - - if (op_exec_info->abstract != nullptr) { - py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); - if (!output["value"].is_none()) { - py::tuple value_ret(1); - value_ret[0] = output["value"]; - return value_ret; - } - if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { - py::tuple value_ret(1); - value_ret[0] = ""; - return value_ret; - } - } - return RunOpInner(op_exec_info, args_input); -} - -py::tuple RunOp(const py::args &args) { - try { - return RunOpInner(args); - } catch (const py::error_already_set &ex) { - // print function call stack info before release - std::ostringstream oss; - trace::TraceGraphEval(); - trace::GetEvalStackInfo(oss); - // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see - // these info from screen, no need to open log file to find these info - py::print(oss.str()); - MS_LOG(ERROR) << oss.str(); - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(py::error_already_set(ex)); - } catch (const py::type_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::type_error(ex); - } catch (const py::value_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::value_error(ex); - } catch (const py::index_error &ex) { - PynativeExecutor::GetInstance()->Clean(); - throw py::index_error(ex); - } catch (const std::exception &ex) { - PynativeExecutor::GetInstance()->Clean(); - // re-throw this exception to Python interpreter to handle it - throw(std::runtime_error(ex.what())); - } catch (...) { - PynativeExecutor::GetInstance()->Clean(); - std::string exName(abi::__cxa_current_exception_type()->name()); - MS_LOG(EXCEPTION) << "Error occurred when compile graph. Exception name: " << exName; - } -} - -void ClearPyNativeSession() { session = nullptr; } - -PynativeExecutor::~PynativeExecutor() { ClearRes(); } - -PynativeExecutor::PynativeExecutor() { grad_flag_ = false; } - -void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &args) { - auto cell_id = GetId(cell); - if (cell_graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "Newgraph already compiled"; - return; - } - - auto g = std::make_shared(); - - if (top_g_ == nullptr) { - top_g_ = curr_g_ = g; - df_builder_ = std::make_shared(); - MS_LOG(DEBUG) << "First new graph" << top_g_.get(); - Pushp(); - } else { - Pushp(); - curr_g_ = g; - } - if (graph_info_map_.count(g) == 0) { - graph_info_map_[g] = GraphInfo(); - } - for (size_t i = 0; i < args.size(); i++) { - auto new_param = g->add_parameter(); - std::string param_obj = GetId(args[i]); - graph_info_map_[g].param_map[param_obj] = new_param; - } -} - -AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::string &obj_id) { - ValuePtr converted_ret = nullptr; - parse::ConvertData(obj, &converted_ret); - auto node = NewValueNode(converted_ret); - set_obj_node_map(curr_g_, obj_id, node); - return node; -} - -AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &op_mask) { - AnfNodePtr node = nullptr; - std::string obj_id = GetId(obj); - - if (op_mask != nullptr && py::cast(op_mask)) { - MS_LOG(DEBUG) << "Topgraph free parameter"; - // get the parameter name from parameter object - auto name_attr = mindspore::parse::python_adapter::GetPyObjAttr(obj, "name"); - if (py::isinstance(name_attr)) { - MS_LOG(EXCEPTION) << "Parameter object should have name attribute"; - } - auto param_name = py::cast(name_attr); - if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { - auto free_param = df_builder_->add_parameter(); - free_param->set_name(param_name); - auto free_param_new = py::cast(obj.attr("_value")); - free_param->set_default_param(free_param_new); - free_param->debug_info()->set_name(param_name); - MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; - graph_info_map_[df_builder_].param_map[obj_id] = free_param; - return free_param; - } - return graph_info_map_[df_builder_].param_map[obj_id]; - } - - // if input is graph output - if (graph_info_map_[curr_g_].param_map.count(obj_id) != 0) { - // op(x, y) - node = graph_info_map_[curr_g_].param_map[obj_id]; - } else if (graph_info_map_[curr_g_].obj_node_map.count(obj_id) != 0) { - // out = op(op1(x, y)) - // out = op(cell1(x, y)) - // out = op(cell1(x, y)[0]) - node = GetObjNode(obj); - } else if (py::isinstance(obj)) { - // out = op((x, y)) - // out = cell((x, y)) - auto tuple = obj.cast(); - - // cell((1,2)): support not mix (scalar, tensor) - if (tuple.size() > 0 && !py::isinstance(tuple[0])) { - return MakeValueNode(obj, obj_id); - } - - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto tuple_size = static_cast(tuple.size()); - for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], py::object())); - } - auto cnode = curr_g_->NewCNode(args); - set_obj_node_map(curr_g_, GetId(obj), cnode); - node = cnode; - } else { - node = MakeValueNode(obj, obj_id); - } - - MS_LOG(DEBUG) << "Now getinput node " << node->ToString() << obj_id; - return node; -} - -// for output[0][1] need getitem multi -void PynativeExecutor::SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx) { - if (py::isinstance(obj)) { - auto tuple = obj.cast(); - for (int i = 0; i < static_cast(tuple.size()); i++) { - std::vector tmp = idx; - tmp.push_back(i); - set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, tmp); - SetTupleOutput(tuple[i], cnode, tmp); - } - } -} - -void PynativeExecutor::Pushp() { graph_p_.push(curr_g_); } - -void PynativeExecutor::Popp() { - if (graph_p_.empty()) { - MS_LOG(EXCEPTION) << "Stack graph_p_ is empty"; - } - curr_g_ = graph_p_.top(); - graph_p_.pop(); -} - -void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { - auto cell_id = GetId(cell); - if (cell_graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "Endgraph already compiled"; - return; - } - cell_graph_map_[cell_id] = curr_g_; - auto out_id = GetId(out); - if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) { - // cell construct return x, y - if (py::isinstance(out)) { - std::vector args; - args.push_back(NewValueNode(prim::kPrimMakeTuple)); - - auto tuple = out.cast(); - MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size(); - auto tuple_size = static_cast(tuple.size()); - auto cnode = curr_g_->NewCNode(args); - for (int i = 0; i < tuple_size; i++) { - args.push_back(GetInput(tuple[i], py::object())); - set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); - SetTupleOutput(tuple[i], cnode, std::vector{i}); - } - cnode->set_inputs(args); - set_obj_node_map(curr_g_, out_id, cnode); - } else { - MS_LOG(ERROR) << "Graph has no this out: " << out_id; - return; - } - } - EndGraphByOutId(out_id, cell, out, args); -} - -void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, - const py::args &args) { - AnfNodePtr output_node; - if (graph_info_map_[curr_g_].param_map.count(out_id)) { - output_node = graph_info_map_[curr_g_].param_map[out_id]; - } else { - output_node = GetObjNode(out); - } - curr_g_->set_output(output_node); - std::vector inputs; - inputs.push_back(NewValueNode(curr_g_)); - MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString(); - resource_->manager()->AddFuncGraph(curr_g_); - // custom bprop debug - if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { - MS_LOG(DEBUG) << "Use cell custom bprop function."; - FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell); - if (bprop_graph != nullptr) { - (void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_))); - } - } - auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_); - if (curr_g_ != top_g_) { - Popp(); - for (size_t i = 0; i < args.size(); i++) { - auto input = GetInput(args[i], py::object()); - inputs.push_back(input); - } - auto out_cnode = curr_g_->NewCNode(inputs); - set_pyobj(curr_g_, GetId(cell)); - if (py::isinstance(out)) { - auto out_list = py::cast(out); - auto out_size = static_cast(out_list.size()); - for (int i = 0; i < out_size; i++) { - set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i); - SetTupleOutput(out_list[i], out_cnode, std::vector{i}); - } - } - set_obj_node_map(curr_g_, GetId(out), out_cnode); - } else { - parse::ResolveFuncGraph(newfg, resource_); - resource_->set_func_graph(newfg); - } -} - -std::vector PynativeExecutor::GetWeightsArgs(const py::object &weights) { - std::vector w_args; - if (py::hasattr(weights, "__parameter_tuple__")) { - auto tuple = weights.cast(); - MS_LOG(DEBUG) << "GradNet start weights tuple size" << tuple.size(); - w_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t it = 0; it < tuple.size(); ++it) { - auto param = tuple[it]; - auto param_id = GetId(param); - AnfNodePtr para_node = nullptr; - if (graph_info_map_[df_builder_].param_map.count(param_id)) { - para_node = graph_info_map_[df_builder_].param_map[param_id]; - - AnfNodePtr value = parse::GetMixedPrecisionCastHelp(df_builder_, para_node); - AnfNodePtr make_ref = NewValueNode(prim::kPrimMakeRef); - auto refkey = std::make_shared(para_node->cast()->name()); - AnfNodePtr ref_key_node = NewValueNode(refkey); - AnfNodePtr ref_node = df_builder_->NewCNode({make_ref, ref_key_node, value, para_node}); - - w_args.push_back(ref_node); - } - } - } else { - MS_LOG(DEBUG) << "training not paramter_tuple"; - } - return w_args; -} - -abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { - abstract::AbstractBasePtrList args_spec; - std::size_t size = args.size(); - for (std::size_t i = 0; i < size; i++) { - ValuePtr converted = nullptr; - bool succ = parse::ConvertData(args[i], &converted); - if (!succ) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - bool broaden = true; - auto abs = abstract::FromValue(converted, broaden); - args_spec.push_back(abs); - auto param_node = std::static_pointer_cast(df_builder_->parameters()[i]); - param_node->set_abstract(abs); - } - - for (const auto ¶m : df_builder_->parameters()) { - auto param_node = std::static_pointer_cast(param); - if (param_node->has_default()) { - const auto ¶m_value = param_node->default_param(); - ValuePtr value = param_value->value(); - AbstractBasePtr ptr = abstract::FromValue(value, true); - if (ptr == nullptr) { - MS_LOG(EXCEPTION) << "Args convert error"; - } - args_spec.push_back(ptr); - param_node->set_abstract(ptr); - } - } - - return args_spec; -} - -void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { - MS_LOG(INFO) << "GradNet start" << args.size(); - - std::size_t size = args.size(); - auto cell_id = GetId(cell); - if (graph_map_.count(cell_id) != 0) { - MS_LOG(DEBUG) << "GradNet already compiled"; - return; - } - MS_LOG(DEBUG) << "GradNet first compiled"; - std::vector new_params; - for (size_t i = 0; i < size; i++) { - ParameterPtr p = std::make_shared(df_builder_); - new_params.push_back(p); - } - MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); - new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); - df_builder_->set_parameters(new_params); - resource_->manager()->SetParameters(df_builder_, new_params); - - std::vector w_args = GetWeightsArgs(weights); - MS_EXCEPTION_IF_NULL(resource_->func_graph()); - auto g = GradGraph(resource_->func_graph(), grad, w_args, size); - resource_->set_func_graph(g); - resource_->manager()->KeepRoots({g}); - - // get the parameters items and add the value to args_spec - abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); - MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); - - resource_->set_args_spec(args_spec); - MS_LOG(DEBUG) << "Start opt"; - - // Create backend and session - resource_->results()[pipeline::kBackend] = compile::CreateBackend(); - - graph_map_[cell_id] = g; - PynativeOptimizeAction(resource_); - TaskEmitAction(resource_); - ExecuteAction(resource_); - resource_->Clean(); - ad::CleanRes(); - pipeline::ReclaimOptimizer(); -} - -void PynativeExecutor::Clear(const std::string &flag) { - if (!flag.empty()) { - MS_LOG(INFO) << "Clear res"; - (void)graph_map_.erase(flag); - (void)cell_graph_map_.erase(flag); - Clean(); - // Maybe exit in the pynative runing op, so need reset pynative flag. - auto ms_context = MsContext::GetInstance(); - if (ms_context != nullptr) { - ms_context->set_enable_pynative_infer(false); - } - return; - } - - MS_LOG(INFO) << "Clear"; - top_g_ = nullptr; - curr_g_ = nullptr; - graph_info_map_.clear(); - std::stack().swap(graph_p_); -} - -void PynativeExecutor::Clean() { - MS_LOG(INFO) << "Clean all res"; - Clear(); - grad_flag_ = false; - df_builder_ = nullptr; - ad::CleanRes(); - pipeline::ReclaimOptimizer(); -} - -void PynativeExecutor::ClearRes() { - Clean(); - resource_.reset(); -} - -py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) { - VectorRef arg_list; - pipeline::ProcessVmArgInner(args, resource_, &arg_list); - if (resource_->results().find(pipeline::kOutput) == resource_->results().end() || - !resource_->results()[pipeline::kOutput].is()) { - MS_LOG(EXCEPTION) << "Can't find run graph func for "; - } - compile::VmEvalFuncPtr run = resource_->results()[pipeline::kOutput].cast(); - if (run == nullptr) { - MS_LOG(EXCEPTION) << "Can't find run graph func for "; - } - - std::string backend = MsContext::GetInstance()->backend_policy(); - - MS_LOG(DEBUG) << "Eval run" << backend; - BaseRef value = (*run)(arg_list); - MS_LOG(DEBUG) << "Run end" << value.ToString(); - return BaseRefToPyData(value); -} - -FuncGraphPtr PynativeExecutor::GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, - const std::vector &weights, size_t arg_size) { - auto nparam = top_g_->parameters().size(); - std::ostringstream ss; - ss << "grad{" << nparam << "}"; - df_builder_->set_flag(FUNC_GRAPH_FLAG_CORE, true); - df_builder_->debug_info()->set_name(ss.str()); - - auto df = grad_op->GetGrad(NewValueNode(g), nullptr, top_g_->parameters(), weights); - std::vector inputs = {NewValueNode(df)}; - for (size_t i = 0; i < arg_size; ++i) { - inputs.push_back(df_builder_->parameters()[i]); - } - auto out = df_builder_->NewCNode(inputs); - df_builder_->set_output(out); - resource_->manager()->AddFuncGraph(df); - resource_->manager()->AddFuncGraph(df_builder_); - return df_builder_; -} - -void PynativeExecutor::NewGraph(const py::object &cell, const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::NewGraphInner, cell, args); -} - -void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::EndGraphInner, cell, out, args); -} - -void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args) { - PynativeExecutorTry(this, &PynativeExecutor::GradNetInner, grad, cell, weights, args); -} - -REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { - (void)py::class_>(*m, "PynativeExecutor_") - .def_static("get_instance", &PynativeExecutor::GetInstance, "PynativeExecutor get_instance.") - .def("new_graph", &PynativeExecutor::NewGraph, "pynative new a graph.") - .def("end_graph", &PynativeExecutor::EndGraph, "pynative end a graph.") - .def("grad_net", &PynativeExecutor::GradNet, "pynative grad graph.") - .def("clear", &PynativeExecutor::Clear, "pynative clear status.") - .def("__call__", &PynativeExecutor::Run, py::arg("args"), py::arg("phase") = py::str(""), - "Executor run function.") - .def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), - "Executor set grad flag."); - })); -} // namespace pynative -} // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h deleted file mode 100644 index 83cbea88d4..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "pybind11/pybind11.h" -#include "pybind11/numpy.h" - -#include "pynative/base.h" -#include "utils/context/ms_context.h" -#include "ir/anf.h" -#include "pipeline/resource.h" -#include "operator/composite/composite.h" - -namespace mindspore { -namespace pynative { - -namespace py = pybind11; -using ResourcePtr = std::shared_ptr; -using GradOperationPtr = std::shared_ptr; - -py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); - -py::tuple RunOp(const py::args &args); - -py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args, - py::list *const out_args_list); - -void ClearPyNativeSession(); - -struct GraphInfo { - std::unordered_map param_map; - std::unordered_map>> obj_node_map; - AnfNodePtr output; - std::vector objects; -}; - -class PynativeExecutor : public std::enable_shared_from_this { - public: - static std::shared_ptr GetInstance() { - std::lock_guard i_lock(instance_lock_); - if (executor_ == nullptr) { - executor_ = std::shared_ptr(new (std::nothrow) PynativeExecutor()); - resource_ = std::make_shared(); - } - return executor_; - } - void NewGraph(const py::object &cell, const py::args &args); - void NewGraphInner(const py::object &cell, const py::args &args); - void EndGraph(const py::object &cell, const py::object &out, const py::args &args); - void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args); - void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args); - std::vector GetWeightsArgs(const py::object &weights); - abstract::AbstractBasePtrList GetArgsSpec(const py::args &args); - void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); - void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, - const py::args &args); - void Clear(const std::string &flag = ""); - void Clean(); - void ClearRes(); - bool grad_flag() { return grad_flag_; } - void set_grad_flag(bool flag) { grad_flag_ = flag; } - AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask); - AnfNodePtr GetObjNode(const py::object &obj); - FuncGraphPtr curr_g() { return curr_g_; } - void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{-1}); - } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, std::vector{index}); - } - void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector index) { - graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index); - } - AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out); - py::object Run(const py::tuple &args, const py::object &phase); - - void Pushp(); - void Popp(); - FuncGraphPtr GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector &weights, - size_t arg_size); - void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector idx); - AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); - - ~PynativeExecutor(); - - private: - PynativeExecutor(); - static std::shared_ptr executor_; - static std::mutex instance_lock_; - static ResourcePtr resource_; - bool grad_flag_; - std::unordered_map graph_map_; - std::unordered_map cell_graph_map_; - std::unordered_map graph_info_map_; - std::stack graph_p_; - FuncGraphPtr top_g_; - FuncGraphPtr df_builder_; - FuncGraphPtr curr_g_; -}; - -using PynativeExecutorPtr = std::shared_ptr; - -} // namespace pynative -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_H_ diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pynative/pynative_execute_ge.cc deleted file mode 100644 index 8e10468236..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.cc +++ /dev/null @@ -1,312 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "pynative/pynative_execute_ge.h" - -#include -#include -#include -#include - -#include "utils/any.h" -#include "utils/utils.h" -#include "utils/context/ms_context.h" -#include "operator/ops.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/static_analysis/prim.h" -#include "session/session_factory.h" -#include "ir/tensor_py.h" - -const char SINGLE_OP_GRAPH[] = "single_op_graph"; - -using mindspore::tensor::TensorPy; - -namespace mindspore { -namespace pynative { -using MeTensor = mindspore::tensor::Tensor; -using MeTensorPtr = mindspore::tensor::TensorPtr; -using GeOperator = ge::Operator; -using GeOperatorPtr = std::shared_ptr; - -using transform::GraphRunner; -using transform::GraphRunnerOptions; -using transform::OperatorPtr; -static std::shared_ptr session = nullptr; -inline ValuePtr PyAttrValue(const py::object &obj) { - ValuePtr converted_ret = nullptr; - bool converted = parse::ConvertData(obj, &converted_ret); - if (!converted) { - MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); - } - return converted_ret; -} - -MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { - MeTensorPtr me_tensor_ptr = nullptr; - if (py::isinstance(obj)) { - me_tensor_ptr = py::cast(obj); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::array(py::cast(obj)), nullptr); - } else if (py::isinstance(obj)) { - me_tensor_ptr = TensorPy::MakeTensor(py::cast(obj), nullptr); - } else { - MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; - } - return me_tensor_ptr; -} - -bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, - const OperatorPtr &op, std::vector *graph_input_nodes) { - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(graph_input_nodes); - auto op_inputs = op_exec_info->op_inputs; - std::string op_name = op_exec_info->op_name; - transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); - if (adapter == nullptr) { - return false; - } - - int op_input_idx = 1; - size_t size = inputs.size(); - for (size_t i = 0; i < size; i++) { - if (inputs[i] == nullptr) { - continue; - } - auto const_op = std::make_shared(); - MS_EXCEPTION_IF_NULL(const_op); - (void)const_op->set_attr_value(*inputs[i]); - MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); - MS_EXCEPTION_IF_NULL(me_tensor_ptr); - auto const_op_desc = - transform::TransformUtil::GetGeTensorDesc(me_tensor_ptr->shape_c(), me_tensor_ptr->data_type(), kOpFormat_NCHW); - if (const_op_desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << op_name << " output descriptor failed!"; - return false; - } - auto pointer_cast_const_op = std::static_pointer_cast(const_op); - MS_EXCEPTION_IF_NULL(pointer_cast_const_op); - (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); - auto &input_map = adapter->getInputMap(); - if (input_map.find(op_input_idx) == input_map.end()) { - continue; - } - if (adapter->setInput(op, op_input_idx++, const_op)) { - MS_LOG(ERROR) << "Failed to set params, index is " << op_input_idx; - return false; - } - graph_input_nodes->push_back(*const_op); - } - return true; -} - -bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, - const std::unordered_map &attrs, const GeGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(op_exec_info); - std::string op_name = op_exec_info->op_name; - auto op_inputs = op_exec_info->op_inputs; - transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Unable to find Adapter for " << ((std::string)py::str(op_name)); - return false; - } - OperatorPtr op = adapter->generate(op_name); - MS_EXCEPTION_IF_NULL(op); - - std::vector graph_input_nodes; - // hold param nodes after setting input and output for the graph - // set input - if (!SetInputsForSingleOpGraph(op_exec_info, inputs, op, &graph_input_nodes)) { - return false; - } - // set attributes - for (auto attr : attrs) { - (void)adapter->setAttr(op, attr.first, attr.second); - } - // set default attributes - auto extra_attrs = adapter->GetExtraAttr(); - for (auto attr : extra_attrs) { - (void)adapter->setAttr(op, attr.first, attr.second); - } - // set input attributes - auto &input_attr_map = adapter->getInputAttrMap(); - for (auto &it : input_attr_map) { - if (op_inputs.size() < it.first) { - continue; - } - auto const_value = PyAttrValue(op_inputs[it.first - 1]); - if (const_value->isa()) { - continue; - } - it.second.set_attr(op, const_value); - } - // construct output data nodes - std::vector graph_outputs{*op}; - // set input and output nodes for the graph - MS_EXCEPTION_IF_NULL(graph); - (void)graph->SetInputs(graph_input_nodes).SetOutputs(graph_outputs); - MS_LOG(INFO) << "BuildSingleOpGraph done"; - return true; -} - -void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { - MS_EXCEPTION_IF_NULL(inputs); - MS_EXCEPTION_IF_NULL(op_exec_info); - auto op_inputs = op_exec_info->op_inputs; - size_t size = op_inputs.size(); - for (size_t i = 0; i < size; i++) { - if (py::isinstance(op_inputs[i])) { - inputs->emplace_back(nullptr); - continue; - } - MeTensorPtr me_tensor_ptr = ConvertPyObjToTensor(op_inputs[i]); - auto ge_tensor_ptr = transform::TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW); - if (ge_tensor_ptr == nullptr) { - MS_LOG(EXCEPTION) << "Convert inputs to GE tensor failed in op " << op_exec_info->op_name << "."; - } - // set inputs for operator to build single node graph - inputs->push_back(ge_tensor_ptr); - } -} - -PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { - MS_EXCEPTION_IF_NULL(op_exec_info); - auto op_attrs = op_exec_info->op_attrs; - std::unordered_map attrs{}; - - for (auto &item : op_attrs) { - if (!py::isinstance(item.first)) { - MS_LOG(ERROR) << "Type error in py dict convert"; - return PYNATIVE_OP_ATTRS_ERR; - } - std::string name = py::cast(item.first); - auto attr_value = PyAttrValue(py::cast(item.second)); - (void)attrs.emplace(name, attr_value); - } - - // build graph - GeGraphPtr graph = std::make_shared(op_exec_info->op_name); - if (BuildSingleOpGraph(op_exec_info, inputs, attrs, graph) == false) { - MS_LOG(ERROR) << "Failed to BuildSingleOpGraph"; - return PYNATIVE_GRAPH_GE_BUILD_ERR; - } - - // add the single op graph into the graph manager, which will be iterated by session. - transform::Status ret = - transform::DfGraphManager::GetInstance().AddGraph(SINGLE_OP_GRAPH, std::shared_ptr(graph)); - if (ret != transform::SUCCESS) { - MS_LOG(ERROR) << "Failed to AddGraph into graph manager"; - return PYNATIVE_GRAPH_MANAGER_ERR; - } - - return PYNATIVE_SUCCESS; -} - -std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, - const std::vector &ge_tensors) { - std::vector outputs; - AbstractBasePtr abs_base = op_exec_info->abstract; - std::vector> shapes; - if (abs_base != nullptr && abs_base->isa()) { - auto arg_tensor = dyn_cast(abs_base); - shapes.emplace_back(arg_tensor->shape()->shape()); - outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); - return outputs; - } - if (abs_base != nullptr && abs_base->isa()) { - auto arg_tuple = dyn_cast(abs_base); - size_t len = arg_tuple->size(); - - for (size_t i = 0; i < len; i++) { - if (arg_tuple->elements()[i]->isa()) { - auto arg_tensor = dyn_cast(arg_tuple->elements()[i]); - shapes.emplace_back(arg_tensor->shape()->shape()); - } - } - outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); - return outputs; - } - for (auto &it : ge_tensors) { - auto tensor = transform::TransformUtil::ConvertGeTensor(it); - if (tensor != nullptr) { - outputs.emplace_back(tensor); - } - } - return outputs; -} - -py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_LOG(INFO) << "RunOpInGe start"; - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(status); - - // returns a null py::tuple on error - py::tuple err_ret(0); - auto op_name = op_exec_info->op_name; - transform::OpAdapterPtr adapter = transform::DfGraphConvertor::FindAdapter(op_name, true); - if (adapter == nullptr) { - MS_LOG(ERROR) << "Unable to find GE Adapter for " << ((std::string)py::str(op_name)); - *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; - return std::move(err_ret); - } - - std::vector inputs{}; - ToTensorPtr(op_exec_info, &inputs); - // convert me attr to ge AttrValue - PynativeStatusCode ret = ConvertAttributes(op_exec_info, inputs); - if (ret != PYNATIVE_SUCCESS) { - *status = ret; - return std::move(err_ret); - } - // run graph - transform::RunOptions run_options; - run_options.name = SINGLE_OP_GRAPH; - std::vector ge_inputs; - std::vector ge_outputs; - transform::GraphRunnerOptions graph_runner_options; - graph_runner_options.options["ge.trainFlag"] = "1"; - auto graph_runner = std::make_shared(graph_runner_options); - transform::Status run_ret; - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - run_ret = graph_runner->RunGraph(run_options, ge_inputs, &ge_outputs); - } - if (run_ret != transform::Status::SUCCESS) { - MS_LOG(ERROR) << "GraphRunner fails to run graph"; - *status = PYNATIVE_GRAPH_GE_RUN_ERR; - return std::move(err_ret); - } - - std::vector graph_outputs = ConvertOutputTensors(op_exec_info, ge_outputs); - size_t output_size = graph_outputs.size(); - py::tuple result(output_size); - for (size_t i = 0; i < output_size; i++) { - MS_EXCEPTION_IF_NULL(graph_outputs[i]); - result[i] = *graph_outputs[i]; - } - - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInGe end"; - return std::move(result); -} -} // namespace pynative -} // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pynative/pynative_execute_ge.h deleted file mode 100644 index 2dca3df018..0000000000 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ -#define MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ - -#include -#include -#include -#include -#include - -#include "pynative/base.h" -#include "transform/convert.h" -#include "transform/graph_runner.h" -#include "transform/types.h" -#include "utils/context/ms_context.h" - -using GeTensor = ge::Tensor; -using GeTensorPtr = std::shared_ptr; -using GeGraph = ge::Graph; -using GeGraphPtr = std::shared_ptr; - -namespace mindspore { -namespace pynative { -bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, - const std::unordered_map &attrs, const GeGraphPtr &graph); - -py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); -} // namespace pynative -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PYNATIVE_PYNATIVE_EXECUTE_GE_H_ diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt new file mode 100644 index 0000000000..9c95aee0dc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -0,0 +1,65 @@ +file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" + "kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" +) + +if (ENABLE_GPU) + list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_init.cc") +else () + list(APPEND DEVICE_SRC_LIST "gpu/distribution/collective_fake_init.cc") +endif () + +if (ENABLE_D) + file(GLOB_RECURSE D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend/*.cc" "kernel_adjust.cc") +endif () + +if (ENABLE_CPU) + file(GLOB_RECURSE CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "cpu/*.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/mpi/mpi_adapter.cc") +endif () + +if (ENABLE_MPI) + # _ms_mpi + file(GLOB_RECURSE MPI_SRC_LIST "cpu/mpi/mpi_adapter.cc") + set_property(SOURCE ${MPI_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(mpi_adapter SHARED ${MPI_SRC_LIST}) + target_link_libraries(mpi_adapter PRIVATE mindspore::ompi) + + set_property(SOURCE "gpu/mpi/mpi_initializer.cc" + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + pybind11_add_module(_ms_mpi "gpu/mpi/mpi_initializer.cc") + target_link_libraries(_ms_mpi PRIVATE mindspore::pybind11_module mindspore::ompi) +endif () + +# gpu +if (ENABLE_GPU) + file(GLOB_RECURSE CUDA_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc" "gpu/*.cu") + + set(GPU_QUEUE_SRCS "gpu/blocking_queue.cc" "gpu/gpu_buffer_mgr.cc") + set(GPU_COLLECTIVE_SRCS "gpu/distribution/collective_wrapper.cc" + "gpu/distribution/mpi_wrapper.cc" + "gpu/distribution/nccl_wrapper.cc") + + # gpu_queue + list(REMOVE_ITEM CUDA_SRC_LIST ${GPU_QUEUE_SRCS}) + set_property(SOURCE ${GPU_QUEUE_SRCS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(gpu_queue SHARED ${GPU_QUEUE_SRCS}) + target_link_libraries(gpu_queue ${CMAKE_THREAD_LIBS_INIT} ${CUDA_PATH}/lib64/libcudart.so) + + list(REMOVE_ITEM CUDA_SRC_LIST "gpu/mpi/mpi_initializer.cc" ${GPU_COLLECTIVE_SRCS}) + + if (ENABLE_MPI) + include(ExternalProject) + # gpu_collective + set_property(SOURCE ${GPU_COLLECTIVE_SRCS} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) + add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) + target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) + endif () + + # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) +endif () + +set_property(SOURCE ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST} + PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) +add_library(_mindspore_runtime_device_obj OBJECT ${DEVICE_SRC_LIST} ${D_SRC_LIST} ${CPU_SRC_LIST}) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc new file mode 100644 index 0000000000..32238a0603 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -0,0 +1,415 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/ascend_device_address.h" +#include +#include +#include +#include +#include "runtime/mem.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/convert_tensor_utils.h" +#include "ir/dtype/type.h" +#include "ir/tensor.h" +#include "backend/kernel_compiler/common_utils.h" +#include "utils/utils.h" +#include "common/utils.h" +#include "common/trans.h" +#ifdef ENABLE_DUMP_E2E +#include "debug/e2e_dump.h" +#endif +#ifdef ENABLE_DEBUGGER +#include "debug/tensor_load.h" +#endif + +namespace mindspore { +namespace device { +namespace ascend { +const int FLOAT_LEN = sizeof(float); +const int FLOAT16_LEN = 2; // sizeof(float16); +const std::set kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, + kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; + +void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) { + auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; + } +} + +bool FloatToHalfAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { + auto elem_num = src_size / FLOAT_LEN; + if (elem_num != (dst_size / FLOAT16_LEN)) { + MS_EXCEPTION(ArgumentError) << "FloatToHalf failed. size not match src_size[" << src_size << "], dst_size[" + << dst_size << "]"; + } + std::vector half_data(elem_num); + FloatToHalf(half_data.data(), src, elem_num); + SyncMemory(dst, half_data.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); + return true; +} + +bool Float64ToFloatAndSyncHostToDevice(void *dst, size_t dst_size, const void *src, size_t src_size) { + if (src_size / 2 != dst_size) { + MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; + } + size_t elem_num = dst_size / sizeof(float); + auto host_tmp = std::vector(elem_num); + DoubleToFloat(host_tmp.data(), src, elem_num); + SyncMemory(dst, host_tmp.data(), dst_size, RT_MEMCPY_HOST_TO_DEVICE); + return true; +} + +bool SyncDeviceToHostAndHalfToFloat(void *dst, size_t dst_size, const void *src, size_t src_size) { + auto elem_num = src_size / FLOAT16_LEN; + if (elem_num != (dst_size / FLOAT_LEN)) { + MS_EXCEPTION(ArgumentError) << "HalfToFloat failed. size not match src_size[" << src_size << "], dst_size[" + << dst_size << "]"; + } + std::vector half_data(elem_num); + SyncMemory(half_data.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); + HalfToFloat(dst, half_data.data(), elem_num); + return true; +} + +bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *src, size_t src_size) { + if (src_size != dst_size / 2) { + MS_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]"; + } + size_t elem_num = src_size / sizeof(float); + auto host_tmp = std::vector(elem_num); + SyncMemory(host_tmp.data(), src, src_size, RT_MEMCPY_DEVICE_TO_HOST); + FloatToDouble(dst, host_tmp.data(), elem_num); + return true; +} + +void AscendDeviceAddress::SyncStream() const { + MS_LOG(INFO) << "Start!"; + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() != kPynativeMode) { + MS_LOG(INFO) << "Finish!"; + return; + } + auto device_id = ms_context->device_id(); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto ret = runtime_instance->SyncStream(); + if (!ret) { + MS_LOG(EXCEPTION) << "Sync stream error!"; + } + MS_LOG(INFO) << "Finish!"; +} + +bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t size, mindspore::TypeId type, + void *host_ptr) const { + MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + SyncStream(); + bool sync_ok = false; + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { + if (type_id_ == type) { + SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); + sync_ok = true; + } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { + sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); + } else { + auto shape_size = trans::ShapeSize(host_shape); + auto host = std::vector(size_); + SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; + sync_ok = trans::TransDataType(type_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "trans data type failed."; + return false; + } + } + } else { + auto iter = kOpNeedTransFormat.find(format_); + if (iter != kOpNeedTransFormat.end()) { + sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); + } else { + MS_LOG(INFO) << "Can not find format transfer for :" << format_; + } + } + if (!sync_ok) { + MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) + << ", host_type:" << TypeIdLabel(type); + return false; + } + return sync_ok; +} + +bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, + mindspore::TypeId type, void *host_ptr) const { + MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + bool sync_ok = false; + auto host_tmp = std::vector(size_); + SyncMemory(host_tmp.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + std::vector device_shape; + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + device_shape = trans::TransShapeToDevice(host_shape, format_); + } else { + if (host_shape_.empty()) { + host_shape = trans::PaddingShapeTo4d(host_shape); + } else { + host_shape.clear(); + (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize); + } + + device_shape = trans::TransShapeToDevice(host_shape, format_); + } + if (type_id_ != type) { + const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, + host_shape, device_shape, type_id_}; + auto host = std::vector(size_); + sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; + sync_ok = trans::TransDataType(type_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + } else { + const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, + host_shape, device_shape, type_id_}; + sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + } + return sync_ok; +} + +bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t size, mindspore::TypeId type, + const void *host_ptr) const { + MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + SyncStream(); + bool sync_ok = false; + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { + if (type_id_ == type) { + SyncMemory(ptr_, host_ptr, size_, RT_MEMCPY_HOST_TO_DEVICE); + sync_ok = true; + } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { + sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); + } else { + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; + auto host_tmp = std::vector(size_); + sync_ok = trans::TransDataType(type_args, host_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans data type failed."; + return false; + } + SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); + } + } else { + auto iter = kOpNeedTransFormat.find(format_); + if (iter != kOpNeedTransFormat.end()) { + sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); + } else { + MS_LOG(INFO) << "Can not find format transfer for :" << format_; + } + } + if (!sync_ok) { + MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) + << ", host_type:" << TypeIdLabel(type); + return false; + } + return sync_ok; +} + +bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, + mindspore::TypeId type, const void *host_ptr) const { + bool sync_ok = false; + MS_LOG(INFO) << "ConvertFormatAndSyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) + << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + std::vector host_shape; + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.emplace_back(1); + } + std::vector device_shape; + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + device_shape = trans::TransShapeToDevice(host_shape, format_); + } else { + host_shape = trans::PaddingShapeTo4d(host_shape); + device_shape = trans::TransShapeToDevice(host_shape, format_); + } + if (type_id_ != type) { + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; + auto host_tmp = std::vector(size_); + sync_ok = trans::TransDataType(type_args, host_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans datatype failed."; + return false; + } + const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, + host_shape, device_shape, type_id_}; + auto dst_tmp = std::vector(size_); + sync_ok = trans::TransFormat(format_args, dst_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); + } else { + const trans::FormatArgs format_args{host_ptr, size_, kOpFormat_NCHW, format_, host_shape, device_shape, type_id_}; + auto host_tmp = std::vector(size_); + sync_ok = trans::TransFormat(format_args, host_tmp.data()); + if (!sync_ok) { + MS_LOG(ERROR) << "Trans format failed."; + return false; + } + SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); + } + return sync_ok; +} + +void AscendDeviceAddress::UpdateCommunicationAddress() { + MS_EXCEPTION_IF_NULL(ptr_); + communication_ptr_ = reinterpret_cast(ptr_) - kMemAlignSize; +} + +AscendDeviceAddress::~AscendDeviceAddress() { + if (ptr_ == nullptr) { + return; + } + if (from_mem_pool_) { + if (communication_ptr_ != nullptr) { + AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_); + communication_ptr_ = nullptr; + } else { + AscendMemoryPool::GetInstance().FreeTensorMem(ptr_); + } + ptr_ = nullptr; + } +} + +#ifdef ENABLE_DUMP_E2E +bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &filepath, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type) const { + bool ret = false; + if (filepath.empty()) { + MS_LOG(ERROR) << "Dump file path is null!"; + return ret; + } + std::string shape = "shape"; + if (host_shape.size()) { + for (auto &value : host_shape) { + shape = shape + '_' + std::to_string(value); + } + } else { + shape = shape + "_0"; + } + std::string file_extension = ".bin"; + if (trans_flag) { + std::string path = filepath + '_' + shape + '_' + TypeIdLabel(host_type) + '_' + host_fmt + file_extension; + MS_LOG(INFO) << "E2E Dump path is " << path; + mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); + size_t host_size = out_tensor->data().nbytes(); + ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); + if (!ret) { + MS_LOG(ERROR) << "Copy device mem to host failed"; + return ret; + } + ret = mindspore::Dump::DumpToFile(path, out_tensor->data_c(), host_size); + } else { + auto host_tmp = std::vector(size_); + auto ret_rt_memcpy = rtMemcpy(host_tmp.data(), size_, ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; + } + std::string path = + filepath + '_' + shape + '_' + TypeIdToType(type_id_)->ToString() + '_' + format_ + file_extension; + MS_LOG(INFO) << "E2E Dump path is " << path; + ret = mindspore::Dump::DumpToFile(path, host_tmp.data(), size_); + } + + return ret; +} +#endif + +#ifdef ENABLE_DEBUGGER +bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, + const std::string &host_fmt, const std::vector &host_shape, + TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const { + bool ret = false; + + DebugServices *debug_services = debugger->debug_services(); + TensorLoader *tensor_loader = debug_services->get_tensor_loader(); + + if (trans_flag) { + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + mindspore::tensor::TensorPtr out_tensor = std::make_shared(host_type, host_shape); + size_t host_size = out_tensor->data().nbytes(); + ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c()); + if (!ret) { + MS_LOG(ERROR) << "Copy device mem to host failed"; + return ret; + } + auto tensor_data = std::make_shared(); + tensor_data->SetName(tensor_name); + tensor_data->SetExecutionOrder(execution_order); + tensor_data->SetTensor(out_tensor); + tensor_data->SetSlot(slot); + ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); + } else { + mindspore::tensor::TensorPtr out_tensor = std::make_shared(type_id_, host_shape); + size_t host_size = out_tensor->data().nbytes(); + auto ret_rt_memcpy = rtMemcpy(out_tensor->data_c(), host_size, ptr_, host_size, RT_MEMCPY_DEVICE_TO_HOST); + + auto tensor_data = std::make_shared(); + tensor_data->SetName(tensor_name); + tensor_data->SetExecutionOrder(execution_order); + tensor_data->SetTensor(out_tensor); + tensor_data->SetSlot(slot); + ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev); + if (ret_rt_memcpy != RT_ERROR_NONE) { + MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; + } + MS_LOG(INFO) << "E2E tensor name is " << tensor_name; + } + return ret; +} +#endif +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h new file mode 100644 index 0000000000..78d7006b56 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ + +#include +#include +#include +#include "runtime/device/device_address.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "ir/dtype.h" + +namespace mindspore { +#ifdef ENABLE_DEBUGGER +class Debugger; +#endif +namespace device { +namespace ascend { +class AscendDeviceAddress : public DeviceAddress { + public: + explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) + : DeviceAddress(ptr, size, format, type_id) {} + ~AscendDeviceAddress() override; + bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } + void UpdateCommunicationAddress() override; +#ifdef ENABLE_DUMP_E2E + bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type) const; +#endif +#ifdef ENABLE_DEBUGGER + bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, + const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, + bool keep_prev) const; +#endif + + private: + bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; + bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, + const void *host_ptr) const; + void SyncStream() const; + uint8_t *communication_ptr_{nullptr}; +}; +using AscendDeviceAddressPtr = std::shared_ptr; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc new file mode 100644 index 0000000000..07669a9b3c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -0,0 +1,713 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#define PATH_MAX 0x3ffff +#include "runtime/device/ascend/ascend_kernel_runtime.h" +#include +#include +#include +#include +#include +#include +#include "runtime/device/ascend/ascend_device_address.h" +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#include "utils/context/ms_context.h" +#include "utils/mpi/mpi_config.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include "hccl/hcom.h" +#include "common/trans.h" +#include "runtime/context.h" +#include "runtime/device/ascend/ascend_label_assign.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "framework/ge_runtime/model_runner.h" +#include "runtime/device/ascend/tasksink/task_generator.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/tbe/tbe_python_funcs.h" +#include "backend/optimizer/mem_reuse/mem_reuse_checker.h" +#include "runtime/device/ascend/ascend_memory_manager.h" +#include "debug/tensor_load.h" + +using ge::model_runner::ModelRunner; +using mindspore::device::ascend::ProfilingManager; +using mindspore::device::ascend::ProfilingUtils; +using mindspore::device::ascend::tasksink::TaskGenerator; +using mindspore::kernel::tbe::TbeUtils; +using std::vector; + +namespace mindspore { +namespace device { +namespace ascend { +static const size_t PRAMATER_OUTPUT_INDEX = 0; +namespace { +std::string GetRankId() { + std::string rank_id_str; +#ifdef ENABLE_MPI + auto mpi_config_ptr = MpiConfig::GetInstance(); + MS_EXCEPTION_IF_NULL(mpi_config_ptr); + if (mpi_config_ptr->enable_mpi()) { + auto mpi_instance = device::cpu::MPIAdapter::Instance(); + MS_EXCEPTION_IF_NULL(mpi_instance); + int rank_id = mpi_instance->GetRankId(); + const char *offset = std::getenv("RANK_OFFSET"); + if (offset != nullptr) { + try { + int rank_offset = std::stoi(offset); + rank_id += rank_offset; + } catch (std::invalid_argument) { + MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset; + } catch (std::out_of_range) { + MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset; + } + } + rank_id_str = std::to_string(rank_id); + } else { + rank_id_str = std::getenv("RANK_ID"); + } +#else + rank_id_str = std::getenv("RANK_ID"); +#endif + if (rank_id_str.empty()) { + MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID"; + } + return rank_id_str; +} +} // namespace + +AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } + +void AscendKernelRuntime::ClearGraphModelMap() { +#ifdef ENABLE_DATA_DUMP + for (auto &iter : graph_data_dumper_) { + MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first; + iter.second->UnloadDumpInfo(); + } + graph_data_dumper_.clear(); +#endif + for (auto &iter : graph_model_map_) { + MS_LOG(INFO) << "Ge UnloadModel " << iter.first; + auto ret = ModelRunner::Instance().UnloadModel(iter.first); + if (!ret) { + MS_LOG(ERROR) << "UnloadModel failed"; + } + } +} + +void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { + MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; + auto iter = graph_model_map_.find(graph_id); + if (iter == graph_model_map_.end()) { + MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; + return; + } + MS_LOG(DEBUG) << "Ge UnloadModel " << iter->first; + auto ret = ModelRunner::Instance().UnloadModel(iter->first); + if (!ret) { + MS_LOG(ERROR) << "UnloadModel failed"; + } + graph_model_map_.erase(iter); +} + +bool AscendKernelRuntime::NeedDestroyHccl() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->enable_hccl()) { + MS_LOG(INFO) << "Hccl is not enabled"; + return false; + } + // Note: make sure hcom_connectivity_detection api never be used. + return true; +} + +void AscendKernelRuntime::ReleaseDeviceRes() { + MS_LOG(INFO) << "Ascend finalize start"; + // release ge runtime + ClearGraphModelMap(); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto ret = rtSetDevice(context_ptr->device_id()); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; + } + + if (mem_manager_ != nullptr) { + mem_manager_->FreeDeviceMemory(); + } + + (void)DestroyHccl(); + (void)ResetDevice(); + (void)ProfilingManager::GetInstance().StopProfiling(); + MS_LOG(INFO) << "Ascend finalize end"; +} + +bool AscendKernelRuntime::Init() { + if (initialized_) { + return true; + } + bool ret = false; +#ifdef ENABLE_DUMP_E2E + ret = SetDumpConf(); + if (!ret) { + MS_LOG(INFO) << "No dump conf to set!"; + } +#endif + +#ifdef ENABLE_DATA_DUMP + DataDumpParser::GetInstance().ParseDumpConfig(); +#endif + + // Start up profiling before rtSetDevice + ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); + if (!ret) { + MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed."; + } + + ret = InitDevice(); + if (!ret) { + return ret; + } + mem_manager_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->MallocDeviceMemory(); + + initialized_ = true; + return ret; +} + +#ifdef ENABLE_DUMP_E2E +namespace { +void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(dump_conf); + bool trans_flag = dump_conf->trans_flag(); + const auto &apply_kernels = graph->execution_order(); + for (const auto &node : apply_kernels) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = AnfAlgo::GetCNodeName(node); + std::string kernel_name = node->fullname_with_scope(); + if (!dump_conf->IsKernelNeedDump(kernel_name)) { + continue; + } + const std::string strsrc = "/"; + const std::string strdst = "--"; + std::string::size_type pos = 0; + std::string::size_type srclen = strsrc.size(); + std::string::size_type dstlen = strdst.size(); + while ((pos = kernel_name.find(strsrc, pos)) != std::string::npos) { + kernel_name.replace(pos, srclen, strdst); + pos += dstlen; + } + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t j = 0; j < output_size; ++j) { + auto addr = AnfAlgo::GetOutputAddr(node, j); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(node, j); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(node, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto type = AnfAlgo::GetOutputInferDataType(node, j); + auto format = kOpFormat_DEFAULT; + string filepath = dump_path + '/' + kernel_name + '_' + "output_" + std::to_string(j); + auto ascend_addr = dynamic_cast(addr); + auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); + if (!ret) { + MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath + << ", host_format:" << format << ".!"; + } + } + } +} + +void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_path, DumpConfPtr dump_conf) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(dump_conf); + bool trans_flag = dump_conf->trans_flag(); + const auto ¶meters = graph->inputs(); + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + if (!dump_conf->IsKernelNeedDump(parameter_name)) { + continue; + } + auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + string filepath = dump_path + '/' + parameter_name + '_' + "output_0"; + auto ascend_addr = dynamic_cast(addr); + auto ret = ascend_addr->DumpMemToFile(trans_flag, filepath, format, int_shapes, type); + if (!ret) { + MS_LOG(ERROR) << "DumpMemToFile Failed: flag:" << trans_flag << ", path:" << filepath + << ", host_format:" << format << ".!"; + } + } +} +} // namespace +#endif + +bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); +#ifdef ENABLE_DUMP_E2E + MS_LOG(INFO) << "Start dump step"; + DumpConfPtr dump_conf = GetDumpConf(); + MS_EXCEPTION_IF_NULL(dump_conf); + dump_conf->UpdataCurIter(); + bool dump_flag = dump_conf->dump_enable(); + if (!dump_flag) { + MS_LOG(INFO) << "Dump flag is disable, pass dump step"; + return true; + } + uint32_t cur_iter = dump_conf->cur_iter(); + if (dump_conf->dump_iter() != 0) { + if (cur_iter != dump_conf->dump_iter()) { + return true; + } + } + MS_LOG(INFO) << "Cur iter is " << cur_iter; + std::string net_name = dump_conf->dump_net_name(); + std::string iterator = to_string(cur_iter); + std::string dump_path = dump_conf->dump_path(); + if (dump_path.back() == '/') { + dump_path = dump_path + net_name + '/' + iterator; + } else { + dump_path = dump_path + '/' + net_name + '/' + iterator; + } + // dump output + DumpOutput(graph, dump_path, dump_conf); + // dump parameters + DumpParameters(graph, dump_path, dump_conf); +#endif + return true; +} + +#ifdef ENABLE_DEBUGGER +namespace { +void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + bool trans_flag = false; + const auto &apply_kernels = graph->execution_order(); + // for kernels, execution order starts from 1 + int exec_order = 1; + for (const auto &node : apply_kernels) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = AnfAlgo::GetCNodeName(node); + std::string kernel_name = node->fullname_with_scope(); + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t j = 0; j < output_size; ++j) { + auto addr = AnfAlgo::GetOutputAddr(node, j); + auto type = AnfAlgo::GetOutputInferDataType(node, j); + auto format = kOpFormat_DEFAULT; + string tensor_name = kernel_name + ':' + std::to_string(j); + auto ascend_addr = dynamic_cast(addr); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(node, j); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(node, j); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto ret = + ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger, false); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name + << ", host_format:" << format << ".!"; + } + } + exec_order = exec_order + 1; + } +} + +void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); + bool trans_flag = false; + const auto ¶meters = graph->inputs(); + // for parameters, set its execution order to be 0; + int exec_order = 0; + for (auto &item : parameters) { + if (!item->isa()) { + continue; + } + std::string parameter_name = item->fullname_with_scope(); + auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); + auto type = AnfAlgo::GetOutputInferDataType(item, PRAMATER_OUTPUT_INDEX); + auto format = kOpFormat_DEFAULT; + string tensor_name = parameter_name + ':' + "0"; + auto ascend_addr = dynamic_cast(addr); + std::vector int_shapes; + if (trans_flag) { + int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); + } else { + auto shape = AnfAlgo::GetOutputDeviceShape(item, PRAMATER_OUTPUT_INDEX); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), + [](size_t inner_item) { return SizeToInt(inner_item); }); + } + auto ret = + ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger, true); + if (!ret) { + MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name + << ", host_format:" << format << ".!"; + } + } +} +} // namespace +#endif + +bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + MS_EXCEPTION_IF_NULL(graph); +#ifdef ENABLE_DEBUGGER + MS_LOG(INFO) << "Start load step"; + uint32_t cur_iter = 0; + MS_LOG(INFO) << "Cur iter is " << cur_iter; + // load output + LoadOutput(graph, debugger); + // load parameters + LoadParameters(graph, debugger); +#endif + return true; +} + +bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { + if (AnfAlgo::OutputAddrExist(kernel, index)) { + auto address = AnfAlgo::GetOutputAddr(kernel, index); + MS_EXCEPTION_IF_NULL(address); + return address->DeviceType() == DeviceAddressType::kAscend; + } + return false; +} + +DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) { + return std::make_shared(device_ptr, device_size, format, type_id); +} + +bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { + if (graph == nullptr) { + MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; + } + MS_LOG(INFO) << "GenTask start. GraphId:" << graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_task_sink = context_ptr->enable_task_sink(); + if (!is_task_sink) { + return true; + } +#ifdef MEM_REUSE_DEBUG + if (!context_ptr->enable_mem_reuse()) { + // Get normal graph ir for memreuse + mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(graph); + } +#endif + vector> task_info_list; + auto anf_node_list = graph->execution_order(); + TaskGenerator::GenTasks(anf_node_list, &task_info_list, graph->graph_id()); + // Store the task_info_list + auto insert_ret = task_map_.insert(std::make_pair(graph->graph_id(), task_info_list)); + if (!insert_ret.second) { + MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; + } + // Graph may have no compute node, such TensorAddGrad. + if (task_info_list.empty()) { + MS_LOG(WARNING) << "Graph " << graph->graph_id() << " have no compute node"; + return true; + } + AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); + // the streams' flag not HEAD_STREAM + std::vector wait_active_stream_list; + assign_instance.GetWaitStreams(&wait_active_stream_list); + std::vector force_copy_stream_list; + assign_instance.GetHcomStreams(&force_copy_stream_list); + MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() + << ", total event num:" << resource_manager.get_cur_event_num() + << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) + << ", wait_active_stream_list size:" << wait_active_stream_list.size() + << ", force_copy_stream_list size:" << force_copy_stream_list.size(); + std::vector> empty_list; + auto model = std::make_shared( + task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, + 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), + resource_manager.get_cur_event_num(), 0); + auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); + if (!ret.second) { + MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; + } + MS_LOG(INFO) << "TaskGenerator GetTaskInfo end..."; + return true; +} + +bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { + if (graph == nullptr) { + MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; + } + MS_LOG(INFO) << "LoadTask start. GraphId:" << graph->graph_id(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_task_sink = context_ptr->enable_task_sink(); + if (!is_task_sink) { + return true; + } + + if (GraphWithEmptyTaskList(graph)) { + MS_LOG(WARNING) << "LoadTask end, task list is empty"; + return true; + } + + auto model_iter = graph_model_map_.find(graph->graph_id()); + if (model_iter == graph_model_map_.end()) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph LoadTask without GenTask."; + return false; + } + + std::shared_ptr listener; + MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first; + bool status = + ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second, listener); + if (!status) { + MS_LOG(EXCEPTION) << "Load Task Failed"; + } + if (ProfilingManager::GetInstance().IsProfiling()) { + auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first); + auto stream_ids = ModelRunner::Instance().GetStreamIdList(model_iter->first); + ProfilingUtils::ReportProfilingData(task_ids, stream_ids, NOT_NULL(graph)); + } + +#ifdef ENABLE_DATA_DUMP + LaunchDataDump(NOT_NULL(graph)); +#endif + if (!ModelRunner::Instance().LoadModelComplete(model_iter->first)) { + MS_LOG(ERROR) << "Call ge runtime LoadModelComplete failed"; + return false; + } + return true; +} + +#ifdef ENABLE_DATA_DUMP +void AscendKernelRuntime::LaunchDataDump(NotNull graph) { + if (!DataDumpParser::GetInstance().DumpEnabled()) { + return; + } + auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph->graph_id()); + auto data_dumper = std::make_shared(graph.get(), runtime_info_map); + MS_EXCEPTION_IF_NULL(data_dumper); + data_dumper->LoadDumpInfo(); + auto ret = graph_data_dumper_.try_emplace(graph->graph_id(), data_dumper); + if (!ret.second) { + MS_LOG(WARNING) << "[DataDump] Insert graphId:" << graph->graph_id() << " data dumper failed"; + } +} +#endif + +void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) { + auto task_ids = ModelRunner::Instance().GetTaskIdList(graph_id); + auto graph_task_names = ProfilingUtils::graph_kernel_name(); + auto iter = graph_task_names.find(graph_id); + if (iter != graph_task_names.end()) { + const auto &task_names = iter->second; + if (task_ids.size() != task_names.size()) { + MS_LOG(WARNING) << "Task_ids and task_names size not match"; + return; + } + for (size_t i = 0; i < task_ids.size(); ++i) { + MS_LOG(INFO) << "Task_id:" << task_ids[i] << " task_name:" << task_names[i]; + } + } +} + +bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + ge::InputData input_tensors = ge::InputData(); + ge::OutputData *output_tensors = nullptr; + if (GraphWithEmptyTaskList(graph)) { + MS_LOG(WARNING) << "RunTask end, no task info found"; + return true; + } + + if (!CheckGraphIdValid(graph->graph_id())) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Invalid! Graph RunTask without GenTask."; + return false; + } + + bool status = ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); + if (!status) { + MS_LOG(ERROR) << "Run task failed"; + DebugTaskIdName(graph->graph_id()); + return false; + } + return true; +} + +bool AscendKernelRuntime::SyncStream() { + if (RT_ERROR_NONE != rtStreamSynchronize(stream_)) { // o for switch stream + MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; + return false; + } + return true; +} + +bool AscendKernelRuntime::InitDevice() { + int device_count = 0; + auto ret = rtGetDeviceCount(&device_count); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast(ret) << "]"; + } + + ret = rtSetDevice(device_id_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast(ret) << "]"; + } + + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr == nullptr) { + MS_LOG(ERROR) << "Get MsContext instance failed"; + return false; + } + if (context_ptr->enable_hccl()) { + if (!HcclInit()) { + MS_LOG(ERROR) << "HcclInit init failed"; + return false; + } + } + + ret = rtCtxCreate(&rt_context_, 0, device_id_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; + } + + ret = rtCtxSetCurrent(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + } + + ret = rtStreamCreate(&stream_, 0); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; + } + + return true; +} + +bool AscendKernelRuntime::ResetDevice() { + auto ret = rtCtxSetCurrent(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rtCtxSetCurrent failed"; + return false; + } + + if (stream_ != nullptr) { + ret = rtStreamDestroy(stream_); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]"; + } + stream_ = nullptr; + } + + if (rt_context_ != nullptr) { + ret = rtCtxDestroy(rt_context_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; + } + rt_context_ = nullptr; + } + return true; +} + +bool AscendKernelRuntime::HcclInit() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->IsTsdOpened()) { + MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; + } + MS_LOG(INFO) << "Do hcom init"; + auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); + if (config_path_str == nullptr) { + config_path_str = std::getenv("RANK_TABLE_FILE"); + if (config_path_str == nullptr) { + MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"; + return false; + } + } + if (strlen(config_path_str) > PATH_MAX) { + MS_LOG(ERROR) << "File path oversize"; + return false; + } + std::string rank_id_str = GetRankId(); + auto full_path = realpath(config_path_str, nullptr); + if (full_path == nullptr) { + MS_LOG(ERROR) << "File path " << config_path_str << " does not exist"; + return false; + } + MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; + hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); + free(full_path); + if (res != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast(res); + return false; + } + return true; +} + +bool AscendKernelRuntime::DestroyHccl() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!NeedDestroyHccl()) { + MS_LOG(INFO) << "Hccl is not enable, no need to close."; + return true; + } + hcclResult_t res = hcom_destroy(); + if (res != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Hccl destroy failed"; + return false; + } + MS_LOG(INFO) << "Hccl destroy successful, status = " << res << "."; + context_ptr->set_enable_hccl(false); + return true; +} + +bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph *graph) const { + auto iter = task_map_.find(graph->graph_id()); + if (iter == task_map_.end()) { + MS_LOG(EXCEPTION) << "Unknown graph ptr"; + } + return iter->second.empty(); +} + +bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const { + return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h new file mode 100644 index 0000000000..4f1663d4d5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "runtime/context.h" +#include "framework/ge_runtime/davinci_model.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/session/session_basic.h" +#ifdef ENABLE_DATA_DUMP +#include "debug/data_dump_parser.h" +#include "runtime/device/ascend/dump/data_dumper.h" +#endif + +using ge::model_runner::TaskInfo; +using std::unordered_map; +using std::vector; +namespace mindspore { +namespace device { +namespace ascend { +class AscendKernelRuntime : public KernelRuntime { + public: + AscendKernelRuntime() = default; + ~AscendKernelRuntime() override; + bool Init() override; + bool DumpData(session::KernelGraph *graph) override; + bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; + bool GenTask(const session::KernelGraph *graph) override; + bool RunTask(const session::KernelGraph *graph) override; + bool LoadTask(const session::KernelGraph *graph) override; + void ClearGraphRuntimeResource(uint32_t graph_id) override; + bool SyncStream() override; + + protected: + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override; + bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; + + private: + bool InitDevice(); + bool ResetDevice(); + bool HcclInit(); + bool NeedDestroyHccl(); + bool DestroyHccl(); + + void ClearGraphModelMap(); + void ReleaseDeviceRes() override; + bool GraphWithEmptyTaskList(const session::KernelGraph *graph) const; + bool CheckGraphIdValid(GraphId graph_id) const; + static void DebugTaskIdName(GraphId graph_id); + + rtContext_t rt_context_{nullptr}; + bool initialized_{false}; + unordered_map>> task_map_; + unordered_map> graph_model_map_; +#ifdef ENABLE_DATA_DUMP + void LaunchDataDump(NotNull graph); + unordered_map> graph_data_dumper_; +#endif +}; + +MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc new file mode 100644 index 0000000000..035f4dd8e3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.cc @@ -0,0 +1,163 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "runtime/device/ascend/ascend_label_assign.h" +#include "backend/session/anf_runtime_algorithm.h" + +static constexpr uint32_t kLabelGotoLabelId = 1; +static constexpr uint32_t kLabelSwitchLabelId = 2; + +namespace mindspore { +namespace device { +namespace ascend { +static void UpdateLabelGoto(NotNull node) { + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + return; + } + if (node->size() <= kLabelGotoLabelId) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); + } + + auto input = node->input(kLabelGotoLabelId); + uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); + MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; + node->set_inputs({node->input(0)}); +} + +static void UpdateLabelSwitch(NotNull node) { + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + return; + } + if (node->size() <= kLabelGotoLabelId) { + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); + } + std::vector label_list; + for (size_t i = kLabelSwitchLabelId; i < node->size(); ++i) { + auto input = node->input(i); + if (!input->isa() || AnfAlgo::GetCNodeName(input) != kLabelSetOpName) { + break; + } + + uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); + label_list.push_back(goto_label_id); + MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; + } + AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue>(label_list), node.get()); + node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)}); +} + +static void AssignLabelForLabelSet(NotNull> graph, NotNull label_id, + NotNull> *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Assign label for " << graph->ToString(); + graph->SetExecOrderByDefault(); + auto nodes = graph->execution_order(); + + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string node_name = AnfAlgo::GetCNodeName(node); + if (node_name == kLabelSetOpName && !AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(*label_id), node); + MS_LOG(INFO) << "Node " << node->DebugString() << " assign label id " << *label_id; + ++(*label_id); + } + } + + for (auto &cg : graph->child_graph_order()) { + AssignLabelForLabelSet(NOT_NULL(cg), label_id, memo); + } +} + +static void AssignLabelForGotoSwitch(NotNull> graph, + NotNull> *> memo) { + if (memo->find(graph.get()) != memo->end()) { + return; + } + memo->insert(graph.get()); + + MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); + + auto nodes = graph->execution_order(); + auto end_goto = graph->get_end_goto(); + if (end_goto != nullptr) { + nodes.push_back(end_goto); + } + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + std::string node_name = AnfAlgo::GetCNodeName(node); + if (node_name == kLabelGotoOpName) { + UpdateLabelGoto(NOT_NULL(cnode)); + cnode->set_abstract(nullptr); + } + + if (node_name == kLabelSwitchOpName) { + UpdateLabelSwitch(NOT_NULL(cnode)); + } + } + for (auto &cg : graph->child_graph_order()) { + AssignLabelForGotoSwitch(NOT_NULL(cg), memo); + } + graph->SetExecOrderByDefault(); +} + +void AscendLabelAssign::AssignLabel(NotNull> graph) { + MS_LOG(INFO) << "Assign label start."; + std::set> memo; + uint32_t label_id = 0; + AssignLabelForLabelSet(graph, NOT_NULL(&label_id), NOT_NULL(&memo)); + memo.clear(); + { + std::lock_guard lock(label_num_mutex_); + label_num_[graph.get().get()] = label_id; + } + AssignLabelForGotoSwitch(graph, NOT_NULL(&memo)); + MS_LOG(INFO) << "Assign label end."; +} + +uint32_t AscendLabelAssign::GetLabelNum(NotNull graph) { + std::lock_guard lock(label_num_mutex_); + auto iter = label_num_.find(graph.get()); + if (iter == label_num_.end()) { + MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 0."; + return 0; + } + return iter->second; +} + +uint32_t AscendLabelAssign::GetLabelNum(NotNull> graph) { + return GetLabelNum(NOT_NULL(graph.get().get())); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h new file mode 100644 index 0000000000..6b09f2940e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_label_assign.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ + +#include +#include +#include "backend/session/kernel_graph.h" +#include "utils/contract.h" + +namespace mindspore { +namespace device { +namespace ascend { +class AscendLabelAssign { + public: + static AscendLabelAssign &GetInstance() { + static AscendLabelAssign instance; // Guaranteed to be destroyed. + return instance; + } + + AscendLabelAssign(const AscendLabelAssign &) = delete; + AscendLabelAssign &operator=(const AscendLabelAssign &) = delete; + + void AssignLabel(NotNull> graph); + uint32_t GetLabelNum(NotNull graph); + uint32_t GetLabelNum(NotNull> graph); + + private: + AscendLabelAssign() = default; + ~AscendLabelAssign() = default; + + std::map label_num_; + std::mutex label_num_mutex_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_LABEL_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc new file mode 100644 index 0000000000..f9da0850c6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "runtime/device/ascend/ascend_memory_manager.h" +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "utils/context/ms_context.h" +#include "runtime/mem.h" +namespace mindspore { +namespace device { +namespace ascend { +constexpr uint64_t kAscendDeviceMemGB = 30; +constexpr uint64_t kMemSizeGB = 30; +constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB); + +void AscendMemoryManager::MallocDeviceMemory() { + auto context_mem = GetDeviceMemSizeFromContext(); + device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem; + dynamic_mem_offset_ = device_mem_size_; + auto ret = rtMalloc(reinterpret_cast(&device_mem_base_), dynamic_mem_offset_, RT_MEMORY_HBM); + + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << dynamic_mem_offset_ << "] fail, ret[" << ret << "]"; + } + + AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_base_); + AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); +} + +uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + auto variable_memory_max_size = context->variable_memory_max_size(); + if (variable_memory_max_size == "0") { + return 0; + } + MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size; + auto pos = variable_memory_max_size.find('*'); + if (pos == std::string::npos) { + MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size"; + } + auto gb_str = variable_memory_max_size.substr(0, pos); + auto gb_var = std::stoull(gb_str); + MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var; + if (gb_var > kAscendDeviceMemGB || gb_var == 0) { + MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB"; + } + return gb_var << kMemSizeGB; +} + +void AscendMemoryManager::FreeDeviceMemory() { + if (device_mem_base_ != nullptr) { + auto ret = rtFree(device_mem_base_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; + } + device_mem_base_ = nullptr; + } + if (device_mem_pool_base_ != nullptr) { + auto ret = rtFree(device_mem_pool_base_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; + } + device_mem_pool_base_ = nullptr; + } +} + +void AscendMemoryManager::ResetDynamicMemory() { + total_dynamic_size_ = 0; + dynamic_mem_offset_ = device_mem_size_; + AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); +} + +void *AscendMemoryManager::MallocMemFromMemPool(size_t size) { + auto align_size = GetCommonAlignSize(size); + return AscendMemoryPool::GetInstance().AllocTensorMem(align_size); +} + +uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + if (communication_mem) { + // create protect area [kMemAlignSize -- data -- kMemAlignSize] + uint8_t *alloc_address = reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); + return alloc_address + kMemAlignSize; + } else { + return reinterpret_cast(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); + } +} + +uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + if (dynamic_mem_offset_ < align_size) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "]) malloc [" << align_size << "] failed!"; + } + auto new_offset = dynamic_mem_offset_ - align_size; + auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset(); + if (new_offset <= device_mem_pool_offset) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ + << "] memory pool[" << device_mem_pool_offset << "])" + << " malloc [" << align_size << "] failed!"; + } + total_dynamic_size_ += align_size; + dynamic_mem_offset_ = new_offset; + AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_); + if (communication_mem) { + // create protect area [kMemAlignSize -- data -- kMemAlignSize] + return device_mem_base_ + new_offset + kMemAlignSize; + } else { + return device_mem_base_ + new_offset; + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h new file mode 100644 index 0000000000..720f15be00 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ +#include "runtime/device/memory_manager.h" +namespace mindspore { +namespace device { +namespace ascend { +class AscendMemoryManager : public MemoryManager { + public: + AscendMemoryManager() = default; + ~AscendMemoryManager() override = default; + + void MallocDeviceMemory() override; + void FreeDeviceMemory() override; + void ResetDynamicMemory() override; + void *MallocMemFromMemPool(size_t size) override; + + protected: + uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; + uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override; + + private: + uint8_t *device_mem_pool_base_{nullptr}; + uint64_t device_mem_pool_size_{0}; + + uint64_t GetDeviceMemSizeFromContext(); +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc new file mode 100644 index 0000000000..fe71ba43fc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ascend_memory_pool.h" +#include "runtime/device/ascend/ascend_kernel_runtime.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace ascend { +size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "Can not alloc memory size(0) in memory pool !"; + } + if (device_mem_pool_offset_ + size >= graph_dynamic_mem_offset_) { + MS_LOG(EXCEPTION) << "Failed to alloc memory pool memory, the current device_mem_pool_offset_ [" + << device_mem_pool_offset_ << "], current graph_dynamic_mem_offset_ " << graph_dynamic_mem_offset_ + << "], need memory size [" << size << "]"; + } + *addr = device_mem_pool_base_ + device_mem_pool_offset_; + device_mem_pool_offset_ += size; + if (*addr == nullptr) { + MS_LOG(EXCEPTION) << "Alloc device address is nullptr, failed to alloc memory pool memory!"; + } + return size; +} + +bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { + MS_EXCEPTION_IF_NULL(addr); + return true; +} + +size_t AscendMemoryPool::AlignMemorySize(size_t size) const { + if (size == 0) { + MS_LOG(EXCEPTION) << "The align memory size is a zero !"; + } + return size; +} + +void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { + MS_EXCEPTION_IF_NULL(device_mem_pool_base); + device_mem_pool_base_ = device_mem_pool_base; +} + +void AscendMemoryPool::set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset) { + graph_dynamic_mem_offset_ = graph_dynamic_mem_offset; +} + +uint64_t AscendMemoryPool::device_mem_pool_offset() const { return device_mem_pool_offset_; } + +size_t AscendMemoryPool::free_mem_size() { + if (graph_dynamic_mem_offset_ < device_mem_pool_offset_) { + MS_LOG(EXCEPTION) << "graph dynamic mem offset [" << graph_dynamic_mem_offset_ + << "] less than device mem pool offset [" << device_mem_pool_offset_ << "]!"; + } + return graph_dynamic_mem_offset_ - device_mem_pool_offset_; +} + +size_t AscendMemoryPool::total_mem_size() { return graph_dynamic_mem_offset_ == 0 ? 0 : graph_dynamic_mem_offset_ - 1; } +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h new file mode 100644 index 0000000000..7a75198ab4 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ + +#include +#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" + +namespace mindspore { +namespace device { +namespace ascend { +class AscendMemoryPool : public DynamicMemPoolBestFit { + public: + ~AscendMemoryPool() override = default; + AscendMemoryPool(const AscendMemoryPool &) = delete; + AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; + + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + void set_device_mem_pool_base(uint8_t *device_mem_pool_base); + void set_graph_dynamic_mem_offset(uint64_t graph_dynamic_mem_offset); + + uint64_t device_mem_pool_offset() const; + size_t free_mem_size() override; + size_t total_mem_size() override; + + static AscendMemoryPool &GetInstance() { + static AscendMemoryPool instance; + return instance; + } + + protected: + // The real size by memory alloc aligned. + size_t AlignMemorySize(size_t size) const override; + + private: + AscendMemoryPool() = default; + uint8_t *device_mem_pool_base_{nullptr}; + uint64_t device_mem_pool_offset_{0}; + uint64_t graph_dynamic_mem_offset_{0}; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_MEMORY_POOL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc new file mode 100644 index 0000000000..7cf5b94d45 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -0,0 +1,1268 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ascend_stream_assign.h" + +#include +#include + +#include "ir/manager.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_adjust.h" +#include "predict/generator/utils/ir_model_util.h" +#include "backend/optimizer/common/helper.h" +#include "utils/utils.h" + +namespace mindspore { +namespace device { +namespace ascend { +const uint32_t kHcomMaxTask = 5; +const uint32_t kCommonMaxTask = 350; + +void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { + if (IsTaskSink()) { + Reset(); + ReorderIndependentOrders(graph_ptr); + AssignAllNodesStream(graph_ptr); + UpdateAtomicAddrCleanStreamId(graph_ptr); + InsertStreamActive(graph_ptr); + InsertEventForHcomParallel(graph_ptr); + InsertEventForIndependentParallel(graph_ptr); + GetNeedActiveStreams(graph_ptr); + graph_ptr->PrintGraphExecuteOrder(); + CheckResourceAssign(graph_ptr); + MS_LOG(INFO) << "After finish stream assign"; + + FindStreamRelations(graph_ptr); + PrintStreamRelations(); + GetStreamRelations(); + PrintStreamGroups(); + FindEventRelations(graph_ptr); + + // Get info for D Model + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + generator::IRModelUtil::GetInstance().set_event_num(resource_manager.get_cur_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(resource_manager.get_cur_stream_num()); + // Init to 1,temporarily + generator::IRModelUtil::GetInstance().set_batch_num(1); + } +} + +// section 1 +void AscendStreamAssign::ReorderIndependentOrders(const NotNull &graph_ptr) { + std::vector exe_orders; + std::vector independents; + std::vector others; + + auto cnode_ptr_list = graph_ptr->execution_order(); + MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + auto cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (IsIndependentNode(cur_cnode_ptr)) { + independents.emplace_back(cur_cnode_ptr); + } else { + others.emplace_back(cur_cnode_ptr); + } + } + + if (others.empty() || independents.empty()) { + MS_LOG(INFO) << "Independent or others is empty, no need reorder"; + return; + } + + std::set processed; + for (size_t i = 0; i < others.size(); i++) { + auto begin = others.begin() + i; + auto end = begin + 1; + bool flag = false; + for (size_t j = 0; j < independents.size(); j++) { + auto cur_independent = independents[j]; + auto it = std::find(processed.begin(), processed.end(), cur_independent.get()); + if (it != processed.end()) { + continue; + } + + auto res = FindTargetOp(begin, end, cur_independent); + if (res != end) { + flag = true; + exe_orders.emplace_back(cur_independent); + exe_orders.emplace_back(*begin); + processed.emplace(cur_independent.get()); + break; + } + } + + if (!flag) { + exe_orders.emplace_back(*begin); + } + } + + MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size(); + if (processed.size() != independents.size()) { + MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size"; + return; + } + + graph_ptr->set_execution_order(exe_orders); +} + +// section 2 +void AscendStreamAssign::AssignAllNodesStream(const NotNull &graph_ptr) { + auto cnode_ptr_list = graph_ptr->execution_order(); + bool exit_independent = false; + bool exit_hcom = false; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + // node has been assigned stream before + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + + if (IsHcom(cur_cnode_ptr)) { + exit_hcom = true; + continue; + } + + if (IsIndependentNode(cur_cnode_ptr)) { + exit_independent = true; + continue; + } + + AssignCommonStreamId(cur_cnode_ptr); + } + MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); + + if (exit_hcom) { + uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + // node has been assigned stream before + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + + if (IsHcom(cur_cnode_ptr)) { + AssignHcomStreamId(cur_cnode_ptr); + } + } + MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); + } + + if (exit_independent) { + uint32_t first_independ = resource_manager.ApplyNewStream(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) { + continue; + } + if (IsIndependentNode(cur_cnode_ptr)) { + AssignIndependentStreamId(cur_cnode_ptr); + } + } + MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size(); + } + + MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); +} + +void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_common_stream_id = 0; + uint32_t cur_stream_num = resource_manager.get_cur_stream_num(); + if (cur_stream_num == 0) { + cur_common_stream_id = resource_manager.ApplyNewStream(); + } else { + cur_common_stream_id = resource_manager.GetCurAllocStreamId(); + } + + auto it = common_stream_map_.find(cur_common_stream_id); + if (it == common_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); + common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); + } else { + if (it->second < kCommonMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_common_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get()); + common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1)); + } + } +} + +void AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_hcom_stream_id = resource_manager.GetCurAllocStreamId(); + auto it = hcom_stream_map_.find(cur_hcom_stream_id); + if (it == hcom_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); + hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + } else { + if (it->second < kHcomMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_hcom_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get()); + hcom_stream_map_.insert(std::make_pair(cur_hcom_stream_id, 1)); + } + } +} + +void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr) { + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t cur_independent_id = resource_manager.GetCurAllocStreamId(); + auto it = independent_stream_map_.find(cur_independent_id); + if (it == independent_stream_map_.end()) { + AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); + independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); + } else { + if (it->second < kCommonMaxTask) { + AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get()); + it->second++; + } else { + cur_independent_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(cur_independent_id, cur_cnode_ptr.get()); + independent_stream_map_.insert(std::make_pair(cur_independent_id, 1)); + } + } +} + +bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { + MS_EXCEPTION_IF_NULL(node_ptr); + if (AnfAlgo::GetKernelType(node_ptr) != AICPU_KERNEL) { + return false; + } + + if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { + MS_LOG(INFO) << "GetNext should not be independent node"; + return false; + } + + uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); + if (input_nums == 0) { + MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; + return true; + } + + auto inputs = node_ptr->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + if (!inputs[i]->isa()) { + return false; + } + } + MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; + return true; +} + +// section 3: +void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + // update AtomicAddrClean stream same witch the next node + if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) { + AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); + } + } + MS_LOG(INFO) << "End"; +} + +// section 4 +void AscendStreamAssign::InsertStreamActive(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + GetProcessedStream(graph_ptr); + std::vector update_cnode_list; + CNodePtr cur_cnode_ptr = nullptr; + CNodePtr pre_cnode_ptr = nullptr; + uint32_t pre_stream_id = UINT32_MAX; + + bool independent_flag = !(independent_stream_map_.empty()); + bool hcom_flag = !(hcom_stream_map_.empty()); + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (IsIndependentNode(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + + if (IsHcom(cur_cnode_ptr)) { + update_cnode_list.emplace_back(cur_cnode_ptr); + continue; + } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + bool processed = IsProcessedStream(cur_stream_id); + // 1)inner stream assign, need insert active op + if (!processed) { + MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id; + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + // 1.set stream id + AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); + // 2.set active stream ids + std::vector active_index_list{cur_stream_id}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); + update_cnode_list.emplace_back(active_ptr); + } + + if ((independent_flag || hcom_flag) && (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName)) { + MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel"; + UpdateStreamSwitch(graph_ptr, cur_cnode_ptr, &update_cnode_list); + } else { + update_cnode_list.emplace_back(cur_cnode_ptr); + } + + processed_streams_.emplace(cur_stream_id); + pre_stream_id = cur_stream_id; + pre_cnode_ptr = cur_cnode_ptr; + } + graph_ptr->set_execution_order(update_cnode_list); + MS_LOG(INFO) << "End"; +} + +void AscendStreamAssign::GetProcessedStream(const NotNull &graph_ptr) { + // 0 stream is activated at first + processed_streams_.emplace(0); + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + auto cur_cnode_ptr = cnode_ptr_list[i]; + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { + auto true_stream_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrTrueBranchStream); + processed_streams_.emplace(true_stream_id); + + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { + continue; + } + auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); + if (need_active) { + processed_streams_.emplace(cur_stream_id); + } + } + } + for (const auto &item : processed_streams_) { + MS_LOG(INFO) << "Before active:" << item << " is been processed"; + } +} + +void AscendStreamAssign::UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, + vector *orders) { + orders->emplace_back(switch_ptr); + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) { + return; + } + + auto need_active = AnfAlgo::GetNodeAttr(switch_ptr, kStreamNeedActivedFirst); + if (!need_active) { + return; + } + + MS_EXCEPTION_IF_NULL(switch_ptr); + auto true_stream_id = AnfAlgo::GetNodeAttr(switch_ptr, kAttrTrueBranchStream); + MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) + << "; active stream id:" << true_stream_id; + + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); + AnfAlgo::SetStreamId(true_stream_id, active_ptr.get()); + vector active_ids; + // active indepdent stream + for (const auto &item : independent_stream_map_) { + active_ids.emplace_back(item.first); + } + // active hcom stream + for (const auto &item : hcom_stream_map_) { + active_ids.emplace_back(item.first); + } + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_ids), active_ptr); + + // update processed stream + independent_stream_activated_ = true; + for (const auto &item : independent_stream_map_) { + processed_streams_.emplace(item.first); + } + + hcom_stream_activated_ = true; + for (const auto &item : hcom_stream_map_) { + processed_streams_.emplace(item.first); + } + + orders->emplace_back(active_ptr); +} + +bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { + auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id); + if (it != processed_streams_.end()) { + return true; + } + return false; +} + +// section5 +void AscendStreamAssign::InsertEventForHcomParallel(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + InsertEventCommonDependHcom(graph_ptr); + InsertEventHcomDependCommon(graph_ptr); + InsertEventHcomDependHcom(graph_ptr); + MS_LOG(INFO) << "End"; +} + +void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes = cnode_ptr_list; + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto it = cnodes.begin(); + while (it != cnodes.end() && (it + 1) != cnodes.end()) { + MS_EXCEPTION_IF_NULL(*it); + MS_EXCEPTION_IF_NULL(*(it + 1)); + if (IsHcom(*it) && !IsHcom(*(it + 1))) { + CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); + it = cnodes.insert(it + 1, send_cnode_ptr); + + auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); + if (target == cnodes.end()) { + MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope() + << ", can't find target for insert recv op, no insert send/recv"; + it = cnodes.erase(it); + continue; + } + + if (IsHcom(*target)) { + it = cnodes.erase(it); + continue; + } + + // deal recv op + uint32_t stream_id = AnfAlgo::GetStreamId(*target); + CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); + (void)cnodes.insert(target, recv_cnode_ptr); + cur_event_id = resource_manager.ApplyNewEvent(); + } + ++it; + } + // one event allocated additional, should delete + resource_manager.DeleteEvent(); + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes; + CNodePtr cur_cnode_ptr = nullptr; + uint32_t pre_stream_id = UINT32_MAX; + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (i == 0) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; + } + + if (!IsHcom(cur_cnode_ptr)) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; + } + + if (cur_stream_id == pre_stream_id) { + cnodes.emplace_back(cur_cnode_ptr); + pre_stream_id = cur_stream_id; + continue; + } + + if (!IsHcom(cnode_ptr_list[i - 1])) { + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id); + cnodes.emplace_back(send); + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id); + cnodes.emplace_back(recv); + cnodes.emplace_back(cur_cnode_ptr); + } else { + cnodes.emplace_back(cur_cnode_ptr); + } + pre_stream_id = cur_stream_id; + } + + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + uint32_t first_hcom_stream = kInvalidStreamId; + uint32_t last_hcom_stream = kInvalidStreamId; + // key: stream id, value:hcom index + std::map> hcom_index; + for (size_t i = 0; i < cnode_ptr_list.size(); i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (!IsHcom(cur_cnode)) { + continue; + } + uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + auto it = hcom_index.find(cur_stream_id); + if (it != hcom_index.end()) { + hcom_index[cur_stream_id].emplace_back(i); + } else { + hcom_index[cur_stream_id] = {i}; + } + + // record first hcom stream id + if (first_hcom_stream == kInvalidStreamId) { + first_hcom_stream = cur_stream_id; + } + + // record last hcom stream id + if (cur_stream_id != last_hcom_stream) { + last_hcom_stream = cur_stream_id; + } + } + + if (hcom_index.size() < 2) { + MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them"; + return; + } + InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); + MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); +} + +void AscendStreamAssign::InsertEventBetweenHcom(const NotNull &graph_ptr, + const map> &hcom_index, + uint32_t first_hcom_stream, uint32_t last_hcom_stream) { + vector orders; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + size_t first_stream_last_index = hcom_index.at(first_hcom_stream).back(); + size_t last_stream_first_index = hcom_index.at(last_hcom_stream).front(); + std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders)); + for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) { + auto cur_cnode = cnode_ptr_list[i]; + if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) { + orders.emplace_back(cur_cnode); + continue; + } + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (i == first_stream_last_index) { + // first fusion hcom + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else if (i == last_stream_first_index) { + // last fusion hcom + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + orders.emplace_back(cur_cnode); + } else { + auto cur_stream_hcom_size = hcom_index.at(cur_hcom_stream_id).size(); + if (cur_stream_hcom_size == 1) { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id = resource_manager.ApplyNewEvent(); + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else { + // current stream, first hcom:add recv op + if (i == hcom_index.at(cur_hcom_stream_id).front()) { + auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(recv); + cur_event_id = resource_manager.ApplyNewEvent(); + orders.emplace_back(cur_cnode); + } else if (i == hcom_index.at(cur_hcom_stream_id).back()) { + // current stream, last hcom:add send op + orders.emplace_back(cur_cnode); + auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); + orders.emplace_back(send); + } else { + // current stream, not first and last op + orders.emplace_back(cur_cnode); + } + } + } + } + std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); + graph_ptr->set_execution_order(orders); +} + +bool AscendStreamAssign::IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, + size_t index) { + MS_EXCEPTION_IF_NULL(node_ptr); + auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr); + auto it = hcom_index.find(cur_hcom_stream_id); + if (it == hcom_index.end()) { + return false; + } + auto iter = std::find(hcom_index.at(cur_hcom_stream_id).begin(), hcom_index.at(cur_hcom_stream_id).end(), index); + if (iter == hcom_index.at(cur_hcom_stream_id).end()) { + return false; + } + return true; +} + +// section6 +void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull &graph_ptr) { + MS_LOG(INFO) << "Start"; + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto cnode_ptr_list = graph_ptr->execution_order(); + vector cnodes = cnode_ptr_list; + uint32_t cur_event_id = resource_manager.ApplyNewEvent(); + auto it = cnodes.begin(); + while (it != cnodes.end()) { + MS_EXCEPTION_IF_NULL(*it); + if (IsIndependentNode(*it)) { + MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]"; + CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); + it = cnodes.insert(it + 1, send_cnode_ptr); + + auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); + if (target == cnodes.end()) { + MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope() + << "] can't find target for insert recv op, no insert send/recv"; + it = cnodes.erase(it); + continue; + } + + // deal recv op + uint32_t stream_id = AnfAlgo::GetStreamId(*target); + CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id); + (void)cnodes.insert(target, recv_cnode_ptr); + cur_event_id = resource_manager.ApplyNewEvent(); + } + ++it; + } + // one event allocated additional, should delete + resource_manager.DeleteEvent(); + graph_ptr->set_execution_order(cnodes); + MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num(); + MS_LOG(INFO) << "End"; +} + +// section7 +void AscendStreamAssign::GetNeedActiveStreams(const NotNull &graph_ptr) { + CNodePtr cur_cnode_ptr = nullptr; + auto cnode_ptr_list = graph_ptr->execution_order(); + // 1)first stream 0 should be actived first; + need_first_active_streams_.emplace_back(0); + + // 2)stream witch kStreamNeedActivedFirst attr should be actived; + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) { + continue; + } + + auto need_active = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kStreamNeedActivedFirst); + if (need_active) { + auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first"; + need_first_active_streams_.push_back(stream_id); + } + } + + // 3)independent stream:if has not been activate, push to need active vector + if (!independent_stream_activated_) { + for (auto &item : independent_stream_map_) { + need_first_active_streams_.emplace_back(item.first); + } + } + + // 4)hcom stream:if has not been activate, push to need active vector + if (!hcom_stream_activated_) { + for (auto &item : hcom_stream_map_) { + need_first_active_streams_.emplace_back(item.first); + } + } +} + +// section8 +void AscendStreamAssign::CheckResourceAssign(const NotNull &graph_ptr) { + CheckStreamAssign(graph_ptr); + CheckEventAssign(graph_ptr); +} + +void AscendStreamAssign::CheckStreamAssign(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + std::set streams; + uint32_t max_stream = 0; + uint32_t min_stream = kInvalidStreamId; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + if (stream_id == kInvalidStreamId) { + MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream"; + } + + (void)streams.emplace(stream_id); + if (stream_id > max_stream) { + max_stream = stream_id; + } + if (stream_id < min_stream) { + min_stream = stream_id; + } + } + + // check stream assign + if (!streams.empty()) { + if (min_stream != 0) { + MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream; + } + uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); + if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { + MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream + << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); + } + } +} + +void AscendStreamAssign::CheckEventAssign(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + std::map> event_map; + uint32_t max_event_id = 0; + uint32_t min_event_id = kInvalidEventId; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr); + if (name == kSendOpName || name == kRecvOpName) { + uint32_t event_id = AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId); + if (event_id > max_event_id) { + max_event_id = event_id; + } + + if (event_id < min_event_id) { + min_event_id = event_id; + } + auto it = event_map.find(event_id); + if (it == event_map.end()) { + event_map[event_id] = {cur_cnode_ptr}; + } else { + event_map[event_id].emplace_back(cur_cnode_ptr); + } + } + } + // check event assign + if (!event_map.empty()) { + if (min_event_id != 0) { + MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id; + } + uint32_t assigned_event_num = resource_manager.get_cur_event_num(); + if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { + MS_LOG(EXCEPTION) << "Event should be consecutive"; + } + for (const auto &item : event_map) { + if (item.second.size() != 2) { + MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id"; + } + auto first_name = AnfAlgo::GetCNodeName(item.second[0]); + auto second_name = AnfAlgo::GetCNodeName(item.second[1]); + if (!(first_name == kSendOpName && second_name == kRecvOpName)) { + MS_LOG(EXCEPTION) << "Send should be before recv"; + } + } + } +} + +// section9 +CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, + uint32_t stream_id) { + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + AnfAlgo::SetStreamId(stream_id, send_node_ptr.get()); + return send_node_ptr; +} + +CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, + uint32_t stream_id) { + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + +vector::iterator AscendStreamAssign::FindTargetOp(vector::iterator begin, + vector::iterator end, const CNodePtr &node) { + while (begin != end) { + auto inputs = (*begin)->inputs(); + for (size_t i = 1; i < inputs.size(); i++) { + auto input = inputs[i]; + if (opt::IsNopNode(input)) { + CNodePtr cnode = input->cast(); + auto new_inputs = cnode->inputs(); + for (size_t j = 1; j < new_inputs.size(); j++) { + auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0); + if (node == new_real_input.first) { + MS_LOG(INFO) << "Nop node find target op[" << (*begin)->DebugString() << "]"; + return begin; + } + } + } else { + auto real_input = AnfAlgo::VisitKernel(input, 0); + if (node == real_input.first) { + MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]"; + return begin; + } + } + } + ++begin; + } + return end; +} + +bool AscendStreamAssign::IsTaskSink() { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->enable_task_sink()) { + MS_LOG(INFO) << "Task sink mode is not enable"; + return false; + } else { + MS_LOG(INFO) << "Task sink mode is enable"; + return true; + } +} + +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { + MS_EXCEPTION_IF_NULL(wait_active_stream_list); + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + uint32_t total_stream_num = resource_manager.get_cur_stream_num(); + if (total_stream_num == 0) { + MS_LOG(INFO) << "The total_common_stream_num is zero"; + return; + } + + // common stream:active first common stream + for (uint32_t i = 0; i < total_stream_num; i++) { + auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); + if (it == need_first_active_streams_.end()) { + MS_LOG(INFO) << "Wait common stream id = " << i; + wait_active_stream_list->push_back(i); + } + } +} + +bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) { + MS_EXCEPTION_IF_NULL(apply_kernel); + return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL; +} + +void AscendStreamAssign::GetHcomStreams(std::vector *streams) { + MS_EXCEPTION_IF_NULL(streams); + for (const auto &item : hcom_stream_map_) { + streams->emplace_back(item.first); + } +} + +void AscendStreamAssign::Reset() { + independent_stream_activated_ = false; + hcom_stream_activated_ = false; + independent_stream_map_.clear(); + hcom_stream_map_.clear(); + common_stream_map_.clear(); + processed_streams_.clear(); + need_first_active_streams_.clear(); + stream_groups_.clear(); + stream_relations_.clear(); + event_map_.clear(); +} + +// section 10 +bool AscendStreamAssign::IsVecExist(std::vector *group) { + auto group_size = group->size(); + if (group_size == 0) { + return false; + } + for (const auto &item : stream_groups_) { + if (item.size() < group->size()) { + continue; + } + + bool flag = true; + for (size_t i = 0; i < group_size; i++) { + if (item[i] != group->at(i)) { + flag = false; + break; + } + } + + if (flag) { + return true; + } else { + continue; + } + } + + return false; +} + +void AscendStreamAssign::DFS(uint32_t start, std::vector *group) { + auto it = stream_relations_.find(start); + if (it == stream_relations_.end()) { + if (!IsVecExist(group)) { + stream_groups_.emplace_back(*group); + } else { + MS_LOG(WARNING) << "DFS should not print this log"; + } + return; + } + + vector active_streams = stream_relations_[start]; + + for (const auto &item : active_streams) { + group->emplace_back(item); + DFS(item, group); + group->pop_back(); + } +} + +void AscendStreamAssign::GetStreamRelations() { + for (const auto &start : need_first_active_streams_) { + vector group{start}; + DFS(start, &group); + } +} + +void AscendStreamAssign::FindStreamRelations(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto stream_num = resource_manager.get_cur_stream_num(); + if (stream_num <= 1) { + return; + } + + auto exe_orders = graph_ptr->execution_order(); + for (size_t i = 0; i < exe_orders.size(); i++) { + auto cur_cnode = exe_orders[i]; + auto name = AnfAlgo::GetCNodeName(cur_cnode); + if (name != kStreamSwitchOpName && name != kStreamActiveOpName) { + continue; + } + + // support:streamswitch is begin of the stream + if (name == kStreamSwitchOpName) { + GetStreamSwitchStreamRelation(cur_cnode); + } + + if (name == kStreamActiveOpName) { + GetStreamActiveStreamRelation(graph_ptr, i); + } + } +} + +void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) { + MS_EXCEPTION_IF_NULL(node_ptr); + auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr); + auto true_stream_id = AnfAlgo::GetNodeAttr(node_ptr, kAttrTrueBranchStream); + if (true_stream_id <= cur_stream_id) { + MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id + << " is greater than true branch stream id:" << true_stream_id; + } + auto it = stream_relations_.find(cur_stream_id); + if (it == stream_relations_.end()) { + stream_relations_[cur_stream_id] = {true_stream_id}; + } else { + auto iter = + std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id); + if (iter == stream_relations_[cur_stream_id].end()) { + stream_relations_[cur_stream_id].emplace_back(true_stream_id); + } + } +} + +void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index) { + StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index); + if (kind == kInvalid) { + MS_LOG(INFO) << "Invalid streamActive kind"; + return; + } + + auto orders = graph_ptr->execution_order(); + auto cur_cnode = orders[index]; + auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + auto active_list = AnfAlgo::GetNodeAttr>(cur_cnode, kAttrActiveStreamList); + if (kind == kHead) { + uint32_t active_current_node = GetStreamByActivedStream(cur_stream_id); + if (active_current_node == kInvalidStreamId) { + MS_LOG(EXCEPTION) << "No stream to active streamactive stream"; + } + + for (const auto &item : active_list) { + if (item <= active_current_node) { + MS_LOG(WARNING) << "Actived stream is less than activing stream"; + continue; + } + auto it = + std::find(stream_relations_[active_current_node].begin(), stream_relations_[active_current_node].end(), item); + if (it == stream_relations_[active_current_node].end()) { + stream_relations_[active_current_node].emplace_back(item); + } + } + } + + if (kind == kMiddle) { + for (const auto &stream : active_list) { + if (stream <= cur_stream_id) { + MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal"; + } else { + MS_LOG(ERROR) << "MIDDLE StreamActive active stream is greater than self stream, should not be exit now"; + } + } + } + + if (kind == kTail) { + auto it = stream_relations_.find(cur_stream_id); + if (it == stream_relations_.end()) { + stream_relations_[cur_stream_id] = active_list; + } else { + for (const auto &stream : active_list) { + if (stream <= cur_stream_id) { + MS_LOG(WARNING) << "Actived stream is less than activing stream"; + continue; + } + auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream); + if (iter == stream_relations_[cur_stream_id].end()) { + stream_relations_[cur_stream_id].emplace_back(stream); + } + } + } + } +} + +StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull &graph_ptr, size_t index) { + auto exe_orders = graph_ptr->execution_order(); + if (index >= exe_orders.size()) { + MS_LOG(EXCEPTION) << "Invalid op index:" << index; + } + + auto cur_cnode = exe_orders[index]; + auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode); + if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) { + MS_LOG(EXCEPTION) << "Current node name is not StreamActive"; + } + + if (index == 0) { + return kInvalid; + } + + if (index == exe_orders.size() - 1) { + return kInvalid; + } + + uint32_t pre_stream_id = UINT32_MAX; + uint32_t next_stream_id = UINT32_MAX; + int32_t start = SizeToInt(index) - 1; + for (int32_t i = start; i >= 0; i--) { + auto cnode = exe_orders[IntToSize(i)]; + auto name = AnfAlgo::GetCNodeName(cnode); + if (name == kSendOpName || name == kRecvOpName) { + continue; + } + + pre_stream_id = AnfAlgo::GetStreamId(cnode); + break; + } + + for (size_t i = index + 1; i < exe_orders.size(); i++) { + auto cnode = exe_orders[i]; + auto name = AnfAlgo::GetCNodeName(cnode); + if (name == kSendOpName || name == kRecvOpName) { + continue; + } + + next_stream_id = AnfAlgo::GetStreamId(cnode); + break; + } + + // pre_stream_id = UINT32_MAX:means no node active current StreamActive + // next_stream_id = UINT32_MAX:means current StreamActive active no node + if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) { + return kInvalid; + } + + if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) { + return kMiddle; + } + + if (cur_stream_id == pre_stream_id) { + return kTail; + } + + if (cur_stream_id == next_stream_id) { + return kHead; + } + + return kInvalid; +} + +uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) { + if (stream_relations_.empty()) { + return kInvalidStreamId; + } + + for (const auto &item : stream_relations_) { + auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id); + if (it != item.second.end()) { + return item.first; + } + } + + return kInvalidStreamId; +} + +void AscendStreamAssign::PrintStreamRelations() { + MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size(); + for (const auto &item : stream_relations_) { + MS_LOG(INFO) << "Stream:" << item.first; + for (const auto &stream : item.second) { + MS_LOG(INFO) << "--actived stream id:" << stream; + } + } +} + +void AscendStreamAssign::PrintStreamGroups() { + MS_LOG(INFO) << "Stream group size:" << stream_groups_.size(); + for (const auto &item : stream_groups_) { + MS_LOG(INFO) << "Group:"; + for (const auto &stream : item) { + MS_LOG(INFO) << "Stream id:" << stream; + } + } +} + +// section 11 +bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const { + size_t send_group = 0; + size_t recv_group = 0; + bool send_flag = true; + bool recv_flag = true; + for (size_t i = 0; i < stream_groups_.size(); i++) { + auto group = stream_groups_[i]; + if (send_flag) { + auto it = std::find(group.begin(), group.end(), send_stream_id); + if (it != group.end()) { + send_group = i; + send_flag = false; + } + } + + if (recv_flag) { + auto it = std::find(group.begin(), group.end(), recv_stream_id); + if (it != group.end()) { + recv_group = i; + recv_flag = false; + } + } + } + + if (!(send_flag || recv_flag)) { + return (send_group != recv_group); + } + + return false; +} + +void AscendStreamAssign::FindEventRelations(const NotNull &graph_ptr) { + AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); + auto event_nums = resource_manager.get_cur_event_num(); + if (event_nums == 0) { + return; + } + auto exe_orders = graph_ptr->execution_order(); + // find all event info + for (size_t i = 0; i < exe_orders.size(); i++) { + auto cur_cnode = exe_orders[i]; + auto name = AnfAlgo::GetCNodeName(cur_cnode); + if (name == kSendOpName) { + event_map_[cur_cnode] = {}; + } + + if (name == kRecvOpName) { + auto recv_event_id = AnfAlgo::GetNodeAttr(cur_cnode, kAttrEventId); + for (auto &item : event_map_) { + auto send_event_id = AnfAlgo::GetNodeAttr(item.first, kAttrEventId); + if (recv_event_id == send_event_id) { + item.second = cur_cnode; + break; + } + } + } + } + + // delete useless event info + auto begin = event_map_.begin(); + while (begin != event_map_.end()) { + auto send_stream_id = AnfAlgo::GetStreamId(begin->first); + auto recv_stream_id = AnfAlgo::GetStreamId(begin->second); + bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id); + if (!flag) { + begin = event_map_.erase(begin); + } else { + begin++; + } + } + + MS_LOG(INFO) << "Satisfied event info"; + for (const auto &item : event_map_) { + MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr(item.first, kAttrEventId); + } +} + +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h new file mode 100644 index 0000000000..00fca60e8d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.h @@ -0,0 +1,185 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/base.h" +#include "runtime/rt_model.h" +#include "runtime/stream.h" +#include "backend/session/kernel_graph.h" +#include "utils/contract.h" + +namespace mindspore { +namespace device { +namespace ascend { +using std::map; +using std::shared_ptr; +using std::unordered_map; +using std::unordered_set; +using std::vector; +const uint32_t kInvalidStreamId = UINT32_MAX; +const uint32_t kInvalidEventId = UINT32_MAX; +class AscendResourceMng { + public: + static AscendResourceMng &GetInstance() { + static AscendResourceMng instance; + return instance; + } + + void ResetResource() { + cur_stream_num_ = 0; + cur_event_num_ = 0; + } + uint32_t ApplyNewStream() { + if (!cur_stream_num_) { + uint32_t cur_stream_id = cur_stream_num_; + cur_stream_num_++; + return cur_stream_id; + } + uint32_t cur_stream_id = cur_stream_num_; + cur_stream_num_++; + return cur_stream_id; + } + uint32_t ApplyNewEvent() { + if (!cur_event_num_) { + uint32_t cur_event_id = cur_event_num_; + cur_event_num_++; + return cur_event_id; + } + uint32_t cur_event_id = cur_event_num_; + cur_event_num_++; + return cur_event_id; + } + + void DeleteEvent() { + if (!cur_event_num_) { + MS_LOG(WARNING) << "total event num is 0, no event to delete"; + } else { + --cur_event_num_; + } + } + uint32_t get_cur_stream_num() { return cur_stream_num_; } + uint32_t GetCurAllocStreamId() { + if (!cur_stream_num_) { + MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get"; + } + return cur_stream_num_ - 1; + } + uint32_t get_cur_event_num() { return cur_event_num_; } + + private: + uint32_t cur_stream_num_{0}; + uint32_t cur_event_num_{0}; +}; + +enum StreamActiveKind { kInvalid = 0, kHead, kMiddle, kTail }; +class AscendStreamAssign { + public: + static AscendStreamAssign &GetInstance() { + static AscendStreamAssign instance; // Guaranteed to be destroyed. + return instance; + } + + AscendStreamAssign(const AscendStreamAssign &) = delete; + AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; + + void AssignStream(const NotNull &graph_ptr); + void GetHcomStreams(std::vector *streams); + void GetWaitStreams(vector *wait_active_stream_list); + CNodePtr CreateSendApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); + CNodePtr CreateRecvApplyKernel(const NotNull &graph_ptr, uint32_t event_id, uint32_t stream_id); + const std::vector> &get_stream_group() const { return stream_groups_; } + const std::map &get_event_map() const { return event_map_; } + + private: + AscendStreamAssign() = default; + ~AscendStreamAssign() = default; + void Reset(); + void CheckResourceAssign(const NotNull &graph_ptr); + void CheckStreamAssign(const NotNull &graph_ptr); + void CheckEventAssign(const NotNull &graph_ptr); + void AssignAllNodesStream(const NotNull &graph_ptr); + void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr); + void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr); + void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr); + void UpdateAtomicAddrCleanStreamId(const NotNull &graph_ptr); + void FindHcomParallelStreams(const NotNull &graph_ptr); + void InsertStreamActive(const NotNull &graph_ptr); + void UpdateStreamSwitch(const NotNull &graph_ptr, const CNodePtr &switch_ptr, + vector *orders); + void InsertEventForIndependentParallel(const NotNull &graph_ptr); + void InsertEventForHcomParallel(const NotNull &graph_ptr); + void InsertEventCommonDependHcom(const NotNull &graph_ptr); + void InsertEventHcomDependCommon(const NotNull &graph_ptr); + void InsertEventHcomDependHcom(const NotNull &graph_ptr); + void InsertEventBetweenHcom(const NotNull &graph_ptr, const map> &hcom_index, + uint32_t first_hcom_stream, uint32_t last_hcom_stream); + bool IsSatisfiedHcom(const std::map> &hcom_index, const CNodePtr &node_ptr, size_t index); + + void GetProcessedStream(const NotNull &graph_ptr); + void GetNeedActiveStreams(const NotNull &graph_ptr); + void ReorderIndependentOrders(const NotNull &graph_ptr); + + bool IsTaskSink(); + bool IsHcom(const CNodePtr &cur_cnode_ptr); + bool IsIndependentNode(const CNodePtr &node_ptr); + bool IsProcessedStream(uint32_t stream_id); + vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, + const CNodePtr &node); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + + // function for memory resue + void GetStreamRelations(); + void DFS(uint32_t start, std::vector *group); + bool IsVecExist(std::vector *group); + void FindStreamRelations(const NotNull &graph_ptr); + void GetStreamSwitchStreamRelation(const CNodePtr &node_ptr); + void GetStreamActiveStreamRelation(const NotNull &graph_ptr, size_t index); + StreamActiveKind GetStreamActiveKind(const NotNull &graph_ptr, size_t index); + uint32_t GetStreamByActivedStream(uint32_t actived_stream_id); + void PrintStreamRelations(); + void PrintStreamGroups(); + void FindEventRelations(const NotNull &graph_ptr); + bool IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const; + + bool independent_stream_activated_{false}; + bool hcom_stream_activated_{false}; + std::map independent_stream_map_{}; + std::map hcom_stream_map_{}; + std::map common_stream_map_{}; + std::set processed_streams_{}; + std::vector need_first_active_streams_{}; + + // attr for memory copy reuse + std::map> stream_relations_{}; + std::vector> stream_groups_{}; + std::map event_map_; + // new policy end +}; +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_ASCEND_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc new file mode 100644 index 0000000000..ab2c6b2748 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -0,0 +1,282 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef ENABLE_DATA_DUMP +#include "runtime/device/ascend/dump/data_dumper.h" + +#include +#include +#include +#include "utility" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/mem.h" +#include "runtime/kernel.h" +#include "runtime/device/ascend/dump/ge_dump.h" +#include "proto/op_mapping_info.pb.h" +#include "utils/context/ms_context.h" +#include "debug/data_dump_parser.h" + +constexpr uint32_t kAicpuLoadFlag = 1; +constexpr uint32_t kAicpuUnloadFlag = 0; +constexpr uint32_t kTupleTaskId = 0; +constexpr uint32_t kTupleStreamId = 1; +constexpr uint32_t kTupleArgs = 2; +constexpr uint32_t kCurrentStepTensorIndex = 0; +constexpr uint32_t kCurrentEpochTensorIndex = 1; +constexpr uint32_t kStepsPerEpochTensorIndex = 2; + +namespace mindspore { +namespace device { +namespace ascend { +void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task); +void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task); +void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr); + +DataDumper::~DataDumper() { + ReleaseDevMem(&dev_load_mem_); + ReleaseDevMem(&dev_unload_mem_); +} + +void DataDumper::LoadDumpInfo() { + MS_LOG(INFO) << "[DataDump] LoadDumpInfo start"; + MS_EXCEPTION_IF_NULL(kernel_graph_); + aicpu::dump::OpMappingInfo dump_info; + SetOpMappingInfo(NOT_NULL(&dump_info)); + + auto kernels = kernel_graph_->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + if (!KernelNeedDump(kernel)) { + continue; + } + MS_LOG(INFO) << "[DataDump] LoadDumpInfo kernel:" << kernel->fullname_with_scope(); + dump_kernel_names_.emplace_back(kernel->fullname_with_scope()); + + aicpu::dump::Task task; + ConstructDumpTask(NOT_NULL(kernel), NOT_NULL(&task)); + MS_EXCEPTION_IF_NULL(dump_info.mutable_task()); + dump_info.mutable_task()->Add(std::move(task)); + } + RtLoadDumpData(dump_info, &dev_load_mem_); + load_flag_ = true; + MS_LOG(INFO) << "[DataDump] LoadDumpInfo end"; +} + +void DataDumper::SetOpMappingInfo(NotNull dump_info) const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(kernel_graph_); + auto dump_path = DataDumpParser::GetInstance().GetDumpPath(); + if (!dump_path.has_value()) { + MS_LOG(EXCEPTION) << "Dump path invalid"; + } + auto device_id = context_ptr->device_id(); + dump_info->set_dump_path(dump_path.value() + "_" + std::to_string(device_id) + "/"); + MS_LOG(INFO) << "[DataDump] dump_path:" << dump_path.value(); + + dump_info->set_model_name(DataDumpParser::GetInstance().net_name() + "_" + std::to_string(kernel_graph_->graph_id())); + dump_info->set_dump_step(std::to_string(DataDumpParser::GetInstance().dump_step())); + dump_info->set_model_id(kernel_graph_->graph_id()); + dump_info->set_flag(kAicpuLoadFlag); + + const auto &input_ctrl_tensors = kernel_graph_->input_ctrl_tensors(); + if (input_ctrl_tensors == nullptr || input_ctrl_tensors->size() < 3) { + MS_LOG(INFO) << "[DataDump] Not data sink mode, input_ctrl_tensor"; + return; + } + const auto ¤t_step_tensor = input_ctrl_tensors->at(kCurrentStepTensorIndex); + const auto &currnet_epoch_tensor = input_ctrl_tensors->at(kCurrentEpochTensorIndex); + const auto &steps_per_epoch_tensor = input_ctrl_tensors->at(kStepsPerEpochTensorIndex); + + MS_EXCEPTION_IF_NULL(current_step_tensor); + MS_EXCEPTION_IF_NULL(currnet_epoch_tensor); + MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor); + MS_EXCEPTION_IF_NULL(current_step_tensor->device_address()); + MS_EXCEPTION_IF_NULL(currnet_epoch_tensor->device_address()); + MS_EXCEPTION_IF_NULL(steps_per_epoch_tensor->device_address()); + + void *current_step = current_step_tensor->device_address()->ptr_; + void *current_epoch = currnet_epoch_tensor->device_address()->ptr_; + void *steps_per_epoch = steps_per_epoch_tensor->device_address()->ptr_; + + if (current_epoch != nullptr && current_step != nullptr && steps_per_epoch != nullptr) { + dump_info->set_step_id_addr(reinterpret_cast(current_epoch)); + dump_info->set_loop_cond_addr(reinterpret_cast(current_step)); + dump_info->set_iterations_per_loop_addr(reinterpret_cast(steps_per_epoch)); + } else { + MS_LOG(INFO) << "Invalid ctrl tensor device address"; + } +} + +bool DataDumper::KernelNeedDump(const CNodePtr &kernel) const { + if (AnfAlgo::GetKernelType(kernel) != TBE_KERNEL && AnfAlgo::GetKernelType(kernel) != AICPU_KERNEL && + AnfAlgo::GetKernelType(kernel) != AKG_KERNEL) { + return false; + } + MS_EXCEPTION_IF_NULL(kernel); + // dump all kernel if mode is set 0 in data_dump.json + return DataDumpParser::GetInstance().NeedDump(kernel->fullname_with_scope()); +} + +void DataDumper::UnloadDumpInfo() { + if (!load_flag_) { + MS_LOG(WARNING) << "Load not success, no need to unload"; + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph_); + MS_LOG(INFO) << "[DataDump] UnloadDumpInfo start. graphId:" << kernel_graph_->graph_id(); + + aicpu::dump::OpMappingInfo op_mapping_info; + op_mapping_info.set_model_id(kernel_graph_->graph_id()); + op_mapping_info.set_flag(kAicpuUnloadFlag); + + for (const auto &kernel_name : dump_kernel_names_) { + aicpu::dump::Task task; + auto iter = runtime_info_map_.find(kernel_name); + if (iter == runtime_info_map_.end()) { + MS_LOG(EXCEPTION) << "[DataDump] kernel name not found in runtime_info_map"; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto task_id = std::get(*iter->second); + task.set_task_id(task_id); + MS_EXCEPTION_IF_NULL(op_mapping_info.mutable_task()); + op_mapping_info.mutable_task()->Add(std::move(task)); + } + + RtLoadDumpData(op_mapping_info, &dev_unload_mem_); +} + +void DataDumper::ReleaseDevMem(void **ptr) const { + if (ptr == nullptr) { + return; + } + if (*ptr != nullptr) { + rtError_t rt_error = rtFree(*ptr); + if (rt_error != RT_ERROR_NONE) { + MS_LOG(ERROR) << "[DataDump] Call rtFree failed, ret:" << rt_error; + } + *ptr = nullptr; + } +} + +void DataDumper::ConstructDumpTask(NotNull kernel, NotNull dump_task) const { + dump_task->set_end_graph(false); + auto iter = runtime_info_map_.find(kernel->fullname_with_scope()); + if (iter == runtime_info_map_.end()) { + MS_LOG(EXCEPTION) << "[DataDump] kernel name not found in runtime_info_map"; + } + MS_EXCEPTION_IF_NULL(iter->second); + auto task_id = std::get(*iter->second); + auto stream_id = std::get(*iter->second); + auto args = std::get(*iter->second); + MS_LOG(INFO) << "[DataDump] Get runtime info task_id:" << task_id << " stream_id:" << stream_id; + + dump_task->set_task_id(task_id); + dump_task->set_stream_id(stream_id); + MS_EXCEPTION_IF_NULL(dump_task->mutable_op()); + dump_task->mutable_op()->set_op_name(kernel->fullname_with_scope()); + dump_task->mutable_op()->set_op_type(AnfAlgo::GetCNodeName(kernel.get())); + + DumpKernelOutput(kernel, args, dump_task); + DumpKernelInput(kernel, args, dump_task); +} + +void RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, void **ptr) { + std::string proto_str; + size_t proto_size = dump_info.ByteSizeLong(); + bool ret = dump_info.SerializeToString(&proto_str); + if (!ret || proto_size == 0) { + MS_LOG(EXCEPTION) << "[DataDump] Protobuf SerializeToString failed, proto size %zu."; + } + + rtError_t rt_ret = rtMalloc(ptr, proto_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMalloc failed"; + } + + if (ptr == nullptr) { + MS_LOG(ERROR) << "[DataDump] rtMalloc failed, ptr is nullptr"; + return; + } + rt_ret = rtMemcpy(*ptr, proto_size, proto_str.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtMemcpy failed"; + } + + MS_LOG(INFO) << "[DataDump] rtDatadumpInfoLoad start"; + rt_ret = rtDatadumpInfoLoad(*ptr, proto_size); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "[DataDump] Call rtDatadumpInfoLoad failed"; + } +} + +void DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task) { + MS_LOG(INFO) << "[DataDump] DumpKernelOutput start. Kernel:" << kernel->fullname_with_scope(); + auto input_size = AnfAlgo::GetInputTensorNum(kernel); + auto output_size = AnfAlgo::GetOutputTensorNum(kernel); + uint64_t offset = sizeof(void *) * input_size; + for (size_t i = 0; i < output_size; ++i) { + auto data_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i); + + aicpu::dump::Output output; + output.set_data_type(GetGeDataType(data_type)); + output.set_format(GetGeFormat(output_format, output_shape.size())); + MS_EXCEPTION_IF_NULL(output.mutable_shape()); + for (auto dim : output_shape) { + output.mutable_shape()->add_dim(dim); + } + output.set_original_output_format(GetGeFormat(output_format, output_shape.size())); + output.set_address(static_cast(reinterpret_cast(args)) + offset); + MS_EXCEPTION_IF_NULL(task->mutable_output()); + task->mutable_output()->Add(std::move(output)); + offset += sizeof(void *); + } +} + +void DumpKernelInput(const CNodePtr &kernel, void *args, NotNull task) { + MS_LOG(INFO) << "[DataDump] DumpKernelInput start. Kernel:" << kernel->fullname_with_scope(); + auto input_size = AnfAlgo::GetInputTensorNum(kernel); + uint64_t offset = 0; + for (size_t i = 0; i < input_size; ++i) { + aicpu::dump::Input input; + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + auto input_node = input_node_with_index.first; + auto input_index = input_node_with_index.second; + std::string output_format = AnfAlgo::GetOutputFormat(input_node, input_index); + auto output_type = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); + if (output_type == kTypeUnknown) { + MS_LOG(WARNING) << "[DataDump] It is not suggested to use a lonely weight parameter as the output of graph"; + output_type = AnfAlgo::GetOutputInferDataType(input_node, input_index); + } + auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index); + + input.set_data_type(GetGeDataType(output_type)); + input.set_format(GetGeFormat(output_format, output_shape.size())); + MS_EXCEPTION_IF_NULL(input.mutable_shape()); + for (auto dim : output_shape) { + input.mutable_shape()->add_dim(dim); + } + input.set_address(static_cast(reinterpret_cast(args)) + offset); + MS_EXCEPTION_IF_NULL(task->mutable_input()); + task->mutable_input()->Add(std::move(input)); + offset += sizeof(void *); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.h b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.h new file mode 100644 index 0000000000..d99eb4db68 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ +#ifdef ENABLE_DATA_DUMP +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" + +namespace aicpu { +namespace dump { +class OpMappingInfo; +class Task; +} // namespace dump +} // namespace aicpu +namespace mindspore { +namespace device { +namespace ascend { +// tuple(op_name, task_id, stream_id, args) +using RuntimeInfo = std::tuple; +class DataDumper { + public: + DataDumper(const session::KernelGraph *kernel_graph, + const std::map> &runtime_info_map) + : load_flag_(false), + dev_load_mem_(nullptr), + dev_unload_mem_(nullptr), + kernel_graph_(kernel_graph), + runtime_info_map_(runtime_info_map) {} + ~DataDumper(); + void LoadDumpInfo(); + + void UnloadDumpInfo(); + + private: + void ReleaseDevMem(void **ptr) const; + bool KernelNeedDump(const CNodePtr &kernel) const; + void SetOpMappingInfo(NotNull dump_info) const; + void ConstructDumpTask(NotNull kernel, NotNull dump_task) const; + + bool load_flag_; + void *dev_load_mem_; + void *dev_unload_mem_; + std::vector dump_kernel_names_; + const session::KernelGraph *kernel_graph_; + std::map> runtime_info_map_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_DUMP_DATADUMP_H_ diff --git a/mindspore/ccsrc/device/ascend/dump/ge_dump.h b/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h similarity index 100% rename from mindspore/ccsrc/device/ascend/dump/ge_dump.h rename to mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h diff --git a/mindspore/ccsrc/device/ascend/dump/proto/ge_dtype.proto b/mindspore/ccsrc/runtime/device/ascend/dump/proto/ge_dtype.proto similarity index 100% rename from mindspore/ccsrc/device/ascend/dump/proto/ge_dtype.proto rename to mindspore/ccsrc/runtime/device/ascend/dump/proto/ge_dtype.proto diff --git a/mindspore/ccsrc/device/ascend/dump/proto/op_mapping_info.proto b/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto similarity index 100% rename from mindspore/ccsrc/device/ascend/dump/proto/op_mapping_info.proto rename to mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc new file mode 100644 index 0000000000..39cefcb020 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -0,0 +1,286 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/kernel_build_ascend.h" + +#include +#include +#include +#include + +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" +#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" +#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include "backend/kernel_compiler/hccl/hccl_kernel_build.h" +#include "backend/kernel_compiler/rts/rt_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/common_utils.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "./common.h" + +namespace mindspore { +namespace device { +namespace ascend { +using mindspore::kernel::tbe::TbeUtils; +using std::make_shared; +static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { + kernel::KernelModPtr kernel_mod_ptr = nullptr; + KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); + switch (kernel_type) { + case KernelType::AICPU_KERNEL: { + kernel_mod_ptr = kernel::AicpuOpBuild(anf_node); + break; + } + case KernelType::RT_KERNEL: { + kernel_mod_ptr = kernel::RtOpBuild(anf_node); + break; + } + case KernelType::HCCL_KERNEL: { + kernel_mod_ptr = kernel::HcclOpBuild(anf_node); + break; + } + default: { + MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type; + } + } + return kernel_mod_ptr; +} + +static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + std::vector tbe_nodes; + for (const auto &anf_node : kernel_graph_ptr->execution_order()) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!AnfAlgo::IsRealKernel(anf_node)) { + continue; + } + KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); + switch (kernel_type) { + case KernelType::TBE_KERNEL: { + if (AnfAlgo::GetKernelMod(anf_node) == nullptr && + AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) { + tbe_nodes.push_back(anf_node); + } + break; + } + default: { + break; + } + } + } + bool ret = kernel::TbeOpParallelPreBuild(tbe_nodes); + return ret; +} + +static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + std::vector tbe_nodes; + std::vector akg_nodes; + std::vector other_nodes; + for (const auto &anf_node : kernel_graph_ptr->execution_order()) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!AnfAlgo::IsRealKernel(anf_node)) { + continue; + } + KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); + switch (kernel_type) { + case KernelType::TBE_KERNEL: { + if (AnfAlgo::GetKernelMod(anf_node) == nullptr) { + tbe_nodes.push_back(anf_node); + } + break; + } + case KernelType::AKG_KERNEL: { + akg_nodes.push_back(anf_node); + break; + } + default: { + other_nodes.push_back(anf_node); + break; + } + } + } + bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes); + bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes); + auto bin_map = kernel::tbe::KernelMeta::GetInstance(); + (void)bin_map->ReadIndex(kernel::kCceKernelMeta); + for (const auto &anf_node : other_nodes) { + kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get()); + } + return tbe_ret && akg_ret; +} + +static std::vector CalCleanZerosSize(const CNodePtr &pre_node) { + MS_EXCEPTION_IF_NULL(pre_node); + auto kernel_mod = AnfAlgo::GetKernelMod(pre_node); + MS_EXCEPTION_IF_NULL(kernel_mod); + std::vector clean_size_list; + // clean output + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + auto output_men_size = kernel_mod->GetOutputSizeList(); + for (auto index : output_indexs) { + auto clean_item = (output_men_size.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; + clean_size_list.emplace_back(clean_item); + } + } + // clean workspace + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList(); + for (const auto &index : workspace_indexs) { + auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; + clean_size_list.emplace_back(clean_item); + } + } + MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope(); + return clean_size_list; +} + +static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph, + const mindspore::CNodePtr &pre_node, std::vector *new_nodes) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(pre_node); + MS_EXCEPTION_IF_NULL(new_nodes); + auto clear_zero_prim = std::make_shared(kAtomicAddrCleanOpName); + MS_EXCEPTION_IF_NULL(clear_zero_prim); + auto new_value_node = NewValueNode(clear_zero_prim); + MS_EXCEPTION_IF_NULL(new_value_node); + std::vector inputs = {new_value_node}; + inputs.push_back(pre_node); + CNodePtr clear_zero = kernel_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clear_zero); + AbstractBasePtr abstract = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract); + clear_zero->set_abstract(abstract); + auto builder = std::make_shared(); + builder->SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); + auto clean_size = CalCleanZerosSize(pre_node); + AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get()); + new_nodes->push_back(clear_zero); +} + +static bool IsAtomicNode(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto parameters_indexs = kernel_mod->GenParameters(); + if (parameters_indexs.empty()) { + return false; + } + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size(); + size_t param_num = parameters_indexs.size(); + size_t total_num = input_num + workspace_num + output_num; + MS_LOG(INFO) << "parameters size: " << param_num << ", input & workspace & output num: " << total_num; + size_t pad_index = param_num; + for (; pad_index < total_num; ++pad_index) { + parameters_indexs.emplace_back(0); + } + // process input + for (size_t j = 0; j < input_num; ++j) { + if (parameters_indexs.at(j) == 1) { + MS_LOG(EXCEPTION) << "Atomic addr clean does't support clean input address, input index: " << j; + } + } + // process output + std::vector output_indexs = {}; + for (size_t i = 0; i < output_num; ++i) { + auto param_output = parameters_indexs.at(input_num + workspace_num + i); + if (param_output == 1) { + output_indexs.emplace_back(i); + MS_LOG(INFO) << "Atomic clear output index: " << i; + } + } + if (!output_indexs.empty()) { + AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node); + } + // process workspace + std::vector workspace_indexs = {}; + for (size_t k = 0; k < workspace_num; ++k) { + auto param_workspace = parameters_indexs.at(input_num + k); + if (param_workspace == 1) { + workspace_indexs.emplace_back(k); + MS_LOG(INFO) << "Atomic clear workspace index: " << k; + } + } + if (!workspace_indexs.empty()) { + AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node); + } + return !(workspace_indexs.empty() && output_indexs.empty()); +} + +bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + bool ret = device::ascend::KernelPreBuildParallelCompile(kernel_graph_ptr); + return ret; +} + +bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + TbeUtils::LoadCache(); + bool ret; + ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr); + return ret; +} + +void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector new_nodes; + for (const auto &anf_node : kernel_graph->execution_order()) { + std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); + if (apply_function_name == prim::kPrimMaxPoolGrad->name() && + AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { + auto clear_zero_prim = std::make_shared(kClearZeroOpName); + MS_EXCEPTION_IF_NULL(clear_zero_prim); + auto new_value_node = NewValueNode(clear_zero_prim); + MS_EXCEPTION_IF_NULL(new_value_node); + std::vector inputs = {new_value_node}; + inputs.push_back(anf_node); + CNodePtr clear_zero = kernel_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clear_zero); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + clear_zero->set_kernel_info(kernel_info); + AbstractBasePtr abstract = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract); + AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector({"x"})), clear_zero); + SelectKernelInfo(clear_zero); + // set the distinction label of clear same with anf + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); + new_nodes.push_back(clear_zero); + } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { + if (IsAtomicNode(anf_node)) { + AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); + } + } + new_nodes.push_back(anf_node); + } + kernel_graph->set_execution_order(new_nodes); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h new file mode 100644 index 0000000000..0d2870eb0a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ + +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace device { +namespace ascend { +/** + * @brief kernel pre build for ascend. + */ +bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); +/** + * @brief kernel build for ascend. + */ +bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr); +/** + * @brief preporcess of kernel build for ascend, e.g. inserting clear_zero node for maxpool, bn. + * Must DO these changes just before kernel build, and after all of other optimizations on AnfGraph + */ +void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph); +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_BUILD_ASCEND_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc new file mode 100644 index 0000000000..e8fc6c7a98 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -0,0 +1,584 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/kernel_select_ascend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "debug/anf_ir_dump.h" +#include "frontend/operator/ops.h" +#include "ir/func_graph.h" +#include "utils/context/ms_context.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace { +const float kWegihtBaseScore = 1; +const float kFeatureMapBaseScore = 10; +constexpr auto kPriChoosenFormat = "pri_format"; +enum MatchCountPriority : int { + MATCH_COUNT_PRIORITY_BEGIN = 0, + MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN, + MATCH_FORMAT_COUNT, + MATCH_SPECIAL_FORMAT_COUNT, + MATCH_DEFAULT_FORMAT_COUNT, + MATCH_OUTPUT_DTYPE_COUNT, + MATCH_COUNT_PRIORITY_END +}; + +const int kUnSupportMixedDataTypeIndex = -1; + +bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) { + MS_EXCEPTION_IF_NULL(cnode); + // Check input data type + for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { + TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) { + return false; + } + } + // Check output data type + for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { + if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) { + return false; + } + } + return true; +} + +string GetPriorityMatchFormat(const CNodePtr &cnode) { + string priority_matched_format = kOpFormat_NC1HWC0; + bool is_init = false; + bool need_change_nd = false; + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); + if (AnfAlgo::IsFeatureMapInput(cnode, index) && + kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) { + priority_matched_format = !is_init ? pre_output_format : priority_matched_format; + is_init = true; + } + // feature map has two or more special format; + if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) { + priority_matched_format = kOpFormat_DEFAULT; + } + auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); + need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); + } + if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { + priority_matched_format = kOpFormat_DEFAULT; + } + AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); + return priority_matched_format; +} +/** + * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, + * if equal then next num location + * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3] + */ +bool PriorityChooseItem(const std::vector &cur_item, std::vector *best_item) { + MS_EXCEPTION_IF_NULL(best_item); + if (cur_item.size() != best_item->size()) { + MS_LOG(ERROR) << "Item size should be same!"; + return false; + } + // Update the best_item by comparing the cur_item and best_item + for (size_t i = 0; i < cur_item.size(); i++) { + if (cur_item[i] > best_item->at(i)) { + *best_item = cur_item; + return true; + } else if (cur_item[i] == best_item->at(i)) { + continue; + } else { + return false; + } + } + return false; +} + +void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr &kernel_node, + std::vector *const cur_kernelinfo_match_counts) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts); + if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) { + MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END; + } + auto pri_match_format = GetPriorityMatchFormat(kernel_node); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + auto input_anf_node = kernel_node->input(input_index + 1); + // we do not take ValueNode into consideration in graph kernel. + if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) { + if (input_anf_node->isa() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) { + continue; + } + } + auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore; + if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) { + (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score; + } + // we match output fix precision first. + auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index); + if (prev_device_type == kTypeUnknown) { + prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); + } + if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) { + (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score; + } + if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { + (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; + } + if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { + (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; + } + } + + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + // cal count of same output dtype between abstract and kernel info + if (kernel_build_info.GetOutputDeviceType(output_index) == + AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { + (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1; + } + } +} + +void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *support_index) { + MS_EXCEPTION_IF_NULL(support_index); + int index = kUnSupportMixedDataTypeIndex; + switch (data_type) { + case kNumberTypeFloat16: + index = 0; + break; + case kNumberTypeFloat32: + case kNumberTypeFloat: + index = 1; + break; + default: + break; + } + support_index->push_back(index); +} + +void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, + std::vector *support_datatype_index, std::vector *support_datatype) { + MS_EXCEPTION_IF_NULL(support_datatype); + auto data_type = kernel_build_info.GetInputDeviceType(input_index); + support_datatype->push_back(data_type); + AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); +} + +void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, + std::vector *support_datatype_index, std::vector *support_datatype) { + MS_EXCEPTION_IF_NULL(support_datatype); + auto data_type = kernel_build_info.GetOutputDeviceType(output_index); + support_datatype->push_back(data_type); + AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); +} + +void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, + std::vector *node_mix_precision_datatype_index, + std::vector *node_mix_precision_datatype) { + AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); + MS_EXCEPTION_IF_NULL(cur_input); + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); + TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); + AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); + node_mix_precision_datatype->push_back(input_origin_type); +} + +void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, + std::vector *node_mix_precision_datatype_index, + std::vector *node_mix_precision_datatype) { + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); + auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); + AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); + node_mix_precision_datatype->push_back(output_origin_type); +} + +void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { + MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " + << node_mix_precision_datatype.size(); + } + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { + MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " + << kernel_support_datatypes.size(); + } +} + +bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, + kernel_match_datatype_idx); + for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { + if (node_mix_precision_datatype[i] == kTypeUnknown) { + continue; + } + auto iter = kernel_match_datatype_idx->begin(); + while (iter != kernel_match_datatype_idx->end()) { + if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { + auto find_iter = kernel_support_datatypes.find(iter->first); + if (find_iter == kernel_support_datatypes.end()) { + MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; + } + if (i >= find_iter->second.size()) { + MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); + } + if (node_mix_precision_datatype[i] != find_iter->second[i]) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + continue; + } + auto datatype_indexes = iter->second; + if (i >= datatype_indexes.size()) { + MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); + } + if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + } + } + return !kernel_match_datatype_idx->empty(); +} + +bool CanDataTypeReduce(const std::vector &datatype_indexes, int check_index, + const std::vector &node_mix_precision_datatype_index) { + auto check_index_tmp = IntToSize(check_index); + if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { + return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && + datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; + } + MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; +} + +bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatypes, + std::map> *kernel_match_datatype_idx) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, + kernel_match_datatype_idx); + for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { + if (node_mix_precision_datatype[i] == kTypeUnknown) { + continue; + } + auto iter = kernel_match_datatype_idx->begin(); + while (iter != kernel_match_datatype_idx->end()) { + if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { + auto find_iter = kernel_support_datatypes.find(iter->first); + if (find_iter == kernel_support_datatypes.end()) { + MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; + } + if (i >= find_iter->second.size()) { + MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); + } + if (node_mix_precision_datatype[i] != find_iter->second[i]) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + continue; + } + auto datatype_indexes = iter->second; + if (i >= datatype_indexes.size()) { + MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); + } + if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { + iter = kernel_match_datatype_idx->erase(iter); + } else { + ++iter; + } + } + } + return !kernel_match_datatype_idx->empty(); +} + +void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, + std::vector *support_indexes, std::vector *node_mix_precision_datatype, + std::vector *support_datatypes, + std::vector *node_mix_precision_datatype_index) { + MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); + bool add_node_datatype_flag = false; + if (node_mix_precision_datatype->empty()) { + add_node_datatype_flag = true; + } + for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { + AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes); + if (add_node_datatype_flag) { + AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype); + } + } + // Check output data type + for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { + AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes); + if (add_node_datatype_flag) { + AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype); + } + } +} + +void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, + const std::vector &node_mix_precision_datatype, + const std::map> &kernel_support_datatype, + std::map> *kernel_match_datatype_idx, bool *precision_reduce) { + MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(precision_reduce); + std::map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; + // raise precision + bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, + kernel_support_datatype, kernel_match_datatype_idx); + if (selected_ret) { + *precision_reduce = false; + return; + } + if (context_ptr->enable_reduce_precision()) { + selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, + kernel_support_datatype, &kernel_match_datatype_idx_copy); + } + if (selected_ret) { + *precision_reduce = true; + *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; + } +} + +void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, + const std::shared_ptr &selected_kernel_build_info, + bool precision_reduce) { + MS_EXCEPTION_IF_NULL(selected_kernel_build_info); + MS_EXCEPTION_IF_NULL(cnode); + std::ostringstream buffer; + buffer << cnode->DebugString(); + if (precision_reduce) { + buffer << " Reduce precision, node datatype: \n"; + } else { + buffer << " Raise precision, node datatype: \n"; + } + PrintInputAndOutputInferType(buffer, cnode); + buffer << ", select kernel:" << selected_kernel_build_info->ToString(); + MS_LOG(INFO) << buffer.str(); +} + +std::shared_ptr ChooseMatchedKernelInfo( + const CNodePtr &kernel_node, const std::vector> &kernel_info_list) { + if (kernel_info_list.empty()) { + return nullptr; + } + std::vector most_match_counts = {-1, -1, -1, -1, -1}; + size_t selected_index = 0; + for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { + std::vector cur_kernel_info_match_counts = {0, 0, 0, 0, 0}; + auto kernel_info_ptr = kernel_info_list[info_index]; + MS_EXCEPTION_IF_NULL(kernel_info_ptr); + UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts); + // Currently the selection policy is the match format count first, and then is datatype counts. + if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) { + selected_index = SizeToInt(info_index); + } + } + return kernel_info_list[selected_index]; +} + +std::vector> FilteredKernelInfoByDtype( + const CNodePtr &cnode, const std::vector> &kernel_info_list) { + std::vector> result; + for (const auto &kernel_build_info : kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_build_info); + if (!MatchInferOutputDataType(cnode, *kernel_build_info)) { + continue; + } + result.push_back(kernel_build_info); + } + return result; +} + +std::vector> FilterRaisedOrReducePrecisionMatchedKernelInfo( + const CNodePtr &cnode, const std::vector> &kernel_info_list, + bool *precision_reduce) { + std::vector> filtered_kernel_info_list; + std::map> kernel_match_datatype_idx; + std::map> kernel_support_datatype; + std::vector node_mix_precision_datatype_index; + std::vector node_mix_precision_datatype; + for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { + std::vector support_indexes; + std::vector support_datatypes; + MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); + AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, + &support_datatypes, &node_mix_precision_datatype_index); + kernel_match_datatype_idx[info_index] = support_indexes; + kernel_support_datatype[info_index] = support_datatypes; + } + PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, + &kernel_match_datatype_idx, precision_reduce); + std::transform( + kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), + [&](const std::pair> &matched_idx) -> std::shared_ptr { + return kernel_info_list[matched_idx.first]; + }); + return filtered_kernel_info_list; +} +} // namespace + +void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); + MS_EXCEPTION_IF_NULL(input_kernel_node); + auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); + MS_EXCEPTION_IF_NULL(input_with_index.first); + auto real_input_node = input_with_index.first; + if (real_input_node->isa()) { + continue; + } + if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { + continue; + } + auto builder = std::make_shared(); + if (IsValueNode(input_kernel_node) && + AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + continue; + } + // we set special device info of a input tensor. + bool is_ref = false; + auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); + if (op_info != nullptr) { + is_ref = op_info->is_ref(); + } + MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); + if (MsContext::GetInstance()->execution_mode() == kPynativeMode && + AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { + continue; + } + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); + } + } +} + +KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, + const std::vector> &kernel_info_list) { + MS_EXCEPTION_IF_NULL(kernel_node); + KernelSelectStatus select_status = kNoMatched; + bool precision_reduce = false; + std::shared_ptr selected_kernel_info = nullptr; + // Matched kernel info + // Filter kernel info matched with me infered type + auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); + if (!filtered_kernel_info_list.empty()) { + selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + select_status = kStatusAllMatched; + } else { + // selected kernel info using raised precision or reduce precision + filtered_kernel_info_list = + FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce); + selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); + if (selected_kernel_info == nullptr) { + return select_status; + } else { + PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce); + select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision; + } + } + // Set kernel info to the anfnode + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + return select_status; +} + +KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { + std::vector> kernel_info_list; + std::vector> aicpu_kernel_info_list; + MS_EXCEPTION_IF_NULL(kernel_node); + if (AnfAlgo::IsGraphKernel(kernel_node)) { + auto func_graph = GetValueNode(kernel_node->input(kAnfPrimitiveIndex)); + MS_EXCEPTION_IF_NULL(func_graph); + SelectGraphKernelInfo(kernel_node, func_graph); + return kStatusAllMatched; + } + kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type); + auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); + // If aicore not find valid kernel info reloading aicpu kernel info list to find it + if (select_status == kNoMatched) { + MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() + << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; + kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list); + select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list); + AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); + } + // The kernel info not finded both in the aicpu kernel list & aicore kernel list + if (select_status == kNoMatched) { + std::ostringstream buffer; + PrintInputAndOutputInferType(buffer, kernel_node); + MS_LOG(WARNING) << ">>> Candidates kernel info list:"; + for (size_t index = 0; index < kernel_info_list.size(); ++index) { + MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString(); + } + for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) { + MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index) + << "] :" << aicpu_kernel_info_list[index]->ToString(); + } + if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) { + auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get()); + // Set format and data type for input tensor. + SetTensorDeviceInfo(*selected_kernel_info, kernel_node); + } else { + MS_LOG(WARNING) << " <<<"; + MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() + << "] cannot find valid kernel info, not supported the type:" << buffer.str() + << ", please refer to the supported dtypes in candidates kernel info list"; + } + } + return select_status; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h new file mode 100644 index 0000000000..8a93b77cec --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ +#include "ir/anf.h" +#include "backend/kernel_compiler/kernel_build_info.h" +namespace mindspore { +namespace device { +namespace ascend { +enum KernelSelectStatus { + kNoMatched = -1, + kStatusAllMatched = 0, + kStatusReducePrecision = 1, + kStatusRaisePrecision = 2, +}; +KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, + KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); +void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node); +void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph); +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_KERNEL_SELECT_ASCEND_ANFALGO_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc new file mode 100644 index 0000000000..42e856d112 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc @@ -0,0 +1,531 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" +#include "ir/func_graph.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace { +// sort format according the number of occurrences. +bool cmp_format_num(const std::pair &a, const std::pair &b) { + if (a.second != b.second) { + return a.second > b.second; + } else if (a.first == kOpFormat_DEFAULT) { + return a.second + 1 > b.second; + } else if (b.first == kOpFormat_DEFAULT) { + return a.second > b.second + 1; + } + return a.second > b.second; +} + +TypeId GetPrimitivePrecision(const CNodePtr &cnode) { + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + + TypeId except_type = kTypeUnknown; + if (primitive->GetAttr(kAttrFixPrecision) != nullptr) { + auto strExceptDtype = GetValue(primitive->GetAttr(kAttrFixPrecision)); + if (strExceptDtype == "float16") { + except_type = kNumberTypeFloat16; + } else if (strExceptDtype == "float32") { + except_type = kNumberTypeFloat32; + } else { + MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got" << strExceptDtype; + } + } + + return except_type; +} +} // namespace + +void ResetKernelBuildInfo(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); + MS_EXCEPTION_IF_NULL(input_kernel_node); + auto kernel_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); + if (!kernel::IsWeightBoundary(kernel_with_index.first)) { + continue; + } + // reset format and dtype. + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_kernel_node.get()); + } +} + +void UpdateKernelInfo(const std::vector &node_list) { + for (size_t i = 0; i < node_list.size(); ++i) { + // select nodes in subgraph. + auto anf_node = node_list[i]; + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto fix_precision_type = GetPrimitivePrecision(cnode); + if (fix_precision_type != kTypeUnknown) { + std::vector> kernel_info_list; + kernel::KernelQuery(cnode, &kernel_info_list, KernelType::AKG_KERNEL); + + for (size_t index = 0; index < kernel_info_list.size(); ++index) + // only math the first input + if (kernel_info_list[index]->GetInputDeviceType(0) == fix_precision_type && + kernel_info_list[index]->GetInputFormat(0) == AnfAlgo::GetPrevNodeOutputFormat(cnode, 0) && + AnfAlgo::GetInputDeviceDataType(cnode, 0) != fix_precision_type) { + auto selected_kernel_info_ptr = kernel_info_list[index]; + ResetKernelBuildInfo(cnode); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, cnode.get()); + SetTensorDeviceInfo(*selected_kernel_info_ptr, cnode); + break; + } + } + } +} + +bool CanConvertDefaultShapeToNZ(const std::vector &shape) { + for (size_t i = 1; i <= shape.size(); ++i) { + if (i > 2) { + break; + } + if (shape[shape.size() - i] != 1 && shape[shape.size() - i] % kCubeSize != 0) { + return false; + } + } + return true; +} + +std::vector DefaultToFracNZAxis(const std::vector &ori_shape, const std::vector &axis) { + std::vector frac_nz_axis = axis; + auto shape_len = ori_shape.size(); + for (size_t i = 0; i < axis.size(); ++i) { + auto axis_idx = (frac_nz_axis[i] + shape_len) % shape_len; + if (axis_idx == shape_len - 1) { + frac_nz_axis[i] = axis_idx - 1; + frac_nz_axis.push_back(axis_idx + 2); + } else if (axis_idx == shape_len - 2) { + frac_nz_axis[i] = axis_idx + 1; + frac_nz_axis.push_back(axis_idx + 2); + } else { + frac_nz_axis[i] = axis_idx; + } + } + return frac_nz_axis; +} + +std::vector GetReducedFracNZShape(const std::vector &ori_shape, const std::vector &axis, + bool keep_dims) { + std::vector result; + std::set positive_idx; + for (const auto &a : axis) { + positive_idx.insert(a >= 0 ? a : ori_shape.size() + a); + } + for (size_t i = 0; i < ori_shape.size(); ++i) { + if (positive_idx.count(i) == 0) { + result.push_back(ori_shape[i]); + } else if (keep_dims) { + result.push_back(1); + } + } + return result; +} + +void UpdateFracNZReduceOp(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, 0); + if (input_format == kOpFormat_FRAC_NZ) { + // Clone primitive to modify it + auto prim = GetCNodePrimitive(cnode); + auto new_prim = std::make_shared(*prim); + auto new_prim_node = NewValueNode(new_prim); + cnode->set_input(0, new_prim_node); + + auto axis_value = new_prim->GetAttr(kAttrAxis); + std::vector default_axis; + if (axis_value->isa()) { + auto value_list = dyn_cast(axis_value); + for (const auto &item : value_list->value()) { + if (item->isa()) { + default_axis.push_back(GetValue(item)); + } + } + } else if (axis_value->isa()) { + auto value_tuple = dyn_cast(axis_value); + for (const auto &item : value_tuple->value()) { + if (item->isa()) { + default_axis.push_back(GetValue(item)); + } + } + } else { + MS_LOG(ERROR) << "Axis attr type is not correct!"; + } + auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); + std::vector frac_nz_axis = DefaultToFracNZAxis(infer_shape, default_axis); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue>(frac_nz_axis), cnode); + auto output_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + if (output_shape.size() == 1) { + AnfAlgo::SetNodeAttr(kAttrOutputDefault, MakeValue(true), cnode); + } + } +} + +void GetDefaultFormat(const CNodePtr &kernel_node, std::string *default_format, bool *use_same_format) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(default_format); + MS_EXCEPTION_IF_NULL(use_same_format); + std::unordered_map all_input_formats; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; ++i) { + auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (!input_kernel_node->isa()) { + ++all_input_formats[AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)]; + continue; + } + auto para = input_kernel_node->cast(); + if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { + ++all_input_formats[AnfAlgo::GetOutputFormat(para, 0)]; + continue; + } + *use_same_format = false; + } + + if (all_input_formats.empty()) { + // all inputs are parameter. + *default_format = kOpFormat_NC1HWC0; + } else { + std::vector> pairs; + for (auto iter = all_input_formats.begin(); iter != all_input_formats.end(); ++iter) { + pairs.push_back(std::make_pair(iter->first, iter->second)); + } + + std::sort(pairs.begin(), pairs.end(), cmp_format_num); + *default_format = pairs.begin()->first; + } + + for (size_t i = 0; i < input_num; ++i) { + auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (!input_kernel_node->isa() || + AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) != kTypeUnknown) { + continue; + } + auto weight_infer_shape = AnfAlgo::GetOutputInferShape(input_kernel_node, 0); + if (weight_infer_shape.size() < 2 && *default_format == kOpFormat_FRAC_NZ) { + *default_format = kOpFormat_DEFAULT; + *use_same_format = true; + break; + } + } +} + +void UpdateInputsKernelInfo(const CNodePtr &kernel_node, const std::vector &input_list, + const std::string &default_format, bool use_same_format, + std::vector *graph_input_format, std::vector *graph_input_type) { + MS_EXCEPTION_IF_NULL(graph_input_format); + MS_EXCEPTION_IF_NULL(graph_input_type); + // We set same format to all inputs of graph kernel subgraph, and process this latter. + // We set dtype to inputs of graph kernel subgraph same as infer dtypes. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; ++i) { + auto input_kernel_node = AnfAlgo::VisitKernel(kernel_node->input(i + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (use_same_format) { + bool can_convert = true; + if (default_format == kOpFormat_FRAC_NZ) { + auto infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + if (!CanConvertDefaultShapeToNZ(infer_shape)) { + MS_LOG(WARNING) << "Shape can't be converted to frac nz shape, so use default format instead"; + can_convert = false; + } + } + if (can_convert) { + graph_input_format->push_back(default_format); + } else { + graph_input_format->push_back(kOpFormat_DEFAULT); + } + graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); + continue; + } + + if (!input_kernel_node->isa()) { + // subgraph parameter from output of other nodes. + graph_input_format->push_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, i)); + graph_input_type->push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, i)); + continue; + } + + auto para = input_kernel_node->cast(); + MS_EXCEPTION_IF_NULL(para); + if (AnfAlgo::GetOutputDeviceDataType(para, 0) != kTypeUnknown) { + // parameter already selected. + graph_input_format->push_back(AnfAlgo::GetOutputFormat(para, 0)); + graph_input_type->push_back(AnfAlgo::GetOutputDeviceDataType(para, 0)); + continue; + } + + // weight parameter. + graph_input_format->push_back(default_format); + graph_input_type->push_back(AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)); + } + + for (size_t i = 0; i < input_num; ++i) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + std::vector outputs_format = {(*graph_input_format)[i]}; + std::vector outputs_device_type = {(*graph_input_type)[i]}; + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_device_type); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); + } +} + +void UpdateEquivFormat(const std::vector> &output_index, + const std::vector &node_list, const FuncGraphPtr &func_graph, + const FuncGraphManagerPtr &mng) { + MS_EXCEPTION_IF_NULL(mng); + for (size_t i = 0; i < node_list.size(); ++i) { + // select nodes in subgraph. + auto anf_node = node_list[i]; + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + cnode->set_kernel_info(std::make_shared()); + SelectKernelInfo(cnode, KernelType::AKG_KERNEL); + // Update ReduceSum + if (!IsPrimitiveCNode(cnode, prim::kPrimReduceSum)) { + continue; + } + UpdateFracNZReduceOp(cnode); + // If ReduceSum's output is 1d and not Default format, convert it to Default format + auto out_format = AnfAlgo::GetOutputFormat(cnode, 0); + if (out_format == kOpFormat_DEFAULT || !AnfAlgo::HasNodeAttr(kAttrOutputDefault, cnode)) { + continue; + } + auto infer_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + // Insert EquivFormat node, then select kernel info again + std::vector trans_inputs; + trans_inputs.push_back(NewValueNode(prim::kPrimEquivFormat)); + trans_inputs.push_back(cnode); + CNodePtr trans_node = func_graph->NewCNode(trans_inputs); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0)}, + {AnfAlgo::GetOutputInferShape(cnode, 0)}, trans_node.get()); + AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue>({"x"}), trans_node); + + if (trans_node->kernel_info() == nullptr) { + trans_node->set_kernel_info(std::make_shared()); + } + SelectKernelInfo(trans_node, KernelType::AKG_KERNEL); + mng->Replace(cnode, trans_node); + } +} + +void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &input_list, + const FuncGraphManagerPtr &mng, const std::string &default_format, + std::vector *graph_input_format, std::vector *graph_input_type, + std::vector *need_update) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(mng); + MS_EXCEPTION_IF_NULL(graph_input_format); + MS_EXCEPTION_IF_NULL(graph_input_type); + MS_EXCEPTION_IF_NULL(need_update); + // check graph input format and dtype use inner ops. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (graph_input_format->size() != input_num || graph_input_type->size() != input_num || + need_update->size() != input_num) { + MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() + << "], [" << graph_input_format->size() << "] != [" << input_num << "]"; + } + auto &node_users = mng->node_users(); + for (size_t i = 0; i < input_num; ++i) { + auto &input = input_list[i]; + auto iter = node_users.find(input); + if (iter == node_users.end() || iter->second.empty()) { + continue; + } + for (auto &node_user : iter->second) { + if (node_user.first->kernel_info() == nullptr || + node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { + // maybe not a real kernel. + continue; + } + auto user_format = AnfAlgo::GetInputFormat(node_user.first, IntToSize(node_user.second - 1)); + if (user_format != (*graph_input_format)[i]) { + MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" + << kernel_node->DebugString() + << "] selected different format. we use defult: " << default_format; + (*graph_input_format)[i] = default_format; + (*need_update)[i] = true; + } + + if (kernel_node->input(i + 1)->isa() || + AnfAlgo::GetInputDeviceDataType(node_user.first, IntToSize(node_user.second - 1)) == (*graph_input_type)[i]) { + continue; + } + + TypeId default_dtype = AnfAlgo::GetOutputInferDataType(input, 0); + MS_LOG(WARNING) << "Users of input: [" << i << "][" << input->DebugString(2) << " of [" + << kernel_node->DebugString() + << "] selected different dtype. we use default: " << TypeIdLabel(default_dtype); + (*graph_input_type)[i] = default_dtype; + (*need_update)[i] = true; + } + } +} + +void UpdateFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector &node_list, + const std::vector &input_list, const std::vector &need_update, + const std::vector &graph_input_format, + const std::vector &graph_input_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + // update graph input format and dtype use inner ops. + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (graph_input_format.size() != input_num || graph_input_type.size() != input_num || + need_update.size() != input_num) { + MS_LOG(EXCEPTION) << "Graph input format size is not equal to input num of cnode[" << kernel_node->DebugString() + << "], [" << graph_input_format.size() << "] != [" << input_num << "]"; + } + for (size_t i = 0; i < input_num; ++i) { + if (!need_update[i]) { + continue; + } + + MS_LOG(DEBUG) << "Update input format: " << i << " of: [" << kernel_node->DebugString() + << "] to: " << graph_input_format[i]; + MS_LOG(DEBUG) << "Update input dtype: " << i << " of: [" << kernel_node->DebugString() + << "] to: " << TypeIdLabel(graph_input_type[i]); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + std::vector outputs_format = {graph_input_format[i]}; + std::vector outputs_device_type = {graph_input_type[i]}; + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_device_type); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_list[i].get()); + } + + ResetKernelBuildInfo(kernel_node); + // select nodes in subgraph again. + for (size_t i = 0; i < node_list.size(); ++i) { + auto anf_node = node_list[i]; + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + size_t cnode_input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t j = 0; j < cnode_input_num; ++j) { + auto input_node = cnode->input(j + 1); + MS_EXCEPTION_IF_NULL(input_node); + if (!IsValueNode(input_node)) { + continue; + } + // reset format and dtype of const tensor. + builder.SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType(std::vector{kTypeUnknown}); + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), input_node.get()); + } + SelectKernelInfo(node_list[i]->cast(), KernelType::AKG_KERNEL); + } +} + +void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector> &output_index, + const std::vector &graph_input_format, + const std::vector &graph_input_type) { + MS_EXCEPTION_IF_NULL(kernel_node); + std::vector graph_output_format; + std::vector graph_output_type; + for (size_t i = 0; i < output_index.size(); ++i) { + auto const &output = output_index[i]; + graph_output_format.push_back(AnfAlgo::GetOutputFormat(output.first, output.second)); + TypeId output_type(kTypeUnknown); + if (output.first->isa()) { + output_type = AnfAlgo::GetCNodeOutputPrecision(output.first); + } + if (output_type == kTypeUnknown) { + output_type = AnfAlgo::GetOutputDeviceDataType(output.first, output.second); + } + graph_output_type.push_back(output_type); + } + + kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder; + graph_info_builder.SetInputsFormat(graph_input_format); + graph_info_builder.SetInputsDeviceType(graph_input_type); + graph_info_builder.SetOutputsFormat(graph_output_format); + graph_info_builder.SetOutputsDeviceType(graph_output_type); + graph_info_builder.SetProcessor(kernel::Processor::AICORE); + graph_info_builder.SetKernelType(KernelType::AKG_KERNEL); + graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE); + auto graph_selected_info = graph_info_builder.Build(); + MS_EXCEPTION_IF_NULL(graph_selected_info); + AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get()); + SetTensorDeviceInfo(*graph_selected_info, kernel_node); +} + +void SelectGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(func_graph); + + // collect input info of funcgraph + std::vector node_list; + std::vector input_list; + std::vector output_list; + kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + if (input_list.size() != kernel_node->inputs().size() - 1) { + MS_EXCEPTION(ArgumentError) << "Input num of funcgraph[" << func_graph->ToString() << "] not equal input of cnode[" + << kernel_node->DebugString() << "], [%" << input_list.size() << "] != [" + << kernel_node->inputs().size() << "]"; + } + + std::string default_format; + bool use_same_format = true; + GetDefaultFormat(kernel_node, &default_format, &use_same_format); + MS_LOG(DEBUG) << "GraphKernel[" << func_graph->ToString() << "] use same input format[" << default_format + << "] for ParameterWeight."; + + std::vector graph_input_format; + std::vector graph_input_type; + UpdateInputsKernelInfo(kernel_node, input_list, default_format, use_same_format, &graph_input_format, + &graph_input_type); + + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + } + auto output_index = kernel::GetOutputIndex(node_list, input_list, output_list); + UpdateEquivFormat(output_index, node_list, func_graph, mng); + node_list.clear(); + input_list.clear(); + output_list.clear(); + kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); + + // update graph input format and dtype use inner ops. + std::vector need_update(AnfAlgo::GetInputTensorNum(kernel_node), false); + CheckFormatsAndDtypes(kernel_node, input_list, mng, default_format, &graph_input_format, &graph_input_type, + &need_update); + UpdateFormatsAndDtypes(kernel_node, node_list, input_list, need_update, graph_input_format, graph_input_type); + + // set fix_precision for kernel when the me prim has fix_precision attr + UpdateKernelInfo(node_list); + + output_index = kernel::GetOutputIndex(node_list, input_list, output_list); + SetGraphKernelInfo(kernel_node, output_index, graph_input_format, graph_input_type); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.cc new file mode 100644 index 0000000000..4886c00a8e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/profiling/plugin_impl.h" +#include +#include "utils/log_adapter.h" +using std::string; + +namespace mindspore { +namespace device { +namespace ascend { +Reporter *PluginImpl::reporter_ = nullptr; + +PluginImpl::PluginImpl(const std::string &module) : module_(module) { MS_LOG(INFO) << "Create PluginImpl."; } + +int PluginImpl::Init(const Reporter *reporter) { + MS_LOG(INFO) << "PluginImpl init"; + MS_EXCEPTION_IF_NULL(reporter); + reporter_ = const_cast(reporter); + return 0; +} + +int PluginImpl::UnInit() { + MS_LOG(INFO) << " PluginImpl Uninit "; + reporter_ = nullptr; + return 0; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h b/mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/plugin_impl.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/plugin_impl.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.cc new file mode 100644 index 0000000000..1f35cba0f7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/ascend/profiling/profiling_engine_impl.h" +#include "utils/log_adapter.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" + +namespace mindspore { +namespace device { +namespace ascend { +PluginIntf *ProfilingEngineImpl::CreatePlugin() { + MS_LOG(INFO) << "Create Plugin."; + return new (std::nothrow) PluginImpl("Framework"); +} + +int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { + if (plugin != nullptr) { + delete plugin; + plugin = nullptr; + } + return 0; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/profiling_engine_impl.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc new file mode 100644 index 0000000000..6117fe5ecf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include +#include +#include "securec/include/securec.h" +#include "./prof_mgr_core.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" +#include "runtime/device/ascend/profiling/profiling_engine_impl.h" +#include "utils/log_adapter.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "utils/convert_utils.h" +#include "runtime/base.h" + +namespace mindspore { +namespace device { +namespace ascend { +ProfilingManager &ProfilingManager::GetInstance() { + static ProfilingManager inst; + return inst; +} + +ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { + engine_0_ = std::make_shared(); +} + +uint64_t ProfilingManager::GetJobId() const { + const char *job_id = std::getenv("JOB_ID"); + return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); +} + +bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { + if (!IsProfiling()) { + MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; + return false; + } + if (op_taskId_map.empty()) { + MS_LOG(WARNING) << "op_taskId_map is empty."; + return false; + } + auto reporter = PluginImpl::GetPluginReporter(); + if (reporter == nullptr) { + MS_LOG(ERROR) << "No profiling data report!"; + return false; + } + MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); + + Msprof::Engine::ReporterData reporter_data = {}; + for (const auto &iter : op_taskId_map) { + auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; + reporter_data.deviceId = UintToInt(device_id_); + reporter_data.data = (unsigned char *)(const_cast(data.c_str())); + reporter_data.dataLen = data.size(); + auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); + if (ret != 0) { + MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; + return false; + } + ret = reporter->Report(&reporter_data); + if (ret != 0) { + MS_LOG(ERROR) << "reporter data fail, errorno(" << ret << ")"; + return false; + } + } + return true; +} + +static std::vector Split(const std::string &str, const char delim) { + std::vector elems; + + if (str.empty()) { + elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + elems.push_back(item); + } + auto str_size = str.size(); + if (str_size > 0 && str[str_size - 1] == delim) { + elems.emplace_back(""); + } + + return elems; +} + +bool ProfilingManager::StartupProfiling(uint32_t device_id) { + auto is_profiling = IsProfiling(); + if (!is_profiling) { + MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; + return true; + } + device_id_ = device_id; + // register Framework to profiling + int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); + if (result != 0) { + MS_LOG(ERROR) << "Register profiling Engine failed."; + return false; + } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + const string prof_options_str = context->profiling_options(); + std::vector opts = Split(prof_options_str, ':'); + if (opts.empty()) { + MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; + return true; + } + // current one docker only use one device` + nlohmann::json p_device; + // JOBID + auto job_id = GetJobId(); + p_device["jobID"] = std::to_string(job_id); + // device_id + p_device["deviceID"] = std::to_string(device_id); + // features:'training_trace', 'task_trace' etc + nlohmann::json features; + for (std::vector::size_type i = 0; i < opts.size(); i++) { + nlohmann::json f; + f["name"] = opts[i]; + features[i] = f; + } + p_device["features"] = features; + // only one device, but sProfMgrStartUp API require for device list + nlohmann::json devices; + devices[0] = p_device; + nlohmann::json startCfg; + startCfg["startCfg"] = devices; + + if (!ProfStartUp(NOT_NULL(&startCfg))) { + MS_LOG(ERROR) << "ProfMgrStartUp failed."; + return false; + } + return true; +} + +bool ProfilingManager::ProfStartUp(NotNull startCfg) { + // convert json to string + std::stringstream ss; + ss << *startCfg; + std::string cfg = ss.str(); + MS_LOG(INFO) << "profiling config " << cfg; + auto ret = rtProfilerStart(); + if (ret != RT_ERROR_NONE) { + MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret; + return false; + } + + // call profiling startup API + ProfMgrCfg prof_cfg = {cfg}; + prof_handle_ = ProfMgrStartUp(&prof_cfg); + if (prof_handle_ == nullptr) { + MS_LOG(ERROR) << "Startup profiling failed."; + return false; + } + return true; +} + +bool ProfilingManager::StopProfiling() { + MS_LOG(INFO) << "StopProfiling"; + if (!IsProfiling()) { + MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; + return true; + } + Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); + if (reporter != nullptr) { + MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); + } + + auto rt_ret = rtProfilerStop(); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "Call rtProfilerStop failed"; + return false; + } + + if (prof_handle_ != nullptr) { + int result = ProfMgrStop(prof_handle_); + if (result != 0) { + MS_LOG(ERROR) << "ProfMgr stop return fail:" << result << "."; + prof_handle_ = nullptr; + return false; + } + prof_handle_ = nullptr; + } + + return true; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/profiling_manager.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc new file mode 100644 index 0000000000..5b1db6a404 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc @@ -0,0 +1,367 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/profiling/reporter/graph_desc_reporter.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "utils/utils.h" +#include "runtime/device/ascend/profiling/reporter/task_desc_reporter.h" +#include "utils/context/ms_context.h" +#include "runtime/device/ascend/profiling/reporter/point_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +constexpr uint32_t kMaxProfilingNodeNum = 100; +constexpr char kCustomNode[] = "PROFILING_CUSTOM_"; +constexpr char kFpStartNode[] = "PROFILING_FP_START"; +constexpr char kBpEndNode[] = "PROFILING_BP_END"; +constexpr char kIterEndNode[] = "PROFILING_ITER_END"; +// PROFILING_CUSTOM_LOGID_START 3 +constexpr uint64_t kProfilingFpStartLogId = 1; +constexpr uint64_t kProfilingBpEndLogId = 2; +constexpr uint64_t kProfilingIterEndLogId = 255; +std::map> ProfilingUtils::graph_profiling_cnode_; +std::map> ProfilingUtils::graph_kernel_name_; +std::map>> ProfilingUtils::graph_point_; +uint32_t ProfilingUtils::custom_node_index_ = 1; + +ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull graph_ptr) { + MS_LOG(INFO) << "get env start"; + custom_node_index_ = 1; + auto &cnode_exec_order = graph_ptr->execution_order(); + ProfilingTraceInfo profiling_trace; + profiling_trace.trace_begin = GetTraceBegin(cnode_exec_order); + profiling_trace.trace_bp_end = GetTraceBpEnd(cnode_exec_order); + profiling_trace.trace_netoutput = GetTraceNetoutput(cnode_exec_order); + + for (uint32_t i = 1; i <= kMaxProfilingNodeNum; ++i) { + std::string env_str = std::string(kCustomNode) + std::to_string(i); + const char *node_full_name = std::getenv(env_str.c_str()); + if (node_full_name == nullptr) { + break; + } + MS_LOG(INFO) << "Get profiling node:" << node_full_name; + profiling_trace.trace_custom_node.insert(node_full_name); + } + MS_LOG(INFO) << "get env end"; + GetTraceHccl(cnode_exec_order, NOT_NULL(&profiling_trace)); + + MS_LOG(INFO) << "[profiling]trace_begin:" << profiling_trace.trace_begin + << " trace_bp_end:" << profiling_trace.trace_bp_end + << " trace_netoutput:" << profiling_trace.trace_netoutput; + return profiling_trace; +} + +void ProfilingUtils::GetTraceHccl(const std::vector &cnode_exec_order, + NotNull profiling_trace) { + for (const auto &node : cnode_exec_order) { + if (AnfAlgo::IsCommunicationOp(node)) { + MS_EXCEPTION_IF_NULL(node); + profiling_trace->trace_custom_node.insert(node->fullname_with_scope()); + MS_LOG(INFO) << "[profiling]Get hccl node:" << node->fullname_with_scope(); + } + } +} + +std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exec_order) { + const char *trace_begin = std::getenv(kFpStartNode); + if (trace_begin != nullptr) { + return std::string(trace_begin); + } + + std::string fp_start_str; + std::set getnext_outputs; + GetCNodeOutputRealNode(kGetNextOpName, cnode_exec_order, NOT_NULL(&getnext_outputs)); + if (getnext_outputs.empty()) { + auto first_node = cnode_exec_order.front(); + MS_EXCEPTION_IF_NULL(first_node); + fp_start_str = first_node->fullname_with_scope(); + } else { + for (auto &cnode : cnode_exec_order) { + if (getnext_outputs.count(cnode->fullname_with_scope()) != 0) { + fp_start_str = cnode->fullname_with_scope(); + break; + } + } + } + return fp_start_str; +} + +void ProfilingUtils::GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, + NotNull *> getnext_outputs) { + for (const auto &cnode : cnode_exec_order) { + MS_EXCEPTION_IF_NULL(cnode); + for (const auto &input : cnode->inputs()) { + auto prev_cnode = AnfAlgo::VisitKernel(input, 0); + if (!prev_cnode.first->isa()) { + continue; + } + if (AnfAlgo::GetCNodeName(prev_cnode.first) == node_name) { + getnext_outputs->insert(cnode->fullname_with_scope()); + MS_LOG(INFO) << "Find GetNext Output CNode:" << cnode->fullname_with_scope(); + } + } + } + if (getnext_outputs->empty()) { + MS_LOG(WARNING) << "GetNext not found"; + } +} + +std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exec_order) { + const char *trace_bp_end = std::getenv(kBpEndNode); + + if (trace_bp_end != nullptr) { + return std::string(trace_bp_end); + } + std::string bp_end_str; + // Contain hccl kernel + auto iter = cnode_exec_order.rbegin(); + while (iter != cnode_exec_order.rend()) { + if (AnfAlgo::IsCommunicationOp(*iter)) { + // store communication op input nodes' name + std::set ar_input_node_names; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); + auto input_node = input_node_with_index.first; + ar_input_node_names.insert(input_node->fullname_with_scope()); + } + // start from previous node + ++iter; + // find input names in previous node + while (iter != cnode_exec_order.rend()) { + if (ar_input_node_names.find((*iter)->fullname_with_scope()) != ar_input_node_names.end()) { + bp_end_str = (*iter)->fullname_with_scope(); + break; + } + ++iter; + } + break; + } + ++iter; + } + + if (bp_end_str.empty()) { + bp_end_str = GetGraphLastTbeKernelName(cnode_exec_order); + } + return bp_end_str; +} + +std::string ProfilingUtils::GetGraphLastTbeKernelName(const std::vector &cnode_exec_order) { + std::string last_tbe_kernel_name; + // find last tbe_kernel + for (auto iter = cnode_exec_order.rbegin(); iter != cnode_exec_order.rend(); ++iter) { + if (AnfAlgo::GetKernelType(*iter) == TBE_KERNEL) { + last_tbe_kernel_name = (*iter)->fullname_with_scope(); + break; + } + } + if (last_tbe_kernel_name.empty()) { + MS_LOG(WARNING) << "tbe kernel not found in graph"; + } + return last_tbe_kernel_name; +} + +std::string ProfilingUtils::GetTraceNetoutput(const std::vector &cnode_exec_order) { + const char *trace_netoutput = std::getenv(kIterEndNode); + return trace_netoutput == nullptr ? GetGraphLastTbeKernelName(cnode_exec_order) : std::string(trace_netoutput); +} + +NotNull ProfilingUtils::CreateProfilingCNode(const ProfilingContent &profiling_content, + NotNull graph_ptr) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + selected_kernel_builder.SetInputsDeviceType({TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + abstract::AbstractBasePtr type_none_abstract = std::make_shared(); + auto primitive = std::make_shared(ProfilingUtils::kProfiling); + std::vector inputs; + inputs.emplace_back(NewValueNode(primitive)); + CNodePtr cnode_ptr = graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(cnode_ptr); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), cnode_ptr.get()); + cnode_ptr->set_abstract(type_none_abstract); + // set attr + ValuePtr notify_value = MakeValue(profiling_content.notify); + ValuePtr trace_id_value = MakeValue(profiling_content.profiler_trace_id); + ValuePtr flags_value = MakeValue(profiling_content.flags); + AnfAlgo::SetNodeAttr(ProfilingUtils::kNotify, notify_value, cnode_ptr); + AnfAlgo::SetNodeAttr(ProfilingUtils::kProfilerTraceId, trace_id_value, cnode_ptr); + AnfAlgo::SetNodeAttr(ProfilingUtils::kFlags, flags_value, cnode_ptr); + return NOT_NULL(cnode_ptr); +} + +void ProfilingUtils::SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id) { + std::shared_ptr prof_desc_ptr = std::make_shared(node_name, point_id); + auto iter = graph_point_.find(graph_id); + if (iter == graph_point_.end()) { + std::vector> tmp_vect = {prof_desc_ptr}; + graph_point_.insert({graph_id, tmp_vect}); + } else { + iter->second.emplace_back(prof_desc_ptr); + } +} + +void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node, + const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + if (profiling_trace_info.trace_begin == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match FpStart:" << profiling_trace_info.trace_begin; + ProfilingTraceJobId(anf_node, graph_ptr, kernel_list); + ProfilingContent fp_profiling_content = {false, kProfilingFpStartLogId, 0}; + auto fp_profiling_node = CreateProfilingCNodeWithStream(anf_node, fp_profiling_content, graph_ptr); + kernel_list->emplace_back(fp_profiling_node); + // insert ProfDesc + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingFpStartLogId); + } +} + +void ProfilingUtils::ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, + NotNull *> kernel_list) { + MS_LOG(INFO) << "Profiling Match start"; + auto job_id = ProfilingManager::GetInstance().GetJobId(); + ProfilingContent job_profiling_context = {false, job_id, 0}; + auto job_profiling_node = CreateProfilingCNodeWithStream(anf_node, job_profiling_context, graph_ptr); + kernel_list->emplace_back(job_profiling_node); +} + +CNodePtr ProfilingUtils::CreateProfilingCNodeWithStream(const mindspore::AnfNodePtr &anf_node, + const ProfilingContent &profiling_content, + NotNull graph_ptr) { + CNodePtr profiling_node = CreateProfilingCNode(profiling_content, graph_ptr); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), profiling_node.get()); + AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(anf_node), profiling_node.get()); + return profiling_node; +} + +void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto iter = profiling_trace_info.trace_custom_node.find(anf_node->fullname_with_scope()); + if (iter == profiling_trace_info.trace_custom_node.end()) { + return; + } + MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope(); + // custom op profiling job start from 3. + auto custom_point_id = 2 * custom_node_index_ + 1; + ProfilingContent front_profiling_content = {false, custom_point_id, 0}; + CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr); + kernel_list->insert(kernel_list->end() - 1, front_node); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id); + + ProfilingContent back_profiling_content = {false, custom_point_id + 1, 0}; + CNodePtr back_node = CreateProfilingCNodeWithStream(anf_node, back_profiling_content, graph_ptr); + kernel_list->insert(kernel_list->end(), back_node); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), custom_point_id + 1); + ++custom_node_index_; +} + +void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + MS_EXCEPTION_IF_NULL(anf_node); + if (profiling_trace_info.trace_bp_end == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match BpEnd:" << profiling_trace_info.trace_bp_end; + ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0}; + CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); + kernel_list->emplace_back(bp_end_node); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingBpEndLogId); + } +} + +void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list) { + MS_EXCEPTION_IF_NULL(anf_node); + auto full_scope_name = anf_node->fullname_with_scope(); + if (profiling_trace_info.trace_netoutput == full_scope_name) { + MS_LOG(INFO) << "Profiling Match IterEnd:" << profiling_trace_info.trace_netoutput; + ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0}; + CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); + kernel_list->emplace_back(bp_kernel_ptr); + SaveProfilingPoint(graph_ptr->graph_id(), anf_node->fullname_with_scope(), kProfilingIterEndLogId); + } +} + +void ProfilingUtils::SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names) { + auto ret = graph_kernel_name_.try_emplace(graph_id, kernel_names); + if (!ret.second) { + MS_LOG(ERROR) << "[profiling]graph " << graph_id << " kernel names already exist"; + } +} + +void ProfilingUtils::SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list) { + auto ret = graph_profiling_cnode_.try_emplace(graph_id, profiling_cnode_list); + if (!ret.second) { + MS_LOG(ERROR) << "[profiling]graph " << graph_id << " profiling cnode list already exist"; + } +} + +bool ProfilingUtils::ValidComputeGraph(NotNull graph_ptr) { + for (const auto &node : graph_ptr->execution_order()) { + if (AnfAlgo::GetKernelType(node) == TBE_KERNEL) { + return true; + } + } + return false; +} + +void ProfilingUtils::ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, + NotNull graph) { + if (!ValidComputeGraph(graph)) { + MS_LOG(WARNING) << "Not a valid compute graph:" << graph->graph_id(); + return; + } + + auto ret = graph_profiling_cnode_.find(graph->graph_id()); + if (ret == graph_profiling_cnode_.end()) { + MS_LOG(ERROR) << "Graph id not found"; + return; + } + + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + TaskDescReporter task_reporter(context->device_id(), "vm.task_desc_info", ret->second); + task_reporter.set_task_ids(task_ids); + task_reporter.set_stream_ids(stream_ids); + task_reporter.ReportData(); + + GraphDescReporter graph_reporter(context->device_id(), "vm.graph_desc_info", ret->second); + graph_profiling_cnode_.erase(ret); + graph_reporter.ReportData(); + + // Report profiling point + auto point_iter = graph_point_.find(graph->graph_id()); + if (point_iter == graph_point_.end()) { + MS_LOG(ERROR) << "Graph id not found in graph_point"; + return; + } + PointReporter point_reporter(context->device_id(), "vm.point"); + for (const auto &point : point_iter->second) { + point_reporter.AddReportData(point); + } + point_reporter.ReportData(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h new file mode 100644 index 0000000000..de8ff2ac39 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.h @@ -0,0 +1,142 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "utils/contract.h" +#include "runtime/device/ascend/profiling/reporter/profiling_desc.h" + +namespace mindspore { +namespace device { +namespace ascend { +struct ProfilingTraceInfo { + // execute order's first execute op(like: Cast or Four2Five ...), except tdt op(GetNext ...) + std::string trace_begin; + // get first net_output(apply kernel) from graph outputs: fp ->net_output<- bp + std::string trace_bp_end; + // execute order's end execute (like: Conv2DBackpropFilter) + std::string trace_netoutput; + + // profiling specific op, such as AllReduce; + std::set trace_custom_node; + + // 1. insert profiling_trace_begin if profiling_trace_bp_end is not empty. + // 2. op lanuch get task info with callback func. + // 3. insert profiling_trace_bp_end. + // 4. insert profiling_trace_net_output if profiling_trace_bp_end is not empty. + + bool IsValid() const { return !(trace_begin.empty() || trace_netoutput.empty()); } +}; + +struct ProfilingContent { + // true -send data from device to host and finish profiling + bool notify; + uint64_t profiler_trace_id; + uint32_t flags; +}; + +class ProfilingUtils { + public: + ProfilingUtils() = default; + ~ProfilingUtils() = default; + + // Insert job_id profiling node and fp_start profiling node. + // Job_id is got from envs, which shound be a number greater than 255 + // Fp_start node should been inserted in the start of a network, and the log_id is hard code to 1. + static void ProfilingTraceFpStart(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + static void ProfilingTraceJobId(const AnfNodePtr &anf_node, NotNull graph_ptr, + NotNull *> kernel_list); + + // Insert net output profiling node, which tells the device to stop profiling. + // The notify in struct ProfilingContent should be 'true', which tells the device to send data to host. + static void ProfilingTraceEnd(const AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + // Insert bp_end profiling node, which should been inserted after the last backpropagation CNode in the network. + static void ProfilingTraceBpEnd(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + // Mapping graph id and the kernels' name in the graph + static void SetGraphProfilingCNode(uint32_t graph_id, const std::vector &profiling_cnode_list); + + static void SetGraphKernelName(uint32_t graph_id, const std::vector &kernel_names); + + // Mapping task_id and kernel name for device to generate the time cost of specific kernel. + // Device calculate the time cost of the task which is marked by task id. + // But we need data of (kernel name , time cost) + static void ReportProfilingData(const std::vector &task_ids, const std::vector &stream_ids, + NotNull graph); + + // Get profiling trace point from envs. + // export PROFILING_FP_START='full name of the first cnode to execute' + // export PROFILING_BP_END='full name of the last backpropagation cnode to execute' + // export PROFILING_ITER_END='full name of last cnode in graph to execute' + // And other cnode, like AllReduce, export PROFILING_CUSTOM_1='full name of AllReduce cnode' + // GetNext, export PROFIFLING_CUSTOM_2='full name fo GetNext cnode' + // The variable i in PROFILING_CUSTOM_i should start from 1 without interruption. + static ProfilingTraceInfo GetProfilingTraceFromEnv(NotNull graph_ptr); + + // Insert two profiling trace points, one in front and one behind + static void ProfilingCustomOp(const mindspore::AnfNodePtr &anf_node, const ProfilingTraceInfo &profiling_trace_info, + NotNull graph_ptr, + NotNull *> kernel_list); + + static std::map> graph_kernel_name() { return graph_kernel_name_; } + + inline static constexpr char kProfiling[] = "Profiling"; + inline static constexpr char kNotify[] = "notify"; + inline static constexpr char kProfilerTraceId[] = "profiler_trace_id"; + inline static constexpr char kFlags[] = "flags"; + + private: + static NotNull CreateProfilingCNode(const ProfilingContent &profiling_content, + NotNull graph_ptr); + static CNodePtr CreateProfilingCNodeWithStream(const AnfNodePtr &anf_node, const ProfilingContent &profiling_content, + NotNull graph_ptr); + static std::string GetTraceBegin(const std::vector &cnode_exec_order); + static std::string GetTraceBpEnd(const std::vector &cnode_exec_order); + static std::string GetTraceNetoutput(const std::vector &cnode_exec_order); + static std::string GetGraphLastTbeKernelName(const std::vector &cnode_exec_order); + static void GetTraceHccl(const std::vector &cnode_exec_order, + NotNull profiling_trace); + static void GetCNodeOutputRealNode(const std::string &node_name, const std::vector &cnode_exec_order, + NotNull *> getnext_outputs); + + static bool ValidComputeGraph(NotNull graph_ptr); + static void SaveProfilingPoint(uint32_t graph_id, const std::string &node_name, uint32_t point_id); + + // graph id --> (kernel name list) + static std::map> graph_profiling_cnode_; + static std::map> graph_kernel_name_; + static std::map>> graph_point_; + static uint32_t custom_node_index_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_PROFILING_UTILS_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.cc new file mode 100644 index 0000000000..87e2bbcb06 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" +#include "utils/log_adapter.h" + +constexpr size_t kReportMaxLen = 2048; + +namespace mindspore { +namespace device { +namespace ascend { +DescReporter::~DescReporter() = default; + +void DescReporter::ReportByLine(const std::string &data, const std::string &file_name) const { + auto reporter = PluginImpl::GetPluginReporter(); + MS_EXCEPTION_IF_NULL(reporter); + + auto tot_size = data.size(); + size_t cur_size = 0; + while (cur_size < tot_size) { + size_t remain_size = tot_size - cur_size; + size_t report_size = std::min(remain_size, kReportMaxLen); + + Msprof::Engine::ReporterData report_data{}; + report_data.deviceId = device_id_; + report_data.dataLen = report_size; + report_data.data = (unsigned char *)data.c_str() + cur_size; + auto ret = memcpy_s(report_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, file_name.c_str(), file_name.length()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Memcpy_s report data tag failed"; + } + auto report_ret = reporter->Report(&report_data); + if (report_ret != 0) { + MS_LOG(EXCEPTION) << "Report data failed"; + } + if (report_size == 0) { + MS_LOG(WARNING) << "Report_size is 0"; + break; + } + cur_size += report_size; + } +} + +void DescReporter::ReportAllLine() { + for (const auto &desc : prof_desc_list_) { + auto data = desc->ToString(); + ReportByLine(data, file_name_); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h new file mode 100644 index 0000000000..f25c64ce05 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/desc_reporter.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ + +#include +#include +#include +#include +#include "toolchain/prof_reporter.h" +#include "runtime/device/ascend/profiling/reporter/profiling_desc.h" +#include "utils/contract.h" +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace device { +namespace ascend { +class DescReporter { + public: + virtual ~DescReporter() = 0; + DescReporter(int device_id, std::string file_name) : device_id_(device_id), file_name_(std::move(file_name)) {} + + virtual void ReportData() = 0; + + protected: + void ReportByLine(const std::string &data, const std::string &file_name) const; + void ReportAllLine(); + + int device_id_; + std::string file_name_; + std::vector> prof_desc_list_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.cc new file mode 100644 index 0000000000..5c028986d4 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "runtime/device/ascend/profiling/reporter/graph_desc_reporter.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace ascend { +void GraphDescReporter::ReportData() { + for (const auto &node : cnode_list_) { + if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { + MS_LOG(WARNING) << "Skip non tbe kernel"; + continue; + } + std::vector input_data_list; + std::vector output_data_list; + MS_EXCEPTION_IF_NULL(node); + auto op_name = node->fullname_with_scope(); + auto op_type = AnfAlgo::GetCNodeName(node); + auto input_size = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_size; ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); + auto input_node = input_node_with_index.first; + auto input_index = input_node_with_index.second; + DataElement element{}; + element.index_ = i; + element.data_type_ = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); + element.data_format_ = AnfAlgo::GetOutputFormat(input_node, input_index); + element.data_shape_ = AnfAlgo::GetOutputDeviceShape(input_node, input_index); + input_data_list.emplace_back(element); + } + + auto output_size = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_size; ++i) { + DataElement element{}; + element.index_ = i; + element.data_type_ = AnfAlgo::GetOutputDeviceDataType(node, i); + element.data_format_ = AnfAlgo::GetOutputFormat(node, i); + element.data_shape_ = AnfAlgo::GetOutputDeviceShape(node, i); + output_data_list.emplace_back(element); + } + + auto graph_desc = std::make_shared(op_name, op_type, input_data_list, output_data_list); + prof_desc_list_.emplace_back(graph_desc); + } + ReportAllLine(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h new file mode 100644 index 0000000000..531f122cde --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/graph_desc_reporter.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ + +#include +#include +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +class GraphDescReporter : public DescReporter { + public: + GraphDescReporter(uint32_t device_id, const std::string &file_name, std::vector cnode_list) + : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} + ~GraphDescReporter() override = default; + void ReportData() override; + + private: + std::vector cnode_list_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_GRAPH_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.cc new file mode 100644 index 0000000000..42a1b4c286 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/profiling/reporter/point_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +void PointReporter::ReportData() { ReportAllLine(); } + +void PointReporter::AddReportData(const std::shared_ptr &prof_desc) { + prof_desc_list_.emplace_back(prof_desc); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h new file mode 100644 index 0000000000..c24535f4ec --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/point_reporter.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ + +#include +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +class PointReporter : public DescReporter { + public: + PointReporter(uint32_t device_id, const std::string &file_name) : DescReporter(device_id, file_name) {} + ~PointReporter() override = default; + void ReportData() override; + void AddReportData(const std::shared_ptr &prof_desc); +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_POINT_REPORTER_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.cc new file mode 100644 index 0000000000..4aec72472c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include "runtime/device/ascend/profiling/reporter/profiling_desc.h" + +namespace mindspore { +namespace device { +namespace ascend { +std::string TaskDesc::ToString() { + std::string out = op_name_; + out.append(" ") + .append(std::to_string(block_dim_)) + .append(" ") + .append(std::to_string(task_id_)) + .append(" ") + .append(std::to_string(stream_id_)) + .append("\n"); + return out; +} + +std::string GraphDesc::ToString() { + std::string desc; + desc.append("op_name:").append(op_name_).append(" op_type:").append(op_type_); + int input_id = 0; + for (const auto &element : input_data_list_) { + desc.append(" input_id:") + .append(std::to_string(input_id++)) + .append(" input_format:") + .append(element.data_format_) + .append(" input_data_type:") + .append(std::to_string(element.data_type_)) + .append(" input_shape:") + .append(DataShapeToString(element.data_shape_)); + } + + input_id = 0; + for (const auto &element : output_data_list_) { + desc.append(" output_id:") + .append(std::to_string(input_id++)) + .append(" output_format:") + .append(element.data_format_) + .append(" output_data_type:") + .append(std::to_string(element.data_type_)) + .append(" output_shape:") + .append((DataShapeToString(element.data_shape_))); + } + + desc.append("\n"); + + return desc; +} + +std::string PointDesc::ToString() { + std::string desc; + desc.append(std::to_string(point_id_)).append(" ").append(op_name_).append("\n"); + return desc; +} + +std::string GraphDesc::DataShapeToString(const std::vector &shape) { + std::ostringstream oss; + oss << "\""; + if (!shape.empty()) { + std::copy(shape.begin(), shape.end() - 1, std::ostream_iterator(oss, ",")); + oss << shape.back(); + } + oss << "\""; + return oss.str(); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h similarity index 100% rename from mindspore/ccsrc/device/ascend/profiling/reporter/profiling_desc.h rename to mindspore/ccsrc/runtime/device/ascend/profiling/reporter/profiling_desc.h diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.cc new file mode 100644 index 0000000000..26d722aa1a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "runtime/device/ascend/profiling/reporter/task_desc_reporter.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/ascend_kernel_mod.h" + +namespace mindspore { +namespace device { +namespace ascend { +void TaskDescReporter::ReportData() { + MS_LOG(INFO) << "cnode_list.size()=" << cnode_list_.size() << " task_ids_.size()=" << task_ids_.size(); + if (cnode_list_.size() != task_ids_.size()) { + MS_LOG(ERROR) << "cnode list size not equal task ids size"; + return; + } + + size_t task_index = 0; + for (const auto &node : cnode_list_) { + if (AnfAlgo::GetKernelType(node) != TBE_KERNEL && AnfAlgo::GetKernelType(node) != AKG_KERNEL) { + MS_LOG(WARNING) << "Skip non tbe kernel"; + ++task_index; + continue; + } + auto kernel_mod = AnfAlgo::GetKernelMod(node); + auto ascend_kernel_mod = dynamic_cast(kernel_mod); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(ascend_kernel_mod); + // Check task_id and stream_id valid + CheckStreamTaskValid(task_index, task_index); + auto desc_ptr = std::make_shared(node->fullname_with_scope(), task_ids_[task_index], + ascend_kernel_mod->block_dim(), stream_ids_[task_index]); + prof_desc_list_.emplace_back(desc_ptr); + ++task_index; + } + ReportAllLine(); +} + +void TaskDescReporter::CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id) { + if (task_id >= task_ids_.size() || stream_id >= stream_ids_.size()) { + MS_LOG(EXCEPTION) << "Index invalid. task_id:" << task_id << ", task_ids.size:" << task_ids_.size() + << ", stream_id:" << stream_id << ", stream_ids.size:" << stream_ids_.size(); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h new file mode 100644 index 0000000000..51526735a9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/reporter/task_desc_reporter.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ + +#include +#include +#include +#include "runtime/device/ascend/profiling/reporter/desc_reporter.h" + +namespace mindspore { +namespace device { +namespace ascend { +class TaskDescReporter : public DescReporter { + public: + TaskDescReporter(int device_id, const std::string &file_name, std::vector cnode_list) + : DescReporter(device_id, file_name), cnode_list_(std::move(cnode_list)) {} + ~TaskDescReporter() override = default; + void ReportData() override; + void set_task_ids(const std::vector &task_ids) { task_ids_ = task_ids; } + void set_stream_ids(const std::vector &stream_ids) { stream_ids_ = stream_ids; } + + private: + std::vector task_ids_; + std::vector stream_ids_; + void CheckStreamTaskValid(uint32_t task_id, uint32_t stream_id); + std::vector cnode_list_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_ASCEND_PROFILING_REPORTER_TASK_DESC_REPORTER_H_ diff --git a/mindspore/ccsrc/device/ascend/readme.md b/mindspore/ccsrc/runtime/device/ascend/readme.md similarity index 100% rename from mindspore/ccsrc/device/ascend/readme.md rename to mindspore/ccsrc/runtime/device/ascend/readme.md diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc new file mode 100644 index 0000000000..dba71edfd3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/tasksink/runtime_utils.h" + +#include + +#include "hccl/hcom.h" +#include "utils/log_adapter.h" +#include "utils/utils.h" + +constexpr auto kHcomBroadcast = "hcom_broadcast_"; +constexpr auto kHcomAllGather = "hcom_all_gather_"; +constexpr auto kHcomAllReduce = "hcom_all_reduce_"; +constexpr auto kHcomReduceScatter = "hcom_reduce_scatter_"; +constexpr auto kUnderline = "_"; +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) { + hcclResult_t ret = hcom_bind_model(model, stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast(ret); + return false; + } + return true; +} + +bool RuntimeUtils::HcomUnbindModel(rtModel_t model) { + hcclResult_t ret = hcom_unbind_model(model); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast(ret); + return false; + } + return true; +} + +bool RuntimeUtils::HcomDistribute(const std::shared_ptr &task_info, rtStream_t stream) { + MS_LOG(INFO) << "hccl distribute start"; + MS_EXCEPTION_IF_NULL(task_info); + hcclResult_t ret; + static uint32_t task_counter = 0; + auto hccl_group = task_info->group(); + if (task_info->hccl_type() == kBroadcastOpName) { + // call hcom broadcast interface to run op + const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast(task_info->count()), + static_cast(task_info->data_type()), static_cast(task_info->root_id()), + hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast(ret); + return false; + } + } else if (task_info->hccl_type() == kAllGatherOpName) { + // call hcom allgather interface to run op + const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), + static_cast(task_info->count()), static_cast(task_info->data_type()), + hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret; + return false; + } + } else if (task_info->hccl_type() == kAllReduceOpName) { + // call hcom allreduce interface to run op + const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), + static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->op_type()), hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret; + return false; + } + } else if (task_info->hccl_type() == kReduceScatterOpName) { + // call hcom reducescatter interface to run op + const string tag_reduce_scatter = + kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0); + ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(), + static_cast(task_info->count()), static_cast(task_info->data_type()), + static_cast(task_info->op_type()), hccl_group.c_str(), stream); + if (ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret; + return false; + } + } + return true; +} +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/tasksink/runtime_utils.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h similarity index 100% rename from mindspore/ccsrc/device/ascend/tasksink/runtime_utils.h rename to mindspore/ccsrc/runtime/device/ascend/tasksink/runtime_utils.h diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc new file mode 100644 index 0000000000..5aeb932105 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/tasksink/task_generator.h" + +#include +#include "backend/kernel_compiler/task_stream.h" +#include "utils/context/ms_context.h" +#include "common/utils.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, + uint32_t graph_id) { + MS_LOG(INFO) << "GenTasks start..."; + MS_EXCEPTION_IF_NULL(task_info_list); + // Traverse graph applykernel list and run + if (!LaunchAllKernel(anf_node_list, task_info_list, graph_id)) { + MS_LOG(ERROR) << "LaunchAllKernel failed"; + return false; + } + MS_LOG(INFO) << "GenTasks end..."; + return true; +} + +void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + MS_EXCEPTION_IF_NULL(kernel_inputs); + // akg process + // set atomic clean addr + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, anf_node_ptr)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicOutputIndexs); + auto graph = anf_node_ptr->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_users = manager->node_users(); + if (node_users[anf_node_ptr].empty()) { + MS_LOG(EXCEPTION) << "Node users of " << anf_node_ptr->ToString() << " is empty."; + } + auto depend_node = node_users[anf_node_ptr].pop().first; + if (!IsPrimitiveCNode(depend_node, prim::kPrimDepend)) { + MS_LOG(EXCEPTION) << "Checking Depend node failed"; + } + if (node_users[depend_node].empty()) { + MS_LOG(EXCEPTION) << "Node users of " << depend_node->ToString() << " is empty."; + } + auto post_node = node_users[depend_node].pop().first; + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(post_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + input->size = device_address->size_; + kernel_inputs->push_back(input); + } + MS_LOG(DEBUG) << "AtomicAddClean clean output size: " << clean_output_indexs.size(); + } +} + +void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + MS_EXCEPTION_IF_NULL(kernel_inputs); + if (anf_node_ptr->inputs().size() != 2) { + LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs); + return; + } + MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); + auto pre_node = (anf_node_ptr->inputs()[1])->cast(); + // set clean output addr + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->push_back(input); + } + MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); + } + // set clean workspace address + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + for (const auto &index : clean_workspace_indexs) { + auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_inputs->push_back(workspace); + } + } + auto clear_mems = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); + if (kernel_inputs->size() != clear_mems.size()) { + MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" + << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); + } +} + +bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, + std::vector *task_info_list) { + MS_EXCEPTION_IF_NULL(task_info_list); + MS_EXCEPTION_IF_NULL(anf_node_ptr); + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + auto kernel_mod = AnfAlgo::GetKernelMod(anf_node_ptr); + MS_EXCEPTION_IF_NULL(kernel_mod); + kernel_mod->set_kernel_name(anf_node_ptr->fullname_with_scope()); + if (AnfAlgo::GetCNodeName(anf_node_ptr) != kAtomicAddrCleanOpName) { + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { + auto real_input_index = AnfAlgo::GetRealInputIndex(anf_node_ptr, i); + auto device_address = AnfAlgo::GetPrevNodeOutputAddr(anf_node_ptr, real_input_index); + AddressPtr input = std::make_shared
(); + input->addr = device_address->ptr_; + input->size = device_address->size_; + kernel_inputs.push_back(input); + } + + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node_ptr); ++i) { + auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i); + AddressPtr output = std::make_shared
(); + output->addr = it->ptr_; + output->size = it->size_; + kernel_outputs.push_back(output); + } + + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetWorkspaceAddr(anf_node_ptr, i); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + workspace->size = device_address->size_; + kernel_workspaces.push_back(workspace); + } + } else { + LaunchAddrCleanKernel(anf_node_ptr, &kernel_inputs); + } + + auto ascend_kernel_mod = dynamic_cast(kernel_mod); + MS_EXCEPTION_IF_NULL(ascend_kernel_mod); + std::vector task_info_ptrs = + ascend_kernel_mod->GenTask(kernel_inputs, kernel_workspaces, kernel_outputs, stream_id); + task_info_list->insert(task_info_list->end(), task_info_ptrs.begin(), task_info_ptrs.end()); + return true; +} + +bool TaskGenerator::LaunchAllKernel(const std::vector &anf_node_list, + std::vector *task_info_list, uint32_t graph_id) { + uint32_t current_op_index = 0; + std::vector profiling_cnode_list; + std::vector kernel_name_list; + for (const auto &anf_node_ptr : anf_node_list) { + size_t old_size = task_info_list->size(); + uint32_t stream_id = AnfAlgo::GetStreamId(anf_node_ptr); + MS_EXCEPTION_IF_NULL(anf_node_ptr); + MS_LOG(INFO) << "Task gen launch begin, current_op_idx:" << current_op_index + << " name:" << anf_node_ptr->fullname_with_scope() << ", stream id:" << stream_id; + if (!LaunchKernel(anf_node_ptr, stream_id, task_info_list)) { + MS_LOG(ERROR) << "LaunchKernel failed."; + return false; + } + for (size_t i = old_size; i < task_info_list->size(); ++i) { + profiling_cnode_list.emplace_back(anf_node_ptr); + kernel_name_list.emplace_back(anf_node_ptr->fullname_with_scope()); + } + current_op_index++; + } + + ProfilingUtils::SetGraphKernelName(graph_id, kernel_name_list); + if (ProfilingManager::GetInstance().IsProfiling()) { + ProfilingUtils::SetGraphProfilingCNode(graph_id, profiling_cnode_list); + } + return true; +} +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h new file mode 100644 index 0000000000..134dec48b6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ +#define MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ + +#include +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "ir/anf.h" +#include "backend/kernel_compiler/ascend_kernel_mod.h" +#include "framework/ge_runtime/task_info.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace tasksink { +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; +using AddressPtrList = std::vector; +using ge::model_runner::TaskInfo; +using TaskInfoPtr = std::shared_ptr; +class TaskGenerator { + public: + TaskGenerator() = default; + ~TaskGenerator() = default; + TaskGenerator(const TaskGenerator &in) = delete; + TaskGenerator &operator=(const TaskGenerator &in) = delete; + + static bool GenTasks(const std::vector &anf_node_list, std::vector *task_info_list, + uint32_t graph_id); + + private: + static void LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); + static void LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs); + static bool LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_id, std::vector *task_info_list); + static bool LaunchAllKernel(const std::vector &anf_node_list, std::vector *task_info_list, + uint32_t graph_id); +}; +} // namespace tasksink +} // namespace ascend +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_ASCEND_TASK_TASK_BUILD_H_ diff --git a/mindspore/ccsrc/runtime/device/convert_tensor_utils.cc b/mindspore/ccsrc/runtime/device/convert_tensor_utils.cc new file mode 100644 index 0000000000..cfd9b0fbdf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/convert_tensor_utils.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/convert_tensor_utils.h" +#include +namespace mindspore { +namespace device { +void HalfToFloat(void *dst, const void *src, size_t elem_num) { + auto half_data = static_cast(src); + auto float_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + float tmp = Eigen::half_impl::half_to_float(half_data[i]); + float_data[i] = tmp; + } +} + +void FloatToHalf(void *dst, const void *src, size_t elem_num) { + auto float_data = static_cast(src); + auto half_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + half_data[i] = Eigen::half(float_data[i]); + } +} + +void DoubleToFloat(void *dst, const void *src, size_t elem_num) { + auto double_data = static_cast(src); + auto float_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + float_data[i] = static_cast(double_data[i]); + } +} + +void FloatToDouble(void *dst, const void *src, size_t elem_num) { + auto float_data = static_cast(src); + auto double_data = static_cast(dst); + for (size_t i = 0; i < elem_num; ++i) { + double_data[i] = static_cast(float_data[i]); + } +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/convert_tensor_utils.h b/mindspore/ccsrc/runtime/device/convert_tensor_utils.h similarity index 100% rename from mindspore/ccsrc/device/convert_tensor_utils.h rename to mindspore/ccsrc/runtime/device/convert_tensor_utils.h diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc new file mode 100644 index 0000000000..92269233bd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_device_address.h" +#include +#include "runtime/device/convert_tensor_utils.h" + +namespace mindspore { +namespace device { +namespace cpu { +bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size_t size, TypeId type, + void *host_ptr) const { + if (ptr_ == nullptr) { + MS_LOG(ERROR) << "The pointer ptr_ is null!"; + return false; + } + + if (host_ptr == ptr_) { + MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; + return true; + } + + if (type == type_id_) { + auto ret_code = memcpy_s(host_ptr, size, ptr_, size_); + if (ret_code != EOK) { + MS_LOG(ERROR) << "Failed to copy tensor!"; + return false; + } + } else if (type == kNumberTypeFloat16) { + FloatToHalf(host_ptr, ptr_, size / 2); + } else if (type == kNumberTypeFloat64) { + FloatToDouble(host_ptr, ptr_, size / sizeof(double)); + } else { + MS_LOG(ERROR) << "Types not match. Device type: " << TypeIdLabel(type_id_) << ", host type: " << TypeIdLabel(type) + << "!"; + return false; + } + return true; +} + +bool CPUDeviceAddress::SyncHostToDevice(const std::vector & /*shape*/, size_t size, TypeId type, + const void *host_ptr) const { + if (type == kNumberTypeFloat16) { + HalfToFloat(ptr_, host_ptr, size / 2); + } else if (type == kNumberTypeFloat64) { + DoubleToFloat(ptr_, host_ptr, size / sizeof(double)); + } + return true; +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h new file mode 100644 index 0000000000..63cf171fa2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ + +#include +#include +#include "runtime/device/device_address.h" + +namespace mindspore { +namespace device { +namespace cpu { +class CPUDeviceAddress : public DeviceAddress { + public: + CPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + + CPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) + : DeviceAddress(ptr, size, format, type_id) {} + + ~CPUDeviceAddress() override = default; + + bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc new file mode 100644 index 0000000000..d2e41a1fbd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -0,0 +1,324 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_kernel_runtime.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "utils/context/ms_context.h" +#include "utils/config_manager.h" +#include "utils/profile.h" +#include "common/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/session_basic.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace device { +namespace cpu { +const size_t INIT_NODE_REF = 1; +namespace { +TypeId GetCPUSupportOutputTypeId(const TypeId type_id) { + TypeId support_type_id = type_id; + if (type_id == kNumberTypeUInt32) { + support_type_id = kNumberTypeInt32; + } + if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 || + type_id == kNumberTypeFloat64) { + support_type_id = kNumberTypeFloat32; + } + if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) { + MS_LOG(EXCEPTION) << "Check output type failed."; + } + return support_type_id; +} +} // namespace + +void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { + AssignValueNodeAddress(kernel_graph); + AssignInputNodeAddress(kernel_graph); + AssignKernelOutputAddress(kernel_graph); + resource_manager_.MemPlan(kernel_graph); + resource_manager_.MemMalloc(kernel_graph); +} + +void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + size_t type_size = sizeof(float); + for (auto &item_node : kernel_graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(item_node); + if (item_node->isa()) { + auto value_node = item_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (!node_value->isa()) { + continue; + } + auto tensor = node_value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + std::vector data_shape = tensor->shape(); + size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); + DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32); + MS_EXCEPTION_IF_NULL(address); + if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { + address->ptr_ = tensor->data_c(); + } else { + address->ptr_ = resource_manager_.MemMalloc(tensor_size); + if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "Value node sync host to device failed!"; + } + } + address->ref_count_ = INIT_NODE_REF; + AnfAlgo::SetOutputAddr(address, 0, item_node.get()); + } + } +} + +void CPUKernelRuntime::AssignInputNodeAddress(const session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + size_t type_size = sizeof(float); + for (auto &item : kernel_graph->inputs()) { + MS_EXCEPTION_IF_NULL(item); + if (item->isa()) { + auto output_num = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_num; index++) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + std::vector fmt_shape = AnfAlgo::GetOutputDeviceShape(item, index); + size_t tensor_size = + fmt_shape.empty() ? type_size + : std::accumulate(fmt_shape.begin(), fmt_shape.end(), type_size, std::multiplies()); + auto format = AnfAlgo::GetOutputFormat(item, index); + auto address = CreateDeviceAddress(nullptr, tensor_size, format, output_type_id); + AnfAlgo::SetOutputAddr(address, index, item.get()); + } + } + } +} + +void CPUKernelRuntime::AssignKernelOutputAddress(const session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernels = kernel_graph->execution_order(); + for (auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + auto output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i, + kernel.get()); + } + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32), + i, kernel.get()); + } + } +} + +DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) { + return std::make_shared(device_ptr, device_size, format, type_id); +} + +tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, + std::set *bound_addresses, + std::vector *need_sync_outputs) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(bound_addresses); + MS_EXCEPTION_IF_NULL(need_sync_outputs); + size_t output_size = AnfAlgo::GetOutputTensorNum(node); + if (index >= output_size) { + MS_LOG(EXCEPTION) << "Invalid input index " << index; + } + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + auto shape = AnfAlgo::GetOutputInferShape(node, index); + std::vector temp_shape; + (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); + type_id = GetCPUSupportOutputTypeId(type_id); + tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + MS_EXCEPTION_IF_NULL(tensor); + if (bound_addresses->find(address) != bound_addresses->end()) { + tensor->set_device_address(address); + need_sync_outputs->emplace_back(tensor); + } else { + address->ptr_ = tensor->data_c(); + address->ref_count_ = INIT_NODE_REF; + (void)bound_addresses->insert(address); + } + tensor->set_dirty(false); + return tensor; +} + +BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, + const std::unordered_map &input_map, + std::set *bound_addresses, + std::vector *need_sync_outputs) { + auto &input_node = kernel_with_index.first; + auto index = kernel_with_index.second; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + auto node = input_node->cast(); + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::GetCNodeName(input_node) == prim::kPrimMakeTuple->name()) { + VectorRef ret; + for (size_t i = 1; i < node->inputs().size(); i++) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); + auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); + ret.push_back(out); + } + return ret; + } + return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); + } else if (input_node->isa() || input_node->isa()) { + auto iter = input_map.find(input_node.get()); + if (iter != input_map.end()) { + return iter->second; + } + } + return BaseRef(); +} + +void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, + const std::vector &inputs, VectorRef *outputs, + std::vector *need_sync_outputs) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + // bind input ptr + auto &input_nodes = kernel_graph->inputs(); + if (input_nodes.size() != inputs.size()) { + MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; + } + std::unordered_map input_map; + size_t input_idx = 0; + for (auto &item : input_nodes) { + MS_EXCEPTION_IF_NULL(item); + input_map[item.get()] = inputs[input_idx]; + if (item->isa()) { + auto address = AnfAlgo::GetMutableOutputAddr(item, 0); + auto tensor = inputs[input_idx]; + auto tensor_address = tensor->device_address(); + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor_address != nullptr && tensor_address != address) { + (void)tensor->data_sync(); + } + std::vector data_shape = tensor->shape(); + size_t tensor_size = + std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); + if (tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { + address->ptr_ = tensor->data_c(); + } else { + address->ptr_ = resource_manager_.MemMalloc(tensor_size); + if (!address->SyncHostToDevice(data_shape, LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; + } + tensor->set_dirty(true); + } + address->ref_count_ = INIT_NODE_REF; + tensor->set_device_address(address); + } + input_idx++; + } + // new output and bind ptr + std::set bound_addresses; + auto output_nodes = kernel_graph->outputs(); + for (const auto &item : output_nodes) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); + auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); + outputs->push_back(std::move(out)); + } +} + +void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector *input_list) { + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(input_list); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + if (address->ptr_ == nullptr) { + address->ptr_ = resource_manager_.MemMalloc(address->size_); + } + MS_EXCEPTION_IF_NULL(address->ptr_); + input->addr = address->ptr_; + input->size = address->size_; + input_list->push_back(input); +} + +void CPUKernelRuntime::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + resource_manager_.IncreaseSummaryRefCount(summary_outputs); +} + +void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + resource_manager_.DecreaseSummaryRefCount(summary_outputs); +} + +bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + resource_manager_.IncreaseAddressRefCount(kernel_graph); + + auto kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { +#ifdef ENABLE_PROFILE + double start_time = GetTime(); +#endif + std::vector kernel_inputs; + std::vector kernel_workspaces; + std::vector kernel_outputs; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i).get(); + MS_EXCEPTION_IF_NULL(device_address); + AddRuntimeAddress(device_address, &kernel_inputs); + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i).get(); + MS_EXCEPTION_IF_NULL(device_address); + AddRuntimeAddress(device_address, &kernel_outputs); + } + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(device_address); + AddRuntimeAddress(device_address, &kernel_workspaces); + } + auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0); + resource_manager_.DecreaseAddressRefCount(kernel); + if (!ret) { + MS_LOG(EXCEPTION) << "Launch kernel failed."; + } +#ifdef ENABLE_PROFILE + double cost_time = GetTime() - start_time; + MS_LOG(INFO) << "cpu kernel: " << kernel->fullname_with_scope() << " costs " << cost_time * 1e6 << " us"; +#endif + } + return true; +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h new file mode 100644 index 0000000000..a29f840bfd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -0,0 +1,70 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/session_basic.h" +#include "runtime/device/cpu/cpu_resource_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/any.h" +namespace mindspore { +namespace device { +namespace cpu { +class CPUKernelRuntime : public KernelRuntime { + public: + CPUKernelRuntime() = default; + ~CPUKernelRuntime() override = default; + + bool Init() override { return true; } + bool Run(session::KernelGraph *graph) override; + void AssignKernelAddress(session::KernelGraph *kernel_graph); + void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, + VectorRef *outputs, std::vector *need_sync_outputs); + void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + + protected: + bool SyncStream() override { return true; }; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override; + + private: + tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, + std::set *bound_addresses, + std::vector *need_sync_outputs); + + BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, + const std::unordered_map &input_map, + std::set *bound_addresses, + std::vector *need_sync_outputs); + void AssignValueNodeAddress(session::KernelGraph *kernel_graph); + void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); + void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); + void AddRuntimeAddress(DeviceAddress *address, std::vector *input_list); + CPUResourceManager resource_manager_; +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc new file mode 100644 index 0000000000..c607260ab3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.cc @@ -0,0 +1,174 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_resource_manager.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace cpu { +CPUResourceManager::~CPUResourceManager() { MemFree(); } + +void CPUResourceManager::MemFree() { + if (mem_ptr_ != nullptr) { + free(mem_ptr_); + mem_ptr_ = nullptr; + mem_size_ = 0; + } + + for (auto &&iter : dynamic_mem_) { + free(iter.first); + } + dynamic_mem_.clear(); +} + +void CPUResourceManager::MemPlan(const session::KernelGraph *graph) { + mem_plan_.MemPlan(graph); + size_t graph_mem_size = mem_plan_.GetGraphMemSize(graph); + if (graph_mem_size > mem_size_) { + MemFree(); + mem_ptr_ = reinterpret_cast(malloc(graph_mem_size)); + if (mem_ptr_ != nullptr) { + mem_size_ = graph_mem_size; + dynamic_malloc_ = false; + } else { + MS_LOG(INFO) << "Switch to dynamic malloc"; + dynamic_malloc_ = true; + } + } +} + +void CPUResourceManager::MemMalloc(const session::KernelGraph *graph) { + if (dynamic_malloc_) { + return; + } + mem_plan_.MemAssign(graph, mem_ptr_); +} + +void *CPUResourceManager::MemMalloc(size_t mem_size) { + void *ptr = malloc(mem_size); + if (ptr != nullptr) { + memset_s(ptr, mem_size, 0, mem_size); + dynamic_mem_[ptr] = mem_size; + return ptr; + } else { + MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size; + } +} + +void CPUResourceManager::MemFree(void *ptr) { + auto iter = dynamic_mem_.find(ptr); + if (iter != dynamic_mem_.end()) { + (void)dynamic_mem_.erase(iter); + free(ptr); + } +} + +void CPUResourceManager::IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + if (!dynamic_malloc_) { + return; + } + + if (summary_outputs.empty()) { + return; + } + + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } +} + +void CPUResourceManager::DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs) { + if (!dynamic_malloc_) { + return; + } + + if (summary_outputs.empty()) { + return; + } + + for (auto &output_item : summary_outputs) { + auto node = output_item.second.first; + size_t index = IntToSize(output_item.second.second); + auto address = AnfAlgo::GetMutableOutputAddr(node, index); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } +} + +void CPUResourceManager::IncreaseAddressRefCount(const session::KernelGraph *graph) { + if (!dynamic_malloc_) { + return; + } + MS_EXCEPTION_IF_NULL(graph); + auto kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_++; + } + } +} + +void CPUResourceManager::DecreaseAddressRefCount(const AnfNodePtr &kernel) { + if (!dynamic_malloc_) { + return; + } + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + address->ref_count_--; + if (address->ref_count_ == 0 && address->ptr_ != nullptr) { + MemFree(address->ptr_); + address->ptr_ = nullptr; + } + } +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h new file mode 100644 index 0000000000..d251760dd2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_resource_manager.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ + +#include +#include +#include "backend/session/kernel_graph.h" +#include "backend/session/session_basic.h" +#include "runtime/device/device_address.h" +#include "runtime/device/cpu/cpu_simple_mem_plan.h" +namespace mindspore { +namespace device { +namespace cpu { +class CPUResourceManager { + public: + CPUResourceManager() = default; + ~CPUResourceManager(); + + void MemPlan(const session::KernelGraph *graph); + void MemMalloc(const session::KernelGraph *graph); + void IncreaseAddressRefCount(const session::KernelGraph *graph); + void DecreaseAddressRefCount(const AnfNodePtr &kernel); + void *MemMalloc(size_t mem_size); + void MemFree(void *ptr); + void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + + private: + void MemFree(); + CPUSimpleMemPlan mem_plan_; + + size_t mem_size_{0}; + uint8_t *mem_ptr_{nullptr}; + bool dynamic_malloc_{false}; + std::unordered_map dynamic_mem_; +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_RESOURCE_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc new file mode 100644 index 0000000000..7838e66984 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.cc @@ -0,0 +1,118 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/cpu_simple_mem_plan.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace cpu { +void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + size_t total_mem_size = 0; + auto kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + MS_EXCEPTION_IF_NULL(kernel_with_index.first); + if (kernel_with_index.first->isa()) { + continue; + } + auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + total_mem_size += address->size_; + } + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { + auto address = AnfAlgo::GetOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + total_mem_size += address->size_; + } + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + total_mem_size += address->size_; + } + } + } + graph_mem_size_[graph] = total_mem_size; +} + +size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const { + auto iter = graph_mem_size_.find(graph); + if (iter != graph_mem_size_.end()) { + return iter->second; + } + return 0; +} + +void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(base_ptr); + uint8_t *mem_ptr = base_ptr; + auto kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + MS_EXCEPTION_IF_NULL(kernel_with_index.first); + if (kernel_with_index.first->isa()) { + continue; + } + auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + address->ptr_ = mem_ptr; + mem_ptr = mem_ptr + address->size_; + } + } + + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { + auto address = AnfAlgo::GetMutableOutputAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + address->ptr_ = mem_ptr; + mem_ptr = mem_ptr + address->size_; + } + } + + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto address = AnfAlgo::GetWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(address); + if (address->ptr_ == nullptr) { + address->ptr_ = mem_ptr; + mem_ptr = mem_ptr + address->size_; + } + } + } +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h new file mode 100644 index 0000000000..123e29fbe5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_simple_mem_plan.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ +#define MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ + +#include +#include +#include "backend/session/kernel_graph.h" +#include "runtime/device/device_address.h" + +namespace mindspore { +namespace device { +namespace cpu { +class CPUSimpleMemPlan { + public: + CPUSimpleMemPlan() = default; + ~CPUSimpleMemPlan() = default; + + void MemPlan(const session::KernelGraph *graph); + void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); + size_t GetGraphMemSize(const session::KernelGraph *graph) const; + + private: + std::unordered_map graph_mem_size_; +}; +} // namespace cpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_CPU_CPU_SIMPLE_MEM_PLAN_H_ diff --git a/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc new file mode 100644 index 0000000000..9528e61ee9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/cpu/kernel_select_cpu.h" + +#include +#include +#include + +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace device { +namespace cpu { +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; +using mindspore::kernel::KernelBuildInfo; +namespace { +bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) { + auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa() || input_node->isa()) { + return true; + } + return false; +} + +void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector &input_not_cnode_indexes, + const CNodePtr kernel_node) { + for (auto &input_index : input_not_cnode_indexes) { + auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first; + MS_EXCEPTION_IF_NULL(input_node); + std::vector output_types; + output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first); + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + builder->SetOutputsFormat({kOpFormat_DEFAULT}); + builder->SetOutputsDeviceType(output_types); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get()); + } +} + +void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector *input_formats, + std::vector *input_types, std::vector *input_no_cnode_indexes) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + TypeId dtype = kTypeUnknown; + if (IsInputNotCNode(kernel_node, input_index)) { + input_no_cnode_indexes->emplace_back(input_index); + dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); + } else { + dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index); + } + input_formats->emplace_back(kOpFormat_DEFAULT); + input_types->emplace_back(dtype); + } +} + +void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr, + std::vector *output_formats, std::vector *output_types) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second); + auto dtype = kernel_attr.GetOutputAttr(output_index).first; + output_types->emplace_back(dtype); + } +} + +bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector &input_formats, + const std::vector &input_types, + const std::vector &input_not_cnode_indexes) { + if (kernel_attr.GetInputSize() != input_types.size()) { + MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size(); + return false; + } + auto input_num = input_types.size(); + for (size_t i = 0; i < input_num; ++i) { + bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(), + [i](size_t index) { return index == i; }); + bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size()); + if (have_cnode_input && is_not_cnode_idx) { + continue; + } + if (kernel_attr.GetInputAttr(i).first != input_types[i]) { + MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first + << ", actual input dtype:" << input_types[i]; + return false; + } + if (kernel_attr.GetInputAttr(i).second != input_formats[i]) { + MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second + << ", actual input format:" << input_formats[i]; + return false; + } + } + return true; +} + +void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) { + MS_EXCEPTION_IF_NULL(kernel_attr); + TypeId input_dtype = kernel_attr->GetInputAttr(0).first; + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 1; i < input_num; ++i) { + kernel_attr->AddInputAttr(input_dtype); + } + + TypeId output_dtype = kernel_attr->GetOutputAttr(0).first; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t i = 1; i < output_num; ++i) { + kernel_attr->AddOutputAttr(output_dtype); + } +} +} // namespace + +void SetKernelInfo(const CNodePtr &kernel_node) { + std::vector input_formats; + std::vector input_types; + std::vector input_not_cnode_indexes; + std::vector output_formats; + std::vector output_types; + + MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node); + GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes); + + auto kernel_attrs = + kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node)); + + for (size_t index = 0; index < kernel_attrs.size(); ++index) { + auto kernel_attr = kernel_attrs[index]; + if (kernel_attr.GetAllSame()) { + ExpandKernelAttr(kernel_node, &kernel_attr); + } + if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (kernel_attr.GetOutputSize() != output_num) { + MS_LOG(DEBUG) << "Output num is not equal!"; + continue; + } + MS_LOG(INFO) << "Input format and dtype is matched, index: " << index; + GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types); + UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node); + for (auto &input_index : input_not_cnode_indexes) { + input_types[input_index] = kernel_attr.GetInputAttr(input_index).first; + } + break; + } + } + + auto builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(builder); + builder->SetInputsFormat(input_formats); + builder->SetInputsDeviceType(input_types); + builder->SetOutputsFormat(output_formats); + builder->SetOutputsDeviceType(output_types); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); +} +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/kernel_select_cpu.h b/mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h similarity index 100% rename from mindspore/ccsrc/device/cpu/kernel_select_cpu.h rename to mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.h diff --git a/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc new file mode 100644 index 0000000000..c124523d59 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.cc @@ -0,0 +1,277 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/cpu/mpi/mpi_adapter.h" +#ifdef ENABLE_MPI +#include +#include +#include "pybind11/pybind11.h" +#endif // ENABLE_MPI +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace cpu { +std::shared_ptr MPIAdapter::instance_ = nullptr; +std::shared_ptr MPIAdapter::Instance() { + if (instance_ == nullptr) { + MS_LOG(DEBUG) << "Create new mpi adapter instance."; + instance_.reset(new (std::nothrow) MPIAdapter()); + } + return instance_; +} + +#ifdef ENABLE_MPI + +#define RAISE_EXCEPTION(message) \ + { \ + std::ostringstream oss; \ + oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message; \ + pybind11::pybind11_fail(oss.str()); \ + } + +#define RAISE_EXCEPTION_WITH_PARAM(message, param) \ + { \ + std::ostringstream oss; \ + oss << "[" << __FILE__ << "] [" << __LINE__ << "] " << message << param; \ + pybind11::pybind11_fail(oss.str()); \ + } + +namespace { +MPI_Op GetMpiOp(const std::string &op_type) { + if (op_type == "sum") { + return MPI_SUM; + } else if (op_type == "max") { + return MPI_MAX; + } else if (op_type == "min") { + return MPI_MIN; + } else if (op_type == "prod") { + return MPI_PROD; + } + + RAISE_EXCEPTION_WITH_PARAM("unsupport op_type: ", op_type); + return MPI_SUM; +} + +int GetScatterIndex(int rankid, const std::vector &ranks_group) { + int scatter_index = -1; + for (size_t i = 0; i < ranks_group.size(); ++i) { + if (ranks_group[i] == rankid) { + scatter_index = static_cast(i); + break; + } + } + if (scatter_index == -1) { + RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rankid); + } + return scatter_index; +} +} // namespace + +MPIAdapter::MPIAdapter() : comm_group_world_(MPI_GROUP_NULL) { Init(); } + +MPIAdapter::~MPIAdapter() { + int finalized; + MPI_Finalized(&finalized); + if (finalized != 0) { + return; + } + + for (auto iter = ranks_group_.begin(); iter != ranks_group_.end(); ++iter) { + MPI_Group_free(&iter->second); + } + ranks_group_.clear(); + if (comm_group_world_ != MPI_GROUP_NULL) { + MPI_Group_free(&comm_group_world_); + comm_group_world_ = MPI_GROUP_NULL; + } + MPI_Finalize(); +} + +void MPIAdapter::Init() { + static bool init = false; + if (init) { + return; + } + + int init_flag = 0; + if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { + RAISE_EXCEPTION("Check mpi initialized fail!"); + } + if (init_flag == 0) { + auto ret = MPI_Init(nullptr, nullptr); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION("Failed to init mpi!"); + } + } + + MPI_Comm_group(MPI_COMM_WORLD, &comm_group_world_); + if (comm_group_world_ == MPI_GROUP_NULL) { + RAISE_EXCEPTION("comm_group_world_ init fail!"); + } + auto ret = MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION("Failed to init mpi rank id!"); + } + + ret = MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("Failed to init mpi rank size!rankid:", rank_id_) + } + init = true; +} + +MPI_Group MPIAdapter::AddGroup(const std::vector &ranks) { + if (ranks.size() > static_cast(rank_size_) || ranks.empty()) { + RAISE_EXCEPTION_WITH_PARAM("input rank size:", ranks.size()); + } + + if (std::find(ranks.begin(), ranks.end(), rank_id_) == ranks.end()) { + RAISE_EXCEPTION_WITH_PARAM("local rankid does not in the input rank group!local rank id:", rank_id_); + } + std::lock_guard lock(group_mutex_); + auto iter = ranks_group_.find(ranks); + if (iter != ranks_group_.end()) { + return iter->second; + } + const auto ranks_size = ranks.size(); + std::vector ranks_input(ranks_size, 0); + for (size_t i = 0; i < ranks_size; ++i) { + ranks_input[i] = ranks[i]; + } + + MPI_Group group = MPI_GROUP_NULL; + MPI_Group_incl(comm_group_world_, ranks.size(), ranks_input.data(), &group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi group fail!rankid:", rank_id_) + } + + ranks_group_[ranks] = group; + return group; +} + +bool MPIAdapter::ReduceScatter(const float *input, float *output, const std::vector &ranks_group, size_t data_num, + const std::string &op_type) { + if (ranks_group.empty()) { + RAISE_EXCEPTION("input rank group is empty!"); + return false; + } + + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_) + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); + } + std::vector receive_count(ranks_group.size(), 0); + for (size_t i = 0; i < ranks_group.size(); ++i) { + receive_count[i] = data_num; + } + + auto op = GetMpiOp(op_type); + auto ret = MPI_Reduce_scatter(input, output, receive_count.data(), MPI_FLOAT, op, comm); + bool result = true; + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi reduce_scatter fail!ret = ", ret); + result = false; + } + + ret = MPI_Comm_free(&comm); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail! ret = ", ret); + } + return result; +} + +bool MPIAdapter::ReduceScatterOverwriteInput(float *input, const std::vector &ranks_group, size_t input_data_num, + size_t output_size, const std::string &op_type, float *output) { + int scatter_index = GetScatterIndex(rank_id_, ranks_group); + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail!rankid:", rank_id_); + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail!rankid:", rank_id_); + } + + MPI_Win window; + auto ret = MPI_Win_create(input, input_data_num * sizeof(float), sizeof(float), MPI_INFO_NULL, comm, &window); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi window create fail! ret = ", ret); + } + MPI_Win_fence(0, window); + for (size_t i = 0; i < ranks_group.size(); ++i) { + int remote_rank = ranks_group[i]; + if (rank_id_ == remote_rank) { + continue; + } + auto op = GetMpiOp(op_type); + ret = MPI_Accumulate(input + i * input_data_num, input_data_num, MPI_FLOAT, remote_rank, i * input_data_num, + input_data_num, MPI_FLOAT, op, window); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi accumulate fail!ret = ", ret); + } + } + MPI_Win_fence(0, window); + if (output != nullptr) { + auto data_size = input_data_num * sizeof(float); + if (output_size < data_size) { + std::ostringstream exception_msg; + exception_msg << "output buffer size " << output_size << " < input size " << data_size; + RAISE_EXCEPTION(exception_msg.str()) + } + auto copy_ret = memcpy_s(output, output_size, input + scatter_index * input_data_num, data_size); + if (copy_ret != 0) { + RAISE_EXCEPTION_WITH_PARAM("copy output memory fail!ret = ", copy_ret); + } + } + MPI_Win_free(&window); + MPI_Comm_free(&comm); + return true; +} + +bool MPIAdapter::AllGather(const float *input, float *output, const std::vector &ranks_group, size_t data_num) { + if (ranks_group.empty()) { + RAISE_EXCEPTION("input rank group is empty!"); + return false; + } + auto group = AddGroup(ranks_group); + if (group == MPI_GROUP_NULL) { + RAISE_EXCEPTION_WITH_PARAM("Get mpi group fail! rankid:", rank_id_); + } + MPI_Comm comm; + MPI_Comm_create_group(MPI_COMM_WORLD, group, 0, &comm); + if (comm == MPI_COMM_NULL) { + RAISE_EXCEPTION_WITH_PARAM("create mpi comm fail! rankid:", rank_id_); + } + auto ret = MPI_Allgather(input, data_num, MPI_FLOAT, output, data_num, MPI_FLOAT, comm); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi allgater fail!ret = ", ret); + } + ret = MPI_Comm_free(&comm); + if (ret != MPI_SUCCESS) { + RAISE_EXCEPTION_WITH_PARAM("mpi comm free fail!ret = ", ret); + } + return true; +} +#endif // ENABLE_MPI +} // namespace cpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h b/mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h similarity index 100% rename from mindspore/ccsrc/device/cpu/mpi/mpi_adapter.h rename to mindspore/ccsrc/runtime/device/cpu/mpi/mpi_adapter.h diff --git a/mindspore/ccsrc/device/cpu/readme.md b/mindspore/ccsrc/runtime/device/cpu/readme.md similarity index 100% rename from mindspore/ccsrc/device/cpu/readme.md rename to mindspore/ccsrc/runtime/device/cpu/readme.md diff --git a/mindspore/ccsrc/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h similarity index 100% rename from mindspore/ccsrc/device/device_address.h rename to mindspore/ccsrc/runtime/device/device_address.h diff --git a/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc new file mode 100644 index 0000000000..547c2fbe64 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/blocking_queue.h" +#include +#include "runtime/device/gpu/gpu_common.h" +#include "common/utils.h" + +namespace mindspore { +namespace device { +GpuQueue::GpuQueue(void *addr, const std::vector &shape, const size_t &capacity) + : buffer_(addr), head_(0), tail_(0), shape_(shape), len_(0), capacity_(capacity), stream_(0), node_info_(nullptr) { + CHECK_CUDA_RET_WITH_ERROR(cudaStreamCreate(&stream_), "Cuda Create Stream Failed"); + node_info_ = std::make_unique(capacity); + for (auto item : shape) { + len_ += item; + } +} + +GpuQueue::~GpuQueue() { buffer_ = nullptr; } + +BlockQueueStatus_T GpuQueue::Push(const std::vector &data) { + int offset = 0; + for (size_t i = 0; i < data.size(); i++) { + auto item = data[i]; + if (item.data_ptr_ == nullptr || item.data_len_ != shape_[i]) { + MS_LOG(ERROR) << "Invalid Input: ptr: " << item.data_ptr_ << ", len: " << item.data_len_; + return ERROR_INPUT; + } + + void *addr = reinterpret_cast(buffer_) + tail_ * len_ + offset; + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(addr, item.data_ptr_, item.data_len_, cudaMemcpyHostToDevice, stream_), + "Cuda Memcpy Error"); + + offset += item.data_len_; + } + + node_info_[tail_].event_.reset(new cudaEvent_t()); + CHECK_CUDA_RET_WITH_ERROR(cudaEventCreate(&(*(node_info_[tail_].event_))), "Cuda Create Event Failed"); + node_info_[tail_].data_ = data; + tail_ = (tail_ + 1) % (capacity_); + return SUCCESS; +} + +BlockQueueStatus_T GpuQueue::Front(void **addr, size_t *len) const { + CHECK_CUDA_RET_WITH_ERROR(cudaEventSynchronize(*(node_info_[head_].event_)), "Cuda Event Syn Failed"); + CHECK_CUDA_RET_WITH_ERROR(cudaEventDestroy(*(node_info_[head_].event_)), "Cuda Destroy Event Failed"); + *addr = (unsigned char *)buffer_ + head_ * len_; + *len = len_; + + for (auto item : node_info_[head_].data_) { + host_release_(item.data_ptr_); + } + return SUCCESS; +} + +BlockQueueStatus_T GpuQueue::Pop() { + head_ = (head_ + 1) % (capacity_); + return SUCCESS; +} + +bool GpuQueue::Destroy() { + if (stream_ != nullptr) { + auto ret = cudaStreamDestroy(stream_); + if (ret == cudaSuccess) { + return true; + } else { + return false; + } + } else { + return true; + } +} + +BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector &shape, const size_t &capacity) { + if (addr == nullptr) { + MS_LOG(ERROR) << "addr is nullptr"; + return INTERNAL_ERROR; + } + queue_ = std::make_shared(addr, shape, capacity); + return SUCCESS; +} + +void BlockingQueue::RegisterRelease(const std::function &func) { queue_->RegisterRelease(func); } + +BlockQueueStatus_T BlockingQueue::Push(const std::vector &data, unsigned int timeout_in_sec) { + std::unique_lock locker(mutex_); + if (queue_->IsFull()) { + if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) { + return TIMEOUT; + } + } + auto ret = queue_->Push(data); + if (ret) { + return ret; + } + not_empty_cond_.notify_one(); + return SUCCESS; +} + +BlockQueueStatus_T BlockingQueue::Front(void **addr, size_t *len) { + std::unique_lock locker(mutex_); + bool timeout = not_empty_cond_.wait_for(locker, std::chrono::seconds(30), [this] { return !queue_->IsEmpty(); }); + if (!timeout) { + return TIMEOUT; + } + + return queue_->Front(addr, len); +} + +BlockQueueStatus_T BlockingQueue::Pop() { + std::unique_lock locker(mutex_); + not_empty_cond_.wait(locker, [this] { return !queue_->IsEmpty(); }); + auto ret = queue_->Pop(); + if (ret) { + return ret; + } + not_full_cond_.notify_one(); + return SUCCESS; +} + +bool BlockingQueue::Destroy() { + if (queue_ != nullptr) { + return queue_->Destroy(); + } else { + return true; + } +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.h b/mindspore/ccsrc/runtime/device/gpu/blocking_queue.h similarity index 100% rename from mindspore/ccsrc/device/gpu/blocking_queue.h rename to mindspore/ccsrc/runtime/device/gpu/blocking_queue.h diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_common.h b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h new file mode 100644 index 0000000000..2689fdbaca --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ + +#include +#include "runtime/device/gpu/gpu_device_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +class CudaCommon { + public: + inline int threads_num() const { return threads_per_block_; } + inline int major_sm() const { return major_sm_; } + inline int blocks_num(const int total_threads) const { + return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); + } + + static CudaCommon &GetInstance() { + static CudaCommon instance; + return instance; + } + + private: + CudaCommon() { + uint32_t device_id = GPUDeviceManager::GetInstance().cur_device_id(); + cudaDeviceProp prop; + (void)cudaGetDeviceProperties(&prop, device_id); + threads_per_block_ = prop.maxThreadsPerBlock; + max_blocks_ = prop.multiProcessorCount; + major_sm_ = prop.major; + } + ~CudaCommon() = default; + CudaCommon(const CudaCommon &) = delete; + CudaCommon &operator=(const CudaCommon &) = delete; + + int max_blocks_; + int threads_per_block_; + int major_sm_; +}; +#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) +#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() +#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() +#define MINIUM_SM 6 +#define RECOMMEND_SM 7 +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_CUDA_COMMON_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc new file mode 100644 index 0000000000..1f5e5e3c22 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc @@ -0,0 +1,231 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/cuda_driver.h" +#include +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" + +namespace mindspore { +namespace device { +namespace gpu { +size_t CudaDriver::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + size_t retreat_count = 0; + auto ret = cudaMalloc(reinterpret_cast(addr), size); + // If free memory is not enough, then retry with mem_malloc_retry_rate_. + while (ret == cudaErrorMemoryAllocation) { + size = FloatToSize(size * mem_malloc_retry_rate_); + size = (size / mem_malloc_align_size_) * mem_malloc_align_size_; + ret = cudaMalloc(reinterpret_cast(addr), size); + retreat_count++; + if (retreat_count > mem_malloc_retry_conut_max_) { + break; + } + } + + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMalloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + return size; +} + +bool CudaDriver::FreeDeviceMem(const DeviceMemPtr &addr) { + auto ret = cudaFree(addr); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaFree failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +size_t CudaDriver::AllocHostPinnedMem(size_t size, void **addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "The memory allocate size is 0"; + } + auto ret = cudaHostAlloc(addr, size, cudaHostAllocDefault); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaHostAlloc failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + return size; +} + +void CudaDriver::FreeHostPinnedMem(void *addr) { + if (addr) { + auto ret = cudaFreeHost(addr); + if (ret != cudaSuccess) { + MS_LOG(EXCEPTION) << "cudaFreeHost failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + } + } +} + +bool CudaDriver::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) { + auto ret = cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) { + auto ret = cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) { + auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, (cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, + DeviceStream stream) { + auto ret = cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, (cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemcpyAsync failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +size_t CudaDriver::total_mem_size() { + size_t free; + size_t total; + auto ret = cudaMemGetInfo(&free, &total); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + return total; +} + +size_t CudaDriver::free_mem_size() { + size_t free; + size_t total; + auto ret = cudaMemGetInfo(&free, &total); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaMemGetInfo failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return 0; + } + + return free; +} + +bool CudaDriver::CreateStream(DeviceStream *stream) { + auto ret = cudaStreamCreateWithFlags(reinterpret_cast(stream), cudaStreamNonBlocking); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamCreate failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::DestroyStream(const DeviceStream &stream) { + auto ret = cudaStreamDestroy((cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::SyncStream(const DeviceStream &stream) { + auto ret = cudaStreamSynchronize((cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaStreamSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::CreateEvent(DeviceEvent *event, unsigned int flag) { + auto ret = cudaEventCreateWithFlags(reinterpret_cast(event), flag); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventCreateWithFlags failed, ret[" << static_cast(ret) << "], " + << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::DestroyEvent(const DeviceEvent &event) { + auto ret = cudaEventDestroy((cudaEvent_t)event); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventDestroy failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::RecordEvent(DeviceEvent event, DeviceStream stream) { + auto ret = cudaEventRecord((cudaEvent_t)event, (cudaStream_t)stream); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventRecord failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::SyncEvent(const DeviceEvent &event) { + auto ret = cudaEventSynchronize((cudaEvent_t)event); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaEventSynchronize failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} + +bool CudaDriver::QueryEvent(const DeviceEvent &event) { + auto ret = cudaEventQuery((cudaEvent_t)event); + if (ret == cudaSuccess) { + return true; + } else if (ret == cudaErrorNotReady) { + return false; + } else { + MS_LOG(ERROR) << "cudaEventQuery failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } +} + +int CudaDriver::device_count() { + int dev_count; + auto ret = cudaGetDeviceCount(&dev_count); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaGetDeviceCount failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + } + return dev_count; +} + +bool CudaDriver::set_current_device(int index) { + auto ret = cudaSetDevice(index); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaSetDevice failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } + return true; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/cuda_driver.h b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h similarity index 100% rename from mindspore/ccsrc/device/gpu/cuda_driver.h rename to mindspore/ccsrc/runtime/device/gpu/cuda_driver.h diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_common.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h similarity index 100% rename from mindspore/ccsrc/device/gpu/distribution/collective_common.h rename to mindspore/ccsrc/runtime/device/gpu/distribution/collective_common.h diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.cc new file mode 100644 index 0000000000..80793042fd --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/collective_fake_init.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace gpu { +void CollectiveFakeInitializer::InitCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } + +void CollectiveFakeInitializer::FinalizeCollective() { MS_LOG(EXCEPTION) << "build without enable gpu!"; } +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_fake_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h similarity index 100% rename from mindspore/ccsrc/device/gpu/distribution/collective_fake_init.h rename to mindspore/ccsrc/runtime/device/gpu/distribution/collective_fake_init.h diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc new file mode 100644 index 0000000000..cba789b38d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/collective_init.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace gpu { +CollectiveInitializer &CollectiveInitializer::instance() { + static CollectiveInitializer instance = {}; + return instance; +} + +bool CollectiveInitializer::collective_inited() const { return collective_inited_; } + +const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } + +void CollectiveInitializer::InitCollective() { + void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); + if (handle == nullptr) { + MS_LOG(EXCEPTION) + << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " + "installed.\n2.nccl is not " + "installed or found.\n3.mpi is not installed or found"; + } + auto mpi_init_funcptr = reinterpret_cast(dlsym(handle, "InitMPI")); + MS_EXCEPTION_IF_NULL(mpi_init_funcptr); + (*mpi_init_funcptr)(); + + CollectiveInitializer::instance().collective_inited_ = true; + CollectiveInitializer::instance().collective_handle_ = handle; +} + +void CollectiveInitializer::FinalizeCollective() { + if (CollectiveInitializer::instance().collective_handle_ != nullptr) { + if (dlclose(CollectiveInitializer::instance().collective_handle_) != 0) { + MS_LOG(EXCEPTION) << "Closing libgpu_collective.so handle failed."; + } + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.h b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h similarity index 100% rename from mindspore/ccsrc/device/gpu/distribution/collective_init.h rename to mindspore/ccsrc/runtime/device/gpu/distribution/collective_init.h diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc new file mode 100644 index 0000000000..927c93cfaf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/collective_wrapper.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/mpi_wrapper.h" +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +#ifndef EXPORT_WRAPPER +#define EXPORT_WRAPPER __attribute__((visibility("default"))) +#endif + +using MPIWrapper = mindspore::device::gpu::MPIWrapper; +using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; + +extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } + +extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } + +extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } + +extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, + cudaStream_t stream) { + return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); +} + +extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, cudaStream_t stream) { + return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); +} + +extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, + cudaStream_t stream) { + return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); +} diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc new file mode 100644 index 0000000000..ed768fbbe5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/mpi_wrapper.h" + +#include +#include +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +namespace mindspore { +namespace device { +namespace gpu { +MPIWrapper::MPIWrapper() : rank_id_(0), rank_size_(0), local_rank_id_(0) { Init(); } + +MPIWrapper::~MPIWrapper() { + int finalized; + MPI_Finalized(&finalized); + if (finalized == 0) { + MPI_Finalize(); + } +} + +MPIWrapper &MPIWrapper::instance() { + static MPIWrapper instance; + return instance; +} + +int MPIWrapper::local_rank_id() const { return local_rank_id_; } + +void MPIWrapper::Init() { + int initialized; + CHECK_RET(MPI_Initialized(&initialized), MPI_SUCCESS, "Failed to check mpi initialization status."); + + if (initialized == 0) { + MPI_Init(nullptr, nullptr); + } + CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); + CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); + NCCLWrapper::instance().set_rank(rank_id_, rank_size_); + AssignLocalRankId(); + + ncclUniqueId unique_id; + if (rank_id_ == 0) { + unique_id = NCCLWrapper::instance().nccl_unique_id(); + } + CHECK_RET(MPI_Bcast(reinterpret_cast(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), + MPI_SUCCESS, "Failed to broadcast nccl unique id."); + NCCLWrapper::instance().set_nccl_unique_id(unique_id); + return; +} + +void MPIWrapper::AssignLocalRankId() { + char host_name[MAX_HOSTNAME_LEN] = {0}; + CHECK_RET(gethostname(host_name, MAX_HOSTNAME_LEN), 0, "Getting host name failed."); + size_t host_hash = std::hash()(host_name); + + const int kRankSize = rank_size_; + size_t all_host_hashs[kRankSize]; + all_host_hashs[rank_id_] = host_hash; + CHECK_RET(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_host_hashs, sizeof(size_t), MPI_BYTE, MPI_COMM_WORLD), + MPI_SUCCESS, "MPI_Allgather host hashs failed."); + for (int global_rank = 0; global_rank < kRankSize; global_rank++) { + if (global_rank == rank_id_) { + break; + } + if (all_host_hashs[global_rank] == all_host_hashs[rank_id_]) { + local_rank_id_++; + } + } + return; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h new file mode 100644 index 0000000000..3d54b376cf --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/mpi_wrapper.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/gpu/distribution/collective_common.h" + +namespace mindspore { +namespace device { +namespace gpu { +class MPIWrapper { + public: + MPIWrapper(MPIWrapper const &) = delete; + MPIWrapper &operator=(const MPIWrapper &) = delete; + static MPIWrapper &instance(); + int local_rank_id() const; + + private: + MPIWrapper(); + ~MPIWrapper(); + void Init(); + void AssignLocalRankId(); + + int rank_id_; + int rank_size_; + int local_rank_id_; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_MPI_WRAPPER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc new file mode 100644 index 0000000000..adf0b2f6fb --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/distribution/nccl_wrapper.h" + +namespace mindspore { +namespace device { +namespace gpu { +NCCLWrapper &NCCLWrapper::instance() { + static NCCLWrapper instance; + return instance; +} + +ncclUniqueId NCCLWrapper::nccl_unique_id() const { + ncclUniqueId unique_id; + CHECK_RET(ncclGetUniqueId(&unique_id), ncclSuccess, "Failed to create nccl unique id."); + return unique_id; +} + +void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } + +void NCCLWrapper::set_rank(int rank_id, int rank_size) { + rank_id_ = rank_id; + rank_size_ = rank_size; +} + +void NCCLWrapper::InitNCCLComm() { + CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, + "Failed to init nccl communicator."); +} + +ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + ncclRedOp_t reduce_type, cudaStream_t stream) { + return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); +} + +ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, + cudaStream_t stream) { + return ncclAllGather(input_addr, output_addr, count, data_type, comm_, stream); +} + +ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, + ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream) { + return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, comm_, stream); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h new file mode 100644 index 0000000000..fb09efc085 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/distribution/nccl_wrapper.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ + +#include +#include +#include +#include "runtime/device/gpu/distribution/collective_common.h" + +namespace mindspore { +namespace device { +namespace gpu { +class NCCLWrapper { + public: + NCCLWrapper(NCCLWrapper const &) = delete; + NCCLWrapper &operator=(const NCCLWrapper &) = delete; + static NCCLWrapper &instance(); + ncclUniqueId nccl_unique_id() const; + void set_nccl_unique_id(ncclUniqueId unique_id); + void set_rank(int rank_id, int rank_size); + void InitNCCLComm(); + ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, cudaStream_t stream); + ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, + cudaStream_t stream); + ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, + ncclRedOp_t op, cudaStream_t stream); + + private: + NCCLWrapper() : rank_id_(-1), rank_size_(0) {} + ~NCCLWrapper() = default; + + private: + int rank_id_; + int rank_size_; + ncclUniqueId unique_id_; + ncclComm_t comm_; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_DISTRIBUTION_NCCL_WRAPPER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc new file mode 100644 index 0000000000..a1b1fa9b79 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include +#include +#include "utils/log_adapter.h" +#include "common/utils.h" + +namespace mindspore { +namespace device { +unsigned int HandleMgr::AllocHandle() { + for (size_t i = 0; i < MAX_HANDLE_NUM; ++i) { + if (!handle_list_[i]) { + handle_list_[i] = true; + return (unsigned int)i; + } + } + return INVALID_HANDLE; +} + +void HandleMgr::FreeHandle(unsigned int handle_id) { + if (handle_id >= MAX_HANDLE_NUM) { + return; + } + handle_list_[handle_id] = false; +} + +GpuBufferMgr &GpuBufferMgr::GetInstance() noexcept { + static GpuBufferMgr instance; + return instance; +} + +BlockQueueStatus_T GpuBufferMgr::Create(unsigned int device_id, const std::string &channel_name, void *addr, + const std::vector &shape, const size_t &capacity) { + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (name_queue_map_.count(name)) { + MS_LOG(ERROR) << "Queue not exist " << name; + return QUEUE_NOT_EXIST; + } + std::shared_ptr queue = std::make_shared(); + BlockQueueStatus_T rt = queue->Create(addr, shape, capacity); + if (rt != SUCCESS) { + return rt; + } + (void)name_queue_map_.insert(std::make_pair(name, queue)); + init_ = true; + return SUCCESS; +} + +unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, + const std::vector &shape, const std::function func) { + set_device(); + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (!name_queue_map_.count(name)) { + MS_LOG(ERROR) << "Queue not exist " << name; + return HandleMgr::INVALID_HANDLE; + } + unsigned int handle = handle_mgr_.AllocHandle(); + if (handle == HandleMgr::INVALID_HANDLE) { + MS_LOG(ERROR) << "handle is invalid"; + return HandleMgr::INVALID_HANDLE; + } + (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); + name_queue_map_[name]->RegisterRelease(func); + open_by_dataset_++; + return handle; +} + +unsigned int GpuBufferMgr::Open(unsigned int device_id, const std::string &channel_name, + const std::vector &shape) { + set_device(); + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (!name_queue_map_.count(name)) { + MS_LOG(ERROR) << "Queue not exist " << name; + return HandleMgr::INVALID_HANDLE; + } + unsigned int handle = handle_mgr_.AllocHandle(); + if (handle == HandleMgr::INVALID_HANDLE) { + MS_LOG(ERROR) << "handle is invalid"; + return HandleMgr::INVALID_HANDLE; + } + (void)handle_queue_map_.insert(std::make_pair(handle, name_queue_map_[name])); + return handle; +} + +void GpuBufferMgr::set_device_id(int device_id) { cur_dev_id_ = device_id; } + +void GpuBufferMgr::set_device() const { + auto ret = cudaSetDevice(cur_dev_id_); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaSetDevice, ret[" << static_cast(ret) << "]"; + } +} + +BlockQueueStatus_T GpuBufferMgr::Push(unsigned int handle, const std::vector &data, + unsigned int timeout_in_sec) { + auto iter = handle_queue_map_.find(handle); + if (iter == handle_queue_map_.end()) { + return HANDLE_NOT_EXIST; + } + return iter->second->Push(data, timeout_in_sec); +} + +BlockQueueStatus_T GpuBufferMgr::Front(unsigned int handle, void **addr, size_t *len) { + auto iter = handle_queue_map_.find(handle); + if (iter == handle_queue_map_.end()) { + return HANDLE_NOT_EXIST; + } + return iter->second->Front(addr, len); +} + +BlockQueueStatus_T GpuBufferMgr::Pop(unsigned int handle) { + auto iter = handle_queue_map_.find(handle); + if (iter == handle_queue_map_.end()) { + return HANDLE_NOT_EXIST; + } + return iter->second->Pop(); +} + +void GpuBufferMgr::Close(unsigned int handle) noexcept { + if (!handle_queue_map_.count(handle)) { + return; + } + (void)handle_queue_map_.erase(handle); + handle_mgr_.FreeHandle(handle); + return; +} + +bool GpuBufferMgr::IsInit() const { return init_; } + +bool GpuBufferMgr::IsClosed() const { return closed_; } + +bool GpuBufferMgr::Destroy() { + for (auto iter = name_queue_map_.begin(); iter != name_queue_map_.end(); ++iter) { + std::shared_ptr queue = iter->second; + if (queue != nullptr) { + if (!queue->Destroy()) { + return false; + } + queue.reset(); + } + } + name_queue_map_.clear(); + return true; +} + +inline bool GpuBufferMgr::isCreated(unsigned int device_id, const std::string &channel_name) { + std::string name = std::to_string(device_id) + std::string("_") + channel_name; + if (name_queue_map_.count(name) != 0) { + return true; + } + return false; +} + +bool GpuBufferMgr::CloseNotify() { + bool result = true; + // lock scope + { + std::lock_guard lk(close_mutex_); + // set closed_ to be true, all the dataset retry can be jumped out of the while + closed_ = true; + } + + // wati for the dataset threads' ack + for (int i = 0; i < open_by_dataset_; i++) { + if (sema.Wait() == false) { + MS_LOG(ERROR) << "time out of receiving signals"; + result = false; + } + MS_LOG(DEBUG) << "receive one signal (" << i + 1 << "/" << open_by_dataset_ << ")"; + } + return result; +} + +void GpuBufferMgr::CloseConfirm() { sema.Signal(); } +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h new file mode 100644 index 0000000000..722a36c4ed --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_buffer_mgr.h @@ -0,0 +1,139 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "runtime/device/gpu/blocking_queue.h" + +#define EXPORT __attribute__((visibility("default"))) + +namespace mindspore { +namespace device { +static const unsigned int MAX_WAIT_TIME_IN_SEC = 60; + +class Semaphore { + public: + explicit Semaphore(int count = 0) : count_(count) {} + + inline void Signal() { + std::unique_lock lock(mutex_); + ++count_; + cv_.notify_one(); + } + + inline bool Wait() { + std::unique_lock lock(mutex_); + while (count_ == 0) { + if (cv_.wait_for(lock, std::chrono::seconds(MAX_WAIT_TIME_IN_SEC)) == std::cv_status::timeout) { + return false; + } + } + --count_; + return true; + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + int count_; +}; + +class HandleMgr { + public: + static const unsigned int MAX_HANDLE_NUM = 32; + static const unsigned int INVALID_HANDLE = 0xffffffffUL; + + unsigned int AllocHandle(); + void FreeHandle(unsigned int); + + private: + bool handle_list_[MAX_HANDLE_NUM]; +}; + +class GpuBufferMgr { + public: + EXPORT GpuBufferMgr() : cur_dev_id_(0), init_(false), closed_(false), open_by_dataset_(0) {} + + EXPORT virtual ~GpuBufferMgr() = default; + + EXPORT static GpuBufferMgr &GetInstance() noexcept; + + EXPORT BlockQueueStatus_T Create(unsigned int device_id, const std::string &channel_name, void *addr, + const std::vector &shape, const size_t &capacity); + + // call for Push thread + EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape, + std::function func); + + // call for Front/Pop thread + EXPORT unsigned int Open(unsigned int device_id, const std::string &channel_name, const std::vector &shape); + + EXPORT BlockQueueStatus_T Push(unsigned int handle, const std::vector &data, + unsigned int timeout_in_sec); + EXPORT BlockQueueStatus_T Front(unsigned int handle, void **addr, size_t *len); + EXPORT BlockQueueStatus_T Pop(unsigned int handle); + + EXPORT void set_device_id(int device_id); + + EXPORT void Close(unsigned int handle) noexcept; + + EXPORT bool IsInit() const; + + EXPORT bool IsClosed() const; + + EXPORT bool Destroy(); + + // call for Release GPU Resources + EXPORT bool CloseNotify(); + + // call for dataset send thread + EXPORT void CloseConfirm(); + + private: + void set_device() const; + + int cur_dev_id_; + bool init_; + bool closed_; + std::mutex mutex_; + std::mutex close_mutex_; + // how many queues opened by dataset + int open_by_dataset_; + Semaphore sema; + + HandleMgr handle_mgr_; + + std::map> handle_queue_map_; + std::map> name_queue_map_; + + inline bool isCreated(unsigned int device_id, const std::string &channel_name); + + GpuBufferMgr(const GpuBufferMgr &) = delete; + GpuBufferMgr &operator=(const GpuBufferMgr &) = delete; +}; +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_BUFFER_MGR_H_ diff --git a/mindspore/ccsrc/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h similarity index 100% rename from mindspore/ccsrc/device/gpu/gpu_common.h rename to mindspore/ccsrc/runtime/device/gpu/gpu_common.h diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc new file mode 100644 index 0000000000..a20a6a9a3c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_device_address.h" +#include +#include "runtime/device/gpu/gpu_device_manager.h" +#include "utils/log_adapter.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, TypeId, void *host_ptr) const { + MS_EXCEPTION_IF_NULL(host_ptr); + auto &stream = GPUDeviceManager::GetInstance().default_stream(); + MS_EXCEPTION_IF_NULL(stream); + auto ret = GPUDeviceManager::GetInstance().SyncStream(stream); + if (!ret) { + MS_LOG(ERROR) << "SyncStream failed"; + return ret; + } + if (size != size_) { + MS_LOG(WARNING) << "SyncDeviceToHost ignored, host size: " << size << ", device size " << size_; + return true; + } + return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size_); +} + +bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t, TypeId, const void *host_ptr) const { + MS_EXCEPTION_IF_NULL(host_ptr); + auto &stream = GPUDeviceManager::GetInstance().default_stream(); + MS_EXCEPTION_IF_NULL(stream); + if (!GPUDeviceManager::GetInstance().CopyHostMemToDeviceAsync(ptr_, host_ptr, size_, stream)) { + MS_LOG(ERROR) << "CopyHostMemToDeviceAsync failed"; + return false; + } + return GPUDeviceManager::GetInstance().SyncStream(stream); +} + +GPUDeviceAddress::~GPUDeviceAddress() { + if (ptr_ == nullptr) { + return; + } + if (from_mem_pool_) { + GPUMemoryAllocator::GetInstance().FreeTensorMem(ptr_); + ptr_ = nullptr; + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h new file mode 100644 index 0000000000..ade738deed --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ + +#include +#include +#include "runtime/device/device_address.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUDeviceAddress : public DeviceAddress { + public: + GPUDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} + GPUDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) + : DeviceAddress(ptr, size, format, type_id) {} + ~GPUDeviceAddress() override; + + bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + void set_status(DeviceAddressStatus status) { status_ = status; } + DeviceAddressStatus status() const { return status_; } + DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } + + private: + DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_ADDRESS_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc new file mode 100644 index 0000000000..8f17fc20b5 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_device_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "utils/log_adapter.h" +#include "utils/convert_utils.h" +#include "runtime/device/gpu/gpu_buffer_mgr.h" + +namespace mindspore { +namespace device { +namespace gpu { +void GPUDeviceManager::InitDevice() { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id"); + CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream."); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cuDNN handle."); + CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle."); + CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast(default_stream())), + "Failed to set stream for cuBLAS handle."); + CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator") +} + +void GPUDeviceManager::ReleaseDevice() { + for (DeviceStream stream : gpu_streams_) { + if (stream != nullptr) { + CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream."); + } + } + if (cudnn_handle_ != nullptr) { + CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle"); + } + if (cublas_handle_ != nullptr) { + CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle."); + } + CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); +} + +bool GPUDeviceManager::CreateStream(DeviceStream *stream) { + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); + gpu_streams_.emplace_back(*stream); + return true; +} + +const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } + +int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } + +bool GPUDeviceManager::set_cur_device_id(uint32_t device_id) { + if (!dev_id_init_) { + dev_id_init_ = true; + cur_dev_id_ = device_id; + mindspore::device::GpuBufferMgr::GetInstance().set_device_id(UintToInt(device_id)); + return true; + } else { + MS_LOG(ERROR) << "Device already been set."; + return false; + } +} + +uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } + +bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } + +const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } + +const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } + +bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } + +bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { + return CudaDriver::CopyDeviceMemToHost(dst, src, size); +} + +bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { + return CudaDriver::CopyHostMemToDevice(dst, src, size); +} + +bool GPUDeviceManager::CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, + DeviceStream stream) const { + return CudaDriver::CopyDeviceMemToHostAsync(dst, src, size, stream); +} + +bool GPUDeviceManager::CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, + DeviceStream stream) const { + return CudaDriver::CopyHostMemToDeviceAsync(dst, src, size, stream); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h new file mode 100644 index 0000000000..002806675c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_manager.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ + +#include +#include +#include +#include +#include "runtime/device/gpu/cuda_driver.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUDeviceManager { + public: + void InitDevice(); + void ReleaseDevice(); + + int device_count() const; + bool set_cur_device_id(uint32_t device_id); + uint32_t cur_device_id() const; + bool is_device_id_init() const; + + bool CreateStream(DeviceStream *stream); + bool SyncStream(const DeviceStream &stream) const; + const DeviceStream &default_stream() const; + + const cudnnHandle_t &GetCudnnHandle() const; + const cublasHandle_t &GetCublasHandle() const; + + bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; + bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; + + bool CopyDeviceMemToHostAsync(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size, DeviceStream stream) const; + bool CopyHostMemToDeviceAsync(const DeviceMemPtr &dst, const void *src, size_t size, DeviceStream stream) const; + + static GPUDeviceManager &GetInstance() { + static GPUDeviceManager instance; + return instance; + } + + private: + GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} + ~GPUDeviceManager() = default; + GPUDeviceManager(const GPUDeviceManager &) = delete; + GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; + + // default CUDA stream used for all the kernels. + DeviceStream default_stream_{nullptr}; + + // all gpu CUDA streams including default_stream_. + std::vector gpu_streams_; + + // handle used for cuDNN kernels. + cudnnHandle_t cudnn_handle_{nullptr}; + + // handle used for cuBLAS kernels. + cublasHandle_t cublas_handle_{nullptr}; + + bool dev_id_init_; + uint32_t cur_dev_id_; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_DEVICE_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc new file mode 100644 index 0000000000..9d88a205bc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/gpu/gpu_kernel_build.h" +#include +#include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/akg/akg_kernel_build.h" +#include "backend/kernel_compiler/akg/gpu/akg_gpu_kernel_build.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "frontend/operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +namespace mindspore { +namespace device { +namespace gpu { +void GpuBuild(const KernelGraphPtr &kernel_graph) { + kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); + MS_EXCEPTION_IF_NULL(bin_map); + bin_map->Initialize(); + MS_EXCEPTION_IF_NULL(kernel_graph); + auto kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { + std::string kernel_name = session::AnfRuntimeAlgorithm::GetCNodeName(kernel); + if (kernel_name == prim::kPrimTupleGetItem->name() || kernel_name == prim::kPrimMakeTuple->name() || + kernel_name == prim::kPrimDepend->name() || kernel_name == prim::kPrimStateSetItem->name()) { + continue; + } + + if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) == KernelType::AKG_KERNEL) { + auto gpu_kernel_ptr = kernel::AkgGpuKernelBuild(kernel); + if (!gpu_kernel_ptr) { + MS_LOG(EXCEPTION) << "Build akg kernel op[" << kernel_name << "] failed"; + } + session::AnfRuntimeAlgorithm::SetKernelMod(gpu_kernel_ptr, kernel.get()); + } else { + auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel); + if (!gpu_kernel_ptr) { + MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel_name << "] failed"; + } + if (!gpu_kernel_ptr->Init(kernel)) { + MS_LOG(EXCEPTION) << "Initialize gpu kernel op[" << kernel_name << "] failed."; + } + session::AnfRuntimeAlgorithm::SetKernelMod((kernel::KernelModPtr)gpu_kernel_ptr, kernel.get()); + } + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h new file mode 100644 index 0000000000..831c4e9511 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_build.h @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ + +#include +#include "backend/session/kernel_graph.h" +namespace mindspore { +namespace device { +namespace gpu { +void GpuBuild(const std::shared_ptr &kernel_graph); +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPUKERNELBUILD_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc new file mode 100644 index 0000000000..ddf73841b7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -0,0 +1,646 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_kernel_runtime.h" +#include "runtime/device/gpu/gpu_device_address.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "runtime/device/gpu/gpu_buffer_mgr.h" +#include "runtime/device/gpu/gpu_device_manager.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "runtime/device/gpu/distribution/collective_init.h" +#include "utils/convert_utils.h" +#include "utils/context/ms_context.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "common/utils.h" +#include "runtime/device/gpu/gpu_memory_manager.h" +#include "backend/kernel_compiler/common_utils.h" +#include "runtime/device/gpu/gpu_memory_copy_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +using mindspore::device::memswap::MemSwapManager; +using mindspore::device::memswap::SwapKind; +bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } + +bool GPUKernelRuntime::Init() { + if (device_init_ == true) { + GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); + return true; + } + auto ret = InitDevice(); + if (!ret) { + MS_LOG(ERROR) << "InitDevice error."; + return ret; + } + mem_manager_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->MallocDeviceMemory(); + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + bool collective_inited = CollectiveInitializer::instance().collective_inited(); + if (collective_inited && collective_handle_ != nullptr) { + auto init_nccl_comm_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "InitNCCLComm")); + MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr); + (*init_nccl_comm_funcptr)(); + } + device_init_ = true; + return ret; +} + +DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) { + return std::make_shared(device_ptr, device_size, format, type_id); +} + +bool GPUKernelRuntime::InitDevice() { + if (GPUDeviceManager::GetInstance().device_count() <= 0) { + MS_LOG(ERROR) << "No GPU device found."; + return false; + } + const void *collective_handle_ = CollectiveInitializer::instance().collective_handle(); + bool collective_inited = CollectiveInitializer::instance().collective_inited(); + if (collective_inited && collective_handle_ != nullptr) { + auto get_local_rank_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "local_rank_id")); + MS_EXCEPTION_IF_NULL(get_local_rank_funcptr); + device_id_ = IntToUint((*get_local_rank_funcptr)()); + } + if (!GPUDeviceManager::GetInstance().is_device_id_init()) { + if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) { + MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_); + return false; + } + } + GPUDeviceManager::GetInstance().InitDevice(); + stream_ = GPUDeviceManager::GetInstance().default_stream(); + if (stream_ == nullptr) { + MS_LOG(ERROR) << "No default CUDA stream found."; + return false; + } + return true; +} + +void GPUKernelRuntime::ReleaseDeviceRes() { + // For dataset mode. + if (GpuBufferMgr::GetInstance().IsInit()) { + if (!GpuBufferMgr::GetInstance().IsClosed()) { + if (!GpuBufferMgr::GetInstance().CloseNotify()) { + MS_LOG(EXCEPTION) << "Could not close gpu data queue."; + } + } + CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue."); + } + + // Destroy remaining memory swap events and free host memory. + for (auto &item : mem_swap_map_) { + auto &mem_swap_manager = item.second; + MS_EXCEPTION_IF_NULL(mem_swap_manager); + if (mem_swap_manager->trigger_swap()) { + mem_swap_manager->ClearSwapQueue(); + mem_swap_manager->ReleaseHostPinnedMem(); + } + } + + GPUDeviceManager::GetInstance().ReleaseDevice(); + if (mem_manager_ != nullptr) { + mem_manager_->FreeDeviceMemory(); + } + + kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance(); + MS_EXCEPTION_IF_NULL(bin_map); + bin_map->RemoveKernelCache(); +} + +void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->ResetDynamicMemory(); + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); + bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); + if (is_enable_dynamic_mem) { + // Use the dynamic memory pool. + InitKernelRefCount(graph); + InitMemorySwapInfo(graph); + InitKernelOutputAddress(graph); + } else { + AssignDynamicMemory(graph); + } +} + +bool GPUKernelRuntime::Run(session::KernelGraph *graph) { + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + bool ret = true; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); + bool is_enable_pynative_infer = context_ptr->enable_pynative_infer(); + if (is_enable_dynamic_mem && !is_enable_pynative_infer) { + auto graph_id = graph->graph_id(); + auto iter = mem_swap_map_.find(graph_id); + if (iter == mem_swap_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory swap map failed."; + } + mem_swap_manager_ = iter->second; + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + while (!LaunchKernelDynamic(graph)) { + MS_LOG(WARNING) << "Run out of memory and try memory swapping, it may take some time, please wait a moment."; + if (!UpdateMemorySwapInfo(graph)) { + return false; + } + } + } else { + ret = LaunchKernel(graph); + } + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(DEBUG) << "GPU kernel runtime run graph in " << cost << " us"; + return ret; +} + +void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // Init the kernel reference count. + if (!mem_reuse_util_ptr->InitDynamicKernelRef(graph)) { + MS_LOG(EXCEPTION) << "Init kernel reference count failed"; + } + mem_reuse_util_ptr->SetKernelDefMap(); + mem_reuse_util_ptr->SetReuseRefCount(); + // Can't free the device address of graph output, so set the reference count of graph output specially. + mem_reuse_util_ptr->SetGraphOutputRefCount(); + // Can't free the device address of summary nodes, so set the reference count of summary nodes specially. + mem_reuse_util_ptr->SetSummaryNodesRefCount(); + auto graph_id = graph->graph_id(); + mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr; +} + +void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared(); + MS_EXCEPTION_IF_NULL(gpu_mem_copy_manager); + MemSwapManagerPtr mem_swap_manager = std::make_shared(gpu_mem_copy_manager); + MS_EXCEPTION_IF_NULL(mem_swap_manager); + auto graph_id = graph->graph_id(); + mem_swap_map_[graph_id] = mem_swap_manager; +} + +void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + } + } +} + +void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (!AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); + } + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } +} + +bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto graph_id = graph->graph_id(); + auto iter = mem_reuse_util_map_.find(graph_id); + if (iter == mem_reuse_util_map_.end()) { + MS_LOG(EXCEPTION) << "Find memory reuse map failed."; + } + auto mem_reuse_util_ptr = iter->second; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // Reset the reference count. + mem_reuse_util_ptr->ResetDynamicUsedRefCount(); + // The inputs and outputs memory of communication kernel need be continuous, so separate processing. + AllocCommunicationOpDynamicRes(graph); + + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + if (!ret) { + return false; + } + if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { + MS_LOG(EXCEPTION) << "Launch kernel failed."; + } + FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id); + UpdateMemorySwapTask(kernel); + } + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + ClearSwapQueue(); + return true; +} + +bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); + for (auto &mem_swap_info : mem_swap_info_list) { + auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); + const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; + auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); + + if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); + } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + auto status = device_address->status(); + if (status == DeviceAddressStatus::kInDeviceToHost) { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + } else if (status == DeviceAddressStatus::kInHost) { + if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) { + return false; + } + if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) { + mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address); + } + } + } + } + return true; +} + +bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + ClearKernelOutputAddress(graph); + if (!mem_swap_manager_->mem_swap_init()) { + mem_swap_manager_->Init(graph); + } + return mem_swap_manager_->RetreatSwapInfo(); +} + +bool GPUKernelRuntime::UpdateMemorySwapTask(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return true; + } + if (mem_swap_manager_->QueryKernelTriggerSwap(kernel)) { + CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); + if (!AddMemorySwapTask(kernel)) { + return false; + } + } + CHECK_OP_RET_WITH_EXCEPT(mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost), "SyncCopyStream failed."); + return true; +} + +void GPUKernelRuntime::UpdateHostSwapQueue(const DeviceAddressPtr device_address) { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + auto status = device_address->status(); + switch (status) { + case DeviceAddressStatus::kInDevice: + break; + case DeviceAddressStatus::kInDeviceToHost: { + mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_); + device_address->set_status(DeviceAddressStatus::kInDevice); + break; + } + case DeviceAddressStatus::kInHostToDevice: { + while (device_address->status() != DeviceAddressStatus::kInDevice) { + while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { + device_address_swap_in->set_status(DeviceAddressStatus::kInDevice); + } + } + break; + } + case DeviceAddressStatus::kInHost: + MS_LOG(ERROR) << "Invaild device address status:" << status; + break; + default: + MS_LOG(EXCEPTION) << "Invaild device address status:" << status; + } +} + +void GPUKernelRuntime::UpdateDeviceSwapQueue() { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } +} + +void GPUKernelRuntime::ClearSwapQueue() { + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + if (!mem_swap_manager_->trigger_swap()) { + return; + } + mem_swap_manager_->ClearSwapQueue(); +} + +bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) { + MS_EXCEPTION_IF_NULL(mem_manager_); + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, size); + if (!ret) { + if (!mem_swap_manager_->trigger_swap()) { + return false; + } + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + ret = mem_manager_->MallocMemFromMemPool(device_address, size); + if (!ret) { + return false; + } + } + return true; +} + +void *GPUKernelRuntime::AttemptMallocMem(size_t size) { + MS_EXCEPTION_IF_NULL(mem_manager_); + MS_EXCEPTION_IF_NULL(mem_swap_manager_); + auto device_ptr = mem_manager_->MallocMemFromMemPool(size); + if (!device_ptr) { + if (!mem_swap_manager_->trigger_swap()) { + return nullptr; + } + mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); + while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { + if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { + device_address_swap_out->set_status(DeviceAddressStatus::kInHost); + mem_manager_->FreeMemFromMemPool(device_address_swap_out); + } + } + device_ptr = mem_manager_->MallocMemFromMemPool(size); + if (!device_ptr) { + return nullptr; + } + } + return device_ptr; +} + +bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, + AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { + if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) { + return false; + } + if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) { + return false; + } + if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) { + return false; + } + return true; +} + +bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_inputs); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + UpdateHostSwapQueue(device_address); + MS_EXCEPTION_IF_NULL(device_address->ptr_); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + input->size = device_address->size_; + kernel_inputs->emplace_back(input); + } + return true; +} + +bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_outputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_outputs); + UpdateDeviceSwapQueue(); + auto output_sizes = kernel_mod.GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { + return false; + } + kernel::AddressPtr output = std::make_shared(); + MS_EXCEPTION_IF_NULL(output); + output->addr = device_address->ptr_; + output->size = output_sizes[i]; + kernel_outputs->emplace_back(output); + } + return true; +} + +bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_workspaces) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_workspaces); + auto workspace_sizes = kernel_mod.GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + if (workspace_sizes[i] == 0) { + kernel_workspaces->emplace_back(nullptr); + continue; + } + auto device_ptr = AttemptMallocMem(workspace_sizes[i]); + if (!device_ptr) { + return false; + } + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_ptr; + workspace->size = workspace_sizes[i]; + kernel_workspaces->emplace_back(workspace); + } + return true; +} + +void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::IsCommunicationOp(kernel)) { + AllocCommunicationOpInputDynamicRes(kernel); + AllocCommunicationOpOutputDynamicRes(kernel); + } + } +} + +void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + bool is_need_alloc_memory = false; + bool is_need_free_memory = false; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_ == nullptr) { + is_need_alloc_memory = true; + } else { + is_need_free_memory = true; + } + total_size += device_address->size_; + size_list.emplace_back(device_address->size_); + addr_list.emplace_back(device_address); + } + AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); +} + +void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + bool is_need_alloc_memory = false; + bool is_need_free_memory = false; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_ == nullptr) { + is_need_alloc_memory = true; + } else { + is_need_free_memory = true; + } + total_size += output_sizes[i]; + size_list.emplace_back(output_sizes[i]); + addr_list.emplace_back(device_address); + } + AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list); +} + +void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, + const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + MS_EXCEPTION_IF_NULL(mem_manager_); + if (!is_need_alloc_memory) { + return; + } + if (is_need_free_memory) { + for (const auto &iter : addr_list) { + MS_EXCEPTION_IF_NULL(iter); + // Free the inputs/outputs of communication kernel which are not released. + if (iter->ptr_ != nullptr) { + mem_manager_->FreeMemFromMemPool(iter); + } + } + } + auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } +} + +void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, + const AddressPtrList &kernel_workspaces, uint32_t graph_id) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id]; + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + auto cnode = kernel->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::IsCommunicationOp(kernel)) { + return; + } + // Free the input of kernel by reference count. + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i); + if (kernel_ref_count_ptr == nullptr) { + continue; + } + kernel_ref_count_ptr->ref_count_dynamic_use_--; + if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) { + MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; + } + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); + mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } + // Free the output of kernel, if output has no reference. + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { + auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i); + if (kernel_ref_count_ptr == nullptr) { + continue; + } + if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + mem_manager_->FreeMemFromMemPool(device_address); + device_address->set_status(DeviceAddressStatus::kInDevice); + } + } + // Free the workspace of kernel. + for (size_t i = 0; i < kernel_workspaces.size(); ++i) { + auto workspace = kernel_workspaces[i]; + if (workspace != nullptr) { + MS_EXCEPTION_IF_NULL(workspace->addr); + mem_manager_->FreeMemFromMemPool(workspace->addr); + workspace->addr = nullptr; + } + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h new file mode 100644 index 0000000000..2b1f8198ce --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -0,0 +1,91 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "backend/optimizer/mem_reuse/mem_swap_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +using mindspore::device::memswap::MemSwapManagerPtr; +class GPUKernelRuntime : public KernelRuntime { + public: + GPUKernelRuntime() = default; + ~GPUKernelRuntime() override = default; + bool Init() override; + void ReleaseDeviceRes() override; + void AssignMemory(session::KernelGraph *graph) override; + bool Run(session::KernelGraph *graph) override; + + protected: + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) override; + bool SyncStream() override; + + private: + GPUKernelRuntime(const GPUKernelRuntime &); + GPUKernelRuntime &operator=(const GPUKernelRuntime &); + bool InitDevice(); + bool device_init_{false}; + + // The related functions and members for using dynamic memory pool. + void InitKernelRefCount(const session::KernelGraph *graph); + void InitKernelOutputAddress(const session::KernelGraph *graph); + void InitMemorySwapInfo(const session::KernelGraph *graph); + void ClearKernelOutputAddress(const session::KernelGraph *graph); + bool LaunchKernelDynamic(const session::KernelGraph *graph); + bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); + void *AttemptMallocMem(size_t size); + bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, + AddressPtrList *kernel_outputs); + bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs); + bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_outputs); + bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, + const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces); + void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph); + void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel); + void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); + void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, + const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, + uint32_t graph_id); + bool AddMemorySwapTask(const AnfNodePtr &kernel); + bool UpdateMemorySwapInfo(const session::KernelGraph *graph); + bool UpdateMemorySwapTask(const AnfNodePtr &kernel); + void UpdateHostSwapQueue(const DeviceAddressPtr device_address); + void UpdateDeviceSwapQueue(); + void ClearSwapQueue(); + std::unordered_map mem_reuse_util_map_; + std::unordered_map mem_swap_map_; + MemSwapManagerPtr mem_swap_manager_{nullptr}; +}; +MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc new file mode 100644 index 0000000000..e2395bbaf2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "utils/log_adapter.h" +#include "utils/context/ms_context.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +namespace device { +namespace gpu { +bool GPUMemoryAllocator::Init() { + size_t total_size = total_mem_size(); + size_t free_size = CudaDriver::free_mem_size(); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + limited_device_memory_ = context_ptr->max_device_memory(); + available_device_memory_ = FloatToSize(limited_device_memory_ * 1024 * 1024 * 1024); + if (total_size > 0 && free_size > 0 && available_device_memory_ > 0) { + MS_LOG(INFO) << "GPU device total memory size " << total_size << ", current free memory size " << free_size + << ", set max available memory size " << available_device_memory_ << "."; + } else { + MS_LOG(EXCEPTION) << "GPU device memory error, total memory size " << total_size << ", current free memory size " + << free_size << ", set max available memory size " << available_device_memory_ << "."; + } + return true; +} + +void GPUMemoryAllocator::CheckMaxDeviceMemory() const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto max_device_memory = context_ptr->max_device_memory(); + // Currently not support modifying the max device memory. + if (limited_device_memory_ != max_device_memory) { + MS_LOG(EXCEPTION) + << "Can't change context param max_device_memory in runtime, currently effective max_device_memory(" + << limited_device_memory_ << "GB), set new max_device_memory(" << max_device_memory << "GB) failed."; + } +} + +bool GPUMemoryAllocator::Finalize() { + if (buffer_q_addr_ != nullptr) { + if (!CudaDriver::FreeDeviceMem(buffer_q_addr_)) { + MS_LOG(ERROR) << "Could not free buffer queue memory."; + return false; + } + } + return true; +} + +bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { + auto alloc_size = AllocDeviceMem(size, addr); + buffer_q_addr_ = *addr; + // Buffer queue needs to ensure that the alloc_size and size is equal. + return (alloc_size == size) ? true : false; +} + +size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { + if (size == 0) { + MS_LOG(EXCEPTION) << "The memory alloc size is 0."; + } + auto free_size = free_mem_size(); + if (size > free_size) { + MS_LOG(EXCEPTION) << "Memory not enough: current free memory size[" << free_size + << "] is smaller than required size[" << size << "]."; + } + + auto alloc_size = CudaDriver::AllocDeviceMem(size, addr); + if (alloc_size == 0) { + MS_LOG(EXCEPTION) << "Alloc device memory[" << size << "] failed."; + } + total_used_device_memory_ += alloc_size; + available_device_memory_ -= alloc_size; + MS_LOG(INFO) << "Current free memory size[" << free_size - alloc_size << "], current alloc size[" << alloc_size + << "], total used size[" << total_used_device_memory_ << "]."; + return alloc_size; +} + +bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } + +size_t GPUMemoryAllocator::free_mem_size() { return std::min(CudaDriver::free_mem_size(), available_device_memory_); } + +size_t GPUMemoryAllocator::total_mem_size() { return CudaDriver::total_mem_size(); } +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h new file mode 100644 index 0000000000..4b6eaa4e14 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ + +#include +#include "runtime/device/gpu/cuda_driver.h" +#include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h" + +namespace mindspore { +namespace device { +namespace gpu { +class GPUMemoryAllocator : public DynamicMemPoolBestFit { + public: + ~GPUMemoryAllocator() override = default; + bool Init(); + void CheckMaxDeviceMemory() const; + bool Finalize(); + bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); + + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + size_t free_mem_size() override; + size_t total_mem_size() override; + + static GPUMemoryAllocator &GetInstance() { + static GPUMemoryAllocator instance; + return instance; + } + + private: + GPUMemoryAllocator() = default; + GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; + GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; + + // Used to track address of data buffer queue. + DeviceMemPtr buffer_q_addr_{nullptr}; + + float limited_device_memory_{0.0}; + size_t total_used_device_memory_{0}; + size_t available_device_memory_{0}; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc new file mode 100644 index 0000000000..0406c0f151 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_memory_copy_manager.h" +#include "runtime/device/gpu/gpu_common.h" +#include "runtime/device/gpu/gpu_device_manager.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace gpu { +void GPUMemCopyManager::Init() { + CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_out_stream_), + "Failed to create CUDA stream of memory swap out."); + CHECK_OP_RET_WITH_EXCEPT(GPUDeviceManager::GetInstance().CreateStream(&swap_in_stream_), + "Failed to create CUDA stream of memory swap in."); +} + +void GPUMemCopyManager::AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(host_addr.addr); + DeviceEvent event = nullptr; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); + DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); + MS_EXCEPTION_IF_NULL(device_ptr); + device_address->set_status(DeviceAddressStatus::kInDeviceToHost); + + CHECK_OP_RET_WITH_EXCEPT( + CudaDriver::CopyDeviceMemToHostAsync(host_addr.addr, device_ptr, host_addr.size, swap_out_stream_), + "Failed to copy device memory to host."); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_out_stream_), + "Failed to record CUDA event to swap out stream."); + swap_out_queue_.emplace(device_address, event); +} + +void GPUMemCopyManager::AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) { + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(host_addr.addr); + DeviceEvent event = nullptr; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&event, cudaEventDisableTiming), "Failed to create CUDA event."); + DeviceMemPtr device_ptr = const_cast(device_address->GetPtr()); + MS_EXCEPTION_IF_NULL(device_ptr); + device_address->set_status(DeviceAddressStatus::kInHostToDevice); + + CHECK_OP_RET_WITH_EXCEPT( + CudaDriver::CopyHostMemToDeviceAsync(device_ptr, host_addr.addr, host_addr.size, swap_in_stream_), + "Failed to copy host memory to device."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(event, swap_in_stream_), + "Failed to record CUDA event to swap in stream."); + swap_in_queue_.emplace(device_address, event); +} + +bool GPUMemCopyManager::SyncMemCopyStream(SwapKind swap_kind) { + if (swap_kind == SwapKind::kDeviceToHost) { + return GPUDeviceManager::GetInstance().SyncStream(swap_out_stream_); + } else { + return GPUDeviceManager::GetInstance().SyncStream(swap_in_stream_); + } +} + +DeviceAddressPtr GPUMemCopyManager::UpdateSwapOutQueue() { + if (swap_out_queue_.empty()) { + return nullptr; + } + auto &task = swap_out_queue_.front(); + auto device_address = task.first; + auto &event = task.second; + bool finish_swap = CudaDriver::QueryEvent(event); + if (!finish_swap) { + return nullptr; + } + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); + swap_out_queue_.pop(); + return device_address; +} + +DeviceAddressPtr GPUMemCopyManager::UpdateSwapInQueue() { + if (swap_in_queue_.empty()) { + return nullptr; + } + auto &task = swap_in_queue_.front(); + auto device_address = task.first; + auto &event = task.second; + bool finish_swap = CudaDriver::QueryEvent(event); + if (!finish_swap) { + return nullptr; + } + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); + swap_in_queue_.pop(); + return device_address; +} + +bool GPUMemCopyManager::AllocHostPinnedMem(size_t size, void **addr) const { + auto alloc_size = CudaDriver::AllocHostPinnedMem(size, addr); + return alloc_size == size; +} + +void GPUMemCopyManager::FreeHostPinnedMem(void *addr) const { CudaDriver::FreeHostPinnedMem(addr); } + +void GPUMemCopyManager::ClearSwapQueue() { + CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kDeviceToHost), "Failed to sync swap out stream"); + CHECK_OP_RET_WITH_EXCEPT(SyncMemCopyStream(SwapKind::kHostToDevice), "Failed to sync swap in stream"); + + while (!swap_out_queue_.empty()) { + auto &event = swap_out_queue_.front().second; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap out."); + swap_out_queue_.pop(); + } + while (!swap_in_queue_.empty()) { + auto &event = swap_in_queue_.front().second; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(event), "Failed to destroy CUDA event of swap in."); + swap_in_queue_.pop(); + } +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h new file mode 100644 index 0000000000..dc99b7f7d0 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_copy_manager.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ + +#include +#include +#include +#include "backend/optimizer/mem_reuse/mem_copy_manager.h" +#include "runtime/device/device_address.h" +#include "runtime/device/gpu/cuda_driver.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace device { +namespace gpu { +using mindspore::device::memswap::MemCopyManager; +using mindspore::device::memswap::SwapKind; +class GPUMemCopyManager : public MemCopyManager { + public: + GPUMemCopyManager() = default; + + ~GPUMemCopyManager() override = default; + + void Init() override; + + void AddMemSwapOutTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; + + void AddMemSwapInTask(const DeviceAddressPtr &device_address, const HostAddress &host_addr) override; + + bool SyncMemCopyStream(SwapKind swap_kind) override; + + DeviceAddressPtr UpdateSwapOutQueue() override; + + DeviceAddressPtr UpdateSwapInQueue() override; + + bool AllocHostPinnedMem(size_t size, void **addr) const override; + + void FreeHostPinnedMem(void *addr) const override; + + void ClearSwapQueue() override; + + private: + DeviceStream swap_out_stream_{nullptr}; + DeviceStream swap_in_stream_{nullptr}; + std::queue> swap_out_queue_; + std::queue> swap_in_queue_; +}; +using GPUMemCopyManagerPtr = std::shared_ptr; +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_COPY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc new file mode 100644 index 0000000000..ffa07eea0d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc @@ -0,0 +1,92 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_memory_manager.h" +#include "runtime/device/gpu/gpu_memory_allocator.h" +#include "utils/context/ms_context.h" +#include "utils/convert_utils.h" +namespace mindspore { +namespace device { +namespace gpu { +void *GPUMemoryManager::MallocMemFromMemPool(size_t size) { + return GPUMemoryAllocator::GetInstance().AllocTensorMem(size); +} + +void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { + GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); +} + +std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); +} + +void GPUMemoryManager::MallocDeviceMemory() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + // If use the dynamic memory pool, then alloc the first memory block to init. + if (context_ptr->enable_dynamic_mem_pool()) { + auto device_addr = MallocMemFromMemPool(1); + if (!device_addr) { + MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; + } + } else { + // Need to reserve 20% space for dynamic memory + const float init_gpu_mem_ratio = 0.8; + size_t mem_size = FloatToSize(GPUMemoryAllocator::GetInstance().free_mem_size() * init_gpu_mem_ratio); + auto alloc_size = + GPUMemoryAllocator::GetInstance().AllocDeviceMem(mem_size, reinterpret_cast(&device_mem_base_)); + device_mem_size_ = alloc_size; + static_mem_offset_ = device_mem_size_; + } +} + +void GPUMemoryManager::FreeDeviceMemory() { + if (device_mem_base_ != nullptr) { + if (!GPUMemoryAllocator::GetInstance().FreeDeviceMem(device_mem_base_)) { + MS_LOG(EXCEPTION) << "Could not free gpu device memory."; + } + } + GPUMemoryAllocator::GetInstance().ReleaseDeviceRes(); +} + +uint8_t *GPUMemoryManager::MallocStaticMem(size_t size, bool) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->enable_dynamic_mem_pool()) { + auto device_ptr = MallocMemFromMemPool(size); + MS_EXCEPTION_IF_NULL(device_ptr); + return AddressOffset(device_ptr, 0); + } + + auto align_size = GetCommonAlignSize(size); + if (static_mem_offset_ < align_size) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + auto offset = static_mem_offset_ - align_size; + if (dynamic_mem_offset_ > offset) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + total_static_size_ += align_size; + static_mem_offset_ = offset; + return device_mem_base_ + offset; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h new file mode 100644 index 0000000000..533116cefc --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#include +#include "runtime/device/memory_manager.h" +namespace mindspore { +namespace device { +namespace gpu { +class GPUMemoryManager : public MemoryManager { + public: + GPUMemoryManager() = default; + virtual ~GPUMemoryManager() = default; + + void MallocDeviceMemory() override; + void FreeDeviceMemory() override; + + void *MallocMemFromMemPool(size_t size) override; + void FreeMemFromMemPool(void *device_ptr) override; + std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); + + protected: + uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; +}; +} // namespace gpu +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc new file mode 100644 index 0000000000..78915f10d7 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/gpu_stream_assign.h" +#include +#include +#include +#include +#include "runtime/device/gpu/gpu_common.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "runtime/device/gpu/gpu_device_manager.h" + +namespace mindspore { +namespace device { +namespace gpu { +void AssignGpuStream(const std::shared_ptr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector allreduce_kernels; + auto execution_kernels = kernel_graph->execution_order(); + for (auto kernel_node : execution_kernels) { + std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == kAllReduceOpName) { + allreduce_kernels.emplace_back(kernel_node); + } else { + DeviceStream compute_stream = GPUDeviceManager::GetInstance().default_stream(); + MS_EXCEPTION_IF_NULL(compute_stream); + AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(compute_stream)), kernel_node); + } + } + if (allreduce_kernels.size() > 1) { + // Assign multiple streams only when there're multiple AllReduce nodes. + std::vector send_recv_pairs; + if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) { + DeviceStream comm_stream = nullptr; + GPUDeviceManager::GetInstance().CreateStream(&comm_stream); + std::transform( + allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) { + AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(comm_stream)), allreduce_kernel); + return allreduce_kernel; + }); + InsertStreamSwitchNode(kernel_graph, send_recv_pairs); + } else { + return; + } + } +} + +bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, + std::vector *send_recv_pairs) { + auto execution_kernels = kernel_graph->execution_order(); + std::vector::iterator iter, iter_begin; + iter = iter_begin = execution_kernels.begin(); + std::vector::iterator iter_end = execution_kernels.end(); + for (; iter != execution_kernels.end(); ++iter) { + std::string kernel_name = AnfAlgo::GetCNodeName(*iter); + if (kernel_name == kAllReduceOpName) { + // Find AllReduce node's last input node. + std::vector::iterator mock_send_node_iter = + FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch); + if (mock_send_node_iter == iter + 1) { + MS_LOG(WARNING) << "Can't find send node place before AllReduce node."; + continue; + } + SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter, + IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)}; + send_recv_pairs->push_back(pair1); + // Find node which uses AllReduce as input[0]. + std::vector::iterator mock_recv_node_iter = + FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch); + if (mock_recv_node_iter == iter_end) { + MS_LOG(WARNING) << "Can't find recv node place after AllReduce node."; + return false; + } + SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1), + IntToSize(mock_recv_node_iter - iter_begin)}; + send_recv_pairs->push_back(pair2); + } + } + return true; +} + +std::vector::iterator FindSendNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_recv_node, + StreamSwitchType stream_switch_type) { + MS_EXCEPTION_IF_NULL(mock_recv_node); + if (stream_switch_type == kAllReduceStreamSwitch) { + for (auto iter = begin; iter != end; iter++) { + if (*(iter + 1) == mock_recv_node) { + return iter; + } + } + } + return end; +} + +std::vector::iterator FindRecvNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_send_node, + StreamSwitchType stream_switch_type) { + MS_EXCEPTION_IF_NULL(mock_send_node); + for (auto iter = begin; iter != end; iter++) { + auto node = *iter; + if (stream_switch_type == kAllReduceStreamSwitch) { + for (auto input : node->inputs()) { + if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) { + return iter; + } + } + } + } + return end; +} + +void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, + const std::vector &send_recv_pairs) { + std::set ordered_stream_switch_nodes; + for (SendRecvPair pair : send_recv_pairs) { + StreamSwitchType stream_switch_type = pair.stream_switch_type; + CNodePtr mock_send_node = pair.mock_send_node; + CNodePtr mock_recv_node = pair.mock_recv_node; + size_t send_node_offset = pair.send_node_offset; + size_t recv_node_offset = pair.recv_node_offset; + CNodePtr send_node = nullptr; + CNodePtr recv_node = nullptr; + // Step 1: generate Send and Recv CNodes. + if (stream_switch_type == kAllReduceStreamSwitch) { + if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) { + MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch"; + } + } + // Step 2: sort send and recv CNodes by offset. + ordered_stream_switch_nodes.insert({send_node_offset, send_node}); + ordered_stream_switch_nodes.insert({recv_node_offset, recv_node}); + } + // Step 3: insert stream switch CNodes into execution kernel list. + auto execution_kernels = kernel_graph->execution_order(); + for (auto node = ordered_stream_switch_nodes.rbegin(); node != ordered_stream_switch_nodes.rend(); node++) { + execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); + } + kernel_graph->set_execution_order(execution_kernels); +} + +bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, + const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, + CNodePtr *recv_node) { + *send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName); + MS_EXCEPTION_IF_NULL(*send_node); + *recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName); + MS_EXCEPTION_IF_NULL(*recv_node); + + cudaEvent_t event = nullptr; + CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed."); + AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast(event)), *send_node); + AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast(event)), *recv_node); + + uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, kAttrStreamId); + AnfAlgo::SetNodeAttr(kAttrRecordEventStream, MakeValue(send_stream), *send_node); + uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, kAttrStreamId); + AnfAlgo::SetNodeAttr(kAttrWaitEventStream, MakeValue(recv_stream), *recv_node); + return true; +} + +CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name) { + auto op = std::make_shared(name); + MS_EXCEPTION_IF_NULL(op); + auto apply = std::make_shared(op); + MS_EXCEPTION_IF_NULL(apply); + std::vector input_list = {apply}; + CNodePtr node = kernel_graph->NewCNode(input_list); + MS_EXCEPTION_IF_NULL(node); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get()); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + node->set_abstract(abstract_none); + SetKernelInfo(node); + return node; +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h new file mode 100644 index 0000000000..f22ce8fe38 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_stream_assign.h @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ +#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ + +#include +#include +#include +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace device { +namespace gpu { +enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 }; +struct SendRecvPair { + StreamSwitchType stream_switch_type; + CNodePtr mock_send_node; + CNodePtr mock_recv_node; + size_t send_node_offset; + size_t recv_node_offset; +}; +struct StreamSwitchNode { + size_t offset; + CNodePtr cnode; + bool operator<(const StreamSwitchNode &n) const { + if (offset < n.offset) { + return true; + } else if (offset == n.offset) { + return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false; + } else { + return false; + } + } +}; +void AssignGpuStream(const std::shared_ptr &kernel_graph); +bool FindAllReduceStreamSwitchPos(const std::shared_ptr &kernel_graph, + std::vector *send_recv_pairs); +// Find Send node position according to "mock" recv node. +// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node. +std::vector::iterator FindSendNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_recv_node, + StreamSwitchType stream_switch_type); +// Find Recv node position according to "mock" send node. +// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node. +std::vector::iterator FindRecvNodePos(std::vector::iterator begin, + std::vector::iterator end, const CNodePtr mock_send_node, + StreamSwitchType stream_switch_type); +void InsertStreamSwitchNode(const std::shared_ptr &kernel_graph, + const std::vector &send_recv_pairs); +bool GenSendRecvCNodesForAllReduce(const std::shared_ptr &kernel_graph, + const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node, + CNodePtr *recv_node); +CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name); +} // namespace gpu +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc new file mode 100644 index 0000000000..4326987784 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/kernel_info_setter.h" +#include +#include +#include "backend/kernel_compiler/kernel.h" +#include "utils/utils.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "common/utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/oplib/opinfo.h" + +namespace mindspore { +namespace device { +namespace gpu { +using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; +using mindspore::kernel::KernelBuildInfo; +namespace { +bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, + const std::shared_ptr &selected_kernel_info) { + MS_EXCEPTION_IF_NULL(selected_kernel_info); + MS_EXCEPTION_IF_NULL(alternative_kernel_info); + size_t selected_input_num = selected_kernel_info->GetInputNum(); + size_t alternative_input_num = alternative_kernel_info->GetInputNum(); + if (selected_input_num != alternative_input_num) { + return false; + } + for (size_t i = 0; i < selected_input_num; i++) { + if (selected_kernel_info->GetInputFormat(i) != alternative_kernel_info->GetInputFormat(i)) { + return false; + } + if (selected_kernel_info->GetInputDeviceType(i) != alternative_kernel_info->GetInputDeviceType(i)) { + return false; + } + } + + size_t selected_output_num = selected_kernel_info->GetOutputNum(); + size_t alternative_output_num = alternative_kernel_info->GetOutputNum(); + if (selected_output_num != alternative_output_num) { + return false; + } + for (size_t i = 0; i < selected_output_num; i++) { + if (selected_kernel_info->GetOutputFormat(i) != alternative_kernel_info->GetOutputFormat(i)) { + return false; + } + if (selected_kernel_info->GetOutputDeviceType(i) != alternative_kernel_info->GetOutputDeviceType(i)) { + return false; + } + } + return true; +} + +std::string SupportedTypeList(const CNodePtr &kernel_node) { + std::string supported_type_lists = + kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); + if (!supported_type_lists.empty()) { + return supported_type_lists; + } + std::vector> kernel_info_list; + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); + if (op_info_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Unsupported op [" << op_name << "]"; + } + (void)ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list); + for (size_t i = 0; i < kernel_info_list.size(); i++) { + auto supported_akg_type = kernel_info_list[i]->GetAllInputDeviceTypes(); + auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes(); + std::string supported_akg_type_list = "in["; + for (auto type : supported_akg_type) { + supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); + } + supported_type_lists = supported_type_lists + supported_akg_type_list + "], out["; + supported_akg_type_list.clear(); + for (auto type : supported_akg_type_out) { + supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type); + } + supported_type_lists = supported_type_lists + supported_akg_type_list + "]; "; + } + return supported_type_lists; +} + +bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(selected_kernel_info); + std::vector> kernel_info_list; + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + + auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kAKG); + if (op_info_ptr == nullptr) { + MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; + return false; + } + if (!ParseMetadata(kernel_node, op_info_ptr, kernel::Processor::CUDA, &kernel_info_list)) { + MS_LOG(EXCEPTION) << "Parsed metadata of op[" << op_name << "] failed."; + } + if (kernel_info_list.empty()) { + MS_LOG(EXCEPTION) << "Akg dose not has metadata of op[" << op_name << "]."; + } + + bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), + [&](const std::shared_ptr &alternative_kernel_info) { + return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); + }); + if (!match) { + MS_LOG(ERROR) << "Not find op[" << op_name << "] in akg"; + return false; + } + return true; +} + +void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + auto input_kernel_node = kernel_node->input(input_index + 1); + MS_EXCEPTION_IF_NULL(input_kernel_node); + if (!input_kernel_node->isa()) { + continue; + } + std::shared_ptr builder = + std::make_shared(); + + auto param = input_kernel_node->cast(); + MS_EXCEPTION_IF_NULL(param); + if (!AnfAlgo::IsParameterWeight(param)) { + std::vector output_format = {kOpFormat_DEFAULT}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {AnfAlgo::GetOutputInferDataType(input_kernel_node, 0)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + continue; + } + if ((AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) || + (AnfAlgo::GetCNodeName(kernel_node) == "ApplyMomentum")) { + std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; + builder->SetOutputsFormat(output_format); + std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + builder->SetOutputsDeviceType(output_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get()); + } + } +} +} // namespace + +void SetKernelInfo(const CNodePtr &kernel_node) { + std::vector inputs_format; + std::vector inputs_type; + std::shared_ptr builder = + std::make_shared(); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + builder->SetInputsFormat(inputs_format); + builder->SetInputsDeviceType(inputs_type); + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + builder->SetOutputsFormat(outputs_format); + builder->SetOutputsDeviceType(outputs_type); + + bool result = + kernel::GpuKernelFactory::GetInstance().SearchRegistered(AnfAlgo::GetCNodeName(kernel_node), builder->Build()); + KernelType kernel_type = UNKNOWN_KERNEL_TYPE; + + if (!result) { + result = SelectAkgKernel(kernel_node, builder->Build()); + kernel_type = AKG_KERNEL; + } + + if (!result) { + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + std::string build_type = "in ["; + std::for_each(std::begin(inputs_type), std::end(inputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "] out ["; + std::for_each(std::begin(outputs_type), std::end(outputs_type), + [&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; }); + build_type += "]"; + auto supported_type_lists = SupportedTypeList(kernel_node); + MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name + << "] fail! Incompatible data type!\nThe supported data types are " << supported_type_lists + << ", but get " << build_type; + } + builder->SetKernelType(kernel_type); + builder->SetProcessor(kernel::Processor::CUDA); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get()); + SetTensorDeviceInfo(*(builder->Build()), kernel_node); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h similarity index 100% rename from mindspore/ccsrc/device/gpu/kernel_info_setter.h rename to mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.h diff --git a/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc new file mode 100644 index 0000000000..4605a0eb4e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/gpu/mpi/mpi_initializer.h" + +#include +#include +#include + +namespace mindspore { +namespace device { +namespace gpu { +MPIInitializer::MPIInitializer() { + int init_flag = 0; + if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { + return; + } + if (init_flag == 0) { + auto ret = MPI_Init(nullptr, nullptr); + if (ret != MPI_SUCCESS) { + return; + } + } + MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); + MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); +} + +MPIInitializer::~MPIInitializer() { + int finalized_flag = 0; + (void)MPI_Finalized(&finalized_flag); + if (finalized_flag == 0) { + (void)MPI_Finalize(); + } +} + +MPIInitializer &MPIInitializer::GetInstance() { + static MPIInitializer instance; + return instance; +} + +int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } + +int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } + +PYBIND11_MODULE(_ms_mpi, mpi_initializer) { + mpi_initializer.doc() = "mindspore mpi python wrapper"; + mpi_initializer.def("get_rank_id", &MPIInitializer::get_rank_id, "get rank id"); + mpi_initializer.def("get_rank_size", &MPIInitializer::get_rank_size, "get rank size"); +} +} // namespace gpu +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/mpi/mpi_initializer.h b/mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h similarity index 100% rename from mindspore/ccsrc/device/gpu/mpi/mpi_initializer.h rename to mindspore/ccsrc/runtime/device/gpu/mpi/mpi_initializer.h diff --git a/mindspore/ccsrc/device/gpu/readme.md b/mindspore/ccsrc/runtime/device/gpu/readme.md similarity index 100% rename from mindspore/ccsrc/device/gpu/readme.md rename to mindspore/ccsrc/runtime/device/gpu/readme.md diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc new file mode 100644 index 0000000000..bb1f7f723e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -0,0 +1,591 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_adjust.h" + +#include +#include +#include +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" +#include "common/trans.h" +#include "utils/config_manager.h" +#include "common/utils.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/utils.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" +#include "runtime/device/ascend/kernel_select_ascend.h" +#include "runtime/base.h" +#include "runtime/device/ascend/ascend_stream_assign.h" + +namespace mindspore { +namespace device { +using device::ascend::ProfilingUtils; +void KernelAdjust::ReorderGetNext(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + const std::vector &origin_cnode_list = kernel_graph_ptr->execution_order(); + std::vector getnext_list; + std::vector other_list; + for (const auto &cnode : origin_cnode_list) { + if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { + getnext_list.emplace_back(cnode); + } else { + other_list.emplace_back(cnode); + } + } + std::vector new_order_list; + new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end()); + new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); + kernel_graph_ptr->set_execution_order(new_order_list); +} + +bool KernelAdjust::NeedInsertSwitch() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && + ConfigManager::GetInstance().iter_num() > 1); +} + +CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + return send_node_ptr; +} + +CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + +void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { + device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance(); + resource_manager.ResetResource(); + if (!NeedInsertSwitch()) { + return; + } + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; + ReorderGetNext(kernel_graph_ptr); + std::map switch_loop_input; + CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); + + std::vector *mute_inputs = kernel_graph_ptr->MutableInputs(); + MS_EXCEPTION_IF_NULL(mute_inputs); + mute_inputs->push_back(switch_loop_input[kLoopCountParamName]); + mute_inputs->push_back(switch_loop_input[kEpochParamName]); + mute_inputs->push_back(switch_loop_input[kIterLoopParamName]); + mute_inputs->push_back(switch_loop_input[kZeroParamName]); + mute_inputs->push_back(switch_loop_input[kOneParamName]); + for (const auto &input : kernel_graph_ptr->inputs()) { + MS_EXCEPTION_IF_NULL(input); + if (input->isa()) { + ParameterPtr param_ptr = input->cast(); + if (param_ptr == nullptr) { + MS_EXCEPTION(NotSupportError) << "Cast to parameter point failed !"; + } + } + } + + const std::vector &orders = kernel_graph_ptr->execution_order(); + if (orders.empty()) { + MS_LOG(EXCEPTION) << "graph execution order is empty"; + } + + std::vector exec_order; + std::vector getnext_active_streams; + std::vector fpbp_active_streams; + CNodePtr getnext_cnode; + uint32_t eos_done_event_id = UINT32_MAX; + + // getnext loop process + // getnext loop stream switch op + CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(getnext_switch_app); + uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get()); + exec_order.push_back(getnext_switch_app); + + // getnext op + uint32_t getnext_stream_id = resource_manager.ApplyNewStream(); + size_t i = 0; + for (; i < orders.size(); i++) { + auto node = orders[i]; + exec_order.push_back(node); + AnfAlgo::SetStreamId(getnext_stream_id, exec_order[exec_order.size() - 1].get()); + if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + getnext_cnode = node; + break; + } + } + + // update getnext loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(getnext_stream_id), getnext_switch_app); + + // getnext loop fpbp start send + uint32_t fpbp_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get()); + exec_order.push_back(fpbp_start_send); + + if (eos_mode) { + // getnext loop eos start send + uint32_t eos_start_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, eos_start_event_id); + AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get()); + exec_order.push_back(eos_start_send); + + // End Of Sequence loop process + // eos loop stream switch + CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(eos_switch_app); + uint32_t eos_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(eos_switch_stream_id, eos_switch_app.get()); + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), eos_switch_app); + exec_order.push_back(eos_switch_app); + + // eos loop eos start recv + CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id); + uint32_t eos_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get()); + exec_order.push_back(eos_start_recv); + + // update eos loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(eos_stream_id), eos_switch_app); + + // EndOfSequence op + CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode); + MS_EXCEPTION_IF_NULL(end_of_sequence_op); + AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get()); + exec_order.push_back(end_of_sequence_op); + + // eos loop eos done send + eos_done_event_id = resource_manager.ApplyNewEvent(); + CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, eos_done_event_id); + AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get()); + exec_order.push_back(eos_done_send); + + // eos loop stream active + fpbp_active_streams.push_back(eos_switch_stream_id); + } + + // fpbp loop process + // fpbp loop stream switch + CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(fpbp_switch_app); + uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get()); + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), fpbp_switch_app); + exec_order.push_back(fpbp_switch_app); + + // fpbp loop fpbp start recv + CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id); + uint32_t fpbp_stream_id = resource_manager.ApplyNewStream(); + AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get()); + exec_order.push_back(fpbp_start_recv); + + // update fpbp loop stream switch true_branch_stream attr + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(fpbp_stream_id), fpbp_switch_app); + + // fpbp loop AssignAdd + CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get()); + exec_order.push_back(assign_add_one); + + // fpbp memcpy + std::vector memcpy_list; + std::vector other_list; + CNodePtr cur_cnode = nullptr; + for (size_t idx = i + 1; idx < orders.size(); idx++) { + cur_cnode = orders[idx]; + if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { + memcpy_list.emplace_back(cur_cnode); + } else { + other_list.emplace_back(cur_cnode); + } + } + + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + + // fpbp loop eos done recv + if (eos_mode) { + CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id); + AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get()); + exec_order.push_back(eos_done_recv); + } + + // stream active to activate getnext loop + CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(getnext_active_app); + getnext_active_streams.push_back(getnext_switch_stream_id); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(getnext_active_streams), + getnext_active_app); + exec_order.push_back(getnext_active_app); + + // fpbp loop other ops + (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order)); + + // stream active to activate fpbp loop and eos loop + CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(fpbp_active_app); + fpbp_active_streams.push_back(fpbp_switch_stream_id); + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(fpbp_active_streams), fpbp_active_app); + exec_order.push_back(fpbp_active_app); + + kernel_graph_ptr->set_execution_order(exec_order); +} + +void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, + std::map *switch_loop_input) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(switch_loop_input); + std::vector shp = {1}; + tensor::TensorPtr tensor_ptr = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(tensor_ptr); + mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); + if (paremeter_abstract_ptr == nullptr) { + MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!"; + } + + ParameterPtr loop_count = std::make_shared(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(loop_count); + loop_count->set_name(kLoopCountParamName); + loop_count->set_abstract(paremeter_abstract_ptr); + ParameterPtr loop_count_new = kernel_graph_ptr->NewParameter(loop_count); + + (*switch_loop_input)[kLoopCountParamName] = loop_count_new; + + ParameterPtr iter_loop = std::make_shared(kernel_graph_ptr); + iter_loop->set_name(kIterLoopParamName); + iter_loop->set_abstract(paremeter_abstract_ptr); + ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop); + (*switch_loop_input)[kIterLoopParamName] = iter_loop_new; + + ParameterPtr zero = std::make_shared(kernel_graph_ptr); + zero->set_name(kZeroParamName); + zero->set_abstract(paremeter_abstract_ptr); + ParameterPtr zero_new = kernel_graph_ptr->NewParameter(zero); + (*switch_loop_input)[kZeroParamName] = zero_new; + + ParameterPtr one = std::make_shared(kernel_graph_ptr); + one->set_name(kOneParamName); + one->set_abstract(paremeter_abstract_ptr); + ParameterPtr one_new = kernel_graph_ptr->NewParameter(one); + (*switch_loop_input)[kOneParamName] = one_new; + + ParameterPtr epoch = std::make_shared(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(epoch); + epoch->set_name(kEpochParamName); + epoch->set_abstract(paremeter_abstract_ptr); + ParameterPtr epoch_new = kernel_graph_ptr->NewParameter(epoch); + (*switch_loop_input)[kEpochParamName] = epoch_new; +} + +kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBuilder( + const std::vector &formats, const std::vector &type_ids) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat(formats); + selected_kernel_builder.SetInputsDeviceType(type_ids); + + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICORE); + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + return selected_kernel_builder; +} + +CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( + {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + auto typeNone_abstract = std::make_shared(); + auto stream_switch = std::make_shared(kStreamSwitchOpName); + std::vector inputs; + inputs.push_back(NewValueNode(stream_switch)); + inputs.push_back(switch_loop_input.at(kLoopCountParamName)); + inputs.push_back(switch_loop_input.at(kIterLoopParamName)); + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(stream_switch_app); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_switch_app.get()); + stream_switch_app->set_abstract(typeNone_abstract); + // set attr: cond_ RT_LESS + int condition = static_cast(RT_LESS); + ValuePtr cond = MakeValue(condition); + AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); + // set attr:data_type + int data_type = static_cast(RT_SWITCH_INT64); + ValuePtr dt = MakeValue(data_type); + AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); + // set distinction label and graph id + return stream_switch_app; +} + +CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( + {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); + auto stream_active_others = std::make_shared(kStreamActiveOpName); + std::vector inputs; + inputs.push_back(NewValueNode(stream_active_others)); + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(stream_active_others_app); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); + stream_active_others_app->set_abstract(typeNone_abstract); + return stream_active_others_app; +} + +CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &node, size_t output_idx) { + auto idx = NewValueNode(SizeToInt(output_idx)); + MS_EXCEPTION_IF_NULL(idx); + auto imm = std::make_shared(SizeToInt(output_idx)); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_scope(node->scope()); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); + return tuple_getitem; +} + +CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &getnext_cnode) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8}); + + selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE); + selected_kernel_builder.SetProcessor(kernel::Processor::AICPU); + selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL); + + selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8}); + // EndOfSequence + auto end_of_sequence = std::make_shared(kEndOfSequence); + std::vector inputs; + inputs.push_back(NewValueNode(end_of_sequence)); + // GetNext output 0 is EndOfSequence's input + auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0); + inputs.push_back(tuple_get_item); + CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(end_of_sequence_node); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get()); + std::vector input_names = {"x"}; + ValuePtr input_names_v = MakeValue(input_names); + AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node); + std::vector output_names = {"y"}; + ValuePtr output_names_v = MakeValue(output_names); + AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node); + end_of_sequence_node->set_abstract(tuple_get_item->abstract()); + return end_of_sequence_node; +} + +CNodePtr KernelAdjust::CreateStreamAssignAddnOP( + const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( + {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); + selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); + selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); + // AssignAdd + auto assign_add = std::make_shared(kAssignAddOpName); + std::vector inputs; + inputs.push_back(NewValueNode(assign_add)); + inputs.push_back(switch_loop_input.at(kLoopCountParamName)); + inputs.push_back(switch_loop_input.at(kOneParamName)); + CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_one.get()); + std::vector input_names = {"ref", "value"}; + std::vector output_names = {"output"}; + ValuePtr input_names_v = MakeValue(input_names); + ValuePtr output_names_v = MakeValue(output_names); + AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one); + AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one); + selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); + MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); + assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); + return assign_add_one; +} + +bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { + if (!NeedInsertSwitch()) { + return true; + } + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + auto input_nodes = kernel_graph_ptr->inputs(); + std::vector inputs; + LoadSwitchInputs(&inputs); + std::shared_ptr> inputsPtr = std::make_shared>(inputs); + kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr); + size_t input_ctrl_size = inputs.size(); + // inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one. + // deal four ctrl nodes. + for (size_t i = 0; i < inputs.size(); ++i) { + auto tensor = inputs[i]; + size_t deal_index = input_nodes.size() - input_ctrl_size + i; + if (deal_index >= input_nodes.size()) { + MS_LOG(EXCEPTION) << "deal_index[" << deal_index << "] out of range"; + } + auto input_node = input_nodes[deal_index]; + bool need_sync = false; + MS_EXCEPTION_IF_NULL(input_node); + if (input_node->isa()) { + auto pk_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(tensor); + MS_EXCEPTION_IF_NULL(pk_node); + if (tensor->is_dirty() || !pk_node->has_default()) { + need_sync = true; + } + } + if (need_sync) { + auto pk_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + MS_EXCEPTION_IF_NULL(device_address); + tensor->set_device_address(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(INFO) << "SyncHostToDevice failed."; + return false; + } + } + tensor->set_dirty(false); + } + return true; +} + +void KernelAdjust::LoadSwitchInputs(std::vector *inputs) { + MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; + MS_EXCEPTION_IF_NULL(inputs); + std::vector shp = {1}; + tensor::TensorPtr loop_count_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(loop_count_tensor); + int32_t *val = nullptr; + val = static_cast(loop_count_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + inputs->push_back(loop_count_tensor); + + // Epoch in device + tensor::TensorPtr epoch_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(epoch_tensor); + val = static_cast(epoch_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + inputs->push_back(epoch_tensor); + + tensor::TensorPtr iter_loop_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(iter_loop_tensor); + val = static_cast(iter_loop_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num())); + MS_LOG(INFO) << "iter_loop_tensor = " << *val; + inputs->push_back(iter_loop_tensor); + + tensor::TensorPtr zero_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(zero_tensor); + val = static_cast(zero_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 0; + inputs->push_back(zero_tensor); + + tensor::TensorPtr one_tensor = std::make_shared(kInt32->type_id(), shp); + MS_EXCEPTION_IF_NULL(one_tensor); + val = static_cast(one_tensor->data_c()); + MS_EXCEPTION_IF_NULL(val); + *val = 1; + inputs->push_back(one_tensor); + + MS_LOG(INFO) << "---------------- LoadSwitchInputs End--"; +} + +void KernelAdjust::Profiling(NotNull kernel_graph_ptr) { + if (!ascend::ProfilingManager::GetInstance().IsProfiling()) { + MS_LOG(INFO) << "No need to profiling"; + return; + } + ProfilingTraceInfo profiling_trace_info = ProfilingUtils::GetProfilingTraceFromEnv(kernel_graph_ptr); + if (!profiling_trace_info.IsValid()) { + MS_LOG(WARNING) << "[profiling] no profiling node found!"; + return; + } + InsertProfilingKernel(profiling_trace_info, kernel_graph_ptr); +} + +void KernelAdjust::InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, + NotNull kernel_graph_ptr) { + MS_LOG(INFO) << "[profiling] Insert profiling kernel start"; + if (!profiling_trace_info.IsValid()) { + MS_LOG(WARNING) << "Profiling trace point not found"; + return; + } + std::vector new_cnode_list; + std::vector cnode_ptr_list = kernel_graph_ptr->execution_order(); + if (cnode_ptr_list.empty()) { + MS_LOG(ERROR) << "No CNode in graph"; + return; + } + for (const auto &cnode_ptr : cnode_ptr_list) { + ProfilingUtils::ProfilingTraceFpStart(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + new_cnode_list.emplace_back(cnode_ptr); + ProfilingUtils::ProfilingCustomOp(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + ProfilingUtils::ProfilingTraceBpEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + ProfilingUtils::ProfilingTraceEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr, NOT_NULL(&new_cnode_list)); + } + kernel_graph_ptr->set_execution_order(new_cnode_list); +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.h b/mindspore/ccsrc/runtime/device/kernel_adjust.h new file mode 100644 index 0000000000..dbd6f226af --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ + +#include +#include +#include +#include +#include +#include "ir/anf.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/session/session_context.h" +#include "ir/tensor.h" +#include "runtime/device/ascend/profiling/profiling_utils.h" +#include "runtime/device/kernel_info.h" + +using mindspore::device::ascend::ProfilingTraceInfo; +using mindspore::device::ascend::ProfilingUtils; +namespace mindspore { +constexpr auto kLoopCountParamName = "loop_count"; +constexpr auto kIterLoopParamName = "iter_loop"; +constexpr auto kZeroParamName = "zero"; +constexpr auto kOneParamName = "one"; +constexpr auto kEpochParamName = "loop_epoch"; +constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; +constexpr uint32_t kSecondStreamSwitchLabel = 2; + +namespace device { +class KernelAdjust { + public: + static KernelAdjust &GetInstance() { + static KernelAdjust instance; + return instance; + } + + void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); + bool StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr); + void Profiling(NotNull kernel_graph_ptr); + static bool NeedInsertSwitch(); + CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); + + private: + KernelAdjust() = default; + ~KernelAdjust() = default; + + void ReorderGetNext(const std::shared_ptr &kernel_graph_ptr); + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, + std::map *switch_loop_input); + CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input); + CNodePtr CreatTupleGetItemNode(const std::shared_ptr &kernel_graph_ptr, const CNodePtr &node, + size_t output_idx); + CNodePtr CreateEndOfSequenceOP(const std::shared_ptr &kernel_graph_ptr, + const CNodePtr &getnext_cnode); + CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, + const std::map &switch_loop_input); + kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, + const std::vector &type_ids); + void LoadSwitchInputs(std::vector *inputs); + void InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info, + NotNull kernel_graph_ptr); +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_ADJUST_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_info.cc b/mindspore/ccsrc/runtime/device/kernel_info.cc new file mode 100644 index 0000000000..692532e70b --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_info.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_info.h" + +namespace mindspore { +namespace device { +const kernel::KernelBuildInfo *KernelInfo::select_kernel_build_info() const { return select_kernel_build_info_.get(); } + +kernel::KernelBuildInfoPtr KernelInfo::GetMutableSelectKernelBuildInfo() const { return select_kernel_build_info_; } + +const DeviceAddress *KernelInfo::GetOutputAddr(size_t index) const { + if (index >= output_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return output_address_list_[index].get(); +} + +DeviceAddressPtr KernelInfo::GetMutableOutputAddr(size_t index) const { + if (index >= output_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return output_address_list_[index]; +} + +bool KernelInfo::OutputAddrExist(size_t index) const { + if (index >= output_address_list_.size()) { + return false; + } + return output_address_list_[index] != nullptr; +} + +bool KernelInfo::SetOutputAddr(const DeviceAddressPtr &output_address, size_t index) { + // parameter and valuenode + if (kernel_mod_ == nullptr && index >= output_address_list_.size()) { + for (size_t i = output_address_list_.size(); i <= index; i++) { + output_address_list_.emplace_back(nullptr); + } + } else if (output_address_list_.empty()) { + // set cnode + for (size_t i = 0; i < kernel_mod_->GetOutputSizeList().size(); i++) { + output_address_list_.emplace_back(nullptr); + } + } + if (index >= output_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return false; + } + output_address_list_[index] = output_address; + return true; +} + +DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { + if (index >= workspace_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return workspace_address_list_[index].get(); +} + +bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { + if (workspace_address_list_.empty()) { + // parameter and valuenode + if (kernel_mod_ == nullptr) { + workspace_address_list_.emplace_back(nullptr); + } else { + // set cnode + for (size_t i = 0; i < kernel_mod_->GetWorkspaceSizeList().size(); i++) { + workspace_address_list_.emplace_back(nullptr); + } + } + } + if (index >= workspace_address_list_.size()) { + MS_LOG(ERROR) << "Index" << index << " out of range"; + return false; + } + workspace_address_list_[index] = output_address; + return true; +} + +void KernelInfo::set_kernel_mod(const kernel::KernelModPtr &kernel_mod) { kernel_mod_ = kernel_mod; } + +kernel::KernelMod *KernelInfo::MutableKernelMod() const { return kernel_mod_.get(); } + +const kernel::KernelMod *KernelInfo::kernel_mod() const { return kernel_mod_.get(); } + +bool KernelInfo::operator==(const KernelInfo &other) const { + if (stream_id_ != other.stream_id_ || stream_distinction_label_ != other.stream_distinction_label_ || + graph_id_ != other.graph_id_) { + return false; + } + if ((select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ == nullptr) || + (select_kernel_build_info_ == nullptr && other.select_kernel_build_info_ != nullptr)) { + return false; + } + if (select_kernel_build_info_ != nullptr && other.select_kernel_build_info_ != nullptr) { + if (!(*select_kernel_build_info_ == *(other.select_kernel_build_info_))) { + return false; + } + } + // Currently we only check whether both the kernel_mod_ are initialized or uninitialized. + if ((kernel_mod_ == nullptr && other.kernel_mod_ != nullptr) || + (kernel_mod_ != nullptr && other.kernel_mod_ == nullptr)) { + return false; + } + // Currently we only check whether both the sizes are equal of output_address_list_ and workspace_address_list_ or + // not. We can complete this check in the future. + if (output_address_list_.size() != other.output_address_list_.size() || + workspace_address_list_.size() != other.workspace_address_list_.size()) { + return false; + } + return true; +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h new file mode 100644 index 0000000000..b8ab985c86 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -0,0 +1,85 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DEVICE_KERNEL_INFO_H_ +#define MINDSPORE_DEVICE_KERNEL_INFO_H_ + +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" +#include "runtime/device/ascend/ascend_device_address.h" +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +const uint32_t kInvalidGraphId = UINT32_MAX; +const uint32_t kInvalidDistincLabel = UINT32_MAX; +namespace device { +class KernelInfo { + public: + KernelInfo() { + kernel_mod_ = nullptr; + is_feature_map_ = false; + select_kernel_build_info_ = nullptr; + output_address_list_ = {}; + workspace_address_list_ = {}; + stream_id_ = UINT32_MAX; + stream_distinction_label_ = kInvalidDistincLabel; + graph_id_ = kInvalidGraphId; + } + virtual ~KernelInfo() = default; + + const kernel::KernelBuildInfo *select_kernel_build_info() const; + kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; + void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + select_kernel_build_info_ = select_kernel_build_info; + } + void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } + const DeviceAddress *GetOutputAddr(size_t index) const; + DeviceAddressPtr GetMutableOutputAddr(size_t index) const; + bool OutputAddrExist(size_t index) const; + bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); + DeviceAddress *GetWorkspaceAddr(size_t index) const; + bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); + void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); + kernel::KernelMod *MutableKernelMod() const; + const kernel::KernelMod *kernel_mod() const; + uint32_t stream_id() const { return stream_id_; } + void set_stream_id(uint32_t stream_id) { stream_id_ = stream_id; } + uint32_t stream_distinction_label() const { return stream_distinction_label_; } + void set_stream_distinction_label(uint32_t stream_distinction_label) { + stream_distinction_label_ = stream_distinction_label; + } + void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } + uint32_t graph_id() const { return graph_id_; } + bool operator==(const KernelInfo &other) const; + bool is_feature_map() const { return is_feature_map_; } + + private: + bool is_feature_map_; + kernel::KernelBuildInfoPtr select_kernel_build_info_; + std::vector> output_address_list_; + std::vector> workspace_address_list_; + kernel::KernelModPtr kernel_mod_; + // stream_id_ is the index of stream object vector + uint32_t stream_id_; + // stream_distinction_label_ is used mark different op in different stream + uint32_t stream_distinction_label_; + // record which graph the node belong to + uint32_t graph_id_; +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_DEVICE_KERNEL_INFO_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc new file mode 100644 index 0000000000..49fddcae45 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -0,0 +1,772 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_runtime.h" +#include +#include +#include +#include +#include "common/utils.h" +#include "common/trans.h" +#include "utils/utils.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/kernel_compiler/oplib/oplib.h" +#include "ir/value.h" +using mindspore::kernel::Address; +using mindspore::kernel::AddressPtr; + +namespace mindspore { +namespace device { +KernelRuntime::~KernelRuntime() { +#ifdef ENABLE_DUMP_E2E + dump_conf_ptr_ = nullptr; +#endif +} + +bool KernelRuntime::Run(session::KernelGraph *graph) { + bool ret = false; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); +#if defined(_WIN32) || defined(_WIN64) + auto start_time = std::chrono::steady_clock::now(); +#else + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); +#endif + bool is_task_sink = context_ptr->enable_task_sink(); + if (is_task_sink) { + ret = RunTask(graph); + } else { + ret = LaunchKernel(graph); + } +#if defined(_WIN32) || defined(_WIN64) + auto end_time = std::chrono::steady_clock::now(); + std::chrono::duration> cost = end_time - start_time; + MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us"; +#else + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Call MS Run Success in " << cost << " us"; +#endif + return ret; +} + +// for D to impl +bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +// for D to impl +bool KernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { + if (graph != nullptr) { + return true; + } + return false; +} + +// for D to impl +bool KernelRuntime::GenTask(const session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +bool KernelRuntime::LoadTask(const session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +// for D to impl +bool KernelRuntime::RunTask(const session::KernelGraph *graph) { + if (graph != nullptr) { + return true; + } + return false; +} + +bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::OutputAddrExist(kernel, index)) { + return true; + } + return false; +} + +size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { + MS_EXCEPTION_IF_NULL(node); + if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { + MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" + << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; + } + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); + } + size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); + auto format = AnfAlgo::GetOutputFormat(node, output_index); + if (shape.empty() && format != kOpFormat_DEFAULT) { + shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); + shape = trans::TransShapeToDevice(shape, format); + } + // scalar's output shape is a empty vector + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + return tensor_size; +} + +void KernelRuntime::AssignMemory(session::KernelGraph *graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(mem_manager_); + mem_manager_->ResetDynamicMemory(); + AssignStaticMemory(graph); + AssignDynamicMemory(graph); + UpdateRefNodeOutputMem(graph); +} + +void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, + session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + RunOpAssignInputMemory(input_tensors, graph); + AssignStaticMemoryValueNode(graph); + for (const auto &cnode : graph->execution_order()) { + RunOpAssignOutputMemory(cnode); + RunOpAssignWorkSpaceMemory(cnode); + } + UpdateRefNodeOutputMem(graph); +} + +void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + // clear input parameter memory resource + for (const auto &input_node : graph->inputs()) { + MS_EXCEPTION_IF_NULL(input_node); + AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); + } + // clear input value node memory resource + for (const auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get()); + } + for (const auto &cnode : graph->execution_order()) { + MS_EXCEPTION_IF_NULL(cnode); + // clear output memory resource + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + AnfAlgo::SetOutputAddr(nullptr, index, cnode.get()); + } + // clear workspace memory resource + auto kernel_mod = AnfAlgo::GetKernelMod(cnode); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); + for (size_t index = 0; index < workspace_lists.size(); ++index) { + AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get()); + } + } +} + +void KernelRuntime::AssignStaticMemory(session::KernelGraph *graph) { + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); + AssignStaticMemoryOutput(graph); +} + +void KernelRuntime::RunOpAssignInputMemory(const std::vector &input_tensors, + const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (input_tensors.size() != graph->inputs().size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() + << " should be equal to graph input parameter size " << graph->inputs().size(); + } + + for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) { + auto item = graph->inputs()[input_index]; + MS_EXCEPTION_IF_NULL(item); + if (!item->isa()) { + continue; + } + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + MS_EXCEPTION_IF_NULL(input_tensors[input_index]); + if (input_tensors[input_index]->device_address().get() != nullptr) { + AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get()); + continue; + } + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(item, index); + } + auto tensor_size = CountNodeDeviceMemorySize(item, index); + auto device_address = + CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + MS_EXCEPTION_IF_NULL(device_address); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } + AnfAlgo::SetOutputAddr(device_address, index, item.get()); + } + } +} + +void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + return; + } + + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + continue; + } + std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i)); + MS_EXCEPTION_IF_NULL(device_address); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + } +} + +void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (kernel->isa()) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_lists.size(); ++i) { + auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown); + MS_EXCEPTION_IF_NULL(device_address); + auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]); + if (!ret) { + MS_LOG(EXCEPTION) << "Malloc device memory failed."; + } + AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); + } + } +} + +void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto graph_inputs = graph->inputs(); + auto graph_valid_input = graph->valid_inputs(); + std::vector need_alloc_nodes; + for (size_t i = 0; i < graph_inputs.size(); ++i) { + auto item = graph_inputs[i]; + MS_EXCEPTION_IF_NULL(item); + if (i < graph_valid_input.size() && !graph_valid_input[i]) { + continue; + } + + if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { + auto outs = AnfAlgo::GetAllOutput(item); + for (auto &out : outs) { + MS_EXCEPTION_IF_NULL(out); + if (!out->isa()) { + continue; + } + if (NodeOutputDeviceAddressExist(out, 0)) { + continue; + } + need_alloc_nodes.push_back(out); + } + } + if (!item->isa()) { + continue; + } + if (NodeOutputDeviceAddressExist(item, 0)) { + continue; + } + need_alloc_nodes.push_back(item); + } + + for (auto &item : need_alloc_nodes) { + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown + if (output_type_id == kTypeUnknown) { + MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; + output_type_id = AnfAlgo::GetOutputInferDataType(item, index); + } + auto tensor_size = CountNodeDeviceMemorySize(item, index); + auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); + auto address = CreateDeviceAddress(ptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + AnfAlgo::SetOutputAddr(address, index, item.get()); + } + } +} + +void KernelRuntime::AssignStaticMemoryOutput(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); + std::vector non_communication_op; + // Assign Communicate Op Memory firstly. + for (const auto &node : nodes) { + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + MS_EXCEPTION_IF_NULL(item_with_index.first); + if (!item_with_index.first->isa() || !AnfAlgo::IsRealKernel(item_with_index.first)) { + continue; + } + graph->AddFinalOutputKernel(item_with_index.first); + if (AnfAlgo::IsCommunicationOp(item_with_index.first)) { + AssignCommunicationNodeMem(kStaticMem, item_with_index.first); + } else { + non_communication_op.emplace_back(item_with_index); + } + } + + for (const auto &item_with_index : non_communication_op) { + AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); + } +} + +void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + MS_LOG(INFO) << "This kernel has no output size."; + continue; + } + for (size_t i = 0; i < output_sizes.size(); ++i) { + session::AnfWithOutIndex out_pair(kernel, i); + if (graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = graph->GetRefCorrespondOutput(out_pair); + MS_EXCEPTION_IF_NULL(origin_pair.first); + auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second); + MS_EXCEPTION_IF_NULL(origin_node_output_addr); + auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i); + if (origin_node_output_addr.get() != cur_node_output_addr.get()) { + MS_LOG(INFO) << "REF address is not same, ref node output need address update"; + MS_LOG(INFO) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is " + << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i; + AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get()); + } + } + } + } +} + +void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { + AssignCommunicationNodeInputMem(node); + AssignCommunicationNodeOutputMem(flag, node); +} + +void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; + return; + } + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + size_t total_size = 0; + size_t output_index = 0; + std::vector align_size_list; + for (uint64_t mem_size : output_sizes) { + if (AnfAlgo::OutputAddrExist(node, output_index++)) { + MS_LOG(INFO) << "communication op addr exist"; + continue; + } + if (context_ptr->enable_hccl()) { + mem_size = mem_manager_->GetCommonAlignSize(mem_size); + } + total_size += mem_size; + align_size_list.emplace_back(mem_size); + } + uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); + for (size_t j = 0; j < align_size_list.size(); ++j) { + std::string output_format = AnfAlgo::GetOutputFormat(node, j); + auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j); + auto address = CreateDeviceAddress(output_ptr, output_sizes[j], output_format, output_type); + MS_EXCEPTION_IF_NULL(address); + if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { + address->UpdateCommunicationAddress(); + } + AnfAlgo::SetOutputAddr(address, j, node.get()); + output_ptr += align_size_list[j]; + } +} + +DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.size() <= index) { + MS_LOG(EXCEPTION) << "Previous node output size < node index"; + } + std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index); + auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index); + auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type); + AnfAlgo::SetOutputAddr(address, index, anf_node.get()); + return address; +} + +void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + size_t total_size = 0; + std::vector> addr_size; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); + auto input_node = input_node_with_index.first; + DeviceAddressPtr address = nullptr; + if (input_node->isa()) { + address = PreAssignCNodeMemory(input_node, input_node_with_index.second); + } else { + MS_LOG(EXCEPTION) << "Communication node inputs only support CNode"; + } + MS_EXCEPTION_IF_NULL(address); + auto mem_size = mem_manager_->GetCommonAlignSize(address->size()); + total_size += mem_size; + addr_size.emplace_back(address.get(), mem_size); + } + uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); + for (const auto &iter : addr_size) { + MS_EXCEPTION_IF_NULL(iter.first); + iter.first->set_ptr(input_ptr); + input_ptr += iter.second; + } +} + +void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + if (AnfAlgo::IsGetNext(NOT_NULL(node)) && flag == kReuseDynamicMem) { + MS_LOG(INFO) << "GetNext disable mem_reuse"; + flag = kDynamicMem; + } + auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + if (output_sizes.empty()) { + MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size."; + return; + } + for (size_t i = 0; i < output_sizes.size(); ++i) { + if ((kGetAllOuts != index) && (SizeToInt(i) != index)) { + continue; + } + if (NodeOutputDeviceAddressExist(node, i)) { + MS_LOG(INFO) << "Already malloc index:" << i; + continue; + } + auto ptr = mem_manager_->MallocOutputMem(node, i, flag, output_sizes[i]); + if (ptr == nullptr) { + // reused ptr, no need alloc, continue; + continue; + } + std::string output_format = AnfAlgo::GetOutputFormat(node, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); + auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type); + MS_EXCEPTION_IF_NULL(device_address); + device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); + if (AnfAlgo::IsCommunicationOp(node) && context_ptr->enable_hccl()) { + device_address->UpdateCommunicationAddress(); + } + AnfAlgo::SetOutputAddr(device_address, i, node.get()); + } +} + +void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, + size_t output_idx) { + MS_EXCEPTION_IF_NULL(value_node); + MS_EXCEPTION_IF_NULL(node_value); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto tensor = node_value->cast(); + if (tensor == nullptr) { + MS_LOG(WARNING) << "Tensor is null"; + return; + } + size_t tensor_size = tensor->data().nbytes(); + auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); + } + auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); + DeviceAddressPtr address = nullptr; + if (ms_context->enable_pynative_infer()) { + address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { + MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + } + } else { + auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); + address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + } + AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); + if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), + tensor->data_c())) { + MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" + << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " + << AnfAlgo::GetOutputInferDataType(value_node, output_idx); + } +} + +void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + if (NodeOutputDeviceAddressExist(value_node, 0)) { + MS_LOG(INFO) << "value_node[" << value_node->DebugString() << "] address already exist"; + continue; + } + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (node_value->isa()) { + AssignValueNodeTensor(value_node, node_value, 0); + } else if (node_value->isa()) { + auto value = GetValue(node_value); + size_t tensor_size = value.size(); + DeviceAddressPtr address = nullptr; + if (ms_context->enable_pynative_infer()) { + address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { + MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + } + } else { + auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); + address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + } + AnfAlgo::SetOutputAddr(address, 0, value_node.get()); + std::vector shape = {1, SizeToInt(tensor_size)}; + if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) { + MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!"; + } + } + } +} + +void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + bool is_enable_mem_reuse = context_ptr->enable_mem_reuse(); + auto mem_flag = kDynamicMem; + if (is_enable_mem_reuse) { + mem_manager_->MallocReusedDynamicMem(graph); + mem_flag = kReuseDynamicMem; + } + auto &execution_nodes = graph->execution_order(); + std::vector compute_nodes; + // communication nodes first + for (auto &node : execution_nodes) { + if (AnfAlgo::IsCommunicationOp(node)) { + // skip if the memory is already alocated + AssignCommunicationNodeMem(mem_flag, node); + } else { + compute_nodes.emplace_back(node); + } + } + + // then compute nodes + for (auto &node : compute_nodes) { + AssignNodeOutputMem(mem_flag, node, kGetAllOuts); + AssignWorkSpaceMem(mem_flag, node); + } +} + +void KernelRuntime::AssignWorkSpaceMem(int flag, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(mem_manager_); + auto kernel_mod = AnfAlgo::GetKernelMod(node); + MS_EXCEPTION_IF_NULL(kernel_mod); + size_t index = 0; + for (auto &size : kernel_mod->GetWorkspaceSizeList()) { + auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, flag, size); + AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get()); + index++; + } +} + +void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces, + AddressPtrList *kernel_outputs) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(kernel_inputs); + MS_EXCEPTION_IF_NULL(kernel_workspaces); + MS_EXCEPTION_IF_NULL(kernel_outputs); + auto cnode = kernel->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) { + return GenAddrCleanLaunchArgs(cnode, kernel_inputs); + } + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); + auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); + MS_EXCEPTION_IF_NULL(device_address); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->emplace_back(input); + } + + for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetOutputAddr(kernel, i); + kernel::AddressPtr output = std::make_shared(); + MS_EXCEPTION_IF_NULL(output); + output->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(output->addr); + output->size = device_address->size_; + kernel_outputs->emplace_back(output); + } + + for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_workspaces->emplace_back(workspace); + } +} + +void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs) { + if (cnode->inputs().size() != 2) { + MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2."; + } + MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); + auto pre_node = (cnode->inputs()[1])->cast(); + // set clean output address + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->emplace_back(input); + } + MS_LOG(INFO) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); + } + // set clean workspace address + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto clean_workspaces_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + for (const auto &index : clean_workspaces_indexs) { + auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_inputs->emplace_back(workspace); + } + } +} + +bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { + auto &kernels = graph.execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + return false; + } + } + return true; +} + +bool KernelRuntime::LaunchKernel(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + if (!LaunchKernelMod(*graph)) { + MS_LOG(ERROR) << "LaunchKernelMod failed!"; + return false; + } + return true; +} + +void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { + MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource"; +} + +#ifdef ENABLE_DUMP_E2E +bool KernelRuntime::SetDumpConf() { + dump_conf_ptr_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(dump_conf_ptr_); + bool ret = dump_conf_ptr_->SetDumpConfFromJsonFile(); + return ret; +} + +DumpConfPtr KernelRuntime::GetDumpConf() { return dump_conf_ptr_; } +#endif +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h new file mode 100644 index 0000000000..8320355b82 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -0,0 +1,122 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ +#define MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ +#include +#include +#include +#include + +#include "runtime/device/device_address.h" +#include "ir/tensor.h" +#include "predict/generator/utils/ir_model_util.h" +#ifdef ENABLE_DUMP_E2E +#include "debug/e2e_dump.h" +#endif +#ifdef ENABLE_DEBUGGER +#include "debug/debugger/debugger.h" +#endif +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/kernel.h" +#include "utils/context/ms_context.h" +#include "runtime/device/memory_manager.h" + +using mindspore::tensor::Tensor; +using std::vector; +using TensorPtr = std::shared_ptr; +using mindspore::kernel::AddressPtr; +using AddressPtrList = std::vector; + +namespace mindspore { +#ifndef ENABLE_DEBUGGER +class Debugger; +#endif +namespace device { +class KernelRuntime { + public: + KernelRuntime() = default; + virtual ~KernelRuntime(); + virtual bool Init() = 0; + virtual void AssignMemory(session::KernelGraph *graph); + void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); + void RunOpClearMemory(const session::KernelGraph *graph); + virtual bool Run(session::KernelGraph *graph); + virtual bool DumpData(session::KernelGraph *graph); + virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); + virtual bool RunTask(const session::KernelGraph *graph); + virtual bool GenTask(const session::KernelGraph *graph); + bool LaunchKernel(const session::KernelGraph *graph); + virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); + virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); + virtual void ClearGraphRuntimeResource(uint32_t graph_id); + virtual bool SyncStream() = 0; + +#ifdef ENABLE_DUMP_E2E + DumpConfPtr GetDumpConf(); +#endif + virtual bool LoadTask(const session::KernelGraph *graph); + // for GPU and D to impl + virtual void ReleaseDeviceRes() {} + void set_device_id(uint32_t device_id) { device_id_ = device_id; } + + protected: + virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) = 0; + virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); + void AssignStaticMemory(session::KernelGraph *graph); + void AssignDynamicMemory(session::KernelGraph *graph); + void ReuseAssignDynamicMemory(session::KernelGraph *graph); + void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); + void AssignWorkSpaceMem(int flag, const AnfNodePtr &node); + void AssignReuseWorkSpaceMem(const AnfNodePtr &node); + + void UpdateRefNodeOutputMem(const session::KernelGraph *graph); + + void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); + void AssignCommunicationNodeInputMem(const AnfNodePtr &node); + void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); +#ifdef ENABLE_DUMP_E2E + bool SetDumpConf(); +#endif + + private: + void AssignStaticMemoryOutput(session::KernelGraph *graph); + void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); + bool LaunchKernelMod(const session::KernelGraph &graph); + void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); + size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); + void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); + void RunOpAssignOutputMemory(const AnfNodePtr &kernel); + void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); + void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); + DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); + + protected: + uint32_t device_id_{0}; +#ifdef ENABLE_DUMP_E2E + DumpConfPtr dump_conf_ptr_; +#endif + void *stream_ = nullptr; + std::shared_ptr mem_manager_{nullptr}; +}; +using KernelRuntimePtr = std::shared_ptr; +} // namespace device +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_H_ diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc new file mode 100644 index 0000000000..626259f9ce --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/kernel_runtime_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +void KernelRuntimeManager::ClearRuntimeResource() { + std::lock_guard guard(lock_); + for (auto &iter : runtime_map_) { + MS_LOG(INFO) << "Release device " << iter.first; + MS_EXCEPTION_IF_NULL(iter.second); + iter.second->ReleaseDeviceRes(); + } + runtime_map_.clear(); +} + +void KernelRuntimeManager::ClearGraphResource(uint32_t graph_id) { + std::lock_guard guard(lock_); + for (auto &iter : runtime_map_) { + MS_LOG(INFO) << "Clear device " << iter.first << " graph " << graph_id << " runtime resource"; + if (!iter.second) { + MS_LOG(ERROR) << "Kernel runtime is nullptr"; + continue; + } + iter.second->ClearGraphRuntimeResource(graph_id); + } +} + +void KernelRuntimeManager::Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { + if (runtime_creators_.find(device_name) == runtime_creators_.end()) { + (void)runtime_creators_.emplace(device_name, runtime_creator); + } +} + +KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id) { + std::string runtime_key = device_name + "_" + std::to_string(device_id); + auto runtime_iter = runtime_map_.find(runtime_key); + if (runtime_iter != runtime_map_.end()) { + return runtime_iter->second.get(); + } else if (runtime_map_.size() > 0) { + auto cur_runtime_key = runtime_map_.begin()->first; + auto find_pos = cur_runtime_key.rfind('_'); + if (find_pos != std::string::npos) { + if (cur_runtime_key.size() > find_pos + 1) { + auto cur_device_id = cur_runtime_key.substr(find_pos + 1); + MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id + << ", set device id: " << device_id << " failed"; + } else { + MS_LOG(EXCEPTION) << "Can't change device id in runtime, current runtime_key size error, set device id: " + << device_id << " failed"; + } + } + } + return GetKernelRuntime(device_name, device_id); +} + +KernelRuntime *KernelRuntimeManager::GetKernelRuntime(const std::string &device_name, uint32_t device_id) { + std::lock_guard guard(lock_); + std::string runtime_key = device_name + "_" + std::to_string(device_id); + auto runtime_iter = runtime_map_.find(runtime_key); + if (runtime_iter != runtime_map_.end()) { + return runtime_iter->second.get(); + } + std::shared_ptr kernel_runtime; + auto creator_iter = runtime_creators_.find(device_name); + if (creator_iter != runtime_creators_.end()) { + MS_EXCEPTION_IF_NULL(creator_iter->second); + kernel_runtime = (creator_iter->second)(); + kernel_runtime->set_device_id(device_id); + MS_EXCEPTION_IF_NULL(kernel_runtime); + runtime_map_[runtime_key] = kernel_runtime; + } else { + MS_LOG(EXCEPTION) << "No kernel runtime creator for " << device_name << " with device id " << device_id; + } + + return kernel_runtime.get(); +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h new file mode 100644 index 0000000000..7fcb40ae67 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.h @@ -0,0 +1,65 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ +#include +#include +#include +#include +#include +#include +#include "common/utils.h" +#include "runtime/device/kernel_runtime.h" +namespace mindspore { +namespace device { +using KernelRuntimeCreator = std::function()>; + +class KernelRuntimeManager { + public: + static KernelRuntimeManager &Instance() { + static KernelRuntimeManager instance; + return instance; + } + void Register(const std::string &device_name, KernelRuntimeCreator &&runtime_creator); + KernelRuntime *GetKernelRuntime(const std::string &device_name, uint32_t device_id); + KernelRuntime *GetSingleKernelRuntime(const std::string &device_name, uint32_t device_id); + void ClearRuntimeResource(); + void ClearGraphResource(uint32_t graph_id); + + private: + KernelRuntimeManager() = default; + ~KernelRuntimeManager() = default; + DISABLE_COPY_AND_ASSIGN(KernelRuntimeManager); + std::map > runtime_map_; + std::map runtime_creators_; + std::mutex lock_; +}; + +class KernelRuntimeRegistrar { + public: + KernelRuntimeRegistrar(const std::string &device_name, KernelRuntimeCreator &&runtime_creator) { + KernelRuntimeManager::Instance().Register(device_name, std::move(runtime_creator)); + } + ~KernelRuntimeRegistrar() = default; +}; + +#define MS_REG_KERNEL_RUNTIME(DEVICE_NAME, RUNTIME_CLASS) \ + static const KernelRuntimeRegistrar g_kernel_runtime_##DEVICE_NAME##_reg( \ + DEVICE_NAME, []() { return std::make_shared(); }); +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_KERNEL_RUNTIME_MANAGER_H_ diff --git a/mindspore/ccsrc/runtime/device/memory_manager.cc b/mindspore/ccsrc/runtime/device/memory_manager.cc new file mode 100644 index 0000000000..563d5f0f50 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/memory_manager.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/memory_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" +using mindspore::memreuse::BestFitMemReuse; +using mindspore::memreuse::MemReuseUtilPtr; +namespace mindspore { +namespace device { +size_t MemoryManager::GetCommonAlignSize(size_t input_size) const { + return (input_size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize; +} + +size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) const { + return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + 2 * kMemAlignSize; +} + +void MemoryManager::MallocReusedDynamicMem(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr); + // set all infos + mem_reuse_util_ptr->SetAllInfo(graph); + auto bestfit_mem_reuse = std::make_shared(); + MS_EXCEPTION_IF_NULL(bestfit_mem_reuse); + bestfit_mem_reuse->Reuse(mem_reuse_util_ptr.get()); + size_t total_allocated_size = bestfit_mem_reuse->GetAllocatedSize(); + MS_LOG(INFO) << "TotalReuseDynamicSize [" << total_allocated_size << "]"; + mem_reuse_util_ptr_ = mem_reuse_util_ptr; + auto base_ptr = MallocDynamicMem(total_allocated_size, false); + mem_reuse_util_ptr_->set_mem_base(base_ptr); +} + +uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { + MS_EXCEPTION_IF_NULL(node); + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + uint8_t *ptr = nullptr; + if (AnfAlgo::IsCommunicationOp(node)) { + bool communication_mem = false; + if (context_ptr->enable_hccl()) { + communication_mem = true; + } + if (flag == kStaticMem) { + ptr = MallocStaticMem(size, communication_mem); + } else { + ptr = MallocDynamicMem(size, communication_mem); + } + return ptr; + } + + if (flag == kStaticMem) { + ptr = MallocStaticMem(size, false); + } else if (flag == kDynamicMem) { + ptr = MallocDynamicMem(size, false); + } else if (flag == kReuseDynamicMem) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); + ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index); + } + return ptr; +} + +uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size) { + if (flag == kReuseDynamicMem) { + MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_); + return mem_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index); + } + return MallocDynamicMem(size, false); +} + +uint8_t *MemoryManager::MallocMem(int flag, size_t size) { + uint8_t *ptr = nullptr; + if (flag == kStaticMem) { + ptr = MallocStaticMem(size, false); + } else if (flag == kDynamicMem) { + ptr = MallocDynamicMem(size, false); + } + return ptr; +} + +uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + + MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + + if (static_mem_offset_ < align_size) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + total_static_size_ += align_size; + auto offset = static_mem_offset_ - align_size; + if (dynamic_mem_offset_ > offset) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + static_mem_offset_ = offset; + if (communication_mem) { + return device_mem_base_ + offset + kMemAlignSize; + } else { + return device_mem_base_ + offset; + } +} + +uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) { + size_t align_size = 0; + if (communication_mem) { + align_size = GetCommunicationAlignSize(size); + } else { + align_size = GetCommonAlignSize(size); + } + + MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] communication_mem: " << communication_mem; + + uint64_t offset = dynamic_mem_offset_; + auto new_offset = dynamic_mem_offset_ + align_size; + if (new_offset > static_mem_offset_) { + MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_ + << "] static[" << total_static_size_ << "])" + << " malloc [" << align_size << "] failed!"; + } + total_dynamic_size_ += align_size; + dynamic_mem_offset_ = new_offset; + + if (communication_mem) { + return device_mem_base_ + offset + kMemAlignSize; + } else { + return device_mem_base_ + offset; + } +} + +bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) { + auto device_ptr = MallocMemFromMemPool(size); + if (!device_ptr) { + return false; + } + address->ptr_ = device_ptr; + address->from_mem_pool_ = true; + return true; +} + +void *MemoryManager::MallocMemFromMemPool(size_t size) { + if (size == 0) { + MS_LOG(ERROR) << "MallocMemFromMemPool size is 0."; + } + return nullptr; +} + +void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) { + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(address->ptr_); + FreeMemFromMemPool(address->ptr_); + address->ptr_ = nullptr; +} + +void MemoryManager::FreeMemFromMemPool(void *device_ptr) { + if (device_ptr == nullptr) { + MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; + } +} + +bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); + if (device_ptr_list.size() == 0) { + return false; + } + if (addr_list.size() != device_ptr_list.size()) { + MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; + } + for (size_t i = 0; i < addr_list.size(); i++) { + MS_EXCEPTION_IF_NULL(device_ptr_list[i]); + MS_EXCEPTION_IF_NULL(addr_list[i]); + addr_list[i]->ptr_ = device_ptr_list[i]; + addr_list[i]->from_mem_pool_ = true; + } + return true; +} + +std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + if (total_size == 0) { + MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; + } + std::vector device_ptr_list; + device_ptr_list.emplace_back(nullptr); + return device_ptr_list; +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_manager.h b/mindspore/ccsrc/runtime/device/memory_manager.h new file mode 100644 index 0000000000..3c6fb1b39a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/memory_manager.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ +#define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ +#include +#include +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +namespace mindspore { +namespace device { +const int kStaticMem = 0; +const int kDynamicMem = 1; +const int kReuseDynamicMem = 2; +const int kGetAllOuts = -1; +const uint64_t kMemAlignSize = 512; +using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; + +class MemoryManager { + public: + MemoryManager() = default; + virtual ~MemoryManager() = default; + + virtual void MallocDeviceMemory() = 0; + virtual void FreeDeviceMemory() = 0; + virtual void ResetDynamicMemory() { + total_dynamic_size_ = 0; + dynamic_mem_offset_ = 0; + } + + void MallocReusedDynamicMem(session::KernelGraph *graph); + uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, int flag, size_t size); + uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, int flag, size_t size); + virtual uint8_t *MallocMem(int flag, size_t size); + + virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size); + virtual void *MallocMemFromMemPool(size_t size); + virtual void FreeMemFromMemPool(const DeviceAddressPtr address); + virtual void FreeMemFromMemPool(void *device_ptr); + virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); + + size_t GetCommonAlignSize(size_t input_size) const; + size_t GetCommunicationAlignSize(size_t input_size) const; + + protected: + virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem); + virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); + uint8_t *device_mem_base_{nullptr}; + uint64_t device_mem_size_{0}; + uint64_t dynamic_mem_offset_{0}; + uint64_t static_mem_offset_{0}; + size_t total_static_size_ = 0; + size_t total_dynamic_size_ = 0; + MemReuseUtilPtr mem_reuse_util_ptr_{nullptr}; +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ diff --git a/mindspore/ccsrc/session/CMakeLists.txt b/mindspore/ccsrc/session/CMakeLists.txt deleted file mode 100644 index 782eb51183..0000000000 --- a/mindspore/ccsrc/session/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -file(GLOB_RECURSE _SESSION_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "kernel_graph.cc" - "session_basic.cc" - "session_factory.cc" - "anf_runtime_algorithm.cc" -) - -if (ENABLE_GPU) - file(GLOB_RECURSE _GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "gpu_session.cc" - ) - list(APPEND _SESSION_SRC_LIST ${_GPU_SRC_LIST}) -endif () - -if (ENABLE_CPU) - file(GLOB_RECURSE _CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "cpu_session.cc" - ) - list(APPEND _SESSION_SRC_LIST ${_CPU_SRC_LIST}) -endif () - -if (ENABLE_D) - file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - "ascend_session.cc" - "ascend_control_parser.cc" - "ascend_inference_session.cc" - ) - list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) -endif () - -set_property(SOURCE ${_SESSION_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_SESSION) -add_library(_mindspore_session_obj OBJECT ${_SESSION_SRC_LIST}) diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc deleted file mode 100644 index 81ad02e787..0000000000 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ /dev/null @@ -1,1121 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/anf_runtime_algorithm.h" -#include -#include -#include -#include -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "operator/ops.h" -#include "utils/utils.h" -#include "device/kernel_info.h" -#include "device/device_address.h" -#include "pre_activate/common/helper.h" -#include "kernel/kernel.h" -#include "kernel/kernel_build_info.h" -#include "common/utils.h" -#include "common/trans.h" - -namespace mindspore { -namespace session { -using abstract::AbstractTensor; -using abstract::AbstractTuple; -using device::KernelInfo; -using device::ascend::AscendDeviceAddress; -using kernel::KernelBuildInfoPtr; -using kernel::KernelMod; -using kernel::KernelModPtr; -namespace { -std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { - MS_EXCEPTION_IF_NULL(shape); - std::vector shape_size_t; - std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); - return shape_size_t; -} -} // namespace - -KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) { - MS_EXCEPTION_IF_NULL(anf_node); - if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - auto node = cnode->input(index + IntToSize(1)); - MS_EXCEPTION_IF_NULL(node); - return VisitKernel(node, 0); - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx)); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernel(cnode->input(kRealInputIndexInDepend), 0); - } else { - return std::make_pair(anf_node, index); - } - } else { - MS_LOG(EXCEPTION) << "The input is invalid"; - } -} - -KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index, - bool visit_nop_node, - const std::vector &return_types) { - MS_EXCEPTION_IF_NULL(anf_node); - for (const auto &prim_type : return_types) { - if (CheckPrimitiveType(anf_node, prim_type)) { - return std::make_pair(anf_node, index); - } - } - if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - return std::make_pair(anf_node, 0); - } else if (anf_node->isa()) { - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input0 = cnode->input(0); - MS_EXCEPTION_IF_NULL(input0); - if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem); - MS_EXCEPTION_IF_NULL(input2); - auto value_node = input2->cast(); - MS_EXCEPTION_IF_NULL(value_node); - int item_idx = GetValue(value_node->value()); - return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), - visit_nop_node, return_types); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); - } else if (opt::IsNopNode(cnode) && visit_nop_node) { - if (cnode->inputs().size() == 2) { - return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); - } else { - MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; - } - } else { - return std::make_pair(anf_node, index); - } - } else { - MS_LOG(EXCEPTION) << "The input is invalid"; - } -} - -std::vector AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node, - const std::vector &return_types) { - std::vector ret; - auto return_prim_type = return_types; - // if visited make_tuple should return back - return_prim_type.push_back(prim::kPrimMakeTuple); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type); - if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { - MS_EXCEPTION_IF_NULL(item_with_index.first); - auto make_tuple = item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - for (size_t i = 1; i < make_tuple->inputs().size(); i++) { - auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types); - (void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret)); - } - return ret; - } - ret.push_back(item_with_index.first); - return ret; -} - -AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - return node->input(kAnfPrimitiveIndex); -} - -PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = GetCNodePrimitiveNode(cnode); - MS_EXCEPTION_IF_NULL(attr_input); - auto value_node = attr_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - auto primitive = value->cast(); - return primitive; -} - -bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type); -} - -FuncGraphPtr AnfRuntimeAlgorithm::GetCNodeFuncGraphPtr(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - auto value_node = attr_input->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - return value->cast(); -} - -std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - return primitive->name(); - } - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(func_graph); - return func_graph->ToString(); - } - MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString(); -} - -std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - return node->DebugString(); -} - -void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); - } - // single op cnode. - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - primitive->set_attr(key, value); - return; - } - // graph kernel cnode. - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - fg->set_attr(key, value); -} - -void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) { - CopyNodeAttr(key, key, from, to); -} - -void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, - const AnfNodePtr &to) { - MS_EXCEPTION_IF_NULL(from); - MS_EXCEPTION_IF_NULL(to); - if (!from->isa() || !to->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is " - << to->DebugString(); - } - auto from_primitive = AnfAlgo::GetCNodePrimitive(from); - MS_EXCEPTION_IF_NULL(from_primitive); - auto to_primitive = AnfAlgo::GetCNodePrimitive(to); - MS_EXCEPTION_IF_NULL(to_primitive); - to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key)); -} - -void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) { - MS_EXCEPTION_IF_NULL(from); - MS_EXCEPTION_IF_NULL(to); - if (!from->isa() || !to->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is " - << from->DebugString(); - } - auto from_primitive = AnfAlgo::GetCNodePrimitive(from); - MS_EXCEPTION_IF_NULL(from_primitive); - auto to_primitive = AnfAlgo::GetCNodePrimitive(to); - MS_EXCEPTION_IF_NULL(to_primitive); - (void)to_primitive->SetAttrs(from_primitive->attrs()); -} - -void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString(); - } - // single op cnode. - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - primitive->EraseAttr(key); - return; - } - // graph kernel cnode. - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - fg->erase_flag(key); -} - -bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString(); - return false; - } - // single op cnode. - auto primitive = AnfAlgo::GetCNodePrimitive(node); - if (primitive != nullptr) { - return primitive->HasAttr(key); - } - // graph kernel cnode. - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - return fg->has_attr(key); -} - -size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString(); - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - size_t input_num = cnode->inputs().size(); - if (input_num == 0) { - MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"; - } - // exclude intputs[0],which is value_node storing attr,inputs left are real input - return input_num - 1; -} - -size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - TypePtr type = node->Type(); - if (type == nullptr) { - return 0; - } - if (type->isa()) { - auto tuple_type = type->cast(); - MS_EXCEPTION_IF_NULL(tuple_type); - return tuple_type->size(); - } else if (type->isa() || type->isa()) { - return 1; - } else if (type->isa()) { - return 0; - } else { - return 1; - } -} - -std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "Output index:" << output_idx - << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" - << node->DebugString() << "]"; - } - if (!AnfAlgo::IsRealKernel(node)) { - return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto format = build_info->GetOutputFormat(output_idx); - if (format == kernel::KernelBuildInfo::kInvalidFormat) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid output format"; - } - return format; -} - -std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(node); - if (input_idx > GetInputTensorNum(node)) { - MS_LOG(EXCEPTION) << "Input index :" << input_idx - << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" - << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - GetPrevNodeOutputFormat(node, input_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto format = build_info->GetInputFormat(input_idx); - if (format == kernel::KernelBuildInfo::kInvalidFormat) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid input format"; - } - return format; -} - -KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(anf_node); - if (!anf_node->isa()) { - MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); - } - auto node = cnode->input(input_idx + 1); - MS_EXCEPTION_IF_NULL(node); - return VisitKernel(node, 0); -} - -std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); -} - -std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); - return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); -} - -std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - abstract::BaseShapePtr base_shape = node->Shape(); - MS_EXCEPTION_IF_NULL(base_shape); - if (base_shape->isa() && output_idx == 0) { - return TransShapeToSizet(base_shape->cast()); - } else if (base_shape->isa()) { - auto tuple_shape = base_shape->cast(); - MS_EXCEPTION_IF_NULL(tuple_shape); - if (output_idx >= tuple_shape->size()) { - MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size() - << "."; - } - auto b_shp = (*tuple_shape)[output_idx]; - if (b_shp->isa()) { - return TransShapeToSizet(b_shp->cast()); - } else if (b_shp->isa()) { - return std::vector(); - } else { - MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx - << " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString(); - } - } else if (base_shape->isa()) { - return std::vector(); - } - MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is " - << base_shape->ToString(); -} - -std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); - return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second); -} - -std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { - auto format = GetOutputFormat(node, output_idx); - auto infer_shape = GetOutputInferShape(node, output_idx); - if (infer_shape.empty()) { - return infer_shape; - } - // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); - } - return trans::TransShapeToDevice(infer_shape, format); -} - -std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { - auto format = GetInputFormat(node, input_idx); - auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); - if (infer_shape.empty()) { - return infer_shape; - } - // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); - } - return trans::TransShapeToDevice(infer_shape, format); -} - -std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(node); - if (input_idx > GetInputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index:" << input_idx - << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" - << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputReshapeType(node, input_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - if (build_info->IsInputDefaultPadding()) { - return {}; - } - return build_info->GetInputReshapeType(input_idx); -} - -std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputReshapeType(node, output_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - if (build_info->IsOutputDefaultPadding()) { - return {}; - } - return build_info->GetOutputReshapeType(output_idx); -} - -TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - TypePtr type_ptr = node->Type(); - MS_EXCEPTION_IF_NULL(type_ptr); - if (type_ptr->isa() && output_idx == 0) { - auto tensor_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem->type_id(); - } else if (type_ptr->isa()) { - auto tuple_ptr = type_ptr->cast(); - MS_EXCEPTION_IF_NULL(tuple_ptr); - if (output_idx >= tuple_ptr->size()) { - MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size(); - } - auto tuple_i = (*tuple_ptr)[output_idx]; - MS_EXCEPTION_IF_NULL(tuple_i); - if (tuple_i->isa()) { - auto tensor_ptr = tuple_i->cast(); - MS_EXCEPTION_IF_NULL(tensor_ptr); - TypePtr elem = tensor_ptr->element(); - MS_EXCEPTION_IF_NULL(elem); - return elem->type_id(); - } else if (tuple_i->isa()) { - return tuple_i->type_id(); - } else { - MS_LOG(WARNING) << "Not support type " << tuple_i->ToString(); - return tuple_i->type_id(); - } - } else if (type_ptr->isa()) { - return type_ptr->type_id(); - } - return type_ptr->type_id(); -} - -TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); - return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second); -} - -TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputDeviceDataType(node, output_idx); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto dtype = build_info->GetOutputDeviceType(output_idx); - if (dtype == TypeId::kNumberTypeEnd) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid dtype"; - } - return dtype; -} - -TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { - MS_EXCEPTION_IF_NULL(node); - if (input_idx > GetInputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " - << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; - } - if (!IsRealKernel(node)) { - return GetPrevNodeOutputDeviceDataType(node, 0); - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - auto dtype = build_info->GetInputDeviceType(input_idx); - if (dtype == TypeId::kNumberTypeEnd) { - MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]" - << " has a invalid dtype"; - } - return dtype; -} - -TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second); -} - -// get output device addr of anf_node -const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, - bool visit_nop_node) { - MS_EXCEPTION_IF_NULL(node); - if (opt::IsNopNode(node) && visit_nop_node) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() == 2) { - return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0); - } else { - MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; - } - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto addr = kernel_info->GetOutputAddr(output_idx); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() - << " output addr is not exist"; - } - return addr; -} - -DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, - bool visit_nop_node) { - MS_EXCEPTION_IF_NULL(node); - if (opt::IsNopNode(node) && visit_nop_node) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() == 2) { - return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0); - } else { - MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; - } - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto addr = kernel_info->GetMutableOutputAddr(output_idx); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() - << " output addr is not exist"; - } - return addr; -} - -// get output device addr of anf_node -bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->OutputAddrExist(output_idx); -} - -const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, - bool visit_nop_node) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); -} - -DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, - bool visit_nop_node) { - KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); -} - -// set output device addr of anf_node -void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - if (!kernel_info->SetOutputAddr(addr, output_idx)) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; - } -} - -// set workspace device addr of anf_node -void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; - } -} - -// get workspace device addr of anf_node -DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto addr = kernel_info->GetWorkspaceAddr(output_idx); - if (addr == nullptr) { - MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() - << "] workspace addr is not exist"; - } - return addr; -} - -// set infer shapes and types of anf node -void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, - const std::vector> &shapes, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - if (types.size() != shapes.size()) { - MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size(); - } - if (shapes.empty()) { - node->set_abstract(std::make_shared()); - } else if (shapes.size() == 1) { - // single output handle - std::vector shape_int; - std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt); - auto abstract = std::make_shared(TypeIdToType(types[0]), shape_int); - node->set_abstract(abstract); - } else { - // multiple output handle - std::vector abstract_list; - for (size_t i = 0; i < types.size(); ++i) { - std::vector shape_int; - std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt); - abstract_list.push_back(std::make_shared(TypeIdToType(types[i]), shape_int)); - } - auto abstract_tuple = std::make_shared(abstract_list); - node->set_abstract(abstract_tuple); - } -} -// copy an abstract of a node to another node -void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) { - to_node->set_abstract(from_node->abstract()); -} - -kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - // select_kernel_build_info() has checked whether return pointer is null - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->op_pattern(); -} - -// get KernelBuildType of node, such as ATT,RT,FWK and so on -KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - // select_kernel_build_info() has checked whether return pointer is null - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->kernel_type(); -} - -kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->processor(); -} - -kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - auto build_info = kernel_info->select_kernel_build_info(); - MS_EXCEPTION_IF_NULL(build_info); - return build_info->fusion_type(); -} - -// set select kernel_build_info -void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->set_select_kernel_build_info(select_kernel_build_info); -} - -// get select kernel_build_info -KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->GetMutableSelectKernelBuildInfo(); -} - -// get kernelMode -KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->MutableKernelMod(); -} - -// set kernel mod -void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_kernel_mod(kernel_mod); -} - -bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // parameter and value node is not a real kernel too - if (!node->isa()) { - return true; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().empty()) { - MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString(); - } - auto input = cnode->inputs()[0]; - bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || - IsPrimitive(input, prim::kPrimTensorSummary) || - IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || - IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || - IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || - IsPrimitive(input, prim::kPrimReturn); - return !is_virtual_node; -} - -bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // parameter and value node is not a real cnode kernel - if (!node->isa()) { - return false; - } - // return considered as a real node - if (CheckPrimitiveType(node, prim::kPrimReturn)) { - return true; - } - return IsRealKernel(node); -} - -bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - // graph kernel should be a real cnode kernel. - if (!IsRealCNodeKernel(node)) { - return false; - } - - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input = cnode->input(kAnfPrimitiveIndex); - // graph kernel should has func_graph as first input. - if (!IsValueNode(input)) { - return false; - } - - auto func_graph = GetValueNode(input); - MS_EXCEPTION_IF_NULL(func_graph); - return func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); -} - -bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { - MS_EXCEPTION_IF_NULL(node); - return node->has_default(); -} - -void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_stream_id(stream_id); -} - -uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->stream_id(); -} - -void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_stream_distinction_label(stream_label); -} - -uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->stream_distinction_label(); -} - -void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - kernel_info->set_graph_id(graph_id); -} - -uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { - MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->graph_id(); -} - -bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { - MS_EXCEPTION_IF_NULL(anf); - TypePtr type = anf->Type(); - MS_EXCEPTION_IF_NULL(type); - return type->isa(); -} - -AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { - MS_EXCEPTION_IF_NULL(node); - auto get_input_index = index + 1; - if (index + 1 > node->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" - << node->inputs().size(); - } - // input 0 is primitive node - return node->input(get_input_index); -} - -bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - return false; - } - auto kernel_info = node->kernel_info(); - MS_EXCEPTION_IF_NULL(kernel_info); - return kernel_info->is_feature_map(); -} - -bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = cnode->input(input_index + 1); - return IsFeatureMapOutput(input_node); -} - -size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { - MS_EXCEPTION_IF_NULL(anf_node); - static std::map> spec_node_list = { - {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, - {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}}, - {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, - {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, - {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, - {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, - {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, - {prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, - {prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, - {prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}, - {prim::kPrimApplyCenteredRMSProp->name(), - {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 5}, {5, 6}, {6, 7}, {7, 8}, {8, 4}}}}; - size_t ret = cur_index; - auto node_name = AnfAlgo::GetCNodeName(anf_node); - if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) { - auto find = spec_node_list.find(node_name); - if (find != spec_node_list.end()) { - ret = find->second[cur_index]; - MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name; - } - } - return ret; -} - -void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(input_node); - node->set_input(index + 1, input_node); -} - -bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return false; - } - auto kernel_name = AnfAlgo::GetCNodeName(node); - if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || - kernel_name == kReduceScatterOpName) { - return true; - } - return false; -} - -bool AnfRuntimeAlgorithm::IsGetNext(const NotNull &node) { - auto kernel_name = AnfAlgo::GetCNodeName(node); - return kernel_name == kGetNextOpName; -} - -FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto value_node = node->cast(); - if (value_node == nullptr) { - return nullptr; - } - auto value = value_node->value(); - if (value == nullptr) { - return nullptr; - } - auto func_graph = value->cast(); - return func_graph; -} - -std::vector AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) { - MS_EXCEPTION_IF_NULL(call_node); - if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared("call"))) { - MS_LOG(EXCEPTION) << "Anf node: " << call_node->DebugString() << "is not a call node."; - } - auto input1 = call_node->input(1); - MS_EXCEPTION_IF_NULL(input1); - if (input1->isa()) { - auto value_node = input1->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto kernel_graph = value_node->value(); - MS_EXCEPTION_IF_NULL(kernel_graph); - return {kernel_graph->cast()}; - } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { - auto switch_node = input1->cast(); - MS_EXCEPTION_IF_NULL(switch_node); - auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr { - auto partial = switch_node->input(input_index); - MS_EXCEPTION_IF_NULL(partial); - if (IsValueNode(partial)) { - return GetValueNode(partial); - } - auto partial_cnode = partial->cast(); - MS_EXCEPTION_IF_NULL(partial_cnode); - auto graph_node = partial_cnode->input(1); - MS_EXCEPTION_IF_NULL(graph_node); - auto graph_value_node = graph_node->cast(); - MS_EXCEPTION_IF_NULL(graph_value_node); - auto graph_value = graph_value_node->value(); - MS_EXCEPTION_IF_NULL(graph_value); - auto child_graph = graph_value->cast(); - return child_graph; - }; - return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)}; - } - return {}; -} - -bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { - MS_EXCEPTION_IF_NULL(call_node); - if (!CheckPrimitiveType(call_node, prim::kPrimCall)) { - MS_LOG(EXCEPTION) << "Call node should be a 'call', but is a " << call_node->DebugString(); - } - auto input1 = call_node->input(1); - if (input1->isa()) { - return false; - } else if (input1->isa() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) { - return true; - } - MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); -} - -bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); - if (shape.empty()) { - return true; - } - return shape.size() == kShape1dDims && shape[0] == 1; -} - -bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); - if (shape.empty()) { - return true; - } - return shape.size() == kShape1dDims && shape[0] == 1; -} - -void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { - std::vector all_opt_list; - std::vector non_opt_list; - - for (const auto &node : *node_list) { - MS_EXCEPTION_IF_NULL(node); - if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) { - all_opt_list.emplace_back(node); - } else { - non_opt_list.emplace_back(node); - } - } - node_list->clear(); - std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list)); - std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list)); -} - -TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto prim = AnfAlgo::GetCNodePrimitive(node); - if (prim == nullptr) { - return kTypeUnknown; - } - - TypeId except_type = kTypeUnknown; - if (prim->GetAttr(kAttrOutputPrecision) != nullptr) { - auto output_type_str = GetValue(prim->GetAttr(kAttrOutputPrecision)); - if (output_type_str == "float16") { - except_type = kNumberTypeFloat16; - } else if (output_type_str == "float32") { - except_type = kNumberTypeFloat32; - } else { - MS_LOG(EXCEPTION) << "The fix precision must be float16 or float32, but got " << output_type_str; - } - } - - return except_type; -} - -TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << node->DebugString() << ", input node is not CNode."; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); - } - auto input_node = cnode->input(input_idx + 1); - MS_EXCEPTION_IF_NULL(input_node); - auto kernel_with_index = VisitKernel(input_node, 0); - if (!kernel_with_index.first->isa()) { - return kTypeUnknown; - } - return GetCNodeOutputPrecision(kernel_with_index.first); -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h deleted file mode 100644 index 3238b1cecc..0000000000 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ /dev/null @@ -1,210 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H -#define MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H -#include -#include -#include -#include -#include -#include -#include -#include "ir/anf.h" -#include "ir/dtype.h" -#include "base/base.h" -#include "ir/primitive.h" -#include "device/device_address.h" -#include "kernel/kernel.h" -#include "kernel/kernel_build_info.h" -#include "operator/ops.h" -#include "utils/contract.h" -#include "session/kernel_graph.h" - -namespace mindspore { -namespace session { -using AnfVisitFuncion = std::function; -using KernelWithIndex = std::pair; -class AnfRuntimeAlgorithm { - public: - // get input_anf_node's real kernel by recurse - static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index); - static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index, - bool visit_nop_node = false, - const std::vector &return_types = { - prim::kPrimMakeTuple}); - static std::vector GetAllOutput(const AnfNodePtr &node, - const std::vector &return_types = {}); - // get cnode primitive - static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node); - static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index); - static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); - // check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple - static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type); - // get cnode primitive - static FuncGraphPtr GetCNodeFuncGraphPtr(const AnfNodePtr &node); - // get kernel_name of anf node - static std::string GetCNodeName(const AnfNodePtr &node); - // get detail info of anf node - static std::string GetNodeDebugString(const AnfNodePtr &node); - // get attr of anf node - template - static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - std::string node_debug_log = node->DebugString(); - MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str(); - } - // single op cnode. - if (auto primitive = GetCNodePrimitive(node); primitive != nullptr) { - return GetValue(primitive->GetAttr(key)); - } - // graph kernel cnode. - auto fg = GetCNodeFuncGraphPtr(node); - MS_EXCEPTION_IF_NULL(fg); - return GetValue(fg->get_attr(key)); - } - static bool IsTupleOutput(const AnfNodePtr &anf); - // set attr of anf node - static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); - // set attr of key from 'from' node to 'to' node - static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to); - // set a new key for attr from 'from' node to 'to' node - static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from, - const AnfNodePtr &to); - // set all attrs from 'from' node to 'to' node - static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); - // check whether a cnode has the specified attr. - static bool HasNodeAttr(const std::string &key, const CNodePtr &node); - // delete attr of anf node - static void EraseNodeAttr(const std::string &key, AnfNodePtr node); - // get the num of input real_kernel(which can be build and run in device) - static size_t GetInputTensorNum(const AnfNodePtr &node); - // get the num of output real_kernel(which can be build and run in device) - static size_t GetOutputTensorNum(const AnfNodePtr &node); - // get output format select of anf node - static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx); - // get input format select of anf node - static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); - // get prev node output width output index - static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); - // get output format from prev node,input_index is the input index of current node related to prev node - static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); - // get reshape_type of from the output of input node. - static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); - // get output shapes inferred by ME from input nodes. - static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); - // get input shapes inferred by ME from input nodes. - static std::vector GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx); - // get output shapes which will built and run in device - static std::vector GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); - // get input shapes which will built and run in device - static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); - // Get Input Padding Axis - static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); - // Get Output Padding Axis - static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); - // get output data type inferred by ME of anf node - static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); - // get output original data type from prev node,input_index is the input index of current node related to prev node - static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx); - // get output select data type of anf node - static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx); - // get input select data type of anf node - static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx); - // get output select data type from prev node,input_index is the input index of current node related to prev node - static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); - // get output device addr of anf_node - static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); - // get mutable output device addr of anf_node - static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); - // check whether output addr is exist or not - static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); - // get address from prev node,input_index is the input index of current node related to prev node - static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, - bool visit_nop_node = true); - static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, - bool visit_nop_node = true); - // set output device addr of anf_node - static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); - // set workspace device addr of anf_node - static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); - // get workspace device addr of anf_node - static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); - // set infer shapes and types of anf node - static void SetOutputInferTypeAndShape(const std::vector &types, - const std::vector> &shapes, AnfNode *node); - static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node); - // get op pattern of the node - static kernel::OpPattern GetOpPattern(const AnfNodePtr &node); - // get KernelBuildType of node ,such as ATT,RT,FWK and so on - static KernelType GetKernelType(const AnfNodePtr &node); - // get processor type:AICORE,AICPU... - static kernel::Processor GetProcessor(const AnfNodePtr &node); - // get fusion type:AICORE,AICPU... - static kernel::FusionType GetFusionType(const AnfNodePtr &node); - // set select kernel_build_info - static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node); - // get select kernel_build_info - static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node); - // get kernelMode - static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node); - // set kernel mod - static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node); - // checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too - static bool IsRealKernel(const AnfNodePtr &node); - // checkout whether the anf node is a real kernel that is a cnode and can run on device - static bool IsRealCNodeKernel(const AnfNodePtr &node); - // checkout whether the anf node is a graph kernel. - static bool IsGraphKernel(const AnfNodePtr &node); - // check parameter is weight or data - static bool IsParameterWeight(const ParameterPtr &node); - // set stream id of kernel,which will be set in stream assign and be used in stream generate - static void SetStreamId(uint32_t stream_id, AnfNode *node); - // get stream id - static uint32_t GetStreamId(const AnfNodePtr &node); - // set stream distinction label to distinguish different ops in different streams - static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node); - // get stream distinction label - static uint32_t GetStreamDistinctionLabel(const AnfNode *node); - // set graph id - static void SetGraphId(uint32_t graph_id, AnfNode *node); - // get graph id - static uint32_t GetGraphId(const AnfNode *node); - static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); - // charge if the node's output is a feature map output - static bool IsFeatureMapOutput(const AnfNodePtr &node); - // charge if the node's input is from a feature map output - static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); - // get real input index for some tbe ops which input order is different between me and tbe impl - static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); - static bool IsCommunicationOp(const AnfNodePtr &node); - static bool IsGetNext(const NotNull &node); - static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); - static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); - static bool IsSwitchCall(const CNodePtr &call_node); - static bool IsScalarInput(const CNodePtr &cnode, size_t index); - static bool IsScalarOutput(const CNodePtr &cnode, size_t index); - static void ReorderExecList(NotNull *> node_list); - // get fix output precision of cnode. - static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); - // get fix output precision from prev node, input_idx is the input index of current node related to prev node. - static TypeId GetPrevNodeOutputPrecision(const AnfNodePtr &node, size_t input_idx); -}; -} // namespace session -using AnfAlgo = session::AnfRuntimeAlgorithm; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ANF_RUNTIME_ALGORITHM_H diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc deleted file mode 100644 index 0c97116c6e..0000000000 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ /dev/null @@ -1,643 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "session/ascend_control_parser.h" -#include -#include -#include "session/anf_runtime_algorithm.h" -#include "utils/union_find_set.h" -#include "device/ascend/ascend_label_assign.h" - -static constexpr size_t kCNodePrim = 0; -static constexpr size_t kCNodeCallArg = 1; -static constexpr size_t kCNodeSwitchCond = 1; -static constexpr size_t kCNodeSwitchTrue = 2; -static constexpr size_t kCNodeSwitchFalse = 3; -static constexpr size_t kCNodeSwitchLength = 4; -static constexpr size_t kCNodePartialLength = 2; -static constexpr size_t kCNodePartialFunc = 1; -static constexpr size_t kCNodeSwitchLayerBranch = 2; -static constexpr size_t kCNodeSwitchLayerLength = 3; - -namespace mindspore { -namespace session { -static CNodePtr GetJumpNode(NotNull parent_graph, NotNull child_graph) { - auto &nodes = parent_graph->execution_order(); - CNodePtr last_jump_node = nullptr; - for (auto &node : nodes) { - if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { - if (child_graph->get_start_label() == node->input(kCNodeCallArg)) { - return node; - } - last_jump_node = node; - } else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { - if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) || - child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) { - return node; - } - last_jump_node = node; - } - } - if (last_jump_node == nullptr) { - MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString(); - } - return last_jump_node; -} - -static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, - const NotNull *> memo) { - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &iter : real_inputs) { - auto ¶ = iter.first; - MS_EXCEPTION_IF_NULL(para); - if (para->isa()) { - union_find_set->Add(para); - } - for (auto &arg : iter.second) { - MS_EXCEPTION_IF_NULL(arg); - if (!arg->isa()) { - continue; - } - union_find_set->Add(arg); - } - } - for (auto &child : kg->child_graph_order()) { - InitUnionFindSet(NOT_NULL(child), union_find_set, memo); - } -} - -static void UnionParentParameter(NotNull kg, const NotNull *> union_find_set, - const NotNull *> memo) { - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &iter : real_inputs) { - auto ¶ = iter.first; - for (auto &arg : iter.second) { - MS_EXCEPTION_IF_NULL(arg); - if (!arg->isa()) { - continue; - } - if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) { - continue; - } - union_find_set->Union(arg, para); - } - } - for (auto &child : kg->child_graph_order()) { - UnionParentParameter(NOT_NULL(child), union_find_set, memo); - } -} - -static UnionFindSet MakeUnionFindSet(NotNull root_kg) { - UnionFindSet result; - std::set memo; - InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); - memo.clear(); - UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); - return result; -} - -static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, - const std::set ¶meter_reuse_set, - const NotNull *> memo) { - if (parameter_reuse_set.empty()) { - MS_LOG(EXCEPTION) << "Parameter_reuse_set is empty."; - } - if (memo->find(kg.get()) != memo->end()) { - return; - } - memo->insert(kg.get()); - - for (auto ¶ : parameter_reuse_set) { - if (para == main_parameter.get()) { - continue; - } - MS_EXCEPTION_IF_NULL(para); - MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " - << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); - kg->ReplaceNode(NOT_NULL(para), main_parameter); - } - - for (auto &child : kg->child_graph_order()) { - RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); - } -} - -static AnfNodePtr GetMainParameter(NotNull root_kg, const AnfNodePtr key, - const std::set ¶meter_reuse_set) { - AnfNodePtr main_parameter = key; - std::set root_inputs_set; - const auto &root_inputs_vector = root_kg->inputs(); - root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); - for (auto &node : parameter_reuse_set) { - if (root_inputs_set.find(node) != root_inputs_set.end()) { - main_parameter = node; - break; - } - } - return main_parameter; -} - -static void ReuseParameter(NotNull root_kg, NotNull *> parameter_set) { - auto parameter_reuse_sets = parameter_set->GetSets(); - for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { - if (parameter_reuse_set.size() <= 1) { - continue; - } - auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set); - std::set memo; - RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); - } -} - -CNodePtr GetNextRealKernel(const std::vector &list, size_t start) { - for (size_t i = start; i < list.size() - 1; ++i) { - if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { - return list[i]; - } - } - return nullptr; -} - -void AscendControlParser::LinkGraph(NotNull kg) { - std::set memo; - (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); - device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg); - std::map graph_id_map; - for (auto &g : memo) { - MS_EXCEPTION_IF_NULL(g); - if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() - << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); - } - graph_id_map[g->graph_id()] = g; - } - - // Insert Assign - ChildGraphDataAssign(graph_id_map); - // Make UnionFindSet - UnionFindSet parameter_set = MakeUnionFindSet(kg); - // Reuse Parameter - ReuseParameter(kg, NOT_NULL(¶meter_set)); -} - -void AscendControlParser::ExecutorValidate(NotNull root_graph) { - std::set memo; - (void)RecurseGraph(root_graph, NOT_NULL(&memo)); -} - -void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { - for (auto &iter : graph_id_map) { - auto &kg = iter.second; - MS_LOG(INFO) << "Data assign graph:" << kg->graph_id(); - MS_EXCEPTION_IF_NULL(kg); - std::set> memo; - const std::vector>> &real_inputs = kg->real_inputs(); - for (auto &it : real_inputs) { - auto ¶meter = it.first; - auto &args = it.second; - for (auto &arg : args) { - MS_EXCEPTION_IF_NULL(arg); - if (memo.find({parameter, arg}) != memo.end()) { - continue; - } else { - memo.emplace(parameter, arg); - } - auto unreuse_args_map = kg->unreuse_args(); - auto unreuse_arg_iter = unreuse_args_map.find(arg); - if (unreuse_arg_iter == unreuse_args_map.end()) { - MS_EXCEPTION_IF_NULL(arg); - MS_EXCEPTION_IF_NULL(parameter); - if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << "."; - } - MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() - << ", arg:" << arg->DebugString(); - continue; - } - auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); - if (target_graph_iter == graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; - } - InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), - NOT_NULL(parameter)); - } - } - kg->SetExecOrderByDefault(); - } -} - -NotNull AscendControlParser::GetStartLabel(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label) { - CNodePtr start_label; - if (last_node != nullptr && last_label != nullptr) { - start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); - kg->set_start_label(start_label); - } else { - // no goto node will jump to start label of root graph, so return a fake label - start_label = std::make_shared(std::vector(), FuncGraphPtr(nullptr)); - } - return NOT_NULL(start_label); -} - -NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, - const NotNull *> memo) { - MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); - - // 1. recursive condition - if (memo->find(kg) != memo->end()) { - MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); - return NOT_NULL(kg->get_start_label()); - } - memo->insert(kg.get()); - - // 2. args replace placeholder - LinkParentGraph(kg, last_node, last_label); - - // 3. topological sort - kg->SetExecOrderByDefault(); - const std::vector &nodes = kg->execution_order(); - // 4. insert first_label - CNodePtr start_label = GetStartLabel(kg, last_node, last_label); - - // 5. traverse - for (size_t i = 0; i < nodes.size(); ++i) { - auto &cnode = nodes[i]; - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->size() < kCNodePrim + 1) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex); - if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { - MS_LOG(DEBUG) << "Continue node " << cnode->DebugString(); - continue; - } - AnfNodePtr arg = cnode->input(kFirstDataInputIndex); - MS_EXCEPTION_IF_NULL(arg); - if (IsValueNode(arg)) { - RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (!arg->isa()) { - MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); - RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { - auto arg_cnode = arg->cast(); - MS_EXCEPTION_IF_NULL(arg_cnode); - cnode->set_inputs(arg_cnode->inputs()); - RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); - } - } - kg->SetExecOrderByDefault(); - MS_LOG(INFO) << "End KernelGraph process: " << kg->ToString(); - return NOT_NULL(start_label); -} - -void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { - auto return_node = kg->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - return_node->input(kFirstDataInputIndex), attch_node.get()}; - auto depend_node = kg->NewCNode(inputs); - return_node->set_input(1, depend_node); -} - -void AscendControlParser::InsertControlDependToGraph(NotNull kg, NotNull first_node, - NotNull second_node) { - MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() - << ", the second node is " << second_node->DebugString(); - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), - first_node, second_node}; - auto control_depend = kg->NewCNode(inputs); - InsertDependToGraph(kg, NOT_NULL(control_depend)); -} - -void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label) { - // if not entry graph, replace return with label_goto - if (from_graph_call_node != nullptr && last_label != nullptr) { - auto label_goto = - kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); - MS_EXCEPTION_IF_NULL(label_goto); - MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); - kg->set_end_goto(label_goto); - } -} - -void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo) { - MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); - - // 1 get kernel graph - const std::vector &origin_inputs = cur_node->inputs(); - if (kCNodeCallArg >= origin_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); - } - std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; - if (!IsValueNode(origin_inputs[kCNodeCallArg])) { - MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; - return; - } - // 2 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " - << cur_node->DebugString(); - // 3 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - auto call_kg = GetValueNode(origin_inputs[kCNodeCallArg]); - // 4 modify call op to goto op - cur_node->set_input(kCNodePrim, new_inputs[kCNodePrim]); - // 5 recurse sub graph - CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); - new_inputs.push_back(sub_label); - cur_node->set_inputs(new_inputs); - cur_node->set_abstract(nullptr); - MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); -} - -void AscendControlParser::RecurseSwitch(NotNull kg, NotNull cur_node, - const CNodePtr &next_node, const NotNull *> memo) { - MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); - - if (cur_node->size() < kCNodeSwitchLength) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; - } - // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_EXCEPTION_IF_NULL(back_label); - MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " - << cur_node->DebugString(); - // 2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - // 3 recurse sub graph - const std::vector &origin_switch_inputs = cur_node->inputs(); - if (kCNodeSwitchCond >= origin_switch_inputs.size()) { - MS_LOG(EXCEPTION) << "The size of origin_switch_inputs is not more than " << kCNodeSwitchCond; - } - std::vector new_switch_inputs = { - std::make_shared(std::make_shared(kLabelSwitchOpName)), - origin_switch_inputs[kCNodeSwitchCond]}; - for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { - // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - // 3.2 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); - new_switch_inputs.push_back(branch_label); - } - std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); - - cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(nullptr); - MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString(); -} - -void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull cur_node, - const CNodePtr &next_node, - const NotNull *> memo) { - MS_LOG(INFO) << "Process switch node " << cur_node->DebugString(); - - if (cur_node->size() < kCNodeSwitchLayerLength) { - MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; - } - - auto branch_tuple = cur_node->input(kCNodeSwitchLayerBranch); - MS_EXCEPTION_IF_NULL(branch_tuple); - if (!branch_tuple->isa()) { - MS_LOG(EXCEPTION) << branch_tuple->DebugString() << " is not a CNode"; - } - const std::vector &branch_partial = utils::cast(branch_tuple)->inputs(); - // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - // 2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - if (next_node != nullptr && next_node != kg->get_return()) { - InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); - } - // 3 recurse sub graph - const std::vector &origin_switch_inputs = cur_node->inputs(); - if (kCNodeSwitchCond >= origin_switch_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << "."; - } - std::vector new_switch_inputs = { - std::make_shared(std::make_shared(kLabelSwitchOpName)), - origin_switch_inputs[kCNodeSwitchCond]}; - for (size_t i = 0; i < branch_partial.size(); ++i) { - // 3.1 branch kernel graph and args - KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - // 3.2 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); - new_switch_inputs.push_back(branch_label); - } - new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); - cur_node->set_inputs(new_switch_inputs); - cur_node->set_abstract(nullptr); - MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString(); -} - -KernelGraphPtr AscendControlParser::ParsePartial(NotNull node) { - if (!node.get()->isa()) { - if (IsValueNode(node)) { - return GetValueNode(node); - } - MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); - } - // 2.1 branch kernel graph and args - auto partial_cnode = utils::cast(node.get()); - MS_EXCEPTION_IF_NULL(partial_cnode); - if (partial_cnode->size() < kCNodePartialLength) { - MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; - } - - const auto &partial_inputs = partial_cnode->inputs(); - if (kCNodePartialFunc >= partial_inputs.size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << "."; - } - auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); - return branch_kg; -} - -void AscendControlParser::InsertMultipleAssignToGraph(NotNull from_graph, - NotNull to_graph, NotNull from, - NotNull to) { - std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); - std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); - MS_LOG(INFO) << "Insert multi-assign from [" << from->DebugString() << "] to [" << to->DebugString() << "]"; - if (from_outputs.size() != to_outputs.size()) { - MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" - << to_outputs.size() << "]"; - } - for (size_t i = 0; i < from_outputs.size(); i++) { - auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); - if (assign_node != nullptr) { - auto jump_node = GetJumpNode(from_graph, to_graph); - const auto &from_graph_exe_order = from_graph->execution_order(); - auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node); - if (jump_node_iter == from_graph_exe_order.end()) { - MS_EXCEPTION_IF_NULL(jump_node); - MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id(); - } - // insert assign between jump_node -1 and jump_node - if (jump_node_iter != from_graph_exe_order.begin()) { - InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node)); - } - if (jump_node != nullptr) { - InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node)); - } - } - } -} - -AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, - NotNull to) { - if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && - AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return nullptr; - } - if (from.get() == to.get()) { - return nullptr; - } - MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " - << to->DebugString(); - // config inputs of assign node - std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimAssign->name())), to, from}; - // generate a new cnode - auto assign_node = kg->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_node); - assign_node->set_abstract(to->abstract()); - return assign_node; -} - -std::vector AscendControlParser::RecurseGraph(NotNull graph, - const NotNull *> memo) { - MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start"; - if (memo->find(graph) != memo->end()) { - return {}; - } - memo->insert(graph.get()); - graph->SetExecOrderByDefault(); - std::vector cnodes = graph->execution_order(); - - auto end_label_goto = graph->get_end_goto(); - if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { - cnodes.pop_back(); - } - AnfAlgo::ReorderExecList(NOT_NULL(&cnodes)); - if (end_label_goto != nullptr) { - cnodes.push_back(end_label_goto); - } - - std::vector execution_order; - uint32_t child_order_index = 0; - for (auto &node : cnodes) { - execution_order.push_back(node); - if (node == graph->get_end_goto()) { - continue; - } - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { - std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); - for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { - if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { - MS_LOG(EXCEPTION) << "Check label index fail"; - } - if (child_order_index >= graph->child_graph_order().size()) { - MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size(); - } - auto child_graph = graph->child_graph_order()[child_order_index++]; - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); - execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); - } - } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { - uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); - if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { - MS_LOG(EXCEPTION) << "Check label index fail"; - } - auto child_graph = graph->child_graph_order()[child_order_index++]; - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); - execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); - } - } - graph->set_execution_order(execution_order); - graph->PrintGraphExecuteOrder(); - return execution_order; -} - -bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, - NotNull graph) { - const std::vector> &child_graph_order = graph->child_graph_order(); - // check index and child order size - if (child_graph_order.size() <= IntToSize(order_index)) { - MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " - << child_graph_order.size() << " goto index " << order_index; - } - auto child_graph = child_graph_order[order_index]; - MS_EXCEPTION_IF_NULL(child_graph); - - // get start_label_set_index of child graph - auto start_label_set = child_graph->get_start_label(); - uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); - if (label_index != start_label_set_index) { - MS_EXCEPTION_IF_NULL(cur_label); - MS_EXCEPTION_IF_NULL(start_label_set); - MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() - << " index " << start_label_set_index << " current child graph order : " << order_index; - return false; - } else { - return true; - } -} - -void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { - MS_LOG(INFO) << "Graph id:" << kg->graph_id(); - kg->SetExecOrderByDefault(); - auto call_nodes = kg->FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); - std::vector child_graph_order; - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); - for (const auto &child_graph : call_child_graphs) { - MS_EXCEPTION_IF_NULL(child_graph); - if (child_graph != kg->parent_graph()) { - child_graph->set_parent_graph(kg.get()); - } - child_graph_order.push_back(child_graph); - } - } - for (size_t i = 0; i < child_graph_order.size(); i++) { - MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; - } - kg->set_child_graph_order(child_graph_order); -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h deleted file mode 100644 index 7530f2019e..0000000000 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H - -#include -#include -#include -#include -#include "session/kernel_graph.h" -#include "utils/base_ref.h" -#include "utils/contract.h" -#include "utils/union_find_set.h" - -namespace mindspore { -namespace session { -class AscendControlParser { - public: - static void ChildGraphDataAssign(const std::map &graph_id_map); - static void LinkGraph(NotNull kg); - - static void InsertDependToGraph(NotNull kg, NotNull attch_node); - static void InsertControlDependToGraph(NotNull kg, NotNull first_node, - NotNull second_node); - static void ExecutorValidate(NotNull root_graph); - static void UpdateChildGraphOrder(NotNull kg); - - private: - static NotNull GetStartLabel(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label); - static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, - const NotNull *> memo); - static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, - const NotNull *> memo); - - static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label); - static KernelGraphPtr ParsePartial(NotNull node); - - static void InsertMultipleAssignToGraph(NotNull from_graph, NotNull to_graph, - NotNull from, NotNull to); - static AnfNodePtr InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - - // root graph order - static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, - NotNull graph); - static std::vector RecurseGraph(NotNull graph, - const NotNull *> memo); -}; -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H diff --git a/mindspore/ccsrc/session/ascend_inference_session.cc b/mindspore/ccsrc/session/ascend_inference_session.cc deleted file mode 100644 index 8593d0104a..0000000000 --- a/mindspore/ccsrc/session/ascend_inference_session.cc +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/ascend_inference_session.h" -#include "operator/ops.h" -#include "ir/tensor.h" -#include "ir/anf.h" -#include "ir/param_value.h" -#include "device/kernel_runtime.h" -#include "session/anf_runtime_algorithm.h" -#include "common/utils.h" -#include "common/trans.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "utils/config_manager.h" -#include "utils/base_ref_extends.h" - -namespace mindspore { -namespace session { -void AscendInferenceSession::LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - std::vector inputs(inputs_const); - auto input_nodes = kernel_graph->inputs(); - - size_t no_weight_input = 0; - for (size_t i = 0; i < input_nodes.size(); ++i) { - tensor::TensorPtr tensor = nullptr; - if (!input_nodes[i]->isa()) { - MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; - continue; - } - auto pk_node = input_nodes[i]->cast(); - MS_EXCEPTION_IF_NULL(pk_node); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - MS_EXCEPTION_IF_NULL(device_address); - if (!AnfAlgo::IsParameterWeight(pk_node)) { - tensor = inputs[no_weight_input++]; - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } -} - -GraphId AscendInferenceSession::CompileGraph(NotNull func_graph) { - auto graph_id = AscendSession::CompileGraph(func_graph); - auto kernel_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(kernel_graph); - // load weight data to device - auto input_nodes = kernel_graph->inputs(); - for (size_t i = 0; i < input_nodes.size(); ++i) { - if (!input_nodes[i]->isa()) { - MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; - continue; - } - auto pk_node = input_nodes[i]->cast(); - MS_EXCEPTION_IF_NULL(pk_node); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - MS_EXCEPTION_IF_NULL(device_address); - if (AnfAlgo::IsParameterWeight(pk_node)) { - const auto ¶m_value = pk_node->default_param(); - MS_EXCEPTION_IF_NULL(param_value); - auto tensor = std::dynamic_pointer_cast(param_value->value()); - MS_EXCEPTION_IF_NULL(tensor); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } - return graph_id; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_inference_session.h b/mindspore/ccsrc/session/ascend_inference_session.h deleted file mode 100644 index e8ccff3f17..0000000000 --- a/mindspore/ccsrc/session/ascend_inference_session.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "session/ascend_session.h" -#include "session/kernel_graph.h" -#include "kernel/kernel.h" -#include "session/session_factory.h" -#include "session/ascend_control_parser.h" - -namespace mindspore { -namespace session { -class AscendInferenceSession : public AscendSession { - public: - AscendInferenceSession() = default; - ~AscendInferenceSession() = default; - void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const; - GraphId CompileGraph(NotNull func_graph) override; -}; -MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc deleted file mode 100644 index 9505eb20ff..0000000000 --- a/mindspore/ccsrc/session/ascend_session.cc +++ /dev/null @@ -1,1752 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/ascend_session.h" -#include -#include -#include -#include -#include -#include -#include "operator/ops.h" -#include "ir/tensor.h" -#include "ir/anf.h" -#include "common/trans.h" -#include "device/kernel_runtime.h" -#include "device/ascend/kernel_select_ascend.h" -#include "device/ascend/kernel_build_ascend.h" -#include "device/ascend/ascend_kernel_runtime.h" -#include "device/ascend/ascend_device_address.h" -#include "pre_activate/ascend/ascend_backend_optimization.h" -#include "pre_activate/common/common_backend_optimization.h" -#include "device/kernel_adjust.h" -#include "device/ascend/ascend_stream_assign.h" -#include "device/ascend/ascend_label_assign.h" -#include "predict/predict.h" -#include "session/anf_runtime_algorithm.h" -#include "ir/scalar.h" -#include "debug/anf_ir_dump.h" -#include "debug/anf_ir_utils.h" -#include "debug/draw.h" -#include "common/utils.h" -#include "pre_activate/common/helper.h" -#include "device/kernel_runtime_manager.h" -#include "kernel/tbe/tbe_python_funcs.h" -#include "utils/config_manager.h" -#include "utils/base_ref_extends.h" -#include "debug/tensor_load.h" - -namespace mindspore { -namespace session { -const size_t kInvalidIndex = SIZE_MAX; -constexpr size_t kReturnDataIndex = 1; -namespace { -void DumpGraphExeOrder(const std::vector &execution_order, const std::string &tag = "") { - MS_LOG(INFO) << "Dump execution_order size " << execution_order.size(); - MS_LOG(INFO) << "[index][stream_label][graph_id][node string]"; - int i = 0; - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - MS_LOG(INFO) << "[ " << i << "]" - << "[" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "]" - << "[" << AnfAlgo::GetGraphId(cnode.get()) << "]" - << "[" << cnode->DebugString() << "]"; - i++; - } - - std::stringstream buf; - buf << "================== execution order ==================\n"; - if (!tag.empty()) { - buf << tag << "\n"; - } - buf << "execution_order size: " << execution_order.size() << "\n"; - i = 0; - for (auto &cnode : execution_order) { - MS_EXCEPTION_IF_NULL(cnode); - buf << i << ":\n"; - buf << "\t" << cnode->DebugString() << "\n"; - buf << "\t" << AnfAlgo::GetStreamDistinctionLabel(cnode.get()) << "\n"; - buf << "\t" << AnfAlgo::GetGraphId(cnode.get()) << "\n"; - i++; - } - buf << "================== execution order ==================\n"; - // std::cout << buf.str() << std::endl; -} - -void DumpGraphInputArgs(const VectorRef &args) { - MS_LOG(INFO) << "Args size[%lu]" << args.size(); - for (size_t i = 0; i < args.size(); i++) { - if (utils::isa(args[i])) { - auto anf = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(anf); - MS_LOG(INFO) << "Parameter arg" << i << " = [%s]" << anf->DebugString(); - } else if (utils::isa(args[i])) { - auto value = utils::cast(args[i]); - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Tensor arg" << i << " = " << value->ToString(); - } else { - MS_LOG(INFO) << "Unknown arg" << i << " = " << args[i].ToString(); - } - } -} - -void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { - MS_EXCEPTION_IF_NULL(graph); - if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { - graph->set_stream_distinction_label(label); - } -} - -std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &args) { - MS_EXCEPTION_IF_NULL(graph); - std::vector graph_inputs = graph->inputs(); - auto valid_inputs = graph->valid_inputs(); - size_t real_args_size = 0; - std::vector real_args = {}; - for (size_t i = 0; i < args.size(); i++) { - if (utils::isa(args[i])) { - auto tmp_args = AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}); - for (auto &real_arg : tmp_args) { - auto anf_node = utils::cast(real_arg); - MS_EXCEPTION_IF_NULL(anf_node); - auto abstract = anf_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && - !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - real_args_size += tuple_abstract->size(); - continue; - } - real_args_size += 1; - real_args.push_back(real_arg); - } - } else { - real_args_size += 1; - real_args.push_back(args[i]); - } - } - if (graph_inputs.size() != valid_inputs.size()) { - MS_LOG(EXCEPTION) << "Graph_inputs.size(): " << graph_inputs.size() - << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; - } - if (real_args_size != graph_inputs.size()) { - for (size_t j = 0; j < valid_inputs.size(); j++) { - if (valid_inputs[j]) { - MS_LOG(INFO) << "Index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); - } - } - MS_LOG(WARNING) << "Real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() - << " not equal"; - } - return real_args; -} - -std::vector GetCNodes(const std::vector &anf_nodes) { - std::vector cnodes = {}; - size_t i = 0; - for (const auto &anf : anf_nodes) { - MS_LOG(INFO) << "Apply_list[" << i++ << "] = " << anf->DebugString(); - MS_EXCEPTION_IF_NULL(anf); - if (anf->isa()) { - cnodes.push_back(anf->cast()); - } - } - return cnodes; -} - -static std::vector> GetChildList(const std::vector &cnodes, - const std::set &cut_prims) { - size_t after_cut_index = 0; - std::vector> ret; - for (size_t i = 0; i < cnodes.size(); ++i) { - bool is_cut_node = false; - for (auto &prim : cut_prims) { - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim)) { - is_cut_node = true; - break; - } - } - if (is_cut_node) { - // is call and not switch call,cut to 3 lists - if (!AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall)) { - // if is not a call,cut to 2 lists - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); - after_cut_index = i; - } else if (!AnfAlgo::IsSwitchCall(cnodes[i])) { - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.begin() + i); - ret.emplace_back(1, cnodes[i]); - after_cut_index = i + 1; - continue; - } - } - // get last child graph list - if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimReturn)) { - ret.emplace_back(cnodes.begin() + after_cut_index, cnodes.end()); - continue; - } - } - return ret; -} - -static void BindCallArgsWithParameter(const std::vector ¶meters, const std::vector &args, - const KernelGraphPtr &graph, KernelGraphPtr child_graph, - const NotNull *> memo) { - MS_EXCEPTION_IF_NULL(child_graph); - MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id(); - if (args.empty()) { - return; - } - if (parameters.size() != args.size()) { - MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() - << " and args size:" << args.size() << " not equal!"; - } - child_graph->SetExecOrderByDefault(); - for (size_t i = 0; i < parameters.size(); i++) { - MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]" - << args[i]->DebugString(); - if (args[i] == parameters[i]) { - MS_LOG(INFO) << "Parameter and arg are same."; - continue; - } - child_graph->SetRealInput(parameters[i], args[i]); - if (memo->find(child_graph) != memo->end() || !args[i]->isa()) { - MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id(); - child_graph->AddUnreuseArgs(args[i], graph); - } - } -} - -// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of -// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -static void UpdateRealInput(NotNull graph, bool split_flag, - const NotNull *> memo) { - MS_EXCEPTION_IF_NULL(memo.get()); - auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); - if (child_graphs.size() == 1) { - MS_EXCEPTION_IF_NULL(child_graphs[0]); - std::vector real_args = - std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); - std::vector child_inputs = child_graphs[0]->inputs(); - BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo); - if (split_flag) { - call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); - } - } else if (child_graphs.size() == 2) { - auto get_partial_args = [&](size_t input_index) -> std::vector { - auto switch_node = call_node->input(1); - MS_EXCEPTION_IF_NULL(switch_node); - auto switch_cnode = switch_node->cast(); - MS_EXCEPTION_IF_NULL(switch_cnode); - auto partial = switch_cnode->input(input_index); - MS_EXCEPTION_IF_NULL(partial); - if (IsValueNode(partial)) { - return {}; - } - auto partial_cnode = partial->cast(); - MS_EXCEPTION_IF_NULL(partial_cnode); - auto ret = std::vector(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end()); - if (split_flag) { - partial_cnode->set_inputs( - std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); - } - return ret; - }; - BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo); - BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo); - } - } -} - -static void RecurseToUpdateCallRealInput(NotNull graph, - const NotNull *> memo) { - memo->insert(graph.get()); - MS_LOG(INFO) << "Start graph id:" << graph->graph_id(); - for (auto &child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) != memo->end()) { - MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() - << ",parent graph:" << graph->parent_graph()->graph_id(); - continue; - } - RecurseToUpdateCallRealInput(NOT_NULL(child_graph), memo); - } - // this action should from bottom to top - graph->UpdateCallRealInput(); -} -} // namespace - -GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - MS_LOG(INFO) << "Start"; - // construct graph, if successfully, graph_sum_ + 1 - auto graph = ConstructKernelGraph(lst, outputs); - auto graph_id = graph->graph_id(); - MS_LOG(INFO) << "Compile graph " << graph_id << " success"; - return graph_id; -} - -GraphId AscendSession::CompileGraph(NotNull func_graph) { - MS_LOG(INFO) << "Start"; - std::vector all_graphs; - auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); - BackendOptimization(all_graphs); - // split switch - SplitGraphs(NOT_NULL(root_graph)); - // empty graph dont entry to backend - if (root_graph->execution_order().empty()) { - MS_LOG(INFO) << root_graph->ToString() << " is empty graph."; - root_graph->set_executable(false); - InitRuntimeResource(); - return root_graph->graph_id(); - } - // insert goto labels and label_sets - LinkChildGraphs(NOT_NULL(root_graph)); - // resource initialize - InitRuntimeResource(); - // recurse compile child root_graph - std::set memo; - RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo)); - // root root_graph valiate,include genearte execute order and so on - RootGraphExecutorValidate(NOT_NULL(root_graph)); - // adjust kernel - AdjustKernel(root_graph); - // assign stream - AssignStream(NOT_NULL(root_graph)); - // insert profiling point - device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get())); - // build kernel - BuildKernel(root_graph); - // alloc mem - MemoryAlloc(root_graph.get()); - // task generate - GenerateTaskInfo(root_graph); - // load task into device - LoadTask(root_graph); - DumpAllGraphs(all_graphs); - // return the root_graph id to backend - auto graph_id = root_graph->graph_id(); - return graph_id; -} - -void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto graph_order = GetGraphOrder(kernel_graph->graph_id()); - for (auto graph_id : graph_order) { - auto child_graph = GetGraph(graph_id); - if (child_graph == nullptr) { - continue; - } - if (child_graph->summary_node_exist()) { - kernel_graph->set_summary_node_exist(true); - return; - } - } - kernel_graph->set_summary_node_exist(false); -} - -void AscendSession::BuildGraph(GraphId graph_id) { - MS_LOG(INFO) << "Start"; - auto graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(graph); - // resource initialize - InitRuntimeResource(); - // multiple graph handle - if (graph_id == final_graph_id_) { - if (!graph->executable()) { - return; - } - // insert assigns to child graph - InsertAllAssigns(); - // insert switch and active to child graph - MergeSwitchCompile(); - SetFinalGraphSummaryFlag(graph); - // OptChildGraphs - auto graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - for (size_t i = 0; i < graph_order.size(); i++) { - if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { - continue; - } - MS_LOG(INFO) << "Start build child graph " << graph_order[i]; - auto child_graph = GetGraph(graph_order[i]); - CompileChildGraph(child_graph); - } - GetSummaryNodes(graph.get()); - // merge child graph - MergeGraphExecOrder(); - } else { - auto single_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(single_graph); - CompileChildGraph(single_graph); - // set the distinction label of single graph - single_graph->set_stream_distinction_label(graph_id); - single_graph->UpdateExecuteKernelStreamLabel(); - } - // adjust execution order because merge child graph and other special operations - AdjustKernel(graph); - // Assign streams for control sink and hccl and so on - AssignStream(NOT_NULL(graph)); - - device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get())); - // build kernel if node is cnode - BuildKernel(graph); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->precompile_only()) { - MS_LOG(INFO) << "Precompile only, stop in build kernel step"; - } else { - // alloc memory, including static memory and dynamic memory - MemoryAlloc(graph.get()); - // generate task info for task sink mode - GenerateTaskInfo(graph); - // load task info to device if it is sink mode - LoadTask(graph); - } - // sync the inital const tensor to device - SyncInitialTenosrToDevice(); - DumpAllGraphs({graph}); - MS_LOG(INFO) << "End"; -} - -void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { - MS_EXCEPTION_IF_NULL(child_graph); - MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); - opt::AscendBackendIRFusionOptimization(child_graph); - opt::AscendBackendFuseBasicOpt(child_graph, true); - opt::AscendBackendGraphKernelOpt(child_graph, true); - child_graph->SetExecOrderByDefault(); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "select_kernel_before" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; - DumpIR(file_path, child_graph); - } - // select kernel build info - SelectKernel(*child_graph); - if (save_graphs) { - std::string file_path = - save_graphs_path + "/" + "select_kernel_after" + "_graph_" + std::to_string(child_graph->graph_id()) + ".ir"; - DumpIR(file_path, child_graph); - } - // convert kernel Graph to model - predictmodel::StepConvertGraph(child_graph); - // optimize graph - HardwareOptimize(child_graph); - // assign static memory of parameters - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignStaticMemoryInput(child_graph.get()); - runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); -} - -void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, - VectorRef *const outputs) { - MS_LOG(INFO) << "Start"; - auto kernel_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(kernel_graph); - // if none of child graph and no anf output exists - if (!kernel_graph->executable()) { - MS_LOG(INFO) << "No child graph has anf output"; - UpdateOutputs(kernel_graph, outputs, inputs); - return; - } - // load input data from user input - LoadInputData(kernel_graph, inputs); - // convert inputs to model - predictmodel::StepConvertWeight(inputs); -#ifdef ENABLE_DEBUGGER - // debugger pre-execution processing - if (debugger_) { - debugger_->PreExecute(kernel_graph); - } -#endif - { - py::gil_scoped_release release; - // run task on device - ExecTask(kernel_graph); - } - // get result from device - UpdateOutputs(kernel_graph, outputs, inputs); - // summary - Summary(kernel_graph.get()); -#ifdef ENABLE_DEBUGGER - // load tensor from device for debugger - if (debugger_ && debugger_->debugger_enabled()) { - LoadTensor(kernel_graph); - } -#endif - // dump used for debug - Dump(kernel_graph); -#ifdef ENABLE_DEBUGGER - // debugger post-execution processing - if (debugger_) { - debugger_->PostExecute(); - } -#endif - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start"; - // data layout optimization - opt::RunOpAscendDataLayout(kernel_graph); - // mixed precision optimization - opt::AscendMixPrecision(kernel_graph); - MS_LOG(INFO) << "Finish"; -} - -void AscendSession::RunOpExecTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Run task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { - if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { - return true; - } - - return false; -} - -void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) { - MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; - if (GraphCacheExist(graph_info)) { - MS_LOG(INFO) << "Build op " << op_run_info.op_name << " graph cache has existed !"; - return; - } - - // construct graph include one op - auto graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); - MS_EXCEPTION_IF_NULL(graph); - opt::RunOpAscendBackendIRFusionOptimization(graph); - // kernel select - SelectKernel(*graph); - // optimize - RunOpHardwareOptimize(graph); - // init runtime resource - InitRuntimeResource(); - // build kernel - RunOpAdjustKernel(graph); - BuildKernel(graph); - run_op_graphs_[graph_info] = graph; - MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; -} - -py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) { - auto graph = run_op_graphs_[graph_info]; - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; - // malloc mem - RunOpMemoryAlloc(input_tensors, graph.get()); - // load input data to device - LoadInputData(graph, input_tensors); - // run op - RunOpExecTask(graph); - // get output - VectorRef outputs; - UpdateOutputs(graph, &outputs, input_tensors); - // trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(outputs); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_LOG(EXCEPTION) << "The output tensors should be a tuple !"; - } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); - RunOpMemoryClear(graph.get()); - MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; - return tuple_tensors; -} - -// compile graph steps -void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - size_t raise_precision_count = 0; - size_t reduce_precision_count = 0; - for (const auto &cnode : kernel_graph.execution_order()) { - auto status = device::ascend::SelectKernelInfo(cnode); - if (status == device::ascend::kStatusRaisePrecision) { - raise_precision_count++; - } else if (status == device::ascend::kStatusReducePrecision) { - reduce_precision_count++; - } - MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString(); - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kGraphMode) { - if (raise_precision_count > 0) { - MS_LOG(WARNING) << "There has " << raise_precision_count - << " node/nodes used raise precision to selected the kernel!"; - } - if (reduce_precision_count > 0) { - MS_LOG(WARNING) << "There has " << reduce_precision_count - << " node/nodes used reduce precision to selected the kernel!"; - } - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::InitRuntimeResource() { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Init()) { - MS_LOG(EXCEPTION) << "Kernel runtime init error."; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_graph) const { - device::ascend::KernelPreBuild(kernel_graph.get()); - MS_LOG(INFO) << "HardwareOptimize start!"; - opt::AscendBackendOptimization(kernel_graph); - opt::AscendGraphKernelCommonProcess(kernel_graph); - opt::AscendBackendFuseBasicOpt(kernel_graph, false); - opt::AscendBackendAddAtomicClean(kernel_graph); - MS_EXCEPTION_IF_NULL(kernel_graph); - kernel_graph->SetExecOrderByDefault(); - MS_LOG(INFO) << "HardwareOptimize Finish!"; -} - -void AscendSession::AdjustKernel(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - opt::HideNopNode(kernel_graph.get()); - // Insert CLearZero op - // prepare for next step from json get atomic info - BuildKernel(kernel_graph); - device::ascend::KernelBuildPreprocess(kernel_graph.get()); - device::KernelAdjust::GetInstance().InsertSwitchLoop(kernel_graph); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - if (save_graphs) { - std::string file_path = save_graphs_path + "/" + "after_adjust_kernel.ir"; - DumpIR(file_path, kernel_graph); - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - opt::HideNopNode(kernel_graph.get()); - // Insert CLearZero op - // prepare for next step from json get atomic info - BuildKernel(kernel_graph); - device::ascend::KernelBuildPreprocess(kernel_graph.get()); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::AssignStream(NotNull kernel_graph) const { - MS_LOG(INFO) << "Start!"; - device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - auto ret = device::ascend::KernelBuild(kernel_graph.get()); - if (!ret) { - MS_LOG(EXCEPTION) << "Kernel build error."; - } - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "KernelBuild run in " << PRIu64 << " us " << cost; - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(kernel_graph); - opt::RemoveNopNode(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignMemory(kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpMemoryAlloc(const std::vector &input_tensors, - KernelGraph *kernel_graph) const { - MS_LOG(INFO) << "Start memory alloc!"; - MS_EXCEPTION_IF_NULL(kernel_graph); - opt::RemoveNopNode(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpClearMemory(kernel_graph); -} - -void AscendSession::GenerateTaskInfo(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->GenTask(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Generate task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::LoadTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->LoadTask(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "Load task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::ExecTask(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - bool ret_ok = runtime_instance->Run(kernel_graph.get()); - if (!ret_ok) { - MS_LOG(EXCEPTION) << "run task error!"; - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::Dump(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - (void)runtime_instance->DumpData(kernel_graph.get()); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::DumpAllGraphs(const std::vector &all_graphs) { -#ifdef ENABLE_DUMP_IR - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - bool save_graphs = context_ptr->save_graphs_flag(); - if (!save_graphs) { - return; - } - auto save_graphs_path = context_ptr->save_graphs_path(); - if (save_graphs_path.empty()) { - save_graphs_path = "."; - } - for (auto &graph : all_graphs) { - MS_EXCEPTION_IF_NULL(graph); - std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(graph->graph_id()) + ".ir"; - DumpIR(file_path, graph, true); - DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id())); - } -#endif -} - -void AscendSession::LoadTensor(const std::shared_ptr &kernel_graph) const { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(kernel_graph); -#ifdef ENABLE_DEBUGGER - auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - DebugServices *debug_services = debugger_->debug_services(); - TensorLoader *tensor_loader = debug_services->get_tensor_loader(); - tensor_loader->EmptyTensor(); - uint32_t iter_num = tensor_loader->GetIterNum(); - tensor_loader->set_iter_num(++iter_num); - (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get()); - tensor_loader->EmptyPrevTensor(); -#endif - MS_LOG(INFO) << "Finish!"; -} - -GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { - MS_LOG(INFO) << "Start! Args size " << args.size(); - auto final_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(final_graph); - final_graph_id_ = final_graph->graph_id(); - MS_LOG(INFO) << "Create a new final graph" << final_graph_id_ << " success"; - // init private variables and bind them with final_graph_id - graph_execute_orders_[final_graph_id_] = std::vector(); - graph_order_types_[final_graph_id_] = std::vector(); - for (const auto ¶meter : args) { - MS_EXCEPTION_IF_NULL(parameter); - if (!parameter->isa()) { - MS_LOG(EXCEPTION) << parameter->DebugString() << " is not a parameter type!"; - } - AnfNodePtr parameter_backend = nullptr; - // if function return UINT_MAX,the parameter is not exist in child graph - auto parameter_belong_graph_id = GetGraphIdByNode(parameter); - if (parameter_belong_graph_id == kInvalidGraphId) { - parameter_backend = CreateNewParameterFromParameter(parameter, true, final_graph.get()); - final_graph->FrontBackendlMapAdd(parameter, parameter_backend); - MS_LOG(INFO) << "New parameter" << parameter->DebugString() << "in final_graph"; - } else { - // parametr is a parameter of child graph - auto graph = GetGraph(parameter_belong_graph_id); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Reuse parameter [" << parameter->DebugString() << "] of child graph [" - << parameter_belong_graph_id << "]"; - parameter_backend = graph->GetBackendAnfByFrontAnf(parameter); - // add parameter in backend to final graph inputs - auto final_graph_inputs = final_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(final_graph_inputs); - final_graph_inputs->push_back(parameter_backend); - } - MS_EXCEPTION_IF_NULL(parameter_backend); - MS_LOG(INFO) << "Parameter backend " << parameter_backend->DebugString() << " belong_graph_id " - << AnfAlgo::GetGraphId(parameter_backend.get()); - } - MS_LOG(INFO) << "End final_graph_id " << final_graph_id_; - return final_graph_id_; -} - -void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, - std::map> *summary) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(summary); - // if final graph have no child graph - auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); - if (graph_order_iter == graph_execute_orders_.end()) { - SessionBasic::GetSummaryNodes(graph); - auto summary_nodes = graph->summary_nodes(); - summary->insert(summary_nodes.begin(), summary_nodes.end()); - return; - } - // for every child graph, find summary nodes - auto graph_order = GetGraphOrder(graph->graph_id()); - for (size_t i = 0; i < graph_order.size(); i++) { - auto child_graph = GetGraph(graph_order[i]); - if (child_graph == nullptr) { - continue; - } - SessionBasic::GetSummaryNodes(child_graph.get()); - auto child_graph_summary = child_graph->summary_nodes(); - summary->insert(child_graph_summary.begin(), child_graph_summary.end()); - RecurseGetSummaryNodes(child_graph.get(), summary); - } - graph->set_summary_nodes(*summary); -} - -void AscendSession::GetSummaryNodes(KernelGraph *graph) { - MS_LOG(DEBUG) << "Update summary Start"; - MS_EXCEPTION_IF_NULL(graph); - auto summary_nodes = graph->summary_nodes(); - std::map> summary; - summary.insert(summary_nodes.begin(), summary_nodes.end()); - RecurseGetSummaryNodes(graph, &summary); - graph->set_summary_nodes(summary); - MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); -} - -AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { - auto fake_graph = GetGraph(fake_graph_id); - MS_EXCEPTION_IF_NULL(fake_graph); - auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { - auto parameter = fake_graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = fake_graph->NewParameter(parameter); - // Add new parameter to the graph input of fake_graph to sure that all parameters will be allocated memory. - auto graph_inputs = fake_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->push_back(new_parameter); - return new_parameter; - }; - auto create_parameter_from_cnode = [&](const AnfNodePtr &cnode, size_t output_idx) -> AnfNodePtr { - MS_EXCEPTION_IF_NULL(cnode); - auto abstract = cnode->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa()) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]"; - return create_parameter((*tuple_abstract)[output_idx]); - } - return create_parameter(cnode->abstract()); - }; - if (AnfAlgo::CheckPrimitiveType(output_item_with_index.first, prim::kPrimMakeTuple)) { - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - auto make_tuple = output_item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - for (size_t i = 1; i < make_tuple->inputs().size(); i++) { - auto input = make_tuple->inputs()[i]; - make_tuple_inputs.push_back(CreateFakeOutput(fake_graph_id, input)); - } - return fake_graph->NewCNode(make_tuple_inputs); - } - return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); -} - -void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { - // get the backend anf node related to the output node of front - auto output_from_graph_id = GetGraphIdByNode(node); - auto output_from_graph = GetGraph(output_from_graph_id); - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id - << "] to final graph"; - MS_EXCEPTION_IF_NULL(output_from_graph); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - // if output is from final graph,it remarks no child graph exist - if (final_graph_id_ == output_from_graph_id) { - MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); - final_graph->set_output(ConstructOutput({node}, final_graph)); - final_graph->set_executable(false); - return; - } - final_graph->set_output(output_from_graph->output()); -} - -void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { - auto value_node = NewValueNode(value); - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - value_node->set_abstract(abstract::FromValue(value)); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); - final_graph->set_executable(false); - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; -} - -void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { - for (auto &output : vec_output) { - if (utils::isa(output)) { - auto output_anf_node = utils::cast(output); - SetFinalGraphOutput(output_anf_node); - } else if (utils::isa(output)) { - auto value = utils::cast(output); - SetFinalGraphOutput(value); - } else { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } - } -} - -void AscendSession::SetFinalGraphOutput(const BaseRef &output) { - if (utils::isa(output)) { - auto output_anf_node = utils::cast(output); - SetFinalGraphOutput(output_anf_node); - } else if (utils::isa(output)) { - auto value = utils::cast(output); - SetFinalGraphOutput(value); - } else if (utils::isa(output)) { - auto vec_output = utils::cast(output); - SetFinalGraphOutput(vec_output); - } else { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } -} - -void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id) { - MS_LOG(INFO) << "Start!"; - MS_LOG(INFO) << "Condition graph id[" << condition_graph_id << "],true graph id[" << true_graph_id << "]"; - auto condition_graph = GetGraph(condition_graph_id); - MS_EXCEPTION_IF_NULL(condition_graph); - tensor::TensorPtr tensor = std::make_shared(kNumberTypeInt32, std::vector{1}); - int32_t *val = nullptr; - val = static_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - auto value_node = std::make_shared(tensor); - value_node->set_abstract(abstract::FromValue(tensor, false)); - auto counter_const = condition_graph->NewValueNode(value_node); - condition_graph->AddValueNodeToGraph(counter_const); - // create a new switch op - auto switch_primitive = std::make_shared("StreamSwitch"); - auto cond_output_it = condition_output_.find(condition_graph_id); - if (cond_output_it == condition_output_.end()) { - MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; - } - auto cond_output_kernel = - AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; - MS_EXCEPTION_IF_NULL(cond_output_kernel); - std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; - CNodePtr switch_node = condition_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(switch_node); - switch_node->set_abstract(std::make_shared()); - AnfAlgo::SetGraphId(condition_graph_id, switch_node.get()); - // set attr: cond_ RT_GREATER - AnfAlgo::SetNodeAttr(kAttrSwitchCondition, MakeValue(static_cast(RT_GREATER)), switch_node); - // set attr:data_type - AnfAlgo::SetNodeAttr(kAttrDataType, MakeValue(static_cast(RT_SWITCH_INT64)), switch_node); - // set attr:true branch graph id ,which is same to stream distinction label - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(true_graph_id), switch_node); - // append switch at the end of condition graph - auto return_node = condition_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - InsertControlDependToGraph(condition_graph_id, return_node->input(kReturnDataIndex), switch_node); - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { - auto &graph_execute_order = GetGraphOrder(final_graph_id_); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - auto false_index = ExecOrderOfChildGraph(final_graph_id_, false_graph_id); - if (false_index == kInvalidIndex || false_index == 0) { - return; - } - for (int i = SizeToInt(false_index) - 1; i >= 0; i--) { - size_t graph_index = IntToSize(i); - if (graph_index >= graph_execute_order.size()) { - MS_LOG(EXCEPTION) << "Graph index[" << graph_index << "] out of range[" << graph_execute_order.size() << "]"; - } - if (graph_order_type[graph_index] == COMMON_GRAPH) { - auto true_last_id = graph_execute_order[graph_index]; - MS_LOG(INFO) << "The last graph of if true branch is " << true_last_id; - auto true_last = GetGraph(true_last_id); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - auto false_last = GetGraph(false_graph_id); - MS_EXCEPTION_IF_NULL(true_last); - MS_EXCEPTION_IF_NULL(false_last); - MS_LOG(INFO) << "The last graph of false branch is " << false_graph_id; - // create fake output - auto fake_output_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(fake_output_graph); - graph_execute_order.push_back(fake_output_graph->graph_id()); - graph_order_type.push_back(COMMON_GRAPH); - fake_output_graph->set_output(CreateFakeOutput(fake_output_graph->graph_id(), final_graph->output())); - final_graph->set_output(fake_output_graph->output()); - InsertMultipleAssignToGraph(true_last_id, true_last->output(), final_graph->output()); - InsertMultipleAssignToGraph(false_graph_id, false_last->output(), final_graph->output()); - // insert stream active for loop sink - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1) { - // insert active in true graph, another active will be inserted in kernel adjust - InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); - } - break; - } - } -} - -void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, - const AnfNodePtr &output) { - if (switches_.find(cond_graph_id) != switches_.end()) { - MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; - return; - } - switches_[cond_graph_id] = std::pair(true_graph_id, false_graph_id); - condition_output_[cond_graph_id] = output; - MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; - // set the type of condition graph - auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - if (cond_graph_index >= graph_order_type.size()) { - MS_LOG(EXCEPTION) << "Cond_graph_index " << cond_graph_index << " out of range " << graph_order_types_.size(); - } - graph_order_type[cond_graph_index] = CONDITION_GRAPH; - // update distinction label of false graph,update before merge to sure the distinction - if (false_graph_id != kInvalidGraphId) { - // false graph and condition in graph same stream - auto condition_graph = GetGraph(cond_graph_id); - MS_EXCEPTION_IF_NULL(condition_graph); - SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); - // if false graph is a condition graph and has been switch compiled before,it's false should be updated again - auto cond_it = switches_.find(false_graph_id); - while (cond_it != switches_.end() && cond_it->second.second != kInvalidGraphId) { - cond_graph_id = cond_it->first; - false_graph_id = cond_it->second.second; - condition_graph = GetGraph(cond_graph_id); - if (condition_graph == nullptr) { - continue; - } - SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); - cond_it = switches_.find(false_graph_id); - } - } -} // namespace session - -void AscendSession::MergeSwitchCompile() { - auto graph_execute_order = GetGraphOrder(final_graph_id_); - auto &graph_order_type = GetGraphOrderType(final_graph_id_); - for (auto switch_compile : switches_) { - auto cond_graph_id = switch_compile.first; - auto true_graph_id = switch_compile.second.first; - auto false_graph_id = switch_compile.second.second; - MS_LOG(INFO) << "Switch compile: " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; - auto condition_graph = GetGraph(cond_graph_id); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(condition_graph); - MS_EXCEPTION_IF_NULL(final_graph); - // insert switch to condition graph - InsertSwitchToGraph(cond_graph_id, true_graph_id); - auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); - auto prev_graph_id = kInvalidGraphId; - // if condition graph is the first graph and final graph has assign op,then the final graph is the common graph - if (cond_graph_index == 0 && !final_graph->execution_order().empty()) { - prev_graph_id = final_graph_id_; - // set the distinction label of final graph - SetStreamDistinctionLabel(final_graph, final_graph_id_, true); - // if condition graph is not the first graph - } else if ((cond_graph_index - 1 < graph_execute_order.size()) && - (graph_order_type[cond_graph_index - 1] == COMMON_GRAPH)) { - prev_graph_id = graph_execute_order[cond_graph_index - 1]; - } - // insert stream active to common graph - if (prev_graph_id != kInvalidGraphId) { - InsertStreamActiveToGraph(prev_graph_id, condition_graph->stream_distinction_label()); - } - // if this is a 'if' condition - auto it = while_condition_graphs_.find(cond_graph_id); - if (it == while_condition_graphs_.end()) { - CopyOutputOfIf(false_graph_id); - } else { - // if it is a while,insert a stream active to true graph - GraphId from_graph = it->second; - InsertStreamActiveToGraph(from_graph, condition_graph->stream_distinction_label()); - } - } - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::InsertAllAssigns() { - std::vector> assigns; - for (auto assign : assigns_) { - auto front_anf = std::get<0>(assign); - auto to_graph_id = std::get<1>(assign); - auto input_idx = std::get<2>(assign); - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - assigns.emplace_back(std::pair(front_anf, backend_parameter)); - } - // erase the repeat assign - std::set> inserted_nodes; - for (auto &assign : assigns) { - auto front_anf = assign.first; - auto backend_parameter = assign.second; - auto from_graph_id = GetGraphIdByNode(front_anf); - auto from_graph = GetGraph(from_graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); - if (inserted_nodes.find(assign) == inserted_nodes.end()) { - InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); - (void)inserted_nodes.insert(assign); - } - } -} - -// insert active to graph -void AscendSession::SetActive(GraphId from, GraphId to) { - if (while_condition_graphs_.find(to) != while_condition_graphs_.end()) { - MS_LOG(WARNING) << "To " << to << " has been exits in map,from " << from << ",exist from " - << while_condition_graphs_[to]; - return; - } - MS_LOG(INFO) << "From " << from << " to " << to; - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - std::vector graph_order_new; - std::vector graph_type_new; - for (size_t i = 0; i < graph_order.size(); i++) { - auto graph_id = graph_order[i]; - graph_order_new.push_back(graph_id); - graph_type_new.push_back(graph_type[i]); - if (from == graph_id) { - graph_order_new.push_back(kInvalidGraphId); - graph_type_new.push_back(BRANCH_END); - } - } - graph_order = graph_order_new; - graph_type = graph_type_new; - // set the graph type of condition graph - graph_type[ExecOrderOfChildGraph(final_graph_id_, to)] = CONDITION_GRAPH; - // record the condition graph into while condition set - while_condition_graphs_[to] = from; -} - -void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx) { - MS_LOG(INFO) << "Start!"; - MS_EXCEPTION_IF_NULL(front_anf); - auto from_graph_id = GetGraphIdByNode(front_anf); - auto from_graph = GetGraph(from_graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - MS_EXCEPTION_IF_NULL(backend_parameter); - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); - MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" - << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) - << "]"; - // a node should not assign to itself - if (backend_arg.get() == backend_parameter.get()) { - return; - } - // if arg is the the parameter of child graph,it is parameter of final graph too - if (front_anf->isa()) { - MS_EXCEPTION_IF_NULL(backend_arg); - MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() - << "] will be replaced."; - to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); - return; - } - MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" - << backend_parameter->DebugString() << "of graph " << to_graph_id; - assigns_.emplace_back(std::tuple(front_anf, to_graph_id, input_idx)); -} - -void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, - size_t input_idx) { - MS_LOG(INFO) << "Start!"; - std::pair graph_input_pair(to_graph_id, input_idx); - initial_tenosrs_[graph_input_pair] = front_tensor; - MS_LOG(INFO) << "Finish!"; -} - -void AscendSession::UpdateGraphOrder(GraphId to_graph_id) { - MS_LOG(INFO) << "To_graph_id " << to_graph_id; - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - for (size_t i = 0; i < graph_order.size(); i++) { - if (graph_order[i] == to_graph_id) { - return; - } - } - // if graph is not in graph order,add it to graph order - SetStreamDistinctionLabel(GetGraph(to_graph_id), to_graph_id, false); - graph_order.push_back(to_graph_id); - graph_type.push_back(COMMON_GRAPH); - for (size_t i = 0; i < graph_order.size(); i++) { - MS_LOG(INFO) << "Index " << i << ",graph_id " << graph_order[i] << ",graph_type" << graph_type[i]; - } -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - auto output_num = AnfAlgo::GetOutputTensorNum(node); - if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - return input_index + output_num; - } - auto valid_inputs = graph->valid_inputs(); - if (valid_inputs[input_index]) { - SetChildGraphParameter(node, graph->graph_id(), input_index); - } else { - MS_LOG(DEBUG) << "Invalid input arg: " << node->DebugString(); - } - return ++input_index; -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value Node should be a tensor, unexpected value: " << value->ToString(); - } - SetChildGraphParameter(value->cast(), graph->graph_id(), input_index); - return ++input_index; -} - -size_t AscendSession::SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index) { - auto index = input_index; - for (auto &arg : vec_args) { - if (utils::isa(arg)) { - // arg is a anf node - auto node = utils::cast(arg); - index = SetChildGraphInput(graph, node, input_index); - } else if (utils::isa(arg)) { - // arg is a tensor - auto value = utils::cast(arg); - index = SetChildGraphInput(graph, value, input_index); - } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << arg.ToString(); - } - } - return index; -} - -void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { - MS_LOG(INFO) << "Set input of graph " << g; - auto to_graph = GetGraph(g); - MS_EXCEPTION_IF_NULL(to_graph); - DumpGraphInputArgs(args); - UpdateGraphOrder(g); - auto &graph_inputs = to_graph->inputs(); - auto real_args = GetRealArgs(to_graph, args); - size_t input_index = 0; - for (size_t i = 0; i < real_args.size(); i++) { - if (input_index >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_index << " out of range size " << graph_inputs.size(); - } - auto &real_arg = real_args[i]; - if (utils::isa(real_arg)) { - // arg is a anf node - auto node = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, node, input_index); - } else if (utils::isa(real_arg)) { - // arg is a tensor - auto value = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, value, input_index); - } else if (utils::isa(real_arg)) { - // arg is a VectorRef - auto vec_args = utils::cast(real_arg); - input_index = SetChildGraphInput(to_graph, vec_args, input_index); - } else { - MS_LOG(EXCEPTION) << "Unexpected arg type " << real_arg.ToString(); - } - } - MS_LOG(INFO) << "Finish!"; -} - -GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const { - for (const auto &graph_item : graphs_) { - auto graph = graph_item.second; - MS_EXCEPTION_IF_NULL(graph); - // if front_anf is a parameter,the backend parameter may have two - if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { - return graph_item.first; - } - } - MS_EXCEPTION_IF_NULL(front_anf); - MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; - return kInvalidGraphId; -} - -void AscendSession::MergeGraphExecOrder() { - MS_LOG(INFO) << "Start!"; - // merge graph order - auto &graph_order = GetGraphOrder(final_graph_id_); - auto &graph_type = GetGraphOrderType(final_graph_id_); - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - if (graph_order.empty()) { - MS_LOG(WARNING) << "Graph output is a lonely variable not linked to any op!"; - return; - } - if (graph_order.size() > 1) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!context_ptr->enable_task_sink()) { - MS_LOG(EXCEPTION) << "Control sink network should run with task-sink mode!"; - } - } - // if first graph is common,the final graph has no label,then set the stream of final graph same with the first graph - SetStreamDistinctionLabel(final_graph, graph_order[0], false); - std::vector final_exec_order = final_graph->execution_order(); - KernelGraphPtr last_graph = nullptr; - for (size_t i = 0; i < graph_order.size(); i++) { - auto graph_id = graph_order[i]; - if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { - continue; - } - auto child_graph = GetGraph(graph_id); - last_graph = child_graph; - MS_EXCEPTION_IF_NULL(child_graph); - auto exec_order = child_graph->execution_order(); - MS_LOG(INFO) << "Merge graph,graph_id " << graph_id; - (void)std::transform(exec_order.begin(), exec_order.end(), std::back_inserter(final_exec_order), - [&](CNodePtr node) -> CNodePtr { - AnfAlgo::SetStreamDistinctionLabel(child_graph->stream_distinction_label(), node.get()); - return node; - }); - // add all value nodes of child graphs to final graph - for (auto &value_node : child_graph->graph_value_nodes()) { - final_graph->AddValueNodeToGraph(value_node); - } - // copy ref map to final graph - auto child_ref_map = child_graph->GetRefMap(); - for (auto &item : child_ref_map) { - if (final_graph->IsInRefOutputMap(item.first)) { - MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; - } - final_graph->AddRefCorrespondPairs(item.first, item.second); - } - } - // set final_exec_order into final graph - MS_EXCEPTION_IF_NULL(final_graph); - DumpGraphExeOrder(final_exec_order); - final_graph->set_execution_order(final_exec_order); -} - -void AscendSession::InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { - MS_EXCEPTION_IF_NULL(from); - MS_EXCEPTION_IF_NULL(to); - if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && - AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { - return; - } - if (from.get() == to.get()) { - return; - } - MS_LOG(INFO) << "Insert assign to graph " << graph_id << " from " << from->DebugString() << " to " - << to->DebugString(); - auto graph = graphs_[graph_id]; - MS_EXCEPTION_IF_NULL(graph); - // config inputs of assign node - std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; - // generate a new cnode - auto assign_node = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(assign_node); - assign_node->set_abstract(to->abstract()); - // append the assign at the end of from graph - InsertDependToGraph(graph_id, assign_node); -} - -void AscendSession::InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to) { - std::vector from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); - std::vector to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); - MS_LOG(INFO) << "Insert assigns from [" << AnfAlgo::GetGraphId(from.get()) << "] to [" - << AnfAlgo::GetGraphId(to.get()) << "]"; - if (from_outputs.size() != to_outputs.size()) { - MS_LOG(INFO) << "From[" << from->DebugString(5) << "] to[" << to->DebugString(5) << "]"; - MS_LOG(EXCEPTION) << "From outputs size[" << from_outputs.size() << "] is not equal to to outputs size[" - << to_outputs.size() << "]"; - } - for (size_t i = 0; i < from_outputs.size(); i++) { - InsertAssignToGraph(graph_id, from_outputs[i], to_outputs[i]); - } -} - -void AscendSession::InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream) { - MS_LOG(INFO) << "Insert stream_active from " << graph_id << " to " << actived_stream; - auto from_graph = GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(from_graph); - std::vector inputs = {NewValueNode(std::make_shared("StreamActive"))}; - auto active_node = from_graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(active_node); - active_node->set_abstract(std::make_shared()); - // set the active stream id into the attr of active node - std::vector active_index_value = {}; - active_index_value.push_back(actived_stream); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_value), active_node); - // append the active node at the end of from graph - auto return_node = from_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - InsertControlDependToGraph(graph_id, return_node->input(kReturnDataIndex), active_node); -} - -void AscendSession::InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node) { - AscendControlParser::InsertDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(attch_node)); -} - -void AscendSession::InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, - const AnfNodePtr &second_node) { - AscendControlParser::InsertControlDependToGraph(NOT_NULL(GetGraph(graph_id)), NOT_NULL(first_node), - NOT_NULL(second_node)); -} - -size_t AscendSession::ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph) { - auto &graph_order = GetGraphOrder(final_graph); - for (size_t i = 0; i < graph_order.size(); i++) { - if (child_graph == graph_order[i]) { - return i; - } - } - return kInvalidIndex; -} - -std::vector &AscendSession::GetGraphOrder(GraphId final_graph_id) { - auto graph_order_iter = graph_execute_orders_.find(final_graph_id); - if (graph_order_iter == graph_execute_orders_.end()) { - MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no child graph"; - } - return graph_order_iter->second; -} - -// get graph order type vector by graph id -std::vector &AscendSession::GetGraphOrderType(GraphId final_graph_id) { - auto graph_type_iter = graph_order_types_.find(final_graph_id); - if (graph_type_iter == graph_order_types_.end()) { - MS_LOG(EXCEPTION) << "Final graph" << final_graph_id << "has no graph_order_types_"; - } - return graph_type_iter->second; -} - -void AscendSession::SyncInitialTenosrToDevice() { - for (auto &item : initial_tenosrs_) { - auto to_graph_id = item.first.first; - auto input_idx = item.first.second; - auto front_tensor = item.second; - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); - std::vector graph_inputs = to_graph->inputs(); - if (input_idx >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "Input_index " << input_idx << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_idx]; - // sync data from host to device - MS_EXCEPTION_IF_NULL(front_tensor); - size_t tensor_size = front_tensor->data().nbytes(); - auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); - MS_EXCEPTION_IF_NULL(addr); - if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, - front_tensor->data_type(), front_tensor->data_c())) { - MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; - } - } -} - -static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph, const std::vector &list) { - // count the output of every anf node - std::set has_output_nodes; - for (auto &anf_node : list) { - MS_EXCEPTION_IF_NULL(anf_node); - for (auto &input : anf_node->inputs()) { - (void)has_output_nodes.insert(input); - } - } - - auto make_tuple_primitve = NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())); - std::vector make_tuple_inputs = {make_tuple_primitve}; - int output_idx = 0; - MS_EXCEPTION_IF_NULL(new_kernel_graph); - for (auto &anf_node : list) { - if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { - new_kernel_graph->set_return(anf_node); - } - if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString(); - make_tuple_inputs.push_back(anf_node); - } - } - if (new_kernel_graph->get_return() == nullptr) { - new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); - } -} - -std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list) { - MS_EXCEPTION_IF_NULL(new_kernel_graph); - MS_LOG(INFO) << "Start contruct splited kernel graph:" << new_kernel_graph->graph_id(); - MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); - std::vector call_node_inputs; - std::vector new_graph_inputs; - // create new parameter from cnode - for (auto &anf_node : list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto input = cnode->inputs()[input_idx]; - MS_EXCEPTION_IF_NULL(input); - AnfNodePtr new_parameter = nullptr; - // check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one - // parameter in graph inputs and one arg in call node - auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input); - if (call_input_it != call_node_inputs.end()) { - cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]); - continue; - } - // value node consider move to new graph - if (input->isa()) { - cnode->set_input(input_idx, input); - continue; - } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { - // if is cnode and not in current child graph - new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); - cnode->set_input(input_idx, new_parameter); - } else { - // if is a cnode and in current graph - continue; - } - new_graph_inputs.push_back(new_parameter); - call_node_inputs.push_back(input); - } - } - // set graph inputs of new graph - auto graph_inputs = new_kernel_graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->clear(); - std::copy(new_graph_inputs.begin(), new_graph_inputs.end(), std::back_inserter(*graph_inputs)); - - MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); - ConstructSplitedGraphOutput(new_kernel_graph, list); - MS_LOG(INFO) << "End"; - return call_node_inputs; -} - -void AscendSession::BackendOptimization(const std::vector &all_graphs) { - MS_LOG(INFO) << "Start BackendCommonOptimization"; - for (auto &graph : all_graphs) { - opt::BackendCommonOptimization(graph); - } - MS_LOG(INFO) << "End."; -} - -void AscendSession::SplitGraphs(NotNull root_graph) { - std::set memo; - // if output of graph is nullptr,no need insert maketuple at the end of graph - if (root_graph->output() == nullptr) { - return; - } - // if root graph output is a call node ,the root graph is condition graph of 'if' sentence - auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first; - if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) { - SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo)); - for (auto &child_graph : root_graph->child_graph_order()) { - RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo)); - } - } else { - RecurseSplitGraph(root_graph, NOT_NULL(&memo)); - } - memo.clear(); - // add maketuple to the end of the last child graph to suit old process - auto output_graph = root_graph->child_graph_order().empty() ? root_graph : root_graph->child_graph_order().back(); - auto make_tuple = output_graph->NewCNode( - {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), output_graph->output()}); - output_graph->set_output(make_tuple); - // replace the real input if the real input is a call - RecurseToUpdateCallRealInput(root_graph, NOT_NULL(&memo)); -} - -AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull graph, - const std::vector &child_graph_list) { - // if child graph list only has a call ,then return the exist call - if (child_graph_list.size() == 1 && AnfAlgo::CheckPrimitiveType(child_graph_list[0], prim::kPrimCall)) { - return child_graph_list[0]; - } - // create new child graph - auto child_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(child_graph); - // create new value node to bind child graph - auto graph_value_node = graph->NewValueNode(NewValueNode(child_graph)); - std::vector new_call_input = {NewValueNode(std::make_shared(prim::kPrimCall->name())), - graph_value_node}; - // set the graph id of all node of child graph - for (auto &child_graph_node : child_graph_list) { - AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); - } - auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); - std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); - auto new_call = graph->NewCNode(new_call_input); - AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call); - return new_call; -} - -void AscendSession::SplitGraph(NotNull graph, const std::set &cut_prims, - const NotNull *> memo) { - MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id(); - bool split_flag = false; - auto apply_list = GetCNodes(TopoSort(graph->get_return())); - // update the root graph child graph order - AscendControlParser::UpdateChildGraphOrder(graph); - // get child list from current graph - std::vector> child_graph_lists = GetChildList(apply_list, cut_prims); - if (child_graph_lists.size() > 1) { - std::list depend_input = {}; - for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { - auto call_node = BindNewCallToNewGraph(graph, child_graph_lists[call_index]); - MS_EXCEPTION_IF_NULL(call_node); - // if call node is the last call of true graph,no need create child graph after that - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); - depend_input.push_front(call_node); - if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { - break; - } - } - depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); - auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); - auto new_return_primitive = - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); - graph->set_return(graph->NewCNode({new_return_primitive, depend})); - AnfNodePtr pre_call_node = nullptr; - AnfNodePtr cur_call_node = nullptr; - auto iter = depend_input.begin(); - for (++iter; iter != depend_input.end(); ++iter) { - pre_call_node = cur_call_node; - cur_call_node = *iter; - if (pre_call_node != nullptr && cur_call_node != nullptr) { - AscendControlParser::InsertControlDependToGraph(graph, NOT_NULL(cur_call_node), NOT_NULL(pre_call_node)); - } - } - split_flag = true; - } - AscendControlParser::UpdateChildGraphOrder(graph); - UpdateRealInput(graph, split_flag, memo); - MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end"; -} - -void AscendSession::RecurseSplitGraph(NotNull graph, const NotNull *> memo) { - memo->insert(graph.get()); - SplitGraph(graph, {prim::kPrimCall}, memo); - for (auto &child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) == memo->end()) { - RecurseSplitGraph(NOT_NULL(child_graph), memo); - } - } -} - -void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } - -void AscendSession::RootGraphExecutorValidate(NotNull graph) { - AscendControlParser::ExecutorValidate(graph); -} - -void AscendSession::RecurseCompileGraph(NotNull graph, const NotNull *> memo) { - memo->insert(graph.get()); - CompileChildGraph(graph); - for (auto child_graph : graph->child_graph_order()) { - if (memo->find(child_graph) != memo->end()) { - continue; - } - RecurseCompileGraph(NOT_NULL(child_graph), memo); - // copy ref map to final graph - auto child_ref_map = child_graph->GetRefMap(); - for (auto &item : child_ref_map) { - if (graph->IsInRefOutputMap(item.first)) { - MS_LOG(EXCEPTION) << "The ref pair is already in final graph!"; - } - graph->AddRefCorrespondPairs(item.first, item.second); - } - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h deleted file mode 100755 index 8a6df2bd26..0000000000 --- a/mindspore/ccsrc/session/ascend_session.h +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H -#define MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "kernel/kernel.h" -#include "session/session_factory.h" -#include "session/ascend_control_parser.h" - -namespace mindspore { -namespace session { -enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; - -class AscendSession : public SessionBasic { - public: - AscendSession() { final_graph_id_ = kInvalidGraphId; } - ~AscendSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kAscendDevice, device_id); - } - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - GraphId CompileGraph(NotNull func_graph) override; - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; - void BuildGraph(GraphId) override; - void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) override; - py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) override; - - // set parameters of final graph - GraphId SetFinalGraphInput(const std::vector &args) override; - // set output of final graph - void SetFinalGraphOutput(const BaseRef &output) override; - // insert switch and set the relative active ops - void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; - // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter - void SetChildGraphInput(GraphId g, const VectorRef &args) override; - // get graph id in child graphs by ME front anf node pointer - GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; - // get graph id of final graph - GraphId GetFinalRunGraph() const override { return final_graph_id_; } - // insert active to graph - void SetActive(GraphId, GraphId) override; - // compile child graph when session have multiple child graphs - void CompileChildGraph(const KernelGraphPtr &child_graph); - void RecurseGetSummaryNodes(KernelGraph *graph, std::map> *summary); - void GetSummaryNodes(KernelGraph *graph); - - private: - void InitRuntimeResource(); - void SelectKernel(const KernelGraph &kernel_graph) const; - void HardwareOptimize(const std::shared_ptr &kernel_graph) const; - void AdjustKernel(const std::shared_ptr &kernel_graph) const; - void RunOpAdjustKernel(const std::shared_ptr &kernel_graph) const; - void AssignStream(NotNull kernel_graph) const; - void BuildKernel(const std::shared_ptr &kernel_graph) const; - void MemoryAlloc(KernelGraph *kernel_graph) const; - void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; - void RunOpMemoryClear(const KernelGraph *kernel_graph) const; - void GenerateTaskInfo(const std::shared_ptr &kernel_graph) const; - void LoadTask(const std::shared_ptr &kernel_graph) const; - void ExecTask(const std::shared_ptr &kernel_graph) const; - void Dump(const std::shared_ptr &kernel_graph) const; - void DumpAllGraphs(const std::vector &all_graphs); - void LoadTensor(const std::shared_ptr &kernel_graph) const; - // below functions are used for run op - void RunOpHardwareOptimize(const std::shared_ptr &kernel_graph) const; - void RunOpExecTask(const std::shared_ptr &kernel_graph) const; - - size_t SetChildGraphInput(const KernelGraphPtr &graph, const AnfNodePtr &node, size_t input_index); - size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); - size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); - - void SetFinalGraphOutput(const AnfNodePtr &node); - void SetFinalGraphOutput(const ValuePtr &value); - void SetFinalGraphOutput(const VectorRef &vec_output); - - void SplitGraph(NotNull graph, const std::set &cut_prims, - const NotNull *> memo); - // split graphs with recurse from root graph - void SplitGraphs(NotNull root_graph); - void BackendOptimization(const std::vector &all_graphs); - void LinkChildGraphs(NotNull graph); - void RootGraphExecutorValidate(NotNull graph); - std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list); - void RecurseCompileGraph(NotNull graph, const NotNull *> memo); - void RecurseSplitGraph(NotNull graph, const NotNull *> memo); - AnfNodePtr BindNewCallToNewGraph(NotNull graph, const std::vector &child_graph_list); - - // merge execution order list of child graphs - void MergeGraphExecOrder(); - // insert assion op to sync data bettween different graphs - void InsertAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); - // insert mutiple assigns to graph - void InsertMultipleAssignToGraph(GraphId graph_id, const AnfNodePtr &from, const AnfNodePtr &to); - // insert active op to graph - void InsertStreamActiveToGraph(GraphId graph_id, uint32_t actived_stream); - // get execute index of graph - size_t ExecOrderOfChildGraph(GraphId final_graph, GraphId child_graph); - // handle condition graph from vm - void InsertSwitchToGraph(GraphId condition_graph_id, GraphId true_graph_id); - // insert depend to graph, used to attch control nodes to graph - void InsertDependToGraph(GraphId graph_id, const AnfNodePtr &attch_node); - // insert depend to graph, used to attch control nodes to graph - void InsertControlDependToGraph(GraphId graph_id, const AnfNodePtr &first_node, const AnfNodePtr &second_node); - // set child graph parameter if front arg is a anf - void SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId to_graph_id, size_t input_idx); - // set child graph parameter if front arg is a tensor - void SetChildGraphParameter(const tensor::TensorPtr &front_tensor, GraphId to_graph_id, size_t input_idx); - // update the execution order of all child graphs - void UpdateGraphOrder(GraphId to_graph); - // handle switch when merge - void MergeSwitchCompile(); - // get graph order vector by graph id - std::vector &GetGraphOrder(GraphId final_graph_id); - // get graph order type vector by graph id - std::vector &GetGraphOrderType(GraphId final_graph_id); - // copy output of if and else - void CopyOutputOfIf(GraphId false_graph_id); - // check if graph cache exist - bool GraphCacheExist(const GraphInfo &graph_info) const; - // insert all assign to child graph - void InsertAllAssigns(); - // create fake output of final graph - AnfNodePtr CreateFakeOutput(GraphId final_graph_id, const AnfNodePtr &true_output); - // sync intial tensors' data to device - void SyncInitialTenosrToDevice(); - void SetFinalGraphSummaryFlag(const std::shared_ptr &kernel_graph); - - // member variables - // key is final_graph_id,value is child graph execute order of final graph - std::unordered_map> graph_execute_orders_; - // key is final_graph_id,value is the graph types of child graphs - std::unordered_map> graph_order_types_; - // record condition graph of while - std::unordered_map while_condition_graphs_; - // record all conditions - std::unordered_map> switches_; - std::unordered_map condition_output_; - // share parameters - std::vector> assigns_; - // initial tensors, these tensor will sync data to device before run graph - std::map, tensor::TensorPtr> initial_tenosrs_; - // final_graph_id is used in every root graph has it's own session situation - GraphId final_graph_id_; -}; -MS_REG_SESSION(kAscendDevice, AscendSession); -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_ASCEND_SESSION_H diff --git a/mindspore/ccsrc/session/cpu_session.cc b/mindspore/ccsrc/session/cpu_session.cc deleted file mode 100644 index 1927df2f49..0000000000 --- a/mindspore/ccsrc/session/cpu_session.cc +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "session/cpu_session.h" -#include -#include "ir/tensor.h" -#include "ir/anf.h" -#include "kernel/kernel.h" -#include "common/utils.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_runtime.h" -#include "predict/predict.h" -#include "kernel/cpu/cpu_kernel_factory.h" -#include "device/cpu/kernel_select_cpu.h" -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif - -namespace mindspore { -namespace session { -ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; - } - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - ParameterPtr new_parameter = graph->NewParameter(anf->cast()); - TraceManager::EndTrace(); - graph_inputs->push_back(new_parameter); - valid_inputs->push_back(valid_input); - return new_parameter; -} - -GraphId CPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - auto graph_id = graph_sum_; - auto graph = ConstructKernelGraph(lst, outputs); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Set kernel info"; - SetKernelInfo(graph.get()); - predictmodel::StepConvertGraph(graph); - MS_LOG(INFO) << "Build kernel"; - BuildKernel(graph.get()); - MS_LOG(INFO) << "Assign kernel address"; - runtime_.AssignKernelAddress(graph.get()); - return graph_id; -} - -void CPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - auto &kernel_graph = graphs_[graph_id]; - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_LOG(INFO) << "Bind input output address"; - std::vector need_sync_outputs; - runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs, &need_sync_outputs); - MS_LOG(INFO) << "Run graph start"; - predictmodel::StepConvertWeight(inputs); - auto execution_order = kernel_graph->execution_order(); - Reorder(&execution_order); - - bool enable_summary = summary_callback_ != nullptr; - kernel_graph->set_execution_order(execution_order); - NamedSummaryOutputs summary_outputs; - if (enable_summary) { - GetSummaryNodes(kernel_graph.get()); - summary_outputs = kernel_graph->summary_nodes(); - runtime_.IncreaseSummaryRefCount(summary_outputs); - } -#ifdef ENABLE_DEBUGGER - // debugger pre-execution processing - if (debugger_) { - debugger_->PreExecute(kernel_graph); - } -#endif - bool ret = runtime_.Run(kernel_graph.get()); - if (!ret) { - MS_LOG(EXCEPTION) << "Run graph failed"; - } - for (auto output : need_sync_outputs) { - (void)output->data_sync(); - } - - if (enable_summary) { - Summary(kernel_graph.get()); - runtime_.DecreaseSummaryRefCount(summary_outputs); - } - -#ifdef ENABLE_DEBUGGER - // debugger post-execution processing - if (debugger_) { - debugger_->PostExecute(); - } -#endif - MS_LOG(INFO) << "Run graph end"; -} - -void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto &kernel_nodes = kernel_graph->execution_order(); - for (const auto &kernel_node : kernel_nodes) { - MS_EXCEPTION_IF_NULL(kernel_node); - device::cpu::SetKernelInfo(kernel_node); - } -} - -void CPUSession::BuildKernel(const KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto &kernel_nodes = kernel_graph->execution_order(); - for (const auto &kernel_node : kernel_nodes) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "]."; - std::shared_ptr cpu_kernel = - kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node); - if (cpu_kernel == nullptr) { - MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support."; - } - cpu_kernel->Init(kernel_node); - AnfAlgo::SetKernelMod(cpu_kernel, kernel_node.get()); - MS_LOG(INFO) << "Cpu build success operator[" << kernel_name << "]."; - } -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/cpu_session.h b/mindspore/ccsrc/session/cpu_session.h deleted file mode 100644 index 36b987e840..0000000000 --- a/mindspore/ccsrc/session/cpu_session.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_CPU_SESSION_H -#define MINDSPORE_CCSRC_SESSION_CPU_SESSION_H -#include -#include -#include -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "device/cpu/cpu_kernel_runtime.h" -#include "session/session_factory.h" -namespace mindspore { -namespace session { -class CPUSession : public SessionBasic { - public: - CPUSession() = default; - ~CPUSession() override = default; - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kCPUDevice, device_id); - } - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; - - protected: - ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) override; - - private: - void SetKernelInfo(const KernelGraph *kernel_graph); - void BuildKernel(const KernelGraph *kernel_graph); - device::cpu::CPUKernelRuntime runtime_; -}; -MS_REG_SESSION(kCPUDevice, CPUSession); -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_CPU_SESSION_H diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc deleted file mode 100644 index 8d6d176970..0000000000 --- a/mindspore/ccsrc/session/gpu_session.cc +++ /dev/null @@ -1,268 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/gpu_session.h" -#include "device/gpu/kernel_info_setter.h" -#include "device/gpu/gpu_kernel_build.h" -#include "device/gpu/gpu_kernel_runtime.h" -#include "device/gpu/gpu_stream_assign.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/common/helper.h" -#include "pre_activate/pass/communication_op_fusion.h" -#include "pre_activate/pass/getitem_tuple.h" -#include "pre_activate/gpu/adam_weight_decay_fusion.h" -#include "pre_activate/gpu/adam_fusion.h" -#include "device/kernel_runtime_manager.h" -#include "predict/predict.h" -#include "common/utils.h" -#include "common/trans.h" -#include "utils/context/ms_context.h" -#include "utils/base_ref_extends.h" - -namespace mindspore { -namespace session { -namespace gpu { -using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; - -void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - for (const auto &kernel_node : kernel_graph->execution_order()) { - MS_EXCEPTION_IF_NULL(kernel_node); - device::gpu::SetKernelInfo(kernel_node); - } -} - -void GPUSession::StartKernelRT() const { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Init()) { - MS_LOG(EXCEPTION) << "GPU start kernel runtime failed"; - } -} - -void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_graph) { - auto optimizer = std::make_shared(); - auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); - pm->AddPass(std::make_shared()); - optimizer->AddPassManager(pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - -void GPUSession::AssignStream(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - device::gpu::AssignGpuStream(kernel_graph); -} - -void GPUSession::BuildKernel(const std::shared_ptr &kernel_graph) const { - device::gpu::GpuBuild(kernel_graph); -} - -void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->AssignMemory(kernel_graph); -} - -void GPUSession::RunOpAllocateMemory(const std::vector &input_tensors, - KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); -} - -void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpClearMemory(kernel_graph); -} - -void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const { - std::vector inputs(inputs_const); - MS_EXCEPTION_IF_NULL(kernel_graph); - auto input_nodes = kernel_graph->inputs(); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - - for (size_t i = 0; i < inputs.size(); ++i) { - auto tensor = inputs[i]; - MS_EXCEPTION_IF_NULL(tensor); - auto input_node = input_nodes[i]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - auto pk_node = input_node->cast(); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - auto tensor_address = tensor->device_address(); - bool need_sync = false; - if (ms_context->enable_pynative_infer()) { - if (tensor_address == nullptr || tensor_address != device_address) { - need_sync = true; - } - } else if (tensor->is_dirty() || tensor_address == nullptr) { - need_sync = true; - } else if (tensor_address != device_address) { - if (tensor_address->DeviceType() == device_address->DeviceType()) { - AnfAlgo::SetOutputAddr(tensor_address, 0, pk_node.get()); - } else { - need_sync = true; - } - } - if (need_sync) { - tensor->set_device_address(device_address); - MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } - tensor->set_dirty(false); - } -} - -void GPUSession::Execute(const std::shared_ptr &kernel_graph) const { - auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); - MS_EXCEPTION_IF_NULL(runtime_instance); - if (!runtime_instance->Run(kernel_graph.get())) { - MS_LOG(EXCEPTION) << "GPU execute graph failed!"; - } -} - -GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - // Construct graph, if successfully, graph_sum_ + 1 - auto graph_id = graph_sum_; - auto graph = ConstructKernelGraph(lst, outputs); - MS_EXCEPTION_IF_NULL(graph); - // Optimize - Optimize(graph); - // Select kernel build info - SelectKernel(graph); - // Convert kernel Graph to model - predictmodel::StepConvertGraph(graph); - // Start gpu kernel runtime - StartKernelRT(); - // HardwareOptimize - HardwareOptimize(graph); - // Assign CUDA streams - AssignStream(graph); - // Hide NoOp from execution graph - opt::HideNopNode(graph.get()); - // Build kernel if node is cnode - BuildKernel(graph); - // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph - auto execution_order = graph->execution_order(); - Reorder(&execution_order); - graph->set_execution_order(execution_order); - // Get summary nodes. - GetSummaryNodes(graph.get()); - // Remove NoOp from execution graph - opt::RemoveNopNode(graph.get()); - // Set graph manager. - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = MakeManager({graph}); - context_->AddManager(manager); - if (manager) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - // Alloc memory, including static memory and dynamic memory - AllocateMemory(graph.get()); - return graph_id; -} - -void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { - auto &kernel_graph = graphs_[graph_id]; - // Load input data from user input - LoadInputData(kernel_graph, inputs); - MS_EXCEPTION_IF_NULL(kernel_graph); - // Convert inputs to model - predictmodel::StepConvertWeight(inputs); - { - py::gil_scoped_release gil_release; - // Run graph on GPU - Execute(kernel_graph); - } - // Get result from GPU - UpdateOutputs(kernel_graph, outputs, inputs); - // Summary - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->enable_gpu_summary()) { - Summary(kernel_graph.get()); - } -} - -void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) { - // Check if the graph cache exists. - if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { - return; - } - // Prepare the graph - auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); - MS_EXCEPTION_IF_NULL(kernel_graph); - SelectKernel(kernel_graph); - StartKernelRT(); - // Hide NoOp from execution graph - opt::HideNopNode(kernel_graph.get()); - BuildKernel(kernel_graph); - run_op_graphs_[graph_info] = kernel_graph; -} - -py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) { - auto kernel_graph = run_op_graphs_[graph_info]; - MS_EXCEPTION_IF_NULL(kernel_graph); - // Remove NoOp from execution graph - opt::RemoveNopNode(kernel_graph.get()); - RunOpAllocateMemory(input_tensors, kernel_graph.get()); - // Execute the computation - LoadInputData(kernel_graph, input_tensors); - Execute(kernel_graph); - // Fetch outputs - VectorRef outputs; - UpdateOutputs(kernel_graph, &outputs, input_tensors); - // Trans output to tuple - auto output_tensors = TransformBaseRefListToTuple(outputs); - if (!utils::isa(output_tensors) || - !py::isinstance(utils::cast(output_tensors).object_)) { - MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !"; - } - py::object tuple_obj = utils::cast(output_tensors).object_; - py::tuple tuple_tensors = py::cast(tuple_obj); - RunOpClearMemory(kernel_graph.get()); - return tuple_tensors; -} -} // namespace gpu -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/gpu_session.h b/mindspore/ccsrc/session/gpu_session.h deleted file mode 100644 index 4e46c2138d..0000000000 --- a/mindspore/ccsrc/session/gpu_session.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_GPU_SESSION_H -#define MINDSPORE_CCSRC_SESSION_GPU_SESSION_H - -#include -#include -#include "session/session_basic.h" -#include "session/kernel_graph.h" -#include "session/session_factory.h" -using KernelGraph = mindspore::session::KernelGraph; - -namespace mindspore { -namespace session { -namespace gpu { -class GPUSession : public SessionBasic { - public: - GPUSession() = default; - ~GPUSession() override = default; - - void Init(uint32_t device_id) override { - SessionBasic::Init(device_id); - context_ = std::make_shared(kGPUDevice, device_id); - } - - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; - void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) override; - py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors) override; - - private: - void SelectKernel(const std::shared_ptr &kernel_graph) const; - - void StartKernelRT() const; - - void Optimize(const std::shared_ptr &kernel_graph); - - void HardwareOptimize(const std::shared_ptr &kernel_graph); - - void AssignStream(const std::shared_ptr &kernel_graph); - - void BuildKernel(const std::shared_ptr &kernel_graph) const; - - void AllocateMemory(KernelGraph *kernel_graph) const; - - void RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const; - - void RunOpClearMemory(KernelGraph *kernel_graph) const; - - void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const override; - - void Execute(const std::shared_ptr &kernel_graph) const; -}; -using GPUSessionPtr = std::shared_ptr; -MS_REG_SESSION(kGPUDevice, GPUSession); -} // namespace gpu -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_GPU_SESSION_H diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc deleted file mode 100644 index c8cc6fbbee..0000000000 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ /dev/null @@ -1,998 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/kernel_graph.h" -#include -#include -#include -#include -#include "operator/ops.h" -#include "ir/param_value.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" -#include "kernel/kernel_build_info.h" -#include "device/kernel_runtime_manager.h" -#include "kernel/common_utils.h" - -namespace mindspore { -namespace session { -namespace { -constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput"; -constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList"; -void PushNoVisitedNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(que); - MS_EXCEPTION_IF_NULL(visited_nodes); - if (visited_nodes->find(node) == visited_nodes->end()) { - que->push(node); - (void)visited_nodes->insert(node); - MS_LOG(DEBUG) << "Push que:" << node->DebugString(); - } -} - -std::vector GetCallRealOutputs(const AnfNodePtr &call_node) { - auto item_with_index = - AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple}); - AnfNodePtr node = item_with_index.first; - MS_EXCEPTION_IF_NULL(node); - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) { - auto outputs = AnfAlgo::GetAllOutput(node); - std::set memo; - std::vector new_output; - for (auto &output : outputs) { - if (memo.find(output) != memo.end()) { - continue; - } - memo.insert(output); - new_output.push_back(output); - } - if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) { - node = new_output[0]; - } - } - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) { - return {node}; - } - std::vector real_inputs; - auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node->cast()); - for (const auto &child_graph : child_graphs) { - if (child_graph->get_output_null()) { - continue; - } - auto real_input = child_graph->output(); - auto child_real_inputs = GetCallRealOutputs(real_input); - std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs)); - } - return real_inputs; -} - -AnfNodePtr MakeValueNode(const AnfNodePtr &node) { - auto value_node = node->cast(); - if (value_node == nullptr) { - return nullptr; - } - - ValueNodePtr new_value_node = std::make_shared(value_node->value()); - new_value_node->set_abstract(value_node->abstract()); - // create kernel_info fo new value node - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { - types.push_back(kTypeUnknown); - } - kernel_build_info_builder->SetOutputsDeviceType(types); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - return new_value_node; -} - -bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) { - if (left == right) { - return true; - } - if (left == nullptr || right == nullptr) { - return false; - } - if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) { - return false; - } - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) { - return AnfAlgo::GetNodeAttr(left, kAttrLabelIndex) == - AnfAlgo::GetNodeAttr(right, kAttrLabelIndex); - } - return false; -} -} // namespace -std::vector KernelGraph::outputs() const { - auto graph_output = output(); - if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) { - auto make_tuple = output()->cast(); - MS_EXCEPTION_IF_NULL(make_tuple); - auto &inputs = make_tuple->inputs(); - return std::vector(inputs.begin() + 1, inputs.end()); - } - return std::vector(1, graph_output); -} - -void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(visit_queue); - MS_EXCEPTION_IF_NULL(visited_nodes); - auto it = node_output_edges_.find(node); - if (it == node_output_edges_.end()) { - // value node and parameter has no input,no need to print log - if (node->isa()) { - MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; - } - return; - } - - // visit all reduce node first, then other nodes - std::vector active_nodes; - for (const auto &output_edge : it->second) { - auto next_node = output_edge.first; - MS_EXCEPTION_IF_NULL(next_node); - if (node_input_num_.find(next_node) == node_input_num_.end()) { - MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; - } - MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() - << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; - if (node_input_num_[next_node] < output_edge.second) { - MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] - << ",depend edge:" << output_edge.second; - } - node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; - // allreduce first - if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { - (void)visited_nodes->insert(next_node); - if (AnfAlgo::IsCommunicationOp(next_node)) { - MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); - visit_queue->push(next_node); - } else { - active_nodes.emplace_back(next_node); - } - } - } - - for (auto &node : active_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Visit node:" << node->DebugString(); - visit_queue->push(node); - } -} - -void KernelGraph::SetExecOrderByDefault() { - std::queue seed_nodes; - UpdateNodeEdgeList(&seed_nodes); - execution_order_.clear(); - std::unordered_set visited_nodes; - std::queue zero_input_nodes; - AnfNodePtr last_communication_node = nullptr; - std::queue communication_descendants; - while (!seed_nodes.empty() || last_communication_node != nullptr) { - // seed nodes first, then visit last all reduce node descendant - if (seed_nodes.empty()) { - VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); - last_communication_node = nullptr; - } else { - zero_input_nodes.push(seed_nodes.front()); - seed_nodes.pop(); - } - // all reduce node descendant first, then common queue - while (!zero_input_nodes.empty() || !communication_descendants.empty()) { - AnfNodePtr node = nullptr; - bool is_communication_descendant = false; - if (communication_descendants.empty()) { - node = zero_input_nodes.front(); - zero_input_nodes.pop(); - } else { - node = communication_descendants.front(); - communication_descendants.pop(); - is_communication_descendant = true; - } - // add execute node - MS_EXCEPTION_IF_NULL(node); - if (node->isa() && AnfAlgo::IsRealKernel(node)) { - execution_order_.push_back(node->cast()); - } - // for all reduce node, visit last all reduce node descendant - if (AnfAlgo::IsCommunicationOp(node)) { - if (last_communication_node != nullptr) { - VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); - } - last_communication_node = node; - } else if (is_communication_descendant) { - VisitNodeDescendants(node, &communication_descendants, &visited_nodes); - } else { - VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); - } - } - } - CheckLoop(); - // resort start label / end goto - std::vector re_order; - if (start_label_ != nullptr) { - re_order.push_back(start_label_); - } - for (auto &node : execution_order_) { - if (node == start_label_ || node == end_goto_) { - continue; - } - - if (IsSameLabel(node, end_goto_)) { - end_goto_ = node; - MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id(); - continue; - } - - if (IsSameLabel(node, start_label_)) { - start_label_ = node; - MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id(); - continue; - } - - re_order.push_back(node); - } - if (end_goto_ != nullptr) { - re_order.push_back(end_goto_); - } - execution_order_ = re_order; -} - -void KernelGraph::CheckLoop() { - std::map none_zero_nodes; - if (node_input_edges_.size() != node_input_num_.size()) { - MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size() - << "not equal to node_input_num_ size:" << node_input_num_.size(); - } - for (auto &it : node_input_num_) { - MS_EXCEPTION_IF_NULL(it.first); - string str; - auto node_input_it = node_input_edges_.find(it.first); - if (node_input_it == node_input_edges_.end()) { - MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]"; - } - for (const auto &input_edge : node_input_edges_[it.first]) { - MS_EXCEPTION_IF_NULL(input_edge.first); - str = str.append(input_edge.first->DebugString()).append("|"); - } - if (it.second != 0) { - MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second; - none_zero_nodes[it.first] = it.second; - } - } - // if don't consider control depend and loop exit,a exception will be throw - if (!none_zero_nodes.empty()) { - MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); - } -} - -CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { - auto cnode = FuncGraph::NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode); - cnode->set_abstract(std::make_shared()); - CreateKernelInfoFromNewParameter(cnode); - - auto kernel_info = std::make_shared(); - std::vector feature_map_input_indexs; - // if the node only has the primitive(such as getNext) or the node's input has a feature map input - // then the node's output is a feature map output - for (size_t index = 1; index < inputs.size(); ++index) { - auto node = inputs[index]; - if (AnfAlgo::IsFeatureMapOutput(node)) { - feature_map_input_indexs.push_back(index); - } - } - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) { - AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); - } - if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { - kernel_info->SetFeatureMapFlag(true); - } - if (AnfAlgo::IsRealCNodeKernel(cnode)) { - AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); - AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); - } - cnode->set_kernel_info(kernel_info); - AnfAlgo::SetGraphId(graph_id_, cnode.get()); - return cnode; -} - -void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { - if (!AnfAlgo::IsGraphKernel(cnode)) { - return; - } - auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - MS_EXCEPTION_IF_NULL(func_graph); - - std::vector node_list; - std::vector input_list; - std::vector output_list; - kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list); - for (auto &anf_node : node_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto kernel_info = std::make_shared(); - anf_node->set_kernel_info(kernel_info); - auto anf_cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(anf_cnode); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) { - auto input_node = anf_cnode->input(i + 1); - MS_EXCEPTION_IF_NULL(input_node); - if (IsValueNode(input_node)) { - auto new_input_node = MakeValueNode(input_node); - if (new_input_node != nullptr) { - anf_cnode->set_input(i + 1, new_input_node); - } - } - } - } - for (auto &anf_node : input_list) { - MS_EXCEPTION_IF_NULL(anf_node); - auto kernel_info = std::make_shared(); - anf_node->set_kernel_info(kernel_info); - } -} - -CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - auto new_cnode = std::make_shared(*cnode); - // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map - if (BackendNodeExistInFrontBackendMap(cnode)) { - FrontBackendlMapUpdate(cnode, new_cnode); - } - AnfAlgo::SetGraphId(graph_id_, cnode.get()); - if (IsInternalOutput(cnode)) { - ReplaceInternalOutput(cnode, new_cnode); - } - return new_cnode; -} - -ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { - ParameterPtr new_parameter = add_parameter(); - MS_EXCEPTION_IF_NULL(new_parameter); - // create kernel_info form new parameter - auto kernel_info = std::make_shared(); - size_t output_tensor_num = 1; - // if use default parameter = nullptr,it remarks create a new parameter from no parameter - if (parameter == nullptr) { - new_parameter->set_abstract(std::make_shared()); - kernel_info->SetFeatureMapFlag(true); - } else { - // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter - new_parameter->set_abstract(parameter->abstract()); - new_parameter->set_name(parameter->name()); - if (AnfAlgo::IsParameterWeight(parameter)) { - new_parameter->set_default_param(parameter->default_param()); - kernel_info->SetFeatureMapFlag(false); - } else { - kernel_info->SetFeatureMapFlag(true); - } - } - new_parameter->set_kernel_info(kernel_info); - // create kernel_build_info for new parameter - auto kernel_build_info_builder = std::make_shared(); - // create init data type, - std::vector init_data_type = {}; - - TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0); - init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type); - - // set the format of parameter to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat(std::vector(output_tensor_num, kOpFormat_DEFAULT)); - // set parameter initaial device data type - kernel_build_info_builder->SetOutputsDeviceType(init_data_type); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get()); - AnfAlgo::SetGraphId(graph_id_, new_parameter.get()); - return new_parameter; -} - -std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto node_value = value_node->value(); - auto output_size = AnfAlgo::GetOutputTensorNum(value_node); - std::vector convert_inputs; - if (!node_value->isa()) { - MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString(); - } - auto value_tuple = node_value->cast(); - MS_EXCEPTION_IF_NULL(value_tuple); - if (value_tuple->size() != output_size) { - MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size() - << " is not mathced with the value node's output size" << output_size; - } - for (size_t index = 0; index < value_tuple->value().size(); ++index) { - auto new_value_node = std::make_shared(value_tuple->value()[index]); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)}, - {AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get()); - AddValueNodeToGraph(new_value_node); - auto kernel_info = std::make_shared(); - new_value_node->set_kernel_info(kernel_info); - kernel_info->SetFeatureMapFlag(false); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - // set the format of value_node to DEFAULT_FORMAT - kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); - // set value node initial device data type = infer data type - kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - AddValueNodeToGraph(new_value_node); - convert_inputs.emplace_back(new_value_node); - } - if (!RemoveValueNodeFromGraph(value_node)) { - MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString(); - } - return convert_inputs; -} - -ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(value_node); - auto new_value_node = MakeValueNode(value_node)->cast(); - AnfAlgo::SetGraphId(graph_id_, new_value_node.get()); - return new_value_node; -} - -const std::vector &KernelGraph::inputs() const { - MS_EXCEPTION_IF_NULL(inputs_); - return *inputs_; -} - -void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) { - MS_EXCEPTION_IF_NULL(front_anf); - MS_EXCEPTION_IF_NULL(backend_anf); - if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_"; - } - if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_"; - } - front_backend_anf_map_[front_anf] = backend_anf; - backend_front_anf_map_[backend_anf] = front_anf; -} - -void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { - MS_EXCEPTION_IF_NULL(old_backend_anf); - MS_EXCEPTION_IF_NULL(new_backend_anf); - if (old_backend_anf == new_backend_anf) { - MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString(); - return; - } - if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { - MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; - return; - } - if (front_backend_anf_map_.find(backend_front_anf_map_[old_backend_anf]) == front_backend_anf_map_.end()) { - MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString(); - } - front_backend_anf_map_[backend_front_anf_map_[old_backend_anf]] = new_backend_anf; - backend_front_anf_map_[new_backend_anf] = backend_front_anf_map_[old_backend_anf]; - // delete old kernel - (void)backend_front_anf_map_.erase(old_backend_anf); -} -// get kernel by anf -AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { - if (front_backend_anf_map_.find(front_anf) == front_backend_anf_map_.end()) { - return nullptr; - } - return front_backend_anf_map_[front_anf]; -} - -bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { - return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); -} - -ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) { - if (tensor_to_value_node_map_.find(tensor) == tensor_to_value_node_map_.end()) { - return nullptr; - } - return tensor_to_value_node_map_[tensor]; -} - -void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) { - MS_EXCEPTION_IF_NULL(tensor); - MS_EXCEPTION_IF_NULL(value_node); - tensor_to_value_node_map_[tensor] = value_node; -} - -void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(input); - MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num; - auto output_depend_edge = std::pair(node, depend_edge_num); - // add output depend edge of input - auto output_it = node_output_edges_.find(input); - if (output_it == node_output_edges_.end()) { - node_output_edges_[input] = std::vector>{output_depend_edge}; - } else { - output_it->second.push_back(output_depend_edge); - } - // add input depend edge of output - auto input_depend_edge = std::pair(input, depend_edge_num); - auto input_it = node_input_edges_.find(node); - if (input_it == node_input_edges_.end()) { - node_input_edges_[node] = std::vector>{input_depend_edge}; - } else { - input_it->second.push_back(input_depend_edge); - } - // add node input depend num - auto depend_it = node_input_num_.find(node); - if (depend_it == node_input_num_.end()) { - node_input_num_[node] = depend_edge_num; - } else { - depend_it->second += depend_edge_num; - } -} - -std::vector KernelGraph::GetOutputNodes(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto it = node_output_edges_.find(node); - if (it == node_output_edges_.end()) { - MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]"; - } - std::vector output_nodes; - auto trans = [](const std::pair &pair) -> AnfNodePtr { return pair.first; }; - (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes), trans); - return output_nodes; -} - -// Find control_depend real input nodes. -void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector *result, std::set *visited) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(result); - MS_EXCEPTION_IF_NULL(visited); - if (visited->find(anf_node) != visited->end()) { - MS_LOG(WARNING) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited"; - return; - } - visited->insert(anf_node); - if (AnfAlgo::IsRealKernel(anf_node)) { - result->emplace_back(anf_node); - return; - } - if (!anf_node->isa()) { - return; - } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().empty()) { - MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString(); - } - auto input0 = cnode->input(0); - if (IsPrimitive(input0, prim::kPrimMakeTuple)) { - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - GetAllFatherRealNode(cnode->input(i), result, visited); - } - } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) { - if (cnode->inputs().size() != kTupleGetItemInputSize) { - MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!"; - } - GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited); - } else if (IsPrimitive(input0, prim::kPrimDepend)) { - if (cnode->inputs().size() != kDependInputSize) { - MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!"; - } - GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited); - GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited); - } -} - -// update the depend relations of control depend -void KernelGraph::UpdateControlDependRelations(const std::vector &depends) { - for (const auto &node : depends) { - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { - MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend"; - } - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(depend_node); - std::vector prior_nodes = {prior_node}; - std::vector depend_nodes = {depend_node}; - int depend_mode = 0; - if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { - depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); - } - MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString() - << "], depend_mode :" << depend_mode << "."; - if (prior_node->isa() && depend_mode == 1) { - prior_nodes = GetOutputNodes(prior_node); - } - if (depend_node->isa()) { - depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector{}; - } - - std::vector real_prior_nodes; - std::set prior_visited; - for (const auto &tmp : prior_nodes) { - GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); - } - - std::vector real_depend_nodes; - std::set depend_visited; - for (const auto &tmp : depend_nodes) { - GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited); - } - - for (auto &first_node : real_prior_nodes) { - if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) { - continue; - } - for (auto &second_node : real_depend_nodes) { - if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) { - continue; - } - MS_EXCEPTION_IF_NULL(first_node); - MS_EXCEPTION_IF_NULL(second_node); - MS_LOG(INFO) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); - AddDependEdge(second_node, first_node, 1); - } - } - } -} - -bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(que); - MS_EXCEPTION_IF_NULL(visited_nodes); - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) { - return false; - } - // set the control depend visited but don't push it into the que - if (visited_nodes->find(node) != visited_nodes->end()) { - return true; - } - (void)visited_nodes->insert(cnode); - // add a 0 depend num to keep the link relations to prepare for finding zero output nodes - auto prior_node = cnode->input(kControlDependPriorIndex); - auto depend_node = cnode->input(kControlDependBehindIndex); - for (const auto &input : cnode->inputs()) { - AddDependEdge(node, input, 0); - } - PushNoVisitedNode(depend_node, que, visited_nodes); - PushNoVisitedNode(prior_node, que, visited_nodes); - return true; -} - -void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { - MS_EXCEPTION_IF_NULL(seed_nodes); - node_output_edges_.clear(); - node_input_num_.clear(); - node_input_edges_.clear(); - std::vector control_depends; - std::unordered_set visited_nodes; - std::queue que; - que.push(get_return()); - while (!que.empty()) { - auto node = que.front(); - que.pop(); - MS_EXCEPTION_IF_NULL(node); - if (node->isa() || node->isa()) { - seed_nodes->push(node); - continue; - } - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // handle data links - for (const auto &input : cnode->inputs()) { - size_t depend_edge_num = 1; - // handle control depend,all inputs of control depend has no depend edge - if (HandleControlDependNode(input, &que, &visited_nodes)) { - control_depends.push_back(input); - depend_edge_num = 0; - } - PushNoVisitedNode(input, &que, &visited_nodes); - AddDependEdge(node, input, depend_edge_num); - } - } - UpdateControlDependRelations(control_depends); -} - -void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); } - -bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; } - -AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const { - if (!IsInRefOutputMap(out_pair)) { - MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap"; - } - return ref_out_in_map_.at(out_pair); -} - -void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) { - if (IsInRefOutputMap(final_pair)) { - MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap"; - } - (void)ref_out_in_map_.insert(std::make_pair(final_pair, origin_pair)); -} - -bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { - if (graph_value_nodes_.find(value_node) != graph_value_nodes_.end()) { - (void)graph_value_nodes_.erase(value_node); - return true; - } - return false; -} - -void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { - MS_EXCEPTION_IF_NULL(inputs_); - { - std::queue seed_nodes; - UpdateNodeEdgeList(&seed_nodes); - } - auto it = node_output_edges_.find(old_anf_node); - if (it != node_output_edges_.end()) { - const auto &outputs = it->second; - for (auto &output_node : outputs) { - MS_EXCEPTION_IF_NULL(output_node.first); - auto output_cnode = output_node.first->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto &output_node_inputs = output_cnode->inputs(); - // don't replace node if it is a control edge => output_node.second == 0 - if (output_node.second == 0) { - continue; - } - for (size_t i = 1; i < output_node_inputs.size(); i++) { - if (output_node_inputs[i] == old_anf_node.get()) { - output_cnode->set_input(i, new_anf_node); - } - } - // update graph inputs - for (size_t i = 0; i < inputs_->size(); i++) { - if ((*inputs_)[i] == old_anf_node.get()) { - MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() - << ",new graph input:" << new_anf_node->DebugString(); - (*inputs_)[i] = new_anf_node.get(); - break; - } - } - } - // update front to backend map - FrontBackendlMapUpdate(old_anf_node, new_anf_node); - } - { - std::queue seed_nodes; - UpdateNodeEdgeList(&seed_nodes); - } - // update graph inputs in child graph - auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(), - [&old_anf_node](const std::pair> &n) -> bool { - return n.first == old_anf_node.get(); - }); - if (it_real_inputs != real_inputs_.end()) { - // erase old parameter in map - auto old_args = it_real_inputs->second; - real_inputs_.erase(it_real_inputs); - // insert new parameter to map - auto iter = std::find_if(real_inputs_.begin(), real_inputs_.end(), - [&new_anf_node](const std::pair> &n) -> bool { - return n.first == new_anf_node.get(); - }); - if (iter != real_inputs_.end()) { - MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited."; - iter->second = old_args; - } else { - real_inputs_.emplace_back(new_anf_node, old_args); - } - } -} - -void KernelGraph::UpdateExecuteKernelStreamLabel() { - for (auto &kernel : execution_order_) { - AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get()); - } -} - -std::vector> KernelGraph::GetLeafGraphOrder() { - std::vector> leaf_graph_order; - if (IsLeafGraph()) { - leaf_graph_order.push_back(shared_from_this()->cast()); - } else { - for (const auto &child_graph : child_graph_order_) { - MS_EXCEPTION_IF_NULL(child_graph); - auto child_leaf_graph_order = child_graph->GetLeafGraphOrder(); - std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order)); - } - } - return leaf_graph_order; -} - -bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } - -std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { - std::vector result; - for (const auto &anf : execution_order_) { - if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { - result.push_back(anf->cast()); - } - } - return result; -} - -void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { - MS_EXCEPTION_IF_NULL(parameter); - MS_EXCEPTION_IF_NULL(arg); - MS_LOG(INFO) << "Parameter: " << parameter->DebugString() << ", real input : " << arg->DebugString(); - MS_EXCEPTION_IF_NULL(parameter); - MS_EXCEPTION_IF_NULL(arg); - auto iter = std::find_if( - real_inputs_.begin(), real_inputs_.end(), - [¶meter](const std::pair> &n) -> bool { return n.first == parameter; }); - if (iter != real_inputs_.end()) { - auto &args = iter->second; - args.push_back(arg); - } else { - real_inputs_.emplace_back(parameter, std::vector(1, arg)); - } -} - -void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph) { - unreuse_args_[arg] = from_graph; -} - -void KernelGraph::UpdateCallRealInput() { - MS_LOG(INFO) << "Update graph id: " << graph_id_; - std::vector>> real_inputs_map; - for (auto &it : real_inputs_) { - auto parameter = it.first; - MS_EXCEPTION_IF_NULL(parameter); - auto real_inputs = it.second; - std::vector new_real_inputs; - for (auto &real_input : real_inputs) { - // if real input is a call node ,find the child graph output act as the new real input - auto tmp_real_input = GetCallRealOutputs(real_input); - std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs)); - // replace the call in unreuse_args_ - auto unreuse_arg_it = unreuse_args_.find(real_input); - if (unreuse_arg_it != unreuse_args_.end()) { - auto old_graph = unreuse_arg_it->second; - for (auto new_real_input : new_real_inputs) { - // if call reference graph output is parameter, it will be allowed to reuse - if (!new_real_input->isa()) { - unreuse_args_[new_real_input] = old_graph; - } - } - } - } - real_inputs_map.emplace_back(parameter, new_real_inputs); - } - real_inputs_ = real_inputs_map; -} - -void KernelGraph::PrintGraphExecuteOrder() const { - MS_LOG(INFO) << "Graph:" << graph_id_ << "execution order"; - for (size_t i = 0; i < execution_order_.size(); i++) { - CNodePtr cur_cnode_ptr = execution_order_[i]; - MS_EXCEPTION_IF_NULL(cur_cnode_ptr); - std::string event_str; - std::string label_str; - if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) { - event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrEventId)) + "]"; - } - - if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) { - label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr(cur_cnode_ptr, kAttrLabelIndex)) + "]"; - } - - if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) { - auto label_list = AnfAlgo::GetNodeAttr>(cur_cnode_ptr, kAttrLabelSwitchList); - label_str = ", label_id["; - for (size_t j = 0; j < label_list.size(); ++j) { - label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]"); - } - } - - MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id[" - << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" - << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]" - << event_str << label_str; - } -} - -void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { - if (front_node == nullptr || node == nullptr) { - MS_LOG(INFO) << "Front node or node is nullptr"; - return; - } - MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); - front_to_internal_outputs_map_[front_node] = node; - internal_outputs_to_front_map_[node] = front_node; -} - -void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) { - if (new_node == nullptr || node == nullptr) { - MS_LOG(INFO) << "New node or node is nullptr"; - return; - } - if (node == new_node) { - MS_LOG(INFO) << "New node and node is the same"; - return; - } - auto iter = internal_outputs_to_front_map_.find(node); - if (iter == internal_outputs_to_front_map_.end()) { - MS_LOG(INFO) << "Node is not internal output"; - return; - } - MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString(); - internal_outputs_to_front_map_[new_node] = iter->second; - front_to_internal_outputs_map_[iter->second] = new_node; - internal_outputs_to_front_map_.erase(iter); -} - -AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { - auto iter = front_to_internal_outputs_map_.find(front_node); - if (iter != front_to_internal_outputs_map_.end()) { - return iter->second; - } - return nullptr; -} - -bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const { - if (internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end()) { - return true; - } - return false; -} - -AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const { - auto iter = internal_outputs_to_front_map_.find(node); - if (iter != internal_outputs_to_front_map_.end()) { - return iter->second; - } - return nullptr; -} - -void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) { - if (node == nullptr) { - return; - } - (void)final_output_kernels_.insert(node); -} - -bool KernelGraph::IsFinalOutputKernel(const AnfNodePtr &node) const { - if (node == nullptr) { - return false; - } - if (final_output_kernels_.find(node) != final_output_kernels_.end()) { - return true; - } - return false; -} - -std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } - -KernelGraph::~KernelGraph() { device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_); } -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h deleted file mode 100644 index 2e46cfa76a..0000000000 --- a/mindspore/ccsrc/session/kernel_graph.h +++ /dev/null @@ -1,226 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H -#define MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "ir/func_graph.h" -#include "ir/anf.h" -#include "utils/graph_utils.h" -#include "utils/contract.h" -#include "device/kernel_info.h" - -namespace mindspore { -namespace session { -using AnfWithOutIndex = std::pair; -class KernelGraph : public FuncGraph { - public: - KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) { - inputs_ = std::make_shared>(); - execution_order_ = {}; - executable_ = true; - summary_node_exist_ = false; - stream_distinction_label_ = kInvalidDistincLabel; - } - ~KernelGraph() override; - - MS_DECLARE_PARENT(KernelGraph, FuncGraph); - - const std::vector &inputs() const; - std::vector *MutableInputs() const { return inputs_.get(); } - std::vector outputs() const; - CNodePtr NewCNode(const std::vector &inputs) override; - void CreateKernelInfoFromNewParameter(const CNodePtr &cnode); - CNodePtr NewCNode(const CNodePtr &cnode); - ParameterPtr NewParameter(const ParameterPtr ¶meter = nullptr); - ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr); - std::vector SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node); - void set_execution_order(const std::vector &order) { execution_order_ = order; } - const std::vector &execution_order() const { return execution_order_; } - void SetExecOrderByDefault(); - uint32_t graph_id() const { return graph_id_; } - void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } - - // and a new front to backend anf relation to maop - void FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf); - // replace old backend anf with new backend anf - void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); - // get backend anf by front anf - AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); - // check backend node whether exist in map - bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); - // get value node by tensor - ValueNodePtr GetValueNodeByTensor(const tensor::TensorPtr &tensor); - // add value node tensor relation map - void TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node); - // get all value nodes of graph - const std::unordered_set graph_value_nodes() const { return graph_value_nodes_; } - // add value node to graph - void AddValueNodeToGraph(const ValueNodePtr &value_node); - // ref output is in map - bool IsInRefOutputMap(const AnfWithOutIndex &pair) const; - // get ref correspond pairs - AnfWithOutIndex GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const; - // add ref correspond pairs - void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair); - // get map - std::map GetRefMap() const { return ref_out_in_map_; } - // checkout whether loop exist in graph - void CheckLoop(); - // check whether graph is executable - bool executable() const { return executable_; } - // set executable of graph - void set_executable(bool executable) { executable_ = executable; } - // set summary_node of graph - void set_summary_node_exist(bool summary_node_exist) { summary_node_exist_ = summary_node_exist; } - // check whether exist summary node in graph - bool summary_node_exist() const { return summary_node_exist_; } - // set invalid inputs for control sink - std::vector *MutableValidInputs() { return &valid_inputs_; } - std::vector valid_inputs() const { return valid_inputs_; } - // replace node in graph - void ReplaceNode(NotNull old_anf_node, NotNull new_anf_node); - // set stream label of graph - void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } - // get stream label of graph - uint32_t stream_distinction_label() { return stream_distinction_label_; } - // refresh execute kernel stream label - void UpdateExecuteKernelStreamLabel(); - // calculate the leaf graph order of root graph - std::vector> GetLeafGraphOrder(); - // the child graph of current graph - const std::vector> &child_graph_order() const { return child_graph_order_; } - void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } - // checkout whether current graph is leaf graph - bool IsLeafGraph() const; - - // set input_tensors pointer of control parameter - void set_input_ctrl_tensors(const std::shared_ptr> &input_tensors_ptr) { - input_ctrl_tensors_ = input_tensors_ptr; - } - // get input_tensors pointer of control parameter - std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } - // get parent kernel graph - std::shared_ptr parent_graph() const { return parent_graph_; } - // set parent kernel graph - void set_parent_graph(const std::shared_ptr &parent_graph) { parent_graph_ = parent_graph; } - // find anf node in graph - std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; - // get real inputs - const std::vector>> &real_inputs() const { return real_inputs_; } - void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); - // mark unreused args - void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr &from_graph); - const std::map> &unreuse_args() const { return unreuse_args_; } - // used to dump ir - std::string ToString() const override; - // update the real input if the node is a call - void UpdateCallRealInput(); - - void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } - CNodePtr get_start_label() { return start_label_; } - void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } - CNodePtr get_end_goto() { return end_goto_; } - bool get_output_null() { return null_output_; } - void set_output_null(bool is_output_null) { null_output_ = is_output_null; } - void PrintGraphExecuteOrder() const; - const std::map> &summary_nodes() const { return summary_nodes_; } - void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } - void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); - void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node); - AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; - bool IsInternalOutput(const AnfNodePtr &node) const; - AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const; - void AddFinalOutputKernel(const AnfNodePtr &node); - bool IsFinalOutputKernel(const AnfNodePtr &node) const; - uint32_t current_epoch() const { return current_epoch_; } - void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } - - private: - // remove value node form graph - bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); - void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes); - // update node edge list - void UpdateNodeEdgeList(std::queue *seed_nodes); - // add node depend edge by data edge or control depend - void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); - // handle control depend - std::vector GetOutputNodes(const AnfNodePtr &node); - bool HandleControlDependNode(const AnfNodePtr &node, std::queue *que, - std::unordered_set *visited_nodes); - void UpdateControlDependRelations(const std::vector &depends); - - std::shared_ptr> inputs_; - std::vector execution_order_; - uint32_t graph_id_; - uint32_t stream_distinction_label_; - - // record map bettween front anf and backend anf,use two map implement bidirectional map - std::unordered_map front_backend_anf_map_; - std::unordered_map backend_front_anf_map_; - // there may be a tensor from ME backend ,a value ndoe will be create according the tensor,map record - std::unordered_map tensor_to_value_node_map_; - // include all value nodes - std::unordered_set graph_value_nodes_; - std::unordered_map node_input_num_; - std::unordered_map>> node_input_edges_; - // record map between ref final output anf with index and ref origin input with index - std::map ref_out_in_map_; - std::unordered_map>> node_output_edges_; - std::map> summary_nodes_; - // graph needn't execute - bool executable_; - // exist summary node in graph - bool summary_node_exist_; - // valid inputs - std::vector valid_inputs_; - - // new members for control sink process - // all child grahs refers to partial node - std::map> node_to_child_graphs_; - // child graph execute order in root graph - std::vector> child_graph_order_; - - // input_tensors of control parameter - std::shared_ptr> input_ctrl_tensors_; - - // parameter graph - std::shared_ptr parent_graph_; - // record real parameters,inputs_ is the formal parameters - std::vector>> real_inputs_; - std::map> unreuse_args_; - - CNodePtr start_label_; - CNodePtr end_goto_; - bool null_output_; - std::unordered_map front_to_internal_outputs_map_; - std::unordered_map internal_outputs_to_front_map_; - std::set final_output_kernels_; - uint32_t current_epoch_; -}; -} // namespace session -using KernelGraphPtr = std::shared_ptr; -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_KERNEL_GRAPH_H diff --git a/mindspore/ccsrc/session/session.cc b/mindspore/ccsrc/session/session.cc deleted file mode 100644 index ae70fc77aa..0000000000 --- a/mindspore/ccsrc/session/session.cc +++ /dev/null @@ -1,208 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "include/inference.h" -#include "session/session.h" -#include "utils/load_onnx/anf_converter.h" -#include "session/session_basic.h" -#include "session/session_factory.h" -#include "utils/base_ref_utils.h" -#include "kernel/oplib/oplib.h" -#ifdef ENABLE_D -#include "utils/context/ms_context.h" -#include "session/ascend_session.h" -#else -#include "session/cpu_session.h" -#endif - -namespace py = pybind11; -namespace mindspore::inference { -std::shared_ptr LoadModel(const char *model_buf, size_t size, const std::string &device) { - try { - inference::Session::RegAllOp(); - auto anf_graph = lite::AnfConverter::RunAnfConverter(model_buf, size); - return anf_graph; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference LoadModel failed"; - return nullptr; - } -} - -void ExitInference() { - auto ms_context = MsContext::GetInstance(); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return; - } - if (!ms_context->CloseTsd()) { - MS_LOG(ERROR) << "Inference CloseTsd failed!"; - return; - } -} - -std::shared_ptr MSSession::CreateSession(const std::string &device, uint32_t device_id) { - try { - auto session = std::make_shared(); - auto ret = session->Init(device, device_id); - if (ret != 0) { - return nullptr; - } - return session; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CreatSession failed"; - return nullptr; - } -} - -void Session::RegAllOp() { - static std::mutex init_mutex; - static bool Initialized = false; - - std::lock_guard lock(init_mutex); - if (Initialized) { - return; - } - Initialized = true; - MsContext::GetInstance()->set_execution_mode(kGraphMode); - Py_Initialize(); - auto c_expression = PyImport_ImportModule("mindspore._c_expression"); - if (c_expression == nullptr) { - MS_LOG(EXCEPTION) << "Failed to import mindspore._c_expression module."; - return; - } - PyObject *c_expression_dict = PyModule_GetDict(c_expression); - - PyObject *op_info_loader_class = PyDict_GetItemString(c_expression_dict, "OpInfoLoaderPy"); - if (op_info_loader_class == nullptr) { - MS_LOG(EXCEPTION) << "Failed to get op_info_loader_class from mindspore._c_expression."; - return; - } - PyObject *op_info_loader = PyInstanceMethod_New(op_info_loader_class); - if (op_info_loader == nullptr) { - MS_LOG(EXCEPTION) << "Failed to create op_info_loader instance."; - return; - } - PyObject *op_info_loader_ins = PyObject_CallObject(op_info_loader, nullptr); - if (op_info_loader_ins == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call op_info_loader instance."; - return; - } - auto all_ops_info_vector_addr_ul = PyObject_CallMethod(op_info_loader_ins, "get_all_ops_info", nullptr); - if (all_ops_info_vector_addr_ul == nullptr) { - MS_LOG(EXCEPTION) << "Failed to call get_all_ops_addr."; - return; - } - auto all_ops_info_vector_addr = PyLong_AsVoidPtr(all_ops_info_vector_addr_ul); - auto all_ops_info = static_cast *>(all_ops_info_vector_addr); - for (auto op_info : *all_ops_info) { - kernel::OpLib::RegOpInfo(std::shared_ptr(op_info)); - } - all_ops_info->clear(); - delete all_ops_info; - Py_DECREF(op_info_loader); - Py_DECREF(op_info_loader_class); - Py_DECREF(c_expression_dict); - Py_DECREF(c_expression); - return; -} - -uint32_t Session::CompileGraph(std::shared_ptr funcGraphPtr) { - MS_ASSERT(session_impl_ != nullptr); - try { - auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr)); - py::gil_scoped_release gil_release; - return graph_id; - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference CompileGraph failed"; - return static_cast(-1); - } -} - -MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector> &inputs) { - try { - std::vector inTensors; - inTensors.resize(inputs.size()); - bool has_error = false; - std::transform(inputs.begin(), inputs.end(), inTensors.begin(), - [&has_error](const std::shared_ptr &tensor_ptr) -> tensor::TensorPtr { - if (tensor_ptr == nullptr) { - MS_LOG(WARNING) << "input MSTensor is nullptr, return nullptr"; - has_error = true; - return nullptr; - } - auto tensor = static_cast(tensor_ptr.get()); - if (tensor == nullptr) { - MS_LOG(ERROR) << "Can not cast input MSTensor to tensor"; - has_error = true; - return nullptr; - } - return tensor->tensor(); - }); - if (has_error) { - MS_LOG(ERROR) << "Init Tensor failed, returning empty result"; - std::vector> multiTensor; - return multiTensor; - } - VectorRef outputs; - session_impl_->RunGraph(graph_id, inTensors, &outputs); - - return TransformVectorRefToMultiTensor(outputs); - } catch (std::exception &e) { - MS_LOG(ERROR) << "Inference Rungraph failed"; - return MultiTensor(); - } -} -namespace { -string AjustTargetName(const std::string &device) { - if (device == kAscendDevice) { - return std::string(kAscendDevice) + "Inference"; - } else { - MS_LOG(ERROR) << "Only support device Ascend right now"; - return ""; - } -} -} // namespace -int Session::Init(const std::string &device, uint32_t device_id) { - RegAllOp(); - auto ms_context = MsContext::GetInstance(); - ms_context->set_execution_mode(kGraphMode); - ms_context->set_device_id(device_id); - auto ajust_device = AjustTargetName(device); - if (ajust_device == "") { - return -1; - } - ms_context->set_device_target(device); - session_impl_ = session::SessionFactory::Get().Create(ajust_device); - if (session_impl_ == nullptr) { - MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; - return -1; - } - session_impl_->Init(device_id); - if (ms_context == nullptr) { - MS_LOG(ERROR) << "Get Context failed!"; - return -1; - } - if (!ms_context->OpenTsd()) { - MS_LOG(ERROR) << "Session init OpenTsd failed!"; - return -1; - } - return 0; -} - -Session::Session() = default; -} // namespace mindspore::inference diff --git a/mindspore/ccsrc/session/session.h b/mindspore/ccsrc/session/session.h deleted file mode 100644 index b608163067..0000000000 --- a/mindspore/ccsrc/session/session.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H -#define MINDSPORE_CCSRC_SESSION_SESSION_H - -#include -#include -#include -#include -#include -#include - -#include "session/session_basic.h" -#include "ir/anf.h" -#include "include/inference.h" - -namespace mindspore { -namespace inference { -class Session : public MSSession { - public: - Session(); - - uint32_t CompileGraph(std::shared_ptr funcGraphPtr) override; - - MultiTensor RunGraph(uint32_t graph_id, const std::vector> &inputs) override; - - int Init(const std::string &device, uint32_t device_id); - - static void RegAllOp(); - - private: - std::shared_ptr session_impl_ = nullptr; - std::vector graph_id_; -}; -} // namespace inference -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc deleted file mode 100644 index 59cc0dd020..0000000000 --- a/mindspore/ccsrc/session/session_basic.cc +++ /dev/null @@ -1,1128 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/session_basic.h" -#include -#include -#include -#include -#include "pipeline/parse/data_converter.h" -#include "ir/manager.h" -#include "ir/param_value.h" -#include "kernel/common_utils.h" -#include "operator/ops.h" -#include "common/trans.h" -#include "utils/context/ms_context.h" -#include "utils/config_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" -#include "pre_activate/common/common_backend_optimization.h" -#include "pre_activate/pass/const_input_to_attr_registry.h" -#include "pre_activate/common/helper.h" -#include "common/utils.h" -#include "ir/dtype.h" -#include "ir/anf.h" -#include "ir/func_graph_cloner.h" - -namespace mindspore { -namespace session { -static std::shared_ptr> python_paras; -void ClearPythonParasMap() { python_paras = nullptr; } -namespace { -const int kSummaryGetItem = 2; - -ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) { - if (node == nullptr) { - return nullptr; - } - auto parameter = node->cast(); - if (parameter == nullptr || !parameter->has_default()) { - return nullptr; - } - return parameter->default_param(); -} - -BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; - // if node is a value node, no need sync addr from device to host - if (!AnfAlgo::OutputAddrExist(node, output_index)) { - if (node->isa()) { - auto value_node = node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - return value_node->value(); - } - if (node->isa()) { - for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { - if (input_idx >= input_tensors.size()) { - MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); - } - if (graph.inputs()[input_idx] == node) { - return input_tensors[input_idx]; - } - } - MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr"; - } - } - // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) - auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); - MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, output_index); - TypeId type_id = kNumberTypeFloat32; - type_id = AnfAlgo::GetOutputInferDataType(node, output_index); - std::vector temp_shape; - if (graph.IsInternalOutput(node)) { - temp_shape.emplace_back(1); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - tensor->set_device_address(address); - tensor->set_dirty(false); - return tensor; - } - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - // if in paynative mode,data only copyed to host when user want to print data - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { - tensor->set_device_address(address); - tensor->set_dirty(false); - } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), - LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) { - MS_LOG(INFO) << "Output sync device to host error!!!"; - tensor->set_dirty(false); - } - return tensor; -} - -BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(anf); - MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); - MS_EXCEPTION_IF_NULL(item_with_index.first); - MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString(); - // special handle for maketuple - if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { - auto cnode = item_with_index.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - VectorRef ret; - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors); - ret.push_back(out); - } - return ret; - } - // if is graph return nothing ,the function should return a null anylist - size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first); - if (size == 0) { - return VectorRef(); - } - return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors); -} - -BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, - const std::vector &input_tensors) { - MS_EXCEPTION_IF_NULL(anf); - if (!AnfAlgo::IsRealKernel(anf)) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel"; - } - if (anf->isa()) { - return CreateOneTensor(anf, 0, graph, input_tensors); - } - VectorRef ret; - if (anf->isa() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) { - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) { - auto out = CreateOneTensor(anf, i, graph, input_tensors); - ret.emplace_back(out); - } - } - return ret; -} - -ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - auto value_node = anf->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto value = value_node->value(); - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - return nullptr; - } - auto new_value_node = graph->NewValueNode(value_node); - graph->FrontBackendlMapAdd(anf, new_value_node); - graph->AddValueNodeToGraph(new_value_node); - return new_value_node; -} - -size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Load kInputCtrlTensors"; - auto inputs_params = graph->input_ctrl_tensors(); - if (inputs_params == nullptr) { - return 0; - } - if (inputs_params->size() < 2) { - MS_LOG(EXCEPTION) << "Illegal inputs_params size"; - } - auto tensor = (*inputs_params)[0]; - MS_EXCEPTION_IF_NULL(tensor); - auto *val = static_cast(tensor->data_c()); - MS_EXCEPTION_IF_NULL(val); - *val = 0; - tensor->set_dirty(true); - // set loop_count to zero - MS_EXCEPTION_IF_NULL(inputs); - inputs->push_back(tensor); - - auto epoch_tensor = (*inputs_params)[1]; - MS_EXCEPTION_IF_NULL(epoch_tensor); - auto *epoch_val = static_cast(epoch_tensor->data_c()); - MS_EXCEPTION_IF_NULL(epoch_val); - *epoch_val = graph->current_epoch(); - epoch_tensor->set_dirty(true); - inputs->push_back(epoch_tensor); - MS_LOG(INFO) << "Load epoch_val:" << *epoch_val; - - graph->set_current_epoch(graph->current_epoch() + 1); - - return inputs_params->size(); -} - -ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(input_tensor); - auto value_node = std::make_shared(input_tensor); - MS_EXCEPTION_IF_NULL(value_node); - // construct abstract of value node - auto type_of_tensor = input_tensor->Dtype(); - auto shape_of_tensor = input_tensor->shape(); - auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); - value_node->set_abstract(abstract); - // add value node to graph - auto input_value_node = graph->NewValueNode(value_node); - graph->AddValueNodeToGraph(input_value_node); - return input_value_node; -} - -ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor, - int tensor_mask) { - MS_EXCEPTION_IF_NULL(graph); - auto param = graph->NewParameter(); - MS_EXCEPTION_IF_NULL(param); - if (tensor_mask == kParameterWeightTensorMask) { - auto param_value_new = std::make_shared(); - param->set_default_param(param_value_new); - } - // set the kernel info of parameter - auto kernel_build_info_builder = std::make_shared(); - MS_EXCEPTION_IF_NULL(input_tensor); - if (input_tensor->device_address().get() == nullptr) { - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type(); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{param_init_data_type}); - } else { - kernel_build_info_builder->SetOutputsFormat(std::vector{input_tensor->device_address()->format()}); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{input_tensor->device_address()->type_id()}); - } - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); - // construct abstract of parameter - auto type_of_tensor = input_tensor->Dtype(); - auto shape_of_tensor = input_tensor->shape(); - auto abstract = std::make_shared(type_of_tensor, shape_of_tensor); - param->set_abstract(abstract); - return param; -} - -void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { - MS_LOG(INFO) << "Graph outputs:"; - const size_t max_deep = 10; - if (recurse_level > max_deep) { - MS_LOG(INFO) << "Recurse too deep"; - return; - } - std::string tab_str; - for (size_t i = 0; i < recurse_level; i++) { - tab_str = tab_str.append(" "); - } - if (any.is()) { - (void)tab_str.append("{"); - MS_LOG(INFO) << tab_str; - auto any_list = any.cast(); - for (auto &it : any_list) { - DumpGraphOutput(it, recurse_level + 1); - } - (void)tab_str.append("}"); - MS_LOG(INFO) << tab_str; - } - (void)tab_str.append(any.ToString()); - MS_LOG(INFO) << tab_str; -} - -bool ExistSummaryNode(const KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto ret = graph->get_return(); - MS_EXCEPTION_IF_NULL(ret); - auto all_nodes = DeepLinkedGraphSearch(ret); - for (auto &n : all_nodes) { - if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || - IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { - return true; - } - } - return false; -} -} // namespace - -GraphId SessionBasic::graph_sum_ = 0; - -KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) { - auto it = graphs_.find(graph_id); - if (it == graphs_.end()) { - MS_LOG(WARNING) << "Can't find graph " << graph_id; - return nullptr; - } - return it->second; -} - -void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) { - auto graph_id = GetGraphIdByNode(out_node); - if (graph_id == kInvalidGraphId) { - return; - } - auto node_graph = GetGraph(graph_id); - if (node_graph == nullptr) { - return; - } - MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString(); - auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node); - if (ref_node == nullptr) { - MS_LOG(INFO) << "No corresponding internal output for output node"; - return; - } - auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0); - auto ref_real_node = real_kernel.first; - auto ref_real_node_index = real_kernel.second; - if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && - node_graph->IsFinalOutputKernel(ref_real_node)) { - auto kernel_info = ref_real_node->kernel_info(); - if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { - MS_LOG(INFO) << "No kernel info"; - return; - } - auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index); - if (address == nullptr) { - MS_LOG(INFO) << "No kernel address"; - return; - } - auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); - auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); - parameter->set_kernel_info(std::make_shared()); - auto d_kernel_info = parameter->kernel_info(); - MS_EXCEPTION_IF_NULL(d_kernel_info); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsDeviceType({type}); - builder.SetOutputsFormat({format}); - d_kernel_info->set_select_kernel_build_info(builder.Build()); - AnfAlgo::SetOutputAddr(address, 0, parameter.get()); - } -} - -std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, - KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(graph); - std::vector parameters; - std::vector pre_graph_out = {node}; - // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive - if (!AnfAlgo::IsRealKernel(node)) { - pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); - } - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { - auto parameter = graph->NewParameter(); - MS_EXCEPTION_IF_NULL(parameter); - parameter->set_abstract(abstract); - auto new_parameter = graph->NewParameter(parameter); - parameters.push_back(new_parameter); - valid_inputs->push_back(valid_input); - graph_inputs->push_back(new_parameter); - }; - for (const auto &out_node : pre_graph_out) { - MS_EXCEPTION_IF_NULL(out_node); - auto abstract = out_node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - // create multiple parameters if is a tuple output real kernel - if (abstract->isa() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]"; - for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) { - create_parameter((*tuple_abstract)[output_idx]); - } - continue; - } - // create single parameter if is a abstract real kernel - create_parameter(out_node->abstract()); - InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]); - } - return parameters; -} - -ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, - KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; - } - MS_EXCEPTION_IF_NULL(graph); - auto param_value = GetParamDefaultValue(anf); - auto valid_inputs = graph->MutableValidInputs(); - MS_EXCEPTION_IF_NULL(valid_inputs); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - ParameterPtr new_parameter = nullptr; - // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter - if (python_paras == nullptr) { - python_paras = std::make_shared>(); - } - auto iter = python_paras->find(param_value); - if (iter != python_paras->end()) { - new_parameter = iter->second; - } else { - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - new_parameter = graph->NewParameter(anf->cast()); - if (param_value != nullptr) { - (*python_paras)[param_value] = new_parameter; - } - TraceManager::EndTrace(); - } - graph_inputs->push_back(new_parameter); - valid_inputs->push_back(valid_input); - return new_parameter; -} - -AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; - auto parameters = CreateParameterFromTuple(anf, valid_input, graph); - if (parameters.empty()) { - MS_LOG(EXCEPTION) << "No parameter exist!!"; - } - if (parameters.size() == 1) { - return parameters[0]; - } - std::vector make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)}; - (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); - auto make_tuple = graph->NewCNode(make_tuple_input); - MS_EXCEPTION_IF_NULL(make_tuple); - MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; - return make_tuple; -} - -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, - bool *from_other_graph, - std::unordered_map *other_graph_cnode) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(from_other_graph); - MS_EXCEPTION_IF_NULL(other_graph_cnode); - *from_other_graph = false; - // get primitive of old node - std::vector cnode_inputs; - auto prim = AnfAlgo::GetCNodePrimitive(cnode); - if (prim != nullptr) { - // push attr to inputs[0] of new cnode - cnode_inputs.push_back(std::make_shared(std::make_shared(*prim))); - } else { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - MS_EXCEPTION_IF_NULL(fg); - auto new_fg = BasicClone(fg); - cnode_inputs.push_back(std::make_shared(new_fg)); - } - auto origin_inputs = cnode->inputs(); - bool optimize_depend = false; - if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && - origin_inputs[kRealInputIndexInDepend]->isa()) { - optimize_depend = true; - } - // if has multiple depends,only select first depend as parameter - for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { - auto anf = origin_inputs[input_idx]; - MS_EXCEPTION_IF_NULL(anf); - // anf has been created before - if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); - continue; - } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { - cnode_inputs.push_back((*other_graph_cnode)[anf]); - continue; - } else if (anf->isa() && !IsValueNode(anf)) { - // if input is a value node, - auto new_value_node = CreateNewValueNode(anf, graph); - if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); - } - continue; - } else if (anf->isa()) { - auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); - cnode_inputs.push_back(new_parameter); - if (GetGraphIdByNode(anf) == kInvalidGraphId) { - graph->FrontBackendlMapAdd(anf, new_parameter); - } else { - (*other_graph_cnode)[anf] = new_parameter; - } - continue; - } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { - cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); - continue; - } else { - *from_other_graph = true; - // the input node is a cnode from other graph - auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); - cnode_inputs.push_back(parameter_from_cnode); - (*other_graph_cnode)[anf] = parameter_from_cnode; - } - } - TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); - auto new_cnode = graph->NewCNode(cnode_inputs); - TraceManager::EndTrace(); - return new_cnode; -} - -CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(node_input); - MS_EXCEPTION_IF_NULL(graph); - // switch input generalizes partial - if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) || - AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) { - return node_input->cast(); - } - if (node_input->isa()) { - MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call."; - } - std::vector partial_inputs = {NewValueNode(std::make_shared(prim::kPrimPartial->name()))}; - if (node_input->isa() && IsValueNode(node_input)) { - partial_inputs.emplace_back(node_input); - auto partial_node = graph->NewCNode(partial_inputs); - return partial_node; - } - KernelGraphPtr kernel_graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(kernel_graph); - kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input)); - partial_inputs.emplace_back(std::make_shared(kernel_graph)); - auto partial_node = graph->NewCNode(partial_inputs); - return partial_node; -} - -CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf_node); - MS_EXCEPTION_IF_NULL(graph); - auto node = anf_node->cast(); - MS_EXCEPTION_IF_NULL(node); - if (node->inputs().size() < kSwitchInputSize) { - MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize; - } - auto primitive = NewValueNode(std::make_shared(prim::kPrimSwitch->name())); - std::vector switch_inputs = {primitive, node->input(1)}; - for (size_t index = 2; index < node->inputs().size(); index++) { - auto input = CreateSwitchInput(node->input(index), graph); - switch_inputs.emplace_back(input); - } - auto switch_node = graph->NewCNode(switch_inputs); - return switch_node; -} - -std::vector SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - // create primitive of cnode:call(partial or switch) - std::vector cnode_inputs = { - graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input); - if (cnode_input == nullptr) { - MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString() - << ", but input[0] has not been created."; - } - // if the node is partial, insert the inputs of partial to the call - if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) { - auto partial_node = attr_input->cast(); - MS_EXCEPTION_IF_NULL(partial_node); - auto partial_inputs = partial_node->inputs(); - std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(), - std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node)); - return graph->GetBackendAnfByFrontAnf(node); - }); - return cnode_inputs; - } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) { - auto switch_node = HandleSwitchInputs(cnode_input, graph); - cnode_inputs.emplace_back(switch_node); - return cnode_inputs; - } - MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch."; -} - -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(graph); - std::vector cnode_inputs; - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - if (AnfAlgo::IsGraphKernel(cnode)) { - auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode); - MS_EXCEPTION_IF_NULL(fg); - auto new_fg = BasicClone(fg); - cnode_inputs.push_back(std::make_shared(new_fg)); - } else if (IsValueNode(attr_input)) { - // create primitive of cnode:call - cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimCall->name())))}; - // create a ValueNode as input of cnode:call - if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input)); - } else { - auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph); - if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); - } - } - } else if (attr_input->isa()) { - cnode_inputs = CreateSwitchOrPartialNode(cnode, graph); - } else { - // get primitive of old node - auto prim = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(prim); - // push attr to inputs[0] of new cnode - cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared(*prim)))}; - } - - for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto anf = cnode->input(input_idx); - MS_EXCEPTION_IF_NULL(anf); - // anf has been created before - if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { - cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); - continue; - } else if (IsValueNode(anf)) { - continue; - } - MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; - } - TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); - auto new_cnode = graph->NewCNode(cnode_inputs); - TraceManager::EndTrace(); - return new_cnode; -} - -ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - auto value_node = anf->cast(); - MS_EXCEPTION_IF_NULL(value_node); - auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf); - MS_EXCEPTION_IF_NULL(sub_func_graph); - if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) { - MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph."; - } - auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph]; - - ValueNodePtr new_value_node = std::make_shared(sub_kernel_graph); - new_value_node->set_abstract(value_node->abstract()); - // create new kernel_info of new value_node - auto kernel_info = std::make_shared(); - kernel_info->SetFeatureMapFlag(false); - new_value_node->set_kernel_info(kernel_info); - // create kernel_build_info for new value node - auto kernel_build_info_builder = std::make_shared(); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get()); - AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get()); - - graph->FrontBackendlMapAdd(anf, new_value_node); - - return new_value_node; -} - -ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(anf); - MS_EXCEPTION_IF_NULL(graph); - if (!anf->isa()) { - MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter"; - } - - auto param_value = GetParamDefaultValue(anf); - ParameterPtr new_parameter = nullptr; - if (python_paras == nullptr) { - python_paras = std::make_shared>(); - } - auto iter = python_paras->find(param_value); - if (iter != python_paras->end()) { - new_parameter = iter->second; - } else { - TraceManager::DebugTrace(std::make_shared(anf->debug_info())); - new_parameter = graph->NewParameter(anf->cast()); - if (param_value != nullptr) { - (*python_paras)[param_value] = new_parameter; - } - TraceManager::EndTrace(); - } - - return new_parameter; -} - -KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { - std::unordered_map other_graph_cnode; - auto graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Create graph: " << graph->graph_id(); - size_t from_other_graph_depend_num = 0; - for (const auto &node : lst) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); - if (!node->isa()) { - MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // create a new cnode object - bool from_other_graph = false; - // only first depend from other graph can create - bool valid_input = true; - if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { - valid_input = false; - } - auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { - from_other_graph_depend_num++; - } - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_scope(cnode->scope()); - // record map relations between anf from ME and new anf node used in backend - graph->FrontBackendlMapAdd(node, new_cnode); - } - // add a make_tuple at the end of graph as output - graph->set_output(ConstructOutput(outputs, graph)); - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = MakeManager({graph}); - if (manager) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - graph->SetExecOrderByDefault(); - if (ExistSummaryNode(graph.get())) { - graph->set_summary_node_exist(true); - } - opt::BackendCommonOptimization(graph); - return graph; -} - -void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(graph); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // create a new cnode object - auto new_cnode = CreateNewCNode(cnode, graph.get()); - MS_EXCEPTION_IF_NULL(new_cnode); - new_cnode->set_abstract(cnode->abstract()); - new_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); - new_cnode->set_scope(cnode->scope()); - graph->FrontBackendlMapAdd(node, new_cnode); - if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { - graph->set_return(new_cnode); - } -} -std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph, - std::vector *all_out_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(all_out_graph); - auto node_list = TopoSort(func_graph->get_return()); - auto graph = NewKernelGraph(); - MS_EXCEPTION_IF_NULL(graph); - front_backend_graph_map_[func_graph] = graph; - MS_LOG(INFO) << "Create graph: " << graph->graph_id(); - - bool is_trace_back = false; - for (const auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); - if (node->isa()) { - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - auto new_parameter = CreateNewParameter(node, graph.get()); - graph_inputs->push_back(new_parameter); - graph->FrontBackendlMapAdd(node, new_parameter); - continue; - } else if (node->isa()) { - if (!IsValueNode(node)) { - // if input is a common value node, - (void)CreateNewValueNode(node, graph.get()); - } else { - // if input is a ValueNode - FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node); - if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) { - is_trace_back = true; - } else { - (void)ConstructKernelGraph(child_graph, all_out_graph); - } - (void)CreateValueNodeKernelGraph(node, graph.get()); - } - continue; - } else { - CreateCNodeKernelGraph(node, graph); - } - } - // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. - graph->set_output_null(is_trace_back); - AddParameterToGraphInputs(func_graph->parameters(), graph.get()); - graph->SetExecOrderByDefault(); - if (ExistSummaryNode(graph.get())) { - graph->set_summary_node_exist(true); - } - all_out_graph->push_back(graph); - return graph; -} - -void SessionBasic::AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph) { - MS_EXCEPTION_IF_NULL(graph); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - graph_inputs->clear(); - for (auto ¶meter : parameters) { - MS_EXCEPTION_IF_NULL(parameter); - auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); - if (backend_parameter == nullptr) { - // for example "def f(x,y,z) {return x + y}", parameter z in unused - auto new_parameter = CreateNewParameter(parameter, graph); - graph_inputs->push_back(new_parameter); - MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); - continue; - } - MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); - graph_inputs->push_back(backend_parameter); - } -} - -// run graph steps -void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const { - std::vector inputs(inputs_const); - size_t input_ctrl_size = 2; - MS_EXCEPTION_IF_NULL(kernel_graph); - if (kernel_graph->input_ctrl_tensors()) { - input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); - } - auto input_nodes = kernel_graph->inputs(); - if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { - MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() - << ", input_ctrl_size:" << input_ctrl_size; - } - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - for (size_t i = 0; i < inputs.size(); ++i) { - auto tensor = inputs[i]; - MS_EXCEPTION_IF_NULL(tensor); - auto input_node = input_nodes[i]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - auto pk_node = input_node->cast(); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - bool need_sync = false; - if (ms_context->enable_pynative_infer()) { - if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { - need_sync = true; - } - } else { - if (tensor->is_dirty()) { - need_sync = true; - } else if (tensor->device_address() != device_address) { - (void)tensor->data_sync(); - need_sync = true; - } - } - if (need_sync) { - if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { - tensor->set_device_address(device_address); - } - MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } - } - } - tensor->set_dirty(false); - } -} - -void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, - const std::vector &input_tensors) const { - MS_EXCEPTION_IF_NULL(kernel_graph); - MS_EXCEPTION_IF_NULL(outputs); - if (!kernel_graph->child_graph_order().empty()) { - // use the last child graph output as the root graph output - UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); - return; - } - auto anf_outputs = kernel_graph->outputs(); - for (auto &item : anf_outputs) { - MS_EXCEPTION_IF_NULL(item); - MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; - if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) { - outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors)); - continue; - } - outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors)); - } -} - -void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { - MS_EXCEPTION_IF_NULL(callback); - summary_callback_ = callback; -} - -void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } - -void SessionBasic::GetSummaryNodes(KernelGraph *graph) { - MS_LOG(DEBUG) << "Update summary Start"; - MS_EXCEPTION_IF_NULL(graph); - if (!graph->summary_node_exist()) { - return; - } - auto summary = graph->summary_nodes(); - auto apply_list = TopoSort(graph->get_return()); - for (auto &n : apply_list) { - MS_EXCEPTION_IF_NULL(n); - if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) || - IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) { - auto cnode = n->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() <= kSummaryGetItem) { - MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!"; - } - auto node = cnode->input(kSummaryGetItem); - MS_EXCEPTION_IF_NULL(node); - auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - MS_EXCEPTION_IF_NULL(item_with_index.first); - if (!AnfAlgo::IsRealKernel(item_with_index.first)) { - MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString(); - } - summary[n->fullname_with_scope()] = item_with_index; - } - } - graph->set_summary_nodes(summary); - MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); -} - -void SessionBasic::Summary(KernelGraph *graph) { - if (summary_callback_ == nullptr) { - return; - } - MS_EXCEPTION_IF_NULL(graph); - bool exist_summary = graph->summary_node_exist(); - if (!exist_summary) { - return; - } - GetSummaryNodes(graph); - auto summary_outputs = graph->summary_nodes(); - std::map params_list; - // fetch outputs apply kernel in session & run callback functions - for (auto &output_item : summary_outputs) { - auto node = output_item.second.first; - size_t index = IntToSize(output_item.second.second); - auto address = AnfAlgo::GetOutputAddr(node, index); - auto shape = AnfAlgo::GetOutputInferShape(node, index); - TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); - std::vector temp_shape; - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); - MS_EXCEPTION_IF_NULL(address); - if (!address->GetPtr()) { - continue; - } - if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), - tensor->data_type(), tensor->data_c())) { - MS_LOG(ERROR) << "Failed to sync output from device to host."; - } - tensor->set_dirty(false); - params_list[output_item.first] = tensor; - } - // call callback function here - summary_callback_(0, params_list); -} - -CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { - MS_EXCEPTION_IF_NULL(graph); - std::vector output_args; - for (const auto &output : outputs) { - MS_EXCEPTION_IF_NULL(output); - MS_LOG(INFO) << "Output:" << output->DebugString(); - } - auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { - auto backend_anf = graph->GetBackendAnfByFrontAnf(out); - if (backend_anf != nullptr) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->execution_mode() == kPynativeMode) { - return backend_anf; - } - auto front_real_kernel = AnfAlgo::VisitKernel(out, 0); - auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0); - MS_EXCEPTION_IF_NULL(out); - auto out_func_graph = out->func_graph(); - MS_EXCEPTION_IF_NULL(out_func_graph); - auto out_func_graph_manager = out_func_graph->manager(); - if (out_func_graph_manager == nullptr) { - return backend_anf; - } - auto node_users = out_func_graph_manager->node_users(); - auto users = node_users[out]; - bool internal_output = true; - std::string kernel_target = GetCNodeTarget(front_real_kernel.first); - for (auto user : users) { - if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { - internal_output = false; - break; - } - } - if (internal_output) { - MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString(); - graph->AddInternalOutput(out, backend_real_kernel.first); - } - return backend_anf; - } - MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; - }; - output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); - (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), - [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); }); - return graph->NewCNode(output_args); -} - -void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph) { - MS_LOG(INFO) << "Start!"; - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - MS_EXCEPTION_IF_NULL(graph); - if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) { - for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) { - auto idx = NewValueNode(SizeToInt(output_index)); - MS_EXCEPTION_IF_NULL(idx); - auto imm = std::make_shared(output_index); - idx->set_abstract(std::make_shared(imm)); - auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx}); - std::vector types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)}; - std::vector> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)}; - AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get()); - make_tuple_inputs.push_back(getitem); - } - } else { - make_tuple_inputs.push_back(cnode); - } - // create output - auto g_output = graph->NewCNode(make_tuple_inputs); - graph->set_output(g_output); - // set graph manager,which now is only used to get valuenodes and hardware optimizing - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = context_->manager(); - if (manager != nullptr) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } - MS_LOG(INFO) << "Finish!"; -} - -std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, - const std::vector &input_tensors, - const std::vector &tensors_mask) { - auto graph = std::make_shared(); - std::vector inputs; - // set input[0] - PrimitivePtr op_prim = op_run_info.py_primitive; - MS_EXCEPTION_IF_NULL(op_prim); - inputs.push_back(std::make_shared(op_prim)); - // set input parameter - MS_LOG(INFO) << "Input tensor size: " << input_tensors.size(); - if (input_tensors.size() != tensors_mask.size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " - << tensors_mask.size(); - } - for (size_t i = 0; i < input_tensors.size(); ++i) { - if (tensors_mask[i] == kValueNodeTensorMask) { - auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]); - inputs.push_back(value_node); - continue; - } - auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); - inputs.push_back(parameter); - auto mutable_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(mutable_inputs); - mutable_inputs->push_back(parameter); - } - // set execution order - auto cnode = graph->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(cnode); - // set abstract,which include inferred shapes and types - cnode->set_abstract(op_run_info.abstract); - // set execution order - std::vector exe_order = {cnode}; - graph->set_execution_order(exe_order); - // set output - CreateOutputNode(cnode, graph); - return graph; -} - -BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) { - if (utils::isa(base_ref)) { - auto ref_list = utils::cast(base_ref); - py::tuple output_tensors(ref_list.size()); - for (size_t i = 0; i < ref_list.size(); ++i) { - auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef - if (utils::isa(output)) { - auto tensor_ptr = utils::cast(output); - MS_EXCEPTION_IF_NULL(tensor_ptr); - output_tensors[i] = tensor_ptr; - } else if (utils::isa(output)) { - py::object obj = utils::cast(output).object_; - py::tuple tensor_tuple = py::cast(obj); - output_tensors[i] = tensor_tuple; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } - } - return output_tensors; // turn tuple to py::object and store in PyObjectRef - } else if (utils::isa(base_ref)) { - return base_ref; - } else { - MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!"; - } -} - -KernelGraphPtr SessionBasic::NewKernelGraph() { - auto graph = std::make_shared(); - graph->set_graph_id(graph_sum_); - graphs_[graph_sum_++] = graph; - return graph; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h deleted file mode 100755 index 8f8f88e65a..0000000000 --- a/mindspore/ccsrc/session/session_basic.h +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H -#define MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H - -#include -#include -#include -#include -#include -#include - -#include "utils/base_ref_extends.h" -#include "session/session_context.h" -#include "session/kernel_graph.h" -#include "ir/anf.h" -#include "ir/tensor.h" -#include "utils/any.h" -#include "utils/contract.h" -#include "pynative/pynative_execute.h" -#include "device/kernel_info.h" -#ifdef ENABLE_DEBUGGER -#include "debug/debugger/debugger.h" -#endif - -namespace mindspore { -using GraphId = uint32_t; -using GraphInfo = std::string; -namespace session { -void ClearPythonParasMap(); -using CallBackFunc = uint32_t (*)(uint32_t graph_id, - const std::map ¶ms_list); -using AnyList = std::vector; -using AnyListPtr = std::shared_ptr; - -using OpRunInfo = pynative::OpExecInfo; -using OpRunInfoPtr = std::shared_ptr; - -class SessionBasic { - public: - SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { -#ifdef ENABLE_DEBUGGER - debugger_ = nullptr; -#endif - } - - virtual void Init(uint32_t device_id) { device_id_ = device_id; } - - virtual ~SessionBasic() { summary_callback_ = nullptr; } - - virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; - virtual GraphId CompileGraph(NotNull func_graph) { return kInvalidGraphId; } - // build graph, used to handle multiple child graphs - virtual void BuildGraph(GraphId) {} - - virtual void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) = 0; - - virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors, - const std::vector &tensors_mask) {} - - virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors) { - return py::tuple(); - } - - virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); - - void CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph); - - std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); - std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, - std::vector *all_out_graph); - - CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, - std::unordered_map *other_graph_cnode); - CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); - - CNodePtr CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph); - CNodePtr HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph); - std::vector CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph); - - // set parameters of final graph - virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } - // set output of final graph - virtual void SetFinalGraphOutput(const BaseRef &) {} - // insert switch and set the relative active ops - virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} - // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter - virtual void SetChildGraphInput(GraphId, const VectorRef &) {} - // get graph id in child graphs by ME front anf node pointer - virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; } - virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } - virtual void SetActive(GraphId, GraphId) {} - virtual void GetSummaryNodes(KernelGraph *graph); - -#ifdef ENABLE_DEBUGGER - // set debugger - void SetDebugger() { - debugger_ = Debugger::GetInstance(); - debugger_->Init(device_id_); - } -#endif - - protected: - // Get graph by graph id ,if not exist return null ptr - KernelGraphPtr GetGraph(GraphId graph_id); - virtual void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const; - void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, - const std::vector &input_tensors) const; - void Reorder(std::vector *node_list); - void Summary(KernelGraph *graph); - // create graph output for RunOp - void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph); - CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph); - // create a single run op graph - std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, - const std::vector &input_tensors, - const std::vector &tensors_mask); - // trans BaseRef list to py::tuple - BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); - // create a new kernel graph and update the graph sum - KernelGraphPtr NewKernelGraph(); - std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph); - virtual ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); - ValueNodePtr CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph); - ParameterPtr CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph); - AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph); - void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); - void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); - - std::unordered_map> graphs_; - std::unordered_map> run_op_graphs_; - std::unordered_map front_backend_graph_map_; - std::shared_ptr context_; - CallBackFunc summary_callback_; - static GraphId graph_sum_; - uint32_t device_id_; -#ifdef ENABLE_DEBUGGER - std::shared_ptr debugger_; -#endif -}; - -using SessionPtr = std::shared_ptr; -using NamedSummaryOutputs = std::map>; -} // namespace session -} // namespace mindspore -#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H diff --git a/mindspore/ccsrc/session/session_context.cc b/mindspore/ccsrc/session/session_context.cc deleted file mode 100644 index 2b6ebf6b84..0000000000 --- a/mindspore/ccsrc/session/session_context.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/session_context.h" -namespace mindspore { -namespace session { -std::shared_ptr Context::GetInstance() { - static std::shared_ptr context_singleton = std::make_shared(); - return context_singleton; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/session_context.h b/mindspore/ccsrc/session/session_context.h deleted file mode 100644 index 78794c348e..0000000000 --- a/mindspore/ccsrc/session/session_context.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H -#define MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H -#include -#include -#include -#include -#include -#include - -#include "ir/tensor.h" -#include "pipeline/resource.h" -#include "utils/context/ms_context.h" -namespace mindspore { -namespace session { -const char kInputCtrlTensors[] = "input_ctrl_tensors"; - -class Context : public pipeline::ResourceBase { - public: - explicit Context(std::string target = kAscendDevice, uint32_t device_id = 0) - : target_(std::move(target)), device_id_(device_id) {} - ~Context() override = default; - - uint32_t device_id() const { return device_id_; } - static std::shared_ptr GetInstance(); - void AddManager(const FuncGraphManagerPtr &m) { manager_list_.push_back(m); } - - private: - std::vector manager_list_; - std::string target_; - uint32_t device_id_; -}; -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_SESSION_SESSION_CONTEXT_H diff --git a/mindspore/ccsrc/session/session_factory.cc b/mindspore/ccsrc/session/session_factory.cc deleted file mode 100644 index 4cd0481f8c..0000000000 --- a/mindspore/ccsrc/session/session_factory.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "session/session_factory.h" -#include -#include -#include -namespace mindspore { -namespace session { -SessionFactory &SessionFactory::Get() { - static SessionFactory instance; - return instance; -} - -void SessionFactory::Register(const std::string &device_name, SessionCreator &&session_creator) { - if (session_creators_.end() == session_creators_.find(device_name)) { - (void)session_creators_.emplace(device_name, session_creator); - } -} - -std::shared_ptr SessionFactory::Create(const std::string &device_name) { - auto iter = session_creators_.find(device_name); - if (session_creators_.end() != iter) { - MS_EXCEPTION_IF_NULL(iter->second); - return (iter->second)(); - } - return nullptr; -} -} // namespace session -} // namespace mindspore diff --git a/mindspore/ccsrc/session/session_factory.h b/mindspore/ccsrc/session/session_factory.h deleted file mode 100644 index 99db0afeb7..0000000000 --- a/mindspore/ccsrc/session/session_factory.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ -#define MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ - -#include -#include -#include -#include -#include -#include "common/utils.h" -#include "session/session_basic.h" -namespace mindspore { -namespace session { -using SessionCreator = std::function()>; -class SessionFactory { - public: - static SessionFactory &Get(); - void Register(const std::string &device_name, SessionCreator &&session_creator); - std::shared_ptr Create(const std::string &device_name); - - private: - SessionFactory() = default; - ~SessionFactory() = default; - DISABLE_COPY_AND_ASSIGN(SessionFactory) - std::map session_creators_; -}; - -class SessionRegistrar { - public: - SessionRegistrar(const std::string &device_name, SessionCreator &&session_creator) { - SessionFactory::Get().Register(device_name, std::move(session_creator)); - } - ~SessionRegistrar() = default; -}; - -#define MS_REG_SESSION(DEVICE_NAME, SESSION_CLASS) \ - static const SessionRegistrar g_session_registrar__##DEVICE_NAME##_##_reg( \ - DEVICE_NAME, []() { return std::make_shared(); }); -} // namespace session -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_SESSION_SESSION_FACTORY_H_ diff --git a/mindspore/ccsrc/transform/CMakeLists.txt b/mindspore/ccsrc/transform/CMakeLists.txt deleted file mode 100644 index c783cc0060..0000000000 --- a/mindspore/ccsrc/transform/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -if (ENABLE_GE OR ENABLE_D) - file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") - set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) - add_library(_mindspore_transform_obj OBJECT ${_TRANSFORM_SRC_LIST}) - - if (NOT ENABLE_GE) - target_compile_definitions(_mindspore_transform_obj PRIVATE NO_GE_CLIENT) - endif() -endif () diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc deleted file mode 100644 index 56ce06d2d7..0000000000 --- a/mindspore/ccsrc/transform/convert.cc +++ /dev/null @@ -1,2073 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/convert.h" - -#include -#include -#include -#include "utils/utils.h" - -#include "operator/ops.h" -#include "utils/log_adapter.h" -#include "utils/graph_utils.h" -#include "utils/symbolic.h" -#include "utils/config_manager.h" -#include "utils/convert_utils.h" -#include "./common.h" -#include "utils/context/ms_context.h" - -namespace mindspore { -namespace transform { -using std::endl; - -#define ADPT_DESC_ONE(T) std::make_shared(std::make_shared>()) -#define ADPT_DESC_TWO(T, I) \ - std::make_shared(std::make_shared>(), std::make_shared>()) -#define GET_MACRO(_1, _2, DESC, ...) DESC -#define ADPT_DESC(...) GET_MACRO(__VA_ARGS__, ADPT_DESC_TWO, ADPT_DESC_ONE, ...)(__VA_ARGS__) - -using ge::Operator; -using mindspore::kAnyValue; -using std::make_shared; -using std::shared_ptr; -using std::string; -using std::vector; - -const char kNameCustomOp[] = "CustomOp"; -const char kNameConst[] = "Const"; -const char kNameParam[] = "parameter"; -const char kNameRandomUniform[] = "RandomUniform"; -const char kNameSimpleMean[] = "SimpleMean"; -const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; -const char kNameAllReduce[] = "AllReduce"; -const char kNameBroadcast[] = "Broadcast"; -const char kNameAllgather[] = "AllGather"; -const char kNameReduceScatter[] = "ReduceScatter"; -const char kNameReduceSum[] = "ReduceSum"; -const char kNameIsFinite[] = "isFinite"; -const char kNameReciprocal[] = "Reciprocal"; -const char kNameRsqrt[] = "Rsqrt"; -const char kNameRsqrtGrad[] = "RsqrtGrad"; -const char kNameSqrt[] = "Sqrt"; -const char kNameSquare[] = "Square"; -const char kNameSquaredDifference[] = "SquaredDifference"; -const char kNamePow[] = "Pow"; -const char kNameBatchMatMul[] = "BatchMatMul"; -const char kNameStridedSlice[] = "StridedSlice"; -const char kNameStridedSliceGrad[] = "StridedSliceGrad"; -const char kNameExpandDims[] = "ExpandDims"; -const char kNameLog[] = "Log"; -const char kNameLogicalAnd[] = "LogicalAnd"; -const char kNameLogicalNot[] = "LogicalNot"; -const char kNameLogicalOr[] = "LogicalOr"; -const char kNameExp[] = "Exp"; -const char kNameLessEqual[] = "LessEqual"; -const char kNameGreaterEqual[] = "GreaterEqual"; -const char kNameEqual[] = "Equal"; -const char kNameNotEqual[] = "NotEqual"; -const char kNameFlattenGrad[] = "FlattenGrad"; -const char kNameConvolution[] = "Convolution"; -const char kNameBiasAdd[] = "BiasAdd"; -const char kNameMaxPoolGrad[] = "MaxPoolGrad"; -const char kNameAvgPoolGrad[] = "AvgPoolGrad"; -const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; -const char kNameApplyMomentum[] = "ApplyMomentum"; -const char kNameDropoutDoMask[] = "DropoutDoMask"; -const char kNameResizeBilinear[] = "ResizeBilinear"; -const char kNameResizeBilinearGrad[] = "ResizeBilinearGrad"; -const char kNameZerosLike[] = "ZerosLike"; -const char kNameOnesLike[] = "OnesLike"; -const char kNameTruncatedNormal[] = "TruncatedNormal"; -const char kNameSpaceToBatchNd[] = "SpaceToBatchNd"; -const char kNameConfusionMatrix[] = "ConfusionMatrix"; -const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; -const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; -const char kNameApplyAdam[] = "Adam"; -const char kNameExtractImagePatches[] = "ExtractImagePatches"; -const char kNameReLU6[] = "ReLU6"; -const char kNameReLU6Grad[] = "ReLU6Grad"; -const char kNameElu[] = "Elu"; -const char kNameEluGrad[] = "EluGrad"; -const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; -const char kNameScatterUpdate[] = "ScatterUpdate"; -const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; -const char kNameScatterMax[] = "ScatterMax"; -const char kNameNMSWithMask[] = "NMSWithMask"; -const char kNameCheckValid[] = "CheckValid"; -const char kNameSmoothL1Loss[] = "SmoothL1Loss"; -const char kNameSmoothL1LossGrad[] = "SmoothL1LossGrad"; -const char kNameSGD[] = "SGD"; -const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits"; -const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; -const char kNameScatterNdD[] = "ScatterNd"; -const char kNamePadD[] = "Pad"; -const char kNameMirrorPad[] = "MirrorPad"; -const char kNameMirrorPadGrad[] = "MirrorPadGrad"; -const char kNameGatherNd[] = "GatherNd"; -const char kNameArgmax[] = "Argmax"; -const char kNameArgmin[] = "Argmin"; -const char kNameArgMaxWithValue[] = "ArgMaxWithValue"; -const char kNameArgMinWithValue[] = "ArgMinWithValue"; -const char kNameReduceProd[] = "ReduceProd"; -const char kNameCumProd[] = "CumProd"; -const char kNameDiagpart[] = "Diagpart"; -const char kNameSplitD[] = "Split"; -const char kNameBatchToSpaceNd[] = "BatchToSpaceNd"; -const char kNameFloor[] = "Floor"; -const char kNameNPUGetFloatStatus[] = "NPUGetFloatStatus"; -const char kNameAssign[] = "Assign"; -const char kNameAssignAdd[] = "AssignAdd"; -const char kNameAssignSub[] = "AssignSub"; -const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; -const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; -const char kNameReshape[] = "Reshape"; -const char kNameTransShape[] = "TransShape"; -const char kNameRealDiv[] = "RealDiv"; -const char kNameTile[] = "Tile"; -const char kNameCos[] = "Cos"; -const char kNameACos[] = "ACos"; -const char kNameACosGrad[] = "ACosGrad"; -const char kNameFloorDiv[] = "FloorDiv"; -const char kNameSin[] = "Sin"; -const char kNamePrelu[] = "PReLU"; -const char kNamePreluGrad[] = "PReLUGrad"; -const char kNameSigmoid[] = "Sigmoid"; -const char kNameSigmoidGrad[] = "SigmoidGrad"; -const char kNameL2Normalize[] = "L2Normalize"; -const char kNameL2NormalizeGrad[] = "L2NormalizeGrad"; -const char kNameSoftmax[] = "Softmax"; -const char kNameIOU[] = "IOU"; -const char kNameBoundingBoxDecode[] = "BoundingBoxDecode"; -const char kNameBoundingBoxEncode[] = "BoundingBoxEncode"; -const char kNameSlice[] = "Slice"; -const char kNameAddN[] = "AddN"; -const char kNameLess[] = "Less"; -const char kNameGreater[] = "Greater"; -const char kNamePack[] = "Pack"; -const char kNameUnpack[] = "Unpack"; -const char kNameMerge[] = "Merge"; -const char kNameGeSwitch[] = "GeSwitch"; - -const char kNameHuberLoss[] = "HuberLoss"; -const char kNameCumSum[] = "CumSum"; -const char kNameHuberLossGrad[] = "HuberLossGrad"; -const char kNameSparseSoftmaxCrossEntropy[] = "SparseSoftmaxCrossEntropy"; -const char kNameSparseSoftmaxCrossEntropyGrad[] = "SparseSoftmaxCrossEntropyGrad"; -const char kNameTopK[] = "TopK"; -const char kNameSoftmaxGrad[] = "SoftmaxGrad"; -const char kNameMaxPool[] = "MaxPool"; -const char kNameAvgPool[] = "AvgPool"; -const char kNameMaxPoolWithArgmax[] = "MaxPoolWithArgmax"; -const char kNameBatchNorm[] = "BatchNorm"; -const char kNameBatchNormGrad[] = "BatchNormGrad"; -const char kNameROIAlign[] = "ROIAlign"; -const char kNameROIAlignGrad[] = "ROIAlignGrad"; -const char kNameRandomChoiceWithMask[] = "RandomChoiceWithMask"; -const char kNameAbs[] = "Abs"; -const char kNameAbsGrad[] = "AbsGrad"; -const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; -const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; -const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; -const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD"; -const char kNameApplyProximalAdagrad[] = "ApplyProximalAdagrad"; -const char kNameAcosh[] = "Acosh"; -const char kNameAcoshGrad[] = "AcoshGrad"; -const char kNameFloorMod[] = "FloorMod"; -const char kNameSpaceToDepth[] = "SpaceToDepth"; -const char kNameDepthToSpace[] = "DepthToSpace"; -const char kNameSign[] = "Sign"; -const char kNameLARSUpdate[] = "LARSUpdate"; -const char kNameRound[] = "Round"; -const char kNamePrint[] = "Print"; -const char kNameApplyFtrl[] = "ApplyFtrl"; -const char kNameDiag[] = "Diag"; -const char kNameDiagPart[] = "DiagPart"; -const char kNameSpaceToBatch[] = "SpaceToBatch"; -const char kNameBatchToSpace[] = "BatchToSpace"; -const char kNameAtan2[] = "Atan2"; -const char kNameApplyRMSProp[] = "ApplyRMSProp"; -const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; -const char kNameL2Loss[] = "L2Loss"; -const char kNameCTCLoss[] = "CTCLoss"; -const char kNameRange[] = "Range"; -const char kNameSquareSumAll[] = "SquareSumAll"; -const char kNameAscendQuant[] = "AscendQuant"; -const char kNameAscendDequant[] = "AscendDequant"; -const char kNameCase[] = "Case"; - -// -----------------OpAdapter initialization-------------- -std::unordered_map &DfGraphConvertor::get_adpt_map() { - static std::unordered_map adpt_map = { - {string(kNameCustomOp), ADPT_DESC(Operator)}, - {string(kNameIOU), ADPT_DESC(Iou)}, - {string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)}, - {string(kNameSlice), ADPT_DESC(SliceD)}, - {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)}, - {string(kNameMaxPool), ADPT_DESC(MaxPool)}, - {string(kNameAvgPool), ADPT_DESC(AvgPool)}, - {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, - {string(kNameTopK), ADPT_DESC(TopK)}, - {string(kNamePack), ADPT_DESC(Pack)}, - {string(kNameUnpack), ADPT_DESC(Unpack)}, - {string(kNameSplitD), ADPT_DESC(SplitD)}, - {string(kNameAllReduce), ADPT_DESC(HcomAllReduce)}, - {string(kNameBroadcast), ADPT_DESC(HcomBroadcast)}, - {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, - {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, - {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, - {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, - {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, - {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, - {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, - {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, - {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, - {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, - {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, - {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, - {prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD)}, - {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, - {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, - {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, - {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, - {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, - {string(kNameReshape), ADPT_DESC(Reshape)}, - {string(kNameTransShape), ADPT_DESC(TransShape)}, - {string(kNameFlattenGrad), ADPT_DESC(Reshape)}, - {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, - {string(kNameAddN), ADPT_DESC(AddN)}, - {string(kNameLess), ADPT_DESC(Less)}, - {string(kNameSqrt), ADPT_DESC(Sqrt)}, - {string(kNameRsqrt), ADPT_DESC(Rsqrt)}, - {string(kNameSquare), ADPT_DESC(Square)}, - {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, - {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, - {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)}, - {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)}, - {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, - {string(kNameReLU6), ADPT_DESC(Relu6)}, - {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, - {string(kNameElu), ADPT_DESC(Elu)}, - {string(kNameEluGrad), ADPT_DESC(EluGrad)}, - {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)}, - {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, - {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, - {string(kNameOnesLike), ADPT_DESC(OnesLike)}, - {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, - {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, - {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, - {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, - {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, - {string(kNameCheckValid), ADPT_DESC(CheckValid)}, - {string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)}, - {string(kNameSmoothL1LossGrad), ADPT_DESC(SmoothL1LossGrad)}, - {string(kNameSigmoidCrossEntropyWithLogits), ADPT_DESC(SigmoidCrossEntropyWithLogits)}, - {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, - {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, - {string(kNamePadD), ADPT_DESC(PadD)}, - {string(kNameMirrorPad), ADPT_DESC(MirrorPad)}, - {string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)}, - {string(kNameGatherNd), ADPT_DESC(GatherNd)}, - {string(kNameArgmax), ADPT_DESC(ArgMaxD)}, - {string(kNameArgmin), ADPT_DESC(ArgMinD)}, - {string(kNameArgMaxWithValue), ADPT_DESC(ArgMaxWithValue)}, - {string(kNameArgMinWithValue), ADPT_DESC(ArgMinWithValue)}, - {prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD)}, - {prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMeanD)}, - {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD)}, - {prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD)}, - {prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD)}, - {string(kNameLARSUpdate), ADPT_DESC(LarsV2Update)}, - {string(kNameReduceProd), ADPT_DESC(ReduceProdD)}, - {string(kNameCumProd), ADPT_DESC(CumprodD)}, - {string(kNameMerge), ADPT_DESC(Merge)}, - {string(kNameGeSwitch), ADPT_DESC(Switch)}, - {string(kNameCumSum), ADPT_DESC(CumsumD)}, - - {prim::kPrimMul->name(), ADPT_DESC(Mul)}, - {string(kNameTile), ADPT_DESC(TileD)}, - {prim::kPrimOneHot->name(), ADPT_DESC(OneHot)}, - - {prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2D)}, - {string(kNameCos), ADPT_DESC(Cos)}, - {string(kNameACos), ADPT_DESC(Acos)}, - {string(kNameACosGrad), ADPT_DESC(AcosGrad)}, - {string(kNameFloor), ADPT_DESC(Floor)}, - {string(kNameFloorDiv), ADPT_DESC(FloorDiv)}, - {string(kNameSin), ADPT_DESC(Sin)}, - {string(kNameExp), ADPT_DESC(Exp)}, - {string(kNameBoundingBoxEncode), ADPT_DESC(BoundingBoxEncode)}, - {string(kNameBoundingBoxDecode), ADPT_DESC(BoundingBoxDecode)}, - - {prim::kPrimCast->name(), ADPT_DESC(Cast)}, - {string(kNameRealDiv), ADPT_DESC(RealDiv)}, - {prim::kPrimNeg->name(), ADPT_DESC(Neg)}, - {prim::kPrimTranspose->name(), ADPT_DESC(TransposeD)}, - {prim::kPrimSub->name(), ADPT_DESC(Sub)}, - {string(kNameReciprocal), ADPT_DESC(Reciprocal)}, - {prim::kPrimDropoutGenMask->name(), ADPT_DESC(DropOutGenMask)}, - {string(kNameAssignAdd), ADPT_DESC(AssignAdd)}, - {string(kNameAssignSub), ADPT_DESC(AssignSub)}, - {prim::kPrimConcat->name(), ADPT_DESC(ConcatD)}, - {string(kNamePow), ADPT_DESC(Pow)}, - {string(kNameExp), ADPT_DESC(Exp)}, - {string(kNameEqual), ADPT_DESC(Equal)}, - {string(kNameNotEqual), ADPT_DESC(NotEqual)}, - {string(kNameLog), ADPT_DESC(Log)}, - {string(kNameLogicalAnd), ADPT_DESC(LogicalAnd)}, - {string(kNameLogicalNot), ADPT_DESC(LogicalNot)}, - {string(kNameLogicalOr), ADPT_DESC(LogicalOr)}, - {string(kNameGreater), ADPT_DESC(Greater)}, - {prim::kPrimMaximum->name(), ADPT_DESC(Maximum)}, - {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, - {string(kNamePrelu), ADPT_DESC(PRelu)}, - {string(kNamePreluGrad), ADPT_DESC(PReluGrad)}, - {string(kNameSigmoid), ADPT_DESC(Sigmoid)}, - {string(kNameSigmoidGrad), ADPT_DESC(SigmoidGrad)}, - {string(kNameSGD), ADPT_DESC(SGD)}, - {prim::kPrimLogSoftmaxGrad->name(), ADPT_DESC(LogSoftmaxGrad)}, - {prim::kPrimMaximumGrad->name(), ADPT_DESC(MaximumGrad)}, - {prim::kPrimMinimumGrad->name(), ADPT_DESC(MinimumGrad)}, - {string(kNameL2Normalize), ADPT_DESC(L2Normalize)}, - {string(kNameL2NormalizeGrad), ADPT_DESC(L2NormalizeGrad)}, - - {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, - {prim::kPrimSelect->name(), ADPT_DESC(Select)}, - {string(kNameLessEqual), ADPT_DESC(LessEqual)}, - {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)}, - {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, - {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, - {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, - {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, - {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, - {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMin)}, - {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, - {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, - {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, - {prim::kPrimLayerNorm->name(), ADPT_DESC(LayerNorm)}, - {prim::kPrimLayerNormGrad->name(), ADPT_DESC(LayerNormGrad)}, - {string(kNameBatchMatMul), ADPT_DESC(BatchMatMul)}, - {string(kNameDropoutDoMask), ADPT_DESC(DropOutDoMask)}, - - {string(kNameNPUGetFloatStatus), ADPT_DESC(NPUGetFloatStatus)}, - {string(kNameNPUAllocFloatStatus), ADPT_DESC(NPUAllocFloatStatus)}, - {string(kNameNPUClearFloatStatus), ADPT_DESC(NPUClearFloatStatus)}, - - {string(kNameRandomChoiceWithMask), ADPT_DESC(RandomChoiceWithMask)}, - {prim::kPrimSoftmaxCrossEntropyWithLogits->name(), ADPT_DESC(SoftmaxCrossEntropyWithLogits)}, - - {prim::kPrimScalarSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimImageSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimTensorSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary)}, - {prim::kPrimDebug->name(), ADPT_DESC(Summary)}, - {prim::kPrimTensorAdd->name(), - std::make_shared(std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})), - std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})))}, - {string(kNameBiasAdd), ADPT_DESC(BiasAdd)}, - {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, - - {prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2)}, - - {string(kNameConst), ADPT_DESC(Constant, Const)}, - {string(kNameSoftmax), ADPT_DESC(SoftmaxV2)}, - {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, - {string(kNameParam), ADPT_DESC(Data)}, - {string(kNameROIAlign), ADPT_DESC(ROIAlign)}, - {string(kNameROIAlignGrad), ADPT_DESC(ROIAlignGrad)}, - {string(kNameAbs), ADPT_DESC(Abs)}, - {string(kNameAbsGrad), ADPT_DESC(AbsGrad)}, - {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, - {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, - {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, - {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, - {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, - {string(kNameAcosh), ADPT_DESC(Acosh)}, - {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, - {string(kNameFloorMod), ADPT_DESC(FloorMod)}, - {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, - {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, - {string(kNameSign), ADPT_DESC(Sign)}, - {string(kNameRound), ADPT_DESC(Round)}, - {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)}, - {string(kNameDiag), ADPT_DESC(Diag)}, - {string(kNameDiagPart), ADPT_DESC(DiagPart)}, - {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, - {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, - {string(kNameAtan2), ADPT_DESC(Atan2)}, - {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, - {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, - {string(kNameL2Loss), ADPT_DESC(L2Loss)}, - {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, - {string(kNameRange), ADPT_DESC(RangeD)}, - {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, - {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, - {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}, - {string(kNameCase), ADPT_DESC(Case)}}; -#ifdef ENABLE_GE - adpt_map[string(kNamePrint)] = ADPT_DESC(Print); - adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); -#endif - return adpt_map; -} - -// ---------------implement of DfGraphConvertor------------- -PrimType GetCNodeFuncType(const CNodePtr cnode) { - if (cnode->inputs().empty()) { - return kPrimTypeUnknown; - } - - AnfNodePtr valuenode = cnode->input(0); - if (IsValueNode(valuenode)) { - // check whether the valuenode is primitive - return GetValueNode(valuenode)->prim_type(); - } - return kPrimTypeUnknown; -} - -bool IsCaseNode(const CNodePtr node) { - if (!node->inputs().empty() && node->input(0)->isa() && - GetCNodeFuncName(node->input(0)->cast()) == "switch_layer") { - return true; - } - return false; -} - -std::string GetCNodeTargetFuncName(const CNodePtr cnode) { - if (IsCaseNode(cnode)) { - return string(kNameCase); - } - auto name = GetCNodeFuncName(cnode); - if (name == "switch_layer") { - name = ""; - } - return name; -} - -OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { - if (node->isa()) { - auto cnode = node->cast(); - - std::string name = kNameCustomOp; - if (!IsCustomCNode(cnode)) { - name = GetCNodeTargetFuncName(cnode); - } - - auto it_adpt = get_adpt_map().find(name); - if (it_adpt != get_adpt_map().end()) { - return it_adpt->second->Get(train); - } - MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; - } - - if (node->isa()) { - return get_adpt_map()[kNameConst]->Get(train); - } - if (node->isa()) { - return get_adpt_map()[kNameParam]->Get(train); - } - return OpAdapterPtr(nullptr); -} - -void DfGraphConvertor::InitLoopVar(std::vector *init_input) { - if (this->training_) { - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); - auto var_iter_num = std::make_shared("npu_runconfig/iterations_per_loop"); - auto var_loop_cond = std::make_shared("npu_runconfig/loop_cond"); - auto var_one = std::make_shared("npu_runconfig/one"); - auto var_zero = std::make_shared("npu_runconfig/zero"); - (void)var_iter_num->update_output_desc_y(desc); - (void)var_loop_cond->update_output_desc_y(desc); - (void)var_one->update_output_desc_y(desc); - (void)var_zero->update_output_desc_y(desc); - vars_["npu_runconfig/iterations_per_loop"] = var_iter_num; - vars_["npu_runconfig/loop_cond"] = var_loop_cond; - vars_["npu_runconfig/one"] = var_one; - vars_["npu_runconfig/zero"] = var_zero; - - int64_t value = 0; - auto const_iter_num = std::make_shared("const/npu_runconfig/iterations_per_loop"); - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - value = ConfigManager::GetInstance().iter_num(); - } else { - MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1"; - value = 1; - ConfigManager::GetInstance().set_iter_num(value); - } - value -= 1; // iteration start from 0, the max iteration number for n loop should be n-1 - (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - auto const_loop_cond = std::make_shared("const/npu_runconfig/loop_cond"); - value = 0; - (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - auto const_one = std::make_shared("const/npu_runconfig/one"); - value = 1; - (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - auto const_zero = std::make_shared("const/npu_runconfig/zero"); - value = 0; - (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); - - (void)const_iter_num->update_output_desc_y(desc); - (void)const_loop_cond->update_output_desc_y(desc); - (void)const_one->update_output_desc_y(desc); - (void)const_zero->update_output_desc_y(desc); - - auto assign_iter_num = std::make_shared("assign/npu_runconfig/iterations_per_loop"); - (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num); - auto assign_loop_cond = std::make_shared("assign/npu_runconfig/loop_cond"); - (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond); - auto assign_one = std::make_shared("assign/npu_runconfig/one"); - (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one); - auto assign_zero = std::make_shared("assign/npu_runconfig/zero"); - (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero); - - init_input->push_back(*var_iter_num); - init_input->push_back(*var_loop_cond); - init_input->push_back(*var_one); - init_input->push_back(*var_zero); - init_ops_.push_back(var_iter_num); - init_ops_.push_back(var_loop_cond); - init_ops_.push_back(var_one); - init_ops_.push_back(var_zero); - init_ops_.push_back(const_iter_num); - init_ops_.push_back(const_loop_cond); - init_ops_.push_back(const_one); - init_ops_.push_back(const_zero); - init_ops_.push_back(assign_iter_num); - init_ops_.push_back(assign_loop_cond); - init_ops_.push_back(assign_one); - init_ops_.push_back(assign_zero); - } -} - -OpAdapterPtr DfGraphConvertor::FindAdapter(const std::string &name, bool train) { - auto it = get_adpt_map().find(name); - if (it != get_adpt_map().end()) { - return it->second->Get(train); - } - MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; -} - -void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) { - // draw init subgraph - init_sout_ << "op_assign" << it.get() << "[label=<"; - init_sout_ << "" << endl; - init_sout_ << ""; - init_sout_ << ""; - init_sout_ << ""; - init_sout_ << "" << endl; - init_sout_ << "" << endl; - init_sout_ << "
resourcevalue
" - << "\"assign_" << name << "\"
> shape=plaintext]" << endl; - init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl; - init_sout_ << "const" << it.get() << "[label= \"" << name << "_const" - << "\" shape=ellipse]" << endl; - init_sout_ << "param" << it.get() << "->" - << "op_assign" << it.get() << ":1" << endl; - init_sout_ << "const" << it.get() << "->" - << "op_assign" << it.get() << ":2" << endl; -} - -void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input) { - DfGraphPtr init_graph = std::make_shared("init"); - std::vector nodes = TopoSort(anf_graph_->get_return()); - - for (auto &it : nodes) { - if (it->isa()) { - if (IsValueNode(it)) { - auto symbolic = GetValueNode(it); - auto name = std::static_pointer_cast(symbolic->node())->name(); - auto iter = vars_.find(name); // get correspoding varaible op - if (iter != vars_.end()) { - op_cache_[it.get()] = iter->second; - // #ifdef DRAW_GE_GRAPH - compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] - << "[style=\"dotted\"]" << endl; - // #endif - } - } else if (IsValueNode(it)) { - auto refkey = GetValueNode(it); - auto name = refkey->tag(); - auto iter = vars_.find(name); // get correspoding varaible op - if (iter != vars_.end()) { - op_cache_[it.get()] = iter->second; - compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] - << "[style=\"dotted\"]" << endl; - } - } - } - } - - for (auto &it : tensors) { - if (vars_.find(it.first) == vars_.end()) { - MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph."; - vars_[it.first] = nullptr; - } - } - - // set up init sub graph - if (init_input->size()) { - // init sub graph needs no input - MS_LOG(INFO) << "Build data init subgraph."; - (void)init_graph->SetInputs(*init_input); - this->init_graph_ = init_graph; - } else { - this->init_graph_ = nullptr; - } -} - -void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) { - MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input"; - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - auto getnext_idx = static_cast(input_idx); - DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); - if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) { - getnext_idx = param.input_indexes()[input_idx] - 1; // input_idx start from 0. - MS_LOG(INFO) << "remap input_index:" << input_idx << " to getnext_index:" << getnext_idx << "."; - } - // use iterator_getnext op with output_name instead of data op in BuildGraph. - out_handle_cache_[it.get()] = OutHandler(dataset_iter_getnext_, "y" + std::to_string(getnext_idx)); - } -} - -void DfGraphConvertor::SetupBroadcast(const std::shared_ptr &broadcast, - const std::vector &broadcast_desc, - const DfGraphPtr &broadcast_graph, std::vector broadcast_input) { - MS_LOG(INFO) << "build broadcast subgraph"; - if (broadcast_desc.size() != broadcast_input.size()) { - MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input"; - } - (void)broadcast->create_dynamic_input_x(static_cast(broadcast_input.size())); - (void)broadcast->create_dynamic_output_y(static_cast(broadcast_desc.size())); - for (unsigned int i = 0; i < broadcast_input.size(); i++) { - (void)broadcast->set_dynamic_input_x(i, broadcast_input[i]); - (void)broadcast->update_dynamic_output_desc_y(i, broadcast_desc[i]); - } - (void)broadcast_graph->SetInputs(broadcast_input); - this->broadcast_graph_ = broadcast_graph; -} - -void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) { - int index = 0; - std::vector init_input; - for (auto it : tensors) { - std::string name = it.first; - auto node_itor = params_.find(name); - // if name not in params_, create a node in graph - if (node_itor == params_.end()) { - MS_LOG(WARNING) << name << " is not in params, and create a new node."; - ParameterPtr param = std::make_shared(nullptr); - name = name + "_temp"; - param->set_name(name); - (void)ConvertParameter(param); - node_itor = params_.find(name); - } - auto node = node_itor->second; - auto op_itor = op_cache_.find(node.get()); - if (op_itor == op_cache_.end()) { - MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << "."; - } - auto adpt = FindAdapter(kNameParam, training_); - if (adpt == nullptr) continue; - auto param_op = adpt->generate(name + "_data"); - MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << "."; - - if (!training_) { - auto adpt_const = FindAdapter(kNameConst, training_); - if (adpt_const == nullptr) continue; - auto const_op = adpt_const->generate(name + "_const"); - (void)adpt_const->setAttr(const_op, "value", it.second); - - auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); - if (const_op_desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; - continue; - } - (void)std::static_pointer_cast(const_op)->update_output_desc_y(*const_op_desc); - - vars_[name] = const_op; - op_itor->second = const_op; - continue; - } - - // create tensor descriptor for output descriptor - auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; - continue; - } - - // we need three variable ops for each graph with same name - // build init subgraph - if (it.second->is_init() == 0) { - (void)std::static_pointer_cast(param_op)->set_attr_index(index++); - auto init_var = std::make_shared(name); - auto assign_op = std::make_shared("assign_" + name); - (void)init_var->update_output_desc_y(*desc); - (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op); - init_input.push_back(*init_var); - init_ops_.push_back(param_op); - init_ops_.push_back(assign_op); - init_ops_.push_back(init_var); - } - - auto variable = std::make_shared(name); - (void)variable->update_output_desc_y(*desc); - // do not use read variable while variable sink - MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << "."; - op_itor->second = variable; // replace parameter with variable - vars_[name] = variable; // prevent the variable operator from being freed - DrawParamInitSubGraph(name, node); - } - InitLoopVar(&init_input); - SetupParamInitSubGraph(tensors, &init_input); -} - -// convert all parameter need initialize to variable -DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) { - size_t input_idx = 0; - if (error_ != 0) { - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - error_ = INVALID_ARGUMENT; - MS_LOG(ERROR) << "Invalid AnfGraph in InitParam."; - return *this; - } - - // Processing input with MakeDatasetHandler - for (auto &it : anf_graph_->parameters()) { - auto op_itor = op_cache_.find(it.get()); // converted node - if (it->isa() && op_itor != op_cache_.end()) { - string name = std::static_pointer_cast(it)->name(); - auto tensor_itor = tensors.find(name); // in init value map - if (tensor_itor == tensors.end()) { - DfGraphConvertor::MakeDatasetHandler(name, input_idx, it); - input_idx++; - } - } - } - InitParamWithData(tensors); - init_sout_ << "}" << endl; - return *this; -} - -#if (defined ENABLE_GE) -void DfGraphConvertor::BuildSaveCheckpointGraph() { - std::vector graph_inputs; - ge::op::Save save_op("save_parms"); - int save_op_is_active = 0; - size_t index = 0; - string name; - - int32_t count_size = std::count_if(vars_.begin(), vars_.end(), [](const std::pair &it) { - return (it.second == nullptr || it.first.find("/") != std::string::npos); - }); - - (void)save_op.create_dynamic_input_tensors(vars_.size() - static_cast(count_size)); - - // for each "parameter" in anf graph excluding "input" - for (const auto &it : vars_) { - name = it.first; - if (it.second == nullptr || name.find("/") != std::string::npos) continue; - Variable variable(name); - (void)variable.update_output_desc_y(it.second->GetOutputDesc(0)); - (void)save_op.set_dynamic_input_tensors(index++, variable); - - graph_inputs.push_back(variable); - - if (save_op_is_active == 0) { - checkpoint_sout_ << "op_save" << &save_op << "[label=<"; - checkpoint_sout_ << "" << endl; - checkpoint_sout_ << "" << endl; - checkpoint_sout_ << "" << endl; - checkpoint_sout_ << "
tensor
" - << "\"saveop" - << "\"
> shape=plaintext]" << endl; - } - - checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl; - - checkpoint_sout_ << "param" << it.second << "->" - << "op_save" << &save_op << ":1" << endl; - save_op_is_active = 1; - } - if (save_op_is_active) { - std::vector graph_output; - graph_output.emplace_back(save_op); - DfGraphPtr checkpoint_graph = std::make_shared("checkpoint"); - (void)checkpoint_graph->SetInputs(graph_inputs); - (void)checkpoint_graph->SetOutputs(graph_output); - this->save_ckp_graph_ = checkpoint_graph; - } else { - this->save_ckp_graph_ = nullptr; - } - - checkpoint_sout_ << "}" << endl; - return; -} -#endif - -DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) { - if (error_ != 0) { - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - error_ = INVALID_ARGUMENT; - MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph"; - return *this; - } - - DfGraphPtr broadcast_graph = std::make_shared("broadcast"); - // collect the operators create for broadcast sub graph, in order to avoid auto release - std::vector broadcast_input; - std::vector broadcast_desc; - auto broadcast = std::make_shared("broadcast_parameter"); - (void)broadcast->set_attr_root_rank(0); - (void)broadcast->set_attr_group("hccl_world_group"); - broadcast_ops_.push_back(broadcast); - - // find every parameter, build broadcast subgraph (or initialize the parameter with constant) - for (auto &it : anf_graph_->parameters()) { - auto op_itor = op_cache_.find(it.get()); // converted node - if (it->isa() && op_itor != op_cache_.end()) { - string name = std::static_pointer_cast(it)->name(); - auto tensor_itor = tensors.find(name); // in init tensor map - if (tensor_itor != tensors.end()) { - auto tensor = tensor_itor->second; - auto shape_ge = tensor->shape_c(); - - // create tensor descriptor for output descriptor - auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; - continue; - } - - // build broadcast subgraph - if (distribute_) { - auto broadcast_var = std::make_shared(name); - (void)broadcast_var->update_output_desc_y(*desc); - broadcast_input.push_back(*broadcast_var); - broadcast_desc.push_back(*desc); - broadcast_ops_.push_back(broadcast_var); - } - } - } - } - - // set up broadcast sub graph - if (!broadcast_input.empty()) { - DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input); - } else { - this->broadcast_graph_ = nullptr; - } - return *this; -} - -DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() { - if (error_ != 0) { - MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << "."; - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - error_ = INVALID_ARGUMENT; - MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph"; - return *this; - } -#if (defined ENABLE_GE) - BuildSaveCheckpointGraph(); - // Restoring from checkpoint file is done by pyfront, not in graph now. -#endif - return *this; -} - -DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { - if (error_ != 0) { - return *this; - } - if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { - MS_LOG(ERROR) << "Invalid AnfGraph"; - error_ = FAILED; - return *this; - } - - compute_sout_.clear(); - compute_sout_ << "digraph {" << endl; - init_sout_.clear(); - init_sout_ << "digraph {" << endl; - checkpoint_sout_.clear(); - checkpoint_sout_ << "digraph {" << endl; - restore_checkpoint_sout_.clear(); - restore_checkpoint_sout_ << "digraph {" << endl; - - // Convert all anf node to Operator - MS_LOG(DEBUG) << "convert all node"; - std::vector nodes = TopoSort(anf_graph_->get_return()); - for (auto &it : nodes) { - (void)Convert(it); - if (this->error_ != 0) { - MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << "."; - } - } - - // Create dataset iterator and iterator_getnext node - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); - MS_LOG(INFO) << "Dataset param is " << param.ToString() << "."; - // GetNext - auto iter_getnext_op = make_shared("get_next_tmp"); - (void)iter_getnext_op->set_attr_output_types(param.ge_types()); - (void)iter_getnext_op->set_attr_output_shapes(param.shapes()); - (void)iter_getnext_op->set_attr_channel_name(param.queue_name()); - - // save iter_getnext_op for later use - dataset_iter_getnext_ = iter_getnext_op; - } - - // return the data flow graph - return *this; -} - -void DfGraphConvertor::TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out) { - auto it = out_handle_cache_.find(anf_out.get()); - if (it != out_handle_cache_.end()) { - OutHandler handle = it->second; - auto op = handle.op; - if (op != nullptr) { - MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; - graph_outputs_.emplace_back(std::make_pair(*op, handle.out)); - } else { - MS_LOG(EXCEPTION) << "tuple_getitem: " << anf_out->fullname_with_scope() << " is not converted"; - } - } else { - // invalid tuple_getitem e.g. tuple_getitem(tuple_getitem())/tuple_getitem(depend())/tuple_getitem(make_tuple()) - MS_LOG(WARNING) << "Invalid tuple_getitem: " << anf_out->fullname_with_scope(); - } -} - -void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { - AnfNodePtr anf_out = node; - AnfNodePtr pre_node = nullptr; - - // trace Parameter node - TraceOutputFromParameter(anf_out); - // then trace cnode - if (!node->isa()) { - return; - } - - // trace tuple_getitem - while (anf_out->isa() && IsPrimitiveCNode(anf_out, prim::kPrimTupleGetItem)) { - pre_node = anf_out; - anf_out = anf_out->cast()->input(1); - } - // trace every element of make_tuple - auto c = anf_out->cast(); - std::string name = ""; - if (anf_out->isa()) { - name = GetCNodeTargetFuncName(c); - } - - if (name == "make_tuple") { - for (unsigned int i = 1; i < c->inputs().size(); i++) { - TraceOutput(c->input(i)); - } - } else if (name == "Depend") { - if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs - MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3"; - } - TraceOutput(c->input(1)); - } else if (name == "tuple_getitem") { - TraceOutputFromTupleGetItem(anf_out); - } else { - // add outputs; - auto op = Convert(anf_out); - std::string index; - if (op != nullptr) { - if ((pre_node != nullptr) && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) { - auto item = out_handle_cache_.find(pre_node.get()); - if (item != out_handle_cache_.end()) { - index = item->second.out; - } else { - MS_LOG(WARNING) << "Can't get operater: " << anf_out->fullname_with_scope() << " 's output item"; - } - } - MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index; - graph_outputs_.emplace_back(make_pair(*op, index)); - } - } -} - -void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) { - if (anf_out->isa()) { - MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope(); - auto it = out_handle_cache_.find(anf_out.get()); - if (it != out_handle_cache_.end()) { - // For dataset graph mode, input parameter is converted to a "iterator_get_next:yn" OutHandler. - OutHandler handle = it->second; - auto op = handle.op; - MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; - graph_outputs_.emplace_back(make_pair(*op, handle.out)); - } else { - // common parameter case - auto op = Convert(anf_out); - if (op != nullptr) { - MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType(); - graph_outputs_.emplace_back(std::make_pair(*op, "")); - } - } - } -} - -void SetupDatasetIterGetNextNode(const OperatorPtr &op) { - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); - size_t output_num = param.ge_types().size(); - MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << "."; - // set iterator_getnext op's output num - shared_ptr iter_getnext = std::static_pointer_cast(op); - (void)iter_getnext->create_dynamic_output_y(static_cast(output_num)); - - for (uint32_t i = 0; i < output_num; i++) { - ge::TensorDesc desc(GeShape(param.shapes()[i]), ge::FORMAT_NCHW, (ge::DataType)param.ge_types()[i]); - // we don't SetRealDimCnt here since GE do not use this output's real-dim - (void)iter_getnext->update_dynamic_output_desc_y((i), desc); - } - } - return; -} - -void DfGraphConvertor::SetSubgraph(AnfNodePtr node) { - if (!node->isa()) { - return; - } - auto cnode = node->cast(); - if (!IsCaseNode(cnode)) { - return; - } - std::vector case_inputs; - for (size_t i = 1; i < cnode->inputs().size(); i++) { - case_inputs.emplace_back(cnode->input(i)); - } - std::shared_ptr> branches = std::make_shared>(); - auto bnode = cnode->input(0)->cast()->input(2)->cast(); - - for (size_t i = 1; i < bnode->inputs().size(); i++) { - auto branch_node = bnode->input(i)->cast(); - for (size_t j = 2; j < branch_node->inputs().size(); j++) { - if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { - case_inputs.emplace_back(branch_node->input(j)); - } - } - } - - for (size_t i = 1; i < bnode->inputs().size(); i++) { - ProcessSubgraph(bnode->input(i), case_inputs); - } - - for (size_t i = 1; i < bnode->inputs().size(); i++) { - branches->emplace_back(branches_map_[bnode->input(i).get()]); - } - - if (op_cache_.find(node.get()) == op_cache_.end()) { - return; - } - - OpAdapterPtr adpt = FindAdapter(node, training_); - if (nullptr == adpt) { - MS_LOG(DEBUG) << "Not found adapter"; - return; - } - - OperatorPtr op = Convert(node); - adpt->setSubgraph(op, 0, branches); - return; -} - -void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) { - std::vector case_inputs; - for (size_t i = 1; i < node->inputs().size(); i++) { - case_inputs.emplace_back(node->input(i)); - } - std::shared_ptr> branches = std::make_shared>(); - auto bnode = input_node->input(2)->cast(); - - for (size_t i = 1; i < bnode->inputs().size(); i++) { - auto branch_node = bnode->input(i)->cast(); - for (size_t j = 2; j < branch_node->inputs().size(); j++) { - if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { - case_inputs.emplace_back(branch_node->input(j)); - } - } - } - - const size_t case_index = 1; - const size_t make_tuple_index = 2; - - AnfNodePtr case_index_iter = input_node->input(case_index); - AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index); - auto make_tuple_node = make_tuple_iter->cast(); - std::shared_ptr> tuple_items = std::make_shared>(); - - for (size_t i = 0; i < case_inputs.size(); i++) { - auto item = case_inputs[i]; - auto op = Convert(item); - if (op != nullptr) { - tuple_items->emplace_back(OutHandler(op, "")); - } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { - tuple_items->push_back(out_handle_cache_[item.get()]); - } else { - MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); - continue; - } - } - - tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items; - - std::shared_ptr> case_input_items = std::make_shared>(); - case_input_items->emplace_back(case_index_iter); - case_input_items->emplace_back(make_tuple_iter); - case_input_handle_cache_[node.get()] = case_input_items; -} - -DfGraphConvertor &DfGraphConvertor::BuildGraph() { - SetupDatasetIterGetNextNode(dataset_iter_getnext_); - - if (error_ != 0) { - return *this; - } - - // Case node set input. - std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); - for (auto &it : nodes) { - if (it->isa() && IsCaseNode(it->cast())) { - auto node = it->cast(); - auto input_node = node->input(0)->cast(); - GetCaseNodeInput(node, input_node); - } - } - - // update tuple_out_handle_cache_ - for (auto it : tuple_out_handle_cache_) { - std::size_t len = it.second->size(); - for (std::size_t i = 0; i < len; i++) { - OutHandler handle = (*it.second)[i]; - if (handle.op) { - string name = handle.op->GetName(); - if (vars_.count(name)) { - OperatorPtr new_op = vars_[name]; - if (new_op != nullptr) { - MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name; - (*it.second)[i] = OutHandler(new_op, handle.out); - } - } - } - } - } - - // set up dependices - MS_LOG(DEBUG) << "set up dependices"; - nodes = ::mindspore::TopoSort(anf_graph_->get_return()); - for (auto &it : nodes) { - SetNodeInput(it); - SetOpControlInput(it); - SetSubgraph(it); - UpdateOpDesc(it); - } - - if (error_ == 0) { - df_graph_ = make_shared(anf_graph_->ToString()); - } else { - return *this; - } - - // set graph input according to the order from anf graph - std::vector inputs; - if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { - inputs.push_back(*dataset_iter_getnext_); - } else { - auto params = anf_graph_->parameters(); - if (use_inputs_) { - params = inputs_; - auto anf_params = anf_graph_->parameters(); - for (size_t i = 0; i < params.size(); i++) { - for (size_t j = 0; j < anf_params.size(); j++) { - if (params[i]->ToString() == anf_params[j]->ToString()) { - params[i] = anf_params[j]; - } - } - } - } - - int index = 0; - for (auto &it : params) { - auto name = std::static_pointer_cast(it)->name(); - // the parameters which has not been converted to var - if (vars_.find(name) == vars_.end()) { - auto op = Convert(it); - MS_EXCEPTION_IF_NULL(op); - MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index; - if (op == nullptr) { - MS_LOG(ERROR) << "Convert graph failed!"; - return *this; - } - UpdateDataOpDesc(it, op); - - MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index; - (void)std::static_pointer_cast(op)->set_attr_index(index++); - inputs.push_back(*op); - } else if (vars_[name] != nullptr) { - MS_LOG(INFO) << "add var input " << it->ToString(); - auto op = Convert(it); - MS_EXCEPTION_IF_NULL(op); - inputs.push_back(*op); - } - } - } - - // Add const nodes as graph input for some operator work with constant - std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs), - [](OperatorPtr x) { return *x; }); - - MS_LOG(INFO) << "set graph input num: " << inputs.size(); - (void)df_graph_->SetInputs(inputs); - - // set graph output - // set the value of finale return apply node as the output of dataflow graph - MS_LOG(DEBUG) << "set output"; - graph_outputs_.clear(); - TraceOutput(anf_graph_->get_return()->input(1)); - MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size(); - (void)df_graph_->SetOutputs(graph_outputs_); - - compute_sout_ << "}" << endl; - // For the graph(e.g. eval_subgraph) whose IterNum is 1, donot set NeedIteration flag. - if (ConfigManager::GetInstance().iter_num() > 1) { - df_graph_->SetNeedIteration(true); - } - return *this; -} - -void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const { - auto node = std::static_pointer_cast(it); - if (node == nullptr) { - MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node."; - return; - } - auto normal_shape_ptr = dyn_cast(node->Shape()); - vector shape; - if (normal_shape_ptr == nullptr) { - MS_LOG(INFO) << "Invalid shape to update data op descriptor."; - return; - } - shape = normal_shape_ptr->shape(); - if (node->Type() == nullptr) { - MS_LOG(INFO) << "Invalid type to update data op descriptor."; - return; - } - TypeId me_type = node->Type()->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(node->Type())->element()->type_id(); - } - std::ostringstream buf; - buf << "[" << shape << "]"; - MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type; - auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); - if (desc == nullptr) { - MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; - } else { - (void)std::static_pointer_cast(op)->update_input_desc_x(*desc); - (void)std::static_pointer_cast(op)->update_output_desc_y(*desc); - } -} - -DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; } - -DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; } - -DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; } - -DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; } - -void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { - if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { - return; - } - - std::vector control_edges = control_depend_cache_[node.get()]; - if ((control_edges.empty())) { - MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; - return; - } - - for (auto &item : control_edges) { - (void)item.dest_op->AddControlInput(*item.src_op); - } -} - -const std::vector trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)}; - -void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { - OperatorPtr src = Convert(node); - int case_flag = 0; - auto &inputs = node->inputs(); - size_t input_size = inputs.size(); - if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) { - case_flag = 1; - input_size = case_input_handle_cache_[node.get()]->size() + 1; - } - - for (size_t i = 1; i < input_size; i++) { - auto pred = inputs[i]; - if (case_flag != 0) { - pred = case_input_handle_cache_[node.get()]->at(i - 1); - } - - while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "Depend") { - pred = pred->cast()->input(1); - } - // skip the None input - if (IsValueNode(pred)) { - continue; - } - // transform "Const" op to "Variable" op when the next node is "Assign" op. - std::string c_name = GetCNodeTargetFuncName(node); - auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); - if (!training_ && pos != trans_var_list.end() && pred->isa()) { - std::string name = std::static_pointer_cast(pred)->name(); - auto op_itor = op_cache_.find(pred.get()); - if (op_itor == op_cache_.end()) { - MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; - } - if (op_itor->second != nullptr && - (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && - vars_.find(name) != vars_.end()) { - auto variable = std::make_shared(name); - auto desc = vars_[name]->GetOutputDesc("y"); - (void)variable->update_output_desc_y(desc); - MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; - op_itor->second = variable; // replace parameter with variable - vars_[name] = variable; - } - } - // find in out_hadnle_cache_ first - auto it = out_handle_cache_.find(pred.get()); - if (it != out_handle_cache_.end()) { - int ret = adpt->setInput(src, SizeToInt(i), it->second); - if (ret == 0) { - if (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "tuple_getitem") { - compute_sout_ << op_draw_name_[pred->cast()->input(1).get()] << " -> " << op_draw_name_[node.get()] - << ":" << i << endl; - } else if (pred->isa()) { - compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; - } else { - // don't draw anything. - MS_LOG(INFO) << "DRAW_GE_GRAPH: Shouldn't have this case."; - } - AddGraphConstInput(it->second.op); - } - } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) { - std::shared_ptr> handler_vec = tuple_out_handle_cache_[pred.get()]; - int ret = adpt->setInput(src, SizeToInt(i), handler_vec); - if ((ret == 0) && pred->isa() && (pred->cast()->inputs().size() == handler_vec->size() + 1)) { - for (unsigned int j = 0; j < handler_vec->size(); j++) { - compute_sout_ << op_draw_name_[pred->cast()->input(j + 1).get()] << " -> " - << op_draw_name_[node.get()] << ":" << i << endl; - AddGraphConstInput(handler_vec->at(j).op); - } - } else { - MS_LOG(WARNING) << "Convert tuple node setInput failed : " << node->ToString(); - } - } else { - auto op = Convert(pred); - int ret = adpt->setInput(src, SizeToInt(i), op); - if (ret == 0) { - compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; - AddGraphConstInput(op); - } - } - } -} - -void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) { - if (op->GetOpType() == "Constant") { - graph_const_inputs_.push_back(op); - } -} - -void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { - if (!node->isa()) { - return; - } - if (op_cache_.find(node.get()) == op_cache_.end()) { - return; - } - auto cnode = node->cast(); - OpAdapterPtr adpt = FindAdapter(cnode, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return; - } - - // get Operator from op_cache_, use adapter to set Inputs - DfGraphConvertor::SetOpInput(adpt, cnode); -} - -void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector &inputs) { - if (!node->isa() || GetCNodeFuncName(node->cast()) != "Partial") { - return; - } - auto graph_node = node->cast()->input(1)->cast(); - FuncGraphPtr anf_graph = graph_node->value()->cast(); - DfGraphConvertor convertor(anf_graph); - convertor.use_inputs_ = true; - convertor.inputs_ = inputs; - (void)convertor.ConvertAllNode().BuildGraph(); - std::string name = graph_node->ToString() + "_ge_graph.dot"; - if (MsContext::GetInstance()->save_graphs_flag()) { - convertor.DrawComputeGraph(name); - } - branches_map_[node.get()] = *(convertor.df_graph_); -} - -// Update GE op's shape and type info -void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { - if (nullptr == node || !node->isa()) { - return; - } - - if (op_cache_.find(node.get()) == op_cache_.end()) { - return; - } - - OpAdapterPtr adpt = FindAdapter(node, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return; - } - - // get Operator from op_cache_ - OperatorPtr op = Convert(node); - - adpt->updateOutputDesc(op, node->Shape(), node->Type(), node); -} - -OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) { - if (node == nullptr) { - MS_LOG(ERROR) << "node is nullptr"; - error_ = NOT_FOUND; - return nullptr; - } - // find in cache - if (op_cache_.count(node.get())) { - return op_cache_[node.get()]; - } - - // do not convert primitive node - if (IsValueNode(node)) { - return nullptr; - } - - // convert a new one - if (node->isa()) { - return ConvertCNode(node->cast()); - } - if (node->isa()) { - return ConvertParameter(node); - } - if (node->isa()) { - return ConvertValueNode(node->cast()); - } - - MS_LOG(ERROR) << "Invalide AnfNode"; - error_ = INVALID_ARGUMENT; - return nullptr; -} - -void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { - std::shared_ptr> tuple_items = std::make_shared>(); - // convert each tuple item to a OutHandler - for (size_t i = 1; i < node->inputs().size(); i++) { - AnfNodePtr item = node->input(i); - OperatorPtr op = Convert(item); - if (op != nullptr) { - tuple_items->emplace_back(OutHandler(op, "")); - } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { - tuple_items->push_back(out_handle_cache_[item.get()]); - } else { - MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << item->ToString(); - return; - } - } - - MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size(); - tuple_out_handle_cache_[node.get()] = tuple_items; -} - -AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned int *index) { - const int TUPLE_GET_ITEM_INDEX = 2; - if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs - MS_LOG(EXCEPTION) << "length of inputs of TupleGetItem is less than 3"; - } - auto index_node = node->inputs()[TUPLE_GET_ITEM_INDEX]; - if (!index_node->isa()) { - error_ = INVALID_ARGUMENT; - MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; - } - *index = IntToUint(GetValue(GetValueNode(index_node))); - return node->inputs()[1]; -} - -AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) { - auto cnode = node->cast(); - if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs - MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3"; - } - return cnode->inputs()[1]; -} - -AnfNodePtr DfGraphConvertor::TraceMakeTuple(const CNodePtr &node, unsigned int index) { - if (index + 1 >= node->inputs().size()) { - MS_LOG(EXCEPTION) << "length of make_tuple is less than index: " << index; - } - return node->inputs()[index + 1]; -} - -OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node, const std::stack &index_stack, - AnfNode *const draw_index) { - if (node == nullptr) { - MS_LOG(ERROR) << "Get nullptr while trace real op"; - return OutHandler(nullptr, ""); - } - std::ostringstream ss; - ss << "op" << node.get(); - if (index_stack.empty()) { - op_draw_name_[draw_index] = ss.str(); - return OutHandler(Convert(node), ""); - } else { - OpAdapterPtr adpt = FindAdapter(node, training_); - if (nullptr == adpt) { - MS_LOG(ERROR) << "Can not get node output as adpt is nullptr!"; - error_ = NOT_FOUND; - return OutHandler(nullptr, ""); - } - OperatorPtr op = Convert(node); - if (op == nullptr) { - error_ = NOT_FOUND; - MS_LOG(ERROR) << "Can not convert node for trace real op"; - return OutHandler(nullptr, ""); - } - op_draw_name_[draw_index] = ss.str(); - return adpt->getOutput(Convert(node), UintToInt(index_stack.top())); - } -} - -// get the real operator through maketuple tuple_getitem depend -OutHandler DfGraphConvertor::TraceRealOp(AnfNodePtr node) { - bool flag = IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) || - IsPrimitiveCNode(node, prim::kPrimDepend); - std::stack index_stack; - auto draw_index = node.get(); - while (flag) { - flag = false; - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - unsigned int index; - node = TraceTupleGetItem(node->cast(), &index); - index_stack.push(index); - flag = true; - } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - if (index_stack.empty()) { - MS_LOG(ERROR) << "TraceRealOp find a make_tuple node"; - return OutHandler(nullptr, ""); - } else { - node = TraceMakeTuple(node->cast(), index_stack.top()); - index_stack.pop(); - flag = true; - } - } else if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - node = TraceDepend(node->cast()); - flag = true; - } - } - return GetHandler(node, index_stack, draw_index); -} - -void DfGraphConvertor::ConvertTupleGetItem(const CNodePtr node) { - auto handle = TraceRealOp(node); - if (handle.op == nullptr) { - MS_LOG(ERROR) << "Failed to trace tuple get item"; - return; - } - out_handle_cache_[node.get()] = handle; -} - -// Get the real op for tuple_getitem through make tuple, or depend -AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) { - const int TUPLE_GET_ITEM_INDEX = 2; - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - auto node_inputs = node->cast()->inputs(); - if (node_inputs.size() != 3) { // "tuple_getitem" primitive must have 3 inputs - MS_LOG(ERROR) << "tuple get item node not correct!"; - error_ = FAILED; - return node; - } - MS_EXCEPTION_IF_NULL(node_inputs[TUPLE_GET_ITEM_INDEX]); - if (!node_inputs[TUPLE_GET_ITEM_INDEX]->isa()) { - error_ = INVALID_ARGUMENT; - MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; - } - auto value_ptr = GetValueNode(node_inputs[TUPLE_GET_ITEM_INDEX])->cast(); - if (value_ptr == nullptr) { - MS_LOG(ERROR) << "Can not convert get item as value is nullptr!"; - error_ = FAILED; - return node; - } - int index = value_ptr->value(); - - // make_tuple apply inputs:make_tuple, [tuple_items,] - if (IsPrimitiveCNode(node_inputs[1], prim::kPrimMakeTuple)) { - auto tuple_inputs = node->cast()->inputs(); - if (tuple_inputs.size() < IntToSize(index + 1)) { - MS_LOG(ERROR) << "make tuple input items node not correct! size:" << tuple_inputs.size() - << ", item index:" << index; - error_ = FAILED; - return node; - } - return GetRealOpNode(tuple_inputs[IntToSize(index + 1)]); - } - return GetRealOpNode(node_inputs[1]); - } - - // depend apply inputs: depend,output,depended_node - if (IsPrimitiveCNode(node, prim::kPrimDepend)) { - auto depend_inputs = node->cast()->inputs(); - if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs - MS_LOG(ERROR) << "depend input items not correct"; - error_ = FAILED; - return node; - } - return GetRealOpNode(depend_inputs[1]); - } - return node; -} - -// convert the anf node to corresponding operator list -std::vector DfGraphConvertor::ConvertDependNode(const AnfNodePtr node) { - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - std::vector op_lists; - auto node_inputs = node->cast()->inputs(); - for (size_t index = 1; index < node_inputs.size(); index++) { - auto op = Convert(GetRealOpNode(node_inputs[index])); - if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed"; - error_ = FAILED; - return std::vector({}); - } - op_lists.push_back(op); - } - return op_lists; - } - - auto op = Convert(GetRealOpNode(node)); - if (op == nullptr) { - MS_LOG(ERROR) << "Convert control depend node to operator failed"; - error_ = FAILED; - return std::vector({}); - } - return std::vector({op}); -} - -// get the anf node list for depend -std::vector DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { - std::vector nodes; - // for make tuple, should control depend on the tuple items - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - auto node_inputs = node->cast()->inputs(); - for (size_t index = 1; index < node_inputs.size(); index++) { - nodes.push_back(GetRealOpNode(node_inputs[index])); - } - return nodes; - } - - // for parameter ,find the apply that used the parameter as the control depended node - if (node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { - nodes.push_back(GetRealOpNode(use_node)); - } - } - return nodes; - } - nodes.push_back(GetRealOpNode(node)); - return nodes; -} - -void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { -#ifdef DRAW_GE_GRAPH - auto src_depend_nodes = GetDependNodes(src_node); - auto dst_depend_nodes = GetDependNodes(dest_node); - if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { - for (auto &item : dst_depend_nodes) { - compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] - << "[style=\"dotted\"]" << endl; - } - } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { - for (auto &item : src_depend_nodes) { - compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] - << "[style=\"dotted\"]" << endl; - } - } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { - compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] - << "[style=\"dotted\"]" << endl; - } -#endif -} - -void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, - const AnfNodePtr &dest_node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list) { - if (src_node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[src_node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && - (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { - auto converted_list = ConvertDependNode(use_node); - src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); - } - } - } - - if (dest_node->isa()) { - auto uses = node->func_graph()->manager()->node_users()[dest_node]; - for (auto &use : uses) { - auto use_node = use.first; - if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && - (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { - auto converted_list = ConvertDependNode(use_node); - dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); - } - } - } -} - -bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list) { - const int CONTROL_DEPEND_INDEX = 0; - const int SRC_NODE_INDEX = 1; - const int DEST_NODE_INDEX = 2; - const int DEPEND_MODE_NORMAL_USE = 0; - const int DEPEND_MODE_ON_PARAMETER_USE = 1; - - auto node_inputs = node->inputs(); - if (node_inputs.size() <= DEST_NODE_INDEX) { - MS_LOG(WARNING) << "Control depend node input size error"; - return false; - } - auto src_node = node_inputs[SRC_NODE_INDEX]; - auto dest_node = node_inputs[DEST_NODE_INDEX]; - if ((src_node == nullptr) || (dest_node == nullptr)) { - MS_LOG(ERROR) << "Control depend node miss src or dest node"; - error_ = FAILED; - return false; - } - AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; - PrimitivePtr prim_ptr = GetValueNode(fn); - ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); - int depend_mode = DEPEND_MODE_NORMAL_USE; - if (mode_ptr != nullptr) { - auto mode_int = mode_ptr->cast(); - MS_EXCEPTION_IF_NULL(mode_int); - depend_mode = mode_int->value(); - MS_LOG(DEBUG) << "depend_mode = " << depend_mode; - } - if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { - GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); - } - - if (src_node->isa()) { - auto converted_list = ConvertDependNode(src_node); - src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); - } - - if (dest_node->isa()) { - auto converted_list = ConvertDependNode(dest_node); - dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); - } - if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; - error_ = SUCCESS; - } - return true; -} - -void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { - const int SRC_NODE_INDEX = 1; - const int DEST_NODE_INDEX = 2; - if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { - return; - } - auto node_inputs = node->inputs(); - if (node_inputs.size() <= DEST_NODE_INDEX) { - MS_LOG(WARNING) << "Control depend node input size error"; - return; - } - auto src_node = node_inputs[SRC_NODE_INDEX]; - auto dest_node = node_inputs[DEST_NODE_INDEX]; - if ((src_node == nullptr) || (dest_node == nullptr)) { - MS_LOG(ERROR) << "Control depend node miss src or dest node"; - error_ = FAILED; - return; - } - std::shared_ptr> src_ops_list = std::make_shared>(); - std::shared_ptr> dst_ops_list = std::make_shared>(); - if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { - MS_LOG(ERROR) << "Get depend list failed"; - error_ = FAILED; - return; - } - std::vector control_edges; - if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { - (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), - [src_ops_list](const OperatorPtr &op) -> ControlEdge { - return {(*src_ops_list)[0], op}; - }); - } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { - (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), - [dst_ops_list](const OperatorPtr &op) -> ControlEdge { - return {op, (*dst_ops_list)[0]}; - }); - } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { - control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); - } else if (src_ops_list->empty() || dst_ops_list->empty()) { - MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; - } else { - MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() - << " -> dst:" << dst_ops_list->size(); - error_ = FAILED; - return; - } - control_depend_cache_[node.get()] = control_edges; - -#ifdef DRAW_GE_GRAPH - DrawControlDepend(src_node, dest_node); -#endif -} - -bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { - // ignore apply node of return - if (name == "return" || name == "Depend") { - return false; - } - - if (name == "" && GetCNodeFuncName(node) == "switch_layer") { - return false; - } - - if (name == "Partial") { - return false; - } - - // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers - if (name == "make_tuple") { - ConvertMakeTuple(node); - return false; - } - - // As for nodes with multi outputs, convert tuple_getitem to OutHandle - if (name == "tuple_getitem") { - ConvertTupleGetItem(node); - return false; - } - - if (name == "ControlDepend") { - ConvertControlDependNode(node); - return false; - } - - return true; -} - -OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { - std::string name = GetCNodeTargetFuncName(node); - if (!CheckCNode(name, node)) { - return nullptr; - } - - // get corresponding OpAdapter - OpAdapterPtr adpt = FindAdapter(node, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return nullptr; - } - - // get operator - OperatorPtr op = nullptr; - auto it_op = op_cache_.find(node.get()); - if (it_op != op_cache_.end()) { - op = it_op->second; - } else { - op = adpt->generate(node); - } - - // set attribute for primitive - (void)adpt->setAttr(op, node); - - // add into cache - (void)op_cache_.insert(std::make_pair(node.get(), op)); - - DrawCNode(node, adpt); - - return op_cache_[node.get()]; -} - -OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) { - // convert Parameter in ANF to variable in DataFlow - auto op = FindAdapter(node, training_)->generate(node); - op_cache_[node.get()] = op; - - // build index for parameter using name - std::string name = std::static_pointer_cast(node)->name(); - params_[name] = node; - - std::ostringstream ss; - ss << "op" << node.get(); - op_draw_name_[node.get()] = ss.str(); - compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl; - return op_cache_[node.get()]; -} - -Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) { - MS_EXCEPTION_IF_NULL(node); - ValuePtr value = node->value(); - MS_EXCEPTION_IF_NULL(value); - if (!value->isa() && !value->isa()) { - return FAILED; - } - - auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); - if (vec.empty()) { - return FAILED; - } - - std::shared_ptr> tuple_items = std::make_shared>(); - for (size_t i = 0; i < vec.size(); i++) { - MS_EXCEPTION_IF_NULL(vec[i]); - if (vec[i]->isa()) { - GeTensorPtr ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast(), kOpFormat_NCHW); - auto const_op = std::make_shared(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i)); - (void)const_op->set_attr_value(*ge_tensor); - (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc()); - tuple_items->emplace_back(OutHandler(const_op, "")); - } else { - return FAILED; - } - } - if (tuple_items->empty()) { - return FAILED; - } - - tuple_out_handle_cache_[node.get()] = tuple_items; - return SUCCESS; -} - -OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) { - // convert valuenode in ANF to Const in DataFlow - // find paramerte referenced by SymbolicKeyInstance of valuenode - std::ostringstream ss; - ss << "op" << node.get(); - op_draw_name_[node.get()] = ss.str(); - compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl; - - if (TryConvertValueNodeToMultiConst(node) == SUCCESS) { - MS_LOG(INFO) << "Convert value node to multi Constant OP success"; - return nullptr; - } - - OpAdapterPtr adpt = FindAdapter(node, training_); - if (adpt == nullptr) { - error_ = NOT_FOUND; - return nullptr; - } - auto op = adpt->generate(node); - // set const's attrs - if (adpt->setAttr(op, "value", node->value()) != 0) { - MS_LOG(WARNING) << "set attr value for const failed"; - } - -#if (defined ENABLE_GE) - auto const_op = std::static_pointer_cast(op); - if (const_op == nullptr) { - MS_LOG(ERROR) << "Get Constant operator failed"; - return nullptr; - } - auto ge_tensor = const_op->get_attr_value(); - auto ge_desc = ge_tensor.GetTensorDesc(); - (void)const_op->update_output_desc_y(ge_desc); -#endif - - op_cache_[node.get()] = op; - return op_cache_[node.get()]; -} - -void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { - if (nullptr == adpt || nullptr == node) { - MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!"; - return; - } - std::ostringstream ss; - ss << "op" << node.get(); - op_draw_name_[node.get()] = ss.str(); - - compute_sout_ << ss.str() << "[label=<"; - compute_sout_ << "" << endl; - - auto input_map = adpt->getInputMap(); - auto dyn_input_map = adpt->getDynInputMap(); - if (input_map.size() + dyn_input_map.size() > 0) { - compute_sout_ << ""; - for (auto &it : input_map) { - compute_sout_ << ""; - } - for (auto &it : dyn_input_map) { - compute_sout_ << ""; - } - compute_sout_ << "" << endl; - } - - compute_sout_ << "" << endl; - - // print attrs' values - auto atts = adpt->GetAttrsFromDrawGraph(); - for (auto &it : atts) { - compute_sout_ << ""; - } - - adpt->clearAttrVect(); - - compute_sout_ << "
" << it.second.name << "" << it.second.name << "
\"" << node->ToString() - << ":" << GetCNodeTargetFuncName(node) << "\"
\"" << it - << "\"
> shape=plaintext]" << endl; -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/convert.h b/mindspore/ccsrc/transform/convert.h deleted file mode 100644 index cca0371c2e..0000000000 --- a/mindspore/ccsrc/transform/convert.h +++ /dev/null @@ -1,258 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ -#define MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ - -#define DRAW_GE_GRAPH - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ir/anf.h" -#include "ir/func_graph.h" -#include "transform/util.h" -#include "ir/tensor.h" -#include "transform/df_graph_manager.h" -#include "utils/config_manager.h" -#include "transform/op_declare.h" -#include "graph/operator_reg.h" -#ifdef OPEN_SOURCE -#include "ge/client/ge_api.h" -#else -#include "external/ge/ge_api.h" -#endif -#include "graph/tensor.h" -#include "ops/all_ops.h" - -namespace mindspore { -namespace transform { -class OpAdapterDesc { - public: - OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} - - OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} - - explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} - - OpAdapterDesc(const OpAdapterDesc &desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - } - - OpAdapterDesc(OpAdapterDesc &&desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - desc.train_ = nullptr; - desc.infer_ = nullptr; - } - - ~OpAdapterDesc() = default; - - OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } - - OpAdapterDesc &operator=(const OpAdapterDesc &desc) { - if (this != &desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - } - return *this; - } - - OpAdapterDesc &operator=(OpAdapterDesc &&desc) { - if (this != &desc) { - this->train_ = desc.train_; - this->infer_ = desc.infer_; - desc.train_ = nullptr; - desc.infer_ = nullptr; - } - return *this; - } - - private: - OpAdapterPtr train_; - OpAdapterPtr infer_; -}; - -using OpAdapterDescPtr = std::shared_ptr; -using TensorOrderMap = std::map>; - -class DfGraphConvertor { - public: - explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) - : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { -#if (!defined ENABLE_GE) || (defined ENABLE_INFER) - training_ = anf_graph->has_flag("training"); -#else - training_ = ENABLE_TRAIN; -#endif - distribute_ = anf_graph->has_flag("broadcast_flag"); - if (anf_graph->has_flag("broadcast_flag")) { - ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); - } else { - ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); - } - - MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_; - } - - ~DfGraphConvertor() {} - - static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { - get_adpt_map()[name] = std::make_shared(adpt); - } - static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { - get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); - } - - void DrawComputeGraph(const std::string &name) { - std::ofstream fout(name); - if (!fout.is_open()) { - MS_LOG(ERROR) << "Open file '" << name << "' failed!"; - return; - } - fout << compute_sout_.str(); - fout.close(); - } - void DrawInitGraph(const std::string &name) { - std::ofstream fout(name); - if (!fout.is_open()) { - MS_LOG(ERROR) << "Open file '" << name << "' failed!"; - return; - } - fout << init_sout_.str(); - fout.close(); - } - void DrawSaveCheckpointGraph(const std::string &name) { - std::ofstream fout(name); - if (!fout.is_open()) { - MS_LOG(ERROR) << "Open file '" << name << "' failed!"; - return; - } - fout << checkpoint_sout_.str(); - fout.close(); - } - - DfGraphConvertor &ConvertAllNode(); - DfGraphConvertor &BuildGraph(); - DfGraphConvertor &InitParam(const TensorOrderMap &tensors); - DfGraphConvertor &GenerateCheckpointGraph(); - DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); - void InitParamWithData(const TensorOrderMap &tensors); - void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); - void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, - const DfGraphPtr &broadcast_graph, std::vector broadcast_input); - void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); - void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); - void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); - - DfGraphPtr GetComputeGraph(); - DfGraphPtr GetInitGraph(); - DfGraphPtr GetSaveCheckpointGraph(); - DfGraphPtr GetBroadcastGraph(); - static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); - static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); - int ErrCode() const { return static_cast(error_); } - - static std::unordered_map &get_adpt_map(); - bool is_training() const { return training_; } - void set_training(bool is_training) { training_ = is_training; } - - protected: - void InitLoopVar(std::vector *init_input); - - private: - std::ostringstream compute_sout_; - std::ostringstream init_sout_; - std::ostringstream checkpoint_sout_; - std::ostringstream restore_checkpoint_sout_; - std::unordered_map op_draw_name_; - - AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); - AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); - AnfNodePtr TraceDepend(const CNodePtr &node); - OutHandler TraceRealOp(AnfNodePtr node); - OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); - OperatorPtr Convert(AnfNodePtr node); - OperatorPtr ConvertCNode(CNodePtr node); - std::vector ConvertDependNode(AnfNodePtr node); - AnfNodePtr GetRealOpNode(AnfNodePtr node); - std::vector GetDependNodes(const AnfNodePtr &node); - OperatorPtr ConvertParameter(AnfNodePtr node); - Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); - OperatorPtr ConvertValueNode(ValueNodePtr node); - void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); - void ConvertTupleGetItem(const CNodePtr node); - void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, - const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list); - bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, - const std::shared_ptr> &dst_ops_list); - void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); - void ConvertControlDependNode(const CNodePtr node); - void ConvertMakeTuple(const CNodePtr node); - bool CheckCNode(const std::string &name, const CNodePtr node); - void TraceOutput(AnfNodePtr node); - void TraceOutputFromParameter(const AnfNodePtr &anf_out); - void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); - void SetNodeInput(AnfNodePtr node); - void SetOpControlInput(const AnfNodePtr node); - void UpdateOpDesc(AnfNodePtr node); - void SetSubgraph(AnfNodePtr node); - void ProcessSubgraph(AnfNodePtr node, const std::vector &inputs); - void BuildSaveCheckpointGraph(); - void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); - void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; - void AddGraphConstInput(const OperatorPtr &op); - - std::shared_ptr anf_graph_{nullptr}; - std::shared_ptr df_graph_{nullptr}; - std::shared_ptr init_graph_{nullptr}; - std::shared_ptr save_ckp_graph_{nullptr}; - std::shared_ptr restore_ckp_graph_{nullptr}; - std::shared_ptr broadcast_graph_{nullptr}; - std::unordered_map branches_map_; - std::unordered_map op_cache_; - std::unordered_map> control_depend_cache_; - /* record "tuple_getitem"<->"out_handler" mapping */ - std::unordered_map out_handle_cache_; - /* record "make_tuple"<->"out_handler vector" mapping */ - std::unordered_map>> tuple_out_handle_cache_; - std::unordered_map>> case_input_handle_cache_; - std::unordered_map params_; - std::unordered_map vars_; - std::vector> graph_outputs_; - std::vector graph_const_inputs_; - std::vector init_ops_; - std::vector broadcast_ops_; - std::vector inputs_; - OperatorPtr dataset_iter_getnext_; - Status error_ = SUCCESS; - bool training_ = false; - bool distribute_ = false; - bool use_inputs_ = false; -}; -} // namespace transform -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ diff --git a/mindspore/ccsrc/transform/df_graph_manager.cc b/mindspore/ccsrc/transform/df_graph_manager.cc deleted file mode 100644 index f62c386587..0000000000 --- a/mindspore/ccsrc/transform/df_graph_manager.cc +++ /dev/null @@ -1,214 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/df_graph_manager.h" - -#include -#include -#include -#include - -#include "securec/include/securec.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/pipeline.h" -#include "utils/config_manager.h" -#ifndef NO_DLIB -#include "tdt/tsd_client.h" -#endif - -namespace mindspore { -namespace transform { -DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, - const OptionMap &options) - : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} - -DfGraphManager::DfGraphManager() { - graph_id_ = 0; - graph_runner_ptr_ = nullptr; - sess_ptr_ = nullptr; -} - -DfGraphManager::~DfGraphManager() { - // in python fisrt destroy after atexit but in c++ destoy before atexit - DeleteGraphRunner(); - DeleteGeSession(); - ClearGraph(); - parse::python_adapter::set_python_env_flag(false); -} - -DfGraphManager &DfGraphManager::GetInstance() { - static DfGraphManager instance; - return instance; -} - -int DfGraphManager::GenerateId() { - graph_id_++; - if (graph_id_ <= 0) { - graph_id_ = 1; - } - MS_LOG(INFO) << "Generate graph Id : " << graph_id_; - return graph_id_; -} - -Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { - std::lock_guard lg(lock_); - if (name.empty()) { - MS_LOG(ERROR) << "The graph name is null, add graph failed"; - return Status::INVALID_ARGUMENT; - } - - if (graph_ptr == nullptr) { - MS_LOG(WARNING) << "The new graph {" << name << "}'s pointer is null, add graph failed"; - return Status::INVALID_ARGUMENT; - } - - int id = GenerateId(); - DfGraphWrapperPtr wrap_ptr = std::make_shared(name, id, graph_ptr, options); - auto ret = graphs_.emplace(name, wrap_ptr); - if (ret.second == false) { - MS_LOG(WARNING) << "The graph name:{ " << name << " }is already exists! The old graph will be overwritten!!"; - ret.first->second = wrap_ptr; - } - MS_LOG(INFO) << "Add graph " << name << " to GraphManager success!"; - return Status::SUCCESS; -} - -std::vector DfGraphManager::GetAllGraphs() { - std::lock_guard lg(lock_); - std::vector ret; - std::stringstream ss; - ss << "{ "; - for (auto it = graphs_.begin(); it != graphs_.end(); ++it) { - ss << it->first << ", "; - ret.emplace_back(it->second); - } - ss << "}"; - MS_LOG(INFO) << "Return graphs: " << ss.str(); - return ret; -} -std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } - -void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } - -DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { - std::lock_guard lg(lock_); - if (name.empty()) { - MS_LOG(ERROR) << "The graph name is null"; - return nullptr; - } - - auto it = graphs_.find(name); - if (it == graphs_.end()) { - MS_LOG(INFO) << "Can't found graph name: " << name; - return nullptr; - } - MS_LOG(INFO) << "Return graph: " << name; - return it->second; -} - -void DfGraphManager::ClearGraph() noexcept { - std::lock_guard lg(lock_); - graphs_.clear(); - anf_graphs_.clear(); - MS_LOG(INFO) << "Remove all graphs in GraphManager"; -} - -void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { - DfGraphWrapperPtr df_graph = GetGraphByName(name); - if (df_graph == nullptr) { - MS_LOG(ERROR) << "Can't found graph name: " << name; - return; - } - std::lock_guard lg(lock_); - anf_graphs_[df_graph->id_] = anf_graph_ptr; -} - -AnfGraphPtr DfGraphManager::GetAnfGraph(uint32_t graph_id) { - std::lock_guard lg(lock_); - auto iter = anf_graphs_.find(graph_id); - if (iter == anf_graphs_.end()) { - MS_LOG(ERROR) << "Can't found anf graph, graph_id = " << graph_id; - return nullptr; - } - - return iter->second; -} - -void DfGraphManager::EraseAnfGraph() { - std::lock_guard lg(lock_); - anf_graphs_.clear(); -} - -void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { - std::lock_guard lg(lock_); - if (sess_ptr == nullptr) { - MS_LOG(WARNING) << "You are adding a empty Ge Session"; - } - - if (sess_ptr_ == nullptr) { - MS_LOG(INFO) << "Add a new Ge Session success"; - } else { - MS_LOG(INFO) << "Add a new Ge Session success, the old Ge Session will be overwritten!!"; - } - sess_ptr_ = sess_ptr; -} - -std::shared_ptr DfGraphManager::GetGeSession() { - std::lock_guard lg(lock_); - return sess_ptr_; -} - -void DfGraphManager::DeleteGeSession() noexcept { - std::lock_guard lg(lock_); - if (sess_ptr_ == nullptr) { - MS_LOG(INFO) << "Ge Session is not exist"; - } else { - sess_ptr_ = nullptr; - saved_graphs_.clear(); - MS_LOG(INFO) << "Delete Ge Session success"; - } -} - -void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { - std::lock_guard lg(lock_); - if (graph_runner_ptr == nullptr) { - MS_LOG(WARNING) << "You are adding a empty GraphRunner"; - } - - if (graph_runner_ptr_ == nullptr) { - MS_LOG(INFO) << "Add a new GraphRunner success"; - } else { - MS_LOG(INFO) << "Add a new GraphRunner success, the old GraphRunner will be overwritten!!"; - } - graph_runner_ptr_ = graph_runner_ptr; -} - -std::shared_ptr DfGraphManager::GetGraphRunner() { - std::lock_guard lg(lock_); - return graph_runner_ptr_; -} - -void DfGraphManager::DeleteGraphRunner() noexcept { - std::lock_guard lg(lock_); - if (graph_runner_ptr_ == nullptr) { - MS_LOG(INFO) << "GraphRunner is not exist"; - } else { - graph_runner_ptr_ = nullptr; - MS_LOG(INFO) << "Delete GraphRunner success"; - } -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/df_graph_manager.h b/mindspore/ccsrc/transform/df_graph_manager.h deleted file mode 100644 index 2ca43d1f07..0000000000 --- a/mindspore/ccsrc/transform/df_graph_manager.h +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_DF_GRAPH_MANAGER_H_ -#define TRANSFORM_DF_GRAPH_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "transform/types.h" -#include "ir/anf.h" - -namespace mindspore { -const char BROADCAST_GRAPH_NAME[] = "broadcast_subgraph"; - -namespace transform { -class GraphRunner; -using OptionMap = std::map; - -struct DfGraphWrapper { - public: - DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); - ~DfGraphWrapper() {} - - std::string name_; - int id_; - DfGraphPtr graph_ptr_; - OptionMap options_ = {}; -}; - -using DfGraphWrapperPtr = std::shared_ptr; - -class DfGraphManager { - public: - ~DfGraphManager(); - void ClearGraph() noexcept; - - static DfGraphManager &GetInstance(); - Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); - std::vector GetAllGraphs(); - std::set GetSavedGraphs(); - void AddSavedGraphs(const std::string &id); - DfGraphWrapperPtr GetGraphByName(const std::string &name); - DfGraphManager(const DfGraphManager &) = delete; - void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); - AnfGraphPtr GetAnfGraph(uint32_t graph_id); - std::shared_ptr GetGraphRunner(); - void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; - void DeleteGraphRunner() noexcept; - void SetGeSession(const std::shared_ptr &sess_ptr); - std::shared_ptr GetGeSession(); - void DeleteGeSession() noexcept; - void EraseAnfGraph(); - - private: - DfGraphManager(); - int GenerateId(); - - std::mutex lock_; - std::map graphs_; - std::set saved_graphs_; - int graph_id_; - std::map anf_graphs_; - std::shared_ptr graph_runner_ptr_; - std::shared_ptr sess_ptr_; -}; -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_DF_GRAPH_MANAGER_H_ diff --git a/mindspore/ccsrc/transform/graph_builder.cc b/mindspore/ccsrc/transform/graph_builder.cc deleted file mode 100644 index 785c5c7f3a..0000000000 --- a/mindspore/ccsrc/transform/graph_builder.cc +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/graph_builder.h" - -#include -#include - -namespace mindspore { -namespace transform { -DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { - MS_LOG(INFO) << "BuildMDDatasetGraph."; - - // InitData - auto d = ge::op::InitData("init_data_tmp").set_attr_channel_name(param.queue_name()); - - // set graph inputs & outputs - std::vector inputs{d}; - std::vector outputs{d}; - DfGraphPtr dataset_graph = std::make_shared("dataset"); - (void)dataset_graph->SetInputs(inputs); - (void)dataset_graph->SetOutputs(outputs); - - return dataset_graph; -} - -Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { - Status ret; - std::string graph_name = phase; - - MS_LOG(INFO) << "BuildDatasetGraph begin. phase is " << phase; - MS_LOG(INFO) << "param is " << param.ToString() << "."; - - DfGraphPtr dataset_graph = BuildMDDatasetGraph(param); - ret = DfGraphManager::GetInstance().AddGraph(graph_name, dataset_graph); - if (ret != Status::SUCCESS) { - MS_LOG(ERROR) << "BuildDatasetGraph failed."; - } else { - MS_LOG(INFO) << "BuildDatasetGraph end."; - } - return ret; -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_builder.h b/mindspore/ccsrc/transform/graph_builder.h deleted file mode 100644 index 3d959f5a85..0000000000 --- a/mindspore/ccsrc/transform/graph_builder.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_GRAPH_BUILDER_H_ -#define TRANSFORM_GRAPH_BUILDER_H_ - -#include -#include -#include -#include -#include -#include "transform/types.h" -#include "transform/convert.h" - -namespace mindspore { -namespace transform { -Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_GRAPH_BUILDER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt new file mode 100644 index 0000000000..3f062609d5 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/CMakeLists.txt @@ -0,0 +1,9 @@ +if (ENABLE_GE OR ENABLE_D) + file(GLOB_RECURSE _TRANSFORM_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") + set_property(SOURCE ${_TRANSFORM_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_GE_ADPT) + add_library(_mindspore_transform_graph_ir_obj OBJECT ${_TRANSFORM_SRC_LIST}) + + if (NOT ENABLE_GE) + target_compile_definitions(_mindspore_transform_graph_ir_obj PRIVATE NO_GE_CLIENT) + endif() +endif () diff --git a/mindspore/ccsrc/transform/all_ops.h b/mindspore/ccsrc/transform/graph_ir/all_ops.h similarity index 100% rename from mindspore/ccsrc/transform/all_ops.h rename to mindspore/ccsrc/transform/graph_ir/all_ops.h diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc new file mode 100644 index 0000000000..7419dd2cc9 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -0,0 +1,2073 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/convert.h" + +#include +#include +#include +#include "utils/utils.h" + +#include "frontend/operator/ops.h" +#include "utils/log_adapter.h" +#include "utils/graph_utils.h" +#include "utils/symbolic.h" +#include "utils/config_manager.h" +#include "utils/convert_utils.h" +#include "./common.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace transform { +using std::endl; + +#define ADPT_DESC_ONE(T) std::make_shared(std::make_shared>()) +#define ADPT_DESC_TWO(T, I) \ + std::make_shared(std::make_shared>(), std::make_shared>()) +#define GET_MACRO(_1, _2, DESC, ...) DESC +#define ADPT_DESC(...) GET_MACRO(__VA_ARGS__, ADPT_DESC_TWO, ADPT_DESC_ONE, ...)(__VA_ARGS__) + +using ge::Operator; +using mindspore::kAnyValue; +using std::make_shared; +using std::shared_ptr; +using std::string; +using std::vector; + +const char kNameCustomOp[] = "CustomOp"; +const char kNameConst[] = "Const"; +const char kNameParam[] = "parameter"; +const char kNameRandomUniform[] = "RandomUniform"; +const char kNameSimpleMean[] = "SimpleMean"; +const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; +const char kNameAllReduce[] = "AllReduce"; +const char kNameBroadcast[] = "Broadcast"; +const char kNameAllgather[] = "AllGather"; +const char kNameReduceScatter[] = "ReduceScatter"; +const char kNameReduceSum[] = "ReduceSum"; +const char kNameIsFinite[] = "isFinite"; +const char kNameReciprocal[] = "Reciprocal"; +const char kNameRsqrt[] = "Rsqrt"; +const char kNameRsqrtGrad[] = "RsqrtGrad"; +const char kNameSqrt[] = "Sqrt"; +const char kNameSquare[] = "Square"; +const char kNameSquaredDifference[] = "SquaredDifference"; +const char kNamePow[] = "Pow"; +const char kNameBatchMatMul[] = "BatchMatMul"; +const char kNameStridedSlice[] = "StridedSlice"; +const char kNameStridedSliceGrad[] = "StridedSliceGrad"; +const char kNameExpandDims[] = "ExpandDims"; +const char kNameLog[] = "Log"; +const char kNameLogicalAnd[] = "LogicalAnd"; +const char kNameLogicalNot[] = "LogicalNot"; +const char kNameLogicalOr[] = "LogicalOr"; +const char kNameExp[] = "Exp"; +const char kNameLessEqual[] = "LessEqual"; +const char kNameGreaterEqual[] = "GreaterEqual"; +const char kNameEqual[] = "Equal"; +const char kNameNotEqual[] = "NotEqual"; +const char kNameFlattenGrad[] = "FlattenGrad"; +const char kNameConvolution[] = "Convolution"; +const char kNameBiasAdd[] = "BiasAdd"; +const char kNameMaxPoolGrad[] = "MaxPoolGrad"; +const char kNameAvgPoolGrad[] = "AvgPoolGrad"; +const char kNameMaxPoolGradWithArgmax[] = "MaxPoolGradWithArgmax"; +const char kNameApplyMomentum[] = "ApplyMomentum"; +const char kNameDropoutDoMask[] = "DropoutDoMask"; +const char kNameResizeBilinear[] = "ResizeBilinear"; +const char kNameResizeBilinearGrad[] = "ResizeBilinearGrad"; +const char kNameZerosLike[] = "ZerosLike"; +const char kNameOnesLike[] = "OnesLike"; +const char kNameTruncatedNormal[] = "TruncatedNormal"; +const char kNameSpaceToBatchNd[] = "SpaceToBatchNd"; +const char kNameConfusionMatrix[] = "ConfusionMatrix"; +const char kNameResizeNearestNeighborD[] = "ResizeNearestNeighbor"; +const char kNameResizeNearestNeighborGrad[] = "ResizeNearestNeighborGrad"; +const char kNameApplyAdam[] = "Adam"; +const char kNameExtractImagePatches[] = "ExtractImagePatches"; +const char kNameReLU6[] = "ReLU6"; +const char kNameReLU6Grad[] = "ReLU6Grad"; +const char kNameElu[] = "Elu"; +const char kNameEluGrad[] = "EluGrad"; +const char kNameTensorScatterUpdate[] = "TensorScatterUpdate"; +const char kNameScatterUpdate[] = "ScatterUpdate"; +const char kNameScatterNdUpdate[] = "ScatterNdUpdate"; +const char kNameScatterMax[] = "ScatterMax"; +const char kNameNMSWithMask[] = "NMSWithMask"; +const char kNameCheckValid[] = "CheckValid"; +const char kNameSmoothL1Loss[] = "SmoothL1Loss"; +const char kNameSmoothL1LossGrad[] = "SmoothL1LossGrad"; +const char kNameSGD[] = "SGD"; +const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits"; +const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; +const char kNameScatterNdD[] = "ScatterNd"; +const char kNamePadD[] = "Pad"; +const char kNameMirrorPad[] = "MirrorPad"; +const char kNameMirrorPadGrad[] = "MirrorPadGrad"; +const char kNameGatherNd[] = "GatherNd"; +const char kNameArgmax[] = "Argmax"; +const char kNameArgmin[] = "Argmin"; +const char kNameArgMaxWithValue[] = "ArgMaxWithValue"; +const char kNameArgMinWithValue[] = "ArgMinWithValue"; +const char kNameReduceProd[] = "ReduceProd"; +const char kNameCumProd[] = "CumProd"; +const char kNameDiagpart[] = "Diagpart"; +const char kNameSplitD[] = "Split"; +const char kNameBatchToSpaceNd[] = "BatchToSpaceNd"; +const char kNameFloor[] = "Floor"; +const char kNameNPUGetFloatStatus[] = "NPUGetFloatStatus"; +const char kNameAssign[] = "Assign"; +const char kNameAssignAdd[] = "AssignAdd"; +const char kNameAssignSub[] = "AssignSub"; +const char kNameNPUAllocFloatStatus[] = "NPUAllocFloatStatus"; +const char kNameNPUClearFloatStatus[] = "NPUClearFloatStatus"; +const char kNameReshape[] = "Reshape"; +const char kNameTransShape[] = "TransShape"; +const char kNameRealDiv[] = "RealDiv"; +const char kNameTile[] = "Tile"; +const char kNameCos[] = "Cos"; +const char kNameACos[] = "ACos"; +const char kNameACosGrad[] = "ACosGrad"; +const char kNameFloorDiv[] = "FloorDiv"; +const char kNameSin[] = "Sin"; +const char kNamePrelu[] = "PReLU"; +const char kNamePreluGrad[] = "PReLUGrad"; +const char kNameSigmoid[] = "Sigmoid"; +const char kNameSigmoidGrad[] = "SigmoidGrad"; +const char kNameL2Normalize[] = "L2Normalize"; +const char kNameL2NormalizeGrad[] = "L2NormalizeGrad"; +const char kNameSoftmax[] = "Softmax"; +const char kNameIOU[] = "IOU"; +const char kNameBoundingBoxDecode[] = "BoundingBoxDecode"; +const char kNameBoundingBoxEncode[] = "BoundingBoxEncode"; +const char kNameSlice[] = "Slice"; +const char kNameAddN[] = "AddN"; +const char kNameLess[] = "Less"; +const char kNameGreater[] = "Greater"; +const char kNamePack[] = "Pack"; +const char kNameUnpack[] = "Unpack"; +const char kNameMerge[] = "Merge"; +const char kNameGeSwitch[] = "GeSwitch"; + +const char kNameHuberLoss[] = "HuberLoss"; +const char kNameCumSum[] = "CumSum"; +const char kNameHuberLossGrad[] = "HuberLossGrad"; +const char kNameSparseSoftmaxCrossEntropy[] = "SparseSoftmaxCrossEntropy"; +const char kNameSparseSoftmaxCrossEntropyGrad[] = "SparseSoftmaxCrossEntropyGrad"; +const char kNameTopK[] = "TopK"; +const char kNameSoftmaxGrad[] = "SoftmaxGrad"; +const char kNameMaxPool[] = "MaxPool"; +const char kNameAvgPool[] = "AvgPool"; +const char kNameMaxPoolWithArgmax[] = "MaxPoolWithArgmax"; +const char kNameBatchNorm[] = "BatchNorm"; +const char kNameBatchNormGrad[] = "BatchNormGrad"; +const char kNameROIAlign[] = "ROIAlign"; +const char kNameROIAlignGrad[] = "ROIAlignGrad"; +const char kNameRandomChoiceWithMask[] = "RandomChoiceWithMask"; +const char kNameAbs[] = "Abs"; +const char kNameAbsGrad[] = "AbsGrad"; +const char kNameBinaryCrossEntropy[] = "BinaryCrossEntropy"; +const char kNameBinaryCrossEntropyGrad[] = "BinaryCrossEntropyGrad"; +const char kNameSparseApplyAdagrad[] = "SparseApplyAdagrad"; +const char kNameSparseApplyFtrlD[] = "SparseApplyFtrlD"; +const char kNameApplyProximalAdagrad[] = "ApplyProximalAdagrad"; +const char kNameAcosh[] = "Acosh"; +const char kNameAcoshGrad[] = "AcoshGrad"; +const char kNameFloorMod[] = "FloorMod"; +const char kNameSpaceToDepth[] = "SpaceToDepth"; +const char kNameDepthToSpace[] = "DepthToSpace"; +const char kNameSign[] = "Sign"; +const char kNameLARSUpdate[] = "LARSUpdate"; +const char kNameRound[] = "Round"; +const char kNamePrint[] = "Print"; +const char kNameApplyFtrl[] = "ApplyFtrl"; +const char kNameDiag[] = "Diag"; +const char kNameDiagPart[] = "DiagPart"; +const char kNameSpaceToBatch[] = "SpaceToBatch"; +const char kNameBatchToSpace[] = "BatchToSpace"; +const char kNameAtan2[] = "Atan2"; +const char kNameApplyRMSProp[] = "ApplyRMSProp"; +const char kNameApplyCenteredRMSProp[] = "ApplyCenteredRMSProp"; +const char kNameL2Loss[] = "L2Loss"; +const char kNameCTCLoss[] = "CTCLoss"; +const char kNameRange[] = "Range"; +const char kNameSquareSumAll[] = "SquareSumAll"; +const char kNameAscendQuant[] = "AscendQuant"; +const char kNameAscendDequant[] = "AscendDequant"; +const char kNameCase[] = "Case"; + +// -----------------OpAdapter initialization-------------- +std::unordered_map &DfGraphConvertor::get_adpt_map() { + static std::unordered_map adpt_map = { + {string(kNameCustomOp), ADPT_DESC(Operator)}, + {string(kNameIOU), ADPT_DESC(Iou)}, + {string(kNameGreaterEqual), ADPT_DESC(GreaterEqual)}, + {string(kNameSlice), ADPT_DESC(SliceD)}, + {string(kNameApplyMomentum), ADPT_DESC(ApplyMomentumD)}, + {string(kNameMaxPool), ADPT_DESC(MaxPool)}, + {string(kNameAvgPool), ADPT_DESC(AvgPool)}, + {string(kNameMaxPoolWithArgmax), ADPT_DESC(MaxPoolWithArgmax)}, + {string(kNameTopK), ADPT_DESC(TopK)}, + {string(kNamePack), ADPT_DESC(Pack)}, + {string(kNameUnpack), ADPT_DESC(Unpack)}, + {string(kNameSplitD), ADPT_DESC(SplitD)}, + {string(kNameAllReduce), ADPT_DESC(HcomAllReduce)}, + {string(kNameBroadcast), ADPT_DESC(HcomBroadcast)}, + {string(kNameAllgather), ADPT_DESC(HcomAllGather)}, + {string(kNameReduceScatter), ADPT_DESC(HcomReduceScatter)}, + {string(kNameMaxPoolGrad), ADPT_DESC(MaxPoolGrad)}, + {string(kNameAvgPoolGrad), ADPT_DESC(AvgPoolGrad)}, + {string(kNameMaxPoolGradWithArgmax), ADPT_DESC(MaxPoolGradWithArgmax)}, + {string(kNameExtractImagePatches), ADPT_DESC(ExtractImagePatches)}, + {prim::kPrimAssign->name(), ADPT_DESC(Assign)}, + {prim::kPrimStateSetItem->name(), ADPT_DESC(Assign)}, + {prim::kPrimReluGrad->name(), ADPT_DESC(ReluGrad)}, + {prim::kPrimBiasAddGrad->name(), ADPT_DESC(BiasAddGrad)}, + {prim::kPrimConv2D->name(), ADPT_DESC(Conv2D)}, + {prim::kPrimConv2DBackpropInput->name(), ADPT_DESC(Conv2DBackpropInputD)}, + {prim::kPrimConv2DBackpropFilter->name(), ADPT_DESC(Conv2DBackpropFilterD)}, + {prim::kPrimDepthwiseConv2dNative->name(), ADPT_DESC(DepthwiseConv2D)}, + {prim::kPrimDepthwiseConv2dNativeBackpropFilter->name(), ADPT_DESC(DepthwiseConv2DBackpropFilterD)}, + {prim::kPrimDepthwiseConv2dNativeBackpropInput->name(), ADPT_DESC(DepthwiseConv2DBackpropInputD)}, + {string(kNameBatchNorm), ADPT_DESC(BatchNorm)}, + {string(kNameBatchNormGrad), ADPT_DESC(BatchNormGrad)}, + {string(kNameReshape), ADPT_DESC(Reshape)}, + {string(kNameTransShape), ADPT_DESC(TransShape)}, + {string(kNameFlattenGrad), ADPT_DESC(Reshape)}, + {prim::kPrimFlatten->name(), ADPT_DESC(Flatten)}, + {string(kNameAddN), ADPT_DESC(AddN)}, + {string(kNameLess), ADPT_DESC(Less)}, + {string(kNameSqrt), ADPT_DESC(Sqrt)}, + {string(kNameRsqrt), ADPT_DESC(Rsqrt)}, + {string(kNameSquare), ADPT_DESC(Square)}, + {prim::kPrimTanh->name(), ADPT_DESC(Tanh)}, + {prim::kPrimTanhGrad->name(), ADPT_DESC(TanhGrad)}, + {string(kNameResizeNearestNeighborD), ADPT_DESC(ResizeNearestNeighborV2D)}, + {string(kNameResizeNearestNeighborGrad), ADPT_DESC(ResizeNearestNeighborV2Grad)}, + {string(kNameApplyAdam), ADPT_DESC(ApplyAdam)}, + {string(kNameReLU6), ADPT_DESC(Relu6)}, + {string(kNameReLU6Grad), ADPT_DESC(Relu6Grad)}, + {string(kNameElu), ADPT_DESC(Elu)}, + {string(kNameEluGrad), ADPT_DESC(EluGrad)}, + {string(kNameResizeBilinearGrad), ADPT_DESC(ResizeBilinearV2Grad)}, + {string(kNameResizeBilinear), ADPT_DESC(ResizeBilinearV2D)}, + {string(kNameZerosLike), ADPT_DESC(ZerosLike)}, + {string(kNameOnesLike), ADPT_DESC(OnesLike)}, + {string(kNameTensorScatterUpdate), ADPT_DESC(TensorScatterUpdate)}, + {string(kNameScatterUpdate), ADPT_DESC(ScatterUpdate)}, + {string(kNameScatterNdUpdate), ADPT_DESC(ScatterNdUpdate)}, + {string(kNameScatterMax), ADPT_DESC(ScatterMax)}, + {string(kNameNMSWithMask), ADPT_DESC(NMSWithMask)}, + {string(kNameCheckValid), ADPT_DESC(CheckValid)}, + {string(kNameSmoothL1Loss), ADPT_DESC(SmoothL1Loss)}, + {string(kNameSmoothL1LossGrad), ADPT_DESC(SmoothL1LossGrad)}, + {string(kNameSigmoidCrossEntropyWithLogits), ADPT_DESC(SigmoidCrossEntropyWithLogits)}, + {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, + {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, + {string(kNamePadD), ADPT_DESC(PadD)}, + {string(kNameMirrorPad), ADPT_DESC(MirrorPad)}, + {string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)}, + {string(kNameGatherNd), ADPT_DESC(GatherNd)}, + {string(kNameArgmax), ADPT_DESC(ArgMaxD)}, + {string(kNameArgmin), ADPT_DESC(ArgMinD)}, + {string(kNameArgMaxWithValue), ADPT_DESC(ArgMaxWithValue)}, + {string(kNameArgMinWithValue), ADPT_DESC(ArgMinWithValue)}, + {prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD)}, + {prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMeanD)}, + {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD)}, + {prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD)}, + {prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD)}, + {string(kNameLARSUpdate), ADPT_DESC(LarsV2Update)}, + {string(kNameReduceProd), ADPT_DESC(ReduceProdD)}, + {string(kNameCumProd), ADPT_DESC(CumprodD)}, + {string(kNameMerge), ADPT_DESC(Merge)}, + {string(kNameGeSwitch), ADPT_DESC(Switch)}, + {string(kNameCumSum), ADPT_DESC(CumsumD)}, + + {prim::kPrimMul->name(), ADPT_DESC(Mul)}, + {string(kNameTile), ADPT_DESC(TileD)}, + {prim::kPrimOneHot->name(), ADPT_DESC(OneHot)}, + + {prim::kPrimGatherV2->name(), ADPT_DESC(GatherV2D)}, + {string(kNameCos), ADPT_DESC(Cos)}, + {string(kNameACos), ADPT_DESC(Acos)}, + {string(kNameACosGrad), ADPT_DESC(AcosGrad)}, + {string(kNameFloor), ADPT_DESC(Floor)}, + {string(kNameFloorDiv), ADPT_DESC(FloorDiv)}, + {string(kNameSin), ADPT_DESC(Sin)}, + {string(kNameExp), ADPT_DESC(Exp)}, + {string(kNameBoundingBoxEncode), ADPT_DESC(BoundingBoxEncode)}, + {string(kNameBoundingBoxDecode), ADPT_DESC(BoundingBoxDecode)}, + + {prim::kPrimCast->name(), ADPT_DESC(Cast)}, + {string(kNameRealDiv), ADPT_DESC(RealDiv)}, + {prim::kPrimNeg->name(), ADPT_DESC(Neg)}, + {prim::kPrimTranspose->name(), ADPT_DESC(TransposeD)}, + {prim::kPrimSub->name(), ADPT_DESC(Sub)}, + {string(kNameReciprocal), ADPT_DESC(Reciprocal)}, + {prim::kPrimDropoutGenMask->name(), ADPT_DESC(DropOutGenMask)}, + {string(kNameAssignAdd), ADPT_DESC(AssignAdd)}, + {string(kNameAssignSub), ADPT_DESC(AssignSub)}, + {prim::kPrimConcat->name(), ADPT_DESC(ConcatD)}, + {string(kNamePow), ADPT_DESC(Pow)}, + {string(kNameExp), ADPT_DESC(Exp)}, + {string(kNameEqual), ADPT_DESC(Equal)}, + {string(kNameNotEqual), ADPT_DESC(NotEqual)}, + {string(kNameLog), ADPT_DESC(Log)}, + {string(kNameLogicalAnd), ADPT_DESC(LogicalAnd)}, + {string(kNameLogicalNot), ADPT_DESC(LogicalNot)}, + {string(kNameLogicalOr), ADPT_DESC(LogicalOr)}, + {string(kNameGreater), ADPT_DESC(Greater)}, + {prim::kPrimMaximum->name(), ADPT_DESC(Maximum)}, + {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, + {string(kNamePrelu), ADPT_DESC(PRelu)}, + {string(kNamePreluGrad), ADPT_DESC(PReluGrad)}, + {string(kNameSigmoid), ADPT_DESC(Sigmoid)}, + {string(kNameSigmoidGrad), ADPT_DESC(SigmoidGrad)}, + {string(kNameSGD), ADPT_DESC(SGD)}, + {prim::kPrimLogSoftmaxGrad->name(), ADPT_DESC(LogSoftmaxGrad)}, + {prim::kPrimMaximumGrad->name(), ADPT_DESC(MaximumGrad)}, + {prim::kPrimMinimumGrad->name(), ADPT_DESC(MinimumGrad)}, + {string(kNameL2Normalize), ADPT_DESC(L2Normalize)}, + {string(kNameL2NormalizeGrad), ADPT_DESC(L2NormalizeGrad)}, + + {prim::kPrimMinimum->name(), ADPT_DESC(Minimum)}, + {prim::kPrimSelect->name(), ADPT_DESC(Select)}, + {string(kNameLessEqual), ADPT_DESC(LessEqual)}, + {prim::kPrimLogSoftmax->name(), ADPT_DESC(LogSoftmaxV2)}, + {string(kNameTruncatedNormal), ADPT_DESC(TruncatedNormal)}, + {string(kNameStridedSliceGrad), ADPT_DESC(StridedSliceGrad)}, + {prim::kPrimGelu->name(), ADPT_DESC(Gelu)}, + {prim::kPrimGeluGrad->name(), ADPT_DESC(GeluGrad)}, + {string(kNameStridedSlice), ADPT_DESC(StridedSlice)}, + {prim::kPrimUnsortedSegmentMin->name(), ADPT_DESC(UnsortedSegmentMin)}, + {prim::kPrimUnsortedSegmentSum->name(), ADPT_DESC(UnsortedSegmentSumD)}, + {string(kNameExpandDims), ADPT_DESC(ExpandDims)}, + {prim::kPrimSqueeze->name(), ADPT_DESC(Squeeze)}, + {prim::kPrimLayerNorm->name(), ADPT_DESC(LayerNorm)}, + {prim::kPrimLayerNormGrad->name(), ADPT_DESC(LayerNormGrad)}, + {string(kNameBatchMatMul), ADPT_DESC(BatchMatMul)}, + {string(kNameDropoutDoMask), ADPT_DESC(DropOutDoMask)}, + + {string(kNameNPUGetFloatStatus), ADPT_DESC(NPUGetFloatStatus)}, + {string(kNameNPUAllocFloatStatus), ADPT_DESC(NPUAllocFloatStatus)}, + {string(kNameNPUClearFloatStatus), ADPT_DESC(NPUClearFloatStatus)}, + + {string(kNameRandomChoiceWithMask), ADPT_DESC(RandomChoiceWithMask)}, + {prim::kPrimSoftmaxCrossEntropyWithLogits->name(), ADPT_DESC(SoftmaxCrossEntropyWithLogits)}, + + {prim::kPrimScalarSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimImageSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimTensorSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimHistogramSummary->name(), ADPT_DESC(Summary)}, + {prim::kPrimDebug->name(), ADPT_DESC(Summary)}, + {prim::kPrimTensorAdd->name(), + std::make_shared(std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})), + std::make_shared>(ExtraAttr({{"mode", MakeValue(1)}})))}, + {string(kNameBiasAdd), ADPT_DESC(BiasAdd)}, + {prim::kPrimRelu->name(), ADPT_DESC(Relu)}, + + {prim::kPrimMatMul->name(), ADPT_DESC(MatMulV2)}, + + {string(kNameConst), ADPT_DESC(Constant, Const)}, + {string(kNameSoftmax), ADPT_DESC(SoftmaxV2)}, + {string(kNameSoftmaxGrad), ADPT_DESC(SoftmaxGrad)}, + {string(kNameParam), ADPT_DESC(Data)}, + {string(kNameROIAlign), ADPT_DESC(ROIAlign)}, + {string(kNameROIAlignGrad), ADPT_DESC(ROIAlignGrad)}, + {string(kNameAbs), ADPT_DESC(Abs)}, + {string(kNameAbsGrad), ADPT_DESC(AbsGrad)}, + {string(kNameBinaryCrossEntropy), ADPT_DESC(BinaryCrossEntropy)}, + {string(kNameBinaryCrossEntropyGrad), ADPT_DESC(BinaryCrossEntropyGrad)}, + {string(kNameSparseApplyAdagrad), ADPT_DESC(SparseApplyAdagradD)}, + {string(kNameSparseApplyFtrlD), ADPT_DESC(SparseApplyFtrlD)}, + {string(kNameApplyProximalAdagrad), ADPT_DESC(ApplyProximalAdagradD)}, + {string(kNameAcosh), ADPT_DESC(Acosh)}, + {string(kNameAcoshGrad), ADPT_DESC(AcoshGrad)}, + {string(kNameFloorMod), ADPT_DESC(FloorMod)}, + {string(kNameSpaceToDepth), ADPT_DESC(SpaceToDepth)}, + {string(kNameDepthToSpace), ADPT_DESC(DepthToSpace)}, + {string(kNameSign), ADPT_DESC(Sign)}, + {string(kNameRound), ADPT_DESC(Round)}, + {string(kNameApplyFtrl), ADPT_DESC(ApplyFtrlD)}, + {string(kNameDiag), ADPT_DESC(Diag)}, + {string(kNameDiagPart), ADPT_DESC(DiagPart)}, + {string(kNameSpaceToBatch), ADPT_DESC(SpaceToBatchD)}, + {string(kNameBatchToSpace), ADPT_DESC(BatchToSpaceD)}, + {string(kNameAtan2), ADPT_DESC(Atan2)}, + {string(kNameApplyRMSProp), ADPT_DESC(ApplyRMSPropD)}, + {string(kNameApplyCenteredRMSProp), ADPT_DESC(ApplyCenteredRMSProp)}, + {string(kNameL2Loss), ADPT_DESC(L2Loss)}, + {string(kNameCTCLoss), ADPT_DESC(CTCLoss)}, + {string(kNameRange), ADPT_DESC(RangeD)}, + {string(kNameSquareSumAll), ADPT_DESC(SquareSumAll)}, + {string(kNameAscendQuant), ADPT_DESC(AscendQuant)}, + {string(kNameAscendDequant), ADPT_DESC(AscendDequant)}, + {string(kNameCase), ADPT_DESC(Case)}}; +#ifdef ENABLE_GE + adpt_map[string(kNamePrint)] = ADPT_DESC(Print); + adpt_map[string(kNameApplyAdam)] = ADPT_DESC(ApplyAdamD); +#endif + return adpt_map; +} + +// ---------------implement of DfGraphConvertor------------- +PrimType GetCNodeFuncType(const CNodePtr cnode) { + if (cnode->inputs().empty()) { + return kPrimTypeUnknown; + } + + AnfNodePtr valuenode = cnode->input(0); + if (IsValueNode(valuenode)) { + // check whether the valuenode is primitive + return GetValueNode(valuenode)->prim_type(); + } + return kPrimTypeUnknown; +} + +bool IsCaseNode(const CNodePtr node) { + if (!node->inputs().empty() && node->input(0)->isa() && + GetCNodeFuncName(node->input(0)->cast()) == "switch_layer") { + return true; + } + return false; +} + +std::string GetCNodeTargetFuncName(const CNodePtr cnode) { + if (IsCaseNode(cnode)) { + return string(kNameCase); + } + auto name = GetCNodeFuncName(cnode); + if (name == "switch_layer") { + name = ""; + } + return name; +} + +OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) { + if (node->isa()) { + auto cnode = node->cast(); + + std::string name = kNameCustomOp; + if (!IsCustomCNode(cnode)) { + name = GetCNodeTargetFuncName(cnode); + } + + auto it_adpt = get_adpt_map().find(name); + if (it_adpt != get_adpt_map().end()) { + return it_adpt->second->Get(train); + } + MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; + } + + if (node->isa()) { + return get_adpt_map()[kNameConst]->Get(train); + } + if (node->isa()) { + return get_adpt_map()[kNameParam]->Get(train); + } + return OpAdapterPtr(nullptr); +} + +void DfGraphConvertor::InitLoopVar(std::vector *init_input) { + if (this->training_) { + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + auto var_iter_num = std::make_shared("npu_runconfig/iterations_per_loop"); + auto var_loop_cond = std::make_shared("npu_runconfig/loop_cond"); + auto var_one = std::make_shared("npu_runconfig/one"); + auto var_zero = std::make_shared("npu_runconfig/zero"); + (void)var_iter_num->update_output_desc_y(desc); + (void)var_loop_cond->update_output_desc_y(desc); + (void)var_one->update_output_desc_y(desc); + (void)var_zero->update_output_desc_y(desc); + vars_["npu_runconfig/iterations_per_loop"] = var_iter_num; + vars_["npu_runconfig/loop_cond"] = var_loop_cond; + vars_["npu_runconfig/one"] = var_one; + vars_["npu_runconfig/zero"] = var_zero; + + int64_t value = 0; + auto const_iter_num = std::make_shared("const/npu_runconfig/iterations_per_loop"); + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + value = ConfigManager::GetInstance().iter_num(); + } else { + MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1"; + value = 1; + ConfigManager::GetInstance().set_iter_num(value); + } + value -= 1; // iteration start from 0, the max iteration number for n loop should be n-1 + (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + auto const_loop_cond = std::make_shared("const/npu_runconfig/loop_cond"); + value = 0; + (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + auto const_one = std::make_shared("const/npu_runconfig/one"); + value = 1; + (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + auto const_zero = std::make_shared("const/npu_runconfig/zero"); + value = 0; + (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast(&value), sizeof(int64_t))); + + (void)const_iter_num->update_output_desc_y(desc); + (void)const_loop_cond->update_output_desc_y(desc); + (void)const_one->update_output_desc_y(desc); + (void)const_zero->update_output_desc_y(desc); + + auto assign_iter_num = std::make_shared("assign/npu_runconfig/iterations_per_loop"); + (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num); + auto assign_loop_cond = std::make_shared("assign/npu_runconfig/loop_cond"); + (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond); + auto assign_one = std::make_shared("assign/npu_runconfig/one"); + (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one); + auto assign_zero = std::make_shared("assign/npu_runconfig/zero"); + (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero); + + init_input->push_back(*var_iter_num); + init_input->push_back(*var_loop_cond); + init_input->push_back(*var_one); + init_input->push_back(*var_zero); + init_ops_.push_back(var_iter_num); + init_ops_.push_back(var_loop_cond); + init_ops_.push_back(var_one); + init_ops_.push_back(var_zero); + init_ops_.push_back(const_iter_num); + init_ops_.push_back(const_loop_cond); + init_ops_.push_back(const_one); + init_ops_.push_back(const_zero); + init_ops_.push_back(assign_iter_num); + init_ops_.push_back(assign_loop_cond); + init_ops_.push_back(assign_one); + init_ops_.push_back(assign_zero); + } +} + +OpAdapterPtr DfGraphConvertor::FindAdapter(const std::string &name, bool train) { + auto it = get_adpt_map().find(name); + if (it != get_adpt_map().end()) { + return it->second->Get(train); + } + MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name; +} + +void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) { + // draw init subgraph + init_sout_ << "op_assign" << it.get() << "[label=<"; + init_sout_ << "" << endl; + init_sout_ << ""; + init_sout_ << ""; + init_sout_ << ""; + init_sout_ << "" << endl; + init_sout_ << "" << endl; + init_sout_ << "
resourcevalue
" + << "\"assign_" << name << "\"
> shape=plaintext]" << endl; + init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl; + init_sout_ << "const" << it.get() << "[label= \"" << name << "_const" + << "\" shape=ellipse]" << endl; + init_sout_ << "param" << it.get() << "->" + << "op_assign" << it.get() << ":1" << endl; + init_sout_ << "const" << it.get() << "->" + << "op_assign" << it.get() << ":2" << endl; +} + +void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input) { + DfGraphPtr init_graph = std::make_shared("init"); + std::vector nodes = TopoSort(anf_graph_->get_return()); + + for (auto &it : nodes) { + if (it->isa()) { + if (IsValueNode(it)) { + auto symbolic = GetValueNode(it); + auto name = std::static_pointer_cast(symbolic->node())->name(); + auto iter = vars_.find(name); // get correspoding varaible op + if (iter != vars_.end()) { + op_cache_[it.get()] = iter->second; + // #ifdef DRAW_GE_GRAPH + compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] + << "[style=\"dotted\"]" << endl; + // #endif + } + } else if (IsValueNode(it)) { + auto refkey = GetValueNode(it); + auto name = refkey->tag(); + auto iter = vars_.find(name); // get correspoding varaible op + if (iter != vars_.end()) { + op_cache_[it.get()] = iter->second; + compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] + << "[style=\"dotted\"]" << endl; + } + } + } + } + + for (auto &it : tensors) { + if (vars_.find(it.first) == vars_.end()) { + MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph."; + vars_[it.first] = nullptr; + } + } + + // set up init sub graph + if (init_input->size()) { + // init sub graph needs no input + MS_LOG(INFO) << "Build data init subgraph."; + (void)init_graph->SetInputs(*init_input); + this->init_graph_ = init_graph; + } else { + this->init_graph_ = nullptr; + } +} + +void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) { + MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input"; + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + auto getnext_idx = static_cast(input_idx); + DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); + if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) { + getnext_idx = param.input_indexes()[input_idx] - 1; // input_idx start from 0. + MS_LOG(INFO) << "remap input_index:" << input_idx << " to getnext_index:" << getnext_idx << "."; + } + // use iterator_getnext op with output_name instead of data op in BuildGraph. + out_handle_cache_[it.get()] = OutHandler(dataset_iter_getnext_, "y" + std::to_string(getnext_idx)); + } +} + +void DfGraphConvertor::SetupBroadcast(const std::shared_ptr &broadcast, + const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input) { + MS_LOG(INFO) << "build broadcast subgraph"; + if (broadcast_desc.size() != broadcast_input.size()) { + MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input"; + } + (void)broadcast->create_dynamic_input_x(static_cast(broadcast_input.size())); + (void)broadcast->create_dynamic_output_y(static_cast(broadcast_desc.size())); + for (unsigned int i = 0; i < broadcast_input.size(); i++) { + (void)broadcast->set_dynamic_input_x(i, broadcast_input[i]); + (void)broadcast->update_dynamic_output_desc_y(i, broadcast_desc[i]); + } + (void)broadcast_graph->SetInputs(broadcast_input); + this->broadcast_graph_ = broadcast_graph; +} + +void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) { + int index = 0; + std::vector init_input; + for (auto it : tensors) { + std::string name = it.first; + auto node_itor = params_.find(name); + // if name not in params_, create a node in graph + if (node_itor == params_.end()) { + MS_LOG(WARNING) << name << " is not in params, and create a new node."; + ParameterPtr param = std::make_shared(nullptr); + name = name + "_temp"; + param->set_name(name); + (void)ConvertParameter(param); + node_itor = params_.find(name); + } + auto node = node_itor->second; + auto op_itor = op_cache_.find(node.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << "."; + } + auto adpt = FindAdapter(kNameParam, training_); + if (adpt == nullptr) continue; + auto param_op = adpt->generate(name + "_data"); + MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << "."; + + if (!training_) { + auto adpt_const = FindAdapter(kNameConst, training_); + if (adpt_const == nullptr) continue; + auto const_op = adpt_const->generate(name + "_const"); + (void)adpt_const->setAttr(const_op, "value", it.second); + + auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); + if (const_op_desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + continue; + } + (void)std::static_pointer_cast(const_op)->update_output_desc_y(*const_op_desc); + + vars_[name] = const_op; + op_itor->second = const_op; + continue; + } + + // create tensor descriptor for output descriptor + auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + continue; + } + + // we need three variable ops for each graph with same name + // build init subgraph + if (it.second->is_init() == 0) { + (void)std::static_pointer_cast(param_op)->set_attr_index(index++); + auto init_var = std::make_shared(name); + auto assign_op = std::make_shared("assign_" + name); + (void)init_var->update_output_desc_y(*desc); + (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op); + init_input.push_back(*init_var); + init_ops_.push_back(param_op); + init_ops_.push_back(assign_op); + init_ops_.push_back(init_var); + } + + auto variable = std::make_shared(name); + (void)variable->update_output_desc_y(*desc); + // do not use read variable while variable sink + MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; // prevent the variable operator from being freed + DrawParamInitSubGraph(name, node); + } + InitLoopVar(&init_input); + SetupParamInitSubGraph(tensors, &init_input); +} + +// convert all parameter need initialize to variable +DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) { + size_t input_idx = 0; + if (error_ != 0) { + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + error_ = INVALID_ARGUMENT; + MS_LOG(ERROR) << "Invalid AnfGraph in InitParam."; + return *this; + } + + // Processing input with MakeDatasetHandler + for (auto &it : anf_graph_->parameters()) { + auto op_itor = op_cache_.find(it.get()); // converted node + if (it->isa() && op_itor != op_cache_.end()) { + string name = std::static_pointer_cast(it)->name(); + auto tensor_itor = tensors.find(name); // in init value map + if (tensor_itor == tensors.end()) { + DfGraphConvertor::MakeDatasetHandler(name, input_idx, it); + input_idx++; + } + } + } + InitParamWithData(tensors); + init_sout_ << "}" << endl; + return *this; +} + +#if (defined ENABLE_GE) +void DfGraphConvertor::BuildSaveCheckpointGraph() { + std::vector graph_inputs; + ge::op::Save save_op("save_parms"); + int save_op_is_active = 0; + size_t index = 0; + string name; + + int32_t count_size = std::count_if(vars_.begin(), vars_.end(), [](const std::pair &it) { + return (it.second == nullptr || it.first.find("/") != std::string::npos); + }); + + (void)save_op.create_dynamic_input_tensors(vars_.size() - static_cast(count_size)); + + // for each "parameter" in anf graph excluding "input" + for (const auto &it : vars_) { + name = it.first; + if (it.second == nullptr || name.find("/") != std::string::npos) continue; + Variable variable(name); + (void)variable.update_output_desc_y(it.second->GetOutputDesc(0)); + (void)save_op.set_dynamic_input_tensors(index++, variable); + + graph_inputs.push_back(variable); + + if (save_op_is_active == 0) { + checkpoint_sout_ << "op_save" << &save_op << "[label=<"; + checkpoint_sout_ << "" << endl; + checkpoint_sout_ << "" << endl; + checkpoint_sout_ << "" << endl; + checkpoint_sout_ << "
tensor
" + << "\"saveop" + << "\"
> shape=plaintext]" << endl; + } + + checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl; + + checkpoint_sout_ << "param" << it.second << "->" + << "op_save" << &save_op << ":1" << endl; + save_op_is_active = 1; + } + if (save_op_is_active) { + std::vector graph_output; + graph_output.emplace_back(save_op); + DfGraphPtr checkpoint_graph = std::make_shared("checkpoint"); + (void)checkpoint_graph->SetInputs(graph_inputs); + (void)checkpoint_graph->SetOutputs(graph_output); + this->save_ckp_graph_ = checkpoint_graph; + } else { + this->save_ckp_graph_ = nullptr; + } + + checkpoint_sout_ << "}" << endl; + return; +} +#endif + +DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) { + if (error_ != 0) { + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + error_ = INVALID_ARGUMENT; + MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph"; + return *this; + } + + DfGraphPtr broadcast_graph = std::make_shared("broadcast"); + // collect the operators create for broadcast sub graph, in order to avoid auto release + std::vector broadcast_input; + std::vector broadcast_desc; + auto broadcast = std::make_shared("broadcast_parameter"); + (void)broadcast->set_attr_root_rank(0); + (void)broadcast->set_attr_group("hccl_world_group"); + broadcast_ops_.push_back(broadcast); + + // find every parameter, build broadcast subgraph (or initialize the parameter with constant) + for (auto &it : anf_graph_->parameters()) { + auto op_itor = op_cache_.find(it.get()); // converted node + if (it->isa() && op_itor != op_cache_.end()) { + string name = std::static_pointer_cast(it)->name(); + auto tensor_itor = tensors.find(name); // in init tensor map + if (tensor_itor != tensors.end()) { + auto tensor = tensor_itor->second; + auto shape_ge = tensor->shape_c(); + + // create tensor descriptor for output descriptor + auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + continue; + } + + // build broadcast subgraph + if (distribute_) { + auto broadcast_var = std::make_shared(name); + (void)broadcast_var->update_output_desc_y(*desc); + broadcast_input.push_back(*broadcast_var); + broadcast_desc.push_back(*desc); + broadcast_ops_.push_back(broadcast_var); + } + } + } + } + + // set up broadcast sub graph + if (!broadcast_input.empty()) { + DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input); + } else { + this->broadcast_graph_ = nullptr; + } + return *this; +} + +DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() { + if (error_ != 0) { + MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << "."; + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + error_ = INVALID_ARGUMENT; + MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph"; + return *this; + } +#if (defined ENABLE_GE) + BuildSaveCheckpointGraph(); + // Restoring from checkpoint file is done by pyfront, not in graph now. +#endif + return *this; +} + +DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { + if (error_ != 0) { + return *this; + } + if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) { + MS_LOG(ERROR) << "Invalid AnfGraph"; + error_ = FAILED; + return *this; + } + + compute_sout_.clear(); + compute_sout_ << "digraph {" << endl; + init_sout_.clear(); + init_sout_ << "digraph {" << endl; + checkpoint_sout_.clear(); + checkpoint_sout_ << "digraph {" << endl; + restore_checkpoint_sout_.clear(); + restore_checkpoint_sout_ << "digraph {" << endl; + + // Convert all anf node to Operator + MS_LOG(DEBUG) << "convert all node"; + std::vector nodes = TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + (void)Convert(it); + if (this->error_ != 0) { + MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << "."; + } + } + + // Create dataset iterator and iterator_getnext node + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); + MS_LOG(INFO) << "Dataset param is " << param.ToString() << "."; + // GetNext + auto iter_getnext_op = make_shared("get_next_tmp"); + (void)iter_getnext_op->set_attr_output_types(param.ge_types()); + (void)iter_getnext_op->set_attr_output_shapes(param.shapes()); + (void)iter_getnext_op->set_attr_channel_name(param.queue_name()); + + // save iter_getnext_op for later use + dataset_iter_getnext_ = iter_getnext_op; + } + + // return the data flow graph + return *this; +} + +void DfGraphConvertor::TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out) { + auto it = out_handle_cache_.find(anf_out.get()); + if (it != out_handle_cache_.end()) { + OutHandler handle = it->second; + auto op = handle.op; + if (op != nullptr) { + MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; + graph_outputs_.emplace_back(std::make_pair(*op, handle.out)); + } else { + MS_LOG(EXCEPTION) << "tuple_getitem: " << anf_out->fullname_with_scope() << " is not converted"; + } + } else { + // invalid tuple_getitem e.g. tuple_getitem(tuple_getitem())/tuple_getitem(depend())/tuple_getitem(make_tuple()) + MS_LOG(WARNING) << "Invalid tuple_getitem: " << anf_out->fullname_with_scope(); + } +} + +void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { + AnfNodePtr anf_out = node; + AnfNodePtr pre_node = nullptr; + + // trace Parameter node + TraceOutputFromParameter(anf_out); + // then trace cnode + if (!node->isa()) { + return; + } + + // trace tuple_getitem + while (anf_out->isa() && IsPrimitiveCNode(anf_out, prim::kPrimTupleGetItem)) { + pre_node = anf_out; + anf_out = anf_out->cast()->input(1); + } + // trace every element of make_tuple + auto c = anf_out->cast(); + std::string name = ""; + if (anf_out->isa()) { + name = GetCNodeTargetFuncName(c); + } + + if (name == "make_tuple") { + for (unsigned int i = 1; i < c->inputs().size(); i++) { + TraceOutput(c->input(i)); + } + } else if (name == "Depend") { + if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs + MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3"; + } + TraceOutput(c->input(1)); + } else if (name == "tuple_getitem") { + TraceOutputFromTupleGetItem(anf_out); + } else { + // add outputs; + auto op = Convert(anf_out); + std::string index; + if (op != nullptr) { + if ((pre_node != nullptr) && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) { + auto item = out_handle_cache_.find(pre_node.get()); + if (item != out_handle_cache_.end()) { + index = item->second.out; + } else { + MS_LOG(WARNING) << "Can't get operater: " << anf_out->fullname_with_scope() << " 's output item"; + } + } + MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index; + graph_outputs_.emplace_back(make_pair(*op, index)); + } + } +} + +void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) { + if (anf_out->isa()) { + MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope(); + auto it = out_handle_cache_.find(anf_out.get()); + if (it != out_handle_cache_.end()) { + // For dataset graph mode, input parameter is converted to a "iterator_get_next:yn" OutHandler. + OutHandler handle = it->second; + auto op = handle.op; + MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out; + graph_outputs_.emplace_back(make_pair(*op, handle.out)); + } else { + // common parameter case + auto op = Convert(anf_out); + if (op != nullptr) { + MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType(); + graph_outputs_.emplace_back(std::make_pair(*op, "")); + } + } + } +} + +void SetupDatasetIterGetNextNode(const OperatorPtr &op) { + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + DatasetGraphParam param = ConfigManager::GetInstance().dataset_param(); + size_t output_num = param.ge_types().size(); + MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << "."; + // set iterator_getnext op's output num + shared_ptr iter_getnext = std::static_pointer_cast(op); + (void)iter_getnext->create_dynamic_output_y(static_cast(output_num)); + + for (uint32_t i = 0; i < output_num; i++) { + ge::TensorDesc desc(GeShape(param.shapes()[i]), ge::FORMAT_NCHW, (ge::DataType)param.ge_types()[i]); + // we don't SetRealDimCnt here since GE do not use this output's real-dim + (void)iter_getnext->update_dynamic_output_desc_y((i), desc); + } + } + return; +} + +void DfGraphConvertor::SetSubgraph(AnfNodePtr node) { + if (!node->isa()) { + return; + } + auto cnode = node->cast(); + if (!IsCaseNode(cnode)) { + return; + } + std::vector case_inputs; + for (size_t i = 1; i < cnode->inputs().size(); i++) { + case_inputs.emplace_back(cnode->input(i)); + } + std::shared_ptr> branches = std::make_shared>(); + auto bnode = cnode->input(0)->cast()->input(2)->cast(); + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + auto branch_node = bnode->input(i)->cast(); + for (size_t j = 2; j < branch_node->inputs().size(); j++) { + if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { + case_inputs.emplace_back(branch_node->input(j)); + } + } + } + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + ProcessSubgraph(bnode->input(i), case_inputs); + } + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + branches->emplace_back(branches_map_[bnode->input(i).get()]); + } + + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (nullptr == adpt) { + MS_LOG(DEBUG) << "Not found adapter"; + return; + } + + OperatorPtr op = Convert(node); + adpt->setSubgraph(op, 0, branches); + return; +} + +void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) { + std::vector case_inputs; + for (size_t i = 1; i < node->inputs().size(); i++) { + case_inputs.emplace_back(node->input(i)); + } + std::shared_ptr> branches = std::make_shared>(); + auto bnode = input_node->input(2)->cast(); + + for (size_t i = 1; i < bnode->inputs().size(); i++) { + auto branch_node = bnode->input(i)->cast(); + for (size_t j = 2; j < branch_node->inputs().size(); j++) { + if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) { + case_inputs.emplace_back(branch_node->input(j)); + } + } + } + + const size_t case_index = 1; + const size_t make_tuple_index = 2; + + AnfNodePtr case_index_iter = input_node->input(case_index); + AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index); + auto make_tuple_node = make_tuple_iter->cast(); + std::shared_ptr> tuple_items = std::make_shared>(); + + for (size_t i = 0; i < case_inputs.size(); i++) { + auto item = case_inputs[i]; + auto op = Convert(item); + if (op != nullptr) { + tuple_items->emplace_back(OutHandler(op, "")); + } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { + tuple_items->push_back(out_handle_cache_[item.get()]); + } else { + MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); + continue; + } + } + + tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items; + + std::shared_ptr> case_input_items = std::make_shared>(); + case_input_items->emplace_back(case_index_iter); + case_input_items->emplace_back(make_tuple_iter); + case_input_handle_cache_[node.get()] = case_input_items; +} + +DfGraphConvertor &DfGraphConvertor::BuildGraph() { + SetupDatasetIterGetNextNode(dataset_iter_getnext_); + + if (error_ != 0) { + return *this; + } + + // Case node set input. + std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + if (it->isa() && IsCaseNode(it->cast())) { + auto node = it->cast(); + auto input_node = node->input(0)->cast(); + GetCaseNodeInput(node, input_node); + } + } + + // update tuple_out_handle_cache_ + for (auto it : tuple_out_handle_cache_) { + std::size_t len = it.second->size(); + for (std::size_t i = 0; i < len; i++) { + OutHandler handle = (*it.second)[i]; + if (handle.op) { + string name = handle.op->GetName(); + if (vars_.count(name)) { + OperatorPtr new_op = vars_[name]; + if (new_op != nullptr) { + MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name; + (*it.second)[i] = OutHandler(new_op, handle.out); + } + } + } + } + } + + // set up dependices + MS_LOG(DEBUG) << "set up dependices"; + nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + for (auto &it : nodes) { + SetNodeInput(it); + SetOpControlInput(it); + SetSubgraph(it); + UpdateOpDesc(it); + } + + if (error_ == 0) { + df_graph_ = make_shared(anf_graph_->ToString()); + } else { + return *this; + } + + // set graph input according to the order from anf graph + std::vector inputs; + if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) { + inputs.push_back(*dataset_iter_getnext_); + } else { + auto params = anf_graph_->parameters(); + if (use_inputs_) { + params = inputs_; + auto anf_params = anf_graph_->parameters(); + for (size_t i = 0; i < params.size(); i++) { + for (size_t j = 0; j < anf_params.size(); j++) { + if (params[i]->ToString() == anf_params[j]->ToString()) { + params[i] = anf_params[j]; + } + } + } + } + + int index = 0; + for (auto &it : params) { + auto name = std::static_pointer_cast(it)->name(); + // the parameters which has not been converted to var + if (vars_.find(name) == vars_.end()) { + auto op = Convert(it); + MS_EXCEPTION_IF_NULL(op); + MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index; + if (op == nullptr) { + MS_LOG(ERROR) << "Convert graph failed!"; + return *this; + } + UpdateDataOpDesc(it, op); + + MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index; + (void)std::static_pointer_cast(op)->set_attr_index(index++); + inputs.push_back(*op); + } else if (vars_[name] != nullptr) { + MS_LOG(INFO) << "add var input " << it->ToString(); + auto op = Convert(it); + MS_EXCEPTION_IF_NULL(op); + inputs.push_back(*op); + } + } + } + + // Add const nodes as graph input for some operator work with constant + std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs), + [](OperatorPtr x) { return *x; }); + + MS_LOG(INFO) << "set graph input num: " << inputs.size(); + (void)df_graph_->SetInputs(inputs); + + // set graph output + // set the value of finale return apply node as the output of dataflow graph + MS_LOG(DEBUG) << "set output"; + graph_outputs_.clear(); + TraceOutput(anf_graph_->get_return()->input(1)); + MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size(); + (void)df_graph_->SetOutputs(graph_outputs_); + + compute_sout_ << "}" << endl; + // For the graph(e.g. eval_subgraph) whose IterNum is 1, donot set NeedIteration flag. + if (ConfigManager::GetInstance().iter_num() > 1) { + df_graph_->SetNeedIteration(true); + } + return *this; +} + +void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const { + auto node = std::static_pointer_cast(it); + if (node == nullptr) { + MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node."; + return; + } + auto normal_shape_ptr = dyn_cast(node->Shape()); + vector shape; + if (normal_shape_ptr == nullptr) { + MS_LOG(INFO) << "Invalid shape to update data op descriptor."; + return; + } + shape = normal_shape_ptr->shape(); + if (node->Type() == nullptr) { + MS_LOG(INFO) << "Invalid type to update data op descriptor."; + return; + } + TypeId me_type = node->Type()->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(node->Type())->element()->type_id(); + } + std::ostringstream buf; + buf << "[" << shape << "]"; + MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type; + auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); + if (desc == nullptr) { + MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null."; + } else { + (void)std::static_pointer_cast(op)->update_input_desc_x(*desc); + (void)std::static_pointer_cast(op)->update_output_desc_y(*desc); + } +} + +DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; } + +DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; } + +DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; } + +DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; } + +void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { + if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { + return; + } + + std::vector control_edges = control_depend_cache_[node.get()]; + if ((control_edges.empty())) { + MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; + return; + } + + for (auto &item : control_edges) { + (void)item.dest_op->AddControlInput(*item.src_op); + } +} + +const std::vector trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)}; + +void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { + OperatorPtr src = Convert(node); + int case_flag = 0; + auto &inputs = node->inputs(); + size_t input_size = inputs.size(); + if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) { + case_flag = 1; + input_size = case_input_handle_cache_[node.get()]->size() + 1; + } + + for (size_t i = 1; i < input_size; i++) { + auto pred = inputs[i]; + if (case_flag != 0) { + pred = case_input_handle_cache_[node.get()]->at(i - 1); + } + + while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "Depend") { + pred = pred->cast()->input(1); + } + // skip the None input + if (IsValueNode(pred)) { + continue; + } + // transform "Const" op to "Variable" op when the next node is "Assign" op. + std::string c_name = GetCNodeTargetFuncName(node); + auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); + if (!training_ && pos != trans_var_list.end() && pred->isa()) { + std::string name = std::static_pointer_cast(pred)->name(); + auto op_itor = op_cache_.find(pred.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; + } + if (op_itor->second != nullptr && + (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && + vars_.find(name) != vars_.end()) { + auto variable = std::make_shared(name); + auto desc = vars_[name]->GetOutputDesc("y"); + (void)variable->update_output_desc_y(desc); + MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; + } + } + // find in out_hadnle_cache_ first + auto it = out_handle_cache_.find(pred.get()); + if (it != out_handle_cache_.end()) { + int ret = adpt->setInput(src, SizeToInt(i), it->second); + if (ret == 0) { + if (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "tuple_getitem") { + compute_sout_ << op_draw_name_[pred->cast()->input(1).get()] << " -> " << op_draw_name_[node.get()] + << ":" << i << endl; + } else if (pred->isa()) { + compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; + } else { + // don't draw anything. + MS_LOG(INFO) << "DRAW_GE_GRAPH: Shouldn't have this case."; + } + AddGraphConstInput(it->second.op); + } + } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) { + std::shared_ptr> handler_vec = tuple_out_handle_cache_[pred.get()]; + int ret = adpt->setInput(src, SizeToInt(i), handler_vec); + if ((ret == 0) && pred->isa() && (pred->cast()->inputs().size() == handler_vec->size() + 1)) { + for (unsigned int j = 0; j < handler_vec->size(); j++) { + compute_sout_ << op_draw_name_[pred->cast()->input(j + 1).get()] << " -> " + << op_draw_name_[node.get()] << ":" << i << endl; + AddGraphConstInput(handler_vec->at(j).op); + } + } else { + MS_LOG(WARNING) << "Convert tuple node setInput failed : " << node->ToString(); + } + } else { + auto op = Convert(pred); + int ret = adpt->setInput(src, SizeToInt(i), op); + if (ret == 0) { + compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; + AddGraphConstInput(op); + } + } + } +} + +void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) { + if (op->GetOpType() == "Constant") { + graph_const_inputs_.push_back(op); + } +} + +void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) { + if (!node->isa()) { + return; + } + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + auto cnode = node->cast(); + OpAdapterPtr adpt = FindAdapter(cnode, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return; + } + + // get Operator from op_cache_, use adapter to set Inputs + DfGraphConvertor::SetOpInput(adpt, cnode); +} + +void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector &inputs) { + if (!node->isa() || GetCNodeFuncName(node->cast()) != "Partial") { + return; + } + auto graph_node = node->cast()->input(1)->cast(); + FuncGraphPtr anf_graph = graph_node->value()->cast(); + DfGraphConvertor convertor(anf_graph); + convertor.use_inputs_ = true; + convertor.inputs_ = inputs; + (void)convertor.ConvertAllNode().BuildGraph(); + std::string name = graph_node->ToString() + "_ge_graph.dot"; + if (MsContext::GetInstance()->save_graphs_flag()) { + convertor.DrawComputeGraph(name); + } + branches_map_[node.get()] = *(convertor.df_graph_); +} + +// Update GE op's shape and type info +void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) { + if (nullptr == node || !node->isa()) { + return; + } + + if (op_cache_.find(node.get()) == op_cache_.end()) { + return; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return; + } + + // get Operator from op_cache_ + OperatorPtr op = Convert(node); + + adpt->updateOutputDesc(op, node->Shape(), node->Type(), node); +} + +OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) { + if (node == nullptr) { + MS_LOG(ERROR) << "node is nullptr"; + error_ = NOT_FOUND; + return nullptr; + } + // find in cache + if (op_cache_.count(node.get())) { + return op_cache_[node.get()]; + } + + // do not convert primitive node + if (IsValueNode(node)) { + return nullptr; + } + + // convert a new one + if (node->isa()) { + return ConvertCNode(node->cast()); + } + if (node->isa()) { + return ConvertParameter(node); + } + if (node->isa()) { + return ConvertValueNode(node->cast()); + } + + MS_LOG(ERROR) << "Invalide AnfNode"; + error_ = INVALID_ARGUMENT; + return nullptr; +} + +void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { + std::shared_ptr> tuple_items = std::make_shared>(); + // convert each tuple item to a OutHandler + for (size_t i = 1; i < node->inputs().size(); i++) { + AnfNodePtr item = node->input(i); + OperatorPtr op = Convert(item); + if (op != nullptr) { + tuple_items->emplace_back(OutHandler(op, "")); + } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { + tuple_items->push_back(out_handle_cache_[item.get()]); + } else { + MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << item->ToString(); + return; + } + } + + MS_LOG(WARNING) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size(); + tuple_out_handle_cache_[node.get()] = tuple_items; +} + +AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned int *index) { + const int TUPLE_GET_ITEM_INDEX = 2; + if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs + MS_LOG(EXCEPTION) << "length of inputs of TupleGetItem is less than 3"; + } + auto index_node = node->inputs()[TUPLE_GET_ITEM_INDEX]; + if (!index_node->isa()) { + error_ = INVALID_ARGUMENT; + MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; + } + *index = IntToUint(GetValue(GetValueNode(index_node))); + return node->inputs()[1]; +} + +AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) { + auto cnode = node->cast(); + if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs + MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3"; + } + return cnode->inputs()[1]; +} + +AnfNodePtr DfGraphConvertor::TraceMakeTuple(const CNodePtr &node, unsigned int index) { + if (index + 1 >= node->inputs().size()) { + MS_LOG(EXCEPTION) << "length of make_tuple is less than index: " << index; + } + return node->inputs()[index + 1]; +} + +OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node, const std::stack &index_stack, + AnfNode *const draw_index) { + if (node == nullptr) { + MS_LOG(ERROR) << "Get nullptr while trace real op"; + return OutHandler(nullptr, ""); + } + std::ostringstream ss; + ss << "op" << node.get(); + if (index_stack.empty()) { + op_draw_name_[draw_index] = ss.str(); + return OutHandler(Convert(node), ""); + } else { + OpAdapterPtr adpt = FindAdapter(node, training_); + if (nullptr == adpt) { + MS_LOG(ERROR) << "Can not get node output as adpt is nullptr!"; + error_ = NOT_FOUND; + return OutHandler(nullptr, ""); + } + OperatorPtr op = Convert(node); + if (op == nullptr) { + error_ = NOT_FOUND; + MS_LOG(ERROR) << "Can not convert node for trace real op"; + return OutHandler(nullptr, ""); + } + op_draw_name_[draw_index] = ss.str(); + return adpt->getOutput(Convert(node), UintToInt(index_stack.top())); + } +} + +// get the real operator through maketuple tuple_getitem depend +OutHandler DfGraphConvertor::TraceRealOp(AnfNodePtr node) { + bool flag = IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) || + IsPrimitiveCNode(node, prim::kPrimDepend); + std::stack index_stack; + auto draw_index = node.get(); + while (flag) { + flag = false; + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + unsigned int index; + node = TraceTupleGetItem(node->cast(), &index); + index_stack.push(index); + flag = true; + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + if (index_stack.empty()) { + MS_LOG(ERROR) << "TraceRealOp find a make_tuple node"; + return OutHandler(nullptr, ""); + } else { + node = TraceMakeTuple(node->cast(), index_stack.top()); + index_stack.pop(); + flag = true; + } + } else if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + node = TraceDepend(node->cast()); + flag = true; + } + } + return GetHandler(node, index_stack, draw_index); +} + +void DfGraphConvertor::ConvertTupleGetItem(const CNodePtr node) { + auto handle = TraceRealOp(node); + if (handle.op == nullptr) { + MS_LOG(ERROR) << "Failed to trace tuple get item"; + return; + } + out_handle_cache_[node.get()] = handle; +} + +// Get the real op for tuple_getitem through make tuple, or depend +AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) { + const int TUPLE_GET_ITEM_INDEX = 2; + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + auto node_inputs = node->cast()->inputs(); + if (node_inputs.size() != 3) { // "tuple_getitem" primitive must have 3 inputs + MS_LOG(ERROR) << "tuple get item node not correct!"; + error_ = FAILED; + return node; + } + MS_EXCEPTION_IF_NULL(node_inputs[TUPLE_GET_ITEM_INDEX]); + if (!node_inputs[TUPLE_GET_ITEM_INDEX]->isa()) { + error_ = INVALID_ARGUMENT; + MS_LOG(EXCEPTION) << "can't convert get item with non-constant index"; + } + auto value_ptr = GetValueNode(node_inputs[TUPLE_GET_ITEM_INDEX])->cast(); + if (value_ptr == nullptr) { + MS_LOG(ERROR) << "Can not convert get item as value is nullptr!"; + error_ = FAILED; + return node; + } + int index = value_ptr->value(); + + // make_tuple apply inputs:make_tuple, [tuple_items,] + if (IsPrimitiveCNode(node_inputs[1], prim::kPrimMakeTuple)) { + auto tuple_inputs = node->cast()->inputs(); + if (tuple_inputs.size() < IntToSize(index + 1)) { + MS_LOG(ERROR) << "make tuple input items node not correct! size:" << tuple_inputs.size() + << ", item index:" << index; + error_ = FAILED; + return node; + } + return GetRealOpNode(tuple_inputs[IntToSize(index + 1)]); + } + return GetRealOpNode(node_inputs[1]); + } + + // depend apply inputs: depend,output,depended_node + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + auto depend_inputs = node->cast()->inputs(); + if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs + MS_LOG(ERROR) << "depend input items not correct"; + error_ = FAILED; + return node; + } + return GetRealOpNode(depend_inputs[1]); + } + return node; +} + +// convert the anf node to corresponding operator list +std::vector DfGraphConvertor::ConvertDependNode(const AnfNodePtr node) { + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + std::vector op_lists; + auto node_inputs = node->cast()->inputs(); + for (size_t index = 1; index < node_inputs.size(); index++) { + auto op = Convert(GetRealOpNode(node_inputs[index])); + if (op == nullptr) { + MS_LOG(ERROR) << "Convert control depend node to operator failed"; + error_ = FAILED; + return std::vector({}); + } + op_lists.push_back(op); + } + return op_lists; + } + + auto op = Convert(GetRealOpNode(node)); + if (op == nullptr) { + MS_LOG(ERROR) << "Convert control depend node to operator failed"; + error_ = FAILED; + return std::vector({}); + } + return std::vector({op}); +} + +// get the anf node list for depend +std::vector DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) { + std::vector nodes; + // for make tuple, should control depend on the tuple items + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + auto node_inputs = node->cast()->inputs(); + for (size_t index = 1; index < node_inputs.size(); index++) { + nodes.push_back(GetRealOpNode(node_inputs[index])); + } + return nodes; + } + + // for parameter ,find the apply that used the parameter as the control depended node + if (node->isa()) { + auto uses = node->func_graph()->manager()->node_users()[node]; + for (auto &use : uses) { + auto use_node = use.first; + if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) { + nodes.push_back(GetRealOpNode(use_node)); + } + } + return nodes; + } + nodes.push_back(GetRealOpNode(node)); + return nodes; +} + +void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) { +#ifdef DRAW_GE_GRAPH + auto src_depend_nodes = GetDependNodes(src_node); + auto dst_depend_nodes = GetDependNodes(dest_node); + if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) { + for (auto &item : dst_depend_nodes) { + compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()] + << "[style=\"dotted\"]" << endl; + } + } else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) { + for (auto &item : src_depend_nodes) { + compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] + << "[style=\"dotted\"]" << endl; + } + } else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) { + compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()] + << "[style=\"dotted\"]" << endl; + } +#endif +} + +void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, + const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list) { + if (src_node->isa()) { + auto uses = node->func_graph()->manager()->node_users()[src_node]; + for (auto &use : uses) { + auto use_node = use.first; + if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && + (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { + auto converted_list = ConvertDependNode(use_node); + src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); + } + } + } + + if (dest_node->isa()) { + auto uses = node->func_graph()->manager()->node_users()[dest_node]; + for (auto &use : uses) { + auto use_node = use.first; + if ((use_node->isa()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) && + (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) { + auto converted_list = ConvertDependNode(use_node); + dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); + } + } + } +} + +bool DfGraphConvertor::GetControlDependList(const CNodePtr &node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list) { + const int CONTROL_DEPEND_INDEX = 0; + const int SRC_NODE_INDEX = 1; + const int DEST_NODE_INDEX = 2; + const int DEPEND_MODE_NORMAL_USE = 0; + const int DEPEND_MODE_ON_PARAMETER_USE = 1; + + auto node_inputs = node->inputs(); + if (node_inputs.size() <= DEST_NODE_INDEX) { + MS_LOG(WARNING) << "Control depend node input size error"; + return false; + } + auto src_node = node_inputs[SRC_NODE_INDEX]; + auto dest_node = node_inputs[DEST_NODE_INDEX]; + if ((src_node == nullptr) || (dest_node == nullptr)) { + MS_LOG(ERROR) << "Control depend node miss src or dest node"; + error_ = FAILED; + return false; + } + AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX]; + PrimitivePtr prim_ptr = GetValueNode(fn); + ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); + int depend_mode = DEPEND_MODE_NORMAL_USE; + if (mode_ptr != nullptr) { + auto mode_int = mode_ptr->cast(); + MS_EXCEPTION_IF_NULL(mode_int); + depend_mode = mode_int->value(); + MS_LOG(DEBUG) << "depend_mode = " << depend_mode; + } + if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) { + GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list); + } + + if (src_node->isa()) { + auto converted_list = ConvertDependNode(src_node); + src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end()); + } + + if (dest_node->isa()) { + auto converted_list = ConvertDependNode(dest_node); + dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end()); + } + if (src_ops_list->empty() || dst_ops_list->empty()) { + MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it"; + error_ = SUCCESS; + } + return true; +} + +void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { + const int SRC_NODE_INDEX = 1; + const int DEST_NODE_INDEX = 2; + if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) { + return; + } + auto node_inputs = node->inputs(); + if (node_inputs.size() <= DEST_NODE_INDEX) { + MS_LOG(WARNING) << "Control depend node input size error"; + return; + } + auto src_node = node_inputs[SRC_NODE_INDEX]; + auto dest_node = node_inputs[DEST_NODE_INDEX]; + if ((src_node == nullptr) || (dest_node == nullptr)) { + MS_LOG(ERROR) << "Control depend node miss src or dest node"; + error_ = FAILED; + return; + } + std::shared_ptr> src_ops_list = std::make_shared>(); + std::shared_ptr> dst_ops_list = std::make_shared>(); + if (!GetControlDependList(node, src_ops_list, dst_ops_list)) { + MS_LOG(ERROR) << "Get depend list failed"; + error_ = FAILED; + return; + } + std::vector control_edges; + if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) { + (void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges), + [src_ops_list](const OperatorPtr &op) -> ControlEdge { + return {(*src_ops_list)[0], op}; + }); + } else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) { + (void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges), + [dst_ops_list](const OperatorPtr &op) -> ControlEdge { + return {op, (*dst_ops_list)[0]}; + }); + } else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) { + control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]}); + } else if (src_ops_list->empty() || dst_ops_list->empty()) { + MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it"; + } else { + MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size() + << " -> dst:" << dst_ops_list->size(); + error_ = FAILED; + return; + } + control_depend_cache_[node.get()] = control_edges; + +#ifdef DRAW_GE_GRAPH + DrawControlDepend(src_node, dest_node); +#endif +} + +bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { + // ignore apply node of return + if (name == "return" || name == "Depend") { + return false; + } + + if (name == "" && GetCNodeFuncName(node) == "switch_layer") { + return false; + } + + if (name == "Partial") { + return false; + } + + // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers + if (name == "make_tuple") { + ConvertMakeTuple(node); + return false; + } + + // As for nodes with multi outputs, convert tuple_getitem to OutHandle + if (name == "tuple_getitem") { + ConvertTupleGetItem(node); + return false; + } + + if (name == "ControlDepend") { + ConvertControlDependNode(node); + return false; + } + + return true; +} + +OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) { + std::string name = GetCNodeTargetFuncName(node); + if (!CheckCNode(name, node)) { + return nullptr; + } + + // get corresponding OpAdapter + OpAdapterPtr adpt = FindAdapter(node, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return nullptr; + } + + // get operator + OperatorPtr op = nullptr; + auto it_op = op_cache_.find(node.get()); + if (it_op != op_cache_.end()) { + op = it_op->second; + } else { + op = adpt->generate(node); + } + + // set attribute for primitive + (void)adpt->setAttr(op, node); + + // add into cache + (void)op_cache_.insert(std::make_pair(node.get(), op)); + + DrawCNode(node, adpt); + + return op_cache_[node.get()]; +} + +OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) { + // convert Parameter in ANF to variable in DataFlow + auto op = FindAdapter(node, training_)->generate(node); + op_cache_[node.get()] = op; + + // build index for parameter using name + std::string name = std::static_pointer_cast(node)->name(); + params_[name] = node; + + std::ostringstream ss; + ss << "op" << node.get(); + op_draw_name_[node.get()] = ss.str(); + compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl; + return op_cache_[node.get()]; +} + +Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) { + MS_EXCEPTION_IF_NULL(node); + ValuePtr value = node->value(); + MS_EXCEPTION_IF_NULL(value); + if (!value->isa() && !value->isa()) { + return FAILED; + } + + auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); + if (vec.empty()) { + return FAILED; + } + + std::shared_ptr> tuple_items = std::make_shared>(); + for (size_t i = 0; i < vec.size(); i++) { + MS_EXCEPTION_IF_NULL(vec[i]); + if (vec[i]->isa()) { + GeTensorPtr ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast(), kOpFormat_NCHW); + auto const_op = std::make_shared(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i)); + (void)const_op->set_attr_value(*ge_tensor); + (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc()); + tuple_items->emplace_back(OutHandler(const_op, "")); + } else { + return FAILED; + } + } + if (tuple_items->empty()) { + return FAILED; + } + + tuple_out_handle_cache_[node.get()] = tuple_items; + return SUCCESS; +} + +OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) { + // convert valuenode in ANF to Const in DataFlow + // find paramerte referenced by SymbolicKeyInstance of valuenode + std::ostringstream ss; + ss << "op" << node.get(); + op_draw_name_[node.get()] = ss.str(); + compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl; + + if (TryConvertValueNodeToMultiConst(node) == SUCCESS) { + MS_LOG(INFO) << "Convert value node to multi Constant OP success"; + return nullptr; + } + + OpAdapterPtr adpt = FindAdapter(node, training_); + if (adpt == nullptr) { + error_ = NOT_FOUND; + return nullptr; + } + auto op = adpt->generate(node); + // set const's attrs + if (adpt->setAttr(op, "value", node->value()) != 0) { + MS_LOG(WARNING) << "set attr value for const failed"; + } + +#if (defined ENABLE_GE) + auto const_op = std::static_pointer_cast(op); + if (const_op == nullptr) { + MS_LOG(ERROR) << "Get Constant operator failed"; + return nullptr; + } + auto ge_tensor = const_op->get_attr_value(); + auto ge_desc = ge_tensor.GetTensorDesc(); + (void)const_op->update_output_desc_y(ge_desc); +#endif + + op_cache_[node.get()] = op; + return op_cache_[node.get()]; +} + +void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) { + if (nullptr == adpt || nullptr == node) { + MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!"; + return; + } + std::ostringstream ss; + ss << "op" << node.get(); + op_draw_name_[node.get()] = ss.str(); + + compute_sout_ << ss.str() << "[label=<"; + compute_sout_ << "" << endl; + + auto input_map = adpt->getInputMap(); + auto dyn_input_map = adpt->getDynInputMap(); + if (input_map.size() + dyn_input_map.size() > 0) { + compute_sout_ << ""; + for (auto &it : input_map) { + compute_sout_ << ""; + } + for (auto &it : dyn_input_map) { + compute_sout_ << ""; + } + compute_sout_ << "" << endl; + } + + compute_sout_ << "" << endl; + + // print attrs' values + auto atts = adpt->GetAttrsFromDrawGraph(); + for (auto &it : atts) { + compute_sout_ << ""; + } + + adpt->clearAttrVect(); + + compute_sout_ << "
" << it.second.name << "" << it.second.name << "
\"" << node->ToString() + << ":" << GetCNodeTargetFuncName(node) << "\"
\"" << it + << "\"
> shape=plaintext]" << endl; +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h new file mode 100644 index 0000000000..6fa27831bf --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -0,0 +1,258 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ +#define MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ + +#define DRAW_GE_GRAPH + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "transform/graph_ir/util.h" +#include "ir/tensor.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "utils/config_manager.h" +#include "transform/graph_ir/op_declare.h" +#include "graph/operator_reg.h" +#ifdef OPEN_SOURCE +#include "ge/client/ge_api.h" +#else +#include "external/ge/ge_api.h" +#endif +#include "graph/tensor.h" +#include "ops/all_ops.h" + +namespace mindspore { +namespace transform { +class OpAdapterDesc { + public: + OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} + + OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} + + explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} + + OpAdapterDesc(const OpAdapterDesc &desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + } + + OpAdapterDesc(OpAdapterDesc &&desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + desc.train_ = nullptr; + desc.infer_ = nullptr; + } + + ~OpAdapterDesc() = default; + + OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } + + OpAdapterDesc &operator=(const OpAdapterDesc &desc) { + if (this != &desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + } + return *this; + } + + OpAdapterDesc &operator=(OpAdapterDesc &&desc) { + if (this != &desc) { + this->train_ = desc.train_; + this->infer_ = desc.infer_; + desc.train_ = nullptr; + desc.infer_ = nullptr; + } + return *this; + } + + private: + OpAdapterPtr train_; + OpAdapterPtr infer_; +}; + +using OpAdapterDescPtr = std::shared_ptr; +using TensorOrderMap = std::map>; + +class DfGraphConvertor { + public: + explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) + : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { +#if (!defined ENABLE_GE) || (defined ENABLE_INFER) + training_ = anf_graph->has_flag("training"); +#else + training_ = ENABLE_TRAIN; +#endif + distribute_ = anf_graph->has_flag("broadcast_flag"); + if (anf_graph->has_flag("broadcast_flag")) { + ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::DISTRIBUTION); + } else { + ConfigManager::GetInstance().set_parallel_strategy(ParallelStrategy::ONE_DEVICE); + } + + MS_LOG(INFO) << "Create DfGraphConvertor with training: " << training_ << ", distribute: " << distribute_; + } + + ~DfGraphConvertor() {} + + static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { + get_adpt_map()[name] = std::make_shared(adpt); + } + static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { + get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); + } + + void DrawComputeGraph(const std::string &name) { + std::ofstream fout(name); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << name << "' failed!"; + return; + } + fout << compute_sout_.str(); + fout.close(); + } + void DrawInitGraph(const std::string &name) { + std::ofstream fout(name); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << name << "' failed!"; + return; + } + fout << init_sout_.str(); + fout.close(); + } + void DrawSaveCheckpointGraph(const std::string &name) { + std::ofstream fout(name); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << name << "' failed!"; + return; + } + fout << checkpoint_sout_.str(); + fout.close(); + } + + DfGraphConvertor &ConvertAllNode(); + DfGraphConvertor &BuildGraph(); + DfGraphConvertor &InitParam(const TensorOrderMap &tensors); + DfGraphConvertor &GenerateCheckpointGraph(); + DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); + void InitParamWithData(const TensorOrderMap &tensors); + void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); + void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input); + void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); + void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); + void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); + + DfGraphPtr GetComputeGraph(); + DfGraphPtr GetInitGraph(); + DfGraphPtr GetSaveCheckpointGraph(); + DfGraphPtr GetBroadcastGraph(); + static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); + static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); + int ErrCode() const { return static_cast(error_); } + + static std::unordered_map &get_adpt_map(); + bool is_training() const { return training_; } + void set_training(bool is_training) { training_ = is_training; } + + protected: + void InitLoopVar(std::vector *init_input); + + private: + std::ostringstream compute_sout_; + std::ostringstream init_sout_; + std::ostringstream checkpoint_sout_; + std::ostringstream restore_checkpoint_sout_; + std::unordered_map op_draw_name_; + + AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); + AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); + AnfNodePtr TraceDepend(const CNodePtr &node); + OutHandler TraceRealOp(AnfNodePtr node); + OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); + OperatorPtr Convert(AnfNodePtr node); + OperatorPtr ConvertCNode(CNodePtr node); + std::vector ConvertDependNode(AnfNodePtr node); + AnfNodePtr GetRealOpNode(AnfNodePtr node); + std::vector GetDependNodes(const AnfNodePtr &node); + OperatorPtr ConvertParameter(AnfNodePtr node); + Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); + OperatorPtr ConvertValueNode(ValueNodePtr node); + void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); + void ConvertTupleGetItem(const CNodePtr node); + void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); + void ConvertControlDependNode(const CNodePtr node); + void ConvertMakeTuple(const CNodePtr node); + bool CheckCNode(const std::string &name, const CNodePtr node); + void TraceOutput(AnfNodePtr node); + void TraceOutputFromParameter(const AnfNodePtr &anf_out); + void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); + void SetNodeInput(AnfNodePtr node); + void SetOpControlInput(const AnfNodePtr node); + void UpdateOpDesc(AnfNodePtr node); + void SetSubgraph(AnfNodePtr node); + void ProcessSubgraph(AnfNodePtr node, const std::vector &inputs); + void BuildSaveCheckpointGraph(); + void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); + void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; + void AddGraphConstInput(const OperatorPtr &op); + + std::shared_ptr anf_graph_{nullptr}; + std::shared_ptr df_graph_{nullptr}; + std::shared_ptr init_graph_{nullptr}; + std::shared_ptr save_ckp_graph_{nullptr}; + std::shared_ptr restore_ckp_graph_{nullptr}; + std::shared_ptr broadcast_graph_{nullptr}; + std::unordered_map branches_map_; + std::unordered_map op_cache_; + std::unordered_map> control_depend_cache_; + /* record "tuple_getitem"<->"out_handler" mapping */ + std::unordered_map out_handle_cache_; + /* record "make_tuple"<->"out_handler vector" mapping */ + std::unordered_map>> tuple_out_handle_cache_; + std::unordered_map>> case_input_handle_cache_; + std::unordered_map params_; + std::unordered_map vars_; + std::vector> graph_outputs_; + std::vector graph_const_inputs_; + std::vector init_ops_; + std::vector broadcast_ops_; + std::vector inputs_; + OperatorPtr dataset_iter_getnext_; + Status error_ = SUCCESS; + bool training_ = false; + bool distribute_ = false; + bool use_inputs_ = false; +}; +} // namespace transform +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_TRANSFORM_CONVERT_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/df_graph_manager.cc b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.cc new file mode 100644 index 0000000000..29985d6784 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/df_graph_manager.h" + +#include +#include +#include +#include + +#include "securec/include/securec.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/pipeline.h" +#include "utils/config_manager.h" +#ifndef NO_DLIB +#include "tdt/tsd_client.h" +#endif + +namespace mindspore { +namespace transform { +DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, + const OptionMap &options) + : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} + +DfGraphManager::DfGraphManager() { + graph_id_ = 0; + graph_runner_ptr_ = nullptr; + sess_ptr_ = nullptr; +} + +DfGraphManager::~DfGraphManager() { + // in python fisrt destroy after atexit but in c++ destoy before atexit + DeleteGraphRunner(); + DeleteGeSession(); + ClearGraph(); + parse::python_adapter::set_python_env_flag(false); +} + +DfGraphManager &DfGraphManager::GetInstance() { + static DfGraphManager instance; + return instance; +} + +int DfGraphManager::GenerateId() { + graph_id_++; + if (graph_id_ <= 0) { + graph_id_ = 1; + } + MS_LOG(INFO) << "Generate graph Id : " << graph_id_; + return graph_id_; +} + +Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { + std::lock_guard lg(lock_); + if (name.empty()) { + MS_LOG(ERROR) << "The graph name is null, add graph failed"; + return Status::INVALID_ARGUMENT; + } + + if (graph_ptr == nullptr) { + MS_LOG(WARNING) << "The new graph {" << name << "}'s pointer is null, add graph failed"; + return Status::INVALID_ARGUMENT; + } + + int id = GenerateId(); + DfGraphWrapperPtr wrap_ptr = std::make_shared(name, id, graph_ptr, options); + auto ret = graphs_.emplace(name, wrap_ptr); + if (ret.second == false) { + MS_LOG(WARNING) << "The graph name:{ " << name << " }is already exists! The old graph will be overwritten!!"; + ret.first->second = wrap_ptr; + } + MS_LOG(INFO) << "Add graph " << name << " to GraphManager success!"; + return Status::SUCCESS; +} + +std::vector DfGraphManager::GetAllGraphs() { + std::lock_guard lg(lock_); + std::vector ret; + std::stringstream ss; + ss << "{ "; + for (auto it = graphs_.begin(); it != graphs_.end(); ++it) { + ss << it->first << ", "; + ret.emplace_back(it->second); + } + ss << "}"; + MS_LOG(INFO) << "Return graphs: " << ss.str(); + return ret; +} +std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } + +void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } + +DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { + std::lock_guard lg(lock_); + if (name.empty()) { + MS_LOG(ERROR) << "The graph name is null"; + return nullptr; + } + + auto it = graphs_.find(name); + if (it == graphs_.end()) { + MS_LOG(INFO) << "Can't found graph name: " << name; + return nullptr; + } + MS_LOG(INFO) << "Return graph: " << name; + return it->second; +} + +void DfGraphManager::ClearGraph() noexcept { + std::lock_guard lg(lock_); + graphs_.clear(); + anf_graphs_.clear(); + MS_LOG(INFO) << "Remove all graphs in GraphManager"; +} + +void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { + DfGraphWrapperPtr df_graph = GetGraphByName(name); + if (df_graph == nullptr) { + MS_LOG(ERROR) << "Can't found graph name: " << name; + return; + } + std::lock_guard lg(lock_); + anf_graphs_[df_graph->id_] = anf_graph_ptr; +} + +AnfGraphPtr DfGraphManager::GetAnfGraph(uint32_t graph_id) { + std::lock_guard lg(lock_); + auto iter = anf_graphs_.find(graph_id); + if (iter == anf_graphs_.end()) { + MS_LOG(ERROR) << "Can't found anf graph, graph_id = " << graph_id; + return nullptr; + } + + return iter->second; +} + +void DfGraphManager::EraseAnfGraph() { + std::lock_guard lg(lock_); + anf_graphs_.clear(); +} + +void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { + std::lock_guard lg(lock_); + if (sess_ptr == nullptr) { + MS_LOG(WARNING) << "You are adding a empty Ge Session"; + } + + if (sess_ptr_ == nullptr) { + MS_LOG(INFO) << "Add a new Ge Session success"; + } else { + MS_LOG(INFO) << "Add a new Ge Session success, the old Ge Session will be overwritten!!"; + } + sess_ptr_ = sess_ptr; +} + +std::shared_ptr DfGraphManager::GetGeSession() { + std::lock_guard lg(lock_); + return sess_ptr_; +} + +void DfGraphManager::DeleteGeSession() noexcept { + std::lock_guard lg(lock_); + if (sess_ptr_ == nullptr) { + MS_LOG(INFO) << "Ge Session is not exist"; + } else { + sess_ptr_ = nullptr; + saved_graphs_.clear(); + MS_LOG(INFO) << "Delete Ge Session success"; + } +} + +void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { + std::lock_guard lg(lock_); + if (graph_runner_ptr == nullptr) { + MS_LOG(WARNING) << "You are adding a empty GraphRunner"; + } + + if (graph_runner_ptr_ == nullptr) { + MS_LOG(INFO) << "Add a new GraphRunner success"; + } else { + MS_LOG(INFO) << "Add a new GraphRunner success, the old GraphRunner will be overwritten!!"; + } + graph_runner_ptr_ = graph_runner_ptr; +} + +std::shared_ptr DfGraphManager::GetGraphRunner() { + std::lock_guard lg(lock_); + return graph_runner_ptr_; +} + +void DfGraphManager::DeleteGraphRunner() noexcept { + std::lock_guard lg(lock_); + if (graph_runner_ptr_ == nullptr) { + MS_LOG(INFO) << "GraphRunner is not exist"; + } else { + graph_runner_ptr_ = nullptr; + MS_LOG(INFO) << "Delete GraphRunner success"; + } +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h new file mode 100644 index 0000000000..8a574b7a04 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/df_graph_manager.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_DF_GRAPH_MANAGER_H_ +#define TRANSFORM_DF_GRAPH_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "transform/graph_ir/types.h" +#include "ir/anf.h" + +namespace mindspore { +const char BROADCAST_GRAPH_NAME[] = "broadcast_subgraph"; + +namespace transform { +class GraphRunner; +using OptionMap = std::map; + +struct DfGraphWrapper { + public: + DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); + ~DfGraphWrapper() {} + + std::string name_; + int id_; + DfGraphPtr graph_ptr_; + OptionMap options_ = {}; +}; + +using DfGraphWrapperPtr = std::shared_ptr; + +class DfGraphManager { + public: + ~DfGraphManager(); + void ClearGraph() noexcept; + + static DfGraphManager &GetInstance(); + Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); + std::vector GetAllGraphs(); + std::set GetSavedGraphs(); + void AddSavedGraphs(const std::string &id); + DfGraphWrapperPtr GetGraphByName(const std::string &name); + DfGraphManager(const DfGraphManager &) = delete; + void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); + AnfGraphPtr GetAnfGraph(uint32_t graph_id); + std::shared_ptr GetGraphRunner(); + void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; + void DeleteGraphRunner() noexcept; + void SetGeSession(const std::shared_ptr &sess_ptr); + std::shared_ptr GetGeSession(); + void DeleteGeSession() noexcept; + void EraseAnfGraph(); + + private: + DfGraphManager(); + int GenerateId(); + + std::mutex lock_; + std::map graphs_; + std::set saved_graphs_; + int graph_id_; + std::map anf_graphs_; + std::shared_ptr graph_runner_ptr_; + std::shared_ptr sess_ptr_; +}; +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_DF_GRAPH_MANAGER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/graph_builder.cc b/mindspore/ccsrc/transform/graph_ir/graph_builder.cc new file mode 100644 index 0000000000..6ee45feef8 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_builder.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/graph_builder.h" + +#include +#include + +namespace mindspore { +namespace transform { +DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { + MS_LOG(INFO) << "BuildMDDatasetGraph."; + + // InitData + auto d = ge::op::InitData("init_data_tmp").set_attr_channel_name(param.queue_name()); + + // set graph inputs & outputs + std::vector inputs{d}; + std::vector outputs{d}; + DfGraphPtr dataset_graph = std::make_shared("dataset"); + (void)dataset_graph->SetInputs(inputs); + (void)dataset_graph->SetOutputs(outputs); + + return dataset_graph; +} + +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { + Status ret; + std::string graph_name = phase; + + MS_LOG(INFO) << "BuildDatasetGraph begin. phase is " << phase; + MS_LOG(INFO) << "param is " << param.ToString() << "."; + + DfGraphPtr dataset_graph = BuildMDDatasetGraph(param); + ret = DfGraphManager::GetInstance().AddGraph(graph_name, dataset_graph); + if (ret != Status::SUCCESS) { + MS_LOG(ERROR) << "BuildDatasetGraph failed."; + } else { + MS_LOG(INFO) << "BuildDatasetGraph end."; + } + return ret; +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/graph_builder.h b/mindspore/ccsrc/transform/graph_ir/graph_builder.h new file mode 100644 index 0000000000..5162674242 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_builder.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_GRAPH_BUILDER_H_ +#define TRANSFORM_GRAPH_BUILDER_H_ + +#include +#include +#include +#include +#include +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/convert.h" + +namespace mindspore { +namespace transform { +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_GRAPH_BUILDER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/graph_runner.cc b/mindspore/ccsrc/transform/graph_ir/graph_runner.cc new file mode 100644 index 0000000000..d20c49a381 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_runner.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * Limitations under the License. + */ + +#include "transform/graph_ir/graph_runner.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "utils/config_manager.h" +#include "sys/time.h" +#include "utils/callbacks.h" +#include "utils/utils.h" +#include "./common.h" +#ifdef ENABLE_GE +#include "utils/callbacks_ge.h" +#endif + +#ifdef NO_GE_CLIENT +namespace ge { +Session::Session(const std::map &options) { + if (options.empty()) { + MS_LOG(ERROR) << "session input options is empty"; + } + sessionId_ = 0; +} +Session::~Session() {} +} // namespace ge +#endif + +namespace mindspore { +namespace transform { +std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { + std::shared_ptr ret = std::make_shared(sess_options); + if (ret == nullptr) { + MS_LOG(ERROR) << "Create GE session failed"; + return nullptr; + } + MS_LOG(INFO) << "Create new GE session success"; + return ret; +} + +GraphRunner::GraphRunner(const GraphRunnerOptions &options) + : options_(options), graph_manager_(DfGraphManager::GetInstance()) { + if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { + MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; + } + + if (options.sess_ptr != nullptr) { + sess_ = options.sess_ptr; + } else { + sess_ = NewSession(options.options); + if (sess_ == nullptr) { + MS_LOG(EXCEPTION) << "GraphRunner initialize failed!!"; + return; + } + } + +#if (defined ENABLE_GE) + // register the callback function + if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ge::GRAPH_SUCCESS) { + MS_LOG(EXCEPTION) << "register callback failed!"; + return; + } + + if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ge::GRAPH_SUCCESS) { + MS_LOG(EXCEPTION) << "register summary callback failed!"; + return; + } +#endif + + std::vector wrappers = graph_manager_.GetAllGraphs(); + if (wrappers.empty()) { + MS_LOG(INFO) << "The GraphManager is empty!!"; + return; + } + +#ifdef ENABLE_GE + for (auto &it : wrappers) { + std::set saved_graph = graph_manager_.GetSavedGraphs(); + auto iter_find = saved_graph.find(std::to_string(it->id_)); + if (iter_find != saved_graph.end()) { + continue; + } + MS_LOG(INFO) << "Add the graph " << (*it).name_ << " to GE, it's id is: " << (*it).id_; + graph_manager_.AddSavedGraphs(std::to_string(it->id_)); + (void)sess_->AddGraph(it->id_, *(it->graph_ptr_), it->options_); + } +#endif +} + +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *outputs) { + std::string name = options.name; + if (name.empty()) { + MS_LOG(ERROR) << "The graph name is null"; + return Status::INVALID_ARGUMENT; + } + + DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name); + if (wrap_ptr == nullptr) { + MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; + return Status::NOT_FOUND; + } + + if (wrap_ptr->graph_ptr_ == nullptr) { + MS_LOG(WARNING) << "The graph is null"; + return Status::NOT_FOUND; + } + + // call ge::RunGraph() to exec a graph; + std::vector ge_inputs; + std::vector ge_outputs; + + (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), + [](const GeTensorPtr &i) { return *i; }); + + MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; + + struct timeval start_time, end_time; + (void)gettimeofday(&start_time, nullptr); + +#ifdef ENABLE_GE + if (sess_ == nullptr) { + MS_LOG(ERROR) << "The GE session is null, can't run the graph!"; + return Status::FAILED; + } + + // The information of some nodes could be changed after fusion in some cases + // Therefore a graph needs to be rebuilt in above situation + if (sess_->IsGraphNeedRebuild(wrap_ptr->id_)) { + sess_->RemoveGraph(wrap_ptr->id_); + sess_->AddGraph(wrap_ptr->id_, *(wrap_ptr->graph_ptr_), wrap_ptr->options_); + } + + ge::Status ret = sess_->RunGraph(wrap_ptr->id_, ge_inputs, ge_outputs); + if (ret != ge::GRAPH_SUCCESS) { + MS_LOG(ERROR) << "Call GE RunGraph Failed, ret is: " << ret; + return Status::FAILED; + } +#else + ge_outputs.swap(ge_inputs); +#endif + + (void)gettimeofday(&end_time, nullptr); + const uint64_t kUSecondInSecond = 1000000; + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << ge_outputs.size(); + + (void)std::transform(ge_outputs.begin(), ge_outputs.end(), std::back_inserter(*outputs), + [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); + + return Status::SUCCESS; +} + +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *const outputs) { + std::vector ge_inputs; + for (auto it : inputs) { + MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); + auto shape = (*it).shape(); + std::string shape_str; + for (const auto &elem : shape) { + shape_str += std::to_string(elem); + shape_str += " "; + } + MS_LOG(INFO) << "inputs tensor's shape is: { " << shape_str << "}"; + + auto ge_tensor_ptr = TransformUtil::ConvertTensor(it, kOpFormat_NCHW); + if (ge_tensor_ptr != nullptr) { + ge_inputs.emplace_back(ge_tensor_ptr); + } else { + MS_LOG(INFO) << "Convert input Me tensor to Ge tensor failed. Abort this graph"; + return Status::FAILED; + } + } + + std::vector ge_outputs; + Status ret; + { + // Release GIL before calling into (potentially long-running) C++ code + py::gil_scoped_release release; + ret = RunGraph(options, ge_inputs, &ge_outputs); + } + if (ret != Status::SUCCESS) { + return ret; + } else { + // conver GeTensor to MeTensor + for (auto &it : ge_outputs) { + auto tensor = TransformUtil::ConvertGeTensor(it); + if (tensor != nullptr) { + outputs->emplace_back(tensor); + } + } + MS_LOG(INFO) << "Return Me tensor outputs num is: " << outputs->size(); + return Status::SUCCESS; + } +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/graph_runner.h b/mindspore/ccsrc/transform/graph_ir/graph_runner.h new file mode 100644 index 0000000000..92db9e1413 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/graph_runner.h @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_GRAPH_RUNNER_H_ +#define TRANSFORM_GRAPH_RUNNER_H_ + +#include +#include +#include +#include +#include + +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/util.h" +#include "ir/tensor.h" +#include "transform/graph_ir/df_graph_manager.h" + +namespace mindspore { +namespace transform { +using SessionOptions = std::map; + +struct GraphRunnerOptions { + std::string target{"default_graph_runner"}; + SessionOptions options; + // if sess_ptr is nullptr, GraphRunner will create a new ge session + std::shared_ptr sess_ptr{nullptr}; +}; + +struct RunOptions { + // graph's name + std::string name; +}; + +class GraphRunner { + public: + explicit GraphRunner(const GraphRunnerOptions &options); + ~GraphRunner() { sess_ = nullptr; } + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + static std::shared_ptr NewSession(const SessionOptions &sess_options); + + private: + std::shared_ptr sess_; + transform::GraphRunnerOptions options_; + DfGraphManager &graph_manager_; +}; +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_GRAPH_RUNNER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.h b/mindspore/ccsrc/transform/graph_ir/op_adapter.h new file mode 100644 index 0000000000..358cbd20a1 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.h @@ -0,0 +1,913 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_ADAPTER_H_ +#define TRANSFORM_OP_ADAPTER_H_ + +#include +#include +#include +#include + +#include "transform/graph_ir/op_adapter_util.h" +#include "utils/utils.h" +namespace mindspore { +namespace transform { +static uint32_t CustomInferFunc(const Operator &) { return 0; } + +template +class OpAdapter : public BaseOpAdapter { + public: + using OpType = T; + OpAdapter() {} + explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} + ~OpAdapter() override {} + + bool IsCustomOp(const OperatorPtr &op) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_input_map_.find(op->GetOpType()); + if (it == cus_input_map_.end()) { + return false; + } + return true; + } + + Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(prim); + // Create the map of custom op from input index to input name. + std::unordered_map input_map; + auto value = prim->GetAttr("input_names"); + if (value == nullptr) { + cus_output_map_[prim->name()] = input_map; + return NOT_FOUND; + } + + auto input_names = GetValue>(value); + for (size_t i = 0; i < input_names.size(); ++i) { + // input_map begin form 1 + input_map[i + 1] = input_names[i]; + op->CustomInputRegister(input_names[i]); + } + + if (cus_input_map_.find(prim->name()) == cus_input_map_.end()) { + cus_input_map_[prim->name()] = input_map; + } + return SUCCESS; + } + + Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(prim); + // Create the map of custom op from output index to output name. + std::unordered_map output_map; + auto value = prim->GetAttr("output_names"); + if (value == nullptr) { + // generate a empty output_map for it + cus_output_map_[prim->name()] = output_map; + return NOT_FOUND; + } + + auto output_names = GetValue>(value); + for (size_t i = 0; i < output_names.size(); ++i) { + // output_map begin form 0 + output_map[i] = output_names[i]; + op->CustomOutputRegister(output_names[i]); + } + + if (cus_output_map_.find(prim->name()) == cus_output_map_.end()) { + cus_output_map_[prim->name()] = output_map; + } + return SUCCESS; + } + + // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs. + OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { + MS_EXCEPTION_IF_NULL(anf); + auto node = anf->cast(); + if (node == nullptr) { + return nullptr; + } + + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "length of node inputs is empty"; + } + + auto prim = GetValueNode(node->inputs()[0]); + MS_EXCEPTION_IF_NULL(prim); + auto op = std::make_shared(node->fullname_with_scope(), prim->name()); + if (GenerateCustomOpInputMap(op, prim) != SUCCESS) { + MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "]."; + } + + if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) { + MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "]."; + } + + op->CustomInferFuncRegister(CustomInferFunc); + + return op; + } + + OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { + OperatorPtr op = nullptr; + // There are duplicate names in ANF graph, do not assign ANF node name to GE + // GE will generate unique name automatically + if (anf != nullptr && anf->fullname_with_scope() != "") { + MS_LOG(DEBUG) << anf->fullname_with_scope(); + op = std::make_shared(anf->fullname_with_scope()); + } else { + MS_LOG(DEBUG) << "no fullname_with_scope"; + op = std::make_shared(); + } + + // set dynamic output num if op use DYNAMIC_OUTPUT + if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) { + TypePtr type = anf->Type(); + if (type == nullptr) { + MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!"; + } + size_t num = type->isa() ? (type->cast>()->size()) : 1; + MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString() + << ", num:" << num; + dyn_output_map_.begin()->second.create_dyn_output(op, static_cast(num)); + } + return op; + } + + OperatorPtr generate(const AnfNodePtr &anf) override { + OperatorPtr op = nullptr; + if (IsCustomCNode(anf)) { + op = GenerateCustomOp(anf); + } else { + op = GenerateNormalOp(anf); + } + return op; + } + + OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } + + const std::unordered_map &getInputMap() override { return input_map_; } + const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } + const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } + const std::unordered_map &getOutputMap() override { return output_map_; } + const std::unordered_map &getDynSubgraphMap() override { return dyn_subgraph_map_; } + + Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr> branches) { + MS_EXCEPTION_IF_NULL(op); + auto it = dyn_subgraph_map_.find(index); + if (it != dyn_subgraph_map_.end()) { + auto size = branches->size(); + it->second.create_dyn_subgraph(op, static_cast(size)); + for (size_t i = 0; i < size; i++) { + it->second.set_subgraph(op, static_cast(i), std::make_shared((*branches)[i])); + } + return SUCCESS; + } + return NOT_FOUND; + } + + int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) override { + return static_cast(SetOpSubgraphFunc(op, index, branches)); + } + + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(input); + auto it = cus_input_map_.find(op->GetOpType()); + if (it == cus_input_map_.end()) { + return NOT_FOUND; + } + std::unordered_map &input_map = it->second; + + if ((input_map.find(index) != input_map.end())) { + MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; + (void)op->SetInput(input_map[index], *input); + return SUCCESS; + } + return NOT_FOUND; + } + + Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { + MS_EXCEPTION_IF_NULL(op); + auto it = input_map_.find(index); + if (it != input_map_.end()) { + MS_EXCEPTION_IF_NULL(input); + MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name; + it->second.set_op(op, input); + return SUCCESS; + } + return NOT_FOUND; + } + + int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { + if (IsCustomOp(op)) { + auto cus_op = std::dynamic_pointer_cast(op); + return static_cast(SetCustomOpInput(cus_op, index, input)); + } else { + return static_cast(SetNormalOpInput(op, index, input)); + } + } + + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_input_map_.find(op->GetOpType()); + if (it == cus_input_map_.end()) { + return NOT_FOUND; + } + + std::unordered_map &input_map = it->second; + if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { + if (handle.out.empty()) { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; + (void)op->SetInput(input_map[index], *(handle.op)); + } else { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" + << input_map[index]; + (void)op->SetInput(input_map[index], *(handle.op), handle.out); + } + return SUCCESS; + } + return NOT_FOUND; + } + + Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { + MS_EXCEPTION_IF_NULL(op); + auto it = input_map_.find(index); + if ((handle.op != nullptr) && (it != input_map_.end())) { + if (handle.out.empty()) { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name; + it->second.set_op(op, handle.op); + } else { + MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" + << it->second.name; + it->second.set_handle(op, handle); + } + return SUCCESS; + } + return NOT_FOUND; + } + + int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { + if (IsCustomOp(op)) { + auto cus_op = std::dynamic_pointer_cast(op); + return static_cast(SetCustomOpInput(cus_op, index, handle)); + } else { + return static_cast(SetNormalOpInput(op, index, handle)); + } + } + + int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { + MS_EXCEPTION_IF_NULL(handler_vec); + if (IsCustomOp(op)) { + MS_LOG(ERROR) << "Custom Op do not support dynamic input"; + return static_cast(FAILED); + } + MS_EXCEPTION_IF_NULL(op); + auto it = dyn_input_map_.find(index); + if (it != dyn_input_map_.end()) { + it->second.create_dyn_input(op, static_cast(handler_vec->size())); + for (unsigned int i = 0; i < handler_vec->size(); ++i) { + OutHandler h = (*handler_vec)[i]; + MS_EXCEPTION_IF_NULL(h.op); + if (h.out.empty()) { + MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name; + it->second.set_op(op, (i) /* index start from 0 */, h.op); + } else { + MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":" + << it->second.name; + it->second.set_handle(op, i, h); + } + } + return 0; + } + return static_cast(NOT_FOUND); + } + + OutHandler getOutput(const OperatorPtr &op, int index) override { + MS_EXCEPTION_IF_NULL(op); + if (IsCustomOp(op)) { + return getCustomOutput(op, index); + } + return getNormalOutput(op, index); + } + + OutHandler getCustomOutput(const OperatorPtr &op, int index) { + MS_EXCEPTION_IF_NULL(op); + auto it = cus_output_map_.find(op->GetOpType()); + if (it == cus_output_map_.end()) { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!"; + return OutHandler(); + } + + std::unordered_map &output_map = it->second; + + if ((output_map.find(index) != output_map.end())) { + return OutHandler(op, output_map[index]); + } + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!"; + return OutHandler(); + } + + OutHandler getNormalOutput(const OperatorPtr &op, int index) { + MS_EXCEPTION_IF_NULL(op); + if (!dyn_output_map_.empty() && !output_map_.empty()) { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; + return OutHandler(); + } + auto it = output_map_.find(index); + if (it != output_map_.end()) { + return OutHandler(op, it->second.name); + } else if (!dyn_output_map_.empty()) { + return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index)); + } else { + MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!"; + return OutHandler(); + } + } + + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { + MS_EXCEPTION_IF_NULL(type); + std::string format = "NCHW"; + if (op->GetOpType() == kExtractImagePatchesOpName) { + format = "NHWC"; + } + + auto desc = CreateOutputDesc(dyn_cast(shp), type, format); + if (desc == nullptr) { + MS_LOG(ERROR) << "Update output descriptor failed!"; + return FAILED; + } + + if (IsCustomOp(op)) { + if (cus_output_map_.find(op->GetOpType()) == cus_output_map_.end() || + (cus_output_map_[op->GetOpType()].empty())) { + MS_LOG(ERROR) << "This op does not create custom output map"; + return FAILED; + } + auto cus_op = std::dynamic_pointer_cast(op); + MS_EXCEPTION_IF_NULL(cus_op); + std::unordered_map output_map = cus_output_map_[op->GetOpType()]; + (void)cus_op->UpdateOutputDesc(output_map[0], *desc); + } else { + if (output_map_.empty()) { + MS_LOG(INFO) << "This op does not have output map"; + return FAILED; + } + output_map_.begin()->second.update_out_desc(op, *desc); + } + return SUCCESS; + } + + size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { + MS_EXCEPTION_IF_NULL(cus_op); + if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { + MS_LOG(ERROR) << "This op does not create custom output map"; + return 0; + } + size_t output_size = cus_output_map_[cus_op->GetOpType()].size(); + return output_size; + } + + std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, + const std::string &format) { + if (shape_ptr == nullptr) { + MS_LOG(ERROR) << "Shape ptr is nullptr"; + return nullptr; + } + + if (type == nullptr) { + MS_LOG(ERROR) << "Type ptr is nullptr"; + return nullptr; + } + + TypeId me_type = type->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(type)->element()->type_id(); + } + auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format); + return desc; + } + + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { + auto tuple_shp = dyn_cast(shp); + MS_EXCEPTION_IF_NULL(tuple_shp); + + size_t output_size = 0; + bool is_custom_op = IsCustomOp(op); + if (is_custom_op) { + output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast(op)); + } else { + output_size = output_map_.size(); + } + + if (output_size == 0) { + MS_LOG(INFO) << "This op does not have output map"; + return FAILED; + } + + if (output_size != tuple_shp->shape().size()) { + MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; + return FAILED; + } + std::string format = "NCHW"; + if (op->GetOpType() == kTopKOpName) { + format = "NHWC"; + } + for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { + auto tuple_type = dyn_cast(type); + MS_EXCEPTION_IF_NULL(tuple_type); + TypePtr type_elem = tuple_type->elements()[i]; + + auto desc = CreateOutputDesc(dyn_cast(tuple_shp->shape()[i]), type_elem, format); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create output descriptor failed!"; + return FAILED; + } + + if (is_custom_op) { + (void)std::dynamic_pointer_cast(op)->UpdateOutputDesc(cus_output_map_[op->GetOpType()][i], + *desc); + } else { + auto it = output_map_.find(i); + if (it != output_map_.end()) { + it->second.update_out_desc(op, *desc); + } + } + } + return SUCCESS; + } + + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + TypeId me_type = node->Type()->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(node->Type())->element()->type_id(); + } + if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) { + return nullptr; + } + + std::vector shape; + auto shape_ptr = dyn_cast(node->Shape()); + if (nullptr != shape_ptr) { + shape = shape_ptr->shape(); + } + + auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); + if (desc == nullptr) { + MS_LOG(ERROR) << "Update output descriptor failed!"; + return nullptr; + } + return desc; + } + + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { + if (op == nullptr) { + MS_LOG(ERROR) << "op is nullptr"; + return; + } + MS_EXCEPTION_IF_NULL(node); + + auto inputs = node->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + auto it = input_map_.find(i); + if (it != input_map_.end()) { + auto desc = CreateNodeDesc(inputs[i]); + if (desc == nullptr) { + continue; + } + if (op->GetOpType() == kExtractImagePatchesOpName) { + desc->SetFormat(ge::Format::FORMAT_NHWC); + } + it->second.update_input_desc(op, *desc); + } + } + } + + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { + if (op == nullptr) { + MS_LOG(ERROR) << "op is nullptr"; + return; + } + MS_EXCEPTION_IF_NULL(node); + + if (cus_input_map_.find(op->GetOpType()) == cus_input_map_.end() || (cus_input_map_[op->GetOpType()].empty())) { + MS_LOG(ERROR) << "This op does not create custom input map"; + return; + } + + std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; + auto inputs = node->cast()->inputs(); + for (size_t i = 1; i < inputs.size(); ++i) { + if (input_map.find(i) != input_map.end()) { + auto desc = CreateNodeDesc(inputs[i]); + if (desc == nullptr) { + continue; + } + (void)op->UpdateInputDesc(input_map[i], *desc); + } + } + } + + void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(op); + MS_EXCEPTION_IF_NULL(node); + if (IsCustomOp(op)) { + auto cus_op = std::dynamic_pointer_cast(op); + UpdateCustomOpInputDesc(cus_op, node); + } else { + UpdateNormalOpInputDesc(op, node); + } + } + + void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) override { + if (op == nullptr) { + MS_LOG(ERROR) << "op is nullptr"; + return; + } + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Op name is " << op->GetName(); + + auto normal_shape_ptr = dyn_cast(shp); + auto no_shape_ptr = dyn_cast(shp); + + if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { + if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { + return; + } + } else if (nullptr != dyn_cast(shp)) { + if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { + return; + } + } else { + MS_LOG(WARNING) << "Update output desc failed, unknow output shape type"; + return; + } + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return; + } + + // Need to update input_desc while the output_desc is updated + updateInputDesc(op, node); + } + + int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { + auto it = attr_map_.find(attrKey); + if (it != attr_map_.end()) { + // switch case for each avalilable attribute type + MS_LOG(INFO) << "Set attr: " << attrKey << "(" << it->second.name << "), value: " << attrValue->ToString(); + AddAttrToDrawGraph(attrKey + std::string("=") + attrValue->ToString()); + it->second.set_attr(op, attrValue); + return 0; + } + return static_cast(NOT_FOUND); + } + + int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { + enum ValueType { + SINGLE_VALUE = 0, + SEQUEUE_VALUE, + UNKNOWN_VALUE, + }; + + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(op); + + ValueType value_type = SINGLE_VALUE; + for (auto item : prim->attrs()) { + if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + (void)op->SetAttr(item.first, GetValue(item.second)); + } else if (item.second->isa()) { + value_type = SEQUEUE_VALUE; + auto val_seq = item.second->cast(); + if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else if ((*val_seq)[0]->isa()) { + (void)op->SetAttr(item.first, GetValue>(item.second)); + } else { + MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() + << ", attr name: " << item.first << ", value: " << item.second->ToString(); + } + } else { + value_type = UNKNOWN_VALUE; + MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() + << ", attr name: " << item.first << ", value: " << item.second->ToString(); + return static_cast(NOT_FOUND); + } + + if (value_type == SINGLE_VALUE) { + AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString()); + } else if (value_type == SEQUEUE_VALUE) { + AddAttrToDrawGraph(item.first + std::string("=") + "[...]"); + } + } + return 0; + } + + int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { + int ret = 0; + MS_EXCEPTION_IF_NULL(prim); + MS_EXCEPTION_IF_NULL(op); + for (auto &it : attr_map_) { + auto value = prim->GetAttr(it.first); + if (value != nullptr) { + // set attr from primitive + ret = setAttr(op, it.first, value); + if (ret) { + return ret; + } + } else { + // set attr from extra_attr + auto it_extra = extra_attr_.find(it.first); + if (it_extra != extra_attr_.end()) { + ret = setAttr(op, it.first, it_extra->second); + if (ret) { + return ret; + } + } + } + } + return 0; + } + + int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { + int ret = 0; + if (IsCustomPrim(prim)) { + auto cus_op = std::dynamic_pointer_cast(op); + ret = SetCustomOpAttr(cus_op, prim); + } else { + ret = SetNormalOpAttr(op, prim); + } + return ret; + } + + int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { + // no attribute for lonely node + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return 0; + } + + auto cnode = node->cast(); + if (cnode == nullptr) { + return 0; + } + + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + return 0; + } + + // get Attr T from abstract of anfnode first, + // if attr "T" appears in primitive, the primitive T will cover this one + if (attr_map_.find("T") != attr_map_.end()) { + // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype + TypePtr type; + if (inputs.size() > 1) { + type = inputs[1]->Type(); + } else { + type = node->Type(); + } + if (type != nullptr) { + (void)setAttr(op, "T", MakeValue(type)); + } + } + + // set attr from primitive and ExtraAttr + if (IsValueNode(inputs[0])) { + // set attr from primitive + PrimitivePtr prim = GetValueNode(inputs[0]); + int ret = setAttr(op, prim); + if (ret != 0) { + return ret; + } + } + + // set attr from const input + for (auto &it : input_attr_map_) { + if (inputs.size() <= it.first || !inputs[it.first]->isa()) { + continue; + } + auto const_value = GetValueNode(inputs[it.first]); + MS_LOG(INFO) << "Set attr: input_" << it.first << "(" << it.second.name + << "), value: " << const_value->ToString(); + if (const_value->isa()) { + continue; + } + AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString()); + it.second.set_attr(op, const_value); + } + return 0; + } + + std::unordered_map GetExtraAttr() override { return extra_attr_; } + + private: + template + static S ConvertAny(const ValuePtr &value, const AnyTraits &) { + return GetValue(value); + } + + // specialization for reverse bool + static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { + return reverse != GetValue(value); + } + + template + static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { + return ConvertAnyUtil(value, traits_from, traits_to); + } + + // specialization for tensor + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { + // To-DO the format may read from ME tensor + return ConvertAnyUtil(value, traits); + } + + // specialization for int + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { + return static_cast(GetValue(value)); + } + + // specialization for int or tuple broadcast to Vector + static std::vector ConvertAny(const ValuePtr &value, const std::string &name, + const AnyTraits> anyTraitsInt) { + return ConvertAnyUtil(value, name, anyTraitsInt); + } + + static std::vector> ConvertAny(const ValuePtr &value, + const AnyTraits>>) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Value: " << value->type_name(); + std::vector> list; + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name(); + } + auto vec = value->cast(); + MS_EXCEPTION_IF_NULL(vec); + for (auto &it : vec->value()) { + MS_EXCEPTION_IF_NULL(it); + if (!it->isa()) { + MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); + } + auto sub_vector = it->cast(); + std::vector sublist; + for (auto &item : sub_vector->value()) { + sublist.push_back(static_cast(GetValue(item))); + } + list.push_back(sublist); + } + return list; + } + + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, + const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(DEBUG) << "Value: " << value->type_name(); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name(); + } + auto vec = value->cast(); + std::vector list; + for (auto &it : vec->value()) { + MS_EXCEPTION_IF_NULL(it); + if (!it->isa()) { + MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); + } + auto sub_vector = it->cast(); + for (auto &item : sub_vector->value()) { + list.push_back(static_cast(GetValue(item))); + } + } + return list; + } + + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, + const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(INFO) << "Value: " << value->type_name(); + std::vector list; + if (value->isa()) { + auto vec = value->cast(); + MS_EXCEPTION_IF_NULL(vec); + for (auto &it : vec->value()) { + list.push_back(static_cast(GetValue(it))); + } + return list; + } + if (value->isa()) { + list.push_back(static_cast(GetValue(value))); + return list; + } + MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); + } + + static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, + const AnyTraits anyTraitsStr) { + return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); + } + + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, + const AnyTraits anyTraitsFlo) { + return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); + } + + static std::vector ConvertAny(const ValuePtr &value, const std::string &format, + const AnyTraits> anyTraitsVec, + const AnyTraits anyTraitsInt) { + return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); + } + + // convert value list for value tuple to vector + template + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, + const AnyTraits> anyTraitsQ) { + return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); + } + + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { + auto name = GetValue(value); + auto it = enum_map_.find(name); + int v = 0; + if (it != enum_map_.end()) { + v = it->second; + } + return v; + } + + static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { + return ConvertAnyUtil(value, anyTraitsGE); + } + + // convert any value to tensor + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { + return ConvertAnyUtil(value, anyTraitsValue); + } + + static const std::unordered_map input_map_; + static const std::unordered_map dyn_input_map_; + static const std::unordered_map output_map_; + static const std::unordered_map dyn_output_map_; + static const std::unordered_map dyn_subgraph_map_; + static const std::unordered_map attr_map_; + static const std::unordered_map enum_map_; + // convert input from anf graph to Attr in Operators + static const std::unordered_map input_attr_map_; + static std::unordered_map> cus_input_map_; + static std::unordered_map> cus_output_map_; + std::unordered_map extra_attr_; + std::unordered_map name_counts_; +}; + +template +const std::unordered_map OpAdapter::input_map_; +template +const std::unordered_map OpAdapter::dyn_input_map_; +template +const std::unordered_map OpAdapter::output_map_; +template +const std::unordered_map OpAdapter::dyn_output_map_; +template +const std::unordered_map OpAdapter::dyn_subgraph_map_; +template +const std::unordered_map OpAdapter::attr_map_; +template +const std::unordered_map OpAdapter::enum_map_; +template +const std::unordered_map OpAdapter::input_attr_map_; +template +std::unordered_map> OpAdapter::cus_input_map_; +template +std::unordered_map> OpAdapter::cus_output_map_; + +// specialization for method +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_OP_ADAPTER_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h new file mode 100644 index 0000000000..77e28dda94 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h @@ -0,0 +1,198 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_ADAPTER_BASE_H_ +#define TRANSFORM_OP_ADAPTER_BASE_H_ + +#include +#include +#include +#include +#include +#include + +#include "transform/graph_ir/util.h" +#include "ir/anf.h" +#include "ir/primitive.h" +#include "ir/value.h" +#include "transform/graph_ir/types.h" +#ifdef ENABLE_GE +#ifdef OPEN_SOURCE +#include "graph/types.h" +#endif +#endif + +#include "graph/operator_reg.h" +#ifdef OPEN_SOURCE +#include "ge/client/ge_api.h" +#else +#include "external/ge/ge_api.h" +#endif +#include "graph/tensor.h" +#include "transform/graph_ir/all_ops.h" + +namespace ge { +class CustomOperator : public Operator { + public: + CustomOperator(const string &name, const string &type) : Operator(name, type) {} + + ~CustomOperator() override{}; + + void CustomInputRegister(const string &name) { Operator::InputRegister(name); } + + void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } + + void CustomInferFuncRegister(const std::function &func) { + Operator::InferFuncRegister(func); + } +}; +} // namespace ge + +namespace mindspore { +namespace transform { +using CusOperatorPtr = std::shared_ptr; +using CustomOperator = ge::CustomOperator; + +struct OutHandler { + OperatorPtr op; + std::string out; + OutHandler() : op(nullptr), out("") {} + OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} +}; + +struct ControlEdge { + OperatorPtr src_op; + OperatorPtr dest_op; +}; + +using AttrFunc = std::function; +using OutputFunc = std::function; +using InputOpFunc = std::function; +using InputHandleFunc = std::function; +using CreateDynInputOpFunc = std::function; +using DynInputOpFunc = std::function; +using DynInputHandleFunc = std::function; +using UpdateOutputDescFunc = std::function; +using CreateDynOutputOpFunc = std::function; +using CreateDynSubGraphFunc = std::function; +using DynSubGraphFunc = std::function; + +struct AttrDesc { + std::string name; + AttrFunc set_attr; +}; + +struct InputDesc { + std::string name; + InputOpFunc set_op; + InputHandleFunc set_handle; + UpdateOutputDescFunc update_input_desc; +}; + +struct DynInputDesc { + std::string name; + CreateDynInputOpFunc create_dyn_input; + DynInputOpFunc set_op; + DynInputHandleFunc set_handle; +}; + +struct DynSubGraphDesc { + std::string name; + CreateDynSubGraphFunc create_dyn_subgraph; + DynSubGraphFunc set_subgraph; +}; + +struct OutputDesc { + std::string name; + UpdateOutputDescFunc update_out_desc; +}; + +struct DynOutputDesc { + std::string name; + CreateDynOutputOpFunc create_dyn_output; +}; + +class BaseOpAdapter { + public: + virtual ~BaseOpAdapter() {} + virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; + virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } + virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; + virtual int setInput(const OperatorPtr &op, int index, + const std::shared_ptr> &handler_vec) = 0; + virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; + virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; + virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; + virtual std::unordered_map GetExtraAttr() = 0; + template ::value>::type> + int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { + return setAttr(op, attrKey, MakeValue(attrValue)); + } + template ::value>::type> + int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { + return setAttr(op, attrKey, MakeValue(attrValue)); + } + virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; + virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) = 0; + virtual const std::unordered_map &getInputMap() = 0; + virtual const std::unordered_map &getInputAttrMap() = 0; + virtual const std::unordered_map &getDynInputMap() = 0; + virtual const std::unordered_map &getOutputMap() = 0; + virtual const std::unordered_map &getDynSubgraphMap() = 0; + void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } + const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } + void clearAttrVect() { attrs_vec_.clear(); } + + private: + std::vector attrs_vec_; +}; + +using OpAdapterPtr = std::shared_ptr; + +enum AttrType { + ATTR_INT = 0, + ATTR_FLOAT, + ATTR_DOUBLE, + ATTR_STRING, + ATTR_TENSOR, + ATTR_BOOL, + ATTR_LIST_INT, + ATTR_LIST_ANY_INT, + ATTR_ENUM +}; + +struct GeEnum {}; +struct TFType {}; +struct GEType {}; + +// declare Any type +template +struct AnyTraits { + using type = T; +}; + +template <> +struct AnyTraits { + using type = int64_t; +}; + +using ExtraAttr = std::unordered_map; +} // namespace transform +} // namespace mindspore +#endif // TRANSFORM_OP_ADAPTER_BASE_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc new file mode 100644 index 0000000000..78f1f263de --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc @@ -0,0 +1,264 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/op_adapter_util.h" + +#include +#include +#include + +#include "utils/utils.h" +#include "transform/graph_ir/op_adapter_base.h" + +namespace mindspore { +namespace transform { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { + // To-DO the format may read from ME tensor + MS_EXCEPTION_IF_NULL(value); + auto me_tensor = value->cast(); + auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); + return ge_tensor == nullptr ? GeTensor() : *ge_tensor; +} + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, + const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); + std::vector list; + if (name == "pad") { + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); + } + auto vec = value->cast(); + list.resize(vec->value().size() + 2); + list[0] = 1; + list[1] = 1; + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2, + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); + } else { + int64_t data = GetValue(value); + int size = 2; // 2 int in list + list = TransformUtil::ConvertIntToList(data, size); + } + + return list; +} + +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + auto vec = value->cast(); + if (nullptr == vec) { + MS_LOG(EXCEPTION) << "not ValueTuplePtr"; + } + std::ostringstream buffer; + int i = 0; + for (auto &it : vec->value()) { + if (i != 0) { + buffer << ","; + } + buffer << GetValue(it); + i++; + } + return buffer.str(); +} + +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + auto vec = value->cast(); + if (nullptr == vec) { + MS_LOG(EXCEPTION) << "not ValueTuplePtr"; + } + std::vector list; + list.resize(vec->value().size()); + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); + return list; +} + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, + const AnyTraits>, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + auto vec = value->cast(); + if (nullptr == vec) { + MS_LOG(EXCEPTION) << "not ValueTuplePtr"; + } + std::vector list; + list.resize(vec->value().size()); + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); + if (format == kOpFormat_NHWC) { + if (list.size() < 4) { + MS_LOG(EXCEPTION) << "The size of list is less than 4"; + } else { + int64_t temp = list[1]; + list[1] = list[2]; + list[2] = list[3]; + list[3] = temp; + } + } + return list; +} + +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + if (!value->isa()) { + MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() + << ", type: " << value->type_name() << ", value should be a Typeptr"; + } + auto type = value->cast(); + MS_EXCEPTION_IF_NULL(type); + TypeId me_type = type->type_id(); + if (kObjectTypeTensorType == me_type) { + me_type = dyn_cast(type)->element()->type_id(); + } + return TransformUtil::ConvertDataType(me_type); +} + +GeTensor VectorToTensorUtil(const ValuePtr &value) { + // convert tuple or list to ge tensor, only supported one dim for now + MS_EXCEPTION_IF_NULL(value); + auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); + if (vec.empty()) { + MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor"; + return GeTensor(); + } + MS_EXCEPTION_IF_NULL(vec[0]); + if (vec[0]->isa()) { + MS_LOG(INFO) << "convert value to tensor with data type = Int32"; + auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); + auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; + } + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); + } else if (vec[0]->isa()) { + MS_LOG(INFO) << "convert value to tensor with data type = Float32"; + auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); + auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; + } + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); + } else if (vec[0]->isa()) { + MS_LOG(INFO) << "convert value to tensor with data type = Bool"; + // We use uint8_t to save bool type data + auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); + auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeBool, kOpFormat_NCHW); + if (desc == nullptr) { + MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; + } + return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); + } else { + MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); + } + + return GeTensor(); +} + +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + // convert me tensor to ge tensor + return ConvertAnyUtil(value, AnyTraits()); + } else if (value->isa() || value->isa()) { + return VectorToTensorUtil(value); + } else if (value->isa()) { + // convert scalar Int to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = Int32"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); + } else if (value->isa()) { + // convert scalar Int64 to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); + } else if (value->isa()) { + // convert scalar FP32 to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); + } else if (value->isa()) { + // convert scalar FP32 to GeTensor + MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; + GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); + auto v = GetValue(value); + desc.SetRealDimCnt(0); + return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); + } else if (value->isa()) { + // convert String to GeTensor + MS_LOG(INFO) << "convert string to tensor with data type = String"; + std::string v = GetValue(value); + std::vector ge_shape; + GeShape shape(ge_shape); + GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING); + GeTensor str_tensor(desc); + str_tensor.SetData(v); + return str_tensor; + } else { + MS_LOG(WARNING) << "Unsupported value type: " << value->type_name() + << " to convert to tensor. Value: " << value->ToString(); + } + return GeTensor(); +} + +bool IsCustomPrim(const PrimitivePtr &prim) { + if (prim == nullptr) { + return false; + } + + ValuePtr flag = prim->GetAttr("_custom_op_flag"); + if (flag == nullptr) { + return false; + } + + bool is_custom_op = GetValue(flag); + if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) { + MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op " + "can not assign the op information config path."; + } + + return is_custom_op; +} + +bool IsCustomCNode(const AnfNodePtr &anf) { + if (anf == nullptr) { + return false; + } + auto node = anf->cast(); + if (node == nullptr) { + return false; + } + if (node->inputs().empty()) { + MS_LOG(EXCEPTION) << "length of node inputs is empty"; + } + MS_EXCEPTION_IF_NULL(node->inputs()[0]); + if (!node->inputs()[0]->isa()) { + return false; + } + auto cus_prim = GetValueNode(node->inputs()[0]); + if (cus_prim == nullptr) { + return false; + } + + return IsCustomPrim(cus_prim); +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h new file mode 100644 index 0000000000..0a0d745ba2 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.h @@ -0,0 +1,66 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_ADAPTER_UTIL_H_ +#define TRANSFORM_OP_ADAPTER_UTIL_H_ + +#include +#include + +#include "transform/graph_ir/op_adapter_base.h" + +namespace mindspore { +namespace transform { +template +static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { + return static_cast(GetValue

(value)); +} + +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, + const AnyTraits>); + +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); + +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); + +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, + const AnyTraits>, const AnyTraits); + +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); + +template +std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { + if (!value->isa() && !value->isa()) { + MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() + << ", type: " << value->type_name() << ", value should be a tuple or list"; + } + auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); + std::vector data; + for (auto &it : vec) { + data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); + } + return data; +} + +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); + +bool IsCustomPrim(const PrimitivePtr &prim); +bool IsCustomCNode(const AnfNodePtr &node); +} // namespace transform +} // namespace mindspore +#endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.cc b/mindspore/ccsrc/transform/graph_ir/op_declare.cc new file mode 100644 index 0000000000..e3751e0c92 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.cc @@ -0,0 +1,1330 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/op_declare.h" + +#include + +#include "transform/graph_ir/all_ops.h" +#include "utils/utils.h" + +namespace mindspore { +namespace transform { +#define INPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::input_map_ +#define EMPTY_INPUT_MAP std::unordered_map() +#define INPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, const OperatorPtr input) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_input_##name(*input); \ + }, \ + [](const OperatorPtr op, const OutHandler& handle) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_input_##name(*(handle.op), handle.out); \ + }, \ + [](const OperatorPtr op, const GeTensorDesc desc) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->update_input_desc_##name(desc); \ + } \ + } + +#define DYN_INPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_input_map_ +#define DYN_INPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_input_##name(num); \ + }, \ + [](const OperatorPtr op, unsigned int index, const OperatorPtr input) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_input_##name(index, *input); \ + }, \ + [](const OperatorPtr op, unsigned int index, const OutHandler& handle) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_input_##name(index, *(handle.op), handle.out); \ + } \ + } + +#define DYN_SUBGRAPH_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_subgraph_map_ +#define DYN_SUBGRAPH_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_subgraph_##name(num); \ + }, \ + [](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \ + } \ + } + +#define ATTR_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::attr_map_ +#define EMPTY_ATTR_MAP std::unordered_map() +#define ATTR_DESC(name, ...) \ + { \ +#name, \ + [](const OperatorPtr op, const ValuePtr& value) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->set_attr_##name(ConvertAny(value, __VA_ARGS__)); \ + } \ + } + +#define INPUT_ATTR_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::input_attr_map_ + +#define OUTPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::output_map_ +#define OUTPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, const GeTensorDesc desc) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->update_output_desc_##name(desc); \ + } \ + } + +#define DYN_OUTPUT_MAP(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_output_map_ + +#define DYN_OUTPUT_DESC(name) \ + { \ +#name, \ + [](const OperatorPtr op, unsigned int num) { \ + auto p = std::static_pointer_cast(op); \ + (void)p->create_dynamic_output_##name(num); \ + } \ + } + +template <> +std::unordered_map> OpAdapter::cus_input_map_{}; +template <> +std::unordered_map> OpAdapter::cus_output_map_{}; + +// --------------specialization for each operator---------- +// const +INPUT_MAP(Const) = EMPTY_INPUT_MAP; +ATTR_MAP(Const) = {{"value", ATTR_DESC(value, AnyTraits())}}; +OUTPUT_MAP(Const) = {{0, OUTPUT_DESC(y)}}; + +// Assign +INPUT_MAP(Assign) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}}; +ATTR_MAP(Assign) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Assign) = {{0, OUTPUT_DESC(ref)}}; + +// Constant +INPUT_MAP(Constant) = EMPTY_INPUT_MAP; +ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits())}}; +OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}}; + +// ApplyMomentumD +INPUT_MAP(ApplyMomentumD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}}; +ATTR_MAP(ApplyMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}, + {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; + +// ScalarSummary +INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}}; +ATTR_MAP(Summary) = EMPTY_ATTR_MAP; + +// Data +INPUT_MAP(Data) = EMPTY_INPUT_MAP; +ATTR_MAP(Data) = EMPTY_ATTR_MAP; + +// BatchNorm +INPUT_MAP(BatchNorm) = {{1, INPUT_DESC(x)}, + {2, INPUT_DESC(scale)}, + {3, INPUT_DESC(offset)}, + {4, INPUT_DESC(mean)}, + {5, INPUT_DESC(variance)}}; +ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, + {"is_training", ATTR_DESC(is_training, AnyTraits())}}; +OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)}, + {1, OUTPUT_DESC(batch_mean)}, + {2, OUTPUT_DESC(batch_variance)}, + {3, OUTPUT_DESC(reserve_space_1)}, + {4, OUTPUT_DESC(reserve_space_2)}}; + +// BatchNormGrad +INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, + {2, INPUT_DESC(x)}, + {3, INPUT_DESC(scale)}, + {4, INPUT_DESC(reserve_space_1)}, + {5, INPUT_DESC(reserve_space_2)}}; +ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, + {"is_training", ATTR_DESC(is_training, AnyTraits())}}; +OUTPUT_MAP(BatchNormGrad) = {{0, OUTPUT_DESC(x_backprop)}, + {1, OUTPUT_DESC(scale_backprop)}, + {2, OUTPUT_DESC(offset_backprop)}, + {3, OUTPUT_DESC(reserve_space_4)}, + {4, OUTPUT_DESC(reserve_space_5)}}; + +// Relu +INPUT_MAP(Relu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Relu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Relu) = {{0, OUTPUT_DESC(y)}}; + +// Elu +INPUT_MAP(Elu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Elu) = {{"alpha", ATTR_DESC(alpha, AnyTraits())}}; +OUTPUT_MAP(Elu) = {{0, OUTPUT_DESC(y)}}; + +// EluGrad +INPUT_MAP(EluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(activations)}}; +ATTR_MAP(EluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(EluGrad) = {{0, OUTPUT_DESC(y)}}; + +// PRelu +INPUT_MAP(PRelu) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight)}}; +ATTR_MAP(PRelu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}}; + +// PReluGrad +INPUT_MAP(PReluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(features)}, {3, INPUT_DESC(weights)}}; +ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(da)}}; + +// Sigmoid +INPUT_MAP(Sigmoid) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sigmoid) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sigmoid) = {{0, OUTPUT_DESC(y)}}; + +// SigmoidGrad +INPUT_MAP(SigmoidGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(SigmoidGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SigmoidGrad) = {{0, OUTPUT_DESC(z)}}; + +// L2NormalizeGrad +INPUT_MAP(L2NormalizeGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(dy)}}; +ATTR_MAP(L2NormalizeGrad) = { + {"axis", ATTR_DESC(dim, AnyTraits>(), AnyTraits>())}, + {"epsilon", ATTR_DESC(eps, AnyTraits())}}; +OUTPUT_MAP(L2NormalizeGrad) = {{0, OUTPUT_DESC(dx)}}; + +// LarsV2Update +INPUT_MAP(LarsV2Update) = {{1, INPUT_DESC(w)}, + {2, INPUT_DESC(g)}, + {3, INPUT_DESC(w_square_sum)}, + {4, INPUT_DESC(g_square_sum)}, + {5, INPUT_DESC(weight_decay)}, + {6, INPUT_DESC(learning_rate)}}; +ATTR_MAP(LarsV2Update) = {{"epsilon", ATTR_DESC(epsilon, AnyTraits())}, + {"hyperpara", ATTR_DESC(hyperpara, AnyTraits())}, + {"use_clip", ATTR_DESC(use_clip, AnyTraits())}}; +OUTPUT_MAP(LarsV2Update) = {{0, OUTPUT_DESC(g_new)}}; + +// L2Normalize +INPUT_MAP(L2Normalize) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(L2Normalize) = { + {"axis", ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}, + {"epsilon", ATTR_DESC(eps, AnyTraits())}}; +OUTPUT_MAP(L2Normalize) = {{0, OUTPUT_DESC(y)}}; + +// CumsumD +INPUT_MAP(CumsumD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(CumsumD) = {{2, ATTR_DESC(axis, AnyTraits())}}; +ATTR_MAP(CumsumD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits())}, + {"reverse", ATTR_DESC(reverse, AnyTraits())}}; +OUTPUT_MAP(CumsumD) = {{0, OUTPUT_DESC(y)}}; + +// SoftmaxV2 +INPUT_MAP(SoftmaxV2) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SoftmaxV2) = { + {"axis", ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}, +}; +OUTPUT_MAP(SoftmaxV2) = {{0, OUTPUT_DESC(y)}}; + +// SoftmaxGrad +INPUT_MAP(SoftmaxGrad) = {{1, INPUT_DESC(softmax)}, {2, INPUT_DESC(grad_softmax)}}; +OUTPUT_MAP(SoftmaxGrad) = {{0, OUTPUT_DESC(grad_x)}}; +ATTR_MAP(SoftmaxGrad) = EMPTY_ATTR_MAP; + +// Flatten +INPUT_MAP(Flatten) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Flatten) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Flatten) = {{0, OUTPUT_DESC(y)}}; + +// add +INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Add) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}}; + +// GatherV2 +INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(axis)}}; +ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; + +// ReduceSumD +INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceSumD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceSumD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceSumD) = {{0, OUTPUT_DESC(y)}}; + +// ReduceProdD +INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceProdD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}}; + +// CumprodD +INPUT_MAP(CumprodD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(CumprodD) = {{2, ATTR_DESC(axis, AnyTraits())}}; +ATTR_MAP(CumprodD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits())}, + {"reverse", ATTR_DESC(reverse, AnyTraits())}}; +OUTPUT_MAP(CumprodD) = {{0, OUTPUT_DESC(y)}}; + +// SoftmaxCrossEntropyWithLogits +INPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(labels)}}; +ATTR_MAP(SoftmaxCrossEntropyWithLogits) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(backprop)}}; + +// MeanGrad +INPUT_MAP(MeanGrad) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(MeanGrad) = {{2, ATTR_DESC(mean_grad_output_shape_value, kOpFormat_NHWC, + AnyTraits>(), AnyTraits())}}; +ATTR_MAP(MeanGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; + +INPUT_MAP(SliceD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(SliceD) = {{2, ATTR_DESC(offsets, AnyTraits(), AnyTraits>())}, + {3, ATTR_DESC(size, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(SliceD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SliceD) = {{0, OUTPUT_DESC(y)}}; + +// MaxPool +INPUT_MAP(MaxPool) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(MaxPool) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(MaxPool) = {{0, OUTPUT_DESC(y)}}; + +// AvgPool +INPUT_MAP(AvgPool) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(AvgPool) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(AvgPool) = {{0, OUTPUT_DESC(y)}}; + +// GreaterEqual +INPUT_MAP(GreaterEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(GreaterEqual) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GreaterEqual) = {{0, OUTPUT_DESC(y)}}; + +// AssignAdd +INPUT_MAP(AssignAdd) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}}; +ATTR_MAP(AssignAdd) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AssignAdd) = {{0, OUTPUT_DESC(ref)}}; + +// AssignSub +INPUT_MAP(AssignSub) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(value)}}; +ATTR_MAP(AssignSub) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AssignSub) = {{0, OUTPUT_DESC(var)}}; + +// Cos +INPUT_MAP(Cos) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Cos) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Cos) = {{0, OUTPUT_DESC(y)}}; + +// Acos +INPUT_MAP(Acos) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Acos) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Acos) = {{0, OUTPUT_DESC(y)}}; + +// AcosGrad +INPUT_MAP(AcosGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AcosGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AcosGrad) = {{0, OUTPUT_DESC(z)}}; + +// Acosh +INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Acosh) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}}; + +// AcoshGrad +INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}}; + +// Floor +INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Floor) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Floor) = {{0, OUTPUT_DESC(y)}}; + +// FloorDiv +INPUT_MAP(FloorDiv) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(FloorDiv) = EMPTY_ATTR_MAP; +OUTPUT_MAP(FloorDiv) = {{0, OUTPUT_DESC(y)}}; + +// FloorMod +INPUT_MAP(FloorMod) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(FloorMod) = EMPTY_ATTR_MAP; +OUTPUT_MAP(FloorMod) = {{0, OUTPUT_DESC(y)}}; + +// Sin +INPUT_MAP(Sin) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sin) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sin) = {{0, OUTPUT_DESC(y)}}; + +// Exp +INPUT_MAP(Exp) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Exp) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Exp) = {{0, OUTPUT_DESC(y)}}; + +// BoundingBoxEncode +INPUT_MAP(BoundingBoxEncode) = { + {1, INPUT_DESC(anchor_box)}, + {2, INPUT_DESC(ground_truth_box)}, +}; +ATTR_MAP(BoundingBoxEncode) = { + {"means", ATTR_DESC(means, AnyTraits>(), AnyTraits())}, + {"stds", ATTR_DESC(stds, AnyTraits>(), AnyTraits())}, +}; +OUTPUT_MAP(BoundingBoxEncode) = {{0, OUTPUT_DESC(delats)}}; + +// BoundingBoxDecode +INPUT_MAP(BoundingBoxDecode) = { + {1, INPUT_DESC(rois)}, + {2, INPUT_DESC(deltas)}, +}; +ATTR_MAP(BoundingBoxDecode) = { + {"means", ATTR_DESC(means, AnyTraits>(), AnyTraits())}, + {"stds", ATTR_DESC(stds, AnyTraits>(), AnyTraits())}, + {"max_shape", ATTR_DESC(max_shape, AnyTraits>(), AnyTraits>())}, + {"wh_ratio_clip", ATTR_DESC(wh_ratio_clip, AnyTraits())}, +}; +OUTPUT_MAP(BoundingBoxDecode) = {{0, OUTPUT_DESC(bboxes)}}; + +// TopK +INPUT_MAP(TopK) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(k)}}; +ATTR_MAP(TopK) = {{"sorted", ATTR_DESC(sorted, AnyTraits())}}; +OUTPUT_MAP(TopK) = {{0, OUTPUT_DESC(values)}, {1, OUTPUT_DESC(indices)}}; + +// Multiply +INPUT_MAP(Multiply) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; +ATTR_MAP(Multiply) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Multiply) = {{0, OUTPUT_DESC(z)}}; + +// TileD +INPUT_MAP(TileD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TileD) = {{2, ATTR_DESC(multiples, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TileD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TileD) = {{0, OUTPUT_DESC(y)}}; + +// OneHot +INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}}; +ATTR_MAP(OneHot) = {{"axis", ATTR_DESC(axis, AnyTraits())}}; +OUTPUT_MAP(OneHot) = {{0, OUTPUT_DESC(y)}}; + +// GatherV2D +INPUT_MAP(GatherV2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; +INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits())}}; +ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}}; + +// Reshape +INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}}; +ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}}; + +// TransShape +INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TransShape) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}}; + +// BiasAdd +INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}}; +ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(BiasAdd) = {{0, OUTPUT_DESC(y)}}; + +// Iou +INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; +ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; + +// ResizeNearestNeighborV2D +INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeNearestNeighborV2D) = { + {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, + {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; + +// ResizeNearestNeighborV2Grad +INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; +ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; + +// ApplyAdam +INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, + {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, + {10, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; +OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; + +// ApplyAdamD +INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, + {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, + {10, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; +OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; + +// Relu6 +INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Relu6) = {{0, OUTPUT_DESC(y)}}; + +// Relu6Grad +INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; +ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; + +// ResizeBilinearV2Grad +INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; +ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; + +// ResizeBilinearV2D +INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeBilinearV2D) = { + {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, + {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; + +// ZerosLike +INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ZerosLike) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ZerosLike) = {{0, OUTPUT_DESC(y)}}; + +// OnesLike +INPUT_MAP(OnesLike) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(OnesLike) = EMPTY_ATTR_MAP; +OUTPUT_MAP(OnesLike) = {{0, OUTPUT_DESC(y)}}; + +// NMSWithMask +INPUT_MAP(NMSWithMask) = {{1, INPUT_DESC(box_scores)}}; +ATTR_MAP(NMSWithMask) = {{"iou_threshold", ATTR_DESC(iou_threshold, AnyTraits())}}; +OUTPUT_MAP(NMSWithMask) = { + {0, OUTPUT_DESC(selected_boxes)}, {1, OUTPUT_DESC(selected_idx)}, {2, OUTPUT_DESC(selected_mask)}}; + +// Unpack +INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits())}, {"num", ATTR_DESC(num, AnyTraits())}}; +DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; + +// TensorScatterUpdate +INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}}; + +// ScatterUpdate +INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}}; + +// ScatterNdUpdate +INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}}; + +// ScatterMax +INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; +ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}}; + +// CheckValid +INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}}; +ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP; +OUTPUT_MAP(CheckValid) = {{0, OUTPUT_DESC(valid_tensor)}}; + +// SmoothL1Loss +INPUT_MAP(SmoothL1Loss) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(label)}}; +ATTR_MAP(SmoothL1Loss) = {{"sigma", ATTR_DESC(sigma, AnyTraits())}}; +OUTPUT_MAP(SmoothL1Loss) = {{0, OUTPUT_DESC(loss)}}; + +// SmoothL1LossGrad +INPUT_MAP(SmoothL1LossGrad) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(label)}, {3, INPUT_DESC(dout)}}; +ATTR_MAP(SmoothL1LossGrad) = {{"sigma", ATTR_DESC(sigma, AnyTraits())}}; +OUTPUT_MAP(SmoothL1LossGrad) = {{0, OUTPUT_DESC(gradient)}}; + +// SigmoidCrossEntropyWithLogits +INPUT_MAP(SigmoidCrossEntropyWithLogits) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(target)}}; +ATTR_MAP(SigmoidCrossEntropyWithLogits) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SigmoidCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}}; + +// SigmoidCrossEntropyWithLogitsGrad +INPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = { + {1, INPUT_DESC(predict)}, {2, INPUT_DESC(target)}, {3, INPUT_DESC(dout)}}; +ATTR_MAP(SigmoidCrossEntropyWithLogitsGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = {{0, OUTPUT_DESC(gradient)}}; + +// ScatterNdD +INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ScatterNdD) = { + {3, ATTR_DESC(shape, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ScatterNdD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ScatterNdD) = {{0, OUTPUT_DESC(y)}}; + +// PadD +INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits>>())}}; +OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}}; + +// MirrorPad +INPUT_MAP(MirrorPad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; +ATTR_MAP(MirrorPad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(MirrorPad) = {{0, OUTPUT_DESC(y)}}; + +// MirrorPadGrad +INPUT_MAP(MirrorPadGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; +ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; + +// GatherNd +INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; +ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GatherNd) = {{0, OUTPUT_DESC(y)}}; + +// ROIAlign +INPUT_MAP(ROIAlign) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(rois)}}; +OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, + {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, + {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, + {"sample_num", ATTR_DESC(sample_num, AnyTraits())}, + {"roi_end_mode", ATTR_DESC(roi_end_mode, AnyTraits())}}; + +// ROIAlignGrad +INPUT_MAP(ROIAlignGrad) = {{1, INPUT_DESC(ydiff)}, {2, INPUT_DESC(rois)}}; +OUTPUT_MAP(ROIAlignGrad) = {{0, OUTPUT_DESC(xdiff)}}; +ATTR_MAP(ROIAlignGrad) = { + {"xdiff_shape", ATTR_DESC(xdiff_shape, AnyTraits>(), AnyTraits>())}, + {"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, + {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, + {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, + {"sample_num", ATTR_DESC(sample_num, AnyTraits())}}; + +// ArgMaxD +INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"output_type", ATTR_DESC(dtype, AnyTraits())}}; +OUTPUT_MAP(ArgMaxD) = {{0, OUTPUT_DESC(y)}}; + +// ArgMinD +INPUT_MAP(ArgMinD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMinD) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"output_type", ATTR_DESC(dtype, AnyTraits())}}; +OUTPUT_MAP(ArgMinD) = {{0, OUTPUT_DESC(y)}}; + +// ArgMaxWithValue +INPUT_MAP(ArgMaxWithValue) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMaxWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ArgMaxWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; + +// ArgMinWithValue +INPUT_MAP(ArgMinWithValue) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, + {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; + +// ReduceAllD +INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceAllD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; + +// ReduceMeanD +INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceMeanD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceMeanD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceMeanD) = {{0, OUTPUT_DESC(y)}}; + +// HCOMAllreduce +INPUT_MAP(HcomAllReduce) = {{1, INPUT_DESC(x)}}; +OUTPUT_MAP(HcomAllReduce) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(HcomAllReduce) = {{"op", ATTR_DESC(reduction, AnyTraits())}, + {"group", ATTR_DESC(group, AnyTraits())}, + {"fusion", ATTR_DESC(fusion, AnyTraits())}}; + +// HCOMBraodcast +INPUT_MAP(HcomBroadcast) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(HcomBroadcast) = {{1, DYN_INPUT_DESC(x)}}; +DYN_OUTPUT_MAP(HcomBroadcast) = {{0, DYN_OUTPUT_DESC(y)}}; +ATTR_MAP(HcomBroadcast) = {{"root_rank", ATTR_DESC(root_rank, AnyTraits())}, + {"group", ATTR_DESC(group, AnyTraits())}}; + +// HCOMAllreduce +INPUT_MAP(HcomAllGather) = {{1, INPUT_DESC(x)}}; +OUTPUT_MAP(HcomAllGather) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(HcomAllGather) = {{"group", ATTR_DESC(group, AnyTraits())}, + {"rank_size", ATTR_DESC(rank_size, AnyTraits())}}; + +// HCOMReduceScatter +INPUT_MAP(HcomReduceScatter) = {{1, INPUT_DESC(x)}}; +OUTPUT_MAP(HcomReduceScatter) = {{0, OUTPUT_DESC(y)}}; +ATTR_MAP(HcomReduceScatter) = {{"group", ATTR_DESC(group, AnyTraits())}, + {"op", ATTR_DESC(reduction, AnyTraits())}, + {"rank_size", ATTR_DESC(rank_size, AnyTraits())}}; + +// Variable +INPUT_MAP(Variable) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Variable) = EMPTY_ATTR_MAP; + +// ReluGrad +INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; +ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}}; + +// BiasAddGrad +INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(BiasAddGrad) = {{0, OUTPUT_DESC(y)}}; + +// MaxPoolGrad +INPUT_MAP(MaxPoolGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grad)}}; +ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; + +// avgpoolgrad +INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; +ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}}; +OUTPUT_MAP(AvgPoolGrad) = {{0, OUTPUT_DESC(out_grad)}}; + +// MaxPoolWithArgmax +INPUT_MAP(MaxPoolWithArgmax) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}}; + +// MaxPoolGradWithArgmax +INPUT_MAP(MaxPoolGradWithArgmax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}, {3, INPUT_DESC(argmax)}}; +ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; + +// ExtractImagePatches +INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits(), AnyTraits>())}, + {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, + {"rates", ATTR_DESC(rates, AnyTraits(), AnyTraits>())}, + {"padding", ATTR_DESC(padding, AnyTraits())}}; +OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; + +// Conv2D +INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +ATTR_MAP(Conv2D) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; +OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; + +// Conv2DBackpropInputD +INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}}; +INPUT_ATTR_MAP(Conv2DBackpropInputD) = { + {3, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(Conv2DBackpropInputD) = { + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; +OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; + +// Conv2DBackpropFilterD +INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { + {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(Conv2DBackpropFilterD) = { + {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, +}; +OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; + +// DepthwiseConv2D +INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; +ATTR_MAP(DepthwiseConv2D) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, +}; +OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}}; + +// DepthwiseConv2DBackpropInputD +INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_DESC(out_backprop)}}; +INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = { + {1, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(DepthwiseConv2DBackpropInputD) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, +}; +OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}}; + +// DepthwiseConv2DBackpropFilterD +INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_DESC(out_backprop)}}; +INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { + {2, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, +}; +OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; + +// MatMulV2 +INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits())}, + {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits())}}; +OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; + +// Merge +INPUT_MAP(Merge) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(Merge) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(Merge) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Merge) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(value_index)}}; + +// Switch +INPUT_MAP(Switch) = {{1, INPUT_DESC(data)}, {2, INPUT_DESC(pred)}}; +OUTPUT_MAP(Switch) = {{0, OUTPUT_DESC(output_false)}, {1, OUTPUT_DESC(output_true)}}; +ATTR_MAP(Switch) = EMPTY_ATTR_MAP; + +// AddN +INPUT_MAP(AddN) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(AddN) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(AddN) = {{"n", ATTR_DESC(N, AnyTraits())}}; +OUTPUT_MAP(AddN) = {{0, OUTPUT_DESC(y)}}; + +// Mul +INPUT_MAP(Mul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Mul) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Mul) = {{0, OUTPUT_DESC(y)}}; + +// RealDiv +INPUT_MAP(RealDiv) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(RealDiv) = EMPTY_ATTR_MAP; +OUTPUT_MAP(RealDiv) = {{0, OUTPUT_DESC(y)}}; + +// Cast +INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits())}}; +ATTR_MAP(Cast) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; + +// Case +INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}}; +DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}}; +ATTR_MAP(Case) = EMPTY_ATTR_MAP; +DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}}; +DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}}; + +// Reciprocal +INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Reciprocal) = {{0, OUTPUT_DESC(y)}}; + +// Sub +INPUT_MAP(Sub) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Sub) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sub) = {{0, OUTPUT_DESC(y)}}; + +// SplitD +INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits())}, + {"output_num", ATTR_DESC(num_split, AnyTraits())}}; +DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}}; + +// Range +INPUT_MAP(RangeD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(RangeD) = {{"start", ATTR_DESC(start, AnyTraits())}, + {"limit", ATTR_DESC(limit, AnyTraits())}, + {"delta", ATTR_DESC(delta, AnyTraits())}}; +OUTPUT_MAP(RangeD) = {{0, OUTPUT_DESC(y)}}; + +// Neg +INPUT_MAP(Neg) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Neg) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Neg) = {{0, OUTPUT_DESC(y)}}; + +// Transpose +INPUT_MAP(TransposeD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(TransposeD) = {{2, ATTR_DESC(perm, AnyTraits(), AnyTraits>())}}; +ATTR_MAP(TransposeD) = EMPTY_ATTR_MAP; +// Do not set Transpose operator output descriptor + +// DropOutGenMask +INPUT_MAP(DropOutGenMask) = {{1, INPUT_DESC(shape)}, {2, INPUT_DESC(prob)}}; +ATTR_MAP(DropOutGenMask) = {{"Seed0", ATTR_DESC(seed, AnyTraits())}, + {"Seed1", ATTR_DESC(seed2, AnyTraits())}}; +OUTPUT_MAP(DropOutGenMask) = {{0, OUTPUT_DESC(y)}}; + +// Pack +INPUT_MAP(Pack) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(Pack) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(Pack) = {{"num", ATTR_DESC(N, AnyTraits())}, {"axis", ATTR_DESC(axis, AnyTraits())}}; +OUTPUT_MAP(Pack) = {{0, OUTPUT_DESC(y)}}; + +// ConcatD +INPUT_MAP(ConcatD) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(ConcatD) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(ConcatD) = { + {"axis", ATTR_DESC(concat_dim, AnyTraits())}, + {"inputNums", ATTR_DESC(N, AnyTraits())}, +}; +OUTPUT_MAP(ConcatD) = {{0, OUTPUT_DESC(y)}}; + +// Less +INPUT_MAP(Less) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Less) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Less) = {{0, OUTPUT_DESC(y)}}; + +// Rsqrt +INPUT_MAP(Rsqrt) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Rsqrt) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Rsqrt) = {{0, OUTPUT_DESC(y)}}; + +// Sqrt +INPUT_MAP(Sqrt) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sqrt) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sqrt) = {{0, OUTPUT_DESC(y)}}; + +// Square +INPUT_MAP(Square) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Square) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Square) = {{0, OUTPUT_DESC(y)}}; + +// SquareSumAll +INPUT_MAP(SquareSumAll) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(SquareSumAll) = EMPTY_ATTR_MAP; +OUTPUT_MAP(SquareSumAll) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + +// Tanh +INPUT_MAP(Tanh) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Tanh) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Tanh) = {{0, OUTPUT_DESC(y)}}; + +// TanhGrad +INPUT_MAP(TanhGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(TanhGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}}; + +// ReduceMinD +INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceMinD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}}; + +// ReduceMaxD +INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceMaxD) = { + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}}; + +// Maximum +INPUT_MAP(Maximum) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Maximum) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Maximum) = {{0, OUTPUT_DESC(y)}}; + +// Minimum +INPUT_MAP(Minimum) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Minimum) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Minimum) = {{0, OUTPUT_DESC(y)}}; + +// MaximumGrad +INPUT_MAP(MaximumGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grads)}}; +ATTR_MAP(MaximumGrad) = {{"grad_x", ATTR_DESC(grad_x, AnyTraits())}, + {"grad_y", ATTR_DESC(grad_y, AnyTraits())}}; +OUTPUT_MAP(MaximumGrad) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + +// MinimumGrad +INPUT_MAP(MinimumGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grads)}}; +ATTR_MAP(MinimumGrad) = {{"grad_x", ATTR_DESC(grad_x, AnyTraits())}, + {"grad_y", ATTR_DESC(grad_y, AnyTraits())}}; +OUTPUT_MAP(MinimumGrad) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; + +// Pow +INPUT_MAP(Pow) = { + {1, INPUT_DESC(x1)}, + {2, INPUT_DESC(x2)}, +}; +ATTR_MAP(Pow) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Pow) = {{0, OUTPUT_DESC(y)}}; + +// Equal +INPUT_MAP(Equal) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Equal) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Equal) = {{0, OUTPUT_DESC(y)}}; + +// NotEqual +INPUT_MAP(NotEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(NotEqual) = EMPTY_ATTR_MAP; +OUTPUT_MAP(NotEqual) = {{0, OUTPUT_DESC(y)}}; + +// Log +INPUT_MAP(Log) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Log) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Log) = {{0, OUTPUT_DESC(y)}}; + +// LogicalAnd +INPUT_MAP(LogicalAnd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(LogicalAnd) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LogicalAnd) = {{0, OUTPUT_DESC(y)}}; + +// LogicalOr +INPUT_MAP(LogicalOr) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(LogicalOr) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LogicalOr) = {{0, OUTPUT_DESC(y)}}; + +// LogicalNot +INPUT_MAP(LogicalNot) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(LogicalNot) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LogicalNot) = {{0, OUTPUT_DESC(y)}}; + +// Greater +INPUT_MAP(Greater) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Greater) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Greater) = {{0, OUTPUT_DESC(y)}}; + +// LogSoftmaxGrad +INPUT_MAP(LogSoftmaxGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}}; +ATTR_MAP(LogSoftmaxGrad) = { + {"axis", ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}}; +OUTPUT_MAP(LogSoftmaxGrad) = {{0, OUTPUT_DESC(y)}}; + +// Select +INPUT_MAP(Select) = {{1, INPUT_DESC(condition)}, {2, INPUT_DESC(x1)}, {3, INPUT_DESC(x2)}}; +ATTR_MAP(Select) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Select) = {{0, OUTPUT_DESC(y)}}; + +// LessEqual +INPUT_MAP(LessEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(LessEqual) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LessEqual) = {{0, OUTPUT_DESC(y)}}; + +// LogSoftmaxV2 +INPUT_MAP(LogSoftmaxV2) = {{1, INPUT_DESC(logits)}}; +ATTR_MAP(LogSoftmaxV2) = { + {"axis", ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; +OUTPUT_MAP(LogSoftmaxV2) = {{0, OUTPUT_DESC(logsoftmax)}}; + +// RandomChoiceWithMask +INPUT_MAP(RandomChoiceWithMask) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(RandomChoiceWithMask) = {{"count", ATTR_DESC(count, AnyTraits())}, + {"seed", ATTR_DESC(seed, AnyTraits())}, + {"seed2", ATTR_DESC(seed2, AnyTraits())}}; +OUTPUT_MAP(RandomChoiceWithMask) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mask)}}; + +// TruncatedNormal +INPUT_MAP(TruncatedNormal) = {{1, INPUT_DESC(shape)}}; +ATTR_MAP(TruncatedNormal) = {{"seed", ATTR_DESC(seed, AnyTraits())}, + {"seed2", ATTR_DESC(seed2, AnyTraits())}}; +OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}}; + +// StridedSliceGrad +INPUT_MAP(StridedSliceGrad) = { + {1, INPUT_DESC(dy)}, {2, INPUT_DESC(shape)}, {3, INPUT_DESC(begin)}, {4, INPUT_DESC(end)}, {5, INPUT_DESC(strides)}}; +ATTR_MAP(StridedSliceGrad) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits())}, + {"end_mask", ATTR_DESC(end_mask, AnyTraits())}, + {"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits())}, + {"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits())}, + {"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits())}}; +OUTPUT_MAP(StridedSliceGrad) = {{0, OUTPUT_DESC(output)}}; + +// Gelu +INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Gelu) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}}; + +// GeluGrad +INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}}; +ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; + +// StridedSlice +INPUT_MAP(StridedSlice) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(begin)}, {3, INPUT_DESC(end)}, {4, INPUT_DESC(strides)}}; +ATTR_MAP(StridedSlice) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits())}, + {"end_mask", ATTR_DESC(end_mask, AnyTraits())}, + {"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits())}, + {"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits())}, + {"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits())}}; +OUTPUT_MAP(StridedSlice) = {{0, OUTPUT_DESC(y)}}; + +// UnsortedSegmentSum +INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; +INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits())}}; +ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; +OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; + +// UnsortedSegmentMin +INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}}; +ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP; +OUTPUT_MAP(UnsortedSegmentMin) = {{0, OUTPUT_DESC(y)}}; + +// ExpandDims +INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; +ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; +OUTPUT_MAP(ExpandDims) = {{0, OUTPUT_DESC(y)}}; + +// Squeeze +INPUT_MAP(Squeeze) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Squeeze) = {{"axis", ATTR_DESC(axis, AnyTraits(), AnyTraits>())}}; +OUTPUT_MAP(Squeeze) = {{0, OUTPUT_DESC(y)}}; + +// SGD +INPUT_MAP(SGD) = {{1, INPUT_DESC(parameters)}, {2, INPUT_DESC(gradient)}, {3, INPUT_DESC(learning_rate)}, + {4, INPUT_DESC(accum)}, {5, INPUT_DESC(momentum)}, {6, INPUT_DESC(stat)}}; +ATTR_MAP(SGD) = {{"dampening", ATTR_DESC(dampening, AnyTraits())}, + {"weight_decay", ATTR_DESC(weight_decay, AnyTraits())}, + {"nesterov", ATTR_DESC(nesterov, AnyTraits())}}; +OUTPUT_MAP(SGD) = {{0, OUTPUT_DESC(parameters)}}; + +// LayerNorm +INPUT_MAP(LayerNorm) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(gamma)}, {3, INPUT_DESC(beta)}}; +ATTR_MAP(LayerNorm) = {{"begin_norm_axis", ATTR_DESC(begin_norm_axis, AnyTraits())}, + {"begin_params_axis", ATTR_DESC(begin_params_axis, AnyTraits())}, + {"epsilon", ATTR_DESC(epsilon, AnyTraits())}}; +OUTPUT_MAP(LayerNorm) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mean)}, {2, OUTPUT_DESC(variance)}}; + +// LayerNormGrad +INPUT_MAP(LayerNormGrad) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(dy)}, {3, INPUT_DESC(variance)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(gamma)}}; +ATTR_MAP(LayerNormGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(LayerNormGrad) = {{0, OUTPUT_DESC(pd_x)}, {1, OUTPUT_DESC(pd_gamma)}, {2, OUTPUT_DESC(pd_beta)}}; + +// BatchMatMul +INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(BatchMatMul) = {{"transpose_x1", ATTR_DESC(adj_x1, AnyTraits())}, + {"transpose_x2", ATTR_DESC(adj_x2, AnyTraits())}}; +OUTPUT_MAP(BatchMatMul) = {{0, OUTPUT_DESC(y)}}; + +// DropoutDoMask +INPUT_MAP(DropOutDoMask) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(mask)}, {3, INPUT_DESC(keep_prob)}}; +ATTR_MAP(DropOutDoMask) = EMPTY_ATTR_MAP; +OUTPUT_MAP(DropOutDoMask) = {{0, OUTPUT_DESC(y)}}; + +// NPUGetFloatStatus +INPUT_MAP(NPUGetFloatStatus) = {{1, INPUT_DESC(addr)}}; +OUTPUT_MAP(NPUGetFloatStatus) = {{0, OUTPUT_DESC(data)}}; +ATTR_MAP(NPUGetFloatStatus) = EMPTY_ATTR_MAP; + +// NPUAllocFloatStatus +INPUT_MAP(NPUAllocFloatStatus) = EMPTY_INPUT_MAP; +ATTR_MAP(NPUAllocFloatStatus) = EMPTY_ATTR_MAP; +OUTPUT_MAP(NPUAllocFloatStatus) = {{0, OUTPUT_DESC(data)}}; + +// NPUClearFloatStatus +INPUT_MAP(NPUClearFloatStatus) = {{1, INPUT_DESC(addr)}}; +OUTPUT_MAP(NPUClearFloatStatus) = {{0, OUTPUT_DESC(data)}}; +ATTR_MAP(NPUClearFloatStatus) = EMPTY_ATTR_MAP; + +// Abs +INPUT_MAP(Abs) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Abs) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Abs) = {{0, OUTPUT_DESC(y)}}; + +// AbsGrad +INPUT_MAP(AbsGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; +ATTR_MAP(AbsGrad) = EMPTY_ATTR_MAP; +OUTPUT_MAP(AbsGrad) = {{0, OUTPUT_DESC(z)}}; + +// BinaryCrossEntropy +INPUT_MAP(BinaryCrossEntropy) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(weight)}}; +ATTR_MAP(BinaryCrossEntropy) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; +OUTPUT_MAP(BinaryCrossEntropy) = {{0, OUTPUT_DESC(output)}}; + +// BinaryCrossEntropyGrad +INPUT_MAP(BinaryCrossEntropyGrad) = { + {1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(grad_output)}, {4, INPUT_DESC(weight)}}; +ATTR_MAP(BinaryCrossEntropyGrad) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; +OUTPUT_MAP(BinaryCrossEntropyGrad) = {{0, OUTPUT_DESC(output)}}; + +// SparseApplyAdagradD +INPUT_MAP(SparseApplyAdagradD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; +ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits())}, + {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; + +// ApplyProximalAdagradD +INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, + {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; +ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; + +// SparseApplyFtrlD +INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, + {2, INPUT_DESC(accum)}, + {3, INPUT_DESC(linear)}, + {4, INPUT_DESC(grad)}, + {5, INPUT_DESC(indices)}}; +ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, + {"lr", ATTR_DESC(lr, AnyTraits())}, + {"l1", ATTR_DESC(l1, AnyTraits())}, + {"l2", ATTR_DESC(l2, AnyTraits())}, + {"lr_power", ATTR_DESC(lr_power, AnyTraits())}}; +OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}}; + +// SpaceToDepth +INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SpaceToDepth) = {{"block_size", ATTR_DESC(block_size, AnyTraits())}}; +OUTPUT_MAP(SpaceToDepth) = {{0, OUTPUT_DESC(y)}}; + +// DepthToSpace +INPUT_MAP(DepthToSpace) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(DepthToSpace) = {{"block_size", ATTR_DESC(block_size, AnyTraits())}}; +OUTPUT_MAP(DepthToSpace) = {{0, OUTPUT_DESC(y)}}; + +// Sign +INPUT_MAP(Sign) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Sign) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Sign) = {{0, OUTPUT_DESC(y)}}; + +// Round +INPUT_MAP(Round) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Round) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; + +// ApplyFtrlD +INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, + {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, + {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; +ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; + +// Diag +INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(Diag) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}}; + +// DiagPart +INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP; +OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}}; + +// SpaceToBatchD +INPUT_MAP(SpaceToBatchD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(SpaceToBatchD) = { + {"block_size", ATTR_DESC(block_size, AnyTraits())}, + {"paddings", ATTR_DESC(paddings, AnyTraits>>(), AnyTraits>())}}; +OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}}; + +// BatchToSpaceD +INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(BatchToSpaceD) = { + {"block_size", ATTR_DESC(block_size, AnyTraits())}, + {"crops", ATTR_DESC(crops, AnyTraits>>(), AnyTraits>())}}; +OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}}; + +// Atan2 +INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; +ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; +OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; + +// ApplyRMSPropD +INPUT_MAP(ApplyRMSPropD) = { + {1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(lr)}, {5, INPUT_DESC(grad)}}; +INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits())}, + {7, ATTR_DESC(momentum, AnyTraits())}, + {8, ATTR_DESC(epsilon, AnyTraits())}}; +ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; + +// ApplyCenteredRMSProp +INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, + {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, + {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; +ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; +OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; + +// L2Loss +INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; +OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}}; + +// CTCLoss +INPUT_MAP(CTCLoss) = {{1, INPUT_DESC(inputs)}, + {2, INPUT_DESC(labels_indices)}, + {3, INPUT_DESC(labels_values)}, + {4, INPUT_DESC(sequence_length)}}; +ATTR_MAP(CTCLoss) = { + {"preprocess_collapse_repeated", ATTR_DESC(preprocess_collapse_repeated, AnyTraits())}, + {"ctc_merge_repeated", ATTR_DESC(ctc_merge_repeated, AnyTraits())}, + {"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits())}}; +OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}}; + +// AscendQuant +INPUT_MAP(AscendQuant) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(AscendQuant) = {{"scale", ATTR_DESC(scale, AnyTraits())}, + {"offset", ATTR_DESC(offset, AnyTraits())}, + {"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, + {"round_mode", ATTR_DESC(round_mode, AnyTraits())}}; +OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}}; + +// AscendDequant +INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; +ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, + {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}}; +OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; +#ifdef ENABLE_GE +// Print +INPUT_MAP(Print) = EMPTY_INPUT_MAP; +DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}}; +ATTR_MAP(Print) = EMPTY_ATTR_MAP; +#endif +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/op_declare.h b/mindspore/ccsrc/transform/graph_ir/op_declare.h new file mode 100755 index 0000000000..e493ea0e52 --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/op_declare.h @@ -0,0 +1,505 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_OP_DECLARE_H_ +#define TRANSFORM_OP_DECLARE_H_ + +#include +#include +#include "transform/graph_ir/op_adapter.h" + +namespace mindspore { +namespace transform { +#define DECLARE_OP_ADAPTER(T) \ + using T = ge::op::T; \ + template <> \ + const std::unordered_map OpAdapter::input_map_; \ + template <> \ + const std::unordered_map OpAdapter::attr_map_; + +#define DECLARE_OP_USE_OUTPUT(T) \ + template <> \ + const std::unordered_map OpAdapter::output_map_; + +#define DECLARE_OP_USE_ENUM(T) \ + template <> \ + const std::unordered_map OpAdapter::enum_map_; + +#define DECLARE_OP_USE_INPUT_ATTR(T) \ + template <> \ + const std::unordered_map OpAdapter::input_attr_map_; + +#define DECLARE_OP_USE_DYN_INPUT(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_input_map_; + +#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_subgraph_map_; + +#define DECLARE_OP_USE_DYN_OUTPUT(T) \ + template <> \ + const std::unordered_map OpAdapter::dyn_output_map_; + +template <> +std::unordered_map> OpAdapter::cus_input_map_; +template <> +std::unordered_map> OpAdapter::cus_output_map_; + +DECLARE_OP_ADAPTER(GreaterEqual) +DECLARE_OP_USE_OUTPUT(GreaterEqual) +DECLARE_OP_ADAPTER(SliceD) +DECLARE_OP_USE_INPUT_ATTR(SliceD) +DECLARE_OP_USE_OUTPUT(SliceD) +DECLARE_OP_ADAPTER(AssignAdd) +DECLARE_OP_USE_OUTPUT(AssignAdd) +DECLARE_OP_ADAPTER(AssignSub) +DECLARE_OP_USE_OUTPUT(AssignSub) + +DECLARE_OP_ADAPTER(ReduceMean) +DECLARE_OP_ADAPTER(Multiply) +DECLARE_OP_USE_OUTPUT(Multiply) + +// ** Distributed Operations ** +DECLARE_OP_ADAPTER(HcomReduceScatter) +DECLARE_OP_USE_OUTPUT(HcomReduceScatter) +DECLARE_OP_ADAPTER(HcomBroadcast) +DECLARE_OP_USE_DYN_INPUT(HcomBroadcast) +DECLARE_OP_USE_DYN_OUTPUT(HcomBroadcast) +DECLARE_OP_ADAPTER(HcomAllReduce) +DECLARE_OP_USE_OUTPUT(HcomAllReduce) +DECLARE_OP_ADAPTER(HcomAllGather) +DECLARE_OP_USE_OUTPUT(HcomAllGather) +DECLARE_OP_ADAPTER(Variable) +DECLARE_OP_ADAPTER(ReluGrad) +DECLARE_OP_USE_OUTPUT(ReluGrad) +DECLARE_OP_ADAPTER(BiasAddGrad) +DECLARE_OP_USE_OUTPUT(BiasAddGrad) +DECLARE_OP_ADAPTER(MaxPoolWithArgmax) +DECLARE_OP_USE_OUTPUT(MaxPoolWithArgmax) +DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax) +DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax) +DECLARE_OP_ADAPTER(Conv2D) +DECLARE_OP_USE_ENUM(Conv2D) +DECLARE_OP_USE_OUTPUT(Conv2D) +DECLARE_OP_ADAPTER(ExtractImagePatches) +DECLARE_OP_USE_OUTPUT(ExtractImagePatches) +DECLARE_OP_ADAPTER(Conv2DBackpropInputD) +DECLARE_OP_USE_ENUM(Conv2DBackpropInputD) +DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD) +DECLARE_OP_USE_OUTPUT(Conv2DBackpropInputD) +DECLARE_OP_ADAPTER(Conv2DBackpropFilterD) +DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD) +DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD) +DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD) +DECLARE_OP_ADAPTER(DepthwiseConv2D) +DECLARE_OP_USE_ENUM(DepthwiseConv2D) +DECLARE_OP_USE_OUTPUT(DepthwiseConv2D) +DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropFilterD) +DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropFilterD) +DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropFilterD) +DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD) +DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD) +DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) +DECLARE_OP_ADAPTER(Reshape) +DECLARE_OP_USE_OUTPUT(Reshape) +DECLARE_OP_ADAPTER(TransShape) +DECLARE_OP_USE_INPUT_ATTR(TransShape) +DECLARE_OP_USE_OUTPUT(TransShape) +DECLARE_OP_ADAPTER(Iou) +DECLARE_OP_USE_OUTPUT(Iou) +DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) +DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D) +DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) +DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) +DECLARE_OP_ADAPTER(ApplyAdam) +DECLARE_OP_USE_OUTPUT(ApplyAdam) +DECLARE_OP_ADAPTER(ApplyAdamD) +DECLARE_OP_USE_OUTPUT(ApplyAdamD) +DECLARE_OP_ADAPTER(Relu6) +DECLARE_OP_USE_OUTPUT(Relu6) +DECLARE_OP_ADAPTER(Relu6Grad) +DECLARE_OP_USE_OUTPUT(Relu6Grad) +DECLARE_OP_ADAPTER(ResizeBilinearV2D) +DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D) +DECLARE_OP_ADAPTER(ResizeBilinearV2Grad) +DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad) +DECLARE_OP_ADAPTER(ZerosLike) +DECLARE_OP_USE_OUTPUT(ZerosLike) +DECLARE_OP_ADAPTER(OnesLike) +DECLARE_OP_USE_OUTPUT(OnesLike) +DECLARE_OP_ADAPTER(TensorScatterUpdate) +DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) +DECLARE_OP_ADAPTER(ScatterUpdate) +DECLARE_OP_USE_OUTPUT(ScatterUpdate) +DECLARE_OP_ADAPTER(ScatterNdUpdate) +DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) +DECLARE_OP_ADAPTER(ScatterMax) +DECLARE_OP_USE_OUTPUT(ScatterMax) +DECLARE_OP_ADAPTER(NMSWithMask) +DECLARE_OP_USE_OUTPUT(NMSWithMask) +DECLARE_OP_ADAPTER(Unpack) +DECLARE_OP_USE_DYN_OUTPUT(Unpack) +DECLARE_OP_ADAPTER(CheckValid) +DECLARE_OP_USE_OUTPUT(CheckValid) +DECLARE_OP_ADAPTER(SmoothL1Loss) +DECLARE_OP_USE_OUTPUT(SmoothL1Loss) +DECLARE_OP_ADAPTER(SmoothL1LossGrad) +DECLARE_OP_USE_OUTPUT(SmoothL1LossGrad) +DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogits) +DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogits) +DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogitsGrad) +DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogitsGrad) +DECLARE_OP_ADAPTER(ScatterNdD) +DECLARE_OP_USE_INPUT_ATTR(ScatterNdD) +DECLARE_OP_USE_OUTPUT(ScatterNdD) +DECLARE_OP_ADAPTER(PadD) +DECLARE_OP_USE_OUTPUT(PadD) +DECLARE_OP_ADAPTER(MirrorPad) +DECLARE_OP_USE_OUTPUT(MirrorPad) +DECLARE_OP_ADAPTER(MirrorPadGrad) +DECLARE_OP_USE_OUTPUT(MirrorPadGrad) +DECLARE_OP_ADAPTER(BoundingBoxEncode) +DECLARE_OP_USE_OUTPUT(BoundingBoxEncode) +DECLARE_OP_ADAPTER(BoundingBoxDecode) +DECLARE_OP_USE_OUTPUT(BoundingBoxDecode) +DECLARE_OP_ADAPTER(GatherNd) +DECLARE_OP_USE_OUTPUT(GatherNd) +DECLARE_OP_ADAPTER(ArgMaxD) +DECLARE_OP_USE_OUTPUT(ArgMaxD) +DECLARE_OP_ADAPTER(ArgMinD) +DECLARE_OP_USE_OUTPUT(ArgMinD) +DECLARE_OP_ADAPTER(ArgMaxWithValue) +DECLARE_OP_USE_OUTPUT(ArgMaxWithValue) +DECLARE_OP_ADAPTER(ArgMinWithValue) +DECLARE_OP_USE_OUTPUT(ArgMinWithValue) +DECLARE_OP_ADAPTER(Mul) +DECLARE_OP_USE_OUTPUT(Mul) +DECLARE_OP_ADAPTER(AddN) +DECLARE_OP_USE_DYN_INPUT(AddN) +DECLARE_OP_USE_OUTPUT(AddN) +DECLARE_OP_ADAPTER(Less) +DECLARE_OP_USE_OUTPUT(Less) +DECLARE_OP_ADAPTER(Rsqrt) +DECLARE_OP_USE_OUTPUT(Rsqrt) +DECLARE_OP_ADAPTER(Sqrt) +DECLARE_OP_USE_OUTPUT(Sqrt) +DECLARE_OP_ADAPTER(Square) +DECLARE_OP_USE_OUTPUT(Square) +DECLARE_OP_ADAPTER(SplitD) +DECLARE_OP_USE_DYN_OUTPUT(SplitD) +DECLARE_OP_ADAPTER(SGD) +DECLARE_OP_USE_OUTPUT(SGD) +DECLARE_OP_ADAPTER(SquareSumAll) +DECLARE_OP_USE_OUTPUT(SquareSumAll) + +DECLARE_OP_ADAPTER(Tanh) +DECLARE_OP_USE_OUTPUT(Tanh) +DECLARE_OP_ADAPTER(TanhGrad) +DECLARE_OP_USE_OUTPUT(TanhGrad) +DECLARE_OP_ADAPTER(Maximum) +DECLARE_OP_USE_OUTPUT(Maximum) +DECLARE_OP_ADAPTER(Minimum) +DECLARE_OP_USE_OUTPUT(Minimum) +DECLARE_OP_ADAPTER(MaximumGrad) +DECLARE_OP_USE_OUTPUT(MaximumGrad) +DECLARE_OP_ADAPTER(MinimumGrad) +DECLARE_OP_USE_OUTPUT(MinimumGrad) +DECLARE_OP_ADAPTER(ReduceMinD) +DECLARE_OP_USE_INPUT_ATTR(ReduceMinD) +DECLARE_OP_USE_OUTPUT(ReduceMinD) +DECLARE_OP_ADAPTER(ReduceMaxD) +DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD) +DECLARE_OP_USE_OUTPUT(ReduceMaxD) +DECLARE_OP_ADAPTER(Merge) +DECLARE_OP_USE_DYN_INPUT(Merge) +DECLARE_OP_USE_OUTPUT(Merge) +DECLARE_OP_ADAPTER(Switch) +DECLARE_OP_USE_OUTPUT(Switch) + +DECLARE_OP_ADAPTER(TopK) +DECLARE_OP_USE_OUTPUT(TopK) + +DECLARE_OP_ADAPTER(RealDiv) +DECLARE_OP_USE_OUTPUT(RealDiv) + +DECLARE_OP_ADAPTER(Cast) +DECLARE_OP_USE_INPUT_ATTR(Cast) +DECLARE_OP_USE_OUTPUT(Cast) +DECLARE_OP_ADAPTER(Case) +DECLARE_OP_USE_DYN_INPUT(Case) +DECLARE_OP_USE_DYN_SUBGRAPH(Case) +DECLARE_OP_USE_DYN_OUTPUT(Case) +DECLARE_OP_ADAPTER(Reciprocal) +DECLARE_OP_USE_OUTPUT(Reciprocal) +DECLARE_OP_ADAPTER(Neg) +DECLARE_OP_USE_OUTPUT(Neg) +DECLARE_OP_ADAPTER(TransposeD) +DECLARE_OP_USE_INPUT_ATTR(TransposeD) +// Do not set Transpose operator output descriptor +DECLARE_OP_ADAPTER(Sub) +DECLARE_OP_USE_OUTPUT(Sub) +DECLARE_OP_ADAPTER(DropOutGenMask) +DECLARE_OP_USE_OUTPUT(DropOutGenMask) +DECLARE_OP_ADAPTER(ConcatD) +DECLARE_OP_USE_DYN_INPUT(ConcatD) +DECLARE_OP_USE_OUTPUT(ConcatD) +DECLARE_OP_ADAPTER(Pack) +DECLARE_OP_USE_DYN_INPUT(Pack) +DECLARE_OP_USE_OUTPUT(Pack) + +DECLARE_OP_ADAPTER(Pow) +DECLARE_OP_USE_OUTPUT(Pow) +DECLARE_OP_ADAPTER(Equal) +DECLARE_OP_USE_OUTPUT(Equal) +DECLARE_OP_ADAPTER(NotEqual) +DECLARE_OP_USE_OUTPUT(NotEqual) +DECLARE_OP_ADAPTER(Log) +DECLARE_OP_USE_OUTPUT(Log) +DECLARE_OP_ADAPTER(LogicalAnd) +DECLARE_OP_USE_OUTPUT(LogicalAnd) +DECLARE_OP_ADAPTER(LogicalOr) +DECLARE_OP_USE_OUTPUT(LogicalOr) +DECLARE_OP_ADAPTER(LogicalNot) +DECLARE_OP_USE_OUTPUT(LogicalNot) +DECLARE_OP_ADAPTER(LogSoftmaxGrad) +DECLARE_OP_USE_OUTPUT(LogSoftmaxGrad) + +DECLARE_OP_ADAPTER(RandomChoiceWithMask) +DECLARE_OP_USE_OUTPUT(RandomChoiceWithMask) + +DECLARE_OP_ADAPTER(Select) +DECLARE_OP_USE_OUTPUT(Select) +DECLARE_OP_ADAPTER(LessEqual) +DECLARE_OP_USE_OUTPUT(LessEqual) +DECLARE_OP_ADAPTER(LogSoftmaxV2) +DECLARE_OP_USE_OUTPUT(LogSoftmaxV2) +DECLARE_OP_ADAPTER(TruncatedNormal) +DECLARE_OP_USE_OUTPUT(TruncatedNormal) +DECLARE_OP_ADAPTER(StridedSliceGrad) +DECLARE_OP_USE_OUTPUT(StridedSliceGrad) +DECLARE_OP_ADAPTER(Gelu) +DECLARE_OP_USE_OUTPUT(Gelu) +DECLARE_OP_ADAPTER(GeluGrad) +DECLARE_OP_USE_OUTPUT(GeluGrad) +DECLARE_OP_ADAPTER(StridedSlice) +DECLARE_OP_USE_OUTPUT(StridedSlice) +DECLARE_OP_ADAPTER(UnsortedSegmentSumD) +DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) +DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) +DECLARE_OP_ADAPTER(UnsortedSegmentMin) +DECLARE_OP_USE_OUTPUT(UnsortedSegmentMin) +DECLARE_OP_ADAPTER(ExpandDims) +DECLARE_OP_USE_OUTPUT(ExpandDims) +DECLARE_OP_ADAPTER(Squeeze) +DECLARE_OP_USE_OUTPUT(Squeeze) +DECLARE_OP_ADAPTER(LayerNorm) +DECLARE_OP_USE_OUTPUT(LayerNorm) +DECLARE_OP_ADAPTER(LayerNormGrad) +DECLARE_OP_USE_OUTPUT(LayerNormGrad) +DECLARE_OP_ADAPTER(BatchMatMul) +DECLARE_OP_USE_OUTPUT(BatchMatMul) +DECLARE_OP_ADAPTER(DropOutDoMask) +DECLARE_OP_USE_OUTPUT(DropOutDoMask) +// ** Mix-precision Operations ** +DECLARE_OP_ADAPTER(NPUGetFloatStatus) +DECLARE_OP_USE_OUTPUT(NPUGetFloatStatus) +DECLARE_OP_ADAPTER(NPUAllocFloatStatus) +DECLARE_OP_USE_OUTPUT(NPUAllocFloatStatus) +DECLARE_OP_ADAPTER(NPUClearFloatStatus) +DECLARE_OP_USE_OUTPUT(NPUClearFloatStatus) +DECLARE_OP_ADAPTER(MatMulV2) +DECLARE_OP_USE_OUTPUT(MatMulV2) + +DECLARE_OP_ADAPTER(SoftmaxCrossEntropyWithLogits) +DECLARE_OP_USE_OUTPUT(SoftmaxCrossEntropyWithLogits) + +DECLARE_OP_ADAPTER(MeanGrad) +DECLARE_OP_USE_INPUT_ATTR(MeanGrad) + +DECLARE_OP_ADAPTER(Assign) +DECLARE_OP_USE_OUTPUT(Assign) +DECLARE_OP_ADAPTER(Constant) +DECLARE_OP_USE_OUTPUT(Constant) +DECLARE_OP_ADAPTER(ApplyMomentumD) +DECLARE_OP_USE_OUTPUT(ApplyMomentumD) +// ** Summary Operations ** +DECLARE_OP_ADAPTER(Summary) + +// fully supported +DECLARE_OP_ADAPTER(Add) +DECLARE_OP_USE_OUTPUT(Add) +DECLARE_OP_ADAPTER(Const) +DECLARE_OP_USE_OUTPUT(Const) +DECLARE_OP_ADAPTER(Cos) +DECLARE_OP_USE_OUTPUT(Cos) + +DECLARE_OP_ADAPTER(Acos) +DECLARE_OP_USE_OUTPUT(Acos) +DECLARE_OP_ADAPTER(AcosGrad) +DECLARE_OP_USE_OUTPUT(AcosGrad) +DECLARE_OP_ADAPTER(Acosh) +DECLARE_OP_USE_OUTPUT(Acosh) +DECLARE_OP_ADAPTER(AcoshGrad) +DECLARE_OP_USE_OUTPUT(AcoshGrad) + +DECLARE_OP_ADAPTER(Floor) +DECLARE_OP_USE_OUTPUT(Floor) +DECLARE_OP_ADAPTER(FloorDiv) +DECLARE_OP_USE_OUTPUT(FloorDiv) +DECLARE_OP_ADAPTER(FloorMod) +DECLARE_OP_USE_OUTPUT(FloorMod) +DECLARE_OP_ADAPTER(Sin) +DECLARE_OP_USE_OUTPUT(Sin) +DECLARE_OP_ADAPTER(Exp) +DECLARE_OP_USE_OUTPUT(Exp) + +DECLARE_OP_ADAPTER(ReduceAllD) +DECLARE_OP_USE_INPUT_ATTR(ReduceAllD) +DECLARE_OP_USE_OUTPUT(ReduceAllD) +DECLARE_OP_ADAPTER(ReduceSumD) +DECLARE_OP_USE_INPUT_ATTR(ReduceSumD) +DECLARE_OP_USE_OUTPUT(ReduceSumD) +DECLARE_OP_ADAPTER(ReduceMeanD) +DECLARE_OP_USE_INPUT_ATTR(ReduceMeanD) +DECLARE_OP_USE_OUTPUT(ReduceMeanD) +DECLARE_OP_ADAPTER(ReduceProdD) +DECLARE_OP_USE_INPUT_ATTR(ReduceProdD) +DECLARE_OP_USE_OUTPUT(ReduceProdD) +DECLARE_OP_ADAPTER(CumprodD) +DECLARE_OP_USE_INPUT_ATTR(CumprodD) +DECLARE_OP_USE_OUTPUT(CumprodD) + +DECLARE_OP_ADAPTER(TileD) +DECLARE_OP_USE_INPUT_ATTR(TileD) +DECLARE_OP_USE_OUTPUT(TileD) +DECLARE_OP_ADAPTER(OneHot) +DECLARE_OP_USE_OUTPUT(OneHot) +DECLARE_OP_ADAPTER(GatherV2D) +DECLARE_OP_USE_INPUT_ATTR(GatherV2D) +DECLARE_OP_USE_OUTPUT(GatherV2D) +DECLARE_OP_ADAPTER(RangeD) +DECLARE_OP_USE_OUTPUT(RangeD) + +DECLARE_OP_ADAPTER(Data) +DECLARE_OP_ADAPTER(BiasAdd) +DECLARE_OP_USE_OUTPUT(BiasAdd) +DECLARE_OP_ADAPTER(BatchNorm) +DECLARE_OP_USE_OUTPUT(BatchNorm) +DECLARE_OP_ADAPTER(BatchNormGrad) +DECLARE_OP_USE_OUTPUT(BatchNormGrad) +DECLARE_OP_ADAPTER(Relu) +DECLARE_OP_USE_OUTPUT(Relu) +DECLARE_OP_ADAPTER(PRelu) +DECLARE_OP_USE_OUTPUT(PRelu) +DECLARE_OP_ADAPTER(Elu) +DECLARE_OP_USE_OUTPUT(Elu) + +DECLARE_OP_ADAPTER(EluGrad) +DECLARE_OP_USE_OUTPUT(EluGrad) +DECLARE_OP_ADAPTER(PReluGrad) +DECLARE_OP_USE_OUTPUT(PReluGrad) + +DECLARE_OP_ADAPTER(L2Normalize) +DECLARE_OP_USE_OUTPUT(L2Normalize) + +DECLARE_OP_ADAPTER(CumsumD) +DECLARE_OP_USE_INPUT_ATTR(CumsumD) +DECLARE_OP_USE_OUTPUT(CumsumD) +DECLARE_OP_ADAPTER(L2NormalizeGrad) +DECLARE_OP_USE_OUTPUT(L2NormalizeGrad) +DECLARE_OP_ADAPTER(Sigmoid) +DECLARE_OP_USE_OUTPUT(Sigmoid) +DECLARE_OP_ADAPTER(SigmoidGrad) +DECLARE_OP_USE_OUTPUT(SigmoidGrad) +DECLARE_OP_ADAPTER(SoftmaxV2) +DECLARE_OP_USE_OUTPUT(SoftmaxV2) +DECLARE_OP_ADAPTER(SoftmaxGrad) +DECLARE_OP_USE_OUTPUT(SoftmaxGrad) +DECLARE_OP_ADAPTER(Greater) +DECLARE_OP_USE_OUTPUT(Greater) +DECLARE_OP_ADAPTER(Flatten) +DECLARE_OP_USE_OUTPUT(Flatten) +DECLARE_OP_ADAPTER(GatherV2) +DECLARE_OP_USE_OUTPUT(GatherV2) +DECLARE_OP_ADAPTER(MaxPool) +DECLARE_OP_USE_OUTPUT(MaxPool) +DECLARE_OP_ADAPTER(MaxPoolGrad) +DECLARE_OP_USE_OUTPUT(MaxPoolGrad) +DECLARE_OP_ADAPTER(AvgPool) +DECLARE_OP_USE_OUTPUT(AvgPool) +DECLARE_OP_ADAPTER(AvgPoolGrad) +DECLARE_OP_USE_OUTPUT(AvgPoolGrad) +DECLARE_OP_ADAPTER(ROIAlign) +DECLARE_OP_USE_OUTPUT(ROIAlign) +DECLARE_OP_ADAPTER(ROIAlignGrad) +DECLARE_OP_USE_OUTPUT(ROIAlignGrad) +DECLARE_OP_ADAPTER(Abs) +DECLARE_OP_USE_OUTPUT(Abs) +DECLARE_OP_ADAPTER(AbsGrad) +DECLARE_OP_USE_OUTPUT(AbsGrad) +DECLARE_OP_ADAPTER(BinaryCrossEntropy) +DECLARE_OP_USE_OUTPUT(BinaryCrossEntropy) +DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) +DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) +DECLARE_OP_ADAPTER(SparseApplyAdagradD) +DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) +DECLARE_OP_ADAPTER(ApplyProximalAdagradD) +DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) +DECLARE_OP_ADAPTER(SpaceToDepth) +DECLARE_OP_USE_OUTPUT(SpaceToDepth) +DECLARE_OP_ADAPTER(DepthToSpace) +DECLARE_OP_USE_OUTPUT(DepthToSpace) +DECLARE_OP_ADAPTER(Sign) +DECLARE_OP_USE_OUTPUT(Sign) +DECLARE_OP_ADAPTER(LarsV2Update) +DECLARE_OP_USE_OUTPUT(LarsV2Update) +DECLARE_OP_ADAPTER(Round) +DECLARE_OP_USE_OUTPUT(Round) +DECLARE_OP_ADAPTER(ApplyFtrlD) +DECLARE_OP_USE_OUTPUT(ApplyFtrlD) +DECLARE_OP_ADAPTER(SparseApplyFtrlD) +DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) +DECLARE_OP_ADAPTER(Diag) +DECLARE_OP_USE_OUTPUT(Diag) +DECLARE_OP_ADAPTER(DiagPart) +DECLARE_OP_USE_OUTPUT(DiagPart) +DECLARE_OP_ADAPTER(SpaceToBatchD) +DECLARE_OP_USE_OUTPUT(SpaceToBatchD) +DECLARE_OP_ADAPTER(BatchToSpaceD) +DECLARE_OP_USE_OUTPUT(BatchToSpaceD) +DECLARE_OP_ADAPTER(Atan2) +DECLARE_OP_USE_OUTPUT(Atan2) +DECLARE_OP_ADAPTER(ApplyRMSPropD) +DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) +DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) +DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) +DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) +DECLARE_OP_ADAPTER(L2Loss) +DECLARE_OP_USE_OUTPUT(L2Loss) +DECLARE_OP_ADAPTER(CTCLoss) +DECLARE_OP_USE_OUTPUT(CTCLoss) +DECLARE_OP_ADAPTER(AscendQuant) +DECLARE_OP_USE_OUTPUT(AscendQuant) +DECLARE_OP_ADAPTER(AscendDequant) +DECLARE_OP_USE_OUTPUT(AscendDequant) +#ifdef ENABLE_GE +DECLARE_OP_ADAPTER(Print) +DECLARE_OP_USE_DYN_INPUT(Print) +#endif +} // namespace transform +} // namespace mindspore +#endif // TRANSFORM_OP_DECLARE_H_ diff --git a/mindspore/ccsrc/transform/types.h b/mindspore/ccsrc/transform/graph_ir/types.h similarity index 100% rename from mindspore/ccsrc/transform/types.h rename to mindspore/ccsrc/transform/graph_ir/types.h diff --git a/mindspore/ccsrc/transform/graph_ir/util.cc b/mindspore/ccsrc/transform/graph_ir/util.cc new file mode 100644 index 0000000000..6ae665d69f --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/util.cc @@ -0,0 +1,452 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transform/graph_ir/util.h" + +#include +#include +#include + +#include "securec/include/securec.h" +#include "utils/convert_utils.h" +#include "utils/utils.h" + +namespace mindspore { +namespace transform { +using std::make_shared; +using std::shared_ptr; +using std::string; +using std::vector; + +const size_t kErrorSize = 0; + +vector TransformUtil::ConvertIntToList(int64_t data, int size) { + vector list{}; + if (size <= 0) { + MS_LOG(WARNING) << "size <= 0"; + return list; + } + for (int i = 0; i < size; ++i) { + list.push_back(data); + } + return list; +} + +static std::map datatype_trans_map = { + {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT}, + {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8}, + {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32}, + {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8}, + {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, + {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; + +GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { + MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; + if (datatype_trans_map.find(type) != datatype_trans_map.end()) { + return datatype_trans_map[type]; + } else { + return GeDataType::DT_UNDEFINED; + } +} + +static std::map datatype_size_map = { + {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float + {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)}, + {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)}, + {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)}, + {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, + {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; + +size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { + if (datatype_size_map.find(type) != datatype_size_map.end()) { + return datatype_size_map[type]; + } else { + MS_LOG(ERROR) << "Illegal tensor data type!"; + return kErrorSize; + } +} + +GeFormat TransformUtil::ConvertFormat(const string &format) { + if (format == kOpFormat_NCHW) { + return GeFormat::FORMAT_NCHW; + } else if (format == kOpFormat_NC1HWC0) { + return GeFormat::FORMAT_NC1HWC0; + } else if (format == kOpFormat_NHWC) { + return GeFormat::FORMAT_NHWC; + } else if (format == kOpFormat_HWCN) { + return GeFormat::FORMAT_HWCN; + } else { + return GeFormat::FORMAT_ND; + } +} + +static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } + +std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, + const MeDataType &me_type, const std::string &format) { + // convert me shape to ge shape + std::vector ge_shape; + + if (me_shape.size() == 1) { + ge_shape.push_back(static_cast(me_shape[0])); + } else { + ge_shape.resize(me_shape.size()); + (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc); + } + + GeShape shape(ge_shape); + if (shape.GetDimNum() == 0) { + MS_LOG(INFO) << "The dims size of Ge tensor is zero"; + } + // convert me format to ge format + GeFormat ge_format = ConvertFormat(format); + if (ge_format == GeFormat::FORMAT_ND) { + MS_LOG(ERROR) << "undefined data format : " << static_cast(ge_format); + return nullptr; + } + // convert me datatype to ge datatype + GeDataType data_type = ConvertDataType(me_type); + if (data_type == GeDataType::DT_UNDEFINED) { + MS_LOG(ERROR) << "undefined data type :" << me_type; + return nullptr; + } + + auto desc = std::make_shared(shape, ge_format, data_type); + if (desc == nullptr) { + MS_LOG(ERROR) << "Create GeTensorDesc failed!"; + return nullptr; + } + MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size(); + desc->SetRealDimCnt(SizeToInt(me_shape.size())); + return desc; +} + +// if failed, return empty vector. +std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, + const std::string &format) { + std::vector ge_tensors; + + for (size_t index = 0; index < me_tensors.size(); index++) { + MS_EXCEPTION_IF_NULL(me_tensors[index]); + MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize(); + auto shape = me_tensors[index]->shape(); + std::string shape_str; + for (size_t i = 0; i < shape.size(); i++) { + shape_str += std::to_string(shape[i]); + shape_str += " "; + } + MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}"; + MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type(); + + auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format); + if (ge_tensor_ptr != nullptr) { + ge_tensors.emplace_back(ge_tensor_ptr); + } else { + MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!"; + ge_tensors.clear(); + return ge_tensors; + } + } + return ge_tensors; +} + +GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) { + // get tensor data type size + MS_EXCEPTION_IF_NULL(tensor); + size_t type_size = GetDataTypeSize(tensor->data_type()); + if (type_size == kErrorSize) { + MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size; + return nullptr; + } + size_t elements_num = IntToSize(tensor->ElementsNum()); + if (UINT_MAX / type_size < elements_num) { + MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size + << " overflowed UINT_MAX: " << UINT_MAX << "."; + return nullptr; + } + + // get tensor buff size + size_t data_buff_size = elements_num * type_size; + if (data_buff_size == 0) { + MS_LOG(INFO) << "The Me Tensor data buff size is 0."; + } + // create ge tensor + auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format); + if (desc == nullptr) { + MS_LOG(ERROR) << "Failed to get Tensor Desc"; + return nullptr; + } + GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); + if (tensor_ptr != nullptr) { + MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; + } + return tensor_ptr; +} + +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims) { + std::vector outputs; + + for (size_t index = 0; index < ge_tensors.size(); index++) { + MeTensorPtr me_tensor_ptr = nullptr; + if (index < request_dims.size()) { + me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]); + } else { + std::vector empty_shape; + me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape); + } + + if (me_tensor_ptr != nullptr) { + outputs.emplace_back(me_tensor_ptr); + } else { + MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; + return outputs; + } + } + return outputs; +} + +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { + std::vector outputs; + + for (size_t index = 0; index < ge_tensors.size(); index++) { + MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]); + if (me_tensor_ptr != nullptr) { + outputs.emplace_back(me_tensor_ptr); + } else { + MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; + return outputs; + } + } + return outputs; +} + +MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) { + switch (type) { + case GeDataType::DT_FLOAT16: + return MeDataType::kNumberTypeFloat16; + case GeDataType::DT_FLOAT: + return MeDataType::kNumberTypeFloat32; + case GeDataType::DT_DOUBLE: + return MeDataType::kNumberTypeFloat64; + case GeDataType::DT_INT64: + return MeDataType::kNumberTypeInt64; + case GeDataType::DT_INT32: + return MeDataType::kNumberTypeInt32; + case GeDataType::DT_INT16: + return MeDataType::kNumberTypeInt16; + case GeDataType::DT_INT8: + return MeDataType::kNumberTypeInt8; + case GeDataType::DT_BOOL: + return MeDataType::kNumberTypeBool; + case GeDataType::DT_UINT8: + return MeDataType::kNumberTypeUInt8; + case GeDataType::DT_UINT16: + return MeDataType::kNumberTypeUInt16; + case GeDataType::DT_UINT32: + return MeDataType::kNumberTypeUInt32; + case GeDataType::DT_UINT64: + return MeDataType::kNumberTypeUInt64; + case GeDataType::DT_UNDEFINED: + case GeDataType::DT_DUAL_SUB_UINT8: + case GeDataType::DT_DUAL_SUB_INT8: + case GeDataType::DT_DUAL: + return MeDataType::kTypeUnknown; + default: + return MeDataType::kTypeUnknown; + } +} + +namespace { +bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { + MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); + MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); + + const int GE_DIMS = 4; + std::vector ge_dims = ge_shape.GetDims(); + if (request_dims.size() > ge_dims.size()) { + MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's"; + return false; + } + + // convert NHWC to NCHW + if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) && + (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) { + MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; + return true; + } + + std::string::size_type i = 0; + for (; i < request_dims.size(); i++) { + if (ge_dims[i] != request_dims[i]) { + MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's"; + return false; + } + } + + for (; i < ge_dims.size(); i++) { + if (ge_dims[i] != 1) { + MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1"; + return false; + } + } + MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; + return true; +} +} // namespace + +GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { + std::vector ge_dims; + (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); + return GeShape(ge_dims); +} + +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { + std::vector me_dims; + std::vector ge_dims = ge_shape.GetDims(); + (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); + return me_dims; +} + +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { + vector ret; + if (ge_shape.GetDimNum() == 0) { + MS_LOG(DEBUG) << "GeTensor's shape is scalar"; + return ret; + } + + if (IsGeShapeCompatible(ge_shape, request_dims) == true) { + ret = request_dims; + } else { + MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape"; + ret = ConvertGeShape(ge_shape); + } + return ret; +} + +MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type) { + MeTensor me_tensor(me_type, me_dims); + + // Get the writable data pointer of the tensor and cast it to its data type + auto me_data_ptr = reinterpret_cast(me_tensor.data_c()); + size_t me_data_size = static_cast(me_tensor.data().nbytes()); + MS_EXCEPTION_IF_NULL(me_data_ptr); + MS_EXCEPTION_IF_NULL(ge_tensor); + if (me_data_size < ge_tensor->GetSize()) { + MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor [" + << ge_tensor->GetSize() << " bytes]"; + return nullptr; + } + + // Copy or use the writable data pointer of the ME tensor + MS_EXCEPTION_IF_NULL(ge_tensor->GetData()); + if (ge_tensor->GetSize() == 0) { + MS_LOG(ERROR) << "GE tensor data size is zero!"; + return nullptr; + } + + // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB + // which is the size limit of memcpy_s + memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize()); + + return make_shared(me_tensor); +} + +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { + MS_EXCEPTION_IF_NULL(ge_tensor); + GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); + vector me_dims = ConvertGeShape(ge_shape); + + TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); + if (type_id == MeDataType::kTypeUnknown) { + MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " + << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + return nullptr; + } + return GenerateMeTensor(ge_tensor, me_dims, type_id); +} + +// if request_dims is empty, use ge tensor's shape,otherwise convert to request shape +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { + MS_EXCEPTION_IF_NULL(ge_tensor); + GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); + vector me_dims = ConvertGeShape(ge_shape, request_dims); + MS_LOG(INFO) << "GE tensor type is " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + // Create a tensor with wanted data type and shape + TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); + if (type_id == MeDataType::kTypeUnknown) { + MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " + << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + return nullptr; + } + return GenerateMeTensor(ge_tensor, me_dims, type_id); +} + +std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) { + std::string ret; + if (ge_tensor == nullptr) { + MS_LOG(ERROR) << "Input ge tensor is nullptr"; + return ret; + } + + MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); + switch (ge_tensor->GetTensorDesc().GetDataType()) { + case GeDataType::DT_UINT32: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_FLOAT: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_INT32: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_DOUBLE: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_INT64: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_UINT64: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_INT16: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_UINT16: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_DUAL_SUB_INT8: + case GeDataType::DT_INT8: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_UINT8: + case GeDataType::DT_DUAL_SUB_UINT8: + ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); + break; + case GeDataType::DT_FLOAT16: + case GeDataType::DT_BOOL: + case GeDataType::DT_UNDEFINED: + case GeDataType::DT_DUAL: + default: + MS_LOG(ERROR) << "Unsupported to print type:" << static_cast(ge_tensor->GetTensorDesc().GetDataType()) + << " ge tensor"; + break; + } + return ret; +} +} // namespace transform +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/util.h b/mindspore/ccsrc/transform/graph_ir/util.h new file mode 100644 index 0000000000..32d4242c4f --- /dev/null +++ b/mindspore/ccsrc/transform/graph_ir/util.h @@ -0,0 +1,241 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TRANSFORM_UTIL_H_ +#define TRANSFORM_UTIL_H_ + +#include +#include +#include +#include +#include "securec/include/securec.h" +#include "ir/anf.h" +#include "ir/dtype.h" +#include "ir/tensor.h" +#include "transform/graph_ir/types.h" + +#include "graph/tensor.h" + +namespace mindspore { +namespace transform { +class TransformUtil { + public: + /* + * Parameters: + * type: [MeDataType] the data type for ME tensor + * Return: + * [GeDataType] the data type for ge tensor + * */ + static std::vector ConvertIntToList(int64_t data, int size); + + /* + * Parameters: + * type: [MeDataType] the data type for ME tensor + * Return: + * [GeDataType] the data type for ge tensor + * */ + static GeDataType ConvertDataType(const MeDataType &type); + + /* + * Parameters: + * type: [string] the data format in ME op + * Return: + * [GeFormat] the data format for ge tensor + * */ + static GeFormat ConvertFormat(const std::string &format); + + /* + * Parameters: + * type: [MeDataType] the data type for ME tensor + * Return: + * [size_t] the buff size for the type in ME + * */ + static size_t GetDataTypeSize(const MeDataType &type); + + /* + * Parameters: + * tensor: [MeTensorPtr] the me tensor to get description from + * format: [string] the data format in ME + * is_input: [bool] whether the tensor is used as input, default:false + * Return: + * [shared_ptr] the shared pointer of ge tensor description + * */ + static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, + const std::string &format); + + /* + * Parameters: + * tensor: [MeTensor] the data tensor in ME + * format: [string] the data format in ME op + * is_input: [bool] whether the tensor is used as input, default:false + * Return: + * [GeTensor] the data tensor in GE + * */ + static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); + + /* + * Parameters: + * me_tensors: [vector] the data tensors in ME + * format: [string] the data format in ME op + * Return: + * [std::vector] the data tensors in GE + * */ + static std::vector ConvertInputTensors(const std::vector &me_tensors, + const std::string &format); + + /* + * Parameters: + * tensor: [GeTensor] the data tensor in GE + * Return: + * [MeTensor] the data tensor in ME + * */ + static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); + + /* + * Parameters: + * tensor: [GeTensor] the data tensor in GE + * request_dims [std::vector] the output Me tensors must adjust to this shapes + * Return: + * [MeTensor] the data tensor in ME + * */ + static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); + /* + * Parameters: + * ge_tensors: [std::vector] the data tensor in GE + * request_dims [std::vector>] the output Me tensors must adjust to this shapes + * Return: + * [std::vector] the data tensor in ME + * */ + static std::vector ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims); + /* + * Parameters: + * ge_tensors: [std::vector] the data tensor in GE + * Return: + * [std::vector] the data tensor in ME + * */ + static std::vector ConvertGeTensors(const std::vector &ge_tensors); + /* + * Parameters: + * ge_tensor: [GeTensor] the data tensor in GE + * me_dims: [std::vector] the shape of created Me tensor + * me_type: [TypeId] the type of created Me tensor + * Return: + * [MeTensor] the data tensor in ME + * */ + static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type); + /* + * Parameters: + * type: [GeDataType] the ge tensor data type + * Return: + * [MeDataType] the me tensor data type + * */ + static MeDataType ConvertGeDataType(const GeDataType &type); + + /* + * Parameters: + * me_dims: [std::vector] the me shape + * Return: + * [GeShape] the ge shape + * */ + static GeShape ConvertMeShape(const std::vector &me_dims); + + /* + * Parameters: + * ge_shape: [GeShape] the ge shape + * Return: + * [vector] the me shape + * */ + static std::vector ConvertGeShape(const GeShape &ge_shape); + + /* Function: + * Convert GeShape to Me request shape, Support pattern: + * {1, x, 1, 1} --> {x} + * {x, 1, 1, 1} --> {x} + * {x, x, 1, 1} --> {x, x} + * {x, x, x, 1} --> {x, x, x} + * {x, x, x, x} --> {x, x, x, x} + * If unmatch upon patterns, return original ge dims + * Parameters: + * ge_shape: [GeShape] the ge shape + * request_dims: [vector] request dims + * Return: + * [vector] the me shape + * */ + static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); + + /* + * Parameters: + * vec: [std::vector] the vector to print + * Return: + * [string] value string + * */ + template ::value>::type> + static std::string PrintVector(const std::vector &vec) { + const int MAX_PRINT_NUM = 100; + std::stringstream ss; + ss << "{ "; + int i = 0; + for (auto it = vec.begin(); it != vec.end(); ++it) { + ss << std::to_string(*it) << ", "; + i++; + if (i >= MAX_PRINT_NUM) { + break; + } + } + + if (i >= MAX_PRINT_NUM) { + ss << "... to be continue}"; + } else { + ss << "}"; + } + return ss.str(); + } + + /* + * Parameters: + * ge_tensor: [GeTensorPtr] the ge tensor + * Return: + * [stringstream] value string + * */ + static std::string PrintGeTensor(const GeTensorPtr ge_tensor); + + /* + * Parameters: + * data: [uint8_t *] the ge tensor data pointer + * size: [size_t] the ge tensor data bytes + * Return: + * [shared_ptr] vector pointer + * */ + template ::value>::type> + static std::vector MakeVector(const uint8_t *const data, size_t size) { + auto dest = std::vector(size / sizeof(T)); + if (data == nullptr) { + return dest; + } + + errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size); + if (EOK != ret) { + return std::vector(); + } + return dest; + } +}; +} // namespace transform +} // namespace mindspore + +#endif // TRANSFORM_UTIL_H_ diff --git a/mindspore/ccsrc/transform/graph_runner.cc b/mindspore/ccsrc/transform/graph_runner.cc deleted file mode 100644 index 52d0d8e17f..0000000000 --- a/mindspore/ccsrc/transform/graph_runner.cc +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * Limitations under the License. - */ - -#include "transform/graph_runner.h" -#include -#include -#include -#include "utils/log_adapter.h" -#include "utils/config_manager.h" -#include "sys/time.h" -#include "utils/callbacks.h" -#include "utils/utils.h" -#include "./common.h" -#ifdef ENABLE_GE -#include "utils/callbacks_ge.h" -#endif - -#ifdef NO_GE_CLIENT -namespace ge { -Session::Session(const std::map &options) { - if (options.empty()) { - MS_LOG(ERROR) << "session input options is empty"; - } - sessionId_ = 0; -} -Session::~Session() {} -} // namespace ge -#endif - -namespace mindspore { -namespace transform { -std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { - std::shared_ptr ret = std::make_shared(sess_options); - if (ret == nullptr) { - MS_LOG(ERROR) << "Create GE session failed"; - return nullptr; - } - MS_LOG(INFO) << "Create new GE session success"; - return ret; -} - -GraphRunner::GraphRunner(const GraphRunnerOptions &options) - : options_(options), graph_manager_(DfGraphManager::GetInstance()) { - if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { - MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; - } - - if (options.sess_ptr != nullptr) { - sess_ = options.sess_ptr; - } else { - sess_ = NewSession(options.options); - if (sess_ == nullptr) { - MS_LOG(EXCEPTION) << "GraphRunner initialize failed!!"; - return; - } - } - -#if (defined ENABLE_GE) - // register the callback function - if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ge::GRAPH_SUCCESS) { - MS_LOG(EXCEPTION) << "register callback failed!"; - return; - } - - if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ge::GRAPH_SUCCESS) { - MS_LOG(EXCEPTION) << "register summary callback failed!"; - return; - } -#endif - - std::vector wrappers = graph_manager_.GetAllGraphs(); - if (wrappers.empty()) { - MS_LOG(INFO) << "The GraphManager is empty!!"; - return; - } - -#ifdef ENABLE_GE - for (auto &it : wrappers) { - std::set saved_graph = graph_manager_.GetSavedGraphs(); - auto iter_find = saved_graph.find(std::to_string(it->id_)); - if (iter_find != saved_graph.end()) { - continue; - } - MS_LOG(INFO) << "Add the graph " << (*it).name_ << " to GE, it's id is: " << (*it).id_; - graph_manager_.AddSavedGraphs(std::to_string(it->id_)); - (void)sess_->AddGraph(it->id_, *(it->graph_ptr_), it->options_); - } -#endif -} - -Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, - std::vector *outputs) { - std::string name = options.name; - if (name.empty()) { - MS_LOG(ERROR) << "The graph name is null"; - return Status::INVALID_ARGUMENT; - } - - DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name); - if (wrap_ptr == nullptr) { - MS_LOG(ERROR) << "Get graph form DfGraphManager failed!"; - return Status::NOT_FOUND; - } - - if (wrap_ptr->graph_ptr_ == nullptr) { - MS_LOG(WARNING) << "The graph is null"; - return Status::NOT_FOUND; - } - - // call ge::RunGraph() to exec a graph; - std::vector ge_inputs; - std::vector ge_outputs; - - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), - [](const GeTensorPtr &i) { return *i; }); - - MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; - - struct timeval start_time, end_time; - (void)gettimeofday(&start_time, nullptr); - -#ifdef ENABLE_GE - if (sess_ == nullptr) { - MS_LOG(ERROR) << "The GE session is null, can't run the graph!"; - return Status::FAILED; - } - - // The information of some nodes could be changed after fusion in some cases - // Therefore a graph needs to be rebuilt in above situation - if (sess_->IsGraphNeedRebuild(wrap_ptr->id_)) { - sess_->RemoveGraph(wrap_ptr->id_); - sess_->AddGraph(wrap_ptr->id_, *(wrap_ptr->graph_ptr_), wrap_ptr->options_); - } - - ge::Status ret = sess_->RunGraph(wrap_ptr->id_, ge_inputs, ge_outputs); - if (ret != ge::GRAPH_SUCCESS) { - MS_LOG(ERROR) << "Call GE RunGraph Failed, ret is: " << ret; - return Status::FAILED; - } -#else - ge_outputs.swap(ge_inputs); -#endif - - (void)gettimeofday(&end_time, nullptr); - const uint64_t kUSecondInSecond = 1000000; - uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); - cost += static_cast(end_time.tv_usec - start_time.tv_usec); - MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << ge_outputs.size(); - - (void)std::transform(ge_outputs.begin(), ge_outputs.end(), std::back_inserter(*outputs), - [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); - - return Status::SUCCESS; -} - -Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, - std::vector *const outputs) { - std::vector ge_inputs; - for (auto it : inputs) { - MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); - auto shape = (*it).shape(); - std::string shape_str; - for (const auto &elem : shape) { - shape_str += std::to_string(elem); - shape_str += " "; - } - MS_LOG(INFO) << "inputs tensor's shape is: { " << shape_str << "}"; - - auto ge_tensor_ptr = TransformUtil::ConvertTensor(it, kOpFormat_NCHW); - if (ge_tensor_ptr != nullptr) { - ge_inputs.emplace_back(ge_tensor_ptr); - } else { - MS_LOG(INFO) << "Convert input Me tensor to Ge tensor failed. Abort this graph"; - return Status::FAILED; - } - } - - std::vector ge_outputs; - Status ret; - { - // Release GIL before calling into (potentially long-running) C++ code - py::gil_scoped_release release; - ret = RunGraph(options, ge_inputs, &ge_outputs); - } - if (ret != Status::SUCCESS) { - return ret; - } else { - // conver GeTensor to MeTensor - for (auto &it : ge_outputs) { - auto tensor = TransformUtil::ConvertGeTensor(it); - if (tensor != nullptr) { - outputs->emplace_back(tensor); - } - } - MS_LOG(INFO) << "Return Me tensor outputs num is: " << outputs->size(); - return Status::SUCCESS; - } -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_runner.h b/mindspore/ccsrc/transform/graph_runner.h deleted file mode 100644 index 30769c8310..0000000000 --- a/mindspore/ccsrc/transform/graph_runner.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_GRAPH_RUNNER_H_ -#define TRANSFORM_GRAPH_RUNNER_H_ - -#include -#include -#include -#include -#include - -#include "transform/types.h" -#include "transform/util.h" -#include "ir/tensor.h" -#include "transform/df_graph_manager.h" - -namespace mindspore { -namespace transform { -using SessionOptions = std::map; - -struct GraphRunnerOptions { - std::string target{"default_graph_runner"}; - SessionOptions options; - // if sess_ptr is nullptr, GraphRunner will create a new ge session - std::shared_ptr sess_ptr{nullptr}; -}; - -struct RunOptions { - // graph's name - std::string name; -}; - -class GraphRunner { - public: - explicit GraphRunner(const GraphRunnerOptions &options); - ~GraphRunner() { sess_ = nullptr; } - Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); - Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); - static std::shared_ptr NewSession(const SessionOptions &sess_options); - - private: - std::shared_ptr sess_; - transform::GraphRunnerOptions options_; - DfGraphManager &graph_manager_; -}; -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_GRAPH_RUNNER_H_ diff --git a/mindspore/ccsrc/transform/onnx/CMakeLists.txt b/mindspore/ccsrc/transform/onnx/CMakeLists.txt new file mode 100644 index 0000000000..0d2f6c947b --- /dev/null +++ b/mindspore/ccsrc/transform/onnx/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE _ONNX_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_ONNX_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ONNX) +add_library(_mindspore_transform_onnx_obj OBJECT ${_ONNX_SRC_FILES}) diff --git a/mindspore/ccsrc/transform/onnx/ir_exporter.cc b/mindspore/ccsrc/transform/onnx/ir_exporter.cc new file mode 100644 index 0000000000..78858eea8a --- /dev/null +++ b/mindspore/ccsrc/transform/onnx/ir_exporter.cc @@ -0,0 +1,618 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ir/tensor.h" +#include "ir/param_value.h" +#include "debug/anf_ir_utils.h" +#include "frontend/operator/ops.h" +#include "proto/onnx.pb.h" + +namespace mindspore { +using FloatPtr = std::shared_ptr; +using IntPtr = std::shared_ptr; + +// anf type to onnx type map +static std::unordered_map g_data_type_map = { + {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, + {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, + {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, + {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, + {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, + {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, + {kObjectTypeString, onnx::TensorProto_DataType_STRING}, +}; + +static std::unordered_map g_data_bits_int_map = { + {8, onnx::TensorProto_DataType_INT8}, + {16, onnx::TensorProto_DataType_INT16}, + {32, onnx::TensorProto_DataType_INT32}, + {64, onnx::TensorProto_DataType_INT64}, +}; + +static std::unordered_map g_data_bits_float_map = { + {16, onnx::TensorProto_DataType_FLOAT16}, + {32, onnx::TensorProto_DataType_FLOAT}, +}; + +// Can build different builder according to format +class IrExportBuilder; +using IrExportBuilderPtr = std::shared_ptr; + +class IrExporter { + public: + explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {} + virtual ~IrExporter() = default; + std::string GetDumpString(const FuncGraphPtr &func_graph); + + private: + IrExportBuilderPtr builder_; +}; + +class IrExportBuilder { + public: + IrExportBuilder() = default; + ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); } + std::string GetProtoString(const FuncGraphPtr &func_graph); + void BuildModelInfo(); + void BuildModel(const FuncGraphPtr &func_graph); + + private: + void BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto); + void BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto); + void BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto); + std::string BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto); + + void SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto); + void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, onnx::ValueInfoProto *const value_proto); + void SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto); + void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, onnx::TensorProto *const tensor_proto); + void SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto); + void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, onnx::NodeProto *const node_proto, + std::string suffix = "0"); + void SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto); + void SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto); + void SetSequenceToAttributeProto(const ValueSequeuePtr &value, onnx::AttributeProto *const attr_proto); + + onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); + onnx::TensorProto_DataType GetOnnxDataBitsIntType(int bits); + onnx::TensorProto_DataType GetOnnxDataBitsFloatType(int bits); + std::string GetNodeName(const AnfNodePtr &node); + std::string GetUniqueNodeName(const AnfNodePtr &node); + std::string GetOpTypeName(const AnfNodePtr &node); + size_t AllocateIndex() { return ++node_index_; } + void ResetIndex() { node_index_ = 0; } + + private: + onnx::ModelProto model_; + onnx::NodeProto *last_node_{nullptr}; + std::list todo_; + std::map node_index_map_; + size_t node_index_{0}; +}; + +using IrExporterPtr = std::shared_ptr; + +std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) { + if ((builder_ == nullptr) || (func_graph == nullptr)) { + MS_LOG(EXCEPTION) << "Input params is null."; + } + + // Export model info + builder_->BuildModelInfo(); + + // Export model and return string + builder_->BuildModel(func_graph); + + return builder_->GetProtoString(func_graph); +} + +std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) { + MS_LOG(DEBUG) << "BuildModel complete!"; + return model_.SerializeAsString(); +} + +void IrExportBuilder::BuildModelInfo() { + model_.set_ir_version(onnx::IR_VERSION_2019_1_22); + model_.set_producer_name("MindSpore"); + model_.set_model_version(1); +} + +void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) { + onnx::GraphProto *graph_proto = model_.mutable_graph(); + graph_proto->set_name(func_graph->ToString()); + ResetIndex(); + todo_.clear(); + todo_.push_back(func_graph); + while (!todo_.empty()) { + FuncGraphPtr fg = todo_.back(); + todo_.pop_back(); + BuildFuncGraph(fg, graph_proto); + } +} + +void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + // Export parameters + // 1. parameters should be mapped to ValueInfoProto + // 2. parameters with default value should be mapped to Initializer + BuildParameters(func_graph, graph_proto); + + // Export operator nodes(include output) + BuildNodes(func_graph, graph_proto); +} + +void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto &item : func_graph->parameters()) { + auto param = item->cast(); + if (param == nullptr) { + MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter."; + } + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); + std::string param_name = GetUniqueNodeName(param); + input_proto->set_name(param_name); + SetValueInfoProto(param, input_proto); + if (!param->has_default()) { + MS_LOG(DEBUG) << "Parameter: '" << item->ToString() << "' has no default"; + continue; + } + + // Using ONNX initializer to set parameter's default value + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); + initializer_proto->set_name(param_name); + SetParamToTensorProto(param, initializer_proto); + auto tensor = std::dynamic_pointer_cast(param->default_param()->value()); + if (tensor) { + initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); + } + } +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataType(TypeId type_id) { + auto iter = g_data_type_map.find(type_id); + if (iter == g_data_type_map.end()) { + MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id; + } + return iter->second; +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsIntType(int bits) { + auto iter = g_data_bits_int_map.find(bits); + if (iter == g_data_bits_int_map.end()) { + MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits; + } + return iter->second; +} + +onnx::TensorProto_DataType IrExportBuilder::GetOnnxDataBitsFloatType(int bits) { + auto iter = g_data_bits_float_map.find(bits); + if (iter == g_data_bits_float_map.end()) { + MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits; + } + return iter->second; +} + +void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto) { + if (node == nullptr || value_proto == nullptr) { + MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!"; + } + MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString(); + SetValueInfoProto(node->Type(), node->Shape(), value_proto); +} + +void IrExportBuilder::SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::ValueInfoProto *const value_proto) { + onnx::TypeProto *type_proto = value_proto->mutable_type(); + if (type->isa() && shape->isa()) { + auto tensor = type->cast(); + auto elem_type = tensor->element(); + const auto &dims = shape->cast()->shape(); + type_proto->mutable_tensor_type()->set_elem_type(GetOnnxDataType(elem_type->type_id())); + for (const auto &dim : dims) { + MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim; + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + } else if (type->isa()) { + auto tup_shape = shape->cast(); + type_proto->set_denotation(std::to_string(tup_shape->shape().size())); + } else { + MS_LOG(EXCEPTION) << "Value type: " << type->type_name() << " is not supported!"; + } +} + +void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("tensor"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + auto data = value->cast(); + tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); + auto dtype = data->data_type(); + auto shape = data->shape_c(); + tensor_proto->set_data_type(GetOnnxDataType(dtype)); + for (const auto &dim : shape) { + tensor_proto->add_dims(dim); + } +} + +void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::TensorProto *const tensor_proto) { + if (!type->isa() || !shape->isa()) { + MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString(); + } + auto tensor = type->cast(); + const auto &dims = shape->cast()->shape(); + tensor_proto->set_data_type(GetOnnxDataType(tensor->element()->type_id())); + for (const auto &dim : dims) { + tensor_proto->add_dims(dim); + } +} + +void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { + if (param == nullptr || tensor_proto == nullptr) { + MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!"; + } + MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString(); + SetTensorProto(param->Type(), param->Shape(), tensor_proto); +} + +void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + for (const AnfNodePtr &node : nodes) { + if (!node->isa()) { + MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode"; + continue; + } + auto cnode = node->cast(); + if (cnode == func_graph->get_return()) { + BuildOutput(cnode, graph_proto); + } else { + BuildCNode(cnode, graph_proto); + } + } +} + +void IrExportBuilder::BuildOutput(const CNodePtr &node, onnx::GraphProto *const graph_proto) { + if (node->size() != 2) { + MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; + } + AnfNodePtr arg = node->input(1); + // Using make_tuple to set multi-output + if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple)) { + auto tuple_node = arg->cast(); + for (size_t i = 1; i < tuple_node->size(); i++) { + auto input_node = arg->cast()->input(i); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + auto output_name = GetUniqueNodeName(tuple_node->input(i)); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(tuple_node->input(i), output_proto); + } + } else { + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + std::string output_name = GetUniqueNodeName(node); + output_proto->set_name(output_name); + last_node_->add_output(output_name); + SetValueInfoProto(arg, output_proto); + } +} + +std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { + // May be ValueNode/CNode/Parameter + std::string type_name = ""; + if (IsValueNode(node)) { + PrimitivePtr prim = GetValueNode(node); + type_name = prim->ToString(); + } else if (IsValueNode(node)) { + FuncGraphPtr fg = GetValueNode(node); + todo_.push_back(fg); + type_name = fg->ToString(); + } else if (node->isa() || node->isa()) { + type_name = node->ToString(); + } else { + MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name(); + } + MS_LOG(DEBUG) << "ExportType: " << type_name; + return type_name; +} + +void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, + onnx::NodeProto *const node_proto, std::string suffix) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_ref_attr_name("shape"); + if (suffix.compare("0") != 0) { + attr_proto->set_name("shape" + suffix); + } else { + attr_proto->set_name("shape"); + } + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetTensorProto(type, shape, tensor_proto); +} + +void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, onnx::NodeProto *const node_proto) { + // Get shape of cnode + // 1. prim ArgMaxWithValue need to get shape from tuple element + // 2. some cnode doesn't has shape, such as LayerNorm + // 3. other cnodes have shape + if (node->IsApply(prim::kPrimArgMaxWithValue) || node->IsApply(prim::kPrimLayerNorm)) { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa()) { + MS_LOG(EXCEPTION) << "Output data of ArgMaxWithValue cnode must be tuple: " << type->type_name(); + } + auto elements = type->cast()->elements(); + auto tuple_shape = shape->cast()->shape(); + for (size_t i = 0; i < elements.size(); i++) { + SetShapeToNodeProto(elements[i], tuple_shape[i], node_proto, std::to_string(i)); + } + } else { + auto type = node->Type(); + auto shape = node->Shape(); + if (!type->isa() || !shape->isa()) { + MS_LOG(DEBUG) << "Cnode has no shape: " << node->ToString(); + return; + } + SetShapeToNodeProto(type, shape, node_proto); + } +} + +void IrExportBuilder::BuildCNode(const CNodePtr &node, onnx::GraphProto *const graph_proto) { + auto inputs_size = node->size(); + if (inputs_size < 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + // Need to build input node before dealing with cnode + std::vector op_inputs; + std::vector input_names; + for (size_t i = 1; i < inputs_size; i++) { + auto input = node->input(i); + op_inputs.push_back(input); + input_names.push_back(BuildInputNode(input, graph_proto)); + } + + // Build cnode + onnx::NodeProto *node_proto = graph_proto->add_node(); + std::string output_name = GetUniqueNodeName(node); + node_proto->add_output(output_name); + node_proto->set_name(output_name); + node_proto->set_domain(node->fullname_with_scope()); + AnfNodePtr op = node->input(0); + std::string type_name = GetOpTypeName(op); + node_proto->set_op_type(type_name); + last_node_ = node_proto; + SetShapeToNodeProto(node, node_proto); + (void)std::for_each(input_names.begin(), input_names.end(), + [&node_proto](const string &name) { node_proto->add_input(name); }); + + // Add primitive attrs + if (IsValueNode(op)) { + auto prim = GetValueNode(op); + for (auto attr : prim->attrs()) { + MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name(attr.first); + SetValueToAttributeProto(attr.second, attr_proto); + } + } else { + MS_LOG(EXCEPTION) << "Need to support op type: " << op->type_name(); + } +} + +std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, onnx::GraphProto *const graph_proto) { + std::string node_name = GetUniqueNodeName(node); + if (node->isa()) { + // When node input is a ValueNode, need to create a Constant Node + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(node_name); + SetAttributeProto(node, node_proto); + } + return node_name; +} + +std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) { + // Naming anfnode + // 1. parameter is unique in one func_graph + // 2. cnode and valuenode may be reduplicative, so add index to identify. + std::string node_name = ""; + if (node->isa()) { + node_name = GetNodeName(node); + } else if (node->isa() || node->isa()) { + auto iter = node_index_map_.find(node); + if (iter != node_index_map_.end()) { + node_name = GetNodeName(node) + ":" + std::to_string(iter->second); + } else { + auto node_idx = AllocateIndex(); + node_index_map_[node] = node_idx; + node_name = GetNodeName(node) + ":" + std::to_string(node_idx); + } + } else { + MS_LOG(EXCEPTION) << "Can not support type of node:" << node->ToString(); + } + MS_LOG(DEBUG) << "Node name: " << node_name; + return node_name; +} + +std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) { + std::string node_name = ""; + if ((node != nullptr) && (node->func_graph() != nullptr)) { + node_name = node->func_graph()->ToString() + ":"; + } + node_name += node->ToString(); + MS_LOG(DEBUG) << "GetNodeName: " << node_name; + return node_name; +} + +void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, onnx::NodeProto *const node_proto) { + if (node == nullptr || node_proto == nullptr) { + MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!"; + } + auto value = node->cast()->value(); + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString(); + SetValueToAttributeProto(value, attr_proto); +} + +void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("type"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { + auto int_value = value->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + } else if (value->isa()) { + auto float_value = value->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + } else if (value->isa()) { + tensor_proto->set_name("tensor"); + auto elem_type = value->cast()->element(); + if (elem_type->isa()) { + auto int_value = elem_type->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsIntType(int_value->nbits())); + } else if (elem_type->isa()) { + auto float_value = elem_type->cast(); + tensor_proto->set_data_type(GetOnnxDataBitsFloatType(float_value->nbits())); + } else { + MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name(); + } + } else { + MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + } +} + +void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + if (value->isa() || value->isa()) { + SetScalarToAttributeProto(value, attr_proto); + } else if (value->isa() || value->isa()) { + SetTypeToAttributeProto(value, attr_proto); + } else if (value->isa()) { + SetSequenceToAttributeProto(value->cast(), attr_proto); + } else if (value->isa()) { + SetTensorToAttributeProto(value, attr_proto); + } else { + MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); + } +} + +void IrExportBuilder::SetScalarToAttributeProto(const ValuePtr &value, onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + SetScalarToProto(value, tensor_proto); +} + +void IrExportBuilder::SetScalarToProto(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { + if (value == nullptr || tensor_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValuePtr or TensorProto is null!"; + } + if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_STRING); + tensor_proto->add_string_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_BOOL); + tensor_proto->add_int32_data(GetValue(value)); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT8); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT16); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT32); + tensor_proto->add_int32_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->add_int64_data(value->cast()->value()); + } else if (value->isa()) { + tensor_proto->set_data_type(onnx::TensorProto_DataType_FLOAT); + tensor_proto->add_float_data(GetValue(value)); + } else { + MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name(); + } +} + +void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value, + onnx::AttributeProto *const attr_proto) { + if (value == nullptr || attr_proto == nullptr) { + MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!"; + } + attr_proto->set_ref_attr_name("scalar"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + if (value->isa()) { + const ValueTuplePtr &tuple_value = value->cast(); + if (tuple_value->value().size() == 0) { + MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0"; + return; + } + auto type_id = tuple_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); + for (const auto &item : tuple_value->value()) { + SetScalarToProto(item, tensor_proto); + } + } else if (value->isa()) { + const ValueListPtr &list_value = value->cast(); + if (list_value->value().size() == 0) { + MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0"; + return; + } + auto type_id = list_value->value()[0]->type()->type_id(); + tensor_proto->set_data_type(GetOnnxDataType(type_id)); + for (const auto &item : list_value->value()) { + SetScalarToProto(item, tensor_proto); + } + } +} + +std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { + auto builder = std::make_shared(); + if (builder == nullptr) { + MS_LOG(ERROR) << "Create ir exporter failed!"; + return ""; + } + auto exporter = std::make_shared(builder); + if (exporter == nullptr) { + return ""; + } + return exporter->GetDumpString(func_graph); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/onnx/onnx_exporter.cc b/mindspore/ccsrc/transform/onnx/onnx_exporter.cc new file mode 100644 index 0000000000..f69fb81a7e --- /dev/null +++ b/mindspore/ccsrc/transform/onnx/onnx_exporter.cc @@ -0,0 +1,1207 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "debug/anf_ir_utils.h" +#include "proto/onnx.pb.h" +#include "frontend/operator/ops.h" +#include "ir/tensor.h" +#include "ir/param_value.h" + +namespace mindspore { +enum OpMergeMode { + OP_MERGE_UNDEFINED = 0, // undefined behavior + OP_MERGE_IGNORE = 1, // indicate an input op merged into other op in compute node list + OP_MERGE_CONV = 2, // indicate `MindSpore Conv + BiasAdd` --> `ONNX Conv` + OP_MERGE_GEMM = 3, // indicate `MindSpore MatMul + BiasAdd` --> `ONNX Gemm` + OP_MERGE_BATCH_NORM = 4, // indicate `MindSpore BatchNorm(x)[0]` --> `ONNX BatchNormalization` + OP_MERGE_MAXPOOL_WITH_ARGMAX = 5, // indicate `MindSpore MaxPoolWithArgmax(x)[0]` --> `ONNX MaxPool` +}; + +struct OpMergedInfo { + OpMergeMode mode = OP_MERGE_UNDEFINED; + int referred_count = 0; +}; + +using GenAttrFuncType = + std::function; + +template +void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + auto casted_value = dyn_cast(value); + if (casted_value == nullptr) { + MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; + } + auto attr_value = casted_value->value(); + switch (attr_type) { + case onnx::AttributeProto_AttributeType_INT: + attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); + break; + case onnx::AttributeProto_AttributeType_FLOAT: + attr_proto->set_f(static_cast(attr_value)); + break; + case onnx::AttributeProto_AttributeType_INTS: + for (size_t i = 0; i < rep_cnt; ++i) { + attr_proto->add_ints(static_cast<::google::protobuf::int64>(attr_value)); + } + break; + case onnx::AttributeProto_AttributeType_FLOATS: + for (size_t i = 0; i < rep_cnt; ++i) { + attr_proto->add_floats(static_cast(attr_value)); + } + break; + default: + MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; + } + attr_proto->set_type(attr_type); +} + +template +void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + auto tuple_ptr = dyn_cast(value); + if (tuple_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; + } + switch (attr_type) { + case onnx::AttributeProto_AttributeType_INTS: + for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + break; + case onnx::AttributeProto_AttributeType_FLOATS: + for (size_t i = beg_idx; i < tuple_ptr->size(); ++i) { + attr_proto->add_floats(GetValue((*tuple_ptr)[i])); + } + break; + default: + MS_LOG(EXCEPTION) << "Convert attribute fail, unexpected ONNX type " << attr_type; + } + attr_proto->set_type(attr_type); +} + +void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + auto attr_value = GetValue(value); + if (attr_value == "VALID") { + attr_proto->set_s("VALID"); + } else { + attr_proto->set_s("SAME_UPPER"); + } +} + +class OpAttrInfo { + public: + OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) + : attr_name_(attr_name), + onnx_attr_name_(onnx_attr_name), + onnx_attr_type_(onnx_attr_type), + fn_gen_attr_(fn_gen_attr) {} + ~OpAttrInfo() {} + + const std::string &attr_name() const { return attr_name_; } + const std::string &onnx_attr_name() const { return onnx_attr_name_; } + onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } + GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } + + private: + std::string attr_name_; // attribute name of MindSpore + std::string onnx_attr_name_; // corresponding attribute name of ONNX + onnx::AttributeProto_AttributeType onnx_attr_type_; // corresponding attribute type of ONNX + GenAttrFuncType fn_gen_attr_; // function used convert +}; + +class OpNameInfo { + public: + OpNameInfo &set_op_type(const std::string &op_type) { + op_type_ = op_type; + return *this; + } + + const std::string &op_type() const { return op_type_; } + + OpNameInfo &set_onnx_type(const std::string &onnx_type) { + onnx_type_ = onnx_type; + return *this; + } + + const std::string &onnx_type() const { return onnx_type_; } + + OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { + op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); + return *this; + } + + const std::vector &op_attrs() const { return op_attrs_; } + + private: + std::string op_type_; // operator type of MindSpore + std::string onnx_type_; // corresponding ONNX operator type + std::vector op_attrs_; // operator attributes map info +}; + +#define OPERATOR_ONNX_CONVERT_DEFINE(name, onnx_name, impl) \ + OpNameInfo GetOpOnnxConvertInfo_##name() { return impl.set_op_type(#name).set_onnx_type(#onnx_name); } + +OPERATOR_ONNX_CONVERT_DEFINE(TensorAdd, Add, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Mul, Mul, OpNameInfo()) + +OPERATOR_ONNX_CONVERT_DEFINE(ReLU, Relu, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Sigmoid, Sigmoid, OpNameInfo()) + +OPERATOR_ONNX_CONVERT_DEFINE(Flatten, Flatten, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Squeeze, Squeeze, + OpNameInfo().Attr("axis", "axes", onnx::AttributeProto_AttributeType_INTS, + SetAttrTupleValueToProto<0>)) + +OPERATOR_ONNX_CONVERT_DEFINE( + Conv2D, Conv, + OpNameInfo() + .Attr("dilation", "dilations", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) + .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) + .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, + [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, + const PrimitivePtr &prim) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + auto attr_value = GetValue(value); + if (attr_value == "valid") { + attr_proto->set_s("VALID"); + } else if (attr_value == "same") { + attr_proto->set_s("SAME_UPPER"); + } else { // pad_mode is 'pad', use attribute 'pad_list' to fill ONNX attribute 'pads' + attr_proto->set_name("pads"); + SetAttrTupleValueToProto(prim->GetAttr("pad_list"), onnx::AttributeProto_AttributeType_INTS, attr_proto, + prim); + } + }) + .Attr("stride", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) +OPERATOR_ONNX_CONVERT_DEFINE(BiasAdd, Add, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(MatMul, Gemm, + OpNameInfo() + .Attr("transpose_a", "transA", onnx::AttributeProto_AttributeType_INT, + SetAttrValueToProto) + .Attr("transpose_b", "transB", onnx::AttributeProto_AttributeType_INT, + SetAttrValueToProto)) + +OPERATOR_ONNX_CONVERT_DEFINE(BatchNorm, BatchNormalization, + OpNameInfo().Attr("epsilon", "epsilon", onnx::AttributeProto_AttributeType_FLOAT, + SetAttrValueToProto)) + +OPERATOR_ONNX_CONVERT_DEFINE(Reshape, Reshape, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ReduceMean, ReduceMean, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Cast, Cast, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(PReLU, PRelu, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, + OpNameInfo() + .Attr("axis", "axis", onnx::AttributeProto_AttributeType_INT, + SetAttrValueToProto) + .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, + [](ValuePtr, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto->set_i(0); + })) + +OPERATOR_ONNX_CONVERT_DEFINE(SimpleMean, AveragePool, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE( + MaxPool, MaxPool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + +OPERATOR_ONNX_CONVERT_DEFINE( + MaxPoolWithArgmax, MaxPool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + +OPERATOR_ONNX_CONVERT_DEFINE( + AvgPool, AveragePool, + OpNameInfo() + .Attr("ksize", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>) + .Attr("padding", "auto_pad", onnx::AttributeProto_AttributeType_STRING, SetPoolingPadMode) + .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) + +OPERATOR_ONNX_CONVERT_DEFINE(GatherV2, Gather, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(Sub, Sub, OpNameInfo()) + +#define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name + +void RegisterOpConverters(const std::function &fn) { + fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); + fn(OP_CONVERT_FUNCTION_NAME(Mul)()); + + fn(OP_CONVERT_FUNCTION_NAME(ReLU)()); + fn(OP_CONVERT_FUNCTION_NAME(Sigmoid)()); + + fn(OP_CONVERT_FUNCTION_NAME(Conv2D)()); + fn(OP_CONVERT_FUNCTION_NAME(Argmax)()); + + fn(OP_CONVERT_FUNCTION_NAME(Flatten)()); + fn(OP_CONVERT_FUNCTION_NAME(MaxPool)()); + fn(OP_CONVERT_FUNCTION_NAME(MaxPoolWithArgmax)()); + fn(OP_CONVERT_FUNCTION_NAME(AvgPool)()); + + fn(OP_CONVERT_FUNCTION_NAME(Squeeze)()); + fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); + fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); + + fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); + fn(OP_CONVERT_FUNCTION_NAME(Concat)()); + fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); + fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); + fn(OP_CONVERT_FUNCTION_NAME(Sub)()); +} + +class OpConvertRegistry { + public: + ~OpConvertRegistry() { Clear(); } + + static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } + + static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } + + static OpConvertRegistry &GetSingleton() { + static OpConvertRegistry registry = OpConvertRegistry(); + return registry; + } + + static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } + + void Clear() noexcept { op_map_.clear(); } + + private: + OpConvertRegistry() {} + + std::unordered_map op_map_; +}; + +class OnnxExporter { + public: + OnnxExporter() {} + ~OnnxExporter() {} + + std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); + + private: + void InitModelInfo(); + + void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + + size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *graph_proto); + + static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); + void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); + void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); + + void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr); + void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReduce(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto); + + void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); + void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); + + size_t AllocateNodeIndex() { return ++onnx_node_index_; } + + void ResetNodeIndex() { onnx_node_index_ = 0; } + + static int GetInt32Value(const AnfNodePtr &node) { + auto value_node_ptr = dyn_cast(node); + MS_EXCEPTION_IF_NULL(value_node_ptr); + return GetValue(value_node_ptr->value()); + } + + onnx::ModelProto model_; + + size_t onnx_node_index_ = 0; +}; + +std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + return ""; + } + ResetNodeIndex(); + OpConvertRegistry::GetSingleton().Clear(); + OpConvertRegistry::RegisterAllOpConverters(); + InitModelInfo(); + onnx::GraphProto *graph_proto = model_.mutable_graph(); + ExportFuncGraph(func_graph, graph_proto); + return model_.SerializeAsString(); +} + +void OnnxExporter::InitModelInfo() { + model_.set_ir_version(onnx::IR_VERSION_2019_1_22); + model_.set_producer_name("MindSpore"); + model_.set_producer_version("1.0"); + onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); + opset_proto->set_version(9); +} + +void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + std::map node_map; + + MS_LOG(INFO) << "Begin exporting onnx model for graph " << func_graph->ToString(); + + onnx_node_index_ = func_graph->parameters().size(); + + // set graph name + graph_proto->set_name(func_graph->ToString()); + + // export parameters + // 1. all parameters (with or without default value) will be mapped to ONNX parameters + // 2. parameters with default value will mapped to ONNX initializers + ExportParameters(func_graph, graph_proto); + + // export computational nodes and output nodes + ExportNodes(func_graph, &node_map, graph_proto); + + MS_LOG(INFO) << "End exporting onnx model for graph " << func_graph->ToString(); +} + +void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto ¶m : func_graph->parameters()) { + const ParameterPtr param_ptr = dyn_cast(param); + if (param_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; + } + + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); + input_proto->set_name(param_ptr->ToString()); + SetValueInfoType(param_ptr, input_proto); + + if (!param_ptr->has_default()) { + continue; + } + // parameter with default value is an ONNX initializer + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); + initializer_proto->set_name(param_ptr->ToString()); + SetTensorProtoInfo(param_ptr, initializer_proto); + // set value for initializer + auto tensor = std::dynamic_pointer_cast(param_ptr->default_param()->value()); + if (tensor) { + initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); + } + } +} + +onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { + // clang-format off + static std::unordered_map type_map = { + {kNumberTypeBool, onnx::TensorProto_DataType_BOOL}, + {kNumberTypeInt8, onnx::TensorProto_DataType_INT8}, + {kNumberTypeInt16, onnx::TensorProto_DataType_INT16}, + {kNumberTypeInt32, onnx::TensorProto_DataType_INT32}, + {kNumberTypeInt64, onnx::TensorProto_DataType_INT64}, + {kNumberTypeUInt8, onnx::TensorProto_DataType_UINT8}, + {kNumberTypeUInt16, onnx::TensorProto_DataType_UINT16}, + {kNumberTypeUInt32, onnx::TensorProto_DataType_UINT32}, + {kNumberTypeUInt64, onnx::TensorProto_DataType_UINT64}, + {kNumberTypeFloat16, onnx::TensorProto_DataType_FLOAT16}, + {kNumberTypeFloat32, onnx::TensorProto_DataType_FLOAT}, + {kNumberTypeFloat64, onnx::TensorProto_DataType_DOUBLE}, + }; + // clang-format on + + auto iter = type_map.find(type_id); + if (iter == type_map.end()) { + MS_LOG(EXCEPTION) << "Convert type error, unsupported type " << type_id; + } + + return iter->second; +} + +void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { + auto dtype = node->Type(); + auto shape = node->Shape(); + onnx::TypeProto *type_proto = value_proto->mutable_type(); + if (dtype->isa() && shape->isa()) { + auto tensor = dyn_cast(dtype); + auto elem_type = tensor->element(); + const auto &dims = dyn_cast(shape)->shape(); + // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 + auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); + type_proto->mutable_tensor_type()->set_elem_type(type); + + for (const auto &dim : dims) { + type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); + } + } +} + +void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { + auto dtype = param->Type(); + auto shape = param->Shape(); + if (!dtype->isa() || !shape->isa()) { + MS_LOG(EXCEPTION) << "Parameter " << param->name() << " is not a regular tensor, with value " << param->ToString(); + } + + auto tensor = dyn_cast(dtype); + auto elem_type = tensor->element(); + const auto &dims = dyn_cast(shape)->shape(); + tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); + for (const auto &dim : dims) { + tensor_proto->add_dims(dim); + } +} + +void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr) { + std::unordered_map &op_merged_infos = *op_merged_infos_ptr; + + for (auto &node : nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (cnode == func_graph->get_return()) { + // if the key `input` does not exist, just create a new one + op_merged_infos[cnode].referred_count += 1; + } + for (auto &input : cnode->inputs()) { + if (!input->isa()) { + continue; + } + // if the key `input` does not exist, just create a new one + op_merged_infos[input].referred_count += 1; + } + // MindSpore Conv + BiasAdd --> ONNX Conv + if (cnode->IsApply(std::make_shared("BiasAdd")) && + IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) { + op_merged_infos[cnode].mode = OP_MERGE_CONV; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(std::make_shared("BiasAdd")) && + IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) { + op_merged_infos[cnode].mode = OP_MERGE_GEMM; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(prim::kPrimTupleGetItem) && + IsPrimitiveCNode(cnode->input(1), std::make_shared("BatchNorm")) && + GetInt32Value(cnode->input(2)) == 0) { + op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } else if (cnode->IsApply(prim::kPrimTupleGetItem) && + IsPrimitiveCNode(cnode->input(1), std::make_shared("MaxPoolWithArgmax")) && + GetInt32Value(cnode->input(2)) == 0) { + op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; + op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; + op_merged_infos[cnode->input(1)].referred_count -= 1; + } + } +} + +/** + * AnfNode + * +-- CNode + * +-- ANode + * | +-- Parameter + * | `-- ValueNode + */ +void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + + std::unordered_map op_merged_infos; + MatchAndMark(func_graph, nodes, &op_merged_infos); + + for (const AnfNodePtr &node : nodes) { + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + auto iter = op_merged_infos.find(cnode); + // the node is not referenced by any other nodes, skip it + if (iter == op_merged_infos.end()) { + continue; + } + auto merged_info = iter->second; + // the op node is merged with other node and not used any more, skip it + if (merged_info.mode == OP_MERGE_IGNORE && merged_info.referred_count == 0) { + continue; + } + if (cnode == func_graph->get_return()) { + ExportOutput(func_graph, cnode, node_map_ptr, graph_proto); + continue; + } + switch (merged_info.mode) { + case OP_MERGE_CONV: + ExportMergeConv(func_graph, cnode, node_map_ptr, graph_proto); + break; + case OP_MERGE_GEMM: + ExportMergeGemm(func_graph, cnode, node_map_ptr, graph_proto); + break; + case OP_MERGE_BATCH_NORM: + ExportMergeBatchNorm(func_graph, cnode, node_map_ptr, graph_proto); + break; + case OP_MERGE_MAXPOOL_WITH_ARGMAX: + ExportMergeMaxPoolWithArgmax(func_graph, cnode, node_map_ptr, graph_proto); + break; + default: + ExportCNode(func_graph, cnode, node_map_ptr, graph_proto); + break; + } + } +} + +void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_shape = node->input(2); + std::string name_shape; + if (input_shape->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[input_shape] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_shape = std::to_string(const_node_idx); + node_proto->add_output(name_shape); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(input_shape)->value(), attr_proto->mutable_t()); + } else { + name_shape = GetNodeInputName(input_shape, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Reshape."; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimReshape->name()); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_shape); +} + +void OnnxExporter::ExportPrimReduce(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_axis = node->input(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + auto name = prim::kPrimReduceMean->name(); + if (node->IsApply(prim::kPrimReduceSum)) { + name = prim::kPrimReduceSum->name(); + } + node_proto->set_op_type(name); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + + if (input_axis->isa()) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("axes"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); + auto axis_value = dyn_cast(input_axis)->value(); + auto int_ptr = dyn_cast(axis_value); + if (int_ptr == nullptr) { + auto tuple_ptr = dyn_cast(axis_value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + attr_proto->add_ints(GetValue((*tuple_ptr)[i])); + } + } else { + attr_proto->add_ints(int_ptr->value()); + } + } else { + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to attributes for " << name; + } +} + +void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_type = node->input(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimCast->name()); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_data); + + if (input_type->isa()) { + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("to"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + auto type_value = dyn_cast(input_type)->value(); + auto type_ptr = dyn_cast(type_value); + MS_EXCEPTION_IF_NULL(type_ptr); + attr_proto->set_i(GetOnnxDataType(type_ptr->type_id())); + } else { + MS_LOG(EXCEPTION) << "Need to convert MindSpore Cast input(1) to ONNX Cast to attribute."; + } +} + +void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + + auto x_shape = dyn_cast(node->input(1)->Shape()); + auto slope_shape = dyn_cast(node->input(2)->Shape()); + MS_EXCEPTION_IF_NULL(x_shape); + MS_EXCEPTION_IF_NULL(slope_shape); + + // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] + if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { + auto node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Unsqueeze"); + node_proto->add_output(std::to_string(node_idx)); + + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); + attr_proto->set_name("axes"); + attr_proto->add_ints(1); + attr_proto->add_ints(2); + + node_proto->add_input(input_slope); + input_slope = std::to_string(node_idx); + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("PRelu"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_x); + node_proto->add_input(input_slope); +} + +void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Clip"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(input_x); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto->set_name("min"); + attr_proto->set_f(0.f); + attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto->set_name("max"); + attr_proto->set_f(6.f); +} + +void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + auto x_shape = dyn_cast(node->input(1)->Shape()); + auto w_shape = dyn_cast(node->input(2)->Shape()); + MS_EXCEPTION_IF_NULL(x_shape); + MS_EXCEPTION_IF_NULL(w_shape); + if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) { + MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d."; + } + if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) { + MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape"; + } + // create w_shape constant node + auto node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto = graph_proto->add_node(); + std::string name_w_shape = std::to_string(node_idx); + node_proto->add_output(name_w_shape); + node_proto->set_op_type("Constant"); + // create Value Tensor + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size())); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + // reshape + tensor_proto->add_int64_data(w_shape->shape()[1]); + tensor_proto->add_int64_data(w_shape->shape()[0]); + tensor_proto->add_int64_data(w_shape->shape()[2]); + tensor_proto->add_int64_data(w_shape->shape()[3]); + + // add reshape node + node_idx = AllocateNodeIndex(); + node_proto = graph_proto->add_node(); + node_proto->set_op_type(prim::kPrimReshape->name()); + node_proto->add_input(input_w); + node_proto->add_input(name_w_shape); + input_w = std::to_string(node_idx); + node_proto->add_output(input_w); + + // add conv node + node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + node_proto = graph_proto->add_node(); + node_proto->set_op_type("Conv"); + node_proto->add_input(input_x); + node_proto->add_input(input_w); + node_proto->add_output(std::to_string(node_idx)); + // set attributes + AnfNodePtr op = node->input(0); + auto op_value = dyn_cast(op); + auto prim = dyn_cast(op_value->value()); + // set dilations + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("dilations"); + SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, + prim); + // set group + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("group"); + onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + onnx_attr_proto->set_i(x_shape->shape()[1]); + // set kernel_shape + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("kernel_shape"); + SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, + prim); + + // set pad + onnx_attr_proto = node_proto->add_attribute(); + auto attr_value = GetValue(prim->GetAttr("pad_mode")); + onnx_attr_proto->set_name("auto_pad"); + onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); + if (attr_value == "valid") { + onnx_attr_proto->set_s("VALID"); + } else if (attr_value == "same") { + onnx_attr_proto->set_s("SAME_UPPER"); + } else { + onnx_attr_proto->set_name("pads"); + SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); + } + // set strides + onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name("strides"); + SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); +} + +void OnnxExporter::ExportPrimTile(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto multiples = node->input(2); + std::string name_multiples; + if (multiples->isa()) { + auto const_node_idx = AllocateNodeIndex(); + (*node_map_ptr)[multiples] = const_node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + name_multiples = std::to_string(const_node_idx); + node_proto->add_output(name_multiples); + + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("repeat"); + + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + ConvertTupleToTensor(dyn_cast(multiples)->value(), attr_proto->mutable_t()); + } else { + name_multiples = GetNodeInputName(multiples, node_map_ptr, graph_proto); + MS_LOG(EXCEPTION) << "Need to insert op convert variable from tuple to tensor for Tile."; + } + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Tile"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_multiples); +} + +void OnnxExporter::ExportPrimSquare(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + std::string name_exponent; + auto const_node_idx = AllocateNodeIndex(); + onnx::NodeProto *node_proto_exp = graph_proto->add_node(); + name_exponent = std::to_string(const_node_idx); + node_proto_exp->add_output(name_exponent); + + node_proto_exp->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto_exp->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + tensor_proto->set_name("exponent"); + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(1)); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + tensor_proto->add_int64_data(2); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Pow"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_exponent); +} + +void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); + auto name_indices = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); + auto axis = node->input(3)->cast()->value(); + + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->set_op_type("Gather"); + node_proto->add_output(std::to_string(node_idx)); + node_proto->add_input(name_x); + node_proto->add_input(name_indices); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto->set_i(static_cast<::google::protobuf::int64>(dyn_cast(axis)->value())); +} + +void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert + if (node->IsApply(prim::kPrimReshape)) { + return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); + } + + if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { + return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) + if (node->IsApply(prim::kPrimCast)) { + return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); + } + + // ONNX PRelu requires unidirectional broadcasting, here need some process + if (node->IsApply(std::make_shared("PReLU"))) { + return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) + if (node->IsApply(std::make_shared("ReLU6"))) { + return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) + if (node->IsApply(std::make_shared("DepthwiseConv2dNative"))) { + return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Tile(x) --> ONNX Tile(x, repeat) + if (node->IsApply(prim::kPrimTile)) { + return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore Square(x) --> ONNX Pow(x, 2) + if (node->IsApply(prim::kPrimSquare)) { + return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); + } + + // MindSpore GatherV2(x, indices, axis) --> ONNX Pow(x, indices) + if (node->IsApply(prim::kPrimGatherV2)) { + return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); + } + + auto inputs = node->inputs(); + if (inputs.size() < 1) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + AnfNodePtr op = inputs[0]; + std::vector op_inputs; + // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator + for (size_t i = 1; i < inputs.size(); i++) { + op_inputs.push_back(inputs[i]); + } + auto op_value = dyn_cast(op); + if (op_value == nullptr) { + MS_LOG(EXCEPTION) << "Need to support node op type " << op->type_name(); + } + auto prim = dyn_cast(op_value->value()); + if (prim == nullptr) { + MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); + } + + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); +} + +size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *const graph_proto) { + auto op_map = OpConvertRegistry::GetOpConvertMap(); + auto op_iter = op_map.find(prim->name()); + if (op_iter == op_map.end()) { + MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; + } + const OpNameInfo &op_convert_info = op_iter->second; + + auto node_idx = AllocateNodeIndex(); + + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(std::to_string(node_idx)); + node_proto->set_op_type(op_convert_info.onnx_type()); + + // Set inputs + for (const auto &input : inputs) { + auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); + node_proto->add_input(input_name); + } + + // Set node attribute + for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { + const std::string &attr_name = attr.attr_name(); + ValuePtr attr_value = nullptr; + if (!attr_name.empty()) { + attr_value = prim->GetAttr(attr_name); + if (attr_value == nullptr) { + MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; + } + } + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); + onnx_attr_proto->set_name(attr.onnx_attr_name()); + attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); + } + return node_idx; +} + +void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto conv_node = dyn_cast(node->input(1)); + auto input_x = conv_node->input(1); // conv input x + auto input_w = conv_node->input(2); // conv weight(filter) + auto input_b = node->input(2); // conv bias + + PrimitivePtr prim_conv = dyn_cast((dyn_cast(conv_node->input(0)))->value()); + std::vector inputs{input_x, input_w, input_b}; + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); +} + +void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto matmul_node = dyn_cast(node->input(1)); + auto input_x = matmul_node->input(1); // matmul input x + auto input_y = matmul_node->input(2); // matmul input y + auto input_b = node->input(2); // matmul bias + + PrimitivePtr prim_matmul = dyn_cast((dyn_cast(matmul_node->input(0)))->value()); + std::vector inputs{input_x, input_y, input_b}; + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); +} + +void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto batch_norm_node = dyn_cast(node->input(1)); + + PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); + std::vector inputs; + for (size_t i = 1; i < batch_norm_node->inputs().size(); i++) { + inputs.push_back(batch_norm_node->input(i)); + } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); +} + +void OnnxExporter::ExportMergeMaxPoolWithArgmax(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + auto maxpool_with_argmax_node = dyn_cast(node->input(1)); + + PrimitivePtr prim_maxpool_with_argmax = + dyn_cast((dyn_cast(maxpool_with_argmax_node->input(0)))->value()); + std::vector inputs; + for (size_t i = 1; i < maxpool_with_argmax_node->inputs().size(); i++) { + inputs.push_back(maxpool_with_argmax_node->input(i)); + } + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_maxpool_with_argmax, inputs, graph_proto); +} + +void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + if (node->inputs().size() != 2) { + MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; + } + AnfNodePtr arg = node->input(1); + std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); + output_proto->set_name(name); + SetValueInfoType(arg, output_proto, false); +} + +std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { + if (node->isa()) { + auto iter = node_map_ptr->find(node); + if (iter == node_map_ptr->end()) { + MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in node_map"; + } + return std::to_string(iter->second); + } + + if (node->isa()) { + return node->ToString(); + } + + // for ValueNode input, create a Constant Operator + if (node->isa()) { + auto iter = node_map_ptr->find(node); + if (iter != node_map_ptr->end()) { + return std::to_string(iter->second); + } + // the id number starts at 1, so the id of created node should be size of map plus one + auto node_idx = AllocateNodeIndex(); + (*node_map_ptr)[node] = node_idx; + std::string node_name = std::to_string(node_idx); + + onnx::NodeProto *node_proto = graph_proto->add_node(); + node_proto->add_output(node_name); + + SetNodeAttribute(node->cast()->value(), node_proto); + + return node_name; + } + + MS_LOG(EXCEPTION) << "Unexpected node type " << node->type_name(); +} + +void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { + auto tuple_ptr = dyn_cast(value); + MS_EXCEPTION_IF_NULL(tuple_ptr); + if (tuple_ptr->size() == 0) { + MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, the size of converted tuple is 0."; + } + auto type_id = (*tuple_ptr)[0]->type()->type_id(); + for (size_t i = 1; i < tuple_ptr->size(); ++i) { + if ((*tuple_ptr)[i]->type()->type_id() != type_id) { + MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, type of tuple elements is not same."; + } + } + + tensor_proto->add_dims(static_cast<::google::protobuf::int64>(tuple_ptr->size())); + tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); + for (size_t i = 0; i < tuple_ptr->size(); ++i) { + ValuePtr elem = (*tuple_ptr)[i]; + if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else if (elem->isa()) { + tensor_proto->add_int64_data(dyn_cast(elem)->value()); + } else { + MS_LOG(EXCEPTION) << "Convert tuple to tensor fail, unexpected tuple element type " << elem->type()->type_name() + << "."; + } + } +} + +void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { + node_proto->set_op_type("Constant"); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); + attr_proto->set_name("value"); + if (value->isa()) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + auto casted_value = dyn_cast(value); + if (casted_value == nullptr) { + MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; + } + auto attr_value = casted_value->value(); + attr_proto->set_i(static_cast<::google::protobuf::int64>(attr_value)); + attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); + } else if (value->isa()) { + attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); + onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); + auto data = dyn_cast(value); + tensor_proto->set_raw_data(data->data_c(), static_cast(data->data().nbytes())); + auto dtype = data->data_type(); + auto shape = data->shape_c(); + + tensor_proto->set_data_type(GetOnnxDataType(dtype)); + for (const auto &dim : shape) { + tensor_proto->add_dims(dim); + } + } else { + MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; + } +} + +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { + OnnxExporter exporter; + return exporter.GetOnnxProtoString(func_graph); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h deleted file mode 100644 index caac4258df..0000000000 --- a/mindspore/ccsrc/transform/op_adapter.h +++ /dev/null @@ -1,913 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_ADAPTER_H_ -#define TRANSFORM_OP_ADAPTER_H_ - -#include -#include -#include -#include - -#include "transform/op_adapter_util.h" -#include "utils/utils.h" -namespace mindspore { -namespace transform { -static uint32_t CustomInferFunc(const Operator &) { return 0; } - -template -class OpAdapter : public BaseOpAdapter { - public: - using OpType = T; - OpAdapter() {} - explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} - ~OpAdapter() override {} - - bool IsCustomOp(const OperatorPtr &op) { - MS_EXCEPTION_IF_NULL(op); - auto it = cus_input_map_.find(op->GetOpType()); - if (it == cus_input_map_.end()) { - return false; - } - return true; - } - - Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(prim); - // Create the map of custom op from input index to input name. - std::unordered_map input_map; - auto value = prim->GetAttr("input_names"); - if (value == nullptr) { - cus_output_map_[prim->name()] = input_map; - return NOT_FOUND; - } - - auto input_names = GetValue>(value); - for (size_t i = 0; i < input_names.size(); ++i) { - // input_map begin form 1 - input_map[i + 1] = input_names[i]; - op->CustomInputRegister(input_names[i]); - } - - if (cus_input_map_.find(prim->name()) == cus_input_map_.end()) { - cus_input_map_[prim->name()] = input_map; - } - return SUCCESS; - } - - Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(prim); - // Create the map of custom op from output index to output name. - std::unordered_map output_map; - auto value = prim->GetAttr("output_names"); - if (value == nullptr) { - // generate a empty output_map for it - cus_output_map_[prim->name()] = output_map; - return NOT_FOUND; - } - - auto output_names = GetValue>(value); - for (size_t i = 0; i < output_names.size(); ++i) { - // output_map begin form 0 - output_map[i] = output_names[i]; - op->CustomOutputRegister(output_names[i]); - } - - if (cus_output_map_.find(prim->name()) == cus_output_map_.end()) { - cus_output_map_[prim->name()] = output_map; - } - return SUCCESS; - } - - // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs. - OperatorPtr GenerateCustomOp(const AnfNodePtr anf) { - MS_EXCEPTION_IF_NULL(anf); - auto node = anf->cast(); - if (node == nullptr) { - return nullptr; - } - - if (node->inputs().empty()) { - MS_LOG(EXCEPTION) << "length of node inputs is empty"; - } - - auto prim = GetValueNode(node->inputs()[0]); - MS_EXCEPTION_IF_NULL(prim); - auto op = std::make_shared(node->fullname_with_scope(), prim->name()); - if (GenerateCustomOpInputMap(op, prim) != SUCCESS) { - MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "]."; - } - - if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) { - MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "]."; - } - - op->CustomInferFuncRegister(CustomInferFunc); - - return op; - } - - OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { - OperatorPtr op = nullptr; - // There are duplicate names in ANF graph, do not assign ANF node name to GE - // GE will generate unique name automatically - if (anf != nullptr && anf->fullname_with_scope() != "") { - MS_LOG(DEBUG) << anf->fullname_with_scope(); - op = std::make_shared(anf->fullname_with_scope()); - } else { - MS_LOG(DEBUG) << "no fullname_with_scope"; - op = std::make_shared(); - } - - // set dynamic output num if op use DYNAMIC_OUTPUT - if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) { - TypePtr type = anf->Type(); - if (type == nullptr) { - MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!"; - } - size_t num = type->isa() ? (type->cast>()->size()) : 1; - MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString() - << ", num:" << num; - dyn_output_map_.begin()->second.create_dyn_output(op, static_cast(num)); - } - return op; - } - - OperatorPtr generate(const AnfNodePtr &anf) override { - OperatorPtr op = nullptr; - if (IsCustomCNode(anf)) { - op = GenerateCustomOp(anf); - } else { - op = GenerateNormalOp(anf); - } - return op; - } - - OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } - - const std::unordered_map &getInputMap() override { return input_map_; } - const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } - const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } - const std::unordered_map &getOutputMap() override { return output_map_; } - const std::unordered_map &getDynSubgraphMap() override { return dyn_subgraph_map_; } - - Status SetOpSubgraphFunc(const OperatorPtr &op, int index, std::shared_ptr> branches) { - MS_EXCEPTION_IF_NULL(op); - auto it = dyn_subgraph_map_.find(index); - if (it != dyn_subgraph_map_.end()) { - auto size = branches->size(); - it->second.create_dyn_subgraph(op, static_cast(size)); - for (size_t i = 0; i < size; i++) { - it->second.set_subgraph(op, static_cast(i), std::make_shared((*branches)[i])); - } - return SUCCESS; - } - return NOT_FOUND; - } - - int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) override { - return static_cast(SetOpSubgraphFunc(op, index, branches)); - } - - Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(input); - auto it = cus_input_map_.find(op->GetOpType()); - if (it == cus_input_map_.end()) { - return NOT_FOUND; - } - std::unordered_map &input_map = it->second; - - if ((input_map.find(index) != input_map.end())) { - MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; - (void)op->SetInput(input_map[index], *input); - return SUCCESS; - } - return NOT_FOUND; - } - - Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { - MS_EXCEPTION_IF_NULL(op); - auto it = input_map_.find(index); - if (it != input_map_.end()) { - MS_EXCEPTION_IF_NULL(input); - MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name; - it->second.set_op(op, input); - return SUCCESS; - } - return NOT_FOUND; - } - - int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { - if (IsCustomOp(op)) { - auto cus_op = std::dynamic_pointer_cast(op); - return static_cast(SetCustomOpInput(cus_op, index, input)); - } else { - return static_cast(SetNormalOpInput(op, index, input)); - } - } - - Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { - MS_EXCEPTION_IF_NULL(op); - auto it = cus_input_map_.find(op->GetOpType()); - if (it == cus_input_map_.end()) { - return NOT_FOUND; - } - - std::unordered_map &input_map = it->second; - if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { - if (handle.out.empty()) { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; - (void)op->SetInput(input_map[index], *(handle.op)); - } else { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" - << input_map[index]; - (void)op->SetInput(input_map[index], *(handle.op), handle.out); - } - return SUCCESS; - } - return NOT_FOUND; - } - - Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { - MS_EXCEPTION_IF_NULL(op); - auto it = input_map_.find(index); - if ((handle.op != nullptr) && (it != input_map_.end())) { - if (handle.out.empty()) { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name; - it->second.set_op(op, handle.op); - } else { - MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":" - << it->second.name; - it->second.set_handle(op, handle); - } - return SUCCESS; - } - return NOT_FOUND; - } - - int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { - if (IsCustomOp(op)) { - auto cus_op = std::dynamic_pointer_cast(op); - return static_cast(SetCustomOpInput(cus_op, index, handle)); - } else { - return static_cast(SetNormalOpInput(op, index, handle)); - } - } - - int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { - MS_EXCEPTION_IF_NULL(handler_vec); - if (IsCustomOp(op)) { - MS_LOG(ERROR) << "Custom Op do not support dynamic input"; - return static_cast(FAILED); - } - MS_EXCEPTION_IF_NULL(op); - auto it = dyn_input_map_.find(index); - if (it != dyn_input_map_.end()) { - it->second.create_dyn_input(op, static_cast(handler_vec->size())); - for (unsigned int i = 0; i < handler_vec->size(); ++i) { - OutHandler h = (*handler_vec)[i]; - MS_EXCEPTION_IF_NULL(h.op); - if (h.out.empty()) { - MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name; - it->second.set_op(op, (i) /* index start from 0 */, h.op); - } else { - MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":" - << it->second.name; - it->second.set_handle(op, i, h); - } - } - return 0; - } - return static_cast(NOT_FOUND); - } - - OutHandler getOutput(const OperatorPtr &op, int index) override { - MS_EXCEPTION_IF_NULL(op); - if (IsCustomOp(op)) { - return getCustomOutput(op, index); - } - return getNormalOutput(op, index); - } - - OutHandler getCustomOutput(const OperatorPtr &op, int index) { - MS_EXCEPTION_IF_NULL(op); - auto it = cus_output_map_.find(op->GetOpType()); - if (it == cus_output_map_.end()) { - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!"; - return OutHandler(); - } - - std::unordered_map &output_map = it->second; - - if ((output_map.find(index) != output_map.end())) { - return OutHandler(op, output_map[index]); - } - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!"; - return OutHandler(); - } - - OutHandler getNormalOutput(const OperatorPtr &op, int index) { - MS_EXCEPTION_IF_NULL(op); - if (!dyn_output_map_.empty() && !output_map_.empty()) { - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; - return OutHandler(); - } - auto it = output_map_.find(index); - if (it != output_map_.end()) { - return OutHandler(op, it->second.name); - } else if (!dyn_output_map_.empty()) { - return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index)); - } else { - MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!"; - return OutHandler(); - } - } - - Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { - MS_EXCEPTION_IF_NULL(type); - std::string format = "NCHW"; - if (op->GetOpType() == kExtractImagePatchesOpName) { - format = "NHWC"; - } - - auto desc = CreateOutputDesc(dyn_cast(shp), type, format); - if (desc == nullptr) { - MS_LOG(ERROR) << "Update output descriptor failed!"; - return FAILED; - } - - if (IsCustomOp(op)) { - if (cus_output_map_.find(op->GetOpType()) == cus_output_map_.end() || - (cus_output_map_[op->GetOpType()].empty())) { - MS_LOG(ERROR) << "This op does not create custom output map"; - return FAILED; - } - auto cus_op = std::dynamic_pointer_cast(op); - MS_EXCEPTION_IF_NULL(cus_op); - std::unordered_map output_map = cus_output_map_[op->GetOpType()]; - (void)cus_op->UpdateOutputDesc(output_map[0], *desc); - } else { - if (output_map_.empty()) { - MS_LOG(INFO) << "This op does not have output map"; - return FAILED; - } - output_map_.begin()->second.update_out_desc(op, *desc); - } - return SUCCESS; - } - - size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { - MS_EXCEPTION_IF_NULL(cus_op); - if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { - MS_LOG(ERROR) << "This op does not create custom output map"; - return 0; - } - size_t output_size = cus_output_map_[cus_op->GetOpType()].size(); - return output_size; - } - - std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, - const std::string &format) { - if (shape_ptr == nullptr) { - MS_LOG(ERROR) << "Shape ptr is nullptr"; - return nullptr; - } - - if (type == nullptr) { - MS_LOG(ERROR) << "Type ptr is nullptr"; - return nullptr; - } - - TypeId me_type = type->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(type)->element()->type_id(); - } - auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format); - return desc; - } - - Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { - auto tuple_shp = dyn_cast(shp); - MS_EXCEPTION_IF_NULL(tuple_shp); - - size_t output_size = 0; - bool is_custom_op = IsCustomOp(op); - if (is_custom_op) { - output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast(op)); - } else { - output_size = output_map_.size(); - } - - if (output_size == 0) { - MS_LOG(INFO) << "This op does not have output map"; - return FAILED; - } - - if (output_size != tuple_shp->shape().size()) { - MS_LOG(ERROR) << "output_map is not equal tuple_shape size"; - return FAILED; - } - std::string format = "NCHW"; - if (op->GetOpType() == kTopKOpName) { - format = "NHWC"; - } - for (size_t i = 0; i < tuple_shp->shape().size(); ++i) { - auto tuple_type = dyn_cast(type); - MS_EXCEPTION_IF_NULL(tuple_type); - TypePtr type_elem = tuple_type->elements()[i]; - - auto desc = CreateOutputDesc(dyn_cast(tuple_shp->shape()[i]), type_elem, format); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create output descriptor failed!"; - return FAILED; - } - - if (is_custom_op) { - (void)std::dynamic_pointer_cast(op)->UpdateOutputDesc(cus_output_map_[op->GetOpType()][i], - *desc); - } else { - auto it = output_map_.find(i); - if (it != output_map_.end()) { - it->second.update_out_desc(op, *desc); - } - } - } - return SUCCESS; - } - - std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - TypeId me_type = node->Type()->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(node->Type())->element()->type_id(); - } - if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) { - return nullptr; - } - - std::vector shape; - auto shape_ptr = dyn_cast(node->Shape()); - if (nullptr != shape_ptr) { - shape = shape_ptr->shape(); - } - - auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW"); - if (desc == nullptr) { - MS_LOG(ERROR) << "Update output descriptor failed!"; - return nullptr; - } - return desc; - } - - void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is nullptr"; - return; - } - MS_EXCEPTION_IF_NULL(node); - - auto inputs = node->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - auto it = input_map_.find(i); - if (it != input_map_.end()) { - auto desc = CreateNodeDesc(inputs[i]); - if (desc == nullptr) { - continue; - } - if (op->GetOpType() == kExtractImagePatchesOpName) { - desc->SetFormat(ge::Format::FORMAT_NHWC); - } - it->second.update_input_desc(op, *desc); - } - } - } - - void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { - if (op == nullptr) { - MS_LOG(ERROR) << "op is nullptr"; - return; - } - MS_EXCEPTION_IF_NULL(node); - - if (cus_input_map_.find(op->GetOpType()) == cus_input_map_.end() || (cus_input_map_[op->GetOpType()].empty())) { - MS_LOG(ERROR) << "This op does not create custom input map"; - return; - } - - std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; - auto inputs = node->cast()->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - if (input_map.find(i) != input_map.end()) { - auto desc = CreateNodeDesc(inputs[i]); - if (desc == nullptr) { - continue; - } - (void)op->UpdateInputDesc(input_map[i], *desc); - } - } - } - - void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(op); - MS_EXCEPTION_IF_NULL(node); - if (IsCustomOp(op)) { - auto cus_op = std::dynamic_pointer_cast(op); - UpdateCustomOpInputDesc(cus_op, node); - } else { - UpdateNormalOpInputDesc(op, node); - } - } - - void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, - const AnfNodePtr &node) override { - if (op == nullptr) { - MS_LOG(ERROR) << "op is nullptr"; - return; - } - MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Op name is " << op->GetName(); - - auto normal_shape_ptr = dyn_cast(shp); - auto no_shape_ptr = dyn_cast(shp); - - if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) { - if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) { - return; - } - } else if (nullptr != dyn_cast(shp)) { - if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) { - return; - } - } else { - MS_LOG(WARNING) << "Update output desc failed, unknow output shape type"; - return; - } - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return; - } - - // Need to update input_desc while the output_desc is updated - updateInputDesc(op, node); - } - - int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { - auto it = attr_map_.find(attrKey); - if (it != attr_map_.end()) { - // switch case for each avalilable attribute type - MS_LOG(INFO) << "Set attr: " << attrKey << "(" << it->second.name << "), value: " << attrValue->ToString(); - AddAttrToDrawGraph(attrKey + std::string("=") + attrValue->ToString()); - it->second.set_attr(op, attrValue); - return 0; - } - return static_cast(NOT_FOUND); - } - - int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { - enum ValueType { - SINGLE_VALUE = 0, - SEQUEUE_VALUE, - UNKNOWN_VALUE, - }; - - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(op); - - ValueType value_type = SINGLE_VALUE; - for (auto item : prim->attrs()) { - if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - (void)op->SetAttr(item.first, GetValue(item.second)); - } else if (item.second->isa()) { - value_type = SEQUEUE_VALUE; - auto val_seq = item.second->cast(); - if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else if ((*val_seq)[0]->isa()) { - (void)op->SetAttr(item.first, GetValue>(item.second)); - } else { - MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() - << ", attr name: " << item.first << ", value: " << item.second->ToString(); - } - } else { - value_type = UNKNOWN_VALUE; - MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name() - << ", attr name: " << item.first << ", value: " << item.second->ToString(); - return static_cast(NOT_FOUND); - } - - if (value_type == SINGLE_VALUE) { - AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString()); - } else if (value_type == SEQUEUE_VALUE) { - AddAttrToDrawGraph(item.first + std::string("=") + "[...]"); - } - } - return 0; - } - - int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { - int ret = 0; - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(op); - for (auto &it : attr_map_) { - auto value = prim->GetAttr(it.first); - if (value != nullptr) { - // set attr from primitive - ret = setAttr(op, it.first, value); - if (ret) { - return ret; - } - } else { - // set attr from extra_attr - auto it_extra = extra_attr_.find(it.first); - if (it_extra != extra_attr_.end()) { - ret = setAttr(op, it.first, it_extra->second); - if (ret) { - return ret; - } - } - } - } - return 0; - } - - int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { - int ret = 0; - if (IsCustomPrim(prim)) { - auto cus_op = std::dynamic_pointer_cast(op); - ret = SetCustomOpAttr(cus_op, prim); - } else { - ret = SetNormalOpAttr(op, prim); - } - return ret; - } - - int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { - // no attribute for lonely node - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - return 0; - } - - auto cnode = node->cast(); - if (cnode == nullptr) { - return 0; - } - - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - return 0; - } - - // get Attr T from abstract of anfnode first, - // if attr "T" appears in primitive, the primitive T will cover this one - if (attr_map_.find("T") != attr_map_.end()) { - // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype - TypePtr type; - if (inputs.size() > 1) { - type = inputs[1]->Type(); - } else { - type = node->Type(); - } - if (type != nullptr) { - (void)setAttr(op, "T", MakeValue(type)); - } - } - - // set attr from primitive and ExtraAttr - if (IsValueNode(inputs[0])) { - // set attr from primitive - PrimitivePtr prim = GetValueNode(inputs[0]); - int ret = setAttr(op, prim); - if (ret != 0) { - return ret; - } - } - - // set attr from const input - for (auto &it : input_attr_map_) { - if (inputs.size() <= it.first || !inputs[it.first]->isa()) { - continue; - } - auto const_value = GetValueNode(inputs[it.first]); - MS_LOG(INFO) << "Set attr: input_" << it.first << "(" << it.second.name - << "), value: " << const_value->ToString(); - if (const_value->isa()) { - continue; - } - AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString()); - it.second.set_attr(op, const_value); - } - return 0; - } - - std::unordered_map GetExtraAttr() override { return extra_attr_; } - - private: - template - static S ConvertAny(const ValuePtr &value, const AnyTraits &) { - return GetValue(value); - } - - // specialization for reverse bool - static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { - return reverse != GetValue(value); - } - - template - static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { - return ConvertAnyUtil(value, traits_from, traits_to); - } - - // specialization for tensor - static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { - // To-DO the format may read from ME tensor - return ConvertAnyUtil(value, traits); - } - - // specialization for int - static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { - return static_cast(GetValue(value)); - } - - // specialization for int or tuple broadcast to Vector - static std::vector ConvertAny(const ValuePtr &value, const std::string &name, - const AnyTraits> anyTraitsInt) { - return ConvertAnyUtil(value, name, anyTraitsInt); - } - - static std::vector> ConvertAny(const ValuePtr &value, - const AnyTraits>>) { - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Value: " << value->type_name(); - std::vector> list; - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name(); - } - auto vec = value->cast(); - MS_EXCEPTION_IF_NULL(vec); - for (auto &it : vec->value()) { - MS_EXCEPTION_IF_NULL(it); - if (!it->isa()) { - MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); - } - auto sub_vector = it->cast(); - std::vector sublist; - for (auto &item : sub_vector->value()) { - sublist.push_back(static_cast(GetValue(item))); - } - list.push_back(sublist); - } - return list; - } - - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, - const AnyTraits>) { - MS_EXCEPTION_IF_NULL(value); - MS_LOG(DEBUG) << "Value: " << value->type_name(); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name(); - } - auto vec = value->cast(); - std::vector list; - for (auto &it : vec->value()) { - MS_EXCEPTION_IF_NULL(it); - if (!it->isa()) { - MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); - } - auto sub_vector = it->cast(); - for (auto &item : sub_vector->value()) { - list.push_back(static_cast(GetValue(item))); - } - } - return list; - } - - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, - const AnyTraits>) { - MS_EXCEPTION_IF_NULL(value); - MS_LOG(INFO) << "Value: " << value->type_name(); - std::vector list; - if (value->isa()) { - auto vec = value->cast(); - MS_EXCEPTION_IF_NULL(vec); - for (auto &it : vec->value()) { - list.push_back(static_cast(GetValue(it))); - } - return list; - } - if (value->isa()) { - list.push_back(static_cast(GetValue(value))); - return list; - } - MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); - } - - static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, - const AnyTraits anyTraitsStr) { - return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); - } - - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, - const AnyTraits anyTraitsFlo) { - return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); - } - - static std::vector ConvertAny(const ValuePtr &value, const std::string &format, - const AnyTraits> anyTraitsVec, - const AnyTraits anyTraitsInt) { - return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); - } - - // convert value list for value tuple to vector - template - static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, - const AnyTraits> anyTraitsQ) { - return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); - } - - static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { - auto name = GetValue(value); - auto it = enum_map_.find(name); - int v = 0; - if (it != enum_map_.end()) { - v = it->second; - } - return v; - } - - static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { - return ConvertAnyUtil(value, anyTraitsGE); - } - - // convert any value to tensor - static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { - return ConvertAnyUtil(value, anyTraitsValue); - } - - static const std::unordered_map input_map_; - static const std::unordered_map dyn_input_map_; - static const std::unordered_map output_map_; - static const std::unordered_map dyn_output_map_; - static const std::unordered_map dyn_subgraph_map_; - static const std::unordered_map attr_map_; - static const std::unordered_map enum_map_; - // convert input from anf graph to Attr in Operators - static const std::unordered_map input_attr_map_; - static std::unordered_map> cus_input_map_; - static std::unordered_map> cus_output_map_; - std::unordered_map extra_attr_; - std::unordered_map name_counts_; -}; - -template -const std::unordered_map OpAdapter::input_map_; -template -const std::unordered_map OpAdapter::dyn_input_map_; -template -const std::unordered_map OpAdapter::output_map_; -template -const std::unordered_map OpAdapter::dyn_output_map_; -template -const std::unordered_map OpAdapter::dyn_subgraph_map_; -template -const std::unordered_map OpAdapter::attr_map_; -template -const std::unordered_map OpAdapter::enum_map_; -template -const std::unordered_map OpAdapter::input_attr_map_; -template -std::unordered_map> OpAdapter::cus_input_map_; -template -std::unordered_map> OpAdapter::cus_output_map_; - -// specialization for method -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_OP_ADAPTER_H_ diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h deleted file mode 100644 index 2c6fcedf09..0000000000 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ /dev/null @@ -1,198 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_ADAPTER_BASE_H_ -#define TRANSFORM_OP_ADAPTER_BASE_H_ - -#include -#include -#include -#include -#include -#include - -#include "transform/util.h" -#include "ir/anf.h" -#include "ir/primitive.h" -#include "ir/value.h" -#include "transform/types.h" -#ifdef ENABLE_GE -#ifdef OPEN_SOURCE -#include "graph/types.h" -#endif -#endif - -#include "graph/operator_reg.h" -#ifdef OPEN_SOURCE -#include "ge/client/ge_api.h" -#else -#include "external/ge/ge_api.h" -#endif -#include "graph/tensor.h" -#include "transform/all_ops.h" - -namespace ge { -class CustomOperator : public Operator { - public: - CustomOperator(const string &name, const string &type) : Operator(name, type) {} - - ~CustomOperator() override{}; - - void CustomInputRegister(const string &name) { Operator::InputRegister(name); } - - void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } - - void CustomInferFuncRegister(const std::function &func) { - Operator::InferFuncRegister(func); - } -}; -} // namespace ge - -namespace mindspore { -namespace transform { -using CusOperatorPtr = std::shared_ptr; -using CustomOperator = ge::CustomOperator; - -struct OutHandler { - OperatorPtr op; - std::string out; - OutHandler() : op(nullptr), out("") {} - OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} -}; - -struct ControlEdge { - OperatorPtr src_op; - OperatorPtr dest_op; -}; - -using AttrFunc = std::function; -using OutputFunc = std::function; -using InputOpFunc = std::function; -using InputHandleFunc = std::function; -using CreateDynInputOpFunc = std::function; -using DynInputOpFunc = std::function; -using DynInputHandleFunc = std::function; -using UpdateOutputDescFunc = std::function; -using CreateDynOutputOpFunc = std::function; -using CreateDynSubGraphFunc = std::function; -using DynSubGraphFunc = std::function; - -struct AttrDesc { - std::string name; - AttrFunc set_attr; -}; - -struct InputDesc { - std::string name; - InputOpFunc set_op; - InputHandleFunc set_handle; - UpdateOutputDescFunc update_input_desc; -}; - -struct DynInputDesc { - std::string name; - CreateDynInputOpFunc create_dyn_input; - DynInputOpFunc set_op; - DynInputHandleFunc set_handle; -}; - -struct DynSubGraphDesc { - std::string name; - CreateDynSubGraphFunc create_dyn_subgraph; - DynSubGraphFunc set_subgraph; -}; - -struct OutputDesc { - std::string name; - UpdateOutputDescFunc update_out_desc; -}; - -struct DynOutputDesc { - std::string name; - CreateDynOutputOpFunc create_dyn_output; -}; - -class BaseOpAdapter { - public: - virtual ~BaseOpAdapter() {} - virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; - virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } - virtual int setSubgraph(const OperatorPtr &op, int index, std::shared_ptr> branches) = 0; - virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; - virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; - virtual int setInput(const OperatorPtr &op, int index, - const std::shared_ptr> &handler_vec) = 0; - virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; - virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; - virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; - virtual std::unordered_map GetExtraAttr() = 0; - template ::value>::type> - int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { - return setAttr(op, attrKey, MakeValue(attrValue)); - } - template ::value>::type> - int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { - return setAttr(op, attrKey, MakeValue(attrValue)); - } - virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; - virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, - const AnfNodePtr &node) = 0; - virtual const std::unordered_map &getInputMap() = 0; - virtual const std::unordered_map &getInputAttrMap() = 0; - virtual const std::unordered_map &getDynInputMap() = 0; - virtual const std::unordered_map &getOutputMap() = 0; - virtual const std::unordered_map &getDynSubgraphMap() = 0; - void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } - const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } - void clearAttrVect() { attrs_vec_.clear(); } - - private: - std::vector attrs_vec_; -}; - -using OpAdapterPtr = std::shared_ptr; - -enum AttrType { - ATTR_INT = 0, - ATTR_FLOAT, - ATTR_DOUBLE, - ATTR_STRING, - ATTR_TENSOR, - ATTR_BOOL, - ATTR_LIST_INT, - ATTR_LIST_ANY_INT, - ATTR_ENUM -}; - -struct GeEnum {}; -struct TFType {}; -struct GEType {}; - -// declare Any type -template -struct AnyTraits { - using type = T; -}; - -template <> -struct AnyTraits { - using type = int64_t; -}; - -using ExtraAttr = std::unordered_map; -} // namespace transform -} // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_BASE_H_ diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc deleted file mode 100644 index cae43c13dc..0000000000 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ /dev/null @@ -1,264 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/op_adapter_util.h" - -#include -#include -#include - -#include "utils/utils.h" -#include "transform/op_adapter_base.h" - -namespace mindspore { -namespace transform { -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { - // To-DO the format may read from ME tensor - MS_EXCEPTION_IF_NULL(value); - auto me_tensor = value->cast(); - auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW); - return ge_tensor == nullptr ? GeTensor() : *ge_tensor; -} - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, - const AnyTraits>) { - MS_EXCEPTION_IF_NULL(value); - std::vector list; - if (name == "pad") { - if (!value->isa()) { - MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); - } - auto vec = value->cast(); - list.resize(vec->value().size() + 2); - list[0] = 1; - list[1] = 1; - (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2, - [](const ValuePtr &val) { return static_cast(GetValue(val)); }); - } else { - int64_t data = GetValue(value); - int size = 2; // 2 int in list - list = TransformUtil::ConvertIntToList(data, size); - } - - return list; -} - -std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - auto vec = value->cast(); - if (nullptr == vec) { - MS_LOG(EXCEPTION) << "not ValueTuplePtr"; - } - std::ostringstream buffer; - int i = 0; - for (auto &it : vec->value()) { - if (i != 0) { - buffer << ","; - } - buffer << GetValue(it); - i++; - } - return buffer.str(); -} - -std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - auto vec = value->cast(); - if (nullptr == vec) { - MS_LOG(EXCEPTION) << "not ValueTuplePtr"; - } - std::vector list; - list.resize(vec->value().size()); - (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr &val) { return static_cast(GetValue(val)); }); - return list; -} - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, - const AnyTraits>, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - auto vec = value->cast(); - if (nullptr == vec) { - MS_LOG(EXCEPTION) << "not ValueTuplePtr"; - } - std::vector list; - list.resize(vec->value().size()); - (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr &val) { return static_cast(GetValue(val)); }); - if (format == kOpFormat_NHWC) { - if (list.size() < 4) { - MS_LOG(EXCEPTION) << "The size of list is less than 4"; - } else { - int64_t temp = list[1]; - list[1] = list[2]; - list[2] = list[3]; - list[3] = temp; - } - } - return list; -} - -GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - if (!value->isa()) { - MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() - << ", type: " << value->type_name() << ", value should be a Typeptr"; - } - auto type = value->cast(); - MS_EXCEPTION_IF_NULL(type); - TypeId me_type = type->type_id(); - if (kObjectTypeTensorType == me_type) { - me_type = dyn_cast(type)->element()->type_id(); - } - return TransformUtil::ConvertDataType(me_type); -} - -GeTensor VectorToTensorUtil(const ValuePtr &value) { - // convert tuple or list to ge tensor, only supported one dim for now - MS_EXCEPTION_IF_NULL(value); - auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); - if (vec.empty()) { - MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor"; - return GeTensor(); - } - MS_EXCEPTION_IF_NULL(vec[0]); - if (vec[0]->isa()) { - MS_LOG(INFO) << "convert value to tensor with data type = Int32"; - auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); - auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; - } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); - } else if (vec[0]->isa()) { - MS_LOG(INFO) << "convert value to tensor with data type = Float32"; - auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); - auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; - } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); - } else if (vec[0]->isa()) { - MS_LOG(INFO) << "convert value to tensor with data type = Bool"; - // We use uint8_t to save bool type data - auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); - auto desc = TransformUtil::GetGeTensorDesc({static_cast(vec.size())}, kNumberTypeBool, kOpFormat_NCHW); - if (desc == nullptr) { - MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; - } - return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); - } else { - MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); - } - - return GeTensor(); -} - -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - // convert me tensor to ge tensor - return ConvertAnyUtil(value, AnyTraits()); - } else if (value->isa() || value->isa()) { - return VectorToTensorUtil(value); - } else if (value->isa()) { - // convert scalar Int to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = Int32"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); - } else if (value->isa()) { - // convert scalar Int64 to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); - } else if (value->isa()) { - // convert scalar FP32 to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); - } else if (value->isa()) { - // convert scalar FP32 to GeTensor - MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; - GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); - auto v = GetValue(value); - desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); - } else if (value->isa()) { - // convert String to GeTensor - MS_LOG(INFO) << "convert string to tensor with data type = String"; - std::string v = GetValue(value); - std::vector ge_shape; - GeShape shape(ge_shape); - GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING); - GeTensor str_tensor(desc); - str_tensor.SetData(v); - return str_tensor; - } else { - MS_LOG(WARNING) << "Unsupported value type: " << value->type_name() - << " to convert to tensor. Value: " << value->ToString(); - } - return GeTensor(); -} - -bool IsCustomPrim(const PrimitivePtr &prim) { - if (prim == nullptr) { - return false; - } - - ValuePtr flag = prim->GetAttr("_custom_op_flag"); - if (flag == nullptr) { - return false; - } - - bool is_custom_op = GetValue(flag); - if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) { - MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op " - "can not assign the op information config path."; - } - - return is_custom_op; -} - -bool IsCustomCNode(const AnfNodePtr &anf) { - if (anf == nullptr) { - return false; - } - auto node = anf->cast(); - if (node == nullptr) { - return false; - } - if (node->inputs().empty()) { - MS_LOG(EXCEPTION) << "length of node inputs is empty"; - } - MS_EXCEPTION_IF_NULL(node->inputs()[0]); - if (!node->inputs()[0]->isa()) { - return false; - } - auto cus_prim = GetValueNode(node->inputs()[0]); - if (cus_prim == nullptr) { - return false; - } - - return IsCustomPrim(cus_prim); -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter_util.h b/mindspore/ccsrc/transform/op_adapter_util.h deleted file mode 100644 index fcabc732d5..0000000000 --- a/mindspore/ccsrc/transform/op_adapter_util.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_ADAPTER_UTIL_H_ -#define TRANSFORM_OP_ADAPTER_UTIL_H_ - -#include -#include - -#include "transform/op_adapter_base.h" - -namespace mindspore { -namespace transform { -template -static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { - return static_cast(GetValue

(value)); -} - -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, - const AnyTraits>); - -std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); - -std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); - -std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, - const AnyTraits>, const AnyTraits); - -GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); - -template -std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { - if (!value->isa() && !value->isa()) { - MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() - << ", type: " << value->type_name() << ", value should be a tuple or list"; - } - auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); - std::vector data; - for (auto &it : vec) { - data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); - } - return data; -} - -GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); - -bool IsCustomPrim(const PrimitivePtr &prim); -bool IsCustomCNode(const AnfNodePtr &node); -} // namespace transform -} // namespace mindspore -#endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc deleted file mode 100644 index ffaaa952db..0000000000 --- a/mindspore/ccsrc/transform/op_declare.cc +++ /dev/null @@ -1,1330 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/op_declare.h" - -#include - -#include "transform/all_ops.h" -#include "utils/utils.h" - -namespace mindspore { -namespace transform { -#define INPUT_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::input_map_ -#define EMPTY_INPUT_MAP std::unordered_map() -#define INPUT_DESC(name) \ - { \ -#name, \ - [](const OperatorPtr op, const OperatorPtr input) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->set_input_##name(*input); \ - }, \ - [](const OperatorPtr op, const OutHandler& handle) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->set_input_##name(*(handle.op), handle.out); \ - }, \ - [](const OperatorPtr op, const GeTensorDesc desc) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->update_input_desc_##name(desc); \ - } \ - } - -#define DYN_INPUT_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_input_map_ -#define DYN_INPUT_DESC(name) \ - { \ -#name, \ - [](const OperatorPtr op, unsigned int num) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->create_dynamic_input_##name(num); \ - }, \ - [](const OperatorPtr op, unsigned int index, const OperatorPtr input) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->set_dynamic_input_##name(index, *input); \ - }, \ - [](const OperatorPtr op, unsigned int index, const OutHandler& handle) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->set_dynamic_input_##name(index, *(handle.op), handle.out); \ - } \ - } - -#define DYN_SUBGRAPH_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_subgraph_map_ -#define DYN_SUBGRAPH_DESC(name) \ - { \ -#name, \ - [](const OperatorPtr op, unsigned int num) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->create_dynamic_subgraph_##name(num); \ - }, \ - [](const OperatorPtr op, unsigned int index, const DfGraphPtr graph) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->set_dynamic_subgraph_builder_##name(index, [graph](){return *graph;}); \ - } \ - } - -#define ATTR_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::attr_map_ -#define EMPTY_ATTR_MAP std::unordered_map() -#define ATTR_DESC(name, ...) \ - { \ -#name, \ - [](const OperatorPtr op, const ValuePtr& value) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->set_attr_##name(ConvertAny(value, __VA_ARGS__)); \ - } \ - } - -#define INPUT_ATTR_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::input_attr_map_ - -#define OUTPUT_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::output_map_ -#define OUTPUT_DESC(name) \ - { \ -#name, \ - [](const OperatorPtr op, const GeTensorDesc desc) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->update_output_desc_##name(desc); \ - } \ - } - -#define DYN_OUTPUT_MAP(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_output_map_ - -#define DYN_OUTPUT_DESC(name) \ - { \ -#name, \ - [](const OperatorPtr op, unsigned int num) { \ - auto p = std::static_pointer_cast(op); \ - (void)p->create_dynamic_output_##name(num); \ - } \ - } - -template <> -std::unordered_map> OpAdapter::cus_input_map_{}; -template <> -std::unordered_map> OpAdapter::cus_output_map_{}; - -// --------------specialization for each operator---------- -// const -INPUT_MAP(Const) = EMPTY_INPUT_MAP; -ATTR_MAP(Const) = {{"value", ATTR_DESC(value, AnyTraits())}}; -OUTPUT_MAP(Const) = {{0, OUTPUT_DESC(y)}}; - -// Assign -INPUT_MAP(Assign) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}}; -ATTR_MAP(Assign) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Assign) = {{0, OUTPUT_DESC(ref)}}; - -// Constant -INPUT_MAP(Constant) = EMPTY_INPUT_MAP; -ATTR_MAP(Constant) = {{"value", ATTR_DESC(value, AnyTraits())}}; -OUTPUT_MAP(Constant) = {{0, OUTPUT_DESC(y)}}; - -// ApplyMomentumD -INPUT_MAP(ApplyMomentumD) = { - {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, {4, INPUT_DESC(grad)}, {5, INPUT_DESC(momentum)}}; -ATTR_MAP(ApplyMomentumD) = {{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}, - {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyMomentumD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; - -// ScalarSummary -INPUT_MAP(Summary) = {{2, INPUT_DESC(x)}}; -ATTR_MAP(Summary) = EMPTY_ATTR_MAP; - -// Data -INPUT_MAP(Data) = EMPTY_INPUT_MAP; -ATTR_MAP(Data) = EMPTY_ATTR_MAP; - -// BatchNorm -INPUT_MAP(BatchNorm) = {{1, INPUT_DESC(x)}, - {2, INPUT_DESC(scale)}, - {3, INPUT_DESC(offset)}, - {4, INPUT_DESC(mean)}, - {5, INPUT_DESC(variance)}}; -ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, - {"is_training", ATTR_DESC(is_training, AnyTraits())}}; -OUTPUT_MAP(BatchNorm) = {{0, OUTPUT_DESC(y)}, - {1, OUTPUT_DESC(batch_mean)}, - {2, OUTPUT_DESC(batch_variance)}, - {3, OUTPUT_DESC(reserve_space_1)}, - {4, OUTPUT_DESC(reserve_space_2)}}; - -// BatchNormGrad -INPUT_MAP(BatchNormGrad) = {{1, INPUT_DESC(y_backprop)}, - {2, INPUT_DESC(x)}, - {3, INPUT_DESC(scale)}, - {4, INPUT_DESC(reserve_space_1)}, - {5, INPUT_DESC(reserve_space_2)}}; -ATTR_MAP(BatchNormGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, - {"is_training", ATTR_DESC(is_training, AnyTraits())}}; -OUTPUT_MAP(BatchNormGrad) = {{0, OUTPUT_DESC(x_backprop)}, - {1, OUTPUT_DESC(scale_backprop)}, - {2, OUTPUT_DESC(offset_backprop)}, - {3, OUTPUT_DESC(reserve_space_4)}, - {4, OUTPUT_DESC(reserve_space_5)}}; - -// Relu -INPUT_MAP(Relu) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Relu) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Relu) = {{0, OUTPUT_DESC(y)}}; - -// Elu -INPUT_MAP(Elu) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Elu) = {{"alpha", ATTR_DESC(alpha, AnyTraits())}}; -OUTPUT_MAP(Elu) = {{0, OUTPUT_DESC(y)}}; - -// EluGrad -INPUT_MAP(EluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(activations)}}; -ATTR_MAP(EluGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(EluGrad) = {{0, OUTPUT_DESC(y)}}; - -// PRelu -INPUT_MAP(PRelu) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight)}}; -ATTR_MAP(PRelu) = EMPTY_ATTR_MAP; -OUTPUT_MAP(PRelu) = {{0, OUTPUT_DESC(y)}}; - -// PReluGrad -INPUT_MAP(PReluGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(features)}, {3, INPUT_DESC(weights)}}; -ATTR_MAP(PReluGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(PReluGrad) = {{0, OUTPUT_DESC(dx)}, {1, OUTPUT_DESC(da)}}; - -// Sigmoid -INPUT_MAP(Sigmoid) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Sigmoid) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Sigmoid) = {{0, OUTPUT_DESC(y)}}; - -// SigmoidGrad -INPUT_MAP(SigmoidGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; -ATTR_MAP(SigmoidGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(SigmoidGrad) = {{0, OUTPUT_DESC(z)}}; - -// L2NormalizeGrad -INPUT_MAP(L2NormalizeGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(dy)}}; -ATTR_MAP(L2NormalizeGrad) = { - {"axis", ATTR_DESC(dim, AnyTraits>(), AnyTraits>())}, - {"epsilon", ATTR_DESC(eps, AnyTraits())}}; -OUTPUT_MAP(L2NormalizeGrad) = {{0, OUTPUT_DESC(dx)}}; - -// LarsV2Update -INPUT_MAP(LarsV2Update) = {{1, INPUT_DESC(w)}, - {2, INPUT_DESC(g)}, - {3, INPUT_DESC(w_square_sum)}, - {4, INPUT_DESC(g_square_sum)}, - {5, INPUT_DESC(weight_decay)}, - {6, INPUT_DESC(learning_rate)}}; -ATTR_MAP(LarsV2Update) = {{"epsilon", ATTR_DESC(epsilon, AnyTraits())}, - {"hyperpara", ATTR_DESC(hyperpara, AnyTraits())}, - {"use_clip", ATTR_DESC(use_clip, AnyTraits())}}; -OUTPUT_MAP(LarsV2Update) = {{0, OUTPUT_DESC(g_new)}}; - -// L2Normalize -INPUT_MAP(L2Normalize) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(L2Normalize) = { - {"axis", ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}, - {"epsilon", ATTR_DESC(eps, AnyTraits())}}; -OUTPUT_MAP(L2Normalize) = {{0, OUTPUT_DESC(y)}}; - -// CumsumD -INPUT_MAP(CumsumD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(CumsumD) = {{2, ATTR_DESC(axis, AnyTraits())}}; -ATTR_MAP(CumsumD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits())}, - {"reverse", ATTR_DESC(reverse, AnyTraits())}}; -OUTPUT_MAP(CumsumD) = {{0, OUTPUT_DESC(y)}}; - -// SoftmaxV2 -INPUT_MAP(SoftmaxV2) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(SoftmaxV2) = { - {"axis", ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}, -}; -OUTPUT_MAP(SoftmaxV2) = {{0, OUTPUT_DESC(y)}}; - -// SoftmaxGrad -INPUT_MAP(SoftmaxGrad) = {{1, INPUT_DESC(softmax)}, {2, INPUT_DESC(grad_softmax)}}; -OUTPUT_MAP(SoftmaxGrad) = {{0, OUTPUT_DESC(grad_x)}}; -ATTR_MAP(SoftmaxGrad) = EMPTY_ATTR_MAP; - -// Flatten -INPUT_MAP(Flatten) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Flatten) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Flatten) = {{0, OUTPUT_DESC(y)}}; - -// add -INPUT_MAP(Add) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Add) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Add) = {{0, OUTPUT_DESC(y)}}; - -// GatherV2 -INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(axis)}}; -ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; -OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; - -// ReduceSumD -INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ReduceSumD) = { - {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ReduceSumD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceSumD) = {{0, OUTPUT_DESC(y)}}; - -// ReduceProdD -INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ReduceProdD) = { - {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ReduceProdD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceProdD) = {{0, OUTPUT_DESC(y)}}; - -// CumprodD -INPUT_MAP(CumprodD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(CumprodD) = {{2, ATTR_DESC(axis, AnyTraits())}}; -ATTR_MAP(CumprodD) = {{"exclusive", ATTR_DESC(exclusive, AnyTraits())}, - {"reverse", ATTR_DESC(reverse, AnyTraits())}}; -OUTPUT_MAP(CumprodD) = {{0, OUTPUT_DESC(y)}}; - -// SoftmaxCrossEntropyWithLogits -INPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(labels)}}; -ATTR_MAP(SoftmaxCrossEntropyWithLogits) = EMPTY_ATTR_MAP; -OUTPUT_MAP(SoftmaxCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(backprop)}}; - -// MeanGrad -INPUT_MAP(MeanGrad) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(MeanGrad) = {{2, ATTR_DESC(mean_grad_output_shape_value, kOpFormat_NHWC, - AnyTraits>(), AnyTraits())}}; -ATTR_MAP(MeanGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; - -INPUT_MAP(SliceD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(SliceD) = {{2, ATTR_DESC(offsets, AnyTraits(), AnyTraits>())}, - {3, ATTR_DESC(size, AnyTraits(), AnyTraits>())}}; -ATTR_MAP(SliceD) = EMPTY_ATTR_MAP; -OUTPUT_MAP(SliceD) = {{0, OUTPUT_DESC(y)}}; - -// MaxPool -INPUT_MAP(MaxPool) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(MaxPool) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}}; -OUTPUT_MAP(MaxPool) = {{0, OUTPUT_DESC(y)}}; - -// AvgPool -INPUT_MAP(AvgPool) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(AvgPool) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}}; -OUTPUT_MAP(AvgPool) = {{0, OUTPUT_DESC(y)}}; - -// GreaterEqual -INPUT_MAP(GreaterEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(GreaterEqual) = EMPTY_ATTR_MAP; -OUTPUT_MAP(GreaterEqual) = {{0, OUTPUT_DESC(y)}}; - -// AssignAdd -INPUT_MAP(AssignAdd) = {{1, INPUT_DESC(ref)}, {2, INPUT_DESC(value)}}; -ATTR_MAP(AssignAdd) = EMPTY_ATTR_MAP; -OUTPUT_MAP(AssignAdd) = {{0, OUTPUT_DESC(ref)}}; - -// AssignSub -INPUT_MAP(AssignSub) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(value)}}; -ATTR_MAP(AssignSub) = EMPTY_ATTR_MAP; -OUTPUT_MAP(AssignSub) = {{0, OUTPUT_DESC(var)}}; - -// Cos -INPUT_MAP(Cos) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Cos) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Cos) = {{0, OUTPUT_DESC(y)}}; - -// Acos -INPUT_MAP(Acos) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Acos) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Acos) = {{0, OUTPUT_DESC(y)}}; - -// AcosGrad -INPUT_MAP(AcosGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; -ATTR_MAP(AcosGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(AcosGrad) = {{0, OUTPUT_DESC(z)}}; - -// Acosh -INPUT_MAP(Acosh) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Acosh) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Acosh) = {{0, OUTPUT_DESC(y)}}; - -// AcoshGrad -INPUT_MAP(AcoshGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; -ATTR_MAP(AcoshGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(AcoshGrad) = {{0, OUTPUT_DESC(z)}}; - -// Floor -INPUT_MAP(Floor) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Floor) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Floor) = {{0, OUTPUT_DESC(y)}}; - -// FloorDiv -INPUT_MAP(FloorDiv) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(FloorDiv) = EMPTY_ATTR_MAP; -OUTPUT_MAP(FloorDiv) = {{0, OUTPUT_DESC(y)}}; - -// FloorMod -INPUT_MAP(FloorMod) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(FloorMod) = EMPTY_ATTR_MAP; -OUTPUT_MAP(FloorMod) = {{0, OUTPUT_DESC(y)}}; - -// Sin -INPUT_MAP(Sin) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Sin) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Sin) = {{0, OUTPUT_DESC(y)}}; - -// Exp -INPUT_MAP(Exp) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Exp) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Exp) = {{0, OUTPUT_DESC(y)}}; - -// BoundingBoxEncode -INPUT_MAP(BoundingBoxEncode) = { - {1, INPUT_DESC(anchor_box)}, - {2, INPUT_DESC(ground_truth_box)}, -}; -ATTR_MAP(BoundingBoxEncode) = { - {"means", ATTR_DESC(means, AnyTraits>(), AnyTraits())}, - {"stds", ATTR_DESC(stds, AnyTraits>(), AnyTraits())}, -}; -OUTPUT_MAP(BoundingBoxEncode) = {{0, OUTPUT_DESC(delats)}}; - -// BoundingBoxDecode -INPUT_MAP(BoundingBoxDecode) = { - {1, INPUT_DESC(rois)}, - {2, INPUT_DESC(deltas)}, -}; -ATTR_MAP(BoundingBoxDecode) = { - {"means", ATTR_DESC(means, AnyTraits>(), AnyTraits())}, - {"stds", ATTR_DESC(stds, AnyTraits>(), AnyTraits())}, - {"max_shape", ATTR_DESC(max_shape, AnyTraits>(), AnyTraits>())}, - {"wh_ratio_clip", ATTR_DESC(wh_ratio_clip, AnyTraits())}, -}; -OUTPUT_MAP(BoundingBoxDecode) = {{0, OUTPUT_DESC(bboxes)}}; - -// TopK -INPUT_MAP(TopK) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(k)}}; -ATTR_MAP(TopK) = {{"sorted", ATTR_DESC(sorted, AnyTraits())}}; -OUTPUT_MAP(TopK) = {{0, OUTPUT_DESC(values)}, {1, OUTPUT_DESC(indices)}}; - -// Multiply -INPUT_MAP(Multiply) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}}; -ATTR_MAP(Multiply) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Multiply) = {{0, OUTPUT_DESC(z)}}; - -// TileD -INPUT_MAP(TileD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(TileD) = {{2, ATTR_DESC(multiples, AnyTraits(), AnyTraits>())}}; -ATTR_MAP(TileD) = EMPTY_ATTR_MAP; -OUTPUT_MAP(TileD) = {{0, OUTPUT_DESC(y)}}; - -// OneHot -INPUT_MAP(OneHot) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(depth)}, {3, INPUT_DESC(on_value)}, {4, INPUT_DESC(off_value)}}; -ATTR_MAP(OneHot) = {{"axis", ATTR_DESC(axis, AnyTraits())}}; -OUTPUT_MAP(OneHot) = {{0, OUTPUT_DESC(y)}}; - -// GatherV2D -INPUT_MAP(GatherV2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; -INPUT_ATTR_MAP(GatherV2D) = {{3, ATTR_DESC(axis, AnyTraits())}}; -ATTR_MAP(GatherV2D) = EMPTY_ATTR_MAP; -OUTPUT_MAP(GatherV2D) = {{0, OUTPUT_DESC(y)}}; - -// Reshape -INPUT_MAP(Reshape) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(shape)}}; -ATTR_MAP(Reshape) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Reshape) = {{0, OUTPUT_DESC(y)}}; - -// TransShape -INPUT_MAP(TransShape) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(TransShape) = {{2, ATTR_DESC(outShape, AnyTraits(), AnyTraits>())}}; -ATTR_MAP(TransShape) = EMPTY_ATTR_MAP; -OUTPUT_MAP(TransShape) = {{0, OUTPUT_DESC(y)}}; - -// BiasAdd -INPUT_MAP(BiasAdd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(bias)}}; -ATTR_MAP(BiasAdd) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; -OUTPUT_MAP(BiasAdd) = {{0, OUTPUT_DESC(y)}}; - -// Iou -INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; -ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; -OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; - -// ResizeNearestNeighborV2D -INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeNearestNeighborV2D) = { - {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, - {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; - -// ResizeNearestNeighborV2Grad -INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; -ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; - -// ApplyAdam -INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, - {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, - {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, - {10, INPUT_DESC(grad)}}; -ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, - {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; -OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; - -// ApplyAdamD -INPUT_MAP(ApplyAdamD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, - {4, INPUT_DESC(beta1_power)}, {5, INPUT_DESC(beta2_power)}, {6, INPUT_DESC(lr)}, - {7, INPUT_DESC(beta1)}, {8, INPUT_DESC(beta2)}, {9, INPUT_DESC(epsilon)}, - {10, INPUT_DESC(grad)}}; -ATTR_MAP(ApplyAdamD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, - {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; -OUTPUT_MAP(ApplyAdamD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; - -// Relu6 -INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Relu6) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Relu6) = {{0, OUTPUT_DESC(y)}}; - -// Relu6Grad -INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; -ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; - -// ResizeBilinearV2Grad -INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; -ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; - -// ResizeBilinearV2D -INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeBilinearV2D) = { - {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, - {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; - -// ZerosLike -INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ZerosLike) = EMPTY_ATTR_MAP; -OUTPUT_MAP(ZerosLike) = {{0, OUTPUT_DESC(y)}}; - -// OnesLike -INPUT_MAP(OnesLike) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(OnesLike) = EMPTY_ATTR_MAP; -OUTPUT_MAP(OnesLike) = {{0, OUTPUT_DESC(y)}}; - -// NMSWithMask -INPUT_MAP(NMSWithMask) = {{1, INPUT_DESC(box_scores)}}; -ATTR_MAP(NMSWithMask) = {{"iou_threshold", ATTR_DESC(iou_threshold, AnyTraits())}}; -OUTPUT_MAP(NMSWithMask) = { - {0, OUTPUT_DESC(selected_boxes)}, {1, OUTPUT_DESC(selected_idx)}, {2, OUTPUT_DESC(selected_mask)}}; - -// Unpack -INPUT_MAP(Unpack) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Unpack) = {{"axis", ATTR_DESC(axis, AnyTraits())}, {"num", ATTR_DESC(num, AnyTraits())}}; -DYN_OUTPUT_MAP(Unpack) = {{0, DYN_OUTPUT_DESC(y)}}; - -// TensorScatterUpdate -INPUT_MAP(TensorScatterUpdate) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; -ATTR_MAP(TensorScatterUpdate) = EMPTY_ATTR_MAP; -OUTPUT_MAP(TensorScatterUpdate) = {{0, OUTPUT_DESC(y)}}; - -// ScatterUpdate -INPUT_MAP(ScatterUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; -ATTR_MAP(ScatterUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ScatterUpdate) = {{0, OUTPUT_DESC(var)}}; - -// ScatterNdUpdate -INPUT_MAP(ScatterNdUpdate) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; -ATTR_MAP(ScatterNdUpdate) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ScatterNdUpdate) = {{0, OUTPUT_DESC(var)}}; - -// ScatterMax -INPUT_MAP(ScatterMax) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(updates)}}; -ATTR_MAP(ScatterMax) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ScatterMax) = {{0, OUTPUT_DESC(var)}}; - -// CheckValid -INPUT_MAP(CheckValid) = {{1, INPUT_DESC(bbox_tensor)}, {2, INPUT_DESC(img_metas)}}; -ATTR_MAP(CheckValid) = EMPTY_ATTR_MAP; -OUTPUT_MAP(CheckValid) = {{0, OUTPUT_DESC(valid_tensor)}}; - -// SmoothL1Loss -INPUT_MAP(SmoothL1Loss) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(label)}}; -ATTR_MAP(SmoothL1Loss) = {{"sigma", ATTR_DESC(sigma, AnyTraits())}}; -OUTPUT_MAP(SmoothL1Loss) = {{0, OUTPUT_DESC(loss)}}; - -// SmoothL1LossGrad -INPUT_MAP(SmoothL1LossGrad) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(label)}, {3, INPUT_DESC(dout)}}; -ATTR_MAP(SmoothL1LossGrad) = {{"sigma", ATTR_DESC(sigma, AnyTraits())}}; -OUTPUT_MAP(SmoothL1LossGrad) = {{0, OUTPUT_DESC(gradient)}}; - -// SigmoidCrossEntropyWithLogits -INPUT_MAP(SigmoidCrossEntropyWithLogits) = {{1, INPUT_DESC(predict)}, {2, INPUT_DESC(target)}}; -ATTR_MAP(SigmoidCrossEntropyWithLogits) = EMPTY_ATTR_MAP; -OUTPUT_MAP(SigmoidCrossEntropyWithLogits) = {{0, OUTPUT_DESC(loss)}}; - -// SigmoidCrossEntropyWithLogitsGrad -INPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = { - {1, INPUT_DESC(predict)}, {2, INPUT_DESC(target)}, {3, INPUT_DESC(dout)}}; -ATTR_MAP(SigmoidCrossEntropyWithLogitsGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(SigmoidCrossEntropyWithLogitsGrad) = {{0, OUTPUT_DESC(gradient)}}; - -// ScatterNdD -INPUT_MAP(ScatterNdD) = {{1, INPUT_DESC(indices)}, {2, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ScatterNdD) = { - {3, ATTR_DESC(shape, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ScatterNdD) = EMPTY_ATTR_MAP; -OUTPUT_MAP(ScatterNdD) = {{0, OUTPUT_DESC(y)}}; - -// PadD -INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits>>())}}; -OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}}; - -// MirrorPad -INPUT_MAP(MirrorPad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; -ATTR_MAP(MirrorPad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; -OUTPUT_MAP(MirrorPad) = {{0, OUTPUT_DESC(y)}}; - -// MirrorPadGrad -INPUT_MAP(MirrorPadGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; -ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; -OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; - -// GatherNd -INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}}; -ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; -OUTPUT_MAP(GatherNd) = {{0, OUTPUT_DESC(y)}}; - -// ROIAlign -INPUT_MAP(ROIAlign) = {{1, INPUT_DESC(features)}, {2, INPUT_DESC(rois)}}; -OUTPUT_MAP(ROIAlign) = {{0, OUTPUT_DESC(y)}}; -ATTR_MAP(ROIAlign) = {{"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, - {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, - {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, - {"sample_num", ATTR_DESC(sample_num, AnyTraits())}, - {"roi_end_mode", ATTR_DESC(roi_end_mode, AnyTraits())}}; - -// ROIAlignGrad -INPUT_MAP(ROIAlignGrad) = {{1, INPUT_DESC(ydiff)}, {2, INPUT_DESC(rois)}}; -OUTPUT_MAP(ROIAlignGrad) = {{0, OUTPUT_DESC(xdiff)}}; -ATTR_MAP(ROIAlignGrad) = { - {"xdiff_shape", ATTR_DESC(xdiff_shape, AnyTraits>(), AnyTraits>())}, - {"pooled_height", ATTR_DESC(pooled_height, AnyTraits())}, - {"pooled_width", ATTR_DESC(pooled_width, AnyTraits())}, - {"spatial_scale", ATTR_DESC(spatial_scale, AnyTraits())}, - {"sample_num", ATTR_DESC(sample_num, AnyTraits())}}; - -// ArgMaxD -INPUT_MAP(ArgMaxD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ArgMaxD) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, - {"output_type", ATTR_DESC(dtype, AnyTraits())}}; -OUTPUT_MAP(ArgMaxD) = {{0, OUTPUT_DESC(y)}}; - -// ArgMinD -INPUT_MAP(ArgMinD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ArgMinD) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, - {"output_type", ATTR_DESC(dtype, AnyTraits())}}; -OUTPUT_MAP(ArgMinD) = {{0, OUTPUT_DESC(y)}}; - -// ArgMaxWithValue -INPUT_MAP(ArgMaxWithValue) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ArgMaxWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, - {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ArgMaxWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; - -// ArgMinWithValue -INPUT_MAP(ArgMinWithValue) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, - {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; - -// ReduceAllD -INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ReduceAllD) = { - {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; - -// ReduceMeanD -INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ReduceMeanD) = { - {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ReduceMeanD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceMeanD) = {{0, OUTPUT_DESC(y)}}; - -// HCOMAllreduce -INPUT_MAP(HcomAllReduce) = {{1, INPUT_DESC(x)}}; -OUTPUT_MAP(HcomAllReduce) = {{0, OUTPUT_DESC(y)}}; -ATTR_MAP(HcomAllReduce) = {{"op", ATTR_DESC(reduction, AnyTraits())}, - {"group", ATTR_DESC(group, AnyTraits())}, - {"fusion", ATTR_DESC(fusion, AnyTraits())}}; - -// HCOMBraodcast -INPUT_MAP(HcomBroadcast) = EMPTY_INPUT_MAP; -DYN_INPUT_MAP(HcomBroadcast) = {{1, DYN_INPUT_DESC(x)}}; -DYN_OUTPUT_MAP(HcomBroadcast) = {{0, DYN_OUTPUT_DESC(y)}}; -ATTR_MAP(HcomBroadcast) = {{"root_rank", ATTR_DESC(root_rank, AnyTraits())}, - {"group", ATTR_DESC(group, AnyTraits())}}; - -// HCOMAllreduce -INPUT_MAP(HcomAllGather) = {{1, INPUT_DESC(x)}}; -OUTPUT_MAP(HcomAllGather) = {{0, OUTPUT_DESC(y)}}; -ATTR_MAP(HcomAllGather) = {{"group", ATTR_DESC(group, AnyTraits())}, - {"rank_size", ATTR_DESC(rank_size, AnyTraits())}}; - -// HCOMReduceScatter -INPUT_MAP(HcomReduceScatter) = {{1, INPUT_DESC(x)}}; -OUTPUT_MAP(HcomReduceScatter) = {{0, OUTPUT_DESC(y)}}; -ATTR_MAP(HcomReduceScatter) = {{"group", ATTR_DESC(group, AnyTraits())}, - {"op", ATTR_DESC(reduction, AnyTraits())}, - {"rank_size", ATTR_DESC(rank_size, AnyTraits())}}; - -// Variable -INPUT_MAP(Variable) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Variable) = EMPTY_ATTR_MAP; - -// ReluGrad -INPUT_MAP(ReluGrad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; -ATTR_MAP(ReluGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(ReluGrad) = {{0, OUTPUT_DESC(backprops)}}; - -// BiasAddGrad -INPUT_MAP(BiasAddGrad) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(BiasAddGrad) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}}; -OUTPUT_MAP(BiasAddGrad) = {{0, OUTPUT_DESC(y)}}; - -// MaxPoolGrad -INPUT_MAP(MaxPoolGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grad)}}; -ATTR_MAP(MaxPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}}; -OUTPUT_MAP(MaxPoolGrad) = {{0, OUTPUT_DESC(y)}}; - -// avgpoolgrad -INPUT_MAP(AvgPoolGrad) = {{1, INPUT_DESC(orig_input_shape)}, {2, INPUT_DESC(input_grad)}}; -ATTR_MAP(AvgPoolGrad) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}}; -OUTPUT_MAP(AvgPoolGrad) = {{0, OUTPUT_DESC(out_grad)}}; - -// MaxPoolWithArgmax -INPUT_MAP(MaxPoolWithArgmax) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(MaxPoolWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}}; -OUTPUT_MAP(MaxPoolWithArgmax) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(argmax)}}; - -// MaxPoolGradWithArgmax -INPUT_MAP(MaxPoolGradWithArgmax) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}, {3, INPUT_DESC(argmax)}}; -ATTR_MAP(MaxPoolGradWithArgmax) = {{"ksize", ATTR_DESC(ksize, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}}; -OUTPUT_MAP(MaxPoolGradWithArgmax) = {{0, OUTPUT_DESC(y)}}; - -// ExtractImagePatches -INPUT_MAP(ExtractImagePatches) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ExtractImagePatches) = {{"ksizes", ATTR_DESC(ksizes, AnyTraits(), AnyTraits>())}, - {"strides", ATTR_DESC(strides, AnyTraits(), AnyTraits>())}, - {"rates", ATTR_DESC(rates, AnyTraits(), AnyTraits>())}, - {"padding", ATTR_DESC(padding, AnyTraits())}}; -OUTPUT_MAP(ExtractImagePatches) = {{0, OUTPUT_DESC(y)}}; - -// Conv2D -INPUT_MAP(Conv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; -ATTR_MAP(Conv2D) = { - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, - {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"group", ATTR_DESC(groups, AnyTraits())}, -}; -OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; - -// Conv2DBackpropInputD -INPUT_MAP(Conv2DBackpropInputD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(filter)}}; -INPUT_ATTR_MAP(Conv2DBackpropInputD) = { - {3, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(Conv2DBackpropInputD) = { - {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"group", ATTR_DESC(groups, AnyTraits())}, -}; -OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; - -// Conv2DBackpropFilterD -INPUT_MAP(Conv2DBackpropFilterD) = {{1, INPUT_DESC(out_backprop)}, {2, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { - {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(Conv2DBackpropFilterD) = { - {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}, - {"group", ATTR_DESC(groups, AnyTraits())}, -}; -OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; - -// DepthwiseConv2D -INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; -ATTR_MAP(DepthwiseConv2D) = { - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, - {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, - {"data_format", ATTR_DESC(data_format, AnyTraits())}, -}; -OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}}; - -// DepthwiseConv2DBackpropInputD -INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_DESC(out_backprop)}}; -INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = { - {1, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(DepthwiseConv2DBackpropInputD) = { - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, - {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, -}; -OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}}; - -// DepthwiseConv2DBackpropFilterD -INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_DESC(out_backprop)}}; -INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { - {2, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, - {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, -}; -OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; - -// MatMulV2 -INPUT_MAP(MatMulV2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(MatMulV2) = {{"transpose_a", ATTR_DESC(transpose_x1, AnyTraits())}, - {"transpose_b", ATTR_DESC(transpose_x2, AnyTraits())}}; -OUTPUT_MAP(MatMulV2) = {{0, OUTPUT_DESC(y)}}; - -// Merge -INPUT_MAP(Merge) = EMPTY_INPUT_MAP; -DYN_INPUT_MAP(Merge) = {{1, DYN_INPUT_DESC(x)}}; -ATTR_MAP(Merge) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Merge) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(value_index)}}; - -// Switch -INPUT_MAP(Switch) = {{1, INPUT_DESC(data)}, {2, INPUT_DESC(pred)}}; -OUTPUT_MAP(Switch) = {{0, OUTPUT_DESC(output_false)}, {1, OUTPUT_DESC(output_true)}}; -ATTR_MAP(Switch) = EMPTY_ATTR_MAP; - -// AddN -INPUT_MAP(AddN) = EMPTY_INPUT_MAP; -DYN_INPUT_MAP(AddN) = {{1, DYN_INPUT_DESC(x)}}; -ATTR_MAP(AddN) = {{"n", ATTR_DESC(N, AnyTraits())}}; -OUTPUT_MAP(AddN) = {{0, OUTPUT_DESC(y)}}; - -// Mul -INPUT_MAP(Mul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Mul) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Mul) = {{0, OUTPUT_DESC(y)}}; - -// RealDiv -INPUT_MAP(RealDiv) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(RealDiv) = EMPTY_ATTR_MAP; -OUTPUT_MAP(RealDiv) = {{0, OUTPUT_DESC(y)}}; - -// Cast -INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits())}}; -ATTR_MAP(Cast) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; - -// Case -INPUT_MAP(Case) = {{1, INPUT_DESC(branch_index)}}; -DYN_INPUT_MAP(Case) = {{2, DYN_INPUT_DESC(input)}}; -ATTR_MAP(Case) = EMPTY_ATTR_MAP; -DYN_OUTPUT_MAP(Case) = {{0, DYN_OUTPUT_DESC(output)}}; -DYN_SUBGRAPH_MAP(Case) = {{0, DYN_SUBGRAPH_DESC(branches)}}; - -// Reciprocal -INPUT_MAP(Reciprocal) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Reciprocal) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Reciprocal) = {{0, OUTPUT_DESC(y)}}; - -// Sub -INPUT_MAP(Sub) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Sub) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Sub) = {{0, OUTPUT_DESC(y)}}; - -// SplitD -INPUT_MAP(SplitD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(SplitD) = {{"axis", ATTR_DESC(split_dim, AnyTraits())}, - {"output_num", ATTR_DESC(num_split, AnyTraits())}}; -DYN_OUTPUT_MAP(SplitD) = {{0, DYN_OUTPUT_DESC(y)}}; - -// Range -INPUT_MAP(RangeD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(RangeD) = {{"start", ATTR_DESC(start, AnyTraits())}, - {"limit", ATTR_DESC(limit, AnyTraits())}, - {"delta", ATTR_DESC(delta, AnyTraits())}}; -OUTPUT_MAP(RangeD) = {{0, OUTPUT_DESC(y)}}; - -// Neg -INPUT_MAP(Neg) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Neg) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Neg) = {{0, OUTPUT_DESC(y)}}; - -// Transpose -INPUT_MAP(TransposeD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(TransposeD) = {{2, ATTR_DESC(perm, AnyTraits(), AnyTraits>())}}; -ATTR_MAP(TransposeD) = EMPTY_ATTR_MAP; -// Do not set Transpose operator output descriptor - -// DropOutGenMask -INPUT_MAP(DropOutGenMask) = {{1, INPUT_DESC(shape)}, {2, INPUT_DESC(prob)}}; -ATTR_MAP(DropOutGenMask) = {{"Seed0", ATTR_DESC(seed, AnyTraits())}, - {"Seed1", ATTR_DESC(seed2, AnyTraits())}}; -OUTPUT_MAP(DropOutGenMask) = {{0, OUTPUT_DESC(y)}}; - -// Pack -INPUT_MAP(Pack) = EMPTY_INPUT_MAP; -DYN_INPUT_MAP(Pack) = {{1, DYN_INPUT_DESC(x)}}; -ATTR_MAP(Pack) = {{"num", ATTR_DESC(N, AnyTraits())}, {"axis", ATTR_DESC(axis, AnyTraits())}}; -OUTPUT_MAP(Pack) = {{0, OUTPUT_DESC(y)}}; - -// ConcatD -INPUT_MAP(ConcatD) = EMPTY_INPUT_MAP; -DYN_INPUT_MAP(ConcatD) = {{1, DYN_INPUT_DESC(x)}}; -ATTR_MAP(ConcatD) = { - {"axis", ATTR_DESC(concat_dim, AnyTraits())}, - {"inputNums", ATTR_DESC(N, AnyTraits())}, -}; -OUTPUT_MAP(ConcatD) = {{0, OUTPUT_DESC(y)}}; - -// Less -INPUT_MAP(Less) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Less) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Less) = {{0, OUTPUT_DESC(y)}}; - -// Rsqrt -INPUT_MAP(Rsqrt) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Rsqrt) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Rsqrt) = {{0, OUTPUT_DESC(y)}}; - -// Sqrt -INPUT_MAP(Sqrt) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Sqrt) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Sqrt) = {{0, OUTPUT_DESC(y)}}; - -// Square -INPUT_MAP(Square) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Square) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Square) = {{0, OUTPUT_DESC(y)}}; - -// SquareSumAll -INPUT_MAP(SquareSumAll) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(SquareSumAll) = EMPTY_ATTR_MAP; -OUTPUT_MAP(SquareSumAll) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; - -// Tanh -INPUT_MAP(Tanh) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Tanh) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Tanh) = {{0, OUTPUT_DESC(y)}}; - -// TanhGrad -INPUT_MAP(TanhGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; -ATTR_MAP(TanhGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(TanhGrad) = {{0, OUTPUT_DESC(z)}}; - -// ReduceMinD -INPUT_MAP(ReduceMinD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ReduceMinD) = { - {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ReduceMinD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceMinD) = {{0, OUTPUT_DESC(y)}}; - -// ReduceMaxD -INPUT_MAP(ReduceMaxD) = {{1, INPUT_DESC(x)}}; -INPUT_ATTR_MAP(ReduceMaxD) = { - {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -ATTR_MAP(ReduceMaxD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceMaxD) = {{0, OUTPUT_DESC(y)}}; - -// Maximum -INPUT_MAP(Maximum) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Maximum) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Maximum) = {{0, OUTPUT_DESC(y)}}; - -// Minimum -INPUT_MAP(Minimum) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Minimum) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Minimum) = {{0, OUTPUT_DESC(y)}}; - -// MaximumGrad -INPUT_MAP(MaximumGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grads)}}; -ATTR_MAP(MaximumGrad) = {{"grad_x", ATTR_DESC(grad_x, AnyTraits())}, - {"grad_y", ATTR_DESC(grad_y, AnyTraits())}}; -OUTPUT_MAP(MaximumGrad) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; - -// MinimumGrad -INPUT_MAP(MinimumGrad) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}, {3, INPUT_DESC(grads)}}; -ATTR_MAP(MinimumGrad) = {{"grad_x", ATTR_DESC(grad_x, AnyTraits())}, - {"grad_y", ATTR_DESC(grad_y, AnyTraits())}}; -OUTPUT_MAP(MinimumGrad) = {{0, OUTPUT_DESC(y1)}, {1, OUTPUT_DESC(y2)}}; - -// Pow -INPUT_MAP(Pow) = { - {1, INPUT_DESC(x1)}, - {2, INPUT_DESC(x2)}, -}; -ATTR_MAP(Pow) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Pow) = {{0, OUTPUT_DESC(y)}}; - -// Equal -INPUT_MAP(Equal) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Equal) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Equal) = {{0, OUTPUT_DESC(y)}}; - -// NotEqual -INPUT_MAP(NotEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(NotEqual) = EMPTY_ATTR_MAP; -OUTPUT_MAP(NotEqual) = {{0, OUTPUT_DESC(y)}}; - -// Log -INPUT_MAP(Log) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Log) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Log) = {{0, OUTPUT_DESC(y)}}; - -// LogicalAnd -INPUT_MAP(LogicalAnd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(LogicalAnd) = EMPTY_ATTR_MAP; -OUTPUT_MAP(LogicalAnd) = {{0, OUTPUT_DESC(y)}}; - -// LogicalOr -INPUT_MAP(LogicalOr) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(LogicalOr) = EMPTY_ATTR_MAP; -OUTPUT_MAP(LogicalOr) = {{0, OUTPUT_DESC(y)}}; - -// LogicalNot -INPUT_MAP(LogicalNot) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(LogicalNot) = EMPTY_ATTR_MAP; -OUTPUT_MAP(LogicalNot) = {{0, OUTPUT_DESC(y)}}; - -// Greater -INPUT_MAP(Greater) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Greater) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Greater) = {{0, OUTPUT_DESC(y)}}; - -// LogSoftmaxGrad -INPUT_MAP(LogSoftmaxGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(grad)}}; -ATTR_MAP(LogSoftmaxGrad) = { - {"axis", ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}}; -OUTPUT_MAP(LogSoftmaxGrad) = {{0, OUTPUT_DESC(y)}}; - -// Select -INPUT_MAP(Select) = {{1, INPUT_DESC(condition)}, {2, INPUT_DESC(x1)}, {3, INPUT_DESC(x2)}}; -ATTR_MAP(Select) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Select) = {{0, OUTPUT_DESC(y)}}; - -// LessEqual -INPUT_MAP(LessEqual) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(LessEqual) = EMPTY_ATTR_MAP; -OUTPUT_MAP(LessEqual) = {{0, OUTPUT_DESC(y)}}; - -// LogSoftmaxV2 -INPUT_MAP(LogSoftmaxV2) = {{1, INPUT_DESC(logits)}}; -ATTR_MAP(LogSoftmaxV2) = { - {"axis", ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; -OUTPUT_MAP(LogSoftmaxV2) = {{0, OUTPUT_DESC(logsoftmax)}}; - -// RandomChoiceWithMask -INPUT_MAP(RandomChoiceWithMask) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(RandomChoiceWithMask) = {{"count", ATTR_DESC(count, AnyTraits())}, - {"seed", ATTR_DESC(seed, AnyTraits())}, - {"seed2", ATTR_DESC(seed2, AnyTraits())}}; -OUTPUT_MAP(RandomChoiceWithMask) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mask)}}; - -// TruncatedNormal -INPUT_MAP(TruncatedNormal) = {{1, INPUT_DESC(shape)}}; -ATTR_MAP(TruncatedNormal) = {{"seed", ATTR_DESC(seed, AnyTraits())}, - {"seed2", ATTR_DESC(seed2, AnyTraits())}}; -OUTPUT_MAP(TruncatedNormal) = {{0, OUTPUT_DESC(y)}}; - -// StridedSliceGrad -INPUT_MAP(StridedSliceGrad) = { - {1, INPUT_DESC(dy)}, {2, INPUT_DESC(shape)}, {3, INPUT_DESC(begin)}, {4, INPUT_DESC(end)}, {5, INPUT_DESC(strides)}}; -ATTR_MAP(StridedSliceGrad) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits())}, - {"end_mask", ATTR_DESC(end_mask, AnyTraits())}, - {"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits())}, - {"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits())}, - {"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits())}}; -OUTPUT_MAP(StridedSliceGrad) = {{0, OUTPUT_DESC(output)}}; - -// Gelu -INPUT_MAP(Gelu) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Gelu) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Gelu) = {{0, OUTPUT_DESC(y)}}; - -// GeluGrad -INPUT_MAP(GeluGrad) = {{1, INPUT_DESC(dy)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(y)}}; -ATTR_MAP(GeluGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(GeluGrad) = {{0, OUTPUT_DESC(z)}}; - -// StridedSlice -INPUT_MAP(StridedSlice) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(begin)}, {3, INPUT_DESC(end)}, {4, INPUT_DESC(strides)}}; -ATTR_MAP(StridedSlice) = {{"begin_mask", ATTR_DESC(begin_mask, AnyTraits())}, - {"end_mask", ATTR_DESC(end_mask, AnyTraits())}, - {"ellipsis_mask", ATTR_DESC(ellipsis_mask, AnyTraits())}, - {"new_axis_mask", ATTR_DESC(new_axis_mask, AnyTraits())}, - {"shrink_axis_mask", ATTR_DESC(shrink_axis_mask, AnyTraits())}}; -OUTPUT_MAP(StridedSlice) = {{0, OUTPUT_DESC(y)}}; - -// UnsortedSegmentSum -INPUT_MAP(UnsortedSegmentSumD) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}}; -INPUT_ATTR_MAP(UnsortedSegmentSumD) = {{3, ATTR_DESC(num_segments, AnyTraits())}}; -ATTR_MAP(UnsortedSegmentSumD) = EMPTY_ATTR_MAP; -OUTPUT_MAP(UnsortedSegmentSumD) = {{0, OUTPUT_DESC(y)}}; - -// UnsortedSegmentMin -INPUT_MAP(UnsortedSegmentMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(segment_ids)}, {3, INPUT_DESC(num_segments)}}; -ATTR_MAP(UnsortedSegmentMin) = EMPTY_ATTR_MAP; -OUTPUT_MAP(UnsortedSegmentMin) = {{0, OUTPUT_DESC(y)}}; - -// ExpandDims -INPUT_MAP(ExpandDims) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; -ATTR_MAP(ExpandDims) = EMPTY_ATTR_MAP; -OUTPUT_MAP(ExpandDims) = {{0, OUTPUT_DESC(y)}}; - -// Squeeze -INPUT_MAP(Squeeze) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Squeeze) = {{"axis", ATTR_DESC(axis, AnyTraits(), AnyTraits>())}}; -OUTPUT_MAP(Squeeze) = {{0, OUTPUT_DESC(y)}}; - -// SGD -INPUT_MAP(SGD) = {{1, INPUT_DESC(parameters)}, {2, INPUT_DESC(gradient)}, {3, INPUT_DESC(learning_rate)}, - {4, INPUT_DESC(accum)}, {5, INPUT_DESC(momentum)}, {6, INPUT_DESC(stat)}}; -ATTR_MAP(SGD) = {{"dampening", ATTR_DESC(dampening, AnyTraits())}, - {"weight_decay", ATTR_DESC(weight_decay, AnyTraits())}, - {"nesterov", ATTR_DESC(nesterov, AnyTraits())}}; -OUTPUT_MAP(SGD) = {{0, OUTPUT_DESC(parameters)}}; - -// LayerNorm -INPUT_MAP(LayerNorm) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(gamma)}, {3, INPUT_DESC(beta)}}; -ATTR_MAP(LayerNorm) = {{"begin_norm_axis", ATTR_DESC(begin_norm_axis, AnyTraits())}, - {"begin_params_axis", ATTR_DESC(begin_params_axis, AnyTraits())}, - {"epsilon", ATTR_DESC(epsilon, AnyTraits())}}; -OUTPUT_MAP(LayerNorm) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(mean)}, {2, OUTPUT_DESC(variance)}}; - -// LayerNormGrad -INPUT_MAP(LayerNormGrad) = { - {1, INPUT_DESC(x)}, {2, INPUT_DESC(dy)}, {3, INPUT_DESC(variance)}, {4, INPUT_DESC(mean)}, {5, INPUT_DESC(gamma)}}; -ATTR_MAP(LayerNormGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(LayerNormGrad) = {{0, OUTPUT_DESC(pd_x)}, {1, OUTPUT_DESC(pd_gamma)}, {2, OUTPUT_DESC(pd_beta)}}; - -// BatchMatMul -INPUT_MAP(BatchMatMul) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(BatchMatMul) = {{"transpose_x1", ATTR_DESC(adj_x1, AnyTraits())}, - {"transpose_x2", ATTR_DESC(adj_x2, AnyTraits())}}; -OUTPUT_MAP(BatchMatMul) = {{0, OUTPUT_DESC(y)}}; - -// DropoutDoMask -INPUT_MAP(DropOutDoMask) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(mask)}, {3, INPUT_DESC(keep_prob)}}; -ATTR_MAP(DropOutDoMask) = EMPTY_ATTR_MAP; -OUTPUT_MAP(DropOutDoMask) = {{0, OUTPUT_DESC(y)}}; - -// NPUGetFloatStatus -INPUT_MAP(NPUGetFloatStatus) = {{1, INPUT_DESC(addr)}}; -OUTPUT_MAP(NPUGetFloatStatus) = {{0, OUTPUT_DESC(data)}}; -ATTR_MAP(NPUGetFloatStatus) = EMPTY_ATTR_MAP; - -// NPUAllocFloatStatus -INPUT_MAP(NPUAllocFloatStatus) = EMPTY_INPUT_MAP; -ATTR_MAP(NPUAllocFloatStatus) = EMPTY_ATTR_MAP; -OUTPUT_MAP(NPUAllocFloatStatus) = {{0, OUTPUT_DESC(data)}}; - -// NPUClearFloatStatus -INPUT_MAP(NPUClearFloatStatus) = {{1, INPUT_DESC(addr)}}; -OUTPUT_MAP(NPUClearFloatStatus) = {{0, OUTPUT_DESC(data)}}; -ATTR_MAP(NPUClearFloatStatus) = EMPTY_ATTR_MAP; - -// Abs -INPUT_MAP(Abs) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Abs) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Abs) = {{0, OUTPUT_DESC(y)}}; - -// AbsGrad -INPUT_MAP(AbsGrad) = {{1, INPUT_DESC(y)}, {2, INPUT_DESC(dy)}}; -ATTR_MAP(AbsGrad) = EMPTY_ATTR_MAP; -OUTPUT_MAP(AbsGrad) = {{0, OUTPUT_DESC(z)}}; - -// BinaryCrossEntropy -INPUT_MAP(BinaryCrossEntropy) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(weight)}}; -ATTR_MAP(BinaryCrossEntropy) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; -OUTPUT_MAP(BinaryCrossEntropy) = {{0, OUTPUT_DESC(output)}}; - -// BinaryCrossEntropyGrad -INPUT_MAP(BinaryCrossEntropyGrad) = { - {1, INPUT_DESC(x)}, {2, INPUT_DESC(y)}, {3, INPUT_DESC(grad_output)}, {4, INPUT_DESC(weight)}}; -ATTR_MAP(BinaryCrossEntropyGrad) = {{"reduction", ATTR_DESC(reduction, AnyTraits())}}; -OUTPUT_MAP(BinaryCrossEntropyGrad) = {{0, OUTPUT_DESC(output)}}; - -// SparseApplyAdagradD -INPUT_MAP(SparseApplyAdagradD) = { - {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; -ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits())}, - {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; - -// ApplyProximalAdagradD -INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)}, - {4, INPUT_DESC(l1)}, {5, INPUT_DESC(l2)}, {6, INPUT_DESC(grad)}}; -ATTR_MAP(ApplyProximalAdagradD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyProximalAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; - -// SparseApplyFtrlD -INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, - {2, INPUT_DESC(accum)}, - {3, INPUT_DESC(linear)}, - {4, INPUT_DESC(grad)}, - {5, INPUT_DESC(indices)}}; -ATTR_MAP(SparseApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, - {"lr", ATTR_DESC(lr, AnyTraits())}, - {"l1", ATTR_DESC(l1, AnyTraits())}, - {"l2", ATTR_DESC(l2, AnyTraits())}, - {"lr_power", ATTR_DESC(lr_power, AnyTraits())}}; -OUTPUT_MAP(SparseApplyFtrlD) = {{0, OUTPUT_DESC(var)}}; - -// SpaceToDepth -INPUT_MAP(SpaceToDepth) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(SpaceToDepth) = {{"block_size", ATTR_DESC(block_size, AnyTraits())}}; -OUTPUT_MAP(SpaceToDepth) = {{0, OUTPUT_DESC(y)}}; - -// DepthToSpace -INPUT_MAP(DepthToSpace) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(DepthToSpace) = {{"block_size", ATTR_DESC(block_size, AnyTraits())}}; -OUTPUT_MAP(DepthToSpace) = {{0, OUTPUT_DESC(y)}}; - -// Sign -INPUT_MAP(Sign) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Sign) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Sign) = {{0, OUTPUT_DESC(y)}}; - -// Round -INPUT_MAP(Round) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Round) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Round) = {{0, OUTPUT_DESC(y)}}; - -// ApplyFtrlD -INPUT_MAP(ApplyFtrlD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(linear)}, - {4, INPUT_DESC(grad)}, {5, INPUT_DESC(lr)}, {6, INPUT_DESC(l1)}, - {7, INPUT_DESC(l2)}, {8, INPUT_DESC(lr_power)}}; -ATTR_MAP(ApplyFtrlD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyFtrlD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}, {2, OUTPUT_DESC(linear)}}; - -// Diag -INPUT_MAP(Diag) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(Diag) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Diag) = {{0, OUTPUT_DESC(y)}}; - -// DiagPart -INPUT_MAP(DiagPart) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(DiagPart) = EMPTY_ATTR_MAP; -OUTPUT_MAP(DiagPart) = {{0, OUTPUT_DESC(y)}}; - -// SpaceToBatchD -INPUT_MAP(SpaceToBatchD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(SpaceToBatchD) = { - {"block_size", ATTR_DESC(block_size, AnyTraits())}, - {"paddings", ATTR_DESC(paddings, AnyTraits>>(), AnyTraits>())}}; -OUTPUT_MAP(SpaceToBatchD) = {{0, OUTPUT_DESC(y)}}; - -// BatchToSpaceD -INPUT_MAP(BatchToSpaceD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(BatchToSpaceD) = { - {"block_size", ATTR_DESC(block_size, AnyTraits())}, - {"crops", ATTR_DESC(crops, AnyTraits>>(), AnyTraits>())}}; -OUTPUT_MAP(BatchToSpaceD) = {{0, OUTPUT_DESC(y)}}; - -// Atan2 -INPUT_MAP(Atan2) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; -ATTR_MAP(Atan2) = EMPTY_ATTR_MAP; -OUTPUT_MAP(Atan2) = {{0, OUTPUT_DESC(y)}}; - -// ApplyRMSPropD -INPUT_MAP(ApplyRMSPropD) = { - {1, INPUT_DESC(var)}, {2, INPUT_DESC(ms)}, {3, INPUT_DESC(mom)}, {4, INPUT_DESC(lr)}, {5, INPUT_DESC(grad)}}; -INPUT_ATTR_MAP(ApplyRMSPropD) = {{6, ATTR_DESC(rho, AnyTraits())}, - {7, ATTR_DESC(momentum, AnyTraits())}, - {8, ATTR_DESC(epsilon, AnyTraits())}}; -ATTR_MAP(ApplyRMSPropD) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyRMSPropD) = {{0, OUTPUT_DESC(var)}}; - -// ApplyCenteredRMSProp -INPUT_MAP(ApplyCenteredRMSProp) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(mg)}, {3, INPUT_DESC(ms)}, - {4, INPUT_DESC(mom)}, {5, INPUT_DESC(grad)}, {6, INPUT_DESC(lr)}, - {7, INPUT_DESC(rho)}, {8, INPUT_DESC(momentum)}, {9, INPUT_DESC(epsilon)}}; -ATTR_MAP(ApplyCenteredRMSProp) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(ApplyCenteredRMSProp) = {{0, OUTPUT_DESC(var)}}; - -// L2Loss -INPUT_MAP(L2Loss) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(L2Loss) = EMPTY_ATTR_MAP; -OUTPUT_MAP(L2Loss) = {{0, OUTPUT_DESC(y)}}; - -// CTCLoss -INPUT_MAP(CTCLoss) = {{1, INPUT_DESC(inputs)}, - {2, INPUT_DESC(labels_indices)}, - {3, INPUT_DESC(labels_values)}, - {4, INPUT_DESC(sequence_length)}}; -ATTR_MAP(CTCLoss) = { - {"preprocess_collapse_repeated", ATTR_DESC(preprocess_collapse_repeated, AnyTraits())}, - {"ctc_merge_repeated", ATTR_DESC(ctc_merge_repeated, AnyTraits())}, - {"ignore_longer_outputs_than_inputs", ATTR_DESC(ignore_longer_outputs_than_inputs, AnyTraits())}}; -OUTPUT_MAP(CTCLoss) = {{0, OUTPUT_DESC(loss)}, {1, OUTPUT_DESC(gradient)}}; - -// AscendQuant -INPUT_MAP(AscendQuant) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(AscendQuant) = {{"scale", ATTR_DESC(scale, AnyTraits())}, - {"offset", ATTR_DESC(offset, AnyTraits())}, - {"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, - {"round_mode", ATTR_DESC(round_mode, AnyTraits())}}; -OUTPUT_MAP(AscendQuant) = {{0, OUTPUT_DESC(y)}}; - -// AscendDequant -INPUT_MAP(AscendDequant) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(deq_scale)}}; -ATTR_MAP(AscendDequant) = {{"sqrt_mode", ATTR_DESC(sqrt_mode, AnyTraits())}, - {"relu_flag", ATTR_DESC(relu_flag, AnyTraits())}}; -OUTPUT_MAP(AscendDequant) = {{0, OUTPUT_DESC(y)}}; -#ifdef ENABLE_GE -// Print -INPUT_MAP(Print) = EMPTY_INPUT_MAP; -DYN_INPUT_MAP(Print) = {{1, DYN_INPUT_DESC(x)}}; -ATTR_MAP(Print) = EMPTY_ATTR_MAP; -#endif -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h deleted file mode 100755 index 2dfbf11fc4..0000000000 --- a/mindspore/ccsrc/transform/op_declare.h +++ /dev/null @@ -1,505 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_OP_DECLARE_H_ -#define TRANSFORM_OP_DECLARE_H_ - -#include -#include -#include "transform/op_adapter.h" - -namespace mindspore { -namespace transform { -#define DECLARE_OP_ADAPTER(T) \ - using T = ge::op::T; \ - template <> \ - const std::unordered_map OpAdapter::input_map_; \ - template <> \ - const std::unordered_map OpAdapter::attr_map_; - -#define DECLARE_OP_USE_OUTPUT(T) \ - template <> \ - const std::unordered_map OpAdapter::output_map_; - -#define DECLARE_OP_USE_ENUM(T) \ - template <> \ - const std::unordered_map OpAdapter::enum_map_; - -#define DECLARE_OP_USE_INPUT_ATTR(T) \ - template <> \ - const std::unordered_map OpAdapter::input_attr_map_; - -#define DECLARE_OP_USE_DYN_INPUT(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_input_map_; - -#define DECLARE_OP_USE_DYN_SUBGRAPH(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_subgraph_map_; - -#define DECLARE_OP_USE_DYN_OUTPUT(T) \ - template <> \ - const std::unordered_map OpAdapter::dyn_output_map_; - -template <> -std::unordered_map> OpAdapter::cus_input_map_; -template <> -std::unordered_map> OpAdapter::cus_output_map_; - -DECLARE_OP_ADAPTER(GreaterEqual) -DECLARE_OP_USE_OUTPUT(GreaterEqual) -DECLARE_OP_ADAPTER(SliceD) -DECLARE_OP_USE_INPUT_ATTR(SliceD) -DECLARE_OP_USE_OUTPUT(SliceD) -DECLARE_OP_ADAPTER(AssignAdd) -DECLARE_OP_USE_OUTPUT(AssignAdd) -DECLARE_OP_ADAPTER(AssignSub) -DECLARE_OP_USE_OUTPUT(AssignSub) - -DECLARE_OP_ADAPTER(ReduceMean) -DECLARE_OP_ADAPTER(Multiply) -DECLARE_OP_USE_OUTPUT(Multiply) - -// ** Distributed Operations ** -DECLARE_OP_ADAPTER(HcomReduceScatter) -DECLARE_OP_USE_OUTPUT(HcomReduceScatter) -DECLARE_OP_ADAPTER(HcomBroadcast) -DECLARE_OP_USE_DYN_INPUT(HcomBroadcast) -DECLARE_OP_USE_DYN_OUTPUT(HcomBroadcast) -DECLARE_OP_ADAPTER(HcomAllReduce) -DECLARE_OP_USE_OUTPUT(HcomAllReduce) -DECLARE_OP_ADAPTER(HcomAllGather) -DECLARE_OP_USE_OUTPUT(HcomAllGather) -DECLARE_OP_ADAPTER(Variable) -DECLARE_OP_ADAPTER(ReluGrad) -DECLARE_OP_USE_OUTPUT(ReluGrad) -DECLARE_OP_ADAPTER(BiasAddGrad) -DECLARE_OP_USE_OUTPUT(BiasAddGrad) -DECLARE_OP_ADAPTER(MaxPoolWithArgmax) -DECLARE_OP_USE_OUTPUT(MaxPoolWithArgmax) -DECLARE_OP_ADAPTER(MaxPoolGradWithArgmax) -DECLARE_OP_USE_OUTPUT(MaxPoolGradWithArgmax) -DECLARE_OP_ADAPTER(Conv2D) -DECLARE_OP_USE_ENUM(Conv2D) -DECLARE_OP_USE_OUTPUT(Conv2D) -DECLARE_OP_ADAPTER(ExtractImagePatches) -DECLARE_OP_USE_OUTPUT(ExtractImagePatches) -DECLARE_OP_ADAPTER(Conv2DBackpropInputD) -DECLARE_OP_USE_ENUM(Conv2DBackpropInputD) -DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropInputD) -DECLARE_OP_USE_OUTPUT(Conv2DBackpropInputD) -DECLARE_OP_ADAPTER(Conv2DBackpropFilterD) -DECLARE_OP_USE_ENUM(Conv2DBackpropFilterD) -DECLARE_OP_USE_INPUT_ATTR(Conv2DBackpropFilterD) -DECLARE_OP_USE_OUTPUT(Conv2DBackpropFilterD) -DECLARE_OP_ADAPTER(DepthwiseConv2D) -DECLARE_OP_USE_ENUM(DepthwiseConv2D) -DECLARE_OP_USE_OUTPUT(DepthwiseConv2D) -DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropFilterD) -DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropFilterD) -DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropFilterD) -DECLARE_OP_ADAPTER(DepthwiseConv2DBackpropInputD) -DECLARE_OP_USE_INPUT_ATTR(DepthwiseConv2DBackpropInputD) -DECLARE_OP_USE_OUTPUT(DepthwiseConv2DBackpropInputD) -DECLARE_OP_ADAPTER(Reshape) -DECLARE_OP_USE_OUTPUT(Reshape) -DECLARE_OP_ADAPTER(TransShape) -DECLARE_OP_USE_INPUT_ATTR(TransShape) -DECLARE_OP_USE_OUTPUT(TransShape) -DECLARE_OP_ADAPTER(Iou) -DECLARE_OP_USE_OUTPUT(Iou) -DECLARE_OP_ADAPTER(ResizeNearestNeighborV2D) -DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2D) -DECLARE_OP_ADAPTER(ResizeNearestNeighborV2Grad) -DECLARE_OP_USE_OUTPUT(ResizeNearestNeighborV2Grad) -DECLARE_OP_ADAPTER(ApplyAdam) -DECLARE_OP_USE_OUTPUT(ApplyAdam) -DECLARE_OP_ADAPTER(ApplyAdamD) -DECLARE_OP_USE_OUTPUT(ApplyAdamD) -DECLARE_OP_ADAPTER(Relu6) -DECLARE_OP_USE_OUTPUT(Relu6) -DECLARE_OP_ADAPTER(Relu6Grad) -DECLARE_OP_USE_OUTPUT(Relu6Grad) -DECLARE_OP_ADAPTER(ResizeBilinearV2D) -DECLARE_OP_USE_OUTPUT(ResizeBilinearV2D) -DECLARE_OP_ADAPTER(ResizeBilinearV2Grad) -DECLARE_OP_USE_OUTPUT(ResizeBilinearV2Grad) -DECLARE_OP_ADAPTER(ZerosLike) -DECLARE_OP_USE_OUTPUT(ZerosLike) -DECLARE_OP_ADAPTER(OnesLike) -DECLARE_OP_USE_OUTPUT(OnesLike) -DECLARE_OP_ADAPTER(TensorScatterUpdate) -DECLARE_OP_USE_OUTPUT(TensorScatterUpdate) -DECLARE_OP_ADAPTER(ScatterUpdate) -DECLARE_OP_USE_OUTPUT(ScatterUpdate) -DECLARE_OP_ADAPTER(ScatterNdUpdate) -DECLARE_OP_USE_OUTPUT(ScatterNdUpdate) -DECLARE_OP_ADAPTER(ScatterMax) -DECLARE_OP_USE_OUTPUT(ScatterMax) -DECLARE_OP_ADAPTER(NMSWithMask) -DECLARE_OP_USE_OUTPUT(NMSWithMask) -DECLARE_OP_ADAPTER(Unpack) -DECLARE_OP_USE_DYN_OUTPUT(Unpack) -DECLARE_OP_ADAPTER(CheckValid) -DECLARE_OP_USE_OUTPUT(CheckValid) -DECLARE_OP_ADAPTER(SmoothL1Loss) -DECLARE_OP_USE_OUTPUT(SmoothL1Loss) -DECLARE_OP_ADAPTER(SmoothL1LossGrad) -DECLARE_OP_USE_OUTPUT(SmoothL1LossGrad) -DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogits) -DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogits) -DECLARE_OP_ADAPTER(SigmoidCrossEntropyWithLogitsGrad) -DECLARE_OP_USE_OUTPUT(SigmoidCrossEntropyWithLogitsGrad) -DECLARE_OP_ADAPTER(ScatterNdD) -DECLARE_OP_USE_INPUT_ATTR(ScatterNdD) -DECLARE_OP_USE_OUTPUT(ScatterNdD) -DECLARE_OP_ADAPTER(PadD) -DECLARE_OP_USE_OUTPUT(PadD) -DECLARE_OP_ADAPTER(MirrorPad) -DECLARE_OP_USE_OUTPUT(MirrorPad) -DECLARE_OP_ADAPTER(MirrorPadGrad) -DECLARE_OP_USE_OUTPUT(MirrorPadGrad) -DECLARE_OP_ADAPTER(BoundingBoxEncode) -DECLARE_OP_USE_OUTPUT(BoundingBoxEncode) -DECLARE_OP_ADAPTER(BoundingBoxDecode) -DECLARE_OP_USE_OUTPUT(BoundingBoxDecode) -DECLARE_OP_ADAPTER(GatherNd) -DECLARE_OP_USE_OUTPUT(GatherNd) -DECLARE_OP_ADAPTER(ArgMaxD) -DECLARE_OP_USE_OUTPUT(ArgMaxD) -DECLARE_OP_ADAPTER(ArgMinD) -DECLARE_OP_USE_OUTPUT(ArgMinD) -DECLARE_OP_ADAPTER(ArgMaxWithValue) -DECLARE_OP_USE_OUTPUT(ArgMaxWithValue) -DECLARE_OP_ADAPTER(ArgMinWithValue) -DECLARE_OP_USE_OUTPUT(ArgMinWithValue) -DECLARE_OP_ADAPTER(Mul) -DECLARE_OP_USE_OUTPUT(Mul) -DECLARE_OP_ADAPTER(AddN) -DECLARE_OP_USE_DYN_INPUT(AddN) -DECLARE_OP_USE_OUTPUT(AddN) -DECLARE_OP_ADAPTER(Less) -DECLARE_OP_USE_OUTPUT(Less) -DECLARE_OP_ADAPTER(Rsqrt) -DECLARE_OP_USE_OUTPUT(Rsqrt) -DECLARE_OP_ADAPTER(Sqrt) -DECLARE_OP_USE_OUTPUT(Sqrt) -DECLARE_OP_ADAPTER(Square) -DECLARE_OP_USE_OUTPUT(Square) -DECLARE_OP_ADAPTER(SplitD) -DECLARE_OP_USE_DYN_OUTPUT(SplitD) -DECLARE_OP_ADAPTER(SGD) -DECLARE_OP_USE_OUTPUT(SGD) -DECLARE_OP_ADAPTER(SquareSumAll) -DECLARE_OP_USE_OUTPUT(SquareSumAll) - -DECLARE_OP_ADAPTER(Tanh) -DECLARE_OP_USE_OUTPUT(Tanh) -DECLARE_OP_ADAPTER(TanhGrad) -DECLARE_OP_USE_OUTPUT(TanhGrad) -DECLARE_OP_ADAPTER(Maximum) -DECLARE_OP_USE_OUTPUT(Maximum) -DECLARE_OP_ADAPTER(Minimum) -DECLARE_OP_USE_OUTPUT(Minimum) -DECLARE_OP_ADAPTER(MaximumGrad) -DECLARE_OP_USE_OUTPUT(MaximumGrad) -DECLARE_OP_ADAPTER(MinimumGrad) -DECLARE_OP_USE_OUTPUT(MinimumGrad) -DECLARE_OP_ADAPTER(ReduceMinD) -DECLARE_OP_USE_INPUT_ATTR(ReduceMinD) -DECLARE_OP_USE_OUTPUT(ReduceMinD) -DECLARE_OP_ADAPTER(ReduceMaxD) -DECLARE_OP_USE_INPUT_ATTR(ReduceMaxD) -DECLARE_OP_USE_OUTPUT(ReduceMaxD) -DECLARE_OP_ADAPTER(Merge) -DECLARE_OP_USE_DYN_INPUT(Merge) -DECLARE_OP_USE_OUTPUT(Merge) -DECLARE_OP_ADAPTER(Switch) -DECLARE_OP_USE_OUTPUT(Switch) - -DECLARE_OP_ADAPTER(TopK) -DECLARE_OP_USE_OUTPUT(TopK) - -DECLARE_OP_ADAPTER(RealDiv) -DECLARE_OP_USE_OUTPUT(RealDiv) - -DECLARE_OP_ADAPTER(Cast) -DECLARE_OP_USE_INPUT_ATTR(Cast) -DECLARE_OP_USE_OUTPUT(Cast) -DECLARE_OP_ADAPTER(Case) -DECLARE_OP_USE_DYN_INPUT(Case) -DECLARE_OP_USE_DYN_SUBGRAPH(Case) -DECLARE_OP_USE_DYN_OUTPUT(Case) -DECLARE_OP_ADAPTER(Reciprocal) -DECLARE_OP_USE_OUTPUT(Reciprocal) -DECLARE_OP_ADAPTER(Neg) -DECLARE_OP_USE_OUTPUT(Neg) -DECLARE_OP_ADAPTER(TransposeD) -DECLARE_OP_USE_INPUT_ATTR(TransposeD) -// Do not set Transpose operator output descriptor -DECLARE_OP_ADAPTER(Sub) -DECLARE_OP_USE_OUTPUT(Sub) -DECLARE_OP_ADAPTER(DropOutGenMask) -DECLARE_OP_USE_OUTPUT(DropOutGenMask) -DECLARE_OP_ADAPTER(ConcatD) -DECLARE_OP_USE_DYN_INPUT(ConcatD) -DECLARE_OP_USE_OUTPUT(ConcatD) -DECLARE_OP_ADAPTER(Pack) -DECLARE_OP_USE_DYN_INPUT(Pack) -DECLARE_OP_USE_OUTPUT(Pack) - -DECLARE_OP_ADAPTER(Pow) -DECLARE_OP_USE_OUTPUT(Pow) -DECLARE_OP_ADAPTER(Equal) -DECLARE_OP_USE_OUTPUT(Equal) -DECLARE_OP_ADAPTER(NotEqual) -DECLARE_OP_USE_OUTPUT(NotEqual) -DECLARE_OP_ADAPTER(Log) -DECLARE_OP_USE_OUTPUT(Log) -DECLARE_OP_ADAPTER(LogicalAnd) -DECLARE_OP_USE_OUTPUT(LogicalAnd) -DECLARE_OP_ADAPTER(LogicalOr) -DECLARE_OP_USE_OUTPUT(LogicalOr) -DECLARE_OP_ADAPTER(LogicalNot) -DECLARE_OP_USE_OUTPUT(LogicalNot) -DECLARE_OP_ADAPTER(LogSoftmaxGrad) -DECLARE_OP_USE_OUTPUT(LogSoftmaxGrad) - -DECLARE_OP_ADAPTER(RandomChoiceWithMask) -DECLARE_OP_USE_OUTPUT(RandomChoiceWithMask) - -DECLARE_OP_ADAPTER(Select) -DECLARE_OP_USE_OUTPUT(Select) -DECLARE_OP_ADAPTER(LessEqual) -DECLARE_OP_USE_OUTPUT(LessEqual) -DECLARE_OP_ADAPTER(LogSoftmaxV2) -DECLARE_OP_USE_OUTPUT(LogSoftmaxV2) -DECLARE_OP_ADAPTER(TruncatedNormal) -DECLARE_OP_USE_OUTPUT(TruncatedNormal) -DECLARE_OP_ADAPTER(StridedSliceGrad) -DECLARE_OP_USE_OUTPUT(StridedSliceGrad) -DECLARE_OP_ADAPTER(Gelu) -DECLARE_OP_USE_OUTPUT(Gelu) -DECLARE_OP_ADAPTER(GeluGrad) -DECLARE_OP_USE_OUTPUT(GeluGrad) -DECLARE_OP_ADAPTER(StridedSlice) -DECLARE_OP_USE_OUTPUT(StridedSlice) -DECLARE_OP_ADAPTER(UnsortedSegmentSumD) -DECLARE_OP_USE_INPUT_ATTR(UnsortedSegmentSumD) -DECLARE_OP_USE_OUTPUT(UnsortedSegmentSumD) -DECLARE_OP_ADAPTER(UnsortedSegmentMin) -DECLARE_OP_USE_OUTPUT(UnsortedSegmentMin) -DECLARE_OP_ADAPTER(ExpandDims) -DECLARE_OP_USE_OUTPUT(ExpandDims) -DECLARE_OP_ADAPTER(Squeeze) -DECLARE_OP_USE_OUTPUT(Squeeze) -DECLARE_OP_ADAPTER(LayerNorm) -DECLARE_OP_USE_OUTPUT(LayerNorm) -DECLARE_OP_ADAPTER(LayerNormGrad) -DECLARE_OP_USE_OUTPUT(LayerNormGrad) -DECLARE_OP_ADAPTER(BatchMatMul) -DECLARE_OP_USE_OUTPUT(BatchMatMul) -DECLARE_OP_ADAPTER(DropOutDoMask) -DECLARE_OP_USE_OUTPUT(DropOutDoMask) -// ** Mix-precision Operations ** -DECLARE_OP_ADAPTER(NPUGetFloatStatus) -DECLARE_OP_USE_OUTPUT(NPUGetFloatStatus) -DECLARE_OP_ADAPTER(NPUAllocFloatStatus) -DECLARE_OP_USE_OUTPUT(NPUAllocFloatStatus) -DECLARE_OP_ADAPTER(NPUClearFloatStatus) -DECLARE_OP_USE_OUTPUT(NPUClearFloatStatus) -DECLARE_OP_ADAPTER(MatMulV2) -DECLARE_OP_USE_OUTPUT(MatMulV2) - -DECLARE_OP_ADAPTER(SoftmaxCrossEntropyWithLogits) -DECLARE_OP_USE_OUTPUT(SoftmaxCrossEntropyWithLogits) - -DECLARE_OP_ADAPTER(MeanGrad) -DECLARE_OP_USE_INPUT_ATTR(MeanGrad) - -DECLARE_OP_ADAPTER(Assign) -DECLARE_OP_USE_OUTPUT(Assign) -DECLARE_OP_ADAPTER(Constant) -DECLARE_OP_USE_OUTPUT(Constant) -DECLARE_OP_ADAPTER(ApplyMomentumD) -DECLARE_OP_USE_OUTPUT(ApplyMomentumD) -// ** Summary Operations ** -DECLARE_OP_ADAPTER(Summary) - -// fully supported -DECLARE_OP_ADAPTER(Add) -DECLARE_OP_USE_OUTPUT(Add) -DECLARE_OP_ADAPTER(Const) -DECLARE_OP_USE_OUTPUT(Const) -DECLARE_OP_ADAPTER(Cos) -DECLARE_OP_USE_OUTPUT(Cos) - -DECLARE_OP_ADAPTER(Acos) -DECLARE_OP_USE_OUTPUT(Acos) -DECLARE_OP_ADAPTER(AcosGrad) -DECLARE_OP_USE_OUTPUT(AcosGrad) -DECLARE_OP_ADAPTER(Acosh) -DECLARE_OP_USE_OUTPUT(Acosh) -DECLARE_OP_ADAPTER(AcoshGrad) -DECLARE_OP_USE_OUTPUT(AcoshGrad) - -DECLARE_OP_ADAPTER(Floor) -DECLARE_OP_USE_OUTPUT(Floor) -DECLARE_OP_ADAPTER(FloorDiv) -DECLARE_OP_USE_OUTPUT(FloorDiv) -DECLARE_OP_ADAPTER(FloorMod) -DECLARE_OP_USE_OUTPUT(FloorMod) -DECLARE_OP_ADAPTER(Sin) -DECLARE_OP_USE_OUTPUT(Sin) -DECLARE_OP_ADAPTER(Exp) -DECLARE_OP_USE_OUTPUT(Exp) - -DECLARE_OP_ADAPTER(ReduceAllD) -DECLARE_OP_USE_INPUT_ATTR(ReduceAllD) -DECLARE_OP_USE_OUTPUT(ReduceAllD) -DECLARE_OP_ADAPTER(ReduceSumD) -DECLARE_OP_USE_INPUT_ATTR(ReduceSumD) -DECLARE_OP_USE_OUTPUT(ReduceSumD) -DECLARE_OP_ADAPTER(ReduceMeanD) -DECLARE_OP_USE_INPUT_ATTR(ReduceMeanD) -DECLARE_OP_USE_OUTPUT(ReduceMeanD) -DECLARE_OP_ADAPTER(ReduceProdD) -DECLARE_OP_USE_INPUT_ATTR(ReduceProdD) -DECLARE_OP_USE_OUTPUT(ReduceProdD) -DECLARE_OP_ADAPTER(CumprodD) -DECLARE_OP_USE_INPUT_ATTR(CumprodD) -DECLARE_OP_USE_OUTPUT(CumprodD) - -DECLARE_OP_ADAPTER(TileD) -DECLARE_OP_USE_INPUT_ATTR(TileD) -DECLARE_OP_USE_OUTPUT(TileD) -DECLARE_OP_ADAPTER(OneHot) -DECLARE_OP_USE_OUTPUT(OneHot) -DECLARE_OP_ADAPTER(GatherV2D) -DECLARE_OP_USE_INPUT_ATTR(GatherV2D) -DECLARE_OP_USE_OUTPUT(GatherV2D) -DECLARE_OP_ADAPTER(RangeD) -DECLARE_OP_USE_OUTPUT(RangeD) - -DECLARE_OP_ADAPTER(Data) -DECLARE_OP_ADAPTER(BiasAdd) -DECLARE_OP_USE_OUTPUT(BiasAdd) -DECLARE_OP_ADAPTER(BatchNorm) -DECLARE_OP_USE_OUTPUT(BatchNorm) -DECLARE_OP_ADAPTER(BatchNormGrad) -DECLARE_OP_USE_OUTPUT(BatchNormGrad) -DECLARE_OP_ADAPTER(Relu) -DECLARE_OP_USE_OUTPUT(Relu) -DECLARE_OP_ADAPTER(PRelu) -DECLARE_OP_USE_OUTPUT(PRelu) -DECLARE_OP_ADAPTER(Elu) -DECLARE_OP_USE_OUTPUT(Elu) - -DECLARE_OP_ADAPTER(EluGrad) -DECLARE_OP_USE_OUTPUT(EluGrad) -DECLARE_OP_ADAPTER(PReluGrad) -DECLARE_OP_USE_OUTPUT(PReluGrad) - -DECLARE_OP_ADAPTER(L2Normalize) -DECLARE_OP_USE_OUTPUT(L2Normalize) - -DECLARE_OP_ADAPTER(CumsumD) -DECLARE_OP_USE_INPUT_ATTR(CumsumD) -DECLARE_OP_USE_OUTPUT(CumsumD) -DECLARE_OP_ADAPTER(L2NormalizeGrad) -DECLARE_OP_USE_OUTPUT(L2NormalizeGrad) -DECLARE_OP_ADAPTER(Sigmoid) -DECLARE_OP_USE_OUTPUT(Sigmoid) -DECLARE_OP_ADAPTER(SigmoidGrad) -DECLARE_OP_USE_OUTPUT(SigmoidGrad) -DECLARE_OP_ADAPTER(SoftmaxV2) -DECLARE_OP_USE_OUTPUT(SoftmaxV2) -DECLARE_OP_ADAPTER(SoftmaxGrad) -DECLARE_OP_USE_OUTPUT(SoftmaxGrad) -DECLARE_OP_ADAPTER(Greater) -DECLARE_OP_USE_OUTPUT(Greater) -DECLARE_OP_ADAPTER(Flatten) -DECLARE_OP_USE_OUTPUT(Flatten) -DECLARE_OP_ADAPTER(GatherV2) -DECLARE_OP_USE_OUTPUT(GatherV2) -DECLARE_OP_ADAPTER(MaxPool) -DECLARE_OP_USE_OUTPUT(MaxPool) -DECLARE_OP_ADAPTER(MaxPoolGrad) -DECLARE_OP_USE_OUTPUT(MaxPoolGrad) -DECLARE_OP_ADAPTER(AvgPool) -DECLARE_OP_USE_OUTPUT(AvgPool) -DECLARE_OP_ADAPTER(AvgPoolGrad) -DECLARE_OP_USE_OUTPUT(AvgPoolGrad) -DECLARE_OP_ADAPTER(ROIAlign) -DECLARE_OP_USE_OUTPUT(ROIAlign) -DECLARE_OP_ADAPTER(ROIAlignGrad) -DECLARE_OP_USE_OUTPUT(ROIAlignGrad) -DECLARE_OP_ADAPTER(Abs) -DECLARE_OP_USE_OUTPUT(Abs) -DECLARE_OP_ADAPTER(AbsGrad) -DECLARE_OP_USE_OUTPUT(AbsGrad) -DECLARE_OP_ADAPTER(BinaryCrossEntropy) -DECLARE_OP_USE_OUTPUT(BinaryCrossEntropy) -DECLARE_OP_ADAPTER(BinaryCrossEntropyGrad) -DECLARE_OP_USE_OUTPUT(BinaryCrossEntropyGrad) -DECLARE_OP_ADAPTER(SparseApplyAdagradD) -DECLARE_OP_USE_OUTPUT(SparseApplyAdagradD) -DECLARE_OP_ADAPTER(ApplyProximalAdagradD) -DECLARE_OP_USE_OUTPUT(ApplyProximalAdagradD) -DECLARE_OP_ADAPTER(SpaceToDepth) -DECLARE_OP_USE_OUTPUT(SpaceToDepth) -DECLARE_OP_ADAPTER(DepthToSpace) -DECLARE_OP_USE_OUTPUT(DepthToSpace) -DECLARE_OP_ADAPTER(Sign) -DECLARE_OP_USE_OUTPUT(Sign) -DECLARE_OP_ADAPTER(LarsV2Update) -DECLARE_OP_USE_OUTPUT(LarsV2Update) -DECLARE_OP_ADAPTER(Round) -DECLARE_OP_USE_OUTPUT(Round) -DECLARE_OP_ADAPTER(ApplyFtrlD) -DECLARE_OP_USE_OUTPUT(ApplyFtrlD) -DECLARE_OP_ADAPTER(SparseApplyFtrlD) -DECLARE_OP_USE_OUTPUT(SparseApplyFtrlD) -DECLARE_OP_ADAPTER(Diag) -DECLARE_OP_USE_OUTPUT(Diag) -DECLARE_OP_ADAPTER(DiagPart) -DECLARE_OP_USE_OUTPUT(DiagPart) -DECLARE_OP_ADAPTER(SpaceToBatchD) -DECLARE_OP_USE_OUTPUT(SpaceToBatchD) -DECLARE_OP_ADAPTER(BatchToSpaceD) -DECLARE_OP_USE_OUTPUT(BatchToSpaceD) -DECLARE_OP_ADAPTER(Atan2) -DECLARE_OP_USE_OUTPUT(Atan2) -DECLARE_OP_ADAPTER(ApplyRMSPropD) -DECLARE_OP_USE_INPUT_ATTR(ApplyRMSPropD) -DECLARE_OP_USE_OUTPUT(ApplyRMSPropD) -DECLARE_OP_ADAPTER(ApplyCenteredRMSProp) -DECLARE_OP_USE_OUTPUT(ApplyCenteredRMSProp) -DECLARE_OP_ADAPTER(L2Loss) -DECLARE_OP_USE_OUTPUT(L2Loss) -DECLARE_OP_ADAPTER(CTCLoss) -DECLARE_OP_USE_OUTPUT(CTCLoss) -DECLARE_OP_ADAPTER(AscendQuant) -DECLARE_OP_USE_OUTPUT(AscendQuant) -DECLARE_OP_ADAPTER(AscendDequant) -DECLARE_OP_USE_OUTPUT(AscendDequant) -#ifdef ENABLE_GE -DECLARE_OP_ADAPTER(Print) -DECLARE_OP_USE_DYN_INPUT(Print) -#endif -} // namespace transform -} // namespace mindspore -#endif // TRANSFORM_OP_DECLARE_H_ diff --git a/mindspore/ccsrc/transform/util.cc b/mindspore/ccsrc/transform/util.cc deleted file mode 100644 index b848ec117b..0000000000 --- a/mindspore/ccsrc/transform/util.cc +++ /dev/null @@ -1,452 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "transform/util.h" - -#include -#include -#include - -#include "securec/include/securec.h" -#include "utils/convert_utils.h" -#include "utils/utils.h" - -namespace mindspore { -namespace transform { -using std::make_shared; -using std::shared_ptr; -using std::string; -using std::vector; - -const size_t kErrorSize = 0; - -vector TransformUtil::ConvertIntToList(int64_t data, int size) { - vector list{}; - if (size <= 0) { - MS_LOG(WARNING) << "size <= 0"; - return list; - } - for (int i = 0; i < size; ++i) { - list.push_back(data); - } - return list; -} - -static std::map datatype_trans_map = { - {MeDataType::kNumberTypeFloat16, GeDataType::DT_FLOAT16}, {MeDataType::kNumberTypeFloat32, GeDataType::DT_FLOAT}, - {MeDataType::kNumberTypeFloat64, GeDataType::DT_DOUBLE}, {MeDataType::kNumberTypeInt8, GeDataType::DT_INT8}, - {MeDataType::kNumberTypeInt16, GeDataType::DT_INT16}, {MeDataType::kNumberTypeInt32, GeDataType::DT_INT32}, - {MeDataType::kNumberTypeInt64, GeDataType::DT_INT64}, {MeDataType::kNumberTypeUInt8, GeDataType::DT_UINT8}, - {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, - {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; - -GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { - MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; - if (datatype_trans_map.find(type) != datatype_trans_map.end()) { - return datatype_trans_map[type]; - } else { - return GeDataType::DT_UNDEFINED; - } -} - -static std::map datatype_size_map = { - {MeDataType::kNumberTypeFloat16, sizeof(float) / 2}, {MeDataType::kNumberTypeFloat32, sizeof(float)}, // 1/2 of float - {MeDataType::kNumberTypeFloat64, sizeof(double)}, {MeDataType::kNumberTypeInt8, sizeof(int8_t)}, - {MeDataType::kNumberTypeInt16, sizeof(int16_t)}, {MeDataType::kNumberTypeInt32, sizeof(int32_t)}, - {MeDataType::kNumberTypeInt64, sizeof(int64_t)}, {MeDataType::kNumberTypeUInt8, sizeof(uint8_t)}, - {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, - {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; - -size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { - if (datatype_size_map.find(type) != datatype_size_map.end()) { - return datatype_size_map[type]; - } else { - MS_LOG(ERROR) << "Illegal tensor data type!"; - return kErrorSize; - } -} - -GeFormat TransformUtil::ConvertFormat(const string &format) { - if (format == kOpFormat_NCHW) { - return GeFormat::FORMAT_NCHW; - } else if (format == kOpFormat_NC1HWC0) { - return GeFormat::FORMAT_NC1HWC0; - } else if (format == kOpFormat_NHWC) { - return GeFormat::FORMAT_NHWC; - } else if (format == kOpFormat_HWCN) { - return GeFormat::FORMAT_HWCN; - } else { - return GeFormat::FORMAT_ND; - } -} - -static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } - -std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, - const MeDataType &me_type, const std::string &format) { - // convert me shape to ge shape - std::vector ge_shape; - - if (me_shape.size() == 1) { - ge_shape.push_back(static_cast(me_shape[0])); - } else { - ge_shape.resize(me_shape.size()); - (void)std::transform(me_shape.begin(), me_shape.end(), ge_shape.begin(), IntegerCastFunc); - } - - GeShape shape(ge_shape); - if (shape.GetDimNum() == 0) { - MS_LOG(INFO) << "The dims size of Ge tensor is zero"; - } - // convert me format to ge format - GeFormat ge_format = ConvertFormat(format); - if (ge_format == GeFormat::FORMAT_ND) { - MS_LOG(ERROR) << "undefined data format : " << static_cast(ge_format); - return nullptr; - } - // convert me datatype to ge datatype - GeDataType data_type = ConvertDataType(me_type); - if (data_type == GeDataType::DT_UNDEFINED) { - MS_LOG(ERROR) << "undefined data type :" << me_type; - return nullptr; - } - - auto desc = std::make_shared(shape, ge_format, data_type); - if (desc == nullptr) { - MS_LOG(ERROR) << "Create GeTensorDesc failed!"; - return nullptr; - } - MS_LOG(INFO) << "SetRealDimCnt is :" << me_shape.size(); - desc->SetRealDimCnt(SizeToInt(me_shape.size())); - return desc; -} - -// if failed, return empty vector. -std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, - const std::string &format) { - std::vector ge_tensors; - - for (size_t index = 0; index < me_tensors.size(); index++) { - MS_EXCEPTION_IF_NULL(me_tensors[index]); - MS_LOG(INFO) << "me_tensor " << index << " 's data size is: " << me_tensors[index]->DataSize(); - auto shape = me_tensors[index]->shape(); - std::string shape_str; - for (size_t i = 0; i < shape.size(); i++) { - shape_str += std::to_string(shape[i]); - shape_str += " "; - } - MS_LOG(INFO) << "me_tensor " << index << " 's shape is: { " << shape_str << "}"; - MS_LOG(INFO) << "me_tensor " << index << " 's type is: " << me_tensors[index]->data_type(); - - auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensors[index], format); - if (ge_tensor_ptr != nullptr) { - ge_tensors.emplace_back(ge_tensor_ptr); - } else { - MS_LOG(ERROR) << "Convert me_tensor " << index << " to Ge Tensor failed!"; - ge_tensors.clear(); - return ge_tensors; - } - } - return ge_tensors; -} - -GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::string &format) { - // get tensor data type size - MS_EXCEPTION_IF_NULL(tensor); - size_t type_size = GetDataTypeSize(tensor->data_type()); - if (type_size == kErrorSize) { - MS_LOG(ERROR) << "The Me Tensor data type size is wrong, type size is: " << type_size; - return nullptr; - } - size_t elements_num = IntToSize(tensor->ElementsNum()); - if (UINT_MAX / type_size < elements_num) { - MS_LOG(ERROR) << "The required Me Tensor data buff size " << elements_num << " x " << type_size - << " overflowed UINT_MAX: " << UINT_MAX << "."; - return nullptr; - } - - // get tensor buff size - size_t data_buff_size = elements_num * type_size; - if (data_buff_size == 0) { - MS_LOG(INFO) << "The Me Tensor data buff size is 0."; - } - // create ge tensor - auto desc = GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format); - if (desc == nullptr) { - MS_LOG(ERROR) << "Failed to get Tensor Desc"; - return nullptr; - } - GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); - if (tensor_ptr != nullptr) { - MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; - } - return tensor_ptr; -} - -std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, - const std::vector> &request_dims) { - std::vector outputs; - - for (size_t index = 0; index < ge_tensors.size(); index++) { - MeTensorPtr me_tensor_ptr = nullptr; - if (index < request_dims.size()) { - me_tensor_ptr = ConvertGeTensor(ge_tensors[index], request_dims[index]); - } else { - std::vector empty_shape; - me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape); - } - - if (me_tensor_ptr != nullptr) { - outputs.emplace_back(me_tensor_ptr); - } else { - MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; - return outputs; - } - } - return outputs; -} - -std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { - std::vector outputs; - - for (size_t index = 0; index < ge_tensors.size(); index++) { - MeTensorPtr me_tensor_ptr = ConvertGeTensor(ge_tensors[index]); - if (me_tensor_ptr != nullptr) { - outputs.emplace_back(me_tensor_ptr); - } else { - MS_LOG(ERROR) << "Convert Ge Tensor " << index << " to Me Tensor failed!"; - return outputs; - } - } - return outputs; -} - -MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) { - switch (type) { - case GeDataType::DT_FLOAT16: - return MeDataType::kNumberTypeFloat16; - case GeDataType::DT_FLOAT: - return MeDataType::kNumberTypeFloat32; - case GeDataType::DT_DOUBLE: - return MeDataType::kNumberTypeFloat64; - case GeDataType::DT_INT64: - return MeDataType::kNumberTypeInt64; - case GeDataType::DT_INT32: - return MeDataType::kNumberTypeInt32; - case GeDataType::DT_INT16: - return MeDataType::kNumberTypeInt16; - case GeDataType::DT_INT8: - return MeDataType::kNumberTypeInt8; - case GeDataType::DT_BOOL: - return MeDataType::kNumberTypeBool; - case GeDataType::DT_UINT8: - return MeDataType::kNumberTypeUInt8; - case GeDataType::DT_UINT16: - return MeDataType::kNumberTypeUInt16; - case GeDataType::DT_UINT32: - return MeDataType::kNumberTypeUInt32; - case GeDataType::DT_UINT64: - return MeDataType::kNumberTypeUInt64; - case GeDataType::DT_UNDEFINED: - case GeDataType::DT_DUAL_SUB_UINT8: - case GeDataType::DT_DUAL_SUB_INT8: - case GeDataType::DT_DUAL: - return MeDataType::kTypeUnknown; - default: - return MeDataType::kTypeUnknown; - } -} - -namespace { -bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { - MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); - MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); - - const int GE_DIMS = 4; - std::vector ge_dims = ge_shape.GetDims(); - if (request_dims.size() > ge_dims.size()) { - MS_LOG(ERROR) << "Request shape's dims count greater than ge shape's"; - return false; - } - - // convert NHWC to NCHW - if ((request_dims.size() == 1) && (ge_dims.size() == GE_DIMS) && (request_dims[0] == ge_dims[1]) && - (ge_dims[0] == 1) && (ge_dims[2] == 1) && (ge_dims[3] == 1)) { - MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; - return true; - } - - std::string::size_type i = 0; - for (; i < request_dims.size(); i++) { - if (ge_dims[i] != request_dims[i]) { - MS_LOG(ERROR) << "Request shape's dims value not equal to ge shape's"; - return false; - } - } - - for (; i < ge_dims.size(); i++) { - if (ge_dims[i] != 1) { - MS_LOG(ERROR) << "GeShape's extend dims is not equal to 1"; - return false; - } - } - MS_LOG(INFO) << "Ge tensor shape and request shape is compatible"; - return true; -} -} // namespace - -GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { - std::vector ge_dims; - (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); - return GeShape(ge_dims); -} - -std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { - std::vector me_dims; - std::vector ge_dims = ge_shape.GetDims(); - (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); - return me_dims; -} - -std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { - vector ret; - if (ge_shape.GetDimNum() == 0) { - MS_LOG(DEBUG) << "GeTensor's shape is scalar"; - return ret; - } - - if (IsGeShapeCompatible(ge_shape, request_dims) == true) { - ret = request_dims; - } else { - MS_LOG(ERROR) << "GeShape and Me request shape are incompatible, return GeShape"; - ret = ConvertGeShape(ge_shape); - } - return ret; -} - -MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, - const TypeId &me_type) { - MeTensor me_tensor(me_type, me_dims); - - // Get the writable data pointer of the tensor and cast it to its data type - auto me_data_ptr = reinterpret_cast(me_tensor.data_c()); - size_t me_data_size = static_cast(me_tensor.data().nbytes()); - MS_EXCEPTION_IF_NULL(me_data_ptr); - MS_EXCEPTION_IF_NULL(ge_tensor); - if (me_data_size < ge_tensor->GetSize()) { - MS_LOG(ERROR) << "ME tensor data size[" << me_data_size << " bytes] is less than GE tensor [" - << ge_tensor->GetSize() << " bytes]"; - return nullptr; - } - - // Copy or use the writable data pointer of the ME tensor - MS_EXCEPTION_IF_NULL(ge_tensor->GetData()); - if (ge_tensor->GetSize() == 0) { - MS_LOG(ERROR) << "GE tensor data size is zero!"; - return nullptr; - } - - // Use memcpy here, not memcpy_s, just because the size of ge_tensor may be bigger than 2GB - // which is the size limit of memcpy_s - memcpy(me_data_ptr, ge_tensor->GetData(), ge_tensor->GetSize()); - - return make_shared(me_tensor); -} - -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { - MS_EXCEPTION_IF_NULL(ge_tensor); - GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); - vector me_dims = ConvertGeShape(ge_shape); - - TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); - if (type_id == MeDataType::kTypeUnknown) { - MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " - << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - return nullptr; - } - return GenerateMeTensor(ge_tensor, me_dims, type_id); -} - -// if request_dims is empty, use ge tensor's shape,otherwise convert to request shape -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { - MS_EXCEPTION_IF_NULL(ge_tensor); - GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); - vector me_dims = ConvertGeShape(ge_shape, request_dims); - MS_LOG(INFO) << "GE tensor type is " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - // Create a tensor with wanted data type and shape - TypeId type_id = ConvertGeDataType(ge_tensor->GetTensorDesc().GetDataType()); - if (type_id == MeDataType::kTypeUnknown) { - MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: " - << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - return nullptr; - } - return GenerateMeTensor(ge_tensor, me_dims, type_id); -} - -std::string TransformUtil::PrintGeTensor(const GeTensorPtr ge_tensor) { - std::string ret; - if (ge_tensor == nullptr) { - MS_LOG(ERROR) << "Input ge tensor is nullptr"; - return ret; - } - - MS_LOG(INFO) << "Ge Tensor data type is : " << static_cast(ge_tensor->GetTensorDesc().GetDataType()); - switch (ge_tensor->GetTensorDesc().GetDataType()) { - case GeDataType::DT_UINT32: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_FLOAT: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_INT32: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_DOUBLE: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_INT64: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_UINT64: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_INT16: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_UINT16: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_DUAL_SUB_INT8: - case GeDataType::DT_INT8: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_UINT8: - case GeDataType::DT_DUAL_SUB_UINT8: - ret = PrintVector(MakeVector(ge_tensor->GetData(), ge_tensor->GetSize())); - break; - case GeDataType::DT_FLOAT16: - case GeDataType::DT_BOOL: - case GeDataType::DT_UNDEFINED: - case GeDataType::DT_DUAL: - default: - MS_LOG(ERROR) << "Unsupported to print type:" << static_cast(ge_tensor->GetTensorDesc().GetDataType()) - << " ge tensor"; - break; - } - return ret; -} -} // namespace transform -} // namespace mindspore diff --git a/mindspore/ccsrc/transform/util.h b/mindspore/ccsrc/transform/util.h deleted file mode 100644 index 5d8db26ad1..0000000000 --- a/mindspore/ccsrc/transform/util.h +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef TRANSFORM_UTIL_H_ -#define TRANSFORM_UTIL_H_ - -#include -#include -#include -#include -#include "securec/include/securec.h" -#include "ir/anf.h" -#include "ir/dtype.h" -#include "ir/tensor.h" -#include "transform/types.h" - -#include "graph/tensor.h" - -namespace mindspore { -namespace transform { -class TransformUtil { - public: - /* - * Parameters: - * type: [MeDataType] the data type for ME tensor - * Return: - * [GeDataType] the data type for ge tensor - * */ - static std::vector ConvertIntToList(int64_t data, int size); - - /* - * Parameters: - * type: [MeDataType] the data type for ME tensor - * Return: - * [GeDataType] the data type for ge tensor - * */ - static GeDataType ConvertDataType(const MeDataType &type); - - /* - * Parameters: - * type: [string] the data format in ME op - * Return: - * [GeFormat] the data format for ge tensor - * */ - static GeFormat ConvertFormat(const std::string &format); - - /* - * Parameters: - * type: [MeDataType] the data type for ME tensor - * Return: - * [size_t] the buff size for the type in ME - * */ - static size_t GetDataTypeSize(const MeDataType &type); - - /* - * Parameters: - * tensor: [MeTensorPtr] the me tensor to get description from - * format: [string] the data format in ME - * is_input: [bool] whether the tensor is used as input, default:false - * Return: - * [shared_ptr] the shared pointer of ge tensor description - * */ - static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, - const std::string &format); - - /* - * Parameters: - * tensor: [MeTensor] the data tensor in ME - * format: [string] the data format in ME op - * is_input: [bool] whether the tensor is used as input, default:false - * Return: - * [GeTensor] the data tensor in GE - * */ - static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); - - /* - * Parameters: - * me_tensors: [vector] the data tensors in ME - * format: [string] the data format in ME op - * Return: - * [std::vector] the data tensors in GE - * */ - static std::vector ConvertInputTensors(const std::vector &me_tensors, - const std::string &format); - - /* - * Parameters: - * tensor: [GeTensor] the data tensor in GE - * Return: - * [MeTensor] the data tensor in ME - * */ - static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); - - /* - * Parameters: - * tensor: [GeTensor] the data tensor in GE - * request_dims [std::vector] the output Me tensors must adjust to this shapes - * Return: - * [MeTensor] the data tensor in ME - * */ - static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); - /* - * Parameters: - * ge_tensors: [std::vector] the data tensor in GE - * request_dims [std::vector>] the output Me tensors must adjust to this shapes - * Return: - * [std::vector] the data tensor in ME - * */ - static std::vector ConvertGeTensors(const std::vector &ge_tensors, - const std::vector> &request_dims); - /* - * Parameters: - * ge_tensors: [std::vector] the data tensor in GE - * Return: - * [std::vector] the data tensor in ME - * */ - static std::vector ConvertGeTensors(const std::vector &ge_tensors); - /* - * Parameters: - * ge_tensor: [GeTensor] the data tensor in GE - * me_dims: [std::vector] the shape of created Me tensor - * me_type: [TypeId] the type of created Me tensor - * Return: - * [MeTensor] the data tensor in ME - * */ - static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, - const TypeId &me_type); - /* - * Parameters: - * type: [GeDataType] the ge tensor data type - * Return: - * [MeDataType] the me tensor data type - * */ - static MeDataType ConvertGeDataType(const GeDataType &type); - - /* - * Parameters: - * me_dims: [std::vector] the me shape - * Return: - * [GeShape] the ge shape - * */ - static GeShape ConvertMeShape(const std::vector &me_dims); - - /* - * Parameters: - * ge_shape: [GeShape] the ge shape - * Return: - * [vector] the me shape - * */ - static std::vector ConvertGeShape(const GeShape &ge_shape); - - /* Function: - * Convert GeShape to Me request shape, Support pattern: - * {1, x, 1, 1} --> {x} - * {x, 1, 1, 1} --> {x} - * {x, x, 1, 1} --> {x, x} - * {x, x, x, 1} --> {x, x, x} - * {x, x, x, x} --> {x, x, x, x} - * If unmatch upon patterns, return original ge dims - * Parameters: - * ge_shape: [GeShape] the ge shape - * request_dims: [vector] request dims - * Return: - * [vector] the me shape - * */ - static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); - - /* - * Parameters: - * vec: [std::vector] the vector to print - * Return: - * [string] value string - * */ - template ::value>::type> - static std::string PrintVector(const std::vector &vec) { - const int MAX_PRINT_NUM = 100; - std::stringstream ss; - ss << "{ "; - int i = 0; - for (auto it = vec.begin(); it != vec.end(); ++it) { - ss << std::to_string(*it) << ", "; - i++; - if (i >= MAX_PRINT_NUM) { - break; - } - } - - if (i >= MAX_PRINT_NUM) { - ss << "... to be continue}"; - } else { - ss << "}"; - } - return ss.str(); - } - - /* - * Parameters: - * ge_tensor: [GeTensorPtr] the ge tensor - * Return: - * [stringstream] value string - * */ - static std::string PrintGeTensor(const GeTensorPtr ge_tensor); - - /* - * Parameters: - * data: [uint8_t *] the ge tensor data pointer - * size: [size_t] the ge tensor data bytes - * Return: - * [shared_ptr] vector pointer - * */ - template ::value>::type> - static std::vector MakeVector(const uint8_t *const data, size_t size) { - auto dest = std::vector(size / sizeof(T)); - if (data == nullptr) { - return dest; - } - - errno_t ret = memcpy_s(dest.data(), dest.size() * sizeof(T), data, size); - if (EOK != ret) { - return std::vector(); - } - return dest; - } -}; -} // namespace transform -} // namespace mindspore - -#endif // TRANSFORM_UTIL_H_ diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index 427cc5e568..ceb95d5c8c 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -20,8 +20,8 @@ #include #include #include "pybind11/pybind11.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/visible.h" namespace mindspore { diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 55125ebe91..6001b295ad 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -17,10 +17,10 @@ #include "utils/callbacks_ge.h" #include "pybind11/pybind11.h" #include "ir/param_value.h" -#include "transform/df_graph_manager.h" -#include "transform/util.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/parse/python_adapter.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "transform/graph_ir/util.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/visible.h" namespace mindspore { diff --git a/mindspore/ccsrc/utils/callbacks_ge.h b/mindspore/ccsrc/utils/callbacks_ge.h index 9735c3000a..f0ef583aaa 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.h +++ b/mindspore/ccsrc/utils/callbacks_ge.h @@ -20,8 +20,8 @@ #include #include #include -#include "transform/types.h" -#include "transform/util.h" +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/util.h" #include "ir/tensor.h" namespace mindspore { diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 92bf92abea..37b6bf638b 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -27,7 +27,7 @@ #include "tdt/data_common.h" #endif #ifdef ENABLE_GE -#include "transform/df_graph_manager.h" +#include "transform/graph_ir/df_graph_manager.h" #endif #include "ir/tensor.h" #include "common/utils.h" diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index a5a618dff4..b1847d1df5 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -26,8 +26,8 @@ #include "pybind11/pybind11.h" #include "abstract/abstract_value.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/parse_base.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" #include "ir/value.h" #include "ir/tensor.h" #include "ir/param_value.h" diff --git a/mindspore/ccsrc/utils/graph_utils_extends.cc b/mindspore/ccsrc/utils/graph_utils_extends.cc index 0740c24236..852dd0e3f2 100644 --- a/mindspore/ccsrc/utils/graph_utils_extends.cc +++ b/mindspore/ccsrc/utils/graph_utils_extends.cc @@ -31,8 +31,8 @@ #include "debug/label.h" #include "utils/log_adapter.h" #include "common/utils.h" -#include "pipeline/parse/function_block.h" -#include "pipeline/parse/python_adapter.h" +#include "pipeline/jit/parse/function_block.h" +#include "pipeline/jit/parse/python_adapter.h" namespace mindspore { namespace { diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc index d676be895e..fa1137e3f6 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc @@ -23,7 +23,7 @@ #include "google/protobuf/io/zero_copy_stream_impl.h" #include "ir/tensor.h" #include "ir/param_value.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "abstract/abstract_value.h" #include "proto/onnx.pb.h" #include "utils/log_adapter.h" diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc index 97fa954e12..490e2517a9 100644 --- a/mindspore/ccsrc/utils/primitive_utils.cc +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -15,7 +15,7 @@ */ #include "utils/primitive_utils.h" -#include "pipeline/parse/python_adapter.h" +#include "pipeline/jit/parse/python_adapter.h" #include "utils/log_adapter.h" #include "common/utils.h" diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index cdaa826c82..08cd4e4291 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -21,7 +21,7 @@ #include #include #include "ir/tensor.h" -#include "device/convert_tensor_utils.h" +#include "runtime/device/convert_tensor_utils.h" #include "./securec.h" #ifndef NO_DLIB #include "tdt/tsd_client.h" diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 88a07c7c12..0290ee57fc 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -23,7 +23,7 @@ #include "utils/callbacks.h" #include "utils/graph_utils.h" #include "utils/base_ref_extends.h" -#include "session/session_factory.h" +#include "backend/session/session_factory.h" #include "common/utils.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index c8d0696fa4..208c4010fb 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -26,7 +26,7 @@ #include "ir/anf.h" #include "vm/segment_runner.h" #include "vm/vm.h" -#include "session/session_basic.h" +#include "backend/session/session_basic.h" namespace mindspore { namespace compile { diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index db27506134..540b77bcaf 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -31,7 +31,7 @@ #include "utils/utils.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" namespace mindspore { const char kMsConvert[] = "ms"; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index ccad0112c3..2cf6ead813 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -28,7 +28,7 @@ #include "abstract/abstract_value.h" #ifdef ENABLE_GE -#include "transform/convert.h" +#include "transform/graph_ir/convert.h" #endif #include "utils/graph_utils.h" #include "utils/context/ms_context.h" diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 55c32ea4e3..d08a24d188 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -28,7 +28,7 @@ #include "vm/vm.h" #include "ir/anf.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "vm/segment_runner.h" #include "vm/backend.h" diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index 047b330158..baa5b0ea11 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -23,7 +23,7 @@ #include "vm/vmimpl.h" #include "vm/backend.h" #include "vm/transform.h" -#include "pipeline/parse/data_converter.h" +#include "pipeline/jit/parse/data_converter.h" #include "utils/base_ref_extends.h" namespace mindspore { diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index cb23cdaf43..2aebf8ad0d 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -27,7 +27,7 @@ #include #include "ir/tensor.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" #include "ir/primitive_py.h" diff --git a/mindspore/ccsrc/ir/CMakeLists.txt b/mindspore/core/ir/CMakeLists.txt similarity index 100% rename from mindspore/ccsrc/ir/CMakeLists.txt rename to mindspore/core/ir/CMakeLists.txt diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc new file mode 100644 index 0000000000..0d96ddf263 --- /dev/null +++ b/mindspore/core/ir/anf.cc @@ -0,0 +1,221 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/anf.h" + +#include +#include +#include +#include + +#include "ir/func_graph.h" +#include "ir/primitive.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +// namespace to support intermediate representation definition +CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) + : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} + +// Check if CNode is an apply with the specific Primitive. +bool CNode::IsApply(const PrimitivePtr &value) const { + if (value == nullptr) { + return false; + } + + if (inputs_.size() != 0 && IsValueNode(inputs_[0])) { + PrimitivePtr fn_value = GetValueNode(inputs_[0]); + if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { + return true; + } + } + + return false; +} + +void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } + +std::string CNode::DebugString(int recursive_level) const { + std::ostringstream buffer; + if (recursive_level > 0) { + if (func_graph() != nullptr) { + buffer << func_graph()->ToString() << ":"; + } + buffer << ToString() << "{"; + bool is_first_node = true; + int idx = 0; + for (auto &node : inputs_) { + MS_EXCEPTION_IF_NULL(node); + if (is_first_node) { + is_first_node = false; + } else { + buffer << ", "; + } + buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1); + idx++; + } + buffer << "}"; + } else { + buffer << ToString(); + } + return buffer.str(); +} + +std::string ValueNode::ToString() const { + MS_EXCEPTION_IF_NULL(value_); + if (value_->isa()) { + return value_->cast()->ToString(); + } + std::ostringstream buffer; + buffer << AnfNode::ToString(); + buffer << "(" << value_->ToString() << ")"; + return buffer.str(); +} + +std::string ValueNode::DebugString(int) const { + MS_EXCEPTION_IF_NULL(value_); + std::ostringstream buffer; + buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString(); + return buffer.str(); +} + +std::string ValueNode::fullname_with_scope() { + if (!fullname_with_scope_.empty()) { + return fullname_with_scope_; + } + + MS_EXCEPTION_IF_NULL(scope()); + fullname_with_scope_ = scope()->name() + "/" + "data-" + id_generator::get_id(shared_from_base()); + return fullname_with_scope_; +} + +bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; + } + if (value != nullptr) { + return cnode->IsApply(value); + } + const auto &prim = GetValueNode(cnode->input(0)); + return prim != nullptr; +} + +PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { + if (node == nullptr) { + return nullptr; + } + auto cnode = node->cast(); + if (cnode != nullptr) { + if (cnode->size() > 0) { + auto prim = GetValueNode(cnode->input(0)); + return prim; + } + } + return nullptr; +} + +std::string GetCNodeFuncName(const CNodePtr cnode) { + if (cnode->inputs().empty()) { + return ""; + } + + AnfNodePtr valuenode = cnode->input(0); + if (valuenode->isa()) { + auto value = GetValueNode(valuenode); + // check whether the valuenode is primitive + if (value->isa()) { + return value->cast()->name(); + } + return value->ToString(); + } + return ""; +} + +bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { + if (IsValueNode(node)) { + PrimitivePtr fn_value = GetValueNode(node); + MS_EXCEPTION_IF_NULL(value); + if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) { + return true; + } + } + return false; +} + +size_t NewSeenGeneration() { + static size_t seen_generation = 0; + return ++seen_generation; +} + +namespace id_generator { +static std::unordered_map node_ids; +std::string get_id(const AnfNodePtr &node) { + auto type_name = node->type_name(); + if (node_ids.find(type_name) == node_ids.end()) { + node_ids[type_name] = 0; + } else { + node_ids[type_name]++; + } + return std::to_string(node_ids[type_name]); +} + +void reset_id() { node_ids.clear(); } +} // namespace id_generator + +std::string GetCNodeTarget(const AnfNodePtr &node) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->device_target(); + if (!node->isa()) { + return default_target; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto attr_input = cnode->input(0); + if (attr_input == nullptr) { + return default_target; + } + auto value_node = attr_input->cast(); + if (value_node == nullptr) { + return default_target; + } + auto value = value_node->value(); + if (value == nullptr) { + return default_target; + } + if (!value->isa()) { + return default_target; + } + auto primitive = value->cast(); + auto att_target = primitive->GetAttr("primitive_target"); + if (att_target != nullptr) { + if (!att_target->isa()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + auto target = GetValue(att_target); + if (kTargetSet.find(target) == kTargetSet.end()) { + MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; + } + return target; + } + return default_target; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/anf.h b/mindspore/core/ir/anf.h similarity index 100% rename from mindspore/ccsrc/ir/anf.h rename to mindspore/core/ir/anf.h diff --git a/mindspore/core/ir/anf_extends.cc b/mindspore/core/ir/anf_extends.cc new file mode 100644 index 0000000000..1caf7f1b36 --- /dev/null +++ b/mindspore/core/ir/anf_extends.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/anf.h" + +#include +#include +#include +#include + +#include "ir/visitor.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "frontend/operator/ops.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "debug/label.h" + +namespace mindspore { +// namespace to support intermediate representation definition +// Methods of AnfNode +TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); } +BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } + +std::string AnfNode::ToString() const { + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); +} + +OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { + if (operator_info_ != nullptr) { + MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() + << ", using the new one: " << operator_info->name(); + auto old_ptr = operator_info_; + operator_info_ = operator_info; + return old_ptr; + } + operator_info_ = operator_info; + return nullptr; +} + +std::string CNode::fullname_with_scope() { + // if full name is set, return its name immediately + if (!fullname_with_scope_.empty()) { + return fullname_with_scope_; + } + + if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) || + IsApply(prim::kPrimHistogramSummary)) { + std::string tag = GetValue(GetValueNode(input(1))); + std::string name; + if (IsApply(prim::kPrimScalarSummary)) { + name = tag + "[:Scalar]"; + } else if (IsApply(prim::kPrimImageSummary)) { + name = tag + "[:Image]"; + } else if (IsApply(prim::kPrimHistogramSummary)) { + name = tag + "[:Histogram]"; + } else { + name = tag + "[:Tensor]"; + } + fullname_with_scope_ = name; + } else { + // cnode input 0 should be primitive ptr or funcgraph ptr + auto value_ptr = input(0)->cast(); + if (value_ptr == nullptr) { + MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << "."; + fullname_with_scope_ = id_generator::get_id(shared_from_base()); + return fullname_with_scope_; + } + auto input_value = value_ptr->value(); + if (input_value == nullptr) { + MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr."; + fullname_with_scope_ = id_generator::get_id(shared_from_base()); + return fullname_with_scope_; + } + + auto prim = input_value->cast(); + MS_EXCEPTION_IF_NULL(scope()); + fullname_with_scope_ = scope()->name() + "/"; + if (prim != nullptr) { + fullname_with_scope_ += prim->name(); + } else { + auto func_graph = input_value->cast(); + MS_EXCEPTION_IF_NULL(func_graph); + auto fg_flag = func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL); + if (fg_flag != nullptr) { + auto fg_name = GetValue(fg_flag); + fullname_with_scope_ += "GraphKernel_" + fg_name; + } else { + fullname_with_scope_ += func_graph->ToString(); + } + } + fullname_with_scope_ += "-op" + id_generator::get_id(shared_from_base()); + } + + return fullname_with_scope_; +} + +void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/anf_py.cc b/mindspore/core/ir/anf_py.cc similarity index 100% rename from mindspore/ccsrc/ir/anf_py.cc rename to mindspore/core/ir/anf_py.cc diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/core/ir/dtype.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype.cc rename to mindspore/core/ir/dtype.cc diff --git a/mindspore/ccsrc/ir/dtype.h b/mindspore/core/ir/dtype.h similarity index 100% rename from mindspore/ccsrc/ir/dtype.h rename to mindspore/core/ir/dtype.h diff --git a/mindspore/ccsrc/ir/dtype/container.cc b/mindspore/core/ir/dtype/container.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/container.cc rename to mindspore/core/ir/dtype/container.cc diff --git a/mindspore/ccsrc/ir/dtype/container.h b/mindspore/core/ir/dtype/container.h similarity index 100% rename from mindspore/ccsrc/ir/dtype/container.h rename to mindspore/core/ir/dtype/container.h diff --git a/mindspore/ccsrc/ir/dtype/empty.cc b/mindspore/core/ir/dtype/empty.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/empty.cc rename to mindspore/core/ir/dtype/empty.cc diff --git a/mindspore/ccsrc/ir/dtype/empty.h b/mindspore/core/ir/dtype/empty.h similarity index 100% rename from mindspore/ccsrc/ir/dtype/empty.h rename to mindspore/core/ir/dtype/empty.h diff --git a/mindspore/ccsrc/ir/dtype/number.cc b/mindspore/core/ir/dtype/number.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/number.cc rename to mindspore/core/ir/dtype/number.cc diff --git a/mindspore/ccsrc/ir/dtype/number.h b/mindspore/core/ir/dtype/number.h similarity index 100% rename from mindspore/ccsrc/ir/dtype/number.h rename to mindspore/core/ir/dtype/number.h diff --git a/mindspore/ccsrc/ir/dtype/ref.cc b/mindspore/core/ir/dtype/ref.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/ref.cc rename to mindspore/core/ir/dtype/ref.cc diff --git a/mindspore/ccsrc/ir/dtype/ref.h b/mindspore/core/ir/dtype/ref.h similarity index 100% rename from mindspore/ccsrc/ir/dtype/ref.h rename to mindspore/core/ir/dtype/ref.h diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/core/ir/dtype/type.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/type.cc rename to mindspore/core/ir/dtype/type.cc diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/core/ir/dtype/type.h similarity index 100% rename from mindspore/ccsrc/ir/dtype/type.h rename to mindspore/core/ir/dtype/type.h diff --git a/mindspore/ccsrc/ir/dtype/type_extends.cc b/mindspore/core/ir/dtype/type_extends.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype/type_extends.cc rename to mindspore/core/ir/dtype/type_extends.cc diff --git a/mindspore/ccsrc/ir/dtype/type_id.h b/mindspore/core/ir/dtype/type_id.h similarity index 100% rename from mindspore/ccsrc/ir/dtype/type_id.h rename to mindspore/core/ir/dtype/type_id.h diff --git a/mindspore/ccsrc/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype_extends.cc rename to mindspore/core/ir/dtype_extends.cc diff --git a/mindspore/ccsrc/ir/dtype_py.cc b/mindspore/core/ir/dtype_py.cc similarity index 100% rename from mindspore/ccsrc/ir/dtype_py.cc rename to mindspore/core/ir/dtype_py.cc diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc new file mode 100644 index 0000000000..fabdd3e7d3 --- /dev/null +++ b/mindspore/core/ir/func_graph.cc @@ -0,0 +1,628 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/func_graph.h" + +#include +#include +#include + +#include "debug/trace.h" +#include "ir/manager.h" +#include "frontend/operator/ops.h" +#include "utils/ordered_set.h" +#include "utils/convert_utils_base.h" + +namespace mindspore { +/* + * Methods of Graph + */ +FuncGraph::FuncGraph() + : attrs_(), + transforms_(), + parameter_default_value_(), + seen_(0), + parameters_(), + has_vararg_(false), + has_kwarg_(false), + kwonlyargs_count_(0), + hyper_param_count_(0), + is_generated_(false), + return_(nullptr), + manager_(std::weak_ptr()), + stub_(false) { + debug_info_ = std::make_shared(); +} + +AnfNodePtr FuncGraph::output() const { + // If return value is set, return should have two inputs. + if (return_ != nullptr && return_->inputs().size() == 2) { + return return_->input(1); + } else { + // If not set yet, return nullptr. + return nullptr; + } +} + +ParameterPtr FuncGraph::add_parameter() { + FuncGraphPtr this_func_graph = shared_from_base(); + ParameterPtr p = std::make_shared(this_func_graph); + add_parameter(p); + return p; +} + +void FuncGraph::add_parameter(const ParameterPtr &p) { + if (manager_.lock()) { + manager_.lock()->AddParameter(shared_from_base(), p); + } else { + parameters_.push_back(p); + } +} + +ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { + FuncGraphPtr this_graph = shared_from_base(); + ParameterPtr p = std::make_shared(this_graph); + p->set_name(name); + p->debug_info()->set_name(name); + + if (manager_.lock()) { + manager_.lock()->AddParameter(shared_from_base(), p); + } else { + parameters_.push_back(p); + } + hyper_param_count_++; + return p; +} + +bool FuncGraph::has_flag(const std::string &key) { + auto iter = attrs_.find(key); + if (iter != attrs_.cend()) { + if (iter->second->isa()) { + return GetValue(iter->second); + } + MS_LOG(WARNING) << "key " << key << " is not a flag, please use has_attr function."; + } + return false; +} + +bool FuncGraph::has_attr(const std::string &key) { + auto iter = attrs_.find(key); + return !(iter == attrs_.cend()); +} + +ValuePtr FuncGraph::get_attr(const std::string &key) { + auto iter = attrs_.find(key); + return iter == attrs_.cend() ? nullptr : iter->second; +} + +CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { + CNodePtr cnode = std::make_shared(inputs, shared_from_base()); + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + order_.push_back(cnode); + MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order."; + } + return cnode; +} + +CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { + CNodePtr app = NewCNode(inputs); + app->set_scope(scope); + return app; +} + +void FuncGraph::DumpCNodeList() { + MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; + for (const auto &cnode : order_) { + MS_LOG(INFO) << cnode->DebugString(); + } +} + +std::string FuncGraph::ToString() const { + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); +} + +GraphDebugInfoPtr FuncGraph::debug_info() { + MS_EXCEPTION_IF_NULL(this->debug_info_); + if (this->debug_info_->get_graph() == nullptr) { + this->debug_info_->set_graph(shared_from_base()); + } + return this->debug_info_; +} + +const AnfNodeSet &FuncGraph::nodes() { return nodes_; } + +void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); } + +void FuncGraph::ClearNodes() { nodes_.clear(); } + +void FuncGraph::AddNode(AnfNodePtr node) { nodes_.add(node); } + +void FuncGraph::DropNode(AnfNodePtr node) { + nodes_.erase(node); + auto graph = node->func_graph(); + // Remove the node from order list. + if (graph) { + graph->EraseUnusedNodeInOrder(node); + } +} + +const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } + +void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) { + auto &others = source->value_nodes(); + for (auto it = others.begin(); it != others.end(); it++) { + AddValueNode(it->first, it->second); + } +} + +void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } + +void FuncGraph::AddValueNode(AnfNodePtr node, int count) { + if (value_nodes_.count(node) == 0) { + value_nodes_[node] = count; + } else { + value_nodes_[node] += count; + } +} + +void FuncGraph::DropValueNode(AnfNodePtr node) { + if (value_nodes_.count(node) != 0) { + if (value_nodes_[node] == 1) { + (void)value_nodes_.erase(node); + } else { + value_nodes_[node]--; + if (value_nodes_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of ValueNode '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } + +void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) { + auto &others = source->free_variables(); + for (auto it = others.begin(); it != others.end(); it++) { + if (it->first->func_graph().get() != this) { + (void)AddFreeVariable(it->first, it->second); + } + } +} + +void FuncGraph::ClearFreeVariables() { free_variables_.clear(); } + +bool FuncGraph::AddFreeVariable(AnfNodePtr node, int count) { + if (free_variables_.count(node) == 0) { + free_variables_[node] = count; + return true; + } else { + free_variables_[node] += count; + return false; + } +} + +bool FuncGraph::DropFreeVariable(AnfNodePtr node) { + if (free_variables_.count(node) != 0) { + if (free_variables_[node] == 1) { + (void)free_variables_.erase(node); + return true; + } else { + free_variables_[node]--; + if (free_variables_[node] < 0) { + MS_LOG(EXCEPTION) << "Count of free variable '" << node + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } + return false; +} + +const BaseRefCounterMap &FuncGraph::free_variables_total() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + auto &fv_total = mng->free_variables_total(); + return fv_total[shared_from_base()]; +} + +std::vector FuncGraph::free_variables_nodes() { + std::vector nodes; + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { + auto key = p.first; + if (utils::isa(key)) { + nodes.push_back(utils::cast(key)); + } + } + + return nodes; +} + +std::vector FuncGraph::free_variables_func_graphs() { + std::vector func_graphs; + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { + auto key = p.first; + if (utils::isa(key)) { + func_graphs.push_back(utils::cast(key)); + } + } + + return func_graphs; +} + +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; } + +void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) { + auto &others = source->func_graphs_used(); + for (auto it = others.begin(); it != others.end(); it++) { + (void)AddFuncGraphUsed(it->first, it->second); + } + func_graphs_used_.erase(source); +} + +void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); } + +bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) { + if (func_graphs_used_.count(fg) == 0) { + func_graphs_used_[fg] = count; + return true; + } else { + func_graphs_used_[fg] += count; + return false; + } +} + +bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) { + if (func_graphs_used_.count(fg) != 0) { + if (func_graphs_used_[fg] == 1) { + (void)func_graphs_used_.erase(fg); + return true; + } else { + func_graphs_used_[fg]--; + if (func_graphs_used_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } + return false; +} + +const FuncGraphSet &FuncGraph::func_graphs_used_total() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + auto &used = mng->func_graphs_used_total(shared_from_base()); + return used; +} + +const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } + +void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) { + auto &others = source->func_graph_cnodes_index(); + for (auto it = others.begin(); it != others.end(); it++) { + // Ignore the user graph who may own itself. + auto fg = it->first->first->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + if (fg.get() != this) { + AddFuncGraphCNodeIndex(it->first, it->second); + } + } +} + +void FuncGraph::ClearFuncGraphCNodesIndex() { func_graph_cnodes_index_.clear(); } + +void FuncGraph::AddFuncGraphCNodeIndex(CNodeIndexPairPtr pair, int count) { + if (func_graph_cnodes_index_.count(pair) == 0) { + func_graph_cnodes_index_[pair] = count; + } else { + func_graph_cnodes_index_[pair] += count; + } +} + +void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) { + if (func_graph_cnodes_index_.count(pair) != 0) { + if (func_graph_cnodes_index_[pair] == 1) { + (void)func_graph_cnodes_index_.erase(pair); + } else { + func_graph_cnodes_index_[pair]--; + if (func_graph_cnodes_index_[pair] < 0) { + MS_LOG(EXCEPTION) << "Count of CNode/Index '" << pair->first << "/" << pair->second + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; } + +void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) { + auto &others = source->j_func_graphs(); + for (auto it = others.begin(); it != others.end(); it++) { + AddJFuncGraph(it->first, it->second); + } +} + +void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); } + +void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) { + if (j_func_graphs_.count(fg) == 0) { + j_func_graphs_[fg] = count; + } else { + j_func_graphs_[fg] += count; + } +} + +void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) { + if (j_func_graphs_.count(fg) != 0) { + if (j_func_graphs_[fg] == 1) { + (void)j_func_graphs_.erase(fg); + } else { + j_func_graphs_[fg]--; + if (j_func_graphs_[fg] < 0) { + MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg + << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + } + } +} + +FuncGraphPtr FuncGraph::parent() { + // report the bug early. + if (manager_.lock() == nullptr) { + MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString() + << " NodeInfo: " << trace::GetDebugInfo(debug_info()); + } + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->parent(shared_from_base()); +} + +const FuncGraphSet &FuncGraph::children() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->children(shared_from_base()); +} + +const FuncGraphSet &FuncGraph::scope() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->scopes(shared_from_base()); +} + +bool FuncGraph::recursive() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->recursive(shared_from_base()); +} + +std::shared_ptr> FuncGraph::recursive_graphs() { + auto mng = manager_.lock(); + MS_EXCEPTION_IF_NULL(mng); + return mng->recursive_graphs(shared_from_base()); +} + +AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { + auto itr = this->parameter_default_value_.find(name); + if (itr == parameter_default_value_.end()) { + return nullptr; + } + auto default_value = itr->second; + if (default_value == nullptr) { + MS_LOG(EXCEPTION) << "Graph parameter " << name << " not exist"; + } + if (IsValueNode(default_value)) { + return nullptr; + } + return default_value; +} + +// set the default values +void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { + auto all_is_null = + std::all_of(value_list.begin(), value_list.end(), [](const AnfNodePtr &node) { return IsValueNode(node); }); + if (value_list.empty()) { + all_is_null = true; + } + for (size_t i = 0; i < name_list.size(); ++i) { + if (!all_is_null) { + this->parameter_default_value_[name_list[i]] = value_list[i]; + } + } +} + +void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } + +size_t FuncGraph::GetDefaultValueCount() { + int null_count = + std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), + [](const std::pair &pair) { return IsValueNode(pair.second); }); + return parameter_default_value_.size() - IntToSize(null_count); +} + +AnfNodePtr FuncGraph::GetVariableArgParameter() { + if (!has_vararg_) { + return nullptr; + } + + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 2) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 2]; + } + + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]; +} + +std::string FuncGraph::GetVariableArgName() { + if (!has_vararg_) { + return ""; + } + + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 2) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 2 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 2]->cast()->name(); + } + + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); +} + +AnfNodePtr FuncGraph::GetVariableKwargParameter() { + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]; + } + return nullptr; +} + +std::string FuncGraph::GetVariableKwargName() { + if (has_kwarg_) { + if (parameters_.size() < hyper_param_count_ + 1) { + MS_LOG(EXCEPTION) << "Length of parameters is " << parameters_.size() << ", hyper_param_count is " + << hyper_param_count_ << ", parameters is less than 1 + hyper_param_count"; + } + return parameters_[parameters_.size() - hyper_param_count_ - 1]->cast()->name(); + } + return ""; +} + +int FuncGraph::GetPositionalArgsCount() const { + int count = SizeToInt(parameters_.size()); + if (has_kwarg_) { + count--; + } + if (has_vararg_) { + count--; + } + return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); +} + +AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { + for (size_t i = 0; i < parameters_.size(); ++i) { + MS_EXCEPTION_IF_NULL(parameters_[i]); + auto param_cast = parameters_[i]->cast(); + MS_EXCEPTION_IF_NULL(param_cast); + if (param_cast->name() == name) { + return parameters_[i]; + } + } + return nullptr; +} + +void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } + +std::list FuncGraph::GetOrderedCnodes() { + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + MS_LOG(DEBUG) << "Return ordered cnodes."; + return order_; + } else { + auto this_ptr = shared_from_base(); + auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); + auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); + + std::list cnodes; + auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); + for (const auto &node : nodes) { + auto cnode = dyn_cast(node); + if (cnode) { + cnodes.push_back(cnode); + } + } + return cnodes; + } +} + +void FuncGraph::EraseUnusedNodeInOrder() { + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + auto mng = manager_.lock(); + if (mng) { + auto &all_nodes = nodes(); + // Erase unused cnode. + for (auto it = order_.begin(); it != order_.end();) { + if (all_nodes.count(*it)) { + (void)it++; + } else { + MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; + it = order_.erase(it); + } + } + } + } +} + +void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { + if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { + order_.remove(n->cast()); + MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; + } +} + +void FuncGraph::CheckOrder() { + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + MS_LOG(DEBUG) << "Check graph " << ToString(); + for (auto it = order_.begin(); it != order_.end(); (void)it++) { + for (const auto &input_node : (*it)->inputs()) { + if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { + // Need to reorder the wrong order node. + auto found = std::find(order_.begin(), it, input_node); + if (found == it) { + DumpCNodeList(); + MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString() + << " doesn't obey the input dependency, " + << "as input " << input_node->DebugString() << " is not ahead of itself."; + } + } + } + } + auto mng = manager_.lock(); + if (mng != nullptr) { + const auto &all_nodes = nodes(); + if (all_nodes.size() != (order_.size() + parameters_.size())) { + DumpCNodeList(); + MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " + << all_nodes.size() - parameters_.size() << "."; + } + } + MS_LOG(DEBUG) << "Check order okay."; + } +} + +size_t NewFgSeenGeneration() { + static size_t fg_seen_generation = 0; + return ++fg_seen_generation; +} + +const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); +const char kFuncGraphFlagUndetermined[] = "Undeterminate"; +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/core/ir/func_graph.h similarity index 100% rename from mindspore/ccsrc/ir/func_graph.h rename to mindspore/core/ir/func_graph.h diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc new file mode 100644 index 0000000000..0857770cad --- /dev/null +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -0,0 +1,650 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/func_graph_cloner.h" + +#include + +#include "ir/manager.h" +#include "ir/param_value.h" +#include "frontend/operator/ops.h" +#include "utils/convert_utils_base.h" +#include "utils/log_adapter.h" +#include "utils/profile.h" +#include "utils/context/ms_context.h" + +// namespace to support intermediate representation definition +namespace mindspore { +Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, + bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) + : clone_all_valuenodes_(clone_all_valuenodes), + clone_all_child_graphs_(clone_all_child_graphs), + clone_all_used_graphs_(clone_all_used_graphs), + relation_(relation), + target_relation_(target_relation == nullptr ? relation : target_relation) { + for (auto &func_graph : func_graphs) { + AddClone(func_graph); + } + scope_ = kDefaultScope; + type_ = kBasic; +} + +void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList ¶ms, CloneType type) { + if (func_graph != nullptr) { + todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); + type_ = type; + } +} + +void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { + MS_EXCEPTION_IF_NULL(node); + if (repl_node_.find(node) != repl_node_.end() || node->isa()) { + return; + } + if (node->isa()) { + CloneParameter(node, target); + } else if (node->isa()) { + CloneCNode(node, target); + } +} + +void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(target); + TraceManager::DebugTrace(node->debug_info(), relation_); + auto new_param = (is_add) ? target->add_parameter() : std::make_shared(target); + auto old_param = node->cast(); + new_param->set_abstract(old_param->abstract()); + new_param->set_name(old_param->name()); + if (old_param->has_default()) { + // Default parameter can be shared since it is readonly. + new_param->set_default_param(old_param->default_param()); + } + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_param->set_scope(scope); + repl_node_[node] = new_param; + TraceManager::EndTrace(); +} + +void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(target); + TraceManager::DebugTrace(node->debug_info(), relation_); + CNodePtr new_node = std::make_shared(AnfNodePtrList{}, target); + auto old_node = node->cast(); + new_node->set_abstract(old_node->abstract()); + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_node->set_scope(scope); + new_node->set_kernel_info(old_node->kernel_info_ptr()); + repl_node_[old_node] = new_node; + nodes_.emplace_back(old_node, new_node); + TraceManager::EndTrace(); +} + +void Cloner::CloneValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + TraceManager::DebugTrace(node->debug_info(), relation_); + ValueNodePtr new_const = NewValueNode(GetValueNode(node)); + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_const->set_scope(scope); + new_const->set_abstract(node->abstract()); + repl_node_[node] = new_const; + TraceManager::EndTrace(); +} + +void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(target); + TraceManager::DebugTrace(node->debug_info(), relation_); + ValueNodePtr new_const = NewValueNode(target); + ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); + new_const->set_scope(scope); + new_const->set_abstract(node->abstract()); + repl_node_[node] = new_const; + TraceManager::EndTrace(); +} + +void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager_); + if (!clone_all_valuenodes_) { + return; + } + auto &value_nodes = func_graph->value_nodes(); + for (auto &value_node : value_nodes) { + auto old_node = value_node.first; + MS_EXCEPTION_IF_NULL(old_node); + if (repl_node_.count(old_node) == 0) { + CloneValueNode(old_node); + } + } +} + +void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager_); + if (!clone_all_child_graphs_) { + return; + } + auto &scopes = manager_->scopes(func_graph); + for (auto &graph : scopes) { + if (graph != func_graph) { + todo_.push_back({graph, nullptr, {}}); + } + } +} + +void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager_); + if (!clone_all_used_graphs_) { + return; + } + auto &used = func_graph->func_graphs_used(); + for (auto &fg : used) { + todo_.push_back({fg.first, nullptr, {}}); + } +} + +void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + for (auto &item : func_graph->parameter_default_value()) { + auto nodes = DeepLinkedGraphSearch(item.second); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + CloneNode(node, target_func_graph); + } else if (node->isa()) { + CloneValueNode(node); + } + } + } +} + +void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + MS_EXCEPTION_IF_NULL(manager_); + auto return_node = repl_node_[func_graph->get_return()]->cast(); + if (return_node == nullptr) { + MS_LOG(EXCEPTION) << "Can't find replicate node for return."; + } + target_func_graph->set_return(return_node); + + auto &cnodes = func_graph->func_graph_cnodes_index(); + for (auto &cnode : cnodes) { + auto parent = cnode.first->first->cast(); + auto valuenode = parent->input(cnode.first->second); + CloneValueNode(valuenode, target_func_graph); + } +} + +void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { + MS_EXCEPTION_IF_NULL(func_graph); + auto &old_params = func_graph->parameters(); + if (old_params.size() != params.size()) { + MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; + return; + } + for (size_t i = 0; i < old_params.size(); ++i) { + repl_node_[old_params[i]] = params[i]; + } +} + +void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); + *target_func_graph = std::make_shared(); + (*target_func_graph)->set_attrs(func_graph->attrs()); + (*target_func_graph)->set_transforms(func_graph->transforms()); + (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); + (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); + (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count()); + (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); + (*target_func_graph)->set_is_generate(func_graph->is_generated()); + (*target_func_graph)->set_stub(func_graph->stub()); + TraceManager::EndTrace(); +} + +void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + auto ¶ms = func_graph->parameters(); + for (auto ¶m : params) { + CloneParameter(param, target_func_graph, true); + } + repl_func_graph_[func_graph] = target_func_graph; +} + +void Cloner::GenParameters(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto &free_vars = manager_->free_variables_total(); + auto iter = free_vars.find(func_graph); + if (iter == free_vars.end()) { + return; + } + + for (auto &fv_map : iter->second) { + auto &free_var = fv_map.first; + if (utils::isa(free_var)) { + repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); + } + } +} + +void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { + param->set_abstract(node->abstract()); + if (node->isa()) { + ParameterPtr old_param = dyn_cast(node); + if (old_param->has_default()) { + // Default parameter can be shared since it is readonly. + param->set_default_param(old_param->default_param()); + } + param->set_name(old_param->name()); + } +} + +ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { + TraceManager::DebugTrace(std::make_shared(node->debug_info())); + ParameterPtr param = std::make_shared(func_graph); + TraceManager::EndTrace(); + CloneParameter(param, node); + if (is_add) { + func_graph->add_parameter(param); + } + repl_node_[param] = node; + repl_map_node_[func_graph][node] = param; + return param; +} + +void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, + AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { + AnfNodePtrList parameters; + std::unordered_set old_params; + for (auto ¶m : func_graph->parameters()) { + auto iter = repl_node_.find(param); + if (iter != repl_node_.end()) { + (void)old_params.insert(iter->second); + parameters.push_back(param); + } else { + parameters.push_back(AddParameter(func_graph, param, false)); + (void)old_params.insert(param); + } + } + AnfNodePtr new_param = nullptr; + for (auto ¶m : params) { + auto old_param = repl_node_[param]; + if (old_param->isa() && old_param->func_graph() == func_graph) { + repl_node_[old_param] = old_param; + repl_map_node_[func_graph][old_param] = old_param; + input_params->push_back(old_param); + continue; + } + if (old_params.find(old_param) != old_params.end()) { + new_param = repl_map_node_[func_graph][old_param]; + input_params->push_back(new_param); + continue; + } + new_param = AddParameter(func_graph, old_param, false); + parameters.push_back(new_param); + lift_params->push_back(new_param); + input_params->push_back(new_param); + } + func_graph->set_parameters(parameters); +} + +void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { + AnfNodePtr node = nullptr; + auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; + auto iter = repl_func_graph.find(func_graph); + if (iter == repl_func_graph.end()) { + node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); + repl_func_graph[func_graph] = node; + } else { + node = iter->second; + } + if (node == nullptr || !node->isa()) { + return; + } + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + (void)std::copy(params.begin(), params.end(), std::back_inserter(inputs)); + cnode->set_inputs(inputs); + OrderParameters(func_graph, inputs); +} + +void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { + std::unordered_set old_params; + for (auto ¶m : func_graph->parameters()) { + (void)old_params.insert(repl_node_[param]); + } + std::unordered_set new_params; + AnfNodePtrList parameters; + // Ignore the 1st and 2nd param of inputs(such as. partial graph) + for (size_t i = 2; i < inputs.size(); ++i) { + auto input = inputs[i]; + auto param = repl_node_[input]; + if (old_params.find(param) != old_params.end()) { + auto new_param = repl_map_node_[func_graph][param]; + parameters.push_back(new_param); + (void)new_params.insert(new_param); + } + } + for (auto ¶m : func_graph->parameters()) { + if (new_params.find(param) == new_params.end()) { + parameters.push_back(param); + } + } + func_graph->set_parameters(parameters); +} + +void Cloner::SetEdges(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + for (auto &node : func_graph->nodes()) { + if (node == nullptr) { + continue; + } + // Only cnode needed to be handled + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + auto &input = inputs[i]; + if (IsValueNode(input)) { + auto graph = GetValueNode(input); + auto &repl_func_graph = repl_map_func_graph_[func_graph]; + if (repl_func_graph.find(graph) != repl_func_graph.end()) { + transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); + } + } else { + auto &repl_node = repl_map_node_[func_graph]; + if (repl_node.find(input) != repl_node.end()) { + transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); + } + } + } + } +} + +void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { + AnfNodePtrList lift_params; + AnfNodePtrList input_params; + AddParameters(func_graph_user, params, &lift_params, &input_params); + AddInputs(func_graph_user, func_graph, input_params); + if (lift_params.empty()) { + return; + } + for (auto &cnode : func_graph_user->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params); + } +} + +void Cloner::Lift() { + for (auto &func_graph_params : repl_func_graph_params_) { + auto &func_graph = func_graph_params.first; + auto ¶ms = func_graph_params.second; + for (auto &cnode : func_graph->func_graph_cnodes_index()) { + LiftParameters(cnode.first->first->func_graph(), func_graph, params); + } + } +} + +void Cloner::LiftParameters() { + MS_EXCEPTION_IF_NULL(manager_); + transaction_ = manager_->Transact(); + const FuncGraphSet &func_graphs = manager_->func_graphs(); + for (auto &func_graph : func_graphs) { + GenParameters(func_graph); + } + Lift(); + for (auto &func_graph : func_graphs) { + SetEdges(func_graph); + } + transaction_.Commit(); +} + +bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { + MS_EXCEPTION_IF_NULL(func_graph); + // Make sure only inline once + if (status_.count(func_graph) != 0) { + if (is_inline == status_[func_graph]) { + return false; + } + if (clone_all_used_graphs_) { + MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False."; + return false; + } + } + return true; +} + +void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + MS_EXCEPTION_IF_NULL(manager_); + const AnfNodeSet &nodes = func_graph->nodes(); + for (auto &node : nodes) { + CloneNode(node, target_func_graph); + } +} + +void Cloner::Run() { + if (todo_.empty()) { + return; + } + + if (type_ < kLifting) { + // Basic and Inline Clone + FuncGraphPtrList func_graphs; + (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), + [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); + manager_ = Manage(func_graphs, false); + CloneNodes(); + LinkEdges(); + SetDefaults(); + } else { + // Lifting Clone + CloneInfo item = todo_.back(); + manager_ = Manage(item.origin); + LiftParameters(); + } +} + +void Cloner::CloneNodes() { + while (!todo_.empty()) { + CloneInfo item = todo_.back(); + todo_.pop_back(); + + bool is_inline = (item.target != nullptr); + FuncGraphPtr func_graph = item.origin; + FuncGraphPtr target_func_graph = item.target; + (void)graph_set_.insert(func_graph); + + if (!CheckStatus(func_graph, is_inline)) { + continue; + } + + if (is_inline) { + InlineCloneParameters(func_graph, item.params); + CloneAllNodes(func_graph, target_func_graph); + } else { + SetFuncGraphInfo(func_graph, &target_func_graph); + CloneParameters(func_graph, target_func_graph); + CloneAllNodes(func_graph, target_func_graph); + CloneFuncGraphValueNodes(func_graph, target_func_graph); + CloneFuncGraphDefaultValues(func_graph, target_func_graph); + } + + CloneValueNodes(func_graph); + AddChildGraphs(func_graph); + AddTotalGraphs(func_graph); + status_[func_graph] = is_inline; + } +} + +void Cloner::LinkEdges() { + for (auto &node_pair : nodes_) { + CNodePtr old_node = node_pair.first; + CNodePtr new_node = node_pair.second; + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + for (auto &input : old_node->inputs()) { + auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; + new_node->add_input(new_input); + } + } +} + +// For the graphs cloned, update its default value map to the cloned nodes +void Cloner::SetDefaults() { + for (auto &item : graph_set_) { + MS_EXCEPTION_IF_NULL(item); + if (repl_func_graph_.count(item) != 0) { + for (auto ¶m_def : item->parameter_default_value()) { + MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); + if (repl_node_.count(param_def.second) != 0) { + repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); + } else { + repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second); + } + } + } + } +} + +AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { + MS_EXCEPTION_IF_NULL(root); + if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { + MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; + } + CloneNode(root, repl_func_graph_[root->func_graph()]); + auto iter = repl_node_.find(root); + if (iter != repl_node_.end()) { + return iter->second; + } + MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; +} + +AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { +#ifdef ENABLE_PROFILE + double time = GetTime(); +#endif + Run(); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time); +#endif + return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); +} + +FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { +#ifdef ENABLE_PROFILE + double time = GetTime(); +#endif + Run(); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time); +#endif + return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); +} + +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); + return cloner[func_graph]; +} + +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(target_func_graph); + Cloner cloner({}, false); + if (scope != nullptr) { + cloner.set_scope(scope); + } + cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline); + return cloner[func_graph->output()]; +} + +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + Cloner cloner({}, false); + cloner.AddClone(func_graph, nullptr, {}, kLifting); + return cloner[func_graph]; +} + +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphPtrList func_graphs = {func_graph}; + ClonerPtr cloner = + std::make_shared(func_graphs, false, false, false, std::make_shared(), relation); +#ifdef ENABLE_PROFILE + double time = GetTime(); +#endif + cloner->Run(); +#ifdef ENABLE_PROFILE + MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time); +#endif + return cloner; +} + +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { + MS_EXCEPTION_IF_NULL(func_graph); + TraceManager::DebugTrace(func_graph->debug_info(), relation); + auto new_func_graph = std::make_shared(); + TraceManager::EndTrace(); + + auto ¶meters = func_graph->parameters(); + (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { + MS_EXCEPTION_IF_NULL(param); + TraceManager::DebugTrace(std::make_shared(param->debug_info())); + (void)new_func_graph->add_parameter(); + TraceManager::EndTrace(); + }); + + Cloner cloner = Cloner(); + cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters()); + AnfNodePtr output = cloner[func_graph->output()]; + new_func_graph->set_output(output); + new_func_graph->set_has_vararg(func_graph->has_vararg()); + new_func_graph->set_has_kwarg(func_graph->has_kwarg()); + new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); + new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); + new_func_graph->set_is_generate(func_graph->is_generated()); + new_func_graph->set_stub(func_graph->stub()); + for (auto &item : func_graph->parameter_default_value()) { + new_func_graph->set_param_default_value(item.first, cloner[item.second]); + } + + if (MsContext::GetInstance()->is_multi_graph_sink()) { + if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { + new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); + } + } + + if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); + } + + return new_func_graph; +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_cloner.h b/mindspore/core/ir/func_graph_cloner.h similarity index 100% rename from mindspore/ccsrc/ir/func_graph_cloner.h rename to mindspore/core/ir/func_graph_cloner.h diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc new file mode 100644 index 0000000000..27f9958a5e --- /dev/null +++ b/mindspore/core/ir/func_graph_extends.cc @@ -0,0 +1,422 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/func_graph.h" + +#include +#include +#include + +#include "ir/manager.h" +#include "ir/func_graph_cloner.h" +#include "frontend/operator/ops.h" +#include "utils/ordered_set.h" +#include "abstract/abstract_value.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/abstract_function.h" + +#include "debug/anf_ir_dump.h" +#include "debug/trace.h" +#include "debug/draw.h" +#include "debug/label.h" + +namespace mindspore { +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractFunctionPtr; +using mindspore::abstract::AnalysisContextPtr; +using mindspore::abstract::PrimitiveAbstractClosure; +using mindspore::abstract::VirtualAbstractClosure; + +AbstractFunctionPtr FuncGraph::abstract() { + AbstractBasePtrList args_spec_list; + + for (auto &p : parameters_) { + MS_EXCEPTION_IF_NULL(p); + if (p->abstract() == nullptr) { + MS_LOG(ERROR) << "Error!!"; + return nullptr; + } + args_spec_list.push_back(p->abstract()); + } + + if (nullptr == output()) { + MS_LOG(ERROR) << "Error func graph no output"; + return nullptr; + } + + return std::make_shared(args_spec_list, output()->abstract()); +} + +abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { + AnalysisContextPtr temp_context = context; + if (temp_context == nullptr) { + temp_context = abstract::AnalysisContext::DummyContext(); + } + return std::make_shared(shared_from_base(), temp_context); +} + +void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { + if (force_new_ret || return_ == nullptr) { + std::vector params({NewValueNode(prim::kPrimReturn), value}); + FuncGraphPtr this_graph = shared_from_base(); + return_ = this_graph->NewCNode(params); + } else { + if (manager_.lock()) { + manager_.lock()->SetEdge(return_, 1, value); + } else { + return_->set_input(1, value); + } + } + + return_->set_abstract(value->abstract()); + + AnfNodePtr input0 = return_->input(0); + + PrimitivePtr return_prim = prim::kPrimReturn; + auto f = std::make_shared(return_prim, input0); + input0->set_abstract(f); +} + +void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } + +void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + std::unordered_map *repl_nodes, int variable_args_count, + int pos_args_input_count) { + // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple + if (specialized_graph->has_vararg()) { + TraceManager::DebugTrace( + std::make_shared(specialized_graph->GetVariableArgParameter()->debug_info())); + std::vector var_param_tuple_nodes; + var_param_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); + + if (variable_args_count < 0) { + MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count + << " were given."; + } + // for python variable argument input , there is no upper limit + for (int i = 0; i < variable_args_count; ++i) { + ParameterPtr p = std::make_shared(specialized_graph); + std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); + p->set_name(param_name); + MS_EXCEPTION_IF_NULL(p->debug_info()); + p->debug_info()->set_name(param_name); + var_param_tuple_nodes.push_back(p); + MS_EXCEPTION_IF_NULL(specialized_parameter_list); + specialized_parameter_list->push_back(p); + } + auto var_tuple_param = specialized_graph->NewCNode(var_param_tuple_nodes); + (void)repl_nodes->emplace(specialized_graph->GetVariableArgParameter(), var_tuple_param); + TraceManager::EndTrace(); + } else if (variable_args_count > 0) { + MS_LOG(EXCEPTION) << "Function:" << this->ToString() << " takes " << this->GetPositionalArgsCount() + << " positional arguments, but " << pos_args_input_count << " were given."; + } +} + +void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + const std::vector &kwarg_list, + std::unordered_map *repl_nodes) { + std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; + std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; + + for (const auto &kwarg : kwarg_list) { + MS_EXCEPTION_IF_NULL(kwarg); + std::string kw_param_name = kwarg->get_key(); + MS_EXCEPTION_IF_NULL(specialized_graph); + AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name); + // if not find correspoding parameter node + if (param_node == nullptr) { + if (!has_kwarg()) { + MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name; + } else { + ParameterPtr p = std::make_shared(specialized_graph); + std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; + MS_EXCEPTION_IF_NULL(specialized_parameter_list); + auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), + [param_name](const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto param = node->cast(); + return param != nullptr && param->name() == param_name; + }); + if (find_kw_arg_in_list) { + MS_LOG(EXCEPTION) << "Multiply values for keyword argument:" << kw_param_name; + } + p->set_name(param_name); + p->debug_info()->set_name(param_name); + kwarg_keys_tuple_nodes.push_back(NewValueNode(kw_param_name)); + auto extract_node = + specialized_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), p}); + kwarg_values_tuple_nodes.push_back(extract_node); + specialized_parameter_list->push_back(p); + } + } else { + auto node_itr = std::find(specialized_parameter_list->begin(), specialized_parameter_list->end(), param_node); + // multiply values found given for parameter + if (node_itr != specialized_parameter_list->end()) { + MS_LOG(EXCEPTION) << "Multiply values for specific argument:" << kw_param_name; + } else { + specialized_parameter_list->push_back(param_node); + auto extract_node = specialized_graph->NewCNode( + {NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kw_param_name), param_node}); + (void)repl_nodes->emplace(param_node, extract_node); + } + } + } + + GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); +} + +void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, + std::unordered_map *repl_nodes, + const std::vector &kwarg_keys_tuple_nodes, + const std::vector &kwarg_values_tuple_nodes) { + if (has_kwarg()) { + MS_EXCEPTION_IF_NULL(specialized_graph); + TraceManager::DebugTrace( + std::make_shared(specialized_graph->GetVariableKwargParameter()->debug_info())); + auto make_tuple_keys = specialized_graph->NewCNode(kwarg_keys_tuple_nodes); + auto make_tuple_values = specialized_graph->NewCNode(kwarg_values_tuple_nodes); + auto make_dict_node = + specialized_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), make_tuple_keys, make_tuple_values}); + MS_EXCEPTION_IF_NULL(repl_nodes); + (void)repl_nodes->emplace(specialized_graph->GetVariableKwargParameter(), make_dict_node); + TraceManager::EndTrace(); + } +} + +bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { + // if the function does not have any vararg/kwarg/kwonly/default value/kw args input + // return the original graph + if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { + return false; + } + + // if the graph is generated for specific input, do not need to generate again + if (is_generated()) { + return false; + } + return true; +} + +void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, + const std::vector &specialized_parameter_list, + std::unordered_map *repl_nodes) { + MS_EXCEPTION_IF_NULL(specialized_graph); + for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { + auto param_node = specialized_graph->parameters()[i]; + MS_EXCEPTION_IF_NULL(param_node); + auto param_name = param_node->cast()->name(); + auto node_itr = std::find(specialized_parameter_list.begin(), specialized_parameter_list.end(), param_node); + if (node_itr != specialized_parameter_list.end()) { + continue; + } + if (param_name == specialized_graph->GetVariableArgName() || + param_name == specialized_graph->GetVariableKwargName()) { + continue; + } + auto default_value = specialized_graph->GetDefaultValueByName(param_name); + if (default_value == nullptr) { + MS_LOG(EXCEPTION) << "Miss argument input for parameter:" << param_name; + } + MS_EXCEPTION_IF_NULL(repl_nodes); + (void)repl_nodes->emplace(param_node, default_value); + } +} + +FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { + std::vector kwarg_list; + size_t arguments_count = args_spec_list.size(); + for (const auto &arg : args_spec_list) { + // if it is a keyword argument + MS_EXCEPTION_IF_NULL(arg); + if (arg->isa()) { + kwarg_list.push_back(dyn_cast(arg)); + } + } + if (!NeedGenerate(kwarg_list)) { + return shared_from_base(); + } + FuncGraphPtr specialized_graph = BasicClone(shared_from_base()); + size_t kwarg_count = kwarg_list.size(); + int pos_args_input_count = SizeToInt(arguments_count - kwarg_count - hyper_param_count()); + int pos_args_count = std::min(pos_args_input_count, this->GetPositionalArgsCount()); + int variable_args_count = pos_args_input_count - pos_args_count; + std::vector specialized_parameter_list; + std::unordered_map repl_nodes; + // the parameters that has arg input, copy from original parameters + for (size_t i = 0; i < IntToSize(pos_args_count); ++i) { + specialized_parameter_list.push_back(specialized_graph->parameters()[i]); + } + + GenerateVarParams(specialized_graph, &specialized_parameter_list, &repl_nodes, variable_args_count, + pos_args_input_count); + + GenerateKwParams(specialized_graph, &specialized_parameter_list, kwarg_list, &repl_nodes); + + GenerateDefaultValue(specialized_graph, specialized_parameter_list, &repl_nodes); + + // append hyper parameter to specialized_parameter_list + MS_EXCEPTION_IF_NULL(specialized_graph); + auto params = specialized_graph->parameters(); + (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), + std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); + + std::shared_ptr manager = mindspore::Manage(specialized_graph, false); + auto tr = manager->Transact(); + for (auto &node_pair : repl_nodes) { + MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" + << node_pair.second->DebugString(); + (void)tr.Replace(node_pair.first, node_pair.second); + } + tr.SetParameters(specialized_graph, specialized_parameter_list); + tr.Commit(); + specialized_graph->set_has_kwarg(false); + specialized_graph->set_has_vararg(false); + specialized_graph->set_kwonlyargs_count(0); + specialized_graph->ClearDefaultValues(); + specialized_graph->set_is_generate(true); + return specialized_graph; +} + +const char kPrimHasEffect[] = "_side_effect_flag"; + +bool FuncGraph::HasEffect(const CNodePtr &cnode) { + auto prim = GetCNodePrimitive(cnode); + if (prim != nullptr && prim->isa()) { + auto do_sig = prim->cast(); + auto prim_val = do_sig->function(); + if (prim_val != nullptr && prim_val->isa()) { + prim = prim_val->cast(); + } else { + prim = nullptr; + } + } + if (prim != nullptr) { + auto effect_val = prim->GetAttr(kPrimHasEffect); + if (effect_val && effect_val->isa()) { + auto effect_bool = GetValue(effect_val); + return effect_bool; + } + } + return false; +} + +std::shared_ptr> FindRoots(const std::vector &segment) { + std::shared_ptr> roots = std::make_shared>(segment); + for (const auto &node : segment) { + if (roots->size() == 1) { + return roots; + } + auto input_size = node->size(); + for (size_t i = 0; i < input_size; i++) { + auto in_node = node->input(i); + auto in_cnode = in_node->cast(); + if (in_cnode != nullptr) { + (void)roots->erase(in_cnode); + } + } + } + return roots; +} + +std::shared_ptr> FindLeaves(const std::vector &segment) { + std::shared_ptr> nodes = std::make_shared>(segment); + for (const auto &node : segment) { + if (nodes->size() == 1) { + return nodes; + } + if (IsPrimitiveCNode(node, prim::kPrimSwitch)) { + (void)nodes->erase(node); + continue; + } + auto input_size = node->size(); + for (size_t i = 0; i < input_size; i++) { + auto in_node = node->input(i); + if (!in_node->isa()) { + continue; + } + auto in_cnode = in_node->cast(); + if (in_cnode != nullptr) { + if (std::find(segment.begin(), segment.end(), in_cnode) != segment.end()) { + (void)nodes->erase(node); + break; + } + } + } + } + return nodes; +} + +void FuncGraph::ReleaseFullOrderToEffectOrder() { + MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << "."; + if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { + std::list depends_order; + std::vector segment; + for (const auto &cnode : order_) { + if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { + continue; + } + if (HasEffect(cnode)) { + MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << "."; + if (segment.size() > 0) { + auto roots = FindRoots(segment); + for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { + depends_order.push_back(*iter); + } + } + segment.clear(); + depends_order.push_back(cnode); + } else { + MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << "."; + segment.push_back(cnode); + } + } + if (segment.size() > 1) { + auto roots = FindRoots(segment); + for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { + depends_order.push_back(*iter); + } + } + std::vector depend_inputs; + auto old_ret = output(); + for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) { + if (*iter != old_ret) { + depend_inputs.push_back(*iter); + } + } + set_flag(GRAPH_FLAG_HAS_EFFECT, false); + set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); + if (!depend_inputs.empty()) { + SetEffectDepends(depend_inputs); + } + } +} + +void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { + auto old_ret = output(); + std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; + (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); + auto new_ret = NewCNode(inputs); + auto mng = manager(); + if (mng) { + (void)mng->Replace(old_ret, new_ret); + } else { + return_->set_input(1, new_ret); + } +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph_py.cc b/mindspore/core/ir/func_graph_py.cc similarity index 100% rename from mindspore/ccsrc/ir/func_graph_py.cc rename to mindspore/core/ir/func_graph_py.cc diff --git a/mindspore/ccsrc/ir/lite/param_value_lite.h b/mindspore/core/ir/lite/param_value_lite.h similarity index 100% rename from mindspore/ccsrc/ir/lite/param_value_lite.h rename to mindspore/core/ir/lite/param_value_lite.h diff --git a/mindspore/ccsrc/ir/lite/tensor.cc b/mindspore/core/ir/lite/tensor.cc similarity index 100% rename from mindspore/ccsrc/ir/lite/tensor.cc rename to mindspore/core/ir/lite/tensor.cc diff --git a/mindspore/ccsrc/ir/lite/tensor.h b/mindspore/core/ir/lite/tensor.h similarity index 100% rename from mindspore/ccsrc/ir/lite/tensor.h rename to mindspore/core/ir/lite/tensor.h diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc new file mode 100644 index 0000000000..00c39679cd --- /dev/null +++ b/mindspore/core/ir/manager.cc @@ -0,0 +1,914 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/manager.h" + +#include +#include +#include + +#include "debug/trace_base.h" +#include "ir/func_graph.h" +#include "utils/profile.h" +#include "utils/convert_utils_base.h" +#include "frontend/operator/ops.h" + +namespace mindspore { + +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { + auto m = std::make_shared(func_graphs, manage); + m->Init(); + return m; +} + +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { + FuncGraphManagerPtr m = nullptr; + bool root = false; + + for (auto &fg : func_graphs) { + if (fg == nullptr) { + continue; + } + if (fg->manager() != nullptr) { + m = fg->manager(); + break; + } + } + + if (m == nullptr) { + std::vector tmp; + m = MakeManager(tmp, manage); + root = true; + } + + for (auto &fg : func_graphs) { + if (fg == nullptr) { + continue; + } + m->AddFuncGraph(fg, root); + } + return m; +} + +FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { + std::vector func_graphs = {func_graph}; + return Manage(func_graphs, manage); +} + +FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) + : roots_(roots), is_manage_(manage) { + Reset(); +} + +void FuncGraphManager::Reset() { + func_graphs_ = FuncGraphSet(); + all_nodes_ = AnfNodeSet(); + node_users_ = NodeUsersMap(); + + signals_ = std::make_shared(); + + func_graph_parents_total_ = std::make_shared(this); + func_graph_parent_ = std::make_shared(this); + children_ = std::make_shared(this); + scopes_ = std::make_shared(this); + free_variables_total_ = std::make_shared(this); + func_graphs_used_total_ = std::make_shared(this); + recursive_ = std::make_shared(this); + j_total_ = std::make_shared(this); + + limit_ = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); +} + +void FuncGraphManager::Init() { + auto roots = roots_; + roots_ = FuncGraphSet(); + + for (auto &fg : roots) { + AddFuncGraph(fg, true); + } +} + +FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); + func_graph_parents_total_->Recompute(fg); + MS_LOG(DEBUG) << "End func_graph_parents func graph " << fg->ToString(); + return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; +} + +FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(func_graph_parent_); + MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); + func_graph_parent_->Recompute(fg); + if (func_graph_parent_->parent_analysis().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager:" << fg->ToString(); + return nullptr; + } + MS_LOG(DEBUG) << "End parents func graph " << fg->ToString(); + return func_graph_parent_->parent_analysis()[fg]; +} + +FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(children_); + MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); + children_->Recompute(fg); + return children_->children_analysis()[fg]; +} + +FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(scopes_); + MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); + scopes_->Recompute(fg); + MS_LOG(DEBUG) << "End scopes func graph:" << fg->ToString(); + return scopes_->scope_analysis()[fg]; +} + +FVTotalMap &FuncGraphManager::free_variables_total() const { + MS_EXCEPTION_IF_NULL(free_variables_total_); + free_variables_total_->Recompute(); + return free_variables_total_->fv_total_analysis(); +} + +FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(func_graphs_used_total_); + func_graphs_used_total_->Recompute(fg); + return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; +} + +bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + recursive_->Recompute(fg); + if (recursive_->recursive_analysis().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); + return false; + } + return recursive_->recursive_analysis()[fg]; +} + +std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(fg); + if (recursive(fg)) { + if (!recursive_->recursive_map().count(fg)) { + auto trace = std::list(); + recursive_->CheckRecursiveGraphs(fg, &trace); + } + if (recursive_->recursive_map().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); + return nullptr; + } + return recursive_->recursive_map()[fg]; + } else { + return nullptr; + } +} + +bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { + MS_EXCEPTION_IF_NULL(j_total_); + MS_EXCEPTION_IF_NULL(fg); + j_total_->Recompute(fg); + if (j_total_->j_total_analysis().count(fg) == 0) { + MS_LOG(WARNING) << "This func graph is not in manager: " << fg->ToString(); + return false; + } + return j_total_->j_total_analysis()[fg]; +} + +// add a func graph to this manager, optionally as a root func graph. +void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { + MS_EXCEPTION_IF_NULL(func_graph); + if (is_root) { + roots_.add(func_graph); + } + if (func_graphs_.contains(func_graph)) { + return; + } + AddIntoManaged(func_graph); + std::vector para = func_graph->parameters(); + AcquireNodes(para); + std::vector return_vec({func_graph->get_return()}); + AcquireNodes(return_vec); +} + +// clear the all information in manager +void FuncGraphManager::Clear() { + func_graphs_.clear(); + all_nodes_.clear(); + node_users_.clear(); + roots_.clear(); + + signals_->InvalidateComputer(); +} + +void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { + MS_LOG(DEBUG) << "Start keep roots"; + bool root_exist = false; + for (auto &item : func_graphs) { + if (roots_.contains(item)) { + root_exist = true; + break; + } + } + + // if the new_root in roots_, we add new_root first, then calculate the func_graphs + // relation to new_root, remove the func_graphs not relation to new_root + // if the new_root not in roots_, we clear the all func_graphs in manager + // then add the new_root + if (root_exist || func_graphs.empty()) { + FuncGraphSet roots(func_graphs); + if (roots.empty()) { + roots = roots_; + } else { + roots_.clear(); + for (auto &item : roots) { + AddFuncGraph(item, true); + } + } + + FuncGraphSet keep; + for (auto &item : roots) { + MS_LOG(DEBUG) << "roots: " << item->ToString(); + keep.update(func_graphs_used_total(item)); +#ifdef DEBUG + for (auto &k : keep) { + MS_LOG(DEBUG) << "keep: " << k->ToString(); + } +#endif + } + MaybeDropFuncGraphs(func_graphs_ - keep, true); + } else { + Clear(); + FuncGraphSet roots(func_graphs); + for (auto &item : roots) { + AddFuncGraph(item, true); + } + } +} + +void FuncGraphManager::RemoveRoots() { + MS_LOG(DEBUG) << "Start remove roots"; + roots_.clear(); + MaybeDropFuncGraphs(func_graphs_, true); +} + +void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(fg); + if (is_manage_) { + if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { + MS_LOG(WARNING) << "A func graph can only have one manager."; + } + FuncGraphManagerPtr this_manager = shared_from_this(); + fg->set_manager(this_manager); + } + func_graphs_.add(fg); +} + +void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { + FuncGraphSet todo(func_graphs); + std::set dropped; + // int count = 0; + while (!todo.empty()) { + FuncGraphPtr func_graph = todo.pop(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "Maybe drop func graph " << func_graph->ToString(); + if (roots_.contains(func_graph)) { + MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString(); + continue; + } + auto &users_cnode_index = func_graph->func_graph_cnodes_index(); + if (!users_cnode_index.empty() && !ignore_users) { + MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); + continue; + } + if (dropped.find(func_graph) != dropped.end()) { + MS_LOG(DEBUG) << "Func graph had been dropped " << func_graph->ToString(); + continue; + } + (void)dropped.insert(func_graph); + std::vector return_vec = {func_graph->get_return()}; + todo.update(MaybeDropNodes(return_vec)); + } + for (auto &fg : dropped) { + MS_EXCEPTION_IF_NULL(fg); + all_nodes_.difference_update(fg->parameters()); + (void)func_graphs_.erase(fg); + if (fg->manager().get() == this) { + fg->set_manager(nullptr); + } + MS_LOG(DEBUG) << "Func graph dropped " << fg->ToString(); + } +} + +void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(inp); + if (direction == kDecEdge) { + MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); + auto &users_node = node_users_[inp]; + if (!users_node.contains(make_pair(node, index))) { + return; + } + (void)users_node.erase(make_pair(node, index)); + DropEdge(node, index, inp); + } else { + MS_LOG(DEBUG) << "Add node " << node->ToString() << " input[" << index << "] " << inp->ToString(); + if (IsValueNode(inp)) { + MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); + AddFuncGraph(GetValueNode(inp)); + } + auto &users_node = node_users_[inp]; + users_node.add(make_pair(node, index)); + AddEdge(node, index, inp); + } +} + +void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + int index = 0; + for (auto &inp : cnode->inputs()) { + ProcessEdge(cnode, index, inp, direction); + ++index; + } + } +} + +IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { + if (all_nodes_.contains(node)) { + return EXCLUDE; + } else { + return FOLLOW; + } +} + +void FuncGraphManager::AcquireNodes(const std::vector &nodes) { + AnfNodeSet acq; + for (auto &node : nodes) { + AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit_)); + + all_nodes_.update(new_nodes); + acq.update(new_nodes); + } + + for (auto &node : acq) { + MS_EXCEPTION_IF_NULL(node); + auto fg = node->func_graph(); + if (fg != nullptr) { + fg->AddNode(node); + } + ProcessInputs(node, kIncEdge); + } +} + +FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { + AnfNodeSet nodes_ordered(nodes); + FuncGraphSetPtr func_graphs_to_check = std::make_shared(); + while (!nodes_ordered.empty()) { + AnfNodePtr node = nodes_ordered.pop(); + MS_EXCEPTION_IF_NULL(node); + if (!all_nodes_.contains(node)) { + continue; + } + AnfNodeIndexSet &users = node_users_[node]; + + std::vector parameters; + if (!users.empty() || + (node->isa() && parameters.end() != std::find(parameters.begin(), parameters.end(), node))) { + continue; + } + if (IsValueNode(node)) { + auto fg = GetValueNode(node); + func_graphs_to_check->add(fg); + MS_LOG(DEBUG) << "Set value of node " << node->DebugString() << " from func graph " << fg->ToString() + << " to null"; + } + ProcessInputs(node, kDecEdge); + (void)all_nodes_.erase(node); + if (node->func_graph() != nullptr) { + node->func_graph()->DropNode(node); + } + + if (node->isa()) { + auto cnode = node->cast(); + nodes_ordered.update(cnode->inputs()); + } + (void)node_users_.erase(node); + } + return func_graphs_to_check; +} + +void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { + auto tr = Transact(); + tr.SetParameters(fg, parameters); + tr.Commit(); +} + +void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { + auto tr = Transact(); + tr.AddParameter(fg, parameter); + tr.Commit(); +} + +bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + auto tr = Transact(); + bool success = tr.Replace(old_node, new_node); + if (success) { + tr.Commit(); + } + return success; +} + +void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { + auto tr = Transact(); + tr.SetEdge(node, index, value); + tr.Commit(); +} + +void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { + AnfNodePtr source_return = source->get_return(); + AnfNodePtr source_output = source->output(); + AnfNodePtr source_prim = source_return->cast()->input(0); + + int index = 0; + (void)node_users_[source_prim].erase(make_pair(source_return, index)); + DropEdge(source_return, index, source_prim); + index = 1; + (void)node_users_[source_output].erase(make_pair(source_return, index)); + DropEdge(source_return, index, source_output); + (void)all_nodes_.erase(source_return); + (void)node_users_.erase(source_return); + source->DropNode(source_return); + for (auto &node : source->nodes()) { + node->set_func_graph(target); + if (node->scope() == kDefaultScope) { + node->set_scope(scope); + } + } + + MoveAllNodes(source, target); + all_nodes_.difference_update(source->parameters()); + (void)func_graphs_.erase(source); + if (source->manager().get() == this) { + source->set_manager(nullptr); + } +} + +void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr input) { + auto fg = node->func_graph(); + if (input->isa()) { + fg->AddValueNode(input); + if (IsValueNode(input)) { + auto used = GetValueNode(input); + used->AddFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->AddFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + fg->AddJFuncGraph(used); + } + } + } else if (fg != nullptr && fg != input->func_graph()) { + if (fg->AddFreeVariable(input)) { + signals_->InvalidateComputer(); + } + } +} + +void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr input) { + auto fg = node->func_graph(); + if (input->isa()) { + fg->DropValueNode(input); + if (IsValueNode(input)) { + auto used = GetValueNode(input); + used->DropFuncGraphCNodeIndex(std::make_shared(std::make_pair(node, index))); + if (fg->DropFuncGraphUsed(used)) { + signals_->InvalidateComputer(); + } + if (IsPrimitiveCNode(node, prim::kPrimJ)) { + fg->DropJFuncGraph(used); + } + } + } else if (fg != nullptr && fg != input->func_graph()) { + if (fg->DropFreeVariable(input)) { + signals_->InvalidateComputer(); + } + } +} + +void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { + target->CopyNodes(source); + target->CopyValueNodes(source); + target->CopyFuncGraphCNodesIndex(source); + target->CopyFreeVariables(source); + target->CopyFuncGraphsUsed(source); + target->CopyJFuncGraphs(source); + signals_->InvalidateComputer(); + source->ClearNodes(); + source->ClearValueNodes(); + source->ClearFuncGraphCNodesIndex(); + source->ClearFreeVariables(); + source->ClearFuncGraphsUsed(); + source->ClearJFuncGraphs(); +} + +FuncGraphTransaction FuncGraphManager::Transact() { + auto tr = FuncGraphTransaction(this); + return tr; +} + +void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, + EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { + for (auto &iter : changes) { + auto operation = iter.op; + auto args = iter.args; + switch (operation) { + case Change::kTxSetEdge: { + auto edge = args.cast(); + auto old_node = edge.root_node->input(edge.index); + (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; + (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; + (*rms)[old_node] += 1; + (*adds)[edge.new_node] += 1; + edge.root_node->set_input(edge.index, edge.new_node); + } break; + case Change::kTxSetParams: { + auto param = args.cast(); + MS_EXCEPTION_IF_NULL(param.func_graph); + auto old_parameters = param.func_graph->parameters(); + for (auto &p : param.params) { + (*adds)[p] += 1; + } + for (auto &p : old_parameters) { + (*rms)[p] += 1; + } + param.func_graph->set_parameters(param.params); + } break; + case Change::kTxAddParam: { + auto param = args.cast(); + MS_EXCEPTION_IF_NULL(param.func_graph); + (*adds)[param.param] += 1; + auto param_node = param.param->cast(); + param.func_graph->append_parameter(param_node); + } break; + default: + break; + } + } +} + +void FuncGraphManager::CommitChanges(const std::vector &changes) { + EdgeTupleCounter add_edges; + EdgeTupleCounter rm_edges; + Counter adds; + Counter rms; + ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); + + auto sub_edges = add_edges - rm_edges; + for (auto &iter : sub_edges) { + auto root_node = iter.first.first; + int index = iter.first.second.first; + auto new_node = iter.first.second.second; + ProcessEdge(root_node, index, new_node, kIncEdge); + } + + auto sub_nodes = adds - rms; + std::vector nodes; + (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); + + AcquireNodes(nodes); + + auto sub_edges_reverse = rm_edges - add_edges; + for (auto &iter : sub_edges_reverse) { + auto root_node = iter.first.first; + int index = iter.first.second.first; + auto old_node = iter.first.second.second; + ProcessEdge(root_node, index, old_node, kDecEdge); + } + + auto sub_nodes_reverse = rms - adds; + std::vector nodes_reverse; + + (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); + + auto drop_func_graphs = MaybeDropNodes(nodes_reverse); + MaybeDropFuncGraphs(*drop_func_graphs); +} + +void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { + changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); +} + +void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { + changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); +} + +bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + FuncGraphPtr old_func_graph = old_node->func_graph(); + if (old_func_graph != nullptr && old_func_graph->get_return() == old_node) { + MS_LOG(WARNING) << "Cannot replace the return node of a func graph " << old_func_graph->ToString(); + return false; + } + auto users = manager_->node_users()[old_node]; + for (auto &node : users) { + SetEdge(node.first, node.second, new_node); + } + + return true; +} + +void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { + if (k < 0) { + MS_LOG(EXCEPTION) << "Invalid value k = " << k; + } + MS_EXCEPTION_IF_NULL(src_node); + auto cnode = src_node->cast(); + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed."; + } + changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)}); +} + +void FuncGraphTransaction::Commit() { + std::vector changes; + changes_.swap(changes); + manager_->CommitChanges(changes); +} + +DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) { + MS_EXCEPTION_IF_NULL(manager_); + manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); + validate_ = false; +} + +void DepComputer::Recompute() { + if (!validate_) { + RealRecompute(); + validate_ = true; + } +} + +void DepComputer::Recompute(const FuncGraphPtr &fg) { + if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { + RealRecompute(fg); + func_graphs_validate_[fg] = true; + } +} + +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { + return std::make_shared(); + } + FuncGraphSetPtr parents = std::make_shared(); + + // Append all the fvs in fg. + auto &fvs = fg->free_variables(); + for (auto fv : fvs) { + parents->add(fv.first->func_graph()); + } + + // Search the fv in fg's child func graph. + auto &fgs = fg->func_graphs_used(); + for (auto &item : fgs) { + fg->seen_ = seen_num; + auto gt = item.first; + parents->update(SeekParents(gt, seen_num)); + } + (void)parents->erase(fg); + return parents; +} + +void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(fg); + func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration())); +} + +bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { + auto l1 = lhs.second.size(); + auto l2 = rhs.second.size(); + return l1 < l2; +} + +void ParentComputer::RealRecompute(FuncGraphPtr fg) { + this->parent_analysis_[fg] = nullptr; + // Note: must be a copy other than reference as it is modified thereafter. + auto deps = this->manager_->func_graph_parents_total(fg); + + if (deps.empty()) { + this->parent_analysis_[fg] = nullptr; + return; + } else if (deps.size() == 1) { + this->parent_analysis_[fg] = deps.pop(); + return; + } else { + // return nearest parent as parent + FuncGraphSet deps_copy(deps); + for (auto &dep : deps) { + auto parent_deps = this->manager_->func_graph_parents_total(dep); + for (auto &p_d : parent_deps) { + if (deps_copy.count(p_d)) { + (void)deps_copy.erase(p_d); + } + } + if (deps_copy.size() == 1) { + this->parent_analysis_[fg] = deps_copy.pop(); + return; + } + } + } +} + +void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(manager_); + auto used_fg_total = manager_->func_graphs_used_total(fg); + for (auto &used_fg : used_fg_total) { + if (manager_->parent(used_fg) == fg) { + children_analysis_[fg].add(used_fg); + } + } +} + +void ScopeComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(manager_); + auto &children = manager_->children(fg); + + scope_analysis_[fg] = FuncGraphSet(); + scope_analysis_[fg].add(fg); + for (auto &child : children) { + scope_analysis_[fg].add(child); + } +} + +void FVTotalComputer::RealRecompute() { + auto manager = DepComputer::manager_; + MS_EXCEPTION_IF_NULL(manager); + + for (auto &fg : manager->func_graphs()) { + fv_total_analysis_[fg] = OrderedMap(); + } + + for (auto &fg : manager->func_graphs()) { + // add all free variable nodes + AnfNodeCounterMap items = fg->free_variables(); + for (auto &iter : items) { + auto curr = fg; + while (curr != nullptr) { + fv_total_analysis_[curr][iter.first] = iter.second; + curr = manager->parent(curr); + if (curr != nullptr) { + const AnfNodeSet &all_nodes = curr->nodes(); + if (all_nodes.contains(iter.first)) { + break; + } + } + } + } + + // add all FGs of free variables + auto &used = fg->func_graphs_used(); + for (auto &iter : used) { + auto p = manager->parent(iter.first); + if (p == nullptr) { + continue; + } + auto curr = fg; + while (curr != p) { + fv_total_analysis_[curr][iter.first] = iter.second; + curr = manager->parent(curr); + } + } + } +} + +void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { + MS_EXCEPTION_IF_NULL(manager_); + std::vector todo; + std::vector todo_new; + + todo.push_back(fg); + while (!todo.empty()) { + todo_new.clear(); + for (auto > : todo) { + for (auto &item : gt->func_graphs_used()) { + auto used_fg = item.first; + if (used_fg == fg) { + func_graph_used_total_analysis_[fg].add(used_fg); + continue; + } + if (func_graph_used_total_analysis_[fg].count(used_fg) == 0) { + todo_new.push_back(used_fg); + } + MS_LOG(DEBUG) << fg->ToString() << " add func graph " << used_fg->ToString(); + func_graph_used_total_analysis_[fg].add(used_fg); + } + } + todo = todo_new; + } +} + +bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(manager); + std::vector todo; + std::vector todo_new; + todo.push_back(fg); + FuncGraphSet used_total; + while (!todo.empty()) { + todo_new.clear(); + for (auto > : todo) { + for (auto &item : gt->func_graphs_used()) { + auto used_g = item.first; + if (used_g == fg) { + return true; + } + if (used_total.count(used_g) == 0) { + todo_new.push_back(used_g); + } + used_total.add(used_g); + } + } + todo = todo_new; + } + return false; +} + +void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { + this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); +} + +void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { + MS_EXCEPTION_IF_NULL(trace); + auto res = std::find(trace->begin(), trace->end(), fg); + // find recursive + if (res != trace->end()) { + auto recur_ptr = std::make_shared>(res, trace->end()); + for (auto iter = res; iter != trace->end(); (void)iter++) { + MS_LOG(DEBUG) << "Recursive graph " << (*iter)->ToString(); + recursive_map_[*iter] = recur_ptr; + } + } else { + trace->push_back(fg); + auto &items = fg->func_graphs_used(); + for (auto iter = items.begin(); iter != items.end(); (void)iter++) { + CheckRecursiveGraphs(iter->first, trace); + } + trace->pop_back(); + if (!recursive_map_.count(fg)) { + recursive_map_[fg] = nullptr; + } + } +} + +bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) { + if (fg->seen_ == seen_num) { + MS_LOG(DEBUG) << fg->ToString() << " had been checked"; + return false; + } + auto &j_fgs = fg->j_func_graphs(); + if (!j_fgs.empty()) { + // check g1->J(fg)->g2->g cycle; + auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair iter) { + return iter.first->seen_ != seen_num; + }); + if (contains_j != j_fgs.end()) { + MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; + return true; + } + } + fg->seen_ = seen_num; + + // check if func graphs used contains J(func_graph); + for (auto &item : fg->func_graphs_used()) { + auto used_g = item.first; + if (SeekJ(used_g, seen_num)) { + MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; + return true; + } + } + MS_LOG(DEBUG) << fg->ToString() << " doesn't contain J(func_graph)"; + return false; +} + +void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) { + this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration()); +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/core/ir/manager.h similarity index 100% rename from mindspore/ccsrc/ir/manager.h rename to mindspore/core/ir/manager.h diff --git a/mindspore/core/ir/meta_func_graph.cc b/mindspore/core/ir/meta_func_graph.cc new file mode 100644 index 0000000000..df07ea1b67 --- /dev/null +++ b/mindspore/core/ir/meta_func_graph.cc @@ -0,0 +1,58 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/meta_func_graph.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/abstract_function.h" + +// namespace to support intermediate representation definition +namespace mindspore { +abstract::AbstractBasePtr MetaFuncGraph::MakeAbstractClosure(const AnfNodePtr &anf_node) { + abstract::MetaFuncGraphAbstractClosurePtr meta_func_graph_fn; + if (anf_node == nullptr) { + meta_func_graph_fn = std::make_shared(shared_from_base()); + } else { + meta_func_graph_fn = + std::make_shared(shared_from_base(), anf_node->scope()); + } + return meta_func_graph_fn; +} + +FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { + TypePtrList types; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), + [](const AbstractBasePtr &arg) -> TypePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->BuildType(); + }); + // filter unsafe characters in log print since name_ is from outside + auto iter = cache_.find(types); + if (iter == cache_.end()) { + FuncGraphPtr fg = GenerateFromTypes(types); + MS_EXCEPTION_IF_NULL(fg); + MS_LOG(INFO) << "MetaFuncgraph: cache miss for types: " << mindspore::ToString(args_spec_list) + << ", g: " << fg->ToString(); + cache_[types] = fg; + return fg; + } else { + MS_LOG(DEBUG) << "MetaFuncgraph: cache hit for types: " << mindspore::ToString(args_spec_list) + << ", g: " << iter->second->ToString(); + return iter->second; + } +} +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/core/ir/meta_func_graph.h similarity index 100% rename from mindspore/ccsrc/ir/meta_func_graph.h rename to mindspore/core/ir/meta_func_graph.h diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/core/ir/meta_tensor.cc similarity index 100% rename from mindspore/ccsrc/ir/meta_tensor.cc rename to mindspore/core/ir/meta_tensor.cc diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/core/ir/meta_tensor.h similarity index 100% rename from mindspore/ccsrc/ir/meta_tensor.h rename to mindspore/core/ir/meta_tensor.h diff --git a/mindspore/ccsrc/ir/meta_tensor_extends.cc b/mindspore/core/ir/meta_tensor_extends.cc similarity index 100% rename from mindspore/ccsrc/ir/meta_tensor_extends.cc rename to mindspore/core/ir/meta_tensor_extends.cc diff --git a/mindspore/ccsrc/ir/named.cc b/mindspore/core/ir/named.cc similarity index 100% rename from mindspore/ccsrc/ir/named.cc rename to mindspore/core/ir/named.cc diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/core/ir/named.h similarity index 100% rename from mindspore/ccsrc/ir/named.h rename to mindspore/core/ir/named.h diff --git a/mindspore/ccsrc/ir/optimizer_caller.h b/mindspore/core/ir/optimizer_caller.h similarity index 100% rename from mindspore/ccsrc/ir/optimizer_caller.h rename to mindspore/core/ir/optimizer_caller.h diff --git a/mindspore/ccsrc/ir/param_value.h b/mindspore/core/ir/param_value.h similarity index 100% rename from mindspore/ccsrc/ir/param_value.h rename to mindspore/core/ir/param_value.h diff --git a/mindspore/ccsrc/ir/param_value_py.cc b/mindspore/core/ir/param_value_py.cc similarity index 100% rename from mindspore/ccsrc/ir/param_value_py.cc rename to mindspore/core/ir/param_value_py.cc diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h new file mode 100644 index 0000000000..94ba4a381a --- /dev/null +++ b/mindspore/core/ir/pattern_matcher.h @@ -0,0 +1,310 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ +#define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ + +#include +#include + +#include "ir/anf.h" +#include "frontend/operator/ops.h" + +namespace mindspore { + +/// +/// Base class for all recognizable patterns. +/// We implement an Expression Template approach using static polymorphism based on +/// the Curiously Recurring Template Pattern (CRTP) which "achieves a similar effect +/// to the use of virtual functions without the costs..." as described in: +/// https://en.wikipedia.org/wiki/Expression_templates and +/// https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +/// The TryCapture function tries to capture the pattern with the given node. +/// The GetNode function builds a new node using the captured values. +/// + +template +class PBase { + public: + bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) { + return func(get_object().GetNode(node)); + } + + const T &get_object() const { return *static_cast(this); } + + template + bool TryCapture(const TN &value) const { + get_object().Reset(); + return get_object().TryCapture_(value); + } + + using Internal = T; +}; + +template +class PIsEqual { + public: + bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } +}; + +template +class PatternNode : public PBase > { + public: + T GetNode(const AnfNodePtr &node) const { + if (!captured_) { + MS_EXCEPTION(ValueError) << "A Pattern wasn't captured for this Token before the call to GetNode."; + } + return captured_node_; + } + + bool TryCapture_(const T &node) const { + if (!captured_) { + captured_node_ = node; + captured_ = true; + return true; + } + return PIsEqual()(captured_node_, node); + } + + void Reset() const { captured_ = false; } + using Internal = const PatternNode &; + + protected: + mutable T captured_node_; + mutable bool captured_{false}; +}; + +template +class PBinOperation : public PBase > { + public: + PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + AnfNodePtr lhs = x_.GetNode(node->func_graph()); + AnfNodePtr rhs = y_.GetNode(node->func_graph()); + AnfNodePtrList list = {prim_->cast(), lhs, rhs}; + return NewCNode(list, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() == 3) { + // Binary Prim assumes only two inputs + if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { + return false; + } + return true; + } + } + return false; + } + + void Reset() const { + x_.Reset(); + y_.Reset(); + } + + private: + const PrimitivePtr prim_; + typename T::Internal x_; + typename T2::Internal y_; +}; + +/// +/// Helper functions to apply a pattern function on all elements of a tuple +/// +namespace tuple_utils { +template +struct apply_func_tuple_item { + template + static void apply(Func *func, const TTuple &tuple) { + (*func)(Index, std::get(tuple)); + apply_func_tuple_item<(Index + 1) == std::tuple_size::value, (Index + 1), Func>::apply(func, tuple); + } +}; + +template +struct apply_func_tuple_item { + template + static void apply(Func *func, const TTuple &tuple) {} +}; + +template +inline void apply_func_tuple(Func *func, const TTuple &tuple) { + apply_func_tuple_item::value == 0, 0, Func>::apply(func, tuple); +} + +struct PTupleResetCapture { + template + void operator()(size_t i, const T &pattern) const { + pattern.Reset(); + } +}; + +struct PTupleCapture { + explicit PTupleCapture(const AnfNodePtrList tuple) : tuple_(tuple) {} + + template + void operator()(size_t i, const TPattern &pattern) { + // Check if the first node is a Primitive + if (i == 0 && tuple_[i]->isa()) { + auto prim = tuple_[i]->cast(); + if (tuple_[i] != pattern.GetNode(tuple_[i])) { + captured_ = false; + } + } else { + captured_ = captured_ && pattern.TryCapture_(tuple_[i]); + } + } + + const AnfNodePtrList tuple_; + bool captured_{true}; +}; + +struct PTupleGetNode { + explicit PTupleGetNode(const AnfNodePtr &node) : node_(node) {} + + template + void operator()(size_t, const TPattern &pattern) { + args_.push_back(pattern.GetNode(node_)); + } + + const AnfNodePtr &node_; + std::vector args_; +}; +} // namespace tuple_utils + +template +class PCNode : public PBase > { + public: + explicit PCNode(const TArgs &... args) : args_(args...) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + tuple_utils::PTupleGetNode get_node(node); + tuple_utils::apply_func_tuple(&get_node, args_); + return NewCNode(get_node.args_, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (node->isa()) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if (inputs.size() != sizeof...(TArgs)) { + return false; + } + tuple_utils::PTupleCapture capture_func(inputs); + tuple_utils::apply_func_tuple(&capture_func, args_); + return capture_func.captured_; + } + + return false; + } + + void Reset() const { + tuple_utils::PTupleResetCapture reset; + tuple_utils::apply_func_tuple(&reset, args_); + } + + private: + std::tuple args_; +}; + +template +class PPrimitive : public PBase > { + public: + explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {} + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + tuple_utils::PTupleGetNode get_node(node); + tuple_utils::apply_func_tuple(&get_node, args_); + auto prim_cnode = get_node.args_; + prim_cnode.insert(prim_cnode.begin(), NewValueNode(prim_)); + return NewCNode(prim_cnode, node->func_graph()); + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsPrimitiveCNode(node, prim_)) { + auto cnode = node->cast(); + auto inputs = cnode->inputs(); + if ((inputs.size() - 1) != sizeof...(TArgs)) { + return false; + } + + AnfNodePtrList rest(inputs.begin() + 1, inputs.end()); + tuple_utils::PTupleCapture capture_func(rest); + tuple_utils::apply_func_tuple(&capture_func, args_); + + return capture_func.captured_; + } + + return false; + } + + void Reset() const { + tuple_utils::PTupleResetCapture reset; + tuple_utils::apply_func_tuple(&reset, args_); + } + + private: + const PrimitivePtr prim_; + std::tuple args_; +}; + +// Macro for binary operation functions +#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ + template \ + inline PBinOperation Operator(const PBase &x, const PBase &y) { \ + return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ + } + +// Arithmetic operations +BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); +BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); + +// Macros for match and replace +#define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ + if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + if ((Condition)) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } \ + return (ElseNode).GetNode(OrigNode); \ + } + +#define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ + if ((CaptureNode).TryCapture(OrigNode)) { \ + return (Lambda)(); \ + } + +#define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ + if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ + return (Lambda)(); \ + } + +} // namespace mindspore + +#endif // #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/core/ir/primitive.cc similarity index 100% rename from mindspore/ccsrc/ir/primitive.cc rename to mindspore/core/ir/primitive.cc diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h new file mode 100644 index 0000000000..5471b58063 --- /dev/null +++ b/mindspore/core/ir/primitive.h @@ -0,0 +1,152 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_H_ +#define MINDSPORE_CCSRC_IR_PRIMITIVE_H_ + +#include +#include +#include +#include +#include + +#include "ir/dtype/type.h" +#include "abstract/abstract_value.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +// Supported meta type +enum PrimType { + kPrimTypeUnknown = 0, + kPrimTypeBegin = kTypeUnknown, + kPrimTypeBuiltIn, // Built-in primitive operator + kPrimTypePyInferShape, // Primitive operator defined by custom + kPrimTypePyInferTensor, // Primitive operator defined by custom + kPrimTypeUserCustom +}; + +class Primitive : public Named { + public: + explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn) + : Named(name), + is_base_(is_base), + has_signature_(false), + prim_type_(prim_type), + record_evaluate_add_attr_(false) {} + + Primitive(const Primitive &prim) + : Named(prim), + attrs_(prim.attrs_), + instance_name_(prim.instance_name_), + is_base_(prim.is_base_), + has_signature_(prim.has_signature_), + prim_type_(prim.prim_type_), + record_evaluate_add_attr_(false) {} + + MS_DECLARE_PARENT(Primitive, Named); + + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); + std::string ToString() const override { return name(); } + void BeginRecordAddAttr() { + evaluate_added_attrs_.clear(); + record_evaluate_add_attr_ = true; + } + void EndRecordAddAttr() { record_evaluate_add_attr_ = false; } + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { + attrs_[name] = attr; + if (record_evaluate_add_attr_) { + evaluate_added_attrs_[name] = attr; + } + return *this; + } + + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } + return *this; + } + + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } + + ValuePtr GetAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return iter == attrs_.cend() ? nullptr : iter->second; + } + + const std::unordered_map &attrs() const { return attrs_; } + const std::unordered_map &evaluate_added_attrs() const { return evaluate_added_attrs_; } + + // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. + bool HasAttr() const { return !attrs_.empty(); } + bool HasAttr(const std::string &attrName) const { + auto iter = attrs_.find(attrName); + return !(iter == attrs_.cend()); + } + void set_prim_type(const PrimType t) { prim_type_ = t; } + void set_instance_name(const std::string s) { instance_name_ = s; } + bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } + bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } + bool IsCustomPrim() const { return prim_type_ == kPrimTypeUserCustom; } + + PrimType prim_type() const { return prim_type_; } + std::string instance_name() const { return instance_name_; } + std::string GetAttrsText() const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; + ~Primitive() override = default; + + void set_has_signature(bool has_signature) { has_signature_ = has_signature; } + bool has_signature() const { return has_signature_; } + bool is_base() const { return is_base_; } + virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; } + virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; } + + protected: + std::unordered_map attrs_; + std::unordered_map evaluate_added_attrs_; + + private: + std::string instance_name_; + bool is_base_; + bool has_signature_; + PrimType prim_type_; + bool record_evaluate_add_attr_; +}; + +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { + os << *p; + return os; +} + +struct PrimitiveEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + MS_EXCEPTION_IF_NULL(t1); + MS_EXCEPTION_IF_NULL(t2); + return t1->name() == t2->name(); + } +}; + +struct PrimitiveHasher { + std::size_t operator()(PrimitivePtr const &prim) const { + MS_EXCEPTION_IF_NULL(prim); + return prim->Hash(); + } +}; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ diff --git a/mindspore/core/ir/primitive_extends.cc b/mindspore/core/ir/primitive_extends.cc new file mode 100644 index 0000000000..8e04ba8233 --- /dev/null +++ b/mindspore/core/ir/primitive_extends.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive.h" +#include "pipeline/jit/static_analysis/abstract_function.h" + +namespace mindspore { +abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { + auto prim_func = std::make_shared(shared_from_base(), anf_node); + return prim_func; +} +} // namespace mindspore diff --git a/mindspore/core/ir/primitive_py.cc b/mindspore/core/ir/primitive_py.cc new file mode 100644 index 0000000000..1a97487ddc --- /dev/null +++ b/mindspore/core/ir/primitive_py.cc @@ -0,0 +1,195 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/primitive_py.h" +#include +#include +#include "ir/signature.h" +#include "frontend/operator/ops.h" +#include "./common.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pybind11/pytypes.h" +#include "utils/convert_utils_base.h" +#include "utils/primitive_utils.h" +#include "utils/base_ref_py.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +namespace { +constexpr auto kBpropAttrName = "bprop"; +constexpr auto kCellHookAttrName = "cell_hook"; +constexpr auto kCellIDAttrName = "cell_id"; +void SyncData(const py::object &arg) { + if (py::isinstance(arg)) { + py::tuple arg_list = py::cast(arg); + for (size_t i = 0; i < arg_list.size(); i++) { + SyncData(arg_list[i]); + } + } + if (py::isinstance(arg)) { + auto tensor = py::cast(arg); + (void)tensor->data_sync(); + } +} +} // namespace +std::map PrimitivePy::hook_grad_; +static ValuePtr PyArgToValue(const py::object &arg) { + if (py::isinstance(arg) && + py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { + return nullptr; + } + return parse::data_converter::PyDataToValue(arg); +} + +void PrimitivePy::set_signatures( + std::vector> signatures) { + signatures_.clear(); + for (auto &signature : signatures) { + auto [name, rw, kind, arg_default, dtype] = signature; + auto default_value = PyArgToValue(arg_default); + signatures_.emplace_back(name, rw, kind, default_value, dtype); + } + set_has_signature(true); +} + +py::function PrimitivePy::GetBpropFunction() { + static const char *const get_bprop_func_name = "get_bprop"; + if (py::hasattr(python_obj_, get_bprop_func_name)) { + py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); + return fn; + } else { + auto fn = GetBpropFunctionByObj(python_obj_); + return fn; + } +} + +BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { + auto py_args = py::tuple(args.size()); + size_t i = 0; + for (auto &arg : args) { + py_args[i] = BaseRefToPyData(arg); + MS_LOG(DEBUG) << "arg:" << i << ":"; + i++; + } + py::object obj; + bool is_bprop = this->HasAttr(kBpropAttrName); + if (is_bprop) { + SyncData(py_args); + obj = hook_(*py_args); + return std::make_shared(obj); + } + SyncData(py_args[2]); + bool is_cell = this->HasAttr(kCellHookAttrName); + if (is_cell) { + auto cell_id = GetValue(this->GetAttr(kCellIDAttrName)); + auto iter = hook_grad_.find(cell_id); + if (iter != hook_grad_.end()) { + auto hook_args = py::tuple(3); + hook_args[0] = cell_id; + hook_args[1] = py::make_tuple(iter->second); + hook_args[2] = py::make_tuple(py_args[2]); + obj = hook_(*hook_args); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + hook_grad_.erase(cell_id); + } else { + hook_grad_[cell_id] = py_args[2]; + obj = py_args[2]; + } + } else { + // Hook operator for execute variable hook function + obj = hook_(py::make_tuple(py_args[2])); + if (py::isinstance(obj)) { + obj = py_args[2]; + } + } + obj = py::make_tuple(obj); + return std::make_shared(obj); +} + +py::function PrimitivePy::GetComputeFunction() { + static const char *const compute_func_name = "vm_impl"; + + if (py::hasattr(python_obj_, compute_func_name)) { + MS_LOG(INFO) << name() << " compute_func_name"; + py::function fn = python_obj_.attr(compute_func_name).cast(); + return fn; + } + + static const std::string vm_module = "mindspore.ops.vm_impl_registry"; + static const std::string get_vm_impl_fn = "get_vm_impl_fn"; + MS_LOG(INFO) << name() << ": get_vm_impl_fn"; + py::function get_fn = parse::python_adapter::GetPyFn(vm_module, get_vm_impl_fn); + py::function vm_fn = get_fn(python_obj_); + + if (py::isinstance(vm_fn)) { + MS_LOG(WARNING) << "Cannot find " << python_obj_.attr("__class__").attr("__name__").cast(); + vm_fn = mindspore::GetComputeFunction(Primitive::name()); + } + return vm_fn; +} + +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { + std::string attr_name = name; + ValuePtr converted_ret = nullptr; + if (py::isinstance(obj)) { + MS_LOG(EXCEPTION) << "AddPyAttr failed, obj should not be py::module"; + } + bool converted = parse::ConvertData(obj, &converted_ret); + if (!converted) { + MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); + } + (void)this->AddAttr(attr_name, converted_ret); +} + +py::dict PrimitivePy::GetAttrDict() { + py::dict attr_dict; + for (auto &attr : attrs_) { + attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); + } + return attr_dict; +} + +void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(primitive); + if (!primitive->isa()) { + MS_LOG(EXCEPTION) << "Cannot copy a primtive which is not python primitive hook function to python primitive!"; + } + auto primitive_py = primitive->cast(); + MS_EXCEPTION_IF_NULL(primitive_py); + this->set_hook(primitive_py->hook()); +} + +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { + (void)py::enum_(*m, "prim_type", py::arithmetic()) + .value("unknown", PrimType::kPrimTypeUnknown) + .value("builtin", PrimType::kPrimTypeBuiltIn) + .value("py_infer_shape", PrimType::kPrimTypePyInferShape) + .value("user_custom", PrimType::kPrimTypeUserCustom); + (void)py::class_>(*m, "Primitive_") + .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) + .def(py::init()) + .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") + .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") + .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") + .def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.") + .def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.") + .def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name."); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/primitive_py.h b/mindspore/core/ir/primitive_py.h new file mode 100644 index 0000000000..2dc45ac341 --- /dev/null +++ b/mindspore/core/ir/primitive_py.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ +#define MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ + +#include +#include +#include +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "utils/misc.h" +#include "pybind11/pybind11.h" +#include "utils/log_adapter.h" +#include "ir/primitive.h" +#include "ir/signature.h" +#include "frontend/parallel/ops_info/operator_info.h" + +namespace py = pybind11; +namespace mindspore { +class PrimitivePy : public Primitive { + public: + PrimitivePy(const py::str &name, const py::object &python_obj) + : Primitive(name, false), python_obj_(python_obj), signatures_() {} + ~PrimitivePy() override = default; + MS_DECLARE_PARENT(PrimitivePy, Primitive); + py::function GetBpropFunction(); + py::function GetComputeFunction(); + + void set_signatures( + std::vector> + signatures); + + const std::vector &signatures() const { return signatures_; } + + void CopyHookFunction(const PrimitivePtr &primitive) override; + + void AddPyAttr(const py::str &name, const py::object &obj); + + py::dict GetAttrDict(); + void set_hook(const py::function &hook) { hook_ = hook; } + py::function hook() const { return hook_; } + BaseRef RunHookFunction(const VectorRef &args) const override; + const bool parse_info_ = true; + const py::object &GetPyObj() const { return python_obj_; } + bool is_tuple_input_ = false; + + private: + py::object python_obj_; + py::function hook_; + std::vector signatures_; + static std::map hook_grad_; +}; + +using PrimitivePyPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_PY_H_ diff --git a/mindspore/ccsrc/ir/scalar.h b/mindspore/core/ir/scalar.h similarity index 100% rename from mindspore/ccsrc/ir/scalar.h rename to mindspore/core/ir/scalar.h diff --git a/mindspore/ccsrc/ir/scope.cc b/mindspore/core/ir/scope.cc similarity index 100% rename from mindspore/ccsrc/ir/scope.cc rename to mindspore/core/ir/scope.cc diff --git a/mindspore/ccsrc/ir/scope.h b/mindspore/core/ir/scope.h similarity index 100% rename from mindspore/ccsrc/ir/scope.h rename to mindspore/core/ir/scope.h diff --git a/mindspore/ccsrc/ir/signature.h b/mindspore/core/ir/signature.h similarity index 100% rename from mindspore/ccsrc/ir/signature.h rename to mindspore/core/ir/signature.h diff --git a/mindspore/core/ir/signature_py.cc b/mindspore/core/ir/signature_py.cc new file mode 100644 index 0000000000..f513df8533 --- /dev/null +++ b/mindspore/core/ir/signature_py.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/signature.h" +#include "pybind11/operators.h" +#include "pybind_api/api_register.h" +#include "pipeline/jit/parse/data_converter.h" + +namespace py = pybind11; + +namespace mindspore { +// Bind SignatureEnumRW as a python class. +REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { + (void)py::enum_(*m, "signature_rw", py::arithmetic()) + .value("RW_READ", SignatureEnumRW::kRWRead) + .value("RW_WRITE", SignatureEnumRW::kRWWrite) + .value("RW_REF", SignatureEnumRW::kRWRef) + .value("RW_EMPTY_DEFAULT_VALUE", SignatureEnumRW::kRWEmptyDefaultValue); + (void)py::enum_(*m, "signature_kind", py::arithmetic()) + .value("KIND_POSITIONAL_KEYWORD", SignatureEnumKind::kKindPositionalKeyword) + .value("KIND_VAR_POSITIONAL", SignatureEnumKind::kKindVarPositional) + .value("KIND_KEYWORD_ONLY", SignatureEnumKind::kKindKeywordOnly) + .value("KIND_VAR_KEYWARD", SignatureEnumKind::kKindVarKeyword) + .value("KIND_EMPTY_DEFAULT_VALUE", SignatureEnumKind::kKindEmptyDefaultValue); + (void)py::enum_(*m, "signature_dtype", py::arithmetic()) + .value("T", SignatureEnumDType::kDType) + .value("T1", SignatureEnumDType::kDType1) + .value("T2", SignatureEnumDType::kDType2) + .value("T3", SignatureEnumDType::kDType3) + .value("T4", SignatureEnumDType::kDType4) + .value("T5", SignatureEnumDType::kDType5) + .value("T6", SignatureEnumDType::kDType6) + .value("T7", SignatureEnumDType::kDType7) + .value("T8", SignatureEnumDType::kDType8) + .value("T9", SignatureEnumDType::kDType9) + .value("T_EMPTY_DEFAULT_VALUE", SignatureEnumDType::kDTypeEmptyDefaultValue); + })); +} // namespace mindspore diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc new file mode 100644 index 0000000000..6c966b32e3 --- /dev/null +++ b/mindspore/core/ir/tensor.cc @@ -0,0 +1,506 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/tensor.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "runtime/device/device_address.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace tensor { +constexpr auto kEllipsis = "..."; +constexpr auto kThreshold = 6; + +constexpr auto kThreshold1DFloat = kThreshold * 2; +constexpr auto kThreshold1DInt = kThreshold * 4; +constexpr auto kThreshold1DBool = kThreshold * 2; + +static std::string MakeId() { + // Use atomic to make id generator thread safe. + static std::atomic last_id{1}; + return "T" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed)); +} + +static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { + return data_type ? data_type->type_id() : defaultTypeId; +} + +static size_t SizeOf(const std::vector &shape) { + return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); +} + +template +std::vector CopyData(const std::vector &shape, void *data, TypeId data_type) { + const size_t count = SizeOf(shape); + switch (data_type) { + case kNumberTypeBool: + case kNumberTypeUInt8: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt8: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt16: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt32: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeInt64: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeUInt16: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeUInt32: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeUInt64: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeFloat16: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeFloat32: { + const float *buf = static_cast(data); + return std::vector(buf, buf + count); + } + case kNumberTypeFloat64: { + auto buf = static_cast(data); + return std::vector(buf, buf + count); + } + default: + break; + } + MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; +} + +template +std::vector CopyData(const std::vector &shape, void *data, size_t data_len) { + size_t size = SizeOf(shape); + if (size * sizeof(T) != data_len) { + MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) + << " item size " << sizeof(T); + } + auto buf = static_cast(data); + return {buf, buf + size}; +} + +// Tensor data implementation. +template +class TensorDataImpl : public TensorData { + public: + explicit TensorDataImpl(const std::vector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} + ~TensorDataImpl() = default; + + TensorDataImpl(const std::vector &shape, void *data, size_t data_len) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_len)) {} + + TensorDataImpl(const std::vector &shape, void *data, TypeId data_type) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_type)) {} + + template + TensorDataImpl(const std::vector &shape, InputIt first, InputIt last) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(first, last) {} + + template + TensorDataImpl(const std::vector &shape, Scalar scalar) + : ndim_(shape.size()), data_size_(SizeOf(shape)), data_({static_cast(scalar)}) {} + + ssize_t size() const override { return static_cast(data_size_); } + + ssize_t itemsize() const override { return static_cast(sizeof(T)); } + + ssize_t nbytes() const override { return size() * itemsize(); } + + ssize_t ndim() const override { return static_cast(ndim_); } + + void *data() override { + static std::vector empty_data(1); + if (data_size_ == 0) { + // Prevent null pointer for empty shape. + return empty_data.data(); + } + // Lazy allocation. + if (data_.empty()) { + data_.resize(data_size_); + } + return data_.data(); + } + + bool equals(const TensorData &other) const override { + auto ptr = dynamic_cast *>(&other); + if (ptr) { + return (ptr == this) || ((ndim_ == ptr->ndim_) && (data_size_ == ptr->data_size_) && (data_ == ptr->data_)); + } + return false; + } + + std::string ToString(const TypeId type, const std::vector &shape) const override { + constexpr auto valid = + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; + static_assert(valid, "Type is invalid"); + if (data_size_ == 0) { + return ""; + } + if (data_.empty()) { + return ""; + } + + std::ostringstream ss; + ssize_t cursor = 0; + SummaryStringRecursive(ss, type, shape, &cursor, 0); + return ss.str(); + } + + private: + void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const { + int linefeedThreshold; + constexpr auto isFloat = + std::is_same::value || std::is_same::value || std::is_same::value; + for (ssize_t i = start; i < end && (cursor + i) < static_cast(data_size_); i++) { + const auto value = data_[cursor + i]; + if constexpr (isFloat) { + ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right) + << value; + linefeedThreshold = kThreshold1DFloat; + } else if (type == kNumberTypeBool) { + ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True"); + linefeedThreshold = kThreshold1DBool; + } else { + constexpr auto isSigned = std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value; + if constexpr (isSigned) { + if (static_cast(value) >= 0) { + ss << ' '; + } + } + if constexpr (std::is_same::value) { + ss << static_cast(value); + } else if constexpr (std::is_same::value) { + ss << static_cast(value); + } else { + ss << value; + } + linefeedThreshold = kThreshold1DInt; + } + if (i != end - 1) { + ss << ' '; + } + if (ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { // Add a line feed every {threshold of type} for 1D tensor. + ss << '\n' << ' '; + } + } + } + + void SummaryStringRecursive(std::ostringstream &ss, const TypeId type, const std::vector &shape, ssize_t *cursor, + ssize_t depth) const { + if (depth >= static_cast(ndim_)) { + return; + } + ss << '['; + if (depth == static_cast(ndim_) - 1) { // Bottom dimension + ssize_t num = shape[depth]; + if (num > kThreshold && ndim_ > 1) { + OutputDataString(ss, type, *cursor, 0, kThreshold / 2); + ss << ' ' << kEllipsis << ' '; + OutputDataString(ss, type, *cursor, num - kThreshold / 2, num); + } else { + OutputDataString(ss, type, *cursor, 0, num); + } + *cursor += num; + } else { // Middle dimension + ssize_t num = shape[depth]; + // Handle the first half. + for (ssize_t i = 0; i < std::min(static_cast(kThreshold / 2), num); i++) { + if (i > 0) { + ss << '\n'; + ss << std::setw(depth + 1) << ' '; // Add the indent. + } + SummaryStringRecursive(ss, type, shape, cursor, depth + 1); + } + // Handle the ignored part. + if (num > kThreshold) { + ss << '\n'; + ss << std::setw(depth + 1) << ' '; // Add the indent. + ss << kEllipsis; + // Ignored at this layer. + ssize_t ignored = shape[depth + 1]; + for (ssize_t i = depth + 2; i < static_cast(ndim_); i++) { + ignored *= shape[i]; + } + // Multiple with ignored layers number. + ignored *= num - kThreshold; + + *cursor += ignored; + } + // Handle the second half. + if (num > kThreshold / 2) { + for (ssize_t i = num - kThreshold / 2; i < num; i++) { + ss << '\n'; + ss << std::setw(depth + 1) << ' '; // Add the indent. + SummaryStringRecursive(ss, type, shape, cursor, depth + 1); + } + } + } + ss << ']'; + } + + size_t ndim_{0}; + size_t data_size_{0}; + std::vector data_; +}; + +template +TensorDataPtr MakeTensorData(TypeId data_type, const std::vector &shape, const Args... args) { + switch (data_type) { + case kNumberTypeBool: + case kNumberTypeUInt8: + return std::make_shared>(shape, args...); + case kNumberTypeInt8: + return std::make_shared>(shape, args...); + case kNumberTypeInt16: + return std::make_shared>(shape, args...); + case kNumberTypeInt32: + return std::make_shared>(shape, args...); + case kNumberTypeInt64: + return std::make_shared>(shape, args...); + case kNumberTypeUInt16: + return std::make_shared>(shape, args...); + case kNumberTypeUInt32: + return std::make_shared>(shape, args...); + case kNumberTypeUInt64: + return std::make_shared>(shape, args...); + case kNumberTypeFloat16: + return std::make_shared>(shape, args...); + case kNumberTypeFloat32: + return std::make_shared>(shape, args...); + case kNumberTypeFloat64: + return std::make_shared>(shape, args...); + default: + break; + } + MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; +} + +Tensor::Tensor(const Tensor &tensor) + : MetaTensor(tensor), + init_flag_(tensor.init_flag_), + data_(tensor.data_), + dirty_(tensor.dirty_), + id_(tensor.id_), + device_address_(tensor.device_address_) {} + +Tensor::Tensor(const Tensor &tensor, TypeId data_type) + : MetaTensor(data_type, tensor.shape_), + init_flag_(tensor.init_flag_), + data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), + dirty_(tensor.dirty_), + id_(tensor.id_), + device_address_(tensor.device_address_) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data) + : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape) + : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len) + : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} + +Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type) + : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} + +Tensor::Tensor(const std::vector &input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {static_cast(input.size())}), + data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), + id_(MakeId()) {} + +Tensor::Tensor(const std::vector &input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {static_cast(input.size())}), + data_(MakeTensorData(data_type_, shape_, input.begin(), input.end())), + id_(MakeId()) {} + +Tensor::Tensor(int64_t input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeInt32), {}), + data_(MakeTensorData(data_type_, {}, input)), + id_(MakeId()) {} + +Tensor::Tensor(double input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeFloat32), {}), + data_(MakeTensorData(data_type_, {}, input)), + id_(MakeId()) {} + +bool Tensor::operator==(const Tensor &tensor) const { + return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); +} + +bool Tensor::ValueEqual(const Tensor &tensor) const { + return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_))); +} +// assgin value to this tensor +Tensor &Tensor::AssignValue(const Tensor &tensor) { + if (this != &tensor) { + MetaTensor::operator=(tensor); + dirty_ = tensor.is_dirty(); + device_address_ = tensor.device_address(); + data_ = tensor.data_; + id_ = tensor.id(); + } + return *this; +} +abstract::AbstractBasePtr Tensor::ToAbstract() { + auto tens = shared_from_base(); + auto dtype = tens->Dtype(); + if (!IsSubType(dtype, kNumber)) { + MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; + } + auto tensor_shape = tens->shape(); + auto abs_tensor = std::make_shared(dtype, tensor_shape); + abs_tensor->set_value(shared_from_base()); + return abs_tensor; +} + +std::string Tensor::GetShapeAndDataTypeInfo() const { + std::ostringstream buf; + buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); + return buf.str(); +} + +std::string Tensor::ToString() const { + const int small_tensor_size = 30; + std::ostringstream buf; + buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); + // only print small tensor + if (DataSize() < small_tensor_size) { + buf << ", value:" << data().ToString(data_type_, shape()); + } + return buf.str(); +} + +std::string Tensor::ToStringRepr() const { + std::ostringstream buf; + auto type_ptr = this->Dtype(); + MS_EXCEPTION_IF_NULL(type_ptr); + buf << "Tensor shape:[" << shape() << "]" << type_ptr->ToString(); + buf << "\nvalue:" << data().ToString(data_type_, shape()); + return buf.str(); +} + +void Tensor::data_sync() const { + if (device_address_ != nullptr) { + if (!device_address_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { + MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; + } + } +} + +TypeId Tensor::set_data_type(const TypeId data_type) { + if (data_type != data_type_) { + data_ = MakeTensorData(data_type, shape_, data_->data(), data_type_); + return MetaTensor::set_data_type(data_type); + } + return data_type; +} +} // namespace tensor + +namespace inference { +MSTensor *MSTensor::CreateTensor(TypeId data_type, const std::vector &shape) { + return new Tensor(data_type, shape); +} + +Tensor::Tensor(TypeId data_type, const std::vector &shape) { + this->tensor_impl_ = std::make_shared(data_type, shape); +} + +Tensor::Tensor(std::shared_ptr tensor_ptr) { this->tensor_impl_ = std::move(tensor_ptr); } + +TypeId Tensor::data_type() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data_type(); +} + +TypeId Tensor::set_data_type(TypeId data_type) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_data_type(data_type); +} + +std::vector Tensor::shape() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->shape(); +} + +size_t Tensor::set_shape(const std::vector &shape) { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->set_shape(shape); +} + +int Tensor::DimensionSize(size_t index) const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->DimensionSize(index); +} + +int Tensor::ElementsNum() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->ElementsNum(); +} + +std::size_t Tensor::hash() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->hash(); +} + +std::shared_ptr Tensor::tensor() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_; +} + +size_t Tensor::Size() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data().nbytes(); +} + +void *Tensor::MutableData() const { + MS_ASSERT(this->tensor_impl_ != nullptr); + return this->tensor_impl_->data_c(); +} + +} // namespace inference +} // namespace mindspore diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h new file mode 100644 index 0000000000..f2ed2c1609 --- /dev/null +++ b/mindspore/core/ir/tensor.h @@ -0,0 +1,278 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_TENSOR_H_ +#define MINDSPORE_CCSRC_IR_TENSOR_H_ + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "runtime/device/device_address.h" +#include "ir/meta_tensor.h" +#include "include/ms_tensor.h" +#include "utils/log_adapter.h" + +using float16 = Eigen::half; + +using mindspore::device::DeviceAddress; +using DeviceAddressPtr = std::shared_ptr; +// brief mindspore namespace. +// +// mindspore namespace is the top level namespace of MindSpore project. +// Other namespace should be a sub namespace of mindspore namespace in the ME project. +namespace mindspore { +// brief mindspore::tensor namespace +// +// A sub namespace in ME to support tensor related definition. +namespace tensor { +// Tensor data interface. +class TensorData { + public: + /// Total number of elements. + virtual ssize_t size() const = 0; + /// Byte size of a single element. + virtual ssize_t itemsize() const = 0; + /// Total number of bytes. + virtual ssize_t nbytes() const = 0; + /// Number of dimensions. + virtual ssize_t ndim() const = 0; + /// Data pointer. + virtual void *data() = 0; + /// Is data equals. + virtual bool equals(const TensorData &other) const = 0; + /// To string. + virtual std::string ToString(const TypeId type, const std::vector &shape) const = 0; +}; + +using TensorDataPtr = std::shared_ptr; + +// Tensor entity class +class Tensor : public MetaTensor { + public: + abstract::AbstractBasePtr ToAbstract() override; + + // brief Create tensor from another tensor, data is shared. + // + // param tensor [Tensor] The input tensor. + explicit Tensor(const Tensor &tensor); + + // brief Create tensor with given data type from another tensor. + // + // param tensor [Tensor] The input tensor. + // param data_type [TypeId] The new tensor data type. + Tensor(const Tensor &tensor, TypeId data_type); + + // brief Create tensor with the given shared tensor data. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + // param data The shared tensor data. + Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data); + + // brief Create an all zero tensor. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + Tensor(TypeId data_type, const std::vector &shape); + + // brief Create a tensor with input data buffer. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + // param data The input data to be copied into tensor. + // param data_len The length of data in bytes. + Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len); + + // brief Create a tensor with input data buffer and given source data type. + // + // param data_type [TypeId] Data type of the tensor. + // param shape The shape represented by std::vector of the tensor. + // param data The input data to be copied into tensor. + // param src_data_type The source data type. + Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type); + + // brief Create 1 dimension tensor from an int vector. + // + // param input [std::vector] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); + + // brief Create 1 dimension tensor from a float vector. + // + // param input [std::vector] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(const std::vector &input, const TypePtr &data_type = nullptr); + + // brief Create 0 dimension tensor from an int scalar. + // + // param input [int64] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(int64_t input, const TypePtr &data_type = nullptr); + + // brief Create 0 dimension tensor from a float scalar. + // + // param input [double] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(double input, const TypePtr &data_type = nullptr); + + ~Tensor() override = default; + + MS_DECLARE_PARENT(Tensor, MetaTensor); + + // brief Compares two Tensor objects. + // + // Compare two tensor objects to see if they have same data type, shape and data address. + // + // param tensor The Tensor object to be compared. + // return true: If having same type, shape and data address, return true, or return false. + bool operator==(const Tensor &tensor) const; + + // It is different from 'operator==' which just compare shape/type/address, + // it do real value comparison. + bool ValueEqual(const Tensor &tensor) const; + + // assgin value to this tensor + Tensor &AssignValue(const Tensor &tensor); + + bool operator==(const Value &other) const override { + if (other.isa()) { + auto &other_ = static_cast(other); + return *this == other_; + } + return false; + } + + // brief Gets tensor's dimension + // + // return The number of dimensions of the tensor data. + int DataDim() const { return static_cast(data().ndim()); } + + // brief Getting tensor data size + // + // return The total number of elements of the tensor data. + int DataSize() const { return static_cast(data().size()); } + + // brief Get the data type fo the tensor for C++ + // + // return [int] The tensor's data type will be cast to int to return. + int data_type_c() const { return static_cast(data_type_); } + + // brief Get the tensor's shape for C++ + // + // return [std::vector] + std::vector shape_c(void) const { return shape(); } + + // brief Get Tensor data pointer for c++ type + // + // return The pointer to the object + void *data_c() { return data().data(); } + + // brief Get Tensor data byte-size for c++ type + // + // return byte size of Tensor data + size_t Size() const { return data().nbytes(); } + + void *data_c() const { return data_->data(); } + + // brief Sync data with device. + void data_sync() const; + + // brief Get the internal data object. + // + // return The reference to internal data object. + TensorData &data() { return *data_; } + + // brief Get the internal data shared pointer. + // + // return The reference to internal data object. + const TensorDataPtr &data_ptr() const { return data_; } + + // brief Get the internal data object. + // + // return The reference to internal data object. + const TensorData &data() const { return *data_; } + + TypeId set_data_type(const TypeId data_type) override; + + std::string GetShapeAndDataTypeInfo() const; + + std::string ToString() const override; + + std::string ToStringRepr() const; + + bool is_init() const { return init_flag_; } + void set_init_flag(bool flag) { init_flag_ = flag; } + + bool is_dirty() const { return dirty_; } + void set_dirty(const bool dirty) { dirty_ = dirty; } + + DeviceAddressPtr device_address() const { return device_address_; } + void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } + + std::string id() const { return id_; } + + const bool parse_info_ = true; + + private: + bool init_flag_{false}; + TensorDataPtr data_{nullptr}; + bool dirty_{true}; + std::string id_{""}; + DeviceAddressPtr device_address_{nullptr}; +}; +using TensorPtr = std::shared_ptr; +using TensorPtrList = std::vector>; +} // namespace tensor + +namespace inference { +class Tensor : public MSTensor { + public: + Tensor(TypeId data_type, const std::vector &shape); + + explicit Tensor(std::shared_ptr tensor_ptr); + + ~Tensor() = default; + + TypeId data_type() const override; + + TypeId set_data_type(const TypeId data_type) override; + + std::vector shape() const override; + + size_t set_shape(const std::vector &shape) override; + + int DimensionSize(size_t index) const override; + + int ElementsNum() const override; + + std::size_t hash() const override; + + std::shared_ptr tensor() const; + + size_t Size() const override; + + void *MutableData() const override; + + protected: + std::shared_ptr tensor_impl_; +}; +} // namespace inference +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_TENSOR_H_ diff --git a/mindspore/core/ir/tensor_py.cc b/mindspore/core/ir/tensor_py.cc new file mode 100644 index 0000000000..f5f83d0e07 --- /dev/null +++ b/mindspore/core/ir/tensor_py.cc @@ -0,0 +1,390 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/tensor_py.h" + +#include +#include +#include +#include +#include + +#include "runtime/device/device_address.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace tensor { + +static TypeId GetDataType(const py::buffer_info &buf) { + if (buf.format.size() == 1) { + switch (buf.format.front()) { + case 'e': + case 'f': + case 'd': + switch (buf.itemsize) { + case 2: + return TypeId::kNumberTypeFloat16; + case 4: + return TypeId::kNumberTypeFloat32; + case 8: + return TypeId::kNumberTypeFloat64; + } + break; + case 'b': + case 'h': + case 'i': + case 'l': + case 'q': + switch (buf.itemsize) { + case 1: + return TypeId::kNumberTypeInt8; + case 2: + return TypeId::kNumberTypeInt16; + case 4: + return TypeId::kNumberTypeInt32; + case 8: + return TypeId::kNumberTypeInt64; + } + break; + case 'B': + case 'H': + case 'I': + case 'L': + case 'Q': + switch (buf.itemsize) { + case 1: + return TypeId::kNumberTypeUInt8; + case 2: + return TypeId::kNumberTypeUInt16; + case 4: + return TypeId::kNumberTypeUInt32; + case 8: + return TypeId::kNumberTypeUInt64; + } + break; + case '?': + return TypeId::kNumberTypeBool; + } + } + MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize; + return TypeId::kTypeUnknown; +} + +static std::string GetPyTypeFormat(TypeId data_type) { + switch (data_type) { + case TypeId::kNumberTypeFloat16: + return "e"; + case TypeId::kNumberTypeFloat32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeFloat64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt8: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt16: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeUInt64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt8: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt16: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt32: + return py::format_descriptor::format(); + case TypeId::kNumberTypeInt64: + return py::format_descriptor::format(); + case TypeId::kNumberTypeBool: + return py::format_descriptor::format(); + default: + MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; + return ""; + } +} + +static bool IsCContiguous(const py::array &input) { + auto flags = static_cast(input.flags()); + return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0; +} + +TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { + // Get input buffer info. + py::buffer_info buf = input.request(); + // Check data types. + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kTypeUnknown; + auto buf_type = GetDataType(buf); + if (buf_type == TypeId::kTypeUnknown && data_type == TypeId::kTypeUnknown) { + MS_LOG(EXCEPTION) << "Unsupported tensor type!"; + } + // Use buf type as data type if type_ptr not set. + if (data_type == TypeId::kTypeUnknown) { + data_type = buf_type; + } + // Convert input array to C contiguous if need. + std::unique_ptr tmp_buf; + if (!IsCContiguous(input)) { + Py_buffer pybuf; + if (PyObject_GetBuffer(input.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS)) { + MS_LOG(EXCEPTION) << "Failed to get buffer from the input!"; + } + tmp_buf = std::make_unique(pybuf.len); + if (PyBuffer_ToContiguous(tmp_buf.get(), &pybuf, pybuf.len, 'C')) { + MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer."; + } + PyBuffer_Release(&pybuf); + buf.ptr = tmp_buf.get(); + } + // Get tensor shape. + std::vector shape(buf.shape.begin(), buf.shape.end()); + if (data_type == buf_type) { + // Use memory copy if input data type is same as the required type. + return std::make_shared(data_type, shape, buf.ptr, buf.size * buf.itemsize); + } + // Create tensor with data type converted. + return std::make_shared(data_type, shape, buf.ptr, buf_type); +} + +static std::vector GetStrides(const std::vector &shape, ssize_t item_size) { + std::vector strides; + strides.reserve(shape.size()); + const auto ndim = shape.size(); + for (size_t i = 0; i < ndim; ++i) { + auto stride = item_size; + for (size_t j = i + 1; j < ndim; ++j) { + stride *= shape[j]; + } + strides.push_back(stride); + } + return strides; +} + +static py::buffer_info GetPyBufferInfo(const Tensor &tensor) { + std::vector shape(tensor.shape().begin(), tensor.shape().end()); + std::vector strides = GetStrides(shape, tensor.data().itemsize()); + return py::buffer_info{ + tensor.data_c(), tensor.data().itemsize(), GetPyTypeFormat(tensor.data_type()), tensor.DataDim(), shape, strides}; +} + +py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { + auto &shape = tensor.shape(); + py::tuple dims(shape.size()); + for (size_t i = 0; i < dims.size(); ++i) { + dims[i] = py::int_(shape[i]); + } + return dims; +} + +py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { + tensor.data_sync(); + auto info = GetPyBufferInfo(tensor); + py::object self = py::cast(&tensor); + return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); +} + +py::array TensorPy::AsNumpy(const Tensor &tensor) { + auto info = GetPyBufferInfo(tensor); + py::object self = py::cast(&tensor); + return py::array(py::dtype(info), info.shape, info.strides, info.ptr, self); +} + +static std::vector GetShapeFromTuple(const py::tuple &tuple) { + std::vector shape; + const size_t size = tuple.size(); + shape.reserve(tuple.size()); + for (size_t i = 0; i < size; ++i) { + shape.push_back(py::int_(tuple[i])); + } + return shape; +} + +REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { + // Define python MetaTensor class. + (void)py::class_>(*m, "MetaTensor") + .def(py::init>(), py::arg("dtype"), py::arg("shape")) + .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) + .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") + .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.") + .def(py::pickle( + [](const MetaTensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(static_cast(t.data_type()), t.shape()); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 2) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); + return tensor; + })); + // Define python Tensor class. + // dtype should define before Tensor, because Tensor init depend dtype + (void)py::class_>(*m, "Tensor") + .def(py::init([](const Tensor &tensor) { return std::make_shared(tensor); }), + py::arg("input")) + .def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) { + TypeId data_type = type_ptr ? type_ptr->type_id() : kTypeUnknown; + if (data_type == kTypeUnknown || tensor.data_type() == data_type) { + return std::make_shared(tensor); + } + return std::make_shared(tensor, data_type); + }), + py::arg("input"), py::arg("dtype")) + .def(py::init([](const TypePtr &type_ptr, const py::tuple &shape) { + auto data_type = type_ptr ? type_ptr->type_id() : TypeId::kNumberTypeFloat64; + return std::make_shared(data_type, GetShapeFromTuple(shape)); + }), + py::arg("dtype"), py::arg("shape")) + .def(py::init([](const py::array &input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(input, type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::float_ input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::int_ input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::list input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def(py::init([](py::tuple input, const TypePtr &type_ptr) { + return TensorPy::MakeTensor(py::array(input), type_ptr); + }), + py::arg("input"), py::arg("dtype") = nullptr) + .def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_) + .def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag) + .def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter( + Get the tensor's data type. + + Returns: + type, the data type of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 1), np.int32)) + >>> data.dtype + Int32 + )mydelimiter") + .def_property_readonly("shape", TensorPy::GetPyTupleShape, R"mydelimiter( + Get the tensor's shape. + + Returns: + tuple[int], the shape of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((3, 3))) + >>> data.shape() + (3, 3) + )mydelimiter") + .def("asnumpy", TensorPy::SyncAsNumpy, R"mydelimiter( + Convert tensor to numpy.ndarray. + + Returns: + numpy.ndarray. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> array = data.asnumpy() + >>> array + array([[1., 1., 1.], + [1., 1., 1.]]) + )mydelimiter") + .def("size", &Tensor::DataSize, R"mydelimiter( + Get tensor's data size. + + Returns: + int, the size of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.size() + 6 + )mydelimiter") + .def("is_init", &Tensor::is_init, R"mydelimiter( + Get tensor init_flag. + + Returns: + bool, whether the tensor init. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.is_init() + False + )mydelimiter") + .def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter( + Set tensor init_flag. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.set_init_flag(True) + )mydelimiter") + .def("dim", &Tensor::DataDim, R"mydelimiter( + Get tensor's data dimension. + + Returns: + int, the dimension of tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((2, 3))) + >>> data.dim() + 2 + )mydelimiter") + .def("assign_value", &Tensor::AssignValue, R"mydelimiter( + Assign another tensor value to this. + + Arg: + value (:class:`mindspore.tensor`): The value tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data2 = mindspore.Tensor(np.ones((2, 2), np.float32)) + >>> data.assign_value(data2) + >>> data.shape + (2, 2) + )mydelimiter") + .def("set_dtype", &Tensor::SetDtype, R"mydelimiter( + Set the tensor's data type. + + Arg: + dtype (:class:`mindspore.dtype`): The type of output tensor. + + Examples: + >>> data = mindspore.Tensor(np.ones((1, 2), np.float32)) + >>> data.set_dtype(mindspore.int32) + mindspore.int32 + )mydelimiter") + .def("__str__", &Tensor::ToString) + .def("__repr__", &Tensor::ToStringRepr) + .def(py::pickle( + [](const Tensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(TensorPy::AsNumpy(t)); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + return TensorPy::MakeTensor(t[0].cast()); + })); + })); +} // namespace tensor +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/tensor_py.h b/mindspore/core/ir/tensor_py.h similarity index 100% rename from mindspore/ccsrc/ir/tensor_py.h rename to mindspore/core/ir/tensor_py.h diff --git a/mindspore/ccsrc/ir/value.cc b/mindspore/core/ir/value.cc similarity index 100% rename from mindspore/ccsrc/ir/value.cc rename to mindspore/core/ir/value.cc diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/core/ir/value.h similarity index 100% rename from mindspore/ccsrc/ir/value.h rename to mindspore/core/ir/value.h diff --git a/mindspore/ccsrc/ir/value_extends.cc b/mindspore/core/ir/value_extends.cc similarity index 100% rename from mindspore/ccsrc/ir/value_extends.cc rename to mindspore/core/ir/value_extends.cc diff --git a/mindspore/ccsrc/ir/value_py.cc b/mindspore/core/ir/value_py.cc similarity index 100% rename from mindspore/ccsrc/ir/value_py.cc rename to mindspore/core/ir/value_py.cc diff --git a/mindspore/ccsrc/ir/visitor.cc b/mindspore/core/ir/visitor.cc similarity index 100% rename from mindspore/ccsrc/ir/visitor.cc rename to mindspore/core/ir/visitor.cc diff --git a/mindspore/ccsrc/ir/visitor.h b/mindspore/core/ir/visitor.h similarity index 100% rename from mindspore/ccsrc/ir/visitor.h rename to mindspore/core/ir/visitor.h diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 65fbb43133..ef19433c4d 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -17,6 +17,7 @@ message("PYTHON_INCLUDE_DIRS = ${PYTHON_INCLUDE_DIRS}") message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") include_directories(${PYTHON_INCLUDE_DIRS}) include_directories(${MS_CCSRC_PATH}) +include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/) include_directories(${CMAKE_BINARY_DIR}) @@ -27,8 +28,8 @@ link_directories(${MS_CCSRC_BUILD_PATH}) if(ENABLE_MINDDATA) add_definitions(-D ENABLE_MINDDATA) - link_directories(${MS_CCSRC_BUILD_PATH}/dataset) - link_directories(${MS_CCSRC_BUILD_PATH}/mindrecord) + link_directories(${MS_CCSRC_BUILD_PATH}/minddata/dataset) + link_directories(${MS_CCSRC_BUILD_PATH}/minddata/mindrecord) endif() # fetch ut test files if(ENABLE_MINDDATA) @@ -53,82 +54,81 @@ endif() file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/base/*.cc" "../../../mindspore/ccsrc/abstract/*.cc" - "../../../mindspore/ccsrc/ir/*.cc" + "../../../mindspore/core/ir/*.cc" "../../../mindspore/ccsrc/common/*.cc" "../../../mindspore/ccsrc/utils/*.cc" - "../../../mindspore/ccsrc/parallel/*.cc" - "../../../mindspore/ccsrc/pipeline/parse/*.cc" - "../../../mindspore/ccsrc/pipeline/static_analysis/*.cc" - "../../../mindspore/ccsrc/pipeline/pipeline.cc" - "../../../mindspore/ccsrc/pipeline/resource.cc" - "../../../mindspore/ccsrc/pipeline/pass.cc" - "../../../mindspore/ccsrc/pipeline/action.cc" - "../../../mindspore/ccsrc/pipeline/validator.cc" - "../../../mindspore/ccsrc/pipeline/remove_value_node_dup.cc" - "../../../mindspore/ccsrc/optimizer/*.cc" + "../../../mindspore/ccsrc/pipeline/jit/parse/*.cc" + "../../../mindspore/ccsrc/pipeline/jit/static_analysis/*.cc" + "../../../mindspore/ccsrc/pipeline/jit/pipeline.cc" + "../../../mindspore/ccsrc/pipeline/jit/resource.cc" + "../../../mindspore/ccsrc/pipeline/jit/pass.cc" + "../../../mindspore/ccsrc/pipeline/jit/action.cc" + "../../../mindspore/ccsrc/pipeline/jit/validator.cc" + "../../../mindspore/ccsrc/pipeline/jit/remove_value_node_dup.cc" + "../../../mindspore/ccsrc/frontend/optimizer/*.cc" + "../../../mindspore/ccsrc/frontend/parallel/*.cc" "../../../mindspore/ccsrc/debug/*.cc" - "../../../mindspore/ccsrc/operator/*.cc" - "../../../mindspore/ccsrc/transform/*.cc" - "../../../mindspore/ccsrc/session/anf_runtime_algorithm.cc" - "../../../mindspore/ccsrc/session/ascend_session.cc" - "../../../mindspore/ccsrc/session/ascend_control_parser.cc" - "../../../mindspore/ccsrc/session/kernel_graph.cc" - "../../../mindspore/ccsrc/session/session_basic.cc" - "../../../mindspore/ccsrc/session/session_factory.cc" + "../../../mindspore/ccsrc/frontend/operator/*.cc" + "../../../mindspore/ccsrc/transform/graph_ir/*.cc" + "../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc" + "../../../mindspore/ccsrc/backend/session/ascend_session.cc" + "../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc" + "../../../mindspore/ccsrc/backend/session/kernel_graph.cc" + "../../../mindspore/ccsrc/backend/session/session_basic.cc" + "../../../mindspore/ccsrc/backend/session/session_factory.cc" "../../../mindspore/ccsrc/vm/*.cc" - "../../../mindspore/ccsrc/pynative/*.cc" + "../../../mindspore/ccsrc/pipeline/pynative/*.cc" "../../../mindspore/ccsrc/pybind_api/*.cc" - "../../../mindspore/ccsrc/kernel/akg/*.cc" - "../../../mindspore/ccsrc/kernel/kash/*.cc" - "../../../mindspore/ccsrc/kernel/cce/*.cc" - "../../../mindspore/ccsrc/kernel/rts/*.cc" - "../../../mindspore/ccsrc/kernel/hccl/*.cc" - "../../../mindspore/ccsrc/kernel/kernel_query.cc" - "../../../mindspore/ccsrc/kernel/kernel_build_info.cc" - "../../../mindspore/ccsrc/pre_activate/ascend/*.cc" - "../../../mindspore/ccsrc/pre_activate/common/*.cc" - "../../../mindspore/ccsrc/pre_activate/gpu/*.cc" - "../../../mindspore/ccsrc/pre_activate/mem_reuse/*.cc" - "../../../mindspore/ccsrc/pre_activate/pass/*.cc" - "../../../mindspore/ccsrc/kernel/aicpu/aicpu_kernel_metadata.cc" - "../../../mindspore/ccsrc/kernel/rts/rt_kernel_info.cc" - "../../../mindspore/ccsrc/kernel/common_utils.cc" - "../../../mindspore/ccsrc/kernel/oplib/*.cc" - "../../../mindspore/ccsrc/kernel/tbe/*.cc" - "../../../mindspore/ccsrc/device/kernel_runtime.cc" - "../../../mindspore/ccsrc/device/memory_manager.cc" - "../../../mindspore/ccsrc/device/kernel_runtime_manager.cc" - "../../../mindspore/ccsrc/device/kernel_info.cc" - "../../../mindspore/ccsrc/device/ascend/profiling/*.cc" - "../../../mindspore/ccsrc/device/ascend/kernel_select_ascend.cc" - "../../../mindspore/ccsrc/device/ascend/kernel_select_graph_kernel.cc" - "../../../mindspore/ccsrc/device/convert_tensor_utils.cc" - "../../../mindspore/ccsrc/device/ascend/kernel_build_ascend.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_memory_manager.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_device_address.cc" - "../../../mindspore/ccsrc/device/ascend/ascend_memory_pool.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/akg/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/kash/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/rts/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/hccl/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc" + "../../../mindspore/ccsrc/backend/optimizer/ascend/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/common/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/gpu/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/mem_reuse/*.cc" + "../../../mindspore/ccsrc/backend/optimizer/pass/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/common_utils.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/oplib/*.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/tbe/*.cc" + "../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc" + "../../../mindspore/ccsrc/runtime/device/memory_manager.cc" + "../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc" + "../../../mindspore/ccsrc/runtime/device/kernel_info.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/profiling/*.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc" + "../../../mindspore/ccsrc/runtime/device/convert_tensor_utils.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc" + "../../../mindspore/ccsrc/runtime/device/ascend/ascend_memory_pool.cc" "../../../mindspore/ccsrc/predict/generator/utils/ir_model_util.cc" "../../../mindspore/ccsrc/predict/predict.cc" "../../../mindspore/ccsrc/predict/converter/*.cc" "../../../mindspore/ccsrc/predict/converter/attr_utils/*.cc" "../../../mindspore/ccsrc/predict/converter/lite_model/*.cc" "../../../mindspore/ccsrc/predict/converter/lite_model/operations/*.cc" - "../../../mindspore/ccsrc/kernel/cpu/cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/cpu_kernel_factory.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_adam_cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.cc" - "../../../mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.cc" + "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.cc" ) list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/debug/dump_proto.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ir/lite/tensor.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/ps/util.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/ps/scheduler.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/ps/optimizer_info.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/parallel/ps/optimizer_info_builder.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/core/ir/lite/tensor.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/util.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/scheduler.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/frontend/parallel/ps/optimizer_info_builder.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/anf_ir.pb.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/node_strategy.pb.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc") diff --git a/tests/ut/cpp/abstract/abstract_test.cc b/tests/ut/cpp/abstract/abstract_test.cc index ea0b5e5b61..2e3a2a8d1a 100644 --- a/tests/ut/cpp/abstract/abstract_test.cc +++ b/tests/ut/cpp/abstract/abstract_test.cc @@ -18,13 +18,13 @@ #include "common/common_test.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/static_analysis.h" #include "abstract/utils.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/resolve.h" -#include "pipeline/parse/data_converter.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/resolve.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/abstract/utils_test.cc b/tests/ut/cpp/abstract/utils_test.cc index fbc6b3c3e2..33cada28d7 100644 --- a/tests/ut/cpp/abstract/utils_test.cc +++ b/tests/ut/cpp/abstract/utils_test.cc @@ -16,7 +16,7 @@ #include "abstract/utils.h" #include "common/common_test.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/static_analysis.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/common/backend_common_test.cc b/tests/ut/cpp/common/backend_common_test.cc index 060b170a8c..3710349298 100644 --- a/tests/ut/cpp/common/backend_common_test.cc +++ b/tests/ut/cpp/common/backend_common_test.cc @@ -20,11 +20,11 @@ #include #include "utils/log_adapter.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" -#include "session/ascend_session.h" -#include "pipeline/resource.h" -#include "pipeline/action.h" +#include "backend/session/ascend_session.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" #include "ir/anf.h" #include "ir/manager.h" diff --git a/tests/ut/cpp/common/backend_common_test.h b/tests/ut/cpp/common/backend_common_test.h index fb3334182a..f5bfc9d6dd 100644 --- a/tests/ut/cpp/common/backend_common_test.h +++ b/tests/ut/cpp/common/backend_common_test.h @@ -17,7 +17,7 @@ #define TESTS_UT_CPP_COMMON_UT_BACKEND_COMMON_H_ #include "common/common_test.h" #include "utils/context/ms_context.h" -#include "session/kernel_graph.h" +#include "backend/session/kernel_graph.h" namespace mindspore { class BackendCommon : public UT::Common { diff --git a/tests/ut/cpp/common/py_func_graph_fetcher.h b/tests/ut/cpp/common/py_func_graph_fetcher.h index 98552a96b5..d864842760 100644 --- a/tests/ut/cpp/common/py_func_graph_fetcher.h +++ b/tests/ut/cpp/common/py_func_graph_fetcher.h @@ -22,8 +22,8 @@ #include "ir/primitive.h" #include "ir/manager.h" #include "ir/func_graph.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/parse.h" #include "./common.h" namespace UT { diff --git a/tests/ut/cpp/common/test_main.cc b/tests/ut/cpp/common/test_main.cc index f0cfc1778c..fa456ed260 100644 --- a/tests/ut/cpp/common/test_main.cc +++ b/tests/ut/cpp/common/test_main.cc @@ -16,8 +16,8 @@ #include #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "pipeline/pipeline.h" -#include "pipeline/resource.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/resource.h" namespace mindspore { extern void InitSubModulesLogLevel(); diff --git a/tests/ut/cpp/dataset/arena_test.cc b/tests/ut/cpp/dataset/arena_test.cc index e8698ad979..10d27b51c6 100644 --- a/tests/ut/cpp/dataset/arena_test.cc +++ b/tests/ut/cpp/dataset/arena_test.cc @@ -15,7 +15,7 @@ */ #include -#include "dataset/util/arena.h" +#include "minddata/dataset/util/arena.h" #include "common/common.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/batch_op_test.cc b/tests/ut/cpp/dataset/batch_op_test.cc index a04da06e4e..3e1f3c0b32 100644 --- a/tests/ut/cpp/dataset/batch_op_test.cc +++ b/tests/ut/cpp/dataset/batch_op_test.cc @@ -16,14 +16,14 @@ #include #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" #include "securec.h" -#include "dataset/util/status.h" +#include "minddata/dataset/util/status.h" namespace common = mindspore::common; namespace de = mindspore::dataset; diff --git a/tests/ut/cpp/dataset/bit_functions_test.cc b/tests/ut/cpp/dataset/bit_functions_test.cc index 02b6a25f76..cf1c1562db 100644 --- a/tests/ut/cpp/dataset/bit_functions_test.cc +++ b/tests/ut/cpp/dataset/bit_functions_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/constants.h" +#include "minddata/dataset/core/constants.h" #include "common/common.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc b/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc index 4633eefe35..dc59d39fac 100644 --- a/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc +++ b/tests/ut/cpp/dataset/bounding_box_augment_op_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/bounding_box_augment_op.h" -#include "dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/kernels/image/bounding_box_augment_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/btree_test.cc b/tests/ut/cpp/dataset/btree_test.cc index 67b6c4e6c7..9fa4fce812 100644 --- a/tests/ut/cpp/dataset/btree_test.cc +++ b/tests/ut/cpp/dataset/btree_test.cc @@ -15,10 +15,10 @@ */ #include -#include "dataset/util/btree.h" -#include "dataset/util/auto_index.h" -#include "dataset/util/system_pool.h" -#include "dataset/util/task_manager.h" +#include "minddata/dataset/util/btree.h" +#include "minddata/dataset/util/auto_index.h" +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/task_manager.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/c_api_test.cc b/tests/ut/cpp/dataset/c_api_test.cc index 385b327768..902bc9a43b 100644 --- a/tests/ut/cpp/dataset/c_api_test.cc +++ b/tests/ut/cpp/dataset/c_api_test.cc @@ -24,12 +24,12 @@ #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/include/datasets.h" -#include "dataset/include/status.h" -#include "dataset/include/transforms.h" -#include "dataset/include/iterator.h" -#include "dataset/core/constants.h" -#include "dataset/include/samplers.h" +#include "minddata/dataset/include/datasets.h" +#include "minddata/dataset/include/status.h" +#include "minddata/dataset/include/transforms.h" +#include "minddata/dataset/include/iterator.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/include/samplers.h" using namespace mindspore::dataset::api; using mindspore::MsLogLevel::ERROR; diff --git a/tests/ut/cpp/dataset/cache_op_test.cc b/tests/ut/cpp/dataset/cache_op_test.cc index a31a8f8ddf..bdb7c861b2 100644 --- a/tests/ut/cpp/dataset/cache_op_test.cc +++ b/tests/ut/cpp/dataset/cache_op_test.cc @@ -14,19 +14,19 @@ * limitations under the License. */ #include -#include "dataset/core/client.h" -#include "dataset/engine/cache/cache_client.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/cache_op.h" -#include "dataset/engine/datasetops/cache_lookup_op.h" -#include "dataset/engine/datasetops/cache_merge_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/cache/cache_client.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/cache_op.h" +#include "minddata/dataset/engine/datasetops/cache_lookup_op.h" +#include "minddata/dataset/engine/datasetops/cache_merge_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "dataset/util/storage_container.h" // lint !e322 -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/util/storage_container.h" // lint !e322 +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/data_schema.h" using namespace mindspore::dataset; using mindspore::LogStream; diff --git a/tests/ut/cpp/dataset/celeba_op_test.cc b/tests/ut/cpp/dataset/celeba_op_test.cc index a109739fda..ccaed122f4 100644 --- a/tests/ut/cpp/dataset/celeba_op_test.cc +++ b/tests/ut/cpp/dataset/celeba_op_test.cc @@ -19,11 +19,11 @@ #include #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/celeba_op.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/celeba_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/center_crop_op_test.cc b/tests/ut/cpp/dataset/center_crop_op_test.cc index 54c45c957e..cd0f362f64 100644 --- a/tests/ut/cpp/dataset/center_crop_op_test.cc +++ b/tests/ut/cpp/dataset/center_crop_op_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/center_crop_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/center_crop_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/channel_swap_test.cc b/tests/ut/cpp/dataset/channel_swap_test.cc index f1dc1396ca..2000de15b2 100644 --- a/tests/ut/cpp/dataset/channel_swap_test.cc +++ b/tests/ut/cpp/dataset/channel_swap_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/hwc_to_chw_op.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" +#include "minddata/dataset/core/data_type.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/cifar_op_test.cc b/tests/ut/cpp/dataset/cifar_op_test.cc index b37b9acaee..ed22f4f347 100644 --- a/tests/ut/cpp/dataset/cifar_op_test.cc +++ b/tests/ut/cpp/dataset/cifar_op_test.cc @@ -20,14 +20,14 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/cifar_op.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/cifar_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/circular_pool_test.cc b/tests/ut/cpp/dataset/circular_pool_test.cc index c42b08ddcd..d06f846684 100644 --- a/tests/ut/cpp/dataset/circular_pool_test.cc +++ b/tests/ut/cpp/dataset/circular_pool_test.cc @@ -15,9 +15,9 @@ */ #include #include -#include "dataset/util/task_manager.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/services.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/services.h" #include "common/common.h" #include "common/utils.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/client_config_test.cc b/tests/ut/cpp/dataset/client_config_test.cc index a907d50134..5cc9600b4e 100644 --- a/tests/ut/cpp/dataset/client_config_test.cc +++ b/tests/ut/cpp/dataset/client_config_test.cc @@ -20,11 +20,11 @@ #include #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" -#include "dataset/util/status.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/clue_op_test.cc b/tests/ut/cpp/dataset/clue_op_test.cc index ff2f01a9ff..0935434a06 100644 --- a/tests/ut/cpp/dataset/clue_op_test.cc +++ b/tests/ut/cpp/dataset/clue_op_test.cc @@ -17,13 +17,13 @@ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "dataset/engine/datasetops/source/clue_op.h" -#include "dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/source/clue_op.h" +#include "minddata/dataset/util/status.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/coco_op_test.cc b/tests/ut/cpp/dataset/coco_op_test.cc index bcb82f8ec1..6e6d3c26e5 100644 --- a/tests/ut/cpp/dataset/coco_op_test.cc +++ b/tests/ut/cpp/dataset/coco_op_test.cc @@ -20,18 +20,18 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/coco_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/coco_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/common/bboxop_common.cc b/tests/ut/cpp/dataset/common/bboxop_common.cc index e4be1fbbe6..62c9f85348 100644 --- a/tests/ut/cpp/dataset/common/bboxop_common.cc +++ b/tests/ut/cpp/dataset/common/bboxop_common.cc @@ -26,9 +26,9 @@ #include "./tinyxml2.h" #include "opencv2/opencv.hpp" #include "common/utils.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/util/path.h" -#include "dataset/core/constants.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/core/constants.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/common/bboxop_common.h b/tests/ut/cpp/dataset/common/bboxop_common.h index ba3ceb62d9..243908e7a3 100644 --- a/tests/ut/cpp/dataset/common/bboxop_common.h +++ b/tests/ut/cpp/dataset/common/bboxop_common.h @@ -17,7 +17,7 @@ #define TESTS_DATASET_UT_CORE_COMMON_DE_UT_BBOXOP_COMMON_H_ #include "cvop_common.h" -#include "dataset/util/path.h" +#include "minddata/dataset/util/path.h" namespace UT { namespace CVOP { diff --git a/tests/ut/cpp/dataset/common/cvop_common.cc b/tests/ut/cpp/dataset/common/cvop_common.cc index 6f66229e80..48d69564fd 100644 --- a/tests/ut/cpp/dataset/common/cvop_common.cc +++ b/tests/ut/cpp/dataset/common/cvop_common.cc @@ -18,9 +18,9 @@ #include #include #include "cvop_common.h" -#include "dataset/core/constants.h" +#include "minddata/dataset/core/constants.h" #include "common/utils.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" #include #include diff --git a/tests/ut/cpp/dataset/common/cvop_common.h b/tests/ut/cpp/dataset/common/cvop_common.h index 02c079fd68..59134091fd 100644 --- a/tests/ut/cpp/dataset/common/cvop_common.h +++ b/tests/ut/cpp/dataset/common/cvop_common.h @@ -19,7 +19,7 @@ #include #include #include "common.h" -#include "dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/image_utils.h" namespace UT { namespace CVOP { diff --git a/tests/ut/cpp/dataset/concat_op_test.cc b/tests/ut/cpp/dataset/concat_op_test.cc index 70d0268ec7..9e991ce0d3 100644 --- a/tests/ut/cpp/dataset/concat_op_test.cc +++ b/tests/ut/cpp/dataset/concat_op_test.cc @@ -19,7 +19,7 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/concatenate_op_test.cc b/tests/ut/cpp/dataset/concatenate_op_test.cc index 1ceedbac38..dc2fc69266 100644 --- a/tests/ut/cpp/dataset/concatenate_op_test.cc +++ b/tests/ut/cpp/dataset/concatenate_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/concatenate_op.h" +#include "minddata/dataset/kernels/data/concatenate_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/connector_test.cc b/tests/ut/cpp/dataset/connector_test.cc index 7ee36cc2c0..0fc5b100d7 100644 --- a/tests/ut/cpp/dataset/connector_test.cc +++ b/tests/ut/cpp/dataset/connector_test.cc @@ -23,8 +23,8 @@ #include "common/common.h" -#include "dataset/engine/connector.h" -#include "dataset/util/task_manager.h" +#include "minddata/dataset/engine/connector.h" +#include "minddata/dataset/util/task_manager.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/cut_out_op_test.cc b/tests/ut/cpp/dataset/cut_out_op_test.cc index 462fb3a875..5d24d9c3f9 100644 --- a/tests/ut/cpp/dataset/cut_out_op_test.cc +++ b/tests/ut/cpp/dataset/cut_out_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/cut_out_op.h" +#include "minddata/dataset/kernels/image/cut_out_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/cyclic_array_test.cc b/tests/ut/cpp/dataset/cyclic_array_test.cc index 55f75c403f..380436de1b 100644 --- a/tests/ut/cpp/dataset/cyclic_array_test.cc +++ b/tests/ut/cpp/dataset/cyclic_array_test.cc @@ -19,7 +19,7 @@ #include "common/cvop_common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/engine/perf/cyclic_array.h" +#include "minddata/dataset/engine/perf/cyclic_array.h" #include using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/datatype_test.cc b/tests/ut/cpp/dataset/datatype_test.cc index 8cb2210228..b81618dc24 100644 --- a/tests/ut/cpp/dataset/datatype_test.cc +++ b/tests/ut/cpp/dataset/datatype_test.cc @@ -15,11 +15,11 @@ */ #include #include "./securec.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/core/data_type.h" #include "common/common.h" #include "gtest/gtest.h" #include -#include "dataset/core/constants.h" +#include "minddata/dataset/core/constants.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/decode_op_test.cc b/tests/ut/cpp/dataset/decode_op_test.cc index 7f3e129ac0..1cd03099ce 100644 --- a/tests/ut/cpp/dataset/decode_op_test.cc +++ b/tests/ut/cpp/dataset/decode_op_test.cc @@ -16,7 +16,7 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/duplicate_op_test.cc b/tests/ut/cpp/dataset/duplicate_op_test.cc index b7ce32f655..93779b084d 100644 --- a/tests/ut/cpp/dataset/duplicate_op_test.cc +++ b/tests/ut/cpp/dataset/duplicate_op_test.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/core/tensor.h" -#include "dataset/kernels/data/duplicate_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/data/duplicate_op.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/execution_tree_test.cc b/tests/ut/cpp/dataset/execution_tree_test.cc index 529644331a..b871dd00d8 100644 --- a/tests/ut/cpp/dataset/execution_tree_test.cc +++ b/tests/ut/cpp/dataset/execution_tree_test.cc @@ -14,11 +14,11 @@ * limitations under the License. */ #include -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/execution_tree.h" +#include "minddata/dataset/engine/datasetops/shuffle_op.h" +#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/fill_op_test.cc b/tests/ut/cpp/dataset/fill_op_test.cc index d43b7d7548..20e323cc8d 100644 --- a/tests/ut/cpp/dataset/fill_op_test.cc +++ b/tests/ut/cpp/dataset/fill_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/fill_op.h" +#include "minddata/dataset/kernels/data/fill_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/filter_op_test.cc b/tests/ut/cpp/dataset/filter_op_test.cc index 45ee714337..3e5be8dc04 100644 --- a/tests/ut/cpp/dataset/filter_op_test.cc +++ b/tests/ut/cpp/dataset/filter_op_test.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/global_context_test.cc b/tests/ut/cpp/dataset/global_context_test.cc index bb75d941aa..cd4c970ae6 100644 --- a/tests/ut/cpp/dataset/global_context_test.cc +++ b/tests/ut/cpp/dataset/global_context_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "common/common.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 584fde5cef..c4dd7b055c 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -20,9 +20,9 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/util/status.h" -#include "dataset/engine/gnn/node.h" -#include "dataset/engine/gnn/graph_loader.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/graph_loader.h" using namespace mindspore::dataset; using namespace mindspore::dataset::gnn; diff --git a/tests/ut/cpp/dataset/image_folder_op_test.cc b/tests/ut/cpp/dataset/image_folder_op_test.cc index 576c5abbfc..3168efa196 100644 --- a/tests/ut/cpp/dataset/image_folder_op_test.cc +++ b/tests/ut/cpp/dataset/image_folder_op_test.cc @@ -19,18 +19,18 @@ #include #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/interrupt_test.cc b/tests/ut/cpp/dataset/interrupt_test.cc index 7ab608b9ae..8a06413175 100644 --- a/tests/ut/cpp/dataset/interrupt_test.cc +++ b/tests/ut/cpp/dataset/interrupt_test.cc @@ -15,10 +15,10 @@ */ #include "common/common.h" #include "utils/log_adapter.h" -#include "dataset/util/services.h" -#include "dataset/util/intrp_service.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/queue.h" +#include "minddata/dataset/util/services.h" +#include "minddata/dataset/util/intrp_service.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/queue.h" using namespace mindspore::dataset; using mindspore::MsLogLevel::INFO; diff --git a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc index 849943beb1..85b3384d36 100644 --- a/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/jieba_tokenizer_op_test.cc @@ -18,7 +18,7 @@ #include #include "common/common.h" -#include "dataset/text/kernels/jieba_tokenizer_op.h" +#include "minddata/dataset/text/kernels/jieba_tokenizer_op.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/manifest_op_test.cc b/tests/ut/cpp/dataset/manifest_op_test.cc index 6317a6a345..a6eef4aaa2 100644 --- a/tests/ut/cpp/dataset/manifest_op_test.cc +++ b/tests/ut/cpp/dataset/manifest_op_test.cc @@ -20,12 +20,12 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/manifest_op.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/manifest_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index e5deac723f..4e9cfe9ec9 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -19,12 +19,12 @@ #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/resize_op.h" -#include "dataset/kernels/tensor_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/tensor_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/mask_test.cc b/tests/ut/cpp/dataset/mask_test.cc index 9ff5f51fce..609d5bf447 100644 --- a/tests/ut/cpp/dataset/mask_test.cc +++ b/tests/ut/cpp/dataset/mask_test.cc @@ -15,15 +15,15 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/kernels/data/mask_op.h" -#include "dataset/kernels/data/data_utils.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/kernels/data/mask_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/memory_pool_test.cc b/tests/ut/cpp/dataset/memory_pool_test.cc index 136f3fe1b8..b5907655dc 100644 --- a/tests/ut/cpp/dataset/memory_pool_test.cc +++ b/tests/ut/cpp/dataset/memory_pool_test.cc @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "dataset/util/memory_pool.h" -#include "dataset/util/circular_pool.h" -#include "dataset/util/system_pool.h" -#include "dataset/util/allocator.h" +#include "minddata/dataset/util/memory_pool.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/util/system_pool.h" +#include "minddata/dataset/util/allocator.h" #include "common/common.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/mind_record_op_test.cc b/tests/ut/cpp/dataset/mind_record_op_test.cc index b2cbdf027e..c9067535d6 100644 --- a/tests/ut/cpp/dataset/mind_record_op_test.cc +++ b/tests/ut/cpp/dataset/mind_record_op_test.cc @@ -16,14 +16,14 @@ #include #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" #include "utils/log_adapter.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/mnist_op_test.cc b/tests/ut/cpp/dataset/mnist_op_test.cc index da78cb6f7f..dfceeaa06a 100644 --- a/tests/ut/cpp/dataset/mnist_op_test.cc +++ b/tests/ut/cpp/dataset/mnist_op_test.cc @@ -20,18 +20,18 @@ #include "common/utils.h" #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/mnist_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/mnist_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/normalize_op_test.cc b/tests/ut/cpp/dataset/normalize_op_test.cc index 05ac3f6289..31791e0e66 100644 --- a/tests/ut/cpp/dataset/normalize_op_test.cc +++ b/tests/ut/cpp/dataset/normalize_op_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/normalize_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" #include diff --git a/tests/ut/cpp/dataset/one_hot_op_test.cc b/tests/ut/cpp/dataset/one_hot_op_test.cc index c414e371e5..2617ae4536 100644 --- a/tests/ut/cpp/dataset/one_hot_op_test.cc +++ b/tests/ut/cpp/dataset/one_hot_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/one_hot_op.h" +#include "minddata/dataset/kernels/data/one_hot_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/pad_end_op_test.cc b/tests/ut/cpp/dataset/pad_end_op_test.cc index 2787501aa9..1c838da8e8 100644 --- a/tests/ut/cpp/dataset/pad_end_op_test.cc +++ b/tests/ut/cpp/dataset/pad_end_op_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common.h" -#include "dataset/kernels/data/pad_end_op.h" +#include "minddata/dataset/kernels/data/pad_end_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/pad_op_test.cc b/tests/ut/cpp/dataset/pad_op_test.cc index b659d009f3..e2bd822d02 100644 --- a/tests/ut/cpp/dataset/pad_op_test.cc +++ b/tests/ut/cpp/dataset/pad_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/pad_op.h" +#include "minddata/dataset/kernels/image/pad_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/path_test.cc b/tests/ut/cpp/dataset/path_test.cc index 4cf3b17968..b36b38bbc7 100644 --- a/tests/ut/cpp/dataset/path_test.cc +++ b/tests/ut/cpp/dataset/path_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/path.h" +#include "minddata/dataset/util/path.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/perf_data_test.cc b/tests/ut/cpp/dataset/perf_data_test.cc index 048ee1f21a..486209be21 100644 --- a/tests/ut/cpp/dataset/perf_data_test.cc +++ b/tests/ut/cpp/dataset/perf_data_test.cc @@ -17,8 +17,8 @@ #include "common/cvop_common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/engine/perf/cyclic_array.h" -#include "dataset/engine/perf/perf_data.h" +#include "minddata/dataset/engine/perf/cyclic_array.h" +#include "minddata/dataset/engine/perf/perf_data.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/project_op_test.cc b/tests/ut/cpp/dataset/project_op_test.cc index 484396321c..45ef11b88f 100644 --- a/tests/ut/cpp/dataset/project_op_test.cc +++ b/tests/ut/cpp/dataset/project_op_test.cc @@ -19,7 +19,7 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/queue_test.cc b/tests/ut/cpp/dataset/queue_test.cc index 05c80ea50f..ec40cc2ae4 100644 --- a/tests/ut/cpp/dataset/queue_test.cc +++ b/tests/ut/cpp/dataset/queue_test.cc @@ -16,8 +16,8 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/util/task_manager.h" -#include "dataset/util/queue.h" +#include "minddata/dataset/util/task_manager.h" +#include "minddata/dataset/util/queue.h" #include #include #include diff --git a/tests/ut/cpp/dataset/random_color_adjust_op_test.cc b/tests/ut/cpp/dataset/random_color_adjust_op_test.cc index 82df108ad1..96f4dd8145 100644 --- a/tests/ut/cpp/dataset/random_color_adjust_op_test.cc +++ b/tests/ut/cpp/dataset/random_color_adjust_op_test.cc @@ -15,8 +15,8 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_color_adjust_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/random_color_adjust_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc b/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc index 3d5298b071..fd59a90117 100644 --- a/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_and_resize_op_test.cc @@ -16,7 +16,7 @@ #include "common/common.h" #include "common/cvop_common.h" #include -#include "dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc index a1d4481f55..4efdcb8b78 100644 --- a/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_and_resize_with_bbox_op_test.cc @@ -14,11 +14,11 @@ * limitations under the License. */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_with_bbox_op.h" #include "utils/log_adapter.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" using namespace mindspore::dataset; using mindspore::LogStream; diff --git a/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc b/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc index a2ed2fe9f1..170525b4e7 100644 --- a/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_decode_resize_op_test.cc @@ -16,10 +16,10 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/image/random_crop_decode_resize_op.h" -#include "dataset/core/config_manager.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/random_crop_decode_resize_op.h" +#include "minddata/dataset/core/config_manager.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_op_test.cc b/tests/ut/cpp/dataset/random_crop_op_test.cc index 2f3b19e2f4..9c8f1f31ed 100644 --- a/tests/ut/cpp/dataset/random_crop_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_crop_op.h" +#include "minddata/dataset/kernels/image/random_crop_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc index 3790574e02..fcf8ba2605 100644 --- a/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc +++ b/tests/ut/cpp/dataset/random_crop_with_bbox_op_test.cc @@ -15,11 +15,11 @@ */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/random_crop_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_crop_with_bbox_op.h" #include "utils/log_adapter.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" using namespace mindspore::dataset; using mindspore::LogStream; diff --git a/tests/ut/cpp/dataset/random_data_op_test.cc b/tests/ut/cpp/dataset/random_data_op_test.cc index f8a7440c03..3cb7b57ad6 100644 --- a/tests/ut/cpp/dataset/random_data_op_test.cc +++ b/tests/ut/cpp/dataset/random_data_op_test.cc @@ -14,15 +14,15 @@ * limitations under the License. */ -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include #include #include -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/datasetops/source/random_data_op.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/datasetops/source/random_data_op.h" +#include "minddata/dataset/engine/data_schema.h" using namespace mindspore::dataset; using mindspore::MsLogLevel::INFO; diff --git a/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc b/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc index eb2f753554..bb4ba7498d 100644 --- a/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc +++ b/tests/ut/cpp/dataset/random_horizontal_flip_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_horizontal_flip_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc b/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc index 7bdd547918..ed4e866478 100644 --- a/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc +++ b/tests/ut/cpp/dataset/random_horizontal_flip_with_bbox_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_horizontal_flip_with_bbox_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_resize_op_test.cc b/tests/ut/cpp/dataset/random_resize_op_test.cc index ee185f2fc6..d9e85de6e5 100644 --- a/tests/ut/cpp/dataset/random_resize_op_test.cc +++ b/tests/ut/cpp/dataset/random_resize_op_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/kernels/image/random_resize_op.h" +#include "minddata/dataset/kernels/image/random_resize_op.h" #include "common/common.h" #include "common/cvop_common.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc index 01e2bf3fbb..e106f57375 100644 --- a/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc +++ b/tests/ut/cpp/dataset/random_resize_with_bbox_op_test.cc @@ -15,11 +15,11 @@ */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/random_resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_resize_with_bbox_op.h" #include "utils/log_adapter.h" -#include "dataset/core/config_manager.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/config_manager.h" +#include "minddata/dataset/core/global_context.h" using namespace mindspore::dataset; using mindspore::LogStream; diff --git a/tests/ut/cpp/dataset/random_rotation_op_test.cc b/tests/ut/cpp/dataset/random_rotation_op_test.cc index 8b82ef1dcd..a6eb5a1ff3 100644 --- a/tests/ut/cpp/dataset/random_rotation_op_test.cc +++ b/tests/ut/cpp/dataset/random_rotation_op_test.cc @@ -16,8 +16,8 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/core/cv_tensor.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc b/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc index a2583cab96..db8cc89893 100644 --- a/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc +++ b/tests/ut/cpp/dataset/random_vertical_flip_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_vertical_flip_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc b/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc index 2fea8c6c34..d1946ef700 100644 --- a/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc +++ b/tests/ut/cpp/dataset/random_vertical_flip_with_bbox_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" +#include "minddata/dataset/kernels/image/random_vertical_flip_with_bbox_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/rename_op_test.cc b/tests/ut/cpp/dataset/rename_op_test.cc index b6849ec53e..f2091ff466 100644 --- a/tests/ut/cpp/dataset/rename_op_test.cc +++ b/tests/ut/cpp/dataset/rename_op_test.cc @@ -17,15 +17,15 @@ #include #include #include -#include "dataset/core/client.h" -#include "dataset/core/constants.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/rename_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/rename_op.h" #include "common/common.h" #include "common/utils.h" -#include "dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_buffer.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/repeat_op_test.cc b/tests/ut/cpp/dataset/repeat_op_test.cc index 42549546ba..74d494c0dc 100644 --- a/tests/ut/cpp/dataset/repeat_op_test.cc +++ b/tests/ut/cpp/dataset/repeat_op_test.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/rescale_op_test.cc b/tests/ut/cpp/dataset/rescale_op_test.cc index 86abbe972e..5d9bf32a9f 100644 --- a/tests/ut/cpp/dataset/rescale_op_test.cc +++ b/tests/ut/cpp/dataset/rescale_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/rescale_op.h" +#include "minddata/dataset/kernels/image/rescale_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/resize_bilinear_op_test.cc b/tests/ut/cpp/dataset/resize_bilinear_op_test.cc index 8642484149..910c8af2a2 100644 --- a/tests/ut/cpp/dataset/resize_bilinear_op_test.cc +++ b/tests/ut/cpp/dataset/resize_bilinear_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/resize_bilinear_op.h" +#include "minddata/dataset/kernels/image/resize_bilinear_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/resize_op_test.cc b/tests/ut/cpp/dataset/resize_op_test.cc index e23320a65a..807668dde4 100644 --- a/tests/ut/cpp/dataset/resize_op_test.cc +++ b/tests/ut/cpp/dataset/resize_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/resize_op.h" +#include "minddata/dataset/kernels/image/resize_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc b/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc index b81e4f9649..f9eaf85a55 100644 --- a/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc +++ b/tests/ut/cpp/dataset/resize_with_bbox_op_test.cc @@ -15,7 +15,7 @@ */ #include "common/bboxop_common.h" -#include "dataset/kernels/image/resize_with_bbox_op.h" +#include "minddata/dataset/kernels/image/resize_with_bbox_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/schema_test.cc b/tests/ut/cpp/dataset/schema_test.cc index 2da61bc047..95b9c75d9e 100644 --- a/tests/ut/cpp/dataset/schema_test.cc +++ b/tests/ut/cpp/dataset/schema_test.cc @@ -19,11 +19,11 @@ #include #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/data_schema.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/data_schema.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/shuffle_op_test.cc b/tests/ut/cpp/dataset/shuffle_op_test.cc index c9bcb24c4e..98b4878efb 100644 --- a/tests/ut/cpp/dataset/shuffle_op_test.cc +++ b/tests/ut/cpp/dataset/shuffle_op_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/skip_op_test.cc b/tests/ut/cpp/dataset/skip_op_test.cc index 697745512d..387d2f69ff 100644 --- a/tests/ut/cpp/dataset/skip_op_test.cc +++ b/tests/ut/cpp/dataset/skip_op_test.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/circular_pool.h" -#include "dataset/core/client.h" +#include "minddata/dataset/util/circular_pool.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc index dfe15a8f15..96e9652bbc 100644 --- a/tests/ut/cpp/dataset/stand_alone_samplers_test.cc +++ b/tests/ut/cpp/dataset/stand_alone_samplers_test.cc @@ -15,13 +15,13 @@ */ #include "common/common.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/status_test.cc b/tests/ut/cpp/dataset/status_test.cc index c64a86b8ba..195da1c119 100644 --- a/tests/ut/cpp/dataset/status_test.cc +++ b/tests/ut/cpp/dataset/status_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/status.h" +#include "minddata/dataset/util/status.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/subset_random_sampler_test.cc b/tests/ut/cpp/dataset/subset_random_sampler_test.cc index 22200ccbac..c389686014 100644 --- a/tests/ut/cpp/dataset/subset_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/subset_random_sampler_test.cc @@ -16,11 +16,11 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include #include diff --git a/tests/ut/cpp/dataset/take_op_test.cc b/tests/ut/cpp/dataset/take_op_test.cc index b7be066d6c..a8bfe40b10 100644 --- a/tests/ut/cpp/dataset/take_op_test.cc +++ b/tests/ut/cpp/dataset/take_op_test.cc @@ -19,7 +19,7 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/task_manager_test.cc b/tests/ut/cpp/dataset/task_manager_test.cc index 3d34ec9ec5..7b8101fa56 100644 --- a/tests/ut/cpp/dataset/task_manager_test.cc +++ b/tests/ut/cpp/dataset/task_manager_test.cc @@ -16,7 +16,7 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/util/task_manager.h" +#include "minddata/dataset/util/task_manager.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc b/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc index 1849227877..70832c04b5 100644 --- a/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc +++ b/tests/ut/cpp/dataset/tensor_op_fusion_pass_test.cc @@ -16,13 +16,13 @@ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/kernels/image/random_crop_and_resize_op.h" -#include "dataset/kernels/image/decode_op.h" -#include "dataset/engine/datasetops/source/image_folder_op.h" -#include "dataset/engine/execution_tree.h" +#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h" +#include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/engine/datasetops/source/image_folder_op.h" +#include "minddata/dataset/engine/execution_tree.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tensor_string_test.cc b/tests/ut/cpp/dataset/tensor_string_test.cc index 43b235304d..fe336a34c5 100644 --- a/tests/ut/cpp/dataset/tensor_string_test.cc +++ b/tests/ut/cpp/dataset/tensor_string_test.cc @@ -15,13 +15,13 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index 72181a0caf..fce4652b47 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -15,13 +15,13 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tensorshape_test.cc b/tests/ut/cpp/dataset/tensorshape_test.cc index 1af0bf9c82..65ab386db0 100644 --- a/tests/ut/cpp/dataset/tensorshape_test.cc +++ b/tests/ut/cpp/dataset/tensorshape_test.cc @@ -15,10 +15,10 @@ */ #include #include "./securec.h" -#include "dataset/core/client.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor_shape.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/data_schema.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/text_file_op_test.cc b/tests/ut/cpp/dataset/text_file_op_test.cc index 7887eda955..bc2674a6a3 100644 --- a/tests/ut/cpp/dataset/text_file_op_test.cc +++ b/tests/ut/cpp/dataset/text_file_op_test.cc @@ -17,13 +17,13 @@ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/util/status.h" +#include "minddata/dataset/engine/datasetops/source/text_file_op.h" +#include "minddata/dataset/util/status.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 9b312296d8..30fde33ff9 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -17,8 +17,8 @@ #include #include -#include "dataset/core/client.h" -#include "dataset/engine/data_schema.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/engine/data_schema.h" #include "common/common.h" #include "common/utils.h" #include "gtest/gtest.h" diff --git a/tests/ut/cpp/dataset/to_float16_op_test.cc b/tests/ut/cpp/dataset/to_float16_op_test.cc index 9c49c67b2c..5c886690c9 100644 --- a/tests/ut/cpp/dataset/to_float16_op_test.cc +++ b/tests/ut/cpp/dataset/to_float16_op_test.cc @@ -15,9 +15,9 @@ */ #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/image/random_rotation_op.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/data/to_float16_op.h" +#include "minddata/dataset/kernels/image/random_rotation_op.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/data/to_float16_op.h" #include "utils/log_adapter.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/tokenizer_op_test.cc b/tests/ut/cpp/dataset/tokenizer_op_test.cc index afac92aa4b..cc2d7473ff 100644 --- a/tests/ut/cpp/dataset/tokenizer_op_test.cc +++ b/tests/ut/cpp/dataset/tokenizer_op_test.cc @@ -18,14 +18,14 @@ #include #include "common/common.h" -#include "dataset/text/kernels/basic_tokenizer_op.h" -#include "dataset/text/kernels/case_fold_op.h" -#include "dataset/text/kernels/normalize_utf8_op.h" -#include "dataset/text/kernels/regex_replace_op.h" -#include "dataset/text/kernels/regex_tokenizer_op.h" -#include "dataset/text/kernels/unicode_char_tokenizer_op.h" -#include "dataset/text/kernels/unicode_script_tokenizer_op.h" -#include "dataset/text/kernels/whitespace_tokenizer_op.h" +#include "minddata/dataset/text/kernels/basic_tokenizer_op.h" +#include "minddata/dataset/text/kernels/case_fold_op.h" +#include "minddata/dataset/text/kernels/normalize_utf8_op.h" +#include "minddata/dataset/text/kernels/regex_replace_op.h" +#include "minddata/dataset/text/kernels/regex_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_char_tokenizer_op.h" +#include "minddata/dataset/text/kernels/unicode_script_tokenizer_op.h" +#include "minddata/dataset/text/kernels/whitespace_tokenizer_op.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/treap_test.cc b/tests/ut/cpp/dataset/treap_test.cc index b454ab108e..b9c534719c 100644 --- a/tests/ut/cpp/dataset/treap_test.cc +++ b/tests/ut/cpp/dataset/treap_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "dataset/util/treap.h" +#include "minddata/dataset/util/treap.h" #include "common/common.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/dataset/trucate_pair_test.cc b/tests/ut/cpp/dataset/trucate_pair_test.cc index 95e2aaa11b..af7e61c16a 100644 --- a/tests/ut/cpp/dataset/trucate_pair_test.cc +++ b/tests/ut/cpp/dataset/trucate_pair_test.cc @@ -15,12 +15,12 @@ */ #include #include -#include "dataset/core/client.h" +#include "minddata/dataset/core/client.h" #include "common/common.h" #include "gtest/gtest.h" #include "securec.h" -#include "dataset/core/tensor.h" -#include "mindspore/ccsrc/dataset/text/kernels/truncate_sequence_pair_op.h" +#include "minddata/dataset/core/tensor.h" +#include "mindspore/ccsrc/minddata/dataset/text/kernels/truncate_sequence_pair_op.h" using namespace mindspore::dataset; diff --git a/tests/ut/cpp/dataset/type_cast_op_test.cc b/tests/ut/cpp/dataset/type_cast_op_test.cc index 543eb71637..a94a7fedba 100644 --- a/tests/ut/cpp/dataset/type_cast_op_test.cc +++ b/tests/ut/cpp/dataset/type_cast_op_test.cc @@ -17,12 +17,12 @@ #include #include "common/common.h" #include "common/cvop_common.h" -#include "dataset/kernels/data/type_cast_op.h" -#include "dataset/core/client.h" -#include "dataset/core/cv_tensor.h" -#include "dataset/core/data_type.h" -#include "dataset/core/tensor.h" -#include "dataset/core/pybind_support.h" +#include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/pybind_support.h" #include "gtest/gtest.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/voc_op_test.cc b/tests/ut/cpp/dataset/voc_op_test.cc index 05dc28b487..4bb212ffc7 100644 --- a/tests/ut/cpp/dataset/voc_op_test.cc +++ b/tests/ut/cpp/dataset/voc_op_test.cc @@ -20,18 +20,18 @@ #include "common/common.h" #include "common/utils.h" -#include "dataset/core/client.h" -#include "dataset/core/global_context.h" -#include "dataset/engine/datasetops/source/voc_op.h" -#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h" -#include "dataset/engine/datasetops/source/sampler/pk_sampler.h" -#include "dataset/engine/datasetops/source/sampler/random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" -#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" -#include "dataset/util/path.h" -#include "dataset/util/status.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/global_context.h" +#include "minddata/dataset/engine/datasetops/source/voc_op.h" +#include "minddata/dataset/engine/datasetops/source/sampler/distributed_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/subset_random_sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/util/path.h" +#include "minddata/dataset/util/status.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "securec.h" diff --git a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc index d146ed10ac..bb3079aec8 100644 --- a/tests/ut/cpp/dataset/weighted_random_sampler_test.cc +++ b/tests/ut/cpp/dataset/weighted_random_sampler_test.cc @@ -16,11 +16,11 @@ #include "common/common.h" #include "gtest/gtest.h" -#include "dataset/core/constants.h" -#include "dataset/core/tensor.h" -#include "dataset/engine/data_buffer.h" -#include "dataset/engine/datasetops/source/sampler/sampler.h" -#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" +#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "utils/log_adapter.h" #include diff --git a/tests/ut/cpp/dataset/zip_op_test.cc b/tests/ut/cpp/dataset/zip_op_test.cc index b387341398..3ff6d1697e 100644 --- a/tests/ut/cpp/dataset/zip_op_test.cc +++ b/tests/ut/cpp/dataset/zip_op_test.cc @@ -21,17 +21,17 @@ #include #include #include -#include "dataset/core/client.h" -#include "dataset/core/constants.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/zip_op.h" -#include "dataset/core/tensor.h" -#include "dataset/core/config_manager.h" +#include "minddata/dataset/core/client.h" +#include "minddata/dataset/core/constants.h" +#include "minddata/dataset/engine/datasetops/map_op.h" +#include "minddata/dataset/engine/datasetops/zip_op.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/config_manager.h" #include "common/common.h" #include "common/utils.h" -#include "dataset/engine/data_buffer.h" +#include "minddata/dataset/engine/data_buffer.h" #include "gtest/gtest.h" -#include "dataset/core/global_context.h" +#include "minddata/dataset/core/global_context.h" #include "utils/log_adapter.h" namespace common = mindspore::common; diff --git a/tests/ut/cpp/device/ascend_kernel_runtime_test.cc b/tests/ut/cpp/device/ascend_kernel_runtime_test.cc index effa0b212d..2aa9512808 100644 --- a/tests/ut/cpp/device/ascend_kernel_runtime_test.cc +++ b/tests/ut/cpp/device/ascend_kernel_runtime_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" -#include "device/kernel_runtime.h" +#include "runtime/device/kernel_runtime.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/device/ascend_profiling_test.cc b/tests/ut/cpp/device/ascend_profiling_test.cc index 2829a5fd4a..f862d84c4a 100644 --- a/tests/ut/cpp/device/ascend_profiling_test.cc +++ b/tests/ut/cpp/device/ascend_profiling_test.cc @@ -18,12 +18,12 @@ #include "./prof_reporter.h" #include "common/common_test.h" -#include "device/ascend/profiling/profiling_manager.h" +#include "runtime/device/ascend/profiling/profiling_manager.h" #include "./common.h" #define private public -#include "device/ascend/profiling/plugin_impl.h" +#include "runtime/device/ascend/profiling/plugin_impl.h" #undef private -#include "device/ascend/profiling/profiling_engine_impl.h" +#include "runtime/device/ascend/profiling/profiling_engine_impl.h" namespace mindspore { namespace device { diff --git a/tests/ut/cpp/ir/anf_test.cc b/tests/ut/cpp/ir/anf_test.cc index c649518e21..9b217a2321 100644 --- a/tests/ut/cpp/ir/anf_test.cc +++ b/tests/ut/cpp/ir/anf_test.cc @@ -19,7 +19,7 @@ #include "common/common_test.h" #include "ir/anf.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/ir/clone_test.cc b/tests/ut/cpp/ir/clone_test.cc index bb8cae7fbb..20da3fb8b5 100644 --- a/tests/ut/cpp/ir/clone_test.cc +++ b/tests/ut/cpp/ir/clone_test.cc @@ -21,7 +21,7 @@ #include "ir/manager.h" #include "utils/log_adapter.h" #include "ir/func_graph_cloner.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "utils/graph_utils.h" #include "debug/draw.h" #include "./common.h" diff --git a/tests/ut/cpp/ir/manager_test.cc b/tests/ut/cpp/ir/manager_test.cc index 04b584ec10..3e6d1a312c 100644 --- a/tests/ut/cpp/ir/manager_test.cc +++ b/tests/ut/cpp/ir/manager_test.cc @@ -18,8 +18,8 @@ #include "ir/dtype.h" #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "pipeline/parse/parse.h" -#include "operator/ops.h" +#include "pipeline/jit/parse/parse.h" +#include "frontend/operator/ops.h" #include "utils/log_adapter.h" #include "debug/draw.h" #include "debug/label.h" diff --git a/tests/ut/cpp/kernel/common_utils_test.cc b/tests/ut/cpp/kernel/common_utils_test.cc index 4bc05b5c05..83f7c59e52 100644 --- a/tests/ut/cpp/kernel/common_utils_test.cc +++ b/tests/ut/cpp/kernel/common_utils_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" -#include "kernel/common_utils.h" +#include "backend/kernel_compiler/common_utils.h" namespace mindspore { namespace kernel { diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc index dfd6147389..e5cba86230 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_adam_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_adam_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_adam_cpu_kernel.h" #undef private #undef protected diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc index a7df66cf9a..230c8cbf9e 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_ftrl_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_ftrl_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_ftrl_cpu_kernel.h" #undef private #undef protected diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc index 63e8706d1b..a829ead90e 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_lazy_adam_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_lazy_adam_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_lazy_adam_cpu_kernel.h" #undef private #undef protected diff --git a/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc b/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc index 0d679d7e5c..64bd5d3ef3 100644 --- a/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc +++ b/tests/ut/cpp/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #define private public #define protected public -#include "kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" +#include "backend/kernel_compiler/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h" #undef private #undef protected diff --git a/tests/ut/cpp/mindrecord/ut_common.h b/tests/ut/cpp/mindrecord/ut_common.h index 8b244bf87a..ee943ab88e 100644 --- a/tests/ut/cpp/mindrecord/ut_common.h +++ b/tests/ut/cpp/mindrecord/ut_common.h @@ -25,10 +25,10 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_writer.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_writer.h" using json = nlohmann::json; using std::ifstream; using std::pair; diff --git a/tests/ut/cpp/mindrecord/ut_shard.cc b/tests/ut/cpp/mindrecord/ut_shard.cc index b8c229e82f..11492e9f28 100644 --- a/tests/ut/cpp/mindrecord/ut_shard.cc +++ b/tests/ut/cpp/mindrecord/ut_shard.cc @@ -23,10 +23,10 @@ #include "configuration.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc index cea71c34b7..2ff3d1655d 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc @@ -29,13 +29,13 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_writer.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_header.h" -#include "mindrecord/include/shard_schema.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_header.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc index 140fff4166..8e264aafa0 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_index_generator_test.cc @@ -29,10 +29,10 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_error.h" -#include "mindrecord/include/shard_index_generator.h" -#include "mindrecord/include/shard_index.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_error.h" +#include "minddata/mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_index.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 7fe60c3bfa..4501ea0800 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -24,11 +24,11 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_category.h" -#include "mindrecord/include/shard_pk_sample.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_sample.h" -#include "mindrecord/include/shard_shuffle.h" +#include "minddata/mindrecord/include/shard_category.h" +#include "minddata/mindrecord/include/shard_pk_sample.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_shuffle.h" #include "ut_common.h" using mindspore::LogStream; diff --git a/tests/ut/cpp/mindrecord/ut_shard_page_test.cc b/tests/ut/cpp/mindrecord/ut_shard_page_test.cc index dabd3d819f..a7e444c80f 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_page_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_page_test.cc @@ -21,7 +21,7 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_page.h" +#include "minddata/mindrecord/include/shard_page.h" #include "ut_common.h" using json = nlohmann::json; diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index c532fe28b8..8b5eb2cf69 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -24,8 +24,8 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_sample.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_sample.h" #include "ut_common.h" using mindspore::LogStream; diff --git a/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc b/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc index 8d9654a5ef..6863a25791 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc @@ -29,9 +29,9 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_page.h" -#include "mindrecord/include/shard_schema.h" -#include "mindrecord/include/shard_statistics.h" +#include "minddata/mindrecord/include/shard_page.h" +#include "minddata/mindrecord/include/shard_schema.h" +#include "minddata/mindrecord/include/shard_statistics.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc index 3fa6812352..6b99e44d89 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc @@ -30,7 +30,7 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_segment.h" +#include "minddata/mindrecord/include/shard_segment.h" #include "ut_common.h" using mindspore::LogStream; diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 159efbf2f8..046b4f93d5 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -24,9 +24,9 @@ #include "common/utils.h" #include "gtest/gtest.h" #include "utils/log_adapter.h" -#include "mindrecord/include/shard_reader.h" -#include "mindrecord/include/shard_writer.h" -#include "mindrecord/include/shard_index_generator.h" +#include "minddata/mindrecord/include/shard_reader.h" +#include "minddata/mindrecord/include/shard_writer.h" +#include "minddata/mindrecord/include/shard_index_generator.h" #include "securec.h" #include "ut_common.h" diff --git a/tests/ut/cpp/operator/cc_implementations_test.cc b/tests/ut/cpp/operator/cc_implementations_test.cc index bac885db88..4bc5aea964 100644 --- a/tests/ut/cpp/operator/cc_implementations_test.cc +++ b/tests/ut/cpp/operator/cc_implementations_test.cc @@ -18,7 +18,7 @@ #include #include "common/common_test.h" -#include "operator/cc_implementations.h" +#include "frontend/operator/cc_implementations.h" namespace mindspore { namespace prim { diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index ce852175a6..a2108998bc 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -18,10 +18,10 @@ #include "common/common_test.h" #include "ir/anf.h" #include "ir/value.h" -#include "operator/composite/composite.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/abstract_function.h" +#include "frontend/operator/composite/composite.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/abstract_function.h" #include "debug/trace.h" namespace mindspore { diff --git a/tests/ut/cpp/operator/grad_implementations_test.cc b/tests/ut/cpp/operator/grad_implementations_test.cc index e9035e63b6..f55553ab72 100644 --- a/tests/ut/cpp/operator/grad_implementations_test.cc +++ b/tests/ut/cpp/operator/grad_implementations_test.cc @@ -20,7 +20,7 @@ #include "ir/value.h" #include "ir/manager.h" #include "common/common_test.h" -#include "optimizer/ad/dfunctor.h" +#include "frontend/optimizer/ad/dfunctor.h" #include "debug/draw.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 87d32f3e76..789b1cab25 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -20,7 +20,7 @@ #include "common/common_test.h" #include "ir/value.h" #include "ir/primitive_py.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "./common.h" namespace mindspore { diff --git a/tests/ut/cpp/operator/prim2func_test.cc b/tests/ut/cpp/operator/prim2func_test.cc index 8f7c73a064..3952128b52 100644 --- a/tests/ut/cpp/operator/prim2func_test.cc +++ b/tests/ut/cpp/operator/prim2func_test.cc @@ -21,7 +21,7 @@ #include "ir/anf.h" #include "ir/dtype.h" -#include "operator/prim_to_function.h" +#include "frontend/operator/prim_to_function.h" namespace mindspore { namespace prim { diff --git a/tests/ut/cpp/optimizer/ad/ad_test.cc b/tests/ut/cpp/optimizer/ad/ad_test.cc index 34612b5474..3f861d3604 100644 --- a/tests/ut/cpp/optimizer/ad/ad_test.cc +++ b/tests/ut/cpp/optimizer/ad/ad_test.cc @@ -16,7 +16,7 @@ #include #include -#include "optimizer/ad/grad.h" +#include "frontend/optimizer/ad/grad.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" @@ -24,10 +24,10 @@ #include "ir/func_graph_cloner.h" #include "utils/log_adapter.h" #include "utils/graph_utils.h" -#include "pipeline/resource.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace ad { diff --git a/tests/ut/cpp/optimizer/cconv_test.cc b/tests/ut/cpp/optimizer/cconv_test.cc index 8bd6957e85..c004409058 100644 --- a/tests/ut/cpp/optimizer/cconv_test.cc +++ b/tests/ut/cpp/optimizer/cconv_test.cc @@ -20,7 +20,7 @@ #include "ir/func_graph_cloner.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/optimizer/clean_test.cc b/tests/ut/cpp/optimizer/clean_test.cc index c4f393c233..82bec1b5a8 100644 --- a/tests/ut/cpp/optimizer/clean_test.cc +++ b/tests/ut/cpp/optimizer/clean_test.cc @@ -19,9 +19,9 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" -#include "optimizer/clean.h" +#include "frontend/optimizer/clean.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index bc8561f171..751b301283 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -25,11 +25,11 @@ #include "ir/manager.h" #include "ir/value.h" #include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "pipeline/resource.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/irpass.h" +#include "pipeline/jit/resource.h" #include "debug/draw.h" -#include "pipeline/parse/data_converter.h" +#include "pipeline/jit/parse/data_converter.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 2428d0dddb..c329adc4a5 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -22,13 +22,13 @@ #include "ir/anf.h" #include "ir/visitor.h" #include "ir/func_graph_cloner.h" -#include "optimizer/opt.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/arithmetic_simplify.h" +#include "frontend/optimizer/opt.h" +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/irpass/arithmetic_simplify.h" #include "debug/draw.h" -#include "operator/ops.h" -#include "optimizer/cse.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/cse.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/optimizer_test.cc b/tests/ut/cpp/optimizer/optimizer_test.cc index ca7c589d47..c5c99531e4 100644 --- a/tests/ut/cpp/optimizer/optimizer_test.cc +++ b/tests/ut/cpp/optimizer/optimizer_test.cc @@ -20,10 +20,10 @@ #include "common/py_func_graph_fetcher.h" #include "ir/anf.h" -#include "operator/ops.h" -#include "optimizer/cse.h" -#include "optimizer/optimizer.h" -#include "optimizer/irpass.h" +#include "frontend/operator/ops.h" +#include "frontend/optimizer/cse.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/irpass.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc index 0462993672..a500afc859 100644 --- a/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/dp_algo_test.cc @@ -15,12 +15,12 @@ */ #include "common/common_test.h" -#include "parallel/device_manager.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/ops_info/tmp_identity_info.h" -#include "parallel/auto_parallel/dp_algo_costmodel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/auto_parallel/dp_algo_costmodel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc index 291539c27d..190a189a2d 100644 --- a/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/edge_costmodel_test.cc @@ -16,9 +16,9 @@ #include "common/common_test.h" #include "ir/dtype/number.h" -#include "parallel/device_manager.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/ops_info/matmul_info.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc index 78d05c7235..7d63f03179 100644 --- a/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/graph_costmodel_test.cc @@ -15,9 +15,9 @@ */ #include "common/common_test.h" -#include "parallel/device_manager.h" -#include "parallel/auto_parallel/graph_costmodel.h" -#include "parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/ops_info/matmul_info.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc index 919c5b43ec..b9b6bb67d9 100644 --- a/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/operator_costmodel_test.cc @@ -15,10 +15,10 @@ */ #include -#include "parallel/tensor_layout/tensor_layout.h" -#include "parallel/tensor_layout/tensor_info.h" -#include "parallel/auto_parallel/operator_costmodel.h" -#include "parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_info.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/device_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc index 1eb65b468f..7942fa2a10 100644 --- a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc @@ -15,9 +15,9 @@ */ #include "common/common_test.h" -#include "parallel/auto_parallel/rec_core/rec_tensor.h" -#include "parallel/auto_parallel/rec_core/rec_graph.h" -#include "parallel/auto_parallel/rec_core/rec_partition.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" +#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h" #include #include "ir/value.h" diff --git a/tests/ut/cpp/parallel/device_manager_test.cc b/tests/ut/cpp/parallel/device_manager_test.cc index 056896f514..0c048d647b 100644 --- a/tests/ut/cpp/parallel/device_manager_test.cc +++ b/tests/ut/cpp/parallel/device_manager_test.cc @@ -15,9 +15,9 @@ */ #include #include "common/common_test.h" -#include "parallel/device.h" -#include "parallel/device_manager.h" -#include "parallel/group_manager.h" +#include "frontend/parallel/device.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/group_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/device_matrix_test.cc b/tests/ut/cpp/parallel/device_matrix_test.cc index 877a211df8..57a438e76e 100644 --- a/tests/ut/cpp/parallel/device_matrix_test.cc +++ b/tests/ut/cpp/parallel/device_matrix_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/device_matrix.h" +#include "frontend/parallel/device_matrix.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/group_manager_test.cc b/tests/ut/cpp/parallel/group_manager_test.cc index e3d2b3a364..fa4abfcb7e 100644 --- a/tests/ut/cpp/parallel/group_manager_test.cc +++ b/tests/ut/cpp/parallel/group_manager_test.cc @@ -14,10 +14,10 @@ * limitations under the License. */ #include -#include "parallel/device_manager.h" +#include "frontend/parallel/device_manager.h" #include "common/common_test.h" -#include "parallel/device.h" -#include "parallel/group_manager.h" +#include "frontend/parallel/device.h" +#include "frontend/parallel/group_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/activation_info_test.cc b/tests/ut/cpp/parallel/ops_info/activation_info_test.cc index a9fe9b4c48..5f09de9e48 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/activation_test.cc b/tests/ut/cpp/parallel/ops_info/activation_test.cc index 9af7203799..9d129b7a18 100644 --- a/tests/ut/cpp/parallel/ops_info/activation_test.cc +++ b/tests/ut/cpp/parallel/ops_info/activation_test.cc @@ -18,9 +18,9 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc b/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc index e54d1f2423..e49ed4e79d 100644 --- a/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/gelu_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc b/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc index 947ad60cca..125723868a 100644 --- a/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc +++ b/tests/ut/cpp/parallel/ops_info/generate_strategy_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc b/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc index 503edf2eda..029e0f2dc6 100644 --- a/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/get_next_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/get_next_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/get_next_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc b/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc index b59481e1f6..7037a85699 100644 --- a/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/l2_normalize_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/l2_normalize_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/l2_normalize_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc index cf5a4239a2..8de5c07226 100644 --- a/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/log_softmax_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc index f710f51265..2d5676f211 100644 --- a/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/matmul_info_test.cc @@ -18,11 +18,11 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" -#include "parallel/auto_parallel/graph_costmodel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/auto_parallel/graph_costmodel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc index 07d150a294..074e4582f0 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/onehot_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/onehot_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc index c89bf97fb3..769d5bec45 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/onehot_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/onehot_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc index 7b37a90fd8..f582640db8 100644 --- a/tests/ut/cpp/parallel/ops_info/pow_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/pow_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/prelu_test.cc b/tests/ut/cpp/parallel/ops_info/prelu_test.cc index d6db1b8460..1d4cf5eff0 100644 --- a/tests/ut/cpp/parallel/ops_info/prelu_test.cc +++ b/tests/ut/cpp/parallel/ops_info/prelu_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/prelu_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/prelu_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc b/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc index a1fe46ca33..64ba6af70b 100644 --- a/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reduce_method_test.cc @@ -18,11 +18,11 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/reduce_method_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/reduce_method_info.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/reshape_test.cc b/tests/ut/cpp/parallel/ops_info/reshape_test.cc index fb60c6d250..8cc8390e9a 100644 --- a/tests/ut/cpp/parallel/ops_info/reshape_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reshape_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/reshape_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/reshape_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc b/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc index 03634b9a6f..d370c168c9 100644 --- a/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/softmax_entropy_loss_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/loss_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/loss_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc b/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc index bba6e89626..9c4205672b 100644 --- a/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/softmax_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc b/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc index a892c5c84a..2be6c5bf7f 100644 --- a/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tanh_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/activation_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/activation_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc index 42d292c605..b523652fcb 100644 --- a/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tensor_add_info_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/arithmetic_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/arithmetic_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc index eabac51e17..461a27d4ed 100644 --- a/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc +++ b/tests/ut/cpp/parallel/ops_info/tmpidentity_test.cc @@ -15,10 +15,10 @@ */ #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/device_manager.h" -#include "parallel/ops_info/operator_info.h" -#include "parallel/ops_info/tmp_identity_info.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/ops_info/tmp_identity_info.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/ops_info/transpose_test.cc b/tests/ut/cpp/parallel/ops_info/transpose_test.cc index 991ec47820..fe5cbb01b3 100644 --- a/tests/ut/cpp/parallel/ops_info/transpose_test.cc +++ b/tests/ut/cpp/parallel/ops_info/transpose_test.cc @@ -18,10 +18,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/transpose_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/transpose_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/step_auto_parallel_test.cc b/tests/ut/cpp/parallel/step_auto_parallel_test.cc index a1474ca244..6cf7ec66c6 100644 --- a/tests/ut/cpp/parallel/step_auto_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_auto_parallel_test.cc @@ -14,12 +14,12 @@ * limitations under the License. */ #include "common/common_test.h" -#include "parallel/step_parallel.h" -#include "parallel/step_auto_parallel.h" -#include "parallel/auto_parallel/edge_costmodel.h" -#include "parallel/ops_info/operator_info.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/step_auto_parallel.h" +#include "frontend/parallel/auto_parallel/edge_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/static_analysis/static_analysis.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index d8f8681a34..5657db8790 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -14,12 +14,12 @@ * limitations under the License. */ #include "common/common_test.h" -#include "parallel/step_parallel.h" -#include "parallel/graph_util/generate_graph.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/graph_util/generate_graph.h" #include "common/py_func_graph_fetcher.h" #include "debug/draw.h" -#include "operator/ops.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/static_analysis/static_analysis.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/strategy_test.cc b/tests/ut/cpp/parallel/strategy_test.cc index 9a2f92f018..c13b71944e 100644 --- a/tests/ut/cpp/parallel/strategy_test.cc +++ b/tests/ut/cpp/parallel/strategy_test.cc @@ -17,7 +17,7 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" +#include "frontend/parallel/strategy.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc b/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc index 2ba8cc9dfc..b80f199035 100644 --- a/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/construct_operator_test.cc @@ -17,10 +17,10 @@ #include #include "common/common_test.h" #include "ir/value.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/matmul_info.h" -#include "parallel/device_manager.h" -#include "parallel/tensor_layout/construct_operator.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/matmul_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/construct_operator.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc index 5291e2f48d..4ddc130a45 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_layout_transfer_test.cc @@ -17,8 +17,8 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_layout.h" -#include "parallel/tensor_layout/redistribution_layout_transfer.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/redistribution_layout_transfer.h" #include "util_layout_gen_test.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc b/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc index 1b1dd4af04..f6caad2f9d 100644 --- a/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/redistribution_operator_infer_test.cc @@ -16,8 +16,8 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/redistribution_operator_infer.h" -#include "parallel/device_manager.h" +#include "frontend/parallel/tensor_layout/redistribution_operator_infer.h" +#include "frontend/parallel/device_manager.h" #include "util_layout_gen_test.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc index 9d6152721e..11f471ea33 100644 --- a/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/reshape_layout_transfer_test.cc @@ -17,8 +17,8 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_layout.h" -#include "parallel/tensor_layout/reshape_layout_transfer.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/reshape_layout_transfer.h" #include "util_layout_gen_test.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc b/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc index b5e2ea3e5b..824ab876cd 100644 --- a/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/shape_util_test.cc @@ -16,7 +16,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/shape_util.h" +#include "frontend/parallel/tensor_layout/shape_util.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc b/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc index bae05d650a..15fb16f088 100644 --- a/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/tensor_layout_test.cc @@ -17,7 +17,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc b/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc index 572763faa3..40a4017c4b 100644 --- a/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/tensor_redistribution_test.cc @@ -17,7 +17,7 @@ #include #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc index 6f5c1e49ed..330b571ae7 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.cc @@ -21,7 +21,7 @@ #include #include #include -#include "parallel/tensor_layout/shape_util.h" +#include "frontend/parallel/tensor_layout/shape_util.h" #include "common/common_test.h" using std::pow; diff --git a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h index a359cadbea..c16a1fc6d4 100644 --- a/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h +++ b/tests/ut/cpp/parallel/tensor_layout/util_layout_gen_test.h @@ -20,7 +20,7 @@ #include #include -#include "parallel/tensor_layout/tensor_layout.h" +#include "frontend/parallel/tensor_layout/tensor_layout.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/parallel/virtual_dataset_test.cc b/tests/ut/cpp/parallel/virtual_dataset_test.cc index 1d3ff081c7..4cafdebc17 100644 --- a/tests/ut/cpp/parallel/virtual_dataset_test.cc +++ b/tests/ut/cpp/parallel/virtual_dataset_test.cc @@ -17,10 +17,10 @@ #include #include #include "common/common_test.h" -#include "parallel/strategy.h" -#include "parallel/ops_info/virtual_dataset_info.h" -#include "parallel/device_manager.h" -#include "parallel/step_parallel.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/ops_info/virtual_dataset_info.h" +#include "frontend/parallel/device_manager.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { diff --git a/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc b/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc index 3c97cfb203..2d21b591ea 100644 --- a/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_abnormal_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" #include "utils/profile.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_class_test.cc b/tests/ut/cpp/pipeline/parse/parser_class_test.cc index dcedc32b1b..8d9cc8ebc8 100644 --- a/tests/ut/cpp/pipeline/parse/parser_class_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_class_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc b/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc index fd8438503f..1f54298a81 100644 --- a/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_integrate_test.cc @@ -18,7 +18,7 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc b/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc index adc09cca32..937ad1fe5e 100644 --- a/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_primitive_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/parser_test.cc b/tests/ut/cpp/pipeline/parse/parser_test.cc index 4d7731dfd1..f1d9087110 100644 --- a/tests/ut/cpp/pipeline/parse/parser_test.cc +++ b/tests/ut/cpp/pipeline/parse/parser_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/parse/resolve_test.cc b/tests/ut/cpp/pipeline/parse/resolve_test.cc index 8ade92bb34..5a2d0ebd7f 100644 --- a/tests/ut/cpp/pipeline/parse/resolve_test.cc +++ b/tests/ut/cpp/pipeline/parse/resolve_test.cc @@ -19,7 +19,7 @@ #include "common/py_func_graph_fetcher.h" #include "utils/log_adapter.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/resource_test.cc b/tests/ut/cpp/pipeline/resource_test.cc index 09bd2060dc..b6be393652 100644 --- a/tests/ut/cpp/pipeline/resource_test.cc +++ b/tests/ut/cpp/pipeline/resource_test.cc @@ -18,9 +18,9 @@ #include "common/common_test.h" #include "utils/log_adapter.h" -#include "pipeline/resource.h" +#include "pipeline/jit/resource.h" #include "ir/primitive.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace pipeline { diff --git a/tests/ut/cpp/pipeline/static_analysis/data_test.cc b/tests/ut/cpp/pipeline/static_analysis/data_test.cc index d431dcc0ec..fb9d8b1f7e 100644 --- a/tests/ut/cpp/pipeline/static_analysis/data_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/data_test.cc @@ -18,8 +18,8 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" #include "abstract/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc b/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc index eebe6c252b..664f353faa 100644 --- a/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/evaluator_test.cc @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "pipeline/static_analysis/evaluator.h" -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/prim.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/helper.cc b/tests/ut/cpp/pipeline/static_analysis/helper.cc index db697e95e0..ebf8c233e2 100644 --- a/tests/ut/cpp/pipeline/static_analysis/helper.cc +++ b/tests/ut/cpp/pipeline/static_analysis/helper.cc @@ -16,7 +16,7 @@ #include "pipeline/static_analysis/helper.h" -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/prim.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/pipeline/static_analysis/helper.h b/tests/ut/cpp/pipeline/static_analysis/helper.h index 7ca902a1e9..44c647779e 100644 --- a/tests/ut/cpp/pipeline/static_analysis/helper.h +++ b/tests/ut/cpp/pipeline/static_analysis/helper.h @@ -17,7 +17,7 @@ #ifndef TESTS_UT_PIPELINE_STATIC_ANALYSIS_HELPER_H_ #define TESTS_UT_PIPELINE_STATIC_ANALYSIS_HELPER_H_ -#include "pipeline/static_analysis/evaluator.h" +#include "pipeline/jit/static_analysis/evaluator.h" namespace mindspore { namespace abstract { diff --git a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc index 04a14a0f29..8ebea4d212 100644 --- a/tests/ut/cpp/pipeline/static_analysis/prim_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/prim_test.cc @@ -21,9 +21,9 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/static_analysis/helper.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/draw.h" #include "ir/tensor.h" #include "utils/symbolic.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc b/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc index 23ea55f8f7..e32a86d9be 100644 --- a/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/specialize_test.cc @@ -20,8 +20,8 @@ #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "pipeline/static_analysis/program_specialize.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/program_specialize.h" #include "pipeline/static_analysis/helper.h" #include "utils/log_adapter.h" #include "utils/graph_utils.h" diff --git a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc index 8a58969e12..78d3a7083a 100644 --- a/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/static_analysis_test.cc @@ -16,16 +16,16 @@ #include #include -#include "pipeline/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/prim.h" #include "pipeline/static_analysis/helper.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "ir/manager.h" #include "ir/tensor.h" -#include "operator/ops.h" -#include "pipeline/parse/parse.h" -#include "pipeline/parse/data_converter.h" -#include "pipeline/resource.h" +#include "frontend/operator/ops.h" +#include "pipeline/jit/parse/parse.h" +#include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/resource.h" #include "debug/draw.h" #include "utils/log_adapter.h" diff --git a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc index 483c144930..58b810a3e1 100644 --- a/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/buffer_fusion/buffer_fusion_test.cc @@ -17,23 +17,23 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "kernel/kernel.h" -#include "device/kernel_info.h" -#include "pre_activate/common/optimizer.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" -#include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" -#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" +#include "backend/kernel_compiler/kernel.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.h" +#include "backend/optimizer/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_single_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_double_in_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" +#include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc index e4ab2431b7..ba64c206af 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc @@ -15,14 +15,14 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "mindspore/ccsrc/backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc index 56bf0ae4e0..2be25212e8 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc @@ -15,16 +15,16 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/ascend_session.h" -#include "session/anf_runtime_algorithm.h" -#include "pipeline/resource.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/ops.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_getnext.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc index 22cf70ded3..103d0f21a4 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -15,16 +15,16 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" #undef private #undef protected namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc index 72ce73e20f..89d680f442 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/check_consistency_test.cc @@ -16,18 +16,18 @@ #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" #include "common/backend_common_test.h" -#include "session/ascend_session.h" -#include "session/anf_runtime_algorithm.h" -#include "pipeline/resource.h" -#include "pipeline/action.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "pipeline/jit/resource.h" +#include "pipeline/jit/action.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/format_type/check_consistency.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/format_type/check_consistency.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc index 317eace6c6..2b61a49048 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_cast_test.cc @@ -14,17 +14,17 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "device/kernel_info.h" -#include "pre_activate/ascend/format_type/insert_cast.h" -#include "kernel/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/ascend/format_type/insert_cast.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "utils/context/ms_context.h" diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc index 8c57238e0a..0a5cf3dd9e 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc @@ -14,18 +14,18 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc index c0017c2deb..69e7fa8b27 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/merge_cast_to_op_test.cc @@ -15,17 +15,17 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/merge_cast_to_op.h" +#include "backend/optimizer/ascend/format_type/merge_cast_to_op.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc index 90174636b1..8ec2b22a79 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/addn_fission_test.cc @@ -18,7 +18,7 @@ #include "common/py_func_graph_fetcher.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/addn_fission.h" +#include "backend/optimizer/ascend/ir_fission/addn_fission.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc index 06895cb081..f793e0371b 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_bert_fission_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fission/batch_norm_bert_fission.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc index ea4a5c0d5d..80f30c8938 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h" +#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_infer_fission.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc index dc437221f8..f0a5a857b9 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_grad_split_test.cc @@ -15,17 +15,17 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/bn_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/bn_grad_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc index c5ebc28b48..9f4f31bf82 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/bn_split_test.cc @@ -15,20 +15,20 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/ascend_session.h" -#include "session/anf_runtime_algorithm.h" -#include "pipeline/resource.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/bn_split.h" +#include "backend/optimizer/ascend/ir_fission/bn_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc index c0a0cc455e..c726142e99 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/lars_v2_fission_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fission/lars_v2_fission.h" +#include "backend/optimizer/ascend/ir_fission/lars_v2_fission.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc index 1df87960e3..4303485d85 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/layer_norm_grad_split_test.cc @@ -15,17 +15,17 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "operator/ops.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "kernel/kernel_build_info.h" -#include "pre_activate/common/optimizer.h" +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/optimizer/common/optimizer.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" +#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc index b0aa455a0a..9f84f22678 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/single_batch_norm_fission_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fission/single_batch_norm_fission.h" +#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc index ab70e83480..30de43be4e 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/split_fission_test.cc @@ -18,7 +18,7 @@ #include "common/py_func_graph_fetcher.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/split_fission.h" +#include "backend/optimizer/ascend/ir_fission/split_fission.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index b09268aa66..2ab614d4c2 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -16,13 +16,13 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" -#include "pre_activate/pass/convert_const_input_to_attr.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/pass/convert_const_input_to_attr.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fission/topk_split.h" +#include "backend/optimizer/ascend/ir_fission/topk_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc index f2b975a08e..220e45f10a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc @@ -16,16 +16,16 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" #include "debug/anf_ir_dump.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include "pre_activate/ascend/ir_fission/transdata_split.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/ir_fission/transdata_split.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc index c2ee7b6519..2759864037 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_fusion_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc index 014e60f579..78c815bf50 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc index 8b44fa6dc4..5d42ff7069 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc @@ -19,7 +19,7 @@ #define private public #define protected public -#include "pre_activate/ascend/ir_fusion/add_input_to_output.h" +#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc index 466cba8e67..d9d0baf7be 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnorm_to_bninfer_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/batchnorm_to_bninfer.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc index d1fc2783ac..1b64e5fd00 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc index 0c8bf67391..aa56d79239 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc index 4160c3a8e4..ac01f9b1dd 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/clip_by_value_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/clip_by_value_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc index 2044857841..be6bd95b02 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc index 05fa2c65df..068cc0d12e 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/confusion_softmax_grad_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/confusion_softmax_grad_rule.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc index ffa5a42b4d..663ed309ee 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/derelu_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/derelu_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc index 597b7b18ff..f7cbfdc678 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/fused_batch_norm_fusion_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc index 6ea622d030..64c004ff27 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_rule_test.cc @@ -17,7 +17,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc index 36f0321511..776ce625b7 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_rule.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc index fbb1f5e913..bf21649672 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule_test.cc @@ -16,7 +16,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc index f1ca92c811..6a7c866ab4 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_next_right_rule_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_next_right_rule.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc index 7a2806162b..4de2de2700 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc index 05262e72ab..5be6195da2 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2_test.cc @@ -17,7 +17,7 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_v2.h" +#include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_v2.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc index 44b9b3df69..7392d05b98 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc @@ -15,13 +15,13 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" +#include "runtime/device/kernel_info.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #define private public #define protected public -#include "pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc index c8f97be290..f67eda9776 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/matmul_biasadd_fusion_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc index 114fcf4233..50dfd66f54 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/momentum_lossscale_fusion_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pre_activate/ascend/ir_fusion/momentum_lossscale_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc index 87bb21f89a..b293cdeecb 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_add_fusion_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc index ab9718d80a..8ac106f81c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/mul_addn_fusion_test.cc @@ -15,7 +15,7 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.h" +#include "mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc index 59140e91a1..6792f4720a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc @@ -17,8 +17,8 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/ir_fusion/reshape_transpose_fusion.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc index 5f02f0e9c1..f6e8a1194c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc index 2dd858a0fc..efe5433d75 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/square_sum_fusion_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" #include "debug/anf_ir_dump.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc index 3290acd42f..6ec407d2ea 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc @@ -17,8 +17,8 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/ascend/ir_fusion/transpose_reshape_fusion.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 98dc9e9efc..d156959c4c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -16,14 +16,14 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "device/kernel_info.h" -#include "session/anf_runtime_algorithm.h" -#include "kernel/oplib/oplib.h" +#include "runtime/device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/oplib/oplib.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" -#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc index 7b0e2cc9db..12030433fc 100644 --- a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc +++ b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc @@ -20,8 +20,8 @@ #include #include "common/common_test.h" -#include "pre_activate/common/pattern_engine.h" -#include "pre_activate/common/visit.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/visit.h" #include "utils/base_ref.h" #include "ir/anf.h" diff --git a/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc index 5b237fda58..8b6d3e061a 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/kernel_ref_test.cc @@ -18,7 +18,7 @@ #include #include -#include "pre_activate/mem_reuse/kernel_refcount.h" +#include "backend/optimizer/mem_reuse/kernel_refcount.h" #include "utils/utils.h" #include "common/common_test.h" diff --git a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc index e0966d2d12..2a6904658e 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_allocator_test.cc @@ -17,9 +17,9 @@ #include #include #include -#include "operator/ops.h" -#include "pre_activate/mem_reuse/mem_reuse.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" +#include "frontend/operator/ops.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc index a36463d297..31ae923c0a 100644 --- a/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc +++ b/tests/ut/cpp/pre_activate/mem_reuse/mem_reuse_test.cc @@ -16,19 +16,19 @@ #include #include #include -#include "session/kernel_graph.h" -#include "session/session_basic.h" -#include "session/ascend_session.h" -#include "pre_activate/mem_reuse/kernel_refcount.h" -#include "pre_activate/mem_reuse/mem_reuse_allocator.h" -#include "device/kernel_info.h" -#include "kernel/tbe/tbe_kernel_mod.h" -#include "operator/ops.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/session_basic.h" +#include "backend/session/ascend_session.h" +#include "backend/optimizer/mem_reuse/kernel_refcount.h" +#include "backend/optimizer/mem_reuse/mem_reuse_allocator.h" +#include "runtime/device/kernel_info.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" +#include "frontend/operator/ops.h" #include "utils/log_adapter.h" -#include "session/anf_runtime_algorithm.h" +#include "backend/session/anf_runtime_algorithm.h" #include "common/utils.h" -#include "pipeline/resource.h" -#include "pre_activate/mem_reuse/mem_reuse.h" +#include "pipeline/jit/resource.h" +#include "backend/optimizer/mem_reuse/mem_reuse.h" #include "common/common_test.h" #include "common/py_func_graph_fetcher.h" diff --git a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc index 69a330614e..02e1865a82 100644 --- a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc @@ -15,16 +15,16 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/pass/communication_op_fusion.h" -#include "pre_activate/common/optimizer.h" -#include "device/kernel_info.h" -#include "pre_activate/common/pass_manager.h" -#include "kernel/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/pass/communication_op_fusion.h" +#include "backend/optimizer/common/optimizer.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "utils/context/ms_context.h" diff --git a/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc b/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc index 12c4d35db5..cfcc34970b 100644 --- a/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc +++ b/tests/ut/cpp/pre_activate/pass/common_subexpression_elimination_test.cc @@ -14,17 +14,17 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "device/kernel_info.h" -#include "pre_activate/pass/common_subexpression_elimination.h" -#include "kernel/kernel_build_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/pass/common_subexpression_elimination.h" +#include "backend/kernel_compiler/kernel_build_info.h" #include "utils/utils.h" #include "utils/context/ms_context.h" diff --git a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc index 8fc709433e..25e4b3c111 100644 --- a/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc +++ b/tests/ut/cpp/pre_activate/pass/const_to_attr_strided_slice_grad_test.cc @@ -14,13 +14,13 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/const_to_attr_strided_slice_grad.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/const_to_attr_strided_slice_grad.h" #include "utils/utils.h" #include "common/utils.h" diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc index fcb3b19a24..ac3272317a 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_attr_test.cc @@ -14,13 +14,13 @@ * limitations under the License. */ #include "common/backend_common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_const_input_to_attr.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_const_input_to_attr.h" #include "utils/utils.h" #include "common/utils.h" diff --git a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc index 1749e54d94..5b303d15a5 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_const_input_to_tensor_input_test.cc @@ -18,10 +18,10 @@ #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_const_input_to_tensor_input.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_const_input_to_tensor_input.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc b/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc index aded376536..2c1dfc1c6c 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_tuple_input_to_dynamic_input_test.cc @@ -18,10 +18,10 @@ #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_tuple_input_to_dynamic_input.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_tuple_input_to_dynamic_input.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc b/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc index eeb01270e2..458c854218 100644 --- a/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc +++ b/tests/ut/cpp/pre_activate/pass/convert_tuple_output_to_maketuple_test.cc @@ -18,10 +18,10 @@ #include "ir/tensor.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -#include "session/anf_runtime_algorithm.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/convert_tuple_output_to_maketuple.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" +#include "backend/optimizer/pass/convert_tuple_output_to_maketuple.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc index 3e43155011..07bef7a042 100644 --- a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc @@ -15,26 +15,26 @@ */ #include "common/backend_common_test.h" -#include "kernel/kernel.h" -#include "operator/ops.h" +#include "backend/kernel_compiler/kernel.h" +#include "frontend/operator/ops.h" #include "ir/tensor.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "common/py_func_graph_fetcher.h" -// #include "device/optimizer/pass/insert_trans_op.h" -#include "pre_activate/ascend/format_type/insert_cast.h" -#include "pre_activate/pass/eliminate_redundant_op.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pass_manager.h" +// #include "runtime/device/optimizer/pass/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_cast.h" +#include "backend/optimizer/pass/eliminate_redundant_op.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/common/pass_manager.h" #include "utils/utils.h" #include "utils/context/ms_context.h" -#include "session/anf_runtime_algorithm.h" -#include "device/kernel_info.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/kernel_info.h" #include "utils/context/ms_context.h" #define private public #define protected public -#include "pre_activate/ascend/format_type/insert_trans_op.h" +#include "backend/optimizer/ascend/format_type/insert_trans_op.h" #undef private #undef protected diff --git a/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc b/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc index b172e1b351..555dd95426 100644 --- a/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc +++ b/tests/ut/cpp/pre_activate/pass/getitem_tuple_test.cc @@ -15,14 +15,14 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "session/ascend_session.h" -#include "pipeline/resource.h" -#include "operator/ops.h" +#include "backend/session/ascend_session.h" +#include "pipeline/jit/resource.h" +#include "frontend/operator/ops.h" #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "utils/utils.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/pass/getitem_tuple.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/pass/getitem_tuple.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc index 04461e6602..f9cfe273bc 100644 --- a/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc +++ b/tests/ut/cpp/pre_activate/pass/optimize_dependence_test.cc @@ -15,8 +15,8 @@ */ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" -#include "pre_activate/common/optimizer.h" -#include "pre_activate/pass/optimize_dependence.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/pass/optimize_dependence.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index a0d1516b58..c5f25ca484 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -16,10 +16,10 @@ #include #include #include "common/common_test.h" -#include "pipeline/parse/python_adapter.h" -#include "pipeline/parse/data_converter.h" -#include "operator/ops.h" -#include "pynative/pynative_execute.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "pipeline/jit/parse/data_converter.h" +#include "frontend/operator/ops.h" +#include "pipeline/pynative/pynative_execute.h" #include "utils/context/ms_context.h" #include "utils/utils.h" diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index 6769775b3f..e81870fd4f 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -16,11 +16,11 @@ #include "common/common_test.h" #include "ir/param_value.h" -#include "operator/ops.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "mindspore/ccsrc/device/kernel_info.h" -#include "mindspore/ccsrc/device/ascend/ascend_device_address.h" +#include "frontend/operator/ops.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "mindspore/ccsrc/runtime/device/kernel_info.h" +#include "mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index 318cbc982a..fb78a150b6 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -16,10 +16,10 @@ #include "common/common_test.h" #include "ir/param_value.h" -#include "operator/ops.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" -#include "mindspore/ccsrc/device/kernel_info.h" +#include "frontend/operator/ops.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "mindspore/ccsrc/runtime/device/kernel_info.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/session/session_basic_test.cc b/tests/ut/cpp/session/session_basic_test.cc index 1a7ca68065..c438c92b52 100644 --- a/tests/ut/cpp/session/session_basic_test.cc +++ b/tests/ut/cpp/session/session_basic_test.cc @@ -15,10 +15,10 @@ */ #include "common/common_test.h" -#include "operator/ops.h" -#include "session/ascend_session.h" -#include "session/kernel_graph.h" -#include "session/anf_runtime_algorithm.h" +#include "frontend/operator/ops.h" +#include "backend/session/ascend_session.h" +#include "backend/session/kernel_graph.h" +#include "backend/session/anf_runtime_algorithm.h" #include "utils/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/stub/aicpu/aicpu_stub.cc b/tests/ut/cpp/stub/aicpu/aicpu_stub.cc index 78ada6de18..5516d1fdc8 100644 --- a/tests/ut/cpp/stub/aicpu/aicpu_stub.cc +++ b/tests/ut/cpp/stub/aicpu/aicpu_stub.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "kernel/kernel.h" +#include "backend/kernel_compiler/kernel.h" namespace mindspore { namespace kernel { diff --git a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc index 9b48adb574..234ffdaf6b 100644 --- a/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc +++ b/tests/ut/cpp/stub/ge/ge_task_launch_stub.cc @@ -15,7 +15,7 @@ */ #include #include "framework/ge_runtime/model_runner.h" -#include "device/ascend/tasksink/runtime_utils.h" +#include "runtime/device/ascend/tasksink/runtime_utils.h" namespace ge { namespace model_runner { diff --git a/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc b/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc index ba642dfe18..87ab543c7c 100755 --- a/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc +++ b/tests/ut/cpp/stub/kernel/kernel_fusion_stub.cc @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "kernel/kernel_fusion.h" -#include "kernel/tbe/tbe_kernel_mod.h" +#include "backend/kernel_compiler/kernel_fusion.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_mod.h" #include "common/utils.h" namespace mindspore { diff --git a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc index 43d0dd4b3f..f6f2f45092 100644 --- a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc +++ b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc @@ -15,7 +15,7 @@ */ #include #include -#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" +#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "utils/log_adapter.h" namespace mindspore { diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index 8c00e518c3..85470e2315 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "device/ascend/ascend_stream_assign.h" -#include "device/ascend/ascend_label_assign.h" -#include "device/kernel_adjust.h" +#include "runtime/device/ascend/ascend_stream_assign.h" +#include "runtime/device/ascend/ascend_label_assign.h" +#include "runtime/device/kernel_adjust.h" namespace mindspore { namespace device { diff --git a/tests/ut/cpp/stub/tasksink/task_sink_stub.cc b/tests/ut/cpp/stub/tasksink/task_sink_stub.cc index b4318488c0..0b12a3862c 100644 --- a/tests/ut/cpp/stub/tasksink/task_sink_stub.cc +++ b/tests/ut/cpp/stub/tasksink/task_sink_stub.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "device/ascend/tasksink/task_generator.h" +#include "runtime/device/ascend/tasksink/task_generator.h" namespace mindspore { namespace device { diff --git a/tests/ut/cpp/transform/convert_test.cc b/tests/ut/cpp/transform/convert_test.cc index f8f48920e0..6902f7d90d 100644 --- a/tests/ut/cpp/transform/convert_test.cc +++ b/tests/ut/cpp/transform/convert_test.cc @@ -20,16 +20,16 @@ #include "transform/transform_base_test.h" #include "common/py_func_graph_fetcher.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "debug/draw.h" #include "debug/anf_ir_dump.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" #include "common/common_test.h" #define private public -#include "transform/types.h" -#include "transform/convert.h" +#include "transform/graph_ir/types.h" +#include "transform/graph_ir/convert.h" #include "securec/include/securec.h" #include "utils/utils.h" using std::cout; diff --git a/tests/ut/cpp/transform/graph_builder_test.cc b/tests/ut/cpp/transform/graph_builder_test.cc index e92463e2dc..e4d72b33cb 100644 --- a/tests/ut/cpp/transform/graph_builder_test.cc +++ b/tests/ut/cpp/transform/graph_builder_test.cc @@ -25,8 +25,8 @@ #endif #define private public -#include "transform/graph_builder.h" -#include "transform/df_graph_manager.h" +#include "transform/graph_ir/graph_builder.h" +#include "transform/graph_ir/df_graph_manager.h" using UT::Common; diff --git a/tests/ut/cpp/transform/graph_manager_test.cc b/tests/ut/cpp/transform/graph_manager_test.cc index 699f81ca4c..9e55e1725b 100644 --- a/tests/ut/cpp/transform/graph_manager_test.cc +++ b/tests/ut/cpp/transform/graph_manager_test.cc @@ -25,7 +25,7 @@ #endif #define private public -#include "transform/df_graph_manager.h" +#include "transform/graph_ir/df_graph_manager.h" using UT::Common; diff --git a/tests/ut/cpp/transform/graph_runner_test.cc b/tests/ut/cpp/transform/graph_runner_test.cc index 1b87cea464..b91ec959d2 100644 --- a/tests/ut/cpp/transform/graph_runner_test.cc +++ b/tests/ut/cpp/transform/graph_runner_test.cc @@ -21,10 +21,10 @@ #include "ir/tensor_py.h" #include "transform/transform_base_test.h" #include "common/py_func_graph_fetcher.h" -#include "pipeline/static_analysis/static_analysis.h" -#include "operator/ops.h" -#include "transform/df_graph_manager.h" -#include "transform/convert.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "frontend/operator/ops.h" +#include "transform/graph_ir/df_graph_manager.h" +#include "transform/graph_ir/convert.h" #include "utils/utils.h" #ifdef OPEN_SOURCE @@ -34,7 +34,7 @@ #endif #define private public -#include "transform/graph_runner.h" +#include "transform/graph_ir/graph_runner.h" using mindspore::tensor::TensorPy; diff --git a/tests/ut/cpp/transform/op_adapter_test.cc b/tests/ut/cpp/transform/op_adapter_test.cc index 254452bb42..2aa6ba37e3 100644 --- a/tests/ut/cpp/transform/op_adapter_test.cc +++ b/tests/ut/cpp/transform/op_adapter_test.cc @@ -19,9 +19,9 @@ #include "common/common_test.h" -#include "transform/op_declare.h" +#include "transform/graph_ir/op_declare.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "./common.h" using std::cout; diff --git a/tests/ut/cpp/transform/transform_base_test.h b/tests/ut/cpp/transform/transform_base_test.h index 92147dfbbf..4886b25748 100644 --- a/tests/ut/cpp/transform/transform_base_test.h +++ b/tests/ut/cpp/transform/transform_base_test.h @@ -20,11 +20,11 @@ #include #include #include -#include "transform/util.h" +#include "transform/graph_ir/util.h" #include "ir/tensor.h" #include "common/common_test.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "./common.h" #include "graph/tensor.h" diff --git a/tests/ut/cpp/utils/any_test.cc b/tests/ut/cpp/utils/any_test.cc index d11831d602..8a49017d95 100644 --- a/tests/ut/cpp/utils/any_test.cc +++ b/tests/ut/cpp/utils/any_test.cc @@ -20,7 +20,7 @@ #include #include "common/common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "utils/any.h" #include "utils/misc.h" diff --git a/tests/ut/cpp/utils/callback_test.cc b/tests/ut/cpp/utils/callback_test.cc index c63f68f000..0a4ffb8190 100644 --- a/tests/ut/cpp/utils/callback_test.cc +++ b/tests/ut/cpp/utils/callback_test.cc @@ -18,9 +18,9 @@ #include "pybind11/pybind11.h" #include "utils/callbacks.h" #include "common/common_test.h" -#include "pipeline/pipeline.h" -#include "pipeline/parse/python_adapter.h" -#include "transform/df_graph_manager.h" +#include "pipeline/jit/pipeline.h" +#include "pipeline/jit/parse/python_adapter.h" +#include "transform/graph_ir/df_graph_manager.h" #include "debug/draw.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" diff --git a/tests/ut/cpp/utils/graph_utils_test.cc b/tests/ut/cpp/utils/graph_utils_test.cc index ce5a4318d3..35fa9cdc6a 100644 --- a/tests/ut/cpp/utils/graph_utils_test.cc +++ b/tests/ut/cpp/utils/graph_utils_test.cc @@ -24,8 +24,8 @@ #include "ir/anf.h" #include "utils/graph_utils.h" -#include "pipeline/parse/parse_base.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse_base.h" +#include "pipeline/jit/parse/parse.h" namespace mindspore { diff --git a/tests/ut/cpp/utils/ir_import_test.cc b/tests/ut/cpp/utils/ir_import_test.cc index 5e7db98a38..374c36b4e8 100644 --- a/tests/ut/cpp/utils/ir_import_test.cc +++ b/tests/ut/cpp/utils/ir_import_test.cc @@ -19,10 +19,10 @@ #include "utils/log_adapter.h" #include "debug/anf_ir_utils.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" namespace mindspore { class TestIrImporter : public UT::Common { diff --git a/tests/ut/cpp/utils/symbolic_test.cc b/tests/ut/cpp/utils/symbolic_test.cc index f259b62d6b..c0abd388d5 100644 --- a/tests/ut/cpp/utils/symbolic_test.cc +++ b/tests/ut/cpp/utils/symbolic_test.cc @@ -14,7 +14,7 @@ * limitations under the License. */ #include "common/common_test.h" -#include "pipeline/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/static_analysis.h" #include "utils/symbolic.h" using std::cout; diff --git a/tests/ut/cpp/utils/validator_test.cc b/tests/ut/cpp/utils/validator_test.cc index 8eef44bde5..93334d7664 100644 --- a/tests/ut/cpp/utils/validator_test.cc +++ b/tests/ut/cpp/utils/validator_test.cc @@ -18,11 +18,11 @@ #include "common/common_test.h" #include "utils/log_adapter.h" -#include "pipeline/validator.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/validator.h" +#include "pipeline/jit/parse/parse.h" #include "ir/manager.h" -#include "pipeline/static_analysis/prim.h" -#include "operator/ops.h" +#include "pipeline/jit/static_analysis/prim.h" +#include "frontend/operator/ops.h" namespace mindspore { namespace validator { diff --git a/tests/ut/cpp/vm/segment_runner_test.cc b/tests/ut/cpp/vm/segment_runner_test.cc index b9bc552d90..c83b1b3434 100644 --- a/tests/ut/cpp/vm/segment_runner_test.cc +++ b/tests/ut/cpp/vm/segment_runner_test.cc @@ -20,11 +20,11 @@ #include "ir/manager.h" #include "utils/log_adapter.h" #include "ir/func_graph_cloner.h" -#include "pipeline/parse/parse.h" +#include "pipeline/jit/parse/parse.h" #include "utils/graph_utils.h" -#include "pipeline/resource.h" +#include "pipeline/jit/resource.h" #include "debug/draw.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "vm/segment_runner.h" #include "vm/transform.h" #include "ir/tensor.h" diff --git a/tests/ut/cpp/vm/vm_test.cc b/tests/ut/cpp/vm/vm_test.cc index 04633043af..9168d408c3 100644 --- a/tests/ut/cpp/vm/vm_test.cc +++ b/tests/ut/cpp/vm/vm_test.cc @@ -15,7 +15,7 @@ */ #include "vm/vm.h" #include "common/common_test.h" -#include "operator/ops.h" +#include "frontend/operator/ops.h" #include "vm/backend.h" namespace mindspore {